Compare commits

..

13 Commits

Author SHA1 Message Date
68048ff574 feat: Add combined track functionality with repository and service layers
Some checks failed
Backend CI / lint-and-test (push) Failing after 2m27s
Frontend CI / lint-and-build (push) Successful in 57s
- Introduced CombinedTrackModel, CombinedTrackCreate, and CombinedTrackRepository for managing combined tracks.
- Implemented logic to create combined tracks based on existing tracks between two stations.
- Added methods to check for existing combined tracks and retrieve constituent track IDs.
- Enhanced TrackModel and TrackRepository to support OSM ID and track updates.
- Created migration scripts for adding combined tracks table and OSM ID to tracks.
- Updated services and API endpoints to handle combined track operations.
- Added tests for combined track creation, repository methods, and API interactions.
2025-11-10 14:12:28 +01:00
f73ab7ad14 fix: simplify import statements in osm_refresh.py 2025-10-11 22:11:45 +02:00
3c97c47f7e feat: add track selection and display functionality in the app 2025-10-11 22:10:53 +02:00
c35049cd54 fix: formatting (black)
Some checks failed
Backend CI / lint-and-test (push) Failing after 1m54s
2025-10-11 21:58:32 +02:00
f9086d2d04 feat: initialize database with demo data on first run and update README
Some checks failed
Backend CI / lint-and-test (push) Failing after 1m33s
Frontend CI / lint-and-build (push) Successful in 17s
2025-10-11 21:52:30 +02:00
25ca7ab196 Add OSM Track Harvesting Policy and demo database initialization script
- Updated documentation to include OSM Track Harvesting Policy with details on railway types, service filters, usage filters, and geometry guardrails.
- Introduced a new script `init_demo_db.py` to automate the database setup process, including environment checks, running migrations, and loading OSM fixtures for demo data.
2025-10-11 21:37:25 +02:00
0b84ee953e fix: correct grammar and formatting in README 2025-10-11 21:37:01 +02:00
8877380f21 fix: revert README structure 2025-10-11 21:07:27 +02:00
4393f17c45 refactor: simplify stage plan return type and enhance test coverage for OSM refresh 2025-10-11 20:37:25 +02:00
e10b2ee71c docs: fix formatting 2025-10-11 20:23:08 +02:00
1c8adb36fe feat: Add OSM refresh script and update loading scripts for improved database handling 2025-10-11 20:21:14 +02:00
c2927f2f60 feat: Enhance track model and import functionality
- Added new fields to TrackModel: status, is_bidirectional, and coordinates.
- Updated network service to handle new track attributes and geometry extraction.
- Introduced CLI scripts for importing and loading tracks from OpenStreetMap.
- Implemented normalization of track elements to ensure valid geometries.
- Enhanced tests for track model, network service, and import/load scripts.
- Updated frontend to accommodate new track attributes and improve route computation.
- Documented OSM ingestion process in architecture and runtime views.
2025-10-11 19:54:10 +02:00
090dca29c2 feat: add route selection functionality and improve station handling
- Added `vitest` for testing and created initial tests for route utilities.
- Implemented route selection logic in the App component, allowing users to select start and end stations.
- Updated the NetworkMap component to reflect focused and selected stations, including visual indicators for start and end stations.
- Enhanced the route panel UI to display selected route information and estimated lengths.
- Introduced utility functions for building track adjacency and computing routes based on selected stations.
- Improved styling for route selection and station list items to enhance user experience.
2025-10-11 19:28:35 +02:00
48 changed files with 541929 additions and 276 deletions

287
README.md
View File

@@ -1,146 +1,213 @@
# Rail Game
A browser-based railway simulation game using real world railway maps from OpenStreetMap.
A browser-based railway simulation game using real-world railway maps from OpenStreetMap.
## At a glance
- Frontend: React + Vite (TypeScript)
- Backend: Python (FastAPI, SQLAlchemy)
- Database: PostgreSQL with PostGIS (spatial types)
- Mapping: Leaflet + OpenStreetMap
## Features
- Real world railway maps
- Interactive Leaflet map preview of the demo network snapshot
- Build and manage your own railway network
- Dynamic train schedules
- Real-world railway maps
- Interactive Leaflet map preview of a demo network snapshot
- Build and manage your railway network
- Dynamic train schedules and simulated trains
## Architecture
## Current project layout
The project is built using the following technologies:
This repository contains a full-stack demo app (frontend + backend), supporting scripts, docs and infra. Key folders:
- Frontend: HTML5, CSS3, JavaScript, React
- Backend: Python, FastAPI, Flask, SQLAlchemy
- Database: PostgreSQL with PostGIS extension
- Mapping: Leaflet, OpenStreetMap
- `backend/` — FastAPI application, models, services, migration scripts and backend tests.
- `frontend/` — React app (Vite) and frontend tests.
- `docs/` — Architecture docs and ADRs.
- `infra/` — Deployment assets (Dockerfiles, compose files, init scripts).
- `data/` — Fixtures and imported OSM snapshots.
- `scripts/` — Utility scripts (precommit helpers, setup hooks).
- `tests/` — End-to-end tests and cross-cutting tests.
## Project Structure
Planned structure for code and assets (folders created as needed):
```text
rail-game/
|-- backend/
| |-- app/
| | |-- api/ # FastAPI/Flask route handlers
| | |-- core/ # Config, startup, shared utilities
| | |-- models/ # SQLAlchemy models and schemas
| | |-- services/ # Domain logic and service layer
| | `-- websocket/ # Real-time communication handlers
| |-- tests/ # Backend unit and integration tests
| `-- requirements/ # Backend dependency lockfiles
|-- frontend/
| |-- public/ # Static assets served as-is
| |-- src/
| | |-- components/ # Reusable React components
| | |-- hooks/ # Custom React hooks
| | |-- pages/ # Top-level routed views
| | |-- state/ # Redux/Context stores and slices
| | |-- styles/ # Global and modular stylesheets
| | `-- utils/ # Frontend helpers and formatters
| `-- tests/ # Frontend unit and integration tests
|-- docs/ # Architecture docs, ADRs, guides
|-- infra/ # Deployment, IaC, Docker, CI workflows
|-- scripts/ # Tooling for setup, linting, migrations
|-- data/ # Seed data, fixtures, import/export tooling
`-- tests/ # End-to-end and cross-cutting tests
```
Use `infra/` to capture deployment assets (Dockerfiles, compose files, Terraform) and `.github/` for automation. Shared code that crosses layers should live in the respective service directories or dedicated packages under `backend/`.
Refer to the in-repo `docs/` for architecture decisions and deeper design notes.
## Installation
1. Clone the repository:
Below are concise, verified steps for getting the project running locally. Commands show both PowerShell (Windows) and Bash/macOS/Linux variants where they differ.
```bash
git clone https://github.com/zwitschi/rail-game.git
cd rail-game
```
## Prerequisites
2. Set up the backend (from the project root):
- Git
- Python 3.10+ (3.11 recommended) and pip
- Node.js 16+ (or the version required by `frontend/package.json`)
- PostgreSQL with PostGIS if you want to run the full DB-backed stack locally
- Docker & Docker Compose (optional, for containerized dev)
```bash
python -m venv .venv
.\.venv\Scripts\activate
python -m pip install -e .[dev]
```
### Clone repository
3. Set up the frontend:
PowerShell / Bash
```bash
cd frontend
npm install
cd ..
```
git clone https://github.com/zwitschi/rail-game.git
cd rail-game
4. Copy the sample environment file and adjust the database URLs per environment:
### Backend: create virtual environment and install
```bash
copy .env.example .env # PowerShell: Copy-Item .env.example .env
```
PowerShell
`DATABASE_URL`, `TEST_DATABASE_URL`, and `ALEMBIC_DATABASE_URL` control the runtime, test, and migration connections respectively.
5. (Optional) Point Git to the bundled hooks: `pwsh scripts/setup_hooks.ps1`.
6. Run database migrations to set up the schema:
python -m venv .venv
.\.venv\Scripts\Activate.ps1
python -m pip install -e .[dev]
```bash
cd backend
alembic upgrade head
cd ..
```
Bash / macOS / Linux
7. Start the development servers from separate terminals:
python -m venv .venv
source .venv/bin/activate
python -m pip install -e '.[dev]'
- Backend: `uvicorn backend.app.main:app --reload --port 8000`
- Frontend: `cd frontend && npm run dev`
### Notes
8. Open your browser: frontend runs at `http://localhost:5173`, backend API at `http://localhost:8000`.
9. Run quality checks:
- Installing editable extras (`.[dev]`) installs dev/test tools used by the backend (pytest, black, isort, alembic, etc.).
- Backend unit tests: `pytest`
- Backend formatters: `black backend/` and `isort backend/`
- Frontend lint: `cd frontend && npm run lint`
- Frontend type/build check: `cd frontend && npm run build`
### Environment file
10. Build for production:
Copy the sample `.env.example` to `.env` and adjust the database connection strings as needed.
- Frontend bundle: `cd frontend && npm run build`
- Backend container: `docker build -t rail-game-backend backend/`
PowerShell
11. Run containers:
- Backend: `docker run -p 8000:8000 rail-game-backend`
- Frontend: Serve `frontend/dist` with any static file host.
Copy-Item .env.example .env
## Database Migrations
Bash
- Alembic configuration lives in `backend/alembic.ini` with scripts under `backend/migrations/`.
- Generate new revisions with `alembic revision --autogenerate -m "short description"` (ensure models are imported before running autogenerate).
- Apply migrations via `alembic upgrade head`; rollback with `alembic downgrade -1` during development.
cp .env.example .env
## PostgreSQL Configuration
### Important environment variables
- **Database URLs**: The backend reads connection strings from the `.env` file. Set `DATABASE_URL` (development), `TEST_DATABASE_URL` (pytest/CI), and `ALEMBIC_DATABASE_URL` (migration runner). URLs use the SQLAlchemy format, e.g. `postgresql+psycopg://user:password@host:port/database`.
- **Required Extensions**: Migrations enable `postgis` for spatial types and `pgcrypto` for UUID generation. Ensure your Postgres instance has these extensions available.
- **Recommended Databases**: create `railgame_dev` and `railgame_test` (or variants) owned by a dedicated `railgame` role with privileges to create extensions.
- **Connection Debugging**: Toggle `DATABASE_ECHO=true` in `.env` to log SQL statements during development.
- `DATABASE_URL` — runtime DB connection for the app
- `TEST_DATABASE_URL` — database used by pytest in CI/local tests
- `ALEMBIC_DATABASE_URL` — used when running alembic outside the app process
## API Preview
### Database (Postgres + PostGIS)
- `GET /api/health` Lightweight readiness probe.
- `POST /api/auth/register` Creates an in-memory demo account and returns a JWT access token.
- `POST /api/auth/login` Exchanges credentials for a JWT access token (demo user: `demo` / `railgame123`).
- `GET /api/auth/me` Returns the current authenticated user profile.
- `GET /api/network` Returns a sample snapshot of stations, tracks, and trains (camelCase fields) generated from shared domain models; requires a valid bearer token.
If you run Postgres locally, create the dev/test DBs and ensure the `postgis` extension is available. Example (psql):
## Developer Tooling
-- create DBs (run in psql as a superuser or role with create privileges)
CREATE DATABASE railgame_dev;
CREATE DATABASE railgame_test;
- Install backend tooling in editable mode: `python -m pip install -e .[dev]`.
- Configure git hooks (Git for Windows works with these scripts): `pwsh scripts/setup_hooks.ps1`.
- Pre-commit hooks run `black`, `isort`, `pytest backend/tests`, and `npm run lint` if frontend dependencies are installed.
- Run the checks manually any time with `python scripts/precommit.py`.
- Frontend lint/format commands live in `frontend/package.json` (`npm run lint`, `npm run format`).
- Continuous integration runs via workflows in `.github/workflows/` covering backend lint/tests and frontend lint/build.
-- connect to the db and enable extensions
\c railgame_dev
CREATE EXTENSION IF NOT EXISTS postgis;
CREATE EXTENSION IF NOT EXISTS pgcrypto;
Adjust DB names and roles to match your `.env` values.
### Quick database setup (recommended)
For a streamlined setup, use the included initialization script after configuring your `.env` file:
PowerShell / Bash
python scripts/init_demo_db.py
This script validates your environment, runs migrations, and loads demo OSM data. Use `--dry-run` to preview changes, or `--region` to load specific regions.
**Note**: The script uses `python-dotenv` to load your `.env` file. If not installed, run `pip install python-dotenv`.
### Run migrations
PowerShell / Bash
cd backend
alembic upgrade head
cd ..
If you prefer to run alembic with a specific URL without editing `.env`, set `ALEMBIC_DATABASE_URL` in the environment before running the command.
### Load OSM fixtures (optional)
Use the included scripts to refresh stations and tracks from saved OSM fixtures. This step assumes the database is migrated and reachable.
PowerShell / Bash
# dry-run
python -m backend.scripts.osm_refresh --region all --no-commit
# commit to DB
python -m backend.scripts.osm_refresh --region all
See `backend/scripts/*.py` for more granular import options (`--skip-*` flags).
### Frontend
Install dependencies and run the dev server from the `frontend/` directory.
PowerShell / Bash
cd frontend
npm install
npm run dev
The frontend runs at `http://localhost:5173` by default (Vite). The React app talks to the backend API at the address configured in its environment (see `frontend` README or `vite` config).
### Run backend locally (development)
PowerShell / Bash
# from project root
uvicorn backend.app.main:app --reload --port 8000
The backend API listens at `http://localhost:8000` by default.
### Tests & linters
Backend
pytest
black backend/ && isort backend/
Frontend
cd frontend
npm run lint
npm run build # type/build check
### Docker / Compose (optional)
Build and run both services with Docker Compose if you prefer containers:
PowerShell / Bash
docker compose up --build
This starts all services (Postgres, Redis, backend, frontend) and automatically initializes the database with demo data on first run. The backend waits for the database to be ready before running migrations and loading OSM fixtures.
**Services:**
- Backend API: `http://localhost:8000`
- Frontend: `http://localhost:8080`
- Postgres: `localhost:5432`
- Redis: `localhost:6379`
This expects a working Docker environment and may require you to set DB URLs to point to the containerized Postgres service if one is defined in `docker-compose.yml`.
## Troubleshooting
- If migrations fail with missing PostGIS functions, ensure `postgis` is installed and enabled in the target database.
- If alembic autogenerate creates unexpected changes, confirm the models being imported match the app import path used by `alembic` (see `backend/migrations/env.py`).
- For authentication/debugging, the demo user is `demo` / `railgame123` (used by some integration tests and the demo auth flow).
- If frontend dev server fails due to node version, check `frontend/package.json` engines or use `nvm`/`nvm-windows` to match the recommended Node version.
## API preview
Some useful endpoints for local testing:
- `GET /api/health` — readiness probe
- `POST /api/auth/register` — demo account creation + JWT
- `POST /api/auth/login` — exchange credentials for JWT (demo user: `demo` / `railgame123`)
- `GET /api/auth/me` — current user profile (requires bearer token)
- `GET /api/network` — sample network snapshot (requires bearer token)
## Contributing
- See `docs/` for architecture and ADRs.
- Keep tests green and follow formatting rules (black, isort for Python; Prettier/ESLint for frontend).
- Open issues or PRs for bugs, features, or docs improvements.

80
TODO.md
View File

