Files
rail-game/backend/scripts/tracks_load.py
zwitschi c35049cd54
Some checks failed
Backend CI / lint-and-test (push) Failing after 1m54s
fix: formatting (black)
2025-10-11 21:58:32 +02:00

276 lines
8.2 KiB
Python

from __future__ import annotations
"""CLI for loading normalized track JSON into the database."""
import argparse
import json
import math
import sys
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Iterable, Mapping, Sequence
from geoalchemy2.elements import WKBElement, WKTElement
from geoalchemy2.shape import to_shape
from backend.app.core.osm_config import TRACK_STATION_SNAP_RADIUS_METERS
from backend.app.db.session import SessionLocal
from backend.app.models import TrackCreate
from backend.app.repositories import StationRepository, TrackRepository
@dataclass(slots=True)
class ParsedTrack:
coordinates: list[tuple[float, float]]
name: str | None = None
length_meters: float | None = None
max_speed_kph: float | None = None
status: str = "operational"
is_bidirectional: bool = True
@dataclass(slots=True)
class StationRef:
id: str
latitude: float
longitude: float
def build_argument_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(
description="Load normalized track data into PostGIS",
)
parser.add_argument(
"input",
type=Path,
help="Path to the normalized track JSON file produced by tracks_import.py",
)
parser.add_argument(
"--commit",
dest="commit",
action="store_true",
default=True,
help="Commit the transaction after loading (default).",
)
parser.add_argument(
"--no-commit",
dest="commit",
action="store_false",
help="Rollback the transaction after loading (useful for dry runs).",
)
return parser
def main(argv: list[str] | None = None) -> int:
parser = build_argument_parser()
args = parser.parse_args(argv)
if not args.input.exists():
parser.error(f"Input file {args.input} does not exist")
with args.input.open("r", encoding="utf-8") as handle:
payload = json.load(handle)
track_entries = payload.get("tracks") or []
if not isinstance(track_entries, list):
parser.error("Invalid payload: 'tracks' must be a list")
try:
tracks = _parse_track_entries(track_entries)
except ValueError as exc:
parser.error(str(exc))
created = load_tracks(tracks, commit=args.commit)
print(f"Loaded {created} tracks from {args.input}")
return 0
def _parse_track_entries(entries: Iterable[Mapping[str, Any]]) -> list[ParsedTrack]:
parsed: list[ParsedTrack] = []
for entry in entries:
coordinates = entry.get("coordinates")
if not isinstance(coordinates, Sequence) or len(coordinates) < 2:
raise ValueError(
"Invalid track entry: 'coordinates' must contain at least two points"
)
processed_coordinates: list[tuple[float, float]] = []
for pair in coordinates:
if not isinstance(pair, Sequence) or len(pair) != 2:
raise ValueError(f"Invalid coordinate pair {pair!r} in track entry")
lat, lon = pair
processed_coordinates.append((float(lat), float(lon)))
name = entry.get("name")
length = _safe_float(entry.get("lengthMeters"))
max_speed = _safe_float(entry.get("maxSpeedKph"))
status = entry.get("status", "operational")
is_bidirectional = entry.get("isBidirectional", True)
parsed.append(
ParsedTrack(
coordinates=processed_coordinates,
name=str(name) if name else None,
length_meters=length,
max_speed_kph=max_speed,
status=str(status) if status else "operational",
is_bidirectional=bool(is_bidirectional),
)
)
return parsed
def load_tracks(tracks: Iterable[ParsedTrack], commit: bool = True) -> int:
created = 0
with SessionLocal() as session:
station_repo = StationRepository(session)
track_repo = TrackRepository(session)
station_index = _build_station_index(station_repo.list_active())
existing_pairs = {
(str(track.start_station_id), str(track.end_station_id))
for track in track_repo.list_all()
}
for track_data in tracks:
start_station = _nearest_station(
track_data.coordinates[0],
station_index,
TRACK_STATION_SNAP_RADIUS_METERS,
)
end_station = _nearest_station(
track_data.coordinates[-1],
station_index,
TRACK_STATION_SNAP_RADIUS_METERS,
)
if not start_station or not end_station:
continue
if start_station.id == end_station.id:
continue
pair = (start_station.id, end_station.id)
if pair in existing_pairs:
continue
length = track_data.length_meters or _polyline_length(
track_data.coordinates
)
max_speed = (
int(round(track_data.max_speed_kph))
if track_data.max_speed_kph is not None
else None
)
create_schema = TrackCreate(
name=track_data.name,
start_station_id=start_station.id,
end_station_id=end_station.id,
coordinates=track_data.coordinates,
length_meters=length,
max_speed_kph=max_speed,
status=track_data.status,
is_bidirectional=track_data.is_bidirectional,
)
track_repo.create(create_schema)
existing_pairs.add(pair)
created += 1
if commit:
session.commit()
else:
session.rollback()
return created
def _nearest_station(
coordinate: tuple[float, float],
stations: Sequence[StationRef],
max_distance_meters: float,
) -> StationRef | None:
best_station: StationRef | None = None
best_distance = math.inf
for station in stations:
distance = _haversine(coordinate, (station.latitude, station.longitude))
if distance < best_distance:
best_station = station
best_distance = distance
if best_distance <= max_distance_meters:
return best_station
return None
def _build_station_index(stations: Iterable[Any]) -> list[StationRef]:
index: list[StationRef] = []
for station in stations:
location = getattr(station, "location", None)
if location is None:
continue
point = _to_point(location)
if point is None:
continue
latitude = getattr(point, "y", None)
longitude = getattr(point, "x", None)
if latitude is None or longitude is None:
continue
index.append(
StationRef(
id=str(station.id),
latitude=float(latitude),
longitude=float(longitude),
)
)
return index
def _to_point(geometry: WKBElement | WKTElement | Any):
try:
point = to_shape(geometry)
return point if getattr(point, "geom_type", None) == "Point" else None
except (
Exception
): # pragma: no cover - defensive, should not happen with valid geometry
return None
def _polyline_length(points: Sequence[tuple[float, float]]) -> float:
if len(points) < 2:
return 0.0
total = 0.0
for index in range(len(points) - 1):
total += _haversine(points[index], points[index + 1])
return total
def _haversine(a: tuple[float, float], b: tuple[float, float]) -> float:
lat1, lon1 = a
lat2, lon2 = b
radius = 6_371_000
phi1 = math.radians(lat1)
phi2 = math.radians(lat2)
delta_phi = math.radians(lat2 - lat1)
delta_lambda = math.radians(lon2 - lon1)
sin_dphi = math.sin(delta_phi / 2)
sin_dlambda = math.sin(delta_lambda / 2)
root = sin_dphi**2 + math.cos(phi1) * math.cos(phi2) * sin_dlambda**2
distance = 2 * radius * math.atan2(math.sqrt(root), math.sqrt(1 - root))
return distance
def _safe_float(value: Any) -> float | None:
if value is None or value == "":
return None
try:
return float(value)
except (TypeError, ValueError):
return None
if __name__ == "__main__":
sys.exit(main())