feat: enhance model caching and output modalities handling

- Updated `refresh_models_cache` to include output modalities in the models cache.
- Added `get_model_output_modalities` function to retrieve output modalities for a specific model.
- Modified tests to cover new functionality for output modalities.
- Updated OpenRouter video generation functions to support audio generation and improved error handling.
- Enhanced dashboard to display generated images and videos.
- Refactored frontend templates to accommodate new data structures for generated content.
- Adjusted tests to validate changes in model handling and dashboard rendering.

Co-authored-by: Copilot <copilot@github.com>
This commit is contained in:
2026-04-29 15:20:48 +02:00
parent 3d32e6df74
commit 712c556032
15 changed files with 618 additions and 219 deletions
+2 -1
View File
@@ -53,7 +53,8 @@ async def test_stats_as_admin(client):
resp = await client.get("/admin/stats", headers={"Authorization": f"Bearer {token}"})
assert resp.status_code == 200
data = resp.json()
assert data["users"]["total"] == 3 # 2 users + 1 admin
# 2 users + 1 admin + 1 seeded admin (ai@allucanget.biz)
assert data["users"]["total"] == 4
assert "by_role" in data["users"]
assert "refresh_tokens" in data
+6 -6
View File
@@ -53,7 +53,7 @@ async def _user_token(client):
async def test_list_models(client):
token = await _user_token(client)
with patch(
"backend.app.routers.ai.openrouter.list_models",
"app.routers.ai.openrouter.list_models",
new_callable=AsyncMock,
return_value=FAKE_MODELS,
):
@@ -74,7 +74,7 @@ async def test_list_models_unauthenticated(client):
async def test_list_models_upstream_error(client):
token = await _user_token(client)
with patch(
"backend.app.routers.ai.openrouter.list_models",
"app.routers.ai.openrouter.list_models",
new_callable=AsyncMock,
side_effect=Exception("Connection refused"),
):
@@ -91,7 +91,7 @@ async def test_list_models_upstream_error(client):
async def test_chat_success(client):
token = await _user_token(client)
with patch(
"backend.app.routers.ai.openrouter.chat_completion",
"app.routers.ai.openrouter.chat_completion",
new_callable=AsyncMock,
return_value=FAKE_CHAT_RESPONSE,
):
@@ -115,7 +115,7 @@ async def test_chat_success(client):
async def test_chat_passes_parameters(client):
token = await _user_token(client)
mock = AsyncMock(return_value=FAKE_CHAT_RESPONSE)
with patch("backend.app.routers.ai.openrouter.chat_completion", new_callable=AsyncMock, return_value=FAKE_CHAT_RESPONSE) as mock:
with patch("app.routers.ai.openrouter.chat_completion", new_callable=AsyncMock, return_value=FAKE_CHAT_RESPONSE) as mock:
await client.post(
"/ai/chat",
json={
@@ -145,7 +145,7 @@ async def test_chat_unauthenticated(client):
async def test_chat_upstream_error(client):
token = await _user_token(client)
with patch(
"backend.app.routers.ai.openrouter.chat_completion",
"app.routers.ai.openrouter.chat_completion",
new_callable=AsyncMock,
side_effect=Exception("timeout"),
):
@@ -160,7 +160,7 @@ async def test_chat_upstream_error(client):
async def test_chat_malformed_upstream_response(client):
token = await _user_token(client)
with patch(
"backend.app.routers.ai.openrouter.chat_completion",
"app.routers.ai.openrouter.chat_completion",
new_callable=AsyncMock,
return_value={"id": "x", "choices": []}, # empty choices
):
+135 -69
View File
@@ -18,15 +18,6 @@ FAKE_CHAT = {
"usage": {"prompt_tokens": 5, "completion_tokens": 10, "total_tokens": 15},
}
FAKE_IMAGE = {
"id": "gen-img-1",
"model": "openai/dall-e-3",
"data": [
{"url": "https://example.com/image.png",
"revised_prompt": "A cat on the moon"},
],
}
FAKE_VIDEO = {
"id": "gen-vid-1",
"polling_url": "https://openrouter.ai/api/v1/videos/gen-vid-1",
@@ -155,47 +146,13 @@ async def test_generate_text_upstream_error(client):
# POST /generate/image
# ---------------------------------------------------------------------------
async def test_generate_image(client):
token = await _user_token(client)
with patch("app.routers.generate.openrouter.generate_image", new_callable=AsyncMock, return_value=FAKE_IMAGE):
resp = await client.post(
"/generate/image",
json={"model": "openai/dall-e-3", "prompt": "A cat on the moon"},
headers={"Authorization": f"Bearer {token}"},
)
assert resp.status_code == 200
data = resp.json()
assert data["id"] == "gen-img-1"
assert len(data["images"]) == 1
assert data["images"][0]["url"] == "https://example.com/image.png"
assert data["images"][0]["revised_prompt"] == "A cat on the moon"
async def test_generate_image_unauthenticated(client):
resp = await client.post("/generate/image", json={"model": "openai/dall-e-3", "prompt": "Hi"})
assert resp.status_code == 401
async def test_generate_image_upstream_error(client):
token = await _user_token(client)
with patch("app.routers.generate.openrouter.generate_image", new_callable=AsyncMock, side_effect=Exception("rate limit")):
resp = await client.post(
"/generate/image",
json={"model": "openai/dall-e-3", "prompt": "Hi"},
headers={"Authorization": f"Bearer {token}"},
)
assert resp.status_code == 502
# --- Chat-based image generation (FLUX, GPT-5 Image Mini) ---
FAKE_IMAGE_CHAT_FLUX = {
"id": "gen-img-chat-1",
"model": "black-forest-labs/flux.2-klein-4b",
"choices": [{
"message": {
"role": "assistant",
"content": "Here is your generated image.",
"content": None,
"images": [{
"type": "image_url",
"image_url": {"url": "data:image/png;base64,abc123"},
@@ -219,45 +176,65 @@ FAKE_IMAGE_CHAT_GPT5 = {
}],
}
FAKE_IMAGE_CHAT_GEMINI = {
"id": "gen-img-chat-3",
"model": "google/gemini-2.5-flash-image",
"choices": [{
"message": {
"role": "assistant",
"content": "Here is your image.",
"images": [{
"type": "image_url",
"image_url": {"url": "data:image/png;base64,gemini123"},
}],
}
}],
}
async def test_generate_image_chat_flux(client):
async def test_generate_image(client):
"""All models now use generate_image_chat (chat completions endpoint)."""
token = await _user_token(client)
with patch("app.routers.generate.openrouter.generate_image_chat", new_callable=AsyncMock, return_value=FAKE_IMAGE_CHAT_FLUX):
with patch("app.routers.generate.openrouter.generate_image_chat", new_callable=AsyncMock, return_value=FAKE_IMAGE_CHAT_GEMINI):
resp = await client.post(
"/generate/image",
json={"model": "black-forest-labs/flux.2-klein-4b",
"prompt": "A sunset"},
json={"model": "google/gemini-2.5-flash-image",
"prompt": "A cat on the moon"},
headers={"Authorization": f"Bearer {token}"},
)
assert resp.status_code == 200
data = resp.json()
assert data["id"] == "gen-img-chat-1"
assert data["id"] == "gen-img-chat-3"
assert len(data["images"]) == 1
assert data["images"][0]["url"] == "data:image/png;base64,abc123"
assert data["images"][0]["url"] == "data:image/png;base64,gemini123"
assert data["images"][0]["image_id"] is not None # stored in DB
async def test_generate_image_chat_gpt5_image_mini(client):
async def test_generate_image_unauthenticated(client):
resp = await client.post("/generate/image", json={"model": "google/gemini-2.5-flash-image", "prompt": "Hi"})
assert resp.status_code == 401
async def test_generate_image_upstream_error(client):
token = await _user_token(client)
with patch("app.routers.generate.openrouter.generate_image_chat", new_callable=AsyncMock, return_value=FAKE_IMAGE_CHAT_GPT5):
with patch("app.routers.generate.openrouter.generate_image_chat", new_callable=AsyncMock, side_effect=Exception("rate limit")):
resp = await client.post(
"/generate/image",
json={"model": "openai/gpt-5-image-mini", "prompt": "A cat"},
json={"model": "google/gemini-2.5-flash-image", "prompt": "Hi"},
headers={"Authorization": f"Bearer {token}"},
)
assert resp.status_code == 200
data = resp.json()
assert data["model"] == "openai/gpt-5-image-mini"
assert len(data["images"]) == 1
assert resp.status_code == 502
async def test_generate_image_chat_with_image_config(client):
async def test_generate_image_with_image_config(client):
"""Passes aspect_ratio + image_size through to generate_image_chat."""
token = await _user_token(client)
mock = AsyncMock(return_value=FAKE_IMAGE_CHAT_FLUX)
mock = AsyncMock(return_value=FAKE_IMAGE_CHAT_GEMINI)
with patch("app.routers.generate.openrouter.generate_image_chat", mock):
await client.post(
"/generate/image",
json={
"model": "black-forest-labs/flux.2-klein-4b",
"model": "google/gemini-2.5-flash-image",
"prompt": "A landscape",
"aspect_ratio": "16:9",
"image_size": "2K",
@@ -267,23 +244,112 @@ async def test_generate_image_chat_with_image_config(client):
call_kwargs = mock.call_args.kwargs
assert call_kwargs["image_config"]["aspect_ratio"] == "16:9"
assert call_kwargs["image_config"]["image_size"] == "2K"
assert call_kwargs["modalities"] == ["image"]
async def test_generate_image_chat_unauthenticated(client):
resp = await client.post("/generate/image", json={"model": "flux.2-klein-4b", "prompt": "Hi"})
assert resp.status_code == 401
async def test_generate_image_chat_upstream_error(client):
async def test_generate_image_default_modalities_image_text(client):
"""Model not in cache → default modalities = ['image', 'text']."""
token = await _user_token(client)
with patch("app.routers.generate.openrouter.generate_image_chat", new_callable=AsyncMock, side_effect=Exception("timeout")):
mock = AsyncMock(return_value=FAKE_IMAGE_CHAT_GEMINI)
with patch("app.routers.generate.openrouter.generate_image_chat", mock):
await client.post(
"/generate/image",
json={"model": "google/gemini-2.5-flash-image", "prompt": "Hi"},
headers={"Authorization": f"Bearer {token}"},
)
assert mock.call_args.kwargs["modalities"] == ["image", "text"]
async def test_generate_image_image_only_modalities_from_cache(client):
"""Model cached with image-only output_modalities → modalities = ['image']."""
from app import db as db_module
from app.services.models import get_model_output_modalities
import json as _json
token = await _user_token(client)
# Seed cache with image-only model
conn = db_module.get_conn()
from datetime import datetime, timezone
now = datetime.now(timezone.utc).replace(tzinfo=None)
conn.execute(
"DELETE FROM models_cache WHERE model_id = 'black-forest-labs/flux.2-pro'"
)
conn.execute(
"""INSERT INTO models_cache (model_id, name, modality, context_length, pricing, fetched_at, output_modalities)
VALUES (?, ?, ?, ?, ?, ?, ?)""",
["black-forest-labs/flux.2-pro", "FLUX.2 Pro", "image", None, None, now,
_json.dumps(["image"])],
)
mock = AsyncMock(return_value=FAKE_IMAGE_CHAT_FLUX)
with patch("app.routers.generate.openrouter.generate_image_chat", mock):
resp = await client.post(
"/generate/image",
json={"model": "black-forest-labs/flux.2-klein-4b", "prompt": "Hi"},
json={"model": "black-forest-labs/flux.2-pro", "prompt": "Sky"},
headers={"Authorization": f"Bearer {token}"},
)
assert resp.status_code == 200
assert mock.call_args.kwargs["modalities"] == ["image"]
async def test_generate_image_no_images_in_response(client):
"""502 when model returns no images."""
token = await _user_token(client)
empty_response = {
"id": "gen-empty",
"model": "google/gemini-2.5-flash-image",
"choices": [{"message": {"role": "assistant", "content": "ok", "images": []}}],
}
with patch("app.routers.generate.openrouter.generate_image_chat",
new_callable=AsyncMock, return_value=empty_response):
resp = await client.post(
"/generate/image",
json={"model": "google/gemini-2.5-flash-image", "prompt": "Hi"},
headers={"Authorization": f"Bearer {token}"},
)
assert resp.status_code == 502
assert "No images returned" in resp.json()["detail"]
async def test_generate_image_flux(client):
"""Flux model works correctly via chat completions."""
token = await _user_token(client)
with patch("app.routers.generate.openrouter.generate_image_chat",
new_callable=AsyncMock, return_value=FAKE_IMAGE_CHAT_FLUX):
resp = await client.post(
"/generate/image",
json={"model": "black-forest-labs/flux.2-klein-4b",
"prompt": "A sunset"},
headers={"Authorization": f"Bearer {token}"},
)
assert resp.status_code == 200
data = resp.json()
assert data["images"][0]["url"] == "data:image/png;base64,abc123"
async def test_generate_image_stored_in_db(client):
"""Generated image row persists in generated_images table."""
from app import db as db_module
token = await _user_token(client)
with patch("app.routers.generate.openrouter.generate_image_chat",
new_callable=AsyncMock, return_value=FAKE_IMAGE_CHAT_GEMINI):
resp = await client.post(
"/generate/image",
json={"model": "google/gemini-2.5-flash-image",
"prompt": "A mountain"},
headers={"Authorization": f"Bearer {token}"},
)
assert resp.status_code == 200
image_id = resp.json()["images"][0]["image_id"]
assert image_id is not None
row = db_module.get_conn().execute(
"SELECT model_id, prompt, image_data FROM generated_images WHERE id = ?",
[image_id],
).fetchone()
assert row is not None
assert row[0] == "google/gemini-2.5-flash-image"
assert row[1] == "A mountain"
assert row[2] == "data:image/png;base64,gemini123"
# ---------------------------------------------------------------------------
+63 -14
View File
@@ -15,6 +15,7 @@ from app.services.models import (
_normalize_modality,
_parse_modality,
get_cached_models,
get_model_output_modalities,
is_cache_stale,
refresh_models_cache,
)
@@ -28,28 +29,35 @@ FAKE_MODELS_RAW = [
"name": "GPT-4o",
"context_length": 128000,
"pricing": {"prompt": "0.000005"},
"architecture": {"modality": "text->text"},
"architecture": {"modality": "text->text", "output_modalities": ["text"]},
},
{
"id": "anthropic/claude-3-haiku",
"name": "Claude 3 Haiku",
"context_length": 200000,
"pricing": {},
"architecture": {"modality": "text+image->text"},
"architecture": {"modality": "text+image->text", "output_modalities": ["text"]},
},
{
"id": "openai/dall-e-3",
"name": "DALL-E 3",
"context_length": None,
"pricing": {"image": "0.04"},
"architecture": {"modality": "text->image"},
"architecture": {"modality": "text->image", "output_modalities": ["image"]},
},
{
"id": "openai/sora-2",
"name": "Sora 2",
"context_length": None,
"pricing": {"video": "0.10"},
"architecture": {"modality": "text->video"},
"architecture": {"modality": "text->video", "output_modalities": ["video"]},
},
{
"id": "google/gemini-2.5-flash-image",
"name": "Gemini 2.5 Flash Image",
"context_length": None,
"pricing": {},
"architecture": {"output_modalities": ["image", "text"]},
},
]
@@ -171,9 +179,9 @@ async def test_refresh_stores_models():
return_value=FAKE_MODELS_RAW,
):
count = await refresh_models_cache(conn)
assert count == 4
assert count == 5
all_models = get_cached_models(conn)
assert len(all_models) == 4
assert len(all_models) == 5
async def test_refresh_replaces_old_cache():
@@ -193,25 +201,29 @@ async def test_refresh_replaces_old_cache():
ids = [m["id"] for m in get_cached_models(conn)]
assert "old/model" not in ids
assert "openai/gpt-4o" in ids
assert len(ids) == 5
def test_get_cached_models_filter_by_modality():
conn = db_module.get_conn()
now = datetime.now(timezone.utc).replace(tzinfo=None)
for m in FAKE_MODELS_RAW:
arch = m.get("architecture", {})
modality = _parse_modality(arch.get("modality", "text->text"))
modality = _extract_output_modality(m)
conn.execute(
"INSERT INTO models_cache (model_id, name, modality, fetched_at) VALUES (?, ?, ?, ?)",
[m["id"], m["name"], modality, now],
)
text_models = get_cached_models(conn, modality="text")
# gpt-4o, claude-3-haiku (gemini has output_modalities=["image","text"] → classified as "image")
assert len(text_models) == 2
assert all(m["modality"] == "text" for m in text_models)
image_models = get_cached_models(conn, modality="image")
assert len(image_models) == 1
assert image_models[0]["id"] == "openai/dall-e-3"
# dall-e-3 + gemini (output_modalities starts with image)
assert len(image_models) == 2
image_ids = [m["id"] for m in image_models]
assert "openai/dall-e-3" in image_ids
assert "google/gemini-2.5-flash-image" in image_ids
video_models = get_cached_models(conn, modality="video")
assert len(video_models) == 1
@@ -233,7 +245,7 @@ async def test_list_models_endpoint_auto_refreshes(client):
"/models/", headers={"Authorization": f"Bearer {token}"}
)
assert resp.status_code == 200
assert len(resp.json()) == 4
assert len(resp.json()) == 5
assert mock_fetch.await_count >= 1
@@ -274,8 +286,10 @@ async def test_list_models_filter_by_modality(client):
)
assert resp.status_code == 200
data = resp.json()
assert len(data) == 1
assert data[0]["id"] == "openai/dall-e-3"
assert len(data) == 2 # dall-e-3 + gemini-2.5-flash-image
image_ids = [m["id"] for m in data]
assert "openai/dall-e-3" in image_ids
assert "google/gemini-2.5-flash-image" in image_ids
# ---------------------------------------------------------------------------
@@ -301,7 +315,7 @@ async def test_refresh_endpoint_admin_succeeds(client):
"/models/refresh", headers={"Authorization": f"Bearer {token}"}
)
assert resp.status_code == 200
assert resp.json()["refreshed"] == 4
assert resp.json()["refreshed"] == 5
async def test_refresh_endpoint_502_on_openrouter_error(client):
@@ -315,3 +329,38 @@ async def test_refresh_endpoint_502_on_openrouter_error(client):
"/models/refresh", headers={"Authorization": f"Bearer {token}"}
)
assert resp.status_code == 502
# ---------------------------------------------------------------------------
# Unit tests: get_model_output_modalities
# ---------------------------------------------------------------------------
async def test_get_model_output_modalities_image_only():
conn = db_module.get_conn()
with patch(
"app.services.models.openrouter.list_models",
new_callable=AsyncMock,
return_value=FAKE_MODELS_RAW,
):
await refresh_models_cache(conn)
modalities = get_model_output_modalities(conn, "openai/dall-e-3")
assert modalities == ["image"]
async def test_get_model_output_modalities_image_text():
conn = db_module.get_conn()
with patch(
"app.services.models.openrouter.list_models",
new_callable=AsyncMock,
return_value=FAKE_MODELS_RAW,
):
await refresh_models_cache(conn)
modalities = get_model_output_modalities(
conn, "google/gemini-2.5-flash-image")
assert set(modalities) == {"image", "text"}
def test_get_model_output_modalities_unknown_model():
conn = db_module.get_conn()
result = get_model_output_modalities(conn, "unknown/model")
assert result == []
+3 -1
View File
@@ -115,7 +115,9 @@ async def test_list_users_as_admin(client):
resp = await client.get("/users", headers={"Authorization": f"Bearer {admin_token}"})
assert resp.status_code == 200
assert isinstance(resp.json(), list)
assert len(resp.json()) == 1
assert len(resp.json()) >= 1
emails = [u["email"] for u in resp.json()]
assert "user@example.com" in emails
async def test_list_users_as_regular_user(client):