112 lines
3.7 KiB
Python
112 lines
3.7 KiB
Python
from __future__ import annotations
|
|
|
|
from datetime import timedelta
|
|
from typing import Iterator
|
|
|
|
import pytest
|
|
from fastapi import Depends, FastAPI
|
|
from fastapi.responses import JSONResponse
|
|
from fastapi.testclient import TestClient
|
|
from sqlalchemy.orm import sessionmaker
|
|
|
|
from config.settings import get_settings
|
|
from dependencies import get_unit_of_work, require_current_user
|
|
from middleware.auth_session import AuthSessionMiddleware
|
|
from models import User
|
|
from services.security import create_access_token, create_refresh_token, hash_password
|
|
from services.unit_of_work import UnitOfWork
|
|
|
|
|
|
@pytest.fixture()
|
|
def auth_app(session_factory: sessionmaker) -> Iterator[TestClient]:
|
|
app = FastAPI()
|
|
|
|
def _override_uow() -> Iterator[UnitOfWork]:
|
|
with UnitOfWork(session_factory=session_factory) as uow:
|
|
yield uow
|
|
|
|
app.dependency_overrides[get_unit_of_work] = _override_uow
|
|
|
|
@app.get("/me")
|
|
def read_me(user: User = Depends(require_current_user)) -> JSONResponse:
|
|
return JSONResponse({"id": user.id, "username": user.username})
|
|
|
|
app.add_middleware(
|
|
AuthSessionMiddleware,
|
|
unit_of_work_factory=lambda: UnitOfWork(
|
|
session_factory=session_factory),
|
|
)
|
|
|
|
client = TestClient(app)
|
|
try:
|
|
yield client
|
|
finally:
|
|
client.close()
|
|
|
|
|
|
def _create_user(session_factory: sessionmaker) -> User:
|
|
with UnitOfWork(session_factory=session_factory) as uow:
|
|
assert uow.users is not None
|
|
user = User(
|
|
email="jane@example.com",
|
|
username="jane",
|
|
password_hash=hash_password("secret"),
|
|
is_active=True,
|
|
is_superuser=False,
|
|
)
|
|
uow.users.create(user)
|
|
return user
|
|
|
|
|
|
def _issue_tokens(user: User) -> tuple[str, str]:
|
|
settings = get_settings().jwt_settings()
|
|
access = create_access_token(str(user.id), settings, scopes=["auth"])
|
|
refresh = create_refresh_token(str(user.id), settings, scopes=["auth"])
|
|
return access, refresh
|
|
|
|
|
|
def test_middleware_populates_current_user(auth_app: TestClient, session_factory: sessionmaker) -> None:
|
|
user = _create_user(session_factory)
|
|
access, refresh = _issue_tokens(user)
|
|
|
|
auth_app.cookies.set("calminer_access_token", access)
|
|
auth_app.cookies.set("calminer_refresh_token", refresh)
|
|
|
|
response = auth_app.get("/me")
|
|
assert response.status_code == 200
|
|
payload = response.json()
|
|
assert payload["id"] == user.id
|
|
assert payload["username"] == user.username
|
|
|
|
|
|
def test_middleware_refreshes_expired_access_token(auth_app: TestClient, session_factory: sessionmaker) -> None:
|
|
user = _create_user(session_factory)
|
|
settings = get_settings().jwt_settings()
|
|
expired = create_access_token(
|
|
str(user.id),
|
|
settings,
|
|
scopes=["auth"],
|
|
expires_delta=timedelta(seconds=-1),
|
|
)
|
|
refresh = create_refresh_token(str(user.id), settings, scopes=["auth"])
|
|
|
|
auth_app.cookies.set("calminer_access_token", expired)
|
|
auth_app.cookies.set("calminer_refresh_token", refresh)
|
|
|
|
response = auth_app.get("/me")
|
|
assert response.status_code == 200
|
|
new_access = response.cookies.get("calminer_access_token")
|
|
new_refresh = response.cookies.get("calminer_refresh_token")
|
|
assert new_access is not None and new_access != expired
|
|
assert new_refresh is not None
|
|
|
|
|
|
def test_middleware_blocks_invalid_tokens(auth_app: TestClient) -> None:
|
|
auth_app.cookies.set("calminer_access_token", "invalid-token")
|
|
auth_app.cookies.set("calminer_refresh_token", "invalid-token")
|
|
|
|
response = auth_app.get("/me")
|
|
assert response.status_code == 401
|
|
set_cookies = response.headers.get_list("set-cookie")
|
|
assert any("calminer_access_token=" in value for value in set_cookies)
|