feat: implement repository and unit-of-work patterns for service layer operations
This commit is contained in:
@@ -4,3 +4,4 @@
|
||||
|
||||
- Captured current implementation status, requirements coverage, missing features, and prioritized roadmap in `calminer-docs/implementation_status.md` to guide future development.
|
||||
- Added core SQLAlchemy domain models, shared metadata descriptors, and Alembic migration setup (with initial schema snapshot) to establish the persistence layer foundation.
|
||||
- Introduced repository and unit-of-work helpers for projects, scenarios, financial inputs, and simulation parameters to support service-layer operations.
|
||||
|
||||
9
services/exceptions.py
Normal file
9
services/exceptions.py
Normal file
@@ -0,0 +1,9 @@
|
||||
"""Domain-level exceptions for service and repository layers."""
|
||||
|
||||
|
||||
class EntityNotFoundError(Exception):
|
||||
"""Raised when a requested entity cannot be located."""
|
||||
|
||||
|
||||
class EntityConflictError(Exception):
|
||||
"""Raised when attempting to create or update an entity that violates uniqueness."""
|
||||
146
services/repositories.py
Normal file
146
services/repositories.py
Normal file
@@ -0,0 +1,146 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Iterable
|
||||
from typing import Sequence
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.orm import Session, joinedload
|
||||
|
||||
from models import FinancialInput, Project, Scenario, SimulationParameter
|
||||
from services.exceptions import EntityConflictError, EntityNotFoundError
|
||||
|
||||
|
||||
class ProjectRepository:
|
||||
"""Persistence operations for Project entities."""
|
||||
|
||||
def __init__(self, session: Session) -> None:
|
||||
self.session = session
|
||||
|
||||
def list(self) -> Sequence[Project]:
|
||||
stmt = select(Project).order_by(Project.created_at)
|
||||
return self.session.execute(stmt).scalars().all()
|
||||
|
||||
def get(self, project_id: int, *, with_children: bool = False) -> Project:
|
||||
stmt = select(Project).where(Project.id == project_id)
|
||||
if with_children:
|
||||
stmt = stmt.options(joinedload(Project.scenarios))
|
||||
project = self.session.execute(stmt).scalar_one_or_none()
|
||||
if project is None:
|
||||
raise EntityNotFoundError(f"Project {project_id} not found")
|
||||
return project
|
||||
|
||||
def create(self, project: Project) -> Project:
|
||||
self.session.add(project)
|
||||
try:
|
||||
self.session.flush()
|
||||
except IntegrityError as exc: # pragma: no cover - reliance on DB constraints
|
||||
raise EntityConflictError("Project violates uniqueness constraints") from exc
|
||||
return project
|
||||
|
||||
def delete(self, project_id: int) -> None:
|
||||
project = self.get(project_id)
|
||||
self.session.delete(project)
|
||||
|
||||
|
||||
class ScenarioRepository:
|
||||
"""Persistence operations for Scenario entities."""
|
||||
|
||||
def __init__(self, session: Session) -> None:
|
||||
self.session = session
|
||||
|
||||
def list_for_project(self, project_id: int) -> Sequence[Scenario]:
|
||||
stmt = (
|
||||
select(Scenario)
|
||||
.where(Scenario.project_id == project_id)
|
||||
.order_by(Scenario.created_at)
|
||||
)
|
||||
return self.session.execute(stmt).scalars().all()
|
||||
|
||||
def get(self, scenario_id: int, *, with_children: bool = False) -> Scenario:
|
||||
stmt = select(Scenario).where(Scenario.id == scenario_id)
|
||||
if with_children:
|
||||
stmt = stmt.options(
|
||||
joinedload(Scenario.financial_inputs),
|
||||
joinedload(Scenario.simulation_parameters),
|
||||
)
|
||||
scenario = self.session.execute(stmt).scalar_one_or_none()
|
||||
if scenario is None:
|
||||
raise EntityNotFoundError(f"Scenario {scenario_id} not found")
|
||||
return scenario
|
||||
|
||||
def create(self, scenario: Scenario) -> Scenario:
|
||||
self.session.add(scenario)
|
||||
try:
|
||||
self.session.flush()
|
||||
except IntegrityError as exc: # pragma: no cover
|
||||
raise EntityConflictError("Scenario violates constraints") from exc
|
||||
return scenario
|
||||
|
||||
def delete(self, scenario_id: int) -> None:
|
||||
scenario = self.get(scenario_id)
|
||||
self.session.delete(scenario)
|
||||
|
||||
|
||||
class FinancialInputRepository:
|
||||
"""Persistence operations for FinancialInput entities."""
|
||||
|
||||
def __init__(self, session: Session) -> None:
|
||||
self.session = session
|
||||
|
||||
def list_for_scenario(self, scenario_id: int) -> Sequence[FinancialInput]:
|
||||
stmt = (
|
||||
select(FinancialInput)
|
||||
.where(FinancialInput.scenario_id == scenario_id)
|
||||
.order_by(FinancialInput.created_at)
|
||||
)
|
||||
return self.session.execute(stmt).scalars().all()
|
||||
|
||||
def bulk_upsert(self, inputs: Iterable[FinancialInput]) -> Sequence[FinancialInput]:
|
||||
entities = list(inputs)
|
||||
self.session.add_all(entities)
|
||||
try:
|
||||
self.session.flush()
|
||||
except IntegrityError as exc: # pragma: no cover
|
||||
raise EntityConflictError("Financial input violates constraints") from exc
|
||||
return entities
|
||||
|
||||
def delete(self, input_id: int) -> None:
|
||||
stmt = select(FinancialInput).where(FinancialInput.id == input_id)
|
||||
entity = self.session.execute(stmt).scalar_one_or_none()
|
||||
if entity is None:
|
||||
raise EntityNotFoundError(f"Financial input {input_id} not found")
|
||||
self.session.delete(entity)
|
||||
|
||||
|
||||
class SimulationParameterRepository:
|
||||
"""Persistence operations for SimulationParameter entities."""
|
||||
|
||||
def __init__(self, session: Session) -> None:
|
||||
self.session = session
|
||||
|
||||
def list_for_scenario(self, scenario_id: int) -> Sequence[SimulationParameter]:
|
||||
stmt = (
|
||||
select(SimulationParameter)
|
||||
.where(SimulationParameter.scenario_id == scenario_id)
|
||||
.order_by(SimulationParameter.created_at)
|
||||
)
|
||||
return self.session.execute(stmt).scalars().all()
|
||||
|
||||
def bulk_upsert(
|
||||
self, parameters: Iterable[SimulationParameter]
|
||||
) -> Sequence[SimulationParameter]:
|
||||
entities = list(parameters)
|
||||
self.session.add_all(entities)
|
||||
try:
|
||||
self.session.flush()
|
||||
except IntegrityError as exc: # pragma: no cover
|
||||
raise EntityConflictError("Simulation parameter violates constraints") from exc
|
||||
return entities
|
||||
|
||||
def delete(self, parameter_id: int) -> None:
|
||||
stmt = select(SimulationParameter).where(SimulationParameter.id == parameter_id)
|
||||
entity = self.session.execute(stmt).scalar_one_or_none()
|
||||
if entity is None:
|
||||
raise EntityNotFoundError(f"Simulation parameter {parameter_id} not found")
|
||||
self.session.delete(entity)
|
||||
53
services/unit_of_work.py
Normal file
53
services/unit_of_work.py
Normal file
@@ -0,0 +1,53 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from contextlib import AbstractContextManager
|
||||
from typing import Callable
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from config.database import SessionLocal
|
||||
from services.repositories import (
|
||||
FinancialInputRepository,
|
||||
ProjectRepository,
|
||||
ScenarioRepository,
|
||||
SimulationParameterRepository,
|
||||
)
|
||||
|
||||
|
||||
class UnitOfWork(AbstractContextManager["UnitOfWork"]):
|
||||
"""Simple unit-of-work wrapper around SQLAlchemy sessions."""
|
||||
|
||||
def __init__(self, session_factory: Callable[[], Session] = SessionLocal) -> None:
|
||||
self._session_factory = session_factory
|
||||
self.session: Session | None = None
|
||||
|
||||
def __enter__(self) -> "UnitOfWork":
|
||||
self.session = self._session_factory()
|
||||
self.projects = ProjectRepository(self.session)
|
||||
self.scenarios = ScenarioRepository(self.session)
|
||||
self.financial_inputs = FinancialInputRepository(self.session)
|
||||
self.simulation_parameters = SimulationParameterRepository(self.session)
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback) -> None:
|
||||
assert self.session is not None
|
||||
if exc_type is None:
|
||||
self.session.commit()
|
||||
else:
|
||||
self.session.rollback()
|
||||
self.session.close()
|
||||
|
||||
def flush(self) -> None:
|
||||
if not self.session:
|
||||
raise RuntimeError("UnitOfWork session is not initialised")
|
||||
self.session.flush()
|
||||
|
||||
def commit(self) -> None:
|
||||
if not self.session:
|
||||
raise RuntimeError("UnitOfWork session is not initialised")
|
||||
self.session.commit()
|
||||
|
||||
def rollback(self) -> None:
|
||||
if not self.session:
|
||||
raise RuntimeError("UnitOfWork session is not initialised")
|
||||
self.session.rollback()
|
||||
Reference in New Issue
Block a user