feat: implement scenario comparison validation and API endpoint with comprehensive unit tests
This commit is contained in:
254
tests/test_scenario_validation.py
Normal file
254
tests/test_scenario_validation.py
Normal file
@@ -0,0 +1,254 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import date
|
||||
from collections.abc import Iterator
|
||||
from typing import cast
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from fastapi import FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
from pydantic import ValidationError
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from sqlalchemy.engine import Engine
|
||||
from sqlalchemy.pool import StaticPool
|
||||
|
||||
from config.database import Base
|
||||
from dependencies import get_unit_of_work
|
||||
from models import (
|
||||
MiningOperationType,
|
||||
Project,
|
||||
ResourceType,
|
||||
Scenario,
|
||||
ScenarioStatus,
|
||||
)
|
||||
from schemas.scenario import (
|
||||
ScenarioComparisonRequest,
|
||||
ScenarioComparisonResponse,
|
||||
)
|
||||
from services.exceptions import ScenarioValidationError
|
||||
from services.scenario_validation import ScenarioComparisonValidator
|
||||
from services.unit_of_work import UnitOfWork
|
||||
from routes.scenarios import router as scenarios_router
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def validator() -> ScenarioComparisonValidator:
|
||||
return ScenarioComparisonValidator()
|
||||
|
||||
|
||||
def _make_scenario(**overrides) -> Scenario:
|
||||
project_id: int = int(overrides.get("project_id", 1))
|
||||
name: str = str(overrides.get("name", "Scenario"))
|
||||
status = cast(ScenarioStatus, overrides.get(
|
||||
"status", ScenarioStatus.DRAFT))
|
||||
start_date = overrides.get("start_date", date(2025, 1, 1))
|
||||
end_date = overrides.get("end_date", date(2025, 12, 31))
|
||||
currency = cast(str, overrides.get("currency", "USD"))
|
||||
primary_resource = cast(ResourceType, overrides.get(
|
||||
"primary_resource", ResourceType.DIESEL))
|
||||
|
||||
scenario = Scenario(
|
||||
project_id=project_id,
|
||||
name=name,
|
||||
status=status,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
currency=currency,
|
||||
primary_resource=primary_resource,
|
||||
)
|
||||
|
||||
if "id" in overrides:
|
||||
scenario.id = overrides["id"]
|
||||
|
||||
return scenario
|
||||
|
||||
|
||||
class TestScenarioComparisonValidator:
|
||||
def test_validate_allows_matching_scenarios(self, validator: ScenarioComparisonValidator) -> None:
|
||||
scenario_a = _make_scenario(id=1)
|
||||
scenario_b = _make_scenario(id=2)
|
||||
|
||||
validator.validate([scenario_a, scenario_b])
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"kwargs_a, kwargs_b, expected_code",
|
||||
[
|
||||
({"project_id": 1}, {"project_id": 2}, "SCENARIO_PROJECT_MISMATCH"),
|
||||
({"status": ScenarioStatus.ARCHIVED}, {}, "SCENARIO_STATUS_INVALID"),
|
||||
({"currency": "USD"}, {"currency": "CAD"},
|
||||
"SCENARIO_CURRENCY_MISMATCH"),
|
||||
(
|
||||
{"start_date": date(2025, 1, 1), "end_date": date(2025, 6, 1)},
|
||||
{"start_date": date(2025, 7, 1),
|
||||
"end_date": date(2025, 12, 31)},
|
||||
"SCENARIO_TIMELINE_DISJOINT",
|
||||
),
|
||||
({"primary_resource": ResourceType.DIESEL}, {
|
||||
"primary_resource": ResourceType.ELECTRICITY}, "SCENARIO_RESOURCE_MISMATCH"),
|
||||
],
|
||||
)
|
||||
def test_validate_raises_for_conflicts(
|
||||
self,
|
||||
validator: ScenarioComparisonValidator,
|
||||
kwargs_a: dict[str, object],
|
||||
kwargs_b: dict[str, object],
|
||||
expected_code: str,
|
||||
) -> None:
|
||||
scenario_a = _make_scenario(id=10, **kwargs_a)
|
||||
scenario_b = _make_scenario(id=20, **kwargs_b)
|
||||
|
||||
with pytest.raises(ScenarioValidationError) as exc_info:
|
||||
validator.validate([scenario_a, scenario_b])
|
||||
|
||||
assert exc_info.value.code == expected_code
|
||||
|
||||
def test_timeline_rule_skips_when_insufficient_ranges(
|
||||
self, validator: ScenarioComparisonValidator
|
||||
) -> None:
|
||||
scenario_a = _make_scenario(id=1, start_date=None, end_date=None)
|
||||
scenario_b = _make_scenario(id=2, start_date=date(
|
||||
2025, 1, 1), end_date=date(2025, 12, 31))
|
||||
|
||||
validator.validate([scenario_a, scenario_b])
|
||||
|
||||
|
||||
class TestScenarioComparisonRequest:
|
||||
def test_requires_two_unique_identifiers(self) -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
ScenarioComparisonRequest.model_validate({"scenario_ids": [1]})
|
||||
|
||||
def test_deduplicates_ids_preserving_order(self) -> None:
|
||||
payload = ScenarioComparisonRequest.model_validate(
|
||||
{"scenario_ids": [1, 1, 2, 2, 3]})
|
||||
|
||||
assert payload.scenario_ids == [1, 2, 3]
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def engine() -> Iterator[Engine]:
|
||||
engine = create_engine(
|
||||
"sqlite+pysqlite:///:memory:",
|
||||
future=True,
|
||||
connect_args={"check_same_thread": False},
|
||||
poolclass=StaticPool,
|
||||
)
|
||||
Base.metadata.create_all(bind=engine)
|
||||
try:
|
||||
yield engine
|
||||
finally:
|
||||
Base.metadata.drop_all(bind=engine)
|
||||
engine.dispose()
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def session_factory(engine: Engine) -> Iterator[sessionmaker]:
|
||||
testing_session = sessionmaker(
|
||||
bind=engine, expire_on_commit=False, future=True)
|
||||
yield testing_session
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def api_client(session_factory) -> Iterator[TestClient]:
|
||||
app = FastAPI()
|
||||
app.include_router(scenarios_router)
|
||||
|
||||
def _override_uow() -> Iterator[UnitOfWork]:
|
||||
with UnitOfWork(session_factory=session_factory) as uow:
|
||||
yield uow
|
||||
|
||||
app.dependency_overrides[get_unit_of_work] = _override_uow
|
||||
client = TestClient(app)
|
||||
try:
|
||||
yield client
|
||||
finally:
|
||||
client.close()
|
||||
|
||||
|
||||
def _create_project_with_scenarios(
|
||||
session_factory: sessionmaker,
|
||||
scenario_overrides: list[dict[str, object]],
|
||||
) -> tuple[int, list[int]]:
|
||||
with UnitOfWork(session_factory=session_factory) as uow:
|
||||
project_name = f"Project {uuid4()}"
|
||||
project = Project(name=project_name,
|
||||
operation_type=MiningOperationType.OPEN_PIT)
|
||||
uow.projects.create(project)
|
||||
|
||||
scenario_ids: list[int] = []
|
||||
for index, overrides in enumerate(scenario_overrides, start=1):
|
||||
scenario = Scenario(
|
||||
project_id=project.id,
|
||||
name=f"Scenario {index}",
|
||||
status=overrides.get("status", ScenarioStatus.DRAFT),
|
||||
start_date=overrides.get("start_date", date(2025, 1, 1)),
|
||||
end_date=overrides.get("end_date", date(2025, 12, 31)),
|
||||
currency=overrides.get("currency", "USD"),
|
||||
primary_resource=overrides.get(
|
||||
"primary_resource", ResourceType.DIESEL),
|
||||
)
|
||||
uow.scenarios.create(scenario)
|
||||
scenario_ids.append(scenario.id)
|
||||
|
||||
return project.id, scenario_ids
|
||||
|
||||
|
||||
class TestScenarioComparisonEndpoint:
|
||||
def test_returns_scenarios_when_validation_passes(
|
||||
self, api_client: TestClient, session_factory: sessionmaker
|
||||
) -> None:
|
||||
project_id, scenario_ids = _create_project_with_scenarios(
|
||||
session_factory,
|
||||
[
|
||||
{},
|
||||
{"start_date": date(2025, 6, 1),
|
||||
"end_date": date(2025, 12, 31)},
|
||||
],
|
||||
)
|
||||
|
||||
response = api_client.post(
|
||||
f"/projects/{project_id}/scenarios/compare",
|
||||
json={"scenario_ids": scenario_ids},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
payload = ScenarioComparisonResponse.model_validate(response.json())
|
||||
assert payload.project_id == project_id
|
||||
assert {scenario.id for scenario in payload.scenarios} == set(
|
||||
scenario_ids)
|
||||
|
||||
def test_returns_422_when_currency_mismatch(
|
||||
self, api_client: TestClient, session_factory: sessionmaker
|
||||
) -> None:
|
||||
project_id, scenario_ids = _create_project_with_scenarios(
|
||||
session_factory,
|
||||
[{"currency": "USD"}, {"currency": "CAD"}],
|
||||
)
|
||||
|
||||
response = api_client.post(
|
||||
f"/projects/{project_id}/scenarios/compare",
|
||||
json={"scenario_ids": scenario_ids},
|
||||
)
|
||||
|
||||
assert response.status_code == 422
|
||||
detail = response.json()["detail"]
|
||||
assert detail["code"] == "SCENARIO_CURRENCY_MISMATCH"
|
||||
|
||||
def test_returns_422_when_second_scenario_from_other_project(
|
||||
self, api_client: TestClient, session_factory: sessionmaker
|
||||
) -> None:
|
||||
project_a_id, scenario_ids_a = _create_project_with_scenarios(
|
||||
session_factory, [{}])
|
||||
project_b_id, scenario_ids_b = _create_project_with_scenarios(
|
||||
session_factory, [{}])
|
||||
|
||||
response = api_client.post(
|
||||
f"/projects/{project_a_id}/scenarios/compare",
|
||||
json={"scenario_ids": [scenario_ids_a[0], scenario_ids_b[0]]},
|
||||
)
|
||||
|
||||
assert response.status_code == 422
|
||||
detail = response.json()["detail"]
|
||||
assert detail["code"] == "SCENARIO_PROJECT_MISMATCH"
|
||||
assert project_a_id != project_b_id
|
||||
Reference in New Issue
Block a user