feat: implement scenario comparison validation and API endpoint with comprehensive unit tests

This commit is contained in:
2025-11-09 18:42:04 +01:00
parent c39dde3198
commit 02da881d3e
7 changed files with 483 additions and 5 deletions

View File

@@ -0,0 +1,254 @@
from __future__ import annotations
from datetime import date
from collections.abc import Iterator
from typing import cast
from uuid import uuid4
import pytest
from fastapi import FastAPI
from fastapi.testclient import TestClient
from pydantic import ValidationError
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from sqlalchemy.engine import Engine
from sqlalchemy.pool import StaticPool
from config.database import Base
from dependencies import get_unit_of_work
from models import (
MiningOperationType,
Project,
ResourceType,
Scenario,
ScenarioStatus,
)
from schemas.scenario import (
ScenarioComparisonRequest,
ScenarioComparisonResponse,
)
from services.exceptions import ScenarioValidationError
from services.scenario_validation import ScenarioComparisonValidator
from services.unit_of_work import UnitOfWork
from routes.scenarios import router as scenarios_router
@pytest.fixture()
def validator() -> ScenarioComparisonValidator:
return ScenarioComparisonValidator()
def _make_scenario(**overrides) -> Scenario:
project_id: int = int(overrides.get("project_id", 1))
name: str = str(overrides.get("name", "Scenario"))
status = cast(ScenarioStatus, overrides.get(
"status", ScenarioStatus.DRAFT))
start_date = overrides.get("start_date", date(2025, 1, 1))
end_date = overrides.get("end_date", date(2025, 12, 31))
currency = cast(str, overrides.get("currency", "USD"))
primary_resource = cast(ResourceType, overrides.get(
"primary_resource", ResourceType.DIESEL))
scenario = Scenario(
project_id=project_id,
name=name,
status=status,
start_date=start_date,
end_date=end_date,
currency=currency,
primary_resource=primary_resource,
)
if "id" in overrides:
scenario.id = overrides["id"]
return scenario
class TestScenarioComparisonValidator:
def test_validate_allows_matching_scenarios(self, validator: ScenarioComparisonValidator) -> None:
scenario_a = _make_scenario(id=1)
scenario_b = _make_scenario(id=2)
validator.validate([scenario_a, scenario_b])
@pytest.mark.parametrize(
"kwargs_a, kwargs_b, expected_code",
[
({"project_id": 1}, {"project_id": 2}, "SCENARIO_PROJECT_MISMATCH"),
({"status": ScenarioStatus.ARCHIVED}, {}, "SCENARIO_STATUS_INVALID"),
({"currency": "USD"}, {"currency": "CAD"},
"SCENARIO_CURRENCY_MISMATCH"),
(
{"start_date": date(2025, 1, 1), "end_date": date(2025, 6, 1)},
{"start_date": date(2025, 7, 1),
"end_date": date(2025, 12, 31)},
"SCENARIO_TIMELINE_DISJOINT",
),
({"primary_resource": ResourceType.DIESEL}, {
"primary_resource": ResourceType.ELECTRICITY}, "SCENARIO_RESOURCE_MISMATCH"),
],
)
def test_validate_raises_for_conflicts(
self,
validator: ScenarioComparisonValidator,
kwargs_a: dict[str, object],
kwargs_b: dict[str, object],
expected_code: str,
) -> None:
scenario_a = _make_scenario(id=10, **kwargs_a)
scenario_b = _make_scenario(id=20, **kwargs_b)
with pytest.raises(ScenarioValidationError) as exc_info:
validator.validate([scenario_a, scenario_b])
assert exc_info.value.code == expected_code
def test_timeline_rule_skips_when_insufficient_ranges(
self, validator: ScenarioComparisonValidator
) -> None:
scenario_a = _make_scenario(id=1, start_date=None, end_date=None)
scenario_b = _make_scenario(id=2, start_date=date(
2025, 1, 1), end_date=date(2025, 12, 31))
validator.validate([scenario_a, scenario_b])
class TestScenarioComparisonRequest:
def test_requires_two_unique_identifiers(self) -> None:
with pytest.raises(ValidationError):
ScenarioComparisonRequest.model_validate({"scenario_ids": [1]})
def test_deduplicates_ids_preserving_order(self) -> None:
payload = ScenarioComparisonRequest.model_validate(
{"scenario_ids": [1, 1, 2, 2, 3]})
assert payload.scenario_ids == [1, 2, 3]
@pytest.fixture()
def engine() -> Iterator[Engine]:
engine = create_engine(
"sqlite+pysqlite:///:memory:",
future=True,
connect_args={"check_same_thread": False},
poolclass=StaticPool,
)
Base.metadata.create_all(bind=engine)
try:
yield engine
finally:
Base.metadata.drop_all(bind=engine)
engine.dispose()
@pytest.fixture()
def session_factory(engine: Engine) -> Iterator[sessionmaker]:
testing_session = sessionmaker(
bind=engine, expire_on_commit=False, future=True)
yield testing_session
@pytest.fixture()
def api_client(session_factory) -> Iterator[TestClient]:
app = FastAPI()
app.include_router(scenarios_router)
def _override_uow() -> Iterator[UnitOfWork]:
with UnitOfWork(session_factory=session_factory) as uow:
yield uow
app.dependency_overrides[get_unit_of_work] = _override_uow
client = TestClient(app)
try:
yield client
finally:
client.close()
def _create_project_with_scenarios(
session_factory: sessionmaker,
scenario_overrides: list[dict[str, object]],
) -> tuple[int, list[int]]:
with UnitOfWork(session_factory=session_factory) as uow:
project_name = f"Project {uuid4()}"
project = Project(name=project_name,
operation_type=MiningOperationType.OPEN_PIT)
uow.projects.create(project)
scenario_ids: list[int] = []
for index, overrides in enumerate(scenario_overrides, start=1):
scenario = Scenario(
project_id=project.id,
name=f"Scenario {index}",
status=overrides.get("status", ScenarioStatus.DRAFT),
start_date=overrides.get("start_date", date(2025, 1, 1)),
end_date=overrides.get("end_date", date(2025, 12, 31)),
currency=overrides.get("currency", "USD"),
primary_resource=overrides.get(
"primary_resource", ResourceType.DIESEL),
)
uow.scenarios.create(scenario)
scenario_ids.append(scenario.id)
return project.id, scenario_ids
class TestScenarioComparisonEndpoint:
def test_returns_scenarios_when_validation_passes(
self, api_client: TestClient, session_factory: sessionmaker
) -> None:
project_id, scenario_ids = _create_project_with_scenarios(
session_factory,
[
{},
{"start_date": date(2025, 6, 1),
"end_date": date(2025, 12, 31)},
],
)
response = api_client.post(
f"/projects/{project_id}/scenarios/compare",
json={"scenario_ids": scenario_ids},
)
assert response.status_code == 200
payload = ScenarioComparisonResponse.model_validate(response.json())
assert payload.project_id == project_id
assert {scenario.id for scenario in payload.scenarios} == set(
scenario_ids)
def test_returns_422_when_currency_mismatch(
self, api_client: TestClient, session_factory: sessionmaker
) -> None:
project_id, scenario_ids = _create_project_with_scenarios(
session_factory,
[{"currency": "USD"}, {"currency": "CAD"}],
)
response = api_client.post(
f"/projects/{project_id}/scenarios/compare",
json={"scenario_ids": scenario_ids},
)
assert response.status_code == 422
detail = response.json()["detail"]
assert detail["code"] == "SCENARIO_CURRENCY_MISMATCH"
def test_returns_422_when_second_scenario_from_other_project(
self, api_client: TestClient, session_factory: sessionmaker
) -> None:
project_a_id, scenario_ids_a = _create_project_with_scenarios(
session_factory, [{}])
project_b_id, scenario_ids_b = _create_project_with_scenarios(
session_factory, [{}])
response = api_client.post(
f"/projects/{project_a_id}/scenarios/compare",
json={"scenario_ids": [scenario_ids_a[0], scenario_ids_b[0]]},
)
assert response.status_code == 422
detail = response.json()["detail"]
assert detail["code"] == "SCENARIO_PROJECT_MISMATCH"
assert project_a_id != project_b_id