feat: Add combined track functionality with repository and service layers
Some checks failed
Backend CI / lint-and-test (push) Failing after 2m27s
Frontend CI / lint-and-build (push) Successful in 57s

- Introduced CombinedTrackModel, CombinedTrackCreate, and CombinedTrackRepository for managing combined tracks.
- Implemented logic to create combined tracks based on existing tracks between two stations.
- Added methods to check for existing combined tracks and retrieve constituent track IDs.
- Enhanced TrackModel and TrackRepository to support OSM ID and track updates.
- Created migration scripts for adding combined tracks table and OSM ID to tracks.
- Updated services and API endpoints to handle combined track operations.
- Added tests for combined track creation, repository methods, and API interactions.
This commit is contained in:
2025-11-10 14:12:28 +01:00
parent f73ab7ad14
commit 68048ff574
21 changed files with 1107 additions and 103 deletions

View File

@@ -0,0 +1,166 @@
from __future__ import annotations
from dataclasses import dataclass, field
from datetime import datetime, timezone
from typing import Any, List
from uuid import uuid4
import pytest
from backend.app.models import CombinedTrackModel
from backend.app.repositories.combined_tracks import CombinedTrackRepository
from backend.app.repositories.tracks import TrackRepository
from backend.app.services.combined_tracks import create_combined_track
@dataclass
class DummySession:
added: List[Any] = field(default_factory=list)
scalars_result: List[Any] = field(default_factory=list)
scalar_result: Any = None
statements: List[Any] = field(default_factory=list)
committed: bool = False
rolled_back: bool = False
closed: bool = False
def add(self, instance: Any) -> None:
self.added.append(instance)
def add_all(self, instances: list[Any]) -> None:
self.added.extend(instances)
def scalars(self, statement: Any) -> list[Any]:
self.statements.append(statement)
return list(self.scalars_result)
def scalar(self, statement: Any) -> Any:
self.statements.append(statement)
return self.scalar_result
def flush(
self, _objects: list[Any] | None = None
) -> None: # pragma: no cover - optional
return None
def commit(self) -> None: # pragma: no cover - optional
self.committed = True
def rollback(self) -> None: # pragma: no cover - optional
self.rolled_back = True
def close(self) -> None: # pragma: no cover - optional
self.closed = True
def _now() -> datetime:
return datetime.now(timezone.utc)
def test_combined_track_model_round_trip() -> None:
timestamp = _now()
combined_track = CombinedTrackModel(
id="combined-track-1",
start_station_id="station-1",
end_station_id="station-2",
length_meters=3000.0,
max_speed_kph=100,
status="operational",
is_bidirectional=True,
coordinates=[(52.52, 13.405), (52.6, 13.5), (52.7, 13.6)],
constituent_track_ids=["track-1", "track-2"],
created_at=timestamp,
updated_at=timestamp,
)
assert combined_track.length_meters == 3000.0
assert combined_track.start_station_id != combined_track.end_station_id
assert len(combined_track.coordinates) == 3
assert len(combined_track.constituent_track_ids) == 2
def test_combined_track_repository_create() -> None:
"""Test creating a combined track through the repository."""
session = DummySession()
repo = CombinedTrackRepository(session) # type: ignore[arg-type]
# Create test data
from backend.app.models import CombinedTrackCreate
create_data = CombinedTrackCreate(
start_station_id="550e8400-e29b-41d4-a716-446655440000",
end_station_id="550e8400-e29b-41d4-a716-446655440001",
coordinates=[(52.52, 13.405), (52.6, 13.5)],
constituent_track_ids=["track-1"],
length_meters=1500.0,
max_speed_kph=120,
status="operational",
)
combined_track = repo.create(create_data)
assert combined_track.start_station_id is not None
assert combined_track.end_station_id is not None
assert combined_track.length_meters == 1500.0
assert combined_track.max_speed_kph == 120
assert combined_track.status == "operational"
assert session.added and session.added[0] is combined_track
def test_combined_track_repository_exists_between_stations() -> None:
"""Test checking if combined track exists between stations."""
session = DummySession()
repo = CombinedTrackRepository(session) # type: ignore[arg-type]
# Initially should not exist (scalar_result is None by default)
assert not repo.exists_between_stations(
"550e8400-e29b-41d4-a716-446655440000",
"550e8400-e29b-41d4-a716-446655440001"
)
# Simulate existing combined track
session.scalar_result = True
assert repo.exists_between_stations(
"550e8400-e29b-41d4-a716-446655440000",
"550e8400-e29b-41d4-a716-446655440001"
)
def test_combined_track_service_create_no_path() -> None:
"""Test creating combined track when no path exists."""
# Mock session and repositories
session = DummySession()
# Mock TrackRepository to return no path
class MockTrackRepository:
def __init__(self, session):
pass
def find_path_between_stations(self, start_id, end_id):
return None
# Mock CombinedTrackRepository
class MockCombinedTrackRepository:
def __init__(self, session):
pass
def exists_between_stations(self, start_id, end_id):
return False
# Patch the service to use mock repositories
import backend.app.services.combined_tracks as service_module
original_track_repo = service_module.TrackRepository
original_combined_repo = service_module.CombinedTrackRepository
service_module.TrackRepository = MockTrackRepository
service_module.CombinedTrackRepository = MockCombinedTrackRepository
try:
result = create_combined_track(
session, # type: ignore[arg-type]
"550e8400-e29b-41d4-a716-446655440000",
"550e8400-e29b-41d4-a716-446655440001"
)
assert result is None
finally:
# Restore original classes
service_module.TrackRepository = original_track_repo
service_module.CombinedTrackRepository = original_combined_repo

