from __future__ import annotations from dataclasses import dataclass, field from datetime import datetime, timezone from typing import Any, List from uuid import uuid4 import pytest from backend.app.models import CombinedTrackModel from backend.app.repositories.combined_tracks import CombinedTrackRepository from backend.app.repositories.tracks import TrackRepository from backend.app.services.combined_tracks import create_combined_track @dataclass class DummySession: added: List[Any] = field(default_factory=list) scalars_result: List[Any] = field(default_factory=list) scalar_result: Any = None statements: List[Any] = field(default_factory=list) committed: bool = False rolled_back: bool = False closed: bool = False def add(self, instance: Any) -> None: self.added.append(instance) def add_all(self, instances: list[Any]) -> None: self.added.extend(instances) def scalars(self, statement: Any) -> list[Any]: self.statements.append(statement) return list(self.scalars_result) def scalar(self, statement: Any) -> Any: self.statements.append(statement) return self.scalar_result def flush( self, _objects: list[Any] | None = None ) -> None: # pragma: no cover - optional return None def commit(self) -> None: # pragma: no cover - optional self.committed = True def rollback(self) -> None: # pragma: no cover - optional self.rolled_back = True def close(self) -> None: # pragma: no cover - optional self.closed = True def _now() -> datetime: return datetime.now(timezone.utc) def test_combined_track_model_round_trip() -> None: timestamp = _now() combined_track = CombinedTrackModel( id="combined-track-1", start_station_id="station-1", end_station_id="station-2", length_meters=3000.0, max_speed_kph=100, status="operational", is_bidirectional=True, coordinates=[(52.52, 13.405), (52.6, 13.5), (52.7, 13.6)], constituent_track_ids=["track-1", "track-2"], created_at=timestamp, updated_at=timestamp, ) assert combined_track.length_meters == 3000.0 assert combined_track.start_station_id != combined_track.end_station_id assert len(combined_track.coordinates) == 3 assert len(combined_track.constituent_track_ids) == 2 def test_combined_track_repository_create() -> None: """Test creating a combined track through the repository.""" session = DummySession() repo = CombinedTrackRepository(session) # type: ignore[arg-type] # Create test data from backend.app.models import CombinedTrackCreate create_data = CombinedTrackCreate( start_station_id="550e8400-e29b-41d4-a716-446655440000", end_station_id="550e8400-e29b-41d4-a716-446655440001", coordinates=[(52.52, 13.405), (52.6, 13.5)], constituent_track_ids=["track-1"], length_meters=1500.0, max_speed_kph=120, status="operational", ) combined_track = repo.create(create_data) assert combined_track.start_station_id is not None assert combined_track.end_station_id is not None assert combined_track.length_meters == 1500.0 assert combined_track.max_speed_kph == 120 assert combined_track.status == "operational" assert session.added and session.added[0] is combined_track def test_combined_track_repository_exists_between_stations() -> None: """Test checking if combined track exists between stations.""" session = DummySession() repo = CombinedTrackRepository(session) # type: ignore[arg-type] # Initially should not exist (scalar_result is None by default) assert not repo.exists_between_stations( "550e8400-e29b-41d4-a716-446655440000", "550e8400-e29b-41d4-a716-446655440001" ) # Simulate existing combined track session.scalar_result = True assert repo.exists_between_stations( "550e8400-e29b-41d4-a716-446655440000", "550e8400-e29b-41d4-a716-446655440001" ) def test_combined_track_service_create_no_path() -> None: """Test creating combined track when no path exists.""" # Mock session and repositories session = DummySession() # Mock TrackRepository to return no path class MockTrackRepository: def __init__(self, session): pass def find_path_between_stations(self, start_id, end_id): return None # Mock CombinedTrackRepository class MockCombinedTrackRepository: def __init__(self, session): pass def exists_between_stations(self, start_id, end_id): return False # Patch the service to use mock repositories import backend.app.services.combined_tracks as service_module original_track_repo = service_module.TrackRepository original_combined_repo = service_module.CombinedTrackRepository service_module.TrackRepository = MockTrackRepository service_module.CombinedTrackRepository = MockCombinedTrackRepository try: result = create_combined_track( session, # type: ignore[arg-type] "550e8400-e29b-41d4-a716-446655440000", "550e8400-e29b-41d4-a716-446655440001" ) assert result is None finally: # Restore original classes service_module.TrackRepository = original_track_repo service_module.CombinedTrackRepository = original_combined_repo