refactor: improve code formatting and organization across multiple files
This commit is contained in:
@@ -25,17 +25,25 @@ async def login(credentials: LoginRequest) -> AuthResponse:
|
|||||||
return issue_token_for_user(user)
|
return issue_token_for_user(user)
|
||||||
|
|
||||||
|
|
||||||
@router.post("/register", response_model=AuthResponse, status_code=status.HTTP_201_CREATED)
|
@router.post(
|
||||||
|
"/register", response_model=AuthResponse, status_code=status.HTTP_201_CREATED
|
||||||
|
)
|
||||||
async def register(payload: RegisterRequest) -> AuthResponse:
|
async def register(payload: RegisterRequest) -> AuthResponse:
|
||||||
try:
|
try:
|
||||||
user = register_user(payload.username, payload.password, payload.full_name)
|
user = register_user(payload.username, payload.password, payload.full_name)
|
||||||
except ValueError as exc:
|
except ValueError as exc:
|
||||||
message = str(exc)
|
message = str(exc)
|
||||||
status_code = status.HTTP_409_CONFLICT if "exists" in message else status.HTTP_400_BAD_REQUEST
|
status_code = (
|
||||||
|
status.HTTP_409_CONFLICT
|
||||||
|
if "exists" in message
|
||||||
|
else status.HTTP_400_BAD_REQUEST
|
||||||
|
)
|
||||||
raise HTTPException(status_code=status_code, detail=message) from exc
|
raise HTTPException(status_code=status_code, detail=message) from exc
|
||||||
return issue_token_for_user(user)
|
return issue_token_for_user(user)
|
||||||
|
|
||||||
|
|
||||||
@router.get("/me", response_model=UserPublic)
|
@router.get("/me", response_model=UserPublic)
|
||||||
async def read_current_user(current_user: UserPublic = Depends(get_current_user)) -> UserPublic:
|
async def read_current_user(
|
||||||
|
current_user: UserPublic = Depends(get_current_user),
|
||||||
|
) -> UserPublic:
|
||||||
return current_user
|
return current_user
|
||||||
|
|||||||
@@ -1,11 +1,9 @@
|
|||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||||
|
|
||||||
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
|
|
||||||
class Settings(BaseSettings):
|
class Settings(BaseSettings):
|
||||||
project_name: str = "Rail Game API"
|
project_name: str = "Rail Game API"
|
||||||
version: str = "0.1.0"
|
version: str = "0.1.0"
|
||||||
|
|||||||
@@ -25,12 +25,18 @@ def create_access_token(subject: str, expires_delta: timedelta | None = None) ->
|
|||||||
expires_delta or timedelta(minutes=settings.access_token_expire_minutes)
|
expires_delta or timedelta(minutes=settings.access_token_expire_minutes)
|
||||||
)
|
)
|
||||||
to_encode: Dict[str, Any] = {"sub": subject, "exp": expire}
|
to_encode: Dict[str, Any] = {"sub": subject, "exp": expire}
|
||||||
return jwt.encode(to_encode, settings.jwt_secret_key, algorithm=settings.jwt_algorithm)
|
return jwt.encode(
|
||||||
|
to_encode, settings.jwt_secret_key, algorithm=settings.jwt_algorithm
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def decode_access_token(token: str) -> Dict[str, Any]:
|
def decode_access_token(token: str) -> Dict[str, Any]:
|
||||||
settings = get_settings()
|
settings = get_settings()
|
||||||
try:
|
try:
|
||||||
return jwt.decode(token, settings.jwt_secret_key, algorithms=[settings.jwt_algorithm])
|
return jwt.decode(
|
||||||
except JWTError as exc: # pragma: no cover - specific error mapping handled by caller
|
token, settings.jwt_secret_key, algorithms=[settings.jwt_algorithm]
|
||||||
|
)
|
||||||
|
except (
|
||||||
|
JWTError
|
||||||
|
) as exc: # pragma: no cover - specific error mapping handled by caller
|
||||||
raise ValueError("Invalid token") from exc
|
raise ValueError("Invalid token") from exc
|
||||||
|
|||||||
@@ -3,7 +3,17 @@ from __future__ import annotations
|
|||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
from geoalchemy2 import Geometry
|
from geoalchemy2 import Geometry
|
||||||
from sqlalchemy import Boolean, DateTime, Float, ForeignKey, Integer, Numeric, String, Text, UniqueConstraint
|
from sqlalchemy import (
|
||||||
|
Boolean,
|
||||||
|
DateTime,
|
||||||
|
Float,
|
||||||
|
ForeignKey,
|
||||||
|
Integer,
|
||||||
|
Numeric,
|
||||||
|
String,
|
||||||
|
Text,
|
||||||
|
UniqueConstraint,
|
||||||
|
)
|
||||||
from sqlalchemy.dialects.postgresql import UUID
|
from sqlalchemy.dialects.postgresql import UUID
|
||||||
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
|
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
|
||||||
from sqlalchemy.sql import func
|
from sqlalchemy.sql import func
|
||||||
@@ -14,16 +24,23 @@ class Base(DeclarativeBase):
|
|||||||
|
|
||||||
|
|
||||||
class TimestampMixin:
|
class TimestampMixin:
|
||||||
created_at: Mapped[DateTime] = mapped_column(DateTime(timezone=True), server_default=func.now(), nullable=False)
|
created_at: Mapped[DateTime] = mapped_column(
|
||||||
|
DateTime(timezone=True), server_default=func.now(), nullable=False
|
||||||
|
)
|
||||||
updated_at: Mapped[DateTime] = mapped_column(
|
updated_at: Mapped[DateTime] = mapped_column(
|
||||||
DateTime(timezone=True), server_default=func.now(), onupdate=func.now(), nullable=False
|
DateTime(timezone=True),
|
||||||
|
server_default=func.now(),
|
||||||
|
onupdate=func.now(),
|
||||||
|
nullable=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class User(Base, TimestampMixin):
|
class User(Base, TimestampMixin):
|
||||||
__tablename__ = "users"
|
__tablename__ = "users"
|
||||||
|
|
||||||
id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
id: Mapped[uuid.UUID] = mapped_column(
|
||||||
|
UUID(as_uuid=True), primary_key=True, default=uuid.uuid4
|
||||||
|
)
|
||||||
username: Mapped[str] = mapped_column(String(64), unique=True, nullable=False)
|
username: Mapped[str] = mapped_column(String(64), unique=True, nullable=False)
|
||||||
email: Mapped[str | None] = mapped_column(String(255), unique=True, nullable=True)
|
email: Mapped[str | None] = mapped_column(String(255), unique=True, nullable=True)
|
||||||
full_name: Mapped[str | None] = mapped_column(String(128), nullable=True)
|
full_name: Mapped[str | None] = mapped_column(String(128), nullable=True)
|
||||||
@@ -35,11 +52,15 @@ class User(Base, TimestampMixin):
|
|||||||
class Station(Base, TimestampMixin):
|
class Station(Base, TimestampMixin):
|
||||||
__tablename__ = "stations"
|
__tablename__ = "stations"
|
||||||
|
|
||||||
id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
id: Mapped[uuid.UUID] = mapped_column(
|
||||||
|
UUID(as_uuid=True), primary_key=True, default=uuid.uuid4
|
||||||
|
)
|
||||||
osm_id: Mapped[str | None] = mapped_column(String(32), nullable=True)
|
osm_id: Mapped[str | None] = mapped_column(String(32), nullable=True)
|
||||||
name: Mapped[str] = mapped_column(String(128), nullable=False)
|
name: Mapped[str] = mapped_column(String(128), nullable=False)
|
||||||
code: Mapped[str | None] = mapped_column(String(16), nullable=True)
|
code: Mapped[str | None] = mapped_column(String(16), nullable=True)
|
||||||
location: Mapped[str] = mapped_column(Geometry(geometry_type="POINT", srid=4326), nullable=False)
|
location: Mapped[str] = mapped_column(
|
||||||
|
Geometry(geometry_type="POINT", srid=4326), nullable=False
|
||||||
|
)
|
||||||
elevation_m: Mapped[float | None] = mapped_column(Float, nullable=True)
|
elevation_m: Mapped[float | None] = mapped_column(Float, nullable=True)
|
||||||
is_active: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
|
is_active: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
|
||||||
|
|
||||||
@@ -47,28 +68,50 @@ class Station(Base, TimestampMixin):
|
|||||||
class Track(Base, TimestampMixin):
|
class Track(Base, TimestampMixin):
|
||||||
__tablename__ = "tracks"
|
__tablename__ = "tracks"
|
||||||
|
|
||||||
id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
id: Mapped[uuid.UUID] = mapped_column(
|
||||||
|
UUID(as_uuid=True), primary_key=True, default=uuid.uuid4
|
||||||
|
)
|
||||||
name: Mapped[str | None] = mapped_column(String(128), nullable=True)
|
name: Mapped[str | None] = mapped_column(String(128), nullable=True)
|
||||||
start_station_id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), ForeignKey("stations.id", ondelete="RESTRICT"), nullable=False)
|
start_station_id: Mapped[uuid.UUID] = mapped_column(
|
||||||
end_station_id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), ForeignKey("stations.id", ondelete="RESTRICT"), nullable=False)
|
UUID(as_uuid=True),
|
||||||
|
ForeignKey("stations.id", ondelete="RESTRICT"),
|
||||||
|
nullable=False,
|
||||||
|
)
|
||||||
|
end_station_id: Mapped[uuid.UUID] = mapped_column(
|
||||||
|
UUID(as_uuid=True),
|
||||||
|
ForeignKey("stations.id", ondelete="RESTRICT"),
|
||||||
|
nullable=False,
|
||||||
|
)
|
||||||
length_meters: Mapped[float | None] = mapped_column(Numeric(10, 2), nullable=True)
|
length_meters: Mapped[float | None] = mapped_column(Numeric(10, 2), nullable=True)
|
||||||
max_speed_kph: Mapped[int | None] = mapped_column(Integer, nullable=True)
|
max_speed_kph: Mapped[int | None] = mapped_column(Integer, nullable=True)
|
||||||
is_bidirectional: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
|
is_bidirectional: Mapped[bool] = mapped_column(
|
||||||
|
Boolean, nullable=False, default=True
|
||||||
|
)
|
||||||
status: Mapped[str] = mapped_column(String(32), nullable=False, default="planned")
|
status: Mapped[str] = mapped_column(String(32), nullable=False, default="planned")
|
||||||
track_geometry: Mapped[str] = mapped_column(Geometry(geometry_type="LINESTRING", srid=4326), nullable=False)
|
track_geometry: Mapped[str] = mapped_column(
|
||||||
|
Geometry(geometry_type="LINESTRING", srid=4326), nullable=False
|
||||||
|
)
|
||||||
|
|
||||||
__table_args__ = (
|
__table_args__ = (
|
||||||
UniqueConstraint("start_station_id", "end_station_id", name="uq_tracks_station_pair"),
|
UniqueConstraint(
|
||||||
|
"start_station_id", "end_station_id", name="uq_tracks_station_pair"
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class Train(Base, TimestampMixin):
|
class Train(Base, TimestampMixin):
|
||||||
__tablename__ = "trains"
|
__tablename__ = "trains"
|
||||||
|
|
||||||
id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
id: Mapped[uuid.UUID] = mapped_column(
|
||||||
|
UUID(as_uuid=True), primary_key=True, default=uuid.uuid4
|
||||||
|
)
|
||||||
designation: Mapped[str] = mapped_column(String(64), nullable=False, unique=True)
|
designation: Mapped[str] = mapped_column(String(64), nullable=False, unique=True)
|
||||||
operator_id: Mapped[uuid.UUID | None] = mapped_column(UUID(as_uuid=True), ForeignKey("users.id", ondelete="SET NULL"))
|
operator_id: Mapped[uuid.UUID | None] = mapped_column(
|
||||||
home_station_id: Mapped[uuid.UUID | None] = mapped_column(UUID(as_uuid=True), ForeignKey("stations.id", ondelete="SET NULL"))
|
UUID(as_uuid=True), ForeignKey("users.id", ondelete="SET NULL")
|
||||||
|
)
|
||||||
|
home_station_id: Mapped[uuid.UUID | None] = mapped_column(
|
||||||
|
UUID(as_uuid=True), ForeignKey("stations.id", ondelete="SET NULL")
|
||||||
|
)
|
||||||
capacity: Mapped[int] = mapped_column(Integer, nullable=False)
|
capacity: Mapped[int] = mapped_column(Integer, nullable=False)
|
||||||
max_speed_kph: Mapped[int] = mapped_column(Integer, nullable=False)
|
max_speed_kph: Mapped[int] = mapped_column(Integer, nullable=False)
|
||||||
consist: Mapped[str | None] = mapped_column(Text, nullable=True)
|
consist: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||||
@@ -77,14 +120,28 @@ class Train(Base, TimestampMixin):
|
|||||||
class TrainSchedule(Base, TimestampMixin):
|
class TrainSchedule(Base, TimestampMixin):
|
||||||
__tablename__ = "train_schedules"
|
__tablename__ = "train_schedules"
|
||||||
|
|
||||||
id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
id: Mapped[uuid.UUID] = mapped_column(
|
||||||
train_id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), ForeignKey("trains.id", ondelete="CASCADE"), nullable=False)
|
UUID(as_uuid=True), primary_key=True, default=uuid.uuid4
|
||||||
|
)
|
||||||
|
train_id: Mapped[uuid.UUID] = mapped_column(
|
||||||
|
UUID(as_uuid=True), ForeignKey("trains.id", ondelete="CASCADE"), nullable=False
|
||||||
|
)
|
||||||
sequence_index: Mapped[int] = mapped_column(Integer, nullable=False)
|
sequence_index: Mapped[int] = mapped_column(Integer, nullable=False)
|
||||||
station_id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), ForeignKey("stations.id", ondelete="CASCADE"), nullable=False)
|
station_id: Mapped[uuid.UUID] = mapped_column(
|
||||||
scheduled_arrival: Mapped[DateTime | None] = mapped_column(DateTime(timezone=True), nullable=True)
|
UUID(as_uuid=True),
|
||||||
scheduled_departure: Mapped[DateTime | None] = mapped_column(DateTime(timezone=True), nullable=True)
|
ForeignKey("stations.id", ondelete="CASCADE"),
|
||||||
|
nullable=False,
|
||||||
|
)
|
||||||
|
scheduled_arrival: Mapped[DateTime | None] = mapped_column(
|
||||||
|
DateTime(timezone=True), nullable=True
|
||||||
|
)
|
||||||
|
scheduled_departure: Mapped[DateTime | None] = mapped_column(
|
||||||
|
DateTime(timezone=True), nullable=True
|
||||||
|
)
|
||||||
dwell_seconds: Mapped[int | None] = mapped_column(Integer, nullable=True)
|
dwell_seconds: Mapped[int | None] = mapped_column(Integer, nullable=True)
|
||||||
|
|
||||||
__table_args__ = (
|
__table_args__ = (
|
||||||
UniqueConstraint("train_id", "sequence_index", name="uq_train_schedule_sequence"),
|
UniqueConstraint(
|
||||||
|
"train_id", "sequence_index", name="uq_train_schedule_sequence"
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -9,16 +9,20 @@ from backend.app.core.config import get_settings
|
|||||||
|
|
||||||
settings = get_settings()
|
settings = get_settings()
|
||||||
|
|
||||||
engine = create_engine(settings.sqlalchemy_database_url, echo=settings.database_echo, future=True)
|
engine = create_engine(
|
||||||
SessionLocal = sessionmaker(bind=engine, autoflush=False, autocommit=False, expire_on_commit=False)
|
settings.sqlalchemy_database_url, echo=settings.database_echo, future=True
|
||||||
|
)
|
||||||
|
SessionLocal = sessionmaker(
|
||||||
|
bind=engine, autoflush=False, autocommit=False, expire_on_commit=False
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_db_session() -> Generator[Session, None, None]:
|
def get_db_session() -> Generator[Session, None, None]:
|
||||||
session = SessionLocal()
|
session = SessionLocal()
|
||||||
try:
|
try:
|
||||||
yield session
|
yield session
|
||||||
finally:
|
finally:
|
||||||
session.close()
|
session.close()
|
||||||
|
|
||||||
|
|
||||||
__all__ = ["engine", "SessionLocal", "get_db_session"]
|
__all__ = ["engine", "SessionLocal", "get_db_session"]
|
||||||
|
|||||||
@@ -8,9 +8,9 @@ from sqlalchemy.orm import Session
|
|||||||
from backend.app.db.session import SessionLocal
|
from backend.app.db.session import SessionLocal
|
||||||
from backend.app.repositories import (
|
from backend.app.repositories import (
|
||||||
StationRepository,
|
StationRepository,
|
||||||
|
TrackRepository,
|
||||||
TrainRepository,
|
TrainRepository,
|
||||||
TrainScheduleRepository,
|
TrainScheduleRepository,
|
||||||
TrackRepository,
|
|
||||||
UserRepository,
|
UserRepository,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -1,39 +1,39 @@
|
|||||||
from .auth import (
|
from .auth import (
|
||||||
AuthResponse,
|
AuthResponse,
|
||||||
LoginRequest,
|
LoginRequest,
|
||||||
RegisterRequest,
|
RegisterRequest,
|
||||||
TokenPayload,
|
TokenPayload,
|
||||||
TokenResponse,
|
TokenResponse,
|
||||||
UserInDB,
|
UserInDB,
|
||||||
UserPublic,
|
UserPublic,
|
||||||
)
|
)
|
||||||
from .base import (
|
from .base import (
|
||||||
StationCreate,
|
StationCreate,
|
||||||
StationModel,
|
StationModel,
|
||||||
TrackCreate,
|
TrackCreate,
|
||||||
TrackModel,
|
TrackModel,
|
||||||
TrainScheduleCreate,
|
TrainCreate,
|
||||||
TrainCreate,
|
TrainModel,
|
||||||
TrainModel,
|
TrainScheduleCreate,
|
||||||
UserCreate,
|
UserCreate,
|
||||||
to_camel,
|
to_camel,
|
||||||
)
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"LoginRequest",
|
"LoginRequest",
|
||||||
"RegisterRequest",
|
"RegisterRequest",
|
||||||
"AuthResponse",
|
"AuthResponse",
|
||||||
"TokenPayload",
|
"TokenPayload",
|
||||||
"TokenResponse",
|
"TokenResponse",
|
||||||
"UserInDB",
|
"UserInDB",
|
||||||
"UserPublic",
|
"UserPublic",
|
||||||
"StationCreate",
|
"StationCreate",
|
||||||
"StationModel",
|
"StationModel",
|
||||||
"TrackCreate",
|
"TrackCreate",
|
||||||
"TrackModel",
|
"TrackModel",
|
||||||
"TrainScheduleCreate",
|
"TrainScheduleCreate",
|
||||||
"TrainCreate",
|
"TrainCreate",
|
||||||
"TrainModel",
|
"TrainModel",
|
||||||
"UserCreate",
|
"UserCreate",
|
||||||
"to_camel",
|
"to_camel",
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ def to_camel(string: str) -> str:
|
|||||||
head, *tail = string.split("_")
|
head, *tail = string.split("_")
|
||||||
return head + "".join(part.capitalize() for part in tail)
|
return head + "".join(part.capitalize() for part in tail)
|
||||||
|
|
||||||
|
|
||||||
IdT = TypeVar("IdT", bound=str)
|
IdT = TypeVar("IdT", bound=str)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,8 +1,8 @@
|
|||||||
"""Repository abstractions for database access."""
|
"""Repository abstractions for database access."""
|
||||||
|
|
||||||
from backend.app.repositories.stations import StationRepository
|
from backend.app.repositories.stations import StationRepository
|
||||||
from backend.app.repositories.train_schedules import TrainScheduleRepository
|
|
||||||
from backend.app.repositories.tracks import TrackRepository
|
from backend.app.repositories.tracks import TrackRepository
|
||||||
|
from backend.app.repositories.train_schedules import TrainScheduleRepository
|
||||||
from backend.app.repositories.trains import TrainRepository
|
from backend.app.repositories.trains import TrainRepository
|
||||||
from backend.app.repositories.users import UserRepository
|
from backend.app.repositories.users import UserRepository
|
||||||
|
|
||||||
|
|||||||
@@ -1,8 +1,8 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import sqlalchemy as sa
|
|
||||||
from typing import Generic, Iterable, Optional, Sequence, Type, TypeVar
|
from typing import Generic, Iterable, Optional, Sequence, Type, TypeVar
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from backend.app.db.models import Base
|
from backend.app.db.models import Base
|
||||||
|
|||||||
@@ -1,13 +1,12 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import sqlalchemy as sa
|
import sqlalchemy as sa
|
||||||
|
from geoalchemy2.elements import WKTElement
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from geoalchemy2.elements import WKTElement
|
|
||||||
|
|
||||||
from backend.app.db.models import Station
|
from backend.app.db.models import Station
|
||||||
from backend.app.repositories.base import BaseRepository
|
|
||||||
from backend.app.models import StationCreate
|
from backend.app.models import StationCreate
|
||||||
|
from backend.app.repositories.base import BaseRepository
|
||||||
|
|
||||||
|
|
||||||
class StationRepository(BaseRepository[Station]):
|
class StationRepository(BaseRepository[Station]):
|
||||||
|
|||||||
@@ -1,13 +1,14 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
import sqlalchemy as sa
|
import sqlalchemy as sa
|
||||||
from geoalchemy2.elements import WKTElement
|
from geoalchemy2.elements import WKTElement
|
||||||
from uuid import UUID
|
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from backend.app.db.models import Track
|
from backend.app.db.models import Track
|
||||||
from backend.app.repositories.base import BaseRepository
|
|
||||||
from backend.app.models import TrackCreate
|
from backend.app.models import TrackCreate
|
||||||
|
from backend.app.repositories.base import BaseRepository
|
||||||
|
|
||||||
|
|
||||||
class TrackRepository(BaseRepository[Track]):
|
class TrackRepository(BaseRepository[Track]):
|
||||||
|
|||||||
@@ -1,7 +1,8 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import sqlalchemy as sa
|
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from backend.app.db.models import TrainSchedule
|
from backend.app.db.models import TrainSchedule
|
||||||
|
|||||||
@@ -1,12 +1,13 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import sqlalchemy as sa
|
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from backend.app.db.models import Train
|
from backend.app.db.models import Train
|
||||||
from backend.app.repositories.base import BaseRepository
|
|
||||||
from backend.app.models import TrainCreate
|
from backend.app.models import TrainCreate
|
||||||
|
from backend.app.repositories.base import BaseRepository
|
||||||
|
|
||||||
|
|
||||||
class TrainRepository(BaseRepository[Train]):
|
class TrainRepository(BaseRepository[Train]):
|
||||||
|
|||||||
@@ -17,11 +17,15 @@ class UserRepository(BaseRepository[User]):
|
|||||||
super().__init__(session)
|
super().__init__(session)
|
||||||
|
|
||||||
def get_by_username(self, username: str) -> User | None:
|
def get_by_username(self, username: str) -> User | None:
|
||||||
statement = sa.select(self.model).where(sa.func.lower(self.model.username) == username.lower())
|
statement = sa.select(self.model).where(
|
||||||
|
sa.func.lower(self.model.username) == username.lower()
|
||||||
|
)
|
||||||
return self.session.scalar(statement)
|
return self.session.scalar(statement)
|
||||||
|
|
||||||
def list_recent(self, limit: int = 50) -> list[User]:
|
def list_recent(self, limit: int = 50) -> list[User]:
|
||||||
statement = sa.select(self.model).order_by(self.model.created_at.desc()).limit(limit)
|
statement = (
|
||||||
|
sa.select(self.model).order_by(self.model.created_at.desc()).limit(limit)
|
||||||
|
)
|
||||||
return list(self.session.scalars(statement))
|
return list(self.session.scalars(statement))
|
||||||
|
|
||||||
def create(self, data: UserCreate) -> User:
|
def create(self, data: UserCreate) -> User:
|
||||||
|
|||||||
@@ -43,7 +43,9 @@ def to_public_user(user: UserInDB) -> UserPublic:
|
|||||||
return UserPublic(username=user.username, full_name=user.full_name)
|
return UserPublic(username=user.username, full_name=user.full_name)
|
||||||
|
|
||||||
|
|
||||||
def register_user(username: str, password: str, full_name: Optional[str] = None) -> UserInDB:
|
def register_user(
|
||||||
|
username: str, password: str, full_name: Optional[str] = None
|
||||||
|
) -> UserInDB:
|
||||||
normalized_username = username.strip()
|
normalized_username = username.strip()
|
||||||
if not normalized_username:
|
if not normalized_username:
|
||||||
raise ValueError("Username must not be empty")
|
raise ValueError("Username must not be empty")
|
||||||
|
|||||||
@@ -6,13 +6,14 @@ from typing import Iterable, cast
|
|||||||
|
|
||||||
from geoalchemy2.elements import WKBElement, WKTElement
|
from geoalchemy2.elements import WKBElement, WKTElement
|
||||||
from geoalchemy2.shape import to_shape
|
from geoalchemy2.shape import to_shape
|
||||||
|
|
||||||
try: # pragma: no cover - optional dependency guard
|
try: # pragma: no cover - optional dependency guard
|
||||||
from shapely.geometry import Point # type: ignore
|
from shapely.geometry import Point # type: ignore
|
||||||
except ImportError: # pragma: no cover - allow running without shapely at import time
|
except ImportError: # pragma: no cover - allow running without shapely at import time
|
||||||
Point = None # type: ignore[assignment]
|
Point = None # type: ignore[assignment]
|
||||||
|
|
||||||
from sqlalchemy.orm import Session
|
|
||||||
from sqlalchemy.exc import SQLAlchemyError
|
from sqlalchemy.exc import SQLAlchemyError
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from backend.app.models import StationModel, TrackModel, TrainModel
|
from backend.app.models import StationModel, TrackModel, TrainModel
|
||||||
from backend.app.repositories import StationRepository, TrackRepository, TrainRepository
|
from backend.app.repositories import StationRepository, TrackRepository, TrainRepository
|
||||||
|
|||||||
@@ -57,7 +57,9 @@ def run_migrations_online() -> None:
|
|||||||
)
|
)
|
||||||
|
|
||||||
with connectable.connect() as connection:
|
with connectable.connect() as connection:
|
||||||
context.configure(connection=connection, target_metadata=target_metadata, compare_type=True)
|
context.configure(
|
||||||
|
connection=connection, target_metadata=target_metadata, compare_type=True
|
||||||
|
)
|
||||||
|
|
||||||
with context.begin_transaction():
|
with context.begin_transaction():
|
||||||
context.run_migrations()
|
context.run_migrations()
|
||||||
|
|||||||
@@ -20,79 +20,185 @@ def upgrade() -> None:
|
|||||||
|
|
||||||
op.create_table(
|
op.create_table(
|
||||||
"users",
|
"users",
|
||||||
sa.Column("id", postgresql.UUID(as_uuid=True), primary_key=True, server_default=sa.text("gen_random_uuid()")),
|
sa.Column(
|
||||||
|
"id",
|
||||||
|
postgresql.UUID(as_uuid=True),
|
||||||
|
primary_key=True,
|
||||||
|
server_default=sa.text("gen_random_uuid()"),
|
||||||
|
),
|
||||||
sa.Column("username", sa.String(length=64), nullable=False, unique=True),
|
sa.Column("username", sa.String(length=64), nullable=False, unique=True),
|
||||||
sa.Column("email", sa.String(length=255), nullable=True, unique=True),
|
sa.Column("email", sa.String(length=255), nullable=True, unique=True),
|
||||||
sa.Column("full_name", sa.String(length=128), nullable=True),
|
sa.Column("full_name", sa.String(length=128), nullable=True),
|
||||||
sa.Column("password_hash", sa.String(length=256), nullable=False),
|
sa.Column("password_hash", sa.String(length=256), nullable=False),
|
||||||
sa.Column("role", sa.String(length=32), nullable=False, server_default="player"),
|
sa.Column(
|
||||||
|
"role", sa.String(length=32), nullable=False, server_default="player"
|
||||||
|
),
|
||||||
sa.Column("preferences", sa.Text(), nullable=True),
|
sa.Column("preferences", sa.Text(), nullable=True),
|
||||||
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("timezone('utc', now())"), nullable=False),
|
sa.Column(
|
||||||
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.text("timezone('utc', now())"), nullable=False),
|
"created_at",
|
||||||
|
sa.DateTime(timezone=True),
|
||||||
|
server_default=sa.text("timezone('utc', now())"),
|
||||||
|
nullable=False,
|
||||||
|
),
|
||||||
|
sa.Column(
|
||||||
|
"updated_at",
|
||||||
|
sa.DateTime(timezone=True),
|
||||||
|
server_default=sa.text("timezone('utc', now())"),
|
||||||
|
nullable=False,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
op.create_table(
|
op.create_table(
|
||||||
"stations",
|
"stations",
|
||||||
sa.Column("id", postgresql.UUID(as_uuid=True), primary_key=True, server_default=sa.text("gen_random_uuid()")),
|
sa.Column(
|
||||||
|
"id",
|
||||||
|
postgresql.UUID(as_uuid=True),
|
||||||
|
primary_key=True,
|
||||||
|
server_default=sa.text("gen_random_uuid()"),
|
||||||
|
),
|
||||||
sa.Column("osm_id", sa.String(length=32), nullable=True),
|
sa.Column("osm_id", sa.String(length=32), nullable=True),
|
||||||
sa.Column("name", sa.String(length=128), nullable=False),
|
sa.Column("name", sa.String(length=128), nullable=False),
|
||||||
sa.Column("code", sa.String(length=16), nullable=True),
|
sa.Column("code", sa.String(length=16), nullable=True),
|
||||||
sa.Column("location", Geometry(geometry_type="POINT", srid=4326), nullable=False),
|
sa.Column(
|
||||||
|
"location", Geometry(geometry_type="POINT", srid=4326), nullable=False
|
||||||
|
),
|
||||||
sa.Column("elevation_m", sa.Float(), nullable=True),
|
sa.Column("elevation_m", sa.Float(), nullable=True),
|
||||||
sa.Column("is_active", sa.Boolean(), nullable=False, server_default=sa.text("true")),
|
sa.Column(
|
||||||
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("timezone('utc', now())"), nullable=False),
|
"is_active", sa.Boolean(), nullable=False, server_default=sa.text("true")
|
||||||
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.text("timezone('utc', now())"), nullable=False),
|
),
|
||||||
|
sa.Column(
|
||||||
|
"created_at",
|
||||||
|
sa.DateTime(timezone=True),
|
||||||
|
server_default=sa.text("timezone('utc', now())"),
|
||||||
|
nullable=False,
|
||||||
|
),
|
||||||
|
sa.Column(
|
||||||
|
"updated_at",
|
||||||
|
sa.DateTime(timezone=True),
|
||||||
|
server_default=sa.text("timezone('utc', now())"),
|
||||||
|
nullable=False,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
"ix_stations_location", "stations", ["location"], postgresql_using="gist"
|
||||||
)
|
)
|
||||||
op.create_index("ix_stations_location", "stations", ["location"], postgresql_using="gist")
|
|
||||||
|
|
||||||
op.create_table(
|
op.create_table(
|
||||||
"tracks",
|
"tracks",
|
||||||
sa.Column("id", postgresql.UUID(as_uuid=True), primary_key=True, server_default=sa.text("gen_random_uuid()")),
|
sa.Column(
|
||||||
|
"id",
|
||||||
|
postgresql.UUID(as_uuid=True),
|
||||||
|
primary_key=True,
|
||||||
|
server_default=sa.text("gen_random_uuid()"),
|
||||||
|
),
|
||||||
sa.Column("name", sa.String(length=128), nullable=True),
|
sa.Column("name", sa.String(length=128), nullable=True),
|
||||||
sa.Column("start_station_id", postgresql.UUID(as_uuid=True), nullable=False),
|
sa.Column("start_station_id", postgresql.UUID(as_uuid=True), nullable=False),
|
||||||
sa.Column("end_station_id", postgresql.UUID(as_uuid=True), nullable=False),
|
sa.Column("end_station_id", postgresql.UUID(as_uuid=True), nullable=False),
|
||||||
sa.Column("length_meters", sa.Numeric(10, 2), nullable=True),
|
sa.Column("length_meters", sa.Numeric(10, 2), nullable=True),
|
||||||
sa.Column("max_speed_kph", sa.Integer(), nullable=True),
|
sa.Column("max_speed_kph", sa.Integer(), nullable=True),
|
||||||
sa.Column("is_bidirectional", sa.Boolean(), nullable=False, server_default=sa.text("true")),
|
sa.Column(
|
||||||
sa.Column("status", sa.String(length=32), nullable=False, server_default="planned"),
|
"is_bidirectional",
|
||||||
sa.Column("track_geometry", Geometry(geometry_type="LINESTRING", srid=4326), nullable=False),
|
sa.Boolean(),
|
||||||
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("timezone('utc', now())"), nullable=False),
|
nullable=False,
|
||||||
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.text("timezone('utc', now())"), nullable=False),
|
server_default=sa.text("true"),
|
||||||
sa.ForeignKeyConstraint(["start_station_id"], ["stations.id"], ondelete="RESTRICT"),
|
),
|
||||||
sa.ForeignKeyConstraint(["end_station_id"], ["stations.id"], ondelete="RESTRICT"),
|
sa.Column(
|
||||||
sa.UniqueConstraint("start_station_id", "end_station_id", name="uq_tracks_station_pair"),
|
"status", sa.String(length=32), nullable=False, server_default="planned"
|
||||||
|
),
|
||||||
|
sa.Column(
|
||||||
|
"track_geometry",
|
||||||
|
Geometry(geometry_type="LINESTRING", srid=4326),
|
||||||
|
nullable=False,
|
||||||
|
),
|
||||||
|
sa.Column(
|
||||||
|
"created_at",
|
||||||
|
sa.DateTime(timezone=True),
|
||||||
|
server_default=sa.text("timezone('utc', now())"),
|
||||||
|
nullable=False,
|
||||||
|
),
|
||||||
|
sa.Column(
|
||||||
|
"updated_at",
|
||||||
|
sa.DateTime(timezone=True),
|
||||||
|
server_default=sa.text("timezone('utc', now())"),
|
||||||
|
nullable=False,
|
||||||
|
),
|
||||||
|
sa.ForeignKeyConstraint(
|
||||||
|
["start_station_id"], ["stations.id"], ondelete="RESTRICT"
|
||||||
|
),
|
||||||
|
sa.ForeignKeyConstraint(
|
||||||
|
["end_station_id"], ["stations.id"], ondelete="RESTRICT"
|
||||||
|
),
|
||||||
|
sa.UniqueConstraint(
|
||||||
|
"start_station_id", "end_station_id", name="uq_tracks_station_pair"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
"ix_tracks_geometry", "tracks", ["track_geometry"], postgresql_using="gist"
|
||||||
)
|
)
|
||||||
op.create_index("ix_tracks_geometry", "tracks", ["track_geometry"], postgresql_using="gist")
|
|
||||||
|
|
||||||
op.create_table(
|
op.create_table(
|
||||||
"trains",
|
"trains",
|
||||||
sa.Column("id", postgresql.UUID(as_uuid=True), primary_key=True, server_default=sa.text("gen_random_uuid()")),
|
sa.Column(
|
||||||
|
"id",
|
||||||
|
postgresql.UUID(as_uuid=True),
|
||||||
|
primary_key=True,
|
||||||
|
server_default=sa.text("gen_random_uuid()"),
|
||||||
|
),
|
||||||
sa.Column("designation", sa.String(length=64), nullable=False, unique=True),
|
sa.Column("designation", sa.String(length=64), nullable=False, unique=True),
|
||||||
sa.Column("operator_id", postgresql.UUID(as_uuid=True), nullable=True),
|
sa.Column("operator_id", postgresql.UUID(as_uuid=True), nullable=True),
|
||||||
sa.Column("home_station_id", postgresql.UUID(as_uuid=True), nullable=True),
|
sa.Column("home_station_id", postgresql.UUID(as_uuid=True), nullable=True),
|
||||||
sa.Column("capacity", sa.Integer(), nullable=False),
|
sa.Column("capacity", sa.Integer(), nullable=False),
|
||||||
sa.Column("max_speed_kph", sa.Integer(), nullable=False),
|
sa.Column("max_speed_kph", sa.Integer(), nullable=False),
|
||||||
sa.Column("consist", sa.Text(), nullable=True),
|
sa.Column("consist", sa.Text(), nullable=True),
|
||||||
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("timezone('utc', now())"), nullable=False),
|
sa.Column(
|
||||||
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.text("timezone('utc', now())"), nullable=False),
|
"created_at",
|
||||||
|
sa.DateTime(timezone=True),
|
||||||
|
server_default=sa.text("timezone('utc', now())"),
|
||||||
|
nullable=False,
|
||||||
|
),
|
||||||
|
sa.Column(
|
||||||
|
"updated_at",
|
||||||
|
sa.DateTime(timezone=True),
|
||||||
|
server_default=sa.text("timezone('utc', now())"),
|
||||||
|
nullable=False,
|
||||||
|
),
|
||||||
sa.ForeignKeyConstraint(["operator_id"], ["users.id"], ondelete="SET NULL"),
|
sa.ForeignKeyConstraint(["operator_id"], ["users.id"], ondelete="SET NULL"),
|
||||||
sa.ForeignKeyConstraint(["home_station_id"], ["stations.id"], ondelete="SET NULL"),
|
sa.ForeignKeyConstraint(
|
||||||
|
["home_station_id"], ["stations.id"], ondelete="SET NULL"
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
op.create_table(
|
op.create_table(
|
||||||
"train_schedules",
|
"train_schedules",
|
||||||
sa.Column("id", postgresql.UUID(as_uuid=True), primary_key=True, server_default=sa.text("gen_random_uuid()")),
|
sa.Column(
|
||||||
|
"id",
|
||||||
|
postgresql.UUID(as_uuid=True),
|
||||||
|
primary_key=True,
|
||||||
|
server_default=sa.text("gen_random_uuid()"),
|
||||||
|
),
|
||||||
sa.Column("train_id", postgresql.UUID(as_uuid=True), nullable=False),
|
sa.Column("train_id", postgresql.UUID(as_uuid=True), nullable=False),
|
||||||
sa.Column("sequence_index", sa.Integer(), nullable=False),
|
sa.Column("sequence_index", sa.Integer(), nullable=False),
|
||||||
sa.Column("station_id", postgresql.UUID(as_uuid=True), nullable=False),
|
sa.Column("station_id", postgresql.UUID(as_uuid=True), nullable=False),
|
||||||
sa.Column("scheduled_arrival", sa.DateTime(timezone=True), nullable=True),
|
sa.Column("scheduled_arrival", sa.DateTime(timezone=True), nullable=True),
|
||||||
sa.Column("scheduled_departure", sa.DateTime(timezone=True), nullable=True),
|
sa.Column("scheduled_departure", sa.DateTime(timezone=True), nullable=True),
|
||||||
sa.Column("dwell_seconds", sa.Integer(), nullable=True),
|
sa.Column("dwell_seconds", sa.Integer(), nullable=True),
|
||||||
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("timezone('utc', now())"), nullable=False),
|
sa.Column(
|
||||||
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.text("timezone('utc', now())"), nullable=False),
|
"created_at",
|
||||||
|
sa.DateTime(timezone=True),
|
||||||
|
server_default=sa.text("timezone('utc', now())"),
|
||||||
|
nullable=False,
|
||||||
|
),
|
||||||
|
sa.Column(
|
||||||
|
"updated_at",
|
||||||
|
sa.DateTime(timezone=True),
|
||||||
|
server_default=sa.text("timezone('utc', now())"),
|
||||||
|
nullable=False,
|
||||||
|
),
|
||||||
sa.ForeignKeyConstraint(["train_id"], ["trains.id"], ondelete="CASCADE"),
|
sa.ForeignKeyConstraint(["train_id"], ["trains.id"], ondelete="CASCADE"),
|
||||||
sa.ForeignKeyConstraint(["station_id"], ["stations.id"], ondelete="CASCADE"),
|
sa.ForeignKeyConstraint(["station_id"], ["stations.id"], ondelete="CASCADE"),
|
||||||
sa.UniqueConstraint("train_id", "sequence_index", name="uq_train_schedule_sequence"),
|
sa.UniqueConstraint(
|
||||||
|
"train_id", "sequence_index", name="uq_train_schedule_sequence"
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -35,9 +35,7 @@ def test_me_endpoint_returns_current_user() -> None:
|
|||||||
)
|
)
|
||||||
token = login.json()["accessToken"]
|
token = login.json()["accessToken"]
|
||||||
|
|
||||||
response = client.get(
|
response = client.get("/api/auth/me", headers={"Authorization": f"Bearer {token}"})
|
||||||
"/api/auth/me", headers={"Authorization": f"Bearer {token}"}
|
|
||||||
)
|
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
assert response.json()["username"] == "demo"
|
assert response.json()["username"] == "demo"
|
||||||
|
|
||||||
|
|||||||
@@ -20,9 +20,7 @@ def test_network_snapshot_requires_authentication() -> None:
|
|||||||
|
|
||||||
def test_network_snapshot_endpoint_returns_collections() -> None:
|
def test_network_snapshot_endpoint_returns_collections() -> None:
|
||||||
token = _authenticate()
|
token = _authenticate()
|
||||||
response = client.get(
|
response = client.get("/api/network", headers={"Authorization": f"Bearer {token}"})
|
||||||
"/api/network", headers={"Authorization": f"Bearer {token}"}
|
|
||||||
)
|
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
|
|
||||||
payload = response.json()
|
payload = response.json()
|
||||||
|
|||||||
@@ -40,7 +40,9 @@ def sample_entities() -> dict[str, SimpleNamespace]:
|
|||||||
return {"station": station, "track": track, "train": train}
|
return {"station": station, "track": track, "train": train}
|
||||||
|
|
||||||
|
|
||||||
def test_network_snapshot_prefers_repository_data(monkeypatch: pytest.MonkeyPatch, sample_entities: dict[str, SimpleNamespace]) -> None:
|
def test_network_snapshot_prefers_repository_data(
|
||||||
|
monkeypatch: pytest.MonkeyPatch, sample_entities: dict[str, SimpleNamespace]
|
||||||
|
) -> None:
|
||||||
station = sample_entities["station"]
|
station = sample_entities["station"]
|
||||||
track = sample_entities["track"]
|
track = sample_entities["track"]
|
||||||
train = sample_entities["train"]
|
train = sample_entities["train"]
|
||||||
@@ -58,7 +60,9 @@ def test_network_snapshot_prefers_repository_data(monkeypatch: pytest.MonkeyPatc
|
|||||||
assert snapshot["trains"][0]["operatingTrackIds"] == []
|
assert snapshot["trains"][0]["operatingTrackIds"] == []
|
||||||
|
|
||||||
|
|
||||||
def test_network_snapshot_falls_back_when_repositories_empty(monkeypatch: pytest.MonkeyPatch) -> None:
|
def test_network_snapshot_falls_back_when_repositories_empty(
|
||||||
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
|
) -> None:
|
||||||
monkeypatch.setattr(StationRepository, "list_active", lambda self: [])
|
monkeypatch.setattr(StationRepository, "list_active", lambda self: [])
|
||||||
monkeypatch.setattr(TrackRepository, "list_all", lambda self: [])
|
monkeypatch.setattr(TrackRepository, "list_all", lambda self: [])
|
||||||
monkeypatch.setattr(TrainRepository, "list_all", lambda self: [])
|
monkeypatch.setattr(TrainRepository, "list_all", lambda self: [])
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ from typing import Any, List, cast
|
|||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from backend.app.db.models import TrainSchedule, User
|
from backend.app.db.models import TrainSchedule, User
|
||||||
from backend.app.db.unit_of_work import SqlAlchemyUnitOfWork
|
from backend.app.db.unit_of_work import SqlAlchemyUnitOfWork
|
||||||
@@ -23,7 +24,6 @@ from backend.app.repositories import (
|
|||||||
TrainScheduleRepository,
|
TrainScheduleRepository,
|
||||||
UserRepository,
|
UserRepository,
|
||||||
)
|
)
|
||||||
from sqlalchemy.orm import Session
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -50,7 +50,9 @@ class DummySession:
|
|||||||
self.statements.append(statement)
|
self.statements.append(statement)
|
||||||
return self.scalar_result
|
return self.scalar_result
|
||||||
|
|
||||||
def flush(self, _objects: list[Any] | None = None) -> None: # pragma: no cover - optional
|
def flush(
|
||||||
|
self, _objects: list[Any] | None = None
|
||||||
|
) -> None: # pragma: no cover - optional
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def commit(self) -> None: # pragma: no cover - optional
|
def commit(self) -> None: # pragma: no cover - optional
|
||||||
@@ -215,9 +217,7 @@ def test_unit_of_work_commits_and_closes_session() -> None:
|
|||||||
uow = SqlAlchemyUnitOfWork(lambda: cast(Session, session))
|
uow = SqlAlchemyUnitOfWork(lambda: cast(Session, session))
|
||||||
|
|
||||||
with uow as active:
|
with uow as active:
|
||||||
active.users.create(
|
active.users.create(UserCreate(username="demo", password_hash="hashed"))
|
||||||
UserCreate(username="demo", password_hash="hashed")
|
|
||||||
)
|
|
||||||
active.commit()
|
active.commit()
|
||||||
|
|
||||||
assert session.committed
|
assert session.committed
|
||||||
|
|||||||
Reference in New Issue
Block a user