from __future__ import annotations from datetime import timedelta from typing import Iterator import pytest from fastapi import Depends, FastAPI from fastapi.responses import JSONResponse from fastapi.testclient import TestClient from sqlalchemy.orm import sessionmaker from config.settings import get_settings from dependencies import get_unit_of_work, require_current_user from middleware.auth_session import AuthSessionMiddleware from models import User from services.security import create_access_token, create_refresh_token, hash_password from services.unit_of_work import UnitOfWork @pytest.fixture() def auth_app(session_factory: sessionmaker) -> Iterator[TestClient]: app = FastAPI() def _override_uow() -> Iterator[UnitOfWork]: with UnitOfWork(session_factory=session_factory) as uow: yield uow app.dependency_overrides[get_unit_of_work] = _override_uow @app.get("/me") def read_me(user: User = Depends(require_current_user)) -> JSONResponse: return JSONResponse({"id": user.id, "username": user.username}) app.add_middleware( AuthSessionMiddleware, unit_of_work_factory=lambda: UnitOfWork( session_factory=session_factory), ) client = TestClient(app) try: yield client finally: client.close() def _create_user(session_factory: sessionmaker) -> User: with UnitOfWork(session_factory=session_factory) as uow: assert uow.users is not None user = User( email="jane@example.com", username="jane", password_hash=hash_password("secret"), is_active=True, is_superuser=False, ) uow.users.create(user) return user def _issue_tokens(user: User) -> tuple[str, str]: settings = get_settings().jwt_settings() access = create_access_token(str(user.id), settings, scopes=["auth"]) refresh = create_refresh_token(str(user.id), settings, scopes=["auth"]) return access, refresh def test_middleware_populates_current_user(auth_app: TestClient, session_factory: sessionmaker) -> None: user = _create_user(session_factory) access, refresh = _issue_tokens(user) auth_app.cookies.set("calminer_access_token", access) auth_app.cookies.set("calminer_refresh_token", refresh) response = auth_app.get("/me") assert response.status_code == 200 payload = response.json() assert payload["id"] == user.id assert payload["username"] == user.username def test_middleware_refreshes_expired_access_token(auth_app: TestClient, session_factory: sessionmaker) -> None: user = _create_user(session_factory) settings = get_settings().jwt_settings() expired = create_access_token( str(user.id), settings, scopes=["auth"], expires_delta=timedelta(seconds=-1), ) refresh = create_refresh_token(str(user.id), settings, scopes=["auth"]) auth_app.cookies.set("calminer_access_token", expired) auth_app.cookies.set("calminer_refresh_token", refresh) response = auth_app.get("/me") assert response.status_code == 200 new_access = response.cookies.get("calminer_access_token") new_refresh = response.cookies.get("calminer_refresh_token") assert new_access is not None and new_access != expired assert new_refresh is not None def test_middleware_blocks_invalid_tokens(auth_app: TestClient) -> None: auth_app.cookies.set("calminer_access_token", "invalid-token") auth_app.cookies.set("calminer_refresh_token", "invalid-token") response = auth_app.get("/me") assert response.status_code == 401 set_cookies = response.headers.get_list("set-cookie") assert any("calminer_access_token=" in value for value in set_cookies)