feat: enhance model caching and output modalities handling

- Updated `refresh_models_cache` to include output modalities in the models cache.
- Added `get_model_output_modalities` function to retrieve output modalities for a specific model.
- Modified tests to cover new functionality for output modalities.
- Updated OpenRouter video generation functions to support audio generation and improved error handling.
- Enhanced dashboard to display generated images and videos.
- Refactored frontend templates to accommodate new data structures for generated content.
- Adjusted tests to validate changes in model handling and dashboard rendering.

Co-authored-by: Copilot <copilot@github.com>
This commit is contained in:
2026-04-29 15:20:48 +02:00
parent 3d32e6df74
commit 712c556032
15 changed files with 618 additions and 219 deletions
+28
View File
@@ -86,6 +86,34 @@ def _run_migrations(conn: duckdb.DuckDBPyConnection) -> None:
fetched_at TIMESTAMP NOT NULL
)
""")
conn.execute("""
CREATE TABLE IF NOT EXISTS generated_images (
id UUID DEFAULT uuid() PRIMARY KEY,
user_id UUID NOT NULL,
model_id VARCHAR NOT NULL,
prompt VARCHAR NOT NULL,
image_data VARCHAR NOT NULL,
created_at TIMESTAMP DEFAULT now()
)
""")
conn.execute("""
CREATE TABLE IF NOT EXISTS generated_videos (
id UUID DEFAULT uuid() PRIMARY KEY,
user_id UUID NOT NULL,
job_id VARCHAR NOT NULL,
model_id VARCHAR NOT NULL,
prompt VARCHAR NOT NULL,
polling_url VARCHAR,
status VARCHAR NOT NULL DEFAULT 'pending',
video_url VARCHAR,
created_at TIMESTAMP DEFAULT now(),
updated_at TIMESTAMP DEFAULT now()
)
""")
# Migration: add output_modalities column if absent (stores JSON array string)
conn.execute("""
ALTER TABLE models_cache ADD COLUMN IF NOT EXISTS output_modalities VARCHAR
""")
_seed_admin(conn)
+14 -14
View File
@@ -1,10 +1,10 @@
from .routers import auth as auth_router
from .routers import users as users_router
from .routers import admin as admin_router
from .routers import ai as ai_router
from .routers import generate as generate_router
from .routers import images as images_router
from .routers import models as models_router
from .routers import auth
from .routers import users
from .routers import admin
from .routers import ai
from .routers import generate
from .routers import images
from .routers import models
from .db import close_db, init_db
import os
from contextlib import asynccontextmanager
@@ -38,13 +38,13 @@ app.add_middleware(
allow_headers=["*"],
)
app.include_router(auth_router.router)
app.include_router(users_router.router)
app.include_router(admin_router.router)
app.include_router(ai_router.router)
app.include_router(generate_router.router)
app.include_router(images_router.router)
app.include_router(models_router.router)
app.include_router(auth.router)
app.include_router(users.router)
app.include_router(admin.router)
app.include_router(ai.router)
app.include_router(generate.router)
app.include_router(images.router)
app.include_router(models.router)
@app.get("/health", tags=["health"])
+1
View File
@@ -62,6 +62,7 @@ class ImageResult(BaseModel):
url: str | None = None
b64_json: str | None = None
revised_prompt: str | None = None
image_id: str | None = None # UUID of stored row in generated_images
class ImageResponse(BaseModel):
+193 -69
View File
@@ -1,6 +1,9 @@
"""Generate router: text, image, video, and image-to-video generation."""
from datetime import datetime, timezone
from fastapi import APIRouter, Depends, HTTPException, status
from ..db import get_conn, get_write_lock
from ..dependencies import get_current_user
from ..models.ai import (
ImageRequest,
@@ -13,6 +16,7 @@ from ..models.ai import (
VideoResponse,
)
from ..services import openrouter
from ..services.models import get_model_output_modalities
router = APIRouter(prefix="/generate", tags=["generate"])
@@ -62,81 +66,129 @@ async def generate_text(
@router.post("/image", response_model=ImageResponse)
async def generate_image(
body: ImageRequest,
_: dict = Depends(get_current_user),
current_user: dict = Depends(get_current_user),
) -> ImageResponse:
"""Generate images from a text prompt."""
# Detect if model uses chat completions (FLUX, GPT-5 Image Mini) vs /images/generations (DALL-E)
chat_models = {"black-forest-labs/flux.2-klein-4b",
"openai/gpt-5-image-mini"}
is_chat_model = body.model.lower() in {m.lower() for m in chat_models} or \
any(m in body.model.lower() for m in ["flux", "gpt-5-image-mini"])
"""Generate images from a prompt using the chat completions endpoint.
All OpenRouter image models use /chat/completions with a modalities param.
Models that output only images use ["image"]; those that also output text
use ["image", "text"]. We look this up from the model cache; default to
["image", "text"] when the model is not yet cached.
"""
# Determine modalities from cache; default ["image", "text"] works for most models
try:
conn = get_conn()
cached_modalities = get_model_output_modalities(conn, body.model)
except Exception:
cached_modalities = []
if cached_modalities:
# If cache says model only outputs image (no text), use ["image"]
modalities = ["image"] if set(cached_modalities) == {
"image"} else ["image", "text"]
else:
# Safe default: ["image", "text"]; works for Gemini, GPT-image etc.
# For image-only models that fail with this, the error surfaces to the user.
modalities = ["image", "text"]
image_config: dict = {}
if body.aspect_ratio:
image_config["aspect_ratio"] = body.aspect_ratio
if body.image_size:
image_config["image_size"] = body.image_size
try:
if is_chat_model:
image_config = {}
if body.aspect_ratio:
image_config["aspect_ratio"] = body.aspect_ratio
if body.image_size:
image_config["image_size"] = body.image_size
result = await openrouter.generate_image_chat(
model=body.model,
prompt=body.prompt,
modalities=[
"image", "text"] if "gpt-5-image-mini" in body.model.lower() else ["image"],
image_config=image_config if image_config else None,
)
else:
result = await openrouter.generate_image(
model=body.model,
prompt=body.prompt,
n=body.n,
size=body.size,
)
result = await openrouter.generate_image_chat(
model=body.model,
prompt=body.prompt,
modalities=modalities,
image_config=image_config if image_config else None,
)
except Exception as exc:
raise HTTPException(
status_code=status.HTTP_502_BAD_GATEWAY, detail=f"OpenRouter error: {exc}")
try:
if is_chat_model:
# Chat completions response: choices[0].message.images[].image_url.url
images = []
message = result.get("choices", [{}])[0].get("message", {})
for item in message.get("images", []):
img_url = item.get("image_url", {}).get("url")
images.append(ImageResult(
url=img_url,
b64_json=None,
revised_prompt=message.get("content"),
message = result.get("choices", [{}])[0].get("message", {})
images = []
for item in message.get("images", []):
img_url = item.get("image_url", {}).get("url")
images.append(ImageResult(
url=img_url,
b64_json=None,
revised_prompt=message.get("content") or None,
))
if not images:
raise HTTPException(
status_code=status.HTTP_502_BAD_GATEWAY,
detail="No images returned by model. Verify the model supports image generation.",
)
# Persist each image to DB
user_id = current_user.get("id") or current_user.get("sub")
now = datetime.now(timezone.utc).replace(tzinfo=None)
stored: list[ImageResult] = []
async with get_write_lock():
conn = get_conn()
for img in images:
if img.url:
row = conn.execute(
"""INSERT INTO generated_images (user_id, model_id, prompt, image_data, created_at)
VALUES (?, ?, ?, ?, ?) RETURNING id""",
[user_id, body.model, body.prompt, img.url, now],
).fetchone()
image_id = str(row[0]) if row else None
else:
image_id = None
stored.append(ImageResult(
url=img.url,
b64_json=img.b64_json,
revised_prompt=img.revised_prompt,
image_id=image_id,
))
return ImageResponse(
id=result.get("id", ""),
model=result.get("model", body.model),
images=images,
)
else:
# /images/generations response: data[].url
images = [
ImageResult(
url=item.get("url"),
b64_json=item.get("b64_json"),
revised_prompt=item.get("revised_prompt"),
)
for item in result.get("data", [])
]
return ImageResponse(
id=result.get("id", ""),
model=result.get("model", body.model),
images=images,
)
return ImageResponse(
id=result.get("id", ""),
model=result.get("model", body.model),
images=stored,
)
except HTTPException:
raise
except (KeyError, TypeError) as exc:
raise HTTPException(status_code=status.HTTP_502_BAD_GATEWAY,
detail=f"Unexpected response format: {exc}")
@router.get("/images")
async def list_generated_images(
current_user: dict = Depends(get_current_user),
) -> list[dict]:
"""Return all generated images for the current user, newest first."""
user_id = current_user.get("id") or current_user.get("sub")
conn = get_conn()
rows = conn.execute(
"""SELECT id, model_id, prompt, image_data, created_at
FROM generated_images
WHERE user_id = ?
ORDER BY created_at DESC""",
[user_id],
).fetchall()
return [
{
"id": str(r[0]),
"model_id": r[1],
"prompt": r[2],
"image_data": r[3],
"created_at": r[4].isoformat() if r[4] else None,
}
for r in rows
]
@router.post("/video", response_model=VideoResponse)
async def generate_video(
body: VideoRequest,
_: dict = Depends(get_current_user),
current_user: dict = Depends(get_current_user),
) -> VideoResponse:
"""Generate a video from a text prompt."""
try:
@@ -151,12 +203,26 @@ async def generate_video(
raise HTTPException(
status_code=status.HTTP_502_BAD_GATEWAY, detail=f"OpenRouter error: {exc}")
user_id = current_user.get("id") or current_user.get("sub")
job_id = result.get("id", "")
polling_url = result.get("polling_url")
job_status = result.get("status", "pending")
now = datetime.now(timezone.utc).replace(tzinfo=None)
async with get_write_lock():
conn = get_conn()
conn.execute(
"""INSERT INTO generated_videos (user_id, job_id, model_id, prompt, polling_url, status, created_at, updated_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)""",
[user_id, job_id, body.model, body.prompt,
polling_url, job_status, now, now],
)
urls = result.get("unsigned_urls") or result.get("video_urls")
return VideoResponse(
id=result.get("id", ""),
id=job_id,
model=body.model,
status=result.get("status", "queued"),
polling_url=result.get("polling_url"),
status=job_status,
polling_url=polling_url,
video_urls=urls,
video_url=(urls or [None])[0],
error=result.get("error"),
@@ -167,7 +233,7 @@ async def generate_video(
@router.post("/video/from-image", response_model=VideoResponse)
async def generate_video_from_image(
body: VideoFromImageRequest,
_: dict = Depends(get_current_user),
current_user: dict = Depends(get_current_user),
) -> VideoResponse:
"""Generate a video from an image and a text prompt."""
try:
@@ -183,12 +249,26 @@ async def generate_video_from_image(
raise HTTPException(
status_code=status.HTTP_502_BAD_GATEWAY, detail=f"OpenRouter error: {exc}")
user_id = current_user.get("id") or current_user.get("sub")
job_id = result.get("id", "")
polling_url = result.get("polling_url")
job_status = result.get("status", "pending")
now = datetime.now(timezone.utc).replace(tzinfo=None)
async with get_write_lock():
conn = get_conn()
conn.execute(
"""INSERT INTO generated_videos (user_id, job_id, model_id, prompt, polling_url, status, created_at, updated_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)""",
[user_id, job_id, body.model, body.prompt,
polling_url, job_status, now, now],
)
urls = result.get("unsigned_urls") or result.get("video_urls")
return VideoResponse(
id=result.get("id", ""),
id=job_id,
model=body.model,
status=result.get("status", "queued"),
polling_url=result.get("polling_url"),
status=job_status,
polling_url=polling_url,
video_urls=urls,
video_url=(urls or [None])[0],
error=result.get("error"),
@@ -199,23 +279,67 @@ async def generate_video_from_image(
@router.get("/video/status", response_model=VideoResponse)
async def poll_video_status(
polling_url: str,
_: dict = Depends(get_current_user),
current_user: dict = Depends(get_current_user),
) -> VideoResponse:
"""Poll the status of a video generation job via its polling_url."""
"""Poll status of a video generation job; updates DB row when completed/failed."""
try:
result = await openrouter.poll_video_status(polling_url)
except Exception as exc:
raise HTTPException(
status_code=status.HTTP_502_BAD_GATEWAY, detail=f"OpenRouter error: {exc}")
job_status = result.get("status", "processing")
urls = result.get("unsigned_urls") or result.get("video_urls")
video_url = (urls or [None])[0]
# Update DB row for this job when terminal state reached
if job_status in ("completed", "failed"):
now = datetime.now(timezone.utc).replace(tzinfo=None)
async with get_write_lock():
conn = get_conn()
conn.execute(
"""UPDATE generated_videos
SET status = ?, video_url = ?, updated_at = ?
WHERE job_id = ?""",
[job_status, video_url, now, result.get("id", "")],
)
return VideoResponse(
id=result.get("id", ""),
model=result.get("model", ""),
status=result.get("status", "processing"),
status=job_status,
polling_url=result.get("polling_url"),
video_urls=urls,
video_url=(urls or [None])[0],
video_url=video_url,
error=result.get("error"),
metadata=result.get("metadata"),
)
@router.get("/videos")
async def list_generated_videos(
current_user: dict = Depends(get_current_user),
) -> list[dict]:
"""Return all generated video jobs for the current user, newest first."""
user_id = current_user.get("id") or current_user.get("sub")
conn = get_conn()
rows = conn.execute(
"""SELECT id, job_id, model_id, prompt, polling_url, status, video_url, created_at
FROM generated_videos
WHERE user_id = ?
ORDER BY created_at DESC""",
[user_id],
).fetchall()
return [
{
"id": str(r[0]),
"job_id": r[1],
"model_id": r[2],
"prompt": r[3],
"polling_url": r[4],
"status": r[5],
"video_url": r[6],
"created_at": r[7].isoformat() if r[7] else None,
}
for r in rows
]
+33 -3
View File
@@ -91,16 +91,28 @@ async def refresh_models_cache(conn: duckdb.DuckDBPyConnection) -> int:
model_id = m.get("id", "")
if not model_id:
continue
# Full output_modalities array from architecture (for proper modalities param in image gen)
architecture = m.get("architecture") or {}
raw_output_modalities: list | None = (
architecture.get("output_modalities") or m.get("output_modalities")
)
output_modalities_json: str | None = (
json.dumps([_normalize_modality(str(v))
for v in raw_output_modalities])
if isinstance(raw_output_modalities, list)
else None
)
conn.execute(
"""
INSERT INTO models_cache (model_id, name, modality, context_length, pricing, fetched_at)
VALUES (?, ?, ?, ?, ?, ?)
INSERT INTO models_cache (model_id, name, modality, context_length, pricing, fetched_at, output_modalities)
VALUES (?, ?, ?, ?, ?, ?, ?)
ON CONFLICT (model_id) DO UPDATE SET
name = excluded.name,
modality = excluded.modality,
context_length = excluded.context_length,
pricing = excluded.pricing,
fetched_at = excluded.fetched_at
fetched_at = excluded.fetched_at,
output_modalities = excluded.output_modalities
""",
[
model_id,
@@ -109,6 +121,7 @@ async def refresh_models_cache(conn: duckdb.DuckDBPyConnection) -> int:
m.get("context_length"),
json.dumps(pricing) if pricing else None,
now,
output_modalities_json,
],
)
count += 1
@@ -168,3 +181,20 @@ def get_cached_models(
"pricing": pricing,
})
return result
def get_model_output_modalities(
conn: duckdb.DuckDBPyConnection,
model_id: str,
) -> list[str]:
"""Return output_modalities list for a model; empty list if not found."""
row = conn.execute(
"SELECT output_modalities FROM models_cache WHERE model_id = ?",
[model_id],
).fetchone()
if not row or not row[0]:
return []
try:
return json.loads(row[0])
except (json.JSONDecodeError, TypeError):
return []
+33 -5
View File
@@ -95,8 +95,9 @@ async def generate_video(
duration_seconds: int | None = None,
aspect_ratio: str = "16:9",
resolution: str | None = None,
generate_audio: bool | None = None,
) -> dict[str, Any]:
"""Request text-to-video generation via OpenRouter."""
"""Request text-to-video generation via OpenRouter POST /videos."""
base_url = os.getenv("OPENROUTER_BASE_URL", OPENROUTER_BASE_URL)
payload: dict[str, Any] = {
"model": model,
@@ -104,9 +105,12 @@ async def generate_video(
"aspect_ratio": aspect_ratio,
}
if duration_seconds is not None:
payload["duration_seconds"] = duration_seconds
# API uses 'duration' not 'duration_seconds'
payload["duration"] = duration_seconds
if resolution is not None:
payload["resolution"] = resolution
if generate_audio is not None:
payload["generate_audio"] = generate_audio
async with httpx.AsyncClient(timeout=120) as client:
resp = client.build_request(
"POST", f"{base_url}/videos", headers=_headers(), json=payload
@@ -123,19 +127,31 @@ async def generate_video_from_image(
duration_seconds: int | None = None,
aspect_ratio: str = "16:9",
resolution: str | None = None,
generate_audio: bool | None = None,
) -> dict[str, Any]:
"""Request image-to-video generation via OpenRouter."""
"""Request image-to-video generation via OpenRouter POST /videos.
Uses frame_images array with first_frame as per OpenRouter API spec.
"""
base_url = os.getenv("OPENROUTER_BASE_URL", OPENROUTER_BASE_URL)
payload: dict[str, Any] = {
"model": model,
"image_url": image_url,
"prompt": prompt,
"aspect_ratio": aspect_ratio,
"frame_images": [
{
"type": "image_url",
"image_url": {"url": image_url},
"frame_type": "first_frame",
}
],
}
if duration_seconds is not None:
payload["duration_seconds"] = duration_seconds
payload["duration"] = duration_seconds
if resolution is not None:
payload["resolution"] = resolution
if generate_audio is not None:
payload["generate_audio"] = generate_audio
async with httpx.AsyncClient(timeout=120) as client:
resp = client.build_request(
"POST", f"{base_url}/videos", headers=_headers(), json=payload
@@ -154,6 +170,18 @@ async def poll_video_status(polling_url: str) -> dict[str, Any]:
return response.json()
async def list_video_models() -> list[dict[str, Any]]:
"""Return video generation models from the dedicated /videos/models endpoint."""
base_url = os.getenv("OPENROUTER_BASE_URL", OPENROUTER_BASE_URL)
async with httpx.AsyncClient(timeout=15) as client:
resp = client.build_request(
"GET", f"{base_url}/videos/models", headers=_headers()
)
response = await client.send(resp)
response.raise_for_status()
return response.json().get("data", [])
async def generate_image_chat(
model: str,
prompt: str,