Files
ai.allucanget.biz/backend/app/services/models.py
T

247 lines
8.0 KiB
Python

"""Model cache service: fetch from OpenRouter, store in DuckDB."""
import json
from datetime import datetime, timedelta, timezone
from typing import Any
import duckdb
from . import openrouter
CACHE_TTL_HOURS = 24
def _normalize_modality(raw: str) -> str:
"""Normalize OpenRouter modality labels to canonical values."""
value = (raw or "").strip().lower()
if value in {"text", "image", "video", "audio", "embeddings", "embedding"}:
return "embeddings" if value == "embedding" else value
if "image" in value:
return "image"
if "video" in value:
return "video"
if "audio" in value:
return "audio"
if "embed" in value:
return "embeddings"
return "text"
def _parse_modality(raw_modality: str) -> str:
"""Extract output modality from OpenRouter architecture.modality string.
Examples: "text->text", "text+image->text", "text->image", "text->video"
"""
output = raw_modality.split(
"->", 1)[-1] if "->" in raw_modality else raw_modality
return _normalize_modality(output)
def _extract_output_modality(model: dict[str, Any]) -> str:
"""Extract output modality using OpenRouter schema, fallback to legacy field."""
architecture = model.get("architecture") or {}
output_modalities = architecture.get(
"output_modalities") or model.get("output_modalities")
if isinstance(output_modalities, list) and output_modalities:
return _normalize_modality(str(output_modalities[0]))
raw_modality = architecture.get(
"modality") or model.get("modality") or "text->text"
if isinstance(raw_modality, str):
return _parse_modality(raw_modality)
return "text"
async def _fetch_models_for_cache() -> list[dict[str, Any]]:
"""Fetch broad + modality-specific lists and merge unique models by id."""
by_id: dict[str, dict[str, Any]] = {}
# Primary fetch: all modalities (per OpenRouter docs).
primary = await openrouter.list_models(output_modalities="all")
for model in primary:
model_id = model.get("id")
if model_id:
by_id[model_id] = model
# Warmup fetches: some providers surface better results with explicit modality filter.
for modality in ("image", "video", "audio", "embeddings", "text"):
try:
subset = await openrouter.list_models(output_modalities=modality)
except Exception:
continue
for model in subset:
model_id = model.get("id")
if model_id and model_id not in by_id:
by_id[model_id] = model
return list(by_id.values())
async def refresh_models_cache(conn: duckdb.DuckDBPyConnection) -> int:
"""Fetch all models from OpenRouter and replace the cache. Returns count stored."""
raw = await _fetch_models_for_cache()
# Use naive UTC to avoid DuckDB TIMESTAMP tz-stripping inconsistencies
now = datetime.now(timezone.utc).replace(tzinfo=None)
conn.execute("DELETE FROM models_cache")
count = 0
for m in raw:
modality = _extract_output_modality(m)
pricing = m.get("pricing")
model_id = m.get("id", "")
if not model_id:
continue
# Full output_modalities array from architecture (for proper modalities param in image gen)
architecture = m.get("architecture") or {}
raw_output_modalities: list | None = (
architecture.get("output_modalities") or m.get("output_modalities")
)
output_modalities_json: str | None = (
json.dumps([_normalize_modality(str(v))
for v in raw_output_modalities])
if isinstance(raw_output_modalities, list)
else None
)
conn.execute(
"""
INSERT INTO models_cache (model_id, name, modality, context_length, pricing, fetched_at, output_modalities)
VALUES (?, ?, ?, ?, ?, ?, ?)
ON CONFLICT (model_id) DO UPDATE SET
name = excluded.name,
modality = excluded.modality,
context_length = excluded.context_length,
pricing = excluded.pricing,
fetched_at = excluded.fetched_at,
output_modalities = excluded.output_modalities
""",
[
model_id,
m.get("name", model_id),
modality,
m.get("context_length"),
json.dumps(pricing) if pricing else None,
now,
output_modalities_json,
],
)
count += 1
return count
def is_cache_stale(conn: duckdb.DuckDBPyConnection) -> bool:
"""Return True if cache is empty or last fetched more than CACHE_TTL_HOURS ago."""
row = conn.execute("SELECT MAX(fetched_at) FROM models_cache").fetchone()
if not row or row[0] is None:
return True
last_fetched = row[0]
# DuckDB TIMESTAMP is always naive; compare against naive UTC
if last_fetched.tzinfo is not None:
last_fetched = last_fetched.replace(tzinfo=None)
now_naive = datetime.now(timezone.utc).replace(tzinfo=None)
return now_naive - last_fetched > timedelta(hours=CACHE_TTL_HOURS)
def get_cached_models(
conn: duckdb.DuckDBPyConnection,
modality: str | None = None,
) -> list[dict[str, Any]]:
"""Return cached models, optionally filtered by modality, ordered by name."""
if modality:
rows = conn.execute(
"""
SELECT model_id, name, modality, context_length, pricing
FROM models_cache
WHERE modality = ?
ORDER BY name
""",
[modality],
).fetchall()
else:
rows = conn.execute(
"""
SELECT model_id, name, modality, context_length, pricing
FROM models_cache
ORDER BY name
"""
).fetchall()
result = []
for row in rows:
pricing = None
if row[4]:
try:
pricing = json.loads(row[4])
except (json.JSONDecodeError, TypeError):
pricing = None
result.append({
"id": row[0],
"name": row[1],
"modality": row[2],
"context_length": row[3],
"pricing": pricing,
})
return result
def get_model_output_modalities(
conn: duckdb.DuckDBPyConnection,
model_id: str,
) -> list[str]:
"""Return output_modalities list for a model; empty list if not found."""
row = conn.execute(
"SELECT output_modalities FROM models_cache WHERE model_id = ?",
[model_id],
).fetchone()
if not row or not row[0]:
return []
try:
return json.loads(row[0])
except (json.JSONDecodeError, TypeError):
return []
def get_cache_status(conn: duckdb.DuckDBPyConnection) -> dict[str, Any]:
"""Return cache last update time and model count."""
row = conn.execute(
"SELECT MAX(fetched_at), COUNT(*) FROM models_cache"
).fetchone()
last_updated, model_count = (row[0], row[1]) if row else (None, 0)
return {"last_updated": last_updated, "model_count": model_count}
def mark_timed_out_video_jobs(conn: duckdb.DuckDBPyConnection, timeout_minutes: int = 120) -> int:
"""Mark video jobs that have been in 'queued' or 'processing' status for too long as 'failed'.
Returns the number of jobs marked as timed out.
"""
timeout_threshold = datetime.now(
timezone.utc) - timedelta(minutes=timeout_minutes)
# Find timed out jobs
timed_out_rows = conn.execute(
"""
SELECT id FROM generated_videos
WHERE status IN ('queued', 'processing')
AND updated_at < ?
""",
[timeout_threshold]
).fetchall()
if not timed_out_rows:
return 0
job_ids = [row[0] for row in timed_out_rows]
placeholders = ",".join(["?"] * len(job_ids))
# Update them to failed
conn.execute(
f"""
UPDATE generated_videos
SET status = 'failed', updated_at = ?
WHERE id IN ({placeholders})
""",
[datetime.now(timezone.utc)] + job_ids
)
return len(job_ids)