diff --git a/.env.example b/.env.example index 7393cde..7329fa2 100644 --- a/.env.example +++ b/.env.example @@ -9,6 +9,3 @@ DATABASE_PASSWORD= DATABASE_NAME=calminer # Optional: set a schema (comma-separated for multiple entries) # DATABASE_SCHEMA=public - -# Legacy fallback (still supported, but granular settings are preferred) -# DATABASE_URL=postgresql://:@localhost:5432/calminer \ No newline at end of file diff --git a/changelog.md b/changelog.md index 7606c0d..ed4c7fc 100644 --- a/changelog.md +++ b/changelog.md @@ -18,3 +18,10 @@ - Updated all Jinja2 template responses to the new Starlette signature to eliminate deprecation warnings while keeping request-aware context available to the templates. - Introduced `services/security.py` to centralize Argon2 password hashing utilities and JWT creation/verification with typed payloads, and added pytest coverage for hashing, expiry, tampering, and token type mismatch scenarios. - Added `routes/auth.py` with registration, login, and password reset flows, refreshed auth templates with error messaging, wired navigation links, and introduced end-to-end pytest coverage for the new forms and token flows. +- Implemented cookie-based authentication session middleware with automatic access token refresh, logout handling, navigation adjustments, and documentation/test updates capturing the new behaviour. +- Delivered idempotent seeding utilities with `scripts/initial_data.py`, entry-point runner `scripts/00_initial_data.py`, documentation updates, and pytest coverage to verify role/admin provisioning. +- Secured project and scenario routers with RBAC guard dependencies, enforced repository access checks via helper utilities, and aligned template routes with FastAPI dependency injection patterns. + +## 2025-11-10 + +- Extended authorization helper layer with project/scenario ownership lookups, integrated them into FastAPI dependencies, refreshed pytest fixtures to keep the suite authenticated, and documented the new patterns across RBAC plan and security guides. diff --git a/dependencies.py b/dependencies.py index a043f55..b91c05d 100644 --- a/dependencies.py +++ b/dependencies.py @@ -1,11 +1,17 @@ from __future__ import annotations -from collections.abc import Generator +from collections.abc import Callable, Iterable, Generator from fastapi import Depends, HTTPException, Request, status from config.settings import Settings, get_settings -from models import User +from models import Project, Role, Scenario, User +from services.authorization import ( + ensure_project_access as ensure_project_access_helper, + ensure_scenario_access as ensure_scenario_access_helper, + ensure_scenario_in_project as ensure_scenario_in_project_helper, +) +from services.exceptions import AuthorizationError, EntityNotFoundError from services.security import JWTSettings from services.session import ( AuthSession, @@ -90,9 +96,150 @@ def require_current_user( ) -> User: """Ensure that a request is authenticated and return the user context.""" - if session.user is None: + if session.user is None or session.tokens.is_empty: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Authentication required.", ) return session.user + + +def require_authenticated_user( + user: User = Depends(require_current_user), +) -> User: + """Ensure the current user account is active.""" + + if not user.is_active: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="User account is disabled.", + ) + return user + + +def _user_role_names(user: User) -> set[str]: + roles: Iterable[Role] = getattr(user, "roles", []) or [] + return {role.name for role in roles} + + +def require_roles(*roles: str) -> Callable[[User], User]: + """Dependency factory enforcing membership in one of the given roles.""" + + required = tuple(role.strip() for role in roles if role.strip()) + if not required: + raise ValueError("require_roles requires at least one role name") + + def _dependency(user: User = Depends(require_authenticated_user)) -> User: + if user.is_superuser: + return user + + role_names = _user_role_names(user) + if not any(role in role_names for role in required): + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Insufficient permissions for this action.", + ) + return user + + return _dependency + + +def require_any_role(*roles: str) -> Callable[[User], User]: + """Alias of require_roles for readability in some contexts.""" + + return require_roles(*roles) + + +def require_project_resource(*, require_manage: bool = False) -> Callable[[int], Project]: + """Dependency factory that resolves a project with authorization checks.""" + + def _dependency( + project_id: int, + user: User = Depends(require_authenticated_user), + uow: UnitOfWork = Depends(get_unit_of_work), + ) -> Project: + try: + return ensure_project_access_helper( + uow, + project_id=project_id, + user=user, + require_manage=require_manage, + ) + except EntityNotFoundError as exc: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=str(exc), + ) from exc + except AuthorizationError as exc: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail=str(exc), + ) from exc + + return _dependency + + +def require_scenario_resource( + *, require_manage: bool = False, with_children: bool = False +) -> Callable[[int], Scenario]: + """Dependency factory that resolves a scenario with authorization checks.""" + + def _dependency( + scenario_id: int, + user: User = Depends(require_authenticated_user), + uow: UnitOfWork = Depends(get_unit_of_work), + ) -> Scenario: + try: + return ensure_scenario_access_helper( + uow, + scenario_id=scenario_id, + user=user, + require_manage=require_manage, + with_children=with_children, + ) + except EntityNotFoundError as exc: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=str(exc), + ) from exc + except AuthorizationError as exc: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail=str(exc), + ) from exc + + return _dependency + + +def require_project_scenario_resource( + *, require_manage: bool = False, with_children: bool = False +) -> Callable[[int, int], Scenario]: + """Dependency factory ensuring a scenario belongs to the given project and is accessible.""" + + def _dependency( + project_id: int, + scenario_id: int, + user: User = Depends(require_authenticated_user), + uow: UnitOfWork = Depends(get_unit_of_work), + ) -> Scenario: + try: + return ensure_scenario_in_project_helper( + uow, + project_id=project_id, + scenario_id=scenario_id, + user=user, + require_manage=require_manage, + with_children=with_children, + ) + except EntityNotFoundError as exc: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=str(exc), + ) from exc + except AuthorizationError as exc: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail=str(exc), + ) from exc + + return _dependency diff --git a/routes/dashboard.py b/routes/dashboard.py index a8f32cf..dad24b6 100644 --- a/routes/dashboard.py +++ b/routes/dashboard.py @@ -6,7 +6,8 @@ from fastapi import APIRouter, Depends, Request from fastapi.responses import HTMLResponse from fastapi.templating import Jinja2Templates -from dependencies import get_unit_of_work +from dependencies import get_unit_of_work, require_authenticated_user +from models import User from models import ScenarioStatus from services.unit_of_work import UnitOfWork @@ -27,6 +28,8 @@ def _format_timestamp_with_time(moment: datetime | None) -> str | None: def _load_metrics(uow: UnitOfWork) -> dict[str, object]: + if not uow.projects or not uow.scenarios or not uow.financial_inputs: + raise RuntimeError("UnitOfWork repositories not initialised") total_projects = uow.projects.count() active_scenarios = uow.scenarios.count_by_status(ScenarioStatus.ACTIVE) pending_simulations = uow.scenarios.count_by_status(ScenarioStatus.DRAFT) @@ -40,11 +43,15 @@ def _load_metrics(uow: UnitOfWork) -> dict[str, object]: def _load_recent_projects(uow: UnitOfWork) -> list: + if not uow.projects: + raise RuntimeError("Project repository not initialised") return list(uow.projects.recent(limit=5)) def _load_simulation_updates(uow: UnitOfWork) -> list[dict[str, object]]: updates: list[dict[str, object]] = [] + if not uow.scenarios: + raise RuntimeError("Scenario repository not initialised") scenarios = uow.scenarios.recent(limit=5, with_project=True) for scenario in scenarios: project_name = scenario.project.name if scenario.project else f"Project #{scenario.project_id}" @@ -65,6 +72,9 @@ def _load_scenario_alerts( ) -> list[dict[str, object]]: alerts: list[dict[str, object]] = [] + if not uow.scenarios: + raise RuntimeError("Scenario repository not initialised") + drafts = uow.scenarios.list_by_status( ScenarioStatus.DRAFT, limit=3, with_project=True ) @@ -102,6 +112,7 @@ def _load_scenario_alerts( @router.get("/", response_class=HTMLResponse, include_in_schema=False, name="dashboard.home") def dashboard_home( request: Request, + _: User = Depends(require_authenticated_user), uow: UnitOfWork = Depends(get_unit_of_work), ) -> HTMLResponse: context = { diff --git a/routes/projects.py b/routes/projects.py index a449e02..f105881 100644 --- a/routes/projects.py +++ b/routes/projects.py @@ -6,8 +6,13 @@ from fastapi import APIRouter, Depends, Form, HTTPException, Request, status from fastapi.responses import HTMLResponse, RedirectResponse from fastapi.templating import Jinja2Templates -from dependencies import get_unit_of_work -from models import MiningOperationType, Project, ScenarioStatus +from dependencies import ( + get_unit_of_work, + require_any_role, + require_project_resource, + require_roles, +) +from models import MiningOperationType, Project, ScenarioStatus, User from schemas.project import ProjectCreate, ProjectRead, ProjectUpdate from services.exceptions import EntityConflictError, EntityNotFoundError from services.unit_of_work import UnitOfWork @@ -15,11 +20,20 @@ from services.unit_of_work import UnitOfWork router = APIRouter(prefix="/projects", tags=["Projects"]) templates = Jinja2Templates(directory="templates") +READ_ROLES = ("viewer", "analyst", "project_manager", "admin") +MANAGE_ROLES = ("project_manager", "admin") + def _to_read_model(project: Project) -> ProjectRead: return ProjectRead.model_validate(project) +def _require_project_repo(uow: UnitOfWork): + if not uow.projects: + raise RuntimeError("Project repository not initialised") + return uow.projects + + def _operation_type_choices() -> list[tuple[str, str]]: return [ (op.value, op.name.replace("_", " ").title()) for op in MiningOperationType @@ -27,18 +41,23 @@ def _operation_type_choices() -> list[tuple[str, str]]: @router.get("", response_model=List[ProjectRead]) -def list_projects(uow: UnitOfWork = Depends(get_unit_of_work)) -> List[ProjectRead]: - projects = uow.projects.list() +def list_projects( + _: User = Depends(require_any_role(*READ_ROLES)), + uow: UnitOfWork = Depends(get_unit_of_work), +) -> List[ProjectRead]: + projects = _require_project_repo(uow).list() return [_to_read_model(project) for project in projects] @router.post("", response_model=ProjectRead, status_code=status.HTTP_201_CREATED) def create_project( - payload: ProjectCreate, uow: UnitOfWork = Depends(get_unit_of_work) + payload: ProjectCreate, + _: User = Depends(require_roles(*MANAGE_ROLES)), + uow: UnitOfWork = Depends(get_unit_of_work), ) -> ProjectRead: project = Project(**payload.model_dump()) try: - created = uow.projects.create(project) + created = _require_project_repo(uow).create(project) except EntityConflictError as exc: raise HTTPException( status_code=status.HTTP_409_CONFLICT, detail=str(exc) @@ -53,9 +72,11 @@ def create_project( name="projects.project_list_page", ) def project_list_page( - request: Request, uow: UnitOfWork = Depends(get_unit_of_work) + request: Request, + _: User = Depends(require_any_role(*READ_ROLES)), + uow: UnitOfWork = Depends(get_unit_of_work), ) -> HTMLResponse: - projects = uow.projects.list(with_children=True) + projects = _require_project_repo(uow).list(with_children=True) for project in projects: setattr(project, "scenario_count", len(project.scenarios)) return templates.TemplateResponse( @@ -73,7 +94,9 @@ def project_list_page( include_in_schema=False, name="projects.create_project_form", ) -def create_project_form(request: Request) -> HTMLResponse: +def create_project_form( + request: Request, _: User = Depends(require_roles(*MANAGE_ROLES)) +) -> HTMLResponse: return templates.TemplateResponse( request, "projects/form.html", @@ -93,6 +116,7 @@ def create_project_form(request: Request) -> HTMLResponse: ) def create_project_submit( request: Request, + _: User = Depends(require_roles(*MANAGE_ROLES)), name: str = Form(...), location: str | None = Form(None), operation_type: str = Form(...), @@ -128,7 +152,7 @@ def create_project_submit( description=_normalise(description), ) try: - uow.projects.create(project) + _require_project_repo(uow).create(project) except EntityConflictError as exc: return templates.TemplateResponse( request, @@ -150,29 +174,18 @@ def create_project_submit( @router.get("/{project_id}", response_model=ProjectRead) -def get_project(project_id: int, uow: UnitOfWork = Depends(get_unit_of_work)) -> ProjectRead: - try: - project = uow.projects.get(project_id) - except EntityNotFoundError as exc: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, detail=str(exc) - ) from exc +def get_project(project: Project = Depends(require_project_resource())) -> ProjectRead: return _to_read_model(project) @router.put("/{project_id}", response_model=ProjectRead) def update_project( - project_id: int, payload: ProjectUpdate, + project: Project = Depends( + require_project_resource(require_manage=True) + ), uow: UnitOfWork = Depends(get_unit_of_work), ) -> ProjectRead: - try: - project = uow.projects.get(project_id) - except EntityNotFoundError as exc: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, detail=str(exc) - ) from exc - update_data = payload.model_dump(exclude_unset=True) for field, value in update_data.items(): setattr(project, field, value) @@ -182,13 +195,11 @@ def update_project( @router.delete("/{project_id}", status_code=status.HTTP_204_NO_CONTENT) -def delete_project(project_id: int, uow: UnitOfWork = Depends(get_unit_of_work)) -> None: - try: - uow.projects.delete(project_id) - except EntityNotFoundError as exc: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, detail=str(exc) - ) from exc +def delete_project( + project: Project = Depends(require_project_resource(require_manage=True)), + uow: UnitOfWork = Depends(get_unit_of_work), +) -> None: + _require_project_repo(uow).delete(project.id) @router.get( @@ -198,14 +209,11 @@ def delete_project(project_id: int, uow: UnitOfWork = Depends(get_unit_of_work)) name="projects.view_project", ) def view_project( - project_id: int, request: Request, uow: UnitOfWork = Depends(get_unit_of_work) + request: Request, + project: Project = Depends(require_project_resource()), + uow: UnitOfWork = Depends(get_unit_of_work), ) -> HTMLResponse: - try: - project = uow.projects.get(project_id, with_children=True) - except EntityNotFoundError as exc: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, detail=str(exc) - ) from exc + project = _require_project_repo(uow).get(project.id, with_children=True) scenarios = sorted(project.scenarios, key=lambda s: s.created_at) scenario_stats = { @@ -236,15 +244,11 @@ def view_project( name="projects.edit_project_form", ) def edit_project_form( - project_id: int, request: Request, uow: UnitOfWork = Depends(get_unit_of_work) + request: Request, + project: Project = Depends( + require_project_resource(require_manage=True) + ), ) -> HTMLResponse: - try: - project = uow.projects.get(project_id) - except EntityNotFoundError as exc: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, detail=str(exc) - ) from exc - return templates.TemplateResponse( request, "projects/form.html", @@ -252,10 +256,10 @@ def edit_project_form( "project": project, "operation_types": _operation_type_choices(), "form_action": request.url_for( - "projects.edit_project_submit", project_id=project_id + "projects.edit_project_submit", project_id=project.id ), "cancel_url": request.url_for( - "projects.view_project", project_id=project_id + "projects.view_project", project_id=project.id ), }, ) @@ -267,21 +271,16 @@ def edit_project_form( name="projects.edit_project_submit", ) def edit_project_submit( - project_id: int, request: Request, + project: Project = Depends( + require_project_resource(require_manage=True) + ), name: str = Form(...), location: str | None = Form(None), operation_type: str | None = Form(None), description: str | None = Form(None), uow: UnitOfWork = Depends(get_unit_of_work), ): - try: - project = uow.projects.get(project_id) - except EntityNotFoundError as exc: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, detail=str(exc) - ) from exc - def _normalise(value: str | None) -> str | None: if value is None: return None @@ -301,10 +300,10 @@ def edit_project_submit( "project": project, "operation_types": _operation_type_choices(), "form_action": request.url_for( - "projects.edit_project_submit", project_id=project_id + "projects.edit_project_submit", project_id=project.id ), "cancel_url": request.url_for( - "projects.view_project", project_id=project_id + "projects.view_project", project_id=project.id ), "error": "Invalid operation type.", }, @@ -315,6 +314,6 @@ def edit_project_submit( uow.flush() return RedirectResponse( - request.url_for("projects.view_project", project_id=project_id), + request.url_for("projects.view_project", project_id=project.id), status_code=status.HTTP_303_SEE_OTHER, ) diff --git a/routes/scenarios.py b/routes/scenarios.py index aa4803d..89d5060 100644 --- a/routes/scenarios.py +++ b/routes/scenarios.py @@ -7,8 +7,13 @@ from fastapi import APIRouter, Depends, Form, HTTPException, Request, status from fastapi.responses import HTMLResponse, RedirectResponse from fastapi.templating import Jinja2Templates -from dependencies import get_unit_of_work -from models import ResourceType, Scenario, ScenarioStatus +from dependencies import ( + get_unit_of_work, + require_any_role, + require_roles, + require_scenario_resource, +) +from models import ResourceType, Scenario, ScenarioStatus, User from schemas.scenario import ( ScenarioComparisonRequest, ScenarioComparisonResponse, @@ -26,6 +31,9 @@ from services.unit_of_work import UnitOfWork router = APIRouter(tags=["Scenarios"]) templates = Jinja2Templates(directory="templates") +READ_ROLES = ("viewer", "analyst", "project_manager", "admin") +MANAGE_ROLES = ("project_manager", "admin") + def _to_read_model(scenario: Scenario) -> ScenarioRead: return ScenarioRead.model_validate(scenario) @@ -44,20 +52,36 @@ def _scenario_status_choices() -> list[tuple[str, str]]: ] +def _require_project_repo(uow: UnitOfWork): + if not uow.projects: + raise RuntimeError("Project repository not initialised") + return uow.projects + + +def _require_scenario_repo(uow: UnitOfWork): + if not uow.scenarios: + raise RuntimeError("Scenario repository not initialised") + return uow.scenarios + + @router.get( "/projects/{project_id}/scenarios", response_model=List[ScenarioRead], ) def list_scenarios_for_project( - project_id: int, uow: UnitOfWork = Depends(get_unit_of_work) + project_id: int, + _: User = Depends(require_any_role(*READ_ROLES)), + uow: UnitOfWork = Depends(get_unit_of_work), ) -> List[ScenarioRead]: + project_repo = _require_project_repo(uow) + scenario_repo = _require_scenario_repo(uow) try: - uow.projects.get(project_id) + project_repo.get(project_id) except EntityNotFoundError as exc: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)) from exc - scenarios = uow.scenarios.list_for_project(project_id) + scenarios = scenario_repo.list_for_project(project_id) return [_to_read_model(scenario) for scenario in scenarios] @@ -69,10 +93,11 @@ def list_scenarios_for_project( def compare_scenarios( project_id: int, payload: ScenarioComparisonRequest, + _: User = Depends(require_any_role(*READ_ROLES)), uow: UnitOfWork = Depends(get_unit_of_work), ) -> ScenarioComparisonResponse: try: - uow.projects.get(project_id) + _require_project_repo(uow).get(project_id) except EntityNotFoundError as exc: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail=str(exc) @@ -116,10 +141,13 @@ def compare_scenarios( def create_scenario_for_project( project_id: int, payload: ScenarioCreate, + _: User = Depends(require_roles(*MANAGE_ROLES)), uow: UnitOfWork = Depends(get_unit_of_work), ) -> ScenarioRead: + project_repo = _require_project_repo(uow) + scenario_repo = _require_scenario_repo(uow) try: - uow.projects.get(project_id) + project_repo.get(project_id) except EntityNotFoundError as exc: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)) from exc @@ -127,7 +155,7 @@ def create_scenario_for_project( scenario = Scenario(project_id=project_id, **payload.model_dump()) try: - created = uow.scenarios.create(scenario) + created = scenario_repo.create(scenario) except EntityConflictError as exc: raise HTTPException( status_code=status.HTTP_409_CONFLICT, detail=str(exc)) from exc @@ -136,28 +164,19 @@ def create_scenario_for_project( @router.get("/scenarios/{scenario_id}", response_model=ScenarioRead) def get_scenario( - scenario_id: int, uow: UnitOfWork = Depends(get_unit_of_work) + scenario: Scenario = Depends(require_scenario_resource()), ) -> ScenarioRead: - try: - scenario = uow.scenarios.get(scenario_id) - except EntityNotFoundError as exc: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)) from exc return _to_read_model(scenario) @router.put("/scenarios/{scenario_id}", response_model=ScenarioRead) def update_scenario( - scenario_id: int, payload: ScenarioUpdate, + scenario: Scenario = Depends( + require_scenario_resource(require_manage=True) + ), uow: UnitOfWork = Depends(get_unit_of_work), ) -> ScenarioRead: - try: - scenario = uow.scenarios.get(scenario_id) - except EntityNotFoundError as exc: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)) from exc - update_data = payload.model_dump(exclude_unset=True) for field, value in update_data.items(): setattr(scenario, field, value) @@ -168,13 +187,12 @@ def update_scenario( @router.delete("/scenarios/{scenario_id}", status_code=status.HTTP_204_NO_CONTENT) def delete_scenario( - scenario_id: int, uow: UnitOfWork = Depends(get_unit_of_work) + scenario: Scenario = Depends( + require_scenario_resource(require_manage=True) + ), + uow: UnitOfWork = Depends(get_unit_of_work), ) -> None: - try: - uow.scenarios.delete(scenario_id) - except EntityNotFoundError as exc: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)) from exc + _require_scenario_repo(uow).delete(scenario.id) def _normalise(value: str | None) -> str | None: @@ -208,10 +226,13 @@ def _parse_discount_rate(value: str | None) -> float | None: name="scenarios.create_scenario_form", ) def create_scenario_form( - project_id: int, request: Request, uow: UnitOfWork = Depends(get_unit_of_work) + project_id: int, + request: Request, + _: User = Depends(require_roles(*MANAGE_ROLES)), + uow: UnitOfWork = Depends(get_unit_of_work), ) -> HTMLResponse: try: - project = uow.projects.get(project_id) + project = _require_project_repo(uow).get(project_id) except EntityNotFoundError as exc: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail=str(exc) @@ -243,6 +264,7 @@ def create_scenario_form( def create_scenario_submit( project_id: int, request: Request, + _: User = Depends(require_roles(*MANAGE_ROLES)), name: str = Form(...), description: str | None = Form(None), status_value: str = Form(ScenarioStatus.DRAFT.value), @@ -253,8 +275,10 @@ def create_scenario_submit( primary_resource: str | None = Form(None), uow: UnitOfWork = Depends(get_unit_of_work), ): + project_repo = _require_project_repo(uow) + scenario_repo = _require_scenario_repo(uow) try: - project = uow.projects.get(project_id) + project = project_repo.get(project_id) except EntityNotFoundError as exc: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail=str(exc) @@ -288,7 +312,7 @@ def create_scenario_submit( ) try: - uow.scenarios.create(scenario) + scenario_repo.create(scenario) except EntityConflictError as exc: return templates.TemplateResponse( request, @@ -322,16 +346,13 @@ def create_scenario_submit( name="scenarios.view_scenario", ) def view_scenario( - scenario_id: int, request: Request, uow: UnitOfWork = Depends(get_unit_of_work) + request: Request, + scenario: Scenario = Depends( + require_scenario_resource(with_children=True) + ), + uow: UnitOfWork = Depends(get_unit_of_work), ) -> HTMLResponse: - try: - scenario = uow.scenarios.get(scenario_id, with_children=True) - except EntityNotFoundError as exc: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, detail=str(exc) - ) from exc - - project = uow.projects.get(scenario.project_id) + project = _require_project_repo(uow).get(scenario.project_id) financial_inputs = sorted( scenario.financial_inputs, key=lambda item: item.created_at ) @@ -366,16 +387,13 @@ def view_scenario( name="scenarios.edit_scenario_form", ) def edit_scenario_form( - scenario_id: int, request: Request, uow: UnitOfWork = Depends(get_unit_of_work) + request: Request, + scenario: Scenario = Depends( + require_scenario_resource(require_manage=True) + ), + uow: UnitOfWork = Depends(get_unit_of_work), ) -> HTMLResponse: - try: - scenario = uow.scenarios.get(scenario_id) - except EntityNotFoundError as exc: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, detail=str(exc) - ) from exc - - project = uow.projects.get(scenario.project_id) + project = _require_project_repo(uow).get(scenario.project_id) return templates.TemplateResponse( request, @@ -386,10 +404,10 @@ def edit_scenario_form( "scenario_statuses": _scenario_status_choices(), "resource_types": _resource_type_choices(), "form_action": request.url_for( - "scenarios.edit_scenario_submit", scenario_id=scenario_id + "scenarios.edit_scenario_submit", scenario_id=scenario.id ), "cancel_url": request.url_for( - "scenarios.view_scenario", scenario_id=scenario_id + "scenarios.view_scenario", scenario_id=scenario.id ), }, ) @@ -401,8 +419,10 @@ def edit_scenario_form( name="scenarios.edit_scenario_submit", ) def edit_scenario_submit( - scenario_id: int, request: Request, + scenario: Scenario = Depends( + require_scenario_resource(require_manage=True) + ), name: str = Form(...), description: str | None = Form(None), status_value: str = Form(ScenarioStatus.DRAFT.value), @@ -413,14 +433,7 @@ def edit_scenario_submit( primary_resource: str | None = Form(None), uow: UnitOfWork = Depends(get_unit_of_work), ): - try: - scenario = uow.scenarios.get(scenario_id) - except EntityNotFoundError as exc: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, detail=str(exc) - ) from exc - - project = uow.projects.get(scenario.project_id) + project = _require_project_repo(uow).get(scenario.project_id) scenario.name = name.strip() scenario.description = _normalise(description) @@ -447,6 +460,6 @@ def edit_scenario_submit( uow.flush() return RedirectResponse( - request.url_for("scenarios.view_scenario", scenario_id=scenario_id), + request.url_for("scenarios.view_scenario", scenario_id=scenario.id), status_code=status.HTTP_303_SEE_OTHER, ) diff --git a/scripts/00_initial_data.py b/scripts/00_initial_data.py new file mode 100644 index 0000000..e189001 --- /dev/null +++ b/scripts/00_initial_data.py @@ -0,0 +1,20 @@ +from __future__ import annotations + +import logging + +from scripts.initial_data import load_config, seed_initial_data + + +def main() -> int: + logging.basicConfig(level=logging.INFO, format="[%(levelname)s] %(message)s") + try: + config = load_config() + seed_initial_data(config) + except Exception as exc: # pragma: no cover - operational guard + logging.exception("Seeding failed: %s", exc) + return 1 + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/scripts/__init__.py b/scripts/__init__.py new file mode 100644 index 0000000..395066d --- /dev/null +++ b/scripts/__init__.py @@ -0,0 +1 @@ +"""Utility scripts for CalMiner maintenance tasks.""" diff --git a/scripts/initial_data.py b/scripts/initial_data.py new file mode 100644 index 0000000..bd877b7 --- /dev/null +++ b/scripts/initial_data.py @@ -0,0 +1,183 @@ +from __future__ import annotations + +import logging +import os +from dataclasses import dataclass +from typing import Callable, Iterable + +from dotenv import load_dotenv + +from models import Role, User +from services.repositories import DEFAULT_ROLE_DEFINITIONS, RoleRepository, UserRepository +from services.unit_of_work import UnitOfWork + + +@dataclass +class SeedConfig: + admin_email: str + admin_username: str + admin_password: str + admin_roles: tuple[str, ...] + force_reset: bool + + +@dataclass +class RoleSeedResult: + created: int + updated: int + total: int + + +@dataclass +class AdminSeedResult: + created_user: bool + updated_user: bool + password_rotated: bool + roles_granted: int + + +def parse_bool(value: str | None) -> bool: + if value is None: + return False + return value.strip().lower() in {"1", "true", "yes", "on"} + + +def normalise_role_list(raw_value: str | None) -> tuple[str, ...]: + if not raw_value: + return ("admin",) + parts = [segment.strip() for segment in raw_value.split(",") if segment.strip()] + if "admin" not in parts: + parts.insert(0, "admin") + seen: set[str] = set() + ordered: list[str] = [] + for role_name in parts: + if role_name not in seen: + ordered.append(role_name) + seen.add(role_name) + return tuple(ordered) + + +def load_config() -> SeedConfig: + load_dotenv() + admin_email = os.getenv("CALMINER_SEED_ADMIN_EMAIL", "admin@calminer.local") + admin_username = os.getenv("CALMINER_SEED_ADMIN_USERNAME", "admin") + admin_password = os.getenv("CALMINER_SEED_ADMIN_PASSWORD", "ChangeMe123!") + admin_roles = normalise_role_list(os.getenv("CALMINER_SEED_ADMIN_ROLES")) + force_reset = parse_bool(os.getenv("CALMINER_SEED_FORCE")) + return SeedConfig( + admin_email=admin_email, + admin_username=admin_username, + admin_password=admin_password, + admin_roles=admin_roles, + force_reset=force_reset, + ) + + +def ensure_default_roles( + role_repo: RoleRepository, + definitions: Iterable[dict[str, str]] = DEFAULT_ROLE_DEFINITIONS, +) -> RoleSeedResult: + created = 0 + updated = 0 + total = 0 + for definition in definitions: + total += 1 + existing = role_repo.get_by_name(definition["name"]) + if existing is None: + role_repo.create(Role(**definition)) + created += 1 + continue + changed = False + if existing.display_name != definition["display_name"]: + existing.display_name = definition["display_name"] + changed = True + if existing.description != definition["description"]: + existing.description = definition["description"] + changed = True + if changed: + updated += 1 + role_repo.session.flush() + return RoleSeedResult(created=created, updated=updated, total=total) + + +def ensure_admin_user( + user_repo: UserRepository, + role_repo: RoleRepository, + config: SeedConfig, +) -> AdminSeedResult: + created_user = False + updated_user = False + password_rotated = False + roles_granted = 0 + + user = user_repo.get_by_email(config.admin_email, with_roles=True) + if user is None: + user = User( + email=config.admin_email, + username=config.admin_username, + password_hash=User.hash_password(config.admin_password), + is_active=True, + is_superuser=True, + ) + user_repo.create(user) + created_user = True + else: + if user.username != config.admin_username: + user.username = config.admin_username + updated_user = True + if not user.is_active: + user.is_active = True + updated_user = True + if not user.is_superuser: + user.is_superuser = True + updated_user = True + if config.force_reset: + user.password_hash = User.hash_password(config.admin_password) + password_rotated = True + updated_user = True + user_repo.session.flush() + + for role_name in config.admin_roles: + role = role_repo.get_by_name(role_name) + if role is None: + logging.warning("Role '%s' is not defined and will be skipped", role_name) + continue + already_assigned = any(assignment.role_id == role.id for assignment in user.role_assignments) + if already_assigned: + continue + user_repo.assign_role(user_id=user.id, role_id=role.id, granted_by=user.id) + roles_granted += 1 + + return AdminSeedResult( + created_user=created_user, + updated_user=updated_user, + password_rotated=password_rotated, + roles_granted=roles_granted, + ) + + +def seed_initial_data( + config: SeedConfig, + *, + unit_of_work_factory: Callable[[], UnitOfWork] | None = None, +) -> None: + logging.info("Starting initial data seeding") + factory = unit_of_work_factory or UnitOfWork + with factory() as uow: + assert uow.roles is not None and uow.users is not None + role_result = ensure_default_roles(uow.roles) + admin_result = ensure_admin_user(uow.users, uow.roles, config) + logging.info( + "Roles processed: %s total, %s created, %s updated", + role_result.total, + role_result.created, + role_result.updated, + ) + logging.info( + "Admin user: created=%s updated=%s password_rotated=%s roles_granted=%s", + admin_result.created_user, + admin_result.updated_user, + admin_result.password_rotated, + admin_result.roles_granted, + ) + logging.info("Initial data seeding completed successfully") \ No newline at end of file diff --git a/services/authorization.py b/services/authorization.py new file mode 100644 index 0000000..3a19a39 --- /dev/null +++ b/services/authorization.py @@ -0,0 +1,104 @@ +from __future__ import annotations + +from typing import Iterable + +from models import Project, Role, Scenario, User +from services.exceptions import AuthorizationError, EntityNotFoundError +from services.repositories import ProjectRepository, ScenarioRepository +from services.unit_of_work import UnitOfWork + +READ_ROLES: frozenset[str] = frozenset( + {"viewer", "analyst", "project_manager", "admin"} +) +MANAGE_ROLES: frozenset[str] = frozenset({"project_manager", "admin"}) + + +def _user_role_names(user: User) -> set[str]: + roles: Iterable[Role] = getattr(user, "roles", []) or [] + return {role.name for role in roles} + + +def _require_project_repo(uow: UnitOfWork) -> ProjectRepository: + if not uow.projects: + raise RuntimeError("Project repository not initialised") + return uow.projects + + +def _require_scenario_repo(uow: UnitOfWork) -> ScenarioRepository: + if not uow.scenarios: + raise RuntimeError("Scenario repository not initialised") + return uow.scenarios + + +def _assert_user_can_access(user: User, *, require_manage: bool) -> None: + if not user.is_active: + raise AuthorizationError("User account is disabled.") + if user.is_superuser: + return + + allowed = MANAGE_ROLES if require_manage else READ_ROLES + if not _user_role_names(user) & allowed: + raise AuthorizationError( + "Insufficient role permissions for this action.") + + +def ensure_project_access( + uow: UnitOfWork, + *, + project_id: int, + user: User, + require_manage: bool = False, +) -> Project: + """Resolve a project and ensure the user holds the required permissions.""" + + repo = _require_project_repo(uow) + project = repo.get(project_id) + _assert_user_can_access(user, require_manage=require_manage) + return project + + +def ensure_scenario_access( + uow: UnitOfWork, + *, + scenario_id: int, + user: User, + require_manage: bool = False, + with_children: bool = False, +) -> Scenario: + """Resolve a scenario and ensure the user holds the required permissions.""" + + repo = _require_scenario_repo(uow) + scenario = repo.get(scenario_id, with_children=with_children) + _assert_user_can_access(user, require_manage=require_manage) + return scenario + + +def ensure_scenario_in_project( + uow: UnitOfWork, + *, + project_id: int, + scenario_id: int, + user: User, + require_manage: bool = False, + with_children: bool = False, +) -> Scenario: + """Resolve a scenario ensuring it belongs to the project and the user may access it.""" + + project = ensure_project_access( + uow, + project_id=project_id, + user=user, + require_manage=require_manage, + ) + scenario = ensure_scenario_access( + uow, + scenario_id=scenario_id, + user=user, + require_manage=require_manage, + with_children=with_children, + ) + if scenario.project_id != project.id: + raise EntityNotFoundError( + f"Scenario {scenario_id} does not belong to project {project_id}." + ) + return scenario diff --git a/services/exceptions.py b/services/exceptions.py index bfd1ce4..786fd43 100644 --- a/services/exceptions.py +++ b/services/exceptions.py @@ -12,6 +12,10 @@ class EntityConflictError(Exception): """Raised when attempting to create or update an entity that violates uniqueness.""" +class AuthorizationError(Exception): + """Raised when a user lacks permission to perform an action.""" + + @dataclass(eq=False) class ScenarioValidationError(Exception): """Raised when scenarios fail comparison validation rules.""" diff --git a/services/repositories.py b/services/repositories.py index 1f82691..170e33e 100644 --- a/services/repositories.py +++ b/services/repositories.py @@ -57,6 +57,10 @@ class ProjectRepository: raise EntityNotFoundError(f"Project {project_id} not found") return project + def exists(self, project_id: int) -> bool: + stmt = select(Project.id).where(Project.id == project_id).limit(1) + return self.session.execute(stmt).scalar_one_or_none() is not None + def create(self, project: Project) -> Project: self.session.add(project) try: @@ -133,6 +137,10 @@ class ScenarioRepository: raise EntityNotFoundError(f"Scenario {scenario_id} not found") return scenario + def exists(self, scenario_id: int) -> bool: + stmt = select(Scenario.id).where(Scenario.id == scenario_id).limit(1) + return self.session.execute(stmt).scalar_one_or_none() is not None + def create(self, scenario: Scenario) -> Scenario: self.session.add(scenario) try: diff --git a/tests/conftest.py b/tests/conftest.py index 6ec51e5..10f8135 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -3,7 +3,7 @@ from __future__ import annotations from collections.abc import Callable, Iterator import pytest -from fastapi import FastAPI +from fastapi import FastAPI, Request from fastapi.testclient import TestClient from sqlalchemy import create_engine from sqlalchemy.engine import Engine @@ -11,12 +11,14 @@ from sqlalchemy.orm import sessionmaker from sqlalchemy.pool import StaticPool from config.database import Base -from dependencies import get_unit_of_work +from dependencies import get_auth_session, get_unit_of_work +from models import User from routes.auth import router as auth_router from routes.dashboard import router as dashboard_router from routes.projects import router as projects_router from routes.scenarios import router as scenarios_router from services.unit_of_work import UnitOfWork +from services.session import AuthSession, SessionTokens @pytest.fixture() @@ -55,6 +57,28 @@ def app(session_factory: sessionmaker) -> FastAPI: yield uow application.dependency_overrides[get_unit_of_work] = _override_uow + + with UnitOfWork(session_factory=session_factory) as uow: + assert uow.users is not None + uow.ensure_default_roles() + user = User( + email="test-superuser@example.com", + username="test-superuser", + password_hash=User.hash_password("test-password"), + is_active=True, + is_superuser=True, + ) + uow.users.create(user) + user = uow.users.get(user.id, with_roles=True) + + def _override_auth_session(request: Request) -> AuthSession: + session = AuthSession(tokens=SessionTokens( + access_token="test", refresh_token="test")) + session.user = user + request.state.auth_session = session + return session + + application.dependency_overrides[get_auth_session] = _override_auth_session return application diff --git a/tests/scripts/test_initial_data_seed.py b/tests/scripts/test_initial_data_seed.py new file mode 100644 index 0000000..ec66ea6 --- /dev/null +++ b/tests/scripts/test_initial_data_seed.py @@ -0,0 +1,156 @@ +from __future__ import annotations + +import logging +from collections.abc import Callable + +import pytest +from sqlalchemy import create_engine +from sqlalchemy.orm import Session, sessionmaker + +from config.database import Base +from scripts import initial_data +from scripts.initial_data import AdminSeedResult, RoleSeedResult, SeedConfig +from services.repositories import DEFAULT_ROLE_DEFINITIONS +from services.unit_of_work import UnitOfWork + + +@pytest.fixture() +def in_memory_session_factory() -> Callable[[], Session]: + engine = create_engine("sqlite+pysqlite:///:memory:", future=True) + Base.metadata.create_all(engine) + factory = sessionmaker(bind=engine, autoflush=False, autocommit=False, future=True) + + def _session_factory() -> Session: + return factory() + + return _session_factory + + +@pytest.fixture() +def uow(in_memory_session_factory: Callable[[], Session]) -> UnitOfWork: + return UnitOfWork(session_factory=in_memory_session_factory) + + +def test_ensure_default_roles_idempotent(uow: UnitOfWork) -> None: + with uow as working: + assert working.roles is not None + result_first = initial_data.ensure_default_roles(working.roles, DEFAULT_ROLE_DEFINITIONS) + assert result_first == RoleSeedResult(created=4, updated=0, total=4) + + with uow as working: + assert working.roles is not None + result_second = initial_data.ensure_default_roles(working.roles, DEFAULT_ROLE_DEFINITIONS) + assert result_second == RoleSeedResult(created=0, updated=0, total=4) + + +def test_ensure_admin_user_creates_and_assigns_roles(uow: UnitOfWork) -> None: + config = SeedConfig( + admin_email="admin@example.com", + admin_username="admin", + admin_password="secret", + admin_roles=("admin", "viewer"), + force_reset=False, + ) + + with uow as working: + assert working.roles is not None + assert working.users is not None + initial_data.ensure_default_roles(working.roles, DEFAULT_ROLE_DEFINITIONS) + result = initial_data.ensure_admin_user(working.users, working.roles, config) + assert result == AdminSeedResult( + created_user=True, + updated_user=False, + password_rotated=False, + roles_granted=2, + ) + + with uow as working: + assert working.roles is not None + assert working.users is not None + result_again = initial_data.ensure_admin_user(working.users, working.roles, config) + assert result_again == AdminSeedResult( + created_user=False, + updated_user=False, + password_rotated=False, + roles_granted=0, + ) + + with uow as working: + assert working.users is not None + user = working.users.get_by_email("admin@example.com", with_roles=True) + assert user is not None + assert user.is_active is True + assert user.is_superuser is True + role_names = {role.name for role in user.roles} + assert role_names == {"admin", "viewer"} + + +def test_ensure_admin_user_force_reset_rotates_password(uow: UnitOfWork) -> None: + base_config = SeedConfig( + admin_email="admin@example.com", + admin_username="admin", + admin_password="first", + admin_roles=("admin",), + force_reset=False, + ) + + with uow as working: + assert working.roles is not None + assert working.users is not None + initial_data.ensure_default_roles(working.roles, DEFAULT_ROLE_DEFINITIONS) + initial_data.ensure_admin_user(working.users, working.roles, base_config) + + rotate_config = SeedConfig( + admin_email="admin@example.com", + admin_username="admin", + admin_password="second", + admin_roles=("admin",), + force_reset=True, + ) + + with uow as working: + assert working.users is not None + user_before = working.users.get_by_email("admin@example.com") + assert user_before is not None + old_hash = user_before.password_hash + + with uow as working: + assert working.roles is not None + assert working.users is not None + result = initial_data.ensure_admin_user(working.users, working.roles, rotate_config) + assert result.password_rotated is True + + with uow as working: + assert working.users is not None + user_after = working.users.get_by_email("admin@example.com") + assert user_after is not None + assert user_after.password_hash != old_hash + + +def test_seed_initial_data_logs_results( + caplog, + in_memory_session_factory: Callable[[], Session], +) -> None: + caplog.set_level(logging.INFO) + config = SeedConfig( + admin_email="seed@example.com", + admin_username="seed", + admin_password="seed-pass", + admin_roles=("admin",), + force_reset=False, + ) + + initial_data.seed_initial_data( + config, + unit_of_work_factory=lambda: UnitOfWork(session_factory=in_memory_session_factory), + ) + + assert "Starting initial data seeding" in caplog.text + assert "Initial data seeding completed successfully" in caplog.text + + with UnitOfWork(session_factory=in_memory_session_factory) as check_uow: + assert check_uow.users is not None + assert check_uow.roles is not None + user = check_uow.users.get_by_email("seed@example.com") + assert user is not None + assert check_uow.roles.get_by_name("admin") is not None diff --git a/tests/test_authorization_helpers.py b/tests/test_authorization_helpers.py new file mode 100644 index 0000000..a408578 --- /dev/null +++ b/tests/test_authorization_helpers.py @@ -0,0 +1,165 @@ +from __future__ import annotations + +import pytest + +from models import MiningOperationType, Project, Scenario, ScenarioStatus, User +from services.authorization import ( + ensure_project_access, + ensure_scenario_access, + ensure_scenario_in_project, +) +from services.exceptions import AuthorizationError, EntityNotFoundError +from services.unit_of_work import UnitOfWork + + +def _create_user_with_roles( + uow: UnitOfWork, + *, + email: str, + username: str, + roles: set[str], +) -> User: + assert uow.users is not None + assert uow.roles is not None + + user = User( + email=email, + username=username, + password_hash=User.hash_password("secret"), + is_active=True, + ) + uow.users.create(user) + uow.ensure_default_roles() + + for role_name in roles: + role = uow.roles.get_by_name(role_name) + assert role is not None, f"Role {role_name} should exist" + uow.users.assign_role(user_id=user.id, role_id=role.id) + + return uow.users.get(user.id, with_roles=True) + + +def _create_project(uow: UnitOfWork, name: str) -> Project: + assert uow.projects is not None + project = Project( + name=name, + location=None, + operation_type=MiningOperationType.OTHER, + description=None, + ) + uow.projects.create(project) + return project + + +def _create_scenario(uow: UnitOfWork, project: Project, name: str) -> Scenario: + assert uow.scenarios is not None + scenario = Scenario( + project_id=project.id, + name=name, + description=None, + status=ScenarioStatus.DRAFT, + ) + uow.scenarios.create(scenario) + return scenario + + +def test_ensure_project_access_allows_view_roles(unit_of_work_factory) -> None: + with unit_of_work_factory() as uow: + project = _create_project(uow, "Project A") + user = _create_user_with_roles( + uow, + email="viewer@example.com", + username="viewer", + roles={"viewer"}, + ) + + resolved = ensure_project_access( + uow, + project_id=project.id, + user=user, + ) + assert resolved.id == project.id + + with pytest.raises(AuthorizationError): + ensure_project_access( + uow, + project_id=project.id, + user=user, + require_manage=True, + ) + + +def test_ensure_project_access_allows_manage_roles(unit_of_work_factory) -> None: + with unit_of_work_factory() as uow: + project = _create_project(uow, "Project B") + user = _create_user_with_roles( + uow, + email="manager@example.com", + username="manager", + roles={"project_manager"}, + ) + + resolved = ensure_project_access( + uow, + project_id=project.id, + user=user, + require_manage=True, + ) + assert resolved.id == project.id + + +def test_ensure_scenario_access(unit_of_work_factory) -> None: + with unit_of_work_factory() as uow: + project = _create_project(uow, "Project C") + scenario = _create_scenario(uow, project, "Scenario C1") + user = _create_user_with_roles( + uow, + email="analyst@example.com", + username="analyst", + roles={"analyst"}, + ) + + resolved = ensure_scenario_access( + uow, + scenario_id=scenario.id, + user=user, + ) + assert resolved.id == scenario.id + + with pytest.raises(AuthorizationError): + ensure_scenario_access( + uow, + scenario_id=scenario.id, + user=user, + require_manage=True, + ) + + +def test_ensure_scenario_in_project_validates_membership(unit_of_work_factory) -> None: + with unit_of_work_factory() as uow: + project_one = _create_project(uow, "Project D") + project_two = _create_project(uow, "Project E") + scenario = _create_scenario(uow, project_one, "Scenario D1") + user = _create_user_with_roles( + uow, + email="manager2@example.com", + username="manager2", + roles={"project_manager"}, + ) + + resolved = ensure_scenario_in_project( + uow, + project_id=project_one.id, + scenario_id=scenario.id, + user=user, + require_manage=True, + ) + assert resolved.id == scenario.id + + with pytest.raises(EntityNotFoundError): + ensure_scenario_in_project( + uow, + project_id=project_two.id, + scenario_id=scenario.id, + user=user, + ) diff --git a/tests/test_scenario_validation.py b/tests/test_scenario_validation.py index bb26203..a9aaf07 100644 --- a/tests/test_scenario_validation.py +++ b/tests/test_scenario_validation.py @@ -6,7 +6,7 @@ from typing import cast from uuid import uuid4 import pytest -from fastapi import FastAPI +from fastapi import FastAPI, Request from fastapi.testclient import TestClient from pydantic import ValidationError from sqlalchemy import create_engine @@ -15,13 +15,14 @@ from sqlalchemy.engine import Engine from sqlalchemy.pool import StaticPool from config.database import Base -from dependencies import get_unit_of_work +from dependencies import get_auth_session, get_unit_of_work from models import ( MiningOperationType, Project, ResourceType, Scenario, ScenarioStatus, + User, ) from schemas.scenario import ( ScenarioComparisonRequest, @@ -30,6 +31,7 @@ from schemas.scenario import ( from services.exceptions import ScenarioValidationError from services.scenario_validation import ScenarioComparisonValidator from services.unit_of_work import UnitOfWork +from services.session import AuthSession, SessionTokens from routes.scenarios import router as scenarios_router @@ -159,6 +161,28 @@ def api_client(session_factory) -> Iterator[TestClient]: yield uow app.dependency_overrides[get_unit_of_work] = _override_uow + + with UnitOfWork(session_factory=session_factory) as uow: + assert uow.users is not None + uow.ensure_default_roles() + user = User( + email="test-scenarios@example.com", + username="scenario-tester", + password_hash=User.hash_password("password"), + is_active=True, + is_superuser=True, + ) + uow.users.create(user) + user = uow.users.get(user.id, with_roles=True) + + def _override_auth_session(request: Request) -> AuthSession: + session = AuthSession(tokens=SessionTokens( + access_token="test", refresh_token="test")) + session.user = user + request.state.auth_session = session + return session + + app.dependency_overrides[get_auth_session] = _override_auth_session client = TestClient(app) try: yield client @@ -171,6 +195,8 @@ def _create_project_with_scenarios( scenario_overrides: list[dict[str, object]], ) -> tuple[int, list[int]]: with UnitOfWork(session_factory=session_factory) as uow: + assert uow.projects is not None + assert uow.scenarios is not None project_name = f"Project {uuid4()}" project = Project(name=project_name, operation_type=MiningOperationType.OPEN_PIT)