feat: enhance project and scenario management with role-based access control

- Implemented role-based access control for project and scenario routes.
- Added authorization checks to ensure users have appropriate roles for viewing and managing projects and scenarios.
- Introduced utility functions for ensuring project and scenario access based on user roles.
- Refactored project and scenario routes to utilize new authorization helpers.
- Created initial data seeding script to set up default roles and an admin user.
- Added tests for authorization helpers and initial data seeding functionality.
- Updated exception handling to include authorization errors.
This commit is contained in:
2025-11-09 23:14:54 +01:00
parent 27262bdfa3
commit 0f79864188
16 changed files with 997 additions and 132 deletions

View File

@@ -9,6 +9,3 @@ DATABASE_PASSWORD=<password>
DATABASE_NAME=calminer DATABASE_NAME=calminer
# Optional: set a schema (comma-separated for multiple entries) # Optional: set a schema (comma-separated for multiple entries)
# DATABASE_SCHEMA=public # DATABASE_SCHEMA=public
# Legacy fallback (still supported, but granular settings are preferred)
# DATABASE_URL=postgresql://<user>:<password>@localhost:5432/calminer

View File

@@ -18,3 +18,10 @@
- Updated all Jinja2 template responses to the new Starlette signature to eliminate deprecation warnings while keeping request-aware context available to the templates. - Updated all Jinja2 template responses to the new Starlette signature to eliminate deprecation warnings while keeping request-aware context available to the templates.
- Introduced `services/security.py` to centralize Argon2 password hashing utilities and JWT creation/verification with typed payloads, and added pytest coverage for hashing, expiry, tampering, and token type mismatch scenarios. - Introduced `services/security.py` to centralize Argon2 password hashing utilities and JWT creation/verification with typed payloads, and added pytest coverage for hashing, expiry, tampering, and token type mismatch scenarios.
- Added `routes/auth.py` with registration, login, and password reset flows, refreshed auth templates with error messaging, wired navigation links, and introduced end-to-end pytest coverage for the new forms and token flows. - Added `routes/auth.py` with registration, login, and password reset flows, refreshed auth templates with error messaging, wired navigation links, and introduced end-to-end pytest coverage for the new forms and token flows.
- Implemented cookie-based authentication session middleware with automatic access token refresh, logout handling, navigation adjustments, and documentation/test updates capturing the new behaviour.
- Delivered idempotent seeding utilities with `scripts/initial_data.py`, entry-point runner `scripts/00_initial_data.py`, documentation updates, and pytest coverage to verify role/admin provisioning.
- Secured project and scenario routers with RBAC guard dependencies, enforced repository access checks via helper utilities, and aligned template routes with FastAPI dependency injection patterns.
## 2025-11-10
- Extended authorization helper layer with project/scenario ownership lookups, integrated them into FastAPI dependencies, refreshed pytest fixtures to keep the suite authenticated, and documented the new patterns across RBAC plan and security guides.

View File

@@ -1,11 +1,17 @@
from __future__ import annotations from __future__ import annotations
from collections.abc import Generator from collections.abc import Callable, Iterable, Generator
from fastapi import Depends, HTTPException, Request, status from fastapi import Depends, HTTPException, Request, status
from config.settings import Settings, get_settings from config.settings import Settings, get_settings
from models import User from models import Project, Role, Scenario, User
from services.authorization import (
ensure_project_access as ensure_project_access_helper,
ensure_scenario_access as ensure_scenario_access_helper,
ensure_scenario_in_project as ensure_scenario_in_project_helper,
)
from services.exceptions import AuthorizationError, EntityNotFoundError
from services.security import JWTSettings from services.security import JWTSettings
from services.session import ( from services.session import (
AuthSession, AuthSession,
@@ -90,9 +96,150 @@ def require_current_user(
) -> User: ) -> User:
"""Ensure that a request is authenticated and return the user context.""" """Ensure that a request is authenticated and return the user context."""
if session.user is None: if session.user is None or session.tokens.is_empty:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
detail="Authentication required.", detail="Authentication required.",
) )
return session.user return session.user
def require_authenticated_user(
user: User = Depends(require_current_user),
) -> User:
"""Ensure the current user account is active."""
if not user.is_active:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="User account is disabled.",
)
return user
def _user_role_names(user: User) -> set[str]:
roles: Iterable[Role] = getattr(user, "roles", []) or []
return {role.name for role in roles}
def require_roles(*roles: str) -> Callable[[User], User]:
"""Dependency factory enforcing membership in one of the given roles."""
required = tuple(role.strip() for role in roles if role.strip())
if not required:
raise ValueError("require_roles requires at least one role name")
def _dependency(user: User = Depends(require_authenticated_user)) -> User:
if user.is_superuser:
return user
role_names = _user_role_names(user)
if not any(role in role_names for role in required):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Insufficient permissions for this action.",
)
return user
return _dependency
def require_any_role(*roles: str) -> Callable[[User], User]:
"""Alias of require_roles for readability in some contexts."""
return require_roles(*roles)
def require_project_resource(*, require_manage: bool = False) -> Callable[[int], Project]:
"""Dependency factory that resolves a project with authorization checks."""
def _dependency(
project_id: int,
user: User = Depends(require_authenticated_user),
uow: UnitOfWork = Depends(get_unit_of_work),
) -> Project:
try:
return ensure_project_access_helper(
uow,
project_id=project_id,
user=user,
require_manage=require_manage,
)
except EntityNotFoundError as exc:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=str(exc),
) from exc
except AuthorizationError as exc:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=str(exc),
) from exc
return _dependency
def require_scenario_resource(
*, require_manage: bool = False, with_children: bool = False
) -> Callable[[int], Scenario]:
"""Dependency factory that resolves a scenario with authorization checks."""
def _dependency(
scenario_id: int,
user: User = Depends(require_authenticated_user),
uow: UnitOfWork = Depends(get_unit_of_work),
) -> Scenario:
try:
return ensure_scenario_access_helper(
uow,
scenario_id=scenario_id,
user=user,
require_manage=require_manage,
with_children=with_children,
)
except EntityNotFoundError as exc:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=str(exc),
) from exc
except AuthorizationError as exc:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=str(exc),
) from exc
return _dependency
def require_project_scenario_resource(
*, require_manage: bool = False, with_children: bool = False
) -> Callable[[int, int], Scenario]:
"""Dependency factory ensuring a scenario belongs to the given project and is accessible."""
def _dependency(
project_id: int,
scenario_id: int,
user: User = Depends(require_authenticated_user),
uow: UnitOfWork = Depends(get_unit_of_work),
) -> Scenario:
try:
return ensure_scenario_in_project_helper(
uow,
project_id=project_id,
scenario_id=scenario_id,
user=user,
require_manage=require_manage,
with_children=with_children,
)
except EntityNotFoundError as exc:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=str(exc),
) from exc
except AuthorizationError as exc:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=str(exc),
) from exc
return _dependency

View File

