Add models caching and management functionality with corresponding API endpoints and templates
Co-authored-by: Copilot <copilot@github.com>
This commit is contained in:
@@ -75,6 +75,17 @@ def _run_migrations(conn: duckdb.DuckDBPyConnection) -> None:
|
|||||||
created_at TIMESTAMP DEFAULT now()
|
created_at TIMESTAMP DEFAULT now()
|
||||||
)
|
)
|
||||||
""")
|
""")
|
||||||
|
conn.execute("""
|
||||||
|
CREATE TABLE IF NOT EXISTS models_cache (
|
||||||
|
id UUID DEFAULT uuid() PRIMARY KEY,
|
||||||
|
model_id VARCHAR NOT NULL UNIQUE,
|
||||||
|
name VARCHAR NOT NULL,
|
||||||
|
modality VARCHAR NOT NULL,
|
||||||
|
context_length BIGINT,
|
||||||
|
pricing JSON,
|
||||||
|
fetched_at TIMESTAMP NOT NULL
|
||||||
|
)
|
||||||
|
""")
|
||||||
_seed_admin(conn)
|
_seed_admin(conn)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ from .routers import admin as admin_router
|
|||||||
from .routers import ai as ai_router
|
from .routers import ai as ai_router
|
||||||
from .routers import generate as generate_router
|
from .routers import generate as generate_router
|
||||||
from .routers import images as images_router
|
from .routers import images as images_router
|
||||||
|
from .routers import models as models_router
|
||||||
from .db import close_db, init_db
|
from .db import close_db, init_db
|
||||||
import os
|
import os
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
@@ -43,6 +44,7 @@ app.include_router(admin_router.router)
|
|||||||
app.include_router(ai_router.router)
|
app.include_router(ai_router.router)
|
||||||
app.include_router(generate_router.router)
|
app.include_router(generate_router.router)
|
||||||
app.include_router(images_router.router)
|
app.include_router(images_router.router)
|
||||||
|
app.include_router(models_router.router)
|
||||||
|
|
||||||
|
|
||||||
@app.get("/health", tags=["health"])
|
@app.get("/health", tags=["health"])
|
||||||
|
|||||||
@@ -0,0 +1,47 @@
|
|||||||
|
"""Models router: list and refresh the OpenRouter model cache."""
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
||||||
|
|
||||||
|
from ..db import get_conn, get_write_lock
|
||||||
|
from ..dependencies import get_current_user, require_admin
|
||||||
|
from ..services import models as models_service
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/models", tags=["models"])
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/")
|
||||||
|
async def list_models(
|
||||||
|
modality: str | None = Query(
|
||||||
|
None,
|
||||||
|
description="Filter by output modality: text, image, video, audio",
|
||||||
|
),
|
||||||
|
_: dict = Depends(get_current_user),
|
||||||
|
):
|
||||||
|
"""Return cached models. Auto-refreshes cache if stale (older than 24 h)."""
|
||||||
|
conn = get_conn()
|
||||||
|
if models_service.is_cache_stale(conn):
|
||||||
|
async with get_write_lock():
|
||||||
|
# Re-check inside lock to avoid redundant parallel refreshes
|
||||||
|
if models_service.is_cache_stale(conn):
|
||||||
|
try:
|
||||||
|
await models_service.refresh_models_cache(conn)
|
||||||
|
except Exception as exc:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_502_BAD_GATEWAY,
|
||||||
|
detail=f"Failed to refresh model cache: {exc}",
|
||||||
|
)
|
||||||
|
return models_service.get_cached_models(conn, modality)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/refresh", status_code=200)
|
||||||
|
async def refresh_models(_: dict = Depends(require_admin)):
|
||||||
|
"""Force-refresh the model cache from OpenRouter. Admin only."""
|
||||||
|
conn = get_conn()
|
||||||
|
async with get_write_lock():
|
||||||
|
try:
|
||||||
|
count = await models_service.refresh_models_cache(conn)
|
||||||
|
except Exception as exc:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_502_BAD_GATEWAY,
|
||||||
|
detail=f"OpenRouter error: {exc}",
|
||||||
|
)
|
||||||
|
return {"refreshed": count}
|
||||||
@@ -0,0 +1,124 @@
|
|||||||
|
"""Model cache service: fetch from OpenRouter, store in DuckDB."""
|
||||||
|
import json
|
||||||
|
from datetime import datetime, timedelta, timezone
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import duckdb
|
||||||
|
|
||||||
|
from . import openrouter
|
||||||
|
|
||||||
|
CACHE_TTL_HOURS = 24
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_modality(raw_modality: str) -> str:
|
||||||
|
"""Extract output modality from OpenRouter architecture.modality string.
|
||||||
|
|
||||||
|
Examples: "text->text", "text+image->text", "text->image", "text->video"
|
||||||
|
"""
|
||||||
|
output = raw_modality.split(
|
||||||
|
"->", 1)[-1].lower() if "->" in raw_modality else raw_modality.lower()
|
||||||
|
if "text" in output:
|
||||||
|
return "text"
|
||||||
|
if "image" in output:
|
||||||
|
return "image"
|
||||||
|
if "video" in output:
|
||||||
|
return "video"
|
||||||
|
if "audio" in output:
|
||||||
|
return "audio"
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
async def refresh_models_cache(conn: duckdb.DuckDBPyConnection) -> int:
|
||||||
|
"""Fetch all models from OpenRouter and replace the cache. Returns count stored."""
|
||||||
|
raw = await openrouter.list_models()
|
||||||
|
# Use naive UTC to avoid DuckDB TIMESTAMP tz-stripping inconsistencies
|
||||||
|
now = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||||
|
|
||||||
|
conn.execute("DELETE FROM models_cache")
|
||||||
|
count = 0
|
||||||
|
for m in raw:
|
||||||
|
arch = m.get("architecture", {})
|
||||||
|
modality_raw = arch.get(
|
||||||
|
"modality", "text->text") if arch else "text->text"
|
||||||
|
modality = _parse_modality(modality_raw)
|
||||||
|
pricing = m.get("pricing")
|
||||||
|
model_id = m.get("id", "")
|
||||||
|
if not model_id:
|
||||||
|
continue
|
||||||
|
conn.execute(
|
||||||
|
"""
|
||||||
|
INSERT INTO models_cache (model_id, name, modality, context_length, pricing, fetched_at)
|
||||||
|
VALUES (?, ?, ?, ?, ?, ?)
|
||||||
|
ON CONFLICT (model_id) DO UPDATE SET
|
||||||
|
name = excluded.name,
|
||||||
|
modality = excluded.modality,
|
||||||
|
context_length = excluded.context_length,
|
||||||
|
pricing = excluded.pricing,
|
||||||
|
fetched_at = excluded.fetched_at
|
||||||
|
""",
|
||||||
|
[
|
||||||
|
model_id,
|
||||||
|
m.get("name", model_id),
|
||||||
|
modality,
|
||||||
|
m.get("context_length"),
|
||||||
|
json.dumps(pricing) if pricing else None,
|
||||||
|
now,
|
||||||
|
],
|
||||||
|
)
|
||||||
|
count += 1
|
||||||
|
return count
|
||||||
|
|
||||||
|
|
||||||
|
def is_cache_stale(conn: duckdb.DuckDBPyConnection) -> bool:
|
||||||
|
"""Return True if cache is empty or last fetched more than CACHE_TTL_HOURS ago."""
|
||||||
|
row = conn.execute("SELECT MAX(fetched_at) FROM models_cache").fetchone()
|
||||||
|
if not row or row[0] is None:
|
||||||
|
return True
|
||||||
|
last_fetched = row[0]
|
||||||
|
# DuckDB TIMESTAMP is always naive; compare against naive UTC
|
||||||
|
if last_fetched.tzinfo is not None:
|
||||||
|
last_fetched = last_fetched.replace(tzinfo=None)
|
||||||
|
now_naive = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||||
|
return now_naive - last_fetched > timedelta(hours=CACHE_TTL_HOURS)
|
||||||
|
|
||||||
|
|
||||||
|
def get_cached_models(
|
||||||
|
conn: duckdb.DuckDBPyConnection,
|
||||||
|
modality: str | None = None,
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
|
"""Return cached models, optionally filtered by modality, ordered by name."""
|
||||||
|
if modality:
|
||||||
|
rows = conn.execute(
|
||||||
|
"""
|
||||||
|
SELECT model_id, name, modality, context_length, pricing
|
||||||
|
FROM models_cache
|
||||||
|
WHERE modality = ?
|
||||||
|
ORDER BY name
|
||||||
|
""",
|
||||||
|
[modality],
|
||||||
|
).fetchall()
|
||||||
|
else:
|
||||||
|
rows = conn.execute(
|
||||||
|
"""
|
||||||
|
SELECT model_id, name, modality, context_length, pricing
|
||||||
|
FROM models_cache
|
||||||
|
ORDER BY name
|
||||||
|
"""
|
||||||
|
).fetchall()
|
||||||
|
|
||||||
|
result = []
|
||||||
|
for row in rows:
|
||||||
|
pricing = None
|
||||||
|
if row[4]:
|
||||||
|
try:
|
||||||
|
pricing = json.loads(row[4])
|
||||||
|
except (json.JSONDecodeError, TypeError):
|
||||||
|
pricing = None
|
||||||
|
result.append({
|
||||||
|
"id": row[0],
|
||||||
|
"name": row[1],
|
||||||
|
"modality": row[2],
|
||||||
|
"context_length": row[3],
|
||||||
|
"pricing": pricing,
|
||||||
|
})
|
||||||
|
return result
|
||||||
@@ -0,0 +1,296 @@
|
|||||||
|
"""Tests for model cache service and router."""
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
from datetime import datetime, timedelta, timezone
|
||||||
|
from unittest.mock import AsyncMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import pytest_asyncio
|
||||||
|
from httpx import ASGITransport, AsyncClient
|
||||||
|
|
||||||
|
from app import db as db_module
|
||||||
|
from app.main import app
|
||||||
|
from app.services.models import (
|
||||||
|
_parse_modality,
|
||||||
|
get_cached_models,
|
||||||
|
is_cache_stale,
|
||||||
|
refresh_models_cache,
|
||||||
|
)
|
||||||
|
|
||||||
|
os.environ.setdefault("JWT_SECRET", "test-secret-key-for-testing-only")
|
||||||
|
os.environ.setdefault("OPENROUTER_API_KEY", "test-key")
|
||||||
|
|
||||||
|
FAKE_MODELS_RAW = [
|
||||||
|
{
|
||||||
|
"id": "openai/gpt-4o",
|
||||||
|
"name": "GPT-4o",
|
||||||
|
"context_length": 128000,
|
||||||
|
"pricing": {"prompt": "0.000005"},
|
||||||
|
"architecture": {"modality": "text->text"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "anthropic/claude-3-haiku",
|
||||||
|
"name": "Claude 3 Haiku",
|
||||||
|
"context_length": 200000,
|
||||||
|
"pricing": {},
|
||||||
|
"architecture": {"modality": "text+image->text"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "openai/dall-e-3",
|
||||||
|
"name": "DALL-E 3",
|
||||||
|
"context_length": None,
|
||||||
|
"pricing": {"image": "0.04"},
|
||||||
|
"architecture": {"modality": "text->image"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "openai/sora-2",
|
||||||
|
"name": "Sora 2",
|
||||||
|
"context_length": None,
|
||||||
|
"pricing": {"video": "0.10"},
|
||||||
|
"architecture": {"modality": "text->video"},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def fresh_db():
|
||||||
|
db_module._conn = None
|
||||||
|
db_module.init_db(":memory:")
|
||||||
|
yield
|
||||||
|
db_module.close_db()
|
||||||
|
db_module._conn = None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def client(fresh_db):
|
||||||
|
transport = ASGITransport(app=app)
|
||||||
|
async with AsyncClient(transport=transport, base_url="http://test") as ac:
|
||||||
|
yield ac
|
||||||
|
|
||||||
|
|
||||||
|
async def _register_login(client, email, password, is_admin=False):
|
||||||
|
"""Register + login; optionally promote to admin directly in DB."""
|
||||||
|
await client.post("/auth/register", json={"email": email, "password": password})
|
||||||
|
if is_admin:
|
||||||
|
db_module.get_conn().execute(
|
||||||
|
"UPDATE users SET role = 'admin' WHERE email = ?", [email]
|
||||||
|
)
|
||||||
|
resp = await client.post("/auth/login", json={"email": email, "password": password})
|
||||||
|
return resp.json()["access_token"]
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Unit tests: _parse_modality
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def test_parse_modality_text():
|
||||||
|
assert _parse_modality("text->text") == "text"
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_modality_multimodal_input_text_output():
|
||||||
|
assert _parse_modality("text+image->text") == "text"
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_modality_image():
|
||||||
|
assert _parse_modality("text->image") == "image"
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_modality_video():
|
||||||
|
assert _parse_modality("text->video") == "video"
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_modality_audio():
|
||||||
|
assert _parse_modality("text->audio") == "audio"
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_modality_no_arrow_fallback():
|
||||||
|
assert _parse_modality("text") == "text"
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Unit tests: is_cache_stale
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def test_cache_stale_when_empty():
|
||||||
|
conn = db_module.get_conn()
|
||||||
|
assert is_cache_stale(conn) is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_cache_not_stale_after_fresh_insert():
|
||||||
|
conn = db_module.get_conn()
|
||||||
|
now = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||||
|
conn.execute(
|
||||||
|
"INSERT INTO models_cache (model_id, name, modality, fetched_at) VALUES (?, ?, ?, ?)",
|
||||||
|
["openai/gpt-4o", "GPT-4o", "text", now],
|
||||||
|
)
|
||||||
|
assert is_cache_stale(conn) is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_cache_stale_after_ttl_exceeded():
|
||||||
|
conn = db_module.get_conn()
|
||||||
|
# Store naive UTC to match DuckDB TIMESTAMP behaviour
|
||||||
|
old_time = datetime.now(timezone.utc).replace(
|
||||||
|
tzinfo=None) - timedelta(hours=25)
|
||||||
|
conn.execute(
|
||||||
|
"INSERT INTO models_cache (model_id, name, modality, fetched_at) VALUES (?, ?, ?, ?)",
|
||||||
|
["openai/gpt-4o", "GPT-4o", "text", old_time],
|
||||||
|
)
|
||||||
|
assert is_cache_stale(conn) is True
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Unit tests: refresh_models_cache + get_cached_models
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def test_refresh_stores_models():
|
||||||
|
conn = db_module.get_conn()
|
||||||
|
with patch(
|
||||||
|
"app.services.models.openrouter.list_models",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
return_value=FAKE_MODELS_RAW,
|
||||||
|
):
|
||||||
|
count = await refresh_models_cache(conn)
|
||||||
|
assert count == 4
|
||||||
|
all_models = get_cached_models(conn)
|
||||||
|
assert len(all_models) == 4
|
||||||
|
|
||||||
|
|
||||||
|
async def test_refresh_replaces_old_cache():
|
||||||
|
conn = db_module.get_conn()
|
||||||
|
old_time = datetime.now(timezone.utc).replace(
|
||||||
|
tzinfo=None) - timedelta(hours=30)
|
||||||
|
conn.execute(
|
||||||
|
"INSERT INTO models_cache (model_id, name, modality, fetched_at) VALUES (?, ?, ?, ?)",
|
||||||
|
["old/model", "Old Model", "text", old_time],
|
||||||
|
)
|
||||||
|
with patch(
|
||||||
|
"app.services.models.openrouter.list_models",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
return_value=FAKE_MODELS_RAW,
|
||||||
|
):
|
||||||
|
await refresh_models_cache(conn)
|
||||||
|
ids = [m["id"] for m in get_cached_models(conn)]
|
||||||
|
assert "old/model" not in ids
|
||||||
|
assert "openai/gpt-4o" in ids
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_cached_models_filter_by_modality():
|
||||||
|
conn = db_module.get_conn()
|
||||||
|
now = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||||
|
for m in FAKE_MODELS_RAW:
|
||||||
|
arch = m.get("architecture", {})
|
||||||
|
modality = _parse_modality(arch.get("modality", "text->text"))
|
||||||
|
conn.execute(
|
||||||
|
"INSERT INTO models_cache (model_id, name, modality, fetched_at) VALUES (?, ?, ?, ?)",
|
||||||
|
[m["id"], m["name"], modality, now],
|
||||||
|
)
|
||||||
|
text_models = get_cached_models(conn, modality="text")
|
||||||
|
assert len(text_models) == 2
|
||||||
|
assert all(m["modality"] == "text" for m in text_models)
|
||||||
|
|
||||||
|
image_models = get_cached_models(conn, modality="image")
|
||||||
|
assert len(image_models) == 1
|
||||||
|
assert image_models[0]["id"] == "openai/dall-e-3"
|
||||||
|
|
||||||
|
video_models = get_cached_models(conn, modality="video")
|
||||||
|
assert len(video_models) == 1
|
||||||
|
assert video_models[0]["id"] == "openai/sora-2"
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Integration tests: GET /models/
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def test_list_models_endpoint_auto_refreshes(client):
|
||||||
|
token = await _register_login(client, "user@example.com", "secret123")
|
||||||
|
with patch(
|
||||||
|
"app.services.models.openrouter.list_models",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
return_value=FAKE_MODELS_RAW,
|
||||||
|
) as mock_fetch:
|
||||||
|
resp = await client.get(
|
||||||
|
"/models/", headers={"Authorization": f"Bearer {token}"}
|
||||||
|
)
|
||||||
|
assert resp.status_code == 200
|
||||||
|
assert len(resp.json()) == 4
|
||||||
|
mock_fetch.assert_awaited_once()
|
||||||
|
|
||||||
|
|
||||||
|
async def test_list_models_endpoint_uses_cache(client):
|
||||||
|
token = await _register_login(client, "user@example.com", "secret123")
|
||||||
|
conn = db_module.get_conn()
|
||||||
|
now = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||||
|
conn.execute(
|
||||||
|
"INSERT INTO models_cache (model_id, name, modality, fetched_at) VALUES (?, ?, ?, ?)",
|
||||||
|
["cached/model", "Cached Model", "text", now],
|
||||||
|
)
|
||||||
|
with patch(
|
||||||
|
"app.services.models.openrouter.list_models",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
) as mock_fetch:
|
||||||
|
resp = await client.get(
|
||||||
|
"/models/?modality=text", headers={"Authorization": f"Bearer {token}"}
|
||||||
|
)
|
||||||
|
assert resp.status_code == 200
|
||||||
|
assert resp.json()[0]["id"] == "cached/model"
|
||||||
|
mock_fetch.assert_not_awaited()
|
||||||
|
|
||||||
|
|
||||||
|
async def test_list_models_endpoint_requires_auth(client):
|
||||||
|
resp = await client.get("/models/")
|
||||||
|
assert resp.status_code == 401
|
||||||
|
|
||||||
|
|
||||||
|
async def test_list_models_filter_by_modality(client):
|
||||||
|
token = await _register_login(client, "user@example.com", "secret123")
|
||||||
|
with patch(
|
||||||
|
"app.services.models.openrouter.list_models",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
return_value=FAKE_MODELS_RAW,
|
||||||
|
):
|
||||||
|
resp = await client.get(
|
||||||
|
"/models/?modality=image", headers={"Authorization": f"Bearer {token}"}
|
||||||
|
)
|
||||||
|
assert resp.status_code == 200
|
||||||
|
data = resp.json()
|
||||||
|
assert len(data) == 1
|
||||||
|
assert data[0]["id"] == "openai/dall-e-3"
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Integration tests: POST /models/refresh
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def test_refresh_endpoint_requires_admin(client):
|
||||||
|
token = await _register_login(client, "user@example.com", "secret123")
|
||||||
|
resp = await client.post(
|
||||||
|
"/models/refresh", headers={"Authorization": f"Bearer {token}"}
|
||||||
|
)
|
||||||
|
assert resp.status_code == 403
|
||||||
|
|
||||||
|
|
||||||
|
async def test_refresh_endpoint_admin_succeeds(client):
|
||||||
|
token = await _register_login(client, "admin@example.com", "secret123", is_admin=True)
|
||||||
|
with patch(
|
||||||
|
"app.services.models.openrouter.list_models",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
return_value=FAKE_MODELS_RAW,
|
||||||
|
):
|
||||||
|
resp = await client.post(
|
||||||
|
"/models/refresh", headers={"Authorization": f"Bearer {token}"}
|
||||||
|
)
|
||||||
|
assert resp.status_code == 200
|
||||||
|
assert resp.json()["refreshed"] == 4
|
||||||
|
|
||||||
|
|
||||||
|
async def test_refresh_endpoint_502_on_openrouter_error(client):
|
||||||
|
token = await _register_login(client, "admin@example.com", "secret123", is_admin=True)
|
||||||
|
with patch(
|
||||||
|
"app.services.models.openrouter.list_models",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
side_effect=RuntimeError("network error"),
|
||||||
|
):
|
||||||
|
resp = await client.post(
|
||||||
|
"/models/refresh", headers={"Authorization": f"Bearer {token}"}
|
||||||
|
)
|
||||||
|
assert resp.status_code == 502
|
||||||
+22
-8
@@ -153,8 +153,9 @@ def generate():
|
|||||||
@login_required
|
@login_required
|
||||||
def generate_text():
|
def generate_text():
|
||||||
result = error = None
|
result = error = None
|
||||||
|
token = session["access_token"]
|
||||||
if request.method == "POST":
|
if request.method == "POST":
|
||||||
resp = _api("POST", "/generate/text", token=session["access_token"], json={
|
resp = _api("POST", "/generate/text", token=token, json={
|
||||||
"model": request.form.get("model", "").strip(),
|
"model": request.form.get("model", "").strip(),
|
||||||
"prompt": request.form.get("prompt", "").strip(),
|
"prompt": request.form.get("prompt", "").strip(),
|
||||||
})
|
})
|
||||||
@@ -162,28 +163,35 @@ def generate_text():
|
|||||||
result = resp.json()
|
result = resp.json()
|
||||||
else:
|
else:
|
||||||
error = resp.json().get("detail", "Generation failed.")
|
error = resp.json().get("detail", "Generation failed.")
|
||||||
return render_template("generate_text.html", result=result, error=error)
|
models_resp = _api("GET", "/models/", token=token,
|
||||||
|
params={"modality": "text"})
|
||||||
|
models = models_resp.json() if models_resp.status_code == 200 else []
|
||||||
|
return render_template("generate_text.html", result=result, error=error, models=models)
|
||||||
|
|
||||||
|
|
||||||
@app.route("/generate/image", methods=["GET", "POST"])
|
@app.route("/generate/image", methods=["GET", "POST"])
|
||||||
@login_required
|
@login_required
|
||||||
def generate_image():
|
def generate_image():
|
||||||
result = error = None
|
result = error = None
|
||||||
|
token = session["access_token"]
|
||||||
if request.method == "POST":
|
if request.method == "POST":
|
||||||
# Upload reference image if provided
|
# Upload reference image if provided
|
||||||
ref_file = request.files.get("reference_image")
|
ref_file = request.files.get("reference_image")
|
||||||
if ref_file and ref_file.filename:
|
if ref_file and ref_file.filename:
|
||||||
up_resp = _api(
|
up_resp = _api(
|
||||||
"POST", "/images/upload",
|
"POST", "/images/upload",
|
||||||
token=session["access_token"],
|
token=token,
|
||||||
files={"file": (ref_file.filename,
|
files={"file": (ref_file.filename,
|
||||||
ref_file.stream, ref_file.content_type)},
|
ref_file.stream, ref_file.content_type)},
|
||||||
)
|
)
|
||||||
if up_resp.status_code not in (200, 201):
|
if up_resp.status_code not in (200, 201):
|
||||||
error = up_resp.json().get("detail", "Image upload failed.")
|
error = up_resp.json().get("detail", "Image upload failed.")
|
||||||
return render_template("generate_image.html", result=result, error=error)
|
models_resp = _api("GET", "/models/",
|
||||||
|
token=token, params={"modality": "image"})
|
||||||
|
models = models_resp.json() if models_resp.status_code == 200 else []
|
||||||
|
return render_template("generate_image.html", result=result, error=error, models=models)
|
||||||
|
|
||||||
resp = _api("POST", "/generate/image", token=session["access_token"], json={
|
resp = _api("POST", "/generate/image", token=token, json={
|
||||||
"model": request.form.get("model", "").strip(),
|
"model": request.form.get("model", "").strip(),
|
||||||
"prompt": request.form.get("prompt", "").strip(),
|
"prompt": request.form.get("prompt", "").strip(),
|
||||||
"n": int(request.form.get("n", 1)),
|
"n": int(request.form.get("n", 1)),
|
||||||
@@ -195,16 +203,19 @@ def generate_image():
|
|||||||
result = resp.json()
|
result = resp.json()
|
||||||
else:
|
else:
|
||||||
error = resp.json().get("detail", "Generation failed.")
|
error = resp.json().get("detail", "Generation failed.")
|
||||||
return render_template("generate_image.html", result=result, error=error)
|
models_resp = _api("GET", "/models/", token=token,
|
||||||
|
params={"modality": "image"})
|
||||||
|
models = models_resp.json() if models_resp.status_code == 200 else []
|
||||||
|
return render_template("generate_image.html", result=result, error=error, models=models)
|
||||||
|
|
||||||
|
|
||||||
@app.route("/generate/video", methods=["GET", "POST"])
|
@app.route("/generate/video", methods=["GET", "POST"])
|
||||||
@login_required
|
@login_required
|
||||||
def generate_video():
|
def generate_video():
|
||||||
result = error = None
|
result = error = None
|
||||||
|
token = session["access_token"]
|
||||||
if request.method == "POST":
|
if request.method == "POST":
|
||||||
mode = request.form.get("mode", "text")
|
mode = request.form.get("mode", "text")
|
||||||
token = session["access_token"]
|
|
||||||
duration_raw = request.form.get("duration_seconds", "")
|
duration_raw = request.form.get("duration_seconds", "")
|
||||||
duration = int(
|
duration = int(
|
||||||
duration_raw) if duration_raw.strip().isdigit() else None
|
duration_raw) if duration_raw.strip().isdigit() else None
|
||||||
@@ -230,7 +241,10 @@ def generate_video():
|
|||||||
result = resp.json()
|
result = resp.json()
|
||||||
else:
|
else:
|
||||||
error = resp.json().get("detail", "Generation failed.")
|
error = resp.json().get("detail", "Generation failed.")
|
||||||
return render_template("generate_video.html", result=result, error=error)
|
models_resp = _api("GET", "/models/", token=token,
|
||||||
|
params={"modality": "video"})
|
||||||
|
models = models_resp.json() if models_resp.status_code == 200 else []
|
||||||
|
return render_template("generate_video.html", result=result, error=error, models=models)
|
||||||
|
|
||||||
|
|
||||||
@app.get("/generate/video/status")
|
@app.get("/generate/video/status")
|
||||||
|
|||||||
@@ -5,9 +5,17 @@
|
|||||||
<h1>Image Generation</h1>
|
<h1>Image Generation</h1>
|
||||||
<form method="post" enctype="multipart/form-data">
|
<form method="post" enctype="multipart/form-data">
|
||||||
<label for="model">Model</label>
|
<label for="model">Model</label>
|
||||||
|
{% if models %}
|
||||||
|
<select id="model" name="model" required>
|
||||||
|
{% for m in models %}
|
||||||
|
<option value="{{ m.id }}" {% if request.form.get('model', '') == m.id %}selected{% endif %}>{{ m.name }}</option>
|
||||||
|
{% endfor %}
|
||||||
|
</select>
|
||||||
|
{% else %}
|
||||||
<input id="model" name="model" type="text" required
|
<input id="model" name="model" type="text" required
|
||||||
placeholder="e.g. openai/dall-e-3"
|
placeholder="e.g. openai/dall-e-3"
|
||||||
value="{{ request.form.get('model', '') }}">
|
value="{{ request.form.get('model', '') }}">
|
||||||
|
{% endif %}
|
||||||
|
|
||||||
<label for="prompt">Prompt</label>
|
<label for="prompt">Prompt</label>
|
||||||
<textarea id="prompt" name="prompt" rows="4" required
|
<textarea id="prompt" name="prompt" rows="4" required
|
||||||
|
|||||||
@@ -4,6 +4,13 @@ AI{% endblock %} {% block content %}
|
|||||||
<h1>Text Generation</h1>
|
<h1>Text Generation</h1>
|
||||||
<form method="post">
|
<form method="post">
|
||||||
<label for="model">Model</label>
|
<label for="model">Model</label>
|
||||||
|
{% if models %}
|
||||||
|
<select id="model" name="model" required>
|
||||||
|
{% for m in models %}
|
||||||
|
<option value="{{ m.id }}" {% if request.form.get('model', '') == m.id %}selected{% endif %}>{{ m.name }}</option>
|
||||||
|
{% endfor %}
|
||||||
|
</select>
|
||||||
|
{% else %}
|
||||||
<input
|
<input
|
||||||
id="model"
|
id="model"
|
||||||
name="model"
|
name="model"
|
||||||
@@ -12,6 +19,7 @@ AI{% endblock %} {% block content %}
|
|||||||
placeholder="e.g. openai/gpt-4o"
|
placeholder="e.g. openai/gpt-4o"
|
||||||
value="{{ request.form.get('model', '') }}"
|
value="{{ request.form.get('model', '') }}"
|
||||||
/>
|
/>
|
||||||
|
{% endif %}
|
||||||
|
|
||||||
<label for="prompt">Prompt</label>
|
<label for="prompt">Prompt</label>
|
||||||
<textarea
|
<textarea
|
||||||
|
|||||||
@@ -19,6 +19,13 @@ AI{% endblock %} {% block content %}
|
|||||||
<input type="hidden" name="mode" value="text" />
|
<input type="hidden" name="mode" value="text" />
|
||||||
|
|
||||||
<label for="model-t">Model</label>
|
<label for="model-t">Model</label>
|
||||||
|
{% if models %}
|
||||||
|
<select id="model-t" name="model" required>
|
||||||
|
{% for m in models %}
|
||||||
|
<option value="{{ m.id }}" {% if request.form.get('model', '') == m.id and request.form.get('mode','text')=='text' %}selected{% endif %}>{{ m.name }}</option>
|
||||||
|
{% endfor %}
|
||||||
|
</select>
|
||||||
|
{% else %}
|
||||||
<input
|
<input
|
||||||
id="model-t"
|
id="model-t"
|
||||||
name="model"
|
name="model"
|
||||||
@@ -27,6 +34,7 @@ AI{% endblock %} {% block content %}
|
|||||||
placeholder="e.g. openai/sora-2-pro"
|
placeholder="e.g. openai/sora-2-pro"
|
||||||
value="{{ request.form.get('model', '') if request.form.get('mode','text')=='text' else '' }}"
|
value="{{ request.form.get('model', '') if request.form.get('mode','text')=='text' else '' }}"
|
||||||
/>
|
/>
|
||||||
|
{% endif %}
|
||||||
|
|
||||||
<label for="prompt-t">Prompt</label>
|
<label for="prompt-t">Prompt</label>
|
||||||
<textarea
|
<textarea
|
||||||
@@ -80,6 +88,13 @@ AI{% endblock %} {% block content %}
|
|||||||
<input type="hidden" name="mode" value="image" />
|
<input type="hidden" name="mode" value="image" />
|
||||||
|
|
||||||
<label for="model-i">Model</label>
|
<label for="model-i">Model</label>
|
||||||
|
{% if models %}
|
||||||
|
<select id="model-i" name="model" required>
|
||||||
|
{% for m in models %}
|
||||||
|
<option value="{{ m.id }}" {% if request.form.get('model', '') == m.id and request.form.get('mode')=='image' %}selected{% endif %}>{{ m.name }}</option>
|
||||||
|
{% endfor %}
|
||||||
|
</select>
|
||||||
|
{% else %}
|
||||||
<input
|
<input
|
||||||
id="model-i"
|
id="model-i"
|
||||||
name="model"
|
name="model"
|
||||||
@@ -88,6 +103,7 @@ AI{% endblock %} {% block content %}
|
|||||||
placeholder="e.g. openai/sora-2-pro"
|
placeholder="e.g. openai/sora-2-pro"
|
||||||
value="{{ request.form.get('model', '') if request.form.get('mode')=='image' else '' }}"
|
value="{{ request.form.get('model', '') if request.form.get('mode')=='image' else '' }}"
|
||||||
/>
|
/>
|
||||||
|
{% endif %}
|
||||||
|
|
||||||
<label for="image_url">Source image URL</label>
|
<label for="image_url">Source image URL</label>
|
||||||
<input
|
<input
|
||||||
|
|||||||
Reference in New Issue
Block a user