712c556032
- Updated `refresh_models_cache` to include output modalities in the models cache. - Added `get_model_output_modalities` function to retrieve output modalities for a specific model. - Modified tests to cover new functionality for output modalities. - Updated OpenRouter video generation functions to support audio generation and improved error handling. - Enhanced dashboard to display generated images and videos. - Refactored frontend templates to accommodate new data structures for generated content. - Adjusted tests to validate changes in model handling and dashboard rendering. Co-authored-by: Copilot <copilot@github.com>
346 lines
12 KiB
Python
346 lines
12 KiB
Python
"""Generate router: text, image, video, and image-to-video generation."""
|
|
from datetime import datetime, timezone
|
|
|
|
from fastapi import APIRouter, Depends, HTTPException, status
|
|
|
|
from ..db import get_conn, get_write_lock
|
|
from ..dependencies import get_current_user
|
|
from ..models.ai import (
|
|
ImageRequest,
|
|
ImageResponse,
|
|
ImageResult,
|
|
TextRequest,
|
|
TextResponse,
|
|
VideoFromImageRequest,
|
|
VideoRequest,
|
|
VideoResponse,
|
|
)
|
|
from ..services import openrouter
|
|
from ..services.models import get_model_output_modalities
|
|
|
|
router = APIRouter(prefix="/generate", tags=["generate"])
|
|
|
|
|
|
@router.post("/text", response_model=TextResponse)
|
|
async def generate_text(
|
|
body: TextRequest,
|
|
_: dict = Depends(get_current_user),
|
|
) -> TextResponse:
|
|
"""Generate text from a prompt using a chat model."""
|
|
if body.messages:
|
|
messages = [{"role": m.role, "content": m.content}
|
|
for m in body.messages]
|
|
if body.system_prompt and (not messages or messages[0]["role"] != "system"):
|
|
messages.insert(
|
|
0, {"role": "system", "content": body.system_prompt})
|
|
else:
|
|
messages = []
|
|
if body.system_prompt:
|
|
messages.append({"role": "system", "content": body.system_prompt})
|
|
messages.append({"role": "user", "content": body.prompt})
|
|
|
|
try:
|
|
result = await openrouter.chat_completion(
|
|
model=body.model,
|
|
messages=messages,
|
|
temperature=body.temperature,
|
|
max_tokens=body.max_tokens,
|
|
)
|
|
except Exception as exc:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_502_BAD_GATEWAY, detail=f"OpenRouter error: {exc}")
|
|
|
|
try:
|
|
choice = result["choices"][0]
|
|
return TextResponse(
|
|
id=result["id"],
|
|
model=result.get("model", body.model),
|
|
content=choice["message"]["content"],
|
|
usage=result.get("usage"),
|
|
)
|
|
except (KeyError, IndexError) as exc:
|
|
raise HTTPException(status_code=status.HTTP_502_BAD_GATEWAY,
|
|
detail=f"Unexpected response format: {exc}")
|
|
|
|
|
|
@router.post("/image", response_model=ImageResponse)
|
|
async def generate_image(
|
|
body: ImageRequest,
|
|
current_user: dict = Depends(get_current_user),
|
|
) -> ImageResponse:
|
|
"""Generate images from a prompt using the chat completions endpoint.
|
|
|
|
All OpenRouter image models use /chat/completions with a modalities param.
|
|
Models that output only images use ["image"]; those that also output text
|
|
use ["image", "text"]. We look this up from the model cache; default to
|
|
["image", "text"] when the model is not yet cached.
|
|
"""
|
|
# Determine modalities from cache; default ["image", "text"] works for most models
|
|
try:
|
|
conn = get_conn()
|
|
cached_modalities = get_model_output_modalities(conn, body.model)
|
|
except Exception:
|
|
cached_modalities = []
|
|
|
|
if cached_modalities:
|
|
# If cache says model only outputs image (no text), use ["image"]
|
|
modalities = ["image"] if set(cached_modalities) == {
|
|
"image"} else ["image", "text"]
|
|
else:
|
|
# Safe default: ["image", "text"]; works for Gemini, GPT-image etc.
|
|
# For image-only models that fail with this, the error surfaces to the user.
|
|
modalities = ["image", "text"]
|
|
|
|
image_config: dict = {}
|
|
if body.aspect_ratio:
|
|
image_config["aspect_ratio"] = body.aspect_ratio
|
|
if body.image_size:
|
|
image_config["image_size"] = body.image_size
|
|
|
|
try:
|
|
result = await openrouter.generate_image_chat(
|
|
model=body.model,
|
|
prompt=body.prompt,
|
|
modalities=modalities,
|
|
image_config=image_config if image_config else None,
|
|
)
|
|
except Exception as exc:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_502_BAD_GATEWAY, detail=f"OpenRouter error: {exc}")
|
|
|
|
try:
|
|
message = result.get("choices", [{}])[0].get("message", {})
|
|
images = []
|
|
for item in message.get("images", []):
|
|
img_url = item.get("image_url", {}).get("url")
|
|
images.append(ImageResult(
|
|
url=img_url,
|
|
b64_json=None,
|
|
revised_prompt=message.get("content") or None,
|
|
))
|
|
if not images:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_502_BAD_GATEWAY,
|
|
detail="No images returned by model. Verify the model supports image generation.",
|
|
)
|
|
|
|
# Persist each image to DB
|
|
user_id = current_user.get("id") or current_user.get("sub")
|
|
now = datetime.now(timezone.utc).replace(tzinfo=None)
|
|
stored: list[ImageResult] = []
|
|
async with get_write_lock():
|
|
conn = get_conn()
|
|
for img in images:
|
|
if img.url:
|
|
row = conn.execute(
|
|
"""INSERT INTO generated_images (user_id, model_id, prompt, image_data, created_at)
|
|
VALUES (?, ?, ?, ?, ?) RETURNING id""",
|
|
[user_id, body.model, body.prompt, img.url, now],
|
|
).fetchone()
|
|
image_id = str(row[0]) if row else None
|
|
else:
|
|
image_id = None
|
|
stored.append(ImageResult(
|
|
url=img.url,
|
|
b64_json=img.b64_json,
|
|
revised_prompt=img.revised_prompt,
|
|
image_id=image_id,
|
|
))
|
|
|
|
return ImageResponse(
|
|
id=result.get("id", ""),
|
|
model=result.get("model", body.model),
|
|
images=stored,
|
|
)
|
|
except HTTPException:
|
|
raise
|
|
except (KeyError, TypeError) as exc:
|
|
raise HTTPException(status_code=status.HTTP_502_BAD_GATEWAY,
|
|
detail=f"Unexpected response format: {exc}")
|
|
|
|
|
|
@router.get("/images")
|
|
async def list_generated_images(
|
|
current_user: dict = Depends(get_current_user),
|
|
) -> list[dict]:
|
|
"""Return all generated images for the current user, newest first."""
|
|
user_id = current_user.get("id") or current_user.get("sub")
|
|
conn = get_conn()
|
|
rows = conn.execute(
|
|
"""SELECT id, model_id, prompt, image_data, created_at
|
|
FROM generated_images
|
|
WHERE user_id = ?
|
|
ORDER BY created_at DESC""",
|
|
[user_id],
|
|
).fetchall()
|
|
return [
|
|
{
|
|
"id": str(r[0]),
|
|
"model_id": r[1],
|
|
"prompt": r[2],
|
|
"image_data": r[3],
|
|
"created_at": r[4].isoformat() if r[4] else None,
|
|
}
|
|
for r in rows
|
|
]
|
|
|
|
|
|
@router.post("/video", response_model=VideoResponse)
|
|
async def generate_video(
|
|
body: VideoRequest,
|
|
current_user: dict = Depends(get_current_user),
|
|
) -> VideoResponse:
|
|
"""Generate a video from a text prompt."""
|
|
try:
|
|
result = await openrouter.generate_video(
|
|
model=body.model,
|
|
prompt=body.prompt,
|
|
duration_seconds=body.duration_seconds,
|
|
aspect_ratio=body.aspect_ratio,
|
|
resolution=body.resolution,
|
|
)
|
|
except Exception as exc:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_502_BAD_GATEWAY, detail=f"OpenRouter error: {exc}")
|
|
|
|
user_id = current_user.get("id") or current_user.get("sub")
|
|
job_id = result.get("id", "")
|
|
polling_url = result.get("polling_url")
|
|
job_status = result.get("status", "pending")
|
|
now = datetime.now(timezone.utc).replace(tzinfo=None)
|
|
async with get_write_lock():
|
|
conn = get_conn()
|
|
conn.execute(
|
|
"""INSERT INTO generated_videos (user_id, job_id, model_id, prompt, polling_url, status, created_at, updated_at)
|
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?)""",
|
|
[user_id, job_id, body.model, body.prompt,
|
|
polling_url, job_status, now, now],
|
|
)
|
|
|
|
urls = result.get("unsigned_urls") or result.get("video_urls")
|
|
return VideoResponse(
|
|
id=job_id,
|
|
model=body.model,
|
|
status=job_status,
|
|
polling_url=polling_url,
|
|
video_urls=urls,
|
|
video_url=(urls or [None])[0],
|
|
error=result.get("error"),
|
|
metadata=result.get("metadata"),
|
|
)
|
|
|
|
|
|
@router.post("/video/from-image", response_model=VideoResponse)
|
|
async def generate_video_from_image(
|
|
body: VideoFromImageRequest,
|
|
current_user: dict = Depends(get_current_user),
|
|
) -> VideoResponse:
|
|
"""Generate a video from an image and a text prompt."""
|
|
try:
|
|
result = await openrouter.generate_video_from_image(
|
|
model=body.model,
|
|
image_url=body.image_url,
|
|
prompt=body.prompt,
|
|
duration_seconds=body.duration_seconds,
|
|
aspect_ratio=body.aspect_ratio,
|
|
resolution=body.resolution,
|
|
)
|
|
except Exception as exc:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_502_BAD_GATEWAY, detail=f"OpenRouter error: {exc}")
|
|
|
|
user_id = current_user.get("id") or current_user.get("sub")
|
|
job_id = result.get("id", "")
|
|
polling_url = result.get("polling_url")
|
|
job_status = result.get("status", "pending")
|
|
now = datetime.now(timezone.utc).replace(tzinfo=None)
|
|
async with get_write_lock():
|
|
conn = get_conn()
|
|
conn.execute(
|
|
"""INSERT INTO generated_videos (user_id, job_id, model_id, prompt, polling_url, status, created_at, updated_at)
|
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?)""",
|
|
[user_id, job_id, body.model, body.prompt,
|
|
polling_url, job_status, now, now],
|
|
)
|
|
|
|
urls = result.get("unsigned_urls") or result.get("video_urls")
|
|
return VideoResponse(
|
|
id=job_id,
|
|
model=body.model,
|
|
status=job_status,
|
|
polling_url=polling_url,
|
|
video_urls=urls,
|
|
video_url=(urls or [None])[0],
|
|
error=result.get("error"),
|
|
metadata=result.get("metadata"),
|
|
)
|
|
|
|
|
|
@router.get("/video/status", response_model=VideoResponse)
|
|
async def poll_video_status(
|
|
polling_url: str,
|
|
current_user: dict = Depends(get_current_user),
|
|
) -> VideoResponse:
|
|
"""Poll status of a video generation job; updates DB row when completed/failed."""
|
|
try:
|
|
result = await openrouter.poll_video_status(polling_url)
|
|
except Exception as exc:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_502_BAD_GATEWAY, detail=f"OpenRouter error: {exc}")
|
|
|
|
job_status = result.get("status", "processing")
|
|
urls = result.get("unsigned_urls") or result.get("video_urls")
|
|
video_url = (urls or [None])[0]
|
|
|
|
# Update DB row for this job when terminal state reached
|
|
if job_status in ("completed", "failed"):
|
|
now = datetime.now(timezone.utc).replace(tzinfo=None)
|
|
async with get_write_lock():
|
|
conn = get_conn()
|
|
conn.execute(
|
|
"""UPDATE generated_videos
|
|
SET status = ?, video_url = ?, updated_at = ?
|
|
WHERE job_id = ?""",
|
|
[job_status, video_url, now, result.get("id", "")],
|
|
)
|
|
|
|
return VideoResponse(
|
|
id=result.get("id", ""),
|
|
model=result.get("model", ""),
|
|
status=job_status,
|
|
polling_url=result.get("polling_url"),
|
|
video_urls=urls,
|
|
video_url=video_url,
|
|
error=result.get("error"),
|
|
metadata=result.get("metadata"),
|
|
)
|
|
|
|
|
|
@router.get("/videos")
|
|
async def list_generated_videos(
|
|
current_user: dict = Depends(get_current_user),
|
|
) -> list[dict]:
|
|
"""Return all generated video jobs for the current user, newest first."""
|
|
user_id = current_user.get("id") or current_user.get("sub")
|
|
conn = get_conn()
|
|
rows = conn.execute(
|
|
"""SELECT id, job_id, model_id, prompt, polling_url, status, video_url, created_at
|
|
FROM generated_videos
|
|
WHERE user_id = ?
|
|
ORDER BY created_at DESC""",
|
|
[user_id],
|
|
).fetchall()
|
|
return [
|
|
{
|
|
"id": str(r[0]),
|
|
"job_id": r[1],
|
|
"model_id": r[2],
|
|
"prompt": r[3],
|
|
"polling_url": r[4],
|
|
"status": r[5],
|
|
"video_url": r[6],
|
|
"created_at": r[7].isoformat() if r[7] else None,
|
|
}
|
|
for r in rows
|
|
]
|