Files
rail-game/backend/app/services/network.py
zwitschi c35049cd54
Some checks failed
Backend CI / lint-and-test (push) Failing after 1m54s
fix: formatting (black)
2025-10-11 21:58:32 +02:00

192 lines
6.1 KiB
Python

"""Domain services for railway network aggregation."""
from datetime import datetime, timezone
from decimal import Decimal
from typing import Iterable, cast
from geoalchemy2.elements import WKBElement, WKTElement
from geoalchemy2.shape import to_shape
try: # pragma: no cover - optional dependency guard
from shapely.geometry import LineString, Point # type: ignore
except ImportError: # pragma: no cover - allow running without shapely at import time
Point = None # type: ignore[assignment]
LineString = None # type: ignore[assignment]
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.orm import Session
from backend.app.models import StationModel, TrackModel, TrainModel
from backend.app.repositories import StationRepository, TrackRepository, TrainRepository
def _timestamp() -> datetime:
return datetime.now(timezone.utc)
def _fallback_snapshot() -> dict[str, list[dict[str, object]]]:
now = _timestamp()
stations = [
StationModel(
id="station-1",
name="Central",
latitude=52.520008,
longitude=13.404954,
created_at=now,
updated_at=now,
),
StationModel(
id="station-2",
name="Harbor",
latitude=53.551086,
longitude=9.993682,
created_at=now,
updated_at=now,
),
]
tracks = [
TrackModel(
id="track-1",
start_station_id="station-1",
end_station_id="station-2",
length_meters=289000.0,
max_speed_kph=230.0,
status="operational",
is_bidirectional=True,
coordinates=[
(stations[0].latitude, stations[0].longitude),
(stations[1].latitude, stations[1].longitude),
],
created_at=now,
updated_at=now,
)
]
trains = [
TrainModel(
id="train-1",
designation="ICE 123",
capacity=400,
max_speed_kph=300.0,
operating_track_ids=[track.id for track in tracks],
created_at=now,
updated_at=now,
)
]
return _serialize_snapshot(stations, tracks, trains)
def _serialize_snapshot(
stations: Iterable[StationModel],
tracks: Iterable[TrackModel],
trains: Iterable[TrainModel],
) -> dict[str, list[dict[str, object]]]:
return {
"stations": [station.model_dump(by_alias=True) for station in stations],
"tracks": [track.model_dump(by_alias=True) for track in tracks],
"trains": [train.model_dump(by_alias=True) for train in trains],
}
def _to_float(value: Decimal | float | int | None, default: float = 0.0) -> float:
if value is None:
return default
if isinstance(value, Decimal):
return float(value)
return float(value)
def get_network_snapshot(session: Session) -> dict[str, list[dict[str, object]]]:
station_repo = StationRepository(session)
track_repo = TrackRepository(session)
train_repo = TrainRepository(session)
try:
stations_entities = station_repo.list_active()
tracks_entities = track_repo.list_all()
trains_entities = train_repo.list_all()
except SQLAlchemyError:
session.rollback()
return _fallback_snapshot()
if not stations_entities and not tracks_entities and not trains_entities:
return _fallback_snapshot()
station_models: list[StationModel] = []
for station in stations_entities:
location = station.location
geom = (
to_shape(cast(WKBElement | WKTElement, location))
if location is not None and Point is not None
else None
)
if Point is not None and geom is not None and isinstance(geom, Point):
latitude = float(geom.y)
longitude = float(geom.x)
else:
latitude = 0.0
longitude = 0.0
station_models.append(
StationModel(
id=str(station.id),
name=station.name,
latitude=latitude,
longitude=longitude,
created_at=cast(datetime, station.created_at),
updated_at=cast(datetime, station.updated_at),
)
)
track_models: list[TrackModel] = []
for track in tracks_entities:
coordinates: list[tuple[float, float]] = []
geometry = track.track_geometry
shape = (
to_shape(cast(WKBElement | WKTElement, geometry))
if geometry is not None and LineString is not None
else None
)
if (
LineString is not None
and shape is not None
and isinstance(shape, LineString)
):
coords_list: list[tuple[float, float]] = []
for coord in shape.coords:
lon = float(coord[0])
lat = float(coord[1])
coords_list.append((lat, lon))
coordinates = coords_list
track_models.append(
TrackModel(
id=str(track.id),
start_station_id=str(track.start_station_id),
end_station_id=str(track.end_station_id),
length_meters=_to_float(track.length_meters),
max_speed_kph=_to_float(track.max_speed_kph),
status=track.status,
is_bidirectional=track.is_bidirectional,
coordinates=coordinates,
created_at=cast(datetime, track.created_at),
updated_at=cast(datetime, track.updated_at),
)
)
train_models: list[TrainModel] = []
for train in trains_entities:
train_models.append(
TrainModel(
id=str(train.id),
designation=train.designation,
capacity=train.capacity,
max_speed_kph=_to_float(train.max_speed_kph),
operating_track_ids=[],
created_at=cast(datetime, train.created_at),
updated_at=cast(datetime, train.updated_at),
)
)
return _serialize_snapshot(station_models, track_models, train_models)