feat: implement scenario comparison validation and API endpoint with comprehensive unit tests
This commit is contained in:
@@ -9,3 +9,4 @@
|
|||||||
- Exposed project and scenario CRUD APIs with validated schemas and integrated them into the FastAPI application.
|
- Exposed project and scenario CRUD APIs with validated schemas and integrated them into the FastAPI application.
|
||||||
- Connected project and scenario routers to new Jinja2 list/detail/edit views with HTML forms and redirects.
|
- Connected project and scenario routers to new Jinja2 list/detail/edit views with HTML forms and redirects.
|
||||||
- Implemented FR-009 client-side enhancements with responsive navigation toggle, mobile-first scenario tables, and shared asset loading across templates.
|
- Implemented FR-009 client-side enhancements with responsive navigation toggle, mobile-first scenario tables, and shared asset loading across templates.
|
||||||
|
- Added scenario comparison validator, FastAPI comparison endpoint, and comprehensive unit tests to enforce FR-009 validation rules through API errors.
|
||||||
|
|||||||
@@ -9,8 +9,18 @@ from fastapi.templating import Jinja2Templates
|
|||||||
|
|
||||||
from dependencies import get_unit_of_work
|
from dependencies import get_unit_of_work
|
||||||
from models import ResourceType, Scenario, ScenarioStatus
|
from models import ResourceType, Scenario, ScenarioStatus
|
||||||
from schemas.scenario import ScenarioCreate, ScenarioRead, ScenarioUpdate
|
from schemas.scenario import (
|
||||||
from services.exceptions import EntityConflictError, EntityNotFoundError
|
ScenarioComparisonRequest,
|
||||||
|
ScenarioComparisonResponse,
|
||||||
|
ScenarioCreate,
|
||||||
|
ScenarioRead,
|
||||||
|
ScenarioUpdate,
|
||||||
|
)
|
||||||
|
from services.exceptions import (
|
||||||
|
EntityConflictError,
|
||||||
|
EntityNotFoundError,
|
||||||
|
ScenarioValidationError,
|
||||||
|
)
|
||||||
from services.unit_of_work import UnitOfWork
|
from services.unit_of_work import UnitOfWork
|
||||||
|
|
||||||
router = APIRouter(tags=["Scenarios"])
|
router = APIRouter(tags=["Scenarios"])
|
||||||
@@ -51,6 +61,53 @@ def list_scenarios_for_project(
|
|||||||
return [_to_read_model(scenario) for scenario in scenarios]
|
return [_to_read_model(scenario) for scenario in scenarios]
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/projects/{project_id}/scenarios/compare",
|
||||||
|
response_model=ScenarioComparisonResponse,
|
||||||
|
status_code=status.HTTP_200_OK,
|
||||||
|
)
|
||||||
|
def compare_scenarios(
|
||||||
|
project_id: int,
|
||||||
|
payload: ScenarioComparisonRequest,
|
||||||
|
uow: UnitOfWork = Depends(get_unit_of_work),
|
||||||
|
) -> ScenarioComparisonResponse:
|
||||||
|
try:
|
||||||
|
uow.projects.get(project_id)
|
||||||
|
except EntityNotFoundError as exc:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)
|
||||||
|
) from exc
|
||||||
|
|
||||||
|
try:
|
||||||
|
scenarios = uow.validate_scenarios_for_comparison(payload.scenario_ids)
|
||||||
|
if any(scenario.project_id != project_id for scenario in scenarios):
|
||||||
|
raise ScenarioValidationError(
|
||||||
|
code="SCENARIO_PROJECT_MISMATCH",
|
||||||
|
message="Selected scenarios do not belong to the same project.",
|
||||||
|
scenario_ids=[
|
||||||
|
scenario.id for scenario in scenarios if scenario.id is not None
|
||||||
|
],
|
||||||
|
)
|
||||||
|
except EntityNotFoundError as exc:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)
|
||||||
|
) from exc
|
||||||
|
except ScenarioValidationError as exc:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_422_UNPROCESSABLE_CONTENT,
|
||||||
|
detail={
|
||||||
|
"code": exc.code,
|
||||||
|
"message": exc.message,
|
||||||
|
"scenario_ids": list(exc.scenario_ids or []),
|
||||||
|
},
|
||||||
|
) from exc
|
||||||
|
|
||||||
|
return ScenarioComparisonResponse(
|
||||||
|
project_id=project_id,
|
||||||
|
scenarios=[_to_read_model(scenario) for scenario in scenarios],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@router.post(
|
@router.post(
|
||||||
"/projects/{project_id}/scenarios",
|
"/projects/{project_id}/scenarios",
|
||||||
response_model=ScenarioRead,
|
response_model=ScenarioRead,
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
from datetime import date, datetime
|
from datetime import date, datetime
|
||||||
|
|
||||||
from pydantic import BaseModel, ConfigDict, field_validator
|
from pydantic import BaseModel, ConfigDict, field_validator, model_validator
|
||||||
|
|
||||||
from models import ResourceType, ScenarioStatus
|
from models import ResourceType, ScenarioStatus
|
||||||
|
|
||||||
@@ -64,3 +64,24 @@ class ScenarioRead(ScenarioBase):
|
|||||||
updated_at: datetime
|
updated_at: datetime
|
||||||
|
|
||||||
model_config = ConfigDict(from_attributes=True)
|
model_config = ConfigDict(from_attributes=True)
|
||||||
|
|
||||||
|
|
||||||
|
class ScenarioComparisonRequest(BaseModel):
|
||||||
|
scenario_ids: list[int]
|
||||||
|
|
||||||
|
model_config = ConfigDict(extra="forbid")
|
||||||
|
|
||||||
|
@model_validator(mode="after")
|
||||||
|
def ensure_minimum_ids(self) -> "ScenarioComparisonRequest":
|
||||||
|
unique_ids: list[int] = list(dict.fromkeys(self.scenario_ids))
|
||||||
|
if len(unique_ids) < 2:
|
||||||
|
raise ValueError("At least two unique scenario identifiers are required for comparison.")
|
||||||
|
self.scenario_ids = unique_ids
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
|
class ScenarioComparisonResponse(BaseModel):
|
||||||
|
project_id: int
|
||||||
|
scenarios: list[ScenarioRead]
|
||||||
|
|
||||||
|
model_config = ConfigDict(from_attributes=True)
|
||||||
|
|||||||
@@ -1,5 +1,8 @@
|
|||||||
"""Domain-level exceptions for service and repository layers."""
|
"""Domain-level exceptions for service and repository layers."""
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Sequence
|
||||||
|
|
||||||
|
|
||||||
class EntityNotFoundError(Exception):
|
class EntityNotFoundError(Exception):
|
||||||
"""Raised when a requested entity cannot be located."""
|
"""Raised when a requested entity cannot be located."""
|
||||||
@@ -7,3 +10,15 @@ class EntityNotFoundError(Exception):
|
|||||||
|
|
||||||
class EntityConflictError(Exception):
|
class EntityConflictError(Exception):
|
||||||
"""Raised when attempting to create or update an entity that violates uniqueness."""
|
"""Raised when attempting to create or update an entity that violates uniqueness."""
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(eq=False)
|
||||||
|
class ScenarioValidationError(Exception):
|
||||||
|
"""Raised when scenarios fail comparison validation rules."""
|
||||||
|
|
||||||
|
code: str
|
||||||
|
message: str
|
||||||
|
scenario_ids: Sequence[int] | None = None
|
||||||
|
|
||||||
|
def __str__(self) -> str: # pragma: no cover - mirrors message for logging
|
||||||
|
return self.message
|
||||||
|
|||||||
106
services/scenario_validation.py
Normal file
106
services/scenario_validation.py
Normal file
@@ -0,0 +1,106 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from datetime import date
|
||||||
|
from typing import Iterable, Sequence
|
||||||
|
|
||||||
|
from models import Scenario, ScenarioStatus
|
||||||
|
from services.exceptions import ScenarioValidationError
|
||||||
|
|
||||||
|
ALLOWED_STATUSES: frozenset[ScenarioStatus] = frozenset(
|
||||||
|
{ScenarioStatus.DRAFT, ScenarioStatus.ACTIVE}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class _ValidationContext:
|
||||||
|
scenarios: Sequence[Scenario]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def scenario_ids(self) -> list[int]:
|
||||||
|
return [scenario.id for scenario in self.scenarios if scenario.id is not None]
|
||||||
|
|
||||||
|
|
||||||
|
class ScenarioComparisonValidator:
|
||||||
|
"""Validates scenarios prior to comparison workflows."""
|
||||||
|
|
||||||
|
def validate(self, scenarios: Sequence[Scenario] | Iterable[Scenario]) -> None:
|
||||||
|
scenario_list = list(scenarios)
|
||||||
|
if len(scenario_list) < 2:
|
||||||
|
# Nothing to validate when fewer than two scenarios are provided.
|
||||||
|
return
|
||||||
|
|
||||||
|
context = _ValidationContext(scenario_list)
|
||||||
|
|
||||||
|
self._ensure_same_project(context)
|
||||||
|
self._ensure_allowed_status(context)
|
||||||
|
self._ensure_shared_currency(context)
|
||||||
|
self._ensure_timeline_overlap(context)
|
||||||
|
self._ensure_shared_primary_resource(context)
|
||||||
|
|
||||||
|
def _ensure_same_project(self, context: _ValidationContext) -> None:
|
||||||
|
project_ids = {scenario.project_id for scenario in context.scenarios}
|
||||||
|
if len(project_ids) > 1:
|
||||||
|
raise ScenarioValidationError(
|
||||||
|
code="SCENARIO_PROJECT_MISMATCH",
|
||||||
|
message="Selected scenarios do not belong to the same project.",
|
||||||
|
scenario_ids=context.scenario_ids,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _ensure_allowed_status(self, context: _ValidationContext) -> None:
|
||||||
|
disallowed = [
|
||||||
|
scenario
|
||||||
|
for scenario in context.scenarios
|
||||||
|
if scenario.status not in ALLOWED_STATUSES
|
||||||
|
]
|
||||||
|
if disallowed:
|
||||||
|
raise ScenarioValidationError(
|
||||||
|
code="SCENARIO_STATUS_INVALID",
|
||||||
|
message="Archived scenarios cannot be compared.",
|
||||||
|
scenario_ids=[
|
||||||
|
scenario.id for scenario in disallowed if scenario.id is not None],
|
||||||
|
)
|
||||||
|
|
||||||
|
def _ensure_shared_currency(self, context: _ValidationContext) -> None:
|
||||||
|
currencies = {
|
||||||
|
scenario.currency
|
||||||
|
for scenario in context.scenarios
|
||||||
|
if scenario.currency is not None
|
||||||
|
}
|
||||||
|
if len(currencies) > 1:
|
||||||
|
raise ScenarioValidationError(
|
||||||
|
code="SCENARIO_CURRENCY_MISMATCH",
|
||||||
|
message="Scenarios use different currencies and cannot be compared.",
|
||||||
|
scenario_ids=context.scenario_ids,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _ensure_timeline_overlap(self, context: _ValidationContext) -> None:
|
||||||
|
ranges = [
|
||||||
|
(scenario.start_date, scenario.end_date)
|
||||||
|
for scenario in context.scenarios
|
||||||
|
if scenario.start_date and scenario.end_date
|
||||||
|
]
|
||||||
|
if len(ranges) < 2:
|
||||||
|
return
|
||||||
|
|
||||||
|
latest_start: date = max(start for start, _ in ranges)
|
||||||
|
earliest_end: date = min(end for _, end in ranges)
|
||||||
|
if latest_start > earliest_end:
|
||||||
|
raise ScenarioValidationError(
|
||||||
|
code="SCENARIO_TIMELINE_DISJOINT",
|
||||||
|
message="Scenario timelines do not overlap; adjust the comparison window.",
|
||||||
|
scenario_ids=context.scenario_ids,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _ensure_shared_primary_resource(self, context: _ValidationContext) -> None:
|
||||||
|
resources = {
|
||||||
|
scenario.primary_resource
|
||||||
|
for scenario in context.scenarios
|
||||||
|
if scenario.primary_resource is not None
|
||||||
|
}
|
||||||
|
if len(resources) > 1:
|
||||||
|
raise ScenarioValidationError(
|
||||||
|
code="SCENARIO_RESOURCE_MISMATCH",
|
||||||
|
message="Scenarios target different primary resources and cannot be compared.",
|
||||||
|
scenario_ids=context.scenario_ids,
|
||||||
|
)
|
||||||
@@ -1,17 +1,19 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from contextlib import AbstractContextManager
|
from contextlib import AbstractContextManager
|
||||||
from typing import Callable
|
from typing import Callable, Sequence
|
||||||
|
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from config.database import SessionLocal
|
from config.database import SessionLocal
|
||||||
|
from models import Scenario
|
||||||
from services.repositories import (
|
from services.repositories import (
|
||||||
FinancialInputRepository,
|
FinancialInputRepository,
|
||||||
ProjectRepository,
|
ProjectRepository,
|
||||||
ScenarioRepository,
|
ScenarioRepository,
|
||||||
SimulationParameterRepository,
|
SimulationParameterRepository,
|
||||||
)
|
)
|
||||||
|
from services.scenario_validation import ScenarioComparisonValidator
|
||||||
|
|
||||||
|
|
||||||
class UnitOfWork(AbstractContextManager["UnitOfWork"]):
|
class UnitOfWork(AbstractContextManager["UnitOfWork"]):
|
||||||
@@ -20,13 +22,16 @@ class UnitOfWork(AbstractContextManager["UnitOfWork"]):
|
|||||||
def __init__(self, session_factory: Callable[[], Session] = SessionLocal) -> None:
|
def __init__(self, session_factory: Callable[[], Session] = SessionLocal) -> None:
|
||||||
self._session_factory = session_factory
|
self._session_factory = session_factory
|
||||||
self.session: Session | None = None
|
self.session: Session | None = None
|
||||||
|
self._scenario_validator: ScenarioComparisonValidator | None = None
|
||||||
|
|
||||||
def __enter__(self) -> "UnitOfWork":
|
def __enter__(self) -> "UnitOfWork":
|
||||||
self.session = self._session_factory()
|
self.session = self._session_factory()
|
||||||
self.projects = ProjectRepository(self.session)
|
self.projects = ProjectRepository(self.session)
|
||||||
self.scenarios = ScenarioRepository(self.session)
|
self.scenarios = ScenarioRepository(self.session)
|
||||||
self.financial_inputs = FinancialInputRepository(self.session)
|
self.financial_inputs = FinancialInputRepository(self.session)
|
||||||
self.simulation_parameters = SimulationParameterRepository(self.session)
|
self.simulation_parameters = SimulationParameterRepository(
|
||||||
|
self.session)
|
||||||
|
self._scenario_validator = ScenarioComparisonValidator()
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def __exit__(self, exc_type, exc_value, traceback) -> None:
|
def __exit__(self, exc_type, exc_value, traceback) -> None:
|
||||||
@@ -36,6 +41,7 @@ class UnitOfWork(AbstractContextManager["UnitOfWork"]):
|
|||||||
else:
|
else:
|
||||||
self.session.rollback()
|
self.session.rollback()
|
||||||
self.session.close()
|
self.session.close()
|
||||||
|
self._scenario_validator = None
|
||||||
|
|
||||||
def flush(self) -> None:
|
def flush(self) -> None:
|
||||||
if not self.session:
|
if not self.session:
|
||||||
@@ -51,3 +57,21 @@ class UnitOfWork(AbstractContextManager["UnitOfWork"]):
|
|||||||
if not self.session:
|
if not self.session:
|
||||||
raise RuntimeError("UnitOfWork session is not initialised")
|
raise RuntimeError("UnitOfWork session is not initialised")
|
||||||
self.session.rollback()
|
self.session.rollback()
|
||||||
|
|
||||||
|
def validate_scenarios_for_comparison(
|
||||||
|
self, scenario_ids: Sequence[int]
|
||||||
|
) -> list[Scenario]:
|
||||||
|
if not self.session or not self._scenario_validator:
|
||||||
|
raise RuntimeError("UnitOfWork session is not initialised")
|
||||||
|
|
||||||
|
scenarios = [self.scenarios.get(scenario_id)
|
||||||
|
for scenario_id in scenario_ids]
|
||||||
|
self._scenario_validator.validate(scenarios)
|
||||||
|
return scenarios
|
||||||
|
|
||||||
|
def validate_scenario_models_for_comparison(
|
||||||
|
self, scenarios: Sequence[Scenario]
|
||||||
|
) -> None:
|
||||||
|
if not self._scenario_validator:
|
||||||
|
raise RuntimeError("UnitOfWork session is not initialised")
|
||||||
|
self._scenario_validator.validate(scenarios)
|
||||||
|
|||||||
254
tests/test_scenario_validation.py
Normal file
254
tests/test_scenario_validation.py
Normal file
@@ -0,0 +1,254 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from datetime import date
|
||||||
|
from collections.abc import Iterator
|
||||||
|
from typing import cast
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from fastapi import FastAPI
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
from pydantic import ValidationError
|
||||||
|
from sqlalchemy import create_engine
|
||||||
|
from sqlalchemy.orm import sessionmaker
|
||||||
|
from sqlalchemy.engine import Engine
|
||||||
|
from sqlalchemy.pool import StaticPool
|
||||||
|
|
||||||
|
from config.database import Base
|
||||||
|
from dependencies import get_unit_of_work
|
||||||
|
from models import (
|
||||||
|
MiningOperationType,
|
||||||
|
Project,
|
||||||
|
ResourceType,
|
||||||
|
Scenario,
|
||||||
|
ScenarioStatus,
|
||||||
|
)
|
||||||
|
from schemas.scenario import (
|
||||||
|
ScenarioComparisonRequest,
|
||||||
|
ScenarioComparisonResponse,
|
||||||
|
)
|
||||||
|
from services.exceptions import ScenarioValidationError
|
||||||
|
from services.scenario_validation import ScenarioComparisonValidator
|
||||||
|
from services.unit_of_work import UnitOfWork
|
||||||
|
from routes.scenarios import router as scenarios_router
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def validator() -> ScenarioComparisonValidator:
|
||||||
|
return ScenarioComparisonValidator()
|
||||||
|
|
||||||
|
|
||||||
|
def _make_scenario(**overrides) -> Scenario:
|
||||||
|
project_id: int = int(overrides.get("project_id", 1))
|
||||||
|
name: str = str(overrides.get("name", "Scenario"))
|
||||||
|
status = cast(ScenarioStatus, overrides.get(
|
||||||
|
"status", ScenarioStatus.DRAFT))
|
||||||
|
start_date = overrides.get("start_date", date(2025, 1, 1))
|
||||||
|
end_date = overrides.get("end_date", date(2025, 12, 31))
|
||||||
|
currency = cast(str, overrides.get("currency", "USD"))
|
||||||
|
primary_resource = cast(ResourceType, overrides.get(
|
||||||
|
"primary_resource", ResourceType.DIESEL))
|
||||||
|
|
||||||
|
scenario = Scenario(
|
||||||
|
project_id=project_id,
|
||||||
|
name=name,
|
||||||
|
status=status,
|
||||||
|
start_date=start_date,
|
||||||
|
end_date=end_date,
|
||||||
|
currency=currency,
|
||||||
|
primary_resource=primary_resource,
|
||||||
|
)
|
||||||
|
|
||||||
|
if "id" in overrides:
|
||||||
|
scenario.id = overrides["id"]
|
||||||
|
|
||||||
|
return scenario
|
||||||
|
|
||||||
|
|
||||||
|
class TestScenarioComparisonValidator:
|
||||||
|
def test_validate_allows_matching_scenarios(self, validator: ScenarioComparisonValidator) -> None:
|
||||||
|
scenario_a = _make_scenario(id=1)
|
||||||
|
scenario_b = _make_scenario(id=2)
|
||||||
|
|
||||||
|
validator.validate([scenario_a, scenario_b])
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"kwargs_a, kwargs_b, expected_code",
|
||||||
|
[
|
||||||
|
({"project_id": 1}, {"project_id": 2}, "SCENARIO_PROJECT_MISMATCH"),
|
||||||
|
({"status": ScenarioStatus.ARCHIVED}, {}, "SCENARIO_STATUS_INVALID"),
|
||||||
|
({"currency": "USD"}, {"currency": "CAD"},
|
||||||
|
"SCENARIO_CURRENCY_MISMATCH"),
|
||||||
|
(
|
||||||
|
{"start_date": date(2025, 1, 1), "end_date": date(2025, 6, 1)},
|
||||||
|
{"start_date": date(2025, 7, 1),
|
||||||
|
"end_date": date(2025, 12, 31)},
|
||||||
|
"SCENARIO_TIMELINE_DISJOINT",
|
||||||
|
),
|
||||||
|
({"primary_resource": ResourceType.DIESEL}, {
|
||||||
|
"primary_resource": ResourceType.ELECTRICITY}, "SCENARIO_RESOURCE_MISMATCH"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_validate_raises_for_conflicts(
|
||||||
|
self,
|
||||||
|
validator: ScenarioComparisonValidator,
|
||||||
|
kwargs_a: dict[str, object],
|
||||||
|
kwargs_b: dict[str, object],
|
||||||
|
expected_code: str,
|
||||||
|
) -> None:
|
||||||
|
scenario_a = _make_scenario(id=10, **kwargs_a)
|
||||||
|
scenario_b = _make_scenario(id=20, **kwargs_b)
|
||||||
|
|
||||||
|
with pytest.raises(ScenarioValidationError) as exc_info:
|
||||||
|
validator.validate([scenario_a, scenario_b])
|
||||||
|
|
||||||
|
assert exc_info.value.code == expected_code
|
||||||
|
|
||||||
|
def test_timeline_rule_skips_when_insufficient_ranges(
|
||||||
|
self, validator: ScenarioComparisonValidator
|
||||||
|
) -> None:
|
||||||
|
scenario_a = _make_scenario(id=1, start_date=None, end_date=None)
|
||||||
|
scenario_b = _make_scenario(id=2, start_date=date(
|
||||||
|
2025, 1, 1), end_date=date(2025, 12, 31))
|
||||||
|
|
||||||
|
validator.validate([scenario_a, scenario_b])
|
||||||
|
|
||||||
|
|
||||||
|
class TestScenarioComparisonRequest:
|
||||||
|
def test_requires_two_unique_identifiers(self) -> None:
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
ScenarioComparisonRequest.model_validate({"scenario_ids": [1]})
|
||||||
|
|
||||||
|
def test_deduplicates_ids_preserving_order(self) -> None:
|
||||||
|
payload = ScenarioComparisonRequest.model_validate(
|
||||||
|
{"scenario_ids": [1, 1, 2, 2, 3]})
|
||||||
|
|
||||||
|
assert payload.scenario_ids == [1, 2, 3]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def engine() -> Iterator[Engine]:
|
||||||
|
engine = create_engine(
|
||||||
|
"sqlite+pysqlite:///:memory:",
|
||||||
|
future=True,
|
||||||
|
connect_args={"check_same_thread": False},
|
||||||
|
poolclass=StaticPool,
|
||||||
|
)
|
||||||
|
Base.metadata.create_all(bind=engine)
|
||||||
|
try:
|
||||||
|
yield engine
|
||||||
|
finally:
|
||||||
|
Base.metadata.drop_all(bind=engine)
|
||||||
|
engine.dispose()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def session_factory(engine: Engine) -> Iterator[sessionmaker]:
|
||||||
|
testing_session = sessionmaker(
|
||||||
|
bind=engine, expire_on_commit=False, future=True)
|
||||||
|
yield testing_session
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def api_client(session_factory) -> Iterator[TestClient]:
|
||||||
|
app = FastAPI()
|
||||||
|
app.include_router(scenarios_router)
|
||||||
|
|
||||||
|
def _override_uow() -> Iterator[UnitOfWork]:
|
||||||
|
with UnitOfWork(session_factory=session_factory) as uow:
|
||||||
|
yield uow
|
||||||
|
|
||||||
|
app.dependency_overrides[get_unit_of_work] = _override_uow
|
||||||
|
client = TestClient(app)
|
||||||
|
try:
|
||||||
|
yield client
|
||||||
|
finally:
|
||||||
|
client.close()
|
||||||
|
|
||||||
|
|
||||||
|
def _create_project_with_scenarios(
|
||||||
|
session_factory: sessionmaker,
|
||||||
|
scenario_overrides: list[dict[str, object]],
|
||||||
|
) -> tuple[int, list[int]]:
|
||||||
|
with UnitOfWork(session_factory=session_factory) as uow:
|
||||||
|
project_name = f"Project {uuid4()}"
|
||||||
|
project = Project(name=project_name,
|
||||||
|
operation_type=MiningOperationType.OPEN_PIT)
|
||||||
|
uow.projects.create(project)
|
||||||
|
|
||||||
|
scenario_ids: list[int] = []
|
||||||
|
for index, overrides in enumerate(scenario_overrides, start=1):
|
||||||
|
scenario = Scenario(
|
||||||
|
project_id=project.id,
|
||||||
|
name=f"Scenario {index}",
|
||||||
|
status=overrides.get("status", ScenarioStatus.DRAFT),
|
||||||
|
start_date=overrides.get("start_date", date(2025, 1, 1)),
|
||||||
|
end_date=overrides.get("end_date", date(2025, 12, 31)),
|
||||||
|
currency=overrides.get("currency", "USD"),
|
||||||
|
primary_resource=overrides.get(
|
||||||
|
"primary_resource", ResourceType.DIESEL),
|
||||||
|
)
|
||||||
|
uow.scenarios.create(scenario)
|
||||||
|
scenario_ids.append(scenario.id)
|
||||||
|
|
||||||
|
return project.id, scenario_ids
|
||||||
|
|
||||||
|
|
||||||
|
class TestScenarioComparisonEndpoint:
|
||||||
|
def test_returns_scenarios_when_validation_passes(
|
||||||
|
self, api_client: TestClient, session_factory: sessionmaker
|
||||||
|
) -> None:
|
||||||
|
project_id, scenario_ids = _create_project_with_scenarios(
|
||||||
|
session_factory,
|
||||||
|
[
|
||||||
|
{},
|
||||||
|
{"start_date": date(2025, 6, 1),
|
||||||
|
"end_date": date(2025, 12, 31)},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
response = api_client.post(
|
||||||
|
f"/projects/{project_id}/scenarios/compare",
|
||||||
|
json={"scenario_ids": scenario_ids},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
payload = ScenarioComparisonResponse.model_validate(response.json())
|
||||||
|
assert payload.project_id == project_id
|
||||||
|
assert {scenario.id for scenario in payload.scenarios} == set(
|
||||||
|
scenario_ids)
|
||||||
|
|
||||||
|
def test_returns_422_when_currency_mismatch(
|
||||||
|
self, api_client: TestClient, session_factory: sessionmaker
|
||||||
|
) -> None:
|
||||||
|
project_id, scenario_ids = _create_project_with_scenarios(
|
||||||
|
session_factory,
|
||||||
|
[{"currency": "USD"}, {"currency": "CAD"}],
|
||||||
|
)
|
||||||
|
|
||||||
|
response = api_client.post(
|
||||||
|
f"/projects/{project_id}/scenarios/compare",
|
||||||
|
json={"scenario_ids": scenario_ids},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 422
|
||||||
|
detail = response.json()["detail"]
|
||||||
|
assert detail["code"] == "SCENARIO_CURRENCY_MISMATCH"
|
||||||
|
|
||||||
|
def test_returns_422_when_second_scenario_from_other_project(
|
||||||
|
self, api_client: TestClient, session_factory: sessionmaker
|
||||||
|
) -> None:
|
||||||
|
project_a_id, scenario_ids_a = _create_project_with_scenarios(
|
||||||
|
session_factory, [{}])
|
||||||
|
project_b_id, scenario_ids_b = _create_project_with_scenarios(
|
||||||
|
session_factory, [{}])
|
||||||
|
|
||||||
|
response = api_client.post(
|
||||||
|
f"/projects/{project_a_id}/scenarios/compare",
|
||||||
|
json={"scenario_ids": [scenario_ids_a[0], scenario_ids_b[0]]},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 422
|
||||||
|
detail = response.json()["detail"]
|
||||||
|
assert detail["code"] == "SCENARIO_PROJECT_MISMATCH"
|
||||||
|
assert project_a_id != project_b_id
|
||||||
Reference in New Issue
Block a user