refactor: improve code formatting and organization across multiple files

This commit is contained in:
2025-10-11 17:40:56 +02:00
parent 47bbd7ab0c
commit 834150518a
23 changed files with 317 additions and 126 deletions

View File

@@ -25,17 +25,25 @@ async def login(credentials: LoginRequest) -> AuthResponse:
return issue_token_for_user(user)
@router.post("/register", response_model=AuthResponse, status_code=status.HTTP_201_CREATED)
@router.post(
"/register", response_model=AuthResponse, status_code=status.HTTP_201_CREATED
)
async def register(payload: RegisterRequest) -> AuthResponse:
try:
user = register_user(payload.username, payload.password, payload.full_name)
except ValueError as exc:
message = str(exc)
status_code = status.HTTP_409_CONFLICT if "exists" in message else status.HTTP_400_BAD_REQUEST
status_code = (
status.HTTP_409_CONFLICT
if "exists" in message
else status.HTTP_400_BAD_REQUEST
)
raise HTTPException(status_code=status_code, detail=message) from exc
return issue_token_for_user(user)
@router.get("/me", response_model=UserPublic)
async def read_current_user(current_user: UserPublic = Depends(get_current_user)) -> UserPublic:
async def read_current_user(
current_user: UserPublic = Depends(get_current_user),
) -> UserPublic:
return current_user

View File

@@ -1,11 +1,9 @@
from functools import lru_cache
from typing import Optional
from pydantic_settings import BaseSettings, SettingsConfigDict
from typing import Optional
class Settings(BaseSettings):
project_name: str = "Rail Game API"
version: str = "0.1.0"

View File

@@ -25,12 +25,18 @@ def create_access_token(subject: str, expires_delta: timedelta | None = None) ->
expires_delta or timedelta(minutes=settings.access_token_expire_minutes)
)
to_encode: Dict[str, Any] = {"sub": subject, "exp": expire}
return jwt.encode(to_encode, settings.jwt_secret_key, algorithm=settings.jwt_algorithm)
return jwt.encode(
to_encode, settings.jwt_secret_key, algorithm=settings.jwt_algorithm
)
def decode_access_token(token: str) -> Dict[str, Any]:
settings = get_settings()
try:
return jwt.decode(token, settings.jwt_secret_key, algorithms=[settings.jwt_algorithm])
except JWTError as exc: # pragma: no cover - specific error mapping handled by caller
return jwt.decode(
token, settings.jwt_secret_key, algorithms=[settings.jwt_algorithm]
)
except (
JWTError
) as exc: # pragma: no cover - specific error mapping handled by caller
raise ValueError("Invalid token") from exc

View File

