feat: Implement session management with middleware and update authentication flow
This commit is contained in:
@@ -1,15 +1,19 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Iterator
|
||||
from typing import cast
|
||||
from urllib.parse import parse_qs, urlparse
|
||||
|
||||
import pytest
|
||||
from fastapi import FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from models import Role, User, UserRole
|
||||
from dependencies import get_auth_session, require_current_user
|
||||
from services.security import hash_password
|
||||
from services.session import AuthSession, SessionTokens
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
@@ -223,7 +227,8 @@ class TestPasswordResetFlow:
|
||||
data={"email": "mismatch@example.com"},
|
||||
follow_redirects=False,
|
||||
)
|
||||
token = parse_qs(urlparse(request_response.headers["location"]).query)["token"][0]
|
||||
token = parse_qs(urlparse(request_response.headers["location"]).query)[
|
||||
"token"][0]
|
||||
|
||||
submit_response = client.post(
|
||||
"/reset-password",
|
||||
@@ -236,4 +241,46 @@ class TestPasswordResetFlow:
|
||||
)
|
||||
|
||||
assert submit_response.status_code == 400
|
||||
assert "Passwords do not match" in submit_response.text
|
||||
assert "Passwords do not match" in submit_response.text
|
||||
|
||||
|
||||
class TestLogoutFlow:
|
||||
def test_logout_clears_cookies_and_redirects(
|
||||
self,
|
||||
client: TestClient,
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
user = User(
|
||||
email="logout@example.com",
|
||||
username="logoutuser",
|
||||
password_hash=hash_password("SecureP@ss1"),
|
||||
is_active=True,
|
||||
)
|
||||
db_session.add(user)
|
||||
db_session.commit()
|
||||
|
||||
session = AuthSession(
|
||||
tokens=SessionTokens(
|
||||
access_token="access-token",
|
||||
refresh_token="refresh-token",
|
||||
access_token_source="cookie",
|
||||
),
|
||||
user=user,
|
||||
)
|
||||
|
||||
app = cast(FastAPI, client.app)
|
||||
app.dependency_overrides[require_current_user] = lambda: user
|
||||
app.dependency_overrides[get_auth_session] = lambda: session
|
||||
|
||||
try:
|
||||
response = client.get("/logout", follow_redirects=False)
|
||||
finally:
|
||||
app.dependency_overrides.pop(require_current_user, None)
|
||||
app.dependency_overrides.pop(get_auth_session, None)
|
||||
|
||||
assert response.status_code == 303
|
||||
location = response.headers.get("location")
|
||||
assert location and location.startswith("http://testserver/login")
|
||||
set_cookie_header = response.headers.get("set-cookie") or ""
|
||||
assert "calminer_access_token=" in set_cookie_header
|
||||
assert "Max-Age=0" in set_cookie_header or "expires=" in set_cookie_header.lower()
|
||||
|
||||
111
tests/test_auth_session_middleware.py
Normal file
111
tests/test_auth_session_middleware.py
Normal file
@@ -0,0 +1,111 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import timedelta
|
||||
from typing import Iterator
|
||||
|
||||
import pytest
|
||||
from fastapi import Depends, FastAPI
|
||||
from fastapi.responses import JSONResponse
|
||||
from fastapi.testclient import TestClient
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from config.settings import get_settings
|
||||
from dependencies import get_unit_of_work, require_current_user
|
||||
from middleware.auth_session import AuthSessionMiddleware
|
||||
from models import User
|
||||
from services.security import create_access_token, create_refresh_token, hash_password
|
||||
from services.unit_of_work import UnitOfWork
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def auth_app(session_factory: sessionmaker) -> Iterator[TestClient]:
|
||||
app = FastAPI()
|
||||
|
||||
def _override_uow() -> Iterator[UnitOfWork]:
|
||||
with UnitOfWork(session_factory=session_factory) as uow:
|
||||
yield uow
|
||||
|
||||
app.dependency_overrides[get_unit_of_work] = _override_uow
|
||||
|
||||
@app.get("/me")
|
||||
def read_me(user: User = Depends(require_current_user)) -> JSONResponse:
|
||||
return JSONResponse({"id": user.id, "username": user.username})
|
||||
|
||||
app.add_middleware(
|
||||
AuthSessionMiddleware,
|
||||
unit_of_work_factory=lambda: UnitOfWork(
|
||||
session_factory=session_factory),
|
||||
)
|
||||
|
||||
client = TestClient(app)
|
||||
try:
|
||||
yield client
|
||||
finally:
|
||||
client.close()
|
||||
|
||||
|
||||
def _create_user(session_factory: sessionmaker) -> User:
|
||||
with UnitOfWork(session_factory=session_factory) as uow:
|
||||
assert uow.users is not None
|
||||
user = User(
|
||||
email="jane@example.com",
|
||||
username="jane",
|
||||
password_hash=hash_password("secret"),
|
||||
is_active=True,
|
||||
is_superuser=False,
|
||||
)
|
||||
uow.users.create(user)
|
||||
return user
|
||||
|
||||
|
||||
def _issue_tokens(user: User) -> tuple[str, str]:
|
||||
settings = get_settings().jwt_settings()
|
||||
access = create_access_token(str(user.id), settings, scopes=["auth"])
|
||||
refresh = create_refresh_token(str(user.id), settings, scopes=["auth"])
|
||||
return access, refresh
|
||||
|
||||
|
||||
def test_middleware_populates_current_user(auth_app: TestClient, session_factory: sessionmaker) -> None:
|
||||
user = _create_user(session_factory)
|
||||
access, refresh = _issue_tokens(user)
|
||||
|
||||
auth_app.cookies.set("calminer_access_token", access)
|
||||
auth_app.cookies.set("calminer_refresh_token", refresh)
|
||||
|
||||
response = auth_app.get("/me")
|
||||
assert response.status_code == 200
|
||||
payload = response.json()
|
||||
assert payload["id"] == user.id
|
||||
assert payload["username"] == user.username
|
||||
|
||||
|
||||
def test_middleware_refreshes_expired_access_token(auth_app: TestClient, session_factory: sessionmaker) -> None:
|
||||
user = _create_user(session_factory)
|
||||
settings = get_settings().jwt_settings()
|
||||
expired = create_access_token(
|
||||
str(user.id),
|
||||
settings,
|
||||
scopes=["auth"],
|
||||
expires_delta=timedelta(seconds=-1),
|
||||
)
|
||||
refresh = create_refresh_token(str(user.id), settings, scopes=["auth"])
|
||||
|
||||
auth_app.cookies.set("calminer_access_token", expired)
|
||||
auth_app.cookies.set("calminer_refresh_token", refresh)
|
||||
|
||||
response = auth_app.get("/me")
|
||||
assert response.status_code == 200
|
||||
new_access = response.cookies.get("calminer_access_token")
|
||||
new_refresh = response.cookies.get("calminer_refresh_token")
|
||||
assert new_access is not None and new_access != expired
|
||||
assert new_refresh is not None
|
||||
|
||||
|
||||
def test_middleware_blocks_invalid_tokens(auth_app: TestClient) -> None:
|
||||
auth_app.cookies.set("calminer_access_token", "invalid-token")
|
||||
auth_app.cookies.set("calminer_refresh_token", "invalid-token")
|
||||
|
||||
response = auth_app.get("/me")
|
||||
assert response.status_code == 401
|
||||
set_cookies = response.headers.get_list("set-cookie")
|
||||
assert any("calminer_access_token=" in value for value in set_cookies)
|
||||
Reference in New Issue
Block a user