Enhance model handling by normalizing modalities and updating fetch logic; add tests for new functionality
Co-authored-by: Copilot <copilot@github.com>
This commit is contained in:
@@ -10,37 +10,83 @@ from . import openrouter
|
|||||||
CACHE_TTL_HOURS = 24
|
CACHE_TTL_HOURS = 24
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_modality(raw: str) -> str:
|
||||||
|
"""Normalize OpenRouter modality labels to canonical values."""
|
||||||
|
value = (raw or "").strip().lower()
|
||||||
|
if value in {"text", "image", "video", "audio", "embeddings", "embedding"}:
|
||||||
|
return "embeddings" if value == "embedding" else value
|
||||||
|
if "image" in value:
|
||||||
|
return "image"
|
||||||
|
if "video" in value:
|
||||||
|
return "video"
|
||||||
|
if "audio" in value:
|
||||||
|
return "audio"
|
||||||
|
if "embed" in value:
|
||||||
|
return "embeddings"
|
||||||
|
return "text"
|
||||||
|
|
||||||
|
|
||||||
def _parse_modality(raw_modality: str) -> str:
|
def _parse_modality(raw_modality: str) -> str:
|
||||||
"""Extract output modality from OpenRouter architecture.modality string.
|
"""Extract output modality from OpenRouter architecture.modality string.
|
||||||
|
|
||||||
Examples: "text->text", "text+image->text", "text->image", "text->video"
|
Examples: "text->text", "text+image->text", "text->image", "text->video"
|
||||||
"""
|
"""
|
||||||
output = raw_modality.split(
|
output = raw_modality.split(
|
||||||
"->", 1)[-1].lower() if "->" in raw_modality else raw_modality.lower()
|
"->", 1)[-1] if "->" in raw_modality else raw_modality
|
||||||
if "text" in output:
|
return _normalize_modality(output)
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_output_modality(model: dict[str, Any]) -> str:
|
||||||
|
"""Extract output modality using OpenRouter schema, fallback to legacy field."""
|
||||||
|
architecture = model.get("architecture") or {}
|
||||||
|
|
||||||
|
output_modalities = architecture.get(
|
||||||
|
"output_modalities") or model.get("output_modalities")
|
||||||
|
if isinstance(output_modalities, list) and output_modalities:
|
||||||
|
return _normalize_modality(str(output_modalities[0]))
|
||||||
|
|
||||||
|
raw_modality = architecture.get(
|
||||||
|
"modality") or model.get("modality") or "text->text"
|
||||||
|
if isinstance(raw_modality, str):
|
||||||
|
return _parse_modality(raw_modality)
|
||||||
return "text"
|
return "text"
|
||||||
if "image" in output:
|
|
||||||
return "image"
|
|
||||||
if "video" in output:
|
async def _fetch_models_for_cache() -> list[dict[str, Any]]:
|
||||||
return "video"
|
"""Fetch broad + modality-specific lists and merge unique models by id."""
|
||||||
if "audio" in output:
|
by_id: dict[str, dict[str, Any]] = {}
|
||||||
return "audio"
|
|
||||||
return output
|
# Primary fetch: all modalities (per OpenRouter docs).
|
||||||
|
primary = await openrouter.list_models(output_modalities="all")
|
||||||
|
for model in primary:
|
||||||
|
model_id = model.get("id")
|
||||||
|
if model_id:
|
||||||
|
by_id[model_id] = model
|
||||||
|
|
||||||
|
# Warmup fetches: some providers surface better results with explicit modality filter.
|
||||||
|
for modality in ("image", "video", "audio", "embeddings", "text"):
|
||||||
|
try:
|
||||||
|
subset = await openrouter.list_models(output_modalities=modality)
|
||||||
|
except Exception:
|
||||||
|
continue
|
||||||
|
for model in subset:
|
||||||
|
model_id = model.get("id")
|
||||||
|
if model_id and model_id not in by_id:
|
||||||
|
by_id[model_id] = model
|
||||||
|
|
||||||
|
return list(by_id.values())
|
||||||
|
|
||||||
|
|
||||||
async def refresh_models_cache(conn: duckdb.DuckDBPyConnection) -> int:
|
async def refresh_models_cache(conn: duckdb.DuckDBPyConnection) -> int:
|
||||||
"""Fetch all models from OpenRouter and replace the cache. Returns count stored."""
|
"""Fetch all models from OpenRouter and replace the cache. Returns count stored."""
|
||||||
raw = await openrouter.list_models()
|
raw = await _fetch_models_for_cache()
|
||||||
# Use naive UTC to avoid DuckDB TIMESTAMP tz-stripping inconsistencies
|
# Use naive UTC to avoid DuckDB TIMESTAMP tz-stripping inconsistencies
|
||||||
now = datetime.now(timezone.utc).replace(tzinfo=None)
|
now = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||||
|
|
||||||
conn.execute("DELETE FROM models_cache")
|
conn.execute("DELETE FROM models_cache")
|
||||||
count = 0
|
count = 0
|
||||||
for m in raw:
|
for m in raw:
|
||||||
arch = m.get("architecture", {})
|
modality = _extract_output_modality(m)
|
||||||
modality_raw = arch.get(
|
|
||||||
"modality", "text->text") if arch else "text->text"
|
|
||||||
modality = _parse_modality(modality_raw)
|
|
||||||
pricing = m.get("pricing")
|
pricing = m.get("pricing")
|
||||||
model_id = m.get("id", "")
|
model_id = m.get("id", "")
|
||||||
if not model_id:
|
if not model_id:
|
||||||
|
|||||||
@@ -24,12 +24,25 @@ def _headers() -> dict[str, str]:
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
async def list_models() -> list[dict[str, Any]]:
|
async def list_models(
|
||||||
"""Return available models from OpenRouter."""
|
output_modalities: str = "all",
|
||||||
|
category: str | None = None,
|
||||||
|
supported_parameters: str | None = None,
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
|
"""Return available models from OpenRouter.
|
||||||
|
|
||||||
|
Docs: GET /models supports query filters like output_modalities.
|
||||||
|
"""
|
||||||
base_url = os.getenv("OPENROUTER_BASE_URL", OPENROUTER_BASE_URL)
|
base_url = os.getenv("OPENROUTER_BASE_URL", OPENROUTER_BASE_URL)
|
||||||
|
params: dict[str, str] = {"output_modalities": output_modalities}
|
||||||
|
if category:
|
||||||
|
params["category"] = category
|
||||||
|
if supported_parameters:
|
||||||
|
params["supported_parameters"] = supported_parameters
|
||||||
|
|
||||||
async with httpx.AsyncClient(timeout=15) as client:
|
async with httpx.AsyncClient(timeout=15) as client:
|
||||||
resp = client.build_request(
|
resp = client.build_request(
|
||||||
"GET", f"{base_url}/models", headers=_headers())
|
"GET", f"{base_url}/models", headers=_headers(), params=params)
|
||||||
response = await client.send(resp)
|
response = await client.send(resp)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
return response.json().get("data", [])
|
return response.json().get("data", [])
|
||||||
|
|||||||
@@ -11,6 +11,8 @@ from httpx import ASGITransport, AsyncClient
|
|||||||
from app import db as db_module
|
from app import db as db_module
|
||||||
from app.main import app
|
from app.main import app
|
||||||
from app.services.models import (
|
from app.services.models import (
|
||||||
|
_extract_output_modality,
|
||||||
|
_normalize_modality,
|
||||||
_parse_modality,
|
_parse_modality,
|
||||||
get_cached_models,
|
get_cached_models,
|
||||||
is_cache_stale,
|
is_cache_stale,
|
||||||
@@ -107,6 +109,25 @@ def test_parse_modality_no_arrow_fallback():
|
|||||||
assert _parse_modality("text") == "text"
|
assert _parse_modality("text") == "text"
|
||||||
|
|
||||||
|
|
||||||
|
def test_normalize_embedding_alias():
|
||||||
|
assert _normalize_modality("embedding") == "embeddings"
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_output_modality_prefers_output_modalities():
|
||||||
|
model = {
|
||||||
|
"architecture": {
|
||||||
|
"modality": "text->text",
|
||||||
|
"output_modalities": ["image"],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
assert _extract_output_modality(model) == "image"
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_output_modality_legacy_fallback():
|
||||||
|
model = {"architecture": {"modality": "text->audio"}}
|
||||||
|
assert _extract_output_modality(model) == "audio"
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Unit tests: is_cache_stale
|
# Unit tests: is_cache_stale
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@@ -213,7 +234,7 @@ async def test_list_models_endpoint_auto_refreshes(client):
|
|||||||
)
|
)
|
||||||
assert resp.status_code == 200
|
assert resp.status_code == 200
|
||||||
assert len(resp.json()) == 4
|
assert len(resp.json()) == 4
|
||||||
mock_fetch.assert_awaited_once()
|
assert mock_fetch.await_count >= 1
|
||||||
|
|
||||||
|
|
||||||
async def test_list_models_endpoint_uses_cache(client):
|
async def test_list_models_endpoint_uses_cache(client):
|
||||||
|
|||||||
Reference in New Issue
Block a user