feat: implement scenario comparison validation and API endpoint with comprehensive unit tests

This commit is contained in:
2025-11-09 18:42:04 +01:00
parent c39dde3198
commit 02da881d3e
7 changed files with 483 additions and 5 deletions

View File

@@ -1,5 +1,8 @@
"""Domain-level exceptions for service and repository layers."""
from dataclasses import dataclass
from typing import Sequence
class EntityNotFoundError(Exception):
"""Raised when a requested entity cannot be located."""
@@ -7,3 +10,15 @@ class EntityNotFoundError(Exception):
class EntityConflictError(Exception):
"""Raised when attempting to create or update an entity that violates uniqueness."""
@dataclass(eq=False)
class ScenarioValidationError(Exception):
"""Raised when scenarios fail comparison validation rules."""
code: str
message: str
scenario_ids: Sequence[int] | None = None
def __str__(self) -> str: # pragma: no cover - mirrors message for logging
return self.message

View File

@@ -0,0 +1,106 @@
from __future__ import annotations
from dataclasses import dataclass
from datetime import date
from typing import Iterable, Sequence
from models import Scenario, ScenarioStatus
from services.exceptions import ScenarioValidationError
ALLOWED_STATUSES: frozenset[ScenarioStatus] = frozenset(
{ScenarioStatus.DRAFT, ScenarioStatus.ACTIVE}
)
@dataclass(frozen=True)
class _ValidationContext:
scenarios: Sequence[Scenario]
@property
def scenario_ids(self) -> list[int]:
return [scenario.id for scenario in self.scenarios if scenario.id is not None]
class ScenarioComparisonValidator:
"""Validates scenarios prior to comparison workflows."""
def validate(self, scenarios: Sequence[Scenario] | Iterable[Scenario]) -> None:
scenario_list = list(scenarios)
if len(scenario_list) < 2:
# Nothing to validate when fewer than two scenarios are provided.
return
context = _ValidationContext(scenario_list)
self._ensure_same_project(context)
self._ensure_allowed_status(context)
self._ensure_shared_currency(context)
self._ensure_timeline_overlap(context)
self._ensure_shared_primary_resource(context)
def _ensure_same_project(self, context: _ValidationContext) -> None:
project_ids = {scenario.project_id for scenario in context.scenarios}
if len(project_ids) > 1:
raise ScenarioValidationError(
code="SCENARIO_PROJECT_MISMATCH",
message="Selected scenarios do not belong to the same project.",
scenario_ids=context.scenario_ids,
)
def _ensure_allowed_status(self, context: _ValidationContext) -> None:
disallowed = [
scenario
for scenario in context.scenarios
if scenario.status not in ALLOWED_STATUSES
]
if disallowed:
raise ScenarioValidationError(
code="SCENARIO_STATUS_INVALID",
message="Archived scenarios cannot be compared.",
scenario_ids=[
scenario.id for scenario in disallowed if scenario.id is not None],
)
def _ensure_shared_currency(self, context: _ValidationContext) -> None:
currencies = {
scenario.currency
for scenario in context.scenarios
if scenario.currency is not None
}
if len(currencies) > 1:
raise ScenarioValidationError(
code="SCENARIO_CURRENCY_MISMATCH",
message="Scenarios use different currencies and cannot be compared.",
scenario_ids=context.scenario_ids,
)
def _ensure_timeline_overlap(self, context: _ValidationContext) -> None:
ranges = [
(scenario.start_date, scenario.end_date)
for scenario in context.scenarios
if scenario.start_date and scenario.end_date
]
if len(ranges) < 2:
return
latest_start: date = max(start for start, _ in ranges)
earliest_end: date = min(end for _, end in ranges)
if latest_start > earliest_end:
raise ScenarioValidationError(
code="SCENARIO_TIMELINE_DISJOINT",
message="Scenario timelines do not overlap; adjust the comparison window.",
scenario_ids=context.scenario_ids,
)
def _ensure_shared_primary_resource(self, context: _ValidationContext) -> None:
resources = {
scenario.primary_resource
for scenario in context.scenarios
if scenario.primary_resource is not None
}
if len(resources) > 1:
raise ScenarioValidationError(
code="SCENARIO_RESOURCE_MISMATCH",
message="Scenarios target different primary resources and cannot be compared.",
scenario_ids=context.scenario_ids,
)

View File

@@ -1,17 +1,19 @@
from __future__ import annotations
from contextlib import AbstractContextManager
from typing import Callable
from typing import Callable, Sequence
from sqlalchemy.orm import Session
from config.database import SessionLocal
from models import Scenario
from services.repositories import (
FinancialInputRepository,
ProjectRepository,
ScenarioRepository,
SimulationParameterRepository,
)
from services.scenario_validation import ScenarioComparisonValidator
class UnitOfWork(AbstractContextManager["UnitOfWork"]):
@@ -20,13 +22,16 @@ class UnitOfWork(AbstractContextManager["UnitOfWork"]):
def __init__(self, session_factory: Callable[[], Session] = SessionLocal) -> None:
self._session_factory = session_factory
self.session: Session | None = None
self._scenario_validator: ScenarioComparisonValidator | None = None
def __enter__(self) -> "UnitOfWork":
self.session = self._session_factory()
self.projects = ProjectRepository(self.session)
self.scenarios = ScenarioRepository(self.session)
self.financial_inputs = FinancialInputRepository(self.session)
self.simulation_parameters = SimulationParameterRepository(self.session)
self.simulation_parameters = SimulationParameterRepository(
self.session)
self._scenario_validator = ScenarioComparisonValidator()
return self
def __exit__(self, exc_type, exc_value, traceback) -> None:
@@ -36,6 +41,7 @@ class UnitOfWork(AbstractContextManager["UnitOfWork"]):
else:
self.session.rollback()
self.session.close()
self._scenario_validator = None
def flush(self) -> None:
if not self.session:
@@ -51,3 +57,21 @@ class UnitOfWork(AbstractContextManager["UnitOfWork"]):
if not self.session:
raise RuntimeError("UnitOfWork session is not initialised")
self.session.rollback()
def validate_scenarios_for_comparison(
self, scenario_ids: Sequence[int]
) -> list[Scenario]:
if not self.session or not self._scenario_validator:
raise RuntimeError("UnitOfWork session is not initialised")
scenarios = [self.scenarios.get(scenario_id)
for scenario_id in scenario_ids]
self._scenario_validator.validate(scenarios)
return scenarios
def validate_scenario_models_for_comparison(
self, scenarios: Sequence[Scenario]
) -> None:
if not self._scenario_validator:
raise RuntimeError("UnitOfWork session is not initialised")
self._scenario_validator.validate(scenarios)