- Introduced CombinedTrackModel, CombinedTrackCreate, and CombinedTrackRepository for managing combined tracks. - Implemented logic to create combined tracks based on existing tracks between two stations. - Added methods to check for existing combined tracks and retrieve constituent track IDs. - Enhanced TrackModel and TrackRepository to support OSM ID and track updates. - Created migration scripts for adding combined tracks table and OSM ID to tracks. - Updated services and API endpoints to handle combined track operations. - Added tests for combined track creation, repository methods, and API interactions.
73 lines
2.7 KiB
Python
73 lines
2.7 KiB
Python
from __future__ import annotations
|
|
|
|
import json
|
|
from uuid import UUID
|
|
|
|
import sqlalchemy as sa
|
|
from geoalchemy2.elements import WKTElement
|
|
from sqlalchemy.orm import Session
|
|
|
|
from backend.app.db.models import CombinedTrack
|
|
from backend.app.models import CombinedTrackCreate
|
|
from backend.app.repositories.base import BaseRepository
|
|
|
|
|
|
class CombinedTrackRepository(BaseRepository[CombinedTrack]):
|
|
model = CombinedTrack
|
|
|
|
def __init__(self, session: Session) -> None:
|
|
super().__init__(session)
|
|
|
|
def list_all(self) -> list[CombinedTrack]:
|
|
statement = sa.select(self.model)
|
|
return list(self.session.scalars(statement))
|
|
|
|
def exists_between_stations(self, start_station_id: str, end_station_id: str) -> bool:
|
|
"""Check if a combined track already exists between two stations."""
|
|
statement = sa.select(sa.exists().where(
|
|
sa.and_(
|
|
self.model.start_station_id == start_station_id,
|
|
self.model.end_station_id == end_station_id
|
|
)
|
|
))
|
|
return bool(self.session.scalar(statement))
|
|
|
|
def get_constituent_track_ids(self, combined_track: CombinedTrack) -> list[str]:
|
|
"""Extract constituent track IDs from a combined track."""
|
|
try:
|
|
return json.loads(combined_track.constituent_track_ids)
|
|
except (json.JSONDecodeError, TypeError):
|
|
return []
|
|
|
|
@staticmethod
|
|
def _ensure_uuid(value: UUID | str) -> UUID:
|
|
if isinstance(value, UUID):
|
|
return value
|
|
return UUID(str(value))
|
|
|
|
@staticmethod
|
|
def _line_string(coordinates: list[tuple[float, float]]) -> WKTElement:
|
|
if len(coordinates) < 2:
|
|
raise ValueError(
|
|
"Combined track geometry requires at least two coordinate pairs")
|
|
parts = [f"{lon} {lat}" for lat, lon in coordinates]
|
|
return WKTElement(f"LINESTRING({', '.join(parts)})", srid=4326)
|
|
|
|
def create(self, data: CombinedTrackCreate) -> CombinedTrack:
|
|
coordinates = list(data.coordinates)
|
|
geometry = self._line_string(coordinates)
|
|
constituent_track_ids_json = json.dumps(data.constituent_track_ids)
|
|
|
|
combined_track = CombinedTrack(
|
|
name=data.name,
|
|
start_station_id=self._ensure_uuid(data.start_station_id),
|
|
end_station_id=self._ensure_uuid(data.end_station_id),
|
|
length_meters=data.length_meters,
|
|
max_speed_kph=data.max_speed_kph,
|
|
is_bidirectional=data.is_bidirectional,
|
|
status=data.status,
|
|
combined_geometry=geometry,
|
|
constituent_track_ids=constituent_track_ids_json,
|
|
)
|
|
self.session.add(combined_track)
|
|
return combined_track |