Files
ai.allucanget.biz/backend/app/routers/generate.py
T

440 lines
15 KiB
Python

"""Generate router: text, image, video, and image-to-video generation."""
from datetime import datetime, timezone
import httpx
from fastapi import APIRouter, Depends, HTTPException, status
from ..db import get_conn, get_write_lock
from ..dependencies import get_current_user
from ..models.ai import (
ImageRequest,
ImageResponse,
ImageResult,
TextRequest,
TextResponse,
VideoFromImageRequest,
VideoRequest,
VideoResponse,
)
from ..services import openrouter
from ..services.models import get_model_output_modalities
router = APIRouter(prefix="/generate", tags=["generate"])
@router.post("/text", response_model=TextResponse)
async def generate_text(
body: TextRequest,
_: dict = Depends(get_current_user),
) -> TextResponse:
"""Generate text from a prompt using a chat model."""
if body.messages:
messages = [{"role": m.role, "content": m.content}
for m in body.messages]
if body.system_prompt and (not messages or messages[0]["role"] != "system"):
messages.insert(
0, {"role": "system", "content": body.system_prompt})
else:
messages = []
if body.system_prompt:
messages.append({"role": "system", "content": body.system_prompt})
messages.append({"role": "user", "content": body.prompt})
try:
result = await openrouter.chat_completion(
model=body.model,
messages=messages,
temperature=body.temperature,
max_tokens=body.max_tokens,
)
except Exception as exc:
raise HTTPException(
status_code=status.HTTP_502_BAD_GATEWAY, detail=f"OpenRouter error: {exc}")
try:
choice = result["choices"][0]
return TextResponse(
id=result["id"],
model=result.get("model", body.model),
content=choice["message"]["content"],
usage=result.get("usage"),
)
except (KeyError, IndexError) as exc:
raise HTTPException(status_code=status.HTTP_502_BAD_GATEWAY,
detail=f"Unexpected response format: {exc}")
@router.post("/image", response_model=ImageResponse)
async def generate_image(
body: ImageRequest,
current_user: dict = Depends(get_current_user),
) -> ImageResponse:
"""Generate images from a prompt using the chat completions endpoint.
All OpenRouter image models use /chat/completions with a modalities param.
Models that output only images use ["image"]; those that also output text
use ["image", "text"]. We look this up from the model cache; default to
["image", "text"] when the model is not yet cached.
"""
# Determine modalities from cache; default ["image", "text"] works for most models
try:
conn = get_conn()
cached_modalities = get_model_output_modalities(conn, body.model)
except Exception:
cached_modalities = []
if cached_modalities:
# If cache says model only outputs image (no text), use ["image"]
modalities = ["image"] if set(cached_modalities) == {
"image"} else ["image", "text"]
else:
# Safe default: ["image", "text"]; works for Gemini, GPT-image etc.
# For image-only models that fail with this, the error surfaces to the user.
modalities = ["image", "text"]
image_config: dict = {}
if body.aspect_ratio:
image_config["aspect_ratio"] = body.aspect_ratio
if body.image_size:
image_config["image_size"] = body.image_size
try:
result = await openrouter.generate_image_chat(
model=body.model,
prompt=body.prompt,
modalities=modalities,
image_config=image_config if image_config else None,
)
except Exception as exc:
raise HTTPException(
status_code=status.HTTP_502_BAD_GATEWAY, detail=f"OpenRouter error: {exc}")
try:
message = result.get("choices", [{}])[0].get("message", {})
images = []
for item in message.get("images", []):
img_url = item.get("image_url", {}).get("url")
images.append(ImageResult(
url=img_url,
b64_json=None,
revised_prompt=message.get("content") or None,
))
if not images:
raise HTTPException(
status_code=status.HTTP_502_BAD_GATEWAY,
detail="No images returned by model. Verify the model supports image generation.",
)
# Persist each image to DB
user_id = current_user.get("id") or current_user.get("sub")
now = datetime.now(timezone.utc).replace(tzinfo=None)
stored: list[ImageResult] = []
async with get_write_lock():
conn = get_conn()
for img in images:
if img.url:
row = conn.execute(
"""INSERT INTO generated_images (user_id, model_id, prompt, image_data, created_at)
VALUES (?, ?, ?, ?, ?) RETURNING id""",
[user_id, body.model, body.prompt, img.url, now],
).fetchone()
image_id = str(row[0]) if row else None
else:
image_id = None
stored.append(ImageResult(
url=img.url,
b64_json=img.b64_json,
revised_prompt=img.revised_prompt,
image_id=image_id,
))
return ImageResponse(
id=result.get("id", ""),
model=result.get("model", body.model),
images=stored,
)
except HTTPException:
raise
except (KeyError, TypeError) as exc:
raise HTTPException(status_code=status.HTTP_502_BAD_GATEWAY,
detail=f"Unexpected response format: {exc}")
@router.get("/images")
async def list_generated_images(
current_user: dict = Depends(get_current_user),
) -> list[dict]:
"""Return all generated images for the current user, newest first."""
user_id = current_user.get("id") or current_user.get("sub")
conn = get_conn()
rows = conn.execute(
"""SELECT id, model_id, prompt, image_data, created_at
FROM generated_images
WHERE user_id = ?
ORDER BY created_at DESC""",
[user_id],
).fetchall()
return [
{
"id": str(r[0]),
"model_id": r[1],
"prompt": r[2],
"image_data": r[3],
"created_at": r[4].isoformat() if r[4] else None,
}
for r in rows
]
@router.get("/images/{image_id}")
async def get_generated_image(
image_id: str,
current_user: dict = Depends(get_current_user),
) -> dict:
"""Return details for a single generated image."""
user_id = current_user.get("id") or current_user.get("sub")
conn = get_conn()
row = conn.execute(
"""SELECT id, model_id, prompt, image_data, created_at
FROM generated_images
WHERE id = ? AND user_id = ?""",
[image_id, user_id],
).fetchone()
if not row:
raise HTTPException(status_code=404, detail="Image not found")
return {
"id": str(row[0]),
"model_id": row[1],
"prompt": row[2],
"image_data": row[3],
"created_at": row[4].isoformat() if row[4] else None,
}
@router.get("/images/{image_id}")
async def get_generated_image(
image_id: str,
current_user: dict = Depends(get_current_user),
) -> dict:
"""Return details for a single generated image."""
user_id = current_user.get("id") or current_user.get("sub")
conn = get_conn()
row = conn.execute(
"""SELECT id, model_id, prompt, image_data, created_at
FROM generated_images
WHERE id = ? AND user_id = ?""",
[image_id, user_id],
).fetchone()
if not row:
raise HTTPException(status_code=404, detail="Image not found")
return {
"id": str(row[0]),
"model_id": row[1],
"prompt": row[2],
"image_data": row[3],
"created_at": row[4].isoformat() if row[4] else None,
}
@router.post("/video", response_model=VideoResponse)
async def generate_video(
body: VideoRequest,
current_user: dict = Depends(get_current_user),
) -> VideoResponse:
"""Generate a video from a text prompt."""
try:
result = await openrouter.generate_video(
model=body.model,
prompt=body.prompt,
duration_seconds=body.duration_seconds,
aspect_ratio=body.aspect_ratio,
resolution=body.resolution,
)
except httpx.HTTPStatusError as exc:
detail = (
f"OpenRouter API error: {exc.response.status_code} - {exc.response.text}"
)
raise HTTPException(
status_code=status.HTTP_502_BAD_GATEWAY, detail=detail)
except Exception as exc:
raise HTTPException(
status_code=status.HTTP_502_BAD_GATEWAY, detail=f"OpenRouter error: {exc}"
)
user_id = current_user.get("id") or current_user.get("sub")
job_id = result.get("id", "")
polling_url = result.get("polling_url")
job_status = result.get("status", "pending")
now = datetime.now(timezone.utc).replace(tzinfo=None)
async with get_write_lock():
conn = get_conn()
conn.execute(
"""INSERT INTO generated_videos (user_id, job_id, model_id, prompt, polling_url, status, created_at, updated_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)""",
[user_id, job_id, body.model, body.prompt,
polling_url, job_status, now, now],
)
urls = result.get("unsigned_urls") or result.get("video_urls")
return VideoResponse(
id=job_id,
model=body.model,
status=job_status,
polling_url=polling_url,
video_urls=urls,
video_url=(urls or [None])[0],
error=result.get("error"),
metadata=result.get("metadata"),
)
@router.post("/video/from-image", response_model=VideoResponse)
async def generate_video_from_image(
body: VideoFromImageRequest,
current_user: dict = Depends(get_current_user),
) -> VideoResponse:
"""Generate a video from an image and a text prompt."""
try:
result = await openrouter.generate_video_from_image(
model=body.model,
image_url=body.image_url,
prompt=body.prompt,
duration_seconds=body.duration_seconds,
aspect_ratio=body.aspect_ratio,
resolution=body.resolution,
)
except httpx.HTTPStatusError as exc:
detail = (
f"OpenRouter API error: {exc.response.status_code} - {exc.response.text}"
)
raise HTTPException(
status_code=status.HTTP_502_BAD_GATEWAY, detail=detail)
except Exception as exc:
raise HTTPException(
status_code=status.HTTP_502_BAD_GATEWAY, detail=f"OpenRouter error: {exc}"
)
user_id = current_user.get("id") or current_user.get("sub")
job_id = result.get("id", "")
polling_url = result.get("polling_url")
job_status = result.get("status", "pending")
now = datetime.now(timezone.utc).replace(tzinfo=None)
async with get_write_lock():
conn = get_conn()
conn.execute(
"""INSERT INTO generated_videos (user_id, job_id, model_id, prompt, polling_url, status, created_at, updated_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)""",
[user_id, job_id, body.model, body.prompt,
polling_url, job_status, now, now],
)
urls = result.get("unsigned_urls") or result.get("video_urls")
return VideoResponse(
id=job_id,
model=body.model,
status=job_status,
polling_url=polling_url,
video_urls=urls,
video_url=(urls or [None])[0],
error=result.get("error"),
metadata=result.get("metadata"),
)
@router.get("/video/status", response_model=VideoResponse)
async def poll_video_status(
polling_url: str,
current_user: dict = Depends(get_current_user),
) -> VideoResponse:
"""Poll status of a video generation job; updates DB row when completed/failed."""
try:
result = await openrouter.poll_video_status(polling_url)
except Exception as exc:
raise HTTPException(
status_code=status.HTTP_502_BAD_GATEWAY, detail=f"OpenRouter error: {exc}")
job_status = result.get("status", "processing")
urls = result.get("unsigned_urls") or result.get("video_urls")
video_url = (urls or [None])[0]
# Update DB row for this job when terminal state reached
if job_status in ("completed", "failed"):
now = datetime.now(timezone.utc).replace(tzinfo=None)
async with get_write_lock():
conn = get_conn()
conn.execute(
"""UPDATE generated_videos
SET status = ?, video_url = ?, updated_at = ?
WHERE job_id = ?""",
[job_status, video_url, now, result.get("id", "")],
)
return VideoResponse(
id=result.get("id", ""),
model=result.get("model", ""),
status=job_status,
polling_url=result.get("polling_url"),
video_urls=urls,
video_url=video_url,
error=result.get("error"),
metadata=result.get("metadata"),
)
@router.get("/videos")
async def list_generated_videos(
current_user: dict = Depends(get_current_user),
) -> list[dict]:
"""Return all generated video jobs for the current user, newest first."""
user_id = current_user.get("id") or current_user.get("sub")
conn = get_conn()
rows = conn.execute(
"""SELECT id, job_id, model_id, prompt, polling_url, status, video_url, created_at
FROM generated_videos
WHERE user_id = ?
ORDER BY created_at DESC""",
[user_id],
).fetchall()
return [
{
"id": str(r[0]),
"job_id": r[1],
"model_id": r[2],
"prompt": r[3],
"polling_url": r[4],
"status": r[5],
"video_url": r[6],
"created_at": r[7].isoformat() if r[7] else None,
}
for r in rows
]
@router.get("/videos/{video_id}")
async def get_generated_video(
video_id: str,
current_user: dict = Depends(get_current_user),
) -> dict:
"""Return details for a single video generation job."""
user_id = current_user.get("id") or current_user.get("sub")
conn = get_conn()
row = conn.execute(
"""SELECT id, job_id, model_id, prompt, polling_url, status, video_url, created_at, updated_at
FROM generated_videos
WHERE id = ? AND user_id = ?""",
[video_id, user_id],
).fetchone()
if not row:
raise HTTPException(status_code=404, detail="Video job not found")
return {
"id": str(row[0]),
"job_id": row[1],
"model_id": row[2],
"prompt": row[3],
"polling_url": row[4],
"status": row[5],
"video_url": row[6],
"created_at": row[7].isoformat() if row[7] else None,
"updated_at": row[8].isoformat() if row[8] else None,
}