from __future__ import annotations from dataclasses import dataclass, field from datetime import datetime, timezone from typing import Any, List, cast from uuid import uuid4 import pytest from backend.app.db.models import TrainSchedule, User from backend.app.db.unit_of_work import SqlAlchemyUnitOfWork from backend.app.models import ( StationCreate, TrackCreate, TrainCreate, TrainScheduleCreate, UserCreate, ) from backend.app.repositories import ( StationRepository, TrackRepository, TrainRepository, TrainScheduleRepository, UserRepository, ) from sqlalchemy.orm import Session @dataclass class DummySession: added: List[Any] = field(default_factory=list) scalars_result: List[Any] = field(default_factory=list) scalar_result: Any = None statements: List[Any] = field(default_factory=list) committed: bool = False rolled_back: bool = False closed: bool = False def add(self, instance: Any) -> None: self.added.append(instance) def add_all(self, instances: list[Any]) -> None: self.added.extend(instances) def scalars(self, statement: Any) -> list[Any]: self.statements.append(statement) return list(self.scalars_result) def scalar(self, statement: Any) -> Any: self.statements.append(statement) return self.scalar_result def flush(self, _objects: list[Any] | None = None) -> None: # pragma: no cover - optional return None def commit(self) -> None: # pragma: no cover - optional self.committed = True def rollback(self) -> None: # pragma: no cover - optional self.rolled_back = True def close(self) -> None: # pragma: no cover - optional self.closed = True def test_station_repository_create_generates_geometry() -> None: session = DummySession() repo = StationRepository(session) # type: ignore[arg-type] station = repo.create( StationCreate( name="Central", latitude=52.52, longitude=13.405, osm_id="123", code="BER", elevation_m=34.5, ) ) assert station.name == "Central" assert session.added and session.added[0] is station assert getattr(station.location, "srid", None) == 4326 assert "POINT" in str(station.location) def test_track_repository_requires_geometry() -> None: session = DummySession() repo = TrackRepository(session) # type: ignore[arg-type] with pytest.raises(ValueError): repo.create( TrackCreate( start_station_id="00000000-0000-0000-0000-000000000001", end_station_id="00000000-0000-0000-0000-000000000002", coordinates=[(52.0, 13.0)], ) ) def test_track_repository_create_builds_linestring() -> None: session = DummySession() repo = TrackRepository(session) # type: ignore[arg-type] track = repo.create( TrackCreate( name="Main Line", start_station_id="00000000-0000-0000-0000-000000000001", end_station_id="00000000-0000-0000-0000-000000000002", coordinates=[(52.0, 13.0), (53.0, 14.0)], length_meters=1000.5, max_speed_kph=160, is_bidirectional=False, status="operational", ) ) assert session.added and session.added[0] is track assert track.status == "operational" geom_repr = str(track.track_geometry) assert "LINESTRING" in geom_repr assert "13.0 52.0" in geom_repr def test_train_repository_create_supports_optional_ids() -> None: session = DummySession() repo = TrainRepository(session) # type: ignore[arg-type] train = repo.create( TrainCreate( designation="ICE 123", capacity=400, max_speed_kph=300, operator_id=None, home_station_id="00000000-0000-0000-0000-000000000001", consist="locomotive+cars", ) ) assert session.added and session.added[0] is train assert train.designation == "ICE 123" assert str(train.home_station_id).endswith("1") assert train.operator_id is None def test_user_repository_create_persists_user() -> None: session = DummySession() repo = UserRepository(session) # type: ignore[arg-type] user = repo.create( UserCreate( username="demo", password_hash="hashed", email="demo@example.com", full_name="Demo Engineer", role="admin", ) ) assert session.added and session.added[0] is user assert user.username == "demo" assert user.role == "admin" def test_user_repository_get_by_username_is_case_insensitive() -> None: existing = User(username="Demo", password_hash="hashed", role="player") session = DummySession(scalar_result=existing) repo = UserRepository(session) # type: ignore[arg-type] result = repo.get_by_username("demo") assert result is existing assert session.statements def test_train_schedule_repository_create_converts_identifiers() -> None: session = DummySession() repo = TrainScheduleRepository(session) # type: ignore[arg-type] train_id = uuid4() station_id = uuid4() schedule = repo.create( TrainScheduleCreate( train_id=str(train_id), station_id=str(station_id), sequence_index=1, scheduled_arrival=datetime.now(timezone.utc), dwell_seconds=90, ) ) assert session.added and session.added[0] is schedule assert schedule.train_id == train_id assert schedule.station_id == station_id def test_train_schedule_repository_list_for_train_orders_results() -> None: train_id = uuid4() schedules = [ TrainSchedule(train_id=train_id, station_id=uuid4(), sequence_index=2), TrainSchedule(train_id=train_id, station_id=uuid4(), sequence_index=1), ] session = DummySession(scalars_result=schedules) repo = TrainScheduleRepository(session) # type: ignore[arg-type] result = repo.list_for_train(train_id) assert result == schedules statement = session.statements[-1] assert getattr(statement, "_order_by_clauses", ()) def test_unit_of_work_commits_and_closes_session() -> None: session = DummySession() uow = SqlAlchemyUnitOfWork(lambda: cast(Session, session)) with uow as active: active.users.create( UserCreate(username="demo", password_hash="hashed") ) active.commit() assert session.committed assert session.closed def test_unit_of_work_rolls_back_on_exception() -> None: session = DummySession() uow = SqlAlchemyUnitOfWork(lambda: cast(Session, session)) with pytest.raises(RuntimeError): with uow: raise RuntimeError("boom") assert session.rolled_back assert session.closed