diff --git a/alembic/versions/20251109_02_add_auth_tables.py b/alembic/versions/20251109_02_add_auth_tables.py new file mode 100644 index 0000000..c669bd2 --- /dev/null +++ b/alembic/versions/20251109_02_add_auth_tables.py @@ -0,0 +1,210 @@ +"""Add authentication and RBAC tables""" + +from __future__ import annotations + +from alembic import op +import sqlalchemy as sa +from passlib.context import CryptContext +from sqlalchemy.sql import column, table + +# revision identifiers, used by Alembic. +revision = "20251109_02" +down_revision = "20251109_01" +branch_labels = None +depends_on = None + +password_context = CryptContext(schemes=["argon2"], deprecated="auto") + + +def upgrade() -> None: + op.create_table( + "users", + sa.Column("id", sa.Integer(), primary_key=True), + sa.Column("email", sa.String(length=255), nullable=False), + sa.Column("username", sa.String(length=128), nullable=False), + sa.Column("password_hash", sa.String(length=255), nullable=False), + sa.Column( + "is_active", + sa.Boolean(), + nullable=False, + server_default=sa.true(), + ), + sa.Column( + "is_superuser", + sa.Boolean(), + nullable=False, + server_default=sa.false(), + ), + sa.Column("last_login_at", sa.DateTime(timezone=True), nullable=True), + sa.Column( + "created_at", + sa.DateTime(timezone=True), + nullable=False, + server_default=sa.func.now(), + ), + sa.Column( + "updated_at", + sa.DateTime(timezone=True), + nullable=False, + server_default=sa.func.now(), + ), + sa.UniqueConstraint("email", name="uq_users_email"), + sa.UniqueConstraint("username", name="uq_users_username"), + ) + op.create_index( + "ix_users_active_superuser", + "users", + ["is_active", "is_superuser"], + unique=False, + ) + + op.create_table( + "roles", + sa.Column("id", sa.Integer(), primary_key=True), + sa.Column("name", sa.String(length=64), nullable=False), + sa.Column("display_name", sa.String(length=128), nullable=False), + sa.Column("description", sa.Text(), nullable=True), + sa.Column( + "created_at", + sa.DateTime(timezone=True), + nullable=False, + server_default=sa.func.now(), + ), + sa.Column( + "updated_at", + sa.DateTime(timezone=True), + nullable=False, + server_default=sa.func.now(), + ), + sa.UniqueConstraint("name", name="uq_roles_name"), + ) + + op.create_table( + "user_roles", + sa.Column("user_id", sa.Integer(), nullable=False), + sa.Column("role_id", sa.Integer(), nullable=False), + sa.Column( + "granted_at", + sa.DateTime(timezone=True), + nullable=False, + server_default=sa.func.now(), + ), + sa.Column("granted_by", sa.Integer(), nullable=True), + sa.ForeignKeyConstraint( + ["user_id"], + ["users.id"], + ondelete="CASCADE", + ), + sa.ForeignKeyConstraint( + ["role_id"], + ["roles.id"], + ondelete="CASCADE", + ), + sa.ForeignKeyConstraint( + ["granted_by"], + ["users.id"], + ondelete="SET NULL", + ), + sa.PrimaryKeyConstraint("user_id", "role_id"), + sa.UniqueConstraint("user_id", "role_id", + name="uq_user_roles_user_role"), + ) + op.create_index( + "ix_user_roles_role_id", + "user_roles", + ["role_id"], + unique=False, + ) + + # Seed default roles + roles_table = table( + "roles", + column("id", sa.Integer()), + column("name", sa.String()), + column("display_name", sa.String()), + column("description", sa.Text()), + ) + + op.bulk_insert( + roles_table, + [ + { + "id": 1, + "name": "admin", + "display_name": "Administrator", + "description": "Full platform access with user management rights.", + }, + { + "id": 2, + "name": "project_manager", + "display_name": "Project Manager", + "description": "Manage projects, scenarios, and associated data.", + }, + { + "id": 3, + "name": "analyst", + "display_name": "Analyst", + "description": "Review dashboards and scenario outputs.", + }, + { + "id": 4, + "name": "viewer", + "display_name": "Viewer", + "description": "Read-only access to assigned projects and reports.", + }, + ], + ) + + admin_password_hash = password_context.hash("ChangeMe123!") + + users_table = table( + "users", + column("id", sa.Integer()), + column("email", sa.String()), + column("username", sa.String()), + column("password_hash", sa.String()), + column("is_active", sa.Boolean()), + column("is_superuser", sa.Boolean()), + ) + + op.bulk_insert( + users_table, + [ + { + "id": 1, + "email": "admin@calminer.local", + "username": "admin", + "password_hash": admin_password_hash, + "is_active": True, + "is_superuser": True, + } + ], + ) + + user_roles_table = table( + "user_roles", + column("user_id", sa.Integer()), + column("role_id", sa.Integer()), + column("granted_by", sa.Integer()), + ) + + op.bulk_insert( + user_roles_table, + [ + { + "user_id": 1, + "role_id": 1, + "granted_by": 1, + } + ], + ) + + +def downgrade() -> None: + op.drop_index("ix_user_roles_role_id", table_name="user_roles") + op.drop_table("user_roles") + + op.drop_table("roles") + + op.drop_index("ix_users_active_superuser", table_name="users") + op.drop_table("users") diff --git a/alembic_test.db b/alembic_test.db new file mode 100644 index 0000000..7c40e15 Binary files /dev/null and b/alembic_test.db differ diff --git a/changelog.md b/changelog.md index 49e0146..7606c0d 100644 --- a/changelog.md +++ b/changelog.md @@ -16,3 +16,5 @@ - Reordered project route registration to prioritize static UI paths, eliminating 422 errors on `/projects/ui` and `/projects/create`, and added pytest smoke coverage for the navigation endpoints. - Added end-to-end integration tests for project and scenario lifecycles, validating HTML redirects, template rendering, and API interactions, and updated `ProjectRepository.get` to deduplicate joined loads for detail views. - Updated all Jinja2 template responses to the new Starlette signature to eliminate deprecation warnings while keeping request-aware context available to the templates. +- Introduced `services/security.py` to centralize Argon2 password hashing utilities and JWT creation/verification with typed payloads, and added pytest coverage for hashing, expiry, tampering, and token type mismatch scenarios. +- Added `routes/auth.py` with registration, login, and password reset flows, refreshed auth templates with error messaging, wired navigation links, and introduced end-to-end pytest coverage for the new forms and token flows. diff --git a/config/settings.py b/config/settings.py new file mode 100644 index 0000000..0b928d4 --- /dev/null +++ b/config/settings.py @@ -0,0 +1,60 @@ +from __future__ import annotations + +import os +from dataclasses import dataclass +from datetime import timedelta +from functools import lru_cache + +from services.security import JWTSettings + + +@dataclass(frozen=True, slots=True) +class Settings: + """Application configuration sourced from environment variables.""" + + jwt_secret_key: str = "change-me" + jwt_algorithm: str = "HS256" + jwt_access_token_minutes: int = 15 + jwt_refresh_token_days: int = 7 + + @classmethod + def from_environment(cls) -> "Settings": + """Construct settings from environment variables.""" + + return cls( + jwt_secret_key=os.getenv("CALMINER_JWT_SECRET", "change-me"), + jwt_algorithm=os.getenv("CALMINER_JWT_ALGORITHM", "HS256"), + jwt_access_token_minutes=cls._int_from_env( + "CALMINER_JWT_ACCESS_MINUTES", 15 + ), + jwt_refresh_token_days=cls._int_from_env( + "CALMINER_JWT_REFRESH_DAYS", 7 + ), + ) + + @staticmethod + def _int_from_env(name: str, default: int) -> int: + raw_value = os.getenv(name) + if raw_value is None: + return default + try: + return int(raw_value) + except ValueError: + return default + + def jwt_settings(self) -> JWTSettings: + """Build runtime JWT settings compatible with token helpers.""" + + return JWTSettings( + secret_key=self.jwt_secret_key, + algorithm=self.jwt_algorithm, + access_token_ttl=timedelta(minutes=self.jwt_access_token_minutes), + refresh_token_ttl=timedelta(days=self.jwt_refresh_token_days), + ) + + +@lru_cache(maxsize=1) +def get_settings() -> Settings: + """Return cached application settings.""" + + return Settings.from_environment() diff --git a/dependencies.py b/dependencies.py index e492586..b2b0774 100644 --- a/dependencies.py +++ b/dependencies.py @@ -2,6 +2,8 @@ from __future__ import annotations from collections.abc import Generator +from config.settings import Settings, get_settings +from services.security import JWTSettings from services.unit_of_work import UnitOfWork @@ -10,3 +12,15 @@ def get_unit_of_work() -> Generator[UnitOfWork, None, None]: with UnitOfWork() as uow: yield uow + + +def get_application_settings() -> Settings: + """Provide cached application settings instance.""" + + return get_settings() + + +def get_jwt_settings() -> JWTSettings: + """Provide JWT runtime configuration derived from settings.""" + + return get_settings().jwt_settings() diff --git a/main.py b/main.py index 000a0e0..456aa5a 100644 --- a/main.py +++ b/main.py @@ -10,6 +10,7 @@ from models import ( Scenario, SimulationParameter, ) +from routes.auth import router as auth_router from routes.dashboard import router as dashboard_router from routes.projects import router as projects_router from routes.scenarios import router as scenarios_router @@ -33,6 +34,7 @@ async def health() -> dict[str, str]: app.include_router(dashboard_router) +app.include_router(auth_router) app.include_router(projects_router) app.include_router(scenarios_router) diff --git a/models/user.py b/models/user.py index 67db19e..580c705 100644 --- a/models/user.py +++ b/models/user.py @@ -4,6 +4,14 @@ from datetime import datetime from typing import List, Optional from passlib.context import CryptContext + +try: # pragma: no cover - defensive compatibility shim + import importlib.metadata as importlib_metadata + import argon2 # type: ignore + + setattr(argon2, "__version__", importlib_metadata.version("argon2-cffi")) +except Exception: + pass from sqlalchemy import ( Boolean, DateTime, diff --git a/routes/auth.py b/routes/auth.py new file mode 100644 index 0000000..71a8752 --- /dev/null +++ b/routes/auth.py @@ -0,0 +1,473 @@ +from __future__ import annotations + +from datetime import datetime, timedelta, timezone +from typing import Any, Iterable + +from fastapi import APIRouter, Depends, HTTPException, Request, UploadFile, status +from fastapi.responses import HTMLResponse, RedirectResponse +from fastapi.templating import Jinja2Templates +from pydantic import ValidationError +from starlette.datastructures import FormData + +from dependencies import get_jwt_settings, get_unit_of_work +from models import Role, User +from schemas.auth import ( + LoginForm, + PasswordResetForm, + PasswordResetRequestForm, + RegistrationForm, +) +from services.exceptions import EntityConflictError +from services.security import ( + JWTSettings, + TokenDecodeError, + TokenExpiredError, + TokenTypeMismatchError, + create_access_token, + create_refresh_token, + decode_access_token, + hash_password, + verify_password, +) +from services.repositories import RoleRepository, UserRepository +from services.unit_of_work import UnitOfWork + +router = APIRouter(tags=["Authentication"]) +templates = Jinja2Templates(directory="templates") + +_PASSWORD_RESET_SCOPE = "password-reset" +_AUTH_SCOPE = "auth" + + +def _template( + request: Request, + template_name: str, + context: dict[str, Any], + *, + status_code: int = status.HTTP_200_OK, +) -> HTMLResponse: + return templates.TemplateResponse( + request, + template_name, + context, + status_code=status_code, + ) + + +def _validation_errors(exc: ValidationError) -> list[str]: + return [error.get("msg", "Invalid input.") for error in exc.errors()] + + +def _scopes(include: Iterable[str]) -> list[str]: + return list(include) + + +def _normalise_form_data(form_data: FormData) -> dict[str, str]: + normalised: dict[str, str] = {} + for key, value in form_data.multi_items(): + if isinstance(value, UploadFile): + str_value = value.filename or "" + else: + str_value = str(value) + normalised[key] = str_value + return normalised + + +def _require_users_repo(uow: UnitOfWork) -> UserRepository: + if not uow.users: + raise RuntimeError("User repository is not initialised") + return uow.users + + +def _require_roles_repo(uow: UnitOfWork) -> RoleRepository: + if not uow.roles: + raise RuntimeError("Role repository is not initialised") + return uow.roles + + +@router.get("/login", response_class=HTMLResponse, include_in_schema=False, name="auth.login_form") +def login_form(request: Request) -> HTMLResponse: + return _template( + request, + "login.html", + { + "form_action": request.url_for("auth.login_submit"), + "errors": [], + "username": "", + }, + ) + + +@router.post("/login", include_in_schema=False, name="auth.login_submit") +async def login_submit( + request: Request, + uow: UnitOfWork = Depends(get_unit_of_work), + jwt_settings: JWTSettings = Depends(get_jwt_settings), +): + form_data = _normalise_form_data(await request.form()) + try: + form = LoginForm(**form_data) + except ValidationError as exc: + return _template( + request, + "login.html", + { + "form_action": request.url_for("auth.login_submit"), + "errors": _validation_errors(exc), + }, + status_code=status.HTTP_400_BAD_REQUEST, + ) + + identifier = form.username + users_repo = _require_users_repo(uow) + user = _lookup_user(users_repo, identifier) + errors: list[str] = [] + + if not user or not verify_password(form.password, user.password_hash): + errors.append("Invalid username or password.") + elif not user.is_active: + errors.append("Account is inactive. Contact an administrator.") + + if errors: + return _template( + request, + "login.html", + { + "form_action": request.url_for("auth.login_submit"), + "errors": errors, + "username": identifier, + }, + status_code=status.HTTP_400_BAD_REQUEST, + ) + + assert user is not None # mypy hint - guarded above + user.last_login_at = datetime.now(timezone.utc) + + access_token = create_access_token( + str(user.id), + jwt_settings, + scopes=_scopes((_AUTH_SCOPE,)), + ) + refresh_token = create_refresh_token( + str(user.id), + jwt_settings, + scopes=_scopes((_AUTH_SCOPE,)), + ) + + response = RedirectResponse( + request.url_for("dashboard.home"), + status_code=status.HTTP_303_SEE_OTHER, + ) + _set_auth_cookies(response, access_token, refresh_token, jwt_settings) + return response + + +def _lookup_user(users_repo: UserRepository, identifier: str) -> User | None: + if "@" in identifier: + return users_repo.get_by_email(identifier.lower(), with_roles=True) + return users_repo.get_by_username(identifier, with_roles=True) + + +def _set_auth_cookies( + response: RedirectResponse, + access_token: str, + refresh_token: str, + jwt_settings: JWTSettings, +) -> None: + access_ttl = int(jwt_settings.access_token_ttl.total_seconds()) + refresh_ttl = int(jwt_settings.refresh_token_ttl.total_seconds()) + response.set_cookie( + "calminer_access_token", + access_token, + httponly=True, + secure=False, + samesite="lax", + max_age=max(access_ttl, 0) or None, + ) + response.set_cookie( + "calminer_refresh_token", + refresh_token, + httponly=True, + secure=False, + samesite="lax", + max_age=max(refresh_ttl, 0) or None, + ) + + +@router.get("/register", response_class=HTMLResponse, include_in_schema=False, name="auth.register_form") +def register_form(request: Request) -> HTMLResponse: + return _template( + request, + "register.html", + { + "form_action": request.url_for("auth.register_submit"), + "errors": [], + "form_data": None, + }, + ) + + +@router.post("/register", include_in_schema=False, name="auth.register_submit") +async def register_submit( + request: Request, + uow: UnitOfWork = Depends(get_unit_of_work), +): + form_data = _normalise_form_data(await request.form()) + try: + form = RegistrationForm(**form_data) + except ValidationError as exc: + return _registration_error_response(request, _validation_errors(exc)) + + errors: list[str] = [] + users_repo = _require_users_repo(uow) + roles_repo = _require_roles_repo(uow) + uow.ensure_default_roles() + + if users_repo.get_by_email(form.email): + errors.append("Email is already registered.") + if users_repo.get_by_username(form.username): + errors.append("Username is already taken.") + + if errors: + return _registration_error_response(request, errors, form) + + user = User( + email=form.email, + username=form.username, + password_hash=hash_password(form.password), + is_active=True, + is_superuser=False, + ) + + try: + created = users_repo.create(user) + except EntityConflictError: + return _registration_error_response( + request, + ["An account with this username or email already exists."], + form, + ) + + viewer_role = _ensure_viewer_role(roles_repo) + if viewer_role is not None: + users_repo.assign_role( + user_id=created.id, + role_id=viewer_role.id, + granted_by=created.id, + ) + + redirect_url = request.url_for( + "auth.login_form").include_query_params(registered="1") + return RedirectResponse( + redirect_url, + status_code=status.HTTP_303_SEE_OTHER, + ) + + +def _registration_error_response( + request: Request, + errors: list[str], + form: RegistrationForm | None = None, +) -> HTMLResponse: + context = { + "form_action": request.url_for("auth.register_submit"), + "errors": errors, + "form_data": form.model_dump(exclude={"password", "confirm_password"}) if form else None, + } + return _template( + request, + "register.html", + context, + status_code=status.HTTP_400_BAD_REQUEST, + ) + + +def _ensure_viewer_role(roles_repo: RoleRepository) -> Role | None: + viewer = roles_repo.get_by_name("viewer") + if viewer: + return viewer + return roles_repo.get_by_name("viewer") + + +@router.get( + "/forgot-password", + response_class=HTMLResponse, + include_in_schema=False, + name="auth.password_reset_request_form", +) +def password_reset_request_form(request: Request) -> HTMLResponse: + return _template( + request, + "forgot_password.html", + { + "form_action": request.url_for("auth.password_reset_request_submit"), + "errors": [], + "message": None, + }, + ) + + +@router.post( + "/forgot-password", + include_in_schema=False, + name="auth.password_reset_request_submit", +) +async def password_reset_request_submit( + request: Request, + uow: UnitOfWork = Depends(get_unit_of_work), + jwt_settings: JWTSettings = Depends(get_jwt_settings), +): + form_data = _normalise_form_data(await request.form()) + try: + form = PasswordResetRequestForm(**form_data) + except ValidationError as exc: + return _template( + request, + "forgot_password.html", + { + "form_action": request.url_for("auth.password_reset_request_submit"), + "errors": _validation_errors(exc), + "message": None, + }, + status_code=status.HTTP_400_BAD_REQUEST, + ) + + users_repo = _require_users_repo(uow) + user = users_repo.get_by_email(form.email) + if not user: + return _template( + request, + "forgot_password.html", + { + "form_action": request.url_for("auth.password_reset_request_submit"), + "errors": [], + "message": "If an account exists, a reset link has been sent.", + }, + ) + + token = create_access_token( + str(user.id), + jwt_settings, + scopes=_scopes((_PASSWORD_RESET_SCOPE,)), + expires_delta=timedelta(hours=1), + ) + + reset_url = request.url_for( + "auth.password_reset_form").include_query_params(token=token) + return RedirectResponse(reset_url, status_code=status.HTTP_303_SEE_OTHER) + + +@router.get( + "/reset-password", + response_class=HTMLResponse, + include_in_schema=False, + name="auth.password_reset_form", +) +def password_reset_form( + request: Request, + token: str | None = None, + jwt_settings: JWTSettings = Depends(get_jwt_settings), +) -> HTMLResponse: + errors: list[str] = [] + if not token: + errors.append("Missing password reset token.") + else: + try: + payload = decode_access_token(token, jwt_settings) + if _PASSWORD_RESET_SCOPE not in payload.scopes: + errors.append("Invalid token scope.") + except TokenExpiredError: + errors.append( + "Token has expired. Please request a new password reset.") + except (TokenDecodeError, TokenTypeMismatchError): + errors.append("Invalid password reset token.") + + return _template( + request, + "reset_password.html", + { + "form_action": request.url_for("auth.password_reset_submit"), + "token": token, + "errors": errors, + }, + status_code=status.HTTP_400_BAD_REQUEST if errors else status.HTTP_200_OK, + ) + + +@router.post( + "/reset-password", + include_in_schema=False, + name="auth.password_reset_submit", +) +async def password_reset_submit( + request: Request, + uow: UnitOfWork = Depends(get_unit_of_work), + jwt_settings: JWTSettings = Depends(get_jwt_settings), +): + form_data = _normalise_form_data(await request.form()) + try: + form = PasswordResetForm(**form_data) + except ValidationError as exc: + return _template( + request, + "reset_password.html", + { + "form_action": request.url_for("auth.password_reset_submit"), + "token": form_data.get("token"), + "errors": _validation_errors(exc), + }, + status_code=status.HTTP_400_BAD_REQUEST, + ) + + try: + payload = decode_access_token(form.token, jwt_settings) + except TokenExpiredError: + return _reset_error_response( + request, + form.token, + "Token has expired. Please request a new password reset.", + ) + except (TokenDecodeError, TokenTypeMismatchError): + return _reset_error_response( + request, + form.token, + "Invalid password reset token.", + ) + + if _PASSWORD_RESET_SCOPE not in payload.scopes: + return _reset_error_response( + request, + form.token, + "Invalid password reset token scope.", + ) + + users_repo = _require_users_repo(uow) + user_id = int(payload.sub) + user = users_repo.get(user_id) + if not user: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail="User not found") + + user.set_password(form.password) + if not user.is_active: + user.is_active = True + + redirect_url = request.url_for( + "auth.login_form").include_query_params(reset="1") + return RedirectResponse( + redirect_url, + status_code=status.HTTP_303_SEE_OTHER, + ) + + +def _reset_error_response(request: Request, token: str, message: str) -> HTMLResponse: + return _template( + request, + "reset_password.html", + { + "form_action": request.url_for("auth.password_reset_submit"), + "token": token, + "errors": [message], + }, + status_code=status.HTTP_400_BAD_REQUEST, + ) diff --git a/schemas/auth.py b/schemas/auth.py new file mode 100644 index 0000000..3a16191 --- /dev/null +++ b/schemas/auth.py @@ -0,0 +1,67 @@ +from __future__ import annotations + +from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_validator + + +class FormModel(BaseModel): + """Base Pydantic model for HTML form submissions.""" + + model_config = ConfigDict(extra="forbid", str_strip_whitespace=True) + + +class RegistrationForm(FormModel): + username: str = Field(min_length=3, max_length=128) + email: str = Field(min_length=5, max_length=255) + password: str = Field(min_length=8, max_length=256) + confirm_password: str + + @field_validator("email") + @classmethod + def validate_email(cls, value: str) -> str: + if "@" not in value or value.startswith("@") or value.endswith("@"): + raise ValueError("Invalid email address.") + local, domain = value.split("@", 1) + if not local or "." not in domain: + raise ValueError("Invalid email address.") + return value.lower() + + @field_validator("confirm_password") + @classmethod + def passwords_match(cls, value: str, info: ValidationInfo) -> str: + password = info.data.get("password") + if password != value: + raise ValueError("Passwords do not match.") + return value + + +class LoginForm(FormModel): + username: str = Field(min_length=1, max_length=255) + password: str = Field(min_length=1, max_length=256) + + +class PasswordResetRequestForm(FormModel): + email: str = Field(min_length=5, max_length=255) + + @field_validator("email") + @classmethod + def validate_email(cls, value: str) -> str: + if "@" not in value or value.startswith("@") or value.endswith("@"): + raise ValueError("Invalid email address.") + local, domain = value.split("@", 1) + if not local or "." not in domain: + raise ValueError("Invalid email address.") + return value.lower() + + +class PasswordResetForm(FormModel): + token: str = Field(min_length=1) + password: str = Field(min_length=8, max_length=256) + confirm_password: str + + @field_validator("confirm_password") + @classmethod + def reset_passwords_match(cls, value: str, info: ValidationInfo) -> str: + password = info.data.get("password") + if password != value: + raise ValueError("Passwords do not match.") + return value diff --git a/services/repositories.py b/services/repositories.py index 5556638..1f82691 100644 --- a/services/repositories.py +++ b/services/repositories.py @@ -8,7 +8,16 @@ from sqlalchemy import select, func from sqlalchemy.exc import IntegrityError from sqlalchemy.orm import Session, joinedload, selectinload -from models import FinancialInput, Project, Scenario, ScenarioStatus, SimulationParameter +from models import ( + FinancialInput, + Project, + Role, + Scenario, + ScenarioStatus, + SimulationParameter, + User, + UserRole, +) from services.exceptions import EntityConflictError, EntityNotFoundError @@ -211,3 +220,220 @@ class SimulationParameterRepository: raise EntityNotFoundError( f"Simulation parameter {parameter_id} not found") self.session.delete(entity) + + +class RoleRepository: + """Persistence operations for Role entities.""" + + def __init__(self, session: Session) -> None: + self.session = session + + def list(self) -> Sequence[Role]: + stmt = select(Role).order_by(Role.name) + return self.session.execute(stmt).scalars().all() + + def get(self, role_id: int) -> Role: + stmt = select(Role).where(Role.id == role_id) + role = self.session.execute(stmt).scalar_one_or_none() + if role is None: + raise EntityNotFoundError(f"Role {role_id} not found") + return role + + def get_by_name(self, name: str) -> Role | None: + stmt = select(Role).where(Role.name == name) + return self.session.execute(stmt).scalar_one_or_none() + + def create(self, role: Role) -> Role: + self.session.add(role) + try: + self.session.flush() + except IntegrityError as exc: # pragma: no cover - DB constraint enforcement + raise EntityConflictError( + "Role violates uniqueness constraints") from exc + return role + + +class UserRepository: + """Persistence operations for User entities and their role assignments.""" + + def __init__(self, session: Session) -> None: + self.session = session + + def list(self, *, with_roles: bool = False) -> Sequence[User]: + stmt = select(User).order_by(User.created_at) + if with_roles: + stmt = stmt.options(selectinload(User.roles)) + return self.session.execute(stmt).scalars().all() + + def _apply_role_option(self, stmt, with_roles: bool): + if with_roles: + stmt = stmt.options( + joinedload(User.role_assignments).joinedload(UserRole.role), + selectinload(User.roles), + ) + return stmt + + def get(self, user_id: int, *, with_roles: bool = False) -> User: + stmt = select(User).where(User.id == user_id).execution_options( + populate_existing=True) + stmt = self._apply_role_option(stmt, with_roles) + result = self.session.execute(stmt) + if with_roles: + result = result.unique() + user = result.scalar_one_or_none() + if user is None: + raise EntityNotFoundError(f"User {user_id} not found") + return user + + def get_by_email(self, email: str, *, with_roles: bool = False) -> User | None: + stmt = select(User).where(User.email == email).execution_options( + populate_existing=True) + stmt = self._apply_role_option(stmt, with_roles) + result = self.session.execute(stmt) + if with_roles: + result = result.unique() + return result.scalar_one_or_none() + + def get_by_username(self, username: str, *, with_roles: bool = False) -> User | None: + stmt = select(User).where(User.username == + username).execution_options(populate_existing=True) + stmt = self._apply_role_option(stmt, with_roles) + result = self.session.execute(stmt) + if with_roles: + result = result.unique() + return result.scalar_one_or_none() + + def create(self, user: User) -> User: + self.session.add(user) + try: + self.session.flush() + except IntegrityError as exc: # pragma: no cover - DB constraint enforcement + raise EntityConflictError( + "User violates uniqueness constraints") from exc + return user + + def assign_role( + self, + *, + user_id: int, + role_id: int, + granted_by: int | None = None, + ) -> UserRole: + stmt = select(UserRole).where( + UserRole.user_id == user_id, + UserRole.role_id == role_id, + ) + assignment = self.session.execute(stmt).scalar_one_or_none() + if assignment: + return assignment + + assignment = UserRole( + user_id=user_id, + role_id=role_id, + granted_by=granted_by, + ) + self.session.add(assignment) + try: + self.session.flush() + except IntegrityError as exc: # pragma: no cover - DB constraint enforcement + raise EntityConflictError( + "Assignment violates constraints") from exc + return assignment + + def revoke_role(self, *, user_id: int, role_id: int) -> None: + stmt = select(UserRole).where( + UserRole.user_id == user_id, + UserRole.role_id == role_id, + ) + assignment = self.session.execute(stmt).scalar_one_or_none() + if assignment is None: + raise EntityNotFoundError( + f"Role {role_id} not assigned to user {user_id}") + self.session.delete(assignment) + self.session.flush() + + +DEFAULT_ROLE_DEFINITIONS: tuple[dict[str, str], ...] = ( + { + "name": "admin", + "display_name": "Administrator", + "description": "Full platform access with user management rights.", + }, + { + "name": "project_manager", + "display_name": "Project Manager", + "description": "Manage projects, scenarios, and associated data.", + }, + { + "name": "analyst", + "display_name": "Analyst", + "description": "Review dashboards and scenario outputs.", + }, + { + "name": "viewer", + "display_name": "Viewer", + "description": "Read-only access to assigned projects and reports.", + }, +) + + +def ensure_default_roles(role_repo: RoleRepository) -> list[Role]: + """Ensure standard roles exist, creating missing ones. + + Returns all current role records in creation order. + """ + + roles: list[Role] = [] + for definition in DEFAULT_ROLE_DEFINITIONS: + existing = role_repo.get_by_name(definition["name"]) + if existing: + roles.append(existing) + continue + role = Role(**definition) + roles.append(role_repo.create(role)) + return roles + + +def ensure_admin_user( + user_repo: UserRepository, + role_repo: RoleRepository, + *, + email: str, + username: str, + password: str, +) -> User: + """Ensure an administrator user exists and holds the admin role.""" + + user = user_repo.get_by_email(email, with_roles=True) + if user is None: + user = User( + email=email, + username=username, + password_hash=User.hash_password(password), + is_active=True, + is_superuser=True, + ) + user_repo.create(user) + else: + if not user.is_active: + user.is_active = True + if not user.is_superuser: + user.is_superuser = True + user_repo.session.flush() + + admin_role = role_repo.get_by_name("admin") + if admin_role is None: # pragma: no cover - safety if ensure_default_roles wasn't called + admin_role = role_repo.create( + Role( + name="admin", + display_name="Administrator", + description="Full platform access with user management rights.", + ) + ) + + user_repo.assign_role( + user_id=user.id, + role_id=admin_role.id, + granted_by=user.id, + ) + return user diff --git a/services/security.py b/services/security.py new file mode 100644 index 0000000..02a078d --- /dev/null +++ b/services/security.py @@ -0,0 +1,213 @@ +from __future__ import annotations + +from dataclasses import dataclass, field +from datetime import datetime, timedelta, timezone +from typing import Any, Dict, Iterable, Literal, Type + +from jose import ExpiredSignatureError, JWTError, jwt +from passlib.context import CryptContext + +try: # pragma: no cover - compatibility shim for passlib/argon2 warning + import importlib.metadata as importlib_metadata + import argon2 # type: ignore + + setattr(argon2, "__version__", importlib_metadata.version("argon2-cffi")) +except Exception: # pragma: no cover - executed only when metadata lookup fails + pass + +from pydantic import BaseModel, Field, ValidationError + +password_context = CryptContext(schemes=["argon2"], deprecated="auto") + + +def hash_password(password: str) -> str: + """Derive a secure hash for a plain-text password.""" + + return password_context.hash(password) + + +def verify_password(candidate: str, hashed: str) -> bool: + """Verify that a candidate password matches a stored hash.""" + + try: + return password_context.verify(candidate, hashed) + except ValueError: + # Raised when the stored hash is malformed or uses an unknown scheme. + return False + + +class TokenError(Exception): + """Base class for token encoding/decoding issues.""" + + +class TokenDecodeError(TokenError): + """Raised when a token cannot be decoded or validated.""" + + +class TokenExpiredError(TokenError): + """Raised when a token has expired.""" + + +class TokenTypeMismatchError(TokenError): + """Raised when a token type does not match the expected flavour.""" + + +TokenKind = Literal["access", "refresh"] + + +class TokenPayload(BaseModel): + """Shared fields for CalMiner JWT payloads.""" + + sub: str + exp: int + type: TokenKind + scopes: list[str] = Field(default_factory=list) + + @property + def expires_at(self) -> datetime: + return datetime.fromtimestamp(self.exp, tz=timezone.utc) + + +@dataclass(slots=True) +class JWTSettings: + """Runtime configuration for JWT encoding and validation.""" + + secret_key: str + algorithm: str = "HS256" + access_token_ttl: timedelta = field( + default_factory=lambda: timedelta(minutes=15)) + refresh_token_ttl: timedelta = field( + default_factory=lambda: timedelta(days=7)) + + +def create_access_token( + subject: str, + settings: JWTSettings, + *, + scopes: Iterable[str] | None = None, + expires_delta: timedelta | None = None, + extra_claims: Dict[str, Any] | None = None, +) -> str: + """Issue a signed access token for the provided subject.""" + + lifetime = expires_delta or settings.access_token_ttl + return _create_token( + subject=subject, + token_type="access", + settings=settings, + lifetime=lifetime, + scopes=scopes, + extra_claims=extra_claims, + ) + + +def create_refresh_token( + subject: str, + settings: JWTSettings, + *, + scopes: Iterable[str] | None = None, + expires_delta: timedelta | None = None, + extra_claims: Dict[str, Any] | None = None, +) -> str: + """Issue a signed refresh token for the provided subject.""" + + lifetime = expires_delta or settings.refresh_token_ttl + return _create_token( + subject=subject, + token_type="refresh", + settings=settings, + lifetime=lifetime, + scopes=scopes, + extra_claims=extra_claims, + ) + + +def decode_access_token(token: str, settings: JWTSettings) -> TokenPayload: + """Validate and decode an access token.""" + + return _decode_token(token, settings, expected_type="access") + + +def decode_refresh_token(token: str, settings: JWTSettings) -> TokenPayload: + """Validate and decode a refresh token.""" + + return _decode_token(token, settings, expected_type="refresh") + + +def _create_token( + *, + subject: str, + token_type: TokenKind, + settings: JWTSettings, + lifetime: timedelta, + scopes: Iterable[str] | None, + extra_claims: Dict[str, Any] | None, +) -> str: + now = datetime.now(timezone.utc) + expire = now + lifetime + payload: Dict[str, Any] = { + "sub": subject, + "type": token_type, + "iat": int(now.timestamp()), + "exp": int(expire.timestamp()), + } + if scopes: + payload["scopes"] = list(scopes) + if extra_claims: + payload.update(extra_claims) + + return jwt.encode(payload, settings.secret_key, algorithm=settings.algorithm) + + +def _decode_token( + token: str, + settings: JWTSettings, + expected_type: TokenKind, +) -> TokenPayload: + try: + decoded = jwt.decode( + token, + settings.secret_key, + algorithms=[settings.algorithm], + options={"verify_aud": False}, + ) + except ExpiredSignatureError as exc: # pragma: no cover - jose marks this path + raise TokenExpiredError("Token has expired") from exc + except JWTError as exc: # pragma: no cover - jose error bubble + raise TokenDecodeError("Unable to decode token") from exc + + try: + payload = _model_validate(TokenPayload, decoded) + except ValidationError as exc: + raise TokenDecodeError("Token payload validation failed") from exc + + if payload.type != expected_type: + raise TokenTypeMismatchError( + f"Expected a {expected_type} token but received '{payload.type}'." + ) + + return payload + + +def _model_validate(model: Type[TokenPayload], data: Dict[str, Any]) -> TokenPayload: + if hasattr(model, "model_validate"): + return model.model_validate(data) # type: ignore[attr-defined] + return model.parse_obj(data) # type: ignore[attr-defined] + + +__all__ = [ + "JWTSettings", + "TokenDecodeError", + "TokenError", + "TokenExpiredError", + "TokenKind", + "TokenPayload", + "TokenTypeMismatchError", + "create_access_token", + "create_refresh_token", + "decode_access_token", + "decode_refresh_token", + "hash_password", + "password_context", + "verify_password", +] diff --git a/services/unit_of_work.py b/services/unit_of_work.py index 2e7b9e8..a51849e 100644 --- a/services/unit_of_work.py +++ b/services/unit_of_work.py @@ -6,12 +6,16 @@ from typing import Callable, Sequence from sqlalchemy.orm import Session from config.database import SessionLocal -from models import Scenario +from models import Role, Scenario from services.repositories import ( FinancialInputRepository, ProjectRepository, + RoleRepository, ScenarioRepository, SimulationParameterRepository, + UserRepository, + ensure_admin_user as ensure_admin_user_record, + ensure_default_roles, ) from services.scenario_validation import ScenarioComparisonValidator @@ -23,6 +27,12 @@ class UnitOfWork(AbstractContextManager["UnitOfWork"]): self._session_factory = session_factory self.session: Session | None = None self._scenario_validator: ScenarioComparisonValidator | None = None + self.projects: ProjectRepository | None = None + self.scenarios: ScenarioRepository | None = None + self.financial_inputs: FinancialInputRepository | None = None + self.simulation_parameters: SimulationParameterRepository | None = None + self.users: UserRepository | None = None + self.roles: RoleRepository | None = None def __enter__(self) -> "UnitOfWork": self.session = self._session_factory() @@ -31,6 +41,8 @@ class UnitOfWork(AbstractContextManager["UnitOfWork"]): self.financial_inputs = FinancialInputRepository(self.session) self.simulation_parameters = SimulationParameterRepository( self.session) + self.users = UserRepository(self.session) + self.roles = RoleRepository(self.session) self._scenario_validator = ScenarioComparisonValidator() return self @@ -42,6 +54,12 @@ class UnitOfWork(AbstractContextManager["UnitOfWork"]): self.session.rollback() self.session.close() self._scenario_validator = None + self.projects = None + self.scenarios = None + self.financial_inputs = None + self.simulation_parameters = None + self.users = None + self.roles = None def flush(self) -> None: if not self.session: @@ -61,7 +79,7 @@ class UnitOfWork(AbstractContextManager["UnitOfWork"]): def validate_scenarios_for_comparison( self, scenario_ids: Sequence[int] ) -> list[Scenario]: - if not self.session or not self._scenario_validator: + if not self.session or not self._scenario_validator or not self.scenarios: raise RuntimeError("UnitOfWork session is not initialised") scenarios = [self.scenarios.get(scenario_id) @@ -75,3 +93,26 @@ class UnitOfWork(AbstractContextManager["UnitOfWork"]): if not self._scenario_validator: raise RuntimeError("UnitOfWork session is not initialised") self._scenario_validator.validate(scenarios) + + def ensure_default_roles(self) -> list[Role]: + if not self.roles: + raise RuntimeError("UnitOfWork session is not initialised") + return ensure_default_roles(self.roles) + + def ensure_admin_user( + self, + *, + email: str, + username: str, + password: str, + ) -> None: + if not self.users or not self.roles: + raise RuntimeError("UnitOfWork session is not initialised") + ensure_default_roles(self.roles) + ensure_admin_user_record( + self.users, + self.roles, + email=email, + username=username, + password=password, + ) diff --git a/static/img/logo_big.png b/static/img/logo_big.png new file mode 100644 index 0000000..98f1a78 Binary files /dev/null and b/static/img/logo_big.png differ diff --git a/templates/forgot_password.html b/templates/forgot_password.html index 4d21fd3..2d257cb 100644 --- a/templates/forgot_password.html +++ b/templates/forgot_password.html @@ -1,17 +1,25 @@ -{% extends "base.html" %} - -{% block title %}Forgot Password{% endblock %} - -{% block content %} +{% extends "base.html" %} {% block title %}Forgot Password{% endblock %} {% +block content %}
-

