from __future__ import annotations from collections.abc import Callable, Iterable, Generator from fastapi import Depends, HTTPException, Request, status from config.settings import Settings, get_settings from models import Project, Role, Scenario, User from services.authorization import ( ensure_project_access as ensure_project_access_helper, ensure_scenario_access as ensure_scenario_access_helper, ensure_scenario_in_project as ensure_scenario_in_project_helper, ) from services.exceptions import AuthorizationError, EntityNotFoundError from services.security import JWTSettings from services.session import ( AuthSession, SessionStrategy, SessionTokens, build_session_strategy, extract_session_tokens, ) from services.unit_of_work import UnitOfWork def get_unit_of_work() -> Generator[UnitOfWork, None, None]: """FastAPI dependency yielding a unit-of-work instance.""" with UnitOfWork() as uow: yield uow def get_application_settings() -> Settings: """Provide cached application settings instance.""" return get_settings() def get_jwt_settings() -> JWTSettings: """Provide JWT runtime configuration derived from settings.""" return get_settings().jwt_settings() def get_session_strategy( settings: Settings = Depends(get_application_settings), ) -> SessionStrategy: """Yield configured session transport strategy.""" return build_session_strategy(settings.session_settings()) def get_session_tokens( request: Request, strategy: SessionStrategy = Depends(get_session_strategy), ) -> SessionTokens: """Extract raw session tokens from the incoming request.""" existing = getattr(request.state, "auth_session", None) if isinstance(existing, AuthSession): return existing.tokens tokens = extract_session_tokens(request, strategy) request.state.auth_session = AuthSession(tokens=tokens) return tokens def get_auth_session( request: Request, tokens: SessionTokens = Depends(get_session_tokens), ) -> AuthSession: """Provide authentication session context for the current request.""" existing = getattr(request.state, "auth_session", None) if isinstance(existing, AuthSession): return existing if tokens.is_empty: session = AuthSession.anonymous() else: session = AuthSession(tokens=tokens) request.state.auth_session = session return session def get_current_user( session: AuthSession = Depends(get_auth_session), ) -> User | None: """Return the current authenticated user if present.""" return session.user def require_current_user( session: AuthSession = Depends(get_auth_session), ) -> User: """Ensure that a request is authenticated and return the user context.""" if session.user is None or session.tokens.is_empty: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Authentication required.", ) return session.user def require_authenticated_user( user: User = Depends(require_current_user), ) -> User: """Ensure the current user account is active.""" if not user.is_active: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="User account is disabled.", ) return user def _user_role_names(user: User) -> set[str]: roles: Iterable[Role] = getattr(user, "roles", []) or [] return {role.name for role in roles} def require_roles(*roles: str) -> Callable[[User], User]: """Dependency factory enforcing membership in one of the given roles.""" required = tuple(role.strip() for role in roles if role.strip()) if not required: raise ValueError("require_roles requires at least one role name") def _dependency(user: User = Depends(require_authenticated_user)) -> User: if user.is_superuser: return user role_names = _user_role_names(user) if not any(role in role_names for role in required): raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="Insufficient permissions for this action.", ) return user return _dependency def require_any_role(*roles: str) -> Callable[[User], User]: """Alias of require_roles for readability in some contexts.""" return require_roles(*roles) def require_project_resource(*, require_manage: bool = False) -> Callable[[int], Project]: """Dependency factory that resolves a project with authorization checks.""" def _dependency( project_id: int, user: User = Depends(require_authenticated_user), uow: UnitOfWork = Depends(get_unit_of_work), ) -> Project: try: return ensure_project_access_helper( uow, project_id=project_id, user=user, require_manage=require_manage, ) except EntityNotFoundError as exc: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail=str(exc), ) from exc except AuthorizationError as exc: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail=str(exc), ) from exc return _dependency def require_scenario_resource( *, require_manage: bool = False, with_children: bool = False ) -> Callable[[int], Scenario]: """Dependency factory that resolves a scenario with authorization checks.""" def _dependency( scenario_id: int, user: User = Depends(require_authenticated_user), uow: UnitOfWork = Depends(get_unit_of_work), ) -> Scenario: try: return ensure_scenario_access_helper( uow, scenario_id=scenario_id, user=user, require_manage=require_manage, with_children=with_children, ) except EntityNotFoundError as exc: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail=str(exc), ) from exc except AuthorizationError as exc: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail=str(exc), ) from exc return _dependency def require_project_scenario_resource( *, require_manage: bool = False, with_children: bool = False ) -> Callable[[int, int], Scenario]: """Dependency factory ensuring a scenario belongs to the given project and is accessible.""" def _dependency( project_id: int, scenario_id: int, user: User = Depends(require_authenticated_user), uow: UnitOfWork = Depends(get_unit_of_work), ) -> Scenario: try: return ensure_scenario_in_project_helper( uow, project_id=project_id, scenario_id=scenario_id, user=user, require_manage=require_manage, with_children=with_children, ) except EntityNotFoundError as exc: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail=str(exc), ) from exc except AuthorizationError as exc: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail=str(exc), ) from exc return _dependency