feat: Implement session management with middleware and update authentication flow
This commit is contained in:
@@ -5,9 +5,25 @@ from dataclasses import dataclass
|
||||
from datetime import timedelta
|
||||
from functools import lru_cache
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from services.security import JWTSettings
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class SessionSettings:
|
||||
"""Cookie and header configuration for session token transport."""
|
||||
|
||||
access_cookie_name: str
|
||||
refresh_cookie_name: str
|
||||
cookie_secure: bool
|
||||
cookie_domain: Optional[str]
|
||||
cookie_path: str
|
||||
header_name: str
|
||||
header_prefix: str
|
||||
allow_header_fallback: bool
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class Settings:
|
||||
"""Application configuration sourced from environment variables."""
|
||||
@@ -16,6 +32,14 @@ class Settings:
|
||||
jwt_algorithm: str = "HS256"
|
||||
jwt_access_token_minutes: int = 15
|
||||
jwt_refresh_token_days: int = 7
|
||||
session_access_cookie_name: str = "calminer_access_token"
|
||||
session_refresh_cookie_name: str = "calminer_refresh_token"
|
||||
session_cookie_secure: bool = False
|
||||
session_cookie_domain: Optional[str] = None
|
||||
session_cookie_path: str = "/"
|
||||
session_header_name: str = "Authorization"
|
||||
session_header_prefix: str = "Bearer"
|
||||
session_allow_header_fallback: bool = True
|
||||
|
||||
@classmethod
|
||||
def from_environment(cls) -> "Settings":
|
||||
@@ -30,6 +54,26 @@ class Settings:
|
||||
jwt_refresh_token_days=cls._int_from_env(
|
||||
"CALMINER_JWT_REFRESH_DAYS", 7
|
||||
),
|
||||
session_access_cookie_name=os.getenv(
|
||||
"CALMINER_SESSION_ACCESS_COOKIE", "calminer_access_token"
|
||||
),
|
||||
session_refresh_cookie_name=os.getenv(
|
||||
"CALMINER_SESSION_REFRESH_COOKIE", "calminer_refresh_token"
|
||||
),
|
||||
session_cookie_secure=cls._bool_from_env(
|
||||
"CALMINER_SESSION_COOKIE_SECURE", False
|
||||
),
|
||||
session_cookie_domain=os.getenv("CALMINER_SESSION_COOKIE_DOMAIN"),
|
||||
session_cookie_path=os.getenv("CALMINER_SESSION_COOKIE_PATH", "/"),
|
||||
session_header_name=os.getenv(
|
||||
"CALMINER_SESSION_HEADER_NAME", "Authorization"
|
||||
),
|
||||
session_header_prefix=os.getenv(
|
||||
"CALMINER_SESSION_HEADER_PREFIX", "Bearer"
|
||||
),
|
||||
session_allow_header_fallback=cls._bool_from_env(
|
||||
"CALMINER_SESSION_ALLOW_HEADER_FALLBACK", True
|
||||
),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@@ -42,6 +86,18 @@ class Settings:
|
||||
except ValueError:
|
||||
return default
|
||||
|
||||
@staticmethod
|
||||
def _bool_from_env(name: str, default: bool) -> bool:
|
||||
raw_value = os.getenv(name)
|
||||
if raw_value is None:
|
||||
return default
|
||||
lowered = raw_value.strip().lower()
|
||||
if lowered in {"1", "true", "yes", "on"}:
|
||||
return True
|
||||
if lowered in {"0", "false", "no", "off"}:
|
||||
return False
|
||||
return default
|
||||
|
||||
def jwt_settings(self) -> JWTSettings:
|
||||
"""Build runtime JWT settings compatible with token helpers."""
|
||||
|
||||
@@ -52,6 +108,20 @@ class Settings:
|
||||
refresh_token_ttl=timedelta(days=self.jwt_refresh_token_days),
|
||||
)
|
||||
|
||||
def session_settings(self) -> SessionSettings:
|
||||
"""Provide transport configuration for session tokens."""
|
||||
|
||||
return SessionSettings(
|
||||
access_cookie_name=self.session_access_cookie_name,
|
||||
refresh_cookie_name=self.session_refresh_cookie_name,
|
||||
cookie_secure=self.session_cookie_secure,
|
||||
cookie_domain=self.session_cookie_domain,
|
||||
cookie_path=self.session_cookie_path,
|
||||
header_name=self.session_header_name,
|
||||
header_prefix=self.session_header_prefix,
|
||||
allow_header_fallback=self.session_allow_header_fallback,
|
||||
)
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def get_settings() -> Settings:
|
||||
|
||||
@@ -2,8 +2,18 @@ from __future__ import annotations
|
||||
|
||||
from collections.abc import Generator
|
||||
|
||||
from fastapi import Depends, HTTPException, Request, status
|
||||
|
||||
from config.settings import Settings, get_settings
|
||||
from models import User
|
||||
from services.security import JWTSettings
|
||||
from services.session import (
|
||||
AuthSession,
|
||||
SessionStrategy,
|
||||
SessionTokens,
|
||||
build_session_strategy,
|
||||
extract_session_tokens,
|
||||
)
|
||||
from services.unit_of_work import UnitOfWork
|
||||
|
||||
|
||||
@@ -24,3 +34,65 @@ def get_jwt_settings() -> JWTSettings:
|
||||
"""Provide JWT runtime configuration derived from settings."""
|
||||
|
||||
return get_settings().jwt_settings()
|
||||
|
||||
|
||||
def get_session_strategy(
|
||||
settings: Settings = Depends(get_application_settings),
|
||||
) -> SessionStrategy:
|
||||
"""Yield configured session transport strategy."""
|
||||
|
||||
return build_session_strategy(settings.session_settings())
|
||||
|
||||
|
||||
def get_session_tokens(
|
||||
request: Request,
|
||||
strategy: SessionStrategy = Depends(get_session_strategy),
|
||||
) -> SessionTokens:
|
||||
"""Extract raw session tokens from the incoming request."""
|
||||
|
||||
existing = getattr(request.state, "auth_session", None)
|
||||
if isinstance(existing, AuthSession):
|
||||
return existing.tokens
|
||||
|
||||
tokens = extract_session_tokens(request, strategy)
|
||||
request.state.auth_session = AuthSession(tokens=tokens)
|
||||
return tokens
|
||||
|
||||
|
||||
def get_auth_session(
|
||||
request: Request,
|
||||
tokens: SessionTokens = Depends(get_session_tokens),
|
||||
) -> AuthSession:
|
||||
"""Provide authentication session context for the current request."""
|
||||
|
||||
existing = getattr(request.state, "auth_session", None)
|
||||
if isinstance(existing, AuthSession):
|
||||
return existing
|
||||
|
||||
if tokens.is_empty:
|
||||
session = AuthSession.anonymous()
|
||||
else:
|
||||
session = AuthSession(tokens=tokens)
|
||||
request.state.auth_session = session
|
||||
return session
|
||||
|
||||
|
||||
def get_current_user(
|
||||
session: AuthSession = Depends(get_auth_session),
|
||||
) -> User | None:
|
||||
"""Return the current authenticated user if present."""
|
||||
|
||||
return session.user
|
||||
|
||||
|
||||
def require_current_user(
|
||||
session: AuthSession = Depends(get_auth_session),
|
||||
) -> User:
|
||||
"""Ensure that a request is authenticated and return the user context."""
|
||||
|
||||
if session.user is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Authentication required.",
|
||||
)
|
||||
return session.user
|
||||
|
||||
6
main.py
6
main.py
@@ -2,8 +2,10 @@ from typing import Awaitable, Callable
|
||||
|
||||
from fastapi import FastAPI, Request, Response
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from middleware.validation import validate_json
|
||||
|
||||
from config.database import Base, engine
|
||||
from middleware.auth_session import AuthSessionMiddleware
|
||||
from middleware.validation import validate_json
|
||||
from models import (
|
||||
FinancialInput,
|
||||
Project,
|
||||
@@ -20,6 +22,8 @@ Base.metadata.create_all(bind=engine)
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
app.add_middleware(AuthSessionMiddleware)
|
||||
|
||||
|
||||
@app.middleware("http")
|
||||
async def json_validation(
|
||||
|
||||
177
middleware/auth_session.py
Normal file
177
middleware/auth_session.py
Normal file
@@ -0,0 +1,177 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, Iterable, Optional
|
||||
|
||||
from fastapi import Request, Response
|
||||
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
|
||||
from starlette.types import ASGIApp
|
||||
|
||||
from config.settings import Settings, get_settings
|
||||
from models import User
|
||||
from services.exceptions import EntityNotFoundError
|
||||
from services.security import (
|
||||
JWTSettings,
|
||||
TokenDecodeError,
|
||||
TokenError,
|
||||
TokenExpiredError,
|
||||
TokenTypeMismatchError,
|
||||
create_access_token,
|
||||
create_refresh_token,
|
||||
decode_access_token,
|
||||
decode_refresh_token,
|
||||
)
|
||||
from services.session import (
|
||||
AuthSession,
|
||||
SessionStrategy,
|
||||
SessionTokens,
|
||||
build_session_strategy,
|
||||
clear_session_cookies,
|
||||
extract_session_tokens,
|
||||
set_session_cookies,
|
||||
)
|
||||
from services.unit_of_work import UnitOfWork
|
||||
|
||||
_AUTH_SCOPE = "auth"
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class _ResolutionResult:
|
||||
session: AuthSession
|
||||
strategy: SessionStrategy
|
||||
jwt_settings: JWTSettings
|
||||
|
||||
|
||||
class AuthSessionMiddleware(BaseHTTPMiddleware):
|
||||
"""Resolve authenticated users from session cookies and refresh tokens."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
app: ASGIApp,
|
||||
*,
|
||||
settings_provider: Callable[[], Settings] = get_settings,
|
||||
unit_of_work_factory: Callable[[], UnitOfWork] = UnitOfWork,
|
||||
refresh_scopes: Iterable[str] | None = None,
|
||||
) -> None:
|
||||
super().__init__(app)
|
||||
self._settings_provider = settings_provider
|
||||
self._unit_of_work_factory = unit_of_work_factory
|
||||
self._refresh_scopes = tuple(
|
||||
refresh_scopes) if refresh_scopes else (_AUTH_SCOPE,)
|
||||
|
||||
async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response:
|
||||
resolved = self._resolve_session(request)
|
||||
response = await call_next(request)
|
||||
self._apply_session(response, resolved)
|
||||
return response
|
||||
|
||||
def _resolve_session(self, request: Request) -> _ResolutionResult:
|
||||
settings = self._settings_provider()
|
||||
jwt_settings = settings.jwt_settings()
|
||||
strategy = build_session_strategy(settings.session_settings())
|
||||
|
||||
tokens = extract_session_tokens(request, strategy)
|
||||
session = AuthSession(tokens=tokens)
|
||||
request.state.auth_session = session
|
||||
|
||||
if tokens.access_token:
|
||||
if self._try_access_token(session, tokens, jwt_settings):
|
||||
return _ResolutionResult(session=session, strategy=strategy, jwt_settings=jwt_settings)
|
||||
|
||||
if tokens.refresh_token:
|
||||
self._try_refresh_token(
|
||||
session, tokens.refresh_token, jwt_settings)
|
||||
|
||||
return _ResolutionResult(session=session, strategy=strategy, jwt_settings=jwt_settings)
|
||||
|
||||
def _try_access_token(
|
||||
self,
|
||||
session: AuthSession,
|
||||
tokens: SessionTokens,
|
||||
jwt_settings: JWTSettings,
|
||||
) -> bool:
|
||||
try:
|
||||
payload = decode_access_token(
|
||||
tokens.access_token or "", jwt_settings)
|
||||
except TokenExpiredError:
|
||||
return False
|
||||
except (TokenDecodeError, TokenTypeMismatchError, TokenError):
|
||||
session.mark_cleared()
|
||||
return False
|
||||
|
||||
user = self._load_user(payload.sub)
|
||||
if not user or not user.is_active or _AUTH_SCOPE not in payload.scopes:
|
||||
session.mark_cleared()
|
||||
return False
|
||||
|
||||
session.user = user
|
||||
session.scopes = tuple(payload.scopes)
|
||||
return True
|
||||
|
||||
def _try_refresh_token(
|
||||
self,
|
||||
session: AuthSession,
|
||||
refresh_token: str,
|
||||
jwt_settings: JWTSettings,
|
||||
) -> None:
|
||||
try:
|
||||
payload = decode_refresh_token(refresh_token, jwt_settings)
|
||||
except (TokenExpiredError, TokenDecodeError, TokenTypeMismatchError, TokenError):
|
||||
session.mark_cleared()
|
||||
return
|
||||
|
||||
user = self._load_user(payload.sub)
|
||||
if not user or not user.is_active or not self._is_refresh_scope_allowed(payload.scopes):
|
||||
session.mark_cleared()
|
||||
return
|
||||
|
||||
session.user = user
|
||||
session.scopes = tuple(payload.scopes)
|
||||
|
||||
access_token = create_access_token(
|
||||
str(user.id),
|
||||
jwt_settings,
|
||||
scopes=payload.scopes,
|
||||
)
|
||||
new_refresh = create_refresh_token(
|
||||
str(user.id),
|
||||
jwt_settings,
|
||||
scopes=payload.scopes,
|
||||
)
|
||||
session.issue_tokens(access_token=access_token,
|
||||
refresh_token=new_refresh)
|
||||
|
||||
def _is_refresh_scope_allowed(self, scopes: Iterable[str]) -> bool:
|
||||
candidate_scopes = set(scopes)
|
||||
return any(scope in candidate_scopes for scope in self._refresh_scopes)
|
||||
|
||||
def _load_user(self, subject: str) -> Optional[User]:
|
||||
try:
|
||||
user_id = int(subject)
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
with self._unit_of_work_factory() as uow:
|
||||
if not uow.users:
|
||||
return None
|
||||
try:
|
||||
user = uow.users.get(user_id, with_roles=True)
|
||||
except EntityNotFoundError:
|
||||
return None
|
||||
return user
|
||||
|
||||
def _apply_session(self, response: Response, resolved: _ResolutionResult) -> None:
|
||||
session = resolved.session
|
||||
if session.clear_cookies:
|
||||
clear_session_cookies(response, resolved.strategy)
|
||||
return
|
||||
|
||||
if session.issued_access_token:
|
||||
refresh_token = session.issued_refresh_token or session.tokens.refresh_token
|
||||
set_session_cookies(
|
||||
response,
|
||||
access_token=session.issued_access_token,
|
||||
refresh_token=refresh_token,
|
||||
strategy=resolved.strategy,
|
||||
jwt_settings=resolved.jwt_settings,
|
||||
)
|
||||
@@ -9,7 +9,13 @@ from fastapi.templating import Jinja2Templates
|
||||
from pydantic import ValidationError
|
||||
from starlette.datastructures import FormData
|
||||
|
||||
from dependencies import get_jwt_settings, get_unit_of_work
|
||||
from dependencies import (
|
||||
get_auth_session,
|
||||
get_jwt_settings,
|
||||
get_session_strategy,
|
||||
get_unit_of_work,
|
||||
require_current_user,
|
||||
)
|
||||
from models import Role, User
|
||||
from schemas.auth import (
|
||||
LoginForm,
|
||||
@@ -29,6 +35,12 @@ from services.security import (
|
||||
hash_password,
|
||||
verify_password,
|
||||
)
|
||||
from services.session import (
|
||||
AuthSession,
|
||||
SessionStrategy,
|
||||
clear_session_cookies,
|
||||
set_session_cookies,
|
||||
)
|
||||
from services.repositories import RoleRepository, UserRepository
|
||||
from services.unit_of_work import UnitOfWork
|
||||
|
||||
@@ -103,6 +115,7 @@ async def login_submit(
|
||||
request: Request,
|
||||
uow: UnitOfWork = Depends(get_unit_of_work),
|
||||
jwt_settings: JWTSettings = Depends(get_jwt_settings),
|
||||
session_strategy: SessionStrategy = Depends(get_session_strategy),
|
||||
):
|
||||
form_data = _normalise_form_data(await request.form())
|
||||
try:
|
||||
@@ -158,7 +171,31 @@ async def login_submit(
|
||||
request.url_for("dashboard.home"),
|
||||
status_code=status.HTTP_303_SEE_OTHER,
|
||||
)
|
||||
_set_auth_cookies(response, access_token, refresh_token, jwt_settings)
|
||||
set_session_cookies(
|
||||
response,
|
||||
access_token=access_token,
|
||||
refresh_token=refresh_token,
|
||||
strategy=session_strategy,
|
||||
jwt_settings=jwt_settings,
|
||||
)
|
||||
return response
|
||||
|
||||
|
||||
@router.get("/logout", include_in_schema=False, name="auth.logout")
|
||||
async def logout(
|
||||
request: Request,
|
||||
_: User = Depends(require_current_user),
|
||||
session: AuthSession = Depends(get_auth_session),
|
||||
session_strategy: SessionStrategy = Depends(get_session_strategy),
|
||||
) -> RedirectResponse:
|
||||
session.mark_cleared()
|
||||
redirect_url = request.url_for(
|
||||
"auth.login_form").include_query_params(logout="1")
|
||||
response = RedirectResponse(
|
||||
redirect_url,
|
||||
status_code=status.HTTP_303_SEE_OTHER,
|
||||
)
|
||||
clear_session_cookies(response, session_strategy)
|
||||
return response
|
||||
|
||||
|
||||
@@ -168,32 +205,6 @@ def _lookup_user(users_repo: UserRepository, identifier: str) -> User | None:
|
||||
return users_repo.get_by_username(identifier, with_roles=True)
|
||||
|
||||
|
||||
def _set_auth_cookies(
|
||||
response: RedirectResponse,
|
||||
access_token: str,
|
||||
refresh_token: str,
|
||||
jwt_settings: JWTSettings,
|
||||
) -> None:
|
||||
access_ttl = int(jwt_settings.access_token_ttl.total_seconds())
|
||||
refresh_ttl = int(jwt_settings.refresh_token_ttl.total_seconds())
|
||||
response.set_cookie(
|
||||
"calminer_access_token",
|
||||
access_token,
|
||||
httponly=True,
|
||||
secure=False,
|
||||
samesite="lax",
|
||||
max_age=max(access_ttl, 0) or None,
|
||||
)
|
||||
response.set_cookie(
|
||||
"calminer_refresh_token",
|
||||
refresh_token,
|
||||
httponly=True,
|
||||
secure=False,
|
||||
samesite="lax",
|
||||
max_age=max(refresh_ttl, 0) or None,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/register", response_class=HTMLResponse, include_in_schema=False, name="auth.register_form")
|
||||
def register_form(request: Request) -> HTMLResponse:
|
||||
return _template(
|
||||
|
||||
192
services/session.py
Normal file
192
services/session.py
Normal file
@@ -0,0 +1,192 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Literal, Optional, TYPE_CHECKING
|
||||
|
||||
from fastapi import Request, Response
|
||||
|
||||
from config.settings import SessionSettings
|
||||
from services.security import JWTSettings
|
||||
|
||||
if TYPE_CHECKING: # pragma: no cover - used only for static typing
|
||||
from models import User
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class SessionStrategy:
|
||||
"""Describe how authentication tokens are transported with requests."""
|
||||
|
||||
access_cookie_name: str
|
||||
refresh_cookie_name: str
|
||||
cookie_secure: bool
|
||||
cookie_domain: Optional[str]
|
||||
cookie_path: str
|
||||
header_name: str
|
||||
header_prefix: str
|
||||
allow_header_fallback: bool = True
|
||||
|
||||
@classmethod
|
||||
def from_settings(cls, settings: SessionSettings) -> "SessionStrategy":
|
||||
return cls(
|
||||
access_cookie_name=settings.access_cookie_name,
|
||||
refresh_cookie_name=settings.refresh_cookie_name,
|
||||
cookie_secure=settings.cookie_secure,
|
||||
cookie_domain=settings.cookie_domain,
|
||||
cookie_path=settings.cookie_path,
|
||||
header_name=settings.header_name,
|
||||
header_prefix=settings.header_prefix,
|
||||
allow_header_fallback=settings.allow_header_fallback,
|
||||
)
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class SessionTokens:
|
||||
"""Raw access and refresh tokens extracted from the transport layer."""
|
||||
|
||||
access_token: Optional[str]
|
||||
refresh_token: Optional[str]
|
||||
access_token_source: Literal["cookie", "header", "none"] = "none"
|
||||
|
||||
@property
|
||||
def has_access(self) -> bool:
|
||||
return bool(self.access_token)
|
||||
|
||||
@property
|
||||
def has_refresh(self) -> bool:
|
||||
return bool(self.refresh_token)
|
||||
|
||||
@property
|
||||
def is_empty(self) -> bool:
|
||||
return not self.has_access and not self.has_refresh
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class AuthSession:
|
||||
"""Holds authenticated user context resolved from session tokens."""
|
||||
|
||||
tokens: SessionTokens
|
||||
user: Optional["User"] = None
|
||||
scopes: tuple[str, ...] = ()
|
||||
issued_access_token: Optional[str] = None
|
||||
issued_refresh_token: Optional[str] = None
|
||||
clear_cookies: bool = False
|
||||
|
||||
@property
|
||||
def is_authenticated(self) -> bool:
|
||||
return self.user is not None
|
||||
|
||||
@classmethod
|
||||
def anonymous(cls) -> "AuthSession":
|
||||
return cls(tokens=SessionTokens(access_token=None, refresh_token=None))
|
||||
|
||||
def issue_tokens(
|
||||
self,
|
||||
*,
|
||||
access_token: str,
|
||||
refresh_token: Optional[str] = None,
|
||||
access_source: Literal["cookie", "header", "none"] = "cookie",
|
||||
) -> None:
|
||||
self.issued_access_token = access_token
|
||||
if refresh_token is not None:
|
||||
self.issued_refresh_token = refresh_token
|
||||
self.tokens = SessionTokens(
|
||||
access_token=access_token,
|
||||
refresh_token=refresh_token if refresh_token is not None else self.tokens.refresh_token,
|
||||
access_token_source=access_source,
|
||||
)
|
||||
|
||||
def mark_cleared(self) -> None:
|
||||
self.clear_cookies = True
|
||||
self.tokens = SessionTokens(access_token=None, refresh_token=None)
|
||||
self.user = None
|
||||
self.scopes = ()
|
||||
|
||||
|
||||
def extract_session_tokens(request: Request, strategy: SessionStrategy) -> SessionTokens:
|
||||
"""Pull tokens from cookies or headers according to configured strategy."""
|
||||
|
||||
access_token: Optional[str] = None
|
||||
refresh_token: Optional[str] = None
|
||||
access_source: Literal["cookie", "header", "none"] = "none"
|
||||
|
||||
if strategy.access_cookie_name in request.cookies:
|
||||
access_token = request.cookies.get(strategy.access_cookie_name) or None
|
||||
if access_token:
|
||||
access_source = "cookie"
|
||||
|
||||
if strategy.refresh_cookie_name in request.cookies:
|
||||
refresh_token = request.cookies.get(
|
||||
strategy.refresh_cookie_name) or None
|
||||
|
||||
if not access_token and strategy.allow_header_fallback:
|
||||
header_value = request.headers.get(strategy.header_name)
|
||||
if header_value:
|
||||
candidate = header_value.strip()
|
||||
prefix = f"{strategy.header_prefix} " if strategy.header_prefix else ""
|
||||
if prefix and candidate.lower().startswith(prefix.lower()):
|
||||
candidate = candidate[len(prefix):].strip()
|
||||
if candidate:
|
||||
access_token = candidate
|
||||
access_source = "header"
|
||||
|
||||
return SessionTokens(
|
||||
access_token=access_token,
|
||||
refresh_token=refresh_token,
|
||||
access_token_source=access_source,
|
||||
)
|
||||
|
||||
|
||||
def build_session_strategy(settings: SessionSettings) -> SessionStrategy:
|
||||
"""Create a session strategy object from settings configuration."""
|
||||
|
||||
return SessionStrategy.from_settings(settings)
|
||||
|
||||
|
||||
def set_session_cookies(
|
||||
response: Response,
|
||||
*,
|
||||
access_token: str,
|
||||
refresh_token: Optional[str],
|
||||
strategy: SessionStrategy,
|
||||
jwt_settings: JWTSettings,
|
||||
) -> None:
|
||||
"""Persist session cookies on an outgoing response."""
|
||||
|
||||
access_ttl = int(jwt_settings.access_token_ttl.total_seconds())
|
||||
refresh_ttl = int(jwt_settings.refresh_token_ttl.total_seconds())
|
||||
response.set_cookie(
|
||||
strategy.access_cookie_name,
|
||||
access_token,
|
||||
httponly=True,
|
||||
secure=strategy.cookie_secure,
|
||||
samesite="lax",
|
||||
max_age=max(access_ttl, 0) or None,
|
||||
domain=strategy.cookie_domain,
|
||||
path=strategy.cookie_path,
|
||||
)
|
||||
if refresh_token is not None:
|
||||
response.set_cookie(
|
||||
strategy.refresh_cookie_name,
|
||||
refresh_token,
|
||||
httponly=True,
|
||||
secure=strategy.cookie_secure,
|
||||
samesite="lax",
|
||||
max_age=max(refresh_ttl, 0) or None,
|
||||
domain=strategy.cookie_domain,
|
||||
path=strategy.cookie_path,
|
||||
)
|
||||
|
||||
|
||||
def clear_session_cookies(response: Response, strategy: SessionStrategy) -> None:
|
||||
"""Remove session cookies from the client."""
|
||||
|
||||
response.delete_cookie(
|
||||
strategy.access_cookie_name,
|
||||
domain=strategy.cookie_domain,
|
||||
path=strategy.cookie_path,
|
||||
)
|
||||
response.delete_cookie(
|
||||
strategy.refresh_cookie_name,
|
||||
domain=strategy.cookie_domain,
|
||||
path=strategy.cookie_path,
|
||||
)
|
||||
@@ -1,52 +1,88 @@
|
||||
{% set dashboard_href = request.url_for('dashboard.home') if request else '/' %}
|
||||
{% set projects_href = request.url_for('projects.project_list_page') if request
|
||||
else '/projects/ui' %} {% set project_create_href =
|
||||
request.url_for('projects.create_project_form') if request else
|
||||
'/projects/create' %} {% set login_href = request.url_for('auth.login_form') if
|
||||
request else '/login' %} {% set register_href =
|
||||
request.url_for('auth.register_form') if request else '/register' %} {% set
|
||||
forgot_href = request.url_for('auth.password_reset_request_form') if request
|
||||
else '/forgot-password' %} {% set nav_groups = [ { "label": "Workspace",
|
||||
"links": [ {"href": dashboard_href, "label": "Dashboard", "match_prefix": "/"},
|
||||
{% set projects_href = request.url_for('projects.project_list_page') if request else '/projects/ui' %}
|
||||
{% set project_create_href = request.url_for('projects.create_project_form') if request else '/projects/create' %}
|
||||
{% set auth_session = request.state.auth_session if request else None %}
|
||||
{% set is_authenticated = auth_session and auth_session.is_authenticated %}
|
||||
|
||||
{% if is_authenticated %}
|
||||
{% set logout_href = request.url_for('auth.logout') if request else '/logout' %}
|
||||
{% set account_links = [
|
||||
{"href": logout_href, "label": "Logout", "match_prefix": "/logout"}
|
||||
] %}
|
||||
{% else %}
|
||||
{% set login_href = request.url_for('auth.login_form') if request else '/login' %}
|
||||
{% set register_href = request.url_for('auth.register_form') if request else '/register' %}
|
||||
{% set forgot_href = request.url_for('auth.password_reset_request_form') if request else '/forgot-password' %}
|
||||
{% set account_links = [
|
||||
{"href": login_href, "label": "Login", "match_prefix": "/login"},
|
||||
{"href": register_href, "label": "Register", "match_prefix": "/register"},
|
||||
{"href": forgot_href, "label": "Forgot Password", "match_prefix": "/forgot-password"}
|
||||
] %}
|
||||
{% endif %}
|
||||
{% set nav_groups = [
|
||||
{
|
||||
"label": "Workspace",
|
||||
"links": [
|
||||
{"href": dashboard_href, "label": "Dashboard", "match_prefix": "/"},
|
||||
{"href": projects_href, "label": "Projects", "match_prefix": "/projects"},
|
||||
{"href": project_create_href, "label": "New Project", "match_prefix":
|
||||
"/projects/create"}, ], }, { "label": "Insights", "links": [ {"href":
|
||||
"/ui/simulations", "label": "Simulations"}, {"href": "/ui/reporting", "label":
|
||||
"Reporting"}, ], }, { "label": "Configuration", "links": [ { "href":
|
||||
"/ui/settings", "label": "Settings", "children": [ {"href": "/theme-settings",
|
||||
"label": "Themes"}, {"href": "/ui/currencies", "label": "Currency Management"},
|
||||
], }, ], }, { "label": "Account", "links": [ {"href": login_href, "label":
|
||||
"Login", "match_prefix": "/login"}, {"href": register_href, "label": "Register",
|
||||
"match_prefix": "/register"}, {"href": forgot_href, "label": "Forgot Password",
|
||||
"match_prefix": "/forgot-password"}, ], }, ] %}
|
||||
{"href": project_create_href, "label": "New Project", "match_prefix": "/projects/create"}
|
||||
]
|
||||
},
|
||||
{
|
||||
"label": "Insights",
|
||||
"links": [
|
||||
{"href": "/ui/simulations", "label": "Simulations"},
|
||||
{"href": "/ui/reporting", "label": "Reporting"}
|
||||
]
|
||||
},
|
||||
{
|
||||
"label": "Configuration",
|
||||
"links": [
|
||||
{
|
||||
"href": "/ui/settings",
|
||||
"label": "Settings",
|
||||
"children": [
|
||||
{"href": "/theme-settings", "label": "Themes"},
|
||||
{"href": "/ui/currencies", "label": "Currency Management"}
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"label": "Account",
|
||||
"links": account_links
|
||||
}
|
||||
] %}
|
||||
|
||||
<nav class="sidebar-nav" aria-label="Primary navigation">
|
||||
{% set current_path = request.url.path if request else "" %} {% for group in
|
||||
nav_groups %}
|
||||
{% set current_path = request.url.path if request else '' %}
|
||||
{% for group in nav_groups %}
|
||||
{% if group.links %}
|
||||
<div class="sidebar-section">
|
||||
<div class="sidebar-section-label">{{ group.label }}</div>
|
||||
<div class="sidebar-section-links">
|
||||
{% for link in group.links %} {% set href = link.href %} {% set
|
||||
match_prefix = link.get('match_prefix', href) %} {% if match_prefix == '/'
|
||||
%} {% set is_active = current_path == '/' %} {% else %} {% set is_active =
|
||||
current_path.startswith(match_prefix) %} {% endif %}
|
||||
{% for link in group.links %}
|
||||
{% set href = link.href %}
|
||||
{% set match_prefix = link.get('match_prefix', href) %}
|
||||
{% if match_prefix == '/' %}
|
||||
{% set is_active = current_path == '/' %}
|
||||
{% else %}
|
||||
{% set is_active = current_path.startswith(match_prefix) %}
|
||||
{% endif %}
|
||||
<div class="sidebar-link-block">
|
||||
<a
|
||||
href="{{ href }}"
|
||||
class="sidebar-link{% if is_active %} is-active{% endif %}"
|
||||
>
|
||||
<a href="{{ href }}" class="sidebar-link{% if is_active %} is-active{% endif %}">
|
||||
{{ link.label }}
|
||||
</a>
|
||||
{% if link.children %}
|
||||
<div class="sidebar-sublinks">
|
||||
{% for child in link.children %} {% set child_prefix =
|
||||
child.get('match_prefix', child.href) %} {% if child_prefix == '/' %}
|
||||
{% set child_active = current_path == '/' %} {% else %} {% set
|
||||
child_active = current_path.startswith(child_prefix) %} {% endif %}
|
||||
<a
|
||||
href="{{ child.href }}"
|
||||
class="sidebar-sublink{% if child_active %} is-active{% endif %}"
|
||||
>
|
||||
{% for child in link.children %}
|
||||
{% set child_prefix = child.get('match_prefix', child.href) %}
|
||||
{% if child_prefix == '/' %}
|
||||
{% set child_active = current_path == '/' %}
|
||||
{% else %}
|
||||
{% set child_active = current_path.startswith(child_prefix) %}
|
||||
{% endif %}
|
||||
<a href="{{ child.href }}" class="sidebar-sublink{% if child_active %} is-active{% endif %}">
|
||||
{{ child.label }}
|
||||
</a>
|
||||
{% endfor %}
|
||||
@@ -56,5 +92,6 @@ else '/forgot-password' %} {% set nav_groups = [ { "label": "Workspace",
|
||||
{% endfor %}
|
||||
</div>
|
||||
</div>
|
||||
{% endif %}
|
||||
{% endfor %}
|
||||
</nav>
|
||||
|
||||
@@ -1,15 +1,19 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Iterator
|
||||
from typing import cast
|
||||
from urllib.parse import parse_qs, urlparse
|
||||
|
||||
import pytest
|
||||
from fastapi import FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from models import Role, User, UserRole
|
||||
from dependencies import get_auth_session, require_current_user
|
||||
from services.security import hash_password
|
||||
from services.session import AuthSession, SessionTokens
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
@@ -223,7 +227,8 @@ class TestPasswordResetFlow:
|
||||
data={"email": "mismatch@example.com"},
|
||||
follow_redirects=False,
|
||||
)
|
||||
token = parse_qs(urlparse(request_response.headers["location"]).query)["token"][0]
|
||||
token = parse_qs(urlparse(request_response.headers["location"]).query)[
|
||||
"token"][0]
|
||||
|
||||
submit_response = client.post(
|
||||
"/reset-password",
|
||||
@@ -237,3 +242,45 @@ class TestPasswordResetFlow:
|
||||
|
||||
assert submit_response.status_code == 400
|
||||
assert "Passwords do not match" in submit_response.text
|
||||
|
||||
|
||||
class TestLogoutFlow:
|
||||
def test_logout_clears_cookies_and_redirects(
|
||||
self,
|
||||
client: TestClient,
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
user = User(
|
||||
email="logout@example.com",
|
||||
username="logoutuser",
|
||||
password_hash=hash_password("SecureP@ss1"),
|
||||
is_active=True,
|
||||
)
|
||||
db_session.add(user)
|
||||
db_session.commit()
|
||||
|
||||
session = AuthSession(
|
||||
tokens=SessionTokens(
|
||||
access_token="access-token",
|
||||
refresh_token="refresh-token",
|
||||
access_token_source="cookie",
|
||||
),
|
||||
user=user,
|
||||
)
|
||||
|
||||
app = cast(FastAPI, client.app)
|
||||
app.dependency_overrides[require_current_user] = lambda: user
|
||||
app.dependency_overrides[get_auth_session] = lambda: session
|
||||
|
||||
try:
|
||||
response = client.get("/logout", follow_redirects=False)
|
||||
finally:
|
||||
app.dependency_overrides.pop(require_current_user, None)
|
||||
app.dependency_overrides.pop(get_auth_session, None)
|
||||
|
||||
assert response.status_code == 303
|
||||
location = response.headers.get("location")
|
||||
assert location and location.startswith("http://testserver/login")
|
||||
set_cookie_header = response.headers.get("set-cookie") or ""
|
||||
assert "calminer_access_token=" in set_cookie_header
|
||||
assert "Max-Age=0" in set_cookie_header or "expires=" in set_cookie_header.lower()
|
||||
|
||||
111
tests/test_auth_session_middleware.py
Normal file
111
tests/test_auth_session_middleware.py
Normal file
@@ -0,0 +1,111 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import timedelta
|
||||
from typing import Iterator
|
||||
|
||||
import pytest
|
||||
from fastapi import Depends, FastAPI
|
||||
from fastapi.responses import JSONResponse
|
||||
from fastapi.testclient import TestClient
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from config.settings import get_settings
|
||||
from dependencies import get_unit_of_work, require_current_user
|
||||
from middleware.auth_session import AuthSessionMiddleware
|
||||
from models import User
|
||||
from services.security import create_access_token, create_refresh_token, hash_password
|
||||
from services.unit_of_work import UnitOfWork
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def auth_app(session_factory: sessionmaker) -> Iterator[TestClient]:
|
||||
app = FastAPI()
|
||||
|
||||
def _override_uow() -> Iterator[UnitOfWork]:
|
||||
with UnitOfWork(session_factory=session_factory) as uow:
|
||||
yield uow
|
||||
|
||||
app.dependency_overrides[get_unit_of_work] = _override_uow
|
||||
|
||||
@app.get("/me")
|
||||
def read_me(user: User = Depends(require_current_user)) -> JSONResponse:
|
||||
return JSONResponse({"id": user.id, "username": user.username})
|
||||
|
||||
app.add_middleware(
|
||||
AuthSessionMiddleware,
|
||||
unit_of_work_factory=lambda: UnitOfWork(
|
||||
session_factory=session_factory),
|
||||
)
|
||||
|
||||
client = TestClient(app)
|
||||
try:
|
||||
yield client
|
||||
finally:
|
||||
client.close()
|
||||
|
||||
|
||||
def _create_user(session_factory: sessionmaker) -> User:
|
||||
with UnitOfWork(session_factory=session_factory) as uow:
|
||||
assert uow.users is not None
|
||||
user = User(
|
||||
email="jane@example.com",
|
||||
username="jane",
|
||||
password_hash=hash_password("secret"),
|
||||
is_active=True,
|
||||
is_superuser=False,
|
||||
)
|
||||
uow.users.create(user)
|
||||
return user
|
||||
|
||||
|
||||
def _issue_tokens(user: User) -> tuple[str, str]:
|
||||
settings = get_settings().jwt_settings()
|
||||
access = create_access_token(str(user.id), settings, scopes=["auth"])
|
||||
refresh = create_refresh_token(str(user.id), settings, scopes=["auth"])
|
||||
return access, refresh
|
||||
|
||||
|
||||
def test_middleware_populates_current_user(auth_app: TestClient, session_factory: sessionmaker) -> None:
|
||||
user = _create_user(session_factory)
|
||||
access, refresh = _issue_tokens(user)
|
||||
|
||||
auth_app.cookies.set("calminer_access_token", access)
|
||||
auth_app.cookies.set("calminer_refresh_token", refresh)
|
||||
|
||||
response = auth_app.get("/me")
|
||||
assert response.status_code == 200
|
||||
payload = response.json()
|
||||
assert payload["id"] == user.id
|
||||
assert payload["username"] == user.username
|
||||
|
||||
|
||||
def test_middleware_refreshes_expired_access_token(auth_app: TestClient, session_factory: sessionmaker) -> None:
|
||||
user = _create_user(session_factory)
|
||||
settings = get_settings().jwt_settings()
|
||||
expired = create_access_token(
|
||||
str(user.id),
|
||||
settings,
|
||||
scopes=["auth"],
|
||||
expires_delta=timedelta(seconds=-1),
|
||||
)
|
||||
refresh = create_refresh_token(str(user.id), settings, scopes=["auth"])
|
||||
|
||||
auth_app.cookies.set("calminer_access_token", expired)
|
||||
auth_app.cookies.set("calminer_refresh_token", refresh)
|
||||
|
||||
response = auth_app.get("/me")
|
||||
assert response.status_code == 200
|
||||
new_access = response.cookies.get("calminer_access_token")
|
||||
new_refresh = response.cookies.get("calminer_refresh_token")
|
||||
assert new_access is not None and new_access != expired
|
||||
assert new_refresh is not None
|
||||
|
||||
|
||||
def test_middleware_blocks_invalid_tokens(auth_app: TestClient) -> None:
|
||||
auth_app.cookies.set("calminer_access_token", "invalid-token")
|
||||
auth_app.cookies.set("calminer_refresh_token", "invalid-token")
|
||||
|
||||
response = auth_app.get("/me")
|
||||
assert response.status_code == 401
|
||||
set_cookies = response.headers.get_list("set-cookie")
|
||||
assert any("calminer_access_token=" in value for value in set_cookies)
|
||||
Reference in New Issue
Block a user