feat: Implement session management with middleware and update authentication flow

This commit is contained in:
2025-11-09 23:14:41 +01:00
parent 3601c2e422
commit 27262bdfa3
9 changed files with 804 additions and 83 deletions

View File

@@ -5,9 +5,25 @@ from dataclasses import dataclass
from datetime import timedelta from datetime import timedelta
from functools import lru_cache from functools import lru_cache
from typing import Optional
from services.security import JWTSettings from services.security import JWTSettings
@dataclass(frozen=True, slots=True)
class SessionSettings:
"""Cookie and header configuration for session token transport."""
access_cookie_name: str
refresh_cookie_name: str
cookie_secure: bool
cookie_domain: Optional[str]
cookie_path: str
header_name: str
header_prefix: str
allow_header_fallback: bool
@dataclass(frozen=True, slots=True) @dataclass(frozen=True, slots=True)
class Settings: class Settings:
"""Application configuration sourced from environment variables.""" """Application configuration sourced from environment variables."""
@@ -16,6 +32,14 @@ class Settings:
jwt_algorithm: str = "HS256" jwt_algorithm: str = "HS256"
jwt_access_token_minutes: int = 15 jwt_access_token_minutes: int = 15
jwt_refresh_token_days: int = 7 jwt_refresh_token_days: int = 7
session_access_cookie_name: str = "calminer_access_token"
session_refresh_cookie_name: str = "calminer_refresh_token"
session_cookie_secure: bool = False
session_cookie_domain: Optional[str] = None
session_cookie_path: str = "/"
session_header_name: str = "Authorization"
session_header_prefix: str = "Bearer"
session_allow_header_fallback: bool = True
@classmethod @classmethod
def from_environment(cls) -> "Settings": def from_environment(cls) -> "Settings":
@@ -30,6 +54,26 @@ class Settings:
jwt_refresh_token_days=cls._int_from_env( jwt_refresh_token_days=cls._int_from_env(
"CALMINER_JWT_REFRESH_DAYS", 7 "CALMINER_JWT_REFRESH_DAYS", 7
), ),
session_access_cookie_name=os.getenv(
"CALMINER_SESSION_ACCESS_COOKIE", "calminer_access_token"
),
session_refresh_cookie_name=os.getenv(
"CALMINER_SESSION_REFRESH_COOKIE", "calminer_refresh_token"
),
session_cookie_secure=cls._bool_from_env(
"CALMINER_SESSION_COOKIE_SECURE", False
),
session_cookie_domain=os.getenv("CALMINER_SESSION_COOKIE_DOMAIN"),
session_cookie_path=os.getenv("CALMINER_SESSION_COOKIE_PATH", "/"),
session_header_name=os.getenv(
"CALMINER_SESSION_HEADER_NAME", "Authorization"
),
session_header_prefix=os.getenv(
"CALMINER_SESSION_HEADER_PREFIX", "Bearer"
),
session_allow_header_fallback=cls._bool_from_env(
"CALMINER_SESSION_ALLOW_HEADER_FALLBACK", True
),
) )
@staticmethod @staticmethod
@@ -42,6 +86,18 @@ class Settings:
except ValueError: except ValueError:
return default return default
@staticmethod
def _bool_from_env(name: str, default: bool) -> bool:
raw_value = os.getenv(name)
if raw_value is None:
return default
lowered = raw_value.strip().lower()
if lowered in {"1", "true", "yes", "on"}:
return True
if lowered in {"0", "false", "no", "off"}:
return False
return default
def jwt_settings(self) -> JWTSettings: def jwt_settings(self) -> JWTSettings:
"""Build runtime JWT settings compatible with token helpers.""" """Build runtime JWT settings compatible with token helpers."""
@@ -52,6 +108,20 @@ class Settings:
refresh_token_ttl=timedelta(days=self.jwt_refresh_token_days), refresh_token_ttl=timedelta(days=self.jwt_refresh_token_days),
) )
def session_settings(self) -> SessionSettings:
"""Provide transport configuration for session tokens."""
return SessionSettings(
access_cookie_name=self.session_access_cookie_name,
refresh_cookie_name=self.session_refresh_cookie_name,
cookie_secure=self.session_cookie_secure,
cookie_domain=self.session_cookie_domain,
cookie_path=self.session_cookie_path,
header_name=self.session_header_name,
header_prefix=self.session_header_prefix,
allow_header_fallback=self.session_allow_header_fallback,
)
@lru_cache(maxsize=1) @lru_cache(maxsize=1)
def get_settings() -> Settings: def get_settings() -> Settings:

