Compare commits

...

4 Commits

Author SHA1 Message Date
68048ff574 feat: Add combined track functionality with repository and service layers
Some checks failed
Backend CI / lint-and-test (push) Failing after 2m27s
Frontend CI / lint-and-build (push) Successful in 57s
- Introduced CombinedTrackModel, CombinedTrackCreate, and CombinedTrackRepository for managing combined tracks.
- Implemented logic to create combined tracks based on existing tracks between two stations.
- Added methods to check for existing combined tracks and retrieve constituent track IDs.
- Enhanced TrackModel and TrackRepository to support OSM ID and track updates.
- Created migration scripts for adding combined tracks table and OSM ID to tracks.
- Updated services and API endpoints to handle combined track operations.
- Added tests for combined track creation, repository methods, and API interactions.
2025-11-10 14:12:28 +01:00
f73ab7ad14 fix: simplify import statements in osm_refresh.py 2025-10-11 22:11:45 +02:00
3c97c47f7e feat: add track selection and display functionality in the app 2025-10-11 22:10:53 +02:00
c35049cd54 fix: formatting (black)
Some checks failed
Backend CI / lint-and-test (push) Failing after 1m54s
2025-10-11 21:58:32 +02:00
31 changed files with 1335 additions and 160 deletions

80
TODO.md
View File

@@ -1,80 +0,0 @@
# Development TODO Plan
## Phase 1 Project Foundations
- [x] Initialize Git hooks, linting, and formatting tooling (ESLint, Prettier, isort, black).
- [x] Configure `pyproject.toml` or equivalent for backend dependency management.
- [x] Scaffold FastAPI application entrypoint with health-check endpoint.
- [x] Bootstrap React app with Vite/CRA, including routing skeleton and global state provider.
- [x] Define shared TypeScript/Python models for core domain entities (tracks, stations, trains).
- [x] Set up CI workflow for linting and test automation (GitHub Actions).
## Phase 2 Core Features
- [x] Implement authentication flow (backend JWT, frontend login/register forms).
- [x] Build map visualization integrating Leaflet with OSM tiles.
- [x] Define geographic bounding boxes and filtering rules for importing real-world stations from OpenStreetMap.
- [x] Implement an import script/CLI that pulls OSM station data and normalizes it to the PostGIS schema.
- [x] Expose backend CRUD endpoints for stations (create, update, archive) with validation and geometry handling.
- [x] Build React map tooling for selecting a station.
- [x] Enhance map UI to support selecting two stations and previewing the rail corridor between them.
- [x] Define track selection criteria and tagging rules for harvesting OSM rail segments within target regions.
- [x] Extend the importer to load track geometries and associate them with existing stations.
- [ ] Implement backend track-management APIs with length/speed validation and topology checks.
- [ ] Implement track path mapping along existing OSM rail segments between chosen stations.
- [ ] Design train connection manager requirements (link trains to operating tracks, manage consist data).
- [ ] Implement backend services and APIs to attach trains to routes and update assignments.
- [ ] Add UI flows for managing train connections, including visual feedback on the map.
- [ ] Establish train scheduling service with validation rules, conflict detection, and persistence APIs.
- [ ] Provide frontend scheduling tools (timeline or table view) for creating and editing train timetables.
- [ ] Develop frontend dashboards for resources, schedules, and achievements.
- [ ] Add real-time simulation updates (WebSocket layer, frontend subscription hooks).
## Phase 3 Data & Persistence
- [x] Design PostgreSQL/PostGIS schema and migrations (Alembic or similar).
- [x] Implement data access layer with SQLAlchemy and repository abstractions.
- [ ] Decide on canonical fixture scope (demo geography, sample trains) and document expected dataset size.
- [ ] Author fixture generation scripts that export JSON/GeoJSON compatible with the repository layer.
- [x] Create ingestion utilities to load fixtures into local and CI databases.
- [ ] Provision a Redis instance/container for local development.
- [ ] Add caching abstractions in backend services (e.g., network snapshot, map layers).
- [ ] Implement cache invalidation hooks tied to repository mutations.
## Phase 4 Testing & Quality
- [x] Write unit tests for backend services and models.
- [ ] Configure Jest/RTL testing utilities and shared mocks for Leaflet and network APIs.
- [ ] Write component tests for map controls, station builder UI, and dashboards.
- [ ] Add integration tests for custom hooks (network snapshot, scheduling forms).
- [x] Stand up Playwright/Cypress project structure with authentication helpers.
- [x] Script login end-to-end flow (Playwright).
- [ ] Script station creation end-to-end flow.
- [ ] Script track placement end-to-end flow.
- [ ] Script scheduling end-to-end flow.
- [ ] Define load/performance targets (requests per second, simulation latency) and tooling.
- [ ] Implement performance test harness covering scheduling and real-time updates.
## Phase 5 Deployment & Ops
- [x] Create Dockerfile for frontend.
- [x] Create Dockerfile for backend.
- [x] Create docker-compose for local development with Postgres/Redis dependencies.
- [ ] Add task runner commands to orchestrate container workflows.
- [ ] Set up CI/CD pipeline for automated builds, tests, and container publishing.
- [ ] Provision infrastructure scripts (Terraform/Ansible) targeting initial cloud environment.
- [ ] Define environment configuration strategy (secrets management, config maps).
- [ ] Configure observability stack (logging, metrics, tracing).
- [ ] Integrate tracing/logging exporters into backend services.
- [ ] Document deployment pipeline and release process.
## Phase 6 Polish & Expansion
- [ ] Add leaderboards and achievements logic with UI integration.
- [ ] Design data model changes required for achievements and ranking.
- [ ] Implement accessibility audit fixes (WCAG compliance).
- [ ] Conduct accessibility audit (contrast, keyboard navigation, screen reader paths).
- [ ] Optimize asset loading and introduce lazy loading strategies.
- [ ] Establish performance budgets for bundle size and render times.
- [ ] Evaluate multiplayer/coop roadmap and spike POCs where feasible.
- [ ] Prototype networking approach (WebRTC/WebSocket) for cooperative sessions.

View File

@@ -25,7 +25,7 @@ EXPOSE 8000
# Initialize database with demo data if INIT_DEMO_DB is set
CMD ["sh", "-c", "\
export PYTHONPATH=/app && \
export PYTHONPATH=/app/backend && \
echo 'Waiting for database...' && \
while ! pg_isready -h db -p 5432 -U railgame >/dev/null 2>&1; do sleep 1; done && \
echo 'Database is ready!' && \

View File

@@ -1,6 +1,6 @@
[alembic]
script_location = migrations
sqlalchemy.url = postgresql+psycopg://railgame:railgame@localhost:5432/railgame
sqlalchemy.url = postgresql+psycopg://railgame:railgame@localhost:5432/railgame_dev
[loggers]
keys = root,sqlalchemy,alembic

View File

@@ -4,9 +4,11 @@ from backend.app.api.auth import router as auth_router
from backend.app.api.health import router as health_router
from backend.app.api.network import router as network_router
from backend.app.api.stations import router as stations_router
from backend.app.api.tracks import router as tracks_router
router = APIRouter()
router.include_router(health_router, tags=["health"])
router.include_router(auth_router)
router.include_router(network_router)
router.include_router(stations_router)
router.include_router(tracks_router)

153
backend/app/api/tracks.py Normal file
View File

@@ -0,0 +1,153 @@
from __future__ import annotations
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy.orm import Session
from backend.app.api.deps import get_current_user, get_db
from backend.app.models import (
CombinedTrackModel,
TrackCreate,
TrackUpdate,
TrackModel,
UserPublic,
)
from backend.app.services.combined_tracks import (
create_combined_track,
get_combined_track,
list_combined_tracks,
)
from backend.app.services.tracks import (
create_track,
delete_track,
regenerate_combined_tracks,
update_track,
get_track,
list_tracks,
)
router = APIRouter(prefix="/tracks", tags=["tracks"])
@router.get("", response_model=list[TrackModel])
def read_combined_tracks(
_: UserPublic = Depends(get_current_user),
db: Session = Depends(get_db),
) -> list[TrackModel]:
"""Return all base tracks."""
return list_tracks(db)
@router.get("/combined", response_model=list[CombinedTrackModel])
def read_combined_tracks_combined(
_: UserPublic = Depends(get_current_user),
db: Session = Depends(get_db),
) -> list[CombinedTrackModel]:
return list_combined_tracks(db)
@router.get("/{track_id}", response_model=TrackModel)
def read_track(
track_id: str,
_: UserPublic = Depends(get_current_user),
db: Session = Depends(get_db),
) -> TrackModel:
track = get_track(db, track_id)
if track is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Track {track_id} not found",
)
return track
@router.get("/combined/{combined_track_id}", response_model=CombinedTrackModel)
def read_combined_track(
combined_track_id: str,
_: UserPublic = Depends(get_current_user),
db: Session = Depends(get_db),
) -> CombinedTrackModel:
combined_track = get_combined_track(db, combined_track_id)
if combined_track is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Combined track {combined_track_id} not found",
)
return combined_track
@router.post("", response_model=TrackModel, status_code=status.HTTP_201_CREATED)
def create_track_endpoint(
payload: TrackCreate,
regenerate: bool = False,
_: UserPublic = Depends(get_current_user),
db: Session = Depends(get_db),
) -> TrackModel:
try:
track = create_track(db, payload)
except ValueError as exc:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail=str(exc)
) from exc
if regenerate:
regenerate_combined_tracks(
db, [track.start_station_id, track.end_station_id])
return track
@router.post(
"/combined",
response_model=CombinedTrackModel,
status_code=status.HTTP_201_CREATED,
summary="Create a combined track between two stations using pathfinding",
)
def create_combined_track_endpoint(
start_station_id: str,
end_station_id: str,
_: UserPublic = Depends(get_current_user),
db: Session = Depends(get_db),
) -> CombinedTrackModel:
combined_track = create_combined_track(
db, start_station_id, end_station_id)
if combined_track is None:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Could not create combined track: no path exists between stations or track already exists",
)
return combined_track
@router.put("/{track_id}", response_model=TrackModel)
def update_track_endpoint(
track_id: str,
payload: TrackUpdate,
regenerate: bool = False,
_: UserPublic = Depends(get_current_user),
db: Session = Depends(get_db),
) -> TrackModel:
track = update_track(db, track_id, payload)
if track is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Track {track_id} not found",
)
if regenerate:
regenerate_combined_tracks(
db, [track.start_station_id, track.end_station_id])
return track
@router.delete("/{track_id}", status_code=status.HTTP_204_NO_CONTENT)
def delete_track_endpoint(
track_id: str,
regenerate: bool = False,
_: UserPublic = Depends(get_current_user),
db: Session = Depends(get_db),
) -> None:
deleted = delete_track(db, track_id, regenerate=regenerate)
if not deleted:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Track {track_id} not found",
)

