from __future__ import annotations from collections.abc import Iterable from datetime import datetime from typing import Mapping, Sequence from sqlalchemy import select, func from sqlalchemy.exc import IntegrityError from sqlalchemy.orm import Session, joinedload, selectinload from models import ( FinancialInput, Project, Role, Scenario, ScenarioStatus, SimulationParameter, User, UserRole, ) from services.exceptions import EntityConflictError, EntityNotFoundError class ProjectRepository: """Persistence operations for Project entities.""" def __init__(self, session: Session) -> None: self.session = session def list(self, *, with_children: bool = False) -> Sequence[Project]: stmt = select(Project).order_by(Project.created_at) if with_children: stmt = stmt.options(selectinload(Project.scenarios)) return self.session.execute(stmt).scalars().all() def count(self) -> int: stmt = select(func.count(Project.id)) return self.session.execute(stmt).scalar_one() def recent(self, limit: int = 5) -> Sequence[Project]: stmt = ( select(Project) .order_by(Project.updated_at.desc()) .limit(limit) ) return self.session.execute(stmt).scalars().all() def get(self, project_id: int, *, with_children: bool = False) -> Project: stmt = select(Project).where(Project.id == project_id) if with_children: stmt = stmt.options(joinedload(Project.scenarios)) result = self.session.execute(stmt) if with_children: result = result.unique() project = result.scalar_one_or_none() if project is None: raise EntityNotFoundError(f"Project {project_id} not found") return project def exists(self, project_id: int) -> bool: stmt = select(Project.id).where(Project.id == project_id).limit(1) return self.session.execute(stmt).scalar_one_or_none() is not None def create(self, project: Project) -> Project: self.session.add(project) try: self.session.flush() except IntegrityError as exc: # pragma: no cover - reliance on DB constraints raise EntityConflictError( "Project violates uniqueness constraints") from exc return project def find_by_names(self, names: Iterable[str]) -> Mapping[str, Project]: normalised = {name.strip().lower() for name in names if name and name.strip()} if not normalised: return {} stmt = select(Project).where(func.lower(Project.name).in_(normalised)) records = self.session.execute(stmt).scalars().all() return {project.name.lower(): project for project in records} def delete(self, project_id: int) -> None: project = self.get(project_id) self.session.delete(project) class ScenarioRepository: """Persistence operations for Scenario entities.""" def __init__(self, session: Session) -> None: self.session = session def list_for_project(self, project_id: int) -> Sequence[Scenario]: stmt = ( select(Scenario) .where(Scenario.project_id == project_id) .order_by(Scenario.created_at) ) return self.session.execute(stmt).scalars().all() def count(self) -> int: stmt = select(func.count(Scenario.id)) return self.session.execute(stmt).scalar_one() def count_by_status(self, status: ScenarioStatus) -> int: stmt = select(func.count(Scenario.id)).where(Scenario.status == status) return self.session.execute(stmt).scalar_one() def recent(self, limit: int = 5, *, with_project: bool = False) -> Sequence[Scenario]: stmt = select(Scenario).order_by( Scenario.updated_at.desc()).limit(limit) if with_project: stmt = stmt.options(joinedload(Scenario.project)) return self.session.execute(stmt).scalars().all() def list_by_status( self, status: ScenarioStatus, *, limit: int | None = None, with_project: bool = False, ) -> Sequence[Scenario]: stmt = ( select(Scenario) .where(Scenario.status == status) .order_by(Scenario.updated_at.desc()) ) if with_project: stmt = stmt.options(joinedload(Scenario.project)) if limit is not None: stmt = stmt.limit(limit) return self.session.execute(stmt).scalars().all() def get(self, scenario_id: int, *, with_children: bool = False) -> Scenario: stmt = select(Scenario).where(Scenario.id == scenario_id) if with_children: stmt = stmt.options( joinedload(Scenario.financial_inputs), joinedload(Scenario.simulation_parameters), ) result = self.session.execute(stmt) if with_children: result = result.unique() scenario = result.scalar_one_or_none() if scenario is None: raise EntityNotFoundError(f"Scenario {scenario_id} not found") return scenario def exists(self, scenario_id: int) -> bool: stmt = select(Scenario.id).where(Scenario.id == scenario_id).limit(1) return self.session.execute(stmt).scalar_one_or_none() is not None def create(self, scenario: Scenario) -> Scenario: self.session.add(scenario) try: self.session.flush() except IntegrityError as exc: # pragma: no cover raise EntityConflictError("Scenario violates constraints") from exc return scenario def find_by_project_and_names( self, project_id: int, names: Iterable[str], ) -> Mapping[str, Scenario]: normalised = {name.strip().lower() for name in names if name and name.strip()} if not normalised: return {} stmt = ( select(Scenario) .where( Scenario.project_id == project_id, func.lower(Scenario.name).in_(normalised), ) ) records = self.session.execute(stmt).scalars().all() return {scenario.name.lower(): scenario for scenario in records} def delete(self, scenario_id: int) -> None: scenario = self.get(scenario_id) self.session.delete(scenario) class FinancialInputRepository: """Persistence operations for FinancialInput entities.""" def __init__(self, session: Session) -> None: self.session = session def list_for_scenario(self, scenario_id: int) -> Sequence[FinancialInput]: stmt = ( select(FinancialInput) .where(FinancialInput.scenario_id == scenario_id) .order_by(FinancialInput.created_at) ) return self.session.execute(stmt).scalars().all() def bulk_upsert(self, inputs: Iterable[FinancialInput]) -> Sequence[FinancialInput]: entities = list(inputs) self.session.add_all(entities) try: self.session.flush() except IntegrityError as exc: # pragma: no cover raise EntityConflictError( "Financial input violates constraints") from exc return entities def delete(self, input_id: int) -> None: stmt = select(FinancialInput).where(FinancialInput.id == input_id) entity = self.session.execute(stmt).scalar_one_or_none() if entity is None: raise EntityNotFoundError(f"Financial input {input_id} not found") self.session.delete(entity) def latest_created_at(self) -> datetime | None: stmt = ( select(FinancialInput.created_at) .order_by(FinancialInput.created_at.desc()) .limit(1) ) return self.session.execute(stmt).scalar_one_or_none() class SimulationParameterRepository: """Persistence operations for SimulationParameter entities.""" def __init__(self, session: Session) -> None: self.session = session def list_for_scenario(self, scenario_id: int) -> Sequence[SimulationParameter]: stmt = ( select(SimulationParameter) .where(SimulationParameter.scenario_id == scenario_id) .order_by(SimulationParameter.created_at) ) return self.session.execute(stmt).scalars().all() def bulk_upsert( self, parameters: Iterable[SimulationParameter] ) -> Sequence[SimulationParameter]: entities = list(parameters) self.session.add_all(entities) try: self.session.flush() except IntegrityError as exc: # pragma: no cover raise EntityConflictError( "Simulation parameter violates constraints") from exc return entities def delete(self, parameter_id: int) -> None: stmt = select(SimulationParameter).where( SimulationParameter.id == parameter_id) entity = self.session.execute(stmt).scalar_one_or_none() if entity is None: raise EntityNotFoundError( f"Simulation parameter {parameter_id} not found") self.session.delete(entity) class RoleRepository: """Persistence operations for Role entities.""" def __init__(self, session: Session) -> None: self.session = session def list(self) -> Sequence[Role]: stmt = select(Role).order_by(Role.name) return self.session.execute(stmt).scalars().all() def get(self, role_id: int) -> Role: stmt = select(Role).where(Role.id == role_id) role = self.session.execute(stmt).scalar_one_or_none() if role is None: raise EntityNotFoundError(f"Role {role_id} not found") return role def get_by_name(self, name: str) -> Role | None: stmt = select(Role).where(Role.name == name) return self.session.execute(stmt).scalar_one_or_none() def create(self, role: Role) -> Role: self.session.add(role) try: self.session.flush() except IntegrityError as exc: # pragma: no cover - DB constraint enforcement raise EntityConflictError( "Role violates uniqueness constraints") from exc return role class UserRepository: """Persistence operations for User entities and their role assignments.""" def __init__(self, session: Session) -> None: self.session = session def list(self, *, with_roles: bool = False) -> Sequence[User]: stmt = select(User).order_by(User.created_at) if with_roles: stmt = stmt.options(selectinload(User.roles)) return self.session.execute(stmt).scalars().all() def _apply_role_option(self, stmt, with_roles: bool): if with_roles: stmt = stmt.options( joinedload(User.role_assignments).joinedload(UserRole.role), selectinload(User.roles), ) return stmt def get(self, user_id: int, *, with_roles: bool = False) -> User: stmt = select(User).where(User.id == user_id).execution_options( populate_existing=True) stmt = self._apply_role_option(stmt, with_roles) result = self.session.execute(stmt) if with_roles: result = result.unique() user = result.scalar_one_or_none() if user is None: raise EntityNotFoundError(f"User {user_id} not found") return user def get_by_email(self, email: str, *, with_roles: bool = False) -> User | None: stmt = select(User).where(User.email == email).execution_options( populate_existing=True) stmt = self._apply_role_option(stmt, with_roles) result = self.session.execute(stmt) if with_roles: result = result.unique() return result.scalar_one_or_none() def get_by_username(self, username: str, *, with_roles: bool = False) -> User | None: stmt = select(User).where(User.username == username).execution_options(populate_existing=True) stmt = self._apply_role_option(stmt, with_roles) result = self.session.execute(stmt) if with_roles: result = result.unique() return result.scalar_one_or_none() def create(self, user: User) -> User: self.session.add(user) try: self.session.flush() except IntegrityError as exc: # pragma: no cover - DB constraint enforcement raise EntityConflictError( "User violates uniqueness constraints") from exc return user def assign_role( self, *, user_id: int, role_id: int, granted_by: int | None = None, ) -> UserRole: stmt = select(UserRole).where( UserRole.user_id == user_id, UserRole.role_id == role_id, ) assignment = self.session.execute(stmt).scalar_one_or_none() if assignment: return assignment assignment = UserRole( user_id=user_id, role_id=role_id, granted_by=granted_by, ) self.session.add(assignment) try: self.session.flush() except IntegrityError as exc: # pragma: no cover - DB constraint enforcement raise EntityConflictError( "Assignment violates constraints") from exc return assignment def revoke_role(self, *, user_id: int, role_id: int) -> None: stmt = select(UserRole).where( UserRole.user_id == user_id, UserRole.role_id == role_id, ) assignment = self.session.execute(stmt).scalar_one_or_none() if assignment is None: raise EntityNotFoundError( f"Role {role_id} not assigned to user {user_id}") self.session.delete(assignment) self.session.flush() DEFAULT_ROLE_DEFINITIONS: tuple[dict[str, str], ...] = ( { "name": "admin", "display_name": "Administrator", "description": "Full platform access with user management rights.", }, { "name": "project_manager", "display_name": "Project Manager", "description": "Manage projects, scenarios, and associated data.", }, { "name": "analyst", "display_name": "Analyst", "description": "Review dashboards and scenario outputs.", }, { "name": "viewer", "display_name": "Viewer", "description": "Read-only access to assigned projects and reports.", }, ) def ensure_default_roles(role_repo: RoleRepository) -> list[Role]: """Ensure standard roles exist, creating missing ones. Returns all current role records in creation order. """ roles: list[Role] = [] for definition in DEFAULT_ROLE_DEFINITIONS: existing = role_repo.get_by_name(definition["name"]) if existing: roles.append(existing) continue role = Role(**definition) roles.append(role_repo.create(role)) return roles def ensure_admin_user( user_repo: UserRepository, role_repo: RoleRepository, *, email: str, username: str, password: str, ) -> User: """Ensure an administrator user exists and holds the admin role.""" user = user_repo.get_by_email(email, with_roles=True) if user is None: user = User( email=email, username=username, password_hash=User.hash_password(password), is_active=True, is_superuser=True, ) user_repo.create(user) else: if not user.is_active: user.is_active = True if not user.is_superuser: user.is_superuser = True user_repo.session.flush() admin_role = role_repo.get_by_name("admin") if admin_role is None: # pragma: no cover - safety if ensure_default_roles wasn't called admin_role = role_repo.create( Role( name="admin", display_name="Administrator", description="Full platform access with user management rights.", ) ) user_repo.assign_role( user_id=user.id, role_id=admin_role.id, granted_by=user.id, ) return user