from __future__ import annotations """CLI utility to export rail track geometries from OpenStreetMap.""" import argparse import json import math import sys from dataclasses import asdict from pathlib import Path from typing import Any, Iterable, Mapping from urllib.parse import quote_plus from backend.app.core.osm_config import ( DEFAULT_REGIONS, TRACK_ALLOWED_RAILWAY_TYPES, TRACK_EXCLUDED_SERVICE_TAGS, TRACK_EXCLUDED_USAGE_TAGS, TRACK_MIN_LENGTH_METERS, TRACK_TAG_FILTERS, compile_overpass_filters, ) OVERPASS_ENDPOINT = "https://overpass-api.de/api/interpreter" def build_argument_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser( description="Export OSM rail track ways for ingestion", ) parser.add_argument( "--output", type=Path, default=Path("data/osm_tracks.json"), help=( "Destination file for the exported track geometries " "(default: data/osm_tracks.json)" ), ) parser.add_argument( "--region", choices=[region.name for region in DEFAULT_REGIONS] + ["all"], default="all", help="Region name to export (default: all)", ) parser.add_argument( "--dry-run", action="store_true", help="Do not fetch data; print the Overpass payload only", ) return parser def build_overpass_query(region_name: str) -> str: if region_name == "all": regions = DEFAULT_REGIONS else: regions = tuple( region for region in DEFAULT_REGIONS if region.name == region_name) if not regions: available = ", ".join(region.name for region in DEFAULT_REGIONS) msg = f"Unknown region {region_name}. Available regions: [{available}]" raise ValueError(msg) filters = compile_overpass_filters(TRACK_TAG_FILTERS) parts = ["[out:json][timeout:120];", "("] for region in regions: parts.append(f" way{filters}\n ({region.to_overpass_arg()});") parts.append(")") parts.append("; out body geom; >; out skel qt;") return "\n".join(parts) def perform_request(query: str) -> dict[str, Any]: import urllib.request payload = f"data={quote_plus(query)}".encode("utf-8") request = urllib.request.Request( OVERPASS_ENDPOINT, data=payload, headers={"Content-Type": "application/x-www-form-urlencoded"}, ) with urllib.request.urlopen(request, timeout=180) as response: payload = response.read() return json.loads(payload) def normalize_track_elements(elements: Iterable[dict[str, Any]]) -> list[dict[str, Any]]: """Convert Overpass way elements into TrackCreate-compatible payloads.""" tracks: list[dict[str, Any]] = [] for element in elements: if element.get("type") != "way": continue raw_geometry = element.get("geometry") or [] coordinates: list[list[float]] = [] for node in raw_geometry: lat = node.get("lat") lon = node.get("lon") if lat is None or lon is None: coordinates = [] break coordinates.append([float(lat), float(lon)]) if len(coordinates) < 2: continue tags: dict[str, Any] = element.get("tags", {}) length_meters = _polyline_length(coordinates) if not _should_include_track(tags, length_meters): continue name = tags.get("name") maxspeed = _parse_maxspeed(tags.get("maxspeed")) status = _derive_status(tags.get("railway")) is_bidirectional = not _is_oneway(tags.get("oneway")) tracks.append( { "osmId": str(element.get("id")), "name": str(name) if name else None, "lengthMeters": length_meters, "maxSpeedKph": maxspeed, "status": status, "isBidirectional": is_bidirectional, "coordinates": coordinates, } ) return tracks def _parse_maxspeed(value: Any) -> float | None: if value is None: return None # Overpass may return values such as "80" or "80 km/h" or "signals". if isinstance(value, (int, float)): return float(value) text = str(value).strip() number = "" for char in text: if char.isdigit() or char == ".": number += char elif number: break try: return float(number) if number else None except ValueError: return None def _derive_status(value: Any) -> str: tag = str(value or "").lower() if tag in {"abandoned", "disused"}: return tag if tag in {"construction", "proposed"}: return "construction" return "operational" def _should_include_track(tags: Mapping[str, Any], length_meters: float) -> bool: railway = str(tags.get("railway", "")).lower() if railway not in TRACK_ALLOWED_RAILWAY_TYPES: return False if length_meters < TRACK_MIN_LENGTH_METERS: return False service = str(tags.get("service", "")).lower() if service and service in TRACK_EXCLUDED_SERVICE_TAGS: return False usage = str(tags.get("usage", "")).lower() if usage and usage in TRACK_EXCLUDED_USAGE_TAGS: return False return True def _is_oneway(value: Any) -> bool: if value is None: return False normalized = str(value).strip().lower() return normalized in {"yes", "true", "1"} def _polyline_length(points: list[list[float]]) -> float: if len(points) < 2: return 0.0 total = 0.0 for index in range(len(points) - 1): total += _haversine(points[index], points[index + 1]) return total def _haversine(a: list[float], b: list[float]) -> float: """Return distance in meters between two [lat, lon] coordinates.""" lat1, lon1 = a lat2, lon2 = b radius = 6_371_000 phi1 = math.radians(lat1) phi2 = math.radians(lat2) delta_phi = math.radians(lat2 - lat1) delta_lambda = math.radians(lon2 - lon1) sin_dphi = math.sin(delta_phi / 2) sin_dlambda = math.sin(delta_lambda / 2) root = sin_dphi**2 + math.cos(phi1) * math.cos(phi2) * sin_dlambda**2 distance = 2 * radius * math.atan2(math.sqrt(root), math.sqrt(1 - root)) return distance def main(argv: list[str] | None = None) -> int: parser = build_argument_parser() args = parser.parse_args(argv) query = build_overpass_query(args.region) if args.dry_run: print(query) return 0 output_path: Path = args.output output_path.parent.mkdir(parents=True, exist_ok=True) data = perform_request(query) raw_elements = data.get("elements", []) tracks = normalize_track_elements(raw_elements) payload = { "metadata": { "endpoint": OVERPASS_ENDPOINT, "region": args.region, "filters": TRACK_TAG_FILTERS, "regions": [asdict(region) for region in DEFAULT_REGIONS], "raw_count": len(raw_elements), "track_count": len(tracks), }, "tracks": tracks, } with output_path.open("w", encoding="utf-8") as handle: json.dump(payload, handle, indent=2) print( f"Normalized {len(tracks)} tracks from {len(raw_elements)} elements into {output_path}" ) return 0 if __name__ == "__main__": sys.exit(main())