View File

@@ -41,11 +41,14 @@ class User(Base, TimestampMixin):
id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True), primary_key=True, default=uuid.uuid4
)
username: Mapped[str] = mapped_column(String(64), unique=True, nullable=False)
email: Mapped[str | None] = mapped_column(String(255), unique=True, nullable=True)
username: Mapped[str] = mapped_column(
String(64), unique=True, nullable=False)
email: Mapped[str | None] = mapped_column(
String(255), unique=True, nullable=True)
full_name: Mapped[str | None] = mapped_column(String(128), nullable=True)
password_hash: Mapped[str] = mapped_column(String(256), nullable=False)
role: Mapped[str] = mapped_column(String(32), nullable=False, default="player")
role: Mapped[str] = mapped_column(
String(32), nullable=False, default="player")
preferences: Mapped[str | None] = mapped_column(Text, nullable=True)
@@ -62,12 +65,50 @@ class Station(Base, TimestampMixin):
Geometry(geometry_type="POINT", srid=4326), nullable=False
)
elevation_m: Mapped[float | None] = mapped_column(Float, nullable=True)
is_active: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
is_active: Mapped[bool] = mapped_column(
Boolean, nullable=False, default=True)
class Track(Base, TimestampMixin):
__tablename__ = "tracks"
id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True), primary_key=True, default=uuid.uuid4
)
osm_id: Mapped[str | None] = mapped_column(String(32), nullable=True)
name: Mapped[str | None] = mapped_column(String(128), nullable=True)
start_station_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
ForeignKey("stations.id", ondelete="RESTRICT"),
nullable=False,
)
end_station_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
ForeignKey("stations.id", ondelete="RESTRICT"),
nullable=False,
)
length_meters: Mapped[float | None] = mapped_column(
Numeric(10, 2), nullable=True)
max_speed_kph: Mapped[int | None] = mapped_column(Integer, nullable=True)
is_bidirectional: Mapped[bool] = mapped_column(
Boolean, nullable=False, default=True
)
status: Mapped[str] = mapped_column(
String(32), nullable=False, default="planned")
track_geometry: Mapped[str] = mapped_column(
Geometry(geometry_type="LINESTRING", srid=4326), nullable=False
)
__table_args__ = (
UniqueConstraint(
"start_station_id", "end_station_id", name="uq_tracks_station_pair"
),
)
class CombinedTrack(Base, TimestampMixin):
__tablename__ = "combined_tracks"
id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True), primary_key=True, default=uuid.uuid4
)
@@ -82,19 +123,25 @@ class Track(Base, TimestampMixin):
ForeignKey("stations.id", ondelete="RESTRICT"),
nullable=False,
)
length_meters: Mapped[float | None] = mapped_column(Numeric(10, 2), nullable=True)
length_meters: Mapped[float | None] = mapped_column(
Numeric(10, 2), nullable=True)
max_speed_kph: Mapped[int | None] = mapped_column(Integer, nullable=True)
is_bidirectional: Mapped[bool] = mapped_column(
Boolean, nullable=False, default=True
)
status: Mapped[str] = mapped_column(String(32), nullable=False, default="planned")
track_geometry: Mapped[str] = mapped_column(
status: Mapped[str] = mapped_column(
String(32), nullable=False, default="planned")
combined_geometry: Mapped[str] = mapped_column(
Geometry(geometry_type="LINESTRING", srid=4326), nullable=False
)
# JSON array of constituent track IDs
constituent_track_ids: Mapped[str] = mapped_column(
Text, nullable=False
)
__table_args__ = (
UniqueConstraint(
"start_station_id", "end_station_id", name="uq_tracks_station_pair"
"start_station_id", "end_station_id", name="uq_combined_tracks_station_pair"
),
)
@@ -105,7 +152,8 @@ class Train(Base, TimestampMixin):
id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True), primary_key=True, default=uuid.uuid4
)
designation: Mapped[str] = mapped_column(String(64), nullable=False, unique=True)
designation: Mapped[str] = mapped_column(
String(64), nullable=False, unique=True)
operator_id: Mapped[uuid.UUID | None] = mapped_column(
UUID(as_uuid=True), ForeignKey("users.id", ondelete="SET NULL")
)

View File

