Add unit tests for station service and enhance documentation
Some checks failed
Backend CI / lint-and-test (push) Failing after 37s
Some checks failed
Backend CI / lint-and-test (push) Failing after 37s
- Introduced unit tests for the station service, covering creation, updating, and archiving of stations. - Added detailed building block view documentation outlining the architecture of the Rail Game system. - Created runtime view documentation illustrating key user interactions and system behavior. - Developed concepts documentation detailing domain models, architectural patterns, and security considerations. - Updated architecture documentation to reference new detailed sections for building block and runtime views.
This commit is contained in:
@@ -3,8 +3,10 @@ from fastapi import APIRouter
|
||||
from backend.app.api.auth import router as auth_router
|
||||
from backend.app.api.health import router as health_router
|
||||
from backend.app.api.network import router as network_router
|
||||
from backend.app.api.stations import router as stations_router
|
||||
|
||||
router = APIRouter()
|
||||
router.include_router(health_router, tags=["health"])
|
||||
router.include_router(auth_router)
|
||||
router.include_router(network_router)
|
||||
router.include_router(stations_router)
|
||||
|
||||
94
backend/app/api/stations.py
Normal file
94
backend/app/api/stations.py
Normal file
@@ -0,0 +1,94 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from backend.app.api.deps import get_current_user, get_db
|
||||
from backend.app.models import StationCreate, StationModel, StationUpdate, UserPublic
|
||||
from backend.app.services.stations import (
|
||||
archive_station,
|
||||
create_station,
|
||||
get_station,
|
||||
list_stations,
|
||||
update_station,
|
||||
)
|
||||
|
||||
router = APIRouter(prefix="/stations", tags=["stations"])
|
||||
|
||||
|
||||
@router.get("", response_model=list[StationModel])
|
||||
def read_stations(
|
||||
include_inactive: bool = False,
|
||||
_: UserPublic = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> list[StationModel]:
|
||||
return list_stations(db, include_inactive=include_inactive)
|
||||
|
||||
|
||||
@router.get("/{station_id}", response_model=StationModel)
|
||||
def read_station(
|
||||
station_id: str,
|
||||
_: UserPublic = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> StationModel:
|
||||
try:
|
||||
return get_station(db, station_id)
|
||||
except LookupError as exc:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)
|
||||
) from exc
|
||||
except ValueError as exc:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST, detail=str(exc)
|
||||
) from exc
|
||||
|
||||
|
||||
@router.post("", response_model=StationModel, status_code=status.HTTP_201_CREATED)
|
||||
def create_station_endpoint(
|
||||
payload: StationCreate,
|
||||
_: UserPublic = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> StationModel:
|
||||
try:
|
||||
return create_station(db, payload)
|
||||
except ValueError as exc:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST, detail=str(exc)
|
||||
) from exc
|
||||
|
||||
|
||||
@router.put("/{station_id}", response_model=StationModel)
|
||||
def update_station_endpoint(
|
||||
station_id: str,
|
||||
payload: StationUpdate,
|
||||
_: UserPublic = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> StationModel:
|
||||
try:
|
||||
return update_station(db, station_id, payload)
|
||||
except LookupError as exc:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)
|
||||
) from exc
|
||||
except ValueError as exc:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST, detail=str(exc)
|
||||
) from exc
|
||||
|
||||
|
||||
@router.post("/{station_id}/archive", response_model=StationModel)
|
||||
def archive_station_endpoint(
|
||||
station_id: str,
|
||||
_: UserPublic = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> StationModel:
|
||||
try:
|
||||
return archive_station(db, station_id)
|
||||
except LookupError as exc:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)
|
||||
) from exc
|
||||
except ValueError as exc:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST, detail=str(exc)
|
||||
) from exc
|
||||
93
backend/app/core/osm_config.py
Normal file
93
backend/app/core/osm_config.py
Normal file
@@ -0,0 +1,93 @@
|
||||
from __future__ import annotations
|
||||
|
||||
"""Geographic presets and tagging rules for OpenStreetMap imports."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Iterable, Mapping, Tuple
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class BoundingBox:
|
||||
"""Geographic bounding box expressed as WGS84 coordinates."""
|
||||
|
||||
name: str
|
||||
north: float
|
||||
south: float
|
||||
east: float
|
||||
west: float
|
||||
description: str | None = None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if self.north <= self.south:
|
||||
msg = f"north ({self.north}) must be greater than south ({self.south})"
|
||||
raise ValueError(msg)
|
||||
if self.east <= self.west:
|
||||
msg = f"east ({self.east}) must be greater than west ({self.west})"
|
||||
raise ValueError(msg)
|
||||
|
||||
def contains(self, latitude: float, longitude: float) -> bool:
|
||||
"""Return True when the given coordinate lies inside the bounding box."""
|
||||
|
||||
return (
|
||||
self.south <= latitude <= self.north and self.west <= longitude <= self.east
|
||||
)
|
||||
|
||||
def to_overpass_arg(self) -> str:
|
||||
"""Return the bbox string used for Overpass API queries."""
|
||||
|
||||
return f"{self.south},{self.west},{self.north},{self.east}"
|
||||
|
||||
|
||||
# Primary metropolitan areas we plan to support.
|
||||
DEFAULT_REGIONS: Tuple[BoundingBox, ...] = (
|
||||
BoundingBox(
|
||||
name="berlin_metropolitan",
|
||||
north=52.6755,
|
||||
south=52.3381,
|
||||
east=13.7611,
|
||||
west=13.0884,
|
||||
description="Berlin and surrounding rapid transit network",
|
||||
),
|
||||
BoundingBox(
|
||||
name="hamburg_metropolitan",
|
||||
north=53.7447,
|
||||
south=53.3950,
|
||||
east=10.3253,
|
||||
west=9.7270,
|
||||
description="Hamburg S-Bahn and harbor region",
|
||||
),
|
||||
BoundingBox(
|
||||
name="munich_metropolitan",
|
||||
north=48.2485,
|
||||
south=47.9960,
|
||||
east=11.7229,
|
||||
west=11.3600,
|
||||
description="Munich S-Bahn core and airport corridor",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
# Tags that identify passenger stations and stops.
|
||||
STATION_TAG_FILTERS: Mapping[str, Tuple[str, ...]] = {
|
||||
"railway": ("station", "halt", "stop"),
|
||||
"public_transport": ("station", "stop_position", "platform"),
|
||||
"train": ("yes", "regional", "suburban"),
|
||||
}
|
||||
|
||||
|
||||
def compile_overpass_filters(filters: Mapping[str, Iterable[str]]) -> str:
|
||||
"""Build an Overpass boolean expression that matches the provided filters."""
|
||||
|
||||
parts: list[str] = []
|
||||
for key, values in filters.items():
|
||||
options = "|".join(sorted(set(values)))
|
||||
parts.append(f' ["{key}"~"^({options})$"]')
|
||||
return "\n".join(parts)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"BoundingBox",
|
||||
"DEFAULT_REGIONS",
|
||||
"STATION_TAG_FILTERS",
|
||||
"compile_overpass_filters",
|
||||
]
|
||||
@@ -10,6 +10,7 @@ from .auth import (
|
||||
from .base import (
|
||||
StationCreate,
|
||||
StationModel,
|
||||
StationUpdate,
|
||||
TrackCreate,
|
||||
TrackModel,
|
||||
TrainCreate,
|
||||
@@ -29,6 +30,7 @@ __all__ = [
|
||||
"UserPublic",
|
||||
"StationCreate",
|
||||
"StationModel",
|
||||
"StationUpdate",
|
||||
"TrackCreate",
|
||||
"TrackModel",
|
||||
"TrainScheduleCreate",
|
||||
|
||||
@@ -42,6 +42,10 @@ class StationModel(IdentifiedModel[str]):
|
||||
name: str
|
||||
latitude: float
|
||||
longitude: float
|
||||
code: str | None = None
|
||||
osm_id: str | None = None
|
||||
elevation_m: float | None = None
|
||||
is_active: bool = True
|
||||
|
||||
|
||||
class TrackModel(IdentifiedModel[str]):
|
||||
@@ -68,6 +72,16 @@ class StationCreate(CamelModel):
|
||||
is_active: bool = True
|
||||
|
||||
|
||||
class StationUpdate(CamelModel):
|
||||
name: str | None = None
|
||||
latitude: float | None = None
|
||||
longitude: float | None = None
|
||||
osm_id: str | None = None
|
||||
code: str | None = None
|
||||
elevation_m: float | None = None
|
||||
is_active: bool | None = None
|
||||
|
||||
|
||||
class TrackCreate(CamelModel):
|
||||
start_station_id: str
|
||||
end_station_id: str
|
||||
|
||||
195
backend/app/services/stations.py
Normal file
195
backend/app/services/stations.py
Normal file
@@ -0,0 +1,195 @@
|
||||
from __future__ import annotations
|
||||
|
||||
"""Application services for station CRUD operations."""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from typing import cast
|
||||
from uuid import UUID
|
||||
|
||||
from geoalchemy2.elements import WKBElement, WKTElement
|
||||
from geoalchemy2.shape import to_shape
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from backend.app.db.models import Station
|
||||
from backend.app.models import StationCreate, StationModel, StationUpdate
|
||||
from backend.app.repositories import StationRepository
|
||||
|
||||
try: # pragma: no cover - optional dependency guard
|
||||
from shapely.geometry import Point # type: ignore
|
||||
except ImportError: # pragma: no cover - allow running without shapely at import time
|
||||
Point = None # type: ignore[assignment]
|
||||
|
||||
|
||||
def list_stations(
|
||||
session: Session, include_inactive: bool = False
|
||||
) -> list[StationModel]:
|
||||
repo = StationRepository(session)
|
||||
if include_inactive:
|
||||
stations = repo.list()
|
||||
else:
|
||||
stations = repo.list_active()
|
||||
return [_to_station_model(station) for station in stations]
|
||||
|
||||
|
||||
def get_station(session: Session, station_id: str) -> StationModel:
|
||||
repo = StationRepository(session)
|
||||
station = _resolve_station(repo, station_id)
|
||||
return _to_station_model(station)
|
||||
|
||||
|
||||
def create_station(session: Session, payload: StationCreate) -> StationModel:
|
||||
name = payload.name.strip()
|
||||
if not name:
|
||||
raise ValueError("Station name must not be empty")
|
||||
_validate_coordinates(payload.latitude, payload.longitude)
|
||||
|
||||
repo = StationRepository(session)
|
||||
station = repo.create(
|
||||
StationCreate(
|
||||
name=name,
|
||||
latitude=payload.latitude,
|
||||
longitude=payload.longitude,
|
||||
osm_id=_normalize_optional(payload.osm_id),
|
||||
code=_normalize_optional(payload.code),
|
||||
elevation_m=payload.elevation_m,
|
||||
is_active=payload.is_active,
|
||||
)
|
||||
)
|
||||
session.flush()
|
||||
session.refresh(station)
|
||||
session.commit()
|
||||
return _to_station_model(station)
|
||||
|
||||
|
||||
def update_station(
|
||||
session: Session, station_id: str, payload: StationUpdate
|
||||
) -> StationModel:
|
||||
repo = StationRepository(session)
|
||||
station = _resolve_station(repo, station_id)
|
||||
|
||||
if payload.name is not None:
|
||||
name = payload.name.strip()
|
||||
if not name:
|
||||
raise ValueError("Station name must not be empty")
|
||||
station.name = name
|
||||
|
||||
if payload.latitude is not None or payload.longitude is not None:
|
||||
if payload.latitude is None or payload.longitude is None:
|
||||
raise ValueError("Both latitude and longitude must be provided together")
|
||||
_validate_coordinates(payload.latitude, payload.longitude)
|
||||
station.location = repo._point(
|
||||
payload.latitude, payload.longitude
|
||||
) # type: ignore[assignment]
|
||||
|
||||
if payload.osm_id is not None:
|
||||
station.osm_id = _normalize_optional(payload.osm_id)
|
||||
|
||||
if payload.code is not None:
|
||||
station.code = _normalize_optional(payload.code)
|
||||
|
||||
if payload.elevation_m is not None:
|
||||
station.elevation_m = payload.elevation_m
|
||||
|
||||
if payload.is_active is not None:
|
||||
station.is_active = payload.is_active
|
||||
|
||||
session.flush()
|
||||
session.refresh(station)
|
||||
session.commit()
|
||||
return _to_station_model(station)
|
||||
|
||||
|
||||
def archive_station(session: Session, station_id: str) -> StationModel:
|
||||
repo = StationRepository(session)
|
||||
station = _resolve_station(repo, station_id)
|
||||
if station.is_active:
|
||||
station.is_active = False
|
||||
session.flush()
|
||||
session.refresh(station)
|
||||
session.commit()
|
||||
return _to_station_model(station)
|
||||
|
||||
|
||||
def _resolve_station(repo: StationRepository, station_id: str) -> Station:
|
||||
identifier = _parse_station_id(station_id)
|
||||
station = repo.get(identifier)
|
||||
if station is None:
|
||||
raise LookupError("Station not found")
|
||||
return station
|
||||
|
||||
|
||||
def _parse_station_id(station_id: str) -> UUID:
|
||||
try:
|
||||
return UUID(station_id)
|
||||
except (ValueError, TypeError) as exc: # pragma: no cover - simple validation
|
||||
raise ValueError("Invalid station identifier") from exc
|
||||
|
||||
|
||||
def _validate_coordinates(latitude: float, longitude: float) -> None:
|
||||
if not (-90.0 <= latitude <= 90.0):
|
||||
raise ValueError("Latitude must be between -90 and 90 degrees")
|
||||
if not (-180.0 <= longitude <= 180.0):
|
||||
raise ValueError("Longitude must be between -180 and 180 degrees")
|
||||
|
||||
|
||||
def _normalize_optional(value: str | None) -> str | None:
|
||||
if value is None:
|
||||
return None
|
||||
normalized = value.strip()
|
||||
return normalized or None
|
||||
|
||||
|
||||
def _to_station_model(station: Station) -> StationModel:
|
||||
latitude, longitude = _extract_coordinates(station.location)
|
||||
created_at = station.created_at or datetime.now(timezone.utc)
|
||||
updated_at = station.updated_at or created_at
|
||||
return StationModel(
|
||||
id=str(station.id),
|
||||
name=station.name,
|
||||
latitude=latitude,
|
||||
longitude=longitude,
|
||||
code=station.code,
|
||||
osm_id=station.osm_id,
|
||||
elevation_m=station.elevation_m,
|
||||
is_active=station.is_active,
|
||||
created_at=cast(datetime, created_at),
|
||||
updated_at=cast(datetime, updated_at),
|
||||
)
|
||||
|
||||
|
||||
def _extract_coordinates(location: object) -> tuple[float, float]:
|
||||
if location is None:
|
||||
raise ValueError("Station location is unavailable")
|
||||
|
||||
# Attempt to leverage GeoAlchemy's shapely integration first.
|
||||
try:
|
||||
geometry = to_shape(cast(WKBElement | WKTElement, location))
|
||||
if Point is not None and isinstance(geometry, Point):
|
||||
return float(geometry.y), float(geometry.x)
|
||||
except Exception: # pragma: no cover - fallback handles parsing
|
||||
pass
|
||||
|
||||
if isinstance(location, WKTElement):
|
||||
return _parse_wkt_point(location.data)
|
||||
|
||||
text = getattr(location, "desc", None)
|
||||
if isinstance(text, str):
|
||||
return _parse_wkt_point(text)
|
||||
|
||||
raise ValueError("Unable to read station geometry")
|
||||
|
||||
|
||||
def _parse_wkt_point(wkt: str) -> tuple[float, float]:
|
||||
marker = "POINT"
|
||||
if not wkt.upper().startswith(marker):
|
||||
raise ValueError("Unsupported geometry format")
|
||||
start = wkt.find("(")
|
||||
end = wkt.find(")", start)
|
||||
if start == -1 or end == -1:
|
||||
raise ValueError("Malformed POINT geometry")
|
||||
coordinates = wkt[start + 1 : end].strip().split()
|
||||
if len(coordinates) != 2:
|
||||
raise ValueError("POINT geometry must contain two coordinates")
|
||||
longitude, latitude = map(float, coordinates)
|
||||
_validate_coordinates(latitude, longitude)
|
||||
return latitude, longitude
|
||||
171
backend/scripts/stations_import.py
Normal file
171
backend/scripts/stations_import.py
Normal file
@@ -0,0 +1,171 @@
|
||||
from __future__ import annotations
|
||||
|
||||
"""CLI utility to import station data from OpenStreetMap."""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import sys
|
||||
from dataclasses import asdict
|
||||
from pathlib import Path
|
||||
from typing import Any, Iterable
|
||||
from urllib.parse import quote_plus
|
||||
|
||||
from backend.app.core.osm_config import (
|
||||
DEFAULT_REGIONS,
|
||||
STATION_TAG_FILTERS,
|
||||
compile_overpass_filters,
|
||||
)
|
||||
|
||||
OVERPASS_ENDPOINT = "https://overpass-api.de/api/interpreter"
|
||||
|
||||
|
||||
def build_argument_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Export OSM station nodes for ingestion"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output",
|
||||
type=Path,
|
||||
default=Path("data/osm_stations.json"),
|
||||
help="Destination file for the exported station nodes (default: data/osm_stations.json)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--region",
|
||||
choices=[region.name for region in DEFAULT_REGIONS] + ["all"],
|
||||
default="all",
|
||||
help="Region name to export (default: all)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dry-run",
|
||||
action="store_true",
|
||||
help="Do not fetch data; print the Overpass payload only",
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
def build_overpass_query(region_name: str) -> str:
|
||||
if region_name == "all":
|
||||
regions = DEFAULT_REGIONS
|
||||
else:
|
||||
regions = tuple(
|
||||
region for region in DEFAULT_REGIONS if region.name == region_name
|
||||
)
|
||||
if not regions:
|
||||
msg = f"Unknown region {region_name}. Available regions: {[region.name for region in DEFAULT_REGIONS]}"
|
||||
raise ValueError(msg)
|
||||
|
||||
filters = compile_overpass_filters(STATION_TAG_FILTERS)
|
||||
|
||||
parts = ["[out:json][timeout:90];", "("]
|
||||
for region in regions:
|
||||
parts.append(f" node{filters}\n ({region.to_overpass_arg()});")
|
||||
parts.append(")")
|
||||
parts.append("; out body; >; out skel qt;")
|
||||
return "\n".join(parts)
|
||||
|
||||
|
||||
def perform_request(query: str) -> dict[str, Any]:
|
||||
import urllib.request
|
||||
|
||||
payload = f"data={quote_plus(query)}".encode("utf-8")
|
||||
request = urllib.request.Request(
|
||||
OVERPASS_ENDPOINT,
|
||||
data=payload,
|
||||
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
||||
)
|
||||
with urllib.request.urlopen(request, timeout=120) as response:
|
||||
payload = response.read()
|
||||
return json.loads(payload)
|
||||
|
||||
|
||||
def normalize_station_elements(
|
||||
elements: Iterable[dict[str, Any]]
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Convert raw Overpass nodes into StationCreate-compatible payloads."""
|
||||
|
||||
stations: list[dict[str, Any]] = []
|
||||
for element in elements:
|
||||
if element.get("type") != "node":
|
||||
continue
|
||||
|
||||
latitude = element.get("lat")
|
||||
longitude = element.get("lon")
|
||||
if latitude is None or longitude is None:
|
||||
continue
|
||||
|
||||
tags: dict[str, Any] = element.get("tags", {})
|
||||
name = tags.get("name")
|
||||
if not name:
|
||||
continue
|
||||
|
||||
raw_code = tags.get("ref") or tags.get(
|
||||
"railway:ref") or tags.get("local_ref")
|
||||
code = str(raw_code) if raw_code is not None else None
|
||||
|
||||
elevation_tag = tags.get("ele") or tags.get("elevation")
|
||||
try:
|
||||
elevation = float(
|
||||
elevation_tag) if elevation_tag is not None else None
|
||||
except (TypeError, ValueError):
|
||||
elevation = None
|
||||
|
||||
disused = str(tags.get("disused", "no")).lower() in {"yes", "true"}
|
||||
railway_status = str(tags.get("railway", "")).lower()
|
||||
abandoned = railway_status in {"abandoned", "disused"}
|
||||
is_active = not (disused or abandoned)
|
||||
|
||||
stations.append(
|
||||
{
|
||||
"osm_id": str(element.get("id")),
|
||||
"name": str(name),
|
||||
"latitude": float(latitude),
|
||||
"longitude": float(longitude),
|
||||
"code": code,
|
||||
"elevation_m": elevation,
|
||||
"is_active": is_active,
|
||||
}
|
||||
)
|
||||
|
||||
return stations
|
||||
|
||||
|
||||
def main(argv: list[str] | None = None) -> int:
|
||||
parser = build_argument_parser()
|
||||
args = parser.parse_args(argv)
|
||||
|
||||
query = build_overpass_query(args.region)
|
||||
|
||||
if args.dry_run:
|
||||
print(query)
|
||||
return 0
|
||||
|
||||
output_path: Path = args.output
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
data = perform_request(query)
|
||||
raw_elements = data.get("elements", [])
|
||||
stations = normalize_station_elements(raw_elements)
|
||||
|
||||
payload = {
|
||||
"metadata": {
|
||||
"endpoint": OVERPASS_ENDPOINT,
|
||||
"region": args.region,
|
||||
"filters": STATION_TAG_FILTERS,
|
||||
"regions": [asdict(region) for region in DEFAULT_REGIONS],
|
||||
"raw_count": len(raw_elements),
|
||||
"station_count": len(stations),
|
||||
},
|
||||
"stations": stations,
|
||||
}
|
||||
|
||||
with output_path.open("w", encoding="utf-8") as handle:
|
||||
json.dump(payload, handle, indent=2)
|
||||
|
||||
print(
|
||||
f"Normalized {len(stations)} stations from {len(raw_elements)} elements into {output_path}"
|
||||
)
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
86
backend/scripts/stations_load.py
Normal file
86
backend/scripts/stations_load.py
Normal file
@@ -0,0 +1,86 @@
|
||||
from __future__ import annotations
|
||||
|
||||
"""CLI for loading normalized station JSON into the database."""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Any, Iterable, Mapping
|
||||
|
||||
from backend.app.db.session import SessionLocal
|
||||
from backend.app.models import StationCreate
|
||||
from backend.app.repositories import StationRepository
|
||||
|
||||
|
||||
def build_argument_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Load normalized station data into PostGIS"
|
||||
)
|
||||
parser.add_argument(
|
||||
"input",
|
||||
type=Path,
|
||||
help="Path to the normalized station JSON file produced by stations_import.py",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--commit/--no-commit",
|
||||
dest="commit",
|
||||
default=True,
|
||||
help="Commit the transaction (default: commit). Use --no-commit for dry runs.",
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
def main(argv: list[str] | None = None) -> int:
|
||||
parser = build_argument_parser()
|
||||
args = parser.parse_args(argv)
|
||||
|
||||
if not args.input.exists():
|
||||
parser.error(f"Input file {args.input} does not exist")
|
||||
|
||||
with args.input.open("r", encoding="utf-8") as handle:
|
||||
payload = json.load(handle)
|
||||
|
||||
stations_data = payload.get("stations") or []
|
||||
if not isinstance(stations_data, list):
|
||||
parser.error("Invalid payload: 'stations' must be a list")
|
||||
|
||||
try:
|
||||
station_creates = _parse_station_entries(stations_data)
|
||||
except ValueError as exc:
|
||||
parser.error(str(exc))
|
||||
|
||||
created = load_stations(station_creates, commit=args.commit)
|
||||
|
||||
print(f"Loaded {created} stations from {args.input}")
|
||||
return 0
|
||||
|
||||
|
||||
def _parse_station_entries(entries: Iterable[Mapping[str, Any]]) -> list[StationCreate]:
|
||||
parsed: list[StationCreate] = []
|
||||
for entry in entries:
|
||||
try:
|
||||
parsed.append(StationCreate(**entry))
|
||||
except Exception as exc: # pragma: no cover - validated in tests
|
||||
raise ValueError(f"Invalid station entry {entry}: {exc}") from exc
|
||||
return parsed
|
||||
|
||||
|
||||
def load_stations(stations: Iterable[StationCreate], commit: bool = True) -> int:
|
||||
created = 0
|
||||
with SessionLocal() as session:
|
||||
repo = StationRepository(session)
|
||||
|
||||
for create_schema in stations:
|
||||
repo.create(create_schema)
|
||||
created += 1
|
||||
|
||||
if commit:
|
||||
session.commit()
|
||||
else:
|
||||
session.rollback()
|
||||
return created
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
28
backend/tests/test_osm_config.py
Normal file
28
backend/tests/test_osm_config.py
Normal file
@@ -0,0 +1,28 @@
|
||||
from backend.app.core.osm_config import (
|
||||
DEFAULT_REGIONS,
|
||||
STATION_TAG_FILTERS,
|
||||
BoundingBox,
|
||||
compile_overpass_filters,
|
||||
)
|
||||
|
||||
|
||||
def test_default_regions_are_valid() -> None:
|
||||
assert DEFAULT_REGIONS, "Expected at least one region definition"
|
||||
for bbox in DEFAULT_REGIONS:
|
||||
assert isinstance(bbox, BoundingBox)
|
||||
assert bbox.north > bbox.south
|
||||
assert bbox.east > bbox.west
|
||||
# Berlin coordinates should fall inside Berlin bounding box for sanity
|
||||
if bbox.name == "berlin_metropolitan":
|
||||
assert bbox.contains(52.5200, 13.4050)
|
||||
|
||||
|
||||
def test_station_tag_filters_compile_to_overpass_snippet() -> None:
|
||||
compiled = compile_overpass_filters(STATION_TAG_FILTERS)
|
||||
# Ensure each key is present with its values
|
||||
for key, values in STATION_TAG_FILTERS.items():
|
||||
assert key in compiled
|
||||
for value in values:
|
||||
assert value in compiled
|
||||
# The snippet should be multi-line to preserve readability
|
||||
assert "\n" in compiled
|
||||
137
backend/tests/test_stations_api.py
Normal file
137
backend/tests/test_stations_api.py
Normal file
@@ -0,0 +1,137 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from backend.app.api import stations as stations_api
|
||||
from backend.app.main import app
|
||||
from backend.app.models import StationCreate, StationModel, StationUpdate
|
||||
|
||||
AUTH_CREDENTIALS = {"username": "demo", "password": "railgame123"}
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
|
||||
def _authenticate() -> str:
|
||||
response = client.post("/api/auth/login", json=AUTH_CREDENTIALS)
|
||||
assert response.status_code == 200
|
||||
return response.json()["accessToken"]
|
||||
|
||||
|
||||
def _station_payload(**overrides: Any) -> dict[str, Any]:
|
||||
payload = {
|
||||
"name": "Central",
|
||||
"latitude": 52.52,
|
||||
"longitude": 13.405,
|
||||
"osmId": "123",
|
||||
"code": "BER",
|
||||
"elevationM": 34.5,
|
||||
"isActive": True,
|
||||
}
|
||||
payload.update(overrides)
|
||||
return payload
|
||||
|
||||
|
||||
def _station_model(**overrides: Any) -> StationModel:
|
||||
now = datetime.now(timezone.utc)
|
||||
base = StationModel(
|
||||
id=str(uuid4()),
|
||||
name="Central",
|
||||
latitude=52.52,
|
||||
longitude=13.405,
|
||||
code="BER",
|
||||
osm_id="123",
|
||||
elevation_m=34.5,
|
||||
is_active=True,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
)
|
||||
return base.model_copy(update=overrides)
|
||||
|
||||
|
||||
def test_list_stations_requires_authentication() -> None:
|
||||
response = client.get("/api/stations")
|
||||
assert response.status_code == 401
|
||||
|
||||
|
||||
def test_list_stations_returns_payload(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
token = _authenticate()
|
||||
|
||||
def fake_list_stations(db, include_inactive: bool) -> list[StationModel]:
|
||||
assert include_inactive is True
|
||||
return [_station_model()]
|
||||
|
||||
monkeypatch.setattr(stations_api, "list_stations", fake_list_stations)
|
||||
|
||||
response = client.get(
|
||||
"/api/stations",
|
||||
params={"include_inactive": "true"},
|
||||
headers={"Authorization": f"Bearer {token}"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
payload = response.json()
|
||||
assert len(payload) == 1
|
||||
assert payload[0]["name"] == "Central"
|
||||
|
||||
|
||||
def test_create_station_delegates_to_service(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
token = _authenticate()
|
||||
seen: dict[str, StationCreate] = {}
|
||||
|
||||
def fake_create_station(db, payload: StationCreate) -> StationModel:
|
||||
seen["payload"] = payload
|
||||
return _station_model()
|
||||
|
||||
monkeypatch.setattr(stations_api, "create_station", fake_create_station)
|
||||
|
||||
response = client.post(
|
||||
"/api/stations",
|
||||
json=_station_payload(),
|
||||
headers={"Authorization": f"Bearer {token}"},
|
||||
)
|
||||
|
||||
assert response.status_code == 201
|
||||
assert response.json()["name"] == "Central"
|
||||
assert seen["payload"].name == "Central"
|
||||
|
||||
|
||||
def test_update_station_not_found_returns_404(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
token = _authenticate()
|
||||
|
||||
def fake_update_station(
|
||||
db, station_id: str, payload: StationUpdate
|
||||
) -> StationModel:
|
||||
raise LookupError("Station not found")
|
||||
|
||||
monkeypatch.setattr(stations_api, "update_station", fake_update_station)
|
||||
|
||||
response = client.put(
|
||||
"/api/stations/123e4567-e89b-12d3-a456-426614174000",
|
||||
json={"name": "New Name"},
|
||||
headers={"Authorization": f"Bearer {token}"},
|
||||
)
|
||||
|
||||
assert response.status_code == 404
|
||||
assert response.json()["detail"] == "Station not found"
|
||||
|
||||
|
||||
def test_archive_station_returns_updated_model(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
token = _authenticate()
|
||||
|
||||
def fake_archive_station(db, station_id: str) -> StationModel:
|
||||
return _station_model(is_active=False)
|
||||
|
||||
monkeypatch.setattr(stations_api, "archive_station", fake_archive_station)
|
||||
|
||||
response = client.post(
|
||||
"/api/stations/123e4567-e89b-12d3-a456-426614174000/archive",
|
||||
headers={"Authorization": f"Bearer {token}"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json()["isActive"] is False
|
||||
67
backend/tests/test_stations_import.py
Normal file
67
backend/tests/test_stations_import.py
Normal file
@@ -0,0 +1,67 @@
|
||||
from backend.scripts.stations_import import (
|
||||
build_overpass_query,
|
||||
normalize_station_elements,
|
||||
)
|
||||
|
||||
|
||||
def test_build_overpass_query_single_region() -> None:
|
||||
query = build_overpass_query("berlin_metropolitan")
|
||||
|
||||
# The query should reference the Berlin bounding box coordinates.
|
||||
assert "52.3381" in query # south
|
||||
assert "52.6755" in query # north
|
||||
assert "13.0884" in query # west
|
||||
assert "13.7611" in query # east
|
||||
assert "node" in query
|
||||
assert "out body" in query
|
||||
|
||||
|
||||
def test_build_overpass_query_all_regions_includes_union() -> None:
|
||||
query = build_overpass_query("all")
|
||||
|
||||
# Ensure multiple regions are present by checking for repeated bbox parentheses.
|
||||
assert query.count("node") >= 3
|
||||
assert query.strip().endswith("out skel qt;")
|
||||
|
||||
|
||||
def test_normalize_station_elements_filters_and_transforms() -> None:
|
||||
raw_elements = [
|
||||
{
|
||||
"type": "node",
|
||||
"id": 123,
|
||||
"lat": 52.5,
|
||||
"lon": 13.4,
|
||||
"tags": {
|
||||
"name": "Sample Station",
|
||||
"ref": "XYZ",
|
||||
"ele": "35.5",
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "node",
|
||||
"id": 999,
|
||||
# Missing coordinates should be ignored
|
||||
"tags": {"name": "Broken"},
|
||||
},
|
||||
{
|
||||
"type": "node",
|
||||
"id": 456,
|
||||
"lat": 50.0,
|
||||
"lon": 8.0,
|
||||
"tags": {
|
||||
"name": "Disused Station",
|
||||
"disused": "yes",
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
stations = normalize_station_elements(raw_elements)
|
||||
|
||||
assert len(stations) == 2
|
||||
primary = stations[0]
|
||||
assert primary["osm_id"] == "123"
|
||||
assert primary["name"] == "Sample Station"
|
||||
assert primary["code"] == "XYZ"
|
||||
assert primary["elevation_m"] == 35.5
|
||||
disused_station = stations[1]
|
||||
assert disused_station["is_active"] is False
|
||||
142
backend/tests/test_stations_load.py
Normal file
142
backend/tests/test_stations_load.py
Normal file
@@ -0,0 +1,142 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.scripts import stations_load
|
||||
|
||||
|
||||
def test_parse_station_entries_returns_models() -> None:
|
||||
entries = [
|
||||
{
|
||||
"name": "Central",
|
||||
"latitude": 52.52,
|
||||
"longitude": 13.405,
|
||||
"osm_id": "123",
|
||||
"code": "BER",
|
||||
"elevation_m": 34.5,
|
||||
"is_active": True,
|
||||
}
|
||||
]
|
||||
|
||||
parsed = stations_load._parse_station_entries(entries)
|
||||
|
||||
assert parsed[0].name == "Central"
|
||||
assert parsed[0].latitude == 52.52
|
||||
assert parsed[0].osm_id == "123"
|
||||
|
||||
|
||||
def test_parse_station_entries_invalid_raises_value_error() -> None:
|
||||
entries = [
|
||||
{
|
||||
"latitude": 52.52,
|
||||
"longitude": 13.405,
|
||||
"is_active": True,
|
||||
}
|
||||
]
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
stations_load._parse_station_entries(entries)
|
||||
|
||||
|
||||
@dataclass
|
||||
class DummySession:
|
||||
committed: bool = False
|
||||
rolled_back: bool = False
|
||||
closed: bool = False
|
||||
|
||||
def __enter__(self) -> "DummySession":
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc, traceback) -> None:
|
||||
self.closed = True
|
||||
|
||||
def commit(self) -> None:
|
||||
self.committed = True
|
||||
|
||||
def rollback(self) -> None:
|
||||
self.rolled_back = True
|
||||
|
||||
|
||||
@dataclass
|
||||
class DummyRepository:
|
||||
session: DummySession
|
||||
created: list = field(default_factory=list)
|
||||
|
||||
def create(self, data) -> None: # pragma: no cover - simple delegation
|
||||
self.created.append(data)
|
||||
|
||||
|
||||
class DummySessionFactory:
|
||||
def __call__(self) -> DummySession:
|
||||
return DummySession()
|
||||
|
||||
|
||||
def test_load_stations_commits_when_requested(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
repo_instances: list[DummyRepository] = []
|
||||
|
||||
def fake_session_local() -> DummySession:
|
||||
return DummySession()
|
||||
|
||||
def fake_repo(session: DummySession) -> DummyRepository:
|
||||
repo = DummyRepository(session)
|
||||
repo_instances.append(repo)
|
||||
return repo
|
||||
|
||||
monkeypatch.setattr(stations_load, "SessionLocal", fake_session_local)
|
||||
monkeypatch.setattr(stations_load, "StationRepository", fake_repo)
|
||||
|
||||
stations = stations_load._parse_station_entries(
|
||||
[
|
||||
{
|
||||
"name": "Central",
|
||||
"latitude": 52.52,
|
||||
"longitude": 13.405,
|
||||
"osm_id": "123",
|
||||
"is_active": True,
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
created = stations_load.load_stations(stations, commit=True)
|
||||
|
||||
assert created == 1
|
||||
assert repo_instances[0].session.committed is True
|
||||
assert repo_instances[0].session.rolled_back is False
|
||||
assert len(repo_instances[0].created) == 1
|
||||
|
||||
|
||||
def test_load_stations_rolls_back_when_no_commit(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
repo_instances: list[DummyRepository] = []
|
||||
|
||||
def fake_session_local() -> DummySession:
|
||||
return DummySession()
|
||||
|
||||
def fake_repo(session: DummySession) -> DummyRepository:
|
||||
repo = DummyRepository(session)
|
||||
repo_instances.append(repo)
|
||||
return repo
|
||||
|
||||
monkeypatch.setattr(stations_load, "SessionLocal", fake_session_local)
|
||||
monkeypatch.setattr(stations_load, "StationRepository", fake_repo)
|
||||
|
||||
stations = stations_load._parse_station_entries(
|
||||
[
|
||||
{
|
||||
"name": "Central",
|
||||
"latitude": 52.52,
|
||||
"longitude": 13.405,
|
||||
"osm_id": "123",
|
||||
"is_active": True,
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
created = stations_load.load_stations(stations, commit=False)
|
||||
|
||||
assert created == 1
|
||||
assert repo_instances[0].session.committed is False
|
||||
assert repo_instances[0].session.rolled_back is True
|
||||
175
backend/tests/test_stations_service.py
Normal file
175
backend/tests/test_stations_service.py
Normal file
@@ -0,0 +1,175 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timezone
|
||||
from typing import Dict, List, cast
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
import pytest
|
||||
from geoalchemy2.elements import WKTElement
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from backend.app.models import StationCreate, StationUpdate
|
||||
from backend.app.services import stations as stations_service
|
||||
|
||||
|
||||
@dataclass
|
||||
class DummySession:
|
||||
flushed: bool = False
|
||||
committed: bool = False
|
||||
refreshed: List[object] = field(default_factory=list)
|
||||
|
||||
def flush(self) -> None:
|
||||
self.flushed = True
|
||||
|
||||
def refresh(self, instance: object) -> None: # pragma: no cover - simple setter
|
||||
self.refreshed.append(instance)
|
||||
|
||||
def commit(self) -> None:
|
||||
self.committed = True
|
||||
|
||||
|
||||
@dataclass
|
||||
class DummyStation:
|
||||
id: UUID
|
||||
name: str
|
||||
location: WKTElement
|
||||
osm_id: str | None
|
||||
code: str | None
|
||||
elevation_m: float | None
|
||||
is_active: bool
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
|
||||
class DummyStationRepository:
|
||||
_store: Dict[UUID, DummyStation] = {}
|
||||
|
||||
def __init__(self, session: DummySession) -> None: # pragma: no cover - simple init
|
||||
self.session = session
|
||||
|
||||
@staticmethod
|
||||
def _point(latitude: float, longitude: float) -> WKTElement:
|
||||
return WKTElement(f"POINT({longitude} {latitude})", srid=4326)
|
||||
|
||||
def list(self) -> list[DummyStation]:
|
||||
return list(self._store.values())
|
||||
|
||||
def list_active(self) -> list[DummyStation]:
|
||||
return [station for station in self._store.values() if station.is_active]
|
||||
|
||||
def get(self, identifier: UUID) -> DummyStation | None:
|
||||
return self._store.get(identifier)
|
||||
|
||||
def create(self, payload: StationCreate) -> DummyStation:
|
||||
station = DummyStation(
|
||||
id=uuid4(),
|
||||
name=payload.name,
|
||||
location=self._point(payload.latitude, payload.longitude),
|
||||
osm_id=payload.osm_id,
|
||||
code=payload.code,
|
||||
elevation_m=payload.elevation_m,
|
||||
is_active=payload.is_active,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
)
|
||||
self._store[station.id] = station
|
||||
return station
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_store(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
DummyStationRepository._store = {}
|
||||
monkeypatch.setattr(stations_service, "StationRepository", DummyStationRepository)
|
||||
|
||||
|
||||
def test_create_station_persists_and_returns_model(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
session = DummySession()
|
||||
payload = StationCreate(
|
||||
name="Central",
|
||||
latitude=52.52,
|
||||
longitude=13.405,
|
||||
osm_id="123",
|
||||
code="BER",
|
||||
elevation_m=34.5,
|
||||
is_active=True,
|
||||
)
|
||||
|
||||
result = stations_service.create_station(cast(Session, session), payload)
|
||||
|
||||
assert session.flushed is True
|
||||
assert session.committed is True
|
||||
assert result.name == "Central"
|
||||
assert result.latitude == pytest.approx(52.52)
|
||||
assert result.longitude == pytest.approx(13.405)
|
||||
assert result.osm_id == "123"
|
||||
|
||||
|
||||
def test_update_station_updates_geometry_and_metadata() -> None:
|
||||
session = DummySession()
|
||||
station_id = uuid4()
|
||||
DummyStationRepository._store[station_id] = DummyStation(
|
||||
id=station_id,
|
||||
name="Old Name",
|
||||
location=DummyStationRepository._point(50.0, 8.0),
|
||||
osm_id=None,
|
||||
code=None,
|
||||
elevation_m=None,
|
||||
is_active=True,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
payload = StationUpdate(name="New Name", latitude=51.0, longitude=9.0)
|
||||
result = stations_service.update_station(
|
||||
cast(Session, session), str(station_id), payload
|
||||
)
|
||||
|
||||
assert result.name == "New Name"
|
||||
assert result.latitude == pytest.approx(51.0)
|
||||
assert result.longitude == pytest.approx(9.0)
|
||||
assert DummyStationRepository._store[station_id].name == "New Name"
|
||||
|
||||
|
||||
def test_update_station_requires_both_coordinates() -> None:
|
||||
session = DummySession()
|
||||
station_id = uuid4()
|
||||
DummyStationRepository._store[station_id] = DummyStation(
|
||||
id=station_id,
|
||||
name="Station",
|
||||
location=DummyStationRepository._point(50.0, 8.0),
|
||||
osm_id=None,
|
||||
code=None,
|
||||
elevation_m=None,
|
||||
is_active=True,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
stations_service.update_station(
|
||||
cast(Session, session), str(station_id), StationUpdate(latitude=51.0)
|
||||
)
|
||||
|
||||
|
||||
def test_archive_station_marks_inactive() -> None:
|
||||
session = DummySession()
|
||||
station_id = uuid4()
|
||||
DummyStationRepository._store[station_id] = DummyStation(
|
||||
id=station_id,
|
||||
name="Station",
|
||||
location=DummyStationRepository._point(50.0, 8.0),
|
||||
osm_id=None,
|
||||
code=None,
|
||||
elevation_m=None,
|
||||
is_active=True,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
result = stations_service.archive_station(cast(Session, session), str(station_id))
|
||||
|
||||
assert result.is_active is False
|
||||
assert DummyStationRepository._store[station_id].is_active is False
|
||||
Reference in New Issue
Block a user