Files
calminer/services/repositories.py

155 lines
5.5 KiB
Python

from __future__ import annotations
from collections.abc import Iterable
from typing import Sequence
from sqlalchemy import select
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session, joinedload
from models import FinancialInput, Project, Scenario, SimulationParameter
from services.exceptions import EntityConflictError, EntityNotFoundError
class ProjectRepository:
"""Persistence operations for Project entities."""
def __init__(self, session: Session) -> None:
self.session = session
def list(self) -> Sequence[Project]:
stmt = select(Project).order_by(Project.created_at)
return self.session.execute(stmt).scalars().all()
def get(self, project_id: int, *, with_children: bool = False) -> Project:
stmt = select(Project).where(Project.id == project_id)
if with_children:
stmt = stmt.options(joinedload(Project.scenarios))
project = self.session.execute(stmt).scalar_one_or_none()
if project is None:
raise EntityNotFoundError(f"Project {project_id} not found")
return project
def create(self, project: Project) -> Project:
self.session.add(project)
try:
self.session.flush()
except IntegrityError as exc: # pragma: no cover - reliance on DB constraints
raise EntityConflictError(
"Project violates uniqueness constraints") from exc
return project
def delete(self, project_id: int) -> None:
project = self.get(project_id)
self.session.delete(project)
class ScenarioRepository:
"""Persistence operations for Scenario entities."""
def __init__(self, session: Session) -> None:
self.session = session
def list_for_project(self, project_id: int) -> Sequence[Scenario]:
stmt = (
select(Scenario)
.where(Scenario.project_id == project_id)
.order_by(Scenario.created_at)
)
return self.session.execute(stmt).scalars().all()
def get(self, scenario_id: int, *, with_children: bool = False) -> Scenario:
stmt = select(Scenario).where(Scenario.id == scenario_id)
if with_children:
stmt = stmt.options(
joinedload(Scenario.financial_inputs),
joinedload(Scenario.simulation_parameters),
)
result = self.session.execute(stmt)
if with_children:
result = result.unique()
scenario = result.scalar_one_or_none()
if scenario is None:
raise EntityNotFoundError(f"Scenario {scenario_id} not found")
return scenario
def create(self, scenario: Scenario) -> Scenario:
self.session.add(scenario)
try:
self.session.flush()
except IntegrityError as exc: # pragma: no cover
raise EntityConflictError("Scenario violates constraints") from exc
return scenario
def delete(self, scenario_id: int) -> None:
scenario = self.get(scenario_id)
self.session.delete(scenario)
class FinancialInputRepository:
"""Persistence operations for FinancialInput entities."""
def __init__(self, session: Session) -> None:
self.session = session
def list_for_scenario(self, scenario_id: int) -> Sequence[FinancialInput]:
stmt = (
select(FinancialInput)
.where(FinancialInput.scenario_id == scenario_id)
.order_by(FinancialInput.created_at)
)
return self.session.execute(stmt).scalars().all()
def bulk_upsert(self, inputs: Iterable[FinancialInput]) -> Sequence[FinancialInput]:
entities = list(inputs)
self.session.add_all(entities)
try:
self.session.flush()
except IntegrityError as exc: # pragma: no cover
raise EntityConflictError(
"Financial input violates constraints") from exc
return entities
def delete(self, input_id: int) -> None:
stmt = select(FinancialInput).where(FinancialInput.id == input_id)
entity = self.session.execute(stmt).scalar_one_or_none()
if entity is None:
raise EntityNotFoundError(f"Financial input {input_id} not found")
self.session.delete(entity)
class SimulationParameterRepository:
"""Persistence operations for SimulationParameter entities."""
def __init__(self, session: Session) -> None:
self.session = session
def list_for_scenario(self, scenario_id: int) -> Sequence[SimulationParameter]:
stmt = (
select(SimulationParameter)
.where(SimulationParameter.scenario_id == scenario_id)
.order_by(SimulationParameter.created_at)
)
return self.session.execute(stmt).scalars().all()
def bulk_upsert(
self, parameters: Iterable[SimulationParameter]
) -> Sequence[SimulationParameter]:
entities = list(parameters)
self.session.add_all(entities)
try:
self.session.flush()
except IntegrityError as exc: # pragma: no cover
raise EntityConflictError(
"Simulation parameter violates constraints") from exc
return entities
def delete(self, parameter_id: int) -> None:
stmt = select(SimulationParameter).where(
SimulationParameter.id == parameter_id)
entity = self.session.execute(stmt).scalar_one_or_none()
if entity is None:
raise EntityNotFoundError(
f"Simulation parameter {parameter_id} not found")
self.session.delete(entity)