63 lines
1.9 KiB
Python
63 lines
1.9 KiB
Python
from __future__ import annotations
|
|
|
|
from collections.abc import Callable
|
|
from typing import Optional
|
|
|
|
from sqlalchemy.orm import Session
|
|
|
|
from backend.app.db.session import SessionLocal
|
|
from backend.app.repositories import (
|
|
StationRepository,
|
|
TrackRepository,
|
|
TrainRepository,
|
|
TrainScheduleRepository,
|
|
UserRepository,
|
|
)
|
|
|
|
|
|
class SqlAlchemyUnitOfWork:
|
|
"""Coordinate transactional work across repositories."""
|
|
|
|
def __init__(self, session_factory: Callable[[], Session] = SessionLocal) -> None:
|
|
self._session_factory = session_factory
|
|
self.session: Optional[Session] = None
|
|
self._committed = False
|
|
self.users: UserRepository
|
|
self.stations: StationRepository
|
|
self.tracks: TrackRepository
|
|
self.trains: TrainRepository
|
|
self.train_schedules: TrainScheduleRepository
|
|
|
|
def __enter__(self) -> "SqlAlchemyUnitOfWork":
|
|
self.session = self._session_factory()
|
|
self.users = UserRepository(self.session)
|
|
self.stations = StationRepository(self.session)
|
|
self.tracks = TrackRepository(self.session)
|
|
self.trains = TrainRepository(self.session)
|
|
self.train_schedules = TrainScheduleRepository(self.session)
|
|
return self
|
|
|
|
def __exit__(self, exc_type, exc, _tb) -> None:
|
|
try:
|
|
if exc:
|
|
self.rollback()
|
|
elif not self._committed:
|
|
self.commit()
|
|
finally:
|
|
if self.session is not None:
|
|
self.session.close()
|
|
self.session = None
|
|
self._committed = False
|
|
|
|
def commit(self) -> None:
|
|
if self.session is None:
|
|
raise RuntimeError("Unit of work is not active")
|
|
self.session.commit()
|
|
self._committed = True
|
|
|
|
def rollback(self) -> None:
|
|
if self.session is None:
|
|
return
|
|
self.session.rollback()
|
|
self._committed = False
|