from __future__ import annotations from collections.abc import Iterator from urllib.parse import parse_qs, urlparse import pytest from fastapi.testclient import TestClient from sqlalchemy import select from sqlalchemy.orm import Session, sessionmaker from models import Role, User, UserRole from services.security import hash_password @pytest.fixture() def db_session(session_factory: sessionmaker) -> Iterator[Session]: session = session_factory() try: yield session finally: session.close() def _get_user(session: Session, *, email: str | None = None, username: str | None = None) -> User | None: stmt = select(User) if email is not None: stmt = stmt.where(User.email == email) if username is not None: stmt = stmt.where(User.username == username) return session.execute(stmt).scalar_one_or_none() class TestRegistrationFlow: def test_register_creates_user_and_assigns_role( self, client: TestClient, db_session: Session, ) -> None: response = client.post( "/register", data={ "username": "newuser", "email": "newuser@example.com", "password": "ComplexP@ss1", "confirm_password": "ComplexP@ss1", }, follow_redirects=False, ) assert response.status_code == 303 location = response.headers.get("location") assert location parsed = urlparse(location) assert parsed.path == "/login" assert parse_qs(parsed.query).get("registered") == ["1"] created = _get_user(db_session, email="newuser@example.com") assert created is not None assert created.is_active role_stmt = select(Role).where(Role.name == "viewer") viewer_role = db_session.execute(role_stmt).scalar_one_or_none() assert viewer_role is not None assignments = db_session.execute( select(UserRole).where( UserRole.user_id == created.id, UserRole.role_id == viewer_role.id, ) ).scalars().all() assert len(assignments) == 1 def test_register_duplicate_email_shows_error( self, client: TestClient, ) -> None: first = client.post( "/register", data={ "username": "existing", "email": "existing@example.com", "password": "ComplexP@ss1", "confirm_password": "ComplexP@ss1", }, follow_redirects=False, ) assert first.status_code == 303 second = client.post( "/register", data={ "username": "existing", "email": "existing@example.com", "password": "ComplexP@ss1", "confirm_password": "ComplexP@ss1", }, follow_redirects=False, ) assert second.status_code == 400 assert "Email is already registered" in second.text class TestLoginFlow: def test_login_sets_tokens_and_updates_last_login( self, client: TestClient, db_session: Session, ) -> None: password = "MySecur3Pass!" user = User( email="login@example.com", username="loginuser", password_hash=hash_password(password), is_active=True, ) db_session.add(user) db_session.commit() response = client.post( "/login", data={"username": "loginuser", "password": password}, follow_redirects=False, ) assert response.status_code == 303 assert response.headers.get("location") == "http://testserver/" set_cookie_header = response.headers.get("set-cookie", "") assert "calminer_access_token=" in set_cookie_header assert "calminer_refresh_token=" in set_cookie_header updated = _get_user(db_session, username="loginuser") assert updated is not None and updated.last_login_at is not None def test_login_invalid_credentials_returns_error(self, client: TestClient) -> None: response = client.post( "/login", data={"username": "unknown", "password": "bad"}, follow_redirects=False, ) assert response.status_code == 400 assert "Invalid username or password" in response.text class TestPasswordResetFlow: def test_password_reset_round_trip( self, client: TestClient, db_session: Session, ) -> None: user = User( email="reset@example.com", username="resetuser", password_hash=hash_password("OldP@ssword1"), is_active=True, ) db_session.add(user) db_session.commit() request_response = client.post( "/forgot-password", data={"email": "reset@example.com"}, follow_redirects=False, ) assert request_response.status_code == 303 reset_location = request_response.headers.get("location") assert reset_location is not None parsed = urlparse(reset_location) assert parsed.path == "/reset-password" token = parse_qs(parsed.query).get("token", [None])[0] assert token form_response = client.get(reset_location) assert form_response.status_code == 200 submit_response = client.post( "/reset-password", data={ "token": token, "password": "N3wP@ssword!", "confirm_password": "N3wP@ssword!", }, follow_redirects=False, ) assert submit_response.status_code == 303 assert "reset=1" in (submit_response.headers.get("location") or "") db_session.refresh(user) assert user.verify_password("N3wP@ssword!") def test_password_reset_with_unknown_email_shows_generic_message( self, client: TestClient, ) -> None: response = client.post( "/forgot-password", data={"email": "doesnotexist@example.com"}, follow_redirects=False, ) assert response.status_code == 200 assert "If an account exists" in response.text def test_password_reset_mismatched_passwords_return_error( self, client: TestClient, db_session: Session, ) -> None: user = User( email="mismatch@example.com", username="mismatch", password_hash=hash_password("OldP@ssword1"), is_active=True, ) db_session.add(user) db_session.commit() request_response = client.post( "/forgot-password", data={"email": "mismatch@example.com"}, follow_redirects=False, ) token = parse_qs(urlparse(request_response.headers["location"]).query)["token"][0] submit_response = client.post( "/reset-password", data={ "token": token, "password": "NewPass123!", "confirm_password": "Different123!", }, follow_redirects=False, ) assert submit_response.status_code == 400 assert "Passwords do not match" in submit_response.text