from __future__ import annotations from collections.abc import Callable, Iterator import pytest import pytest_asyncio from fastapi import FastAPI, Request from fastapi.testclient import TestClient from httpx import ASGITransport, AsyncClient from sqlalchemy import create_engine from sqlalchemy.engine import Engine from sqlalchemy.orm import sessionmaker from sqlalchemy.pool import StaticPool from config.database import Base from dependencies import get_auth_session, get_import_ingestion_service, get_unit_of_work from models import User from routes.auth import router as auth_router from routes.dashboard import router as dashboard_router from routes.projects import router as projects_router from routes.scenarios import router as scenarios_router from routes.imports import router as imports_router from routes.exports import router as exports_router from routes.reports import router as reports_router from services.importers import ImportIngestionService from services.unit_of_work import UnitOfWork from services.session import AuthSession, SessionTokens from tests.utils.security import random_password, random_token @pytest.fixture() def engine() -> Iterator[Engine]: engine = create_engine( "sqlite+pysqlite:///:memory:", future=True, connect_args={"check_same_thread": False}, poolclass=StaticPool, ) Base.metadata.create_all(bind=engine) try: yield engine finally: Base.metadata.drop_all(bind=engine) engine.dispose() @pytest.fixture() def session_factory(engine: Engine) -> Iterator[sessionmaker]: testing_session = sessionmaker( bind=engine, expire_on_commit=False, future=True) yield testing_session @pytest.fixture() def app(session_factory: sessionmaker) -> FastAPI: application = FastAPI() application.include_router(auth_router) application.include_router(dashboard_router) application.include_router(projects_router) application.include_router(scenarios_router) application.include_router(imports_router) application.include_router(exports_router) application.include_router(reports_router) def _override_uow() -> Iterator[UnitOfWork]: with UnitOfWork(session_factory=session_factory) as uow: yield uow application.dependency_overrides[get_unit_of_work] = _override_uow def _ingestion_uow_factory() -> UnitOfWork: return UnitOfWork(session_factory=session_factory) ingestion_service = ImportIngestionService(_ingestion_uow_factory) def _override_ingestion_service() -> ImportIngestionService: return ingestion_service application.dependency_overrides[ get_import_ingestion_service ] = _override_ingestion_service with UnitOfWork(session_factory=session_factory) as uow: assert uow.users is not None uow.ensure_default_roles() user = User( email="test-superuser@example.com", username="test-superuser", password_hash=User.hash_password(random_password()), is_active=True, is_superuser=True, ) uow.users.create(user) user = uow.users.get(user.id, with_roles=True) def _override_auth_session(request: Request) -> AuthSession: session = AuthSession( tokens=SessionTokens( access_token=random_token(), refresh_token=random_token(), ) ) session.user = user request.state.auth_session = session return session application.dependency_overrides[get_auth_session] = _override_auth_session return application @pytest.fixture() def client(app: FastAPI) -> Iterator[TestClient]: test_client = TestClient(app) try: yield test_client finally: test_client.close() @pytest_asyncio.fixture() async def async_client(app: FastAPI) -> AsyncClient: return AsyncClient( transport=ASGITransport(app=app), base_url="http://testserver" ) @pytest.fixture() def unit_of_work_factory(session_factory: sessionmaker) -> Callable[[], UnitOfWork]: def _factory() -> UnitOfWork: return UnitOfWork(session_factory=session_factory) return _factory