View File

@@ -2,8 +2,18 @@ from __future__ import annotations
from collections.abc import Generator from collections.abc import Generator
from fastapi import Depends, HTTPException, Request, status
from config.settings import Settings, get_settings from config.settings import Settings, get_settings
from models import User
from services.security import JWTSettings from services.security import JWTSettings
from services.session import (
AuthSession,
SessionStrategy,
SessionTokens,
build_session_strategy,
extract_session_tokens,
)
from services.unit_of_work import UnitOfWork from services.unit_of_work import UnitOfWork
@@ -24,3 +34,65 @@ def get_jwt_settings() -> JWTSettings:
"""Provide JWT runtime configuration derived from settings.""" """Provide JWT runtime configuration derived from settings."""
return get_settings().jwt_settings() return get_settings().jwt_settings()
def get_session_strategy(
settings: Settings = Depends(get_application_settings),
) -> SessionStrategy:
"""Yield configured session transport strategy."""
return build_session_strategy(settings.session_settings())
def get_session_tokens(
request: Request,
strategy: SessionStrategy = Depends(get_session_strategy),
) -> SessionTokens:
"""Extract raw session tokens from the incoming request."""
existing = getattr(request.state, "auth_session", None)
if isinstance(existing, AuthSession):
return existing.tokens
tokens = extract_session_tokens(request, strategy)
request.state.auth_session = AuthSession(tokens=tokens)
return tokens
def get_auth_session(
request: Request,
tokens: SessionTokens = Depends(get_session_tokens),
) -> AuthSession:
"""Provide authentication session context for the current request."""
existing = getattr(request.state, "auth_session", None)
if isinstance(existing, AuthSession):
return existing
if tokens.is_empty:
session = AuthSession.anonymous()
else:
session = AuthSession(tokens=tokens)
request.state.auth_session = session
return session
def get_current_user(
session: AuthSession = Depends(get_auth_session),
) -> User | None:
"""Return the current authenticated user if present."""
return session.user
def require_current_user(
session: AuthSession = Depends(get_auth_session),
) -> User:
"""Ensure that a request is authenticated and return the user context."""
if session.user is None:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Authentication required.",
)
return session.user

View File

