From 02da881d3e842511730458fcd10a770c57e1ac77 Mon Sep 17 00:00:00 2001 From: zwitschi Date: Sun, 9 Nov 2025 18:42:04 +0100 Subject: [PATCH] feat: implement scenario comparison validation and API endpoint with comprehensive unit tests --- changelog.md | 1 + routes/scenarios.py | 61 ++++++- schemas/scenario.py | 23 ++- services/exceptions.py | 15 ++ services/scenario_validation.py | 106 +++++++++++++ services/unit_of_work.py | 28 +++- tests/test_scenario_validation.py | 254 ++++++++++++++++++++++++++++++ 7 files changed, 483 insertions(+), 5 deletions(-) create mode 100644 services/scenario_validation.py create mode 100644 tests/test_scenario_validation.py diff --git a/changelog.md b/changelog.md index 189f9ad..8c5d180 100644 --- a/changelog.md +++ b/changelog.md @@ -9,3 +9,4 @@ - Exposed project and scenario CRUD APIs with validated schemas and integrated them into the FastAPI application. - Connected project and scenario routers to new Jinja2 list/detail/edit views with HTML forms and redirects. - Implemented FR-009 client-side enhancements with responsive navigation toggle, mobile-first scenario tables, and shared asset loading across templates. +- Added scenario comparison validator, FastAPI comparison endpoint, and comprehensive unit tests to enforce FR-009 validation rules through API errors. diff --git a/routes/scenarios.py b/routes/scenarios.py index 465534d..bce46f1 100644 --- a/routes/scenarios.py +++ b/routes/scenarios.py @@ -9,8 +9,18 @@ from fastapi.templating import Jinja2Templates from dependencies import get_unit_of_work from models import ResourceType, Scenario, ScenarioStatus -from schemas.scenario import ScenarioCreate, ScenarioRead, ScenarioUpdate -from services.exceptions import EntityConflictError, EntityNotFoundError +from schemas.scenario import ( + ScenarioComparisonRequest, + ScenarioComparisonResponse, + ScenarioCreate, + ScenarioRead, + ScenarioUpdate, +) +from services.exceptions import ( + EntityConflictError, + EntityNotFoundError, + ScenarioValidationError, +) from services.unit_of_work import UnitOfWork router = APIRouter(tags=["Scenarios"]) @@ -51,6 +61,53 @@ def list_scenarios_for_project( return [_to_read_model(scenario) for scenario in scenarios] +@router.post( + "/projects/{project_id}/scenarios/compare", + response_model=ScenarioComparisonResponse, + status_code=status.HTTP_200_OK, +) +def compare_scenarios( + project_id: int, + payload: ScenarioComparisonRequest, + uow: UnitOfWork = Depends(get_unit_of_work), +) -> ScenarioComparisonResponse: + try: + uow.projects.get(project_id) + except EntityNotFoundError as exc: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail=str(exc) + ) from exc + + try: + scenarios = uow.validate_scenarios_for_comparison(payload.scenario_ids) + if any(scenario.project_id != project_id for scenario in scenarios): + raise ScenarioValidationError( + code="SCENARIO_PROJECT_MISMATCH", + message="Selected scenarios do not belong to the same project.", + scenario_ids=[ + scenario.id for scenario in scenarios if scenario.id is not None + ], + ) + except EntityNotFoundError as exc: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail=str(exc) + ) from exc + except ScenarioValidationError as exc: + raise HTTPException( + status_code=status.HTTP_422_UNPROCESSABLE_CONTENT, + detail={ + "code": exc.code, + "message": exc.message, + "scenario_ids": list(exc.scenario_ids or []), + }, + ) from exc + + return ScenarioComparisonResponse( + project_id=project_id, + scenarios=[_to_read_model(scenario) for scenario in scenarios], + ) + + @router.post( "/projects/{project_id}/scenarios", response_model=ScenarioRead, diff --git a/schemas/scenario.py b/schemas/scenario.py index 6681cba..73573d6 100644 --- a/schemas/scenario.py +++ b/schemas/scenario.py @@ -2,7 +2,7 @@ from __future__ import annotations from datetime import date, datetime -from pydantic import BaseModel, ConfigDict, field_validator +from pydantic import BaseModel, ConfigDict, field_validator, model_validator from models import ResourceType, ScenarioStatus @@ -64,3 +64,24 @@ class ScenarioRead(ScenarioBase): updated_at: datetime model_config = ConfigDict(from_attributes=True) + + +class ScenarioComparisonRequest(BaseModel): + scenario_ids: list[int] + + model_config = ConfigDict(extra="forbid") + + @model_validator(mode="after") + def ensure_minimum_ids(self) -> "ScenarioComparisonRequest": + unique_ids: list[int] = list(dict.fromkeys(self.scenario_ids)) + if len(unique_ids) < 2: + raise ValueError("At least two unique scenario identifiers are required for comparison.") + self.scenario_ids = unique_ids + return self + + +class ScenarioComparisonResponse(BaseModel): + project_id: int + scenarios: list[ScenarioRead] + + model_config = ConfigDict(from_attributes=True) diff --git a/services/exceptions.py b/services/exceptions.py index 6e58c29..bfd1ce4 100644 --- a/services/exceptions.py +++ b/services/exceptions.py @@ -1,5 +1,8 @@ """Domain-level exceptions for service and repository layers.""" +from dataclasses import dataclass +from typing import Sequence + class EntityNotFoundError(Exception): """Raised when a requested entity cannot be located.""" @@ -7,3 +10,15 @@ class EntityNotFoundError(Exception): class EntityConflictError(Exception): """Raised when attempting to create or update an entity that violates uniqueness.""" + + +@dataclass(eq=False) +class ScenarioValidationError(Exception): + """Raised when scenarios fail comparison validation rules.""" + + code: str + message: str + scenario_ids: Sequence[int] | None = None + + def __str__(self) -> str: # pragma: no cover - mirrors message for logging + return self.message diff --git a/services/scenario_validation.py b/services/scenario_validation.py new file mode 100644 index 0000000..cf4b4a2 --- /dev/null +++ b/services/scenario_validation.py @@ -0,0 +1,106 @@ +from __future__ import annotations + +from dataclasses import dataclass +from datetime import date +from typing import Iterable, Sequence + +from models import Scenario, ScenarioStatus +from services.exceptions import ScenarioValidationError + +ALLOWED_STATUSES: frozenset[ScenarioStatus] = frozenset( + {ScenarioStatus.DRAFT, ScenarioStatus.ACTIVE} +) + + +@dataclass(frozen=True) +class _ValidationContext: + scenarios: Sequence[Scenario] + + @property + def scenario_ids(self) -> list[int]: + return [scenario.id for scenario in self.scenarios if scenario.id is not None] + + +class ScenarioComparisonValidator: + """Validates scenarios prior to comparison workflows.""" + + def validate(self, scenarios: Sequence[Scenario] | Iterable[Scenario]) -> None: + scenario_list = list(scenarios) + if len(scenario_list) < 2: + # Nothing to validate when fewer than two scenarios are provided. + return + + context = _ValidationContext(scenario_list) + + self._ensure_same_project(context) + self._ensure_allowed_status(context) + self._ensure_shared_currency(context) + self._ensure_timeline_overlap(context) + self._ensure_shared_primary_resource(context) + + def _ensure_same_project(self, context: _ValidationContext) -> None: + project_ids = {scenario.project_id for scenario in context.scenarios} + if len(project_ids) > 1: + raise ScenarioValidationError( + code="SCENARIO_PROJECT_MISMATCH", + message="Selected scenarios do not belong to the same project.", + scenario_ids=context.scenario_ids, + ) + + def _ensure_allowed_status(self, context: _ValidationContext) -> None: + disallowed = [ + scenario + for scenario in context.scenarios + if scenario.status not in ALLOWED_STATUSES + ] + if disallowed: + raise ScenarioValidationError( + code="SCENARIO_STATUS_INVALID", + message="Archived scenarios cannot be compared.", + scenario_ids=[ + scenario.id for scenario in disallowed if scenario.id is not None], + ) + + def _ensure_shared_currency(self, context: _ValidationContext) -> None: + currencies = { + scenario.currency + for scenario in context.scenarios + if scenario.currency is not None + } + if len(currencies) > 1: + raise ScenarioValidationError( + code="SCENARIO_CURRENCY_MISMATCH", + message="Scenarios use different currencies and cannot be compared.", + scenario_ids=context.scenario_ids, + ) + + def _ensure_timeline_overlap(self, context: _ValidationContext) -> None: + ranges = [ + (scenario.start_date, scenario.end_date) + for scenario in context.scenarios + if scenario.start_date and scenario.end_date + ] + if len(ranges) < 2: + return + + latest_start: date = max(start for start, _ in ranges) + earliest_end: date = min(end for _, end in ranges) + if latest_start > earliest_end: + raise ScenarioValidationError( + code="SCENARIO_TIMELINE_DISJOINT", + message="Scenario timelines do not overlap; adjust the comparison window.", + scenario_ids=context.scenario_ids, + ) + + def _ensure_shared_primary_resource(self, context: _ValidationContext) -> None: + resources = { + scenario.primary_resource + for scenario in context.scenarios + if scenario.primary_resource is not None + } + if len(resources) > 1: + raise ScenarioValidationError( + code="SCENARIO_RESOURCE_MISMATCH", + message="Scenarios target different primary resources and cannot be compared.", + scenario_ids=context.scenario_ids, + ) diff --git a/services/unit_of_work.py b/services/unit_of_work.py index 6bb8cb8..2e7b9e8 100644 --- a/services/unit_of_work.py +++ b/services/unit_of_work.py @@ -1,17 +1,19 @@ from __future__ import annotations from contextlib import AbstractContextManager -from typing import Callable +from typing import Callable, Sequence from sqlalchemy.orm import Session from config.database import SessionLocal +from models import Scenario from services.repositories import ( FinancialInputRepository, ProjectRepository, ScenarioRepository, SimulationParameterRepository, ) +from services.scenario_validation import ScenarioComparisonValidator class UnitOfWork(AbstractContextManager["UnitOfWork"]): @@ -20,13 +22,16 @@ class UnitOfWork(AbstractContextManager["UnitOfWork"]): def __init__(self, session_factory: Callable[[], Session] = SessionLocal) -> None: self._session_factory = session_factory self.session: Session | None = None + self._scenario_validator: ScenarioComparisonValidator | None = None def __enter__(self) -> "UnitOfWork": self.session = self._session_factory() self.projects = ProjectRepository(self.session) self.scenarios = ScenarioRepository(self.session) self.financial_inputs = FinancialInputRepository(self.session) - self.simulation_parameters = SimulationParameterRepository(self.session) + self.simulation_parameters = SimulationParameterRepository( + self.session) + self._scenario_validator = ScenarioComparisonValidator() return self def __exit__(self, exc_type, exc_value, traceback) -> None: @@ -36,6 +41,7 @@ class UnitOfWork(AbstractContextManager["UnitOfWork"]): else: self.session.rollback() self.session.close() + self._scenario_validator = None def flush(self) -> None: if not self.session: @@ -51,3 +57,21 @@ class UnitOfWork(AbstractContextManager["UnitOfWork"]): if not self.session: raise RuntimeError("UnitOfWork session is not initialised") self.session.rollback() + + def validate_scenarios_for_comparison( + self, scenario_ids: Sequence[int] + ) -> list[Scenario]: + if not self.session or not self._scenario_validator: + raise RuntimeError("UnitOfWork session is not initialised") + + scenarios = [self.scenarios.get(scenario_id) + for scenario_id in scenario_ids] + self._scenario_validator.validate(scenarios) + return scenarios + + def validate_scenario_models_for_comparison( + self, scenarios: Sequence[Scenario] + ) -> None: + if not self._scenario_validator: + raise RuntimeError("UnitOfWork session is not initialised") + self._scenario_validator.validate(scenarios) diff --git a/tests/test_scenario_validation.py b/tests/test_scenario_validation.py new file mode 100644 index 0000000..bb26203 --- /dev/null +++ b/tests/test_scenario_validation.py @@ -0,0 +1,254 @@ +from __future__ import annotations + +from datetime import date +from collections.abc import Iterator +from typing import cast +from uuid import uuid4 + +import pytest +from fastapi import FastAPI +from fastapi.testclient import TestClient +from pydantic import ValidationError +from sqlalchemy import create_engine +from sqlalchemy.orm import sessionmaker +from sqlalchemy.engine import Engine +from sqlalchemy.pool import StaticPool + +from config.database import Base +from dependencies import get_unit_of_work +from models import ( + MiningOperationType, + Project, + ResourceType, + Scenario, + ScenarioStatus, +) +from schemas.scenario import ( + ScenarioComparisonRequest, + ScenarioComparisonResponse, +) +from services.exceptions import ScenarioValidationError +from services.scenario_validation import ScenarioComparisonValidator +from services.unit_of_work import UnitOfWork +from routes.scenarios import router as scenarios_router + + +@pytest.fixture() +def validator() -> ScenarioComparisonValidator: + return ScenarioComparisonValidator() + + +def _make_scenario(**overrides) -> Scenario: + project_id: int = int(overrides.get("project_id", 1)) + name: str = str(overrides.get("name", "Scenario")) + status = cast(ScenarioStatus, overrides.get( + "status", ScenarioStatus.DRAFT)) + start_date = overrides.get("start_date", date(2025, 1, 1)) + end_date = overrides.get("end_date", date(2025, 12, 31)) + currency = cast(str, overrides.get("currency", "USD")) + primary_resource = cast(ResourceType, overrides.get( + "primary_resource", ResourceType.DIESEL)) + + scenario = Scenario( + project_id=project_id, + name=name, + status=status, + start_date=start_date, + end_date=end_date, + currency=currency, + primary_resource=primary_resource, + ) + + if "id" in overrides: + scenario.id = overrides["id"] + + return scenario + + +class TestScenarioComparisonValidator: + def test_validate_allows_matching_scenarios(self, validator: ScenarioComparisonValidator) -> None: + scenario_a = _make_scenario(id=1) + scenario_b = _make_scenario(id=2) + + validator.validate([scenario_a, scenario_b]) + + @pytest.mark.parametrize( + "kwargs_a, kwargs_b, expected_code", + [ + ({"project_id": 1}, {"project_id": 2}, "SCENARIO_PROJECT_MISMATCH"), + ({"status": ScenarioStatus.ARCHIVED}, {}, "SCENARIO_STATUS_INVALID"), + ({"currency": "USD"}, {"currency": "CAD"}, + "SCENARIO_CURRENCY_MISMATCH"), + ( + {"start_date": date(2025, 1, 1), "end_date": date(2025, 6, 1)}, + {"start_date": date(2025, 7, 1), + "end_date": date(2025, 12, 31)}, + "SCENARIO_TIMELINE_DISJOINT", + ), + ({"primary_resource": ResourceType.DIESEL}, { + "primary_resource": ResourceType.ELECTRICITY}, "SCENARIO_RESOURCE_MISMATCH"), + ], + ) + def test_validate_raises_for_conflicts( + self, + validator: ScenarioComparisonValidator, + kwargs_a: dict[str, object], + kwargs_b: dict[str, object], + expected_code: str, + ) -> None: + scenario_a = _make_scenario(id=10, **kwargs_a) + scenario_b = _make_scenario(id=20, **kwargs_b) + + with pytest.raises(ScenarioValidationError) as exc_info: + validator.validate([scenario_a, scenario_b]) + + assert exc_info.value.code == expected_code + + def test_timeline_rule_skips_when_insufficient_ranges( + self, validator: ScenarioComparisonValidator + ) -> None: + scenario_a = _make_scenario(id=1, start_date=None, end_date=None) + scenario_b = _make_scenario(id=2, start_date=date( + 2025, 1, 1), end_date=date(2025, 12, 31)) + + validator.validate([scenario_a, scenario_b]) + + +class TestScenarioComparisonRequest: + def test_requires_two_unique_identifiers(self) -> None: + with pytest.raises(ValidationError): + ScenarioComparisonRequest.model_validate({"scenario_ids": [1]}) + + def test_deduplicates_ids_preserving_order(self) -> None: + payload = ScenarioComparisonRequest.model_validate( + {"scenario_ids": [1, 1, 2, 2, 3]}) + + assert payload.scenario_ids == [1, 2, 3] + + +@pytest.fixture() +def engine() -> Iterator[Engine]: + engine = create_engine( + "sqlite+pysqlite:///:memory:", + future=True, + connect_args={"check_same_thread": False}, + poolclass=StaticPool, + ) + Base.metadata.create_all(bind=engine) + try: + yield engine + finally: + Base.metadata.drop_all(bind=engine) + engine.dispose() + + +@pytest.fixture() +def session_factory(engine: Engine) -> Iterator[sessionmaker]: + testing_session = sessionmaker( + bind=engine, expire_on_commit=False, future=True) + yield testing_session + + +@pytest.fixture() +def api_client(session_factory) -> Iterator[TestClient]: + app = FastAPI() + app.include_router(scenarios_router) + + def _override_uow() -> Iterator[UnitOfWork]: + with UnitOfWork(session_factory=session_factory) as uow: + yield uow + + app.dependency_overrides[get_unit_of_work] = _override_uow + client = TestClient(app) + try: + yield client + finally: + client.close() + + +def _create_project_with_scenarios( + session_factory: sessionmaker, + scenario_overrides: list[dict[str, object]], +) -> tuple[int, list[int]]: + with UnitOfWork(session_factory=session_factory) as uow: + project_name = f"Project {uuid4()}" + project = Project(name=project_name, + operation_type=MiningOperationType.OPEN_PIT) + uow.projects.create(project) + + scenario_ids: list[int] = [] + for index, overrides in enumerate(scenario_overrides, start=1): + scenario = Scenario( + project_id=project.id, + name=f"Scenario {index}", + status=overrides.get("status", ScenarioStatus.DRAFT), + start_date=overrides.get("start_date", date(2025, 1, 1)), + end_date=overrides.get("end_date", date(2025, 12, 31)), + currency=overrides.get("currency", "USD"), + primary_resource=overrides.get( + "primary_resource", ResourceType.DIESEL), + ) + uow.scenarios.create(scenario) + scenario_ids.append(scenario.id) + + return project.id, scenario_ids + + +class TestScenarioComparisonEndpoint: + def test_returns_scenarios_when_validation_passes( + self, api_client: TestClient, session_factory: sessionmaker + ) -> None: + project_id, scenario_ids = _create_project_with_scenarios( + session_factory, + [ + {}, + {"start_date": date(2025, 6, 1), + "end_date": date(2025, 12, 31)}, + ], + ) + + response = api_client.post( + f"/projects/{project_id}/scenarios/compare", + json={"scenario_ids": scenario_ids}, + ) + + assert response.status_code == 200 + payload = ScenarioComparisonResponse.model_validate(response.json()) + assert payload.project_id == project_id + assert {scenario.id for scenario in payload.scenarios} == set( + scenario_ids) + + def test_returns_422_when_currency_mismatch( + self, api_client: TestClient, session_factory: sessionmaker + ) -> None: + project_id, scenario_ids = _create_project_with_scenarios( + session_factory, + [{"currency": "USD"}, {"currency": "CAD"}], + ) + + response = api_client.post( + f"/projects/{project_id}/scenarios/compare", + json={"scenario_ids": scenario_ids}, + ) + + assert response.status_code == 422 + detail = response.json()["detail"] + assert detail["code"] == "SCENARIO_CURRENCY_MISMATCH" + + def test_returns_422_when_second_scenario_from_other_project( + self, api_client: TestClient, session_factory: sessionmaker + ) -> None: + project_a_id, scenario_ids_a = _create_project_with_scenarios( + session_factory, [{}]) + project_b_id, scenario_ids_b = _create_project_with_scenarios( + session_factory, [{}]) + + response = api_client.post( + f"/projects/{project_a_id}/scenarios/compare", + json={"scenario_ids": [scenario_ids_a[0], scenario_ids_b[0]]}, + ) + + assert response.status_code == 422 + detail = response.json()["detail"] + assert detail["code"] == "SCENARIO_PROJECT_MISMATCH" + assert project_a_id != project_b_id