- Updated documentation to include OSM Track Harvesting Policy with details on railway types, service filters, usage filters, and geometry guardrails. - Introduced a new script `init_demo_db.py` to automate the database setup process, including environment checks, running migrations, and loading OSM fixtures for demo data.
201 lines
5.7 KiB
Python
201 lines
5.7 KiB
Python
from __future__ import annotations
|
|
|
|
from dataclasses import dataclass, field
|
|
from typing import List
|
|
|
|
import pytest
|
|
from geoalchemy2.shape import from_shape
|
|
from shapely.geometry import Point
|
|
|
|
from backend.scripts import tracks_load
|
|
|
|
|
|
def test_parse_track_entries_returns_models() -> None:
|
|
entries = [
|
|
{
|
|
"name": "Connector",
|
|
"coordinates": [[52.5, 13.4], [52.6, 13.5]],
|
|
"lengthMeters": 1500,
|
|
"maxSpeedKph": 120,
|
|
"status": "operational",
|
|
"isBidirectional": True,
|
|
}
|
|
]
|
|
|
|
parsed = tracks_load._parse_track_entries(entries)
|
|
|
|
assert parsed[0].name == "Connector"
|
|
assert parsed[0].coordinates[0] == (52.5, 13.4)
|
|
assert parsed[0].length_meters == 1500
|
|
assert parsed[0].max_speed_kph == 120
|
|
|
|
|
|
def test_parse_track_entries_invalid_raises_value_error() -> None:
|
|
entries = [
|
|
{
|
|
"coordinates": [[52.5, 13.4]],
|
|
}
|
|
]
|
|
|
|
with pytest.raises(ValueError):
|
|
tracks_load._parse_track_entries(entries)
|
|
|
|
|
|
@dataclass
|
|
class DummySession:
|
|
committed: bool = False
|
|
rolled_back: bool = False
|
|
|
|
def __enter__(self) -> "DummySession":
|
|
return self
|
|
|
|
def __exit__(self, exc_type, exc, traceback) -> None:
|
|
pass
|
|
|
|
def commit(self) -> None:
|
|
self.committed = True
|
|
|
|
def rollback(self) -> None:
|
|
self.rolled_back = True
|
|
|
|
|
|
@dataclass
|
|
class DummyStation:
|
|
id: str
|
|
location: object
|
|
|
|
|
|
@dataclass
|
|
class DummyStationRepository:
|
|
session: DummySession
|
|
stations: List[DummyStation]
|
|
|
|
def list_active(self) -> List[DummyStation]:
|
|
return self.stations
|
|
|
|
|
|
@dataclass
|
|
class DummyTrackRepository:
|
|
session: DummySession
|
|
created: list = field(default_factory=list)
|
|
existing: list = field(default_factory=list)
|
|
|
|
def list_all(self):
|
|
return self.existing
|
|
|
|
def create(self, data): # pragma: no cover - simple delegation
|
|
self.created.append(data)
|
|
|
|
|
|
def _point(lat: float, lon: float) -> object:
|
|
return from_shape(Point(lon, lat), srid=4326)
|
|
|
|
|
|
def test_load_tracks_creates_entries(monkeypatch: pytest.MonkeyPatch) -> None:
|
|
session_instance = DummySession()
|
|
station_repo_instance = DummyStationRepository(
|
|
session_instance,
|
|
stations=[
|
|
DummyStation(id="station-a", location=_point(52.5, 13.4)),
|
|
DummyStation(id="station-b", location=_point(52.6, 13.5)),
|
|
],
|
|
)
|
|
track_repo_instance = DummyTrackRepository(session_instance)
|
|
|
|
monkeypatch.setattr(tracks_load, "SessionLocal", lambda: session_instance)
|
|
monkeypatch.setattr(tracks_load, "StationRepository",
|
|
lambda session: station_repo_instance)
|
|
monkeypatch.setattr(tracks_load, "TrackRepository",
|
|
lambda session: track_repo_instance)
|
|
|
|
parsed = tracks_load._parse_track_entries(
|
|
[
|
|
{
|
|
"name": "Connector",
|
|
"coordinates": [[52.5, 13.4], [52.6, 13.5]],
|
|
}
|
|
]
|
|
)
|
|
|
|
created = tracks_load.load_tracks(parsed, commit=True)
|
|
|
|
assert created == 1
|
|
assert session_instance.committed is True
|
|
assert track_repo_instance.created
|
|
track = track_repo_instance.created[0]
|
|
assert track.start_station_id == "station-a"
|
|
assert track.end_station_id == "station-b"
|
|
assert track.coordinates == [(52.5, 13.4), (52.6, 13.5)]
|
|
|
|
|
|
def test_load_tracks_skips_existing_pairs(monkeypatch: pytest.MonkeyPatch) -> None:
|
|
session_instance = DummySession()
|
|
station_repo_instance = DummyStationRepository(
|
|
session_instance,
|
|
stations=[
|
|
DummyStation(id="station-a", location=_point(52.5, 13.4)),
|
|
DummyStation(id="station-b", location=_point(52.6, 13.5)),
|
|
],
|
|
)
|
|
existing_track = type("ExistingTrack", (), {
|
|
"start_station_id": "station-a",
|
|
"end_station_id": "station-b",
|
|
})
|
|
track_repo_instance = DummyTrackRepository(
|
|
session_instance,
|
|
existing=[existing_track],
|
|
)
|
|
|
|
monkeypatch.setattr(tracks_load, "SessionLocal", lambda: session_instance)
|
|
monkeypatch.setattr(tracks_load, "StationRepository",
|
|
lambda session: station_repo_instance)
|
|
monkeypatch.setattr(tracks_load, "TrackRepository",
|
|
lambda session: track_repo_instance)
|
|
|
|
parsed = tracks_load._parse_track_entries(
|
|
[
|
|
{
|
|
"name": "Connector",
|
|
"coordinates": [[52.5, 13.4], [52.6, 13.5]],
|
|
}
|
|
]
|
|
)
|
|
|
|
created = tracks_load.load_tracks(parsed, commit=False)
|
|
|
|
assert created == 0
|
|
assert session_instance.rolled_back is True
|
|
assert not track_repo_instance.created
|
|
|
|
|
|
def test_load_tracks_skips_when_station_too_far(monkeypatch: pytest.MonkeyPatch) -> None:
|
|
session_instance = DummySession()
|
|
station_repo_instance = DummyStationRepository(
|
|
session_instance,
|
|
stations=[
|
|
DummyStation(id="remote-station", location=_point(53.5, 14.5)),
|
|
],
|
|
)
|
|
track_repo_instance = DummyTrackRepository(session_instance)
|
|
|
|
monkeypatch.setattr(tracks_load, "SessionLocal", lambda: session_instance)
|
|
monkeypatch.setattr(tracks_load, "StationRepository",
|
|
lambda session: station_repo_instance)
|
|
monkeypatch.setattr(tracks_load, "TrackRepository",
|
|
lambda session: track_repo_instance)
|
|
|
|
parsed = tracks_load._parse_track_entries(
|
|
[
|
|
{
|
|
"name": "Isolated Segment",
|
|
"coordinates": [[52.5, 13.4], [52.51, 13.41]],
|
|
}
|
|
]
|
|
)
|
|
|
|
created = tracks_load.load_tracks(parsed, commit=True)
|
|
|
|
assert created == 0
|
|
assert session_instance.committed is True
|
|
assert not track_repo_instance.created
|