"""Generate router: text, image, video, and image-to-video generation.""" from fastapi import APIRouter, Depends, HTTPException, status from ..dependencies import get_current_user from ..models.ai import ( ImageRequest, ImageResponse, ImageResult, TextRequest, TextResponse, VideoFromImageRequest, VideoRequest, VideoResponse, ) from ..services import openrouter router = APIRouter(prefix="/generate", tags=["generate"]) @router.post("/text", response_model=TextResponse) async def generate_text( body: TextRequest, _: dict = Depends(get_current_user), ) -> TextResponse: """Generate text from a prompt using a chat model.""" if body.messages: messages = [{"role": m.role, "content": m.content} for m in body.messages] if body.system_prompt and (not messages or messages[0]["role"] != "system"): messages.insert( 0, {"role": "system", "content": body.system_prompt}) else: messages = [] if body.system_prompt: messages.append({"role": "system", "content": body.system_prompt}) messages.append({"role": "user", "content": body.prompt}) try: result = await openrouter.chat_completion( model=body.model, messages=messages, temperature=body.temperature, max_tokens=body.max_tokens, ) except Exception as exc: raise HTTPException( status_code=status.HTTP_502_BAD_GATEWAY, detail=f"OpenRouter error: {exc}") try: choice = result["choices"][0] return TextResponse( id=result["id"], model=result.get("model", body.model), content=choice["message"]["content"], usage=result.get("usage"), ) except (KeyError, IndexError) as exc: raise HTTPException(status_code=status.HTTP_502_BAD_GATEWAY, detail=f"Unexpected response format: {exc}") @router.post("/image", response_model=ImageResponse) async def generate_image( body: ImageRequest, _: dict = Depends(get_current_user), ) -> ImageResponse: """Generate images from a text prompt.""" # Detect if model uses chat completions (FLUX, GPT-5 Image Mini) vs /images/generations (DALL-E) chat_models = {"black-forest-labs/flux.2-klein-4b", "openai/gpt-5-image-mini"} is_chat_model = body.model.lower() in {m.lower() for m in chat_models} or \ any(m in body.model.lower() for m in ["flux", "gpt-5-image-mini"]) try: if is_chat_model: image_config = {} if body.aspect_ratio: image_config["aspect_ratio"] = body.aspect_ratio if body.image_size: image_config["image_size"] = body.image_size result = await openrouter.generate_image_chat( model=body.model, prompt=body.prompt, modalities=[ "image", "text"] if "gpt-5-image-mini" in body.model.lower() else ["image"], image_config=image_config if image_config else None, ) else: result = await openrouter.generate_image( model=body.model, prompt=body.prompt, n=body.n, size=body.size, ) except Exception as exc: raise HTTPException( status_code=status.HTTP_502_BAD_GATEWAY, detail=f"OpenRouter error: {exc}") try: if is_chat_model: # Chat completions response: choices[0].message.images[].image_url.url images = [] message = result.get("choices", [{}])[0].get("message", {}) for item in message.get("images", []): img_url = item.get("image_url", {}).get("url") images.append(ImageResult( url=img_url, b64_json=None, revised_prompt=message.get("content"), )) return ImageResponse( id=result.get("id", ""), model=result.get("model", body.model), images=images, ) else: # /images/generations response: data[].url images = [ ImageResult( url=item.get("url"), b64_json=item.get("b64_json"), revised_prompt=item.get("revised_prompt"), ) for item in result.get("data", []) ] return ImageResponse( id=result.get("id", ""), model=result.get("model", body.model), images=images, ) except (KeyError, TypeError) as exc: raise HTTPException(status_code=status.HTTP_502_BAD_GATEWAY, detail=f"Unexpected response format: {exc}") @router.post("/video", response_model=VideoResponse) async def generate_video( body: VideoRequest, _: dict = Depends(get_current_user), ) -> VideoResponse: """Generate a video from a text prompt.""" try: result = await openrouter.generate_video( model=body.model, prompt=body.prompt, duration_seconds=body.duration_seconds, aspect_ratio=body.aspect_ratio, resolution=body.resolution, ) except Exception as exc: raise HTTPException( status_code=status.HTTP_502_BAD_GATEWAY, detail=f"OpenRouter error: {exc}") urls = result.get("unsigned_urls") or result.get("video_urls") return VideoResponse( id=result.get("id", ""), model=body.model, status=result.get("status", "queued"), polling_url=result.get("polling_url"), video_urls=urls, video_url=(urls or [None])[0], error=result.get("error"), metadata=result.get("metadata"), ) @router.post("/video/from-image", response_model=VideoResponse) async def generate_video_from_image( body: VideoFromImageRequest, _: dict = Depends(get_current_user), ) -> VideoResponse: """Generate a video from an image and a text prompt.""" try: result = await openrouter.generate_video_from_image( model=body.model, image_url=body.image_url, prompt=body.prompt, duration_seconds=body.duration_seconds, aspect_ratio=body.aspect_ratio, resolution=body.resolution, ) except Exception as exc: raise HTTPException( status_code=status.HTTP_502_BAD_GATEWAY, detail=f"OpenRouter error: {exc}") urls = result.get("unsigned_urls") or result.get("video_urls") return VideoResponse( id=result.get("id", ""), model=body.model, status=result.get("status", "queued"), polling_url=result.get("polling_url"), video_urls=urls, video_url=(urls or [None])[0], error=result.get("error"), metadata=result.get("metadata"), ) @router.get("/video/status", response_model=VideoResponse) async def poll_video_status( polling_url: str, _: dict = Depends(get_current_user), ) -> VideoResponse: """Poll the status of a video generation job via its polling_url.""" try: result = await openrouter.poll_video_status(polling_url) except Exception as exc: raise HTTPException( status_code=status.HTTP_502_BAD_GATEWAY, detail=f"OpenRouter error: {exc}") urls = result.get("unsigned_urls") or result.get("video_urls") return VideoResponse( id=result.get("id", ""), model=result.get("model", ""), status=result.get("status", "processing"), polling_url=result.get("polling_url"), video_urls=urls, video_url=(urls or [None])[0], error=result.get("error"), metadata=result.get("metadata"), )