"""Model cache service: fetch from OpenRouter, store in DuckDB.""" import json from datetime import datetime, timedelta, timezone from typing import Any import duckdb from . import openrouter CACHE_TTL_HOURS = 24 def _normalize_modality(raw: str) -> str: """Normalize OpenRouter modality labels to canonical values.""" value = (raw or "").strip().lower() if value in {"text", "image", "video", "audio", "embeddings", "embedding"}: return "embeddings" if value == "embedding" else value if "image" in value: return "image" if "video" in value: return "video" if "audio" in value: return "audio" if "embed" in value: return "embeddings" return "text" def _parse_modality(raw_modality: str) -> str: """Extract output modality from OpenRouter architecture.modality string. Examples: "text->text", "text+image->text", "text->image", "text->video" """ output = raw_modality.split( "->", 1)[-1] if "->" in raw_modality else raw_modality return _normalize_modality(output) def _extract_output_modality(model: dict[str, Any]) -> str: """Extract output modality using OpenRouter schema, fallback to legacy field.""" architecture = model.get("architecture") or {} output_modalities = architecture.get( "output_modalities") or model.get("output_modalities") if isinstance(output_modalities, list) and output_modalities: return _normalize_modality(str(output_modalities[0])) raw_modality = architecture.get( "modality") or model.get("modality") or "text->text" if isinstance(raw_modality, str): return _parse_modality(raw_modality) return "text" async def _fetch_models_for_cache() -> list[dict[str, Any]]: """Fetch broad + modality-specific lists and merge unique models by id.""" by_id: dict[str, dict[str, Any]] = {} # Primary fetch: all modalities (per OpenRouter docs). primary = await openrouter.list_models(output_modalities="all") for model in primary: model_id = model.get("id") if model_id: by_id[model_id] = model # Warmup fetches: some providers surface better results with explicit modality filter. for modality in ("image", "video", "audio", "embeddings", "text"): try: subset = await openrouter.list_models(output_modalities=modality) except Exception: continue for model in subset: model_id = model.get("id") if model_id and model_id not in by_id: by_id[model_id] = model return list(by_id.values()) async def refresh_models_cache(conn: duckdb.DuckDBPyConnection) -> int: """Fetch all models from OpenRouter and replace the cache. Returns count stored.""" raw = await _fetch_models_for_cache() # Use naive UTC to avoid DuckDB TIMESTAMP tz-stripping inconsistencies now = datetime.now(timezone.utc).replace(tzinfo=None) conn.execute("DELETE FROM models_cache") count = 0 for m in raw: modality = _extract_output_modality(m) pricing = m.get("pricing") model_id = m.get("id", "") if not model_id: continue # Full output_modalities array from architecture (for proper modalities param in image gen) architecture = m.get("architecture") or {} raw_output_modalities: list | None = ( architecture.get("output_modalities") or m.get("output_modalities") ) output_modalities_json: str | None = ( json.dumps([_normalize_modality(str(v)) for v in raw_output_modalities]) if isinstance(raw_output_modalities, list) else None ) conn.execute( """ INSERT INTO models_cache (model_id, name, modality, context_length, pricing, fetched_at, output_modalities) VALUES (?, ?, ?, ?, ?, ?, ?) ON CONFLICT (model_id) DO UPDATE SET name = excluded.name, modality = excluded.modality, context_length = excluded.context_length, pricing = excluded.pricing, fetched_at = excluded.fetched_at, output_modalities = excluded.output_modalities """, [ model_id, m.get("name", model_id), modality, m.get("context_length"), json.dumps(pricing) if pricing else None, now, output_modalities_json, ], ) count += 1 return count def is_cache_stale(conn: duckdb.DuckDBPyConnection) -> bool: """Return True if cache is empty or last fetched more than CACHE_TTL_HOURS ago.""" row = conn.execute("SELECT MAX(fetched_at) FROM models_cache").fetchone() if not row or row[0] is None: return True last_fetched = row[0] # DuckDB TIMESTAMP is always naive; compare against naive UTC if last_fetched.tzinfo is not None: last_fetched = last_fetched.replace(tzinfo=None) now_naive = datetime.now(timezone.utc).replace(tzinfo=None) return now_naive - last_fetched > timedelta(hours=CACHE_TTL_HOURS) def get_cached_models( conn: duckdb.DuckDBPyConnection, modality: str | None = None, ) -> list[dict[str, Any]]: """Return cached models, optionally filtered by modality, ordered by name.""" if modality: rows = conn.execute( """ SELECT model_id, name, modality, context_length, pricing FROM models_cache WHERE modality = ? ORDER BY name """, [modality], ).fetchall() else: rows = conn.execute( """ SELECT model_id, name, modality, context_length, pricing FROM models_cache ORDER BY name """ ).fetchall() result = [] for row in rows: pricing = None if row[4]: try: pricing = json.loads(row[4]) except (json.JSONDecodeError, TypeError): pricing = None result.append({ "id": row[0], "name": row[1], "modality": row[2], "context_length": row[3], "pricing": pricing, }) return result def get_model_output_modalities( conn: duckdb.DuckDBPyConnection, model_id: str, ) -> list[str]: """Return output_modalities list for a model; empty list if not found.""" row = conn.execute( "SELECT output_modalities FROM models_cache WHERE model_id = ?", [model_id], ).fetchone() if not row or not row[0]: return [] try: return json.loads(row[0]) except (json.JSONDecodeError, TypeError): return [] def get_cache_status(conn: duckdb.DuckDBPyConnection) -> dict[str, Any]: """Return cache last update time and model count.""" row = conn.execute( "SELECT MAX(fetched_at), COUNT(*) FROM models_cache" ).fetchone() last_updated, model_count = (row[0], row[1]) if row else (None, 0) return {"last_updated": last_updated, "model_count": model_count} def mark_timed_out_video_jobs(conn: duckdb.DuckDBPyConnection, timeout_minutes: int = 120) -> int: """Mark video jobs that have been in 'queued' or 'processing' status for too long as 'failed'. Returns the number of jobs marked as timed out. """ timeout_threshold = datetime.now( timezone.utc) - timedelta(minutes=timeout_minutes) # Find timed out jobs timed_out_rows = conn.execute( """ SELECT id FROM generated_videos WHERE status IN ('queued', 'processing') AND updated_at < ? """, [timeout_threshold] ).fetchall() if not timed_out_rows: return 0 job_ids = [row[0] for row in timed_out_rows] placeholders = ",".join(["?"] * len(job_ids)) # Update them to failed conn.execute( f""" UPDATE generated_videos SET status = 'failed', updated_at = ? WHERE id IN ({placeholders}) """, [datetime.now(timezone.utc)] + job_ids ) return len(job_ids)