Files
calminer/tests/test_scenario_validation.py
zwitschi 0f79864188 feat: enhance project and scenario management with role-based access control
- Implemented role-based access control for project and scenario routes.
- Added authorization checks to ensure users have appropriate roles for viewing and managing projects and scenarios.
- Introduced utility functions for ensuring project and scenario access based on user roles.
- Refactored project and scenario routes to utilize new authorization helpers.
- Created initial data seeding script to set up default roles and an admin user.
- Added tests for authorization helpers and initial data seeding functionality.
- Updated exception handling to include authorization errors.
2025-11-09 23:14:54 +01:00

281 lines
9.4 KiB
Python

from __future__ import annotations
from datetime import date
from collections.abc import Iterator
from typing import cast
from uuid import uuid4
import pytest
from fastapi import FastAPI, Request
from fastapi.testclient import TestClient
from pydantic import ValidationError
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from sqlalchemy.engine import Engine
from sqlalchemy.pool import StaticPool
from config.database import Base
from dependencies import get_auth_session, get_unit_of_work
from models import (
MiningOperationType,
Project,
ResourceType,
Scenario,
ScenarioStatus,
User,
)
from schemas.scenario import (
ScenarioComparisonRequest,
ScenarioComparisonResponse,
)
from services.exceptions import ScenarioValidationError
from services.scenario_validation import ScenarioComparisonValidator
from services.unit_of_work import UnitOfWork
from services.session import AuthSession, SessionTokens
from routes.scenarios import router as scenarios_router
@pytest.fixture()
def validator() -> ScenarioComparisonValidator:
return ScenarioComparisonValidator()
def _make_scenario(**overrides) -> Scenario:
project_id: int = int(overrides.get("project_id", 1))
name: str = str(overrides.get("name", "Scenario"))
status = cast(ScenarioStatus, overrides.get(
"status", ScenarioStatus.DRAFT))
start_date = overrides.get("start_date", date(2025, 1, 1))
end_date = overrides.get("end_date", date(2025, 12, 31))
currency = cast(str, overrides.get("currency", "USD"))
primary_resource = cast(ResourceType, overrides.get(
"primary_resource", ResourceType.DIESEL))
scenario = Scenario(
project_id=project_id,
name=name,
status=status,
start_date=start_date,
end_date=end_date,
currency=currency,
primary_resource=primary_resource,
)
if "id" in overrides:
scenario.id = overrides["id"]
return scenario
class TestScenarioComparisonValidator:
def test_validate_allows_matching_scenarios(self, validator: ScenarioComparisonValidator) -> None:
scenario_a = _make_scenario(id=1)
scenario_b = _make_scenario(id=2)
validator.validate([scenario_a, scenario_b])
@pytest.mark.parametrize(
"kwargs_a, kwargs_b, expected_code",
[
({"project_id": 1}, {"project_id": 2}, "SCENARIO_PROJECT_MISMATCH"),
({"status": ScenarioStatus.ARCHIVED}, {}, "SCENARIO_STATUS_INVALID"),
({"currency": "USD"}, {"currency": "CAD"},
"SCENARIO_CURRENCY_MISMATCH"),
(
{"start_date": date(2025, 1, 1), "end_date": date(2025, 6, 1)},
{"start_date": date(2025, 7, 1),
"end_date": date(2025, 12, 31)},
"SCENARIO_TIMELINE_DISJOINT",
),
({"primary_resource": ResourceType.DIESEL}, {
"primary_resource": ResourceType.ELECTRICITY}, "SCENARIO_RESOURCE_MISMATCH"),
],
)
def test_validate_raises_for_conflicts(
self,
validator: ScenarioComparisonValidator,
kwargs_a: dict[str, object],
kwargs_b: dict[str, object],
expected_code: str,
) -> None:
scenario_a = _make_scenario(id=10, **kwargs_a)
scenario_b = _make_scenario(id=20, **kwargs_b)
with pytest.raises(ScenarioValidationError) as exc_info:
validator.validate([scenario_a, scenario_b])
assert exc_info.value.code == expected_code
def test_timeline_rule_skips_when_insufficient_ranges(
self, validator: ScenarioComparisonValidator
) -> None:
scenario_a = _make_scenario(id=1, start_date=None, end_date=None)
scenario_b = _make_scenario(id=2, start_date=date(
2025, 1, 1), end_date=date(2025, 12, 31))
validator.validate([scenario_a, scenario_b])
class TestScenarioComparisonRequest:
def test_requires_two_unique_identifiers(self) -> None:
with pytest.raises(ValidationError):
ScenarioComparisonRequest.model_validate({"scenario_ids": [1]})
def test_deduplicates_ids_preserving_order(self) -> None:
payload = ScenarioComparisonRequest.model_validate(
{"scenario_ids": [1, 1, 2, 2, 3]})
assert payload.scenario_ids == [1, 2, 3]
@pytest.fixture()
def engine() -> Iterator[Engine]:
engine = create_engine(
"sqlite+pysqlite:///:memory:",
future=True,
connect_args={"check_same_thread": False},
poolclass=StaticPool,
)
Base.metadata.create_all(bind=engine)
try:
yield engine
finally:
Base.metadata.drop_all(bind=engine)
engine.dispose()
@pytest.fixture()
def session_factory(engine: Engine) -> Iterator[sessionmaker]:
testing_session = sessionmaker(
bind=engine, expire_on_commit=False, future=True)
yield testing_session
@pytest.fixture()
def api_client(session_factory) -> Iterator[TestClient]:
app = FastAPI()
app.include_router(scenarios_router)
def _override_uow() -> Iterator[UnitOfWork]:
with UnitOfWork(session_factory=session_factory) as uow:
yield uow
app.dependency_overrides[get_unit_of_work] = _override_uow
with UnitOfWork(session_factory=session_factory) as uow:
assert uow.users is not None
uow.ensure_default_roles()
user = User(
email="test-scenarios@example.com",
username="scenario-tester",
password_hash=User.hash_password("password"),
is_active=True,
is_superuser=True,
)
uow.users.create(user)
user = uow.users.get(user.id, with_roles=True)
def _override_auth_session(request: Request) -> AuthSession:
session = AuthSession(tokens=SessionTokens(
access_token="test", refresh_token="test"))
session.user = user
request.state.auth_session = session
return session
app.dependency_overrides[get_auth_session] = _override_auth_session
client = TestClient(app)
try:
yield client
finally:
client.close()
def _create_project_with_scenarios(
session_factory: sessionmaker,
scenario_overrides: list[dict[str, object]],
) -> tuple[int, list[int]]:
with UnitOfWork(session_factory=session_factory) as uow:
assert uow.projects is not None
assert uow.scenarios is not None
project_name = f"Project {uuid4()}"
project = Project(name=project_name,
operation_type=MiningOperationType.OPEN_PIT)
uow.projects.create(project)
scenario_ids: list[int] = []
for index, overrides in enumerate(scenario_overrides, start=1):
scenario = Scenario(
project_id=project.id,
name=f"Scenario {index}",
status=overrides.get("status", ScenarioStatus.DRAFT),
start_date=overrides.get("start_date", date(2025, 1, 1)),
end_date=overrides.get("end_date", date(2025, 12, 31)),
currency=overrides.get("currency", "USD"),
primary_resource=overrides.get(
"primary_resource", ResourceType.DIESEL),
)
uow.scenarios.create(scenario)
scenario_ids.append(scenario.id)
return project.id, scenario_ids
class TestScenarioComparisonEndpoint:
def test_returns_scenarios_when_validation_passes(
self, api_client: TestClient, session_factory: sessionmaker
) -> None:
project_id, scenario_ids = _create_project_with_scenarios(
session_factory,
[
{},
{"start_date": date(2025, 6, 1),
"end_date": date(2025, 12, 31)},
],
)
response = api_client.post(
f"/projects/{project_id}/scenarios/compare",
json={"scenario_ids": scenario_ids},
)
assert response.status_code == 200
payload = ScenarioComparisonResponse.model_validate(response.json())
assert payload.project_id == project_id
assert {scenario.id for scenario in payload.scenarios} == set(
scenario_ids)
def test_returns_422_when_currency_mismatch(
self, api_client: TestClient, session_factory: sessionmaker
) -> None:
project_id, scenario_ids = _create_project_with_scenarios(
session_factory,
[{"currency": "USD"}, {"currency": "CAD"}],
)
response = api_client.post(
f"/projects/{project_id}/scenarios/compare",
json={"scenario_ids": scenario_ids},
)
assert response.status_code == 422
detail = response.json()["detail"]
assert detail["code"] == "SCENARIO_CURRENCY_MISMATCH"
def test_returns_422_when_second_scenario_from_other_project(
self, api_client: TestClient, session_factory: sessionmaker
) -> None:
project_a_id, scenario_ids_a = _create_project_with_scenarios(
session_factory, [{}])
project_b_id, scenario_ids_b = _create_project_with_scenarios(
session_factory, [{}])
response = api_client.post(
f"/projects/{project_a_id}/scenarios/compare",
json={"scenario_ids": [scenario_ids_a[0], scenario_ids_b[0]]},
)
assert response.status_code == 422
detail = response.json()["detail"]
assert detail["code"] == "SCENARIO_PROJECT_MISMATCH"
assert project_a_id != project_b_id