from __future__ import annotations from collections.abc import Callable, Iterable, Generator from fastapi import Depends, HTTPException, Request, status from config.settings import Settings, get_settings from models import Project, Role, Scenario, User from services.authorization import ( ensure_project_access as ensure_project_access_helper, ensure_scenario_access as ensure_scenario_access_helper, ensure_scenario_in_project as ensure_scenario_in_project_helper, ) from services.exceptions import AuthorizationError, EntityNotFoundError from services.security import JWTSettings from services.session import ( AuthSession, SessionStrategy, SessionTokens, build_session_strategy, extract_session_tokens, ) from services.unit_of_work import UnitOfWork from services.importers import ImportIngestionService from services.pricing import PricingMetadata from services.navigation import NavigationService from services.scenario_evaluation import ScenarioPricingConfig, ScenarioPricingEvaluator from services.repositories import pricing_settings_to_metadata def get_unit_of_work() -> Generator[UnitOfWork, None, None]: """FastAPI dependency yielding a unit-of-work instance.""" with UnitOfWork() as uow: yield uow _IMPORT_INGESTION_SERVICE = ImportIngestionService(lambda: UnitOfWork()) def get_import_ingestion_service() -> ImportIngestionService: """Provide singleton import ingestion service.""" return _IMPORT_INGESTION_SERVICE def get_application_settings() -> Settings: """Provide cached application settings instance.""" return get_settings() def get_pricing_metadata( settings: Settings = Depends(get_application_settings), uow: UnitOfWork = Depends(get_unit_of_work), ) -> PricingMetadata: """Return pricing metadata defaults sourced from persisted pricing settings.""" stored = uow.get_pricing_metadata() if stored is not None: return stored fallback = settings.pricing_metadata() seed_result = uow.ensure_default_pricing_settings(metadata=fallback) return pricing_settings_to_metadata(seed_result.settings) def get_navigation_service( uow: UnitOfWork = Depends(get_unit_of_work), ) -> NavigationService: if not uow.navigation: raise RuntimeError("Navigation repository is not initialised") return NavigationService(uow.navigation) def get_pricing_evaluator( metadata: PricingMetadata = Depends(get_pricing_metadata), ) -> ScenarioPricingEvaluator: """Provide a configured scenario pricing evaluator.""" return ScenarioPricingEvaluator(ScenarioPricingConfig(metadata=metadata)) def get_jwt_settings() -> JWTSettings: """Provide JWT runtime configuration derived from settings.""" return get_settings().jwt_settings() def get_session_strategy( settings: Settings = Depends(get_application_settings), ) -> SessionStrategy: """Yield configured session transport strategy.""" return build_session_strategy(settings.session_settings()) def get_session_tokens( request: Request, strategy: SessionStrategy = Depends(get_session_strategy), ) -> SessionTokens: """Extract raw session tokens from the incoming request.""" existing = getattr(request.state, "auth_session", None) if isinstance(existing, AuthSession): return existing.tokens tokens = extract_session_tokens(request, strategy) request.state.auth_session = AuthSession(tokens=tokens) return tokens def get_auth_session( request: Request, tokens: SessionTokens = Depends(get_session_tokens), ) -> AuthSession: """Provide authentication session context for the current request.""" existing = getattr(request.state, "auth_session", None) if isinstance(existing, AuthSession): return existing if tokens.is_empty: session = AuthSession.anonymous() else: session = AuthSession(tokens=tokens) request.state.auth_session = session return session def get_current_user( session: AuthSession = Depends(get_auth_session), ) -> User | None: """Return the current authenticated user if present.""" return session.user def require_current_user( session: AuthSession = Depends(get_auth_session), ) -> User: """Ensure that a request is authenticated and return the user context.""" if session.user is None or session.tokens.is_empty: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Authentication required.", ) return session.user def require_authenticated_user( user: User = Depends(require_current_user), ) -> User: """Ensure the current user account is active.""" if not user.is_active: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="User account is disabled.", ) return user def require_authenticated_user_html( request: Request, session: AuthSession = Depends(get_auth_session), ) -> User: """HTML-aware authenticated dependency that redirects anonymous sessions.""" user = session.user if user is None or session.tokens.is_empty: login_url = str(request.url_for("auth.login_form")) raise HTTPException( status_code=status.HTTP_303_SEE_OTHER, headers={"Location": login_url}, ) if not user.is_active: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="User account is disabled.", ) return user def _user_role_names(user: User) -> set[str]: roles: Iterable[Role] = getattr(user, "roles", []) or [] return {role.name for role in roles} def require_roles(*roles: str) -> Callable[[User], User]: """Dependency factory enforcing membership in one of the given roles.""" required = tuple(role.strip() for role in roles if role.strip()) if not required: raise ValueError("require_roles requires at least one role name") def _dependency(user: User = Depends(require_authenticated_user)) -> User: if user.is_superuser: return user role_names = _user_role_names(user) if not any(role in role_names for role in required): raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="Insufficient permissions for this action.", ) return user return _dependency def require_any_role(*roles: str) -> Callable[[User], User]: """Alias of require_roles for readability in some contexts.""" return require_roles(*roles) def require_roles_html(*roles: str) -> Callable[[Request], User]: """Ensure user is authenticated for HTML responses; redirect anonymous to login.""" required = tuple(role.strip() for role in roles if role.strip()) if not required: raise ValueError("require_roles_html requires at least one role name") def _dependency( request: Request, session: AuthSession = Depends(get_auth_session), ) -> User: user = session.user if user is None: login_url = str(request.url_for("auth.login_form")) raise HTTPException( status_code=status.HTTP_303_SEE_OTHER, headers={"Location": login_url}, ) if user.is_superuser: return user role_names = _user_role_names(user) if not any(role in role_names for role in required): raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="Insufficient permissions for this action.", ) return user return _dependency def require_any_role_html(*roles: str) -> Callable[[Request], User]: """Alias of require_roles_html for readability.""" return require_roles_html(*roles) def require_project_resource( *, require_manage: bool = False, user_dependency: Callable[..., User] = require_authenticated_user, ) -> Callable[[int], Project]: """Dependency factory that resolves a project with authorization checks.""" def _dependency( project_id: int, user: User = Depends(user_dependency), uow: UnitOfWork = Depends(get_unit_of_work), ) -> Project: try: return ensure_project_access_helper( uow, project_id=project_id, user=user, require_manage=require_manage, ) except EntityNotFoundError as exc: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail=str(exc), ) from exc except AuthorizationError as exc: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail=str(exc), ) from exc return _dependency def require_scenario_resource( *, require_manage: bool = False, with_children: bool = False, user_dependency: Callable[..., User] = require_authenticated_user, ) -> Callable[[int], Scenario]: """Dependency factory that resolves a scenario with authorization checks.""" def _dependency( scenario_id: int, user: User = Depends(user_dependency), uow: UnitOfWork = Depends(get_unit_of_work), ) -> Scenario: try: return ensure_scenario_access_helper( uow, scenario_id=scenario_id, user=user, require_manage=require_manage, with_children=with_children, ) except EntityNotFoundError as exc: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail=str(exc), ) from exc except AuthorizationError as exc: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail=str(exc), ) from exc return _dependency def require_project_scenario_resource( *, require_manage: bool = False, with_children: bool = False, user_dependency: Callable[..., User] = require_authenticated_user, ) -> Callable[[int, int], Scenario]: """Dependency factory ensuring a scenario belongs to the given project and is accessible.""" def _dependency( project_id: int, scenario_id: int, user: User = Depends(user_dependency), uow: UnitOfWork = Depends(get_unit_of_work), ) -> Scenario: try: return ensure_scenario_in_project_helper( uow, project_id=project_id, scenario_id=scenario_id, user=user, require_manage=require_manage, with_children=with_children, ) except EntityNotFoundError as exc: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail=str(exc), ) from exc except AuthorizationError as exc: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail=str(exc), ) from exc return _dependency def require_project_resource_html( *, require_manage: bool = False ) -> Callable[[int], Project]: """HTML-aware project loader that redirects anonymous sessions.""" return require_project_resource( require_manage=require_manage, user_dependency=require_authenticated_user_html, ) def require_scenario_resource_html( *, require_manage: bool = False, with_children: bool = False, ) -> Callable[[int], Scenario]: """HTML-aware scenario loader that redirects anonymous sessions.""" return require_scenario_resource( require_manage=require_manage, with_children=with_children, user_dependency=require_authenticated_user_html, ) def require_project_scenario_resource_html( *, require_manage: bool = False, with_children: bool = False, ) -> Callable[[int, int], Scenario]: """HTML-aware project-scenario loader redirecting anonymous sessions.""" return require_project_scenario_resource( require_manage=require_manage, with_children=with_children, user_dependency=require_authenticated_user_html, )