feat: Enhance track model and import functionality

- Added new fields to TrackModel: status, is_bidirectional, and coordinates.
- Updated network service to handle new track attributes and geometry extraction.
- Introduced CLI scripts for importing and loading tracks from OpenStreetMap.
- Implemented normalization of track elements to ensure valid geometries.
- Enhanced tests for track model, network service, and import/load scripts.
- Updated frontend to accommodate new track attributes and improve route computation.
- Documented OSM ingestion process in architecture and runtime views.
This commit is contained in:
2025-10-11 19:54:10 +02:00
parent 090dca29c2
commit c2927f2f60
18 changed files with 968 additions and 52 deletions

View File

@@ -75,6 +75,18 @@ STATION_TAG_FILTERS: Mapping[str, Tuple[str, ...]] = {
}
# Tags that describe rail infrastructure usable for train routing.
TRACK_TAG_FILTERS: Mapping[str, Tuple[str, ...]] = {
"railway": (
"rail",
"light_rail",
"subway",
"tram",
"narrow_gauge",
),
}
def compile_overpass_filters(filters: Mapping[str, Iterable[str]]) -> str:
"""Build an Overpass boolean expression that matches the provided filters."""
@@ -89,5 +101,6 @@ __all__ = [
"BoundingBox",
"DEFAULT_REGIONS",
"STATION_TAG_FILTERS",
"TRACK_TAG_FILTERS",
"compile_overpass_filters",
]

View File

@@ -3,7 +3,7 @@ from __future__ import annotations
from datetime import datetime
from typing import Generic, Sequence, TypeVar
from pydantic import BaseModel, ConfigDict
from pydantic import BaseModel, ConfigDict, Field
def to_camel(string: str) -> str:
@@ -53,6 +53,9 @@ class TrackModel(IdentifiedModel[str]):
end_station_id: str
length_meters: float
max_speed_kph: float
status: str | None = None
is_bidirectional: bool = True
coordinates: list[tuple[float, float]] = Field(default_factory=list)
class TrainModel(IdentifiedModel[str]):

View File

