"""Tests for model cache service and router.""" import json import os from datetime import datetime, timedelta, timezone from unittest.mock import AsyncMock, patch import pytest import pytest_asyncio from httpx import ASGITransport, AsyncClient from app import db as db_module from app.main import app from app.services.models import ( _extract_output_modality, _normalize_modality, _parse_modality, get_cached_models, is_cache_stale, refresh_models_cache, ) os.environ.setdefault("JWT_SECRET", "test-secret-key-for-testing-only") os.environ.setdefault("OPENROUTER_API_KEY", "test-key") FAKE_MODELS_RAW = [ { "id": "openai/gpt-4o", "name": "GPT-4o", "context_length": 128000, "pricing": {"prompt": "0.000005"}, "architecture": {"modality": "text->text"}, }, { "id": "anthropic/claude-3-haiku", "name": "Claude 3 Haiku", "context_length": 200000, "pricing": {}, "architecture": {"modality": "text+image->text"}, }, { "id": "openai/dall-e-3", "name": "DALL-E 3", "context_length": None, "pricing": {"image": "0.04"}, "architecture": {"modality": "text->image"}, }, { "id": "openai/sora-2", "name": "Sora 2", "context_length": None, "pricing": {"video": "0.10"}, "architecture": {"modality": "text->video"}, }, ] @pytest.fixture(autouse=True) def fresh_db(): db_module._conn = None db_module.init_db(":memory:") yield db_module.close_db() db_module._conn = None @pytest_asyncio.fixture async def client(fresh_db): transport = ASGITransport(app=app) async with AsyncClient(transport=transport, base_url="http://test") as ac: yield ac async def _register_login(client, email, password, is_admin=False): """Register + login; optionally promote to admin directly in DB.""" await client.post("/auth/register", json={"email": email, "password": password}) if is_admin: db_module.get_conn().execute( "UPDATE users SET role = 'admin' WHERE email = ?", [email] ) resp = await client.post("/auth/login", json={"email": email, "password": password}) return resp.json()["access_token"] # --------------------------------------------------------------------------- # Unit tests: _parse_modality # --------------------------------------------------------------------------- def test_parse_modality_text(): assert _parse_modality("text->text") == "text" def test_parse_modality_multimodal_input_text_output(): assert _parse_modality("text+image->text") == "text" def test_parse_modality_image(): assert _parse_modality("text->image") == "image" def test_parse_modality_video(): assert _parse_modality("text->video") == "video" def test_parse_modality_audio(): assert _parse_modality("text->audio") == "audio" def test_parse_modality_no_arrow_fallback(): assert _parse_modality("text") == "text" def test_normalize_embedding_alias(): assert _normalize_modality("embedding") == "embeddings" def test_extract_output_modality_prefers_output_modalities(): model = { "architecture": { "modality": "text->text", "output_modalities": ["image"], } } assert _extract_output_modality(model) == "image" def test_extract_output_modality_legacy_fallback(): model = {"architecture": {"modality": "text->audio"}} assert _extract_output_modality(model) == "audio" # --------------------------------------------------------------------------- # Unit tests: is_cache_stale # --------------------------------------------------------------------------- def test_cache_stale_when_empty(): conn = db_module.get_conn() assert is_cache_stale(conn) is True def test_cache_not_stale_after_fresh_insert(): conn = db_module.get_conn() now = datetime.now(timezone.utc).replace(tzinfo=None) conn.execute( "INSERT INTO models_cache (model_id, name, modality, fetched_at) VALUES (?, ?, ?, ?)", ["openai/gpt-4o", "GPT-4o", "text", now], ) assert is_cache_stale(conn) is False def test_cache_stale_after_ttl_exceeded(): conn = db_module.get_conn() # Store naive UTC to match DuckDB TIMESTAMP behaviour old_time = datetime.now(timezone.utc).replace( tzinfo=None) - timedelta(hours=25) conn.execute( "INSERT INTO models_cache (model_id, name, modality, fetched_at) VALUES (?, ?, ?, ?)", ["openai/gpt-4o", "GPT-4o", "text", old_time], ) assert is_cache_stale(conn) is True # --------------------------------------------------------------------------- # Unit tests: refresh_models_cache + get_cached_models # --------------------------------------------------------------------------- async def test_refresh_stores_models(): conn = db_module.get_conn() with patch( "app.services.models.openrouter.list_models", new_callable=AsyncMock, return_value=FAKE_MODELS_RAW, ): count = await refresh_models_cache(conn) assert count == 4 all_models = get_cached_models(conn) assert len(all_models) == 4 async def test_refresh_replaces_old_cache(): conn = db_module.get_conn() old_time = datetime.now(timezone.utc).replace( tzinfo=None) - timedelta(hours=30) conn.execute( "INSERT INTO models_cache (model_id, name, modality, fetched_at) VALUES (?, ?, ?, ?)", ["old/model", "Old Model", "text", old_time], ) with patch( "app.services.models.openrouter.list_models", new_callable=AsyncMock, return_value=FAKE_MODELS_RAW, ): await refresh_models_cache(conn) ids = [m["id"] for m in get_cached_models(conn)] assert "old/model" not in ids assert "openai/gpt-4o" in ids def test_get_cached_models_filter_by_modality(): conn = db_module.get_conn() now = datetime.now(timezone.utc).replace(tzinfo=None) for m in FAKE_MODELS_RAW: arch = m.get("architecture", {}) modality = _parse_modality(arch.get("modality", "text->text")) conn.execute( "INSERT INTO models_cache (model_id, name, modality, fetched_at) VALUES (?, ?, ?, ?)", [m["id"], m["name"], modality, now], ) text_models = get_cached_models(conn, modality="text") assert len(text_models) == 2 assert all(m["modality"] == "text" for m in text_models) image_models = get_cached_models(conn, modality="image") assert len(image_models) == 1 assert image_models[0]["id"] == "openai/dall-e-3" video_models = get_cached_models(conn, modality="video") assert len(video_models) == 1 assert video_models[0]["id"] == "openai/sora-2" # --------------------------------------------------------------------------- # Integration tests: GET /models/ # --------------------------------------------------------------------------- async def test_list_models_endpoint_auto_refreshes(client): token = await _register_login(client, "user@example.com", "secret123") with patch( "app.services.models.openrouter.list_models", new_callable=AsyncMock, return_value=FAKE_MODELS_RAW, ) as mock_fetch: resp = await client.get( "/models/", headers={"Authorization": f"Bearer {token}"} ) assert resp.status_code == 200 assert len(resp.json()) == 4 assert mock_fetch.await_count >= 1 async def test_list_models_endpoint_uses_cache(client): token = await _register_login(client, "user@example.com", "secret123") conn = db_module.get_conn() now = datetime.now(timezone.utc).replace(tzinfo=None) conn.execute( "INSERT INTO models_cache (model_id, name, modality, fetched_at) VALUES (?, ?, ?, ?)", ["cached/model", "Cached Model", "text", now], ) with patch( "app.services.models.openrouter.list_models", new_callable=AsyncMock, ) as mock_fetch: resp = await client.get( "/models/?modality=text", headers={"Authorization": f"Bearer {token}"} ) assert resp.status_code == 200 assert resp.json()[0]["id"] == "cached/model" mock_fetch.assert_not_awaited() async def test_list_models_endpoint_requires_auth(client): resp = await client.get("/models/") assert resp.status_code == 401 async def test_list_models_filter_by_modality(client): token = await _register_login(client, "user@example.com", "secret123") with patch( "app.services.models.openrouter.list_models", new_callable=AsyncMock, return_value=FAKE_MODELS_RAW, ): resp = await client.get( "/models/?modality=image", headers={"Authorization": f"Bearer {token}"} ) assert resp.status_code == 200 data = resp.json() assert len(data) == 1 assert data[0]["id"] == "openai/dall-e-3" # --------------------------------------------------------------------------- # Integration tests: POST /models/refresh # --------------------------------------------------------------------------- async def test_refresh_endpoint_requires_admin(client): token = await _register_login(client, "user@example.com", "secret123") resp = await client.post( "/models/refresh", headers={"Authorization": f"Bearer {token}"} ) assert resp.status_code == 403 async def test_refresh_endpoint_admin_succeeds(client): token = await _register_login(client, "admin@example.com", "secret123", is_admin=True) with patch( "app.services.models.openrouter.list_models", new_callable=AsyncMock, return_value=FAKE_MODELS_RAW, ): resp = await client.post( "/models/refresh", headers={"Authorization": f"Bearer {token}"} ) assert resp.status_code == 200 assert resp.json()["refreshed"] == 4 async def test_refresh_endpoint_502_on_openrouter_error(client): token = await _register_login(client, "admin@example.com", "secret123", is_admin=True) with patch( "app.services.models.openrouter.list_models", new_callable=AsyncMock, side_effect=RuntimeError("network error"), ): resp = await client.post( "/models/refresh", headers={"Authorization": f"Bearer {token}"} ) assert resp.status_code == 502