@@ -2,8 +2,10 @@ from typing import Awaitable, Callable
from fastapi import FastAPI, Request, Response from fastapi import FastAPI, Request, Response
from fastapi.staticfiles import StaticFiles from fastapi.staticfiles import StaticFiles
from middleware.validation import validate_json
from config.database import Base, engine from config.database import Base, engine
from middleware.auth_session import AuthSessionMiddleware
from middleware.validation import validate_json
from models import ( from models import (
FinancialInput, FinancialInput,
Project, Project,
@@ -20,6 +22,8 @@ Base.metadata.create_all(bind=engine)
app = FastAPI() app = FastAPI()
app.add_middleware(AuthSessionMiddleware)
@app.middleware("http") @app.middleware("http")
async def json_validation( async def json_validation(

177
middleware/auth_session.py Normal file
View File

@@ -0,0 +1,177 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import Callable, Iterable, Optional
from fastapi import Request, Response
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
from starlette.types import ASGIApp
from config.settings import Settings, get_settings
from models import User
from services.exceptions import EntityNotFoundError
from services.security import (
JWTSettings,
TokenDecodeError,
TokenError,
TokenExpiredError,
TokenTypeMismatchError,
create_access_token,
create_refresh_token,
decode_access_token,
decode_refresh_token,
)
from services.session import (
AuthSession,
SessionStrategy,
SessionTokens,
build_session_strategy,
clear_session_cookies,
extract_session_tokens,
set_session_cookies,
)
from services.unit_of_work import UnitOfWork
_AUTH_SCOPE = "auth"
@dataclass(slots=True)
class _ResolutionResult:
session: AuthSession
strategy: SessionStrategy
jwt_settings: JWTSettings
class AuthSessionMiddleware(BaseHTTPMiddleware):
"""Resolve authenticated users from session cookies and refresh tokens."""
def __init__(
self,
app: ASGIApp,
*,
settings_provider: Callable[[], Settings] = get_settings,
unit_of_work_factory: Callable[[], UnitOfWork] = UnitOfWork,
refresh_scopes: Iterable[str] | None = None,
) -> None:
super().__init__(app)
self._settings_provider = settings_provider
self._unit_of_work_factory = unit_of_work_factory
self._refresh_scopes = tuple(
refresh_scopes) if refresh_scopes else (_AUTH_SCOPE,)
async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response:
resolved = self._resolve_session(request)
response = await call_next(request)
self._apply_session(response, resolved)
return response
def _resolve_session(self, request: Request) -> _ResolutionResult:
settings = self._settings_provider()
jwt_settings = settings.jwt_settings()
strategy = build_session_strategy(settings.session_settings())
tokens = extract_session_tokens(request, strategy)
session = AuthSession(tokens=tokens)
request.state.auth_session = session
if tokens.access_token:
if self._try_access_token(session, tokens, jwt_settings):
return _ResolutionResult(session=session, strategy=strategy, jwt_settings=jwt_settings)
if tokens.refresh_token:
self._try_refresh_token(
session, tokens.refresh_token, jwt_settings)
return _ResolutionResult(session=session, strategy=strategy, jwt_settings=jwt_settings)
def _try_access_token(
self,
session: AuthSession,
tokens: SessionTokens,
jwt_settings: JWTSettings,
) -> bool:
try:
payload = decode_access_token(
tokens.access_token or "", jwt_settings)
except TokenExpiredError:
return False
except (TokenDecodeError, TokenTypeMismatchError, TokenError):
session.mark_cleared()
return False
user = self._load_user(payload.sub)
if not user or not user.is_active or _AUTH_SCOPE not in payload.scopes:
session.mark_cleared()
return False
session.user = user
session.scopes = tuple(payload.scopes)
return True
def _try_refresh_token(
self,
session: AuthSession,
refresh_token: str,
jwt_settings: JWTSettings,
) -> None:
try:
payload = decode_refresh_token(refresh_token, jwt_settings)
except (TokenExpiredError, TokenDecodeError, TokenTypeMismatchError, TokenError):
session.mark_cleared()
return
user = self._load_user(payload.sub)
if not user or not user.is_active or not self._is_refresh_scope_allowed(payload.scopes):
session.mark_cleared()
return
session.user = user
session.scopes = tuple(payload.scopes)
access_token = create_access_token(
str(user.id),
jwt_settings,
scopes=payload.scopes,
)
new_refresh = create_refresh_token(
str(user.id),
jwt_settings,
scopes=payload.scopes,
)
session.issue_tokens(access_token=access_token,
refresh_token=new_refresh)
def _is_refresh_scope_allowed(self, scopes: Iterable[str]) -> bool:
candidate_scopes = set(scopes)
return any(scope in candidate_scopes for scope in self._refresh_scopes)
def _load_user(self, subject: str) -> Optional[User]:
try:
user_id = int(subject)
except ValueError:
return None
with self._unit_of_work_factory() as uow:
if not uow.users:
return None
try:
user = uow.users.get(user_id, with_roles=True)
except EntityNotFoundError:
return None
return user
def _apply_session(self, response: Response, resolved: _ResolutionResult) -> None:
session = resolved.session
if session.clear_cookies:
clear_session_cookies(response, resolved.strategy)
return
if session.issued_access_token:
refresh_token = session.issued_refresh_token or session.tokens.refresh_token
set_session_cookies(
response,
access_token=session.issued_access_token,
refresh_token=refresh_token,
strategy=resolved.strategy,
jwt_settings=resolved.jwt_settings,
)

View File

@@ -9,7 +9,13 @@ from fastapi.templating import Jinja2Templates
from pydantic import ValidationError from pydantic import ValidationError
from starlette.datastructures import FormData from starlette.datastructures import FormData
from dependencies import get_jwt_settings, get_unit_of_work from dependencies import (
get_auth_session,
get_jwt_settings,
get_session_strategy,
get_unit_of_work,
require_current_user,
)
from models import Role, User from models import Role, User
from schemas.auth import ( from schemas.auth import (
LoginForm, LoginForm,
@@ -29,6 +35,12 @@ from services.security import (
hash_password, hash_password,
verify_password, verify_password,
) )
from services.session import (
AuthSession,
SessionStrategy,
clear_session_cookies,
set_session_cookies,
)
from services.repositories import RoleRepository, UserRepository from services.repositories import RoleRepository, UserRepository
from services.unit_of_work import UnitOfWork from services.unit_of_work import UnitOfWork
@@ -103,6 +115,7 @@ async def login_submit(
request: Request, request: Request,
uow: UnitOfWork = Depends(get_unit_of_work), uow: UnitOfWork = Depends(get_unit_of_work),
jwt_settings: JWTSettings = Depends(get_jwt_settings), jwt_settings: JWTSettings = Depends(get_jwt_settings),
session_strategy: SessionStrategy = Depends(get_session_strategy),
): ):
form_data = _normalise_form_data(await request.form()) form_data = _normalise_form_data(await request.form())
try: try:
@@ -158,7 +171,31 @@ async def login_submit(
request.url_for("dashboard.home"), request.url_for("dashboard.home"),
status_code=status.HTTP_303_SEE_OTHER, status_code=status.HTTP_303_SEE_OTHER,
) )
_set_auth_cookies(response, access_token, refresh_token, jwt_settings) set_session_cookies(
response,
access_token=access_token,
refresh_token=refresh_token,
strategy=session_strategy,
jwt_settings=jwt_settings,
)
return response
@router.get("/logout", include_in_schema=False, name="auth.logout")
async def logout(
request: Request,
_: User = Depends(require_current_user),
session: AuthSession = Depends(get_auth_session),
session_strategy: SessionStrategy = Depends(get_session_strategy),
) -> RedirectResponse:
session.mark_cleared()
redirect_url = request.url_for(
"auth.login_form").include_query_params(logout="1")
response = RedirectResponse(
redirect_url,
status_code=status.HTTP_303_SEE_OTHER,
)
clear_session_cookies(response, session_strategy)
return response return response
@@ -168,32 +205,6 @@ def _lookup_user(users_repo: UserRepository, identifier: str) -> User | None:
return users_repo.get_by_username(identifier, with_roles=True) return users_repo.get_by_username(identifier, with_roles=True)
def _set_auth_cookies(
response: RedirectResponse,
access_token: str,
refresh_token: str,
jwt_settings: JWTSettings,
) -> None:
access_ttl = int(jwt_settings.access_token_ttl.total_seconds())
refresh_ttl = int(jwt_settings.refresh_token_ttl.total_seconds())
response.set_cookie(
"calminer_access_token",
access_token,
httponly=True,
secure=False,
samesite="lax",
max_age=max(access_ttl, 0) or None,
)
response.set_cookie(
"calminer_refresh_token",
refresh_token,
httponly=True,
secure=False,
samesite="lax",
max_age=max(refresh_ttl, 0) or None,
)
@router.get("/register", response_class=HTMLResponse, include_in_schema=False, name="auth.register_form") @router.get("/register", response_class=HTMLResponse, include_in_schema=False, name="auth.register_form")
def register_form(request: Request) -> HTMLResponse: def register_form(request: Request) -> HTMLResponse:
return _template( return _template(

192
services/session.py Normal file
View File

@@ -0,0 +1,192 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import Literal, Optional, TYPE_CHECKING
from fastapi import Request, Response
from config.settings import SessionSettings
from services.security import JWTSettings
if TYPE_CHECKING: # pragma: no cover - used only for static typing
from models import User
@dataclass(slots=True)
class SessionStrategy:
"""Describe how authentication tokens are transported with requests."""
access_cookie_name: str
refresh_cookie_name: str
cookie_secure: bool
cookie_domain: Optional[str]
cookie_path: str
header_name: str
header_prefix: str
allow_header_fallback: bool = True
@classmethod
def from_settings(cls, settings: SessionSettings) -> "SessionStrategy":
return cls(
access_cookie_name=settings.access_cookie_name,
refresh_cookie_name=settings.refresh_cookie_name,
cookie_secure=settings.cookie_secure,
cookie_domain=settings.cookie_domain,
cookie_path=settings.cookie_path,
header_name=settings.header_name,
header_prefix=settings.header_prefix,
allow_header_fallback=settings.allow_header_fallback,
)
@dataclass(slots=True)
class SessionTokens:
"""Raw access and refresh tokens extracted from the transport layer."""
access_token: Optional[str]
refresh_token: Optional[str]
access_token_source: Literal["cookie", "header", "none"] = "none"
@property
def has_access(self) -> bool:
return bool(self.access_token)
@property
def has_refresh(self) -> bool:
return bool(self.refresh_token)
@property
def is_empty(self) -> bool:
return not self.has_access and not self.has_refresh
@dataclass(slots=True)
class AuthSession:
"""Holds authenticated user context resolved from session tokens."""
tokens: SessionTokens
user: Optional["User"] = None
scopes: tuple[str, ...] = ()
issued_access_token: Optional[str] = None
issued_refresh_token: Optional[str] = None
clear_cookies: bool = False
@property
def is_authenticated(self) -> bool:
return self.user is not None
@classmethod
def anonymous(cls) -> "AuthSession":
return cls(tokens=SessionTokens(access_token=None, refresh_token=None))
def issue_tokens(
self,
*,
access_token: str,
refresh_token: Optional[str] = None,
access_source: Literal["cookie", "header", "none"] = "cookie",
) -> None:
self.issued_access_token = access_token
if refresh_token is not None:
self.issued_refresh_token = refresh_token
self.tokens = SessionTokens(
access_token=access_token,
refresh_token=refresh_token if refresh_token is not None else self.tokens.refresh_token,
access_token_source=access_source,
)
def mark_cleared(self) -> None:
self.clear_cookies = True
self.tokens = SessionTokens(access_token=None, refresh_token=None)
self.user = None
self.scopes = ()
def extract_session_tokens(request: Request, strategy: SessionStrategy) -> SessionTokens:
"""Pull tokens from cookies or headers according to configured strategy."""
access_token: Optional[str] = None
refresh_token: Optional[str] = None
access_source: Literal["cookie", "header", "none"] = "none"
if strategy.access_cookie_name in request.cookies:
access_token = request.cookies.get(strategy.access_cookie_name) or None
if access_token:
access_source = "cookie"
if strategy.refresh_cookie_name in request.cookies:
refresh_token = request.cookies.get(
strategy.refresh_cookie_name) or None
if not access_token and strategy.allow_header_fallback:
header_value = request.headers.get(strategy.header_name)
if header_value:
candidate = header_value.strip()
prefix = f"{strategy.header_prefix} " if strategy.header_prefix else ""
if prefix and candidate.lower().startswith(prefix.lower()):
candidate = candidate[len(prefix):].strip()
if candidate:
access_token = candidate
access_source = "header"
return SessionTokens(
access_token=access_token,
refresh_token=refresh_token,
access_token_source=access_source,
)
def build_session_strategy(settings: SessionSettings) -> SessionStrategy:
"""Create a session strategy object from settings configuration."""
return SessionStrategy.from_settings(settings)
def set_session_cookies(
response: Response,
*,
access_token: str,
refresh_token: Optional[str],
strategy: SessionStrategy,
jwt_settings: JWTSettings,
) -> None:
"""Persist session cookies on an outgoing response."""
access_ttl = int(jwt_settings.access_token_ttl.total_seconds())
refresh_ttl = int(jwt_settings.refresh_token_ttl.total_seconds())
response.set_cookie(
strategy.access_cookie_name,
access_token,
httponly=True,
secure=strategy.cookie_secure,
samesite="lax",
max_age=max(access_ttl, 0) or None,
domain=strategy.cookie_domain,
path=strategy.cookie_path,
)
if refresh_token is not None:
response.set_cookie(
strategy.refresh_cookie_name,
refresh_token,
httponly=True,
secure=strategy.cookie_secure,
samesite="lax",
max_age=max(refresh_ttl, 0) or None,
domain=strategy.cookie_domain,
path=strategy.cookie_path,
)
def clear_session_cookies(response: Response, strategy: SessionStrategy) -> None:
"""Remove session cookies from the client."""
response.delete_cookie(
strategy.access_cookie_name,
domain=strategy.cookie_domain,
path=strategy.cookie_path,
)
response.delete_cookie(
strategy.refresh_cookie_name,
domain=strategy.cookie_domain,
path=strategy.cookie_path,
)

View File

@@ -1,60 +1,97 @@
{% set dashboard_href = request.url_for('dashboard.home') if request else '/' %} {% set dashboard_href = request.url_for('dashboard.home') if request else '/' %}
{% set projects_href = request.url_for('projects.project_list_page') if request {% set projects_href = request.url_for('projects.project_list_page') if request else '/projects/ui' %}
else '/projects/ui' %} {% set project_create_href = {% set project_create_href = request.url_for('projects.create_project_form') if request else '/projects/create' %}
request.url_for('projects.create_project_form') if request else {% set auth_session = request.state.auth_session if request else None %}
'/projects/create' %} {% set login_href = request.url_for('auth.login_form') if {% set is_authenticated = auth_session and auth_session.is_authenticated %}
request else '/login' %} {% set register_href =
request.url_for('auth.register_form') if request else '/register' %} {% set {% if is_authenticated %}
forgot_href = request.url_for('auth.password_reset_request_form') if request {% set logout_href = request.url_for('auth.logout') if request else '/logout' %}
else '/forgot-password' %} {% set nav_groups = [ { "label": "Workspace", {% set account_links = [
"links": [ {"href": dashboard_href, "label": "Dashboard", "match_prefix": "/"}, {"href": logout_href, "label": "Logout", "match_prefix": "/logout"}
{"href": projects_href, "label": "Projects", "match_prefix": "/projects"}, ] %}
{"href": project_create_href, "label": "New Project", "match_prefix": {% else %}
"/projects/create"}, ], }, { "label": "Insights", "links": [ {"href": {% set login_href = request.url_for('auth.login_form') if request else '/login' %}
"/ui/simulations", "label": "Simulations"}, {"href": "/ui/reporting", "label": {% set register_href = request.url_for('auth.register_form') if request else '/register' %}
"Reporting"}, ], }, { "label": "Configuration", "links": [ { "href": {% set forgot_href = request.url_for('auth.password_reset_request_form') if request else '/forgot-password' %}
"/ui/settings", "label": "Settings", "children": [ {"href": "/theme-settings", {% set account_links = [
"label": "Themes"}, {"href": "/ui/currencies", "label": "Currency Management"}, {"href": login_href, "label": "Login", "match_prefix": "/login"},
], }, ], }, { "label": "Account", "links": [ {"href": login_href, "label": {"href": register_href, "label": "Register", "match_prefix": "/register"},
"Login", "match_prefix": "/login"}, {"href": register_href, "label": "Register", {"href": forgot_href, "label": "Forgot Password", "match_prefix": "/forgot-password"}
"match_prefix": "/register"}, {"href": forgot_href, "label": "Forgot Password", ] %}
"match_prefix": "/forgot-password"}, ], }, ] %} {% endif %}
{% set nav_groups = [
{
"label": "Workspace",
"links": [
{"href": dashboard_href, "label": "Dashboard", "match_prefix": "/"},
{"href": projects_href, "label": "Projects", "match_prefix": "/projects"},
{"href": project_create_href, "label": "New Project", "match_prefix": "/projects/create"}
]
},
{
"label": "Insights",
"links": [
{"href": "/ui/simulations", "label": "Simulations"},
{"href": "/ui/reporting", "label": "Reporting"}
]
},
{
"label": "Configuration",
"links": [
{
"href": "/ui/settings",
"label": "Settings",
"children": [
{"href": "/theme-settings", "label": "Themes"},
{"href": "/ui/currencies", "label": "Currency Management"}
]
}
]
},
{
"label": "Account",
"links": account_links
}
] %}
<nav class="sidebar-nav" aria-label="Primary navigation"> <nav class="sidebar-nav" aria-label="Primary navigation">
{% set current_path = request.url.path if request else "" %} {% for group in {% set current_path = request.url.path if request else '' %}
nav_groups %} {% for group in nav_groups %}
<div class="sidebar-section"> {% if group.links %}
<div class="sidebar-section-label">{{ group.label }}</div> <div class="sidebar-section">
<div class="sidebar-section-links"> <div class="sidebar-section-label">{{ group.label }}</div>
{% for link in group.links %} {% set href = link.href %} {% set <div class="sidebar-section-links">
match_prefix = link.get('match_prefix', href) %} {% if match_prefix == '/' {% for link in group.links %}
%} {% set is_active = current_path == '/' %} {% else %} {% set is_active = {% set href = link.href %}
current_path.startswith(match_prefix) %} {% endif %} {% set match_prefix = link.get('match_prefix', href) %}
<div class="sidebar-link-block"> {% if match_prefix == '/' %}
<a {% set is_active = current_path == '/' %}
href="{{ href }}" {% else %}
class="sidebar-link{% if is_active %} is-active{% endif %}" {% set is_active = current_path.startswith(match_prefix) %}
> {% endif %}
{{ link.label }} <div class="sidebar-link-block">
</a> <a href="{{ href }}" class="sidebar-link{% if is_active %} is-active{% endif %}">
{% if link.children %} {{ link.label }}
<div class="sidebar-sublinks"> </a>
{% for child in link.children %} {% set child_prefix = {% if link.children %}
child.get('match_prefix', child.href) %} {% if child_prefix == '/' %} <div class="sidebar-sublinks">
{% set child_active = current_path == '/' %} {% else %} {% set {% for child in link.children %}
child_active = current_path.startswith(child_prefix) %} {% endif %} {% set child_prefix = child.get('match_prefix', child.href) %}
<a {% if child_prefix == '/' %}
href="{{ child.href }}" {% set child_active = current_path == '/' %}
class="sidebar-sublink{% if child_active %} is-active{% endif %}" {% else %}
> {% set child_active = current_path.startswith(child_prefix) %}
{{ child.label }} {% endif %}
</a> <a href="{{ child.href }}" class="sidebar-sublink{% if child_active %} is-active{% endif %}">
{{ child.label }}
</a>
{% endfor %}
</div>
{% endif %}
</div>
{% endfor %} {% endfor %}
</div> </div>
{% endif %}
</div> </div>
{% endfor %} {% endif %}
</div>
</div>
{% endfor %} {% endfor %}
</nav> </nav>

View File

@@ -1,15 +1,19 @@
from __future__ import annotations from __future__ import annotations
from collections.abc import Iterator from collections.abc import Iterator
from typing import cast
from urllib.parse import parse_qs, urlparse from urllib.parse import parse_qs, urlparse
import pytest import pytest
from fastapi import FastAPI
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.orm import Session, sessionmaker from sqlalchemy.orm import Session, sessionmaker
from models import Role, User, UserRole from models import Role, User, UserRole
from dependencies import get_auth_session, require_current_user
from services.security import hash_password from services.security import hash_password
from services.session import AuthSession, SessionTokens
@pytest.fixture() @pytest.fixture()
@@ -223,7 +227,8 @@ class TestPasswordResetFlow:
data={"email": "mismatch@example.com"}, data={"email": "mismatch@example.com"},
follow_redirects=False, follow_redirects=False,
) )
token = parse_qs(urlparse(request_response.headers["location"]).query)["token"][0] token = parse_qs(urlparse(request_response.headers["location"]).query)[
"token"][0]
submit_response = client.post( submit_response = client.post(
"/reset-password", "/reset-password",
@@ -237,3 +242,45 @@ class TestPasswordResetFlow:
assert submit_response.status_code == 400 assert submit_response.status_code == 400
assert "Passwords do not match" in submit_response.text assert "Passwords do not match" in submit_response.text
class TestLogoutFlow:
def test_logout_clears_cookies_and_redirects(
self,
client: TestClient,
db_session: Session,
) -> None:
user = User(
email="logout@example.com",
username="logoutuser",
password_hash=hash_password("SecureP@ss1"),
is_active=True,
)
db_session.add(user)
db_session.commit()
session = AuthSession(
tokens=SessionTokens(
access_token="access-token",
refresh_token="refresh-token",
access_token_source="cookie",
),
user=user,
)
app = cast(FastAPI, client.app)
app.dependency_overrides[require_current_user] = lambda: user
app.dependency_overrides[get_auth_session] = lambda: session
try:
response = client.get("/logout", follow_redirects=False)
finally:
app.dependency_overrides.pop(require_current_user, None)
app.dependency_overrides.pop(get_auth_session, None)
assert response.status_code == 303
location = response.headers.get("location")
assert location and location.startswith("http://testserver/login")
set_cookie_header = response.headers.get("set-cookie") or ""
assert "calminer_access_token=" in set_cookie_header
assert "Max-Age=0" in set_cookie_header or "expires=" in set_cookie_header.lower()

View File

@@ -0,0 +1,111 @@
from __future__ import annotations
from datetime import timedelta
from typing import Iterator
import pytest
from fastapi import Depends, FastAPI
from fastapi.responses import JSONResponse
from fastapi.testclient import TestClient
from sqlalchemy.orm import sessionmaker
from config.settings import get_settings
from dependencies import get_unit_of_work, require_current_user
from middleware.auth_session import AuthSessionMiddleware
from models import User
from services.security import create_access_token, create_refresh_token, hash_password
from services.unit_of_work import UnitOfWork
@pytest.fixture()
def auth_app(session_factory: sessionmaker) -> Iterator[TestClient]:
app = FastAPI()
def _override_uow() -> Iterator[UnitOfWork]:
with UnitOfWork(session_factory=session_factory) as uow:
yield uow
app.dependency_overrides[get_unit_of_work] = _override_uow
@app.get("/me")
def read_me(user: User = Depends(require_current_user)) -> JSONResponse:
return JSONResponse({"id": user.id, "username": user.username})
app.add_middleware(
AuthSessionMiddleware,
unit_of_work_factory=lambda: UnitOfWork(
session_factory=session_factory),
)
client = TestClient(app)
try:
yield client
finally:
client.close()
def _create_user(session_factory: sessionmaker) -> User:
with UnitOfWork(session_factory=session_factory) as uow:
assert uow.users is not None
user = User(
email="jane@example.com",
username="jane",
password_hash=hash_password("secret"),
is_active=True,
is_superuser=False,
)
uow.users.create(user)
return user
def _issue_tokens(user: User) -> tuple[str, str]:
settings = get_settings().jwt_settings()
access = create_access_token(str(user.id), settings, scopes=["auth"])
refresh = create_refresh_token(str(user.id), settings, scopes=["auth"])
return access, refresh
def test_middleware_populates_current_user(auth_app: TestClient, session_factory: sessionmaker) -> None:
user = _create_user(session_factory)
access, refresh = _issue_tokens(user)
auth_app.cookies.set("calminer_access_token", access)
auth_app.cookies.set("calminer_refresh_token", refresh)
response = auth_app.get("/me")
assert response.status_code == 200
payload = response.json()
assert payload["id"] == user.id
assert payload["username"] == user.username
def test_middleware_refreshes_expired_access_token(auth_app: TestClient, session_factory: sessionmaker) -> None:
user = _create_user(session_factory)
settings = get_settings().jwt_settings()
expired = create_access_token(
str(user.id),
settings,
scopes=["auth"],
expires_delta=timedelta(seconds=-1),
)
refresh = create_refresh_token(str(user.id), settings, scopes=["auth"])
auth_app.cookies.set("calminer_access_token", expired)
auth_app.cookies.set("calminer_refresh_token", refresh)
response = auth_app.get("/me")
assert response.status_code == 200
new_access = response.cookies.get("calminer_access_token")
new_refresh = response.cookies.get("calminer_refresh_token")
assert new_access is not None and new_access != expired
assert new_refresh is not None
def test_middleware_blocks_invalid_tokens(auth_app: TestClient) -> None:
auth_app.cookies.set("calminer_access_token", "invalid-token")
auth_app.cookies.set("calminer_refresh_token", "invalid-token")
response = auth_app.get("/me")
assert response.status_code == 401
set_cookies = response.headers.get_list("set-cookie")
assert any("calminer_access_token=" in value for value in set_cookies)