Files
calminer/routes/scenarios.py
zwitschi 0f79864188 feat: enhance project and scenario management with role-based access control
- Implemented role-based access control for project and scenario routes.
- Added authorization checks to ensure users have appropriate roles for viewing and managing projects and scenarios.
- Introduced utility functions for ensuring project and scenario access based on user roles.
- Refactored project and scenario routes to utilize new authorization helpers.
- Created initial data seeding script to set up default roles and an admin user.
- Added tests for authorization helpers and initial data seeding functionality.
- Updated exception handling to include authorization errors.
2025-11-09 23:14:54 +01:00

466 lines
14 KiB
Python

from __future__ import annotations
from datetime import date
from typing import List
from fastapi import APIRouter, Depends, Form, HTTPException, Request, status
from fastapi.responses import HTMLResponse, RedirectResponse
from fastapi.templating import Jinja2Templates
from dependencies import (
get_unit_of_work,
require_any_role,
require_roles,
require_scenario_resource,
)
from models import ResourceType, Scenario, ScenarioStatus, User
from schemas.scenario import (
ScenarioComparisonRequest,
ScenarioComparisonResponse,
ScenarioCreate,
ScenarioRead,
ScenarioUpdate,
)
from services.exceptions import (
EntityConflictError,
EntityNotFoundError,
ScenarioValidationError,
)
from services.unit_of_work import UnitOfWork
router = APIRouter(tags=["Scenarios"])
templates = Jinja2Templates(directory="templates")
READ_ROLES = ("viewer", "analyst", "project_manager", "admin")
MANAGE_ROLES = ("project_manager", "admin")
def _to_read_model(scenario: Scenario) -> ScenarioRead:
return ScenarioRead.model_validate(scenario)
def _resource_type_choices() -> list[tuple[str, str]]:
return [
(resource.value, resource.value.replace("_", " ").title())
for resource in ResourceType
]
def _scenario_status_choices() -> list[tuple[str, str]]:
return [
(status.value, status.value.title()) for status in ScenarioStatus
]
def _require_project_repo(uow: UnitOfWork):
if not uow.projects:
raise RuntimeError("Project repository not initialised")
return uow.projects
def _require_scenario_repo(uow: UnitOfWork):
if not uow.scenarios:
raise RuntimeError("Scenario repository not initialised")
return uow.scenarios
@router.get(
"/projects/{project_id}/scenarios",
response_model=List[ScenarioRead],
)
def list_scenarios_for_project(
project_id: int,
_: User = Depends(require_any_role(*READ_ROLES)),
uow: UnitOfWork = Depends(get_unit_of_work),
) -> List[ScenarioRead]:
project_repo = _require_project_repo(uow)
scenario_repo = _require_scenario_repo(uow)
try:
project_repo.get(project_id)
except EntityNotFoundError as exc:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)) from exc
scenarios = scenario_repo.list_for_project(project_id)
return [_to_read_model(scenario) for scenario in scenarios]
@router.post(
"/projects/{project_id}/scenarios/compare",
response_model=ScenarioComparisonResponse,
status_code=status.HTTP_200_OK,
)
def compare_scenarios(
project_id: int,
payload: ScenarioComparisonRequest,
_: User = Depends(require_any_role(*READ_ROLES)),
uow: UnitOfWork = Depends(get_unit_of_work),
) -> ScenarioComparisonResponse:
try:
_require_project_repo(uow).get(project_id)
except EntityNotFoundError as exc:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)
) from exc
try:
scenarios = uow.validate_scenarios_for_comparison(payload.scenario_ids)
if any(scenario.project_id != project_id for scenario in scenarios):
raise ScenarioValidationError(
code="SCENARIO_PROJECT_MISMATCH",
message="Selected scenarios do not belong to the same project.",
scenario_ids=[
scenario.id for scenario in scenarios if scenario.id is not None
],
)
except EntityNotFoundError as exc:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)
) from exc
except ScenarioValidationError as exc:
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_CONTENT,
detail={
"code": exc.code,
"message": exc.message,
"scenario_ids": list(exc.scenario_ids or []),
},
) from exc
return ScenarioComparisonResponse(
project_id=project_id,
scenarios=[_to_read_model(scenario) for scenario in scenarios],
)
@router.post(
"/projects/{project_id}/scenarios",
response_model=ScenarioRead,
status_code=status.HTTP_201_CREATED,
)
def create_scenario_for_project(
project_id: int,
payload: ScenarioCreate,
_: User = Depends(require_roles(*MANAGE_ROLES)),
uow: UnitOfWork = Depends(get_unit_of_work),
) -> ScenarioRead:
project_repo = _require_project_repo(uow)
scenario_repo = _require_scenario_repo(uow)
try:
project_repo.get(project_id)
except EntityNotFoundError as exc:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)) from exc
scenario = Scenario(project_id=project_id, **payload.model_dump())
try:
created = scenario_repo.create(scenario)
except EntityConflictError as exc:
raise HTTPException(
status_code=status.HTTP_409_CONFLICT, detail=str(exc)) from exc
return _to_read_model(created)
@router.get("/scenarios/{scenario_id}", response_model=ScenarioRead)
def get_scenario(
scenario: Scenario = Depends(require_scenario_resource()),
) -> ScenarioRead:
return _to_read_model(scenario)
@router.put("/scenarios/{scenario_id}", response_model=ScenarioRead)
def update_scenario(
payload: ScenarioUpdate,
scenario: Scenario = Depends(
require_scenario_resource(require_manage=True)
),
uow: UnitOfWork = Depends(get_unit_of_work),
) -> ScenarioRead:
update_data = payload.model_dump(exclude_unset=True)
for field, value in update_data.items():
setattr(scenario, field, value)
uow.flush()
return _to_read_model(scenario)
@router.delete("/scenarios/{scenario_id}", status_code=status.HTTP_204_NO_CONTENT)
def delete_scenario(
scenario: Scenario = Depends(
require_scenario_resource(require_manage=True)
),
uow: UnitOfWork = Depends(get_unit_of_work),
) -> None:
_require_scenario_repo(uow).delete(scenario.id)
def _normalise(value: str | None) -> str | None:
if value is None:
return None
value = value.strip()
return value or None
def _parse_date(value: str | None) -> date | None:
value = _normalise(value)
if not value:
return None
return date.fromisoformat(value)
def _parse_discount_rate(value: str | None) -> float | None:
value = _normalise(value)
if not value:
return None
try:
return float(value)
except ValueError:
return None
@router.get(
"/projects/{project_id}/scenarios/new",
response_class=HTMLResponse,
include_in_schema=False,
name="scenarios.create_scenario_form",
)
def create_scenario_form(
project_id: int,
request: Request,
_: User = Depends(require_roles(*MANAGE_ROLES)),
uow: UnitOfWork = Depends(get_unit_of_work),
) -> HTMLResponse:
try:
project = _require_project_repo(uow).get(project_id)
except EntityNotFoundError as exc:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)
) from exc
return templates.TemplateResponse(
request,
"scenarios/form.html",
{
"project": project,
"scenario": None,
"scenario_statuses": _scenario_status_choices(),
"resource_types": _resource_type_choices(),
"form_action": request.url_for(
"scenarios.create_scenario_submit", project_id=project_id
),
"cancel_url": request.url_for(
"projects.view_project", project_id=project_id
),
},
)
@router.post(
"/projects/{project_id}/scenarios/new",
include_in_schema=False,
name="scenarios.create_scenario_submit",
)
def create_scenario_submit(
project_id: int,
request: Request,
_: User = Depends(require_roles(*MANAGE_ROLES)),
name: str = Form(...),
description: str | None = Form(None),
status_value: str = Form(ScenarioStatus.DRAFT.value),
start_date: str | None = Form(None),
end_date: str | None = Form(None),
discount_rate: str | None = Form(None),
currency: str | None = Form(None),
primary_resource: str | None = Form(None),
uow: UnitOfWork = Depends(get_unit_of_work),
):
project_repo = _require_project_repo(uow)
scenario_repo = _require_scenario_repo(uow)
try:
project = project_repo.get(project_id)
except EntityNotFoundError as exc:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)
) from exc
try:
status_enum = ScenarioStatus(status_value)
except ValueError:
status_enum = ScenarioStatus.DRAFT
resource_enum = None
if primary_resource:
try:
resource_enum = ResourceType(primary_resource)
except ValueError:
resource_enum = None
currency_value = _normalise(currency)
currency_value = currency_value.upper() if currency_value else None
scenario = Scenario(
project_id=project_id,
name=name.strip(),
description=_normalise(description),
status=status_enum,
start_date=_parse_date(start_date),
end_date=_parse_date(end_date),
discount_rate=_parse_discount_rate(discount_rate),
currency=currency_value,
primary_resource=resource_enum,
)
try:
scenario_repo.create(scenario)
except EntityConflictError as exc:
return templates.TemplateResponse(
request,
"scenarios/form.html",
{
"project": project,
"scenario": scenario,
"scenario_statuses": _scenario_status_choices(),
"resource_types": _resource_type_choices(),
"form_action": request.url_for(
"scenarios.create_scenario_submit", project_id=project_id
),
"cancel_url": request.url_for(
"projects.view_project", project_id=project_id
),
"error": "Scenario could not be created.",
},
status_code=status.HTTP_409_CONFLICT,
)
return RedirectResponse(
request.url_for("projects.view_project", project_id=project_id),
status_code=status.HTTP_303_SEE_OTHER,
)
@router.get(
"/scenarios/{scenario_id}/view",
response_class=HTMLResponse,
include_in_schema=False,
name="scenarios.view_scenario",
)
def view_scenario(
request: Request,
scenario: Scenario = Depends(
require_scenario_resource(with_children=True)
),
uow: UnitOfWork = Depends(get_unit_of_work),
) -> HTMLResponse:
project = _require_project_repo(uow).get(scenario.project_id)
financial_inputs = sorted(
scenario.financial_inputs, key=lambda item: item.created_at
)
simulation_parameters = sorted(
scenario.simulation_parameters, key=lambda item: item.created_at
)
scenario_metrics = {
"financial_count": len(financial_inputs),
"parameter_count": len(simulation_parameters),
"currency": scenario.currency,
"primary_resource": scenario.primary_resource.value.replace('_', ' ').title() if scenario.primary_resource else None,
}
return templates.TemplateResponse(
request,
"scenarios/detail.html",
{
"project": project,
"scenario": scenario,
"scenario_metrics": scenario_metrics,
"financial_inputs": financial_inputs,
"simulation_parameters": simulation_parameters,
},
)
@router.get(
"/scenarios/{scenario_id}/edit",
response_class=HTMLResponse,
include_in_schema=False,
name="scenarios.edit_scenario_form",
)
def edit_scenario_form(
request: Request,
scenario: Scenario = Depends(
require_scenario_resource(require_manage=True)
),
uow: UnitOfWork = Depends(get_unit_of_work),
) -> HTMLResponse:
project = _require_project_repo(uow).get(scenario.project_id)
return templates.TemplateResponse(
request,
"scenarios/form.html",
{
"project": project,
"scenario": scenario,
"scenario_statuses": _scenario_status_choices(),
"resource_types": _resource_type_choices(),
"form_action": request.url_for(
"scenarios.edit_scenario_submit", scenario_id=scenario.id
),
"cancel_url": request.url_for(
"scenarios.view_scenario", scenario_id=scenario.id
),
},
)
@router.post(
"/scenarios/{scenario_id}/edit",
include_in_schema=False,
name="scenarios.edit_scenario_submit",
)
def edit_scenario_submit(
request: Request,
scenario: Scenario = Depends(
require_scenario_resource(require_manage=True)
),
name: str = Form(...),
description: str | None = Form(None),
status_value: str = Form(ScenarioStatus.DRAFT.value),
start_date: str | None = Form(None),
end_date: str | None = Form(None),
discount_rate: str | None = Form(None),
currency: str | None = Form(None),
primary_resource: str | None = Form(None),
uow: UnitOfWork = Depends(get_unit_of_work),
):
project = _require_project_repo(uow).get(scenario.project_id)
scenario.name = name.strip()
scenario.description = _normalise(description)
try:
scenario.status = ScenarioStatus(status_value)
except ValueError:
scenario.status = ScenarioStatus.DRAFT
scenario.start_date = _parse_date(start_date)
scenario.end_date = _parse_date(end_date)
scenario.discount_rate = _parse_discount_rate(discount_rate)
currency_value = _normalise(currency)
scenario.currency = currency_value.upper() if currency_value else None
resource_enum = None
if primary_resource:
try:
resource_enum = ResourceType(primary_resource)
except ValueError:
resource_enum = None
scenario.primary_resource = resource_enum
uow.flush()
return RedirectResponse(
request.url_for("scenarios.view_scenario", scenario_id=scenario.id),
status_code=status.HTTP_303_SEE_OTHER,
)