Files
rail-game/backend/app/repositories/tracks.py
zwitschi 68048ff574
Some checks failed
Backend CI / lint-and-test (push) Failing after 2m27s
Frontend CI / lint-and-build (push) Successful in 57s
feat: Add combined track functionality with repository and service layers
- Introduced CombinedTrackModel, CombinedTrackCreate, and CombinedTrackRepository for managing combined tracks.
- Implemented logic to create combined tracks based on existing tracks between two stations.
- Added methods to check for existing combined tracks and retrieve constituent track IDs.
- Enhanced TrackModel and TrackRepository to support OSM ID and track updates.
- Created migration scripts for adding combined tracks table and OSM ID to tracks.
- Updated services and API endpoints to handle combined track operations.
- Added tests for combined track creation, repository methods, and API interactions.
2025-11-10 14:12:28 +01:00

173 lines
6.2 KiB
Python

from __future__ import annotations
from uuid import UUID
import sqlalchemy as sa
from geoalchemy2.elements import WKTElement
from sqlalchemy.orm import Session
from backend.app.db.models import Track
from backend.app.models import TrackCreate, TrackUpdate
from backend.app.repositories.base import BaseRepository
class TrackRepository(BaseRepository[Track]):
model = Track
def __init__(self, session: Session) -> None:
super().__init__(session)
def list_all(self) -> list[Track]:
statement = sa.select(self.model)
return list(self.session.scalars(statement))
def exists_by_osm_id(self, osm_id: str) -> bool:
statement = sa.select(sa.exists().where(self.model.osm_id == osm_id))
return bool(self.session.scalar(statement))
def find_path_between_stations(self, start_station_id: str, end_station_id: str) -> list[Track] | None:
"""Find the shortest path between two stations using existing tracks.
Returns a list of tracks that form the path, or None if no path exists.
"""
# Build adjacency list: station -> list of (neighbor_station, track)
adjacency = self._build_track_graph()
if start_station_id not in adjacency or end_station_id not in adjacency:
return None
# BFS to find shortest path
from collections import deque
# (current_station, path_so_far)
queue = deque([(start_station_id, [])])
visited = set([start_station_id])
while queue:
current_station, path = queue.popleft()
if current_station == end_station_id:
return path
for neighbor, track in adjacency[current_station]:
if neighbor not in visited:
visited.add(neighbor)
queue.append((neighbor, path + [track]))
return None # No path found
def _build_track_graph(self) -> dict[str, list[tuple[str, Track]]]:
"""Build a graph representation of tracks: station -> [(neighbor_station, track), ...]"""
tracks = self.list_all()
graph = {}
for track in tracks:
start_id = str(track.start_station_id)
end_id = str(track.end_station_id)
# Add bidirectional edges (assuming tracks are bidirectional)
if start_id not in graph:
graph[start_id] = []
if end_id not in graph:
graph[end_id] = []
graph[start_id].append((end_id, track))
graph[end_id].append((start_id, track))
return graph
def combine_track_geometries(self, tracks: list[Track]) -> list[tuple[float, float]]:
"""Combine the geometries of multiple tracks into a single coordinate sequence.
Assumes tracks are in order and form a continuous path.
"""
if not tracks:
return []
combined_coords = []
for i, track in enumerate(tracks):
# Extract coordinates from track geometry
coords = self._extract_coordinates_from_track(track)
if i == 0:
# First track: add all coordinates
combined_coords.extend(coords)
else:
# Subsequent tracks: skip the first coordinate (shared with previous track)
combined_coords.extend(coords[1:])
return combined_coords
def _extract_coordinates_from_track(self, track: Track) -> list[tuple[float, float]]:
"""Extract coordinate list from a track's geometry."""
# Convert WKT string to WKTElement, then to shapely geometry
from geoalchemy2.elements import WKTElement
from geoalchemy2.shape import to_shape
try:
wkt_element = WKTElement(track.track_geometry)
geom = to_shape(wkt_element)
if hasattr(geom, 'coords'):
# For LineString, coords returns [(x, y), ...] where x=lon, y=lat
# Convert to (lat, lon)
return [(coord[1], coord[0]) for coord in geom.coords]
except Exception:
pass
return []
@staticmethod
def _ensure_uuid(value: UUID | str) -> UUID:
if isinstance(value, UUID):
return value
return UUID(str(value))
@staticmethod
def _line_string(coordinates: list[tuple[float, float]]) -> WKTElement:
if len(coordinates) < 2:
raise ValueError(
"Track geometry requires at least two coordinate pairs")
parts = [f"{lon} {lat}" for lat, lon in coordinates]
return WKTElement(f"LINESTRING({', '.join(parts)})", srid=4326)
def create(self, data: TrackCreate) -> Track:
coordinates = list(data.coordinates)
geometry = self._line_string(coordinates)
track = Track(
osm_id=data.osm_id,
name=data.name,
start_station_id=self._ensure_uuid(data.start_station_id),
end_station_id=self._ensure_uuid(data.end_station_id),
length_meters=data.length_meters,
max_speed_kph=data.max_speed_kph,
is_bidirectional=data.is_bidirectional,
status=data.status,
track_geometry=geometry,
)
self.session.add(track)
return track
def update(self, track: Track, payload: TrackUpdate) -> Track:
if payload.start_station_id is not None:
track.start_station_id = self._ensure_uuid(
payload.start_station_id)
if payload.end_station_id is not None:
track.end_station_id = self._ensure_uuid(payload.end_station_id)
if payload.coordinates is not None:
track.track_geometry = self._line_string(
list(payload.coordinates)) # type: ignore[assignment]
if payload.osm_id is not None:
track.osm_id = payload.osm_id
if payload.name is not None:
track.name = payload.name
if payload.length_meters is not None:
track.length_meters = payload.length_meters
if payload.max_speed_kph is not None:
track.max_speed_kph = payload.max_speed_kph
if payload.is_bidirectional is not None:
track.is_bidirectional = payload.is_bidirectional
if payload.status is not None:
track.status = payload.status
return track