@@ -1,80 +0,0 @@
# Development TODO Plan
## Phase 1 Project Foundations
- [x] Initialize Git hooks, linting, and formatting tooling (ESLint, Prettier, isort, black).
- [x] Configure `pyproject.toml` or equivalent for backend dependency management.
- [x] Scaffold FastAPI application entrypoint with health-check endpoint.
- [x] Bootstrap React app with Vite/CRA, including routing skeleton and global state provider.
- [x] Define shared TypeScript/Python models for core domain entities (tracks, stations, trains).
- [x] Set up CI workflow for linting and test automation (GitHub Actions).
## Phase 2 Core Features
- [x] Implement authentication flow (backend JWT, frontend login/register forms).
- [x] Build map visualization integrating Leaflet with OSM tiles.
- [x] Define geographic bounding boxes and filtering rules for importing real-world stations from OpenStreetMap.
- [x] Implement an import script/CLI that pulls OSM station data and normalizes it to the PostGIS schema.
- [x] Expose backend CRUD endpoints for stations (create, update, archive) with validation and geometry handling.
- [ ] Build React map tooling for selecting a station.
- [ ] Build tools for station editing, including form validation.
- [ ] Define track selection criteria and tagging rules for harvesting OSM rail segments within target regions.
- [ ] Extend the importer to load track geometries and associate them with existing stations.
- [ ] Implement backend track-management APIs with length/speed validation and topology checks.
- [ ] Create a frontend track-drawing workflow (polyline editor, snapping to stations, undo/redo).
- [ ] Design train connection manager requirements (link trains to operating tracks, manage consist data).
- [ ] Implement backend services and APIs to attach trains to routes and update assignments.
- [ ] Add UI flows for managing train connections, including visual feedback on the map.
- [ ] Establish train scheduling service with validation rules, conflict detection, and persistence APIs.
- [ ] Provide frontend scheduling tools (timeline or table view) for creating and editing train timetables.
- [ ] Develop frontend dashboards for resources, schedules, and achievements.
- [ ] Add real-time simulation updates (WebSocket layer, frontend subscription hooks).
## Phase 3 Data & Persistence
- [x] Design PostgreSQL/PostGIS schema and migrations (Alembic or similar).
- [x] Implement data access layer with SQLAlchemy and repository abstractions.
- [ ] Decide on canonical fixture scope (demo geography, sample trains) and document expected dataset size.
- [ ] Author fixture generation scripts that export JSON/GeoJSON compatible with the repository layer.
- [x] Create ingestion utilities to load fixtures into local and CI databases.
- [ ] Provision a Redis instance/container for local development.
- [ ] Add caching abstractions in backend services (e.g., network snapshot, map layers).
- [ ] Implement cache invalidation hooks tied to repository mutations.
## Phase 4 Testing & Quality
- [x] Write unit tests for backend services and models.
- [ ] Configure Jest/RTL testing utilities and shared mocks for Leaflet and network APIs.
- [ ] Write component tests for map controls, station builder UI, and dashboards.
- [ ] Add integration tests for custom hooks (network snapshot, scheduling forms).
- [x] Stand up Playwright/Cypress project structure with authentication helpers.
- [x] Script login end-to-end flow (Playwright).
- [ ] Script station creation end-to-end flow.
- [ ] Script track placement end-to-end flow.
- [ ] Script scheduling end-to-end flow.
- [ ] Define load/performance targets (requests per second, simulation latency) and tooling.
- [ ] Implement performance test harness covering scheduling and real-time updates.
## Phase 5 Deployment & Ops
- [x] Create Dockerfile for frontend.
- [x] Create Dockerfile for backend.
- [x] Create docker-compose for local development with Postgres/Redis dependencies.
- [ ] Add task runner commands to orchestrate container workflows.
- [ ] Set up CI/CD pipeline for automated builds, tests, and container publishing.
- [ ] Provision infrastructure scripts (Terraform/Ansible) targeting initial cloud environment.
- [ ] Define environment configuration strategy (secrets management, config maps).
- [ ] Configure observability stack (logging, metrics, tracing).
- [ ] Integrate tracing/logging exporters into backend services.
- [ ] Document deployment pipeline and release process.
## Phase 6 Polish & Expansion
- [ ] Add leaderboards and achievements logic with UI integration.
- [ ] Design data model changes required for achievements and ranking.
- [ ] Implement accessibility audit fixes (WCAG compliance).
- [ ] Conduct accessibility audit (contrast, keyboard navigation, screen reader paths).
- [ ] Optimize asset loading and introduce lazy loading strategies.
- [ ] Establish performance budgets for bundle size and render times.
- [ ] Evaluate multiplayer/coop roadmap and spike POCs where feasible.
- [ ] Prototype networking approach (WebRTC/WebSocket) for cooperative sessions.

View File

