from __future__ import annotations from dataclasses import dataclass from typing import Callable, Iterable, Optional from fastapi import Request, Response from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint from starlette.types import ASGIApp from config.settings import Settings, get_settings from models import User from services.exceptions import EntityNotFoundError from services.security import ( JWTSettings, TokenDecodeError, TokenError, TokenExpiredError, TokenTypeMismatchError, create_access_token, create_refresh_token, decode_access_token, decode_refresh_token, ) from services.session import ( AuthSession, SessionStrategy, SessionTokens, build_session_strategy, clear_session_cookies, extract_session_tokens, set_session_cookies, ) from services.unit_of_work import UnitOfWork _AUTH_SCOPE = "auth" @dataclass(slots=True) class _ResolutionResult: session: AuthSession strategy: SessionStrategy jwt_settings: JWTSettings class AuthSessionMiddleware(BaseHTTPMiddleware): """Resolve authenticated users from session cookies and refresh tokens.""" def __init__( self, app: ASGIApp, *, settings_provider: Callable[[], Settings] = get_settings, unit_of_work_factory: Callable[[], UnitOfWork] = UnitOfWork, refresh_scopes: Iterable[str] | None = None, ) -> None: super().__init__(app) self._settings_provider = settings_provider self._unit_of_work_factory = unit_of_work_factory self._refresh_scopes = tuple( refresh_scopes) if refresh_scopes else (_AUTH_SCOPE,) async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response: resolved = self._resolve_session(request) response = await call_next(request) self._apply_session(response, resolved) return response def _resolve_session(self, request: Request) -> _ResolutionResult: settings = self._settings_provider() jwt_settings = settings.jwt_settings() strategy = build_session_strategy(settings.session_settings()) tokens = extract_session_tokens(request, strategy) session = AuthSession(tokens=tokens) request.state.auth_session = session if tokens.access_token: if self._try_access_token(session, tokens, jwt_settings): return _ResolutionResult(session=session, strategy=strategy, jwt_settings=jwt_settings) if tokens.refresh_token: self._try_refresh_token( session, tokens.refresh_token, jwt_settings) return _ResolutionResult(session=session, strategy=strategy, jwt_settings=jwt_settings) def _try_access_token( self, session: AuthSession, tokens: SessionTokens, jwt_settings: JWTSettings, ) -> bool: try: payload = decode_access_token( tokens.access_token or "", jwt_settings) except TokenExpiredError: return False except (TokenDecodeError, TokenTypeMismatchError, TokenError): session.mark_cleared() return False user = self._load_user(payload.sub) if not user or not user.is_active or _AUTH_SCOPE not in payload.scopes: session.mark_cleared() return False session.user = user session.scopes = tuple(payload.scopes) return True def _try_refresh_token( self, session: AuthSession, refresh_token: str, jwt_settings: JWTSettings, ) -> None: try: payload = decode_refresh_token(refresh_token, jwt_settings) except (TokenExpiredError, TokenDecodeError, TokenTypeMismatchError, TokenError): session.mark_cleared() return user = self._load_user(payload.sub) if not user or not user.is_active or not self._is_refresh_scope_allowed(payload.scopes): session.mark_cleared() return session.user = user session.scopes = tuple(payload.scopes) access_token = create_access_token( str(user.id), jwt_settings, scopes=payload.scopes, ) new_refresh = create_refresh_token( str(user.id), jwt_settings, scopes=payload.scopes, ) session.issue_tokens(access_token=access_token, refresh_token=new_refresh) def _is_refresh_scope_allowed(self, scopes: Iterable[str]) -> bool: candidate_scopes = set(scopes) return any(scope in candidate_scopes for scope in self._refresh_scopes) def _load_user(self, subject: str) -> Optional[User]: try: user_id = int(subject) except ValueError: return None with self._unit_of_work_factory() as uow: if not uow.users: return None try: user = uow.users.get(user_id, with_roles=True) except EntityNotFoundError: return None return user def _apply_session(self, response: Response, resolved: _ResolutionResult) -> None: session = resolved.session if session.clear_cookies: clear_session_cookies(response, resolved.strategy) return if session.issued_access_token: refresh_token = session.issued_refresh_token or session.tokens.refresh_token set_session_cookies( response, access_token=session.issued_access_token, refresh_token=refresh_token, strategy=resolved.strategy, jwt_settings=resolved.jwt_settings, )