from __future__ import annotations from collections.abc import Iterable from datetime import datetime from typing import Sequence from sqlalchemy import select, func from sqlalchemy.exc import IntegrityError from sqlalchemy.orm import Session, joinedload, selectinload from models import FinancialInput, Project, Scenario, ScenarioStatus, SimulationParameter from services.exceptions import EntityConflictError, EntityNotFoundError class ProjectRepository: """Persistence operations for Project entities.""" def __init__(self, session: Session) -> None: self.session = session def list(self, *, with_children: bool = False) -> Sequence[Project]: stmt = select(Project).order_by(Project.created_at) if with_children: stmt = stmt.options(selectinload(Project.scenarios)) return self.session.execute(stmt).scalars().all() def count(self) -> int: stmt = select(func.count(Project.id)) return self.session.execute(stmt).scalar_one() def recent(self, limit: int = 5) -> Sequence[Project]: stmt = ( select(Project) .order_by(Project.updated_at.desc()) .limit(limit) ) return self.session.execute(stmt).scalars().all() def get(self, project_id: int, *, with_children: bool = False) -> Project: stmt = select(Project).where(Project.id == project_id) if with_children: stmt = stmt.options(joinedload(Project.scenarios)) result = self.session.execute(stmt) if with_children: result = result.unique() project = result.scalar_one_or_none() if project is None: raise EntityNotFoundError(f"Project {project_id} not found") return project def create(self, project: Project) -> Project: self.session.add(project) try: self.session.flush() except IntegrityError as exc: # pragma: no cover - reliance on DB constraints raise EntityConflictError( "Project violates uniqueness constraints") from exc return project def delete(self, project_id: int) -> None: project = self.get(project_id) self.session.delete(project) class ScenarioRepository: """Persistence operations for Scenario entities.""" def __init__(self, session: Session) -> None: self.session = session def list_for_project(self, project_id: int) -> Sequence[Scenario]: stmt = ( select(Scenario) .where(Scenario.project_id == project_id) .order_by(Scenario.created_at) ) return self.session.execute(stmt).scalars().all() def count(self) -> int: stmt = select(func.count(Scenario.id)) return self.session.execute(stmt).scalar_one() def count_by_status(self, status: ScenarioStatus) -> int: stmt = select(func.count(Scenario.id)).where(Scenario.status == status) return self.session.execute(stmt).scalar_one() def recent(self, limit: int = 5, *, with_project: bool = False) -> Sequence[Scenario]: stmt = select(Scenario).order_by( Scenario.updated_at.desc()).limit(limit) if with_project: stmt = stmt.options(joinedload(Scenario.project)) return self.session.execute(stmt).scalars().all() def list_by_status( self, status: ScenarioStatus, *, limit: int | None = None, with_project: bool = False, ) -> Sequence[Scenario]: stmt = ( select(Scenario) .where(Scenario.status == status) .order_by(Scenario.updated_at.desc()) ) if with_project: stmt = stmt.options(joinedload(Scenario.project)) if limit is not None: stmt = stmt.limit(limit) return self.session.execute(stmt).scalars().all() def get(self, scenario_id: int, *, with_children: bool = False) -> Scenario: stmt = select(Scenario).where(Scenario.id == scenario_id) if with_children: stmt = stmt.options( joinedload(Scenario.financial_inputs), joinedload(Scenario.simulation_parameters), ) result = self.session.execute(stmt) if with_children: result = result.unique() scenario = result.scalar_one_or_none() if scenario is None: raise EntityNotFoundError(f"Scenario {scenario_id} not found") return scenario def create(self, scenario: Scenario) -> Scenario: self.session.add(scenario) try: self.session.flush() except IntegrityError as exc: # pragma: no cover raise EntityConflictError("Scenario violates constraints") from exc return scenario def delete(self, scenario_id: int) -> None: scenario = self.get(scenario_id) self.session.delete(scenario) class FinancialInputRepository: """Persistence operations for FinancialInput entities.""" def __init__(self, session: Session) -> None: self.session = session def list_for_scenario(self, scenario_id: int) -> Sequence[FinancialInput]: stmt = ( select(FinancialInput) .where(FinancialInput.scenario_id == scenario_id) .order_by(FinancialInput.created_at) ) return self.session.execute(stmt).scalars().all() def bulk_upsert(self, inputs: Iterable[FinancialInput]) -> Sequence[FinancialInput]: entities = list(inputs) self.session.add_all(entities) try: self.session.flush() except IntegrityError as exc: # pragma: no cover raise EntityConflictError( "Financial input violates constraints") from exc return entities def delete(self, input_id: int) -> None: stmt = select(FinancialInput).where(FinancialInput.id == input_id) entity = self.session.execute(stmt).scalar_one_or_none() if entity is None: raise EntityNotFoundError(f"Financial input {input_id} not found") self.session.delete(entity) def latest_created_at(self) -> datetime | None: stmt = ( select(FinancialInput.created_at) .order_by(FinancialInput.created_at.desc()) .limit(1) ) return self.session.execute(stmt).scalar_one_or_none() class SimulationParameterRepository: """Persistence operations for SimulationParameter entities.""" def __init__(self, session: Session) -> None: self.session = session def list_for_scenario(self, scenario_id: int) -> Sequence[SimulationParameter]: stmt = ( select(SimulationParameter) .where(SimulationParameter.scenario_id == scenario_id) .order_by(SimulationParameter.created_at) ) return self.session.execute(stmt).scalars().all() def bulk_upsert( self, parameters: Iterable[SimulationParameter] ) -> Sequence[SimulationParameter]: entities = list(parameters) self.session.add_all(entities) try: self.session.flush() except IntegrityError as exc: # pragma: no cover raise EntityConflictError( "Simulation parameter violates constraints") from exc return entities def delete(self, parameter_id: int) -> None: stmt = select(SimulationParameter).where( SimulationParameter.id == parameter_id) entity = self.session.execute(stmt).scalar_one_or_none() if entity is None: raise EntityNotFoundError( f"Simulation parameter {parameter_id} not found") self.session.delete(entity)