Add models caching and management functionality with corresponding API endpoints and templates

Co-authored-by: Copilot <copilot@github.com>
This commit is contained in:
2026-04-29 13:51:43 +02:00
parent fe32c32726
commit 96d13fc440
9 changed files with 534 additions and 8 deletions
+296
View File
@@ -0,0 +1,296 @@
"""Tests for model cache service and router."""
import json
import os
from datetime import datetime, timedelta, timezone
from unittest.mock import AsyncMock, patch
import pytest
import pytest_asyncio
from httpx import ASGITransport, AsyncClient
from app import db as db_module
from app.main import app
from app.services.models import (
_parse_modality,
get_cached_models,
is_cache_stale,
refresh_models_cache,
)
os.environ.setdefault("JWT_SECRET", "test-secret-key-for-testing-only")
os.environ.setdefault("OPENROUTER_API_KEY", "test-key")
FAKE_MODELS_RAW = [
{
"id": "openai/gpt-4o",
"name": "GPT-4o",
"context_length": 128000,
"pricing": {"prompt": "0.000005"},
"architecture": {"modality": "text->text"},
},
{
"id": "anthropic/claude-3-haiku",
"name": "Claude 3 Haiku",
"context_length": 200000,
"pricing": {},
"architecture": {"modality": "text+image->text"},
},
{
"id": "openai/dall-e-3",
"name": "DALL-E 3",
"context_length": None,
"pricing": {"image": "0.04"},
"architecture": {"modality": "text->image"},
},
{
"id": "openai/sora-2",
"name": "Sora 2",
"context_length": None,
"pricing": {"video": "0.10"},
"architecture": {"modality": "text->video"},
},
]
@pytest.fixture(autouse=True)
def fresh_db():
db_module._conn = None
db_module.init_db(":memory:")
yield
db_module.close_db()
db_module._conn = None
@pytest_asyncio.fixture
async def client(fresh_db):
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://test") as ac:
yield ac
async def _register_login(client, email, password, is_admin=False):
"""Register + login; optionally promote to admin directly in DB."""
await client.post("/auth/register", json={"email": email, "password": password})
if is_admin:
db_module.get_conn().execute(
"UPDATE users SET role = 'admin' WHERE email = ?", [email]
)
resp = await client.post("/auth/login", json={"email": email, "password": password})
return resp.json()["access_token"]
# ---------------------------------------------------------------------------
# Unit tests: _parse_modality
# ---------------------------------------------------------------------------
def test_parse_modality_text():
assert _parse_modality("text->text") == "text"
def test_parse_modality_multimodal_input_text_output():
assert _parse_modality("text+image->text") == "text"
def test_parse_modality_image():
assert _parse_modality("text->image") == "image"
def test_parse_modality_video():
assert _parse_modality("text->video") == "video"
def test_parse_modality_audio():
assert _parse_modality("text->audio") == "audio"
def test_parse_modality_no_arrow_fallback():
assert _parse_modality("text") == "text"
# ---------------------------------------------------------------------------
# Unit tests: is_cache_stale
# ---------------------------------------------------------------------------
def test_cache_stale_when_empty():
conn = db_module.get_conn()
assert is_cache_stale(conn) is True
def test_cache_not_stale_after_fresh_insert():
conn = db_module.get_conn()
now = datetime.now(timezone.utc).replace(tzinfo=None)
conn.execute(
"INSERT INTO models_cache (model_id, name, modality, fetched_at) VALUES (?, ?, ?, ?)",
["openai/gpt-4o", "GPT-4o", "text", now],
)
assert is_cache_stale(conn) is False
def test_cache_stale_after_ttl_exceeded():
conn = db_module.get_conn()
# Store naive UTC to match DuckDB TIMESTAMP behaviour
old_time = datetime.now(timezone.utc).replace(
tzinfo=None) - timedelta(hours=25)
conn.execute(
"INSERT INTO models_cache (model_id, name, modality, fetched_at) VALUES (?, ?, ?, ?)",
["openai/gpt-4o", "GPT-4o", "text", old_time],
)
assert is_cache_stale(conn) is True
# ---------------------------------------------------------------------------
# Unit tests: refresh_models_cache + get_cached_models
# ---------------------------------------------------------------------------
async def test_refresh_stores_models():
conn = db_module.get_conn()
with patch(
"app.services.models.openrouter.list_models",
new_callable=AsyncMock,
return_value=FAKE_MODELS_RAW,
):
count = await refresh_models_cache(conn)
assert count == 4
all_models = get_cached_models(conn)
assert len(all_models) == 4
async def test_refresh_replaces_old_cache():
conn = db_module.get_conn()
old_time = datetime.now(timezone.utc).replace(
tzinfo=None) - timedelta(hours=30)
conn.execute(
"INSERT INTO models_cache (model_id, name, modality, fetched_at) VALUES (?, ?, ?, ?)",
["old/model", "Old Model", "text", old_time],
)
with patch(
"app.services.models.openrouter.list_models",
new_callable=AsyncMock,
return_value=FAKE_MODELS_RAW,
):
await refresh_models_cache(conn)
ids = [m["id"] for m in get_cached_models(conn)]
assert "old/model" not in ids
assert "openai/gpt-4o" in ids
def test_get_cached_models_filter_by_modality():
conn = db_module.get_conn()
now = datetime.now(timezone.utc).replace(tzinfo=None)
for m in FAKE_MODELS_RAW:
arch = m.get("architecture", {})
modality = _parse_modality(arch.get("modality", "text->text"))
conn.execute(
"INSERT INTO models_cache (model_id, name, modality, fetched_at) VALUES (?, ?, ?, ?)",
[m["id"], m["name"], modality, now],
)
text_models = get_cached_models(conn, modality="text")
assert len(text_models) == 2
assert all(m["modality"] == "text" for m in text_models)
image_models = get_cached_models(conn, modality="image")
assert len(image_models) == 1
assert image_models[0]["id"] == "openai/dall-e-3"
video_models = get_cached_models(conn, modality="video")
assert len(video_models) == 1
assert video_models[0]["id"] == "openai/sora-2"
# ---------------------------------------------------------------------------
# Integration tests: GET /models/
# ---------------------------------------------------------------------------
async def test_list_models_endpoint_auto_refreshes(client):
token = await _register_login(client, "user@example.com", "secret123")
with patch(
"app.services.models.openrouter.list_models",
new_callable=AsyncMock,
return_value=FAKE_MODELS_RAW,
) as mock_fetch:
resp = await client.get(
"/models/", headers={"Authorization": f"Bearer {token}"}
)
assert resp.status_code == 200
assert len(resp.json()) == 4
mock_fetch.assert_awaited_once()
async def test_list_models_endpoint_uses_cache(client):
token = await _register_login(client, "user@example.com", "secret123")
conn = db_module.get_conn()
now = datetime.now(timezone.utc).replace(tzinfo=None)
conn.execute(
"INSERT INTO models_cache (model_id, name, modality, fetched_at) VALUES (?, ?, ?, ?)",
["cached/model", "Cached Model", "text", now],
)
with patch(
"app.services.models.openrouter.list_models",
new_callable=AsyncMock,
) as mock_fetch:
resp = await client.get(
"/models/?modality=text", headers={"Authorization": f"Bearer {token}"}
)
assert resp.status_code == 200
assert resp.json()[0]["id"] == "cached/model"
mock_fetch.assert_not_awaited()
async def test_list_models_endpoint_requires_auth(client):
resp = await client.get("/models/")
assert resp.status_code == 401
async def test_list_models_filter_by_modality(client):
token = await _register_login(client, "user@example.com", "secret123")
with patch(
"app.services.models.openrouter.list_models",
new_callable=AsyncMock,
return_value=FAKE_MODELS_RAW,
):
resp = await client.get(
"/models/?modality=image", headers={"Authorization": f"Bearer {token}"}
)
assert resp.status_code == 200
data = resp.json()
assert len(data) == 1
assert data[0]["id"] == "openai/dall-e-3"
# ---------------------------------------------------------------------------
# Integration tests: POST /models/refresh
# ---------------------------------------------------------------------------
async def test_refresh_endpoint_requires_admin(client):
token = await _register_login(client, "user@example.com", "secret123")
resp = await client.post(
"/models/refresh", headers={"Authorization": f"Bearer {token}"}
)
assert resp.status_code == 403
async def test_refresh_endpoint_admin_succeeds(client):
token = await _register_login(client, "admin@example.com", "secret123", is_admin=True)
with patch(
"app.services.models.openrouter.list_models",
new_callable=AsyncMock,
return_value=FAKE_MODELS_RAW,
):
resp = await client.post(
"/models/refresh", headers={"Authorization": f"Bearer {token}"}
)
assert resp.status_code == 200
assert resp.json()["refreshed"] == 4
async def test_refresh_endpoint_502_on_openrouter_error(client):
token = await _register_login(client, "admin@example.com", "secret123", is_admin=True)
with patch(
"app.services.models.openrouter.list_models",
new_callable=AsyncMock,
side_effect=RuntimeError("network error"),
):
resp = await client.post(
"/models/refresh", headers={"Authorization": f"Bearer {token}"}
)
assert resp.status_code == 502