- Introduced CombinedTrackModel, CombinedTrackCreate, and CombinedTrackRepository for managing combined tracks. - Implemented logic to create combined tracks based on existing tracks between two stations. - Added methods to check for existing combined tracks and retrieve constituent track IDs. - Enhanced TrackModel and TrackRepository to support OSM ID and track updates. - Created migration scripts for adding combined tracks table and OSM ID to tracks. - Updated services and API endpoints to handle combined track operations. - Added tests for combined track creation, repository methods, and API interactions.
173 lines
6.2 KiB
Python
173 lines
6.2 KiB
Python
from __future__ import annotations
|
|
|
|
from uuid import UUID
|
|
|
|
import sqlalchemy as sa
|
|
from geoalchemy2.elements import WKTElement
|
|
from sqlalchemy.orm import Session
|
|
|
|
from backend.app.db.models import Track
|
|
from backend.app.models import TrackCreate, TrackUpdate
|
|
from backend.app.repositories.base import BaseRepository
|
|
|
|
|
|
class TrackRepository(BaseRepository[Track]):
|
|
model = Track
|
|
|
|
def __init__(self, session: Session) -> None:
|
|
super().__init__(session)
|
|
|
|
def list_all(self) -> list[Track]:
|
|
statement = sa.select(self.model)
|
|
return list(self.session.scalars(statement))
|
|
|
|
def exists_by_osm_id(self, osm_id: str) -> bool:
|
|
statement = sa.select(sa.exists().where(self.model.osm_id == osm_id))
|
|
return bool(self.session.scalar(statement))
|
|
|
|
def find_path_between_stations(self, start_station_id: str, end_station_id: str) -> list[Track] | None:
|
|
"""Find the shortest path between two stations using existing tracks.
|
|
|
|
Returns a list of tracks that form the path, or None if no path exists.
|
|
"""
|
|
# Build adjacency list: station -> list of (neighbor_station, track)
|
|
adjacency = self._build_track_graph()
|
|
|
|
if start_station_id not in adjacency or end_station_id not in adjacency:
|
|
return None
|
|
|
|
# BFS to find shortest path
|
|
from collections import deque
|
|
|
|
# (current_station, path_so_far)
|
|
queue = deque([(start_station_id, [])])
|
|
visited = set([start_station_id])
|
|
|
|
while queue:
|
|
current_station, path = queue.popleft()
|
|
|
|
if current_station == end_station_id:
|
|
return path
|
|
|
|
for neighbor, track in adjacency[current_station]:
|
|
if neighbor not in visited:
|
|
visited.add(neighbor)
|
|
queue.append((neighbor, path + [track]))
|
|
|
|
return None # No path found
|
|
|
|
def _build_track_graph(self) -> dict[str, list[tuple[str, Track]]]:
|
|
"""Build a graph representation of tracks: station -> [(neighbor_station, track), ...]"""
|
|
tracks = self.list_all()
|
|
graph = {}
|
|
|
|
for track in tracks:
|
|
start_id = str(track.start_station_id)
|
|
end_id = str(track.end_station_id)
|
|
|
|
# Add bidirectional edges (assuming tracks are bidirectional)
|
|
if start_id not in graph:
|
|
graph[start_id] = []
|
|
if end_id not in graph:
|
|
graph[end_id] = []
|
|
|
|
graph[start_id].append((end_id, track))
|
|
graph[end_id].append((start_id, track))
|
|
|
|
return graph
|
|
|
|
def combine_track_geometries(self, tracks: list[Track]) -> list[tuple[float, float]]:
|
|
"""Combine the geometries of multiple tracks into a single coordinate sequence.
|
|
|
|
Assumes tracks are in order and form a continuous path.
|
|
"""
|
|
if not tracks:
|
|
return []
|
|
|
|
combined_coords = []
|
|
|
|
for i, track in enumerate(tracks):
|
|
# Extract coordinates from track geometry
|
|
coords = self._extract_coordinates_from_track(track)
|
|
|
|
if i == 0:
|
|
# First track: add all coordinates
|
|
combined_coords.extend(coords)
|
|
else:
|
|
# Subsequent tracks: skip the first coordinate (shared with previous track)
|
|
combined_coords.extend(coords[1:])
|
|
|
|
return combined_coords
|
|
|
|
def _extract_coordinates_from_track(self, track: Track) -> list[tuple[float, float]]:
|
|
"""Extract coordinate list from a track's geometry."""
|
|
# Convert WKT string to WKTElement, then to shapely geometry
|
|
from geoalchemy2.elements import WKTElement
|
|
from geoalchemy2.shape import to_shape
|
|
|
|
try:
|
|
wkt_element = WKTElement(track.track_geometry)
|
|
geom = to_shape(wkt_element)
|
|
if hasattr(geom, 'coords'):
|
|
# For LineString, coords returns [(x, y), ...] where x=lon, y=lat
|
|
# Convert to (lat, lon)
|
|
return [(coord[1], coord[0]) for coord in geom.coords]
|
|
except Exception:
|
|
pass
|
|
|
|
return []
|
|
|
|
@staticmethod
|
|
def _ensure_uuid(value: UUID | str) -> UUID:
|
|
if isinstance(value, UUID):
|
|
return value
|
|
return UUID(str(value))
|
|
|
|
@staticmethod
|
|
def _line_string(coordinates: list[tuple[float, float]]) -> WKTElement:
|
|
if len(coordinates) < 2:
|
|
raise ValueError(
|
|
"Track geometry requires at least two coordinate pairs")
|
|
parts = [f"{lon} {lat}" for lat, lon in coordinates]
|
|
return WKTElement(f"LINESTRING({', '.join(parts)})", srid=4326)
|
|
|
|
def create(self, data: TrackCreate) -> Track:
|
|
coordinates = list(data.coordinates)
|
|
geometry = self._line_string(coordinates)
|
|
track = Track(
|
|
osm_id=data.osm_id,
|
|
name=data.name,
|
|
start_station_id=self._ensure_uuid(data.start_station_id),
|
|
end_station_id=self._ensure_uuid(data.end_station_id),
|
|
length_meters=data.length_meters,
|
|
max_speed_kph=data.max_speed_kph,
|
|
is_bidirectional=data.is_bidirectional,
|
|
status=data.status,
|
|
track_geometry=geometry,
|
|
)
|
|
self.session.add(track)
|
|
return track
|
|
|
|
def update(self, track: Track, payload: TrackUpdate) -> Track:
|
|
if payload.start_station_id is not None:
|
|
track.start_station_id = self._ensure_uuid(
|
|
payload.start_station_id)
|
|
if payload.end_station_id is not None:
|
|
track.end_station_id = self._ensure_uuid(payload.end_station_id)
|
|
if payload.coordinates is not None:
|
|
track.track_geometry = self._line_string(
|
|
list(payload.coordinates)) # type: ignore[assignment]
|
|
if payload.osm_id is not None:
|
|
track.osm_id = payload.osm_id
|
|
if payload.name is not None:
|
|
track.name = payload.name
|
|
if payload.length_meters is not None:
|
|
track.length_meters = payload.length_meters
|
|
if payload.max_speed_kph is not None:
|
|
track.max_speed_kph = payload.max_speed_kph
|
|
if payload.is_bidirectional is not None:
|
|
track.is_bidirectional = payload.is_bidirectional
|
|
if payload.status is not None:
|
|
track.status = payload.status
|
|
return track
|