Add models caching and management functionality with corresponding API endpoints and templates

Co-authored-by: Copilot <copilot@github.com>
This commit is contained in:
2026-04-29 13:51:43 +02:00
parent fe32c32726
commit 96d13fc440
9 changed files with 534 additions and 8 deletions
+11
View File
@@ -75,6 +75,17 @@ def _run_migrations(conn: duckdb.DuckDBPyConnection) -> None:
created_at TIMESTAMP DEFAULT now()
)
""")
conn.execute("""
CREATE TABLE IF NOT EXISTS models_cache (
id UUID DEFAULT uuid() PRIMARY KEY,
model_id VARCHAR NOT NULL UNIQUE,
name VARCHAR NOT NULL,
modality VARCHAR NOT NULL,
context_length BIGINT,
pricing JSON,
fetched_at TIMESTAMP NOT NULL
)
""")
_seed_admin(conn)
+2
View File
@@ -4,6 +4,7 @@ from .routers import admin as admin_router
from .routers import ai as ai_router
from .routers import generate as generate_router
from .routers import images as images_router
from .routers import models as models_router
from .db import close_db, init_db
import os
from contextlib import asynccontextmanager
@@ -43,6 +44,7 @@ app.include_router(admin_router.router)
app.include_router(ai_router.router)
app.include_router(generate_router.router)
app.include_router(images_router.router)
app.include_router(models_router.router)
@app.get("/health", tags=["health"])
+47
View File
@@ -0,0 +1,47 @@
"""Models router: list and refresh the OpenRouter model cache."""
from fastapi import APIRouter, Depends, HTTPException, Query, status
from ..db import get_conn, get_write_lock
from ..dependencies import get_current_user, require_admin
from ..services import models as models_service
router = APIRouter(prefix="/models", tags=["models"])
@router.get("/")
async def list_models(
modality: str | None = Query(
None,
description="Filter by output modality: text, image, video, audio",
),
_: dict = Depends(get_current_user),
):
"""Return cached models. Auto-refreshes cache if stale (older than 24 h)."""
conn = get_conn()
if models_service.is_cache_stale(conn):
async with get_write_lock():
# Re-check inside lock to avoid redundant parallel refreshes
if models_service.is_cache_stale(conn):
try:
await models_service.refresh_models_cache(conn)
except Exception as exc:
raise HTTPException(
status_code=status.HTTP_502_BAD_GATEWAY,
detail=f"Failed to refresh model cache: {exc}",
)
return models_service.get_cached_models(conn, modality)
@router.post("/refresh", status_code=200)
async def refresh_models(_: dict = Depends(require_admin)):
"""Force-refresh the model cache from OpenRouter. Admin only."""
conn = get_conn()
async with get_write_lock():
try:
count = await models_service.refresh_models_cache(conn)
except Exception as exc:
raise HTTPException(
status_code=status.HTTP_502_BAD_GATEWAY,
detail=f"OpenRouter error: {exc}",
)
return {"refreshed": count}
+124
View File
@@ -0,0 +1,124 @@
"""Model cache service: fetch from OpenRouter, store in DuckDB."""
import json
from datetime import datetime, timedelta, timezone
from typing import Any
import duckdb
from . import openrouter
CACHE_TTL_HOURS = 24
def _parse_modality(raw_modality: str) -> str:
"""Extract output modality from OpenRouter architecture.modality string.
Examples: "text->text", "text+image->text", "text->image", "text->video"
"""
output = raw_modality.split(
"->", 1)[-1].lower() if "->" in raw_modality else raw_modality.lower()
if "text" in output:
return "text"
if "image" in output:
return "image"
if "video" in output:
return "video"
if "audio" in output:
return "audio"
return output
async def refresh_models_cache(conn: duckdb.DuckDBPyConnection) -> int:
"""Fetch all models from OpenRouter and replace the cache. Returns count stored."""
raw = await openrouter.list_models()
# Use naive UTC to avoid DuckDB TIMESTAMP tz-stripping inconsistencies
now = datetime.now(timezone.utc).replace(tzinfo=None)
conn.execute("DELETE FROM models_cache")
count = 0
for m in raw:
arch = m.get("architecture", {})
modality_raw = arch.get(
"modality", "text->text") if arch else "text->text"
modality = _parse_modality(modality_raw)
pricing = m.get("pricing")
model_id = m.get("id", "")
if not model_id:
continue
conn.execute(
"""
INSERT INTO models_cache (model_id, name, modality, context_length, pricing, fetched_at)
VALUES (?, ?, ?, ?, ?, ?)
ON CONFLICT (model_id) DO UPDATE SET
name = excluded.name,
modality = excluded.modality,
context_length = excluded.context_length,
pricing = excluded.pricing,
fetched_at = excluded.fetched_at
""",
[
model_id,
m.get("name", model_id),
modality,
m.get("context_length"),
json.dumps(pricing) if pricing else None,
now,
],
)
count += 1
return count
def is_cache_stale(conn: duckdb.DuckDBPyConnection) -> bool:
"""Return True if cache is empty or last fetched more than CACHE_TTL_HOURS ago."""
row = conn.execute("SELECT MAX(fetched_at) FROM models_cache").fetchone()
if not row or row[0] is None:
return True
last_fetched = row[0]
# DuckDB TIMESTAMP is always naive; compare against naive UTC
if last_fetched.tzinfo is not None:
last_fetched = last_fetched.replace(tzinfo=None)
now_naive = datetime.now(timezone.utc).replace(tzinfo=None)
return now_naive - last_fetched > timedelta(hours=CACHE_TTL_HOURS)
def get_cached_models(
conn: duckdb.DuckDBPyConnection,
modality: str | None = None,
) -> list[dict[str, Any]]:
"""Return cached models, optionally filtered by modality, ordered by name."""
if modality:
rows = conn.execute(
"""
SELECT model_id, name, modality, context_length, pricing
FROM models_cache
WHERE modality = ?
ORDER BY name
""",
[modality],
).fetchall()
else:
rows = conn.execute(
"""
SELECT model_id, name, modality, context_length, pricing
FROM models_cache
ORDER BY name
"""
).fetchall()
result = []
for row in rows:
pricing = None
if row[4]:
try:
pricing = json.loads(row[4])
except (json.JSONDecodeError, TypeError):
pricing = None
result.append({
"id": row[0],
"name": row[1],
"modality": row[2],
"context_length": row[3],
"pricing": pricing,
})
return result
+296
View File
@@ -0,0 +1,296 @@
"""Tests for model cache service and router."""
import json
import os
from datetime import datetime, timedelta, timezone
from unittest.mock import AsyncMock, patch
import pytest
import pytest_asyncio
from httpx import ASGITransport, AsyncClient
from app import db as db_module
from app.main import app
from app.services.models import (
_parse_modality,
get_cached_models,
is_cache_stale,
refresh_models_cache,
)
os.environ.setdefault("JWT_SECRET", "test-secret-key-for-testing-only")
os.environ.setdefault("OPENROUTER_API_KEY", "test-key")
FAKE_MODELS_RAW = [
{
"id": "openai/gpt-4o",
"name": "GPT-4o",
"context_length": 128000,
"pricing": {"prompt": "0.000005"},
"architecture": {"modality": "text->text"},
},
{
"id": "anthropic/claude-3-haiku",
"name": "Claude 3 Haiku",
"context_length": 200000,
"pricing": {},
"architecture": {"modality": "text+image->text"},
},
{
"id": "openai/dall-e-3",
"name": "DALL-E 3",
"context_length": None,
"pricing": {"image": "0.04"},
"architecture": {"modality": "text->image"},
},
{
"id": "openai/sora-2",
"name": "Sora 2",
"context_length": None,
"pricing": {"video": "0.10"},
"architecture": {"modality": "text->video"},
},
]
@pytest.fixture(autouse=True)
def fresh_db():
db_module._conn = None
db_module.init_db(":memory:")
yield
db_module.close_db()
db_module._conn = None
@pytest_asyncio.fixture
async def client(fresh_db):
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://test") as ac:
yield ac
async def _register_login(client, email, password, is_admin=False):
"""Register + login; optionally promote to admin directly in DB."""
await client.post("/auth/register", json={"email": email, "password": password})
if is_admin:
db_module.get_conn().execute(
"UPDATE users SET role = 'admin' WHERE email = ?", [email]
)
resp = await client.post("/auth/login", json={"email": email, "password": password})
return resp.json()["access_token"]
# ---------------------------------------------------------------------------
# Unit tests: _parse_modality
# ---------------------------------------------------------------------------
def test_parse_modality_text():
assert _parse_modality("text->text") == "text"
def test_parse_modality_multimodal_input_text_output():
assert _parse_modality("text+image->text") == "text"
def test_parse_modality_image():
assert _parse_modality("text->image") == "image"
def test_parse_modality_video():
assert _parse_modality("text->video") == "video"
def test_parse_modality_audio():
assert _parse_modality("text->audio") == "audio"
def test_parse_modality_no_arrow_fallback():
assert _parse_modality("text") == "text"
# ---------------------------------------------------------------------------
# Unit tests: is_cache_stale
# ---------------------------------------------------------------------------
def test_cache_stale_when_empty():
conn = db_module.get_conn()
assert is_cache_stale(conn) is True
def test_cache_not_stale_after_fresh_insert():
conn = db_module.get_conn()
now = datetime.now(timezone.utc).replace(tzinfo=None)
conn.execute(
"INSERT INTO models_cache (model_id, name, modality, fetched_at) VALUES (?, ?, ?, ?)",
["openai/gpt-4o", "GPT-4o", "text", now],
)
assert is_cache_stale(conn) is False
def test_cache_stale_after_ttl_exceeded():
conn = db_module.get_conn()
# Store naive UTC to match DuckDB TIMESTAMP behaviour
old_time = datetime.now(timezone.utc).replace(
tzinfo=None) - timedelta(hours=25)
conn.execute(
"INSERT INTO models_cache (model_id, name, modality, fetched_at) VALUES (?, ?, ?, ?)",
["openai/gpt-4o", "GPT-4o", "text", old_time],
)
assert is_cache_stale(conn) is True
# ---------------------------------------------------------------------------
# Unit tests: refresh_models_cache + get_cached_models
# ---------------------------------------------------------------------------
async def test_refresh_stores_models():
conn = db_module.get_conn()
with patch(
"app.services.models.openrouter.list_models",
new_callable=AsyncMock,
return_value=FAKE_MODELS_RAW,
):
count = await refresh_models_cache(conn)
assert count == 4
all_models = get_cached_models(conn)
assert len(all_models) == 4
async def test_refresh_replaces_old_cache():
conn = db_module.get_conn()
old_time = datetime.now(timezone.utc).replace(
tzinfo=None) - timedelta(hours=30)
conn.execute(
"INSERT INTO models_cache (model_id, name, modality, fetched_at) VALUES (?, ?, ?, ?)",
["old/model", "Old Model", "text", old_time],
)
with patch(
"app.services.models.openrouter.list_models",
new_callable=AsyncMock,
return_value=FAKE_MODELS_RAW,
):
await refresh_models_cache(conn)
ids = [m["id"] for m in get_cached_models(conn)]
assert "old/model" not in ids
assert "openai/gpt-4o" in ids
def test_get_cached_models_filter_by_modality():
conn = db_module.get_conn()
now = datetime.now(timezone.utc).replace(tzinfo=None)
for m in FAKE_MODELS_RAW:
arch = m.get("architecture", {})
modality = _parse_modality(arch.get("modality", "text->text"))
conn.execute(
"INSERT INTO models_cache (model_id, name, modality, fetched_at) VALUES (?, ?, ?, ?)",
[m["id"], m["name"], modality, now],
)
text_models = get_cached_models(conn, modality="text")
assert len(text_models) == 2
assert all(m["modality"] == "text" for m in text_models)
image_models = get_cached_models(conn, modality="image")
assert len(image_models) == 1
assert image_models[0]["id"] == "openai/dall-e-3"
video_models = get_cached_models(conn, modality="video")
assert len(video_models) == 1
assert video_models[0]["id"] == "openai/sora-2"
# ---------------------------------------------------------------------------
# Integration tests: GET /models/
# ---------------------------------------------------------------------------
async def test_list_models_endpoint_auto_refreshes(client):
token = await _register_login(client, "user@example.com", "secret123")
with patch(
"app.services.models.openrouter.list_models",
new_callable=AsyncMock,
return_value=FAKE_MODELS_RAW,
) as mock_fetch:
resp = await client.get(
"/models/", headers={"Authorization": f"Bearer {token}"}
)
assert resp.status_code == 200
assert len(resp.json()) == 4
mock_fetch.assert_awaited_once()
async def test_list_models_endpoint_uses_cache(client):
token = await _register_login(client, "user@example.com", "secret123")
conn = db_module.get_conn()
now = datetime.now(timezone.utc).replace(tzinfo=None)
conn.execute(
"INSERT INTO models_cache (model_id, name, modality, fetched_at) VALUES (?, ?, ?, ?)",
["cached/model", "Cached Model", "text", now],
)
with patch(
"app.services.models.openrouter.list_models",
new_callable=AsyncMock,
) as mock_fetch:
resp = await client.get(
"/models/?modality=text", headers={"Authorization": f"Bearer {token}"}
)
assert resp.status_code == 200
assert resp.json()[0]["id"] == "cached/model"
mock_fetch.assert_not_awaited()
async def test_list_models_endpoint_requires_auth(client):
resp = await client.get("/models/")
assert resp.status_code == 401
async def test_list_models_filter_by_modality(client):
token = await _register_login(client, "user@example.com", "secret123")
with patch(
"app.services.models.openrouter.list_models",
new_callable=AsyncMock,
return_value=FAKE_MODELS_RAW,
):
resp = await client.get(
"/models/?modality=image", headers={"Authorization": f"Bearer {token}"}
)
assert resp.status_code == 200
data = resp.json()
assert len(data) == 1
assert data[0]["id"] == "openai/dall-e-3"
# ---------------------------------------------------------------------------
# Integration tests: POST /models/refresh
# ---------------------------------------------------------------------------
async def test_refresh_endpoint_requires_admin(client):
token = await _register_login(client, "user@example.com", "secret123")
resp = await client.post(
"/models/refresh", headers={"Authorization": f"Bearer {token}"}
)
assert resp.status_code == 403
async def test_refresh_endpoint_admin_succeeds(client):
token = await _register_login(client, "admin@example.com", "secret123", is_admin=True)
with patch(
"app.services.models.openrouter.list_models",
new_callable=AsyncMock,
return_value=FAKE_MODELS_RAW,
):
resp = await client.post(
"/models/refresh", headers={"Authorization": f"Bearer {token}"}
)
assert resp.status_code == 200
assert resp.json()["refreshed"] == 4
async def test_refresh_endpoint_502_on_openrouter_error(client):
token = await _register_login(client, "admin@example.com", "secret123", is_admin=True)
with patch(
"app.services.models.openrouter.list_models",
new_callable=AsyncMock,
side_effect=RuntimeError("network error"),
):
resp = await client.post(
"/models/refresh", headers={"Authorization": f"Bearer {token}"}
)
assert resp.status_code == 502
+22 -8
View File
@@ -153,8 +153,9 @@ def generate():
@login_required
def generate_text():
result = error = None
token = session["access_token"]
if request.method == "POST":
resp = _api("POST", "/generate/text", token=session["access_token"], json={
resp = _api("POST", "/generate/text", token=token, json={
"model": request.form.get("model", "").strip(),
"prompt": request.form.get("prompt", "").strip(),
})
@@ -162,28 +163,35 @@ def generate_text():
result = resp.json()
else:
error = resp.json().get("detail", "Generation failed.")
return render_template("generate_text.html", result=result, error=error)
models_resp = _api("GET", "/models/", token=token,
params={"modality": "text"})
models = models_resp.json() if models_resp.status_code == 200 else []
return render_template("generate_text.html", result=result, error=error, models=models)
@app.route("/generate/image", methods=["GET", "POST"])
@login_required
def generate_image():
result = error = None
token = session["access_token"]
if request.method == "POST":
# Upload reference image if provided
ref_file = request.files.get("reference_image")
if ref_file and ref_file.filename:
up_resp = _api(
"POST", "/images/upload",
token=session["access_token"],
token=token,
files={"file": (ref_file.filename,
ref_file.stream, ref_file.content_type)},
)
if up_resp.status_code not in (200, 201):
error = up_resp.json().get("detail", "Image upload failed.")
return render_template("generate_image.html", result=result, error=error)
models_resp = _api("GET", "/models/",
token=token, params={"modality": "image"})
models = models_resp.json() if models_resp.status_code == 200 else []
return render_template("generate_image.html", result=result, error=error, models=models)
resp = _api("POST", "/generate/image", token=session["access_token"], json={
resp = _api("POST", "/generate/image", token=token, json={
"model": request.form.get("model", "").strip(),
"prompt": request.form.get("prompt", "").strip(),
"n": int(request.form.get("n", 1)),
@@ -195,16 +203,19 @@ def generate_image():
result = resp.json()
else:
error = resp.json().get("detail", "Generation failed.")
return render_template("generate_image.html", result=result, error=error)
models_resp = _api("GET", "/models/", token=token,
params={"modality": "image"})
models = models_resp.json() if models_resp.status_code == 200 else []
return render_template("generate_image.html", result=result, error=error, models=models)
@app.route("/generate/video", methods=["GET", "POST"])
@login_required
def generate_video():
result = error = None
token = session["access_token"]
if request.method == "POST":
mode = request.form.get("mode", "text")
token = session["access_token"]
duration_raw = request.form.get("duration_seconds", "")
duration = int(
duration_raw) if duration_raw.strip().isdigit() else None
@@ -230,7 +241,10 @@ def generate_video():
result = resp.json()
else:
error = resp.json().get("detail", "Generation failed.")
return render_template("generate_video.html", result=result, error=error)
models_resp = _api("GET", "/models/", token=token,
params={"modality": "video"})
models = models_resp.json() if models_resp.status_code == 200 else []
return render_template("generate_video.html", result=result, error=error, models=models)
@app.get("/generate/video/status")
@@ -5,9 +5,17 @@
<h1>Image Generation</h1>
<form method="post" enctype="multipart/form-data">
<label for="model">Model</label>
{% if models %}
<select id="model" name="model" required>
{% for m in models %}
<option value="{{ m.id }}" {% if request.form.get('model', '') == m.id %}selected{% endif %}>{{ m.name }}</option>
{% endfor %}
</select>
{% else %}
<input id="model" name="model" type="text" required
placeholder="e.g. openai/dall-e-3"
value="{{ request.form.get('model', '') }}">
{% endif %}
<label for="prompt">Prompt</label>
<textarea id="prompt" name="prompt" rows="4" required
@@ -4,6 +4,13 @@ AI{% endblock %} {% block content %}
<h1>Text Generation</h1>
<form method="post">
<label for="model">Model</label>
{% if models %}
<select id="model" name="model" required>
{% for m in models %}
<option value="{{ m.id }}" {% if request.form.get('model', '') == m.id %}selected{% endif %}>{{ m.name }}</option>
{% endfor %}
</select>
{% else %}
<input
id="model"
name="model"
@@ -12,6 +19,7 @@ AI{% endblock %} {% block content %}
placeholder="e.g. openai/gpt-4o"
value="{{ request.form.get('model', '') }}"
/>
{% endif %}
<label for="prompt">Prompt</label>
<textarea
@@ -19,6 +19,13 @@ AI{% endblock %} {% block content %}
<input type="hidden" name="mode" value="text" />
<label for="model-t">Model</label>
{% if models %}
<select id="model-t" name="model" required>
{% for m in models %}
<option value="{{ m.id }}" {% if request.form.get('model', '') == m.id and request.form.get('mode','text')=='text' %}selected{% endif %}>{{ m.name }}</option>
{% endfor %}
</select>
{% else %}
<input
id="model-t"
name="model"
@@ -27,6 +34,7 @@ AI{% endblock %} {% block content %}
placeholder="e.g. openai/sora-2-pro"
value="{{ request.form.get('model', '') if request.form.get('mode','text')=='text' else '' }}"
/>
{% endif %}
<label for="prompt-t">Prompt</label>
<textarea
@@ -80,6 +88,13 @@ AI{% endblock %} {% block content %}
<input type="hidden" name="mode" value="image" />
<label for="model-i">Model</label>
{% if models %}
<select id="model-i" name="model" required>
{% for m in models %}
<option value="{{ m.id }}" {% if request.form.get('model', '') == m.id and request.form.get('mode')=='image' %}selected{% endif %}>{{ m.name }}</option>
{% endfor %}
</select>
{% else %}
<input
id="model-i"
name="model"
@@ -88,6 +103,7 @@ AI{% endblock %} {% block content %}
placeholder="e.g. openai/sora-2-pro"
value="{{ request.form.get('model', '') if request.form.get('mode')=='image' else '' }}"
/>
{% endif %}
<label for="image_url">Source image URL</label>
<input