from __future__ import annotations from datetime import date from collections.abc import Iterator from typing import cast from uuid import uuid4 import pytest from fastapi import FastAPI from fastapi.testclient import TestClient from pydantic import ValidationError from sqlalchemy import create_engine from sqlalchemy.orm import sessionmaker from sqlalchemy.engine import Engine from sqlalchemy.pool import StaticPool from config.database import Base from dependencies import get_unit_of_work from models import ( MiningOperationType, Project, ResourceType, Scenario, ScenarioStatus, ) from schemas.scenario import ( ScenarioComparisonRequest, ScenarioComparisonResponse, ) from services.exceptions import ScenarioValidationError from services.scenario_validation import ScenarioComparisonValidator from services.unit_of_work import UnitOfWork from routes.scenarios import router as scenarios_router @pytest.fixture() def validator() -> ScenarioComparisonValidator: return ScenarioComparisonValidator() def _make_scenario(**overrides) -> Scenario: project_id: int = int(overrides.get("project_id", 1)) name: str = str(overrides.get("name", "Scenario")) status = cast(ScenarioStatus, overrides.get( "status", ScenarioStatus.DRAFT)) start_date = overrides.get("start_date", date(2025, 1, 1)) end_date = overrides.get("end_date", date(2025, 12, 31)) currency = cast(str, overrides.get("currency", "USD")) primary_resource = cast(ResourceType, overrides.get( "primary_resource", ResourceType.DIESEL)) scenario = Scenario( project_id=project_id, name=name, status=status, start_date=start_date, end_date=end_date, currency=currency, primary_resource=primary_resource, ) if "id" in overrides: scenario.id = overrides["id"] return scenario class TestScenarioComparisonValidator: def test_validate_allows_matching_scenarios(self, validator: ScenarioComparisonValidator) -> None: scenario_a = _make_scenario(id=1) scenario_b = _make_scenario(id=2) validator.validate([scenario_a, scenario_b]) @pytest.mark.parametrize( "kwargs_a, kwargs_b, expected_code", [ ({"project_id": 1}, {"project_id": 2}, "SCENARIO_PROJECT_MISMATCH"), ({"status": ScenarioStatus.ARCHIVED}, {}, "SCENARIO_STATUS_INVALID"), ({"currency": "USD"}, {"currency": "CAD"}, "SCENARIO_CURRENCY_MISMATCH"), ( {"start_date": date(2025, 1, 1), "end_date": date(2025, 6, 1)}, {"start_date": date(2025, 7, 1), "end_date": date(2025, 12, 31)}, "SCENARIO_TIMELINE_DISJOINT", ), ({"primary_resource": ResourceType.DIESEL}, { "primary_resource": ResourceType.ELECTRICITY}, "SCENARIO_RESOURCE_MISMATCH"), ], ) def test_validate_raises_for_conflicts( self, validator: ScenarioComparisonValidator, kwargs_a: dict[str, object], kwargs_b: dict[str, object], expected_code: str, ) -> None: scenario_a = _make_scenario(id=10, **kwargs_a) scenario_b = _make_scenario(id=20, **kwargs_b) with pytest.raises(ScenarioValidationError) as exc_info: validator.validate([scenario_a, scenario_b]) assert exc_info.value.code == expected_code def test_timeline_rule_skips_when_insufficient_ranges( self, validator: ScenarioComparisonValidator ) -> None: scenario_a = _make_scenario(id=1, start_date=None, end_date=None) scenario_b = _make_scenario(id=2, start_date=date( 2025, 1, 1), end_date=date(2025, 12, 31)) validator.validate([scenario_a, scenario_b]) class TestScenarioComparisonRequest: def test_requires_two_unique_identifiers(self) -> None: with pytest.raises(ValidationError): ScenarioComparisonRequest.model_validate({"scenario_ids": [1]}) def test_deduplicates_ids_preserving_order(self) -> None: payload = ScenarioComparisonRequest.model_validate( {"scenario_ids": [1, 1, 2, 2, 3]}) assert payload.scenario_ids == [1, 2, 3] @pytest.fixture() def engine() -> Iterator[Engine]: engine = create_engine( "sqlite+pysqlite:///:memory:", future=True, connect_args={"check_same_thread": False}, poolclass=StaticPool, ) Base.metadata.create_all(bind=engine) try: yield engine finally: Base.metadata.drop_all(bind=engine) engine.dispose() @pytest.fixture() def session_factory(engine: Engine) -> Iterator[sessionmaker]: testing_session = sessionmaker( bind=engine, expire_on_commit=False, future=True) yield testing_session @pytest.fixture() def api_client(session_factory) -> Iterator[TestClient]: app = FastAPI() app.include_router(scenarios_router) def _override_uow() -> Iterator[UnitOfWork]: with UnitOfWork(session_factory=session_factory) as uow: yield uow app.dependency_overrides[get_unit_of_work] = _override_uow client = TestClient(app) try: yield client finally: client.close() def _create_project_with_scenarios( session_factory: sessionmaker, scenario_overrides: list[dict[str, object]], ) -> tuple[int, list[int]]: with UnitOfWork(session_factory=session_factory) as uow: project_name = f"Project {uuid4()}" project = Project(name=project_name, operation_type=MiningOperationType.OPEN_PIT) uow.projects.create(project) scenario_ids: list[int] = [] for index, overrides in enumerate(scenario_overrides, start=1): scenario = Scenario( project_id=project.id, name=f"Scenario {index}", status=overrides.get("status", ScenarioStatus.DRAFT), start_date=overrides.get("start_date", date(2025, 1, 1)), end_date=overrides.get("end_date", date(2025, 12, 31)), currency=overrides.get("currency", "USD"), primary_resource=overrides.get( "primary_resource", ResourceType.DIESEL), ) uow.scenarios.create(scenario) scenario_ids.append(scenario.id) return project.id, scenario_ids class TestScenarioComparisonEndpoint: def test_returns_scenarios_when_validation_passes( self, api_client: TestClient, session_factory: sessionmaker ) -> None: project_id, scenario_ids = _create_project_with_scenarios( session_factory, [ {}, {"start_date": date(2025, 6, 1), "end_date": date(2025, 12, 31)}, ], ) response = api_client.post( f"/projects/{project_id}/scenarios/compare", json={"scenario_ids": scenario_ids}, ) assert response.status_code == 200 payload = ScenarioComparisonResponse.model_validate(response.json()) assert payload.project_id == project_id assert {scenario.id for scenario in payload.scenarios} == set( scenario_ids) def test_returns_422_when_currency_mismatch( self, api_client: TestClient, session_factory: sessionmaker ) -> None: project_id, scenario_ids = _create_project_with_scenarios( session_factory, [{"currency": "USD"}, {"currency": "CAD"}], ) response = api_client.post( f"/projects/{project_id}/scenarios/compare", json={"scenario_ids": scenario_ids}, ) assert response.status_code == 422 detail = response.json()["detail"] assert detail["code"] == "SCENARIO_CURRENCY_MISMATCH" def test_returns_422_when_second_scenario_from_other_project( self, api_client: TestClient, session_factory: sessionmaker ) -> None: project_a_id, scenario_ids_a = _create_project_with_scenarios( session_factory, [{}]) project_b_id, scenario_ids_b = _create_project_with_scenarios( session_factory, [{}]) response = api_client.post( f"/projects/{project_a_id}/scenarios/compare", json={"scenario_ids": [scenario_ids_a[0], scenario_ids_b[0]]}, ) assert response.status_code == 422 detail = response.json()["detail"] assert detail["code"] == "SCENARIO_PROJECT_MISMATCH" assert project_a_id != project_b_id