Files
calminer/middleware/auth_session.py
zwitschi 6e466a3fd2 Refactor database initialization and remove Alembic migrations
- Removed legacy Alembic migration files and consolidated schema management into a new Pydantic-backed initializer (`scripts/init_db.py`).
- Updated `main.py` to ensure the new DB initializer runs on startup, maintaining idempotency.
- Adjusted session management in `config/database.py` to prevent DetachedInstanceError.
- Introduced new enums in `models/enums.py` for better organization and clarity.
- Refactored various models to utilize the new enums, improving code maintainability.
- Enhanced middleware to handle JSON validation more robustly, ensuring non-JSON requests do not trigger JSON errors.
- Added tests for middleware and enums to ensure expected behavior and consistency.
- Updated changelog to reflect significant changes and improvements.
2025-11-12 16:29:44 +01:00

217 lines
7.3 KiB
Python

from __future__ import annotations
from dataclasses import dataclass
from typing import Callable, Iterable, Optional
from fastapi import Request, Response
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
from starlette.types import ASGIApp
from config.settings import Settings, get_settings
from sqlalchemy.orm.exc import DetachedInstanceError
from models import User
from monitoring.metrics import ACTIVE_CONNECTIONS
from services.exceptions import EntityNotFoundError
from services.security import (
JWTSettings,
TokenDecodeError,
TokenError,
TokenExpiredError,
TokenTypeMismatchError,
create_access_token,
create_refresh_token,
decode_access_token,
decode_refresh_token,
)
from services.session import (
AuthSession,
SessionStrategy,
SessionTokens,
build_session_strategy,
clear_session_cookies,
extract_session_tokens,
set_session_cookies,
)
from services.unit_of_work import UnitOfWork
_AUTH_SCOPE = "auth"
@dataclass(slots=True)
class _ResolutionResult:
session: AuthSession
strategy: SessionStrategy
jwt_settings: JWTSettings
class AuthSessionMiddleware(BaseHTTPMiddleware):
"""Resolve authenticated users from session cookies and refresh tokens."""
_active_sessions: int = 0
def __init__(
self,
app: ASGIApp,
*,
settings_provider: Callable[[], Settings] = get_settings,
unit_of_work_factory: Callable[[], UnitOfWork] = UnitOfWork,
refresh_scopes: Iterable[str] | None = None,
) -> None:
super().__init__(app)
self._settings_provider = settings_provider
self._unit_of_work_factory = unit_of_work_factory
self._refresh_scopes = tuple(
refresh_scopes) if refresh_scopes else (_AUTH_SCOPE,)
async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response:
resolved = self._resolve_session(request)
# Track active sessions for authenticated users
try:
user_active = bool(resolved.session.user and getattr(
resolved.session.user, "is_active", False))
except DetachedInstanceError:
user_active = False
if user_active:
AuthSessionMiddleware._active_sessions += 1
ACTIVE_CONNECTIONS.set(AuthSessionMiddleware._active_sessions)
response: Response | None = None
try:
response = await call_next(request)
return response
finally:
# Always decrement the active sessions counter if we incremented it.
if user_active:
AuthSessionMiddleware._active_sessions = max(
0, AuthSessionMiddleware._active_sessions - 1)
ACTIVE_CONNECTIONS.set(AuthSessionMiddleware._active_sessions)
# Only apply session cookies if a response was produced by downstream
# application. If an exception occurred before a response was created
# we avoid raising another error here.
import logging
if response is not None:
try:
self._apply_session(response, resolved)
except Exception:
logging.getLogger(__name__).exception(
"Failed to apply session cookies to response"
)
else:
logging.getLogger(__name__).debug(
"AuthSessionMiddleware: no response produced by downstream app (response is None)"
)
def _resolve_session(self, request: Request) -> _ResolutionResult:
settings = self._settings_provider()
jwt_settings = settings.jwt_settings()
strategy = build_session_strategy(settings.session_settings())
tokens = extract_session_tokens(request, strategy)
session = AuthSession(tokens=tokens)
request.state.auth_session = session
if tokens.access_token:
if self._try_access_token(session, tokens, jwt_settings):
return _ResolutionResult(session=session, strategy=strategy, jwt_settings=jwt_settings)
if tokens.refresh_token:
self._try_refresh_token(
session, tokens.refresh_token, jwt_settings)
return _ResolutionResult(session=session, strategy=strategy, jwt_settings=jwt_settings)
def _try_access_token(
self,
session: AuthSession,
tokens: SessionTokens,
jwt_settings: JWTSettings,
) -> bool:
try:
payload = decode_access_token(
tokens.access_token or "", jwt_settings)
except TokenExpiredError:
return False
except (TokenDecodeError, TokenTypeMismatchError, TokenError):
session.mark_cleared()
return False
user = self._load_user(payload.sub)
if not user or not user.is_active or _AUTH_SCOPE not in payload.scopes:
session.mark_cleared()
return False
session.user = user
session.scopes = tuple(payload.scopes)
return True
def _try_refresh_token(
self,
session: AuthSession,
refresh_token: str,
jwt_settings: JWTSettings,
) -> None:
try:
payload = decode_refresh_token(refresh_token, jwt_settings)
except (TokenExpiredError, TokenDecodeError, TokenTypeMismatchError, TokenError):
session.mark_cleared()
return
user = self._load_user(payload.sub)
if not user or not user.is_active or not self._is_refresh_scope_allowed(payload.scopes):
session.mark_cleared()
return
session.user = user
session.scopes = tuple(payload.scopes)
access_token = create_access_token(
str(user.id),
jwt_settings,
scopes=payload.scopes,
)
new_refresh = create_refresh_token(
str(user.id),
jwt_settings,
scopes=payload.scopes,
)
session.issue_tokens(access_token=access_token,
refresh_token=new_refresh)
def _is_refresh_scope_allowed(self, scopes: Iterable[str]) -> bool:
candidate_scopes = set(scopes)
return any(scope in candidate_scopes for scope in self._refresh_scopes)
def _load_user(self, subject: str) -> Optional[User]:
try:
user_id = int(subject)
except ValueError:
return None
with self._unit_of_work_factory() as uow:
if not uow.users:
return None
try:
user = uow.users.get(user_id, with_roles=True)
except EntityNotFoundError:
return None
return user
def _apply_session(self, response: Response, resolved: _ResolutionResult) -> None:
session = resolved.session
if session.clear_cookies:
clear_session_cookies(response, resolved.strategy)
return
if session.issued_access_token:
refresh_token = session.issued_refresh_token or session.tokens.refresh_token
set_session_cookies(
response,
access_token=session.issued_access_token,
refresh_token=refresh_token,
strategy=resolved.strategy,
jwt_settings=resolved.jwt_settings,
)