feat: Implement user and role management with repositories

- Added RoleRepository and UserRepository for managing roles and users.
- Implemented methods for creating, retrieving, and assigning roles to users.
- Introduced functions to ensure default roles and an admin user exist in the system.
- Updated UnitOfWork to include user and role repositories.
- Created new security module for password hashing and JWT token management.
- Added tests for authentication flows, including registration, login, and password reset.
- Enhanced HTML templates for user registration, login, and password management with error handling.
- Added a logo image to the static assets.
This commit is contained in:
2025-11-09 21:48:35 +01:00
parent 53879a411f
commit 3601c2e422
22 changed files with 1955 additions and 132 deletions

View File

@@ -0,0 +1,210 @@
"""Add authentication and RBAC tables"""
from __future__ import annotations
from alembic import op
import sqlalchemy as sa
from passlib.context import CryptContext
from sqlalchemy.sql import column, table
# revision identifiers, used by Alembic.
revision = "20251109_02"
down_revision = "20251109_01"
branch_labels = None
depends_on = None
password_context = CryptContext(schemes=["argon2"], deprecated="auto")
def upgrade() -> None:
op.create_table(
"users",
sa.Column("id", sa.Integer(), primary_key=True),
sa.Column("email", sa.String(length=255), nullable=False),
sa.Column("username", sa.String(length=128), nullable=False),
sa.Column("password_hash", sa.String(length=255), nullable=False),
sa.Column(
"is_active",
sa.Boolean(),
nullable=False,
server_default=sa.true(),
),
sa.Column(
"is_superuser",
sa.Boolean(),
nullable=False,
server_default=sa.false(),
),
sa.Column("last_login_at", sa.DateTime(timezone=True), nullable=True),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
nullable=False,
server_default=sa.func.now(),
),
sa.Column(
"updated_at",
sa.DateTime(timezone=True),
nullable=False,
server_default=sa.func.now(),
),
sa.UniqueConstraint("email", name="uq_users_email"),
sa.UniqueConstraint("username", name="uq_users_username"),
)
op.create_index(
"ix_users_active_superuser",
"users",
["is_active", "is_superuser"],
unique=False,
)
op.create_table(
"roles",
sa.Column("id", sa.Integer(), primary_key=True),
sa.Column("name", sa.String(length=64), nullable=False),
sa.Column("display_name", sa.String(length=128), nullable=False),
sa.Column("description", sa.Text(), nullable=True),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
nullable=False,
server_default=sa.func.now(),
),
sa.Column(
"updated_at",
sa.DateTime(timezone=True),
nullable=False,
server_default=sa.func.now(),
),
sa.UniqueConstraint("name", name="uq_roles_name"),
)
op.create_table(
"user_roles",
sa.Column("user_id", sa.Integer(), nullable=False),
sa.Column("role_id", sa.Integer(), nullable=False),
sa.Column(
"granted_at",
sa.DateTime(timezone=True),
nullable=False,
server_default=sa.func.now(),
),
sa.Column("granted_by", sa.Integer(), nullable=True),
sa.ForeignKeyConstraint(
["user_id"],
["users.id"],
ondelete="CASCADE",
),
sa.ForeignKeyConstraint(
["role_id"],
["roles.id"],
ondelete="CASCADE",
),
sa.ForeignKeyConstraint(
["granted_by"],
["users.id"],
ondelete="SET NULL",
),
sa.PrimaryKeyConstraint("user_id", "role_id"),
sa.UniqueConstraint("user_id", "role_id",
name="uq_user_roles_user_role"),
)
op.create_index(
"ix_user_roles_role_id",
"user_roles",
["role_id"],
unique=False,
)
# Seed default roles
roles_table = table(
"roles",
column("id", sa.Integer()),
column("name", sa.String()),
column("display_name", sa.String()),
column("description", sa.Text()),
)
op.bulk_insert(
roles_table,
[
{
"id": 1,
"name": "admin",
"display_name": "Administrator",
"description": "Full platform access with user management rights.",
},
{
"id": 2,
"name": "project_manager",
"display_name": "Project Manager",
"description": "Manage projects, scenarios, and associated data.",
},
{
"id": 3,
"name": "analyst",
"display_name": "Analyst",
"description": "Review dashboards and scenario outputs.",
},
{
"id": 4,
"name": "viewer",
"display_name": "Viewer",
"description": "Read-only access to assigned projects and reports.",
},
],
)
admin_password_hash = password_context.hash("ChangeMe123!")
users_table = table(
"users",
column("id", sa.Integer()),
column("email", sa.String()),
column("username", sa.String()),
column("password_hash", sa.String()),
column("is_active", sa.Boolean()),
column("is_superuser", sa.Boolean()),
)
op.bulk_insert(
users_table,
[
{
"id": 1,
"email": "admin@calminer.local",
"username": "admin",
"password_hash": admin_password_hash,
"is_active": True,
"is_superuser": True,
}
],
)
user_roles_table = table(
"user_roles",
column("user_id", sa.Integer()),
column("role_id", sa.Integer()),
column("granted_by", sa.Integer()),
)
op.bulk_insert(
user_roles_table,
[
{
"user_id": 1,
"role_id": 1,
"granted_by": 1,
}
],
)
def downgrade() -> None:
op.drop_index("ix_user_roles_role_id", table_name="user_roles")
op.drop_table("user_roles")
op.drop_table("roles")
op.drop_index("ix_users_active_superuser", table_name="users")
op.drop_table("users")

BIN
alembic_test.db Normal file

Binary file not shown.

View File

@@ -16,3 +16,5 @@
- Reordered project route registration to prioritize static UI paths, eliminating 422 errors on `/projects/ui` and `/projects/create`, and added pytest smoke coverage for the navigation endpoints. - Reordered project route registration to prioritize static UI paths, eliminating 422 errors on `/projects/ui` and `/projects/create`, and added pytest smoke coverage for the navigation endpoints.
- Added end-to-end integration tests for project and scenario lifecycles, validating HTML redirects, template rendering, and API interactions, and updated `ProjectRepository.get` to deduplicate joined loads for detail views. - Added end-to-end integration tests for project and scenario lifecycles, validating HTML redirects, template rendering, and API interactions, and updated `ProjectRepository.get` to deduplicate joined loads for detail views.
- Updated all Jinja2 template responses to the new Starlette signature to eliminate deprecation warnings while keeping request-aware context available to the templates. - Updated all Jinja2 template responses to the new Starlette signature to eliminate deprecation warnings while keeping request-aware context available to the templates.
- Introduced `services/security.py` to centralize Argon2 password hashing utilities and JWT creation/verification with typed payloads, and added pytest coverage for hashing, expiry, tampering, and token type mismatch scenarios.
- Added `routes/auth.py` with registration, login, and password reset flows, refreshed auth templates with error messaging, wired navigation links, and introduced end-to-end pytest coverage for the new forms and token flows.

60
config/settings.py Normal file
View File