@@ -3,7 +3,17 @@ from __future__ import annotations
import uuid
from geoalchemy2 import Geometry
from sqlalchemy import Boolean, DateTime, Float, ForeignKey, Integer, Numeric, String, Text, UniqueConstraint
from sqlalchemy import (
Boolean,
DateTime,
Float,
ForeignKey,
Integer,
Numeric,
String,
Text,
UniqueConstraint,
)
from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
from sqlalchemy.sql import func
@@ -14,16 +24,23 @@ class Base(DeclarativeBase):
class TimestampMixin:
created_at: Mapped[DateTime] = mapped_column(DateTime(timezone=True), server_default=func.now(), nullable=False)
created_at: Mapped[DateTime] = mapped_column(
DateTime(timezone=True), server_default=func.now(), nullable=False
)
updated_at: Mapped[DateTime] = mapped_column(
DateTime(timezone=True), server_default=func.now(), onupdate=func.now(), nullable=False
DateTime(timezone=True),
server_default=func.now(),
onupdate=func.now(),
nullable=False,
)
class User(Base, TimestampMixin):
__tablename__ = "users"
id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True), primary_key=True, default=uuid.uuid4
)
username: Mapped[str] = mapped_column(String(64), unique=True, nullable=False)
email: Mapped[str | None] = mapped_column(String(255), unique=True, nullable=True)
full_name: Mapped[str | None] = mapped_column(String(128), nullable=True)
@@ -35,11 +52,15 @@ class User(Base, TimestampMixin):
class Station(Base, TimestampMixin):
__tablename__ = "stations"
id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True), primary_key=True, default=uuid.uuid4
)
osm_id: Mapped[str | None] = mapped_column(String(32), nullable=True)
name: Mapped[str] = mapped_column(String(128), nullable=False)
code: Mapped[str | None] = mapped_column(String(16), nullable=True)
location: Mapped[str] = mapped_column(Geometry(geometry_type="POINT", srid=4326), nullable=False)
location: Mapped[str] = mapped_column(
Geometry(geometry_type="POINT", srid=4326), nullable=False
)
elevation_m: Mapped[float | None] = mapped_column(Float, nullable=True)
is_active: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
@@ -47,28 +68,50 @@ class Station(Base, TimestampMixin):
class Track(Base, TimestampMixin):
__tablename__ = "tracks"
id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True), primary_key=True, default=uuid.uuid4
)
name: Mapped[str | None] = mapped_column(String(128), nullable=True)
start_station_id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), ForeignKey("stations.id", ondelete="RESTRICT"), nullable=False)
end_station_id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), ForeignKey("stations.id", ondelete="RESTRICT"), nullable=False)
start_station_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
ForeignKey("stations.id", ondelete="RESTRICT"),
nullable=False,
)
end_station_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
ForeignKey("stations.id", ondelete="RESTRICT"),
nullable=False,
)
length_meters: Mapped[float | None] = mapped_column(Numeric(10, 2), nullable=True)
max_speed_kph: Mapped[int | None] = mapped_column(Integer, nullable=True)
is_bidirectional: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
is_bidirectional: Mapped[bool] = mapped_column(
Boolean, nullable=False, default=True
)
status: Mapped[str] = mapped_column(String(32), nullable=False, default="planned")
track_geometry: Mapped[str] = mapped_column(Geometry(geometry_type="LINESTRING", srid=4326), nullable=False)
track_geometry: Mapped[str] = mapped_column(
Geometry(geometry_type="LINESTRING", srid=4326), nullable=False
)
__table_args__ = (
UniqueConstraint("start_station_id", "end_station_id", name="uq_tracks_station_pair"),
UniqueConstraint(
"start_station_id", "end_station_id", name="uq_tracks_station_pair"
),
)
class Train(Base, TimestampMixin):
__tablename__ = "trains"
id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True), primary_key=True, default=uuid.uuid4
)
designation: Mapped[str] = mapped_column(String(64), nullable=False, unique=True)
operator_id: Mapped[uuid.UUID | None] = mapped_column(UUID(as_uuid=True), ForeignKey("users.id", ondelete="SET NULL"))
home_station_id: Mapped[uuid.UUID | None] = mapped_column(UUID(as_uuid=True), ForeignKey("stations.id", ondelete="SET NULL"))
operator_id: Mapped[uuid.UUID | None] = mapped_column(
UUID(as_uuid=True), ForeignKey("users.id", ondelete="SET NULL")
)
home_station_id: Mapped[uuid.UUID | None] = mapped_column(
UUID(as_uuid=True), ForeignKey("stations.id", ondelete="SET NULL")
)
capacity: Mapped[int] = mapped_column(Integer, nullable=False)
max_speed_kph: Mapped[int] = mapped_column(Integer, nullable=False)
consist: Mapped[str | None] = mapped_column(Text, nullable=True)
@@ -77,14 +120,28 @@ class Train(Base, TimestampMixin):
class TrainSchedule(Base, TimestampMixin):
__tablename__ = "train_schedules"
id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
train_id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), ForeignKey("trains.id", ondelete="CASCADE"), nullable=False)
id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True), primary_key=True, default=uuid.uuid4
)
train_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True), ForeignKey("trains.id", ondelete="CASCADE"), nullable=False
)
sequence_index: Mapped[int] = mapped_column(Integer, nullable=False)
station_id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), ForeignKey("stations.id", ondelete="CASCADE"), nullable=False)
scheduled_arrival: Mapped[DateTime | None] = mapped_column(DateTime(timezone=True), nullable=True)
scheduled_departure: Mapped[DateTime | None] = mapped_column(DateTime(timezone=True), nullable=True)
station_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
ForeignKey("stations.id", ondelete="CASCADE"),
nullable=False,
)
scheduled_arrival: Mapped[DateTime | None] = mapped_column(
DateTime(timezone=True), nullable=True
)
scheduled_departure: Mapped[DateTime | None] = mapped_column(
DateTime(timezone=True), nullable=True
)
dwell_seconds: Mapped[int | None] = mapped_column(Integer, nullable=True)
__table_args__ = (
UniqueConstraint("train_id", "sequence_index", name="uq_train_schedule_sequence"),
UniqueConstraint(
"train_id", "sequence_index", name="uq_train_schedule_sequence"
),
)

