feat: Add combined track functionality with repository and service layers
- Introduced CombinedTrackModel, CombinedTrackCreate, and CombinedTrackRepository for managing combined tracks. - Implemented logic to create combined tracks based on existing tracks between two stations. - Added methods to check for existing combined tracks and retrieve constituent track IDs. - Enhanced TrackModel and TrackRepository to support OSM ID and track updates. - Created migration scripts for adding combined tracks table and OSM ID to tracks. - Updated services and API endpoints to handle combined track operations. - Added tests for combined track creation, repository methods, and API interactions.
This commit is contained in:
80
TODO.md
80
TODO.md
@@ -1,80 +0,0 @@
|
||||
# Development TODO Plan
|
||||
|
||||
## Phase 1 – Project Foundations
|
||||
|
||||
- [x] Initialize Git hooks, linting, and formatting tooling (ESLint, Prettier, isort, black).
|
||||
- [x] Configure `pyproject.toml` or equivalent for backend dependency management.
|
||||
- [x] Scaffold FastAPI application entrypoint with health-check endpoint.
|
||||
- [x] Bootstrap React app with Vite/CRA, including routing skeleton and global state provider.
|
||||
- [x] Define shared TypeScript/Python models for core domain entities (tracks, stations, trains).
|
||||
- [x] Set up CI workflow for linting and test automation (GitHub Actions).
|
||||
|
||||
## Phase 2 – Core Features
|
||||
|
||||
- [x] Implement authentication flow (backend JWT, frontend login/register forms).
|
||||
- [x] Build map visualization integrating Leaflet with OSM tiles.
|
||||
- [x] Define geographic bounding boxes and filtering rules for importing real-world stations from OpenStreetMap.
|
||||
- [x] Implement an import script/CLI that pulls OSM station data and normalizes it to the PostGIS schema.
|
||||
- [x] Expose backend CRUD endpoints for stations (create, update, archive) with validation and geometry handling.
|
||||
- [x] Build React map tooling for selecting a station.
|
||||
- [x] Enhance map UI to support selecting two stations and previewing the rail corridor between them.
|
||||
- [x] Define track selection criteria and tagging rules for harvesting OSM rail segments within target regions.
|
||||
- [x] Extend the importer to load track geometries and associate them with existing stations.
|
||||
- [ ] Implement backend track-management APIs with length/speed validation and topology checks.
|
||||
- [ ] Implement track path mapping along existing OSM rail segments between chosen stations.
|
||||
- [ ] Design train connection manager requirements (link trains to operating tracks, manage consist data).
|
||||
- [ ] Implement backend services and APIs to attach trains to routes and update assignments.
|
||||
- [ ] Add UI flows for managing train connections, including visual feedback on the map.
|
||||
- [ ] Establish train scheduling service with validation rules, conflict detection, and persistence APIs.
|
||||
- [ ] Provide frontend scheduling tools (timeline or table view) for creating and editing train timetables.
|
||||
- [ ] Develop frontend dashboards for resources, schedules, and achievements.
|
||||
- [ ] Add real-time simulation updates (WebSocket layer, frontend subscription hooks).
|
||||
|
||||
## Phase 3 – Data & Persistence
|
||||
|
||||
- [x] Design PostgreSQL/PostGIS schema and migrations (Alembic or similar).
|
||||
- [x] Implement data access layer with SQLAlchemy and repository abstractions.
|
||||
- [ ] Decide on canonical fixture scope (demo geography, sample trains) and document expected dataset size.
|
||||
- [ ] Author fixture generation scripts that export JSON/GeoJSON compatible with the repository layer.
|
||||
- [x] Create ingestion utilities to load fixtures into local and CI databases.
|
||||
- [ ] Provision a Redis instance/container for local development.
|
||||
- [ ] Add caching abstractions in backend services (e.g., network snapshot, map layers).
|
||||
- [ ] Implement cache invalidation hooks tied to repository mutations.
|
||||
|
||||
## Phase 4 – Testing & Quality
|
||||
|
||||
- [x] Write unit tests for backend services and models.
|
||||
- [ ] Configure Jest/RTL testing utilities and shared mocks for Leaflet and network APIs.
|
||||
- [ ] Write component tests for map controls, station builder UI, and dashboards.
|
||||
- [ ] Add integration tests for custom hooks (network snapshot, scheduling forms).
|
||||
- [x] Stand up Playwright/Cypress project structure with authentication helpers.
|
||||
- [x] Script login end-to-end flow (Playwright).
|
||||
- [ ] Script station creation end-to-end flow.
|
||||
- [ ] Script track placement end-to-end flow.
|
||||
- [ ] Script scheduling end-to-end flow.
|
||||
- [ ] Define load/performance targets (requests per second, simulation latency) and tooling.
|
||||
- [ ] Implement performance test harness covering scheduling and real-time updates.
|
||||
|
||||
## Phase 5 – Deployment & Ops
|
||||
|
||||
- [x] Create Dockerfile for frontend.
|
||||
- [x] Create Dockerfile for backend.
|
||||
- [x] Create docker-compose for local development with Postgres/Redis dependencies.
|
||||
- [ ] Add task runner commands to orchestrate container workflows.
|
||||
- [ ] Set up CI/CD pipeline for automated builds, tests, and container publishing.
|
||||
- [ ] Provision infrastructure scripts (Terraform/Ansible) targeting initial cloud environment.
|
||||
- [ ] Define environment configuration strategy (secrets management, config maps).
|
||||
- [ ] Configure observability stack (logging, metrics, tracing).
|
||||
- [ ] Integrate tracing/logging exporters into backend services.
|
||||
- [ ] Document deployment pipeline and release process.
|
||||
|
||||
## Phase 6 – Polish & Expansion
|
||||
|
||||
- [ ] Add leaderboards and achievements logic with UI integration.
|
||||
- [ ] Design data model changes required for achievements and ranking.
|
||||
- [ ] Implement accessibility audit fixes (WCAG compliance).
|
||||
- [ ] Conduct accessibility audit (contrast, keyboard navigation, screen reader paths).
|
||||
- [ ] Optimize asset loading and introduce lazy loading strategies.
|
||||
- [ ] Establish performance budgets for bundle size and render times.
|
||||
- [ ] Evaluate multiplayer/coop roadmap and spike POCs where feasible.
|
||||
- [ ] Prototype networking approach (WebRTC/WebSocket) for cooperative sessions.
|
||||
@@ -25,7 +25,7 @@ EXPOSE 8000
|
||||
|
||||
# Initialize database with demo data if INIT_DEMO_DB is set
|
||||
CMD ["sh", "-c", "\
|
||||
export PYTHONPATH=/app && \
|
||||
export PYTHONPATH=/app/backend && \
|
||||
echo 'Waiting for database...' && \
|
||||
while ! pg_isready -h db -p 5432 -U railgame >/dev/null 2>&1; do sleep 1; done && \
|
||||
echo 'Database is ready!' && \
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[alembic]
|
||||
script_location = migrations
|
||||
sqlalchemy.url = postgresql+psycopg://railgame:railgame@localhost:5432/railgame
|
||||
sqlalchemy.url = postgresql+psycopg://railgame:railgame@localhost:5432/railgame_dev
|
||||
|
||||
[loggers]
|
||||
keys = root,sqlalchemy,alembic
|
||||
|
||||
@@ -4,9 +4,11 @@ from backend.app.api.auth import router as auth_router
|
||||
from backend.app.api.health import router as health_router
|
||||
from backend.app.api.network import router as network_router
|
||||
from backend.app.api.stations import router as stations_router
|
||||
from backend.app.api.tracks import router as tracks_router
|
||||
|
||||
router = APIRouter()
|
||||
router.include_router(health_router, tags=["health"])
|
||||
router.include_router(auth_router)
|
||||
router.include_router(network_router)
|
||||
router.include_router(stations_router)
|
||||
router.include_router(tracks_router)
|
||||
|
||||
153
backend/app/api/tracks.py
Normal file
153
backend/app/api/tracks.py
Normal file
@@ -0,0 +1,153 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from backend.app.api.deps import get_current_user, get_db
|
||||
from backend.app.models import (
|
||||
CombinedTrackModel,
|
||||
TrackCreate,
|
||||
TrackUpdate,
|
||||
TrackModel,
|
||||
UserPublic,
|
||||
)
|
||||
from backend.app.services.combined_tracks import (
|
||||
create_combined_track,
|
||||
get_combined_track,
|
||||
list_combined_tracks,
|
||||
)
|
||||
from backend.app.services.tracks import (
|
||||
create_track,
|
||||
delete_track,
|
||||
regenerate_combined_tracks,
|
||||
update_track,
|
||||
get_track,
|
||||
list_tracks,
|
||||
)
|
||||
|
||||
router = APIRouter(prefix="/tracks", tags=["tracks"])
|
||||
|
||||
|
||||
@router.get("", response_model=list[TrackModel])
|
||||
def read_combined_tracks(
|
||||
_: UserPublic = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> list[TrackModel]:
|
||||
"""Return all base tracks."""
|
||||
return list_tracks(db)
|
||||
|
||||
|
||||
@router.get("/combined", response_model=list[CombinedTrackModel])
|
||||
def read_combined_tracks_combined(
|
||||
_: UserPublic = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> list[CombinedTrackModel]:
|
||||
return list_combined_tracks(db)
|
||||
|
||||
|
||||
@router.get("/{track_id}", response_model=TrackModel)
|
||||
def read_track(
|
||||
track_id: str,
|
||||
_: UserPublic = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> TrackModel:
|
||||
track = get_track(db, track_id)
|
||||
if track is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Track {track_id} not found",
|
||||
)
|
||||
return track
|
||||
|
||||
|
||||
@router.get("/combined/{combined_track_id}", response_model=CombinedTrackModel)
|
||||
def read_combined_track(
|
||||
combined_track_id: str,
|
||||
_: UserPublic = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> CombinedTrackModel:
|
||||
combined_track = get_combined_track(db, combined_track_id)
|
||||
if combined_track is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Combined track {combined_track_id} not found",
|
||||
)
|
||||
return combined_track
|
||||
|
||||
|
||||
@router.post("", response_model=TrackModel, status_code=status.HTTP_201_CREATED)
|
||||
def create_track_endpoint(
|
||||
payload: TrackCreate,
|
||||
regenerate: bool = False,
|
||||
_: UserPublic = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> TrackModel:
|
||||
try:
|
||||
track = create_track(db, payload)
|
||||
except ValueError as exc:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST, detail=str(exc)
|
||||
) from exc
|
||||
|
||||
if regenerate:
|
||||
regenerate_combined_tracks(
|
||||
db, [track.start_station_id, track.end_station_id])
|
||||
|
||||
return track
|
||||
|
||||
|
||||
@router.post(
|
||||
"/combined",
|
||||
response_model=CombinedTrackModel,
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
summary="Create a combined track between two stations using pathfinding",
|
||||
)
|
||||
def create_combined_track_endpoint(
|
||||
start_station_id: str,
|
||||
end_station_id: str,
|
||||
_: UserPublic = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> CombinedTrackModel:
|
||||
combined_track = create_combined_track(
|
||||
db, start_station_id, end_station_id)
|
||||
if combined_track is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Could not create combined track: no path exists between stations or track already exists",
|
||||
)
|
||||
return combined_track
|
||||
|
||||
|
||||
@router.put("/{track_id}", response_model=TrackModel)
|
||||
def update_track_endpoint(
|
||||
track_id: str,
|
||||
payload: TrackUpdate,
|
||||
regenerate: bool = False,
|
||||
_: UserPublic = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> TrackModel:
|
||||
track = update_track(db, track_id, payload)
|
||||
if track is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Track {track_id} not found",
|
||||
)
|
||||
if regenerate:
|
||||
regenerate_combined_tracks(
|
||||
db, [track.start_station_id, track.end_station_id])
|
||||
return track
|
||||
|
||||
|
||||
@router.delete("/{track_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||
def delete_track_endpoint(
|
||||
track_id: str,
|
||||
regenerate: bool = False,
|
||||
_: UserPublic = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> None:
|
||||
deleted = delete_track(db, track_id, regenerate=regenerate)
|
||||
if not deleted:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Track {track_id} not found",
|
||||
)
|
||||
@@ -41,11 +41,14 @@ class User(Base, TimestampMixin):
|
||||
id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True), primary_key=True, default=uuid.uuid4
|
||||
)
|
||||
username: Mapped[str] = mapped_column(String(64), unique=True, nullable=False)
|
||||
email: Mapped[str | None] = mapped_column(String(255), unique=True, nullable=True)
|
||||
username: Mapped[str] = mapped_column(
|
||||
String(64), unique=True, nullable=False)
|
||||
email: Mapped[str | None] = mapped_column(
|
||||
String(255), unique=True, nullable=True)
|
||||
full_name: Mapped[str | None] = mapped_column(String(128), nullable=True)
|
||||
password_hash: Mapped[str] = mapped_column(String(256), nullable=False)
|
||||
role: Mapped[str] = mapped_column(String(32), nullable=False, default="player")
|
||||
role: Mapped[str] = mapped_column(
|
||||
String(32), nullable=False, default="player")
|
||||
preferences: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
|
||||
|
||||
@@ -62,12 +65,50 @@ class Station(Base, TimestampMixin):
|
||||
Geometry(geometry_type="POINT", srid=4326), nullable=False
|
||||
)
|
||||
elevation_m: Mapped[float | None] = mapped_column(Float, nullable=True)
|
||||
is_active: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
|
||||
is_active: Mapped[bool] = mapped_column(
|
||||
Boolean, nullable=False, default=True)
|
||||
|
||||
|
||||
class Track(Base, TimestampMixin):
|
||||
__tablename__ = "tracks"
|
||||
|
||||
id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True), primary_key=True, default=uuid.uuid4
|
||||
)
|
||||
osm_id: Mapped[str | None] = mapped_column(String(32), nullable=True)
|
||||
name: Mapped[str | None] = mapped_column(String(128), nullable=True)
|
||||
start_station_id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("stations.id", ondelete="RESTRICT"),
|
||||
nullable=False,
|
||||
)
|
||||
end_station_id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("stations.id", ondelete="RESTRICT"),
|
||||
nullable=False,
|
||||
)
|
||||
length_meters: Mapped[float | None] = mapped_column(
|
||||
Numeric(10, 2), nullable=True)
|
||||
max_speed_kph: Mapped[int | None] = mapped_column(Integer, nullable=True)
|
||||
is_bidirectional: Mapped[bool] = mapped_column(
|
||||
Boolean, nullable=False, default=True
|
||||
)
|
||||
status: Mapped[str] = mapped_column(
|
||||
String(32), nullable=False, default="planned")
|
||||
track_geometry: Mapped[str] = mapped_column(
|
||||
Geometry(geometry_type="LINESTRING", srid=4326), nullable=False
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
UniqueConstraint(
|
||||
"start_station_id", "end_station_id", name="uq_tracks_station_pair"
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class CombinedTrack(Base, TimestampMixin):
|
||||
__tablename__ = "combined_tracks"
|
||||
|
||||
id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True), primary_key=True, default=uuid.uuid4
|
||||
)
|
||||
@@ -82,19 +123,25 @@ class Track(Base, TimestampMixin):
|
||||
ForeignKey("stations.id", ondelete="RESTRICT"),
|
||||
nullable=False,
|
||||
)
|
||||
length_meters: Mapped[float | None] = mapped_column(Numeric(10, 2), nullable=True)
|
||||
length_meters: Mapped[float | None] = mapped_column(
|
||||
Numeric(10, 2), nullable=True)
|
||||
max_speed_kph: Mapped[int | None] = mapped_column(Integer, nullable=True)
|
||||
is_bidirectional: Mapped[bool] = mapped_column(
|
||||
Boolean, nullable=False, default=True
|
||||
)
|
||||
status: Mapped[str] = mapped_column(String(32), nullable=False, default="planned")
|
||||
track_geometry: Mapped[str] = mapped_column(
|
||||
status: Mapped[str] = mapped_column(
|
||||
String(32), nullable=False, default="planned")
|
||||
combined_geometry: Mapped[str] = mapped_column(
|
||||
Geometry(geometry_type="LINESTRING", srid=4326), nullable=False
|
||||
)
|
||||
# JSON array of constituent track IDs
|
||||
constituent_track_ids: Mapped[str] = mapped_column(
|
||||
Text, nullable=False
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
UniqueConstraint(
|
||||
"start_station_id", "end_station_id", name="uq_tracks_station_pair"
|
||||
"start_station_id", "end_station_id", name="uq_combined_tracks_station_pair"
|
||||
),
|
||||
)
|
||||
|
||||
@@ -105,7 +152,8 @@ class Train(Base, TimestampMixin):
|
||||
id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True), primary_key=True, default=uuid.uuid4
|
||||
)
|
||||
designation: Mapped[str] = mapped_column(String(64), nullable=False, unique=True)
|
||||
designation: Mapped[str] = mapped_column(
|
||||
String(64), nullable=False, unique=True)
|
||||
operator_id: Mapped[uuid.UUID | None] = mapped_column(
|
||||
UUID(as_uuid=True), ForeignKey("users.id", ondelete="SET NULL")
|
||||
)
|
||||
|
||||
@@ -8,11 +8,14 @@ from .auth import (
|
||||
UserPublic,
|
||||
)
|
||||
from .base import (
|
||||
CombinedTrackCreate,
|
||||
CombinedTrackModel,
|
||||
StationCreate,
|
||||
StationModel,
|
||||
StationUpdate,
|
||||
TrackCreate,
|
||||
TrackModel,
|
||||
TrackUpdate,
|
||||
TrainCreate,
|
||||
TrainModel,
|
||||
TrainScheduleCreate,
|
||||
@@ -33,9 +36,12 @@ __all__ = [
|
||||
"StationUpdate",
|
||||
"TrackCreate",
|
||||
"TrackModel",
|
||||
"TrackUpdate",
|
||||
"TrainScheduleCreate",
|
||||
"TrainCreate",
|
||||
"TrainModel",
|
||||
"UserCreate",
|
||||
"to_camel",
|
||||
"CombinedTrackCreate",
|
||||
"CombinedTrackModel",
|
||||
]
|
||||
|
||||
@@ -51,13 +51,24 @@ class StationModel(IdentifiedModel[str]):
|
||||
class TrackModel(IdentifiedModel[str]):
|
||||
start_station_id: str
|
||||
end_station_id: str
|
||||
length_meters: float
|
||||
max_speed_kph: float
|
||||
length_meters: float | None = None
|
||||
max_speed_kph: float | None = None
|
||||
status: str | None = None
|
||||
is_bidirectional: bool = True
|
||||
coordinates: list[tuple[float, float]] = Field(default_factory=list)
|
||||
|
||||
|
||||
class CombinedTrackModel(IdentifiedModel[str]):
|
||||
start_station_id: str
|
||||
end_station_id: str
|
||||
length_meters: float | None = None
|
||||
max_speed_kph: int | None = None
|
||||
status: str | None = None
|
||||
is_bidirectional: bool = True
|
||||
coordinates: list[tuple[float, float]] = Field(default_factory=list)
|
||||
constituent_track_ids: list[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
class TrainModel(IdentifiedModel[str]):
|
||||
designation: str
|
||||
capacity: int
|
||||
@@ -89,6 +100,31 @@ class TrackCreate(CamelModel):
|
||||
start_station_id: str
|
||||
end_station_id: str
|
||||
coordinates: Sequence[tuple[float, float]]
|
||||
osm_id: str | None = None
|
||||
name: str | None = None
|
||||
length_meters: float | None = None
|
||||
max_speed_kph: int | None = None
|
||||
is_bidirectional: bool = True
|
||||
status: str = "planned"
|
||||
|
||||
|
||||
class TrackUpdate(CamelModel):
|
||||
start_station_id: str | None = None
|
||||
end_station_id: str | None = None
|
||||
coordinates: Sequence[tuple[float, float]] | None = None
|
||||
osm_id: str | None = None
|
||||
name: str | None = None
|
||||
length_meters: float | None = None
|
||||
max_speed_kph: int | None = None
|
||||
is_bidirectional: bool | None = None
|
||||
status: str | None = None
|
||||
|
||||
|
||||
class CombinedTrackCreate(CamelModel):
|
||||
start_station_id: str
|
||||
end_station_id: str
|
||||
coordinates: Sequence[tuple[float, float]]
|
||||
constituent_track_ids: list[str]
|
||||
name: str | None = None
|
||||
length_meters: float | None = None
|
||||
max_speed_kph: int | None = None
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
from backend.app.repositories.stations import StationRepository
|
||||
from backend.app.repositories.tracks import TrackRepository
|
||||
from backend.app.repositories.combined_tracks import CombinedTrackRepository
|
||||
from backend.app.repositories.train_schedules import TrainScheduleRepository
|
||||
from backend.app.repositories.trains import TrainRepository
|
||||
from backend.app.repositories.users import UserRepository
|
||||
@@ -10,6 +11,7 @@ __all__ = [
|
||||
"StationRepository",
|
||||
"TrainScheduleRepository",
|
||||
"TrackRepository",
|
||||
"CombinedTrackRepository",
|
||||
"TrainRepository",
|
||||
"UserRepository",
|
||||
]
|
||||
|
||||
73
backend/app/repositories/combined_tracks.py
Normal file
73
backend/app/repositories/combined_tracks.py
Normal file
@@ -0,0 +1,73 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from uuid import UUID
|
||||
|
||||
import sqlalchemy as sa
|
||||
from geoalchemy2.elements import WKTElement
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from backend.app.db.models import CombinedTrack
|
||||
from backend.app.models import CombinedTrackCreate
|
||||
from backend.app.repositories.base import BaseRepository
|
||||
|
||||
|
||||
class CombinedTrackRepository(BaseRepository[CombinedTrack]):
|
||||
model = CombinedTrack
|
||||
|
||||
def __init__(self, session: Session) -> None:
|
||||
super().__init__(session)
|
||||
|
||||
def list_all(self) -> list[CombinedTrack]:
|
||||
statement = sa.select(self.model)
|
||||
return list(self.session.scalars(statement))
|
||||
|
||||
def exists_between_stations(self, start_station_id: str, end_station_id: str) -> bool:
|
||||
"""Check if a combined track already exists between two stations."""
|
||||
statement = sa.select(sa.exists().where(
|
||||
sa.and_(
|
||||
self.model.start_station_id == start_station_id,
|
||||
self.model.end_station_id == end_station_id
|
||||
)
|
||||
))
|
||||
return bool(self.session.scalar(statement))
|
||||
|
||||
def get_constituent_track_ids(self, combined_track: CombinedTrack) -> list[str]:
|
||||
"""Extract constituent track IDs from a combined track."""
|
||||
try:
|
||||
return json.loads(combined_track.constituent_track_ids)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
return []
|
||||
|
||||
@staticmethod
|
||||
def _ensure_uuid(value: UUID | str) -> UUID:
|
||||
if isinstance(value, UUID):
|
||||
return value
|
||||
return UUID(str(value))
|
||||
|
||||
@staticmethod
|
||||
def _line_string(coordinates: list[tuple[float, float]]) -> WKTElement:
|
||||
if len(coordinates) < 2:
|
||||
raise ValueError(
|
||||
"Combined track geometry requires at least two coordinate pairs")
|
||||
parts = [f"{lon} {lat}" for lat, lon in coordinates]
|
||||
return WKTElement(f"LINESTRING({', '.join(parts)})", srid=4326)
|
||||
|
||||
def create(self, data: CombinedTrackCreate) -> CombinedTrack:
|
||||
coordinates = list(data.coordinates)
|
||||
geometry = self._line_string(coordinates)
|
||||
constituent_track_ids_json = json.dumps(data.constituent_track_ids)
|
||||
|
||||
combined_track = CombinedTrack(
|
||||
name=data.name,
|
||||
start_station_id=self._ensure_uuid(data.start_station_id),
|
||||
end_station_id=self._ensure_uuid(data.end_station_id),
|
||||
length_meters=data.length_meters,
|
||||
max_speed_kph=data.max_speed_kph,
|
||||
is_bidirectional=data.is_bidirectional,
|
||||
status=data.status,
|
||||
combined_geometry=geometry,
|
||||
constituent_track_ids=constituent_track_ids_json,
|
||||
)
|
||||
self.session.add(combined_track)
|
||||
return combined_track
|
||||
@@ -7,7 +7,7 @@ from geoalchemy2.elements import WKTElement
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from backend.app.db.models import Track
|
||||
from backend.app.models import TrackCreate
|
||||
from backend.app.models import TrackCreate, TrackUpdate
|
||||
from backend.app.repositories.base import BaseRepository
|
||||
|
||||
|
||||
@@ -21,6 +21,102 @@ class TrackRepository(BaseRepository[Track]):
|
||||
statement = sa.select(self.model)
|
||||
return list(self.session.scalars(statement))
|
||||
|
||||
def exists_by_osm_id(self, osm_id: str) -> bool:
|
||||
statement = sa.select(sa.exists().where(self.model.osm_id == osm_id))
|
||||
return bool(self.session.scalar(statement))
|
||||
|
||||
def find_path_between_stations(self, start_station_id: str, end_station_id: str) -> list[Track] | None:
|
||||
"""Find the shortest path between two stations using existing tracks.
|
||||
|
||||
Returns a list of tracks that form the path, or None if no path exists.
|
||||
"""
|
||||
# Build adjacency list: station -> list of (neighbor_station, track)
|
||||
adjacency = self._build_track_graph()
|
||||
|
||||
if start_station_id not in adjacency or end_station_id not in adjacency:
|
||||
return None
|
||||
|
||||
# BFS to find shortest path
|
||||
from collections import deque
|
||||
|
||||
# (current_station, path_so_far)
|
||||
queue = deque([(start_station_id, [])])
|
||||
visited = set([start_station_id])
|
||||
|
||||
while queue:
|
||||
current_station, path = queue.popleft()
|
||||
|
||||
if current_station == end_station_id:
|
||||
return path
|
||||
|
||||
for neighbor, track in adjacency[current_station]:
|
||||
if neighbor not in visited:
|
||||
visited.add(neighbor)
|
||||
queue.append((neighbor, path + [track]))
|
||||
|
||||
return None # No path found
|
||||
|
||||
def _build_track_graph(self) -> dict[str, list[tuple[str, Track]]]:
|
||||
"""Build a graph representation of tracks: station -> [(neighbor_station, track), ...]"""
|
||||
tracks = self.list_all()
|
||||
graph = {}
|
||||
|
||||
for track in tracks:
|
||||
start_id = str(track.start_station_id)
|
||||
end_id = str(track.end_station_id)
|
||||
|
||||
# Add bidirectional edges (assuming tracks are bidirectional)
|
||||
if start_id not in graph:
|
||||
graph[start_id] = []
|
||||
if end_id not in graph:
|
||||
graph[end_id] = []
|
||||
|
||||
graph[start_id].append((end_id, track))
|
||||
graph[end_id].append((start_id, track))
|
||||
|
||||
return graph
|
||||
|
||||
def combine_track_geometries(self, tracks: list[Track]) -> list[tuple[float, float]]:
|
||||
"""Combine the geometries of multiple tracks into a single coordinate sequence.
|
||||
|
||||
Assumes tracks are in order and form a continuous path.
|
||||
"""
|
||||
if not tracks:
|
||||
return []
|
||||
|
||||
combined_coords = []
|
||||
|
||||
for i, track in enumerate(tracks):
|
||||
# Extract coordinates from track geometry
|
||||
coords = self._extract_coordinates_from_track(track)
|
||||
|
||||
if i == 0:
|
||||
# First track: add all coordinates
|
||||
combined_coords.extend(coords)
|
||||
else:
|
||||
# Subsequent tracks: skip the first coordinate (shared with previous track)
|
||||
combined_coords.extend(coords[1:])
|
||||
|
||||
return combined_coords
|
||||
|
||||
def _extract_coordinates_from_track(self, track: Track) -> list[tuple[float, float]]:
|
||||
"""Extract coordinate list from a track's geometry."""
|
||||
# Convert WKT string to WKTElement, then to shapely geometry
|
||||
from geoalchemy2.elements import WKTElement
|
||||
from geoalchemy2.shape import to_shape
|
||||
|
||||
try:
|
||||
wkt_element = WKTElement(track.track_geometry)
|
||||
geom = to_shape(wkt_element)
|
||||
if hasattr(geom, 'coords'):
|
||||
# For LineString, coords returns [(x, y), ...] where x=lon, y=lat
|
||||
# Convert to (lat, lon)
|
||||
return [(coord[1], coord[0]) for coord in geom.coords]
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return []
|
||||
|
||||
@staticmethod
|
||||
def _ensure_uuid(value: UUID | str) -> UUID:
|
||||
if isinstance(value, UUID):
|
||||
@@ -30,7 +126,8 @@ class TrackRepository(BaseRepository[Track]):
|
||||
@staticmethod
|
||||
def _line_string(coordinates: list[tuple[float, float]]) -> WKTElement:
|
||||
if len(coordinates) < 2:
|
||||
raise ValueError("Track geometry requires at least two coordinate pairs")
|
||||
raise ValueError(
|
||||
"Track geometry requires at least two coordinate pairs")
|
||||
parts = [f"{lon} {lat}" for lat, lon in coordinates]
|
||||
return WKTElement(f"LINESTRING({', '.join(parts)})", srid=4326)
|
||||
|
||||
@@ -38,6 +135,7 @@ class TrackRepository(BaseRepository[Track]):
|
||||
coordinates = list(data.coordinates)
|
||||
geometry = self._line_string(coordinates)
|
||||
track = Track(
|
||||
osm_id=data.osm_id,
|
||||
name=data.name,
|
||||
start_station_id=self._ensure_uuid(data.start_station_id),
|
||||
end_station_id=self._ensure_uuid(data.end_station_id),
|
||||
@@ -49,3 +147,26 @@ class TrackRepository(BaseRepository[Track]):
|
||||
)
|
||||
self.session.add(track)
|
||||
return track
|
||||
|
||||
def update(self, track: Track, payload: TrackUpdate) -> Track:
|
||||
if payload.start_station_id is not None:
|
||||
track.start_station_id = self._ensure_uuid(
|
||||
payload.start_station_id)
|
||||
if payload.end_station_id is not None:
|
||||
track.end_station_id = self._ensure_uuid(payload.end_station_id)
|
||||
if payload.coordinates is not None:
|
||||
track.track_geometry = self._line_string(
|
||||
list(payload.coordinates)) # type: ignore[assignment]
|
||||
if payload.osm_id is not None:
|
||||
track.osm_id = payload.osm_id
|
||||
if payload.name is not None:
|
||||
track.name = payload.name
|
||||
if payload.length_meters is not None:
|
||||
track.length_meters = payload.length_meters
|
||||
if payload.max_speed_kph is not None:
|
||||
track.max_speed_kph = payload.max_speed_kph
|
||||
if payload.is_bidirectional is not None:
|
||||
track.is_bidirectional = payload.is_bidirectional
|
||||
if payload.status is not None:
|
||||
track.status = payload.status
|
||||
return track
|
||||
|
||||
79
backend/app/services/combined_tracks.py
Normal file
79
backend/app/services/combined_tracks.py
Normal file
@@ -0,0 +1,79 @@
|
||||
from __future__ import annotations
|
||||
|
||||
"""Application services for combined track operations."""
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from backend.app.models import CombinedTrackCreate, CombinedTrackModel
|
||||
from backend.app.repositories import CombinedTrackRepository, TrackRepository
|
||||
|
||||
|
||||
def create_combined_track(
|
||||
session: Session, start_station_id: str, end_station_id: str
|
||||
) -> CombinedTrackModel | None:
|
||||
"""Create a combined track between two stations using pathfinding.
|
||||
|
||||
Returns the created combined track, or None if no path exists or
|
||||
a combined track already exists between these stations.
|
||||
"""
|
||||
combined_track_repo = CombinedTrackRepository(session)
|
||||
track_repo = TrackRepository(session)
|
||||
|
||||
# Check if combined track already exists
|
||||
if combined_track_repo.exists_between_stations(start_station_id, end_station_id):
|
||||
return None
|
||||
|
||||
# Find path between stations
|
||||
path_tracks = track_repo.find_path_between_stations(
|
||||
start_station_id, end_station_id)
|
||||
if not path_tracks:
|
||||
return None
|
||||
|
||||
# Combine geometries
|
||||
combined_coords = track_repo.combine_track_geometries(path_tracks)
|
||||
if len(combined_coords) < 2:
|
||||
return None
|
||||
|
||||
# Calculate total length
|
||||
total_length = sum(track.length_meters or 0 for track in path_tracks)
|
||||
|
||||
# Get max speed (use the minimum speed of all tracks)
|
||||
max_speeds = [
|
||||
track.max_speed_kph for track in path_tracks if track.max_speed_kph]
|
||||
max_speed = min(max_speeds) if max_speeds else None
|
||||
|
||||
# Get constituent track IDs
|
||||
constituent_track_ids = [str(track.id) for track in path_tracks]
|
||||
|
||||
# Create combined track
|
||||
create_data = CombinedTrackCreate(
|
||||
start_station_id=start_station_id,
|
||||
end_station_id=end_station_id,
|
||||
coordinates=combined_coords,
|
||||
constituent_track_ids=constituent_track_ids,
|
||||
length_meters=total_length if total_length > 0 else None,
|
||||
max_speed_kph=max_speed,
|
||||
status="operational",
|
||||
)
|
||||
|
||||
combined_track = combined_track_repo.create(create_data)
|
||||
session.commit()
|
||||
|
||||
return CombinedTrackModel.model_validate(combined_track)
|
||||
|
||||
|
||||
def get_combined_track(session: Session, combined_track_id: str) -> CombinedTrackModel | None:
|
||||
"""Get a combined track by ID."""
|
||||
try:
|
||||
combined_track_repo = CombinedTrackRepository(session)
|
||||
combined_track = combined_track_repo.get(combined_track_id)
|
||||
return CombinedTrackModel.model_validate(combined_track)
|
||||
except LookupError:
|
||||
return None
|
||||
|
||||
|
||||
def list_combined_tracks(session: Session) -> list[CombinedTrackModel]:
|
||||
"""List all combined tracks."""
|
||||
combined_track_repo = CombinedTrackRepository(session)
|
||||
combined_tracks = combined_track_repo.list_all()
|
||||
return [CombinedTrackModel.model_validate(ct) for ct in combined_tracks]
|
||||
106
backend/app/services/tracks.py
Normal file
106
backend/app/services/tracks.py
Normal file
@@ -0,0 +1,106 @@
|
||||
from __future__ import annotations
|
||||
|
||||
"""Service layer for primary track management operations."""
|
||||
|
||||
from typing import Iterable
|
||||
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from backend.app.models import CombinedTrackModel, TrackCreate, TrackModel, TrackUpdate
|
||||
from backend.app.repositories import CombinedTrackRepository, TrackRepository
|
||||
|
||||
|
||||
def list_tracks(session: Session) -> list[TrackModel]:
|
||||
repo = TrackRepository(session)
|
||||
tracks = repo.list_all()
|
||||
return [TrackModel.model_validate(track) for track in tracks]
|
||||
|
||||
|
||||
def get_track(session: Session, track_id: str) -> TrackModel | None:
|
||||
repo = TrackRepository(session)
|
||||
track = repo.get(track_id)
|
||||
if track is None:
|
||||
return None
|
||||
return TrackModel.model_validate(track)
|
||||
|
||||
|
||||
def create_track(session: Session, payload: TrackCreate) -> TrackModel:
|
||||
repo = TrackRepository(session)
|
||||
try:
|
||||
track = repo.create(payload)
|
||||
session.commit()
|
||||
except IntegrityError as exc:
|
||||
session.rollback()
|
||||
raise ValueError(
|
||||
"Track with the same station pair already exists") from exc
|
||||
|
||||
return TrackModel.model_validate(track)
|
||||
|
||||
|
||||
def update_track(session: Session, track_id: str, payload: TrackUpdate) -> TrackModel | None:
|
||||
repo = TrackRepository(session)
|
||||
track = repo.get(track_id)
|
||||
if track is None:
|
||||
return None
|
||||
|
||||
repo.update(track, payload)
|
||||
session.commit()
|
||||
|
||||
return TrackModel.model_validate(track)
|
||||
|
||||
|
||||
def delete_track(session: Session, track_id: str, regenerate: bool = False) -> bool:
|
||||
repo = TrackRepository(session)
|
||||
track = repo.get(track_id)
|
||||
if track is None:
|
||||
return False
|
||||
|
||||
start_station_id = str(track.start_station_id)
|
||||
end_station_id = str(track.end_station_id)
|
||||
|
||||
session.delete(track)
|
||||
session.commit()
|
||||
|
||||
if regenerate:
|
||||
regenerate_combined_tracks(session, [start_station_id, end_station_id])
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def regenerate_combined_tracks(session: Session, station_ids: Iterable[str]) -> list[CombinedTrackModel]:
|
||||
combined_repo = CombinedTrackRepository(session)
|
||||
|
||||
station_id_set = set(station_ids)
|
||||
if not station_id_set:
|
||||
return []
|
||||
|
||||
# Remove combined tracks touching these stations
|
||||
for combined in combined_repo.list_all():
|
||||
if {str(combined.start_station_id), str(combined.end_station_id)} & station_id_set:
|
||||
session.delete(combined)
|
||||
|
||||
session.commit()
|
||||
|
||||
# Rebuild combined tracks between affected station pairs
|
||||
from backend.app.services.combined_tracks import create_combined_track
|
||||
|
||||
regenerated: list[CombinedTrackModel] = []
|
||||
station_list = list(station_id_set)
|
||||
for i in range(len(station_list)):
|
||||
for j in range(i + 1, len(station_list)):
|
||||
result = create_combined_track(
|
||||
session, station_list[i], station_list[j])
|
||||
if result is not None:
|
||||
regenerated.append(result)
|
||||
return regenerated
|
||||
|
||||
|
||||
__all__ = [
|
||||
"list_tracks",
|
||||
"get_track",
|
||||
"create_track",
|
||||
"update_track",
|
||||
"delete_track",
|
||||
"regenerate_combined_tracks",
|
||||
]
|
||||
@@ -0,0 +1,19 @@
|
||||
"""Template for new Alembic migration scripts."""
|
||||
|
||||
from __future__ import annotations
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
revision = '63d02d67b39e'
|
||||
down_revision = '20251011_01'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column('tracks', sa.Column(
|
||||
'osm_id', sa.String(length=32), nullable=True))
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column('tracks', 'osm_id')
|
||||
@@ -0,0 +1,75 @@
|
||||
"""Template for new Alembic migration scripts."""
|
||||
|
||||
from __future__ import annotations
|
||||
from sqlalchemy.dialects import postgresql
|
||||
from geoalchemy2.types import Geometry
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
revision = 'e7d4bb03da04'
|
||||
down_revision = '63d02d67b39e'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"combined_tracks",
|
||||
sa.Column(
|
||||
"id",
|
||||
postgresql.UUID(as_uuid=True),
|
||||
primary_key=True,
|
||||
server_default=sa.text("gen_random_uuid()"),
|
||||
),
|
||||
sa.Column("name", sa.String(length=128), nullable=True),
|
||||
sa.Column("start_station_id", postgresql.UUID(
|
||||
as_uuid=True), nullable=False),
|
||||
sa.Column("end_station_id", postgresql.UUID(
|
||||
as_uuid=True), nullable=False),
|
||||
sa.Column("length_meters", sa.Numeric(10, 2), nullable=True),
|
||||
sa.Column("max_speed_kph", sa.Integer(), nullable=True),
|
||||
sa.Column(
|
||||
"is_bidirectional",
|
||||
sa.Boolean(),
|
||||
nullable=False,
|
||||
server_default=sa.text("true"),
|
||||
),
|
||||
sa.Column(
|
||||
"status", sa.String(length=32), nullable=False, server_default="planned"
|
||||
),
|
||||
sa.Column(
|
||||
"combined_geometry",
|
||||
Geometry(geometry_type="LINESTRING", srid=4326),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("constituent_track_ids", sa.Text(), nullable=False),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("timezone('utc', now())"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"updated_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("timezone('utc', now())"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["start_station_id"], ["stations.id"], ondelete="RESTRICT"
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["end_station_id"], ["stations.id"], ondelete="RESTRICT"
|
||||
),
|
||||
sa.UniqueConstraint(
|
||||
"start_station_id", "end_station_id", name="uq_combined_tracks_station_pair"
|
||||
),
|
||||
)
|
||||
op.create_index(
|
||||
"ix_combined_tracks_geometry", "combined_tracks", ["combined_geometry"], postgresql_using="gist"
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_index("ix_combined_tracks_geometry", table_name="combined_tracks")
|
||||
op.drop_table("combined_tracks")
|
||||
@@ -22,6 +22,7 @@ from backend.app.repositories import StationRepository, TrackRepository
|
||||
@dataclass(slots=True)
|
||||
class ParsedTrack:
|
||||
coordinates: list[tuple[float, float]]
|
||||
osm_id: str | None = None
|
||||
name: str | None = None
|
||||
length_meters: float | None = None
|
||||
max_speed_kph: float | None = None
|
||||
@@ -97,7 +98,8 @@ def _parse_track_entries(entries: Iterable[Mapping[str, Any]]) -> list[ParsedTra
|
||||
processed_coordinates: list[tuple[float, float]] = []
|
||||
for pair in coordinates:
|
||||
if not isinstance(pair, Sequence) or len(pair) != 2:
|
||||
raise ValueError(f"Invalid coordinate pair {pair!r} in track entry")
|
||||
raise ValueError(
|
||||
f"Invalid coordinate pair {pair!r} in track entry")
|
||||
lat, lon = pair
|
||||
processed_coordinates.append((float(lat), float(lon)))
|
||||
|
||||
@@ -106,10 +108,12 @@ def _parse_track_entries(entries: Iterable[Mapping[str, Any]]) -> list[ParsedTra
|
||||
max_speed = _safe_float(entry.get("maxSpeedKph"))
|
||||
status = entry.get("status", "operational")
|
||||
is_bidirectional = entry.get("isBidirectional", True)
|
||||
osm_id = entry.get("osmId")
|
||||
|
||||
parsed.append(
|
||||
ParsedTrack(
|
||||
coordinates=processed_coordinates,
|
||||
osm_id=str(osm_id) if osm_id else None,
|
||||
name=str(name) if name else None,
|
||||
length_meters=length,
|
||||
max_speed_kph=max_speed,
|
||||
@@ -133,6 +137,12 @@ def load_tracks(tracks: Iterable[ParsedTrack], commit: bool = True) -> int:
|
||||
}
|
||||
|
||||
for track_data in tracks:
|
||||
# Skip if track with this OSM ID already exists
|
||||
if track_data.osm_id and track_repo.exists_by_osm_id(track_data.osm_id):
|
||||
print(
|
||||
f"Skipping track {track_data.osm_id} - already exists by OSM ID")
|
||||
continue
|
||||
|
||||
start_station = _nearest_station(
|
||||
track_data.coordinates[0],
|
||||
station_index,
|
||||
@@ -145,13 +155,19 @@ def load_tracks(tracks: Iterable[ParsedTrack], commit: bool = True) -> int:
|
||||
)
|
||||
|
||||
if not start_station or not end_station:
|
||||
print(
|
||||
f"Skipping track {track_data.osm_id} - no start/end stations found")
|
||||
continue
|
||||
|
||||
if start_station.id == end_station.id:
|
||||
print(
|
||||
f"Skipping track {track_data.osm_id} - start and end stations are the same")
|
||||
continue
|
||||
|
||||
pair = (start_station.id, end_station.id)
|
||||
if pair in existing_pairs:
|
||||
print(
|
||||
f"Skipping track {track_data.osm_id} - station pair {pair} already exists")
|
||||
continue
|
||||
|
||||
length = track_data.length_meters or _polyline_length(
|
||||
@@ -163,6 +179,7 @@ def load_tracks(tracks: Iterable[ParsedTrack], commit: bool = True) -> int:
|
||||
else None
|
||||
)
|
||||
create_schema = TrackCreate(
|
||||
osm_id=track_data.osm_id,
|
||||
name=track_data.name,
|
||||
start_station_id=start_station.id,
|
||||
end_station_id=end_station.id,
|
||||
@@ -193,7 +210,8 @@ def _nearest_station(
|
||||
best_station: StationRef | None = None
|
||||
best_distance = math.inf
|
||||
for station in stations:
|
||||
distance = _haversine(coordinate, (station.latitude, station.longitude))
|
||||
distance = _haversine(
|
||||
coordinate, (station.latitude, station.longitude))
|
||||
if distance < best_distance:
|
||||
best_station = station
|
||||
best_distance = distance
|
||||
|
||||
166
backend/tests/test_combined_tracks.py
Normal file
166
backend/tests/test_combined_tracks.py
Normal file
@@ -0,0 +1,166 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, List
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.app.models import CombinedTrackModel
|
||||
from backend.app.repositories.combined_tracks import CombinedTrackRepository
|
||||
from backend.app.repositories.tracks import TrackRepository
|
||||
from backend.app.services.combined_tracks import create_combined_track
|
||||
|
||||
|
||||
@dataclass
|
||||
class DummySession:
|
||||
added: List[Any] = field(default_factory=list)
|
||||
scalars_result: List[Any] = field(default_factory=list)
|
||||
scalar_result: Any = None
|
||||
statements: List[Any] = field(default_factory=list)
|
||||
committed: bool = False
|
||||
rolled_back: bool = False
|
||||
closed: bool = False
|
||||
|
||||
def add(self, instance: Any) -> None:
|
||||
self.added.append(instance)
|
||||
|
||||
def add_all(self, instances: list[Any]) -> None:
|
||||
self.added.extend(instances)
|
||||
|
||||
def scalars(self, statement: Any) -> list[Any]:
|
||||
self.statements.append(statement)
|
||||
return list(self.scalars_result)
|
||||
|
||||
def scalar(self, statement: Any) -> Any:
|
||||
self.statements.append(statement)
|
||||
return self.scalar_result
|
||||
|
||||
def flush(
|
||||
self, _objects: list[Any] | None = None
|
||||
) -> None: # pragma: no cover - optional
|
||||
return None
|
||||
|
||||
def commit(self) -> None: # pragma: no cover - optional
|
||||
self.committed = True
|
||||
|
||||
def rollback(self) -> None: # pragma: no cover - optional
|
||||
self.rolled_back = True
|
||||
|
||||
def close(self) -> None: # pragma: no cover - optional
|
||||
self.closed = True
|
||||
|
||||
|
||||
def _now() -> datetime:
|
||||
return datetime.now(timezone.utc)
|
||||
|
||||
|
||||
def test_combined_track_model_round_trip() -> None:
|
||||
timestamp = _now()
|
||||
combined_track = CombinedTrackModel(
|
||||
id="combined-track-1",
|
||||
start_station_id="station-1",
|
||||
end_station_id="station-2",
|
||||
length_meters=3000.0,
|
||||
max_speed_kph=100,
|
||||
status="operational",
|
||||
is_bidirectional=True,
|
||||
coordinates=[(52.52, 13.405), (52.6, 13.5), (52.7, 13.6)],
|
||||
constituent_track_ids=["track-1", "track-2"],
|
||||
created_at=timestamp,
|
||||
updated_at=timestamp,
|
||||
)
|
||||
assert combined_track.length_meters == 3000.0
|
||||
assert combined_track.start_station_id != combined_track.end_station_id
|
||||
assert len(combined_track.coordinates) == 3
|
||||
assert len(combined_track.constituent_track_ids) == 2
|
||||
|
||||
|
||||
def test_combined_track_repository_create() -> None:
|
||||
"""Test creating a combined track through the repository."""
|
||||
session = DummySession()
|
||||
repo = CombinedTrackRepository(session) # type: ignore[arg-type]
|
||||
|
||||
# Create test data
|
||||
from backend.app.models import CombinedTrackCreate
|
||||
|
||||
create_data = CombinedTrackCreate(
|
||||
start_station_id="550e8400-e29b-41d4-a716-446655440000",
|
||||
end_station_id="550e8400-e29b-41d4-a716-446655440001",
|
||||
coordinates=[(52.52, 13.405), (52.6, 13.5)],
|
||||
constituent_track_ids=["track-1"],
|
||||
length_meters=1500.0,
|
||||
max_speed_kph=120,
|
||||
status="operational",
|
||||
)
|
||||
|
||||
combined_track = repo.create(create_data)
|
||||
|
||||
assert combined_track.start_station_id is not None
|
||||
assert combined_track.end_station_id is not None
|
||||
assert combined_track.length_meters == 1500.0
|
||||
assert combined_track.max_speed_kph == 120
|
||||
assert combined_track.status == "operational"
|
||||
assert session.added and session.added[0] is combined_track
|
||||
|
||||
|
||||
def test_combined_track_repository_exists_between_stations() -> None:
|
||||
"""Test checking if combined track exists between stations."""
|
||||
session = DummySession()
|
||||
repo = CombinedTrackRepository(session) # type: ignore[arg-type]
|
||||
|
||||
# Initially should not exist (scalar_result is None by default)
|
||||
assert not repo.exists_between_stations(
|
||||
"550e8400-e29b-41d4-a716-446655440000",
|
||||
"550e8400-e29b-41d4-a716-446655440001"
|
||||
)
|
||||
|
||||
# Simulate existing combined track
|
||||
session.scalar_result = True
|
||||
assert repo.exists_between_stations(
|
||||
"550e8400-e29b-41d4-a716-446655440000",
|
||||
"550e8400-e29b-41d4-a716-446655440001"
|
||||
)
|
||||
|
||||
|
||||
def test_combined_track_service_create_no_path() -> None:
|
||||
"""Test creating combined track when no path exists."""
|
||||
# Mock session and repositories
|
||||
session = DummySession()
|
||||
|
||||
# Mock TrackRepository to return no path
|
||||
class MockTrackRepository:
|
||||
def __init__(self, session):
|
||||
pass
|
||||
|
||||
def find_path_between_stations(self, start_id, end_id):
|
||||
return None
|
||||
|
||||
# Mock CombinedTrackRepository
|
||||
class MockCombinedTrackRepository:
|
||||
def __init__(self, session):
|
||||
pass
|
||||
|
||||
def exists_between_stations(self, start_id, end_id):
|
||||
return False
|
||||
|
||||
# Patch the service to use mock repositories
|
||||
import backend.app.services.combined_tracks as service_module
|
||||
original_track_repo = service_module.TrackRepository
|
||||
original_combined_repo = service_module.CombinedTrackRepository
|
||||
|
||||
service_module.TrackRepository = MockTrackRepository
|
||||
service_module.CombinedTrackRepository = MockCombinedTrackRepository
|
||||
|
||||
try:
|
||||
result = create_combined_track(
|
||||
session, # type: ignore[arg-type]
|
||||
"550e8400-e29b-41d4-a716-446655440000",
|
||||
"550e8400-e29b-41d4-a716-446655440001"
|
||||
)
|
||||
assert result is None
|
||||
finally:
|
||||
# Restore original classes
|
||||
service_module.TrackRepository = original_track_repo
|
||||
service_module.CombinedTrackRepository = original_combined_repo
|
||||
158
backend/tests/test_tracks_api.py
Normal file
158
backend/tests/test_tracks_api.py
Normal file
@@ -0,0 +1,158 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from backend.app.api import tracks as tracks_api
|
||||
from backend.app.main import app
|
||||
from backend.app.models import CombinedTrackModel, TrackModel
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
|
||||
def _track_model(track_id: str = "track-1") -> TrackModel:
|
||||
now = datetime.now(timezone.utc)
|
||||
return TrackModel(
|
||||
id=track_id,
|
||||
start_station_id="station-a",
|
||||
end_station_id="station-b",
|
||||
length_meters=None,
|
||||
max_speed_kph=None,
|
||||
status="planned",
|
||||
coordinates=[(52.5, 13.4), (52.6, 13.5)],
|
||||
is_bidirectional=True,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
)
|
||||
|
||||
|
||||
def _combined_model(track_id: str = "combined-1") -> CombinedTrackModel:
|
||||
now = datetime.now(timezone.utc)
|
||||
return CombinedTrackModel(
|
||||
id=track_id,
|
||||
start_station_id="station-a",
|
||||
end_station_id="station-b",
|
||||
length_meters=1000,
|
||||
max_speed_kph=120,
|
||||
status="operational",
|
||||
coordinates=[(52.5, 13.4), (52.6, 13.5)],
|
||||
constituent_track_ids=["track-1", "track-2"],
|
||||
is_bidirectional=True,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
)
|
||||
|
||||
|
||||
def _authenticate() -> str:
|
||||
response = client.post(
|
||||
"/api/auth/login",
|
||||
json={"username": "demo", "password": "railgame123"},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
return response.json()["accessToken"]
|
||||
|
||||
|
||||
def test_list_tracks(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
token = _authenticate()
|
||||
monkeypatch.setattr(tracks_api, "list_tracks", lambda db: [_track_model()])
|
||||
|
||||
response = client.get(
|
||||
"/api/tracks",
|
||||
headers={"Authorization": f"Bearer {token}"},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
payload = response.json()
|
||||
assert isinstance(payload, list)
|
||||
assert payload[0]["id"] == "track-1"
|
||||
|
||||
|
||||
def test_get_track_returns_404(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
token = _authenticate()
|
||||
monkeypatch.setattr(tracks_api, "get_track", lambda db, track_id: None)
|
||||
|
||||
response = client.get(
|
||||
"/api/tracks/not-found",
|
||||
headers={"Authorization": f"Bearer {token}"},
|
||||
)
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
def test_create_track_calls_service(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
token = _authenticate()
|
||||
captured: dict[str, Any] = {}
|
||||
payload = {
|
||||
"startStationId": "station-a",
|
||||
"endStationId": "station-b",
|
||||
"coordinates": [[52.5, 13.4], [52.6, 13.5]],
|
||||
}
|
||||
|
||||
def fake_create(db: Any, data: Any) -> TrackModel:
|
||||
assert data.start_station_id == "station-a"
|
||||
captured["payload"] = data
|
||||
return _track_model("track-new")
|
||||
|
||||
monkeypatch.setattr(tracks_api, "create_track", fake_create)
|
||||
|
||||
response = client.post(
|
||||
"/api/tracks",
|
||||
json=payload,
|
||||
headers={"Authorization": f"Bearer {token}"},
|
||||
)
|
||||
assert response.status_code == 201
|
||||
body = response.json()
|
||||
assert body["id"] == "track-new"
|
||||
assert captured["payload"].end_station_id == "station-b"
|
||||
|
||||
|
||||
def test_delete_track_returns_404(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
token = _authenticate()
|
||||
monkeypatch.setattr(
|
||||
tracks_api, "delete_track", lambda db, tid, regenerate=False: False
|
||||
)
|
||||
|
||||
response = client.delete(
|
||||
"/api/tracks/missing",
|
||||
headers={"Authorization": f"Bearer {token}"},
|
||||
)
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
def test_delete_track_success(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
token = _authenticate()
|
||||
seen: dict[str, Any] = {}
|
||||
|
||||
def fake_delete(db: Any, track_id: str, regenerate: bool = False) -> bool:
|
||||
seen["track_id"] = track_id
|
||||
seen["regenerate"] = regenerate
|
||||
return True
|
||||
|
||||
monkeypatch.setattr(tracks_api, "delete_track", fake_delete)
|
||||
|
||||
response = client.delete(
|
||||
"/api/tracks/track-99",
|
||||
params={"regenerate": "true"},
|
||||
headers={"Authorization": f"Bearer {token}"},
|
||||
)
|
||||
|
||||
assert response.status_code == 204
|
||||
assert seen["track_id"] == "track-99"
|
||||
assert seen["regenerate"] is True
|
||||
|
||||
|
||||
def test_list_combined_tracks(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
token = _authenticate()
|
||||
monkeypatch.setattr(
|
||||
tracks_api, "list_combined_tracks", lambda db: [_combined_model()]
|
||||
)
|
||||
|
||||
response = client.get(
|
||||
"/api/tracks/combined",
|
||||
headers={"Authorization": f"Bearer {token}"},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
payload = response.json()
|
||||
assert len(payload) == 1
|
||||
assert payload[0]["id"] == "combined-1"
|
||||
@@ -70,8 +70,8 @@ The system interacts with:
|
||||
|
||||
- User registration and authentication
|
||||
- Railway network building and management
|
||||
- Train scheduling and simulation
|
||||
- Map visualization and interaction
|
||||
- Train scheduling and simulation
|
||||
- Leaderboards and user profiles
|
||||
|
||||
**Out of Scope:**
|
||||
|
||||
14
frontend/package-lock.json
generated
14
frontend/package-lock.json
generated
@@ -65,6 +65,7 @@
|
||||
"integrity": "sha512-2BCOP7TN8M+gVDj7/ht3hsaO/B/n5oDbiAyyvnRlNOs+u1o+JWNYTQrmpuNp1/Wq2gcFrI01JAW+paEKDMx/CA==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"@babel/code-frame": "^7.27.1",
|
||||
"@babel/generator": "^7.28.3",
|
||||
@@ -1425,6 +1426,7 @@
|
||||
"integrity": "sha512-2Q7WS25j4pS1cS8yw3d6buNCVJukOTeQ39bAnwR6sOJbaxvyCGebzTMypDFN82CxBLnl+lSWVdCCWbRY6y9yZQ==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"undici-types": "~6.21.0"
|
||||
}
|
||||
@@ -1442,6 +1444,7 @@
|
||||
"integrity": "sha512-RFA/bURkcKzx/X9oumPG9Vp3D3JUgus/d0b67KB0t5S/raciymilkOa66olh78MUI92QLbEJevO7rvqU/kjwKA==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"@types/prop-types": "*",
|
||||
"csstype": "^3.0.2"
|
||||
@@ -1493,6 +1496,7 @@
|
||||
"integrity": "sha512-n1H6IcDhmmUEG7TNVSspGmiHHutt7iVKtZwRppD7e04wha5MrkV1h3pti9xQLcCMt6YWsncpoT0HMjkH1FNwWQ==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"@typescript-eslint/scope-manager": "8.46.0",
|
||||
"@typescript-eslint/types": "8.46.0",
|
||||
@@ -1829,6 +1833,7 @@
|
||||
"integrity": "sha512-NZyJarBfL7nWwIq+FDL6Zp/yHEhePMNnnJ0y3qfieCrmNvYct8uvtiV41UvlSe6apAfk0fY1FbWx+NwfmpvtTg==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"bin": {
|
||||
"acorn": "bin/acorn"
|
||||
},
|
||||
@@ -2202,6 +2207,7 @@
|
||||
}
|
||||
],
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"baseline-browser-mapping": "^2.8.9",
|
||||
"caniuse-lite": "^1.0.30001746",
|
||||
@@ -2852,6 +2858,7 @@
|
||||
"deprecated": "This version is no longer supported. Please see https://eslint.org/version-support for other options.",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"@eslint-community/eslint-utils": "^4.2.0",
|
||||
"@eslint-community/regexpp": "^4.6.1",
|
||||
@@ -4538,7 +4545,8 @@
|
||||
"version": "1.9.4",
|
||||
"resolved": "https://registry.npmjs.org/leaflet/-/leaflet-1.9.4.tgz",
|
||||
"integrity": "sha512-nxS1ynzJOmOlHp+iL3FyWqK89GtNL8U8rvlMOsQdTTssxZwCXh8N2NB3GDQOL+YR3XnWyZAxwQixURb+FA74PA==",
|
||||
"license": "BSD-2-Clause"
|
||||
"license": "BSD-2-Clause",
|
||||
"peer": true
|
||||
},
|
||||
"node_modules/levn": {
|
||||
"version": "0.4.1",
|
||||
@@ -5323,6 +5331,7 @@
|
||||
"resolved": "https://registry.npmjs.org/react/-/react-18.3.1.tgz",
|
||||
"integrity": "sha512-wS+hAgJShR0KhEvPJArfuPVN1+Hz1t0Y6n5jLrGQbkb4urgPE/0Rve+1kMB1v/oWgHgm4WIcV+i7F2pTVj+2iQ==",
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"loose-envify": "^1.1.0"
|
||||
},
|
||||
@@ -5335,6 +5344,7 @@
|
||||
"resolved": "https://registry.npmjs.org/react-dom/-/react-dom-18.3.1.tgz",
|
||||
"integrity": "sha512-5m4nQKp+rZRb09LNH59GM4BxTh9251/ylbKIbpe7TpGxfJ+9kv6BLkLBXIjjspbgbnIBNqlI23tRnTWT0snUIw==",
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"loose-envify": "^1.1.0",
|
||||
"scheduler": "^0.23.2"
|
||||
@@ -6240,6 +6250,7 @@
|
||||
"integrity": "sha512-jl1vZzPDinLr9eUt3J/t7V6FgNEw9QjvBPdysz9KfQDD41fQrC2Y4vKQdiaUpFT4bXlb1RHhLpp8wtm6M5TgSw==",
|
||||
"dev": true,
|
||||
"license": "Apache-2.0",
|
||||
"peer": true,
|
||||
"bin": {
|
||||
"tsc": "bin/tsc",
|
||||
"tsserver": "bin/tsserver"
|
||||
@@ -6328,6 +6339,7 @@
|
||||
"integrity": "sha512-j3lYzGC3P+B5Yfy/pfKNgVEg4+UtcIJcVRt2cDjIOmhLourAqPqf8P7acgxeiSgUB7E3p2P8/3gNIgDLpwzs4g==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"esbuild": "^0.21.3",
|
||||
"postcss": "^8.4.43",
|
||||
|
||||
@@ -62,13 +62,23 @@ def check_database_url():
|
||||
print(f"Using database: {database_url}")
|
||||
|
||||
|
||||
def run_command(cmd, cwd=None, description=""):
|
||||
def run_command(cmd, cwd=None, description="", env=None):
|
||||
"""Run a shell command and return the result."""
|
||||
print(f"\n>>> {description}")
|
||||
print(f"Running: {' '.join(cmd)}")
|
||||
try:
|
||||
result = subprocess.run(cmd, cwd=cwd, check=True,
|
||||
capture_output=True, text=True)
|
||||
env_vars = os.environ.copy()
|
||||
if env:
|
||||
env_vars.update(env)
|
||||
env_vars.setdefault("PYTHONPATH", "/app")
|
||||
result = subprocess.run(
|
||||
cmd,
|
||||
cwd=cwd,
|
||||
check=True,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
env=env_vars,
|
||||
)
|
||||
if result.stdout:
|
||||
print(result.stdout)
|
||||
return result
|
||||
@@ -86,7 +96,7 @@ def run_migrations():
|
||||
run_command(
|
||||
['alembic', 'upgrade', 'head'],
|
||||
cwd='backend',
|
||||
description="Running database migrations"
|
||||
description="Running database migrations",
|
||||
)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user