feat: implement scenario comparison validation and API endpoint with comprehensive unit tests
This commit is contained in:
@@ -1,5 +1,8 @@
|
||||
"""Domain-level exceptions for service and repository layers."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Sequence
|
||||
|
||||
|
||||
class EntityNotFoundError(Exception):
|
||||
"""Raised when a requested entity cannot be located."""
|
||||
@@ -7,3 +10,15 @@ class EntityNotFoundError(Exception):
|
||||
|
||||
class EntityConflictError(Exception):
|
||||
"""Raised when attempting to create or update an entity that violates uniqueness."""
|
||||
|
||||
|
||||
@dataclass(eq=False)
|
||||
class ScenarioValidationError(Exception):
|
||||
"""Raised when scenarios fail comparison validation rules."""
|
||||
|
||||
code: str
|
||||
message: str
|
||||
scenario_ids: Sequence[int] | None = None
|
||||
|
||||
def __str__(self) -> str: # pragma: no cover - mirrors message for logging
|
||||
return self.message
|
||||
|
||||
106
services/scenario_validation.py
Normal file
106
services/scenario_validation.py
Normal file
@@ -0,0 +1,106 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from datetime import date
|
||||
from typing import Iterable, Sequence
|
||||
|
||||
from models import Scenario, ScenarioStatus
|
||||
from services.exceptions import ScenarioValidationError
|
||||
|
||||
ALLOWED_STATUSES: frozenset[ScenarioStatus] = frozenset(
|
||||
{ScenarioStatus.DRAFT, ScenarioStatus.ACTIVE}
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class _ValidationContext:
|
||||
scenarios: Sequence[Scenario]
|
||||
|
||||
@property
|
||||
def scenario_ids(self) -> list[int]:
|
||||
return [scenario.id for scenario in self.scenarios if scenario.id is not None]
|
||||
|
||||
|
||||
class ScenarioComparisonValidator:
|
||||
"""Validates scenarios prior to comparison workflows."""
|
||||
|
||||
def validate(self, scenarios: Sequence[Scenario] | Iterable[Scenario]) -> None:
|
||||
scenario_list = list(scenarios)
|
||||
if len(scenario_list) < 2:
|
||||
# Nothing to validate when fewer than two scenarios are provided.
|
||||
return
|
||||
|
||||
context = _ValidationContext(scenario_list)
|
||||
|
||||
self._ensure_same_project(context)
|
||||
self._ensure_allowed_status(context)
|
||||
self._ensure_shared_currency(context)
|
||||
self._ensure_timeline_overlap(context)
|
||||
self._ensure_shared_primary_resource(context)
|
||||
|
||||
def _ensure_same_project(self, context: _ValidationContext) -> None:
|
||||
project_ids = {scenario.project_id for scenario in context.scenarios}
|
||||
if len(project_ids) > 1:
|
||||
raise ScenarioValidationError(
|
||||
code="SCENARIO_PROJECT_MISMATCH",
|
||||
message="Selected scenarios do not belong to the same project.",
|
||||
scenario_ids=context.scenario_ids,
|
||||
)
|
||||
|
||||
def _ensure_allowed_status(self, context: _ValidationContext) -> None:
|
||||
disallowed = [
|
||||
scenario
|
||||
for scenario in context.scenarios
|
||||
if scenario.status not in ALLOWED_STATUSES
|
||||
]
|
||||
if disallowed:
|
||||
raise ScenarioValidationError(
|
||||
code="SCENARIO_STATUS_INVALID",
|
||||
message="Archived scenarios cannot be compared.",
|
||||
scenario_ids=[
|
||||
scenario.id for scenario in disallowed if scenario.id is not None],
|
||||
)
|
||||
|
||||
def _ensure_shared_currency(self, context: _ValidationContext) -> None:
|
||||
currencies = {
|
||||
scenario.currency
|
||||
for scenario in context.scenarios
|
||||
if scenario.currency is not None
|
||||
}
|
||||
if len(currencies) > 1:
|
||||
raise ScenarioValidationError(
|
||||
code="SCENARIO_CURRENCY_MISMATCH",
|
||||
message="Scenarios use different currencies and cannot be compared.",
|
||||
scenario_ids=context.scenario_ids,
|
||||
)
|
||||
|
||||
def _ensure_timeline_overlap(self, context: _ValidationContext) -> None:
|
||||
ranges = [
|
||||
(scenario.start_date, scenario.end_date)
|
||||
for scenario in context.scenarios
|
||||
if scenario.start_date and scenario.end_date
|
||||
]
|
||||
if len(ranges) < 2:
|
||||
return
|
||||
|
||||
latest_start: date = max(start for start, _ in ranges)
|
||||
earliest_end: date = min(end for _, end in ranges)
|
||||
if latest_start > earliest_end:
|
||||
raise ScenarioValidationError(
|
||||
code="SCENARIO_TIMELINE_DISJOINT",
|
||||
message="Scenario timelines do not overlap; adjust the comparison window.",
|
||||
scenario_ids=context.scenario_ids,
|
||||
)
|
||||
|
||||
def _ensure_shared_primary_resource(self, context: _ValidationContext) -> None:
|
||||
resources = {
|
||||
scenario.primary_resource
|
||||
for scenario in context.scenarios
|
||||
if scenario.primary_resource is not None
|
||||
}
|
||||
if len(resources) > 1:
|
||||
raise ScenarioValidationError(
|
||||
code="SCENARIO_RESOURCE_MISMATCH",
|
||||
message="Scenarios target different primary resources and cannot be compared.",
|
||||
scenario_ids=context.scenario_ids,
|
||||
)
|
||||
@@ -1,17 +1,19 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from contextlib import AbstractContextManager
|
||||
from typing import Callable
|
||||
from typing import Callable, Sequence
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from config.database import SessionLocal
|
||||
from models import Scenario
|
||||
from services.repositories import (
|
||||
FinancialInputRepository,
|
||||
ProjectRepository,
|
||||
ScenarioRepository,
|
||||
SimulationParameterRepository,
|
||||
)
|
||||
from services.scenario_validation import ScenarioComparisonValidator
|
||||
|
||||
|
||||
class UnitOfWork(AbstractContextManager["UnitOfWork"]):
|
||||
@@ -20,13 +22,16 @@ class UnitOfWork(AbstractContextManager["UnitOfWork"]):
|
||||
def __init__(self, session_factory: Callable[[], Session] = SessionLocal) -> None:
|
||||
self._session_factory = session_factory
|
||||
self.session: Session | None = None
|
||||
self._scenario_validator: ScenarioComparisonValidator | None = None
|
||||
|
||||
def __enter__(self) -> "UnitOfWork":
|
||||
self.session = self._session_factory()
|
||||
self.projects = ProjectRepository(self.session)
|
||||
self.scenarios = ScenarioRepository(self.session)
|
||||
self.financial_inputs = FinancialInputRepository(self.session)
|
||||
self.simulation_parameters = SimulationParameterRepository(self.session)
|
||||
self.simulation_parameters = SimulationParameterRepository(
|
||||
self.session)
|
||||
self._scenario_validator = ScenarioComparisonValidator()
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback) -> None:
|
||||
@@ -36,6 +41,7 @@ class UnitOfWork(AbstractContextManager["UnitOfWork"]):
|
||||
else:
|
||||
self.session.rollback()
|
||||
self.session.close()
|
||||
self._scenario_validator = None
|
||||
|
||||
def flush(self) -> None:
|
||||
if not self.session:
|
||||
@@ -51,3 +57,21 @@ class UnitOfWork(AbstractContextManager["UnitOfWork"]):
|
||||
if not self.session:
|
||||
raise RuntimeError("UnitOfWork session is not initialised")
|
||||
self.session.rollback()
|
||||
|
||||
def validate_scenarios_for_comparison(
|
||||
self, scenario_ids: Sequence[int]
|
||||
) -> list[Scenario]:
|
||||
if not self.session or not self._scenario_validator:
|
||||
raise RuntimeError("UnitOfWork session is not initialised")
|
||||
|
||||
scenarios = [self.scenarios.get(scenario_id)
|
||||
for scenario_id in scenario_ids]
|
||||
self._scenario_validator.validate(scenarios)
|
||||
return scenarios
|
||||
|
||||
def validate_scenario_models_for_comparison(
|
||||
self, scenarios: Sequence[Scenario]
|
||||
) -> None:
|
||||
if not self._scenario_validator:
|
||||
raise RuntimeError("UnitOfWork session is not initialised")
|
||||
self._scenario_validator.validate(scenarios)
|
||||
|
||||
Reference in New Issue
Block a user