add AI and generation routers, models, and OpenRouter service integration with tests
Co-authored-by: Copilot <copilot@github.com>
This commit is contained in:
@@ -1,6 +1,8 @@
|
|||||||
from backend.app.routers import auth as auth_router
|
from backend.app.routers import auth as auth_router
|
||||||
from backend.app.routers import users as users_router
|
from backend.app.routers import users as users_router
|
||||||
from backend.app.routers import admin as admin_router
|
from backend.app.routers import admin as admin_router
|
||||||
|
from backend.app.routers import ai as ai_router
|
||||||
|
from backend.app.routers import generate as generate_router
|
||||||
from backend.app.db import close_db, init_db
|
from backend.app.db import close_db, init_db
|
||||||
import os
|
import os
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
@@ -37,6 +39,8 @@ app.add_middleware(
|
|||||||
app.include_router(auth_router.router)
|
app.include_router(auth_router.router)
|
||||||
app.include_router(users_router.router)
|
app.include_router(users_router.router)
|
||||||
app.include_router(admin_router.router)
|
app.include_router(admin_router.router)
|
||||||
|
app.include_router(ai_router.router)
|
||||||
|
app.include_router(generate_router.router)
|
||||||
|
|
||||||
|
|
||||||
@app.get("/health", tags=["health"])
|
@app.get("/health", tags=["health"])
|
||||||
|
|||||||
@@ -0,0 +1,92 @@
|
|||||||
|
"""Pydantic schemas for AI generation endpoints."""
|
||||||
|
from typing import Any
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class ChatMessage(BaseModel):
|
||||||
|
role: str # "user" | "assistant" | "system"
|
||||||
|
content: str
|
||||||
|
|
||||||
|
|
||||||
|
class ChatRequest(BaseModel):
|
||||||
|
model: str
|
||||||
|
messages: list[ChatMessage]
|
||||||
|
temperature: float = 0.7
|
||||||
|
max_tokens: int = 1024
|
||||||
|
|
||||||
|
|
||||||
|
class ChatResponse(BaseModel):
|
||||||
|
id: str
|
||||||
|
model: str
|
||||||
|
content: str
|
||||||
|
usage: dict[str, Any] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class ModelInfo(BaseModel):
|
||||||
|
id: str
|
||||||
|
name: str
|
||||||
|
context_length: int | None = None
|
||||||
|
pricing: dict[str, Any] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
# --- Text generation ---
|
||||||
|
|
||||||
|
class TextRequest(BaseModel):
|
||||||
|
model: str
|
||||||
|
prompt: str
|
||||||
|
system_prompt: str | None = None
|
||||||
|
temperature: float = 0.7
|
||||||
|
max_tokens: int = 1024
|
||||||
|
|
||||||
|
|
||||||
|
class TextResponse(BaseModel):
|
||||||
|
id: str
|
||||||
|
model: str
|
||||||
|
content: str
|
||||||
|
usage: dict[str, Any] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
# --- Image generation ---
|
||||||
|
|
||||||
|
class ImageRequest(BaseModel):
|
||||||
|
model: str
|
||||||
|
prompt: str
|
||||||
|
n: int = 1
|
||||||
|
size: str = "1024x1024"
|
||||||
|
|
||||||
|
|
||||||
|
class ImageResult(BaseModel):
|
||||||
|
url: str | None = None
|
||||||
|
b64_json: str | None = None
|
||||||
|
revised_prompt: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class ImageResponse(BaseModel):
|
||||||
|
id: str
|
||||||
|
model: str
|
||||||
|
images: list[ImageResult]
|
||||||
|
|
||||||
|
|
||||||
|
# --- Video generation ---
|
||||||
|
|
||||||
|
class VideoRequest(BaseModel):
|
||||||
|
model: str
|
||||||
|
prompt: str
|
||||||
|
duration_seconds: int | None = None
|
||||||
|
aspect_ratio: str = "16:9"
|
||||||
|
|
||||||
|
|
||||||
|
class VideoFromImageRequest(BaseModel):
|
||||||
|
model: str
|
||||||
|
image_url: str
|
||||||
|
prompt: str
|
||||||
|
duration_seconds: int | None = None
|
||||||
|
aspect_ratio: str = "16:9"
|
||||||
|
|
||||||
|
|
||||||
|
class VideoResponse(BaseModel):
|
||||||
|
id: str
|
||||||
|
model: str
|
||||||
|
status: str # "queued" | "processing" | "completed"
|
||||||
|
video_url: str | None = None
|
||||||
|
metadata: dict[str, Any] | None = None
|
||||||
@@ -0,0 +1,63 @@
|
|||||||
|
"""AI router: model listing and chat completions via OpenRouter."""
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, status
|
||||||
|
|
||||||
|
from backend.app.dependencies import get_current_user
|
||||||
|
from backend.app.models.ai import ChatRequest, ChatResponse, ModelInfo
|
||||||
|
from backend.app.services import openrouter
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/ai", tags=["ai"])
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/models", response_model=list[ModelInfo])
|
||||||
|
async def get_models(_: dict = Depends(get_current_user)) -> list[ModelInfo]:
|
||||||
|
"""List available AI models from OpenRouter."""
|
||||||
|
try:
|
||||||
|
raw = await openrouter.list_models()
|
||||||
|
except Exception as exc:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_502_BAD_GATEWAY,
|
||||||
|
detail=f"OpenRouter error: {exc}",
|
||||||
|
)
|
||||||
|
return [
|
||||||
|
ModelInfo(
|
||||||
|
id=m.get("id", ""),
|
||||||
|
name=m.get("name", m.get("id", "")),
|
||||||
|
context_length=m.get("context_length"),
|
||||||
|
pricing=m.get("pricing"),
|
||||||
|
)
|
||||||
|
for m in raw
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/chat", response_model=ChatResponse)
|
||||||
|
async def chat(
|
||||||
|
body: ChatRequest,
|
||||||
|
_: dict = Depends(get_current_user),
|
||||||
|
) -> ChatResponse:
|
||||||
|
"""Send a chat completion request through OpenRouter."""
|
||||||
|
try:
|
||||||
|
result = await openrouter.chat_completion(
|
||||||
|
model=body.model,
|
||||||
|
messages=[m.model_dump() for m in body.messages],
|
||||||
|
temperature=body.temperature,
|
||||||
|
max_tokens=body.max_tokens,
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_502_BAD_GATEWAY,
|
||||||
|
detail=f"OpenRouter error: {exc}",
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
choice = result["choices"][0]
|
||||||
|
return ChatResponse(
|
||||||
|
id=result["id"],
|
||||||
|
model=result.get("model", body.model),
|
||||||
|
content=choice["message"]["content"],
|
||||||
|
usage=result.get("usage"),
|
||||||
|
)
|
||||||
|
except (KeyError, IndexError) as exc:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_502_BAD_GATEWAY,
|
||||||
|
detail=f"Unexpected response format from OpenRouter: {exc}",
|
||||||
|
)
|
||||||
@@ -0,0 +1,141 @@
|
|||||||
|
"""Generate router: text, image, video, and image-to-video generation."""
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, status
|
||||||
|
|
||||||
|
from backend.app.dependencies import get_current_user
|
||||||
|
from backend.app.models.ai import (
|
||||||
|
ImageRequest,
|
||||||
|
ImageResponse,
|
||||||
|
ImageResult,
|
||||||
|
TextRequest,
|
||||||
|
TextResponse,
|
||||||
|
VideoFromImageRequest,
|
||||||
|
VideoRequest,
|
||||||
|
VideoResponse,
|
||||||
|
)
|
||||||
|
from backend.app.services import openrouter
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/generate", tags=["generate"])
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/text", response_model=TextResponse)
|
||||||
|
async def generate_text(
|
||||||
|
body: TextRequest,
|
||||||
|
_: dict = Depends(get_current_user),
|
||||||
|
) -> TextResponse:
|
||||||
|
"""Generate text from a prompt using a chat model."""
|
||||||
|
messages = []
|
||||||
|
if body.system_prompt:
|
||||||
|
messages.append({"role": "system", "content": body.system_prompt})
|
||||||
|
messages.append({"role": "user", "content": body.prompt})
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = await openrouter.chat_completion(
|
||||||
|
model=body.model,
|
||||||
|
messages=messages,
|
||||||
|
temperature=body.temperature,
|
||||||
|
max_tokens=body.max_tokens,
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_502_BAD_GATEWAY, detail=f"OpenRouter error: {exc}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
choice = result["choices"][0]
|
||||||
|
return TextResponse(
|
||||||
|
id=result["id"],
|
||||||
|
model=result.get("model", body.model),
|
||||||
|
content=choice["message"]["content"],
|
||||||
|
usage=result.get("usage"),
|
||||||
|
)
|
||||||
|
except (KeyError, IndexError) as exc:
|
||||||
|
raise HTTPException(status_code=status.HTTP_502_BAD_GATEWAY,
|
||||||
|
detail=f"Unexpected response format: {exc}")
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/image", response_model=ImageResponse)
|
||||||
|
async def generate_image(
|
||||||
|
body: ImageRequest,
|
||||||
|
_: dict = Depends(get_current_user),
|
||||||
|
) -> ImageResponse:
|
||||||
|
"""Generate images from a text prompt."""
|
||||||
|
try:
|
||||||
|
result = await openrouter.generate_image(
|
||||||
|
model=body.model,
|
||||||
|
prompt=body.prompt,
|
||||||
|
n=body.n,
|
||||||
|
size=body.size,
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_502_BAD_GATEWAY, detail=f"OpenRouter error: {exc}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
images = [
|
||||||
|
ImageResult(
|
||||||
|
url=item.get("url"),
|
||||||
|
b64_json=item.get("b64_json"),
|
||||||
|
revised_prompt=item.get("revised_prompt"),
|
||||||
|
)
|
||||||
|
for item in result.get("data", [])
|
||||||
|
]
|
||||||
|
return ImageResponse(
|
||||||
|
id=result.get("id", ""),
|
||||||
|
model=result.get("model", body.model),
|
||||||
|
images=images,
|
||||||
|
)
|
||||||
|
except (KeyError, TypeError) as exc:
|
||||||
|
raise HTTPException(status_code=status.HTTP_502_BAD_GATEWAY,
|
||||||
|
detail=f"Unexpected response format: {exc}")
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/video", response_model=VideoResponse)
|
||||||
|
async def generate_video(
|
||||||
|
body: VideoRequest,
|
||||||
|
_: dict = Depends(get_current_user),
|
||||||
|
) -> VideoResponse:
|
||||||
|
"""Generate a video from a text prompt."""
|
||||||
|
try:
|
||||||
|
result = await openrouter.generate_video(
|
||||||
|
model=body.model,
|
||||||
|
prompt=body.prompt,
|
||||||
|
duration_seconds=body.duration_seconds,
|
||||||
|
aspect_ratio=body.aspect_ratio,
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_502_BAD_GATEWAY, detail=f"OpenRouter error: {exc}")
|
||||||
|
|
||||||
|
return VideoResponse(
|
||||||
|
id=result.get("id", ""),
|
||||||
|
model=result.get("model", body.model),
|
||||||
|
status=result.get("status", "queued"),
|
||||||
|
video_url=result.get("video_url"),
|
||||||
|
metadata=result.get("metadata"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/video/from-image", response_model=VideoResponse)
|
||||||
|
async def generate_video_from_image(
|
||||||
|
body: VideoFromImageRequest,
|
||||||
|
_: dict = Depends(get_current_user),
|
||||||
|
) -> VideoResponse:
|
||||||
|
"""Generate a video from an image and a text prompt."""
|
||||||
|
try:
|
||||||
|
result = await openrouter.generate_video_from_image(
|
||||||
|
model=body.model,
|
||||||
|
image_url=body.image_url,
|
||||||
|
prompt=body.prompt,
|
||||||
|
duration_seconds=body.duration_seconds,
|
||||||
|
aspect_ratio=body.aspect_ratio,
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_502_BAD_GATEWAY, detail=f"OpenRouter error: {exc}")
|
||||||
|
|
||||||
|
return VideoResponse(
|
||||||
|
id=result.get("id", ""),
|
||||||
|
model=result.get("model", body.model),
|
||||||
|
status=result.get("status", "queued"),
|
||||||
|
video_url=result.get("video_url"),
|
||||||
|
metadata=result.get("metadata"),
|
||||||
|
)
|
||||||
@@ -0,0 +1,123 @@
|
|||||||
|
"""OpenRouter API client (OpenAI-compatible interface)."""
|
||||||
|
import os
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
OPENROUTER_BASE_URL = "https://openrouter.ai/api/v1"
|
||||||
|
|
||||||
|
|
||||||
|
def _api_key() -> str:
|
||||||
|
key = os.getenv("OPENROUTER_API_KEY")
|
||||||
|
if not key:
|
||||||
|
raise RuntimeError(
|
||||||
|
"OPENROUTER_API_KEY environment variable is not set.")
|
||||||
|
return key
|
||||||
|
|
||||||
|
|
||||||
|
def _headers() -> dict[str, str]:
|
||||||
|
return {
|
||||||
|
"Authorization": f"Bearer {_api_key()}",
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"HTTP-Referer": os.getenv("APP_URL", "https://ai.allucanget.biz"),
|
||||||
|
"X-Title": os.getenv("APP_NAME", "AI Allucanget"),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
async def list_models() -> list[dict[str, Any]]:
|
||||||
|
"""Return available models from OpenRouter."""
|
||||||
|
base_url = os.getenv("OPENROUTER_BASE_URL", OPENROUTER_BASE_URL)
|
||||||
|
async with httpx.AsyncClient(timeout=15) as client:
|
||||||
|
resp = client.build_request(
|
||||||
|
"GET", f"{base_url}/models", headers=_headers())
|
||||||
|
response = await client.send(resp)
|
||||||
|
response.raise_for_status()
|
||||||
|
return response.json().get("data", [])
|
||||||
|
|
||||||
|
|
||||||
|
async def chat_completion(
|
||||||
|
model: str,
|
||||||
|
messages: list[dict[str, str]],
|
||||||
|
temperature: float = 0.7,
|
||||||
|
max_tokens: int = 1024,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Send a chat completion request to OpenRouter."""
|
||||||
|
base_url = os.getenv("OPENROUTER_BASE_URL", OPENROUTER_BASE_URL)
|
||||||
|
payload = {
|
||||||
|
"model": model,
|
||||||
|
"messages": messages,
|
||||||
|
"temperature": temperature,
|
||||||
|
"max_tokens": max_tokens,
|
||||||
|
}
|
||||||
|
async with httpx.AsyncClient(timeout=60) as client:
|
||||||
|
response = await client.send(resp)
|
||||||
|
response.raise_for_status()
|
||||||
|
return response.json()
|
||||||
|
|
||||||
|
|
||||||
|
async def generate_image(
|
||||||
|
model: str,
|
||||||
|
prompt: str,
|
||||||
|
n: int = 1,
|
||||||
|
size: str = "1024x1024",
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Request image generation via OpenRouter /images/generations."""
|
||||||
|
base_url = os.getenv("OPENROUTER_BASE_URL", OPENROUTER_BASE_URL)
|
||||||
|
payload = {"model": model, "prompt": prompt, "n": n, "size": size}
|
||||||
|
async with httpx.AsyncClient(timeout=120) as client:
|
||||||
|
resp = client.build_request(
|
||||||
|
"POST", f"{base_url}/images/generations", headers=_headers(), json=payload
|
||||||
|
)
|
||||||
|
response = await client.send(resp)
|
||||||
|
response.raise_for_status()
|
||||||
|
return response.json()
|
||||||
|
|
||||||
|
|
||||||
|
async def generate_video(
|
||||||
|
model: str,
|
||||||
|
prompt: str,
|
||||||
|
duration_seconds: int | None = None,
|
||||||
|
aspect_ratio: str = "16:9",
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Request text-to-video generation via OpenRouter."""
|
||||||
|
base_url = os.getenv("OPENROUTER_BASE_URL", OPENROUTER_BASE_URL)
|
||||||
|
payload: dict[str, Any] = {
|
||||||
|
"model": model,
|
||||||
|
"prompt": prompt,
|
||||||
|
"aspect_ratio": aspect_ratio,
|
||||||
|
}
|
||||||
|
if duration_seconds is not None:
|
||||||
|
payload["duration_seconds"] = duration_seconds
|
||||||
|
async with httpx.AsyncClient(timeout=120) as client:
|
||||||
|
resp = client.build_request(
|
||||||
|
"POST", f"{base_url}/video/generations", headers=_headers(), json=payload
|
||||||
|
)
|
||||||
|
response = await client.send(resp)
|
||||||
|
response.raise_for_status()
|
||||||
|
return response.json()
|
||||||
|
|
||||||
|
|
||||||
|
async def generate_video_from_image(
|
||||||
|
model: str,
|
||||||
|
image_url: str,
|
||||||
|
prompt: str,
|
||||||
|
duration_seconds: int | None = None,
|
||||||
|
aspect_ratio: str = "16:9",
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Request image-to-video generation via OpenRouter."""
|
||||||
|
base_url = os.getenv("OPENROUTER_BASE_URL", OPENROUTER_BASE_URL)
|
||||||
|
payload: dict[str, Any] = {
|
||||||
|
"model": model,
|
||||||
|
"image_url": image_url,
|
||||||
|
"prompt": prompt,
|
||||||
|
"aspect_ratio": aspect_ratio,
|
||||||
|
}
|
||||||
|
if duration_seconds is not None:
|
||||||
|
payload["duration_seconds"] = duration_seconds
|
||||||
|
async with httpx.AsyncClient(timeout=120) as client:
|
||||||
|
resp = client.build_request(
|
||||||
|
"POST", f"{base_url}/video/generations/from-image", headers=_headers(), json=payload
|
||||||
|
)
|
||||||
|
response = await client.send(resp)
|
||||||
|
response.raise_for_status()
|
||||||
|
return response.json()
|
||||||
@@ -0,0 +1,172 @@
|
|||||||
|
"""Tests for AI endpoints — OpenRouter HTTP calls are fully mocked."""
|
||||||
|
import os
|
||||||
|
import pytest
|
||||||
|
import pytest_asyncio
|
||||||
|
from unittest.mock import AsyncMock, patch
|
||||||
|
from httpx import AsyncClient, ASGITransport
|
||||||
|
|
||||||
|
from backend.app.main import app
|
||||||
|
from backend.app import db as db_module
|
||||||
|
|
||||||
|
os.environ.setdefault("JWT_SECRET", "test-secret-key-for-testing-only")
|
||||||
|
os.environ.setdefault("OPENROUTER_API_KEY", "test-key")
|
||||||
|
|
||||||
|
FAKE_MODELS = [
|
||||||
|
{"id": "openai/gpt-4o", "name": "GPT-4o", "context_length": 128000, "pricing": {"prompt": "0.000005"}},
|
||||||
|
{"id": "anthropic/claude-3-haiku", "name": "Claude 3 Haiku", "context_length": 200000, "pricing": {}},
|
||||||
|
]
|
||||||
|
|
||||||
|
FAKE_CHAT_RESPONSE = {
|
||||||
|
"id": "gen-abc123",
|
||||||
|
"model": "openai/gpt-4o",
|
||||||
|
"choices": [{"message": {"role": "assistant", "content": "Hello! How can I help?"}}],
|
||||||
|
"usage": {"prompt_tokens": 10, "completion_tokens": 8, "total_tokens": 18},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def fresh_db():
|
||||||
|
db_module._conn = None
|
||||||
|
db_module.init_db(":memory:")
|
||||||
|
yield
|
||||||
|
db_module.close_db()
|
||||||
|
db_module._conn = None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def client(fresh_db):
|
||||||
|
transport = ASGITransport(app=app)
|
||||||
|
async with AsyncClient(transport=transport, base_url="http://test") as ac:
|
||||||
|
yield ac
|
||||||
|
|
||||||
|
|
||||||
|
async def _user_token(client):
|
||||||
|
await client.post("/auth/register", json={"email": "user@example.com", "password": "secret123"})
|
||||||
|
resp = await client.post("/auth/login", json={"email": "user@example.com", "password": "secret123"})
|
||||||
|
return resp.json()["access_token"]
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# GET /ai/models
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def test_list_models(client):
|
||||||
|
token = await _user_token(client)
|
||||||
|
with patch(
|
||||||
|
"backend.app.routers.ai.openrouter.list_models",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
return_value=FAKE_MODELS,
|
||||||
|
):
|
||||||
|
resp = await client.get("/ai/models", headers={"Authorization": f"Bearer {token}"})
|
||||||
|
|
||||||
|
assert resp.status_code == 200
|
||||||
|
data = resp.json()
|
||||||
|
assert len(data) == 2
|
||||||
|
assert data[0]["id"] == "openai/gpt-4o"
|
||||||
|
assert data[1]["name"] == "Claude 3 Haiku"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_list_models_unauthenticated(client):
|
||||||
|
resp = await client.get("/ai/models")
|
||||||
|
assert resp.status_code == 401
|
||||||
|
|
||||||
|
|
||||||
|
async def test_list_models_upstream_error(client):
|
||||||
|
token = await _user_token(client)
|
||||||
|
with patch(
|
||||||
|
"backend.app.routers.ai.openrouter.list_models",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
side_effect=Exception("Connection refused"),
|
||||||
|
):
|
||||||
|
resp = await client.get("/ai/models", headers={"Authorization": f"Bearer {token}"})
|
||||||
|
|
||||||
|
assert resp.status_code == 502
|
||||||
|
assert "OpenRouter error" in resp.json()["detail"]
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# POST /ai/chat
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def test_chat_success(client):
|
||||||
|
token = await _user_token(client)
|
||||||
|
with patch(
|
||||||
|
"backend.app.routers.ai.openrouter.chat_completion",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
return_value=FAKE_CHAT_RESPONSE,
|
||||||
|
):
|
||||||
|
resp = await client.post(
|
||||||
|
"/ai/chat",
|
||||||
|
json={
|
||||||
|
"model": "openai/gpt-4o",
|
||||||
|
"messages": [{"role": "user", "content": "Hello"}],
|
||||||
|
},
|
||||||
|
headers={"Authorization": f"Bearer {token}"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert resp.status_code == 200
|
||||||
|
data = resp.json()
|
||||||
|
assert data["id"] == "gen-abc123"
|
||||||
|
assert data["model"] == "openai/gpt-4o"
|
||||||
|
assert data["content"] == "Hello! How can I help?"
|
||||||
|
assert data["usage"]["total_tokens"] == 18
|
||||||
|
|
||||||
|
|
||||||
|
async def test_chat_passes_parameters(client):
|
||||||
|
token = await _user_token(client)
|
||||||
|
mock = AsyncMock(return_value=FAKE_CHAT_RESPONSE)
|
||||||
|
with patch("backend.app.routers.ai.openrouter.chat_completion", new_callable=AsyncMock, return_value=FAKE_CHAT_RESPONSE) as mock:
|
||||||
|
await client.post(
|
||||||
|
"/ai/chat",
|
||||||
|
json={
|
||||||
|
"model": "anthropic/claude-3-haiku",
|
||||||
|
"messages": [{"role": "user", "content": "Hi"}],
|
||||||
|
"temperature": 0.3,
|
||||||
|
"max_tokens": 512,
|
||||||
|
},
|
||||||
|
headers={"Authorization": f"Bearer {token}"},
|
||||||
|
)
|
||||||
|
mock.assert_called_once_with(
|
||||||
|
model="anthropic/claude-3-haiku",
|
||||||
|
messages=[{"role": "user", "content": "Hi"}],
|
||||||
|
temperature=0.3,
|
||||||
|
max_tokens=512,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_chat_unauthenticated(client):
|
||||||
|
resp = await client.post(
|
||||||
|
"/ai/chat",
|
||||||
|
json={"model": "openai/gpt-4o", "messages": [{"role": "user", "content": "Hi"}]},
|
||||||
|
)
|
||||||
|
assert resp.status_code == 401
|
||||||
|
|
||||||
|
|
||||||
|
async def test_chat_upstream_error(client):
|
||||||
|
token = await _user_token(client)
|
||||||
|
with patch(
|
||||||
|
"backend.app.routers.ai.openrouter.chat_completion",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
side_effect=Exception("timeout"),
|
||||||
|
):
|
||||||
|
resp = await client.post(
|
||||||
|
"/ai/chat",
|
||||||
|
json={"model": "openai/gpt-4o", "messages": [{"role": "user", "content": "Hi"}]},
|
||||||
|
headers={"Authorization": f"Bearer {token}"},
|
||||||
|
)
|
||||||
|
assert resp.status_code == 502
|
||||||
|
|
||||||
|
|
||||||
|
async def test_chat_malformed_upstream_response(client):
|
||||||
|
token = await _user_token(client)
|
||||||
|
with patch(
|
||||||
|
"backend.app.routers.ai.openrouter.chat_completion",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
return_value={"id": "x", "choices": []}, # empty choices
|
||||||
|
):
|
||||||
|
resp = await client.post(
|
||||||
|
"/ai/chat",
|
||||||
|
json={"model": "openai/gpt-4o", "messages": [{"role": "user", "content": "Hi"}]},
|
||||||
|
headers={"Authorization": f"Bearer {token}"},
|
||||||
|
)
|
||||||
|
assert resp.status_code == 502
|
||||||
@@ -0,0 +1,231 @@
|
|||||||
|
"""Tests for generate endpoints — all OpenRouter calls mocked."""
|
||||||
|
import os
|
||||||
|
import pytest
|
||||||
|
import pytest_asyncio
|
||||||
|
from unittest.mock import AsyncMock, patch
|
||||||
|
from httpx import AsyncClient, ASGITransport
|
||||||
|
|
||||||
|
from backend.app.main import app
|
||||||
|
from backend.app import db as db_module
|
||||||
|
|
||||||
|
os.environ.setdefault("JWT_SECRET", "test-secret-key-for-testing-only")
|
||||||
|
os.environ.setdefault("OPENROUTER_API_KEY", "test-key")
|
||||||
|
|
||||||
|
FAKE_CHAT = {
|
||||||
|
"id": "gen-text-1",
|
||||||
|
"model": "openai/gpt-4o",
|
||||||
|
"choices": [{"message": {"role": "assistant", "content": "Once upon a time..."}}],
|
||||||
|
"usage": {"prompt_tokens": 5, "completion_tokens": 10, "total_tokens": 15},
|
||||||
|
}
|
||||||
|
|
||||||
|
FAKE_IMAGE = {
|
||||||
|
"id": "gen-img-1",
|
||||||
|
"model": "openai/dall-e-3",
|
||||||
|
"data": [
|
||||||
|
{"url": "https://example.com/image.png",
|
||||||
|
"revised_prompt": "A cat on the moon"},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
FAKE_VIDEO = {
|
||||||
|
"id": "gen-vid-1",
|
||||||
|
"model": "stability/stable-video",
|
||||||
|
"status": "queued",
|
||||||
|
"video_url": None,
|
||||||
|
"metadata": {"estimated_seconds": 30},
|
||||||
|
}
|
||||||
|
|
||||||
|
FAKE_VIDEO_DONE = {
|
||||||
|
"id": "gen-vid-2",
|
||||||
|
"model": "runway/gen-3",
|
||||||
|
"status": "completed",
|
||||||
|
"video_url": "https://example.com/video.mp4",
|
||||||
|
"metadata": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def fresh_db():
|
||||||
|
db_module._conn = None
|
||||||
|
db_module.init_db(":memory:")
|
||||||
|
yield
|
||||||
|
db_module.close_db()
|
||||||
|
db_module._conn = None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def client(fresh_db):
|
||||||
|
transport = ASGITransport(app=app)
|
||||||
|
async with AsyncClient(transport=transport, base_url="http://test") as ac:
|
||||||
|
yield ac
|
||||||
|
|
||||||
|
|
||||||
|
async def _user_token(client):
|
||||||
|
await client.post("/auth/register", json={"email": "user@example.com", "password": "secret123"})
|
||||||
|
resp = await client.post("/auth/login", json={"email": "user@example.com", "password": "secret123"})
|
||||||
|
return resp.json()["access_token"]
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# POST /generate/text
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def test_generate_text(client):
|
||||||
|
token = await _user_token(client)
|
||||||
|
with patch("backend.app.routers.generate.openrouter.chat_completion", new_callable=AsyncMock, return_value=FAKE_CHAT):
|
||||||
|
resp = await client.post(
|
||||||
|
"/generate/text",
|
||||||
|
json={"model": "openai/gpt-4o", "prompt": "Tell me a story"},
|
||||||
|
headers={"Authorization": f"Bearer {token}"},
|
||||||
|
)
|
||||||
|
assert resp.status_code == 200
|
||||||
|
data = resp.json()
|
||||||
|
assert data["content"] == "Once upon a time..."
|
||||||
|
assert data["id"] == "gen-text-1"
|
||||||
|
assert data["usage"]["total_tokens"] == 15
|
||||||
|
|
||||||
|
|
||||||
|
async def test_generate_text_with_system_prompt(client):
|
||||||
|
token = await _user_token(client)
|
||||||
|
mock = AsyncMock(return_value=FAKE_CHAT)
|
||||||
|
with patch("backend.app.routers.generate.openrouter.chat_completion", mock):
|
||||||
|
await client.post(
|
||||||
|
"/generate/text",
|
||||||
|
json={"model": "openai/gpt-4o", "prompt": "Hello",
|
||||||
|
"system_prompt": "Be concise."},
|
||||||
|
headers={"Authorization": f"Bearer {token}"},
|
||||||
|
)
|
||||||
|
call_messages = mock.call_args.kwargs["messages"]
|
||||||
|
assert call_messages[0] == {"role": "system", "content": "Be concise."}
|
||||||
|
assert call_messages[1] == {"role": "user", "content": "Hello"}
|
||||||
|
|
||||||
|
|
||||||
|
async def test_generate_text_unauthenticated(client):
|
||||||
|
resp = await client.post("/generate/text", json={"model": "openai/gpt-4o", "prompt": "Hi"})
|
||||||
|
assert resp.status_code == 401
|
||||||
|
|
||||||
|
|
||||||
|
async def test_generate_text_upstream_error(client):
|
||||||
|
token = await _user_token(client)
|
||||||
|
with patch("backend.app.routers.generate.openrouter.chat_completion", new_callable=AsyncMock, side_effect=Exception("timeout")):
|
||||||
|
resp = await client.post(
|
||||||
|
"/generate/text",
|
||||||
|
json={"model": "openai/gpt-4o", "prompt": "Hi"},
|
||||||
|
headers={"Authorization": f"Bearer {token}"},
|
||||||
|
)
|
||||||
|
assert resp.status_code == 502
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# POST /generate/image
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def test_generate_image(client):
|
||||||
|
token = await _user_token(client)
|
||||||
|
with patch("backend.app.routers.generate.openrouter.generate_image", new_callable=AsyncMock, return_value=FAKE_IMAGE):
|
||||||
|
resp = await client.post(
|
||||||
|
"/generate/image",
|
||||||
|
json={"model": "openai/dall-e-3", "prompt": "A cat on the moon"},
|
||||||
|
headers={"Authorization": f"Bearer {token}"},
|
||||||
|
)
|
||||||
|
assert resp.status_code == 200
|
||||||
|
data = resp.json()
|
||||||
|
assert data["id"] == "gen-img-1"
|
||||||
|
assert len(data["images"]) == 1
|
||||||
|
assert data["images"][0]["url"] == "https://example.com/image.png"
|
||||||
|
assert data["images"][0]["revised_prompt"] == "A cat on the moon"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_generate_image_unauthenticated(client):
|
||||||
|
resp = await client.post("/generate/image", json={"model": "openai/dall-e-3", "prompt": "Hi"})
|
||||||
|
assert resp.status_code == 401
|
||||||
|
|
||||||
|
|
||||||
|
async def test_generate_image_upstream_error(client):
|
||||||
|
token = await _user_token(client)
|
||||||
|
with patch("backend.app.routers.generate.openrouter.generate_image", new_callable=AsyncMock, side_effect=Exception("rate limit")):
|
||||||
|
resp = await client.post(
|
||||||
|
"/generate/image",
|
||||||
|
json={"model": "openai/dall-e-3", "prompt": "Hi"},
|
||||||
|
headers={"Authorization": f"Bearer {token}"},
|
||||||
|
)
|
||||||
|
assert resp.status_code == 502
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# POST /generate/video
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def test_generate_video(client):
|
||||||
|
token = await _user_token(client)
|
||||||
|
with patch("backend.app.routers.generate.openrouter.generate_video", new_callable=AsyncMock, return_value=FAKE_VIDEO):
|
||||||
|
resp = await client.post(
|
||||||
|
"/generate/video",
|
||||||
|
json={"model": "stability/stable-video",
|
||||||
|
"prompt": "Ocean waves at sunset"},
|
||||||
|
headers={"Authorization": f"Bearer {token}"},
|
||||||
|
)
|
||||||
|
assert resp.status_code == 200
|
||||||
|
data = resp.json()
|
||||||
|
assert data["id"] == "gen-vid-1"
|
||||||
|
assert data["status"] == "queued"
|
||||||
|
assert data["video_url"] is None
|
||||||
|
assert data["metadata"]["estimated_seconds"] == 30
|
||||||
|
|
||||||
|
|
||||||
|
async def test_generate_video_unauthenticated(client):
|
||||||
|
resp = await client.post("/generate/video", json={"model": "m", "prompt": "p"})
|
||||||
|
assert resp.status_code == 401
|
||||||
|
|
||||||
|
|
||||||
|
async def test_generate_video_upstream_error(client):
|
||||||
|
token = await _user_token(client)
|
||||||
|
with patch("backend.app.routers.generate.openrouter.generate_video", new_callable=AsyncMock, side_effect=Exception("503")):
|
||||||
|
resp = await client.post(
|
||||||
|
"/generate/video",
|
||||||
|
json={"model": "stability/stable-video", "prompt": "Hi"},
|
||||||
|
headers={"Authorization": f"Bearer {token}"},
|
||||||
|
)
|
||||||
|
assert resp.status_code == 502
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# POST /generate/video/from-image
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def test_generate_video_from_image(client):
|
||||||
|
token = await _user_token(client)
|
||||||
|
with patch("backend.app.routers.generate.openrouter.generate_video_from_image", new_callable=AsyncMock, return_value=FAKE_VIDEO_DONE):
|
||||||
|
resp = await client.post(
|
||||||
|
"/generate/video/from-image",
|
||||||
|
json={
|
||||||
|
"model": "runway/gen-3",
|
||||||
|
"image_url": "https://example.com/cat.jpg",
|
||||||
|
"prompt": "Cat runs across the room",
|
||||||
|
},
|
||||||
|
headers={"Authorization": f"Bearer {token}"},
|
||||||
|
)
|
||||||
|
assert resp.status_code == 200
|
||||||
|
data = resp.json()
|
||||||
|
assert data["status"] == "completed"
|
||||||
|
assert data["video_url"] == "https://example.com/video.mp4"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_generate_video_from_image_unauthenticated(client):
|
||||||
|
resp = await client.post(
|
||||||
|
"/generate/video/from-image",
|
||||||
|
json={"model": "m", "image_url": "https://example.com/img.jpg", "prompt": "p"},
|
||||||
|
)
|
||||||
|
assert resp.status_code == 401
|
||||||
|
|
||||||
|
|
||||||
|
async def test_generate_video_from_image_upstream_error(client):
|
||||||
|
token = await _user_token(client)
|
||||||
|
with patch("backend.app.routers.generate.openrouter.generate_video_from_image", new_callable=AsyncMock, side_effect=Exception("error")):
|
||||||
|
resp = await client.post(
|
||||||
|
"/generate/video/from-image",
|
||||||
|
json={"model": "runway/gen-3",
|
||||||
|
"image_url": "https://example.com/img.jpg", "prompt": "p"},
|
||||||
|
headers={"Authorization": f"Bearer {token}"},
|
||||||
|
)
|
||||||
|
assert resp.status_code == 502
|
||||||
Reference in New Issue
Block a user