Add models caching and management functionality with corresponding API endpoints and templates
Co-authored-by: Copilot <copilot@github.com>
This commit is contained in:
@@ -75,6 +75,17 @@ def _run_migrations(conn: duckdb.DuckDBPyConnection) -> None:
|
||||
created_at TIMESTAMP DEFAULT now()
|
||||
)
|
||||
""")
|
||||
conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS models_cache (
|
||||
id UUID DEFAULT uuid() PRIMARY KEY,
|
||||
model_id VARCHAR NOT NULL UNIQUE,
|
||||
name VARCHAR NOT NULL,
|
||||
modality VARCHAR NOT NULL,
|
||||
context_length BIGINT,
|
||||
pricing JSON,
|
||||
fetched_at TIMESTAMP NOT NULL
|
||||
)
|
||||
""")
|
||||
_seed_admin(conn)
|
||||
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ from .routers import admin as admin_router
|
||||
from .routers import ai as ai_router
|
||||
from .routers import generate as generate_router
|
||||
from .routers import images as images_router
|
||||
from .routers import models as models_router
|
||||
from .db import close_db, init_db
|
||||
import os
|
||||
from contextlib import asynccontextmanager
|
||||
@@ -43,6 +44,7 @@ app.include_router(admin_router.router)
|
||||
app.include_router(ai_router.router)
|
||||
app.include_router(generate_router.router)
|
||||
app.include_router(images_router.router)
|
||||
app.include_router(models_router.router)
|
||||
|
||||
|
||||
@app.get("/health", tags=["health"])
|
||||
|
||||
@@ -0,0 +1,47 @@
|
||||
"""Models router: list and refresh the OpenRouter model cache."""
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
||||
|
||||
from ..db import get_conn, get_write_lock
|
||||
from ..dependencies import get_current_user, require_admin
|
||||
from ..services import models as models_service
|
||||
|
||||
router = APIRouter(prefix="/models", tags=["models"])
|
||||
|
||||
|
||||
@router.get("/")
|
||||
async def list_models(
|
||||
modality: str | None = Query(
|
||||
None,
|
||||
description="Filter by output modality: text, image, video, audio",
|
||||
),
|
||||
_: dict = Depends(get_current_user),
|
||||
):
|
||||
"""Return cached models. Auto-refreshes cache if stale (older than 24 h)."""
|
||||
conn = get_conn()
|
||||
if models_service.is_cache_stale(conn):
|
||||
async with get_write_lock():
|
||||
# Re-check inside lock to avoid redundant parallel refreshes
|
||||
if models_service.is_cache_stale(conn):
|
||||
try:
|
||||
await models_service.refresh_models_cache(conn)
|
||||
except Exception as exc:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_502_BAD_GATEWAY,
|
||||
detail=f"Failed to refresh model cache: {exc}",
|
||||
)
|
||||
return models_service.get_cached_models(conn, modality)
|
||||
|
||||
|
||||
@router.post("/refresh", status_code=200)
|
||||
async def refresh_models(_: dict = Depends(require_admin)):
|
||||
"""Force-refresh the model cache from OpenRouter. Admin only."""
|
||||
conn = get_conn()
|
||||
async with get_write_lock():
|
||||
try:
|
||||
count = await models_service.refresh_models_cache(conn)
|
||||
except Exception as exc:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_502_BAD_GATEWAY,
|
||||
detail=f"OpenRouter error: {exc}",
|
||||
)
|
||||
return {"refreshed": count}
|
||||
@@ -0,0 +1,124 @@
|
||||
"""Model cache service: fetch from OpenRouter, store in DuckDB."""
|
||||
import json
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Any
|
||||
|
||||
import duckdb
|
||||
|
||||
from . import openrouter
|
||||
|
||||
CACHE_TTL_HOURS = 24
|
||||
|
||||
|
||||
def _parse_modality(raw_modality: str) -> str:
|
||||
"""Extract output modality from OpenRouter architecture.modality string.
|
||||
|
||||
Examples: "text->text", "text+image->text", "text->image", "text->video"
|
||||
"""
|
||||
output = raw_modality.split(
|
||||
"->", 1)[-1].lower() if "->" in raw_modality else raw_modality.lower()
|
||||
if "text" in output:
|
||||
return "text"
|
||||
if "image" in output:
|
||||
return "image"
|
||||
if "video" in output:
|
||||
return "video"
|
||||
if "audio" in output:
|
||||
return "audio"
|
||||
return output
|
||||
|
||||
|
||||
async def refresh_models_cache(conn: duckdb.DuckDBPyConnection) -> int:
|
||||
"""Fetch all models from OpenRouter and replace the cache. Returns count stored."""
|
||||
raw = await openrouter.list_models()
|
||||
# Use naive UTC to avoid DuckDB TIMESTAMP tz-stripping inconsistencies
|
||||
now = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
|
||||
conn.execute("DELETE FROM models_cache")
|
||||
count = 0
|
||||
for m in raw:
|
||||
arch = m.get("architecture", {})
|
||||
modality_raw = arch.get(
|
||||
"modality", "text->text") if arch else "text->text"
|
||||
modality = _parse_modality(modality_raw)
|
||||
pricing = m.get("pricing")
|
||||
model_id = m.get("id", "")
|
||||
if not model_id:
|
||||
continue
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT INTO models_cache (model_id, name, modality, context_length, pricing, fetched_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?)
|
||||
ON CONFLICT (model_id) DO UPDATE SET
|
||||
name = excluded.name,
|
||||
modality = excluded.modality,
|
||||
context_length = excluded.context_length,
|
||||
pricing = excluded.pricing,
|
||||
fetched_at = excluded.fetched_at
|
||||
""",
|
||||
[
|
||||
model_id,
|
||||
m.get("name", model_id),
|
||||
modality,
|
||||
m.get("context_length"),
|
||||
json.dumps(pricing) if pricing else None,
|
||||
now,
|
||||
],
|
||||
)
|
||||
count += 1
|
||||
return count
|
||||
|
||||
|
||||
def is_cache_stale(conn: duckdb.DuckDBPyConnection) -> bool:
|
||||
"""Return True if cache is empty or last fetched more than CACHE_TTL_HOURS ago."""
|
||||
row = conn.execute("SELECT MAX(fetched_at) FROM models_cache").fetchone()
|
||||
if not row or row[0] is None:
|
||||
return True
|
||||
last_fetched = row[0]
|
||||
# DuckDB TIMESTAMP is always naive; compare against naive UTC
|
||||
if last_fetched.tzinfo is not None:
|
||||
last_fetched = last_fetched.replace(tzinfo=None)
|
||||
now_naive = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
return now_naive - last_fetched > timedelta(hours=CACHE_TTL_HOURS)
|
||||
|
||||
|
||||
def get_cached_models(
|
||||
conn: duckdb.DuckDBPyConnection,
|
||||
modality: str | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Return cached models, optionally filtered by modality, ordered by name."""
|
||||
if modality:
|
||||
rows = conn.execute(
|
||||
"""
|
||||
SELECT model_id, name, modality, context_length, pricing
|
||||
FROM models_cache
|
||||
WHERE modality = ?
|
||||
ORDER BY name
|
||||
""",
|
||||
[modality],
|
||||
).fetchall()
|
||||
else:
|
||||
rows = conn.execute(
|
||||
"""
|
||||
SELECT model_id, name, modality, context_length, pricing
|
||||
FROM models_cache
|
||||
ORDER BY name
|
||||
"""
|
||||
).fetchall()
|
||||
|
||||
result = []
|
||||
for row in rows:
|
||||
pricing = None
|
||||
if row[4]:
|
||||
try:
|
||||
pricing = json.loads(row[4])
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
pricing = None
|
||||
result.append({
|
||||
"id": row[0],
|
||||
"name": row[1],
|
||||
"modality": row[2],
|
||||
"context_length": row[3],
|
||||
"pricing": pricing,
|
||||
})
|
||||
return result
|
||||
@@ -0,0 +1,296 @@
|
||||
"""Tests for model cache service and router."""
|
||||
import json
|
||||
import os
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from httpx import ASGITransport, AsyncClient
|
||||
|
||||
from app import db as db_module
|
||||
from app.main import app
|
||||
from app.services.models import (
|
||||
_parse_modality,
|
||||
get_cached_models,
|
||||
is_cache_stale,
|
||||
refresh_models_cache,
|
||||
)
|
||||
|
||||
os.environ.setdefault("JWT_SECRET", "test-secret-key-for-testing-only")
|
||||
os.environ.setdefault("OPENROUTER_API_KEY", "test-key")
|
||||
|
||||
FAKE_MODELS_RAW = [
|
||||
{
|
||||
"id": "openai/gpt-4o",
|
||||
"name": "GPT-4o",
|
||||
"context_length": 128000,
|
||||
"pricing": {"prompt": "0.000005"},
|
||||
"architecture": {"modality": "text->text"},
|
||||
},
|
||||
{
|
||||
"id": "anthropic/claude-3-haiku",
|
||||
"name": "Claude 3 Haiku",
|
||||
"context_length": 200000,
|
||||
"pricing": {},
|
||||
"architecture": {"modality": "text+image->text"},
|
||||
},
|
||||
{
|
||||
"id": "openai/dall-e-3",
|
||||
"name": "DALL-E 3",
|
||||
"context_length": None,
|
||||
"pricing": {"image": "0.04"},
|
||||
"architecture": {"modality": "text->image"},
|
||||
},
|
||||
{
|
||||
"id": "openai/sora-2",
|
||||
"name": "Sora 2",
|
||||
"context_length": None,
|
||||
"pricing": {"video": "0.10"},
|
||||
"architecture": {"modality": "text->video"},
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def fresh_db():
|
||||
db_module._conn = None
|
||||
db_module.init_db(":memory:")
|
||||
yield
|
||||
db_module.close_db()
|
||||
db_module._conn = None
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def client(fresh_db):
|
||||
transport = ASGITransport(app=app)
|
||||
async with AsyncClient(transport=transport, base_url="http://test") as ac:
|
||||
yield ac
|
||||
|
||||
|
||||
async def _register_login(client, email, password, is_admin=False):
|
||||
"""Register + login; optionally promote to admin directly in DB."""
|
||||
await client.post("/auth/register", json={"email": email, "password": password})
|
||||
if is_admin:
|
||||
db_module.get_conn().execute(
|
||||
"UPDATE users SET role = 'admin' WHERE email = ?", [email]
|
||||
)
|
||||
resp = await client.post("/auth/login", json={"email": email, "password": password})
|
||||
return resp.json()["access_token"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Unit tests: _parse_modality
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_parse_modality_text():
|
||||
assert _parse_modality("text->text") == "text"
|
||||
|
||||
|
||||
def test_parse_modality_multimodal_input_text_output():
|
||||
assert _parse_modality("text+image->text") == "text"
|
||||
|
||||
|
||||
def test_parse_modality_image():
|
||||
assert _parse_modality("text->image") == "image"
|
||||
|
||||
|
||||
def test_parse_modality_video():
|
||||
assert _parse_modality("text->video") == "video"
|
||||
|
||||
|
||||
def test_parse_modality_audio():
|
||||
assert _parse_modality("text->audio") == "audio"
|
||||
|
||||
|
||||
def test_parse_modality_no_arrow_fallback():
|
||||
assert _parse_modality("text") == "text"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Unit tests: is_cache_stale
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_cache_stale_when_empty():
|
||||
conn = db_module.get_conn()
|
||||
assert is_cache_stale(conn) is True
|
||||
|
||||
|
||||
def test_cache_not_stale_after_fresh_insert():
|
||||
conn = db_module.get_conn()
|
||||
now = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
conn.execute(
|
||||
"INSERT INTO models_cache (model_id, name, modality, fetched_at) VALUES (?, ?, ?, ?)",
|
||||
["openai/gpt-4o", "GPT-4o", "text", now],
|
||||
)
|
||||
assert is_cache_stale(conn) is False
|
||||
|
||||
|
||||
def test_cache_stale_after_ttl_exceeded():
|
||||
conn = db_module.get_conn()
|
||||
# Store naive UTC to match DuckDB TIMESTAMP behaviour
|
||||
old_time = datetime.now(timezone.utc).replace(
|
||||
tzinfo=None) - timedelta(hours=25)
|
||||
conn.execute(
|
||||
"INSERT INTO models_cache (model_id, name, modality, fetched_at) VALUES (?, ?, ?, ?)",
|
||||
["openai/gpt-4o", "GPT-4o", "text", old_time],
|
||||
)
|
||||
assert is_cache_stale(conn) is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Unit tests: refresh_models_cache + get_cached_models
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def test_refresh_stores_models():
|
||||
conn = db_module.get_conn()
|
||||
with patch(
|
||||
"app.services.models.openrouter.list_models",
|
||||
new_callable=AsyncMock,
|
||||
return_value=FAKE_MODELS_RAW,
|
||||
):
|
||||
count = await refresh_models_cache(conn)
|
||||
assert count == 4
|
||||
all_models = get_cached_models(conn)
|
||||
assert len(all_models) == 4
|
||||
|
||||
|
||||
async def test_refresh_replaces_old_cache():
|
||||
conn = db_module.get_conn()
|
||||
old_time = datetime.now(timezone.utc).replace(
|
||||
tzinfo=None) - timedelta(hours=30)
|
||||
conn.execute(
|
||||
"INSERT INTO models_cache (model_id, name, modality, fetched_at) VALUES (?, ?, ?, ?)",
|
||||
["old/model", "Old Model", "text", old_time],
|
||||
)
|
||||
with patch(
|
||||
"app.services.models.openrouter.list_models",
|
||||
new_callable=AsyncMock,
|
||||
return_value=FAKE_MODELS_RAW,
|
||||
):
|
||||
await refresh_models_cache(conn)
|
||||
ids = [m["id"] for m in get_cached_models(conn)]
|
||||
assert "old/model" not in ids
|
||||
assert "openai/gpt-4o" in ids
|
||||
|
||||
|
||||
def test_get_cached_models_filter_by_modality():
|
||||
conn = db_module.get_conn()
|
||||
now = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
for m in FAKE_MODELS_RAW:
|
||||
arch = m.get("architecture", {})
|
||||
modality = _parse_modality(arch.get("modality", "text->text"))
|
||||
conn.execute(
|
||||
"INSERT INTO models_cache (model_id, name, modality, fetched_at) VALUES (?, ?, ?, ?)",
|
||||
[m["id"], m["name"], modality, now],
|
||||
)
|
||||
text_models = get_cached_models(conn, modality="text")
|
||||
assert len(text_models) == 2
|
||||
assert all(m["modality"] == "text" for m in text_models)
|
||||
|
||||
image_models = get_cached_models(conn, modality="image")
|
||||
assert len(image_models) == 1
|
||||
assert image_models[0]["id"] == "openai/dall-e-3"
|
||||
|
||||
video_models = get_cached_models(conn, modality="video")
|
||||
assert len(video_models) == 1
|
||||
assert video_models[0]["id"] == "openai/sora-2"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Integration tests: GET /models/
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def test_list_models_endpoint_auto_refreshes(client):
|
||||
token = await _register_login(client, "user@example.com", "secret123")
|
||||
with patch(
|
||||
"app.services.models.openrouter.list_models",
|
||||
new_callable=AsyncMock,
|
||||
return_value=FAKE_MODELS_RAW,
|
||||
) as mock_fetch:
|
||||
resp = await client.get(
|
||||
"/models/", headers={"Authorization": f"Bearer {token}"}
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
assert len(resp.json()) == 4
|
||||
mock_fetch.assert_awaited_once()
|
||||
|
||||
|
||||
async def test_list_models_endpoint_uses_cache(client):
|
||||
token = await _register_login(client, "user@example.com", "secret123")
|
||||
conn = db_module.get_conn()
|
||||
now = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
conn.execute(
|
||||
"INSERT INTO models_cache (model_id, name, modality, fetched_at) VALUES (?, ?, ?, ?)",
|
||||
["cached/model", "Cached Model", "text", now],
|
||||
)
|
||||
with patch(
|
||||
"app.services.models.openrouter.list_models",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_fetch:
|
||||
resp = await client.get(
|
||||
"/models/?modality=text", headers={"Authorization": f"Bearer {token}"}
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()[0]["id"] == "cached/model"
|
||||
mock_fetch.assert_not_awaited()
|
||||
|
||||
|
||||
async def test_list_models_endpoint_requires_auth(client):
|
||||
resp = await client.get("/models/")
|
||||
assert resp.status_code == 401
|
||||
|
||||
|
||||
async def test_list_models_filter_by_modality(client):
|
||||
token = await _register_login(client, "user@example.com", "secret123")
|
||||
with patch(
|
||||
"app.services.models.openrouter.list_models",
|
||||
new_callable=AsyncMock,
|
||||
return_value=FAKE_MODELS_RAW,
|
||||
):
|
||||
resp = await client.get(
|
||||
"/models/?modality=image", headers={"Authorization": f"Bearer {token}"}
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert len(data) == 1
|
||||
assert data[0]["id"] == "openai/dall-e-3"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Integration tests: POST /models/refresh
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def test_refresh_endpoint_requires_admin(client):
|
||||
token = await _register_login(client, "user@example.com", "secret123")
|
||||
resp = await client.post(
|
||||
"/models/refresh", headers={"Authorization": f"Bearer {token}"}
|
||||
)
|
||||
assert resp.status_code == 403
|
||||
|
||||
|
||||
async def test_refresh_endpoint_admin_succeeds(client):
|
||||
token = await _register_login(client, "admin@example.com", "secret123", is_admin=True)
|
||||
with patch(
|
||||
"app.services.models.openrouter.list_models",
|
||||
new_callable=AsyncMock,
|
||||
return_value=FAKE_MODELS_RAW,
|
||||
):
|
||||
resp = await client.post(
|
||||
"/models/refresh", headers={"Authorization": f"Bearer {token}"}
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["refreshed"] == 4
|
||||
|
||||
|
||||
async def test_refresh_endpoint_502_on_openrouter_error(client):
|
||||
token = await _register_login(client, "admin@example.com", "secret123", is_admin=True)
|
||||
with patch(
|
||||
"app.services.models.openrouter.list_models",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=RuntimeError("network error"),
|
||||
):
|
||||
resp = await client.post(
|
||||
"/models/refresh", headers={"Authorization": f"Bearer {token}"}
|
||||
)
|
||||
assert resp.status_code == 502
|
||||
Reference in New Issue
Block a user