feat: Implement user and role management with repositories
- Added RoleRepository and UserRepository for managing roles and users. - Implemented methods for creating, retrieving, and assigning roles to users. - Introduced functions to ensure default roles and an admin user exist in the system. - Updated UnitOfWork to include user and role repositories. - Created new security module for password hashing and JWT token management. - Added tests for authentication flows, including registration, login, and password reset. - Enhanced HTML templates for user registration, login, and password management with error handling. - Added a logo image to the static assets.
This commit is contained in:
@@ -8,7 +8,16 @@ from sqlalchemy import select, func
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.orm import Session, joinedload, selectinload
|
||||
|
||||
from models import FinancialInput, Project, Scenario, ScenarioStatus, SimulationParameter
|
||||
from models import (
|
||||
FinancialInput,
|
||||
Project,
|
||||
Role,
|
||||
Scenario,
|
||||
ScenarioStatus,
|
||||
SimulationParameter,
|
||||
User,
|
||||
UserRole,
|
||||
)
|
||||
from services.exceptions import EntityConflictError, EntityNotFoundError
|
||||
|
||||
|
||||
@@ -211,3 +220,220 @@ class SimulationParameterRepository:
|
||||
raise EntityNotFoundError(
|
||||
f"Simulation parameter {parameter_id} not found")
|
||||
self.session.delete(entity)
|
||||
|
||||
|
||||
class RoleRepository:
|
||||
"""Persistence operations for Role entities."""
|
||||
|
||||
def __init__(self, session: Session) -> None:
|
||||
self.session = session
|
||||
|
||||
def list(self) -> Sequence[Role]:
|
||||
stmt = select(Role).order_by(Role.name)
|
||||
return self.session.execute(stmt).scalars().all()
|
||||
|
||||
def get(self, role_id: int) -> Role:
|
||||
stmt = select(Role).where(Role.id == role_id)
|
||||
role = self.session.execute(stmt).scalar_one_or_none()
|
||||
if role is None:
|
||||
raise EntityNotFoundError(f"Role {role_id} not found")
|
||||
return role
|
||||
|
||||
def get_by_name(self, name: str) -> Role | None:
|
||||
stmt = select(Role).where(Role.name == name)
|
||||
return self.session.execute(stmt).scalar_one_or_none()
|
||||
|
||||
def create(self, role: Role) -> Role:
|
||||
self.session.add(role)
|
||||
try:
|
||||
self.session.flush()
|
||||
except IntegrityError as exc: # pragma: no cover - DB constraint enforcement
|
||||
raise EntityConflictError(
|
||||
"Role violates uniqueness constraints") from exc
|
||||
return role
|
||||
|
||||
|
||||
class UserRepository:
|
||||
"""Persistence operations for User entities and their role assignments."""
|
||||
|
||||
def __init__(self, session: Session) -> None:
|
||||
self.session = session
|
||||
|
||||
def list(self, *, with_roles: bool = False) -> Sequence[User]:
|
||||
stmt = select(User).order_by(User.created_at)
|
||||
if with_roles:
|
||||
stmt = stmt.options(selectinload(User.roles))
|
||||
return self.session.execute(stmt).scalars().all()
|
||||
|
||||
def _apply_role_option(self, stmt, with_roles: bool):
|
||||
if with_roles:
|
||||
stmt = stmt.options(
|
||||
joinedload(User.role_assignments).joinedload(UserRole.role),
|
||||
selectinload(User.roles),
|
||||
)
|
||||
return stmt
|
||||
|
||||
def get(self, user_id: int, *, with_roles: bool = False) -> User:
|
||||
stmt = select(User).where(User.id == user_id).execution_options(
|
||||
populate_existing=True)
|
||||
stmt = self._apply_role_option(stmt, with_roles)
|
||||
result = self.session.execute(stmt)
|
||||
if with_roles:
|
||||
result = result.unique()
|
||||
user = result.scalar_one_or_none()
|
||||
if user is None:
|
||||
raise EntityNotFoundError(f"User {user_id} not found")
|
||||
return user
|
||||
|
||||
def get_by_email(self, email: str, *, with_roles: bool = False) -> User | None:
|
||||
stmt = select(User).where(User.email == email).execution_options(
|
||||
populate_existing=True)
|
||||
stmt = self._apply_role_option(stmt, with_roles)
|
||||
result = self.session.execute(stmt)
|
||||
if with_roles:
|
||||
result = result.unique()
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
def get_by_username(self, username: str, *, with_roles: bool = False) -> User | None:
|
||||
stmt = select(User).where(User.username ==
|
||||
username).execution_options(populate_existing=True)
|
||||
stmt = self._apply_role_option(stmt, with_roles)
|
||||
result = self.session.execute(stmt)
|
||||
if with_roles:
|
||||
result = result.unique()
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
def create(self, user: User) -> User:
|
||||
self.session.add(user)
|
||||
try:
|
||||
self.session.flush()
|
||||
except IntegrityError as exc: # pragma: no cover - DB constraint enforcement
|
||||
raise EntityConflictError(
|
||||
"User violates uniqueness constraints") from exc
|
||||
return user
|
||||
|
||||
def assign_role(
|
||||
self,
|
||||
*,
|
||||
user_id: int,
|
||||
role_id: int,
|
||||
granted_by: int | None = None,
|
||||
) -> UserRole:
|
||||
stmt = select(UserRole).where(
|
||||
UserRole.user_id == user_id,
|
||||
UserRole.role_id == role_id,
|
||||
)
|
||||
assignment = self.session.execute(stmt).scalar_one_or_none()
|
||||
if assignment:
|
||||
return assignment
|
||||
|
||||
assignment = UserRole(
|
||||
user_id=user_id,
|
||||
role_id=role_id,
|
||||
granted_by=granted_by,
|
||||
)
|
||||
self.session.add(assignment)
|
||||
try:
|
||||
self.session.flush()
|
||||
except IntegrityError as exc: # pragma: no cover - DB constraint enforcement
|
||||
raise EntityConflictError(
|
||||
"Assignment violates constraints") from exc
|
||||
return assignment
|
||||
|
||||
def revoke_role(self, *, user_id: int, role_id: int) -> None:
|
||||
stmt = select(UserRole).where(
|
||||
UserRole.user_id == user_id,
|
||||
UserRole.role_id == role_id,
|
||||
)
|
||||
assignment = self.session.execute(stmt).scalar_one_or_none()
|
||||
if assignment is None:
|
||||
raise EntityNotFoundError(
|
||||
f"Role {role_id} not assigned to user {user_id}")
|
||||
self.session.delete(assignment)
|
||||
self.session.flush()
|
||||
|
||||
|
||||
DEFAULT_ROLE_DEFINITIONS: tuple[dict[str, str], ...] = (
|
||||
{
|
||||
"name": "admin",
|
||||
"display_name": "Administrator",
|
||||
"description": "Full platform access with user management rights.",
|
||||
},
|
||||
{
|
||||
"name": "project_manager",
|
||||
"display_name": "Project Manager",
|
||||
"description": "Manage projects, scenarios, and associated data.",
|
||||
},
|
||||
{
|
||||
"name": "analyst",
|
||||
"display_name": "Analyst",
|
||||
"description": "Review dashboards and scenario outputs.",
|
||||
},
|
||||
{
|
||||
"name": "viewer",
|
||||
"display_name": "Viewer",
|
||||
"description": "Read-only access to assigned projects and reports.",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def ensure_default_roles(role_repo: RoleRepository) -> list[Role]:
|
||||
"""Ensure standard roles exist, creating missing ones.
|
||||
|
||||
Returns all current role records in creation order.
|
||||
"""
|
||||
|
||||
roles: list[Role] = []
|
||||
for definition in DEFAULT_ROLE_DEFINITIONS:
|
||||
existing = role_repo.get_by_name(definition["name"])
|
||||
if existing:
|
||||
roles.append(existing)
|
||||
continue
|
||||
role = Role(**definition)
|
||||
roles.append(role_repo.create(role))
|
||||
return roles
|
||||
|
||||
|
||||
def ensure_admin_user(
|
||||
user_repo: UserRepository,
|
||||
role_repo: RoleRepository,
|
||||
*,
|
||||
email: str,
|
||||
username: str,
|
||||
password: str,
|
||||
) -> User:
|
||||
"""Ensure an administrator user exists and holds the admin role."""
|
||||
|
||||
user = user_repo.get_by_email(email, with_roles=True)
|
||||
if user is None:
|
||||
user = User(
|
||||
email=email,
|
||||
username=username,
|
||||
password_hash=User.hash_password(password),
|
||||
is_active=True,
|
||||
is_superuser=True,
|
||||
)
|
||||
user_repo.create(user)
|
||||
else:
|
||||
if not user.is_active:
|
||||
user.is_active = True
|
||||
if not user.is_superuser:
|
||||
user.is_superuser = True
|
||||
user_repo.session.flush()
|
||||
|
||||
admin_role = role_repo.get_by_name("admin")
|
||||
if admin_role is None: # pragma: no cover - safety if ensure_default_roles wasn't called
|
||||
admin_role = role_repo.create(
|
||||
Role(
|
||||
name="admin",
|
||||
display_name="Administrator",
|
||||
description="Full platform access with user management rights.",
|
||||
)
|
||||
)
|
||||
|
||||
user_repo.assign_role(
|
||||
user_id=user.id,
|
||||
role_id=admin_role.id,
|
||||
granted_by=user.id,
|
||||
)
|
||||
return user
|
||||
|
||||
213
services/security.py
Normal file
213
services/security.py
Normal file
@@ -0,0 +1,213 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Any, Dict, Iterable, Literal, Type
|
||||
|
||||
from jose import ExpiredSignatureError, JWTError, jwt
|
||||
from passlib.context import CryptContext
|
||||
|
||||
try: # pragma: no cover - compatibility shim for passlib/argon2 warning
|
||||
import importlib.metadata as importlib_metadata
|
||||
import argon2 # type: ignore
|
||||
|
||||
setattr(argon2, "__version__", importlib_metadata.version("argon2-cffi"))
|
||||
except Exception: # pragma: no cover - executed only when metadata lookup fails
|
||||
pass
|
||||
|
||||
from pydantic import BaseModel, Field, ValidationError
|
||||
|
||||
password_context = CryptContext(schemes=["argon2"], deprecated="auto")
|
||||
|
||||
|
||||
def hash_password(password: str) -> str:
|
||||
"""Derive a secure hash for a plain-text password."""
|
||||
|
||||
return password_context.hash(password)
|
||||
|
||||
|
||||
def verify_password(candidate: str, hashed: str) -> bool:
|
||||
"""Verify that a candidate password matches a stored hash."""
|
||||
|
||||
try:
|
||||
return password_context.verify(candidate, hashed)
|
||||
except ValueError:
|
||||
# Raised when the stored hash is malformed or uses an unknown scheme.
|
||||
return False
|
||||
|
||||
|
||||
class TokenError(Exception):
|
||||
"""Base class for token encoding/decoding issues."""
|
||||
|
||||
|
||||
class TokenDecodeError(TokenError):
|
||||
"""Raised when a token cannot be decoded or validated."""
|
||||
|
||||
|
||||
class TokenExpiredError(TokenError):
|
||||
"""Raised when a token has expired."""
|
||||
|
||||
|
||||
class TokenTypeMismatchError(TokenError):
|
||||
"""Raised when a token type does not match the expected flavour."""
|
||||
|
||||
|
||||
TokenKind = Literal["access", "refresh"]
|
||||
|
||||
|
||||
class TokenPayload(BaseModel):
|
||||
"""Shared fields for CalMiner JWT payloads."""
|
||||
|
||||
sub: str
|
||||
exp: int
|
||||
type: TokenKind
|
||||
scopes: list[str] = Field(default_factory=list)
|
||||
|
||||
@property
|
||||
def expires_at(self) -> datetime:
|
||||
return datetime.fromtimestamp(self.exp, tz=timezone.utc)
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class JWTSettings:
|
||||
"""Runtime configuration for JWT encoding and validation."""
|
||||
|
||||
secret_key: str
|
||||
algorithm: str = "HS256"
|
||||
access_token_ttl: timedelta = field(
|
||||
default_factory=lambda: timedelta(minutes=15))
|
||||
refresh_token_ttl: timedelta = field(
|
||||
default_factory=lambda: timedelta(days=7))
|
||||
|
||||
|
||||
def create_access_token(
|
||||
subject: str,
|
||||
settings: JWTSettings,
|
||||
*,
|
||||
scopes: Iterable[str] | None = None,
|
||||
expires_delta: timedelta | None = None,
|
||||
extra_claims: Dict[str, Any] | None = None,
|
||||
) -> str:
|
||||
"""Issue a signed access token for the provided subject."""
|
||||
|
||||
lifetime = expires_delta or settings.access_token_ttl
|
||||
return _create_token(
|
||||
subject=subject,
|
||||
token_type="access",
|
||||
settings=settings,
|
||||
lifetime=lifetime,
|
||||
scopes=scopes,
|
||||
extra_claims=extra_claims,
|
||||
)
|
||||
|
||||
|
||||
def create_refresh_token(
|
||||
subject: str,
|
||||
settings: JWTSettings,
|
||||
*,
|
||||
scopes: Iterable[str] | None = None,
|
||||
expires_delta: timedelta | None = None,
|
||||
extra_claims: Dict[str, Any] | None = None,
|
||||
) -> str:
|
||||
"""Issue a signed refresh token for the provided subject."""
|
||||
|
||||
lifetime = expires_delta or settings.refresh_token_ttl
|
||||
return _create_token(
|
||||
subject=subject,
|
||||
token_type="refresh",
|
||||
settings=settings,
|
||||
lifetime=lifetime,
|
||||
scopes=scopes,
|
||||
extra_claims=extra_claims,
|
||||
)
|
||||
|
||||
|
||||
def decode_access_token(token: str, settings: JWTSettings) -> TokenPayload:
|
||||
"""Validate and decode an access token."""
|
||||
|
||||
return _decode_token(token, settings, expected_type="access")
|
||||
|
||||
|
||||
def decode_refresh_token(token: str, settings: JWTSettings) -> TokenPayload:
|
||||
"""Validate and decode a refresh token."""
|
||||
|
||||
return _decode_token(token, settings, expected_type="refresh")
|
||||
|
||||
|
||||
def _create_token(
|
||||
*,
|
||||
subject: str,
|
||||
token_type: TokenKind,
|
||||
settings: JWTSettings,
|
||||
lifetime: timedelta,
|
||||
scopes: Iterable[str] | None,
|
||||
extra_claims: Dict[str, Any] | None,
|
||||
) -> str:
|
||||
now = datetime.now(timezone.utc)
|
||||
expire = now + lifetime
|
||||
payload: Dict[str, Any] = {
|
||||
"sub": subject,
|
||||
"type": token_type,
|
||||
"iat": int(now.timestamp()),
|
||||
"exp": int(expire.timestamp()),
|
||||
}
|
||||
if scopes:
|
||||
payload["scopes"] = list(scopes)
|
||||
if extra_claims:
|
||||
payload.update(extra_claims)
|
||||
|
||||
return jwt.encode(payload, settings.secret_key, algorithm=settings.algorithm)
|
||||
|
||||
|
||||
def _decode_token(
|
||||
token: str,
|
||||
settings: JWTSettings,
|
||||
expected_type: TokenKind,
|
||||
) -> TokenPayload:
|
||||
try:
|
||||
decoded = jwt.decode(
|
||||
token,
|
||||
settings.secret_key,
|
||||
algorithms=[settings.algorithm],
|
||||
options={"verify_aud": False},
|
||||
)
|
||||
except ExpiredSignatureError as exc: # pragma: no cover - jose marks this path
|
||||
raise TokenExpiredError("Token has expired") from exc
|
||||
except JWTError as exc: # pragma: no cover - jose error bubble
|
||||
raise TokenDecodeError("Unable to decode token") from exc
|
||||
|
||||
try:
|
||||
payload = _model_validate(TokenPayload, decoded)
|
||||
except ValidationError as exc:
|
||||
raise TokenDecodeError("Token payload validation failed") from exc
|
||||
|
||||
if payload.type != expected_type:
|
||||
raise TokenTypeMismatchError(
|
||||
f"Expected a {expected_type} token but received '{payload.type}'."
|
||||
)
|
||||
|
||||
return payload
|
||||
|
||||
|
||||
def _model_validate(model: Type[TokenPayload], data: Dict[str, Any]) -> TokenPayload:
|
||||
if hasattr(model, "model_validate"):
|
||||
return model.model_validate(data) # type: ignore[attr-defined]
|
||||
return model.parse_obj(data) # type: ignore[attr-defined]
|
||||
|
||||
|
||||
__all__ = [
|
||||
"JWTSettings",
|
||||
"TokenDecodeError",
|
||||
"TokenError",
|
||||
"TokenExpiredError",
|
||||
"TokenKind",
|
||||
"TokenPayload",
|
||||
"TokenTypeMismatchError",
|
||||
"create_access_token",
|
||||
"create_refresh_token",
|
||||
"decode_access_token",
|
||||
"decode_refresh_token",
|
||||
"hash_password",
|
||||
"password_context",
|
||||
"verify_password",
|
||||
]
|
||||
@@ -6,12 +6,16 @@ from typing import Callable, Sequence
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from config.database import SessionLocal
|
||||
from models import Scenario
|
||||
from models import Role, Scenario
|
||||
from services.repositories import (
|
||||
FinancialInputRepository,
|
||||
ProjectRepository,
|
||||
RoleRepository,
|
||||
ScenarioRepository,
|
||||
SimulationParameterRepository,
|
||||
UserRepository,
|
||||
ensure_admin_user as ensure_admin_user_record,
|
||||
ensure_default_roles,
|
||||
)
|
||||
from services.scenario_validation import ScenarioComparisonValidator
|
||||
|
||||
@@ -23,6 +27,12 @@ class UnitOfWork(AbstractContextManager["UnitOfWork"]):
|
||||
self._session_factory = session_factory
|
||||
self.session: Session | None = None
|
||||
self._scenario_validator: ScenarioComparisonValidator | None = None
|
||||
self.projects: ProjectRepository | None = None
|
||||
self.scenarios: ScenarioRepository | None = None
|
||||
self.financial_inputs: FinancialInputRepository | None = None
|
||||
self.simulation_parameters: SimulationParameterRepository | None = None
|
||||
self.users: UserRepository | None = None
|
||||
self.roles: RoleRepository | None = None
|
||||
|
||||
def __enter__(self) -> "UnitOfWork":
|
||||
self.session = self._session_factory()
|
||||
@@ -31,6 +41,8 @@ class UnitOfWork(AbstractContextManager["UnitOfWork"]):
|
||||
self.financial_inputs = FinancialInputRepository(self.session)
|
||||
self.simulation_parameters = SimulationParameterRepository(
|
||||
self.session)
|
||||
self.users = UserRepository(self.session)
|
||||
self.roles = RoleRepository(self.session)
|
||||
self._scenario_validator = ScenarioComparisonValidator()
|
||||
return self
|
||||
|
||||
@@ -42,6 +54,12 @@ class UnitOfWork(AbstractContextManager["UnitOfWork"]):
|
||||
self.session.rollback()
|
||||
self.session.close()
|
||||
self._scenario_validator = None
|
||||
self.projects = None
|
||||
self.scenarios = None
|
||||
self.financial_inputs = None
|
||||
self.simulation_parameters = None
|
||||
self.users = None
|
||||
self.roles = None
|
||||
|
||||
def flush(self) -> None:
|
||||
if not self.session:
|
||||
@@ -61,7 +79,7 @@ class UnitOfWork(AbstractContextManager["UnitOfWork"]):
|
||||
def validate_scenarios_for_comparison(
|
||||
self, scenario_ids: Sequence[int]
|
||||
) -> list[Scenario]:
|
||||
if not self.session or not self._scenario_validator:
|
||||
if not self.session or not self._scenario_validator or not self.scenarios:
|
||||
raise RuntimeError("UnitOfWork session is not initialised")
|
||||
|
||||
scenarios = [self.scenarios.get(scenario_id)
|
||||
@@ -75,3 +93,26 @@ class UnitOfWork(AbstractContextManager["UnitOfWork"]):
|
||||
if not self._scenario_validator:
|
||||
raise RuntimeError("UnitOfWork session is not initialised")
|
||||
self._scenario_validator.validate(scenarios)
|
||||
|
||||
def ensure_default_roles(self) -> list[Role]:
|
||||
if not self.roles:
|
||||
raise RuntimeError("UnitOfWork session is not initialised")
|
||||
return ensure_default_roles(self.roles)
|
||||
|
||||
def ensure_admin_user(
|
||||
self,
|
||||
*,
|
||||
email: str,
|
||||
username: str,
|
||||
password: str,
|
||||
) -> None:
|
||||
if not self.users or not self.roles:
|
||||
raise RuntimeError("UnitOfWork session is not initialised")
|
||||
ensure_default_roles(self.roles)
|
||||
ensure_admin_user_record(
|
||||
self.users,
|
||||
self.roles,
|
||||
email=email,
|
||||
username=username,
|
||||
password=password,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user