"""Model cache service: fetch from OpenRouter, store in DuckDB.""" import json from datetime import datetime, timedelta, timezone from typing import Any import duckdb from . import openrouter CACHE_TTL_HOURS = 24 def _normalize_modality(raw: str) -> str: """Normalize OpenRouter modality labels to canonical values.""" value = (raw or "").strip().lower() if value in {"text", "image", "video", "audio", "embeddings", "embedding"}: return "embeddings" if value == "embedding" else value if "image" in value: return "image" if "video" in value: return "video" if "audio" in value: return "audio" if "embed" in value: return "embeddings" return "text" def _parse_modality(raw_modality: str) -> str: """Extract output modality from OpenRouter architecture.modality string. Examples: "text->text", "text+image->text", "text->image", "text->video" """ output = raw_modality.split( "->", 1)[-1] if "->" in raw_modality else raw_modality return _normalize_modality(output) def _extract_output_modality(model: dict[str, Any]) -> str: """Extract output modality using OpenRouter schema, fallback to legacy field.""" architecture = model.get("architecture") or {} output_modalities = architecture.get( "output_modalities") or model.get("output_modalities") if isinstance(output_modalities, list) and output_modalities: return _normalize_modality(str(output_modalities[0])) raw_modality = architecture.get( "modality") or model.get("modality") or "text->text" if isinstance(raw_modality, str): return _parse_modality(raw_modality) return "text" async def _fetch_models_for_cache() -> list[dict[str, Any]]: """Fetch broad + modality-specific lists and merge unique models by id.""" by_id: dict[str, dict[str, Any]] = {} # Primary fetch: all modalities (per OpenRouter docs). primary = await openrouter.list_models(output_modalities="all") for model in primary: model_id = model.get("id") if model_id: by_id[model_id] = model # Warmup fetches: some providers surface better results with explicit modality filter. for modality in ("image", "video", "audio", "embeddings", "text"): try: subset = await openrouter.list_models(output_modalities=modality) except Exception: continue for model in subset: model_id = model.get("id") if model_id and model_id not in by_id: by_id[model_id] = model return list(by_id.values()) async def refresh_models_cache(conn: duckdb.DuckDBPyConnection) -> int: """Fetch all models from OpenRouter and replace the cache. Returns count stored.""" raw = await _fetch_models_for_cache() # Use naive UTC to avoid DuckDB TIMESTAMP tz-stripping inconsistencies now = datetime.now(timezone.utc).replace(tzinfo=None) conn.execute("DELETE FROM models_cache") count = 0 for m in raw: modality = _extract_output_modality(m) pricing = m.get("pricing") model_id = m.get("id", "") if not model_id: continue conn.execute( """ INSERT INTO models_cache (model_id, name, modality, context_length, pricing, fetched_at) VALUES (?, ?, ?, ?, ?, ?) ON CONFLICT (model_id) DO UPDATE SET name = excluded.name, modality = excluded.modality, context_length = excluded.context_length, pricing = excluded.pricing, fetched_at = excluded.fetched_at """, [ model_id, m.get("name", model_id), modality, m.get("context_length"), json.dumps(pricing) if pricing else None, now, ], ) count += 1 return count def is_cache_stale(conn: duckdb.DuckDBPyConnection) -> bool: """Return True if cache is empty or last fetched more than CACHE_TTL_HOURS ago.""" row = conn.execute("SELECT MAX(fetched_at) FROM models_cache").fetchone() if not row or row[0] is None: return True last_fetched = row[0] # DuckDB TIMESTAMP is always naive; compare against naive UTC if last_fetched.tzinfo is not None: last_fetched = last_fetched.replace(tzinfo=None) now_naive = datetime.now(timezone.utc).replace(tzinfo=None) return now_naive - last_fetched > timedelta(hours=CACHE_TTL_HOURS) def get_cached_models( conn: duckdb.DuckDBPyConnection, modality: str | None = None, ) -> list[dict[str, Any]]: """Return cached models, optionally filtered by modality, ordered by name.""" if modality: rows = conn.execute( """ SELECT model_id, name, modality, context_length, pricing FROM models_cache WHERE modality = ? ORDER BY name """, [modality], ).fetchall() else: rows = conn.execute( """ SELECT model_id, name, modality, context_length, pricing FROM models_cache ORDER BY name """ ).fetchall() result = [] for row in rows: pricing = None if row[4]: try: pricing = json.loads(row[4]) except (json.JSONDecodeError, TypeError): pricing = None result.append({ "id": row[0], "name": row[1], "modality": row[2], "context_length": row[3], "pricing": pricing, }) return result