feat: Implement session management with middleware and update authentication flow
This commit is contained in:
177
middleware/auth_session.py
Normal file
177
middleware/auth_session.py
Normal file
@@ -0,0 +1,177 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, Iterable, Optional
|
||||
|
||||
from fastapi import Request, Response
|
||||
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
|
||||
from starlette.types import ASGIApp
|
||||
|
||||
from config.settings import Settings, get_settings
|
||||
from models import User
|
||||
from services.exceptions import EntityNotFoundError
|
||||
from services.security import (
|
||||
JWTSettings,
|
||||
TokenDecodeError,
|
||||
TokenError,
|
||||
TokenExpiredError,
|
||||
TokenTypeMismatchError,
|
||||
create_access_token,
|
||||
create_refresh_token,
|
||||
decode_access_token,
|
||||
decode_refresh_token,
|
||||
)
|
||||
from services.session import (
|
||||
AuthSession,
|
||||
SessionStrategy,
|
||||
SessionTokens,
|
||||
build_session_strategy,
|
||||
clear_session_cookies,
|
||||
extract_session_tokens,
|
||||
set_session_cookies,
|
||||
)
|
||||
from services.unit_of_work import UnitOfWork
|
||||
|
||||
_AUTH_SCOPE = "auth"
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class _ResolutionResult:
|
||||
session: AuthSession
|
||||
strategy: SessionStrategy
|
||||
jwt_settings: JWTSettings
|
||||
|
||||
|
||||
class AuthSessionMiddleware(BaseHTTPMiddleware):
|
||||
"""Resolve authenticated users from session cookies and refresh tokens."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
app: ASGIApp,
|
||||
*,
|
||||
settings_provider: Callable[[], Settings] = get_settings,
|
||||
unit_of_work_factory: Callable[[], UnitOfWork] = UnitOfWork,
|
||||
refresh_scopes: Iterable[str] | None = None,
|
||||
) -> None:
|
||||
super().__init__(app)
|
||||
self._settings_provider = settings_provider
|
||||
self._unit_of_work_factory = unit_of_work_factory
|
||||
self._refresh_scopes = tuple(
|
||||
refresh_scopes) if refresh_scopes else (_AUTH_SCOPE,)
|
||||
|
||||
async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response:
|
||||
resolved = self._resolve_session(request)
|
||||
response = await call_next(request)
|
||||
self._apply_session(response, resolved)
|
||||
return response
|
||||
|
||||
def _resolve_session(self, request: Request) -> _ResolutionResult:
|
||||
settings = self._settings_provider()
|
||||
jwt_settings = settings.jwt_settings()
|
||||
strategy = build_session_strategy(settings.session_settings())
|
||||
|
||||
tokens = extract_session_tokens(request, strategy)
|
||||
session = AuthSession(tokens=tokens)
|
||||
request.state.auth_session = session
|
||||
|
||||
if tokens.access_token:
|
||||
if self._try_access_token(session, tokens, jwt_settings):
|
||||
return _ResolutionResult(session=session, strategy=strategy, jwt_settings=jwt_settings)
|
||||
|
||||
if tokens.refresh_token:
|
||||
self._try_refresh_token(
|
||||
session, tokens.refresh_token, jwt_settings)
|
||||
|
||||
return _ResolutionResult(session=session, strategy=strategy, jwt_settings=jwt_settings)
|
||||
|
||||
def _try_access_token(
|
||||
self,
|
||||
session: AuthSession,
|
||||
tokens: SessionTokens,
|
||||
jwt_settings: JWTSettings,
|
||||
) -> bool:
|
||||
try:
|
||||
payload = decode_access_token(
|
||||
tokens.access_token or "", jwt_settings)
|
||||
except TokenExpiredError:
|
||||
return False
|
||||
except (TokenDecodeError, TokenTypeMismatchError, TokenError):
|
||||
session.mark_cleared()
|
||||
return False
|
||||
|
||||
user = self._load_user(payload.sub)
|
||||
if not user or not user.is_active or _AUTH_SCOPE not in payload.scopes:
|
||||
session.mark_cleared()
|
||||
return False
|
||||
|
||||
session.user = user
|
||||
session.scopes = tuple(payload.scopes)
|
||||
return True
|
||||
|
||||
def _try_refresh_token(
|
||||
self,
|
||||
session: AuthSession,
|
||||
refresh_token: str,
|
||||
jwt_settings: JWTSettings,
|
||||
) -> None:
|
||||
try:
|
||||
payload = decode_refresh_token(refresh_token, jwt_settings)
|
||||
except (TokenExpiredError, TokenDecodeError, TokenTypeMismatchError, TokenError):
|
||||
session.mark_cleared()
|
||||
return
|
||||
|
||||
user = self._load_user(payload.sub)
|
||||
if not user or not user.is_active or not self._is_refresh_scope_allowed(payload.scopes):
|
||||
session.mark_cleared()
|
||||
return
|
||||
|
||||
session.user = user
|
||||
session.scopes = tuple(payload.scopes)
|
||||
|
||||
access_token = create_access_token(
|
||||
str(user.id),
|
||||
jwt_settings,
|
||||
scopes=payload.scopes,
|
||||
)
|
||||
new_refresh = create_refresh_token(
|
||||
str(user.id),
|
||||
jwt_settings,
|
||||
scopes=payload.scopes,
|
||||
)
|
||||
session.issue_tokens(access_token=access_token,
|
||||
refresh_token=new_refresh)
|
||||
|
||||
def _is_refresh_scope_allowed(self, scopes: Iterable[str]) -> bool:
|
||||
candidate_scopes = set(scopes)
|
||||
return any(scope in candidate_scopes for scope in self._refresh_scopes)
|
||||
|
||||
def _load_user(self, subject: str) -> Optional[User]:
|
||||
try:
|
||||
user_id = int(subject)
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
with self._unit_of_work_factory() as uow:
|
||||
if not uow.users:
|
||||
return None
|
||||
try:
|
||||
user = uow.users.get(user_id, with_roles=True)
|
||||
except EntityNotFoundError:
|
||||
return None
|
||||
return user
|
||||
|
||||
def _apply_session(self, response: Response, resolved: _ResolutionResult) -> None:
|
||||
session = resolved.session
|
||||
if session.clear_cookies:
|
||||
clear_session_cookies(response, resolved.strategy)
|
||||
return
|
||||
|
||||
if session.issued_access_token:
|
||||
refresh_token = session.issued_refresh_token or session.tokens.refresh_token
|
||||
set_session_cookies(
|
||||
response,
|
||||
access_token=session.issued_access_token,
|
||||
refresh_token=refresh_token,
|
||||
strategy=resolved.strategy,
|
||||
jwt_settings=resolved.jwt_settings,
|
||||
)
|
||||
Reference in New Issue
Block a user