@@ -0,0 +1,60 @@
from __future__ import annotations
import os
from dataclasses import dataclass
from datetime import timedelta
from functools import lru_cache
from services.security import JWTSettings
@dataclass(frozen=True, slots=True)
class Settings:
"""Application configuration sourced from environment variables."""
jwt_secret_key: str = "change-me"
jwt_algorithm: str = "HS256"
jwt_access_token_minutes: int = 15
jwt_refresh_token_days: int = 7
@classmethod
def from_environment(cls) -> "Settings":
"""Construct settings from environment variables."""
return cls(
jwt_secret_key=os.getenv("CALMINER_JWT_SECRET", "change-me"),
jwt_algorithm=os.getenv("CALMINER_JWT_ALGORITHM", "HS256"),
jwt_access_token_minutes=cls._int_from_env(
"CALMINER_JWT_ACCESS_MINUTES", 15
),
jwt_refresh_token_days=cls._int_from_env(
"CALMINER_JWT_REFRESH_DAYS", 7
),
)
@staticmethod
def _int_from_env(name: str, default: int) -> int:
raw_value = os.getenv(name)
if raw_value is None:
return default
try:
return int(raw_value)
except ValueError:
return default
def jwt_settings(self) -> JWTSettings:
"""Build runtime JWT settings compatible with token helpers."""
return JWTSettings(
secret_key=self.jwt_secret_key,
algorithm=self.jwt_algorithm,
access_token_ttl=timedelta(minutes=self.jwt_access_token_minutes),
refresh_token_ttl=timedelta(days=self.jwt_refresh_token_days),
)
@lru_cache(maxsize=1)
def get_settings() -> Settings:
"""Return cached application settings."""
return Settings.from_environment()

View File

@@ -2,6 +2,8 @@ from __future__ import annotations
from collections.abc import Generator from collections.abc import Generator
from config.settings import Settings, get_settings
from services.security import JWTSettings
from services.unit_of_work import UnitOfWork from services.unit_of_work import UnitOfWork
@@ -10,3 +12,15 @@ def get_unit_of_work() -> Generator[UnitOfWork, None, None]:
with UnitOfWork() as uow: with UnitOfWork() as uow:
yield uow yield uow
def get_application_settings() -> Settings:
"""Provide cached application settings instance."""
return get_settings()
def get_jwt_settings() -> JWTSettings:
"""Provide JWT runtime configuration derived from settings."""
return get_settings().jwt_settings()

View File

@@ -10,6 +10,7 @@ from models import (
Scenario, Scenario,
SimulationParameter, SimulationParameter,
) )
from routes.auth import router as auth_router
from routes.dashboard import router as dashboard_router from routes.dashboard import router as dashboard_router
from routes.projects import router as projects_router from routes.projects import router as projects_router
from routes.scenarios import router as scenarios_router from routes.scenarios import router as scenarios_router
@@ -33,6 +34,7 @@ async def health() -> dict[str, str]:
app.include_router(dashboard_router) app.include_router(dashboard_router)
app.include_router(auth_router)
app.include_router(projects_router) app.include_router(projects_router)
app.include_router(scenarios_router) app.include_router(scenarios_router)

View File

