feat: Implement user and role management with repositories

- Added RoleRepository and UserRepository for managing roles and users.
- Implemented methods for creating, retrieving, and assigning roles to users.
- Introduced functions to ensure default roles and an admin user exist in the system.
- Updated UnitOfWork to include user and role repositories.
- Created new security module for password hashing and JWT token management.
- Added tests for authentication flows, including registration, login, and password reset.
- Enhanced HTML templates for user registration, login, and password management with error handling.
- Added a logo image to the static assets.
This commit is contained in:
2025-11-09 21:48:35 +01:00
parent 53879a411f
commit 3601c2e422
22 changed files with 1955 additions and 132 deletions

View File

@@ -8,7 +8,16 @@ from sqlalchemy import select, func
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session, joinedload, selectinload
from models import FinancialInput, Project, Scenario, ScenarioStatus, SimulationParameter
from models import (
FinancialInput,
Project,
Role,
Scenario,
ScenarioStatus,
SimulationParameter,
User,
UserRole,
)
from services.exceptions import EntityConflictError, EntityNotFoundError
@@ -211,3 +220,220 @@ class SimulationParameterRepository:
raise EntityNotFoundError(
f"Simulation parameter {parameter_id} not found")
self.session.delete(entity)
class RoleRepository:
"""Persistence operations for Role entities."""
def __init__(self, session: Session) -> None:
self.session = session
def list(self) -> Sequence[Role]:
stmt = select(Role).order_by(Role.name)
return self.session.execute(stmt).scalars().all()
def get(self, role_id: int) -> Role:
stmt = select(Role).where(Role.id == role_id)
role = self.session.execute(stmt).scalar_one_or_none()
if role is None:
raise EntityNotFoundError(f"Role {role_id} not found")
return role
def get_by_name(self, name: str) -> Role | None:
stmt = select(Role).where(Role.name == name)
return self.session.execute(stmt).scalar_one_or_none()
def create(self, role: Role) -> Role:
self.session.add(role)
try:
self.session.flush()
except IntegrityError as exc: # pragma: no cover - DB constraint enforcement
raise EntityConflictError(
"Role violates uniqueness constraints") from exc
return role
class UserRepository:
"""Persistence operations for User entities and their role assignments."""
def __init__(self, session: Session) -> None:
self.session = session
def list(self, *, with_roles: bool = False) -> Sequence[User]:
stmt = select(User).order_by(User.created_at)
if with_roles:
stmt = stmt.options(selectinload(User.roles))
return self.session.execute(stmt).scalars().all()
def _apply_role_option(self, stmt, with_roles: bool):
if with_roles:
stmt = stmt.options(
joinedload(User.role_assignments).joinedload(UserRole.role),
selectinload(User.roles),
)
return stmt
def get(self, user_id: int, *, with_roles: bool = False) -> User:
stmt = select(User).where(User.id == user_id).execution_options(
populate_existing=True)
stmt = self._apply_role_option(stmt, with_roles)
result = self.session.execute(stmt)
if with_roles:
result = result.unique()
user = result.scalar_one_or_none()
if user is None:
raise EntityNotFoundError(f"User {user_id} not found")
return user
def get_by_email(self, email: str, *, with_roles: bool = False) -> User | None:
stmt = select(User).where(User.email == email).execution_options(
populate_existing=True)
stmt = self._apply_role_option(stmt, with_roles)
result = self.session.execute(stmt)
if with_roles:
result = result.unique()
return result.scalar_one_or_none()
def get_by_username(self, username: str, *, with_roles: bool = False) -> User | None:
stmt = select(User).where(User.username ==
username).execution_options(populate_existing=True)
stmt = self._apply_role_option(stmt, with_roles)
result = self.session.execute(stmt)
if with_roles:
result = result.unique()
return result.scalar_one_or_none()
def create(self, user: User) -> User:
self.session.add(user)
try:
self.session.flush()
except IntegrityError as exc: # pragma: no cover - DB constraint enforcement
raise EntityConflictError(
"User violates uniqueness constraints") from exc
return user
def assign_role(
self,
*,
user_id: int,
role_id: int,
granted_by: int | None = None,
) -> UserRole:
stmt = select(UserRole).where(
UserRole.user_id == user_id,
UserRole.role_id == role_id,
)
assignment = self.session.execute(stmt).scalar_one_or_none()
if assignment:
return assignment
assignment = UserRole(
user_id=user_id,
role_id=role_id,
granted_by=granted_by,
)
self.session.add(assignment)
try:
self.session.flush()
except IntegrityError as exc: # pragma: no cover - DB constraint enforcement
raise EntityConflictError(
"Assignment violates constraints") from exc
return assignment
def revoke_role(self, *, user_id: int, role_id: int) -> None:
stmt = select(UserRole).where(
UserRole.user_id == user_id,
UserRole.role_id == role_id,
)
assignment = self.session.execute(stmt).scalar_one_or_none()
if assignment is None:
raise EntityNotFoundError(
f"Role {role_id} not assigned to user {user_id}")
self.session.delete(assignment)
self.session.flush()
DEFAULT_ROLE_DEFINITIONS: tuple[dict[str, str], ...] = (
{
"name": "admin",
"display_name": "Administrator",
"description": "Full platform access with user management rights.",
},
{
"name": "project_manager",
"display_name": "Project Manager",
"description": "Manage projects, scenarios, and associated data.",
},
{
"name": "analyst",
"display_name": "Analyst",
"description": "Review dashboards and scenario outputs.",
},
{
"name": "viewer",
"display_name": "Viewer",
"description": "Read-only access to assigned projects and reports.",
},
)
def ensure_default_roles(role_repo: RoleRepository) -> list[Role]:
"""Ensure standard roles exist, creating missing ones.
Returns all current role records in creation order.
"""
roles: list[Role] = []
for definition in DEFAULT_ROLE_DEFINITIONS:
existing = role_repo.get_by_name(definition["name"])
if existing:
roles.append(existing)
continue
role = Role(**definition)
roles.append(role_repo.create(role))
return roles
def ensure_admin_user(
user_repo: UserRepository,
role_repo: RoleRepository,
*,
email: str,
username: str,
password: str,
) -> User:
"""Ensure an administrator user exists and holds the admin role."""
user = user_repo.get_by_email(email, with_roles=True)
if user is None:
user = User(
email=email,
username=username,
password_hash=User.hash_password(password),
is_active=True,
is_superuser=True,
)
user_repo.create(user)
else:
if not user.is_active:
user.is_active = True
if not user.is_superuser:
user.is_superuser = True
user_repo.session.flush()
admin_role = role_repo.get_by_name("admin")
if admin_role is None: # pragma: no cover - safety if ensure_default_roles wasn't called
admin_role = role_repo.create(
Role(
name="admin",
display_name="Administrator",
description="Full platform access with user management rights.",
)
)
user_repo.assign_role(
user_id=user.id,
role_id=admin_role.id,
granted_by=user.id,
)
return user

