from __future__ import annotations from collections.abc import Callable, Iterator import pytest import pytest_asyncio from fastapi import FastAPI, Request from fastapi.testclient import TestClient from httpx import ASGITransport, AsyncClient from sqlalchemy import create_engine from sqlalchemy.engine import Engine from sqlalchemy.orm import sessionmaker from sqlalchemy.pool import StaticPool from config.database import Base from dependencies import get_auth_session, get_import_ingestion_service, get_unit_of_work from models import User from routes.auth import router as auth_router from routes.dashboard import router as dashboard_router from routes.calculations import router as calculations_router from routes.navigation import router as navigation_router from routes.projects import router as projects_router from routes.scenarios import router as scenarios_router from routes.imports import router as imports_router from routes.exports import router as exports_router from routes.reports import router as reports_router from routes.ui import router as ui_router from services.importers import ImportIngestionService from services.unit_of_work import UnitOfWork from services.session import AuthSession, SessionTokens from tests.utils.security import random_password, random_token BASE_TESTSERVER_URL = "http://testserver" TEST_USER_HEADER = "X-Test-User" @pytest.fixture() def engine() -> Iterator[Engine]: engine = create_engine( "sqlite+pysqlite:///:memory:", future=True, connect_args={"check_same_thread": False}, poolclass=StaticPool, ) Base.metadata.create_all(bind=engine) try: yield engine finally: Base.metadata.drop_all(bind=engine) engine.dispose() @pytest.fixture() def session_factory(engine: Engine) -> Iterator[sessionmaker]: testing_session = sessionmaker( bind=engine, expire_on_commit=False, future=True) yield testing_session @pytest.fixture() def app(session_factory: sessionmaker) -> FastAPI: application = FastAPI() application.include_router(auth_router) application.include_router(dashboard_router) application.include_router(calculations_router) application.include_router(projects_router) application.include_router(navigation_router) application.include_router(scenarios_router) application.include_router(imports_router) application.include_router(exports_router) application.include_router(reports_router) application.include_router(ui_router) def _override_uow() -> Iterator[UnitOfWork]: with UnitOfWork(session_factory=session_factory) as uow: yield uow application.dependency_overrides[get_unit_of_work] = _override_uow def _ingestion_uow_factory() -> UnitOfWork: return UnitOfWork(session_factory=session_factory) ingestion_service = ImportIngestionService(_ingestion_uow_factory) def _override_ingestion_service() -> ImportIngestionService: return ingestion_service application.dependency_overrides[ get_import_ingestion_service ] = _override_ingestion_service with UnitOfWork(session_factory=session_factory) as uow: assert uow.users is not None and uow.roles is not None roles = {role.name: role for role in uow.ensure_default_roles()} admin_user = User( email="test-superuser@example.com", username="test-superuser", password_hash=User.hash_password(random_password()), is_active=True, is_superuser=True, ) viewer_user = User( email="test-viewer@example.com", username="test-viewer", password_hash=User.hash_password(random_password()), is_active=True, is_superuser=False, ) uow.users.create(admin_user) uow.users.create(viewer_user) uow.users.assign_role( user_id=admin_user.id, role_id=roles["admin"].id, granted_by=admin_user.id, ) uow.users.assign_role( user_id=viewer_user.id, role_id=roles["viewer"].id, granted_by=admin_user.id, ) admin_user = uow.users.get(admin_user.id, with_roles=True) viewer_user = uow.users.get(viewer_user.id, with_roles=True) application.state.test_users = { "admin": admin_user, "viewer": viewer_user, } def _resolve_user(alias: str) -> tuple[User, tuple[str, ...]]: normalised = alias.strip().lower() user = application.state.test_users.get(normalised) if user is None: raise ValueError(f"Unknown test user alias: {alias}") roles = tuple(role.name for role in user.roles) return user, roles def _override_auth_session(request: Request) -> AuthSession: alias = request.headers.get(TEST_USER_HEADER, "admin").strip().lower() if alias == "anonymous": session = AuthSession.anonymous() else: user, role_slugs = _resolve_user(alias or "admin") session = AuthSession( tokens=SessionTokens( access_token=random_token(), refresh_token=random_token(), ), user=user, ) session.set_role_slugs(role_slugs) request.state.auth_session = session return session application.dependency_overrides[get_auth_session] = _override_auth_session return application @pytest.fixture() def client(app: FastAPI) -> Iterator[TestClient]: test_client = TestClient(app, headers={TEST_USER_HEADER: "admin"}) try: yield test_client finally: test_client.close() @pytest_asyncio.fixture() async def async_client(app: FastAPI) -> AsyncClient: return AsyncClient( transport=ASGITransport(app=app), base_url="http://testserver", headers={TEST_USER_HEADER: "admin"}, ) @pytest.fixture() def test_user_headers() -> Callable[[str | None], dict[str, str]]: def _factory(alias: str | None = "admin") -> dict[str, str]: if alias is None: return {} return {TEST_USER_HEADER: alias.lower()} return _factory @pytest.fixture() def unit_of_work_factory(session_factory: sessionmaker) -> Callable[[], UnitOfWork]: def _factory() -> UnitOfWork: return UnitOfWork(session_factory=session_factory) return _factory @pytest.fixture() def app_url_for(app: FastAPI) -> Callable[..., str]: def _builder(route_name: str, **path_params: object) -> str: normalised_params = { key: str(value) for key, value in path_params.items() if value is not None } return f"{BASE_TESTSERVER_URL}{app.url_path_for(route_name, **normalised_params)}" return _builder @pytest.fixture() def scenario_calculation_url( app_url_for: Callable[..., str] ) -> Callable[[str, int, int], str]: def _builder(route_name: str, project_id: int, scenario_id: int) -> str: return app_url_for( route_name, project_id=project_id, scenario_id=scenario_id, ) return _builder