Forgot Password

-
-
- - -
- -
-

Remember your password? Login here

+

Forgot Password

+ {% if errors %} +
+ +
+ {% endif %} {% if message %} +
{{ message }}
+ {% endif %} +
+
+ + +
+ +
+

Remember your password? Login here

{% endblock %} diff --git a/templates/login.html b/templates/login.html index 6c2eb00..7279b8f 100644 --- a/templates/login.html +++ b/templates/login.html @@ -1,22 +1,34 @@ -{% extends "base.html" %} - -{% block title %}Login{% endblock %} - -{% block content %} +{% extends "base.html" %} {% block title %}Login{% endblock %} {% block content +%}
-

Login

-
-
- - -
-
- - -
- -
-

Don't have an account? Register here

-

Forgot password?

+

Login

+ {% if errors %} +
+ +
+ {% endif %} +
+
+ + +
+
+ + +
+ +
+

Don't have an account? Register here

+

Forgot password?

{% endblock %} diff --git a/templates/partials/sidebar_nav.html b/templates/partials/sidebar_nav.html index 8fec8b3..1980125 100644 --- a/templates/partials/sidebar_nav.html +++ b/templates/partials/sidebar_nav.html @@ -1,80 +1,60 @@ {% set dashboard_href = request.url_for('dashboard.home') if request else '/' %} -{% set projects_href = request.url_for('projects.project_list_page') if request else '/projects/ui' %} -{% set project_create_href = request.url_for('projects.create_project_form') if request else '/projects/create' %} - -{% set nav_groups = [ - { - "label": "Workspace", - "links": [ - {"href": dashboard_href, "label": "Dashboard", "match_prefix": "/"}, - {"href": projects_href, "label": "Projects", "match_prefix": "/projects"}, - {"href": project_create_href, "label": "New Project", "match_prefix": "/projects/create"}, - ], - }, - { - "label": "Insights", - "links": [ - {"href": "/ui/simulations", "label": "Simulations"}, - {"href": "/ui/reporting", "label": "Reporting"}, - ], - }, - { - "label": "Configuration", - "links": [ - { - "href": "/ui/settings", - "label": "Settings", - "children": [ - {"href": "/theme-settings", "label": "Themes"}, - {"href": "/ui/currencies", "label": "Currency Management"}, - ], - }, - ], - }, -] %} +{% set projects_href = request.url_for('projects.project_list_page') if request +else '/projects/ui' %} {% set project_create_href = +request.url_for('projects.create_project_form') if request else +'/projects/create' %} {% set login_href = request.url_for('auth.login_form') if +request else '/login' %} {% set register_href = +request.url_for('auth.register_form') if request else '/register' %} {% set +forgot_href = request.url_for('auth.password_reset_request_form') if request +else '/forgot-password' %} {% set nav_groups = [ { "label": "Workspace", +"links": [ {"href": dashboard_href, "label": "Dashboard", "match_prefix": "/"}, +{"href": projects_href, "label": "Projects", "match_prefix": "/projects"}, +{"href": project_create_href, "label": "New Project", "match_prefix": +"/projects/create"}, ], }, { "label": "Insights", "links": [ {"href": +"/ui/simulations", "label": "Simulations"}, {"href": "/ui/reporting", "label": +"Reporting"}, ], }, { "label": "Configuration", "links": [ { "href": +"/ui/settings", "label": "Settings", "children": [ {"href": "/theme-settings", +"label": "Themes"}, {"href": "/ui/currencies", "label": "Currency Management"}, +], }, ], }, { "label": "Account", "links": [ {"href": login_href, "label": +"Login", "match_prefix": "/login"}, {"href": register_href, "label": "Register", +"match_prefix": "/register"}, {"href": forgot_href, "label": "Forgot Password", +"match_prefix": "/forgot-password"}, ], }, ] %} diff --git a/templates/register.html b/templates/register.html index 04a7b4e..d422405 100644 --- a/templates/register.html +++ b/templates/register.html @@ -1,25 +1,43 @@ -{% extends "base.html" %} - -{% block title %}Register{% endblock %} - -{% block content %} +{% extends "base.html" %} {% block title %}Register{% endblock %} {% block +content %}
-

