From 834150518a5c7ce685c20092bef8c31a964dbca3 Mon Sep 17 00:00:00 2001 From: zwitschi Date: Sat, 11 Oct 2025 17:40:56 +0200 Subject: [PATCH] refactor: improve code formatting and organization across multiple files --- backend/app/api/auth.py | 14 +- backend/app/core/config.py | 4 +- backend/app/core/security.py | 12 +- backend/app/db/models.py | 99 ++++++++--- backend/app/db/session.py | 18 +- backend/app/db/unit_of_work.py | 2 +- backend/app/models/__init__.py | 64 +++---- backend/app/models/base.py | 1 + backend/app/repositories/__init__.py | 2 +- backend/app/repositories/base.py | 2 +- backend/app/repositories/stations.py | 5 +- backend/app/repositories/tracks.py | 5 +- backend/app/repositories/train_schedules.py | 3 +- backend/app/repositories/trains.py | 5 +- backend/app/repositories/users.py | 8 +- backend/app/services/auth.py | 4 +- backend/app/services/network.py | 3 +- backend/migrations/env.py | 4 +- .../versions/20251011_01_initial_schema.py | 162 +++++++++++++++--- backend/tests/test_auth_api.py | 4 +- backend/tests/test_network_api.py | 4 +- backend/tests/test_network_service.py | 8 +- backend/tests/test_repositories.py | 10 +- 23 files changed, 317 insertions(+), 126 deletions(-) diff --git a/backend/app/api/auth.py b/backend/app/api/auth.py index 291102d..1214a87 100644 --- a/backend/app/api/auth.py +++ b/backend/app/api/auth.py @@ -25,17 +25,25 @@ async def login(credentials: LoginRequest) -> AuthResponse: return issue_token_for_user(user) -@router.post("/register", response_model=AuthResponse, status_code=status.HTTP_201_CREATED) +@router.post( + "/register", response_model=AuthResponse, status_code=status.HTTP_201_CREATED +) async def register(payload: RegisterRequest) -> AuthResponse: try: user = register_user(payload.username, payload.password, payload.full_name) except ValueError as exc: message = str(exc) - status_code = status.HTTP_409_CONFLICT if "exists" in message else status.HTTP_400_BAD_REQUEST + status_code = ( + status.HTTP_409_CONFLICT + if "exists" in message + else status.HTTP_400_BAD_REQUEST + ) raise HTTPException(status_code=status_code, detail=message) from exc return issue_token_for_user(user) @router.get("/me", response_model=UserPublic) -async def read_current_user(current_user: UserPublic = Depends(get_current_user)) -> UserPublic: +async def read_current_user( + current_user: UserPublic = Depends(get_current_user), +) -> UserPublic: return current_user diff --git a/backend/app/core/config.py b/backend/app/core/config.py index ec23336..058368b 100644 --- a/backend/app/core/config.py +++ b/backend/app/core/config.py @@ -1,11 +1,9 @@ from functools import lru_cache +from typing import Optional from pydantic_settings import BaseSettings, SettingsConfigDict -from typing import Optional - - class Settings(BaseSettings): project_name: str = "Rail Game API" version: str = "0.1.0" diff --git a/backend/app/core/security.py b/backend/app/core/security.py index a57d83a..1112272 100644 --- a/backend/app/core/security.py +++ b/backend/app/core/security.py @@ -25,12 +25,18 @@ def create_access_token(subject: str, expires_delta: timedelta | None = None) -> expires_delta or timedelta(minutes=settings.access_token_expire_minutes) ) to_encode: Dict[str, Any] = {"sub": subject, "exp": expire} - return jwt.encode(to_encode, settings.jwt_secret_key, algorithm=settings.jwt_algorithm) + return jwt.encode( + to_encode, settings.jwt_secret_key, algorithm=settings.jwt_algorithm + ) def decode_access_token(token: str) -> Dict[str, Any]: settings = get_settings() try: - return jwt.decode(token, settings.jwt_secret_key, algorithms=[settings.jwt_algorithm]) - except JWTError as exc: # pragma: no cover - specific error mapping handled by caller + return jwt.decode( + token, settings.jwt_secret_key, algorithms=[settings.jwt_algorithm] + ) + except ( + JWTError + ) as exc: # pragma: no cover - specific error mapping handled by caller raise ValueError("Invalid token") from exc diff --git a/backend/app/db/models.py b/backend/app/db/models.py index b0d7ede..570e64e 100644 --- a/backend/app/db/models.py +++ b/backend/app/db/models.py @@ -3,7 +3,17 @@ from __future__ import annotations import uuid from geoalchemy2 import Geometry -from sqlalchemy import Boolean, DateTime, Float, ForeignKey, Integer, Numeric, String, Text, UniqueConstraint +from sqlalchemy import ( + Boolean, + DateTime, + Float, + ForeignKey, + Integer, + Numeric, + String, + Text, + UniqueConstraint, +) from sqlalchemy.dialects.postgresql import UUID from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column from sqlalchemy.sql import func @@ -14,16 +24,23 @@ class Base(DeclarativeBase): class TimestampMixin: - created_at: Mapped[DateTime] = mapped_column(DateTime(timezone=True), server_default=func.now(), nullable=False) + created_at: Mapped[DateTime] = mapped_column( + DateTime(timezone=True), server_default=func.now(), nullable=False + ) updated_at: Mapped[DateTime] = mapped_column( - DateTime(timezone=True), server_default=func.now(), onupdate=func.now(), nullable=False + DateTime(timezone=True), + server_default=func.now(), + onupdate=func.now(), + nullable=False, ) class User(Base, TimestampMixin): __tablename__ = "users" - id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) + id: Mapped[uuid.UUID] = mapped_column( + UUID(as_uuid=True), primary_key=True, default=uuid.uuid4 + ) username: Mapped[str] = mapped_column(String(64), unique=True, nullable=False) email: Mapped[str | None] = mapped_column(String(255), unique=True, nullable=True) full_name: Mapped[str | None] = mapped_column(String(128), nullable=True) @@ -35,11 +52,15 @@ class User(Base, TimestampMixin): class Station(Base, TimestampMixin): __tablename__ = "stations" - id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) + id: Mapped[uuid.UUID] = mapped_column( + UUID(as_uuid=True), primary_key=True, default=uuid.uuid4 + ) osm_id: Mapped[str | None] = mapped_column(String(32), nullable=True) name: Mapped[str] = mapped_column(String(128), nullable=False) code: Mapped[str | None] = mapped_column(String(16), nullable=True) - location: Mapped[str] = mapped_column(Geometry(geometry_type="POINT", srid=4326), nullable=False) + location: Mapped[str] = mapped_column( + Geometry(geometry_type="POINT", srid=4326), nullable=False + ) elevation_m: Mapped[float | None] = mapped_column(Float, nullable=True) is_active: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True) @@ -47,28 +68,50 @@ class Station(Base, TimestampMixin): class Track(Base, TimestampMixin): __tablename__ = "tracks" - id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) + id: Mapped[uuid.UUID] = mapped_column( + UUID(as_uuid=True), primary_key=True, default=uuid.uuid4 + ) name: Mapped[str | None] = mapped_column(String(128), nullable=True) - start_station_id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), ForeignKey("stations.id", ondelete="RESTRICT"), nullable=False) - end_station_id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), ForeignKey("stations.id", ondelete="RESTRICT"), nullable=False) + start_station_id: Mapped[uuid.UUID] = mapped_column( + UUID(as_uuid=True), + ForeignKey("stations.id", ondelete="RESTRICT"), + nullable=False, + ) + end_station_id: Mapped[uuid.UUID] = mapped_column( + UUID(as_uuid=True), + ForeignKey("stations.id", ondelete="RESTRICT"), + nullable=False, + ) length_meters: Mapped[float | None] = mapped_column(Numeric(10, 2), nullable=True) max_speed_kph: Mapped[int | None] = mapped_column(Integer, nullable=True) - is_bidirectional: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True) + is_bidirectional: Mapped[bool] = mapped_column( + Boolean, nullable=False, default=True + ) status: Mapped[str] = mapped_column(String(32), nullable=False, default="planned") - track_geometry: Mapped[str] = mapped_column(Geometry(geometry_type="LINESTRING", srid=4326), nullable=False) + track_geometry: Mapped[str] = mapped_column( + Geometry(geometry_type="LINESTRING", srid=4326), nullable=False + ) __table_args__ = ( - UniqueConstraint("start_station_id", "end_station_id", name="uq_tracks_station_pair"), + UniqueConstraint( + "start_station_id", "end_station_id", name="uq_tracks_station_pair" + ), ) class Train(Base, TimestampMixin): __tablename__ = "trains" - id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) + id: Mapped[uuid.UUID] = mapped_column( + UUID(as_uuid=True), primary_key=True, default=uuid.uuid4 + ) designation: Mapped[str] = mapped_column(String(64), nullable=False, unique=True) - operator_id: Mapped[uuid.UUID | None] = mapped_column(UUID(as_uuid=True), ForeignKey("users.id", ondelete="SET NULL")) - home_station_id: Mapped[uuid.UUID | None] = mapped_column(UUID(as_uuid=True), ForeignKey("stations.id", ondelete="SET NULL")) + operator_id: Mapped[uuid.UUID | None] = mapped_column( + UUID(as_uuid=True), ForeignKey("users.id", ondelete="SET NULL") + ) + home_station_id: Mapped[uuid.UUID | None] = mapped_column( + UUID(as_uuid=True), ForeignKey("stations.id", ondelete="SET NULL") + ) capacity: Mapped[int] = mapped_column(Integer, nullable=False) max_speed_kph: Mapped[int] = mapped_column(Integer, nullable=False) consist: Mapped[str | None] = mapped_column(Text, nullable=True) @@ -77,14 +120,28 @@ class Train(Base, TimestampMixin): class TrainSchedule(Base, TimestampMixin): __tablename__ = "train_schedules" - id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) - train_id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), ForeignKey("trains.id", ondelete="CASCADE"), nullable=False) + id: Mapped[uuid.UUID] = mapped_column( + UUID(as_uuid=True), primary_key=True, default=uuid.uuid4 + ) + train_id: Mapped[uuid.UUID] = mapped_column( + UUID(as_uuid=True), ForeignKey("trains.id", ondelete="CASCADE"), nullable=False + ) sequence_index: Mapped[int] = mapped_column(Integer, nullable=False) - station_id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), ForeignKey("stations.id", ondelete="CASCADE"), nullable=False) - scheduled_arrival: Mapped[DateTime | None] = mapped_column(DateTime(timezone=True), nullable=True) - scheduled_departure: Mapped[DateTime | None] = mapped_column(DateTime(timezone=True), nullable=True) + station_id: Mapped[uuid.UUID] = mapped_column( + UUID(as_uuid=True), + ForeignKey("stations.id", ondelete="CASCADE"), + nullable=False, + ) + scheduled_arrival: Mapped[DateTime | None] = mapped_column( + DateTime(timezone=True), nullable=True + ) + scheduled_departure: Mapped[DateTime | None] = mapped_column( + DateTime(timezone=True), nullable=True + ) dwell_seconds: Mapped[int | None] = mapped_column(Integer, nullable=True) __table_args__ = ( - UniqueConstraint("train_id", "sequence_index", name="uq_train_schedule_sequence"), + UniqueConstraint( + "train_id", "sequence_index", name="uq_train_schedule_sequence" + ), ) diff --git a/backend/app/db/session.py b/backend/app/db/session.py index 0e41cba..342acd0 100644 --- a/backend/app/db/session.py +++ b/backend/app/db/session.py @@ -9,16 +9,20 @@ from backend.app.core.config import get_settings settings = get_settings() -engine = create_engine(settings.sqlalchemy_database_url, echo=settings.database_echo, future=True) -SessionLocal = sessionmaker(bind=engine, autoflush=False, autocommit=False, expire_on_commit=False) +engine = create_engine( + settings.sqlalchemy_database_url, echo=settings.database_echo, future=True +) +SessionLocal = sessionmaker( + bind=engine, autoflush=False, autocommit=False, expire_on_commit=False +) def get_db_session() -> Generator[Session, None, None]: - session = SessionLocal() - try: - yield session - finally: - session.close() + session = SessionLocal() + try: + yield session + finally: + session.close() __all__ = ["engine", "SessionLocal", "get_db_session"] diff --git a/backend/app/db/unit_of_work.py b/backend/app/db/unit_of_work.py index 8ce87e7..f2aea0b 100644 --- a/backend/app/db/unit_of_work.py +++ b/backend/app/db/unit_of_work.py @@ -8,9 +8,9 @@ from sqlalchemy.orm import Session from backend.app.db.session import SessionLocal from backend.app.repositories import ( StationRepository, + TrackRepository, TrainRepository, TrainScheduleRepository, - TrackRepository, UserRepository, ) diff --git a/backend/app/models/__init__.py b/backend/app/models/__init__.py index f8fed05..bd2e0bb 100644 --- a/backend/app/models/__init__.py +++ b/backend/app/models/__init__.py @@ -1,39 +1,39 @@ from .auth import ( - AuthResponse, - LoginRequest, - RegisterRequest, - TokenPayload, - TokenResponse, - UserInDB, - UserPublic, + AuthResponse, + LoginRequest, + RegisterRequest, + TokenPayload, + TokenResponse, + UserInDB, + UserPublic, ) from .base import ( - StationCreate, - StationModel, - TrackCreate, - TrackModel, - TrainScheduleCreate, - TrainCreate, - TrainModel, - UserCreate, - to_camel, + StationCreate, + StationModel, + TrackCreate, + TrackModel, + TrainCreate, + TrainModel, + TrainScheduleCreate, + UserCreate, + to_camel, ) __all__ = [ - "LoginRequest", - "RegisterRequest", - "AuthResponse", - "TokenPayload", - "TokenResponse", - "UserInDB", - "UserPublic", - "StationCreate", - "StationModel", - "TrackCreate", - "TrackModel", - "TrainScheduleCreate", - "TrainCreate", - "TrainModel", - "UserCreate", - "to_camel", + "LoginRequest", + "RegisterRequest", + "AuthResponse", + "TokenPayload", + "TokenResponse", + "UserInDB", + "UserPublic", + "StationCreate", + "StationModel", + "TrackCreate", + "TrackModel", + "TrainScheduleCreate", + "TrainCreate", + "TrainModel", + "UserCreate", + "to_camel", ] diff --git a/backend/app/models/base.py b/backend/app/models/base.py index b9a6aeb..b93f192 100644 --- a/backend/app/models/base.py +++ b/backend/app/models/base.py @@ -10,6 +10,7 @@ def to_camel(string: str) -> str: head, *tail = string.split("_") return head + "".join(part.capitalize() for part in tail) + IdT = TypeVar("IdT", bound=str) diff --git a/backend/app/repositories/__init__.py b/backend/app/repositories/__init__.py index 49dfca1..85f7970 100644 --- a/backend/app/repositories/__init__.py +++ b/backend/app/repositories/__init__.py @@ -1,8 +1,8 @@ """Repository abstractions for database access.""" from backend.app.repositories.stations import StationRepository -from backend.app.repositories.train_schedules import TrainScheduleRepository from backend.app.repositories.tracks import TrackRepository +from backend.app.repositories.train_schedules import TrainScheduleRepository from backend.app.repositories.trains import TrainRepository from backend.app.repositories.users import UserRepository diff --git a/backend/app/repositories/base.py b/backend/app/repositories/base.py index b1040f9..46ebd2e 100644 --- a/backend/app/repositories/base.py +++ b/backend/app/repositories/base.py @@ -1,8 +1,8 @@ from __future__ import annotations -import sqlalchemy as sa from typing import Generic, Iterable, Optional, Sequence, Type, TypeVar +import sqlalchemy as sa from sqlalchemy.orm import Session from backend.app.db.models import Base diff --git a/backend/app/repositories/stations.py b/backend/app/repositories/stations.py index 93cde66..db78b8e 100644 --- a/backend/app/repositories/stations.py +++ b/backend/app/repositories/stations.py @@ -1,13 +1,12 @@ from __future__ import annotations import sqlalchemy as sa +from geoalchemy2.elements import WKTElement from sqlalchemy.orm import Session -from geoalchemy2.elements import WKTElement - from backend.app.db.models import Station -from backend.app.repositories.base import BaseRepository from backend.app.models import StationCreate +from backend.app.repositories.base import BaseRepository class StationRepository(BaseRepository[Station]): diff --git a/backend/app/repositories/tracks.py b/backend/app/repositories/tracks.py index 84ca192..6afec95 100644 --- a/backend/app/repositories/tracks.py +++ b/backend/app/repositories/tracks.py @@ -1,13 +1,14 @@ from __future__ import annotations +from uuid import UUID + import sqlalchemy as sa from geoalchemy2.elements import WKTElement -from uuid import UUID from sqlalchemy.orm import Session from backend.app.db.models import Track -from backend.app.repositories.base import BaseRepository from backend.app.models import TrackCreate +from backend.app.repositories.base import BaseRepository class TrackRepository(BaseRepository[Track]): diff --git a/backend/app/repositories/train_schedules.py b/backend/app/repositories/train_schedules.py index 7e988f4..e7813e3 100644 --- a/backend/app/repositories/train_schedules.py +++ b/backend/app/repositories/train_schedules.py @@ -1,7 +1,8 @@ from __future__ import annotations -import sqlalchemy as sa from uuid import UUID + +import sqlalchemy as sa from sqlalchemy.orm import Session from backend.app.db.models import TrainSchedule diff --git a/backend/app/repositories/trains.py b/backend/app/repositories/trains.py index 40e7ef6..1fed71b 100644 --- a/backend/app/repositories/trains.py +++ b/backend/app/repositories/trains.py @@ -1,12 +1,13 @@ from __future__ import annotations -import sqlalchemy as sa from uuid import UUID + +import sqlalchemy as sa from sqlalchemy.orm import Session from backend.app.db.models import Train -from backend.app.repositories.base import BaseRepository from backend.app.models import TrainCreate +from backend.app.repositories.base import BaseRepository class TrainRepository(BaseRepository[Train]): diff --git a/backend/app/repositories/users.py b/backend/app/repositories/users.py index a1b2092..ed59352 100644 --- a/backend/app/repositories/users.py +++ b/backend/app/repositories/users.py @@ -17,11 +17,15 @@ class UserRepository(BaseRepository[User]): super().__init__(session) def get_by_username(self, username: str) -> User | None: - statement = sa.select(self.model).where(sa.func.lower(self.model.username) == username.lower()) + statement = sa.select(self.model).where( + sa.func.lower(self.model.username) == username.lower() + ) return self.session.scalar(statement) def list_recent(self, limit: int = 50) -> list[User]: - statement = sa.select(self.model).order_by(self.model.created_at.desc()).limit(limit) + statement = ( + sa.select(self.model).order_by(self.model.created_at.desc()).limit(limit) + ) return list(self.session.scalars(statement)) def create(self, data: UserCreate) -> User: diff --git a/backend/app/services/auth.py b/backend/app/services/auth.py index 8402a82..29769f0 100644 --- a/backend/app/services/auth.py +++ b/backend/app/services/auth.py @@ -43,7 +43,9 @@ def to_public_user(user: UserInDB) -> UserPublic: return UserPublic(username=user.username, full_name=user.full_name) -def register_user(username: str, password: str, full_name: Optional[str] = None) -> UserInDB: +def register_user( + username: str, password: str, full_name: Optional[str] = None +) -> UserInDB: normalized_username = username.strip() if not normalized_username: raise ValueError("Username must not be empty") diff --git a/backend/app/services/network.py b/backend/app/services/network.py index 0868d7f..e661096 100644 --- a/backend/app/services/network.py +++ b/backend/app/services/network.py @@ -6,13 +6,14 @@ from typing import Iterable, cast from geoalchemy2.elements import WKBElement, WKTElement from geoalchemy2.shape import to_shape + try: # pragma: no cover - optional dependency guard from shapely.geometry import Point # type: ignore except ImportError: # pragma: no cover - allow running without shapely at import time Point = None # type: ignore[assignment] -from sqlalchemy.orm import Session from sqlalchemy.exc import SQLAlchemyError +from sqlalchemy.orm import Session from backend.app.models import StationModel, TrackModel, TrainModel from backend.app.repositories import StationRepository, TrackRepository, TrainRepository diff --git a/backend/migrations/env.py b/backend/migrations/env.py index 20ea5a8..45047b8 100644 --- a/backend/migrations/env.py +++ b/backend/migrations/env.py @@ -57,7 +57,9 @@ def run_migrations_online() -> None: ) with connectable.connect() as connection: - context.configure(connection=connection, target_metadata=target_metadata, compare_type=True) + context.configure( + connection=connection, target_metadata=target_metadata, compare_type=True + ) with context.begin_transaction(): context.run_migrations() diff --git a/backend/migrations/versions/20251011_01_initial_schema.py b/backend/migrations/versions/20251011_01_initial_schema.py index 42aca06..548c2e3 100644 --- a/backend/migrations/versions/20251011_01_initial_schema.py +++ b/backend/migrations/versions/20251011_01_initial_schema.py @@ -20,79 +20,185 @@ def upgrade() -> None: op.create_table( "users", - sa.Column("id", postgresql.UUID(as_uuid=True), primary_key=True, server_default=sa.text("gen_random_uuid()")), + sa.Column( + "id", + postgresql.UUID(as_uuid=True), + primary_key=True, + server_default=sa.text("gen_random_uuid()"), + ), sa.Column("username", sa.String(length=64), nullable=False, unique=True), sa.Column("email", sa.String(length=255), nullable=True, unique=True), sa.Column("full_name", sa.String(length=128), nullable=True), sa.Column("password_hash", sa.String(length=256), nullable=False), - sa.Column("role", sa.String(length=32), nullable=False, server_default="player"), + sa.Column( + "role", sa.String(length=32), nullable=False, server_default="player" + ), sa.Column("preferences", sa.Text(), nullable=True), - sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("timezone('utc', now())"), nullable=False), - sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.text("timezone('utc', now())"), nullable=False), + sa.Column( + "created_at", + sa.DateTime(timezone=True), + server_default=sa.text("timezone('utc', now())"), + nullable=False, + ), + sa.Column( + "updated_at", + sa.DateTime(timezone=True), + server_default=sa.text("timezone('utc', now())"), + nullable=False, + ), ) op.create_table( "stations", - sa.Column("id", postgresql.UUID(as_uuid=True), primary_key=True, server_default=sa.text("gen_random_uuid()")), + sa.Column( + "id", + postgresql.UUID(as_uuid=True), + primary_key=True, + server_default=sa.text("gen_random_uuid()"), + ), sa.Column("osm_id", sa.String(length=32), nullable=True), sa.Column("name", sa.String(length=128), nullable=False), sa.Column("code", sa.String(length=16), nullable=True), - sa.Column("location", Geometry(geometry_type="POINT", srid=4326), nullable=False), + sa.Column( + "location", Geometry(geometry_type="POINT", srid=4326), nullable=False + ), sa.Column("elevation_m", sa.Float(), nullable=True), - sa.Column("is_active", sa.Boolean(), nullable=False, server_default=sa.text("true")), - sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("timezone('utc', now())"), nullable=False), - sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.text("timezone('utc', now())"), nullable=False), + sa.Column( + "is_active", sa.Boolean(), nullable=False, server_default=sa.text("true") + ), + sa.Column( + "created_at", + sa.DateTime(timezone=True), + server_default=sa.text("timezone('utc', now())"), + nullable=False, + ), + sa.Column( + "updated_at", + sa.DateTime(timezone=True), + server_default=sa.text("timezone('utc', now())"), + nullable=False, + ), + ) + op.create_index( + "ix_stations_location", "stations", ["location"], postgresql_using="gist" ) - op.create_index("ix_stations_location", "stations", ["location"], postgresql_using="gist") op.create_table( "tracks", - sa.Column("id", postgresql.UUID(as_uuid=True), primary_key=True, server_default=sa.text("gen_random_uuid()")), + sa.Column( + "id", + postgresql.UUID(as_uuid=True), + primary_key=True, + server_default=sa.text("gen_random_uuid()"), + ), sa.Column("name", sa.String(length=128), nullable=True), sa.Column("start_station_id", postgresql.UUID(as_uuid=True), nullable=False), sa.Column("end_station_id", postgresql.UUID(as_uuid=True), nullable=False), sa.Column("length_meters", sa.Numeric(10, 2), nullable=True), sa.Column("max_speed_kph", sa.Integer(), nullable=True), - sa.Column("is_bidirectional", sa.Boolean(), nullable=False, server_default=sa.text("true")), - sa.Column("status", sa.String(length=32), nullable=False, server_default="planned"), - sa.Column("track_geometry", Geometry(geometry_type="LINESTRING", srid=4326), nullable=False), - sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("timezone('utc', now())"), nullable=False), - sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.text("timezone('utc', now())"), nullable=False), - sa.ForeignKeyConstraint(["start_station_id"], ["stations.id"], ondelete="RESTRICT"), - sa.ForeignKeyConstraint(["end_station_id"], ["stations.id"], ondelete="RESTRICT"), - sa.UniqueConstraint("start_station_id", "end_station_id", name="uq_tracks_station_pair"), + sa.Column( + "is_bidirectional", + sa.Boolean(), + nullable=False, + server_default=sa.text("true"), + ), + sa.Column( + "status", sa.String(length=32), nullable=False, server_default="planned" + ), + sa.Column( + "track_geometry", + Geometry(geometry_type="LINESTRING", srid=4326), + nullable=False, + ), + sa.Column( + "created_at", + sa.DateTime(timezone=True), + server_default=sa.text("timezone('utc', now())"), + nullable=False, + ), + sa.Column( + "updated_at", + sa.DateTime(timezone=True), + server_default=sa.text("timezone('utc', now())"), + nullable=False, + ), + sa.ForeignKeyConstraint( + ["start_station_id"], ["stations.id"], ondelete="RESTRICT" + ), + sa.ForeignKeyConstraint( + ["end_station_id"], ["stations.id"], ondelete="RESTRICT" + ), + sa.UniqueConstraint( + "start_station_id", "end_station_id", name="uq_tracks_station_pair" + ), + ) + op.create_index( + "ix_tracks_geometry", "tracks", ["track_geometry"], postgresql_using="gist" ) - op.create_index("ix_tracks_geometry", "tracks", ["track_geometry"], postgresql_using="gist") op.create_table( "trains", - sa.Column("id", postgresql.UUID(as_uuid=True), primary_key=True, server_default=sa.text("gen_random_uuid()")), + sa.Column( + "id", + postgresql.UUID(as_uuid=True), + primary_key=True, + server_default=sa.text("gen_random_uuid()"), + ), sa.Column("designation", sa.String(length=64), nullable=False, unique=True), sa.Column("operator_id", postgresql.UUID(as_uuid=True), nullable=True), sa.Column("home_station_id", postgresql.UUID(as_uuid=True), nullable=True), sa.Column("capacity", sa.Integer(), nullable=False), sa.Column("max_speed_kph", sa.Integer(), nullable=False), sa.Column("consist", sa.Text(), nullable=True), - sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("timezone('utc', now())"), nullable=False), - sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.text("timezone('utc', now())"), nullable=False), + sa.Column( + "created_at", + sa.DateTime(timezone=True), + server_default=sa.text("timezone('utc', now())"), + nullable=False, + ), + sa.Column( + "updated_at", + sa.DateTime(timezone=True), + server_default=sa.text("timezone('utc', now())"), + nullable=False, + ), sa.ForeignKeyConstraint(["operator_id"], ["users.id"], ondelete="SET NULL"), - sa.ForeignKeyConstraint(["home_station_id"], ["stations.id"], ondelete="SET NULL"), + sa.ForeignKeyConstraint( + ["home_station_id"], ["stations.id"], ondelete="SET NULL" + ), ) op.create_table( "train_schedules", - sa.Column("id", postgresql.UUID(as_uuid=True), primary_key=True, server_default=sa.text("gen_random_uuid()")), + sa.Column( + "id", + postgresql.UUID(as_uuid=True), + primary_key=True, + server_default=sa.text("gen_random_uuid()"), + ), sa.Column("train_id", postgresql.UUID(as_uuid=True), nullable=False), sa.Column("sequence_index", sa.Integer(), nullable=False), sa.Column("station_id", postgresql.UUID(as_uuid=True), nullable=False), sa.Column("scheduled_arrival", sa.DateTime(timezone=True), nullable=True), sa.Column("scheduled_departure", sa.DateTime(timezone=True), nullable=True), sa.Column("dwell_seconds", sa.Integer(), nullable=True), - sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("timezone('utc', now())"), nullable=False), - sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.text("timezone('utc', now())"), nullable=False), + sa.Column( + "created_at", + sa.DateTime(timezone=True), + server_default=sa.text("timezone('utc', now())"), + nullable=False, + ), + sa.Column( + "updated_at", + sa.DateTime(timezone=True), + server_default=sa.text("timezone('utc', now())"), + nullable=False, + ), sa.ForeignKeyConstraint(["train_id"], ["trains.id"], ondelete="CASCADE"), sa.ForeignKeyConstraint(["station_id"], ["stations.id"], ondelete="CASCADE"), - sa.UniqueConstraint("train_id", "sequence_index", name="uq_train_schedule_sequence"), + sa.UniqueConstraint( + "train_id", "sequence_index", name="uq_train_schedule_sequence" + ), ) diff --git a/backend/tests/test_auth_api.py b/backend/tests/test_auth_api.py index 5f59c58..8d06ae2 100644 --- a/backend/tests/test_auth_api.py +++ b/backend/tests/test_auth_api.py @@ -35,9 +35,7 @@ def test_me_endpoint_returns_current_user() -> None: ) token = login.json()["accessToken"] - response = client.get( - "/api/auth/me", headers={"Authorization": f"Bearer {token}"} - ) + response = client.get("/api/auth/me", headers={"Authorization": f"Bearer {token}"}) assert response.status_code == 200 assert response.json()["username"] == "demo" diff --git a/backend/tests/test_network_api.py b/backend/tests/test_network_api.py index 5b9b7bd..696a273 100644 --- a/backend/tests/test_network_api.py +++ b/backend/tests/test_network_api.py @@ -20,9 +20,7 @@ def test_network_snapshot_requires_authentication() -> None: def test_network_snapshot_endpoint_returns_collections() -> None: token = _authenticate() - response = client.get( - "/api/network", headers={"Authorization": f"Bearer {token}"} - ) + response = client.get("/api/network", headers={"Authorization": f"Bearer {token}"}) assert response.status_code == 200 payload = response.json() diff --git a/backend/tests/test_network_service.py b/backend/tests/test_network_service.py index 785900b..83ceee5 100644 --- a/backend/tests/test_network_service.py +++ b/backend/tests/test_network_service.py @@ -40,7 +40,9 @@ def sample_entities() -> dict[str, SimpleNamespace]: return {"station": station, "track": track, "train": train} -def test_network_snapshot_prefers_repository_data(monkeypatch: pytest.MonkeyPatch, sample_entities: dict[str, SimpleNamespace]) -> None: +def test_network_snapshot_prefers_repository_data( + monkeypatch: pytest.MonkeyPatch, sample_entities: dict[str, SimpleNamespace] +) -> None: station = sample_entities["station"] track = sample_entities["track"] train = sample_entities["train"] @@ -58,7 +60,9 @@ def test_network_snapshot_prefers_repository_data(monkeypatch: pytest.MonkeyPatc assert snapshot["trains"][0]["operatingTrackIds"] == [] -def test_network_snapshot_falls_back_when_repositories_empty(monkeypatch: pytest.MonkeyPatch) -> None: +def test_network_snapshot_falls_back_when_repositories_empty( + monkeypatch: pytest.MonkeyPatch, +) -> None: monkeypatch.setattr(StationRepository, "list_active", lambda self: []) monkeypatch.setattr(TrackRepository, "list_all", lambda self: []) monkeypatch.setattr(TrainRepository, "list_all", lambda self: []) diff --git a/backend/tests/test_repositories.py b/backend/tests/test_repositories.py index 15bdd81..644f6f4 100644 --- a/backend/tests/test_repositories.py +++ b/backend/tests/test_repositories.py @@ -6,6 +6,7 @@ from typing import Any, List, cast from uuid import uuid4 import pytest +from sqlalchemy.orm import Session from backend.app.db.models import TrainSchedule, User from backend.app.db.unit_of_work import SqlAlchemyUnitOfWork @@ -23,7 +24,6 @@ from backend.app.repositories import ( TrainScheduleRepository, UserRepository, ) -from sqlalchemy.orm import Session @dataclass @@ -50,7 +50,9 @@ class DummySession: self.statements.append(statement) return self.scalar_result - def flush(self, _objects: list[Any] | None = None) -> None: # pragma: no cover - optional + def flush( + self, _objects: list[Any] | None = None + ) -> None: # pragma: no cover - optional return None def commit(self) -> None: # pragma: no cover - optional @@ -215,9 +217,7 @@ def test_unit_of_work_commits_and_closes_session() -> None: uow = SqlAlchemyUnitOfWork(lambda: cast(Session, session)) with uow as active: - active.users.create( - UserCreate(username="demo", password_hash="hashed") - ) + active.users.create(UserCreate(username="demo", password_hash="hashed")) active.commit() assert session.committed