from __future__ import annotations from dataclasses import dataclass, field from typing import List import pytest from geoalchemy2.shape import from_shape from shapely.geometry import Point from backend.scripts import tracks_load def test_parse_track_entries_returns_models() -> None: entries = [ { "name": "Connector", "coordinates": [[52.5, 13.4], [52.6, 13.5]], "lengthMeters": 1500, "maxSpeedKph": 120, "status": "operational", "isBidirectional": True, } ] parsed = tracks_load._parse_track_entries(entries) assert parsed[0].name == "Connector" assert parsed[0].coordinates[0] == (52.5, 13.4) assert parsed[0].length_meters == 1500 assert parsed[0].max_speed_kph == 120 def test_parse_track_entries_invalid_raises_value_error() -> None: entries = [ { "coordinates": [[52.5, 13.4]], } ] with pytest.raises(ValueError): tracks_load._parse_track_entries(entries) @dataclass class DummySession: committed: bool = False rolled_back: bool = False def __enter__(self) -> "DummySession": return self def __exit__(self, exc_type, exc, traceback) -> None: pass def commit(self) -> None: self.committed = True def rollback(self) -> None: self.rolled_back = True @dataclass class DummyStation: id: str location: object @dataclass class DummyStationRepository: session: DummySession stations: List[DummyStation] def list_active(self) -> List[DummyStation]: return self.stations @dataclass class DummyTrackRepository: session: DummySession created: list = field(default_factory=list) existing: list = field(default_factory=list) def list_all(self): return self.existing def create(self, data): # pragma: no cover - simple delegation self.created.append(data) def _point(lat: float, lon: float) -> object: return from_shape(Point(lon, lat), srid=4326) def test_load_tracks_creates_entries(monkeypatch: pytest.MonkeyPatch) -> None: session_instance = DummySession() station_repo_instance = DummyStationRepository( session_instance, stations=[ DummyStation(id="station-a", location=_point(52.5, 13.4)), DummyStation(id="station-b", location=_point(52.6, 13.5)), ], ) track_repo_instance = DummyTrackRepository(session_instance) monkeypatch.setattr(tracks_load, "SessionLocal", lambda: session_instance) monkeypatch.setattr( tracks_load, "StationRepository", lambda session: station_repo_instance ) monkeypatch.setattr( tracks_load, "TrackRepository", lambda session: track_repo_instance ) parsed = tracks_load._parse_track_entries( [ { "name": "Connector", "coordinates": [[52.5, 13.4], [52.6, 13.5]], } ] ) created = tracks_load.load_tracks(parsed, commit=True) assert created == 1 assert session_instance.committed is True assert track_repo_instance.created track = track_repo_instance.created[0] assert track.start_station_id == "station-a" assert track.end_station_id == "station-b" assert track.coordinates == [(52.5, 13.4), (52.6, 13.5)] def test_load_tracks_skips_existing_pairs(monkeypatch: pytest.MonkeyPatch) -> None: session_instance = DummySession() station_repo_instance = DummyStationRepository( session_instance, stations=[ DummyStation(id="station-a", location=_point(52.5, 13.4)), DummyStation(id="station-b", location=_point(52.6, 13.5)), ], ) existing_track = type( "ExistingTrack", (), { "start_station_id": "station-a", "end_station_id": "station-b", }, ) track_repo_instance = DummyTrackRepository( session_instance, existing=[existing_track], ) monkeypatch.setattr(tracks_load, "SessionLocal", lambda: session_instance) monkeypatch.setattr( tracks_load, "StationRepository", lambda session: station_repo_instance ) monkeypatch.setattr( tracks_load, "TrackRepository", lambda session: track_repo_instance ) parsed = tracks_load._parse_track_entries( [ { "name": "Connector", "coordinates": [[52.5, 13.4], [52.6, 13.5]], } ] ) created = tracks_load.load_tracks(parsed, commit=False) assert created == 0 assert session_instance.rolled_back is True assert not track_repo_instance.created def test_load_tracks_skips_when_station_too_far( monkeypatch: pytest.MonkeyPatch, ) -> None: session_instance = DummySession() station_repo_instance = DummyStationRepository( session_instance, stations=[ DummyStation(id="remote-station", location=_point(53.5, 14.5)), ], ) track_repo_instance = DummyTrackRepository(session_instance) monkeypatch.setattr(tracks_load, "SessionLocal", lambda: session_instance) monkeypatch.setattr( tracks_load, "StationRepository", lambda session: station_repo_instance ) monkeypatch.setattr( tracks_load, "TrackRepository", lambda session: track_repo_instance ) parsed = tracks_load._parse_track_entries( [ { "name": "Isolated Segment", "coordinates": [[52.5, 13.4], [52.51, 13.41]], } ] ) created = tracks_load.load_tracks(parsed, commit=True) assert created == 0 assert session_instance.committed is True assert not track_repo_instance.created