213
services/security.py Normal file
View File

@@ -0,0 +1,213 @@
from __future__ import annotations
from dataclasses import dataclass, field
from datetime import datetime, timedelta, timezone
from typing import Any, Dict, Iterable, Literal, Type
from jose import ExpiredSignatureError, JWTError, jwt
from passlib.context import CryptContext
try: # pragma: no cover - compatibility shim for passlib/argon2 warning
import importlib.metadata as importlib_metadata
import argon2 # type: ignore
setattr(argon2, "__version__", importlib_metadata.version("argon2-cffi"))
except Exception: # pragma: no cover - executed only when metadata lookup fails
pass
from pydantic import BaseModel, Field, ValidationError
password_context = CryptContext(schemes=["argon2"], deprecated="auto")
def hash_password(password: str) -> str:
"""Derive a secure hash for a plain-text password."""
return password_context.hash(password)
def verify_password(candidate: str, hashed: str) -> bool:
"""Verify that a candidate password matches a stored hash."""
try:
return password_context.verify(candidate, hashed)
except ValueError:
# Raised when the stored hash is malformed or uses an unknown scheme.
return False
class TokenError(Exception):
"""Base class for token encoding/decoding issues."""
class TokenDecodeError(TokenError):
"""Raised when a token cannot be decoded or validated."""
class TokenExpiredError(TokenError):
"""Raised when a token has expired."""
class TokenTypeMismatchError(TokenError):
"""Raised when a token type does not match the expected flavour."""
TokenKind = Literal["access", "refresh"]
class TokenPayload(BaseModel):
"""Shared fields for CalMiner JWT payloads."""
sub: str
exp: int
type: TokenKind
scopes: list[str] = Field(default_factory=list)
@property
def expires_at(self) -> datetime:
return datetime.fromtimestamp(self.exp, tz=timezone.utc)
@dataclass(slots=True)
class JWTSettings:
"""Runtime configuration for JWT encoding and validation."""
secret_key: str
algorithm: str = "HS256"
access_token_ttl: timedelta = field(
default_factory=lambda: timedelta(minutes=15))
refresh_token_ttl: timedelta = field(
default_factory=lambda: timedelta(days=7))
def create_access_token(
subject: str,
settings: JWTSettings,
*,
scopes: Iterable[str] | None = None,
expires_delta: timedelta | None = None,
extra_claims: Dict[str, Any] | None = None,
) -> str:
"""Issue a signed access token for the provided subject."""
lifetime = expires_delta or settings.access_token_ttl
return _create_token(
subject=subject,
token_type="access",
settings=settings,
lifetime=lifetime,
scopes=scopes,
extra_claims=extra_claims,
)
def create_refresh_token(
subject: str,
settings: JWTSettings,
*,
scopes: Iterable[str] | None = None,
expires_delta: timedelta | None = None,
extra_claims: Dict[str, Any] | None = None,
) -> str:
"""Issue a signed refresh token for the provided subject."""
lifetime = expires_delta or settings.refresh_token_ttl
return _create_token(
subject=subject,
token_type="refresh",
settings=settings,
lifetime=lifetime,
scopes=scopes,
extra_claims=extra_claims,
)
def decode_access_token(token: str, settings: JWTSettings) -> TokenPayload:
"""Validate and decode an access token."""
return _decode_token(token, settings, expected_type="access")
def decode_refresh_token(token: str, settings: JWTSettings) -> TokenPayload:
"""Validate and decode a refresh token."""
return _decode_token(token, settings, expected_type="refresh")
def _create_token(
*,
subject: str,
token_type: TokenKind,
settings: JWTSettings,
lifetime: timedelta,
scopes: Iterable[str] | None,
extra_claims: Dict[str, Any] | None,
) -> str:
now = datetime.now(timezone.utc)
expire = now + lifetime
payload: Dict[str, Any] = {
"sub": subject,
"type": token_type,
"iat": int(now.timestamp()),
"exp": int(expire.timestamp()),
}
if scopes:
payload["scopes"] = list(scopes)
if extra_claims:
payload.update(extra_claims)
return jwt.encode(payload, settings.secret_key, algorithm=settings.algorithm)
def _decode_token(
token: str,
settings: JWTSettings,
expected_type: TokenKind,
) -> TokenPayload:
try:
decoded = jwt.decode(
token,
settings.secret_key,
algorithms=[settings.algorithm],
options={"verify_aud": False},
)
except ExpiredSignatureError as exc: # pragma: no cover - jose marks this path
raise TokenExpiredError("Token has expired") from exc
except JWTError as exc: # pragma: no cover - jose error bubble
raise TokenDecodeError("Unable to decode token") from exc
try:
payload = _model_validate(TokenPayload, decoded)
except ValidationError as exc:
raise TokenDecodeError("Token payload validation failed") from exc
if payload.type != expected_type:
raise TokenTypeMismatchError(
f"Expected a {expected_type} token but received '{payload.type}'."
)
return payload
def _model_validate(model: Type[TokenPayload], data: Dict[str, Any]) -> TokenPayload:
if hasattr(model, "model_validate"):
return model.model_validate(data) # type: ignore[attr-defined]
return model.parse_obj(data) # type: ignore[attr-defined]
__all__ = [
"JWTSettings",
"TokenDecodeError",
"TokenError",
"TokenExpiredError",
"TokenKind",
"TokenPayload",
"TokenTypeMismatchError",
"create_access_token",
"create_refresh_token",
"decode_access_token",
"decode_refresh_token",
"hash_password",
"password_context",
"verify_password",
]

