352 lines
12 KiB
Python
352 lines
12 KiB
Python
from __future__ import annotations
|
|
|
|
import secrets
|
|
from datetime import date
|
|
from collections.abc import Iterator
|
|
from typing import cast
|
|
from uuid import uuid4
|
|
|
|
import pytest
|
|
from fastapi import FastAPI, Request
|
|
from fastapi.testclient import TestClient
|
|
from pydantic import ValidationError
|
|
from sqlalchemy import create_engine
|
|
from sqlalchemy.orm import sessionmaker
|
|
from sqlalchemy.engine import Engine
|
|
from sqlalchemy.pool import StaticPool
|
|
|
|
from config.database import Base
|
|
from dependencies import get_auth_session, get_unit_of_work
|
|
from models import (
|
|
MiningOperationType,
|
|
Project,
|
|
ResourceType,
|
|
Scenario,
|
|
ScenarioStatus,
|
|
User,
|
|
)
|
|
from schemas.scenario import (
|
|
ScenarioComparisonRequest,
|
|
ScenarioComparisonResponse,
|
|
)
|
|
from services.exceptions import ScenarioValidationError
|
|
from services.scenario_validation import ScenarioComparisonValidator
|
|
from services.unit_of_work import UnitOfWork
|
|
from services.session import AuthSession, SessionTokens
|
|
from routes.scenarios import router as scenarios_router
|
|
|
|
|
|
@pytest.fixture()
|
|
def validator() -> ScenarioComparisonValidator:
|
|
return ScenarioComparisonValidator()
|
|
|
|
|
|
def _make_scenario(**overrides) -> Scenario:
|
|
project_id: int = int(overrides.get("project_id", 1))
|
|
name: str = str(overrides.get("name", "Scenario"))
|
|
status = cast(ScenarioStatus, overrides.get(
|
|
"status", ScenarioStatus.DRAFT))
|
|
start_date = overrides.get("start_date", date(2025, 1, 1))
|
|
end_date = overrides.get("end_date", date(2025, 12, 31))
|
|
currency = cast(str, overrides.get("currency", "USD"))
|
|
primary_resource = cast(ResourceType, overrides.get(
|
|
"primary_resource", ResourceType.DIESEL))
|
|
|
|
scenario = Scenario(
|
|
project_id=project_id,
|
|
name=name,
|
|
status=status,
|
|
start_date=start_date,
|
|
end_date=end_date,
|
|
currency=currency,
|
|
primary_resource=primary_resource,
|
|
)
|
|
|
|
if "id" in overrides:
|
|
scenario.id = overrides["id"]
|
|
|
|
return scenario
|
|
|
|
|
|
class TestScenarioComparisonValidator:
|
|
def test_validate_allows_matching_scenarios(self, validator: ScenarioComparisonValidator) -> None:
|
|
scenario_a = _make_scenario(id=1)
|
|
scenario_b = _make_scenario(id=2)
|
|
|
|
validator.validate([scenario_a, scenario_b])
|
|
|
|
@pytest.mark.parametrize(
|
|
"kwargs_a, kwargs_b, expected_code",
|
|
[
|
|
({"project_id": 1}, {"project_id": 2}, "SCENARIO_PROJECT_MISMATCH"),
|
|
({"status": ScenarioStatus.ARCHIVED}, {}, "SCENARIO_STATUS_INVALID"),
|
|
({"currency": "USD"}, {"currency": "CAD"},
|
|
"SCENARIO_CURRENCY_MISMATCH"),
|
|
(
|
|
{"start_date": date(2025, 1, 1), "end_date": date(2025, 6, 1)},
|
|
{"start_date": date(2025, 7, 1),
|
|
"end_date": date(2025, 12, 31)},
|
|
"SCENARIO_TIMELINE_DISJOINT",
|
|
),
|
|
({"primary_resource": ResourceType.DIESEL}, {
|
|
"primary_resource": ResourceType.ELECTRICITY}, "SCENARIO_RESOURCE_MISMATCH"),
|
|
],
|
|
)
|
|
def test_validate_raises_for_conflicts(
|
|
self,
|
|
validator: ScenarioComparisonValidator,
|
|
kwargs_a: dict[str, object],
|
|
kwargs_b: dict[str, object],
|
|
expected_code: str,
|
|
) -> None:
|
|
scenario_a = _make_scenario(id=10, **kwargs_a)
|
|
scenario_b = _make_scenario(id=20, **kwargs_b)
|
|
|
|
with pytest.raises(ScenarioValidationError) as exc_info:
|
|
validator.validate([scenario_a, scenario_b])
|
|
|
|
assert exc_info.value.code == expected_code
|
|
|
|
def test_timeline_rule_skips_when_insufficient_ranges(
|
|
self, validator: ScenarioComparisonValidator
|
|
) -> None:
|
|
scenario_a = _make_scenario(id=1, start_date=None, end_date=None)
|
|
scenario_b = _make_scenario(id=2, start_date=date(
|
|
2025, 1, 1), end_date=date(2025, 12, 31))
|
|
|
|
validator.validate([scenario_a, scenario_b])
|
|
|
|
|
|
class TestScenarioComparisonRequest:
|
|
def test_requires_two_unique_identifiers(self) -> None:
|
|
with pytest.raises(ValidationError):
|
|
ScenarioComparisonRequest.model_validate({"scenario_ids": [1]})
|
|
|
|
def test_deduplicates_ids_preserving_order(self) -> None:
|
|
payload = ScenarioComparisonRequest.model_validate(
|
|
{"scenario_ids": [1, 1, 2, 2, 3]})
|
|
|
|
assert payload.scenario_ids == [1, 2, 3]
|
|
|
|
|
|
@pytest.fixture()
|
|
def engine() -> Iterator[Engine]:
|
|
engine = create_engine(
|
|
"sqlite+pysqlite:///:memory:",
|
|
future=True,
|
|
connect_args={"check_same_thread": False},
|
|
poolclass=StaticPool,
|
|
)
|
|
Base.metadata.create_all(bind=engine)
|
|
try:
|
|
yield engine
|
|
finally:
|
|
Base.metadata.drop_all(bind=engine)
|
|
engine.dispose()
|
|
|
|
|
|
@pytest.fixture()
|
|
def session_factory(engine: Engine) -> Iterator[sessionmaker]:
|
|
testing_session = sessionmaker(
|
|
bind=engine, expire_on_commit=False, future=True)
|
|
yield testing_session
|
|
|
|
|
|
@pytest.fixture()
|
|
def api_client(session_factory) -> Iterator[TestClient]:
|
|
app = FastAPI()
|
|
app.include_router(scenarios_router)
|
|
|
|
def _override_uow() -> Iterator[UnitOfWork]:
|
|
with UnitOfWork(session_factory=session_factory) as uow:
|
|
yield uow
|
|
|
|
app.dependency_overrides[get_unit_of_work] = _override_uow
|
|
|
|
with UnitOfWork(session_factory=session_factory) as uow:
|
|
assert uow.users is not None
|
|
uow.ensure_default_roles()
|
|
user = User(
|
|
email="test-scenarios@example.com",
|
|
username="scenario-tester",
|
|
password_hash=User.hash_password("password"),
|
|
is_active=True,
|
|
is_superuser=True,
|
|
)
|
|
uow.users.create(user)
|
|
user = uow.users.get(user.id, with_roles=True)
|
|
|
|
def _override_auth_session(request: Request) -> AuthSession:
|
|
tokens = SessionTokens(
|
|
access_token=secrets.token_urlsafe(16),
|
|
refresh_token=secrets.token_urlsafe(16),
|
|
)
|
|
session = AuthSession(tokens=tokens)
|
|
session.user = user
|
|
request.state.auth_session = session
|
|
return session
|
|
|
|
app.dependency_overrides[get_auth_session] = _override_auth_session
|
|
client = TestClient(app)
|
|
try:
|
|
yield client
|
|
finally:
|
|
client.close()
|
|
|
|
|
|
def _create_project_with_scenarios(
|
|
session_factory: sessionmaker,
|
|
scenario_overrides: list[dict[str, object]],
|
|
) -> tuple[int, list[int]]:
|
|
with UnitOfWork(session_factory=session_factory) as uow:
|
|
assert uow.projects is not None
|
|
assert uow.scenarios is not None
|
|
project_name = f"Project {uuid4()}"
|
|
project = Project(name=project_name,
|
|
operation_type=MiningOperationType.OPEN_PIT)
|
|
uow.projects.create(project)
|
|
|
|
scenario_ids: list[int] = []
|
|
for index, overrides in enumerate(scenario_overrides, start=1):
|
|
scenario = Scenario(
|
|
project_id=project.id,
|
|
name=f"Scenario {index}",
|
|
status=overrides.get("status", ScenarioStatus.DRAFT),
|
|
start_date=overrides.get("start_date", date(2025, 1, 1)),
|
|
end_date=overrides.get("end_date", date(2025, 12, 31)),
|
|
currency=overrides.get("currency", "USD"),
|
|
primary_resource=overrides.get(
|
|
"primary_resource", ResourceType.DIESEL),
|
|
)
|
|
uow.scenarios.create(scenario)
|
|
scenario_ids.append(scenario.id)
|
|
|
|
return project.id, scenario_ids
|
|
|
|
|
|
class TestScenarioComparisonEndpoint:
|
|
def test_returns_scenarios_when_validation_passes(
|
|
self, api_client: TestClient, session_factory: sessionmaker
|
|
) -> None:
|
|
project_id, scenario_ids = _create_project_with_scenarios(
|
|
session_factory,
|
|
[
|
|
{},
|
|
{"start_date": date(2025, 6, 1),
|
|
"end_date": date(2025, 12, 31)},
|
|
],
|
|
)
|
|
|
|
response = api_client.post(
|
|
f"/projects/{project_id}/scenarios/compare",
|
|
json={"scenario_ids": scenario_ids},
|
|
)
|
|
|
|
assert response.status_code == 200
|
|
payload = ScenarioComparisonResponse.model_validate(response.json())
|
|
assert payload.project_id == project_id
|
|
assert {scenario.id for scenario in payload.scenarios} == set(
|
|
scenario_ids)
|
|
|
|
def test_returns_422_when_currency_mismatch(
|
|
self, api_client: TestClient, session_factory: sessionmaker
|
|
) -> None:
|
|
project_id, scenario_ids = _create_project_with_scenarios(
|
|
session_factory,
|
|
[{"currency": "USD"}, {"currency": "CAD"}],
|
|
)
|
|
|
|
response = api_client.post(
|
|
f"/projects/{project_id}/scenarios/compare",
|
|
json={"scenario_ids": scenario_ids},
|
|
)
|
|
|
|
assert response.status_code == 422
|
|
detail = response.json()["detail"]
|
|
assert detail["code"] == "SCENARIO_CURRENCY_MISMATCH"
|
|
|
|
def test_returns_422_when_second_scenario_from_other_project(
|
|
self, api_client: TestClient, session_factory: sessionmaker
|
|
) -> None:
|
|
project_a_id, scenario_ids_a = _create_project_with_scenarios(
|
|
session_factory, [{}])
|
|
project_b_id, scenario_ids_b = _create_project_with_scenarios(
|
|
session_factory, [{}])
|
|
|
|
response = api_client.post(
|
|
f"/projects/{project_a_id}/scenarios/compare",
|
|
json={"scenario_ids": [scenario_ids_a[0], scenario_ids_b[0]]},
|
|
)
|
|
|
|
assert response.status_code == 422
|
|
detail = response.json()["detail"]
|
|
assert detail["code"] == "SCENARIO_PROJECT_MISMATCH"
|
|
assert project_a_id != project_b_id
|
|
|
|
|
|
class TestScenarioApiCurrencyValidation:
|
|
def test_create_api_rejects_invalid_currency(
|
|
self,
|
|
api_client: TestClient,
|
|
session_factory: sessionmaker,
|
|
) -> None:
|
|
with UnitOfWork(session_factory=session_factory) as uow:
|
|
assert uow.projects is not None
|
|
assert uow.scenarios is not None
|
|
project = Project(
|
|
name="Currency Validation Project",
|
|
operation_type=MiningOperationType.OPEN_PIT,
|
|
)
|
|
uow.projects.create(project)
|
|
project_id = project.id
|
|
|
|
response = api_client.post(
|
|
f"/projects/{project_id}/scenarios",
|
|
json={
|
|
"name": "Invalid Currency Scenario",
|
|
"currency": "US",
|
|
},
|
|
)
|
|
|
|
assert response.status_code == 422
|
|
detail = response.json().get("detail", [])
|
|
assert any(
|
|
"Invalid currency code" in item.get("msg", "") for item in detail
|
|
), detail
|
|
|
|
with UnitOfWork(session_factory=session_factory) as uow:
|
|
assert uow.scenarios is not None
|
|
scenarios = uow.scenarios.list_for_project(project_id)
|
|
assert scenarios == []
|
|
|
|
def test_create_api_normalises_currency(
|
|
self,
|
|
api_client: TestClient,
|
|
session_factory: sessionmaker,
|
|
) -> None:
|
|
with UnitOfWork(session_factory=session_factory) as uow:
|
|
assert uow.projects is not None
|
|
assert uow.scenarios is not None
|
|
project = Project(
|
|
name="Currency Normalisation Project",
|
|
operation_type=MiningOperationType.OPEN_PIT,
|
|
)
|
|
uow.projects.create(project)
|
|
project_id = project.id
|
|
|
|
response = api_client.post(
|
|
f"/projects/{project_id}/scenarios",
|
|
json={
|
|
"name": "Normalised Currency Scenario",
|
|
"currency": "cad",
|
|
},
|
|
)
|
|
|
|
assert response.status_code == 201
|
|
|
|
with UnitOfWork(session_factory=session_factory) as uow:
|
|
assert uow.scenarios is not None
|
|
scenarios = uow.scenarios.list_for_project(project_id)
|
|
assert len(scenarios) == 1
|
|
assert scenarios[0].currency == "CAD"
|