Files
rail-game/backend/scripts/stations_import.py
zwitschi c35049cd54
Some checks failed
Backend CI / lint-and-test (push) Failing after 1m54s
fix: formatting (black)
2025-10-11 21:58:32 +02:00

170 lines
4.9 KiB
Python

from __future__ import annotations
"""CLI utility to import station data from OpenStreetMap."""
import argparse
import json
import sys
from dataclasses import asdict
from pathlib import Path
from typing import Any, Iterable
from urllib.parse import quote_plus
from backend.app.core.osm_config import (
DEFAULT_REGIONS,
STATION_TAG_FILTERS,
compile_overpass_filters,
)
OVERPASS_ENDPOINT = "https://overpass-api.de/api/interpreter"
def build_argument_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(
description="Export OSM station nodes for ingestion"
)
parser.add_argument(
"--output",
type=Path,
default=Path("data/osm_stations.json"),
help="Destination file for the exported station nodes (default: data/osm_stations.json)",
)
parser.add_argument(
"--region",
choices=[region.name for region in DEFAULT_REGIONS] + ["all"],
default="all",
help="Region name to export (default: all)",
)
parser.add_argument(
"--dry-run",
action="store_true",
help="Do not fetch data; print the Overpass payload only",
)
return parser
def build_overpass_query(region_name: str) -> str:
if region_name == "all":
regions = DEFAULT_REGIONS
else:
regions = tuple(
region for region in DEFAULT_REGIONS if region.name == region_name
)
if not regions:
msg = f"Unknown region {region_name}. Available regions: {[region.name for region in DEFAULT_REGIONS]}"
raise ValueError(msg)
filters = compile_overpass_filters(STATION_TAG_FILTERS)
parts = ["[out:json][timeout:90];", "("]
for region in regions:
parts.append(f" node{filters}\n ({region.to_overpass_arg()});")
parts.append(")")
parts.append("; out body; >; out skel qt;")
return "\n".join(parts)
def perform_request(query: str) -> dict[str, Any]:
import urllib.request
payload = f"data={quote_plus(query)}".encode("utf-8")
request = urllib.request.Request(
OVERPASS_ENDPOINT,
data=payload,
headers={"Content-Type": "application/x-www-form-urlencoded"},
)
with urllib.request.urlopen(request, timeout=120) as response:
payload = response.read()
return json.loads(payload)
def normalize_station_elements(
elements: Iterable[dict[str, Any]]
) -> list[dict[str, Any]]:
"""Convert raw Overpass nodes into StationCreate-compatible payloads."""
stations: list[dict[str, Any]] = []
for element in elements:
if element.get("type") != "node":
continue
latitude = element.get("lat")
longitude = element.get("lon")
if latitude is None or longitude is None:
continue
tags: dict[str, Any] = element.get("tags", {})
name = tags.get("name")
if not name:
continue
raw_code = tags.get("ref") or tags.get("railway:ref") or tags.get("local_ref")
code = str(raw_code) if raw_code is not None else None
elevation_tag = tags.get("ele") or tags.get("elevation")
try:
elevation = float(elevation_tag) if elevation_tag is not None else None
except (TypeError, ValueError):
elevation = None
disused = str(tags.get("disused", "no")).lower() in {"yes", "true"}
railway_status = str(tags.get("railway", "")).lower()
abandoned = railway_status in {"abandoned", "disused"}
is_active = not (disused or abandoned)
stations.append(
{
"osm_id": str(element.get("id")),
"name": str(name),
"latitude": float(latitude),
"longitude": float(longitude),
"code": code,
"elevation_m": elevation,
"is_active": is_active,
}
)
return stations
def main(argv: list[str] | None = None) -> int:
parser = build_argument_parser()
args = parser.parse_args(argv)
query = build_overpass_query(args.region)
if args.dry_run:
print(query)
return 0
output_path: Path = args.output
output_path.parent.mkdir(parents=True, exist_ok=True)
data = perform_request(query)
raw_elements = data.get("elements", [])
stations = normalize_station_elements(raw_elements)
payload = {
"metadata": {
"endpoint": OVERPASS_ENDPOINT,
"region": args.region,
"filters": STATION_TAG_FILTERS,
"regions": [asdict(region) for region in DEFAULT_REGIONS],
"raw_count": len(raw_elements),
"station_count": len(stations),
},
"stations": stations,
}
with output_path.open("w", encoding="utf-8") as handle:
json.dump(payload, handle, indent=2)
print(
f"Normalized {len(stations)} stations from {len(raw_elements)} elements into {output_path}"
)
return 0
if __name__ == "__main__":
sys.exit(main())