@@ -8,15 +8,26 @@ ENV PYTHONDONTWRITEBYTECODE=1 \
WORKDIR /app
RUN apt-get update \
&& apt-get install -y --no-install-recommends build-essential libpq-dev \
&& apt-get install -y --no-install-recommends build-essential libpq-dev postgresql-client \
&& rm -rf /var/lib/apt/lists/*
COPY backend/requirements/base.txt ./backend/requirements/base.txt
RUN pip install --upgrade pip \
&& pip install -r backend/requirements/base.txt
COPY scripts ./scripts
COPY .env.example ./.env.example
COPY .env* ./
COPY backend ./backend
EXPOSE 8000
CMD ["uvicorn", "backend.app.main:app", "--host", "0.0.0.0", "--port", "8000"]
# Initialize database with demo data if INIT_DEMO_DB is set
CMD ["sh", "-c", "\
export PYTHONPATH=/app/backend && \
echo 'Waiting for database...' && \
while ! pg_isready -h db -p 5432 -U railgame >/dev/null 2>&1; do sleep 1; done && \
echo 'Database is ready!' && \
if [ \"$INIT_DEMO_DB\" = \"true\" ]; then python scripts/init_demo_db.py; fi && \
uvicorn backend.app.main:app --host 0.0.0.0 --port 8000"]

View File

@@ -1,6 +1,6 @@
[alembic]
script_location = migrations
sqlalchemy.url = postgresql+psycopg://railgame:railgame@localhost:5432/railgame
sqlalchemy.url = postgresql+psycopg://railgame:railgame@localhost:5432/railgame_dev
[loggers]
keys = root,sqlalchemy,alembic

View File

@@ -4,9 +4,11 @@ from backend.app.api.auth import router as auth_router
from backend.app.api.health import router as health_router
from backend.app.api.network import router as network_router
from backend.app.api.stations import router as stations_router
from backend.app.api.tracks import router as tracks_router
router = APIRouter()
router.include_router(health_router, tags=["health"])
router.include_router(auth_router)
router.include_router(network_router)
router.include_router(stations_router)
router.include_router(tracks_router)

153
backend/app/api/tracks.py Normal file
View File

@@ -0,0 +1,153 @@
from __future__ import annotations
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy.orm import Session
from backend.app.api.deps import get_current_user, get_db
from backend.app.models import (
CombinedTrackModel,
TrackCreate,
TrackUpdate,
TrackModel,
UserPublic,
)
from backend.app.services.combined_tracks import (
create_combined_track,
get_combined_track,
list_combined_tracks,
)
from backend.app.services.tracks import (
create_track,
delete_track,
regenerate_combined_tracks,
update_track,
get_track,
list_tracks,
)
router = APIRouter(prefix="/tracks", tags=["tracks"])
@router.get("", response_model=list[TrackModel])
def read_combined_tracks(
_: UserPublic = Depends(get_current_user),
db: Session = Depends(get_db),
) -> list[TrackModel]:
"""Return all base tracks."""
return list_tracks(db)
@router.get("/combined", response_model=list[CombinedTrackModel])
def read_combined_tracks_combined(
_: UserPublic = Depends(get_current_user),
db: Session = Depends(get_db),
) -> list[CombinedTrackModel]:
return list_combined_tracks(db)
@router.get("/{track_id}", response_model=TrackModel)
def read_track(
track_id: str,
_: UserPublic = Depends(get_current_user),
db: Session = Depends(get_db),
) -> TrackModel:
track = get_track(db, track_id)
if track is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Track {track_id} not found",
)
return track
@router.get("/combined/{combined_track_id}", response_model=CombinedTrackModel)
def read_combined_track(
combined_track_id: str,
_: UserPublic = Depends(get_current_user),
db: Session = Depends(get_db),
) -> CombinedTrackModel:
combined_track = get_combined_track(db, combined_track_id)
if combined_track is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Combined track {combined_track_id} not found",
)
return combined_track
@router.post("", response_model=TrackModel, status_code=status.HTTP_201_CREATED)
def create_track_endpoint(
payload: TrackCreate,
regenerate: bool = False,
_: UserPublic = Depends(get_current_user),
db: Session = Depends(get_db),
) -> TrackModel:
try:
track = create_track(db, payload)
except ValueError as exc:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail=str(exc)
) from exc
if regenerate:
regenerate_combined_tracks(
db, [track.start_station_id, track.end_station_id])
return track
@router.post(
"/combined",
response_model=CombinedTrackModel,
status_code=status.HTTP_201_CREATED,
summary="Create a combined track between two stations using pathfinding",
)
def create_combined_track_endpoint(
start_station_id: str,
end_station_id: str,
_: UserPublic = Depends(get_current_user),
db: Session = Depends(get_db),
) -> CombinedTrackModel:
combined_track = create_combined_track(
db, start_station_id, end_station_id)
if combined_track is None:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Could not create combined track: no path exists between stations or track already exists",
)
return combined_track
@router.put("/{track_id}", response_model=TrackModel)
def update_track_endpoint(
track_id: str,
payload: TrackUpdate,
regenerate: bool = False,
_: UserPublic = Depends(get_current_user),
db: Session = Depends(get_db),
) -> TrackModel:
track = update_track(db, track_id, payload)
if track is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Track {track_id} not found",
)
if regenerate:
regenerate_combined_tracks(
db, [track.start_station_id, track.end_station_id])
return track
@router.delete("/{track_id}", status_code=status.HTTP_204_NO_CONTENT)
def delete_track_endpoint(
track_id: str,
regenerate: bool = False,
_: UserPublic = Depends(get_current_user),
db: Session = Depends(get_db),
) -> None:
deleted = delete_track(db, track_id, regenerate=regenerate)
if not deleted:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Track {track_id} not found",
)

View File

@@ -75,6 +75,43 @@ STATION_TAG_FILTERS: Mapping[str, Tuple[str, ...]] = {
}
# Tags that describe rail infrastructure usable for train routing.
TRACK_ALLOWED_RAILWAY_TYPES: Tuple[str, ...] = (
"rail",
"light_rail",
"subway",
"tram",
"narrow_gauge",
"disused",
"construction",
)
TRACK_TAG_FILTERS: Mapping[str, Tuple[str, ...]] = {
"railway": TRACK_ALLOWED_RAILWAY_TYPES,
}
# Track ingestion policy
TRACK_EXCLUDED_SERVICE_TAGS: Tuple[str, ...] = (
"yard",
"siding",
"spur",
"crossover",
"industrial",
"military",
)
TRACK_EXCLUDED_USAGE_TAGS: Tuple[str, ...] = (
"military",
"tourism",
)
TRACK_MIN_LENGTH_METERS: float = 75.0
TRACK_STATION_SNAP_RADIUS_METERS: float = 350.0
def compile_overpass_filters(filters: Mapping[str, Iterable[str]]) -> str:
"""Build an Overpass boolean expression that matches the provided filters."""
@@ -89,5 +126,11 @@ __all__ = [
"BoundingBox",
"DEFAULT_REGIONS",
"STATION_TAG_FILTERS",
"TRACK_ALLOWED_RAILWAY_TYPES",
"TRACK_TAG_FILTERS",
"TRACK_EXCLUDED_SERVICE_TAGS",
"TRACK_EXCLUDED_USAGE_TAGS",
"TRACK_MIN_LENGTH_METERS",
"TRACK_STATION_SNAP_RADIUS_METERS",
"compile_overpass_filters",
]

View File

@@ -41,11 +41,14 @@ class User(Base, TimestampMixin):
id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True), primary_key=True, default=uuid.uuid4
)
username: Mapped[str] = mapped_column(String(64), unique=True, nullable=False)
email: Mapped[str | None] = mapped_column(String(255), unique=True, nullable=True)
username: Mapped[str] = mapped_column(
String(64), unique=True, nullable=False)
email: Mapped[str | None] = mapped_column(
String(255), unique=True, nullable=True)
full_name: Mapped[str | None] = mapped_column(String(128), nullable=True)
password_hash: Mapped[str] = mapped_column(String(256), nullable=False)
role: Mapped[str] = mapped_column(String(32), nullable=False, default="player")
role: Mapped[str] = mapped_column(
String(32), nullable=False, default="player")
preferences: Mapped[str | None] = mapped_column(Text, nullable=True)
@@ -62,12 +65,50 @@ class Station(Base, TimestampMixin):
Geometry(geometry_type="POINT", srid=4326), nullable=False
)
elevation_m: Mapped[float | None] = mapped_column(Float, nullable=True)
is_active: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
is_active: Mapped[bool] = mapped_column(
Boolean, nullable=False, default=True)
class Track(Base, TimestampMixin):
__tablename__ = "tracks"
id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True), primary_key=True, default=uuid.uuid4
)
osm_id: Mapped[str | None] = mapped_column(String(32), nullable=True)
name: Mapped[str | None] = mapped_column(String(128), nullable=True)
start_station_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
ForeignKey("stations.id", ondelete="RESTRICT"),
nullable=False,
)
end_station_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
ForeignKey("stations.id", ondelete="RESTRICT"),
nullable=False,
)
length_meters: Mapped[float | None] = mapped_column(
Numeric(10, 2), nullable=True)
max_speed_kph: Mapped[int | None] = mapped_column(Integer, nullable=True)
is_bidirectional: Mapped[bool] = mapped_column(
Boolean, nullable=False, default=True
)
status: Mapped[str] = mapped_column(
String(32), nullable=False, default="planned")
track_geometry: Mapped[str] = mapped_column(
Geometry(geometry_type="LINESTRING", srid=4326), nullable=False
)
__table_args__ = (
UniqueConstraint(
"start_station_id", "end_station_id", name="uq_tracks_station_pair"
),
)
class CombinedTrack(Base, TimestampMixin):
__tablename__ = "combined_tracks"
id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True), primary_key=True, default=uuid.uuid4
)
@@ -82,19 +123,25 @@ class Track(Base, TimestampMixin):
ForeignKey("stations.id", ondelete="RESTRICT"),
nullable=False,
)
length_meters: Mapped[float | None] = mapped_column(Numeric(10, 2), nullable=True)
length_meters: Mapped[float | None] = mapped_column(
Numeric(10, 2), nullable=True)
max_speed_kph: Mapped[int | None] = mapped_column(Integer, nullable=True)
is_bidirectional: Mapped[bool] = mapped_column(
Boolean, nullable=False, default=True
)
status: Mapped[str] = mapped_column(String(32), nullable=False, default="planned")
track_geometry: Mapped[str] = mapped_column(
status: Mapped[str] = mapped_column(
String(32), nullable=False, default="planned")
combined_geometry: Mapped[str] = mapped_column(
Geometry(geometry_type="LINESTRING", srid=4326), nullable=False
)
# JSON array of constituent track IDs
constituent_track_ids: Mapped[str] = mapped_column(
Text, nullable=False
)
__table_args__ = (
UniqueConstraint(
"start_station_id", "end_station_id", name="uq_tracks_station_pair"
"start_station_id", "end_station_id", name="uq_combined_tracks_station_pair"
),
)
@@ -105,7 +152,8 @@ class Train(Base, TimestampMixin):
id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True), primary_key=True, default=uuid.uuid4
)
designation: Mapped[str] = mapped_column(String(64), nullable=False, unique=True)
designation: Mapped[str] = mapped_column(
String(64), nullable=False, unique=True)
operator_id: Mapped[uuid.UUID | None] = mapped_column(
UUID(as_uuid=True), ForeignKey("users.id", ondelete="SET NULL")
)

View File

@@ -8,11 +8,14 @@ from .auth import (
UserPublic,
)
from .base import (
CombinedTrackCreate,
CombinedTrackModel,
StationCreate,
StationModel,
StationUpdate,
TrackCreate,
TrackModel,
TrackUpdate,
TrainCreate,
TrainModel,
TrainScheduleCreate,
@@ -33,9 +36,12 @@ __all__ = [
"StationUpdate",
"TrackCreate",
"TrackModel",
"TrackUpdate",
"TrainScheduleCreate",
"TrainCreate",
"TrainModel",
"UserCreate",
"to_camel",
"CombinedTrackCreate",
"CombinedTrackModel",
]

View File

@@ -3,7 +3,7 @@ from __future__ import annotations
from datetime import datetime
from typing import Generic, Sequence, TypeVar
from pydantic import BaseModel, ConfigDict
from pydantic import BaseModel, ConfigDict, Field
def to_camel(string: str) -> str:
@@ -51,8 +51,22 @@ class StationModel(IdentifiedModel[str]):
class TrackModel(IdentifiedModel[str]):
start_station_id: str
end_station_id: str
length_meters: float
max_speed_kph: float
length_meters: float | None = None
max_speed_kph: float | None = None
status: str | None = None
is_bidirectional: bool = True
coordinates: list[tuple[float, float]] = Field(default_factory=list)
class CombinedTrackModel(IdentifiedModel[str]):
start_station_id: str
end_station_id: str
length_meters: float | None = None
max_speed_kph: int | None = None
status: str | None = None
is_bidirectional: bool = True
coordinates: list[tuple[float, float]] = Field(default_factory=list)
constituent_track_ids: list[str] = Field(default_factory=list)
class TrainModel(IdentifiedModel[str]):
@@ -86,6 +100,31 @@ class TrackCreate(CamelModel):
start_station_id: str
end_station_id: str
coordinates: Sequence[tuple[float, float]]
osm_id: str | None = None
name: str | None = None
length_meters: float | None = None
max_speed_kph: int | None = None
is_bidirectional: bool = True
status: str = "planned"
class TrackUpdate(CamelModel):
start_station_id: str | None = None
end_station_id: str | None = None
coordinates: Sequence[tuple[float, float]] | None = None
osm_id: str | None = None
name: str | None = None
length_meters: float | None = None
max_speed_kph: int | None = None
is_bidirectional: bool | None = None
status: str | None = None
class CombinedTrackCreate(CamelModel):
start_station_id: str
end_station_id: str
coordinates: Sequence[tuple[float, float]]
constituent_track_ids: list[str]
name: str | None = None
length_meters: float | None = None
max_speed_kph: int | None = None

View File

@@ -2,6 +2,7 @@
from backend.app.repositories.stations import StationRepository
from backend.app.repositories.tracks import TrackRepository
from backend.app.repositories.combined_tracks import CombinedTrackRepository
from backend.app.repositories.train_schedules import TrainScheduleRepository
from backend.app.repositories.trains import TrainRepository
from backend.app.repositories.users import UserRepository
@@ -10,6 +11,7 @@ __all__ = [
"StationRepository",
"TrainScheduleRepository",
"TrackRepository",
"CombinedTrackRepository",
"TrainRepository",
"UserRepository",
]

View File

@@ -0,0 +1,73 @@
from __future__ import annotations
import json
from uuid import UUID
import sqlalchemy as sa
from geoalchemy2.elements import WKTElement
from sqlalchemy.orm import Session
from backend.app.db.models import CombinedTrack
from backend.app.models import CombinedTrackCreate
from backend.app.repositories.base import BaseRepository
class CombinedTrackRepository(BaseRepository[CombinedTrack]):
model = CombinedTrack
def __init__(self, session: Session) -> None:
super().__init__(session)
def list_all(self) -> list[CombinedTrack]:
statement = sa.select(self.model)
return list(self.session.scalars(statement))
def exists_between_stations(self, start_station_id: str, end_station_id: str) -> bool:
"""Check if a combined track already exists between two stations."""
statement = sa.select(sa.exists().where(
sa.and_(
self.model.start_station_id == start_station_id,
self.model.end_station_id == end_station_id
)
))
return bool(self.session.scalar(statement))
def get_constituent_track_ids(self, combined_track: CombinedTrack) -> list[str]:
"""Extract constituent track IDs from a combined track."""
try:
return json.loads(combined_track.constituent_track_ids)
except (json.JSONDecodeError, TypeError):
return []
@staticmethod
def _ensure_uuid(value: UUID | str) -> UUID:
if isinstance(value, UUID):
return value
return UUID(str(value))
@staticmethod
def _line_string(coordinates: list[tuple[float, float]]) -> WKTElement:
if len(coordinates) < 2:
raise ValueError(
"Combined track geometry requires at least two coordinate pairs")
parts = [f"{lon} {lat}" for lat, lon in coordinates]
return WKTElement(f"LINESTRING({', '.join(parts)})", srid=4326)
def create(self, data: CombinedTrackCreate) -> CombinedTrack:
coordinates = list(data.coordinates)
geometry = self._line_string(coordinates)
constituent_track_ids_json = json.dumps(data.constituent_track_ids)
combined_track = CombinedTrack(
name=data.name,
start_station_id=self._ensure_uuid(data.start_station_id),
end_station_id=self._ensure_uuid(data.end_station_id),
length_meters=data.length_meters,
max_speed_kph=data.max_speed_kph,
is_bidirectional=data.is_bidirectional,
status=data.status,
combined_geometry=geometry,
constituent_track_ids=constituent_track_ids_json,
)
self.session.add(combined_track)
return combined_track

View File

@@ -7,7 +7,7 @@ from geoalchemy2.elements import WKTElement
from sqlalchemy.orm import Session
from backend.app.db.models import Track
from backend.app.models import TrackCreate
from backend.app.models import TrackCreate, TrackUpdate
from backend.app.repositories.base import BaseRepository
@@ -21,6 +21,102 @@ class TrackRepository(BaseRepository[Track]):
statement = sa.select(self.model)
return list(self.session.scalars(statement))
def exists_by_osm_id(self, osm_id: str) -> bool:
statement = sa.select(sa.exists().where(self.model.osm_id == osm_id))
return bool(self.session.scalar(statement))
def find_path_between_stations(self, start_station_id: str, end_station_id: str) -> list[Track] | None:
"""Find the shortest path between two stations using existing tracks.
Returns a list of tracks that form the path, or None if no path exists.
"""
# Build adjacency list: station -> list of (neighbor_station, track)
adjacency = self._build_track_graph()
if start_station_id not in adjacency or end_station_id not in adjacency:
return None
# BFS to find shortest path
from collections import deque
# (current_station, path_so_far)
queue = deque([(start_station_id, [])])
visited = set([start_station_id])
while queue:
current_station, path = queue.popleft()
if current_station == end_station_id:
return path
for neighbor, track in adjacency[current_station]:
if neighbor not in visited:
visited.add(neighbor)
queue.append((neighbor, path + [track]))
return None # No path found
def _build_track_graph(self) -> dict[str, list[tuple[str, Track]]]:
"""Build a graph representation of tracks: station -> [(neighbor_station, track), ...]"""
tracks = self.list_all()
graph = {}
for track in tracks:
start_id = str(track.start_station_id)
end_id = str(track.end_station_id)
# Add bidirectional edges (assuming tracks are bidirectional)
if start_id not in graph:
graph[start_id] = []
if end_id not in graph:
graph[end_id] = []
graph[start_id].append((end_id, track))
graph[end_id].append((start_id, track))
return graph
def combine_track_geometries(self, tracks: list[Track]) -> list[tuple[float, float]]:
"""Combine the geometries of multiple tracks into a single coordinate sequence.
Assumes tracks are in order and form a continuous path.
"""
if not tracks:
return []
combined_coords = []
for i, track in enumerate(tracks):
# Extract coordinates from track geometry
coords = self._extract_coordinates_from_track(track)
if i == 0:
# First track: add all coordinates
combined_coords.extend(coords)
else:
# Subsequent tracks: skip the first coordinate (shared with previous track)
combined_coords.extend(coords[1:])
return combined_coords
def _extract_coordinates_from_track(self, track: Track) -> list[tuple[float, float]]:
"""Extract coordinate list from a track's geometry."""
# Convert WKT string to WKTElement, then to shapely geometry
from geoalchemy2.elements import WKTElement
from geoalchemy2.shape import to_shape
try:
wkt_element = WKTElement(track.track_geometry)
geom = to_shape(wkt_element)
if hasattr(geom, 'coords'):
# For LineString, coords returns [(x, y), ...] where x=lon, y=lat
# Convert to (lat, lon)
return [(coord[1], coord[0]) for coord in geom.coords]
except Exception:
pass
return []
@staticmethod
def _ensure_uuid(value: UUID | str) -> UUID:
if isinstance(value, UUID):
@@ -30,7 +126,8 @@ class TrackRepository(BaseRepository[Track]):
@staticmethod
def _line_string(coordinates: list[tuple[float, float]]) -> WKTElement:
if len(coordinates) < 2:
raise ValueError("Track geometry requires at least two coordinate pairs")
raise ValueError(
"Track geometry requires at least two coordinate pairs")
parts = [f"{lon} {lat}" for lat, lon in coordinates]
return WKTElement(f"LINESTRING({', '.join(parts)})", srid=4326)
@@ -38,6 +135,7 @@ class TrackRepository(BaseRepository[Track]):
coordinates = list(data.coordinates)
geometry = self._line_string(coordinates)
track = Track(
osm_id=data.osm_id,
name=data.name,
start_station_id=self._ensure_uuid(data.start_station_id),
end_station_id=self._ensure_uuid(data.end_station_id),
@@ -49,3 +147,26 @@ class TrackRepository(BaseRepository[Track]):
)
self.session.add(track)
return track
def update(self, track: Track, payload: TrackUpdate) -> Track:
if payload.start_station_id is not None:
track.start_station_id = self._ensure_uuid(
payload.start_station_id)
if payload.end_station_id is not None:
track.end_station_id = self._ensure_uuid(payload.end_station_id)
if payload.coordinates is not None:
track.track_geometry = self._line_string(
list(payload.coordinates)) # type: ignore[assignment]
if payload.osm_id is not None:
track.osm_id = payload.osm_id
if payload.name is not None:
track.name = payload.name
if payload.length_meters is not None:
track.length_meters = payload.length_meters
if payload.max_speed_kph is not None:
track.max_speed_kph = payload.max_speed_kph
if payload.is_bidirectional is not None:
track.is_bidirectional = payload.is_bidirectional
if payload.status is not None:
track.status = payload.status
return track

View File

@@ -0,0 +1,79 @@
from __future__ import annotations
"""Application services for combined track operations."""
from sqlalchemy.orm import Session
from backend.app.models import CombinedTrackCreate, CombinedTrackModel
from backend.app.repositories import CombinedTrackRepository, TrackRepository
def create_combined_track(
session: Session, start_station_id: str, end_station_id: str
) -> CombinedTrackModel | None:
"""Create a combined track between two stations using pathfinding.
Returns the created combined track, or None if no path exists or
a combined track already exists between these stations.
"""
combined_track_repo = CombinedTrackRepository(session)
track_repo = TrackRepository(session)
# Check if combined track already exists
if combined_track_repo.exists_between_stations(start_station_id, end_station_id):
return None
# Find path between stations
path_tracks = track_repo.find_path_between_stations(
start_station_id, end_station_id)
if not path_tracks:
return None
# Combine geometries
combined_coords = track_repo.combine_track_geometries(path_tracks)
if len(combined_coords) < 2:
return None
# Calculate total length
total_length = sum(track.length_meters or 0 for track in path_tracks)
# Get max speed (use the minimum speed of all tracks)
max_speeds = [
track.max_speed_kph for track in path_tracks if track.max_speed_kph]
max_speed = min(max_speeds) if max_speeds else None
# Get constituent track IDs
constituent_track_ids = [str(track.id) for track in path_tracks]
# Create combined track
create_data = CombinedTrackCreate(
start_station_id=start_station_id,
end_station_id=end_station_id,
coordinates=combined_coords,
constituent_track_ids=constituent_track_ids,
length_meters=total_length if total_length > 0 else None,
max_speed_kph=max_speed,
status="operational",
)
combined_track = combined_track_repo.create(create_data)
session.commit()
return CombinedTrackModel.model_validate(combined_track)
def get_combined_track(session: Session, combined_track_id: str) -> CombinedTrackModel | None:
"""Get a combined track by ID."""
try:
combined_track_repo = CombinedTrackRepository(session)
combined_track = combined_track_repo.get(combined_track_id)
return CombinedTrackModel.model_validate(combined_track)
except LookupError:
return None
def list_combined_tracks(session: Session) -> list[CombinedTrackModel]:
"""List all combined tracks."""
combined_track_repo = CombinedTrackRepository(session)
combined_tracks = combined_track_repo.list_all()
return [CombinedTrackModel.model_validate(ct) for ct in combined_tracks]

View File

@@ -8,9 +8,10 @@ from geoalchemy2.elements import WKBElement, WKTElement
from geoalchemy2.shape import to_shape
try: # pragma: no cover - optional dependency guard
from shapely.geometry import Point # type: ignore
from shapely.geometry import LineString, Point # type: ignore
except ImportError: # pragma: no cover - allow running without shapely at import time
Point = None # type: ignore[assignment]
LineString = None # type: ignore[assignment]
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.orm import Session
@@ -51,6 +52,12 @@ def _fallback_snapshot() -> dict[str, list[dict[str, object]]]:
end_station_id="station-2",
length_meters=289000.0,
max_speed_kph=230.0,
status="operational",
is_bidirectional=True,
coordinates=[
(stations[0].latitude, stations[0].longitude),
(stations[1].latitude, stations[1].longitude),
],
created_at=now,
updated_at=now,
)
@@ -134,6 +141,24 @@ def get_network_snapshot(session: Session) -> dict[str, list[dict[str, object]]]
track_models: list[TrackModel] = []
for track in tracks_entities:
coordinates: list[tuple[float, float]] = []
geometry = track.track_geometry
shape = (
to_shape(cast(WKBElement | WKTElement, geometry))
if geometry is not None and LineString is not None
else None
)
if (
LineString is not None
and shape is not None
and isinstance(shape, LineString)
):
coords_list: list[tuple[float, float]] = []
for coord in shape.coords:
lon = float(coord[0])
lat = float(coord[1])
coords_list.append((lat, lon))
coordinates = coords_list
track_models.append(
TrackModel(
id=str(track.id),
@@ -141,6 +166,9 @@ def get_network_snapshot(session: Session) -> dict[str, list[dict[str, object]]]
end_station_id=str(track.end_station_id),
length_meters=_to_float(track.length_meters),
max_speed_kph=_to_float(track.max_speed_kph),
status=track.status,
is_bidirectional=track.is_bidirectional,
coordinates=coordinates,
created_at=cast(datetime, track.created_at),
updated_at=cast(datetime, track.updated_at),
)

View File

@@ -0,0 +1,106 @@
from __future__ import annotations
"""Service layer for primary track management operations."""
from typing import Iterable
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session
from backend.app.models import CombinedTrackModel, TrackCreate, TrackModel, TrackUpdate
from backend.app.repositories import CombinedTrackRepository, TrackRepository
def list_tracks(session: Session) -> list[TrackModel]:
repo = TrackRepository(session)
tracks = repo.list_all()
return [TrackModel.model_validate(track) for track in tracks]
def get_track(session: Session, track_id: str) -> TrackModel | None:
repo = TrackRepository(session)
track = repo.get(track_id)
if track is None:
return None
return TrackModel.model_validate(track)
def create_track(session: Session, payload: TrackCreate) -> TrackModel:
repo = TrackRepository(session)
try:
track = repo.create(payload)
session.commit()
except IntegrityError as exc:
session.rollback()
raise ValueError(
"Track with the same station pair already exists") from exc
return TrackModel.model_validate(track)
def update_track(session: Session, track_id: str, payload: TrackUpdate) -> TrackModel | None:
repo = TrackRepository(session)
track = repo.get(track_id)
if track is None:
return None
repo.update(track, payload)
session.commit()
return TrackModel.model_validate(track)
def delete_track(session: Session, track_id: str, regenerate: bool = False) -> bool:
repo = TrackRepository(session)
track = repo.get(track_id)
if track is None:
return False
start_station_id = str(track.start_station_id)
end_station_id = str(track.end_station_id)
session.delete(track)
session.commit()
if regenerate:
regenerate_combined_tracks(session, [start_station_id, end_station_id])
return True
def regenerate_combined_tracks(session: Session, station_ids: Iterable[str]) -> list[CombinedTrackModel]:
combined_repo = CombinedTrackRepository(session)
station_id_set = set(station_ids)
if not station_id_set:
return []
# Remove combined tracks touching these stations
for combined in combined_repo.list_all():
if {str(combined.start_station_id), str(combined.end_station_id)} & station_id_set:
session.delete(combined)
session.commit()
# Rebuild combined tracks between affected station pairs
from backend.app.services.combined_tracks import create_combined_track
regenerated: list[CombinedTrackModel] = []
station_list = list(station_id_set)
for i in range(len(station_list)):
for j in range(i + 1, len(station_list)):
result = create_combined_track(
session, station_list[i], station_list[j])
if result is not None:
regenerated.append(result)
return regenerated
__all__ = [
"list_tracks",
"get_track",
"create_track",
"update_track",
"delete_track",
"regenerate_combined_tracks",
]

View File

@@ -0,0 +1,19 @@
"""Template for new Alembic migration scripts."""
from __future__ import annotations
import sqlalchemy as sa
from alembic import op
revision = '63d02d67b39e'
down_revision = '20251011_01'
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column('tracks', sa.Column(
'osm_id', sa.String(length=32), nullable=True))
def downgrade() -> None:
op.drop_column('tracks', 'osm_id')

View File

@@ -0,0 +1,75 @@
"""Template for new Alembic migration scripts."""
from __future__ import annotations
from sqlalchemy.dialects import postgresql
from geoalchemy2.types import Geometry
import sqlalchemy as sa
from alembic import op
revision = 'e7d4bb03da04'
down_revision = '63d02d67b39e'
branch_labels = None
depends_on = None
def upgrade() -> None:
op.create_table(
"combined_tracks",
sa.Column(
"id",
postgresql.UUID(as_uuid=True),
primary_key=True,
server_default=sa.text("gen_random_uuid()"),
),
sa.Column("name", sa.String(length=128), nullable=True),
sa.Column("start_station_id", postgresql.UUID(
as_uuid=True), nullable=False),
sa.Column("end_station_id", postgresql.UUID(
as_uuid=True), nullable=False),
sa.Column("length_meters", sa.Numeric(10, 2), nullable=True),
sa.Column("max_speed_kph", sa.Integer(), nullable=True),
sa.Column(
"is_bidirectional",
sa.Boolean(),
nullable=False,
server_default=sa.text("true"),
),
sa.Column(
"status", sa.String(length=32), nullable=False, server_default="planned"
),
sa.Column(
"combined_geometry",
Geometry(geometry_type="LINESTRING", srid=4326),
nullable=False,
),
sa.Column("constituent_track_ids", sa.Text(), nullable=False),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.text("timezone('utc', now())"),
nullable=False,
),
sa.Column(
"updated_at",
sa.DateTime(timezone=True),
server_default=sa.text("timezone('utc', now())"),
nullable=False,
),
sa.ForeignKeyConstraint(
["start_station_id"], ["stations.id"], ondelete="RESTRICT"
),
sa.ForeignKeyConstraint(
["end_station_id"], ["stations.id"], ondelete="RESTRICT"
),
sa.UniqueConstraint(
"start_station_id", "end_station_id", name="uq_combined_tracks_station_pair"
),
)
op.create_index(
"ix_combined_tracks_geometry", "combined_tracks", ["combined_geometry"], postgresql_using="gist"
)
def downgrade() -> None:
op.drop_index("ix_combined_tracks_geometry", table_name="combined_tracks")
op.drop_table("combined_tracks")

View File

@@ -0,0 +1,196 @@
from __future__ import annotations
"""Orchestrate the OSM station/track import and load pipeline."""
import argparse
import sys
from dataclasses import dataclass
from pathlib import Path
from typing import Callable, Sequence
from backend.app.core.osm_config import DEFAULT_REGIONS
from backend.scripts import stations_import, stations_load, tracks_import, tracks_load
@dataclass(slots=True)
class Stage:
label: str
runner: Callable[[list[str] | None], int]
args: list[str]
input_path: Path | None = None
output_path: Path | None = None
def build_argument_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(
description="Run the station and track import/load workflow in sequence.",
)
parser.add_argument(
"--region",
choices=[region.name for region in DEFAULT_REGIONS] + ["all"],
default="all",
help="Region selector forwarded to the import scripts (default: all).",
)
parser.add_argument(
"--output-dir",
type=Path,
default=Path("data"),
help="Directory where intermediate JSON payloads are stored (default: data/).",
)
parser.add_argument(
"--stations-json",
type=Path,
help="Existing station JSON file to load; defaults to <output-dir>/osm_stations.json.",
)
parser.add_argument(
"--tracks-json",
type=Path,
help="Existing track JSON file to load; defaults to <output-dir>/osm_tracks.json.",
)
parser.add_argument(
"--skip-station-import",
action="store_true",
help="Skip the station import step (expects --stations-json to point to data).",
)
parser.add_argument(
"--skip-station-load",
action="store_true",
help="Skip loading stations into PostGIS.",
)
parser.add_argument(
"--skip-track-import",
action="store_true",
help="Skip the track import step (expects --tracks-json to point to data).",
)
parser.add_argument(
"--skip-track-load",
action="store_true",
help="Skip loading tracks into PostGIS.",
)
parser.add_argument(
"--dry-run",
action="store_true",
help="Print the planned stages without invoking Overpass or mutating the database.",
)
parser.add_argument(
"--commit",
dest="commit",
action="store_true",
default=True,
help="Commit database changes produced by the load steps (default).",
)
parser.add_argument(
"--no-commit",
dest="commit",
action="store_false",
help="Rollback database changes after load steps (dry run).",
)
return parser
def _build_stage_plan(args: argparse.Namespace) -> list[Stage]:
station_json = args.stations_json or args.output_dir / "osm_stations.json"
track_json = args.tracks_json or args.output_dir / "osm_tracks.json"
stages: list[Stage] = []
if not args.skip_station_import:
stages.append(
Stage(
label="Import stations",
runner=stations_import.main,
args=["--output", str(station_json), "--region", args.region],
output_path=station_json,
)
)
if not args.skip_station_load:
load_args = [str(station_json)]
if not args.commit:
load_args.append("--no-commit")
stages.append(
Stage(
label="Load stations",
runner=stations_load.main,
args=load_args,
input_path=station_json,
)
)
if not args.skip_track_import:
stages.append(
Stage(
label="Import tracks",
runner=tracks_import.main,
args=["--output", str(track_json), "--region", args.region],
output_path=track_json,
)
)
if not args.skip_track_load:
load_args = [str(track_json)]
if not args.commit:
load_args.append("--no-commit")
stages.append(
Stage(
label="Load tracks",
runner=tracks_load.main,
args=load_args,
input_path=track_json,
)
)
return stages
def _describe_plan(stages: Sequence[Stage]) -> None:
if not stages:
print("No stages selected; nothing to do.")
return
print("Selected stages:")
for stage in stages:
detail = " ".join(stage.args) if stage.args else "<no args>"
print(f" - {stage.label}: {detail}")
def _execute_stage(stage: Stage) -> None:
print(f"\n>>> {stage.label}")
if stage.output_path is not None:
stage.output_path.parent.mkdir(parents=True, exist_ok=True)
if stage.input_path is not None and not stage.input_path.exists():
raise RuntimeError(
f"Expected input file {stage.input_path} for {stage.label}; run the import step first or provide an existing file."
)
try:
exit_code = stage.runner(stage.args)
except SystemExit as exc: # argparse.error exits via SystemExit
exit_code = int(exc.code or 0)
if exit_code:
raise RuntimeError(f"{stage.label} failed with exit code {exit_code}.")
def main(argv: list[str] | None = None) -> int:
parser = build_argument_parser()
args = parser.parse_args(argv)
stages = _build_stage_plan(args)
if args.dry_run:
print("Dry run: the following stages would run in order.")
_describe_plan(stages)
return 0
for stage in stages:
_execute_stage(stage)
print("\nOSM refresh pipeline completed successfully.")
return 0
if __name__ == "__main__":
sys.exit(main())

View File

@@ -98,14 +98,12 @@ def normalize_station_elements(
if not name:
continue
raw_code = tags.get("ref") or tags.get(
"railway:ref") or tags.get("local_ref")
raw_code = tags.get("ref") or tags.get("railway:ref") or tags.get("local_ref")
code = str(raw_code) if raw_code is not None else None
elevation_tag = tags.get("ele") or tags.get("elevation")
try:
elevation = float(
elevation_tag) if elevation_tag is not None else None
elevation = float(elevation_tag) if elevation_tag is not None else None
except (TypeError, ValueError):
elevation = None

View File

@@ -23,10 +23,17 @@ def build_argument_parser() -> argparse.ArgumentParser:
help="Path to the normalized station JSON file produced by stations_import.py",
)
parser.add_argument(
"--commit/--no-commit",
"--commit",
dest="commit",
action="store_true",
default=True,
help="Commit the transaction (default: commit). Use --no-commit for dry runs.",
help="Commit the transaction after loading (default).",
)
parser.add_argument(
"--no-commit",
dest="commit",
action="store_false",
help="Rollback the transaction after loading (useful for dry runs).",
)
return parser

View File

@@ -0,0 +1,262 @@
from __future__ import annotations
"""CLI utility to export rail track geometries from OpenStreetMap."""
import argparse
import json
import math
import sys
from dataclasses import asdict
from pathlib import Path
from typing import Any, Iterable, Mapping
from urllib.parse import quote_plus
from backend.app.core.osm_config import (
DEFAULT_REGIONS,
TRACK_ALLOWED_RAILWAY_TYPES,
TRACK_EXCLUDED_SERVICE_TAGS,
TRACK_EXCLUDED_USAGE_TAGS,
TRACK_MIN_LENGTH_METERS,
TRACK_TAG_FILTERS,
compile_overpass_filters,
)
OVERPASS_ENDPOINT = "https://overpass-api.de/api/interpreter"
def build_argument_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(
description="Export OSM rail track ways for ingestion",
)
parser.add_argument(
"--output",
type=Path,
default=Path("data/osm_tracks.json"),
help=(
"Destination file for the exported track geometries "
"(default: data/osm_tracks.json)"
),
)
parser.add_argument(
"--region",
choices=[region.name for region in DEFAULT_REGIONS] + ["all"],
default="all",
help="Region name to export (default: all)",
)
parser.add_argument(
"--dry-run",
action="store_true",
help="Do not fetch data; print the Overpass payload only",
)
return parser
def build_overpass_query(region_name: str) -> str:
if region_name == "all":
regions = DEFAULT_REGIONS
else:
regions = tuple(
region for region in DEFAULT_REGIONS if region.name == region_name
)
if not regions:
available = ", ".join(region.name for region in DEFAULT_REGIONS)
msg = f"Unknown region {region_name}. Available regions: [{available}]"
raise ValueError(msg)
filters = compile_overpass_filters(TRACK_TAG_FILTERS)
parts = ["[out:json][timeout:120];", "("]
for region in regions:
parts.append(f" way{filters}\n ({region.to_overpass_arg()});")
parts.append(")")
parts.append("; out body geom; >; out skel qt;")
return "\n".join(parts)
def perform_request(query: str) -> dict[str, Any]:
import urllib.request
payload = f"data={quote_plus(query)}".encode("utf-8")
request = urllib.request.Request(
OVERPASS_ENDPOINT,
data=payload,
headers={"Content-Type": "application/x-www-form-urlencoded"},
)
with urllib.request.urlopen(request, timeout=180) as response:
payload = response.read()
return json.loads(payload)
def normalize_track_elements(
elements: Iterable[dict[str, Any]]
) -> list[dict[str, Any]]:
"""Convert Overpass way elements into TrackCreate-compatible payloads."""
tracks: list[dict[str, Any]] = []
for element in elements:
if element.get("type") != "way":
continue
raw_geometry = element.get("geometry") or []
coordinates: list[list[float]] = []
for node in raw_geometry:
lat = node.get("lat")
lon = node.get("lon")
if lat is None or lon is None:
coordinates = []
break
coordinates.append([float(lat), float(lon)])
if len(coordinates) < 2:
continue
tags: dict[str, Any] = element.get("tags", {})
length_meters = _polyline_length(coordinates)
if not _should_include_track(tags, length_meters):
continue
name = tags.get("name")
maxspeed = _parse_maxspeed(tags.get("maxspeed"))
status = _derive_status(tags.get("railway"))
is_bidirectional = not _is_oneway(tags.get("oneway"))
tracks.append(
{
"osmId": str(element.get("id")),
"name": str(name) if name else None,
"lengthMeters": length_meters,
"maxSpeedKph": maxspeed,
"status": status,
"isBidirectional": is_bidirectional,
"coordinates": coordinates,
}
)
return tracks
def _parse_maxspeed(value: Any) -> float | None:
if value is None:
return None
# Overpass may return values such as "80" or "80 km/h" or "signals".
if isinstance(value, (int, float)):
return float(value)
text = str(value).strip()
number = ""
for char in text:
if char.isdigit() or char == ".":
number += char
elif number:
break
try:
return float(number) if number else None
except ValueError:
return None
def _derive_status(value: Any) -> str:
tag = str(value or "").lower()
if tag in {"abandoned", "disused"}:
return tag
if tag in {"construction", "proposed"}:
return "construction"
return "operational"
def _should_include_track(tags: Mapping[str, Any], length_meters: float) -> bool:
railway = str(tags.get("railway", "")).lower()
if railway not in TRACK_ALLOWED_RAILWAY_TYPES:
return False
if length_meters < TRACK_MIN_LENGTH_METERS:
return False
service = str(tags.get("service", "")).lower()
if service and service in TRACK_EXCLUDED_SERVICE_TAGS:
return False
usage = str(tags.get("usage", "")).lower()
if usage and usage in TRACK_EXCLUDED_USAGE_TAGS:
return False
return True
def _is_oneway(value: Any) -> bool:
if value is None:
return False
normalized = str(value).strip().lower()
return normalized in {"yes", "true", "1"}
def _polyline_length(points: list[list[float]]) -> float:
if len(points) < 2:
return 0.0
total = 0.0
for index in range(len(points) - 1):
total += _haversine(points[index], points[index + 1])
return total
def _haversine(a: list[float], b: list[float]) -> float:
"""Return distance in meters between two [lat, lon] coordinates."""
lat1, lon1 = a
lat2, lon2 = b
radius = 6_371_000
phi1 = math.radians(lat1)
phi2 = math.radians(lat2)
delta_phi = math.radians(lat2 - lat1)
delta_lambda = math.radians(lon2 - lon1)
sin_dphi = math.sin(delta_phi / 2)
sin_dlambda = math.sin(delta_lambda / 2)
root = sin_dphi**2 + math.cos(phi1) * math.cos(phi2) * sin_dlambda**2
distance = 2 * radius * math.atan2(math.sqrt(root), math.sqrt(1 - root))
return distance
def main(argv: list[str] | None = None) -> int:
parser = build_argument_parser()
args = parser.parse_args(argv)
query = build_overpass_query(args.region)
if args.dry_run:
print(query)
return 0
output_path: Path = args.output
output_path.parent.mkdir(parents=True, exist_ok=True)
data = perform_request(query)
raw_elements = data.get("elements", [])
tracks = normalize_track_elements(raw_elements)
payload = {
"metadata": {
"endpoint": OVERPASS_ENDPOINT,
"region": args.region,
"filters": TRACK_TAG_FILTERS,
"regions": [asdict(region) for region in DEFAULT_REGIONS],
"raw_count": len(raw_elements),
"track_count": len(tracks),
},
"tracks": tracks,
}
with output_path.open("w", encoding="utf-8") as handle:
json.dump(payload, handle, indent=2)
print(
f"Normalized {len(tracks)} tracks from {len(raw_elements)} elements into {output_path}"
)
return 0
if __name__ == "__main__":
sys.exit(main())

View File

@@ -0,0 +1,293 @@
from __future__ import annotations
"""CLI for loading normalized track JSON into the database."""
import argparse
import json
import math
import sys
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Iterable, Mapping, Sequence
from geoalchemy2.elements import WKBElement, WKTElement
from geoalchemy2.shape import to_shape
from backend.app.core.osm_config import TRACK_STATION_SNAP_RADIUS_METERS
from backend.app.db.session import SessionLocal
from backend.app.models import TrackCreate
from backend.app.repositories import StationRepository, TrackRepository
@dataclass(slots=True)
class ParsedTrack:
coordinates: list[tuple[float, float]]
osm_id: str | None = None
name: str | None = None
length_meters: float | None = None
max_speed_kph: float | None = None
status: str = "operational"
is_bidirectional: bool = True
@dataclass(slots=True)
class StationRef:
id: str
latitude: float
longitude: float
def build_argument_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(
description="Load normalized track data into PostGIS",
)
parser.add_argument(
"input",
type=Path,
help="Path to the normalized track JSON file produced by tracks_import.py",
)
parser.add_argument(
"--commit",
dest="commit",
action="store_true",
default=True,
help="Commit the transaction after loading (default).",
)
parser.add_argument(
"--no-commit",
dest="commit",
action="store_false",
help="Rollback the transaction after loading (useful for dry runs).",
)
return parser
def main(argv: list[str] | None = None) -> int:
parser = build_argument_parser()
args = parser.parse_args(argv)
if not args.input.exists():
parser.error(f"Input file {args.input} does not exist")
with args.input.open("r", encoding="utf-8") as handle:
payload = json.load(handle)
track_entries = payload.get("tracks") or []
if not isinstance(track_entries, list):
parser.error("Invalid payload: 'tracks' must be a list")
try:
tracks = _parse_track_entries(track_entries)
except ValueError as exc:
parser.error(str(exc))
created = load_tracks(tracks, commit=args.commit)
print(f"Loaded {created} tracks from {args.input}")
return 0
def _parse_track_entries(entries: Iterable[Mapping[str, Any]]) -> list[ParsedTrack]:
parsed: list[ParsedTrack] = []
for entry in entries:
coordinates = entry.get("coordinates")
if not isinstance(coordinates, Sequence) or len(coordinates) < 2:
raise ValueError(
"Invalid track entry: 'coordinates' must contain at least two points"
)
processed_coordinates: list[tuple[float, float]] = []
for pair in coordinates:
if not isinstance(pair, Sequence) or len(pair) != 2:
raise ValueError(
f"Invalid coordinate pair {pair!r} in track entry")
lat, lon = pair
processed_coordinates.append((float(lat), float(lon)))
name = entry.get("name")
length = _safe_float(entry.get("lengthMeters"))
max_speed = _safe_float(entry.get("maxSpeedKph"))
status = entry.get("status", "operational")
is_bidirectional = entry.get("isBidirectional", True)
osm_id = entry.get("osmId")
parsed.append(
ParsedTrack(
coordinates=processed_coordinates,
osm_id=str(osm_id) if osm_id else None,
name=str(name) if name else None,
length_meters=length,
max_speed_kph=max_speed,
status=str(status) if status else "operational",
is_bidirectional=bool(is_bidirectional),
)
)
return parsed
def load_tracks(tracks: Iterable[ParsedTrack], commit: bool = True) -> int:
created = 0
with SessionLocal() as session:
station_repo = StationRepository(session)
track_repo = TrackRepository(session)
station_index = _build_station_index(station_repo.list_active())
existing_pairs = {
(str(track.start_station_id), str(track.end_station_id))
for track in track_repo.list_all()
}
for track_data in tracks:
# Skip if track with this OSM ID already exists
if track_data.osm_id and track_repo.exists_by_osm_id(track_data.osm_id):
print(
f"Skipping track {track_data.osm_id} - already exists by OSM ID")
continue
start_station = _nearest_station(
track_data.coordinates[0],
station_index,
TRACK_STATION_SNAP_RADIUS_METERS,
)
end_station = _nearest_station(
track_data.coordinates[-1],
station_index,
TRACK_STATION_SNAP_RADIUS_METERS,
)
if not start_station or not end_station:
print(
f"Skipping track {track_data.osm_id} - no start/end stations found")
continue
if start_station.id == end_station.id:
print(
f"Skipping track {track_data.osm_id} - start and end stations are the same")
continue
pair = (start_station.id, end_station.id)
if pair in existing_pairs:
print(
f"Skipping track {track_data.osm_id} - station pair {pair} already exists")
continue
length = track_data.length_meters or _polyline_length(
track_data.coordinates
)
max_speed = (
int(round(track_data.max_speed_kph))
if track_data.max_speed_kph is not None
else None
)
create_schema = TrackCreate(
osm_id=track_data.osm_id,
name=track_data.name,
start_station_id=start_station.id,
end_station_id=end_station.id,
coordinates=track_data.coordinates,
length_meters=length,
max_speed_kph=max_speed,
status=track_data.status,
is_bidirectional=track_data.is_bidirectional,
)
track_repo.create(create_schema)
existing_pairs.add(pair)
created += 1
if commit:
session.commit()
else:
session.rollback()
return created
def _nearest_station(
coordinate: tuple[float, float],
stations: Sequence[StationRef],
max_distance_meters: float,
) -> StationRef | None:
best_station: StationRef | None = None
best_distance = math.inf
for station in stations:
distance = _haversine(
coordinate, (station.latitude, station.longitude))
if distance < best_distance:
best_station = station
best_distance = distance
if best_distance <= max_distance_meters:
return best_station
return None
def _build_station_index(stations: Iterable[Any]) -> list[StationRef]:
index: list[StationRef] = []
for station in stations:
location = getattr(station, "location", None)
if location is None:
continue
point = _to_point(location)
if point is None:
continue
latitude = getattr(point, "y", None)
longitude = getattr(point, "x", None)
if latitude is None or longitude is None:
continue
index.append(
StationRef(
id=str(station.id),
latitude=float(latitude),
longitude=float(longitude),
)
)
return index
def _to_point(geometry: WKBElement | WKTElement | Any):
try:
point = to_shape(geometry)
return point if getattr(point, "geom_type", None) == "Point" else None
except (
Exception
): # pragma: no cover - defensive, should not happen with valid geometry
return None
def _polyline_length(points: Sequence[tuple[float, float]]) -> float:
if len(points) < 2:
return 0.0
total = 0.0
for index in range(len(points) - 1):
total += _haversine(points[index], points[index + 1])
return total
def _haversine(a: tuple[float, float], b: tuple[float, float]) -> float:
lat1, lon1 = a
lat2, lon2 = b
radius = 6_371_000
phi1 = math.radians(lat1)
phi2 = math.radians(lat2)
delta_phi = math.radians(lat2 - lat1)
delta_lambda = math.radians(lon2 - lon1)
sin_dphi = math.sin(delta_phi / 2)
sin_dlambda = math.sin(delta_lambda / 2)
root = sin_dphi**2 + math.cos(phi1) * math.cos(phi2) * sin_dlambda**2
distance = 2 * radius * math.atan2(math.sqrt(root), math.sqrt(1 - root))
return distance
def _safe_float(value: Any) -> float | None:
if value is None or value == "":
return None
try:
return float(value)
except (TypeError, ValueError):
return None
if __name__ == "__main__":
sys.exit(main())

View File

@@ -0,0 +1,166 @@
from __future__ import annotations
from dataclasses import dataclass, field
from datetime import datetime, timezone
from typing import Any, List
from uuid import uuid4
import pytest
from backend.app.models import CombinedTrackModel
from backend.app.repositories.combined_tracks import CombinedTrackRepository
from backend.app.repositories.tracks import TrackRepository
from backend.app.services.combined_tracks import create_combined_track
@dataclass
class DummySession:
added: List[Any] = field(default_factory=list)
scalars_result: List[Any] = field(default_factory=list)
scalar_result: Any = None
statements: List[Any] = field(default_factory=list)
committed: bool = False
rolled_back: bool = False
closed: bool = False
def add(self, instance: Any) -> None:
self.added.append(instance)
def add_all(self, instances: list[Any]) -> None:
self.added.extend(instances)
def scalars(self, statement: Any) -> list[Any]:
self.statements.append(statement)
return list(self.scalars_result)
def scalar(self, statement: Any) -> Any:
self.statements.append(statement)
return self.scalar_result
def flush(
self, _objects: list[Any] | None = None
) -> None: # pragma: no cover - optional
return None
def commit(self) -> None: # pragma: no cover - optional
self.committed = True
def rollback(self) -> None: # pragma: no cover - optional
self.rolled_back = True
def close(self) -> None: # pragma: no cover - optional
self.closed = True
def _now() -> datetime:
return datetime.now(timezone.utc)
def test_combined_track_model_round_trip() -> None:
timestamp = _now()
combined_track = CombinedTrackModel(
id="combined-track-1",
start_station_id="station-1",
end_station_id="station-2",
length_meters=3000.0,
max_speed_kph=100,
status="operational",
is_bidirectional=True,
coordinates=[(52.52, 13.405), (52.6, 13.5), (52.7, 13.6)],
constituent_track_ids=["track-1", "track-2"],
created_at=timestamp,
updated_at=timestamp,
)
assert combined_track.length_meters == 3000.0
assert combined_track.start_station_id != combined_track.end_station_id
assert len(combined_track.coordinates) == 3
assert len(combined_track.constituent_track_ids) == 2
def test_combined_track_repository_create() -> None:
"""Test creating a combined track through the repository."""
session = DummySession()
repo = CombinedTrackRepository(session) # type: ignore[arg-type]
# Create test data
from backend.app.models import CombinedTrackCreate
create_data = CombinedTrackCreate(
start_station_id="550e8400-e29b-41d4-a716-446655440000",
end_station_id="550e8400-e29b-41d4-a716-446655440001",
coordinates=[(52.52, 13.405), (52.6, 13.5)],
constituent_track_ids=["track-1"],
length_meters=1500.0,
max_speed_kph=120,
status="operational",
)
combined_track = repo.create(create_data)
assert combined_track.start_station_id is not None
assert combined_track.end_station_id is not None
assert combined_track.length_meters == 1500.0
assert combined_track.max_speed_kph == 120
assert combined_track.status == "operational"
assert session.added and session.added[0] is combined_track
def test_combined_track_repository_exists_between_stations() -> None:
"""Test checking if combined track exists between stations."""
session = DummySession()
repo = CombinedTrackRepository(session) # type: ignore[arg-type]
# Initially should not exist (scalar_result is None by default)
assert not repo.exists_between_stations(
"550e8400-e29b-41d4-a716-446655440000",
"550e8400-e29b-41d4-a716-446655440001"
)
# Simulate existing combined track
session.scalar_result = True
assert repo.exists_between_stations(
"550e8400-e29b-41d4-a716-446655440000",
"550e8400-e29b-41d4-a716-446655440001"
)
def test_combined_track_service_create_no_path() -> None:
"""Test creating combined track when no path exists."""
# Mock session and repositories
session = DummySession()
# Mock TrackRepository to return no path
class MockTrackRepository:
def __init__(self, session):
pass
def find_path_between_stations(self, start_id, end_id):
return None
# Mock CombinedTrackRepository
class MockCombinedTrackRepository:
def __init__(self, session):
pass
def exists_between_stations(self, start_id, end_id):
return False
# Patch the service to use mock repositories
import backend.app.services.combined_tracks as service_module
original_track_repo = service_module.TrackRepository
original_combined_repo = service_module.CombinedTrackRepository
service_module.TrackRepository = MockTrackRepository
service_module.CombinedTrackRepository = MockCombinedTrackRepository
try:
result = create_combined_track(
session, # type: ignore[arg-type]
"550e8400-e29b-41d4-a716-446655440000",
"550e8400-e29b-41d4-a716-446655440001"
)
assert result is None
finally:
# Restore original classes
service_module.TrackRepository = original_track_repo
service_module.CombinedTrackRepository = original_combined_repo

View File

@@ -29,11 +29,15 @@ def test_track_model_properties() -> None:
end_station_id="station-2",
length_meters=1500.0,
max_speed_kph=120.0,
status="operational",
is_bidirectional=True,
coordinates=[(52.52, 13.405), (52.6, 13.5)],
created_at=timestamp,
updated_at=timestamp,
)
assert track.length_meters > 0
assert track.start_station_id != track.end_station_id
assert len(track.coordinates) == 2
def test_train_model_operating_tracks() -> None:

View File

@@ -26,6 +26,9 @@ def sample_entities() -> dict[str, SimpleNamespace]:
end_station_id=station.id,
length_meters=1234.5,
max_speed_kph=160,
status="operational",
is_bidirectional=True,
track_geometry=None,
created_at=timestamp,
updated_at=timestamp,
)

View File

@@ -0,0 +1,161 @@
from __future__ import annotations
from argparse import Namespace
from pathlib import Path
import pytest
from backend.scripts import osm_refresh
def _namespace(output_dir: Path, **overrides: object) -> Namespace:
defaults: dict[str, object] = {
"region": "all",
"output_dir": output_dir,
"stations_json": None,
"tracks_json": None,
"skip_station_import": False,
"skip_station_load": False,
"skip_track_import": False,
"skip_track_load": False,
"dry_run": False,
"commit": True,
}
defaults.update(overrides)
return Namespace(**defaults)
def test_build_stage_plan_default_sequence(tmp_path: Path) -> None:
stages = osm_refresh._build_stage_plan(_namespace(tmp_path))
labels = [stage.label for stage in stages]
assert labels == [
"Import stations",
"Load stations",
"Import tracks",
"Load tracks",
]
expected_station_path = tmp_path / "osm_stations.json"
expected_track_path = tmp_path / "osm_tracks.json"
assert stages[0].output_path == expected_station_path
assert stages[1].input_path == expected_station_path
assert stages[2].output_path == expected_track_path
assert stages[3].input_path == expected_track_path
def test_build_stage_plan_respects_skip_flags(tmp_path: Path) -> None:
stages = osm_refresh._build_stage_plan(
_namespace(
tmp_path,
skip_station_import=True,
skip_track_import=True,
)
)
labels = [stage.label for stage in stages]
assert labels == ["Load stations", "Load tracks"]
def test_main_dry_run_lists_plan(
monkeypatch: pytest.MonkeyPatch, tmp_path: Path, capsys: pytest.CaptureFixture[str]
) -> None:
def fail(_args: list[str] | None) -> int: # pragma: no cover - defensive
raise AssertionError("runner should not be invoked during dry run")
monkeypatch.setattr(osm_refresh.stations_import, "main", fail)
monkeypatch.setattr(osm_refresh.tracks_import, "main", fail)
monkeypatch.setattr(osm_refresh.stations_load, "main", fail)
monkeypatch.setattr(osm_refresh.tracks_load, "main", fail)
exit_code = osm_refresh.main(["--dry-run", "--output-dir", str(tmp_path)])
assert exit_code == 0
captured = capsys.readouterr().out
assert "Dry run" in captured
assert "Import stations" in captured
assert "Load tracks" in captured
def test_main_executes_stages_in_order(
monkeypatch: pytest.MonkeyPatch, tmp_path: Path
) -> None:
calls: list[str] = []
def make_import(name: str):
def runner(args: list[str] | None) -> int:
assert args is not None
calls.append(name)
output_index = args.index("--output") + 1
output_path = Path(args[output_index])
output_path.write_text("{}", encoding="utf-8")
return 0
return runner
def make_load(name: str):
def runner(args: list[str] | None) -> int:
assert args is not None
calls.append(name)
return 0
return runner
monkeypatch.setattr(
osm_refresh.stations_import, "main", make_import("stations_import")
)
monkeypatch.setattr(osm_refresh.tracks_import, "main", make_import("tracks_import"))
monkeypatch.setattr(osm_refresh.stations_load, "main", make_load("stations_load"))
monkeypatch.setattr(osm_refresh.tracks_load, "main", make_load("tracks_load"))
exit_code = osm_refresh.main(["--output-dir", str(tmp_path)])
assert exit_code == 0
assert calls == [
"stations_import",
"stations_load",
"tracks_import",
"tracks_load",
]
def test_main_skip_import_flags(
monkeypatch: pytest.MonkeyPatch, tmp_path: Path
) -> None:
station_json = tmp_path / "stations.json"
station_json.write_text("{}", encoding="utf-8")
track_json = tmp_path / "tracks.json"
track_json.write_text("{}", encoding="utf-8")
def fail(_args: list[str] | None) -> int: # pragma: no cover - defensive
raise AssertionError("import stage should be skipped")
calls: list[str] = []
def record(name: str):
def runner(args: list[str] | None) -> int:
assert args is not None
calls.append(name)
return 0
return runner
monkeypatch.setattr(osm_refresh.stations_import, "main", fail)
monkeypatch.setattr(osm_refresh.tracks_import, "main", fail)
monkeypatch.setattr(osm_refresh.stations_load, "main", record("stations_load"))
monkeypatch.setattr(osm_refresh.tracks_load, "main", record("tracks_load"))
exit_code = osm_refresh.main(
[
"--skip-station-import",
"--skip-track-import",
"--stations-json",
str(station_json),
"--tracks-json",
str(track_json),
]
)
assert exit_code == 0
assert calls == ["stations_load", "tracks_load"]

View File

@@ -0,0 +1,158 @@
from __future__ import annotations
from datetime import datetime, timezone
from typing import Any
import pytest
from fastapi.testclient import TestClient
from backend.app.api import tracks as tracks_api
from backend.app.main import app
from backend.app.models import CombinedTrackModel, TrackModel
client = TestClient(app)
def _track_model(track_id: str = "track-1") -> TrackModel:
now = datetime.now(timezone.utc)
return TrackModel(
id=track_id,
start_station_id="station-a",
end_station_id="station-b",
length_meters=None,
max_speed_kph=None,
status="planned",
coordinates=[(52.5, 13.4), (52.6, 13.5)],
is_bidirectional=True,
created_at=now,
updated_at=now,
)
def _combined_model(track_id: str = "combined-1") -> CombinedTrackModel:
now = datetime.now(timezone.utc)
return CombinedTrackModel(
id=track_id,
start_station_id="station-a",
end_station_id="station-b",
length_meters=1000,
max_speed_kph=120,
status="operational",
coordinates=[(52.5, 13.4), (52.6, 13.5)],
constituent_track_ids=["track-1", "track-2"],
is_bidirectional=True,
created_at=now,
updated_at=now,
)
def _authenticate() -> str:
response = client.post(
"/api/auth/login",
json={"username": "demo", "password": "railgame123"},
)
assert response.status_code == 200
return response.json()["accessToken"]
def test_list_tracks(monkeypatch: pytest.MonkeyPatch) -> None:
token = _authenticate()
monkeypatch.setattr(tracks_api, "list_tracks", lambda db: [_track_model()])
response = client.get(
"/api/tracks",
headers={"Authorization": f"Bearer {token}"},
)
assert response.status_code == 200
payload = response.json()
assert isinstance(payload, list)
assert payload[0]["id"] == "track-1"
def test_get_track_returns_404(monkeypatch: pytest.MonkeyPatch) -> None:
token = _authenticate()
monkeypatch.setattr(tracks_api, "get_track", lambda db, track_id: None)
response = client.get(
"/api/tracks/not-found",
headers={"Authorization": f"Bearer {token}"},
)
assert response.status_code == 404
def test_create_track_calls_service(monkeypatch: pytest.MonkeyPatch) -> None:
token = _authenticate()
captured: dict[str, Any] = {}
payload = {
"startStationId": "station-a",
"endStationId": "station-b",
"coordinates": [[52.5, 13.4], [52.6, 13.5]],
}
def fake_create(db: Any, data: Any) -> TrackModel:
assert data.start_station_id == "station-a"
captured["payload"] = data
return _track_model("track-new")
monkeypatch.setattr(tracks_api, "create_track", fake_create)
response = client.post(
"/api/tracks",
json=payload,
headers={"Authorization": f"Bearer {token}"},
)
assert response.status_code == 201
body = response.json()
assert body["id"] == "track-new"
assert captured["payload"].end_station_id == "station-b"
def test_delete_track_returns_404(monkeypatch: pytest.MonkeyPatch) -> None:
token = _authenticate()
monkeypatch.setattr(
tracks_api, "delete_track", lambda db, tid, regenerate=False: False
)
response = client.delete(
"/api/tracks/missing",
headers={"Authorization": f"Bearer {token}"},
)
assert response.status_code == 404
def test_delete_track_success(monkeypatch: pytest.MonkeyPatch) -> None:
token = _authenticate()
seen: dict[str, Any] = {}
def fake_delete(db: Any, track_id: str, regenerate: bool = False) -> bool:
seen["track_id"] = track_id
seen["regenerate"] = regenerate
return True
monkeypatch.setattr(tracks_api, "delete_track", fake_delete)
response = client.delete(
"/api/tracks/track-99",
params={"regenerate": "true"},
headers={"Authorization": f"Bearer {token}"},
)
assert response.status_code == 204
assert seen["track_id"] == "track-99"
assert seen["regenerate"] is True
def test_list_combined_tracks(monkeypatch: pytest.MonkeyPatch) -> None:
token = _authenticate()
monkeypatch.setattr(
tracks_api, "list_combined_tracks", lambda db: [_combined_model()]
)
response = client.get(
"/api/tracks/combined",
headers={"Authorization": f"Bearer {token}"},
)
assert response.status_code == 200
payload = response.json()
assert len(payload) == 1
assert payload[0]["id"] == "combined-1"

View File

@@ -0,0 +1,110 @@
from __future__ import annotations
from backend.scripts import tracks_import
def test_normalize_track_elements_excludes_invalid_geometries() -> None:
elements = [
{
"type": "way",
"id": 123,
"geometry": [
{"lat": 52.5, "lon": 13.4},
{"lat": 52.6, "lon": 13.5},
],
"tags": {
"name": "Main Line",
"railway": "rail",
"maxspeed": "120",
},
},
{
"type": "way",
"id": 456,
"geometry": [
{"lat": 51.0},
],
"tags": {"railway": "rail"},
},
{
"type": "node",
"id": 789,
},
]
tracks = tracks_import.normalize_track_elements(elements)
assert len(tracks) == 1
track = tracks[0]
assert track["osmId"] == "123"
assert track["name"] == "Main Line"
assert track["maxSpeedKph"] == 120.0
assert track["status"] == "operational"
assert track["isBidirectional"] is True
assert track["coordinates"] == [[52.5, 13.4], [52.6, 13.5]]
assert track["lengthMeters"] > 0
def test_normalize_track_elements_marks_oneway_and_status() -> None:
elements = [
{
"type": "way",
"id": 42,
"geometry": [
{"lat": 48.1, "lon": 11.5},
{"lat": 48.2, "lon": 11.6},
],
"tags": {
"railway": "disused",
"oneway": "yes",
},
}
]
tracks = tracks_import.normalize_track_elements(elements)
assert len(tracks) == 1
track = tracks[0]
assert track["status"] == "disused"
assert track["isBidirectional"] is False
def test_normalize_track_elements_skips_service_tracks() -> None:
elements = [
{
"type": "way",
"id": 77,
"geometry": [
{"lat": 52.5000, "lon": 13.4000},
{"lat": 52.5010, "lon": 13.4010},
],
"tags": {
"railway": "rail",
"service": "yard",
},
}
]
tracks = tracks_import.normalize_track_elements(elements)
assert tracks == []
def test_normalize_track_elements_skips_short_tracks() -> None:
elements = [
{
"type": "way",
"id": 81,
"geometry": [
{"lat": 52.500000, "lon": 13.400000},
{"lat": 52.500100, "lon": 13.400050},
],
"tags": {
"railway": "rail",
},
}
]
tracks = tracks_import.normalize_track_elements(elements)
assert tracks == []

View File

@@ -0,0 +1,212 @@
from __future__ import annotations
from dataclasses import dataclass, field
from typing import List
import pytest
from geoalchemy2.shape import from_shape
from shapely.geometry import Point
from backend.scripts import tracks_load
def test_parse_track_entries_returns_models() -> None:
entries = [
{
"name": "Connector",
"coordinates": [[52.5, 13.4], [52.6, 13.5]],
"lengthMeters": 1500,
"maxSpeedKph": 120,
"status": "operational",
"isBidirectional": True,
}
]
parsed = tracks_load._parse_track_entries(entries)
assert parsed[0].name == "Connector"
assert parsed[0].coordinates[0] == (52.5, 13.4)
assert parsed[0].length_meters == 1500
assert parsed[0].max_speed_kph == 120
def test_parse_track_entries_invalid_raises_value_error() -> None:
entries = [
{
"coordinates": [[52.5, 13.4]],
}
]
with pytest.raises(ValueError):
tracks_load._parse_track_entries(entries)
@dataclass
class DummySession:
committed: bool = False
rolled_back: bool = False
def __enter__(self) -> "DummySession":
return self
def __exit__(self, exc_type, exc, traceback) -> None:
pass
def commit(self) -> None:
self.committed = True
def rollback(self) -> None:
self.rolled_back = True
@dataclass
class DummyStation:
id: str
location: object
@dataclass
class DummyStationRepository:
session: DummySession
stations: List[DummyStation]
def list_active(self) -> List[DummyStation]:
return self.stations
@dataclass
class DummyTrackRepository:
session: DummySession
created: list = field(default_factory=list)
existing: list = field(default_factory=list)
def list_all(self):
return self.existing
def create(self, data): # pragma: no cover - simple delegation
self.created.append(data)
def _point(lat: float, lon: float) -> object:
return from_shape(Point(lon, lat), srid=4326)
def test_load_tracks_creates_entries(monkeypatch: pytest.MonkeyPatch) -> None:
session_instance = DummySession()
station_repo_instance = DummyStationRepository(
session_instance,
stations=[
DummyStation(id="station-a", location=_point(52.5, 13.4)),
DummyStation(id="station-b", location=_point(52.6, 13.5)),
],
)
track_repo_instance = DummyTrackRepository(session_instance)
monkeypatch.setattr(tracks_load, "SessionLocal", lambda: session_instance)
monkeypatch.setattr(
tracks_load, "StationRepository", lambda session: station_repo_instance
)
monkeypatch.setattr(
tracks_load, "TrackRepository", lambda session: track_repo_instance
)
parsed = tracks_load._parse_track_entries(
[
{
"name": "Connector",
"coordinates": [[52.5, 13.4], [52.6, 13.5]],
}
]
)
created = tracks_load.load_tracks(parsed, commit=True)
assert created == 1
assert session_instance.committed is True
assert track_repo_instance.created
track = track_repo_instance.created[0]
assert track.start_station_id == "station-a"
assert track.end_station_id == "station-b"
assert track.coordinates == [(52.5, 13.4), (52.6, 13.5)]
def test_load_tracks_skips_existing_pairs(monkeypatch: pytest.MonkeyPatch) -> None:
session_instance = DummySession()
station_repo_instance = DummyStationRepository(
session_instance,
stations=[
DummyStation(id="station-a", location=_point(52.5, 13.4)),
DummyStation(id="station-b", location=_point(52.6, 13.5)),
],
)
existing_track = type(
"ExistingTrack",
(),
{
"start_station_id": "station-a",
"end_station_id": "station-b",
},
)
track_repo_instance = DummyTrackRepository(
session_instance,
existing=[existing_track],
)
monkeypatch.setattr(tracks_load, "SessionLocal", lambda: session_instance)
monkeypatch.setattr(
tracks_load, "StationRepository", lambda session: station_repo_instance
)
monkeypatch.setattr(
tracks_load, "TrackRepository", lambda session: track_repo_instance
)
parsed = tracks_load._parse_track_entries(
[
{
"name": "Connector",
"coordinates": [[52.5, 13.4], [52.6, 13.5]],
}
]
)
created = tracks_load.load_tracks(parsed, commit=False)
assert created == 0
assert session_instance.rolled_back is True
assert not track_repo_instance.created
def test_load_tracks_skips_when_station_too_far(
monkeypatch: pytest.MonkeyPatch,
) -> None:
session_instance = DummySession()
station_repo_instance = DummyStationRepository(
session_instance,
stations=[
DummyStation(id="remote-station", location=_point(53.5, 14.5)),
],
)
track_repo_instance = DummyTrackRepository(session_instance)
monkeypatch.setattr(tracks_load, "SessionLocal", lambda: session_instance)
monkeypatch.setattr(
tracks_load, "StationRepository", lambda session: station_repo_instance
)
monkeypatch.setattr(
tracks_load, "TrackRepository", lambda session: track_repo_instance
)
parsed = tracks_load._parse_track_entries(
[
{
"name": "Isolated Segment",
"coordinates": [[52.5, 13.4], [52.51, 13.41]],
}
]
)
created = tracks_load.load_tracks(parsed, commit=True)
assert created == 0
assert session_instance.committed is True
assert not track_repo_instance.created

9782
data/osm_stations.json Normal file

File diff suppressed because it is too large Load Diff

527625
data/osm_tracks.json Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -1,5 +1,3 @@
version: "3.9"
services:
db:
build:
@@ -27,6 +25,7 @@ services:
DATABASE_URL: postgresql+psycopg://railgame:railgame@db:5432/railgame_dev
TEST_DATABASE_URL: postgresql+psycopg://railgame:railgame@db:5432/railgame_test
REDIS_URL: redis://redis:6379/0
INIT_DEMO_DB: "true"
depends_on:
- db
- redis

View File

@@ -111,6 +111,7 @@ graph TD
- **Health Module**: Lightweight readiness probes used by infrastructure checks.
- **Network Module**: Serves read-only snapshots of stations, tracks, and trains using shared domain models (camelCase aliases for client compatibility).
- **OSM Ingestion CLI**: Script pairings (`stations_import`/`stations_load`, `tracks_import`/`tracks_load`) that harvest OpenStreetMap fixtures and persist normalized station and track geometries into PostGIS.
- **Authentication Module**: JWT-based user registration, authentication, and authorization. The current prototype supports on-the-fly account creation backed by an in-memory user store and issues short-lived access tokens to validate the client flow end-to-end.
- **Railway Calculation Module**: Algorithms for route optimization and scheduling (planned).
- **Resource Management Module**: Logic for game economy and progression (planned).
@@ -157,4 +158,3 @@ rail-game/
```
Shared code that spans application layers should be surfaced through well-defined APIs within `backend/app/services` or exposed via frontend data contracts to keep coupling low. Infrastructure automation and CI/CD assets remain isolated under `infra/` to support multiple deployment targets.

View File

@@ -78,7 +78,31 @@ sequenceDiagram
F->>F: Render Leaflet map and snapshot summaries
```
#### 6.2.4 Building Railway Network
#### 6.2.4 OSM Track Import and Load
**Scenario**: Operator refreshes spatial fixtures by harvesting OSM railways and persisting them to PostGIS.
**Description**: The paired CLI scripts `tracks_import.py` and `tracks_load.py` export candidate track segments from Overpass, associate endpoints with the nearest known stations, and store the resulting LINESTRING geometries. Dry-run flags allow inspection of the generated Overpass payload or database mutations before commit.
```mermaid
sequenceDiagram
participant Ops as Operator
participant TI as tracks_import.py
participant OL as Overpass API
participant TL as tracks_load.py
participant DB as PostGIS
Ops->>TI: Invoke with region + output path
TI->>OL: POST compiled Overpass query
OL-->>TI: Return rail way elements (JSON)
TI-->>Ops: Write normalized tracks JSON
Ops->>TL: Invoke with normalized JSON
TL->>DB: Fetch stations + existing tracks
TL->>DB: Insert snapped LINESTRING geometries
TL-->>Ops: Report committed track count
```
#### 6.2.5 Building Railway Network
**Scenario**: User adds a new track segment to their railway network.
@@ -101,7 +125,7 @@ sequenceDiagram
F->>F: Update map display
```
#### 6.2.5 Running Train Simulation
#### 6.2.6 Running Train Simulation
**Scenario**: User starts a train simulation on their network.
@@ -129,7 +153,7 @@ sequenceDiagram
end
```
#### 6.2.6 Saving Game Progress
#### 6.2.7 Saving Game Progress
**Scenario**: User saves their current game state.
@@ -154,4 +178,3 @@ sequenceDiagram
- **Real-time Updates**: WebSocket connections for simulation updates, with fallback to polling
- **Load Balancing**: Backend API can be scaled horizontally for multiple users
- **CDN**: Static assets and map tiles served via CDN for faster loading

View File

@@ -55,6 +55,13 @@ Dynamic simulation of train operations:
- **Fallback Mechanisms**: Polling as alternative when WebSockets unavailable
- **Event-Driven Updates**: Push notifications for game state changes
#### 8.2.4 OSM Track Harvesting Policy
- **Railway Types**: Importer requests `rail`, `light_rail`, `subway`, `tram`, `narrow_gauge`, plus `construction` and `disused` variants to capture build-state metadata.
- **Service Filters**: `service` tags such as `yard`, `siding`, `spur`, `crossover`, `industrial`, or `military` are excluded to focus on mainline traffic.
- **Usage Filters**: Ways flagged with `usage=military` or `usage=tourism` are skipped; unspecified usage defaults to accepted.
- **Geometry Guardrails**: Segments shorter than 75 meters are discarded and track endpoints must snap to an existing station within 350 meters or the segment is ignored during loading.
### 8.3 User Interface Concepts
#### 8.3.1 Component-Based Architecture
@@ -127,4 +134,3 @@ Dynamic simulation of train operations:
- **Lazy Loading**: On-demand loading of components and data
- **Caching Layers**: Redis for frequently accessed data
- **Asset Optimization**: Minification and compression of static resources

View File

@@ -70,8 +70,8 @@ The system interacts with:
- User registration and authentication
- Railway network building and management
- Train scheduling and simulation
- Map visualization and interaction
- Train scheduling and simulation
- Leaderboards and user profiles
**Out of Scope:**
@@ -100,6 +100,7 @@ The system interacts with:
- Browser-native implementation for broad accessibility
- Spatial database for efficient geographical queries
- Offline-friendly OSM ingestion pipeline that uses dedicated CLI scripts to export/import stations and tracks before seeding the database
- Modular architecture allowing for future extensions (e.g., multiplayer)
## 5. Building Block View

File diff suppressed because it is too large Load Diff

View File

@@ -9,6 +9,7 @@
"preview": "vite preview",
"lint": "eslint \"src/**/*.{ts,tsx}\"",
"format": "prettier --write \"src/**/*.{ts,tsx,css}\"",
"test": "vitest run",
"test:e2e": "playwright test"
},
"dependencies": {
@@ -34,6 +35,7 @@
"eslint-plugin-react-hooks": "^4.6.2",
"prettier": "^3.3.3",
"typescript": "^5.5.3",
"vite": "^5.4.0"
"vite": "^5.4.0",
"vitest": "^1.6.0"
}
}

