feat: enhance model caching and output modalities handling

- Updated `refresh_models_cache` to include output modalities in the models cache.
- Added `get_model_output_modalities` function to retrieve output modalities for a specific model.
- Modified tests to cover new functionality for output modalities.
- Updated OpenRouter video generation functions to support audio generation and improved error handling.
- Enhanced dashboard to display generated images and videos.
- Refactored frontend templates to accommodate new data structures for generated content.
- Adjusted tests to validate changes in model handling and dashboard rendering.

Co-authored-by: Copilot <copilot@github.com>
This commit is contained in:
2026-04-29 15:20:48 +02:00
parent 3d32e6df74
commit 712c556032
15 changed files with 618 additions and 219 deletions
+28
View File
@@ -86,6 +86,34 @@ def _run_migrations(conn: duckdb.DuckDBPyConnection) -> None:
fetched_at TIMESTAMP NOT NULL fetched_at TIMESTAMP NOT NULL
) )
""") """)
conn.execute("""
CREATE TABLE IF NOT EXISTS generated_images (
id UUID DEFAULT uuid() PRIMARY KEY,
user_id UUID NOT NULL,
model_id VARCHAR NOT NULL,
prompt VARCHAR NOT NULL,
image_data VARCHAR NOT NULL,
created_at TIMESTAMP DEFAULT now()
)
""")
conn.execute("""
CREATE TABLE IF NOT EXISTS generated_videos (
id UUID DEFAULT uuid() PRIMARY KEY,
user_id UUID NOT NULL,
job_id VARCHAR NOT NULL,
model_id VARCHAR NOT NULL,
prompt VARCHAR NOT NULL,
polling_url VARCHAR,
status VARCHAR NOT NULL DEFAULT 'pending',
video_url VARCHAR,
created_at TIMESTAMP DEFAULT now(),
updated_at TIMESTAMP DEFAULT now()
)
""")
# Migration: add output_modalities column if absent (stores JSON array string)
conn.execute("""
ALTER TABLE models_cache ADD COLUMN IF NOT EXISTS output_modalities VARCHAR
""")
_seed_admin(conn) _seed_admin(conn)
+14 -14
View File
@@ -1,10 +1,10 @@
from .routers import auth as auth_router from .routers import auth
from .routers import users as users_router from .routers import users
from .routers import admin as admin_router from .routers import admin
from .routers import ai as ai_router from .routers import ai
from .routers import generate as generate_router from .routers import generate
from .routers import images as images_router from .routers import images
from .routers import models as models_router from .routers import models
from .db import close_db, init_db from .db import close_db, init_db
import os import os
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
@@ -38,13 +38,13 @@ app.add_middleware(
allow_headers=["*"], allow_headers=["*"],
) )
app.include_router(auth_router.router) app.include_router(auth.router)
app.include_router(users_router.router) app.include_router(users.router)
app.include_router(admin_router.router) app.include_router(admin.router)
app.include_router(ai_router.router) app.include_router(ai.router)
app.include_router(generate_router.router) app.include_router(generate.router)
app.include_router(images_router.router) app.include_router(images.router)
app.include_router(models_router.router) app.include_router(models.router)
@app.get("/health", tags=["health"]) @app.get("/health", tags=["health"])
+1
View File
@@ -62,6 +62,7 @@ class ImageResult(BaseModel):
url: str | None = None url: str | None = None
b64_json: str | None = None b64_json: str | None = None
revised_prompt: str | None = None revised_prompt: str | None = None
image_id: str | None = None # UUID of stored row in generated_images
class ImageResponse(BaseModel): class ImageResponse(BaseModel):
+193 -69
View File
@@ -1,6 +1,9 @@
"""Generate router: text, image, video, and image-to-video generation.""" """Generate router: text, image, video, and image-to-video generation."""
from datetime import datetime, timezone
from fastapi import APIRouter, Depends, HTTPException, status from fastapi import APIRouter, Depends, HTTPException, status
from ..db import get_conn, get_write_lock
from ..dependencies import get_current_user from ..dependencies import get_current_user
from ..models.ai import ( from ..models.ai import (
ImageRequest, ImageRequest,
@@ -13,6 +16,7 @@ from ..models.ai import (
VideoResponse, VideoResponse,
) )
from ..services import openrouter from ..services import openrouter
from ..services.models import get_model_output_modalities
router = APIRouter(prefix="/generate", tags=["generate"]) router = APIRouter(prefix="/generate", tags=["generate"])
@@ -62,81 +66,129 @@ async def generate_text(
@router.post("/image", response_model=ImageResponse) @router.post("/image", response_model=ImageResponse)
async def generate_image( async def generate_image(
body: ImageRequest, body: ImageRequest,
_: dict = Depends(get_current_user), current_user: dict = Depends(get_current_user),
) -> ImageResponse: ) -> ImageResponse:
"""Generate images from a text prompt.""" """Generate images from a prompt using the chat completions endpoint.
# Detect if model uses chat completions (FLUX, GPT-5 Image Mini) vs /images/generations (DALL-E)
chat_models = {"black-forest-labs/flux.2-klein-4b", All OpenRouter image models use /chat/completions with a modalities param.
"openai/gpt-5-image-mini"} Models that output only images use ["image"]; those that also output text
is_chat_model = body.model.lower() in {m.lower() for m in chat_models} or \ use ["image", "text"]. We look this up from the model cache; default to
any(m in body.model.lower() for m in ["flux", "gpt-5-image-mini"]) ["image", "text"] when the model is not yet cached.
"""
# Determine modalities from cache; default ["image", "text"] works for most models
try:
conn = get_conn()
cached_modalities = get_model_output_modalities(conn, body.model)
except Exception:
cached_modalities = []
if cached_modalities:
# If cache says model only outputs image (no text), use ["image"]
modalities = ["image"] if set(cached_modalities) == {
"image"} else ["image", "text"]
else:
# Safe default: ["image", "text"]; works for Gemini, GPT-image etc.
# For image-only models that fail with this, the error surfaces to the user.
modalities = ["image", "text"]
image_config: dict = {}
if body.aspect_ratio:
image_config["aspect_ratio"] = body.aspect_ratio
if body.image_size:
image_config["image_size"] = body.image_size
try: try:
if is_chat_model: result = await openrouter.generate_image_chat(
image_config = {} model=body.model,
if body.aspect_ratio: prompt=body.prompt,
image_config["aspect_ratio"] = body.aspect_ratio modalities=modalities,
if body.image_size: image_config=image_config if image_config else None,
image_config["image_size"] = body.image_size )
result = await openrouter.generate_image_chat(
model=body.model,
prompt=body.prompt,
modalities=[
"image", "text"] if "gpt-5-image-mini" in body.model.lower() else ["image"],
image_config=image_config if image_config else None,
)
else:
result = await openrouter.generate_image(
model=body.model,
prompt=body.prompt,
n=body.n,
size=body.size,
)
except Exception as exc: except Exception as exc:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_502_BAD_GATEWAY, detail=f"OpenRouter error: {exc}") status_code=status.HTTP_502_BAD_GATEWAY, detail=f"OpenRouter error: {exc}")
try: try:
if is_chat_model: message = result.get("choices", [{}])[0].get("message", {})
# Chat completions response: choices[0].message.images[].image_url.url images = []
images = [] for item in message.get("images", []):
message = result.get("choices", [{}])[0].get("message", {}) img_url = item.get("image_url", {}).get("url")
for item in message.get("images", []): images.append(ImageResult(
img_url = item.get("image_url", {}).get("url") url=img_url,
images.append(ImageResult( b64_json=None,
url=img_url, revised_prompt=message.get("content") or None,
b64_json=None, ))
revised_prompt=message.get("content"), if not images:
raise HTTPException(
status_code=status.HTTP_502_BAD_GATEWAY,
detail="No images returned by model. Verify the model supports image generation.",
)
# Persist each image to DB
user_id = current_user.get("id") or current_user.get("sub")
now = datetime.now(timezone.utc).replace(tzinfo=None)
stored: list[ImageResult] = []
async with get_write_lock():
conn = get_conn()
for img in images:
if img.url:
row = conn.execute(
"""INSERT INTO generated_images (user_id, model_id, prompt, image_data, created_at)
VALUES (?, ?, ?, ?, ?) RETURNING id""",
[user_id, body.model, body.prompt, img.url, now],
).fetchone()
image_id = str(row[0]) if row else None
else:
image_id = None
stored.append(ImageResult(
url=img.url,
b64_json=img.b64_json,
revised_prompt=img.revised_prompt,
image_id=image_id,
)) ))
return ImageResponse(
id=result.get("id", ""), return ImageResponse(
model=result.get("model", body.model), id=result.get("id", ""),
images=images, model=result.get("model", body.model),
) images=stored,
else: )
# /images/generations response: data[].url except HTTPException:
images = [ raise
ImageResult(
url=item.get("url"),
b64_json=item.get("b64_json"),
revised_prompt=item.get("revised_prompt"),
)
for item in result.get("data", [])
]
return ImageResponse(
id=result.get("id", ""),
model=result.get("model", body.model),
images=images,
)
except (KeyError, TypeError) as exc: except (KeyError, TypeError) as exc:
raise HTTPException(status_code=status.HTTP_502_BAD_GATEWAY, raise HTTPException(status_code=status.HTTP_502_BAD_GATEWAY,
detail=f"Unexpected response format: {exc}") detail=f"Unexpected response format: {exc}")
@router.get("/images")
async def list_generated_images(
current_user: dict = Depends(get_current_user),
) -> list[dict]:
"""Return all generated images for the current user, newest first."""
user_id = current_user.get("id") or current_user.get("sub")
conn = get_conn()
rows = conn.execute(
"""SELECT id, model_id, prompt, image_data, created_at
FROM generated_images
WHERE user_id = ?
ORDER BY created_at DESC""",
[user_id],
).fetchall()
return [
{
"id": str(r[0]),
"model_id": r[1],
"prompt": r[2],
"image_data": r[3],
"created_at": r[4].isoformat() if r[4] else None,
}
for r in rows
]
@router.post("/video", response_model=VideoResponse) @router.post("/video", response_model=VideoResponse)
async def generate_video( async def generate_video(
body: VideoRequest, body: VideoRequest,
_: dict = Depends(get_current_user), current_user: dict = Depends(get_current_user),
) -> VideoResponse: ) -> VideoResponse:
"""Generate a video from a text prompt.""" """Generate a video from a text prompt."""
try: try:
@@ -151,12 +203,26 @@ async def generate_video(
raise HTTPException( raise HTTPException(
status_code=status.HTTP_502_BAD_GATEWAY, detail=f"OpenRouter error: {exc}") status_code=status.HTTP_502_BAD_GATEWAY, detail=f"OpenRouter error: {exc}")
user_id = current_user.get("id") or current_user.get("sub")
job_id = result.get("id", "")
polling_url = result.get("polling_url")
job_status = result.get("status", "pending")
now = datetime.now(timezone.utc).replace(tzinfo=None)
async with get_write_lock():
conn = get_conn()
conn.execute(
"""INSERT INTO generated_videos (user_id, job_id, model_id, prompt, polling_url, status, created_at, updated_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)""",
[user_id, job_id, body.model, body.prompt,
polling_url, job_status, now, now],
)
urls = result.get("unsigned_urls") or result.get("video_urls") urls = result.get("unsigned_urls") or result.get("video_urls")
return VideoResponse( return VideoResponse(
id=result.get("id", ""), id=job_id,
model=body.model, model=body.model,
status=result.get("status", "queued"), status=job_status,
polling_url=result.get("polling_url"), polling_url=polling_url,
video_urls=urls, video_urls=urls,
video_url=(urls or [None])[0], video_url=(urls or [None])[0],
error=result.get("error"), error=result.get("error"),
@@ -167,7 +233,7 @@ async def generate_video(
@router.post("/video/from-image", response_model=VideoResponse) @router.post("/video/from-image", response_model=VideoResponse)
async def generate_video_from_image( async def generate_video_from_image(
body: VideoFromImageRequest, body: VideoFromImageRequest,
_: dict = Depends(get_current_user), current_user: dict = Depends(get_current_user),
) -> VideoResponse: ) -> VideoResponse:
"""Generate a video from an image and a text prompt.""" """Generate a video from an image and a text prompt."""
try: try:
@@ -183,12 +249,26 @@ async def generate_video_from_image(
raise HTTPException( raise HTTPException(
status_code=status.HTTP_502_BAD_GATEWAY, detail=f"OpenRouter error: {exc}") status_code=status.HTTP_502_BAD_GATEWAY, detail=f"OpenRouter error: {exc}")
user_id = current_user.get("id") or current_user.get("sub")
job_id = result.get("id", "")
polling_url = result.get("polling_url")
job_status = result.get("status", "pending")
now = datetime.now(timezone.utc).replace(tzinfo=None)
async with get_write_lock():
conn = get_conn()
conn.execute(
"""INSERT INTO generated_videos (user_id, job_id, model_id, prompt, polling_url, status, created_at, updated_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)""",
[user_id, job_id, body.model, body.prompt,
polling_url, job_status, now, now],
)
urls = result.get("unsigned_urls") or result.get("video_urls") urls = result.get("unsigned_urls") or result.get("video_urls")
return VideoResponse( return VideoResponse(
id=result.get("id", ""), id=job_id,
model=body.model, model=body.model,
status=result.get("status", "queued"), status=job_status,
polling_url=result.get("polling_url"), polling_url=polling_url,
video_urls=urls, video_urls=urls,
video_url=(urls or [None])[0], video_url=(urls or [None])[0],
error=result.get("error"), error=result.get("error"),
@@ -199,23 +279,67 @@ async def generate_video_from_image(
@router.get("/video/status", response_model=VideoResponse) @router.get("/video/status", response_model=VideoResponse)
async def poll_video_status( async def poll_video_status(
polling_url: str, polling_url: str,
_: dict = Depends(get_current_user), current_user: dict = Depends(get_current_user),
) -> VideoResponse: ) -> VideoResponse:
"""Poll the status of a video generation job via its polling_url.""" """Poll status of a video generation job; updates DB row when completed/failed."""
try: try:
result = await openrouter.poll_video_status(polling_url) result = await openrouter.poll_video_status(polling_url)
except Exception as exc: except Exception as exc:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_502_BAD_GATEWAY, detail=f"OpenRouter error: {exc}") status_code=status.HTTP_502_BAD_GATEWAY, detail=f"OpenRouter error: {exc}")
job_status = result.get("status", "processing")
urls = result.get("unsigned_urls") or result.get("video_urls") urls = result.get("unsigned_urls") or result.get("video_urls")
video_url = (urls or [None])[0]
# Update DB row for this job when terminal state reached
if job_status in ("completed", "failed"):
now = datetime.now(timezone.utc).replace(tzinfo=None)
async with get_write_lock():
conn = get_conn()
conn.execute(
"""UPDATE generated_videos
SET status = ?, video_url = ?, updated_at = ?
WHERE job_id = ?""",
[job_status, video_url, now, result.get("id", "")],
)
return VideoResponse( return VideoResponse(
id=result.get("id", ""), id=result.get("id", ""),
model=result.get("model", ""), model=result.get("model", ""),
status=result.get("status", "processing"), status=job_status,
polling_url=result.get("polling_url"), polling_url=result.get("polling_url"),
video_urls=urls, video_urls=urls,
video_url=(urls or [None])[0], video_url=video_url,
error=result.get("error"), error=result.get("error"),
metadata=result.get("metadata"), metadata=result.get("metadata"),
) )
@router.get("/videos")
async def list_generated_videos(
current_user: dict = Depends(get_current_user),
) -> list[dict]:
"""Return all generated video jobs for the current user, newest first."""
user_id = current_user.get("id") or current_user.get("sub")
conn = get_conn()
rows = conn.execute(
"""SELECT id, job_id, model_id, prompt, polling_url, status, video_url, created_at
FROM generated_videos
WHERE user_id = ?
ORDER BY created_at DESC""",
[user_id],
).fetchall()
return [
{
"id": str(r[0]),
"job_id": r[1],
"model_id": r[2],
"prompt": r[3],
"polling_url": r[4],
"status": r[5],
"video_url": r[6],
"created_at": r[7].isoformat() if r[7] else None,
}
for r in rows
]
+33 -3
View File
@@ -91,16 +91,28 @@ async def refresh_models_cache(conn: duckdb.DuckDBPyConnection) -> int:
model_id = m.get("id", "") model_id = m.get("id", "")
if not model_id: if not model_id:
continue continue
# Full output_modalities array from architecture (for proper modalities param in image gen)
architecture = m.get("architecture") or {}
raw_output_modalities: list | None = (
architecture.get("output_modalities") or m.get("output_modalities")
)
output_modalities_json: str | None = (
json.dumps([_normalize_modality(str(v))
for v in raw_output_modalities])
if isinstance(raw_output_modalities, list)
else None
)
conn.execute( conn.execute(
""" """
INSERT INTO models_cache (model_id, name, modality, context_length, pricing, fetched_at) INSERT INTO models_cache (model_id, name, modality, context_length, pricing, fetched_at, output_modalities)
VALUES (?, ?, ?, ?, ?, ?) VALUES (?, ?, ?, ?, ?, ?, ?)
ON CONFLICT (model_id) DO UPDATE SET ON CONFLICT (model_id) DO UPDATE SET
name = excluded.name, name = excluded.name,
modality = excluded.modality, modality = excluded.modality,
context_length = excluded.context_length, context_length = excluded.context_length,
pricing = excluded.pricing, pricing = excluded.pricing,
fetched_at = excluded.fetched_at fetched_at = excluded.fetched_at,
output_modalities = excluded.output_modalities
""", """,
[ [
model_id, model_id,
@@ -109,6 +121,7 @@ async def refresh_models_cache(conn: duckdb.DuckDBPyConnection) -> int:
m.get("context_length"), m.get("context_length"),
json.dumps(pricing) if pricing else None, json.dumps(pricing) if pricing else None,
now, now,
output_modalities_json,
], ],
) )
count += 1 count += 1
@@ -168,3 +181,20 @@ def get_cached_models(
"pricing": pricing, "pricing": pricing,
}) })
return result return result
def get_model_output_modalities(
conn: duckdb.DuckDBPyConnection,
model_id: str,
) -> list[str]:
"""Return output_modalities list for a model; empty list if not found."""
row = conn.execute(
"SELECT output_modalities FROM models_cache WHERE model_id = ?",
[model_id],
).fetchone()
if not row or not row[0]:
return []
try:
return json.loads(row[0])
except (json.JSONDecodeError, TypeError):
return []
+33 -5
View File
@@ -95,8 +95,9 @@ async def generate_video(
duration_seconds: int | None = None, duration_seconds: int | None = None,
aspect_ratio: str = "16:9", aspect_ratio: str = "16:9",
resolution: str | None = None, resolution: str | None = None,
generate_audio: bool | None = None,
) -> dict[str, Any]: ) -> dict[str, Any]:
"""Request text-to-video generation via OpenRouter.""" """Request text-to-video generation via OpenRouter POST /videos."""
base_url = os.getenv("OPENROUTER_BASE_URL", OPENROUTER_BASE_URL) base_url = os.getenv("OPENROUTER_BASE_URL", OPENROUTER_BASE_URL)
payload: dict[str, Any] = { payload: dict[str, Any] = {
"model": model, "model": model,
@@ -104,9 +105,12 @@ async def generate_video(
"aspect_ratio": aspect_ratio, "aspect_ratio": aspect_ratio,
} }
if duration_seconds is not None: if duration_seconds is not None:
payload["duration_seconds"] = duration_seconds # API uses 'duration' not 'duration_seconds'
payload["duration"] = duration_seconds
if resolution is not None: if resolution is not None:
payload["resolution"] = resolution payload["resolution"] = resolution
if generate_audio is not None:
payload["generate_audio"] = generate_audio
async with httpx.AsyncClient(timeout=120) as client: async with httpx.AsyncClient(timeout=120) as client:
resp = client.build_request( resp = client.build_request(
"POST", f"{base_url}/videos", headers=_headers(), json=payload "POST", f"{base_url}/videos", headers=_headers(), json=payload
@@ -123,19 +127,31 @@ async def generate_video_from_image(
duration_seconds: int | None = None, duration_seconds: int | None = None,
aspect_ratio: str = "16:9", aspect_ratio: str = "16:9",
resolution: str | None = None, resolution: str | None = None,
generate_audio: bool | None = None,
) -> dict[str, Any]: ) -> dict[str, Any]:
"""Request image-to-video generation via OpenRouter.""" """Request image-to-video generation via OpenRouter POST /videos.
Uses frame_images array with first_frame as per OpenRouter API spec.
"""
base_url = os.getenv("OPENROUTER_BASE_URL", OPENROUTER_BASE_URL) base_url = os.getenv("OPENROUTER_BASE_URL", OPENROUTER_BASE_URL)
payload: dict[str, Any] = { payload: dict[str, Any] = {
"model": model, "model": model,
"image_url": image_url,
"prompt": prompt, "prompt": prompt,
"aspect_ratio": aspect_ratio, "aspect_ratio": aspect_ratio,
"frame_images": [
{
"type": "image_url",
"image_url": {"url": image_url},
"frame_type": "first_frame",
}
],
} }
if duration_seconds is not None: if duration_seconds is not None:
payload["duration_seconds"] = duration_seconds payload["duration"] = duration_seconds
if resolution is not None: if resolution is not None:
payload["resolution"] = resolution payload["resolution"] = resolution
if generate_audio is not None:
payload["generate_audio"] = generate_audio
async with httpx.AsyncClient(timeout=120) as client: async with httpx.AsyncClient(timeout=120) as client:
resp = client.build_request( resp = client.build_request(
"POST", f"{base_url}/videos", headers=_headers(), json=payload "POST", f"{base_url}/videos", headers=_headers(), json=payload
@@ -154,6 +170,18 @@ async def poll_video_status(polling_url: str) -> dict[str, Any]:
return response.json() return response.json()
async def list_video_models() -> list[dict[str, Any]]:
"""Return video generation models from the dedicated /videos/models endpoint."""
base_url = os.getenv("OPENROUTER_BASE_URL", OPENROUTER_BASE_URL)
async with httpx.AsyncClient(timeout=15) as client:
resp = client.build_request(
"GET", f"{base_url}/videos/models", headers=_headers()
)
response = await client.send(resp)
response.raise_for_status()
return response.json().get("data", [])
async def generate_image_chat( async def generate_image_chat(
model: str, model: str,
prompt: str, prompt: str,
+2 -1
View File
@@ -53,7 +53,8 @@ async def test_stats_as_admin(client):
resp = await client.get("/admin/stats", headers={"Authorization": f"Bearer {token}"}) resp = await client.get("/admin/stats", headers={"Authorization": f"Bearer {token}"})
assert resp.status_code == 200 assert resp.status_code == 200
data = resp.json() data = resp.json()
assert data["users"]["total"] == 3 # 2 users + 1 admin # 2 users + 1 admin + 1 seeded admin (ai@allucanget.biz)
assert data["users"]["total"] == 4
assert "by_role" in data["users"] assert "by_role" in data["users"]
assert "refresh_tokens" in data assert "refresh_tokens" in data
+6 -6
View File
@@ -53,7 +53,7 @@ async def _user_token(client):
async def test_list_models(client): async def test_list_models(client):
token = await _user_token(client) token = await _user_token(client)
with patch( with patch(
"backend.app.routers.ai.openrouter.list_models", "app.routers.ai.openrouter.list_models",
new_callable=AsyncMock, new_callable=AsyncMock,
return_value=FAKE_MODELS, return_value=FAKE_MODELS,
): ):
@@ -74,7 +74,7 @@ async def test_list_models_unauthenticated(client):
async def test_list_models_upstream_error(client): async def test_list_models_upstream_error(client):
token = await _user_token(client) token = await _user_token(client)
with patch( with patch(
"backend.app.routers.ai.openrouter.list_models", "app.routers.ai.openrouter.list_models",
new_callable=AsyncMock, new_callable=AsyncMock,
side_effect=Exception("Connection refused"), side_effect=Exception("Connection refused"),
): ):
@@ -91,7 +91,7 @@ async def test_list_models_upstream_error(client):
async def test_chat_success(client): async def test_chat_success(client):
token = await _user_token(client) token = await _user_token(client)
with patch( with patch(
"backend.app.routers.ai.openrouter.chat_completion", "app.routers.ai.openrouter.chat_completion",
new_callable=AsyncMock, new_callable=AsyncMock,
return_value=FAKE_CHAT_RESPONSE, return_value=FAKE_CHAT_RESPONSE,
): ):
@@ -115,7 +115,7 @@ async def test_chat_success(client):
async def test_chat_passes_parameters(client): async def test_chat_passes_parameters(client):
token = await _user_token(client) token = await _user_token(client)
mock = AsyncMock(return_value=FAKE_CHAT_RESPONSE) mock = AsyncMock(return_value=FAKE_CHAT_RESPONSE)
with patch("backend.app.routers.ai.openrouter.chat_completion", new_callable=AsyncMock, return_value=FAKE_CHAT_RESPONSE) as mock: with patch("app.routers.ai.openrouter.chat_completion", new_callable=AsyncMock, return_value=FAKE_CHAT_RESPONSE) as mock:
await client.post( await client.post(
"/ai/chat", "/ai/chat",
json={ json={
@@ -145,7 +145,7 @@ async def test_chat_unauthenticated(client):
async def test_chat_upstream_error(client): async def test_chat_upstream_error(client):
token = await _user_token(client) token = await _user_token(client)
with patch( with patch(
"backend.app.routers.ai.openrouter.chat_completion", "app.routers.ai.openrouter.chat_completion",
new_callable=AsyncMock, new_callable=AsyncMock,
side_effect=Exception("timeout"), side_effect=Exception("timeout"),
): ):
@@ -160,7 +160,7 @@ async def test_chat_upstream_error(client):
async def test_chat_malformed_upstream_response(client): async def test_chat_malformed_upstream_response(client):
token = await _user_token(client) token = await _user_token(client)
with patch( with patch(
"backend.app.routers.ai.openrouter.chat_completion", "app.routers.ai.openrouter.chat_completion",
new_callable=AsyncMock, new_callable=AsyncMock,
return_value={"id": "x", "choices": []}, # empty choices return_value={"id": "x", "choices": []}, # empty choices
): ):
+135 -69
View File
@@ -18,15 +18,6 @@ FAKE_CHAT = {
"usage": {"prompt_tokens": 5, "completion_tokens": 10, "total_tokens": 15}, "usage": {"prompt_tokens": 5, "completion_tokens": 10, "total_tokens": 15},
} }
FAKE_IMAGE = {
"id": "gen-img-1",
"model": "openai/dall-e-3",
"data": [
{"url": "https://example.com/image.png",
"revised_prompt": "A cat on the moon"},
],
}
FAKE_VIDEO = { FAKE_VIDEO = {
"id": "gen-vid-1", "id": "gen-vid-1",
"polling_url": "https://openrouter.ai/api/v1/videos/gen-vid-1", "polling_url": "https://openrouter.ai/api/v1/videos/gen-vid-1",
@@ -155,47 +146,13 @@ async def test_generate_text_upstream_error(client):
# POST /generate/image # POST /generate/image
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
async def test_generate_image(client):
token = await _user_token(client)
with patch("app.routers.generate.openrouter.generate_image", new_callable=AsyncMock, return_value=FAKE_IMAGE):
resp = await client.post(
"/generate/image",
json={"model": "openai/dall-e-3", "prompt": "A cat on the moon"},
headers={"Authorization": f"Bearer {token}"},
)
assert resp.status_code == 200
data = resp.json()
assert data["id"] == "gen-img-1"
assert len(data["images"]) == 1
assert data["images"][0]["url"] == "https://example.com/image.png"
assert data["images"][0]["revised_prompt"] == "A cat on the moon"
async def test_generate_image_unauthenticated(client):
resp = await client.post("/generate/image", json={"model": "openai/dall-e-3", "prompt": "Hi"})
assert resp.status_code == 401
async def test_generate_image_upstream_error(client):
token = await _user_token(client)
with patch("app.routers.generate.openrouter.generate_image", new_callable=AsyncMock, side_effect=Exception("rate limit")):
resp = await client.post(
"/generate/image",
json={"model": "openai/dall-e-3", "prompt": "Hi"},
headers={"Authorization": f"Bearer {token}"},
)
assert resp.status_code == 502
# --- Chat-based image generation (FLUX, GPT-5 Image Mini) ---
FAKE_IMAGE_CHAT_FLUX = { FAKE_IMAGE_CHAT_FLUX = {
"id": "gen-img-chat-1", "id": "gen-img-chat-1",
"model": "black-forest-labs/flux.2-klein-4b", "model": "black-forest-labs/flux.2-klein-4b",
"choices": [{ "choices": [{
"message": { "message": {
"role": "assistant", "role": "assistant",
"content": "Here is your generated image.", "content": None,
"images": [{ "images": [{
"type": "image_url", "type": "image_url",
"image_url": {"url": "data:image/png;base64,abc123"}, "image_url": {"url": "data:image/png;base64,abc123"},
@@ -219,45 +176,65 @@ FAKE_IMAGE_CHAT_GPT5 = {
}], }],
} }
FAKE_IMAGE_CHAT_GEMINI = {
"id": "gen-img-chat-3",
"model": "google/gemini-2.5-flash-image",
"choices": [{
"message": {
"role": "assistant",
"content": "Here is your image.",
"images": [{
"type": "image_url",
"image_url": {"url": "data:image/png;base64,gemini123"},
}],
}
}],
}
async def test_generate_image_chat_flux(client):
async def test_generate_image(client):
"""All models now use generate_image_chat (chat completions endpoint)."""
token = await _user_token(client) token = await _user_token(client)
with patch("app.routers.generate.openrouter.generate_image_chat", new_callable=AsyncMock, return_value=FAKE_IMAGE_CHAT_FLUX): with patch("app.routers.generate.openrouter.generate_image_chat", new_callable=AsyncMock, return_value=FAKE_IMAGE_CHAT_GEMINI):
resp = await client.post( resp = await client.post(
"/generate/image", "/generate/image",
json={"model": "black-forest-labs/flux.2-klein-4b", json={"model": "google/gemini-2.5-flash-image",
"prompt": "A sunset"}, "prompt": "A cat on the moon"},
headers={"Authorization": f"Bearer {token}"}, headers={"Authorization": f"Bearer {token}"},
) )
assert resp.status_code == 200 assert resp.status_code == 200
data = resp.json() data = resp.json()
assert data["id"] == "gen-img-chat-1" assert data["id"] == "gen-img-chat-3"
assert len(data["images"]) == 1 assert len(data["images"]) == 1
assert data["images"][0]["url"] == "data:image/png;base64,abc123" assert data["images"][0]["url"] == "data:image/png;base64,gemini123"
assert data["images"][0]["image_id"] is not None # stored in DB
async def test_generate_image_chat_gpt5_image_mini(client): async def test_generate_image_unauthenticated(client):
resp = await client.post("/generate/image", json={"model": "google/gemini-2.5-flash-image", "prompt": "Hi"})
assert resp.status_code == 401
async def test_generate_image_upstream_error(client):
token = await _user_token(client) token = await _user_token(client)
with patch("app.routers.generate.openrouter.generate_image_chat", new_callable=AsyncMock, return_value=FAKE_IMAGE_CHAT_GPT5): with patch("app.routers.generate.openrouter.generate_image_chat", new_callable=AsyncMock, side_effect=Exception("rate limit")):
resp = await client.post( resp = await client.post(
"/generate/image", "/generate/image",
json={"model": "openai/gpt-5-image-mini", "prompt": "A cat"}, json={"model": "google/gemini-2.5-flash-image", "prompt": "Hi"},
headers={"Authorization": f"Bearer {token}"}, headers={"Authorization": f"Bearer {token}"},
) )
assert resp.status_code == 200 assert resp.status_code == 502
data = resp.json()
assert data["model"] == "openai/gpt-5-image-mini"
assert len(data["images"]) == 1
async def test_generate_image_chat_with_image_config(client): async def test_generate_image_with_image_config(client):
"""Passes aspect_ratio + image_size through to generate_image_chat."""
token = await _user_token(client) token = await _user_token(client)
mock = AsyncMock(return_value=FAKE_IMAGE_CHAT_FLUX) mock = AsyncMock(return_value=FAKE_IMAGE_CHAT_GEMINI)
with patch("app.routers.generate.openrouter.generate_image_chat", mock): with patch("app.routers.generate.openrouter.generate_image_chat", mock):
await client.post( await client.post(
"/generate/image", "/generate/image",
json={ json={
"model": "black-forest-labs/flux.2-klein-4b", "model": "google/gemini-2.5-flash-image",
"prompt": "A landscape", "prompt": "A landscape",
"aspect_ratio": "16:9", "aspect_ratio": "16:9",
"image_size": "2K", "image_size": "2K",
@@ -267,23 +244,112 @@ async def test_generate_image_chat_with_image_config(client):
call_kwargs = mock.call_args.kwargs call_kwargs = mock.call_args.kwargs
assert call_kwargs["image_config"]["aspect_ratio"] == "16:9" assert call_kwargs["image_config"]["aspect_ratio"] == "16:9"
assert call_kwargs["image_config"]["image_size"] == "2K" assert call_kwargs["image_config"]["image_size"] == "2K"
assert call_kwargs["modalities"] == ["image"]
async def test_generate_image_chat_unauthenticated(client): async def test_generate_image_default_modalities_image_text(client):
resp = await client.post("/generate/image", json={"model": "flux.2-klein-4b", "prompt": "Hi"}) """Model not in cache → default modalities = ['image', 'text']."""
assert resp.status_code == 401
async def test_generate_image_chat_upstream_error(client):
token = await _user_token(client) token = await _user_token(client)
with patch("app.routers.generate.openrouter.generate_image_chat", new_callable=AsyncMock, side_effect=Exception("timeout")): mock = AsyncMock(return_value=FAKE_IMAGE_CHAT_GEMINI)
with patch("app.routers.generate.openrouter.generate_image_chat", mock):
await client.post(
"/generate/image",
json={"model": "google/gemini-2.5-flash-image", "prompt": "Hi"},
headers={"Authorization": f"Bearer {token}"},
)
assert mock.call_args.kwargs["modalities"] == ["image", "text"]
async def test_generate_image_image_only_modalities_from_cache(client):
"""Model cached with image-only output_modalities → modalities = ['image']."""
from app import db as db_module
from app.services.models import get_model_output_modalities
import json as _json
token = await _user_token(client)
# Seed cache with image-only model
conn = db_module.get_conn()
from datetime import datetime, timezone
now = datetime.now(timezone.utc).replace(tzinfo=None)
conn.execute(
"DELETE FROM models_cache WHERE model_id = 'black-forest-labs/flux.2-pro'"
)
conn.execute(
"""INSERT INTO models_cache (model_id, name, modality, context_length, pricing, fetched_at, output_modalities)
VALUES (?, ?, ?, ?, ?, ?, ?)""",
["black-forest-labs/flux.2-pro", "FLUX.2 Pro", "image", None, None, now,
_json.dumps(["image"])],
)
mock = AsyncMock(return_value=FAKE_IMAGE_CHAT_FLUX)
with patch("app.routers.generate.openrouter.generate_image_chat", mock):
resp = await client.post( resp = await client.post(
"/generate/image", "/generate/image",
json={"model": "black-forest-labs/flux.2-klein-4b", "prompt": "Hi"}, json={"model": "black-forest-labs/flux.2-pro", "prompt": "Sky"},
headers={"Authorization": f"Bearer {token}"},
)
assert resp.status_code == 200
assert mock.call_args.kwargs["modalities"] == ["image"]
async def test_generate_image_no_images_in_response(client):
"""502 when model returns no images."""
token = await _user_token(client)
empty_response = {
"id": "gen-empty",
"model": "google/gemini-2.5-flash-image",
"choices": [{"message": {"role": "assistant", "content": "ok", "images": []}}],
}
with patch("app.routers.generate.openrouter.generate_image_chat",
new_callable=AsyncMock, return_value=empty_response):
resp = await client.post(
"/generate/image",
json={"model": "google/gemini-2.5-flash-image", "prompt": "Hi"},
headers={"Authorization": f"Bearer {token}"}, headers={"Authorization": f"Bearer {token}"},
) )
assert resp.status_code == 502 assert resp.status_code == 502
assert "No images returned" in resp.json()["detail"]
async def test_generate_image_flux(client):
"""Flux model works correctly via chat completions."""
token = await _user_token(client)
with patch("app.routers.generate.openrouter.generate_image_chat",
new_callable=AsyncMock, return_value=FAKE_IMAGE_CHAT_FLUX):
resp = await client.post(
"/generate/image",
json={"model": "black-forest-labs/flux.2-klein-4b",
"prompt": "A sunset"},
headers={"Authorization": f"Bearer {token}"},
)
assert resp.status_code == 200
data = resp.json()
assert data["images"][0]["url"] == "data:image/png;base64,abc123"
async def test_generate_image_stored_in_db(client):
"""Generated image row persists in generated_images table."""
from app import db as db_module
token = await _user_token(client)
with patch("app.routers.generate.openrouter.generate_image_chat",
new_callable=AsyncMock, return_value=FAKE_IMAGE_CHAT_GEMINI):
resp = await client.post(
"/generate/image",
json={"model": "google/gemini-2.5-flash-image",
"prompt": "A mountain"},
headers={"Authorization": f"Bearer {token}"},
)
assert resp.status_code == 200
image_id = resp.json()["images"][0]["image_id"]
assert image_id is not None
row = db_module.get_conn().execute(
"SELECT model_id, prompt, image_data FROM generated_images WHERE id = ?",
[image_id],
).fetchone()
assert row is not None
assert row[0] == "google/gemini-2.5-flash-image"
assert row[1] == "A mountain"
assert row[2] == "data:image/png;base64,gemini123"
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
+63 -14
View File
@@ -15,6 +15,7 @@ from app.services.models import (
_normalize_modality, _normalize_modality,
_parse_modality, _parse_modality,
get_cached_models, get_cached_models,
get_model_output_modalities,
is_cache_stale, is_cache_stale,
refresh_models_cache, refresh_models_cache,
) )
@@ -28,28 +29,35 @@ FAKE_MODELS_RAW = [
"name": "GPT-4o", "name": "GPT-4o",
"context_length": 128000, "context_length": 128000,
"pricing": {"prompt": "0.000005"}, "pricing": {"prompt": "0.000005"},
"architecture": {"modality": "text->text"}, "architecture": {"modality": "text->text", "output_modalities": ["text"]},
}, },
{ {
"id": "anthropic/claude-3-haiku", "id": "anthropic/claude-3-haiku",
"name": "Claude 3 Haiku", "name": "Claude 3 Haiku",
"context_length": 200000, "context_length": 200000,
"pricing": {}, "pricing": {},
"architecture": {"modality": "text+image->text"}, "architecture": {"modality": "text+image->text", "output_modalities": ["text"]},
}, },
{ {
"id": "openai/dall-e-3", "id": "openai/dall-e-3",
"name": "DALL-E 3", "name": "DALL-E 3",
"context_length": None, "context_length": None,
"pricing": {"image": "0.04"}, "pricing": {"image": "0.04"},
"architecture": {"modality": "text->image"}, "architecture": {"modality": "text->image", "output_modalities": ["image"]},
}, },
{ {
"id": "openai/sora-2", "id": "openai/sora-2",
"name": "Sora 2", "name": "Sora 2",
"context_length": None, "context_length": None,
"pricing": {"video": "0.10"}, "pricing": {"video": "0.10"},
"architecture": {"modality": "text->video"}, "architecture": {"modality": "text->video", "output_modalities": ["video"]},
},
{
"id": "google/gemini-2.5-flash-image",
"name": "Gemini 2.5 Flash Image",
"context_length": None,
"pricing": {},
"architecture": {"output_modalities": ["image", "text"]},
}, },
] ]
@@ -171,9 +179,9 @@ async def test_refresh_stores_models():
return_value=FAKE_MODELS_RAW, return_value=FAKE_MODELS_RAW,
): ):
count = await refresh_models_cache(conn) count = await refresh_models_cache(conn)
assert count == 4 assert count == 5
all_models = get_cached_models(conn) all_models = get_cached_models(conn)
assert len(all_models) == 4 assert len(all_models) == 5
async def test_refresh_replaces_old_cache(): async def test_refresh_replaces_old_cache():
@@ -193,25 +201,29 @@ async def test_refresh_replaces_old_cache():
ids = [m["id"] for m in get_cached_models(conn)] ids = [m["id"] for m in get_cached_models(conn)]
assert "old/model" not in ids assert "old/model" not in ids
assert "openai/gpt-4o" in ids assert "openai/gpt-4o" in ids
assert len(ids) == 5
def test_get_cached_models_filter_by_modality(): def test_get_cached_models_filter_by_modality():
conn = db_module.get_conn() conn = db_module.get_conn()
now = datetime.now(timezone.utc).replace(tzinfo=None) now = datetime.now(timezone.utc).replace(tzinfo=None)
for m in FAKE_MODELS_RAW: for m in FAKE_MODELS_RAW:
arch = m.get("architecture", {}) modality = _extract_output_modality(m)
modality = _parse_modality(arch.get("modality", "text->text"))
conn.execute( conn.execute(
"INSERT INTO models_cache (model_id, name, modality, fetched_at) VALUES (?, ?, ?, ?)", "INSERT INTO models_cache (model_id, name, modality, fetched_at) VALUES (?, ?, ?, ?)",
[m["id"], m["name"], modality, now], [m["id"], m["name"], modality, now],
) )
text_models = get_cached_models(conn, modality="text") text_models = get_cached_models(conn, modality="text")
# gpt-4o, claude-3-haiku (gemini has output_modalities=["image","text"] → classified as "image")
assert len(text_models) == 2 assert len(text_models) == 2
assert all(m["modality"] == "text" for m in text_models) assert all(m["modality"] == "text" for m in text_models)
image_models = get_cached_models(conn, modality="image") image_models = get_cached_models(conn, modality="image")
assert len(image_models) == 1 # dall-e-3 + gemini (output_modalities starts with image)
assert image_models[0]["id"] == "openai/dall-e-3" assert len(image_models) == 2
image_ids = [m["id"] for m in image_models]
assert "openai/dall-e-3" in image_ids
assert "google/gemini-2.5-flash-image" in image_ids
video_models = get_cached_models(conn, modality="video") video_models = get_cached_models(conn, modality="video")
assert len(video_models) == 1 assert len(video_models) == 1
@@ -233,7 +245,7 @@ async def test_list_models_endpoint_auto_refreshes(client):
"/models/", headers={"Authorization": f"Bearer {token}"} "/models/", headers={"Authorization": f"Bearer {token}"}
) )
assert resp.status_code == 200 assert resp.status_code == 200
assert len(resp.json()) == 4 assert len(resp.json()) == 5
assert mock_fetch.await_count >= 1 assert mock_fetch.await_count >= 1
@@ -274,8 +286,10 @@ async def test_list_models_filter_by_modality(client):
) )
assert resp.status_code == 200 assert resp.status_code == 200
data = resp.json() data = resp.json()
assert len(data) == 1 assert len(data) == 2 # dall-e-3 + gemini-2.5-flash-image
assert data[0]["id"] == "openai/dall-e-3" image_ids = [m["id"] for m in data]
assert "openai/dall-e-3" in image_ids
assert "google/gemini-2.5-flash-image" in image_ids
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@@ -301,7 +315,7 @@ async def test_refresh_endpoint_admin_succeeds(client):
"/models/refresh", headers={"Authorization": f"Bearer {token}"} "/models/refresh", headers={"Authorization": f"Bearer {token}"}
) )
assert resp.status_code == 200 assert resp.status_code == 200
assert resp.json()["refreshed"] == 4 assert resp.json()["refreshed"] == 5
async def test_refresh_endpoint_502_on_openrouter_error(client): async def test_refresh_endpoint_502_on_openrouter_error(client):
@@ -315,3 +329,38 @@ async def test_refresh_endpoint_502_on_openrouter_error(client):
"/models/refresh", headers={"Authorization": f"Bearer {token}"} "/models/refresh", headers={"Authorization": f"Bearer {token}"}
) )
assert resp.status_code == 502 assert resp.status_code == 502
# ---------------------------------------------------------------------------
# Unit tests: get_model_output_modalities
# ---------------------------------------------------------------------------
async def test_get_model_output_modalities_image_only():
conn = db_module.get_conn()
with patch(
"app.services.models.openrouter.list_models",
new_callable=AsyncMock,
return_value=FAKE_MODELS_RAW,
):
await refresh_models_cache(conn)
modalities = get_model_output_modalities(conn, "openai/dall-e-3")
assert modalities == ["image"]
async def test_get_model_output_modalities_image_text():
conn = db_module.get_conn()
with patch(
"app.services.models.openrouter.list_models",
new_callable=AsyncMock,
return_value=FAKE_MODELS_RAW,
):
await refresh_models_cache(conn)
modalities = get_model_output_modalities(
conn, "google/gemini-2.5-flash-image")
assert set(modalities) == {"image", "text"}
def test_get_model_output_modalities_unknown_model():
conn = db_module.get_conn()
result = get_model_output_modalities(conn, "unknown/model")
assert result == []
+3 -1
View File
@@ -115,7 +115,9 @@ async def test_list_users_as_admin(client):
resp = await client.get("/users", headers={"Authorization": f"Bearer {admin_token}"}) resp = await client.get("/users", headers={"Authorization": f"Bearer {admin_token}"})
assert resp.status_code == 200 assert resp.status_code == 200
assert isinstance(resp.json(), list) assert isinstance(resp.json(), list)
assert len(resp.json()) == 1 assert len(resp.json()) >= 1
emails = [u["email"] for u in resp.json()]
assert "user@example.com" in emails
async def test_list_users_as_regular_user(client): async def test_list_users_as_regular_user(client):
+7 -1
View File
@@ -178,7 +178,13 @@ def dashboard():
user = resp.json() if resp.status_code == 200 else {} user = resp.json() if resp.status_code == 200 else {}
img_resp = _api("GET", "/images/", token=token) img_resp = _api("GET", "/images/", token=token)
images = img_resp.json() if img_resp.status_code == 200 else [] images = img_resp.json() if img_resp.status_code == 200 else []
return render_template("dashboard.html", user=user, images=images) gen_resp = _api("GET", "/generate/images", token=token)
generated_images = gen_resp.json() if gen_resp.status_code == 200 else []
vid_resp = _api("GET", "/generate/videos", token=token)
generated_videos = vid_resp.json() if vid_resp.status_code == 200 else []
return render_template("dashboard.html", user=user, images=images,
generated_images=generated_images,
generated_videos=generated_videos)
# ── Generate ────────────────────────────────────────────────────────────── # ── Generate ──────────────────────────────────────────────────────────────
+53 -1
View File
@@ -6,7 +6,59 @@ endblock %} {% block content %}
<a href="{{ url_for('generate') }}" class="btn">Start generating</a> <a href="{{ url_for('generate') }}" class="btn">Start generating</a>
</div> </div>
{% if images %} {% if generated_images %}
<div class="card mt-2">
<h2>Generated images</h2>
<div class="image-grid">
{% for img in generated_images %}
<div class="image-grid-item">
<img
src="{{ img.image_data }}"
alt="{{ img.prompt }}"
class="generated-image"
loading="lazy"
/>
<p class="text-muted" style="font-size: 0.75rem; margin-top: 0.25rem">
<strong>{{ img.model_id }}</strong><br />{{ img.prompt[:80] }}{% if
img.prompt|length > 80 %}…{% endif %}
</p>
</div>
{% endfor %}
</div>
</div>
{% endif %} {% if generated_videos %}
<div class="card mt-2">
<h2>Generated videos</h2>
<div class="image-grid">
{% for vid in generated_videos %}
<div class="image-grid-item">
{% if vid.video_url %}
<video controls style="max-width: 100%; border-radius: 6px">
<source src="{{ vid.video_url }}" />
Your browser does not support the video tag.
</video>
{% else %}
<div
style="
background: #1a1a1a;
border-radius: 6px;
padding: 2rem;
text-align: center;
"
>
<span class="text-muted">{{ vid.status | capitalize }} &hellip;</span>
</div>
{% endif %}
<p class="text-muted" style="font-size: 0.75rem; margin-top: 0.25rem">
<strong>{{ vid.model_id }}</strong><br />{{ vid.prompt[:80] }}{% if
vid.prompt|length > 80 %}…{% endif %}<br />
<em>{{ vid.status }}</em>
</p>
</div>
{% endfor %}
</div>
</div>
{% endif %} {% if images %}
<div class="card mt-2"> <div class="card mt-2">
<h2>Uploaded reference images</h2> <h2>Uploaded reference images</h2>
<div class="image-grid"> <div class="image-grid">
+19 -32
View File
@@ -8,12 +8,12 @@
{% if models %} {% if models %}
<select id="model" name="model" required> <select id="model" name="model" required>
{% for m in models %} {% for m in models %}
<option value="{{ m.id }}" {% if request.form.get('model', '') == m.id %}selected{% endif %}>{{ m.name }}</option> <option value="{{ m.id }}" {{ "selected" if request.form.get('model', '') == m.id else "" }}>{{ m.name }}</option>
{% endfor %} {% endfor %}
</select> </select>
{% else %} {% else %}
<input id="model" name="model" type="text" required <input id="model" name="model" type="text" required
placeholder="e.g. openai/dall-e-3" placeholder="e.g. google/gemini-2.5-flash-image"
value="{{ request.form.get('model', '') }}"> value="{{ request.form.get('model', '') }}">
{% endif %} {% endif %}
@@ -21,40 +21,25 @@
<textarea id="prompt" name="prompt" rows="4" required <textarea id="prompt" name="prompt" rows="4" required
placeholder="Describe the image you want…">{{ request.form.get('prompt', '') }}</textarea> placeholder="Describe the image you want…">{{ request.form.get('prompt', '') }}</textarea>
<label for="size">Size</label>
<select id="size" name="size">
<option value="1024x1024" {% if request.form.get('size','1024x1024')=='1024x1024' %}selected{% endif %}>1024×1024</option>
<option value="1792x1024" {% if request.form.get('size')=='1792x1024' %}selected{% endif %}>1792×1024 (landscape)</option>
<option value="1024x1792" {% if request.form.get('size')=='1024x1792' %}selected{% endif %}>1024×1792 (portrait)</option>
<option value="512x512" {% if request.form.get('size')=='512x512' %}selected{% endif %}>512×512</option>
</select>
<label for="aspect_ratio">Aspect ratio</label> <label for="aspect_ratio">Aspect ratio</label>
<select id="aspect_ratio" name="aspect_ratio"> <select id="aspect_ratio" name="aspect_ratio">
<option value="">Auto (default)</option> <option value="">Auto (default 1:1)</option>
<option value="1:1" {% if request.form.get('aspect_ratio')=='1:1' %}selected{% endif %}>1:1 (square)</option> <option value="1:1" {{ "selected" if request.form.get('aspect_ratio')=='1:1' else "" }}>1:1 (square)</option>
<option value="16:9" {% if request.form.get('aspect_ratio')=='16:9' %}selected{% endif %}>16:9 (landscape)</option> <option value="16:9" {{ "selected" if request.form.get('aspect_ratio')=='16:9' else "" }}>16:9 (landscape)</option>
<option value="9:16" {% if request.form.get('aspect_ratio')=='9:16' %}selected{% endif %}>9:16 (portrait)</option> <option value="9:16" {{ "selected" if request.form.get('aspect_ratio')=='9:16' else "" }}>9:16 (portrait)</option>
<option value="4:3" {% if request.form.get('aspect_ratio')=='4:3' %}selected{% endif %}>4:3</option> <option value="4:3" {{ "selected" if request.form.get('aspect_ratio')=='4:3' else "" }}>4:3</option>
<option value="3:4" {% if request.form.get('aspect_ratio')=='3:4' %}selected{% endif %}>3:4</option> <option value="3:4" {{ "selected" if request.form.get('aspect_ratio')=='3:4' else "" }}>3:4</option>
<option value="3:2" {% if request.form.get('aspect_ratio')=='3:2' %}selected{% endif %}>3:2</option> <option value="3:2" {{ "selected" if request.form.get('aspect_ratio')=='3:2' else "" }}>3:2</option>
<option value="2:3" {% if request.form.get('aspect_ratio')=='2:3' %}selected{% endif %}>2:3</option> <option value="2:3" {{ "selected" if request.form.get('aspect_ratio')=='2:3' else "" }}>2:3</option>
</select> </select>
<label for="image_size">Resolution</label> <label for="image_size">Resolution</label>
<select id="image_size" name="image_size"> <select id="image_size" name="image_size">
<option value="">Auto (default)</option> <option value="">Auto (default 1K)</option>
<option value="0.5K" {% if request.form.get('image_size')=='0.5K' %}selected{% endif %}>0.5K (low)</option> <option value="0.5K" {{ "selected" if request.form.get('image_size')=='0.5K' else "" }}>0.5K (low)</option>
<option value="1K" {% if request.form.get('image_size')=='1K' %}selected{% endif %}>1K (standard)</option> <option value="1K" {{ "selected" if request.form.get('image_size')=='1K' else "" }}>1K (standard)</option>
<option value="2K" {% if request.form.get('image_size')=='2K' %}selected{% endif %}>2K (high)</option> <option value="2K" {{ "selected" if request.form.get('image_size')=='2K' else "" }}>2K (high)</option>
<option value="4K" {% if request.form.get('image_size')=='4K' %}selected{% endif %}>4K (ultra)</option> <option value="4K" {{ "selected" if request.form.get('image_size')=='4K' else "" }}>4K (ultra)</option>
</select>
<label for="n">Number of images</label>
<select id="n" name="n">
<option value="1" {% if request.form.get('n','1')=='1' %}selected{% endif %}>1</option>
<option value="2" {% if request.form.get('n')=='2' %}selected{% endif %}>2</option>
<option value="4" {% if request.form.get('n')=='4' %}selected{% endif %}>4</option>
</select> </select>
<label for="reference_image">Reference image (optional)</label> <label for="reference_image">Reference image (optional)</label>
@@ -65,7 +50,7 @@
accept="image/png,image/jpeg,image/webp,image/gif" accept="image/png,image/jpeg,image/webp,image/gif"
> >
<p class="text-muted mt-1" id="reference-image-help"> <p class="text-muted mt-1" id="reference-image-help">
Upload image for visual reference in upcoming image-to-image flow. Upload an image to use as visual reference (image-to-image).
</p> </p>
<div class="image-upload-preview" id="image-upload-preview" hidden> <div class="image-upload-preview" id="image-upload-preview" hidden>
<p class="text-muted" id="image-upload-filename"></p> <p class="text-muted" id="image-upload-filename"></p>
@@ -83,7 +68,9 @@
<div class="result"> <div class="result">
<h2>Generated image{{ 's' if result.images|length > 1 }}</h2> <h2>Generated image{{ 's' if result.images|length > 1 }}</h2>
{% for img in result.images %} {% for img in result.images %}
<img src="{{ img.url }}" alt="Generated image" class="generated-image"> {% if img.url %}
<img src="{{ img.url }}" alt="Generated image" class="generated-image">
{% endif %}
{% if img.revised_prompt %} {% if img.revised_prompt %}
<p class="text-muted mt-1" style="font-size:0.8rem;">{{ img.revised_prompt }}</p> <p class="text-muted mt-1" style="font-size:0.8rem;">{{ img.revised_prompt }}</p>
{% endif %} {% endif %}
+28 -3
View File
@@ -151,7 +151,8 @@ def test_dashboard_renders_user_info(client):
me_mock = _mock_response( me_mock = _mock_response(
200, {"id": "1", "email": "u@example.com", "role": "user"}) 200, {"id": "1", "email": "u@example.com", "role": "user"})
images_mock = _mock_response(200, []) images_mock = _mock_response(200, [])
with patch("frontend.app.main.httpx.request", side_effect=[me_mock, images_mock]): gen_images_mock = _mock_response(200, [])
with patch("frontend.app.main.httpx.request", side_effect=[me_mock, images_mock, gen_images_mock]):
resp = client.get("/dashboard") resp = client.get("/dashboard")
assert resp.status_code == 200 assert resp.status_code == 200
assert b"u@example.com" in resp.data assert b"u@example.com" in resp.data
@@ -531,19 +532,43 @@ def test_dashboard_shows_uploaded_images(client):
{"id": "img-1", "filename": "cat.png", "content_type": "image/png", {"id": "img-1", "filename": "cat.png", "content_type": "image/png",
"size_bytes": 1024, "created_at": "2026-04-29T10:00:00"}, "size_bytes": 1024, "created_at": "2026-04-29T10:00:00"},
]) ])
with patch("frontend.app.main.httpx.request", side_effect=[me_mock, images_mock]): gen_images_mock = _mock_response(200, [])
with patch("frontend.app.main.httpx.request", side_effect=[me_mock, images_mock, gen_images_mock]):
resp = client.get("/dashboard") resp = client.get("/dashboard")
assert resp.status_code == 200 assert resp.status_code == 200
assert b"cat.png" in resp.data assert b"cat.png" in resp.data
assert b"img-1" in resp.data assert b"img-1" in resp.data
def test_dashboard_shows_generated_images(client):
_set_auth(client)
me_mock = _mock_response(
200, {"id": "1", "email": "u@example.com", "role": "user"})
images_mock = _mock_response(200, [])
gen_images_mock = _mock_response(200, [
{
"id": "gen-1",
"model_id": "google/gemini-2.5-flash-image",
"prompt": "A cat on the moon",
"image_data": "data:image/png;base64,abc123",
"created_at": "2026-04-29T10:00:00",
}
])
with patch("frontend.app.main.httpx.request", side_effect=[me_mock, images_mock, gen_images_mock]):
resp = client.get("/dashboard")
assert resp.status_code == 200
assert b"Generated images" in resp.data
assert b"A cat on the moon" in resp.data
assert b"data:image/png;base64,abc123" in resp.data
def test_dashboard_no_images_section_when_empty(client): def test_dashboard_no_images_section_when_empty(client):
_set_auth(client) _set_auth(client)
me_mock = _mock_response( me_mock = _mock_response(
200, {"id": "1", "email": "u@example.com", "role": "user"}) 200, {"id": "1", "email": "u@example.com", "role": "user"})
images_mock = _mock_response(200, []) images_mock = _mock_response(200, [])
with patch("frontend.app.main.httpx.request", side_effect=[me_mock, images_mock]): gen_images_mock = _mock_response(200, [])
with patch("frontend.app.main.httpx.request", side_effect=[me_mock, images_mock, gen_images_mock]):
resp = client.get("/dashboard") resp = client.get("/dashboard")
assert resp.status_code == 200 assert resp.status_code == 200
assert b"Uploaded reference images" not in resp.data assert b"Uploaded reference images" not in resp.data