78 lines
2.7 KiB
Python
78 lines
2.7 KiB
Python
from __future__ import annotations
|
|
|
|
from contextlib import AbstractContextManager
|
|
from typing import Callable, Sequence
|
|
|
|
from sqlalchemy.orm import Session
|
|
|
|
from config.database import SessionLocal
|
|
from models import Scenario
|
|
from services.repositories import (
|
|
FinancialInputRepository,
|
|
ProjectRepository,
|
|
ScenarioRepository,
|
|
SimulationParameterRepository,
|
|
)
|
|
from services.scenario_validation import ScenarioComparisonValidator
|
|
|
|
|
|
class UnitOfWork(AbstractContextManager["UnitOfWork"]):
|
|
"""Simple unit-of-work wrapper around SQLAlchemy sessions."""
|
|
|
|
def __init__(self, session_factory: Callable[[], Session] = SessionLocal) -> None:
|
|
self._session_factory = session_factory
|
|
self.session: Session | None = None
|
|
self._scenario_validator: ScenarioComparisonValidator | None = None
|
|
|
|
def __enter__(self) -> "UnitOfWork":
|
|
self.session = self._session_factory()
|
|
self.projects = ProjectRepository(self.session)
|
|
self.scenarios = ScenarioRepository(self.session)
|
|
self.financial_inputs = FinancialInputRepository(self.session)
|
|
self.simulation_parameters = SimulationParameterRepository(
|
|
self.session)
|
|
self._scenario_validator = ScenarioComparisonValidator()
|
|
return self
|
|
|
|
def __exit__(self, exc_type, exc_value, traceback) -> None:
|
|
assert self.session is not None
|
|
if exc_type is None:
|
|
self.session.commit()
|
|
else:
|
|
self.session.rollback()
|
|
self.session.close()
|
|
self._scenario_validator = None
|
|
|
|
def flush(self) -> None:
|
|
if not self.session:
|
|
raise RuntimeError("UnitOfWork session is not initialised")
|
|
self.session.flush()
|
|
|
|
def commit(self) -> None:
|
|
if not self.session:
|
|
raise RuntimeError("UnitOfWork session is not initialised")
|
|
self.session.commit()
|
|
|
|
def rollback(self) -> None:
|
|
if not self.session:
|
|
raise RuntimeError("UnitOfWork session is not initialised")
|
|
self.session.rollback()
|
|
|
|
def validate_scenarios_for_comparison(
|
|
self, scenario_ids: Sequence[int]
|
|
) -> list[Scenario]:
|
|
if not self.session or not self._scenario_validator:
|
|
raise RuntimeError("UnitOfWork session is not initialised")
|
|
|
|
scenarios = [self.scenarios.get(scenario_id)
|
|
for scenario_id in scenario_ids]
|
|
self._scenario_validator.validate(scenarios)
|
|
return scenarios
|
|
|
|
def validate_scenario_models_for_comparison(
|
|
self, scenarios: Sequence[Scenario]
|
|
) -> None:
|
|
if not self._scenario_validator:
|
|
raise RuntimeError("UnitOfWork session is not initialised")
|
|
self._scenario_validator.validate(scenarios)
|