Add model loading functionality with modality filtering and refactor generation routes to utilize it
Co-authored-by: Copilot <copilot@github.com>
This commit is contained in:
+58
-12
@@ -35,6 +35,60 @@ def _api(method: str, path: str, *, token: str | None = None, **kwargs):
|
|||||||
return httpx.request(method, _backend(path), headers=headers, timeout=30, **kwargs)
|
return httpx.request(method, _backend(path), headers=headers, timeout=30, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def _model_matches_modality(model: dict, modality: str) -> bool:
|
||||||
|
"""Heuristic fallback when backend modality filter returns empty."""
|
||||||
|
model_modality = (model.get("modality") or "").lower()
|
||||||
|
if model_modality == modality:
|
||||||
|
return True
|
||||||
|
|
||||||
|
text = f"{model.get('id', '')} {model.get('name', '')}".lower()
|
||||||
|
keywords = {
|
||||||
|
"image": ["image", "dall-e", "flux", "stable-diffusion", "sdxl", "recraft", "ideogram", "gpt-image"],
|
||||||
|
"video": ["video", "sora", "runway", "veo", "kling", "pika", "luma", "wan"],
|
||||||
|
"audio": ["audio", "speech", "voice", "tts", "transcribe", "whisper"],
|
||||||
|
}
|
||||||
|
|
||||||
|
if modality in keywords:
|
||||||
|
return any(k in text for k in keywords[modality])
|
||||||
|
|
||||||
|
if modality == "text":
|
||||||
|
non_text_hits = any(
|
||||||
|
k in text for k in keywords["image"] + keywords["video"] + keywords["audio"])
|
||||||
|
return not non_text_hits
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def _load_models(token: str, modality: str) -> list[dict]:
|
||||||
|
"""Load models for modality; fallback to unfiltered cache if needed."""
|
||||||
|
try:
|
||||||
|
models_resp = _api("GET", "/models/", token=token,
|
||||||
|
params={"modality": modality})
|
||||||
|
except httpx.RequestError:
|
||||||
|
return []
|
||||||
|
if models_resp.status_code == 200:
|
||||||
|
try:
|
||||||
|
models = models_resp.json()
|
||||||
|
except ValueError:
|
||||||
|
models = []
|
||||||
|
if models:
|
||||||
|
return models
|
||||||
|
|
||||||
|
try:
|
||||||
|
all_resp = _api("GET", "/models/", token=token)
|
||||||
|
except httpx.RequestError:
|
||||||
|
return []
|
||||||
|
if all_resp.status_code != 200:
|
||||||
|
return []
|
||||||
|
|
||||||
|
try:
|
||||||
|
all_models = all_resp.json()
|
||||||
|
except ValueError:
|
||||||
|
return []
|
||||||
|
filtered = [m for m in all_models if _model_matches_modality(m, modality)]
|
||||||
|
return filtered or all_models
|
||||||
|
|
||||||
|
|
||||||
def login_required(view):
|
def login_required(view):
|
||||||
@functools.wraps(view)
|
@functools.wraps(view)
|
||||||
def wrapped(*args, **kwargs):
|
def wrapped(*args, **kwargs):
|
||||||
@@ -163,9 +217,7 @@ def generate_text():
|
|||||||
result = resp.json()
|
result = resp.json()
|
||||||
else:
|
else:
|
||||||
error = resp.json().get("detail", "Generation failed.")
|
error = resp.json().get("detail", "Generation failed.")
|
||||||
models_resp = _api("GET", "/models/", token=token,
|
models = _load_models(token, "text")
|
||||||
params={"modality": "text"})
|
|
||||||
models = models_resp.json() if models_resp.status_code == 200 else []
|
|
||||||
return render_template("generate_text.html", result=result, error=error, models=models)
|
return render_template("generate_text.html", result=result, error=error, models=models)
|
||||||
|
|
||||||
|
|
||||||
@@ -186,9 +238,7 @@ def generate_image():
|
|||||||
)
|
)
|
||||||
if up_resp.status_code not in (200, 201):
|
if up_resp.status_code not in (200, 201):
|
||||||
error = up_resp.json().get("detail", "Image upload failed.")
|
error = up_resp.json().get("detail", "Image upload failed.")
|
||||||
models_resp = _api("GET", "/models/",
|
models = _load_models(token, "image")
|
||||||
token=token, params={"modality": "image"})
|
|
||||||
models = models_resp.json() if models_resp.status_code == 200 else []
|
|
||||||
return render_template("generate_image.html", result=result, error=error, models=models)
|
return render_template("generate_image.html", result=result, error=error, models=models)
|
||||||
|
|
||||||
resp = _api("POST", "/generate/image", token=token, json={
|
resp = _api("POST", "/generate/image", token=token, json={
|
||||||
@@ -203,9 +253,7 @@ def generate_image():
|
|||||||
result = resp.json()
|
result = resp.json()
|
||||||
else:
|
else:
|
||||||
error = resp.json().get("detail", "Generation failed.")
|
error = resp.json().get("detail", "Generation failed.")
|
||||||
models_resp = _api("GET", "/models/", token=token,
|
models = _load_models(token, "image")
|
||||||
params={"modality": "image"})
|
|
||||||
models = models_resp.json() if models_resp.status_code == 200 else []
|
|
||||||
return render_template("generate_image.html", result=result, error=error, models=models)
|
return render_template("generate_image.html", result=result, error=error, models=models)
|
||||||
|
|
||||||
|
|
||||||
@@ -241,9 +289,7 @@ def generate_video():
|
|||||||
result = resp.json()
|
result = resp.json()
|
||||||
else:
|
else:
|
||||||
error = resp.json().get("detail", "Generation failed.")
|
error = resp.json().get("detail", "Generation failed.")
|
||||||
models_resp = _api("GET", "/models/", token=token,
|
models = _load_models(token, "video")
|
||||||
params={"modality": "video"})
|
|
||||||
models = models_resp.json() if models_resp.status_code == 200 else []
|
|
||||||
return render_template("generate_video.html", result=result, error=error, models=models)
|
return render_template("generate_video.html", result=result, error=error, models=models)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user