feat: Add combined track functionality with repository and service layers
- Introduced CombinedTrackModel, CombinedTrackCreate, and CombinedTrackRepository for managing combined tracks. - Implemented logic to create combined tracks based on existing tracks between two stations. - Added methods to check for existing combined tracks and retrieve constituent track IDs. - Enhanced TrackModel and TrackRepository to support OSM ID and track updates. - Created migration scripts for adding combined tracks table and OSM ID to tracks. - Updated services and API endpoints to handle combined track operations. - Added tests for combined track creation, repository methods, and API interactions.
This commit is contained in:
73
backend/app/repositories/combined_tracks.py
Normal file
73
backend/app/repositories/combined_tracks.py
Normal file
@@ -0,0 +1,73 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from uuid import UUID
|
||||
|
||||
import sqlalchemy as sa
|
||||
from geoalchemy2.elements import WKTElement
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from backend.app.db.models import CombinedTrack
|
||||
from backend.app.models import CombinedTrackCreate
|
||||
from backend.app.repositories.base import BaseRepository
|
||||
|
||||
|
||||
class CombinedTrackRepository(BaseRepository[CombinedTrack]):
|
||||
model = CombinedTrack
|
||||
|
||||
def __init__(self, session: Session) -> None:
|
||||
super().__init__(session)
|
||||
|
||||
def list_all(self) -> list[CombinedTrack]:
|
||||
statement = sa.select(self.model)
|
||||
return list(self.session.scalars(statement))
|
||||
|
||||
def exists_between_stations(self, start_station_id: str, end_station_id: str) -> bool:
|
||||
"""Check if a combined track already exists between two stations."""
|
||||
statement = sa.select(sa.exists().where(
|
||||
sa.and_(
|
||||
self.model.start_station_id == start_station_id,
|
||||
self.model.end_station_id == end_station_id
|
||||
)
|
||||
))
|
||||
return bool(self.session.scalar(statement))
|
||||
|
||||
def get_constituent_track_ids(self, combined_track: CombinedTrack) -> list[str]:
|
||||
"""Extract constituent track IDs from a combined track."""
|
||||
try:
|
||||
return json.loads(combined_track.constituent_track_ids)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
return []
|
||||
|
||||
@staticmethod
|
||||
def _ensure_uuid(value: UUID | str) -> UUID:
|
||||
if isinstance(value, UUID):
|
||||
return value
|
||||
return UUID(str(value))
|
||||
|
||||
@staticmethod
|
||||
def _line_string(coordinates: list[tuple[float, float]]) -> WKTElement:
|
||||
if len(coordinates) < 2:
|
||||
raise ValueError(
|
||||
"Combined track geometry requires at least two coordinate pairs")
|
||||
parts = [f"{lon} {lat}" for lat, lon in coordinates]
|
||||
return WKTElement(f"LINESTRING({', '.join(parts)})", srid=4326)
|
||||
|
||||
def create(self, data: CombinedTrackCreate) -> CombinedTrack:
|
||||
coordinates = list(data.coordinates)
|
||||
geometry = self._line_string(coordinates)
|
||||
constituent_track_ids_json = json.dumps(data.constituent_track_ids)
|
||||
|
||||
combined_track = CombinedTrack(
|
||||
name=data.name,
|
||||
start_station_id=self._ensure_uuid(data.start_station_id),
|
||||
end_station_id=self._ensure_uuid(data.end_station_id),
|
||||
length_meters=data.length_meters,
|
||||
max_speed_kph=data.max_speed_kph,
|
||||
is_bidirectional=data.is_bidirectional,
|
||||
status=data.status,
|
||||
combined_geometry=geometry,
|
||||
constituent_track_ids=constituent_track_ids_json,
|
||||
)
|
||||
self.session.add(combined_track)
|
||||
return combined_track
|
||||
Reference in New Issue
Block a user