405 lines
14 KiB
Python
405 lines
14 KiB
Python
from __future__ import annotations
|
|
|
|
from collections.abc import Iterator
|
|
from typing import cast
|
|
from urllib.parse import parse_qs, urlparse
|
|
|
|
import pytest
|
|
from fastapi import FastAPI
|
|
from fastapi.testclient import TestClient
|
|
from sqlalchemy import select
|
|
from sqlalchemy.orm import Session, sessionmaker
|
|
|
|
from models import Role, User, UserRole
|
|
from dependencies import get_auth_session, require_current_user
|
|
from services.security import hash_password
|
|
from services.session import AuthSession, SessionTokens
|
|
from tests.conftest import app
|
|
from tests.utils.security import random_password, random_token
|
|
|
|
COOKIE_SOURCE = "cookie"
|
|
|
|
|
|
@pytest.fixture()
|
|
def db_session(session_factory: sessionmaker) -> Iterator[Session]:
|
|
session = session_factory()
|
|
try:
|
|
yield session
|
|
finally:
|
|
session.close()
|
|
|
|
|
|
def _get_user(session: Session, *, email: str | None = None, username: str | None = None) -> User | None:
|
|
stmt = select(User)
|
|
if email is not None:
|
|
stmt = stmt.where(User.email == email)
|
|
if username is not None:
|
|
stmt = stmt.where(User.username == username)
|
|
return session.execute(stmt).scalar_one_or_none()
|
|
|
|
|
|
class TestRegistrationFlow:
|
|
def test_register_creates_user_and_assigns_role(
|
|
self,
|
|
client: TestClient,
|
|
db_session: Session,
|
|
) -> None:
|
|
password = random_password()
|
|
response = client.post(
|
|
"/register",
|
|
data={
|
|
"username": "newuser",
|
|
"email": "newuser@example.com",
|
|
"password": password,
|
|
"confirm_password": password,
|
|
},
|
|
follow_redirects=False,
|
|
)
|
|
|
|
assert response.status_code == 303
|
|
location = response.headers.get("location")
|
|
assert location
|
|
parsed = urlparse(location)
|
|
assert parsed.path == "/login"
|
|
assert parse_qs(parsed.query).get("registered") == ["1"]
|
|
|
|
created = _get_user(db_session, email="newuser@example.com")
|
|
assert created is not None
|
|
assert created.is_active
|
|
|
|
role_stmt = select(Role).where(Role.name == "viewer")
|
|
viewer_role = db_session.execute(role_stmt).scalar_one_or_none()
|
|
assert viewer_role is not None
|
|
|
|
assignments = db_session.execute(
|
|
select(UserRole).where(
|
|
UserRole.user_id == created.id,
|
|
UserRole.role_id == viewer_role.id,
|
|
)
|
|
).scalars().all()
|
|
assert len(assignments) == 1
|
|
|
|
def test_register_duplicate_email_shows_error(
|
|
self,
|
|
client: TestClient,
|
|
) -> None:
|
|
password = random_password()
|
|
first = client.post(
|
|
"/register",
|
|
data={
|
|
"username": "existing",
|
|
"email": "existing@example.com",
|
|
"password": password,
|
|
"confirm_password": password,
|
|
},
|
|
follow_redirects=False,
|
|
)
|
|
assert first.status_code == 303
|
|
|
|
second = client.post(
|
|
"/register",
|
|
data={
|
|
"username": "existing",
|
|
"email": "existing@example.com",
|
|
"password": password,
|
|
"confirm_password": password,
|
|
},
|
|
follow_redirects=False,
|
|
)
|
|
|
|
assert second.status_code == 400
|
|
assert "Email is already registered" in second.text
|
|
|
|
|
|
class TestLoginFlow:
|
|
def test_login_sets_tokens_and_updates_last_login(
|
|
self,
|
|
client: TestClient,
|
|
db_session: Session,
|
|
) -> None:
|
|
password = random_password()
|
|
user = User(
|
|
email="login@example.com",
|
|
username="loginuser",
|
|
password_hash=hash_password(password),
|
|
is_active=True,
|
|
)
|
|
db_session.add(user)
|
|
db_session.commit()
|
|
|
|
response = client.post(
|
|
"/login",
|
|
data={"username": "loginuser", "password": password},
|
|
follow_redirects=False,
|
|
)
|
|
|
|
assert response.status_code == 303
|
|
assert response.headers.get("location") == "http://testserver/"
|
|
set_cookie_header = response.headers.get("set-cookie", "")
|
|
assert "calminer_access_token=" in set_cookie_header
|
|
assert "calminer_refresh_token=" in set_cookie_header
|
|
|
|
updated = _get_user(db_session, username="loginuser")
|
|
assert updated is not None and updated.last_login_at is not None
|
|
|
|
def test_login_invalid_credentials_returns_error(self, client: TestClient) -> None:
|
|
response = client.post(
|
|
"/login",
|
|
data={"username": "unknown", "password": "bad"},
|
|
follow_redirects=False,
|
|
)
|
|
|
|
assert response.status_code == 400
|
|
assert "Invalid username or password" in response.text
|
|
|
|
|
|
class TestPasswordResetFlow:
|
|
def test_password_reset_round_trip(
|
|
self,
|
|
client: TestClient,
|
|
db_session: Session,
|
|
) -> None:
|
|
old_password = random_password()
|
|
user = User(
|
|
email="reset@example.com",
|
|
username="resetuser",
|
|
password_hash=hash_password(old_password),
|
|
is_active=True,
|
|
)
|
|
db_session.add(user)
|
|
db_session.commit()
|
|
|
|
request_response = client.post(
|
|
"/forgot-password",
|
|
data={"email": "reset@example.com"},
|
|
follow_redirects=False,
|
|
)
|
|
|
|
assert request_response.status_code == 303
|
|
reset_location = request_response.headers.get("location")
|
|
assert reset_location is not None
|
|
parsed = urlparse(reset_location)
|
|
assert parsed.path == "/reset-password"
|
|
token = parse_qs(parsed.query).get("token", [None])[0]
|
|
assert token
|
|
|
|
form_response = client.get(reset_location)
|
|
assert form_response.status_code == 200
|
|
|
|
new_password = random_password()
|
|
submit_response = client.post(
|
|
"/reset-password",
|
|
data={
|
|
"token": token,
|
|
"password": new_password,
|
|
"confirm_password": new_password,
|
|
},
|
|
follow_redirects=False,
|
|
)
|
|
|
|
assert submit_response.status_code == 303
|
|
assert "reset=1" in (submit_response.headers.get("location") or "")
|
|
|
|
db_session.refresh(user)
|
|
assert user.verify_password(new_password)
|
|
|
|
def test_password_reset_with_unknown_email_shows_generic_message(
|
|
self,
|
|
client: TestClient,
|
|
) -> None:
|
|
response = client.post(
|
|
"/forgot-password",
|
|
data={"email": "doesnotexist@example.com"},
|
|
follow_redirects=False,
|
|
)
|
|
|
|
assert response.status_code == 200
|
|
assert "If an account exists" in response.text
|
|
|
|
def test_password_reset_mismatched_passwords_return_error(
|
|
self,
|
|
client: TestClient,
|
|
db_session: Session,
|
|
) -> None:
|
|
original_password = random_password()
|
|
user = User(
|
|
email="mismatch@example.com",
|
|
username="mismatch",
|
|
password_hash=hash_password(original_password),
|
|
is_active=True,
|
|
)
|
|
db_session.add(user)
|
|
db_session.commit()
|
|
|
|
request_response = client.post(
|
|
"/forgot-password",
|
|
data={"email": "mismatch@example.com"},
|
|
follow_redirects=False,
|
|
)
|
|
token = parse_qs(urlparse(request_response.headers["location"]).query)[
|
|
"token"][0]
|
|
|
|
submit_response = client.post(
|
|
"/reset-password",
|
|
data={
|
|
"token": token,
|
|
"password": random_password(),
|
|
"confirm_password": random_password(),
|
|
},
|
|
follow_redirects=False,
|
|
)
|
|
|
|
assert submit_response.status_code == 400
|
|
assert "Passwords do not match" in submit_response.text
|
|
|
|
|
|
class TestLogoutFlow:
|
|
def test_logout_clears_cookies_and_redirects(
|
|
self,
|
|
client: TestClient,
|
|
db_session: Session,
|
|
) -> None:
|
|
logout_password = random_password()
|
|
user = User(
|
|
email="logout@example.com",
|
|
username="logoutuser",
|
|
password_hash=hash_password(logout_password),
|
|
is_active=True,
|
|
)
|
|
db_session.add(user)
|
|
db_session.commit()
|
|
|
|
session = AuthSession(
|
|
tokens=SessionTokens(
|
|
access_token=random_token(),
|
|
refresh_token=random_token(),
|
|
access_token_source=COOKIE_SOURCE,
|
|
),
|
|
user=user,
|
|
)
|
|
|
|
app = cast(FastAPI, client.app)
|
|
app.dependency_overrides[require_current_user] = lambda: user
|
|
app.dependency_overrides[get_auth_session] = lambda: session
|
|
|
|
try:
|
|
response = client.get("/logout", follow_redirects=False)
|
|
finally:
|
|
app.dependency_overrides.pop(require_current_user, None)
|
|
app.dependency_overrides.pop(get_auth_session, None)
|
|
|
|
assert response.status_code == 303
|
|
location = response.headers.get("location")
|
|
assert location and location.startswith("http://testserver/login")
|
|
set_cookie_header = response.headers.get("set-cookie") or ""
|
|
assert "calminer_access_token=" in set_cookie_header
|
|
assert "Max-Age=0" in set_cookie_header or "expires=" in set_cookie_header.lower()
|
|
|
|
|
|
class TestLoginFlowEndToEnd:
|
|
def test_get_login_form_renders(self, client: TestClient) -> None:
|
|
response = client.get("/login")
|
|
assert response.status_code == 200
|
|
assert "login-form" in response.text
|
|
assert "username" in response.text
|
|
|
|
def test_unauthenticated_root_redirects_to_login(self, client: TestClient) -> None:
|
|
# Temporarily override to anonymous session
|
|
app = cast(FastAPI, client.app)
|
|
original_override = app.dependency_overrides.get(get_auth_session)
|
|
app.dependency_overrides[get_auth_session] = lambda: AuthSession.anonymous(
|
|
)
|
|
try:
|
|
response = client.get("/", follow_redirects=False)
|
|
assert response.status_code == 303
|
|
assert response.headers.get(
|
|
"location") == "http://testserver/login"
|
|
finally:
|
|
if original_override is not None:
|
|
app.dependency_overrides[get_auth_session] = original_override
|
|
else:
|
|
app.dependency_overrides.pop(get_auth_session, None)
|
|
|
|
def test_login_success_redirects_to_dashboard_and_sets_session(
|
|
self, client: TestClient, db_session: Session
|
|
) -> None:
|
|
password = random_password()
|
|
user = User(
|
|
email="e2e@example.com",
|
|
username="e2euser",
|
|
password_hash=hash_password(password),
|
|
is_active=True,
|
|
)
|
|
db_session.add(user)
|
|
db_session.commit()
|
|
|
|
# Override to anonymous for login
|
|
app = cast(FastAPI, client.app)
|
|
app.dependency_overrides[get_auth_session] = lambda: AuthSession.anonymous(
|
|
)
|
|
try:
|
|
login_response = client.post(
|
|
"/login",
|
|
data={"username": "e2euser", "password": password},
|
|
follow_redirects=False,
|
|
)
|
|
assert login_response.status_code == 303
|
|
assert login_response.headers.get(
|
|
"location") == "http://testserver/"
|
|
set_cookie_header = login_response.headers.get("set-cookie", "")
|
|
assert "calminer_access_token=" in set_cookie_header
|
|
|
|
# Now with cookies, GET / should show dashboard
|
|
dashboard_response = client.get("/")
|
|
assert dashboard_response.status_code == 200
|
|
assert "Dashboard" in dashboard_response.text or "metrics" in dashboard_response.text
|
|
finally:
|
|
app.dependency_overrides.pop(get_auth_session, None)
|
|
|
|
def test_logout_redirects_to_login_and_clears_session(self, client: TestClient) -> None:
|
|
# Assuming authenticated from conftest
|
|
logout_response = client.get("/logout", follow_redirects=False)
|
|
assert logout_response.status_code == 303
|
|
location = logout_response.headers.get("location")
|
|
assert location and "login" in location
|
|
set_cookie_header = logout_response.headers.get("set-cookie", "")
|
|
assert "calminer_access_token=" in set_cookie_header
|
|
assert "Max-Age=0" in set_cookie_header or "expires=" in set_cookie_header.lower()
|
|
|
|
# After logout, GET / should redirect to login
|
|
app = cast(FastAPI, client.app)
|
|
app.dependency_overrides[get_auth_session] = lambda: AuthSession.anonymous(
|
|
)
|
|
try:
|
|
root_response = client.get("/", follow_redirects=False)
|
|
assert root_response.status_code == 303
|
|
assert root_response.headers.get(
|
|
"location") == "http://testserver/login"
|
|
finally:
|
|
app.dependency_overrides.pop(get_auth_session, None)
|
|
|
|
def test_login_inactive_user_shows_error(self, client: TestClient, db_session: Session) -> None:
|
|
password = random_password()
|
|
user = User(
|
|
email="inactive@example.com",
|
|
username="inactiveuser",
|
|
password_hash=hash_password(password),
|
|
is_active=False,
|
|
)
|
|
db_session.add(user)
|
|
db_session.commit()
|
|
|
|
app = cast(FastAPI, client.app)
|
|
app.dependency_overrides[get_auth_session] = lambda: AuthSession.anonymous(
|
|
)
|
|
try:
|
|
response = client.post(
|
|
"/login",
|
|
data={"username": "inactiveuser", "password": password},
|
|
follow_redirects=False,
|
|
)
|
|
assert response.status_code == 400
|
|
assert "Account is inactive" in response.text
|
|
finally:
|
|
app.dependency_overrides.pop(get_auth_session, None)
|