feat: implement scenario comparison validation and API endpoint with comprehensive unit tests

This commit is contained in:
2025-11-09 18:42:04 +01:00
parent c39dde3198
commit 02da881d3e
7 changed files with 483 additions and 5 deletions

View File

@@ -1,17 +1,19 @@
from __future__ import annotations
from contextlib import AbstractContextManager
from typing import Callable
from typing import Callable, Sequence
from sqlalchemy.orm import Session
from config.database import SessionLocal
from models import Scenario
from services.repositories import (
FinancialInputRepository,
ProjectRepository,
ScenarioRepository,
SimulationParameterRepository,
)
from services.scenario_validation import ScenarioComparisonValidator
class UnitOfWork(AbstractContextManager["UnitOfWork"]):
@@ -20,13 +22,16 @@ class UnitOfWork(AbstractContextManager["UnitOfWork"]):
def __init__(self, session_factory: Callable[[], Session] = SessionLocal) -> None:
self._session_factory = session_factory
self.session: Session | None = None
self._scenario_validator: ScenarioComparisonValidator | None = None
def __enter__(self) -> "UnitOfWork":
self.session = self._session_factory()
self.projects = ProjectRepository(self.session)
self.scenarios = ScenarioRepository(self.session)
self.financial_inputs = FinancialInputRepository(self.session)
self.simulation_parameters = SimulationParameterRepository(self.session)
self.simulation_parameters = SimulationParameterRepository(
self.session)
self._scenario_validator = ScenarioComparisonValidator()
return self
def __exit__(self, exc_type, exc_value, traceback) -> None:
@@ -36,6 +41,7 @@ class UnitOfWork(AbstractContextManager["UnitOfWork"]):
else:
self.session.rollback()
self.session.close()
self._scenario_validator = None
def flush(self) -> None:
if not self.session:
@@ -51,3 +57,21 @@ class UnitOfWork(AbstractContextManager["UnitOfWork"]):
if not self.session:
raise RuntimeError("UnitOfWork session is not initialised")
self.session.rollback()
def validate_scenarios_for_comparison(
self, scenario_ids: Sequence[int]
) -> list[Scenario]:
if not self.session or not self._scenario_validator:
raise RuntimeError("UnitOfWork session is not initialised")
scenarios = [self.scenarios.get(scenario_id)
for scenario_id in scenario_ids]
self._scenario_validator.validate(scenarios)
return scenarios
def validate_scenario_models_for_comparison(
self, scenarios: Sequence[Scenario]
) -> None:
if not self._scenario_validator:
raise RuntimeError("UnitOfWork session is not initialised")
self._scenario_validator.validate(scenarios)