Files
calminer/tests/test_scenario_validation.py
zwitschi 795a9f99f4 feat: Enhance currency handling and validation across scenarios
- Updated form template to prefill currency input with default value and added help text for clarity.
- Modified integration tests to assert more descriptive error messages for invalid currency codes.
- Introduced new tests for currency normalization and validation in various scenarios, including imports and exports.
- Added comprehensive tests for pricing calculations, ensuring defaults are respected and overrides function correctly.
- Implemented unit tests for pricing settings repository, ensuring CRUD operations and default settings are handled properly.
- Enhanced scenario pricing evaluation tests to validate currency handling and metadata defaults.
- Added simulation tests to ensure Monte Carlo runs are accurate and handle various distribution scenarios.
2025-11-11 18:29:59 +01:00

348 lines
12 KiB
Python

from __future__ import annotations
from datetime import date
from collections.abc import Iterator
from typing import cast
from uuid import uuid4
import pytest
from fastapi import FastAPI, Request
from fastapi.testclient import TestClient
from pydantic import ValidationError
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from sqlalchemy.engine import Engine
from sqlalchemy.pool import StaticPool
from config.database import Base
from dependencies import get_auth_session, get_unit_of_work
from models import (
MiningOperationType,
Project,
ResourceType,
Scenario,
ScenarioStatus,
User,
)
from schemas.scenario import (
ScenarioComparisonRequest,
ScenarioComparisonResponse,
)
from services.exceptions import ScenarioValidationError
from services.scenario_validation import ScenarioComparisonValidator
from services.unit_of_work import UnitOfWork
from services.session import AuthSession, SessionTokens
from routes.scenarios import router as scenarios_router
@pytest.fixture()
def validator() -> ScenarioComparisonValidator:
return ScenarioComparisonValidator()
def _make_scenario(**overrides) -> Scenario:
project_id: int = int(overrides.get("project_id", 1))
name: str = str(overrides.get("name", "Scenario"))
status = cast(ScenarioStatus, overrides.get(
"status", ScenarioStatus.DRAFT))
start_date = overrides.get("start_date", date(2025, 1, 1))
end_date = overrides.get("end_date", date(2025, 12, 31))
currency = cast(str, overrides.get("currency", "USD"))
primary_resource = cast(ResourceType, overrides.get(
"primary_resource", ResourceType.DIESEL))
scenario = Scenario(
project_id=project_id,
name=name,
status=status,
start_date=start_date,
end_date=end_date,
currency=currency,
primary_resource=primary_resource,
)
if "id" in overrides:
scenario.id = overrides["id"]
return scenario
class TestScenarioComparisonValidator:
def test_validate_allows_matching_scenarios(self, validator: ScenarioComparisonValidator) -> None:
scenario_a = _make_scenario(id=1)
scenario_b = _make_scenario(id=2)
validator.validate([scenario_a, scenario_b])
@pytest.mark.parametrize(
"kwargs_a, kwargs_b, expected_code",
[
({"project_id": 1}, {"project_id": 2}, "SCENARIO_PROJECT_MISMATCH"),
({"status": ScenarioStatus.ARCHIVED}, {}, "SCENARIO_STATUS_INVALID"),
({"currency": "USD"}, {"currency": "CAD"},
"SCENARIO_CURRENCY_MISMATCH"),
(
{"start_date": date(2025, 1, 1), "end_date": date(2025, 6, 1)},
{"start_date": date(2025, 7, 1),
"end_date": date(2025, 12, 31)},
"SCENARIO_TIMELINE_DISJOINT",
),
({"primary_resource": ResourceType.DIESEL}, {
"primary_resource": ResourceType.ELECTRICITY}, "SCENARIO_RESOURCE_MISMATCH"),
],
)
def test_validate_raises_for_conflicts(
self,
validator: ScenarioComparisonValidator,
kwargs_a: dict[str, object],
kwargs_b: dict[str, object],
expected_code: str,
) -> None:
scenario_a = _make_scenario(id=10, **kwargs_a)
scenario_b = _make_scenario(id=20, **kwargs_b)
with pytest.raises(ScenarioValidationError) as exc_info:
validator.validate([scenario_a, scenario_b])
assert exc_info.value.code == expected_code
def test_timeline_rule_skips_when_insufficient_ranges(
self, validator: ScenarioComparisonValidator
) -> None:
scenario_a = _make_scenario(id=1, start_date=None, end_date=None)
scenario_b = _make_scenario(id=2, start_date=date(
2025, 1, 1), end_date=date(2025, 12, 31))
validator.validate([scenario_a, scenario_b])
class TestScenarioComparisonRequest:
def test_requires_two_unique_identifiers(self) -> None:
with pytest.raises(ValidationError):
ScenarioComparisonRequest.model_validate({"scenario_ids": [1]})
def test_deduplicates_ids_preserving_order(self) -> None:
payload = ScenarioComparisonRequest.model_validate(
{"scenario_ids": [1, 1, 2, 2, 3]})
assert payload.scenario_ids == [1, 2, 3]
@pytest.fixture()
def engine() -> Iterator[Engine]:
engine = create_engine(
"sqlite+pysqlite:///:memory:",
future=True,
connect_args={"check_same_thread": False},
poolclass=StaticPool,
)
Base.metadata.create_all(bind=engine)
try:
yield engine
finally:
Base.metadata.drop_all(bind=engine)
engine.dispose()
@pytest.fixture()
def session_factory(engine: Engine) -> Iterator[sessionmaker]:
testing_session = sessionmaker(
bind=engine, expire_on_commit=False, future=True)
yield testing_session
@pytest.fixture()
def api_client(session_factory) -> Iterator[TestClient]:
app = FastAPI()
app.include_router(scenarios_router)
def _override_uow() -> Iterator[UnitOfWork]:
with UnitOfWork(session_factory=session_factory) as uow:
yield uow
app.dependency_overrides[get_unit_of_work] = _override_uow
with UnitOfWork(session_factory=session_factory) as uow:
assert uow.users is not None
uow.ensure_default_roles()
user = User(
email="test-scenarios@example.com",
username="scenario-tester",
password_hash=User.hash_password("password"),
is_active=True,
is_superuser=True,
)
uow.users.create(user)
user = uow.users.get(user.id, with_roles=True)
def _override_auth_session(request: Request) -> AuthSession:
session = AuthSession(tokens=SessionTokens(
access_token="test", refresh_token="test"))
session.user = user
request.state.auth_session = session
return session
app.dependency_overrides[get_auth_session] = _override_auth_session
client = TestClient(app)
try:
yield client
finally:
client.close()
def _create_project_with_scenarios(
session_factory: sessionmaker,
scenario_overrides: list[dict[str, object]],
) -> tuple[int, list[int]]:
with UnitOfWork(session_factory=session_factory) as uow:
assert uow.projects is not None
assert uow.scenarios is not None
project_name = f"Project {uuid4()}"
project = Project(name=project_name,
operation_type=MiningOperationType.OPEN_PIT)
uow.projects.create(project)
scenario_ids: list[int] = []
for index, overrides in enumerate(scenario_overrides, start=1):
scenario = Scenario(
project_id=project.id,
name=f"Scenario {index}",
status=overrides.get("status", ScenarioStatus.DRAFT),
start_date=overrides.get("start_date", date(2025, 1, 1)),
end_date=overrides.get("end_date", date(2025, 12, 31)),
currency=overrides.get("currency", "USD"),
primary_resource=overrides.get(
"primary_resource", ResourceType.DIESEL),
)
uow.scenarios.create(scenario)
scenario_ids.append(scenario.id)
return project.id, scenario_ids
class TestScenarioComparisonEndpoint:
def test_returns_scenarios_when_validation_passes(
self, api_client: TestClient, session_factory: sessionmaker
) -> None:
project_id, scenario_ids = _create_project_with_scenarios(
session_factory,
[
{},
{"start_date": date(2025, 6, 1),
"end_date": date(2025, 12, 31)},
],
)
response = api_client.post(
f"/projects/{project_id}/scenarios/compare",
json={"scenario_ids": scenario_ids},
)
assert response.status_code == 200
payload = ScenarioComparisonResponse.model_validate(response.json())
assert payload.project_id == project_id
assert {scenario.id for scenario in payload.scenarios} == set(
scenario_ids)
def test_returns_422_when_currency_mismatch(
self, api_client: TestClient, session_factory: sessionmaker
) -> None:
project_id, scenario_ids = _create_project_with_scenarios(
session_factory,
[{"currency": "USD"}, {"currency": "CAD"}],
)
response = api_client.post(
f"/projects/{project_id}/scenarios/compare",
json={"scenario_ids": scenario_ids},
)
assert response.status_code == 422
detail = response.json()["detail"]
assert detail["code"] == "SCENARIO_CURRENCY_MISMATCH"
def test_returns_422_when_second_scenario_from_other_project(
self, api_client: TestClient, session_factory: sessionmaker
) -> None:
project_a_id, scenario_ids_a = _create_project_with_scenarios(
session_factory, [{}])
project_b_id, scenario_ids_b = _create_project_with_scenarios(
session_factory, [{}])
response = api_client.post(
f"/projects/{project_a_id}/scenarios/compare",
json={"scenario_ids": [scenario_ids_a[0], scenario_ids_b[0]]},
)
assert response.status_code == 422
detail = response.json()["detail"]
assert detail["code"] == "SCENARIO_PROJECT_MISMATCH"
assert project_a_id != project_b_id
class TestScenarioApiCurrencyValidation:
def test_create_api_rejects_invalid_currency(
self,
api_client: TestClient,
session_factory: sessionmaker,
) -> None:
with UnitOfWork(session_factory=session_factory) as uow:
assert uow.projects is not None
assert uow.scenarios is not None
project = Project(
name="Currency Validation Project",
operation_type=MiningOperationType.OPEN_PIT,
)
uow.projects.create(project)
project_id = project.id
response = api_client.post(
f"/projects/{project_id}/scenarios",
json={
"name": "Invalid Currency Scenario",
"currency": "US",
},
)
assert response.status_code == 422
detail = response.json().get("detail", [])
assert any(
"Invalid currency code" in item.get("msg", "") for item in detail
), detail
with UnitOfWork(session_factory=session_factory) as uow:
assert uow.scenarios is not None
scenarios = uow.scenarios.list_for_project(project_id)
assert scenarios == []
def test_create_api_normalises_currency(
self,
api_client: TestClient,
session_factory: sessionmaker,
) -> None:
with UnitOfWork(session_factory=session_factory) as uow:
assert uow.projects is not None
assert uow.scenarios is not None
project = Project(
name="Currency Normalisation Project",
operation_type=MiningOperationType.OPEN_PIT,
)
uow.projects.create(project)
project_id = project.id
response = api_client.post(
f"/projects/{project_id}/scenarios",
json={
"name": "Normalised Currency Scenario",
"currency": "cad",
},
)
assert response.status_code == 201
with UnitOfWork(session_factory=session_factory) as uow:
assert uow.scenarios is not None
scenarios = uow.scenarios.list_for_project(project_id)
assert len(scenarios) == 1
assert scenarios[0].currency == "CAD"