from __future__ import annotations from datetime import timedelta import pytest from services.security import ( JWTSettings, TokenDecodeError, TokenExpiredError, TokenTypeMismatchError, create_access_token, create_refresh_token, decode_access_token, decode_refresh_token, hash_password, verify_password, ) def test_hash_password_round_trip() -> None: hashed = hash_password("super-secret") assert hashed != "super-secret" assert verify_password("super-secret", hashed) assert not verify_password("incorrect", hashed) def test_verify_password_handles_malformed_hash() -> None: assert not verify_password("secret", "not-a-valid-hash") def test_access_token_roundtrip() -> None: settings = JWTSettings(secret_key="unit-test-secret") token = create_access_token( "user-id-123", settings, scopes=("read", "write"), extra_claims={"custom": "value"}, ) payload = decode_access_token(token, settings) assert payload.sub == "user-id-123" assert payload.type == "access" assert payload.scopes == ["read", "write"] def test_refresh_token_type_mismatch() -> None: settings = JWTSettings(secret_key="unit-test-secret") token = create_refresh_token("user-id-456", settings) with pytest.raises(TokenTypeMismatchError): decode_access_token(token, settings) def test_decode_expired_token() -> None: settings = JWTSettings(secret_key="unit-test-secret") expired_token = create_access_token( "user-id-789", settings, expires_delta=timedelta(seconds=-5), ) with pytest.raises(TokenExpiredError): decode_access_token(expired_token, settings) def test_decode_tampered_token() -> None: settings = JWTSettings(secret_key="unit-test-secret") token = create_access_token("user-id-321", settings) tampered = token[:-1] + ("a" if token[-1] != "a" else "b") with pytest.raises(TokenDecodeError): decode_access_token(tampered, settings)