feat: Add combined track functionality with repository and service layers
- Introduced CombinedTrackModel, CombinedTrackCreate, and CombinedTrackRepository for managing combined tracks. - Implemented logic to create combined tracks based on existing tracks between two stations. - Added methods to check for existing combined tracks and retrieve constituent track IDs. - Enhanced TrackModel and TrackRepository to support OSM ID and track updates. - Created migration scripts for adding combined tracks table and OSM ID to tracks. - Updated services and API endpoints to handle combined track operations. - Added tests for combined track creation, repository methods, and API interactions.
This commit is contained in:
@@ -4,9 +4,11 @@ from backend.app.api.auth import router as auth_router
|
||||
from backend.app.api.health import router as health_router
|
||||
from backend.app.api.network import router as network_router
|
||||
from backend.app.api.stations import router as stations_router
|
||||
from backend.app.api.tracks import router as tracks_router
|
||||
|
||||
router = APIRouter()
|
||||
router.include_router(health_router, tags=["health"])
|
||||
router.include_router(auth_router)
|
||||
router.include_router(network_router)
|
||||
router.include_router(stations_router)
|
||||
router.include_router(tracks_router)
|
||||
|
||||
153
backend/app/api/tracks.py
Normal file
153
backend/app/api/tracks.py
Normal file
@@ -0,0 +1,153 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from backend.app.api.deps import get_current_user, get_db
|
||||
from backend.app.models import (
|
||||
CombinedTrackModel,
|
||||
TrackCreate,
|
||||
TrackUpdate,
|
||||
TrackModel,
|
||||
UserPublic,
|
||||
)
|
||||
from backend.app.services.combined_tracks import (
|
||||
create_combined_track,
|
||||
get_combined_track,
|
||||
list_combined_tracks,
|
||||
)
|
||||
from backend.app.services.tracks import (
|
||||
create_track,
|
||||
delete_track,
|
||||
regenerate_combined_tracks,
|
||||
update_track,
|
||||
get_track,
|
||||
list_tracks,
|
||||
)
|
||||
|
||||
router = APIRouter(prefix="/tracks", tags=["tracks"])
|
||||
|
||||
|
||||
@router.get("", response_model=list[TrackModel])
|
||||
def read_combined_tracks(
|
||||
_: UserPublic = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> list[TrackModel]:
|
||||
"""Return all base tracks."""
|
||||
return list_tracks(db)
|
||||
|
||||
|
||||
@router.get("/combined", response_model=list[CombinedTrackModel])
|
||||
def read_combined_tracks_combined(
|
||||
_: UserPublic = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> list[CombinedTrackModel]:
|
||||
return list_combined_tracks(db)
|
||||
|
||||
|
||||
@router.get("/{track_id}", response_model=TrackModel)
|
||||
def read_track(
|
||||
track_id: str,
|
||||
_: UserPublic = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> TrackModel:
|
||||
track = get_track(db, track_id)
|
||||
if track is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Track {track_id} not found",
|
||||
)
|
||||
return track
|
||||
|
||||
|
||||
@router.get("/combined/{combined_track_id}", response_model=CombinedTrackModel)
|
||||
def read_combined_track(
|
||||
combined_track_id: str,
|
||||
_: UserPublic = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> CombinedTrackModel:
|
||||
combined_track = get_combined_track(db, combined_track_id)
|
||||
if combined_track is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Combined track {combined_track_id} not found",
|
||||
)
|
||||
return combined_track
|
||||
|
||||
|
||||
@router.post("", response_model=TrackModel, status_code=status.HTTP_201_CREATED)
|
||||
def create_track_endpoint(
|
||||
payload: TrackCreate,
|
||||
regenerate: bool = False,
|
||||
_: UserPublic = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> TrackModel:
|
||||
try:
|
||||
track = create_track(db, payload)
|
||||
except ValueError as exc:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST, detail=str(exc)
|
||||
) from exc
|
||||
|
||||
if regenerate:
|
||||
regenerate_combined_tracks(
|
||||
db, [track.start_station_id, track.end_station_id])
|
||||
|
||||
return track
|
||||
|
||||
|
||||
@router.post(
|
||||
"/combined",
|
||||
response_model=CombinedTrackModel,
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
summary="Create a combined track between two stations using pathfinding",
|
||||
)
|
||||
def create_combined_track_endpoint(
|
||||
start_station_id: str,
|
||||
end_station_id: str,
|
||||
_: UserPublic = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> CombinedTrackModel:
|
||||
combined_track = create_combined_track(
|
||||
db, start_station_id, end_station_id)
|
||||
if combined_track is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Could not create combined track: no path exists between stations or track already exists",
|
||||
)
|
||||
return combined_track
|
||||
|
||||
|
||||
@router.put("/{track_id}", response_model=TrackModel)
|
||||
def update_track_endpoint(
|
||||
track_id: str,
|
||||
payload: TrackUpdate,
|
||||
regenerate: bool = False,
|
||||
_: UserPublic = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> TrackModel:
|
||||
track = update_track(db, track_id, payload)
|
||||
if track is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Track {track_id} not found",
|
||||
)
|
||||
if regenerate:
|
||||
regenerate_combined_tracks(
|
||||
db, [track.start_station_id, track.end_station_id])
|
||||
return track
|
||||
|
||||
|
||||
@router.delete("/{track_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||
def delete_track_endpoint(
|
||||
track_id: str,
|
||||
regenerate: bool = False,
|
||||
_: UserPublic = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> None:
|
||||
deleted = delete_track(db, track_id, regenerate=regenerate)
|
||||
if not deleted:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Track {track_id} not found",
|
||||
)
|
||||
@@ -41,11 +41,14 @@ class User(Base, TimestampMixin):
|
||||
id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True), primary_key=True, default=uuid.uuid4
|
||||
)
|
||||
username: Mapped[str] = mapped_column(String(64), unique=True, nullable=False)
|
||||
email: Mapped[str | None] = mapped_column(String(255), unique=True, nullable=True)
|
||||
username: Mapped[str] = mapped_column(
|
||||
String(64), unique=True, nullable=False)
|
||||
email: Mapped[str | None] = mapped_column(
|
||||
String(255), unique=True, nullable=True)
|
||||
full_name: Mapped[str | None] = mapped_column(String(128), nullable=True)
|
||||
password_hash: Mapped[str] = mapped_column(String(256), nullable=False)
|
||||
role: Mapped[str] = mapped_column(String(32), nullable=False, default="player")
|
||||
role: Mapped[str] = mapped_column(
|
||||
String(32), nullable=False, default="player")
|
||||
preferences: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
|
||||
|
||||
@@ -62,12 +65,50 @@ class Station(Base, TimestampMixin):
|
||||
Geometry(geometry_type="POINT", srid=4326), nullable=False
|
||||
)
|
||||
elevation_m: Mapped[float | None] = mapped_column(Float, nullable=True)
|
||||
is_active: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
|
||||
is_active: Mapped[bool] = mapped_column(
|
||||
Boolean, nullable=False, default=True)
|
||||
|
||||
|
||||
class Track(Base, TimestampMixin):
|
||||
__tablename__ = "tracks"
|
||||
|
||||
id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True), primary_key=True, default=uuid.uuid4
|
||||
)
|
||||
osm_id: Mapped[str | None] = mapped_column(String(32), nullable=True)
|
||||
name: Mapped[str | None] = mapped_column(String(128), nullable=True)
|
||||
start_station_id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("stations.id", ondelete="RESTRICT"),
|
||||
nullable=False,
|
||||
)
|
||||
end_station_id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("stations.id", ondelete="RESTRICT"),
|
||||
nullable=False,
|
||||
)
|
||||
length_meters: Mapped[float | None] = mapped_column(
|
||||
Numeric(10, 2), nullable=True)
|
||||
max_speed_kph: Mapped[int | None] = mapped_column(Integer, nullable=True)
|
||||
is_bidirectional: Mapped[bool] = mapped_column(
|
||||
Boolean, nullable=False, default=True
|
||||
)
|
||||
status: Mapped[str] = mapped_column(
|
||||
String(32), nullable=False, default="planned")
|
||||
track_geometry: Mapped[str] = mapped_column(
|
||||
Geometry(geometry_type="LINESTRING", srid=4326), nullable=False
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
UniqueConstraint(
|
||||
"start_station_id", "end_station_id", name="uq_tracks_station_pair"
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class CombinedTrack(Base, TimestampMixin):
|
||||
__tablename__ = "combined_tracks"
|
||||
|
||||
id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True), primary_key=True, default=uuid.uuid4
|
||||
)
|
||||
@@ -82,19 +123,25 @@ class Track(Base, TimestampMixin):
|
||||
ForeignKey("stations.id", ondelete="RESTRICT"),
|
||||
nullable=False,
|
||||
)
|
||||
length_meters: Mapped[float | None] = mapped_column(Numeric(10, 2), nullable=True)
|
||||
length_meters: Mapped[float | None] = mapped_column(
|
||||
Numeric(10, 2), nullable=True)
|
||||
max_speed_kph: Mapped[int | None] = mapped_column(Integer, nullable=True)
|
||||
is_bidirectional: Mapped[bool] = mapped_column(
|
||||
Boolean, nullable=False, default=True
|
||||
)
|
||||
status: Mapped[str] = mapped_column(String(32), nullable=False, default="planned")
|
||||
track_geometry: Mapped[str] = mapped_column(
|
||||
status: Mapped[str] = mapped_column(
|
||||
String(32), nullable=False, default="planned")
|
||||
combined_geometry: Mapped[str] = mapped_column(
|
||||
Geometry(geometry_type="LINESTRING", srid=4326), nullable=False
|
||||
)
|
||||
# JSON array of constituent track IDs
|
||||
constituent_track_ids: Mapped[str] = mapped_column(
|
||||
Text, nullable=False
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
UniqueConstraint(
|
||||
"start_station_id", "end_station_id", name="uq_tracks_station_pair"
|
||||
"start_station_id", "end_station_id", name="uq_combined_tracks_station_pair"
|
||||
),
|
||||
)
|
||||
|
||||
@@ -105,7 +152,8 @@ class Train(Base, TimestampMixin):
|
||||
id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True), primary_key=True, default=uuid.uuid4
|
||||
)
|
||||
designation: Mapped[str] = mapped_column(String(64), nullable=False, unique=True)
|
||||
designation: Mapped[str] = mapped_column(
|
||||
String(64), nullable=False, unique=True)
|
||||
operator_id: Mapped[uuid.UUID | None] = mapped_column(
|
||||
UUID(as_uuid=True), ForeignKey("users.id", ondelete="SET NULL")
|
||||
)
|
||||
|
||||
@@ -8,11 +8,14 @@ from .auth import (
|
||||
UserPublic,
|
||||
)
|
||||
from .base import (
|
||||
CombinedTrackCreate,
|
||||
CombinedTrackModel,
|
||||
StationCreate,
|
||||
StationModel,
|
||||
StationUpdate,
|
||||
TrackCreate,
|
||||
TrackModel,
|
||||
TrackUpdate,
|
||||
TrainCreate,
|
||||
TrainModel,
|
||||
TrainScheduleCreate,
|
||||
@@ -33,9 +36,12 @@ __all__ = [
|
||||
"StationUpdate",
|
||||
"TrackCreate",
|
||||
"TrackModel",
|
||||
"TrackUpdate",
|
||||
"TrainScheduleCreate",
|
||||
"TrainCreate",
|
||||
"TrainModel",
|
||||
"UserCreate",
|
||||
"to_camel",
|
||||
"CombinedTrackCreate",
|
||||
"CombinedTrackModel",
|
||||
]
|
||||
|
||||
@@ -51,13 +51,24 @@ class StationModel(IdentifiedModel[str]):
|
||||
class TrackModel(IdentifiedModel[str]):
|
||||
start_station_id: str
|
||||
end_station_id: str
|
||||
length_meters: float
|
||||
max_speed_kph: float
|
||||
length_meters: float | None = None
|
||||
max_speed_kph: float | None = None
|
||||
status: str | None = None
|
||||
is_bidirectional: bool = True
|
||||
coordinates: list[tuple[float, float]] = Field(default_factory=list)
|
||||
|
||||
|
||||
class CombinedTrackModel(IdentifiedModel[str]):
|
||||
start_station_id: str
|
||||
end_station_id: str
|
||||
length_meters: float | None = None
|
||||
max_speed_kph: int | None = None
|
||||
status: str | None = None
|
||||
is_bidirectional: bool = True
|
||||
coordinates: list[tuple[float, float]] = Field(default_factory=list)
|
||||
constituent_track_ids: list[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
class TrainModel(IdentifiedModel[str]):
|
||||
designation: str
|
||||
capacity: int
|
||||
@@ -89,6 +100,31 @@ class TrackCreate(CamelModel):
|
||||
start_station_id: str
|
||||
end_station_id: str
|
||||
coordinates: Sequence[tuple[float, float]]
|
||||
osm_id: str | None = None
|
||||
name: str | None = None
|
||||
length_meters: float | None = None
|
||||
max_speed_kph: int | None = None
|
||||
is_bidirectional: bool = True
|
||||
status: str = "planned"
|
||||
|
||||
|
||||
class TrackUpdate(CamelModel):
|
||||
start_station_id: str | None = None
|
||||
end_station_id: str | None = None
|
||||
coordinates: Sequence[tuple[float, float]] | None = None
|
||||
osm_id: str | None = None
|
||||
name: str | None = None
|
||||
length_meters: float | None = None
|
||||
max_speed_kph: int | None = None
|
||||
is_bidirectional: bool | None = None
|
||||
status: str | None = None
|
||||
|
||||
|
||||
class CombinedTrackCreate(CamelModel):
|
||||
start_station_id: str
|
||||
end_station_id: str
|
||||
coordinates: Sequence[tuple[float, float]]
|
||||
constituent_track_ids: list[str]
|
||||
name: str | None = None
|
||||
length_meters: float | None = None
|
||||
max_speed_kph: int | None = None
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
from backend.app.repositories.stations import StationRepository
|
||||
from backend.app.repositories.tracks import TrackRepository
|
||||
from backend.app.repositories.combined_tracks import CombinedTrackRepository
|
||||
from backend.app.repositories.train_schedules import TrainScheduleRepository
|
||||
from backend.app.repositories.trains import TrainRepository
|
||||
from backend.app.repositories.users import UserRepository
|
||||
@@ -10,6 +11,7 @@ __all__ = [
|
||||
"StationRepository",
|
||||
"TrainScheduleRepository",
|
||||
"TrackRepository",
|
||||
"CombinedTrackRepository",
|
||||
"TrainRepository",
|
||||
"UserRepository",
|
||||
]
|
||||
|
||||
73
backend/app/repositories/combined_tracks.py
Normal file
73
backend/app/repositories/combined_tracks.py
Normal file
@@ -0,0 +1,73 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from uuid import UUID
|
||||
|
||||
import sqlalchemy as sa
|
||||
from geoalchemy2.elements import WKTElement
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from backend.app.db.models import CombinedTrack
|
||||
from backend.app.models import CombinedTrackCreate
|
||||
from backend.app.repositories.base import BaseRepository
|
||||
|
||||
|
||||
class CombinedTrackRepository(BaseRepository[CombinedTrack]):
|
||||
model = CombinedTrack
|
||||
|
||||
def __init__(self, session: Session) -> None:
|
||||
super().__init__(session)
|
||||
|
||||
def list_all(self) -> list[CombinedTrack]:
|
||||
statement = sa.select(self.model)
|
||||
return list(self.session.scalars(statement))
|
||||
|
||||
def exists_between_stations(self, start_station_id: str, end_station_id: str) -> bool:
|
||||
"""Check if a combined track already exists between two stations."""
|
||||
statement = sa.select(sa.exists().where(
|
||||
sa.and_(
|
||||
self.model.start_station_id == start_station_id,
|
||||
self.model.end_station_id == end_station_id
|
||||
)
|
||||
))
|
||||
return bool(self.session.scalar(statement))
|
||||
|
||||
def get_constituent_track_ids(self, combined_track: CombinedTrack) -> list[str]:
|
||||
"""Extract constituent track IDs from a combined track."""
|
||||
try:
|
||||
return json.loads(combined_track.constituent_track_ids)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
return []
|
||||
|
||||
@staticmethod
|
||||
def _ensure_uuid(value: UUID | str) -> UUID:
|
||||
if isinstance(value, UUID):
|
||||
return value
|
||||
return UUID(str(value))
|
||||
|
||||
@staticmethod
|
||||
def _line_string(coordinates: list[tuple[float, float]]) -> WKTElement:
|
||||
if len(coordinates) < 2:
|
||||
raise ValueError(
|
||||
"Combined track geometry requires at least two coordinate pairs")
|
||||
parts = [f"{lon} {lat}" for lat, lon in coordinates]
|
||||
return WKTElement(f"LINESTRING({', '.join(parts)})", srid=4326)
|
||||
|
||||
def create(self, data: CombinedTrackCreate) -> CombinedTrack:
|
||||
coordinates = list(data.coordinates)
|
||||
geometry = self._line_string(coordinates)
|
||||
constituent_track_ids_json = json.dumps(data.constituent_track_ids)
|
||||
|
||||
combined_track = CombinedTrack(
|
||||
name=data.name,
|
||||
start_station_id=self._ensure_uuid(data.start_station_id),
|
||||
end_station_id=self._ensure_uuid(data.end_station_id),
|
||||
length_meters=data.length_meters,
|
||||
max_speed_kph=data.max_speed_kph,
|
||||
is_bidirectional=data.is_bidirectional,
|
||||
status=data.status,
|
||||
combined_geometry=geometry,
|
||||
constituent_track_ids=constituent_track_ids_json,
|
||||
)
|
||||
self.session.add(combined_track)
|
||||
return combined_track
|
||||
@@ -7,7 +7,7 @@ from geoalchemy2.elements import WKTElement
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from backend.app.db.models import Track
|
||||
from backend.app.models import TrackCreate
|
||||
from backend.app.models import TrackCreate, TrackUpdate
|
||||
from backend.app.repositories.base import BaseRepository
|
||||
|
||||
|
||||
@@ -21,6 +21,102 @@ class TrackRepository(BaseRepository[Track]):
|
||||
statement = sa.select(self.model)
|
||||
return list(self.session.scalars(statement))
|
||||
|
||||
def exists_by_osm_id(self, osm_id: str) -> bool:
|
||||
statement = sa.select(sa.exists().where(self.model.osm_id == osm_id))
|
||||
return bool(self.session.scalar(statement))
|
||||
|
||||
def find_path_between_stations(self, start_station_id: str, end_station_id: str) -> list[Track] | None:
|
||||
"""Find the shortest path between two stations using existing tracks.
|
||||
|
||||
Returns a list of tracks that form the path, or None if no path exists.
|
||||
"""
|
||||
# Build adjacency list: station -> list of (neighbor_station, track)
|
||||
adjacency = self._build_track_graph()
|
||||
|
||||
if start_station_id not in adjacency or end_station_id not in adjacency:
|
||||
return None
|
||||
|
||||
# BFS to find shortest path
|
||||
from collections import deque
|
||||
|
||||
# (current_station, path_so_far)
|
||||
queue = deque([(start_station_id, [])])
|
||||
visited = set([start_station_id])
|
||||
|
||||
while queue:
|
||||
current_station, path = queue.popleft()
|
||||
|
||||
if current_station == end_station_id:
|
||||
return path
|
||||
|
||||
for neighbor, track in adjacency[current_station]:
|
||||
if neighbor not in visited:
|
||||
visited.add(neighbor)
|
||||
queue.append((neighbor, path + [track]))
|
||||
|
||||
return None # No path found
|
||||
|
||||
def _build_track_graph(self) -> dict[str, list[tuple[str, Track]]]:
|
||||
"""Build a graph representation of tracks: station -> [(neighbor_station, track), ...]"""
|
||||
tracks = self.list_all()
|
||||
graph = {}
|
||||
|
||||
for track in tracks:
|
||||
start_id = str(track.start_station_id)
|
||||
end_id = str(track.end_station_id)
|
||||
|
||||
# Add bidirectional edges (assuming tracks are bidirectional)
|
||||
if start_id not in graph:
|
||||
graph[start_id] = []
|
||||
if end_id not in graph:
|
||||
graph[end_id] = []
|
||||
|
||||
graph[start_id].append((end_id, track))
|
||||
graph[end_id].append((start_id, track))
|
||||
|
||||
return graph
|
||||
|
||||
def combine_track_geometries(self, tracks: list[Track]) -> list[tuple[float, float]]:
|
||||
"""Combine the geometries of multiple tracks into a single coordinate sequence.
|
||||
|
||||
Assumes tracks are in order and form a continuous path.
|
||||
"""
|
||||
if not tracks:
|
||||
return []
|
||||
|
||||
combined_coords = []
|
||||
|
||||
for i, track in enumerate(tracks):
|
||||
# Extract coordinates from track geometry
|
||||
coords = self._extract_coordinates_from_track(track)
|
||||
|
||||
if i == 0:
|
||||
# First track: add all coordinates
|
||||
combined_coords.extend(coords)
|
||||
else:
|
||||
# Subsequent tracks: skip the first coordinate (shared with previous track)
|
||||
combined_coords.extend(coords[1:])
|
||||
|
||||
return combined_coords
|
||||
|
||||
def _extract_coordinates_from_track(self, track: Track) -> list[tuple[float, float]]:
|
||||
"""Extract coordinate list from a track's geometry."""
|
||||
# Convert WKT string to WKTElement, then to shapely geometry
|
||||
from geoalchemy2.elements import WKTElement
|
||||
from geoalchemy2.shape import to_shape
|
||||
|
||||
try:
|
||||
wkt_element = WKTElement(track.track_geometry)
|
||||
geom = to_shape(wkt_element)
|
||||
if hasattr(geom, 'coords'):
|
||||
# For LineString, coords returns [(x, y), ...] where x=lon, y=lat
|
||||
# Convert to (lat, lon)
|
||||
return [(coord[1], coord[0]) for coord in geom.coords]
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return []
|
||||
|
||||
@staticmethod
|
||||
def _ensure_uuid(value: UUID | str) -> UUID:
|
||||
if isinstance(value, UUID):
|
||||
@@ -30,7 +126,8 @@ class TrackRepository(BaseRepository[Track]):
|
||||
@staticmethod
|
||||
def _line_string(coordinates: list[tuple[float, float]]) -> WKTElement:
|
||||
if len(coordinates) < 2:
|
||||
raise ValueError("Track geometry requires at least two coordinate pairs")
|
||||
raise ValueError(
|
||||
"Track geometry requires at least two coordinate pairs")
|
||||
parts = [f"{lon} {lat}" for lat, lon in coordinates]
|
||||
return WKTElement(f"LINESTRING({', '.join(parts)})", srid=4326)
|
||||
|
||||
@@ -38,6 +135,7 @@ class TrackRepository(BaseRepository[Track]):
|
||||
coordinates = list(data.coordinates)
|
||||
geometry = self._line_string(coordinates)
|
||||
track = Track(
|
||||
osm_id=data.osm_id,
|
||||
name=data.name,
|
||||
start_station_id=self._ensure_uuid(data.start_station_id),
|
||||
end_station_id=self._ensure_uuid(data.end_station_id),
|
||||
@@ -49,3 +147,26 @@ class TrackRepository(BaseRepository[Track]):
|
||||
)
|
||||
self.session.add(track)
|
||||
return track
|
||||
|
||||
def update(self, track: Track, payload: TrackUpdate) -> Track:
|
||||
if payload.start_station_id is not None:
|
||||
track.start_station_id = self._ensure_uuid(
|
||||
payload.start_station_id)
|
||||
if payload.end_station_id is not None:
|
||||
track.end_station_id = self._ensure_uuid(payload.end_station_id)
|
||||
if payload.coordinates is not None:
|
||||
track.track_geometry = self._line_string(
|
||||
list(payload.coordinates)) # type: ignore[assignment]
|
||||
if payload.osm_id is not None:
|
||||
track.osm_id = payload.osm_id
|
||||
if payload.name is not None:
|
||||
track.name = payload.name
|
||||
if payload.length_meters is not None:
|
||||
track.length_meters = payload.length_meters
|
||||
if payload.max_speed_kph is not None:
|
||||
track.max_speed_kph = payload.max_speed_kph
|
||||
if payload.is_bidirectional is not None:
|
||||
track.is_bidirectional = payload.is_bidirectional
|
||||
if payload.status is not None:
|
||||
track.status = payload.status
|
||||
return track
|
||||
|
||||
79
backend/app/services/combined_tracks.py
Normal file
79
backend/app/services/combined_tracks.py
Normal file
@@ -0,0 +1,79 @@
|
||||
from __future__ import annotations
|
||||
|
||||
"""Application services for combined track operations."""
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from backend.app.models import CombinedTrackCreate, CombinedTrackModel
|
||||
from backend.app.repositories import CombinedTrackRepository, TrackRepository
|
||||
|
||||
|
||||
def create_combined_track(
|
||||
session: Session, start_station_id: str, end_station_id: str
|
||||
) -> CombinedTrackModel | None:
|
||||
"""Create a combined track between two stations using pathfinding.
|
||||
|
||||
Returns the created combined track, or None if no path exists or
|
||||
a combined track already exists between these stations.
|
||||
"""
|
||||
combined_track_repo = CombinedTrackRepository(session)
|
||||
track_repo = TrackRepository(session)
|
||||
|
||||
# Check if combined track already exists
|
||||
if combined_track_repo.exists_between_stations(start_station_id, end_station_id):
|
||||
return None
|
||||
|
||||
# Find path between stations
|
||||
path_tracks = track_repo.find_path_between_stations(
|
||||
start_station_id, end_station_id)
|
||||
if not path_tracks:
|
||||
return None
|
||||
|
||||
# Combine geometries
|
||||
combined_coords = track_repo.combine_track_geometries(path_tracks)
|
||||
if len(combined_coords) < 2:
|
||||
return None
|
||||
|
||||
# Calculate total length
|
||||
total_length = sum(track.length_meters or 0 for track in path_tracks)
|
||||
|
||||
# Get max speed (use the minimum speed of all tracks)
|
||||
max_speeds = [
|
||||
track.max_speed_kph for track in path_tracks if track.max_speed_kph]
|
||||
max_speed = min(max_speeds) if max_speeds else None
|
||||
|
||||
# Get constituent track IDs
|
||||
constituent_track_ids = [str(track.id) for track in path_tracks]
|
||||
|
||||
# Create combined track
|
||||
create_data = CombinedTrackCreate(
|
||||
start_station_id=start_station_id,
|
||||
end_station_id=end_station_id,
|
||||
coordinates=combined_coords,
|
||||
constituent_track_ids=constituent_track_ids,
|
||||
length_meters=total_length if total_length > 0 else None,
|
||||
max_speed_kph=max_speed,
|
||||
status="operational",
|
||||
)
|
||||
|
||||
combined_track = combined_track_repo.create(create_data)
|
||||
session.commit()
|
||||
|
||||
return CombinedTrackModel.model_validate(combined_track)
|
||||
|
||||
|
||||
def get_combined_track(session: Session, combined_track_id: str) -> CombinedTrackModel | None:
|
||||
"""Get a combined track by ID."""
|
||||
try:
|
||||
combined_track_repo = CombinedTrackRepository(session)
|
||||
combined_track = combined_track_repo.get(combined_track_id)
|
||||
return CombinedTrackModel.model_validate(combined_track)
|
||||
except LookupError:
|
||||
return None
|
||||
|
||||
|
||||
def list_combined_tracks(session: Session) -> list[CombinedTrackModel]:
|
||||
"""List all combined tracks."""
|
||||
combined_track_repo = CombinedTrackRepository(session)
|
||||
combined_tracks = combined_track_repo.list_all()
|
||||
return [CombinedTrackModel.model_validate(ct) for ct in combined_tracks]
|
||||
106
backend/app/services/tracks.py
Normal file
106
backend/app/services/tracks.py
Normal file
@@ -0,0 +1,106 @@
|
||||
from __future__ import annotations
|
||||
|
||||
"""Service layer for primary track management operations."""
|
||||
|
||||
from typing import Iterable
|
||||
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from backend.app.models import CombinedTrackModel, TrackCreate, TrackModel, TrackUpdate
|
||||
from backend.app.repositories import CombinedTrackRepository, TrackRepository
|
||||
|
||||
|
||||
def list_tracks(session: Session) -> list[TrackModel]:
|
||||
repo = TrackRepository(session)
|
||||
tracks = repo.list_all()
|
||||
return [TrackModel.model_validate(track) for track in tracks]
|
||||
|
||||
|
||||
def get_track(session: Session, track_id: str) -> TrackModel | None:
|
||||
repo = TrackRepository(session)
|
||||
track = repo.get(track_id)
|
||||
if track is None:
|
||||
return None
|
||||
return TrackModel.model_validate(track)
|
||||
|
||||
|
||||
def create_track(session: Session, payload: TrackCreate) -> TrackModel:
|
||||
repo = TrackRepository(session)
|
||||
try:
|
||||
track = repo.create(payload)
|
||||
session.commit()
|
||||
except IntegrityError as exc:
|
||||
session.rollback()
|
||||
raise ValueError(
|
||||
"Track with the same station pair already exists") from exc
|
||||
|
||||
return TrackModel.model_validate(track)
|
||||
|
||||
|
||||
def update_track(session: Session, track_id: str, payload: TrackUpdate) -> TrackModel | None:
|
||||
repo = TrackRepository(session)
|
||||
track = repo.get(track_id)
|
||||
if track is None:
|
||||
return None
|
||||
|
||||
repo.update(track, payload)
|
||||
session.commit()
|
||||
|
||||
return TrackModel.model_validate(track)
|
||||
|
||||
|
||||
def delete_track(session: Session, track_id: str, regenerate: bool = False) -> bool:
|
||||
repo = TrackRepository(session)
|
||||
track = repo.get(track_id)
|
||||
if track is None:
|
||||
return False
|
||||
|
||||
start_station_id = str(track.start_station_id)
|
||||
end_station_id = str(track.end_station_id)
|
||||
|
||||
session.delete(track)
|
||||
session.commit()
|
||||
|
||||
if regenerate:
|
||||
regenerate_combined_tracks(session, [start_station_id, end_station_id])
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def regenerate_combined_tracks(session: Session, station_ids: Iterable[str]) -> list[CombinedTrackModel]:
|
||||
combined_repo = CombinedTrackRepository(session)
|
||||
|
||||
station_id_set = set(station_ids)
|
||||
if not station_id_set:
|
||||
return []
|
||||
|
||||
# Remove combined tracks touching these stations
|
||||
for combined in combined_repo.list_all():
|
||||
if {str(combined.start_station_id), str(combined.end_station_id)} & station_id_set:
|
||||
session.delete(combined)
|
||||
|
||||
session.commit()
|
||||
|
||||
# Rebuild combined tracks between affected station pairs
|
||||
from backend.app.services.combined_tracks import create_combined_track
|
||||
|
||||
regenerated: list[CombinedTrackModel] = []
|
||||
station_list = list(station_id_set)
|
||||
for i in range(len(station_list)):
|
||||
for j in range(i + 1, len(station_list)):
|
||||
result = create_combined_track(
|
||||
session, station_list[i], station_list[j])
|
||||
if result is not None:
|
||||
regenerated.append(result)
|
||||
return regenerated
|
||||
|
||||
|
||||
__all__ = [
|
||||
"list_tracks",
|
||||
"get_track",
|
||||
"create_track",
|
||||
"update_track",
|
||||
"delete_track",
|
||||
"regenerate_combined_tracks",
|
||||
]
|
||||
Reference in New Issue
Block a user