Add model loading functionality with modality filtering and refactor generation routes to utilize it
Co-authored-by: Copilot <copilot@github.com>
This commit is contained in:
+58
-12
@@ -35,6 +35,60 @@ def _api(method: str, path: str, *, token: str | None = None, **kwargs):
|
||||
return httpx.request(method, _backend(path), headers=headers, timeout=30, **kwargs)
|
||||
|
||||
|
||||
def _model_matches_modality(model: dict, modality: str) -> bool:
|
||||
"""Heuristic fallback when backend modality filter returns empty."""
|
||||
model_modality = (model.get("modality") or "").lower()
|
||||
if model_modality == modality:
|
||||
return True
|
||||
|
||||
text = f"{model.get('id', '')} {model.get('name', '')}".lower()
|
||||
keywords = {
|
||||
"image": ["image", "dall-e", "flux", "stable-diffusion", "sdxl", "recraft", "ideogram", "gpt-image"],
|
||||
"video": ["video", "sora", "runway", "veo", "kling", "pika", "luma", "wan"],
|
||||
"audio": ["audio", "speech", "voice", "tts", "transcribe", "whisper"],
|
||||
}
|
||||
|
||||
if modality in keywords:
|
||||
return any(k in text for k in keywords[modality])
|
||||
|
||||
if modality == "text":
|
||||
non_text_hits = any(
|
||||
k in text for k in keywords["image"] + keywords["video"] + keywords["audio"])
|
||||
return not non_text_hits
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def _load_models(token: str, modality: str) -> list[dict]:
|
||||
"""Load models for modality; fallback to unfiltered cache if needed."""
|
||||
try:
|
||||
models_resp = _api("GET", "/models/", token=token,
|
||||
params={"modality": modality})
|
||||
except httpx.RequestError:
|
||||
return []
|
||||
if models_resp.status_code == 200:
|
||||
try:
|
||||
models = models_resp.json()
|
||||
except ValueError:
|
||||
models = []
|
||||
if models:
|
||||
return models
|
||||
|
||||
try:
|
||||
all_resp = _api("GET", "/models/", token=token)
|
||||
except httpx.RequestError:
|
||||
return []
|
||||
if all_resp.status_code != 200:
|
||||
return []
|
||||
|
||||
try:
|
||||
all_models = all_resp.json()
|
||||
except ValueError:
|
||||
return []
|
||||
filtered = [m for m in all_models if _model_matches_modality(m, modality)]
|
||||
return filtered or all_models
|
||||
|
||||
|
||||
def login_required(view):
|
||||
@functools.wraps(view)
|
||||
def wrapped(*args, **kwargs):
|
||||
@@ -163,9 +217,7 @@ def generate_text():
|
||||
result = resp.json()
|
||||
else:
|
||||
error = resp.json().get("detail", "Generation failed.")
|
||||
models_resp = _api("GET", "/models/", token=token,
|
||||
params={"modality": "text"})
|
||||
models = models_resp.json() if models_resp.status_code == 200 else []
|
||||
models = _load_models(token, "text")
|
||||
return render_template("generate_text.html", result=result, error=error, models=models)
|
||||
|
||||
|
||||
@@ -186,9 +238,7 @@ def generate_image():
|
||||
)
|
||||
if up_resp.status_code not in (200, 201):
|
||||
error = up_resp.json().get("detail", "Image upload failed.")
|
||||
models_resp = _api("GET", "/models/",
|
||||
token=token, params={"modality": "image"})
|
||||
models = models_resp.json() if models_resp.status_code == 200 else []
|
||||
models = _load_models(token, "image")
|
||||
return render_template("generate_image.html", result=result, error=error, models=models)
|
||||
|
||||
resp = _api("POST", "/generate/image", token=token, json={
|
||||
@@ -203,9 +253,7 @@ def generate_image():
|
||||
result = resp.json()
|
||||
else:
|
||||
error = resp.json().get("detail", "Generation failed.")
|
||||
models_resp = _api("GET", "/models/", token=token,
|
||||
params={"modality": "image"})
|
||||
models = models_resp.json() if models_resp.status_code == 200 else []
|
||||
models = _load_models(token, "image")
|
||||
return render_template("generate_image.html", result=result, error=error, models=models)
|
||||
|
||||
|
||||
@@ -241,9 +289,7 @@ def generate_video():
|
||||
result = resp.json()
|
||||
else:
|
||||
error = resp.json().get("detail", "Generation failed.")
|
||||
models_resp = _api("GET", "/models/", token=token,
|
||||
params={"modality": "video"})
|
||||
models = models_resp.json() if models_resp.status_code == 200 else []
|
||||
models = _load_models(token, "video")
|
||||
return render_template("generate_video.html", result=result, error=error, models=models)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user