This commit is contained in:
@@ -148,7 +148,11 @@ def get_network_snapshot(session: Session) -> dict[str, list[dict[str, object]]]
|
|||||||
if geometry is not None and LineString is not None
|
if geometry is not None and LineString is not None
|
||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
if LineString is not None and shape is not None and isinstance(shape, LineString):
|
if (
|
||||||
|
LineString is not None
|
||||||
|
and shape is not None
|
||||||
|
and isinstance(shape, LineString)
|
||||||
|
):
|
||||||
coords_list: list[tuple[float, float]] = []
|
coords_list: list[tuple[float, float]] = []
|
||||||
for coord in shape.coords:
|
for coord in shape.coords:
|
||||||
lon = float(coord[0])
|
lon = float(coord[0])
|
||||||
|
|||||||
@@ -98,14 +98,12 @@ def normalize_station_elements(
|
|||||||
if not name:
|
if not name:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
raw_code = tags.get("ref") or tags.get(
|
raw_code = tags.get("ref") or tags.get("railway:ref") or tags.get("local_ref")
|
||||||
"railway:ref") or tags.get("local_ref")
|
|
||||||
code = str(raw_code) if raw_code is not None else None
|
code = str(raw_code) if raw_code is not None else None
|
||||||
|
|
||||||
elevation_tag = tags.get("ele") or tags.get("elevation")
|
elevation_tag = tags.get("ele") or tags.get("elevation")
|
||||||
try:
|
try:
|
||||||
elevation = float(
|
elevation = float(elevation_tag) if elevation_tag is not None else None
|
||||||
elevation_tag) if elevation_tag is not None else None
|
|
||||||
except (TypeError, ValueError):
|
except (TypeError, ValueError):
|
||||||
elevation = None
|
elevation = None
|
||||||
|
|
||||||
|
|||||||
@@ -56,7 +56,8 @@ def build_overpass_query(region_name: str) -> str:
|
|||||||
regions = DEFAULT_REGIONS
|
regions = DEFAULT_REGIONS
|
||||||
else:
|
else:
|
||||||
regions = tuple(
|
regions = tuple(
|
||||||
region for region in DEFAULT_REGIONS if region.name == region_name)
|
region for region in DEFAULT_REGIONS if region.name == region_name
|
||||||
|
)
|
||||||
if not regions:
|
if not regions:
|
||||||
available = ", ".join(region.name for region in DEFAULT_REGIONS)
|
available = ", ".join(region.name for region in DEFAULT_REGIONS)
|
||||||
msg = f"Unknown region {region_name}. Available regions: [{available}]"
|
msg = f"Unknown region {region_name}. Available regions: [{available}]"
|
||||||
@@ -86,7 +87,9 @@ def perform_request(query: str) -> dict[str, Any]:
|
|||||||
return json.loads(payload)
|
return json.loads(payload)
|
||||||
|
|
||||||
|
|
||||||
def normalize_track_elements(elements: Iterable[dict[str, Any]]) -> list[dict[str, Any]]:
|
def normalize_track_elements(
|
||||||
|
elements: Iterable[dict[str, Any]]
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
"""Convert Overpass way elements into TrackCreate-compatible payloads."""
|
"""Convert Overpass way elements into TrackCreate-compatible payloads."""
|
||||||
|
|
||||||
tracks: list[dict[str, Any]] = []
|
tracks: list[dict[str, Any]] = []
|
||||||
|
|||||||
@@ -91,13 +91,13 @@ def _parse_track_entries(entries: Iterable[Mapping[str, Any]]) -> list[ParsedTra
|
|||||||
coordinates = entry.get("coordinates")
|
coordinates = entry.get("coordinates")
|
||||||
if not isinstance(coordinates, Sequence) or len(coordinates) < 2:
|
if not isinstance(coordinates, Sequence) or len(coordinates) < 2:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Invalid track entry: 'coordinates' must contain at least two points")
|
"Invalid track entry: 'coordinates' must contain at least two points"
|
||||||
|
)
|
||||||
|
|
||||||
processed_coordinates: list[tuple[float, float]] = []
|
processed_coordinates: list[tuple[float, float]] = []
|
||||||
for pair in coordinates:
|
for pair in coordinates:
|
||||||
if not isinstance(pair, Sequence) or len(pair) != 2:
|
if not isinstance(pair, Sequence) or len(pair) != 2:
|
||||||
raise ValueError(
|
raise ValueError(f"Invalid coordinate pair {pair!r} in track entry")
|
||||||
f"Invalid coordinate pair {pair!r} in track entry")
|
|
||||||
lat, lon = pair
|
lat, lon = pair
|
||||||
processed_coordinates.append((float(lat), float(lon)))
|
processed_coordinates.append((float(lat), float(lon)))
|
||||||
|
|
||||||
@@ -155,7 +155,8 @@ def load_tracks(tracks: Iterable[ParsedTrack], commit: bool = True) -> int:
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
length = track_data.length_meters or _polyline_length(
|
length = track_data.length_meters or _polyline_length(
|
||||||
track_data.coordinates)
|
track_data.coordinates
|
||||||
|
)
|
||||||
max_speed = (
|
max_speed = (
|
||||||
int(round(track_data.max_speed_kph))
|
int(round(track_data.max_speed_kph))
|
||||||
if track_data.max_speed_kph is not None
|
if track_data.max_speed_kph is not None
|
||||||
@@ -192,8 +193,7 @@ def _nearest_station(
|
|||||||
best_station: StationRef | None = None
|
best_station: StationRef | None = None
|
||||||
best_distance = math.inf
|
best_distance = math.inf
|
||||||
for station in stations:
|
for station in stations:
|
||||||
distance = _haversine(
|
distance = _haversine(coordinate, (station.latitude, station.longitude))
|
||||||
coordinate, (station.latitude, station.longitude))
|
|
||||||
if distance < best_distance:
|
if distance < best_distance:
|
||||||
best_station = station
|
best_station = station
|
||||||
best_distance = distance
|
best_distance = distance
|
||||||
@@ -229,7 +229,9 @@ def _to_point(geometry: WKBElement | WKTElement | Any):
|
|||||||
try:
|
try:
|
||||||
point = to_shape(geometry)
|
point = to_shape(geometry)
|
||||||
return point if getattr(point, "geom_type", None) == "Point" else None
|
return point if getattr(point, "geom_type", None) == "Point" else None
|
||||||
except Exception: # pragma: no cover - defensive, should not happen with valid geometry
|
except (
|
||||||
|
Exception
|
||||||
|
): # pragma: no cover - defensive, should not happen with valid geometry
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -50,8 +50,7 @@ def test_network_snapshot_prefers_repository_data(
|
|||||||
track = sample_entities["track"]
|
track = sample_entities["track"]
|
||||||
train = sample_entities["train"]
|
train = sample_entities["train"]
|
||||||
|
|
||||||
monkeypatch.setattr(StationRepository, "list_active",
|
monkeypatch.setattr(StationRepository, "list_active", lambda self: [station])
|
||||||
lambda self: [station])
|
|
||||||
monkeypatch.setattr(TrackRepository, "list_all", lambda self: [track])
|
monkeypatch.setattr(TrackRepository, "list_all", lambda self: [track])
|
||||||
monkeypatch.setattr(TrainRepository, "list_all", lambda self: [train])
|
monkeypatch.setattr(TrainRepository, "list_all", lambda self: [train])
|
||||||
|
|
||||||
@@ -59,8 +58,7 @@ def test_network_snapshot_prefers_repository_data(
|
|||||||
|
|
||||||
assert snapshot["stations"]
|
assert snapshot["stations"]
|
||||||
assert snapshot["stations"][0]["name"] == station.name
|
assert snapshot["stations"][0]["name"] == station.name
|
||||||
assert snapshot["tracks"][0]["lengthMeters"] == pytest.approx(
|
assert snapshot["tracks"][0]["lengthMeters"] == pytest.approx(track.length_meters)
|
||||||
track.length_meters)
|
|
||||||
assert snapshot["trains"][0]["designation"] == train.designation
|
assert snapshot["trains"][0]["designation"] == train.designation
|
||||||
assert snapshot["trains"][0]["operatingTrackIds"] == []
|
assert snapshot["trains"][0]["operatingTrackIds"] == []
|
||||||
|
|
||||||
@@ -76,5 +74,4 @@ def test_network_snapshot_falls_back_when_repositories_empty(
|
|||||||
|
|
||||||
assert snapshot["stations"]
|
assert snapshot["stations"]
|
||||||
assert snapshot["trains"]
|
assert snapshot["trains"]
|
||||||
assert any(station["name"] ==
|
assert any(station["name"] == "Central" for station in snapshot["stations"])
|
||||||
"Central" for station in snapshot["stations"])
|
|
||||||
|
|||||||
@@ -58,7 +58,9 @@ def test_build_stage_plan_respects_skip_flags(tmp_path: Path) -> None:
|
|||||||
assert labels == ["Load stations", "Load tracks"]
|
assert labels == ["Load stations", "Load tracks"]
|
||||||
|
|
||||||
|
|
||||||
def test_main_dry_run_lists_plan(monkeypatch: pytest.MonkeyPatch, tmp_path: Path, capsys: pytest.CaptureFixture[str]) -> None:
|
def test_main_dry_run_lists_plan(
|
||||||
|
monkeypatch: pytest.MonkeyPatch, tmp_path: Path, capsys: pytest.CaptureFixture[str]
|
||||||
|
) -> None:
|
||||||
def fail(_args: list[str] | None) -> int: # pragma: no cover - defensive
|
def fail(_args: list[str] | None) -> int: # pragma: no cover - defensive
|
||||||
raise AssertionError("runner should not be invoked during dry run")
|
raise AssertionError("runner should not be invoked during dry run")
|
||||||
|
|
||||||
@@ -76,7 +78,9 @@ def test_main_dry_run_lists_plan(monkeypatch: pytest.MonkeyPatch, tmp_path: Path
|
|||||||
assert "Load tracks" in captured
|
assert "Load tracks" in captured
|
||||||
|
|
||||||
|
|
||||||
def test_main_executes_stages_in_order(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None:
|
def test_main_executes_stages_in_order(
|
||||||
|
monkeypatch: pytest.MonkeyPatch, tmp_path: Path
|
||||||
|
) -> None:
|
||||||
calls: list[str] = []
|
calls: list[str] = []
|
||||||
|
|
||||||
def make_import(name: str):
|
def make_import(name: str):
|
||||||
@@ -98,14 +102,12 @@ def test_main_executes_stages_in_order(monkeypatch: pytest.MonkeyPatch, tmp_path
|
|||||||
|
|
||||||
return runner
|
return runner
|
||||||
|
|
||||||
monkeypatch.setattr(osm_refresh.stations_import, "main",
|
monkeypatch.setattr(
|
||||||
make_import("stations_import"))
|
osm_refresh.stations_import, "main", make_import("stations_import")
|
||||||
monkeypatch.setattr(osm_refresh.tracks_import, "main",
|
)
|
||||||
make_import("tracks_import"))
|
monkeypatch.setattr(osm_refresh.tracks_import, "main", make_import("tracks_import"))
|
||||||
monkeypatch.setattr(osm_refresh.stations_load, "main",
|
monkeypatch.setattr(osm_refresh.stations_load, "main", make_load("stations_load"))
|
||||||
make_load("stations_load"))
|
monkeypatch.setattr(osm_refresh.tracks_load, "main", make_load("tracks_load"))
|
||||||
monkeypatch.setattr(osm_refresh.tracks_load, "main",
|
|
||||||
make_load("tracks_load"))
|
|
||||||
|
|
||||||
exit_code = osm_refresh.main(["--output-dir", str(tmp_path)])
|
exit_code = osm_refresh.main(["--output-dir", str(tmp_path)])
|
||||||
|
|
||||||
@@ -118,7 +120,9 @@ def test_main_executes_stages_in_order(monkeypatch: pytest.MonkeyPatch, tmp_path
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def test_main_skip_import_flags(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None:
|
def test_main_skip_import_flags(
|
||||||
|
monkeypatch: pytest.MonkeyPatch, tmp_path: Path
|
||||||
|
) -> None:
|
||||||
station_json = tmp_path / "stations.json"
|
station_json = tmp_path / "stations.json"
|
||||||
station_json.write_text("{}", encoding="utf-8")
|
station_json.write_text("{}", encoding="utf-8")
|
||||||
track_json = tmp_path / "tracks.json"
|
track_json = tmp_path / "tracks.json"
|
||||||
@@ -139,8 +143,7 @@ def test_main_skip_import_flags(monkeypatch: pytest.MonkeyPatch, tmp_path: Path)
|
|||||||
|
|
||||||
monkeypatch.setattr(osm_refresh.stations_import, "main", fail)
|
monkeypatch.setattr(osm_refresh.stations_import, "main", fail)
|
||||||
monkeypatch.setattr(osm_refresh.tracks_import, "main", fail)
|
monkeypatch.setattr(osm_refresh.tracks_import, "main", fail)
|
||||||
monkeypatch.setattr(osm_refresh.stations_load,
|
monkeypatch.setattr(osm_refresh.stations_load, "main", record("stations_load"))
|
||||||
"main", record("stations_load"))
|
|
||||||
monkeypatch.setattr(osm_refresh.tracks_load, "main", record("tracks_load"))
|
monkeypatch.setattr(osm_refresh.tracks_load, "main", record("tracks_load"))
|
||||||
|
|
||||||
exit_code = osm_refresh.main(
|
exit_code = osm_refresh.main(
|
||||||
|
|||||||
@@ -103,10 +103,12 @@ def test_load_tracks_creates_entries(monkeypatch: pytest.MonkeyPatch) -> None:
|
|||||||
track_repo_instance = DummyTrackRepository(session_instance)
|
track_repo_instance = DummyTrackRepository(session_instance)
|
||||||
|
|
||||||
monkeypatch.setattr(tracks_load, "SessionLocal", lambda: session_instance)
|
monkeypatch.setattr(tracks_load, "SessionLocal", lambda: session_instance)
|
||||||
monkeypatch.setattr(tracks_load, "StationRepository",
|
monkeypatch.setattr(
|
||||||
lambda session: station_repo_instance)
|
tracks_load, "StationRepository", lambda session: station_repo_instance
|
||||||
monkeypatch.setattr(tracks_load, "TrackRepository",
|
)
|
||||||
lambda session: track_repo_instance)
|
monkeypatch.setattr(
|
||||||
|
tracks_load, "TrackRepository", lambda session: track_repo_instance
|
||||||
|
)
|
||||||
|
|
||||||
parsed = tracks_load._parse_track_entries(
|
parsed = tracks_load._parse_track_entries(
|
||||||
[
|
[
|
||||||
@@ -137,20 +139,26 @@ def test_load_tracks_skips_existing_pairs(monkeypatch: pytest.MonkeyPatch) -> No
|
|||||||
DummyStation(id="station-b", location=_point(52.6, 13.5)),
|
DummyStation(id="station-b", location=_point(52.6, 13.5)),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
existing_track = type("ExistingTrack", (), {
|
existing_track = type(
|
||||||
"start_station_id": "station-a",
|
"ExistingTrack",
|
||||||
"end_station_id": "station-b",
|
(),
|
||||||
})
|
{
|
||||||
|
"start_station_id": "station-a",
|
||||||
|
"end_station_id": "station-b",
|
||||||
|
},
|
||||||
|
)
|
||||||
track_repo_instance = DummyTrackRepository(
|
track_repo_instance = DummyTrackRepository(
|
||||||
session_instance,
|
session_instance,
|
||||||
existing=[existing_track],
|
existing=[existing_track],
|
||||||
)
|
)
|
||||||
|
|
||||||
monkeypatch.setattr(tracks_load, "SessionLocal", lambda: session_instance)
|
monkeypatch.setattr(tracks_load, "SessionLocal", lambda: session_instance)
|
||||||
monkeypatch.setattr(tracks_load, "StationRepository",
|
monkeypatch.setattr(
|
||||||
lambda session: station_repo_instance)
|
tracks_load, "StationRepository", lambda session: station_repo_instance
|
||||||
monkeypatch.setattr(tracks_load, "TrackRepository",
|
)
|
||||||
lambda session: track_repo_instance)
|
monkeypatch.setattr(
|
||||||
|
tracks_load, "TrackRepository", lambda session: track_repo_instance
|
||||||
|
)
|
||||||
|
|
||||||
parsed = tracks_load._parse_track_entries(
|
parsed = tracks_load._parse_track_entries(
|
||||||
[
|
[
|
||||||
@@ -168,7 +176,9 @@ def test_load_tracks_skips_existing_pairs(monkeypatch: pytest.MonkeyPatch) -> No
|
|||||||
assert not track_repo_instance.created
|
assert not track_repo_instance.created
|
||||||
|
|
||||||
|
|
||||||
def test_load_tracks_skips_when_station_too_far(monkeypatch: pytest.MonkeyPatch) -> None:
|
def test_load_tracks_skips_when_station_too_far(
|
||||||
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
|
) -> None:
|
||||||
session_instance = DummySession()
|
session_instance = DummySession()
|
||||||
station_repo_instance = DummyStationRepository(
|
station_repo_instance = DummyStationRepository(
|
||||||
session_instance,
|
session_instance,
|
||||||
@@ -179,10 +189,12 @@ def test_load_tracks_skips_when_station_too_far(monkeypatch: pytest.MonkeyPatch)
|
|||||||
track_repo_instance = DummyTrackRepository(session_instance)
|
track_repo_instance = DummyTrackRepository(session_instance)
|
||||||
|
|
||||||
monkeypatch.setattr(tracks_load, "SessionLocal", lambda: session_instance)
|
monkeypatch.setattr(tracks_load, "SessionLocal", lambda: session_instance)
|
||||||
monkeypatch.setattr(tracks_load, "StationRepository",
|
monkeypatch.setattr(
|
||||||
lambda session: station_repo_instance)
|
tracks_load, "StationRepository", lambda session: station_repo_instance
|
||||||
monkeypatch.setattr(tracks_load, "TrackRepository",
|
)
|
||||||
lambda session: track_repo_instance)
|
monkeypatch.setattr(
|
||||||
|
tracks_load, "TrackRepository", lambda session: track_repo_instance
|
||||||
|
)
|
||||||
|
|
||||||
parsed = tracks_load._parse_track_entries(
|
parsed = tracks_load._parse_track_entries(
|
||||||
[
|
[
|
||||||
|
|||||||
Reference in New Issue
Block a user