@@ -8,11 +8,14 @@ from .auth import (
UserPublic,
)
from .base import (
CombinedTrackCreate,
CombinedTrackModel,
StationCreate,
StationModel,
StationUpdate,
TrackCreate,
TrackModel,
TrackUpdate,
TrainCreate,
TrainModel,
TrainScheduleCreate,
@@ -33,9 +36,12 @@ __all__ = [
"StationUpdate",
"TrackCreate",
"TrackModel",
"TrackUpdate",
"TrainScheduleCreate",
"TrainCreate",
"TrainModel",
"UserCreate",
"to_camel",
"CombinedTrackCreate",
"CombinedTrackModel",
]

View File

@@ -51,13 +51,24 @@ class StationModel(IdentifiedModel[str]):
class TrackModel(IdentifiedModel[str]):
start_station_id: str
end_station_id: str
length_meters: float
max_speed_kph: float
length_meters: float | None = None
max_speed_kph: float | None = None
status: str | None = None
is_bidirectional: bool = True
coordinates: list[tuple[float, float]] = Field(default_factory=list)
class CombinedTrackModel(IdentifiedModel[str]):
start_station_id: str
end_station_id: str
length_meters: float | None = None
max_speed_kph: int | None = None
status: str | None = None
is_bidirectional: bool = True
coordinates: list[tuple[float, float]] = Field(default_factory=list)
constituent_track_ids: list[str] = Field(default_factory=list)
class TrainModel(IdentifiedModel[str]):
designation: str
capacity: int
@@ -89,6 +100,31 @@ class TrackCreate(CamelModel):
start_station_id: str
end_station_id: str
coordinates: Sequence[tuple[float, float]]
osm_id: str | None = None
name: str | None = None
length_meters: float | None = None
max_speed_kph: int | None = None
is_bidirectional: bool = True
status: str = "planned"
class TrackUpdate(CamelModel):
start_station_id: str | None = None
end_station_id: str | None = None
coordinates: Sequence[tuple[float, float]] | None = None
osm_id: str | None = None
name: str | None = None
length_meters: float | None = None
max_speed_kph: int | None = None
is_bidirectional: bool | None = None
status: str | None = None
class CombinedTrackCreate(CamelModel):
start_station_id: str
end_station_id: str
coordinates: Sequence[tuple[float, float]]
constituent_track_ids: list[str]
name: str | None = None
length_meters: float | None = None
max_speed_kph: int | None = None

View File

@@ -2,6 +2,7 @@
from backend.app.repositories.stations import StationRepository
from backend.app.repositories.tracks import TrackRepository
from backend.app.repositories.combined_tracks import CombinedTrackRepository
from backend.app.repositories.train_schedules import TrainScheduleRepository
from backend.app.repositories.trains import TrainRepository
from backend.app.repositories.users import UserRepository
@@ -10,6 +11,7 @@ __all__ = [
"StationRepository",
"TrainScheduleRepository",
"TrackRepository",
"CombinedTrackRepository",
"TrainRepository",
"UserRepository",
]

View File

@@ -0,0 +1,73 @@
from __future__ import annotations
import json
from uuid import UUID
import sqlalchemy as sa
from geoalchemy2.elements import WKTElement
from sqlalchemy.orm import Session
from backend.app.db.models import CombinedTrack
from backend.app.models import CombinedTrackCreate
from backend.app.repositories.base import BaseRepository
class CombinedTrackRepository(BaseRepository[CombinedTrack]):
model = CombinedTrack
def __init__(self, session: Session) -> None:
super().__init__(session)
def list_all(self) -> list[CombinedTrack]:
statement = sa.select(self.model)
return list(self.session.scalars(statement))
def exists_between_stations(self, start_station_id: str, end_station_id: str) -> bool:
"""Check if a combined track already exists between two stations."""
statement = sa.select(sa.exists().where(
sa.and_(
self.model.start_station_id == start_station_id,
self.model.end_station_id == end_station_id
)
))
return bool(self.session.scalar(statement))
def get_constituent_track_ids(self, combined_track: CombinedTrack) -> list[str]:
"""Extract constituent track IDs from a combined track."""
try:
return json.loads(combined_track.constituent_track_ids)
except (json.JSONDecodeError, TypeError):
return []
@staticmethod
def _ensure_uuid(value: UUID | str) -> UUID:
if isinstance(value, UUID):
return value
return UUID(str(value))
@staticmethod
def _line_string(coordinates: list[tuple[float, float]]) -> WKTElement:
if len(coordinates) < 2:
raise ValueError(
"Combined track geometry requires at least two coordinate pairs")
parts = [f"{lon} {lat}" for lat, lon in coordinates]
return WKTElement(f"LINESTRING({', '.join(parts)})", srid=4326)
def create(self, data: CombinedTrackCreate) -> CombinedTrack:
coordinates = list(data.coordinates)
geometry = self._line_string(coordinates)
constituent_track_ids_json = json.dumps(data.constituent_track_ids)
combined_track = CombinedTrack(
name=data.name,
start_station_id=self._ensure_uuid(data.start_station_id),
end_station_id=self._ensure_uuid(data.end_station_id),
length_meters=data.length_meters,
max_speed_kph=data.max_speed_kph,
is_bidirectional=data.is_bidirectional,
status=data.status,
combined_geometry=geometry,
constituent_track_ids=constituent_track_ids_json,
)
self.session.add(combined_track)
return combined_track

View File

@@ -7,7 +7,7 @@ from geoalchemy2.elements import WKTElement
from sqlalchemy.orm import Session
from backend.app.db.models import Track
from backend.app.models import TrackCreate
from backend.app.models import TrackCreate, TrackUpdate
from backend.app.repositories.base import BaseRepository
@@ -21,6 +21,102 @@ class TrackRepository(BaseRepository[Track]):
statement = sa.select(self.model)
return list(self.session.scalars(statement))
def exists_by_osm_id(self, osm_id: str) -> bool:
statement = sa.select(sa.exists().where(self.model.osm_id == osm_id))
return bool(self.session.scalar(statement))
def find_path_between_stations(self, start_station_id: str, end_station_id: str) -> list[Track] | None:
"""Find the shortest path between two stations using existing tracks.
Returns a list of tracks that form the path, or None if no path exists.
"""
# Build adjacency list: station -> list of (neighbor_station, track)
adjacency = self._build_track_graph()
if start_station_id not in adjacency or end_station_id not in adjacency:
return None
# BFS to find shortest path
from collections import deque
# (current_station, path_so_far)
queue = deque([(start_station_id, [])])
visited = set([start_station_id])
while queue:
current_station, path = queue.popleft()
if current_station == end_station_id:
return path
for neighbor, track in adjacency[current_station]:
if neighbor not in visited:
visited.add(neighbor)
queue.append((neighbor, path + [track]))
return None # No path found
def _build_track_graph(self) -> dict[str, list[tuple[str, Track]]]:
"""Build a graph representation of tracks: station -> [(neighbor_station, track), ...]"""
tracks = self.list_all()
graph = {}
for track in tracks:
start_id = str(track.start_station_id)
end_id = str(track.end_station_id)
# Add bidirectional edges (assuming tracks are bidirectional)
if start_id not in graph:
graph[start_id] = []
if end_id not in graph:
graph[end_id] = []
graph[start_id].append((end_id, track))
graph[end_id].append((start_id, track))
return graph
def combine_track_geometries(self, tracks: list[Track]) -> list[tuple[float, float]]:
"""Combine the geometries of multiple tracks into a single coordinate sequence.
Assumes tracks are in order and form a continuous path.
"""
if not tracks:
return []
combined_coords = []
for i, track in enumerate(tracks):
# Extract coordinates from track geometry
coords = self._extract_coordinates_from_track(track)
if i == 0:
# First track: add all coordinates
combined_coords.extend(coords)
else:
# Subsequent tracks: skip the first coordinate (shared with previous track)
combined_coords.extend(coords[1:])
return combined_coords
def _extract_coordinates_from_track(self, track: Track) -> list[tuple[float, float]]:
"""Extract coordinate list from a track's geometry."""
# Convert WKT string to WKTElement, then to shapely geometry
from geoalchemy2.elements import WKTElement
from geoalchemy2.shape import to_shape
try:
wkt_element = WKTElement(track.track_geometry)
geom = to_shape(wkt_element)
if hasattr(geom, 'coords'):
# For LineString, coords returns [(x, y), ...] where x=lon, y=lat
# Convert to (lat, lon)
return [(coord[1], coord[0]) for coord in geom.coords]
except Exception:
pass
return []
@staticmethod
def _ensure_uuid(value: UUID | str) -> UUID:
if isinstance(value, UUID):
@@ -30,7 +126,8 @@ class TrackRepository(BaseRepository[Track]):
@staticmethod
def _line_string(coordinates: list[tuple[float, float]]) -> WKTElement:
if len(coordinates) < 2:
raise ValueError("Track geometry requires at least two coordinate pairs")
raise ValueError(
"Track geometry requires at least two coordinate pairs")
parts = [f"{lon} {lat}" for lat, lon in coordinates]
return WKTElement(f"LINESTRING({', '.join(parts)})", srid=4326)
@@ -38,6 +135,7 @@ class TrackRepository(BaseRepository[Track]):
coordinates = list(data.coordinates)
geometry = self._line_string(coordinates)
track = Track(
osm_id=data.osm_id,
name=data.name,
start_station_id=self._ensure_uuid(data.start_station_id),
end_station_id=self._ensure_uuid(data.end_station_id),
@@ -49,3 +147,26 @@ class TrackRepository(BaseRepository[Track]):
)
self.session.add(track)
return track
def update(self, track: Track, payload: TrackUpdate) -> Track:
if payload.start_station_id is not None:
track.start_station_id = self._ensure_uuid(
payload.start_station_id)
if payload.end_station_id is not None:
track.end_station_id = self._ensure_uuid(payload.end_station_id)
if payload.coordinates is not None:
track.track_geometry = self._line_string(
list(payload.coordinates)) # type: ignore[assignment]
if payload.osm_id is not None:
track.osm_id = payload.osm_id
if payload.name is not None:
track.name = payload.name
if payload.length_meters is not None:
track.length_meters = payload.length_meters
if payload.max_speed_kph is not None:
track.max_speed_kph = payload.max_speed_kph
if payload.is_bidirectional is not None:
track.is_bidirectional = payload.is_bidirectional
if payload.status is not None:
track.status = payload.status
return track

View File

@@ -0,0 +1,79 @@
from __future__ import annotations
"""Application services for combined track operations."""
from sqlalchemy.orm import Session
from backend.app.models import CombinedTrackCreate, CombinedTrackModel
from backend.app.repositories import CombinedTrackRepository, TrackRepository
def create_combined_track(
session: Session, start_station_id: str, end_station_id: str
) -> CombinedTrackModel | None:
"""Create a combined track between two stations using pathfinding.
Returns the created combined track, or None if no path exists or
a combined track already exists between these stations.
"""
combined_track_repo = CombinedTrackRepository(session)
track_repo = TrackRepository(session)
# Check if combined track already exists
if combined_track_repo.exists_between_stations(start_station_id, end_station_id):
return None
# Find path between stations
path_tracks = track_repo.find_path_between_stations(
start_station_id, end_station_id)
if not path_tracks:
return None
# Combine geometries
combined_coords = track_repo.combine_track_geometries(path_tracks)
if len(combined_coords) < 2:
return None
# Calculate total length
total_length = sum(track.length_meters or 0 for track in path_tracks)
# Get max speed (use the minimum speed of all tracks)
max_speeds = [
track.max_speed_kph for track in path_tracks if track.max_speed_kph]
max_speed = min(max_speeds) if max_speeds else None
# Get constituent track IDs
constituent_track_ids = [str(track.id) for track in path_tracks]
# Create combined track
create_data = CombinedTrackCreate(
start_station_id=start_station_id,
end_station_id=end_station_id,
coordinates=combined_coords,
constituent_track_ids=constituent_track_ids,
length_meters=total_length if total_length > 0 else None,
max_speed_kph=max_speed,
status="operational",
)
combined_track = combined_track_repo.create(create_data)
session.commit()
return CombinedTrackModel.model_validate(combined_track)
def get_combined_track(session: Session, combined_track_id: str) -> CombinedTrackModel | None:
"""Get a combined track by ID."""
try:
combined_track_repo = CombinedTrackRepository(session)
combined_track = combined_track_repo.get(combined_track_id)
return CombinedTrackModel.model_validate(combined_track)
except LookupError:
return None
def list_combined_tracks(session: Session) -> list[CombinedTrackModel]:
"""List all combined tracks."""
combined_track_repo = CombinedTrackRepository(session)
combined_tracks = combined_track_repo.list_all()
return [CombinedTrackModel.model_validate(ct) for ct in combined_tracks]

View File

@@ -148,7 +148,11 @@ def get_network_snapshot(session: Session) -> dict[str, list[dict[str, object]]]
if geometry is not None and LineString is not None
else None
)
if LineString is not None and shape is not None and isinstance(shape, LineString):
if (
LineString is not None
and shape is not None
and isinstance(shape, LineString)
):
coords_list: list[tuple[float, float]] = []
for coord in shape.coords:
lon = float(coord[0])

