feat: implement scenario comparison validation and API endpoint with comprehensive unit tests

This commit is contained in:
2025-11-09 18:42:04 +01:00
parent c39dde3198
commit 02da881d3e
7 changed files with 483 additions and 5 deletions

View File

@@ -9,3 +9,4 @@
- Exposed project and scenario CRUD APIs with validated schemas and integrated them into the FastAPI application. - Exposed project and scenario CRUD APIs with validated schemas and integrated them into the FastAPI application.
- Connected project and scenario routers to new Jinja2 list/detail/edit views with HTML forms and redirects. - Connected project and scenario routers to new Jinja2 list/detail/edit views with HTML forms and redirects.
- Implemented FR-009 client-side enhancements with responsive navigation toggle, mobile-first scenario tables, and shared asset loading across templates. - Implemented FR-009 client-side enhancements with responsive navigation toggle, mobile-first scenario tables, and shared asset loading across templates.
- Added scenario comparison validator, FastAPI comparison endpoint, and comprehensive unit tests to enforce FR-009 validation rules through API errors.

View File

@@ -9,8 +9,18 @@ from fastapi.templating import Jinja2Templates
from dependencies import get_unit_of_work from dependencies import get_unit_of_work
from models import ResourceType, Scenario, ScenarioStatus from models import ResourceType, Scenario, ScenarioStatus
from schemas.scenario import ScenarioCreate, ScenarioRead, ScenarioUpdate from schemas.scenario import (
from services.exceptions import EntityConflictError, EntityNotFoundError ScenarioComparisonRequest,
ScenarioComparisonResponse,
ScenarioCreate,
ScenarioRead,
ScenarioUpdate,
)
from services.exceptions import (
EntityConflictError,
EntityNotFoundError,
ScenarioValidationError,
)
from services.unit_of_work import UnitOfWork from services.unit_of_work import UnitOfWork
router = APIRouter(tags=["Scenarios"]) router = APIRouter(tags=["Scenarios"])
@@ -51,6 +61,53 @@ def list_scenarios_for_project(
return [_to_read_model(scenario) for scenario in scenarios] return [_to_read_model(scenario) for scenario in scenarios]
@router.post(
"/projects/{project_id}/scenarios/compare",
response_model=ScenarioComparisonResponse,
status_code=status.HTTP_200_OK,
)
def compare_scenarios(
project_id: int,
payload: ScenarioComparisonRequest,
uow: UnitOfWork = Depends(get_unit_of_work),
) -> ScenarioComparisonResponse:
try:
uow.projects.get(project_id)
except EntityNotFoundError as exc:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)
) from exc
try:
scenarios = uow.validate_scenarios_for_comparison(payload.scenario_ids)
if any(scenario.project_id != project_id for scenario in scenarios):
raise ScenarioValidationError(
code="SCENARIO_PROJECT_MISMATCH",
message="Selected scenarios do not belong to the same project.",
scenario_ids=[
scenario.id for scenario in scenarios if scenario.id is not None
],
)
except EntityNotFoundError as exc:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)
) from exc
except ScenarioValidationError as exc:
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_CONTENT,
detail={
"code": exc.code,
"message": exc.message,
"scenario_ids": list(exc.scenario_ids or []),
},
) from exc
return ScenarioComparisonResponse(
project_id=project_id,
scenarios=[_to_read_model(scenario) for scenario in scenarios],
)
@router.post( @router.post(
"/projects/{project_id}/scenarios", "/projects/{project_id}/scenarios",
response_model=ScenarioRead, response_model=ScenarioRead,

View File

@@ -2,7 +2,7 @@ from __future__ import annotations
from datetime import date, datetime from datetime import date, datetime
from pydantic import BaseModel, ConfigDict, field_validator from pydantic import BaseModel, ConfigDict, field_validator, model_validator
from models import ResourceType, ScenarioStatus from models import ResourceType, ScenarioStatus
@@ -64,3 +64,24 @@ class ScenarioRead(ScenarioBase):
updated_at: datetime updated_at: datetime
model_config = ConfigDict(from_attributes=True) model_config = ConfigDict(from_attributes=True)
class ScenarioComparisonRequest(BaseModel):
scenario_ids: list[int]
model_config = ConfigDict(extra="forbid")
@model_validator(mode="after")
def ensure_minimum_ids(self) -> "ScenarioComparisonRequest":
unique_ids: list[int] = list(dict.fromkeys(self.scenario_ids))
if len(unique_ids) < 2:
raise ValueError("At least two unique scenario identifiers are required for comparison.")
self.scenario_ids = unique_ids
return self
class ScenarioComparisonResponse(BaseModel):
project_id: int
scenarios: list[ScenarioRead]
model_config = ConfigDict(from_attributes=True)

View File

@@ -1,5 +1,8 @@
"""Domain-level exceptions for service and repository layers.""" """Domain-level exceptions for service and repository layers."""
from dataclasses import dataclass
from typing import Sequence
class EntityNotFoundError(Exception): class EntityNotFoundError(Exception):
"""Raised when a requested entity cannot be located.""" """Raised when a requested entity cannot be located."""
@@ -7,3 +10,15 @@ class EntityNotFoundError(Exception):
class EntityConflictError(Exception): class EntityConflictError(Exception):
"""Raised when attempting to create or update an entity that violates uniqueness.""" """Raised when attempting to create or update an entity that violates uniqueness."""
@dataclass(eq=False)
class ScenarioValidationError(Exception):
"""Raised when scenarios fail comparison validation rules."""
code: str
message: str
scenario_ids: Sequence[int] | None = None
def __str__(self) -> str: # pragma: no cover - mirrors message for logging
return self.message

View File

@@ -0,0 +1,106 @@
from __future__ import annotations
from dataclasses import dataclass
from datetime import date
from typing import Iterable, Sequence
from models import Scenario, ScenarioStatus
from services.exceptions import ScenarioValidationError
ALLOWED_STATUSES: frozenset[ScenarioStatus] = frozenset(
{ScenarioStatus.DRAFT, ScenarioStatus.ACTIVE}
)
@dataclass(frozen=True)
class _ValidationContext:
scenarios: Sequence[Scenario]
@property
def scenario_ids(self) -> list[int]:
return [scenario.id for scenario in self.scenarios if scenario.id is not None]
class ScenarioComparisonValidator:
"""Validates scenarios prior to comparison workflows."""
def validate(self, scenarios: Sequence[Scenario] | Iterable[Scenario]) -> None:
scenario_list = list(scenarios)
if len(scenario_list) < 2:
# Nothing to validate when fewer than two scenarios are provided.
return
context = _ValidationContext(scenario_list)
self._ensure_same_project(context)
self._ensure_allowed_status(context)
self._ensure_shared_currency(context)
self._ensure_timeline_overlap(context)
self._ensure_shared_primary_resource(context)
def _ensure_same_project(self, context: _ValidationContext) -> None:
project_ids = {scenario.project_id for scenario in context.scenarios}
if len(project_ids) > 1:
raise ScenarioValidationError(
code="SCENARIO_PROJECT_MISMATCH",
message="Selected scenarios do not belong to the same project.",
scenario_ids=context.scenario_ids,
)
def _ensure_allowed_status(self, context: _ValidationContext) -> None:
disallowed = [
scenario
for scenario in context.scenarios
if scenario.status not in ALLOWED_STATUSES
]
if disallowed:
raise ScenarioValidationError(
code="SCENARIO_STATUS_INVALID",
message="Archived scenarios cannot be compared.",
scenario_ids=[
scenario.id for scenario in disallowed if scenario.id is not None],
)
def _ensure_shared_currency(self, context: _ValidationContext) -> None:
currencies = {
scenario.currency
for scenario in context.scenarios
if scenario.currency is not None
}
if len(currencies) > 1:
raise ScenarioValidationError(
code="SCENARIO_CURRENCY_MISMATCH",
message="Scenarios use different currencies and cannot be compared.",
scenario_ids=context.scenario_ids,
)
def _ensure_timeline_overlap(self, context: _ValidationContext) -> None:
ranges = [
(scenario.start_date, scenario.end_date)
for scenario in context.scenarios
if scenario.start_date and scenario.end_date
]
if len(ranges) < 2:
return
latest_start: date = max(start for start, _ in ranges)
earliest_end: date = min(end for _, end in ranges)
if latest_start > earliest_end:
raise ScenarioValidationError(
code="SCENARIO_TIMELINE_DISJOINT",
message="Scenario timelines do not overlap; adjust the comparison window.",
scenario_ids=context.scenario_ids,
)
def _ensure_shared_primary_resource(self, context: _ValidationContext) -> None:
resources = {
scenario.primary_resource
for scenario in context.scenarios
if scenario.primary_resource is not None
}
if len(resources) > 1:
raise ScenarioValidationError(
code="SCENARIO_RESOURCE_MISMATCH",
message="Scenarios target different primary resources and cannot be compared.",
scenario_ids=context.scenario_ids,
)

View File

@@ -1,17 +1,19 @@
from __future__ import annotations from __future__ import annotations
from contextlib import AbstractContextManager from contextlib import AbstractContextManager
from typing import Callable from typing import Callable, Sequence
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from config.database import SessionLocal from config.database import SessionLocal
from models import Scenario
from services.repositories import ( from services.repositories import (
FinancialInputRepository, FinancialInputRepository,
ProjectRepository, ProjectRepository,
ScenarioRepository, ScenarioRepository,
SimulationParameterRepository, SimulationParameterRepository,
) )
from services.scenario_validation import ScenarioComparisonValidator
class UnitOfWork(AbstractContextManager["UnitOfWork"]): class UnitOfWork(AbstractContextManager["UnitOfWork"]):
@@ -20,13 +22,16 @@ class UnitOfWork(AbstractContextManager["UnitOfWork"]):
def __init__(self, session_factory: Callable[[], Session] = SessionLocal) -> None: def __init__(self, session_factory: Callable[[], Session] = SessionLocal) -> None:
self._session_factory = session_factory self._session_factory = session_factory
self.session: Session | None = None self.session: Session | None = None
self._scenario_validator: ScenarioComparisonValidator | None = None
def __enter__(self) -> "UnitOfWork": def __enter__(self) -> "UnitOfWork":
self.session = self._session_factory() self.session = self._session_factory()
self.projects = ProjectRepository(self.session) self.projects = ProjectRepository(self.session)
self.scenarios = ScenarioRepository(self.session) self.scenarios = ScenarioRepository(self.session)
self.financial_inputs = FinancialInputRepository(self.session) self.financial_inputs = FinancialInputRepository(self.session)
self.simulation_parameters = SimulationParameterRepository(self.session) self.simulation_parameters = SimulationParameterRepository(
self.session)
self._scenario_validator = ScenarioComparisonValidator()
return self return self
def __exit__(self, exc_type, exc_value, traceback) -> None: def __exit__(self, exc_type, exc_value, traceback) -> None:
@@ -36,6 +41,7 @@ class UnitOfWork(AbstractContextManager["UnitOfWork"]):
else: else:
self.session.rollback() self.session.rollback()
self.session.close() self.session.close()
self._scenario_validator = None
def flush(self) -> None: def flush(self) -> None:
if not self.session: if not self.session:
@@ -51,3 +57,21 @@ class UnitOfWork(AbstractContextManager["UnitOfWork"]):
if not self.session: if not self.session:
raise RuntimeError("UnitOfWork session is not initialised") raise RuntimeError("UnitOfWork session is not initialised")
self.session.rollback() self.session.rollback()
def validate_scenarios_for_comparison(
self, scenario_ids: Sequence[int]
) -> list[Scenario]:
if not self.session or not self._scenario_validator:
raise RuntimeError("UnitOfWork session is not initialised")
scenarios = [self.scenarios.get(scenario_id)
for scenario_id in scenario_ids]
self._scenario_validator.validate(scenarios)
return scenarios
def validate_scenario_models_for_comparison(
self, scenarios: Sequence[Scenario]
) -> None:
if not self._scenario_validator:
raise RuntimeError("UnitOfWork session is not initialised")
self._scenario_validator.validate(scenarios)

View File

@@ -0,0 +1,254 @@
from __future__ import annotations
from datetime import date
from collections.abc import Iterator
from typing import cast
from uuid import uuid4
import pytest
from fastapi import FastAPI
from fastapi.testclient import TestClient
from pydantic import ValidationError
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from sqlalchemy.engine import Engine
from sqlalchemy.pool import StaticPool
from config.database import Base
from dependencies import get_unit_of_work
from models import (
MiningOperationType,
Project,
ResourceType,
Scenario,
ScenarioStatus,
)
from schemas.scenario import (
ScenarioComparisonRequest,
ScenarioComparisonResponse,
)
from services.exceptions import ScenarioValidationError
from services.scenario_validation import ScenarioComparisonValidator
from services.unit_of_work import UnitOfWork
from routes.scenarios import router as scenarios_router
@pytest.fixture()
def validator() -> ScenarioComparisonValidator:
return ScenarioComparisonValidator()
def _make_scenario(**overrides) -> Scenario:
project_id: int = int(overrides.get("project_id", 1))
name: str = str(overrides.get("name", "Scenario"))
status = cast(ScenarioStatus, overrides.get(
"status", ScenarioStatus.DRAFT))
start_date = overrides.get("start_date", date(2025, 1, 1))
end_date = overrides.get("end_date", date(2025, 12, 31))
currency = cast(str, overrides.get("currency", "USD"))
primary_resource = cast(ResourceType, overrides.get(
"primary_resource", ResourceType.DIESEL))
scenario = Scenario(
project_id=project_id,
name=name,
status=status,
start_date=start_date,
end_date=end_date,
currency=currency,
primary_resource=primary_resource,
)
if "id" in overrides:
scenario.id = overrides["id"]
return scenario
class TestScenarioComparisonValidator:
def test_validate_allows_matching_scenarios(self, validator: ScenarioComparisonValidator) -> None:
scenario_a = _make_scenario(id=1)
scenario_b = _make_scenario(id=2)
validator.validate([scenario_a, scenario_b])
@pytest.mark.parametrize(
"kwargs_a, kwargs_b, expected_code",
[
({"project_id": 1}, {"project_id": 2}, "SCENARIO_PROJECT_MISMATCH"),
({"status": ScenarioStatus.ARCHIVED}, {}, "SCENARIO_STATUS_INVALID"),
({"currency": "USD"}, {"currency": "CAD"},
"SCENARIO_CURRENCY_MISMATCH"),
(
{"start_date": date(2025, 1, 1), "end_date": date(2025, 6, 1)},
{"start_date": date(2025, 7, 1),
"end_date": date(2025, 12, 31)},
"SCENARIO_TIMELINE_DISJOINT",
),
({"primary_resource": ResourceType.DIESEL}, {
"primary_resource": ResourceType.ELECTRICITY}, "SCENARIO_RESOURCE_MISMATCH"),
],
)
def test_validate_raises_for_conflicts(
self,
validator: ScenarioComparisonValidator,
kwargs_a: dict[str, object],
kwargs_b: dict[str, object],
expected_code: str,
) -> None:
scenario_a = _make_scenario(id=10, **kwargs_a)
scenario_b = _make_scenario(id=20, **kwargs_b)
with pytest.raises(ScenarioValidationError) as exc_info:
validator.validate([scenario_a, scenario_b])
assert exc_info.value.code == expected_code
def test_timeline_rule_skips_when_insufficient_ranges(
self, validator: ScenarioComparisonValidator
) -> None:
scenario_a = _make_scenario(id=1, start_date=None, end_date=None)
scenario_b = _make_scenario(id=2, start_date=date(
2025, 1, 1), end_date=date(2025, 12, 31))
validator.validate([scenario_a, scenario_b])
class TestScenarioComparisonRequest:
def test_requires_two_unique_identifiers(self) -> None:
with pytest.raises(ValidationError):
ScenarioComparisonRequest.model_validate({"scenario_ids": [1]})
def test_deduplicates_ids_preserving_order(self) -> None:
payload = ScenarioComparisonRequest.model_validate(
{"scenario_ids": [1, 1, 2, 2, 3]})
assert payload.scenario_ids == [1, 2, 3]
@pytest.fixture()
def engine() -> Iterator[Engine]:
engine = create_engine(
"sqlite+pysqlite:///:memory:",
future=True,
connect_args={"check_same_thread": False},
poolclass=StaticPool,
)
Base.metadata.create_all(bind=engine)
try:
yield engine
finally:
Base.metadata.drop_all(bind=engine)
engine.dispose()
@pytest.fixture()
def session_factory(engine: Engine) -> Iterator[sessionmaker]:
testing_session = sessionmaker(
bind=engine, expire_on_commit=False, future=True)
yield testing_session
@pytest.fixture()
def api_client(session_factory) -> Iterator[TestClient]:
app = FastAPI()
app.include_router(scenarios_router)
def _override_uow() -> Iterator[UnitOfWork]:
with UnitOfWork(session_factory=session_factory) as uow:
yield uow
app.dependency_overrides[get_unit_of_work] = _override_uow
client = TestClient(app)
try:
yield client
finally:
client.close()
def _create_project_with_scenarios(
session_factory: sessionmaker,
scenario_overrides: list[dict[str, object]],
) -> tuple[int, list[int]]:
with UnitOfWork(session_factory=session_factory) as uow:
project_name = f"Project {uuid4()}"
project = Project(name=project_name,
operation_type=MiningOperationType.OPEN_PIT)
uow.projects.create(project)
scenario_ids: list[int] = []
for index, overrides in enumerate(scenario_overrides, start=1):
scenario = Scenario(
project_id=project.id,
name=f"Scenario {index}",
status=overrides.get("status", ScenarioStatus.DRAFT),
start_date=overrides.get("start_date", date(2025, 1, 1)),
end_date=overrides.get("end_date", date(2025, 12, 31)),
currency=overrides.get("currency", "USD"),
primary_resource=overrides.get(
"primary_resource", ResourceType.DIESEL),
)
uow.scenarios.create(scenario)
scenario_ids.append(scenario.id)
return project.id, scenario_ids
class TestScenarioComparisonEndpoint:
def test_returns_scenarios_when_validation_passes(
self, api_client: TestClient, session_factory: sessionmaker
) -> None:
project_id, scenario_ids = _create_project_with_scenarios(
session_factory,
[
{},
{"start_date": date(2025, 6, 1),
"end_date": date(2025, 12, 31)},
],
)
response = api_client.post(
f"/projects/{project_id}/scenarios/compare",
json={"scenario_ids": scenario_ids},
)
assert response.status_code == 200
payload = ScenarioComparisonResponse.model_validate(response.json())
assert payload.project_id == project_id
assert {scenario.id for scenario in payload.scenarios} == set(
scenario_ids)
def test_returns_422_when_currency_mismatch(
self, api_client: TestClient, session_factory: sessionmaker
) -> None:
project_id, scenario_ids = _create_project_with_scenarios(
session_factory,
[{"currency": "USD"}, {"currency": "CAD"}],
)
response = api_client.post(
f"/projects/{project_id}/scenarios/compare",
json={"scenario_ids": scenario_ids},
)
assert response.status_code == 422
detail = response.json()["detail"]
assert detail["code"] == "SCENARIO_CURRENCY_MISMATCH"
def test_returns_422_when_second_scenario_from_other_project(
self, api_client: TestClient, session_factory: sessionmaker
) -> None:
project_a_id, scenario_ids_a = _create_project_with_scenarios(
session_factory, [{}])
project_b_id, scenario_ids_b = _create_project_with_scenarios(
session_factory, [{}])
response = api_client.post(
f"/projects/{project_a_id}/scenarios/compare",
json={"scenario_ids": [scenario_ids_a[0], scenario_ids_b[0]]},
)
assert response.status_code == 422
detail = response.json()["detail"]
assert detail["code"] == "SCENARIO_PROJECT_MISMATCH"
assert project_a_id != project_b_id