from __future__ import annotations from contextlib import AbstractContextManager from typing import Callable from sqlalchemy.orm import Session from config.database import SessionLocal from services.repositories import ( FinancialInputRepository, ProjectRepository, ScenarioRepository, SimulationParameterRepository, ) class UnitOfWork(AbstractContextManager["UnitOfWork"]): """Simple unit-of-work wrapper around SQLAlchemy sessions.""" def __init__(self, session_factory: Callable[[], Session] = SessionLocal) -> None: self._session_factory = session_factory self.session: Session | None = None def __enter__(self) -> "UnitOfWork": self.session = self._session_factory() self.projects = ProjectRepository(self.session) self.scenarios = ScenarioRepository(self.session) self.financial_inputs = FinancialInputRepository(self.session) self.simulation_parameters = SimulationParameterRepository(self.session) return self def __exit__(self, exc_type, exc_value, traceback) -> None: assert self.session is not None if exc_type is None: self.session.commit() else: self.session.rollback() self.session.close() def flush(self) -> None: if not self.session: raise RuntimeError("UnitOfWork session is not initialised") self.session.flush() def commit(self) -> None: if not self.session: raise RuntimeError("UnitOfWork session is not initialised") self.session.commit() def rollback(self) -> None: if not self.session: raise RuntimeError("UnitOfWork session is not initialised") self.session.rollback()