149 lines
4.7 KiB
Python
149 lines
4.7 KiB
Python
"""Helper utilities for the frontend app."""
|
|
import functools
|
|
|
|
import httpx
|
|
from flask import redirect, session, url_for, flash
|
|
|
|
|
|
def _backend(path: str) -> str:
|
|
from flask import current_app
|
|
return f"{current_app.config['BACKEND_URL']}{path}"
|
|
|
|
|
|
def _api(method: str, path: str, *, token: str | None = None, **kwargs):
|
|
headers = kwargs.pop("headers", {})
|
|
if token:
|
|
headers["Authorization"] = f"Bearer {token}"
|
|
return httpx.request(method, _backend(path), headers=headers, timeout=30, **kwargs)
|
|
|
|
|
|
def _model_matches_modality(model: dict, modality: str) -> bool:
|
|
"""Heuristic fallback when backend modality filter returns empty."""
|
|
model_modality = (model.get("modality") or "").lower()
|
|
if model_modality == modality:
|
|
return True
|
|
|
|
text = f"{model.get('id', '')} {model.get('name', '')}".lower()
|
|
keywords = {
|
|
"image": ["image", "dall-e", "flux", "stable-diffusion", "sdxl", "recraft", "ideogram", "gpt-image"],
|
|
"video": ["video", "sora", "runway", "veo", "kling", "pika", "luma", "wan"],
|
|
"audio": ["audio", "speech", "voice", "tts", "transcribe", "whisper"],
|
|
}
|
|
|
|
if modality in keywords:
|
|
return any(k in text for k in keywords[modality])
|
|
|
|
if modality == "text":
|
|
non_text_hits = any(
|
|
k in text for k in keywords["image"] + keywords["video"] + keywords["audio"])
|
|
return not non_text_hits
|
|
|
|
return False
|
|
|
|
|
|
def _load_models(token: str, modality: str) -> list[dict]:
|
|
"""Load models for modality; fallback to unfiltered cache if needed."""
|
|
try:
|
|
models_resp = _api("GET", "/models/", token=token,
|
|
params={"modality": modality})
|
|
except httpx.RequestError:
|
|
return []
|
|
if models_resp.status_code == 200:
|
|
try:
|
|
models = models_resp.json()
|
|
except ValueError:
|
|
models = []
|
|
if models:
|
|
return models
|
|
|
|
try:
|
|
all_resp = _api("GET", "/models/", token=token)
|
|
except httpx.RequestError:
|
|
return []
|
|
if all_resp.status_code != 200:
|
|
return []
|
|
|
|
try:
|
|
all_models = all_resp.json()
|
|
except ValueError:
|
|
return []
|
|
filtered = [m for m in all_models if _model_matches_modality(m, modality)]
|
|
return filtered or all_models
|
|
|
|
|
|
def login_required(view):
|
|
@functools.wraps(view)
|
|
def wrapped(*args, **kwargs):
|
|
if "access_token" not in session:
|
|
return redirect(url_for("auth.login"))
|
|
# Validate, with auto-refresh on expiry
|
|
if not _validate_and_refresh():
|
|
return redirect(url_for("auth.login"))
|
|
return view(*args, **kwargs)
|
|
return wrapped
|
|
|
|
|
|
def admin_required(view):
|
|
@functools.wraps(view)
|
|
def wrapped(*args, **kwargs):
|
|
if "access_token" not in session:
|
|
return redirect(url_for("auth.login"))
|
|
if not _validate_and_refresh():
|
|
return redirect(url_for("auth.login"))
|
|
if session.get("user_role") != "admin":
|
|
flash("Admin access required.", "error")
|
|
return redirect(url_for("dashboard.index"))
|
|
return view(*args, **kwargs)
|
|
return wrapped
|
|
|
|
|
|
# ── Token validation & refresh ────────────────────────────────────────────
|
|
|
|
def _validate_access_token(token: str) -> bool:
|
|
"""Return True if the access token is still valid."""
|
|
try:
|
|
resp = _api("GET", "/auth/validate", token=token)
|
|
return resp.status_code == 200
|
|
except httpx.RequestError:
|
|
return False
|
|
|
|
|
|
def _try_refresh() -> bool:
|
|
"""Attempt to refresh an expired access token using the stored refresh token.
|
|
|
|
On success, updates session tokens in place. Returns True if a valid
|
|
access token exists after the attempt.
|
|
"""
|
|
refresh_token = session.get("refresh_token")
|
|
if not refresh_token:
|
|
return False
|
|
try:
|
|
resp = _api("POST", "/auth/refresh",
|
|
json={"refresh_token": refresh_token})
|
|
except httpx.RequestError:
|
|
return False
|
|
if resp.status_code != 200:
|
|
return False
|
|
data = resp.json()
|
|
session["access_token"] = data["access_token"]
|
|
session["refresh_token"] = data["refresh_token"]
|
|
return True
|
|
|
|
|
|
def _validate_and_refresh() -> bool:
|
|
"""Check access token validity; attempt refresh if expired.
|
|
|
|
Returns True if a valid session exists after the check.
|
|
"""
|
|
token = session.get("access_token")
|
|
if not token:
|
|
return False
|
|
if _validate_access_token(token):
|
|
return True
|
|
# Access token expired — try to refresh
|
|
if _try_refresh():
|
|
return True
|
|
# Both tokens are dead — clear session
|
|
session.clear()
|
|
return False
|