View File

@@ -0,0 +1,158 @@
from __future__ import annotations
from datetime import datetime, timezone
from typing import Any
import pytest
from fastapi.testclient import TestClient
from backend.app.api import tracks as tracks_api
from backend.app.main import app
from backend.app.models import CombinedTrackModel, TrackModel
client = TestClient(app)
def _track_model(track_id: str = "track-1") -> TrackModel:
now = datetime.now(timezone.utc)
return TrackModel(
id=track_id,
start_station_id="station-a",
end_station_id="station-b",
length_meters=None,
max_speed_kph=None,
status="planned",
coordinates=[(52.5, 13.4), (52.6, 13.5)],
is_bidirectional=True,
created_at=now,
updated_at=now,
)
def _combined_model(track_id: str = "combined-1") -> CombinedTrackModel:
now = datetime.now(timezone.utc)
return CombinedTrackModel(
id=track_id,
start_station_id="station-a",
end_station_id="station-b",
length_meters=1000,
max_speed_kph=120,
status="operational",
coordinates=[(52.5, 13.4), (52.6, 13.5)],
constituent_track_ids=["track-1", "track-2"],
is_bidirectional=True,
created_at=now,
updated_at=now,
)
def _authenticate() -> str:
response = client.post(
"/api/auth/login",
json={"username": "demo", "password": "railgame123"},
)
assert response.status_code == 200
return response.json()["accessToken"]
def test_list_tracks(monkeypatch: pytest.MonkeyPatch) -> None:
token = _authenticate()
monkeypatch.setattr(tracks_api, "list_tracks", lambda db: [_track_model()])
response = client.get(
"/api/tracks",
headers={"Authorization": f"Bearer {token}"},
)
assert response.status_code == 200
payload = response.json()
assert isinstance(payload, list)
assert payload[0]["id"] == "track-1"
def test_get_track_returns_404(monkeypatch: pytest.MonkeyPatch) -> None:
token = _authenticate()
monkeypatch.setattr(tracks_api, "get_track", lambda db, track_id: None)
response = client.get(
"/api/tracks/not-found",
headers={"Authorization": f"Bearer {token}"},
)
assert response.status_code == 404
def test_create_track_calls_service(monkeypatch: pytest.MonkeyPatch) -> None:
token = _authenticate()
captured: dict[str, Any] = {}
payload = {
"startStationId": "station-a",
"endStationId": "station-b",
"coordinates": [[52.5, 13.4], [52.6, 13.5]],
}
def fake_create(db: Any, data: Any) -> TrackModel:
assert data.start_station_id == "station-a"
captured["payload"] = data
return _track_model("track-new")
monkeypatch.setattr(tracks_api, "create_track", fake_create)
response = client.post(
"/api/tracks",
json=payload,
headers={"Authorization": f"Bearer {token}"},
)
assert response.status_code == 201
body = response.json()
assert body["id"] == "track-new"
assert captured["payload"].end_station_id == "station-b"
def test_delete_track_returns_404(monkeypatch: pytest.MonkeyPatch) -> None:
token = _authenticate()
monkeypatch.setattr(
tracks_api, "delete_track", lambda db, tid, regenerate=False: False
)
response = client.delete(
"/api/tracks/missing",
headers={"Authorization": f"Bearer {token}"},
)
assert response.status_code == 404
def test_delete_track_success(monkeypatch: pytest.MonkeyPatch) -> None:
token = _authenticate()
seen: dict[str, Any] = {}
def fake_delete(db: Any, track_id: str, regenerate: bool = False) -> bool:
seen["track_id"] = track_id
seen["regenerate"] = regenerate
return True
monkeypatch.setattr(tracks_api, "delete_track", fake_delete)
response = client.delete(
"/api/tracks/track-99",
params={"regenerate": "true"},
headers={"Authorization": f"Bearer {token}"},
)
assert response.status_code == 204
assert seen["track_id"] == "track-99"
assert seen["regenerate"] is True
def test_list_combined_tracks(monkeypatch: pytest.MonkeyPatch) -> None:
token = _authenticate()
monkeypatch.setattr(
tracks_api, "list_combined_tracks", lambda db: [_combined_model()]
)
response = client.get(
"/api/tracks/combined",
headers={"Authorization": f"Bearer {token}"},
)
assert response.status_code == 200
payload = response.json()
assert len(payload) == 1
assert payload[0]["id"] == "combined-1"