View File

@@ -0,0 +1,106 @@
from __future__ import annotations
"""Service layer for primary track management operations."""
from typing import Iterable
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session
from backend.app.models import CombinedTrackModel, TrackCreate, TrackModel, TrackUpdate
from backend.app.repositories import CombinedTrackRepository, TrackRepository
def list_tracks(session: Session) -> list[TrackModel]:
repo = TrackRepository(session)
tracks = repo.list_all()
return [TrackModel.model_validate(track) for track in tracks]
def get_track(session: Session, track_id: str) -> TrackModel | None:
repo = TrackRepository(session)
track = repo.get(track_id)
if track is None:
return None
return TrackModel.model_validate(track)
def create_track(session: Session, payload: TrackCreate) -> TrackModel:
repo = TrackRepository(session)
try:
track = repo.create(payload)
session.commit()
except IntegrityError as exc:
session.rollback()
raise ValueError(
"Track with the same station pair already exists") from exc
return TrackModel.model_validate(track)
def update_track(session: Session, track_id: str, payload: TrackUpdate) -> TrackModel | None:
repo = TrackRepository(session)
track = repo.get(track_id)
if track is None:
return None
repo.update(track, payload)
session.commit()
return TrackModel.model_validate(track)
def delete_track(session: Session, track_id: str, regenerate: bool = False) -> bool:
repo = TrackRepository(session)
track = repo.get(track_id)
if track is None:
return False
start_station_id = str(track.start_station_id)
end_station_id = str(track.end_station_id)
session.delete(track)
session.commit()
if regenerate:
regenerate_combined_tracks(session, [start_station_id, end_station_id])
return True
def regenerate_combined_tracks(session: Session, station_ids: Iterable[str]) -> list[CombinedTrackModel]:
combined_repo = CombinedTrackRepository(session)
station_id_set = set(station_ids)
if not station_id_set:
return []
# Remove combined tracks touching these stations
for combined in combined_repo.list_all():
if {str(combined.start_station_id), str(combined.end_station_id)} & station_id_set:
session.delete(combined)
session.commit()
# Rebuild combined tracks between affected station pairs
from backend.app.services.combined_tracks import create_combined_track
regenerated: list[CombinedTrackModel] = []
station_list = list(station_id_set)
for i in range(len(station_list)):
for j in range(i + 1, len(station_list)):
result = create_combined_track(
session, station_list[i], station_list[j])
if result is not None:
regenerated.append(result)
return regenerated
__all__ = [
"list_tracks",
"get_track",
"create_track",
"update_track",
"delete_track",
"regenerate_combined_tracks",
]

View File

@@ -0,0 +1,19 @@
"""Template for new Alembic migration scripts."""
from __future__ import annotations
import sqlalchemy as sa
from alembic import op
revision = '63d02d67b39e'
down_revision = '20251011_01'
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column('tracks', sa.Column(
'osm_id', sa.String(length=32), nullable=True))
def downgrade() -> None:
op.drop_column('tracks', 'osm_id')

View File

@@ -0,0 +1,75 @@
"""Template for new Alembic migration scripts."""
from __future__ import annotations
from sqlalchemy.dialects import postgresql
from geoalchemy2.types import Geometry
import sqlalchemy as sa
from alembic import op
revision = 'e7d4bb03da04'
down_revision = '63d02d67b39e'
branch_labels = None
depends_on = None
def upgrade() -> None:
op.create_table(
"combined_tracks",
sa.Column(
"id",
postgresql.UUID(as_uuid=True),
primary_key=True,
server_default=sa.text("gen_random_uuid()"),
),
sa.Column("name", sa.String(length=128), nullable=True),
sa.Column("start_station_id", postgresql.UUID(
as_uuid=True), nullable=False),
sa.Column("end_station_id", postgresql.UUID(
as_uuid=True), nullable=False),
sa.Column("length_meters", sa.Numeric(10, 2), nullable=True),
sa.Column("max_speed_kph", sa.Integer(), nullable=True),
sa.Column(
"is_bidirectional",
sa.Boolean(),
nullable=False,
server_default=sa.text("true"),
),
sa.Column(
"status", sa.String(length=32), nullable=False, server_default="planned"
),
sa.Column(
"combined_geometry",
Geometry(geometry_type="LINESTRING", srid=4326),
nullable=False,
),
sa.Column("constituent_track_ids", sa.Text(), nullable=False),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.text("timezone('utc', now())"),
nullable=False,
),
sa.Column(
"updated_at",
sa.DateTime(timezone=True),
server_default=sa.text("timezone('utc', now())"),
nullable=False,
),
sa.ForeignKeyConstraint(
["start_station_id"], ["stations.id"], ondelete="RESTRICT"
),
sa.ForeignKeyConstraint(
["end_station_id"], ["stations.id"], ondelete="RESTRICT"
),
sa.UniqueConstraint(
"start_station_id", "end_station_id", name="uq_combined_tracks_station_pair"
),
)
op.create_index(
"ix_combined_tracks_geometry", "combined_tracks", ["combined_geometry"], postgresql_using="gist"
)
def downgrade() -> None:
op.drop_index("ix_combined_tracks_geometry", table_name="combined_tracks")
op.drop_table("combined_tracks")

View File

@@ -9,12 +9,7 @@ from pathlib import Path
from typing import Callable, Sequence
from backend.app.core.osm_config import DEFAULT_REGIONS
from backend.scripts import (
stations_import,
stations_load,
tracks_import,
tracks_load,
)
from backend.scripts import stations_import, stations_load, tracks_import, tracks_load
@dataclass(slots=True)

View File