View File

@@ -6,12 +6,16 @@ from typing import Callable, Sequence
from sqlalchemy.orm import Session
from config.database import SessionLocal
from models import Scenario
from models import Role, Scenario
from services.repositories import (
FinancialInputRepository,
ProjectRepository,
RoleRepository,
ScenarioRepository,
SimulationParameterRepository,
UserRepository,
ensure_admin_user as ensure_admin_user_record,
ensure_default_roles,
)
from services.scenario_validation import ScenarioComparisonValidator
@@ -23,6 +27,12 @@ class UnitOfWork(AbstractContextManager["UnitOfWork"]):
self._session_factory = session_factory
self.session: Session | None = None
self._scenario_validator: ScenarioComparisonValidator | None = None
self.projects: ProjectRepository | None = None
self.scenarios: ScenarioRepository | None = None
self.financial_inputs: FinancialInputRepository | None = None
self.simulation_parameters: SimulationParameterRepository | None = None
self.users: UserRepository | None = None
self.roles: RoleRepository | None = None
def __enter__(self) -> "UnitOfWork":
self.session = self._session_factory()
@@ -31,6 +41,8 @@ class UnitOfWork(AbstractContextManager["UnitOfWork"]):
self.financial_inputs = FinancialInputRepository(self.session)
self.simulation_parameters = SimulationParameterRepository(
self.session)
self.users = UserRepository(self.session)
self.roles = RoleRepository(self.session)
self._scenario_validator = ScenarioComparisonValidator()
return self
@@ -42,6 +54,12 @@ class UnitOfWork(AbstractContextManager["UnitOfWork"]):
self.session.rollback()
self.session.close()
self._scenario_validator = None
self.projects = None
self.scenarios = None
self.financial_inputs = None
self.simulation_parameters = None
self.users = None
self.roles = None
def flush(self) -> None:
if not self.session:
@@ -61,7 +79,7 @@ class UnitOfWork(AbstractContextManager["UnitOfWork"]):
def validate_scenarios_for_comparison(
self, scenario_ids: Sequence[int]
) -> list[Scenario]:
if not self.session or not self._scenario_validator:
if not self.session or not self._scenario_validator or not self.scenarios:
raise RuntimeError("UnitOfWork session is not initialised")
scenarios = [self.scenarios.get(scenario_id)
@@ -75,3 +93,26 @@ class UnitOfWork(AbstractContextManager["UnitOfWork"]):
if not self._scenario_validator:
raise RuntimeError("UnitOfWork session is not initialised")
self._scenario_validator.validate(scenarios)
def ensure_default_roles(self) -> list[Role]:
if not self.roles:
raise RuntimeError("UnitOfWork session is not initialised")
return ensure_default_roles(self.roles)
def ensure_admin_user(
self,
*,
email: str,
username: str,
password: str,
) -> None:
if not self.users or not self.roles:
raise RuntimeError("UnitOfWork session is not initialised")
ensure_default_roles(self.roles)
ensure_admin_user_record(
self.users,
self.roles,
email=email,
username=username,
password=password,
)