diff --git a/alembic/versions/20251109_02_add_auth_tables.py b/alembic/versions/20251109_02_add_auth_tables.py new file mode 100644 index 0000000..c669bd2 --- /dev/null +++ b/alembic/versions/20251109_02_add_auth_tables.py @@ -0,0 +1,210 @@ +"""Add authentication and RBAC tables""" + +from __future__ import annotations + +from alembic import op +import sqlalchemy as sa +from passlib.context import CryptContext +from sqlalchemy.sql import column, table + +# revision identifiers, used by Alembic. +revision = "20251109_02" +down_revision = "20251109_01" +branch_labels = None +depends_on = None + +password_context = CryptContext(schemes=["argon2"], deprecated="auto") + + +def upgrade() -> None: + op.create_table( + "users", + sa.Column("id", sa.Integer(), primary_key=True), + sa.Column("email", sa.String(length=255), nullable=False), + sa.Column("username", sa.String(length=128), nullable=False), + sa.Column("password_hash", sa.String(length=255), nullable=False), + sa.Column( + "is_active", + sa.Boolean(), + nullable=False, + server_default=sa.true(), + ), + sa.Column( + "is_superuser", + sa.Boolean(), + nullable=False, + server_default=sa.false(), + ), + sa.Column("last_login_at", sa.DateTime(timezone=True), nullable=True), + sa.Column( + "created_at", + sa.DateTime(timezone=True), + nullable=False, + server_default=sa.func.now(), + ), + sa.Column( + "updated_at", + sa.DateTime(timezone=True), + nullable=False, + server_default=sa.func.now(), + ), + sa.UniqueConstraint("email", name="uq_users_email"), + sa.UniqueConstraint("username", name="uq_users_username"), + ) + op.create_index( + "ix_users_active_superuser", + "users", + ["is_active", "is_superuser"], + unique=False, + ) + + op.create_table( + "roles", + sa.Column("id", sa.Integer(), primary_key=True), + sa.Column("name", sa.String(length=64), nullable=False), + sa.Column("display_name", sa.String(length=128), nullable=False), + sa.Column("description", sa.Text(), nullable=True), + sa.Column( + "created_at", + sa.DateTime(timezone=True), + nullable=False, + server_default=sa.func.now(), + ), + sa.Column( + "updated_at", + sa.DateTime(timezone=True), + nullable=False, + server_default=sa.func.now(), + ), + sa.UniqueConstraint("name", name="uq_roles_name"), + ) + + op.create_table( + "user_roles", + sa.Column("user_id", sa.Integer(), nullable=False), + sa.Column("role_id", sa.Integer(), nullable=False), + sa.Column( + "granted_at", + sa.DateTime(timezone=True), + nullable=False, + server_default=sa.func.now(), + ), + sa.Column("granted_by", sa.Integer(), nullable=True), + sa.ForeignKeyConstraint( + ["user_id"], + ["users.id"], + ondelete="CASCADE", + ), + sa.ForeignKeyConstraint( + ["role_id"], + ["roles.id"], + ondelete="CASCADE", + ), + sa.ForeignKeyConstraint( + ["granted_by"], + ["users.id"], + ondelete="SET NULL", + ), + sa.PrimaryKeyConstraint("user_id", "role_id"), + sa.UniqueConstraint("user_id", "role_id", + name="uq_user_roles_user_role"), + ) + op.create_index( + "ix_user_roles_role_id", + "user_roles", + ["role_id"], + unique=False, + ) + + # Seed default roles + roles_table = table( + "roles", + column("id", sa.Integer()), + column("name", sa.String()), + column("display_name", sa.String()), + column("description", sa.Text()), + ) + + op.bulk_insert( + roles_table, + [ + { + "id": 1, + "name": "admin", + "display_name": "Administrator", + "description": "Full platform access with user management rights.", + }, + { + "id": 2, + "name": "project_manager", + "display_name": "Project Manager", + "description": "Manage projects, scenarios, and associated data.", + }, + { + "id": 3, + "name": "analyst", + "display_name": "Analyst", + "description": "Review dashboards and scenario outputs.", + }, + { + "id": 4, + "name": "viewer", + "display_name": "Viewer", + "description": "Read-only access to assigned projects and reports.", + }, + ], + ) + + admin_password_hash = password_context.hash("ChangeMe123!") + + users_table = table( + "users", + column("id", sa.Integer()), + column("email", sa.String()), + column("username", sa.String()), + column("password_hash", sa.String()), + column("is_active", sa.Boolean()), + column("is_superuser", sa.Boolean()), + ) + + op.bulk_insert( + users_table, + [ + { + "id": 1, + "email": "admin@calminer.local", + "username": "admin", + "password_hash": admin_password_hash, + "is_active": True, + "is_superuser": True, + } + ], + ) + + user_roles_table = table( + "user_roles", + column("user_id", sa.Integer()), + column("role_id", sa.Integer()), + column("granted_by", sa.Integer()), + ) + + op.bulk_insert( + user_roles_table, + [ + { + "user_id": 1, + "role_id": 1, + "granted_by": 1, + } + ], + ) + + +def downgrade() -> None: + op.drop_index("ix_user_roles_role_id", table_name="user_roles") + op.drop_table("user_roles") + + op.drop_table("roles") + + op.drop_index("ix_users_active_superuser", table_name="users") + op.drop_table("users") diff --git a/alembic_test.db b/alembic_test.db new file mode 100644 index 0000000..7c40e15 Binary files /dev/null and b/alembic_test.db differ diff --git a/changelog.md b/changelog.md index 49e0146..7606c0d 100644 --- a/changelog.md +++ b/changelog.md @@ -16,3 +16,5 @@ - Reordered project route registration to prioritize static UI paths, eliminating 422 errors on `/projects/ui` and `/projects/create`, and added pytest smoke coverage for the navigation endpoints. - Added end-to-end integration tests for project and scenario lifecycles, validating HTML redirects, template rendering, and API interactions, and updated `ProjectRepository.get` to deduplicate joined loads for detail views. - Updated all Jinja2 template responses to the new Starlette signature to eliminate deprecation warnings while keeping request-aware context available to the templates. +- Introduced `services/security.py` to centralize Argon2 password hashing utilities and JWT creation/verification with typed payloads, and added pytest coverage for hashing, expiry, tampering, and token type mismatch scenarios. +- Added `routes/auth.py` with registration, login, and password reset flows, refreshed auth templates with error messaging, wired navigation links, and introduced end-to-end pytest coverage for the new forms and token flows. diff --git a/config/settings.py b/config/settings.py new file mode 100644 index 0000000..0b928d4 --- /dev/null +++ b/config/settings.py @@ -0,0 +1,60 @@ +from __future__ import annotations + +import os +from dataclasses import dataclass +from datetime import timedelta +from functools import lru_cache + +from services.security import JWTSettings + + +@dataclass(frozen=True, slots=True) +class Settings: + """Application configuration sourced from environment variables.""" + + jwt_secret_key: str = "change-me" + jwt_algorithm: str = "HS256" + jwt_access_token_minutes: int = 15 + jwt_refresh_token_days: int = 7 + + @classmethod + def from_environment(cls) -> "Settings": + """Construct settings from environment variables.""" + + return cls( + jwt_secret_key=os.getenv("CALMINER_JWT_SECRET", "change-me"), + jwt_algorithm=os.getenv("CALMINER_JWT_ALGORITHM", "HS256"), + jwt_access_token_minutes=cls._int_from_env( + "CALMINER_JWT_ACCESS_MINUTES", 15 + ), + jwt_refresh_token_days=cls._int_from_env( + "CALMINER_JWT_REFRESH_DAYS", 7 + ), + ) + + @staticmethod + def _int_from_env(name: str, default: int) -> int: + raw_value = os.getenv(name) + if raw_value is None: + return default + try: + return int(raw_value) + except ValueError: + return default + + def jwt_settings(self) -> JWTSettings: + """Build runtime JWT settings compatible with token helpers.""" + + return JWTSettings( + secret_key=self.jwt_secret_key, + algorithm=self.jwt_algorithm, + access_token_ttl=timedelta(minutes=self.jwt_access_token_minutes), + refresh_token_ttl=timedelta(days=self.jwt_refresh_token_days), + ) + + +@lru_cache(maxsize=1) +def get_settings() -> Settings: + """Return cached application settings.""" + + return Settings.from_environment() diff --git a/dependencies.py b/dependencies.py index e492586..b2b0774 100644 --- a/dependencies.py +++ b/dependencies.py @@ -2,6 +2,8 @@ from __future__ import annotations from collections.abc import Generator +from config.settings import Settings, get_settings +from services.security import JWTSettings from services.unit_of_work import UnitOfWork @@ -10,3 +12,15 @@ def get_unit_of_work() -> Generator[UnitOfWork, None, None]: with UnitOfWork() as uow: yield uow + + +def get_application_settings() -> Settings: + """Provide cached application settings instance.""" + + return get_settings() + + +def get_jwt_settings() -> JWTSettings: + """Provide JWT runtime configuration derived from settings.""" + + return get_settings().jwt_settings() diff --git a/main.py b/main.py index 000a0e0..456aa5a 100644 --- a/main.py +++ b/main.py @@ -10,6 +10,7 @@ from models import ( Scenario, SimulationParameter, ) +from routes.auth import router as auth_router from routes.dashboard import router as dashboard_router from routes.projects import router as projects_router from routes.scenarios import router as scenarios_router @@ -33,6 +34,7 @@ async def health() -> dict[str, str]: app.include_router(dashboard_router) +app.include_router(auth_router) app.include_router(projects_router) app.include_router(scenarios_router) diff --git a/models/user.py b/models/user.py index 67db19e..580c705 100644 --- a/models/user.py +++ b/models/user.py @@ -4,6 +4,14 @@ from datetime import datetime from typing import List, Optional from passlib.context import CryptContext + +try: # pragma: no cover - defensive compatibility shim + import importlib.metadata as importlib_metadata + import argon2 # type: ignore + + setattr(argon2, "__version__", importlib_metadata.version("argon2-cffi")) +except Exception: + pass from sqlalchemy import ( Boolean, DateTime, diff --git a/routes/auth.py b/routes/auth.py new file mode 100644 index 0000000..71a8752 --- /dev/null +++ b/routes/auth.py @@ -0,0 +1,473 @@ +from __future__ import annotations + +from datetime import datetime, timedelta, timezone +from typing import Any, Iterable + +from fastapi import APIRouter, Depends, HTTPException, Request, UploadFile, status +from fastapi.responses import HTMLResponse, RedirectResponse +from fastapi.templating import Jinja2Templates +from pydantic import ValidationError +from starlette.datastructures import FormData + +from dependencies import get_jwt_settings, get_unit_of_work +from models import Role, User +from schemas.auth import ( + LoginForm, + PasswordResetForm, + PasswordResetRequestForm, + RegistrationForm, +) +from services.exceptions import EntityConflictError +from services.security import ( + JWTSettings, + TokenDecodeError, + TokenExpiredError, + TokenTypeMismatchError, + create_access_token, + create_refresh_token, + decode_access_token, + hash_password, + verify_password, +) +from services.repositories import RoleRepository, UserRepository +from services.unit_of_work import UnitOfWork + +router = APIRouter(tags=["Authentication"]) +templates = Jinja2Templates(directory="templates") + +_PASSWORD_RESET_SCOPE = "password-reset" +_AUTH_SCOPE = "auth" + + +def _template( + request: Request, + template_name: str, + context: dict[str, Any], + *, + status_code: int = status.HTTP_200_OK, +) -> HTMLResponse: + return templates.TemplateResponse( + request, + template_name, + context, + status_code=status_code, + ) + + +def _validation_errors(exc: ValidationError) -> list[str]: + return [error.get("msg", "Invalid input.") for error in exc.errors()] + + +def _scopes(include: Iterable[str]) -> list[str]: + return list(include) + + +def _normalise_form_data(form_data: FormData) -> dict[str, str]: + normalised: dict[str, str] = {} + for key, value in form_data.multi_items(): + if isinstance(value, UploadFile): + str_value = value.filename or "" + else: + str_value = str(value) + normalised[key] = str_value + return normalised + + +def _require_users_repo(uow: UnitOfWork) -> UserRepository: + if not uow.users: + raise RuntimeError("User repository is not initialised") + return uow.users + + +def _require_roles_repo(uow: UnitOfWork) -> RoleRepository: + if not uow.roles: + raise RuntimeError("Role repository is not initialised") + return uow.roles + + +@router.get("/login", response_class=HTMLResponse, include_in_schema=False, name="auth.login_form") +def login_form(request: Request) -> HTMLResponse: + return _template( + request, + "login.html", + { + "form_action": request.url_for("auth.login_submit"), + "errors": [], + "username": "", + }, + ) + + +@router.post("/login", include_in_schema=False, name="auth.login_submit") +async def login_submit( + request: Request, + uow: UnitOfWork = Depends(get_unit_of_work), + jwt_settings: JWTSettings = Depends(get_jwt_settings), +): + form_data = _normalise_form_data(await request.form()) + try: + form = LoginForm(**form_data) + except ValidationError as exc: + return _template( + request, + "login.html", + { + "form_action": request.url_for("auth.login_submit"), + "errors": _validation_errors(exc), + }, + status_code=status.HTTP_400_BAD_REQUEST, + ) + + identifier = form.username + users_repo = _require_users_repo(uow) + user = _lookup_user(users_repo, identifier) + errors: list[str] = [] + + if not user or not verify_password(form.password, user.password_hash): + errors.append("Invalid username or password.") + elif not user.is_active: + errors.append("Account is inactive. Contact an administrator.") + + if errors: + return _template( + request, + "login.html", + { + "form_action": request.url_for("auth.login_submit"), + "errors": errors, + "username": identifier, + }, + status_code=status.HTTP_400_BAD_REQUEST, + ) + + assert user is not None # mypy hint - guarded above + user.last_login_at = datetime.now(timezone.utc) + + access_token = create_access_token( + str(user.id), + jwt_settings, + scopes=_scopes((_AUTH_SCOPE,)), + ) + refresh_token = create_refresh_token( + str(user.id), + jwt_settings, + scopes=_scopes((_AUTH_SCOPE,)), + ) + + response = RedirectResponse( + request.url_for("dashboard.home"), + status_code=status.HTTP_303_SEE_OTHER, + ) + _set_auth_cookies(response, access_token, refresh_token, jwt_settings) + return response + + +def _lookup_user(users_repo: UserRepository, identifier: str) -> User | None: + if "@" in identifier: + return users_repo.get_by_email(identifier.lower(), with_roles=True) + return users_repo.get_by_username(identifier, with_roles=True) + + +def _set_auth_cookies( + response: RedirectResponse, + access_token: str, + refresh_token: str, + jwt_settings: JWTSettings, +) -> None: + access_ttl = int(jwt_settings.access_token_ttl.total_seconds()) + refresh_ttl = int(jwt_settings.refresh_token_ttl.total_seconds()) + response.set_cookie( + "calminer_access_token", + access_token, + httponly=True, + secure=False, + samesite="lax", + max_age=max(access_ttl, 0) or None, + ) + response.set_cookie( + "calminer_refresh_token", + refresh_token, + httponly=True, + secure=False, + samesite="lax", + max_age=max(refresh_ttl, 0) or None, + ) + + +@router.get("/register", response_class=HTMLResponse, include_in_schema=False, name="auth.register_form") +def register_form(request: Request) -> HTMLResponse: + return _template( + request, + "register.html", + { + "form_action": request.url_for("auth.register_submit"), + "errors": [], + "form_data": None, + }, + ) + + +@router.post("/register", include_in_schema=False, name="auth.register_submit") +async def register_submit( + request: Request, + uow: UnitOfWork = Depends(get_unit_of_work), +): + form_data = _normalise_form_data(await request.form()) + try: + form = RegistrationForm(**form_data) + except ValidationError as exc: + return _registration_error_response(request, _validation_errors(exc)) + + errors: list[str] = [] + users_repo = _require_users_repo(uow) + roles_repo = _require_roles_repo(uow) + uow.ensure_default_roles() + + if users_repo.get_by_email(form.email): + errors.append("Email is already registered.") + if users_repo.get_by_username(form.username): + errors.append("Username is already taken.") + + if errors: + return _registration_error_response(request, errors, form) + + user = User( + email=form.email, + username=form.username, + password_hash=hash_password(form.password), + is_active=True, + is_superuser=False, + ) + + try: + created = users_repo.create(user) + except EntityConflictError: + return _registration_error_response( + request, + ["An account with this username or email already exists."], + form, + ) + + viewer_role = _ensure_viewer_role(roles_repo) + if viewer_role is not None: + users_repo.assign_role( + user_id=created.id, + role_id=viewer_role.id, + granted_by=created.id, + ) + + redirect_url = request.url_for( + "auth.login_form").include_query_params(registered="1") + return RedirectResponse( + redirect_url, + status_code=status.HTTP_303_SEE_OTHER, + ) + + +def _registration_error_response( + request: Request, + errors: list[str], + form: RegistrationForm | None = None, +) -> HTMLResponse: + context = { + "form_action": request.url_for("auth.register_submit"), + "errors": errors, + "form_data": form.model_dump(exclude={"password", "confirm_password"}) if form else None, + } + return _template( + request, + "register.html", + context, + status_code=status.HTTP_400_BAD_REQUEST, + ) + + +def _ensure_viewer_role(roles_repo: RoleRepository) -> Role | None: + viewer = roles_repo.get_by_name("viewer") + if viewer: + return viewer + return roles_repo.get_by_name("viewer") + + +@router.get( + "/forgot-password", + response_class=HTMLResponse, + include_in_schema=False, + name="auth.password_reset_request_form", +) +def password_reset_request_form(request: Request) -> HTMLResponse: + return _template( + request, + "forgot_password.html", + { + "form_action": request.url_for("auth.password_reset_request_submit"), + "errors": [], + "message": None, + }, + ) + + +@router.post( + "/forgot-password", + include_in_schema=False, + name="auth.password_reset_request_submit", +) +async def password_reset_request_submit( + request: Request, + uow: UnitOfWork = Depends(get_unit_of_work), + jwt_settings: JWTSettings = Depends(get_jwt_settings), +): + form_data = _normalise_form_data(await request.form()) + try: + form = PasswordResetRequestForm(**form_data) + except ValidationError as exc: + return _template( + request, + "forgot_password.html", + { + "form_action": request.url_for("auth.password_reset_request_submit"), + "errors": _validation_errors(exc), + "message": None, + }, + status_code=status.HTTP_400_BAD_REQUEST, + ) + + users_repo = _require_users_repo(uow) + user = users_repo.get_by_email(form.email) + if not user: + return _template( + request, + "forgot_password.html", + { + "form_action": request.url_for("auth.password_reset_request_submit"), + "errors": [], + "message": "If an account exists, a reset link has been sent.", + }, + ) + + token = create_access_token( + str(user.id), + jwt_settings, + scopes=_scopes((_PASSWORD_RESET_SCOPE,)), + expires_delta=timedelta(hours=1), + ) + + reset_url = request.url_for( + "auth.password_reset_form").include_query_params(token=token) + return RedirectResponse(reset_url, status_code=status.HTTP_303_SEE_OTHER) + + +@router.get( + "/reset-password", + response_class=HTMLResponse, + include_in_schema=False, + name="auth.password_reset_form", +) +def password_reset_form( + request: Request, + token: str | None = None, + jwt_settings: JWTSettings = Depends(get_jwt_settings), +) -> HTMLResponse: + errors: list[str] = [] + if not token: + errors.append("Missing password reset token.") + else: + try: + payload = decode_access_token(token, jwt_settings) + if _PASSWORD_RESET_SCOPE not in payload.scopes: + errors.append("Invalid token scope.") + except TokenExpiredError: + errors.append( + "Token has expired. Please request a new password reset.") + except (TokenDecodeError, TokenTypeMismatchError): + errors.append("Invalid password reset token.") + + return _template( + request, + "reset_password.html", + { + "form_action": request.url_for("auth.password_reset_submit"), + "token": token, + "errors": errors, + }, + status_code=status.HTTP_400_BAD_REQUEST if errors else status.HTTP_200_OK, + ) + + +@router.post( + "/reset-password", + include_in_schema=False, + name="auth.password_reset_submit", +) +async def password_reset_submit( + request: Request, + uow: UnitOfWork = Depends(get_unit_of_work), + jwt_settings: JWTSettings = Depends(get_jwt_settings), +): + form_data = _normalise_form_data(await request.form()) + try: + form = PasswordResetForm(**form_data) + except ValidationError as exc: + return _template( + request, + "reset_password.html", + { + "form_action": request.url_for("auth.password_reset_submit"), + "token": form_data.get("token"), + "errors": _validation_errors(exc), + }, + status_code=status.HTTP_400_BAD_REQUEST, + ) + + try: + payload = decode_access_token(form.token, jwt_settings) + except TokenExpiredError: + return _reset_error_response( + request, + form.token, + "Token has expired. Please request a new password reset.", + ) + except (TokenDecodeError, TokenTypeMismatchError): + return _reset_error_response( + request, + form.token, + "Invalid password reset token.", + ) + + if _PASSWORD_RESET_SCOPE not in payload.scopes: + return _reset_error_response( + request, + form.token, + "Invalid password reset token scope.", + ) + + users_repo = _require_users_repo(uow) + user_id = int(payload.sub) + user = users_repo.get(user_id) + if not user: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail="User not found") + + user.set_password(form.password) + if not user.is_active: + user.is_active = True + + redirect_url = request.url_for( + "auth.login_form").include_query_params(reset="1") + return RedirectResponse( + redirect_url, + status_code=status.HTTP_303_SEE_OTHER, + ) + + +def _reset_error_response(request: Request, token: str, message: str) -> HTMLResponse: + return _template( + request, + "reset_password.html", + { + "form_action": request.url_for("auth.password_reset_submit"), + "token": token, + "errors": [message], + }, + status_code=status.HTTP_400_BAD_REQUEST, + ) diff --git a/schemas/auth.py b/schemas/auth.py new file mode 100644 index 0000000..3a16191 --- /dev/null +++ b/schemas/auth.py @@ -0,0 +1,67 @@ +from __future__ import annotations + +from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_validator + + +class FormModel(BaseModel): + """Base Pydantic model for HTML form submissions.""" + + model_config = ConfigDict(extra="forbid", str_strip_whitespace=True) + + +class RegistrationForm(FormModel): + username: str = Field(min_length=3, max_length=128) + email: str = Field(min_length=5, max_length=255) + password: str = Field(min_length=8, max_length=256) + confirm_password: str + + @field_validator("email") + @classmethod + def validate_email(cls, value: str) -> str: + if "@" not in value or value.startswith("@") or value.endswith("@"): + raise ValueError("Invalid email address.") + local, domain = value.split("@", 1) + if not local or "." not in domain: + raise ValueError("Invalid email address.") + return value.lower() + + @field_validator("confirm_password") + @classmethod + def passwords_match(cls, value: str, info: ValidationInfo) -> str: + password = info.data.get("password") + if password != value: + raise ValueError("Passwords do not match.") + return value + + +class LoginForm(FormModel): + username: str = Field(min_length=1, max_length=255) + password: str = Field(min_length=1, max_length=256) + + +class PasswordResetRequestForm(FormModel): + email: str = Field(min_length=5, max_length=255) + + @field_validator("email") + @classmethod + def validate_email(cls, value: str) -> str: + if "@" not in value or value.startswith("@") or value.endswith("@"): + raise ValueError("Invalid email address.") + local, domain = value.split("@", 1) + if not local or "." not in domain: + raise ValueError("Invalid email address.") + return value.lower() + + +class PasswordResetForm(FormModel): + token: str = Field(min_length=1) + password: str = Field(min_length=8, max_length=256) + confirm_password: str + + @field_validator("confirm_password") + @classmethod + def reset_passwords_match(cls, value: str, info: ValidationInfo) -> str: + password = info.data.get("password") + if password != value: + raise ValueError("Passwords do not match.") + return value diff --git a/services/repositories.py b/services/repositories.py index 5556638..1f82691 100644 --- a/services/repositories.py +++ b/services/repositories.py @@ -8,7 +8,16 @@ from sqlalchemy import select, func from sqlalchemy.exc import IntegrityError from sqlalchemy.orm import Session, joinedload, selectinload -from models import FinancialInput, Project, Scenario, ScenarioStatus, SimulationParameter +from models import ( + FinancialInput, + Project, + Role, + Scenario, + ScenarioStatus, + SimulationParameter, + User, + UserRole, +) from services.exceptions import EntityConflictError, EntityNotFoundError @@ -211,3 +220,220 @@ class SimulationParameterRepository: raise EntityNotFoundError( f"Simulation parameter {parameter_id} not found") self.session.delete(entity) + + +class RoleRepository: + """Persistence operations for Role entities.""" + + def __init__(self, session: Session) -> None: + self.session = session + + def list(self) -> Sequence[Role]: + stmt = select(Role).order_by(Role.name) + return self.session.execute(stmt).scalars().all() + + def get(self, role_id: int) -> Role: + stmt = select(Role).where(Role.id == role_id) + role = self.session.execute(stmt).scalar_one_or_none() + if role is None: + raise EntityNotFoundError(f"Role {role_id} not found") + return role + + def get_by_name(self, name: str) -> Role | None: + stmt = select(Role).where(Role.name == name) + return self.session.execute(stmt).scalar_one_or_none() + + def create(self, role: Role) -> Role: + self.session.add(role) + try: + self.session.flush() + except IntegrityError as exc: # pragma: no cover - DB constraint enforcement + raise EntityConflictError( + "Role violates uniqueness constraints") from exc + return role + + +class UserRepository: + """Persistence operations for User entities and their role assignments.""" + + def __init__(self, session: Session) -> None: + self.session = session + + def list(self, *, with_roles: bool = False) -> Sequence[User]: + stmt = select(User).order_by(User.created_at) + if with_roles: + stmt = stmt.options(selectinload(User.roles)) + return self.session.execute(stmt).scalars().all() + + def _apply_role_option(self, stmt, with_roles: bool): + if with_roles: + stmt = stmt.options( + joinedload(User.role_assignments).joinedload(UserRole.role), + selectinload(User.roles), + ) + return stmt + + def get(self, user_id: int, *, with_roles: bool = False) -> User: + stmt = select(User).where(User.id == user_id).execution_options( + populate_existing=True) + stmt = self._apply_role_option(stmt, with_roles) + result = self.session.execute(stmt) + if with_roles: + result = result.unique() + user = result.scalar_one_or_none() + if user is None: + raise EntityNotFoundError(f"User {user_id} not found") + return user + + def get_by_email(self, email: str, *, with_roles: bool = False) -> User | None: + stmt = select(User).where(User.email == email).execution_options( + populate_existing=True) + stmt = self._apply_role_option(stmt, with_roles) + result = self.session.execute(stmt) + if with_roles: + result = result.unique() + return result.scalar_one_or_none() + + def get_by_username(self, username: str, *, with_roles: bool = False) -> User | None: + stmt = select(User).where(User.username == + username).execution_options(populate_existing=True) + stmt = self._apply_role_option(stmt, with_roles) + result = self.session.execute(stmt) + if with_roles: + result = result.unique() + return result.scalar_one_or_none() + + def create(self, user: User) -> User: + self.session.add(user) + try: + self.session.flush() + except IntegrityError as exc: # pragma: no cover - DB constraint enforcement + raise EntityConflictError( + "User violates uniqueness constraints") from exc + return user + + def assign_role( + self, + *, + user_id: int, + role_id: int, + granted_by: int | None = None, + ) -> UserRole: + stmt = select(UserRole).where( + UserRole.user_id == user_id, + UserRole.role_id == role_id, + ) + assignment = self.session.execute(stmt).scalar_one_or_none() + if assignment: + return assignment + + assignment = UserRole( + user_id=user_id, + role_id=role_id, + granted_by=granted_by, + ) + self.session.add(assignment) + try: + self.session.flush() + except IntegrityError as exc: # pragma: no cover - DB constraint enforcement + raise EntityConflictError( + "Assignment violates constraints") from exc + return assignment + + def revoke_role(self, *, user_id: int, role_id: int) -> None: + stmt = select(UserRole).where( + UserRole.user_id == user_id, + UserRole.role_id == role_id, + ) + assignment = self.session.execute(stmt).scalar_one_or_none() + if assignment is None: + raise EntityNotFoundError( + f"Role {role_id} not assigned to user {user_id}") + self.session.delete(assignment) + self.session.flush() + + +DEFAULT_ROLE_DEFINITIONS: tuple[dict[str, str], ...] = ( + { + "name": "admin", + "display_name": "Administrator", + "description": "Full platform access with user management rights.", + }, + { + "name": "project_manager", + "display_name": "Project Manager", + "description": "Manage projects, scenarios, and associated data.", + }, + { + "name": "analyst", + "display_name": "Analyst", + "description": "Review dashboards and scenario outputs.", + }, + { + "name": "viewer", + "display_name": "Viewer", + "description": "Read-only access to assigned projects and reports.", + }, +) + + +def ensure_default_roles(role_repo: RoleRepository) -> list[Role]: + """Ensure standard roles exist, creating missing ones. + + Returns all current role records in creation order. + """ + + roles: list[Role] = [] + for definition in DEFAULT_ROLE_DEFINITIONS: + existing = role_repo.get_by_name(definition["name"]) + if existing: + roles.append(existing) + continue + role = Role(**definition) + roles.append(role_repo.create(role)) + return roles + + +def ensure_admin_user( + user_repo: UserRepository, + role_repo: RoleRepository, + *, + email: str, + username: str, + password: str, +) -> User: + """Ensure an administrator user exists and holds the admin role.""" + + user = user_repo.get_by_email(email, with_roles=True) + if user is None: + user = User( + email=email, + username=username, + password_hash=User.hash_password(password), + is_active=True, + is_superuser=True, + ) + user_repo.create(user) + else: + if not user.is_active: + user.is_active = True + if not user.is_superuser: + user.is_superuser = True + user_repo.session.flush() + + admin_role = role_repo.get_by_name("admin") + if admin_role is None: # pragma: no cover - safety if ensure_default_roles wasn't called + admin_role = role_repo.create( + Role( + name="admin", + display_name="Administrator", + description="Full platform access with user management rights.", + ) + ) + + user_repo.assign_role( + user_id=user.id, + role_id=admin_role.id, + granted_by=user.id, + ) + return user diff --git a/services/security.py b/services/security.py new file mode 100644 index 0000000..02a078d --- /dev/null +++ b/services/security.py @@ -0,0 +1,213 @@ +from __future__ import annotations + +from dataclasses import dataclass, field +from datetime import datetime, timedelta, timezone +from typing import Any, Dict, Iterable, Literal, Type + +from jose import ExpiredSignatureError, JWTError, jwt +from passlib.context import CryptContext + +try: # pragma: no cover - compatibility shim for passlib/argon2 warning + import importlib.metadata as importlib_metadata + import argon2 # type: ignore + + setattr(argon2, "__version__", importlib_metadata.version("argon2-cffi")) +except Exception: # pragma: no cover - executed only when metadata lookup fails + pass + +from pydantic import BaseModel, Field, ValidationError + +password_context = CryptContext(schemes=["argon2"], deprecated="auto") + + +def hash_password(password: str) -> str: + """Derive a secure hash for a plain-text password.""" + + return password_context.hash(password) + + +def verify_password(candidate: str, hashed: str) -> bool: + """Verify that a candidate password matches a stored hash.""" + + try: + return password_context.verify(candidate, hashed) + except ValueError: + # Raised when the stored hash is malformed or uses an unknown scheme. + return False + + +class TokenError(Exception): + """Base class for token encoding/decoding issues.""" + + +class TokenDecodeError(TokenError): + """Raised when a token cannot be decoded or validated.""" + + +class TokenExpiredError(TokenError): + """Raised when a token has expired.""" + + +class TokenTypeMismatchError(TokenError): + """Raised when a token type does not match the expected flavour.""" + + +TokenKind = Literal["access", "refresh"] + + +class TokenPayload(BaseModel): + """Shared fields for CalMiner JWT payloads.""" + + sub: str + exp: int + type: TokenKind + scopes: list[str] = Field(default_factory=list) + + @property + def expires_at(self) -> datetime: + return datetime.fromtimestamp(self.exp, tz=timezone.utc) + + +@dataclass(slots=True) +class JWTSettings: + """Runtime configuration for JWT encoding and validation.""" + + secret_key: str + algorithm: str = "HS256" + access_token_ttl: timedelta = field( + default_factory=lambda: timedelta(minutes=15)) + refresh_token_ttl: timedelta = field( + default_factory=lambda: timedelta(days=7)) + + +def create_access_token( + subject: str, + settings: JWTSettings, + *, + scopes: Iterable[str] | None = None, + expires_delta: timedelta | None = None, + extra_claims: Dict[str, Any] | None = None, +) -> str: + """Issue a signed access token for the provided subject.""" + + lifetime = expires_delta or settings.access_token_ttl + return _create_token( + subject=subject, + token_type="access", + settings=settings, + lifetime=lifetime, + scopes=scopes, + extra_claims=extra_claims, + ) + + +def create_refresh_token( + subject: str, + settings: JWTSettings, + *, + scopes: Iterable[str] | None = None, + expires_delta: timedelta | None = None, + extra_claims: Dict[str, Any] | None = None, +) -> str: + """Issue a signed refresh token for the provided subject.""" + + lifetime = expires_delta or settings.refresh_token_ttl + return _create_token( + subject=subject, + token_type="refresh", + settings=settings, + lifetime=lifetime, + scopes=scopes, + extra_claims=extra_claims, + ) + + +def decode_access_token(token: str, settings: JWTSettings) -> TokenPayload: + """Validate and decode an access token.""" + + return _decode_token(token, settings, expected_type="access") + + +def decode_refresh_token(token: str, settings: JWTSettings) -> TokenPayload: + """Validate and decode a refresh token.""" + + return _decode_token(token, settings, expected_type="refresh") + + +def _create_token( + *, + subject: str, + token_type: TokenKind, + settings: JWTSettings, + lifetime: timedelta, + scopes: Iterable[str] | None, + extra_claims: Dict[str, Any] | None, +) -> str: + now = datetime.now(timezone.utc) + expire = now + lifetime + payload: Dict[str, Any] = { + "sub": subject, + "type": token_type, + "iat": int(now.timestamp()), + "exp": int(expire.timestamp()), + } + if scopes: + payload["scopes"] = list(scopes) + if extra_claims: + payload.update(extra_claims) + + return jwt.encode(payload, settings.secret_key, algorithm=settings.algorithm) + + +def _decode_token( + token: str, + settings: JWTSettings, + expected_type: TokenKind, +) -> TokenPayload: + try: + decoded = jwt.decode( + token, + settings.secret_key, + algorithms=[settings.algorithm], + options={"verify_aud": False}, + ) + except ExpiredSignatureError as exc: # pragma: no cover - jose marks this path + raise TokenExpiredError("Token has expired") from exc + except JWTError as exc: # pragma: no cover - jose error bubble + raise TokenDecodeError("Unable to decode token") from exc + + try: + payload = _model_validate(TokenPayload, decoded) + except ValidationError as exc: + raise TokenDecodeError("Token payload validation failed") from exc + + if payload.type != expected_type: + raise TokenTypeMismatchError( + f"Expected a {expected_type} token but received '{payload.type}'." + ) + + return payload + + +def _model_validate(model: Type[TokenPayload], data: Dict[str, Any]) -> TokenPayload: + if hasattr(model, "model_validate"): + return model.model_validate(data) # type: ignore[attr-defined] + return model.parse_obj(data) # type: ignore[attr-defined] + + +__all__ = [ + "JWTSettings", + "TokenDecodeError", + "TokenError", + "TokenExpiredError", + "TokenKind", + "TokenPayload", + "TokenTypeMismatchError", + "create_access_token", + "create_refresh_token", + "decode_access_token", + "decode_refresh_token", + "hash_password", + "password_context", + "verify_password", +] diff --git a/services/unit_of_work.py b/services/unit_of_work.py index 2e7b9e8..a51849e 100644 --- a/services/unit_of_work.py +++ b/services/unit_of_work.py @@ -6,12 +6,16 @@ from typing import Callable, Sequence from sqlalchemy.orm import Session from config.database import SessionLocal -from models import Scenario +from models import Role, Scenario from services.repositories import ( FinancialInputRepository, ProjectRepository, + RoleRepository, ScenarioRepository, SimulationParameterRepository, + UserRepository, + ensure_admin_user as ensure_admin_user_record, + ensure_default_roles, ) from services.scenario_validation import ScenarioComparisonValidator @@ -23,6 +27,12 @@ class UnitOfWork(AbstractContextManager["UnitOfWork"]): self._session_factory = session_factory self.session: Session | None = None self._scenario_validator: ScenarioComparisonValidator | None = None + self.projects: ProjectRepository | None = None + self.scenarios: ScenarioRepository | None = None + self.financial_inputs: FinancialInputRepository | None = None + self.simulation_parameters: SimulationParameterRepository | None = None + self.users: UserRepository | None = None + self.roles: RoleRepository | None = None def __enter__(self) -> "UnitOfWork": self.session = self._session_factory() @@ -31,6 +41,8 @@ class UnitOfWork(AbstractContextManager["UnitOfWork"]): self.financial_inputs = FinancialInputRepository(self.session) self.simulation_parameters = SimulationParameterRepository( self.session) + self.users = UserRepository(self.session) + self.roles = RoleRepository(self.session) self._scenario_validator = ScenarioComparisonValidator() return self @@ -42,6 +54,12 @@ class UnitOfWork(AbstractContextManager["UnitOfWork"]): self.session.rollback() self.session.close() self._scenario_validator = None + self.projects = None + self.scenarios = None + self.financial_inputs = None + self.simulation_parameters = None + self.users = None + self.roles = None def flush(self) -> None: if not self.session: @@ -61,7 +79,7 @@ class UnitOfWork(AbstractContextManager["UnitOfWork"]): def validate_scenarios_for_comparison( self, scenario_ids: Sequence[int] ) -> list[Scenario]: - if not self.session or not self._scenario_validator: + if not self.session or not self._scenario_validator or not self.scenarios: raise RuntimeError("UnitOfWork session is not initialised") scenarios = [self.scenarios.get(scenario_id) @@ -75,3 +93,26 @@ class UnitOfWork(AbstractContextManager["UnitOfWork"]): if not self._scenario_validator: raise RuntimeError("UnitOfWork session is not initialised") self._scenario_validator.validate(scenarios) + + def ensure_default_roles(self) -> list[Role]: + if not self.roles: + raise RuntimeError("UnitOfWork session is not initialised") + return ensure_default_roles(self.roles) + + def ensure_admin_user( + self, + *, + email: str, + username: str, + password: str, + ) -> None: + if not self.users or not self.roles: + raise RuntimeError("UnitOfWork session is not initialised") + ensure_default_roles(self.roles) + ensure_admin_user_record( + self.users, + self.roles, + email=email, + username=username, + password=password, + ) diff --git a/static/img/logo_big.png b/static/img/logo_big.png new file mode 100644 index 0000000..98f1a78 Binary files /dev/null and b/static/img/logo_big.png differ diff --git a/templates/forgot_password.html b/templates/forgot_password.html index 4d21fd3..2d257cb 100644 --- a/templates/forgot_password.html +++ b/templates/forgot_password.html @@ -1,17 +1,25 @@ -{% extends "base.html" %} - -{% block title %}Forgot Password{% endblock %} - -{% block content %} +{% extends "base.html" %} {% block title %}Forgot Password{% endblock %} {% +block content %}
Remember your password? Login here
+Remember your password? Login here
Don't have an account? Register here
- +Don't have an account? Register here
+