feat: enhance project and scenario management with role-based access control
- Implemented role-based access control for project and scenario routes. - Added authorization checks to ensure users have appropriate roles for viewing and managing projects and scenarios. - Introduced utility functions for ensuring project and scenario access based on user roles. - Refactored project and scenario routes to utilize new authorization helpers. - Created initial data seeding script to set up default roles and an admin user. - Added tests for authorization helpers and initial data seeding functionality. - Updated exception handling to include authorization errors.
This commit is contained in:
@@ -9,6 +9,3 @@ DATABASE_PASSWORD=<password>
|
||||
DATABASE_NAME=calminer
|
||||
# Optional: set a schema (comma-separated for multiple entries)
|
||||
# DATABASE_SCHEMA=public
|
||||
|
||||
# Legacy fallback (still supported, but granular settings are preferred)
|
||||
# DATABASE_URL=postgresql://<user>:<password>@localhost:5432/calminer
|
||||
@@ -18,3 +18,10 @@
|
||||
- Updated all Jinja2 template responses to the new Starlette signature to eliminate deprecation warnings while keeping request-aware context available to the templates.
|
||||
- Introduced `services/security.py` to centralize Argon2 password hashing utilities and JWT creation/verification with typed payloads, and added pytest coverage for hashing, expiry, tampering, and token type mismatch scenarios.
|
||||
- Added `routes/auth.py` with registration, login, and password reset flows, refreshed auth templates with error messaging, wired navigation links, and introduced end-to-end pytest coverage for the new forms and token flows.
|
||||
- Implemented cookie-based authentication session middleware with automatic access token refresh, logout handling, navigation adjustments, and documentation/test updates capturing the new behaviour.
|
||||
- Delivered idempotent seeding utilities with `scripts/initial_data.py`, entry-point runner `scripts/00_initial_data.py`, documentation updates, and pytest coverage to verify role/admin provisioning.
|
||||
- Secured project and scenario routers with RBAC guard dependencies, enforced repository access checks via helper utilities, and aligned template routes with FastAPI dependency injection patterns.
|
||||
|
||||
## 2025-11-10
|
||||
|
||||
- Extended authorization helper layer with project/scenario ownership lookups, integrated them into FastAPI dependencies, refreshed pytest fixtures to keep the suite authenticated, and documented the new patterns across RBAC plan and security guides.
|
||||
|
||||
153
dependencies.py
153
dependencies.py
@@ -1,11 +1,17 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Generator
|
||||
from collections.abc import Callable, Iterable, Generator
|
||||
|
||||
from fastapi import Depends, HTTPException, Request, status
|
||||
|
||||
from config.settings import Settings, get_settings
|
||||
from models import User
|
||||
from models import Project, Role, Scenario, User
|
||||
from services.authorization import (
|
||||
ensure_project_access as ensure_project_access_helper,
|
||||
ensure_scenario_access as ensure_scenario_access_helper,
|
||||
ensure_scenario_in_project as ensure_scenario_in_project_helper,
|
||||
)
|
||||
from services.exceptions import AuthorizationError, EntityNotFoundError
|
||||
from services.security import JWTSettings
|
||||
from services.session import (
|
||||
AuthSession,
|
||||
@@ -90,9 +96,150 @@ def require_current_user(
|
||||
) -> User:
|
||||
"""Ensure that a request is authenticated and return the user context."""
|
||||
|
||||
if session.user is None:
|
||||
if session.user is None or session.tokens.is_empty:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Authentication required.",
|
||||
)
|
||||
return session.user
|
||||
|
||||
|
||||
def require_authenticated_user(
|
||||
user: User = Depends(require_current_user),
|
||||
) -> User:
|
||||
"""Ensure the current user account is active."""
|
||||
|
||||
if not user.is_active:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="User account is disabled.",
|
||||
)
|
||||
return user
|
||||
|
||||
|
||||
def _user_role_names(user: User) -> set[str]:
|
||||
roles: Iterable[Role] = getattr(user, "roles", []) or []
|
||||
return {role.name for role in roles}
|
||||
|
||||
|
||||
def require_roles(*roles: str) -> Callable[[User], User]:
|
||||
"""Dependency factory enforcing membership in one of the given roles."""
|
||||
|
||||
required = tuple(role.strip() for role in roles if role.strip())
|
||||
if not required:
|
||||
raise ValueError("require_roles requires at least one role name")
|
||||
|
||||
def _dependency(user: User = Depends(require_authenticated_user)) -> User:
|
||||
if user.is_superuser:
|
||||
return user
|
||||
|
||||
role_names = _user_role_names(user)
|
||||
if not any(role in role_names for role in required):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Insufficient permissions for this action.",
|
||||
)
|
||||
return user
|
||||
|
||||
return _dependency
|
||||
|
||||
|
||||
def require_any_role(*roles: str) -> Callable[[User], User]:
|
||||
"""Alias of require_roles for readability in some contexts."""
|
||||
|
||||
return require_roles(*roles)
|
||||
|
||||
|
||||
def require_project_resource(*, require_manage: bool = False) -> Callable[[int], Project]:
|
||||
"""Dependency factory that resolves a project with authorization checks."""
|
||||
|
||||
def _dependency(
|
||||
project_id: int,
|
||||
user: User = Depends(require_authenticated_user),
|
||||
uow: UnitOfWork = Depends(get_unit_of_work),
|
||||
) -> Project:
|
||||
try:
|
||||
return ensure_project_access_helper(
|
||||
uow,
|
||||
project_id=project_id,
|
||||
user=user,
|
||||
require_manage=require_manage,
|
||||
)
|
||||
except EntityNotFoundError as exc:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=str(exc),
|
||||
) from exc
|
||||
except AuthorizationError as exc:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=str(exc),
|
||||
) from exc
|
||||
|
||||
return _dependency
|
||||
|
||||
|
||||
def require_scenario_resource(
|
||||
*, require_manage: bool = False, with_children: bool = False
|
||||
) -> Callable[[int], Scenario]:
|
||||
"""Dependency factory that resolves a scenario with authorization checks."""
|
||||
|
||||
def _dependency(
|
||||
scenario_id: int,
|
||||
user: User = Depends(require_authenticated_user),
|
||||
uow: UnitOfWork = Depends(get_unit_of_work),
|
||||
) -> Scenario:
|
||||
try:
|
||||
return ensure_scenario_access_helper(
|
||||
uow,
|
||||
scenario_id=scenario_id,
|
||||
user=user,
|
||||
require_manage=require_manage,
|
||||
with_children=with_children,
|
||||
)
|
||||
except EntityNotFoundError as exc:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=str(exc),
|
||||
) from exc
|
||||
except AuthorizationError as exc:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=str(exc),
|
||||
) from exc
|
||||
|
||||
return _dependency
|
||||
|
||||
|
||||
def require_project_scenario_resource(
|
||||
*, require_manage: bool = False, with_children: bool = False
|
||||
) -> Callable[[int, int], Scenario]:
|
||||
"""Dependency factory ensuring a scenario belongs to the given project and is accessible."""
|
||||
|
||||
def _dependency(
|
||||
project_id: int,
|
||||
scenario_id: int,
|
||||
user: User = Depends(require_authenticated_user),
|
||||
uow: UnitOfWork = Depends(get_unit_of_work),
|
||||
) -> Scenario:
|
||||
try:
|
||||
return ensure_scenario_in_project_helper(
|
||||
uow,
|
||||
project_id=project_id,
|
||||
scenario_id=scenario_id,
|
||||
user=user,
|
||||
require_manage=require_manage,
|
||||
with_children=with_children,
|
||||
)
|
||||
except EntityNotFoundError as exc:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=str(exc),
|
||||
) from exc
|
||||
except AuthorizationError as exc:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=str(exc),
|
||||
) from exc
|
||||
|
||||
return _dependency
|
||||
|
||||
@@ -6,7 +6,8 @@ from fastapi import APIRouter, Depends, Request
|
||||
from fastapi.responses import HTMLResponse
|
||||
from fastapi.templating import Jinja2Templates
|
||||
|
||||
from dependencies import get_unit_of_work
|
||||
from dependencies import get_unit_of_work, require_authenticated_user
|
||||
from models import User
|
||||
from models import ScenarioStatus
|
||||
from services.unit_of_work import UnitOfWork
|
||||
|
||||
@@ -27,6 +28,8 @@ def _format_timestamp_with_time(moment: datetime | None) -> str | None:
|
||||
|
||||
|
||||
def _load_metrics(uow: UnitOfWork) -> dict[str, object]:
|
||||
if not uow.projects or not uow.scenarios or not uow.financial_inputs:
|
||||
raise RuntimeError("UnitOfWork repositories not initialised")
|
||||
total_projects = uow.projects.count()
|
||||
active_scenarios = uow.scenarios.count_by_status(ScenarioStatus.ACTIVE)
|
||||
pending_simulations = uow.scenarios.count_by_status(ScenarioStatus.DRAFT)
|
||||
@@ -40,11 +43,15 @@ def _load_metrics(uow: UnitOfWork) -> dict[str, object]:
|
||||
|
||||
|
||||
def _load_recent_projects(uow: UnitOfWork) -> list:
|
||||
if not uow.projects:
|
||||
raise RuntimeError("Project repository not initialised")
|
||||
return list(uow.projects.recent(limit=5))
|
||||
|
||||
|
||||
def _load_simulation_updates(uow: UnitOfWork) -> list[dict[str, object]]:
|
||||
updates: list[dict[str, object]] = []
|
||||
if not uow.scenarios:
|
||||
raise RuntimeError("Scenario repository not initialised")
|
||||
scenarios = uow.scenarios.recent(limit=5, with_project=True)
|
||||
for scenario in scenarios:
|
||||
project_name = scenario.project.name if scenario.project else f"Project #{scenario.project_id}"
|
||||
@@ -65,6 +72,9 @@ def _load_scenario_alerts(
|
||||
) -> list[dict[str, object]]:
|
||||
alerts: list[dict[str, object]] = []
|
||||
|
||||
if not uow.scenarios:
|
||||
raise RuntimeError("Scenario repository not initialised")
|
||||
|
||||
drafts = uow.scenarios.list_by_status(
|
||||
ScenarioStatus.DRAFT, limit=3, with_project=True
|
||||
)
|
||||
@@ -102,6 +112,7 @@ def _load_scenario_alerts(
|
||||
@router.get("/", response_class=HTMLResponse, include_in_schema=False, name="dashboard.home")
|
||||
def dashboard_home(
|
||||
request: Request,
|
||||
_: User = Depends(require_authenticated_user),
|
||||
uow: UnitOfWork = Depends(get_unit_of_work),
|
||||
) -> HTMLResponse:
|
||||
context = {
|
||||
|
||||
@@ -6,8 +6,13 @@ from fastapi import APIRouter, Depends, Form, HTTPException, Request, status
|
||||
from fastapi.responses import HTMLResponse, RedirectResponse
|
||||
from fastapi.templating import Jinja2Templates
|
||||
|
||||
from dependencies import get_unit_of_work
|
||||
from models import MiningOperationType, Project, ScenarioStatus
|
||||
from dependencies import (
|
||||
get_unit_of_work,
|
||||
require_any_role,
|
||||
require_project_resource,
|
||||
require_roles,
|
||||
)
|
||||
from models import MiningOperationType, Project, ScenarioStatus, User
|
||||
from schemas.project import ProjectCreate, ProjectRead, ProjectUpdate
|
||||
from services.exceptions import EntityConflictError, EntityNotFoundError
|
||||
from services.unit_of_work import UnitOfWork
|
||||
@@ -15,11 +20,20 @@ from services.unit_of_work import UnitOfWork
|
||||
router = APIRouter(prefix="/projects", tags=["Projects"])
|
||||
templates = Jinja2Templates(directory="templates")
|
||||
|
||||
READ_ROLES = ("viewer", "analyst", "project_manager", "admin")
|
||||
MANAGE_ROLES = ("project_manager", "admin")
|
||||
|
||||
|
||||
def _to_read_model(project: Project) -> ProjectRead:
|
||||
return ProjectRead.model_validate(project)
|
||||
|
||||
|
||||
def _require_project_repo(uow: UnitOfWork):
|
||||
if not uow.projects:
|
||||
raise RuntimeError("Project repository not initialised")
|
||||
return uow.projects
|
||||
|
||||
|
||||
def _operation_type_choices() -> list[tuple[str, str]]:
|
||||
return [
|
||||
(op.value, op.name.replace("_", " ").title()) for op in MiningOperationType
|
||||
@@ -27,18 +41,23 @@ def _operation_type_choices() -> list[tuple[str, str]]:
|
||||
|
||||
|
||||
@router.get("", response_model=List[ProjectRead])
|
||||
def list_projects(uow: UnitOfWork = Depends(get_unit_of_work)) -> List[ProjectRead]:
|
||||
projects = uow.projects.list()
|
||||
def list_projects(
|
||||
_: User = Depends(require_any_role(*READ_ROLES)),
|
||||
uow: UnitOfWork = Depends(get_unit_of_work),
|
||||
) -> List[ProjectRead]:
|
||||
projects = _require_project_repo(uow).list()
|
||||
return [_to_read_model(project) for project in projects]
|
||||
|
||||
|
||||
@router.post("", response_model=ProjectRead, status_code=status.HTTP_201_CREATED)
|
||||
def create_project(
|
||||
payload: ProjectCreate, uow: UnitOfWork = Depends(get_unit_of_work)
|
||||
payload: ProjectCreate,
|
||||
_: User = Depends(require_roles(*MANAGE_ROLES)),
|
||||
uow: UnitOfWork = Depends(get_unit_of_work),
|
||||
) -> ProjectRead:
|
||||
project = Project(**payload.model_dump())
|
||||
try:
|
||||
created = uow.projects.create(project)
|
||||
created = _require_project_repo(uow).create(project)
|
||||
except EntityConflictError as exc:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT, detail=str(exc)
|
||||
@@ -53,9 +72,11 @@ def create_project(
|
||||
name="projects.project_list_page",
|
||||
)
|
||||
def project_list_page(
|
||||
request: Request, uow: UnitOfWork = Depends(get_unit_of_work)
|
||||
request: Request,
|
||||
_: User = Depends(require_any_role(*READ_ROLES)),
|
||||
uow: UnitOfWork = Depends(get_unit_of_work),
|
||||
) -> HTMLResponse:
|
||||
projects = uow.projects.list(with_children=True)
|
||||
projects = _require_project_repo(uow).list(with_children=True)
|
||||
for project in projects:
|
||||
setattr(project, "scenario_count", len(project.scenarios))
|
||||
return templates.TemplateResponse(
|
||||
@@ -73,7 +94,9 @@ def project_list_page(
|
||||
include_in_schema=False,
|
||||
name="projects.create_project_form",
|
||||
)
|
||||
def create_project_form(request: Request) -> HTMLResponse:
|
||||
def create_project_form(
|
||||
request: Request, _: User = Depends(require_roles(*MANAGE_ROLES))
|
||||
) -> HTMLResponse:
|
||||
return templates.TemplateResponse(
|
||||
request,
|
||||
"projects/form.html",
|
||||
@@ -93,6 +116,7 @@ def create_project_form(request: Request) -> HTMLResponse:
|
||||
)
|
||||
def create_project_submit(
|
||||
request: Request,
|
||||
_: User = Depends(require_roles(*MANAGE_ROLES)),
|
||||
name: str = Form(...),
|
||||
location: str | None = Form(None),
|
||||
operation_type: str = Form(...),
|
||||
@@ -128,7 +152,7 @@ def create_project_submit(
|
||||
description=_normalise(description),
|
||||
)
|
||||
try:
|
||||
uow.projects.create(project)
|
||||
_require_project_repo(uow).create(project)
|
||||
except EntityConflictError as exc:
|
||||
return templates.TemplateResponse(
|
||||
request,
|
||||
@@ -150,29 +174,18 @@ def create_project_submit(
|
||||
|
||||
|
||||
@router.get("/{project_id}", response_model=ProjectRead)
|
||||
def get_project(project_id: int, uow: UnitOfWork = Depends(get_unit_of_work)) -> ProjectRead:
|
||||
try:
|
||||
project = uow.projects.get(project_id)
|
||||
except EntityNotFoundError as exc:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)
|
||||
) from exc
|
||||
def get_project(project: Project = Depends(require_project_resource())) -> ProjectRead:
|
||||
return _to_read_model(project)
|
||||
|
||||
|
||||
@router.put("/{project_id}", response_model=ProjectRead)
|
||||
def update_project(
|
||||
project_id: int,
|
||||
payload: ProjectUpdate,
|
||||
project: Project = Depends(
|
||||
require_project_resource(require_manage=True)
|
||||
),
|
||||
uow: UnitOfWork = Depends(get_unit_of_work),
|
||||
) -> ProjectRead:
|
||||
try:
|
||||
project = uow.projects.get(project_id)
|
||||
except EntityNotFoundError as exc:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)
|
||||
) from exc
|
||||
|
||||
update_data = payload.model_dump(exclude_unset=True)
|
||||
for field, value in update_data.items():
|
||||
setattr(project, field, value)
|
||||
@@ -182,13 +195,11 @@ def update_project(
|
||||
|
||||
|
||||
@router.delete("/{project_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||
def delete_project(project_id: int, uow: UnitOfWork = Depends(get_unit_of_work)) -> None:
|
||||
try:
|
||||
uow.projects.delete(project_id)
|
||||
except EntityNotFoundError as exc:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)
|
||||
) from exc
|
||||
def delete_project(
|
||||
project: Project = Depends(require_project_resource(require_manage=True)),
|
||||
uow: UnitOfWork = Depends(get_unit_of_work),
|
||||
) -> None:
|
||||
_require_project_repo(uow).delete(project.id)
|
||||
|
||||
|
||||
@router.get(
|
||||
@@ -198,14 +209,11 @@ def delete_project(project_id: int, uow: UnitOfWork = Depends(get_unit_of_work))
|
||||
name="projects.view_project",
|
||||
)
|
||||
def view_project(
|
||||
project_id: int, request: Request, uow: UnitOfWork = Depends(get_unit_of_work)
|
||||
request: Request,
|
||||
project: Project = Depends(require_project_resource()),
|
||||
uow: UnitOfWork = Depends(get_unit_of_work),
|
||||
) -> HTMLResponse:
|
||||
try:
|
||||
project = uow.projects.get(project_id, with_children=True)
|
||||
except EntityNotFoundError as exc:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)
|
||||
) from exc
|
||||
project = _require_project_repo(uow).get(project.id, with_children=True)
|
||||
|
||||
scenarios = sorted(project.scenarios, key=lambda s: s.created_at)
|
||||
scenario_stats = {
|
||||
@@ -236,15 +244,11 @@ def view_project(
|
||||
name="projects.edit_project_form",
|
||||
)
|
||||
def edit_project_form(
|
||||
project_id: int, request: Request, uow: UnitOfWork = Depends(get_unit_of_work)
|
||||
request: Request,
|
||||
project: Project = Depends(
|
||||
require_project_resource(require_manage=True)
|
||||
),
|
||||
) -> HTMLResponse:
|
||||
try:
|
||||
project = uow.projects.get(project_id)
|
||||
except EntityNotFoundError as exc:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)
|
||||
) from exc
|
||||
|
||||
return templates.TemplateResponse(
|
||||
request,
|
||||
"projects/form.html",
|
||||
@@ -252,10 +256,10 @@ def edit_project_form(
|
||||
"project": project,
|
||||
"operation_types": _operation_type_choices(),
|
||||
"form_action": request.url_for(
|
||||
"projects.edit_project_submit", project_id=project_id
|
||||
"projects.edit_project_submit", project_id=project.id
|
||||
),
|
||||
"cancel_url": request.url_for(
|
||||
"projects.view_project", project_id=project_id
|
||||
"projects.view_project", project_id=project.id
|
||||
),
|
||||
},
|
||||
)
|
||||
@@ -267,21 +271,16 @@ def edit_project_form(
|
||||
name="projects.edit_project_submit",
|
||||
)
|
||||
def edit_project_submit(
|
||||
project_id: int,
|
||||
request: Request,
|
||||
project: Project = Depends(
|
||||
require_project_resource(require_manage=True)
|
||||
),
|
||||
name: str = Form(...),
|
||||
location: str | None = Form(None),
|
||||
operation_type: str | None = Form(None),
|
||||
description: str | None = Form(None),
|
||||
uow: UnitOfWork = Depends(get_unit_of_work),
|
||||
):
|
||||
try:
|
||||
project = uow.projects.get(project_id)
|
||||
except EntityNotFoundError as exc:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)
|
||||
) from exc
|
||||
|
||||
def _normalise(value: str | None) -> str | None:
|
||||
if value is None:
|
||||
return None
|
||||
@@ -301,10 +300,10 @@ def edit_project_submit(
|
||||
"project": project,
|
||||
"operation_types": _operation_type_choices(),
|
||||
"form_action": request.url_for(
|
||||
"projects.edit_project_submit", project_id=project_id
|
||||
"projects.edit_project_submit", project_id=project.id
|
||||
),
|
||||
"cancel_url": request.url_for(
|
||||
"projects.view_project", project_id=project_id
|
||||
"projects.view_project", project_id=project.id
|
||||
),
|
||||
"error": "Invalid operation type.",
|
||||
},
|
||||
@@ -315,6 +314,6 @@ def edit_project_submit(
|
||||
uow.flush()
|
||||
|
||||
return RedirectResponse(
|
||||
request.url_for("projects.view_project", project_id=project_id),
|
||||
request.url_for("projects.view_project", project_id=project.id),
|
||||
status_code=status.HTTP_303_SEE_OTHER,
|
||||
)
|
||||
|
||||
@@ -7,8 +7,13 @@ from fastapi import APIRouter, Depends, Form, HTTPException, Request, status
|
||||
from fastapi.responses import HTMLResponse, RedirectResponse
|
||||
from fastapi.templating import Jinja2Templates
|
||||
|
||||
from dependencies import get_unit_of_work
|
||||
from models import ResourceType, Scenario, ScenarioStatus
|
||||
from dependencies import (
|
||||
get_unit_of_work,
|
||||
require_any_role,
|
||||
require_roles,
|
||||
require_scenario_resource,
|
||||
)
|
||||
from models import ResourceType, Scenario, ScenarioStatus, User
|
||||
from schemas.scenario import (
|
||||
ScenarioComparisonRequest,
|
||||
ScenarioComparisonResponse,
|
||||
@@ -26,6 +31,9 @@ from services.unit_of_work import UnitOfWork
|
||||
router = APIRouter(tags=["Scenarios"])
|
||||
templates = Jinja2Templates(directory="templates")
|
||||
|
||||
READ_ROLES = ("viewer", "analyst", "project_manager", "admin")
|
||||
MANAGE_ROLES = ("project_manager", "admin")
|
||||
|
||||
|
||||
def _to_read_model(scenario: Scenario) -> ScenarioRead:
|
||||
return ScenarioRead.model_validate(scenario)
|
||||
@@ -44,20 +52,36 @@ def _scenario_status_choices() -> list[tuple[str, str]]:
|
||||
]
|
||||
|
||||
|
||||
def _require_project_repo(uow: UnitOfWork):
|
||||
if not uow.projects:
|
||||
raise RuntimeError("Project repository not initialised")
|
||||
return uow.projects
|
||||
|
||||
|
||||
def _require_scenario_repo(uow: UnitOfWork):
|
||||
if not uow.scenarios:
|
||||
raise RuntimeError("Scenario repository not initialised")
|
||||
return uow.scenarios
|
||||
|
||||
|
||||
@router.get(
|
||||
"/projects/{project_id}/scenarios",
|
||||
response_model=List[ScenarioRead],
|
||||
)
|
||||
def list_scenarios_for_project(
|
||||
project_id: int, uow: UnitOfWork = Depends(get_unit_of_work)
|
||||
project_id: int,
|
||||
_: User = Depends(require_any_role(*READ_ROLES)),
|
||||
uow: UnitOfWork = Depends(get_unit_of_work),
|
||||
) -> List[ScenarioRead]:
|
||||
project_repo = _require_project_repo(uow)
|
||||
scenario_repo = _require_scenario_repo(uow)
|
||||
try:
|
||||
uow.projects.get(project_id)
|
||||
project_repo.get(project_id)
|
||||
except EntityNotFoundError as exc:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)) from exc
|
||||
|
||||
scenarios = uow.scenarios.list_for_project(project_id)
|
||||
scenarios = scenario_repo.list_for_project(project_id)
|
||||
return [_to_read_model(scenario) for scenario in scenarios]
|
||||
|
||||
|
||||
@@ -69,10 +93,11 @@ def list_scenarios_for_project(
|
||||
def compare_scenarios(
|
||||
project_id: int,
|
||||
payload: ScenarioComparisonRequest,
|
||||
_: User = Depends(require_any_role(*READ_ROLES)),
|
||||
uow: UnitOfWork = Depends(get_unit_of_work),
|
||||
) -> ScenarioComparisonResponse:
|
||||
try:
|
||||
uow.projects.get(project_id)
|
||||
_require_project_repo(uow).get(project_id)
|
||||
except EntityNotFoundError as exc:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)
|
||||
@@ -116,10 +141,13 @@ def compare_scenarios(
|
||||
def create_scenario_for_project(
|
||||
project_id: int,
|
||||
payload: ScenarioCreate,
|
||||
_: User = Depends(require_roles(*MANAGE_ROLES)),
|
||||
uow: UnitOfWork = Depends(get_unit_of_work),
|
||||
) -> ScenarioRead:
|
||||
project_repo = _require_project_repo(uow)
|
||||
scenario_repo = _require_scenario_repo(uow)
|
||||
try:
|
||||
uow.projects.get(project_id)
|
||||
project_repo.get(project_id)
|
||||
except EntityNotFoundError as exc:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)) from exc
|
||||
@@ -127,7 +155,7 @@ def create_scenario_for_project(
|
||||
scenario = Scenario(project_id=project_id, **payload.model_dump())
|
||||
|
||||
try:
|
||||
created = uow.scenarios.create(scenario)
|
||||
created = scenario_repo.create(scenario)
|
||||
except EntityConflictError as exc:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT, detail=str(exc)) from exc
|
||||
@@ -136,28 +164,19 @@ def create_scenario_for_project(
|
||||
|
||||
@router.get("/scenarios/{scenario_id}", response_model=ScenarioRead)
|
||||
def get_scenario(
|
||||
scenario_id: int, uow: UnitOfWork = Depends(get_unit_of_work)
|
||||
scenario: Scenario = Depends(require_scenario_resource()),
|
||||
) -> ScenarioRead:
|
||||
try:
|
||||
scenario = uow.scenarios.get(scenario_id)
|
||||
except EntityNotFoundError as exc:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)) from exc
|
||||
return _to_read_model(scenario)
|
||||
|
||||
|
||||
@router.put("/scenarios/{scenario_id}", response_model=ScenarioRead)
|
||||
def update_scenario(
|
||||
scenario_id: int,
|
||||
payload: ScenarioUpdate,
|
||||
scenario: Scenario = Depends(
|
||||
require_scenario_resource(require_manage=True)
|
||||
),
|
||||
uow: UnitOfWork = Depends(get_unit_of_work),
|
||||
) -> ScenarioRead:
|
||||
try:
|
||||
scenario = uow.scenarios.get(scenario_id)
|
||||
except EntityNotFoundError as exc:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)) from exc
|
||||
|
||||
update_data = payload.model_dump(exclude_unset=True)
|
||||
for field, value in update_data.items():
|
||||
setattr(scenario, field, value)
|
||||
@@ -168,13 +187,12 @@ def update_scenario(
|
||||
|
||||
@router.delete("/scenarios/{scenario_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||
def delete_scenario(
|
||||
scenario_id: int, uow: UnitOfWork = Depends(get_unit_of_work)
|
||||
scenario: Scenario = Depends(
|
||||
require_scenario_resource(require_manage=True)
|
||||
),
|
||||
uow: UnitOfWork = Depends(get_unit_of_work),
|
||||
) -> None:
|
||||
try:
|
||||
uow.scenarios.delete(scenario_id)
|
||||
except EntityNotFoundError as exc:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)) from exc
|
||||
_require_scenario_repo(uow).delete(scenario.id)
|
||||
|
||||
|
||||
def _normalise(value: str | None) -> str | None:
|
||||
@@ -208,10 +226,13 @@ def _parse_discount_rate(value: str | None) -> float | None:
|
||||
name="scenarios.create_scenario_form",
|
||||
)
|
||||
def create_scenario_form(
|
||||
project_id: int, request: Request, uow: UnitOfWork = Depends(get_unit_of_work)
|
||||
project_id: int,
|
||||
request: Request,
|
||||
_: User = Depends(require_roles(*MANAGE_ROLES)),
|
||||
uow: UnitOfWork = Depends(get_unit_of_work),
|
||||
) -> HTMLResponse:
|
||||
try:
|
||||
project = uow.projects.get(project_id)
|
||||
project = _require_project_repo(uow).get(project_id)
|
||||
except EntityNotFoundError as exc:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)
|
||||
@@ -243,6 +264,7 @@ def create_scenario_form(
|
||||
def create_scenario_submit(
|
||||
project_id: int,
|
||||
request: Request,
|
||||
_: User = Depends(require_roles(*MANAGE_ROLES)),
|
||||
name: str = Form(...),
|
||||
description: str | None = Form(None),
|
||||
status_value: str = Form(ScenarioStatus.DRAFT.value),
|
||||
@@ -253,8 +275,10 @@ def create_scenario_submit(
|
||||
primary_resource: str | None = Form(None),
|
||||
uow: UnitOfWork = Depends(get_unit_of_work),
|
||||
):
|
||||
project_repo = _require_project_repo(uow)
|
||||
scenario_repo = _require_scenario_repo(uow)
|
||||
try:
|
||||
project = uow.projects.get(project_id)
|
||||
project = project_repo.get(project_id)
|
||||
except EntityNotFoundError as exc:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)
|
||||
@@ -288,7 +312,7 @@ def create_scenario_submit(
|
||||
)
|
||||
|
||||
try:
|
||||
uow.scenarios.create(scenario)
|
||||
scenario_repo.create(scenario)
|
||||
except EntityConflictError as exc:
|
||||
return templates.TemplateResponse(
|
||||
request,
|
||||
@@ -322,16 +346,13 @@ def create_scenario_submit(
|
||||
name="scenarios.view_scenario",
|
||||
)
|
||||
def view_scenario(
|
||||
scenario_id: int, request: Request, uow: UnitOfWork = Depends(get_unit_of_work)
|
||||
request: Request,
|
||||
scenario: Scenario = Depends(
|
||||
require_scenario_resource(with_children=True)
|
||||
),
|
||||
uow: UnitOfWork = Depends(get_unit_of_work),
|
||||
) -> HTMLResponse:
|
||||
try:
|
||||
scenario = uow.scenarios.get(scenario_id, with_children=True)
|
||||
except EntityNotFoundError as exc:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)
|
||||
) from exc
|
||||
|
||||
project = uow.projects.get(scenario.project_id)
|
||||
project = _require_project_repo(uow).get(scenario.project_id)
|
||||
financial_inputs = sorted(
|
||||
scenario.financial_inputs, key=lambda item: item.created_at
|
||||
)
|
||||
@@ -366,16 +387,13 @@ def view_scenario(
|
||||
name="scenarios.edit_scenario_form",
|
||||
)
|
||||
def edit_scenario_form(
|
||||
scenario_id: int, request: Request, uow: UnitOfWork = Depends(get_unit_of_work)
|
||||
request: Request,
|
||||
scenario: Scenario = Depends(
|
||||
require_scenario_resource(require_manage=True)
|
||||
),
|
||||
uow: UnitOfWork = Depends(get_unit_of_work),
|
||||
) -> HTMLResponse:
|
||||
try:
|
||||
scenario = uow.scenarios.get(scenario_id)
|
||||
except EntityNotFoundError as exc:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)
|
||||
) from exc
|
||||
|
||||
project = uow.projects.get(scenario.project_id)
|
||||
project = _require_project_repo(uow).get(scenario.project_id)
|
||||
|
||||
return templates.TemplateResponse(
|
||||
request,
|
||||
@@ -386,10 +404,10 @@ def edit_scenario_form(
|
||||
"scenario_statuses": _scenario_status_choices(),
|
||||
"resource_types": _resource_type_choices(),
|
||||
"form_action": request.url_for(
|
||||
"scenarios.edit_scenario_submit", scenario_id=scenario_id
|
||||
"scenarios.edit_scenario_submit", scenario_id=scenario.id
|
||||
),
|
||||
"cancel_url": request.url_for(
|
||||
"scenarios.view_scenario", scenario_id=scenario_id
|
||||
"scenarios.view_scenario", scenario_id=scenario.id
|
||||
),
|
||||
},
|
||||
)
|
||||
@@ -401,8 +419,10 @@ def edit_scenario_form(
|
||||
name="scenarios.edit_scenario_submit",
|
||||
)
|
||||
def edit_scenario_submit(
|
||||
scenario_id: int,
|
||||
request: Request,
|
||||
scenario: Scenario = Depends(
|
||||
require_scenario_resource(require_manage=True)
|
||||
),
|
||||
name: str = Form(...),
|
||||
description: str | None = Form(None),
|
||||
status_value: str = Form(ScenarioStatus.DRAFT.value),
|
||||
@@ -413,14 +433,7 @@ def edit_scenario_submit(
|
||||
primary_resource: str | None = Form(None),
|
||||
uow: UnitOfWork = Depends(get_unit_of_work),
|
||||
):
|
||||
try:
|
||||
scenario = uow.scenarios.get(scenario_id)
|
||||
except EntityNotFoundError as exc:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)
|
||||
) from exc
|
||||
|
||||
project = uow.projects.get(scenario.project_id)
|
||||
project = _require_project_repo(uow).get(scenario.project_id)
|
||||
|
||||
scenario.name = name.strip()
|
||||
scenario.description = _normalise(description)
|
||||
@@ -447,6 +460,6 @@ def edit_scenario_submit(
|
||||
uow.flush()
|
||||
|
||||
return RedirectResponse(
|
||||
request.url_for("scenarios.view_scenario", scenario_id=scenario_id),
|
||||
request.url_for("scenarios.view_scenario", scenario_id=scenario.id),
|
||||
status_code=status.HTTP_303_SEE_OTHER,
|
||||
)
|
||||
|
||||
20
scripts/00_initial_data.py
Normal file
20
scripts/00_initial_data.py
Normal file
@@ -0,0 +1,20 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
from scripts.initial_data import load_config, seed_initial_data
|
||||
|
||||
|
||||
def main() -> int:
|
||||
logging.basicConfig(level=logging.INFO, format="[%(levelname)s] %(message)s")
|
||||
try:
|
||||
config = load_config()
|
||||
seed_initial_data(config)
|
||||
except Exception as exc: # pragma: no cover - operational guard
|
||||
logging.exception("Seeding failed: %s", exc)
|
||||
return 1
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(main())
|
||||
1
scripts/__init__.py
Normal file
1
scripts/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Utility scripts for CalMiner maintenance tasks."""
|
||||
183
scripts/initial_data.py
Normal file
183
scripts/initial_data.py
Normal file
@@ -0,0 +1,183 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, Iterable
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from models import Role, User
|
||||
from services.repositories import DEFAULT_ROLE_DEFINITIONS, RoleRepository, UserRepository
|
||||
from services.unit_of_work import UnitOfWork
|
||||
|
||||
|
||||
@dataclass
|
||||
class SeedConfig:
|
||||
admin_email: str
|
||||
admin_username: str
|
||||
admin_password: str
|
||||
admin_roles: tuple[str, ...]
|
||||
force_reset: bool
|
||||
|
||||
|
||||
@dataclass
|
||||
class RoleSeedResult:
|
||||
created: int
|
||||
updated: int
|
||||
total: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdminSeedResult:
|
||||
created_user: bool
|
||||
updated_user: bool
|
||||
password_rotated: bool
|
||||
roles_granted: int
|
||||
|
||||
|
||||
def parse_bool(value: str | None) -> bool:
|
||||
if value is None:
|
||||
return False
|
||||
return value.strip().lower() in {"1", "true", "yes", "on"}
|
||||
|
||||
|
||||
def normalise_role_list(raw_value: str | None) -> tuple[str, ...]:
|
||||
if not raw_value:
|
||||
return ("admin",)
|
||||
parts = [segment.strip() for segment in raw_value.split(",") if segment.strip()]
|
||||
if "admin" not in parts:
|
||||
parts.insert(0, "admin")
|
||||
seen: set[str] = set()
|
||||
ordered: list[str] = []
|
||||
for role_name in parts:
|
||||
if role_name not in seen:
|
||||
ordered.append(role_name)
|
||||
seen.add(role_name)
|
||||
return tuple(ordered)
|
||||
|
||||
|
||||
def load_config() -> SeedConfig:
|
||||
load_dotenv()
|
||||
admin_email = os.getenv("CALMINER_SEED_ADMIN_EMAIL", "admin@calminer.local")
|
||||
admin_username = os.getenv("CALMINER_SEED_ADMIN_USERNAME", "admin")
|
||||
admin_password = os.getenv("CALMINER_SEED_ADMIN_PASSWORD", "ChangeMe123!")
|
||||
admin_roles = normalise_role_list(os.getenv("CALMINER_SEED_ADMIN_ROLES"))
|
||||
force_reset = parse_bool(os.getenv("CALMINER_SEED_FORCE"))
|
||||
return SeedConfig(
|
||||
admin_email=admin_email,
|
||||
admin_username=admin_username,
|
||||
admin_password=admin_password,
|
||||
admin_roles=admin_roles,
|
||||
force_reset=force_reset,
|
||||
)
|
||||
|
||||
|
||||
def ensure_default_roles(
|
||||
role_repo: RoleRepository,
|
||||
definitions: Iterable[dict[str, str]] = DEFAULT_ROLE_DEFINITIONS,
|
||||
) -> RoleSeedResult:
|
||||
created = 0
|
||||
updated = 0
|
||||
total = 0
|
||||
for definition in definitions:
|
||||
total += 1
|
||||
existing = role_repo.get_by_name(definition["name"])
|
||||
if existing is None:
|
||||
role_repo.create(Role(**definition))
|
||||
created += 1
|
||||
continue
|
||||
changed = False
|
||||
if existing.display_name != definition["display_name"]:
|
||||
existing.display_name = definition["display_name"]
|
||||
changed = True
|
||||
if existing.description != definition["description"]:
|
||||
existing.description = definition["description"]
|
||||
changed = True
|
||||
if changed:
|
||||
updated += 1
|
||||
role_repo.session.flush()
|
||||
return RoleSeedResult(created=created, updated=updated, total=total)
|
||||
|
||||
|
||||
def ensure_admin_user(
|
||||
user_repo: UserRepository,
|
||||
role_repo: RoleRepository,
|
||||
config: SeedConfig,
|
||||
) -> AdminSeedResult:
|
||||
created_user = False
|
||||
updated_user = False
|
||||
password_rotated = False
|
||||
roles_granted = 0
|
||||
|
||||
user = user_repo.get_by_email(config.admin_email, with_roles=True)
|
||||
if user is None:
|
||||
user = User(
|
||||
email=config.admin_email,
|
||||
username=config.admin_username,
|
||||
password_hash=User.hash_password(config.admin_password),
|
||||
is_active=True,
|
||||
is_superuser=True,
|
||||
)
|
||||
user_repo.create(user)
|
||||
created_user = True
|
||||
else:
|
||||
if user.username != config.admin_username:
|
||||
user.username = config.admin_username
|
||||
updated_user = True
|
||||
if not user.is_active:
|
||||
user.is_active = True
|
||||
updated_user = True
|
||||
if not user.is_superuser:
|
||||
user.is_superuser = True
|
||||
updated_user = True
|
||||
if config.force_reset:
|
||||
user.password_hash = User.hash_password(config.admin_password)
|
||||
password_rotated = True
|
||||
updated_user = True
|
||||
user_repo.session.flush()
|
||||
|
||||
for role_name in config.admin_roles:
|
||||
role = role_repo.get_by_name(role_name)
|
||||
if role is None:
|
||||
logging.warning("Role '%s' is not defined and will be skipped", role_name)
|
||||
continue
|
||||
already_assigned = any(assignment.role_id == role.id for assignment in user.role_assignments)
|
||||
if already_assigned:
|
||||
continue
|
||||
user_repo.assign_role(user_id=user.id, role_id=role.id, granted_by=user.id)
|
||||
roles_granted += 1
|
||||
|
||||
return AdminSeedResult(
|
||||
created_user=created_user,
|
||||
updated_user=updated_user,
|
||||
password_rotated=password_rotated,
|
||||
roles_granted=roles_granted,
|
||||
)
|
||||
|
||||
|
||||
def seed_initial_data(
|
||||
config: SeedConfig,
|
||||
*,
|
||||
unit_of_work_factory: Callable[[], UnitOfWork] | None = None,
|
||||
) -> None:
|
||||
logging.info("Starting initial data seeding")
|
||||
factory = unit_of_work_factory or UnitOfWork
|
||||
with factory() as uow:
|
||||
assert uow.roles is not None and uow.users is not None
|
||||
role_result = ensure_default_roles(uow.roles)
|
||||
admin_result = ensure_admin_user(uow.users, uow.roles, config)
|
||||
logging.info(
|
||||
"Roles processed: %s total, %s created, %s updated",
|
||||
role_result.total,
|
||||
role_result.created,
|
||||
role_result.updated,
|
||||
)
|
||||
logging.info(
|
||||
"Admin user: created=%s updated=%s password_rotated=%s roles_granted=%s",
|
||||
admin_result.created_user,
|
||||
admin_result.updated_user,
|
||||
admin_result.password_rotated,
|
||||
admin_result.roles_granted,
|
||||
)
|
||||
logging.info("Initial data seeding completed successfully")
|
||||
104
services/authorization.py
Normal file
104
services/authorization.py
Normal file
@@ -0,0 +1,104 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Iterable
|
||||
|
||||
from models import Project, Role, Scenario, User
|
||||
from services.exceptions import AuthorizationError, EntityNotFoundError
|
||||
from services.repositories import ProjectRepository, ScenarioRepository
|
||||
from services.unit_of_work import UnitOfWork
|
||||
|
||||
READ_ROLES: frozenset[str] = frozenset(
|
||||
{"viewer", "analyst", "project_manager", "admin"}
|
||||
)
|
||||
MANAGE_ROLES: frozenset[str] = frozenset({"project_manager", "admin"})
|
||||
|
||||
|
||||
def _user_role_names(user: User) -> set[str]:
|
||||
roles: Iterable[Role] = getattr(user, "roles", []) or []
|
||||
return {role.name for role in roles}
|
||||
|
||||
|
||||
def _require_project_repo(uow: UnitOfWork) -> ProjectRepository:
|
||||
if not uow.projects:
|
||||
raise RuntimeError("Project repository not initialised")
|
||||
return uow.projects
|
||||
|
||||
|
||||
def _require_scenario_repo(uow: UnitOfWork) -> ScenarioRepository:
|
||||
if not uow.scenarios:
|
||||
raise RuntimeError("Scenario repository not initialised")
|
||||
return uow.scenarios
|
||||
|
||||
|
||||
def _assert_user_can_access(user: User, *, require_manage: bool) -> None:
|
||||
if not user.is_active:
|
||||
raise AuthorizationError("User account is disabled.")
|
||||
if user.is_superuser:
|
||||
return
|
||||
|
||||
allowed = MANAGE_ROLES if require_manage else READ_ROLES
|
||||
if not _user_role_names(user) & allowed:
|
||||
raise AuthorizationError(
|
||||
"Insufficient role permissions for this action.")
|
||||
|
||||
|
||||
def ensure_project_access(
|
||||
uow: UnitOfWork,
|
||||
*,
|
||||
project_id: int,
|
||||
user: User,
|
||||
require_manage: bool = False,
|
||||
) -> Project:
|
||||
"""Resolve a project and ensure the user holds the required permissions."""
|
||||
|
||||
repo = _require_project_repo(uow)
|
||||
project = repo.get(project_id)
|
||||
_assert_user_can_access(user, require_manage=require_manage)
|
||||
return project
|
||||
|
||||
|
||||
def ensure_scenario_access(
|
||||
uow: UnitOfWork,
|
||||
*,
|
||||
scenario_id: int,
|
||||
user: User,
|
||||
require_manage: bool = False,
|
||||
with_children: bool = False,
|
||||
) -> Scenario:
|
||||
"""Resolve a scenario and ensure the user holds the required permissions."""
|
||||
|
||||
repo = _require_scenario_repo(uow)
|
||||
scenario = repo.get(scenario_id, with_children=with_children)
|
||||
_assert_user_can_access(user, require_manage=require_manage)
|
||||
return scenario
|
||||
|
||||
|
||||
def ensure_scenario_in_project(
|
||||
uow: UnitOfWork,
|
||||
*,
|
||||
project_id: int,
|
||||
scenario_id: int,
|
||||
user: User,
|
||||
require_manage: bool = False,
|
||||
with_children: bool = False,
|
||||
) -> Scenario:
|
||||
"""Resolve a scenario ensuring it belongs to the project and the user may access it."""
|
||||
|
||||
project = ensure_project_access(
|
||||
uow,
|
||||
project_id=project_id,
|
||||
user=user,
|
||||
require_manage=require_manage,
|
||||
)
|
||||
scenario = ensure_scenario_access(
|
||||
uow,
|
||||
scenario_id=scenario_id,
|
||||
user=user,
|
||||
require_manage=require_manage,
|
||||
with_children=with_children,
|
||||
)
|
||||
if scenario.project_id != project.id:
|
||||
raise EntityNotFoundError(
|
||||
f"Scenario {scenario_id} does not belong to project {project_id}."
|
||||
)
|
||||
return scenario
|
||||
@@ -12,6 +12,10 @@ class EntityConflictError(Exception):
|
||||
"""Raised when attempting to create or update an entity that violates uniqueness."""
|
||||
|
||||
|
||||
class AuthorizationError(Exception):
|
||||
"""Raised when a user lacks permission to perform an action."""
|
||||
|
||||
|
||||
@dataclass(eq=False)
|
||||
class ScenarioValidationError(Exception):
|
||||
"""Raised when scenarios fail comparison validation rules."""
|
||||
|
||||
@@ -57,6 +57,10 @@ class ProjectRepository:
|
||||
raise EntityNotFoundError(f"Project {project_id} not found")
|
||||
return project
|
||||
|
||||
def exists(self, project_id: int) -> bool:
|
||||
stmt = select(Project.id).where(Project.id == project_id).limit(1)
|
||||
return self.session.execute(stmt).scalar_one_or_none() is not None
|
||||
|
||||
def create(self, project: Project) -> Project:
|
||||
self.session.add(project)
|
||||
try:
|
||||
@@ -133,6 +137,10 @@ class ScenarioRepository:
|
||||
raise EntityNotFoundError(f"Scenario {scenario_id} not found")
|
||||
return scenario
|
||||
|
||||
def exists(self, scenario_id: int) -> bool:
|
||||
stmt = select(Scenario.id).where(Scenario.id == scenario_id).limit(1)
|
||||
return self.session.execute(stmt).scalar_one_or_none() is not None
|
||||
|
||||
def create(self, scenario: Scenario) -> Scenario:
|
||||
self.session.add(scenario)
|
||||
try:
|
||||
|
||||
@@ -3,7 +3,7 @@ from __future__ import annotations
|
||||
from collections.abc import Callable, Iterator
|
||||
|
||||
import pytest
|
||||
from fastapi import FastAPI
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi.testclient import TestClient
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.engine import Engine
|
||||
@@ -11,12 +11,14 @@ from sqlalchemy.orm import sessionmaker
|
||||
from sqlalchemy.pool import StaticPool
|
||||
|
||||
from config.database import Base
|
||||
from dependencies import get_unit_of_work
|
||||
from dependencies import get_auth_session, get_unit_of_work
|
||||
from models import User
|
||||
from routes.auth import router as auth_router
|
||||
from routes.dashboard import router as dashboard_router
|
||||
from routes.projects import router as projects_router
|
||||
from routes.scenarios import router as scenarios_router
|
||||
from services.unit_of_work import UnitOfWork
|
||||
from services.session import AuthSession, SessionTokens
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
@@ -55,6 +57,28 @@ def app(session_factory: sessionmaker) -> FastAPI:
|
||||
yield uow
|
||||
|
||||
application.dependency_overrides[get_unit_of_work] = _override_uow
|
||||
|
||||
with UnitOfWork(session_factory=session_factory) as uow:
|
||||
assert uow.users is not None
|
||||
uow.ensure_default_roles()
|
||||
user = User(
|
||||
email="test-superuser@example.com",
|
||||
username="test-superuser",
|
||||
password_hash=User.hash_password("test-password"),
|
||||
is_active=True,
|
||||
is_superuser=True,
|
||||
)
|
||||
uow.users.create(user)
|
||||
user = uow.users.get(user.id, with_roles=True)
|
||||
|
||||
def _override_auth_session(request: Request) -> AuthSession:
|
||||
session = AuthSession(tokens=SessionTokens(
|
||||
access_token="test", refresh_token="test"))
|
||||
session.user = user
|
||||
request.state.auth_session = session
|
||||
return session
|
||||
|
||||
application.dependency_overrides[get_auth_session] = _override_auth_session
|
||||
return application
|
||||
|
||||
|
||||
|
||||
156
tests/scripts/test_initial_data_seed.py
Normal file
156
tests/scripts/test_initial_data_seed.py
Normal file
@@ -0,0 +1,156 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from config.database import Base
|
||||
from scripts import initial_data
|
||||
from scripts.initial_data import AdminSeedResult, RoleSeedResult, SeedConfig
|
||||
from services.repositories import DEFAULT_ROLE_DEFINITIONS
|
||||
from services.unit_of_work import UnitOfWork
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def in_memory_session_factory() -> Callable[[], Session]:
|
||||
engine = create_engine("sqlite+pysqlite:///:memory:", future=True)
|
||||
Base.metadata.create_all(engine)
|
||||
factory = sessionmaker(bind=engine, autoflush=False, autocommit=False, future=True)
|
||||
|
||||
def _session_factory() -> Session:
|
||||
return factory()
|
||||
|
||||
return _session_factory
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def uow(in_memory_session_factory: Callable[[], Session]) -> UnitOfWork:
|
||||
return UnitOfWork(session_factory=in_memory_session_factory)
|
||||
|
||||
|
||||
def test_ensure_default_roles_idempotent(uow: UnitOfWork) -> None:
|
||||
with uow as working:
|
||||
assert working.roles is not None
|
||||
result_first = initial_data.ensure_default_roles(working.roles, DEFAULT_ROLE_DEFINITIONS)
|
||||
assert result_first == RoleSeedResult(created=4, updated=0, total=4)
|
||||
|
||||
with uow as working:
|
||||
assert working.roles is not None
|
||||
result_second = initial_data.ensure_default_roles(working.roles, DEFAULT_ROLE_DEFINITIONS)
|
||||
assert result_second == RoleSeedResult(created=0, updated=0, total=4)
|
||||
|
||||
|
||||
def test_ensure_admin_user_creates_and_assigns_roles(uow: UnitOfWork) -> None:
|
||||
config = SeedConfig(
|
||||
admin_email="admin@example.com",
|
||||
admin_username="admin",
|
||||
admin_password="secret",
|
||||
admin_roles=("admin", "viewer"),
|
||||
force_reset=False,
|
||||
)
|
||||
|
||||
with uow as working:
|
||||
assert working.roles is not None
|
||||
assert working.users is not None
|
||||
initial_data.ensure_default_roles(working.roles, DEFAULT_ROLE_DEFINITIONS)
|
||||
result = initial_data.ensure_admin_user(working.users, working.roles, config)
|
||||
assert result == AdminSeedResult(
|
||||
created_user=True,
|
||||
updated_user=False,
|
||||
password_rotated=False,
|
||||
roles_granted=2,
|
||||
)
|
||||
|
||||
with uow as working:
|
||||
assert working.roles is not None
|
||||
assert working.users is not None
|
||||
result_again = initial_data.ensure_admin_user(working.users, working.roles, config)
|
||||
assert result_again == AdminSeedResult(
|
||||
created_user=False,
|
||||
updated_user=False,
|
||||
password_rotated=False,
|
||||
roles_granted=0,
|
||||
)
|
||||
|
||||
with uow as working:
|
||||
assert working.users is not None
|
||||
user = working.users.get_by_email("admin@example.com", with_roles=True)
|
||||
assert user is not None
|
||||
assert user.is_active is True
|
||||
assert user.is_superuser is True
|
||||
role_names = {role.name for role in user.roles}
|
||||
assert role_names == {"admin", "viewer"}
|
||||
|
||||
|
||||
def test_ensure_admin_user_force_reset_rotates_password(uow: UnitOfWork) -> None:
|
||||
base_config = SeedConfig(
|
||||
admin_email="admin@example.com",
|
||||
admin_username="admin",
|
||||
admin_password="first",
|
||||
admin_roles=("admin",),
|
||||
force_reset=False,
|
||||
)
|
||||
|
||||
with uow as working:
|
||||
assert working.roles is not None
|
||||
assert working.users is not None
|
||||
initial_data.ensure_default_roles(working.roles, DEFAULT_ROLE_DEFINITIONS)
|
||||
initial_data.ensure_admin_user(working.users, working.roles, base_config)
|
||||
|
||||
rotate_config = SeedConfig(
|
||||
admin_email="admin@example.com",
|
||||
admin_username="admin",
|
||||
admin_password="second",
|
||||
admin_roles=("admin",),
|
||||
force_reset=True,
|
||||
)
|
||||
|
||||
with uow as working:
|
||||
assert working.users is not None
|
||||
user_before = working.users.get_by_email("admin@example.com")
|
||||
assert user_before is not None
|
||||
old_hash = user_before.password_hash
|
||||
|
||||
with uow as working:
|
||||
assert working.roles is not None
|
||||
assert working.users is not None
|
||||
result = initial_data.ensure_admin_user(working.users, working.roles, rotate_config)
|
||||
assert result.password_rotated is True
|
||||
|
||||
with uow as working:
|
||||
assert working.users is not None
|
||||
user_after = working.users.get_by_email("admin@example.com")
|
||||
assert user_after is not None
|
||||
assert user_after.password_hash != old_hash
|
||||
|
||||
|
||||
def test_seed_initial_data_logs_results(
|
||||
caplog,
|
||||
in_memory_session_factory: Callable[[], Session],
|
||||
) -> None:
|
||||
caplog.set_level(logging.INFO)
|
||||
config = SeedConfig(
|
||||
admin_email="seed@example.com",
|
||||
admin_username="seed",
|
||||
admin_password="seed-pass",
|
||||
admin_roles=("admin",),
|
||||
force_reset=False,
|
||||
)
|
||||
|
||||
initial_data.seed_initial_data(
|
||||
config,
|
||||
unit_of_work_factory=lambda: UnitOfWork(session_factory=in_memory_session_factory),
|
||||
)
|
||||
|
||||
assert "Starting initial data seeding" in caplog.text
|
||||
assert "Initial data seeding completed successfully" in caplog.text
|
||||
|
||||
with UnitOfWork(session_factory=in_memory_session_factory) as check_uow:
|
||||
assert check_uow.users is not None
|
||||
assert check_uow.roles is not None
|
||||
user = check_uow.users.get_by_email("seed@example.com")
|
||||
assert user is not None
|
||||
assert check_uow.roles.get_by_name("admin") is not None
|
||||
165
tests/test_authorization_helpers.py
Normal file
165
tests/test_authorization_helpers.py
Normal file
@@ -0,0 +1,165 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from models import MiningOperationType, Project, Scenario, ScenarioStatus, User
|
||||
from services.authorization import (
|
||||
ensure_project_access,
|
||||
ensure_scenario_access,
|
||||
ensure_scenario_in_project,
|
||||
)
|
||||
from services.exceptions import AuthorizationError, EntityNotFoundError
|
||||
from services.unit_of_work import UnitOfWork
|
||||
|
||||
|
||||
def _create_user_with_roles(
|
||||
uow: UnitOfWork,
|
||||
*,
|
||||
email: str,
|
||||
username: str,
|
||||
roles: set[str],
|
||||
) -> User:
|
||||
assert uow.users is not None
|
||||
assert uow.roles is not None
|
||||
|
||||
user = User(
|
||||
email=email,
|
||||
username=username,
|
||||
password_hash=User.hash_password("secret"),
|
||||
is_active=True,
|
||||
)
|
||||
uow.users.create(user)
|
||||
uow.ensure_default_roles()
|
||||
|
||||
for role_name in roles:
|
||||
role = uow.roles.get_by_name(role_name)
|
||||
assert role is not None, f"Role {role_name} should exist"
|
||||
uow.users.assign_role(user_id=user.id, role_id=role.id)
|
||||
|
||||
return uow.users.get(user.id, with_roles=True)
|
||||
|
||||
|
||||
def _create_project(uow: UnitOfWork, name: str) -> Project:
|
||||
assert uow.projects is not None
|
||||
project = Project(
|
||||
name=name,
|
||||
location=None,
|
||||
operation_type=MiningOperationType.OTHER,
|
||||
description=None,
|
||||
)
|
||||
uow.projects.create(project)
|
||||
return project
|
||||
|
||||
|
||||
def _create_scenario(uow: UnitOfWork, project: Project, name: str) -> Scenario:
|
||||
assert uow.scenarios is not None
|
||||
scenario = Scenario(
|
||||
project_id=project.id,
|
||||
name=name,
|
||||
description=None,
|
||||
status=ScenarioStatus.DRAFT,
|
||||
)
|
||||
uow.scenarios.create(scenario)
|
||||
return scenario
|
||||
|
||||
|
||||
def test_ensure_project_access_allows_view_roles(unit_of_work_factory) -> None:
|
||||
with unit_of_work_factory() as uow:
|
||||
project = _create_project(uow, "Project A")
|
||||
user = _create_user_with_roles(
|
||||
uow,
|
||||
email="viewer@example.com",
|
||||
username="viewer",
|
||||
roles={"viewer"},
|
||||
)
|
||||
|
||||
resolved = ensure_project_access(
|
||||
uow,
|
||||
project_id=project.id,
|
||||
user=user,
|
||||
)
|
||||
assert resolved.id == project.id
|
||||
|
||||
with pytest.raises(AuthorizationError):
|
||||
ensure_project_access(
|
||||
uow,
|
||||
project_id=project.id,
|
||||
user=user,
|
||||
require_manage=True,
|
||||
)
|
||||
|
||||
|
||||
def test_ensure_project_access_allows_manage_roles(unit_of_work_factory) -> None:
|
||||
with unit_of_work_factory() as uow:
|
||||
project = _create_project(uow, "Project B")
|
||||
user = _create_user_with_roles(
|
||||
uow,
|
||||
email="manager@example.com",
|
||||
username="manager",
|
||||
roles={"project_manager"},
|
||||
)
|
||||
|
||||
resolved = ensure_project_access(
|
||||
uow,
|
||||
project_id=project.id,
|
||||
user=user,
|
||||
require_manage=True,
|
||||
)
|
||||
assert resolved.id == project.id
|
||||
|
||||
|
||||
def test_ensure_scenario_access(unit_of_work_factory) -> None:
|
||||
with unit_of_work_factory() as uow:
|
||||
project = _create_project(uow, "Project C")
|
||||
scenario = _create_scenario(uow, project, "Scenario C1")
|
||||
user = _create_user_with_roles(
|
||||
uow,
|
||||
email="analyst@example.com",
|
||||
username="analyst",
|
||||
roles={"analyst"},
|
||||
)
|
||||
|
||||
resolved = ensure_scenario_access(
|
||||
uow,
|
||||
scenario_id=scenario.id,
|
||||
user=user,
|
||||
)
|
||||
assert resolved.id == scenario.id
|
||||
|
||||
with pytest.raises(AuthorizationError):
|
||||
ensure_scenario_access(
|
||||
uow,
|
||||
scenario_id=scenario.id,
|
||||
user=user,
|
||||
require_manage=True,
|
||||
)
|
||||
|
||||
|
||||
def test_ensure_scenario_in_project_validates_membership(unit_of_work_factory) -> None:
|
||||
with unit_of_work_factory() as uow:
|
||||
project_one = _create_project(uow, "Project D")
|
||||
project_two = _create_project(uow, "Project E")
|
||||
scenario = _create_scenario(uow, project_one, "Scenario D1")
|
||||
user = _create_user_with_roles(
|
||||
uow,
|
||||
email="manager2@example.com",
|
||||
username="manager2",
|
||||
roles={"project_manager"},
|
||||
)
|
||||
|
||||
resolved = ensure_scenario_in_project(
|
||||
uow,
|
||||
project_id=project_one.id,
|
||||
scenario_id=scenario.id,
|
||||
user=user,
|
||||
require_manage=True,
|
||||
)
|
||||
assert resolved.id == scenario.id
|
||||
|
||||
with pytest.raises(EntityNotFoundError):
|
||||
ensure_scenario_in_project(
|
||||
uow,
|
||||
project_id=project_two.id,
|
||||
scenario_id=scenario.id,
|
||||
user=user,
|
||||
)
|
||||
@@ -6,7 +6,7 @@ from typing import cast
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from fastapi import FastAPI
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi.testclient import TestClient
|
||||
from pydantic import ValidationError
|
||||
from sqlalchemy import create_engine
|
||||
@@ -15,13 +15,14 @@ from sqlalchemy.engine import Engine
|
||||
from sqlalchemy.pool import StaticPool
|
||||
|
||||
from config.database import Base
|
||||
from dependencies import get_unit_of_work
|
||||
from dependencies import get_auth_session, get_unit_of_work
|
||||
from models import (
|
||||
MiningOperationType,
|
||||
Project,
|
||||
ResourceType,
|
||||
Scenario,
|
||||
ScenarioStatus,
|
||||
User,
|
||||
)
|
||||
from schemas.scenario import (
|
||||
ScenarioComparisonRequest,
|
||||
@@ -30,6 +31,7 @@ from schemas.scenario import (
|
||||
from services.exceptions import ScenarioValidationError
|
||||
from services.scenario_validation import ScenarioComparisonValidator
|
||||
from services.unit_of_work import UnitOfWork
|
||||
from services.session import AuthSession, SessionTokens
|
||||
from routes.scenarios import router as scenarios_router
|
||||
|
||||
|
||||
@@ -159,6 +161,28 @@ def api_client(session_factory) -> Iterator[TestClient]:
|
||||
yield uow
|
||||
|
||||
app.dependency_overrides[get_unit_of_work] = _override_uow
|
||||
|
||||
with UnitOfWork(session_factory=session_factory) as uow:
|
||||
assert uow.users is not None
|
||||
uow.ensure_default_roles()
|
||||
user = User(
|
||||
email="test-scenarios@example.com",
|
||||
username="scenario-tester",
|
||||
password_hash=User.hash_password("password"),
|
||||
is_active=True,
|
||||
is_superuser=True,
|
||||
)
|
||||
uow.users.create(user)
|
||||
user = uow.users.get(user.id, with_roles=True)
|
||||
|
||||
def _override_auth_session(request: Request) -> AuthSession:
|
||||
session = AuthSession(tokens=SessionTokens(
|
||||
access_token="test", refresh_token="test"))
|
||||
session.user = user
|
||||
request.state.auth_session = session
|
||||
return session
|
||||
|
||||
app.dependency_overrides[get_auth_session] = _override_auth_session
|
||||
client = TestClient(app)
|
||||
try:
|
||||
yield client
|
||||
@@ -171,6 +195,8 @@ def _create_project_with_scenarios(
|
||||
scenario_overrides: list[dict[str, object]],
|
||||
) -> tuple[int, list[int]]:
|
||||
with UnitOfWork(session_factory=session_factory) as uow:
|
||||
assert uow.projects is not None
|
||||
assert uow.scenarios is not None
|
||||
project_name = f"Project {uuid4()}"
|
||||
project = Project(name=project_name,
|
||||
operation_type=MiningOperationType.OPEN_PIT)
|
||||
|
||||
Reference in New Issue
Block a user