- Implemented role-based access control for project and scenario routes. - Added authorization checks to ensure users have appropriate roles for viewing and managing projects and scenarios. - Introduced utility functions for ensuring project and scenario access based on user roles. - Refactored project and scenario routes to utilize new authorization helpers. - Created initial data seeding script to set up default roles and an admin user. - Added tests for authorization helpers and initial data seeding functionality. - Updated exception handling to include authorization errors.
448 lines
15 KiB
Python
448 lines
15 KiB
Python
from __future__ import annotations
|
|
|
|
from collections.abc import Iterable
|
|
from datetime import datetime
|
|
from typing import Sequence
|
|
|
|
from sqlalchemy import select, func
|
|
from sqlalchemy.exc import IntegrityError
|
|
from sqlalchemy.orm import Session, joinedload, selectinload
|
|
|
|
from models import (
|
|
FinancialInput,
|
|
Project,
|
|
Role,
|
|
Scenario,
|
|
ScenarioStatus,
|
|
SimulationParameter,
|
|
User,
|
|
UserRole,
|
|
)
|
|
from services.exceptions import EntityConflictError, EntityNotFoundError
|
|
|
|
|
|
class ProjectRepository:
|
|
"""Persistence operations for Project entities."""
|
|
|
|
def __init__(self, session: Session) -> None:
|
|
self.session = session
|
|
|
|
def list(self, *, with_children: bool = False) -> Sequence[Project]:
|
|
stmt = select(Project).order_by(Project.created_at)
|
|
if with_children:
|
|
stmt = stmt.options(selectinload(Project.scenarios))
|
|
return self.session.execute(stmt).scalars().all()
|
|
|
|
def count(self) -> int:
|
|
stmt = select(func.count(Project.id))
|
|
return self.session.execute(stmt).scalar_one()
|
|
|
|
def recent(self, limit: int = 5) -> Sequence[Project]:
|
|
stmt = (
|
|
select(Project)
|
|
.order_by(Project.updated_at.desc())
|
|
.limit(limit)
|
|
)
|
|
return self.session.execute(stmt).scalars().all()
|
|
|
|
def get(self, project_id: int, *, with_children: bool = False) -> Project:
|
|
stmt = select(Project).where(Project.id == project_id)
|
|
if with_children:
|
|
stmt = stmt.options(joinedload(Project.scenarios))
|
|
result = self.session.execute(stmt)
|
|
if with_children:
|
|
result = result.unique()
|
|
project = result.scalar_one_or_none()
|
|
if project is None:
|
|
raise EntityNotFoundError(f"Project {project_id} not found")
|
|
return project
|
|
|
|
def exists(self, project_id: int) -> bool:
|
|
stmt = select(Project.id).where(Project.id == project_id).limit(1)
|
|
return self.session.execute(stmt).scalar_one_or_none() is not None
|
|
|
|
def create(self, project: Project) -> Project:
|
|
self.session.add(project)
|
|
try:
|
|
self.session.flush()
|
|
except IntegrityError as exc: # pragma: no cover - reliance on DB constraints
|
|
raise EntityConflictError(
|
|
"Project violates uniqueness constraints") from exc
|
|
return project
|
|
|
|
def delete(self, project_id: int) -> None:
|
|
project = self.get(project_id)
|
|
self.session.delete(project)
|
|
|
|
|
|
class ScenarioRepository:
|
|
"""Persistence operations for Scenario entities."""
|
|
|
|
def __init__(self, session: Session) -> None:
|
|
self.session = session
|
|
|
|
def list_for_project(self, project_id: int) -> Sequence[Scenario]:
|
|
stmt = (
|
|
select(Scenario)
|
|
.where(Scenario.project_id == project_id)
|
|
.order_by(Scenario.created_at)
|
|
)
|
|
return self.session.execute(stmt).scalars().all()
|
|
|
|
def count(self) -> int:
|
|
stmt = select(func.count(Scenario.id))
|
|
return self.session.execute(stmt).scalar_one()
|
|
|
|
def count_by_status(self, status: ScenarioStatus) -> int:
|
|
stmt = select(func.count(Scenario.id)).where(Scenario.status == status)
|
|
return self.session.execute(stmt).scalar_one()
|
|
|
|
def recent(self, limit: int = 5, *, with_project: bool = False) -> Sequence[Scenario]:
|
|
stmt = select(Scenario).order_by(
|
|
Scenario.updated_at.desc()).limit(limit)
|
|
if with_project:
|
|
stmt = stmt.options(joinedload(Scenario.project))
|
|
return self.session.execute(stmt).scalars().all()
|
|
|
|
def list_by_status(
|
|
self,
|
|
status: ScenarioStatus,
|
|
*,
|
|
limit: int | None = None,
|
|
with_project: bool = False,
|
|
) -> Sequence[Scenario]:
|
|
stmt = (
|
|
select(Scenario)
|
|
.where(Scenario.status == status)
|
|
.order_by(Scenario.updated_at.desc())
|
|
)
|
|
if with_project:
|
|
stmt = stmt.options(joinedload(Scenario.project))
|
|
if limit is not None:
|
|
stmt = stmt.limit(limit)
|
|
return self.session.execute(stmt).scalars().all()
|
|
|
|
def get(self, scenario_id: int, *, with_children: bool = False) -> Scenario:
|
|
stmt = select(Scenario).where(Scenario.id == scenario_id)
|
|
if with_children:
|
|
stmt = stmt.options(
|
|
joinedload(Scenario.financial_inputs),
|
|
joinedload(Scenario.simulation_parameters),
|
|
)
|
|
result = self.session.execute(stmt)
|
|
if with_children:
|
|
result = result.unique()
|
|
scenario = result.scalar_one_or_none()
|
|
if scenario is None:
|
|
raise EntityNotFoundError(f"Scenario {scenario_id} not found")
|
|
return scenario
|
|
|
|
def exists(self, scenario_id: int) -> bool:
|
|
stmt = select(Scenario.id).where(Scenario.id == scenario_id).limit(1)
|
|
return self.session.execute(stmt).scalar_one_or_none() is not None
|
|
|
|
def create(self, scenario: Scenario) -> Scenario:
|
|
self.session.add(scenario)
|
|
try:
|
|
self.session.flush()
|
|
except IntegrityError as exc: # pragma: no cover
|
|
raise EntityConflictError("Scenario violates constraints") from exc
|
|
return scenario
|
|
|
|
def delete(self, scenario_id: int) -> None:
|
|
scenario = self.get(scenario_id)
|
|
self.session.delete(scenario)
|
|
|
|
|
|
class FinancialInputRepository:
|
|
"""Persistence operations for FinancialInput entities."""
|
|
|
|
def __init__(self, session: Session) -> None:
|
|
self.session = session
|
|
|
|
def list_for_scenario(self, scenario_id: int) -> Sequence[FinancialInput]:
|
|
stmt = (
|
|
select(FinancialInput)
|
|
.where(FinancialInput.scenario_id == scenario_id)
|
|
.order_by(FinancialInput.created_at)
|
|
)
|
|
return self.session.execute(stmt).scalars().all()
|
|
|
|
def bulk_upsert(self, inputs: Iterable[FinancialInput]) -> Sequence[FinancialInput]:
|
|
entities = list(inputs)
|
|
self.session.add_all(entities)
|
|
try:
|
|
self.session.flush()
|
|
except IntegrityError as exc: # pragma: no cover
|
|
raise EntityConflictError(
|
|
"Financial input violates constraints") from exc
|
|
return entities
|
|
|
|
def delete(self, input_id: int) -> None:
|
|
stmt = select(FinancialInput).where(FinancialInput.id == input_id)
|
|
entity = self.session.execute(stmt).scalar_one_or_none()
|
|
if entity is None:
|
|
raise EntityNotFoundError(f"Financial input {input_id} not found")
|
|
self.session.delete(entity)
|
|
|
|
def latest_created_at(self) -> datetime | None:
|
|
stmt = (
|
|
select(FinancialInput.created_at)
|
|
.order_by(FinancialInput.created_at.desc())
|
|
.limit(1)
|
|
)
|
|
return self.session.execute(stmt).scalar_one_or_none()
|
|
|
|
|
|
class SimulationParameterRepository:
|
|
"""Persistence operations for SimulationParameter entities."""
|
|
|
|
def __init__(self, session: Session) -> None:
|
|
self.session = session
|
|
|
|
def list_for_scenario(self, scenario_id: int) -> Sequence[SimulationParameter]:
|
|
stmt = (
|
|
select(SimulationParameter)
|
|
.where(SimulationParameter.scenario_id == scenario_id)
|
|
.order_by(SimulationParameter.created_at)
|
|
)
|
|
return self.session.execute(stmt).scalars().all()
|
|
|
|
def bulk_upsert(
|
|
self, parameters: Iterable[SimulationParameter]
|
|
) -> Sequence[SimulationParameter]:
|
|
entities = list(parameters)
|
|
self.session.add_all(entities)
|
|
try:
|
|
self.session.flush()
|
|
except IntegrityError as exc: # pragma: no cover
|
|
raise EntityConflictError(
|
|
"Simulation parameter violates constraints") from exc
|
|
return entities
|
|
|
|
def delete(self, parameter_id: int) -> None:
|
|
stmt = select(SimulationParameter).where(
|
|
SimulationParameter.id == parameter_id)
|
|
entity = self.session.execute(stmt).scalar_one_or_none()
|
|
if entity is None:
|
|
raise EntityNotFoundError(
|
|
f"Simulation parameter {parameter_id} not found")
|
|
self.session.delete(entity)
|
|
|
|
|
|
class RoleRepository:
|
|
"""Persistence operations for Role entities."""
|
|
|
|
def __init__(self, session: Session) -> None:
|
|
self.session = session
|
|
|
|
def list(self) -> Sequence[Role]:
|
|
stmt = select(Role).order_by(Role.name)
|
|
return self.session.execute(stmt).scalars().all()
|
|
|
|
def get(self, role_id: int) -> Role:
|
|
stmt = select(Role).where(Role.id == role_id)
|
|
role = self.session.execute(stmt).scalar_one_or_none()
|
|
if role is None:
|
|
raise EntityNotFoundError(f"Role {role_id} not found")
|
|
return role
|
|
|
|
def get_by_name(self, name: str) -> Role | None:
|
|
stmt = select(Role).where(Role.name == name)
|
|
return self.session.execute(stmt).scalar_one_or_none()
|
|
|
|
def create(self, role: Role) -> Role:
|
|
self.session.add(role)
|
|
try:
|
|
self.session.flush()
|
|
except IntegrityError as exc: # pragma: no cover - DB constraint enforcement
|
|
raise EntityConflictError(
|
|
"Role violates uniqueness constraints") from exc
|
|
return role
|
|
|
|
|
|
class UserRepository:
|
|
"""Persistence operations for User entities and their role assignments."""
|
|
|
|
def __init__(self, session: Session) -> None:
|
|
self.session = session
|
|
|
|
def list(self, *, with_roles: bool = False) -> Sequence[User]:
|
|
stmt = select(User).order_by(User.created_at)
|
|
if with_roles:
|
|
stmt = stmt.options(selectinload(User.roles))
|
|
return self.session.execute(stmt).scalars().all()
|
|
|
|
def _apply_role_option(self, stmt, with_roles: bool):
|
|
if with_roles:
|
|
stmt = stmt.options(
|
|
joinedload(User.role_assignments).joinedload(UserRole.role),
|
|
selectinload(User.roles),
|
|
)
|
|
return stmt
|
|
|
|
def get(self, user_id: int, *, with_roles: bool = False) -> User:
|
|
stmt = select(User).where(User.id == user_id).execution_options(
|
|
populate_existing=True)
|
|
stmt = self._apply_role_option(stmt, with_roles)
|
|
result = self.session.execute(stmt)
|
|
if with_roles:
|
|
result = result.unique()
|
|
user = result.scalar_one_or_none()
|
|
if user is None:
|
|
raise EntityNotFoundError(f"User {user_id} not found")
|
|
return user
|
|
|
|
def get_by_email(self, email: str, *, with_roles: bool = False) -> User | None:
|
|
stmt = select(User).where(User.email == email).execution_options(
|
|
populate_existing=True)
|
|
stmt = self._apply_role_option(stmt, with_roles)
|
|
result = self.session.execute(stmt)
|
|
if with_roles:
|
|
result = result.unique()
|
|
return result.scalar_one_or_none()
|
|
|
|
def get_by_username(self, username: str, *, with_roles: bool = False) -> User | None:
|
|
stmt = select(User).where(User.username ==
|
|
username).execution_options(populate_existing=True)
|
|
stmt = self._apply_role_option(stmt, with_roles)
|
|
result = self.session.execute(stmt)
|
|
if with_roles:
|
|
result = result.unique()
|
|
return result.scalar_one_or_none()
|
|
|
|
def create(self, user: User) -> User:
|
|
self.session.add(user)
|
|
try:
|
|
self.session.flush()
|
|
except IntegrityError as exc: # pragma: no cover - DB constraint enforcement
|
|
raise EntityConflictError(
|
|
"User violates uniqueness constraints") from exc
|
|
return user
|
|
|
|
def assign_role(
|
|
self,
|
|
*,
|
|
user_id: int,
|
|
role_id: int,
|
|
granted_by: int | None = None,
|
|
) -> UserRole:
|
|
stmt = select(UserRole).where(
|
|
UserRole.user_id == user_id,
|
|
UserRole.role_id == role_id,
|
|
)
|
|
assignment = self.session.execute(stmt).scalar_one_or_none()
|
|
if assignment:
|
|
return assignment
|
|
|
|
assignment = UserRole(
|
|
user_id=user_id,
|
|
role_id=role_id,
|
|
granted_by=granted_by,
|
|
)
|
|
self.session.add(assignment)
|
|
try:
|
|
self.session.flush()
|
|
except IntegrityError as exc: # pragma: no cover - DB constraint enforcement
|
|
raise EntityConflictError(
|
|
"Assignment violates constraints") from exc
|
|
return assignment
|
|
|
|
def revoke_role(self, *, user_id: int, role_id: int) -> None:
|
|
stmt = select(UserRole).where(
|
|
UserRole.user_id == user_id,
|
|
UserRole.role_id == role_id,
|
|
)
|
|
assignment = self.session.execute(stmt).scalar_one_or_none()
|
|
if assignment is None:
|
|
raise EntityNotFoundError(
|
|
f"Role {role_id} not assigned to user {user_id}")
|
|
self.session.delete(assignment)
|
|
self.session.flush()
|
|
|
|
|
|
DEFAULT_ROLE_DEFINITIONS: tuple[dict[str, str], ...] = (
|
|
{
|
|
"name": "admin",
|
|
"display_name": "Administrator",
|
|
"description": "Full platform access with user management rights.",
|
|
},
|
|
{
|
|
"name": "project_manager",
|
|
"display_name": "Project Manager",
|
|
"description": "Manage projects, scenarios, and associated data.",
|
|
},
|
|
{
|
|
"name": "analyst",
|
|
"display_name": "Analyst",
|
|
"description": "Review dashboards and scenario outputs.",
|
|
},
|
|
{
|
|
"name": "viewer",
|
|
"display_name": "Viewer",
|
|
"description": "Read-only access to assigned projects and reports.",
|
|
},
|
|
)
|
|
|
|
|
|
def ensure_default_roles(role_repo: RoleRepository) -> list[Role]:
|
|
"""Ensure standard roles exist, creating missing ones.
|
|
|
|
Returns all current role records in creation order.
|
|
"""
|
|
|
|
roles: list[Role] = []
|
|
for definition in DEFAULT_ROLE_DEFINITIONS:
|
|
existing = role_repo.get_by_name(definition["name"])
|
|
if existing:
|
|
roles.append(existing)
|
|
continue
|
|
role = Role(**definition)
|
|
roles.append(role_repo.create(role))
|
|
return roles
|
|
|
|
|
|
def ensure_admin_user(
|
|
user_repo: UserRepository,
|
|
role_repo: RoleRepository,
|
|
*,
|
|
email: str,
|
|
username: str,
|
|
password: str,
|
|
) -> User:
|
|
"""Ensure an administrator user exists and holds the admin role."""
|
|
|
|
user = user_repo.get_by_email(email, with_roles=True)
|
|
if user is None:
|
|
user = User(
|
|
email=email,
|
|
username=username,
|
|
password_hash=User.hash_password(password),
|
|
is_active=True,
|
|
is_superuser=True,
|
|
)
|
|
user_repo.create(user)
|
|
else:
|
|
if not user.is_active:
|
|
user.is_active = True
|
|
if not user.is_superuser:
|
|
user.is_superuser = True
|
|
user_repo.session.flush()
|
|
|
|
admin_role = role_repo.get_by_name("admin")
|
|
if admin_role is None: # pragma: no cover - safety if ensure_default_roles wasn't called
|
|
admin_role = role_repo.create(
|
|
Role(
|
|
name="admin",
|
|
display_name="Administrator",
|
|
description="Full platform access with user management rights.",
|
|
)
|
|
)
|
|
|
|
user_repo.assign_role(
|
|
user_id=user.id,
|
|
role_id=admin_role.id,
|
|
granted_by=user.id,
|
|
)
|
|
return user
|