Files
calminer/tests/test_auth_session_middleware.py

112 lines
3.7 KiB
Python

from __future__ import annotations
from datetime import timedelta
from typing import Iterator
import pytest
from fastapi import Depends, FastAPI
from fastapi.responses import JSONResponse
from fastapi.testclient import TestClient
from sqlalchemy.orm import sessionmaker
from config.settings import get_settings
from dependencies import get_unit_of_work, require_current_user
from middleware.auth_session import AuthSessionMiddleware
from models import User
from services.security import create_access_token, create_refresh_token, hash_password
from services.unit_of_work import UnitOfWork
@pytest.fixture()
def auth_app(session_factory: sessionmaker) -> Iterator[TestClient]:
app = FastAPI()
def _override_uow() -> Iterator[UnitOfWork]:
with UnitOfWork(session_factory=session_factory) as uow:
yield uow
app.dependency_overrides[get_unit_of_work] = _override_uow
@app.get("/me")
def read_me(user: User = Depends(require_current_user)) -> JSONResponse:
return JSONResponse({"id": user.id, "username": user.username})
app.add_middleware(
AuthSessionMiddleware,
unit_of_work_factory=lambda: UnitOfWork(
session_factory=session_factory),
)
client = TestClient(app)
try:
yield client
finally:
client.close()
def _create_user(session_factory: sessionmaker) -> User:
with UnitOfWork(session_factory=session_factory) as uow:
assert uow.users is not None
user = User(
email="jane@example.com",
username="jane",
password_hash=hash_password("secret"),
is_active=True,
is_superuser=False,
)
uow.users.create(user)
return user
def _issue_tokens(user: User) -> tuple[str, str]:
settings = get_settings().jwt_settings()
access = create_access_token(str(user.id), settings, scopes=["auth"])
refresh = create_refresh_token(str(user.id), settings, scopes=["auth"])
return access, refresh
def test_middleware_populates_current_user(auth_app: TestClient, session_factory: sessionmaker) -> None:
user = _create_user(session_factory)
access, refresh = _issue_tokens(user)
auth_app.cookies.set("calminer_access_token", access)
auth_app.cookies.set("calminer_refresh_token", refresh)
response = auth_app.get("/me")
assert response.status_code == 200
payload = response.json()
assert payload["id"] == user.id
assert payload["username"] == user.username
def test_middleware_refreshes_expired_access_token(auth_app: TestClient, session_factory: sessionmaker) -> None:
user = _create_user(session_factory)
settings = get_settings().jwt_settings()
expired = create_access_token(
str(user.id),
settings,
scopes=["auth"],
expires_delta=timedelta(seconds=-1),
)
refresh = create_refresh_token(str(user.id), settings, scopes=["auth"])
auth_app.cookies.set("calminer_access_token", expired)
auth_app.cookies.set("calminer_refresh_token", refresh)
response = auth_app.get("/me")
assert response.status_code == 200
new_access = response.cookies.get("calminer_access_token")
new_refresh = response.cookies.get("calminer_refresh_token")
assert new_access is not None and new_access != expired
assert new_refresh is not None
def test_middleware_blocks_invalid_tokens(auth_app: TestClient) -> None:
auth_app.cookies.set("calminer_access_token", "invalid-token")
auth_app.cookies.set("calminer_refresh_token", "invalid-token")
response = auth_app.get("/me")
assert response.status_code == 401
set_cookies = response.headers.get_list("set-cookie")
assert any("calminer_access_token=" in value for value in set_cookies)