feat: Implement session management with middleware and update authentication flow
This commit is contained in:
@@ -5,9 +5,25 @@ from dataclasses import dataclass
|
|||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
from services.security import JWTSettings
|
from services.security import JWTSettings
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True, slots=True)
|
||||||
|
class SessionSettings:
|
||||||
|
"""Cookie and header configuration for session token transport."""
|
||||||
|
|
||||||
|
access_cookie_name: str
|
||||||
|
refresh_cookie_name: str
|
||||||
|
cookie_secure: bool
|
||||||
|
cookie_domain: Optional[str]
|
||||||
|
cookie_path: str
|
||||||
|
header_name: str
|
||||||
|
header_prefix: str
|
||||||
|
allow_header_fallback: bool
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True, slots=True)
|
@dataclass(frozen=True, slots=True)
|
||||||
class Settings:
|
class Settings:
|
||||||
"""Application configuration sourced from environment variables."""
|
"""Application configuration sourced from environment variables."""
|
||||||
@@ -16,6 +32,14 @@ class Settings:
|
|||||||
jwt_algorithm: str = "HS256"
|
jwt_algorithm: str = "HS256"
|
||||||
jwt_access_token_minutes: int = 15
|
jwt_access_token_minutes: int = 15
|
||||||
jwt_refresh_token_days: int = 7
|
jwt_refresh_token_days: int = 7
|
||||||
|
session_access_cookie_name: str = "calminer_access_token"
|
||||||
|
session_refresh_cookie_name: str = "calminer_refresh_token"
|
||||||
|
session_cookie_secure: bool = False
|
||||||
|
session_cookie_domain: Optional[str] = None
|
||||||
|
session_cookie_path: str = "/"
|
||||||
|
session_header_name: str = "Authorization"
|
||||||
|
session_header_prefix: str = "Bearer"
|
||||||
|
session_allow_header_fallback: bool = True
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_environment(cls) -> "Settings":
|
def from_environment(cls) -> "Settings":
|
||||||
@@ -30,6 +54,26 @@ class Settings:
|
|||||||
jwt_refresh_token_days=cls._int_from_env(
|
jwt_refresh_token_days=cls._int_from_env(
|
||||||
"CALMINER_JWT_REFRESH_DAYS", 7
|
"CALMINER_JWT_REFRESH_DAYS", 7
|
||||||
),
|
),
|
||||||
|
session_access_cookie_name=os.getenv(
|
||||||
|
"CALMINER_SESSION_ACCESS_COOKIE", "calminer_access_token"
|
||||||
|
),
|
||||||
|
session_refresh_cookie_name=os.getenv(
|
||||||
|
"CALMINER_SESSION_REFRESH_COOKIE", "calminer_refresh_token"
|
||||||
|
),
|
||||||
|
session_cookie_secure=cls._bool_from_env(
|
||||||
|
"CALMINER_SESSION_COOKIE_SECURE", False
|
||||||
|
),
|
||||||
|
session_cookie_domain=os.getenv("CALMINER_SESSION_COOKIE_DOMAIN"),
|
||||||
|
session_cookie_path=os.getenv("CALMINER_SESSION_COOKIE_PATH", "/"),
|
||||||
|
session_header_name=os.getenv(
|
||||||
|
"CALMINER_SESSION_HEADER_NAME", "Authorization"
|
||||||
|
),
|
||||||
|
session_header_prefix=os.getenv(
|
||||||
|
"CALMINER_SESSION_HEADER_PREFIX", "Bearer"
|
||||||
|
),
|
||||||
|
session_allow_header_fallback=cls._bool_from_env(
|
||||||
|
"CALMINER_SESSION_ALLOW_HEADER_FALLBACK", True
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -42,6 +86,18 @@ class Settings:
|
|||||||
except ValueError:
|
except ValueError:
|
||||||
return default
|
return default
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _bool_from_env(name: str, default: bool) -> bool:
|
||||||
|
raw_value = os.getenv(name)
|
||||||
|
if raw_value is None:
|
||||||
|
return default
|
||||||
|
lowered = raw_value.strip().lower()
|
||||||
|
if lowered in {"1", "true", "yes", "on"}:
|
||||||
|
return True
|
||||||
|
if lowered in {"0", "false", "no", "off"}:
|
||||||
|
return False
|
||||||
|
return default
|
||||||
|
|
||||||
def jwt_settings(self) -> JWTSettings:
|
def jwt_settings(self) -> JWTSettings:
|
||||||
"""Build runtime JWT settings compatible with token helpers."""
|
"""Build runtime JWT settings compatible with token helpers."""
|
||||||
|
|
||||||
@@ -52,6 +108,20 @@ class Settings:
|
|||||||
refresh_token_ttl=timedelta(days=self.jwt_refresh_token_days),
|
refresh_token_ttl=timedelta(days=self.jwt_refresh_token_days),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def session_settings(self) -> SessionSettings:
|
||||||
|
"""Provide transport configuration for session tokens."""
|
||||||
|
|
||||||
|
return SessionSettings(
|
||||||
|
access_cookie_name=self.session_access_cookie_name,
|
||||||
|
refresh_cookie_name=self.session_refresh_cookie_name,
|
||||||
|
cookie_secure=self.session_cookie_secure,
|
||||||
|
cookie_domain=self.session_cookie_domain,
|
||||||
|
cookie_path=self.session_cookie_path,
|
||||||
|
header_name=self.session_header_name,
|
||||||
|
header_prefix=self.session_header_prefix,
|
||||||
|
allow_header_fallback=self.session_allow_header_fallback,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@lru_cache(maxsize=1)
|
@lru_cache(maxsize=1)
|
||||||
def get_settings() -> Settings:
|
def get_settings() -> Settings:
|
||||||
|
|||||||
@@ -2,8 +2,18 @@ from __future__ import annotations
|
|||||||
|
|
||||||
from collections.abc import Generator
|
from collections.abc import Generator
|
||||||
|
|
||||||
|
from fastapi import Depends, HTTPException, Request, status
|
||||||
|
|
||||||
from config.settings import Settings, get_settings
|
from config.settings import Settings, get_settings
|
||||||
|
from models import User
|
||||||
from services.security import JWTSettings
|
from services.security import JWTSettings
|
||||||
|
from services.session import (
|
||||||
|
AuthSession,
|
||||||
|
SessionStrategy,
|
||||||
|
SessionTokens,
|
||||||
|
build_session_strategy,
|
||||||
|
extract_session_tokens,
|
||||||
|
)
|
||||||
from services.unit_of_work import UnitOfWork
|
from services.unit_of_work import UnitOfWork
|
||||||
|
|
||||||
|
|
||||||
@@ -24,3 +34,65 @@ def get_jwt_settings() -> JWTSettings:
|
|||||||
"""Provide JWT runtime configuration derived from settings."""
|
"""Provide JWT runtime configuration derived from settings."""
|
||||||
|
|
||||||
return get_settings().jwt_settings()
|
return get_settings().jwt_settings()
|
||||||
|
|
||||||
|
|
||||||
|
def get_session_strategy(
|
||||||
|
settings: Settings = Depends(get_application_settings),
|
||||||
|
) -> SessionStrategy:
|
||||||
|
"""Yield configured session transport strategy."""
|
||||||
|
|
||||||
|
return build_session_strategy(settings.session_settings())
|
||||||
|
|
||||||
|
|
||||||
|
def get_session_tokens(
|
||||||
|
request: Request,
|
||||||
|
strategy: SessionStrategy = Depends(get_session_strategy),
|
||||||
|
) -> SessionTokens:
|
||||||
|
"""Extract raw session tokens from the incoming request."""
|
||||||
|
|
||||||
|
existing = getattr(request.state, "auth_session", None)
|
||||||
|
if isinstance(existing, AuthSession):
|
||||||
|
return existing.tokens
|
||||||
|
|
||||||
|
tokens = extract_session_tokens(request, strategy)
|
||||||
|
request.state.auth_session = AuthSession(tokens=tokens)
|
||||||
|
return tokens
|
||||||
|
|
||||||
|
|
||||||
|
def get_auth_session(
|
||||||
|
request: Request,
|
||||||
|
tokens: SessionTokens = Depends(get_session_tokens),
|
||||||
|
) -> AuthSession:
|
||||||
|
"""Provide authentication session context for the current request."""
|
||||||
|
|
||||||
|
existing = getattr(request.state, "auth_session", None)
|
||||||
|
if isinstance(existing, AuthSession):
|
||||||
|
return existing
|
||||||
|
|
||||||
|
if tokens.is_empty:
|
||||||
|
session = AuthSession.anonymous()
|
||||||
|
else:
|
||||||
|
session = AuthSession(tokens=tokens)
|
||||||
|
request.state.auth_session = session
|
||||||
|
return session
|
||||||
|
|
||||||
|
|
||||||
|
def get_current_user(
|
||||||
|
session: AuthSession = Depends(get_auth_session),
|
||||||
|
) -> User | None:
|
||||||
|
"""Return the current authenticated user if present."""
|
||||||
|
|
||||||
|
return session.user
|
||||||
|
|
||||||
|
|
||||||
|
def require_current_user(
|
||||||
|
session: AuthSession = Depends(get_auth_session),
|
||||||
|
) -> User:
|
||||||
|
"""Ensure that a request is authenticated and return the user context."""
|
||||||
|
|
||||||
|
if session.user is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="Authentication required.",
|
||||||
|
)
|
||||||
|
return session.user
|
||||||
|
|||||||
6
main.py
6
main.py
@@ -2,8 +2,10 @@ from typing import Awaitable, Callable
|
|||||||
|
|
||||||
from fastapi import FastAPI, Request, Response
|
from fastapi import FastAPI, Request, Response
|
||||||
from fastapi.staticfiles import StaticFiles
|
from fastapi.staticfiles import StaticFiles
|
||||||
from middleware.validation import validate_json
|
|
||||||
from config.database import Base, engine
|
from config.database import Base, engine
|
||||||
|
from middleware.auth_session import AuthSessionMiddleware
|
||||||
|
from middleware.validation import validate_json
|
||||||
from models import (
|
from models import (
|
||||||
FinancialInput,
|
FinancialInput,
|
||||||
Project,
|
Project,
|
||||||
@@ -20,6 +22,8 @@ Base.metadata.create_all(bind=engine)
|
|||||||
|
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
|
|
||||||
|
app.add_middleware(AuthSessionMiddleware)
|
||||||
|
|
||||||
|
|
||||||
@app.middleware("http")
|
@app.middleware("http")
|
||||||
async def json_validation(
|
async def json_validation(
|
||||||
|
|||||||
177
middleware/auth_session.py
Normal file
177
middleware/auth_session.py
Normal file
@@ -0,0 +1,177 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Callable, Iterable, Optional
|
||||||
|
|
||||||
|
from fastapi import Request, Response
|
||||||
|
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
|
||||||
|
from starlette.types import ASGIApp
|
||||||
|
|
||||||
|
from config.settings import Settings, get_settings
|
||||||
|
from models import User
|
||||||
|
from services.exceptions import EntityNotFoundError
|
||||||
|
from services.security import (
|
||||||
|
JWTSettings,
|
||||||
|
TokenDecodeError,
|
||||||
|
TokenError,
|
||||||
|
TokenExpiredError,
|
||||||
|
TokenTypeMismatchError,
|
||||||
|
create_access_token,
|
||||||
|
create_refresh_token,
|
||||||
|
decode_access_token,
|
||||||
|
decode_refresh_token,
|
||||||
|
)
|
||||||
|
from services.session import (
|
||||||
|
AuthSession,
|
||||||
|
SessionStrategy,
|
||||||
|
SessionTokens,
|
||||||
|
build_session_strategy,
|
||||||
|
clear_session_cookies,
|
||||||
|
extract_session_tokens,
|
||||||
|
set_session_cookies,
|
||||||
|
)
|
||||||
|
from services.unit_of_work import UnitOfWork
|
||||||
|
|
||||||
|
_AUTH_SCOPE = "auth"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(slots=True)
|
||||||
|
class _ResolutionResult:
|
||||||
|
session: AuthSession
|
||||||
|
strategy: SessionStrategy
|
||||||
|
jwt_settings: JWTSettings
|
||||||
|
|
||||||
|
|
||||||
|
class AuthSessionMiddleware(BaseHTTPMiddleware):
|
||||||
|
"""Resolve authenticated users from session cookies and refresh tokens."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
app: ASGIApp,
|
||||||
|
*,
|
||||||
|
settings_provider: Callable[[], Settings] = get_settings,
|
||||||
|
unit_of_work_factory: Callable[[], UnitOfWork] = UnitOfWork,
|
||||||
|
refresh_scopes: Iterable[str] | None = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(app)
|
||||||
|
self._settings_provider = settings_provider
|
||||||
|
self._unit_of_work_factory = unit_of_work_factory
|
||||||
|
self._refresh_scopes = tuple(
|
||||||
|
refresh_scopes) if refresh_scopes else (_AUTH_SCOPE,)
|
||||||
|
|
||||||
|
async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response:
|
||||||
|
resolved = self._resolve_session(request)
|
||||||
|
response = await call_next(request)
|
||||||
|
self._apply_session(response, resolved)
|
||||||
|
return response
|
||||||
|
|
||||||
|
def _resolve_session(self, request: Request) -> _ResolutionResult:
|
||||||
|
settings = self._settings_provider()
|
||||||
|
jwt_settings = settings.jwt_settings()
|
||||||
|
strategy = build_session_strategy(settings.session_settings())
|
||||||
|
|
||||||
|
tokens = extract_session_tokens(request, strategy)
|
||||||
|
session = AuthSession(tokens=tokens)
|
||||||
|
request.state.auth_session = session
|
||||||
|
|
||||||
|
if tokens.access_token:
|
||||||
|
if self._try_access_token(session, tokens, jwt_settings):
|
||||||
|
return _ResolutionResult(session=session, strategy=strategy, jwt_settings=jwt_settings)
|
||||||
|
|
||||||
|
if tokens.refresh_token:
|
||||||
|
self._try_refresh_token(
|
||||||
|
session, tokens.refresh_token, jwt_settings)
|
||||||
|
|
||||||
|
return _ResolutionResult(session=session, strategy=strategy, jwt_settings=jwt_settings)
|
||||||
|
|
||||||
|
def _try_access_token(
|
||||||
|
self,
|
||||||
|
session: AuthSession,
|
||||||
|
tokens: SessionTokens,
|
||||||
|
jwt_settings: JWTSettings,
|
||||||
|
) -> bool:
|
||||||
|
try:
|
||||||
|
payload = decode_access_token(
|
||||||
|
tokens.access_token or "", jwt_settings)
|
||||||
|
except TokenExpiredError:
|
||||||
|
return False
|
||||||
|
except (TokenDecodeError, TokenTypeMismatchError, TokenError):
|
||||||
|
session.mark_cleared()
|
||||||
|
return False
|
||||||
|
|
||||||
|
user = self._load_user(payload.sub)
|
||||||
|
if not user or not user.is_active or _AUTH_SCOPE not in payload.scopes:
|
||||||
|
session.mark_cleared()
|
||||||
|
return False
|
||||||
|
|
||||||
|
session.user = user
|
||||||
|
session.scopes = tuple(payload.scopes)
|
||||||
|
return True
|
||||||
|
|
||||||
|
def _try_refresh_token(
|
||||||
|
self,
|
||||||
|
session: AuthSession,
|
||||||
|
refresh_token: str,
|
||||||
|
jwt_settings: JWTSettings,
|
||||||
|
) -> None:
|
||||||
|
try:
|
||||||
|
payload = decode_refresh_token(refresh_token, jwt_settings)
|
||||||
|
except (TokenExpiredError, TokenDecodeError, TokenTypeMismatchError, TokenError):
|
||||||
|
session.mark_cleared()
|
||||||
|
return
|
||||||
|
|
||||||
|
user = self._load_user(payload.sub)
|
||||||
|
if not user or not user.is_active or not self._is_refresh_scope_allowed(payload.scopes):
|
||||||
|
session.mark_cleared()
|
||||||
|
return
|
||||||
|
|
||||||
|
session.user = user
|
||||||
|
session.scopes = tuple(payload.scopes)
|
||||||
|
|
||||||
|
access_token = create_access_token(
|
||||||
|
str(user.id),
|
||||||
|
jwt_settings,
|
||||||
|
scopes=payload.scopes,
|
||||||
|
)
|
||||||
|
new_refresh = create_refresh_token(
|
||||||
|
str(user.id),
|
||||||
|
jwt_settings,
|
||||||
|
scopes=payload.scopes,
|
||||||
|
)
|
||||||
|
session.issue_tokens(access_token=access_token,
|
||||||
|
refresh_token=new_refresh)
|
||||||
|
|
||||||
|
def _is_refresh_scope_allowed(self, scopes: Iterable[str]) -> bool:
|
||||||
|
candidate_scopes = set(scopes)
|
||||||
|
return any(scope in candidate_scopes for scope in self._refresh_scopes)
|
||||||
|
|
||||||
|
def _load_user(self, subject: str) -> Optional[User]:
|
||||||
|
try:
|
||||||
|
user_id = int(subject)
|
||||||
|
except ValueError:
|
||||||
|
return None
|
||||||
|
|
||||||
|
with self._unit_of_work_factory() as uow:
|
||||||
|
if not uow.users:
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
user = uow.users.get(user_id, with_roles=True)
|
||||||
|
except EntityNotFoundError:
|
||||||
|
return None
|
||||||
|
return user
|
||||||
|
|
||||||
|
def _apply_session(self, response: Response, resolved: _ResolutionResult) -> None:
|
||||||
|
session = resolved.session
|
||||||
|
if session.clear_cookies:
|
||||||
|
clear_session_cookies(response, resolved.strategy)
|
||||||
|
return
|
||||||
|
|
||||||
|
if session.issued_access_token:
|
||||||
|
refresh_token = session.issued_refresh_token or session.tokens.refresh_token
|
||||||
|
set_session_cookies(
|
||||||
|
response,
|
||||||
|
access_token=session.issued_access_token,
|
||||||
|
refresh_token=refresh_token,
|
||||||
|
strategy=resolved.strategy,
|
||||||
|
jwt_settings=resolved.jwt_settings,
|
||||||
|
)
|
||||||
@@ -9,7 +9,13 @@ from fastapi.templating import Jinja2Templates
|
|||||||
from pydantic import ValidationError
|
from pydantic import ValidationError
|
||||||
from starlette.datastructures import FormData
|
from starlette.datastructures import FormData
|
||||||
|
|
||||||
from dependencies import get_jwt_settings, get_unit_of_work
|
from dependencies import (
|
||||||
|
get_auth_session,
|
||||||
|
get_jwt_settings,
|
||||||
|
get_session_strategy,
|
||||||
|
get_unit_of_work,
|
||||||
|
require_current_user,
|
||||||
|
)
|
||||||
from models import Role, User
|
from models import Role, User
|
||||||
from schemas.auth import (
|
from schemas.auth import (
|
||||||
LoginForm,
|
LoginForm,
|
||||||
@@ -29,6 +35,12 @@ from services.security import (
|
|||||||
hash_password,
|
hash_password,
|
||||||
verify_password,
|
verify_password,
|
||||||
)
|
)
|
||||||
|
from services.session import (
|
||||||
|
AuthSession,
|
||||||
|
SessionStrategy,
|
||||||
|
clear_session_cookies,
|
||||||
|
set_session_cookies,
|
||||||
|
)
|
||||||
from services.repositories import RoleRepository, UserRepository
|
from services.repositories import RoleRepository, UserRepository
|
||||||
from services.unit_of_work import UnitOfWork
|
from services.unit_of_work import UnitOfWork
|
||||||
|
|
||||||
@@ -103,6 +115,7 @@ async def login_submit(
|
|||||||
request: Request,
|
request: Request,
|
||||||
uow: UnitOfWork = Depends(get_unit_of_work),
|
uow: UnitOfWork = Depends(get_unit_of_work),
|
||||||
jwt_settings: JWTSettings = Depends(get_jwt_settings),
|
jwt_settings: JWTSettings = Depends(get_jwt_settings),
|
||||||
|
session_strategy: SessionStrategy = Depends(get_session_strategy),
|
||||||
):
|
):
|
||||||
form_data = _normalise_form_data(await request.form())
|
form_data = _normalise_form_data(await request.form())
|
||||||
try:
|
try:
|
||||||
@@ -158,7 +171,31 @@ async def login_submit(
|
|||||||
request.url_for("dashboard.home"),
|
request.url_for("dashboard.home"),
|
||||||
status_code=status.HTTP_303_SEE_OTHER,
|
status_code=status.HTTP_303_SEE_OTHER,
|
||||||
)
|
)
|
||||||
_set_auth_cookies(response, access_token, refresh_token, jwt_settings)
|
set_session_cookies(
|
||||||
|
response,
|
||||||
|
access_token=access_token,
|
||||||
|
refresh_token=refresh_token,
|
||||||
|
strategy=session_strategy,
|
||||||
|
jwt_settings=jwt_settings,
|
||||||
|
)
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/logout", include_in_schema=False, name="auth.logout")
|
||||||
|
async def logout(
|
||||||
|
request: Request,
|
||||||
|
_: User = Depends(require_current_user),
|
||||||
|
session: AuthSession = Depends(get_auth_session),
|
||||||
|
session_strategy: SessionStrategy = Depends(get_session_strategy),
|
||||||
|
) -> RedirectResponse:
|
||||||
|
session.mark_cleared()
|
||||||
|
redirect_url = request.url_for(
|
||||||
|
"auth.login_form").include_query_params(logout="1")
|
||||||
|
response = RedirectResponse(
|
||||||
|
redirect_url,
|
||||||
|
status_code=status.HTTP_303_SEE_OTHER,
|
||||||
|
)
|
||||||
|
clear_session_cookies(response, session_strategy)
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
||||||
@@ -168,32 +205,6 @@ def _lookup_user(users_repo: UserRepository, identifier: str) -> User | None:
|
|||||||
return users_repo.get_by_username(identifier, with_roles=True)
|
return users_repo.get_by_username(identifier, with_roles=True)
|
||||||
|
|
||||||
|
|
||||||
def _set_auth_cookies(
|
|
||||||
response: RedirectResponse,
|
|
||||||
access_token: str,
|
|
||||||
refresh_token: str,
|
|
||||||
jwt_settings: JWTSettings,
|
|
||||||
) -> None:
|
|
||||||
access_ttl = int(jwt_settings.access_token_ttl.total_seconds())
|
|
||||||
refresh_ttl = int(jwt_settings.refresh_token_ttl.total_seconds())
|
|
||||||
response.set_cookie(
|
|
||||||
"calminer_access_token",
|
|
||||||
access_token,
|
|
||||||
httponly=True,
|
|
||||||
secure=False,
|
|
||||||
samesite="lax",
|
|
||||||
max_age=max(access_ttl, 0) or None,
|
|
||||||
)
|
|
||||||
response.set_cookie(
|
|
||||||
"calminer_refresh_token",
|
|
||||||
refresh_token,
|
|
||||||
httponly=True,
|
|
||||||
secure=False,
|
|
||||||
samesite="lax",
|
|
||||||
max_age=max(refresh_ttl, 0) or None,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/register", response_class=HTMLResponse, include_in_schema=False, name="auth.register_form")
|
@router.get("/register", response_class=HTMLResponse, include_in_schema=False, name="auth.register_form")
|
||||||
def register_form(request: Request) -> HTMLResponse:
|
def register_form(request: Request) -> HTMLResponse:
|
||||||
return _template(
|
return _template(
|
||||||
|
|||||||
192
services/session.py
Normal file
192
services/session.py
Normal file
@@ -0,0 +1,192 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Literal, Optional, TYPE_CHECKING
|
||||||
|
|
||||||
|
from fastapi import Request, Response
|
||||||
|
|
||||||
|
from config.settings import SessionSettings
|
||||||
|
from services.security import JWTSettings
|
||||||
|
|
||||||
|
if TYPE_CHECKING: # pragma: no cover - used only for static typing
|
||||||
|
from models import User
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(slots=True)
|
||||||
|
class SessionStrategy:
|
||||||
|
"""Describe how authentication tokens are transported with requests."""
|
||||||
|
|
||||||
|
access_cookie_name: str
|
||||||
|
refresh_cookie_name: str
|
||||||
|
cookie_secure: bool
|
||||||
|
cookie_domain: Optional[str]
|
||||||
|
cookie_path: str
|
||||||
|
header_name: str
|
||||||
|
header_prefix: str
|
||||||
|
allow_header_fallback: bool = True
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_settings(cls, settings: SessionSettings) -> "SessionStrategy":
|
||||||
|
return cls(
|
||||||
|
access_cookie_name=settings.access_cookie_name,
|
||||||
|
refresh_cookie_name=settings.refresh_cookie_name,
|
||||||
|
cookie_secure=settings.cookie_secure,
|
||||||
|
cookie_domain=settings.cookie_domain,
|
||||||
|
cookie_path=settings.cookie_path,
|
||||||
|
header_name=settings.header_name,
|
||||||
|
header_prefix=settings.header_prefix,
|
||||||
|
allow_header_fallback=settings.allow_header_fallback,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(slots=True)
|
||||||
|
class SessionTokens:
|
||||||
|
"""Raw access and refresh tokens extracted from the transport layer."""
|
||||||
|
|
||||||
|
access_token: Optional[str]
|
||||||
|
refresh_token: Optional[str]
|
||||||
|
access_token_source: Literal["cookie", "header", "none"] = "none"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def has_access(self) -> bool:
|
||||||
|
return bool(self.access_token)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def has_refresh(self) -> bool:
|
||||||
|
return bool(self.refresh_token)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_empty(self) -> bool:
|
||||||
|
return not self.has_access and not self.has_refresh
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(slots=True)
|
||||||
|
class AuthSession:
|
||||||
|
"""Holds authenticated user context resolved from session tokens."""
|
||||||
|
|
||||||
|
tokens: SessionTokens
|
||||||
|
user: Optional["User"] = None
|
||||||
|
scopes: tuple[str, ...] = ()
|
||||||
|
issued_access_token: Optional[str] = None
|
||||||
|
issued_refresh_token: Optional[str] = None
|
||||||
|
clear_cookies: bool = False
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_authenticated(self) -> bool:
|
||||||
|
return self.user is not None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def anonymous(cls) -> "AuthSession":
|
||||||
|
return cls(tokens=SessionTokens(access_token=None, refresh_token=None))
|
||||||
|
|
||||||
|
def issue_tokens(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
access_token: str,
|
||||||
|
refresh_token: Optional[str] = None,
|
||||||
|
access_source: Literal["cookie", "header", "none"] = "cookie",
|
||||||
|
) -> None:
|
||||||
|
self.issued_access_token = access_token
|
||||||
|
if refresh_token is not None:
|
||||||
|
self.issued_refresh_token = refresh_token
|
||||||
|
self.tokens = SessionTokens(
|
||||||
|
access_token=access_token,
|
||||||
|
refresh_token=refresh_token if refresh_token is not None else self.tokens.refresh_token,
|
||||||
|
access_token_source=access_source,
|
||||||
|
)
|
||||||
|
|
||||||
|
def mark_cleared(self) -> None:
|
||||||
|
self.clear_cookies = True
|
||||||
|
self.tokens = SessionTokens(access_token=None, refresh_token=None)
|
||||||
|
self.user = None
|
||||||
|
self.scopes = ()
|
||||||
|
|
||||||
|
|
||||||
|
def extract_session_tokens(request: Request, strategy: SessionStrategy) -> SessionTokens:
|
||||||
|
"""Pull tokens from cookies or headers according to configured strategy."""
|
||||||
|
|
||||||
|
access_token: Optional[str] = None
|
||||||
|
refresh_token: Optional[str] = None
|
||||||
|
access_source: Literal["cookie", "header", "none"] = "none"
|
||||||
|
|
||||||
|
if strategy.access_cookie_name in request.cookies:
|
||||||
|
access_token = request.cookies.get(strategy.access_cookie_name) or None
|
||||||
|
if access_token:
|
||||||
|
access_source = "cookie"
|
||||||
|
|
||||||
|
if strategy.refresh_cookie_name in request.cookies:
|
||||||
|
refresh_token = request.cookies.get(
|
||||||
|
strategy.refresh_cookie_name) or None
|
||||||
|
|
||||||
|
if not access_token and strategy.allow_header_fallback:
|
||||||
|
header_value = request.headers.get(strategy.header_name)
|
||||||
|
if header_value:
|
||||||
|
candidate = header_value.strip()
|
||||||
|
prefix = f"{strategy.header_prefix} " if strategy.header_prefix else ""
|
||||||
|
if prefix and candidate.lower().startswith(prefix.lower()):
|
||||||
|
candidate = candidate[len(prefix):].strip()
|
||||||
|
if candidate:
|
||||||
|
access_token = candidate
|
||||||
|
access_source = "header"
|
||||||
|
|
||||||
|
return SessionTokens(
|
||||||
|
access_token=access_token,
|
||||||
|
refresh_token=refresh_token,
|
||||||
|
access_token_source=access_source,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def build_session_strategy(settings: SessionSettings) -> SessionStrategy:
|
||||||
|
"""Create a session strategy object from settings configuration."""
|
||||||
|
|
||||||
|
return SessionStrategy.from_settings(settings)
|
||||||
|
|
||||||
|
|
||||||
|
def set_session_cookies(
|
||||||
|
response: Response,
|
||||||
|
*,
|
||||||
|
access_token: str,
|
||||||
|
refresh_token: Optional[str],
|
||||||
|
strategy: SessionStrategy,
|
||||||
|
jwt_settings: JWTSettings,
|
||||||
|
) -> None:
|
||||||
|
"""Persist session cookies on an outgoing response."""
|
||||||
|
|
||||||
|
access_ttl = int(jwt_settings.access_token_ttl.total_seconds())
|
||||||
|
refresh_ttl = int(jwt_settings.refresh_token_ttl.total_seconds())
|
||||||
|
response.set_cookie(
|
||||||
|
strategy.access_cookie_name,
|
||||||
|
access_token,
|
||||||
|
httponly=True,
|
||||||
|
secure=strategy.cookie_secure,
|
||||||
|
samesite="lax",
|
||||||
|
max_age=max(access_ttl, 0) or None,
|
||||||
|
domain=strategy.cookie_domain,
|
||||||
|
path=strategy.cookie_path,
|
||||||
|
)
|
||||||
|
if refresh_token is not None:
|
||||||
|
response.set_cookie(
|
||||||
|
strategy.refresh_cookie_name,
|
||||||
|
refresh_token,
|
||||||
|
httponly=True,
|
||||||
|
secure=strategy.cookie_secure,
|
||||||
|
samesite="lax",
|
||||||
|
max_age=max(refresh_ttl, 0) or None,
|
||||||
|
domain=strategy.cookie_domain,
|
||||||
|
path=strategy.cookie_path,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def clear_session_cookies(response: Response, strategy: SessionStrategy) -> None:
|
||||||
|
"""Remove session cookies from the client."""
|
||||||
|
|
||||||
|
response.delete_cookie(
|
||||||
|
strategy.access_cookie_name,
|
||||||
|
domain=strategy.cookie_domain,
|
||||||
|
path=strategy.cookie_path,
|
||||||
|
)
|
||||||
|
response.delete_cookie(
|
||||||
|
strategy.refresh_cookie_name,
|
||||||
|
domain=strategy.cookie_domain,
|
||||||
|
path=strategy.cookie_path,
|
||||||
|
)
|
||||||
@@ -1,52 +1,88 @@
|
|||||||
{% set dashboard_href = request.url_for('dashboard.home') if request else '/' %}
|
{% set dashboard_href = request.url_for('dashboard.home') if request else '/' %}
|
||||||
{% set projects_href = request.url_for('projects.project_list_page') if request
|
{% set projects_href = request.url_for('projects.project_list_page') if request else '/projects/ui' %}
|
||||||
else '/projects/ui' %} {% set project_create_href =
|
{% set project_create_href = request.url_for('projects.create_project_form') if request else '/projects/create' %}
|
||||||
request.url_for('projects.create_project_form') if request else
|
{% set auth_session = request.state.auth_session if request else None %}
|
||||||
'/projects/create' %} {% set login_href = request.url_for('auth.login_form') if
|
{% set is_authenticated = auth_session and auth_session.is_authenticated %}
|
||||||
request else '/login' %} {% set register_href =
|
|
||||||
request.url_for('auth.register_form') if request else '/register' %} {% set
|
{% if is_authenticated %}
|
||||||
forgot_href = request.url_for('auth.password_reset_request_form') if request
|
{% set logout_href = request.url_for('auth.logout') if request else '/logout' %}
|
||||||
else '/forgot-password' %} {% set nav_groups = [ { "label": "Workspace",
|
{% set account_links = [
|
||||||
"links": [ {"href": dashboard_href, "label": "Dashboard", "match_prefix": "/"},
|
{"href": logout_href, "label": "Logout", "match_prefix": "/logout"}
|
||||||
|
] %}
|
||||||
|
{% else %}
|
||||||
|
{% set login_href = request.url_for('auth.login_form') if request else '/login' %}
|
||||||
|
{% set register_href = request.url_for('auth.register_form') if request else '/register' %}
|
||||||
|
{% set forgot_href = request.url_for('auth.password_reset_request_form') if request else '/forgot-password' %}
|
||||||
|
{% set account_links = [
|
||||||
|
{"href": login_href, "label": "Login", "match_prefix": "/login"},
|
||||||
|
{"href": register_href, "label": "Register", "match_prefix": "/register"},
|
||||||
|
{"href": forgot_href, "label": "Forgot Password", "match_prefix": "/forgot-password"}
|
||||||
|
] %}
|
||||||
|
{% endif %}
|
||||||
|
{% set nav_groups = [
|
||||||
|
{
|
||||||
|
"label": "Workspace",
|
||||||
|
"links": [
|
||||||
|
{"href": dashboard_href, "label": "Dashboard", "match_prefix": "/"},
|
||||||
{"href": projects_href, "label": "Projects", "match_prefix": "/projects"},
|
{"href": projects_href, "label": "Projects", "match_prefix": "/projects"},
|
||||||
{"href": project_create_href, "label": "New Project", "match_prefix":
|
{"href": project_create_href, "label": "New Project", "match_prefix": "/projects/create"}
|
||||||
"/projects/create"}, ], }, { "label": "Insights", "links": [ {"href":
|
]
|
||||||
"/ui/simulations", "label": "Simulations"}, {"href": "/ui/reporting", "label":
|
},
|
||||||
"Reporting"}, ], }, { "label": "Configuration", "links": [ { "href":
|
{
|
||||||
"/ui/settings", "label": "Settings", "children": [ {"href": "/theme-settings",
|
"label": "Insights",
|
||||||
"label": "Themes"}, {"href": "/ui/currencies", "label": "Currency Management"},
|
"links": [
|
||||||
], }, ], }, { "label": "Account", "links": [ {"href": login_href, "label":
|
{"href": "/ui/simulations", "label": "Simulations"},
|
||||||
"Login", "match_prefix": "/login"}, {"href": register_href, "label": "Register",
|
{"href": "/ui/reporting", "label": "Reporting"}
|
||||||
"match_prefix": "/register"}, {"href": forgot_href, "label": "Forgot Password",
|
]
|
||||||
"match_prefix": "/forgot-password"}, ], }, ] %}
|
},
|
||||||
|
{
|
||||||
|
"label": "Configuration",
|
||||||
|
"links": [
|
||||||
|
{
|
||||||
|
"href": "/ui/settings",
|
||||||
|
"label": "Settings",
|
||||||
|
"children": [
|
||||||
|
{"href": "/theme-settings", "label": "Themes"},
|
||||||
|
{"href": "/ui/currencies", "label": "Currency Management"}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"label": "Account",
|
||||||
|
"links": account_links
|
||||||
|
}
|
||||||
|
] %}
|
||||||
|
|
||||||
<nav class="sidebar-nav" aria-label="Primary navigation">
|
<nav class="sidebar-nav" aria-label="Primary navigation">
|
||||||
{% set current_path = request.url.path if request else "" %} {% for group in
|
{% set current_path = request.url.path if request else '' %}
|
||||||
nav_groups %}
|
{% for group in nav_groups %}
|
||||||
|
{% if group.links %}
|
||||||
<div class="sidebar-section">
|
<div class="sidebar-section">
|
||||||
<div class="sidebar-section-label">{{ group.label }}</div>
|
<div class="sidebar-section-label">{{ group.label }}</div>
|
||||||
<div class="sidebar-section-links">
|
<div class="sidebar-section-links">
|
||||||
{% for link in group.links %} {% set href = link.href %} {% set
|
{% for link in group.links %}
|
||||||
match_prefix = link.get('match_prefix', href) %} {% if match_prefix == '/'
|
{% set href = link.href %}
|
||||||
%} {% set is_active = current_path == '/' %} {% else %} {% set is_active =
|
{% set match_prefix = link.get('match_prefix', href) %}
|
||||||
current_path.startswith(match_prefix) %} {% endif %}
|
{% if match_prefix == '/' %}
|
||||||
|
{% set is_active = current_path == '/' %}
|
||||||
|
{% else %}
|
||||||
|
{% set is_active = current_path.startswith(match_prefix) %}
|
||||||
|
{% endif %}
|
||||||
<div class="sidebar-link-block">
|
<div class="sidebar-link-block">
|
||||||
<a
|
<a href="{{ href }}" class="sidebar-link{% if is_active %} is-active{% endif %}">
|
||||||
href="{{ href }}"
|
|
||||||
class="sidebar-link{% if is_active %} is-active{% endif %}"
|
|
||||||
>
|
|
||||||
{{ link.label }}
|
{{ link.label }}
|
||||||
</a>
|
</a>
|
||||||
{% if link.children %}
|
{% if link.children %}
|
||||||
<div class="sidebar-sublinks">
|
<div class="sidebar-sublinks">
|
||||||
{% for child in link.children %} {% set child_prefix =
|
{% for child in link.children %}
|
||||||
child.get('match_prefix', child.href) %} {% if child_prefix == '/' %}
|
{% set child_prefix = child.get('match_prefix', child.href) %}
|
||||||
{% set child_active = current_path == '/' %} {% else %} {% set
|
{% if child_prefix == '/' %}
|
||||||
child_active = current_path.startswith(child_prefix) %} {% endif %}
|
{% set child_active = current_path == '/' %}
|
||||||
<a
|
{% else %}
|
||||||
href="{{ child.href }}"
|
{% set child_active = current_path.startswith(child_prefix) %}
|
||||||
class="sidebar-sublink{% if child_active %} is-active{% endif %}"
|
{% endif %}
|
||||||
>
|
<a href="{{ child.href }}" class="sidebar-sublink{% if child_active %} is-active{% endif %}">
|
||||||
{{ child.label }}
|
{{ child.label }}
|
||||||
</a>
|
</a>
|
||||||
{% endfor %}
|
{% endfor %}
|
||||||
@@ -56,5 +92,6 @@ else '/forgot-password' %} {% set nav_groups = [ { "label": "Workspace",
|
|||||||
{% endfor %}
|
{% endfor %}
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
{% endif %}
|
||||||
{% endfor %}
|
{% endfor %}
|
||||||
</nav>
|
</nav>
|
||||||
|
|||||||
@@ -1,15 +1,19 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from collections.abc import Iterator
|
from collections.abc import Iterator
|
||||||
|
from typing import cast
|
||||||
from urllib.parse import parse_qs, urlparse
|
from urllib.parse import parse_qs, urlparse
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
from fastapi import FastAPI
|
||||||
from fastapi.testclient import TestClient
|
from fastapi.testclient import TestClient
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from sqlalchemy.orm import Session, sessionmaker
|
from sqlalchemy.orm import Session, sessionmaker
|
||||||
|
|
||||||
from models import Role, User, UserRole
|
from models import Role, User, UserRole
|
||||||
|
from dependencies import get_auth_session, require_current_user
|
||||||
from services.security import hash_password
|
from services.security import hash_password
|
||||||
|
from services.session import AuthSession, SessionTokens
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture()
|
@pytest.fixture()
|
||||||
@@ -223,7 +227,8 @@ class TestPasswordResetFlow:
|
|||||||
data={"email": "mismatch@example.com"},
|
data={"email": "mismatch@example.com"},
|
||||||
follow_redirects=False,
|
follow_redirects=False,
|
||||||
)
|
)
|
||||||
token = parse_qs(urlparse(request_response.headers["location"]).query)["token"][0]
|
token = parse_qs(urlparse(request_response.headers["location"]).query)[
|
||||||
|
"token"][0]
|
||||||
|
|
||||||
submit_response = client.post(
|
submit_response = client.post(
|
||||||
"/reset-password",
|
"/reset-password",
|
||||||
@@ -237,3 +242,45 @@ class TestPasswordResetFlow:
|
|||||||
|
|
||||||
assert submit_response.status_code == 400
|
assert submit_response.status_code == 400
|
||||||
assert "Passwords do not match" in submit_response.text
|
assert "Passwords do not match" in submit_response.text
|
||||||
|
|
||||||
|
|
||||||
|
class TestLogoutFlow:
|
||||||
|
def test_logout_clears_cookies_and_redirects(
|
||||||
|
self,
|
||||||
|
client: TestClient,
|
||||||
|
db_session: Session,
|
||||||
|
) -> None:
|
||||||
|
user = User(
|
||||||
|
email="logout@example.com",
|
||||||
|
username="logoutuser",
|
||||||
|
password_hash=hash_password("SecureP@ss1"),
|
||||||
|
is_active=True,
|
||||||
|
)
|
||||||
|
db_session.add(user)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
session = AuthSession(
|
||||||
|
tokens=SessionTokens(
|
||||||
|
access_token="access-token",
|
||||||
|
refresh_token="refresh-token",
|
||||||
|
access_token_source="cookie",
|
||||||
|
),
|
||||||
|
user=user,
|
||||||
|
)
|
||||||
|
|
||||||
|
app = cast(FastAPI, client.app)
|
||||||
|
app.dependency_overrides[require_current_user] = lambda: user
|
||||||
|
app.dependency_overrides[get_auth_session] = lambda: session
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = client.get("/logout", follow_redirects=False)
|
||||||
|
finally:
|
||||||
|
app.dependency_overrides.pop(require_current_user, None)
|
||||||
|
app.dependency_overrides.pop(get_auth_session, None)
|
||||||
|
|
||||||
|
assert response.status_code == 303
|
||||||
|
location = response.headers.get("location")
|
||||||
|
assert location and location.startswith("http://testserver/login")
|
||||||
|
set_cookie_header = response.headers.get("set-cookie") or ""
|
||||||
|
assert "calminer_access_token=" in set_cookie_header
|
||||||
|
assert "Max-Age=0" in set_cookie_header or "expires=" in set_cookie_header.lower()
|
||||||
|
|||||||
111
tests/test_auth_session_middleware.py
Normal file
111
tests/test_auth_session_middleware.py
Normal file
@@ -0,0 +1,111 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from datetime import timedelta
|
||||||
|
from typing import Iterator
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from fastapi import Depends, FastAPI
|
||||||
|
from fastapi.responses import JSONResponse
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
from sqlalchemy.orm import sessionmaker
|
||||||
|
|
||||||
|
from config.settings import get_settings
|
||||||
|
from dependencies import get_unit_of_work, require_current_user
|
||||||
|
from middleware.auth_session import AuthSessionMiddleware
|
||||||
|
from models import User
|
||||||
|
from services.security import create_access_token, create_refresh_token, hash_password
|
||||||
|
from services.unit_of_work import UnitOfWork
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def auth_app(session_factory: sessionmaker) -> Iterator[TestClient]:
|
||||||
|
app = FastAPI()
|
||||||
|
|
||||||
|
def _override_uow() -> Iterator[UnitOfWork]:
|
||||||
|
with UnitOfWork(session_factory=session_factory) as uow:
|
||||||
|
yield uow
|
||||||
|
|
||||||
|
app.dependency_overrides[get_unit_of_work] = _override_uow
|
||||||
|
|
||||||
|
@app.get("/me")
|
||||||
|
def read_me(user: User = Depends(require_current_user)) -> JSONResponse:
|
||||||
|
return JSONResponse({"id": user.id, "username": user.username})
|
||||||
|
|
||||||
|
app.add_middleware(
|
||||||
|
AuthSessionMiddleware,
|
||||||
|
unit_of_work_factory=lambda: UnitOfWork(
|
||||||
|
session_factory=session_factory),
|
||||||
|
)
|
||||||
|
|
||||||
|
client = TestClient(app)
|
||||||
|
try:
|
||||||
|
yield client
|
||||||
|
finally:
|
||||||
|
client.close()
|
||||||
|
|
||||||
|
|
||||||
|
def _create_user(session_factory: sessionmaker) -> User:
|
||||||
|
with UnitOfWork(session_factory=session_factory) as uow:
|
||||||
|
assert uow.users is not None
|
||||||
|
user = User(
|
||||||
|
email="jane@example.com",
|
||||||
|
username="jane",
|
||||||
|
password_hash=hash_password("secret"),
|
||||||
|
is_active=True,
|
||||||
|
is_superuser=False,
|
||||||
|
)
|
||||||
|
uow.users.create(user)
|
||||||
|
return user
|
||||||
|
|
||||||
|
|
||||||
|
def _issue_tokens(user: User) -> tuple[str, str]:
|
||||||
|
settings = get_settings().jwt_settings()
|
||||||
|
access = create_access_token(str(user.id), settings, scopes=["auth"])
|
||||||
|
refresh = create_refresh_token(str(user.id), settings, scopes=["auth"])
|
||||||
|
return access, refresh
|
||||||
|
|
||||||
|
|
||||||
|
def test_middleware_populates_current_user(auth_app: TestClient, session_factory: sessionmaker) -> None:
|
||||||
|
user = _create_user(session_factory)
|
||||||
|
access, refresh = _issue_tokens(user)
|
||||||
|
|
||||||
|
auth_app.cookies.set("calminer_access_token", access)
|
||||||
|
auth_app.cookies.set("calminer_refresh_token", refresh)
|
||||||
|
|
||||||
|
response = auth_app.get("/me")
|
||||||
|
assert response.status_code == 200
|
||||||
|
payload = response.json()
|
||||||
|
assert payload["id"] == user.id
|
||||||
|
assert payload["username"] == user.username
|
||||||
|
|
||||||
|
|
||||||
|
def test_middleware_refreshes_expired_access_token(auth_app: TestClient, session_factory: sessionmaker) -> None:
|
||||||
|
user = _create_user(session_factory)
|
||||||
|
settings = get_settings().jwt_settings()
|
||||||
|
expired = create_access_token(
|
||||||
|
str(user.id),
|
||||||
|
settings,
|
||||||
|
scopes=["auth"],
|
||||||
|
expires_delta=timedelta(seconds=-1),
|
||||||
|
)
|
||||||
|
refresh = create_refresh_token(str(user.id), settings, scopes=["auth"])
|
||||||
|
|
||||||
|
auth_app.cookies.set("calminer_access_token", expired)
|
||||||
|
auth_app.cookies.set("calminer_refresh_token", refresh)
|
||||||
|
|
||||||
|
response = auth_app.get("/me")
|
||||||
|
assert response.status_code == 200
|
||||||
|
new_access = response.cookies.get("calminer_access_token")
|
||||||
|
new_refresh = response.cookies.get("calminer_refresh_token")
|
||||||
|
assert new_access is not None and new_access != expired
|
||||||
|
assert new_refresh is not None
|
||||||
|
|
||||||
|
|
||||||
|
def test_middleware_blocks_invalid_tokens(auth_app: TestClient) -> None:
|
||||||
|
auth_app.cookies.set("calminer_access_token", "invalid-token")
|
||||||
|
auth_app.cookies.set("calminer_refresh_token", "invalid-token")
|
||||||
|
|
||||||
|
response = auth_app.get("/me")
|
||||||
|
assert response.status_code == 401
|
||||||
|
set_cookies = response.headers.get_list("set-cookie")
|
||||||
|
assert any("calminer_access_token=" in value for value in set_cookies)
|
||||||
Reference in New Issue
Block a user