add AI and generation routers, models, and OpenRouter service integration with tests
Co-authored-by: Copilot <copilot@github.com>
This commit is contained in:
@@ -0,0 +1,172 @@
|
||||
"""Tests for AI endpoints — OpenRouter HTTP calls are fully mocked."""
|
||||
import os
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from unittest.mock import AsyncMock, patch
|
||||
from httpx import AsyncClient, ASGITransport
|
||||
|
||||
from backend.app.main import app
|
||||
from backend.app import db as db_module
|
||||
|
||||
os.environ.setdefault("JWT_SECRET", "test-secret-key-for-testing-only")
|
||||
os.environ.setdefault("OPENROUTER_API_KEY", "test-key")
|
||||
|
||||
FAKE_MODELS = [
|
||||
{"id": "openai/gpt-4o", "name": "GPT-4o", "context_length": 128000, "pricing": {"prompt": "0.000005"}},
|
||||
{"id": "anthropic/claude-3-haiku", "name": "Claude 3 Haiku", "context_length": 200000, "pricing": {}},
|
||||
]
|
||||
|
||||
FAKE_CHAT_RESPONSE = {
|
||||
"id": "gen-abc123",
|
||||
"model": "openai/gpt-4o",
|
||||
"choices": [{"message": {"role": "assistant", "content": "Hello! How can I help?"}}],
|
||||
"usage": {"prompt_tokens": 10, "completion_tokens": 8, "total_tokens": 18},
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def fresh_db():
|
||||
db_module._conn = None
|
||||
db_module.init_db(":memory:")
|
||||
yield
|
||||
db_module.close_db()
|
||||
db_module._conn = None
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def client(fresh_db):
|
||||
transport = ASGITransport(app=app)
|
||||
async with AsyncClient(transport=transport, base_url="http://test") as ac:
|
||||
yield ac
|
||||
|
||||
|
||||
async def _user_token(client):
|
||||
await client.post("/auth/register", json={"email": "user@example.com", "password": "secret123"})
|
||||
resp = await client.post("/auth/login", json={"email": "user@example.com", "password": "secret123"})
|
||||
return resp.json()["access_token"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GET /ai/models
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def test_list_models(client):
|
||||
token = await _user_token(client)
|
||||
with patch(
|
||||
"backend.app.routers.ai.openrouter.list_models",
|
||||
new_callable=AsyncMock,
|
||||
return_value=FAKE_MODELS,
|
||||
):
|
||||
resp = await client.get("/ai/models", headers={"Authorization": f"Bearer {token}"})
|
||||
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert len(data) == 2
|
||||
assert data[0]["id"] == "openai/gpt-4o"
|
||||
assert data[1]["name"] == "Claude 3 Haiku"
|
||||
|
||||
|
||||
async def test_list_models_unauthenticated(client):
|
||||
resp = await client.get("/ai/models")
|
||||
assert resp.status_code == 401
|
||||
|
||||
|
||||
async def test_list_models_upstream_error(client):
|
||||
token = await _user_token(client)
|
||||
with patch(
|
||||
"backend.app.routers.ai.openrouter.list_models",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=Exception("Connection refused"),
|
||||
):
|
||||
resp = await client.get("/ai/models", headers={"Authorization": f"Bearer {token}"})
|
||||
|
||||
assert resp.status_code == 502
|
||||
assert "OpenRouter error" in resp.json()["detail"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# POST /ai/chat
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def test_chat_success(client):
|
||||
token = await _user_token(client)
|
||||
with patch(
|
||||
"backend.app.routers.ai.openrouter.chat_completion",
|
||||
new_callable=AsyncMock,
|
||||
return_value=FAKE_CHAT_RESPONSE,
|
||||
):
|
||||
resp = await client.post(
|
||||
"/ai/chat",
|
||||
json={
|
||||
"model": "openai/gpt-4o",
|
||||
"messages": [{"role": "user", "content": "Hello"}],
|
||||
},
|
||||
headers={"Authorization": f"Bearer {token}"},
|
||||
)
|
||||
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["id"] == "gen-abc123"
|
||||
assert data["model"] == "openai/gpt-4o"
|
||||
assert data["content"] == "Hello! How can I help?"
|
||||
assert data["usage"]["total_tokens"] == 18
|
||||
|
||||
|
||||
async def test_chat_passes_parameters(client):
|
||||
token = await _user_token(client)
|
||||
mock = AsyncMock(return_value=FAKE_CHAT_RESPONSE)
|
||||
with patch("backend.app.routers.ai.openrouter.chat_completion", new_callable=AsyncMock, return_value=FAKE_CHAT_RESPONSE) as mock:
|
||||
await client.post(
|
||||
"/ai/chat",
|
||||
json={
|
||||
"model": "anthropic/claude-3-haiku",
|
||||
"messages": [{"role": "user", "content": "Hi"}],
|
||||
"temperature": 0.3,
|
||||
"max_tokens": 512,
|
||||
},
|
||||
headers={"Authorization": f"Bearer {token}"},
|
||||
)
|
||||
mock.assert_called_once_with(
|
||||
model="anthropic/claude-3-haiku",
|
||||
messages=[{"role": "user", "content": "Hi"}],
|
||||
temperature=0.3,
|
||||
max_tokens=512,
|
||||
)
|
||||
|
||||
|
||||
async def test_chat_unauthenticated(client):
|
||||
resp = await client.post(
|
||||
"/ai/chat",
|
||||
json={"model": "openai/gpt-4o", "messages": [{"role": "user", "content": "Hi"}]},
|
||||
)
|
||||
assert resp.status_code == 401
|
||||
|
||||
|
||||
async def test_chat_upstream_error(client):
|
||||
token = await _user_token(client)
|
||||
with patch(
|
||||
"backend.app.routers.ai.openrouter.chat_completion",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=Exception("timeout"),
|
||||
):
|
||||
resp = await client.post(
|
||||
"/ai/chat",
|
||||
json={"model": "openai/gpt-4o", "messages": [{"role": "user", "content": "Hi"}]},
|
||||
headers={"Authorization": f"Bearer {token}"},
|
||||
)
|
||||
assert resp.status_code == 502
|
||||
|
||||
|
||||
async def test_chat_malformed_upstream_response(client):
|
||||
token = await _user_token(client)
|
||||
with patch(
|
||||
"backend.app.routers.ai.openrouter.chat_completion",
|
||||
new_callable=AsyncMock,
|
||||
return_value={"id": "x", "choices": []}, # empty choices
|
||||
):
|
||||
resp = await client.post(
|
||||
"/ai/chat",
|
||||
json={"model": "openai/gpt-4o", "messages": [{"role": "user", "content": "Hi"}]},
|
||||
headers={"Authorization": f"Bearer {token}"},
|
||||
)
|
||||
assert resp.status_code == 502
|
||||
@@ -0,0 +1,231 @@
|
||||
"""Tests for generate endpoints — all OpenRouter calls mocked."""
|
||||
import os
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from unittest.mock import AsyncMock, patch
|
||||
from httpx import AsyncClient, ASGITransport
|
||||
|
||||
from backend.app.main import app
|
||||
from backend.app import db as db_module
|
||||
|
||||
os.environ.setdefault("JWT_SECRET", "test-secret-key-for-testing-only")
|
||||
os.environ.setdefault("OPENROUTER_API_KEY", "test-key")
|
||||
|
||||
FAKE_CHAT = {
|
||||
"id": "gen-text-1",
|
||||
"model": "openai/gpt-4o",
|
||||
"choices": [{"message": {"role": "assistant", "content": "Once upon a time..."}}],
|
||||
"usage": {"prompt_tokens": 5, "completion_tokens": 10, "total_tokens": 15},
|
||||
}
|
||||
|
||||
FAKE_IMAGE = {
|
||||
"id": "gen-img-1",
|
||||
"model": "openai/dall-e-3",
|
||||
"data": [
|
||||
{"url": "https://example.com/image.png",
|
||||
"revised_prompt": "A cat on the moon"},
|
||||
],
|
||||
}
|
||||
|
||||
FAKE_VIDEO = {
|
||||
"id": "gen-vid-1",
|
||||
"model": "stability/stable-video",
|
||||
"status": "queued",
|
||||
"video_url": None,
|
||||
"metadata": {"estimated_seconds": 30},
|
||||
}
|
||||
|
||||
FAKE_VIDEO_DONE = {
|
||||
"id": "gen-vid-2",
|
||||
"model": "runway/gen-3",
|
||||
"status": "completed",
|
||||
"video_url": "https://example.com/video.mp4",
|
||||
"metadata": None,
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def fresh_db():
|
||||
db_module._conn = None
|
||||
db_module.init_db(":memory:")
|
||||
yield
|
||||
db_module.close_db()
|
||||
db_module._conn = None
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def client(fresh_db):
|
||||
transport = ASGITransport(app=app)
|
||||
async with AsyncClient(transport=transport, base_url="http://test") as ac:
|
||||
yield ac
|
||||
|
||||
|
||||
async def _user_token(client):
|
||||
await client.post("/auth/register", json={"email": "user@example.com", "password": "secret123"})
|
||||
resp = await client.post("/auth/login", json={"email": "user@example.com", "password": "secret123"})
|
||||
return resp.json()["access_token"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# POST /generate/text
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def test_generate_text(client):
|
||||
token = await _user_token(client)
|
||||
with patch("backend.app.routers.generate.openrouter.chat_completion", new_callable=AsyncMock, return_value=FAKE_CHAT):
|
||||
resp = await client.post(
|
||||
"/generate/text",
|
||||
json={"model": "openai/gpt-4o", "prompt": "Tell me a story"},
|
||||
headers={"Authorization": f"Bearer {token}"},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["content"] == "Once upon a time..."
|
||||
assert data["id"] == "gen-text-1"
|
||||
assert data["usage"]["total_tokens"] == 15
|
||||
|
||||
|
||||
async def test_generate_text_with_system_prompt(client):
|
||||
token = await _user_token(client)
|
||||
mock = AsyncMock(return_value=FAKE_CHAT)
|
||||
with patch("backend.app.routers.generate.openrouter.chat_completion", mock):
|
||||
await client.post(
|
||||
"/generate/text",
|
||||
json={"model": "openai/gpt-4o", "prompt": "Hello",
|
||||
"system_prompt": "Be concise."},
|
||||
headers={"Authorization": f"Bearer {token}"},
|
||||
)
|
||||
call_messages = mock.call_args.kwargs["messages"]
|
||||
assert call_messages[0] == {"role": "system", "content": "Be concise."}
|
||||
assert call_messages[1] == {"role": "user", "content": "Hello"}
|
||||
|
||||
|
||||
async def test_generate_text_unauthenticated(client):
|
||||
resp = await client.post("/generate/text", json={"model": "openai/gpt-4o", "prompt": "Hi"})
|
||||
assert resp.status_code == 401
|
||||
|
||||
|
||||
async def test_generate_text_upstream_error(client):
|
||||
token = await _user_token(client)
|
||||
with patch("backend.app.routers.generate.openrouter.chat_completion", new_callable=AsyncMock, side_effect=Exception("timeout")):
|
||||
resp = await client.post(
|
||||
"/generate/text",
|
||||
json={"model": "openai/gpt-4o", "prompt": "Hi"},
|
||||
headers={"Authorization": f"Bearer {token}"},
|
||||
)
|
||||
assert resp.status_code == 502
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# POST /generate/image
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def test_generate_image(client):
|
||||
token = await _user_token(client)
|
||||
with patch("backend.app.routers.generate.openrouter.generate_image", new_callable=AsyncMock, return_value=FAKE_IMAGE):
|
||||
resp = await client.post(
|
||||
"/generate/image",
|
||||
json={"model": "openai/dall-e-3", "prompt": "A cat on the moon"},
|
||||
headers={"Authorization": f"Bearer {token}"},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["id"] == "gen-img-1"
|
||||
assert len(data["images"]) == 1
|
||||
assert data["images"][0]["url"] == "https://example.com/image.png"
|
||||
assert data["images"][0]["revised_prompt"] == "A cat on the moon"
|
||||
|
||||
|
||||
async def test_generate_image_unauthenticated(client):
|
||||
resp = await client.post("/generate/image", json={"model": "openai/dall-e-3", "prompt": "Hi"})
|
||||
assert resp.status_code == 401
|
||||
|
||||
|
||||
async def test_generate_image_upstream_error(client):
|
||||
token = await _user_token(client)
|
||||
with patch("backend.app.routers.generate.openrouter.generate_image", new_callable=AsyncMock, side_effect=Exception("rate limit")):
|
||||
resp = await client.post(
|
||||
"/generate/image",
|
||||
json={"model": "openai/dall-e-3", "prompt": "Hi"},
|
||||
headers={"Authorization": f"Bearer {token}"},
|
||||
)
|
||||
assert resp.status_code == 502
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# POST /generate/video
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def test_generate_video(client):
|
||||
token = await _user_token(client)
|
||||
with patch("backend.app.routers.generate.openrouter.generate_video", new_callable=AsyncMock, return_value=FAKE_VIDEO):
|
||||
resp = await client.post(
|
||||
"/generate/video",
|
||||
json={"model": "stability/stable-video",
|
||||
"prompt": "Ocean waves at sunset"},
|
||||
headers={"Authorization": f"Bearer {token}"},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["id"] == "gen-vid-1"
|
||||
assert data["status"] == "queued"
|
||||
assert data["video_url"] is None
|
||||
assert data["metadata"]["estimated_seconds"] == 30
|
||||
|
||||
|
||||
async def test_generate_video_unauthenticated(client):
|
||||
resp = await client.post("/generate/video", json={"model": "m", "prompt": "p"})
|
||||
assert resp.status_code == 401
|
||||
|
||||
|
||||
async def test_generate_video_upstream_error(client):
|
||||
token = await _user_token(client)
|
||||
with patch("backend.app.routers.generate.openrouter.generate_video", new_callable=AsyncMock, side_effect=Exception("503")):
|
||||
resp = await client.post(
|
||||
"/generate/video",
|
||||
json={"model": "stability/stable-video", "prompt": "Hi"},
|
||||
headers={"Authorization": f"Bearer {token}"},
|
||||
)
|
||||
assert resp.status_code == 502
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# POST /generate/video/from-image
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def test_generate_video_from_image(client):
|
||||
token = await _user_token(client)
|
||||
with patch("backend.app.routers.generate.openrouter.generate_video_from_image", new_callable=AsyncMock, return_value=FAKE_VIDEO_DONE):
|
||||
resp = await client.post(
|
||||
"/generate/video/from-image",
|
||||
json={
|
||||
"model": "runway/gen-3",
|
||||
"image_url": "https://example.com/cat.jpg",
|
||||
"prompt": "Cat runs across the room",
|
||||
},
|
||||
headers={"Authorization": f"Bearer {token}"},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["status"] == "completed"
|
||||
assert data["video_url"] == "https://example.com/video.mp4"
|
||||
|
||||
|
||||
async def test_generate_video_from_image_unauthenticated(client):
|
||||
resp = await client.post(
|
||||
"/generate/video/from-image",
|
||||
json={"model": "m", "image_url": "https://example.com/img.jpg", "prompt": "p"},
|
||||
)
|
||||
assert resp.status_code == 401
|
||||
|
||||
|
||||
async def test_generate_video_from_image_upstream_error(client):
|
||||
token = await _user_token(client)
|
||||
with patch("backend.app.routers.generate.openrouter.generate_video_from_image", new_callable=AsyncMock, side_effect=Exception("error")):
|
||||
resp = await client.post(
|
||||
"/generate/video/from-image",
|
||||
json={"model": "runway/gen-3",
|
||||
"image_url": "https://example.com/img.jpg", "prompt": "p"},
|
||||
headers={"Authorization": f"Bearer {token}"},
|
||||
)
|
||||
assert resp.status_code == 502
|
||||
Reference in New Issue
Block a user