Add model loading functionality with modality filtering and refactor generation routes to utilize it

Co-authored-by: Copilot <copilot@github.com>
This commit is contained in:
2026-04-29 14:04:29 +02:00
parent 96d13fc440
commit acc6991341
+58 -12
View File
@@ -35,6 +35,60 @@ def _api(method: str, path: str, *, token: str | None = None, **kwargs):
return httpx.request(method, _backend(path), headers=headers, timeout=30, **kwargs) return httpx.request(method, _backend(path), headers=headers, timeout=30, **kwargs)
def _model_matches_modality(model: dict, modality: str) -> bool:
"""Heuristic fallback when backend modality filter returns empty."""
model_modality = (model.get("modality") or "").lower()
if model_modality == modality:
return True
text = f"{model.get('id', '')} {model.get('name', '')}".lower()
keywords = {
"image": ["image", "dall-e", "flux", "stable-diffusion", "sdxl", "recraft", "ideogram", "gpt-image"],
"video": ["video", "sora", "runway", "veo", "kling", "pika", "luma", "wan"],
"audio": ["audio", "speech", "voice", "tts", "transcribe", "whisper"],
}
if modality in keywords:
return any(k in text for k in keywords[modality])
if modality == "text":
non_text_hits = any(
k in text for k in keywords["image"] + keywords["video"] + keywords["audio"])
return not non_text_hits
return False
def _load_models(token: str, modality: str) -> list[dict]:
"""Load models for modality; fallback to unfiltered cache if needed."""
try:
models_resp = _api("GET", "/models/", token=token,
params={"modality": modality})
except httpx.RequestError:
return []
if models_resp.status_code == 200:
try:
models = models_resp.json()
except ValueError:
models = []
if models:
return models
try:
all_resp = _api("GET", "/models/", token=token)
except httpx.RequestError:
return []
if all_resp.status_code != 200:
return []
try:
all_models = all_resp.json()
except ValueError:
return []
filtered = [m for m in all_models if _model_matches_modality(m, modality)]
return filtered or all_models
def login_required(view): def login_required(view):
@functools.wraps(view) @functools.wraps(view)
def wrapped(*args, **kwargs): def wrapped(*args, **kwargs):
@@ -163,9 +217,7 @@ def generate_text():
result = resp.json() result = resp.json()
else: else:
error = resp.json().get("detail", "Generation failed.") error = resp.json().get("detail", "Generation failed.")
models_resp = _api("GET", "/models/", token=token, models = _load_models(token, "text")
params={"modality": "text"})
models = models_resp.json() if models_resp.status_code == 200 else []
return render_template("generate_text.html", result=result, error=error, models=models) return render_template("generate_text.html", result=result, error=error, models=models)
@@ -186,9 +238,7 @@ def generate_image():
) )
if up_resp.status_code not in (200, 201): if up_resp.status_code not in (200, 201):
error = up_resp.json().get("detail", "Image upload failed.") error = up_resp.json().get("detail", "Image upload failed.")
models_resp = _api("GET", "/models/", models = _load_models(token, "image")
token=token, params={"modality": "image"})
models = models_resp.json() if models_resp.status_code == 200 else []
return render_template("generate_image.html", result=result, error=error, models=models) return render_template("generate_image.html", result=result, error=error, models=models)
resp = _api("POST", "/generate/image", token=token, json={ resp = _api("POST", "/generate/image", token=token, json={
@@ -203,9 +253,7 @@ def generate_image():
result = resp.json() result = resp.json()
else: else:
error = resp.json().get("detail", "Generation failed.") error = resp.json().get("detail", "Generation failed.")
models_resp = _api("GET", "/models/", token=token, models = _load_models(token, "image")
params={"modality": "image"})
models = models_resp.json() if models_resp.status_code == 200 else []
return render_template("generate_image.html", result=result, error=error, models=models) return render_template("generate_image.html", result=result, error=error, models=models)
@@ -241,9 +289,7 @@ def generate_video():
result = resp.json() result = resp.json()
else: else:
error = resp.json().get("detail", "Generation failed.") error = resp.json().get("detail", "Generation failed.")
models_resp = _api("GET", "/models/", token=token, models = _load_models(token, "video")
params={"modality": "video"})
models = models_resp.json() if models_resp.status_code == 200 else []
return render_template("generate_video.html", result=result, error=error, models=models) return render_template("generate_video.html", result=result, error=error, models=models)