feat: enhance model caching and output modalities handling
- Updated `refresh_models_cache` to include output modalities in the models cache. - Added `get_model_output_modalities` function to retrieve output modalities for a specific model. - Modified tests to cover new functionality for output modalities. - Updated OpenRouter video generation functions to support audio generation and improved error handling. - Enhanced dashboard to display generated images and videos. - Refactored frontend templates to accommodate new data structures for generated content. - Adjusted tests to validate changes in model handling and dashboard rendering. Co-authored-by: Copilot <copilot@github.com>
This commit is contained in:
@@ -15,6 +15,7 @@ from app.services.models import (
|
||||
_normalize_modality,
|
||||
_parse_modality,
|
||||
get_cached_models,
|
||||
get_model_output_modalities,
|
||||
is_cache_stale,
|
||||
refresh_models_cache,
|
||||
)
|
||||
@@ -28,28 +29,35 @@ FAKE_MODELS_RAW = [
|
||||
"name": "GPT-4o",
|
||||
"context_length": 128000,
|
||||
"pricing": {"prompt": "0.000005"},
|
||||
"architecture": {"modality": "text->text"},
|
||||
"architecture": {"modality": "text->text", "output_modalities": ["text"]},
|
||||
},
|
||||
{
|
||||
"id": "anthropic/claude-3-haiku",
|
||||
"name": "Claude 3 Haiku",
|
||||
"context_length": 200000,
|
||||
"pricing": {},
|
||||
"architecture": {"modality": "text+image->text"},
|
||||
"architecture": {"modality": "text+image->text", "output_modalities": ["text"]},
|
||||
},
|
||||
{
|
||||
"id": "openai/dall-e-3",
|
||||
"name": "DALL-E 3",
|
||||
"context_length": None,
|
||||
"pricing": {"image": "0.04"},
|
||||
"architecture": {"modality": "text->image"},
|
||||
"architecture": {"modality": "text->image", "output_modalities": ["image"]},
|
||||
},
|
||||
{
|
||||
"id": "openai/sora-2",
|
||||
"name": "Sora 2",
|
||||
"context_length": None,
|
||||
"pricing": {"video": "0.10"},
|
||||
"architecture": {"modality": "text->video"},
|
||||
"architecture": {"modality": "text->video", "output_modalities": ["video"]},
|
||||
},
|
||||
{
|
||||
"id": "google/gemini-2.5-flash-image",
|
||||
"name": "Gemini 2.5 Flash Image",
|
||||
"context_length": None,
|
||||
"pricing": {},
|
||||
"architecture": {"output_modalities": ["image", "text"]},
|
||||
},
|
||||
]
|
||||
|
||||
@@ -171,9 +179,9 @@ async def test_refresh_stores_models():
|
||||
return_value=FAKE_MODELS_RAW,
|
||||
):
|
||||
count = await refresh_models_cache(conn)
|
||||
assert count == 4
|
||||
assert count == 5
|
||||
all_models = get_cached_models(conn)
|
||||
assert len(all_models) == 4
|
||||
assert len(all_models) == 5
|
||||
|
||||
|
||||
async def test_refresh_replaces_old_cache():
|
||||
@@ -193,25 +201,29 @@ async def test_refresh_replaces_old_cache():
|
||||
ids = [m["id"] for m in get_cached_models(conn)]
|
||||
assert "old/model" not in ids
|
||||
assert "openai/gpt-4o" in ids
|
||||
assert len(ids) == 5
|
||||
|
||||
|
||||
def test_get_cached_models_filter_by_modality():
|
||||
conn = db_module.get_conn()
|
||||
now = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
for m in FAKE_MODELS_RAW:
|
||||
arch = m.get("architecture", {})
|
||||
modality = _parse_modality(arch.get("modality", "text->text"))
|
||||
modality = _extract_output_modality(m)
|
||||
conn.execute(
|
||||
"INSERT INTO models_cache (model_id, name, modality, fetched_at) VALUES (?, ?, ?, ?)",
|
||||
[m["id"], m["name"], modality, now],
|
||||
)
|
||||
text_models = get_cached_models(conn, modality="text")
|
||||
# gpt-4o, claude-3-haiku (gemini has output_modalities=["image","text"] → classified as "image")
|
||||
assert len(text_models) == 2
|
||||
assert all(m["modality"] == "text" for m in text_models)
|
||||
|
||||
image_models = get_cached_models(conn, modality="image")
|
||||
assert len(image_models) == 1
|
||||
assert image_models[0]["id"] == "openai/dall-e-3"
|
||||
# dall-e-3 + gemini (output_modalities starts with image)
|
||||
assert len(image_models) == 2
|
||||
image_ids = [m["id"] for m in image_models]
|
||||
assert "openai/dall-e-3" in image_ids
|
||||
assert "google/gemini-2.5-flash-image" in image_ids
|
||||
|
||||
video_models = get_cached_models(conn, modality="video")
|
||||
assert len(video_models) == 1
|
||||
@@ -233,7 +245,7 @@ async def test_list_models_endpoint_auto_refreshes(client):
|
||||
"/models/", headers={"Authorization": f"Bearer {token}"}
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
assert len(resp.json()) == 4
|
||||
assert len(resp.json()) == 5
|
||||
assert mock_fetch.await_count >= 1
|
||||
|
||||
|
||||
@@ -274,8 +286,10 @@ async def test_list_models_filter_by_modality(client):
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert len(data) == 1
|
||||
assert data[0]["id"] == "openai/dall-e-3"
|
||||
assert len(data) == 2 # dall-e-3 + gemini-2.5-flash-image
|
||||
image_ids = [m["id"] for m in data]
|
||||
assert "openai/dall-e-3" in image_ids
|
||||
assert "google/gemini-2.5-flash-image" in image_ids
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -301,7 +315,7 @@ async def test_refresh_endpoint_admin_succeeds(client):
|
||||
"/models/refresh", headers={"Authorization": f"Bearer {token}"}
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["refreshed"] == 4
|
||||
assert resp.json()["refreshed"] == 5
|
||||
|
||||
|
||||
async def test_refresh_endpoint_502_on_openrouter_error(client):
|
||||
@@ -315,3 +329,38 @@ async def test_refresh_endpoint_502_on_openrouter_error(client):
|
||||
"/models/refresh", headers={"Authorization": f"Bearer {token}"}
|
||||
)
|
||||
assert resp.status_code == 502
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Unit tests: get_model_output_modalities
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def test_get_model_output_modalities_image_only():
|
||||
conn = db_module.get_conn()
|
||||
with patch(
|
||||
"app.services.models.openrouter.list_models",
|
||||
new_callable=AsyncMock,
|
||||
return_value=FAKE_MODELS_RAW,
|
||||
):
|
||||
await refresh_models_cache(conn)
|
||||
modalities = get_model_output_modalities(conn, "openai/dall-e-3")
|
||||
assert modalities == ["image"]
|
||||
|
||||
|
||||
async def test_get_model_output_modalities_image_text():
|
||||
conn = db_module.get_conn()
|
||||
with patch(
|
||||
"app.services.models.openrouter.list_models",
|
||||
new_callable=AsyncMock,
|
||||
return_value=FAKE_MODELS_RAW,
|
||||
):
|
||||
await refresh_models_cache(conn)
|
||||
modalities = get_model_output_modalities(
|
||||
conn, "google/gemini-2.5-flash-image")
|
||||
assert set(modalities) == {"image", "text"}
|
||||
|
||||
|
||||
def test_get_model_output_modalities_unknown_model():
|
||||
conn = db_module.get_conn()
|
||||
result = get_model_output_modalities(conn, "unknown/model")
|
||||
assert result == []
|
||||
|
||||
Reference in New Issue
Block a user