@@ -6,7 +6,8 @@ from fastapi import APIRouter, Depends, Request
from fastapi.responses import HTMLResponse from fastapi.responses import HTMLResponse
from fastapi.templating import Jinja2Templates from fastapi.templating import Jinja2Templates
from dependencies import get_unit_of_work from dependencies import get_unit_of_work, require_authenticated_user
from models import User
from models import ScenarioStatus from models import ScenarioStatus
from services.unit_of_work import UnitOfWork from services.unit_of_work import UnitOfWork
@@ -27,6 +28,8 @@ def _format_timestamp_with_time(moment: datetime | None) -> str | None:
def _load_metrics(uow: UnitOfWork) -> dict[str, object]: def _load_metrics(uow: UnitOfWork) -> dict[str, object]:
if not uow.projects or not uow.scenarios or not uow.financial_inputs:
raise RuntimeError("UnitOfWork repositories not initialised")
total_projects = uow.projects.count() total_projects = uow.projects.count()
active_scenarios = uow.scenarios.count_by_status(ScenarioStatus.ACTIVE) active_scenarios = uow.scenarios.count_by_status(ScenarioStatus.ACTIVE)
pending_simulations = uow.scenarios.count_by_status(ScenarioStatus.DRAFT) pending_simulations = uow.scenarios.count_by_status(ScenarioStatus.DRAFT)
@@ -40,11 +43,15 @@ def _load_metrics(uow: UnitOfWork) -> dict[str, object]:
def _load_recent_projects(uow: UnitOfWork) -> list: def _load_recent_projects(uow: UnitOfWork) -> list:
if not uow.projects:
raise RuntimeError("Project repository not initialised")
return list(uow.projects.recent(limit=5)) return list(uow.projects.recent(limit=5))
def _load_simulation_updates(uow: UnitOfWork) -> list[dict[str, object]]: def _load_simulation_updates(uow: UnitOfWork) -> list[dict[str, object]]:
updates: list[dict[str, object]] = [] updates: list[dict[str, object]] = []
if not uow.scenarios:
raise RuntimeError("Scenario repository not initialised")
scenarios = uow.scenarios.recent(limit=5, with_project=True) scenarios = uow.scenarios.recent(limit=5, with_project=True)
for scenario in scenarios: for scenario in scenarios:
project_name = scenario.project.name if scenario.project else f"Project #{scenario.project_id}" project_name = scenario.project.name if scenario.project else f"Project #{scenario.project_id}"
@@ -65,6 +72,9 @@ def _load_scenario_alerts(
) -> list[dict[str, object]]: ) -> list[dict[str, object]]:
alerts: list[dict[str, object]] = [] alerts: list[dict[str, object]] = []
if not uow.scenarios:
raise RuntimeError("Scenario repository not initialised")
drafts = uow.scenarios.list_by_status( drafts = uow.scenarios.list_by_status(
ScenarioStatus.DRAFT, limit=3, with_project=True ScenarioStatus.DRAFT, limit=3, with_project=True
) )
@@ -102,6 +112,7 @@ def _load_scenario_alerts(
@router.get("/", response_class=HTMLResponse, include_in_schema=False, name="dashboard.home") @router.get("/", response_class=HTMLResponse, include_in_schema=False, name="dashboard.home")
def dashboard_home( def dashboard_home(
request: Request, request: Request,
_: User = Depends(require_authenticated_user),
uow: UnitOfWork = Depends(get_unit_of_work), uow: UnitOfWork = Depends(get_unit_of_work),
) -> HTMLResponse: ) -> HTMLResponse:
context = { context = {

View File

@@ -6,8 +6,13 @@ from fastapi import APIRouter, Depends, Form, HTTPException, Request, status
from fastapi.responses import HTMLResponse, RedirectResponse from fastapi.responses import HTMLResponse, RedirectResponse
from fastapi.templating import Jinja2Templates from fastapi.templating import Jinja2Templates
from dependencies import get_unit_of_work from dependencies import (
from models import MiningOperationType, Project, ScenarioStatus get_unit_of_work,
require_any_role,
require_project_resource,
require_roles,
)
from models import MiningOperationType, Project, ScenarioStatus, User
from schemas.project import ProjectCreate, ProjectRead, ProjectUpdate from schemas.project import ProjectCreate, ProjectRead, ProjectUpdate
from services.exceptions import EntityConflictError, EntityNotFoundError from services.exceptions import EntityConflictError, EntityNotFoundError
from services.unit_of_work import UnitOfWork from services.unit_of_work import UnitOfWork
@@ -15,11 +20,20 @@ from services.unit_of_work import UnitOfWork
router = APIRouter(prefix="/projects", tags=["Projects"]) router = APIRouter(prefix="/projects", tags=["Projects"])
templates = Jinja2Templates(directory="templates") templates = Jinja2Templates(directory="templates")
READ_ROLES = ("viewer", "analyst", "project_manager", "admin")
MANAGE_ROLES = ("project_manager", "admin")
def _to_read_model(project: Project) -> ProjectRead: def _to_read_model(project: Project) -> ProjectRead:
return ProjectRead.model_validate(project) return ProjectRead.model_validate(project)
def _require_project_repo(uow: UnitOfWork):
if not uow.projects:
raise RuntimeError("Project repository not initialised")
return uow.projects
def _operation_type_choices() -> list[tuple[str, str]]: def _operation_type_choices() -> list[tuple[str, str]]:
return [ return [
(op.value, op.name.replace("_", " ").title()) for op in MiningOperationType (op.value, op.name.replace("_", " ").title()) for op in MiningOperationType
@@ -27,18 +41,23 @@ def _operation_type_choices() -> list[tuple[str, str]]:
@router.get("", response_model=List[ProjectRead]) @router.get("", response_model=List[ProjectRead])
def list_projects(uow: UnitOfWork = Depends(get_unit_of_work)) -> List[ProjectRead]: def list_projects(
projects = uow.projects.list() _: User = Depends(require_any_role(*READ_ROLES)),
uow: UnitOfWork = Depends(get_unit_of_work),
) -> List[ProjectRead]:
projects = _require_project_repo(uow).list()
return [_to_read_model(project) for project in projects] return [_to_read_model(project) for project in projects]
@router.post("", response_model=ProjectRead, status_code=status.HTTP_201_CREATED) @router.post("", response_model=ProjectRead, status_code=status.HTTP_201_CREATED)
def create_project( def create_project(
payload: ProjectCreate, uow: UnitOfWork = Depends(get_unit_of_work) payload: ProjectCreate,
_: User = Depends(require_roles(*MANAGE_ROLES)),
uow: UnitOfWork = Depends(get_unit_of_work),
) -> ProjectRead: ) -> ProjectRead:
project = Project(**payload.model_dump()) project = Project(**payload.model_dump())
try: try:
created = uow.projects.create(project) created = _require_project_repo(uow).create(project)
except EntityConflictError as exc: except EntityConflictError as exc:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_409_CONFLICT, detail=str(exc) status_code=status.HTTP_409_CONFLICT, detail=str(exc)
@@ -53,9 +72,11 @@ def create_project(
name="projects.project_list_page", name="projects.project_list_page",
) )
def project_list_page( def project_list_page(
request: Request, uow: UnitOfWork = Depends(get_unit_of_work) request: Request,
_: User = Depends(require_any_role(*READ_ROLES)),
uow: UnitOfWork = Depends(get_unit_of_work),
) -> HTMLResponse: ) -> HTMLResponse:
projects = uow.projects.list(with_children=True) projects = _require_project_repo(uow).list(with_children=True)
for project in projects: for project in projects:
setattr(project, "scenario_count", len(project.scenarios)) setattr(project, "scenario_count", len(project.scenarios))
return templates.TemplateResponse( return templates.TemplateResponse(
@@ -73,7 +94,9 @@ def project_list_page(
include_in_schema=False, include_in_schema=False,
name="projects.create_project_form", name="projects.create_project_form",
) )
def create_project_form(request: Request) -> HTMLResponse: def create_project_form(
request: Request, _: User = Depends(require_roles(*MANAGE_ROLES))
) -> HTMLResponse:
return templates.TemplateResponse( return templates.TemplateResponse(
request, request,
"projects/form.html", "projects/form.html",
@@ -93,6 +116,7 @@ def create_project_form(request: Request) -> HTMLResponse:
) )
def create_project_submit( def create_project_submit(
request: Request, request: Request,
_: User = Depends(require_roles(*MANAGE_ROLES)),
name: str = Form(...), name: str = Form(...),
location: str | None = Form(None), location: str | None = Form(None),
operation_type: str = Form(...), operation_type: str = Form(...),
@@ -128,7 +152,7 @@ def create_project_submit(
description=_normalise(description), description=_normalise(description),
) )
try: try:
uow.projects.create(project) _require_project_repo(uow).create(project)
except EntityConflictError as exc: except EntityConflictError as exc:
return templates.TemplateResponse( return templates.TemplateResponse(
request, request,
@@ -150,29 +174,18 @@ def create_project_submit(
@router.get("/{project_id}", response_model=ProjectRead) @router.get("/{project_id}", response_model=ProjectRead)
def get_project(project_id: int, uow: UnitOfWork = Depends(get_unit_of_work)) -> ProjectRead: def get_project(project: Project = Depends(require_project_resource())) -> ProjectRead:
try:
project = uow.projects.get(project_id)
except EntityNotFoundError as exc:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)
) from exc
return _to_read_model(project) return _to_read_model(project)
@router.put("/{project_id}", response_model=ProjectRead) @router.put("/{project_id}", response_model=ProjectRead)
def update_project( def update_project(
project_id: int,
payload: ProjectUpdate, payload: ProjectUpdate,
project: Project = Depends(
require_project_resource(require_manage=True)
),
uow: UnitOfWork = Depends(get_unit_of_work), uow: UnitOfWork = Depends(get_unit_of_work),
) -> ProjectRead: ) -> ProjectRead:
try:
project = uow.projects.get(project_id)
except EntityNotFoundError as exc:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)
) from exc
update_data = payload.model_dump(exclude_unset=True) update_data = payload.model_dump(exclude_unset=True)
for field, value in update_data.items(): for field, value in update_data.items():
setattr(project, field, value) setattr(project, field, value)
@@ -182,13 +195,11 @@ def update_project(
@router.delete("/{project_id}", status_code=status.HTTP_204_NO_CONTENT) @router.delete("/{project_id}", status_code=status.HTTP_204_NO_CONTENT)
def delete_project(project_id: int, uow: UnitOfWork = Depends(get_unit_of_work)) -> None: def delete_project(
try: project: Project = Depends(require_project_resource(require_manage=True)),
uow.projects.delete(project_id) uow: UnitOfWork = Depends(get_unit_of_work),
except EntityNotFoundError as exc: ) -> None:
raise HTTPException( _require_project_repo(uow).delete(project.id)
status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)
) from exc
@router.get( @router.get(
@@ -198,14 +209,11 @@ def delete_project(project_id: int, uow: UnitOfWork = Depends(get_unit_of_work))
name="projects.view_project", name="projects.view_project",
) )
def view_project( def view_project(
project_id: int, request: Request, uow: UnitOfWork = Depends(get_unit_of_work) request: Request,
project: Project = Depends(require_project_resource()),
uow: UnitOfWork = Depends(get_unit_of_work),
) -> HTMLResponse: ) -> HTMLResponse:
try: project = _require_project_repo(uow).get(project.id, with_children=True)
project = uow.projects.get(project_id, with_children=True)
except EntityNotFoundError as exc:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)
) from exc
scenarios = sorted(project.scenarios, key=lambda s: s.created_at) scenarios = sorted(project.scenarios, key=lambda s: s.created_at)
scenario_stats = { scenario_stats = {
@@ -236,15 +244,11 @@ def view_project(
name="projects.edit_project_form", name="projects.edit_project_form",
) )
def edit_project_form( def edit_project_form(
project_id: int, request: Request, uow: UnitOfWork = Depends(get_unit_of_work) request: Request,
project: Project = Depends(
require_project_resource(require_manage=True)
),
) -> HTMLResponse: ) -> HTMLResponse:
try:
project = uow.projects.get(project_id)
except EntityNotFoundError as exc:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)
) from exc
return templates.TemplateResponse( return templates.TemplateResponse(
request, request,
"projects/form.html", "projects/form.html",
@@ -252,10 +256,10 @@ def edit_project_form(
"project": project, "project": project,
"operation_types": _operation_type_choices(), "operation_types": _operation_type_choices(),
"form_action": request.url_for( "form_action": request.url_for(
"projects.edit_project_submit", project_id=project_id "projects.edit_project_submit", project_id=project.id
), ),
"cancel_url": request.url_for( "cancel_url": request.url_for(
"projects.view_project", project_id=project_id "projects.view_project", project_id=project.id
), ),
}, },
) )
@@ -267,21 +271,16 @@ def edit_project_form(
name="projects.edit_project_submit", name="projects.edit_project_submit",
) )
def edit_project_submit( def edit_project_submit(
project_id: int,
request: Request, request: Request,
project: Project = Depends(
require_project_resource(require_manage=True)
),
name: str = Form(...), name: str = Form(...),
location: str | None = Form(None), location: str | None = Form(None),
operation_type: str | None = Form(None), operation_type: str | None = Form(None),
description: str | None = Form(None), description: str | None = Form(None),
uow: UnitOfWork = Depends(get_unit_of_work), uow: UnitOfWork = Depends(get_unit_of_work),
): ):
try:
project = uow.projects.get(project_id)
except EntityNotFoundError as exc:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)
) from exc
def _normalise(value: str | None) -> str | None: def _normalise(value: str | None) -> str | None:
if value is None: if value is None:
return None return None
@@ -301,10 +300,10 @@ def edit_project_submit(
"project": project, "project": project,
"operation_types": _operation_type_choices(), "operation_types": _operation_type_choices(),
"form_action": request.url_for( "form_action": request.url_for(
"projects.edit_project_submit", project_id=project_id "projects.edit_project_submit", project_id=project.id
), ),
"cancel_url": request.url_for( "cancel_url": request.url_for(
"projects.view_project", project_id=project_id "projects.view_project", project_id=project.id
), ),
"error": "Invalid operation type.", "error": "Invalid operation type.",
}, },
@@ -315,6 +314,6 @@ def edit_project_submit(
uow.flush() uow.flush()
return RedirectResponse( return RedirectResponse(
request.url_for("projects.view_project", project_id=project_id), request.url_for("projects.view_project", project_id=project.id),
status_code=status.HTTP_303_SEE_OTHER, status_code=status.HTTP_303_SEE_OTHER,
) )

View File

@@ -7,8 +7,13 @@ from fastapi import APIRouter, Depends, Form, HTTPException, Request, status
from fastapi.responses import HTMLResponse, RedirectResponse from fastapi.responses import HTMLResponse, RedirectResponse
from fastapi.templating import Jinja2Templates from fastapi.templating import Jinja2Templates
from dependencies import get_unit_of_work from dependencies import (
from models import ResourceType, Scenario, ScenarioStatus get_unit_of_work,
require_any_role,
require_roles,
require_scenario_resource,
)
from models import ResourceType, Scenario, ScenarioStatus, User
from schemas.scenario import ( from schemas.scenario import (
ScenarioComparisonRequest, ScenarioComparisonRequest,
ScenarioComparisonResponse, ScenarioComparisonResponse,
@@ -26,6 +31,9 @@ from services.unit_of_work import UnitOfWork
router = APIRouter(tags=["Scenarios"]) router = APIRouter(tags=["Scenarios"])
templates = Jinja2Templates(directory="templates") templates = Jinja2Templates(directory="templates")
READ_ROLES = ("viewer", "analyst", "project_manager", "admin")
MANAGE_ROLES = ("project_manager", "admin")
def _to_read_model(scenario: Scenario) -> ScenarioRead: def _to_read_model(scenario: Scenario) -> ScenarioRead:
return ScenarioRead.model_validate(scenario) return ScenarioRead.model_validate(scenario)
@@ -44,20 +52,36 @@ def _scenario_status_choices() -> list[tuple[str, str]]:
] ]
def _require_project_repo(uow: UnitOfWork):
if not uow.projects:
raise RuntimeError("Project repository not initialised")
return uow.projects
def _require_scenario_repo(uow: UnitOfWork):
if not uow.scenarios:
raise RuntimeError("Scenario repository not initialised")
return uow.scenarios
@router.get( @router.get(
"/projects/{project_id}/scenarios", "/projects/{project_id}/scenarios",
response_model=List[ScenarioRead], response_model=List[ScenarioRead],
) )
def list_scenarios_for_project( def list_scenarios_for_project(
project_id: int, uow: UnitOfWork = Depends(get_unit_of_work) project_id: int,
_: User = Depends(require_any_role(*READ_ROLES)),
uow: UnitOfWork = Depends(get_unit_of_work),
) -> List[ScenarioRead]: ) -> List[ScenarioRead]:
project_repo = _require_project_repo(uow)
scenario_repo = _require_scenario_repo(uow)
try: try:
uow.projects.get(project_id) project_repo.get(project_id)
except EntityNotFoundError as exc: except EntityNotFoundError as exc:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)) from exc status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)) from exc
scenarios = uow.scenarios.list_for_project(project_id) scenarios = scenario_repo.list_for_project(project_id)
return [_to_read_model(scenario) for scenario in scenarios] return [_to_read_model(scenario) for scenario in scenarios]
@@ -69,10 +93,11 @@ def list_scenarios_for_project(
def compare_scenarios( def compare_scenarios(
project_id: int, project_id: int,
payload: ScenarioComparisonRequest, payload: ScenarioComparisonRequest,
_: User = Depends(require_any_role(*READ_ROLES)),
uow: UnitOfWork = Depends(get_unit_of_work), uow: UnitOfWork = Depends(get_unit_of_work),
) -> ScenarioComparisonResponse: ) -> ScenarioComparisonResponse:
try: try:
uow.projects.get(project_id) _require_project_repo(uow).get(project_id)
except EntityNotFoundError as exc: except EntityNotFoundError as exc:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail=str(exc) status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)
@@ -116,10 +141,13 @@ def compare_scenarios(
def create_scenario_for_project( def create_scenario_for_project(
project_id: int, project_id: int,
payload: ScenarioCreate, payload: ScenarioCreate,
_: User = Depends(require_roles(*MANAGE_ROLES)),
uow: UnitOfWork = Depends(get_unit_of_work), uow: UnitOfWork = Depends(get_unit_of_work),
) -> ScenarioRead: ) -> ScenarioRead:
project_repo = _require_project_repo(uow)
scenario_repo = _require_scenario_repo(uow)
try: try:
uow.projects.get(project_id) project_repo.get(project_id)
except EntityNotFoundError as exc: except EntityNotFoundError as exc:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)) from exc status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)) from exc
@@ -127,7 +155,7 @@ def create_scenario_for_project(
scenario = Scenario(project_id=project_id, **payload.model_dump()) scenario = Scenario(project_id=project_id, **payload.model_dump())
try: try:
created = uow.scenarios.create(scenario) created = scenario_repo.create(scenario)
except EntityConflictError as exc: except EntityConflictError as exc:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_409_CONFLICT, detail=str(exc)) from exc status_code=status.HTTP_409_CONFLICT, detail=str(exc)) from exc
@@ -136,28 +164,19 @@ def create_scenario_for_project(
@router.get("/scenarios/{scenario_id}", response_model=ScenarioRead) @router.get("/scenarios/{scenario_id}", response_model=ScenarioRead)
def get_scenario( def get_scenario(
scenario_id: int, uow: UnitOfWork = Depends(get_unit_of_work) scenario: Scenario = Depends(require_scenario_resource()),
) -> ScenarioRead: ) -> ScenarioRead:
try:
scenario = uow.scenarios.get(scenario_id)
except EntityNotFoundError as exc:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)) from exc
return _to_read_model(scenario) return _to_read_model(scenario)
@router.put("/scenarios/{scenario_id}", response_model=ScenarioRead) @router.put("/scenarios/{scenario_id}", response_model=ScenarioRead)
def update_scenario( def update_scenario(
scenario_id: int,
payload: ScenarioUpdate, payload: ScenarioUpdate,
scenario: Scenario = Depends(
require_scenario_resource(require_manage=True)
),
uow: UnitOfWork = Depends(get_unit_of_work), uow: UnitOfWork = Depends(get_unit_of_work),
) -> ScenarioRead: ) -> ScenarioRead:
try:
scenario = uow.scenarios.get(scenario_id)
except EntityNotFoundError as exc:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)) from exc
update_data = payload.model_dump(exclude_unset=True) update_data = payload.model_dump(exclude_unset=True)
for field, value in update_data.items(): for field, value in update_data.items():
setattr(scenario, field, value) setattr(scenario, field, value)
@@ -168,13 +187,12 @@ def update_scenario(
@router.delete("/scenarios/{scenario_id}", status_code=status.HTTP_204_NO_CONTENT) @router.delete("/scenarios/{scenario_id}", status_code=status.HTTP_204_NO_CONTENT)
def delete_scenario( def delete_scenario(
scenario_id: int, uow: UnitOfWork = Depends(get_unit_of_work) scenario: Scenario = Depends(
require_scenario_resource(require_manage=True)
),
uow: UnitOfWork = Depends(get_unit_of_work),
) -> None: ) -> None:
try: _require_scenario_repo(uow).delete(scenario.id)
uow.scenarios.delete(scenario_id)
except EntityNotFoundError as exc:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)) from exc
def _normalise(value: str | None) -> str | None: def _normalise(value: str | None) -> str | None:
@@ -208,10 +226,13 @@ def _parse_discount_rate(value: str | None) -> float | None:
name="scenarios.create_scenario_form", name="scenarios.create_scenario_form",
) )
def create_scenario_form( def create_scenario_form(
project_id: int, request: Request, uow: UnitOfWork = Depends(get_unit_of_work) project_id: int,
request: Request,
_: User = Depends(require_roles(*MANAGE_ROLES)),
uow: UnitOfWork = Depends(get_unit_of_work),
) -> HTMLResponse: ) -> HTMLResponse:
try: try:
project = uow.projects.get(project_id) project = _require_project_repo(uow).get(project_id)
except EntityNotFoundError as exc: except EntityNotFoundError as exc:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail=str(exc) status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)
@@ -243,6 +264,7 @@ def create_scenario_form(
def create_scenario_submit( def create_scenario_submit(
project_id: int, project_id: int,
request: Request, request: Request,
_: User = Depends(require_roles(*MANAGE_ROLES)),
name: str = Form(...), name: str = Form(...),
description: str | None = Form(None), description: str | None = Form(None),
status_value: str = Form(ScenarioStatus.DRAFT.value), status_value: str = Form(ScenarioStatus.DRAFT.value),
@@ -253,8 +275,10 @@ def create_scenario_submit(
primary_resource: str | None = Form(None), primary_resource: str | None = Form(None),
uow: UnitOfWork = Depends(get_unit_of_work), uow: UnitOfWork = Depends(get_unit_of_work),
): ):
project_repo = _require_project_repo(uow)
scenario_repo = _require_scenario_repo(uow)
try: try:
project = uow.projects.get(project_id) project = project_repo.get(project_id)
except EntityNotFoundError as exc: except EntityNotFoundError as exc:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail=str(exc) status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)
@@ -288,7 +312,7 @@ def create_scenario_submit(
) )
try: try:
uow.scenarios.create(scenario) scenario_repo.create(scenario)
except EntityConflictError as exc: except EntityConflictError as exc:
return templates.TemplateResponse( return templates.TemplateResponse(
request, request,
@@ -322,16 +346,13 @@ def create_scenario_submit(
name="scenarios.view_scenario", name="scenarios.view_scenario",
) )
def view_scenario( def view_scenario(
scenario_id: int, request: Request, uow: UnitOfWork = Depends(get_unit_of_work) request: Request,
scenario: Scenario = Depends(
require_scenario_resource(with_children=True)
),
uow: UnitOfWork = Depends(get_unit_of_work),
) -> HTMLResponse: ) -> HTMLResponse:
try: project = _require_project_repo(uow).get(scenario.project_id)
scenario = uow.scenarios.get(scenario_id, with_children=True)
except EntityNotFoundError as exc:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)
) from exc
project = uow.projects.get(scenario.project_id)
financial_inputs = sorted( financial_inputs = sorted(
scenario.financial_inputs, key=lambda item: item.created_at scenario.financial_inputs, key=lambda item: item.created_at
) )
@@ -366,16 +387,13 @@ def view_scenario(
name="scenarios.edit_scenario_form", name="scenarios.edit_scenario_form",
) )
def edit_scenario_form( def edit_scenario_form(
scenario_id: int, request: Request, uow: UnitOfWork = Depends(get_unit_of_work) request: Request,
scenario: Scenario = Depends(
require_scenario_resource(require_manage=True)
),
uow: UnitOfWork = Depends(get_unit_of_work),
) -> HTMLResponse: ) -> HTMLResponse:
try: project = _require_project_repo(uow).get(scenario.project_id)
scenario = uow.scenarios.get(scenario_id)
except EntityNotFoundError as exc:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)
) from exc
project = uow.projects.get(scenario.project_id)
return templates.TemplateResponse( return templates.TemplateResponse(
request, request,
@@ -386,10 +404,10 @@ def edit_scenario_form(
"scenario_statuses": _scenario_status_choices(), "scenario_statuses": _scenario_status_choices(),
"resource_types": _resource_type_choices(), "resource_types": _resource_type_choices(),
"form_action": request.url_for( "form_action": request.url_for(
"scenarios.edit_scenario_submit", scenario_id=scenario_id "scenarios.edit_scenario_submit", scenario_id=scenario.id
), ),
"cancel_url": request.url_for( "cancel_url": request.url_for(
"scenarios.view_scenario", scenario_id=scenario_id "scenarios.view_scenario", scenario_id=scenario.id
), ),
}, },
) )
@@ -401,8 +419,10 @@ def edit_scenario_form(
name="scenarios.edit_scenario_submit", name="scenarios.edit_scenario_submit",
) )
def edit_scenario_submit( def edit_scenario_submit(
scenario_id: int,
request: Request, request: Request,
scenario: Scenario = Depends(
require_scenario_resource(require_manage=True)
),
name: str = Form(...), name: str = Form(...),
description: str | None = Form(None), description: str | None = Form(None),
status_value: str = Form(ScenarioStatus.DRAFT.value), status_value: str = Form(ScenarioStatus.DRAFT.value),
@@ -413,14 +433,7 @@ def edit_scenario_submit(
primary_resource: str | None = Form(None), primary_resource: str | None = Form(None),
uow: UnitOfWork = Depends(get_unit_of_work), uow: UnitOfWork = Depends(get_unit_of_work),
): ):
try: project = _require_project_repo(uow).get(scenario.project_id)
scenario = uow.scenarios.get(scenario_id)
except EntityNotFoundError as exc:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)
) from exc
project = uow.projects.get(scenario.project_id)
scenario.name = name.strip() scenario.name = name.strip()
scenario.description = _normalise(description) scenario.description = _normalise(description)
@@ -447,6 +460,6 @@ def edit_scenario_submit(
uow.flush() uow.flush()
return RedirectResponse( return RedirectResponse(
request.url_for("scenarios.view_scenario", scenario_id=scenario_id), request.url_for("scenarios.view_scenario", scenario_id=scenario.id),
status_code=status.HTTP_303_SEE_OTHER, status_code=status.HTTP_303_SEE_OTHER,
) )

View File

@@ -0,0 +1,20 @@
from __future__ import annotations
import logging
from scripts.initial_data import load_config, seed_initial_data
def main() -> int:
logging.basicConfig(level=logging.INFO, format="[%(levelname)s] %(message)s")
try:
config = load_config()
seed_initial_data(config)
except Exception as exc: # pragma: no cover - operational guard
logging.exception("Seeding failed: %s", exc)
return 1
return 0
if __name__ == "__main__":
raise SystemExit(main())

1
scripts/__init__.py Normal file
View File

@@ -0,0 +1 @@
"""Utility scripts for CalMiner maintenance tasks."""

183
scripts/initial_data.py Normal file
View File

@@ -0,0 +1,183 @@
from __future__ import annotations
import logging
import os
from dataclasses import dataclass
from typing import Callable, Iterable
from dotenv import load_dotenv
from models import Role, User
from services.repositories import DEFAULT_ROLE_DEFINITIONS, RoleRepository, UserRepository
from services.unit_of_work import UnitOfWork
@dataclass
class SeedConfig:
admin_email: str
admin_username: str
admin_password: str
admin_roles: tuple[str, ...]
force_reset: bool
@dataclass
class RoleSeedResult:
created: int
updated: int
total: int
@dataclass
class AdminSeedResult:
created_user: bool
updated_user: bool
password_rotated: bool
roles_granted: int
def parse_bool(value: str | None) -> bool:
if value is None:
return False
return value.strip().lower() in {"1", "true", "yes", "on"}
def normalise_role_list(raw_value: str | None) -> tuple[str, ...]:
if not raw_value:
return ("admin",)
parts = [segment.strip() for segment in raw_value.split(",") if segment.strip()]
if "admin" not in parts:
parts.insert(0, "admin")
seen: set[str] = set()
ordered: list[str] = []
for role_name in parts:
if role_name not in seen:
ordered.append(role_name)
seen.add(role_name)
return tuple(ordered)
def load_config() -> SeedConfig:
load_dotenv()
admin_email = os.getenv("CALMINER_SEED_ADMIN_EMAIL", "admin@calminer.local")
admin_username = os.getenv("CALMINER_SEED_ADMIN_USERNAME", "admin")
admin_password = os.getenv("CALMINER_SEED_ADMIN_PASSWORD", "ChangeMe123!")
admin_roles = normalise_role_list(os.getenv("CALMINER_SEED_ADMIN_ROLES"))
force_reset = parse_bool(os.getenv("CALMINER_SEED_FORCE"))
return SeedConfig(
admin_email=admin_email,
admin_username=admin_username,
admin_password=admin_password,
admin_roles=admin_roles,
force_reset=force_reset,
)
def ensure_default_roles(
role_repo: RoleRepository,
definitions: Iterable[dict[str, str]] = DEFAULT_ROLE_DEFINITIONS,
) -> RoleSeedResult:
created = 0
updated = 0
total = 0
for definition in definitions:
total += 1
existing = role_repo.get_by_name(definition["name"])
if existing is None:
role_repo.create(Role(**definition))
created += 1
continue
changed = False
if existing.display_name != definition["display_name"]:
existing.display_name = definition["display_name"]
changed = True
if existing.description != definition["description"]:
existing.description = definition["description"]
changed = True
if changed:
updated += 1
role_repo.session.flush()
return RoleSeedResult(created=created, updated=updated, total=total)
def ensure_admin_user(
user_repo: UserRepository,
role_repo: RoleRepository,
config: SeedConfig,
) -> AdminSeedResult:
created_user = False
updated_user = False
password_rotated = False
roles_granted = 0
user = user_repo.get_by_email(config.admin_email, with_roles=True)
if user is None:
user = User(
email=config.admin_email,
username=config.admin_username,
password_hash=User.hash_password(config.admin_password),
is_active=True,
is_superuser=True,
)
user_repo.create(user)
created_user = True
else:
if user.username != config.admin_username:
user.username = config.admin_username
updated_user = True
if not user.is_active:
user.is_active = True
updated_user = True
if not user.is_superuser:
user.is_superuser = True
updated_user = True
if config.force_reset:
user.password_hash = User.hash_password(config.admin_password)
password_rotated = True
updated_user = True
user_repo.session.flush()
for role_name in config.admin_roles:
role = role_repo.get_by_name(role_name)
if role is None:
logging.warning("Role '%s' is not defined and will be skipped", role_name)
continue
already_assigned = any(assignment.role_id == role.id for assignment in user.role_assignments)
if already_assigned:
continue
user_repo.assign_role(user_id=user.id, role_id=role.id, granted_by=user.id)
roles_granted += 1
return AdminSeedResult(
created_user=created_user,
updated_user=updated_user,
password_rotated=password_rotated,
roles_granted=roles_granted,
)
def seed_initial_data(
config: SeedConfig,
*,
unit_of_work_factory: Callable[[], UnitOfWork] | None = None,
) -> None:
logging.info("Starting initial data seeding")
factory = unit_of_work_factory or UnitOfWork
with factory() as uow:
assert uow.roles is not None and uow.users is not None
role_result = ensure_default_roles(uow.roles)
admin_result = ensure_admin_user(uow.users, uow.roles, config)
logging.info(
"Roles processed: %s total, %s created, %s updated",
role_result.total,
role_result.created,
role_result.updated,
)
logging.info(
"Admin user: created=%s updated=%s password_rotated=%s roles_granted=%s",
admin_result.created_user,
admin_result.updated_user,
admin_result.password_rotated,
admin_result.roles_granted,
)
logging.info("Initial data seeding completed successfully")

104
services/authorization.py Normal file
View File

@@ -0,0 +1,104 @@
from __future__ import annotations
from typing import Iterable
from models import Project, Role, Scenario, User
from services.exceptions import AuthorizationError, EntityNotFoundError
from services.repositories import ProjectRepository, ScenarioRepository
from services.unit_of_work import UnitOfWork
READ_ROLES: frozenset[str] = frozenset(
{"viewer", "analyst", "project_manager", "admin"}
)
MANAGE_ROLES: frozenset[str] = frozenset({"project_manager", "admin"})
def _user_role_names(user: User) -> set[str]:
roles: Iterable[Role] = getattr(user, "roles", []) or []
return {role.name for role in roles}
def _require_project_repo(uow: UnitOfWork) -> ProjectRepository:
if not uow.projects:
raise RuntimeError("Project repository not initialised")
return uow.projects
def _require_scenario_repo(uow: UnitOfWork) -> ScenarioRepository:
if not uow.scenarios:
raise RuntimeError("Scenario repository not initialised")
return uow.scenarios
def _assert_user_can_access(user: User, *, require_manage: bool) -> None:
if not user.is_active:
raise AuthorizationError("User account is disabled.")
if user.is_superuser:
return
allowed = MANAGE_ROLES if require_manage else READ_ROLES
if not _user_role_names(user) & allowed:
raise AuthorizationError(
"Insufficient role permissions for this action.")
def ensure_project_access(
uow: UnitOfWork,
*,
project_id: int,
user: User,
require_manage: bool = False,
) -> Project:
"""Resolve a project and ensure the user holds the required permissions."""
repo = _require_project_repo(uow)
project = repo.get(project_id)
_assert_user_can_access(user, require_manage=require_manage)
return project
def ensure_scenario_access(
uow: UnitOfWork,
*,
scenario_id: int,
user: User,
require_manage: bool = False,
with_children: bool = False,
) -> Scenario:
"""Resolve a scenario and ensure the user holds the required permissions."""
repo = _require_scenario_repo(uow)
scenario = repo.get(scenario_id, with_children=with_children)
_assert_user_can_access(user, require_manage=require_manage)
return scenario
def ensure_scenario_in_project(
uow: UnitOfWork,
*,
project_id: int,
scenario_id: int,
user: User,
require_manage: bool = False,
with_children: bool = False,
) -> Scenario:
"""Resolve a scenario ensuring it belongs to the project and the user may access it."""
project = ensure_project_access(
uow,
project_id=project_id,
user=user,
require_manage=require_manage,
)
scenario = ensure_scenario_access(
uow,
scenario_id=scenario_id,
user=user,
require_manage=require_manage,
with_children=with_children,
)
if scenario.project_id != project.id:
raise EntityNotFoundError(
f"Scenario {scenario_id} does not belong to project {project_id}."
)
return scenario

View File

@@ -12,6 +12,10 @@ class EntityConflictError(Exception):
"""Raised when attempting to create or update an entity that violates uniqueness.""" """Raised when attempting to create or update an entity that violates uniqueness."""
class AuthorizationError(Exception):
"""Raised when a user lacks permission to perform an action."""
@dataclass(eq=False) @dataclass(eq=False)
class ScenarioValidationError(Exception): class ScenarioValidationError(Exception):
"""Raised when scenarios fail comparison validation rules.""" """Raised when scenarios fail comparison validation rules."""

View File

@@ -57,6 +57,10 @@ class ProjectRepository:
raise EntityNotFoundError(f"Project {project_id} not found") raise EntityNotFoundError(f"Project {project_id} not found")
return project return project
def exists(self, project_id: int) -> bool:
stmt = select(Project.id).where(Project.id == project_id).limit(1)
return self.session.execute(stmt).scalar_one_or_none() is not None
def create(self, project: Project) -> Project: def create(self, project: Project) -> Project:
self.session.add(project) self.session.add(project)
try: try:
@@ -133,6 +137,10 @@ class ScenarioRepository:
raise EntityNotFoundError(f"Scenario {scenario_id} not found") raise EntityNotFoundError(f"Scenario {scenario_id} not found")
return scenario return scenario
def exists(self, scenario_id: int) -> bool:
stmt = select(Scenario.id).where(Scenario.id == scenario_id).limit(1)
return self.session.execute(stmt).scalar_one_or_none() is not None
def create(self, scenario: Scenario) -> Scenario: def create(self, scenario: Scenario) -> Scenario:
self.session.add(scenario) self.session.add(scenario)
try: try:

View File

@@ -3,7 +3,7 @@ from __future__ import annotations
from collections.abc import Callable, Iterator from collections.abc import Callable, Iterator
import pytest import pytest
from fastapi import FastAPI from fastapi import FastAPI, Request
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
from sqlalchemy import create_engine from sqlalchemy import create_engine
from sqlalchemy.engine import Engine from sqlalchemy.engine import Engine
@@ -11,12 +11,14 @@ from sqlalchemy.orm import sessionmaker
from sqlalchemy.pool import StaticPool from sqlalchemy.pool import StaticPool
from config.database import Base from config.database import Base
from dependencies import get_unit_of_work from dependencies import get_auth_session, get_unit_of_work
from models import User
from routes.auth import router as auth_router from routes.auth import router as auth_router
from routes.dashboard import router as dashboard_router from routes.dashboard import router as dashboard_router
from routes.projects import router as projects_router from routes.projects import router as projects_router
from routes.scenarios import router as scenarios_router from routes.scenarios import router as scenarios_router
from services.unit_of_work import UnitOfWork from services.unit_of_work import UnitOfWork
from services.session import AuthSession, SessionTokens
@pytest.fixture() @pytest.fixture()
@@ -55,6 +57,28 @@ def app(session_factory: sessionmaker) -> FastAPI:
yield uow yield uow
application.dependency_overrides[get_unit_of_work] = _override_uow application.dependency_overrides[get_unit_of_work] = _override_uow
with UnitOfWork(session_factory=session_factory) as uow:
assert uow.users is not None
uow.ensure_default_roles()
user = User(
email="test-superuser@example.com",
username="test-superuser",
password_hash=User.hash_password("test-password"),
is_active=True,
is_superuser=True,
)
uow.users.create(user)
user = uow.users.get(user.id, with_roles=True)
def _override_auth_session(request: Request) -> AuthSession:
session = AuthSession(tokens=SessionTokens(
access_token="test", refresh_token="test"))
session.user = user
request.state.auth_session = session
return session
application.dependency_overrides[get_auth_session] = _override_auth_session
return application return application

View File

@@ -0,0 +1,156 @@
from __future__ import annotations
import logging
from collections.abc import Callable
import pytest
from sqlalchemy import create_engine
from sqlalchemy.orm import Session, sessionmaker
from config.database import Base
from scripts import initial_data
from scripts.initial_data import AdminSeedResult, RoleSeedResult, SeedConfig
from services.repositories import DEFAULT_ROLE_DEFINITIONS
from services.unit_of_work import UnitOfWork
@pytest.fixture()
def in_memory_session_factory() -> Callable[[], Session]:
engine = create_engine("sqlite+pysqlite:///:memory:", future=True)
Base.metadata.create_all(engine)
factory = sessionmaker(bind=engine, autoflush=False, autocommit=False, future=True)
def _session_factory() -> Session:
return factory()
return _session_factory
@pytest.fixture()
def uow(in_memory_session_factory: Callable[[], Session]) -> UnitOfWork:
return UnitOfWork(session_factory=in_memory_session_factory)
def test_ensure_default_roles_idempotent(uow: UnitOfWork) -> None:
with uow as working:
assert working.roles is not None
result_first = initial_data.ensure_default_roles(working.roles, DEFAULT_ROLE_DEFINITIONS)
assert result_first == RoleSeedResult(created=4, updated=0, total=4)
with uow as working:
assert working.roles is not None
result_second = initial_data.ensure_default_roles(working.roles, DEFAULT_ROLE_DEFINITIONS)
assert result_second == RoleSeedResult(created=0, updated=0, total=4)
def test_ensure_admin_user_creates_and_assigns_roles(uow: UnitOfWork) -> None:
config = SeedConfig(
admin_email="admin@example.com",
admin_username="admin",
admin_password="secret",
admin_roles=("admin", "viewer"),
force_reset=False,
)
with uow as working:
assert working.roles is not None
assert working.users is not None
initial_data.ensure_default_roles(working.roles, DEFAULT_ROLE_DEFINITIONS)
result = initial_data.ensure_admin_user(working.users, working.roles, config)
assert result == AdminSeedResult(
created_user=True,
updated_user=False,
password_rotated=False,
roles_granted=2,
)
with uow as working:
assert working.roles is not None
assert working.users is not None
result_again = initial_data.ensure_admin_user(working.users, working.roles, config)
assert result_again == AdminSeedResult(
created_user=False,
updated_user=False,
password_rotated=False,
roles_granted=0,
)
with uow as working:
assert working.users is not None
user = working.users.get_by_email("admin@example.com", with_roles=True)
assert user is not None
assert user.is_active is True
assert user.is_superuser is True
role_names = {role.name for role in user.roles}
assert role_names == {"admin", "viewer"}
def test_ensure_admin_user_force_reset_rotates_password(uow: UnitOfWork) -> None:
base_config = SeedConfig(
admin_email="admin@example.com",
admin_username="admin",
admin_password="first",
admin_roles=("admin",),
force_reset=False,
)
with uow as working:
assert working.roles is not None
assert working.users is not None
initial_data.ensure_default_roles(working.roles, DEFAULT_ROLE_DEFINITIONS)
initial_data.ensure_admin_user(working.users, working.roles, base_config)
rotate_config = SeedConfig(
admin_email="admin@example.com",
admin_username="admin",
admin_password="second",
admin_roles=("admin",),
force_reset=True,
)
with uow as working:
assert working.users is not None
user_before = working.users.get_by_email("admin@example.com")
assert user_before is not None
old_hash = user_before.password_hash
with uow as working:
assert working.roles is not None
assert working.users is not None
result = initial_data.ensure_admin_user(working.users, working.roles, rotate_config)
assert result.password_rotated is True
with uow as working:
assert working.users is not None
user_after = working.users.get_by_email("admin@example.com")
assert user_after is not None
assert user_after.password_hash != old_hash
def test_seed_initial_data_logs_results(
caplog,
in_memory_session_factory: Callable[[], Session],
) -> None:
caplog.set_level(logging.INFO)
config = SeedConfig(
admin_email="seed@example.com",
admin_username="seed",
admin_password="seed-pass",
admin_roles=("admin",),
force_reset=False,
)
initial_data.seed_initial_data(
config,
unit_of_work_factory=lambda: UnitOfWork(session_factory=in_memory_session_factory),
)
assert "Starting initial data seeding" in caplog.text
assert "Initial data seeding completed successfully" in caplog.text
with UnitOfWork(session_factory=in_memory_session_factory) as check_uow:
assert check_uow.users is not None
assert check_uow.roles is not None
user = check_uow.users.get_by_email("seed@example.com")
assert user is not None
assert check_uow.roles.get_by_name("admin") is not None

View File

@@ -0,0 +1,165 @@
from __future__ import annotations
import pytest
from models import MiningOperationType, Project, Scenario, ScenarioStatus, User
from services.authorization import (
ensure_project_access,
ensure_scenario_access,
ensure_scenario_in_project,
)
from services.exceptions import AuthorizationError, EntityNotFoundError
from services.unit_of_work import UnitOfWork
def _create_user_with_roles(
uow: UnitOfWork,
*,
email: str,
username: str,
roles: set[str],
) -> User:
assert uow.users is not None
assert uow.roles is not None
user = User(
email=email,
username=username,
password_hash=User.hash_password("secret"),
is_active=True,
)
uow.users.create(user)
uow.ensure_default_roles()
for role_name in roles:
role = uow.roles.get_by_name(role_name)
assert role is not None, f"Role {role_name} should exist"
uow.users.assign_role(user_id=user.id, role_id=role.id)
return uow.users.get(user.id, with_roles=True)
def _create_project(uow: UnitOfWork, name: str) -> Project:
assert uow.projects is not None
project = Project(
name=name,
location=None,
operation_type=MiningOperationType.OTHER,
description=None,
)
uow.projects.create(project)
return project
def _create_scenario(uow: UnitOfWork, project: Project, name: str) -> Scenario:
assert uow.scenarios is not None
scenario = Scenario(
project_id=project.id,
name=name,
description=None,
status=ScenarioStatus.DRAFT,
)
uow.scenarios.create(scenario)
return scenario
def test_ensure_project_access_allows_view_roles(unit_of_work_factory) -> None:
with unit_of_work_factory() as uow:
project = _create_project(uow, "Project A")
user = _create_user_with_roles(
uow,
email="viewer@example.com",
username="viewer",
roles={"viewer"},
)
resolved = ensure_project_access(
uow,
project_id=project.id,
user=user,
)
assert resolved.id == project.id
with pytest.raises(AuthorizationError):
ensure_project_access(
uow,
project_id=project.id,
user=user,
require_manage=True,
)
def test_ensure_project_access_allows_manage_roles(unit_of_work_factory) -> None:
with unit_of_work_factory() as uow:
project = _create_project(uow, "Project B")
user = _create_user_with_roles(
uow,
email="manager@example.com",
username="manager",
roles={"project_manager"},
)
resolved = ensure_project_access(
uow,
project_id=project.id,
user=user,
require_manage=True,
)
assert resolved.id == project.id
def test_ensure_scenario_access(unit_of_work_factory) -> None:
with unit_of_work_factory() as uow:
project = _create_project(uow, "Project C")
scenario = _create_scenario(uow, project, "Scenario C1")
user = _create_user_with_roles(
uow,
email="analyst@example.com",
username="analyst",
roles={"analyst"},
)
resolved = ensure_scenario_access(
uow,
scenario_id=scenario.id,
user=user,
)
assert resolved.id == scenario.id
with pytest.raises(AuthorizationError):
ensure_scenario_access(
uow,
scenario_id=scenario.id,
user=user,
require_manage=True,
)
def test_ensure_scenario_in_project_validates_membership(unit_of_work_factory) -> None:
with unit_of_work_factory() as uow:
project_one = _create_project(uow, "Project D")
project_two = _create_project(uow, "Project E")
scenario = _create_scenario(uow, project_one, "Scenario D1")
user = _create_user_with_roles(
uow,
email="manager2@example.com",
username="manager2",
roles={"project_manager"},
)
resolved = ensure_scenario_in_project(
uow,
project_id=project_one.id,
scenario_id=scenario.id,
user=user,
require_manage=True,
)
assert resolved.id == scenario.id
with pytest.raises(EntityNotFoundError):
ensure_scenario_in_project(
uow,
project_id=project_two.id,
scenario_id=scenario.id,
user=user,
)

View File

@@ -6,7 +6,7 @@ from typing import cast
from uuid import uuid4 from uuid import uuid4
import pytest import pytest
from fastapi import FastAPI from fastapi import FastAPI, Request
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
from pydantic import ValidationError from pydantic import ValidationError
from sqlalchemy import create_engine from sqlalchemy import create_engine
@@ -15,13 +15,14 @@ from sqlalchemy.engine import Engine
from sqlalchemy.pool import StaticPool from sqlalchemy.pool import StaticPool
from config.database import Base from config.database import Base
from dependencies import get_unit_of_work from dependencies import get_auth_session, get_unit_of_work
from models import ( from models import (
MiningOperationType, MiningOperationType,
Project, Project,
ResourceType, ResourceType,
Scenario, Scenario,
ScenarioStatus, ScenarioStatus,
User,
) )
from schemas.scenario import ( from schemas.scenario import (
ScenarioComparisonRequest, ScenarioComparisonRequest,
@@ -30,6 +31,7 @@ from schemas.scenario import (
from services.exceptions import ScenarioValidationError from services.exceptions import ScenarioValidationError
from services.scenario_validation import ScenarioComparisonValidator from services.scenario_validation import ScenarioComparisonValidator
from services.unit_of_work import UnitOfWork from services.unit_of_work import UnitOfWork
from services.session import AuthSession, SessionTokens
from routes.scenarios import router as scenarios_router from routes.scenarios import router as scenarios_router
@@ -159,6 +161,28 @@ def api_client(session_factory) -> Iterator[TestClient]:
yield uow yield uow
app.dependency_overrides[get_unit_of_work] = _override_uow app.dependency_overrides[get_unit_of_work] = _override_uow
with UnitOfWork(session_factory=session_factory) as uow:
assert uow.users is not None
uow.ensure_default_roles()
user = User(
email="test-scenarios@example.com",
username="scenario-tester",
password_hash=User.hash_password("password"),
is_active=True,
is_superuser=True,
)
uow.users.create(user)
user = uow.users.get(user.id, with_roles=True)
def _override_auth_session(request: Request) -> AuthSession:
session = AuthSession(tokens=SessionTokens(
access_token="test", refresh_token="test"))
session.user = user
request.state.auth_session = session
return session
app.dependency_overrides[get_auth_session] = _override_auth_session
client = TestClient(app) client = TestClient(app)
try: try:
yield client yield client
@@ -171,6 +195,8 @@ def _create_project_with_scenarios(
scenario_overrides: list[dict[str, object]], scenario_overrides: list[dict[str, object]],
) -> tuple[int, list[int]]: ) -> tuple[int, list[int]]:
with UnitOfWork(session_factory=session_factory) as uow: with UnitOfWork(session_factory=session_factory) as uow:
assert uow.projects is not None
assert uow.scenarios is not None
project_name = f"Project {uuid4()}" project_name = f"Project {uuid4()}"
project = Project(name=project_name, project = Project(name=project_name,
operation_type=MiningOperationType.OPEN_PIT) operation_type=MiningOperationType.OPEN_PIT)