from __future__ import annotations """CLI for loading normalized track JSON into the database.""" import argparse import json import math import sys from dataclasses import dataclass from pathlib import Path from typing import Any, Iterable, Mapping, Sequence from geoalchemy2.elements import WKBElement, WKTElement from geoalchemy2.shape import to_shape from backend.app.core.osm_config import TRACK_STATION_SNAP_RADIUS_METERS from backend.app.db.session import SessionLocal from backend.app.models import TrackCreate from backend.app.repositories import StationRepository, TrackRepository @dataclass(slots=True) class ParsedTrack: coordinates: list[tuple[float, float]] name: str | None = None length_meters: float | None = None max_speed_kph: float | None = None status: str = "operational" is_bidirectional: bool = True @dataclass(slots=True) class StationRef: id: str latitude: float longitude: float def build_argument_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser( description="Load normalized track data into PostGIS", ) parser.add_argument( "input", type=Path, help="Path to the normalized track JSON file produced by tracks_import.py", ) parser.add_argument( "--commit", dest="commit", action="store_true", default=True, help="Commit the transaction after loading (default).", ) parser.add_argument( "--no-commit", dest="commit", action="store_false", help="Rollback the transaction after loading (useful for dry runs).", ) return parser def main(argv: list[str] | None = None) -> int: parser = build_argument_parser() args = parser.parse_args(argv) if not args.input.exists(): parser.error(f"Input file {args.input} does not exist") with args.input.open("r", encoding="utf-8") as handle: payload = json.load(handle) track_entries = payload.get("tracks") or [] if not isinstance(track_entries, list): parser.error("Invalid payload: 'tracks' must be a list") try: tracks = _parse_track_entries(track_entries) except ValueError as exc: parser.error(str(exc)) created = load_tracks(tracks, commit=args.commit) print(f"Loaded {created} tracks from {args.input}") return 0 def _parse_track_entries(entries: Iterable[Mapping[str, Any]]) -> list[ParsedTrack]: parsed: list[ParsedTrack] = [] for entry in entries: coordinates = entry.get("coordinates") if not isinstance(coordinates, Sequence) or len(coordinates) < 2: raise ValueError( "Invalid track entry: 'coordinates' must contain at least two points") processed_coordinates: list[tuple[float, float]] = [] for pair in coordinates: if not isinstance(pair, Sequence) or len(pair) != 2: raise ValueError( f"Invalid coordinate pair {pair!r} in track entry") lat, lon = pair processed_coordinates.append((float(lat), float(lon))) name = entry.get("name") length = _safe_float(entry.get("lengthMeters")) max_speed = _safe_float(entry.get("maxSpeedKph")) status = entry.get("status", "operational") is_bidirectional = entry.get("isBidirectional", True) parsed.append( ParsedTrack( coordinates=processed_coordinates, name=str(name) if name else None, length_meters=length, max_speed_kph=max_speed, status=str(status) if status else "operational", is_bidirectional=bool(is_bidirectional), ) ) return parsed def load_tracks(tracks: Iterable[ParsedTrack], commit: bool = True) -> int: created = 0 with SessionLocal() as session: station_repo = StationRepository(session) track_repo = TrackRepository(session) station_index = _build_station_index(station_repo.list_active()) existing_pairs = { (str(track.start_station_id), str(track.end_station_id)) for track in track_repo.list_all() } for track_data in tracks: start_station = _nearest_station( track_data.coordinates[0], station_index, TRACK_STATION_SNAP_RADIUS_METERS, ) end_station = _nearest_station( track_data.coordinates[-1], station_index, TRACK_STATION_SNAP_RADIUS_METERS, ) if not start_station or not end_station: continue if start_station.id == end_station.id: continue pair = (start_station.id, end_station.id) if pair in existing_pairs: continue length = track_data.length_meters or _polyline_length( track_data.coordinates) max_speed = ( int(round(track_data.max_speed_kph)) if track_data.max_speed_kph is not None else None ) create_schema = TrackCreate( name=track_data.name, start_station_id=start_station.id, end_station_id=end_station.id, coordinates=track_data.coordinates, length_meters=length, max_speed_kph=max_speed, status=track_data.status, is_bidirectional=track_data.is_bidirectional, ) track_repo.create(create_schema) existing_pairs.add(pair) created += 1 if commit: session.commit() else: session.rollback() return created def _nearest_station( coordinate: tuple[float, float], stations: Sequence[StationRef], max_distance_meters: float, ) -> StationRef | None: best_station: StationRef | None = None best_distance = math.inf for station in stations: distance = _haversine( coordinate, (station.latitude, station.longitude)) if distance < best_distance: best_station = station best_distance = distance if best_distance <= max_distance_meters: return best_station return None def _build_station_index(stations: Iterable[Any]) -> list[StationRef]: index: list[StationRef] = [] for station in stations: location = getattr(station, "location", None) if location is None: continue point = _to_point(location) if point is None: continue latitude = getattr(point, "y", None) longitude = getattr(point, "x", None) if latitude is None or longitude is None: continue index.append( StationRef( id=str(station.id), latitude=float(latitude), longitude=float(longitude), ) ) return index def _to_point(geometry: WKBElement | WKTElement | Any): try: point = to_shape(geometry) return point if getattr(point, "geom_type", None) == "Point" else None except Exception: # pragma: no cover - defensive, should not happen with valid geometry return None def _polyline_length(points: Sequence[tuple[float, float]]) -> float: if len(points) < 2: return 0.0 total = 0.0 for index in range(len(points) - 1): total += _haversine(points[index], points[index + 1]) return total def _haversine(a: tuple[float, float], b: tuple[float, float]) -> float: lat1, lon1 = a lat2, lon2 = b radius = 6_371_000 phi1 = math.radians(lat1) phi2 = math.radians(lat2) delta_phi = math.radians(lat2 - lat1) delta_lambda = math.radians(lon2 - lon1) sin_dphi = math.sin(delta_phi / 2) sin_dlambda = math.sin(delta_lambda / 2) root = sin_dphi**2 + math.cos(phi1) * math.cos(phi2) * sin_dlambda**2 distance = 2 * radius * math.atan2(math.sqrt(root), math.sqrt(1 - root)) return distance def _safe_float(value: Any) -> float | None: if value is None or value == "": return None try: return float(value) except (TypeError, ValueError): return None if __name__ == "__main__": sys.exit(main())