feat: Implement user and role management with repositories
- Added RoleRepository and UserRepository for managing roles and users. - Implemented methods for creating, retrieving, and assigning roles to users. - Introduced functions to ensure default roles and an admin user exist in the system. - Updated UnitOfWork to include user and role repositories. - Created new security module for password hashing and JWT token management. - Added tests for authentication flows, including registration, login, and password reset. - Enhanced HTML templates for user registration, login, and password management with error handling. - Added a logo image to the static assets.
This commit is contained in:
210
alembic/versions/20251109_02_add_auth_tables.py
Normal file
210
alembic/versions/20251109_02_add_auth_tables.py
Normal file
@@ -0,0 +1,210 @@
|
|||||||
|
"""Add authentication and RBAC tables"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from passlib.context import CryptContext
|
||||||
|
from sqlalchemy.sql import column, table
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision = "20251109_02"
|
||||||
|
down_revision = "20251109_01"
|
||||||
|
branch_labels = None
|
||||||
|
depends_on = None
|
||||||
|
|
||||||
|
password_context = CryptContext(schemes=["argon2"], deprecated="auto")
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
op.create_table(
|
||||||
|
"users",
|
||||||
|
sa.Column("id", sa.Integer(), primary_key=True),
|
||||||
|
sa.Column("email", sa.String(length=255), nullable=False),
|
||||||
|
sa.Column("username", sa.String(length=128), nullable=False),
|
||||||
|
sa.Column("password_hash", sa.String(length=255), nullable=False),
|
||||||
|
sa.Column(
|
||||||
|
"is_active",
|
||||||
|
sa.Boolean(),
|
||||||
|
nullable=False,
|
||||||
|
server_default=sa.true(),
|
||||||
|
),
|
||||||
|
sa.Column(
|
||||||
|
"is_superuser",
|
||||||
|
sa.Boolean(),
|
||||||
|
nullable=False,
|
||||||
|
server_default=sa.false(),
|
||||||
|
),
|
||||||
|
sa.Column("last_login_at", sa.DateTime(timezone=True), nullable=True),
|
||||||
|
sa.Column(
|
||||||
|
"created_at",
|
||||||
|
sa.DateTime(timezone=True),
|
||||||
|
nullable=False,
|
||||||
|
server_default=sa.func.now(),
|
||||||
|
),
|
||||||
|
sa.Column(
|
||||||
|
"updated_at",
|
||||||
|
sa.DateTime(timezone=True),
|
||||||
|
nullable=False,
|
||||||
|
server_default=sa.func.now(),
|
||||||
|
),
|
||||||
|
sa.UniqueConstraint("email", name="uq_users_email"),
|
||||||
|
sa.UniqueConstraint("username", name="uq_users_username"),
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
"ix_users_active_superuser",
|
||||||
|
"users",
|
||||||
|
["is_active", "is_superuser"],
|
||||||
|
unique=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
op.create_table(
|
||||||
|
"roles",
|
||||||
|
sa.Column("id", sa.Integer(), primary_key=True),
|
||||||
|
sa.Column("name", sa.String(length=64), nullable=False),
|
||||||
|
sa.Column("display_name", sa.String(length=128), nullable=False),
|
||||||
|
sa.Column("description", sa.Text(), nullable=True),
|
||||||
|
sa.Column(
|
||||||
|
"created_at",
|
||||||
|
sa.DateTime(timezone=True),
|
||||||
|
nullable=False,
|
||||||
|
server_default=sa.func.now(),
|
||||||
|
),
|
||||||
|
sa.Column(
|
||||||
|
"updated_at",
|
||||||
|
sa.DateTime(timezone=True),
|
||||||
|
nullable=False,
|
||||||
|
server_default=sa.func.now(),
|
||||||
|
),
|
||||||
|
sa.UniqueConstraint("name", name="uq_roles_name"),
|
||||||
|
)
|
||||||
|
|
||||||
|
op.create_table(
|
||||||
|
"user_roles",
|
||||||
|
sa.Column("user_id", sa.Integer(), nullable=False),
|
||||||
|
sa.Column("role_id", sa.Integer(), nullable=False),
|
||||||
|
sa.Column(
|
||||||
|
"granted_at",
|
||||||
|
sa.DateTime(timezone=True),
|
||||||
|
nullable=False,
|
||||||
|
server_default=sa.func.now(),
|
||||||
|
),
|
||||||
|
sa.Column("granted_by", sa.Integer(), nullable=True),
|
||||||
|
sa.ForeignKeyConstraint(
|
||||||
|
["user_id"],
|
||||||
|
["users.id"],
|
||||||
|
ondelete="CASCADE",
|
||||||
|
),
|
||||||
|
sa.ForeignKeyConstraint(
|
||||||
|
["role_id"],
|
||||||
|
["roles.id"],
|
||||||
|
ondelete="CASCADE",
|
||||||
|
),
|
||||||
|
sa.ForeignKeyConstraint(
|
||||||
|
["granted_by"],
|
||||||
|
["users.id"],
|
||||||
|
ondelete="SET NULL",
|
||||||
|
),
|
||||||
|
sa.PrimaryKeyConstraint("user_id", "role_id"),
|
||||||
|
sa.UniqueConstraint("user_id", "role_id",
|
||||||
|
name="uq_user_roles_user_role"),
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
"ix_user_roles_role_id",
|
||||||
|
"user_roles",
|
||||||
|
["role_id"],
|
||||||
|
unique=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Seed default roles
|
||||||
|
roles_table = table(
|
||||||
|
"roles",
|
||||||
|
column("id", sa.Integer()),
|
||||||
|
column("name", sa.String()),
|
||||||
|
column("display_name", sa.String()),
|
||||||
|
column("description", sa.Text()),
|
||||||
|
)
|
||||||
|
|
||||||
|
op.bulk_insert(
|
||||||
|
roles_table,
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"id": 1,
|
||||||
|
"name": "admin",
|
||||||
|
"display_name": "Administrator",
|
||||||
|
"description": "Full platform access with user management rights.",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2,
|
||||||
|
"name": "project_manager",
|
||||||
|
"display_name": "Project Manager",
|
||||||
|
"description": "Manage projects, scenarios, and associated data.",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3,
|
||||||
|
"name": "analyst",
|
||||||
|
"display_name": "Analyst",
|
||||||
|
"description": "Review dashboards and scenario outputs.",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 4,
|
||||||
|
"name": "viewer",
|
||||||
|
"display_name": "Viewer",
|
||||||
|
"description": "Read-only access to assigned projects and reports.",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
admin_password_hash = password_context.hash("ChangeMe123!")
|
||||||
|
|
||||||
|
users_table = table(
|
||||||
|
"users",
|
||||||
|
column("id", sa.Integer()),
|
||||||
|
column("email", sa.String()),
|
||||||
|
column("username", sa.String()),
|
||||||
|
column("password_hash", sa.String()),
|
||||||
|
column("is_active", sa.Boolean()),
|
||||||
|
column("is_superuser", sa.Boolean()),
|
||||||
|
)
|
||||||
|
|
||||||
|
op.bulk_insert(
|
||||||
|
users_table,
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"id": 1,
|
||||||
|
"email": "admin@calminer.local",
|
||||||
|
"username": "admin",
|
||||||
|
"password_hash": admin_password_hash,
|
||||||
|
"is_active": True,
|
||||||
|
"is_superuser": True,
|
||||||
|
}
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
user_roles_table = table(
|
||||||
|
"user_roles",
|
||||||
|
column("user_id", sa.Integer()),
|
||||||
|
column("role_id", sa.Integer()),
|
||||||
|
column("granted_by", sa.Integer()),
|
||||||
|
)
|
||||||
|
|
||||||
|
op.bulk_insert(
|
||||||
|
user_roles_table,
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"user_id": 1,
|
||||||
|
"role_id": 1,
|
||||||
|
"granted_by": 1,
|
||||||
|
}
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.drop_index("ix_user_roles_role_id", table_name="user_roles")
|
||||||
|
op.drop_table("user_roles")
|
||||||
|
|
||||||
|
op.drop_table("roles")
|
||||||
|
|
||||||
|
op.drop_index("ix_users_active_superuser", table_name="users")
|
||||||
|
op.drop_table("users")
|
||||||
BIN
alembic_test.db
Normal file
BIN
alembic_test.db
Normal file
Binary file not shown.
@@ -16,3 +16,5 @@
|
|||||||
- Reordered project route registration to prioritize static UI paths, eliminating 422 errors on `/projects/ui` and `/projects/create`, and added pytest smoke coverage for the navigation endpoints.
|
- Reordered project route registration to prioritize static UI paths, eliminating 422 errors on `/projects/ui` and `/projects/create`, and added pytest smoke coverage for the navigation endpoints.
|
||||||
- Added end-to-end integration tests for project and scenario lifecycles, validating HTML redirects, template rendering, and API interactions, and updated `ProjectRepository.get` to deduplicate joined loads for detail views.
|
- Added end-to-end integration tests for project and scenario lifecycles, validating HTML redirects, template rendering, and API interactions, and updated `ProjectRepository.get` to deduplicate joined loads for detail views.
|
||||||
- Updated all Jinja2 template responses to the new Starlette signature to eliminate deprecation warnings while keeping request-aware context available to the templates.
|
- Updated all Jinja2 template responses to the new Starlette signature to eliminate deprecation warnings while keeping request-aware context available to the templates.
|
||||||
|
- Introduced `services/security.py` to centralize Argon2 password hashing utilities and JWT creation/verification with typed payloads, and added pytest coverage for hashing, expiry, tampering, and token type mismatch scenarios.
|
||||||
|
- Added `routes/auth.py` with registration, login, and password reset flows, refreshed auth templates with error messaging, wired navigation links, and introduced end-to-end pytest coverage for the new forms and token flows.
|
||||||
|
|||||||
60
config/settings.py
Normal file
60
config/settings.py
Normal file
@@ -0,0 +1,60 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import os
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from datetime import timedelta
|
||||||
|
from functools import lru_cache
|
||||||
|
|
||||||
|
from services.security import JWTSettings
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True, slots=True)
|
||||||
|
class Settings:
|
||||||
|
"""Application configuration sourced from environment variables."""
|
||||||
|
|
||||||
|
jwt_secret_key: str = "change-me"
|
||||||
|
jwt_algorithm: str = "HS256"
|
||||||
|
jwt_access_token_minutes: int = 15
|
||||||
|
jwt_refresh_token_days: int = 7
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_environment(cls) -> "Settings":
|
||||||
|
"""Construct settings from environment variables."""
|
||||||
|
|
||||||
|
return cls(
|
||||||
|
jwt_secret_key=os.getenv("CALMINER_JWT_SECRET", "change-me"),
|
||||||
|
jwt_algorithm=os.getenv("CALMINER_JWT_ALGORITHM", "HS256"),
|
||||||
|
jwt_access_token_minutes=cls._int_from_env(
|
||||||
|
"CALMINER_JWT_ACCESS_MINUTES", 15
|
||||||
|
),
|
||||||
|
jwt_refresh_token_days=cls._int_from_env(
|
||||||
|
"CALMINER_JWT_REFRESH_DAYS", 7
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _int_from_env(name: str, default: int) -> int:
|
||||||
|
raw_value = os.getenv(name)
|
||||||
|
if raw_value is None:
|
||||||
|
return default
|
||||||
|
try:
|
||||||
|
return int(raw_value)
|
||||||
|
except ValueError:
|
||||||
|
return default
|
||||||
|
|
||||||
|
def jwt_settings(self) -> JWTSettings:
|
||||||
|
"""Build runtime JWT settings compatible with token helpers."""
|
||||||
|
|
||||||
|
return JWTSettings(
|
||||||
|
secret_key=self.jwt_secret_key,
|
||||||
|
algorithm=self.jwt_algorithm,
|
||||||
|
access_token_ttl=timedelta(minutes=self.jwt_access_token_minutes),
|
||||||
|
refresh_token_ttl=timedelta(days=self.jwt_refresh_token_days),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache(maxsize=1)
|
||||||
|
def get_settings() -> Settings:
|
||||||
|
"""Return cached application settings."""
|
||||||
|
|
||||||
|
return Settings.from_environment()
|
||||||
@@ -2,6 +2,8 @@ from __future__ import annotations
|
|||||||
|
|
||||||
from collections.abc import Generator
|
from collections.abc import Generator
|
||||||
|
|
||||||
|
from config.settings import Settings, get_settings
|
||||||
|
from services.security import JWTSettings
|
||||||
from services.unit_of_work import UnitOfWork
|
from services.unit_of_work import UnitOfWork
|
||||||
|
|
||||||
|
|
||||||
@@ -10,3 +12,15 @@ def get_unit_of_work() -> Generator[UnitOfWork, None, None]:
|
|||||||
|
|
||||||
with UnitOfWork() as uow:
|
with UnitOfWork() as uow:
|
||||||
yield uow
|
yield uow
|
||||||
|
|
||||||
|
|
||||||
|
def get_application_settings() -> Settings:
|
||||||
|
"""Provide cached application settings instance."""
|
||||||
|
|
||||||
|
return get_settings()
|
||||||
|
|
||||||
|
|
||||||
|
def get_jwt_settings() -> JWTSettings:
|
||||||
|
"""Provide JWT runtime configuration derived from settings."""
|
||||||
|
|
||||||
|
return get_settings().jwt_settings()
|
||||||
|
|||||||
2
main.py
2
main.py
@@ -10,6 +10,7 @@ from models import (
|
|||||||
Scenario,
|
Scenario,
|
||||||
SimulationParameter,
|
SimulationParameter,
|
||||||
)
|
)
|
||||||
|
from routes.auth import router as auth_router
|
||||||
from routes.dashboard import router as dashboard_router
|
from routes.dashboard import router as dashboard_router
|
||||||
from routes.projects import router as projects_router
|
from routes.projects import router as projects_router
|
||||||
from routes.scenarios import router as scenarios_router
|
from routes.scenarios import router as scenarios_router
|
||||||
@@ -33,6 +34,7 @@ async def health() -> dict[str, str]:
|
|||||||
|
|
||||||
|
|
||||||
app.include_router(dashboard_router)
|
app.include_router(dashboard_router)
|
||||||
|
app.include_router(auth_router)
|
||||||
app.include_router(projects_router)
|
app.include_router(projects_router)
|
||||||
app.include_router(scenarios_router)
|
app.include_router(scenarios_router)
|
||||||
|
|
||||||
|
|||||||
@@ -4,6 +4,14 @@ from datetime import datetime
|
|||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
from passlib.context import CryptContext
|
from passlib.context import CryptContext
|
||||||
|
|
||||||
|
try: # pragma: no cover - defensive compatibility shim
|
||||||
|
import importlib.metadata as importlib_metadata
|
||||||
|
import argon2 # type: ignore
|
||||||
|
|
||||||
|
setattr(argon2, "__version__", importlib_metadata.version("argon2-cffi"))
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
from sqlalchemy import (
|
from sqlalchemy import (
|
||||||
Boolean,
|
Boolean,
|
||||||
DateTime,
|
DateTime,
|
||||||
|
|||||||
473
routes/auth.py
Normal file
473
routes/auth.py
Normal file
@@ -0,0 +1,473 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from datetime import datetime, timedelta, timezone
|
||||||
|
from typing import Any, Iterable
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, Request, UploadFile, status
|
||||||
|
from fastapi.responses import HTMLResponse, RedirectResponse
|
||||||
|
from fastapi.templating import Jinja2Templates
|
||||||
|
from pydantic import ValidationError
|
||||||
|
from starlette.datastructures import FormData
|
||||||
|
|
||||||
|
from dependencies import get_jwt_settings, get_unit_of_work
|
||||||
|
from models import Role, User
|
||||||
|
from schemas.auth import (
|
||||||
|
LoginForm,
|
||||||
|
PasswordResetForm,
|
||||||
|
PasswordResetRequestForm,
|
||||||
|
RegistrationForm,
|
||||||
|
)
|
||||||
|
from services.exceptions import EntityConflictError
|
||||||
|
from services.security import (
|
||||||
|
JWTSettings,
|
||||||
|
TokenDecodeError,
|
||||||
|
TokenExpiredError,
|
||||||
|
TokenTypeMismatchError,
|
||||||
|
create_access_token,
|
||||||
|
create_refresh_token,
|
||||||
|
decode_access_token,
|
||||||
|
hash_password,
|
||||||
|
verify_password,
|
||||||
|
)
|
||||||
|
from services.repositories import RoleRepository, UserRepository
|
||||||
|
from services.unit_of_work import UnitOfWork
|
||||||
|
|
||||||
|
router = APIRouter(tags=["Authentication"])
|
||||||
|
templates = Jinja2Templates(directory="templates")
|
||||||
|
|
||||||
|
_PASSWORD_RESET_SCOPE = "password-reset"
|
||||||
|
_AUTH_SCOPE = "auth"
|
||||||
|
|
||||||
|
|
||||||
|
def _template(
|
||||||
|
request: Request,
|
||||||
|
template_name: str,
|
||||||
|
context: dict[str, Any],
|
||||||
|
*,
|
||||||
|
status_code: int = status.HTTP_200_OK,
|
||||||
|
) -> HTMLResponse:
|
||||||
|
return templates.TemplateResponse(
|
||||||
|
request,
|
||||||
|
template_name,
|
||||||
|
context,
|
||||||
|
status_code=status_code,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _validation_errors(exc: ValidationError) -> list[str]:
|
||||||
|
return [error.get("msg", "Invalid input.") for error in exc.errors()]
|
||||||
|
|
||||||
|
|
||||||
|
def _scopes(include: Iterable[str]) -> list[str]:
|
||||||
|
return list(include)
|
||||||
|
|
||||||
|
|
||||||
|
def _normalise_form_data(form_data: FormData) -> dict[str, str]:
|
||||||
|
normalised: dict[str, str] = {}
|
||||||
|
for key, value in form_data.multi_items():
|
||||||
|
if isinstance(value, UploadFile):
|
||||||
|
str_value = value.filename or ""
|
||||||
|
else:
|
||||||
|
str_value = str(value)
|
||||||
|
normalised[key] = str_value
|
||||||
|
return normalised
|
||||||
|
|
||||||
|
|
||||||
|
def _require_users_repo(uow: UnitOfWork) -> UserRepository:
|
||||||
|
if not uow.users:
|
||||||
|
raise RuntimeError("User repository is not initialised")
|
||||||
|
return uow.users
|
||||||
|
|
||||||
|
|
||||||
|
def _require_roles_repo(uow: UnitOfWork) -> RoleRepository:
|
||||||
|
if not uow.roles:
|
||||||
|
raise RuntimeError("Role repository is not initialised")
|
||||||
|
return uow.roles
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/login", response_class=HTMLResponse, include_in_schema=False, name="auth.login_form")
|
||||||
|
def login_form(request: Request) -> HTMLResponse:
|
||||||
|
return _template(
|
||||||
|
request,
|
||||||
|
"login.html",
|
||||||
|
{
|
||||||
|
"form_action": request.url_for("auth.login_submit"),
|
||||||
|
"errors": [],
|
||||||
|
"username": "",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/login", include_in_schema=False, name="auth.login_submit")
|
||||||
|
async def login_submit(
|
||||||
|
request: Request,
|
||||||
|
uow: UnitOfWork = Depends(get_unit_of_work),
|
||||||
|
jwt_settings: JWTSettings = Depends(get_jwt_settings),
|
||||||
|
):
|
||||||
|
form_data = _normalise_form_data(await request.form())
|
||||||
|
try:
|
||||||
|
form = LoginForm(**form_data)
|
||||||
|
except ValidationError as exc:
|
||||||
|
return _template(
|
||||||
|
request,
|
||||||
|
"login.html",
|
||||||
|
{
|
||||||
|
"form_action": request.url_for("auth.login_submit"),
|
||||||
|
"errors": _validation_errors(exc),
|
||||||
|
},
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
)
|
||||||
|
|
||||||
|
identifier = form.username
|
||||||
|
users_repo = _require_users_repo(uow)
|
||||||
|
user = _lookup_user(users_repo, identifier)
|
||||||
|
errors: list[str] = []
|
||||||
|
|
||||||
|
if not user or not verify_password(form.password, user.password_hash):
|
||||||
|
errors.append("Invalid username or password.")
|
||||||
|
elif not user.is_active:
|
||||||
|
errors.append("Account is inactive. Contact an administrator.")
|
||||||
|
|
||||||
|
if errors:
|
||||||
|
return _template(
|
||||||
|
request,
|
||||||
|
"login.html",
|
||||||
|
{
|
||||||
|
"form_action": request.url_for("auth.login_submit"),
|
||||||
|
"errors": errors,
|
||||||
|
"username": identifier,
|
||||||
|
},
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert user is not None # mypy hint - guarded above
|
||||||
|
user.last_login_at = datetime.now(timezone.utc)
|
||||||
|
|
||||||
|
access_token = create_access_token(
|
||||||
|
str(user.id),
|
||||||
|
jwt_settings,
|
||||||
|
scopes=_scopes((_AUTH_SCOPE,)),
|
||||||
|
)
|
||||||
|
refresh_token = create_refresh_token(
|
||||||
|
str(user.id),
|
||||||
|
jwt_settings,
|
||||||
|
scopes=_scopes((_AUTH_SCOPE,)),
|
||||||
|
)
|
||||||
|
|
||||||
|
response = RedirectResponse(
|
||||||
|
request.url_for("dashboard.home"),
|
||||||
|
status_code=status.HTTP_303_SEE_OTHER,
|
||||||
|
)
|
||||||
|
_set_auth_cookies(response, access_token, refresh_token, jwt_settings)
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
def _lookup_user(users_repo: UserRepository, identifier: str) -> User | None:
|
||||||
|
if "@" in identifier:
|
||||||
|
return users_repo.get_by_email(identifier.lower(), with_roles=True)
|
||||||
|
return users_repo.get_by_username(identifier, with_roles=True)
|
||||||
|
|
||||||
|
|
||||||
|
def _set_auth_cookies(
|
||||||
|
response: RedirectResponse,
|
||||||
|
access_token: str,
|
||||||
|
refresh_token: str,
|
||||||
|
jwt_settings: JWTSettings,
|
||||||
|
) -> None:
|
||||||
|
access_ttl = int(jwt_settings.access_token_ttl.total_seconds())
|
||||||
|
refresh_ttl = int(jwt_settings.refresh_token_ttl.total_seconds())
|
||||||
|
response.set_cookie(
|
||||||
|
"calminer_access_token",
|
||||||
|
access_token,
|
||||||
|
httponly=True,
|
||||||
|
secure=False,
|
||||||
|
samesite="lax",
|
||||||
|
max_age=max(access_ttl, 0) or None,
|
||||||
|
)
|
||||||
|
response.set_cookie(
|
||||||
|
"calminer_refresh_token",
|
||||||
|
refresh_token,
|
||||||
|
httponly=True,
|
||||||
|
secure=False,
|
||||||
|
samesite="lax",
|
||||||
|
max_age=max(refresh_ttl, 0) or None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/register", response_class=HTMLResponse, include_in_schema=False, name="auth.register_form")
|
||||||
|
def register_form(request: Request) -> HTMLResponse:
|
||||||
|
return _template(
|
||||||
|
request,
|
||||||
|
"register.html",
|
||||||
|
{
|
||||||
|
"form_action": request.url_for("auth.register_submit"),
|
||||||
|
"errors": [],
|
||||||
|
"form_data": None,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/register", include_in_schema=False, name="auth.register_submit")
|
||||||
|
async def register_submit(
|
||||||
|
request: Request,
|
||||||
|
uow: UnitOfWork = Depends(get_unit_of_work),
|
||||||
|
):
|
||||||
|
form_data = _normalise_form_data(await request.form())
|
||||||
|
try:
|
||||||
|
form = RegistrationForm(**form_data)
|
||||||
|
except ValidationError as exc:
|
||||||
|
return _registration_error_response(request, _validation_errors(exc))
|
||||||
|
|
||||||
|
errors: list[str] = []
|
||||||
|
users_repo = _require_users_repo(uow)
|
||||||
|
roles_repo = _require_roles_repo(uow)
|
||||||
|
uow.ensure_default_roles()
|
||||||
|
|
||||||
|
if users_repo.get_by_email(form.email):
|
||||||
|
errors.append("Email is already registered.")
|
||||||
|
if users_repo.get_by_username(form.username):
|
||||||
|
errors.append("Username is already taken.")
|
||||||
|
|
||||||
|
if errors:
|
||||||
|
return _registration_error_response(request, errors, form)
|
||||||
|
|
||||||
|
user = User(
|
||||||
|
email=form.email,
|
||||||
|
username=form.username,
|
||||||
|
password_hash=hash_password(form.password),
|
||||||
|
is_active=True,
|
||||||
|
is_superuser=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
created = users_repo.create(user)
|
||||||
|
except EntityConflictError:
|
||||||
|
return _registration_error_response(
|
||||||
|
request,
|
||||||
|
["An account with this username or email already exists."],
|
||||||
|
form,
|
||||||
|
)
|
||||||
|
|
||||||
|
viewer_role = _ensure_viewer_role(roles_repo)
|
||||||
|
if viewer_role is not None:
|
||||||
|
users_repo.assign_role(
|
||||||
|
user_id=created.id,
|
||||||
|
role_id=viewer_role.id,
|
||||||
|
granted_by=created.id,
|
||||||
|
)
|
||||||
|
|
||||||
|
redirect_url = request.url_for(
|
||||||
|
"auth.login_form").include_query_params(registered="1")
|
||||||
|
return RedirectResponse(
|
||||||
|
redirect_url,
|
||||||
|
status_code=status.HTTP_303_SEE_OTHER,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _registration_error_response(
|
||||||
|
request: Request,
|
||||||
|
errors: list[str],
|
||||||
|
form: RegistrationForm | None = None,
|
||||||
|
) -> HTMLResponse:
|
||||||
|
context = {
|
||||||
|
"form_action": request.url_for("auth.register_submit"),
|
||||||
|
"errors": errors,
|
||||||
|
"form_data": form.model_dump(exclude={"password", "confirm_password"}) if form else None,
|
||||||
|
}
|
||||||
|
return _template(
|
||||||
|
request,
|
||||||
|
"register.html",
|
||||||
|
context,
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _ensure_viewer_role(roles_repo: RoleRepository) -> Role | None:
|
||||||
|
viewer = roles_repo.get_by_name("viewer")
|
||||||
|
if viewer:
|
||||||
|
return viewer
|
||||||
|
return roles_repo.get_by_name("viewer")
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/forgot-password",
|
||||||
|
response_class=HTMLResponse,
|
||||||
|
include_in_schema=False,
|
||||||
|
name="auth.password_reset_request_form",
|
||||||
|
)
|
||||||
|
def password_reset_request_form(request: Request) -> HTMLResponse:
|
||||||
|
return _template(
|
||||||
|
request,
|
||||||
|
"forgot_password.html",
|
||||||
|
{
|
||||||
|
"form_action": request.url_for("auth.password_reset_request_submit"),
|
||||||
|
"errors": [],
|
||||||
|
"message": None,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/forgot-password",
|
||||||
|
include_in_schema=False,
|
||||||
|
name="auth.password_reset_request_submit",
|
||||||
|
)
|
||||||
|
async def password_reset_request_submit(
|
||||||
|
request: Request,
|
||||||
|
uow: UnitOfWork = Depends(get_unit_of_work),
|
||||||
|
jwt_settings: JWTSettings = Depends(get_jwt_settings),
|
||||||
|
):
|
||||||
|
form_data = _normalise_form_data(await request.form())
|
||||||
|
try:
|
||||||
|
form = PasswordResetRequestForm(**form_data)
|
||||||
|
except ValidationError as exc:
|
||||||
|
return _template(
|
||||||
|
request,
|
||||||
|
"forgot_password.html",
|
||||||
|
{
|
||||||
|
"form_action": request.url_for("auth.password_reset_request_submit"),
|
||||||
|
"errors": _validation_errors(exc),
|
||||||
|
"message": None,
|
||||||
|
},
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
)
|
||||||
|
|
||||||
|
users_repo = _require_users_repo(uow)
|
||||||
|
user = users_repo.get_by_email(form.email)
|
||||||
|
if not user:
|
||||||
|
return _template(
|
||||||
|
request,
|
||||||
|
"forgot_password.html",
|
||||||
|
{
|
||||||
|
"form_action": request.url_for("auth.password_reset_request_submit"),
|
||||||
|
"errors": [],
|
||||||
|
"message": "If an account exists, a reset link has been sent.",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
token = create_access_token(
|
||||||
|
str(user.id),
|
||||||
|
jwt_settings,
|
||||||
|
scopes=_scopes((_PASSWORD_RESET_SCOPE,)),
|
||||||
|
expires_delta=timedelta(hours=1),
|
||||||
|
)
|
||||||
|
|
||||||
|
reset_url = request.url_for(
|
||||||
|
"auth.password_reset_form").include_query_params(token=token)
|
||||||
|
return RedirectResponse(reset_url, status_code=status.HTTP_303_SEE_OTHER)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/reset-password",
|
||||||
|
response_class=HTMLResponse,
|
||||||
|
include_in_schema=False,
|
||||||
|
name="auth.password_reset_form",
|
||||||
|
)
|
||||||
|
def password_reset_form(
|
||||||
|
request: Request,
|
||||||
|
token: str | None = None,
|
||||||
|
jwt_settings: JWTSettings = Depends(get_jwt_settings),
|
||||||
|
) -> HTMLResponse:
|
||||||
|
errors: list[str] = []
|
||||||
|
if not token:
|
||||||
|
errors.append("Missing password reset token.")
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
payload = decode_access_token(token, jwt_settings)
|
||||||
|
if _PASSWORD_RESET_SCOPE not in payload.scopes:
|
||||||
|
errors.append("Invalid token scope.")
|
||||||
|
except TokenExpiredError:
|
||||||
|
errors.append(
|
||||||
|
"Token has expired. Please request a new password reset.")
|
||||||
|
except (TokenDecodeError, TokenTypeMismatchError):
|
||||||
|
errors.append("Invalid password reset token.")
|
||||||
|
|
||||||
|
return _template(
|
||||||
|
request,
|
||||||
|
"reset_password.html",
|
||||||
|
{
|
||||||
|
"form_action": request.url_for("auth.password_reset_submit"),
|
||||||
|
"token": token,
|
||||||
|
"errors": errors,
|
||||||
|
},
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST if errors else status.HTTP_200_OK,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/reset-password",
|
||||||
|
include_in_schema=False,
|
||||||
|
name="auth.password_reset_submit",
|
||||||
|
)
|
||||||
|
async def password_reset_submit(
|
||||||
|
request: Request,
|
||||||
|
uow: UnitOfWork = Depends(get_unit_of_work),
|
||||||
|
jwt_settings: JWTSettings = Depends(get_jwt_settings),
|
||||||
|
):
|
||||||
|
form_data = _normalise_form_data(await request.form())
|
||||||
|
try:
|
||||||
|
form = PasswordResetForm(**form_data)
|
||||||
|
except ValidationError as exc:
|
||||||
|
return _template(
|
||||||
|
request,
|
||||||
|
"reset_password.html",
|
||||||
|
{
|
||||||
|
"form_action": request.url_for("auth.password_reset_submit"),
|
||||||
|
"token": form_data.get("token"),
|
||||||
|
"errors": _validation_errors(exc),
|
||||||
|
},
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
payload = decode_access_token(form.token, jwt_settings)
|
||||||
|
except TokenExpiredError:
|
||||||
|
return _reset_error_response(
|
||||||
|
request,
|
||||||
|
form.token,
|
||||||
|
"Token has expired. Please request a new password reset.",
|
||||||
|
)
|
||||||
|
except (TokenDecodeError, TokenTypeMismatchError):
|
||||||
|
return _reset_error_response(
|
||||||
|
request,
|
||||||
|
form.token,
|
||||||
|
"Invalid password reset token.",
|
||||||
|
)
|
||||||
|
|
||||||
|
if _PASSWORD_RESET_SCOPE not in payload.scopes:
|
||||||
|
return _reset_error_response(
|
||||||
|
request,
|
||||||
|
form.token,
|
||||||
|
"Invalid password reset token scope.",
|
||||||
|
)
|
||||||
|
|
||||||
|
users_repo = _require_users_repo(uow)
|
||||||
|
user_id = int(payload.sub)
|
||||||
|
user = users_repo.get(user_id)
|
||||||
|
if not user:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND, detail="User not found")
|
||||||
|
|
||||||
|
user.set_password(form.password)
|
||||||
|
if not user.is_active:
|
||||||
|
user.is_active = True
|
||||||
|
|
||||||
|
redirect_url = request.url_for(
|
||||||
|
"auth.login_form").include_query_params(reset="1")
|
||||||
|
return RedirectResponse(
|
||||||
|
redirect_url,
|
||||||
|
status_code=status.HTTP_303_SEE_OTHER,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _reset_error_response(request: Request, token: str, message: str) -> HTMLResponse:
|
||||||
|
return _template(
|
||||||
|
request,
|
||||||
|
"reset_password.html",
|
||||||
|
{
|
||||||
|
"form_action": request.url_for("auth.password_reset_submit"),
|
||||||
|
"token": token,
|
||||||
|
"errors": [message],
|
||||||
|
},
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
)
|
||||||
67
schemas/auth.py
Normal file
67
schemas/auth.py
Normal file
@@ -0,0 +1,67 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_validator
|
||||||
|
|
||||||
|
|
||||||
|
class FormModel(BaseModel):
|
||||||
|
"""Base Pydantic model for HTML form submissions."""
|
||||||
|
|
||||||
|
model_config = ConfigDict(extra="forbid", str_strip_whitespace=True)
|
||||||
|
|
||||||
|
|
||||||
|
class RegistrationForm(FormModel):
|
||||||
|
username: str = Field(min_length=3, max_length=128)
|
||||||
|
email: str = Field(min_length=5, max_length=255)
|
||||||
|
password: str = Field(min_length=8, max_length=256)
|
||||||
|
confirm_password: str
|
||||||
|
|
||||||
|
@field_validator("email")
|
||||||
|
@classmethod
|
||||||
|
def validate_email(cls, value: str) -> str:
|
||||||
|
if "@" not in value or value.startswith("@") or value.endswith("@"):
|
||||||
|
raise ValueError("Invalid email address.")
|
||||||
|
local, domain = value.split("@", 1)
|
||||||
|
if not local or "." not in domain:
|
||||||
|
raise ValueError("Invalid email address.")
|
||||||
|
return value.lower()
|
||||||
|
|
||||||
|
@field_validator("confirm_password")
|
||||||
|
@classmethod
|
||||||
|
def passwords_match(cls, value: str, info: ValidationInfo) -> str:
|
||||||
|
password = info.data.get("password")
|
||||||
|
if password != value:
|
||||||
|
raise ValueError("Passwords do not match.")
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
class LoginForm(FormModel):
|
||||||
|
username: str = Field(min_length=1, max_length=255)
|
||||||
|
password: str = Field(min_length=1, max_length=256)
|
||||||
|
|
||||||
|
|
||||||
|
class PasswordResetRequestForm(FormModel):
|
||||||
|
email: str = Field(min_length=5, max_length=255)
|
||||||
|
|
||||||
|
@field_validator("email")
|
||||||
|
@classmethod
|
||||||
|
def validate_email(cls, value: str) -> str:
|
||||||
|
if "@" not in value or value.startswith("@") or value.endswith("@"):
|
||||||
|
raise ValueError("Invalid email address.")
|
||||||
|
local, domain = value.split("@", 1)
|
||||||
|
if not local or "." not in domain:
|
||||||
|
raise ValueError("Invalid email address.")
|
||||||
|
return value.lower()
|
||||||
|
|
||||||
|
|
||||||
|
class PasswordResetForm(FormModel):
|
||||||
|
token: str = Field(min_length=1)
|
||||||
|
password: str = Field(min_length=8, max_length=256)
|
||||||
|
confirm_password: str
|
||||||
|
|
||||||
|
@field_validator("confirm_password")
|
||||||
|
@classmethod
|
||||||
|
def reset_passwords_match(cls, value: str, info: ValidationInfo) -> str:
|
||||||
|
password = info.data.get("password")
|
||||||
|
if password != value:
|
||||||
|
raise ValueError("Passwords do not match.")
|
||||||
|
return value
|
||||||
@@ -8,7 +8,16 @@ from sqlalchemy import select, func
|
|||||||
from sqlalchemy.exc import IntegrityError
|
from sqlalchemy.exc import IntegrityError
|
||||||
from sqlalchemy.orm import Session, joinedload, selectinload
|
from sqlalchemy.orm import Session, joinedload, selectinload
|
||||||
|
|
||||||
from models import FinancialInput, Project, Scenario, ScenarioStatus, SimulationParameter
|
from models import (
|
||||||
|
FinancialInput,
|
||||||
|
Project,
|
||||||
|
Role,
|
||||||
|
Scenario,
|
||||||
|
ScenarioStatus,
|
||||||
|
SimulationParameter,
|
||||||
|
User,
|
||||||
|
UserRole,
|
||||||
|
)
|
||||||
from services.exceptions import EntityConflictError, EntityNotFoundError
|
from services.exceptions import EntityConflictError, EntityNotFoundError
|
||||||
|
|
||||||
|
|
||||||
@@ -211,3 +220,220 @@ class SimulationParameterRepository:
|
|||||||
raise EntityNotFoundError(
|
raise EntityNotFoundError(
|
||||||
f"Simulation parameter {parameter_id} not found")
|
f"Simulation parameter {parameter_id} not found")
|
||||||
self.session.delete(entity)
|
self.session.delete(entity)
|
||||||
|
|
||||||
|
|
||||||
|
class RoleRepository:
|
||||||
|
"""Persistence operations for Role entities."""
|
||||||
|
|
||||||
|
def __init__(self, session: Session) -> None:
|
||||||
|
self.session = session
|
||||||
|
|
||||||
|
def list(self) -> Sequence[Role]:
|
||||||
|
stmt = select(Role).order_by(Role.name)
|
||||||
|
return self.session.execute(stmt).scalars().all()
|
||||||
|
|
||||||
|
def get(self, role_id: int) -> Role:
|
||||||
|
stmt = select(Role).where(Role.id == role_id)
|
||||||
|
role = self.session.execute(stmt).scalar_one_or_none()
|
||||||
|
if role is None:
|
||||||
|
raise EntityNotFoundError(f"Role {role_id} not found")
|
||||||
|
return role
|
||||||
|
|
||||||
|
def get_by_name(self, name: str) -> Role | None:
|
||||||
|
stmt = select(Role).where(Role.name == name)
|
||||||
|
return self.session.execute(stmt).scalar_one_or_none()
|
||||||
|
|
||||||
|
def create(self, role: Role) -> Role:
|
||||||
|
self.session.add(role)
|
||||||
|
try:
|
||||||
|
self.session.flush()
|
||||||
|
except IntegrityError as exc: # pragma: no cover - DB constraint enforcement
|
||||||
|
raise EntityConflictError(
|
||||||
|
"Role violates uniqueness constraints") from exc
|
||||||
|
return role
|
||||||
|
|
||||||
|
|
||||||
|
class UserRepository:
|
||||||
|
"""Persistence operations for User entities and their role assignments."""
|
||||||
|
|
||||||
|
def __init__(self, session: Session) -> None:
|
||||||
|
self.session = session
|
||||||
|
|
||||||
|
def list(self, *, with_roles: bool = False) -> Sequence[User]:
|
||||||
|
stmt = select(User).order_by(User.created_at)
|
||||||
|
if with_roles:
|
||||||
|
stmt = stmt.options(selectinload(User.roles))
|
||||||
|
return self.session.execute(stmt).scalars().all()
|
||||||
|
|
||||||
|
def _apply_role_option(self, stmt, with_roles: bool):
|
||||||
|
if with_roles:
|
||||||
|
stmt = stmt.options(
|
||||||
|
joinedload(User.role_assignments).joinedload(UserRole.role),
|
||||||
|
selectinload(User.roles),
|
||||||
|
)
|
||||||
|
return stmt
|
||||||
|
|
||||||
|
def get(self, user_id: int, *, with_roles: bool = False) -> User:
|
||||||
|
stmt = select(User).where(User.id == user_id).execution_options(
|
||||||
|
populate_existing=True)
|
||||||
|
stmt = self._apply_role_option(stmt, with_roles)
|
||||||
|
result = self.session.execute(stmt)
|
||||||
|
if with_roles:
|
||||||
|
result = result.unique()
|
||||||
|
user = result.scalar_one_or_none()
|
||||||
|
if user is None:
|
||||||
|
raise EntityNotFoundError(f"User {user_id} not found")
|
||||||
|
return user
|
||||||
|
|
||||||
|
def get_by_email(self, email: str, *, with_roles: bool = False) -> User | None:
|
||||||
|
stmt = select(User).where(User.email == email).execution_options(
|
||||||
|
populate_existing=True)
|
||||||
|
stmt = self._apply_role_option(stmt, with_roles)
|
||||||
|
result = self.session.execute(stmt)
|
||||||
|
if with_roles:
|
||||||
|
result = result.unique()
|
||||||
|
return result.scalar_one_or_none()
|
||||||
|
|
||||||
|
def get_by_username(self, username: str, *, with_roles: bool = False) -> User | None:
|
||||||
|
stmt = select(User).where(User.username ==
|
||||||
|
username).execution_options(populate_existing=True)
|
||||||
|
stmt = self._apply_role_option(stmt, with_roles)
|
||||||
|
result = self.session.execute(stmt)
|
||||||
|
if with_roles:
|
||||||
|
result = result.unique()
|
||||||
|
return result.scalar_one_or_none()
|
||||||
|
|
||||||
|
def create(self, user: User) -> User:
|
||||||
|
self.session.add(user)
|
||||||
|
try:
|
||||||
|
self.session.flush()
|
||||||
|
except IntegrityError as exc: # pragma: no cover - DB constraint enforcement
|
||||||
|
raise EntityConflictError(
|
||||||
|
"User violates uniqueness constraints") from exc
|
||||||
|
return user
|
||||||
|
|
||||||
|
def assign_role(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
user_id: int,
|
||||||
|
role_id: int,
|
||||||
|
granted_by: int | None = None,
|
||||||
|
) -> UserRole:
|
||||||
|
stmt = select(UserRole).where(
|
||||||
|
UserRole.user_id == user_id,
|
||||||
|
UserRole.role_id == role_id,
|
||||||
|
)
|
||||||
|
assignment = self.session.execute(stmt).scalar_one_or_none()
|
||||||
|
if assignment:
|
||||||
|
return assignment
|
||||||
|
|
||||||
|
assignment = UserRole(
|
||||||
|
user_id=user_id,
|
||||||
|
role_id=role_id,
|
||||||
|
granted_by=granted_by,
|
||||||
|
)
|
||||||
|
self.session.add(assignment)
|
||||||
|
try:
|
||||||
|
self.session.flush()
|
||||||
|
except IntegrityError as exc: # pragma: no cover - DB constraint enforcement
|
||||||
|
raise EntityConflictError(
|
||||||
|
"Assignment violates constraints") from exc
|
||||||
|
return assignment
|
||||||
|
|
||||||
|
def revoke_role(self, *, user_id: int, role_id: int) -> None:
|
||||||
|
stmt = select(UserRole).where(
|
||||||
|
UserRole.user_id == user_id,
|
||||||
|
UserRole.role_id == role_id,
|
||||||
|
)
|
||||||
|
assignment = self.session.execute(stmt).scalar_one_or_none()
|
||||||
|
if assignment is None:
|
||||||
|
raise EntityNotFoundError(
|
||||||
|
f"Role {role_id} not assigned to user {user_id}")
|
||||||
|
self.session.delete(assignment)
|
||||||
|
self.session.flush()
|
||||||
|
|
||||||
|
|
||||||
|
DEFAULT_ROLE_DEFINITIONS: tuple[dict[str, str], ...] = (
|
||||||
|
{
|
||||||
|
"name": "admin",
|
||||||
|
"display_name": "Administrator",
|
||||||
|
"description": "Full platform access with user management rights.",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "project_manager",
|
||||||
|
"display_name": "Project Manager",
|
||||||
|
"description": "Manage projects, scenarios, and associated data.",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "analyst",
|
||||||
|
"display_name": "Analyst",
|
||||||
|
"description": "Review dashboards and scenario outputs.",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "viewer",
|
||||||
|
"display_name": "Viewer",
|
||||||
|
"description": "Read-only access to assigned projects and reports.",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def ensure_default_roles(role_repo: RoleRepository) -> list[Role]:
|
||||||
|
"""Ensure standard roles exist, creating missing ones.
|
||||||
|
|
||||||
|
Returns all current role records in creation order.
|
||||||
|
"""
|
||||||
|
|
||||||
|
roles: list[Role] = []
|
||||||
|
for definition in DEFAULT_ROLE_DEFINITIONS:
|
||||||
|
existing = role_repo.get_by_name(definition["name"])
|
||||||
|
if existing:
|
||||||
|
roles.append(existing)
|
||||||
|
continue
|
||||||
|
role = Role(**definition)
|
||||||
|
roles.append(role_repo.create(role))
|
||||||
|
return roles
|
||||||
|
|
||||||
|
|
||||||
|
def ensure_admin_user(
|
||||||
|
user_repo: UserRepository,
|
||||||
|
role_repo: RoleRepository,
|
||||||
|
*,
|
||||||
|
email: str,
|
||||||
|
username: str,
|
||||||
|
password: str,
|
||||||
|
) -> User:
|
||||||
|
"""Ensure an administrator user exists and holds the admin role."""
|
||||||
|
|
||||||
|
user = user_repo.get_by_email(email, with_roles=True)
|
||||||
|
if user is None:
|
||||||
|
user = User(
|
||||||
|
email=email,
|
||||||
|
username=username,
|
||||||
|
password_hash=User.hash_password(password),
|
||||||
|
is_active=True,
|
||||||
|
is_superuser=True,
|
||||||
|
)
|
||||||
|
user_repo.create(user)
|
||||||
|
else:
|
||||||
|
if not user.is_active:
|
||||||
|
user.is_active = True
|
||||||
|
if not user.is_superuser:
|
||||||
|
user.is_superuser = True
|
||||||
|
user_repo.session.flush()
|
||||||
|
|
||||||
|
admin_role = role_repo.get_by_name("admin")
|
||||||
|
if admin_role is None: # pragma: no cover - safety if ensure_default_roles wasn't called
|
||||||
|
admin_role = role_repo.create(
|
||||||
|
Role(
|
||||||
|
name="admin",
|
||||||
|
display_name="Administrator",
|
||||||
|
description="Full platform access with user management rights.",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
user_repo.assign_role(
|
||||||
|
user_id=user.id,
|
||||||
|
role_id=admin_role.id,
|
||||||
|
granted_by=user.id,
|
||||||
|
)
|
||||||
|
return user
|
||||||
|
|||||||
213
services/security.py
Normal file
213
services/security.py
Normal file
@@ -0,0 +1,213 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from datetime import datetime, timedelta, timezone
|
||||||
|
from typing import Any, Dict, Iterable, Literal, Type
|
||||||
|
|
||||||
|
from jose import ExpiredSignatureError, JWTError, jwt
|
||||||
|
from passlib.context import CryptContext
|
||||||
|
|
||||||
|
try: # pragma: no cover - compatibility shim for passlib/argon2 warning
|
||||||
|
import importlib.metadata as importlib_metadata
|
||||||
|
import argon2 # type: ignore
|
||||||
|
|
||||||
|
setattr(argon2, "__version__", importlib_metadata.version("argon2-cffi"))
|
||||||
|
except Exception: # pragma: no cover - executed only when metadata lookup fails
|
||||||
|
pass
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field, ValidationError
|
||||||
|
|
||||||
|
password_context = CryptContext(schemes=["argon2"], deprecated="auto")
|
||||||
|
|
||||||
|
|
||||||
|
def hash_password(password: str) -> str:
|
||||||
|
"""Derive a secure hash for a plain-text password."""
|
||||||
|
|
||||||
|
return password_context.hash(password)
|
||||||
|
|
||||||
|
|
||||||
|
def verify_password(candidate: str, hashed: str) -> bool:
|
||||||
|
"""Verify that a candidate password matches a stored hash."""
|
||||||
|
|
||||||
|
try:
|
||||||
|
return password_context.verify(candidate, hashed)
|
||||||
|
except ValueError:
|
||||||
|
# Raised when the stored hash is malformed or uses an unknown scheme.
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
class TokenError(Exception):
|
||||||
|
"""Base class for token encoding/decoding issues."""
|
||||||
|
|
||||||
|
|
||||||
|
class TokenDecodeError(TokenError):
|
||||||
|
"""Raised when a token cannot be decoded or validated."""
|
||||||
|
|
||||||
|
|
||||||
|
class TokenExpiredError(TokenError):
|
||||||
|
"""Raised when a token has expired."""
|
||||||
|
|
||||||
|
|
||||||
|
class TokenTypeMismatchError(TokenError):
|
||||||
|
"""Raised when a token type does not match the expected flavour."""
|
||||||
|
|
||||||
|
|
||||||
|
TokenKind = Literal["access", "refresh"]
|
||||||
|
|
||||||
|
|
||||||
|
class TokenPayload(BaseModel):
|
||||||
|
"""Shared fields for CalMiner JWT payloads."""
|
||||||
|
|
||||||
|
sub: str
|
||||||
|
exp: int
|
||||||
|
type: TokenKind
|
||||||
|
scopes: list[str] = Field(default_factory=list)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def expires_at(self) -> datetime:
|
||||||
|
return datetime.fromtimestamp(self.exp, tz=timezone.utc)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(slots=True)
|
||||||
|
class JWTSettings:
|
||||||
|
"""Runtime configuration for JWT encoding and validation."""
|
||||||
|
|
||||||
|
secret_key: str
|
||||||
|
algorithm: str = "HS256"
|
||||||
|
access_token_ttl: timedelta = field(
|
||||||
|
default_factory=lambda: timedelta(minutes=15))
|
||||||
|
refresh_token_ttl: timedelta = field(
|
||||||
|
default_factory=lambda: timedelta(days=7))
|
||||||
|
|
||||||
|
|
||||||
|
def create_access_token(
|
||||||
|
subject: str,
|
||||||
|
settings: JWTSettings,
|
||||||
|
*,
|
||||||
|
scopes: Iterable[str] | None = None,
|
||||||
|
expires_delta: timedelta | None = None,
|
||||||
|
extra_claims: Dict[str, Any] | None = None,
|
||||||
|
) -> str:
|
||||||
|
"""Issue a signed access token for the provided subject."""
|
||||||
|
|
||||||
|
lifetime = expires_delta or settings.access_token_ttl
|
||||||
|
return _create_token(
|
||||||
|
subject=subject,
|
||||||
|
token_type="access",
|
||||||
|
settings=settings,
|
||||||
|
lifetime=lifetime,
|
||||||
|
scopes=scopes,
|
||||||
|
extra_claims=extra_claims,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def create_refresh_token(
|
||||||
|
subject: str,
|
||||||
|
settings: JWTSettings,
|
||||||
|
*,
|
||||||
|
scopes: Iterable[str] | None = None,
|
||||||
|
expires_delta: timedelta | None = None,
|
||||||
|
extra_claims: Dict[str, Any] | None = None,
|
||||||
|
) -> str:
|
||||||
|
"""Issue a signed refresh token for the provided subject."""
|
||||||
|
|
||||||
|
lifetime = expires_delta or settings.refresh_token_ttl
|
||||||
|
return _create_token(
|
||||||
|
subject=subject,
|
||||||
|
token_type="refresh",
|
||||||
|
settings=settings,
|
||||||
|
lifetime=lifetime,
|
||||||
|
scopes=scopes,
|
||||||
|
extra_claims=extra_claims,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def decode_access_token(token: str, settings: JWTSettings) -> TokenPayload:
|
||||||
|
"""Validate and decode an access token."""
|
||||||
|
|
||||||
|
return _decode_token(token, settings, expected_type="access")
|
||||||
|
|
||||||
|
|
||||||
|
def decode_refresh_token(token: str, settings: JWTSettings) -> TokenPayload:
|
||||||
|
"""Validate and decode a refresh token."""
|
||||||
|
|
||||||
|
return _decode_token(token, settings, expected_type="refresh")
|
||||||
|
|
||||||
|
|
||||||
|
def _create_token(
|
||||||
|
*,
|
||||||
|
subject: str,
|
||||||
|
token_type: TokenKind,
|
||||||
|
settings: JWTSettings,
|
||||||
|
lifetime: timedelta,
|
||||||
|
scopes: Iterable[str] | None,
|
||||||
|
extra_claims: Dict[str, Any] | None,
|
||||||
|
) -> str:
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
expire = now + lifetime
|
||||||
|
payload: Dict[str, Any] = {
|
||||||
|
"sub": subject,
|
||||||
|
"type": token_type,
|
||||||
|
"iat": int(now.timestamp()),
|
||||||
|
"exp": int(expire.timestamp()),
|
||||||
|
}
|
||||||
|
if scopes:
|
||||||
|
payload["scopes"] = list(scopes)
|
||||||
|
if extra_claims:
|
||||||
|
payload.update(extra_claims)
|
||||||
|
|
||||||
|
return jwt.encode(payload, settings.secret_key, algorithm=settings.algorithm)
|
||||||
|
|
||||||
|
|
||||||
|
def _decode_token(
|
||||||
|
token: str,
|
||||||
|
settings: JWTSettings,
|
||||||
|
expected_type: TokenKind,
|
||||||
|
) -> TokenPayload:
|
||||||
|
try:
|
||||||
|
decoded = jwt.decode(
|
||||||
|
token,
|
||||||
|
settings.secret_key,
|
||||||
|
algorithms=[settings.algorithm],
|
||||||
|
options={"verify_aud": False},
|
||||||
|
)
|
||||||
|
except ExpiredSignatureError as exc: # pragma: no cover - jose marks this path
|
||||||
|
raise TokenExpiredError("Token has expired") from exc
|
||||||
|
except JWTError as exc: # pragma: no cover - jose error bubble
|
||||||
|
raise TokenDecodeError("Unable to decode token") from exc
|
||||||
|
|
||||||
|
try:
|
||||||
|
payload = _model_validate(TokenPayload, decoded)
|
||||||
|
except ValidationError as exc:
|
||||||
|
raise TokenDecodeError("Token payload validation failed") from exc
|
||||||
|
|
||||||
|
if payload.type != expected_type:
|
||||||
|
raise TokenTypeMismatchError(
|
||||||
|
f"Expected a {expected_type} token but received '{payload.type}'."
|
||||||
|
)
|
||||||
|
|
||||||
|
return payload
|
||||||
|
|
||||||
|
|
||||||
|
def _model_validate(model: Type[TokenPayload], data: Dict[str, Any]) -> TokenPayload:
|
||||||
|
if hasattr(model, "model_validate"):
|
||||||
|
return model.model_validate(data) # type: ignore[attr-defined]
|
||||||
|
return model.parse_obj(data) # type: ignore[attr-defined]
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"JWTSettings",
|
||||||
|
"TokenDecodeError",
|
||||||
|
"TokenError",
|
||||||
|
"TokenExpiredError",
|
||||||
|
"TokenKind",
|
||||||
|
"TokenPayload",
|
||||||
|
"TokenTypeMismatchError",
|
||||||
|
"create_access_token",
|
||||||
|
"create_refresh_token",
|
||||||
|
"decode_access_token",
|
||||||
|
"decode_refresh_token",
|
||||||
|
"hash_password",
|
||||||
|
"password_context",
|
||||||
|
"verify_password",
|
||||||
|
]
|
||||||
@@ -6,12 +6,16 @@ from typing import Callable, Sequence
|
|||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from config.database import SessionLocal
|
from config.database import SessionLocal
|
||||||
from models import Scenario
|
from models import Role, Scenario
|
||||||
from services.repositories import (
|
from services.repositories import (
|
||||||
FinancialInputRepository,
|
FinancialInputRepository,
|
||||||
ProjectRepository,
|
ProjectRepository,
|
||||||
|
RoleRepository,
|
||||||
ScenarioRepository,
|
ScenarioRepository,
|
||||||
SimulationParameterRepository,
|
SimulationParameterRepository,
|
||||||
|
UserRepository,
|
||||||
|
ensure_admin_user as ensure_admin_user_record,
|
||||||
|
ensure_default_roles,
|
||||||
)
|
)
|
||||||
from services.scenario_validation import ScenarioComparisonValidator
|
from services.scenario_validation import ScenarioComparisonValidator
|
||||||
|
|
||||||
@@ -23,6 +27,12 @@ class UnitOfWork(AbstractContextManager["UnitOfWork"]):
|
|||||||
self._session_factory = session_factory
|
self._session_factory = session_factory
|
||||||
self.session: Session | None = None
|
self.session: Session | None = None
|
||||||
self._scenario_validator: ScenarioComparisonValidator | None = None
|
self._scenario_validator: ScenarioComparisonValidator | None = None
|
||||||
|
self.projects: ProjectRepository | None = None
|
||||||
|
self.scenarios: ScenarioRepository | None = None
|
||||||
|
self.financial_inputs: FinancialInputRepository | None = None
|
||||||
|
self.simulation_parameters: SimulationParameterRepository | None = None
|
||||||
|
self.users: UserRepository | None = None
|
||||||
|
self.roles: RoleRepository | None = None
|
||||||
|
|
||||||
def __enter__(self) -> "UnitOfWork":
|
def __enter__(self) -> "UnitOfWork":
|
||||||
self.session = self._session_factory()
|
self.session = self._session_factory()
|
||||||
@@ -31,6 +41,8 @@ class UnitOfWork(AbstractContextManager["UnitOfWork"]):
|
|||||||
self.financial_inputs = FinancialInputRepository(self.session)
|
self.financial_inputs = FinancialInputRepository(self.session)
|
||||||
self.simulation_parameters = SimulationParameterRepository(
|
self.simulation_parameters = SimulationParameterRepository(
|
||||||
self.session)
|
self.session)
|
||||||
|
self.users = UserRepository(self.session)
|
||||||
|
self.roles = RoleRepository(self.session)
|
||||||
self._scenario_validator = ScenarioComparisonValidator()
|
self._scenario_validator = ScenarioComparisonValidator()
|
||||||
return self
|
return self
|
||||||
|
|
||||||
@@ -42,6 +54,12 @@ class UnitOfWork(AbstractContextManager["UnitOfWork"]):
|
|||||||
self.session.rollback()
|
self.session.rollback()
|
||||||
self.session.close()
|
self.session.close()
|
||||||
self._scenario_validator = None
|
self._scenario_validator = None
|
||||||
|
self.projects = None
|
||||||
|
self.scenarios = None
|
||||||
|
self.financial_inputs = None
|
||||||
|
self.simulation_parameters = None
|
||||||
|
self.users = None
|
||||||
|
self.roles = None
|
||||||
|
|
||||||
def flush(self) -> None:
|
def flush(self) -> None:
|
||||||
if not self.session:
|
if not self.session:
|
||||||
@@ -61,7 +79,7 @@ class UnitOfWork(AbstractContextManager["UnitOfWork"]):
|
|||||||
def validate_scenarios_for_comparison(
|
def validate_scenarios_for_comparison(
|
||||||
self, scenario_ids: Sequence[int]
|
self, scenario_ids: Sequence[int]
|
||||||
) -> list[Scenario]:
|
) -> list[Scenario]:
|
||||||
if not self.session or not self._scenario_validator:
|
if not self.session or not self._scenario_validator or not self.scenarios:
|
||||||
raise RuntimeError("UnitOfWork session is not initialised")
|
raise RuntimeError("UnitOfWork session is not initialised")
|
||||||
|
|
||||||
scenarios = [self.scenarios.get(scenario_id)
|
scenarios = [self.scenarios.get(scenario_id)
|
||||||
@@ -75,3 +93,26 @@ class UnitOfWork(AbstractContextManager["UnitOfWork"]):
|
|||||||
if not self._scenario_validator:
|
if not self._scenario_validator:
|
||||||
raise RuntimeError("UnitOfWork session is not initialised")
|
raise RuntimeError("UnitOfWork session is not initialised")
|
||||||
self._scenario_validator.validate(scenarios)
|
self._scenario_validator.validate(scenarios)
|
||||||
|
|
||||||
|
def ensure_default_roles(self) -> list[Role]:
|
||||||
|
if not self.roles:
|
||||||
|
raise RuntimeError("UnitOfWork session is not initialised")
|
||||||
|
return ensure_default_roles(self.roles)
|
||||||
|
|
||||||
|
def ensure_admin_user(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
email: str,
|
||||||
|
username: str,
|
||||||
|
password: str,
|
||||||
|
) -> None:
|
||||||
|
if not self.users or not self.roles:
|
||||||
|
raise RuntimeError("UnitOfWork session is not initialised")
|
||||||
|
ensure_default_roles(self.roles)
|
||||||
|
ensure_admin_user_record(
|
||||||
|
self.users,
|
||||||
|
self.roles,
|
||||||
|
email=email,
|
||||||
|
username=username,
|
||||||
|
password=password,
|
||||||
|
)
|
||||||
|
|||||||
BIN
static/img/logo_big.png
Normal file
BIN
static/img/logo_big.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 1.8 MiB |
@@ -1,14 +1,22 @@
|
|||||||
{% extends "base.html" %}
|
{% extends "base.html" %} {% block title %}Forgot Password{% endblock %} {%
|
||||||
|
block content %}
|
||||||
{% block title %}Forgot Password{% endblock %}
|
|
||||||
|
|
||||||
{% block content %}
|
|
||||||
<div class="container">
|
<div class="container">
|
||||||
<h1>Forgot Password</h1>
|
<h1>Forgot Password</h1>
|
||||||
<form id="forgot-password-form">
|
{% if errors %}
|
||||||
|
<div class="alert alert-error">
|
||||||
|
<ul>
|
||||||
|
{% for error in errors %}
|
||||||
|
<li>{{ error }}</li>
|
||||||
|
{% endfor %}
|
||||||
|
</ul>
|
||||||
|
</div>
|
||||||
|
{% endif %} {% if message %}
|
||||||
|
<div class="alert alert-info">{{ message }}</div>
|
||||||
|
{% endif %}
|
||||||
|
<form id="forgot-password-form" method="post" action="{{ form_action }}">
|
||||||
<div class="form-group">
|
<div class="form-group">
|
||||||
<label for="email">Email:</label>
|
<label for="email">Email:</label>
|
||||||
<input type="email" id="email" name="email" required>
|
<input type="email" id="email" name="email" required />
|
||||||
</div>
|
</div>
|
||||||
<button type="submit">Reset Password</button>
|
<button type="submit">Reset Password</button>
|
||||||
</form>
|
</form>
|
||||||
|
|||||||
@@ -1,18 +1,30 @@
|
|||||||
{% extends "base.html" %}
|
{% extends "base.html" %} {% block title %}Login{% endblock %} {% block content
|
||||||
|
%}
|
||||||
{% block title %}Login{% endblock %}
|
|
||||||
|
|
||||||
{% block content %}
|
|
||||||
<div class="container">
|
<div class="container">
|
||||||
<h1>Login</h1>
|
<h1>Login</h1>
|
||||||
<form id="login-form">
|
{% if errors %}
|
||||||
|
<div class="alert alert-error">
|
||||||
|
<ul>
|
||||||
|
{% for error in errors %}
|
||||||
|
<li>{{ error }}</li>
|
||||||
|
{% endfor %}
|
||||||
|
</ul>
|
||||||
|
</div>
|
||||||
|
{% endif %}
|
||||||
|
<form id="login-form" method="post" action="{{ form_action }}">
|
||||||
<div class="form-group">
|
<div class="form-group">
|
||||||
<label for="username">Username:</label>
|
<label for="username">Username:</label>
|
||||||
<input type="text" id="username" name="username" required>
|
<input
|
||||||
|
type="text"
|
||||||
|
id="username"
|
||||||
|
name="username"
|
||||||
|
value="{{ username | default('') }}"
|
||||||
|
required
|
||||||
|
/>
|
||||||
</div>
|
</div>
|
||||||
<div class="form-group">
|
<div class="form-group">
|
||||||
<label for="password">Password:</label>
|
<label for="password">Password:</label>
|
||||||
<input type="password" id="password" name="password" required>
|
<input type="password" id="password" name="password" required />
|
||||||
</div>
|
</div>
|
||||||
<button type="submit">Login</button>
|
<button type="submit">Login</button>
|
||||||
</form>
|
</form>
|
||||||
|
|||||||
@@ -1,52 +1,35 @@
|
|||||||
{% set dashboard_href = request.url_for('dashboard.home') if request else '/' %}
|
{% set dashboard_href = request.url_for('dashboard.home') if request else '/' %}
|
||||||
{% set projects_href = request.url_for('projects.project_list_page') if request else '/projects/ui' %}
|
{% set projects_href = request.url_for('projects.project_list_page') if request
|
||||||
{% set project_create_href = request.url_for('projects.create_project_form') if request else '/projects/create' %}
|
else '/projects/ui' %} {% set project_create_href =
|
||||||
|
request.url_for('projects.create_project_form') if request else
|
||||||
{% set nav_groups = [
|
'/projects/create' %} {% set login_href = request.url_for('auth.login_form') if
|
||||||
{
|
request else '/login' %} {% set register_href =
|
||||||
"label": "Workspace",
|
request.url_for('auth.register_form') if request else '/register' %} {% set
|
||||||
"links": [
|
forgot_href = request.url_for('auth.password_reset_request_form') if request
|
||||||
{"href": dashboard_href, "label": "Dashboard", "match_prefix": "/"},
|
else '/forgot-password' %} {% set nav_groups = [ { "label": "Workspace",
|
||||||
|
"links": [ {"href": dashboard_href, "label": "Dashboard", "match_prefix": "/"},
|
||||||
{"href": projects_href, "label": "Projects", "match_prefix": "/projects"},
|
{"href": projects_href, "label": "Projects", "match_prefix": "/projects"},
|
||||||
{"href": project_create_href, "label": "New Project", "match_prefix": "/projects/create"},
|
{"href": project_create_href, "label": "New Project", "match_prefix":
|
||||||
],
|
"/projects/create"}, ], }, { "label": "Insights", "links": [ {"href":
|
||||||
},
|
"/ui/simulations", "label": "Simulations"}, {"href": "/ui/reporting", "label":
|
||||||
{
|
"Reporting"}, ], }, { "label": "Configuration", "links": [ { "href":
|
||||||
"label": "Insights",
|
"/ui/settings", "label": "Settings", "children": [ {"href": "/theme-settings",
|
||||||
"links": [
|
"label": "Themes"}, {"href": "/ui/currencies", "label": "Currency Management"},
|
||||||
{"href": "/ui/simulations", "label": "Simulations"},
|
], }, ], }, { "label": "Account", "links": [ {"href": login_href, "label":
|
||||||
{"href": "/ui/reporting", "label": "Reporting"},
|
"Login", "match_prefix": "/login"}, {"href": register_href, "label": "Register",
|
||||||
],
|
"match_prefix": "/register"}, {"href": forgot_href, "label": "Forgot Password",
|
||||||
},
|
"match_prefix": "/forgot-password"}, ], }, ] %}
|
||||||
{
|
|
||||||
"label": "Configuration",
|
|
||||||
"links": [
|
|
||||||
{
|
|
||||||
"href": "/ui/settings",
|
|
||||||
"label": "Settings",
|
|
||||||
"children": [
|
|
||||||
{"href": "/theme-settings", "label": "Themes"},
|
|
||||||
{"href": "/ui/currencies", "label": "Currency Management"},
|
|
||||||
],
|
|
||||||
},
|
|
||||||
],
|
|
||||||
},
|
|
||||||
] %}
|
|
||||||
|
|
||||||
<nav class="sidebar-nav" aria-label="Primary navigation">
|
<nav class="sidebar-nav" aria-label="Primary navigation">
|
||||||
{% set current_path = request.url.path if request else "" %}
|
{% set current_path = request.url.path if request else "" %} {% for group in
|
||||||
{% for group in nav_groups %}
|
nav_groups %}
|
||||||
<div class="sidebar-section">
|
<div class="sidebar-section">
|
||||||
<div class="sidebar-section-label">{{ group.label }}</div>
|
<div class="sidebar-section-label">{{ group.label }}</div>
|
||||||
<div class="sidebar-section-links">
|
<div class="sidebar-section-links">
|
||||||
{% for link in group.links %}
|
{% for link in group.links %} {% set href = link.href %} {% set
|
||||||
{% set href = link.href %}
|
match_prefix = link.get('match_prefix', href) %} {% if match_prefix == '/'
|
||||||
{% set match_prefix = link.get('match_prefix', href) %}
|
%} {% set is_active = current_path == '/' %} {% else %} {% set is_active =
|
||||||
{% if match_prefix == '/' %}
|
current_path.startswith(match_prefix) %} {% endif %}
|
||||||
{% set is_active = current_path == '/' %}
|
|
||||||
{% else %}
|
|
||||||
{% set is_active = current_path.startswith(match_prefix) %}
|
|
||||||
{% endif %}
|
|
||||||
<div class="sidebar-link-block">
|
<div class="sidebar-link-block">
|
||||||
<a
|
<a
|
||||||
href="{{ href }}"
|
href="{{ href }}"
|
||||||
@@ -56,13 +39,10 @@
|
|||||||
</a>
|
</a>
|
||||||
{% if link.children %}
|
{% if link.children %}
|
||||||
<div class="sidebar-sublinks">
|
<div class="sidebar-sublinks">
|
||||||
{% for child in link.children %}
|
{% for child in link.children %} {% set child_prefix =
|
||||||
{% set child_prefix = child.get('match_prefix', child.href) %}
|
child.get('match_prefix', child.href) %} {% if child_prefix == '/' %}
|
||||||
{% if child_prefix == '/' %}
|
{% set child_active = current_path == '/' %} {% else %} {% set
|
||||||
{% set child_active = current_path == '/' %}
|
child_active = current_path.startswith(child_prefix) %} {% endif %}
|
||||||
{% else %}
|
|
||||||
{% set child_active = current_path.startswith(child_prefix) %}
|
|
||||||
{% endif %}
|
|
||||||
<a
|
<a
|
||||||
href="{{ child.href }}"
|
href="{{ child.href }}"
|
||||||
class="sidebar-sublink{% if child_active %} is-active{% endif %}"
|
class="sidebar-sublink{% if child_active %} is-active{% endif %}"
|
||||||
|
|||||||
@@ -1,22 +1,40 @@
|
|||||||
{% extends "base.html" %}
|
{% extends "base.html" %} {% block title %}Register{% endblock %} {% block
|
||||||
|
content %}
|
||||||
{% block title %}Register{% endblock %}
|
|
||||||
|
|
||||||
{% block content %}
|
|
||||||
<div class="container">
|
<div class="container">
|
||||||
<h1>Register</h1>
|
<h1>Register</h1>
|
||||||
<form id="register-form">
|
{% if errors %}
|
||||||
|
<div class="alert alert-error">
|
||||||
|
<ul>
|
||||||
|
{% for error in errors %}
|
||||||
|
<li>{{ error }}</li>
|
||||||
|
{% endfor %}
|
||||||
|
</ul>
|
||||||
|
</div>
|
||||||
|
{% endif %}
|
||||||
|
<form id="register-form" method="post" action="{{ form_action }}">
|
||||||
<div class="form-group">
|
<div class="form-group">
|
||||||
<label for="username">Username:</label>
|
<label for="username">Username:</label>
|
||||||
<input type="text" id="username" name="username" required>
|
<input
|
||||||
|
type="text"
|
||||||
|
id="username"
|
||||||
|
name="username"
|
||||||
|
value="{{ form_data.username if form_data else '' }}"
|
||||||
|
required
|
||||||
|
/>
|
||||||
</div>
|
</div>
|
||||||
<div class="form-group">
|
<div class="form-group">
|
||||||
<label for="email">Email:</label>
|
<label for="email">Email:</label>
|
||||||
<input type="email" id="email" name="email" required>
|
<input
|
||||||
|
type="email"
|
||||||
|
id="email"
|
||||||
|
name="email"
|
||||||
|
value="{{ form_data.email if form_data else '' }}"
|
||||||
|
required
|
||||||
|
/>
|
||||||
</div>
|
</div>
|
||||||
<div class="form-group">
|
<div class="form-group">
|
||||||
<label for="password">Password:</label>
|
<label for="password">Password:</label>
|
||||||
<input type="password" id="password" name="password" required>
|
<input type="password" id="password" name="password" required />
|
||||||
</div>
|
</div>
|
||||||
<button type="submit">Register</button>
|
<button type="submit">Register</button>
|
||||||
</form>
|
</form>
|
||||||
|
|||||||
36
templates/reset_password.html
Normal file
36
templates/reset_password.html
Normal file
@@ -0,0 +1,36 @@
|
|||||||
|
{% extends "base.html" %} {% block title %}Reset Password{% endblock %} {% block
|
||||||
|
content %}
|
||||||
|
<div class="container">
|
||||||
|
<h1>Reset Password</h1>
|
||||||
|
{% if errors %}
|
||||||
|
<div class="alert alert-error">
|
||||||
|
<ul>
|
||||||
|
{% for error in errors %}
|
||||||
|
<li>{{ error }}</li>
|
||||||
|
{% endfor %}
|
||||||
|
</ul>
|
||||||
|
</div>
|
||||||
|
{% endif %}
|
||||||
|
<form id="reset-password-form" method="post" action="{{ form_action }}">
|
||||||
|
<input type="hidden" name="token" value="{{ token | default('') }}" />
|
||||||
|
<div class="form-group">
|
||||||
|
<label for="password">New Password:</label>
|
||||||
|
<input type="password" id="password" name="password" required />
|
||||||
|
</div>
|
||||||
|
<div class="form-group">
|
||||||
|
<label for="confirm_password">Confirm Password:</label>
|
||||||
|
<input
|
||||||
|
type="password"
|
||||||
|
id="confirm_password"
|
||||||
|
name="confirm_password"
|
||||||
|
required
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
<button type="submit">Update Password</button>
|
||||||
|
</form>
|
||||||
|
<p>
|
||||||
|
Remembered your password?
|
||||||
|
<a href="{{ request.url_for('auth.login_form') }}">Return to login</a>
|
||||||
|
</p>
|
||||||
|
</div>
|
||||||
|
{% endblock %}
|
||||||
@@ -12,6 +12,7 @@ from sqlalchemy.pool import StaticPool
|
|||||||
|
|
||||||
from config.database import Base
|
from config.database import Base
|
||||||
from dependencies import get_unit_of_work
|
from dependencies import get_unit_of_work
|
||||||
|
from routes.auth import router as auth_router
|
||||||
from routes.dashboard import router as dashboard_router
|
from routes.dashboard import router as dashboard_router
|
||||||
from routes.projects import router as projects_router
|
from routes.projects import router as projects_router
|
||||||
from routes.scenarios import router as scenarios_router
|
from routes.scenarios import router as scenarios_router
|
||||||
@@ -36,13 +37,15 @@ def engine() -> Iterator[Engine]:
|
|||||||
|
|
||||||
@pytest.fixture()
|
@pytest.fixture()
|
||||||
def session_factory(engine: Engine) -> Iterator[sessionmaker]:
|
def session_factory(engine: Engine) -> Iterator[sessionmaker]:
|
||||||
testing_session = sessionmaker(bind=engine, expire_on_commit=False, future=True)
|
testing_session = sessionmaker(
|
||||||
|
bind=engine, expire_on_commit=False, future=True)
|
||||||
yield testing_session
|
yield testing_session
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture()
|
@pytest.fixture()
|
||||||
def app(session_factory: sessionmaker) -> FastAPI:
|
def app(session_factory: sessionmaker) -> FastAPI:
|
||||||
application = FastAPI()
|
application = FastAPI()
|
||||||
|
application.include_router(auth_router)
|
||||||
application.include_router(dashboard_router)
|
application.include_router(dashboard_router)
|
||||||
application.include_router(projects_router)
|
application.include_router(projects_router)
|
||||||
application.include_router(scenarios_router)
|
application.include_router(scenarios_router)
|
||||||
|
|||||||
135
tests/test_auth_repositories.py
Normal file
135
tests/test_auth_repositories.py
Normal file
@@ -0,0 +1,135 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections.abc import Iterator
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from sqlalchemy import create_engine
|
||||||
|
from sqlalchemy.orm import Session, sessionmaker
|
||||||
|
|
||||||
|
from config.database import Base
|
||||||
|
from models import Role, User
|
||||||
|
from services.repositories import (
|
||||||
|
RoleRepository,
|
||||||
|
UserRepository,
|
||||||
|
ensure_admin_user,
|
||||||
|
ensure_default_roles,
|
||||||
|
)
|
||||||
|
from services.unit_of_work import UnitOfWork
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def engine() -> Iterator:
|
||||||
|
engine = create_engine("sqlite:///:memory:", future=True)
|
||||||
|
Base.metadata.create_all(bind=engine)
|
||||||
|
try:
|
||||||
|
yield engine
|
||||||
|
finally:
|
||||||
|
Base.metadata.drop_all(bind=engine)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def session(engine) -> Iterator[Session]:
|
||||||
|
TestingSession = sessionmaker(
|
||||||
|
bind=engine, expire_on_commit=False, future=True)
|
||||||
|
db = TestingSession()
|
||||||
|
try:
|
||||||
|
yield db
|
||||||
|
finally:
|
||||||
|
db.close()
|
||||||
|
|
||||||
|
|
||||||
|
def test_role_repository_create_and_lookup(session: Session) -> None:
|
||||||
|
repo = RoleRepository(session)
|
||||||
|
role = Role(name="custom", display_name="Custom",
|
||||||
|
description="Custom role")
|
||||||
|
repo.create(role)
|
||||||
|
|
||||||
|
retrieved = repo.get(role.id)
|
||||||
|
assert retrieved.name == "custom"
|
||||||
|
assert repo.get_by_name("custom") is retrieved
|
||||||
|
assert repo.list()[0].name == "custom"
|
||||||
|
|
||||||
|
|
||||||
|
def test_user_repository_assign_and_revoke_role(session: Session) -> None:
|
||||||
|
role_repo = RoleRepository(session)
|
||||||
|
user_repo = UserRepository(session)
|
||||||
|
|
||||||
|
analyst = role_repo.create(
|
||||||
|
Role(name="analyst", display_name="Analyst", description="Analyzes data")
|
||||||
|
)
|
||||||
|
user = User(
|
||||||
|
email="user@example.com",
|
||||||
|
username="user",
|
||||||
|
password_hash=User.hash_password("secret"),
|
||||||
|
)
|
||||||
|
user_repo.create(user)
|
||||||
|
|
||||||
|
assignment = user_repo.assign_role(
|
||||||
|
user_id=user.id, role_id=analyst.id, granted_by=None)
|
||||||
|
assert assignment.role_id == analyst.id
|
||||||
|
|
||||||
|
refreshed = user_repo.get(user.id, with_roles=True)
|
||||||
|
assert refreshed.roles[0].name == "analyst"
|
||||||
|
|
||||||
|
user_repo.revoke_role(user_id=user.id, role_id=analyst.id)
|
||||||
|
refreshed = user_repo.get(user.id, with_roles=True)
|
||||||
|
assert refreshed.roles == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_default_role_and_admin_helpers(session: Session) -> None:
|
||||||
|
role_repo = RoleRepository(session)
|
||||||
|
user_repo = UserRepository(session)
|
||||||
|
|
||||||
|
roles = ensure_default_roles(role_repo)
|
||||||
|
assert {role.name for role in roles} == {
|
||||||
|
"admin", "project_manager", "analyst", "viewer"}
|
||||||
|
|
||||||
|
ensure_admin_user(
|
||||||
|
user_repo,
|
||||||
|
role_repo,
|
||||||
|
email="admin@example.com",
|
||||||
|
username="admin",
|
||||||
|
password="SecurePass1!",
|
||||||
|
)
|
||||||
|
|
||||||
|
admin = user_repo.get_by_email("admin@example.com", with_roles=True)
|
||||||
|
assert admin is not None
|
||||||
|
assert admin.is_superuser
|
||||||
|
assert {role.name for role in admin.roles} >= {"admin"}
|
||||||
|
|
||||||
|
# Idempotent behaviour on subsequent calls
|
||||||
|
ensure_admin_user(
|
||||||
|
user_repo,
|
||||||
|
role_repo,
|
||||||
|
email="admin@example.com",
|
||||||
|
username="admin",
|
||||||
|
password="SecurePass1!",
|
||||||
|
)
|
||||||
|
admin_again = user_repo.get_by_email("admin@example.com", with_roles=True)
|
||||||
|
assert admin_again is not None
|
||||||
|
assert len(admin_again.roles) == len(
|
||||||
|
{role.name for role in admin_again.roles})
|
||||||
|
|
||||||
|
|
||||||
|
def test_unit_of_work_exposes_auth_repositories(engine) -> None:
|
||||||
|
TestingSession = sessionmaker(
|
||||||
|
bind=engine, expire_on_commit=False, future=True)
|
||||||
|
|
||||||
|
with UnitOfWork(session_factory=TestingSession) as uow:
|
||||||
|
assert uow.users is not None
|
||||||
|
assert uow.roles is not None
|
||||||
|
|
||||||
|
roles = uow.ensure_default_roles()
|
||||||
|
assert any(role.name == "admin" for role in roles)
|
||||||
|
|
||||||
|
uow.ensure_admin_user(
|
||||||
|
email="uow-admin@example.com",
|
||||||
|
username="uow-admin",
|
||||||
|
password="AnotherSecret1!",
|
||||||
|
)
|
||||||
|
|
||||||
|
admin = uow.users.get_by_email(
|
||||||
|
"uow-admin@example.com", with_roles=True)
|
||||||
|
assert admin is not None
|
||||||
|
assert admin.is_superuser
|
||||||
|
assert any(role.name == "admin" for role in admin.roles)
|
||||||
239
tests/test_auth_routes.py
Normal file
239
tests/test_auth_routes.py
Normal file
@@ -0,0 +1,239 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections.abc import Iterator
|
||||||
|
from urllib.parse import parse_qs, urlparse
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.orm import Session, sessionmaker
|
||||||
|
|
||||||
|
from models import Role, User, UserRole
|
||||||
|
from services.security import hash_password
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def db_session(session_factory: sessionmaker) -> Iterator[Session]:
|
||||||
|
session = session_factory()
|
||||||
|
try:
|
||||||
|
yield session
|
||||||
|
finally:
|
||||||
|
session.close()
|
||||||
|
|
||||||
|
|
||||||
|
def _get_user(session: Session, *, email: str | None = None, username: str | None = None) -> User | None:
|
||||||
|
stmt = select(User)
|
||||||
|
if email is not None:
|
||||||
|
stmt = stmt.where(User.email == email)
|
||||||
|
if username is not None:
|
||||||
|
stmt = stmt.where(User.username == username)
|
||||||
|
return session.execute(stmt).scalar_one_or_none()
|
||||||
|
|
||||||
|
|
||||||
|
class TestRegistrationFlow:
|
||||||
|
def test_register_creates_user_and_assigns_role(
|
||||||
|
self,
|
||||||
|
client: TestClient,
|
||||||
|
db_session: Session,
|
||||||
|
) -> None:
|
||||||
|
response = client.post(
|
||||||
|
"/register",
|
||||||
|
data={
|
||||||
|
"username": "newuser",
|
||||||
|
"email": "newuser@example.com",
|
||||||
|
"password": "ComplexP@ss1",
|
||||||
|
"confirm_password": "ComplexP@ss1",
|
||||||
|
},
|
||||||
|
follow_redirects=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 303
|
||||||
|
location = response.headers.get("location")
|
||||||
|
assert location
|
||||||
|
parsed = urlparse(location)
|
||||||
|
assert parsed.path == "/login"
|
||||||
|
assert parse_qs(parsed.query).get("registered") == ["1"]
|
||||||
|
|
||||||
|
created = _get_user(db_session, email="newuser@example.com")
|
||||||
|
assert created is not None
|
||||||
|
assert created.is_active
|
||||||
|
|
||||||
|
role_stmt = select(Role).where(Role.name == "viewer")
|
||||||
|
viewer_role = db_session.execute(role_stmt).scalar_one_or_none()
|
||||||
|
assert viewer_role is not None
|
||||||
|
|
||||||
|
assignments = db_session.execute(
|
||||||
|
select(UserRole).where(
|
||||||
|
UserRole.user_id == created.id,
|
||||||
|
UserRole.role_id == viewer_role.id,
|
||||||
|
)
|
||||||
|
).scalars().all()
|
||||||
|
assert len(assignments) == 1
|
||||||
|
|
||||||
|
def test_register_duplicate_email_shows_error(
|
||||||
|
self,
|
||||||
|
client: TestClient,
|
||||||
|
) -> None:
|
||||||
|
first = client.post(
|
||||||
|
"/register",
|
||||||
|
data={
|
||||||
|
"username": "existing",
|
||||||
|
"email": "existing@example.com",
|
||||||
|
"password": "ComplexP@ss1",
|
||||||
|
"confirm_password": "ComplexP@ss1",
|
||||||
|
},
|
||||||
|
follow_redirects=False,
|
||||||
|
)
|
||||||
|
assert first.status_code == 303
|
||||||
|
|
||||||
|
second = client.post(
|
||||||
|
"/register",
|
||||||
|
data={
|
||||||
|
"username": "existing",
|
||||||
|
"email": "existing@example.com",
|
||||||
|
"password": "ComplexP@ss1",
|
||||||
|
"confirm_password": "ComplexP@ss1",
|
||||||
|
},
|
||||||
|
follow_redirects=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert second.status_code == 400
|
||||||
|
assert "Email is already registered" in second.text
|
||||||
|
|
||||||
|
|
||||||
|
class TestLoginFlow:
|
||||||
|
def test_login_sets_tokens_and_updates_last_login(
|
||||||
|
self,
|
||||||
|
client: TestClient,
|
||||||
|
db_session: Session,
|
||||||
|
) -> None:
|
||||||
|
password = "MySecur3Pass!"
|
||||||
|
user = User(
|
||||||
|
email="login@example.com",
|
||||||
|
username="loginuser",
|
||||||
|
password_hash=hash_password(password),
|
||||||
|
is_active=True,
|
||||||
|
)
|
||||||
|
db_session.add(user)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
response = client.post(
|
||||||
|
"/login",
|
||||||
|
data={"username": "loginuser", "password": password},
|
||||||
|
follow_redirects=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 303
|
||||||
|
assert response.headers.get("location") == "http://testserver/"
|
||||||
|
set_cookie_header = response.headers.get("set-cookie", "")
|
||||||
|
assert "calminer_access_token=" in set_cookie_header
|
||||||
|
assert "calminer_refresh_token=" in set_cookie_header
|
||||||
|
|
||||||
|
updated = _get_user(db_session, username="loginuser")
|
||||||
|
assert updated is not None and updated.last_login_at is not None
|
||||||
|
|
||||||
|
def test_login_invalid_credentials_returns_error(self, client: TestClient) -> None:
|
||||||
|
response = client.post(
|
||||||
|
"/login",
|
||||||
|
data={"username": "unknown", "password": "bad"},
|
||||||
|
follow_redirects=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 400
|
||||||
|
assert "Invalid username or password" in response.text
|
||||||
|
|
||||||
|
|
||||||
|
class TestPasswordResetFlow:
|
||||||
|
def test_password_reset_round_trip(
|
||||||
|
self,
|
||||||
|
client: TestClient,
|
||||||
|
db_session: Session,
|
||||||
|
) -> None:
|
||||||
|
user = User(
|
||||||
|
email="reset@example.com",
|
||||||
|
username="resetuser",
|
||||||
|
password_hash=hash_password("OldP@ssword1"),
|
||||||
|
is_active=True,
|
||||||
|
)
|
||||||
|
db_session.add(user)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
request_response = client.post(
|
||||||
|
"/forgot-password",
|
||||||
|
data={"email": "reset@example.com"},
|
||||||
|
follow_redirects=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert request_response.status_code == 303
|
||||||
|
reset_location = request_response.headers.get("location")
|
||||||
|
assert reset_location is not None
|
||||||
|
parsed = urlparse(reset_location)
|
||||||
|
assert parsed.path == "/reset-password"
|
||||||
|
token = parse_qs(parsed.query).get("token", [None])[0]
|
||||||
|
assert token
|
||||||
|
|
||||||
|
form_response = client.get(reset_location)
|
||||||
|
assert form_response.status_code == 200
|
||||||
|
|
||||||
|
submit_response = client.post(
|
||||||
|
"/reset-password",
|
||||||
|
data={
|
||||||
|
"token": token,
|
||||||
|
"password": "N3wP@ssword!",
|
||||||
|
"confirm_password": "N3wP@ssword!",
|
||||||
|
},
|
||||||
|
follow_redirects=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert submit_response.status_code == 303
|
||||||
|
assert "reset=1" in (submit_response.headers.get("location") or "")
|
||||||
|
|
||||||
|
db_session.refresh(user)
|
||||||
|
assert user.verify_password("N3wP@ssword!")
|
||||||
|
|
||||||
|
def test_password_reset_with_unknown_email_shows_generic_message(
|
||||||
|
self,
|
||||||
|
client: TestClient,
|
||||||
|
) -> None:
|
||||||
|
response = client.post(
|
||||||
|
"/forgot-password",
|
||||||
|
data={"email": "doesnotexist@example.com"},
|
||||||
|
follow_redirects=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert "If an account exists" in response.text
|
||||||
|
|
||||||
|
def test_password_reset_mismatched_passwords_return_error(
|
||||||
|
self,
|
||||||
|
client: TestClient,
|
||||||
|
db_session: Session,
|
||||||
|
) -> None:
|
||||||
|
user = User(
|
||||||
|
email="mismatch@example.com",
|
||||||
|
username="mismatch",
|
||||||
|
password_hash=hash_password("OldP@ssword1"),
|
||||||
|
is_active=True,
|
||||||
|
)
|
||||||
|
db_session.add(user)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
request_response = client.post(
|
||||||
|
"/forgot-password",
|
||||||
|
data={"email": "mismatch@example.com"},
|
||||||
|
follow_redirects=False,
|
||||||
|
)
|
||||||
|
token = parse_qs(urlparse(request_response.headers["location"]).query)["token"][0]
|
||||||
|
|
||||||
|
submit_response = client.post(
|
||||||
|
"/reset-password",
|
||||||
|
data={
|
||||||
|
"token": token,
|
||||||
|
"password": "NewPass123!",
|
||||||
|
"confirm_password": "Different123!",
|
||||||
|
},
|
||||||
|
follow_redirects=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert submit_response.status_code == 400
|
||||||
|
assert "Passwords do not match" in submit_response.text
|
||||||
76
tests/test_security.py
Normal file
76
tests/test_security.py
Normal file
@@ -0,0 +1,76 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from datetime import timedelta
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from services.security import (
|
||||||
|
JWTSettings,
|
||||||
|
TokenDecodeError,
|
||||||
|
TokenExpiredError,
|
||||||
|
TokenTypeMismatchError,
|
||||||
|
create_access_token,
|
||||||
|
create_refresh_token,
|
||||||
|
decode_access_token,
|
||||||
|
decode_refresh_token,
|
||||||
|
hash_password,
|
||||||
|
verify_password,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_hash_password_round_trip() -> None:
|
||||||
|
hashed = hash_password("super-secret")
|
||||||
|
|
||||||
|
assert hashed != "super-secret"
|
||||||
|
assert verify_password("super-secret", hashed)
|
||||||
|
assert not verify_password("incorrect", hashed)
|
||||||
|
|
||||||
|
|
||||||
|
def test_verify_password_handles_malformed_hash() -> None:
|
||||||
|
assert not verify_password("secret", "not-a-valid-hash")
|
||||||
|
|
||||||
|
|
||||||
|
def test_access_token_roundtrip() -> None:
|
||||||
|
settings = JWTSettings(secret_key="unit-test-secret")
|
||||||
|
|
||||||
|
token = create_access_token(
|
||||||
|
"user-id-123",
|
||||||
|
settings,
|
||||||
|
scopes=("read", "write"),
|
||||||
|
extra_claims={"custom": "value"},
|
||||||
|
)
|
||||||
|
|
||||||
|
payload = decode_access_token(token, settings)
|
||||||
|
|
||||||
|
assert payload.sub == "user-id-123"
|
||||||
|
assert payload.type == "access"
|
||||||
|
assert payload.scopes == ["read", "write"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_refresh_token_type_mismatch() -> None:
|
||||||
|
settings = JWTSettings(secret_key="unit-test-secret")
|
||||||
|
token = create_refresh_token("user-id-456", settings)
|
||||||
|
|
||||||
|
with pytest.raises(TokenTypeMismatchError):
|
||||||
|
decode_access_token(token, settings)
|
||||||
|
|
||||||
|
|
||||||
|
def test_decode_expired_token() -> None:
|
||||||
|
settings = JWTSettings(secret_key="unit-test-secret")
|
||||||
|
expired_token = create_access_token(
|
||||||
|
"user-id-789",
|
||||||
|
settings,
|
||||||
|
expires_delta=timedelta(seconds=-5),
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(TokenExpiredError):
|
||||||
|
decode_access_token(expired_token, settings)
|
||||||
|
|
||||||
|
|
||||||
|
def test_decode_tampered_token() -> None:
|
||||||
|
settings = JWTSettings(secret_key="unit-test-secret")
|
||||||
|
token = create_access_token("user-id-321", settings)
|
||||||
|
tampered = token[:-1] + ("a" if token[-1] != "a" else "b")
|
||||||
|
|
||||||
|
with pytest.raises(TokenDecodeError):
|
||||||
|
decode_access_token(tampered, settings)
|
||||||
Reference in New Issue
Block a user