Add unit tests for station service and enhance documentation
Some checks failed
Backend CI / lint-and-test (push) Failing after 37s

- Introduced unit tests for the station service, covering creation, updating, and archiving of stations.
- Added detailed building block view documentation outlining the architecture of the Rail Game system.
- Created runtime view documentation illustrating key user interactions and system behavior.
- Developed concepts documentation detailing domain models, architectural patterns, and security considerations.
- Updated architecture documentation to reference new detailed sections for building block and runtime views.
This commit is contained in:
2025-10-11 18:52:25 +02:00
parent 2b9877a9d3
commit 615b63ba76
18 changed files with 1662 additions and 443 deletions

View File

@@ -3,8 +3,10 @@ from fastapi import APIRouter
from backend.app.api.auth import router as auth_router
from backend.app.api.health import router as health_router
from backend.app.api.network import router as network_router
from backend.app.api.stations import router as stations_router
router = APIRouter()
router.include_router(health_router, tags=["health"])
router.include_router(auth_router)
router.include_router(network_router)
router.include_router(stations_router)

View File

@@ -0,0 +1,94 @@
from __future__ import annotations
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy.orm import Session
from backend.app.api.deps import get_current_user, get_db
from backend.app.models import StationCreate, StationModel, StationUpdate, UserPublic
from backend.app.services.stations import (
archive_station,
create_station,
get_station,
list_stations,
update_station,
)
router = APIRouter(prefix="/stations", tags=["stations"])
@router.get("", response_model=list[StationModel])
def read_stations(
include_inactive: bool = False,
_: UserPublic = Depends(get_current_user),
db: Session = Depends(get_db),
) -> list[StationModel]:
return list_stations(db, include_inactive=include_inactive)
@router.get("/{station_id}", response_model=StationModel)
def read_station(
station_id: str,
_: UserPublic = Depends(get_current_user),
db: Session = Depends(get_db),
) -> StationModel:
try:
return get_station(db, station_id)
except LookupError as exc:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)
) from exc
except ValueError as exc:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail=str(exc)
) from exc
@router.post("", response_model=StationModel, status_code=status.HTTP_201_CREATED)
def create_station_endpoint(
payload: StationCreate,
_: UserPublic = Depends(get_current_user),
db: Session = Depends(get_db),
) -> StationModel:
try:
return create_station(db, payload)
except ValueError as exc:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail=str(exc)
) from exc
@router.put("/{station_id}", response_model=StationModel)
def update_station_endpoint(
station_id: str,
payload: StationUpdate,
_: UserPublic = Depends(get_current_user),
db: Session = Depends(get_db),
) -> StationModel:
try:
return update_station(db, station_id, payload)
except LookupError as exc:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)
) from exc
except ValueError as exc:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail=str(exc)
) from exc
@router.post("/{station_id}/archive", response_model=StationModel)
def archive_station_endpoint(
station_id: str,
_: UserPublic = Depends(get_current_user),
db: Session = Depends(get_db),
) -> StationModel:
try:
return archive_station(db, station_id)
except LookupError as exc:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)
) from exc
except ValueError as exc:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail=str(exc)
) from exc

View File

