from __future__ import annotations from uuid import UUID import sqlalchemy as sa from geoalchemy2.elements import WKTElement from sqlalchemy.orm import Session from backend.app.db.models import Track from backend.app.models import TrackCreate, TrackUpdate from backend.app.repositories.base import BaseRepository class TrackRepository(BaseRepository[Track]): model = Track def __init__(self, session: Session) -> None: super().__init__(session) def list_all(self) -> list[Track]: statement = sa.select(self.model) return list(self.session.scalars(statement)) def exists_by_osm_id(self, osm_id: str) -> bool: statement = sa.select(sa.exists().where(self.model.osm_id == osm_id)) return bool(self.session.scalar(statement)) def find_path_between_stations(self, start_station_id: str, end_station_id: str) -> list[Track] | None: """Find the shortest path between two stations using existing tracks. Returns a list of tracks that form the path, or None if no path exists. """ # Build adjacency list: station -> list of (neighbor_station, track) adjacency = self._build_track_graph() if start_station_id not in adjacency or end_station_id not in adjacency: return None # BFS to find shortest path from collections import deque # (current_station, path_so_far) queue = deque([(start_station_id, [])]) visited = set([start_station_id]) while queue: current_station, path = queue.popleft() if current_station == end_station_id: return path for neighbor, track in adjacency[current_station]: if neighbor not in visited: visited.add(neighbor) queue.append((neighbor, path + [track])) return None # No path found def _build_track_graph(self) -> dict[str, list[tuple[str, Track]]]: """Build a graph representation of tracks: station -> [(neighbor_station, track), ...]""" tracks = self.list_all() graph = {} for track in tracks: start_id = str(track.start_station_id) end_id = str(track.end_station_id) # Add bidirectional edges (assuming tracks are bidirectional) if start_id not in graph: graph[start_id] = [] if end_id not in graph: graph[end_id] = [] graph[start_id].append((end_id, track)) graph[end_id].append((start_id, track)) return graph def combine_track_geometries(self, tracks: list[Track]) -> list[tuple[float, float]]: """Combine the geometries of multiple tracks into a single coordinate sequence. Assumes tracks are in order and form a continuous path. """ if not tracks: return [] combined_coords = [] for i, track in enumerate(tracks): # Extract coordinates from track geometry coords = self._extract_coordinates_from_track(track) if i == 0: # First track: add all coordinates combined_coords.extend(coords) else: # Subsequent tracks: skip the first coordinate (shared with previous track) combined_coords.extend(coords[1:]) return combined_coords def _extract_coordinates_from_track(self, track: Track) -> list[tuple[float, float]]: """Extract coordinate list from a track's geometry.""" # Convert WKT string to WKTElement, then to shapely geometry from geoalchemy2.elements import WKTElement from geoalchemy2.shape import to_shape try: wkt_element = WKTElement(track.track_geometry) geom = to_shape(wkt_element) if hasattr(geom, 'coords'): # For LineString, coords returns [(x, y), ...] where x=lon, y=lat # Convert to (lat, lon) return [(coord[1], coord[0]) for coord in geom.coords] except Exception: pass return [] @staticmethod def _ensure_uuid(value: UUID | str) -> UUID: if isinstance(value, UUID): return value return UUID(str(value)) @staticmethod def _line_string(coordinates: list[tuple[float, float]]) -> WKTElement: if len(coordinates) < 2: raise ValueError( "Track geometry requires at least two coordinate pairs") parts = [f"{lon} {lat}" for lat, lon in coordinates] return WKTElement(f"LINESTRING({', '.join(parts)})", srid=4326) def create(self, data: TrackCreate) -> Track: coordinates = list(data.coordinates) geometry = self._line_string(coordinates) track = Track( osm_id=data.osm_id, name=data.name, start_station_id=self._ensure_uuid(data.start_station_id), end_station_id=self._ensure_uuid(data.end_station_id), length_meters=data.length_meters, max_speed_kph=data.max_speed_kph, is_bidirectional=data.is_bidirectional, status=data.status, track_geometry=geometry, ) self.session.add(track) return track def update(self, track: Track, payload: TrackUpdate) -> Track: if payload.start_station_id is not None: track.start_station_id = self._ensure_uuid( payload.start_station_id) if payload.end_station_id is not None: track.end_station_id = self._ensure_uuid(payload.end_station_id) if payload.coordinates is not None: track.track_geometry = self._line_string( list(payload.coordinates)) # type: ignore[assignment] if payload.osm_id is not None: track.osm_id = payload.osm_id if payload.name is not None: track.name = payload.name if payload.length_meters is not None: track.length_meters = payload.length_meters if payload.max_speed_kph is not None: track.max_speed_kph = payload.max_speed_kph if payload.is_bidirectional is not None: track.is_bidirectional = payload.is_bidirectional if payload.status is not None: track.status = payload.status return track