Files
rail-game/backend/app/services/network.py
zwitschi c2927f2f60 feat: Enhance track model and import functionality
- Added new fields to TrackModel: status, is_bidirectional, and coordinates.
- Updated network service to handle new track attributes and geometry extraction.
- Introduced CLI scripts for importing and loading tracks from OpenStreetMap.
- Implemented normalization of track elements to ensure valid geometries.
- Enhanced tests for track model, network service, and import/load scripts.
- Updated frontend to accommodate new track attributes and improve route computation.
- Documented OSM ingestion process in architecture and runtime views.
2025-10-11 19:54:10 +02:00

188 lines
6.1 KiB
Python

"""Domain services for railway network aggregation."""
from datetime import datetime, timezone
from decimal import Decimal
from typing import Iterable, cast
from geoalchemy2.elements import WKBElement, WKTElement
from geoalchemy2.shape import to_shape
try: # pragma: no cover - optional dependency guard
from shapely.geometry import LineString, Point # type: ignore
except ImportError: # pragma: no cover - allow running without shapely at import time
Point = None # type: ignore[assignment]
LineString = None # type: ignore[assignment]
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.orm import Session
from backend.app.models import StationModel, TrackModel, TrainModel
from backend.app.repositories import StationRepository, TrackRepository, TrainRepository
def _timestamp() -> datetime:
return datetime.now(timezone.utc)
def _fallback_snapshot() -> dict[str, list[dict[str, object]]]:
now = _timestamp()
stations = [
StationModel(
id="station-1",
name="Central",
latitude=52.520008,
longitude=13.404954,
created_at=now,
updated_at=now,
),
StationModel(
id="station-2",
name="Harbor",
latitude=53.551086,
longitude=9.993682,
created_at=now,
updated_at=now,
),
]
tracks = [
TrackModel(
id="track-1",
start_station_id="station-1",
end_station_id="station-2",
length_meters=289000.0,
max_speed_kph=230.0,
status="operational",
is_bidirectional=True,
coordinates=[
(stations[0].latitude, stations[0].longitude),
(stations[1].latitude, stations[1].longitude),
],
created_at=now,
updated_at=now,
)
]
trains = [
TrainModel(
id="train-1",
designation="ICE 123",
capacity=400,
max_speed_kph=300.0,
operating_track_ids=[track.id for track in tracks],
created_at=now,
updated_at=now,
)
]
return _serialize_snapshot(stations, tracks, trains)
def _serialize_snapshot(
stations: Iterable[StationModel],
tracks: Iterable[TrackModel],
trains: Iterable[TrainModel],
) -> dict[str, list[dict[str, object]]]:
return {
"stations": [station.model_dump(by_alias=True) for station in stations],
"tracks": [track.model_dump(by_alias=True) for track in tracks],
"trains": [train.model_dump(by_alias=True) for train in trains],
}
def _to_float(value: Decimal | float | int | None, default: float = 0.0) -> float:
if value is None:
return default
if isinstance(value, Decimal):
return float(value)
return float(value)
def get_network_snapshot(session: Session) -> dict[str, list[dict[str, object]]]:
station_repo = StationRepository(session)
track_repo = TrackRepository(session)
train_repo = TrainRepository(session)
try:
stations_entities = station_repo.list_active()
tracks_entities = track_repo.list_all()
trains_entities = train_repo.list_all()
except SQLAlchemyError:
session.rollback()
return _fallback_snapshot()
if not stations_entities and not tracks_entities and not trains_entities:
return _fallback_snapshot()
station_models: list[StationModel] = []
for station in stations_entities:
location = station.location
geom = (
to_shape(cast(WKBElement | WKTElement, location))
if location is not None and Point is not None
else None
)
if Point is not None and geom is not None and isinstance(geom, Point):
latitude = float(geom.y)
longitude = float(geom.x)
else:
latitude = 0.0
longitude = 0.0
station_models.append(
StationModel(
id=str(station.id),
name=station.name,
latitude=latitude,
longitude=longitude,
created_at=cast(datetime, station.created_at),
updated_at=cast(datetime, station.updated_at),
)
)
track_models: list[TrackModel] = []
for track in tracks_entities:
coordinates: list[tuple[float, float]] = []
geometry = track.track_geometry
shape = (
to_shape(cast(WKBElement | WKTElement, geometry))
if geometry is not None and LineString is not None
else None
)
if LineString is not None and shape is not None and isinstance(shape, LineString):
coords_list: list[tuple[float, float]] = []
for coord in shape.coords:
lon = float(coord[0])
lat = float(coord[1])
coords_list.append((lat, lon))
coordinates = coords_list
track_models.append(
TrackModel(
id=str(track.id),
start_station_id=str(track.start_station_id),
end_station_id=str(track.end_station_id),
length_meters=_to_float(track.length_meters),
max_speed_kph=_to_float(track.max_speed_kph),
status=track.status,
is_bidirectional=track.is_bidirectional,
coordinates=coordinates,
created_at=cast(datetime, track.created_at),
updated_at=cast(datetime, track.updated_at),
)
)
train_models: list[TrainModel] = []
for train in trains_entities:
train_models.append(
TrainModel(
id=str(train.id),
designation=train.designation,
capacity=train.capacity,
max_speed_kph=_to_float(train.max_speed_kph),
operating_track_ids=[],
created_at=cast(datetime, train.created_at),
updated_at=cast(datetime, train.updated_at),
)
)
return _serialize_snapshot(station_models, track_models, train_models)