@@ -98,14 +98,12 @@ def normalize_station_elements(
if not name:
continue
raw_code = tags.get("ref") or tags.get(
"railway:ref") or tags.get("local_ref")
raw_code = tags.get("ref") or tags.get("railway:ref") or tags.get("local_ref")
code = str(raw_code) if raw_code is not None else None
elevation_tag = tags.get("ele") or tags.get("elevation")
try:
elevation = float(
elevation_tag) if elevation_tag is not None else None
elevation = float(elevation_tag) if elevation_tag is not None else None
except (TypeError, ValueError):
elevation = None

View File

@@ -56,7 +56,8 @@ def build_overpass_query(region_name: str) -> str:
regions = DEFAULT_REGIONS
else:
regions = tuple(
region for region in DEFAULT_REGIONS if region.name == region_name)
region for region in DEFAULT_REGIONS if region.name == region_name
)
if not regions:
available = ", ".join(region.name for region in DEFAULT_REGIONS)
msg = f"Unknown region {region_name}. Available regions: [{available}]"
@@ -86,7 +87,9 @@ def perform_request(query: str) -> dict[str, Any]:
return json.loads(payload)
def normalize_track_elements(elements: Iterable[dict[str, Any]]) -> list[dict[str, Any]]:
def normalize_track_elements(
elements: Iterable[dict[str, Any]]
) -> list[dict[str, Any]]:
"""Convert Overpass way elements into TrackCreate-compatible payloads."""
tracks: list[dict[str, Any]] = []

View File

