157 lines
5.6 KiB
Python
157 lines
5.6 KiB
Python
from __future__ import annotations
|
|
|
|
from collections.abc import Iterable
|
|
from typing import Sequence
|
|
|
|
from sqlalchemy import select
|
|
from sqlalchemy.exc import IntegrityError
|
|
from sqlalchemy.orm import Session, joinedload, selectinload
|
|
|
|
from models import FinancialInput, Project, Scenario, SimulationParameter
|
|
from services.exceptions import EntityConflictError, EntityNotFoundError
|
|
|
|
|
|
class ProjectRepository:
|
|
"""Persistence operations for Project entities."""
|
|
|
|
def __init__(self, session: Session) -> None:
|
|
self.session = session
|
|
|
|
def list(self, *, with_children: bool = False) -> Sequence[Project]:
|
|
stmt = select(Project).order_by(Project.created_at)
|
|
if with_children:
|
|
stmt = stmt.options(selectinload(Project.scenarios))
|
|
return self.session.execute(stmt).scalars().all()
|
|
|
|
def get(self, project_id: int, *, with_children: bool = False) -> Project:
|
|
stmt = select(Project).where(Project.id == project_id)
|
|
if with_children:
|
|
stmt = stmt.options(joinedload(Project.scenarios))
|
|
project = self.session.execute(stmt).scalar_one_or_none()
|
|
if project is None:
|
|
raise EntityNotFoundError(f"Project {project_id} not found")
|
|
return project
|
|
|
|
def create(self, project: Project) -> Project:
|
|
self.session.add(project)
|
|
try:
|
|
self.session.flush()
|
|
except IntegrityError as exc: # pragma: no cover - reliance on DB constraints
|
|
raise EntityConflictError(
|
|
"Project violates uniqueness constraints") from exc
|
|
return project
|
|
|
|
def delete(self, project_id: int) -> None:
|
|
project = self.get(project_id)
|
|
self.session.delete(project)
|
|
|
|
|
|
class ScenarioRepository:
|
|
"""Persistence operations for Scenario entities."""
|
|
|
|
def __init__(self, session: Session) -> None:
|
|
self.session = session
|
|
|
|
def list_for_project(self, project_id: int) -> Sequence[Scenario]:
|
|
stmt = (
|
|
select(Scenario)
|
|
.where(Scenario.project_id == project_id)
|
|
.order_by(Scenario.created_at)
|
|
)
|
|
return self.session.execute(stmt).scalars().all()
|
|
|
|
def get(self, scenario_id: int, *, with_children: bool = False) -> Scenario:
|
|
stmt = select(Scenario).where(Scenario.id == scenario_id)
|
|
if with_children:
|
|
stmt = stmt.options(
|
|
joinedload(Scenario.financial_inputs),
|
|
joinedload(Scenario.simulation_parameters),
|
|
)
|
|
result = self.session.execute(stmt)
|
|
if with_children:
|
|
result = result.unique()
|
|
scenario = result.scalar_one_or_none()
|
|
if scenario is None:
|
|
raise EntityNotFoundError(f"Scenario {scenario_id} not found")
|
|
return scenario
|
|
|
|
def create(self, scenario: Scenario) -> Scenario:
|
|
self.session.add(scenario)
|
|
try:
|
|
self.session.flush()
|
|
except IntegrityError as exc: # pragma: no cover
|
|
raise EntityConflictError("Scenario violates constraints") from exc
|
|
return scenario
|
|
|
|
def delete(self, scenario_id: int) -> None:
|
|
scenario = self.get(scenario_id)
|
|
self.session.delete(scenario)
|
|
|
|
|
|
class FinancialInputRepository:
|
|
"""Persistence operations for FinancialInput entities."""
|
|
|
|
def __init__(self, session: Session) -> None:
|
|
self.session = session
|
|
|
|
def list_for_scenario(self, scenario_id: int) -> Sequence[FinancialInput]:
|
|
stmt = (
|
|
select(FinancialInput)
|
|
.where(FinancialInput.scenario_id == scenario_id)
|
|
.order_by(FinancialInput.created_at)
|
|
)
|
|
return self.session.execute(stmt).scalars().all()
|
|
|
|
def bulk_upsert(self, inputs: Iterable[FinancialInput]) -> Sequence[FinancialInput]:
|
|
entities = list(inputs)
|
|
self.session.add_all(entities)
|
|
try:
|
|
self.session.flush()
|
|
except IntegrityError as exc: # pragma: no cover
|
|
raise EntityConflictError(
|
|
"Financial input violates constraints") from exc
|
|
return entities
|
|
|
|
def delete(self, input_id: int) -> None:
|
|
stmt = select(FinancialInput).where(FinancialInput.id == input_id)
|
|
entity = self.session.execute(stmt).scalar_one_or_none()
|
|
if entity is None:
|
|
raise EntityNotFoundError(f"Financial input {input_id} not found")
|
|
self.session.delete(entity)
|
|
|
|
|
|
class SimulationParameterRepository:
|
|
"""Persistence operations for SimulationParameter entities."""
|
|
|
|
def __init__(self, session: Session) -> None:
|
|
self.session = session
|
|
|
|
def list_for_scenario(self, scenario_id: int) -> Sequence[SimulationParameter]:
|
|
stmt = (
|
|
select(SimulationParameter)
|
|
.where(SimulationParameter.scenario_id == scenario_id)
|
|
.order_by(SimulationParameter.created_at)
|
|
)
|
|
return self.session.execute(stmt).scalars().all()
|
|
|
|
def bulk_upsert(
|
|
self, parameters: Iterable[SimulationParameter]
|
|
) -> Sequence[SimulationParameter]:
|
|
entities = list(parameters)
|
|
self.session.add_all(entities)
|
|
try:
|
|
self.session.flush()
|
|
except IntegrityError as exc: # pragma: no cover
|
|
raise EntityConflictError(
|
|
"Simulation parameter violates constraints") from exc
|
|
return entities
|
|
|
|
def delete(self, parameter_id: int) -> None:
|
|
stmt = select(SimulationParameter).where(
|
|
SimulationParameter.id == parameter_id)
|
|
entity = self.session.execute(stmt).scalar_one_or_none()
|
|
if entity is None:
|
|
raise EntityNotFoundError(
|
|
f"Simulation parameter {parameter_id} not found")
|
|
self.session.delete(entity)
|