from __future__ import annotations from contextlib import AbstractContextManager from typing import Callable, Sequence from sqlalchemy.orm import Session from config.database import SessionLocal from models import Scenario from services.repositories import ( FinancialInputRepository, ProjectRepository, ScenarioRepository, SimulationParameterRepository, ) from services.scenario_validation import ScenarioComparisonValidator class UnitOfWork(AbstractContextManager["UnitOfWork"]): """Simple unit-of-work wrapper around SQLAlchemy sessions.""" def __init__(self, session_factory: Callable[[], Session] = SessionLocal) -> None: self._session_factory = session_factory self.session: Session | None = None self._scenario_validator: ScenarioComparisonValidator | None = None def __enter__(self) -> "UnitOfWork": self.session = self._session_factory() self.projects = ProjectRepository(self.session) self.scenarios = ScenarioRepository(self.session) self.financial_inputs = FinancialInputRepository(self.session) self.simulation_parameters = SimulationParameterRepository( self.session) self._scenario_validator = ScenarioComparisonValidator() return self def __exit__(self, exc_type, exc_value, traceback) -> None: assert self.session is not None if exc_type is None: self.session.commit() else: self.session.rollback() self.session.close() self._scenario_validator = None def flush(self) -> None: if not self.session: raise RuntimeError("UnitOfWork session is not initialised") self.session.flush() def commit(self) -> None: if not self.session: raise RuntimeError("UnitOfWork session is not initialised") self.session.commit() def rollback(self) -> None: if not self.session: raise RuntimeError("UnitOfWork session is not initialised") self.session.rollback() def validate_scenarios_for_comparison( self, scenario_ids: Sequence[int] ) -> list[Scenario]: if not self.session or not self._scenario_validator: raise RuntimeError("UnitOfWork session is not initialised") scenarios = [self.scenarios.get(scenario_id) for scenario_id in scenario_ids] self._scenario_validator.validate(scenarios) return scenarios def validate_scenario_models_for_comparison( self, scenarios: Sequence[Scenario] ) -> None: if not self._scenario_validator: raise RuntimeError("UnitOfWork session is not initialised") self._scenario_validator.validate(scenarios)