594 lines
20 KiB
Python
594 lines
20 KiB
Python
from __future__ import annotations
|
|
|
|
from collections.abc import Iterable
|
|
from datetime import datetime
|
|
from typing import Mapping, Sequence
|
|
|
|
from sqlalchemy import select, func
|
|
from sqlalchemy.exc import IntegrityError
|
|
from sqlalchemy.orm import Session, joinedload, selectinload
|
|
|
|
from models import (
|
|
FinancialInput,
|
|
Project,
|
|
ResourceType,
|
|
Role,
|
|
Scenario,
|
|
ScenarioStatus,
|
|
SimulationParameter,
|
|
User,
|
|
UserRole,
|
|
)
|
|
from services.exceptions import EntityConflictError, EntityNotFoundError
|
|
from services.export_query import ProjectExportFilters, ScenarioExportFilters
|
|
|
|
|
|
class ProjectRepository:
|
|
"""Persistence operations for Project entities."""
|
|
|
|
def __init__(self, session: Session) -> None:
|
|
self.session = session
|
|
|
|
def list(self, *, with_children: bool = False) -> Sequence[Project]:
|
|
stmt = select(Project).order_by(Project.created_at)
|
|
if with_children:
|
|
stmt = stmt.options(selectinload(Project.scenarios))
|
|
return self.session.execute(stmt).scalars().all()
|
|
|
|
def count(self) -> int:
|
|
stmt = select(func.count(Project.id))
|
|
return self.session.execute(stmt).scalar_one()
|
|
|
|
def recent(self, limit: int = 5) -> Sequence[Project]:
|
|
stmt = (
|
|
select(Project)
|
|
.order_by(Project.updated_at.desc())
|
|
.limit(limit)
|
|
)
|
|
return self.session.execute(stmt).scalars().all()
|
|
|
|
def get(self, project_id: int, *, with_children: bool = False) -> Project:
|
|
stmt = select(Project).where(Project.id == project_id)
|
|
if with_children:
|
|
stmt = stmt.options(joinedload(Project.scenarios))
|
|
result = self.session.execute(stmt)
|
|
if with_children:
|
|
result = result.unique()
|
|
project = result.scalar_one_or_none()
|
|
if project is None:
|
|
raise EntityNotFoundError(f"Project {project_id} not found")
|
|
return project
|
|
|
|
def exists(self, project_id: int) -> bool:
|
|
stmt = select(Project.id).where(Project.id == project_id).limit(1)
|
|
return self.session.execute(stmt).scalar_one_or_none() is not None
|
|
|
|
def create(self, project: Project) -> Project:
|
|
self.session.add(project)
|
|
try:
|
|
self.session.flush()
|
|
except IntegrityError as exc: # pragma: no cover - reliance on DB constraints
|
|
raise EntityConflictError(
|
|
"Project violates uniqueness constraints") from exc
|
|
return project
|
|
|
|
def find_by_names(self, names: Iterable[str]) -> Mapping[str, Project]:
|
|
normalised = {name.strip().lower()
|
|
for name in names if name and name.strip()}
|
|
if not normalised:
|
|
return {}
|
|
stmt = select(Project).where(func.lower(Project.name).in_(normalised))
|
|
records = self.session.execute(stmt).scalars().all()
|
|
return {project.name.lower(): project for project in records}
|
|
|
|
def filtered_for_export(
|
|
self,
|
|
filters: ProjectExportFilters | None = None,
|
|
*,
|
|
include_scenarios: bool = False,
|
|
) -> Sequence[Project]:
|
|
stmt = select(Project)
|
|
if include_scenarios:
|
|
stmt = stmt.options(selectinload(Project.scenarios))
|
|
|
|
if filters:
|
|
ids = filters.normalised_ids()
|
|
if ids:
|
|
stmt = stmt.where(Project.id.in_(ids))
|
|
|
|
name_matches = filters.normalised_names()
|
|
if name_matches:
|
|
stmt = stmt.where(func.lower(Project.name).in_(name_matches))
|
|
|
|
name_pattern = filters.name_search_pattern()
|
|
if name_pattern:
|
|
stmt = stmt.where(Project.name.ilike(name_pattern))
|
|
|
|
locations = filters.normalised_locations()
|
|
if locations:
|
|
stmt = stmt.where(func.lower(Project.location).in_(locations))
|
|
|
|
if filters.operation_types:
|
|
stmt = stmt.where(Project.operation_type.in_(
|
|
filters.operation_types))
|
|
|
|
if filters.created_from:
|
|
stmt = stmt.where(Project.created_at >= filters.created_from)
|
|
|
|
if filters.created_to:
|
|
stmt = stmt.where(Project.created_at <= filters.created_to)
|
|
|
|
if filters.updated_from:
|
|
stmt = stmt.where(Project.updated_at >= filters.updated_from)
|
|
|
|
if filters.updated_to:
|
|
stmt = stmt.where(Project.updated_at <= filters.updated_to)
|
|
|
|
stmt = stmt.order_by(Project.name, Project.id)
|
|
return self.session.execute(stmt).scalars().all()
|
|
|
|
def delete(self, project_id: int) -> None:
|
|
project = self.get(project_id)
|
|
self.session.delete(project)
|
|
|
|
|
|
class ScenarioRepository:
|
|
"""Persistence operations for Scenario entities."""
|
|
|
|
def __init__(self, session: Session) -> None:
|
|
self.session = session
|
|
|
|
def list_for_project(self, project_id: int) -> Sequence[Scenario]:
|
|
stmt = (
|
|
select(Scenario)
|
|
.where(Scenario.project_id == project_id)
|
|
.order_by(Scenario.created_at)
|
|
)
|
|
return self.session.execute(stmt).scalars().all()
|
|
|
|
def count(self) -> int:
|
|
stmt = select(func.count(Scenario.id))
|
|
return self.session.execute(stmt).scalar_one()
|
|
|
|
def count_by_status(self, status: ScenarioStatus) -> int:
|
|
stmt = select(func.count(Scenario.id)).where(Scenario.status == status)
|
|
return self.session.execute(stmt).scalar_one()
|
|
|
|
def recent(self, limit: int = 5, *, with_project: bool = False) -> Sequence[Scenario]:
|
|
stmt = select(Scenario).order_by(
|
|
Scenario.updated_at.desc()).limit(limit)
|
|
if with_project:
|
|
stmt = stmt.options(joinedload(Scenario.project))
|
|
return self.session.execute(stmt).scalars().all()
|
|
|
|
def list_by_status(
|
|
self,
|
|
status: ScenarioStatus,
|
|
*,
|
|
limit: int | None = None,
|
|
with_project: bool = False,
|
|
) -> Sequence[Scenario]:
|
|
stmt = (
|
|
select(Scenario)
|
|
.where(Scenario.status == status)
|
|
.order_by(Scenario.updated_at.desc())
|
|
)
|
|
if with_project:
|
|
stmt = stmt.options(joinedload(Scenario.project))
|
|
if limit is not None:
|
|
stmt = stmt.limit(limit)
|
|
return self.session.execute(stmt).scalars().all()
|
|
|
|
def get(self, scenario_id: int, *, with_children: bool = False) -> Scenario:
|
|
stmt = select(Scenario).where(Scenario.id == scenario_id)
|
|
if with_children:
|
|
stmt = stmt.options(
|
|
joinedload(Scenario.financial_inputs),
|
|
joinedload(Scenario.simulation_parameters),
|
|
)
|
|
result = self.session.execute(stmt)
|
|
if with_children:
|
|
result = result.unique()
|
|
scenario = result.scalar_one_or_none()
|
|
if scenario is None:
|
|
raise EntityNotFoundError(f"Scenario {scenario_id} not found")
|
|
return scenario
|
|
|
|
def exists(self, scenario_id: int) -> bool:
|
|
stmt = select(Scenario.id).where(Scenario.id == scenario_id).limit(1)
|
|
return self.session.execute(stmt).scalar_one_or_none() is not None
|
|
|
|
def create(self, scenario: Scenario) -> Scenario:
|
|
self.session.add(scenario)
|
|
try:
|
|
self.session.flush()
|
|
except IntegrityError as exc: # pragma: no cover
|
|
raise EntityConflictError("Scenario violates constraints") from exc
|
|
return scenario
|
|
|
|
def find_by_project_and_names(
|
|
self,
|
|
project_id: int,
|
|
names: Iterable[str],
|
|
) -> Mapping[str, Scenario]:
|
|
normalised = {name.strip().lower()
|
|
for name in names if name and name.strip()}
|
|
if not normalised:
|
|
return {}
|
|
stmt = (
|
|
select(Scenario)
|
|
.where(
|
|
Scenario.project_id == project_id,
|
|
func.lower(Scenario.name).in_(normalised),
|
|
)
|
|
)
|
|
records = self.session.execute(stmt).scalars().all()
|
|
return {scenario.name.lower(): scenario for scenario in records}
|
|
|
|
def filtered_for_export(
|
|
self,
|
|
filters: ScenarioExportFilters | None = None,
|
|
*,
|
|
include_project: bool = True,
|
|
) -> Sequence[Scenario]:
|
|
stmt = select(Scenario)
|
|
if include_project:
|
|
stmt = stmt.options(joinedload(Scenario.project))
|
|
|
|
if filters:
|
|
scenario_ids = filters.normalised_ids()
|
|
if scenario_ids:
|
|
stmt = stmt.where(Scenario.id.in_(scenario_ids))
|
|
|
|
project_ids = filters.normalised_project_ids()
|
|
if project_ids:
|
|
stmt = stmt.where(Scenario.project_id.in_(project_ids))
|
|
|
|
project_names = filters.normalised_project_names()
|
|
if project_names:
|
|
project_id_select = select(Project.id).where(
|
|
func.lower(Project.name).in_(project_names)
|
|
)
|
|
stmt = stmt.where(Scenario.project_id.in_(project_id_select))
|
|
|
|
name_pattern = filters.name_search_pattern()
|
|
if name_pattern:
|
|
stmt = stmt.where(Scenario.name.ilike(name_pattern))
|
|
|
|
if filters.statuses:
|
|
stmt = stmt.where(Scenario.status.in_(filters.statuses))
|
|
|
|
if filters.start_date_from:
|
|
stmt = stmt.where(Scenario.start_date >=
|
|
filters.start_date_from)
|
|
|
|
if filters.start_date_to:
|
|
stmt = stmt.where(Scenario.start_date <= filters.start_date_to)
|
|
|
|
if filters.end_date_from:
|
|
stmt = stmt.where(Scenario.end_date >= filters.end_date_from)
|
|
|
|
if filters.end_date_to:
|
|
stmt = stmt.where(Scenario.end_date <= filters.end_date_to)
|
|
|
|
if filters.created_from:
|
|
stmt = stmt.where(Scenario.created_at >= filters.created_from)
|
|
|
|
if filters.created_to:
|
|
stmt = stmt.where(Scenario.created_at <= filters.created_to)
|
|
|
|
if filters.updated_from:
|
|
stmt = stmt.where(Scenario.updated_at >= filters.updated_from)
|
|
|
|
if filters.updated_to:
|
|
stmt = stmt.where(Scenario.updated_at <= filters.updated_to)
|
|
|
|
currencies = filters.normalised_currencies()
|
|
if currencies:
|
|
stmt = stmt.where(func.upper(
|
|
Scenario.currency).in_(currencies))
|
|
|
|
if filters.primary_resources:
|
|
stmt = stmt.where(Scenario.primary_resource.in_(
|
|
filters.primary_resources))
|
|
|
|
stmt = stmt.order_by(Scenario.name, Scenario.id)
|
|
return self.session.execute(stmt).scalars().all()
|
|
|
|
def delete(self, scenario_id: int) -> None:
|
|
scenario = self.get(scenario_id)
|
|
self.session.delete(scenario)
|
|
|
|
|
|
class FinancialInputRepository:
|
|
"""Persistence operations for FinancialInput entities."""
|
|
|
|
def __init__(self, session: Session) -> None:
|
|
self.session = session
|
|
|
|
def list_for_scenario(self, scenario_id: int) -> Sequence[FinancialInput]:
|
|
stmt = (
|
|
select(FinancialInput)
|
|
.where(FinancialInput.scenario_id == scenario_id)
|
|
.order_by(FinancialInput.created_at)
|
|
)
|
|
return self.session.execute(stmt).scalars().all()
|
|
|
|
def bulk_upsert(self, inputs: Iterable[FinancialInput]) -> Sequence[FinancialInput]:
|
|
entities = list(inputs)
|
|
self.session.add_all(entities)
|
|
try:
|
|
self.session.flush()
|
|
except IntegrityError as exc: # pragma: no cover
|
|
raise EntityConflictError(
|
|
"Financial input violates constraints") from exc
|
|
return entities
|
|
|
|
def delete(self, input_id: int) -> None:
|
|
stmt = select(FinancialInput).where(FinancialInput.id == input_id)
|
|
entity = self.session.execute(stmt).scalar_one_or_none()
|
|
if entity is None:
|
|
raise EntityNotFoundError(f"Financial input {input_id} not found")
|
|
self.session.delete(entity)
|
|
|
|
def latest_created_at(self) -> datetime | None:
|
|
stmt = (
|
|
select(FinancialInput.created_at)
|
|
.order_by(FinancialInput.created_at.desc())
|
|
.limit(1)
|
|
)
|
|
return self.session.execute(stmt).scalar_one_or_none()
|
|
|
|
|
|
class SimulationParameterRepository:
|
|
"""Persistence operations for SimulationParameter entities."""
|
|
|
|
def __init__(self, session: Session) -> None:
|
|
self.session = session
|
|
|
|
def list_for_scenario(self, scenario_id: int) -> Sequence[SimulationParameter]:
|
|
stmt = (
|
|
select(SimulationParameter)
|
|
.where(SimulationParameter.scenario_id == scenario_id)
|
|
.order_by(SimulationParameter.created_at)
|
|
)
|
|
return self.session.execute(stmt).scalars().all()
|
|
|
|
def bulk_upsert(
|
|
self, parameters: Iterable[SimulationParameter]
|
|
) -> Sequence[SimulationParameter]:
|
|
entities = list(parameters)
|
|
self.session.add_all(entities)
|
|
try:
|
|
self.session.flush()
|
|
except IntegrityError as exc: # pragma: no cover
|
|
raise EntityConflictError(
|
|
"Simulation parameter violates constraints") from exc
|
|
return entities
|
|
|
|
def delete(self, parameter_id: int) -> None:
|
|
stmt = select(SimulationParameter).where(
|
|
SimulationParameter.id == parameter_id)
|
|
entity = self.session.execute(stmt).scalar_one_or_none()
|
|
if entity is None:
|
|
raise EntityNotFoundError(
|
|
f"Simulation parameter {parameter_id} not found")
|
|
self.session.delete(entity)
|
|
|
|
|
|
class RoleRepository:
|
|
"""Persistence operations for Role entities."""
|
|
|
|
def __init__(self, session: Session) -> None:
|
|
self.session = session
|
|
|
|
def list(self) -> Sequence[Role]:
|
|
stmt = select(Role).order_by(Role.name)
|
|
return self.session.execute(stmt).scalars().all()
|
|
|
|
def get(self, role_id: int) -> Role:
|
|
stmt = select(Role).where(Role.id == role_id)
|
|
role = self.session.execute(stmt).scalar_one_or_none()
|
|
if role is None:
|
|
raise EntityNotFoundError(f"Role {role_id} not found")
|
|
return role
|
|
|
|
def get_by_name(self, name: str) -> Role | None:
|
|
stmt = select(Role).where(Role.name == name)
|
|
return self.session.execute(stmt).scalar_one_or_none()
|
|
|
|
def create(self, role: Role) -> Role:
|
|
self.session.add(role)
|
|
try:
|
|
self.session.flush()
|
|
except IntegrityError as exc: # pragma: no cover - DB constraint enforcement
|
|
raise EntityConflictError(
|
|
"Role violates uniqueness constraints") from exc
|
|
return role
|
|
|
|
|
|
class UserRepository:
|
|
"""Persistence operations for User entities and their role assignments."""
|
|
|
|
def __init__(self, session: Session) -> None:
|
|
self.session = session
|
|
|
|
def list(self, *, with_roles: bool = False) -> Sequence[User]:
|
|
stmt = select(User).order_by(User.created_at)
|
|
if with_roles:
|
|
stmt = stmt.options(selectinload(User.roles))
|
|
return self.session.execute(stmt).scalars().all()
|
|
|
|
def _apply_role_option(self, stmt, with_roles: bool):
|
|
if with_roles:
|
|
stmt = stmt.options(
|
|
joinedload(User.role_assignments).joinedload(UserRole.role),
|
|
selectinload(User.roles),
|
|
)
|
|
return stmt
|
|
|
|
def get(self, user_id: int, *, with_roles: bool = False) -> User:
|
|
stmt = select(User).where(User.id == user_id).execution_options(
|
|
populate_existing=True)
|
|
stmt = self._apply_role_option(stmt, with_roles)
|
|
result = self.session.execute(stmt)
|
|
if with_roles:
|
|
result = result.unique()
|
|
user = result.scalar_one_or_none()
|
|
if user is None:
|
|
raise EntityNotFoundError(f"User {user_id} not found")
|
|
return user
|
|
|
|
def get_by_email(self, email: str, *, with_roles: bool = False) -> User | None:
|
|
stmt = select(User).where(User.email == email).execution_options(
|
|
populate_existing=True)
|
|
stmt = self._apply_role_option(stmt, with_roles)
|
|
result = self.session.execute(stmt)
|
|
if with_roles:
|
|
result = result.unique()
|
|
return result.scalar_one_or_none()
|
|
|
|
def get_by_username(self, username: str, *, with_roles: bool = False) -> User | None:
|
|
stmt = select(User).where(User.username ==
|
|
username).execution_options(populate_existing=True)
|
|
stmt = self._apply_role_option(stmt, with_roles)
|
|
result = self.session.execute(stmt)
|
|
if with_roles:
|
|
result = result.unique()
|
|
return result.scalar_one_or_none()
|
|
|
|
def create(self, user: User) -> User:
|
|
self.session.add(user)
|
|
try:
|
|
self.session.flush()
|
|
except IntegrityError as exc: # pragma: no cover - DB constraint enforcement
|
|
raise EntityConflictError(
|
|
"User violates uniqueness constraints") from exc
|
|
return user
|
|
|
|
def assign_role(
|
|
self,
|
|
*,
|
|
user_id: int,
|
|
role_id: int,
|
|
granted_by: int | None = None,
|
|
) -> UserRole:
|
|
stmt = select(UserRole).where(
|
|
UserRole.user_id == user_id,
|
|
UserRole.role_id == role_id,
|
|
)
|
|
assignment = self.session.execute(stmt).scalar_one_or_none()
|
|
if assignment:
|
|
return assignment
|
|
|
|
assignment = UserRole(
|
|
user_id=user_id,
|
|
role_id=role_id,
|
|
granted_by=granted_by,
|
|
)
|
|
self.session.add(assignment)
|
|
try:
|
|
self.session.flush()
|
|
except IntegrityError as exc: # pragma: no cover - DB constraint enforcement
|
|
raise EntityConflictError(
|
|
"Assignment violates constraints") from exc
|
|
return assignment
|
|
|
|
def revoke_role(self, *, user_id: int, role_id: int) -> None:
|
|
stmt = select(UserRole).where(
|
|
UserRole.user_id == user_id,
|
|
UserRole.role_id == role_id,
|
|
)
|
|
assignment = self.session.execute(stmt).scalar_one_or_none()
|
|
if assignment is None:
|
|
raise EntityNotFoundError(
|
|
f"Role {role_id} not assigned to user {user_id}")
|
|
self.session.delete(assignment)
|
|
self.session.flush()
|
|
|
|
|
|
DEFAULT_ROLE_DEFINITIONS: tuple[dict[str, str], ...] = (
|
|
{
|
|
"name": "admin",
|
|
"display_name": "Administrator",
|
|
"description": "Full platform access with user management rights.",
|
|
},
|
|
{
|
|
"name": "project_manager",
|
|
"display_name": "Project Manager",
|
|
"description": "Manage projects, scenarios, and associated data.",
|
|
},
|
|
{
|
|
"name": "analyst",
|
|
"display_name": "Analyst",
|
|
"description": "Review dashboards and scenario outputs.",
|
|
},
|
|
{
|
|
"name": "viewer",
|
|
"display_name": "Viewer",
|
|
"description": "Read-only access to assigned projects and reports.",
|
|
},
|
|
)
|
|
|
|
|
|
def ensure_default_roles(role_repo: RoleRepository) -> list[Role]:
|
|
"""Ensure standard roles exist, creating missing ones.
|
|
|
|
Returns all current role records in creation order.
|
|
"""
|
|
|
|
roles: list[Role] = []
|
|
for definition in DEFAULT_ROLE_DEFINITIONS:
|
|
existing = role_repo.get_by_name(definition["name"])
|
|
if existing:
|
|
roles.append(existing)
|
|
continue
|
|
role = Role(**definition)
|
|
roles.append(role_repo.create(role))
|
|
return roles
|
|
|
|
|
|
def ensure_admin_user(
|
|
user_repo: UserRepository,
|
|
role_repo: RoleRepository,
|
|
*,
|
|
email: str,
|
|
username: str,
|
|
password: str,
|
|
) -> User:
|
|
"""Ensure an administrator user exists and holds the admin role."""
|
|
|
|
user = user_repo.get_by_email(email, with_roles=True)
|
|
if user is None:
|
|
user = User(
|
|
email=email,
|
|
username=username,
|
|
password_hash=User.hash_password(password),
|
|
is_active=True,
|
|
is_superuser=True,
|
|
)
|
|
user_repo.create(user)
|
|
else:
|
|
if not user.is_active:
|
|
user.is_active = True
|
|
if not user.is_superuser:
|
|
user.is_superuser = True
|
|
user_repo.session.flush()
|
|
|
|
admin_role = role_repo.get_by_name("admin")
|
|
if admin_role is None: # pragma: no cover - safety if ensure_default_roles wasn't called
|
|
admin_role = role_repo.create(
|
|
Role(
|
|
name="admin",
|
|
display_name="Administrator",
|
|
description="Full platform access with user management rights.",
|
|
)
|
|
)
|
|
|
|
user_repo.assign_role(
|
|
user_id=user.id,
|
|
role_id=admin_role.id,
|
|
granted_by=user.id,
|
|
)
|
|
return user
|