@@ -8,9 +8,10 @@ from geoalchemy2.elements import WKBElement, WKTElement
from geoalchemy2.shape import to_shape
try: # pragma: no cover - optional dependency guard
from shapely.geometry import Point # type: ignore
from shapely.geometry import LineString, Point # type: ignore
except ImportError: # pragma: no cover - allow running without shapely at import time
Point = None # type: ignore[assignment]
LineString = None # type: ignore[assignment]
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.orm import Session
@@ -51,6 +52,12 @@ def _fallback_snapshot() -> dict[str, list[dict[str, object]]]:
end_station_id="station-2",
length_meters=289000.0,
max_speed_kph=230.0,
status="operational",
is_bidirectional=True,
coordinates=[
(stations[0].latitude, stations[0].longitude),
(stations[1].latitude, stations[1].longitude),
],
created_at=now,
updated_at=now,
)
@@ -134,6 +141,20 @@ def get_network_snapshot(session: Session) -> dict[str, list[dict[str, object]]]
track_models: list[TrackModel] = []
for track in tracks_entities:
coordinates: list[tuple[float, float]] = []
geometry = track.track_geometry
shape = (
to_shape(cast(WKBElement | WKTElement, geometry))
if geometry is not None and LineString is not None
else None
)
if LineString is not None and shape is not None and isinstance(shape, LineString):
coords_list: list[tuple[float, float]] = []
for coord in shape.coords:
lon = float(coord[0])
lat = float(coord[1])
coords_list.append((lat, lon))
coordinates = coords_list
track_models.append(
TrackModel(
id=str(track.id),
@@ -141,6 +162,9 @@ def get_network_snapshot(session: Session) -> dict[str, list[dict[str, object]]]
end_station_id=str(track.end_station_id),
length_meters=_to_float(track.length_meters),
max_speed_kph=_to_float(track.max_speed_kph),
status=track.status,
is_bidirectional=track.is_bidirectional,
coordinates=coordinates,
created_at=cast(datetime, track.created_at),
updated_at=cast(datetime, track.updated_at),
)

View File

@@ -0,0 +1,234 @@
from __future__ import annotations
"""CLI utility to export rail track geometries from OpenStreetMap."""
import argparse
import json
import math
import sys
from dataclasses import asdict
from pathlib import Path
from typing import Any, Iterable
from urllib.parse import quote_plus
from backend.app.core.osm_config import (
DEFAULT_REGIONS,
TRACK_TAG_FILTERS,
compile_overpass_filters,
)
OVERPASS_ENDPOINT = "https://overpass-api.de/api/interpreter"
def build_argument_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(
description="Export OSM rail track ways for ingestion",
)
parser.add_argument(
"--output",
type=Path,
default=Path("data/osm_tracks.json"),
help=(
"Destination file for the exported track geometries "
"(default: data/osm_tracks.json)"
),
)
parser.add_argument(
"--region",
choices=[region.name for region in DEFAULT_REGIONS] + ["all"],
default="all",
help="Region name to export (default: all)",
)
parser.add_argument(
"--dry-run",
action="store_true",
help="Do not fetch data; print the Overpass payload only",
)
return parser
def build_overpass_query(region_name: str) -> str:
if region_name == "all":
regions = DEFAULT_REGIONS
else:
regions = tuple(
region for region in DEFAULT_REGIONS if region.name == region_name)
if not regions:
available = ", ".join(region.name for region in DEFAULT_REGIONS)
msg = f"Unknown region {region_name}. Available regions: [{available}]"
raise ValueError(msg)
filters = compile_overpass_filters(TRACK_TAG_FILTERS)
parts = ["[out:json][timeout:120];", "("]
for region in regions:
parts.append(f" way{filters}\n ({region.to_overpass_arg()});")
parts.append(")")
parts.append("; out body geom; >; out skel qt;")
return "\n".join(parts)
def perform_request(query: str) -> dict[str, Any]:
import urllib.request
payload = f"data={quote_plus(query)}".encode("utf-8")
request = urllib.request.Request(
OVERPASS_ENDPOINT,
data=payload,
headers={"Content-Type": "application/x-www-form-urlencoded"},
)
with urllib.request.urlopen(request, timeout=180) as response:
payload = response.read()
return json.loads(payload)
def normalize_track_elements(elements: Iterable[dict[str, Any]]) -> list[dict[str, Any]]:
"""Convert Overpass way elements into TrackCreate-compatible payloads."""
tracks: list[dict[str, Any]] = []
for element in elements:
if element.get("type") != "way":
continue
raw_geometry = element.get("geometry") or []
coordinates: list[list[float]] = []
for node in raw_geometry:
lat = node.get("lat")
lon = node.get("lon")
if lat is None or lon is None:
coordinates = []
break
coordinates.append([float(lat), float(lon)])
if len(coordinates) < 2:
continue
tags: dict[str, Any] = element.get("tags", {})
name = tags.get("name")
maxspeed = _parse_maxspeed(tags.get("maxspeed"))
status = _derive_status(tags.get("railway"))
is_bidirectional = not _is_oneway(tags.get("oneway"))
length_meters = _polyline_length(coordinates)
tracks.append(
{
"osmId": str(element.get("id")),
"name": str(name) if name else None,
"lengthMeters": length_meters,
"maxSpeedKph": maxspeed,
"status": status,
"isBidirectional": is_bidirectional,
"coordinates": coordinates,
}
)
return tracks
def _parse_maxspeed(value: Any) -> float | None:
if value is None:
return None
# Overpass may return values such as "80" or "80 km/h" or "signals".
if isinstance(value, (int, float)):
return float(value)
text = str(value).strip()
number = ""
for char in text:
if char.isdigit() or char == ".":
number += char
elif number:
break
try:
return float(number) if number else None
except ValueError:
return None
def _derive_status(value: Any) -> str:
tag = str(value or "").lower()
if tag in {"abandoned", "disused"}:
return tag
if tag in {"construction", "proposed"}:
return "construction"
return "operational"
def _is_oneway(value: Any) -> bool:
if value is None:
return False
normalized = str(value).strip().lower()
return normalized in {"yes", "true", "1"}
def _polyline_length(points: list[list[float]]) -> float:
if len(points) < 2:
return 0.0
total = 0.0
for index in range(len(points) - 1):
total += _haversine(points[index], points[index + 1])
return total
def _haversine(a: list[float], b: list[float]) -> float:
"""Return distance in meters between two [lat, lon] coordinates."""
lat1, lon1 = a
lat2, lon2 = b
radius = 6_371_000
phi1 = math.radians(lat1)
phi2 = math.radians(lat2)
delta_phi = math.radians(lat2 - lat1)
delta_lambda = math.radians(lon2 - lon1)
sin_dphi = math.sin(delta_phi / 2)
sin_dlambda = math.sin(delta_lambda / 2)
root = sin_dphi**2 + math.cos(phi1) * math.cos(phi2) * sin_dlambda**2
distance = 2 * radius * math.atan2(math.sqrt(root), math.sqrt(1 - root))
return distance
def main(argv: list[str] | None = None) -> int:
parser = build_argument_parser()
args = parser.parse_args(argv)
query = build_overpass_query(args.region)
if args.dry_run:
print(query)
return 0
output_path: Path = args.output
output_path.parent.mkdir(parents=True, exist_ok=True)
data = perform_request(query)
raw_elements = data.get("elements", [])
tracks = normalize_track_elements(raw_elements)
payload = {
"metadata": {
"endpoint": OVERPASS_ENDPOINT,
"region": args.region,
"filters": TRACK_TAG_FILTERS,
"regions": [asdict(region) for region in DEFAULT_REGIONS],
"raw_count": len(raw_elements),
"track_count": len(tracks),
},
"tracks": tracks,
}
with output_path.open("w", encoding="utf-8") as handle:
json.dump(payload, handle, indent=2)
print(
f"Normalized {len(tracks)} tracks from {len(raw_elements)} elements into {output_path}"
)
return 0
if __name__ == "__main__":
sys.exit(main())

View File

@@ -0,0 +1,243 @@
from __future__ import annotations
"""CLI for loading normalized track JSON into the database."""
import argparse
import json
import math
import sys
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Iterable, Mapping, Sequence
from geoalchemy2.elements import WKBElement, WKTElement
from geoalchemy2.shape import to_shape
from backend.app.db.session import SessionLocal
from backend.app.models import TrackCreate
from backend.app.repositories import StationRepository, TrackRepository
@dataclass(slots=True)
class ParsedTrack:
coordinates: list[tuple[float, float]]
name: str | None = None
length_meters: float | None = None
max_speed_kph: float | None = None
status: str = "operational"
is_bidirectional: bool = True
@dataclass(slots=True)
class StationRef:
id: str
latitude: float
longitude: float
def build_argument_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(
description="Load normalized track data into PostGIS",
)
parser.add_argument(
"input",
type=Path,
help="Path to the normalized track JSON file produced by tracks_import.py",
)
parser.add_argument(
"--commit/--no-commit",
dest="commit",
default=True,
help="Commit the transaction (default: commit). Use --no-commit for dry runs.",
)
return parser
def main(argv: list[str] | None = None) -> int:
parser = build_argument_parser()
args = parser.parse_args(argv)
if not args.input.exists():
parser.error(f"Input file {args.input} does not exist")
with args.input.open("r", encoding="utf-8") as handle:
payload = json.load(handle)
track_entries = payload.get("tracks") or []
if not isinstance(track_entries, list):
parser.error("Invalid payload: 'tracks' must be a list")
try:
tracks = _parse_track_entries(track_entries)
except ValueError as exc:
parser.error(str(exc))
created = load_tracks(tracks, commit=args.commit)
print(f"Loaded {created} tracks from {args.input}")
return 0
def _parse_track_entries(entries: Iterable[Mapping[str, Any]]) -> list[ParsedTrack]:
parsed: list[ParsedTrack] = []
for entry in entries:
coordinates = entry.get("coordinates")
if not isinstance(coordinates, Sequence) or len(coordinates) < 2:
raise ValueError(
"Invalid track entry: 'coordinates' must contain at least two points")
processed_coordinates: list[tuple[float, float]] = []
for pair in coordinates:
if not isinstance(pair, Sequence) or len(pair) != 2:
raise ValueError(
f"Invalid coordinate pair {pair!r} in track entry")
lat, lon = pair
processed_coordinates.append((float(lat), float(lon)))
name = entry.get("name")
length = _safe_float(entry.get("lengthMeters"))
max_speed = _safe_float(entry.get("maxSpeedKph"))
status = entry.get("status", "operational")
is_bidirectional = entry.get("isBidirectional", True)
parsed.append(
ParsedTrack(
coordinates=processed_coordinates,
name=str(name) if name else None,
length_meters=length,
max_speed_kph=max_speed,
status=str(status) if status else "operational",
is_bidirectional=bool(is_bidirectional),
)
)
return parsed
def load_tracks(tracks: Iterable[ParsedTrack], commit: bool = True) -> int:
created = 0
with SessionLocal() as session:
station_repo = StationRepository(session)
track_repo = TrackRepository(session)
station_index = _build_station_index(station_repo.list_active())
existing_pairs = {
(str(track.start_station_id), str(track.end_station_id))
for track in track_repo.list_all()
}
for track_data in tracks:
start_station = _nearest_station(
track_data.coordinates[0], station_index)
end_station = _nearest_station(
track_data.coordinates[-1], station_index)
if not start_station or not end_station:
continue
pair = (start_station.id, end_station.id)
if pair in existing_pairs:
continue
length = track_data.length_meters or _polyline_length(
track_data.coordinates)
create_schema = TrackCreate(
name=track_data.name,
start_station_id=start_station.id,
end_station_id=end_station.id,
coordinates=track_data.coordinates,
length_meters=length,
max_speed_kph=track_data.max_speed_kph,
status=track_data.status,
is_bidirectional=track_data.is_bidirectional,
)
track_repo.create(create_schema)
existing_pairs.add(pair)
created += 1
if commit:
session.commit()
else:
session.rollback()
return created
def _nearest_station(
coordinate: tuple[float, float], stations: Sequence[StationRef]
) -> StationRef | None:
best_station: StationRef | None = None
best_distance = math.inf
for station in stations:
distance = _haversine(
coordinate, (station.latitude, station.longitude))
if distance < best_distance:
best_station = station
best_distance = distance
return best_station
def _build_station_index(stations: Iterable[Any]) -> list[StationRef]:
index: list[StationRef] = []
for station in stations:
location = getattr(station, "location", None)
if location is None:
continue
point = _to_point(location)
if point is None:
continue
index.append(
StationRef(
id=str(station.id),
latitude=float(point.y),
longitude=float(point.x),
)
)
return index
def _to_point(geometry: WKBElement | WKTElement | Any):
try:
point = to_shape(geometry)
return point if getattr(point, "geom_type", None) == "Point" else None
except Exception: # pragma: no cover - defensive, should not happen with valid geometry
return None
def _polyline_length(points: Sequence[tuple[float, float]]) -> float:
if len(points) < 2:
return 0.0
total = 0.0
for index in range(len(points) - 1):
total += _haversine(points[index], points[index + 1])
return total
def _haversine(a: tuple[float, float], b: tuple[float, float]) -> float:
lat1, lon1 = a
lat2, lon2 = b
radius = 6_371_000
phi1 = math.radians(lat1)
phi2 = math.radians(lat2)
delta_phi = math.radians(lat2 - lat1)
delta_lambda = math.radians(lon2 - lon1)
sin_dphi = math.sin(delta_phi / 2)
sin_dlambda = math.sin(delta_lambda / 2)
root = sin_dphi**2 + math.cos(phi1) * math.cos(phi2) * sin_dlambda**2
distance = 2 * radius * math.atan2(math.sqrt(root), math.sqrt(1 - root))
return distance
def _safe_float(value: Any) -> float | None:
if value is None or value == "":
return None
try:
return float(value)
except (TypeError, ValueError):
return None
if __name__ == "__main__":
sys.exit(main())

View File

@@ -29,11 +29,15 @@ def test_track_model_properties() -> None:
end_station_id="station-2",
length_meters=1500.0,
max_speed_kph=120.0,
status="operational",
is_bidirectional=True,
coordinates=[(52.52, 13.405), (52.6, 13.5)],
created_at=timestamp,
updated_at=timestamp,
)
assert track.length_meters > 0
assert track.start_station_id != track.end_station_id
assert len(track.coordinates) == 2
def test_train_model_operating_tracks() -> None:

