feat: Implement user and role management with repositories
- Added RoleRepository and UserRepository for managing roles and users. - Implemented methods for creating, retrieving, and assigning roles to users. - Introduced functions to ensure default roles and an admin user exist in the system. - Updated UnitOfWork to include user and role repositories. - Created new security module for password hashing and JWT token management. - Added tests for authentication flows, including registration, login, and password reset. - Enhanced HTML templates for user registration, login, and password management with error handling. - Added a logo image to the static assets.
This commit is contained in:
@@ -12,6 +12,7 @@ from sqlalchemy.pool import StaticPool
|
||||
|
||||
from config.database import Base
|
||||
from dependencies import get_unit_of_work
|
||||
from routes.auth import router as auth_router
|
||||
from routes.dashboard import router as dashboard_router
|
||||
from routes.projects import router as projects_router
|
||||
from routes.scenarios import router as scenarios_router
|
||||
@@ -36,13 +37,15 @@ def engine() -> Iterator[Engine]:
|
||||
|
||||
@pytest.fixture()
|
||||
def session_factory(engine: Engine) -> Iterator[sessionmaker]:
|
||||
testing_session = sessionmaker(bind=engine, expire_on_commit=False, future=True)
|
||||
testing_session = sessionmaker(
|
||||
bind=engine, expire_on_commit=False, future=True)
|
||||
yield testing_session
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def app(session_factory: sessionmaker) -> FastAPI:
|
||||
application = FastAPI()
|
||||
application.include_router(auth_router)
|
||||
application.include_router(dashboard_router)
|
||||
application.include_router(projects_router)
|
||||
application.include_router(scenarios_router)
|
||||
|
||||
135
tests/test_auth_repositories.py
Normal file
135
tests/test_auth_repositories.py
Normal file
@@ -0,0 +1,135 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Iterator
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from config.database import Base
|
||||
from models import Role, User
|
||||
from services.repositories import (
|
||||
RoleRepository,
|
||||
UserRepository,
|
||||
ensure_admin_user,
|
||||
ensure_default_roles,
|
||||
)
|
||||
from services.unit_of_work import UnitOfWork
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def engine() -> Iterator:
|
||||
engine = create_engine("sqlite:///:memory:", future=True)
|
||||
Base.metadata.create_all(bind=engine)
|
||||
try:
|
||||
yield engine
|
||||
finally:
|
||||
Base.metadata.drop_all(bind=engine)
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def session(engine) -> Iterator[Session]:
|
||||
TestingSession = sessionmaker(
|
||||
bind=engine, expire_on_commit=False, future=True)
|
||||
db = TestingSession()
|
||||
try:
|
||||
yield db
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
def test_role_repository_create_and_lookup(session: Session) -> None:
|
||||
repo = RoleRepository(session)
|
||||
role = Role(name="custom", display_name="Custom",
|
||||
description="Custom role")
|
||||
repo.create(role)
|
||||
|
||||
retrieved = repo.get(role.id)
|
||||
assert retrieved.name == "custom"
|
||||
assert repo.get_by_name("custom") is retrieved
|
||||
assert repo.list()[0].name == "custom"
|
||||
|
||||
|
||||
def test_user_repository_assign_and_revoke_role(session: Session) -> None:
|
||||
role_repo = RoleRepository(session)
|
||||
user_repo = UserRepository(session)
|
||||
|
||||
analyst = role_repo.create(
|
||||
Role(name="analyst", display_name="Analyst", description="Analyzes data")
|
||||
)
|
||||
user = User(
|
||||
email="user@example.com",
|
||||
username="user",
|
||||
password_hash=User.hash_password("secret"),
|
||||
)
|
||||
user_repo.create(user)
|
||||
|
||||
assignment = user_repo.assign_role(
|
||||
user_id=user.id, role_id=analyst.id, granted_by=None)
|
||||
assert assignment.role_id == analyst.id
|
||||
|
||||
refreshed = user_repo.get(user.id, with_roles=True)
|
||||
assert refreshed.roles[0].name == "analyst"
|
||||
|
||||
user_repo.revoke_role(user_id=user.id, role_id=analyst.id)
|
||||
refreshed = user_repo.get(user.id, with_roles=True)
|
||||
assert refreshed.roles == []
|
||||
|
||||
|
||||
def test_default_role_and_admin_helpers(session: Session) -> None:
|
||||
role_repo = RoleRepository(session)
|
||||
user_repo = UserRepository(session)
|
||||
|
||||
roles = ensure_default_roles(role_repo)
|
||||
assert {role.name for role in roles} == {
|
||||
"admin", "project_manager", "analyst", "viewer"}
|
||||
|
||||
ensure_admin_user(
|
||||
user_repo,
|
||||
role_repo,
|
||||
email="admin@example.com",
|
||||
username="admin",
|
||||
password="SecurePass1!",
|
||||
)
|
||||
|
||||
admin = user_repo.get_by_email("admin@example.com", with_roles=True)
|
||||
assert admin is not None
|
||||
assert admin.is_superuser
|
||||
assert {role.name for role in admin.roles} >= {"admin"}
|
||||
|
||||
# Idempotent behaviour on subsequent calls
|
||||
ensure_admin_user(
|
||||
user_repo,
|
||||
role_repo,
|
||||
email="admin@example.com",
|
||||
username="admin",
|
||||
password="SecurePass1!",
|
||||
)
|
||||
admin_again = user_repo.get_by_email("admin@example.com", with_roles=True)
|
||||
assert admin_again is not None
|
||||
assert len(admin_again.roles) == len(
|
||||
{role.name for role in admin_again.roles})
|
||||
|
||||
|
||||
def test_unit_of_work_exposes_auth_repositories(engine) -> None:
|
||||
TestingSession = sessionmaker(
|
||||
bind=engine, expire_on_commit=False, future=True)
|
||||
|
||||
with UnitOfWork(session_factory=TestingSession) as uow:
|
||||
assert uow.users is not None
|
||||
assert uow.roles is not None
|
||||
|
||||
roles = uow.ensure_default_roles()
|
||||
assert any(role.name == "admin" for role in roles)
|
||||
|
||||
uow.ensure_admin_user(
|
||||
email="uow-admin@example.com",
|
||||
username="uow-admin",
|
||||
password="AnotherSecret1!",
|
||||
)
|
||||
|
||||
admin = uow.users.get_by_email(
|
||||
"uow-admin@example.com", with_roles=True)
|
||||
assert admin is not None
|
||||
assert admin.is_superuser
|
||||
assert any(role.name == "admin" for role in admin.roles)
|
||||
239
tests/test_auth_routes.py
Normal file
239
tests/test_auth_routes.py
Normal file
@@ -0,0 +1,239 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Iterator
|
||||
from urllib.parse import parse_qs, urlparse
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from models import Role, User, UserRole
|
||||
from services.security import hash_password
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def db_session(session_factory: sessionmaker) -> Iterator[Session]:
|
||||
session = session_factory()
|
||||
try:
|
||||
yield session
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
|
||||
def _get_user(session: Session, *, email: str | None = None, username: str | None = None) -> User | None:
|
||||
stmt = select(User)
|
||||
if email is not None:
|
||||
stmt = stmt.where(User.email == email)
|
||||
if username is not None:
|
||||
stmt = stmt.where(User.username == username)
|
||||
return session.execute(stmt).scalar_one_or_none()
|
||||
|
||||
|
||||
class TestRegistrationFlow:
|
||||
def test_register_creates_user_and_assigns_role(
|
||||
self,
|
||||
client: TestClient,
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
response = client.post(
|
||||
"/register",
|
||||
data={
|
||||
"username": "newuser",
|
||||
"email": "newuser@example.com",
|
||||
"password": "ComplexP@ss1",
|
||||
"confirm_password": "ComplexP@ss1",
|
||||
},
|
||||
follow_redirects=False,
|
||||
)
|
||||
|
||||
assert response.status_code == 303
|
||||
location = response.headers.get("location")
|
||||
assert location
|
||||
parsed = urlparse(location)
|
||||
assert parsed.path == "/login"
|
||||
assert parse_qs(parsed.query).get("registered") == ["1"]
|
||||
|
||||
created = _get_user(db_session, email="newuser@example.com")
|
||||
assert created is not None
|
||||
assert created.is_active
|
||||
|
||||
role_stmt = select(Role).where(Role.name == "viewer")
|
||||
viewer_role = db_session.execute(role_stmt).scalar_one_or_none()
|
||||
assert viewer_role is not None
|
||||
|
||||
assignments = db_session.execute(
|
||||
select(UserRole).where(
|
||||
UserRole.user_id == created.id,
|
||||
UserRole.role_id == viewer_role.id,
|
||||
)
|
||||
).scalars().all()
|
||||
assert len(assignments) == 1
|
||||
|
||||
def test_register_duplicate_email_shows_error(
|
||||
self,
|
||||
client: TestClient,
|
||||
) -> None:
|
||||
first = client.post(
|
||||
"/register",
|
||||
data={
|
||||
"username": "existing",
|
||||
"email": "existing@example.com",
|
||||
"password": "ComplexP@ss1",
|
||||
"confirm_password": "ComplexP@ss1",
|
||||
},
|
||||
follow_redirects=False,
|
||||
)
|
||||
assert first.status_code == 303
|
||||
|
||||
second = client.post(
|
||||
"/register",
|
||||
data={
|
||||
"username": "existing",
|
||||
"email": "existing@example.com",
|
||||
"password": "ComplexP@ss1",
|
||||
"confirm_password": "ComplexP@ss1",
|
||||
},
|
||||
follow_redirects=False,
|
||||
)
|
||||
|
||||
assert second.status_code == 400
|
||||
assert "Email is already registered" in second.text
|
||||
|
||||
|
||||
class TestLoginFlow:
|
||||
def test_login_sets_tokens_and_updates_last_login(
|
||||
self,
|
||||
client: TestClient,
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
password = "MySecur3Pass!"
|
||||
user = User(
|
||||
email="login@example.com",
|
||||
username="loginuser",
|
||||
password_hash=hash_password(password),
|
||||
is_active=True,
|
||||
)
|
||||
db_session.add(user)
|
||||
db_session.commit()
|
||||
|
||||
response = client.post(
|
||||
"/login",
|
||||
data={"username": "loginuser", "password": password},
|
||||
follow_redirects=False,
|
||||
)
|
||||
|
||||
assert response.status_code == 303
|
||||
assert response.headers.get("location") == "http://testserver/"
|
||||
set_cookie_header = response.headers.get("set-cookie", "")
|
||||
assert "calminer_access_token=" in set_cookie_header
|
||||
assert "calminer_refresh_token=" in set_cookie_header
|
||||
|
||||
updated = _get_user(db_session, username="loginuser")
|
||||
assert updated is not None and updated.last_login_at is not None
|
||||
|
||||
def test_login_invalid_credentials_returns_error(self, client: TestClient) -> None:
|
||||
response = client.post(
|
||||
"/login",
|
||||
data={"username": "unknown", "password": "bad"},
|
||||
follow_redirects=False,
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
assert "Invalid username or password" in response.text
|
||||
|
||||
|
||||
class TestPasswordResetFlow:
|
||||
def test_password_reset_round_trip(
|
||||
self,
|
||||
client: TestClient,
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
user = User(
|
||||
email="reset@example.com",
|
||||
username="resetuser",
|
||||
password_hash=hash_password("OldP@ssword1"),
|
||||
is_active=True,
|
||||
)
|
||||
db_session.add(user)
|
||||
db_session.commit()
|
||||
|
||||
request_response = client.post(
|
||||
"/forgot-password",
|
||||
data={"email": "reset@example.com"},
|
||||
follow_redirects=False,
|
||||
)
|
||||
|
||||
assert request_response.status_code == 303
|
||||
reset_location = request_response.headers.get("location")
|
||||
assert reset_location is not None
|
||||
parsed = urlparse(reset_location)
|
||||
assert parsed.path == "/reset-password"
|
||||
token = parse_qs(parsed.query).get("token", [None])[0]
|
||||
assert token
|
||||
|
||||
form_response = client.get(reset_location)
|
||||
assert form_response.status_code == 200
|
||||
|
||||
submit_response = client.post(
|
||||
"/reset-password",
|
||||
data={
|
||||
"token": token,
|
||||
"password": "N3wP@ssword!",
|
||||
"confirm_password": "N3wP@ssword!",
|
||||
},
|
||||
follow_redirects=False,
|
||||
)
|
||||
|
||||
assert submit_response.status_code == 303
|
||||
assert "reset=1" in (submit_response.headers.get("location") or "")
|
||||
|
||||
db_session.refresh(user)
|
||||
assert user.verify_password("N3wP@ssword!")
|
||||
|
||||
def test_password_reset_with_unknown_email_shows_generic_message(
|
||||
self,
|
||||
client: TestClient,
|
||||
) -> None:
|
||||
response = client.post(
|
||||
"/forgot-password",
|
||||
data={"email": "doesnotexist@example.com"},
|
||||
follow_redirects=False,
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert "If an account exists" in response.text
|
||||
|
||||
def test_password_reset_mismatched_passwords_return_error(
|
||||
self,
|
||||
client: TestClient,
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
user = User(
|
||||
email="mismatch@example.com",
|
||||
username="mismatch",
|
||||
password_hash=hash_password("OldP@ssword1"),
|
||||
is_active=True,
|
||||
)
|
||||
db_session.add(user)
|
||||
db_session.commit()
|
||||
|
||||
request_response = client.post(
|
||||
"/forgot-password",
|
||||
data={"email": "mismatch@example.com"},
|
||||
follow_redirects=False,
|
||||
)
|
||||
token = parse_qs(urlparse(request_response.headers["location"]).query)["token"][0]
|
||||
|
||||
submit_response = client.post(
|
||||
"/reset-password",
|
||||
data={
|
||||
"token": token,
|
||||
"password": "NewPass123!",
|
||||
"confirm_password": "Different123!",
|
||||
},
|
||||
follow_redirects=False,
|
||||
)
|
||||
|
||||
assert submit_response.status_code == 400
|
||||
assert "Passwords do not match" in submit_response.text
|
||||
76
tests/test_security.py
Normal file
76
tests/test_security.py
Normal file
@@ -0,0 +1,76 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import timedelta
|
||||
|
||||
import pytest
|
||||
|
||||
from services.security import (
|
||||
JWTSettings,
|
||||
TokenDecodeError,
|
||||
TokenExpiredError,
|
||||
TokenTypeMismatchError,
|
||||
create_access_token,
|
||||
create_refresh_token,
|
||||
decode_access_token,
|
||||
decode_refresh_token,
|
||||
hash_password,
|
||||
verify_password,
|
||||
)
|
||||
|
||||
|
||||
def test_hash_password_round_trip() -> None:
|
||||
hashed = hash_password("super-secret")
|
||||
|
||||
assert hashed != "super-secret"
|
||||
assert verify_password("super-secret", hashed)
|
||||
assert not verify_password("incorrect", hashed)
|
||||
|
||||
|
||||
def test_verify_password_handles_malformed_hash() -> None:
|
||||
assert not verify_password("secret", "not-a-valid-hash")
|
||||
|
||||
|
||||
def test_access_token_roundtrip() -> None:
|
||||
settings = JWTSettings(secret_key="unit-test-secret")
|
||||
|
||||
token = create_access_token(
|
||||
"user-id-123",
|
||||
settings,
|
||||
scopes=("read", "write"),
|
||||
extra_claims={"custom": "value"},
|
||||
)
|
||||
|
||||
payload = decode_access_token(token, settings)
|
||||
|
||||
assert payload.sub == "user-id-123"
|
||||
assert payload.type == "access"
|
||||
assert payload.scopes == ["read", "write"]
|
||||
|
||||
|
||||
def test_refresh_token_type_mismatch() -> None:
|
||||
settings = JWTSettings(secret_key="unit-test-secret")
|
||||
token = create_refresh_token("user-id-456", settings)
|
||||
|
||||
with pytest.raises(TokenTypeMismatchError):
|
||||
decode_access_token(token, settings)
|
||||
|
||||
|
||||
def test_decode_expired_token() -> None:
|
||||
settings = JWTSettings(secret_key="unit-test-secret")
|
||||
expired_token = create_access_token(
|
||||
"user-id-789",
|
||||
settings,
|
||||
expires_delta=timedelta(seconds=-5),
|
||||
)
|
||||
|
||||
with pytest.raises(TokenExpiredError):
|
||||
decode_access_token(expired_token, settings)
|
||||
|
||||
|
||||
def test_decode_tampered_token() -> None:
|
||||
settings = JWTSettings(secret_key="unit-test-secret")
|
||||
token = create_access_token("user-id-321", settings)
|
||||
tampered = token[:-1] + ("a" if token[-1] != "a" else "b")
|
||||
|
||||
with pytest.raises(TokenDecodeError):
|
||||
decode_access_token(tampered, settings)
|
||||
Reference in New Issue
Block a user