diff --git a/backend/app/models/ai.py b/backend/app/models/ai.py index c5697ff..503862b 100644 --- a/backend/app/models/ai.py +++ b/backend/app/models/ai.py @@ -53,6 +53,8 @@ class ImageRequest(BaseModel): prompt: str n: int = 1 size: str = "1024x1024" + aspect_ratio: str | None = None # e.g. "1:1", "16:9", "9:16" + image_size: str | None = None # e.g. "0.5K", "1K", "2K", "4K" class ImageResult(BaseModel): diff --git a/backend/app/routers/generate.py b/backend/app/routers/generate.py index 4a75ec2..94c2abc 100644 --- a/backend/app/routers/generate.py +++ b/backend/app/routers/generate.py @@ -58,31 +58,69 @@ async def generate_image( _: dict = Depends(get_current_user), ) -> ImageResponse: """Generate images from a text prompt.""" + # Detect if model uses chat completions (FLUX, GPT-5 Image Mini) vs /images/generations (DALL-E) + chat_models = {"black-forest-labs/flux.2-klein-4b", + "openai/gpt-5-image-mini"} + is_chat_model = body.model.lower() in {m.lower() for m in chat_models} or \ + any(m in body.model.lower() for m in ["flux", "gpt-5-image-mini"]) + try: - result = await openrouter.generate_image( - model=body.model, - prompt=body.prompt, - n=body.n, - size=body.size, - ) + if is_chat_model: + image_config = {} + if body.aspect_ratio: + image_config["aspect_ratio"] = body.aspect_ratio + if body.image_size: + image_config["image_size"] = body.image_size + result = await openrouter.generate_image_chat( + model=body.model, + prompt=body.prompt, + modalities=[ + "image", "text"] if "gpt-5-image-mini" in body.model.lower() else ["image"], + image_config=image_config if image_config else None, + ) + else: + result = await openrouter.generate_image( + model=body.model, + prompt=body.prompt, + n=body.n, + size=body.size, + ) except Exception as exc: raise HTTPException( status_code=status.HTTP_502_BAD_GATEWAY, detail=f"OpenRouter error: {exc}") try: - images = [ - ImageResult( - url=item.get("url"), - b64_json=item.get("b64_json"), - revised_prompt=item.get("revised_prompt"), + if is_chat_model: + # Chat completions response: choices[0].message.images[].image_url.url + images = [] + message = result.get("choices", [{}])[0].get("message", {}) + for item in message.get("images", []): + img_url = item.get("image_url", {}).get("url") + images.append(ImageResult( + url=img_url, + b64_json=None, + revised_prompt=message.get("content"), + )) + return ImageResponse( + id=result.get("id", ""), + model=result.get("model", body.model), + images=images, + ) + else: + # /images/generations response: data[].url + images = [ + ImageResult( + url=item.get("url"), + b64_json=item.get("b64_json"), + revised_prompt=item.get("revised_prompt"), + ) + for item in result.get("data", []) + ] + return ImageResponse( + id=result.get("id", ""), + model=result.get("model", body.model), + images=images, ) - for item in result.get("data", []) - ] - return ImageResponse( - id=result.get("id", ""), - model=result.get("model", body.model), - images=images, - ) except (KeyError, TypeError) as exc: raise HTTPException(status_code=status.HTTP_502_BAD_GATEWAY, detail=f"Unexpected response format: {exc}") diff --git a/backend/app/services/openrouter.py b/backend/app/services/openrouter.py index 9b5e8f3..4843603 100644 --- a/backend/app/services/openrouter.py +++ b/backend/app/services/openrouter.py @@ -139,3 +139,34 @@ async def poll_video_status(polling_url: str) -> dict[str, Any]: response = await client.send(resp) response.raise_for_status() return response.json() + + +async def generate_image_chat( + model: str, + prompt: str, + modalities: list[str] | None = None, + image_config: dict[str, Any] | None = None, +) -> dict[str, Any]: + """Request image generation via Chat Completions with modalities. + + Used by models like FLUX.2 Klein 4B and GPT-5 Image Mini that output + images through the chat completions endpoint rather than /images/generations. + """ + base_url = os.getenv("OPENROUTER_BASE_URL", OPENROUTER_BASE_URL) + if modalities is None: + # Image-only models (FLUX) vs multimodal (GPT-5 Image Mini) + modalities = ["image"] + payload: dict[str, Any] = { + "model": model, + "messages": [{"role": "user", "content": prompt}], + "modalities": modalities, + } + if image_config: + payload["image_config"] = image_config + async with httpx.AsyncClient(timeout=120) as client: + resp = client.build_request( + "POST", f"{base_url}/chat/completions", headers=_headers(), json=payload + ) + response = await client.send(resp) + response.raise_for_status() + return response.json() diff --git a/backend/tests/test_generate.py b/backend/tests/test_generate.py index 4dd0b5c..1f717c7 100644 --- a/backend/tests/test_generate.py +++ b/backend/tests/test_generate.py @@ -149,6 +149,105 @@ async def test_generate_image_upstream_error(client): assert resp.status_code == 502 +# --- Chat-based image generation (FLUX, GPT-5 Image Mini) --- + +FAKE_IMAGE_CHAT_FLUX = { + "id": "gen-img-chat-1", + "model": "black-forest-labs/flux.2-klein-4b", + "choices": [{ + "message": { + "role": "assistant", + "content": "Here is your generated image.", + "images": [{ + "type": "image_url", + "image_url": {"url": "data:image/png;base64,abc123"}, + }], + } + }], +} + +FAKE_IMAGE_CHAT_GPT5 = { + "id": "gen-img-chat-2", + "model": "openai/gpt-5-image-mini", + "choices": [{ + "message": { + "role": "assistant", + "content": "Generated image.", + "images": [{ + "type": "image_url", + "image_url": {"url": "data:image/png;base64,xyz789"}, + }], + } + }], +} + + +async def test_generate_image_chat_flux(client): + token = await _user_token(client) + with patch("backend.app.routers.generate.openrouter.generate_image_chat", new_callable=AsyncMock, return_value=FAKE_IMAGE_CHAT_FLUX): + resp = await client.post( + "/generate/image", + json={"model": "black-forest-labs/flux.2-klein-4b", + "prompt": "A sunset"}, + headers={"Authorization": f"Bearer {token}"}, + ) + assert resp.status_code == 200 + data = resp.json() + assert data["id"] == "gen-img-chat-1" + assert len(data["images"]) == 1 + assert data["images"][0]["url"] == "data:image/png;base64,abc123" + + +async def test_generate_image_chat_gpt5_image_mini(client): + token = await _user_token(client) + with patch("backend.app.routers.generate.openrouter.generate_image_chat", new_callable=AsyncMock, return_value=FAKE_IMAGE_CHAT_GPT5): + resp = await client.post( + "/generate/image", + json={"model": "openai/gpt-5-image-mini", "prompt": "A cat"}, + headers={"Authorization": f"Bearer {token}"}, + ) + assert resp.status_code == 200 + data = resp.json() + assert data["model"] == "openai/gpt-5-image-mini" + assert len(data["images"]) == 1 + + +async def test_generate_image_chat_with_image_config(client): + token = await _user_token(client) + mock = AsyncMock(return_value=FAKE_IMAGE_CHAT_FLUX) + with patch("backend.app.routers.generate.openrouter.generate_image_chat", mock): + await client.post( + "/generate/image", + json={ + "model": "black-forest-labs/flux.2-klein-4b", + "prompt": "A landscape", + "aspect_ratio": "16:9", + "image_size": "2K", + }, + headers={"Authorization": f"Bearer {token}"}, + ) + call_kwargs = mock.call_args.kwargs + assert call_kwargs["image_config"]["aspect_ratio"] == "16:9" + assert call_kwargs["image_config"]["image_size"] == "2K" + assert call_kwargs["modalities"] == ["image"] + + +async def test_generate_image_chat_unauthenticated(client): + resp = await client.post("/generate/image", json={"model": "flux.2-klein-4b", "prompt": "Hi"}) + assert resp.status_code == 401 + + +async def test_generate_image_chat_upstream_error(client): + token = await _user_token(client) + with patch("backend.app.routers.generate.openrouter.generate_image_chat", new_callable=AsyncMock, side_effect=Exception("timeout")): + resp = await client.post( + "/generate/image", + json={"model": "black-forest-labs/flux.2-klein-4b", "prompt": "Hi"}, + headers={"Authorization": f"Bearer {token}"}, + ) + assert resp.status_code == 502 + + # --------------------------------------------------------------------------- # POST /generate/video # --------------------------------------------------------------------------- diff --git a/frontend/app/main.py b/frontend/app/main.py index 5017329..b1da228 100644 --- a/frontend/app/main.py +++ b/frontend/app/main.py @@ -158,6 +158,8 @@ def generate_image(): "prompt": request.form.get("prompt", "").strip(), "n": int(request.form.get("n", 1)), "size": request.form.get("size", "1024x1024"), + "aspect_ratio": request.form.get("aspect_ratio", "").strip() or None, + "image_size": request.form.get("image_size", "").strip() or None, }) if resp.status_code == 200: result = resp.json() diff --git a/frontend/app/templates/generate_image.html b/frontend/app/templates/generate_image.html index 0a15db0..786b7e5 100644 --- a/frontend/app/templates/generate_image.html +++ b/frontend/app/templates/generate_image.html @@ -21,6 +21,27 @@ + + + + + +