View File

@@ -9,16 +9,20 @@ from backend.app.core.config import get_settings
settings = get_settings()
engine = create_engine(settings.sqlalchemy_database_url, echo=settings.database_echo, future=True)
SessionLocal = sessionmaker(bind=engine, autoflush=False, autocommit=False, expire_on_commit=False)
engine = create_engine(
settings.sqlalchemy_database_url, echo=settings.database_echo, future=True
)
SessionLocal = sessionmaker(
bind=engine, autoflush=False, autocommit=False, expire_on_commit=False
)
def get_db_session() -> Generator[Session, None, None]:
session = SessionLocal()
try:
yield session
finally:
session.close()
session = SessionLocal()
try:
yield session
finally:
session.close()
__all__ = ["engine", "SessionLocal", "get_db_session"]

View File

@@ -8,9 +8,9 @@ from sqlalchemy.orm import Session
from backend.app.db.session import SessionLocal
from backend.app.repositories import (
StationRepository,
TrackRepository,
TrainRepository,
TrainScheduleRepository,
TrackRepository,
UserRepository,
)

View File

@@ -1,39 +1,39 @@
from .auth import (
AuthResponse,
LoginRequest,
RegisterRequest,
TokenPayload,
TokenResponse,
UserInDB,
UserPublic,
AuthResponse,
LoginRequest,
RegisterRequest,
TokenPayload,
TokenResponse,
UserInDB,
UserPublic,
)
from .base import (
StationCreate,
StationModel,
TrackCreate,
TrackModel,
TrainScheduleCreate,
TrainCreate,
TrainModel,
UserCreate,
to_camel,
StationCreate,
StationModel,
TrackCreate,
TrackModel,
TrainCreate,
TrainModel,
TrainScheduleCreate,
UserCreate,
to_camel,
)
__all__ = [
"LoginRequest",
"RegisterRequest",
"AuthResponse",
"TokenPayload",
"TokenResponse",
"UserInDB",
"UserPublic",
"StationCreate",
"StationModel",
"TrackCreate",
"TrackModel",
"TrainScheduleCreate",
"TrainCreate",
"TrainModel",
"UserCreate",
"to_camel",
"LoginRequest",
"RegisterRequest",
"AuthResponse",
"TokenPayload",
"TokenResponse",
"UserInDB",
"UserPublic",
"StationCreate",
"StationModel",
"TrackCreate",
"TrackModel",
"TrainScheduleCreate",
"TrainCreate",
"TrainModel",
"UserCreate",
"to_camel",
]

View File

@@ -10,6 +10,7 @@ def to_camel(string: str) -> str:
head, *tail = string.split("_")
return head + "".join(part.capitalize() for part in tail)
IdT = TypeVar("IdT", bound=str)

View File

