from __future__ import annotations from dataclasses import dataclass, field from datetime import datetime, timedelta, timezone from hmac import compare_digest from typing import Any, Dict, Iterable, Literal, Type from jose import ExpiredSignatureError, JWTError, jwt from passlib.context import CryptContext try: # pragma: no cover - compatibility shim for passlib/argon2 warning import importlib.metadata as importlib_metadata import argon2 # type: ignore setattr(argon2, "__version__", importlib_metadata.version("argon2-cffi")) except Exception: # pragma: no cover - executed only when metadata lookup fails pass from pydantic import BaseModel, Field, ValidationError password_context = CryptContext(schemes=["argon2"], deprecated="auto") def hash_password(password: str) -> str: """Derive a secure hash for a plain-text password.""" return password_context.hash(password) def verify_password(candidate: str, hashed: str) -> bool: """Verify that a candidate password matches a stored hash.""" try: return password_context.verify(candidate, hashed) except ValueError: # Raised when the stored hash is malformed or uses an unknown scheme. return False class TokenError(Exception): """Base class for token encoding/decoding issues.""" class TokenDecodeError(TokenError): """Raised when a token cannot be decoded or validated.""" class TokenExpiredError(TokenError): """Raised when a token has expired.""" class TokenTypeMismatchError(TokenError): """Raised when a token type does not match the expected flavour.""" TokenKind = Literal["access", "refresh"] class TokenPayload(BaseModel): """Shared fields for CalMiner JWT payloads.""" sub: str exp: int type: TokenKind scopes: list[str] = Field(default_factory=list) @property def expires_at(self) -> datetime: return datetime.fromtimestamp(self.exp, tz=timezone.utc) @dataclass(slots=True) class JWTSettings: """Runtime configuration for JWT encoding and validation.""" secret_key: str algorithm: str = "HS256" access_token_ttl: timedelta = field( default_factory=lambda: timedelta(minutes=15)) refresh_token_ttl: timedelta = field( default_factory=lambda: timedelta(days=7)) def create_access_token( subject: str, settings: JWTSettings, *, scopes: Iterable[str] | None = None, expires_delta: timedelta | None = None, extra_claims: Dict[str, Any] | None = None, ) -> str: """Issue a signed access token for the provided subject.""" lifetime = expires_delta or settings.access_token_ttl return _create_token( subject=subject, token_type="access", settings=settings, lifetime=lifetime, scopes=scopes, extra_claims=extra_claims, ) def create_refresh_token( subject: str, settings: JWTSettings, *, scopes: Iterable[str] | None = None, expires_delta: timedelta | None = None, extra_claims: Dict[str, Any] | None = None, ) -> str: """Issue a signed refresh token for the provided subject.""" lifetime = expires_delta or settings.refresh_token_ttl return _create_token( subject=subject, token_type="refresh", settings=settings, lifetime=lifetime, scopes=scopes, extra_claims=extra_claims, ) def decode_access_token(token: str, settings: JWTSettings) -> TokenPayload: """Validate and decode an access token.""" return _decode_token(token, settings, expected_type="access") def decode_refresh_token(token: str, settings: JWTSettings) -> TokenPayload: """Validate and decode a refresh token.""" return _decode_token(token, settings, expected_type="refresh") def _create_token( *, subject: str, token_type: TokenKind, settings: JWTSettings, lifetime: timedelta, scopes: Iterable[str] | None, extra_claims: Dict[str, Any] | None, ) -> str: now = datetime.now(timezone.utc) expire = now + lifetime payload: Dict[str, Any] = { "sub": subject, "type": token_type, "iat": int(now.timestamp()), "exp": int(expire.timestamp()), } if scopes: payload["scopes"] = list(scopes) if extra_claims: payload.update(extra_claims) return jwt.encode(payload, settings.secret_key, algorithm=settings.algorithm) def _decode_token( token: str, settings: JWTSettings, expected_type: TokenKind, ) -> TokenPayload: try: decoded = jwt.decode( token, settings.secret_key, algorithms=[settings.algorithm], options={"verify_aud": False}, ) except ExpiredSignatureError as exc: # pragma: no cover - jose marks this path raise TokenExpiredError("Token has expired") from exc except JWTError as exc: # pragma: no cover - jose error bubble raise TokenDecodeError("Unable to decode token") from exc expected_token = jwt.encode( decoded, settings.secret_key, algorithm=settings.algorithm, ) if not compare_digest(token, expected_token): raise TokenDecodeError("Token contents have been altered.") try: payload = _model_validate(TokenPayload, decoded) except ValidationError as exc: raise TokenDecodeError("Token payload validation failed") from exc if payload.type != expected_type: raise TokenTypeMismatchError( f"Expected a {expected_type} token but received '{payload.type}'." ) return payload def _model_validate(model: Type[TokenPayload], data: Dict[str, Any]) -> TokenPayload: if hasattr(model, "model_validate"): return model.model_validate(data) # type: ignore[attr-defined] return model.parse_obj(data) # type: ignore[attr-defined] __all__ = [ "JWTSettings", "TokenDecodeError", "TokenError", "TokenExpiredError", "TokenKind", "TokenPayload", "TokenTypeMismatchError", "create_access_token", "create_refresh_token", "decode_access_token", "decode_refresh_token", "hash_password", "password_context", "verify_password", ]