from typing import Generator import pytest from fastapi.testclient import TestClient from sqlalchemy import create_engine from sqlalchemy.orm import Session, sessionmaker from sqlalchemy.pool import StaticPool from config.database import Base from main import app SQLALCHEMY_TEST_URL = "sqlite:///:memory:" engine = create_engine( SQLALCHEMY_TEST_URL, connect_args={"check_same_thread": False}, poolclass=StaticPool, ) TestingSessionLocal = sessionmaker( autocommit=False, autoflush=False, bind=engine) @pytest.fixture(scope="session", autouse=True) def setup_database() -> Generator[None, None, None]: # Ensure all model metadata is registered before creating tables from models import ( capex, consumption, distribution, equipment, maintenance, opex, parameters, production_output, scenario, simulation_result, ) # noqa: F401 - imported for side effects Base.metadata.create_all(bind=engine) yield Base.metadata.drop_all(bind=engine) @pytest.fixture() def db_session() -> Generator[Session, None, None]: session = TestingSessionLocal() try: yield session finally: session.close() @pytest.fixture() def api_client(db_session: Session) -> Generator[TestClient, None, None]: def override_get_db(): try: yield db_session finally: pass from routes import dependencies as route_dependencies app.dependency_overrides[route_dependencies.get_db] = override_get_db with TestClient(app) as client: yield client app.dependency_overrides.pop(route_dependencies.get_db, None)