feat: enhance project and scenario management with role-based access control
- Implemented role-based access control for project and scenario routes. - Added authorization checks to ensure users have appropriate roles for viewing and managing projects and scenarios. - Introduced utility functions for ensuring project and scenario access based on user roles. - Refactored project and scenario routes to utilize new authorization helpers. - Created initial data seeding script to set up default roles and an admin user. - Added tests for authorization helpers and initial data seeding functionality. - Updated exception handling to include authorization errors.
This commit is contained in:
@@ -9,6 +9,3 @@ DATABASE_PASSWORD=<password>
|
|||||||
DATABASE_NAME=calminer
|
DATABASE_NAME=calminer
|
||||||
# Optional: set a schema (comma-separated for multiple entries)
|
# Optional: set a schema (comma-separated for multiple entries)
|
||||||
# DATABASE_SCHEMA=public
|
# DATABASE_SCHEMA=public
|
||||||
|
|
||||||
# Legacy fallback (still supported, but granular settings are preferred)
|
|
||||||
# DATABASE_URL=postgresql://<user>:<password>@localhost:5432/calminer
|
|
||||||
@@ -18,3 +18,10 @@
|
|||||||
- Updated all Jinja2 template responses to the new Starlette signature to eliminate deprecation warnings while keeping request-aware context available to the templates.
|
- Updated all Jinja2 template responses to the new Starlette signature to eliminate deprecation warnings while keeping request-aware context available to the templates.
|
||||||
- Introduced `services/security.py` to centralize Argon2 password hashing utilities and JWT creation/verification with typed payloads, and added pytest coverage for hashing, expiry, tampering, and token type mismatch scenarios.
|
- Introduced `services/security.py` to centralize Argon2 password hashing utilities and JWT creation/verification with typed payloads, and added pytest coverage for hashing, expiry, tampering, and token type mismatch scenarios.
|
||||||
- Added `routes/auth.py` with registration, login, and password reset flows, refreshed auth templates with error messaging, wired navigation links, and introduced end-to-end pytest coverage for the new forms and token flows.
|
- Added `routes/auth.py` with registration, login, and password reset flows, refreshed auth templates with error messaging, wired navigation links, and introduced end-to-end pytest coverage for the new forms and token flows.
|
||||||
|
- Implemented cookie-based authentication session middleware with automatic access token refresh, logout handling, navigation adjustments, and documentation/test updates capturing the new behaviour.
|
||||||
|
- Delivered idempotent seeding utilities with `scripts/initial_data.py`, entry-point runner `scripts/00_initial_data.py`, documentation updates, and pytest coverage to verify role/admin provisioning.
|
||||||
|
- Secured project and scenario routers with RBAC guard dependencies, enforced repository access checks via helper utilities, and aligned template routes with FastAPI dependency injection patterns.
|
||||||
|
|
||||||
|
## 2025-11-10
|
||||||
|
|
||||||
|
- Extended authorization helper layer with project/scenario ownership lookups, integrated them into FastAPI dependencies, refreshed pytest fixtures to keep the suite authenticated, and documented the new patterns across RBAC plan and security guides.
|
||||||
|
|||||||
153
dependencies.py
153
dependencies.py
@@ -1,11 +1,17 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from collections.abc import Generator
|
from collections.abc import Callable, Iterable, Generator
|
||||||
|
|
||||||
from fastapi import Depends, HTTPException, Request, status
|
from fastapi import Depends, HTTPException, Request, status
|
||||||
|
|
||||||
from config.settings import Settings, get_settings
|
from config.settings import Settings, get_settings
|
||||||
from models import User
|
from models import Project, Role, Scenario, User
|
||||||
|
from services.authorization import (
|
||||||
|
ensure_project_access as ensure_project_access_helper,
|
||||||
|
ensure_scenario_access as ensure_scenario_access_helper,
|
||||||
|
ensure_scenario_in_project as ensure_scenario_in_project_helper,
|
||||||
|
)
|
||||||
|
from services.exceptions import AuthorizationError, EntityNotFoundError
|
||||||
from services.security import JWTSettings
|
from services.security import JWTSettings
|
||||||
from services.session import (
|
from services.session import (
|
||||||
AuthSession,
|
AuthSession,
|
||||||
@@ -90,9 +96,150 @@ def require_current_user(
|
|||||||
) -> User:
|
) -> User:
|
||||||
"""Ensure that a request is authenticated and return the user context."""
|
"""Ensure that a request is authenticated and return the user context."""
|
||||||
|
|
||||||
if session.user is None:
|
if session.user is None or session.tokens.is_empty:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
detail="Authentication required.",
|
detail="Authentication required.",
|
||||||
)
|
)
|
||||||
return session.user
|
return session.user
|
||||||
|
|
||||||
|
|
||||||
|
def require_authenticated_user(
|
||||||
|
user: User = Depends(require_current_user),
|
||||||
|
) -> User:
|
||||||
|
"""Ensure the current user account is active."""
|
||||||
|
|
||||||
|
if not user.is_active:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
|
detail="User account is disabled.",
|
||||||
|
)
|
||||||
|
return user
|
||||||
|
|
||||||
|
|
||||||
|
def _user_role_names(user: User) -> set[str]:
|
||||||
|
roles: Iterable[Role] = getattr(user, "roles", []) or []
|
||||||
|
return {role.name for role in roles}
|
||||||
|
|
||||||
|
|
||||||
|
def require_roles(*roles: str) -> Callable[[User], User]:
|
||||||
|
"""Dependency factory enforcing membership in one of the given roles."""
|
||||||
|
|
||||||
|
required = tuple(role.strip() for role in roles if role.strip())
|
||||||
|
if not required:
|
||||||
|
raise ValueError("require_roles requires at least one role name")
|
||||||
|
|
||||||
|
def _dependency(user: User = Depends(require_authenticated_user)) -> User:
|
||||||
|
if user.is_superuser:
|
||||||
|
return user
|
||||||
|
|
||||||
|
role_names = _user_role_names(user)
|
||||||
|
if not any(role in role_names for role in required):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
|
detail="Insufficient permissions for this action.",
|
||||||
|
)
|
||||||
|
return user
|
||||||
|
|
||||||
|
return _dependency
|
||||||
|
|
||||||
|
|
||||||
|
def require_any_role(*roles: str) -> Callable[[User], User]:
|
||||||
|
"""Alias of require_roles for readability in some contexts."""
|
||||||
|
|
||||||
|
return require_roles(*roles)
|
||||||
|
|
||||||
|
|
||||||
|
def require_project_resource(*, require_manage: bool = False) -> Callable[[int], Project]:
|
||||||
|
"""Dependency factory that resolves a project with authorization checks."""
|
||||||
|
|
||||||
|
def _dependency(
|
||||||
|
project_id: int,
|
||||||
|
user: User = Depends(require_authenticated_user),
|
||||||
|
uow: UnitOfWork = Depends(get_unit_of_work),
|
||||||
|
) -> Project:
|
||||||
|
try:
|
||||||
|
return ensure_project_access_helper(
|
||||||
|
uow,
|
||||||
|
project_id=project_id,
|
||||||
|
user=user,
|
||||||
|
require_manage=require_manage,
|
||||||
|
)
|
||||||
|
except EntityNotFoundError as exc:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail=str(exc),
|
||||||
|
) from exc
|
||||||
|
except AuthorizationError as exc:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
|
detail=str(exc),
|
||||||
|
) from exc
|
||||||
|
|
||||||
|
return _dependency
|
||||||
|
|
||||||
|
|
||||||
|
def require_scenario_resource(
|
||||||
|
*, require_manage: bool = False, with_children: bool = False
|
||||||
|
) -> Callable[[int], Scenario]:
|
||||||
|
"""Dependency factory that resolves a scenario with authorization checks."""
|
||||||
|
|
||||||
|
def _dependency(
|
||||||
|
scenario_id: int,
|
||||||
|
user: User = Depends(require_authenticated_user),
|
||||||
|
uow: UnitOfWork = Depends(get_unit_of_work),
|
||||||
|
) -> Scenario:
|
||||||
|
try:
|
||||||
|
return ensure_scenario_access_helper(
|
||||||
|
uow,
|
||||||
|
scenario_id=scenario_id,
|
||||||
|
user=user,
|
||||||
|
require_manage=require_manage,
|
||||||
|
with_children=with_children,
|
||||||
|
)
|
||||||
|
except EntityNotFoundError as exc:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail=str(exc),
|
||||||
|
) from exc
|
||||||
|
except AuthorizationError as exc:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
|
detail=str(exc),
|
||||||
|
) from exc
|
||||||
|
|
||||||
|
return _dependency
|
||||||
|
|
||||||
|
|
||||||
|
def require_project_scenario_resource(
|
||||||
|
*, require_manage: bool = False, with_children: bool = False
|
||||||
|
) -> Callable[[int, int], Scenario]:
|
||||||
|
"""Dependency factory ensuring a scenario belongs to the given project and is accessible."""
|
||||||
|
|
||||||
|
def _dependency(
|
||||||
|
project_id: int,
|
||||||
|
scenario_id: int,
|
||||||
|
user: User = Depends(require_authenticated_user),
|
||||||
|
uow: UnitOfWork = Depends(get_unit_of_work),
|
||||||
|
) -> Scenario:
|
||||||
|
try:
|
||||||
|
return ensure_scenario_in_project_helper(
|
||||||
|
uow,
|
||||||
|
project_id=project_id,
|
||||||
|
scenario_id=scenario_id,
|
||||||
|
user=user,
|
||||||
|
require_manage=require_manage,
|
||||||
|
with_children=with_children,
|
||||||
|
)
|
||||||
|
except EntityNotFoundError as exc:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail=str(exc),
|
||||||
|
) from exc
|
||||||
|
except AuthorizationError as exc:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
|
detail=str(exc),
|
||||||
|
) from exc
|
||||||
|
|
||||||
|
return _dependency
|
||||||
|
|||||||
@@ -6,7 +6,8 @@ from fastapi import APIRouter, Depends, Request
|
|||||||
from fastapi.responses import HTMLResponse
|
from fastapi.responses import HTMLResponse
|
||||||
from fastapi.templating import Jinja2Templates
|
from fastapi.templating import Jinja2Templates
|
||||||
|
|
||||||
from dependencies import get_unit_of_work
|
from dependencies import get_unit_of_work, require_authenticated_user
|
||||||
|
from models import User
|
||||||
from models import ScenarioStatus
|
from models import ScenarioStatus
|
||||||
from services.unit_of_work import UnitOfWork
|
from services.unit_of_work import UnitOfWork
|
||||||
|
|
||||||
@@ -27,6 +28,8 @@ def _format_timestamp_with_time(moment: datetime | None) -> str | None:
|
|||||||
|
|
||||||
|
|
||||||
def _load_metrics(uow: UnitOfWork) -> dict[str, object]:
|
def _load_metrics(uow: UnitOfWork) -> dict[str, object]:
|
||||||
|
if not uow.projects or not uow.scenarios or not uow.financial_inputs:
|
||||||
|
raise RuntimeError("UnitOfWork repositories not initialised")
|
||||||
total_projects = uow.projects.count()
|
total_projects = uow.projects.count()
|
||||||
active_scenarios = uow.scenarios.count_by_status(ScenarioStatus.ACTIVE)
|
active_scenarios = uow.scenarios.count_by_status(ScenarioStatus.ACTIVE)
|
||||||
pending_simulations = uow.scenarios.count_by_status(ScenarioStatus.DRAFT)
|
pending_simulations = uow.scenarios.count_by_status(ScenarioStatus.DRAFT)
|
||||||
@@ -40,11 +43,15 @@ def _load_metrics(uow: UnitOfWork) -> dict[str, object]:
|
|||||||
|
|
||||||
|
|
||||||
def _load_recent_projects(uow: UnitOfWork) -> list:
|
def _load_recent_projects(uow: UnitOfWork) -> list:
|
||||||
|
if not uow.projects:
|
||||||
|
raise RuntimeError("Project repository not initialised")
|
||||||
return list(uow.projects.recent(limit=5))
|
return list(uow.projects.recent(limit=5))
|
||||||
|
|
||||||
|
|
||||||
def _load_simulation_updates(uow: UnitOfWork) -> list[dict[str, object]]:
|
def _load_simulation_updates(uow: UnitOfWork) -> list[dict[str, object]]:
|
||||||
updates: list[dict[str, object]] = []
|
updates: list[dict[str, object]] = []
|
||||||
|
if not uow.scenarios:
|
||||||
|
raise RuntimeError("Scenario repository not initialised")
|
||||||
scenarios = uow.scenarios.recent(limit=5, with_project=True)
|
scenarios = uow.scenarios.recent(limit=5, with_project=True)
|
||||||
for scenario in scenarios:
|
for scenario in scenarios:
|
||||||
project_name = scenario.project.name if scenario.project else f"Project #{scenario.project_id}"
|
project_name = scenario.project.name if scenario.project else f"Project #{scenario.project_id}"
|
||||||
@@ -65,6 +72,9 @@ def _load_scenario_alerts(
|
|||||||
) -> list[dict[str, object]]:
|
) -> list[dict[str, object]]:
|
||||||
alerts: list[dict[str, object]] = []
|
alerts: list[dict[str, object]] = []
|
||||||
|
|
||||||
|
if not uow.scenarios:
|
||||||
|
raise RuntimeError("Scenario repository not initialised")
|
||||||
|
|
||||||
drafts = uow.scenarios.list_by_status(
|
drafts = uow.scenarios.list_by_status(
|
||||||
ScenarioStatus.DRAFT, limit=3, with_project=True
|
ScenarioStatus.DRAFT, limit=3, with_project=True
|
||||||
)
|
)
|
||||||
@@ -102,6 +112,7 @@ def _load_scenario_alerts(
|
|||||||
@router.get("/", response_class=HTMLResponse, include_in_schema=False, name="dashboard.home")
|
@router.get("/", response_class=HTMLResponse, include_in_schema=False, name="dashboard.home")
|
||||||
def dashboard_home(
|
def dashboard_home(
|
||||||
request: Request,
|
request: Request,
|
||||||
|
_: User = Depends(require_authenticated_user),
|
||||||
uow: UnitOfWork = Depends(get_unit_of_work),
|
uow: UnitOfWork = Depends(get_unit_of_work),
|
||||||
) -> HTMLResponse:
|
) -> HTMLResponse:
|
||||||
context = {
|
context = {
|
||||||
|
|||||||
@@ -6,8 +6,13 @@ from fastapi import APIRouter, Depends, Form, HTTPException, Request, status
|
|||||||
from fastapi.responses import HTMLResponse, RedirectResponse
|
from fastapi.responses import HTMLResponse, RedirectResponse
|
||||||
from fastapi.templating import Jinja2Templates
|
from fastapi.templating import Jinja2Templates
|
||||||
|
|
||||||
from dependencies import get_unit_of_work
|
from dependencies import (
|
||||||
from models import MiningOperationType, Project, ScenarioStatus
|
get_unit_of_work,
|
||||||
|
require_any_role,
|
||||||
|
require_project_resource,
|
||||||
|
require_roles,
|
||||||
|
)
|
||||||
|
from models import MiningOperationType, Project, ScenarioStatus, User
|
||||||
from schemas.project import ProjectCreate, ProjectRead, ProjectUpdate
|
from schemas.project import ProjectCreate, ProjectRead, ProjectUpdate
|
||||||
from services.exceptions import EntityConflictError, EntityNotFoundError
|
from services.exceptions import EntityConflictError, EntityNotFoundError
|
||||||
from services.unit_of_work import UnitOfWork
|
from services.unit_of_work import UnitOfWork
|
||||||
@@ -15,11 +20,20 @@ from services.unit_of_work import UnitOfWork
|
|||||||
router = APIRouter(prefix="/projects", tags=["Projects"])
|
router = APIRouter(prefix="/projects", tags=["Projects"])
|
||||||
templates = Jinja2Templates(directory="templates")
|
templates = Jinja2Templates(directory="templates")
|
||||||
|
|
||||||
|
READ_ROLES = ("viewer", "analyst", "project_manager", "admin")
|
||||||
|
MANAGE_ROLES = ("project_manager", "admin")
|
||||||
|
|
||||||
|
|
||||||
def _to_read_model(project: Project) -> ProjectRead:
|
def _to_read_model(project: Project) -> ProjectRead:
|
||||||
return ProjectRead.model_validate(project)
|
return ProjectRead.model_validate(project)
|
||||||
|
|
||||||
|
|
||||||
|
def _require_project_repo(uow: UnitOfWork):
|
||||||
|
if not uow.projects:
|
||||||
|
raise RuntimeError("Project repository not initialised")
|
||||||
|
return uow.projects
|
||||||
|
|
||||||
|
|
||||||
def _operation_type_choices() -> list[tuple[str, str]]:
|
def _operation_type_choices() -> list[tuple[str, str]]:
|
||||||
return [
|
return [
|
||||||
(op.value, op.name.replace("_", " ").title()) for op in MiningOperationType
|
(op.value, op.name.replace("_", " ").title()) for op in MiningOperationType
|
||||||
@@ -27,18 +41,23 @@ def _operation_type_choices() -> list[tuple[str, str]]:
|
|||||||
|
|
||||||
|
|
||||||
@router.get("", response_model=List[ProjectRead])
|
@router.get("", response_model=List[ProjectRead])
|
||||||
def list_projects(uow: UnitOfWork = Depends(get_unit_of_work)) -> List[ProjectRead]:
|
def list_projects(
|
||||||
projects = uow.projects.list()
|
_: User = Depends(require_any_role(*READ_ROLES)),
|
||||||
|
uow: UnitOfWork = Depends(get_unit_of_work),
|
||||||
|
) -> List[ProjectRead]:
|
||||||
|
projects = _require_project_repo(uow).list()
|
||||||
return [_to_read_model(project) for project in projects]
|
return [_to_read_model(project) for project in projects]
|
||||||
|
|
||||||
|
|
||||||
@router.post("", response_model=ProjectRead, status_code=status.HTTP_201_CREATED)
|
@router.post("", response_model=ProjectRead, status_code=status.HTTP_201_CREATED)
|
||||||
def create_project(
|
def create_project(
|
||||||
payload: ProjectCreate, uow: UnitOfWork = Depends(get_unit_of_work)
|
payload: ProjectCreate,
|
||||||
|
_: User = Depends(require_roles(*MANAGE_ROLES)),
|
||||||
|
uow: UnitOfWork = Depends(get_unit_of_work),
|
||||||
) -> ProjectRead:
|
) -> ProjectRead:
|
||||||
project = Project(**payload.model_dump())
|
project = Project(**payload.model_dump())
|
||||||
try:
|
try:
|
||||||
created = uow.projects.create(project)
|
created = _require_project_repo(uow).create(project)
|
||||||
except EntityConflictError as exc:
|
except EntityConflictError as exc:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_409_CONFLICT, detail=str(exc)
|
status_code=status.HTTP_409_CONFLICT, detail=str(exc)
|
||||||
@@ -53,9 +72,11 @@ def create_project(
|
|||||||
name="projects.project_list_page",
|
name="projects.project_list_page",
|
||||||
)
|
)
|
||||||
def project_list_page(
|
def project_list_page(
|
||||||
request: Request, uow: UnitOfWork = Depends(get_unit_of_work)
|
request: Request,
|
||||||
|
_: User = Depends(require_any_role(*READ_ROLES)),
|
||||||
|
uow: UnitOfWork = Depends(get_unit_of_work),
|
||||||
) -> HTMLResponse:
|
) -> HTMLResponse:
|
||||||
projects = uow.projects.list(with_children=True)
|
projects = _require_project_repo(uow).list(with_children=True)
|
||||||
for project in projects:
|
for project in projects:
|
||||||
setattr(project, "scenario_count", len(project.scenarios))
|
setattr(project, "scenario_count", len(project.scenarios))
|
||||||
return templates.TemplateResponse(
|
return templates.TemplateResponse(
|
||||||
@@ -73,7 +94,9 @@ def project_list_page(
|
|||||||
include_in_schema=False,
|
include_in_schema=False,
|
||||||
name="projects.create_project_form",
|
name="projects.create_project_form",
|
||||||
)
|
)
|
||||||
def create_project_form(request: Request) -> HTMLResponse:
|
def create_project_form(
|
||||||
|
request: Request, _: User = Depends(require_roles(*MANAGE_ROLES))
|
||||||
|
) -> HTMLResponse:
|
||||||
return templates.TemplateResponse(
|
return templates.TemplateResponse(
|
||||||
request,
|
request,
|
||||||
"projects/form.html",
|
"projects/form.html",
|
||||||
@@ -93,6 +116,7 @@ def create_project_form(request: Request) -> HTMLResponse:
|
|||||||
)
|
)
|
||||||
def create_project_submit(
|
def create_project_submit(
|
||||||
request: Request,
|
request: Request,
|
||||||
|
_: User = Depends(require_roles(*MANAGE_ROLES)),
|
||||||
name: str = Form(...),
|
name: str = Form(...),
|
||||||
location: str | None = Form(None),
|
location: str | None = Form(None),
|
||||||
operation_type: str = Form(...),
|
operation_type: str = Form(...),
|
||||||
@@ -128,7 +152,7 @@ def create_project_submit(
|
|||||||
description=_normalise(description),
|
description=_normalise(description),
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
uow.projects.create(project)
|
_require_project_repo(uow).create(project)
|
||||||
except EntityConflictError as exc:
|
except EntityConflictError as exc:
|
||||||
return templates.TemplateResponse(
|
return templates.TemplateResponse(
|
||||||
request,
|
request,
|
||||||
@@ -150,29 +174,18 @@ def create_project_submit(
|
|||||||
|
|
||||||
|
|
||||||
@router.get("/{project_id}", response_model=ProjectRead)
|
@router.get("/{project_id}", response_model=ProjectRead)
|
||||||
def get_project(project_id: int, uow: UnitOfWork = Depends(get_unit_of_work)) -> ProjectRead:
|
def get_project(project: Project = Depends(require_project_resource())) -> ProjectRead:
|
||||||
try:
|
|
||||||
project = uow.projects.get(project_id)
|
|
||||||
except EntityNotFoundError as exc:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)
|
|
||||||
) from exc
|
|
||||||
return _to_read_model(project)
|
return _to_read_model(project)
|
||||||
|
|
||||||
|
|
||||||
@router.put("/{project_id}", response_model=ProjectRead)
|
@router.put("/{project_id}", response_model=ProjectRead)
|
||||||
def update_project(
|
def update_project(
|
||||||
project_id: int,
|
|
||||||
payload: ProjectUpdate,
|
payload: ProjectUpdate,
|
||||||
|
project: Project = Depends(
|
||||||
|
require_project_resource(require_manage=True)
|
||||||
|
),
|
||||||
uow: UnitOfWork = Depends(get_unit_of_work),
|
uow: UnitOfWork = Depends(get_unit_of_work),
|
||||||
) -> ProjectRead:
|
) -> ProjectRead:
|
||||||
try:
|
|
||||||
project = uow.projects.get(project_id)
|
|
||||||
except EntityNotFoundError as exc:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)
|
|
||||||
) from exc
|
|
||||||
|
|
||||||
update_data = payload.model_dump(exclude_unset=True)
|
update_data = payload.model_dump(exclude_unset=True)
|
||||||
for field, value in update_data.items():
|
for field, value in update_data.items():
|
||||||
setattr(project, field, value)
|
setattr(project, field, value)
|
||||||
@@ -182,13 +195,11 @@ def update_project(
|
|||||||
|
|
||||||
|
|
||||||
@router.delete("/{project_id}", status_code=status.HTTP_204_NO_CONTENT)
|
@router.delete("/{project_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||||
def delete_project(project_id: int, uow: UnitOfWork = Depends(get_unit_of_work)) -> None:
|
def delete_project(
|
||||||
try:
|
project: Project = Depends(require_project_resource(require_manage=True)),
|
||||||
uow.projects.delete(project_id)
|
uow: UnitOfWork = Depends(get_unit_of_work),
|
||||||
except EntityNotFoundError as exc:
|
) -> None:
|
||||||
raise HTTPException(
|
_require_project_repo(uow).delete(project.id)
|
||||||
status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)
|
|
||||||
) from exc
|
|
||||||
|
|
||||||
|
|
||||||
@router.get(
|
@router.get(
|
||||||
@@ -198,14 +209,11 @@ def delete_project(project_id: int, uow: UnitOfWork = Depends(get_unit_of_work))
|
|||||||
name="projects.view_project",
|
name="projects.view_project",
|
||||||
)
|
)
|
||||||
def view_project(
|
def view_project(
|
||||||
project_id: int, request: Request, uow: UnitOfWork = Depends(get_unit_of_work)
|
request: Request,
|
||||||
|
project: Project = Depends(require_project_resource()),
|
||||||
|
uow: UnitOfWork = Depends(get_unit_of_work),
|
||||||
) -> HTMLResponse:
|
) -> HTMLResponse:
|
||||||
try:
|
project = _require_project_repo(uow).get(project.id, with_children=True)
|
||||||
project = uow.projects.get(project_id, with_children=True)
|
|
||||||
except EntityNotFoundError as exc:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)
|
|
||||||
) from exc
|
|
||||||
|
|
||||||
scenarios = sorted(project.scenarios, key=lambda s: s.created_at)
|
scenarios = sorted(project.scenarios, key=lambda s: s.created_at)
|
||||||
scenario_stats = {
|
scenario_stats = {
|
||||||
@@ -236,15 +244,11 @@ def view_project(
|
|||||||
name="projects.edit_project_form",
|
name="projects.edit_project_form",
|
||||||
)
|
)
|
||||||
def edit_project_form(
|
def edit_project_form(
|
||||||
project_id: int, request: Request, uow: UnitOfWork = Depends(get_unit_of_work)
|
request: Request,
|
||||||
|
project: Project = Depends(
|
||||||
|
require_project_resource(require_manage=True)
|
||||||
|
),
|
||||||
) -> HTMLResponse:
|
) -> HTMLResponse:
|
||||||
try:
|
|
||||||
project = uow.projects.get(project_id)
|
|
||||||
except EntityNotFoundError as exc:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)
|
|
||||||
) from exc
|
|
||||||
|
|
||||||
return templates.TemplateResponse(
|
return templates.TemplateResponse(
|
||||||
request,
|
request,
|
||||||
"projects/form.html",
|
"projects/form.html",
|
||||||
@@ -252,10 +256,10 @@ def edit_project_form(
|
|||||||
"project": project,
|
"project": project,
|
||||||
"operation_types": _operation_type_choices(),
|
"operation_types": _operation_type_choices(),
|
||||||
"form_action": request.url_for(
|
"form_action": request.url_for(
|
||||||
"projects.edit_project_submit", project_id=project_id
|
"projects.edit_project_submit", project_id=project.id
|
||||||
),
|
),
|
||||||
"cancel_url": request.url_for(
|
"cancel_url": request.url_for(
|
||||||
"projects.view_project", project_id=project_id
|
"projects.view_project", project_id=project.id
|
||||||
),
|
),
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
@@ -267,21 +271,16 @@ def edit_project_form(
|
|||||||
name="projects.edit_project_submit",
|
name="projects.edit_project_submit",
|
||||||
)
|
)
|
||||||
def edit_project_submit(
|
def edit_project_submit(
|
||||||
project_id: int,
|
|
||||||
request: Request,
|
request: Request,
|
||||||
|
project: Project = Depends(
|
||||||
|
require_project_resource(require_manage=True)
|
||||||
|
),
|
||||||
name: str = Form(...),
|
name: str = Form(...),
|
||||||
location: str | None = Form(None),
|
location: str | None = Form(None),
|
||||||
operation_type: str | None = Form(None),
|
operation_type: str | None = Form(None),
|
||||||
description: str | None = Form(None),
|
description: str | None = Form(None),
|
||||||
uow: UnitOfWork = Depends(get_unit_of_work),
|
uow: UnitOfWork = Depends(get_unit_of_work),
|
||||||
):
|
):
|
||||||
try:
|
|
||||||
project = uow.projects.get(project_id)
|
|
||||||
except EntityNotFoundError as exc:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)
|
|
||||||
) from exc
|
|
||||||
|
|
||||||
def _normalise(value: str | None) -> str | None:
|
def _normalise(value: str | None) -> str | None:
|
||||||
if value is None:
|
if value is None:
|
||||||
return None
|
return None
|
||||||
@@ -301,10 +300,10 @@ def edit_project_submit(
|
|||||||
"project": project,
|
"project": project,
|
||||||
"operation_types": _operation_type_choices(),
|
"operation_types": _operation_type_choices(),
|
||||||
"form_action": request.url_for(
|
"form_action": request.url_for(
|
||||||
"projects.edit_project_submit", project_id=project_id
|
"projects.edit_project_submit", project_id=project.id
|
||||||
),
|
),
|
||||||
"cancel_url": request.url_for(
|
"cancel_url": request.url_for(
|
||||||
"projects.view_project", project_id=project_id
|
"projects.view_project", project_id=project.id
|
||||||
),
|
),
|
||||||
"error": "Invalid operation type.",
|
"error": "Invalid operation type.",
|
||||||
},
|
},
|
||||||
@@ -315,6 +314,6 @@ def edit_project_submit(
|
|||||||
uow.flush()
|
uow.flush()
|
||||||
|
|
||||||
return RedirectResponse(
|
return RedirectResponse(
|
||||||
request.url_for("projects.view_project", project_id=project_id),
|
request.url_for("projects.view_project", project_id=project.id),
|
||||||
status_code=status.HTTP_303_SEE_OTHER,
|
status_code=status.HTTP_303_SEE_OTHER,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -7,8 +7,13 @@ from fastapi import APIRouter, Depends, Form, HTTPException, Request, status
|
|||||||
from fastapi.responses import HTMLResponse, RedirectResponse
|
from fastapi.responses import HTMLResponse, RedirectResponse
|
||||||
from fastapi.templating import Jinja2Templates
|
from fastapi.templating import Jinja2Templates
|
||||||
|
|
||||||
from dependencies import get_unit_of_work
|
from dependencies import (
|
||||||
from models import ResourceType, Scenario, ScenarioStatus
|
get_unit_of_work,
|
||||||
|
require_any_role,
|
||||||
|
require_roles,
|
||||||
|
require_scenario_resource,
|
||||||
|
)
|
||||||
|
from models import ResourceType, Scenario, ScenarioStatus, User
|
||||||
from schemas.scenario import (
|
from schemas.scenario import (
|
||||||
ScenarioComparisonRequest,
|
ScenarioComparisonRequest,
|
||||||
ScenarioComparisonResponse,
|
ScenarioComparisonResponse,
|
||||||
@@ -26,6 +31,9 @@ from services.unit_of_work import UnitOfWork
|
|||||||
router = APIRouter(tags=["Scenarios"])
|
router = APIRouter(tags=["Scenarios"])
|
||||||
templates = Jinja2Templates(directory="templates")
|
templates = Jinja2Templates(directory="templates")
|
||||||
|
|
||||||
|
READ_ROLES = ("viewer", "analyst", "project_manager", "admin")
|
||||||
|
MANAGE_ROLES = ("project_manager", "admin")
|
||||||
|
|
||||||
|
|
||||||
def _to_read_model(scenario: Scenario) -> ScenarioRead:
|
def _to_read_model(scenario: Scenario) -> ScenarioRead:
|
||||||
return ScenarioRead.model_validate(scenario)
|
return ScenarioRead.model_validate(scenario)
|
||||||
@@ -44,20 +52,36 @@ def _scenario_status_choices() -> list[tuple[str, str]]:
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def _require_project_repo(uow: UnitOfWork):
|
||||||
|
if not uow.projects:
|
||||||
|
raise RuntimeError("Project repository not initialised")
|
||||||
|
return uow.projects
|
||||||
|
|
||||||
|
|
||||||
|
def _require_scenario_repo(uow: UnitOfWork):
|
||||||
|
if not uow.scenarios:
|
||||||
|
raise RuntimeError("Scenario repository not initialised")
|
||||||
|
return uow.scenarios
|
||||||
|
|
||||||
|
|
||||||
@router.get(
|
@router.get(
|
||||||
"/projects/{project_id}/scenarios",
|
"/projects/{project_id}/scenarios",
|
||||||
response_model=List[ScenarioRead],
|
response_model=List[ScenarioRead],
|
||||||
)
|
)
|
||||||
def list_scenarios_for_project(
|
def list_scenarios_for_project(
|
||||||
project_id: int, uow: UnitOfWork = Depends(get_unit_of_work)
|
project_id: int,
|
||||||
|
_: User = Depends(require_any_role(*READ_ROLES)),
|
||||||
|
uow: UnitOfWork = Depends(get_unit_of_work),
|
||||||
) -> List[ScenarioRead]:
|
) -> List[ScenarioRead]:
|
||||||
|
project_repo = _require_project_repo(uow)
|
||||||
|
scenario_repo = _require_scenario_repo(uow)
|
||||||
try:
|
try:
|
||||||
uow.projects.get(project_id)
|
project_repo.get(project_id)
|
||||||
except EntityNotFoundError as exc:
|
except EntityNotFoundError as exc:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)) from exc
|
status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)) from exc
|
||||||
|
|
||||||
scenarios = uow.scenarios.list_for_project(project_id)
|
scenarios = scenario_repo.list_for_project(project_id)
|
||||||
return [_to_read_model(scenario) for scenario in scenarios]
|
return [_to_read_model(scenario) for scenario in scenarios]
|
||||||
|
|
||||||
|
|
||||||
@@ -69,10 +93,11 @@ def list_scenarios_for_project(
|
|||||||
def compare_scenarios(
|
def compare_scenarios(
|
||||||
project_id: int,
|
project_id: int,
|
||||||
payload: ScenarioComparisonRequest,
|
payload: ScenarioComparisonRequest,
|
||||||
|
_: User = Depends(require_any_role(*READ_ROLES)),
|
||||||
uow: UnitOfWork = Depends(get_unit_of_work),
|
uow: UnitOfWork = Depends(get_unit_of_work),
|
||||||
) -> ScenarioComparisonResponse:
|
) -> ScenarioComparisonResponse:
|
||||||
try:
|
try:
|
||||||
uow.projects.get(project_id)
|
_require_project_repo(uow).get(project_id)
|
||||||
except EntityNotFoundError as exc:
|
except EntityNotFoundError as exc:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)
|
status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)
|
||||||
@@ -116,10 +141,13 @@ def compare_scenarios(
|
|||||||
def create_scenario_for_project(
|
def create_scenario_for_project(
|
||||||
project_id: int,
|
project_id: int,
|
||||||
payload: ScenarioCreate,
|
payload: ScenarioCreate,
|
||||||
|
_: User = Depends(require_roles(*MANAGE_ROLES)),
|
||||||
uow: UnitOfWork = Depends(get_unit_of_work),
|
uow: UnitOfWork = Depends(get_unit_of_work),
|
||||||
) -> ScenarioRead:
|
) -> ScenarioRead:
|
||||||
|
project_repo = _require_project_repo(uow)
|
||||||
|
scenario_repo = _require_scenario_repo(uow)
|
||||||
try:
|
try:
|
||||||
uow.projects.get(project_id)
|
project_repo.get(project_id)
|
||||||
except EntityNotFoundError as exc:
|
except EntityNotFoundError as exc:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)) from exc
|
status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)) from exc
|
||||||
@@ -127,7 +155,7 @@ def create_scenario_for_project(
|
|||||||
scenario = Scenario(project_id=project_id, **payload.model_dump())
|
scenario = Scenario(project_id=project_id, **payload.model_dump())
|
||||||
|
|
||||||
try:
|
try:
|
||||||
created = uow.scenarios.create(scenario)
|
created = scenario_repo.create(scenario)
|
||||||
except EntityConflictError as exc:
|
except EntityConflictError as exc:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_409_CONFLICT, detail=str(exc)) from exc
|
status_code=status.HTTP_409_CONFLICT, detail=str(exc)) from exc
|
||||||
@@ -136,28 +164,19 @@ def create_scenario_for_project(
|
|||||||
|
|
||||||
@router.get("/scenarios/{scenario_id}", response_model=ScenarioRead)
|
@router.get("/scenarios/{scenario_id}", response_model=ScenarioRead)
|
||||||
def get_scenario(
|
def get_scenario(
|
||||||
scenario_id: int, uow: UnitOfWork = Depends(get_unit_of_work)
|
scenario: Scenario = Depends(require_scenario_resource()),
|
||||||
) -> ScenarioRead:
|
) -> ScenarioRead:
|
||||||
try:
|
|
||||||
scenario = uow.scenarios.get(scenario_id)
|
|
||||||
except EntityNotFoundError as exc:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)) from exc
|
|
||||||
return _to_read_model(scenario)
|
return _to_read_model(scenario)
|
||||||
|
|
||||||
|
|
||||||
@router.put("/scenarios/{scenario_id}", response_model=ScenarioRead)
|
@router.put("/scenarios/{scenario_id}", response_model=ScenarioRead)
|
||||||
def update_scenario(
|
def update_scenario(
|
||||||
scenario_id: int,
|
|
||||||
payload: ScenarioUpdate,
|
payload: ScenarioUpdate,
|
||||||
|
scenario: Scenario = Depends(
|
||||||
|
require_scenario_resource(require_manage=True)
|
||||||
|
),
|
||||||
uow: UnitOfWork = Depends(get_unit_of_work),
|
uow: UnitOfWork = Depends(get_unit_of_work),
|
||||||
) -> ScenarioRead:
|
) -> ScenarioRead:
|
||||||
try:
|
|
||||||
scenario = uow.scenarios.get(scenario_id)
|
|
||||||
except EntityNotFoundError as exc:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)) from exc
|
|
||||||
|
|
||||||
update_data = payload.model_dump(exclude_unset=True)
|
update_data = payload.model_dump(exclude_unset=True)
|
||||||
for field, value in update_data.items():
|
for field, value in update_data.items():
|
||||||
setattr(scenario, field, value)
|
setattr(scenario, field, value)
|
||||||
@@ -168,13 +187,12 @@ def update_scenario(
|
|||||||
|
|
||||||
@router.delete("/scenarios/{scenario_id}", status_code=status.HTTP_204_NO_CONTENT)
|
@router.delete("/scenarios/{scenario_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||||
def delete_scenario(
|
def delete_scenario(
|
||||||
scenario_id: int, uow: UnitOfWork = Depends(get_unit_of_work)
|
scenario: Scenario = Depends(
|
||||||
|
require_scenario_resource(require_manage=True)
|
||||||
|
),
|
||||||
|
uow: UnitOfWork = Depends(get_unit_of_work),
|
||||||
) -> None:
|
) -> None:
|
||||||
try:
|
_require_scenario_repo(uow).delete(scenario.id)
|
||||||
uow.scenarios.delete(scenario_id)
|
|
||||||
except EntityNotFoundError as exc:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)) from exc
|
|
||||||
|
|
||||||
|
|
||||||
def _normalise(value: str | None) -> str | None:
|
def _normalise(value: str | None) -> str | None:
|
||||||
@@ -208,10 +226,13 @@ def _parse_discount_rate(value: str | None) -> float | None:
|
|||||||
name="scenarios.create_scenario_form",
|
name="scenarios.create_scenario_form",
|
||||||
)
|
)
|
||||||
def create_scenario_form(
|
def create_scenario_form(
|
||||||
project_id: int, request: Request, uow: UnitOfWork = Depends(get_unit_of_work)
|
project_id: int,
|
||||||
|
request: Request,
|
||||||
|
_: User = Depends(require_roles(*MANAGE_ROLES)),
|
||||||
|
uow: UnitOfWork = Depends(get_unit_of_work),
|
||||||
) -> HTMLResponse:
|
) -> HTMLResponse:
|
||||||
try:
|
try:
|
||||||
project = uow.projects.get(project_id)
|
project = _require_project_repo(uow).get(project_id)
|
||||||
except EntityNotFoundError as exc:
|
except EntityNotFoundError as exc:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)
|
status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)
|
||||||
@@ -243,6 +264,7 @@ def create_scenario_form(
|
|||||||
def create_scenario_submit(
|
def create_scenario_submit(
|
||||||
project_id: int,
|
project_id: int,
|
||||||
request: Request,
|
request: Request,
|
||||||
|
_: User = Depends(require_roles(*MANAGE_ROLES)),
|
||||||
name: str = Form(...),
|
name: str = Form(...),
|
||||||
description: str | None = Form(None),
|
description: str | None = Form(None),
|
||||||
status_value: str = Form(ScenarioStatus.DRAFT.value),
|
status_value: str = Form(ScenarioStatus.DRAFT.value),
|
||||||
@@ -253,8 +275,10 @@ def create_scenario_submit(
|
|||||||
primary_resource: str | None = Form(None),
|
primary_resource: str | None = Form(None),
|
||||||
uow: UnitOfWork = Depends(get_unit_of_work),
|
uow: UnitOfWork = Depends(get_unit_of_work),
|
||||||
):
|
):
|
||||||
|
project_repo = _require_project_repo(uow)
|
||||||
|
scenario_repo = _require_scenario_repo(uow)
|
||||||
try:
|
try:
|
||||||
project = uow.projects.get(project_id)
|
project = project_repo.get(project_id)
|
||||||
except EntityNotFoundError as exc:
|
except EntityNotFoundError as exc:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)
|
status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)
|
||||||
@@ -288,7 +312,7 @@ def create_scenario_submit(
|
|||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
uow.scenarios.create(scenario)
|
scenario_repo.create(scenario)
|
||||||
except EntityConflictError as exc:
|
except EntityConflictError as exc:
|
||||||
return templates.TemplateResponse(
|
return templates.TemplateResponse(
|
||||||
request,
|
request,
|
||||||
@@ -322,16 +346,13 @@ def create_scenario_submit(
|
|||||||
name="scenarios.view_scenario",
|
name="scenarios.view_scenario",
|
||||||
)
|
)
|
||||||
def view_scenario(
|
def view_scenario(
|
||||||
scenario_id: int, request: Request, uow: UnitOfWork = Depends(get_unit_of_work)
|
request: Request,
|
||||||
|
scenario: Scenario = Depends(
|
||||||
|
require_scenario_resource(with_children=True)
|
||||||
|
),
|
||||||
|
uow: UnitOfWork = Depends(get_unit_of_work),
|
||||||
) -> HTMLResponse:
|
) -> HTMLResponse:
|
||||||
try:
|
project = _require_project_repo(uow).get(scenario.project_id)
|
||||||
scenario = uow.scenarios.get(scenario_id, with_children=True)
|
|
||||||
except EntityNotFoundError as exc:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)
|
|
||||||
) from exc
|
|
||||||
|
|
||||||
project = uow.projects.get(scenario.project_id)
|
|
||||||
financial_inputs = sorted(
|
financial_inputs = sorted(
|
||||||
scenario.financial_inputs, key=lambda item: item.created_at
|
scenario.financial_inputs, key=lambda item: item.created_at
|
||||||
)
|
)
|
||||||
@@ -366,16 +387,13 @@ def view_scenario(
|
|||||||
name="scenarios.edit_scenario_form",
|
name="scenarios.edit_scenario_form",
|
||||||
)
|
)
|
||||||
def edit_scenario_form(
|
def edit_scenario_form(
|
||||||
scenario_id: int, request: Request, uow: UnitOfWork = Depends(get_unit_of_work)
|
request: Request,
|
||||||
|
scenario: Scenario = Depends(
|
||||||
|
require_scenario_resource(require_manage=True)
|
||||||
|
),
|
||||||
|
uow: UnitOfWork = Depends(get_unit_of_work),
|
||||||
) -> HTMLResponse:
|
) -> HTMLResponse:
|
||||||
try:
|
project = _require_project_repo(uow).get(scenario.project_id)
|
||||||
scenario = uow.scenarios.get(scenario_id)
|
|
||||||
except EntityNotFoundError as exc:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)
|
|
||||||
) from exc
|
|
||||||
|
|
||||||
project = uow.projects.get(scenario.project_id)
|
|
||||||
|
|
||||||
return templates.TemplateResponse(
|
return templates.TemplateResponse(
|
||||||
request,
|
request,
|
||||||
@@ -386,10 +404,10 @@ def edit_scenario_form(
|
|||||||
"scenario_statuses": _scenario_status_choices(),
|
"scenario_statuses": _scenario_status_choices(),
|
||||||
"resource_types": _resource_type_choices(),
|
"resource_types": _resource_type_choices(),
|
||||||
"form_action": request.url_for(
|
"form_action": request.url_for(
|
||||||
"scenarios.edit_scenario_submit", scenario_id=scenario_id
|
"scenarios.edit_scenario_submit", scenario_id=scenario.id
|
||||||
),
|
),
|
||||||
"cancel_url": request.url_for(
|
"cancel_url": request.url_for(
|
||||||
"scenarios.view_scenario", scenario_id=scenario_id
|
"scenarios.view_scenario", scenario_id=scenario.id
|
||||||
),
|
),
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
@@ -401,8 +419,10 @@ def edit_scenario_form(
|
|||||||
name="scenarios.edit_scenario_submit",
|
name="scenarios.edit_scenario_submit",
|
||||||
)
|
)
|
||||||
def edit_scenario_submit(
|
def edit_scenario_submit(
|
||||||
scenario_id: int,
|
|
||||||
request: Request,
|
request: Request,
|
||||||
|
scenario: Scenario = Depends(
|
||||||
|
require_scenario_resource(require_manage=True)
|
||||||
|
),
|
||||||
name: str = Form(...),
|
name: str = Form(...),
|
||||||
description: str | None = Form(None),
|
description: str | None = Form(None),
|
||||||
status_value: str = Form(ScenarioStatus.DRAFT.value),
|
status_value: str = Form(ScenarioStatus.DRAFT.value),
|
||||||
@@ -413,14 +433,7 @@ def edit_scenario_submit(
|
|||||||
primary_resource: str | None = Form(None),
|
primary_resource: str | None = Form(None),
|
||||||
uow: UnitOfWork = Depends(get_unit_of_work),
|
uow: UnitOfWork = Depends(get_unit_of_work),
|
||||||
):
|
):
|
||||||
try:
|
project = _require_project_repo(uow).get(scenario.project_id)
|
||||||
scenario = uow.scenarios.get(scenario_id)
|
|
||||||
except EntityNotFoundError as exc:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)
|
|
||||||
) from exc
|
|
||||||
|
|
||||||
project = uow.projects.get(scenario.project_id)
|
|
||||||
|
|
||||||
scenario.name = name.strip()
|
scenario.name = name.strip()
|
||||||
scenario.description = _normalise(description)
|
scenario.description = _normalise(description)
|
||||||
@@ -447,6 +460,6 @@ def edit_scenario_submit(
|
|||||||
uow.flush()
|
uow.flush()
|
||||||
|
|
||||||
return RedirectResponse(
|
return RedirectResponse(
|
||||||
request.url_for("scenarios.view_scenario", scenario_id=scenario_id),
|
request.url_for("scenarios.view_scenario", scenario_id=scenario.id),
|
||||||
status_code=status.HTTP_303_SEE_OTHER,
|
status_code=status.HTTP_303_SEE_OTHER,
|
||||||
)
|
)
|
||||||
|
|||||||
20
scripts/00_initial_data.py
Normal file
20
scripts/00_initial_data.py
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from scripts.initial_data import load_config, seed_initial_data
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> int:
|
||||||
|
logging.basicConfig(level=logging.INFO, format="[%(levelname)s] %(message)s")
|
||||||
|
try:
|
||||||
|
config = load_config()
|
||||||
|
seed_initial_data(config)
|
||||||
|
except Exception as exc: # pragma: no cover - operational guard
|
||||||
|
logging.exception("Seeding failed: %s", exc)
|
||||||
|
return 1
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
raise SystemExit(main())
|
||||||
1
scripts/__init__.py
Normal file
1
scripts/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
"""Utility scripts for CalMiner maintenance tasks."""
|
||||||
183
scripts/initial_data.py
Normal file
183
scripts/initial_data.py
Normal file
@@ -0,0 +1,183 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Callable, Iterable
|
||||||
|
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
from models import Role, User
|
||||||
|
from services.repositories import DEFAULT_ROLE_DEFINITIONS, RoleRepository, UserRepository
|
||||||
|
from services.unit_of_work import UnitOfWork
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SeedConfig:
|
||||||
|
admin_email: str
|
||||||
|
admin_username: str
|
||||||
|
admin_password: str
|
||||||
|
admin_roles: tuple[str, ...]
|
||||||
|
force_reset: bool
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RoleSeedResult:
|
||||||
|
created: int
|
||||||
|
updated: int
|
||||||
|
total: int
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AdminSeedResult:
|
||||||
|
created_user: bool
|
||||||
|
updated_user: bool
|
||||||
|
password_rotated: bool
|
||||||
|
roles_granted: int
|
||||||
|
|
||||||
|
|
||||||
|
def parse_bool(value: str | None) -> bool:
|
||||||
|
if value is None:
|
||||||
|
return False
|
||||||
|
return value.strip().lower() in {"1", "true", "yes", "on"}
|
||||||
|
|
||||||
|
|
||||||
|
def normalise_role_list(raw_value: str | None) -> tuple[str, ...]:
|
||||||
|
if not raw_value:
|
||||||
|
return ("admin",)
|
||||||
|
parts = [segment.strip() for segment in raw_value.split(",") if segment.strip()]
|
||||||
|
if "admin" not in parts:
|
||||||
|
parts.insert(0, "admin")
|
||||||
|
seen: set[str] = set()
|
||||||
|
ordered: list[str] = []
|
||||||
|
for role_name in parts:
|
||||||
|
if role_name not in seen:
|
||||||
|
ordered.append(role_name)
|
||||||
|
seen.add(role_name)
|
||||||
|
return tuple(ordered)
|
||||||
|
|
||||||
|
|
||||||
|
def load_config() -> SeedConfig:
|
||||||
|
load_dotenv()
|
||||||
|
admin_email = os.getenv("CALMINER_SEED_ADMIN_EMAIL", "admin@calminer.local")
|
||||||
|
admin_username = os.getenv("CALMINER_SEED_ADMIN_USERNAME", "admin")
|
||||||
|
admin_password = os.getenv("CALMINER_SEED_ADMIN_PASSWORD", "ChangeMe123!")
|
||||||
|
admin_roles = normalise_role_list(os.getenv("CALMINER_SEED_ADMIN_ROLES"))
|
||||||
|
force_reset = parse_bool(os.getenv("CALMINER_SEED_FORCE"))
|
||||||
|
return SeedConfig(
|
||||||
|
admin_email=admin_email,
|
||||||
|
admin_username=admin_username,
|
||||||
|
admin_password=admin_password,
|
||||||
|
admin_roles=admin_roles,
|
||||||
|
force_reset=force_reset,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def ensure_default_roles(
|
||||||
|
role_repo: RoleRepository,
|
||||||
|
definitions: Iterable[dict[str, str]] = DEFAULT_ROLE_DEFINITIONS,
|
||||||
|
) -> RoleSeedResult:
|
||||||
|
created = 0
|
||||||
|
updated = 0
|
||||||
|
total = 0
|
||||||
|
for definition in definitions:
|
||||||
|
total += 1
|
||||||
|
existing = role_repo.get_by_name(definition["name"])
|
||||||
|
if existing is None:
|
||||||
|
role_repo.create(Role(**definition))
|
||||||
|
created += 1
|
||||||
|
continue
|
||||||
|
changed = False
|
||||||
|
if existing.display_name != definition["display_name"]:
|
||||||
|
existing.display_name = definition["display_name"]
|
||||||
|
changed = True
|
||||||
|
if existing.description != definition["description"]:
|
||||||
|
existing.description = definition["description"]
|
||||||
|
changed = True
|
||||||
|
if changed:
|
||||||
|
updated += 1
|
||||||
|
role_repo.session.flush()
|
||||||
|
return RoleSeedResult(created=created, updated=updated, total=total)
|
||||||
|
|
||||||
|
|
||||||
|
def ensure_admin_user(
|
||||||
|
user_repo: UserRepository,
|
||||||
|
role_repo: RoleRepository,
|
||||||
|
config: SeedConfig,
|
||||||
|
) -> AdminSeedResult:
|
||||||
|
created_user = False
|
||||||
|
updated_user = False
|
||||||
|
password_rotated = False
|
||||||
|
roles_granted = 0
|
||||||
|
|
||||||
|
user = user_repo.get_by_email(config.admin_email, with_roles=True)
|
||||||
|
if user is None:
|
||||||
|
user = User(
|
||||||
|
email=config.admin_email,
|
||||||
|
username=config.admin_username,
|
||||||
|
password_hash=User.hash_password(config.admin_password),
|
||||||
|
is_active=True,
|
||||||
|
is_superuser=True,
|
||||||
|
)
|
||||||
|
user_repo.create(user)
|
||||||
|
created_user = True
|
||||||
|
else:
|
||||||
|
if user.username != config.admin_username:
|
||||||
|
user.username = config.admin_username
|
||||||
|
updated_user = True
|
||||||
|
if not user.is_active:
|
||||||
|
user.is_active = True
|
||||||
|
updated_user = True
|
||||||
|
if not user.is_superuser:
|
||||||
|
user.is_superuser = True
|
||||||
|
updated_user = True
|
||||||
|
if config.force_reset:
|
||||||
|
user.password_hash = User.hash_password(config.admin_password)
|
||||||
|
password_rotated = True
|
||||||
|
updated_user = True
|
||||||
|
user_repo.session.flush()
|
||||||
|
|
||||||
|
for role_name in config.admin_roles:
|
||||||
|
role = role_repo.get_by_name(role_name)
|
||||||
|
if role is None:
|
||||||
|
logging.warning("Role '%s' is not defined and will be skipped", role_name)
|
||||||
|
continue
|
||||||
|
already_assigned = any(assignment.role_id == role.id for assignment in user.role_assignments)
|
||||||
|
if already_assigned:
|
||||||
|
continue
|
||||||
|
user_repo.assign_role(user_id=user.id, role_id=role.id, granted_by=user.id)
|
||||||
|
roles_granted += 1
|
||||||
|
|
||||||
|
return AdminSeedResult(
|
||||||
|
created_user=created_user,
|
||||||
|
updated_user=updated_user,
|
||||||
|
password_rotated=password_rotated,
|
||||||
|
roles_granted=roles_granted,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def seed_initial_data(
|
||||||
|
config: SeedConfig,
|
||||||
|
*,
|
||||||
|
unit_of_work_factory: Callable[[], UnitOfWork] | None = None,
|
||||||
|
) -> None:
|
||||||
|
logging.info("Starting initial data seeding")
|
||||||
|
factory = unit_of_work_factory or UnitOfWork
|
||||||
|
with factory() as uow:
|
||||||
|
assert uow.roles is not None and uow.users is not None
|
||||||
|
role_result = ensure_default_roles(uow.roles)
|
||||||
|
admin_result = ensure_admin_user(uow.users, uow.roles, config)
|
||||||
|
logging.info(
|
||||||
|
"Roles processed: %s total, %s created, %s updated",
|
||||||
|
role_result.total,
|
||||||
|
role_result.created,
|
||||||
|
role_result.updated,
|
||||||
|
)
|
||||||
|
logging.info(
|
||||||
|
"Admin user: created=%s updated=%s password_rotated=%s roles_granted=%s",
|
||||||
|
admin_result.created_user,
|
||||||
|
admin_result.updated_user,
|
||||||
|
admin_result.password_rotated,
|
||||||
|
admin_result.roles_granted,
|
||||||
|
)
|
||||||
|
logging.info("Initial data seeding completed successfully")
|
||||||
104
services/authorization.py
Normal file
104
services/authorization.py
Normal file
@@ -0,0 +1,104 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Iterable
|
||||||
|
|
||||||
|
from models import Project, Role, Scenario, User
|
||||||
|
from services.exceptions import AuthorizationError, EntityNotFoundError
|
||||||
|
from services.repositories import ProjectRepository, ScenarioRepository
|
||||||
|
from services.unit_of_work import UnitOfWork
|
||||||
|
|
||||||
|
READ_ROLES: frozenset[str] = frozenset(
|
||||||
|
{"viewer", "analyst", "project_manager", "admin"}
|
||||||
|
)
|
||||||
|
MANAGE_ROLES: frozenset[str] = frozenset({"project_manager", "admin"})
|
||||||
|
|
||||||
|
|
||||||
|
def _user_role_names(user: User) -> set[str]:
|
||||||
|
roles: Iterable[Role] = getattr(user, "roles", []) or []
|
||||||
|
return {role.name for role in roles}
|
||||||
|
|
||||||
|
|
||||||
|
def _require_project_repo(uow: UnitOfWork) -> ProjectRepository:
|
||||||
|
if not uow.projects:
|
||||||
|
raise RuntimeError("Project repository not initialised")
|
||||||
|
return uow.projects
|
||||||
|
|
||||||
|
|
||||||
|
def _require_scenario_repo(uow: UnitOfWork) -> ScenarioRepository:
|
||||||
|
if not uow.scenarios:
|
||||||
|
raise RuntimeError("Scenario repository not initialised")
|
||||||
|
return uow.scenarios
|
||||||
|
|
||||||
|
|
||||||
|
def _assert_user_can_access(user: User, *, require_manage: bool) -> None:
|
||||||
|
if not user.is_active:
|
||||||
|
raise AuthorizationError("User account is disabled.")
|
||||||
|
if user.is_superuser:
|
||||||
|
return
|
||||||
|
|
||||||
|
allowed = MANAGE_ROLES if require_manage else READ_ROLES
|
||||||
|
if not _user_role_names(user) & allowed:
|
||||||
|
raise AuthorizationError(
|
||||||
|
"Insufficient role permissions for this action.")
|
||||||
|
|
||||||
|
|
||||||
|
def ensure_project_access(
|
||||||
|
uow: UnitOfWork,
|
||||||
|
*,
|
||||||
|
project_id: int,
|
||||||
|
user: User,
|
||||||
|
require_manage: bool = False,
|
||||||
|
) -> Project:
|
||||||
|
"""Resolve a project and ensure the user holds the required permissions."""
|
||||||
|
|
||||||
|
repo = _require_project_repo(uow)
|
||||||
|
project = repo.get(project_id)
|
||||||
|
_assert_user_can_access(user, require_manage=require_manage)
|
||||||
|
return project
|
||||||
|
|
||||||
|
|
||||||
|
def ensure_scenario_access(
|
||||||
|
uow: UnitOfWork,
|
||||||
|
*,
|
||||||
|
scenario_id: int,
|
||||||
|
user: User,
|
||||||
|
require_manage: bool = False,
|
||||||
|
with_children: bool = False,
|
||||||
|
) -> Scenario:
|
||||||
|
"""Resolve a scenario and ensure the user holds the required permissions."""
|
||||||
|
|
||||||
|
repo = _require_scenario_repo(uow)
|
||||||
|
scenario = repo.get(scenario_id, with_children=with_children)
|
||||||
|
_assert_user_can_access(user, require_manage=require_manage)
|
||||||
|
return scenario
|
||||||
|
|
||||||
|
|
||||||
|
def ensure_scenario_in_project(
|
||||||
|
uow: UnitOfWork,
|
||||||
|
*,
|
||||||
|
project_id: int,
|
||||||
|
scenario_id: int,
|
||||||
|
user: User,
|
||||||
|
require_manage: bool = False,
|
||||||
|
with_children: bool = False,
|
||||||
|
) -> Scenario:
|
||||||
|
"""Resolve a scenario ensuring it belongs to the project and the user may access it."""
|
||||||
|
|
||||||
|
project = ensure_project_access(
|
||||||
|
uow,
|
||||||
|
project_id=project_id,
|
||||||
|
user=user,
|
||||||
|
require_manage=require_manage,
|
||||||
|
)
|
||||||
|
scenario = ensure_scenario_access(
|
||||||
|
uow,
|
||||||
|
scenario_id=scenario_id,
|
||||||
|
user=user,
|
||||||
|
require_manage=require_manage,
|
||||||
|
with_children=with_children,
|
||||||
|
)
|
||||||
|
if scenario.project_id != project.id:
|
||||||
|
raise EntityNotFoundError(
|
||||||
|
f"Scenario {scenario_id} does not belong to project {project_id}."
|
||||||
|
)
|
||||||
|
return scenario
|
||||||
@@ -12,6 +12,10 @@ class EntityConflictError(Exception):
|
|||||||
"""Raised when attempting to create or update an entity that violates uniqueness."""
|
"""Raised when attempting to create or update an entity that violates uniqueness."""
|
||||||
|
|
||||||
|
|
||||||
|
class AuthorizationError(Exception):
|
||||||
|
"""Raised when a user lacks permission to perform an action."""
|
||||||
|
|
||||||
|
|
||||||
@dataclass(eq=False)
|
@dataclass(eq=False)
|
||||||
class ScenarioValidationError(Exception):
|
class ScenarioValidationError(Exception):
|
||||||
"""Raised when scenarios fail comparison validation rules."""
|
"""Raised when scenarios fail comparison validation rules."""
|
||||||
|
|||||||
@@ -57,6 +57,10 @@ class ProjectRepository:
|
|||||||
raise EntityNotFoundError(f"Project {project_id} not found")
|
raise EntityNotFoundError(f"Project {project_id} not found")
|
||||||
return project
|
return project
|
||||||
|
|
||||||
|
def exists(self, project_id: int) -> bool:
|
||||||
|
stmt = select(Project.id).where(Project.id == project_id).limit(1)
|
||||||
|
return self.session.execute(stmt).scalar_one_or_none() is not None
|
||||||
|
|
||||||
def create(self, project: Project) -> Project:
|
def create(self, project: Project) -> Project:
|
||||||
self.session.add(project)
|
self.session.add(project)
|
||||||
try:
|
try:
|
||||||
@@ -133,6 +137,10 @@ class ScenarioRepository:
|
|||||||
raise EntityNotFoundError(f"Scenario {scenario_id} not found")
|
raise EntityNotFoundError(f"Scenario {scenario_id} not found")
|
||||||
return scenario
|
return scenario
|
||||||
|
|
||||||
|
def exists(self, scenario_id: int) -> bool:
|
||||||
|
stmt = select(Scenario.id).where(Scenario.id == scenario_id).limit(1)
|
||||||
|
return self.session.execute(stmt).scalar_one_or_none() is not None
|
||||||
|
|
||||||
def create(self, scenario: Scenario) -> Scenario:
|
def create(self, scenario: Scenario) -> Scenario:
|
||||||
self.session.add(scenario)
|
self.session.add(scenario)
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ from __future__ import annotations
|
|||||||
from collections.abc import Callable, Iterator
|
from collections.abc import Callable, Iterator
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI, Request
|
||||||
from fastapi.testclient import TestClient
|
from fastapi.testclient import TestClient
|
||||||
from sqlalchemy import create_engine
|
from sqlalchemy import create_engine
|
||||||
from sqlalchemy.engine import Engine
|
from sqlalchemy.engine import Engine
|
||||||
@@ -11,12 +11,14 @@ from sqlalchemy.orm import sessionmaker
|
|||||||
from sqlalchemy.pool import StaticPool
|
from sqlalchemy.pool import StaticPool
|
||||||
|
|
||||||
from config.database import Base
|
from config.database import Base
|
||||||
from dependencies import get_unit_of_work
|
from dependencies import get_auth_session, get_unit_of_work
|
||||||
|
from models import User
|
||||||
from routes.auth import router as auth_router
|
from routes.auth import router as auth_router
|
||||||
from routes.dashboard import router as dashboard_router
|
from routes.dashboard import router as dashboard_router
|
||||||
from routes.projects import router as projects_router
|
from routes.projects import router as projects_router
|
||||||
from routes.scenarios import router as scenarios_router
|
from routes.scenarios import router as scenarios_router
|
||||||
from services.unit_of_work import UnitOfWork
|
from services.unit_of_work import UnitOfWork
|
||||||
|
from services.session import AuthSession, SessionTokens
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture()
|
@pytest.fixture()
|
||||||
@@ -55,6 +57,28 @@ def app(session_factory: sessionmaker) -> FastAPI:
|
|||||||
yield uow
|
yield uow
|
||||||
|
|
||||||
application.dependency_overrides[get_unit_of_work] = _override_uow
|
application.dependency_overrides[get_unit_of_work] = _override_uow
|
||||||
|
|
||||||
|
with UnitOfWork(session_factory=session_factory) as uow:
|
||||||
|
assert uow.users is not None
|
||||||
|
uow.ensure_default_roles()
|
||||||
|
user = User(
|
||||||
|
email="test-superuser@example.com",
|
||||||
|
username="test-superuser",
|
||||||
|
password_hash=User.hash_password("test-password"),
|
||||||
|
is_active=True,
|
||||||
|
is_superuser=True,
|
||||||
|
)
|
||||||
|
uow.users.create(user)
|
||||||
|
user = uow.users.get(user.id, with_roles=True)
|
||||||
|
|
||||||
|
def _override_auth_session(request: Request) -> AuthSession:
|
||||||
|
session = AuthSession(tokens=SessionTokens(
|
||||||
|
access_token="test", refresh_token="test"))
|
||||||
|
session.user = user
|
||||||
|
request.state.auth_session = session
|
||||||
|
return session
|
||||||
|
|
||||||
|
application.dependency_overrides[get_auth_session] = _override_auth_session
|
||||||
return application
|
return application
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
156
tests/scripts/test_initial_data_seed.py
Normal file
156
tests/scripts/test_initial_data_seed.py
Normal file
@@ -0,0 +1,156 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from collections.abc import Callable
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from sqlalchemy import create_engine
|
||||||
|
from sqlalchemy.orm import Session, sessionmaker
|
||||||
|
|
||||||
|
from config.database import Base
|
||||||
|
from scripts import initial_data
|
||||||
|
from scripts.initial_data import AdminSeedResult, RoleSeedResult, SeedConfig
|
||||||
|
from services.repositories import DEFAULT_ROLE_DEFINITIONS
|
||||||
|
from services.unit_of_work import UnitOfWork
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def in_memory_session_factory() -> Callable[[], Session]:
|
||||||
|
engine = create_engine("sqlite+pysqlite:///:memory:", future=True)
|
||||||
|
Base.metadata.create_all(engine)
|
||||||
|
factory = sessionmaker(bind=engine, autoflush=False, autocommit=False, future=True)
|
||||||
|
|
||||||
|
def _session_factory() -> Session:
|
||||||
|
return factory()
|
||||||
|
|
||||||
|
return _session_factory
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def uow(in_memory_session_factory: Callable[[], Session]) -> UnitOfWork:
|
||||||
|
return UnitOfWork(session_factory=in_memory_session_factory)
|
||||||
|
|
||||||
|
|
||||||
|
def test_ensure_default_roles_idempotent(uow: UnitOfWork) -> None:
|
||||||
|
with uow as working:
|
||||||
|
assert working.roles is not None
|
||||||
|
result_first = initial_data.ensure_default_roles(working.roles, DEFAULT_ROLE_DEFINITIONS)
|
||||||
|
assert result_first == RoleSeedResult(created=4, updated=0, total=4)
|
||||||
|
|
||||||
|
with uow as working:
|
||||||
|
assert working.roles is not None
|
||||||
|
result_second = initial_data.ensure_default_roles(working.roles, DEFAULT_ROLE_DEFINITIONS)
|
||||||
|
assert result_second == RoleSeedResult(created=0, updated=0, total=4)
|
||||||
|
|
||||||
|
|
||||||
|
def test_ensure_admin_user_creates_and_assigns_roles(uow: UnitOfWork) -> None:
|
||||||
|
config = SeedConfig(
|
||||||
|
admin_email="admin@example.com",
|
||||||
|
admin_username="admin",
|
||||||
|
admin_password="secret",
|
||||||
|
admin_roles=("admin", "viewer"),
|
||||||
|
force_reset=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
with uow as working:
|
||||||
|
assert working.roles is not None
|
||||||
|
assert working.users is not None
|
||||||
|
initial_data.ensure_default_roles(working.roles, DEFAULT_ROLE_DEFINITIONS)
|
||||||
|
result = initial_data.ensure_admin_user(working.users, working.roles, config)
|
||||||
|
assert result == AdminSeedResult(
|
||||||
|
created_user=True,
|
||||||
|
updated_user=False,
|
||||||
|
password_rotated=False,
|
||||||
|
roles_granted=2,
|
||||||
|
)
|
||||||
|
|
||||||
|
with uow as working:
|
||||||
|
assert working.roles is not None
|
||||||
|
assert working.users is not None
|
||||||
|
result_again = initial_data.ensure_admin_user(working.users, working.roles, config)
|
||||||
|
assert result_again == AdminSeedResult(
|
||||||
|
created_user=False,
|
||||||
|
updated_user=False,
|
||||||
|
password_rotated=False,
|
||||||
|
roles_granted=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
with uow as working:
|
||||||
|
assert working.users is not None
|
||||||
|
user = working.users.get_by_email("admin@example.com", with_roles=True)
|
||||||
|
assert user is not None
|
||||||
|
assert user.is_active is True
|
||||||
|
assert user.is_superuser is True
|
||||||
|
role_names = {role.name for role in user.roles}
|
||||||
|
assert role_names == {"admin", "viewer"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_ensure_admin_user_force_reset_rotates_password(uow: UnitOfWork) -> None:
|
||||||
|
base_config = SeedConfig(
|
||||||
|
admin_email="admin@example.com",
|
||||||
|
admin_username="admin",
|
||||||
|
admin_password="first",
|
||||||
|
admin_roles=("admin",),
|
||||||
|
force_reset=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
with uow as working:
|
||||||
|
assert working.roles is not None
|
||||||
|
assert working.users is not None
|
||||||
|
initial_data.ensure_default_roles(working.roles, DEFAULT_ROLE_DEFINITIONS)
|
||||||
|
initial_data.ensure_admin_user(working.users, working.roles, base_config)
|
||||||
|
|
||||||
|
rotate_config = SeedConfig(
|
||||||
|
admin_email="admin@example.com",
|
||||||
|
admin_username="admin",
|
||||||
|
admin_password="second",
|
||||||
|
admin_roles=("admin",),
|
||||||
|
force_reset=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
with uow as working:
|
||||||
|
assert working.users is not None
|
||||||
|
user_before = working.users.get_by_email("admin@example.com")
|
||||||
|
assert user_before is not None
|
||||||
|
old_hash = user_before.password_hash
|
||||||
|
|
||||||
|
with uow as working:
|
||||||
|
assert working.roles is not None
|
||||||
|
assert working.users is not None
|
||||||
|
result = initial_data.ensure_admin_user(working.users, working.roles, rotate_config)
|
||||||
|
assert result.password_rotated is True
|
||||||
|
|
||||||
|
with uow as working:
|
||||||
|
assert working.users is not None
|
||||||
|
user_after = working.users.get_by_email("admin@example.com")
|
||||||
|
assert user_after is not None
|
||||||
|
assert user_after.password_hash != old_hash
|
||||||
|
|
||||||
|
|
||||||
|
def test_seed_initial_data_logs_results(
|
||||||
|
caplog,
|
||||||
|
in_memory_session_factory: Callable[[], Session],
|
||||||
|
) -> None:
|
||||||
|
caplog.set_level(logging.INFO)
|
||||||
|
config = SeedConfig(
|
||||||
|
admin_email="seed@example.com",
|
||||||
|
admin_username="seed",
|
||||||
|
admin_password="seed-pass",
|
||||||
|
admin_roles=("admin",),
|
||||||
|
force_reset=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
initial_data.seed_initial_data(
|
||||||
|
config,
|
||||||
|
unit_of_work_factory=lambda: UnitOfWork(session_factory=in_memory_session_factory),
|
||||||
|
)
|
||||||
|
|
||||||
|
assert "Starting initial data seeding" in caplog.text
|
||||||
|
assert "Initial data seeding completed successfully" in caplog.text
|
||||||
|
|
||||||
|
with UnitOfWork(session_factory=in_memory_session_factory) as check_uow:
|
||||||
|
assert check_uow.users is not None
|
||||||
|
assert check_uow.roles is not None
|
||||||
|
user = check_uow.users.get_by_email("seed@example.com")
|
||||||
|
assert user is not None
|
||||||
|
assert check_uow.roles.get_by_name("admin") is not None
|
||||||
165
tests/test_authorization_helpers.py
Normal file
165
tests/test_authorization_helpers.py
Normal file
@@ -0,0 +1,165 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from models import MiningOperationType, Project, Scenario, ScenarioStatus, User
|
||||||
|
from services.authorization import (
|
||||||
|
ensure_project_access,
|
||||||
|
ensure_scenario_access,
|
||||||
|
ensure_scenario_in_project,
|
||||||
|
)
|
||||||
|
from services.exceptions import AuthorizationError, EntityNotFoundError
|
||||||
|
from services.unit_of_work import UnitOfWork
|
||||||
|
|
||||||
|
|
||||||
|
def _create_user_with_roles(
|
||||||
|
uow: UnitOfWork,
|
||||||
|
*,
|
||||||
|
email: str,
|
||||||
|
username: str,
|
||||||
|
roles: set[str],
|
||||||
|
) -> User:
|
||||||
|
assert uow.users is not None
|
||||||
|
assert uow.roles is not None
|
||||||
|
|
||||||
|
user = User(
|
||||||
|
email=email,
|
||||||
|
username=username,
|
||||||
|
password_hash=User.hash_password("secret"),
|
||||||
|
is_active=True,
|
||||||
|
)
|
||||||
|
uow.users.create(user)
|
||||||
|
uow.ensure_default_roles()
|
||||||
|
|
||||||
|
for role_name in roles:
|
||||||
|
role = uow.roles.get_by_name(role_name)
|
||||||
|
assert role is not None, f"Role {role_name} should exist"
|
||||||
|
uow.users.assign_role(user_id=user.id, role_id=role.id)
|
||||||
|
|
||||||
|
return uow.users.get(user.id, with_roles=True)
|
||||||
|
|
||||||
|
|
||||||
|
def _create_project(uow: UnitOfWork, name: str) -> Project:
|
||||||
|
assert uow.projects is not None
|
||||||
|
project = Project(
|
||||||
|
name=name,
|
||||||
|
location=None,
|
||||||
|
operation_type=MiningOperationType.OTHER,
|
||||||
|
description=None,
|
||||||
|
)
|
||||||
|
uow.projects.create(project)
|
||||||
|
return project
|
||||||
|
|
||||||
|
|
||||||
|
def _create_scenario(uow: UnitOfWork, project: Project, name: str) -> Scenario:
|
||||||
|
assert uow.scenarios is not None
|
||||||
|
scenario = Scenario(
|
||||||
|
project_id=project.id,
|
||||||
|
name=name,
|
||||||
|
description=None,
|
||||||
|
status=ScenarioStatus.DRAFT,
|
||||||
|
)
|
||||||
|
uow.scenarios.create(scenario)
|
||||||
|
return scenario
|
||||||
|
|
||||||
|
|
||||||
|
def test_ensure_project_access_allows_view_roles(unit_of_work_factory) -> None:
|
||||||
|
with unit_of_work_factory() as uow:
|
||||||
|
project = _create_project(uow, "Project A")
|
||||||
|
user = _create_user_with_roles(
|
||||||
|
uow,
|
||||||
|
email="viewer@example.com",
|
||||||
|
username="viewer",
|
||||||
|
roles={"viewer"},
|
||||||
|
)
|
||||||
|
|
||||||
|
resolved = ensure_project_access(
|
||||||
|
uow,
|
||||||
|
project_id=project.id,
|
||||||
|
user=user,
|
||||||
|
)
|
||||||
|
assert resolved.id == project.id
|
||||||
|
|
||||||
|
with pytest.raises(AuthorizationError):
|
||||||
|
ensure_project_access(
|
||||||
|
uow,
|
||||||
|
project_id=project.id,
|
||||||
|
user=user,
|
||||||
|
require_manage=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_ensure_project_access_allows_manage_roles(unit_of_work_factory) -> None:
|
||||||
|
with unit_of_work_factory() as uow:
|
||||||
|
project = _create_project(uow, "Project B")
|
||||||
|
user = _create_user_with_roles(
|
||||||
|
uow,
|
||||||
|
email="manager@example.com",
|
||||||
|
username="manager",
|
||||||
|
roles={"project_manager"},
|
||||||
|
)
|
||||||
|
|
||||||
|
resolved = ensure_project_access(
|
||||||
|
uow,
|
||||||
|
project_id=project.id,
|
||||||
|
user=user,
|
||||||
|
require_manage=True,
|
||||||
|
)
|
||||||
|
assert resolved.id == project.id
|
||||||
|
|
||||||
|
|
||||||
|
def test_ensure_scenario_access(unit_of_work_factory) -> None:
|
||||||
|
with unit_of_work_factory() as uow:
|
||||||
|
project = _create_project(uow, "Project C")
|
||||||
|
scenario = _create_scenario(uow, project, "Scenario C1")
|
||||||
|
user = _create_user_with_roles(
|
||||||
|
uow,
|
||||||
|
email="analyst@example.com",
|
||||||
|
username="analyst",
|
||||||
|
roles={"analyst"},
|
||||||
|
)
|
||||||
|
|
||||||
|
resolved = ensure_scenario_access(
|
||||||
|
uow,
|
||||||
|
scenario_id=scenario.id,
|
||||||
|
user=user,
|
||||||
|
)
|
||||||
|
assert resolved.id == scenario.id
|
||||||
|
|
||||||
|
with pytest.raises(AuthorizationError):
|
||||||
|
ensure_scenario_access(
|
||||||
|
uow,
|
||||||
|
scenario_id=scenario.id,
|
||||||
|
user=user,
|
||||||
|
require_manage=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_ensure_scenario_in_project_validates_membership(unit_of_work_factory) -> None:
|
||||||
|
with unit_of_work_factory() as uow:
|
||||||
|
project_one = _create_project(uow, "Project D")
|
||||||
|
project_two = _create_project(uow, "Project E")
|
||||||
|
scenario = _create_scenario(uow, project_one, "Scenario D1")
|
||||||
|
user = _create_user_with_roles(
|
||||||
|
uow,
|
||||||
|
email="manager2@example.com",
|
||||||
|
username="manager2",
|
||||||
|
roles={"project_manager"},
|
||||||
|
)
|
||||||
|
|
||||||
|
resolved = ensure_scenario_in_project(
|
||||||
|
uow,
|
||||||
|
project_id=project_one.id,
|
||||||
|
scenario_id=scenario.id,
|
||||||
|
user=user,
|
||||||
|
require_manage=True,
|
||||||
|
)
|
||||||
|
assert resolved.id == scenario.id
|
||||||
|
|
||||||
|
with pytest.raises(EntityNotFoundError):
|
||||||
|
ensure_scenario_in_project(
|
||||||
|
uow,
|
||||||
|
project_id=project_two.id,
|
||||||
|
scenario_id=scenario.id,
|
||||||
|
user=user,
|
||||||
|
)
|
||||||
@@ -6,7 +6,7 @@ from typing import cast
|
|||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI, Request
|
||||||
from fastapi.testclient import TestClient
|
from fastapi.testclient import TestClient
|
||||||
from pydantic import ValidationError
|
from pydantic import ValidationError
|
||||||
from sqlalchemy import create_engine
|
from sqlalchemy import create_engine
|
||||||
@@ -15,13 +15,14 @@ from sqlalchemy.engine import Engine
|
|||||||
from sqlalchemy.pool import StaticPool
|
from sqlalchemy.pool import StaticPool
|
||||||
|
|
||||||
from config.database import Base
|
from config.database import Base
|
||||||
from dependencies import get_unit_of_work
|
from dependencies import get_auth_session, get_unit_of_work
|
||||||
from models import (
|
from models import (
|
||||||
MiningOperationType,
|
MiningOperationType,
|
||||||
Project,
|
Project,
|
||||||
ResourceType,
|
ResourceType,
|
||||||
Scenario,
|
Scenario,
|
||||||
ScenarioStatus,
|
ScenarioStatus,
|
||||||
|
User,
|
||||||
)
|
)
|
||||||
from schemas.scenario import (
|
from schemas.scenario import (
|
||||||
ScenarioComparisonRequest,
|
ScenarioComparisonRequest,
|
||||||
@@ -30,6 +31,7 @@ from schemas.scenario import (
|
|||||||
from services.exceptions import ScenarioValidationError
|
from services.exceptions import ScenarioValidationError
|
||||||
from services.scenario_validation import ScenarioComparisonValidator
|
from services.scenario_validation import ScenarioComparisonValidator
|
||||||
from services.unit_of_work import UnitOfWork
|
from services.unit_of_work import UnitOfWork
|
||||||
|
from services.session import AuthSession, SessionTokens
|
||||||
from routes.scenarios import router as scenarios_router
|
from routes.scenarios import router as scenarios_router
|
||||||
|
|
||||||
|
|
||||||
@@ -159,6 +161,28 @@ def api_client(session_factory) -> Iterator[TestClient]:
|
|||||||
yield uow
|
yield uow
|
||||||
|
|
||||||
app.dependency_overrides[get_unit_of_work] = _override_uow
|
app.dependency_overrides[get_unit_of_work] = _override_uow
|
||||||
|
|
||||||
|
with UnitOfWork(session_factory=session_factory) as uow:
|
||||||
|
assert uow.users is not None
|
||||||
|
uow.ensure_default_roles()
|
||||||
|
user = User(
|
||||||
|
email="test-scenarios@example.com",
|
||||||
|
username="scenario-tester",
|
||||||
|
password_hash=User.hash_password("password"),
|
||||||
|
is_active=True,
|
||||||
|
is_superuser=True,
|
||||||
|
)
|
||||||
|
uow.users.create(user)
|
||||||
|
user = uow.users.get(user.id, with_roles=True)
|
||||||
|
|
||||||
|
def _override_auth_session(request: Request) -> AuthSession:
|
||||||
|
session = AuthSession(tokens=SessionTokens(
|
||||||
|
access_token="test", refresh_token="test"))
|
||||||
|
session.user = user
|
||||||
|
request.state.auth_session = session
|
||||||
|
return session
|
||||||
|
|
||||||
|
app.dependency_overrides[get_auth_session] = _override_auth_session
|
||||||
client = TestClient(app)
|
client = TestClient(app)
|
||||||
try:
|
try:
|
||||||
yield client
|
yield client
|
||||||
@@ -171,6 +195,8 @@ def _create_project_with_scenarios(
|
|||||||
scenario_overrides: list[dict[str, object]],
|
scenario_overrides: list[dict[str, object]],
|
||||||
) -> tuple[int, list[int]]:
|
) -> tuple[int, list[int]]:
|
||||||
with UnitOfWork(session_factory=session_factory) as uow:
|
with UnitOfWork(session_factory=session_factory) as uow:
|
||||||
|
assert uow.projects is not None
|
||||||
|
assert uow.scenarios is not None
|
||||||
project_name = f"Project {uuid4()}"
|
project_name = f"Project {uuid4()}"
|
||||||
project = Project(name=project_name,
|
project = Project(name=project_name,
|
||||||
operation_type=MiningOperationType.OPEN_PIT)
|
operation_type=MiningOperationType.OPEN_PIT)
|
||||||
|
|||||||
Reference in New Issue
Block a user