712c556032
- Updated `refresh_models_cache` to include output modalities in the models cache. - Added `get_model_output_modalities` function to retrieve output modalities for a specific model. - Modified tests to cover new functionality for output modalities. - Updated OpenRouter video generation functions to support audio generation and improved error handling. - Enhanced dashboard to display generated images and videos. - Refactored frontend templates to accommodate new data structures for generated content. - Adjusted tests to validate changes in model handling and dashboard rendering. Co-authored-by: Copilot <copilot@github.com>
471 lines
18 KiB
Python
471 lines
18 KiB
Python
"""Tests for generate endpoints — all OpenRouter calls mocked."""
|
|
import os
|
|
import pytest
|
|
import pytest_asyncio
|
|
from unittest.mock import AsyncMock, patch
|
|
from httpx import AsyncClient, ASGITransport
|
|
|
|
from app.main import app
|
|
from app import db as db_module
|
|
|
|
os.environ.setdefault("JWT_SECRET", "test-secret-key-for-testing-only")
|
|
os.environ.setdefault("OPENROUTER_API_KEY", "test-key")
|
|
|
|
FAKE_CHAT = {
|
|
"id": "gen-text-1",
|
|
"model": "openai/gpt-4o",
|
|
"choices": [{"message": {"role": "assistant", "content": "Once upon a time..."}}],
|
|
"usage": {"prompt_tokens": 5, "completion_tokens": 10, "total_tokens": 15},
|
|
}
|
|
|
|
FAKE_VIDEO = {
|
|
"id": "gen-vid-1",
|
|
"polling_url": "https://openrouter.ai/api/v1/videos/gen-vid-1",
|
|
"status": "queued",
|
|
}
|
|
|
|
FAKE_VIDEO_DONE = {
|
|
"id": "gen-vid-2",
|
|
"polling_url": "https://openrouter.ai/api/v1/videos/gen-vid-2",
|
|
"status": "completed",
|
|
"unsigned_urls": ["https://example.com/video.mp4"],
|
|
}
|
|
|
|
|
|
@pytest.fixture(autouse=True)
|
|
def fresh_db():
|
|
db_module._conn = None
|
|
db_module.init_db(":memory:")
|
|
yield
|
|
db_module.close_db()
|
|
db_module._conn = None
|
|
|
|
|
|
@pytest_asyncio.fixture
|
|
async def client(fresh_db):
|
|
transport = ASGITransport(app=app)
|
|
async with AsyncClient(transport=transport, base_url="http://test") as ac:
|
|
yield ac
|
|
|
|
|
|
async def _user_token(client):
|
|
await client.post("/auth/register", json={"email": "user@example.com", "password": "secret123"})
|
|
resp = await client.post("/auth/login", json={"email": "user@example.com", "password": "secret123"})
|
|
return resp.json()["access_token"]
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# POST /generate/text
|
|
# ---------------------------------------------------------------------------
|
|
|
|
async def test_generate_text(client):
|
|
token = await _user_token(client)
|
|
with patch("app.routers.generate.openrouter.chat_completion", new_callable=AsyncMock, return_value=FAKE_CHAT):
|
|
resp = await client.post(
|
|
"/generate/text",
|
|
json={"model": "openai/gpt-4o", "prompt": "Tell me a story"},
|
|
headers={"Authorization": f"Bearer {token}"},
|
|
)
|
|
assert resp.status_code == 200
|
|
data = resp.json()
|
|
assert data["content"] == "Once upon a time..."
|
|
assert data["id"] == "gen-text-1"
|
|
assert data["usage"]["total_tokens"] == 15
|
|
|
|
|
|
async def test_generate_text_with_system_prompt(client):
|
|
token = await _user_token(client)
|
|
mock = AsyncMock(return_value=FAKE_CHAT)
|
|
with patch("app.routers.generate.openrouter.chat_completion", mock):
|
|
await client.post(
|
|
"/generate/text",
|
|
json={"model": "openai/gpt-4o", "prompt": "Hello",
|
|
"system_prompt": "Be concise."},
|
|
headers={"Authorization": f"Bearer {token}"},
|
|
)
|
|
call_messages = mock.call_args.kwargs["messages"]
|
|
assert call_messages[0] == {"role": "system", "content": "Be concise."}
|
|
assert call_messages[1] == {"role": "user", "content": "Hello"}
|
|
|
|
|
|
async def test_generate_text_with_messages_array(client):
|
|
"""messages field takes precedence over prompt for multi-turn chat."""
|
|
token = await _user_token(client)
|
|
mock = AsyncMock(return_value=FAKE_CHAT)
|
|
messages = [
|
|
{"role": "user", "content": "First message"},
|
|
{"role": "assistant", "content": "Reply"},
|
|
{"role": "user", "content": "Follow up"},
|
|
]
|
|
with patch("app.routers.generate.openrouter.chat_completion", mock):
|
|
resp = await client.post(
|
|
"/generate/text",
|
|
json={"model": "openai/gpt-4o", "messages": messages},
|
|
headers={"Authorization": f"Bearer {token}"},
|
|
)
|
|
assert resp.status_code == 200
|
|
call_messages = mock.call_args.kwargs["messages"]
|
|
assert len(call_messages) == 3
|
|
assert call_messages[2]["content"] == "Follow up"
|
|
|
|
|
|
async def test_generate_text_messages_with_system_prompt(client):
|
|
"""system_prompt prepended when messages provided and no system msg present."""
|
|
token = await _user_token(client)
|
|
mock = AsyncMock(return_value=FAKE_CHAT)
|
|
messages = [{"role": "user", "content": "Hi"}]
|
|
with patch("app.routers.generate.openrouter.chat_completion", mock):
|
|
await client.post(
|
|
"/generate/text",
|
|
json={"model": "openai/gpt-4o", "messages": messages,
|
|
"system_prompt": "Be brief."},
|
|
headers={"Authorization": f"Bearer {token}"},
|
|
)
|
|
call_messages = mock.call_args.kwargs["messages"]
|
|
assert call_messages[0] == {"role": "system", "content": "Be brief."}
|
|
assert call_messages[1] == {"role": "user", "content": "Hi"}
|
|
|
|
|
|
async def test_generate_text_unauthenticated(client):
|
|
resp = await client.post("/generate/text", json={"model": "openai/gpt-4o", "prompt": "Hi"})
|
|
assert resp.status_code == 401
|
|
|
|
|
|
async def test_generate_text_upstream_error(client):
|
|
token = await _user_token(client)
|
|
with patch("app.routers.generate.openrouter.chat_completion", new_callable=AsyncMock, side_effect=Exception("timeout")):
|
|
resp = await client.post(
|
|
"/generate/text",
|
|
json={"model": "openai/gpt-4o", "prompt": "Hi"},
|
|
headers={"Authorization": f"Bearer {token}"},
|
|
)
|
|
assert resp.status_code == 502
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# POST /generate/image
|
|
# ---------------------------------------------------------------------------
|
|
|
|
FAKE_IMAGE_CHAT_FLUX = {
|
|
"id": "gen-img-chat-1",
|
|
"model": "black-forest-labs/flux.2-klein-4b",
|
|
"choices": [{
|
|
"message": {
|
|
"role": "assistant",
|
|
"content": None,
|
|
"images": [{
|
|
"type": "image_url",
|
|
"image_url": {"url": "data:image/png;base64,abc123"},
|
|
}],
|
|
}
|
|
}],
|
|
}
|
|
|
|
FAKE_IMAGE_CHAT_GPT5 = {
|
|
"id": "gen-img-chat-2",
|
|
"model": "openai/gpt-5-image-mini",
|
|
"choices": [{
|
|
"message": {
|
|
"role": "assistant",
|
|
"content": "Generated image.",
|
|
"images": [{
|
|
"type": "image_url",
|
|
"image_url": {"url": "data:image/png;base64,xyz789"},
|
|
}],
|
|
}
|
|
}],
|
|
}
|
|
|
|
FAKE_IMAGE_CHAT_GEMINI = {
|
|
"id": "gen-img-chat-3",
|
|
"model": "google/gemini-2.5-flash-image",
|
|
"choices": [{
|
|
"message": {
|
|
"role": "assistant",
|
|
"content": "Here is your image.",
|
|
"images": [{
|
|
"type": "image_url",
|
|
"image_url": {"url": "data:image/png;base64,gemini123"},
|
|
}],
|
|
}
|
|
}],
|
|
}
|
|
|
|
|
|
async def test_generate_image(client):
|
|
"""All models now use generate_image_chat (chat completions endpoint)."""
|
|
token = await _user_token(client)
|
|
with patch("app.routers.generate.openrouter.generate_image_chat", new_callable=AsyncMock, return_value=FAKE_IMAGE_CHAT_GEMINI):
|
|
resp = await client.post(
|
|
"/generate/image",
|
|
json={"model": "google/gemini-2.5-flash-image",
|
|
"prompt": "A cat on the moon"},
|
|
headers={"Authorization": f"Bearer {token}"},
|
|
)
|
|
assert resp.status_code == 200
|
|
data = resp.json()
|
|
assert data["id"] == "gen-img-chat-3"
|
|
assert len(data["images"]) == 1
|
|
assert data["images"][0]["url"] == "data:image/png;base64,gemini123"
|
|
assert data["images"][0]["image_id"] is not None # stored in DB
|
|
|
|
|
|
async def test_generate_image_unauthenticated(client):
|
|
resp = await client.post("/generate/image", json={"model": "google/gemini-2.5-flash-image", "prompt": "Hi"})
|
|
assert resp.status_code == 401
|
|
|
|
|
|
async def test_generate_image_upstream_error(client):
|
|
token = await _user_token(client)
|
|
with patch("app.routers.generate.openrouter.generate_image_chat", new_callable=AsyncMock, side_effect=Exception("rate limit")):
|
|
resp = await client.post(
|
|
"/generate/image",
|
|
json={"model": "google/gemini-2.5-flash-image", "prompt": "Hi"},
|
|
headers={"Authorization": f"Bearer {token}"},
|
|
)
|
|
assert resp.status_code == 502
|
|
|
|
|
|
async def test_generate_image_with_image_config(client):
|
|
"""Passes aspect_ratio + image_size through to generate_image_chat."""
|
|
token = await _user_token(client)
|
|
mock = AsyncMock(return_value=FAKE_IMAGE_CHAT_GEMINI)
|
|
with patch("app.routers.generate.openrouter.generate_image_chat", mock):
|
|
await client.post(
|
|
"/generate/image",
|
|
json={
|
|
"model": "google/gemini-2.5-flash-image",
|
|
"prompt": "A landscape",
|
|
"aspect_ratio": "16:9",
|
|
"image_size": "2K",
|
|
},
|
|
headers={"Authorization": f"Bearer {token}"},
|
|
)
|
|
call_kwargs = mock.call_args.kwargs
|
|
assert call_kwargs["image_config"]["aspect_ratio"] == "16:9"
|
|
assert call_kwargs["image_config"]["image_size"] == "2K"
|
|
|
|
|
|
async def test_generate_image_default_modalities_image_text(client):
|
|
"""Model not in cache → default modalities = ['image', 'text']."""
|
|
token = await _user_token(client)
|
|
mock = AsyncMock(return_value=FAKE_IMAGE_CHAT_GEMINI)
|
|
with patch("app.routers.generate.openrouter.generate_image_chat", mock):
|
|
await client.post(
|
|
"/generate/image",
|
|
json={"model": "google/gemini-2.5-flash-image", "prompt": "Hi"},
|
|
headers={"Authorization": f"Bearer {token}"},
|
|
)
|
|
assert mock.call_args.kwargs["modalities"] == ["image", "text"]
|
|
|
|
|
|
async def test_generate_image_image_only_modalities_from_cache(client):
|
|
"""Model cached with image-only output_modalities → modalities = ['image']."""
|
|
from app import db as db_module
|
|
from app.services.models import get_model_output_modalities
|
|
import json as _json
|
|
token = await _user_token(client)
|
|
|
|
# Seed cache with image-only model
|
|
conn = db_module.get_conn()
|
|
from datetime import datetime, timezone
|
|
now = datetime.now(timezone.utc).replace(tzinfo=None)
|
|
conn.execute(
|
|
"DELETE FROM models_cache WHERE model_id = 'black-forest-labs/flux.2-pro'"
|
|
)
|
|
conn.execute(
|
|
"""INSERT INTO models_cache (model_id, name, modality, context_length, pricing, fetched_at, output_modalities)
|
|
VALUES (?, ?, ?, ?, ?, ?, ?)""",
|
|
["black-forest-labs/flux.2-pro", "FLUX.2 Pro", "image", None, None, now,
|
|
_json.dumps(["image"])],
|
|
)
|
|
|
|
mock = AsyncMock(return_value=FAKE_IMAGE_CHAT_FLUX)
|
|
with patch("app.routers.generate.openrouter.generate_image_chat", mock):
|
|
resp = await client.post(
|
|
"/generate/image",
|
|
json={"model": "black-forest-labs/flux.2-pro", "prompt": "Sky"},
|
|
headers={"Authorization": f"Bearer {token}"},
|
|
)
|
|
assert resp.status_code == 200
|
|
assert mock.call_args.kwargs["modalities"] == ["image"]
|
|
|
|
|
|
async def test_generate_image_no_images_in_response(client):
|
|
"""502 when model returns no images."""
|
|
token = await _user_token(client)
|
|
empty_response = {
|
|
"id": "gen-empty",
|
|
"model": "google/gemini-2.5-flash-image",
|
|
"choices": [{"message": {"role": "assistant", "content": "ok", "images": []}}],
|
|
}
|
|
with patch("app.routers.generate.openrouter.generate_image_chat",
|
|
new_callable=AsyncMock, return_value=empty_response):
|
|
resp = await client.post(
|
|
"/generate/image",
|
|
json={"model": "google/gemini-2.5-flash-image", "prompt": "Hi"},
|
|
headers={"Authorization": f"Bearer {token}"},
|
|
)
|
|
assert resp.status_code == 502
|
|
assert "No images returned" in resp.json()["detail"]
|
|
|
|
|
|
async def test_generate_image_flux(client):
|
|
"""Flux model works correctly via chat completions."""
|
|
token = await _user_token(client)
|
|
with patch("app.routers.generate.openrouter.generate_image_chat",
|
|
new_callable=AsyncMock, return_value=FAKE_IMAGE_CHAT_FLUX):
|
|
resp = await client.post(
|
|
"/generate/image",
|
|
json={"model": "black-forest-labs/flux.2-klein-4b",
|
|
"prompt": "A sunset"},
|
|
headers={"Authorization": f"Bearer {token}"},
|
|
)
|
|
assert resp.status_code == 200
|
|
data = resp.json()
|
|
assert data["images"][0]["url"] == "data:image/png;base64,abc123"
|
|
|
|
|
|
async def test_generate_image_stored_in_db(client):
|
|
"""Generated image row persists in generated_images table."""
|
|
from app import db as db_module
|
|
token = await _user_token(client)
|
|
with patch("app.routers.generate.openrouter.generate_image_chat",
|
|
new_callable=AsyncMock, return_value=FAKE_IMAGE_CHAT_GEMINI):
|
|
resp = await client.post(
|
|
"/generate/image",
|
|
json={"model": "google/gemini-2.5-flash-image",
|
|
"prompt": "A mountain"},
|
|
headers={"Authorization": f"Bearer {token}"},
|
|
)
|
|
assert resp.status_code == 200
|
|
image_id = resp.json()["images"][0]["image_id"]
|
|
assert image_id is not None
|
|
|
|
row = db_module.get_conn().execute(
|
|
"SELECT model_id, prompt, image_data FROM generated_images WHERE id = ?",
|
|
[image_id],
|
|
).fetchone()
|
|
assert row is not None
|
|
assert row[0] == "google/gemini-2.5-flash-image"
|
|
assert row[1] == "A mountain"
|
|
assert row[2] == "data:image/png;base64,gemini123"
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# POST /generate/video
|
|
# ---------------------------------------------------------------------------
|
|
|
|
async def test_generate_video(client):
|
|
token = await _user_token(client)
|
|
with patch("app.routers.generate.openrouter.generate_video", new_callable=AsyncMock, return_value=FAKE_VIDEO):
|
|
resp = await client.post(
|
|
"/generate/video",
|
|
json={"model": "stability/stable-video",
|
|
"prompt": "Ocean waves at sunset"},
|
|
headers={"Authorization": f"Bearer {token}"},
|
|
)
|
|
assert resp.status_code == 200
|
|
data = resp.json()
|
|
assert data["id"] == "gen-vid-1"
|
|
assert data["status"] == "queued"
|
|
assert data["polling_url"] == "https://openrouter.ai/api/v1/videos/gen-vid-1"
|
|
assert data["video_url"] is None
|
|
|
|
|
|
async def test_generate_video_unauthenticated(client):
|
|
resp = await client.post("/generate/video", json={"model": "m", "prompt": "p"})
|
|
assert resp.status_code == 401
|
|
|
|
|
|
async def test_generate_video_upstream_error(client):
|
|
token = await _user_token(client)
|
|
with patch("app.routers.generate.openrouter.generate_video", new_callable=AsyncMock, side_effect=Exception("503")):
|
|
resp = await client.post(
|
|
"/generate/video",
|
|
json={"model": "stability/stable-video", "prompt": "Hi"},
|
|
headers={"Authorization": f"Bearer {token}"},
|
|
)
|
|
assert resp.status_code == 502
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# POST /generate/video/from-image
|
|
# ---------------------------------------------------------------------------
|
|
|
|
async def test_generate_video_from_image(client):
|
|
token = await _user_token(client)
|
|
with patch("app.routers.generate.openrouter.generate_video_from_image", new_callable=AsyncMock, return_value=FAKE_VIDEO_DONE):
|
|
resp = await client.post(
|
|
"/generate/video/from-image",
|
|
json={
|
|
"model": "runway/gen-3",
|
|
"image_url": "https://example.com/cat.jpg",
|
|
"prompt": "Cat runs across the room",
|
|
},
|
|
headers={"Authorization": f"Bearer {token}"},
|
|
)
|
|
assert resp.status_code == 200
|
|
data = resp.json()
|
|
assert data["status"] == "completed"
|
|
assert data["video_url"] == "https://example.com/video.mp4"
|
|
assert data["video_urls"] == ["https://example.com/video.mp4"]
|
|
|
|
|
|
async def test_poll_video_status(client):
|
|
token = await _user_token(client)
|
|
mock_result = {
|
|
"id": "gen-vid-1",
|
|
"status": "completed",
|
|
"unsigned_urls": ["https://example.com/video.mp4"],
|
|
}
|
|
with patch("app.routers.generate.openrouter.poll_video_status", new_callable=AsyncMock, return_value=mock_result):
|
|
resp = await client.get(
|
|
"/generate/video/status",
|
|
params={"polling_url": "https://openrouter.ai/api/v1/videos/gen-vid-1"},
|
|
headers={"Authorization": f"Bearer {token}"},
|
|
)
|
|
assert resp.status_code == 200
|
|
data = resp.json()
|
|
assert data["status"] == "completed"
|
|
assert data["video_url"] == "https://example.com/video.mp4"
|
|
|
|
|
|
async def test_poll_video_status_unauthenticated(client):
|
|
resp = await client.get(
|
|
"/generate/video/status",
|
|
params={"polling_url": "https://openrouter.ai/api/v1/videos/gen-vid-1"},
|
|
)
|
|
assert resp.status_code == 401
|
|
|
|
|
|
async def test_poll_video_status_upstream_error(client):
|
|
token = await _user_token(client)
|
|
with patch("app.routers.generate.openrouter.poll_video_status", new_callable=AsyncMock, side_effect=Exception("timeout")):
|
|
resp = await client.get(
|
|
"/generate/video/status",
|
|
params={"polling_url": "https://openrouter.ai/api/v1/videos/gen-vid-1"},
|
|
headers={"Authorization": f"Bearer {token}"},
|
|
)
|
|
assert resp.status_code == 502
|
|
|
|
|
|
async def test_generate_video_from_image_unauthenticated(client):
|
|
resp = await client.post(
|
|
"/generate/video/from-image",
|
|
json={"model": "m", "image_url": "https://example.com/img.jpg", "prompt": "p"},
|
|
)
|
|
assert resp.status_code == 401
|
|
|
|
|
|
async def test_generate_video_from_image_upstream_error(client):
|
|
token = await _user_token(client)
|
|
with patch("app.routers.generate.openrouter.generate_video_from_image", new_callable=AsyncMock, side_effect=Exception("error")):
|
|
resp = await client.post(
|
|
"/generate/video/from-image",
|
|
json={"model": "runway/gen-3",
|
|
"image_url": "https://example.com/img.jpg", "prompt": "p"},
|
|
headers={"Authorization": f"Bearer {token}"},
|
|
)
|
|
assert resp.status_code == 502
|