"""Domain services for railway network aggregation.""" from datetime import datetime, timezone from decimal import Decimal from typing import Iterable, cast from geoalchemy2.elements import WKBElement, WKTElement from geoalchemy2.shape import to_shape try: # pragma: no cover - optional dependency guard from shapely.geometry import LineString, Point # type: ignore except ImportError: # pragma: no cover - allow running without shapely at import time Point = None # type: ignore[assignment] LineString = None # type: ignore[assignment] from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.orm import Session from backend.app.models import StationModel, TrackModel, TrainModel from backend.app.repositories import StationRepository, TrackRepository, TrainRepository def _timestamp() -> datetime: return datetime.now(timezone.utc) def _fallback_snapshot() -> dict[str, list[dict[str, object]]]: now = _timestamp() stations = [ StationModel( id="station-1", name="Central", latitude=52.520008, longitude=13.404954, created_at=now, updated_at=now, ), StationModel( id="station-2", name="Harbor", latitude=53.551086, longitude=9.993682, created_at=now, updated_at=now, ), ] tracks = [ TrackModel( id="track-1", start_station_id="station-1", end_station_id="station-2", length_meters=289000.0, max_speed_kph=230.0, status="operational", is_bidirectional=True, coordinates=[ (stations[0].latitude, stations[0].longitude), (stations[1].latitude, stations[1].longitude), ], created_at=now, updated_at=now, ) ] trains = [ TrainModel( id="train-1", designation="ICE 123", capacity=400, max_speed_kph=300.0, operating_track_ids=[track.id for track in tracks], created_at=now, updated_at=now, ) ] return _serialize_snapshot(stations, tracks, trains) def _serialize_snapshot( stations: Iterable[StationModel], tracks: Iterable[TrackModel], trains: Iterable[TrainModel], ) -> dict[str, list[dict[str, object]]]: return { "stations": [station.model_dump(by_alias=True) for station in stations], "tracks": [track.model_dump(by_alias=True) for track in tracks], "trains": [train.model_dump(by_alias=True) for train in trains], } def _to_float(value: Decimal | float | int | None, default: float = 0.0) -> float: if value is None: return default if isinstance(value, Decimal): return float(value) return float(value) def get_network_snapshot(session: Session) -> dict[str, list[dict[str, object]]]: station_repo = StationRepository(session) track_repo = TrackRepository(session) train_repo = TrainRepository(session) try: stations_entities = station_repo.list_active() tracks_entities = track_repo.list_all() trains_entities = train_repo.list_all() except SQLAlchemyError: session.rollback() return _fallback_snapshot() if not stations_entities and not tracks_entities and not trains_entities: return _fallback_snapshot() station_models: list[StationModel] = [] for station in stations_entities: location = station.location geom = ( to_shape(cast(WKBElement | WKTElement, location)) if location is not None and Point is not None else None ) if Point is not None and geom is not None and isinstance(geom, Point): latitude = float(geom.y) longitude = float(geom.x) else: latitude = 0.0 longitude = 0.0 station_models.append( StationModel( id=str(station.id), name=station.name, latitude=latitude, longitude=longitude, created_at=cast(datetime, station.created_at), updated_at=cast(datetime, station.updated_at), ) ) track_models: list[TrackModel] = [] for track in tracks_entities: coordinates: list[tuple[float, float]] = [] geometry = track.track_geometry shape = ( to_shape(cast(WKBElement | WKTElement, geometry)) if geometry is not None and LineString is not None else None ) if ( LineString is not None and shape is not None and isinstance(shape, LineString) ): coords_list: list[tuple[float, float]] = [] for coord in shape.coords: lon = float(coord[0]) lat = float(coord[1]) coords_list.append((lat, lon)) coordinates = coords_list track_models.append( TrackModel( id=str(track.id), start_station_id=str(track.start_station_id), end_station_id=str(track.end_station_id), length_meters=_to_float(track.length_meters), max_speed_kph=_to_float(track.max_speed_kph), status=track.status, is_bidirectional=track.is_bidirectional, coordinates=coordinates, created_at=cast(datetime, track.created_at), updated_at=cast(datetime, track.updated_at), ) ) train_models: list[TrainModel] = [] for train in trains_entities: train_models.append( TrainModel( id=str(train.id), designation=train.designation, capacity=train.capacity, max_speed_kph=_to_float(train.max_speed_kph), operating_track_ids=[], created_at=cast(datetime, train.created_at), updated_at=cast(datetime, train.updated_at), ) ) return _serialize_snapshot(station_models, track_models, train_models)