from __future__ import annotations from collections.abc import Callable from typing import Optional from sqlalchemy.orm import Session from backend.app.db.session import SessionLocal from backend.app.repositories import ( StationRepository, TrainRepository, TrainScheduleRepository, TrackRepository, UserRepository, ) class SqlAlchemyUnitOfWork: """Coordinate transactional work across repositories.""" def __init__(self, session_factory: Callable[[], Session] = SessionLocal) -> None: self._session_factory = session_factory self.session: Optional[Session] = None self._committed = False self.users: UserRepository self.stations: StationRepository self.tracks: TrackRepository self.trains: TrainRepository self.train_schedules: TrainScheduleRepository def __enter__(self) -> "SqlAlchemyUnitOfWork": self.session = self._session_factory() self.users = UserRepository(self.session) self.stations = StationRepository(self.session) self.tracks = TrackRepository(self.session) self.trains = TrainRepository(self.session) self.train_schedules = TrainScheduleRepository(self.session) return self def __exit__(self, exc_type, exc, _tb) -> None: try: if exc: self.rollback() elif not self._committed: self.commit() finally: if self.session is not None: self.session.close() self.session = None self._committed = False def commit(self) -> None: if self.session is None: raise RuntimeError("Unit of work is not active") self.session.commit() self._committed = True def rollback(self) -> None: if self.session is None: return self.session.rollback() self._committed = False