115 lines
3.6 KiB
Python
115 lines
3.6 KiB
Python
from __future__ import annotations
|
|
|
|
from collections.abc import Callable, Iterator
|
|
|
|
import pytest
|
|
from fastapi import FastAPI, Request
|
|
from fastapi.testclient import TestClient
|
|
from sqlalchemy import create_engine
|
|
from sqlalchemy.engine import Engine
|
|
from sqlalchemy.orm import sessionmaker
|
|
from sqlalchemy.pool import StaticPool
|
|
|
|
from config.database import Base
|
|
from dependencies import get_auth_session, get_import_ingestion_service, get_unit_of_work
|
|
from models import User
|
|
from routes.auth import router as auth_router
|
|
from routes.dashboard import router as dashboard_router
|
|
from routes.projects import router as projects_router
|
|
from routes.scenarios import router as scenarios_router
|
|
from routes.imports import router as imports_router
|
|
from services.importers import ImportIngestionService
|
|
from services.unit_of_work import UnitOfWork
|
|
from services.session import AuthSession, SessionTokens
|
|
|
|
|
|
@pytest.fixture()
|
|
def engine() -> Iterator[Engine]:
|
|
engine = create_engine(
|
|
"sqlite+pysqlite:///:memory:",
|
|
future=True,
|
|
connect_args={"check_same_thread": False},
|
|
poolclass=StaticPool,
|
|
)
|
|
Base.metadata.create_all(bind=engine)
|
|
try:
|
|
yield engine
|
|
finally:
|
|
Base.metadata.drop_all(bind=engine)
|
|
engine.dispose()
|
|
|
|
|
|
@pytest.fixture()
|
|
def session_factory(engine: Engine) -> Iterator[sessionmaker]:
|
|
testing_session = sessionmaker(
|
|
bind=engine, expire_on_commit=False, future=True)
|
|
yield testing_session
|
|
|
|
|
|
@pytest.fixture()
|
|
def app(session_factory: sessionmaker) -> FastAPI:
|
|
application = FastAPI()
|
|
application.include_router(auth_router)
|
|
application.include_router(dashboard_router)
|
|
application.include_router(projects_router)
|
|
application.include_router(scenarios_router)
|
|
application.include_router(imports_router)
|
|
|
|
def _override_uow() -> Iterator[UnitOfWork]:
|
|
with UnitOfWork(session_factory=session_factory) as uow:
|
|
yield uow
|
|
|
|
application.dependency_overrides[get_unit_of_work] = _override_uow
|
|
|
|
def _ingestion_uow_factory() -> UnitOfWork:
|
|
return UnitOfWork(session_factory=session_factory)
|
|
|
|
ingestion_service = ImportIngestionService(_ingestion_uow_factory)
|
|
|
|
def _override_ingestion_service() -> ImportIngestionService:
|
|
return ingestion_service
|
|
|
|
application.dependency_overrides[
|
|
get_import_ingestion_service
|
|
] = _override_ingestion_service
|
|
|
|
with UnitOfWork(session_factory=session_factory) as uow:
|
|
assert uow.users is not None
|
|
uow.ensure_default_roles()
|
|
user = User(
|
|
email="test-superuser@example.com",
|
|
username="test-superuser",
|
|
password_hash=User.hash_password("test-password"),
|
|
is_active=True,
|
|
is_superuser=True,
|
|
)
|
|
uow.users.create(user)
|
|
user = uow.users.get(user.id, with_roles=True)
|
|
|
|
def _override_auth_session(request: Request) -> AuthSession:
|
|
session = AuthSession(tokens=SessionTokens(
|
|
access_token="test", refresh_token="test"))
|
|
session.user = user
|
|
request.state.auth_session = session
|
|
return session
|
|
|
|
application.dependency_overrides[get_auth_session] = _override_auth_session
|
|
return application
|
|
|
|
|
|
@pytest.fixture()
|
|
def client(app: FastAPI) -> Iterator[TestClient]:
|
|
test_client = TestClient(app)
|
|
try:
|
|
yield test_client
|
|
finally:
|
|
test_client.close()
|
|
|
|
|
|
@pytest.fixture()
|
|
def unit_of_work_factory(session_factory: sessionmaker) -> Callable[[], UnitOfWork]:
|
|
def _factory() -> UnitOfWork:
|
|
return UnitOfWork(session_factory=session_factory)
|
|
|
|
return _factory
|