Files
rail-game/backend/scripts/stations_load.py

94 lines
2.6 KiB
Python

from __future__ import annotations
"""CLI for loading normalized station JSON into the database."""
import argparse
import json
import sys
from pathlib import Path
from typing import Any, Iterable, Mapping
from backend.app.db.session import SessionLocal
from backend.app.models import StationCreate
from backend.app.repositories import StationRepository
def build_argument_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(
description="Load normalized station data into PostGIS"
)
parser.add_argument(
"input",
type=Path,
help="Path to the normalized station JSON file produced by stations_import.py",
)
parser.add_argument(
"--commit",
dest="commit",
action="store_true",
default=True,
help="Commit the transaction after loading (default).",
)
parser.add_argument(
"--no-commit",
dest="commit",
action="store_false",
help="Rollback the transaction after loading (useful for dry runs).",
)
return parser
def main(argv: list[str] | None = None) -> int:
parser = build_argument_parser()
args = parser.parse_args(argv)
if not args.input.exists():
parser.error(f"Input file {args.input} does not exist")
with args.input.open("r", encoding="utf-8") as handle:
payload = json.load(handle)
stations_data = payload.get("stations") or []
if not isinstance(stations_data, list):
parser.error("Invalid payload: 'stations' must be a list")
try:
station_creates = _parse_station_entries(stations_data)
except ValueError as exc:
parser.error(str(exc))
created = load_stations(station_creates, commit=args.commit)
print(f"Loaded {created} stations from {args.input}")
return 0
def _parse_station_entries(entries: Iterable[Mapping[str, Any]]) -> list[StationCreate]:
parsed: list[StationCreate] = []
for entry in entries:
try:
parsed.append(StationCreate(**entry))
except Exception as exc: # pragma: no cover - validated in tests
raise ValueError(f"Invalid station entry {entry}: {exc}") from exc
return parsed
def load_stations(stations: Iterable[StationCreate], commit: bool = True) -> int:
created = 0
with SessionLocal() as session:
repo = StationRepository(session)
for create_schema in stations:
repo.create(create_schema)
created += 1
if commit:
session.commit()
else:
session.rollback()
return created
if __name__ == "__main__":
sys.exit(main())