Files
calminer/services/security.py
zwitschi 6d496a599e feat: Resolve test suite regressions and enhance token tamper detection
feat: Add UI router to application for improved routing
style: Update breadcrumb styles in main.css and remove redundant styles from scenarios.css
2025-11-12 20:30:40 +01:00

223 lines
6.1 KiB
Python

from __future__ import annotations
from dataclasses import dataclass, field
from datetime import datetime, timedelta, timezone
from hmac import compare_digest
from typing import Any, Dict, Iterable, Literal, Type
from jose import ExpiredSignatureError, JWTError, jwt
from passlib.context import CryptContext
try: # pragma: no cover - compatibility shim for passlib/argon2 warning
import importlib.metadata as importlib_metadata
import argon2 # type: ignore
setattr(argon2, "__version__", importlib_metadata.version("argon2-cffi"))
except Exception: # pragma: no cover - executed only when metadata lookup fails
pass
from pydantic import BaseModel, Field, ValidationError
password_context = CryptContext(schemes=["argon2"], deprecated="auto")
def hash_password(password: str) -> str:
"""Derive a secure hash for a plain-text password."""
return password_context.hash(password)
def verify_password(candidate: str, hashed: str) -> bool:
"""Verify that a candidate password matches a stored hash."""
try:
return password_context.verify(candidate, hashed)
except ValueError:
# Raised when the stored hash is malformed or uses an unknown scheme.
return False
class TokenError(Exception):
"""Base class for token encoding/decoding issues."""
class TokenDecodeError(TokenError):
"""Raised when a token cannot be decoded or validated."""
class TokenExpiredError(TokenError):
"""Raised when a token has expired."""
class TokenTypeMismatchError(TokenError):
"""Raised when a token type does not match the expected flavour."""
TokenKind = Literal["access", "refresh"]
class TokenPayload(BaseModel):
"""Shared fields for CalMiner JWT payloads."""
sub: str
exp: int
type: TokenKind
scopes: list[str] = Field(default_factory=list)
@property
def expires_at(self) -> datetime:
return datetime.fromtimestamp(self.exp, tz=timezone.utc)
@dataclass(slots=True)
class JWTSettings:
"""Runtime configuration for JWT encoding and validation."""
secret_key: str
algorithm: str = "HS256"
access_token_ttl: timedelta = field(
default_factory=lambda: timedelta(minutes=15))
refresh_token_ttl: timedelta = field(
default_factory=lambda: timedelta(days=7))
def create_access_token(
subject: str,
settings: JWTSettings,
*,
scopes: Iterable[str] | None = None,
expires_delta: timedelta | None = None,
extra_claims: Dict[str, Any] | None = None,
) -> str:
"""Issue a signed access token for the provided subject."""
lifetime = expires_delta or settings.access_token_ttl
return _create_token(
subject=subject,
token_type="access",
settings=settings,
lifetime=lifetime,
scopes=scopes,
extra_claims=extra_claims,
)
def create_refresh_token(
subject: str,
settings: JWTSettings,
*,
scopes: Iterable[str] | None = None,
expires_delta: timedelta | None = None,
extra_claims: Dict[str, Any] | None = None,
) -> str:
"""Issue a signed refresh token for the provided subject."""
lifetime = expires_delta or settings.refresh_token_ttl
return _create_token(
subject=subject,
token_type="refresh",
settings=settings,
lifetime=lifetime,
scopes=scopes,
extra_claims=extra_claims,
)
def decode_access_token(token: str, settings: JWTSettings) -> TokenPayload:
"""Validate and decode an access token."""
return _decode_token(token, settings, expected_type="access")
def decode_refresh_token(token: str, settings: JWTSettings) -> TokenPayload:
"""Validate and decode a refresh token."""
return _decode_token(token, settings, expected_type="refresh")
def _create_token(
*,
subject: str,
token_type: TokenKind,
settings: JWTSettings,
lifetime: timedelta,
scopes: Iterable[str] | None,
extra_claims: Dict[str, Any] | None,
) -> str:
now = datetime.now(timezone.utc)
expire = now + lifetime
payload: Dict[str, Any] = {
"sub": subject,
"type": token_type,
"iat": int(now.timestamp()),
"exp": int(expire.timestamp()),
}
if scopes:
payload["scopes"] = list(scopes)
if extra_claims:
payload.update(extra_claims)
return jwt.encode(payload, settings.secret_key, algorithm=settings.algorithm)
def _decode_token(
token: str,
settings: JWTSettings,
expected_type: TokenKind,
) -> TokenPayload:
try:
decoded = jwt.decode(
token,
settings.secret_key,
algorithms=[settings.algorithm],
options={"verify_aud": False},
)
except ExpiredSignatureError as exc: # pragma: no cover - jose marks this path
raise TokenExpiredError("Token has expired") from exc
except JWTError as exc: # pragma: no cover - jose error bubble
raise TokenDecodeError("Unable to decode token") from exc
expected_token = jwt.encode(
decoded,
settings.secret_key,
algorithm=settings.algorithm,
)
if not compare_digest(token, expected_token):
raise TokenDecodeError("Token contents have been altered.")
try:
payload = _model_validate(TokenPayload, decoded)
except ValidationError as exc:
raise TokenDecodeError("Token payload validation failed") from exc
if payload.type != expected_type:
raise TokenTypeMismatchError(
f"Expected a {expected_type} token but received '{payload.type}'."
)
return payload
def _model_validate(model: Type[TokenPayload], data: Dict[str, Any]) -> TokenPayload:
if hasattr(model, "model_validate"):
return model.model_validate(data) # type: ignore[attr-defined]
return model.parse_obj(data) # type: ignore[attr-defined]
__all__ = [
"JWTSettings",
"TokenDecodeError",
"TokenError",
"TokenExpiredError",
"TokenKind",
"TokenPayload",
"TokenTypeMismatchError",
"create_access_token",
"create_refresh_token",
"decode_access_token",
"decode_refresh_token",
"hash_password",
"password_context",
"verify_password",
]