feat: Add combined track functionality with repository and service layers
- Introduced CombinedTrackModel, CombinedTrackCreate, and CombinedTrackRepository for managing combined tracks. - Implemented logic to create combined tracks based on existing tracks between two stations. - Added methods to check for existing combined tracks and retrieve constituent track IDs. - Enhanced TrackModel and TrackRepository to support OSM ID and track updates. - Created migration scripts for adding combined tracks table and OSM ID to tracks. - Updated services and API endpoints to handle combined track operations. - Added tests for combined track creation, repository methods, and API interactions.
This commit is contained in:
@@ -2,6 +2,7 @@
|
||||
|
||||
from backend.app.repositories.stations import StationRepository
|
||||
from backend.app.repositories.tracks import TrackRepository
|
||||
from backend.app.repositories.combined_tracks import CombinedTrackRepository
|
||||
from backend.app.repositories.train_schedules import TrainScheduleRepository
|
||||
from backend.app.repositories.trains import TrainRepository
|
||||
from backend.app.repositories.users import UserRepository
|
||||
@@ -10,6 +11,7 @@ __all__ = [
|
||||
"StationRepository",
|
||||
"TrainScheduleRepository",
|
||||
"TrackRepository",
|
||||
"CombinedTrackRepository",
|
||||
"TrainRepository",
|
||||
"UserRepository",
|
||||
]
|
||||
|
||||
73
backend/app/repositories/combined_tracks.py
Normal file
73
backend/app/repositories/combined_tracks.py
Normal file
@@ -0,0 +1,73 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from uuid import UUID
|
||||
|
||||
import sqlalchemy as sa
|
||||
from geoalchemy2.elements import WKTElement
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from backend.app.db.models import CombinedTrack
|
||||
from backend.app.models import CombinedTrackCreate
|
||||
from backend.app.repositories.base import BaseRepository
|
||||
|
||||
|
||||
class CombinedTrackRepository(BaseRepository[CombinedTrack]):
|
||||
model = CombinedTrack
|
||||
|
||||
def __init__(self, session: Session) -> None:
|
||||
super().__init__(session)
|
||||
|
||||
def list_all(self) -> list[CombinedTrack]:
|
||||
statement = sa.select(self.model)
|
||||
return list(self.session.scalars(statement))
|
||||
|
||||
def exists_between_stations(self, start_station_id: str, end_station_id: str) -> bool:
|
||||
"""Check if a combined track already exists between two stations."""
|
||||
statement = sa.select(sa.exists().where(
|
||||
sa.and_(
|
||||
self.model.start_station_id == start_station_id,
|
||||
self.model.end_station_id == end_station_id
|
||||
)
|
||||
))
|
||||
return bool(self.session.scalar(statement))
|
||||
|
||||
def get_constituent_track_ids(self, combined_track: CombinedTrack) -> list[str]:
|
||||
"""Extract constituent track IDs from a combined track."""
|
||||
try:
|
||||
return json.loads(combined_track.constituent_track_ids)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
return []
|
||||
|
||||
@staticmethod
|
||||
def _ensure_uuid(value: UUID | str) -> UUID:
|
||||
if isinstance(value, UUID):
|
||||
return value
|
||||
return UUID(str(value))
|
||||
|
||||
@staticmethod
|
||||
def _line_string(coordinates: list[tuple[float, float]]) -> WKTElement:
|
||||
if len(coordinates) < 2:
|
||||
raise ValueError(
|
||||
"Combined track geometry requires at least two coordinate pairs")
|
||||
parts = [f"{lon} {lat}" for lat, lon in coordinates]
|
||||
return WKTElement(f"LINESTRING({', '.join(parts)})", srid=4326)
|
||||
|
||||
def create(self, data: CombinedTrackCreate) -> CombinedTrack:
|
||||
coordinates = list(data.coordinates)
|
||||
geometry = self._line_string(coordinates)
|
||||
constituent_track_ids_json = json.dumps(data.constituent_track_ids)
|
||||
|
||||
combined_track = CombinedTrack(
|
||||
name=data.name,
|
||||
start_station_id=self._ensure_uuid(data.start_station_id),
|
||||
end_station_id=self._ensure_uuid(data.end_station_id),
|
||||
length_meters=data.length_meters,
|
||||
max_speed_kph=data.max_speed_kph,
|
||||
is_bidirectional=data.is_bidirectional,
|
||||
status=data.status,
|
||||
combined_geometry=geometry,
|
||||
constituent_track_ids=constituent_track_ids_json,
|
||||
)
|
||||
self.session.add(combined_track)
|
||||
return combined_track
|
||||
@@ -7,7 +7,7 @@ from geoalchemy2.elements import WKTElement
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from backend.app.db.models import Track
|
||||
from backend.app.models import TrackCreate
|
||||
from backend.app.models import TrackCreate, TrackUpdate
|
||||
from backend.app.repositories.base import BaseRepository
|
||||
|
||||
|
||||
@@ -21,6 +21,102 @@ class TrackRepository(BaseRepository[Track]):
|
||||
statement = sa.select(self.model)
|
||||
return list(self.session.scalars(statement))
|
||||
|
||||
def exists_by_osm_id(self, osm_id: str) -> bool:
|
||||
statement = sa.select(sa.exists().where(self.model.osm_id == osm_id))
|
||||
return bool(self.session.scalar(statement))
|
||||
|
||||
def find_path_between_stations(self, start_station_id: str, end_station_id: str) -> list[Track] | None:
|
||||
"""Find the shortest path between two stations using existing tracks.
|
||||
|
||||
Returns a list of tracks that form the path, or None if no path exists.
|
||||
"""
|
||||
# Build adjacency list: station -> list of (neighbor_station, track)
|
||||
adjacency = self._build_track_graph()
|
||||
|
||||
if start_station_id not in adjacency or end_station_id not in adjacency:
|
||||
return None
|
||||
|
||||
# BFS to find shortest path
|
||||
from collections import deque
|
||||
|
||||
# (current_station, path_so_far)
|
||||
queue = deque([(start_station_id, [])])
|
||||
visited = set([start_station_id])
|
||||
|
||||
while queue:
|
||||
current_station, path = queue.popleft()
|
||||
|
||||
if current_station == end_station_id:
|
||||
return path
|
||||
|
||||
for neighbor, track in adjacency[current_station]:
|
||||
if neighbor not in visited:
|
||||
visited.add(neighbor)
|
||||
queue.append((neighbor, path + [track]))
|
||||
|
||||
return None # No path found
|
||||
|
||||
def _build_track_graph(self) -> dict[str, list[tuple[str, Track]]]:
|
||||
"""Build a graph representation of tracks: station -> [(neighbor_station, track), ...]"""
|
||||
tracks = self.list_all()
|
||||
graph = {}
|
||||
|
||||
for track in tracks:
|
||||
start_id = str(track.start_station_id)
|
||||
end_id = str(track.end_station_id)
|
||||
|
||||
# Add bidirectional edges (assuming tracks are bidirectional)
|
||||
if start_id not in graph:
|
||||
graph[start_id] = []
|
||||
if end_id not in graph:
|
||||
graph[end_id] = []
|
||||
|
||||
graph[start_id].append((end_id, track))
|
||||
graph[end_id].append((start_id, track))
|
||||
|
||||
return graph
|
||||
|
||||
def combine_track_geometries(self, tracks: list[Track]) -> list[tuple[float, float]]:
|
||||
"""Combine the geometries of multiple tracks into a single coordinate sequence.
|
||||
|
||||
Assumes tracks are in order and form a continuous path.
|
||||
"""
|
||||
if not tracks:
|
||||
return []
|
||||
|
||||
combined_coords = []
|
||||
|
||||
for i, track in enumerate(tracks):
|
||||
# Extract coordinates from track geometry
|
||||
coords = self._extract_coordinates_from_track(track)
|
||||
|
||||
if i == 0:
|
||||
# First track: add all coordinates
|
||||
combined_coords.extend(coords)
|
||||
else:
|
||||
# Subsequent tracks: skip the first coordinate (shared with previous track)
|
||||
combined_coords.extend(coords[1:])
|
||||
|
||||
return combined_coords
|
||||
|
||||
def _extract_coordinates_from_track(self, track: Track) -> list[tuple[float, float]]:
|
||||
"""Extract coordinate list from a track's geometry."""
|
||||
# Convert WKT string to WKTElement, then to shapely geometry
|
||||
from geoalchemy2.elements import WKTElement
|
||||
from geoalchemy2.shape import to_shape
|
||||
|
||||
try:
|
||||
wkt_element = WKTElement(track.track_geometry)
|
||||
geom = to_shape(wkt_element)
|
||||
if hasattr(geom, 'coords'):
|
||||
# For LineString, coords returns [(x, y), ...] where x=lon, y=lat
|
||||
# Convert to (lat, lon)
|
||||
return [(coord[1], coord[0]) for coord in geom.coords]
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return []
|
||||
|
||||
@staticmethod
|
||||
def _ensure_uuid(value: UUID | str) -> UUID:
|
||||
if isinstance(value, UUID):
|
||||
@@ -30,7 +126,8 @@ class TrackRepository(BaseRepository[Track]):
|
||||
@staticmethod
|
||||
def _line_string(coordinates: list[tuple[float, float]]) -> WKTElement:
|
||||
if len(coordinates) < 2:
|
||||
raise ValueError("Track geometry requires at least two coordinate pairs")
|
||||
raise ValueError(
|
||||
"Track geometry requires at least two coordinate pairs")
|
||||
parts = [f"{lon} {lat}" for lat, lon in coordinates]
|
||||
return WKTElement(f"LINESTRING({', '.join(parts)})", srid=4326)
|
||||
|
||||
@@ -38,6 +135,7 @@ class TrackRepository(BaseRepository[Track]):
|
||||
coordinates = list(data.coordinates)
|
||||
geometry = self._line_string(coordinates)
|
||||
track = Track(
|
||||
osm_id=data.osm_id,
|
||||
name=data.name,
|
||||
start_station_id=self._ensure_uuid(data.start_station_id),
|
||||
end_station_id=self._ensure_uuid(data.end_station_id),
|
||||
@@ -49,3 +147,26 @@ class TrackRepository(BaseRepository[Track]):
|
||||
)
|
||||
self.session.add(track)
|
||||
return track
|
||||
|
||||
def update(self, track: Track, payload: TrackUpdate) -> Track:
|
||||
if payload.start_station_id is not None:
|
||||
track.start_station_id = self._ensure_uuid(
|
||||
payload.start_station_id)
|
||||
if payload.end_station_id is not None:
|
||||
track.end_station_id = self._ensure_uuid(payload.end_station_id)
|
||||
if payload.coordinates is not None:
|
||||
track.track_geometry = self._line_string(
|
||||
list(payload.coordinates)) # type: ignore[assignment]
|
||||
if payload.osm_id is not None:
|
||||
track.osm_id = payload.osm_id
|
||||
if payload.name is not None:
|
||||
track.name = payload.name
|
||||
if payload.length_meters is not None:
|
||||
track.length_meters = payload.length_meters
|
||||
if payload.max_speed_kph is not None:
|
||||
track.max_speed_kph = payload.max_speed_kph
|
||||
if payload.is_bidirectional is not None:
|
||||
track.is_bidirectional = payload.is_bidirectional
|
||||
if payload.status is not None:
|
||||
track.status = payload.status
|
||||
return track
|
||||
|
||||
Reference in New Issue
Block a user