diff --git a/backend/app/db.py b/backend/app/db.py index 81aeb38..17528ba 100644 --- a/backend/app/db.py +++ b/backend/app/db.py @@ -75,6 +75,17 @@ def _run_migrations(conn: duckdb.DuckDBPyConnection) -> None: created_at TIMESTAMP DEFAULT now() ) """) + conn.execute(""" + CREATE TABLE IF NOT EXISTS models_cache ( + id UUID DEFAULT uuid() PRIMARY KEY, + model_id VARCHAR NOT NULL UNIQUE, + name VARCHAR NOT NULL, + modality VARCHAR NOT NULL, + context_length BIGINT, + pricing JSON, + fetched_at TIMESTAMP NOT NULL + ) + """) _seed_admin(conn) diff --git a/backend/app/main.py b/backend/app/main.py index bc503f7..1f98737 100644 --- a/backend/app/main.py +++ b/backend/app/main.py @@ -4,6 +4,7 @@ from .routers import admin as admin_router from .routers import ai as ai_router from .routers import generate as generate_router from .routers import images as images_router +from .routers import models as models_router from .db import close_db, init_db import os from contextlib import asynccontextmanager @@ -43,6 +44,7 @@ app.include_router(admin_router.router) app.include_router(ai_router.router) app.include_router(generate_router.router) app.include_router(images_router.router) +app.include_router(models_router.router) @app.get("/health", tags=["health"]) diff --git a/backend/app/routers/models.py b/backend/app/routers/models.py new file mode 100644 index 0000000..682844a --- /dev/null +++ b/backend/app/routers/models.py @@ -0,0 +1,47 @@ +"""Models router: list and refresh the OpenRouter model cache.""" +from fastapi import APIRouter, Depends, HTTPException, Query, status + +from ..db import get_conn, get_write_lock +from ..dependencies import get_current_user, require_admin +from ..services import models as models_service + +router = APIRouter(prefix="/models", tags=["models"]) + + +@router.get("/") +async def list_models( + modality: str | None = Query( + None, + description="Filter by output modality: text, image, video, audio", + ), + _: dict = Depends(get_current_user), +): + """Return cached models. Auto-refreshes cache if stale (older than 24 h).""" + conn = get_conn() + if models_service.is_cache_stale(conn): + async with get_write_lock(): + # Re-check inside lock to avoid redundant parallel refreshes + if models_service.is_cache_stale(conn): + try: + await models_service.refresh_models_cache(conn) + except Exception as exc: + raise HTTPException( + status_code=status.HTTP_502_BAD_GATEWAY, + detail=f"Failed to refresh model cache: {exc}", + ) + return models_service.get_cached_models(conn, modality) + + +@router.post("/refresh", status_code=200) +async def refresh_models(_: dict = Depends(require_admin)): + """Force-refresh the model cache from OpenRouter. Admin only.""" + conn = get_conn() + async with get_write_lock(): + try: + count = await models_service.refresh_models_cache(conn) + except Exception as exc: + raise HTTPException( + status_code=status.HTTP_502_BAD_GATEWAY, + detail=f"OpenRouter error: {exc}", + ) + return {"refreshed": count} diff --git a/backend/app/services/models.py b/backend/app/services/models.py new file mode 100644 index 0000000..0f86e7b --- /dev/null +++ b/backend/app/services/models.py @@ -0,0 +1,124 @@ +"""Model cache service: fetch from OpenRouter, store in DuckDB.""" +import json +from datetime import datetime, timedelta, timezone +from typing import Any + +import duckdb + +from . import openrouter + +CACHE_TTL_HOURS = 24 + + +def _parse_modality(raw_modality: str) -> str: + """Extract output modality from OpenRouter architecture.modality string. + + Examples: "text->text", "text+image->text", "text->image", "text->video" + """ + output = raw_modality.split( + "->", 1)[-1].lower() if "->" in raw_modality else raw_modality.lower() + if "text" in output: + return "text" + if "image" in output: + return "image" + if "video" in output: + return "video" + if "audio" in output: + return "audio" + return output + + +async def refresh_models_cache(conn: duckdb.DuckDBPyConnection) -> int: + """Fetch all models from OpenRouter and replace the cache. Returns count stored.""" + raw = await openrouter.list_models() + # Use naive UTC to avoid DuckDB TIMESTAMP tz-stripping inconsistencies + now = datetime.now(timezone.utc).replace(tzinfo=None) + + conn.execute("DELETE FROM models_cache") + count = 0 + for m in raw: + arch = m.get("architecture", {}) + modality_raw = arch.get( + "modality", "text->text") if arch else "text->text" + modality = _parse_modality(modality_raw) + pricing = m.get("pricing") + model_id = m.get("id", "") + if not model_id: + continue + conn.execute( + """ + INSERT INTO models_cache (model_id, name, modality, context_length, pricing, fetched_at) + VALUES (?, ?, ?, ?, ?, ?) + ON CONFLICT (model_id) DO UPDATE SET + name = excluded.name, + modality = excluded.modality, + context_length = excluded.context_length, + pricing = excluded.pricing, + fetched_at = excluded.fetched_at + """, + [ + model_id, + m.get("name", model_id), + modality, + m.get("context_length"), + json.dumps(pricing) if pricing else None, + now, + ], + ) + count += 1 + return count + + +def is_cache_stale(conn: duckdb.DuckDBPyConnection) -> bool: + """Return True if cache is empty or last fetched more than CACHE_TTL_HOURS ago.""" + row = conn.execute("SELECT MAX(fetched_at) FROM models_cache").fetchone() + if not row or row[0] is None: + return True + last_fetched = row[0] + # DuckDB TIMESTAMP is always naive; compare against naive UTC + if last_fetched.tzinfo is not None: + last_fetched = last_fetched.replace(tzinfo=None) + now_naive = datetime.now(timezone.utc).replace(tzinfo=None) + return now_naive - last_fetched > timedelta(hours=CACHE_TTL_HOURS) + + +def get_cached_models( + conn: duckdb.DuckDBPyConnection, + modality: str | None = None, +) -> list[dict[str, Any]]: + """Return cached models, optionally filtered by modality, ordered by name.""" + if modality: + rows = conn.execute( + """ + SELECT model_id, name, modality, context_length, pricing + FROM models_cache + WHERE modality = ? + ORDER BY name + """, + [modality], + ).fetchall() + else: + rows = conn.execute( + """ + SELECT model_id, name, modality, context_length, pricing + FROM models_cache + ORDER BY name + """ + ).fetchall() + + result = [] + for row in rows: + pricing = None + if row[4]: + try: + pricing = json.loads(row[4]) + except (json.JSONDecodeError, TypeError): + pricing = None + result.append({ + "id": row[0], + "name": row[1], + "modality": row[2], + "context_length": row[3], + "pricing": pricing, + }) + return result diff --git a/backend/tests/test_models.py b/backend/tests/test_models.py new file mode 100644 index 0000000..c2924a2 --- /dev/null +++ b/backend/tests/test_models.py @@ -0,0 +1,296 @@ +"""Tests for model cache service and router.""" +import json +import os +from datetime import datetime, timedelta, timezone +from unittest.mock import AsyncMock, patch + +import pytest +import pytest_asyncio +from httpx import ASGITransport, AsyncClient + +from app import db as db_module +from app.main import app +from app.services.models import ( + _parse_modality, + get_cached_models, + is_cache_stale, + refresh_models_cache, +) + +os.environ.setdefault("JWT_SECRET", "test-secret-key-for-testing-only") +os.environ.setdefault("OPENROUTER_API_KEY", "test-key") + +FAKE_MODELS_RAW = [ + { + "id": "openai/gpt-4o", + "name": "GPT-4o", + "context_length": 128000, + "pricing": {"prompt": "0.000005"}, + "architecture": {"modality": "text->text"}, + }, + { + "id": "anthropic/claude-3-haiku", + "name": "Claude 3 Haiku", + "context_length": 200000, + "pricing": {}, + "architecture": {"modality": "text+image->text"}, + }, + { + "id": "openai/dall-e-3", + "name": "DALL-E 3", + "context_length": None, + "pricing": {"image": "0.04"}, + "architecture": {"modality": "text->image"}, + }, + { + "id": "openai/sora-2", + "name": "Sora 2", + "context_length": None, + "pricing": {"video": "0.10"}, + "architecture": {"modality": "text->video"}, + }, +] + + +@pytest.fixture(autouse=True) +def fresh_db(): + db_module._conn = None + db_module.init_db(":memory:") + yield + db_module.close_db() + db_module._conn = None + + +@pytest_asyncio.fixture +async def client(fresh_db): + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as ac: + yield ac + + +async def _register_login(client, email, password, is_admin=False): + """Register + login; optionally promote to admin directly in DB.""" + await client.post("/auth/register", json={"email": email, "password": password}) + if is_admin: + db_module.get_conn().execute( + "UPDATE users SET role = 'admin' WHERE email = ?", [email] + ) + resp = await client.post("/auth/login", json={"email": email, "password": password}) + return resp.json()["access_token"] + + +# --------------------------------------------------------------------------- +# Unit tests: _parse_modality +# --------------------------------------------------------------------------- + +def test_parse_modality_text(): + assert _parse_modality("text->text") == "text" + + +def test_parse_modality_multimodal_input_text_output(): + assert _parse_modality("text+image->text") == "text" + + +def test_parse_modality_image(): + assert _parse_modality("text->image") == "image" + + +def test_parse_modality_video(): + assert _parse_modality("text->video") == "video" + + +def test_parse_modality_audio(): + assert _parse_modality("text->audio") == "audio" + + +def test_parse_modality_no_arrow_fallback(): + assert _parse_modality("text") == "text" + + +# --------------------------------------------------------------------------- +# Unit tests: is_cache_stale +# --------------------------------------------------------------------------- + +def test_cache_stale_when_empty(): + conn = db_module.get_conn() + assert is_cache_stale(conn) is True + + +def test_cache_not_stale_after_fresh_insert(): + conn = db_module.get_conn() + now = datetime.now(timezone.utc).replace(tzinfo=None) + conn.execute( + "INSERT INTO models_cache (model_id, name, modality, fetched_at) VALUES (?, ?, ?, ?)", + ["openai/gpt-4o", "GPT-4o", "text", now], + ) + assert is_cache_stale(conn) is False + + +def test_cache_stale_after_ttl_exceeded(): + conn = db_module.get_conn() + # Store naive UTC to match DuckDB TIMESTAMP behaviour + old_time = datetime.now(timezone.utc).replace( + tzinfo=None) - timedelta(hours=25) + conn.execute( + "INSERT INTO models_cache (model_id, name, modality, fetched_at) VALUES (?, ?, ?, ?)", + ["openai/gpt-4o", "GPT-4o", "text", old_time], + ) + assert is_cache_stale(conn) is True + + +# --------------------------------------------------------------------------- +# Unit tests: refresh_models_cache + get_cached_models +# --------------------------------------------------------------------------- + +async def test_refresh_stores_models(): + conn = db_module.get_conn() + with patch( + "app.services.models.openrouter.list_models", + new_callable=AsyncMock, + return_value=FAKE_MODELS_RAW, + ): + count = await refresh_models_cache(conn) + assert count == 4 + all_models = get_cached_models(conn) + assert len(all_models) == 4 + + +async def test_refresh_replaces_old_cache(): + conn = db_module.get_conn() + old_time = datetime.now(timezone.utc).replace( + tzinfo=None) - timedelta(hours=30) + conn.execute( + "INSERT INTO models_cache (model_id, name, modality, fetched_at) VALUES (?, ?, ?, ?)", + ["old/model", "Old Model", "text", old_time], + ) + with patch( + "app.services.models.openrouter.list_models", + new_callable=AsyncMock, + return_value=FAKE_MODELS_RAW, + ): + await refresh_models_cache(conn) + ids = [m["id"] for m in get_cached_models(conn)] + assert "old/model" not in ids + assert "openai/gpt-4o" in ids + + +def test_get_cached_models_filter_by_modality(): + conn = db_module.get_conn() + now = datetime.now(timezone.utc).replace(tzinfo=None) + for m in FAKE_MODELS_RAW: + arch = m.get("architecture", {}) + modality = _parse_modality(arch.get("modality", "text->text")) + conn.execute( + "INSERT INTO models_cache (model_id, name, modality, fetched_at) VALUES (?, ?, ?, ?)", + [m["id"], m["name"], modality, now], + ) + text_models = get_cached_models(conn, modality="text") + assert len(text_models) == 2 + assert all(m["modality"] == "text" for m in text_models) + + image_models = get_cached_models(conn, modality="image") + assert len(image_models) == 1 + assert image_models[0]["id"] == "openai/dall-e-3" + + video_models = get_cached_models(conn, modality="video") + assert len(video_models) == 1 + assert video_models[0]["id"] == "openai/sora-2" + + +# --------------------------------------------------------------------------- +# Integration tests: GET /models/ +# --------------------------------------------------------------------------- + +async def test_list_models_endpoint_auto_refreshes(client): + token = await _register_login(client, "user@example.com", "secret123") + with patch( + "app.services.models.openrouter.list_models", + new_callable=AsyncMock, + return_value=FAKE_MODELS_RAW, + ) as mock_fetch: + resp = await client.get( + "/models/", headers={"Authorization": f"Bearer {token}"} + ) + assert resp.status_code == 200 + assert len(resp.json()) == 4 + mock_fetch.assert_awaited_once() + + +async def test_list_models_endpoint_uses_cache(client): + token = await _register_login(client, "user@example.com", "secret123") + conn = db_module.get_conn() + now = datetime.now(timezone.utc).replace(tzinfo=None) + conn.execute( + "INSERT INTO models_cache (model_id, name, modality, fetched_at) VALUES (?, ?, ?, ?)", + ["cached/model", "Cached Model", "text", now], + ) + with patch( + "app.services.models.openrouter.list_models", + new_callable=AsyncMock, + ) as mock_fetch: + resp = await client.get( + "/models/?modality=text", headers={"Authorization": f"Bearer {token}"} + ) + assert resp.status_code == 200 + assert resp.json()[0]["id"] == "cached/model" + mock_fetch.assert_not_awaited() + + +async def test_list_models_endpoint_requires_auth(client): + resp = await client.get("/models/") + assert resp.status_code == 401 + + +async def test_list_models_filter_by_modality(client): + token = await _register_login(client, "user@example.com", "secret123") + with patch( + "app.services.models.openrouter.list_models", + new_callable=AsyncMock, + return_value=FAKE_MODELS_RAW, + ): + resp = await client.get( + "/models/?modality=image", headers={"Authorization": f"Bearer {token}"} + ) + assert resp.status_code == 200 + data = resp.json() + assert len(data) == 1 + assert data[0]["id"] == "openai/dall-e-3" + + +# --------------------------------------------------------------------------- +# Integration tests: POST /models/refresh +# --------------------------------------------------------------------------- + +async def test_refresh_endpoint_requires_admin(client): + token = await _register_login(client, "user@example.com", "secret123") + resp = await client.post( + "/models/refresh", headers={"Authorization": f"Bearer {token}"} + ) + assert resp.status_code == 403 + + +async def test_refresh_endpoint_admin_succeeds(client): + token = await _register_login(client, "admin@example.com", "secret123", is_admin=True) + with patch( + "app.services.models.openrouter.list_models", + new_callable=AsyncMock, + return_value=FAKE_MODELS_RAW, + ): + resp = await client.post( + "/models/refresh", headers={"Authorization": f"Bearer {token}"} + ) + assert resp.status_code == 200 + assert resp.json()["refreshed"] == 4 + + +async def test_refresh_endpoint_502_on_openrouter_error(client): + token = await _register_login(client, "admin@example.com", "secret123", is_admin=True) + with patch( + "app.services.models.openrouter.list_models", + new_callable=AsyncMock, + side_effect=RuntimeError("network error"), + ): + resp = await client.post( + "/models/refresh", headers={"Authorization": f"Bearer {token}"} + ) + assert resp.status_code == 502 diff --git a/frontend/app/main.py b/frontend/app/main.py index 18ae828..3c21f69 100644 --- a/frontend/app/main.py +++ b/frontend/app/main.py @@ -153,8 +153,9 @@ def generate(): @login_required def generate_text(): result = error = None + token = session["access_token"] if request.method == "POST": - resp = _api("POST", "/generate/text", token=session["access_token"], json={ + resp = _api("POST", "/generate/text", token=token, json={ "model": request.form.get("model", "").strip(), "prompt": request.form.get("prompt", "").strip(), }) @@ -162,28 +163,35 @@ def generate_text(): result = resp.json() else: error = resp.json().get("detail", "Generation failed.") - return render_template("generate_text.html", result=result, error=error) + models_resp = _api("GET", "/models/", token=token, + params={"modality": "text"}) + models = models_resp.json() if models_resp.status_code == 200 else [] + return render_template("generate_text.html", result=result, error=error, models=models) @app.route("/generate/image", methods=["GET", "POST"]) @login_required def generate_image(): result = error = None + token = session["access_token"] if request.method == "POST": # Upload reference image if provided ref_file = request.files.get("reference_image") if ref_file and ref_file.filename: up_resp = _api( "POST", "/images/upload", - token=session["access_token"], + token=token, files={"file": (ref_file.filename, ref_file.stream, ref_file.content_type)}, ) if up_resp.status_code not in (200, 201): error = up_resp.json().get("detail", "Image upload failed.") - return render_template("generate_image.html", result=result, error=error) + models_resp = _api("GET", "/models/", + token=token, params={"modality": "image"}) + models = models_resp.json() if models_resp.status_code == 200 else [] + return render_template("generate_image.html", result=result, error=error, models=models) - resp = _api("POST", "/generate/image", token=session["access_token"], json={ + resp = _api("POST", "/generate/image", token=token, json={ "model": request.form.get("model", "").strip(), "prompt": request.form.get("prompt", "").strip(), "n": int(request.form.get("n", 1)), @@ -195,16 +203,19 @@ def generate_image(): result = resp.json() else: error = resp.json().get("detail", "Generation failed.") - return render_template("generate_image.html", result=result, error=error) + models_resp = _api("GET", "/models/", token=token, + params={"modality": "image"}) + models = models_resp.json() if models_resp.status_code == 200 else [] + return render_template("generate_image.html", result=result, error=error, models=models) @app.route("/generate/video", methods=["GET", "POST"]) @login_required def generate_video(): result = error = None + token = session["access_token"] if request.method == "POST": mode = request.form.get("mode", "text") - token = session["access_token"] duration_raw = request.form.get("duration_seconds", "") duration = int( duration_raw) if duration_raw.strip().isdigit() else None @@ -230,7 +241,10 @@ def generate_video(): result = resp.json() else: error = resp.json().get("detail", "Generation failed.") - return render_template("generate_video.html", result=result, error=error) + models_resp = _api("GET", "/models/", token=token, + params={"modality": "video"}) + models = models_resp.json() if models_resp.status_code == 200 else [] + return render_template("generate_video.html", result=result, error=error, models=models) @app.get("/generate/video/status") diff --git a/frontend/app/templates/generate_image.html b/frontend/app/templates/generate_image.html index 7ba8da0..ad6a3cc 100644 --- a/frontend/app/templates/generate_image.html +++ b/frontend/app/templates/generate_image.html @@ -5,9 +5,17 @@

Image Generation

+ {% if models %} + + {% else %} + {% endif %}