@@ -4,6 +4,14 @@ from datetime import datetime
from typing import List, Optional from typing import List, Optional
from passlib.context import CryptContext from passlib.context import CryptContext
try: # pragma: no cover - defensive compatibility shim
import importlib.metadata as importlib_metadata
import argon2 # type: ignore
setattr(argon2, "__version__", importlib_metadata.version("argon2-cffi"))
except Exception:
pass
from sqlalchemy import ( from sqlalchemy import (
Boolean, Boolean,
DateTime, DateTime,

473
routes/auth.py Normal file
View File

@@ -0,0 +1,473 @@
from __future__ import annotations
from datetime import datetime, timedelta, timezone
from typing import Any, Iterable
from fastapi import APIRouter, Depends, HTTPException, Request, UploadFile, status
from fastapi.responses import HTMLResponse, RedirectResponse
from fastapi.templating import Jinja2Templates
from pydantic import ValidationError
from starlette.datastructures import FormData
from dependencies import get_jwt_settings, get_unit_of_work
from models import Role, User
from schemas.auth import (
LoginForm,
PasswordResetForm,
PasswordResetRequestForm,
RegistrationForm,
)
from services.exceptions import EntityConflictError
from services.security import (
JWTSettings,
TokenDecodeError,
TokenExpiredError,
TokenTypeMismatchError,
create_access_token,
create_refresh_token,
decode_access_token,
hash_password,
verify_password,
)
from services.repositories import RoleRepository, UserRepository
from services.unit_of_work import UnitOfWork
router = APIRouter(tags=["Authentication"])
templates = Jinja2Templates(directory="templates")
_PASSWORD_RESET_SCOPE = "password-reset"
_AUTH_SCOPE = "auth"
def _template(
request: Request,
template_name: str,
context: dict[str, Any],
*,
status_code: int = status.HTTP_200_OK,
) -> HTMLResponse:
return templates.TemplateResponse(
request,
template_name,
context,
status_code=status_code,
)
def _validation_errors(exc: ValidationError) -> list[str]:
return [error.get("msg", "Invalid input.") for error in exc.errors()]
def _scopes(include: Iterable[str]) -> list[str]:
return list(include)
def _normalise_form_data(form_data: FormData) -> dict[str, str]:
normalised: dict[str, str] = {}
for key, value in form_data.multi_items():
if isinstance(value, UploadFile):
str_value = value.filename or ""
else:
str_value = str(value)
normalised[key] = str_value
return normalised
def _require_users_repo(uow: UnitOfWork) -> UserRepository:
if not uow.users:
raise RuntimeError("User repository is not initialised")
return uow.users
def _require_roles_repo(uow: UnitOfWork) -> RoleRepository:
if not uow.roles:
raise RuntimeError("Role repository is not initialised")
return uow.roles
@router.get("/login", response_class=HTMLResponse, include_in_schema=False, name="auth.login_form")
def login_form(request: Request) -> HTMLResponse:
return _template(
request,
"login.html",
{
"form_action": request.url_for("auth.login_submit"),
"errors": [],
"username": "",
},
)
@router.post("/login", include_in_schema=False, name="auth.login_submit")
async def login_submit(
request: Request,
uow: UnitOfWork = Depends(get_unit_of_work),
jwt_settings: JWTSettings = Depends(get_jwt_settings),
):
form_data = _normalise_form_data(await request.form())
try:
form = LoginForm(**form_data)
except ValidationError as exc:
return _template(
request,
"login.html",
{
"form_action": request.url_for("auth.login_submit"),
"errors": _validation_errors(exc),
},
status_code=status.HTTP_400_BAD_REQUEST,
)
identifier = form.username
users_repo = _require_users_repo(uow)
user = _lookup_user(users_repo, identifier)
errors: list[str] = []
if not user or not verify_password(form.password, user.password_hash):
errors.append("Invalid username or password.")
elif not user.is_active:
errors.append("Account is inactive. Contact an administrator.")
if errors:
return _template(
request,
"login.html",
{
"form_action": request.url_for("auth.login_submit"),
"errors": errors,
"username": identifier,
},
status_code=status.HTTP_400_BAD_REQUEST,
)
assert user is not None # mypy hint - guarded above
user.last_login_at = datetime.now(timezone.utc)
access_token = create_access_token(
str(user.id),
jwt_settings,
scopes=_scopes((_AUTH_SCOPE,)),
)
refresh_token = create_refresh_token(
str(user.id),
jwt_settings,
scopes=_scopes((_AUTH_SCOPE,)),
)
response = RedirectResponse(
request.url_for("dashboard.home"),
status_code=status.HTTP_303_SEE_OTHER,
)
_set_auth_cookies(response, access_token, refresh_token, jwt_settings)
return response
def _lookup_user(users_repo: UserRepository, identifier: str) -> User | None:
if "@" in identifier:
return users_repo.get_by_email(identifier.lower(), with_roles=True)
return users_repo.get_by_username(identifier, with_roles=True)
def _set_auth_cookies(
response: RedirectResponse,
access_token: str,
refresh_token: str,
jwt_settings: JWTSettings,
) -> None:
access_ttl = int(jwt_settings.access_token_ttl.total_seconds())
refresh_ttl = int(jwt_settings.refresh_token_ttl.total_seconds())
response.set_cookie(
"calminer_access_token",
access_token,
httponly=True,
secure=False,
samesite="lax",
max_age=max(access_ttl, 0) or None,
)
response.set_cookie(
"calminer_refresh_token",
refresh_token,
httponly=True,
secure=False,
samesite="lax",
max_age=max(refresh_ttl, 0) or None,
)
@router.get("/register", response_class=HTMLResponse, include_in_schema=False, name="auth.register_form")
def register_form(request: Request) -> HTMLResponse:
return _template(
request,
"register.html",
{
"form_action": request.url_for("auth.register_submit"),
"errors": [],
"form_data": None,
},
)
@router.post("/register", include_in_schema=False, name="auth.register_submit")
async def register_submit(
request: Request,
uow: UnitOfWork = Depends(get_unit_of_work),
):
form_data = _normalise_form_data(await request.form())
try:
form = RegistrationForm(**form_data)
except ValidationError as exc:
return _registration_error_response(request, _validation_errors(exc))
errors: list[str] = []
users_repo = _require_users_repo(uow)
roles_repo = _require_roles_repo(uow)
uow.ensure_default_roles()
if users_repo.get_by_email(form.email):
errors.append("Email is already registered.")
if users_repo.get_by_username(form.username):
errors.append("Username is already taken.")
if errors:
return _registration_error_response(request, errors, form)
user = User(
email=form.email,
username=form.username,
password_hash=hash_password(form.password),
is_active=True,
is_superuser=False,
)
try:
created = users_repo.create(user)
except EntityConflictError:
return _registration_error_response(
request,
["An account with this username or email already exists."],
form,
)
viewer_role = _ensure_viewer_role(roles_repo)
if viewer_role is not None:
users_repo.assign_role(
user_id=created.id,
role_id=viewer_role.id,
granted_by=created.id,
)
redirect_url = request.url_for(
"auth.login_form").include_query_params(registered="1")
return RedirectResponse(
redirect_url,
status_code=status.HTTP_303_SEE_OTHER,
)
def _registration_error_response(
request: Request,
errors: list[str],
form: RegistrationForm | None = None,
) -> HTMLResponse:
context = {
"form_action": request.url_for("auth.register_submit"),
"errors": errors,
"form_data": form.model_dump(exclude={"password", "confirm_password"}) if form else None,
}
return _template(
request,
"register.html",
context,
status_code=status.HTTP_400_BAD_REQUEST,
)
def _ensure_viewer_role(roles_repo: RoleRepository) -> Role | None:
viewer = roles_repo.get_by_name("viewer")
if viewer:
return viewer
return roles_repo.get_by_name("viewer")
@router.get(
"/forgot-password",
response_class=HTMLResponse,
include_in_schema=False,
name="auth.password_reset_request_form",
)
def password_reset_request_form(request: Request) -> HTMLResponse:
return _template(
request,
"forgot_password.html",
{
"form_action": request.url_for("auth.password_reset_request_submit"),
"errors": [],
"message": None,
},
)
@router.post(
"/forgot-password",
include_in_schema=False,
name="auth.password_reset_request_submit",
)
async def password_reset_request_submit(
request: Request,
uow: UnitOfWork = Depends(get_unit_of_work),
jwt_settings: JWTSettings = Depends(get_jwt_settings),
):
form_data = _normalise_form_data(await request.form())
try:
form = PasswordResetRequestForm(**form_data)
except ValidationError as exc:
return _template(
request,
"forgot_password.html",
{
"form_action": request.url_for("auth.password_reset_request_submit"),
"errors": _validation_errors(exc),
"message": None,
},
status_code=status.HTTP_400_BAD_REQUEST,
)
users_repo = _require_users_repo(uow)
user = users_repo.get_by_email(form.email)
if not user:
return _template(
request,
"forgot_password.html",
{
"form_action": request.url_for("auth.password_reset_request_submit"),
"errors": [],
"message": "If an account exists, a reset link has been sent.",
},
)
token = create_access_token(
str(user.id),
jwt_settings,
scopes=_scopes((_PASSWORD_RESET_SCOPE,)),
expires_delta=timedelta(hours=1),
)
reset_url = request.url_for(
"auth.password_reset_form").include_query_params(token=token)
return RedirectResponse(reset_url, status_code=status.HTTP_303_SEE_OTHER)
@router.get(
"/reset-password",
response_class=HTMLResponse,
include_in_schema=False,
name="auth.password_reset_form",
)
def password_reset_form(
request: Request,
token: str | None = None,
jwt_settings: JWTSettings = Depends(get_jwt_settings),
) -> HTMLResponse:
errors: list[str] = []
if not token:
errors.append("Missing password reset token.")
else:
try:
payload = decode_access_token(token, jwt_settings)
if _PASSWORD_RESET_SCOPE not in payload.scopes:
errors.append("Invalid token scope.")
except TokenExpiredError:
errors.append(
"Token has expired. Please request a new password reset.")
except (TokenDecodeError, TokenTypeMismatchError):
errors.append("Invalid password reset token.")
return _template(
request,
"reset_password.html",
{
"form_action": request.url_for("auth.password_reset_submit"),
"token": token,
"errors": errors,
},
status_code=status.HTTP_400_BAD_REQUEST if errors else status.HTTP_200_OK,
)
@router.post(
"/reset-password",
include_in_schema=False,
name="auth.password_reset_submit",
)
async def password_reset_submit(
request: Request,
uow: UnitOfWork = Depends(get_unit_of_work),
jwt_settings: JWTSettings = Depends(get_jwt_settings),
):
form_data = _normalise_form_data(await request.form())
try:
form = PasswordResetForm(**form_data)
except ValidationError as exc:
return _template(
request,
"reset_password.html",
{
"form_action": request.url_for("auth.password_reset_submit"),
"token": form_data.get("token"),
"errors": _validation_errors(exc),
},
status_code=status.HTTP_400_BAD_REQUEST,
)
try:
payload = decode_access_token(form.token, jwt_settings)
except TokenExpiredError:
return _reset_error_response(
request,
form.token,
"Token has expired. Please request a new password reset.",
)
except (TokenDecodeError, TokenTypeMismatchError):
return _reset_error_response(
request,
form.token,
"Invalid password reset token.",
)
if _PASSWORD_RESET_SCOPE not in payload.scopes:
return _reset_error_response(
request,
form.token,
"Invalid password reset token scope.",
)
users_repo = _require_users_repo(uow)
user_id = int(payload.sub)
user = users_repo.get(user_id)
if not user:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="User not found")
user.set_password(form.password)
if not user.is_active:
user.is_active = True
redirect_url = request.url_for(
"auth.login_form").include_query_params(reset="1")
return RedirectResponse(
redirect_url,
status_code=status.HTTP_303_SEE_OTHER,
)
def _reset_error_response(request: Request, token: str, message: str) -> HTMLResponse:
return _template(
request,
"reset_password.html",
{
"form_action": request.url_for("auth.password_reset_submit"),
"token": token,
"errors": [message],
},
status_code=status.HTTP_400_BAD_REQUEST,
)

67
schemas/auth.py Normal file
View File

@@ -0,0 +1,67 @@
from __future__ import annotations
from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_validator
class FormModel(BaseModel):
"""Base Pydantic model for HTML form submissions."""
model_config = ConfigDict(extra="forbid", str_strip_whitespace=True)
class RegistrationForm(FormModel):
username: str = Field(min_length=3, max_length=128)
email: str = Field(min_length=5, max_length=255)
password: str = Field(min_length=8, max_length=256)
confirm_password: str
@field_validator("email")
@classmethod
def validate_email(cls, value: str) -> str:
if "@" not in value or value.startswith("@") or value.endswith("@"):
raise ValueError("Invalid email address.")
local, domain = value.split("@", 1)
if not local or "." not in domain:
raise ValueError("Invalid email address.")
return value.lower()
@field_validator("confirm_password")
@classmethod
def passwords_match(cls, value: str, info: ValidationInfo) -> str:
password = info.data.get("password")
if password != value:
raise ValueError("Passwords do not match.")
return value
class LoginForm(FormModel):
username: str = Field(min_length=1, max_length=255)
password: str = Field(min_length=1, max_length=256)
class PasswordResetRequestForm(FormModel):
email: str = Field(min_length=5, max_length=255)
@field_validator("email")
@classmethod
def validate_email(cls, value: str) -> str:
if "@" not in value or value.startswith("@") or value.endswith("@"):
raise ValueError("Invalid email address.")
local, domain = value.split("@", 1)
if not local or "." not in domain:
raise ValueError("Invalid email address.")
return value.lower()
class PasswordResetForm(FormModel):
token: str = Field(min_length=1)
password: str = Field(min_length=8, max_length=256)
confirm_password: str
@field_validator("confirm_password")
@classmethod
def reset_passwords_match(cls, value: str, info: ValidationInfo) -> str:
password = info.data.get("password")
if password != value:
raise ValueError("Passwords do not match.")
return value

View File

@@ -8,7 +8,16 @@ from sqlalchemy import select, func
from sqlalchemy.exc import IntegrityError from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session, joinedload, selectinload from sqlalchemy.orm import Session, joinedload, selectinload
from models import FinancialInput, Project, Scenario, ScenarioStatus, SimulationParameter from models import (
FinancialInput,
Project,
Role,
Scenario,
ScenarioStatus,
SimulationParameter,
User,
UserRole,
)
from services.exceptions import EntityConflictError, EntityNotFoundError from services.exceptions import EntityConflictError, EntityNotFoundError
@@ -211,3 +220,220 @@ class SimulationParameterRepository:
raise EntityNotFoundError( raise EntityNotFoundError(
f"Simulation parameter {parameter_id} not found") f"Simulation parameter {parameter_id} not found")
self.session.delete(entity) self.session.delete(entity)
class RoleRepository:
"""Persistence operations for Role entities."""
def __init__(self, session: Session) -> None:
self.session = session
def list(self) -> Sequence[Role]:
stmt = select(Role).order_by(Role.name)
return self.session.execute(stmt).scalars().all()
def get(self, role_id: int) -> Role:
stmt = select(Role).where(Role.id == role_id)
role = self.session.execute(stmt).scalar_one_or_none()
if role is None:
raise EntityNotFoundError(f"Role {role_id} not found")
return role
def get_by_name(self, name: str) -> Role | None:
stmt = select(Role).where(Role.name == name)
return self.session.execute(stmt).scalar_one_or_none()
def create(self, role: Role) -> Role:
self.session.add(role)
try:
self.session.flush()
except IntegrityError as exc: # pragma: no cover - DB constraint enforcement
raise EntityConflictError(
"Role violates uniqueness constraints") from exc
return role
class UserRepository:
"""Persistence operations for User entities and their role assignments."""
def __init__(self, session: Session) -> None:
self.session = session
def list(self, *, with_roles: bool = False) -> Sequence[User]:
stmt = select(User).order_by(User.created_at)
if with_roles:
stmt = stmt.options(selectinload(User.roles))
return self.session.execute(stmt).scalars().all()
def _apply_role_option(self, stmt, with_roles: bool):
if with_roles:
stmt = stmt.options(
joinedload(User.role_assignments).joinedload(UserRole.role),
selectinload(User.roles),
)
return stmt
def get(self, user_id: int, *, with_roles: bool = False) -> User:
stmt = select(User).where(User.id == user_id).execution_options(
populate_existing=True)
stmt = self._apply_role_option(stmt, with_roles)
result = self.session.execute(stmt)
if with_roles:
result = result.unique()
user = result.scalar_one_or_none()
if user is None:
raise EntityNotFoundError(f"User {user_id} not found")
return user
def get_by_email(self, email: str, *, with_roles: bool = False) -> User | None:
stmt = select(User).where(User.email == email).execution_options(
populate_existing=True)
stmt = self._apply_role_option(stmt, with_roles)
result = self.session.execute(stmt)
if with_roles:
result = result.unique()
return result.scalar_one_or_none()
def get_by_username(self, username: str, *, with_roles: bool = False) -> User | None:
stmt = select(User).where(User.username ==
username).execution_options(populate_existing=True)
stmt = self._apply_role_option(stmt, with_roles)
result = self.session.execute(stmt)
if with_roles:
result = result.unique()
return result.scalar_one_or_none()
def create(self, user: User) -> User:
self.session.add(user)
try:
self.session.flush()
except IntegrityError as exc: # pragma: no cover - DB constraint enforcement
raise EntityConflictError(
"User violates uniqueness constraints") from exc
return user
def assign_role(
self,
*,
user_id: int,
role_id: int,
granted_by: int | None = None,
) -> UserRole:
stmt = select(UserRole).where(
UserRole.user_id == user_id,
UserRole.role_id == role_id,
)
assignment = self.session.execute(stmt).scalar_one_or_none()
if assignment:
return assignment
assignment = UserRole(
user_id=user_id,
role_id=role_id,
granted_by=granted_by,
)
self.session.add(assignment)
try:
self.session.flush()
except IntegrityError as exc: # pragma: no cover - DB constraint enforcement
raise EntityConflictError(
"Assignment violates constraints") from exc
return assignment
def revoke_role(self, *, user_id: int, role_id: int) -> None:
stmt = select(UserRole).where(
UserRole.user_id == user_id,
UserRole.role_id == role_id,
)
assignment = self.session.execute(stmt).scalar_one_or_none()
if assignment is None:
raise EntityNotFoundError(
f"Role {role_id} not assigned to user {user_id}")
self.session.delete(assignment)
self.session.flush()
DEFAULT_ROLE_DEFINITIONS: tuple[dict[str, str], ...] = (
{
"name": "admin",
"display_name": "Administrator",
"description": "Full platform access with user management rights.",
},
{
"name": "project_manager",
"display_name": "Project Manager",
"description": "Manage projects, scenarios, and associated data.",
},
{
"name": "analyst",
"display_name": "Analyst",
"description": "Review dashboards and scenario outputs.",
},
{
"name": "viewer",
"display_name": "Viewer",
"description": "Read-only access to assigned projects and reports.",
},
)
def ensure_default_roles(role_repo: RoleRepository) -> list[Role]:
"""Ensure standard roles exist, creating missing ones.
Returns all current role records in creation order.
"""
roles: list[Role] = []
for definition in DEFAULT_ROLE_DEFINITIONS:
existing = role_repo.get_by_name(definition["name"])
if existing:
roles.append(existing)
continue
role = Role(**definition)
roles.append(role_repo.create(role))
return roles
def ensure_admin_user(
user_repo: UserRepository,
role_repo: RoleRepository,
*,
email: str,
username: str,
password: str,
) -> User:
"""Ensure an administrator user exists and holds the admin role."""
user = user_repo.get_by_email(email, with_roles=True)
if user is None:
user = User(
email=email,
username=username,
password_hash=User.hash_password(password),
is_active=True,
is_superuser=True,
)
user_repo.create(user)
else:
if not user.is_active:
user.is_active = True
if not user.is_superuser:
user.is_superuser = True
user_repo.session.flush()
admin_role = role_repo.get_by_name("admin")
if admin_role is None: # pragma: no cover - safety if ensure_default_roles wasn't called
admin_role = role_repo.create(
Role(
name="admin",
display_name="Administrator",
description="Full platform access with user management rights.",
)
)
user_repo.assign_role(
user_id=user.id,
role_id=admin_role.id,
granted_by=user.id,
)
return user

213
services/security.py Normal file
View File

@@ -0,0 +1,213 @@
from __future__ import annotations
from dataclasses import dataclass, field
from datetime import datetime, timedelta, timezone
from typing import Any, Dict, Iterable, Literal, Type
from jose import ExpiredSignatureError, JWTError, jwt
from passlib.context import CryptContext
try: # pragma: no cover - compatibility shim for passlib/argon2 warning
import importlib.metadata as importlib_metadata
import argon2 # type: ignore
setattr(argon2, "__version__", importlib_metadata.version("argon2-cffi"))
except Exception: # pragma: no cover - executed only when metadata lookup fails
pass
from pydantic import BaseModel, Field, ValidationError
password_context = CryptContext(schemes=["argon2"], deprecated="auto")
def hash_password(password: str) -> str:
"""Derive a secure hash for a plain-text password."""
return password_context.hash(password)
def verify_password(candidate: str, hashed: str) -> bool:
"""Verify that a candidate password matches a stored hash."""
try:
return password_context.verify(candidate, hashed)
except ValueError:
# Raised when the stored hash is malformed or uses an unknown scheme.
return False
class TokenError(Exception):
"""Base class for token encoding/decoding issues."""
class TokenDecodeError(TokenError):
"""Raised when a token cannot be decoded or validated."""
class TokenExpiredError(TokenError):
"""Raised when a token has expired."""
class TokenTypeMismatchError(TokenError):
"""Raised when a token type does not match the expected flavour."""
TokenKind = Literal["access", "refresh"]
class TokenPayload(BaseModel):
"""Shared fields for CalMiner JWT payloads."""
sub: str
exp: int
type: TokenKind
scopes: list[str] = Field(default_factory=list)
@property
def expires_at(self) -> datetime:
return datetime.fromtimestamp(self.exp, tz=timezone.utc)
@dataclass(slots=True)
class JWTSettings:
"""Runtime configuration for JWT encoding and validation."""
secret_key: str
algorithm: str = "HS256"
access_token_ttl: timedelta = field(
default_factory=lambda: timedelta(minutes=15))
refresh_token_ttl: timedelta = field(
default_factory=lambda: timedelta(days=7))
def create_access_token(
subject: str,
settings: JWTSettings,
*,
scopes: Iterable[str] | None = None,
expires_delta: timedelta | None = None,
extra_claims: Dict[str, Any] | None = None,
) -> str:
"""Issue a signed access token for the provided subject."""
lifetime = expires_delta or settings.access_token_ttl
return _create_token(
subject=subject,
token_type="access",
settings=settings,
lifetime=lifetime,
scopes=scopes,
extra_claims=extra_claims,
)
def create_refresh_token(
subject: str,
settings: JWTSettings,
*,
scopes: Iterable[str] | None = None,
expires_delta: timedelta | None = None,
extra_claims: Dict[str, Any] | None = None,
) -> str:
"""Issue a signed refresh token for the provided subject."""
lifetime = expires_delta or settings.refresh_token_ttl
return _create_token(
subject=subject,
token_type="refresh",
settings=settings,
lifetime=lifetime,
scopes=scopes,
extra_claims=extra_claims,
)
def decode_access_token(token: str, settings: JWTSettings) -> TokenPayload:
"""Validate and decode an access token."""
return _decode_token(token, settings, expected_type="access")
def decode_refresh_token(token: str, settings: JWTSettings) -> TokenPayload:
"""Validate and decode a refresh token."""
return _decode_token(token, settings, expected_type="refresh")
def _create_token(
*,
subject: str,
token_type: TokenKind,
settings: JWTSettings,
lifetime: timedelta,
scopes: Iterable[str] | None,
extra_claims: Dict[str, Any] | None,
) -> str:
now = datetime.now(timezone.utc)
expire = now + lifetime
payload: Dict[str, Any] = {
"sub": subject,
"type": token_type,
"iat": int(now.timestamp()),
"exp": int(expire.timestamp()),
}
if scopes:
payload["scopes"] = list(scopes)
if extra_claims:
payload.update(extra_claims)
return jwt.encode(payload, settings.secret_key, algorithm=settings.algorithm)
def _decode_token(
token: str,
settings: JWTSettings,
expected_type: TokenKind,
) -> TokenPayload:
try:
decoded = jwt.decode(
token,
settings.secret_key,
algorithms=[settings.algorithm],
options={"verify_aud": False},
)
except ExpiredSignatureError as exc: # pragma: no cover - jose marks this path
raise TokenExpiredError("Token has expired") from exc
except JWTError as exc: # pragma: no cover - jose error bubble
raise TokenDecodeError("Unable to decode token") from exc
try:
payload = _model_validate(TokenPayload, decoded)
except ValidationError as exc:
raise TokenDecodeError("Token payload validation failed") from exc
if payload.type != expected_type:
raise TokenTypeMismatchError(
f"Expected a {expected_type} token but received '{payload.type}'."
)
return payload
def _model_validate(model: Type[TokenPayload], data: Dict[str, Any]) -> TokenPayload:
if hasattr(model, "model_validate"):
return model.model_validate(data) # type: ignore[attr-defined]
return model.parse_obj(data) # type: ignore[attr-defined]
__all__ = [
"JWTSettings",
"TokenDecodeError",
"TokenError",
"TokenExpiredError",
"TokenKind",
"TokenPayload",
"TokenTypeMismatchError",
"create_access_token",
"create_refresh_token",
"decode_access_token",
"decode_refresh_token",
"hash_password",
"password_context",
"verify_password",
]

View File

@@ -6,12 +6,16 @@ from typing import Callable, Sequence
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from config.database import SessionLocal from config.database import SessionLocal
from models import Scenario from models import Role, Scenario
from services.repositories import ( from services.repositories import (
FinancialInputRepository, FinancialInputRepository,
ProjectRepository, ProjectRepository,
RoleRepository,
ScenarioRepository, ScenarioRepository,
SimulationParameterRepository, SimulationParameterRepository,
UserRepository,
ensure_admin_user as ensure_admin_user_record,
ensure_default_roles,
) )
from services.scenario_validation import ScenarioComparisonValidator from services.scenario_validation import ScenarioComparisonValidator
@@ -23,6 +27,12 @@ class UnitOfWork(AbstractContextManager["UnitOfWork"]):
self._session_factory = session_factory self._session_factory = session_factory
self.session: Session | None = None self.session: Session | None = None
self._scenario_validator: ScenarioComparisonValidator | None = None self._scenario_validator: ScenarioComparisonValidator | None = None
self.projects: ProjectRepository | None = None
self.scenarios: ScenarioRepository | None = None
self.financial_inputs: FinancialInputRepository | None = None
self.simulation_parameters: SimulationParameterRepository | None = None
self.users: UserRepository | None = None
self.roles: RoleRepository | None = None
def __enter__(self) -> "UnitOfWork": def __enter__(self) -> "UnitOfWork":
self.session = self._session_factory() self.session = self._session_factory()
@@ -31,6 +41,8 @@ class UnitOfWork(AbstractContextManager["UnitOfWork"]):
self.financial_inputs = FinancialInputRepository(self.session) self.financial_inputs = FinancialInputRepository(self.session)
self.simulation_parameters = SimulationParameterRepository( self.simulation_parameters = SimulationParameterRepository(
self.session) self.session)
self.users = UserRepository(self.session)
self.roles = RoleRepository(self.session)
self._scenario_validator = ScenarioComparisonValidator() self._scenario_validator = ScenarioComparisonValidator()
return self return self
@@ -42,6 +54,12 @@ class UnitOfWork(AbstractContextManager["UnitOfWork"]):
self.session.rollback() self.session.rollback()
self.session.close() self.session.close()
self._scenario_validator = None self._scenario_validator = None
self.projects = None
self.scenarios = None
self.financial_inputs = None
self.simulation_parameters = None
self.users = None
self.roles = None
def flush(self) -> None: def flush(self) -> None:
if not self.session: if not self.session:
@@ -61,7 +79,7 @@ class UnitOfWork(AbstractContextManager["UnitOfWork"]):
def validate_scenarios_for_comparison( def validate_scenarios_for_comparison(
self, scenario_ids: Sequence[int] self, scenario_ids: Sequence[int]
) -> list[Scenario]: ) -> list[Scenario]:
if not self.session or not self._scenario_validator: if not self.session or not self._scenario_validator or not self.scenarios:
raise RuntimeError("UnitOfWork session is not initialised") raise RuntimeError("UnitOfWork session is not initialised")
scenarios = [self.scenarios.get(scenario_id) scenarios = [self.scenarios.get(scenario_id)
@@ -75,3 +93,26 @@ class UnitOfWork(AbstractContextManager["UnitOfWork"]):
if not self._scenario_validator: if not self._scenario_validator:
raise RuntimeError("UnitOfWork session is not initialised") raise RuntimeError("UnitOfWork session is not initialised")
self._scenario_validator.validate(scenarios) self._scenario_validator.validate(scenarios)
def ensure_default_roles(self) -> list[Role]:
if not self.roles:
raise RuntimeError("UnitOfWork session is not initialised")
return ensure_default_roles(self.roles)
def ensure_admin_user(
self,
*,
email: str,
username: str,
password: str,
) -> None:
if not self.users or not self.roles:
raise RuntimeError("UnitOfWork session is not initialised")
ensure_default_roles(self.roles)
ensure_admin_user_record(
self.users,
self.roles,
email=email,
username=username,
password=password,
)

BIN
static/img/logo_big.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.8 MiB

View File

@@ -1,14 +1,22 @@
{% extends "base.html" %} {% extends "base.html" %} {% block title %}Forgot Password{% endblock %} {%
block content %}
{% block title %}Forgot Password{% endblock %}
{% block content %}
<div class="container"> <div class="container">
<h1>Forgot Password</h1> <h1>Forgot Password</h1>
<form id="forgot-password-form"> {% if errors %}
<div class="alert alert-error">
<ul>
{% for error in errors %}
<li>{{ error }}</li>
{% endfor %}
</ul>
</div>
{% endif %} {% if message %}
<div class="alert alert-info">{{ message }}</div>
{% endif %}
<form id="forgot-password-form" method="post" action="{{ form_action }}">
<div class="form-group"> <div class="form-group">
<label for="email">Email:</label> <label for="email">Email:</label>
<input type="email" id="email" name="email" required> <input type="email" id="email" name="email" required />
</div> </div>
<button type="submit">Reset Password</button> <button type="submit">Reset Password</button>
</form> </form>

View File

@@ -1,18 +1,30 @@
{% extends "base.html" %} {% extends "base.html" %} {% block title %}Login{% endblock %} {% block content
%}
{% block title %}Login{% endblock %}
{% block content %}
<div class="container"> <div class="container">
<h1>Login</h1> <h1>Login</h1>
<form id="login-form"> {% if errors %}
<div class="alert alert-error">
<ul>
{% for error in errors %}
<li>{{ error }}</li>
{% endfor %}
</ul>
</div>
{% endif %}
<form id="login-form" method="post" action="{{ form_action }}">
<div class="form-group"> <div class="form-group">
<label for="username">Username:</label> <label for="username">Username:</label>
<input type="text" id="username" name="username" required> <input
type="text"
id="username"
name="username"
value="{{ username | default('') }}"
required
/>
</div> </div>
<div class="form-group"> <div class="form-group">
<label for="password">Password:</label> <label for="password">Password:</label>
<input type="password" id="password" name="password" required> <input type="password" id="password" name="password" required />
</div> </div>
<button type="submit">Login</button> <button type="submit">Login</button>
</form> </form>

View File

@@ -1,52 +1,35 @@
{% set dashboard_href = request.url_for('dashboard.home') if request else '/' %} {% set dashboard_href = request.url_for('dashboard.home') if request else '/' %}
{% set projects_href = request.url_for('projects.project_list_page') if request else '/projects/ui' %} {% set projects_href = request.url_for('projects.project_list_page') if request
{% set project_create_href = request.url_for('projects.create_project_form') if request else '/projects/create' %} else '/projects/ui' %} {% set project_create_href =
request.url_for('projects.create_project_form') if request else
{% set nav_groups = [ '/projects/create' %} {% set login_href = request.url_for('auth.login_form') if
{ request else '/login' %} {% set register_href =
"label": "Workspace", request.url_for('auth.register_form') if request else '/register' %} {% set
"links": [ forgot_href = request.url_for('auth.password_reset_request_form') if request
{"href": dashboard_href, "label": "Dashboard", "match_prefix": "/"}, else '/forgot-password' %} {% set nav_groups = [ { "label": "Workspace",
"links": [ {"href": dashboard_href, "label": "Dashboard", "match_prefix": "/"},
{"href": projects_href, "label": "Projects", "match_prefix": "/projects"}, {"href": projects_href, "label": "Projects", "match_prefix": "/projects"},
{"href": project_create_href, "label": "New Project", "match_prefix": "/projects/create"}, {"href": project_create_href, "label": "New Project", "match_prefix":
], "/projects/create"}, ], }, { "label": "Insights", "links": [ {"href":
}, "/ui/simulations", "label": "Simulations"}, {"href": "/ui/reporting", "label":
{ "Reporting"}, ], }, { "label": "Configuration", "links": [ { "href":
"label": "Insights", "/ui/settings", "label": "Settings", "children": [ {"href": "/theme-settings",
"links": [ "label": "Themes"}, {"href": "/ui/currencies", "label": "Currency Management"},
{"href": "/ui/simulations", "label": "Simulations"}, ], }, ], }, { "label": "Account", "links": [ {"href": login_href, "label":
{"href": "/ui/reporting", "label": "Reporting"}, "Login", "match_prefix": "/login"}, {"href": register_href, "label": "Register",
], "match_prefix": "/register"}, {"href": forgot_href, "label": "Forgot Password",
}, "match_prefix": "/forgot-password"}, ], }, ] %}
{
"label": "Configuration",
"links": [
{
"href": "/ui/settings",
"label": "Settings",
"children": [
{"href": "/theme-settings", "label": "Themes"},
{"href": "/ui/currencies", "label": "Currency Management"},
],
},
],
},
] %}
<nav class="sidebar-nav" aria-label="Primary navigation"> <nav class="sidebar-nav" aria-label="Primary navigation">
{% set current_path = request.url.path if request else "" %} {% set current_path = request.url.path if request else "" %} {% for group in
{% for group in nav_groups %} nav_groups %}
<div class="sidebar-section"> <div class="sidebar-section">
<div class="sidebar-section-label">{{ group.label }}</div> <div class="sidebar-section-label">{{ group.label }}</div>
<div class="sidebar-section-links"> <div class="sidebar-section-links">
{% for link in group.links %} {% for link in group.links %} {% set href = link.href %} {% set
{% set href = link.href %} match_prefix = link.get('match_prefix', href) %} {% if match_prefix == '/'
{% set match_prefix = link.get('match_prefix', href) %} %} {% set is_active = current_path == '/' %} {% else %} {% set is_active =
{% if match_prefix == '/' %} current_path.startswith(match_prefix) %} {% endif %}
{% set is_active = current_path == '/' %}
{% else %}
{% set is_active = current_path.startswith(match_prefix) %}
{% endif %}
<div class="sidebar-link-block"> <div class="sidebar-link-block">
<a <a
href="{{ href }}" href="{{ href }}"
@@ -56,13 +39,10 @@
</a> </a>
{% if link.children %} {% if link.children %}
<div class="sidebar-sublinks"> <div class="sidebar-sublinks">
{% for child in link.children %} {% for child in link.children %} {% set child_prefix =
{% set child_prefix = child.get('match_prefix', child.href) %} child.get('match_prefix', child.href) %} {% if child_prefix == '/' %}
{% if child_prefix == '/' %} {% set child_active = current_path == '/' %} {% else %} {% set
{% set child_active = current_path == '/' %} child_active = current_path.startswith(child_prefix) %} {% endif %}
{% else %}
{% set child_active = current_path.startswith(child_prefix) %}
{% endif %}
<a <a
href="{{ child.href }}" href="{{ child.href }}"
class="sidebar-sublink{% if child_active %} is-active{% endif %}" class="sidebar-sublink{% if child_active %} is-active{% endif %}"

View File

@@ -1,22 +1,40 @@
{% extends "base.html" %} {% extends "base.html" %} {% block title %}Register{% endblock %} {% block
content %}
{% block title %}Register{% endblock %}
{% block content %}
<div class="container"> <div class="container">
<h1>Register</h1> <h1>Register</h1>
<form id="register-form"> {% if errors %}
<div class="alert alert-error">
<ul>
{% for error in errors %}
<li>{{ error }}</li>
{% endfor %}
</ul>
</div>
{% endif %}
<form id="register-form" method="post" action="{{ form_action }}">
<div class="form-group"> <div class="form-group">
<label for="username">Username:</label> <label for="username">Username:</label>
<input type="text" id="username" name="username" required> <input
type="text"
id="username"
name="username"
value="{{ form_data.username if form_data else '' }}"
required
/>
</div> </div>
<div class="form-group"> <div class="form-group">
<label for="email">Email:</label> <label for="email">Email:</label>
<input type="email" id="email" name="email" required> <input
type="email"
id="email"
name="email"
value="{{ form_data.email if form_data else '' }}"
required
/>
</div> </div>
<div class="form-group"> <div class="form-group">
<label for="password">Password:</label> <label for="password">Password:</label>
<input type="password" id="password" name="password" required> <input type="password" id="password" name="password" required />
</div> </div>
<button type="submit">Register</button> <button type="submit">Register</button>
</form> </form>

View File

@@ -0,0 +1,36 @@
{% extends "base.html" %} {% block title %}Reset Password{% endblock %} {% block
content %}
<div class="container">
<h1>Reset Password</h1>
{% if errors %}
<div class="alert alert-error">
<ul>
{% for error in errors %}
<li>{{ error }}</li>
{% endfor %}
</ul>
</div>
{% endif %}
<form id="reset-password-form" method="post" action="{{ form_action }}">
<input type="hidden" name="token" value="{{ token | default('') }}" />
<div class="form-group">
<label for="password">New Password:</label>
<input type="password" id="password" name="password" required />
</div>
<div class="form-group">
<label for="confirm_password">Confirm Password:</label>
<input
type="password"
id="confirm_password"
name="confirm_password"
required
/>
</div>
<button type="submit">Update Password</button>
</form>
<p>
Remembered your password?
<a href="{{ request.url_for('auth.login_form') }}">Return to login</a>
</p>
</div>
{% endblock %}

View File

@@ -12,6 +12,7 @@ from sqlalchemy.pool import StaticPool
from config.database import Base from config.database import Base
from dependencies import get_unit_of_work from dependencies import get_unit_of_work
from routes.auth import router as auth_router
from routes.dashboard import router as dashboard_router from routes.dashboard import router as dashboard_router
from routes.projects import router as projects_router from routes.projects import router as projects_router
from routes.scenarios import router as scenarios_router from routes.scenarios import router as scenarios_router
@@ -36,13 +37,15 @@ def engine() -> Iterator[Engine]:
@pytest.fixture() @pytest.fixture()
def session_factory(engine: Engine) -> Iterator[sessionmaker]: def session_factory(engine: Engine) -> Iterator[sessionmaker]:
testing_session = sessionmaker(bind=engine, expire_on_commit=False, future=True) testing_session = sessionmaker(
bind=engine, expire_on_commit=False, future=True)
yield testing_session yield testing_session
@pytest.fixture() @pytest.fixture()
def app(session_factory: sessionmaker) -> FastAPI: def app(session_factory: sessionmaker) -> FastAPI:
application = FastAPI() application = FastAPI()
application.include_router(auth_router)
application.include_router(dashboard_router) application.include_router(dashboard_router)
application.include_router(projects_router) application.include_router(projects_router)
application.include_router(scenarios_router) application.include_router(scenarios_router)

View File

@@ -0,0 +1,135 @@
from __future__ import annotations
from collections.abc import Iterator
import pytest
from sqlalchemy import create_engine
from sqlalchemy.orm import Session, sessionmaker
from config.database import Base
from models import Role, User
from services.repositories import (
RoleRepository,
UserRepository,
ensure_admin_user,
ensure_default_roles,
)
from services.unit_of_work import UnitOfWork
@pytest.fixture()
def engine() -> Iterator:
engine = create_engine("sqlite:///:memory:", future=True)
Base.metadata.create_all(bind=engine)
try:
yield engine
finally:
Base.metadata.drop_all(bind=engine)
@pytest.fixture()
def session(engine) -> Iterator[Session]:
TestingSession = sessionmaker(
bind=engine, expire_on_commit=False, future=True)
db = TestingSession()
try:
yield db
finally:
db.close()
def test_role_repository_create_and_lookup(session: Session) -> None:
repo = RoleRepository(session)
role = Role(name="custom", display_name="Custom",
description="Custom role")
repo.create(role)
retrieved = repo.get(role.id)
assert retrieved.name == "custom"
assert repo.get_by_name("custom") is retrieved
assert repo.list()[0].name == "custom"
def test_user_repository_assign_and_revoke_role(session: Session) -> None:
role_repo = RoleRepository(session)
user_repo = UserRepository(session)
analyst = role_repo.create(
Role(name="analyst", display_name="Analyst", description="Analyzes data")
)
user = User(
email="user@example.com",
username="user",
password_hash=User.hash_password("secret"),
)
user_repo.create(user)
assignment = user_repo.assign_role(
user_id=user.id, role_id=analyst.id, granted_by=None)
assert assignment.role_id == analyst.id
refreshed = user_repo.get(user.id, with_roles=True)
assert refreshed.roles[0].name == "analyst"
user_repo.revoke_role(user_id=user.id, role_id=analyst.id)
refreshed = user_repo.get(user.id, with_roles=True)
assert refreshed.roles == []
def test_default_role_and_admin_helpers(session: Session) -> None:
role_repo = RoleRepository(session)
user_repo = UserRepository(session)
roles = ensure_default_roles(role_repo)
assert {role.name for role in roles} == {
"admin", "project_manager", "analyst", "viewer"}
ensure_admin_user(
user_repo,
role_repo,
email="admin@example.com",
username="admin",
password="SecurePass1!",
)
admin = user_repo.get_by_email("admin@example.com", with_roles=True)
assert admin is not None
assert admin.is_superuser
assert {role.name for role in admin.roles} >= {"admin"}
# Idempotent behaviour on subsequent calls
ensure_admin_user(
user_repo,
role_repo,
email="admin@example.com",
username="admin",
password="SecurePass1!",
)
admin_again = user_repo.get_by_email("admin@example.com", with_roles=True)
assert admin_again is not None
assert len(admin_again.roles) == len(
{role.name for role in admin_again.roles})
def test_unit_of_work_exposes_auth_repositories(engine) -> None:
TestingSession = sessionmaker(
bind=engine, expire_on_commit=False, future=True)
with UnitOfWork(session_factory=TestingSession) as uow:
assert uow.users is not None
assert uow.roles is not None
roles = uow.ensure_default_roles()
assert any(role.name == "admin" for role in roles)
uow.ensure_admin_user(
email="uow-admin@example.com",
username="uow-admin",
password="AnotherSecret1!",
)
admin = uow.users.get_by_email(
"uow-admin@example.com", with_roles=True)
assert admin is not None
assert admin.is_superuser
assert any(role.name == "admin" for role in admin.roles)

239
tests/test_auth_routes.py Normal file
View File

@@ -0,0 +1,239 @@
from __future__ import annotations
from collections.abc import Iterator
from urllib.parse import parse_qs, urlparse
import pytest
from fastapi.testclient import TestClient
from sqlalchemy import select
from sqlalchemy.orm import Session, sessionmaker
from models import Role, User, UserRole
from services.security import hash_password
@pytest.fixture()
def db_session(session_factory: sessionmaker) -> Iterator[Session]:
session = session_factory()
try:
yield session
finally:
session.close()
def _get_user(session: Session, *, email: str | None = None, username: str | None = None) -> User | None:
stmt = select(User)
if email is not None:
stmt = stmt.where(User.email == email)
if username is not None:
stmt = stmt.where(User.username == username)
return session.execute(stmt).scalar_one_or_none()
class TestRegistrationFlow:
def test_register_creates_user_and_assigns_role(
self,
client: TestClient,
db_session: Session,
) -> None:
response = client.post(
"/register",
data={
"username": "newuser",
"email": "newuser@example.com",
"password": "ComplexP@ss1",
"confirm_password": "ComplexP@ss1",
},
follow_redirects=False,
)
assert response.status_code == 303
location = response.headers.get("location")
assert location
parsed = urlparse(location)
assert parsed.path == "/login"
assert parse_qs(parsed.query).get("registered") == ["1"]
created = _get_user(db_session, email="newuser@example.com")
assert created is not None
assert created.is_active
role_stmt = select(Role).where(Role.name == "viewer")
viewer_role = db_session.execute(role_stmt).scalar_one_or_none()
assert viewer_role is not None
assignments = db_session.execute(
select(UserRole).where(
UserRole.user_id == created.id,
UserRole.role_id == viewer_role.id,
)
).scalars().all()
assert len(assignments) == 1
def test_register_duplicate_email_shows_error(
self,
client: TestClient,
) -> None:
first = client.post(
"/register",
data={
"username": "existing",
"email": "existing@example.com",
"password": "ComplexP@ss1",
"confirm_password": "ComplexP@ss1",
},
follow_redirects=False,
)
assert first.status_code == 303
second = client.post(
"/register",
data={
"username": "existing",
"email": "existing@example.com",
"password": "ComplexP@ss1",
"confirm_password": "ComplexP@ss1",
},
follow_redirects=False,
)
assert second.status_code == 400
assert "Email is already registered" in second.text
class TestLoginFlow:
def test_login_sets_tokens_and_updates_last_login(
self,
client: TestClient,
db_session: Session,
) -> None:
password = "MySecur3Pass!"
user = User(
email="login@example.com",
username="loginuser",
password_hash=hash_password(password),
is_active=True,
)
db_session.add(user)
db_session.commit()
response = client.post(
"/login",
data={"username": "loginuser", "password": password},
follow_redirects=False,
)
assert response.status_code == 303
assert response.headers.get("location") == "http://testserver/"
set_cookie_header = response.headers.get("set-cookie", "")
assert "calminer_access_token=" in set_cookie_header
assert "calminer_refresh_token=" in set_cookie_header
updated = _get_user(db_session, username="loginuser")
assert updated is not None and updated.last_login_at is not None
def test_login_invalid_credentials_returns_error(self, client: TestClient) -> None:
response = client.post(
"/login",
data={"username": "unknown", "password": "bad"},
follow_redirects=False,
)
assert response.status_code == 400
assert "Invalid username or password" in response.text
class TestPasswordResetFlow:
def test_password_reset_round_trip(
self,
client: TestClient,
db_session: Session,
) -> None:
user = User(
email="reset@example.com",
username="resetuser",
password_hash=hash_password("OldP@ssword1"),
is_active=True,
)
db_session.add(user)
db_session.commit()
request_response = client.post(
"/forgot-password",
data={"email": "reset@example.com"},
follow_redirects=False,
)
assert request_response.status_code == 303
reset_location = request_response.headers.get("location")
assert reset_location is not None
parsed = urlparse(reset_location)
assert parsed.path == "/reset-password"
token = parse_qs(parsed.query).get("token", [None])[0]
assert token
form_response = client.get(reset_location)
assert form_response.status_code == 200
submit_response = client.post(
"/reset-password",
data={
"token": token,
"password": "N3wP@ssword!",
"confirm_password": "N3wP@ssword!",
},
follow_redirects=False,
)
assert submit_response.status_code == 303
assert "reset=1" in (submit_response.headers.get("location") or "")
db_session.refresh(user)
assert user.verify_password("N3wP@ssword!")
def test_password_reset_with_unknown_email_shows_generic_message(
self,
client: TestClient,
) -> None:
response = client.post(
"/forgot-password",
data={"email": "doesnotexist@example.com"},
follow_redirects=False,
)
assert response.status_code == 200
assert "If an account exists" in response.text
def test_password_reset_mismatched_passwords_return_error(
self,
client: TestClient,
db_session: Session,
) -> None:
user = User(
email="mismatch@example.com",
username="mismatch",
password_hash=hash_password("OldP@ssword1"),
is_active=True,
)
db_session.add(user)
db_session.commit()
request_response = client.post(
"/forgot-password",
data={"email": "mismatch@example.com"},
follow_redirects=False,
)
token = parse_qs(urlparse(request_response.headers["location"]).query)["token"][0]
submit_response = client.post(
"/reset-password",
data={
"token": token,
"password": "NewPass123!",
"confirm_password": "Different123!",
},
follow_redirects=False,
)
assert submit_response.status_code == 400
assert "Passwords do not match" in submit_response.text

76
tests/test_security.py Normal file
View File

@@ -0,0 +1,76 @@
from __future__ import annotations
from datetime import timedelta
import pytest
from services.security import (
JWTSettings,
TokenDecodeError,
TokenExpiredError,
TokenTypeMismatchError,
create_access_token,
create_refresh_token,
decode_access_token,
decode_refresh_token,
hash_password,
verify_password,
)
def test_hash_password_round_trip() -> None:
hashed = hash_password("super-secret")
assert hashed != "super-secret"
assert verify_password("super-secret", hashed)
assert not verify_password("incorrect", hashed)
def test_verify_password_handles_malformed_hash() -> None:
assert not verify_password("secret", "not-a-valid-hash")
def test_access_token_roundtrip() -> None:
settings = JWTSettings(secret_key="unit-test-secret")
token = create_access_token(
"user-id-123",
settings,
scopes=("read", "write"),
extra_claims={"custom": "value"},
)
payload = decode_access_token(token, settings)
assert payload.sub == "user-id-123"
assert payload.type == "access"
assert payload.scopes == ["read", "write"]
def test_refresh_token_type_mismatch() -> None:
settings = JWTSettings(secret_key="unit-test-secret")
token = create_refresh_token("user-id-456", settings)
with pytest.raises(TokenTypeMismatchError):
decode_access_token(token, settings)
def test_decode_expired_token() -> None:
settings = JWTSettings(secret_key="unit-test-secret")
expired_token = create_access_token(
"user-id-789",
settings,
expires_delta=timedelta(seconds=-5),
)
with pytest.raises(TokenExpiredError):
decode_access_token(expired_token, settings)
def test_decode_tampered_token() -> None:
settings = JWTSettings(secret_key="unit-test-secret")
token = create_access_token("user-id-321", settings)
tampered = token[:-1] + ("a" if token[-1] != "a" else "b")
with pytest.raises(TokenDecodeError):
decode_access_token(tampered, settings)