from __future__ import annotations from collections.abc import Iterator from typing import cast from urllib.parse import parse_qs, urlparse import pytest from fastapi import FastAPI from fastapi.testclient import TestClient from sqlalchemy import select from sqlalchemy.orm import Session, sessionmaker from models import Role, User, UserRole from dependencies import get_auth_session, require_current_user from services.security import hash_password from services.session import AuthSession, SessionTokens @pytest.fixture() def db_session(session_factory: sessionmaker) -> Iterator[Session]: session = session_factory() try: yield session finally: session.close() def _get_user(session: Session, *, email: str | None = None, username: str | None = None) -> User | None: stmt = select(User) if email is not None: stmt = stmt.where(User.email == email) if username is not None: stmt = stmt.where(User.username == username) return session.execute(stmt).scalar_one_or_none() class TestRegistrationFlow: def test_register_creates_user_and_assigns_role( self, client: TestClient, db_session: Session, ) -> None: response = client.post( "/register", data={ "username": "newuser", "email": "newuser@example.com", "password": "ComplexP@ss1", "confirm_password": "ComplexP@ss1", }, follow_redirects=False, ) assert response.status_code == 303 location = response.headers.get("location") assert location parsed = urlparse(location) assert parsed.path == "/login" assert parse_qs(parsed.query).get("registered") == ["1"] created = _get_user(db_session, email="newuser@example.com") assert created is not None assert created.is_active role_stmt = select(Role).where(Role.name == "viewer") viewer_role = db_session.execute(role_stmt).scalar_one_or_none() assert viewer_role is not None assignments = db_session.execute( select(UserRole).where( UserRole.user_id == created.id, UserRole.role_id == viewer_role.id, ) ).scalars().all() assert len(assignments) == 1 def test_register_duplicate_email_shows_error( self, client: TestClient, ) -> None: first = client.post( "/register", data={ "username": "existing", "email": "existing@example.com", "password": "ComplexP@ss1", "confirm_password": "ComplexP@ss1", }, follow_redirects=False, ) assert first.status_code == 303 second = client.post( "/register", data={ "username": "existing", "email": "existing@example.com", "password": "ComplexP@ss1", "confirm_password": "ComplexP@ss1", }, follow_redirects=False, ) assert second.status_code == 400 assert "Email is already registered" in second.text class TestLoginFlow: def test_login_sets_tokens_and_updates_last_login( self, client: TestClient, db_session: Session, ) -> None: password = "MySecur3Pass!" user = User( email="login@example.com", username="loginuser", password_hash=hash_password(password), is_active=True, ) db_session.add(user) db_session.commit() response = client.post( "/login", data={"username": "loginuser", "password": password}, follow_redirects=False, ) assert response.status_code == 303 assert response.headers.get("location") == "http://testserver/" set_cookie_header = response.headers.get("set-cookie", "") assert "calminer_access_token=" in set_cookie_header assert "calminer_refresh_token=" in set_cookie_header updated = _get_user(db_session, username="loginuser") assert updated is not None and updated.last_login_at is not None def test_login_invalid_credentials_returns_error(self, client: TestClient) -> None: response = client.post( "/login", data={"username": "unknown", "password": "bad"}, follow_redirects=False, ) assert response.status_code == 400 assert "Invalid username or password" in response.text class TestPasswordResetFlow: def test_password_reset_round_trip( self, client: TestClient, db_session: Session, ) -> None: user = User( email="reset@example.com", username="resetuser", password_hash=hash_password("OldP@ssword1"), is_active=True, ) db_session.add(user) db_session.commit() request_response = client.post( "/forgot-password", data={"email": "reset@example.com"}, follow_redirects=False, ) assert request_response.status_code == 303 reset_location = request_response.headers.get("location") assert reset_location is not None parsed = urlparse(reset_location) assert parsed.path == "/reset-password" token = parse_qs(parsed.query).get("token", [None])[0] assert token form_response = client.get(reset_location) assert form_response.status_code == 200 submit_response = client.post( "/reset-password", data={ "token": token, "password": "N3wP@ssword!", "confirm_password": "N3wP@ssword!", }, follow_redirects=False, ) assert submit_response.status_code == 303 assert "reset=1" in (submit_response.headers.get("location") or "") db_session.refresh(user) assert user.verify_password("N3wP@ssword!") def test_password_reset_with_unknown_email_shows_generic_message( self, client: TestClient, ) -> None: response = client.post( "/forgot-password", data={"email": "doesnotexist@example.com"}, follow_redirects=False, ) assert response.status_code == 200 assert "If an account exists" in response.text def test_password_reset_mismatched_passwords_return_error( self, client: TestClient, db_session: Session, ) -> None: user = User( email="mismatch@example.com", username="mismatch", password_hash=hash_password("OldP@ssword1"), is_active=True, ) db_session.add(user) db_session.commit() request_response = client.post( "/forgot-password", data={"email": "mismatch@example.com"}, follow_redirects=False, ) token = parse_qs(urlparse(request_response.headers["location"]).query)[ "token"][0] submit_response = client.post( "/reset-password", data={ "token": token, "password": "NewPass123!", "confirm_password": "Different123!", }, follow_redirects=False, ) assert submit_response.status_code == 400 assert "Passwords do not match" in submit_response.text class TestLogoutFlow: def test_logout_clears_cookies_and_redirects( self, client: TestClient, db_session: Session, ) -> None: user = User( email="logout@example.com", username="logoutuser", password_hash=hash_password("SecureP@ss1"), is_active=True, ) db_session.add(user) db_session.commit() session = AuthSession( tokens=SessionTokens( access_token="access-token", refresh_token="refresh-token", access_token_source="cookie", ), user=user, ) app = cast(FastAPI, client.app) app.dependency_overrides[require_current_user] = lambda: user app.dependency_overrides[get_auth_session] = lambda: session try: response = client.get("/logout", follow_redirects=False) finally: app.dependency_overrides.pop(require_current_user, None) app.dependency_overrides.pop(get_auth_session, None) assert response.status_code == 303 location = response.headers.get("location") assert location and location.startswith("http://testserver/login") set_cookie_header = response.headers.get("set-cookie") or "" assert "calminer_access_token=" in set_cookie_header assert "Max-Age=0" in set_cookie_header or "expires=" in set_cookie_header.lower()