Add models caching and management functionality with corresponding API endpoints and templates
Co-authored-by: Copilot <copilot@github.com>
This commit is contained in:
@@ -75,6 +75,17 @@ def _run_migrations(conn: duckdb.DuckDBPyConnection) -> None:
|
||||
created_at TIMESTAMP DEFAULT now()
|
||||
)
|
||||
""")
|
||||
conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS models_cache (
|
||||
id UUID DEFAULT uuid() PRIMARY KEY,
|
||||
model_id VARCHAR NOT NULL UNIQUE,
|
||||
name VARCHAR NOT NULL,
|
||||
modality VARCHAR NOT NULL,
|
||||
context_length BIGINT,
|
||||
pricing JSON,
|
||||
fetched_at TIMESTAMP NOT NULL
|
||||
)
|
||||
""")
|
||||
_seed_admin(conn)
|
||||
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ from .routers import admin as admin_router
|
||||
from .routers import ai as ai_router
|
||||
from .routers import generate as generate_router
|
||||
from .routers import images as images_router
|
||||
from .routers import models as models_router
|
||||
from .db import close_db, init_db
|
||||
import os
|
||||
from contextlib import asynccontextmanager
|
||||
@@ -43,6 +44,7 @@ app.include_router(admin_router.router)
|
||||
app.include_router(ai_router.router)
|
||||
app.include_router(generate_router.router)
|
||||
app.include_router(images_router.router)
|
||||
app.include_router(models_router.router)
|
||||
|
||||
|
||||
@app.get("/health", tags=["health"])
|
||||
|
||||
@@ -0,0 +1,47 @@
|
||||
"""Models router: list and refresh the OpenRouter model cache."""
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
||||
|
||||
from ..db import get_conn, get_write_lock
|
||||
from ..dependencies import get_current_user, require_admin
|
||||
from ..services import models as models_service
|
||||
|
||||
router = APIRouter(prefix="/models", tags=["models"])
|
||||
|
||||
|
||||
@router.get("/")
|
||||
async def list_models(
|
||||
modality: str | None = Query(
|
||||
None,
|
||||
description="Filter by output modality: text, image, video, audio",
|
||||
),
|
||||
_: dict = Depends(get_current_user),
|
||||
):
|
||||
"""Return cached models. Auto-refreshes cache if stale (older than 24 h)."""
|
||||
conn = get_conn()
|
||||
if models_service.is_cache_stale(conn):
|
||||
async with get_write_lock():
|
||||
# Re-check inside lock to avoid redundant parallel refreshes
|
||||
if models_service.is_cache_stale(conn):
|
||||
try:
|
||||
await models_service.refresh_models_cache(conn)
|
||||
except Exception as exc:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_502_BAD_GATEWAY,
|
||||
detail=f"Failed to refresh model cache: {exc}",
|
||||
)
|
||||
return models_service.get_cached_models(conn, modality)
|
||||
|
||||
|
||||
@router.post("/refresh", status_code=200)
|
||||
async def refresh_models(_: dict = Depends(require_admin)):
|
||||
"""Force-refresh the model cache from OpenRouter. Admin only."""
|
||||
conn = get_conn()
|
||||
async with get_write_lock():
|
||||
try:
|
||||
count = await models_service.refresh_models_cache(conn)
|
||||
except Exception as exc:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_502_BAD_GATEWAY,
|
||||
detail=f"OpenRouter error: {exc}",
|
||||
)
|
||||
return {"refreshed": count}
|
||||
@@ -0,0 +1,124 @@
|
||||
"""Model cache service: fetch from OpenRouter, store in DuckDB."""
|
||||
import json
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Any
|
||||
|
||||
import duckdb
|
||||
|
||||
from . import openrouter
|
||||
|
||||
CACHE_TTL_HOURS = 24
|
||||
|
||||
|
||||
def _parse_modality(raw_modality: str) -> str:
|
||||
"""Extract output modality from OpenRouter architecture.modality string.
|
||||
|
||||
Examples: "text->text", "text+image->text", "text->image", "text->video"
|
||||
"""
|
||||
output = raw_modality.split(
|
||||
"->", 1)[-1].lower() if "->" in raw_modality else raw_modality.lower()
|
||||
if "text" in output:
|
||||
return "text"
|
||||
if "image" in output:
|
||||
return "image"
|
||||
if "video" in output:
|
||||
return "video"
|
||||
if "audio" in output:
|
||||
return "audio"
|
||||
return output
|
||||
|
||||
|
||||
async def refresh_models_cache(conn: duckdb.DuckDBPyConnection) -> int:
|
||||
"""Fetch all models from OpenRouter and replace the cache. Returns count stored."""
|
||||
raw = await openrouter.list_models()
|
||||
# Use naive UTC to avoid DuckDB TIMESTAMP tz-stripping inconsistencies
|
||||
now = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
|
||||
conn.execute("DELETE FROM models_cache")
|
||||
count = 0
|
||||
for m in raw:
|
||||
arch = m.get("architecture", {})
|
||||
modality_raw = arch.get(
|
||||
"modality", "text->text") if arch else "text->text"
|
||||
modality = _parse_modality(modality_raw)
|
||||
pricing = m.get("pricing")
|
||||
model_id = m.get("id", "")
|
||||
if not model_id:
|
||||
continue
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT INTO models_cache (model_id, name, modality, context_length, pricing, fetched_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?)
|
||||
ON CONFLICT (model_id) DO UPDATE SET
|
||||
name = excluded.name,
|
||||
modality = excluded.modality,
|
||||
context_length = excluded.context_length,
|
||||
pricing = excluded.pricing,
|
||||
fetched_at = excluded.fetched_at
|
||||
""",
|
||||
[
|
||||
model_id,
|
||||
m.get("name", model_id),
|
||||
modality,
|
||||
m.get("context_length"),
|
||||
json.dumps(pricing) if pricing else None,
|
||||
now,
|
||||
],
|
||||
)
|
||||
count += 1
|
||||
return count
|
||||
|
||||
|
||||
def is_cache_stale(conn: duckdb.DuckDBPyConnection) -> bool:
|
||||
"""Return True if cache is empty or last fetched more than CACHE_TTL_HOURS ago."""
|
||||
row = conn.execute("SELECT MAX(fetched_at) FROM models_cache").fetchone()
|
||||
if not row or row[0] is None:
|
||||
return True
|
||||
last_fetched = row[0]
|
||||
# DuckDB TIMESTAMP is always naive; compare against naive UTC
|
||||
if last_fetched.tzinfo is not None:
|
||||
last_fetched = last_fetched.replace(tzinfo=None)
|
||||
now_naive = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
return now_naive - last_fetched > timedelta(hours=CACHE_TTL_HOURS)
|
||||
|
||||
|
||||
def get_cached_models(
|
||||
conn: duckdb.DuckDBPyConnection,
|
||||
modality: str | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Return cached models, optionally filtered by modality, ordered by name."""
|
||||
if modality:
|
||||
rows = conn.execute(
|
||||
"""
|
||||
SELECT model_id, name, modality, context_length, pricing
|
||||
FROM models_cache
|
||||
WHERE modality = ?
|
||||
ORDER BY name
|
||||
""",
|
||||
[modality],
|
||||
).fetchall()
|
||||
else:
|
||||
rows = conn.execute(
|
||||
"""
|
||||
SELECT model_id, name, modality, context_length, pricing
|
||||
FROM models_cache
|
||||
ORDER BY name
|
||||
"""
|
||||
).fetchall()
|
||||
|
||||
result = []
|
||||
for row in rows:
|
||||
pricing = None
|
||||
if row[4]:
|
||||
try:
|
||||
pricing = json.loads(row[4])
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
pricing = None
|
||||
result.append({
|
||||
"id": row[0],
|
||||
"name": row[1],
|
||||
"modality": row[2],
|
||||
"context_length": row[3],
|
||||
"pricing": pricing,
|
||||
})
|
||||
return result
|
||||
Reference in New Issue
Block a user