@@ -1,8 +1,8 @@
"""Repository abstractions for database access."""
from backend.app.repositories.stations import StationRepository
from backend.app.repositories.train_schedules import TrainScheduleRepository
from backend.app.repositories.tracks import TrackRepository
from backend.app.repositories.train_schedules import TrainScheduleRepository
from backend.app.repositories.trains import TrainRepository
from backend.app.repositories.users import UserRepository

View File

@@ -1,8 +1,8 @@
from __future__ import annotations
import sqlalchemy as sa
from typing import Generic, Iterable, Optional, Sequence, Type, TypeVar
import sqlalchemy as sa
from sqlalchemy.orm import Session
from backend.app.db.models import Base

View File

@@ -1,13 +1,12 @@
from __future__ import annotations
import sqlalchemy as sa
from geoalchemy2.elements import WKTElement
from sqlalchemy.orm import Session
from geoalchemy2.elements import WKTElement
from backend.app.db.models import Station
from backend.app.repositories.base import BaseRepository
from backend.app.models import StationCreate
from backend.app.repositories.base import BaseRepository
class StationRepository(BaseRepository[Station]):

View File

@@ -1,13 +1,14 @@
from __future__ import annotations
from uuid import UUID
import sqlalchemy as sa
from geoalchemy2.elements import WKTElement
from uuid import UUID
from sqlalchemy.orm import Session
from backend.app.db.models import Track
from backend.app.repositories.base import BaseRepository
from backend.app.models import TrackCreate
from backend.app.repositories.base import BaseRepository
class TrackRepository(BaseRepository[Track]):

View File

@@ -1,7 +1,8 @@
from __future__ import annotations
import sqlalchemy as sa
from uuid import UUID
import sqlalchemy as sa
from sqlalchemy.orm import Session
from backend.app.db.models import TrainSchedule

View File

@@ -1,12 +1,13 @@
from __future__ import annotations
import sqlalchemy as sa
from uuid import UUID
import sqlalchemy as sa
from sqlalchemy.orm import Session
from backend.app.db.models import Train
from backend.app.repositories.base import BaseRepository
from backend.app.models import TrainCreate
from backend.app.repositories.base import BaseRepository
class TrainRepository(BaseRepository[Train]):

View File

@@ -17,11 +17,15 @@ class UserRepository(BaseRepository[User]):
super().__init__(session)
def get_by_username(self, username: str) -> User | None:
statement = sa.select(self.model).where(sa.func.lower(self.model.username) == username.lower())
statement = sa.select(self.model).where(
sa.func.lower(self.model.username) == username.lower()
)
return self.session.scalar(statement)
def list_recent(self, limit: int = 50) -> list[User]:
statement = sa.select(self.model).order_by(self.model.created_at.desc()).limit(limit)
statement = (
sa.select(self.model).order_by(self.model.created_at.desc()).limit(limit)
)
return list(self.session.scalars(statement))
def create(self, data: UserCreate) -> User:

View File

@@ -43,7 +43,9 @@ def to_public_user(user: UserInDB) -> UserPublic:
return UserPublic(username=user.username, full_name=user.full_name)
def register_user(username: str, password: str, full_name: Optional[str] = None) -> UserInDB:
def register_user(
username: str, password: str, full_name: Optional[str] = None
) -> UserInDB:
normalized_username = username.strip()
if not normalized_username:
raise ValueError("Username must not be empty")

View File

@@ -6,13 +6,14 @@ from typing import Iterable, cast
from geoalchemy2.elements import WKBElement, WKTElement
from geoalchemy2.shape import to_shape
try: # pragma: no cover - optional dependency guard
from shapely.geometry import Point # type: ignore
except ImportError: # pragma: no cover - allow running without shapely at import time
Point = None # type: ignore[assignment]
from sqlalchemy.orm import Session
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.orm import Session
from backend.app.models import StationModel, TrackModel, TrainModel
from backend.app.repositories import StationRepository, TrackRepository, TrainRepository

View File

