Files
zwitschi 712c556032 feat: enhance model caching and output modalities handling
- Updated `refresh_models_cache` to include output modalities in the models cache.
- Added `get_model_output_modalities` function to retrieve output modalities for a specific model.
- Modified tests to cover new functionality for output modalities.
- Updated OpenRouter video generation functions to support audio generation and improved error handling.
- Enhanced dashboard to display generated images and videos.
- Refactored frontend templates to accommodate new data structures for generated content.
- Adjusted tests to validate changes in model handling and dashboard rendering.

Co-authored-by: Copilot <copilot@github.com>
2026-04-29 15:20:48 +02:00

367 lines
12 KiB
Python

"""Tests for model cache service and router."""
import json
import os
from datetime import datetime, timedelta, timezone
from unittest.mock import AsyncMock, patch
import pytest
import pytest_asyncio
from httpx import ASGITransport, AsyncClient
from app import db as db_module
from app.main import app
from app.services.models import (
_extract_output_modality,
_normalize_modality,
_parse_modality,
get_cached_models,
get_model_output_modalities,
is_cache_stale,
refresh_models_cache,
)
os.environ.setdefault("JWT_SECRET", "test-secret-key-for-testing-only")
os.environ.setdefault("OPENROUTER_API_KEY", "test-key")
FAKE_MODELS_RAW = [
{
"id": "openai/gpt-4o",
"name": "GPT-4o",
"context_length": 128000,
"pricing": {"prompt": "0.000005"},
"architecture": {"modality": "text->text", "output_modalities": ["text"]},
},
{
"id": "anthropic/claude-3-haiku",
"name": "Claude 3 Haiku",
"context_length": 200000,
"pricing": {},
"architecture": {"modality": "text+image->text", "output_modalities": ["text"]},
},
{
"id": "openai/dall-e-3",
"name": "DALL-E 3",
"context_length": None,
"pricing": {"image": "0.04"},
"architecture": {"modality": "text->image", "output_modalities": ["image"]},
},
{
"id": "openai/sora-2",
"name": "Sora 2",
"context_length": None,
"pricing": {"video": "0.10"},
"architecture": {"modality": "text->video", "output_modalities": ["video"]},
},
{
"id": "google/gemini-2.5-flash-image",
"name": "Gemini 2.5 Flash Image",
"context_length": None,
"pricing": {},
"architecture": {"output_modalities": ["image", "text"]},
},
]
@pytest.fixture(autouse=True)
def fresh_db():
db_module._conn = None
db_module.init_db(":memory:")
yield
db_module.close_db()
db_module._conn = None
@pytest_asyncio.fixture
async def client(fresh_db):
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://test") as ac:
yield ac
async def _register_login(client, email, password, is_admin=False):
"""Register + login; optionally promote to admin directly in DB."""
await client.post("/auth/register", json={"email": email, "password": password})
if is_admin:
db_module.get_conn().execute(
"UPDATE users SET role = 'admin' WHERE email = ?", [email]
)
resp = await client.post("/auth/login", json={"email": email, "password": password})
return resp.json()["access_token"]
# ---------------------------------------------------------------------------
# Unit tests: _parse_modality
# ---------------------------------------------------------------------------
def test_parse_modality_text():
assert _parse_modality("text->text") == "text"
def test_parse_modality_multimodal_input_text_output():
assert _parse_modality("text+image->text") == "text"
def test_parse_modality_image():
assert _parse_modality("text->image") == "image"
def test_parse_modality_video():
assert _parse_modality("text->video") == "video"
def test_parse_modality_audio():
assert _parse_modality("text->audio") == "audio"
def test_parse_modality_no_arrow_fallback():
assert _parse_modality("text") == "text"
def test_normalize_embedding_alias():
assert _normalize_modality("embedding") == "embeddings"
def test_extract_output_modality_prefers_output_modalities():
model = {
"architecture": {
"modality": "text->text",
"output_modalities": ["image"],
}
}
assert _extract_output_modality(model) == "image"
def test_extract_output_modality_legacy_fallback():
model = {"architecture": {"modality": "text->audio"}}
assert _extract_output_modality(model) == "audio"
# ---------------------------------------------------------------------------
# Unit tests: is_cache_stale
# ---------------------------------------------------------------------------
def test_cache_stale_when_empty():
conn = db_module.get_conn()
assert is_cache_stale(conn) is True
def test_cache_not_stale_after_fresh_insert():
conn = db_module.get_conn()
now = datetime.now(timezone.utc).replace(tzinfo=None)
conn.execute(
"INSERT INTO models_cache (model_id, name, modality, fetched_at) VALUES (?, ?, ?, ?)",
["openai/gpt-4o", "GPT-4o", "text", now],
)
assert is_cache_stale(conn) is False
def test_cache_stale_after_ttl_exceeded():
conn = db_module.get_conn()
# Store naive UTC to match DuckDB TIMESTAMP behaviour
old_time = datetime.now(timezone.utc).replace(
tzinfo=None) - timedelta(hours=25)
conn.execute(
"INSERT INTO models_cache (model_id, name, modality, fetched_at) VALUES (?, ?, ?, ?)",
["openai/gpt-4o", "GPT-4o", "text", old_time],
)
assert is_cache_stale(conn) is True
# ---------------------------------------------------------------------------
# Unit tests: refresh_models_cache + get_cached_models
# ---------------------------------------------------------------------------
async def test_refresh_stores_models():
conn = db_module.get_conn()
with patch(
"app.services.models.openrouter.list_models",
new_callable=AsyncMock,
return_value=FAKE_MODELS_RAW,
):
count = await refresh_models_cache(conn)
assert count == 5
all_models = get_cached_models(conn)
assert len(all_models) == 5
async def test_refresh_replaces_old_cache():
conn = db_module.get_conn()
old_time = datetime.now(timezone.utc).replace(
tzinfo=None) - timedelta(hours=30)
conn.execute(
"INSERT INTO models_cache (model_id, name, modality, fetched_at) VALUES (?, ?, ?, ?)",
["old/model", "Old Model", "text", old_time],
)
with patch(
"app.services.models.openrouter.list_models",
new_callable=AsyncMock,
return_value=FAKE_MODELS_RAW,
):
await refresh_models_cache(conn)
ids = [m["id"] for m in get_cached_models(conn)]
assert "old/model" not in ids
assert "openai/gpt-4o" in ids
assert len(ids) == 5
def test_get_cached_models_filter_by_modality():
conn = db_module.get_conn()
now = datetime.now(timezone.utc).replace(tzinfo=None)
for m in FAKE_MODELS_RAW:
modality = _extract_output_modality(m)
conn.execute(
"INSERT INTO models_cache (model_id, name, modality, fetched_at) VALUES (?, ?, ?, ?)",
[m["id"], m["name"], modality, now],
)
text_models = get_cached_models(conn, modality="text")
# gpt-4o, claude-3-haiku (gemini has output_modalities=["image","text"] → classified as "image")
assert len(text_models) == 2
assert all(m["modality"] == "text" for m in text_models)
image_models = get_cached_models(conn, modality="image")
# dall-e-3 + gemini (output_modalities starts with image)
assert len(image_models) == 2
image_ids = [m["id"] for m in image_models]
assert "openai/dall-e-3" in image_ids
assert "google/gemini-2.5-flash-image" in image_ids
video_models = get_cached_models(conn, modality="video")
assert len(video_models) == 1
assert video_models[0]["id"] == "openai/sora-2"
# ---------------------------------------------------------------------------
# Integration tests: GET /models/
# ---------------------------------------------------------------------------
async def test_list_models_endpoint_auto_refreshes(client):
token = await _register_login(client, "user@example.com", "secret123")
with patch(
"app.services.models.openrouter.list_models",
new_callable=AsyncMock,
return_value=FAKE_MODELS_RAW,
) as mock_fetch:
resp = await client.get(
"/models/", headers={"Authorization": f"Bearer {token}"}
)
assert resp.status_code == 200
assert len(resp.json()) == 5
assert mock_fetch.await_count >= 1
async def test_list_models_endpoint_uses_cache(client):
token = await _register_login(client, "user@example.com", "secret123")
conn = db_module.get_conn()
now = datetime.now(timezone.utc).replace(tzinfo=None)
conn.execute(
"INSERT INTO models_cache (model_id, name, modality, fetched_at) VALUES (?, ?, ?, ?)",
["cached/model", "Cached Model", "text", now],
)
with patch(
"app.services.models.openrouter.list_models",
new_callable=AsyncMock,
) as mock_fetch:
resp = await client.get(
"/models/?modality=text", headers={"Authorization": f"Bearer {token}"}
)
assert resp.status_code == 200
assert resp.json()[0]["id"] == "cached/model"
mock_fetch.assert_not_awaited()
async def test_list_models_endpoint_requires_auth(client):
resp = await client.get("/models/")
assert resp.status_code == 401
async def test_list_models_filter_by_modality(client):
token = await _register_login(client, "user@example.com", "secret123")
with patch(
"app.services.models.openrouter.list_models",
new_callable=AsyncMock,
return_value=FAKE_MODELS_RAW,
):
resp = await client.get(
"/models/?modality=image", headers={"Authorization": f"Bearer {token}"}
)
assert resp.status_code == 200
data = resp.json()
assert len(data) == 2 # dall-e-3 + gemini-2.5-flash-image
image_ids = [m["id"] for m in data]
assert "openai/dall-e-3" in image_ids
assert "google/gemini-2.5-flash-image" in image_ids
# ---------------------------------------------------------------------------
# Integration tests: POST /models/refresh
# ---------------------------------------------------------------------------
async def test_refresh_endpoint_requires_admin(client):
token = await _register_login(client, "user@example.com", "secret123")
resp = await client.post(
"/models/refresh", headers={"Authorization": f"Bearer {token}"}
)
assert resp.status_code == 403
async def test_refresh_endpoint_admin_succeeds(client):
token = await _register_login(client, "admin@example.com", "secret123", is_admin=True)
with patch(
"app.services.models.openrouter.list_models",
new_callable=AsyncMock,
return_value=FAKE_MODELS_RAW,
):
resp = await client.post(
"/models/refresh", headers={"Authorization": f"Bearer {token}"}
)
assert resp.status_code == 200
assert resp.json()["refreshed"] == 5
async def test_refresh_endpoint_502_on_openrouter_error(client):
token = await _register_login(client, "admin@example.com", "secret123", is_admin=True)
with patch(
"app.services.models.openrouter.list_models",
new_callable=AsyncMock,
side_effect=RuntimeError("network error"),
):
resp = await client.post(
"/models/refresh", headers={"Authorization": f"Bearer {token}"}
)
assert resp.status_code == 502
# ---------------------------------------------------------------------------
# Unit tests: get_model_output_modalities
# ---------------------------------------------------------------------------
async def test_get_model_output_modalities_image_only():
conn = db_module.get_conn()
with patch(
"app.services.models.openrouter.list_models",
new_callable=AsyncMock,
return_value=FAKE_MODELS_RAW,
):
await refresh_models_cache(conn)
modalities = get_model_output_modalities(conn, "openai/dall-e-3")
assert modalities == ["image"]
async def test_get_model_output_modalities_image_text():
conn = db_module.get_conn()
with patch(
"app.services.models.openrouter.list_models",
new_callable=AsyncMock,
return_value=FAKE_MODELS_RAW,
):
await refresh_models_cache(conn)
modalities = get_model_output_modalities(
conn, "google/gemini-2.5-flash-image")
assert set(modalities) == {"image", "text"}
def test_get_model_output_modalities_unknown_model():
conn = db_module.get_conn()
result = get_model_output_modalities(conn, "unknown/model")
assert result == []