feat: enhance model caching and output modalities handling
- Updated `refresh_models_cache` to include output modalities in the models cache. - Added `get_model_output_modalities` function to retrieve output modalities for a specific model. - Modified tests to cover new functionality for output modalities. - Updated OpenRouter video generation functions to support audio generation and improved error handling. - Enhanced dashboard to display generated images and videos. - Refactored frontend templates to accommodate new data structures for generated content. - Adjusted tests to validate changes in model handling and dashboard rendering. Co-authored-by: Copilot <copilot@github.com>
This commit is contained in:
@@ -86,6 +86,34 @@ def _run_migrations(conn: duckdb.DuckDBPyConnection) -> None:
|
||||
fetched_at TIMESTAMP NOT NULL
|
||||
)
|
||||
""")
|
||||
conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS generated_images (
|
||||
id UUID DEFAULT uuid() PRIMARY KEY,
|
||||
user_id UUID NOT NULL,
|
||||
model_id VARCHAR NOT NULL,
|
||||
prompt VARCHAR NOT NULL,
|
||||
image_data VARCHAR NOT NULL,
|
||||
created_at TIMESTAMP DEFAULT now()
|
||||
)
|
||||
""")
|
||||
conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS generated_videos (
|
||||
id UUID DEFAULT uuid() PRIMARY KEY,
|
||||
user_id UUID NOT NULL,
|
||||
job_id VARCHAR NOT NULL,
|
||||
model_id VARCHAR NOT NULL,
|
||||
prompt VARCHAR NOT NULL,
|
||||
polling_url VARCHAR,
|
||||
status VARCHAR NOT NULL DEFAULT 'pending',
|
||||
video_url VARCHAR,
|
||||
created_at TIMESTAMP DEFAULT now(),
|
||||
updated_at TIMESTAMP DEFAULT now()
|
||||
)
|
||||
""")
|
||||
# Migration: add output_modalities column if absent (stores JSON array string)
|
||||
conn.execute("""
|
||||
ALTER TABLE models_cache ADD COLUMN IF NOT EXISTS output_modalities VARCHAR
|
||||
""")
|
||||
_seed_admin(conn)
|
||||
|
||||
|
||||
|
||||
+14
-14
@@ -1,10 +1,10 @@
|
||||
from .routers import auth as auth_router
|
||||
from .routers import users as users_router
|
||||
from .routers import admin as admin_router
|
||||
from .routers import ai as ai_router
|
||||
from .routers import generate as generate_router
|
||||
from .routers import images as images_router
|
||||
from .routers import models as models_router
|
||||
from .routers import auth
|
||||
from .routers import users
|
||||
from .routers import admin
|
||||
from .routers import ai
|
||||
from .routers import generate
|
||||
from .routers import images
|
||||
from .routers import models
|
||||
from .db import close_db, init_db
|
||||
import os
|
||||
from contextlib import asynccontextmanager
|
||||
@@ -38,13 +38,13 @@ app.add_middleware(
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
app.include_router(auth_router.router)
|
||||
app.include_router(users_router.router)
|
||||
app.include_router(admin_router.router)
|
||||
app.include_router(ai_router.router)
|
||||
app.include_router(generate_router.router)
|
||||
app.include_router(images_router.router)
|
||||
app.include_router(models_router.router)
|
||||
app.include_router(auth.router)
|
||||
app.include_router(users.router)
|
||||
app.include_router(admin.router)
|
||||
app.include_router(ai.router)
|
||||
app.include_router(generate.router)
|
||||
app.include_router(images.router)
|
||||
app.include_router(models.router)
|
||||
|
||||
|
||||
@app.get("/health", tags=["health"])
|
||||
|
||||
@@ -62,6 +62,7 @@ class ImageResult(BaseModel):
|
||||
url: str | None = None
|
||||
b64_json: str | None = None
|
||||
revised_prompt: str | None = None
|
||||
image_id: str | None = None # UUID of stored row in generated_images
|
||||
|
||||
|
||||
class ImageResponse(BaseModel):
|
||||
|
||||
+172
-48
@@ -1,6 +1,9 @@
|
||||
"""Generate router: text, image, video, and image-to-video generation."""
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
|
||||
from ..db import get_conn, get_write_lock
|
||||
from ..dependencies import get_current_user
|
||||
from ..models.ai import (
|
||||
ImageRequest,
|
||||
@@ -13,6 +16,7 @@ from ..models.ai import (
|
||||
VideoResponse,
|
||||
)
|
||||
from ..services import openrouter
|
||||
from ..services.models import get_model_output_modalities
|
||||
|
||||
router = APIRouter(prefix="/generate", tags=["generate"])
|
||||
|
||||
@@ -62,81 +66,129 @@ async def generate_text(
|
||||
@router.post("/image", response_model=ImageResponse)
|
||||
async def generate_image(
|
||||
body: ImageRequest,
|
||||
_: dict = Depends(get_current_user),
|
||||
current_user: dict = Depends(get_current_user),
|
||||
) -> ImageResponse:
|
||||
"""Generate images from a text prompt."""
|
||||
# Detect if model uses chat completions (FLUX, GPT-5 Image Mini) vs /images/generations (DALL-E)
|
||||
chat_models = {"black-forest-labs/flux.2-klein-4b",
|
||||
"openai/gpt-5-image-mini"}
|
||||
is_chat_model = body.model.lower() in {m.lower() for m in chat_models} or \
|
||||
any(m in body.model.lower() for m in ["flux", "gpt-5-image-mini"])
|
||||
"""Generate images from a prompt using the chat completions endpoint.
|
||||
|
||||
All OpenRouter image models use /chat/completions with a modalities param.
|
||||
Models that output only images use ["image"]; those that also output text
|
||||
use ["image", "text"]. We look this up from the model cache; default to
|
||||
["image", "text"] when the model is not yet cached.
|
||||
"""
|
||||
# Determine modalities from cache; default ["image", "text"] works for most models
|
||||
try:
|
||||
if is_chat_model:
|
||||
image_config = {}
|
||||
conn = get_conn()
|
||||
cached_modalities = get_model_output_modalities(conn, body.model)
|
||||
except Exception:
|
||||
cached_modalities = []
|
||||
|
||||
if cached_modalities:
|
||||
# If cache says model only outputs image (no text), use ["image"]
|
||||
modalities = ["image"] if set(cached_modalities) == {
|
||||
"image"} else ["image", "text"]
|
||||
else:
|
||||
# Safe default: ["image", "text"]; works for Gemini, GPT-image etc.
|
||||
# For image-only models that fail with this, the error surfaces to the user.
|
||||
modalities = ["image", "text"]
|
||||
|
||||
image_config: dict = {}
|
||||
if body.aspect_ratio:
|
||||
image_config["aspect_ratio"] = body.aspect_ratio
|
||||
if body.image_size:
|
||||
image_config["image_size"] = body.image_size
|
||||
|
||||
try:
|
||||
result = await openrouter.generate_image_chat(
|
||||
model=body.model,
|
||||
prompt=body.prompt,
|
||||
modalities=[
|
||||
"image", "text"] if "gpt-5-image-mini" in body.model.lower() else ["image"],
|
||||
modalities=modalities,
|
||||
image_config=image_config if image_config else None,
|
||||
)
|
||||
else:
|
||||
result = await openrouter.generate_image(
|
||||
model=body.model,
|
||||
prompt=body.prompt,
|
||||
n=body.n,
|
||||
size=body.size,
|
||||
)
|
||||
except Exception as exc:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_502_BAD_GATEWAY, detail=f"OpenRouter error: {exc}")
|
||||
|
||||
try:
|
||||
if is_chat_model:
|
||||
# Chat completions response: choices[0].message.images[].image_url.url
|
||||
images = []
|
||||
message = result.get("choices", [{}])[0].get("message", {})
|
||||
images = []
|
||||
for item in message.get("images", []):
|
||||
img_url = item.get("image_url", {}).get("url")
|
||||
images.append(ImageResult(
|
||||
url=img_url,
|
||||
b64_json=None,
|
||||
revised_prompt=message.get("content"),
|
||||
revised_prompt=message.get("content") or None,
|
||||
))
|
||||
return ImageResponse(
|
||||
id=result.get("id", ""),
|
||||
model=result.get("model", body.model),
|
||||
images=images,
|
||||
if not images:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_502_BAD_GATEWAY,
|
||||
detail="No images returned by model. Verify the model supports image generation.",
|
||||
)
|
||||
|
||||
# Persist each image to DB
|
||||
user_id = current_user.get("id") or current_user.get("sub")
|
||||
now = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
stored: list[ImageResult] = []
|
||||
async with get_write_lock():
|
||||
conn = get_conn()
|
||||
for img in images:
|
||||
if img.url:
|
||||
row = conn.execute(
|
||||
"""INSERT INTO generated_images (user_id, model_id, prompt, image_data, created_at)
|
||||
VALUES (?, ?, ?, ?, ?) RETURNING id""",
|
||||
[user_id, body.model, body.prompt, img.url, now],
|
||||
).fetchone()
|
||||
image_id = str(row[0]) if row else None
|
||||
else:
|
||||
# /images/generations response: data[].url
|
||||
images = [
|
||||
ImageResult(
|
||||
url=item.get("url"),
|
||||
b64_json=item.get("b64_json"),
|
||||
revised_prompt=item.get("revised_prompt"),
|
||||
)
|
||||
for item in result.get("data", [])
|
||||
]
|
||||
image_id = None
|
||||
stored.append(ImageResult(
|
||||
url=img.url,
|
||||
b64_json=img.b64_json,
|
||||
revised_prompt=img.revised_prompt,
|
||||
image_id=image_id,
|
||||
))
|
||||
|
||||
return ImageResponse(
|
||||
id=result.get("id", ""),
|
||||
model=result.get("model", body.model),
|
||||
images=images,
|
||||
images=stored,
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except (KeyError, TypeError) as exc:
|
||||
raise HTTPException(status_code=status.HTTP_502_BAD_GATEWAY,
|
||||
detail=f"Unexpected response format: {exc}")
|
||||
|
||||
|
||||
@router.get("/images")
|
||||
async def list_generated_images(
|
||||
current_user: dict = Depends(get_current_user),
|
||||
) -> list[dict]:
|
||||
"""Return all generated images for the current user, newest first."""
|
||||
user_id = current_user.get("id") or current_user.get("sub")
|
||||
conn = get_conn()
|
||||
rows = conn.execute(
|
||||
"""SELECT id, model_id, prompt, image_data, created_at
|
||||
FROM generated_images
|
||||
WHERE user_id = ?
|
||||
ORDER BY created_at DESC""",
|
||||
[user_id],
|
||||
).fetchall()
|
||||
return [
|
||||
{
|
||||
"id": str(r[0]),
|
||||
"model_id": r[1],
|
||||
"prompt": r[2],
|
||||
"image_data": r[3],
|
||||
"created_at": r[4].isoformat() if r[4] else None,
|
||||
}
|
||||
for r in rows
|
||||
]
|
||||
|
||||
|
||||
@router.post("/video", response_model=VideoResponse)
|
||||
async def generate_video(
|
||||
body: VideoRequest,
|
||||
_: dict = Depends(get_current_user),
|
||||
current_user: dict = Depends(get_current_user),
|
||||
) -> VideoResponse:
|
||||
"""Generate a video from a text prompt."""
|
||||
try:
|
||||
@@ -151,12 +203,26 @@ async def generate_video(
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_502_BAD_GATEWAY, detail=f"OpenRouter error: {exc}")
|
||||
|
||||
user_id = current_user.get("id") or current_user.get("sub")
|
||||
job_id = result.get("id", "")
|
||||
polling_url = result.get("polling_url")
|
||||
job_status = result.get("status", "pending")
|
||||
now = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
async with get_write_lock():
|
||||
conn = get_conn()
|
||||
conn.execute(
|
||||
"""INSERT INTO generated_videos (user_id, job_id, model_id, prompt, polling_url, status, created_at, updated_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?)""",
|
||||
[user_id, job_id, body.model, body.prompt,
|
||||
polling_url, job_status, now, now],
|
||||
)
|
||||
|
||||
urls = result.get("unsigned_urls") or result.get("video_urls")
|
||||
return VideoResponse(
|
||||
id=result.get("id", ""),
|
||||
id=job_id,
|
||||
model=body.model,
|
||||
status=result.get("status", "queued"),
|
||||
polling_url=result.get("polling_url"),
|
||||
status=job_status,
|
||||
polling_url=polling_url,
|
||||
video_urls=urls,
|
||||
video_url=(urls or [None])[0],
|
||||
error=result.get("error"),
|
||||
@@ -167,7 +233,7 @@ async def generate_video(
|
||||
@router.post("/video/from-image", response_model=VideoResponse)
|
||||
async def generate_video_from_image(
|
||||
body: VideoFromImageRequest,
|
||||
_: dict = Depends(get_current_user),
|
||||
current_user: dict = Depends(get_current_user),
|
||||
) -> VideoResponse:
|
||||
"""Generate a video from an image and a text prompt."""
|
||||
try:
|
||||
@@ -183,12 +249,26 @@ async def generate_video_from_image(
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_502_BAD_GATEWAY, detail=f"OpenRouter error: {exc}")
|
||||
|
||||
user_id = current_user.get("id") or current_user.get("sub")
|
||||
job_id = result.get("id", "")
|
||||
polling_url = result.get("polling_url")
|
||||
job_status = result.get("status", "pending")
|
||||
now = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
async with get_write_lock():
|
||||
conn = get_conn()
|
||||
conn.execute(
|
||||
"""INSERT INTO generated_videos (user_id, job_id, model_id, prompt, polling_url, status, created_at, updated_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?)""",
|
||||
[user_id, job_id, body.model, body.prompt,
|
||||
polling_url, job_status, now, now],
|
||||
)
|
||||
|
||||
urls = result.get("unsigned_urls") or result.get("video_urls")
|
||||
return VideoResponse(
|
||||
id=result.get("id", ""),
|
||||
id=job_id,
|
||||
model=body.model,
|
||||
status=result.get("status", "queued"),
|
||||
polling_url=result.get("polling_url"),
|
||||
status=job_status,
|
||||
polling_url=polling_url,
|
||||
video_urls=urls,
|
||||
video_url=(urls or [None])[0],
|
||||
error=result.get("error"),
|
||||
@@ -199,23 +279,67 @@ async def generate_video_from_image(
|
||||
@router.get("/video/status", response_model=VideoResponse)
|
||||
async def poll_video_status(
|
||||
polling_url: str,
|
||||
_: dict = Depends(get_current_user),
|
||||
current_user: dict = Depends(get_current_user),
|
||||
) -> VideoResponse:
|
||||
"""Poll the status of a video generation job via its polling_url."""
|
||||
"""Poll status of a video generation job; updates DB row when completed/failed."""
|
||||
try:
|
||||
result = await openrouter.poll_video_status(polling_url)
|
||||
except Exception as exc:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_502_BAD_GATEWAY, detail=f"OpenRouter error: {exc}")
|
||||
|
||||
job_status = result.get("status", "processing")
|
||||
urls = result.get("unsigned_urls") or result.get("video_urls")
|
||||
video_url = (urls or [None])[0]
|
||||
|
||||
# Update DB row for this job when terminal state reached
|
||||
if job_status in ("completed", "failed"):
|
||||
now = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
async with get_write_lock():
|
||||
conn = get_conn()
|
||||
conn.execute(
|
||||
"""UPDATE generated_videos
|
||||
SET status = ?, video_url = ?, updated_at = ?
|
||||
WHERE job_id = ?""",
|
||||
[job_status, video_url, now, result.get("id", "")],
|
||||
)
|
||||
|
||||
return VideoResponse(
|
||||
id=result.get("id", ""),
|
||||
model=result.get("model", ""),
|
||||
status=result.get("status", "processing"),
|
||||
status=job_status,
|
||||
polling_url=result.get("polling_url"),
|
||||
video_urls=urls,
|
||||
video_url=(urls or [None])[0],
|
||||
video_url=video_url,
|
||||
error=result.get("error"),
|
||||
metadata=result.get("metadata"),
|
||||
)
|
||||
|
||||
|
||||
@router.get("/videos")
|
||||
async def list_generated_videos(
|
||||
current_user: dict = Depends(get_current_user),
|
||||
) -> list[dict]:
|
||||
"""Return all generated video jobs for the current user, newest first."""
|
||||
user_id = current_user.get("id") or current_user.get("sub")
|
||||
conn = get_conn()
|
||||
rows = conn.execute(
|
||||
"""SELECT id, job_id, model_id, prompt, polling_url, status, video_url, created_at
|
||||
FROM generated_videos
|
||||
WHERE user_id = ?
|
||||
ORDER BY created_at DESC""",
|
||||
[user_id],
|
||||
).fetchall()
|
||||
return [
|
||||
{
|
||||
"id": str(r[0]),
|
||||
"job_id": r[1],
|
||||
"model_id": r[2],
|
||||
"prompt": r[3],
|
||||
"polling_url": r[4],
|
||||
"status": r[5],
|
||||
"video_url": r[6],
|
||||
"created_at": r[7].isoformat() if r[7] else None,
|
||||
}
|
||||
for r in rows
|
||||
]
|
||||
|
||||
@@ -91,16 +91,28 @@ async def refresh_models_cache(conn: duckdb.DuckDBPyConnection) -> int:
|
||||
model_id = m.get("id", "")
|
||||
if not model_id:
|
||||
continue
|
||||
# Full output_modalities array from architecture (for proper modalities param in image gen)
|
||||
architecture = m.get("architecture") or {}
|
||||
raw_output_modalities: list | None = (
|
||||
architecture.get("output_modalities") or m.get("output_modalities")
|
||||
)
|
||||
output_modalities_json: str | None = (
|
||||
json.dumps([_normalize_modality(str(v))
|
||||
for v in raw_output_modalities])
|
||||
if isinstance(raw_output_modalities, list)
|
||||
else None
|
||||
)
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT INTO models_cache (model_id, name, modality, context_length, pricing, fetched_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?)
|
||||
INSERT INTO models_cache (model_id, name, modality, context_length, pricing, fetched_at, output_modalities)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?)
|
||||
ON CONFLICT (model_id) DO UPDATE SET
|
||||
name = excluded.name,
|
||||
modality = excluded.modality,
|
||||
context_length = excluded.context_length,
|
||||
pricing = excluded.pricing,
|
||||
fetched_at = excluded.fetched_at
|
||||
fetched_at = excluded.fetched_at,
|
||||
output_modalities = excluded.output_modalities
|
||||
""",
|
||||
[
|
||||
model_id,
|
||||
@@ -109,6 +121,7 @@ async def refresh_models_cache(conn: duckdb.DuckDBPyConnection) -> int:
|
||||
m.get("context_length"),
|
||||
json.dumps(pricing) if pricing else None,
|
||||
now,
|
||||
output_modalities_json,
|
||||
],
|
||||
)
|
||||
count += 1
|
||||
@@ -168,3 +181,20 @@ def get_cached_models(
|
||||
"pricing": pricing,
|
||||
})
|
||||
return result
|
||||
|
||||
|
||||
def get_model_output_modalities(
|
||||
conn: duckdb.DuckDBPyConnection,
|
||||
model_id: str,
|
||||
) -> list[str]:
|
||||
"""Return output_modalities list for a model; empty list if not found."""
|
||||
row = conn.execute(
|
||||
"SELECT output_modalities FROM models_cache WHERE model_id = ?",
|
||||
[model_id],
|
||||
).fetchone()
|
||||
if not row or not row[0]:
|
||||
return []
|
||||
try:
|
||||
return json.loads(row[0])
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
return []
|
||||
|
||||
@@ -95,8 +95,9 @@ async def generate_video(
|
||||
duration_seconds: int | None = None,
|
||||
aspect_ratio: str = "16:9",
|
||||
resolution: str | None = None,
|
||||
generate_audio: bool | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Request text-to-video generation via OpenRouter."""
|
||||
"""Request text-to-video generation via OpenRouter POST /videos."""
|
||||
base_url = os.getenv("OPENROUTER_BASE_URL", OPENROUTER_BASE_URL)
|
||||
payload: dict[str, Any] = {
|
||||
"model": model,
|
||||
@@ -104,9 +105,12 @@ async def generate_video(
|
||||
"aspect_ratio": aspect_ratio,
|
||||
}
|
||||
if duration_seconds is not None:
|
||||
payload["duration_seconds"] = duration_seconds
|
||||
# API uses 'duration' not 'duration_seconds'
|
||||
payload["duration"] = duration_seconds
|
||||
if resolution is not None:
|
||||
payload["resolution"] = resolution
|
||||
if generate_audio is not None:
|
||||
payload["generate_audio"] = generate_audio
|
||||
async with httpx.AsyncClient(timeout=120) as client:
|
||||
resp = client.build_request(
|
||||
"POST", f"{base_url}/videos", headers=_headers(), json=payload
|
||||
@@ -123,19 +127,31 @@ async def generate_video_from_image(
|
||||
duration_seconds: int | None = None,
|
||||
aspect_ratio: str = "16:9",
|
||||
resolution: str | None = None,
|
||||
generate_audio: bool | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Request image-to-video generation via OpenRouter."""
|
||||
"""Request image-to-video generation via OpenRouter POST /videos.
|
||||
|
||||
Uses frame_images array with first_frame as per OpenRouter API spec.
|
||||
"""
|
||||
base_url = os.getenv("OPENROUTER_BASE_URL", OPENROUTER_BASE_URL)
|
||||
payload: dict[str, Any] = {
|
||||
"model": model,
|
||||
"image_url": image_url,
|
||||
"prompt": prompt,
|
||||
"aspect_ratio": aspect_ratio,
|
||||
"frame_images": [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": image_url},
|
||||
"frame_type": "first_frame",
|
||||
}
|
||||
],
|
||||
}
|
||||
if duration_seconds is not None:
|
||||
payload["duration_seconds"] = duration_seconds
|
||||
payload["duration"] = duration_seconds
|
||||
if resolution is not None:
|
||||
payload["resolution"] = resolution
|
||||
if generate_audio is not None:
|
||||
payload["generate_audio"] = generate_audio
|
||||
async with httpx.AsyncClient(timeout=120) as client:
|
||||
resp = client.build_request(
|
||||
"POST", f"{base_url}/videos", headers=_headers(), json=payload
|
||||
@@ -154,6 +170,18 @@ async def poll_video_status(polling_url: str) -> dict[str, Any]:
|
||||
return response.json()
|
||||
|
||||
|
||||
async def list_video_models() -> list[dict[str, Any]]:
|
||||
"""Return video generation models from the dedicated /videos/models endpoint."""
|
||||
base_url = os.getenv("OPENROUTER_BASE_URL", OPENROUTER_BASE_URL)
|
||||
async with httpx.AsyncClient(timeout=15) as client:
|
||||
resp = client.build_request(
|
||||
"GET", f"{base_url}/videos/models", headers=_headers()
|
||||
)
|
||||
response = await client.send(resp)
|
||||
response.raise_for_status()
|
||||
return response.json().get("data", [])
|
||||
|
||||
|
||||
async def generate_image_chat(
|
||||
model: str,
|
||||
prompt: str,
|
||||
|
||||
@@ -53,7 +53,8 @@ async def test_stats_as_admin(client):
|
||||
resp = await client.get("/admin/stats", headers={"Authorization": f"Bearer {token}"})
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["users"]["total"] == 3 # 2 users + 1 admin
|
||||
# 2 users + 1 admin + 1 seeded admin (ai@allucanget.biz)
|
||||
assert data["users"]["total"] == 4
|
||||
assert "by_role" in data["users"]
|
||||
assert "refresh_tokens" in data
|
||||
|
||||
|
||||
@@ -53,7 +53,7 @@ async def _user_token(client):
|
||||
async def test_list_models(client):
|
||||
token = await _user_token(client)
|
||||
with patch(
|
||||
"backend.app.routers.ai.openrouter.list_models",
|
||||
"app.routers.ai.openrouter.list_models",
|
||||
new_callable=AsyncMock,
|
||||
return_value=FAKE_MODELS,
|
||||
):
|
||||
@@ -74,7 +74,7 @@ async def test_list_models_unauthenticated(client):
|
||||
async def test_list_models_upstream_error(client):
|
||||
token = await _user_token(client)
|
||||
with patch(
|
||||
"backend.app.routers.ai.openrouter.list_models",
|
||||
"app.routers.ai.openrouter.list_models",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=Exception("Connection refused"),
|
||||
):
|
||||
@@ -91,7 +91,7 @@ async def test_list_models_upstream_error(client):
|
||||
async def test_chat_success(client):
|
||||
token = await _user_token(client)
|
||||
with patch(
|
||||
"backend.app.routers.ai.openrouter.chat_completion",
|
||||
"app.routers.ai.openrouter.chat_completion",
|
||||
new_callable=AsyncMock,
|
||||
return_value=FAKE_CHAT_RESPONSE,
|
||||
):
|
||||
@@ -115,7 +115,7 @@ async def test_chat_success(client):
|
||||
async def test_chat_passes_parameters(client):
|
||||
token = await _user_token(client)
|
||||
mock = AsyncMock(return_value=FAKE_CHAT_RESPONSE)
|
||||
with patch("backend.app.routers.ai.openrouter.chat_completion", new_callable=AsyncMock, return_value=FAKE_CHAT_RESPONSE) as mock:
|
||||
with patch("app.routers.ai.openrouter.chat_completion", new_callable=AsyncMock, return_value=FAKE_CHAT_RESPONSE) as mock:
|
||||
await client.post(
|
||||
"/ai/chat",
|
||||
json={
|
||||
@@ -145,7 +145,7 @@ async def test_chat_unauthenticated(client):
|
||||
async def test_chat_upstream_error(client):
|
||||
token = await _user_token(client)
|
||||
with patch(
|
||||
"backend.app.routers.ai.openrouter.chat_completion",
|
||||
"app.routers.ai.openrouter.chat_completion",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=Exception("timeout"),
|
||||
):
|
||||
@@ -160,7 +160,7 @@ async def test_chat_upstream_error(client):
|
||||
async def test_chat_malformed_upstream_response(client):
|
||||
token = await _user_token(client)
|
||||
with patch(
|
||||
"backend.app.routers.ai.openrouter.chat_completion",
|
||||
"app.routers.ai.openrouter.chat_completion",
|
||||
new_callable=AsyncMock,
|
||||
return_value={"id": "x", "choices": []}, # empty choices
|
||||
):
|
||||
|
||||
+135
-69
@@ -18,15 +18,6 @@ FAKE_CHAT = {
|
||||
"usage": {"prompt_tokens": 5, "completion_tokens": 10, "total_tokens": 15},
|
||||
}
|
||||
|
||||
FAKE_IMAGE = {
|
||||
"id": "gen-img-1",
|
||||
"model": "openai/dall-e-3",
|
||||
"data": [
|
||||
{"url": "https://example.com/image.png",
|
||||
"revised_prompt": "A cat on the moon"},
|
||||
],
|
||||
}
|
||||
|
||||
FAKE_VIDEO = {
|
||||
"id": "gen-vid-1",
|
||||
"polling_url": "https://openrouter.ai/api/v1/videos/gen-vid-1",
|
||||
@@ -155,47 +146,13 @@ async def test_generate_text_upstream_error(client):
|
||||
# POST /generate/image
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def test_generate_image(client):
|
||||
token = await _user_token(client)
|
||||
with patch("app.routers.generate.openrouter.generate_image", new_callable=AsyncMock, return_value=FAKE_IMAGE):
|
||||
resp = await client.post(
|
||||
"/generate/image",
|
||||
json={"model": "openai/dall-e-3", "prompt": "A cat on the moon"},
|
||||
headers={"Authorization": f"Bearer {token}"},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["id"] == "gen-img-1"
|
||||
assert len(data["images"]) == 1
|
||||
assert data["images"][0]["url"] == "https://example.com/image.png"
|
||||
assert data["images"][0]["revised_prompt"] == "A cat on the moon"
|
||||
|
||||
|
||||
async def test_generate_image_unauthenticated(client):
|
||||
resp = await client.post("/generate/image", json={"model": "openai/dall-e-3", "prompt": "Hi"})
|
||||
assert resp.status_code == 401
|
||||
|
||||
|
||||
async def test_generate_image_upstream_error(client):
|
||||
token = await _user_token(client)
|
||||
with patch("app.routers.generate.openrouter.generate_image", new_callable=AsyncMock, side_effect=Exception("rate limit")):
|
||||
resp = await client.post(
|
||||
"/generate/image",
|
||||
json={"model": "openai/dall-e-3", "prompt": "Hi"},
|
||||
headers={"Authorization": f"Bearer {token}"},
|
||||
)
|
||||
assert resp.status_code == 502
|
||||
|
||||
|
||||
# --- Chat-based image generation (FLUX, GPT-5 Image Mini) ---
|
||||
|
||||
FAKE_IMAGE_CHAT_FLUX = {
|
||||
"id": "gen-img-chat-1",
|
||||
"model": "black-forest-labs/flux.2-klein-4b",
|
||||
"choices": [{
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "Here is your generated image.",
|
||||
"content": None,
|
||||
"images": [{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": "data:image/png;base64,abc123"},
|
||||
@@ -219,45 +176,65 @@ FAKE_IMAGE_CHAT_GPT5 = {
|
||||
}],
|
||||
}
|
||||
|
||||
FAKE_IMAGE_CHAT_GEMINI = {
|
||||
"id": "gen-img-chat-3",
|
||||
"model": "google/gemini-2.5-flash-image",
|
||||
"choices": [{
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "Here is your image.",
|
||||
"images": [{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": "data:image/png;base64,gemini123"},
|
||||
}],
|
||||
}
|
||||
}],
|
||||
}
|
||||
|
||||
async def test_generate_image_chat_flux(client):
|
||||
|
||||
async def test_generate_image(client):
|
||||
"""All models now use generate_image_chat (chat completions endpoint)."""
|
||||
token = await _user_token(client)
|
||||
with patch("app.routers.generate.openrouter.generate_image_chat", new_callable=AsyncMock, return_value=FAKE_IMAGE_CHAT_FLUX):
|
||||
with patch("app.routers.generate.openrouter.generate_image_chat", new_callable=AsyncMock, return_value=FAKE_IMAGE_CHAT_GEMINI):
|
||||
resp = await client.post(
|
||||
"/generate/image",
|
||||
json={"model": "black-forest-labs/flux.2-klein-4b",
|
||||
"prompt": "A sunset"},
|
||||
json={"model": "google/gemini-2.5-flash-image",
|
||||
"prompt": "A cat on the moon"},
|
||||
headers={"Authorization": f"Bearer {token}"},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["id"] == "gen-img-chat-1"
|
||||
assert data["id"] == "gen-img-chat-3"
|
||||
assert len(data["images"]) == 1
|
||||
assert data["images"][0]["url"] == "data:image/png;base64,abc123"
|
||||
assert data["images"][0]["url"] == "data:image/png;base64,gemini123"
|
||||
assert data["images"][0]["image_id"] is not None # stored in DB
|
||||
|
||||
|
||||
async def test_generate_image_chat_gpt5_image_mini(client):
|
||||
async def test_generate_image_unauthenticated(client):
|
||||
resp = await client.post("/generate/image", json={"model": "google/gemini-2.5-flash-image", "prompt": "Hi"})
|
||||
assert resp.status_code == 401
|
||||
|
||||
|
||||
async def test_generate_image_upstream_error(client):
|
||||
token = await _user_token(client)
|
||||
with patch("app.routers.generate.openrouter.generate_image_chat", new_callable=AsyncMock, return_value=FAKE_IMAGE_CHAT_GPT5):
|
||||
with patch("app.routers.generate.openrouter.generate_image_chat", new_callable=AsyncMock, side_effect=Exception("rate limit")):
|
||||
resp = await client.post(
|
||||
"/generate/image",
|
||||
json={"model": "openai/gpt-5-image-mini", "prompt": "A cat"},
|
||||
json={"model": "google/gemini-2.5-flash-image", "prompt": "Hi"},
|
||||
headers={"Authorization": f"Bearer {token}"},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["model"] == "openai/gpt-5-image-mini"
|
||||
assert len(data["images"]) == 1
|
||||
assert resp.status_code == 502
|
||||
|
||||
|
||||
async def test_generate_image_chat_with_image_config(client):
|
||||
async def test_generate_image_with_image_config(client):
|
||||
"""Passes aspect_ratio + image_size through to generate_image_chat."""
|
||||
token = await _user_token(client)
|
||||
mock = AsyncMock(return_value=FAKE_IMAGE_CHAT_FLUX)
|
||||
mock = AsyncMock(return_value=FAKE_IMAGE_CHAT_GEMINI)
|
||||
with patch("app.routers.generate.openrouter.generate_image_chat", mock):
|
||||
await client.post(
|
||||
"/generate/image",
|
||||
json={
|
||||
"model": "black-forest-labs/flux.2-klein-4b",
|
||||
"model": "google/gemini-2.5-flash-image",
|
||||
"prompt": "A landscape",
|
||||
"aspect_ratio": "16:9",
|
||||
"image_size": "2K",
|
||||
@@ -267,23 +244,112 @@ async def test_generate_image_chat_with_image_config(client):
|
||||
call_kwargs = mock.call_args.kwargs
|
||||
assert call_kwargs["image_config"]["aspect_ratio"] == "16:9"
|
||||
assert call_kwargs["image_config"]["image_size"] == "2K"
|
||||
assert call_kwargs["modalities"] == ["image"]
|
||||
|
||||
|
||||
async def test_generate_image_chat_unauthenticated(client):
|
||||
resp = await client.post("/generate/image", json={"model": "flux.2-klein-4b", "prompt": "Hi"})
|
||||
assert resp.status_code == 401
|
||||
|
||||
|
||||
async def test_generate_image_chat_upstream_error(client):
|
||||
async def test_generate_image_default_modalities_image_text(client):
|
||||
"""Model not in cache → default modalities = ['image', 'text']."""
|
||||
token = await _user_token(client)
|
||||
with patch("app.routers.generate.openrouter.generate_image_chat", new_callable=AsyncMock, side_effect=Exception("timeout")):
|
||||
mock = AsyncMock(return_value=FAKE_IMAGE_CHAT_GEMINI)
|
||||
with patch("app.routers.generate.openrouter.generate_image_chat", mock):
|
||||
await client.post(
|
||||
"/generate/image",
|
||||
json={"model": "google/gemini-2.5-flash-image", "prompt": "Hi"},
|
||||
headers={"Authorization": f"Bearer {token}"},
|
||||
)
|
||||
assert mock.call_args.kwargs["modalities"] == ["image", "text"]
|
||||
|
||||
|
||||
async def test_generate_image_image_only_modalities_from_cache(client):
|
||||
"""Model cached with image-only output_modalities → modalities = ['image']."""
|
||||
from app import db as db_module
|
||||
from app.services.models import get_model_output_modalities
|
||||
import json as _json
|
||||
token = await _user_token(client)
|
||||
|
||||
# Seed cache with image-only model
|
||||
conn = db_module.get_conn()
|
||||
from datetime import datetime, timezone
|
||||
now = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
conn.execute(
|
||||
"DELETE FROM models_cache WHERE model_id = 'black-forest-labs/flux.2-pro'"
|
||||
)
|
||||
conn.execute(
|
||||
"""INSERT INTO models_cache (model_id, name, modality, context_length, pricing, fetched_at, output_modalities)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?)""",
|
||||
["black-forest-labs/flux.2-pro", "FLUX.2 Pro", "image", None, None, now,
|
||||
_json.dumps(["image"])],
|
||||
)
|
||||
|
||||
mock = AsyncMock(return_value=FAKE_IMAGE_CHAT_FLUX)
|
||||
with patch("app.routers.generate.openrouter.generate_image_chat", mock):
|
||||
resp = await client.post(
|
||||
"/generate/image",
|
||||
json={"model": "black-forest-labs/flux.2-klein-4b", "prompt": "Hi"},
|
||||
json={"model": "black-forest-labs/flux.2-pro", "prompt": "Sky"},
|
||||
headers={"Authorization": f"Bearer {token}"},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
assert mock.call_args.kwargs["modalities"] == ["image"]
|
||||
|
||||
|
||||
async def test_generate_image_no_images_in_response(client):
|
||||
"""502 when model returns no images."""
|
||||
token = await _user_token(client)
|
||||
empty_response = {
|
||||
"id": "gen-empty",
|
||||
"model": "google/gemini-2.5-flash-image",
|
||||
"choices": [{"message": {"role": "assistant", "content": "ok", "images": []}}],
|
||||
}
|
||||
with patch("app.routers.generate.openrouter.generate_image_chat",
|
||||
new_callable=AsyncMock, return_value=empty_response):
|
||||
resp = await client.post(
|
||||
"/generate/image",
|
||||
json={"model": "google/gemini-2.5-flash-image", "prompt": "Hi"},
|
||||
headers={"Authorization": f"Bearer {token}"},
|
||||
)
|
||||
assert resp.status_code == 502
|
||||
assert "No images returned" in resp.json()["detail"]
|
||||
|
||||
|
||||
async def test_generate_image_flux(client):
|
||||
"""Flux model works correctly via chat completions."""
|
||||
token = await _user_token(client)
|
||||
with patch("app.routers.generate.openrouter.generate_image_chat",
|
||||
new_callable=AsyncMock, return_value=FAKE_IMAGE_CHAT_FLUX):
|
||||
resp = await client.post(
|
||||
"/generate/image",
|
||||
json={"model": "black-forest-labs/flux.2-klein-4b",
|
||||
"prompt": "A sunset"},
|
||||
headers={"Authorization": f"Bearer {token}"},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["images"][0]["url"] == "data:image/png;base64,abc123"
|
||||
|
||||
|
||||
async def test_generate_image_stored_in_db(client):
|
||||
"""Generated image row persists in generated_images table."""
|
||||
from app import db as db_module
|
||||
token = await _user_token(client)
|
||||
with patch("app.routers.generate.openrouter.generate_image_chat",
|
||||
new_callable=AsyncMock, return_value=FAKE_IMAGE_CHAT_GEMINI):
|
||||
resp = await client.post(
|
||||
"/generate/image",
|
||||
json={"model": "google/gemini-2.5-flash-image",
|
||||
"prompt": "A mountain"},
|
||||
headers={"Authorization": f"Bearer {token}"},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
image_id = resp.json()["images"][0]["image_id"]
|
||||
assert image_id is not None
|
||||
|
||||
row = db_module.get_conn().execute(
|
||||
"SELECT model_id, prompt, image_data FROM generated_images WHERE id = ?",
|
||||
[image_id],
|
||||
).fetchone()
|
||||
assert row is not None
|
||||
assert row[0] == "google/gemini-2.5-flash-image"
|
||||
assert row[1] == "A mountain"
|
||||
assert row[2] == "data:image/png;base64,gemini123"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@@ -15,6 +15,7 @@ from app.services.models import (
|
||||
_normalize_modality,
|
||||
_parse_modality,
|
||||
get_cached_models,
|
||||
get_model_output_modalities,
|
||||
is_cache_stale,
|
||||
refresh_models_cache,
|
||||
)
|
||||
@@ -28,28 +29,35 @@ FAKE_MODELS_RAW = [
|
||||
"name": "GPT-4o",
|
||||
"context_length": 128000,
|
||||
"pricing": {"prompt": "0.000005"},
|
||||
"architecture": {"modality": "text->text"},
|
||||
"architecture": {"modality": "text->text", "output_modalities": ["text"]},
|
||||
},
|
||||
{
|
||||
"id": "anthropic/claude-3-haiku",
|
||||
"name": "Claude 3 Haiku",
|
||||
"context_length": 200000,
|
||||
"pricing": {},
|
||||
"architecture": {"modality": "text+image->text"},
|
||||
"architecture": {"modality": "text+image->text", "output_modalities": ["text"]},
|
||||
},
|
||||
{
|
||||
"id": "openai/dall-e-3",
|
||||
"name": "DALL-E 3",
|
||||
"context_length": None,
|
||||
"pricing": {"image": "0.04"},
|
||||
"architecture": {"modality": "text->image"},
|
||||
"architecture": {"modality": "text->image", "output_modalities": ["image"]},
|
||||
},
|
||||
{
|
||||
"id": "openai/sora-2",
|
||||
"name": "Sora 2",
|
||||
"context_length": None,
|
||||
"pricing": {"video": "0.10"},
|
||||
"architecture": {"modality": "text->video"},
|
||||
"architecture": {"modality": "text->video", "output_modalities": ["video"]},
|
||||
},
|
||||
{
|
||||
"id": "google/gemini-2.5-flash-image",
|
||||
"name": "Gemini 2.5 Flash Image",
|
||||
"context_length": None,
|
||||
"pricing": {},
|
||||
"architecture": {"output_modalities": ["image", "text"]},
|
||||
},
|
||||
]
|
||||
|
||||
@@ -171,9 +179,9 @@ async def test_refresh_stores_models():
|
||||
return_value=FAKE_MODELS_RAW,
|
||||
):
|
||||
count = await refresh_models_cache(conn)
|
||||
assert count == 4
|
||||
assert count == 5
|
||||
all_models = get_cached_models(conn)
|
||||
assert len(all_models) == 4
|
||||
assert len(all_models) == 5
|
||||
|
||||
|
||||
async def test_refresh_replaces_old_cache():
|
||||
@@ -193,25 +201,29 @@ async def test_refresh_replaces_old_cache():
|
||||
ids = [m["id"] for m in get_cached_models(conn)]
|
||||
assert "old/model" not in ids
|
||||
assert "openai/gpt-4o" in ids
|
||||
assert len(ids) == 5
|
||||
|
||||
|
||||
def test_get_cached_models_filter_by_modality():
|
||||
conn = db_module.get_conn()
|
||||
now = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
for m in FAKE_MODELS_RAW:
|
||||
arch = m.get("architecture", {})
|
||||
modality = _parse_modality(arch.get("modality", "text->text"))
|
||||
modality = _extract_output_modality(m)
|
||||
conn.execute(
|
||||
"INSERT INTO models_cache (model_id, name, modality, fetched_at) VALUES (?, ?, ?, ?)",
|
||||
[m["id"], m["name"], modality, now],
|
||||
)
|
||||
text_models = get_cached_models(conn, modality="text")
|
||||
# gpt-4o, claude-3-haiku (gemini has output_modalities=["image","text"] → classified as "image")
|
||||
assert len(text_models) == 2
|
||||
assert all(m["modality"] == "text" for m in text_models)
|
||||
|
||||
image_models = get_cached_models(conn, modality="image")
|
||||
assert len(image_models) == 1
|
||||
assert image_models[0]["id"] == "openai/dall-e-3"
|
||||
# dall-e-3 + gemini (output_modalities starts with image)
|
||||
assert len(image_models) == 2
|
||||
image_ids = [m["id"] for m in image_models]
|
||||
assert "openai/dall-e-3" in image_ids
|
||||
assert "google/gemini-2.5-flash-image" in image_ids
|
||||
|
||||
video_models = get_cached_models(conn, modality="video")
|
||||
assert len(video_models) == 1
|
||||
@@ -233,7 +245,7 @@ async def test_list_models_endpoint_auto_refreshes(client):
|
||||
"/models/", headers={"Authorization": f"Bearer {token}"}
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
assert len(resp.json()) == 4
|
||||
assert len(resp.json()) == 5
|
||||
assert mock_fetch.await_count >= 1
|
||||
|
||||
|
||||
@@ -274,8 +286,10 @@ async def test_list_models_filter_by_modality(client):
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert len(data) == 1
|
||||
assert data[0]["id"] == "openai/dall-e-3"
|
||||
assert len(data) == 2 # dall-e-3 + gemini-2.5-flash-image
|
||||
image_ids = [m["id"] for m in data]
|
||||
assert "openai/dall-e-3" in image_ids
|
||||
assert "google/gemini-2.5-flash-image" in image_ids
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -301,7 +315,7 @@ async def test_refresh_endpoint_admin_succeeds(client):
|
||||
"/models/refresh", headers={"Authorization": f"Bearer {token}"}
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["refreshed"] == 4
|
||||
assert resp.json()["refreshed"] == 5
|
||||
|
||||
|
||||
async def test_refresh_endpoint_502_on_openrouter_error(client):
|
||||
@@ -315,3 +329,38 @@ async def test_refresh_endpoint_502_on_openrouter_error(client):
|
||||
"/models/refresh", headers={"Authorization": f"Bearer {token}"}
|
||||
)
|
||||
assert resp.status_code == 502
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Unit tests: get_model_output_modalities
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def test_get_model_output_modalities_image_only():
|
||||
conn = db_module.get_conn()
|
||||
with patch(
|
||||
"app.services.models.openrouter.list_models",
|
||||
new_callable=AsyncMock,
|
||||
return_value=FAKE_MODELS_RAW,
|
||||
):
|
||||
await refresh_models_cache(conn)
|
||||
modalities = get_model_output_modalities(conn, "openai/dall-e-3")
|
||||
assert modalities == ["image"]
|
||||
|
||||
|
||||
async def test_get_model_output_modalities_image_text():
|
||||
conn = db_module.get_conn()
|
||||
with patch(
|
||||
"app.services.models.openrouter.list_models",
|
||||
new_callable=AsyncMock,
|
||||
return_value=FAKE_MODELS_RAW,
|
||||
):
|
||||
await refresh_models_cache(conn)
|
||||
modalities = get_model_output_modalities(
|
||||
conn, "google/gemini-2.5-flash-image")
|
||||
assert set(modalities) == {"image", "text"}
|
||||
|
||||
|
||||
def test_get_model_output_modalities_unknown_model():
|
||||
conn = db_module.get_conn()
|
||||
result = get_model_output_modalities(conn, "unknown/model")
|
||||
assert result == []
|
||||
|
||||
@@ -115,7 +115,9 @@ async def test_list_users_as_admin(client):
|
||||
resp = await client.get("/users", headers={"Authorization": f"Bearer {admin_token}"})
|
||||
assert resp.status_code == 200
|
||||
assert isinstance(resp.json(), list)
|
||||
assert len(resp.json()) == 1
|
||||
assert len(resp.json()) >= 1
|
||||
emails = [u["email"] for u in resp.json()]
|
||||
assert "user@example.com" in emails
|
||||
|
||||
|
||||
async def test_list_users_as_regular_user(client):
|
||||
|
||||
@@ -178,7 +178,13 @@ def dashboard():
|
||||
user = resp.json() if resp.status_code == 200 else {}
|
||||
img_resp = _api("GET", "/images/", token=token)
|
||||
images = img_resp.json() if img_resp.status_code == 200 else []
|
||||
return render_template("dashboard.html", user=user, images=images)
|
||||
gen_resp = _api("GET", "/generate/images", token=token)
|
||||
generated_images = gen_resp.json() if gen_resp.status_code == 200 else []
|
||||
vid_resp = _api("GET", "/generate/videos", token=token)
|
||||
generated_videos = vid_resp.json() if vid_resp.status_code == 200 else []
|
||||
return render_template("dashboard.html", user=user, images=images,
|
||||
generated_images=generated_images,
|
||||
generated_videos=generated_videos)
|
||||
|
||||
|
||||
# ── Generate ──────────────────────────────────────────────────────────────
|
||||
|
||||
@@ -6,7 +6,59 @@ endblock %} {% block content %}
|
||||
<a href="{{ url_for('generate') }}" class="btn">Start generating</a>
|
||||
</div>
|
||||
|
||||
{% if images %}
|
||||
{% if generated_images %}
|
||||
<div class="card mt-2">
|
||||
<h2>Generated images</h2>
|
||||
<div class="image-grid">
|
||||
{% for img in generated_images %}
|
||||
<div class="image-grid-item">
|
||||
<img
|
||||
src="{{ img.image_data }}"
|
||||
alt="{{ img.prompt }}"
|
||||
class="generated-image"
|
||||
loading="lazy"
|
||||
/>
|
||||
<p class="text-muted" style="font-size: 0.75rem; margin-top: 0.25rem">
|
||||
<strong>{{ img.model_id }}</strong><br />{{ img.prompt[:80] }}{% if
|
||||
img.prompt|length > 80 %}…{% endif %}
|
||||
</p>
|
||||
</div>
|
||||
{% endfor %}
|
||||
</div>
|
||||
</div>
|
||||
{% endif %} {% if generated_videos %}
|
||||
<div class="card mt-2">
|
||||
<h2>Generated videos</h2>
|
||||
<div class="image-grid">
|
||||
{% for vid in generated_videos %}
|
||||
<div class="image-grid-item">
|
||||
{% if vid.video_url %}
|
||||
<video controls style="max-width: 100%; border-radius: 6px">
|
||||
<source src="{{ vid.video_url }}" />
|
||||
Your browser does not support the video tag.
|
||||
</video>
|
||||
{% else %}
|
||||
<div
|
||||
style="
|
||||
background: #1a1a1a;
|
||||
border-radius: 6px;
|
||||
padding: 2rem;
|
||||
text-align: center;
|
||||
"
|
||||
>
|
||||
<span class="text-muted">{{ vid.status | capitalize }} …</span>
|
||||
</div>
|
||||
{% endif %}
|
||||
<p class="text-muted" style="font-size: 0.75rem; margin-top: 0.25rem">
|
||||
<strong>{{ vid.model_id }}</strong><br />{{ vid.prompt[:80] }}{% if
|
||||
vid.prompt|length > 80 %}…{% endif %}<br />
|
||||
<em>{{ vid.status }}</em>
|
||||
</p>
|
||||
</div>
|
||||
{% endfor %}
|
||||
</div>
|
||||
</div>
|
||||
{% endif %} {% if images %}
|
||||
<div class="card mt-2">
|
||||
<h2>Uploaded reference images</h2>
|
||||
<div class="image-grid">
|
||||
|
||||
@@ -8,12 +8,12 @@
|
||||
{% if models %}
|
||||
<select id="model" name="model" required>
|
||||
{% for m in models %}
|
||||
<option value="{{ m.id }}" {% if request.form.get('model', '') == m.id %}selected{% endif %}>{{ m.name }}</option>
|
||||
<option value="{{ m.id }}" {{ "selected" if request.form.get('model', '') == m.id else "" }}>{{ m.name }}</option>
|
||||
{% endfor %}
|
||||
</select>
|
||||
{% else %}
|
||||
<input id="model" name="model" type="text" required
|
||||
placeholder="e.g. openai/dall-e-3"
|
||||
placeholder="e.g. google/gemini-2.5-flash-image"
|
||||
value="{{ request.form.get('model', '') }}">
|
||||
{% endif %}
|
||||
|
||||
@@ -21,40 +21,25 @@
|
||||
<textarea id="prompt" name="prompt" rows="4" required
|
||||
placeholder="Describe the image you want…">{{ request.form.get('prompt', '') }}</textarea>
|
||||
|
||||
<label for="size">Size</label>
|
||||
<select id="size" name="size">
|
||||
<option value="1024x1024" {% if request.form.get('size','1024x1024')=='1024x1024' %}selected{% endif %}>1024×1024</option>
|
||||
<option value="1792x1024" {% if request.form.get('size')=='1792x1024' %}selected{% endif %}>1792×1024 (landscape)</option>
|
||||
<option value="1024x1792" {% if request.form.get('size')=='1024x1792' %}selected{% endif %}>1024×1792 (portrait)</option>
|
||||
<option value="512x512" {% if request.form.get('size')=='512x512' %}selected{% endif %}>512×512</option>
|
||||
</select>
|
||||
|
||||
<label for="aspect_ratio">Aspect ratio</label>
|
||||
<select id="aspect_ratio" name="aspect_ratio">
|
||||
<option value="">Auto (default)</option>
|
||||
<option value="1:1" {% if request.form.get('aspect_ratio')=='1:1' %}selected{% endif %}>1:1 (square)</option>
|
||||
<option value="16:9" {% if request.form.get('aspect_ratio')=='16:9' %}selected{% endif %}>16:9 (landscape)</option>
|
||||
<option value="9:16" {% if request.form.get('aspect_ratio')=='9:16' %}selected{% endif %}>9:16 (portrait)</option>
|
||||
<option value="4:3" {% if request.form.get('aspect_ratio')=='4:3' %}selected{% endif %}>4:3</option>
|
||||
<option value="3:4" {% if request.form.get('aspect_ratio')=='3:4' %}selected{% endif %}>3:4</option>
|
||||
<option value="3:2" {% if request.form.get('aspect_ratio')=='3:2' %}selected{% endif %}>3:2</option>
|
||||
<option value="2:3" {% if request.form.get('aspect_ratio')=='2:3' %}selected{% endif %}>2:3</option>
|
||||
<option value="">Auto (default 1:1)</option>
|
||||
<option value="1:1" {{ "selected" if request.form.get('aspect_ratio')=='1:1' else "" }}>1:1 (square)</option>
|
||||
<option value="16:9" {{ "selected" if request.form.get('aspect_ratio')=='16:9' else "" }}>16:9 (landscape)</option>
|
||||
<option value="9:16" {{ "selected" if request.form.get('aspect_ratio')=='9:16' else "" }}>9:16 (portrait)</option>
|
||||
<option value="4:3" {{ "selected" if request.form.get('aspect_ratio')=='4:3' else "" }}>4:3</option>
|
||||
<option value="3:4" {{ "selected" if request.form.get('aspect_ratio')=='3:4' else "" }}>3:4</option>
|
||||
<option value="3:2" {{ "selected" if request.form.get('aspect_ratio')=='3:2' else "" }}>3:2</option>
|
||||
<option value="2:3" {{ "selected" if request.form.get('aspect_ratio')=='2:3' else "" }}>2:3</option>
|
||||
</select>
|
||||
|
||||
<label for="image_size">Resolution</label>
|
||||
<select id="image_size" name="image_size">
|
||||
<option value="">Auto (default)</option>
|
||||
<option value="0.5K" {% if request.form.get('image_size')=='0.5K' %}selected{% endif %}>0.5K (low)</option>
|
||||
<option value="1K" {% if request.form.get('image_size')=='1K' %}selected{% endif %}>1K (standard)</option>
|
||||
<option value="2K" {% if request.form.get('image_size')=='2K' %}selected{% endif %}>2K (high)</option>
|
||||
<option value="4K" {% if request.form.get('image_size')=='4K' %}selected{% endif %}>4K (ultra)</option>
|
||||
</select>
|
||||
|
||||
<label for="n">Number of images</label>
|
||||
<select id="n" name="n">
|
||||
<option value="1" {% if request.form.get('n','1')=='1' %}selected{% endif %}>1</option>
|
||||
<option value="2" {% if request.form.get('n')=='2' %}selected{% endif %}>2</option>
|
||||
<option value="4" {% if request.form.get('n')=='4' %}selected{% endif %}>4</option>
|
||||
<option value="">Auto (default 1K)</option>
|
||||
<option value="0.5K" {{ "selected" if request.form.get('image_size')=='0.5K' else "" }}>0.5K (low)</option>
|
||||
<option value="1K" {{ "selected" if request.form.get('image_size')=='1K' else "" }}>1K (standard)</option>
|
||||
<option value="2K" {{ "selected" if request.form.get('image_size')=='2K' else "" }}>2K (high)</option>
|
||||
<option value="4K" {{ "selected" if request.form.get('image_size')=='4K' else "" }}>4K (ultra)</option>
|
||||
</select>
|
||||
|
||||
<label for="reference_image">Reference image (optional)</label>
|
||||
@@ -65,7 +50,7 @@
|
||||
accept="image/png,image/jpeg,image/webp,image/gif"
|
||||
>
|
||||
<p class="text-muted mt-1" id="reference-image-help">
|
||||
Upload image for visual reference in upcoming image-to-image flow.
|
||||
Upload an image to use as visual reference (image-to-image).
|
||||
</p>
|
||||
<div class="image-upload-preview" id="image-upload-preview" hidden>
|
||||
<p class="text-muted" id="image-upload-filename"></p>
|
||||
@@ -83,7 +68,9 @@
|
||||
<div class="result">
|
||||
<h2>Generated image{{ 's' if result.images|length > 1 }}</h2>
|
||||
{% for img in result.images %}
|
||||
{% if img.url %}
|
||||
<img src="{{ img.url }}" alt="Generated image" class="generated-image">
|
||||
{% endif %}
|
||||
{% if img.revised_prompt %}
|
||||
<p class="text-muted mt-1" style="font-size:0.8rem;">{{ img.revised_prompt }}</p>
|
||||
{% endif %}
|
||||
|
||||
@@ -151,7 +151,8 @@ def test_dashboard_renders_user_info(client):
|
||||
me_mock = _mock_response(
|
||||
200, {"id": "1", "email": "u@example.com", "role": "user"})
|
||||
images_mock = _mock_response(200, [])
|
||||
with patch("frontend.app.main.httpx.request", side_effect=[me_mock, images_mock]):
|
||||
gen_images_mock = _mock_response(200, [])
|
||||
with patch("frontend.app.main.httpx.request", side_effect=[me_mock, images_mock, gen_images_mock]):
|
||||
resp = client.get("/dashboard")
|
||||
assert resp.status_code == 200
|
||||
assert b"u@example.com" in resp.data
|
||||
@@ -531,19 +532,43 @@ def test_dashboard_shows_uploaded_images(client):
|
||||
{"id": "img-1", "filename": "cat.png", "content_type": "image/png",
|
||||
"size_bytes": 1024, "created_at": "2026-04-29T10:00:00"},
|
||||
])
|
||||
with patch("frontend.app.main.httpx.request", side_effect=[me_mock, images_mock]):
|
||||
gen_images_mock = _mock_response(200, [])
|
||||
with patch("frontend.app.main.httpx.request", side_effect=[me_mock, images_mock, gen_images_mock]):
|
||||
resp = client.get("/dashboard")
|
||||
assert resp.status_code == 200
|
||||
assert b"cat.png" in resp.data
|
||||
assert b"img-1" in resp.data
|
||||
|
||||
|
||||
def test_dashboard_shows_generated_images(client):
|
||||
_set_auth(client)
|
||||
me_mock = _mock_response(
|
||||
200, {"id": "1", "email": "u@example.com", "role": "user"})
|
||||
images_mock = _mock_response(200, [])
|
||||
gen_images_mock = _mock_response(200, [
|
||||
{
|
||||
"id": "gen-1",
|
||||
"model_id": "google/gemini-2.5-flash-image",
|
||||
"prompt": "A cat on the moon",
|
||||
"image_data": "data:image/png;base64,abc123",
|
||||
"created_at": "2026-04-29T10:00:00",
|
||||
}
|
||||
])
|
||||
with patch("frontend.app.main.httpx.request", side_effect=[me_mock, images_mock, gen_images_mock]):
|
||||
resp = client.get("/dashboard")
|
||||
assert resp.status_code == 200
|
||||
assert b"Generated images" in resp.data
|
||||
assert b"A cat on the moon" in resp.data
|
||||
assert b"data:image/png;base64,abc123" in resp.data
|
||||
|
||||
|
||||
def test_dashboard_no_images_section_when_empty(client):
|
||||
_set_auth(client)
|
||||
me_mock = _mock_response(
|
||||
200, {"id": "1", "email": "u@example.com", "role": "user"})
|
||||
images_mock = _mock_response(200, [])
|
||||
with patch("frontend.app.main.httpx.request", side_effect=[me_mock, images_mock]):
|
||||
gen_images_mock = _mock_response(200, [])
|
||||
with patch("frontend.app.main.httpx.request", side_effect=[me_mock, images_mock, gen_images_mock]):
|
||||
resp = client.get("/dashboard")
|
||||
assert resp.status_code == 200
|
||||
assert b"Uploaded reference images" not in resp.data
|
||||
|
||||
Reference in New Issue
Block a user