View File

@@ -1,38 +1,135 @@
import './styles/global.css';
import type { LatLngExpression } from 'leaflet';
import { useEffect, useMemo, useState } from 'react';
import { LoginForm } from './components/auth/LoginForm';
import { NetworkMap } from './components/map/NetworkMap';
import { useNetworkSnapshot } from './hooks/useNetworkSnapshot';
import { useAuth } from './state/AuthContext';
import type { Station } from './types/domain';
import { buildTrackAdjacency, computeRoute } from './utils/route';
function App(): JSX.Element {
const { token, user, status: authStatus, logout } = useAuth();
const isAuthenticated = authStatus === 'authenticated' && token !== null;
const { data, status, error } = useNetworkSnapshot(isAuthenticated ? token : null);
const [selectedStationId, setSelectedStationId] = useState<string | null>(null);
const [focusedStationId, setFocusedStationId] = useState<string | null>(null);
const [routeSelection, setRouteSelection] = useState<{
startId: string | null;
endId: string | null;
}>({ startId: null, endId: null });
const [selectedTrackId, setSelectedTrackId] = useState<string | null>(null);
useEffect(() => {
if (status !== 'success' || !data?.stations.length) {
setSelectedStationId(null);
setFocusedStationId(null);
setRouteSelection({ startId: null, endId: null });
setSelectedTrackId(null);
return;
}
if (
!selectedStationId ||
!data.stations.some((station) => station.id === selectedStationId)
) {
setSelectedStationId(data.stations[0].id);
if (!focusedStationId || !hasStation(data.stations, focusedStationId)) {
setFocusedStationId(data.stations[0].id);
}
}, [status, data, selectedStationId]);
}, [status, data, focusedStationId]);
const selectedStation = useMemo(() => {
if (!data || !selectedStationId) {
useEffect(() => {
if (status !== 'success' || !data) {
return;
}
setRouteSelection((current) => {
const startExists = current.startId
? hasStation(data.stations, current.startId)
: false;
const endExists = current.endId
? hasStation(data.stations, current.endId)
: false;
return {
startId: startExists ? current.startId : null,
endId: endExists ? current.endId : null,
};
});
}, [status, data]);
const stationById = useMemo(() => {
if (!data) {
return new Map<string, Station>();
}
const lookup = new Map<string, Station>();
for (const station of data.stations) {
lookup.set(station.id, station);
}
return lookup;
}, [data]);
const trackAdjacency = useMemo(
() => buildTrackAdjacency(data ? data.tracks : []),
[data]
);
const routeComputation = useMemo(
() =>
computeRoute({
startId: routeSelection.startId,
endId: routeSelection.endId,
stationById,
adjacency: trackAdjacency,
}),
[routeSelection, stationById, trackAdjacency]
);
const routeSegments = useMemo<LatLngExpression[][]>(() => {
return routeComputation.segments.map((segment) =>
segment.map((pair) => [pair[0], pair[1]] as LatLngExpression)
);
}, [routeComputation.segments]);
const focusedStation = useMemo(() => {
if (!data || !focusedStationId) {
return null;
}
return data.stations.find((station) => station.id === selectedStationId) ?? null;
}, [data, selectedStationId]);
return stationById.get(focusedStationId) ?? null;
}, [data, focusedStationId, stationById]);
const selectedTrack = useMemo(() => {
if (!data || !selectedTrackId) {
return null;
}
return data.tracks.find((track) => track.id === selectedTrackId) ?? null;
}, [data, selectedTrackId]);
const handleStationSelection = (stationId: string) => {
setFocusedStationId(stationId);
setSelectedTrackId(null);
setRouteSelection((current) => {
if (!current.startId || (current.startId && current.endId)) {
return { startId: stationId, endId: null };
}
if (current.startId === stationId) {
return { startId: stationId, endId: null };
}
return { startId: current.startId, endId: stationId };
});
};
const clearRouteSelection = () => {
setRouteSelection({ startId: null, endId: null });
};
const handleCreateTrack = () => {
if (!routeSelection.startId || !routeSelection.endId) {
return;
}
// TODO: Implement track creation API call
alert(
`Creating track between ${stationById.get(routeSelection.startId)?.name} and ${stationById.get(routeSelection.endId)?.name}`
);
};
return (
<div className="app-shell">
@@ -65,10 +162,86 @@ function App(): JSX.Element {
<div className="map-wrapper">
<NetworkMap
snapshot={data}
selectedStationId={selectedStationId}
onSelectStation={(id) => setSelectedStationId(id)}
focusedStationId={focusedStationId}
startStationId={routeSelection.startId}
endStationId={routeSelection.endId}
routeSegments={routeSegments}
selectedTrackId={selectedTrackId}
onStationClick={handleStationSelection}
onTrackClick={setSelectedTrackId}
/>
</div>
<div className="route-panel">
<div className="route-panel__header">
<h3>Route Selection</h3>
<button
type="button"
className="ghost-button"
onClick={clearRouteSelection}
disabled={!routeSelection.startId && !routeSelection.endId}
>
Clear
</button>
</div>
<p className="route-panel__hint">
Click a station to set the origin, then click another station to
preview the rail corridor between them.
</p>
<dl className="route-panel__meta">
<div>
<dt>Origin</dt>
<dd>
{routeSelection.startId
? (stationById.get(routeSelection.startId)?.name ??
'Unknown station')
: 'Choose a station'}
</dd>
</div>
<div>
<dt>Destination</dt>
<dd>
{routeSelection.endId
? (stationById.get(routeSelection.endId)?.name ??
'Unknown station')
: 'Choose a station'}
</dd>
</div>
<div>
<dt>Estimated Length</dt>
<dd>
{routeComputation.totalLength !== null
? `${(routeComputation.totalLength / 1000).toFixed(2)} km`
: 'N/A'}
</dd>
</div>
</dl>
{routeComputation.error && (
<p className="route-panel__error">{routeComputation.error}</p>
)}
{!routeComputation.error && routeComputation.stations && (
<div className="route-panel__path">
<span>Path:</span>
<ol>
{routeComputation.stations.map((station) => (
<li key={`route-station-${station.id}`}>{station.name}</li>
))}
</ol>
</div>
)}
{routeSelection.startId &&
routeSelection.endId &&
routeComputation.error && (
<div className="route-panel__actions">
<button
type="button"
className="primary-button"
onClick={handleCreateTrack}
>
Create Track
</button>
</div>
)}
</div>
<div className="grid">
<div>
<h3>Stations</h3>
@@ -78,12 +251,20 @@ function App(): JSX.Element {
<button
type="button"
className={`station-list-item${
station.id === selectedStationId
station.id === focusedStationId
? ' station-list-item--selected'
: ''
}${
station.id === routeSelection.startId
? ' station-list-item--start'
: ''
}${
station.id === routeSelection.endId
? ' station-list-item--end'
: ''
}`}
aria-pressed={station.id === selectedStationId}
onClick={() => setSelectedStationId(station.id)}
aria-pressed={station.id === focusedStationId}
onClick={() => handleStationSelection(station.id)}
>
<span className="station-list-item__name">
{station.name}
@@ -92,6 +273,14 @@ function App(): JSX.Element {
{station.latitude.toFixed(3)},{' '}
{station.longitude.toFixed(3)}
</span>
{station.id === routeSelection.startId && (
<span className="station-list-item__badge">Origin</span>
)}
{station.id === routeSelection.endId && (
<span className="station-list-item__badge">
Destination
</span>
)}
</button>
</li>
))}
@@ -114,54 +303,99 @@ function App(): JSX.Element {
{data.tracks.map((track) => (
<li key={track.id}>
{track.startStationId} {track.endStationId} ·{' '}
{(track.lengthMeters / 1000).toFixed(1)} km
{track.lengthMeters > 0
? `${(track.lengthMeters / 1000).toFixed(1)} km`
: 'N/A'}
</li>
))}
</ul>
</div>
</div>
{selectedStation && (
{focusedStation && (
<div className="selected-station">
<h3>Selected Station</h3>
<h3>Focused Station</h3>
<dl>
<div>
<dt>Name</dt>
<dd>{selectedStation.name}</dd>
<dd>{focusedStation.name}</dd>
</div>
<div>
<dt>Coordinates</dt>
<dd>
{selectedStation.latitude.toFixed(5)},{' '}
{selectedStation.longitude.toFixed(5)}
{focusedStation.latitude.toFixed(5)},{' '}
{focusedStation.longitude.toFixed(5)}
</dd>
</div>
{selectedStation.code && (
{focusedStation.code && (
<div>
<dt>Code</dt>
<dd>{selectedStation.code}</dd>
<dd>{focusedStation.code}</dd>
</div>
)}
{typeof selectedStation.elevationM === 'number' && (
{typeof focusedStation.elevationM === 'number' && (
<div>
<dt>Elevation</dt>
<dd>{selectedStation.elevationM.toFixed(1)} m</dd>
<dd>{focusedStation.elevationM.toFixed(1)} m</dd>
</div>
)}
{selectedStation.osmId && (
{focusedStation.osmId && (
<div>
<dt>OSM ID</dt>
<dd>{selectedStation.osmId}</dd>
<dd>{focusedStation.osmId}</dd>
</div>
)}
<div>
<dt>Status</dt>
<dd>
{(selectedStation.isActive ?? true) ? 'Active' : 'Inactive'}
{(focusedStation.isActive ?? true) ? 'Active' : 'Inactive'}
</dd>
</div>
</dl>
</div>
)}
{selectedTrack && (
<div className="selected-track">
<h3>Selected Track</h3>
<dl>
<div>
<dt>Start Station</dt>
<dd>
{stationById.get(selectedTrack.startStationId)?.name ??
'Unknown'}
</dd>
</div>
<div>
<dt>End Station</dt>
<dd>
{stationById.get(selectedTrack.endStationId)?.name ??
'Unknown'}
</dd>
</div>
<div>
<dt>Length</dt>
<dd>
{selectedTrack.lengthMeters > 0
? `${(selectedTrack.lengthMeters / 1000).toFixed(2)} km`
: 'N/A'}
</dd>
</div>
<div>
<dt>Max Speed</dt>
<dd>{selectedTrack.maxSpeedKph} km/h</dd>
</div>
{selectedTrack.status && (
<div>
<dt>Status</dt>
<dd>{selectedTrack.status}</dd>
</div>
)}
<div>
<dt>Bidirectional</dt>
<dd>{selectedTrack.isBidirectional ? 'Yes' : 'No'}</dd>
</div>
</dl>
</div>
)}
</div>
)}
</section>
@@ -172,3 +406,7 @@ function App(): JSX.Element {
}
export default App;
function hasStation(stations: Station[], id: string): boolean {
return stations.some((station) => station.id === id);
}

View File

@@ -15,8 +15,13 @@ import 'leaflet/dist/leaflet.css';
interface NetworkMapProps {
readonly snapshot: NetworkSnapshot;
readonly selectedStationId?: string | null;
readonly onSelectStation?: (stationId: string) => void;
readonly focusedStationId?: string | null;
readonly startStationId?: string | null;
readonly endStationId?: string | null;
readonly routeSegments?: LatLngExpression[][];
readonly selectedTrackId?: string | null;
readonly onStationClick?: (stationId: string) => void;
readonly onTrackClick?: (trackId: string) => void;
}
interface StationPosition {
@@ -29,8 +34,13 @@ const DEFAULT_CENTER: LatLngExpression = [51.505, -0.09];
export function NetworkMap({
snapshot,
selectedStationId,
onSelectStation,
focusedStationId,
startStationId,
endStationId,
routeSegments = [],
selectedTrackId,
onStationClick,
onTrackClick,
}: NetworkMapProps): JSX.Element {
const stationPositions = useMemo<StationPosition[]>(() => {
return snapshot.stations.map((station) => ({
@@ -51,6 +61,12 @@ export function NetworkMap({
const trackSegments = useMemo(() => {
return snapshot.tracks
.map((track) => {
if (track.coordinates && track.coordinates.length >= 2) {
return track.coordinates.map(
(pair) => [pair[0], pair[1]] as LatLngExpression
);
}
const start = stationLookup.get(track.startStationId);
const end = stationLookup.get(track.endStationId);
if (!start || !end) {
@@ -86,12 +102,12 @@ export function NetworkMap({
] as LatLngBoundsExpression;
}, [stationPositions]);
const selectedPosition = useMemo(() => {
if (!selectedStationId) {
const focusedPosition = useMemo(() => {
if (!focusedStationId) {
return null;
}
return stationLookup.get(selectedStationId) ?? null;
}, [selectedStationId, stationLookup]);
return stationLookup.get(focusedStationId) ?? null;
}, [focusedStationId, stationLookup]);
return (
<MapContainer
@@ -104,35 +120,82 @@ export function NetworkMap({
attribution='&copy; <a href="https://www.openstreetmap.org/copyright">OpenStreetMap</a> contributors'
url="https://{s}.tile.openstreetmap.org/{z}/{x}/{y}.png"
/>
{selectedPosition ? <StationFocus position={selectedPosition} /> : null}
{trackSegments.map((segment, index) => (
{focusedPosition ? <StationFocus position={focusedPosition} /> : null}
{trackSegments.map((segment, index) => {
const track = snapshot.tracks[index];
const isSelected = track.id === selectedTrackId;
return (
<Polyline
key={`track-${track.id}`}
positions={segment}
pathOptions={{
color: isSelected ? '#3b82f6' : '#334155',
weight: isSelected ? 5 : 3,
opacity: 0.8,
}}
eventHandlers={{
click: () => {
onTrackClick?.(track.id);
},
}}
>
<Tooltip>
{track.startStationId} {track.endStationId}
<br />
Length:{' '}
{track.lengthMeters > 0
? `${(track.lengthMeters / 1000).toFixed(1)} km`
: 'N/A'}
<br />
Max Speed: {track.maxSpeedKph} km/h
{track.status && (
<>
<br />
Status: {track.status}
</>
)}
</Tooltip>
</Polyline>
);
})}
{routeSegments.map((segment, index) => (
<Polyline
key={`track-${index}`}
key={`route-${index}`}
positions={segment}
pathOptions={{ color: '#38bdf8', weight: 4 }}
pathOptions={{ color: '#facc15', weight: 6, opacity: 0.9 }}
/>
))}
{stationPositions.map((station) => (
<CircleMarker
key={station.id}
center={station.position}
radius={station.id === selectedStationId ? 9 : 6}
radius={station.id === focusedStationId ? 9 : 6}
pathOptions={{
color: station.id === selectedStationId ? '#34d399' : '#f97316',
fillColor: station.id === selectedStationId ? '#6ee7b7' : '#fed7aa',
fillOpacity: 0.95,
weight: station.id === selectedStationId ? 3 : 1,
color: resolveMarkerStroke(
station.id,
startStationId,
endStationId,
focusedStationId
),
fillColor: resolveMarkerFill(
station.id,
startStationId,
endStationId,
focusedStationId
),
fillOpacity: 0.96,
weight: station.id === focusedStationId ? 3 : 1,
}}
eventHandlers={{
click: () => {
onSelectStation?.(station.id);
onStationClick?.(station.id);
},
}}
>
<Tooltip
direction="top"
offset={[0, -8]}
permanent={station.id === selectedStationId}
permanent={station.id === focusedStationId}
sticky
>
{station.name}
@@ -152,3 +215,39 @@ function StationFocus({ position }: { position: LatLngExpression }): null {
return null;
}
function resolveMarkerStroke(
stationId: string,
startStationId?: string | null,
endStationId?: string | null,
focusedStationId?: string | null
): string {
if (stationId === startStationId) {
return '#38bdf8';
}
if (stationId === endStationId) {
return '#fb923c';
}
if (stationId === focusedStationId) {
return '#22c55e';
}
return '#f97316';
}
function resolveMarkerFill(
stationId: string,
startStationId?: string | null,
endStationId?: string | null,
focusedStationId?: string | null
): string {
if (stationId === startStationId) {
return '#bae6fd';
}
if (stationId === endStationId) {
return '#fed7aa';
}
if (stationId === focusedStationId) {
return '#bbf7d0';
}
return '#ffe4c7';
}

View File

@@ -102,6 +102,7 @@ body {
background-color 0.18s ease,
border-color 0.18s ease,
transform 0.18s ease;
flex-wrap: wrap;
}
.station-list-item:hover,
@@ -118,6 +119,16 @@ body {
box-shadow: 0 8px 18px -10px rgba(45, 212, 191, 0.65);
}
.station-list-item--start {
border-color: rgba(56, 189, 248, 0.8);
background: rgba(14, 165, 233, 0.2);
}
.station-list-item--end {
border-color: rgba(249, 115, 22, 0.8);
background: rgba(234, 88, 12, 0.18);
}
.station-list-item__name {
font-weight: 600;
}
@@ -128,6 +139,27 @@ body {
font-family: 'Fira Code', 'Source Code Pro', monospace;
}
.station-list-item__badge {
font-size: 0.75rem;
font-weight: 600;
text-transform: uppercase;
letter-spacing: 0.05em;
padding: 0.1rem 0.45rem;
border-radius: 999px;
background: rgba(148, 163, 184, 0.18);
color: rgba(226, 232, 240, 0.85);
}
.station-list-item--start .station-list-item__badge {
background: rgba(56, 189, 248, 0.35);
color: #0ea5e9;
}
.station-list-item--end .station-list-item__badge {
background: rgba(249, 115, 22, 0.35);
color: #f97316;
}
.grid h3 {
margin-bottom: 0.5rem;
font-size: 1.1rem;
@@ -151,6 +183,95 @@ body {
width: 100%;
}
.route-panel {
display: grid;
gap: 0.85rem;
padding: 1.1rem 1.35rem;
border-radius: 12px;
border: 1px solid rgba(250, 204, 21, 0.3);
background: rgba(161, 98, 7, 0.16);
}
.route-panel__header {
display: flex;
align-items: center;
justify-content: space-between;
gap: 1rem;
}
.route-panel__hint {
font-size: 0.9rem;
color: rgba(226, 232, 240, 0.78);
}
.route-panel__meta {
display: grid;
gap: 0.45rem;
}
.route-panel__meta > div {
display: flex;
justify-content: space-between;
align-items: baseline;
gap: 0.75rem;
}
.route-panel__meta dt {
font-size: 0.8rem;
text-transform: uppercase;
letter-spacing: 0.06em;
color: rgba(226, 232, 240, 0.65);
}
.route-panel__meta dd {
font-size: 0.95rem;
color: rgba(226, 232, 240, 0.92);
}
.route-panel__error {
color: #f87171;
font-weight: 600;
}
.route-panel__path {
display: flex;
gap: 0.6rem;
align-items: baseline;
}
.route-panel__path span {
font-size: 0.85rem;
color: rgba(226, 232, 240, 0.7);
text-transform: uppercase;
letter-spacing: 0.06em;
}
.route-panel__path ol {
display: flex;
flex-wrap: wrap;
gap: 0.4rem;
list-style: none;
padding: 0;
margin: 0;
}
.route-panel__path li::after {
content: '→';
margin-left: 0.35rem;
color: rgba(250, 204, 21, 0.75);
}
.route-panel__path li:last-child::after {
content: '';
margin: 0;
}
.route-panel__actions {
margin-top: 1rem;
display: flex;
gap: 0.75rem;
}
.selected-station {
margin-top: 1rem;
padding: 1rem 1.25rem;
@@ -190,6 +311,45 @@ body {
color: rgba(226, 232, 240, 0.92);
}
.selected-track {
margin-top: 1rem;
padding: 1rem 1.25rem;
border-radius: 12px;
border: 1px solid rgba(59, 130, 246, 0.35);
background: rgba(37, 99, 235, 0.18);
display: grid;
gap: 0.75rem;
}
.selected-track h3 {
color: rgba(226, 232, 240, 0.9);
font-size: 1.1rem;
}
.selected-track dl {
display: grid;
gap: 0.45rem;
}
.selected-track dl > div {
display: flex;
align-items: baseline;
justify-content: space-between;
gap: 0.75rem;
}
.selected-track dt {
font-size: 0.8rem;
text-transform: uppercase;
letter-spacing: 0.08em;
color: rgba(226, 232, 240, 0.6);
}
.selected-track dd {
font-size: 0.95rem;
color: rgba(226, 232, 240, 0.92);
}
@media (min-width: 768px) {
.snapshot-layout {
gap: 2rem;

View File

@@ -22,6 +22,9 @@ export interface Track extends Identified {
readonly endStationId: string;
readonly lengthMeters: number;
readonly maxSpeedKph: number;
readonly status?: string | null;
readonly isBidirectional?: boolean;
readonly coordinates: readonly [number, number][];
}
export interface Train extends Identified {

View File

@@ -0,0 +1,216 @@
import { describe, expect, it } from 'vitest';
import { buildTrackAdjacency, computeRoute } from './route';
import type { Station, Track } from '../types/domain';
const baseTimestamps = {
createdAt: '2024-01-01T00:00:00Z',
updatedAt: '2024-01-01T00:00:00Z',
};
describe('route utilities', () => {
it('finds a multi-hop path across connected tracks', () => {
const stations: Station[] = [
{
id: 'station-a',
name: 'Alpha',
latitude: 51.5,
longitude: -0.1,
...baseTimestamps,
},
{
id: 'station-b',
name: 'Bravo',
latitude: 51.52,
longitude: -0.11,
...baseTimestamps,
},
{
id: 'station-c',
name: 'Charlie',
latitude: 51.54,
longitude: -0.12,
...baseTimestamps,
},
{
id: 'station-d',
name: 'Delta',
latitude: 51.55,
longitude: -0.15,
...baseTimestamps,
},
];
const tracks: Track[] = [
{
id: 'track-ab',
startStationId: 'station-a',
endStationId: 'station-b',
lengthMeters: 1200,
maxSpeedKph: 120,
coordinates: [
[51.5, -0.1],
[51.51, -0.105],
[51.52, -0.11],
],
...baseTimestamps,
},
{
id: 'track-bc',
startStationId: 'station-b',
endStationId: 'station-c',
lengthMeters: 1500,
maxSpeedKph: 110,
coordinates: [
[51.52, -0.11],
[51.53, -0.115],
[51.54, -0.12],
],
...baseTimestamps,
},
{
id: 'track-cd',
startStationId: 'station-c',
endStationId: 'station-d',
lengthMeters: 900,
maxSpeedKph: 115,
coordinates: [
[51.54, -0.12],
[51.545, -0.13],
[51.55, -0.15],
],
...baseTimestamps,
},
];
const stationById = new Map(stations.map((station) => [station.id, station]));
const adjacency = buildTrackAdjacency(tracks);
const result = computeRoute({
startId: 'station-a',
endId: 'station-d',
stationById,
adjacency,
});
expect(result.error).toBeNull();
expect(result.stations?.map((station) => station.id)).toEqual([
'station-a',
'station-b',
'station-c',
'station-d',
]);
expect(result.tracks.map((track) => track.id)).toEqual([
'track-ab',
'track-bc',
'track-cd',
]);
expect(result.totalLength).toBe(1200 + 1500 + 900);
expect(result.segments).toHaveLength(3);
expect(result.segments[0][0]).toEqual([51.5, -0.1]);
expect(result.segments[2][result.segments[2].length - 1]).toEqual([51.55, -0.15]);
});
it('returns an error when no path exists', () => {
const stations: Station[] = [
{
id: 'station-a',
name: 'Alpha',
latitude: 51.5,
longitude: -0.1,
...baseTimestamps,
},
{
id: 'station-b',
name: 'Bravo',
latitude: 51.6,
longitude: -0.2,
...baseTimestamps,
},
];
const tracks: Track[] = [
{
id: 'track-self',
startStationId: 'station-a',
endStationId: 'station-a',
lengthMeters: 0,
maxSpeedKph: 80,
coordinates: [
[51.5, -0.1],
[51.5005, -0.1005],
],
...baseTimestamps,
},
];
const stationById = new Map(stations.map((station) => [station.id, station]));
const adjacency = buildTrackAdjacency(tracks);
const result = computeRoute({
startId: 'station-a',
endId: 'station-b',
stationById,
adjacency,
});
expect(result.stations).toBeNull();
expect(result.tracks).toHaveLength(0);
expect(result.error).toBe(
'No rail connection found between the selected stations.'
);
expect(result.segments).toHaveLength(0);
});
it('reverses track geometry when traversing in the opposite direction', () => {
const stations: Station[] = [
{
id: 'station-a',
name: 'Alpha',
latitude: 51.5,
longitude: -0.1,
...baseTimestamps,
},
{
id: 'station-b',
name: 'Bravo',
latitude: 51.52,
longitude: -0.11,
...baseTimestamps,
},
];
const tracks: Track[] = [
{
id: 'track-ab',
startStationId: 'station-a',
endStationId: 'station-b',
lengthMeters: 1200,
maxSpeedKph: 120,
coordinates: [
[51.5, -0.1],
[51.52, -0.11],
],
...baseTimestamps,
},
];
const stationById = new Map(stations.map((station) => [station.id, station]));
const adjacency = buildTrackAdjacency(tracks);
const result = computeRoute({
startId: 'station-b',
endId: 'station-a',
stationById,
adjacency,
});
expect(result.error).toBeNull();
expect(result.segments).toEqual([
[
[51.52, -0.11],
[51.5, -0.1],
],
]);
});
});

239
frontend/src/utils/route.ts Normal file
View File

@@ -0,0 +1,239 @@
import type { Station, Track } from '../types/domain';
export type LatLngTuple = readonly [number, number];
export interface NeighborEdge {
readonly neighborId: string;
readonly track: Track;
readonly isForward: boolean;
}
export type TrackAdjacency = Map<string, NeighborEdge[]>;
export interface ComputeRouteParams {
readonly startId?: string | null;
readonly endId?: string | null;
readonly stationById: Map<string, Station>;
readonly adjacency: TrackAdjacency;
}
export interface RouteComputation {
readonly stations: Station[] | null;
readonly tracks: Track[];
readonly totalLength: number | null;
readonly error: string | null;
readonly segments: LatLngTuple[][];
}
export function buildTrackAdjacency(tracks: readonly Track[]): TrackAdjacency {
const adjacency: TrackAdjacency = new Map();
const register = (fromId: string, toId: string, track: Track, isForward: boolean) => {
if (!adjacency.has(fromId)) {
adjacency.set(fromId, []);
}
adjacency.get(fromId)!.push({ neighborId: toId, track, isForward });
};
for (const track of tracks) {
register(track.startStationId, track.endStationId, track, true);
register(track.endStationId, track.startStationId, track, false);
}
return adjacency;
}
export function computeRoute({
startId,
endId,
stationById,
adjacency,
}: ComputeRouteParams): RouteComputation {
if (!startId || !endId) {
return emptyResult();
}
if (!stationById.has(startId) || !stationById.has(endId)) {
return {
stations: null,
tracks: [],
totalLength: null,
error: 'Selected stations are no longer available.',
segments: [],
};
}
if (startId === endId) {
const station = stationById.get(startId);
return {
stations: station ? [station] : null,
tracks: [],
totalLength: 0,
error: null,
segments: [],
};
}
const visited = new Set<string>();
const queue: string[] = [];
const parent = new Map<string, { prev: string | null; edge: NeighborEdge | null }>();
queue.push(startId);
visited.add(startId);
parent.set(startId, { prev: null, edge: null });
while (queue.length > 0) {
const current = queue.shift()!;
if (current === endId) {
break;
}
const neighbors = adjacency.get(current) ?? [];
for (const edge of neighbors) {
const { neighborId } = edge;
if (visited.has(neighborId)) {
continue;
}
visited.add(neighborId);
parent.set(neighborId, { prev: current, edge });
queue.push(neighborId);
}
}
if (!parent.has(endId)) {
return {
stations: null,
tracks: [],
totalLength: null,
error: 'No rail connection found between the selected stations.',
segments: [],
};
}
const stationPath: string[] = [];
const trackSequence: Track[] = [];
const directions: boolean[] = [];
let cursor: string | null = endId;
while (cursor) {
const details = parent.get(cursor);
if (!details) {
break;
}
stationPath.push(cursor);
if (details.edge) {
trackSequence.push(details.edge.track);
directions.push(details.edge.isForward);
}
cursor = details.prev;
}
stationPath.reverse();
trackSequence.reverse();
directions.reverse();
const stations = stationPath
.map((id) => stationById.get(id))
.filter((station): station is Station => Boolean(station));
const segments = buildSegments(trackSequence, directions, stationById);
const totalLength = computeTotalLength(trackSequence, stations);
return {
stations,
tracks: trackSequence,
totalLength,
error: null,
segments,
};
}
function buildSegments(
tracks: Track[],
directions: boolean[],
stationById: Map<string, Station>
): LatLngTuple[][] {
const segments: LatLngTuple[][] = [];
for (let index = 0; index < tracks.length; index += 1) {
const track = tracks[index];
const isForward = directions[index] ?? true;
const coordinates = extractTrackCoordinates(track, stationById);
if (coordinates.length < 2) {
continue;
}
segments.push(isForward ? coordinates : [...coordinates].reverse());
}
return segments;
}
function extractTrackCoordinates(
track: Track,
stationById: Map<string, Station>
): LatLngTuple[] {
if (Array.isArray(track.coordinates) && track.coordinates.length >= 2) {
return track.coordinates.map((pair) => [pair[0], pair[1]] as LatLngTuple);
}
const start = stationById.get(track.startStationId);
const end = stationById.get(track.endStationId);
if (!start || !end) {
return [];
}
return [
[start.latitude, start.longitude],
[end.latitude, end.longitude],
];
}
function computeTotalLength(tracks: Track[], stations: Station[]): number | null {
if (tracks.length === 0 && stations.length <= 1) {
return 0;
}
const hasTrackLengths = tracks.every(
(track) =>
typeof track.lengthMeters === 'number' && Number.isFinite(track.lengthMeters)
);
if (hasTrackLengths) {
return tracks.reduce((total, track) => total + (track.lengthMeters ?? 0), 0);
}
if (stations.length < 2) {
return null;
}
let total = 0;
for (let index = 0; index < stations.length - 1; index += 1) {
total += haversineDistance(stations[index], stations[index + 1]);
}
return total;
}
function haversineDistance(a: Station, b: Station): number {
const R = 6371_000;
const toRad = (value: number) => (value * Math.PI) / 180;
const dLat = toRad(b.latitude - a.latitude);
const dLon = toRad(b.longitude - a.longitude);
const lat1 = toRad(a.latitude);
const lat2 = toRad(b.latitude);
const sinDLat = Math.sin(dLat / 2);
const sinDLon = Math.sin(dLon / 2);
const root = sinDLat * sinDLat + Math.cos(lat1) * Math.cos(lat2) * sinDLon * sinDLon;
const c = 2 * Math.atan2(Math.sqrt(root), Math.sqrt(1 - root));
return R * c;
}
function emptyResult(): RouteComputation {
return {
stations: null,
tracks: [],
totalLength: null,
error: null,
segments: [],
};
}

View File

@@ -15,7 +15,7 @@
"isolatedModules": true,
"noEmit": true,
"jsx": "react-jsx",
"types": ["vite/client"]
"types": ["vite/client", "vitest"]
},
"include": ["src"]
}

View File

@@ -0,0 +1,8 @@
import { defineConfig } from 'vitest/config';
export default defineConfig({
test: {
include: ['src/**/*.test.ts'],
environment: 'node',
},
});

169
scripts/init_demo_db.py Normal file
View File

@@ -0,0 +1,169 @@
#!/usr/bin/env python3
"""
Initialize the database with demo data for the Rail Game.
This script automates the database setup process:
1. Validates environment setup
2. Runs database migrations
3. Loads OSM fixtures for demo data
Usage:
python scripts/init_demo_db.py [--dry-run] [--region REGION]
Requirements:
- Virtual environment activated
- .env file configured with DATABASE_URL
- PostgreSQL with PostGIS running
"""
import argparse
import os
import subprocess
import sys
from pathlib import Path
try:
from dotenv import load_dotenv
load_dotenv()
except ImportError:
print("WARNING: python-dotenv not installed. .env file will not be loaded automatically.")
print("Install with: pip install python-dotenv")
def check_virtualenv():
"""Check if we're running in a virtual environment."""
# Skip virtualenv check in Docker containers
if os.getenv('INIT_DEMO_DB') == 'true':
return
if not hasattr(sys, 'real_prefix') and not (hasattr(sys, 'base_prefix') and sys.base_prefix != sys.prefix):
print("ERROR: Virtual environment not activated. Run:")
print(" .venv\\Scripts\\Activate.ps1 (PowerShell)")
print(" source .venv/bin/activate (Bash/macOS/Linux)")
sys.exit(1)
def check_env_file():
"""Check if .env file exists."""
env_file = Path('.env')
if not env_file.exists():
print("ERROR: .env file not found. Copy .env.example to .env and configure:")
print(" Copy-Item .env.example .env (PowerShell)")
print(" cp .env.example .env (Bash)")
sys.exit(1)
def check_database_url():
"""Check if DATABASE_URL is set in environment."""
database_url = os.getenv('DATABASE_URL')
if not database_url:
print("ERROR: DATABASE_URL not set. Check your .env file.")
sys.exit(1)
print(f"Using database: {database_url}")
def run_command(cmd, cwd=None, description="", env=None):
"""Run a shell command and return the result."""
print(f"\n>>> {description}")
print(f"Running: {' '.join(cmd)}")
try:
env_vars = os.environ.copy()
if env:
env_vars.update(env)
env_vars.setdefault("PYTHONPATH", "/app")
result = subprocess.run(
cmd,
cwd=cwd,
check=True,
capture_output=True,
text=True,
env=env_vars,
)
if result.stdout:
print(result.stdout)
return result
except subprocess.CalledProcessError as e:
print(f"ERROR: Command failed with exit code {e.returncode}")
if e.stdout:
print(e.stdout)
if e.stderr:
print(e.stderr)
sys.exit(1)
def run_migrations():
"""Run database migrations using alembic."""
run_command(
['alembic', 'upgrade', 'head'],
cwd='backend',
description="Running database migrations",
)
def load_osm_fixtures(region, dry_run=False):
"""Load OSM fixtures for demo data."""
cmd = ['python', '-m', 'backend.scripts.osm_refresh', '--region', region]
if dry_run:
cmd.append('--no-commit')
description = f"Loading OSM fixtures (dry run) for region: {region}"
else:
description = f"Loading OSM fixtures for region: {region}"
run_command(cmd, description=description)
def main():
parser = argparse.ArgumentParser(
description="Initialize database with demo data")
parser.add_argument(
'--region',
default='all',
help='OSM region to load (default: all)'
)
parser.add_argument(
'--dry-run',
action='store_true',
help='Dry run: run migrations and load fixtures without committing'
)
parser.add_argument(
'--skip-migrations',
action='store_true',
help='Skip running migrations'
)
parser.add_argument(
'--skip-fixtures',
action='store_true',
help='Skip loading OSM fixtures'
)
args = parser.parse_args()
print("Rail Game Database Initialization")
print("=" * 40)
# Pre-flight checks
check_virtualenv()
check_env_file()
check_database_url()
# Run migrations
if not args.skip_migrations:
run_migrations()
else:
print("Skipping migrations (--skip-migrations)")
# Load fixtures
if not args.skip_fixtures:
load_osm_fixtures(args.region, args.dry_run)
else:
print("Skipping fixtures (--skip-fixtures)")
print("\n✅ Database initialization completed successfully!")
if args.dry_run:
print("Note: This was a dry run. No data was committed to the database.")
else:
print("Demo data loaded. You can now start the backend server.")
if __name__ == '__main__':
main()