@@ -0,0 +1,93 @@
from __future__ import annotations
"""Geographic presets and tagging rules for OpenStreetMap imports."""
from dataclasses import dataclass
from typing import Iterable, Mapping, Tuple
@dataclass(frozen=True)
class BoundingBox:
"""Geographic bounding box expressed as WGS84 coordinates."""
name: str
north: float
south: float
east: float
west: float
description: str | None = None
def __post_init__(self) -> None:
if self.north <= self.south:
msg = f"north ({self.north}) must be greater than south ({self.south})"
raise ValueError(msg)
if self.east <= self.west:
msg = f"east ({self.east}) must be greater than west ({self.west})"
raise ValueError(msg)
def contains(self, latitude: float, longitude: float) -> bool:
"""Return True when the given coordinate lies inside the bounding box."""
return (
self.south <= latitude <= self.north and self.west <= longitude <= self.east
)
def to_overpass_arg(self) -> str:
"""Return the bbox string used for Overpass API queries."""
return f"{self.south},{self.west},{self.north},{self.east}"
# Primary metropolitan areas we plan to support.
DEFAULT_REGIONS: Tuple[BoundingBox, ...] = (
BoundingBox(
name="berlin_metropolitan",
north=52.6755,
south=52.3381,
east=13.7611,
west=13.0884,
description="Berlin and surrounding rapid transit network",
),
BoundingBox(
name="hamburg_metropolitan",
north=53.7447,
south=53.3950,
east=10.3253,
west=9.7270,
description="Hamburg S-Bahn and harbor region",
),
BoundingBox(
name="munich_metropolitan",
north=48.2485,
south=47.9960,
east=11.7229,
west=11.3600,
description="Munich S-Bahn core and airport corridor",
),
)
# Tags that identify passenger stations and stops.
STATION_TAG_FILTERS: Mapping[str, Tuple[str, ...]] = {
"railway": ("station", "halt", "stop"),
"public_transport": ("station", "stop_position", "platform"),
"train": ("yes", "regional", "suburban"),
}
def compile_overpass_filters(filters: Mapping[str, Iterable[str]]) -> str:
"""Build an Overpass boolean expression that matches the provided filters."""
parts: list[str] = []
for key, values in filters.items():
options = "|".join(sorted(set(values)))
parts.append(f' ["{key}"~"^({options})$"]')
return "\n".join(parts)
__all__ = [
"BoundingBox",
"DEFAULT_REGIONS",
"STATION_TAG_FILTERS",
"compile_overpass_filters",
]

View File

