"""Generate router: text, image, video, and image-to-video generation.""" from datetime import datetime, timezone import httpx from fastapi import APIRouter, Depends, HTTPException, status from ..db import get_conn, get_write_lock from ..dependencies import get_current_user from ..models.ai import ( ImageRequest, ImageResponse, ImageResult, TextRequest, TextResponse, VideoFromImageRequest, VideoRequest, VideoResponse, ) from ..services import openrouter from ..services.models import get_model_output_modalities router = APIRouter(prefix="/generate", tags=["generate"]) @router.post("/text", response_model=TextResponse) async def generate_text( body: TextRequest, _: dict = Depends(get_current_user), ) -> TextResponse: """Generate text from a prompt using a chat model.""" if body.messages: messages = [{"role": m.role, "content": m.content} for m in body.messages] if body.system_prompt and (not messages or messages[0]["role"] != "system"): messages.insert( 0, {"role": "system", "content": body.system_prompt}) else: messages = [] if body.system_prompt: messages.append({"role": "system", "content": body.system_prompt}) messages.append({"role": "user", "content": body.prompt}) try: result = await openrouter.chat_completion( model=body.model, messages=messages, temperature=body.temperature, max_tokens=body.max_tokens, ) except Exception as exc: raise HTTPException( status_code=status.HTTP_502_BAD_GATEWAY, detail=f"OpenRouter error: {exc}") try: choice = result["choices"][0] return TextResponse( id=result["id"], model=result.get("model", body.model), content=choice["message"]["content"], usage=result.get("usage"), ) except (KeyError, IndexError) as exc: raise HTTPException(status_code=status.HTTP_502_BAD_GATEWAY, detail=f"Unexpected response format: {exc}") @router.post("/image", response_model=ImageResponse) async def generate_image( body: ImageRequest, current_user: dict = Depends(get_current_user), ) -> ImageResponse: """Generate images from a prompt using the chat completions endpoint. All OpenRouter image models use /chat/completions with a modalities param. Models that output only images use ["image"]; those that also output text use ["image", "text"]. We look this up from the model cache; default to ["image", "text"] when the model is not yet cached. """ # Determine modalities from cache; default ["image", "text"] works for most models try: conn = get_conn() cached_modalities = get_model_output_modalities(conn, body.model) except Exception: cached_modalities = [] if cached_modalities: # If cache says model only outputs image (no text), use ["image"] modalities = ["image"] if set(cached_modalities) == { "image"} else ["image", "text"] else: # Safe default: ["image", "text"]; works for Gemini, GPT-image etc. # For image-only models that fail with this, the error surfaces to the user. modalities = ["image", "text"] image_config: dict = {} if body.aspect_ratio: image_config["aspect_ratio"] = body.aspect_ratio if body.image_size: image_config["image_size"] = body.image_size try: result = await openrouter.generate_image_chat( model=body.model, prompt=body.prompt, modalities=modalities, image_config=image_config if image_config else None, ) except Exception as exc: raise HTTPException( status_code=status.HTTP_502_BAD_GATEWAY, detail=f"OpenRouter error: {exc}") try: message = result.get("choices", [{}])[0].get("message", {}) images = [] for item in message.get("images", []): img_url = item.get("image_url", {}).get("url") images.append(ImageResult( url=img_url, b64_json=None, revised_prompt=message.get("content") or None, )) if not images: raise HTTPException( status_code=status.HTTP_502_BAD_GATEWAY, detail="No images returned by model. Verify the model supports image generation.", ) # Persist each image to DB user_id = current_user.get("id") or current_user.get("sub") now = datetime.now(timezone.utc).replace(tzinfo=None) stored: list[ImageResult] = [] async with get_write_lock(): conn = get_conn() for img in images: if img.url: row = conn.execute( """INSERT INTO generated_images (user_id, model_id, prompt, image_data, created_at) VALUES (?, ?, ?, ?, ?) RETURNING id""", [user_id, body.model, body.prompt, img.url, now], ).fetchone() image_id = str(row[0]) if row else None else: image_id = None stored.append(ImageResult( url=img.url, b64_json=img.b64_json, revised_prompt=img.revised_prompt, image_id=image_id, )) return ImageResponse( id=result.get("id", ""), model=result.get("model", body.model), images=stored, ) except HTTPException: raise except (KeyError, TypeError) as exc: raise HTTPException(status_code=status.HTTP_502_BAD_GATEWAY, detail=f"Unexpected response format: {exc}") @router.get("/images") async def list_generated_images( current_user: dict = Depends(get_current_user), ) -> list[dict]: """Return all generated images for the current user, newest first.""" user_id = current_user.get("id") or current_user.get("sub") conn = get_conn() rows = conn.execute( """SELECT id, model_id, prompt, image_data, created_at FROM generated_images WHERE user_id = ? ORDER BY created_at DESC""", [user_id], ).fetchall() return [ { "id": str(r[0]), "model_id": r[1], "prompt": r[2], "image_data": r[3], "created_at": r[4].isoformat() if r[4] else None, } for r in rows ] @router.get("/images/{image_id}") async def get_generated_image( image_id: str, current_user: dict = Depends(get_current_user), ) -> dict: """Return details for a single generated image.""" user_id = current_user.get("id") or current_user.get("sub") conn = get_conn() row = conn.execute( """SELECT id, model_id, prompt, image_data, created_at FROM generated_images WHERE id = ? AND user_id = ?""", [image_id, user_id], ).fetchone() if not row: raise HTTPException(status_code=404, detail="Image not found") return { "id": str(row[0]), "model_id": row[1], "prompt": row[2], "image_data": row[3], "created_at": row[4].isoformat() if row[4] else None, } @router.get("/images/{image_id}") async def get_generated_image( image_id: str, current_user: dict = Depends(get_current_user), ) -> dict: """Return details for a single generated image.""" user_id = current_user.get("id") or current_user.get("sub") conn = get_conn() row = conn.execute( """SELECT id, model_id, prompt, image_data, created_at FROM generated_images WHERE id = ? AND user_id = ?""", [image_id, user_id], ).fetchone() if not row: raise HTTPException(status_code=404, detail="Image not found") return { "id": str(row[0]), "model_id": row[1], "prompt": row[2], "image_data": row[3], "created_at": row[4].isoformat() if row[4] else None, } @router.post("/video", response_model=VideoResponse) async def generate_video( body: VideoRequest, current_user: dict = Depends(get_current_user), ) -> VideoResponse: """Generate a video from a text prompt.""" try: result = await openrouter.generate_video( model=body.model, prompt=body.prompt, duration_seconds=body.duration_seconds, aspect_ratio=body.aspect_ratio, resolution=body.resolution, ) except httpx.HTTPStatusError as exc: detail = ( f"OpenRouter API error: {exc.response.status_code} - {exc.response.text}" ) raise HTTPException( status_code=status.HTTP_502_BAD_GATEWAY, detail=detail) except Exception as exc: raise HTTPException( status_code=status.HTTP_502_BAD_GATEWAY, detail=f"OpenRouter error: {exc}" ) user_id = current_user.get("id") or current_user.get("sub") job_id = result.get("id", "") polling_url = result.get("polling_url") job_status = result.get("status", "pending") now = datetime.now(timezone.utc).replace(tzinfo=None) async with get_write_lock(): conn = get_conn() conn.execute( """INSERT INTO generated_videos (user_id, job_id, model_id, prompt, polling_url, status, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?)""", [user_id, job_id, body.model, body.prompt, polling_url, job_status, now, now], ) urls = result.get("unsigned_urls") or result.get("video_urls") return VideoResponse( id=job_id, model=body.model, status=job_status, polling_url=polling_url, video_urls=urls, video_url=(urls or [None])[0], error=result.get("error"), metadata=result.get("metadata"), ) @router.post("/video/from-image", response_model=VideoResponse) async def generate_video_from_image( body: VideoFromImageRequest, current_user: dict = Depends(get_current_user), ) -> VideoResponse: """Generate a video from an image and a text prompt.""" try: result = await openrouter.generate_video_from_image( model=body.model, image_url=body.image_url, prompt=body.prompt, duration_seconds=body.duration_seconds, aspect_ratio=body.aspect_ratio, resolution=body.resolution, ) except httpx.HTTPStatusError as exc: detail = ( f"OpenRouter API error: {exc.response.status_code} - {exc.response.text}" ) raise HTTPException( status_code=status.HTTP_502_BAD_GATEWAY, detail=detail) except Exception as exc: raise HTTPException( status_code=status.HTTP_502_BAD_GATEWAY, detail=f"OpenRouter error: {exc}" ) user_id = current_user.get("id") or current_user.get("sub") job_id = result.get("id", "") polling_url = result.get("polling_url") job_status = result.get("status", "pending") now = datetime.now(timezone.utc).replace(tzinfo=None) async with get_write_lock(): conn = get_conn() conn.execute( """INSERT INTO generated_videos (user_id, job_id, model_id, prompt, polling_url, status, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?)""", [user_id, job_id, body.model, body.prompt, polling_url, job_status, now, now], ) urls = result.get("unsigned_urls") or result.get("video_urls") return VideoResponse( id=job_id, model=body.model, status=job_status, polling_url=polling_url, video_urls=urls, video_url=(urls or [None])[0], error=result.get("error"), metadata=result.get("metadata"), ) @router.get("/video/status", response_model=VideoResponse) async def poll_video_status( polling_url: str, current_user: dict = Depends(get_current_user), ) -> VideoResponse: """Poll status of a video generation job; updates DB row when completed/failed.""" try: result = await openrouter.poll_video_status(polling_url) except Exception as exc: raise HTTPException( status_code=status.HTTP_502_BAD_GATEWAY, detail=f"OpenRouter error: {exc}") job_status = result.get("status", "processing") urls = result.get("unsigned_urls") or result.get("video_urls") video_url = (urls or [None])[0] # Update DB row for this job when terminal state reached if job_status in ("completed", "failed"): now = datetime.now(timezone.utc).replace(tzinfo=None) async with get_write_lock(): conn = get_conn() conn.execute( """UPDATE generated_videos SET status = ?, video_url = ?, updated_at = ? WHERE job_id = ?""", [job_status, video_url, now, result.get("id", "")], ) return VideoResponse( id=result.get("id", ""), model=result.get("model", ""), status=job_status, polling_url=result.get("polling_url"), video_urls=urls, video_url=video_url, error=result.get("error"), metadata=result.get("metadata"), ) @router.get("/videos") async def list_generated_videos( current_user: dict = Depends(get_current_user), ) -> list[dict]: """Return all generated video jobs for the current user, newest first.""" user_id = current_user.get("id") or current_user.get("sub") conn = get_conn() rows = conn.execute( """SELECT id, job_id, model_id, prompt, polling_url, status, video_url, created_at FROM generated_videos WHERE user_id = ? ORDER BY created_at DESC""", [user_id], ).fetchall() return [ { "id": str(r[0]), "job_id": r[1], "model_id": r[2], "prompt": r[3], "polling_url": r[4], "status": r[5], "video_url": r[6], "created_at": r[7].isoformat() if r[7] else None, } for r in rows ] @router.get("/videos/{video_id}") async def get_generated_video( video_id: str, current_user: dict = Depends(get_current_user), ) -> dict: """Return details for a single video generation job.""" user_id = current_user.get("id") or current_user.get("sub") conn = get_conn() row = conn.execute( """SELECT id, job_id, model_id, prompt, polling_url, status, video_url, created_at, updated_at FROM generated_videos WHERE id = ? AND user_id = ?""", [video_id, user_id], ).fetchone() if not row: raise HTTPException(status_code=404, detail="Video job not found") return { "id": str(row[0]), "job_id": row[1], "model_id": row[2], "prompt": row[3], "polling_url": row[4], "status": row[5], "video_url": row[6], "created_at": row[7].isoformat() if row[7] else None, "updated_at": row[8].isoformat() if row[8] else None, }