From 27262bdfa34dc18678e90f99a4219f55f136c2f3 Mon Sep 17 00:00:00 2001 From: zwitschi Date: Sun, 9 Nov 2025 23:14:41 +0100 Subject: [PATCH] feat: Implement session management with middleware and update authentication flow --- config/settings.py | 70 ++++++++++ dependencies.py | 72 ++++++++++ main.py | 6 +- middleware/auth_session.py | 177 ++++++++++++++++++++++++ routes/auth.py | 67 +++++---- services/session.py | 192 ++++++++++++++++++++++++++ templates/partials/sidebar_nav.html | 141 ++++++++++++------- tests/test_auth_routes.py | 51 ++++++- tests/test_auth_session_middleware.py | 111 +++++++++++++++ 9 files changed, 804 insertions(+), 83 deletions(-) create mode 100644 middleware/auth_session.py create mode 100644 services/session.py create mode 100644 tests/test_auth_session_middleware.py diff --git a/config/settings.py b/config/settings.py index 0b928d4..217a487 100644 --- a/config/settings.py +++ b/config/settings.py @@ -5,9 +5,25 @@ from dataclasses import dataclass from datetime import timedelta from functools import lru_cache +from typing import Optional + from services.security import JWTSettings +@dataclass(frozen=True, slots=True) +class SessionSettings: + """Cookie and header configuration for session token transport.""" + + access_cookie_name: str + refresh_cookie_name: str + cookie_secure: bool + cookie_domain: Optional[str] + cookie_path: str + header_name: str + header_prefix: str + allow_header_fallback: bool + + @dataclass(frozen=True, slots=True) class Settings: """Application configuration sourced from environment variables.""" @@ -16,6 +32,14 @@ class Settings: jwt_algorithm: str = "HS256" jwt_access_token_minutes: int = 15 jwt_refresh_token_days: int = 7 + session_access_cookie_name: str = "calminer_access_token" + session_refresh_cookie_name: str = "calminer_refresh_token" + session_cookie_secure: bool = False + session_cookie_domain: Optional[str] = None + session_cookie_path: str = "/" + session_header_name: str = "Authorization" + session_header_prefix: str = "Bearer" + session_allow_header_fallback: bool = True @classmethod def from_environment(cls) -> "Settings": @@ -30,6 +54,26 @@ class Settings: jwt_refresh_token_days=cls._int_from_env( "CALMINER_JWT_REFRESH_DAYS", 7 ), + session_access_cookie_name=os.getenv( + "CALMINER_SESSION_ACCESS_COOKIE", "calminer_access_token" + ), + session_refresh_cookie_name=os.getenv( + "CALMINER_SESSION_REFRESH_COOKIE", "calminer_refresh_token" + ), + session_cookie_secure=cls._bool_from_env( + "CALMINER_SESSION_COOKIE_SECURE", False + ), + session_cookie_domain=os.getenv("CALMINER_SESSION_COOKIE_DOMAIN"), + session_cookie_path=os.getenv("CALMINER_SESSION_COOKIE_PATH", "/"), + session_header_name=os.getenv( + "CALMINER_SESSION_HEADER_NAME", "Authorization" + ), + session_header_prefix=os.getenv( + "CALMINER_SESSION_HEADER_PREFIX", "Bearer" + ), + session_allow_header_fallback=cls._bool_from_env( + "CALMINER_SESSION_ALLOW_HEADER_FALLBACK", True + ), ) @staticmethod @@ -42,6 +86,18 @@ class Settings: except ValueError: return default + @staticmethod + def _bool_from_env(name: str, default: bool) -> bool: + raw_value = os.getenv(name) + if raw_value is None: + return default + lowered = raw_value.strip().lower() + if lowered in {"1", "true", "yes", "on"}: + return True + if lowered in {"0", "false", "no", "off"}: + return False + return default + def jwt_settings(self) -> JWTSettings: """Build runtime JWT settings compatible with token helpers.""" @@ -52,6 +108,20 @@ class Settings: refresh_token_ttl=timedelta(days=self.jwt_refresh_token_days), ) + def session_settings(self) -> SessionSettings: + """Provide transport configuration for session tokens.""" + + return SessionSettings( + access_cookie_name=self.session_access_cookie_name, + refresh_cookie_name=self.session_refresh_cookie_name, + cookie_secure=self.session_cookie_secure, + cookie_domain=self.session_cookie_domain, + cookie_path=self.session_cookie_path, + header_name=self.session_header_name, + header_prefix=self.session_header_prefix, + allow_header_fallback=self.session_allow_header_fallback, + ) + @lru_cache(maxsize=1) def get_settings() -> Settings: diff --git a/dependencies.py b/dependencies.py index b2b0774..a043f55 100644 --- a/dependencies.py +++ b/dependencies.py @@ -2,8 +2,18 @@ from __future__ import annotations from collections.abc import Generator +from fastapi import Depends, HTTPException, Request, status + from config.settings import Settings, get_settings +from models import User from services.security import JWTSettings +from services.session import ( + AuthSession, + SessionStrategy, + SessionTokens, + build_session_strategy, + extract_session_tokens, +) from services.unit_of_work import UnitOfWork @@ -24,3 +34,65 @@ def get_jwt_settings() -> JWTSettings: """Provide JWT runtime configuration derived from settings.""" return get_settings().jwt_settings() + + +def get_session_strategy( + settings: Settings = Depends(get_application_settings), +) -> SessionStrategy: + """Yield configured session transport strategy.""" + + return build_session_strategy(settings.session_settings()) + + +def get_session_tokens( + request: Request, + strategy: SessionStrategy = Depends(get_session_strategy), +) -> SessionTokens: + """Extract raw session tokens from the incoming request.""" + + existing = getattr(request.state, "auth_session", None) + if isinstance(existing, AuthSession): + return existing.tokens + + tokens = extract_session_tokens(request, strategy) + request.state.auth_session = AuthSession(tokens=tokens) + return tokens + + +def get_auth_session( + request: Request, + tokens: SessionTokens = Depends(get_session_tokens), +) -> AuthSession: + """Provide authentication session context for the current request.""" + + existing = getattr(request.state, "auth_session", None) + if isinstance(existing, AuthSession): + return existing + + if tokens.is_empty: + session = AuthSession.anonymous() + else: + session = AuthSession(tokens=tokens) + request.state.auth_session = session + return session + + +def get_current_user( + session: AuthSession = Depends(get_auth_session), +) -> User | None: + """Return the current authenticated user if present.""" + + return session.user + + +def require_current_user( + session: AuthSession = Depends(get_auth_session), +) -> User: + """Ensure that a request is authenticated and return the user context.""" + + if session.user is None: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Authentication required.", + ) + return session.user diff --git a/main.py b/main.py index 456aa5a..6fd8fb8 100644 --- a/main.py +++ b/main.py @@ -2,8 +2,10 @@ from typing import Awaitable, Callable from fastapi import FastAPI, Request, Response from fastapi.staticfiles import StaticFiles -from middleware.validation import validate_json + from config.database import Base, engine +from middleware.auth_session import AuthSessionMiddleware +from middleware.validation import validate_json from models import ( FinancialInput, Project, @@ -20,6 +22,8 @@ Base.metadata.create_all(bind=engine) app = FastAPI() +app.add_middleware(AuthSessionMiddleware) + @app.middleware("http") async def json_validation( diff --git a/middleware/auth_session.py b/middleware/auth_session.py new file mode 100644 index 0000000..fb2cb50 --- /dev/null +++ b/middleware/auth_session.py @@ -0,0 +1,177 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Callable, Iterable, Optional + +from fastapi import Request, Response +from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint +from starlette.types import ASGIApp + +from config.settings import Settings, get_settings +from models import User +from services.exceptions import EntityNotFoundError +from services.security import ( + JWTSettings, + TokenDecodeError, + TokenError, + TokenExpiredError, + TokenTypeMismatchError, + create_access_token, + create_refresh_token, + decode_access_token, + decode_refresh_token, +) +from services.session import ( + AuthSession, + SessionStrategy, + SessionTokens, + build_session_strategy, + clear_session_cookies, + extract_session_tokens, + set_session_cookies, +) +from services.unit_of_work import UnitOfWork + +_AUTH_SCOPE = "auth" + + +@dataclass(slots=True) +class _ResolutionResult: + session: AuthSession + strategy: SessionStrategy + jwt_settings: JWTSettings + + +class AuthSessionMiddleware(BaseHTTPMiddleware): + """Resolve authenticated users from session cookies and refresh tokens.""" + + def __init__( + self, + app: ASGIApp, + *, + settings_provider: Callable[[], Settings] = get_settings, + unit_of_work_factory: Callable[[], UnitOfWork] = UnitOfWork, + refresh_scopes: Iterable[str] | None = None, + ) -> None: + super().__init__(app) + self._settings_provider = settings_provider + self._unit_of_work_factory = unit_of_work_factory + self._refresh_scopes = tuple( + refresh_scopes) if refresh_scopes else (_AUTH_SCOPE,) + + async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response: + resolved = self._resolve_session(request) + response = await call_next(request) + self._apply_session(response, resolved) + return response + + def _resolve_session(self, request: Request) -> _ResolutionResult: + settings = self._settings_provider() + jwt_settings = settings.jwt_settings() + strategy = build_session_strategy(settings.session_settings()) + + tokens = extract_session_tokens(request, strategy) + session = AuthSession(tokens=tokens) + request.state.auth_session = session + + if tokens.access_token: + if self._try_access_token(session, tokens, jwt_settings): + return _ResolutionResult(session=session, strategy=strategy, jwt_settings=jwt_settings) + + if tokens.refresh_token: + self._try_refresh_token( + session, tokens.refresh_token, jwt_settings) + + return _ResolutionResult(session=session, strategy=strategy, jwt_settings=jwt_settings) + + def _try_access_token( + self, + session: AuthSession, + tokens: SessionTokens, + jwt_settings: JWTSettings, + ) -> bool: + try: + payload = decode_access_token( + tokens.access_token or "", jwt_settings) + except TokenExpiredError: + return False + except (TokenDecodeError, TokenTypeMismatchError, TokenError): + session.mark_cleared() + return False + + user = self._load_user(payload.sub) + if not user or not user.is_active or _AUTH_SCOPE not in payload.scopes: + session.mark_cleared() + return False + + session.user = user + session.scopes = tuple(payload.scopes) + return True + + def _try_refresh_token( + self, + session: AuthSession, + refresh_token: str, + jwt_settings: JWTSettings, + ) -> None: + try: + payload = decode_refresh_token(refresh_token, jwt_settings) + except (TokenExpiredError, TokenDecodeError, TokenTypeMismatchError, TokenError): + session.mark_cleared() + return + + user = self._load_user(payload.sub) + if not user or not user.is_active or not self._is_refresh_scope_allowed(payload.scopes): + session.mark_cleared() + return + + session.user = user + session.scopes = tuple(payload.scopes) + + access_token = create_access_token( + str(user.id), + jwt_settings, + scopes=payload.scopes, + ) + new_refresh = create_refresh_token( + str(user.id), + jwt_settings, + scopes=payload.scopes, + ) + session.issue_tokens(access_token=access_token, + refresh_token=new_refresh) + + def _is_refresh_scope_allowed(self, scopes: Iterable[str]) -> bool: + candidate_scopes = set(scopes) + return any(scope in candidate_scopes for scope in self._refresh_scopes) + + def _load_user(self, subject: str) -> Optional[User]: + try: + user_id = int(subject) + except ValueError: + return None + + with self._unit_of_work_factory() as uow: + if not uow.users: + return None + try: + user = uow.users.get(user_id, with_roles=True) + except EntityNotFoundError: + return None + return user + + def _apply_session(self, response: Response, resolved: _ResolutionResult) -> None: + session = resolved.session + if session.clear_cookies: + clear_session_cookies(response, resolved.strategy) + return + + if session.issued_access_token: + refresh_token = session.issued_refresh_token or session.tokens.refresh_token + set_session_cookies( + response, + access_token=session.issued_access_token, + refresh_token=refresh_token, + strategy=resolved.strategy, + jwt_settings=resolved.jwt_settings, + ) diff --git a/routes/auth.py b/routes/auth.py index 71a8752..19beaa0 100644 --- a/routes/auth.py +++ b/routes/auth.py @@ -9,7 +9,13 @@ from fastapi.templating import Jinja2Templates from pydantic import ValidationError from starlette.datastructures import FormData -from dependencies import get_jwt_settings, get_unit_of_work +from dependencies import ( + get_auth_session, + get_jwt_settings, + get_session_strategy, + get_unit_of_work, + require_current_user, +) from models import Role, User from schemas.auth import ( LoginForm, @@ -29,6 +35,12 @@ from services.security import ( hash_password, verify_password, ) +from services.session import ( + AuthSession, + SessionStrategy, + clear_session_cookies, + set_session_cookies, +) from services.repositories import RoleRepository, UserRepository from services.unit_of_work import UnitOfWork @@ -103,6 +115,7 @@ async def login_submit( request: Request, uow: UnitOfWork = Depends(get_unit_of_work), jwt_settings: JWTSettings = Depends(get_jwt_settings), + session_strategy: SessionStrategy = Depends(get_session_strategy), ): form_data = _normalise_form_data(await request.form()) try: @@ -158,7 +171,31 @@ async def login_submit( request.url_for("dashboard.home"), status_code=status.HTTP_303_SEE_OTHER, ) - _set_auth_cookies(response, access_token, refresh_token, jwt_settings) + set_session_cookies( + response, + access_token=access_token, + refresh_token=refresh_token, + strategy=session_strategy, + jwt_settings=jwt_settings, + ) + return response + + +@router.get("/logout", include_in_schema=False, name="auth.logout") +async def logout( + request: Request, + _: User = Depends(require_current_user), + session: AuthSession = Depends(get_auth_session), + session_strategy: SessionStrategy = Depends(get_session_strategy), +) -> RedirectResponse: + session.mark_cleared() + redirect_url = request.url_for( + "auth.login_form").include_query_params(logout="1") + response = RedirectResponse( + redirect_url, + status_code=status.HTTP_303_SEE_OTHER, + ) + clear_session_cookies(response, session_strategy) return response @@ -168,32 +205,6 @@ def _lookup_user(users_repo: UserRepository, identifier: str) -> User | None: return users_repo.get_by_username(identifier, with_roles=True) -def _set_auth_cookies( - response: RedirectResponse, - access_token: str, - refresh_token: str, - jwt_settings: JWTSettings, -) -> None: - access_ttl = int(jwt_settings.access_token_ttl.total_seconds()) - refresh_ttl = int(jwt_settings.refresh_token_ttl.total_seconds()) - response.set_cookie( - "calminer_access_token", - access_token, - httponly=True, - secure=False, - samesite="lax", - max_age=max(access_ttl, 0) or None, - ) - response.set_cookie( - "calminer_refresh_token", - refresh_token, - httponly=True, - secure=False, - samesite="lax", - max_age=max(refresh_ttl, 0) or None, - ) - - @router.get("/register", response_class=HTMLResponse, include_in_schema=False, name="auth.register_form") def register_form(request: Request) -> HTMLResponse: return _template( diff --git a/services/session.py b/services/session.py new file mode 100644 index 0000000..b989c7a --- /dev/null +++ b/services/session.py @@ -0,0 +1,192 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Literal, Optional, TYPE_CHECKING + +from fastapi import Request, Response + +from config.settings import SessionSettings +from services.security import JWTSettings + +if TYPE_CHECKING: # pragma: no cover - used only for static typing + from models import User + + +@dataclass(slots=True) +class SessionStrategy: + """Describe how authentication tokens are transported with requests.""" + + access_cookie_name: str + refresh_cookie_name: str + cookie_secure: bool + cookie_domain: Optional[str] + cookie_path: str + header_name: str + header_prefix: str + allow_header_fallback: bool = True + + @classmethod + def from_settings(cls, settings: SessionSettings) -> "SessionStrategy": + return cls( + access_cookie_name=settings.access_cookie_name, + refresh_cookie_name=settings.refresh_cookie_name, + cookie_secure=settings.cookie_secure, + cookie_domain=settings.cookie_domain, + cookie_path=settings.cookie_path, + header_name=settings.header_name, + header_prefix=settings.header_prefix, + allow_header_fallback=settings.allow_header_fallback, + ) + + +@dataclass(slots=True) +class SessionTokens: + """Raw access and refresh tokens extracted from the transport layer.""" + + access_token: Optional[str] + refresh_token: Optional[str] + access_token_source: Literal["cookie", "header", "none"] = "none" + + @property + def has_access(self) -> bool: + return bool(self.access_token) + + @property + def has_refresh(self) -> bool: + return bool(self.refresh_token) + + @property + def is_empty(self) -> bool: + return not self.has_access and not self.has_refresh + + +@dataclass(slots=True) +class AuthSession: + """Holds authenticated user context resolved from session tokens.""" + + tokens: SessionTokens + user: Optional["User"] = None + scopes: tuple[str, ...] = () + issued_access_token: Optional[str] = None + issued_refresh_token: Optional[str] = None + clear_cookies: bool = False + + @property + def is_authenticated(self) -> bool: + return self.user is not None + + @classmethod + def anonymous(cls) -> "AuthSession": + return cls(tokens=SessionTokens(access_token=None, refresh_token=None)) + + def issue_tokens( + self, + *, + access_token: str, + refresh_token: Optional[str] = None, + access_source: Literal["cookie", "header", "none"] = "cookie", + ) -> None: + self.issued_access_token = access_token + if refresh_token is not None: + self.issued_refresh_token = refresh_token + self.tokens = SessionTokens( + access_token=access_token, + refresh_token=refresh_token if refresh_token is not None else self.tokens.refresh_token, + access_token_source=access_source, + ) + + def mark_cleared(self) -> None: + self.clear_cookies = True + self.tokens = SessionTokens(access_token=None, refresh_token=None) + self.user = None + self.scopes = () + + +def extract_session_tokens(request: Request, strategy: SessionStrategy) -> SessionTokens: + """Pull tokens from cookies or headers according to configured strategy.""" + + access_token: Optional[str] = None + refresh_token: Optional[str] = None + access_source: Literal["cookie", "header", "none"] = "none" + + if strategy.access_cookie_name in request.cookies: + access_token = request.cookies.get(strategy.access_cookie_name) or None + if access_token: + access_source = "cookie" + + if strategy.refresh_cookie_name in request.cookies: + refresh_token = request.cookies.get( + strategy.refresh_cookie_name) or None + + if not access_token and strategy.allow_header_fallback: + header_value = request.headers.get(strategy.header_name) + if header_value: + candidate = header_value.strip() + prefix = f"{strategy.header_prefix} " if strategy.header_prefix else "" + if prefix and candidate.lower().startswith(prefix.lower()): + candidate = candidate[len(prefix):].strip() + if candidate: + access_token = candidate + access_source = "header" + + return SessionTokens( + access_token=access_token, + refresh_token=refresh_token, + access_token_source=access_source, + ) + + +def build_session_strategy(settings: SessionSettings) -> SessionStrategy: + """Create a session strategy object from settings configuration.""" + + return SessionStrategy.from_settings(settings) + + +def set_session_cookies( + response: Response, + *, + access_token: str, + refresh_token: Optional[str], + strategy: SessionStrategy, + jwt_settings: JWTSettings, +) -> None: + """Persist session cookies on an outgoing response.""" + + access_ttl = int(jwt_settings.access_token_ttl.total_seconds()) + refresh_ttl = int(jwt_settings.refresh_token_ttl.total_seconds()) + response.set_cookie( + strategy.access_cookie_name, + access_token, + httponly=True, + secure=strategy.cookie_secure, + samesite="lax", + max_age=max(access_ttl, 0) or None, + domain=strategy.cookie_domain, + path=strategy.cookie_path, + ) + if refresh_token is not None: + response.set_cookie( + strategy.refresh_cookie_name, + refresh_token, + httponly=True, + secure=strategy.cookie_secure, + samesite="lax", + max_age=max(refresh_ttl, 0) or None, + domain=strategy.cookie_domain, + path=strategy.cookie_path, + ) + + +def clear_session_cookies(response: Response, strategy: SessionStrategy) -> None: + """Remove session cookies from the client.""" + + response.delete_cookie( + strategy.access_cookie_name, + domain=strategy.cookie_domain, + path=strategy.cookie_path, + ) + response.delete_cookie( + strategy.refresh_cookie_name, + domain=strategy.cookie_domain, + path=strategy.cookie_path, + ) diff --git a/templates/partials/sidebar_nav.html b/templates/partials/sidebar_nav.html index 1980125..90e3c1a 100644 --- a/templates/partials/sidebar_nav.html +++ b/templates/partials/sidebar_nav.html @@ -1,60 +1,97 @@ {% set dashboard_href = request.url_for('dashboard.home') if request else '/' %} -{% set projects_href = request.url_for('projects.project_list_page') if request -else '/projects/ui' %} {% set project_create_href = -request.url_for('projects.create_project_form') if request else -'/projects/create' %} {% set login_href = request.url_for('auth.login_form') if -request else '/login' %} {% set register_href = -request.url_for('auth.register_form') if request else '/register' %} {% set -forgot_href = request.url_for('auth.password_reset_request_form') if request -else '/forgot-password' %} {% set nav_groups = [ { "label": "Workspace", -"links": [ {"href": dashboard_href, "label": "Dashboard", "match_prefix": "/"}, -{"href": projects_href, "label": "Projects", "match_prefix": "/projects"}, -{"href": project_create_href, "label": "New Project", "match_prefix": -"/projects/create"}, ], }, { "label": "Insights", "links": [ {"href": -"/ui/simulations", "label": "Simulations"}, {"href": "/ui/reporting", "label": -"Reporting"}, ], }, { "label": "Configuration", "links": [ { "href": -"/ui/settings", "label": "Settings", "children": [ {"href": "/theme-settings", -"label": "Themes"}, {"href": "/ui/currencies", "label": "Currency Management"}, -], }, ], }, { "label": "Account", "links": [ {"href": login_href, "label": -"Login", "match_prefix": "/login"}, {"href": register_href, "label": "Register", -"match_prefix": "/register"}, {"href": forgot_href, "label": "Forgot Password", -"match_prefix": "/forgot-password"}, ], }, ] %} +{% set projects_href = request.url_for('projects.project_list_page') if request else '/projects/ui' %} +{% set project_create_href = request.url_for('projects.create_project_form') if request else '/projects/create' %} +{% set auth_session = request.state.auth_session if request else None %} +{% set is_authenticated = auth_session and auth_session.is_authenticated %} + +{% if is_authenticated %} + {% set logout_href = request.url_for('auth.logout') if request else '/logout' %} + {% set account_links = [ + {"href": logout_href, "label": "Logout", "match_prefix": "/logout"} + ] %} +{% else %} + {% set login_href = request.url_for('auth.login_form') if request else '/login' %} + {% set register_href = request.url_for('auth.register_form') if request else '/register' %} + {% set forgot_href = request.url_for('auth.password_reset_request_form') if request else '/forgot-password' %} + {% set account_links = [ + {"href": login_href, "label": "Login", "match_prefix": "/login"}, + {"href": register_href, "label": "Register", "match_prefix": "/register"}, + {"href": forgot_href, "label": "Forgot Password", "match_prefix": "/forgot-password"} + ] %} +{% endif %} +{% set nav_groups = [ + { + "label": "Workspace", + "links": [ + {"href": dashboard_href, "label": "Dashboard", "match_prefix": "/"}, + {"href": projects_href, "label": "Projects", "match_prefix": "/projects"}, + {"href": project_create_href, "label": "New Project", "match_prefix": "/projects/create"} + ] + }, + { + "label": "Insights", + "links": [ + {"href": "/ui/simulations", "label": "Simulations"}, + {"href": "/ui/reporting", "label": "Reporting"} + ] + }, + { + "label": "Configuration", + "links": [ + { + "href": "/ui/settings", + "label": "Settings", + "children": [ + {"href": "/theme-settings", "label": "Themes"}, + {"href": "/ui/currencies", "label": "Currency Management"} + ] + } + ] + }, + { + "label": "Account", + "links": account_links + } +] %} diff --git a/tests/test_auth_routes.py b/tests/test_auth_routes.py index 08b4956..9196503 100644 --- a/tests/test_auth_routes.py +++ b/tests/test_auth_routes.py @@ -1,15 +1,19 @@ from __future__ import annotations from collections.abc import Iterator +from typing import cast from urllib.parse import parse_qs, urlparse import pytest +from fastapi import FastAPI from fastapi.testclient import TestClient from sqlalchemy import select from sqlalchemy.orm import Session, sessionmaker from models import Role, User, UserRole +from dependencies import get_auth_session, require_current_user from services.security import hash_password +from services.session import AuthSession, SessionTokens @pytest.fixture() @@ -223,7 +227,8 @@ class TestPasswordResetFlow: data={"email": "mismatch@example.com"}, follow_redirects=False, ) - token = parse_qs(urlparse(request_response.headers["location"]).query)["token"][0] + token = parse_qs(urlparse(request_response.headers["location"]).query)[ + "token"][0] submit_response = client.post( "/reset-password", @@ -236,4 +241,46 @@ class TestPasswordResetFlow: ) assert submit_response.status_code == 400 - assert "Passwords do not match" in submit_response.text \ No newline at end of file + assert "Passwords do not match" in submit_response.text + + +class TestLogoutFlow: + def test_logout_clears_cookies_and_redirects( + self, + client: TestClient, + db_session: Session, + ) -> None: + user = User( + email="logout@example.com", + username="logoutuser", + password_hash=hash_password("SecureP@ss1"), + is_active=True, + ) + db_session.add(user) + db_session.commit() + + session = AuthSession( + tokens=SessionTokens( + access_token="access-token", + refresh_token="refresh-token", + access_token_source="cookie", + ), + user=user, + ) + + app = cast(FastAPI, client.app) + app.dependency_overrides[require_current_user] = lambda: user + app.dependency_overrides[get_auth_session] = lambda: session + + try: + response = client.get("/logout", follow_redirects=False) + finally: + app.dependency_overrides.pop(require_current_user, None) + app.dependency_overrides.pop(get_auth_session, None) + + assert response.status_code == 303 + location = response.headers.get("location") + assert location and location.startswith("http://testserver/login") + set_cookie_header = response.headers.get("set-cookie") or "" + assert "calminer_access_token=" in set_cookie_header + assert "Max-Age=0" in set_cookie_header or "expires=" in set_cookie_header.lower() diff --git a/tests/test_auth_session_middleware.py b/tests/test_auth_session_middleware.py new file mode 100644 index 0000000..d31ab71 --- /dev/null +++ b/tests/test_auth_session_middleware.py @@ -0,0 +1,111 @@ +from __future__ import annotations + +from datetime import timedelta +from typing import Iterator + +import pytest +from fastapi import Depends, FastAPI +from fastapi.responses import JSONResponse +from fastapi.testclient import TestClient +from sqlalchemy.orm import sessionmaker + +from config.settings import get_settings +from dependencies import get_unit_of_work, require_current_user +from middleware.auth_session import AuthSessionMiddleware +from models import User +from services.security import create_access_token, create_refresh_token, hash_password +from services.unit_of_work import UnitOfWork + + +@pytest.fixture() +def auth_app(session_factory: sessionmaker) -> Iterator[TestClient]: + app = FastAPI() + + def _override_uow() -> Iterator[UnitOfWork]: + with UnitOfWork(session_factory=session_factory) as uow: + yield uow + + app.dependency_overrides[get_unit_of_work] = _override_uow + + @app.get("/me") + def read_me(user: User = Depends(require_current_user)) -> JSONResponse: + return JSONResponse({"id": user.id, "username": user.username}) + + app.add_middleware( + AuthSessionMiddleware, + unit_of_work_factory=lambda: UnitOfWork( + session_factory=session_factory), + ) + + client = TestClient(app) + try: + yield client + finally: + client.close() + + +def _create_user(session_factory: sessionmaker) -> User: + with UnitOfWork(session_factory=session_factory) as uow: + assert uow.users is not None + user = User( + email="jane@example.com", + username="jane", + password_hash=hash_password("secret"), + is_active=True, + is_superuser=False, + ) + uow.users.create(user) + return user + + +def _issue_tokens(user: User) -> tuple[str, str]: + settings = get_settings().jwt_settings() + access = create_access_token(str(user.id), settings, scopes=["auth"]) + refresh = create_refresh_token(str(user.id), settings, scopes=["auth"]) + return access, refresh + + +def test_middleware_populates_current_user(auth_app: TestClient, session_factory: sessionmaker) -> None: + user = _create_user(session_factory) + access, refresh = _issue_tokens(user) + + auth_app.cookies.set("calminer_access_token", access) + auth_app.cookies.set("calminer_refresh_token", refresh) + + response = auth_app.get("/me") + assert response.status_code == 200 + payload = response.json() + assert payload["id"] == user.id + assert payload["username"] == user.username + + +def test_middleware_refreshes_expired_access_token(auth_app: TestClient, session_factory: sessionmaker) -> None: + user = _create_user(session_factory) + settings = get_settings().jwt_settings() + expired = create_access_token( + str(user.id), + settings, + scopes=["auth"], + expires_delta=timedelta(seconds=-1), + ) + refresh = create_refresh_token(str(user.id), settings, scopes=["auth"]) + + auth_app.cookies.set("calminer_access_token", expired) + auth_app.cookies.set("calminer_refresh_token", refresh) + + response = auth_app.get("/me") + assert response.status_code == 200 + new_access = response.cookies.get("calminer_access_token") + new_refresh = response.cookies.get("calminer_refresh_token") + assert new_access is not None and new_access != expired + assert new_refresh is not None + + +def test_middleware_blocks_invalid_tokens(auth_app: TestClient) -> None: + auth_app.cookies.set("calminer_access_token", "invalid-token") + auth_app.cookies.set("calminer_refresh_token", "invalid-token") + + response = auth_app.get("/me") + assert response.status_code == 401 + set_cookies = response.headers.get_list("set-cookie") + assert any("calminer_access_token=" in value for value in set_cookies)