712c556032
- Updated `refresh_models_cache` to include output modalities in the models cache. - Added `get_model_output_modalities` function to retrieve output modalities for a specific model. - Modified tests to cover new functionality for output modalities. - Updated OpenRouter video generation functions to support audio generation and improved error handling. - Enhanced dashboard to display generated images and videos. - Refactored frontend templates to accommodate new data structures for generated content. - Adjusted tests to validate changes in model handling and dashboard rendering. Co-authored-by: Copilot <copilot@github.com>
201 lines
6.7 KiB
Python
201 lines
6.7 KiB
Python
"""Model cache service: fetch from OpenRouter, store in DuckDB."""
|
|
import json
|
|
from datetime import datetime, timedelta, timezone
|
|
from typing import Any
|
|
|
|
import duckdb
|
|
|
|
from . import openrouter
|
|
|
|
CACHE_TTL_HOURS = 24
|
|
|
|
|
|
def _normalize_modality(raw: str) -> str:
|
|
"""Normalize OpenRouter modality labels to canonical values."""
|
|
value = (raw or "").strip().lower()
|
|
if value in {"text", "image", "video", "audio", "embeddings", "embedding"}:
|
|
return "embeddings" if value == "embedding" else value
|
|
if "image" in value:
|
|
return "image"
|
|
if "video" in value:
|
|
return "video"
|
|
if "audio" in value:
|
|
return "audio"
|
|
if "embed" in value:
|
|
return "embeddings"
|
|
return "text"
|
|
|
|
|
|
def _parse_modality(raw_modality: str) -> str:
|
|
"""Extract output modality from OpenRouter architecture.modality string.
|
|
|
|
Examples: "text->text", "text+image->text", "text->image", "text->video"
|
|
"""
|
|
output = raw_modality.split(
|
|
"->", 1)[-1] if "->" in raw_modality else raw_modality
|
|
return _normalize_modality(output)
|
|
|
|
|
|
def _extract_output_modality(model: dict[str, Any]) -> str:
|
|
"""Extract output modality using OpenRouter schema, fallback to legacy field."""
|
|
architecture = model.get("architecture") or {}
|
|
|
|
output_modalities = architecture.get(
|
|
"output_modalities") or model.get("output_modalities")
|
|
if isinstance(output_modalities, list) and output_modalities:
|
|
return _normalize_modality(str(output_modalities[0]))
|
|
|
|
raw_modality = architecture.get(
|
|
"modality") or model.get("modality") or "text->text"
|
|
if isinstance(raw_modality, str):
|
|
return _parse_modality(raw_modality)
|
|
return "text"
|
|
|
|
|
|
async def _fetch_models_for_cache() -> list[dict[str, Any]]:
|
|
"""Fetch broad + modality-specific lists and merge unique models by id."""
|
|
by_id: dict[str, dict[str, Any]] = {}
|
|
|
|
# Primary fetch: all modalities (per OpenRouter docs).
|
|
primary = await openrouter.list_models(output_modalities="all")
|
|
for model in primary:
|
|
model_id = model.get("id")
|
|
if model_id:
|
|
by_id[model_id] = model
|
|
|
|
# Warmup fetches: some providers surface better results with explicit modality filter.
|
|
for modality in ("image", "video", "audio", "embeddings", "text"):
|
|
try:
|
|
subset = await openrouter.list_models(output_modalities=modality)
|
|
except Exception:
|
|
continue
|
|
for model in subset:
|
|
model_id = model.get("id")
|
|
if model_id and model_id not in by_id:
|
|
by_id[model_id] = model
|
|
|
|
return list(by_id.values())
|
|
|
|
|
|
async def refresh_models_cache(conn: duckdb.DuckDBPyConnection) -> int:
|
|
"""Fetch all models from OpenRouter and replace the cache. Returns count stored."""
|
|
raw = await _fetch_models_for_cache()
|
|
# Use naive UTC to avoid DuckDB TIMESTAMP tz-stripping inconsistencies
|
|
now = datetime.now(timezone.utc).replace(tzinfo=None)
|
|
|
|
conn.execute("DELETE FROM models_cache")
|
|
count = 0
|
|
for m in raw:
|
|
modality = _extract_output_modality(m)
|
|
pricing = m.get("pricing")
|
|
model_id = m.get("id", "")
|
|
if not model_id:
|
|
continue
|
|
# Full output_modalities array from architecture (for proper modalities param in image gen)
|
|
architecture = m.get("architecture") or {}
|
|
raw_output_modalities: list | None = (
|
|
architecture.get("output_modalities") or m.get("output_modalities")
|
|
)
|
|
output_modalities_json: str | None = (
|
|
json.dumps([_normalize_modality(str(v))
|
|
for v in raw_output_modalities])
|
|
if isinstance(raw_output_modalities, list)
|
|
else None
|
|
)
|
|
conn.execute(
|
|
"""
|
|
INSERT INTO models_cache (model_id, name, modality, context_length, pricing, fetched_at, output_modalities)
|
|
VALUES (?, ?, ?, ?, ?, ?, ?)
|
|
ON CONFLICT (model_id) DO UPDATE SET
|
|
name = excluded.name,
|
|
modality = excluded.modality,
|
|
context_length = excluded.context_length,
|
|
pricing = excluded.pricing,
|
|
fetched_at = excluded.fetched_at,
|
|
output_modalities = excluded.output_modalities
|
|
""",
|
|
[
|
|
model_id,
|
|
m.get("name", model_id),
|
|
modality,
|
|
m.get("context_length"),
|
|
json.dumps(pricing) if pricing else None,
|
|
now,
|
|
output_modalities_json,
|
|
],
|
|
)
|
|
count += 1
|
|
return count
|
|
|
|
|
|
def is_cache_stale(conn: duckdb.DuckDBPyConnection) -> bool:
|
|
"""Return True if cache is empty or last fetched more than CACHE_TTL_HOURS ago."""
|
|
row = conn.execute("SELECT MAX(fetched_at) FROM models_cache").fetchone()
|
|
if not row or row[0] is None:
|
|
return True
|
|
last_fetched = row[0]
|
|
# DuckDB TIMESTAMP is always naive; compare against naive UTC
|
|
if last_fetched.tzinfo is not None:
|
|
last_fetched = last_fetched.replace(tzinfo=None)
|
|
now_naive = datetime.now(timezone.utc).replace(tzinfo=None)
|
|
return now_naive - last_fetched > timedelta(hours=CACHE_TTL_HOURS)
|
|
|
|
|
|
def get_cached_models(
|
|
conn: duckdb.DuckDBPyConnection,
|
|
modality: str | None = None,
|
|
) -> list[dict[str, Any]]:
|
|
"""Return cached models, optionally filtered by modality, ordered by name."""
|
|
if modality:
|
|
rows = conn.execute(
|
|
"""
|
|
SELECT model_id, name, modality, context_length, pricing
|
|
FROM models_cache
|
|
WHERE modality = ?
|
|
ORDER BY name
|
|
""",
|
|
[modality],
|
|
).fetchall()
|
|
else:
|
|
rows = conn.execute(
|
|
"""
|
|
SELECT model_id, name, modality, context_length, pricing
|
|
FROM models_cache
|
|
ORDER BY name
|
|
"""
|
|
).fetchall()
|
|
|
|
result = []
|
|
for row in rows:
|
|
pricing = None
|
|
if row[4]:
|
|
try:
|
|
pricing = json.loads(row[4])
|
|
except (json.JSONDecodeError, TypeError):
|
|
pricing = None
|
|
result.append({
|
|
"id": row[0],
|
|
"name": row[1],
|
|
"modality": row[2],
|
|
"context_length": row[3],
|
|
"pricing": pricing,
|
|
})
|
|
return result
|
|
|
|
|
|
def get_model_output_modalities(
|
|
conn: duckdb.DuckDBPyConnection,
|
|
model_id: str,
|
|
) -> list[str]:
|
|
"""Return output_modalities list for a model; empty list if not found."""
|
|
row = conn.execute(
|
|
"SELECT output_modalities FROM models_cache WHERE model_id = ?",
|
|
[model_id],
|
|
).fetchone()
|
|
if not row or not row[0]:
|
|
return []
|
|
try:
|
|
return json.loads(row[0])
|
|
except (json.JSONDecodeError, TypeError):
|
|
return []
|