@@ -57,7 +57,9 @@ def run_migrations_online() -> None:
)
with connectable.connect() as connection:
context.configure(connection=connection, target_metadata=target_metadata, compare_type=True)
context.configure(
connection=connection, target_metadata=target_metadata, compare_type=True
)
with context.begin_transaction():
context.run_migrations()

View File

@@ -20,79 +20,185 @@ def upgrade() -> None:
op.create_table(
"users",
sa.Column("id", postgresql.UUID(as_uuid=True), primary_key=True, server_default=sa.text("gen_random_uuid()")),
sa.Column(
"id",
postgresql.UUID(as_uuid=True),
primary_key=True,
server_default=sa.text("gen_random_uuid()"),
),
sa.Column("username", sa.String(length=64), nullable=False, unique=True),
sa.Column("email", sa.String(length=255), nullable=True, unique=True),
sa.Column("full_name", sa.String(length=128), nullable=True),
sa.Column("password_hash", sa.String(length=256), nullable=False),
sa.Column("role", sa.String(length=32), nullable=False, server_default="player"),
sa.Column(
"role", sa.String(length=32), nullable=False, server_default="player"
),
sa.Column("preferences", sa.Text(), nullable=True),
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("timezone('utc', now())"), nullable=False),
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.text("timezone('utc', now())"), nullable=False),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.text("timezone('utc', now())"),
nullable=False,
),
sa.Column(
"updated_at",
sa.DateTime(timezone=True),
server_default=sa.text("timezone('utc', now())"),
nullable=False,
),
)
op.create_table(
"stations",
sa.Column("id", postgresql.UUID(as_uuid=True), primary_key=True, server_default=sa.text("gen_random_uuid()")),
sa.Column(
"id",
postgresql.UUID(as_uuid=True),
primary_key=True,
server_default=sa.text("gen_random_uuid()"),
),
sa.Column("osm_id", sa.String(length=32), nullable=True),
sa.Column("name", sa.String(length=128), nullable=False),
sa.Column("code", sa.String(length=16), nullable=True),
sa.Column("location", Geometry(geometry_type="POINT", srid=4326), nullable=False),
sa.Column(
"location", Geometry(geometry_type="POINT", srid=4326), nullable=False
),
sa.Column("elevation_m", sa.Float(), nullable=True),
sa.Column("is_active", sa.Boolean(), nullable=False, server_default=sa.text("true")),
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("timezone('utc', now())"), nullable=False),
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.text("timezone('utc', now())"), nullable=False),
sa.Column(
"is_active", sa.Boolean(), nullable=False, server_default=sa.text("true")
),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.text("timezone('utc', now())"),
nullable=False,
),
sa.Column(
"updated_at",
sa.DateTime(timezone=True),
server_default=sa.text("timezone('utc', now())"),
nullable=False,
),
)
op.create_index(
"ix_stations_location", "stations", ["location"], postgresql_using="gist"
)
op.create_index("ix_stations_location", "stations", ["location"], postgresql_using="gist")
op.create_table(
"tracks",
sa.Column("id", postgresql.UUID(as_uuid=True), primary_key=True, server_default=sa.text("gen_random_uuid()")),
sa.Column(
"id",
postgresql.UUID(as_uuid=True),
primary_key=True,
server_default=sa.text("gen_random_uuid()"),
),
sa.Column("name", sa.String(length=128), nullable=True),
sa.Column("start_station_id", postgresql.UUID(as_uuid=True), nullable=False),
sa.Column("end_station_id", postgresql.UUID(as_uuid=True), nullable=False),
sa.Column("length_meters", sa.Numeric(10, 2), nullable=True),
sa.Column("max_speed_kph", sa.Integer(), nullable=True),
sa.Column("is_bidirectional", sa.Boolean(), nullable=False, server_default=sa.text("true")),
sa.Column("status", sa.String(length=32), nullable=False, server_default="planned"),
sa.Column("track_geometry", Geometry(geometry_type="LINESTRING", srid=4326), nullable=False),
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("timezone('utc', now())"), nullable=False),
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.text("timezone('utc', now())"), nullable=False),
sa.ForeignKeyConstraint(["start_station_id"], ["stations.id"], ondelete="RESTRICT"),
sa.ForeignKeyConstraint(["end_station_id"], ["stations.id"], ondelete="RESTRICT"),
sa.UniqueConstraint("start_station_id", "end_station_id", name="uq_tracks_station_pair"),
sa.Column(
"is_bidirectional",
sa.Boolean(),
nullable=False,
server_default=sa.text("true"),
),
sa.Column(
"status", sa.String(length=32), nullable=False, server_default="planned"
),
sa.Column(
"track_geometry",
Geometry(geometry_type="LINESTRING", srid=4326),
nullable=False,
),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.text("timezone('utc', now())"),
nullable=False,
),
sa.Column(
"updated_at",
sa.DateTime(timezone=True),
server_default=sa.text("timezone('utc', now())"),
nullable=False,
),
sa.ForeignKeyConstraint(
["start_station_id"], ["stations.id"], ondelete="RESTRICT"
),
sa.ForeignKeyConstraint(
["end_station_id"], ["stations.id"], ondelete="RESTRICT"
),
sa.UniqueConstraint(
"start_station_id", "end_station_id", name="uq_tracks_station_pair"
),
)
op.create_index(
"ix_tracks_geometry", "tracks", ["track_geometry"], postgresql_using="gist"
)
op.create_index("ix_tracks_geometry", "tracks", ["track_geometry"], postgresql_using="gist")
op.create_table(
"trains",
sa.Column("id", postgresql.UUID(as_uuid=True), primary_key=True, server_default=sa.text("gen_random_uuid()")),
sa.Column(
"id",
postgresql.UUID(as_uuid=True),
primary_key=True,
server_default=sa.text("gen_random_uuid()"),
),
sa.Column("designation", sa.String(length=64), nullable=False, unique=True),
sa.Column("operator_id", postgresql.UUID(as_uuid=True), nullable=True),
sa.Column("home_station_id", postgresql.UUID(as_uuid=True), nullable=True),
sa.Column("capacity", sa.Integer(), nullable=False),
sa.Column("max_speed_kph", sa.Integer(), nullable=False),
sa.Column("consist", sa.Text(), nullable=True),
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("timezone('utc', now())"), nullable=False),
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.text("timezone('utc', now())"), nullable=False),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.text("timezone('utc', now())"),
nullable=False,
),
sa.Column(
"updated_at",
sa.DateTime(timezone=True),
server_default=sa.text("timezone('utc', now())"),
nullable=False,
),
sa.ForeignKeyConstraint(["operator_id"], ["users.id"], ondelete="SET NULL"),
sa.ForeignKeyConstraint(["home_station_id"], ["stations.id"], ondelete="SET NULL"),
sa.ForeignKeyConstraint(
["home_station_id"], ["stations.id"], ondelete="SET NULL"
),
)
op.create_table(
"train_schedules",
sa.Column("id", postgresql.UUID(as_uuid=True), primary_key=True, server_default=sa.text("gen_random_uuid()")),
sa.Column(
"id",
postgresql.UUID(as_uuid=True),
primary_key=True,
server_default=sa.text("gen_random_uuid()"),
),
sa.Column("train_id", postgresql.UUID(as_uuid=True), nullable=False),
sa.Column("sequence_index", sa.Integer(), nullable=False),
sa.Column("station_id", postgresql.UUID(as_uuid=True), nullable=False),
sa.Column("scheduled_arrival", sa.DateTime(timezone=True), nullable=True),
sa.Column("scheduled_departure", sa.DateTime(timezone=True), nullable=True),
sa.Column("dwell_seconds", sa.Integer(), nullable=True),
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("timezone('utc', now())"), nullable=False),
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.text("timezone('utc', now())"), nullable=False),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.text("timezone('utc', now())"),
nullable=False,
),
sa.Column(
"updated_at",
sa.DateTime(timezone=True),
server_default=sa.text("timezone('utc', now())"),
nullable=False,
),
sa.ForeignKeyConstraint(["train_id"], ["trains.id"], ondelete="CASCADE"),
sa.ForeignKeyConstraint(["station_id"], ["stations.id"], ondelete="CASCADE"),
sa.UniqueConstraint("train_id", "sequence_index", name="uq_train_schedule_sequence"),
sa.UniqueConstraint(
"train_id", "sequence_index", name="uq_train_schedule_sequence"
),
)

