Files
rail-game/backend/tests/test_tracks_load.py
zwitschi 25ca7ab196 Add OSM Track Harvesting Policy and demo database initialization script
- Updated documentation to include OSM Track Harvesting Policy with details on railway types, service filters, usage filters, and geometry guardrails.
- Introduced a new script `init_demo_db.py` to automate the database setup process, including environment checks, running migrations, and loading OSM fixtures for demo data.
2025-10-11 21:37:25 +02:00

201 lines
5.7 KiB
Python

from __future__ import annotations
from dataclasses import dataclass, field
from typing import List
import pytest
from geoalchemy2.shape import from_shape
from shapely.geometry import Point
from backend.scripts import tracks_load
def test_parse_track_entries_returns_models() -> None:
entries = [
{
"name": "Connector",
"coordinates": [[52.5, 13.4], [52.6, 13.5]],
"lengthMeters": 1500,
"maxSpeedKph": 120,
"status": "operational",
"isBidirectional": True,
}
]
parsed = tracks_load._parse_track_entries(entries)
assert parsed[0].name == "Connector"
assert parsed[0].coordinates[0] == (52.5, 13.4)
assert parsed[0].length_meters == 1500
assert parsed[0].max_speed_kph == 120
def test_parse_track_entries_invalid_raises_value_error() -> None:
entries = [
{
"coordinates": [[52.5, 13.4]],
}
]
with pytest.raises(ValueError):
tracks_load._parse_track_entries(entries)
@dataclass
class DummySession:
committed: bool = False
rolled_back: bool = False
def __enter__(self) -> "DummySession":
return self
def __exit__(self, exc_type, exc, traceback) -> None:
pass
def commit(self) -> None:
self.committed = True
def rollback(self) -> None:
self.rolled_back = True
@dataclass
class DummyStation:
id: str
location: object
@dataclass
class DummyStationRepository:
session: DummySession
stations: List[DummyStation]
def list_active(self) -> List[DummyStation]:
return self.stations
@dataclass
class DummyTrackRepository:
session: DummySession
created: list = field(default_factory=list)
existing: list = field(default_factory=list)
def list_all(self):
return self.existing
def create(self, data): # pragma: no cover - simple delegation
self.created.append(data)
def _point(lat: float, lon: float) -> object:
return from_shape(Point(lon, lat), srid=4326)
def test_load_tracks_creates_entries(monkeypatch: pytest.MonkeyPatch) -> None:
session_instance = DummySession()
station_repo_instance = DummyStationRepository(
session_instance,
stations=[
DummyStation(id="station-a", location=_point(52.5, 13.4)),
DummyStation(id="station-b", location=_point(52.6, 13.5)),
],
)
track_repo_instance = DummyTrackRepository(session_instance)
monkeypatch.setattr(tracks_load, "SessionLocal", lambda: session_instance)
monkeypatch.setattr(tracks_load, "StationRepository",
lambda session: station_repo_instance)
monkeypatch.setattr(tracks_load, "TrackRepository",
lambda session: track_repo_instance)
parsed = tracks_load._parse_track_entries(
[
{
"name": "Connector",
"coordinates": [[52.5, 13.4], [52.6, 13.5]],
}
]
)
created = tracks_load.load_tracks(parsed, commit=True)
assert created == 1
assert session_instance.committed is True
assert track_repo_instance.created
track = track_repo_instance.created[0]
assert track.start_station_id == "station-a"
assert track.end_station_id == "station-b"
assert track.coordinates == [(52.5, 13.4), (52.6, 13.5)]
def test_load_tracks_skips_existing_pairs(monkeypatch: pytest.MonkeyPatch) -> None:
session_instance = DummySession()
station_repo_instance = DummyStationRepository(
session_instance,
stations=[
DummyStation(id="station-a", location=_point(52.5, 13.4)),
DummyStation(id="station-b", location=_point(52.6, 13.5)),
],
)
existing_track = type("ExistingTrack", (), {
"start_station_id": "station-a",
"end_station_id": "station-b",
})
track_repo_instance = DummyTrackRepository(
session_instance,
existing=[existing_track],
)
monkeypatch.setattr(tracks_load, "SessionLocal", lambda: session_instance)
monkeypatch.setattr(tracks_load, "StationRepository",
lambda session: station_repo_instance)
monkeypatch.setattr(tracks_load, "TrackRepository",
lambda session: track_repo_instance)
parsed = tracks_load._parse_track_entries(
[
{
"name": "Connector",
"coordinates": [[52.5, 13.4], [52.6, 13.5]],
}
]
)
created = tracks_load.load_tracks(parsed, commit=False)
assert created == 0
assert session_instance.rolled_back is True
assert not track_repo_instance.created
def test_load_tracks_skips_when_station_too_far(monkeypatch: pytest.MonkeyPatch) -> None:
session_instance = DummySession()
station_repo_instance = DummyStationRepository(
session_instance,
stations=[
DummyStation(id="remote-station", location=_point(53.5, 14.5)),
],
)
track_repo_instance = DummyTrackRepository(session_instance)
monkeypatch.setattr(tracks_load, "SessionLocal", lambda: session_instance)
monkeypatch.setattr(tracks_load, "StationRepository",
lambda session: station_repo_instance)
monkeypatch.setattr(tracks_load, "TrackRepository",
lambda session: track_repo_instance)
parsed = tracks_load._parse_track_entries(
[
{
"name": "Isolated Segment",
"coordinates": [[52.5, 13.4], [52.51, 13.41]],
}
]
)
created = tracks_load.load_tracks(parsed, commit=True)
assert created == 0
assert session_instance.committed is True
assert not track_repo_instance.created