diff --git a/backend/app/db.py b/backend/app/db.py index 17528ba..51bceac 100644 --- a/backend/app/db.py +++ b/backend/app/db.py @@ -86,6 +86,34 @@ def _run_migrations(conn: duckdb.DuckDBPyConnection) -> None: fetched_at TIMESTAMP NOT NULL ) """) + conn.execute(""" + CREATE TABLE IF NOT EXISTS generated_images ( + id UUID DEFAULT uuid() PRIMARY KEY, + user_id UUID NOT NULL, + model_id VARCHAR NOT NULL, + prompt VARCHAR NOT NULL, + image_data VARCHAR NOT NULL, + created_at TIMESTAMP DEFAULT now() + ) + """) + conn.execute(""" + CREATE TABLE IF NOT EXISTS generated_videos ( + id UUID DEFAULT uuid() PRIMARY KEY, + user_id UUID NOT NULL, + job_id VARCHAR NOT NULL, + model_id VARCHAR NOT NULL, + prompt VARCHAR NOT NULL, + polling_url VARCHAR, + status VARCHAR NOT NULL DEFAULT 'pending', + video_url VARCHAR, + created_at TIMESTAMP DEFAULT now(), + updated_at TIMESTAMP DEFAULT now() + ) + """) + # Migration: add output_modalities column if absent (stores JSON array string) + conn.execute(""" + ALTER TABLE models_cache ADD COLUMN IF NOT EXISTS output_modalities VARCHAR + """) _seed_admin(conn) diff --git a/backend/app/main.py b/backend/app/main.py index 1f98737..24457c8 100644 --- a/backend/app/main.py +++ b/backend/app/main.py @@ -1,10 +1,10 @@ -from .routers import auth as auth_router -from .routers import users as users_router -from .routers import admin as admin_router -from .routers import ai as ai_router -from .routers import generate as generate_router -from .routers import images as images_router -from .routers import models as models_router +from .routers import auth +from .routers import users +from .routers import admin +from .routers import ai +from .routers import generate +from .routers import images +from .routers import models from .db import close_db, init_db import os from contextlib import asynccontextmanager @@ -38,13 +38,13 @@ app.add_middleware( allow_headers=["*"], ) -app.include_router(auth_router.router) -app.include_router(users_router.router) -app.include_router(admin_router.router) -app.include_router(ai_router.router) -app.include_router(generate_router.router) -app.include_router(images_router.router) -app.include_router(models_router.router) +app.include_router(auth.router) +app.include_router(users.router) +app.include_router(admin.router) +app.include_router(ai.router) +app.include_router(generate.router) +app.include_router(images.router) +app.include_router(models.router) @app.get("/health", tags=["health"]) diff --git a/backend/app/models/ai.py b/backend/app/models/ai.py index 9894077..7b671e2 100644 --- a/backend/app/models/ai.py +++ b/backend/app/models/ai.py @@ -62,6 +62,7 @@ class ImageResult(BaseModel): url: str | None = None b64_json: str | None = None revised_prompt: str | None = None + image_id: str | None = None # UUID of stored row in generated_images class ImageResponse(BaseModel): diff --git a/backend/app/routers/generate.py b/backend/app/routers/generate.py index f801a07..fc4c9c2 100644 --- a/backend/app/routers/generate.py +++ b/backend/app/routers/generate.py @@ -1,6 +1,9 @@ """Generate router: text, image, video, and image-to-video generation.""" +from datetime import datetime, timezone + from fastapi import APIRouter, Depends, HTTPException, status +from ..db import get_conn, get_write_lock from ..dependencies import get_current_user from ..models.ai import ( ImageRequest, @@ -13,6 +16,7 @@ from ..models.ai import ( VideoResponse, ) from ..services import openrouter +from ..services.models import get_model_output_modalities router = APIRouter(prefix="/generate", tags=["generate"]) @@ -62,81 +66,129 @@ async def generate_text( @router.post("/image", response_model=ImageResponse) async def generate_image( body: ImageRequest, - _: dict = Depends(get_current_user), + current_user: dict = Depends(get_current_user), ) -> ImageResponse: - """Generate images from a text prompt.""" - # Detect if model uses chat completions (FLUX, GPT-5 Image Mini) vs /images/generations (DALL-E) - chat_models = {"black-forest-labs/flux.2-klein-4b", - "openai/gpt-5-image-mini"} - is_chat_model = body.model.lower() in {m.lower() for m in chat_models} or \ - any(m in body.model.lower() for m in ["flux", "gpt-5-image-mini"]) + """Generate images from a prompt using the chat completions endpoint. + + All OpenRouter image models use /chat/completions with a modalities param. + Models that output only images use ["image"]; those that also output text + use ["image", "text"]. We look this up from the model cache; default to + ["image", "text"] when the model is not yet cached. + """ + # Determine modalities from cache; default ["image", "text"] works for most models + try: + conn = get_conn() + cached_modalities = get_model_output_modalities(conn, body.model) + except Exception: + cached_modalities = [] + + if cached_modalities: + # If cache says model only outputs image (no text), use ["image"] + modalities = ["image"] if set(cached_modalities) == { + "image"} else ["image", "text"] + else: + # Safe default: ["image", "text"]; works for Gemini, GPT-image etc. + # For image-only models that fail with this, the error surfaces to the user. + modalities = ["image", "text"] + + image_config: dict = {} + if body.aspect_ratio: + image_config["aspect_ratio"] = body.aspect_ratio + if body.image_size: + image_config["image_size"] = body.image_size try: - if is_chat_model: - image_config = {} - if body.aspect_ratio: - image_config["aspect_ratio"] = body.aspect_ratio - if body.image_size: - image_config["image_size"] = body.image_size - result = await openrouter.generate_image_chat( - model=body.model, - prompt=body.prompt, - modalities=[ - "image", "text"] if "gpt-5-image-mini" in body.model.lower() else ["image"], - image_config=image_config if image_config else None, - ) - else: - result = await openrouter.generate_image( - model=body.model, - prompt=body.prompt, - n=body.n, - size=body.size, - ) + result = await openrouter.generate_image_chat( + model=body.model, + prompt=body.prompt, + modalities=modalities, + image_config=image_config if image_config else None, + ) except Exception as exc: raise HTTPException( status_code=status.HTTP_502_BAD_GATEWAY, detail=f"OpenRouter error: {exc}") try: - if is_chat_model: - # Chat completions response: choices[0].message.images[].image_url.url - images = [] - message = result.get("choices", [{}])[0].get("message", {}) - for item in message.get("images", []): - img_url = item.get("image_url", {}).get("url") - images.append(ImageResult( - url=img_url, - b64_json=None, - revised_prompt=message.get("content"), + message = result.get("choices", [{}])[0].get("message", {}) + images = [] + for item in message.get("images", []): + img_url = item.get("image_url", {}).get("url") + images.append(ImageResult( + url=img_url, + b64_json=None, + revised_prompt=message.get("content") or None, + )) + if not images: + raise HTTPException( + status_code=status.HTTP_502_BAD_GATEWAY, + detail="No images returned by model. Verify the model supports image generation.", + ) + + # Persist each image to DB + user_id = current_user.get("id") or current_user.get("sub") + now = datetime.now(timezone.utc).replace(tzinfo=None) + stored: list[ImageResult] = [] + async with get_write_lock(): + conn = get_conn() + for img in images: + if img.url: + row = conn.execute( + """INSERT INTO generated_images (user_id, model_id, prompt, image_data, created_at) + VALUES (?, ?, ?, ?, ?) RETURNING id""", + [user_id, body.model, body.prompt, img.url, now], + ).fetchone() + image_id = str(row[0]) if row else None + else: + image_id = None + stored.append(ImageResult( + url=img.url, + b64_json=img.b64_json, + revised_prompt=img.revised_prompt, + image_id=image_id, )) - return ImageResponse( - id=result.get("id", ""), - model=result.get("model", body.model), - images=images, - ) - else: - # /images/generations response: data[].url - images = [ - ImageResult( - url=item.get("url"), - b64_json=item.get("b64_json"), - revised_prompt=item.get("revised_prompt"), - ) - for item in result.get("data", []) - ] - return ImageResponse( - id=result.get("id", ""), - model=result.get("model", body.model), - images=images, - ) + + return ImageResponse( + id=result.get("id", ""), + model=result.get("model", body.model), + images=stored, + ) + except HTTPException: + raise except (KeyError, TypeError) as exc: raise HTTPException(status_code=status.HTTP_502_BAD_GATEWAY, detail=f"Unexpected response format: {exc}") +@router.get("/images") +async def list_generated_images( + current_user: dict = Depends(get_current_user), +) -> list[dict]: + """Return all generated images for the current user, newest first.""" + user_id = current_user.get("id") or current_user.get("sub") + conn = get_conn() + rows = conn.execute( + """SELECT id, model_id, prompt, image_data, created_at + FROM generated_images + WHERE user_id = ? + ORDER BY created_at DESC""", + [user_id], + ).fetchall() + return [ + { + "id": str(r[0]), + "model_id": r[1], + "prompt": r[2], + "image_data": r[3], + "created_at": r[4].isoformat() if r[4] else None, + } + for r in rows + ] + + @router.post("/video", response_model=VideoResponse) async def generate_video( body: VideoRequest, - _: dict = Depends(get_current_user), + current_user: dict = Depends(get_current_user), ) -> VideoResponse: """Generate a video from a text prompt.""" try: @@ -151,12 +203,26 @@ async def generate_video( raise HTTPException( status_code=status.HTTP_502_BAD_GATEWAY, detail=f"OpenRouter error: {exc}") + user_id = current_user.get("id") or current_user.get("sub") + job_id = result.get("id", "") + polling_url = result.get("polling_url") + job_status = result.get("status", "pending") + now = datetime.now(timezone.utc).replace(tzinfo=None) + async with get_write_lock(): + conn = get_conn() + conn.execute( + """INSERT INTO generated_videos (user_id, job_id, model_id, prompt, polling_url, status, created_at, updated_at) + VALUES (?, ?, ?, ?, ?, ?, ?, ?)""", + [user_id, job_id, body.model, body.prompt, + polling_url, job_status, now, now], + ) + urls = result.get("unsigned_urls") or result.get("video_urls") return VideoResponse( - id=result.get("id", ""), + id=job_id, model=body.model, - status=result.get("status", "queued"), - polling_url=result.get("polling_url"), + status=job_status, + polling_url=polling_url, video_urls=urls, video_url=(urls or [None])[0], error=result.get("error"), @@ -167,7 +233,7 @@ async def generate_video( @router.post("/video/from-image", response_model=VideoResponse) async def generate_video_from_image( body: VideoFromImageRequest, - _: dict = Depends(get_current_user), + current_user: dict = Depends(get_current_user), ) -> VideoResponse: """Generate a video from an image and a text prompt.""" try: @@ -183,12 +249,26 @@ async def generate_video_from_image( raise HTTPException( status_code=status.HTTP_502_BAD_GATEWAY, detail=f"OpenRouter error: {exc}") + user_id = current_user.get("id") or current_user.get("sub") + job_id = result.get("id", "") + polling_url = result.get("polling_url") + job_status = result.get("status", "pending") + now = datetime.now(timezone.utc).replace(tzinfo=None) + async with get_write_lock(): + conn = get_conn() + conn.execute( + """INSERT INTO generated_videos (user_id, job_id, model_id, prompt, polling_url, status, created_at, updated_at) + VALUES (?, ?, ?, ?, ?, ?, ?, ?)""", + [user_id, job_id, body.model, body.prompt, + polling_url, job_status, now, now], + ) + urls = result.get("unsigned_urls") or result.get("video_urls") return VideoResponse( - id=result.get("id", ""), + id=job_id, model=body.model, - status=result.get("status", "queued"), - polling_url=result.get("polling_url"), + status=job_status, + polling_url=polling_url, video_urls=urls, video_url=(urls or [None])[0], error=result.get("error"), @@ -199,23 +279,67 @@ async def generate_video_from_image( @router.get("/video/status", response_model=VideoResponse) async def poll_video_status( polling_url: str, - _: dict = Depends(get_current_user), + current_user: dict = Depends(get_current_user), ) -> VideoResponse: - """Poll the status of a video generation job via its polling_url.""" + """Poll status of a video generation job; updates DB row when completed/failed.""" try: result = await openrouter.poll_video_status(polling_url) except Exception as exc: raise HTTPException( status_code=status.HTTP_502_BAD_GATEWAY, detail=f"OpenRouter error: {exc}") + job_status = result.get("status", "processing") urls = result.get("unsigned_urls") or result.get("video_urls") + video_url = (urls or [None])[0] + + # Update DB row for this job when terminal state reached + if job_status in ("completed", "failed"): + now = datetime.now(timezone.utc).replace(tzinfo=None) + async with get_write_lock(): + conn = get_conn() + conn.execute( + """UPDATE generated_videos + SET status = ?, video_url = ?, updated_at = ? + WHERE job_id = ?""", + [job_status, video_url, now, result.get("id", "")], + ) + return VideoResponse( id=result.get("id", ""), model=result.get("model", ""), - status=result.get("status", "processing"), + status=job_status, polling_url=result.get("polling_url"), video_urls=urls, - video_url=(urls or [None])[0], + video_url=video_url, error=result.get("error"), metadata=result.get("metadata"), ) + + +@router.get("/videos") +async def list_generated_videos( + current_user: dict = Depends(get_current_user), +) -> list[dict]: + """Return all generated video jobs for the current user, newest first.""" + user_id = current_user.get("id") or current_user.get("sub") + conn = get_conn() + rows = conn.execute( + """SELECT id, job_id, model_id, prompt, polling_url, status, video_url, created_at + FROM generated_videos + WHERE user_id = ? + ORDER BY created_at DESC""", + [user_id], + ).fetchall() + return [ + { + "id": str(r[0]), + "job_id": r[1], + "model_id": r[2], + "prompt": r[3], + "polling_url": r[4], + "status": r[5], + "video_url": r[6], + "created_at": r[7].isoformat() if r[7] else None, + } + for r in rows + ] diff --git a/backend/app/services/models.py b/backend/app/services/models.py index 2738e18..14289a5 100644 --- a/backend/app/services/models.py +++ b/backend/app/services/models.py @@ -91,16 +91,28 @@ async def refresh_models_cache(conn: duckdb.DuckDBPyConnection) -> int: model_id = m.get("id", "") if not model_id: continue + # Full output_modalities array from architecture (for proper modalities param in image gen) + architecture = m.get("architecture") or {} + raw_output_modalities: list | None = ( + architecture.get("output_modalities") or m.get("output_modalities") + ) + output_modalities_json: str | None = ( + json.dumps([_normalize_modality(str(v)) + for v in raw_output_modalities]) + if isinstance(raw_output_modalities, list) + else None + ) conn.execute( """ - INSERT INTO models_cache (model_id, name, modality, context_length, pricing, fetched_at) - VALUES (?, ?, ?, ?, ?, ?) + INSERT INTO models_cache (model_id, name, modality, context_length, pricing, fetched_at, output_modalities) + VALUES (?, ?, ?, ?, ?, ?, ?) ON CONFLICT (model_id) DO UPDATE SET name = excluded.name, modality = excluded.modality, context_length = excluded.context_length, pricing = excluded.pricing, - fetched_at = excluded.fetched_at + fetched_at = excluded.fetched_at, + output_modalities = excluded.output_modalities """, [ model_id, @@ -109,6 +121,7 @@ async def refresh_models_cache(conn: duckdb.DuckDBPyConnection) -> int: m.get("context_length"), json.dumps(pricing) if pricing else None, now, + output_modalities_json, ], ) count += 1 @@ -168,3 +181,20 @@ def get_cached_models( "pricing": pricing, }) return result + + +def get_model_output_modalities( + conn: duckdb.DuckDBPyConnection, + model_id: str, +) -> list[str]: + """Return output_modalities list for a model; empty list if not found.""" + row = conn.execute( + "SELECT output_modalities FROM models_cache WHERE model_id = ?", + [model_id], + ).fetchone() + if not row or not row[0]: + return [] + try: + return json.loads(row[0]) + except (json.JSONDecodeError, TypeError): + return [] diff --git a/backend/app/services/openrouter.py b/backend/app/services/openrouter.py index e9b66c0..08d7005 100644 --- a/backend/app/services/openrouter.py +++ b/backend/app/services/openrouter.py @@ -95,8 +95,9 @@ async def generate_video( duration_seconds: int | None = None, aspect_ratio: str = "16:9", resolution: str | None = None, + generate_audio: bool | None = None, ) -> dict[str, Any]: - """Request text-to-video generation via OpenRouter.""" + """Request text-to-video generation via OpenRouter POST /videos.""" base_url = os.getenv("OPENROUTER_BASE_URL", OPENROUTER_BASE_URL) payload: dict[str, Any] = { "model": model, @@ -104,9 +105,12 @@ async def generate_video( "aspect_ratio": aspect_ratio, } if duration_seconds is not None: - payload["duration_seconds"] = duration_seconds + # API uses 'duration' not 'duration_seconds' + payload["duration"] = duration_seconds if resolution is not None: payload["resolution"] = resolution + if generate_audio is not None: + payload["generate_audio"] = generate_audio async with httpx.AsyncClient(timeout=120) as client: resp = client.build_request( "POST", f"{base_url}/videos", headers=_headers(), json=payload @@ -123,19 +127,31 @@ async def generate_video_from_image( duration_seconds: int | None = None, aspect_ratio: str = "16:9", resolution: str | None = None, + generate_audio: bool | None = None, ) -> dict[str, Any]: - """Request image-to-video generation via OpenRouter.""" + """Request image-to-video generation via OpenRouter POST /videos. + + Uses frame_images array with first_frame as per OpenRouter API spec. + """ base_url = os.getenv("OPENROUTER_BASE_URL", OPENROUTER_BASE_URL) payload: dict[str, Any] = { "model": model, - "image_url": image_url, "prompt": prompt, "aspect_ratio": aspect_ratio, + "frame_images": [ + { + "type": "image_url", + "image_url": {"url": image_url}, + "frame_type": "first_frame", + } + ], } if duration_seconds is not None: - payload["duration_seconds"] = duration_seconds + payload["duration"] = duration_seconds if resolution is not None: payload["resolution"] = resolution + if generate_audio is not None: + payload["generate_audio"] = generate_audio async with httpx.AsyncClient(timeout=120) as client: resp = client.build_request( "POST", f"{base_url}/videos", headers=_headers(), json=payload @@ -154,6 +170,18 @@ async def poll_video_status(polling_url: str) -> dict[str, Any]: return response.json() +async def list_video_models() -> list[dict[str, Any]]: + """Return video generation models from the dedicated /videos/models endpoint.""" + base_url = os.getenv("OPENROUTER_BASE_URL", OPENROUTER_BASE_URL) + async with httpx.AsyncClient(timeout=15) as client: + resp = client.build_request( + "GET", f"{base_url}/videos/models", headers=_headers() + ) + response = await client.send(resp) + response.raise_for_status() + return response.json().get("data", []) + + async def generate_image_chat( model: str, prompt: str, diff --git a/backend/tests/test_admin.py b/backend/tests/test_admin.py index 9861b96..583fdca 100644 --- a/backend/tests/test_admin.py +++ b/backend/tests/test_admin.py @@ -53,7 +53,8 @@ async def test_stats_as_admin(client): resp = await client.get("/admin/stats", headers={"Authorization": f"Bearer {token}"}) assert resp.status_code == 200 data = resp.json() - assert data["users"]["total"] == 3 # 2 users + 1 admin + # 2 users + 1 admin + 1 seeded admin (ai@allucanget.biz) + assert data["users"]["total"] == 4 assert "by_role" in data["users"] assert "refresh_tokens" in data diff --git a/backend/tests/test_ai.py b/backend/tests/test_ai.py index 3c098db..d2a4d16 100644 --- a/backend/tests/test_ai.py +++ b/backend/tests/test_ai.py @@ -53,7 +53,7 @@ async def _user_token(client): async def test_list_models(client): token = await _user_token(client) with patch( - "backend.app.routers.ai.openrouter.list_models", + "app.routers.ai.openrouter.list_models", new_callable=AsyncMock, return_value=FAKE_MODELS, ): @@ -74,7 +74,7 @@ async def test_list_models_unauthenticated(client): async def test_list_models_upstream_error(client): token = await _user_token(client) with patch( - "backend.app.routers.ai.openrouter.list_models", + "app.routers.ai.openrouter.list_models", new_callable=AsyncMock, side_effect=Exception("Connection refused"), ): @@ -91,7 +91,7 @@ async def test_list_models_upstream_error(client): async def test_chat_success(client): token = await _user_token(client) with patch( - "backend.app.routers.ai.openrouter.chat_completion", + "app.routers.ai.openrouter.chat_completion", new_callable=AsyncMock, return_value=FAKE_CHAT_RESPONSE, ): @@ -115,7 +115,7 @@ async def test_chat_success(client): async def test_chat_passes_parameters(client): token = await _user_token(client) mock = AsyncMock(return_value=FAKE_CHAT_RESPONSE) - with patch("backend.app.routers.ai.openrouter.chat_completion", new_callable=AsyncMock, return_value=FAKE_CHAT_RESPONSE) as mock: + with patch("app.routers.ai.openrouter.chat_completion", new_callable=AsyncMock, return_value=FAKE_CHAT_RESPONSE) as mock: await client.post( "/ai/chat", json={ @@ -145,7 +145,7 @@ async def test_chat_unauthenticated(client): async def test_chat_upstream_error(client): token = await _user_token(client) with patch( - "backend.app.routers.ai.openrouter.chat_completion", + "app.routers.ai.openrouter.chat_completion", new_callable=AsyncMock, side_effect=Exception("timeout"), ): @@ -160,7 +160,7 @@ async def test_chat_upstream_error(client): async def test_chat_malformed_upstream_response(client): token = await _user_token(client) with patch( - "backend.app.routers.ai.openrouter.chat_completion", + "app.routers.ai.openrouter.chat_completion", new_callable=AsyncMock, return_value={"id": "x", "choices": []}, # empty choices ): diff --git a/backend/tests/test_generate.py b/backend/tests/test_generate.py index 595326a..366f089 100644 --- a/backend/tests/test_generate.py +++ b/backend/tests/test_generate.py @@ -18,15 +18,6 @@ FAKE_CHAT = { "usage": {"prompt_tokens": 5, "completion_tokens": 10, "total_tokens": 15}, } -FAKE_IMAGE = { - "id": "gen-img-1", - "model": "openai/dall-e-3", - "data": [ - {"url": "https://example.com/image.png", - "revised_prompt": "A cat on the moon"}, - ], -} - FAKE_VIDEO = { "id": "gen-vid-1", "polling_url": "https://openrouter.ai/api/v1/videos/gen-vid-1", @@ -155,47 +146,13 @@ async def test_generate_text_upstream_error(client): # POST /generate/image # --------------------------------------------------------------------------- -async def test_generate_image(client): - token = await _user_token(client) - with patch("app.routers.generate.openrouter.generate_image", new_callable=AsyncMock, return_value=FAKE_IMAGE): - resp = await client.post( - "/generate/image", - json={"model": "openai/dall-e-3", "prompt": "A cat on the moon"}, - headers={"Authorization": f"Bearer {token}"}, - ) - assert resp.status_code == 200 - data = resp.json() - assert data["id"] == "gen-img-1" - assert len(data["images"]) == 1 - assert data["images"][0]["url"] == "https://example.com/image.png" - assert data["images"][0]["revised_prompt"] == "A cat on the moon" - - -async def test_generate_image_unauthenticated(client): - resp = await client.post("/generate/image", json={"model": "openai/dall-e-3", "prompt": "Hi"}) - assert resp.status_code == 401 - - -async def test_generate_image_upstream_error(client): - token = await _user_token(client) - with patch("app.routers.generate.openrouter.generate_image", new_callable=AsyncMock, side_effect=Exception("rate limit")): - resp = await client.post( - "/generate/image", - json={"model": "openai/dall-e-3", "prompt": "Hi"}, - headers={"Authorization": f"Bearer {token}"}, - ) - assert resp.status_code == 502 - - -# --- Chat-based image generation (FLUX, GPT-5 Image Mini) --- - FAKE_IMAGE_CHAT_FLUX = { "id": "gen-img-chat-1", "model": "black-forest-labs/flux.2-klein-4b", "choices": [{ "message": { "role": "assistant", - "content": "Here is your generated image.", + "content": None, "images": [{ "type": "image_url", "image_url": {"url": "data:image/png;base64,abc123"}, @@ -219,45 +176,65 @@ FAKE_IMAGE_CHAT_GPT5 = { }], } +FAKE_IMAGE_CHAT_GEMINI = { + "id": "gen-img-chat-3", + "model": "google/gemini-2.5-flash-image", + "choices": [{ + "message": { + "role": "assistant", + "content": "Here is your image.", + "images": [{ + "type": "image_url", + "image_url": {"url": "data:image/png;base64,gemini123"}, + }], + } + }], +} -async def test_generate_image_chat_flux(client): + +async def test_generate_image(client): + """All models now use generate_image_chat (chat completions endpoint).""" token = await _user_token(client) - with patch("app.routers.generate.openrouter.generate_image_chat", new_callable=AsyncMock, return_value=FAKE_IMAGE_CHAT_FLUX): + with patch("app.routers.generate.openrouter.generate_image_chat", new_callable=AsyncMock, return_value=FAKE_IMAGE_CHAT_GEMINI): resp = await client.post( "/generate/image", - json={"model": "black-forest-labs/flux.2-klein-4b", - "prompt": "A sunset"}, + json={"model": "google/gemini-2.5-flash-image", + "prompt": "A cat on the moon"}, headers={"Authorization": f"Bearer {token}"}, ) assert resp.status_code == 200 data = resp.json() - assert data["id"] == "gen-img-chat-1" + assert data["id"] == "gen-img-chat-3" assert len(data["images"]) == 1 - assert data["images"][0]["url"] == "data:image/png;base64,abc123" + assert data["images"][0]["url"] == "data:image/png;base64,gemini123" + assert data["images"][0]["image_id"] is not None # stored in DB -async def test_generate_image_chat_gpt5_image_mini(client): +async def test_generate_image_unauthenticated(client): + resp = await client.post("/generate/image", json={"model": "google/gemini-2.5-flash-image", "prompt": "Hi"}) + assert resp.status_code == 401 + + +async def test_generate_image_upstream_error(client): token = await _user_token(client) - with patch("app.routers.generate.openrouter.generate_image_chat", new_callable=AsyncMock, return_value=FAKE_IMAGE_CHAT_GPT5): + with patch("app.routers.generate.openrouter.generate_image_chat", new_callable=AsyncMock, side_effect=Exception("rate limit")): resp = await client.post( "/generate/image", - json={"model": "openai/gpt-5-image-mini", "prompt": "A cat"}, + json={"model": "google/gemini-2.5-flash-image", "prompt": "Hi"}, headers={"Authorization": f"Bearer {token}"}, ) - assert resp.status_code == 200 - data = resp.json() - assert data["model"] == "openai/gpt-5-image-mini" - assert len(data["images"]) == 1 + assert resp.status_code == 502 -async def test_generate_image_chat_with_image_config(client): +async def test_generate_image_with_image_config(client): + """Passes aspect_ratio + image_size through to generate_image_chat.""" token = await _user_token(client) - mock = AsyncMock(return_value=FAKE_IMAGE_CHAT_FLUX) + mock = AsyncMock(return_value=FAKE_IMAGE_CHAT_GEMINI) with patch("app.routers.generate.openrouter.generate_image_chat", mock): await client.post( "/generate/image", json={ - "model": "black-forest-labs/flux.2-klein-4b", + "model": "google/gemini-2.5-flash-image", "prompt": "A landscape", "aspect_ratio": "16:9", "image_size": "2K", @@ -267,23 +244,112 @@ async def test_generate_image_chat_with_image_config(client): call_kwargs = mock.call_args.kwargs assert call_kwargs["image_config"]["aspect_ratio"] == "16:9" assert call_kwargs["image_config"]["image_size"] == "2K" - assert call_kwargs["modalities"] == ["image"] -async def test_generate_image_chat_unauthenticated(client): - resp = await client.post("/generate/image", json={"model": "flux.2-klein-4b", "prompt": "Hi"}) - assert resp.status_code == 401 - - -async def test_generate_image_chat_upstream_error(client): +async def test_generate_image_default_modalities_image_text(client): + """Model not in cache → default modalities = ['image', 'text'].""" token = await _user_token(client) - with patch("app.routers.generate.openrouter.generate_image_chat", new_callable=AsyncMock, side_effect=Exception("timeout")): + mock = AsyncMock(return_value=FAKE_IMAGE_CHAT_GEMINI) + with patch("app.routers.generate.openrouter.generate_image_chat", mock): + await client.post( + "/generate/image", + json={"model": "google/gemini-2.5-flash-image", "prompt": "Hi"}, + headers={"Authorization": f"Bearer {token}"}, + ) + assert mock.call_args.kwargs["modalities"] == ["image", "text"] + + +async def test_generate_image_image_only_modalities_from_cache(client): + """Model cached with image-only output_modalities → modalities = ['image'].""" + from app import db as db_module + from app.services.models import get_model_output_modalities + import json as _json + token = await _user_token(client) + + # Seed cache with image-only model + conn = db_module.get_conn() + from datetime import datetime, timezone + now = datetime.now(timezone.utc).replace(tzinfo=None) + conn.execute( + "DELETE FROM models_cache WHERE model_id = 'black-forest-labs/flux.2-pro'" + ) + conn.execute( + """INSERT INTO models_cache (model_id, name, modality, context_length, pricing, fetched_at, output_modalities) + VALUES (?, ?, ?, ?, ?, ?, ?)""", + ["black-forest-labs/flux.2-pro", "FLUX.2 Pro", "image", None, None, now, + _json.dumps(["image"])], + ) + + mock = AsyncMock(return_value=FAKE_IMAGE_CHAT_FLUX) + with patch("app.routers.generate.openrouter.generate_image_chat", mock): resp = await client.post( "/generate/image", - json={"model": "black-forest-labs/flux.2-klein-4b", "prompt": "Hi"}, + json={"model": "black-forest-labs/flux.2-pro", "prompt": "Sky"}, + headers={"Authorization": f"Bearer {token}"}, + ) + assert resp.status_code == 200 + assert mock.call_args.kwargs["modalities"] == ["image"] + + +async def test_generate_image_no_images_in_response(client): + """502 when model returns no images.""" + token = await _user_token(client) + empty_response = { + "id": "gen-empty", + "model": "google/gemini-2.5-flash-image", + "choices": [{"message": {"role": "assistant", "content": "ok", "images": []}}], + } + with patch("app.routers.generate.openrouter.generate_image_chat", + new_callable=AsyncMock, return_value=empty_response): + resp = await client.post( + "/generate/image", + json={"model": "google/gemini-2.5-flash-image", "prompt": "Hi"}, headers={"Authorization": f"Bearer {token}"}, ) assert resp.status_code == 502 + assert "No images returned" in resp.json()["detail"] + + +async def test_generate_image_flux(client): + """Flux model works correctly via chat completions.""" + token = await _user_token(client) + with patch("app.routers.generate.openrouter.generate_image_chat", + new_callable=AsyncMock, return_value=FAKE_IMAGE_CHAT_FLUX): + resp = await client.post( + "/generate/image", + json={"model": "black-forest-labs/flux.2-klein-4b", + "prompt": "A sunset"}, + headers={"Authorization": f"Bearer {token}"}, + ) + assert resp.status_code == 200 + data = resp.json() + assert data["images"][0]["url"] == "data:image/png;base64,abc123" + + +async def test_generate_image_stored_in_db(client): + """Generated image row persists in generated_images table.""" + from app import db as db_module + token = await _user_token(client) + with patch("app.routers.generate.openrouter.generate_image_chat", + new_callable=AsyncMock, return_value=FAKE_IMAGE_CHAT_GEMINI): + resp = await client.post( + "/generate/image", + json={"model": "google/gemini-2.5-flash-image", + "prompt": "A mountain"}, + headers={"Authorization": f"Bearer {token}"}, + ) + assert resp.status_code == 200 + image_id = resp.json()["images"][0]["image_id"] + assert image_id is not None + + row = db_module.get_conn().execute( + "SELECT model_id, prompt, image_data FROM generated_images WHERE id = ?", + [image_id], + ).fetchone() + assert row is not None + assert row[0] == "google/gemini-2.5-flash-image" + assert row[1] == "A mountain" + assert row[2] == "data:image/png;base64,gemini123" # --------------------------------------------------------------------------- diff --git a/backend/tests/test_models.py b/backend/tests/test_models.py index db08f89..1e3dc48 100644 --- a/backend/tests/test_models.py +++ b/backend/tests/test_models.py @@ -15,6 +15,7 @@ from app.services.models import ( _normalize_modality, _parse_modality, get_cached_models, + get_model_output_modalities, is_cache_stale, refresh_models_cache, ) @@ -28,28 +29,35 @@ FAKE_MODELS_RAW = [ "name": "GPT-4o", "context_length": 128000, "pricing": {"prompt": "0.000005"}, - "architecture": {"modality": "text->text"}, + "architecture": {"modality": "text->text", "output_modalities": ["text"]}, }, { "id": "anthropic/claude-3-haiku", "name": "Claude 3 Haiku", "context_length": 200000, "pricing": {}, - "architecture": {"modality": "text+image->text"}, + "architecture": {"modality": "text+image->text", "output_modalities": ["text"]}, }, { "id": "openai/dall-e-3", "name": "DALL-E 3", "context_length": None, "pricing": {"image": "0.04"}, - "architecture": {"modality": "text->image"}, + "architecture": {"modality": "text->image", "output_modalities": ["image"]}, }, { "id": "openai/sora-2", "name": "Sora 2", "context_length": None, "pricing": {"video": "0.10"}, - "architecture": {"modality": "text->video"}, + "architecture": {"modality": "text->video", "output_modalities": ["video"]}, + }, + { + "id": "google/gemini-2.5-flash-image", + "name": "Gemini 2.5 Flash Image", + "context_length": None, + "pricing": {}, + "architecture": {"output_modalities": ["image", "text"]}, }, ] @@ -171,9 +179,9 @@ async def test_refresh_stores_models(): return_value=FAKE_MODELS_RAW, ): count = await refresh_models_cache(conn) - assert count == 4 + assert count == 5 all_models = get_cached_models(conn) - assert len(all_models) == 4 + assert len(all_models) == 5 async def test_refresh_replaces_old_cache(): @@ -193,25 +201,29 @@ async def test_refresh_replaces_old_cache(): ids = [m["id"] for m in get_cached_models(conn)] assert "old/model" not in ids assert "openai/gpt-4o" in ids + assert len(ids) == 5 def test_get_cached_models_filter_by_modality(): conn = db_module.get_conn() now = datetime.now(timezone.utc).replace(tzinfo=None) for m in FAKE_MODELS_RAW: - arch = m.get("architecture", {}) - modality = _parse_modality(arch.get("modality", "text->text")) + modality = _extract_output_modality(m) conn.execute( "INSERT INTO models_cache (model_id, name, modality, fetched_at) VALUES (?, ?, ?, ?)", [m["id"], m["name"], modality, now], ) text_models = get_cached_models(conn, modality="text") + # gpt-4o, claude-3-haiku (gemini has output_modalities=["image","text"] → classified as "image") assert len(text_models) == 2 assert all(m["modality"] == "text" for m in text_models) image_models = get_cached_models(conn, modality="image") - assert len(image_models) == 1 - assert image_models[0]["id"] == "openai/dall-e-3" + # dall-e-3 + gemini (output_modalities starts with image) + assert len(image_models) == 2 + image_ids = [m["id"] for m in image_models] + assert "openai/dall-e-3" in image_ids + assert "google/gemini-2.5-flash-image" in image_ids video_models = get_cached_models(conn, modality="video") assert len(video_models) == 1 @@ -233,7 +245,7 @@ async def test_list_models_endpoint_auto_refreshes(client): "/models/", headers={"Authorization": f"Bearer {token}"} ) assert resp.status_code == 200 - assert len(resp.json()) == 4 + assert len(resp.json()) == 5 assert mock_fetch.await_count >= 1 @@ -274,8 +286,10 @@ async def test_list_models_filter_by_modality(client): ) assert resp.status_code == 200 data = resp.json() - assert len(data) == 1 - assert data[0]["id"] == "openai/dall-e-3" + assert len(data) == 2 # dall-e-3 + gemini-2.5-flash-image + image_ids = [m["id"] for m in data] + assert "openai/dall-e-3" in image_ids + assert "google/gemini-2.5-flash-image" in image_ids # --------------------------------------------------------------------------- @@ -301,7 +315,7 @@ async def test_refresh_endpoint_admin_succeeds(client): "/models/refresh", headers={"Authorization": f"Bearer {token}"} ) assert resp.status_code == 200 - assert resp.json()["refreshed"] == 4 + assert resp.json()["refreshed"] == 5 async def test_refresh_endpoint_502_on_openrouter_error(client): @@ -315,3 +329,38 @@ async def test_refresh_endpoint_502_on_openrouter_error(client): "/models/refresh", headers={"Authorization": f"Bearer {token}"} ) assert resp.status_code == 502 + + +# --------------------------------------------------------------------------- +# Unit tests: get_model_output_modalities +# --------------------------------------------------------------------------- + +async def test_get_model_output_modalities_image_only(): + conn = db_module.get_conn() + with patch( + "app.services.models.openrouter.list_models", + new_callable=AsyncMock, + return_value=FAKE_MODELS_RAW, + ): + await refresh_models_cache(conn) + modalities = get_model_output_modalities(conn, "openai/dall-e-3") + assert modalities == ["image"] + + +async def test_get_model_output_modalities_image_text(): + conn = db_module.get_conn() + with patch( + "app.services.models.openrouter.list_models", + new_callable=AsyncMock, + return_value=FAKE_MODELS_RAW, + ): + await refresh_models_cache(conn) + modalities = get_model_output_modalities( + conn, "google/gemini-2.5-flash-image") + assert set(modalities) == {"image", "text"} + + +def test_get_model_output_modalities_unknown_model(): + conn = db_module.get_conn() + result = get_model_output_modalities(conn, "unknown/model") + assert result == [] diff --git a/backend/tests/test_users.py b/backend/tests/test_users.py index bc5f4af..bd9792c 100644 --- a/backend/tests/test_users.py +++ b/backend/tests/test_users.py @@ -115,7 +115,9 @@ async def test_list_users_as_admin(client): resp = await client.get("/users", headers={"Authorization": f"Bearer {admin_token}"}) assert resp.status_code == 200 assert isinstance(resp.json(), list) - assert len(resp.json()) == 1 + assert len(resp.json()) >= 1 + emails = [u["email"] for u in resp.json()] + assert "user@example.com" in emails async def test_list_users_as_regular_user(client): diff --git a/frontend/app/main.py b/frontend/app/main.py index 77d82d3..4357024 100644 --- a/frontend/app/main.py +++ b/frontend/app/main.py @@ -178,7 +178,13 @@ def dashboard(): user = resp.json() if resp.status_code == 200 else {} img_resp = _api("GET", "/images/", token=token) images = img_resp.json() if img_resp.status_code == 200 else [] - return render_template("dashboard.html", user=user, images=images) + gen_resp = _api("GET", "/generate/images", token=token) + generated_images = gen_resp.json() if gen_resp.status_code == 200 else [] + vid_resp = _api("GET", "/generate/videos", token=token) + generated_videos = vid_resp.json() if vid_resp.status_code == 200 else [] + return render_template("dashboard.html", user=user, images=images, + generated_images=generated_images, + generated_videos=generated_videos) # ── Generate ────────────────────────────────────────────────────────────── diff --git a/frontend/app/templates/dashboard.html b/frontend/app/templates/dashboard.html index 2ff08f7..537bd4b 100644 --- a/frontend/app/templates/dashboard.html +++ b/frontend/app/templates/dashboard.html @@ -6,7 +6,59 @@ endblock %} {% block content %} Start generating -{% if images %} +{% if generated_images %} +
+

Generated images

+
+ {% for img in generated_images %} +
+ {{ img.prompt }} +

+ {{ img.model_id }}
{{ img.prompt[:80] }}{% if + img.prompt|length > 80 %}…{% endif %} +

+
+ {% endfor %} +
+
+{% endif %} {% if generated_videos %} +
+

Generated videos

+
+ {% for vid in generated_videos %} +
+ {% if vid.video_url %} + + {% else %} +
+ {{ vid.status | capitalize }} … +
+ {% endif %} +

+ {{ vid.model_id }}
{{ vid.prompt[:80] }}{% if + vid.prompt|length > 80 %}…{% endif %}
+ {{ vid.status }} +

+
+ {% endfor %} +
+
+{% endif %} {% if images %}

Uploaded reference images

diff --git a/frontend/app/templates/generate_image.html b/frontend/app/templates/generate_image.html index ad6a3cc..2e83585 100644 --- a/frontend/app/templates/generate_image.html +++ b/frontend/app/templates/generate_image.html @@ -8,12 +8,12 @@ {% if models %} {% else %} {% endif %} @@ -21,40 +21,25 @@ - - - - - - @@ -65,7 +50,7 @@ accept="image/png,image/jpeg,image/webp,image/gif" >

- Upload image for visual reference in upcoming image-to-image flow. + Upload an image to use as visual reference (image-to-image).