View File

@@ -26,6 +26,9 @@ def sample_entities() -> dict[str, SimpleNamespace]:
end_station_id=station.id,
length_meters=1234.5,
max_speed_kph=160,
status="operational",
is_bidirectional=True,
track_geometry=None,
created_at=timestamp,
updated_at=timestamp,
)
@@ -47,7 +50,8 @@ def test_network_snapshot_prefers_repository_data(
track = sample_entities["track"]
train = sample_entities["train"]
monkeypatch.setattr(StationRepository, "list_active", lambda self: [station])
monkeypatch.setattr(StationRepository, "list_active",
lambda self: [station])
monkeypatch.setattr(TrackRepository, "list_all", lambda self: [track])
monkeypatch.setattr(TrainRepository, "list_all", lambda self: [train])
@@ -55,7 +59,8 @@ def test_network_snapshot_prefers_repository_data(
assert snapshot["stations"]
assert snapshot["stations"][0]["name"] == station.name
assert snapshot["tracks"][0]["lengthMeters"] == pytest.approx(track.length_meters)
assert snapshot["tracks"][0]["lengthMeters"] == pytest.approx(
track.length_meters)
assert snapshot["trains"][0]["designation"] == train.designation
assert snapshot["trains"][0]["operatingTrackIds"] == []
@@ -71,4 +76,5 @@ def test_network_snapshot_falls_back_when_repositories_empty(
assert snapshot["stations"]
assert snapshot["trains"]
assert any(station["name"] == "Central" for station in snapshot["stations"])
assert any(station["name"] ==
"Central" for station in snapshot["stations"])

View File

@@ -0,0 +1,69 @@
from __future__ import annotations
from backend.scripts import tracks_import
def test_normalize_track_elements_excludes_invalid_geometries() -> None:
elements = [
{
"type": "way",
"id": 123,
"geometry": [
{"lat": 52.5, "lon": 13.4},
{"lat": 52.6, "lon": 13.5},
],
"tags": {
"name": "Main Line",
"railway": "rail",
"maxspeed": "120",
},
},
{
"type": "way",
"id": 456,
"geometry": [
{"lat": 51.0},
],
"tags": {"railway": "rail"},
},
{
"type": "node",
"id": 789,
},
]
tracks = tracks_import.normalize_track_elements(elements)
assert len(tracks) == 1
track = tracks[0]
assert track["osmId"] == "123"
assert track["name"] == "Main Line"
assert track["maxSpeedKph"] == 120.0
assert track["status"] == "operational"
assert track["isBidirectional"] is True
assert track["coordinates"] == [[52.5, 13.4], [52.6, 13.5]]
assert track["lengthMeters"] > 0
def test_normalize_track_elements_marks_oneway_and_status() -> None:
elements = [
{
"type": "way",
"id": 42,
"geometry": [
{"lat": 48.1, "lon": 11.5},
{"lat": 48.2, "lon": 11.6},
],
"tags": {
"railway": "disused",
"oneway": "yes",
},
}
]
tracks = tracks_import.normalize_track_elements(elements)
assert len(tracks) == 1
track = tracks[0]
assert track["status"] == "disused"
assert track["isBidirectional"] is False

View File

@@ -0,0 +1,168 @@
from __future__ import annotations
from dataclasses import dataclass, field
from typing import List
import pytest
from geoalchemy2.shape import from_shape
from shapely.geometry import Point
from backend.scripts import tracks_load
def test_parse_track_entries_returns_models() -> None:
entries = [
{
"name": "Connector",
"coordinates": [[52.5, 13.4], [52.6, 13.5]],
"lengthMeters": 1500,
"maxSpeedKph": 120,
"status": "operational",
"isBidirectional": True,
}
]
parsed = tracks_load._parse_track_entries(entries)
assert parsed[0].name == "Connector"
assert parsed[0].coordinates[0] == (52.5, 13.4)
assert parsed[0].length_meters == 1500
assert parsed[0].max_speed_kph == 120
def test_parse_track_entries_invalid_raises_value_error() -> None:
entries = [
{
"coordinates": [[52.5, 13.4]],
}
]
with pytest.raises(ValueError):
tracks_load._parse_track_entries(entries)
@dataclass
class DummySession:
committed: bool = False
rolled_back: bool = False
def __enter__(self) -> "DummySession":
return self
def __exit__(self, exc_type, exc, traceback) -> None:
pass
def commit(self) -> None:
self.committed = True
def rollback(self) -> None:
self.rolled_back = True
@dataclass
class DummyStation:
id: str
location: object
@dataclass
class DummyStationRepository:
session: DummySession
stations: List[DummyStation]
def list_active(self) -> List[DummyStation]:
return self.stations
@dataclass
class DummyTrackRepository:
session: DummySession
created: list = field(default_factory=list)
existing: list = field(default_factory=list)
def list_all(self):
return self.existing
def create(self, data): # pragma: no cover - simple delegation
self.created.append(data)
def _point(lat: float, lon: float) -> object:
return from_shape(Point(lon, lat), srid=4326)
def test_load_tracks_creates_entries(monkeypatch: pytest.MonkeyPatch) -> None:
session_instance = DummySession()
station_repo_instance = DummyStationRepository(
session_instance,
stations=[
DummyStation(id="station-a", location=_point(52.5, 13.4)),
DummyStation(id="station-b", location=_point(52.6, 13.5)),
],
)
track_repo_instance = DummyTrackRepository(session_instance)
monkeypatch.setattr(tracks_load, "SessionLocal", lambda: session_instance)
monkeypatch.setattr(tracks_load, "StationRepository",
lambda session: station_repo_instance)
monkeypatch.setattr(tracks_load, "TrackRepository",
lambda session: track_repo_instance)
parsed = tracks_load._parse_track_entries(
[
{
"name": "Connector",
"coordinates": [[52.5, 13.4], [52.6, 13.5]],
}
]
)
created = tracks_load.load_tracks(parsed, commit=True)
assert created == 1
assert session_instance.committed is True
assert track_repo_instance.created
track = track_repo_instance.created[0]
assert track.start_station_id == "station-a"
assert track.end_station_id == "station-b"
assert track.coordinates == [(52.5, 13.4), (52.6, 13.5)]
def test_load_tracks_skips_existing_pairs(monkeypatch: pytest.MonkeyPatch) -> None:
session_instance = DummySession()
station_repo_instance = DummyStationRepository(
session_instance,
stations=[
DummyStation(id="station-a", location=_point(52.5, 13.4)),
DummyStation(id="station-b", location=_point(52.6, 13.5)),
],
)
existing_track = type("ExistingTrack", (), {
"start_station_id": "station-a",
"end_station_id": "station-b",
})
track_repo_instance = DummyTrackRepository(
session_instance,
existing=[existing_track],
)
monkeypatch.setattr(tracks_load, "SessionLocal", lambda: session_instance)
monkeypatch.setattr(tracks_load, "StationRepository",
lambda session: station_repo_instance)
monkeypatch.setattr(tracks_load, "TrackRepository",
lambda session: track_repo_instance)
parsed = tracks_load._parse_track_entries(
[
{
"name": "Connector",
"coordinates": [[52.5, 13.4], [52.6, 13.5]],
}
]
)
created = tracks_load.load_tracks(parsed, commit=False)
assert created == 0
assert session_instance.rolled_back is True
assert not track_repo_instance.created