from __future__ import annotations from dataclasses import dataclass from typing import Callable, Iterable, Optional from fastapi import Request, Response from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint from starlette.types import ASGIApp from config.settings import Settings, get_settings from sqlalchemy.orm.exc import DetachedInstanceError from models import User from monitoring.metrics import ACTIVE_CONNECTIONS from services.exceptions import EntityNotFoundError from services.security import ( JWTSettings, TokenDecodeError, TokenError, TokenExpiredError, TokenTypeMismatchError, create_access_token, create_refresh_token, decode_access_token, decode_refresh_token, ) from services.session import ( AuthSession, SessionStrategy, SessionTokens, build_session_strategy, clear_session_cookies, extract_session_tokens, set_session_cookies, ) from services.unit_of_work import UnitOfWork _AUTH_SCOPE = "auth" @dataclass(slots=True) class _ResolutionResult: session: AuthSession strategy: SessionStrategy jwt_settings: JWTSettings class AuthSessionMiddleware(BaseHTTPMiddleware): """Resolve authenticated users from session cookies and refresh tokens.""" _active_sessions: int = 0 def __init__( self, app: ASGIApp, *, settings_provider: Callable[[], Settings] = get_settings, unit_of_work_factory: Callable[[], UnitOfWork] = UnitOfWork, refresh_scopes: Iterable[str] | None = None, ) -> None: super().__init__(app) self._settings_provider = settings_provider self._unit_of_work_factory = unit_of_work_factory self._refresh_scopes = tuple( refresh_scopes) if refresh_scopes else (_AUTH_SCOPE,) async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response: resolved = self._resolve_session(request) # Track active sessions for authenticated users try: user_active = bool(resolved.session.user and getattr( resolved.session.user, "is_active", False)) except DetachedInstanceError: user_active = False if user_active: AuthSessionMiddleware._active_sessions += 1 ACTIVE_CONNECTIONS.set(AuthSessionMiddleware._active_sessions) response: Response | None = None try: response = await call_next(request) return response finally: # Always decrement the active sessions counter if we incremented it. if user_active: AuthSessionMiddleware._active_sessions = max( 0, AuthSessionMiddleware._active_sessions - 1) ACTIVE_CONNECTIONS.set(AuthSessionMiddleware._active_sessions) # Only apply session cookies if a response was produced by downstream # application. If an exception occurred before a response was created # we avoid raising another error here. import logging if response is not None: try: self._apply_session(response, resolved) except Exception: logging.getLogger(__name__).exception( "Failed to apply session cookies to response" ) else: logging.getLogger(__name__).debug( "AuthSessionMiddleware: no response produced by downstream app (response is None)" ) def _resolve_session(self, request: Request) -> _ResolutionResult: settings = self._settings_provider() jwt_settings = settings.jwt_settings() strategy = build_session_strategy(settings.session_settings()) tokens = extract_session_tokens(request, strategy) session = AuthSession(tokens=tokens) request.state.auth_session = session if tokens.access_token: if self._try_access_token(session, tokens, jwt_settings): return _ResolutionResult(session=session, strategy=strategy, jwt_settings=jwt_settings) if tokens.refresh_token: self._try_refresh_token( session, tokens.refresh_token, jwt_settings) return _ResolutionResult(session=session, strategy=strategy, jwt_settings=jwt_settings) def _try_access_token( self, session: AuthSession, tokens: SessionTokens, jwt_settings: JWTSettings, ) -> bool: try: payload = decode_access_token( tokens.access_token or "", jwt_settings) except TokenExpiredError: return False except (TokenDecodeError, TokenTypeMismatchError, TokenError): session.mark_cleared() return False user = self._load_user(payload.sub) if not user or not user.is_active or _AUTH_SCOPE not in payload.scopes: session.mark_cleared() return False session.user = user session.scopes = tuple(payload.scopes) session.set_role_slugs(role.name for role in getattr(user, "roles", []) if role) return True def _try_refresh_token( self, session: AuthSession, refresh_token: str, jwt_settings: JWTSettings, ) -> None: try: payload = decode_refresh_token(refresh_token, jwt_settings) except (TokenExpiredError, TokenDecodeError, TokenTypeMismatchError, TokenError): session.mark_cleared() return user = self._load_user(payload.sub) if not user or not user.is_active or not self._is_refresh_scope_allowed(payload.scopes): session.mark_cleared() return session.user = user session.scopes = tuple(payload.scopes) session.set_role_slugs(role.name for role in getattr(user, "roles", []) if role) access_token = create_access_token( str(user.id), jwt_settings, scopes=payload.scopes, ) new_refresh = create_refresh_token( str(user.id), jwt_settings, scopes=payload.scopes, ) session.issue_tokens(access_token=access_token, refresh_token=new_refresh) def _is_refresh_scope_allowed(self, scopes: Iterable[str]) -> bool: candidate_scopes = set(scopes) return any(scope in candidate_scopes for scope in self._refresh_scopes) def _load_user(self, subject: str) -> Optional[User]: try: user_id = int(subject) except ValueError: return None with self._unit_of_work_factory() as uow: if not uow.users: return None try: user = uow.users.get(user_id, with_roles=True) except EntityNotFoundError: return None return user def _apply_session(self, response: Response, resolved: _ResolutionResult) -> None: session = resolved.session if session.clear_cookies: clear_session_cookies(response, resolved.strategy) return if session.issued_access_token: refresh_token = session.issued_refresh_token or session.tokens.refresh_token set_session_cookies( response, access_token=session.issued_access_token, refresh_token=refresh_token, strategy=resolved.strategy, jwt_settings=resolved.jwt_settings, )