from __future__ import annotations from dataclasses import dataclass, field from typing import Any, List import pytest from backend.app.models import StationCreate, TrackCreate, TrainCreate from backend.app.repositories import StationRepository, TrackRepository, TrainRepository @dataclass class DummySession: added: List[Any] = field(default_factory=list) def add(self, instance: Any) -> None: self.added.append(instance) def add_all(self, instances: list[Any]) -> None: self.added.extend(instances) def scalars(self, _statement: Any) -> list[Any]: # pragma: no cover - not used here return [] def flush(self, _objects: list[Any] | None = None) -> None: # pragma: no cover - optional return None def test_station_repository_create_generates_geometry() -> None: session = DummySession() repo = StationRepository(session) # type: ignore[arg-type] station = repo.create( StationCreate( name="Central", latitude=52.52, longitude=13.405, osm_id="123", code="BER", elevation_m=34.5, ) ) assert station.name == "Central" assert session.added and session.added[0] is station assert getattr(station.location, "srid", None) == 4326 assert "POINT" in str(station.location) def test_track_repository_requires_geometry() -> None: session = DummySession() repo = TrackRepository(session) # type: ignore[arg-type] with pytest.raises(ValueError): repo.create( TrackCreate( start_station_id="00000000-0000-0000-0000-000000000001", end_station_id="00000000-0000-0000-0000-000000000002", coordinates=[(52.0, 13.0)], ) ) def test_track_repository_create_builds_linestring() -> None: session = DummySession() repo = TrackRepository(session) # type: ignore[arg-type] track = repo.create( TrackCreate( name="Main Line", start_station_id="00000000-0000-0000-0000-000000000001", end_station_id="00000000-0000-0000-0000-000000000002", coordinates=[(52.0, 13.0), (53.0, 14.0)], length_meters=1000.5, max_speed_kph=160, is_bidirectional=False, status="operational", ) ) assert session.added and session.added[0] is track assert track.status == "operational" geom_repr = str(track.track_geometry) assert "LINESTRING" in geom_repr assert "13.0 52.0" in geom_repr def test_train_repository_create_supports_optional_ids() -> None: session = DummySession() repo = TrainRepository(session) # type: ignore[arg-type] train = repo.create( TrainCreate( designation="ICE 123", capacity=400, max_speed_kph=300, operator_id=None, home_station_id="00000000-0000-0000-0000-000000000001", consist="locomotive+cars", ) ) assert session.added and session.added[0] is train assert train.designation == "ICE 123" assert str(train.home_station_id).endswith("1") assert train.operator_id is None