485 lines
14 KiB
Python
485 lines
14 KiB
Python
from __future__ import annotations
|
|
|
|
from datetime import datetime, timedelta, timezone
|
|
from typing import Any, Iterable
|
|
|
|
from fastapi import APIRouter, Depends, HTTPException, Request, UploadFile, status
|
|
from fastapi.responses import HTMLResponse, RedirectResponse
|
|
from fastapi.templating import Jinja2Templates
|
|
from pydantic import ValidationError
|
|
from starlette.datastructures import FormData
|
|
|
|
from dependencies import (
|
|
get_auth_session,
|
|
get_jwt_settings,
|
|
get_session_strategy,
|
|
get_unit_of_work,
|
|
require_current_user,
|
|
)
|
|
from models import Role, User
|
|
from schemas.auth import (
|
|
LoginForm,
|
|
PasswordResetForm,
|
|
PasswordResetRequestForm,
|
|
RegistrationForm,
|
|
)
|
|
from services.exceptions import EntityConflictError
|
|
from services.security import (
|
|
JWTSettings,
|
|
TokenDecodeError,
|
|
TokenExpiredError,
|
|
TokenTypeMismatchError,
|
|
create_access_token,
|
|
create_refresh_token,
|
|
decode_access_token,
|
|
hash_password,
|
|
verify_password,
|
|
)
|
|
from services.session import (
|
|
AuthSession,
|
|
SessionStrategy,
|
|
clear_session_cookies,
|
|
set_session_cookies,
|
|
)
|
|
from services.repositories import RoleRepository, UserRepository
|
|
from services.unit_of_work import UnitOfWork
|
|
|
|
router = APIRouter(tags=["Authentication"])
|
|
templates = Jinja2Templates(directory="templates")
|
|
|
|
_PASSWORD_RESET_SCOPE = "password-reset"
|
|
_AUTH_SCOPE = "auth"
|
|
|
|
|
|
def _template(
|
|
request: Request,
|
|
template_name: str,
|
|
context: dict[str, Any],
|
|
*,
|
|
status_code: int = status.HTTP_200_OK,
|
|
) -> HTMLResponse:
|
|
return templates.TemplateResponse(
|
|
request,
|
|
template_name,
|
|
context,
|
|
status_code=status_code,
|
|
)
|
|
|
|
|
|
def _validation_errors(exc: ValidationError) -> list[str]:
|
|
return [error.get("msg", "Invalid input.") for error in exc.errors()]
|
|
|
|
|
|
def _scopes(include: Iterable[str]) -> list[str]:
|
|
return list(include)
|
|
|
|
|
|
def _normalise_form_data(form_data: FormData) -> dict[str, str]:
|
|
normalised: dict[str, str] = {}
|
|
for key, value in form_data.multi_items():
|
|
if isinstance(value, UploadFile):
|
|
str_value = value.filename or ""
|
|
else:
|
|
str_value = str(value)
|
|
normalised[key] = str_value
|
|
return normalised
|
|
|
|
|
|
def _require_users_repo(uow: UnitOfWork) -> UserRepository:
|
|
if not uow.users:
|
|
raise RuntimeError("User repository is not initialised")
|
|
return uow.users
|
|
|
|
|
|
def _require_roles_repo(uow: UnitOfWork) -> RoleRepository:
|
|
if not uow.roles:
|
|
raise RuntimeError("Role repository is not initialised")
|
|
return uow.roles
|
|
|
|
|
|
@router.get("/login", response_class=HTMLResponse, include_in_schema=False, name="auth.login_form")
|
|
def login_form(request: Request) -> HTMLResponse:
|
|
return _template(
|
|
request,
|
|
"login.html",
|
|
{
|
|
"form_action": request.url_for("auth.login_submit"),
|
|
"errors": [],
|
|
"username": "",
|
|
},
|
|
)
|
|
|
|
|
|
@router.post("/login", include_in_schema=False, name="auth.login_submit")
|
|
async def login_submit(
|
|
request: Request,
|
|
uow: UnitOfWork = Depends(get_unit_of_work),
|
|
jwt_settings: JWTSettings = Depends(get_jwt_settings),
|
|
session_strategy: SessionStrategy = Depends(get_session_strategy),
|
|
):
|
|
form_data = _normalise_form_data(await request.form())
|
|
try:
|
|
form = LoginForm(**form_data)
|
|
except ValidationError as exc:
|
|
return _template(
|
|
request,
|
|
"login.html",
|
|
{
|
|
"form_action": request.url_for("auth.login_submit"),
|
|
"errors": _validation_errors(exc),
|
|
},
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
)
|
|
|
|
identifier = form.username
|
|
users_repo = _require_users_repo(uow)
|
|
user = _lookup_user(users_repo, identifier)
|
|
errors: list[str] = []
|
|
|
|
if not user or not verify_password(form.password, user.password_hash):
|
|
errors.append("Invalid username or password.")
|
|
elif not user.is_active:
|
|
errors.append("Account is inactive. Contact an administrator.")
|
|
|
|
if errors:
|
|
return _template(
|
|
request,
|
|
"login.html",
|
|
{
|
|
"form_action": request.url_for("auth.login_submit"),
|
|
"errors": errors,
|
|
"username": identifier,
|
|
},
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
)
|
|
|
|
assert user is not None # mypy hint - guarded above
|
|
user.last_login_at = datetime.now(timezone.utc)
|
|
|
|
access_token = create_access_token(
|
|
str(user.id),
|
|
jwt_settings,
|
|
scopes=_scopes((_AUTH_SCOPE,)),
|
|
)
|
|
refresh_token = create_refresh_token(
|
|
str(user.id),
|
|
jwt_settings,
|
|
scopes=_scopes((_AUTH_SCOPE,)),
|
|
)
|
|
|
|
response = RedirectResponse(
|
|
request.url_for("dashboard.home"),
|
|
status_code=status.HTTP_303_SEE_OTHER,
|
|
)
|
|
set_session_cookies(
|
|
response,
|
|
access_token=access_token,
|
|
refresh_token=refresh_token,
|
|
strategy=session_strategy,
|
|
jwt_settings=jwt_settings,
|
|
)
|
|
return response
|
|
|
|
|
|
@router.get("/logout", include_in_schema=False, name="auth.logout")
|
|
async def logout(
|
|
request: Request,
|
|
_: User = Depends(require_current_user),
|
|
session: AuthSession = Depends(get_auth_session),
|
|
session_strategy: SessionStrategy = Depends(get_session_strategy),
|
|
) -> RedirectResponse:
|
|
session.mark_cleared()
|
|
redirect_url = request.url_for(
|
|
"auth.login_form").include_query_params(logout="1")
|
|
response = RedirectResponse(
|
|
redirect_url,
|
|
status_code=status.HTTP_303_SEE_OTHER,
|
|
)
|
|
clear_session_cookies(response, session_strategy)
|
|
return response
|
|
|
|
|
|
def _lookup_user(users_repo: UserRepository, identifier: str) -> User | None:
|
|
if "@" in identifier:
|
|
return users_repo.get_by_email(identifier.lower(), with_roles=True)
|
|
return users_repo.get_by_username(identifier, with_roles=True)
|
|
|
|
|
|
@router.get("/register", response_class=HTMLResponse, include_in_schema=False, name="auth.register_form")
|
|
def register_form(request: Request) -> HTMLResponse:
|
|
return _template(
|
|
request,
|
|
"register.html",
|
|
{
|
|
"form_action": request.url_for("auth.register_submit"),
|
|
"errors": [],
|
|
"form_data": None,
|
|
},
|
|
)
|
|
|
|
|
|
@router.post("/register", include_in_schema=False, name="auth.register_submit")
|
|
async def register_submit(
|
|
request: Request,
|
|
uow: UnitOfWork = Depends(get_unit_of_work),
|
|
):
|
|
form_data = _normalise_form_data(await request.form())
|
|
try:
|
|
form = RegistrationForm(**form_data)
|
|
except ValidationError as exc:
|
|
return _registration_error_response(request, _validation_errors(exc))
|
|
|
|
errors: list[str] = []
|
|
users_repo = _require_users_repo(uow)
|
|
roles_repo = _require_roles_repo(uow)
|
|
uow.ensure_default_roles()
|
|
|
|
if users_repo.get_by_email(form.email):
|
|
errors.append("Email is already registered.")
|
|
if users_repo.get_by_username(form.username):
|
|
errors.append("Username is already taken.")
|
|
|
|
if errors:
|
|
return _registration_error_response(request, errors, form)
|
|
|
|
user = User(
|
|
email=form.email,
|
|
username=form.username,
|
|
password_hash=hash_password(form.password),
|
|
is_active=True,
|
|
is_superuser=False,
|
|
)
|
|
|
|
try:
|
|
created = users_repo.create(user)
|
|
except EntityConflictError:
|
|
return _registration_error_response(
|
|
request,
|
|
["An account with this username or email already exists."],
|
|
form,
|
|
)
|
|
|
|
viewer_role = _ensure_viewer_role(roles_repo)
|
|
if viewer_role is not None:
|
|
users_repo.assign_role(
|
|
user_id=created.id,
|
|
role_id=viewer_role.id,
|
|
granted_by=created.id,
|
|
)
|
|
|
|
redirect_url = request.url_for(
|
|
"auth.login_form").include_query_params(registered="1")
|
|
return RedirectResponse(
|
|
redirect_url,
|
|
status_code=status.HTTP_303_SEE_OTHER,
|
|
)
|
|
|
|
|
|
def _registration_error_response(
|
|
request: Request,
|
|
errors: list[str],
|
|
form: RegistrationForm | None = None,
|
|
) -> HTMLResponse:
|
|
context = {
|
|
"form_action": request.url_for("auth.register_submit"),
|
|
"errors": errors,
|
|
"form_data": form.model_dump(exclude={"password", "confirm_password"}) if form else None,
|
|
}
|
|
return _template(
|
|
request,
|
|
"register.html",
|
|
context,
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
)
|
|
|
|
|
|
def _ensure_viewer_role(roles_repo: RoleRepository) -> Role | None:
|
|
viewer = roles_repo.get_by_name("viewer")
|
|
if viewer:
|
|
return viewer
|
|
return roles_repo.get_by_name("viewer")
|
|
|
|
|
|
@router.get(
|
|
"/forgot-password",
|
|
response_class=HTMLResponse,
|
|
include_in_schema=False,
|
|
name="auth.password_reset_request_form",
|
|
)
|
|
def password_reset_request_form(request: Request) -> HTMLResponse:
|
|
return _template(
|
|
request,
|
|
"forgot_password.html",
|
|
{
|
|
"form_action": request.url_for("auth.password_reset_request_submit"),
|
|
"errors": [],
|
|
"message": None,
|
|
},
|
|
)
|
|
|
|
|
|
@router.post(
|
|
"/forgot-password",
|
|
include_in_schema=False,
|
|
name="auth.password_reset_request_submit",
|
|
)
|
|
async def password_reset_request_submit(
|
|
request: Request,
|
|
uow: UnitOfWork = Depends(get_unit_of_work),
|
|
jwt_settings: JWTSettings = Depends(get_jwt_settings),
|
|
):
|
|
form_data = _normalise_form_data(await request.form())
|
|
try:
|
|
form = PasswordResetRequestForm(**form_data)
|
|
except ValidationError as exc:
|
|
return _template(
|
|
request,
|
|
"forgot_password.html",
|
|
{
|
|
"form_action": request.url_for("auth.password_reset_request_submit"),
|
|
"errors": _validation_errors(exc),
|
|
"message": None,
|
|
},
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
)
|
|
|
|
users_repo = _require_users_repo(uow)
|
|
user = users_repo.get_by_email(form.email)
|
|
if not user:
|
|
return _template(
|
|
request,
|
|
"forgot_password.html",
|
|
{
|
|
"form_action": request.url_for("auth.password_reset_request_submit"),
|
|
"errors": [],
|
|
"message": "If an account exists, a reset link has been sent.",
|
|
},
|
|
)
|
|
|
|
token = create_access_token(
|
|
str(user.id),
|
|
jwt_settings,
|
|
scopes=_scopes((_PASSWORD_RESET_SCOPE,)),
|
|
expires_delta=timedelta(hours=1),
|
|
)
|
|
|
|
reset_url = request.url_for(
|
|
"auth.password_reset_form").include_query_params(token=token)
|
|
return RedirectResponse(reset_url, status_code=status.HTTP_303_SEE_OTHER)
|
|
|
|
|
|
@router.get(
|
|
"/reset-password",
|
|
response_class=HTMLResponse,
|
|
include_in_schema=False,
|
|
name="auth.password_reset_form",
|
|
)
|
|
def password_reset_form(
|
|
request: Request,
|
|
token: str | None = None,
|
|
jwt_settings: JWTSettings = Depends(get_jwt_settings),
|
|
) -> HTMLResponse:
|
|
errors: list[str] = []
|
|
if not token:
|
|
errors.append("Missing password reset token.")
|
|
else:
|
|
try:
|
|
payload = decode_access_token(token, jwt_settings)
|
|
if _PASSWORD_RESET_SCOPE not in payload.scopes:
|
|
errors.append("Invalid token scope.")
|
|
except TokenExpiredError:
|
|
errors.append(
|
|
"Token has expired. Please request a new password reset.")
|
|
except (TokenDecodeError, TokenTypeMismatchError):
|
|
errors.append("Invalid password reset token.")
|
|
|
|
return _template(
|
|
request,
|
|
"reset_password.html",
|
|
{
|
|
"form_action": request.url_for("auth.password_reset_submit"),
|
|
"token": token,
|
|
"errors": errors,
|
|
},
|
|
status_code=status.HTTP_400_BAD_REQUEST if errors else status.HTTP_200_OK,
|
|
)
|
|
|
|
|
|
@router.post(
|
|
"/reset-password",
|
|
include_in_schema=False,
|
|
name="auth.password_reset_submit",
|
|
)
|
|
async def password_reset_submit(
|
|
request: Request,
|
|
uow: UnitOfWork = Depends(get_unit_of_work),
|
|
jwt_settings: JWTSettings = Depends(get_jwt_settings),
|
|
):
|
|
form_data = _normalise_form_data(await request.form())
|
|
try:
|
|
form = PasswordResetForm(**form_data)
|
|
except ValidationError as exc:
|
|
return _template(
|
|
request,
|
|
"reset_password.html",
|
|
{
|
|
"form_action": request.url_for("auth.password_reset_submit"),
|
|
"token": form_data.get("token"),
|
|
"errors": _validation_errors(exc),
|
|
},
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
)
|
|
|
|
try:
|
|
payload = decode_access_token(form.token, jwt_settings)
|
|
except TokenExpiredError:
|
|
return _reset_error_response(
|
|
request,
|
|
form.token,
|
|
"Token has expired. Please request a new password reset.",
|
|
)
|
|
except (TokenDecodeError, TokenTypeMismatchError):
|
|
return _reset_error_response(
|
|
request,
|
|
form.token,
|
|
"Invalid password reset token.",
|
|
)
|
|
|
|
if _PASSWORD_RESET_SCOPE not in payload.scopes:
|
|
return _reset_error_response(
|
|
request,
|
|
form.token,
|
|
"Invalid password reset token scope.",
|
|
)
|
|
|
|
users_repo = _require_users_repo(uow)
|
|
user_id = int(payload.sub)
|
|
user = users_repo.get(user_id)
|
|
if not user:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_404_NOT_FOUND, detail="User not found")
|
|
|
|
user.set_password(form.password)
|
|
if not user.is_active:
|
|
user.is_active = True
|
|
|
|
redirect_url = request.url_for(
|
|
"auth.login_form").include_query_params(reset="1")
|
|
return RedirectResponse(
|
|
redirect_url,
|
|
status_code=status.HTTP_303_SEE_OTHER,
|
|
)
|
|
|
|
|
|
def _reset_error_response(request: Request, token: str, message: str) -> HTMLResponse:
|
|
return _template(
|
|
request,
|
|
"reset_password.html",
|
|
{
|
|
"form_action": request.url_for("auth.password_reset_submit"),
|
|
"token": token,
|
|
"errors": [message],
|
|
},
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
)
|