Files
rail-game/backend/tests/test_repositories.py

237 lines
6.8 KiB
Python

from __future__ import annotations
from dataclasses import dataclass, field
from datetime import datetime, timezone
from typing import Any, List, cast
from uuid import uuid4
import pytest
from sqlalchemy.orm import Session
from backend.app.db.models import TrainSchedule, User
from backend.app.db.unit_of_work import SqlAlchemyUnitOfWork
from backend.app.models import (
StationCreate,
TrackCreate,
TrainCreate,
TrainScheduleCreate,
UserCreate,
)
from backend.app.repositories import (
StationRepository,
TrackRepository,
TrainRepository,
TrainScheduleRepository,
UserRepository,
)
@dataclass
class DummySession:
added: List[Any] = field(default_factory=list)
scalars_result: List[Any] = field(default_factory=list)
scalar_result: Any = None
statements: List[Any] = field(default_factory=list)
committed: bool = False
rolled_back: bool = False
closed: bool = False
def add(self, instance: Any) -> None:
self.added.append(instance)
def add_all(self, instances: list[Any]) -> None:
self.added.extend(instances)
def scalars(self, statement: Any) -> list[Any]:
self.statements.append(statement)
return list(self.scalars_result)
def scalar(self, statement: Any) -> Any:
self.statements.append(statement)
return self.scalar_result
def flush(
self, _objects: list[Any] | None = None
) -> None: # pragma: no cover - optional
return None
def commit(self) -> None: # pragma: no cover - optional
self.committed = True
def rollback(self) -> None: # pragma: no cover - optional
self.rolled_back = True
def close(self) -> None: # pragma: no cover - optional
self.closed = True
def test_station_repository_create_generates_geometry() -> None:
session = DummySession()
repo = StationRepository(session) # type: ignore[arg-type]
station = repo.create(
StationCreate(
name="Central",
latitude=52.52,
longitude=13.405,
osm_id="123",
code="BER",
elevation_m=34.5,
)
)
assert station.name == "Central"
assert session.added and session.added[0] is station
assert getattr(station.location, "srid", None) == 4326
assert "POINT" in str(station.location)
def test_track_repository_requires_geometry() -> None:
session = DummySession()
repo = TrackRepository(session) # type: ignore[arg-type]
with pytest.raises(ValueError):
repo.create(
TrackCreate(
start_station_id="00000000-0000-0000-0000-000000000001",
end_station_id="00000000-0000-0000-0000-000000000002",
coordinates=[(52.0, 13.0)],
)
)
def test_track_repository_create_builds_linestring() -> None:
session = DummySession()
repo = TrackRepository(session) # type: ignore[arg-type]
track = repo.create(
TrackCreate(
name="Main Line",
start_station_id="00000000-0000-0000-0000-000000000001",
end_station_id="00000000-0000-0000-0000-000000000002",
coordinates=[(52.0, 13.0), (53.0, 14.0)],
length_meters=1000.5,
max_speed_kph=160,
is_bidirectional=False,
status="operational",
)
)
assert session.added and session.added[0] is track
assert track.status == "operational"
geom_repr = str(track.track_geometry)
assert "LINESTRING" in geom_repr
assert "13.0 52.0" in geom_repr
def test_train_repository_create_supports_optional_ids() -> None:
session = DummySession()
repo = TrainRepository(session) # type: ignore[arg-type]
train = repo.create(
TrainCreate(
designation="ICE 123",
capacity=400,
max_speed_kph=300,
operator_id=None,
home_station_id="00000000-0000-0000-0000-000000000001",
consist="locomotive+cars",
)
)
assert session.added and session.added[0] is train
assert train.designation == "ICE 123"
assert str(train.home_station_id).endswith("1")
assert train.operator_id is None
def test_user_repository_create_persists_user() -> None:
session = DummySession()
repo = UserRepository(session) # type: ignore[arg-type]
user = repo.create(
UserCreate(
username="demo",
password_hash="hashed",
email="demo@example.com",
full_name="Demo Engineer",
role="admin",
)
)
assert session.added and session.added[0] is user
assert user.username == "demo"
assert user.role == "admin"
def test_user_repository_get_by_username_is_case_insensitive() -> None:
existing = User(username="Demo", password_hash="hashed", role="player")
session = DummySession(scalar_result=existing)
repo = UserRepository(session) # type: ignore[arg-type]
result = repo.get_by_username("demo")
assert result is existing
assert session.statements
def test_train_schedule_repository_create_converts_identifiers() -> None:
session = DummySession()
repo = TrainScheduleRepository(session) # type: ignore[arg-type]
train_id = uuid4()
station_id = uuid4()
schedule = repo.create(
TrainScheduleCreate(
train_id=str(train_id),
station_id=str(station_id),
sequence_index=1,
scheduled_arrival=datetime.now(timezone.utc),
dwell_seconds=90,
)
)
assert session.added and session.added[0] is schedule
assert schedule.train_id == train_id
assert schedule.station_id == station_id
def test_train_schedule_repository_list_for_train_orders_results() -> None:
train_id = uuid4()
schedules = [
TrainSchedule(train_id=train_id, station_id=uuid4(), sequence_index=2),
TrainSchedule(train_id=train_id, station_id=uuid4(), sequence_index=1),
]
session = DummySession(scalars_result=schedules)
repo = TrainScheduleRepository(session) # type: ignore[arg-type]
result = repo.list_for_train(train_id)
assert result == schedules
statement = session.statements[-1]
assert getattr(statement, "_order_by_clauses", ())
def test_unit_of_work_commits_and_closes_session() -> None:
session = DummySession()
uow = SqlAlchemyUnitOfWork(lambda: cast(Session, session))
with uow as active:
active.users.create(UserCreate(username="demo", password_hash="hashed"))
active.commit()
assert session.committed
assert session.closed
def test_unit_of_work_rolls_back_on_exception() -> None:
session = DummySession()
uow = SqlAlchemyUnitOfWork(lambda: cast(Session, session))
with pytest.raises(RuntimeError):
with uow:
raise RuntimeError("boom")
assert session.rolled_back
assert session.closed