View File

@@ -35,9 +35,7 @@ def test_me_endpoint_returns_current_user() -> None:
)
token = login.json()["accessToken"]
response = client.get(
"/api/auth/me", headers={"Authorization": f"Bearer {token}"}
)
response = client.get("/api/auth/me", headers={"Authorization": f"Bearer {token}"})
assert response.status_code == 200
assert response.json()["username"] == "demo"

View File

@@ -20,9 +20,7 @@ def test_network_snapshot_requires_authentication() -> None:
def test_network_snapshot_endpoint_returns_collections() -> None:
token = _authenticate()
response = client.get(
"/api/network", headers={"Authorization": f"Bearer {token}"}
)
response = client.get("/api/network", headers={"Authorization": f"Bearer {token}"})
assert response.status_code == 200
payload = response.json()

View File

@@ -40,7 +40,9 @@ def sample_entities() -> dict[str, SimpleNamespace]:
return {"station": station, "track": track, "train": train}
def test_network_snapshot_prefers_repository_data(monkeypatch: pytest.MonkeyPatch, sample_entities: dict[str, SimpleNamespace]) -> None:
def test_network_snapshot_prefers_repository_data(
monkeypatch: pytest.MonkeyPatch, sample_entities: dict[str, SimpleNamespace]
) -> None:
station = sample_entities["station"]
track = sample_entities["track"]
train = sample_entities["train"]
@@ -58,7 +60,9 @@ def test_network_snapshot_prefers_repository_data(monkeypatch: pytest.MonkeyPatc
assert snapshot["trains"][0]["operatingTrackIds"] == []
def test_network_snapshot_falls_back_when_repositories_empty(monkeypatch: pytest.MonkeyPatch) -> None:
def test_network_snapshot_falls_back_when_repositories_empty(
monkeypatch: pytest.MonkeyPatch,
) -> None:
monkeypatch.setattr(StationRepository, "list_active", lambda self: [])
monkeypatch.setattr(TrackRepository, "list_all", lambda self: [])
monkeypatch.setattr(TrainRepository, "list_all", lambda self: [])

View File

@@ -6,6 +6,7 @@ from typing import Any, List, cast
from uuid import uuid4
import pytest
from sqlalchemy.orm import Session
from backend.app.db.models import TrainSchedule, User
from backend.app.db.unit_of_work import SqlAlchemyUnitOfWork
@@ -23,7 +24,6 @@ from backend.app.repositories import (
TrainScheduleRepository,
UserRepository,
)
from sqlalchemy.orm import Session
@dataclass
@@ -50,7 +50,9 @@ class DummySession:
self.statements.append(statement)
return self.scalar_result
def flush(self, _objects: list[Any] | None = None) -> None: # pragma: no cover - optional
def flush(
self, _objects: list[Any] | None = None
) -> None: # pragma: no cover - optional
return None
def commit(self) -> None: # pragma: no cover - optional
@@ -215,9 +217,7 @@ def test_unit_of_work_commits_and_closes_session() -> None:
uow = SqlAlchemyUnitOfWork(lambda: cast(Session, session))
with uow as active:
active.users.create(
UserCreate(username="demo", password_hash="hashed")
)
active.users.create(UserCreate(username="demo", password_hash="hashed"))
active.commit()
assert session.committed