712c556032
- Updated `refresh_models_cache` to include output modalities in the models cache. - Added `get_model_output_modalities` function to retrieve output modalities for a specific model. - Modified tests to cover new functionality for output modalities. - Updated OpenRouter video generation functions to support audio generation and improved error handling. - Enhanced dashboard to display generated images and videos. - Refactored frontend templates to accommodate new data structures for generated content. - Adjusted tests to validate changes in model handling and dashboard rendering. Co-authored-by: Copilot <copilot@github.com>
367 lines
12 KiB
Python
367 lines
12 KiB
Python
"""Tests for model cache service and router."""
|
|
import json
|
|
import os
|
|
from datetime import datetime, timedelta, timezone
|
|
from unittest.mock import AsyncMock, patch
|
|
|
|
import pytest
|
|
import pytest_asyncio
|
|
from httpx import ASGITransport, AsyncClient
|
|
|
|
from app import db as db_module
|
|
from app.main import app
|
|
from app.services.models import (
|
|
_extract_output_modality,
|
|
_normalize_modality,
|
|
_parse_modality,
|
|
get_cached_models,
|
|
get_model_output_modalities,
|
|
is_cache_stale,
|
|
refresh_models_cache,
|
|
)
|
|
|
|
os.environ.setdefault("JWT_SECRET", "test-secret-key-for-testing-only")
|
|
os.environ.setdefault("OPENROUTER_API_KEY", "test-key")
|
|
|
|
FAKE_MODELS_RAW = [
|
|
{
|
|
"id": "openai/gpt-4o",
|
|
"name": "GPT-4o",
|
|
"context_length": 128000,
|
|
"pricing": {"prompt": "0.000005"},
|
|
"architecture": {"modality": "text->text", "output_modalities": ["text"]},
|
|
},
|
|
{
|
|
"id": "anthropic/claude-3-haiku",
|
|
"name": "Claude 3 Haiku",
|
|
"context_length": 200000,
|
|
"pricing": {},
|
|
"architecture": {"modality": "text+image->text", "output_modalities": ["text"]},
|
|
},
|
|
{
|
|
"id": "openai/dall-e-3",
|
|
"name": "DALL-E 3",
|
|
"context_length": None,
|
|
"pricing": {"image": "0.04"},
|
|
"architecture": {"modality": "text->image", "output_modalities": ["image"]},
|
|
},
|
|
{
|
|
"id": "openai/sora-2",
|
|
"name": "Sora 2",
|
|
"context_length": None,
|
|
"pricing": {"video": "0.10"},
|
|
"architecture": {"modality": "text->video", "output_modalities": ["video"]},
|
|
},
|
|
{
|
|
"id": "google/gemini-2.5-flash-image",
|
|
"name": "Gemini 2.5 Flash Image",
|
|
"context_length": None,
|
|
"pricing": {},
|
|
"architecture": {"output_modalities": ["image", "text"]},
|
|
},
|
|
]
|
|
|
|
|
|
@pytest.fixture(autouse=True)
|
|
def fresh_db():
|
|
db_module._conn = None
|
|
db_module.init_db(":memory:")
|
|
yield
|
|
db_module.close_db()
|
|
db_module._conn = None
|
|
|
|
|
|
@pytest_asyncio.fixture
|
|
async def client(fresh_db):
|
|
transport = ASGITransport(app=app)
|
|
async with AsyncClient(transport=transport, base_url="http://test") as ac:
|
|
yield ac
|
|
|
|
|
|
async def _register_login(client, email, password, is_admin=False):
|
|
"""Register + login; optionally promote to admin directly in DB."""
|
|
await client.post("/auth/register", json={"email": email, "password": password})
|
|
if is_admin:
|
|
db_module.get_conn().execute(
|
|
"UPDATE users SET role = 'admin' WHERE email = ?", [email]
|
|
)
|
|
resp = await client.post("/auth/login", json={"email": email, "password": password})
|
|
return resp.json()["access_token"]
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Unit tests: _parse_modality
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def test_parse_modality_text():
|
|
assert _parse_modality("text->text") == "text"
|
|
|
|
|
|
def test_parse_modality_multimodal_input_text_output():
|
|
assert _parse_modality("text+image->text") == "text"
|
|
|
|
|
|
def test_parse_modality_image():
|
|
assert _parse_modality("text->image") == "image"
|
|
|
|
|
|
def test_parse_modality_video():
|
|
assert _parse_modality("text->video") == "video"
|
|
|
|
|
|
def test_parse_modality_audio():
|
|
assert _parse_modality("text->audio") == "audio"
|
|
|
|
|
|
def test_parse_modality_no_arrow_fallback():
|
|
assert _parse_modality("text") == "text"
|
|
|
|
|
|
def test_normalize_embedding_alias():
|
|
assert _normalize_modality("embedding") == "embeddings"
|
|
|
|
|
|
def test_extract_output_modality_prefers_output_modalities():
|
|
model = {
|
|
"architecture": {
|
|
"modality": "text->text",
|
|
"output_modalities": ["image"],
|
|
}
|
|
}
|
|
assert _extract_output_modality(model) == "image"
|
|
|
|
|
|
def test_extract_output_modality_legacy_fallback():
|
|
model = {"architecture": {"modality": "text->audio"}}
|
|
assert _extract_output_modality(model) == "audio"
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Unit tests: is_cache_stale
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def test_cache_stale_when_empty():
|
|
conn = db_module.get_conn()
|
|
assert is_cache_stale(conn) is True
|
|
|
|
|
|
def test_cache_not_stale_after_fresh_insert():
|
|
conn = db_module.get_conn()
|
|
now = datetime.now(timezone.utc).replace(tzinfo=None)
|
|
conn.execute(
|
|
"INSERT INTO models_cache (model_id, name, modality, fetched_at) VALUES (?, ?, ?, ?)",
|
|
["openai/gpt-4o", "GPT-4o", "text", now],
|
|
)
|
|
assert is_cache_stale(conn) is False
|
|
|
|
|
|
def test_cache_stale_after_ttl_exceeded():
|
|
conn = db_module.get_conn()
|
|
# Store naive UTC to match DuckDB TIMESTAMP behaviour
|
|
old_time = datetime.now(timezone.utc).replace(
|
|
tzinfo=None) - timedelta(hours=25)
|
|
conn.execute(
|
|
"INSERT INTO models_cache (model_id, name, modality, fetched_at) VALUES (?, ?, ?, ?)",
|
|
["openai/gpt-4o", "GPT-4o", "text", old_time],
|
|
)
|
|
assert is_cache_stale(conn) is True
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Unit tests: refresh_models_cache + get_cached_models
|
|
# ---------------------------------------------------------------------------
|
|
|
|
async def test_refresh_stores_models():
|
|
conn = db_module.get_conn()
|
|
with patch(
|
|
"app.services.models.openrouter.list_models",
|
|
new_callable=AsyncMock,
|
|
return_value=FAKE_MODELS_RAW,
|
|
):
|
|
count = await refresh_models_cache(conn)
|
|
assert count == 5
|
|
all_models = get_cached_models(conn)
|
|
assert len(all_models) == 5
|
|
|
|
|
|
async def test_refresh_replaces_old_cache():
|
|
conn = db_module.get_conn()
|
|
old_time = datetime.now(timezone.utc).replace(
|
|
tzinfo=None) - timedelta(hours=30)
|
|
conn.execute(
|
|
"INSERT INTO models_cache (model_id, name, modality, fetched_at) VALUES (?, ?, ?, ?)",
|
|
["old/model", "Old Model", "text", old_time],
|
|
)
|
|
with patch(
|
|
"app.services.models.openrouter.list_models",
|
|
new_callable=AsyncMock,
|
|
return_value=FAKE_MODELS_RAW,
|
|
):
|
|
await refresh_models_cache(conn)
|
|
ids = [m["id"] for m in get_cached_models(conn)]
|
|
assert "old/model" not in ids
|
|
assert "openai/gpt-4o" in ids
|
|
assert len(ids) == 5
|
|
|
|
|
|
def test_get_cached_models_filter_by_modality():
|
|
conn = db_module.get_conn()
|
|
now = datetime.now(timezone.utc).replace(tzinfo=None)
|
|
for m in FAKE_MODELS_RAW:
|
|
modality = _extract_output_modality(m)
|
|
conn.execute(
|
|
"INSERT INTO models_cache (model_id, name, modality, fetched_at) VALUES (?, ?, ?, ?)",
|
|
[m["id"], m["name"], modality, now],
|
|
)
|
|
text_models = get_cached_models(conn, modality="text")
|
|
# gpt-4o, claude-3-haiku (gemini has output_modalities=["image","text"] → classified as "image")
|
|
assert len(text_models) == 2
|
|
assert all(m["modality"] == "text" for m in text_models)
|
|
|
|
image_models = get_cached_models(conn, modality="image")
|
|
# dall-e-3 + gemini (output_modalities starts with image)
|
|
assert len(image_models) == 2
|
|
image_ids = [m["id"] for m in image_models]
|
|
assert "openai/dall-e-3" in image_ids
|
|
assert "google/gemini-2.5-flash-image" in image_ids
|
|
|
|
video_models = get_cached_models(conn, modality="video")
|
|
assert len(video_models) == 1
|
|
assert video_models[0]["id"] == "openai/sora-2"
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Integration tests: GET /models/
|
|
# ---------------------------------------------------------------------------
|
|
|
|
async def test_list_models_endpoint_auto_refreshes(client):
|
|
token = await _register_login(client, "user@example.com", "secret123")
|
|
with patch(
|
|
"app.services.models.openrouter.list_models",
|
|
new_callable=AsyncMock,
|
|
return_value=FAKE_MODELS_RAW,
|
|
) as mock_fetch:
|
|
resp = await client.get(
|
|
"/models/", headers={"Authorization": f"Bearer {token}"}
|
|
)
|
|
assert resp.status_code == 200
|
|
assert len(resp.json()) == 5
|
|
assert mock_fetch.await_count >= 1
|
|
|
|
|
|
async def test_list_models_endpoint_uses_cache(client):
|
|
token = await _register_login(client, "user@example.com", "secret123")
|
|
conn = db_module.get_conn()
|
|
now = datetime.now(timezone.utc).replace(tzinfo=None)
|
|
conn.execute(
|
|
"INSERT INTO models_cache (model_id, name, modality, fetched_at) VALUES (?, ?, ?, ?)",
|
|
["cached/model", "Cached Model", "text", now],
|
|
)
|
|
with patch(
|
|
"app.services.models.openrouter.list_models",
|
|
new_callable=AsyncMock,
|
|
) as mock_fetch:
|
|
resp = await client.get(
|
|
"/models/?modality=text", headers={"Authorization": f"Bearer {token}"}
|
|
)
|
|
assert resp.status_code == 200
|
|
assert resp.json()[0]["id"] == "cached/model"
|
|
mock_fetch.assert_not_awaited()
|
|
|
|
|
|
async def test_list_models_endpoint_requires_auth(client):
|
|
resp = await client.get("/models/")
|
|
assert resp.status_code == 401
|
|
|
|
|
|
async def test_list_models_filter_by_modality(client):
|
|
token = await _register_login(client, "user@example.com", "secret123")
|
|
with patch(
|
|
"app.services.models.openrouter.list_models",
|
|
new_callable=AsyncMock,
|
|
return_value=FAKE_MODELS_RAW,
|
|
):
|
|
resp = await client.get(
|
|
"/models/?modality=image", headers={"Authorization": f"Bearer {token}"}
|
|
)
|
|
assert resp.status_code == 200
|
|
data = resp.json()
|
|
assert len(data) == 2 # dall-e-3 + gemini-2.5-flash-image
|
|
image_ids = [m["id"] for m in data]
|
|
assert "openai/dall-e-3" in image_ids
|
|
assert "google/gemini-2.5-flash-image" in image_ids
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Integration tests: POST /models/refresh
|
|
# ---------------------------------------------------------------------------
|
|
|
|
async def test_refresh_endpoint_requires_admin(client):
|
|
token = await _register_login(client, "user@example.com", "secret123")
|
|
resp = await client.post(
|
|
"/models/refresh", headers={"Authorization": f"Bearer {token}"}
|
|
)
|
|
assert resp.status_code == 403
|
|
|
|
|
|
async def test_refresh_endpoint_admin_succeeds(client):
|
|
token = await _register_login(client, "admin@example.com", "secret123", is_admin=True)
|
|
with patch(
|
|
"app.services.models.openrouter.list_models",
|
|
new_callable=AsyncMock,
|
|
return_value=FAKE_MODELS_RAW,
|
|
):
|
|
resp = await client.post(
|
|
"/models/refresh", headers={"Authorization": f"Bearer {token}"}
|
|
)
|
|
assert resp.status_code == 200
|
|
assert resp.json()["refreshed"] == 5
|
|
|
|
|
|
async def test_refresh_endpoint_502_on_openrouter_error(client):
|
|
token = await _register_login(client, "admin@example.com", "secret123", is_admin=True)
|
|
with patch(
|
|
"app.services.models.openrouter.list_models",
|
|
new_callable=AsyncMock,
|
|
side_effect=RuntimeError("network error"),
|
|
):
|
|
resp = await client.post(
|
|
"/models/refresh", headers={"Authorization": f"Bearer {token}"}
|
|
)
|
|
assert resp.status_code == 502
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Unit tests: get_model_output_modalities
|
|
# ---------------------------------------------------------------------------
|
|
|
|
async def test_get_model_output_modalities_image_only():
|
|
conn = db_module.get_conn()
|
|
with patch(
|
|
"app.services.models.openrouter.list_models",
|
|
new_callable=AsyncMock,
|
|
return_value=FAKE_MODELS_RAW,
|
|
):
|
|
await refresh_models_cache(conn)
|
|
modalities = get_model_output_modalities(conn, "openai/dall-e-3")
|
|
assert modalities == ["image"]
|
|
|
|
|
|
async def test_get_model_output_modalities_image_text():
|
|
conn = db_module.get_conn()
|
|
with patch(
|
|
"app.services.models.openrouter.list_models",
|
|
new_callable=AsyncMock,
|
|
return_value=FAKE_MODELS_RAW,
|
|
):
|
|
await refresh_models_cache(conn)
|
|
modalities = get_model_output_modalities(
|
|
conn, "google/gemini-2.5-flash-image")
|
|
assert set(modalities) == {"image", "text"}
|
|
|
|
|
|
def test_get_model_output_modalities_unknown_model():
|
|
conn = db_module.get_conn()
|
|
result = get_model_output_modalities(conn, "unknown/model")
|
|
assert result == []
|