- Introduced CombinedTrackModel, CombinedTrackCreate, and CombinedTrackRepository for managing combined tracks. - Implemented logic to create combined tracks based on existing tracks between two stations. - Added methods to check for existing combined tracks and retrieve constituent track IDs. - Enhanced TrackModel and TrackRepository to support OSM ID and track updates. - Created migration scripts for adding combined tracks table and OSM ID to tracks. - Updated services and API endpoints to handle combined track operations. - Added tests for combined track creation, repository methods, and API interactions.
294 lines
9.0 KiB
Python
294 lines
9.0 KiB
Python
from __future__ import annotations
|
|
|
|
"""CLI for loading normalized track JSON into the database."""
|
|
|
|
import argparse
|
|
import json
|
|
import math
|
|
import sys
|
|
from dataclasses import dataclass
|
|
from pathlib import Path
|
|
from typing import Any, Iterable, Mapping, Sequence
|
|
|
|
from geoalchemy2.elements import WKBElement, WKTElement
|
|
from geoalchemy2.shape import to_shape
|
|
|
|
from backend.app.core.osm_config import TRACK_STATION_SNAP_RADIUS_METERS
|
|
from backend.app.db.session import SessionLocal
|
|
from backend.app.models import TrackCreate
|
|
from backend.app.repositories import StationRepository, TrackRepository
|
|
|
|
|
|
@dataclass(slots=True)
|
|
class ParsedTrack:
|
|
coordinates: list[tuple[float, float]]
|
|
osm_id: str | None = None
|
|
name: str | None = None
|
|
length_meters: float | None = None
|
|
max_speed_kph: float | None = None
|
|
status: str = "operational"
|
|
is_bidirectional: bool = True
|
|
|
|
|
|
@dataclass(slots=True)
|
|
class StationRef:
|
|
id: str
|
|
latitude: float
|
|
longitude: float
|
|
|
|
|
|
def build_argument_parser() -> argparse.ArgumentParser:
|
|
parser = argparse.ArgumentParser(
|
|
description="Load normalized track data into PostGIS",
|
|
)
|
|
parser.add_argument(
|
|
"input",
|
|
type=Path,
|
|
help="Path to the normalized track JSON file produced by tracks_import.py",
|
|
)
|
|
parser.add_argument(
|
|
"--commit",
|
|
dest="commit",
|
|
action="store_true",
|
|
default=True,
|
|
help="Commit the transaction after loading (default).",
|
|
)
|
|
parser.add_argument(
|
|
"--no-commit",
|
|
dest="commit",
|
|
action="store_false",
|
|
help="Rollback the transaction after loading (useful for dry runs).",
|
|
)
|
|
return parser
|
|
|
|
|
|
def main(argv: list[str] | None = None) -> int:
|
|
parser = build_argument_parser()
|
|
args = parser.parse_args(argv)
|
|
|
|
if not args.input.exists():
|
|
parser.error(f"Input file {args.input} does not exist")
|
|
|
|
with args.input.open("r", encoding="utf-8") as handle:
|
|
payload = json.load(handle)
|
|
|
|
track_entries = payload.get("tracks") or []
|
|
if not isinstance(track_entries, list):
|
|
parser.error("Invalid payload: 'tracks' must be a list")
|
|
|
|
try:
|
|
tracks = _parse_track_entries(track_entries)
|
|
except ValueError as exc:
|
|
parser.error(str(exc))
|
|
|
|
created = load_tracks(tracks, commit=args.commit)
|
|
print(f"Loaded {created} tracks from {args.input}")
|
|
return 0
|
|
|
|
|
|
def _parse_track_entries(entries: Iterable[Mapping[str, Any]]) -> list[ParsedTrack]:
|
|
parsed: list[ParsedTrack] = []
|
|
for entry in entries:
|
|
coordinates = entry.get("coordinates")
|
|
if not isinstance(coordinates, Sequence) or len(coordinates) < 2:
|
|
raise ValueError(
|
|
"Invalid track entry: 'coordinates' must contain at least two points"
|
|
)
|
|
|
|
processed_coordinates: list[tuple[float, float]] = []
|
|
for pair in coordinates:
|
|
if not isinstance(pair, Sequence) or len(pair) != 2:
|
|
raise ValueError(
|
|
f"Invalid coordinate pair {pair!r} in track entry")
|
|
lat, lon = pair
|
|
processed_coordinates.append((float(lat), float(lon)))
|
|
|
|
name = entry.get("name")
|
|
length = _safe_float(entry.get("lengthMeters"))
|
|
max_speed = _safe_float(entry.get("maxSpeedKph"))
|
|
status = entry.get("status", "operational")
|
|
is_bidirectional = entry.get("isBidirectional", True)
|
|
osm_id = entry.get("osmId")
|
|
|
|
parsed.append(
|
|
ParsedTrack(
|
|
coordinates=processed_coordinates,
|
|
osm_id=str(osm_id) if osm_id else None,
|
|
name=str(name) if name else None,
|
|
length_meters=length,
|
|
max_speed_kph=max_speed,
|
|
status=str(status) if status else "operational",
|
|
is_bidirectional=bool(is_bidirectional),
|
|
)
|
|
)
|
|
return parsed
|
|
|
|
|
|
def load_tracks(tracks: Iterable[ParsedTrack], commit: bool = True) -> int:
|
|
created = 0
|
|
with SessionLocal() as session:
|
|
station_repo = StationRepository(session)
|
|
track_repo = TrackRepository(session)
|
|
|
|
station_index = _build_station_index(station_repo.list_active())
|
|
existing_pairs = {
|
|
(str(track.start_station_id), str(track.end_station_id))
|
|
for track in track_repo.list_all()
|
|
}
|
|
|
|
for track_data in tracks:
|
|
# Skip if track with this OSM ID already exists
|
|
if track_data.osm_id and track_repo.exists_by_osm_id(track_data.osm_id):
|
|
print(
|
|
f"Skipping track {track_data.osm_id} - already exists by OSM ID")
|
|
continue
|
|
|
|
start_station = _nearest_station(
|
|
track_data.coordinates[0],
|
|
station_index,
|
|
TRACK_STATION_SNAP_RADIUS_METERS,
|
|
)
|
|
end_station = _nearest_station(
|
|
track_data.coordinates[-1],
|
|
station_index,
|
|
TRACK_STATION_SNAP_RADIUS_METERS,
|
|
)
|
|
|
|
if not start_station or not end_station:
|
|
print(
|
|
f"Skipping track {track_data.osm_id} - no start/end stations found")
|
|
continue
|
|
|
|
if start_station.id == end_station.id:
|
|
print(
|
|
f"Skipping track {track_data.osm_id} - start and end stations are the same")
|
|
continue
|
|
|
|
pair = (start_station.id, end_station.id)
|
|
if pair in existing_pairs:
|
|
print(
|
|
f"Skipping track {track_data.osm_id} - station pair {pair} already exists")
|
|
continue
|
|
|
|
length = track_data.length_meters or _polyline_length(
|
|
track_data.coordinates
|
|
)
|
|
max_speed = (
|
|
int(round(track_data.max_speed_kph))
|
|
if track_data.max_speed_kph is not None
|
|
else None
|
|
)
|
|
create_schema = TrackCreate(
|
|
osm_id=track_data.osm_id,
|
|
name=track_data.name,
|
|
start_station_id=start_station.id,
|
|
end_station_id=end_station.id,
|
|
coordinates=track_data.coordinates,
|
|
length_meters=length,
|
|
max_speed_kph=max_speed,
|
|
status=track_data.status,
|
|
is_bidirectional=track_data.is_bidirectional,
|
|
)
|
|
|
|
track_repo.create(create_schema)
|
|
existing_pairs.add(pair)
|
|
created += 1
|
|
|
|
if commit:
|
|
session.commit()
|
|
else:
|
|
session.rollback()
|
|
|
|
return created
|
|
|
|
|
|
def _nearest_station(
|
|
coordinate: tuple[float, float],
|
|
stations: Sequence[StationRef],
|
|
max_distance_meters: float,
|
|
) -> StationRef | None:
|
|
best_station: StationRef | None = None
|
|
best_distance = math.inf
|
|
for station in stations:
|
|
distance = _haversine(
|
|
coordinate, (station.latitude, station.longitude))
|
|
if distance < best_distance:
|
|
best_station = station
|
|
best_distance = distance
|
|
if best_distance <= max_distance_meters:
|
|
return best_station
|
|
return None
|
|
|
|
|
|
def _build_station_index(stations: Iterable[Any]) -> list[StationRef]:
|
|
index: list[StationRef] = []
|
|
for station in stations:
|
|
location = getattr(station, "location", None)
|
|
if location is None:
|
|
continue
|
|
point = _to_point(location)
|
|
if point is None:
|
|
continue
|
|
latitude = getattr(point, "y", None)
|
|
longitude = getattr(point, "x", None)
|
|
if latitude is None or longitude is None:
|
|
continue
|
|
index.append(
|
|
StationRef(
|
|
id=str(station.id),
|
|
latitude=float(latitude),
|
|
longitude=float(longitude),
|
|
)
|
|
)
|
|
return index
|
|
|
|
|
|
def _to_point(geometry: WKBElement | WKTElement | Any):
|
|
try:
|
|
point = to_shape(geometry)
|
|
return point if getattr(point, "geom_type", None) == "Point" else None
|
|
except (
|
|
Exception
|
|
): # pragma: no cover - defensive, should not happen with valid geometry
|
|
return None
|
|
|
|
|
|
def _polyline_length(points: Sequence[tuple[float, float]]) -> float:
|
|
if len(points) < 2:
|
|
return 0.0
|
|
|
|
total = 0.0
|
|
for index in range(len(points) - 1):
|
|
total += _haversine(points[index], points[index + 1])
|
|
return total
|
|
|
|
|
|
def _haversine(a: tuple[float, float], b: tuple[float, float]) -> float:
|
|
lat1, lon1 = a
|
|
lat2, lon2 = b
|
|
radius = 6_371_000
|
|
|
|
phi1 = math.radians(lat1)
|
|
phi2 = math.radians(lat2)
|
|
delta_phi = math.radians(lat2 - lat1)
|
|
delta_lambda = math.radians(lon2 - lon1)
|
|
|
|
sin_dphi = math.sin(delta_phi / 2)
|
|
sin_dlambda = math.sin(delta_lambda / 2)
|
|
root = sin_dphi**2 + math.cos(phi1) * math.cos(phi2) * sin_dlambda**2
|
|
distance = 2 * radius * math.atan2(math.sqrt(root), math.sqrt(1 - root))
|
|
return distance
|
|
|
|
|
|
def _safe_float(value: Any) -> float | None:
|
|
if value is None or value == "":
|
|
return None
|
|
try:
|
|
return float(value)
|
|
except (TypeError, ValueError):
|
|
return None
|
|
|
|
|
|
if __name__ == "__main__":
|
|
sys.exit(main())
|