diff --git a/TODO.md b/TODO.md index 840351c..80a3cc0 100644 --- a/TODO.md +++ b/TODO.md @@ -17,9 +17,9 @@ - [x] Implement an import script/CLI that pulls OSM station data and normalizes it to the PostGIS schema. - [x] Expose backend CRUD endpoints for stations (create, update, archive) with validation and geometry handling. - [x] Build React map tooling for selecting a station. -- [ ] Enhance map UI to support selecting two stations and previewing the rail corridor between them. -- [ ] Define track selection criteria and tagging rules for harvesting OSM rail segments within target regions. -- [ ] Extend the importer to load track geometries and associate them with existing stations. +- [x] Enhance map UI to support selecting two stations and previewing the rail corridor between them. +- [x] Define track selection criteria and tagging rules for harvesting OSM rail segments within target regions. +- [x] Extend the importer to load track geometries and associate them with existing stations. - [ ] Implement backend track-management APIs with length/speed validation and topology checks. - [ ] Implement track path mapping along existing OSM rail segments between chosen stations. - [ ] Design train connection manager requirements (link trains to operating tracks, manage consist data). diff --git a/backend/app/core/osm_config.py b/backend/app/core/osm_config.py index 48f9b5c..9881c1e 100644 --- a/backend/app/core/osm_config.py +++ b/backend/app/core/osm_config.py @@ -75,6 +75,18 @@ STATION_TAG_FILTERS: Mapping[str, Tuple[str, ...]] = { } +# Tags that describe rail infrastructure usable for train routing. +TRACK_TAG_FILTERS: Mapping[str, Tuple[str, ...]] = { + "railway": ( + "rail", + "light_rail", + "subway", + "tram", + "narrow_gauge", + ), +} + + def compile_overpass_filters(filters: Mapping[str, Iterable[str]]) -> str: """Build an Overpass boolean expression that matches the provided filters.""" @@ -89,5 +101,6 @@ __all__ = [ "BoundingBox", "DEFAULT_REGIONS", "STATION_TAG_FILTERS", + "TRACK_TAG_FILTERS", "compile_overpass_filters", ] diff --git a/backend/app/models/base.py b/backend/app/models/base.py index c00117b..a82e048 100644 --- a/backend/app/models/base.py +++ b/backend/app/models/base.py @@ -3,7 +3,7 @@ from __future__ import annotations from datetime import datetime from typing import Generic, Sequence, TypeVar -from pydantic import BaseModel, ConfigDict +from pydantic import BaseModel, ConfigDict, Field def to_camel(string: str) -> str: @@ -53,6 +53,9 @@ class TrackModel(IdentifiedModel[str]): end_station_id: str length_meters: float max_speed_kph: float + status: str | None = None + is_bidirectional: bool = True + coordinates: list[tuple[float, float]] = Field(default_factory=list) class TrainModel(IdentifiedModel[str]): diff --git a/backend/app/services/network.py b/backend/app/services/network.py index e661096..ebc4a8e 100644 --- a/backend/app/services/network.py +++ b/backend/app/services/network.py @@ -8,9 +8,10 @@ from geoalchemy2.elements import WKBElement, WKTElement from geoalchemy2.shape import to_shape try: # pragma: no cover - optional dependency guard - from shapely.geometry import Point # type: ignore + from shapely.geometry import LineString, Point # type: ignore except ImportError: # pragma: no cover - allow running without shapely at import time Point = None # type: ignore[assignment] + LineString = None # type: ignore[assignment] from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.orm import Session @@ -51,6 +52,12 @@ def _fallback_snapshot() -> dict[str, list[dict[str, object]]]: end_station_id="station-2", length_meters=289000.0, max_speed_kph=230.0, + status="operational", + is_bidirectional=True, + coordinates=[ + (stations[0].latitude, stations[0].longitude), + (stations[1].latitude, stations[1].longitude), + ], created_at=now, updated_at=now, ) @@ -134,6 +141,20 @@ def get_network_snapshot(session: Session) -> dict[str, list[dict[str, object]]] track_models: list[TrackModel] = [] for track in tracks_entities: + coordinates: list[tuple[float, float]] = [] + geometry = track.track_geometry + shape = ( + to_shape(cast(WKBElement | WKTElement, geometry)) + if geometry is not None and LineString is not None + else None + ) + if LineString is not None and shape is not None and isinstance(shape, LineString): + coords_list: list[tuple[float, float]] = [] + for coord in shape.coords: + lon = float(coord[0]) + lat = float(coord[1]) + coords_list.append((lat, lon)) + coordinates = coords_list track_models.append( TrackModel( id=str(track.id), @@ -141,6 +162,9 @@ def get_network_snapshot(session: Session) -> dict[str, list[dict[str, object]]] end_station_id=str(track.end_station_id), length_meters=_to_float(track.length_meters), max_speed_kph=_to_float(track.max_speed_kph), + status=track.status, + is_bidirectional=track.is_bidirectional, + coordinates=coordinates, created_at=cast(datetime, track.created_at), updated_at=cast(datetime, track.updated_at), ) diff --git a/backend/scripts/tracks_import.py b/backend/scripts/tracks_import.py new file mode 100644 index 0000000..080d50b --- /dev/null +++ b/backend/scripts/tracks_import.py @@ -0,0 +1,234 @@ +from __future__ import annotations + +"""CLI utility to export rail track geometries from OpenStreetMap.""" + +import argparse +import json +import math +import sys +from dataclasses import asdict +from pathlib import Path +from typing import Any, Iterable +from urllib.parse import quote_plus + +from backend.app.core.osm_config import ( + DEFAULT_REGIONS, + TRACK_TAG_FILTERS, + compile_overpass_filters, +) + +OVERPASS_ENDPOINT = "https://overpass-api.de/api/interpreter" + + +def build_argument_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser( + description="Export OSM rail track ways for ingestion", + ) + parser.add_argument( + "--output", + type=Path, + default=Path("data/osm_tracks.json"), + help=( + "Destination file for the exported track geometries " + "(default: data/osm_tracks.json)" + ), + ) + parser.add_argument( + "--region", + choices=[region.name for region in DEFAULT_REGIONS] + ["all"], + default="all", + help="Region name to export (default: all)", + ) + parser.add_argument( + "--dry-run", + action="store_true", + help="Do not fetch data; print the Overpass payload only", + ) + return parser + + +def build_overpass_query(region_name: str) -> str: + if region_name == "all": + regions = DEFAULT_REGIONS + else: + regions = tuple( + region for region in DEFAULT_REGIONS if region.name == region_name) + if not regions: + available = ", ".join(region.name for region in DEFAULT_REGIONS) + msg = f"Unknown region {region_name}. Available regions: [{available}]" + raise ValueError(msg) + + filters = compile_overpass_filters(TRACK_TAG_FILTERS) + + parts = ["[out:json][timeout:120];", "("] + for region in regions: + parts.append(f" way{filters}\n ({region.to_overpass_arg()});") + parts.append(")") + parts.append("; out body geom; >; out skel qt;") + return "\n".join(parts) + + +def perform_request(query: str) -> dict[str, Any]: + import urllib.request + + payload = f"data={quote_plus(query)}".encode("utf-8") + request = urllib.request.Request( + OVERPASS_ENDPOINT, + data=payload, + headers={"Content-Type": "application/x-www-form-urlencoded"}, + ) + with urllib.request.urlopen(request, timeout=180) as response: + payload = response.read() + return json.loads(payload) + + +def normalize_track_elements(elements: Iterable[dict[str, Any]]) -> list[dict[str, Any]]: + """Convert Overpass way elements into TrackCreate-compatible payloads.""" + + tracks: list[dict[str, Any]] = [] + for element in elements: + if element.get("type") != "way": + continue + + raw_geometry = element.get("geometry") or [] + coordinates: list[list[float]] = [] + for node in raw_geometry: + lat = node.get("lat") + lon = node.get("lon") + if lat is None or lon is None: + coordinates = [] + break + coordinates.append([float(lat), float(lon)]) + + if len(coordinates) < 2: + continue + + tags: dict[str, Any] = element.get("tags", {}) + name = tags.get("name") + maxspeed = _parse_maxspeed(tags.get("maxspeed")) + status = _derive_status(tags.get("railway")) + is_bidirectional = not _is_oneway(tags.get("oneway")) + + length_meters = _polyline_length(coordinates) + + tracks.append( + { + "osmId": str(element.get("id")), + "name": str(name) if name else None, + "lengthMeters": length_meters, + "maxSpeedKph": maxspeed, + "status": status, + "isBidirectional": is_bidirectional, + "coordinates": coordinates, + } + ) + + return tracks + + +def _parse_maxspeed(value: Any) -> float | None: + if value is None: + return None + + # Overpass may return values such as "80" or "80 km/h" or "signals". + if isinstance(value, (int, float)): + return float(value) + + text = str(value).strip() + number = "" + for char in text: + if char.isdigit() or char == ".": + number += char + elif number: + break + try: + return float(number) if number else None + except ValueError: + return None + + +def _derive_status(value: Any) -> str: + tag = str(value or "").lower() + if tag in {"abandoned", "disused"}: + return tag + if tag in {"construction", "proposed"}: + return "construction" + return "operational" + + +def _is_oneway(value: Any) -> bool: + if value is None: + return False + normalized = str(value).strip().lower() + return normalized in {"yes", "true", "1"} + + +def _polyline_length(points: list[list[float]]) -> float: + if len(points) < 2: + return 0.0 + + total = 0.0 + for index in range(len(points) - 1): + total += _haversine(points[index], points[index + 1]) + return total + + +def _haversine(a: list[float], b: list[float]) -> float: + """Return distance in meters between two [lat, lon] coordinates.""" + + lat1, lon1 = a + lat2, lon2 = b + radius = 6_371_000 + + phi1 = math.radians(lat1) + phi2 = math.radians(lat2) + delta_phi = math.radians(lat2 - lat1) + delta_lambda = math.radians(lon2 - lon1) + + sin_dphi = math.sin(delta_phi / 2) + sin_dlambda = math.sin(delta_lambda / 2) + root = sin_dphi**2 + math.cos(phi1) * math.cos(phi2) * sin_dlambda**2 + distance = 2 * radius * math.atan2(math.sqrt(root), math.sqrt(1 - root)) + return distance + + +def main(argv: list[str] | None = None) -> int: + parser = build_argument_parser() + args = parser.parse_args(argv) + + query = build_overpass_query(args.region) + + if args.dry_run: + print(query) + return 0 + + output_path: Path = args.output + output_path.parent.mkdir(parents=True, exist_ok=True) + + data = perform_request(query) + raw_elements = data.get("elements", []) + tracks = normalize_track_elements(raw_elements) + + payload = { + "metadata": { + "endpoint": OVERPASS_ENDPOINT, + "region": args.region, + "filters": TRACK_TAG_FILTERS, + "regions": [asdict(region) for region in DEFAULT_REGIONS], + "raw_count": len(raw_elements), + "track_count": len(tracks), + }, + "tracks": tracks, + } + + with output_path.open("w", encoding="utf-8") as handle: + json.dump(payload, handle, indent=2) + + print( + f"Normalized {len(tracks)} tracks from {len(raw_elements)} elements into {output_path}" + ) + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/backend/scripts/tracks_load.py b/backend/scripts/tracks_load.py new file mode 100644 index 0000000..cda1484 --- /dev/null +++ b/backend/scripts/tracks_load.py @@ -0,0 +1,243 @@ +from __future__ import annotations + +"""CLI for loading normalized track JSON into the database.""" + +import argparse +import json +import math +import sys +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Iterable, Mapping, Sequence + +from geoalchemy2.elements import WKBElement, WKTElement +from geoalchemy2.shape import to_shape + +from backend.app.db.session import SessionLocal +from backend.app.models import TrackCreate +from backend.app.repositories import StationRepository, TrackRepository + + +@dataclass(slots=True) +class ParsedTrack: + coordinates: list[tuple[float, float]] + name: str | None = None + length_meters: float | None = None + max_speed_kph: float | None = None + status: str = "operational" + is_bidirectional: bool = True + + +@dataclass(slots=True) +class StationRef: + id: str + latitude: float + longitude: float + + +def build_argument_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser( + description="Load normalized track data into PostGIS", + ) + parser.add_argument( + "input", + type=Path, + help="Path to the normalized track JSON file produced by tracks_import.py", + ) + parser.add_argument( + "--commit/--no-commit", + dest="commit", + default=True, + help="Commit the transaction (default: commit). Use --no-commit for dry runs.", + ) + return parser + + +def main(argv: list[str] | None = None) -> int: + parser = build_argument_parser() + args = parser.parse_args(argv) + + if not args.input.exists(): + parser.error(f"Input file {args.input} does not exist") + + with args.input.open("r", encoding="utf-8") as handle: + payload = json.load(handle) + + track_entries = payload.get("tracks") or [] + if not isinstance(track_entries, list): + parser.error("Invalid payload: 'tracks' must be a list") + + try: + tracks = _parse_track_entries(track_entries) + except ValueError as exc: + parser.error(str(exc)) + + created = load_tracks(tracks, commit=args.commit) + print(f"Loaded {created} tracks from {args.input}") + return 0 + + +def _parse_track_entries(entries: Iterable[Mapping[str, Any]]) -> list[ParsedTrack]: + parsed: list[ParsedTrack] = [] + for entry in entries: + coordinates = entry.get("coordinates") + if not isinstance(coordinates, Sequence) or len(coordinates) < 2: + raise ValueError( + "Invalid track entry: 'coordinates' must contain at least two points") + + processed_coordinates: list[tuple[float, float]] = [] + for pair in coordinates: + if not isinstance(pair, Sequence) or len(pair) != 2: + raise ValueError( + f"Invalid coordinate pair {pair!r} in track entry") + lat, lon = pair + processed_coordinates.append((float(lat), float(lon))) + + name = entry.get("name") + length = _safe_float(entry.get("lengthMeters")) + max_speed = _safe_float(entry.get("maxSpeedKph")) + status = entry.get("status", "operational") + is_bidirectional = entry.get("isBidirectional", True) + + parsed.append( + ParsedTrack( + coordinates=processed_coordinates, + name=str(name) if name else None, + length_meters=length, + max_speed_kph=max_speed, + status=str(status) if status else "operational", + is_bidirectional=bool(is_bidirectional), + ) + ) + return parsed + + +def load_tracks(tracks: Iterable[ParsedTrack], commit: bool = True) -> int: + created = 0 + with SessionLocal() as session: + station_repo = StationRepository(session) + track_repo = TrackRepository(session) + + station_index = _build_station_index(station_repo.list_active()) + existing_pairs = { + (str(track.start_station_id), str(track.end_station_id)) + for track in track_repo.list_all() + } + + for track_data in tracks: + start_station = _nearest_station( + track_data.coordinates[0], station_index) + end_station = _nearest_station( + track_data.coordinates[-1], station_index) + + if not start_station or not end_station: + continue + + pair = (start_station.id, end_station.id) + if pair in existing_pairs: + continue + + length = track_data.length_meters or _polyline_length( + track_data.coordinates) + create_schema = TrackCreate( + name=track_data.name, + start_station_id=start_station.id, + end_station_id=end_station.id, + coordinates=track_data.coordinates, + length_meters=length, + max_speed_kph=track_data.max_speed_kph, + status=track_data.status, + is_bidirectional=track_data.is_bidirectional, + ) + + track_repo.create(create_schema) + existing_pairs.add(pair) + created += 1 + + if commit: + session.commit() + else: + session.rollback() + + return created + + +def _nearest_station( + coordinate: tuple[float, float], stations: Sequence[StationRef] +) -> StationRef | None: + best_station: StationRef | None = None + best_distance = math.inf + for station in stations: + distance = _haversine( + coordinate, (station.latitude, station.longitude)) + if distance < best_distance: + best_station = station + best_distance = distance + return best_station + + +def _build_station_index(stations: Iterable[Any]) -> list[StationRef]: + index: list[StationRef] = [] + for station in stations: + location = getattr(station, "location", None) + if location is None: + continue + point = _to_point(location) + if point is None: + continue + index.append( + StationRef( + id=str(station.id), + latitude=float(point.y), + longitude=float(point.x), + ) + ) + return index + + +def _to_point(geometry: WKBElement | WKTElement | Any): + try: + point = to_shape(geometry) + return point if getattr(point, "geom_type", None) == "Point" else None + except Exception: # pragma: no cover - defensive, should not happen with valid geometry + return None + + +def _polyline_length(points: Sequence[tuple[float, float]]) -> float: + if len(points) < 2: + return 0.0 + + total = 0.0 + for index in range(len(points) - 1): + total += _haversine(points[index], points[index + 1]) + return total + + +def _haversine(a: tuple[float, float], b: tuple[float, float]) -> float: + lat1, lon1 = a + lat2, lon2 = b + radius = 6_371_000 + + phi1 = math.radians(lat1) + phi2 = math.radians(lat2) + delta_phi = math.radians(lat2 - lat1) + delta_lambda = math.radians(lon2 - lon1) + + sin_dphi = math.sin(delta_phi / 2) + sin_dlambda = math.sin(delta_lambda / 2) + root = sin_dphi**2 + math.cos(phi1) * math.cos(phi2) * sin_dlambda**2 + distance = 2 * radius * math.atan2(math.sqrt(root), math.sqrt(1 - root)) + return distance + + +def _safe_float(value: Any) -> float | None: + if value is None or value == "": + return None + try: + return float(value) + except (TypeError, ValueError): + return None + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/backend/tests/test_models.py b/backend/tests/test_models.py index 9175a43..8ba4509 100644 --- a/backend/tests/test_models.py +++ b/backend/tests/test_models.py @@ -29,11 +29,15 @@ def test_track_model_properties() -> None: end_station_id="station-2", length_meters=1500.0, max_speed_kph=120.0, + status="operational", + is_bidirectional=True, + coordinates=[(52.52, 13.405), (52.6, 13.5)], created_at=timestamp, updated_at=timestamp, ) assert track.length_meters > 0 assert track.start_station_id != track.end_station_id + assert len(track.coordinates) == 2 def test_train_model_operating_tracks() -> None: diff --git a/backend/tests/test_network_service.py b/backend/tests/test_network_service.py index 83ceee5..3fb803a 100644 --- a/backend/tests/test_network_service.py +++ b/backend/tests/test_network_service.py @@ -26,6 +26,9 @@ def sample_entities() -> dict[str, SimpleNamespace]: end_station_id=station.id, length_meters=1234.5, max_speed_kph=160, + status="operational", + is_bidirectional=True, + track_geometry=None, created_at=timestamp, updated_at=timestamp, ) @@ -47,7 +50,8 @@ def test_network_snapshot_prefers_repository_data( track = sample_entities["track"] train = sample_entities["train"] - monkeypatch.setattr(StationRepository, "list_active", lambda self: [station]) + monkeypatch.setattr(StationRepository, "list_active", + lambda self: [station]) monkeypatch.setattr(TrackRepository, "list_all", lambda self: [track]) monkeypatch.setattr(TrainRepository, "list_all", lambda self: [train]) @@ -55,7 +59,8 @@ def test_network_snapshot_prefers_repository_data( assert snapshot["stations"] assert snapshot["stations"][0]["name"] == station.name - assert snapshot["tracks"][0]["lengthMeters"] == pytest.approx(track.length_meters) + assert snapshot["tracks"][0]["lengthMeters"] == pytest.approx( + track.length_meters) assert snapshot["trains"][0]["designation"] == train.designation assert snapshot["trains"][0]["operatingTrackIds"] == [] @@ -71,4 +76,5 @@ def test_network_snapshot_falls_back_when_repositories_empty( assert snapshot["stations"] assert snapshot["trains"] - assert any(station["name"] == "Central" for station in snapshot["stations"]) + assert any(station["name"] == + "Central" for station in snapshot["stations"]) diff --git a/backend/tests/test_tracks_import.py b/backend/tests/test_tracks_import.py new file mode 100644 index 0000000..85d7e32 --- /dev/null +++ b/backend/tests/test_tracks_import.py @@ -0,0 +1,69 @@ +from __future__ import annotations + +from backend.scripts import tracks_import + + +def test_normalize_track_elements_excludes_invalid_geometries() -> None: + elements = [ + { + "type": "way", + "id": 123, + "geometry": [ + {"lat": 52.5, "lon": 13.4}, + {"lat": 52.6, "lon": 13.5}, + ], + "tags": { + "name": "Main Line", + "railway": "rail", + "maxspeed": "120", + }, + }, + { + "type": "way", + "id": 456, + "geometry": [ + {"lat": 51.0}, + ], + "tags": {"railway": "rail"}, + }, + { + "type": "node", + "id": 789, + }, + ] + + tracks = tracks_import.normalize_track_elements(elements) + + assert len(tracks) == 1 + track = tracks[0] + assert track["osmId"] == "123" + assert track["name"] == "Main Line" + assert track["maxSpeedKph"] == 120.0 + assert track["status"] == "operational" + assert track["isBidirectional"] is True + assert track["coordinates"] == [[52.5, 13.4], [52.6, 13.5]] + assert track["lengthMeters"] > 0 + + +def test_normalize_track_elements_marks_oneway_and_status() -> None: + elements = [ + { + "type": "way", + "id": 42, + "geometry": [ + {"lat": 48.1, "lon": 11.5}, + {"lat": 48.2, "lon": 11.6}, + ], + "tags": { + "railway": "disused", + "oneway": "yes", + }, + } + ] + + tracks = tracks_import.normalize_track_elements(elements) + + assert len(tracks) == 1 + track = tracks[0] + assert track["status"] == "disused" + assert track["isBidirectional"] is False diff --git a/backend/tests/test_tracks_load.py b/backend/tests/test_tracks_load.py new file mode 100644 index 0000000..7b5f7cf --- /dev/null +++ b/backend/tests/test_tracks_load.py @@ -0,0 +1,168 @@ +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import List + +import pytest +from geoalchemy2.shape import from_shape +from shapely.geometry import Point + +from backend.scripts import tracks_load + + +def test_parse_track_entries_returns_models() -> None: + entries = [ + { + "name": "Connector", + "coordinates": [[52.5, 13.4], [52.6, 13.5]], + "lengthMeters": 1500, + "maxSpeedKph": 120, + "status": "operational", + "isBidirectional": True, + } + ] + + parsed = tracks_load._parse_track_entries(entries) + + assert parsed[0].name == "Connector" + assert parsed[0].coordinates[0] == (52.5, 13.4) + assert parsed[0].length_meters == 1500 + assert parsed[0].max_speed_kph == 120 + + +def test_parse_track_entries_invalid_raises_value_error() -> None: + entries = [ + { + "coordinates": [[52.5, 13.4]], + } + ] + + with pytest.raises(ValueError): + tracks_load._parse_track_entries(entries) + + +@dataclass +class DummySession: + committed: bool = False + rolled_back: bool = False + + def __enter__(self) -> "DummySession": + return self + + def __exit__(self, exc_type, exc, traceback) -> None: + pass + + def commit(self) -> None: + self.committed = True + + def rollback(self) -> None: + self.rolled_back = True + + +@dataclass +class DummyStation: + id: str + location: object + + +@dataclass +class DummyStationRepository: + session: DummySession + stations: List[DummyStation] + + def list_active(self) -> List[DummyStation]: + return self.stations + + +@dataclass +class DummyTrackRepository: + session: DummySession + created: list = field(default_factory=list) + existing: list = field(default_factory=list) + + def list_all(self): + return self.existing + + def create(self, data): # pragma: no cover - simple delegation + self.created.append(data) + + +def _point(lat: float, lon: float) -> object: + return from_shape(Point(lon, lat), srid=4326) + + +def test_load_tracks_creates_entries(monkeypatch: pytest.MonkeyPatch) -> None: + session_instance = DummySession() + station_repo_instance = DummyStationRepository( + session_instance, + stations=[ + DummyStation(id="station-a", location=_point(52.5, 13.4)), + DummyStation(id="station-b", location=_point(52.6, 13.5)), + ], + ) + track_repo_instance = DummyTrackRepository(session_instance) + + monkeypatch.setattr(tracks_load, "SessionLocal", lambda: session_instance) + monkeypatch.setattr(tracks_load, "StationRepository", + lambda session: station_repo_instance) + monkeypatch.setattr(tracks_load, "TrackRepository", + lambda session: track_repo_instance) + + parsed = tracks_load._parse_track_entries( + [ + { + "name": "Connector", + "coordinates": [[52.5, 13.4], [52.6, 13.5]], + } + ] + ) + + created = tracks_load.load_tracks(parsed, commit=True) + + assert created == 1 + assert session_instance.committed is True + assert track_repo_instance.created + track = track_repo_instance.created[0] + assert track.start_station_id == "station-a" + assert track.end_station_id == "station-b" + assert track.coordinates == [(52.5, 13.4), (52.6, 13.5)] + + +def test_load_tracks_skips_existing_pairs(monkeypatch: pytest.MonkeyPatch) -> None: + session_instance = DummySession() + station_repo_instance = DummyStationRepository( + session_instance, + stations=[ + DummyStation(id="station-a", location=_point(52.5, 13.4)), + DummyStation(id="station-b", location=_point(52.6, 13.5)), + ], + ) + existing_track = type("ExistingTrack", (), { + "start_station_id": "station-a", + "end_station_id": "station-b", + }) + track_repo_instance = DummyTrackRepository( + session_instance, + existing=[existing_track], + ) + + monkeypatch.setattr(tracks_load, "SessionLocal", lambda: session_instance) + monkeypatch.setattr(tracks_load, "StationRepository", + lambda session: station_repo_instance) + monkeypatch.setattr(tracks_load, "TrackRepository", + lambda session: track_repo_instance) + + parsed = tracks_load._parse_track_entries( + [ + { + "name": "Connector", + "coordinates": [[52.5, 13.4], [52.6, 13.5]], + } + ] + ) + + created = tracks_load.load_tracks(parsed, commit=False) + + assert created == 0 + assert session_instance.rolled_back is True + assert not track_repo_instance.created diff --git a/docs/05_Building_Block_View.md b/docs/05_Building_Block_View.md index 4dbeb3a..94aaeb5 100644 --- a/docs/05_Building_Block_View.md +++ b/docs/05_Building_Block_View.md @@ -111,6 +111,7 @@ graph TD - **Health Module**: Lightweight readiness probes used by infrastructure checks. - **Network Module**: Serves read-only snapshots of stations, tracks, and trains using shared domain models (camelCase aliases for client compatibility). +- **OSM Ingestion CLI**: Script pairings (`stations_import`/`stations_load`, `tracks_import`/`tracks_load`) that harvest OpenStreetMap fixtures and persist normalized station and track geometries into PostGIS. - **Authentication Module**: JWT-based user registration, authentication, and authorization. The current prototype supports on-the-fly account creation backed by an in-memory user store and issues short-lived access tokens to validate the client flow end-to-end. - **Railway Calculation Module**: Algorithms for route optimization and scheduling (planned). - **Resource Management Module**: Logic for game economy and progression (planned). @@ -157,4 +158,3 @@ rail-game/ ``` Shared code that spans application layers should be surfaced through well-defined APIs within `backend/app/services` or exposed via frontend data contracts to keep coupling low. Infrastructure automation and CI/CD assets remain isolated under `infra/` to support multiple deployment targets. - diff --git a/docs/06_Runtime_View.md b/docs/06_Runtime_View.md index 8202459..cfdb057 100644 --- a/docs/06_Runtime_View.md +++ b/docs/06_Runtime_View.md @@ -78,7 +78,31 @@ sequenceDiagram F->>F: Render Leaflet map and snapshot summaries ``` -#### 6.2.4 Building Railway Network +#### 6.2.4 OSM Track Import and Load + +**Scenario**: Operator refreshes spatial fixtures by harvesting OSM railways and persisting them to PostGIS. + +**Description**: The paired CLI scripts `tracks_import.py` and `tracks_load.py` export candidate track segments from Overpass, associate endpoints with the nearest known stations, and store the resulting LINESTRING geometries. Dry-run flags allow inspection of the generated Overpass payload or database mutations before commit. + +```mermaid +sequenceDiagram + participant Ops as Operator + participant TI as tracks_import.py + participant OL as Overpass API + participant TL as tracks_load.py + participant DB as PostGIS + + Ops->>TI: Invoke with region + output path + TI->>OL: POST compiled Overpass query + OL-->>TI: Return rail way elements (JSON) + TI-->>Ops: Write normalized tracks JSON + Ops->>TL: Invoke with normalized JSON + TL->>DB: Fetch stations + existing tracks + TL->>DB: Insert snapped LINESTRING geometries + TL-->>Ops: Report committed track count +``` + +#### 6.2.5 Building Railway Network **Scenario**: User adds a new track segment to their railway network. @@ -101,7 +125,7 @@ sequenceDiagram F->>F: Update map display ``` -#### 6.2.5 Running Train Simulation +#### 6.2.6 Running Train Simulation **Scenario**: User starts a train simulation on their network. @@ -129,7 +153,7 @@ sequenceDiagram end ``` -#### 6.2.6 Saving Game Progress +#### 6.2.7 Saving Game Progress **Scenario**: User saves their current game state. @@ -154,4 +178,3 @@ sequenceDiagram - **Real-time Updates**: WebSocket connections for simulation updates, with fallback to polling - **Load Balancing**: Backend API can be scaled horizontally for multiple users - **CDN**: Static assets and map tiles served via CDN for faster loading - diff --git a/docs/architecture.md b/docs/architecture.md index 1b961c2..6b70996 100644 --- a/docs/architecture.md +++ b/docs/architecture.md @@ -100,6 +100,7 @@ The system interacts with: - Browser-native implementation for broad accessibility - Spatial database for efficient geographical queries +- Offline-friendly OSM ingestion pipeline that uses dedicated CLI scripts to export/import stations and tracks before seeding the database - Modular architecture allowing for future extensions (e.g., multiplayer) ## 5. Building Block View diff --git a/frontend/src/App.tsx b/frontend/src/App.tsx index a0af01c..b81158a 100644 --- a/frontend/src/App.tsx +++ b/frontend/src/App.tsx @@ -68,21 +68,22 @@ function App(): JSX.Element { [data] ); - const routeComputation = useMemo(() => { - const core = computeRoute({ - startId: routeSelection.startId, - endId: routeSelection.endId, - stationById, - adjacency: trackAdjacency, - }); + const routeComputation = useMemo( + () => + computeRoute({ + startId: routeSelection.startId, + endId: routeSelection.endId, + stationById, + adjacency: trackAdjacency, + }), + [routeSelection, stationById, trackAdjacency] + ); - const segments = core.stations ? buildSegmentsFromStations(core.stations) : []; - - return { - ...core, - segments, - }; - }, [routeSelection, stationById, trackAdjacency]); + const routeSegments = useMemo(() => { + return routeComputation.segments.map((segment) => + segment.map((pair) => [pair[0], pair[1]] as LatLngExpression) + ); + }, [routeComputation.segments]); const focusedStation = useMemo(() => { if (!data || !focusedStationId) { @@ -144,7 +145,7 @@ function App(): JSX.Element { focusedStationId={focusedStationId} startStationId={routeSelection.startId} endStationId={routeSelection.endId} - routeSegments={routeComputation.segments} + routeSegments={routeSegments} onStationClick={handleStationSelection} /> @@ -267,7 +268,9 @@ function App(): JSX.Element { {data.tracks.map((track) => (
  • {track.startStationId} → {track.endStationId} ·{' '} - {(track.lengthMeters / 1000).toFixed(1)} km + {track.lengthMeters > 0 + ? `${(track.lengthMeters / 1000).toFixed(1)} km` + : 'N/A'}
  • ))} @@ -329,16 +332,3 @@ export default App; function hasStation(stations: Station[], id: string): boolean { return stations.some((station) => station.id === id); } - -function buildSegmentsFromStations(stations: Station[]): LatLngExpression[][] { - const segments: LatLngExpression[][] = []; - for (let index = 0; index < stations.length - 1; index += 1) { - const current = stations[index]; - const next = stations[index + 1]; - segments.push([ - [current.latitude, current.longitude], - [next.latitude, next.longitude], - ]); - } - return segments; -} diff --git a/frontend/src/components/map/NetworkMap.tsx b/frontend/src/components/map/NetworkMap.tsx index 593ad9a..589f6fb 100644 --- a/frontend/src/components/map/NetworkMap.tsx +++ b/frontend/src/components/map/NetworkMap.tsx @@ -57,6 +57,12 @@ export function NetworkMap({ const trackSegments = useMemo(() => { return snapshot.tracks .map((track) => { + if (track.coordinates && track.coordinates.length >= 2) { + return track.coordinates.map( + (pair) => [pair[0], pair[1]] as LatLngExpression + ); + } + const start = stationLookup.get(track.startStationId); const end = stationLookup.get(track.endStationId); if (!start || !end) { diff --git a/frontend/src/types/domain.ts b/frontend/src/types/domain.ts index 4d525ab..3eb6247 100644 --- a/frontend/src/types/domain.ts +++ b/frontend/src/types/domain.ts @@ -22,6 +22,9 @@ export interface Track extends Identified { readonly endStationId: string; readonly lengthMeters: number; readonly maxSpeedKph: number; + readonly status?: string | null; + readonly isBidirectional?: boolean; + readonly coordinates: readonly [number, number][]; } export interface Train extends Identified { diff --git a/frontend/src/utils/route.test.ts b/frontend/src/utils/route.test.ts index 5ff5b40..7b74d78 100644 --- a/frontend/src/utils/route.test.ts +++ b/frontend/src/utils/route.test.ts @@ -48,6 +48,11 @@ describe('route utilities', () => { endStationId: 'station-b', lengthMeters: 1200, maxSpeedKph: 120, + coordinates: [ + [51.5, -0.1], + [51.51, -0.105], + [51.52, -0.11], + ], ...baseTimestamps, }, { @@ -56,6 +61,11 @@ describe('route utilities', () => { endStationId: 'station-c', lengthMeters: 1500, maxSpeedKph: 110, + coordinates: [ + [51.52, -0.11], + [51.53, -0.115], + [51.54, -0.12], + ], ...baseTimestamps, }, { @@ -64,6 +74,11 @@ describe('route utilities', () => { endStationId: 'station-d', lengthMeters: 900, maxSpeedKph: 115, + coordinates: [ + [51.54, -0.12], + [51.545, -0.13], + [51.55, -0.15], + ], ...baseTimestamps, }, ]; @@ -91,6 +106,9 @@ describe('route utilities', () => { 'track-cd', ]); expect(result.totalLength).toBe(1200 + 1500 + 900); + expect(result.segments).toHaveLength(3); + expect(result.segments[0][0]).toEqual([51.5, -0.1]); + expect(result.segments[2][result.segments[2].length - 1]).toEqual([51.55, -0.15]); }); it('returns an error when no path exists', () => { @@ -118,6 +136,10 @@ describe('route utilities', () => { endStationId: 'station-a', lengthMeters: 0, maxSpeedKph: 80, + coordinates: [ + [51.5, -0.1], + [51.5005, -0.1005], + ], ...baseTimestamps, }, ]; @@ -137,5 +159,58 @@ describe('route utilities', () => { expect(result.error).toBe( 'No rail connection found between the selected stations.' ); + expect(result.segments).toHaveLength(0); + }); + + it('reverses track geometry when traversing in the opposite direction', () => { + const stations: Station[] = [ + { + id: 'station-a', + name: 'Alpha', + latitude: 51.5, + longitude: -0.1, + ...baseTimestamps, + }, + { + id: 'station-b', + name: 'Bravo', + latitude: 51.52, + longitude: -0.11, + ...baseTimestamps, + }, + ]; + + const tracks: Track[] = [ + { + id: 'track-ab', + startStationId: 'station-a', + endStationId: 'station-b', + lengthMeters: 1200, + maxSpeedKph: 120, + coordinates: [ + [51.5, -0.1], + [51.52, -0.11], + ], + ...baseTimestamps, + }, + ]; + + const stationById = new Map(stations.map((station) => [station.id, station])); + const adjacency = buildTrackAdjacency(tracks); + + const result = computeRoute({ + startId: 'station-b', + endId: 'station-a', + stationById, + adjacency, + }); + + expect(result.error).toBeNull(); + expect(result.segments).toEqual([ + [ + [51.52, -0.11], + [51.5, -0.1], + ], + ]); }); }); diff --git a/frontend/src/utils/route.ts b/frontend/src/utils/route.ts index f9592af..0edebe8 100644 --- a/frontend/src/utils/route.ts +++ b/frontend/src/utils/route.ts @@ -1,8 +1,11 @@ import type { Station, Track } from '../types/domain'; +export type LatLngTuple = readonly [number, number]; + export interface NeighborEdge { readonly neighborId: string; readonly track: Track; + readonly isForward: boolean; } export type TrackAdjacency = Map; @@ -19,21 +22,22 @@ export interface RouteComputation { readonly tracks: Track[]; readonly totalLength: number | null; readonly error: string | null; + readonly segments: LatLngTuple[][]; } export function buildTrackAdjacency(tracks: readonly Track[]): TrackAdjacency { const adjacency: TrackAdjacency = new Map(); - const register = (fromId: string, toId: string, track: Track) => { + const register = (fromId: string, toId: string, track: Track, isForward: boolean) => { if (!adjacency.has(fromId)) { adjacency.set(fromId, []); } - adjacency.get(fromId)!.push({ neighborId: toId, track }); + adjacency.get(fromId)!.push({ neighborId: toId, track, isForward }); }; for (const track of tracks) { - register(track.startStationId, track.endStationId, track); - register(track.endStationId, track.startStationId, track); + register(track.startStationId, track.endStationId, track, true); + register(track.endStationId, track.startStationId, track, false); } return adjacency; @@ -55,6 +59,7 @@ export function computeRoute({ tracks: [], totalLength: null, error: 'Selected stations are no longer available.', + segments: [], }; } @@ -65,16 +70,17 @@ export function computeRoute({ tracks: [], totalLength: 0, error: null, + segments: [], }; } const visited = new Set(); const queue: string[] = []; - const parent = new Map(); + const parent = new Map(); queue.push(startId); visited.add(startId); - parent.set(startId, { prev: null, via: null }); + parent.set(startId, { prev: null, edge: null }); while (queue.length > 0) { const current = queue.shift()!; @@ -83,12 +89,13 @@ export function computeRoute({ } const neighbors = adjacency.get(current) ?? []; - for (const { neighborId, track } of neighbors) { + for (const edge of neighbors) { + const { neighborId } = edge; if (visited.has(neighborId)) { continue; } visited.add(neighborId); - parent.set(neighborId, { prev: current, via: track }); + parent.set(neighborId, { prev: current, edge }); queue.push(neighborId); } } @@ -99,11 +106,13 @@ export function computeRoute({ tracks: [], totalLength: null, error: 'No rail connection found between the selected stations.', + segments: [], }; } const stationPath: string[] = []; const trackSequence: Track[] = []; + const directions: boolean[] = []; let cursor: string | null = endId; while (cursor) { @@ -112,19 +121,22 @@ export function computeRoute({ break; } stationPath.push(cursor); - if (details.via) { - trackSequence.push(details.via); + if (details.edge) { + trackSequence.push(details.edge.track); + directions.push(details.edge.isForward); } cursor = details.prev; } stationPath.reverse(); trackSequence.reverse(); + directions.reverse(); const stations = stationPath .map((id) => stationById.get(id)) .filter((station): station is Station => Boolean(station)); + const segments = buildSegments(trackSequence, directions, stationById); const totalLength = computeTotalLength(trackSequence, stations); return { @@ -132,9 +144,50 @@ export function computeRoute({ tracks: trackSequence, totalLength, error: null, + segments, }; } +function buildSegments( + tracks: Track[], + directions: boolean[], + stationById: Map +): LatLngTuple[][] { + const segments: LatLngTuple[][] = []; + + for (let index = 0; index < tracks.length; index += 1) { + const track = tracks[index]; + const isForward = directions[index] ?? true; + const coordinates = extractTrackCoordinates(track, stationById); + if (coordinates.length < 2) { + continue; + } + segments.push(isForward ? coordinates : [...coordinates].reverse()); + } + + return segments; +} + +function extractTrackCoordinates( + track: Track, + stationById: Map +): LatLngTuple[] { + if (Array.isArray(track.coordinates) && track.coordinates.length >= 2) { + return track.coordinates.map((pair) => [pair[0], pair[1]] as LatLngTuple); + } + + const start = stationById.get(track.startStationId); + const end = stationById.get(track.endStationId); + if (!start || !end) { + return []; + } + + return [ + [start.latitude, start.longitude], + [end.latitude, end.longitude], + ]; +} + function computeTotalLength(tracks: Track[], stations: Station[]): number | null { if (tracks.length === 0 && stations.length <= 1) { return 0; @@ -181,5 +234,6 @@ function emptyResult(): RouteComputation { tracks: [], totalLength: null, error: null, + segments: [], }; }