from __future__ import annotations import secrets from datetime import date from collections.abc import Iterator from typing import cast from uuid import uuid4 import pytest from fastapi import FastAPI, Request from fastapi.testclient import TestClient from pydantic import ValidationError from sqlalchemy import create_engine from sqlalchemy.orm import sessionmaker from sqlalchemy.engine import Engine from sqlalchemy.pool import StaticPool from config.database import Base from dependencies import get_auth_session, get_unit_of_work from models import ( MiningOperationType, Project, ResourceType, Scenario, ScenarioStatus, User, ) from schemas.scenario import ( ScenarioComparisonRequest, ScenarioComparisonResponse, ) from services.exceptions import ScenarioValidationError from services.scenario_validation import ScenarioComparisonValidator from services.unit_of_work import UnitOfWork from services.session import AuthSession, SessionTokens from routes.scenarios import router as scenarios_router @pytest.fixture() def validator() -> ScenarioComparisonValidator: return ScenarioComparisonValidator() def _make_scenario(**overrides) -> Scenario: project_id: int = int(overrides.get("project_id", 1)) name: str = str(overrides.get("name", "Scenario")) status = cast(ScenarioStatus, overrides.get( "status", ScenarioStatus.DRAFT)) start_date = overrides.get("start_date", date(2025, 1, 1)) end_date = overrides.get("end_date", date(2025, 12, 31)) currency = cast(str, overrides.get("currency", "USD")) primary_resource = cast(ResourceType, overrides.get( "primary_resource", ResourceType.DIESEL)) scenario = Scenario( project_id=project_id, name=name, status=status, start_date=start_date, end_date=end_date, currency=currency, primary_resource=primary_resource, ) if "id" in overrides: scenario.id = overrides["id"] return scenario class TestScenarioComparisonValidator: def test_validate_allows_matching_scenarios(self, validator: ScenarioComparisonValidator) -> None: scenario_a = _make_scenario(id=1) scenario_b = _make_scenario(id=2) validator.validate([scenario_a, scenario_b]) @pytest.mark.parametrize( "kwargs_a, kwargs_b, expected_code", [ ({"project_id": 1}, {"project_id": 2}, "SCENARIO_PROJECT_MISMATCH"), ({"status": ScenarioStatus.ARCHIVED}, {}, "SCENARIO_STATUS_INVALID"), ({"currency": "USD"}, {"currency": "CAD"}, "SCENARIO_CURRENCY_MISMATCH"), ( {"start_date": date(2025, 1, 1), "end_date": date(2025, 6, 1)}, {"start_date": date(2025, 7, 1), "end_date": date(2025, 12, 31)}, "SCENARIO_TIMELINE_DISJOINT", ), ({"primary_resource": ResourceType.DIESEL}, { "primary_resource": ResourceType.ELECTRICITY}, "SCENARIO_RESOURCE_MISMATCH"), ], ) def test_validate_raises_for_conflicts( self, validator: ScenarioComparisonValidator, kwargs_a: dict[str, object], kwargs_b: dict[str, object], expected_code: str, ) -> None: scenario_a = _make_scenario(id=10, **kwargs_a) scenario_b = _make_scenario(id=20, **kwargs_b) with pytest.raises(ScenarioValidationError) as exc_info: validator.validate([scenario_a, scenario_b]) assert exc_info.value.code == expected_code def test_timeline_rule_skips_when_insufficient_ranges( self, validator: ScenarioComparisonValidator ) -> None: scenario_a = _make_scenario(id=1, start_date=None, end_date=None) scenario_b = _make_scenario(id=2, start_date=date( 2025, 1, 1), end_date=date(2025, 12, 31)) validator.validate([scenario_a, scenario_b]) class TestScenarioComparisonRequest: def test_requires_two_unique_identifiers(self) -> None: with pytest.raises(ValidationError): ScenarioComparisonRequest.model_validate({"scenario_ids": [1]}) def test_deduplicates_ids_preserving_order(self) -> None: payload = ScenarioComparisonRequest.model_validate( {"scenario_ids": [1, 1, 2, 2, 3]}) assert payload.scenario_ids == [1, 2, 3] @pytest.fixture() def engine() -> Iterator[Engine]: engine = create_engine( "sqlite+pysqlite:///:memory:", future=True, connect_args={"check_same_thread": False}, poolclass=StaticPool, ) Base.metadata.create_all(bind=engine) try: yield engine finally: Base.metadata.drop_all(bind=engine) engine.dispose() @pytest.fixture() def session_factory(engine: Engine) -> Iterator[sessionmaker]: testing_session = sessionmaker( bind=engine, expire_on_commit=False, future=True) yield testing_session @pytest.fixture() def api_client(session_factory) -> Iterator[TestClient]: app = FastAPI() app.include_router(scenarios_router) def _override_uow() -> Iterator[UnitOfWork]: with UnitOfWork(session_factory=session_factory) as uow: yield uow app.dependency_overrides[get_unit_of_work] = _override_uow with UnitOfWork(session_factory=session_factory) as uow: assert uow.users is not None uow.ensure_default_roles() user = User( email="test-scenarios@example.com", username="scenario-tester", password_hash=User.hash_password("password"), is_active=True, is_superuser=True, ) uow.users.create(user) user = uow.users.get(user.id, with_roles=True) def _override_auth_session(request: Request) -> AuthSession: tokens = SessionTokens( access_token=secrets.token_urlsafe(16), refresh_token=secrets.token_urlsafe(16), ) session = AuthSession(tokens=tokens) session.user = user request.state.auth_session = session return session app.dependency_overrides[get_auth_session] = _override_auth_session client = TestClient(app) try: yield client finally: client.close() def _create_project_with_scenarios( session_factory: sessionmaker, scenario_overrides: list[dict[str, object]], ) -> tuple[int, list[int]]: with UnitOfWork(session_factory=session_factory) as uow: assert uow.projects is not None assert uow.scenarios is not None project_name = f"Project {uuid4()}" project = Project(name=project_name, operation_type=MiningOperationType.OPEN_PIT) uow.projects.create(project) scenario_ids: list[int] = [] for index, overrides in enumerate(scenario_overrides, start=1): scenario = Scenario( project_id=project.id, name=f"Scenario {index}", status=overrides.get("status", ScenarioStatus.DRAFT), start_date=overrides.get("start_date", date(2025, 1, 1)), end_date=overrides.get("end_date", date(2025, 12, 31)), currency=overrides.get("currency", "USD"), primary_resource=overrides.get( "primary_resource", ResourceType.DIESEL), ) uow.scenarios.create(scenario) scenario_ids.append(scenario.id) return project.id, scenario_ids class TestScenarioComparisonEndpoint: def test_returns_scenarios_when_validation_passes( self, api_client: TestClient, session_factory: sessionmaker ) -> None: project_id, scenario_ids = _create_project_with_scenarios( session_factory, [ {}, {"start_date": date(2025, 6, 1), "end_date": date(2025, 12, 31)}, ], ) response = api_client.post( f"/projects/{project_id}/scenarios/compare", json={"scenario_ids": scenario_ids}, ) assert response.status_code == 200 payload = ScenarioComparisonResponse.model_validate(response.json()) assert payload.project_id == project_id assert {scenario.id for scenario in payload.scenarios} == set( scenario_ids) def test_returns_422_when_currency_mismatch( self, api_client: TestClient, session_factory: sessionmaker ) -> None: project_id, scenario_ids = _create_project_with_scenarios( session_factory, [{"currency": "USD"}, {"currency": "CAD"}], ) response = api_client.post( f"/projects/{project_id}/scenarios/compare", json={"scenario_ids": scenario_ids}, ) assert response.status_code == 422 detail = response.json()["detail"] assert detail["code"] == "SCENARIO_CURRENCY_MISMATCH" def test_returns_422_when_second_scenario_from_other_project( self, api_client: TestClient, session_factory: sessionmaker ) -> None: project_a_id, scenario_ids_a = _create_project_with_scenarios( session_factory, [{}]) project_b_id, scenario_ids_b = _create_project_with_scenarios( session_factory, [{}]) response = api_client.post( f"/projects/{project_a_id}/scenarios/compare", json={"scenario_ids": [scenario_ids_a[0], scenario_ids_b[0]]}, ) assert response.status_code == 422 detail = response.json()["detail"] assert detail["code"] == "SCENARIO_PROJECT_MISMATCH" assert project_a_id != project_b_id class TestScenarioApiCurrencyValidation: def test_create_api_rejects_invalid_currency( self, api_client: TestClient, session_factory: sessionmaker, ) -> None: with UnitOfWork(session_factory=session_factory) as uow: assert uow.projects is not None assert uow.scenarios is not None project = Project( name="Currency Validation Project", operation_type=MiningOperationType.OPEN_PIT, ) uow.projects.create(project) project_id = project.id response = api_client.post( f"/projects/{project_id}/scenarios", json={ "name": "Invalid Currency Scenario", "currency": "US", }, ) assert response.status_code == 422 detail = response.json().get("detail", []) assert any( "Invalid currency code" in item.get("msg", "") for item in detail ), detail with UnitOfWork(session_factory=session_factory) as uow: assert uow.scenarios is not None scenarios = uow.scenarios.list_for_project(project_id) assert scenarios == [] def test_create_api_normalises_currency( self, api_client: TestClient, session_factory: sessionmaker, ) -> None: with UnitOfWork(session_factory=session_factory) as uow: assert uow.projects is not None assert uow.scenarios is not None project = Project( name="Currency Normalisation Project", operation_type=MiningOperationType.OPEN_PIT, ) uow.projects.create(project) project_id = project.id response = api_client.post( f"/projects/{project_id}/scenarios", json={ "name": "Normalised Currency Scenario", "currency": "cad", }, ) assert response.status_code == 201 with UnitOfWork(session_factory=session_factory) as uow: assert uow.scenarios is not None scenarios = uow.scenarios.list_for_project(project_id) assert len(scenarios) == 1 assert scenarios[0].currency == "CAD"