@@ -22,6 +22,7 @@ from backend.app.repositories import StationRepository, TrackRepository
@dataclass(slots=True)
class ParsedTrack:
coordinates: list[tuple[float, float]]
osm_id: str | None = None
name: str | None = None
length_meters: float | None = None
max_speed_kph: float | None = None
@@ -91,7 +92,8 @@ def _parse_track_entries(entries: Iterable[Mapping[str, Any]]) -> list[ParsedTra
coordinates = entry.get("coordinates")
if not isinstance(coordinates, Sequence) or len(coordinates) < 2:
raise ValueError(
"Invalid track entry: 'coordinates' must contain at least two points")
"Invalid track entry: 'coordinates' must contain at least two points"
)
processed_coordinates: list[tuple[float, float]] = []
for pair in coordinates:
@@ -106,10 +108,12 @@ def _parse_track_entries(entries: Iterable[Mapping[str, Any]]) -> list[ParsedTra
max_speed = _safe_float(entry.get("maxSpeedKph"))
status = entry.get("status", "operational")
is_bidirectional = entry.get("isBidirectional", True)
osm_id = entry.get("osmId")
parsed.append(
ParsedTrack(
coordinates=processed_coordinates,
osm_id=str(osm_id) if osm_id else None,
name=str(name) if name else None,
length_meters=length,
max_speed_kph=max_speed,
@@ -133,6 +137,12 @@ def load_tracks(tracks: Iterable[ParsedTrack], commit: bool = True) -> int:
}
for track_data in tracks:
# Skip if track with this OSM ID already exists
if track_data.osm_id and track_repo.exists_by_osm_id(track_data.osm_id):
print(
f"Skipping track {track_data.osm_id} - already exists by OSM ID")
continue
start_station = _nearest_station(
track_data.coordinates[0],
station_index,
@@ -145,23 +155,31 @@ def load_tracks(tracks: Iterable[ParsedTrack], commit: bool = True) -> int:
)
if not start_station or not end_station:
print(
f"Skipping track {track_data.osm_id} - no start/end stations found")
continue
if start_station.id == end_station.id:
print(
f"Skipping track {track_data.osm_id} - start and end stations are the same")
continue
pair = (start_station.id, end_station.id)
if pair in existing_pairs:
print(
f"Skipping track {track_data.osm_id} - station pair {pair} already exists")
continue
length = track_data.length_meters or _polyline_length(
track_data.coordinates)
track_data.coordinates
)
max_speed = (
int(round(track_data.max_speed_kph))
if track_data.max_speed_kph is not None
else None
)
create_schema = TrackCreate(
osm_id=track_data.osm_id,
name=track_data.name,
start_station_id=start_station.id,
end_station_id=end_station.id,
@@ -229,7 +247,9 @@ def _to_point(geometry: WKBElement | WKTElement | Any):
try:
point = to_shape(geometry)
return point if getattr(point, "geom_type", None) == "Point" else None
except Exception: # pragma: no cover - defensive, should not happen with valid geometry
except (
Exception
): # pragma: no cover - defensive, should not happen with valid geometry
return None

View File

@@ -0,0 +1,166 @@
from __future__ import annotations
from dataclasses import dataclass, field
from datetime import datetime, timezone
from typing import Any, List
from uuid import uuid4
import pytest
from backend.app.models import CombinedTrackModel
from backend.app.repositories.combined_tracks import CombinedTrackRepository
from backend.app.repositories.tracks import TrackRepository
from backend.app.services.combined_tracks import create_combined_track
@dataclass
class DummySession:
added: List[Any] = field(default_factory=list)
scalars_result: List[Any] = field(default_factory=list)
scalar_result: Any = None
statements: List[Any] = field(default_factory=list)
committed: bool = False
rolled_back: bool = False
closed: bool = False
def add(self, instance: Any) -> None:
self.added.append(instance)
def add_all(self, instances: list[Any]) -> None:
self.added.extend(instances)
def scalars(self, statement: Any) -> list[Any]:
self.statements.append(statement)
return list(self.scalars_result)
def scalar(self, statement: Any) -> Any:
self.statements.append(statement)
return self.scalar_result
def flush(
self, _objects: list[Any] | None = None
) -> None: # pragma: no cover - optional
return None
def commit(self) -> None: # pragma: no cover - optional
self.committed = True
def rollback(self) -> None: # pragma: no cover - optional
self.rolled_back = True
def close(self) -> None: # pragma: no cover - optional
self.closed = True
def _now() -> datetime:
return datetime.now(timezone.utc)
def test_combined_track_model_round_trip() -> None:
timestamp = _now()
combined_track = CombinedTrackModel(
id="combined-track-1",
start_station_id="station-1",
end_station_id="station-2",
length_meters=3000.0,
max_speed_kph=100,
status="operational",
is_bidirectional=True,
coordinates=[(52.52, 13.405), (52.6, 13.5), (52.7, 13.6)],
constituent_track_ids=["track-1", "track-2"],
created_at=timestamp,
updated_at=timestamp,
)
assert combined_track.length_meters == 3000.0
assert combined_track.start_station_id != combined_track.end_station_id
assert len(combined_track.coordinates) == 3
assert len(combined_track.constituent_track_ids) == 2
def test_combined_track_repository_create() -> None:
"""Test creating a combined track through the repository."""
session = DummySession()
repo = CombinedTrackRepository(session) # type: ignore[arg-type]
# Create test data
from backend.app.models import CombinedTrackCreate
create_data = CombinedTrackCreate(
start_station_id="550e8400-e29b-41d4-a716-446655440000",
end_station_id="550e8400-e29b-41d4-a716-446655440001",
coordinates=[(52.52, 13.405), (52.6, 13.5)],
constituent_track_ids=["track-1"],
length_meters=1500.0,
max_speed_kph=120,
status="operational",
)
combined_track = repo.create(create_data)
assert combined_track.start_station_id is not None
assert combined_track.end_station_id is not None
assert combined_track.length_meters == 1500.0
assert combined_track.max_speed_kph == 120
assert combined_track.status == "operational"
assert session.added and session.added[0] is combined_track
def test_combined_track_repository_exists_between_stations() -> None:
"""Test checking if combined track exists between stations."""
session = DummySession()
repo = CombinedTrackRepository(session) # type: ignore[arg-type]
# Initially should not exist (scalar_result is None by default)
assert not repo.exists_between_stations(
"550e8400-e29b-41d4-a716-446655440000",
"550e8400-e29b-41d4-a716-446655440001"
)
# Simulate existing combined track
session.scalar_result = True
assert repo.exists_between_stations(
"550e8400-e29b-41d4-a716-446655440000",
"550e8400-e29b-41d4-a716-446655440001"
)
def test_combined_track_service_create_no_path() -> None:
"""Test creating combined track when no path exists."""
# Mock session and repositories
session = DummySession()
# Mock TrackRepository to return no path
class MockTrackRepository:
def __init__(self, session):
pass
def find_path_between_stations(self, start_id, end_id):
return None
# Mock CombinedTrackRepository
class MockCombinedTrackRepository:
def __init__(self, session):
pass
def exists_between_stations(self, start_id, end_id):
return False
# Patch the service to use mock repositories
import backend.app.services.combined_tracks as service_module
original_track_repo = service_module.TrackRepository
original_combined_repo = service_module.CombinedTrackRepository
service_module.TrackRepository = MockTrackRepository
service_module.CombinedTrackRepository = MockCombinedTrackRepository
try:
result = create_combined_track(
session, # type: ignore[arg-type]
"550e8400-e29b-41d4-a716-446655440000",
"550e8400-e29b-41d4-a716-446655440001"
)
assert result is None
finally:
# Restore original classes
service_module.TrackRepository = original_track_repo
service_module.CombinedTrackRepository = original_combined_repo

View File

@@ -50,8 +50,7 @@ def test_network_snapshot_prefers_repository_data(
track = sample_entities["track"]
train = sample_entities["train"]
monkeypatch.setattr(StationRepository, "list_active",
lambda self: [station])
monkeypatch.setattr(StationRepository, "list_active", lambda self: [station])
monkeypatch.setattr(TrackRepository, "list_all", lambda self: [track])
monkeypatch.setattr(TrainRepository, "list_all", lambda self: [train])
@@ -59,8 +58,7 @@ def test_network_snapshot_prefers_repository_data(
assert snapshot["stations"]
assert snapshot["stations"][0]["name"] == station.name
assert snapshot["tracks"][0]["lengthMeters"] == pytest.approx(
track.length_meters)
assert snapshot["tracks"][0]["lengthMeters"] == pytest.approx(track.length_meters)
assert snapshot["trains"][0]["designation"] == train.designation
assert snapshot["trains"][0]["operatingTrackIds"] == []
@@ -76,5 +74,4 @@ def test_network_snapshot_falls_back_when_repositories_empty(
assert snapshot["stations"]
assert snapshot["trains"]
assert any(station["name"] ==
"Central" for station in snapshot["stations"])
assert any(station["name"] == "Central" for station in snapshot["stations"])

View File

@@ -58,7 +58,9 @@ def test_build_stage_plan_respects_skip_flags(tmp_path: Path) -> None:
assert labels == ["Load stations", "Load tracks"]
def test_main_dry_run_lists_plan(monkeypatch: pytest.MonkeyPatch, tmp_path: Path, capsys: pytest.CaptureFixture[str]) -> None:
def test_main_dry_run_lists_plan(
monkeypatch: pytest.MonkeyPatch, tmp_path: Path, capsys: pytest.CaptureFixture[str]
) -> None:
def fail(_args: list[str] | None) -> int: # pragma: no cover - defensive
raise AssertionError("runner should not be invoked during dry run")
@@ -76,7 +78,9 @@ def test_main_dry_run_lists_plan(monkeypatch: pytest.MonkeyPatch, tmp_path: Path
assert "Load tracks" in captured
def test_main_executes_stages_in_order(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None:
def test_main_executes_stages_in_order(
monkeypatch: pytest.MonkeyPatch, tmp_path: Path
) -> None:
calls: list[str] = []
def make_import(name: str):
@@ -98,14 +102,12 @@ def test_main_executes_stages_in_order(monkeypatch: pytest.MonkeyPatch, tmp_path
return runner
monkeypatch.setattr(osm_refresh.stations_import, "main",
make_import("stations_import"))
monkeypatch.setattr(osm_refresh.tracks_import, "main",
make_import("tracks_import"))
monkeypatch.setattr(osm_refresh.stations_load, "main",
make_load("stations_load"))
monkeypatch.setattr(osm_refresh.tracks_load, "main",
make_load("tracks_load"))
monkeypatch.setattr(
osm_refresh.stations_import, "main", make_import("stations_import")
)
monkeypatch.setattr(osm_refresh.tracks_import, "main", make_import("tracks_import"))
monkeypatch.setattr(osm_refresh.stations_load, "main", make_load("stations_load"))
monkeypatch.setattr(osm_refresh.tracks_load, "main", make_load("tracks_load"))
exit_code = osm_refresh.main(["--output-dir", str(tmp_path)])
@@ -118,7 +120,9 @@ def test_main_executes_stages_in_order(monkeypatch: pytest.MonkeyPatch, tmp_path
]
def test_main_skip_import_flags(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None:
def test_main_skip_import_flags(
monkeypatch: pytest.MonkeyPatch, tmp_path: Path
) -> None:
station_json = tmp_path / "stations.json"
station_json.write_text("{}", encoding="utf-8")
track_json = tmp_path / "tracks.json"
@@ -139,8 +143,7 @@ def test_main_skip_import_flags(monkeypatch: pytest.MonkeyPatch, tmp_path: Path)
monkeypatch.setattr(osm_refresh.stations_import, "main", fail)
monkeypatch.setattr(osm_refresh.tracks_import, "main", fail)
monkeypatch.setattr(osm_refresh.stations_load,
"main", record("stations_load"))
monkeypatch.setattr(osm_refresh.stations_load, "main", record("stations_load"))
monkeypatch.setattr(osm_refresh.tracks_load, "main", record("tracks_load"))
exit_code = osm_refresh.main(

View File

@@ -0,0 +1,158 @@
from __future__ import annotations
from datetime import datetime, timezone
from typing import Any
import pytest
from fastapi.testclient import TestClient
from backend.app.api import tracks as tracks_api
from backend.app.main import app
from backend.app.models import CombinedTrackModel, TrackModel
client = TestClient(app)
def _track_model(track_id: str = "track-1") -> TrackModel:
now = datetime.now(timezone.utc)
return TrackModel(
id=track_id,
start_station_id="station-a",
end_station_id="station-b",
length_meters=None,
max_speed_kph=None,
status="planned",
coordinates=[(52.5, 13.4), (52.6, 13.5)],
is_bidirectional=True,
created_at=now,
updated_at=now,
)
def _combined_model(track_id: str = "combined-1") -> CombinedTrackModel:
now = datetime.now(timezone.utc)
return CombinedTrackModel(
id=track_id,
start_station_id="station-a",
end_station_id="station-b",
length_meters=1000,
max_speed_kph=120,
status="operational",
coordinates=[(52.5, 13.4), (52.6, 13.5)],
constituent_track_ids=["track-1", "track-2"],
is_bidirectional=True,
created_at=now,
updated_at=now,
)
def _authenticate() -> str:
response = client.post(
"/api/auth/login",
json={"username": "demo", "password": "railgame123"},
)
assert response.status_code == 200
return response.json()["accessToken"]
def test_list_tracks(monkeypatch: pytest.MonkeyPatch) -> None:
token = _authenticate()
monkeypatch.setattr(tracks_api, "list_tracks", lambda db: [_track_model()])
response = client.get(
"/api/tracks",
headers={"Authorization": f"Bearer {token}"},
)
assert response.status_code == 200
payload = response.json()
assert isinstance(payload, list)
assert payload[0]["id"] == "track-1"
def test_get_track_returns_404(monkeypatch: pytest.MonkeyPatch) -> None:
token = _authenticate()
monkeypatch.setattr(tracks_api, "get_track", lambda db, track_id: None)
response = client.get(
"/api/tracks/not-found",
headers={"Authorization": f"Bearer {token}"},
)
assert response.status_code == 404
def test_create_track_calls_service(monkeypatch: pytest.MonkeyPatch) -> None:
token = _authenticate()
captured: dict[str, Any] = {}
payload = {
"startStationId": "station-a",
"endStationId": "station-b",
"coordinates": [[52.5, 13.4], [52.6, 13.5]],
}
def fake_create(db: Any, data: Any) -> TrackModel:
assert data.start_station_id == "station-a"
captured["payload"] = data
return _track_model("track-new")
monkeypatch.setattr(tracks_api, "create_track", fake_create)
response = client.post(
"/api/tracks",
json=payload,
headers={"Authorization": f"Bearer {token}"},
)
assert response.status_code == 201
body = response.json()
assert body["id"] == "track-new"
assert captured["payload"].end_station_id == "station-b"
def test_delete_track_returns_404(monkeypatch: pytest.MonkeyPatch) -> None:
token = _authenticate()
monkeypatch.setattr(
tracks_api, "delete_track", lambda db, tid, regenerate=False: False
)
response = client.delete(
"/api/tracks/missing",
headers={"Authorization": f"Bearer {token}"},
)
assert response.status_code == 404
def test_delete_track_success(monkeypatch: pytest.MonkeyPatch) -> None:
token = _authenticate()
seen: dict[str, Any] = {}
def fake_delete(db: Any, track_id: str, regenerate: bool = False) -> bool:
seen["track_id"] = track_id
seen["regenerate"] = regenerate
return True
monkeypatch.setattr(tracks_api, "delete_track", fake_delete)
response = client.delete(
"/api/tracks/track-99",
params={"regenerate": "true"},
headers={"Authorization": f"Bearer {token}"},
)
assert response.status_code == 204
assert seen["track_id"] == "track-99"
assert seen["regenerate"] is True
def test_list_combined_tracks(monkeypatch: pytest.MonkeyPatch) -> None:
token = _authenticate()
monkeypatch.setattr(
tracks_api, "list_combined_tracks", lambda db: [_combined_model()]
)
response = client.get(
"/api/tracks/combined",
headers={"Authorization": f"Bearer {token}"},
)
assert response.status_code == 200
payload = response.json()
assert len(payload) == 1
assert payload[0]["id"] == "combined-1"

View File

@@ -103,10 +103,12 @@ def test_load_tracks_creates_entries(monkeypatch: pytest.MonkeyPatch) -> None:
track_repo_instance = DummyTrackRepository(session_instance)
monkeypatch.setattr(tracks_load, "SessionLocal", lambda: session_instance)
monkeypatch.setattr(tracks_load, "StationRepository",
lambda session: station_repo_instance)
monkeypatch.setattr(tracks_load, "TrackRepository",
lambda session: track_repo_instance)
monkeypatch.setattr(
tracks_load, "StationRepository", lambda session: station_repo_instance
)
monkeypatch.setattr(
tracks_load, "TrackRepository", lambda session: track_repo_instance
)
parsed = tracks_load._parse_track_entries(
[
@@ -137,20 +139,26 @@ def test_load_tracks_skips_existing_pairs(monkeypatch: pytest.MonkeyPatch) -> No
DummyStation(id="station-b", location=_point(52.6, 13.5)),
],
)
existing_track = type("ExistingTrack", (), {
"start_station_id": "station-a",
"end_station_id": "station-b",
})
existing_track = type(
"ExistingTrack",
(),
{
"start_station_id": "station-a",
"end_station_id": "station-b",
},
)
track_repo_instance = DummyTrackRepository(
session_instance,
existing=[existing_track],
)
monkeypatch.setattr(tracks_load, "SessionLocal", lambda: session_instance)
monkeypatch.setattr(tracks_load, "StationRepository",
lambda session: station_repo_instance)
monkeypatch.setattr(tracks_load, "TrackRepository",
lambda session: track_repo_instance)
monkeypatch.setattr(
tracks_load, "StationRepository", lambda session: station_repo_instance
)
monkeypatch.setattr(
tracks_load, "TrackRepository", lambda session: track_repo_instance
)
parsed = tracks_load._parse_track_entries(
[
@@ -168,7 +176,9 @@ def test_load_tracks_skips_existing_pairs(monkeypatch: pytest.MonkeyPatch) -> No
assert not track_repo_instance.created
def test_load_tracks_skips_when_station_too_far(monkeypatch: pytest.MonkeyPatch) -> None:
def test_load_tracks_skips_when_station_too_far(
monkeypatch: pytest.MonkeyPatch,
) -> None:
session_instance = DummySession()
station_repo_instance = DummyStationRepository(
session_instance,
@@ -179,10 +189,12 @@ def test_load_tracks_skips_when_station_too_far(monkeypatch: pytest.MonkeyPatch)
track_repo_instance = DummyTrackRepository(session_instance)
monkeypatch.setattr(tracks_load, "SessionLocal", lambda: session_instance)
monkeypatch.setattr(tracks_load, "StationRepository",
lambda session: station_repo_instance)
monkeypatch.setattr(tracks_load, "TrackRepository",
lambda session: track_repo_instance)
monkeypatch.setattr(
tracks_load, "StationRepository", lambda session: station_repo_instance
)
monkeypatch.setattr(
tracks_load, "TrackRepository", lambda session: track_repo_instance
)
parsed = tracks_load._parse_track_entries(
[

View File

@@ -70,8 +70,8 @@ The system interacts with:
- User registration and authentication
- Railway network building and management
- Train scheduling and simulation
- Map visualization and interaction
- Train scheduling and simulation
- Leaderboards and user profiles
**Out of Scope:**

View File

@@ -65,6 +65,7 @@
"integrity": "sha512-2BCOP7TN8M+gVDj7/ht3hsaO/B/n5oDbiAyyvnRlNOs+u1o+JWNYTQrmpuNp1/Wq2gcFrI01JAW+paEKDMx/CA==",
"dev": true,
"license": "MIT",
"peer": true,
"dependencies": {
"@babel/code-frame": "^7.27.1",
"@babel/generator": "^7.28.3",
@@ -1425,6 +1426,7 @@
"integrity": "sha512-2Q7WS25j4pS1cS8yw3d6buNCVJukOTeQ39bAnwR6sOJbaxvyCGebzTMypDFN82CxBLnl+lSWVdCCWbRY6y9yZQ==",
"dev": true,
"license": "MIT",
"peer": true,
"dependencies": {
"undici-types": "~6.21.0"
}
@@ -1442,6 +1444,7 @@
"integrity": "sha512-RFA/bURkcKzx/X9oumPG9Vp3D3JUgus/d0b67KB0t5S/raciymilkOa66olh78MUI92QLbEJevO7rvqU/kjwKA==",
"dev": true,
"license": "MIT",
"peer": true,
"dependencies": {
"@types/prop-types": "*",
"csstype": "^3.0.2"
@@ -1493,6 +1496,7 @@
"integrity": "sha512-n1H6IcDhmmUEG7TNVSspGmiHHutt7iVKtZwRppD7e04wha5MrkV1h3pti9xQLcCMt6YWsncpoT0HMjkH1FNwWQ==",
"dev": true,
"license": "MIT",
"peer": true,
"dependencies": {
"@typescript-eslint/scope-manager": "8.46.0",
"@typescript-eslint/types": "8.46.0",
@@ -1829,6 +1833,7 @@
"integrity": "sha512-NZyJarBfL7nWwIq+FDL6Zp/yHEhePMNnnJ0y3qfieCrmNvYct8uvtiV41UvlSe6apAfk0fY1FbWx+NwfmpvtTg==",
"dev": true,
"license": "MIT",
"peer": true,
"bin": {
"acorn": "bin/acorn"
},
@@ -2202,6 +2207,7 @@
}
],
"license": "MIT",
"peer": true,
"dependencies": {
"baseline-browser-mapping": "^2.8.9",
"caniuse-lite": "^1.0.30001746",
@@ -2852,6 +2858,7 @@
"deprecated": "This version is no longer supported. Please see https://eslint.org/version-support for other options.",
"dev": true,
"license": "MIT",
"peer": true,
"dependencies": {
"@eslint-community/eslint-utils": "^4.2.0",
"@eslint-community/regexpp": "^4.6.1",
@@ -4538,7 +4545,8 @@
"version": "1.9.4",
"resolved": "https://registry.npmjs.org/leaflet/-/leaflet-1.9.4.tgz",
"integrity": "sha512-nxS1ynzJOmOlHp+iL3FyWqK89GtNL8U8rvlMOsQdTTssxZwCXh8N2NB3GDQOL+YR3XnWyZAxwQixURb+FA74PA==",
"license": "BSD-2-Clause"
"license": "BSD-2-Clause",
"peer": true
},
"node_modules/levn": {
"version": "0.4.1",
@@ -5323,6 +5331,7 @@
"resolved": "https://registry.npmjs.org/react/-/react-18.3.1.tgz",
"integrity": "sha512-wS+hAgJShR0KhEvPJArfuPVN1+Hz1t0Y6n5jLrGQbkb4urgPE/0Rve+1kMB1v/oWgHgm4WIcV+i7F2pTVj+2iQ==",
"license": "MIT",
"peer": true,
"dependencies": {
"loose-envify": "^1.1.0"
},
@@ -5335,6 +5344,7 @@
"resolved": "https://registry.npmjs.org/react-dom/-/react-dom-18.3.1.tgz",
"integrity": "sha512-5m4nQKp+rZRb09LNH59GM4BxTh9251/ylbKIbpe7TpGxfJ+9kv6BLkLBXIjjspbgbnIBNqlI23tRnTWT0snUIw==",
"license": "MIT",
"peer": true,
"dependencies": {
"loose-envify": "^1.1.0",
"scheduler": "^0.23.2"
@@ -6240,6 +6250,7 @@
"integrity": "sha512-jl1vZzPDinLr9eUt3J/t7V6FgNEw9QjvBPdysz9KfQDD41fQrC2Y4vKQdiaUpFT4bXlb1RHhLpp8wtm6M5TgSw==",
"dev": true,
"license": "Apache-2.0",
"peer": true,
"bin": {
"tsc": "bin/tsc",
"tsserver": "bin/tsserver"
@@ -6328,6 +6339,7 @@
"integrity": "sha512-j3lYzGC3P+B5Yfy/pfKNgVEg4+UtcIJcVRt2cDjIOmhLourAqPqf8P7acgxeiSgUB7E3p2P8/3gNIgDLpwzs4g==",
"dev": true,
"license": "MIT",
"peer": true,
"dependencies": {
"esbuild": "^0.21.3",
"postcss": "^8.4.43",

View File

@@ -19,11 +19,13 @@ function App(): JSX.Element {
startId: string | null;
endId: string | null;
}>({ startId: null, endId: null });
const [selectedTrackId, setSelectedTrackId] = useState<string | null>(null);
useEffect(() => {
if (status !== 'success' || !data?.stations.length) {
setFocusedStationId(null);
setRouteSelection({ startId: null, endId: null });
setSelectedTrackId(null);
return;
}
@@ -92,8 +94,16 @@ function App(): JSX.Element {
return stationById.get(focusedStationId) ?? null;
}, [data, focusedStationId, stationById]);
const selectedTrack = useMemo(() => {
if (!data || !selectedTrackId) {
return null;
}
return data.tracks.find((track) => track.id === selectedTrackId) ?? null;
}, [data, selectedTrackId]);
const handleStationSelection = (stationId: string) => {
setFocusedStationId(stationId);
setSelectedTrackId(null);
setRouteSelection((current) => {
if (!current.startId || (current.startId && current.endId)) {
return { startId: stationId, endId: null };
@@ -111,6 +121,16 @@ function App(): JSX.Element {
setRouteSelection({ startId: null, endId: null });
};
const handleCreateTrack = () => {
if (!routeSelection.startId || !routeSelection.endId) {
return;
}
// TODO: Implement track creation API call
alert(
`Creating track between ${stationById.get(routeSelection.startId)?.name} and ${stationById.get(routeSelection.endId)?.name}`
);
};
return (
<div className="app-shell">
<header className="app-header">
@@ -146,7 +166,9 @@ function App(): JSX.Element {
startStationId={routeSelection.startId}
endStationId={routeSelection.endId}
routeSegments={routeSegments}
selectedTrackId={selectedTrackId}
onStationClick={handleStationSelection}
onTrackClick={setSelectedTrackId}
/>
</div>
<div className="route-panel">
@@ -206,6 +228,19 @@ function App(): JSX.Element {
</ol>
</div>
)}
{routeSelection.startId &&
routeSelection.endId &&
routeComputation.error && (
<div className="route-panel__actions">
<button
type="button"
className="primary-button"
onClick={handleCreateTrack}
>
Create Track
</button>
</div>
)}
</div>
<div className="grid">
<div>
@@ -318,6 +353,49 @@ function App(): JSX.Element {
</dl>
</div>
)}
{selectedTrack && (
<div className="selected-track">
<h3>Selected Track</h3>
<dl>
<div>
<dt>Start Station</dt>
<dd>
{stationById.get(selectedTrack.startStationId)?.name ??
'Unknown'}
</dd>
</div>
<div>
<dt>End Station</dt>
<dd>
{stationById.get(selectedTrack.endStationId)?.name ??
'Unknown'}
</dd>
</div>
<div>
<dt>Length</dt>
<dd>
{selectedTrack.lengthMeters > 0
? `${(selectedTrack.lengthMeters / 1000).toFixed(2)} km`
: 'N/A'}
</dd>
</div>
<div>
<dt>Max Speed</dt>
<dd>{selectedTrack.maxSpeedKph} km/h</dd>
</div>
{selectedTrack.status && (
<div>
<dt>Status</dt>
<dd>{selectedTrack.status}</dd>
</div>
)}
<div>
<dt>Bidirectional</dt>
<dd>{selectedTrack.isBidirectional ? 'Yes' : 'No'}</dd>
</div>
</dl>
</div>
)}
</div>
)}
</section>

View File

@@ -19,7 +19,9 @@ interface NetworkMapProps {
readonly startStationId?: string | null;
readonly endStationId?: string | null;
readonly routeSegments?: LatLngExpression[][];
readonly selectedTrackId?: string | null;
readonly onStationClick?: (stationId: string) => void;
readonly onTrackClick?: (trackId: string) => void;
}
interface StationPosition {
@@ -36,7 +38,9 @@ export function NetworkMap({
startStationId,
endStationId,
routeSegments = [],
selectedTrackId,
onStationClick,
onTrackClick,
}: NetworkMapProps): JSX.Element {
const stationPositions = useMemo<StationPosition[]>(() => {
return snapshot.stations.map((station) => ({
@@ -117,13 +121,43 @@ export function NetworkMap({
url="https://{s}.tile.openstreetmap.org/{z}/{x}/{y}.png"
/>
{focusedPosition ? <StationFocus position={focusedPosition} /> : null}
{trackSegments.map((segment, index) => (
<Polyline
key={`track-${index}`}
positions={segment}
pathOptions={{ color: '#334155', weight: 3, opacity: 0.8 }}
/>
))}
{trackSegments.map((segment, index) => {
const track = snapshot.tracks[index];
const isSelected = track.id === selectedTrackId;
return (
<Polyline
key={`track-${track.id}`}
positions={segment}
pathOptions={{
color: isSelected ? '#3b82f6' : '#334155',
weight: isSelected ? 5 : 3,
opacity: 0.8,
}}
eventHandlers={{
click: () => {
onTrackClick?.(track.id);
},
}}
>
<Tooltip>
{track.startStationId} {track.endStationId}
<br />
Length:{' '}
{track.lengthMeters > 0
? `${(track.lengthMeters / 1000).toFixed(1)} km`
: 'N/A'}
<br />
Max Speed: {track.maxSpeedKph} km/h
{track.status && (
<>
<br />
Status: {track.status}
</>
)}
</Tooltip>
</Polyline>
);
})}
{routeSegments.map((segment, index) => (
<Polyline
key={`route-${index}`}

View File

@@ -266,6 +266,12 @@ body {
margin: 0;
}
.route-panel__actions {
margin-top: 1rem;
display: flex;
gap: 0.75rem;
}
.selected-station {
margin-top: 1rem;
padding: 1rem 1.25rem;
@@ -305,6 +311,45 @@ body {
color: rgba(226, 232, 240, 0.92);
}
.selected-track {
margin-top: 1rem;
padding: 1rem 1.25rem;
border-radius: 12px;
border: 1px solid rgba(59, 130, 246, 0.35);
background: rgba(37, 99, 235, 0.18);
display: grid;
gap: 0.75rem;
}
.selected-track h3 {
color: rgba(226, 232, 240, 0.9);
font-size: 1.1rem;
}
.selected-track dl {
display: grid;
gap: 0.45rem;
}
.selected-track dl > div {
display: flex;
align-items: baseline;
justify-content: space-between;
gap: 0.75rem;
}
.selected-track dt {
font-size: 0.8rem;
text-transform: uppercase;
letter-spacing: 0.08em;
color: rgba(226, 232, 240, 0.6);
}
.selected-track dd {
font-size: 0.95rem;
color: rgba(226, 232, 240, 0.92);
}
@media (min-width: 768px) {
.snapshot-layout {
gap: 2rem;

View File

@@ -62,13 +62,23 @@ def check_database_url():
print(f"Using database: {database_url}")
def run_command(cmd, cwd=None, description=""):
def run_command(cmd, cwd=None, description="", env=None):
"""Run a shell command and return the result."""
print(f"\n>>> {description}")
print(f"Running: {' '.join(cmd)}")
try:
result = subprocess.run(cmd, cwd=cwd, check=True,
capture_output=True, text=True)
env_vars = os.environ.copy()
if env:
env_vars.update(env)
env_vars.setdefault("PYTHONPATH", "/app")
result = subprocess.run(
cmd,
cwd=cwd,
check=True,
capture_output=True,
text=True,
env=env_vars,
)
if result.stdout:
print(result.stdout)
return result
@@ -86,7 +96,7 @@ def run_migrations():
run_command(
['alembic', 'upgrade', 'head'],
cwd='backend',
description="Running database migrations"
description="Running database migrations",
)