feat: Add combined track functionality with repository and service layers
- Introduced CombinedTrackModel, CombinedTrackCreate, and CombinedTrackRepository for managing combined tracks. - Implemented logic to create combined tracks based on existing tracks between two stations. - Added methods to check for existing combined tracks and retrieve constituent track IDs. - Enhanced TrackModel and TrackRepository to support OSM ID and track updates. - Created migration scripts for adding combined tracks table and OSM ID to tracks. - Updated services and API endpoints to handle combined track operations. - Added tests for combined track creation, repository methods, and API interactions.
This commit is contained in:
166
backend/tests/test_combined_tracks.py
Normal file
166
backend/tests/test_combined_tracks.py
Normal file
@@ -0,0 +1,166 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, List
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.app.models import CombinedTrackModel
|
||||
from backend.app.repositories.combined_tracks import CombinedTrackRepository
|
||||
from backend.app.repositories.tracks import TrackRepository
|
||||
from backend.app.services.combined_tracks import create_combined_track
|
||||
|
||||
|
||||
@dataclass
|
||||
class DummySession:
|
||||
added: List[Any] = field(default_factory=list)
|
||||
scalars_result: List[Any] = field(default_factory=list)
|
||||
scalar_result: Any = None
|
||||
statements: List[Any] = field(default_factory=list)
|
||||
committed: bool = False
|
||||
rolled_back: bool = False
|
||||
closed: bool = False
|
||||
|
||||
def add(self, instance: Any) -> None:
|
||||
self.added.append(instance)
|
||||
|
||||
def add_all(self, instances: list[Any]) -> None:
|
||||
self.added.extend(instances)
|
||||
|
||||
def scalars(self, statement: Any) -> list[Any]:
|
||||
self.statements.append(statement)
|
||||
return list(self.scalars_result)
|
||||
|
||||
def scalar(self, statement: Any) -> Any:
|
||||
self.statements.append(statement)
|
||||
return self.scalar_result
|
||||
|
||||
def flush(
|
||||
self, _objects: list[Any] | None = None
|
||||
) -> None: # pragma: no cover - optional
|
||||
return None
|
||||
|
||||
def commit(self) -> None: # pragma: no cover - optional
|
||||
self.committed = True
|
||||
|
||||
def rollback(self) -> None: # pragma: no cover - optional
|
||||
self.rolled_back = True
|
||||
|
||||
def close(self) -> None: # pragma: no cover - optional
|
||||
self.closed = True
|
||||
|
||||
|
||||
def _now() -> datetime:
|
||||
return datetime.now(timezone.utc)
|
||||
|
||||
|
||||
def test_combined_track_model_round_trip() -> None:
|
||||
timestamp = _now()
|
||||
combined_track = CombinedTrackModel(
|
||||
id="combined-track-1",
|
||||
start_station_id="station-1",
|
||||
end_station_id="station-2",
|
||||
length_meters=3000.0,
|
||||
max_speed_kph=100,
|
||||
status="operational",
|
||||
is_bidirectional=True,
|
||||
coordinates=[(52.52, 13.405), (52.6, 13.5), (52.7, 13.6)],
|
||||
constituent_track_ids=["track-1", "track-2"],
|
||||
created_at=timestamp,
|
||||
updated_at=timestamp,
|
||||
)
|
||||
assert combined_track.length_meters == 3000.0
|
||||
assert combined_track.start_station_id != combined_track.end_station_id
|
||||
assert len(combined_track.coordinates) == 3
|
||||
assert len(combined_track.constituent_track_ids) == 2
|
||||
|
||||
|
||||
def test_combined_track_repository_create() -> None:
|
||||
"""Test creating a combined track through the repository."""
|
||||
session = DummySession()
|
||||
repo = CombinedTrackRepository(session) # type: ignore[arg-type]
|
||||
|
||||
# Create test data
|
||||
from backend.app.models import CombinedTrackCreate
|
||||
|
||||
create_data = CombinedTrackCreate(
|
||||
start_station_id="550e8400-e29b-41d4-a716-446655440000",
|
||||
end_station_id="550e8400-e29b-41d4-a716-446655440001",
|
||||
coordinates=[(52.52, 13.405), (52.6, 13.5)],
|
||||
constituent_track_ids=["track-1"],
|
||||
length_meters=1500.0,
|
||||
max_speed_kph=120,
|
||||
status="operational",
|
||||
)
|
||||
|
||||
combined_track = repo.create(create_data)
|
||||
|
||||
assert combined_track.start_station_id is not None
|
||||
assert combined_track.end_station_id is not None
|
||||
assert combined_track.length_meters == 1500.0
|
||||
assert combined_track.max_speed_kph == 120
|
||||
assert combined_track.status == "operational"
|
||||
assert session.added and session.added[0] is combined_track
|
||||
|
||||
|
||||
def test_combined_track_repository_exists_between_stations() -> None:
|
||||
"""Test checking if combined track exists between stations."""
|
||||
session = DummySession()
|
||||
repo = CombinedTrackRepository(session) # type: ignore[arg-type]
|
||||
|
||||
# Initially should not exist (scalar_result is None by default)
|
||||
assert not repo.exists_between_stations(
|
||||
"550e8400-e29b-41d4-a716-446655440000",
|
||||
"550e8400-e29b-41d4-a716-446655440001"
|
||||
)
|
||||
|
||||
# Simulate existing combined track
|
||||
session.scalar_result = True
|
||||
assert repo.exists_between_stations(
|
||||
"550e8400-e29b-41d4-a716-446655440000",
|
||||
"550e8400-e29b-41d4-a716-446655440001"
|
||||
)
|
||||
|
||||
|
||||
def test_combined_track_service_create_no_path() -> None:
|
||||
"""Test creating combined track when no path exists."""
|
||||
# Mock session and repositories
|
||||
session = DummySession()
|
||||
|
||||
# Mock TrackRepository to return no path
|
||||
class MockTrackRepository:
|
||||
def __init__(self, session):
|
||||
pass
|
||||
|
||||
def find_path_between_stations(self, start_id, end_id):
|
||||
return None
|
||||
|
||||
# Mock CombinedTrackRepository
|
||||
class MockCombinedTrackRepository:
|
||||
def __init__(self, session):
|
||||
pass
|
||||
|
||||
def exists_between_stations(self, start_id, end_id):
|
||||
return False
|
||||
|
||||
# Patch the service to use mock repositories
|
||||
import backend.app.services.combined_tracks as service_module
|
||||
original_track_repo = service_module.TrackRepository
|
||||
original_combined_repo = service_module.CombinedTrackRepository
|
||||
|
||||
service_module.TrackRepository = MockTrackRepository
|
||||
service_module.CombinedTrackRepository = MockCombinedTrackRepository
|
||||
|
||||
try:
|
||||
result = create_combined_track(
|
||||
session, # type: ignore[arg-type]
|
||||
"550e8400-e29b-41d4-a716-446655440000",
|
||||
"550e8400-e29b-41d4-a716-446655440001"
|
||||
)
|
||||
assert result is None
|
||||
finally:
|
||||
# Restore original classes
|
||||
service_module.TrackRepository = original_track_repo
|
||||
service_module.CombinedTrackRepository = original_combined_repo
|
||||
Reference in New Issue
Block a user