feat: implement repository and unit-of-work patterns for service layer operations
This commit is contained in:
146
services/repositories.py
Normal file
146
services/repositories.py
Normal file
@@ -0,0 +1,146 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Iterable
|
||||
from typing import Sequence
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.orm import Session, joinedload
|
||||
|
||||
from models import FinancialInput, Project, Scenario, SimulationParameter
|
||||
from services.exceptions import EntityConflictError, EntityNotFoundError
|
||||
|
||||
|
||||
class ProjectRepository:
|
||||
"""Persistence operations for Project entities."""
|
||||
|
||||
def __init__(self, session: Session) -> None:
|
||||
self.session = session
|
||||
|
||||
def list(self) -> Sequence[Project]:
|
||||
stmt = select(Project).order_by(Project.created_at)
|
||||
return self.session.execute(stmt).scalars().all()
|
||||
|
||||
def get(self, project_id: int, *, with_children: bool = False) -> Project:
|
||||
stmt = select(Project).where(Project.id == project_id)
|
||||
if with_children:
|
||||
stmt = stmt.options(joinedload(Project.scenarios))
|
||||
project = self.session.execute(stmt).scalar_one_or_none()
|
||||
if project is None:
|
||||
raise EntityNotFoundError(f"Project {project_id} not found")
|
||||
return project
|
||||
|
||||
def create(self, project: Project) -> Project:
|
||||
self.session.add(project)
|
||||
try:
|
||||
self.session.flush()
|
||||
except IntegrityError as exc: # pragma: no cover - reliance on DB constraints
|
||||
raise EntityConflictError("Project violates uniqueness constraints") from exc
|
||||
return project
|
||||
|
||||
def delete(self, project_id: int) -> None:
|
||||
project = self.get(project_id)
|
||||
self.session.delete(project)
|
||||
|
||||
|
||||
class ScenarioRepository:
|
||||
"""Persistence operations for Scenario entities."""
|
||||
|
||||
def __init__(self, session: Session) -> None:
|
||||
self.session = session
|
||||
|
||||
def list_for_project(self, project_id: int) -> Sequence[Scenario]:
|
||||
stmt = (
|
||||
select(Scenario)
|
||||
.where(Scenario.project_id == project_id)
|
||||
.order_by(Scenario.created_at)
|
||||
)
|
||||
return self.session.execute(stmt).scalars().all()
|
||||
|
||||
def get(self, scenario_id: int, *, with_children: bool = False) -> Scenario:
|
||||
stmt = select(Scenario).where(Scenario.id == scenario_id)
|
||||
if with_children:
|
||||
stmt = stmt.options(
|
||||
joinedload(Scenario.financial_inputs),
|
||||
joinedload(Scenario.simulation_parameters),
|
||||
)
|
||||
scenario = self.session.execute(stmt).scalar_one_or_none()
|
||||
if scenario is None:
|
||||
raise EntityNotFoundError(f"Scenario {scenario_id} not found")
|
||||
return scenario
|
||||
|
||||
def create(self, scenario: Scenario) -> Scenario:
|
||||
self.session.add(scenario)
|
||||
try:
|
||||
self.session.flush()
|
||||
except IntegrityError as exc: # pragma: no cover
|
||||
raise EntityConflictError("Scenario violates constraints") from exc
|
||||
return scenario
|
||||
|
||||
def delete(self, scenario_id: int) -> None:
|
||||
scenario = self.get(scenario_id)
|
||||
self.session.delete(scenario)
|
||||
|
||||
|
||||
class FinancialInputRepository:
|
||||
"""Persistence operations for FinancialInput entities."""
|
||||
|
||||
def __init__(self, session: Session) -> None:
|
||||
self.session = session
|
||||
|
||||
def list_for_scenario(self, scenario_id: int) -> Sequence[FinancialInput]:
|
||||
stmt = (
|
||||
select(FinancialInput)
|
||||
.where(FinancialInput.scenario_id == scenario_id)
|
||||
.order_by(FinancialInput.created_at)
|
||||
)
|
||||
return self.session.execute(stmt).scalars().all()
|
||||
|
||||
def bulk_upsert(self, inputs: Iterable[FinancialInput]) -> Sequence[FinancialInput]:
|
||||
entities = list(inputs)
|
||||
self.session.add_all(entities)
|
||||
try:
|
||||
self.session.flush()
|
||||
except IntegrityError as exc: # pragma: no cover
|
||||
raise EntityConflictError("Financial input violates constraints") from exc
|
||||
return entities
|
||||
|
||||
def delete(self, input_id: int) -> None:
|
||||
stmt = select(FinancialInput).where(FinancialInput.id == input_id)
|
||||
entity = self.session.execute(stmt).scalar_one_or_none()
|
||||
if entity is None:
|
||||
raise EntityNotFoundError(f"Financial input {input_id} not found")
|
||||
self.session.delete(entity)
|
||||
|
||||
|
||||
class SimulationParameterRepository:
|
||||
"""Persistence operations for SimulationParameter entities."""
|
||||
|
||||
def __init__(self, session: Session) -> None:
|
||||
self.session = session
|
||||
|
||||
def list_for_scenario(self, scenario_id: int) -> Sequence[SimulationParameter]:
|
||||
stmt = (
|
||||
select(SimulationParameter)
|
||||
.where(SimulationParameter.scenario_id == scenario_id)
|
||||
.order_by(SimulationParameter.created_at)
|
||||
)
|
||||
return self.session.execute(stmt).scalars().all()
|
||||
|
||||
def bulk_upsert(
|
||||
self, parameters: Iterable[SimulationParameter]
|
||||
) -> Sequence[SimulationParameter]:
|
||||
entities = list(parameters)
|
||||
self.session.add_all(entities)
|
||||
try:
|
||||
self.session.flush()
|
||||
except IntegrityError as exc: # pragma: no cover
|
||||
raise EntityConflictError("Simulation parameter violates constraints") from exc
|
||||
return entities
|
||||
|
||||
def delete(self, parameter_id: int) -> None:
|
||||
stmt = select(SimulationParameter).where(SimulationParameter.id == parameter_id)
|
||||
entity = self.session.execute(stmt).scalar_one_or_none()
|
||||
if entity is None:
|
||||
raise EntityNotFoundError(f"Simulation parameter {parameter_id} not found")
|
||||
self.session.delete(entity)
|
||||
Reference in New Issue
Block a user