Register

-
-
- - -
-
- - -
-
- - -
- -
-

Already have an account? Login here

+

Register

+ {% if errors %} +
+ +
+ {% endif %} +
+
+ + +
+
+ + +
+
+ + +
+ +
+

Already have an account? Login here

{% endblock %} diff --git a/templates/reset_password.html b/templates/reset_password.html new file mode 100644 index 0000000..66c3007 --- /dev/null +++ b/templates/reset_password.html @@ -0,0 +1,36 @@ +{% extends "base.html" %} {% block title %}Reset Password{% endblock %} {% block +content %} +
+

Reset Password

+ {% if errors %} +
+ +
+ {% endif %} +
+ +
+ + +
+
+ + +
+ +
+

+ Remembered your password? + Return to login +

+
+{% endblock %} diff --git a/tests/conftest.py b/tests/conftest.py index 5ad55cc..6ec51e5 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -12,6 +12,7 @@ from sqlalchemy.pool import StaticPool from config.database import Base from dependencies import get_unit_of_work +from routes.auth import router as auth_router from routes.dashboard import router as dashboard_router from routes.projects import router as projects_router from routes.scenarios import router as scenarios_router @@ -36,13 +37,15 @@ def engine() -> Iterator[Engine]: @pytest.fixture() def session_factory(engine: Engine) -> Iterator[sessionmaker]: - testing_session = sessionmaker(bind=engine, expire_on_commit=False, future=True) + testing_session = sessionmaker( + bind=engine, expire_on_commit=False, future=True) yield testing_session @pytest.fixture() def app(session_factory: sessionmaker) -> FastAPI: application = FastAPI() + application.include_router(auth_router) application.include_router(dashboard_router) application.include_router(projects_router) application.include_router(scenarios_router) diff --git a/tests/test_auth_repositories.py b/tests/test_auth_repositories.py new file mode 100644 index 0000000..9c656db --- /dev/null +++ b/tests/test_auth_repositories.py @@ -0,0 +1,135 @@ +from __future__ import annotations + +from collections.abc import Iterator + +import pytest +from sqlalchemy import create_engine +from sqlalchemy.orm import Session, sessionmaker + +from config.database import Base +from models import Role, User +from services.repositories import ( + RoleRepository, + UserRepository, + ensure_admin_user, + ensure_default_roles, +) +from services.unit_of_work import UnitOfWork + + +@pytest.fixture() +def engine() -> Iterator: + engine = create_engine("sqlite:///:memory:", future=True) + Base.metadata.create_all(bind=engine) + try: + yield engine + finally: + Base.metadata.drop_all(bind=engine) + + +@pytest.fixture() +def session(engine) -> Iterator[Session]: + TestingSession = sessionmaker( + bind=engine, expire_on_commit=False, future=True) + db = TestingSession() + try: + yield db + finally: + db.close() + + +def test_role_repository_create_and_lookup(session: Session) -> None: + repo = RoleRepository(session) + role = Role(name="custom", display_name="Custom", + description="Custom role") + repo.create(role) + + retrieved = repo.get(role.id) + assert retrieved.name == "custom" + assert repo.get_by_name("custom") is retrieved + assert repo.list()[0].name == "custom" + + +def test_user_repository_assign_and_revoke_role(session: Session) -> None: + role_repo = RoleRepository(session) + user_repo = UserRepository(session) + + analyst = role_repo.create( + Role(name="analyst", display_name="Analyst", description="Analyzes data") + ) + user = User( + email="user@example.com", + username="user", + password_hash=User.hash_password("secret"), + ) + user_repo.create(user) + + assignment = user_repo.assign_role( + user_id=user.id, role_id=analyst.id, granted_by=None) + assert assignment.role_id == analyst.id + + refreshed = user_repo.get(user.id, with_roles=True) + assert refreshed.roles[0].name == "analyst" + + user_repo.revoke_role(user_id=user.id, role_id=analyst.id) + refreshed = user_repo.get(user.id, with_roles=True) + assert refreshed.roles == [] + + +def test_default_role_and_admin_helpers(session: Session) -> None: + role_repo = RoleRepository(session) + user_repo = UserRepository(session) + + roles = ensure_default_roles(role_repo) + assert {role.name for role in roles} == { + "admin", "project_manager", "analyst", "viewer"} + + ensure_admin_user( + user_repo, + role_repo, + email="admin@example.com", + username="admin", + password="SecurePass1!", + ) + + admin = user_repo.get_by_email("admin@example.com", with_roles=True) + assert admin is not None + assert admin.is_superuser + assert {role.name for role in admin.roles} >= {"admin"} + + # Idempotent behaviour on subsequent calls + ensure_admin_user( + user_repo, + role_repo, + email="admin@example.com", + username="admin", + password="SecurePass1!", + ) + admin_again = user_repo.get_by_email("admin@example.com", with_roles=True) + assert admin_again is not None + assert len(admin_again.roles) == len( + {role.name for role in admin_again.roles}) + + +def test_unit_of_work_exposes_auth_repositories(engine) -> None: + TestingSession = sessionmaker( + bind=engine, expire_on_commit=False, future=True) + + with UnitOfWork(session_factory=TestingSession) as uow: + assert uow.users is not None + assert uow.roles is not None + + roles = uow.ensure_default_roles() + assert any(role.name == "admin" for role in roles) + + uow.ensure_admin_user( + email="uow-admin@example.com", + username="uow-admin", + password="AnotherSecret1!", + ) + + admin = uow.users.get_by_email( + "uow-admin@example.com", with_roles=True) + assert admin is not None + assert admin.is_superuser + assert any(role.name == "admin" for role in admin.roles) diff --git a/tests/test_auth_routes.py b/tests/test_auth_routes.py new file mode 100644 index 0000000..08b4956 --- /dev/null +++ b/tests/test_auth_routes.py @@ -0,0 +1,239 @@ +from __future__ import annotations + +from collections.abc import Iterator +from urllib.parse import parse_qs, urlparse + +import pytest +from fastapi.testclient import TestClient +from sqlalchemy import select +from sqlalchemy.orm import Session, sessionmaker + +from models import Role, User, UserRole +from services.security import hash_password + + +@pytest.fixture() +def db_session(session_factory: sessionmaker) -> Iterator[Session]: + session = session_factory() + try: + yield session + finally: + session.close() + + +def _get_user(session: Session, *, email: str | None = None, username: str | None = None) -> User | None: + stmt = select(User) + if email is not None: + stmt = stmt.where(User.email == email) + if username is not None: + stmt = stmt.where(User.username == username) + return session.execute(stmt).scalar_one_or_none() + + +class TestRegistrationFlow: + def test_register_creates_user_and_assigns_role( + self, + client: TestClient, + db_session: Session, + ) -> None: + response = client.post( + "/register", + data={ + "username": "newuser", + "email": "newuser@example.com", + "password": "ComplexP@ss1", + "confirm_password": "ComplexP@ss1", + }, + follow_redirects=False, + ) + + assert response.status_code == 303 + location = response.headers.get("location") + assert location + parsed = urlparse(location) + assert parsed.path == "/login" + assert parse_qs(parsed.query).get("registered") == ["1"] + + created = _get_user(db_session, email="newuser@example.com") + assert created is not None + assert created.is_active + + role_stmt = select(Role).where(Role.name == "viewer") + viewer_role = db_session.execute(role_stmt).scalar_one_or_none() + assert viewer_role is not None + + assignments = db_session.execute( + select(UserRole).where( + UserRole.user_id == created.id, + UserRole.role_id == viewer_role.id, + ) + ).scalars().all() + assert len(assignments) == 1 + + def test_register_duplicate_email_shows_error( + self, + client: TestClient, + ) -> None: + first = client.post( + "/register", + data={ + "username": "existing", + "email": "existing@example.com", + "password": "ComplexP@ss1", + "confirm_password": "ComplexP@ss1", + }, + follow_redirects=False, + ) + assert first.status_code == 303 + + second = client.post( + "/register", + data={ + "username": "existing", + "email": "existing@example.com", + "password": "ComplexP@ss1", + "confirm_password": "ComplexP@ss1", + }, + follow_redirects=False, + ) + + assert second.status_code == 400 + assert "Email is already registered" in second.text + + +class TestLoginFlow: + def test_login_sets_tokens_and_updates_last_login( + self, + client: TestClient, + db_session: Session, + ) -> None: + password = "MySecur3Pass!" + user = User( + email="login@example.com", + username="loginuser", + password_hash=hash_password(password), + is_active=True, + ) + db_session.add(user) + db_session.commit() + + response = client.post( + "/login", + data={"username": "loginuser", "password": password}, + follow_redirects=False, + ) + + assert response.status_code == 303 + assert response.headers.get("location") == "http://testserver/" + set_cookie_header = response.headers.get("set-cookie", "") + assert "calminer_access_token=" in set_cookie_header + assert "calminer_refresh_token=" in set_cookie_header + + updated = _get_user(db_session, username="loginuser") + assert updated is not None and updated.last_login_at is not None + + def test_login_invalid_credentials_returns_error(self, client: TestClient) -> None: + response = client.post( + "/login", + data={"username": "unknown", "password": "bad"}, + follow_redirects=False, + ) + + assert response.status_code == 400 + assert "Invalid username or password" in response.text + + +class TestPasswordResetFlow: + def test_password_reset_round_trip( + self, + client: TestClient, + db_session: Session, + ) -> None: + user = User( + email="reset@example.com", + username="resetuser", + password_hash=hash_password("OldP@ssword1"), + is_active=True, + ) + db_session.add(user) + db_session.commit() + + request_response = client.post( + "/forgot-password", + data={"email": "reset@example.com"}, + follow_redirects=False, + ) + + assert request_response.status_code == 303 + reset_location = request_response.headers.get("location") + assert reset_location is not None + parsed = urlparse(reset_location) + assert parsed.path == "/reset-password" + token = parse_qs(parsed.query).get("token", [None])[0] + assert token + + form_response = client.get(reset_location) + assert form_response.status_code == 200 + + submit_response = client.post( + "/reset-password", + data={ + "token": token, + "password": "N3wP@ssword!", + "confirm_password": "N3wP@ssword!", + }, + follow_redirects=False, + ) + + assert submit_response.status_code == 303 + assert "reset=1" in (submit_response.headers.get("location") or "") + + db_session.refresh(user) + assert user.verify_password("N3wP@ssword!") + + def test_password_reset_with_unknown_email_shows_generic_message( + self, + client: TestClient, + ) -> None: + response = client.post( + "/forgot-password", + data={"email": "doesnotexist@example.com"}, + follow_redirects=False, + ) + + assert response.status_code == 200 + assert "If an account exists" in response.text + + def test_password_reset_mismatched_passwords_return_error( + self, + client: TestClient, + db_session: Session, + ) -> None: + user = User( + email="mismatch@example.com", + username="mismatch", + password_hash=hash_password("OldP@ssword1"), + is_active=True, + ) + db_session.add(user) + db_session.commit() + + request_response = client.post( + "/forgot-password", + data={"email": "mismatch@example.com"}, + follow_redirects=False, + ) + token = parse_qs(urlparse(request_response.headers["location"]).query)["token"][0] + + submit_response = client.post( + "/reset-password", + data={ + "token": token, + "password": "NewPass123!", + "confirm_password": "Different123!", + }, + follow_redirects=False, + ) + + assert submit_response.status_code == 400 + assert "Passwords do not match" in submit_response.text \ No newline at end of file diff --git a/tests/test_security.py b/tests/test_security.py new file mode 100644 index 0000000..607716a --- /dev/null +++ b/tests/test_security.py @@ -0,0 +1,76 @@ +from __future__ import annotations + +from datetime import timedelta + +import pytest + +from services.security import ( + JWTSettings, + TokenDecodeError, + TokenExpiredError, + TokenTypeMismatchError, + create_access_token, + create_refresh_token, + decode_access_token, + decode_refresh_token, + hash_password, + verify_password, +) + + +def test_hash_password_round_trip() -> None: + hashed = hash_password("super-secret") + + assert hashed != "super-secret" + assert verify_password("super-secret", hashed) + assert not verify_password("incorrect", hashed) + + +def test_verify_password_handles_malformed_hash() -> None: + assert not verify_password("secret", "not-a-valid-hash") + + +def test_access_token_roundtrip() -> None: + settings = JWTSettings(secret_key="unit-test-secret") + + token = create_access_token( + "user-id-123", + settings, + scopes=("read", "write"), + extra_claims={"custom": "value"}, + ) + + payload = decode_access_token(token, settings) + + assert payload.sub == "user-id-123" + assert payload.type == "access" + assert payload.scopes == ["read", "write"] + + +def test_refresh_token_type_mismatch() -> None: + settings = JWTSettings(secret_key="unit-test-secret") + token = create_refresh_token("user-id-456", settings) + + with pytest.raises(TokenTypeMismatchError): + decode_access_token(token, settings) + + +def test_decode_expired_token() -> None: + settings = JWTSettings(secret_key="unit-test-secret") + expired_token = create_access_token( + "user-id-789", + settings, + expires_delta=timedelta(seconds=-5), + ) + + with pytest.raises(TokenExpiredError): + decode_access_token(expired_token, settings) + + +def test_decode_tampered_token() -> None: + settings = JWTSettings(secret_key="unit-test-secret") + token = create_access_token("user-id-321", settings) + tampered = token[:-1] + ("a" if token[-1] != "a" else "b") + + with pytest.raises(TokenDecodeError): + decode_access_token(tampered, settings)