from typing import Generator import pytest from fastapi.testclient import TestClient from sqlalchemy import create_engine from sqlalchemy.orm import Session, sessionmaker from sqlalchemy.pool import StaticPool from config.database import Base from main import app SQLALCHEMY_TEST_URL = "sqlite:///:memory:" engine = create_engine( SQLALCHEMY_TEST_URL, connect_args={"check_same_thread": False}, poolclass=StaticPool, ) TestingSessionLocal = sessionmaker( autocommit=False, autoflush=False, bind=engine) @pytest.fixture(scope="session", autouse=True) def setup_database() -> Generator[None, None, None]: # Ensure all model metadata is registered before creating tables from models import ( capex, consumption, distribution, equipment, maintenance, opex, parameters, production_output, scenario, simulation_result, ) # noqa: F401 - imported for side effects Base.metadata.create_all(bind=engine) yield Base.metadata.drop_all(bind=engine) @pytest.fixture() def db_session() -> Generator[Session, None, None]: session = TestingSessionLocal() try: yield session finally: session.close() @pytest.fixture() def api_client(db_session: Session) -> Generator[TestClient, None, None]: def override_get_db(): try: yield db_session finally: pass # override all routers that use get_db from routes import ( consumption, costs, distributions, equipment, maintenance, parameters, production, reporting, scenarios, simulations, ) overrides = { consumption.get_db: override_get_db, costs.get_db: override_get_db, distributions.get_db: override_get_db, equipment.get_db: override_get_db, maintenance.get_db: override_get_db, parameters.get_db: override_get_db, production.get_db: override_get_db, reporting.get_db: override_get_db, scenarios.get_db: override_get_db, simulations.get_db: override_get_db, } for dependency, override in overrides.items(): app.dependency_overrides[dependency] = override with TestClient(app) as client: yield client for dependency in overrides: app.dependency_overrides.pop(dependency, None)