feat: enhance model caching and output modalities handling

- Updated `refresh_models_cache` to include output modalities in the models cache.
- Added `get_model_output_modalities` function to retrieve output modalities for a specific model.
- Modified tests to cover new functionality for output modalities.
- Updated OpenRouter video generation functions to support audio generation and improved error handling.
- Enhanced dashboard to display generated images and videos.
- Refactored frontend templates to accommodate new data structures for generated content.
- Adjusted tests to validate changes in model handling and dashboard rendering.

Co-authored-by: Copilot <copilot@github.com>
This commit is contained in:
2026-04-29 15:20:48 +02:00
parent 3d32e6df74
commit 712c556032
15 changed files with 618 additions and 219 deletions
+28
View File
@@ -86,6 +86,34 @@ def _run_migrations(conn: duckdb.DuckDBPyConnection) -> None:
fetched_at TIMESTAMP NOT NULL
)
""")
conn.execute("""
CREATE TABLE IF NOT EXISTS generated_images (
id UUID DEFAULT uuid() PRIMARY KEY,
user_id UUID NOT NULL,
model_id VARCHAR NOT NULL,
prompt VARCHAR NOT NULL,
image_data VARCHAR NOT NULL,
created_at TIMESTAMP DEFAULT now()
)
""")
conn.execute("""
CREATE TABLE IF NOT EXISTS generated_videos (
id UUID DEFAULT uuid() PRIMARY KEY,
user_id UUID NOT NULL,
job_id VARCHAR NOT NULL,
model_id VARCHAR NOT NULL,
prompt VARCHAR NOT NULL,
polling_url VARCHAR,
status VARCHAR NOT NULL DEFAULT 'pending',
video_url VARCHAR,
created_at TIMESTAMP DEFAULT now(),
updated_at TIMESTAMP DEFAULT now()
)
""")
# Migration: add output_modalities column if absent (stores JSON array string)
conn.execute("""
ALTER TABLE models_cache ADD COLUMN IF NOT EXISTS output_modalities VARCHAR
""")
_seed_admin(conn)
+14 -14
View File
@@ -1,10 +1,10 @@
from .routers import auth as auth_router
from .routers import users as users_router
from .routers import admin as admin_router
from .routers import ai as ai_router
from .routers import generate as generate_router
from .routers import images as images_router
from .routers import models as models_router
from .routers import auth
from .routers import users
from .routers import admin
from .routers import ai
from .routers import generate
from .routers import images
from .routers import models
from .db import close_db, init_db
import os
from contextlib import asynccontextmanager
@@ -38,13 +38,13 @@ app.add_middleware(
allow_headers=["*"],
)
app.include_router(auth_router.router)
app.include_router(users_router.router)
app.include_router(admin_router.router)
app.include_router(ai_router.router)
app.include_router(generate_router.router)
app.include_router(images_router.router)
app.include_router(models_router.router)
app.include_router(auth.router)
app.include_router(users.router)
app.include_router(admin.router)
app.include_router(ai.router)
app.include_router(generate.router)
app.include_router(images.router)
app.include_router(models.router)
@app.get("/health", tags=["health"])
+1
View File
@@ -62,6 +62,7 @@ class ImageResult(BaseModel):
url: str | None = None
b64_json: str | None = None
revised_prompt: str | None = None
image_id: str | None = None # UUID of stored row in generated_images
class ImageResponse(BaseModel):
+193 -69
View File
@@ -1,6 +1,9 @@
"""Generate router: text, image, video, and image-to-video generation."""
from datetime import datetime, timezone
from fastapi import APIRouter, Depends, HTTPException, status
from ..db import get_conn, get_write_lock
from ..dependencies import get_current_user
from ..models.ai import (
ImageRequest,
@@ -13,6 +16,7 @@ from ..models.ai import (
VideoResponse,
)
from ..services import openrouter
from ..services.models import get_model_output_modalities
router = APIRouter(prefix="/generate", tags=["generate"])
@@ -62,81 +66,129 @@ async def generate_text(
@router.post("/image", response_model=ImageResponse)
async def generate_image(
body: ImageRequest,
_: dict = Depends(get_current_user),
current_user: dict = Depends(get_current_user),
) -> ImageResponse:
"""Generate images from a text prompt."""
# Detect if model uses chat completions (FLUX, GPT-5 Image Mini) vs /images/generations (DALL-E)
chat_models = {"black-forest-labs/flux.2-klein-4b",
"openai/gpt-5-image-mini"}
is_chat_model = body.model.lower() in {m.lower() for m in chat_models} or \
any(m in body.model.lower() for m in ["flux", "gpt-5-image-mini"])
"""Generate images from a prompt using the chat completions endpoint.
All OpenRouter image models use /chat/completions with a modalities param.
Models that output only images use ["image"]; those that also output text
use ["image", "text"]. We look this up from the model cache; default to
["image", "text"] when the model is not yet cached.
"""
# Determine modalities from cache; default ["image", "text"] works for most models
try:
conn = get_conn()
cached_modalities = get_model_output_modalities(conn, body.model)
except Exception:
cached_modalities = []
if cached_modalities:
# If cache says model only outputs image (no text), use ["image"]
modalities = ["image"] if set(cached_modalities) == {
"image"} else ["image", "text"]
else:
# Safe default: ["image", "text"]; works for Gemini, GPT-image etc.
# For image-only models that fail with this, the error surfaces to the user.
modalities = ["image", "text"]
image_config: dict = {}
if body.aspect_ratio:
image_config["aspect_ratio"] = body.aspect_ratio
if body.image_size:
image_config["image_size"] = body.image_size
try:
if is_chat_model:
image_config = {}
if body.aspect_ratio:
image_config["aspect_ratio"] = body.aspect_ratio
if body.image_size:
image_config["image_size"] = body.image_size
result = await openrouter.generate_image_chat(
model=body.model,
prompt=body.prompt,
modalities=[
"image", "text"] if "gpt-5-image-mini" in body.model.lower() else ["image"],
image_config=image_config if image_config else None,
)
else:
result = await openrouter.generate_image(
model=body.model,
prompt=body.prompt,
n=body.n,
size=body.size,
)
result = await openrouter.generate_image_chat(
model=body.model,
prompt=body.prompt,
modalities=modalities,
image_config=image_config if image_config else None,
)
except Exception as exc:
raise HTTPException(
status_code=status.HTTP_502_BAD_GATEWAY, detail=f"OpenRouter error: {exc}")
try:
if is_chat_model:
# Chat completions response: choices[0].message.images[].image_url.url
images = []
message = result.get("choices", [{}])[0].get("message", {})
for item in message.get("images", []):
img_url = item.get("image_url", {}).get("url")
images.append(ImageResult(
url=img_url,
b64_json=None,
revised_prompt=message.get("content"),
message = result.get("choices", [{}])[0].get("message", {})
images = []
for item in message.get("images", []):
img_url = item.get("image_url", {}).get("url")
images.append(ImageResult(
url=img_url,
b64_json=None,
revised_prompt=message.get("content") or None,
))
if not images:
raise HTTPException(
status_code=status.HTTP_502_BAD_GATEWAY,
detail="No images returned by model. Verify the model supports image generation.",
)
# Persist each image to DB
user_id = current_user.get("id") or current_user.get("sub")
now = datetime.now(timezone.utc).replace(tzinfo=None)
stored: list[ImageResult] = []
async with get_write_lock():
conn = get_conn()
for img in images:
if img.url:
row = conn.execute(
"""INSERT INTO generated_images (user_id, model_id, prompt, image_data, created_at)
VALUES (?, ?, ?, ?, ?) RETURNING id""",
[user_id, body.model, body.prompt, img.url, now],
).fetchone()
image_id = str(row[0]) if row else None
else:
image_id = None
stored.append(ImageResult(
url=img.url,
b64_json=img.b64_json,
revised_prompt=img.revised_prompt,
image_id=image_id,
))
return ImageResponse(
id=result.get("id", ""),
model=result.get("model", body.model),
images=images,
)
else:
# /images/generations response: data[].url
images = [
ImageResult(
url=item.get("url"),
b64_json=item.get("b64_json"),
revised_prompt=item.get("revised_prompt"),
)
for item in result.get("data", [])
]
return ImageResponse(
id=result.get("id", ""),
model=result.get("model", body.model),
images=images,
)
return ImageResponse(
id=result.get("id", ""),
model=result.get("model", body.model),
images=stored,
)
except HTTPException:
raise
except (KeyError, TypeError) as exc:
raise HTTPException(status_code=status.HTTP_502_BAD_GATEWAY,
detail=f"Unexpected response format: {exc}")
@router.get("/images")
async def list_generated_images(
current_user: dict = Depends(get_current_user),
) -> list[dict]:
"""Return all generated images for the current user, newest first."""
user_id = current_user.get("id") or current_user.get("sub")
conn = get_conn()
rows = conn.execute(
"""SELECT id, model_id, prompt, image_data, created_at
FROM generated_images
WHERE user_id = ?
ORDER BY created_at DESC""",
[user_id],
).fetchall()
return [
{
"id": str(r[0]),
"model_id": r[1],
"prompt": r[2],
"image_data": r[3],
"created_at": r[4].isoformat() if r[4] else None,
}
for r in rows
]
@router.post("/video", response_model=VideoResponse)
async def generate_video(
body: VideoRequest,
_: dict = Depends(get_current_user),
current_user: dict = Depends(get_current_user),
) -> VideoResponse:
"""Generate a video from a text prompt."""
try:
@@ -151,12 +203,26 @@ async def generate_video(
raise HTTPException(
status_code=status.HTTP_502_BAD_GATEWAY, detail=f"OpenRouter error: {exc}")
user_id = current_user.get("id") or current_user.get("sub")
job_id = result.get("id", "")
polling_url = result.get("polling_url")
job_status = result.get("status", "pending")
now = datetime.now(timezone.utc).replace(tzinfo=None)
async with get_write_lock():
conn = get_conn()
conn.execute(
"""INSERT INTO generated_videos (user_id, job_id, model_id, prompt, polling_url, status, created_at, updated_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)""",
[user_id, job_id, body.model, body.prompt,
polling_url, job_status, now, now],
)
urls = result.get("unsigned_urls") or result.get("video_urls")
return VideoResponse(
id=result.get("id", ""),
id=job_id,
model=body.model,
status=result.get("status", "queued"),
polling_url=result.get("polling_url"),
status=job_status,
polling_url=polling_url,
video_urls=urls,
video_url=(urls or [None])[0],
error=result.get("error"),
@@ -167,7 +233,7 @@ async def generate_video(
@router.post("/video/from-image", response_model=VideoResponse)
async def generate_video_from_image(
body: VideoFromImageRequest,
_: dict = Depends(get_current_user),
current_user: dict = Depends(get_current_user),
) -> VideoResponse:
"""Generate a video from an image and a text prompt."""
try:
@@ -183,12 +249,26 @@ async def generate_video_from_image(
raise HTTPException(
status_code=status.HTTP_502_BAD_GATEWAY, detail=f"OpenRouter error: {exc}")
user_id = current_user.get("id") or current_user.get("sub")
job_id = result.get("id", "")
polling_url = result.get("polling_url")
job_status = result.get("status", "pending")
now = datetime.now(timezone.utc).replace(tzinfo=None)
async with get_write_lock():
conn = get_conn()
conn.execute(
"""INSERT INTO generated_videos (user_id, job_id, model_id, prompt, polling_url, status, created_at, updated_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)""",
[user_id, job_id, body.model, body.prompt,
polling_url, job_status, now, now],
)
urls = result.get("unsigned_urls") or result.get("video_urls")
return VideoResponse(
id=result.get("id", ""),
id=job_id,
model=body.model,
status=result.get("status", "queued"),
polling_url=result.get("polling_url"),
status=job_status,
polling_url=polling_url,
video_urls=urls,
video_url=(urls or [None])[0],
error=result.get("error"),
@@ -199,23 +279,67 @@ async def generate_video_from_image(
@router.get("/video/status", response_model=VideoResponse)
async def poll_video_status(
polling_url: str,
_: dict = Depends(get_current_user),
current_user: dict = Depends(get_current_user),
) -> VideoResponse:
"""Poll the status of a video generation job via its polling_url."""
"""Poll status of a video generation job; updates DB row when completed/failed."""
try:
result = await openrouter.poll_video_status(polling_url)
except Exception as exc:
raise HTTPException(
status_code=status.HTTP_502_BAD_GATEWAY, detail=f"OpenRouter error: {exc}")
job_status = result.get("status", "processing")
urls = result.get("unsigned_urls") or result.get("video_urls")
video_url = (urls or [None])[0]
# Update DB row for this job when terminal state reached
if job_status in ("completed", "failed"):
now = datetime.now(timezone.utc).replace(tzinfo=None)
async with get_write_lock():
conn = get_conn()
conn.execute(
"""UPDATE generated_videos
SET status = ?, video_url = ?, updated_at = ?
WHERE job_id = ?""",
[job_status, video_url, now, result.get("id", "")],
)
return VideoResponse(
id=result.get("id", ""),
model=result.get("model", ""),
status=result.get("status", "processing"),
status=job_status,
polling_url=result.get("polling_url"),
video_urls=urls,
video_url=(urls or [None])[0],
video_url=video_url,
error=result.get("error"),
metadata=result.get("metadata"),
)
@router.get("/videos")
async def list_generated_videos(
current_user: dict = Depends(get_current_user),
) -> list[dict]:
"""Return all generated video jobs for the current user, newest first."""
user_id = current_user.get("id") or current_user.get("sub")
conn = get_conn()
rows = conn.execute(
"""SELECT id, job_id, model_id, prompt, polling_url, status, video_url, created_at
FROM generated_videos
WHERE user_id = ?
ORDER BY created_at DESC""",
[user_id],
).fetchall()
return [
{
"id": str(r[0]),
"job_id": r[1],
"model_id": r[2],
"prompt": r[3],
"polling_url": r[4],
"status": r[5],
"video_url": r[6],
"created_at": r[7].isoformat() if r[7] else None,
}
for r in rows
]
+33 -3
View File
@@ -91,16 +91,28 @@ async def refresh_models_cache(conn: duckdb.DuckDBPyConnection) -> int:
model_id = m.get("id", "")
if not model_id:
continue
# Full output_modalities array from architecture (for proper modalities param in image gen)
architecture = m.get("architecture") or {}
raw_output_modalities: list | None = (
architecture.get("output_modalities") or m.get("output_modalities")
)
output_modalities_json: str | None = (
json.dumps([_normalize_modality(str(v))
for v in raw_output_modalities])
if isinstance(raw_output_modalities, list)
else None
)
conn.execute(
"""
INSERT INTO models_cache (model_id, name, modality, context_length, pricing, fetched_at)
VALUES (?, ?, ?, ?, ?, ?)
INSERT INTO models_cache (model_id, name, modality, context_length, pricing, fetched_at, output_modalities)
VALUES (?, ?, ?, ?, ?, ?, ?)
ON CONFLICT (model_id) DO UPDATE SET
name = excluded.name,
modality = excluded.modality,
context_length = excluded.context_length,
pricing = excluded.pricing,
fetched_at = excluded.fetched_at
fetched_at = excluded.fetched_at,
output_modalities = excluded.output_modalities
""",
[
model_id,
@@ -109,6 +121,7 @@ async def refresh_models_cache(conn: duckdb.DuckDBPyConnection) -> int:
m.get("context_length"),
json.dumps(pricing) if pricing else None,
now,
output_modalities_json,
],
)
count += 1
@@ -168,3 +181,20 @@ def get_cached_models(
"pricing": pricing,
})
return result
def get_model_output_modalities(
conn: duckdb.DuckDBPyConnection,
model_id: str,
) -> list[str]:
"""Return output_modalities list for a model; empty list if not found."""
row = conn.execute(
"SELECT output_modalities FROM models_cache WHERE model_id = ?",
[model_id],
).fetchone()
if not row or not row[0]:
return []
try:
return json.loads(row[0])
except (json.JSONDecodeError, TypeError):
return []
+33 -5
View File
@@ -95,8 +95,9 @@ async def generate_video(
duration_seconds: int | None = None,
aspect_ratio: str = "16:9",
resolution: str | None = None,
generate_audio: bool | None = None,
) -> dict[str, Any]:
"""Request text-to-video generation via OpenRouter."""
"""Request text-to-video generation via OpenRouter POST /videos."""
base_url = os.getenv("OPENROUTER_BASE_URL", OPENROUTER_BASE_URL)
payload: dict[str, Any] = {
"model": model,
@@ -104,9 +105,12 @@ async def generate_video(
"aspect_ratio": aspect_ratio,
}
if duration_seconds is not None:
payload["duration_seconds"] = duration_seconds
# API uses 'duration' not 'duration_seconds'
payload["duration"] = duration_seconds
if resolution is not None:
payload["resolution"] = resolution
if generate_audio is not None:
payload["generate_audio"] = generate_audio
async with httpx.AsyncClient(timeout=120) as client:
resp = client.build_request(
"POST", f"{base_url}/videos", headers=_headers(), json=payload
@@ -123,19 +127,31 @@ async def generate_video_from_image(
duration_seconds: int | None = None,
aspect_ratio: str = "16:9",
resolution: str | None = None,
generate_audio: bool | None = None,
) -> dict[str, Any]:
"""Request image-to-video generation via OpenRouter."""
"""Request image-to-video generation via OpenRouter POST /videos.
Uses frame_images array with first_frame as per OpenRouter API spec.
"""
base_url = os.getenv("OPENROUTER_BASE_URL", OPENROUTER_BASE_URL)
payload: dict[str, Any] = {
"model": model,
"image_url": image_url,
"prompt": prompt,
"aspect_ratio": aspect_ratio,
"frame_images": [
{
"type": "image_url",
"image_url": {"url": image_url},
"frame_type": "first_frame",
}
],
}
if duration_seconds is not None:
payload["duration_seconds"] = duration_seconds
payload["duration"] = duration_seconds
if resolution is not None:
payload["resolution"] = resolution
if generate_audio is not None:
payload["generate_audio"] = generate_audio
async with httpx.AsyncClient(timeout=120) as client:
resp = client.build_request(
"POST", f"{base_url}/videos", headers=_headers(), json=payload
@@ -154,6 +170,18 @@ async def poll_video_status(polling_url: str) -> dict[str, Any]:
return response.json()
async def list_video_models() -> list[dict[str, Any]]:
"""Return video generation models from the dedicated /videos/models endpoint."""
base_url = os.getenv("OPENROUTER_BASE_URL", OPENROUTER_BASE_URL)
async with httpx.AsyncClient(timeout=15) as client:
resp = client.build_request(
"GET", f"{base_url}/videos/models", headers=_headers()
)
response = await client.send(resp)
response.raise_for_status()
return response.json().get("data", [])
async def generate_image_chat(
model: str,
prompt: str,
+2 -1
View File
@@ -53,7 +53,8 @@ async def test_stats_as_admin(client):
resp = await client.get("/admin/stats", headers={"Authorization": f"Bearer {token}"})
assert resp.status_code == 200
data = resp.json()
assert data["users"]["total"] == 3 # 2 users + 1 admin
# 2 users + 1 admin + 1 seeded admin (ai@allucanget.biz)
assert data["users"]["total"] == 4
assert "by_role" in data["users"]
assert "refresh_tokens" in data
+6 -6
View File
@@ -53,7 +53,7 @@ async def _user_token(client):
async def test_list_models(client):
token = await _user_token(client)
with patch(
"backend.app.routers.ai.openrouter.list_models",
"app.routers.ai.openrouter.list_models",
new_callable=AsyncMock,
return_value=FAKE_MODELS,
):
@@ -74,7 +74,7 @@ async def test_list_models_unauthenticated(client):
async def test_list_models_upstream_error(client):
token = await _user_token(client)
with patch(
"backend.app.routers.ai.openrouter.list_models",
"app.routers.ai.openrouter.list_models",
new_callable=AsyncMock,
side_effect=Exception("Connection refused"),
):
@@ -91,7 +91,7 @@ async def test_list_models_upstream_error(client):
async def test_chat_success(client):
token = await _user_token(client)
with patch(
"backend.app.routers.ai.openrouter.chat_completion",
"app.routers.ai.openrouter.chat_completion",
new_callable=AsyncMock,
return_value=FAKE_CHAT_RESPONSE,
):
@@ -115,7 +115,7 @@ async def test_chat_success(client):
async def test_chat_passes_parameters(client):
token = await _user_token(client)
mock = AsyncMock(return_value=FAKE_CHAT_RESPONSE)
with patch("backend.app.routers.ai.openrouter.chat_completion", new_callable=AsyncMock, return_value=FAKE_CHAT_RESPONSE) as mock:
with patch("app.routers.ai.openrouter.chat_completion", new_callable=AsyncMock, return_value=FAKE_CHAT_RESPONSE) as mock:
await client.post(
"/ai/chat",
json={
@@ -145,7 +145,7 @@ async def test_chat_unauthenticated(client):
async def test_chat_upstream_error(client):
token = await _user_token(client)
with patch(
"backend.app.routers.ai.openrouter.chat_completion",
"app.routers.ai.openrouter.chat_completion",
new_callable=AsyncMock,
side_effect=Exception("timeout"),
):
@@ -160,7 +160,7 @@ async def test_chat_upstream_error(client):
async def test_chat_malformed_upstream_response(client):
token = await _user_token(client)
with patch(
"backend.app.routers.ai.openrouter.chat_completion",
"app.routers.ai.openrouter.chat_completion",
new_callable=AsyncMock,
return_value={"id": "x", "choices": []}, # empty choices
):
+135 -69
View File
@@ -18,15 +18,6 @@ FAKE_CHAT = {
"usage": {"prompt_tokens": 5, "completion_tokens": 10, "total_tokens": 15},
}
FAKE_IMAGE = {
"id": "gen-img-1",
"model": "openai/dall-e-3",
"data": [
{"url": "https://example.com/image.png",
"revised_prompt": "A cat on the moon"},
],
}
FAKE_VIDEO = {
"id": "gen-vid-1",
"polling_url": "https://openrouter.ai/api/v1/videos/gen-vid-1",
@@ -155,47 +146,13 @@ async def test_generate_text_upstream_error(client):
# POST /generate/image
# ---------------------------------------------------------------------------
async def test_generate_image(client):
token = await _user_token(client)
with patch("app.routers.generate.openrouter.generate_image", new_callable=AsyncMock, return_value=FAKE_IMAGE):
resp = await client.post(
"/generate/image",
json={"model": "openai/dall-e-3", "prompt": "A cat on the moon"},
headers={"Authorization": f"Bearer {token}"},
)
assert resp.status_code == 200
data = resp.json()
assert data["id"] == "gen-img-1"
assert len(data["images"]) == 1
assert data["images"][0]["url"] == "https://example.com/image.png"
assert data["images"][0]["revised_prompt"] == "A cat on the moon"
async def test_generate_image_unauthenticated(client):
resp = await client.post("/generate/image", json={"model": "openai/dall-e-3", "prompt": "Hi"})
assert resp.status_code == 401
async def test_generate_image_upstream_error(client):
token = await _user_token(client)
with patch("app.routers.generate.openrouter.generate_image", new_callable=AsyncMock, side_effect=Exception("rate limit")):
resp = await client.post(
"/generate/image",
json={"model": "openai/dall-e-3", "prompt": "Hi"},
headers={"Authorization": f"Bearer {token}"},
)
assert resp.status_code == 502
# --- Chat-based image generation (FLUX, GPT-5 Image Mini) ---
FAKE_IMAGE_CHAT_FLUX = {
"id": "gen-img-chat-1",
"model": "black-forest-labs/flux.2-klein-4b",
"choices": [{
"message": {
"role": "assistant",
"content": "Here is your generated image.",
"content": None,
"images": [{
"type": "image_url",
"image_url": {"url": "data:image/png;base64,abc123"},
@@ -219,45 +176,65 @@ FAKE_IMAGE_CHAT_GPT5 = {
}],
}
FAKE_IMAGE_CHAT_GEMINI = {
"id": "gen-img-chat-3",
"model": "google/gemini-2.5-flash-image",
"choices": [{
"message": {
"role": "assistant",
"content": "Here is your image.",
"images": [{
"type": "image_url",
"image_url": {"url": "data:image/png;base64,gemini123"},
}],
}
}],
}
async def test_generate_image_chat_flux(client):
async def test_generate_image(client):
"""All models now use generate_image_chat (chat completions endpoint)."""
token = await _user_token(client)
with patch("app.routers.generate.openrouter.generate_image_chat", new_callable=AsyncMock, return_value=FAKE_IMAGE_CHAT_FLUX):
with patch("app.routers.generate.openrouter.generate_image_chat", new_callable=AsyncMock, return_value=FAKE_IMAGE_CHAT_GEMINI):
resp = await client.post(
"/generate/image",
json={"model": "black-forest-labs/flux.2-klein-4b",
"prompt": "A sunset"},
json={"model": "google/gemini-2.5-flash-image",
"prompt": "A cat on the moon"},
headers={"Authorization": f"Bearer {token}"},
)
assert resp.status_code == 200
data = resp.json()
assert data["id"] == "gen-img-chat-1"
assert data["id"] == "gen-img-chat-3"
assert len(data["images"]) == 1
assert data["images"][0]["url"] == "data:image/png;base64,abc123"
assert data["images"][0]["url"] == "data:image/png;base64,gemini123"
assert data["images"][0]["image_id"] is not None # stored in DB
async def test_generate_image_chat_gpt5_image_mini(client):
async def test_generate_image_unauthenticated(client):
resp = await client.post("/generate/image", json={"model": "google/gemini-2.5-flash-image", "prompt": "Hi"})
assert resp.status_code == 401
async def test_generate_image_upstream_error(client):
token = await _user_token(client)
with patch("app.routers.generate.openrouter.generate_image_chat", new_callable=AsyncMock, return_value=FAKE_IMAGE_CHAT_GPT5):
with patch("app.routers.generate.openrouter.generate_image_chat", new_callable=AsyncMock, side_effect=Exception("rate limit")):
resp = await client.post(
"/generate/image",
json={"model": "openai/gpt-5-image-mini", "prompt": "A cat"},
json={"model": "google/gemini-2.5-flash-image", "prompt": "Hi"},
headers={"Authorization": f"Bearer {token}"},
)
assert resp.status_code == 200
data = resp.json()
assert data["model"] == "openai/gpt-5-image-mini"
assert len(data["images"]) == 1
assert resp.status_code == 502
async def test_generate_image_chat_with_image_config(client):
async def test_generate_image_with_image_config(client):
"""Passes aspect_ratio + image_size through to generate_image_chat."""
token = await _user_token(client)
mock = AsyncMock(return_value=FAKE_IMAGE_CHAT_FLUX)
mock = AsyncMock(return_value=FAKE_IMAGE_CHAT_GEMINI)
with patch("app.routers.generate.openrouter.generate_image_chat", mock):
await client.post(
"/generate/image",
json={
"model": "black-forest-labs/flux.2-klein-4b",
"model": "google/gemini-2.5-flash-image",
"prompt": "A landscape",
"aspect_ratio": "16:9",
"image_size": "2K",
@@ -267,23 +244,112 @@ async def test_generate_image_chat_with_image_config(client):
call_kwargs = mock.call_args.kwargs
assert call_kwargs["image_config"]["aspect_ratio"] == "16:9"
assert call_kwargs["image_config"]["image_size"] == "2K"
assert call_kwargs["modalities"] == ["image"]
async def test_generate_image_chat_unauthenticated(client):
resp = await client.post("/generate/image", json={"model": "flux.2-klein-4b", "prompt": "Hi"})
assert resp.status_code == 401
async def test_generate_image_chat_upstream_error(client):
async def test_generate_image_default_modalities_image_text(client):
"""Model not in cache → default modalities = ['image', 'text']."""
token = await _user_token(client)
with patch("app.routers.generate.openrouter.generate_image_chat", new_callable=AsyncMock, side_effect=Exception("timeout")):
mock = AsyncMock(return_value=FAKE_IMAGE_CHAT_GEMINI)
with patch("app.routers.generate.openrouter.generate_image_chat", mock):
await client.post(
"/generate/image",
json={"model": "google/gemini-2.5-flash-image", "prompt": "Hi"},
headers={"Authorization": f"Bearer {token}"},
)
assert mock.call_args.kwargs["modalities"] == ["image", "text"]
async def test_generate_image_image_only_modalities_from_cache(client):
"""Model cached with image-only output_modalities → modalities = ['image']."""
from app import db as db_module
from app.services.models import get_model_output_modalities
import json as _json
token = await _user_token(client)
# Seed cache with image-only model
conn = db_module.get_conn()
from datetime import datetime, timezone
now = datetime.now(timezone.utc).replace(tzinfo=None)
conn.execute(
"DELETE FROM models_cache WHERE model_id = 'black-forest-labs/flux.2-pro'"
)
conn.execute(
"""INSERT INTO models_cache (model_id, name, modality, context_length, pricing, fetched_at, output_modalities)
VALUES (?, ?, ?, ?, ?, ?, ?)""",
["black-forest-labs/flux.2-pro", "FLUX.2 Pro", "image", None, None, now,
_json.dumps(["image"])],
)
mock = AsyncMock(return_value=FAKE_IMAGE_CHAT_FLUX)
with patch("app.routers.generate.openrouter.generate_image_chat", mock):
resp = await client.post(
"/generate/image",
json={"model": "black-forest-labs/flux.2-klein-4b", "prompt": "Hi"},
json={"model": "black-forest-labs/flux.2-pro", "prompt": "Sky"},
headers={"Authorization": f"Bearer {token}"},
)
assert resp.status_code == 200
assert mock.call_args.kwargs["modalities"] == ["image"]
async def test_generate_image_no_images_in_response(client):
"""502 when model returns no images."""
token = await _user_token(client)
empty_response = {
"id": "gen-empty",
"model": "google/gemini-2.5-flash-image",
"choices": [{"message": {"role": "assistant", "content": "ok", "images": []}}],
}
with patch("app.routers.generate.openrouter.generate_image_chat",
new_callable=AsyncMock, return_value=empty_response):
resp = await client.post(
"/generate/image",
json={"model": "google/gemini-2.5-flash-image", "prompt": "Hi"},
headers={"Authorization": f"Bearer {token}"},
)
assert resp.status_code == 502
assert "No images returned" in resp.json()["detail"]
async def test_generate_image_flux(client):
"""Flux model works correctly via chat completions."""
token = await _user_token(client)
with patch("app.routers.generate.openrouter.generate_image_chat",
new_callable=AsyncMock, return_value=FAKE_IMAGE_CHAT_FLUX):
resp = await client.post(
"/generate/image",
json={"model": "black-forest-labs/flux.2-klein-4b",
"prompt": "A sunset"},
headers={"Authorization": f"Bearer {token}"},
)
assert resp.status_code == 200
data = resp.json()
assert data["images"][0]["url"] == "data:image/png;base64,abc123"
async def test_generate_image_stored_in_db(client):
"""Generated image row persists in generated_images table."""
from app import db as db_module
token = await _user_token(client)
with patch("app.routers.generate.openrouter.generate_image_chat",
new_callable=AsyncMock, return_value=FAKE_IMAGE_CHAT_GEMINI):
resp = await client.post(
"/generate/image",
json={"model": "google/gemini-2.5-flash-image",
"prompt": "A mountain"},
headers={"Authorization": f"Bearer {token}"},
)
assert resp.status_code == 200
image_id = resp.json()["images"][0]["image_id"]
assert image_id is not None
row = db_module.get_conn().execute(
"SELECT model_id, prompt, image_data FROM generated_images WHERE id = ?",
[image_id],
).fetchone()
assert row is not None
assert row[0] == "google/gemini-2.5-flash-image"
assert row[1] == "A mountain"
assert row[2] == "data:image/png;base64,gemini123"
# ---------------------------------------------------------------------------
+63 -14
View File
@@ -15,6 +15,7 @@ from app.services.models import (
_normalize_modality,
_parse_modality,
get_cached_models,
get_model_output_modalities,
is_cache_stale,
refresh_models_cache,
)
@@ -28,28 +29,35 @@ FAKE_MODELS_RAW = [
"name": "GPT-4o",
"context_length": 128000,
"pricing": {"prompt": "0.000005"},
"architecture": {"modality": "text->text"},
"architecture": {"modality": "text->text", "output_modalities": ["text"]},
},
{
"id": "anthropic/claude-3-haiku",
"name": "Claude 3 Haiku",
"context_length": 200000,
"pricing": {},
"architecture": {"modality": "text+image->text"},
"architecture": {"modality": "text+image->text", "output_modalities": ["text"]},
},
{
"id": "openai/dall-e-3",
"name": "DALL-E 3",
"context_length": None,
"pricing": {"image": "0.04"},
"architecture": {"modality": "text->image"},
"architecture": {"modality": "text->image", "output_modalities": ["image"]},
},
{
"id": "openai/sora-2",
"name": "Sora 2",
"context_length": None,
"pricing": {"video": "0.10"},
"architecture": {"modality": "text->video"},
"architecture": {"modality": "text->video", "output_modalities": ["video"]},
},
{
"id": "google/gemini-2.5-flash-image",
"name": "Gemini 2.5 Flash Image",
"context_length": None,
"pricing": {},
"architecture": {"output_modalities": ["image", "text"]},
},
]
@@ -171,9 +179,9 @@ async def test_refresh_stores_models():
return_value=FAKE_MODELS_RAW,
):
count = await refresh_models_cache(conn)
assert count == 4
assert count == 5
all_models = get_cached_models(conn)
assert len(all_models) == 4
assert len(all_models) == 5
async def test_refresh_replaces_old_cache():
@@ -193,25 +201,29 @@ async def test_refresh_replaces_old_cache():
ids = [m["id"] for m in get_cached_models(conn)]
assert "old/model" not in ids
assert "openai/gpt-4o" in ids
assert len(ids) == 5
def test_get_cached_models_filter_by_modality():
conn = db_module.get_conn()
now = datetime.now(timezone.utc).replace(tzinfo=None)
for m in FAKE_MODELS_RAW:
arch = m.get("architecture", {})
modality = _parse_modality(arch.get("modality", "text->text"))
modality = _extract_output_modality(m)
conn.execute(
"INSERT INTO models_cache (model_id, name, modality, fetched_at) VALUES (?, ?, ?, ?)",
[m["id"], m["name"], modality, now],
)
text_models = get_cached_models(conn, modality="text")
# gpt-4o, claude-3-haiku (gemini has output_modalities=["image","text"] → classified as "image")
assert len(text_models) == 2
assert all(m["modality"] == "text" for m in text_models)
image_models = get_cached_models(conn, modality="image")
assert len(image_models) == 1
assert image_models[0]["id"] == "openai/dall-e-3"
# dall-e-3 + gemini (output_modalities starts with image)
assert len(image_models) == 2
image_ids = [m["id"] for m in image_models]
assert "openai/dall-e-3" in image_ids
assert "google/gemini-2.5-flash-image" in image_ids
video_models = get_cached_models(conn, modality="video")
assert len(video_models) == 1
@@ -233,7 +245,7 @@ async def test_list_models_endpoint_auto_refreshes(client):
"/models/", headers={"Authorization": f"Bearer {token}"}
)
assert resp.status_code == 200
assert len(resp.json()) == 4
assert len(resp.json()) == 5
assert mock_fetch.await_count >= 1
@@ -274,8 +286,10 @@ async def test_list_models_filter_by_modality(client):
)
assert resp.status_code == 200
data = resp.json()
assert len(data) == 1
assert data[0]["id"] == "openai/dall-e-3"
assert len(data) == 2 # dall-e-3 + gemini-2.5-flash-image
image_ids = [m["id"] for m in data]
assert "openai/dall-e-3" in image_ids
assert "google/gemini-2.5-flash-image" in image_ids
# ---------------------------------------------------------------------------
@@ -301,7 +315,7 @@ async def test_refresh_endpoint_admin_succeeds(client):
"/models/refresh", headers={"Authorization": f"Bearer {token}"}
)
assert resp.status_code == 200
assert resp.json()["refreshed"] == 4
assert resp.json()["refreshed"] == 5
async def test_refresh_endpoint_502_on_openrouter_error(client):
@@ -315,3 +329,38 @@ async def test_refresh_endpoint_502_on_openrouter_error(client):
"/models/refresh", headers={"Authorization": f"Bearer {token}"}
)
assert resp.status_code == 502
# ---------------------------------------------------------------------------
# Unit tests: get_model_output_modalities
# ---------------------------------------------------------------------------
async def test_get_model_output_modalities_image_only():
conn = db_module.get_conn()
with patch(
"app.services.models.openrouter.list_models",
new_callable=AsyncMock,
return_value=FAKE_MODELS_RAW,
):
await refresh_models_cache(conn)
modalities = get_model_output_modalities(conn, "openai/dall-e-3")
assert modalities == ["image"]
async def test_get_model_output_modalities_image_text():
conn = db_module.get_conn()
with patch(
"app.services.models.openrouter.list_models",
new_callable=AsyncMock,
return_value=FAKE_MODELS_RAW,
):
await refresh_models_cache(conn)
modalities = get_model_output_modalities(
conn, "google/gemini-2.5-flash-image")
assert set(modalities) == {"image", "text"}
def test_get_model_output_modalities_unknown_model():
conn = db_module.get_conn()
result = get_model_output_modalities(conn, "unknown/model")
assert result == []
+3 -1
View File
@@ -115,7 +115,9 @@ async def test_list_users_as_admin(client):
resp = await client.get("/users", headers={"Authorization": f"Bearer {admin_token}"})
assert resp.status_code == 200
assert isinstance(resp.json(), list)
assert len(resp.json()) == 1
assert len(resp.json()) >= 1
emails = [u["email"] for u in resp.json()]
assert "user@example.com" in emails
async def test_list_users_as_regular_user(client):
+7 -1
View File
@@ -178,7 +178,13 @@ def dashboard():
user = resp.json() if resp.status_code == 200 else {}
img_resp = _api("GET", "/images/", token=token)
images = img_resp.json() if img_resp.status_code == 200 else []
return render_template("dashboard.html", user=user, images=images)
gen_resp = _api("GET", "/generate/images", token=token)
generated_images = gen_resp.json() if gen_resp.status_code == 200 else []
vid_resp = _api("GET", "/generate/videos", token=token)
generated_videos = vid_resp.json() if vid_resp.status_code == 200 else []
return render_template("dashboard.html", user=user, images=images,
generated_images=generated_images,
generated_videos=generated_videos)
# ── Generate ──────────────────────────────────────────────────────────────
+53 -1
View File
@@ -6,7 +6,59 @@ endblock %} {% block content %}
<a href="{{ url_for('generate') }}" class="btn">Start generating</a>
</div>
{% if images %}
{% if generated_images %}
<div class="card mt-2">
<h2>Generated images</h2>
<div class="image-grid">
{% for img in generated_images %}
<div class="image-grid-item">
<img
src="{{ img.image_data }}"
alt="{{ img.prompt }}"
class="generated-image"
loading="lazy"
/>
<p class="text-muted" style="font-size: 0.75rem; margin-top: 0.25rem">
<strong>{{ img.model_id }}</strong><br />{{ img.prompt[:80] }}{% if
img.prompt|length > 80 %}…{% endif %}
</p>
</div>
{% endfor %}
</div>
</div>
{% endif %} {% if generated_videos %}
<div class="card mt-2">
<h2>Generated videos</h2>
<div class="image-grid">
{% for vid in generated_videos %}
<div class="image-grid-item">
{% if vid.video_url %}
<video controls style="max-width: 100%; border-radius: 6px">
<source src="{{ vid.video_url }}" />
Your browser does not support the video tag.
</video>
{% else %}
<div
style="
background: #1a1a1a;
border-radius: 6px;
padding: 2rem;
text-align: center;
"
>
<span class="text-muted">{{ vid.status | capitalize }} &hellip;</span>
</div>
{% endif %}
<p class="text-muted" style="font-size: 0.75rem; margin-top: 0.25rem">
<strong>{{ vid.model_id }}</strong><br />{{ vid.prompt[:80] }}{% if
vid.prompt|length > 80 %}…{% endif %}<br />
<em>{{ vid.status }}</em>
</p>
</div>
{% endfor %}
</div>
</div>
{% endif %} {% if images %}
<div class="card mt-2">
<h2>Uploaded reference images</h2>
<div class="image-grid">
+19 -32
View File
@@ -8,12 +8,12 @@
{% if models %}
<select id="model" name="model" required>
{% for m in models %}
<option value="{{ m.id }}" {% if request.form.get('model', '') == m.id %}selected{% endif %}>{{ m.name }}</option>
<option value="{{ m.id }}" {{ "selected" if request.form.get('model', '') == m.id else "" }}>{{ m.name }}</option>
{% endfor %}
</select>
{% else %}
<input id="model" name="model" type="text" required
placeholder="e.g. openai/dall-e-3"
placeholder="e.g. google/gemini-2.5-flash-image"
value="{{ request.form.get('model', '') }}">
{% endif %}
@@ -21,40 +21,25 @@
<textarea id="prompt" name="prompt" rows="4" required
placeholder="Describe the image you want…">{{ request.form.get('prompt', '') }}</textarea>
<label for="size">Size</label>
<select id="size" name="size">
<option value="1024x1024" {% if request.form.get('size','1024x1024')=='1024x1024' %}selected{% endif %}>1024×1024</option>
<option value="1792x1024" {% if request.form.get('size')=='1792x1024' %}selected{% endif %}>1792×1024 (landscape)</option>
<option value="1024x1792" {% if request.form.get('size')=='1024x1792' %}selected{% endif %}>1024×1792 (portrait)</option>
<option value="512x512" {% if request.form.get('size')=='512x512' %}selected{% endif %}>512×512</option>
</select>
<label for="aspect_ratio">Aspect ratio</label>
<select id="aspect_ratio" name="aspect_ratio">
<option value="">Auto (default)</option>
<option value="1:1" {% if request.form.get('aspect_ratio')=='1:1' %}selected{% endif %}>1:1 (square)</option>
<option value="16:9" {% if request.form.get('aspect_ratio')=='16:9' %}selected{% endif %}>16:9 (landscape)</option>
<option value="9:16" {% if request.form.get('aspect_ratio')=='9:16' %}selected{% endif %}>9:16 (portrait)</option>
<option value="4:3" {% if request.form.get('aspect_ratio')=='4:3' %}selected{% endif %}>4:3</option>
<option value="3:4" {% if request.form.get('aspect_ratio')=='3:4' %}selected{% endif %}>3:4</option>
<option value="3:2" {% if request.form.get('aspect_ratio')=='3:2' %}selected{% endif %}>3:2</option>
<option value="2:3" {% if request.form.get('aspect_ratio')=='2:3' %}selected{% endif %}>2:3</option>
<option value="">Auto (default 1:1)</option>
<option value="1:1" {{ "selected" if request.form.get('aspect_ratio')=='1:1' else "" }}>1:1 (square)</option>
<option value="16:9" {{ "selected" if request.form.get('aspect_ratio')=='16:9' else "" }}>16:9 (landscape)</option>
<option value="9:16" {{ "selected" if request.form.get('aspect_ratio')=='9:16' else "" }}>9:16 (portrait)</option>
<option value="4:3" {{ "selected" if request.form.get('aspect_ratio')=='4:3' else "" }}>4:3</option>
<option value="3:4" {{ "selected" if request.form.get('aspect_ratio')=='3:4' else "" }}>3:4</option>
<option value="3:2" {{ "selected" if request.form.get('aspect_ratio')=='3:2' else "" }}>3:2</option>
<option value="2:3" {{ "selected" if request.form.get('aspect_ratio')=='2:3' else "" }}>2:3</option>
</select>
<label for="image_size">Resolution</label>
<select id="image_size" name="image_size">
<option value="">Auto (default)</option>
<option value="0.5K" {% if request.form.get('image_size')=='0.5K' %}selected{% endif %}>0.5K (low)</option>
<option value="1K" {% if request.form.get('image_size')=='1K' %}selected{% endif %}>1K (standard)</option>
<option value="2K" {% if request.form.get('image_size')=='2K' %}selected{% endif %}>2K (high)</option>
<option value="4K" {% if request.form.get('image_size')=='4K' %}selected{% endif %}>4K (ultra)</option>
</select>
<label for="n">Number of images</label>
<select id="n" name="n">
<option value="1" {% if request.form.get('n','1')=='1' %}selected{% endif %}>1</option>
<option value="2" {% if request.form.get('n')=='2' %}selected{% endif %}>2</option>
<option value="4" {% if request.form.get('n')=='4' %}selected{% endif %}>4</option>
<option value="">Auto (default 1K)</option>
<option value="0.5K" {{ "selected" if request.form.get('image_size')=='0.5K' else "" }}>0.5K (low)</option>
<option value="1K" {{ "selected" if request.form.get('image_size')=='1K' else "" }}>1K (standard)</option>
<option value="2K" {{ "selected" if request.form.get('image_size')=='2K' else "" }}>2K (high)</option>
<option value="4K" {{ "selected" if request.form.get('image_size')=='4K' else "" }}>4K (ultra)</option>
</select>
<label for="reference_image">Reference image (optional)</label>
@@ -65,7 +50,7 @@
accept="image/png,image/jpeg,image/webp,image/gif"
>
<p class="text-muted mt-1" id="reference-image-help">
Upload image for visual reference in upcoming image-to-image flow.
Upload an image to use as visual reference (image-to-image).
</p>
<div class="image-upload-preview" id="image-upload-preview" hidden>
<p class="text-muted" id="image-upload-filename"></p>
@@ -83,7 +68,9 @@
<div class="result">
<h2>Generated image{{ 's' if result.images|length > 1 }}</h2>
{% for img in result.images %}
<img src="{{ img.url }}" alt="Generated image" class="generated-image">
{% if img.url %}
<img src="{{ img.url }}" alt="Generated image" class="generated-image">
{% endif %}
{% if img.revised_prompt %}
<p class="text-muted mt-1" style="font-size:0.8rem;">{{ img.revised_prompt }}</p>
{% endif %}
+28 -3
View File
@@ -151,7 +151,8 @@ def test_dashboard_renders_user_info(client):
me_mock = _mock_response(
200, {"id": "1", "email": "u@example.com", "role": "user"})
images_mock = _mock_response(200, [])
with patch("frontend.app.main.httpx.request", side_effect=[me_mock, images_mock]):
gen_images_mock = _mock_response(200, [])
with patch("frontend.app.main.httpx.request", side_effect=[me_mock, images_mock, gen_images_mock]):
resp = client.get("/dashboard")
assert resp.status_code == 200
assert b"u@example.com" in resp.data
@@ -531,19 +532,43 @@ def test_dashboard_shows_uploaded_images(client):
{"id": "img-1", "filename": "cat.png", "content_type": "image/png",
"size_bytes": 1024, "created_at": "2026-04-29T10:00:00"},
])
with patch("frontend.app.main.httpx.request", side_effect=[me_mock, images_mock]):
gen_images_mock = _mock_response(200, [])
with patch("frontend.app.main.httpx.request", side_effect=[me_mock, images_mock, gen_images_mock]):
resp = client.get("/dashboard")
assert resp.status_code == 200
assert b"cat.png" in resp.data
assert b"img-1" in resp.data
def test_dashboard_shows_generated_images(client):
_set_auth(client)
me_mock = _mock_response(
200, {"id": "1", "email": "u@example.com", "role": "user"})
images_mock = _mock_response(200, [])
gen_images_mock = _mock_response(200, [
{
"id": "gen-1",
"model_id": "google/gemini-2.5-flash-image",
"prompt": "A cat on the moon",
"image_data": "data:image/png;base64,abc123",
"created_at": "2026-04-29T10:00:00",
}
])
with patch("frontend.app.main.httpx.request", side_effect=[me_mock, images_mock, gen_images_mock]):
resp = client.get("/dashboard")
assert resp.status_code == 200
assert b"Generated images" in resp.data
assert b"A cat on the moon" in resp.data
assert b"data:image/png;base64,abc123" in resp.data
def test_dashboard_no_images_section_when_empty(client):
_set_auth(client)
me_mock = _mock_response(
200, {"id": "1", "email": "u@example.com", "role": "user"})
images_mock = _mock_response(200, [])
with patch("frontend.app.main.httpx.request", side_effect=[me_mock, images_mock]):
gen_images_mock = _mock_response(200, [])
with patch("frontend.app.main.httpx.request", side_effect=[me_mock, images_mock, gen_images_mock]):
resp = client.get("/dashboard")
assert resp.status_code == 200
assert b"Uploaded reference images" not in resp.data