from __future__ import annotations import json from uuid import UUID import sqlalchemy as sa from geoalchemy2.elements import WKTElement from sqlalchemy.orm import Session from backend.app.db.models import CombinedTrack from backend.app.models import CombinedTrackCreate from backend.app.repositories.base import BaseRepository class CombinedTrackRepository(BaseRepository[CombinedTrack]): model = CombinedTrack def __init__(self, session: Session) -> None: super().__init__(session) def list_all(self) -> list[CombinedTrack]: statement = sa.select(self.model) return list(self.session.scalars(statement)) def exists_between_stations(self, start_station_id: str, end_station_id: str) -> bool: """Check if a combined track already exists between two stations.""" statement = sa.select(sa.exists().where( sa.and_( self.model.start_station_id == start_station_id, self.model.end_station_id == end_station_id ) )) return bool(self.session.scalar(statement)) def get_constituent_track_ids(self, combined_track: CombinedTrack) -> list[str]: """Extract constituent track IDs from a combined track.""" try: return json.loads(combined_track.constituent_track_ids) except (json.JSONDecodeError, TypeError): return [] @staticmethod def _ensure_uuid(value: UUID | str) -> UUID: if isinstance(value, UUID): return value return UUID(str(value)) @staticmethod def _line_string(coordinates: list[tuple[float, float]]) -> WKTElement: if len(coordinates) < 2: raise ValueError( "Combined track geometry requires at least two coordinate pairs") parts = [f"{lon} {lat}" for lat, lon in coordinates] return WKTElement(f"LINESTRING({', '.join(parts)})", srid=4326) def create(self, data: CombinedTrackCreate) -> CombinedTrack: coordinates = list(data.coordinates) geometry = self._line_string(coordinates) constituent_track_ids_json = json.dumps(data.constituent_track_ids) combined_track = CombinedTrack( name=data.name, start_station_id=self._ensure_uuid(data.start_station_id), end_station_id=self._ensure_uuid(data.end_station_id), length_meters=data.length_meters, max_speed_kph=data.max_speed_kph, is_bidirectional=data.is_bidirectional, status=data.status, combined_geometry=geometry, constituent_track_ids=constituent_track_ids_json, ) self.session.add(combined_track) return combined_track