feat: Add combined track functionality with repository and service layers
- Introduced CombinedTrackModel, CombinedTrackCreate, and CombinedTrackRepository for managing combined tracks. - Implemented logic to create combined tracks based on existing tracks between two stations. - Added methods to check for existing combined tracks and retrieve constituent track IDs. - Enhanced TrackModel and TrackRepository to support OSM ID and track updates. - Created migration scripts for adding combined tracks table and OSM ID to tracks. - Updated services and API endpoints to handle combined track operations. - Added tests for combined track creation, repository methods, and API interactions.
This commit is contained in:
79
backend/app/services/combined_tracks.py
Normal file
79
backend/app/services/combined_tracks.py
Normal file
@@ -0,0 +1,79 @@
|
||||
from __future__ import annotations
|
||||
|
||||
"""Application services for combined track operations."""
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from backend.app.models import CombinedTrackCreate, CombinedTrackModel
|
||||
from backend.app.repositories import CombinedTrackRepository, TrackRepository
|
||||
|
||||
|
||||
def create_combined_track(
|
||||
session: Session, start_station_id: str, end_station_id: str
|
||||
) -> CombinedTrackModel | None:
|
||||
"""Create a combined track between two stations using pathfinding.
|
||||
|
||||
Returns the created combined track, or None if no path exists or
|
||||
a combined track already exists between these stations.
|
||||
"""
|
||||
combined_track_repo = CombinedTrackRepository(session)
|
||||
track_repo = TrackRepository(session)
|
||||
|
||||
# Check if combined track already exists
|
||||
if combined_track_repo.exists_between_stations(start_station_id, end_station_id):
|
||||
return None
|
||||
|
||||
# Find path between stations
|
||||
path_tracks = track_repo.find_path_between_stations(
|
||||
start_station_id, end_station_id)
|
||||
if not path_tracks:
|
||||
return None
|
||||
|
||||
# Combine geometries
|
||||
combined_coords = track_repo.combine_track_geometries(path_tracks)
|
||||
if len(combined_coords) < 2:
|
||||
return None
|
||||
|
||||
# Calculate total length
|
||||
total_length = sum(track.length_meters or 0 for track in path_tracks)
|
||||
|
||||
# Get max speed (use the minimum speed of all tracks)
|
||||
max_speeds = [
|
||||
track.max_speed_kph for track in path_tracks if track.max_speed_kph]
|
||||
max_speed = min(max_speeds) if max_speeds else None
|
||||
|
||||
# Get constituent track IDs
|
||||
constituent_track_ids = [str(track.id) for track in path_tracks]
|
||||
|
||||
# Create combined track
|
||||
create_data = CombinedTrackCreate(
|
||||
start_station_id=start_station_id,
|
||||
end_station_id=end_station_id,
|
||||
coordinates=combined_coords,
|
||||
constituent_track_ids=constituent_track_ids,
|
||||
length_meters=total_length if total_length > 0 else None,
|
||||
max_speed_kph=max_speed,
|
||||
status="operational",
|
||||
)
|
||||
|
||||
combined_track = combined_track_repo.create(create_data)
|
||||
session.commit()
|
||||
|
||||
return CombinedTrackModel.model_validate(combined_track)
|
||||
|
||||
|
||||
def get_combined_track(session: Session, combined_track_id: str) -> CombinedTrackModel | None:
|
||||
"""Get a combined track by ID."""
|
||||
try:
|
||||
combined_track_repo = CombinedTrackRepository(session)
|
||||
combined_track = combined_track_repo.get(combined_track_id)
|
||||
return CombinedTrackModel.model_validate(combined_track)
|
||||
except LookupError:
|
||||
return None
|
||||
|
||||
|
||||
def list_combined_tracks(session: Session) -> list[CombinedTrackModel]:
|
||||
"""List all combined tracks."""
|
||||
combined_track_repo = CombinedTrackRepository(session)
|
||||
combined_tracks = combined_track_repo.list_all()
|
||||
return [CombinedTrackModel.model_validate(ct) for ct in combined_tracks]
|
||||
106
backend/app/services/tracks.py
Normal file
106
backend/app/services/tracks.py
Normal file
@@ -0,0 +1,106 @@
|
||||
from __future__ import annotations
|
||||
|
||||
"""Service layer for primary track management operations."""
|
||||
|
||||
from typing import Iterable
|
||||
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from backend.app.models import CombinedTrackModel, TrackCreate, TrackModel, TrackUpdate
|
||||
from backend.app.repositories import CombinedTrackRepository, TrackRepository
|
||||
|
||||
|
||||
def list_tracks(session: Session) -> list[TrackModel]:
|
||||
repo = TrackRepository(session)
|
||||
tracks = repo.list_all()
|
||||
return [TrackModel.model_validate(track) for track in tracks]
|
||||
|
||||
|
||||
def get_track(session: Session, track_id: str) -> TrackModel | None:
|
||||
repo = TrackRepository(session)
|
||||
track = repo.get(track_id)
|
||||
if track is None:
|
||||
return None
|
||||
return TrackModel.model_validate(track)
|
||||
|
||||
|
||||
def create_track(session: Session, payload: TrackCreate) -> TrackModel:
|
||||
repo = TrackRepository(session)
|
||||
try:
|
||||
track = repo.create(payload)
|
||||
session.commit()
|
||||
except IntegrityError as exc:
|
||||
session.rollback()
|
||||
raise ValueError(
|
||||
"Track with the same station pair already exists") from exc
|
||||
|
||||
return TrackModel.model_validate(track)
|
||||
|
||||
|
||||
def update_track(session: Session, track_id: str, payload: TrackUpdate) -> TrackModel | None:
|
||||
repo = TrackRepository(session)
|
||||
track = repo.get(track_id)
|
||||
if track is None:
|
||||
return None
|
||||
|
||||
repo.update(track, payload)
|
||||
session.commit()
|
||||
|
||||
return TrackModel.model_validate(track)
|
||||
|
||||
|
||||
def delete_track(session: Session, track_id: str, regenerate: bool = False) -> bool:
|
||||
repo = TrackRepository(session)
|
||||
track = repo.get(track_id)
|
||||
if track is None:
|
||||
return False
|
||||
|
||||
start_station_id = str(track.start_station_id)
|
||||
end_station_id = str(track.end_station_id)
|
||||
|
||||
session.delete(track)
|
||||
session.commit()
|
||||
|
||||
if regenerate:
|
||||
regenerate_combined_tracks(session, [start_station_id, end_station_id])
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def regenerate_combined_tracks(session: Session, station_ids: Iterable[str]) -> list[CombinedTrackModel]:
|
||||
combined_repo = CombinedTrackRepository(session)
|
||||
|
||||
station_id_set = set(station_ids)
|
||||
if not station_id_set:
|
||||
return []
|
||||
|
||||
# Remove combined tracks touching these stations
|
||||
for combined in combined_repo.list_all():
|
||||
if {str(combined.start_station_id), str(combined.end_station_id)} & station_id_set:
|
||||
session.delete(combined)
|
||||
|
||||
session.commit()
|
||||
|
||||
# Rebuild combined tracks between affected station pairs
|
||||
from backend.app.services.combined_tracks import create_combined_track
|
||||
|
||||
regenerated: list[CombinedTrackModel] = []
|
||||
station_list = list(station_id_set)
|
||||
for i in range(len(station_list)):
|
||||
for j in range(i + 1, len(station_list)):
|
||||
result = create_combined_track(
|
||||
session, station_list[i], station_list[j])
|
||||
if result is not None:
|
||||
regenerated.append(result)
|
||||
return regenerated
|
||||
|
||||
|
||||
__all__ = [
|
||||
"list_tracks",
|
||||
"get_track",
|
||||
"create_track",
|
||||
"update_track",
|
||||
"delete_track",
|
||||
"regenerate_combined_tracks",
|
||||
]
|
||||
Reference in New Issue
Block a user