@@ -10,6 +10,7 @@ from .auth import (
from .base import (
StationCreate,
StationModel,
StationUpdate,
TrackCreate,
TrackModel,
TrainCreate,
@@ -29,6 +30,7 @@ __all__ = [
"UserPublic",
"StationCreate",
"StationModel",
"StationUpdate",
"TrackCreate",
"TrackModel",
"TrainScheduleCreate",

View File

@@ -42,6 +42,10 @@ class StationModel(IdentifiedModel[str]):
name: str
latitude: float
longitude: float
code: str | None = None
osm_id: str | None = None
elevation_m: float | None = None
is_active: bool = True
class TrackModel(IdentifiedModel[str]):
@@ -68,6 +72,16 @@ class StationCreate(CamelModel):
is_active: bool = True
class StationUpdate(CamelModel):
name: str | None = None
latitude: float | None = None
longitude: float | None = None
osm_id: str | None = None
code: str | None = None
elevation_m: float | None = None
is_active: bool | None = None
class TrackCreate(CamelModel):
start_station_id: str
end_station_id: str

View File

@@ -0,0 +1,195 @@
from __future__ import annotations
"""Application services for station CRUD operations."""
from datetime import datetime, timezone
from typing import cast
from uuid import UUID
from geoalchemy2.elements import WKBElement, WKTElement
from geoalchemy2.shape import to_shape
from sqlalchemy.orm import Session
from backend.app.db.models import Station
from backend.app.models import StationCreate, StationModel, StationUpdate
from backend.app.repositories import StationRepository
try: # pragma: no cover - optional dependency guard
from shapely.geometry import Point # type: ignore
except ImportError: # pragma: no cover - allow running without shapely at import time
Point = None # type: ignore[assignment]
def list_stations(
session: Session, include_inactive: bool = False
) -> list[StationModel]:
repo = StationRepository(session)
if include_inactive:
stations = repo.list()
else:
stations = repo.list_active()
return [_to_station_model(station) for station in stations]
def get_station(session: Session, station_id: str) -> StationModel:
repo = StationRepository(session)
station = _resolve_station(repo, station_id)
return _to_station_model(station)
def create_station(session: Session, payload: StationCreate) -> StationModel:
name = payload.name.strip()
if not name:
raise ValueError("Station name must not be empty")
_validate_coordinates(payload.latitude, payload.longitude)
repo = StationRepository(session)
station = repo.create(
StationCreate(
name=name,
latitude=payload.latitude,
longitude=payload.longitude,
osm_id=_normalize_optional(payload.osm_id),
code=_normalize_optional(payload.code),
elevation_m=payload.elevation_m,
is_active=payload.is_active,
)
)
session.flush()
session.refresh(station)
session.commit()
return _to_station_model(station)
def update_station(
session: Session, station_id: str, payload: StationUpdate
) -> StationModel:
repo = StationRepository(session)
station = _resolve_station(repo, station_id)
if payload.name is not None:
name = payload.name.strip()
if not name:
raise ValueError("Station name must not be empty")
station.name = name
if payload.latitude is not None or payload.longitude is not None:
if payload.latitude is None or payload.longitude is None:
raise ValueError("Both latitude and longitude must be provided together")
_validate_coordinates(payload.latitude, payload.longitude)
station.location = repo._point(
payload.latitude, payload.longitude
) # type: ignore[assignment]
if payload.osm_id is not None:
station.osm_id = _normalize_optional(payload.osm_id)
if payload.code is not None:
station.code = _normalize_optional(payload.code)
if payload.elevation_m is not None:
station.elevation_m = payload.elevation_m
if payload.is_active is not None:
station.is_active = payload.is_active
session.flush()
session.refresh(station)
session.commit()
return _to_station_model(station)
def archive_station(session: Session, station_id: str) -> StationModel:
repo = StationRepository(session)
station = _resolve_station(repo, station_id)
if station.is_active:
station.is_active = False
session.flush()
session.refresh(station)
session.commit()
return _to_station_model(station)
def _resolve_station(repo: StationRepository, station_id: str) -> Station:
identifier = _parse_station_id(station_id)
station = repo.get(identifier)
if station is None:
raise LookupError("Station not found")
return station
def _parse_station_id(station_id: str) -> UUID:
try:
return UUID(station_id)
except (ValueError, TypeError) as exc: # pragma: no cover - simple validation
raise ValueError("Invalid station identifier") from exc
def _validate_coordinates(latitude: float, longitude: float) -> None:
if not (-90.0 <= latitude <= 90.0):
raise ValueError("Latitude must be between -90 and 90 degrees")
if not (-180.0 <= longitude <= 180.0):
raise ValueError("Longitude must be between -180 and 180 degrees")
def _normalize_optional(value: str | None) -> str | None:
if value is None:
return None
normalized = value.strip()
return normalized or None
def _to_station_model(station: Station) -> StationModel:
latitude, longitude = _extract_coordinates(station.location)
created_at = station.created_at or datetime.now(timezone.utc)
updated_at = station.updated_at or created_at
return StationModel(
id=str(station.id),
name=station.name,
latitude=latitude,
longitude=longitude,
code=station.code,
osm_id=station.osm_id,
elevation_m=station.elevation_m,
is_active=station.is_active,
created_at=cast(datetime, created_at),
updated_at=cast(datetime, updated_at),
)
def _extract_coordinates(location: object) -> tuple[float, float]:
if location is None:
raise ValueError("Station location is unavailable")
# Attempt to leverage GeoAlchemy's shapely integration first.
try:
geometry = to_shape(cast(WKBElement | WKTElement, location))
if Point is not None and isinstance(geometry, Point):
return float(geometry.y), float(geometry.x)
except Exception: # pragma: no cover - fallback handles parsing
pass
if isinstance(location, WKTElement):
return _parse_wkt_point(location.data)
text = getattr(location, "desc", None)
if isinstance(text, str):
return _parse_wkt_point(text)
raise ValueError("Unable to read station geometry")
def _parse_wkt_point(wkt: str) -> tuple[float, float]:
marker = "POINT"
if not wkt.upper().startswith(marker):
raise ValueError("Unsupported geometry format")
start = wkt.find("(")
end = wkt.find(")", start)
if start == -1 or end == -1:
raise ValueError("Malformed POINT geometry")
coordinates = wkt[start + 1 : end].strip().split()
if len(coordinates) != 2:
raise ValueError("POINT geometry must contain two coordinates")
longitude, latitude = map(float, coordinates)
_validate_coordinates(latitude, longitude)
return latitude, longitude

View File

@@ -0,0 +1,171 @@
from __future__ import annotations
"""CLI utility to import station data from OpenStreetMap."""
import argparse
import json
import sys
from dataclasses import asdict
from pathlib import Path
from typing import Any, Iterable
from urllib.parse import quote_plus
from backend.app.core.osm_config import (
DEFAULT_REGIONS,
STATION_TAG_FILTERS,
compile_overpass_filters,
)
OVERPASS_ENDPOINT = "https://overpass-api.de/api/interpreter"
def build_argument_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(
description="Export OSM station nodes for ingestion"
)
parser.add_argument(
"--output",
type=Path,
default=Path("data/osm_stations.json"),
help="Destination file for the exported station nodes (default: data/osm_stations.json)",
)
parser.add_argument(
"--region",
choices=[region.name for region in DEFAULT_REGIONS] + ["all"],
default="all",
help="Region name to export (default: all)",
)
parser.add_argument(
"--dry-run",
action="store_true",
help="Do not fetch data; print the Overpass payload only",
)
return parser
def build_overpass_query(region_name: str) -> str:
if region_name == "all":
regions = DEFAULT_REGIONS
else:
regions = tuple(
region for region in DEFAULT_REGIONS if region.name == region_name
)
if not regions:
msg = f"Unknown region {region_name}. Available regions: {[region.name for region in DEFAULT_REGIONS]}"
raise ValueError(msg)
filters = compile_overpass_filters(STATION_TAG_FILTERS)
parts = ["[out:json][timeout:90];", "("]
for region in regions:
parts.append(f" node{filters}\n ({region.to_overpass_arg()});")
parts.append(")")
parts.append("; out body; >; out skel qt;")
return "\n".join(parts)
def perform_request(query: str) -> dict[str, Any]:
import urllib.request
payload = f"data={quote_plus(query)}".encode("utf-8")
request = urllib.request.Request(
OVERPASS_ENDPOINT,
data=payload,
headers={"Content-Type": "application/x-www-form-urlencoded"},
)
with urllib.request.urlopen(request, timeout=120) as response:
payload = response.read()
return json.loads(payload)
def normalize_station_elements(
elements: Iterable[dict[str, Any]]
) -> list[dict[str, Any]]:
"""Convert raw Overpass nodes into StationCreate-compatible payloads."""
stations: list[dict[str, Any]] = []
for element in elements:
if element.get("type") != "node":
continue
latitude = element.get("lat")
longitude = element.get("lon")
if latitude is None or longitude is None:
continue
tags: dict[str, Any] = element.get("tags", {})
name = tags.get("name")
if not name:
continue
raw_code = tags.get("ref") or tags.get(
"railway:ref") or tags.get("local_ref")
code = str(raw_code) if raw_code is not None else None
elevation_tag = tags.get("ele") or tags.get("elevation")
try:
elevation = float(
elevation_tag) if elevation_tag is not None else None
except (TypeError, ValueError):
elevation = None
disused = str(tags.get("disused", "no")).lower() in {"yes", "true"}
railway_status = str(tags.get("railway", "")).lower()
abandoned = railway_status in {"abandoned", "disused"}
is_active = not (disused or abandoned)
stations.append(
{
"osm_id": str(element.get("id")),
"name": str(name),
"latitude": float(latitude),
"longitude": float(longitude),
"code": code,
"elevation_m": elevation,
"is_active": is_active,
}
)
return stations
def main(argv: list[str] | None = None) -> int:
parser = build_argument_parser()
args = parser.parse_args(argv)
query = build_overpass_query(args.region)
if args.dry_run:
print(query)
return 0
output_path: Path = args.output
output_path.parent.mkdir(parents=True, exist_ok=True)
data = perform_request(query)
raw_elements = data.get("elements", [])
stations = normalize_station_elements(raw_elements)
payload = {
"metadata": {
"endpoint": OVERPASS_ENDPOINT,
"region": args.region,
"filters": STATION_TAG_FILTERS,
"regions": [asdict(region) for region in DEFAULT_REGIONS],
"raw_count": len(raw_elements),
"station_count": len(stations),
},
"stations": stations,
}
with output_path.open("w", encoding="utf-8") as handle:
json.dump(payload, handle, indent=2)
print(
f"Normalized {len(stations)} stations from {len(raw_elements)} elements into {output_path}"
)
return 0
if __name__ == "__main__":
sys.exit(main())

View File

@@ -0,0 +1,86 @@
from __future__ import annotations
"""CLI for loading normalized station JSON into the database."""
import argparse
import json
import sys
from pathlib import Path
from typing import Any, Iterable, Mapping
from backend.app.db.session import SessionLocal
from backend.app.models import StationCreate
from backend.app.repositories import StationRepository
def build_argument_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(
description="Load normalized station data into PostGIS"
)
parser.add_argument(
"input",
type=Path,
help="Path to the normalized station JSON file produced by stations_import.py",
)
parser.add_argument(
"--commit/--no-commit",
dest="commit",
default=True,
help="Commit the transaction (default: commit). Use --no-commit for dry runs.",
)
return parser
def main(argv: list[str] | None = None) -> int:
parser = build_argument_parser()
args = parser.parse_args(argv)
if not args.input.exists():
parser.error(f"Input file {args.input} does not exist")
with args.input.open("r", encoding="utf-8") as handle:
payload = json.load(handle)
stations_data = payload.get("stations") or []
if not isinstance(stations_data, list):
parser.error("Invalid payload: 'stations' must be a list")
try:
station_creates = _parse_station_entries(stations_data)
except ValueError as exc:
parser.error(str(exc))
created = load_stations(station_creates, commit=args.commit)
print(f"Loaded {created} stations from {args.input}")
return 0
def _parse_station_entries(entries: Iterable[Mapping[str, Any]]) -> list[StationCreate]:
parsed: list[StationCreate] = []
for entry in entries:
try:
parsed.append(StationCreate(**entry))
except Exception as exc: # pragma: no cover - validated in tests
raise ValueError(f"Invalid station entry {entry}: {exc}") from exc
return parsed
def load_stations(stations: Iterable[StationCreate], commit: bool = True) -> int:
created = 0
with SessionLocal() as session:
repo = StationRepository(session)
for create_schema in stations:
repo.create(create_schema)
created += 1
if commit:
session.commit()
else:
session.rollback()
return created
if __name__ == "__main__":
sys.exit(main())

View File

@@ -0,0 +1,28 @@
from backend.app.core.osm_config import (
DEFAULT_REGIONS,
STATION_TAG_FILTERS,
BoundingBox,
compile_overpass_filters,
)
def test_default_regions_are_valid() -> None:
assert DEFAULT_REGIONS, "Expected at least one region definition"
for bbox in DEFAULT_REGIONS:
assert isinstance(bbox, BoundingBox)
assert bbox.north > bbox.south
assert bbox.east > bbox.west
# Berlin coordinates should fall inside Berlin bounding box for sanity
if bbox.name == "berlin_metropolitan":
assert bbox.contains(52.5200, 13.4050)
def test_station_tag_filters_compile_to_overpass_snippet() -> None:
compiled = compile_overpass_filters(STATION_TAG_FILTERS)
# Ensure each key is present with its values
for key, values in STATION_TAG_FILTERS.items():
assert key in compiled
for value in values:
assert value in compiled
# The snippet should be multi-line to preserve readability
assert "\n" in compiled

View File

@@ -0,0 +1,137 @@
from __future__ import annotations
from datetime import datetime, timezone
from typing import Any
from uuid import uuid4
import pytest
from fastapi.testclient import TestClient
from backend.app.api import stations as stations_api
from backend.app.main import app
from backend.app.models import StationCreate, StationModel, StationUpdate
AUTH_CREDENTIALS = {"username": "demo", "password": "railgame123"}
client = TestClient(app)
def _authenticate() -> str:
response = client.post("/api/auth/login", json=AUTH_CREDENTIALS)
assert response.status_code == 200
return response.json()["accessToken"]
def _station_payload(**overrides: Any) -> dict[str, Any]:
payload = {
"name": "Central",
"latitude": 52.52,
"longitude": 13.405,
"osmId": "123",
"code": "BER",
"elevationM": 34.5,
"isActive": True,
}
payload.update(overrides)
return payload
def _station_model(**overrides: Any) -> StationModel:
now = datetime.now(timezone.utc)
base = StationModel(
id=str(uuid4()),
name="Central",
latitude=52.52,
longitude=13.405,
code="BER",
osm_id="123",
elevation_m=34.5,
is_active=True,
created_at=now,
updated_at=now,
)
return base.model_copy(update=overrides)
def test_list_stations_requires_authentication() -> None:
response = client.get("/api/stations")
assert response.status_code == 401
def test_list_stations_returns_payload(monkeypatch: pytest.MonkeyPatch) -> None:
token = _authenticate()
def fake_list_stations(db, include_inactive: bool) -> list[StationModel]:
assert include_inactive is True
return [_station_model()]
monkeypatch.setattr(stations_api, "list_stations", fake_list_stations)
response = client.get(
"/api/stations",
params={"include_inactive": "true"},
headers={"Authorization": f"Bearer {token}"},
)
assert response.status_code == 200
payload = response.json()
assert len(payload) == 1
assert payload[0]["name"] == "Central"
def test_create_station_delegates_to_service(monkeypatch: pytest.MonkeyPatch) -> None:
token = _authenticate()
seen: dict[str, StationCreate] = {}
def fake_create_station(db, payload: StationCreate) -> StationModel:
seen["payload"] = payload
return _station_model()
monkeypatch.setattr(stations_api, "create_station", fake_create_station)
response = client.post(
"/api/stations",
json=_station_payload(),
headers={"Authorization": f"Bearer {token}"},
)
assert response.status_code == 201
assert response.json()["name"] == "Central"
assert seen["payload"].name == "Central"
def test_update_station_not_found_returns_404(monkeypatch: pytest.MonkeyPatch) -> None:
token = _authenticate()
def fake_update_station(
db, station_id: str, payload: StationUpdate
) -> StationModel:
raise LookupError("Station not found")
monkeypatch.setattr(stations_api, "update_station", fake_update_station)
response = client.put(
"/api/stations/123e4567-e89b-12d3-a456-426614174000",
json={"name": "New Name"},
headers={"Authorization": f"Bearer {token}"},
)
assert response.status_code == 404
assert response.json()["detail"] == "Station not found"
def test_archive_station_returns_updated_model(monkeypatch: pytest.MonkeyPatch) -> None:
token = _authenticate()
def fake_archive_station(db, station_id: str) -> StationModel:
return _station_model(is_active=False)
monkeypatch.setattr(stations_api, "archive_station", fake_archive_station)
response = client.post(
"/api/stations/123e4567-e89b-12d3-a456-426614174000/archive",
headers={"Authorization": f"Bearer {token}"},
)
assert response.status_code == 200
assert response.json()["isActive"] is False

View File

@@ -0,0 +1,67 @@
from backend.scripts.stations_import import (
build_overpass_query,
normalize_station_elements,
)
def test_build_overpass_query_single_region() -> None:
query = build_overpass_query("berlin_metropolitan")
# The query should reference the Berlin bounding box coordinates.
assert "52.3381" in query # south
assert "52.6755" in query # north
assert "13.0884" in query # west
assert "13.7611" in query # east
assert "node" in query
assert "out body" in query
def test_build_overpass_query_all_regions_includes_union() -> None:
query = build_overpass_query("all")
# Ensure multiple regions are present by checking for repeated bbox parentheses.
assert query.count("node") >= 3
assert query.strip().endswith("out skel qt;")
def test_normalize_station_elements_filters_and_transforms() -> None:
raw_elements = [
{
"type": "node",
"id": 123,
"lat": 52.5,
"lon": 13.4,
"tags": {
"name": "Sample Station",
"ref": "XYZ",
"ele": "35.5",
},
},
{
"type": "node",
"id": 999,
# Missing coordinates should be ignored
"tags": {"name": "Broken"},
},
{
"type": "node",
"id": 456,
"lat": 50.0,
"lon": 8.0,
"tags": {
"name": "Disused Station",
"disused": "yes",
},
},
]
stations = normalize_station_elements(raw_elements)
assert len(stations) == 2
primary = stations[0]
assert primary["osm_id"] == "123"
assert primary["name"] == "Sample Station"
assert primary["code"] == "XYZ"
assert primary["elevation_m"] == 35.5
disused_station = stations[1]
assert disused_station["is_active"] is False

View File

@@ -0,0 +1,142 @@
from __future__ import annotations
from dataclasses import dataclass, field
import pytest
from backend.scripts import stations_load
def test_parse_station_entries_returns_models() -> None:
entries = [
{
"name": "Central",
"latitude": 52.52,
"longitude": 13.405,
"osm_id": "123",
"code": "BER",
"elevation_m": 34.5,
"is_active": True,
}
]
parsed = stations_load._parse_station_entries(entries)
assert parsed[0].name == "Central"
assert parsed[0].latitude == 52.52
assert parsed[0].osm_id == "123"
def test_parse_station_entries_invalid_raises_value_error() -> None:
entries = [
{
"latitude": 52.52,
"longitude": 13.405,
"is_active": True,
}
]
with pytest.raises(ValueError):
stations_load._parse_station_entries(entries)
@dataclass
class DummySession:
committed: bool = False
rolled_back: bool = False
closed: bool = False
def __enter__(self) -> "DummySession":
return self
def __exit__(self, exc_type, exc, traceback) -> None:
self.closed = True
def commit(self) -> None:
self.committed = True
def rollback(self) -> None:
self.rolled_back = True
@dataclass
class DummyRepository:
session: DummySession
created: list = field(default_factory=list)
def create(self, data) -> None: # pragma: no cover - simple delegation
self.created.append(data)
class DummySessionFactory:
def __call__(self) -> DummySession:
return DummySession()
def test_load_stations_commits_when_requested(monkeypatch: pytest.MonkeyPatch) -> None:
repo_instances: list[DummyRepository] = []
def fake_session_local() -> DummySession:
return DummySession()
def fake_repo(session: DummySession) -> DummyRepository:
repo = DummyRepository(session)
repo_instances.append(repo)
return repo
monkeypatch.setattr(stations_load, "SessionLocal", fake_session_local)
monkeypatch.setattr(stations_load, "StationRepository", fake_repo)
stations = stations_load._parse_station_entries(
[
{
"name": "Central",
"latitude": 52.52,
"longitude": 13.405,
"osm_id": "123",
"is_active": True,
}
]
)
created = stations_load.load_stations(stations, commit=True)
assert created == 1
assert repo_instances[0].session.committed is True
assert repo_instances[0].session.rolled_back is False
assert len(repo_instances[0].created) == 1
def test_load_stations_rolls_back_when_no_commit(
monkeypatch: pytest.MonkeyPatch,
) -> None:
repo_instances: list[DummyRepository] = []
def fake_session_local() -> DummySession:
return DummySession()
def fake_repo(session: DummySession) -> DummyRepository:
repo = DummyRepository(session)
repo_instances.append(repo)
return repo
monkeypatch.setattr(stations_load, "SessionLocal", fake_session_local)
monkeypatch.setattr(stations_load, "StationRepository", fake_repo)
stations = stations_load._parse_station_entries(
[
{
"name": "Central",
"latitude": 52.52,
"longitude": 13.405,
"osm_id": "123",
"is_active": True,
}
]
)
created = stations_load.load_stations(stations, commit=False)
assert created == 1
assert repo_instances[0].session.committed is False
assert repo_instances[0].session.rolled_back is True

View File

@@ -0,0 +1,175 @@
from __future__ import annotations
from dataclasses import dataclass, field
from datetime import datetime, timezone
from typing import Dict, List, cast
from uuid import UUID, uuid4
import pytest
from geoalchemy2.elements import WKTElement
from sqlalchemy.orm import Session
from backend.app.models import StationCreate, StationUpdate
from backend.app.services import stations as stations_service
@dataclass
class DummySession:
flushed: bool = False
committed: bool = False
refreshed: List[object] = field(default_factory=list)
def flush(self) -> None:
self.flushed = True
def refresh(self, instance: object) -> None: # pragma: no cover - simple setter
self.refreshed.append(instance)
def commit(self) -> None:
self.committed = True
@dataclass
class DummyStation:
id: UUID
name: str
location: WKTElement
osm_id: str | None
code: str | None
elevation_m: float | None
is_active: bool
created_at: datetime
updated_at: datetime
class DummyStationRepository:
_store: Dict[UUID, DummyStation] = {}
def __init__(self, session: DummySession) -> None: # pragma: no cover - simple init
self.session = session
@staticmethod
def _point(latitude: float, longitude: float) -> WKTElement:
return WKTElement(f"POINT({longitude} {latitude})", srid=4326)
def list(self) -> list[DummyStation]:
return list(self._store.values())
def list_active(self) -> list[DummyStation]:
return [station for station in self._store.values() if station.is_active]
def get(self, identifier: UUID) -> DummyStation | None:
return self._store.get(identifier)
def create(self, payload: StationCreate) -> DummyStation:
station = DummyStation(
id=uuid4(),
name=payload.name,
location=self._point(payload.latitude, payload.longitude),
osm_id=payload.osm_id,
code=payload.code,
elevation_m=payload.elevation_m,
is_active=payload.is_active,
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
)
self._store[station.id] = station
return station
@pytest.fixture(autouse=True)
def reset_store(monkeypatch: pytest.MonkeyPatch) -> None:
DummyStationRepository._store = {}
monkeypatch.setattr(stations_service, "StationRepository", DummyStationRepository)
def test_create_station_persists_and_returns_model(
monkeypatch: pytest.MonkeyPatch,
) -> None:
session = DummySession()
payload = StationCreate(
name="Central",
latitude=52.52,
longitude=13.405,
osm_id="123",
code="BER",
elevation_m=34.5,
is_active=True,
)
result = stations_service.create_station(cast(Session, session), payload)
assert session.flushed is True
assert session.committed is True
assert result.name == "Central"
assert result.latitude == pytest.approx(52.52)
assert result.longitude == pytest.approx(13.405)
assert result.osm_id == "123"
def test_update_station_updates_geometry_and_metadata() -> None:
session = DummySession()
station_id = uuid4()
DummyStationRepository._store[station_id] = DummyStation(
id=station_id,
name="Old Name",
location=DummyStationRepository._point(50.0, 8.0),
osm_id=None,
code=None,
elevation_m=None,
is_active=True,
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
)
payload = StationUpdate(name="New Name", latitude=51.0, longitude=9.0)
result = stations_service.update_station(
cast(Session, session), str(station_id), payload
)
assert result.name == "New Name"
assert result.latitude == pytest.approx(51.0)
assert result.longitude == pytest.approx(9.0)
assert DummyStationRepository._store[station_id].name == "New Name"
def test_update_station_requires_both_coordinates() -> None:
session = DummySession()
station_id = uuid4()
DummyStationRepository._store[station_id] = DummyStation(
id=station_id,
name="Station",
location=DummyStationRepository._point(50.0, 8.0),
osm_id=None,
code=None,
elevation_m=None,
is_active=True,
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
)
with pytest.raises(ValueError):
stations_service.update_station(
cast(Session, session), str(station_id), StationUpdate(latitude=51.0)
)
def test_archive_station_marks_inactive() -> None:
session = DummySession()
station_id = uuid4()
DummyStationRepository._store[station_id] = DummyStation(
id=station_id,
name="Station",
location=DummyStationRepository._point(50.0, 8.0),
osm_id=None,
code=None,
elevation_m=None,
is_active=True,
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
)
result = stations_service.archive_station(cast(Session, session), str(station_id))
assert result.is_active is False
assert DummyStationRepository._store[station_id].is_active is False