feat: Enhance track model and import functionality
- Added new fields to TrackModel: status, is_bidirectional, and coordinates. - Updated network service to handle new track attributes and geometry extraction. - Introduced CLI scripts for importing and loading tracks from OpenStreetMap. - Implemented normalization of track elements to ensure valid geometries. - Enhanced tests for track model, network service, and import/load scripts. - Updated frontend to accommodate new track attributes and improve route computation. - Documented OSM ingestion process in architecture and runtime views.
This commit is contained in:
234
backend/scripts/tracks_import.py
Normal file
234
backend/scripts/tracks_import.py
Normal file
@@ -0,0 +1,234 @@
|
||||
from __future__ import annotations
|
||||
|
||||
"""CLI utility to export rail track geometries from OpenStreetMap."""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import math
|
||||
import sys
|
||||
from dataclasses import asdict
|
||||
from pathlib import Path
|
||||
from typing import Any, Iterable
|
||||
from urllib.parse import quote_plus
|
||||
|
||||
from backend.app.core.osm_config import (
|
||||
DEFAULT_REGIONS,
|
||||
TRACK_TAG_FILTERS,
|
||||
compile_overpass_filters,
|
||||
)
|
||||
|
||||
OVERPASS_ENDPOINT = "https://overpass-api.de/api/interpreter"
|
||||
|
||||
|
||||
def build_argument_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Export OSM rail track ways for ingestion",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output",
|
||||
type=Path,
|
||||
default=Path("data/osm_tracks.json"),
|
||||
help=(
|
||||
"Destination file for the exported track geometries "
|
||||
"(default: data/osm_tracks.json)"
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--region",
|
||||
choices=[region.name for region in DEFAULT_REGIONS] + ["all"],
|
||||
default="all",
|
||||
help="Region name to export (default: all)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dry-run",
|
||||
action="store_true",
|
||||
help="Do not fetch data; print the Overpass payload only",
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
def build_overpass_query(region_name: str) -> str:
|
||||
if region_name == "all":
|
||||
regions = DEFAULT_REGIONS
|
||||
else:
|
||||
regions = tuple(
|
||||
region for region in DEFAULT_REGIONS if region.name == region_name)
|
||||
if not regions:
|
||||
available = ", ".join(region.name for region in DEFAULT_REGIONS)
|
||||
msg = f"Unknown region {region_name}. Available regions: [{available}]"
|
||||
raise ValueError(msg)
|
||||
|
||||
filters = compile_overpass_filters(TRACK_TAG_FILTERS)
|
||||
|
||||
parts = ["[out:json][timeout:120];", "("]
|
||||
for region in regions:
|
||||
parts.append(f" way{filters}\n ({region.to_overpass_arg()});")
|
||||
parts.append(")")
|
||||
parts.append("; out body geom; >; out skel qt;")
|
||||
return "\n".join(parts)
|
||||
|
||||
|
||||
def perform_request(query: str) -> dict[str, Any]:
|
||||
import urllib.request
|
||||
|
||||
payload = f"data={quote_plus(query)}".encode("utf-8")
|
||||
request = urllib.request.Request(
|
||||
OVERPASS_ENDPOINT,
|
||||
data=payload,
|
||||
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
||||
)
|
||||
with urllib.request.urlopen(request, timeout=180) as response:
|
||||
payload = response.read()
|
||||
return json.loads(payload)
|
||||
|
||||
|
||||
def normalize_track_elements(elements: Iterable[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
"""Convert Overpass way elements into TrackCreate-compatible payloads."""
|
||||
|
||||
tracks: list[dict[str, Any]] = []
|
||||
for element in elements:
|
||||
if element.get("type") != "way":
|
||||
continue
|
||||
|
||||
raw_geometry = element.get("geometry") or []
|
||||
coordinates: list[list[float]] = []
|
||||
for node in raw_geometry:
|
||||
lat = node.get("lat")
|
||||
lon = node.get("lon")
|
||||
if lat is None or lon is None:
|
||||
coordinates = []
|
||||
break
|
||||
coordinates.append([float(lat), float(lon)])
|
||||
|
||||
if len(coordinates) < 2:
|
||||
continue
|
||||
|
||||
tags: dict[str, Any] = element.get("tags", {})
|
||||
name = tags.get("name")
|
||||
maxspeed = _parse_maxspeed(tags.get("maxspeed"))
|
||||
status = _derive_status(tags.get("railway"))
|
||||
is_bidirectional = not _is_oneway(tags.get("oneway"))
|
||||
|
||||
length_meters = _polyline_length(coordinates)
|
||||
|
||||
tracks.append(
|
||||
{
|
||||
"osmId": str(element.get("id")),
|
||||
"name": str(name) if name else None,
|
||||
"lengthMeters": length_meters,
|
||||
"maxSpeedKph": maxspeed,
|
||||
"status": status,
|
||||
"isBidirectional": is_bidirectional,
|
||||
"coordinates": coordinates,
|
||||
}
|
||||
)
|
||||
|
||||
return tracks
|
||||
|
||||
|
||||
def _parse_maxspeed(value: Any) -> float | None:
|
||||
if value is None:
|
||||
return None
|
||||
|
||||
# Overpass may return values such as "80" or "80 km/h" or "signals".
|
||||
if isinstance(value, (int, float)):
|
||||
return float(value)
|
||||
|
||||
text = str(value).strip()
|
||||
number = ""
|
||||
for char in text:
|
||||
if char.isdigit() or char == ".":
|
||||
number += char
|
||||
elif number:
|
||||
break
|
||||
try:
|
||||
return float(number) if number else None
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
|
||||
def _derive_status(value: Any) -> str:
|
||||
tag = str(value or "").lower()
|
||||
if tag in {"abandoned", "disused"}:
|
||||
return tag
|
||||
if tag in {"construction", "proposed"}:
|
||||
return "construction"
|
||||
return "operational"
|
||||
|
||||
|
||||
def _is_oneway(value: Any) -> bool:
|
||||
if value is None:
|
||||
return False
|
||||
normalized = str(value).strip().lower()
|
||||
return normalized in {"yes", "true", "1"}
|
||||
|
||||
|
||||
def _polyline_length(points: list[list[float]]) -> float:
|
||||
if len(points) < 2:
|
||||
return 0.0
|
||||
|
||||
total = 0.0
|
||||
for index in range(len(points) - 1):
|
||||
total += _haversine(points[index], points[index + 1])
|
||||
return total
|
||||
|
||||
|
||||
def _haversine(a: list[float], b: list[float]) -> float:
|
||||
"""Return distance in meters between two [lat, lon] coordinates."""
|
||||
|
||||
lat1, lon1 = a
|
||||
lat2, lon2 = b
|
||||
radius = 6_371_000
|
||||
|
||||
phi1 = math.radians(lat1)
|
||||
phi2 = math.radians(lat2)
|
||||
delta_phi = math.radians(lat2 - lat1)
|
||||
delta_lambda = math.radians(lon2 - lon1)
|
||||
|
||||
sin_dphi = math.sin(delta_phi / 2)
|
||||
sin_dlambda = math.sin(delta_lambda / 2)
|
||||
root = sin_dphi**2 + math.cos(phi1) * math.cos(phi2) * sin_dlambda**2
|
||||
distance = 2 * radius * math.atan2(math.sqrt(root), math.sqrt(1 - root))
|
||||
return distance
|
||||
|
||||
|
||||
def main(argv: list[str] | None = None) -> int:
|
||||
parser = build_argument_parser()
|
||||
args = parser.parse_args(argv)
|
||||
|
||||
query = build_overpass_query(args.region)
|
||||
|
||||
if args.dry_run:
|
||||
print(query)
|
||||
return 0
|
||||
|
||||
output_path: Path = args.output
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
data = perform_request(query)
|
||||
raw_elements = data.get("elements", [])
|
||||
tracks = normalize_track_elements(raw_elements)
|
||||
|
||||
payload = {
|
||||
"metadata": {
|
||||
"endpoint": OVERPASS_ENDPOINT,
|
||||
"region": args.region,
|
||||
"filters": TRACK_TAG_FILTERS,
|
||||
"regions": [asdict(region) for region in DEFAULT_REGIONS],
|
||||
"raw_count": len(raw_elements),
|
||||
"track_count": len(tracks),
|
||||
},
|
||||
"tracks": tracks,
|
||||
}
|
||||
|
||||
with output_path.open("w", encoding="utf-8") as handle:
|
||||
json.dump(payload, handle, indent=2)
|
||||
|
||||
print(
|
||||
f"Normalized {len(tracks)} tracks from {len(raw_elements)} elements into {output_path}"
|
||||
)
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
243
backend/scripts/tracks_load.py
Normal file
243
backend/scripts/tracks_load.py
Normal file
@@ -0,0 +1,243 @@
|
||||
from __future__ import annotations
|
||||
|
||||
"""CLI for loading normalized track JSON into the database."""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import math
|
||||
import sys
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any, Iterable, Mapping, Sequence
|
||||
|
||||
from geoalchemy2.elements import WKBElement, WKTElement
|
||||
from geoalchemy2.shape import to_shape
|
||||
|
||||
from backend.app.db.session import SessionLocal
|
||||
from backend.app.models import TrackCreate
|
||||
from backend.app.repositories import StationRepository, TrackRepository
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class ParsedTrack:
|
||||
coordinates: list[tuple[float, float]]
|
||||
name: str | None = None
|
||||
length_meters: float | None = None
|
||||
max_speed_kph: float | None = None
|
||||
status: str = "operational"
|
||||
is_bidirectional: bool = True
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class StationRef:
|
||||
id: str
|
||||
latitude: float
|
||||
longitude: float
|
||||
|
||||
|
||||
def build_argument_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Load normalized track data into PostGIS",
|
||||
)
|
||||
parser.add_argument(
|
||||
"input",
|
||||
type=Path,
|
||||
help="Path to the normalized track JSON file produced by tracks_import.py",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--commit/--no-commit",
|
||||
dest="commit",
|
||||
default=True,
|
||||
help="Commit the transaction (default: commit). Use --no-commit for dry runs.",
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
def main(argv: list[str] | None = None) -> int:
|
||||
parser = build_argument_parser()
|
||||
args = parser.parse_args(argv)
|
||||
|
||||
if not args.input.exists():
|
||||
parser.error(f"Input file {args.input} does not exist")
|
||||
|
||||
with args.input.open("r", encoding="utf-8") as handle:
|
||||
payload = json.load(handle)
|
||||
|
||||
track_entries = payload.get("tracks") or []
|
||||
if not isinstance(track_entries, list):
|
||||
parser.error("Invalid payload: 'tracks' must be a list")
|
||||
|
||||
try:
|
||||
tracks = _parse_track_entries(track_entries)
|
||||
except ValueError as exc:
|
||||
parser.error(str(exc))
|
||||
|
||||
created = load_tracks(tracks, commit=args.commit)
|
||||
print(f"Loaded {created} tracks from {args.input}")
|
||||
return 0
|
||||
|
||||
|
||||
def _parse_track_entries(entries: Iterable[Mapping[str, Any]]) -> list[ParsedTrack]:
|
||||
parsed: list[ParsedTrack] = []
|
||||
for entry in entries:
|
||||
coordinates = entry.get("coordinates")
|
||||
if not isinstance(coordinates, Sequence) or len(coordinates) < 2:
|
||||
raise ValueError(
|
||||
"Invalid track entry: 'coordinates' must contain at least two points")
|
||||
|
||||
processed_coordinates: list[tuple[float, float]] = []
|
||||
for pair in coordinates:
|
||||
if not isinstance(pair, Sequence) or len(pair) != 2:
|
||||
raise ValueError(
|
||||
f"Invalid coordinate pair {pair!r} in track entry")
|
||||
lat, lon = pair
|
||||
processed_coordinates.append((float(lat), float(lon)))
|
||||
|
||||
name = entry.get("name")
|
||||
length = _safe_float(entry.get("lengthMeters"))
|
||||
max_speed = _safe_float(entry.get("maxSpeedKph"))
|
||||
status = entry.get("status", "operational")
|
||||
is_bidirectional = entry.get("isBidirectional", True)
|
||||
|
||||
parsed.append(
|
||||
ParsedTrack(
|
||||
coordinates=processed_coordinates,
|
||||
name=str(name) if name else None,
|
||||
length_meters=length,
|
||||
max_speed_kph=max_speed,
|
||||
status=str(status) if status else "operational",
|
||||
is_bidirectional=bool(is_bidirectional),
|
||||
)
|
||||
)
|
||||
return parsed
|
||||
|
||||
|
||||
def load_tracks(tracks: Iterable[ParsedTrack], commit: bool = True) -> int:
|
||||
created = 0
|
||||
with SessionLocal() as session:
|
||||
station_repo = StationRepository(session)
|
||||
track_repo = TrackRepository(session)
|
||||
|
||||
station_index = _build_station_index(station_repo.list_active())
|
||||
existing_pairs = {
|
||||
(str(track.start_station_id), str(track.end_station_id))
|
||||
for track in track_repo.list_all()
|
||||
}
|
||||
|
||||
for track_data in tracks:
|
||||
start_station = _nearest_station(
|
||||
track_data.coordinates[0], station_index)
|
||||
end_station = _nearest_station(
|
||||
track_data.coordinates[-1], station_index)
|
||||
|
||||
if not start_station or not end_station:
|
||||
continue
|
||||
|
||||
pair = (start_station.id, end_station.id)
|
||||
if pair in existing_pairs:
|
||||
continue
|
||||
|
||||
length = track_data.length_meters or _polyline_length(
|
||||
track_data.coordinates)
|
||||
create_schema = TrackCreate(
|
||||
name=track_data.name,
|
||||
start_station_id=start_station.id,
|
||||
end_station_id=end_station.id,
|
||||
coordinates=track_data.coordinates,
|
||||
length_meters=length,
|
||||
max_speed_kph=track_data.max_speed_kph,
|
||||
status=track_data.status,
|
||||
is_bidirectional=track_data.is_bidirectional,
|
||||
)
|
||||
|
||||
track_repo.create(create_schema)
|
||||
existing_pairs.add(pair)
|
||||
created += 1
|
||||
|
||||
if commit:
|
||||
session.commit()
|
||||
else:
|
||||
session.rollback()
|
||||
|
||||
return created
|
||||
|
||||
|
||||
def _nearest_station(
|
||||
coordinate: tuple[float, float], stations: Sequence[StationRef]
|
||||
) -> StationRef | None:
|
||||
best_station: StationRef | None = None
|
||||
best_distance = math.inf
|
||||
for station in stations:
|
||||
distance = _haversine(
|
||||
coordinate, (station.latitude, station.longitude))
|
||||
if distance < best_distance:
|
||||
best_station = station
|
||||
best_distance = distance
|
||||
return best_station
|
||||
|
||||
|
||||
def _build_station_index(stations: Iterable[Any]) -> list[StationRef]:
|
||||
index: list[StationRef] = []
|
||||
for station in stations:
|
||||
location = getattr(station, "location", None)
|
||||
if location is None:
|
||||
continue
|
||||
point = _to_point(location)
|
||||
if point is None:
|
||||
continue
|
||||
index.append(
|
||||
StationRef(
|
||||
id=str(station.id),
|
||||
latitude=float(point.y),
|
||||
longitude=float(point.x),
|
||||
)
|
||||
)
|
||||
return index
|
||||
|
||||
|
||||
def _to_point(geometry: WKBElement | WKTElement | Any):
|
||||
try:
|
||||
point = to_shape(geometry)
|
||||
return point if getattr(point, "geom_type", None) == "Point" else None
|
||||
except Exception: # pragma: no cover - defensive, should not happen with valid geometry
|
||||
return None
|
||||
|
||||
|
||||
def _polyline_length(points: Sequence[tuple[float, float]]) -> float:
|
||||
if len(points) < 2:
|
||||
return 0.0
|
||||
|
||||
total = 0.0
|
||||
for index in range(len(points) - 1):
|
||||
total += _haversine(points[index], points[index + 1])
|
||||
return total
|
||||
|
||||
|
||||
def _haversine(a: tuple[float, float], b: tuple[float, float]) -> float:
|
||||
lat1, lon1 = a
|
||||
lat2, lon2 = b
|
||||
radius = 6_371_000
|
||||
|
||||
phi1 = math.radians(lat1)
|
||||
phi2 = math.radians(lat2)
|
||||
delta_phi = math.radians(lat2 - lat1)
|
||||
delta_lambda = math.radians(lon2 - lon1)
|
||||
|
||||
sin_dphi = math.sin(delta_phi / 2)
|
||||
sin_dlambda = math.sin(delta_lambda / 2)
|
||||
root = sin_dphi**2 + math.cos(phi1) * math.cos(phi2) * sin_dlambda**2
|
||||
distance = 2 * radius * math.atan2(math.sqrt(root), math.sqrt(1 - root))
|
||||
return distance
|
||||
|
||||
|
||||
def _safe_float(value: Any) -> float | None:
|
||||
if value is None or value == "":
|
||||
return None
|
||||
try:
|
||||
return float(value)
|
||||
except (TypeError, ValueError):
|
||||
return None
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
Reference in New Issue
Block a user