- Introduced CombinedTrackModel, CombinedTrackCreate, and CombinedTrackRepository for managing combined tracks. - Implemented logic to create combined tracks based on existing tracks between two stations. - Added methods to check for existing combined tracks and retrieve constituent track IDs. - Enhanced TrackModel and TrackRepository to support OSM ID and track updates. - Created migration scripts for adding combined tracks table and OSM ID to tracks. - Updated services and API endpoints to handle combined track operations. - Added tests for combined track creation, repository methods, and API interactions.
107 lines
3.0 KiB
Python
107 lines
3.0 KiB
Python
from __future__ import annotations
|
|
|
|
"""Service layer for primary track management operations."""
|
|
|
|
from typing import Iterable
|
|
|
|
from sqlalchemy.exc import IntegrityError
|
|
from sqlalchemy.orm import Session
|
|
|
|
from backend.app.models import CombinedTrackModel, TrackCreate, TrackModel, TrackUpdate
|
|
from backend.app.repositories import CombinedTrackRepository, TrackRepository
|
|
|
|
|
|
def list_tracks(session: Session) -> list[TrackModel]:
|
|
repo = TrackRepository(session)
|
|
tracks = repo.list_all()
|
|
return [TrackModel.model_validate(track) for track in tracks]
|
|
|
|
|
|
def get_track(session: Session, track_id: str) -> TrackModel | None:
|
|
repo = TrackRepository(session)
|
|
track = repo.get(track_id)
|
|
if track is None:
|
|
return None
|
|
return TrackModel.model_validate(track)
|
|
|
|
|
|
def create_track(session: Session, payload: TrackCreate) -> TrackModel:
|
|
repo = TrackRepository(session)
|
|
try:
|
|
track = repo.create(payload)
|
|
session.commit()
|
|
except IntegrityError as exc:
|
|
session.rollback()
|
|
raise ValueError(
|
|
"Track with the same station pair already exists") from exc
|
|
|
|
return TrackModel.model_validate(track)
|
|
|
|
|
|
def update_track(session: Session, track_id: str, payload: TrackUpdate) -> TrackModel | None:
|
|
repo = TrackRepository(session)
|
|
track = repo.get(track_id)
|
|
if track is None:
|
|
return None
|
|
|
|
repo.update(track, payload)
|
|
session.commit()
|
|
|
|
return TrackModel.model_validate(track)
|
|
|
|
|
|
def delete_track(session: Session, track_id: str, regenerate: bool = False) -> bool:
|
|
repo = TrackRepository(session)
|
|
track = repo.get(track_id)
|
|
if track is None:
|
|
return False
|
|
|
|
start_station_id = str(track.start_station_id)
|
|
end_station_id = str(track.end_station_id)
|
|
|
|
session.delete(track)
|
|
session.commit()
|
|
|
|
if regenerate:
|
|
regenerate_combined_tracks(session, [start_station_id, end_station_id])
|
|
|
|
return True
|
|
|
|
|
|
def regenerate_combined_tracks(session: Session, station_ids: Iterable[str]) -> list[CombinedTrackModel]:
|
|
combined_repo = CombinedTrackRepository(session)
|
|
|
|
station_id_set = set(station_ids)
|
|
if not station_id_set:
|
|
return []
|
|
|
|
# Remove combined tracks touching these stations
|
|
for combined in combined_repo.list_all():
|
|
if {str(combined.start_station_id), str(combined.end_station_id)} & station_id_set:
|
|
session.delete(combined)
|
|
|
|
session.commit()
|
|
|
|
# Rebuild combined tracks between affected station pairs
|
|
from backend.app.services.combined_tracks import create_combined_track
|
|
|
|
regenerated: list[CombinedTrackModel] = []
|
|
station_list = list(station_id_set)
|
|
for i in range(len(station_list)):
|
|
for j in range(i + 1, len(station_list)):
|
|
result = create_combined_track(
|
|
session, station_list[i], station_list[j])
|
|
if result is not None:
|
|
regenerated.append(result)
|
|
return regenerated
|
|
|
|
|
|
__all__ = [
|
|
"list_tracks",
|
|
"get_track",
|
|
"create_track",
|
|
"update_track",
|
|
"delete_track",
|
|
"regenerate_combined_tracks",
|
|
]
|