from __future__ import annotations from collections.abc import Iterator from typing import cast from urllib.parse import parse_qs, urlparse import pytest from fastapi import FastAPI from fastapi.testclient import TestClient from sqlalchemy import select from sqlalchemy.orm import Session, sessionmaker from models import Role, User, UserRole from dependencies import get_auth_session, get_jwt_settings, require_current_user from services.security import decode_access_token, hash_password from services.session import AuthSession, SessionTokens from tests.utils.security import random_password, random_token COOKIE_SOURCE = "cookie" @pytest.fixture() def db_session(session_factory: sessionmaker) -> Iterator[Session]: session = session_factory() try: yield session finally: session.close() def _get_user(session: Session, *, email: str | None = None, username: str | None = None) -> User | None: stmt = select(User) if email is not None: stmt = stmt.where(User.email == email) if username is not None: stmt = stmt.where(User.username == username) return session.execute(stmt).scalar_one_or_none() class TestRegistrationFlow: def test_register_creates_user_and_assigns_role( self, client: TestClient, db_session: Session, ) -> None: password = random_password() response = client.post( "/register", data={ "username": "newuser", "email": "newuser@example.com", "password": password, "confirm_password": password, }, follow_redirects=False, ) assert response.status_code == 303 location = response.headers.get("location") assert location parsed = urlparse(location) assert parsed.path == "/login" assert parse_qs(parsed.query).get("registered") == ["1"] created = _get_user(db_session, email="newuser@example.com") assert created is not None assert created.is_active role_stmt = select(Role).where(Role.name == "viewer") viewer_role = db_session.execute(role_stmt).scalar_one_or_none() assert viewer_role is not None assignments = db_session.execute( select(UserRole).where( UserRole.user_id == created.id, UserRole.role_id == viewer_role.id, ) ).scalars().all() assert len(assignments) == 1 def test_register_duplicate_email_shows_error( self, client: TestClient, ) -> None: password = random_password() first = client.post( "/register", data={ "username": "existing", "email": "existing@example.com", "password": password, "confirm_password": password, }, follow_redirects=False, ) assert first.status_code == 303 second = client.post( "/register", data={ "username": "existing", "email": "existing@example.com", "password": password, "confirm_password": password, }, follow_redirects=False, ) assert second.status_code == 400 assert "Email is already registered" in second.text class TestLoginFlow: def test_login_sets_tokens_and_updates_last_login( self, client: TestClient, db_session: Session, ) -> None: password = random_password() user = User( email="login@example.com", username="loginuser", password_hash=hash_password(password), is_active=True, ) db_session.add(user) db_session.commit() response = client.post( "/login", data={"username": "loginuser", "password": password}, follow_redirects=False, ) assert response.status_code == 303 assert response.headers.get("location") == "http://testserver/" set_cookie_header = response.headers.get("set-cookie", "") assert "calminer_access_token=" in set_cookie_header assert "calminer_refresh_token=" in set_cookie_header updated = _get_user(db_session, username="loginuser") assert updated is not None and updated.last_login_at is not None def test_login_invalid_credentials_returns_error(self, client: TestClient) -> None: response = client.post( "/login", data={"username": "unknown", "password": "bad"}, follow_redirects=False, ) assert response.status_code == 400 assert "Invalid username or password" in response.text class TestPasswordResetFlow: def test_password_reset_round_trip( self, client: TestClient, db_session: Session, ) -> None: old_password = random_password() user = User( email="reset@example.com", username="resetuser", password_hash=hash_password(old_password), is_active=True, ) db_session.add(user) db_session.commit() request_response = client.post( "/forgot-password", data={"email": "reset@example.com"}, follow_redirects=False, ) assert request_response.status_code == 303 reset_location = request_response.headers.get("location") assert reset_location is not None parsed = urlparse(reset_location) assert parsed.path == "/reset-password" token = parse_qs(parsed.query).get("token", [None])[0] assert token form_response = client.get(reset_location) assert form_response.status_code == 200 new_password = random_password() submit_response = client.post( "/reset-password", data={ "token": token, "password": new_password, "confirm_password": new_password, }, follow_redirects=False, ) assert submit_response.status_code == 303 assert "reset=1" in (submit_response.headers.get("location") or "") db_session.refresh(user) assert user.verify_password(new_password) def test_password_reset_with_unknown_email_shows_generic_message( self, client: TestClient, ) -> None: response = client.post( "/forgot-password", data={"email": "doesnotexist@example.com"}, follow_redirects=False, ) assert response.status_code == 200 assert "If an account exists" in response.text def test_password_reset_mismatched_passwords_return_error( self, client: TestClient, db_session: Session, ) -> None: original_password = random_password() user = User( email="mismatch@example.com", username="mismatch", password_hash=hash_password(original_password), is_active=True, ) db_session.add(user) db_session.commit() request_response = client.post( "/forgot-password", data={"email": "mismatch@example.com"}, follow_redirects=False, ) token = parse_qs(urlparse(request_response.headers["location"]).query)[ "token"][0] submit_response = client.post( "/reset-password", data={ "token": token, "password": random_password(), "confirm_password": random_password(), }, follow_redirects=False, ) assert submit_response.status_code == 400 assert "Passwords do not match" in submit_response.text class TestLogoutFlow: def test_logout_clears_cookies_and_redirects( self, client: TestClient, db_session: Session, ) -> None: logout_password = random_password() user = User( email="logout@example.com", username="logoutuser", password_hash=hash_password(logout_password), is_active=True, ) db_session.add(user) db_session.commit() session = AuthSession( tokens=SessionTokens( access_token=random_token(), refresh_token=random_token(), access_token_source=COOKIE_SOURCE, ), user=user, ) app = cast(FastAPI, client.app) app.dependency_overrides[require_current_user] = lambda: user app.dependency_overrides[get_auth_session] = lambda: session try: response = client.get("/logout", follow_redirects=False) finally: app.dependency_overrides.pop(require_current_user, None) app.dependency_overrides.pop(get_auth_session, None) assert response.status_code == 303 location = response.headers.get("location") assert location and location.startswith("http://testserver/login") set_cookie_header = response.headers.get("set-cookie") or "" assert "calminer_access_token=" in set_cookie_header assert "Max-Age=0" in set_cookie_header or "expires=" in set_cookie_header.lower() class TestLoginFlowEndToEnd: def test_get_login_form_renders(self, client: TestClient) -> None: response = client.get("/login") assert response.status_code == 200 assert "login-form" in response.text assert "username" in response.text def test_unauthenticated_root_redirects_to_login(self, client: TestClient) -> None: # Temporarily override to anonymous session app = cast(FastAPI, client.app) original_override = app.dependency_overrides.get(get_auth_session) app.dependency_overrides[get_auth_session] = lambda: AuthSession.anonymous( ) try: response = client.get("/", follow_redirects=False) assert response.status_code == 303 assert response.headers.get( "location") == "http://testserver/login" finally: if original_override is not None: app.dependency_overrides[get_auth_session] = original_override else: app.dependency_overrides.pop(get_auth_session, None) def test_login_success_redirects_to_dashboard_and_sets_session( self, client: TestClient, db_session: Session ) -> None: password = random_password() user = User( email="e2e@example.com", username="e2euser", password_hash=hash_password(password), is_active=True, ) db_session.add(user) db_session.commit() # Override to anonymous for login app = cast(FastAPI, client.app) original_override = app.dependency_overrides.get(get_auth_session) app.dependency_overrides[get_auth_session] = lambda: AuthSession.anonymous( ) try: login_response = client.post( "/login", data={"username": "e2euser", "password": password}, follow_redirects=False, ) assert login_response.status_code == 303 assert login_response.headers.get( "location") == "http://testserver/" set_cookie_header = login_response.headers.get("set-cookie", "") assert "calminer_access_token=" in set_cookie_header finally: app.dependency_overrides.pop(get_auth_session, None) access_cookie = client.cookies.get("calminer_access_token") refresh_cookie = client.cookies.get("calminer_refresh_token") assert access_cookie, "Access token cookie was not set" assert refresh_cookie, "Refresh token cookie was not set" jwt_settings = get_jwt_settings() payload = decode_access_token(access_cookie, jwt_settings) assert payload.sub == str(user.id) assert payload.scopes == ["auth"], "Unexpected access token scopes" if original_override is not None: app.dependency_overrides[get_auth_session] = original_override def test_logout_redirects_to_login_and_clears_session(self, client: TestClient) -> None: # Assuming authenticated from conftest logout_response = client.get("/logout", follow_redirects=False) assert logout_response.status_code == 303 location = logout_response.headers.get("location") assert location and "login" in location set_cookie_header = logout_response.headers.get("set-cookie", "") assert "calminer_access_token=" in set_cookie_header assert "Max-Age=0" in set_cookie_header or "expires=" in set_cookie_header.lower() # After logout, GET / should redirect to login app = cast(FastAPI, client.app) app.dependency_overrides[get_auth_session] = lambda: AuthSession.anonymous( ) try: root_response = client.get("/", follow_redirects=False) assert root_response.status_code == 303 assert root_response.headers.get( "location") == "http://testserver/login" finally: app.dependency_overrides.pop(get_auth_session, None) def test_login_inactive_user_shows_error(self, client: TestClient, db_session: Session) -> None: password = random_password() user = User( email="inactive@example.com", username="inactiveuser", password_hash=hash_password(password), is_active=False, ) db_session.add(user) db_session.commit() app = cast(FastAPI, client.app) app.dependency_overrides[get_auth_session] = lambda: AuthSession.anonymous( ) try: response = client.post( "/login", data={"username": "inactiveuser", "password": password}, follow_redirects=False, ) assert response.status_code == 400 assert "Account is inactive" in response.text finally: app.dependency_overrides.pop(get_auth_session, None)