diff --git a/changelog.md b/changelog.md index 1f5de9f..169fa8a 100644 --- a/changelog.md +++ b/changelog.md @@ -4,3 +4,4 @@ - Captured current implementation status, requirements coverage, missing features, and prioritized roadmap in `calminer-docs/implementation_status.md` to guide future development. - Added core SQLAlchemy domain models, shared metadata descriptors, and Alembic migration setup (with initial schema snapshot) to establish the persistence layer foundation. +- Introduced repository and unit-of-work helpers for projects, scenarios, financial inputs, and simulation parameters to support service-layer operations. diff --git a/services/exceptions.py b/services/exceptions.py new file mode 100644 index 0000000..6e58c29 --- /dev/null +++ b/services/exceptions.py @@ -0,0 +1,9 @@ +"""Domain-level exceptions for service and repository layers.""" + + +class EntityNotFoundError(Exception): + """Raised when a requested entity cannot be located.""" + + +class EntityConflictError(Exception): + """Raised when attempting to create or update an entity that violates uniqueness.""" diff --git a/services/repositories.py b/services/repositories.py new file mode 100644 index 0000000..bafff7b --- /dev/null +++ b/services/repositories.py @@ -0,0 +1,146 @@ +from __future__ import annotations + +from collections.abc import Iterable +from typing import Sequence + +from sqlalchemy import select +from sqlalchemy.exc import IntegrityError +from sqlalchemy.orm import Session, joinedload + +from models import FinancialInput, Project, Scenario, SimulationParameter +from services.exceptions import EntityConflictError, EntityNotFoundError + + +class ProjectRepository: + """Persistence operations for Project entities.""" + + def __init__(self, session: Session) -> None: + self.session = session + + def list(self) -> Sequence[Project]: + stmt = select(Project).order_by(Project.created_at) + return self.session.execute(stmt).scalars().all() + + def get(self, project_id: int, *, with_children: bool = False) -> Project: + stmt = select(Project).where(Project.id == project_id) + if with_children: + stmt = stmt.options(joinedload(Project.scenarios)) + project = self.session.execute(stmt).scalar_one_or_none() + if project is None: + raise EntityNotFoundError(f"Project {project_id} not found") + return project + + def create(self, project: Project) -> Project: + self.session.add(project) + try: + self.session.flush() + except IntegrityError as exc: # pragma: no cover - reliance on DB constraints + raise EntityConflictError("Project violates uniqueness constraints") from exc + return project + + def delete(self, project_id: int) -> None: + project = self.get(project_id) + self.session.delete(project) + + +class ScenarioRepository: + """Persistence operations for Scenario entities.""" + + def __init__(self, session: Session) -> None: + self.session = session + + def list_for_project(self, project_id: int) -> Sequence[Scenario]: + stmt = ( + select(Scenario) + .where(Scenario.project_id == project_id) + .order_by(Scenario.created_at) + ) + return self.session.execute(stmt).scalars().all() + + def get(self, scenario_id: int, *, with_children: bool = False) -> Scenario: + stmt = select(Scenario).where(Scenario.id == scenario_id) + if with_children: + stmt = stmt.options( + joinedload(Scenario.financial_inputs), + joinedload(Scenario.simulation_parameters), + ) + scenario = self.session.execute(stmt).scalar_one_or_none() + if scenario is None: + raise EntityNotFoundError(f"Scenario {scenario_id} not found") + return scenario + + def create(self, scenario: Scenario) -> Scenario: + self.session.add(scenario) + try: + self.session.flush() + except IntegrityError as exc: # pragma: no cover + raise EntityConflictError("Scenario violates constraints") from exc + return scenario + + def delete(self, scenario_id: int) -> None: + scenario = self.get(scenario_id) + self.session.delete(scenario) + + +class FinancialInputRepository: + """Persistence operations for FinancialInput entities.""" + + def __init__(self, session: Session) -> None: + self.session = session + + def list_for_scenario(self, scenario_id: int) -> Sequence[FinancialInput]: + stmt = ( + select(FinancialInput) + .where(FinancialInput.scenario_id == scenario_id) + .order_by(FinancialInput.created_at) + ) + return self.session.execute(stmt).scalars().all() + + def bulk_upsert(self, inputs: Iterable[FinancialInput]) -> Sequence[FinancialInput]: + entities = list(inputs) + self.session.add_all(entities) + try: + self.session.flush() + except IntegrityError as exc: # pragma: no cover + raise EntityConflictError("Financial input violates constraints") from exc + return entities + + def delete(self, input_id: int) -> None: + stmt = select(FinancialInput).where(FinancialInput.id == input_id) + entity = self.session.execute(stmt).scalar_one_or_none() + if entity is None: + raise EntityNotFoundError(f"Financial input {input_id} not found") + self.session.delete(entity) + + +class SimulationParameterRepository: + """Persistence operations for SimulationParameter entities.""" + + def __init__(self, session: Session) -> None: + self.session = session + + def list_for_scenario(self, scenario_id: int) -> Sequence[SimulationParameter]: + stmt = ( + select(SimulationParameter) + .where(SimulationParameter.scenario_id == scenario_id) + .order_by(SimulationParameter.created_at) + ) + return self.session.execute(stmt).scalars().all() + + def bulk_upsert( + self, parameters: Iterable[SimulationParameter] + ) -> Sequence[SimulationParameter]: + entities = list(parameters) + self.session.add_all(entities) + try: + self.session.flush() + except IntegrityError as exc: # pragma: no cover + raise EntityConflictError("Simulation parameter violates constraints") from exc + return entities + + def delete(self, parameter_id: int) -> None: + stmt = select(SimulationParameter).where(SimulationParameter.id == parameter_id) + entity = self.session.execute(stmt).scalar_one_or_none() + if entity is None: + raise EntityNotFoundError(f"Simulation parameter {parameter_id} not found") + self.session.delete(entity) diff --git a/services/unit_of_work.py b/services/unit_of_work.py new file mode 100644 index 0000000..6bb8cb8 --- /dev/null +++ b/services/unit_of_work.py @@ -0,0 +1,53 @@ +from __future__ import annotations + +from contextlib import AbstractContextManager +from typing import Callable + +from sqlalchemy.orm import Session + +from config.database import SessionLocal +from services.repositories import ( + FinancialInputRepository, + ProjectRepository, + ScenarioRepository, + SimulationParameterRepository, +) + + +class UnitOfWork(AbstractContextManager["UnitOfWork"]): + """Simple unit-of-work wrapper around SQLAlchemy sessions.""" + + def __init__(self, session_factory: Callable[[], Session] = SessionLocal) -> None: + self._session_factory = session_factory + self.session: Session | None = None + + def __enter__(self) -> "UnitOfWork": + self.session = self._session_factory() + self.projects = ProjectRepository(self.session) + self.scenarios = ScenarioRepository(self.session) + self.financial_inputs = FinancialInputRepository(self.session) + self.simulation_parameters = SimulationParameterRepository(self.session) + return self + + def __exit__(self, exc_type, exc_value, traceback) -> None: + assert self.session is not None + if exc_type is None: + self.session.commit() + else: + self.session.rollback() + self.session.close() + + def flush(self) -> None: + if not self.session: + raise RuntimeError("UnitOfWork session is not initialised") + self.session.flush() + + def commit(self) -> None: + if not self.session: + raise RuntimeError("UnitOfWork session is not initialised") + self.session.commit() + + def rollback(self) -> None: + if not self.session: + raise RuntimeError("UnitOfWork session is not initialised") + self.session.rollback()