from __future__ import annotations from contextlib import AbstractContextManager from typing import Callable, Sequence from sqlalchemy.orm import Session from config.database import SessionLocal from models import Role, Scenario from services.repositories import ( FinancialInputRepository, ProjectRepository, RoleRepository, ScenarioRepository, SimulationParameterRepository, UserRepository, ensure_admin_user as ensure_admin_user_record, ensure_default_roles, ) from services.scenario_validation import ScenarioComparisonValidator class UnitOfWork(AbstractContextManager["UnitOfWork"]): """Simple unit-of-work wrapper around SQLAlchemy sessions.""" def __init__(self, session_factory: Callable[[], Session] = SessionLocal) -> None: self._session_factory = session_factory self.session: Session | None = None self._scenario_validator: ScenarioComparisonValidator | None = None self.projects: ProjectRepository | None = None self.scenarios: ScenarioRepository | None = None self.financial_inputs: FinancialInputRepository | None = None self.simulation_parameters: SimulationParameterRepository | None = None self.users: UserRepository | None = None self.roles: RoleRepository | None = None def __enter__(self) -> "UnitOfWork": self.session = self._session_factory() self.projects = ProjectRepository(self.session) self.scenarios = ScenarioRepository(self.session) self.financial_inputs = FinancialInputRepository(self.session) self.simulation_parameters = SimulationParameterRepository( self.session) self.users = UserRepository(self.session) self.roles = RoleRepository(self.session) self._scenario_validator = ScenarioComparisonValidator() return self def __exit__(self, exc_type, exc_value, traceback) -> None: assert self.session is not None if exc_type is None: self.session.commit() else: self.session.rollback() self.session.close() self._scenario_validator = None self.projects = None self.scenarios = None self.financial_inputs = None self.simulation_parameters = None self.users = None self.roles = None def flush(self) -> None: if not self.session: raise RuntimeError("UnitOfWork session is not initialised") self.session.flush() def commit(self) -> None: if not self.session: raise RuntimeError("UnitOfWork session is not initialised") self.session.commit() def rollback(self) -> None: if not self.session: raise RuntimeError("UnitOfWork session is not initialised") self.session.rollback() def validate_scenarios_for_comparison( self, scenario_ids: Sequence[int] ) -> list[Scenario]: if not self.session or not self._scenario_validator or not self.scenarios: raise RuntimeError("UnitOfWork session is not initialised") scenarios = [self.scenarios.get(scenario_id) for scenario_id in scenario_ids] self._scenario_validator.validate(scenarios) return scenarios def validate_scenario_models_for_comparison( self, scenarios: Sequence[Scenario] ) -> None: if not self._scenario_validator: raise RuntimeError("UnitOfWork session is not initialised") self._scenario_validator.validate(scenarios) def ensure_default_roles(self) -> list[Role]: if not self.roles: raise RuntimeError("UnitOfWork session is not initialised") return ensure_default_roles(self.roles) def ensure_admin_user( self, *, email: str, username: str, password: str, ) -> None: if not self.users or not self.roles: raise RuntimeError("UnitOfWork session is not initialised") ensure_default_roles(self.roles) ensure_admin_user_record( self.users, self.roles, email=email, username=username, password=password, )