From 78b76dc33102f554984b218c4a5c28c983a1ffe4 Mon Sep 17 00:00:00 2001 From: zwitschi Date: Wed, 29 Apr 2026 14:16:42 +0200 Subject: [PATCH] Enhance model handling by normalizing modalities and updating fetch logic; add tests for new functionality Co-authored-by: Copilot --- backend/app/services/models.py | 76 ++++++++++++++++++++++++------ backend/app/services/openrouter.py | 19 ++++++-- backend/tests/test_models.py | 23 ++++++++- 3 files changed, 99 insertions(+), 19 deletions(-) diff --git a/backend/app/services/models.py b/backend/app/services/models.py index 0f86e7b..2738e18 100644 --- a/backend/app/services/models.py +++ b/backend/app/services/models.py @@ -10,37 +10,83 @@ from . import openrouter CACHE_TTL_HOURS = 24 +def _normalize_modality(raw: str) -> str: + """Normalize OpenRouter modality labels to canonical values.""" + value = (raw or "").strip().lower() + if value in {"text", "image", "video", "audio", "embeddings", "embedding"}: + return "embeddings" if value == "embedding" else value + if "image" in value: + return "image" + if "video" in value: + return "video" + if "audio" in value: + return "audio" + if "embed" in value: + return "embeddings" + return "text" + + def _parse_modality(raw_modality: str) -> str: """Extract output modality from OpenRouter architecture.modality string. Examples: "text->text", "text+image->text", "text->image", "text->video" """ output = raw_modality.split( - "->", 1)[-1].lower() if "->" in raw_modality else raw_modality.lower() - if "text" in output: - return "text" - if "image" in output: - return "image" - if "video" in output: - return "video" - if "audio" in output: - return "audio" - return output + "->", 1)[-1] if "->" in raw_modality else raw_modality + return _normalize_modality(output) + + +def _extract_output_modality(model: dict[str, Any]) -> str: + """Extract output modality using OpenRouter schema, fallback to legacy field.""" + architecture = model.get("architecture") or {} + + output_modalities = architecture.get( + "output_modalities") or model.get("output_modalities") + if isinstance(output_modalities, list) and output_modalities: + return _normalize_modality(str(output_modalities[0])) + + raw_modality = architecture.get( + "modality") or model.get("modality") or "text->text" + if isinstance(raw_modality, str): + return _parse_modality(raw_modality) + return "text" + + +async def _fetch_models_for_cache() -> list[dict[str, Any]]: + """Fetch broad + modality-specific lists and merge unique models by id.""" + by_id: dict[str, dict[str, Any]] = {} + + # Primary fetch: all modalities (per OpenRouter docs). + primary = await openrouter.list_models(output_modalities="all") + for model in primary: + model_id = model.get("id") + if model_id: + by_id[model_id] = model + + # Warmup fetches: some providers surface better results with explicit modality filter. + for modality in ("image", "video", "audio", "embeddings", "text"): + try: + subset = await openrouter.list_models(output_modalities=modality) + except Exception: + continue + for model in subset: + model_id = model.get("id") + if model_id and model_id not in by_id: + by_id[model_id] = model + + return list(by_id.values()) async def refresh_models_cache(conn: duckdb.DuckDBPyConnection) -> int: """Fetch all models from OpenRouter and replace the cache. Returns count stored.""" - raw = await openrouter.list_models() + raw = await _fetch_models_for_cache() # Use naive UTC to avoid DuckDB TIMESTAMP tz-stripping inconsistencies now = datetime.now(timezone.utc).replace(tzinfo=None) conn.execute("DELETE FROM models_cache") count = 0 for m in raw: - arch = m.get("architecture", {}) - modality_raw = arch.get( - "modality", "text->text") if arch else "text->text" - modality = _parse_modality(modality_raw) + modality = _extract_output_modality(m) pricing = m.get("pricing") model_id = m.get("id", "") if not model_id: diff --git a/backend/app/services/openrouter.py b/backend/app/services/openrouter.py index 6778ee3..e9b66c0 100644 --- a/backend/app/services/openrouter.py +++ b/backend/app/services/openrouter.py @@ -24,12 +24,25 @@ def _headers() -> dict[str, str]: } -async def list_models() -> list[dict[str, Any]]: - """Return available models from OpenRouter.""" +async def list_models( + output_modalities: str = "all", + category: str | None = None, + supported_parameters: str | None = None, +) -> list[dict[str, Any]]: + """Return available models from OpenRouter. + + Docs: GET /models supports query filters like output_modalities. + """ base_url = os.getenv("OPENROUTER_BASE_URL", OPENROUTER_BASE_URL) + params: dict[str, str] = {"output_modalities": output_modalities} + if category: + params["category"] = category + if supported_parameters: + params["supported_parameters"] = supported_parameters + async with httpx.AsyncClient(timeout=15) as client: resp = client.build_request( - "GET", f"{base_url}/models", headers=_headers()) + "GET", f"{base_url}/models", headers=_headers(), params=params) response = await client.send(resp) response.raise_for_status() return response.json().get("data", []) diff --git a/backend/tests/test_models.py b/backend/tests/test_models.py index c2924a2..db08f89 100644 --- a/backend/tests/test_models.py +++ b/backend/tests/test_models.py @@ -11,6 +11,8 @@ from httpx import ASGITransport, AsyncClient from app import db as db_module from app.main import app from app.services.models import ( + _extract_output_modality, + _normalize_modality, _parse_modality, get_cached_models, is_cache_stale, @@ -107,6 +109,25 @@ def test_parse_modality_no_arrow_fallback(): assert _parse_modality("text") == "text" +def test_normalize_embedding_alias(): + assert _normalize_modality("embedding") == "embeddings" + + +def test_extract_output_modality_prefers_output_modalities(): + model = { + "architecture": { + "modality": "text->text", + "output_modalities": ["image"], + } + } + assert _extract_output_modality(model) == "image" + + +def test_extract_output_modality_legacy_fallback(): + model = {"architecture": {"modality": "text->audio"}} + assert _extract_output_modality(model) == "audio" + + # --------------------------------------------------------------------------- # Unit tests: is_cache_stale # --------------------------------------------------------------------------- @@ -213,7 +234,7 @@ async def test_list_models_endpoint_auto_refreshes(client): ) assert resp.status_code == 200 assert len(resp.json()) == 4 - mock_fetch.assert_awaited_once() + assert mock_fetch.await_count >= 1 async def test_list_models_endpoint_uses_cache(client):