Files
calminer/services/repositories.py

594 lines
20 KiB
Python

from __future__ import annotations
from collections.abc import Iterable
from datetime import datetime
from typing import Mapping, Sequence
from sqlalchemy import select, func
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session, joinedload, selectinload
from models import (
FinancialInput,
Project,
ResourceType,
Role,
Scenario,
ScenarioStatus,
SimulationParameter,
User,
UserRole,
)
from services.exceptions import EntityConflictError, EntityNotFoundError
from services.export_query import ProjectExportFilters, ScenarioExportFilters
class ProjectRepository:
"""Persistence operations for Project entities."""
def __init__(self, session: Session) -> None:
self.session = session
def list(self, *, with_children: bool = False) -> Sequence[Project]:
stmt = select(Project).order_by(Project.created_at)
if with_children:
stmt = stmt.options(selectinload(Project.scenarios))
return self.session.execute(stmt).scalars().all()
def count(self) -> int:
stmt = select(func.count(Project.id))
return self.session.execute(stmt).scalar_one()
def recent(self, limit: int = 5) -> Sequence[Project]:
stmt = (
select(Project)
.order_by(Project.updated_at.desc())
.limit(limit)
)
return self.session.execute(stmt).scalars().all()
def get(self, project_id: int, *, with_children: bool = False) -> Project:
stmt = select(Project).where(Project.id == project_id)
if with_children:
stmt = stmt.options(joinedload(Project.scenarios))
result = self.session.execute(stmt)
if with_children:
result = result.unique()
project = result.scalar_one_or_none()
if project is None:
raise EntityNotFoundError(f"Project {project_id} not found")
return project
def exists(self, project_id: int) -> bool:
stmt = select(Project.id).where(Project.id == project_id).limit(1)
return self.session.execute(stmt).scalar_one_or_none() is not None
def create(self, project: Project) -> Project:
self.session.add(project)
try:
self.session.flush()
except IntegrityError as exc: # pragma: no cover - reliance on DB constraints
raise EntityConflictError(
"Project violates uniqueness constraints") from exc
return project
def find_by_names(self, names: Iterable[str]) -> Mapping[str, Project]:
normalised = {name.strip().lower()
for name in names if name and name.strip()}
if not normalised:
return {}
stmt = select(Project).where(func.lower(Project.name).in_(normalised))
records = self.session.execute(stmt).scalars().all()
return {project.name.lower(): project for project in records}
def filtered_for_export(
self,
filters: ProjectExportFilters | None = None,
*,
include_scenarios: bool = False,
) -> Sequence[Project]:
stmt = select(Project)
if include_scenarios:
stmt = stmt.options(selectinload(Project.scenarios))
if filters:
ids = filters.normalised_ids()
if ids:
stmt = stmt.where(Project.id.in_(ids))
name_matches = filters.normalised_names()
if name_matches:
stmt = stmt.where(func.lower(Project.name).in_(name_matches))
name_pattern = filters.name_search_pattern()
if name_pattern:
stmt = stmt.where(Project.name.ilike(name_pattern))
locations = filters.normalised_locations()
if locations:
stmt = stmt.where(func.lower(Project.location).in_(locations))
if filters.operation_types:
stmt = stmt.where(Project.operation_type.in_(
filters.operation_types))
if filters.created_from:
stmt = stmt.where(Project.created_at >= filters.created_from)
if filters.created_to:
stmt = stmt.where(Project.created_at <= filters.created_to)
if filters.updated_from:
stmt = stmt.where(Project.updated_at >= filters.updated_from)
if filters.updated_to:
stmt = stmt.where(Project.updated_at <= filters.updated_to)
stmt = stmt.order_by(Project.name, Project.id)
return self.session.execute(stmt).scalars().all()
def delete(self, project_id: int) -> None:
project = self.get(project_id)
self.session.delete(project)
class ScenarioRepository:
"""Persistence operations for Scenario entities."""
def __init__(self, session: Session) -> None:
self.session = session
def list_for_project(self, project_id: int) -> Sequence[Scenario]:
stmt = (
select(Scenario)
.where(Scenario.project_id == project_id)
.order_by(Scenario.created_at)
)
return self.session.execute(stmt).scalars().all()
def count(self) -> int:
stmt = select(func.count(Scenario.id))
return self.session.execute(stmt).scalar_one()
def count_by_status(self, status: ScenarioStatus) -> int:
stmt = select(func.count(Scenario.id)).where(Scenario.status == status)
return self.session.execute(stmt).scalar_one()
def recent(self, limit: int = 5, *, with_project: bool = False) -> Sequence[Scenario]:
stmt = select(Scenario).order_by(
Scenario.updated_at.desc()).limit(limit)
if with_project:
stmt = stmt.options(joinedload(Scenario.project))
return self.session.execute(stmt).scalars().all()
def list_by_status(
self,
status: ScenarioStatus,
*,
limit: int | None = None,
with_project: bool = False,
) -> Sequence[Scenario]:
stmt = (
select(Scenario)
.where(Scenario.status == status)
.order_by(Scenario.updated_at.desc())
)
if with_project:
stmt = stmt.options(joinedload(Scenario.project))
if limit is not None:
stmt = stmt.limit(limit)
return self.session.execute(stmt).scalars().all()
def get(self, scenario_id: int, *, with_children: bool = False) -> Scenario:
stmt = select(Scenario).where(Scenario.id == scenario_id)
if with_children:
stmt = stmt.options(
joinedload(Scenario.financial_inputs),
joinedload(Scenario.simulation_parameters),
)
result = self.session.execute(stmt)
if with_children:
result = result.unique()
scenario = result.scalar_one_or_none()
if scenario is None:
raise EntityNotFoundError(f"Scenario {scenario_id} not found")
return scenario
def exists(self, scenario_id: int) -> bool:
stmt = select(Scenario.id).where(Scenario.id == scenario_id).limit(1)
return self.session.execute(stmt).scalar_one_or_none() is not None
def create(self, scenario: Scenario) -> Scenario:
self.session.add(scenario)
try:
self.session.flush()
except IntegrityError as exc: # pragma: no cover
raise EntityConflictError("Scenario violates constraints") from exc
return scenario
def find_by_project_and_names(
self,
project_id: int,
names: Iterable[str],
) -> Mapping[str, Scenario]:
normalised = {name.strip().lower()
for name in names if name and name.strip()}
if not normalised:
return {}
stmt = (
select(Scenario)
.where(
Scenario.project_id == project_id,
func.lower(Scenario.name).in_(normalised),
)
)
records = self.session.execute(stmt).scalars().all()
return {scenario.name.lower(): scenario for scenario in records}
def filtered_for_export(
self,
filters: ScenarioExportFilters | None = None,
*,
include_project: bool = True,
) -> Sequence[Scenario]:
stmt = select(Scenario)
if include_project:
stmt = stmt.options(joinedload(Scenario.project))
if filters:
scenario_ids = filters.normalised_ids()
if scenario_ids:
stmt = stmt.where(Scenario.id.in_(scenario_ids))
project_ids = filters.normalised_project_ids()
if project_ids:
stmt = stmt.where(Scenario.project_id.in_(project_ids))
project_names = filters.normalised_project_names()
if project_names:
project_id_select = select(Project.id).where(
func.lower(Project.name).in_(project_names)
)
stmt = stmt.where(Scenario.project_id.in_(project_id_select))
name_pattern = filters.name_search_pattern()
if name_pattern:
stmt = stmt.where(Scenario.name.ilike(name_pattern))
if filters.statuses:
stmt = stmt.where(Scenario.status.in_(filters.statuses))
if filters.start_date_from:
stmt = stmt.where(Scenario.start_date >=
filters.start_date_from)
if filters.start_date_to:
stmt = stmt.where(Scenario.start_date <= filters.start_date_to)
if filters.end_date_from:
stmt = stmt.where(Scenario.end_date >= filters.end_date_from)
if filters.end_date_to:
stmt = stmt.where(Scenario.end_date <= filters.end_date_to)
if filters.created_from:
stmt = stmt.where(Scenario.created_at >= filters.created_from)
if filters.created_to:
stmt = stmt.where(Scenario.created_at <= filters.created_to)
if filters.updated_from:
stmt = stmt.where(Scenario.updated_at >= filters.updated_from)
if filters.updated_to:
stmt = stmt.where(Scenario.updated_at <= filters.updated_to)
currencies = filters.normalised_currencies()
if currencies:
stmt = stmt.where(func.upper(
Scenario.currency).in_(currencies))
if filters.primary_resources:
stmt = stmt.where(Scenario.primary_resource.in_(
filters.primary_resources))
stmt = stmt.order_by(Scenario.name, Scenario.id)
return self.session.execute(stmt).scalars().all()
def delete(self, scenario_id: int) -> None:
scenario = self.get(scenario_id)
self.session.delete(scenario)
class FinancialInputRepository:
"""Persistence operations for FinancialInput entities."""
def __init__(self, session: Session) -> None:
self.session = session
def list_for_scenario(self, scenario_id: int) -> Sequence[FinancialInput]:
stmt = (
select(FinancialInput)
.where(FinancialInput.scenario_id == scenario_id)
.order_by(FinancialInput.created_at)
)
return self.session.execute(stmt).scalars().all()
def bulk_upsert(self, inputs: Iterable[FinancialInput]) -> Sequence[FinancialInput]:
entities = list(inputs)
self.session.add_all(entities)
try:
self.session.flush()
except IntegrityError as exc: # pragma: no cover
raise EntityConflictError(
"Financial input violates constraints") from exc
return entities
def delete(self, input_id: int) -> None:
stmt = select(FinancialInput).where(FinancialInput.id == input_id)
entity = self.session.execute(stmt).scalar_one_or_none()
if entity is None:
raise EntityNotFoundError(f"Financial input {input_id} not found")
self.session.delete(entity)
def latest_created_at(self) -> datetime | None:
stmt = (
select(FinancialInput.created_at)
.order_by(FinancialInput.created_at.desc())
.limit(1)
)
return self.session.execute(stmt).scalar_one_or_none()
class SimulationParameterRepository:
"""Persistence operations for SimulationParameter entities."""
def __init__(self, session: Session) -> None:
self.session = session
def list_for_scenario(self, scenario_id: int) -> Sequence[SimulationParameter]:
stmt = (
select(SimulationParameter)
.where(SimulationParameter.scenario_id == scenario_id)
.order_by(SimulationParameter.created_at)
)
return self.session.execute(stmt).scalars().all()
def bulk_upsert(
self, parameters: Iterable[SimulationParameter]
) -> Sequence[SimulationParameter]:
entities = list(parameters)
self.session.add_all(entities)
try:
self.session.flush()
except IntegrityError as exc: # pragma: no cover
raise EntityConflictError(
"Simulation parameter violates constraints") from exc
return entities
def delete(self, parameter_id: int) -> None:
stmt = select(SimulationParameter).where(
SimulationParameter.id == parameter_id)
entity = self.session.execute(stmt).scalar_one_or_none()
if entity is None:
raise EntityNotFoundError(
f"Simulation parameter {parameter_id} not found")
self.session.delete(entity)
class RoleRepository:
"""Persistence operations for Role entities."""
def __init__(self, session: Session) -> None:
self.session = session
def list(self) -> Sequence[Role]:
stmt = select(Role).order_by(Role.name)
return self.session.execute(stmt).scalars().all()
def get(self, role_id: int) -> Role:
stmt = select(Role).where(Role.id == role_id)
role = self.session.execute(stmt).scalar_one_or_none()
if role is None:
raise EntityNotFoundError(f"Role {role_id} not found")
return role
def get_by_name(self, name: str) -> Role | None:
stmt = select(Role).where(Role.name == name)
return self.session.execute(stmt).scalar_one_or_none()
def create(self, role: Role) -> Role:
self.session.add(role)
try:
self.session.flush()
except IntegrityError as exc: # pragma: no cover - DB constraint enforcement
raise EntityConflictError(
"Role violates uniqueness constraints") from exc
return role
class UserRepository:
"""Persistence operations for User entities and their role assignments."""
def __init__(self, session: Session) -> None:
self.session = session
def list(self, *, with_roles: bool = False) -> Sequence[User]:
stmt = select(User).order_by(User.created_at)
if with_roles:
stmt = stmt.options(selectinload(User.roles))
return self.session.execute(stmt).scalars().all()
def _apply_role_option(self, stmt, with_roles: bool):
if with_roles:
stmt = stmt.options(
joinedload(User.role_assignments).joinedload(UserRole.role),
selectinload(User.roles),
)
return stmt
def get(self, user_id: int, *, with_roles: bool = False) -> User:
stmt = select(User).where(User.id == user_id).execution_options(
populate_existing=True)
stmt = self._apply_role_option(stmt, with_roles)
result = self.session.execute(stmt)
if with_roles:
result = result.unique()
user = result.scalar_one_or_none()
if user is None:
raise EntityNotFoundError(f"User {user_id} not found")
return user
def get_by_email(self, email: str, *, with_roles: bool = False) -> User | None:
stmt = select(User).where(User.email == email).execution_options(
populate_existing=True)
stmt = self._apply_role_option(stmt, with_roles)
result = self.session.execute(stmt)
if with_roles:
result = result.unique()
return result.scalar_one_or_none()
def get_by_username(self, username: str, *, with_roles: bool = False) -> User | None:
stmt = select(User).where(User.username ==
username).execution_options(populate_existing=True)
stmt = self._apply_role_option(stmt, with_roles)
result = self.session.execute(stmt)
if with_roles:
result = result.unique()
return result.scalar_one_or_none()
def create(self, user: User) -> User:
self.session.add(user)
try:
self.session.flush()
except IntegrityError as exc: # pragma: no cover - DB constraint enforcement
raise EntityConflictError(
"User violates uniqueness constraints") from exc
return user
def assign_role(
self,
*,
user_id: int,
role_id: int,
granted_by: int | None = None,
) -> UserRole:
stmt = select(UserRole).where(
UserRole.user_id == user_id,
UserRole.role_id == role_id,
)
assignment = self.session.execute(stmt).scalar_one_or_none()
if assignment:
return assignment
assignment = UserRole(
user_id=user_id,
role_id=role_id,
granted_by=granted_by,
)
self.session.add(assignment)
try:
self.session.flush()
except IntegrityError as exc: # pragma: no cover - DB constraint enforcement
raise EntityConflictError(
"Assignment violates constraints") from exc
return assignment
def revoke_role(self, *, user_id: int, role_id: int) -> None:
stmt = select(UserRole).where(
UserRole.user_id == user_id,
UserRole.role_id == role_id,
)
assignment = self.session.execute(stmt).scalar_one_or_none()
if assignment is None:
raise EntityNotFoundError(
f"Role {role_id} not assigned to user {user_id}")
self.session.delete(assignment)
self.session.flush()
DEFAULT_ROLE_DEFINITIONS: tuple[dict[str, str], ...] = (
{
"name": "admin",
"display_name": "Administrator",
"description": "Full platform access with user management rights.",
},
{
"name": "project_manager",
"display_name": "Project Manager",
"description": "Manage projects, scenarios, and associated data.",
},
{
"name": "analyst",
"display_name": "Analyst",
"description": "Review dashboards and scenario outputs.",
},
{
"name": "viewer",
"display_name": "Viewer",
"description": "Read-only access to assigned projects and reports.",
},
)
def ensure_default_roles(role_repo: RoleRepository) -> list[Role]:
"""Ensure standard roles exist, creating missing ones.
Returns all current role records in creation order.
"""
roles: list[Role] = []
for definition in DEFAULT_ROLE_DEFINITIONS:
existing = role_repo.get_by_name(definition["name"])
if existing:
roles.append(existing)
continue
role = Role(**definition)
roles.append(role_repo.create(role))
return roles
def ensure_admin_user(
user_repo: UserRepository,
role_repo: RoleRepository,
*,
email: str,
username: str,
password: str,
) -> User:
"""Ensure an administrator user exists and holds the admin role."""
user = user_repo.get_by_email(email, with_roles=True)
if user is None:
user = User(
email=email,
username=username,
password_hash=User.hash_password(password),
is_active=True,
is_superuser=True,
)
user_repo.create(user)
else:
if not user.is_active:
user.is_active = True
if not user.is_superuser:
user.is_superuser = True
user_repo.session.flush()
admin_role = role_repo.get_by_name("admin")
if admin_role is None: # pragma: no cover - safety if ensure_default_roles wasn't called
admin_role = role_repo.create(
Role(
name="admin",
display_name="Administrator",
description="Full platform access with user management rights.",
)
)
user_repo.assign_role(
user_id=user.id,
role_id=admin_role.id,
granted_by=user.id,
)
return user