diff --git a/frontend/app/main.py b/frontend/app/main.py index 3c21f69..a5c380e 100644 --- a/frontend/app/main.py +++ b/frontend/app/main.py @@ -35,6 +35,60 @@ def _api(method: str, path: str, *, token: str | None = None, **kwargs): return httpx.request(method, _backend(path), headers=headers, timeout=30, **kwargs) +def _model_matches_modality(model: dict, modality: str) -> bool: + """Heuristic fallback when backend modality filter returns empty.""" + model_modality = (model.get("modality") or "").lower() + if model_modality == modality: + return True + + text = f"{model.get('id', '')} {model.get('name', '')}".lower() + keywords = { + "image": ["image", "dall-e", "flux", "stable-diffusion", "sdxl", "recraft", "ideogram", "gpt-image"], + "video": ["video", "sora", "runway", "veo", "kling", "pika", "luma", "wan"], + "audio": ["audio", "speech", "voice", "tts", "transcribe", "whisper"], + } + + if modality in keywords: + return any(k in text for k in keywords[modality]) + + if modality == "text": + non_text_hits = any( + k in text for k in keywords["image"] + keywords["video"] + keywords["audio"]) + return not non_text_hits + + return False + + +def _load_models(token: str, modality: str) -> list[dict]: + """Load models for modality; fallback to unfiltered cache if needed.""" + try: + models_resp = _api("GET", "/models/", token=token, + params={"modality": modality}) + except httpx.RequestError: + return [] + if models_resp.status_code == 200: + try: + models = models_resp.json() + except ValueError: + models = [] + if models: + return models + + try: + all_resp = _api("GET", "/models/", token=token) + except httpx.RequestError: + return [] + if all_resp.status_code != 200: + return [] + + try: + all_models = all_resp.json() + except ValueError: + return [] + filtered = [m for m in all_models if _model_matches_modality(m, modality)] + return filtered or all_models + + def login_required(view): @functools.wraps(view) def wrapped(*args, **kwargs): @@ -163,9 +217,7 @@ def generate_text(): result = resp.json() else: error = resp.json().get("detail", "Generation failed.") - models_resp = _api("GET", "/models/", token=token, - params={"modality": "text"}) - models = models_resp.json() if models_resp.status_code == 200 else [] + models = _load_models(token, "text") return render_template("generate_text.html", result=result, error=error, models=models) @@ -186,9 +238,7 @@ def generate_image(): ) if up_resp.status_code not in (200, 201): error = up_resp.json().get("detail", "Image upload failed.") - models_resp = _api("GET", "/models/", - token=token, params={"modality": "image"}) - models = models_resp.json() if models_resp.status_code == 200 else [] + models = _load_models(token, "image") return render_template("generate_image.html", result=result, error=error, models=models) resp = _api("POST", "/generate/image", token=token, json={ @@ -203,9 +253,7 @@ def generate_image(): result = resp.json() else: error = resp.json().get("detail", "Generation failed.") - models_resp = _api("GET", "/models/", token=token, - params={"modality": "image"}) - models = models_resp.json() if models_resp.status_code == 200 else [] + models = _load_models(token, "image") return render_template("generate_image.html", result=result, error=error, models=models) @@ -241,9 +289,7 @@ def generate_video(): result = resp.json() else: error = resp.json().get("detail", "Generation failed.") - models_resp = _api("GET", "/models/", token=token, - params={"modality": "video"}) - models = models_resp.json() if models_resp.status_code == 200 else [] + models = _load_models(token, "video") return render_template("generate_video.html", result=result, error=error, models=models)