Files
calminer/tests/test_auth_routes.py

287 lines
8.9 KiB
Python

from __future__ import annotations
from collections.abc import Iterator
from typing import cast
from urllib.parse import parse_qs, urlparse
import pytest
from fastapi import FastAPI
from fastapi.testclient import TestClient
from sqlalchemy import select
from sqlalchemy.orm import Session, sessionmaker
from models import Role, User, UserRole
from dependencies import get_auth_session, require_current_user
from services.security import hash_password
from services.session import AuthSession, SessionTokens
@pytest.fixture()
def db_session(session_factory: sessionmaker) -> Iterator[Session]:
session = session_factory()
try:
yield session
finally:
session.close()
def _get_user(session: Session, *, email: str | None = None, username: str | None = None) -> User | None:
stmt = select(User)
if email is not None:
stmt = stmt.where(User.email == email)
if username is not None:
stmt = stmt.where(User.username == username)
return session.execute(stmt).scalar_one_or_none()
class TestRegistrationFlow:
def test_register_creates_user_and_assigns_role(
self,
client: TestClient,
db_session: Session,
) -> None:
response = client.post(
"/register",
data={
"username": "newuser",
"email": "newuser@example.com",
"password": "ComplexP@ss1",
"confirm_password": "ComplexP@ss1",
},
follow_redirects=False,
)
assert response.status_code == 303
location = response.headers.get("location")
assert location
parsed = urlparse(location)
assert parsed.path == "/login"
assert parse_qs(parsed.query).get("registered") == ["1"]
created = _get_user(db_session, email="newuser@example.com")
assert created is not None
assert created.is_active
role_stmt = select(Role).where(Role.name == "viewer")
viewer_role = db_session.execute(role_stmt).scalar_one_or_none()
assert viewer_role is not None
assignments = db_session.execute(
select(UserRole).where(
UserRole.user_id == created.id,
UserRole.role_id == viewer_role.id,
)
).scalars().all()
assert len(assignments) == 1
def test_register_duplicate_email_shows_error(
self,
client: TestClient,
) -> None:
first = client.post(
"/register",
data={
"username": "existing",
"email": "existing@example.com",
"password": "ComplexP@ss1",
"confirm_password": "ComplexP@ss1",
},
follow_redirects=False,
)
assert first.status_code == 303
second = client.post(
"/register",
data={
"username": "existing",
"email": "existing@example.com",
"password": "ComplexP@ss1",
"confirm_password": "ComplexP@ss1",
},
follow_redirects=False,
)
assert second.status_code == 400
assert "Email is already registered" in second.text
class TestLoginFlow:
def test_login_sets_tokens_and_updates_last_login(
self,
client: TestClient,
db_session: Session,
) -> None:
password = "MySecur3Pass!"
user = User(
email="login@example.com",
username="loginuser",
password_hash=hash_password(password),
is_active=True,
)
db_session.add(user)
db_session.commit()
response = client.post(
"/login",
data={"username": "loginuser", "password": password},
follow_redirects=False,
)
assert response.status_code == 303
assert response.headers.get("location") == "http://testserver/"
set_cookie_header = response.headers.get("set-cookie", "")
assert "calminer_access_token=" in set_cookie_header
assert "calminer_refresh_token=" in set_cookie_header
updated = _get_user(db_session, username="loginuser")
assert updated is not None and updated.last_login_at is not None
def test_login_invalid_credentials_returns_error(self, client: TestClient) -> None:
response = client.post(
"/login",
data={"username": "unknown", "password": "bad"},
follow_redirects=False,
)
assert response.status_code == 400
assert "Invalid username or password" in response.text
class TestPasswordResetFlow:
def test_password_reset_round_trip(
self,
client: TestClient,
db_session: Session,
) -> None:
user = User(
email="reset@example.com",
username="resetuser",
password_hash=hash_password("OldP@ssword1"),
is_active=True,
)
db_session.add(user)
db_session.commit()
request_response = client.post(
"/forgot-password",
data={"email": "reset@example.com"},
follow_redirects=False,
)
assert request_response.status_code == 303
reset_location = request_response.headers.get("location")
assert reset_location is not None
parsed = urlparse(reset_location)
assert parsed.path == "/reset-password"
token = parse_qs(parsed.query).get("token", [None])[0]
assert token
form_response = client.get(reset_location)
assert form_response.status_code == 200
submit_response = client.post(
"/reset-password",
data={
"token": token,
"password": "N3wP@ssword!",
"confirm_password": "N3wP@ssword!",
},
follow_redirects=False,
)
assert submit_response.status_code == 303
assert "reset=1" in (submit_response.headers.get("location") or "")
db_session.refresh(user)
assert user.verify_password("N3wP@ssword!")
def test_password_reset_with_unknown_email_shows_generic_message(
self,
client: TestClient,
) -> None:
response = client.post(
"/forgot-password",
data={"email": "doesnotexist@example.com"},
follow_redirects=False,
)
assert response.status_code == 200
assert "If an account exists" in response.text
def test_password_reset_mismatched_passwords_return_error(
self,
client: TestClient,
db_session: Session,
) -> None:
user = User(
email="mismatch@example.com",
username="mismatch",
password_hash=hash_password("OldP@ssword1"),
is_active=True,
)
db_session.add(user)
db_session.commit()
request_response = client.post(
"/forgot-password",
data={"email": "mismatch@example.com"},
follow_redirects=False,
)
token = parse_qs(urlparse(request_response.headers["location"]).query)[
"token"][0]
submit_response = client.post(
"/reset-password",
data={
"token": token,
"password": "NewPass123!",
"confirm_password": "Different123!",
},
follow_redirects=False,
)
assert submit_response.status_code == 400
assert "Passwords do not match" in submit_response.text
class TestLogoutFlow:
def test_logout_clears_cookies_and_redirects(
self,
client: TestClient,
db_session: Session,
) -> None:
user = User(
email="logout@example.com",
username="logoutuser",
password_hash=hash_password("SecureP@ss1"),
is_active=True,
)
db_session.add(user)
db_session.commit()
session = AuthSession(
tokens=SessionTokens(
access_token="access-token",
refresh_token="refresh-token",
access_token_source="cookie",
),
user=user,
)
app = cast(FastAPI, client.app)
app.dependency_overrides[require_current_user] = lambda: user
app.dependency_overrides[get_auth_session] = lambda: session
try:
response = client.get("/logout", follow_redirects=False)
finally:
app.dependency_overrides.pop(require_current_user, None)
app.dependency_overrides.pop(get_auth_session, None)
assert response.status_code == 303
location = response.headers.get("location")
assert location and location.startswith("http://testserver/login")
set_cookie_header = response.headers.get("set-cookie") or ""
assert "calminer_access_token=" in set_cookie_header
assert "Max-Age=0" in set_cookie_header or "expires=" in set_cookie_header.lower()