Files
rail-game/backend/scripts/tracks_load.py

251 lines
7.5 KiB
Python

from __future__ import annotations
"""CLI for loading normalized track JSON into the database."""
import argparse
import json
import math
import sys
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Iterable, Mapping, Sequence
from geoalchemy2.elements import WKBElement, WKTElement
from geoalchemy2.shape import to_shape
from backend.app.db.session import SessionLocal
from backend.app.models import TrackCreate
from backend.app.repositories import StationRepository, TrackRepository
@dataclass(slots=True)
class ParsedTrack:
coordinates: list[tuple[float, float]]
name: str | None = None
length_meters: float | None = None
max_speed_kph: float | None = None
status: str = "operational"
is_bidirectional: bool = True
@dataclass(slots=True)
class StationRef:
id: str
latitude: float
longitude: float
def build_argument_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(
description="Load normalized track data into PostGIS",
)
parser.add_argument(
"input",
type=Path,
help="Path to the normalized track JSON file produced by tracks_import.py",
)
parser.add_argument(
"--commit",
dest="commit",
action="store_true",
default=True,
help="Commit the transaction after loading (default).",
)
parser.add_argument(
"--no-commit",
dest="commit",
action="store_false",
help="Rollback the transaction after loading (useful for dry runs).",
)
return parser
def main(argv: list[str] | None = None) -> int:
parser = build_argument_parser()
args = parser.parse_args(argv)
if not args.input.exists():
parser.error(f"Input file {args.input} does not exist")
with args.input.open("r", encoding="utf-8") as handle:
payload = json.load(handle)
track_entries = payload.get("tracks") or []
if not isinstance(track_entries, list):
parser.error("Invalid payload: 'tracks' must be a list")
try:
tracks = _parse_track_entries(track_entries)
except ValueError as exc:
parser.error(str(exc))
created = load_tracks(tracks, commit=args.commit)
print(f"Loaded {created} tracks from {args.input}")
return 0
def _parse_track_entries(entries: Iterable[Mapping[str, Any]]) -> list[ParsedTrack]:
parsed: list[ParsedTrack] = []
for entry in entries:
coordinates = entry.get("coordinates")
if not isinstance(coordinates, Sequence) or len(coordinates) < 2:
raise ValueError(
"Invalid track entry: 'coordinates' must contain at least two points")
processed_coordinates: list[tuple[float, float]] = []
for pair in coordinates:
if not isinstance(pair, Sequence) or len(pair) != 2:
raise ValueError(
f"Invalid coordinate pair {pair!r} in track entry")
lat, lon = pair
processed_coordinates.append((float(lat), float(lon)))
name = entry.get("name")
length = _safe_float(entry.get("lengthMeters"))
max_speed = _safe_float(entry.get("maxSpeedKph"))
status = entry.get("status", "operational")
is_bidirectional = entry.get("isBidirectional", True)
parsed.append(
ParsedTrack(
coordinates=processed_coordinates,
name=str(name) if name else None,
length_meters=length,
max_speed_kph=max_speed,
status=str(status) if status else "operational",
is_bidirectional=bool(is_bidirectional),
)
)
return parsed
def load_tracks(tracks: Iterable[ParsedTrack], commit: bool = True) -> int:
created = 0
with SessionLocal() as session:
station_repo = StationRepository(session)
track_repo = TrackRepository(session)
station_index = _build_station_index(station_repo.list_active())
existing_pairs = {
(str(track.start_station_id), str(track.end_station_id))
for track in track_repo.list_all()
}
for track_data in tracks:
start_station = _nearest_station(
track_data.coordinates[0], station_index)
end_station = _nearest_station(
track_data.coordinates[-1], station_index)
if not start_station or not end_station:
continue
pair = (start_station.id, end_station.id)
if pair in existing_pairs:
continue
length = track_data.length_meters or _polyline_length(
track_data.coordinates)
create_schema = TrackCreate(
name=track_data.name,
start_station_id=start_station.id,
end_station_id=end_station.id,
coordinates=track_data.coordinates,
length_meters=length,
max_speed_kph=track_data.max_speed_kph,
status=track_data.status,
is_bidirectional=track_data.is_bidirectional,
)
track_repo.create(create_schema)
existing_pairs.add(pair)
created += 1
if commit:
session.commit()
else:
session.rollback()
return created
def _nearest_station(
coordinate: tuple[float, float], stations: Sequence[StationRef]
) -> StationRef | None:
best_station: StationRef | None = None
best_distance = math.inf
for station in stations:
distance = _haversine(
coordinate, (station.latitude, station.longitude))
if distance < best_distance:
best_station = station
best_distance = distance
return best_station
def _build_station_index(stations: Iterable[Any]) -> list[StationRef]:
index: list[StationRef] = []
for station in stations:
location = getattr(station, "location", None)
if location is None:
continue
point = _to_point(location)
if point is None:
continue
index.append(
StationRef(
id=str(station.id),
latitude=float(point.y),
longitude=float(point.x),
)
)
return index
def _to_point(geometry: WKBElement | WKTElement | Any):
try:
point = to_shape(geometry)
return point if getattr(point, "geom_type", None) == "Point" else None
except Exception: # pragma: no cover - defensive, should not happen with valid geometry
return None
def _polyline_length(points: Sequence[tuple[float, float]]) -> float:
if len(points) < 2:
return 0.0
total = 0.0
for index in range(len(points) - 1):
total += _haversine(points[index], points[index + 1])
return total
def _haversine(a: tuple[float, float], b: tuple[float, float]) -> float:
lat1, lon1 = a
lat2, lon2 = b
radius = 6_371_000
phi1 = math.radians(lat1)
phi2 = math.radians(lat2)
delta_phi = math.radians(lat2 - lat1)
delta_lambda = math.radians(lon2 - lon1)
sin_dphi = math.sin(delta_phi / 2)
sin_dlambda = math.sin(delta_lambda / 2)
root = sin_dphi**2 + math.cos(phi1) * math.cos(phi2) * sin_dlambda**2
distance = 2 * radius * math.atan2(math.sqrt(root), math.sqrt(1 - root))
return distance
def _safe_float(value: Any) -> float | None:
if value is None or value == "":
return None
try:
return float(value)
except (TypeError, ValueError):
return None
if __name__ == "__main__":
sys.exit(main())