feat: enhance model caching and output modalities handling
- Updated `refresh_models_cache` to include output modalities in the models cache. - Added `get_model_output_modalities` function to retrieve output modalities for a specific model. - Modified tests to cover new functionality for output modalities. - Updated OpenRouter video generation functions to support audio generation and improved error handling. - Enhanced dashboard to display generated images and videos. - Refactored frontend templates to accommodate new data structures for generated content. - Adjusted tests to validate changes in model handling and dashboard rendering. Co-authored-by: Copilot <copilot@github.com>
This commit is contained in:
@@ -86,6 +86,34 @@ def _run_migrations(conn: duckdb.DuckDBPyConnection) -> None:
|
||||
fetched_at TIMESTAMP NOT NULL
|
||||
)
|
||||
""")
|
||||
conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS generated_images (
|
||||
id UUID DEFAULT uuid() PRIMARY KEY,
|
||||
user_id UUID NOT NULL,
|
||||
model_id VARCHAR NOT NULL,
|
||||
prompt VARCHAR NOT NULL,
|
||||
image_data VARCHAR NOT NULL,
|
||||
created_at TIMESTAMP DEFAULT now()
|
||||
)
|
||||
""")
|
||||
conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS generated_videos (
|
||||
id UUID DEFAULT uuid() PRIMARY KEY,
|
||||
user_id UUID NOT NULL,
|
||||
job_id VARCHAR NOT NULL,
|
||||
model_id VARCHAR NOT NULL,
|
||||
prompt VARCHAR NOT NULL,
|
||||
polling_url VARCHAR,
|
||||
status VARCHAR NOT NULL DEFAULT 'pending',
|
||||
video_url VARCHAR,
|
||||
created_at TIMESTAMP DEFAULT now(),
|
||||
updated_at TIMESTAMP DEFAULT now()
|
||||
)
|
||||
""")
|
||||
# Migration: add output_modalities column if absent (stores JSON array string)
|
||||
conn.execute("""
|
||||
ALTER TABLE models_cache ADD COLUMN IF NOT EXISTS output_modalities VARCHAR
|
||||
""")
|
||||
_seed_admin(conn)
|
||||
|
||||
|
||||
|
||||
+14
-14
@@ -1,10 +1,10 @@
|
||||
from .routers import auth as auth_router
|
||||
from .routers import users as users_router
|
||||
from .routers import admin as admin_router
|
||||
from .routers import ai as ai_router
|
||||
from .routers import generate as generate_router
|
||||
from .routers import images as images_router
|
||||
from .routers import models as models_router
|
||||
from .routers import auth
|
||||
from .routers import users
|
||||
from .routers import admin
|
||||
from .routers import ai
|
||||
from .routers import generate
|
||||
from .routers import images
|
||||
from .routers import models
|
||||
from .db import close_db, init_db
|
||||
import os
|
||||
from contextlib import asynccontextmanager
|
||||
@@ -38,13 +38,13 @@ app.add_middleware(
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
app.include_router(auth_router.router)
|
||||
app.include_router(users_router.router)
|
||||
app.include_router(admin_router.router)
|
||||
app.include_router(ai_router.router)
|
||||
app.include_router(generate_router.router)
|
||||
app.include_router(images_router.router)
|
||||
app.include_router(models_router.router)
|
||||
app.include_router(auth.router)
|
||||
app.include_router(users.router)
|
||||
app.include_router(admin.router)
|
||||
app.include_router(ai.router)
|
||||
app.include_router(generate.router)
|
||||
app.include_router(images.router)
|
||||
app.include_router(models.router)
|
||||
|
||||
|
||||
@app.get("/health", tags=["health"])
|
||||
|
||||
@@ -62,6 +62,7 @@ class ImageResult(BaseModel):
|
||||
url: str | None = None
|
||||
b64_json: str | None = None
|
||||
revised_prompt: str | None = None
|
||||
image_id: str | None = None # UUID of stored row in generated_images
|
||||
|
||||
|
||||
class ImageResponse(BaseModel):
|
||||
|
||||
+193
-69
@@ -1,6 +1,9 @@
|
||||
"""Generate router: text, image, video, and image-to-video generation."""
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
|
||||
from ..db import get_conn, get_write_lock
|
||||
from ..dependencies import get_current_user
|
||||
from ..models.ai import (
|
||||
ImageRequest,
|
||||
@@ -13,6 +16,7 @@ from ..models.ai import (
|
||||
VideoResponse,
|
||||
)
|
||||
from ..services import openrouter
|
||||
from ..services.models import get_model_output_modalities
|
||||
|
||||
router = APIRouter(prefix="/generate", tags=["generate"])
|
||||
|
||||
@@ -62,81 +66,129 @@ async def generate_text(
|
||||
@router.post("/image", response_model=ImageResponse)
|
||||
async def generate_image(
|
||||
body: ImageRequest,
|
||||
_: dict = Depends(get_current_user),
|
||||
current_user: dict = Depends(get_current_user),
|
||||
) -> ImageResponse:
|
||||
"""Generate images from a text prompt."""
|
||||
# Detect if model uses chat completions (FLUX, GPT-5 Image Mini) vs /images/generations (DALL-E)
|
||||
chat_models = {"black-forest-labs/flux.2-klein-4b",
|
||||
"openai/gpt-5-image-mini"}
|
||||
is_chat_model = body.model.lower() in {m.lower() for m in chat_models} or \
|
||||
any(m in body.model.lower() for m in ["flux", "gpt-5-image-mini"])
|
||||
"""Generate images from a prompt using the chat completions endpoint.
|
||||
|
||||
All OpenRouter image models use /chat/completions with a modalities param.
|
||||
Models that output only images use ["image"]; those that also output text
|
||||
use ["image", "text"]. We look this up from the model cache; default to
|
||||
["image", "text"] when the model is not yet cached.
|
||||
"""
|
||||
# Determine modalities from cache; default ["image", "text"] works for most models
|
||||
try:
|
||||
conn = get_conn()
|
||||
cached_modalities = get_model_output_modalities(conn, body.model)
|
||||
except Exception:
|
||||
cached_modalities = []
|
||||
|
||||
if cached_modalities:
|
||||
# If cache says model only outputs image (no text), use ["image"]
|
||||
modalities = ["image"] if set(cached_modalities) == {
|
||||
"image"} else ["image", "text"]
|
||||
else:
|
||||
# Safe default: ["image", "text"]; works for Gemini, GPT-image etc.
|
||||
# For image-only models that fail with this, the error surfaces to the user.
|
||||
modalities = ["image", "text"]
|
||||
|
||||
image_config: dict = {}
|
||||
if body.aspect_ratio:
|
||||
image_config["aspect_ratio"] = body.aspect_ratio
|
||||
if body.image_size:
|
||||
image_config["image_size"] = body.image_size
|
||||
|
||||
try:
|
||||
if is_chat_model:
|
||||
image_config = {}
|
||||
if body.aspect_ratio:
|
||||
image_config["aspect_ratio"] = body.aspect_ratio
|
||||
if body.image_size:
|
||||
image_config["image_size"] = body.image_size
|
||||
result = await openrouter.generate_image_chat(
|
||||
model=body.model,
|
||||
prompt=body.prompt,
|
||||
modalities=[
|
||||
"image", "text"] if "gpt-5-image-mini" in body.model.lower() else ["image"],
|
||||
image_config=image_config if image_config else None,
|
||||
)
|
||||
else:
|
||||
result = await openrouter.generate_image(
|
||||
model=body.model,
|
||||
prompt=body.prompt,
|
||||
n=body.n,
|
||||
size=body.size,
|
||||
)
|
||||
result = await openrouter.generate_image_chat(
|
||||
model=body.model,
|
||||
prompt=body.prompt,
|
||||
modalities=modalities,
|
||||
image_config=image_config if image_config else None,
|
||||
)
|
||||
except Exception as exc:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_502_BAD_GATEWAY, detail=f"OpenRouter error: {exc}")
|
||||
|
||||
try:
|
||||
if is_chat_model:
|
||||
# Chat completions response: choices[0].message.images[].image_url.url
|
||||
images = []
|
||||
message = result.get("choices", [{}])[0].get("message", {})
|
||||
for item in message.get("images", []):
|
||||
img_url = item.get("image_url", {}).get("url")
|
||||
images.append(ImageResult(
|
||||
url=img_url,
|
||||
b64_json=None,
|
||||
revised_prompt=message.get("content"),
|
||||
message = result.get("choices", [{}])[0].get("message", {})
|
||||
images = []
|
||||
for item in message.get("images", []):
|
||||
img_url = item.get("image_url", {}).get("url")
|
||||
images.append(ImageResult(
|
||||
url=img_url,
|
||||
b64_json=None,
|
||||
revised_prompt=message.get("content") or None,
|
||||
))
|
||||
if not images:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_502_BAD_GATEWAY,
|
||||
detail="No images returned by model. Verify the model supports image generation.",
|
||||
)
|
||||
|
||||
# Persist each image to DB
|
||||
user_id = current_user.get("id") or current_user.get("sub")
|
||||
now = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
stored: list[ImageResult] = []
|
||||
async with get_write_lock():
|
||||
conn = get_conn()
|
||||
for img in images:
|
||||
if img.url:
|
||||
row = conn.execute(
|
||||
"""INSERT INTO generated_images (user_id, model_id, prompt, image_data, created_at)
|
||||
VALUES (?, ?, ?, ?, ?) RETURNING id""",
|
||||
[user_id, body.model, body.prompt, img.url, now],
|
||||
).fetchone()
|
||||
image_id = str(row[0]) if row else None
|
||||
else:
|
||||
image_id = None
|
||||
stored.append(ImageResult(
|
||||
url=img.url,
|
||||
b64_json=img.b64_json,
|
||||
revised_prompt=img.revised_prompt,
|
||||
image_id=image_id,
|
||||
))
|
||||
return ImageResponse(
|
||||
id=result.get("id", ""),
|
||||
model=result.get("model", body.model),
|
||||
images=images,
|
||||
)
|
||||
else:
|
||||
# /images/generations response: data[].url
|
||||
images = [
|
||||
ImageResult(
|
||||
url=item.get("url"),
|
||||
b64_json=item.get("b64_json"),
|
||||
revised_prompt=item.get("revised_prompt"),
|
||||
)
|
||||
for item in result.get("data", [])
|
||||
]
|
||||
return ImageResponse(
|
||||
id=result.get("id", ""),
|
||||
model=result.get("model", body.model),
|
||||
images=images,
|
||||
)
|
||||
|
||||
return ImageResponse(
|
||||
id=result.get("id", ""),
|
||||
model=result.get("model", body.model),
|
||||
images=stored,
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except (KeyError, TypeError) as exc:
|
||||
raise HTTPException(status_code=status.HTTP_502_BAD_GATEWAY,
|
||||
detail=f"Unexpected response format: {exc}")
|
||||
|
||||
|
||||
@router.get("/images")
|
||||
async def list_generated_images(
|
||||
current_user: dict = Depends(get_current_user),
|
||||
) -> list[dict]:
|
||||
"""Return all generated images for the current user, newest first."""
|
||||
user_id = current_user.get("id") or current_user.get("sub")
|
||||
conn = get_conn()
|
||||
rows = conn.execute(
|
||||
"""SELECT id, model_id, prompt, image_data, created_at
|
||||
FROM generated_images
|
||||
WHERE user_id = ?
|
||||
ORDER BY created_at DESC""",
|
||||
[user_id],
|
||||
).fetchall()
|
||||
return [
|
||||
{
|
||||
"id": str(r[0]),
|
||||
"model_id": r[1],
|
||||
"prompt": r[2],
|
||||
"image_data": r[3],
|
||||
"created_at": r[4].isoformat() if r[4] else None,
|
||||
}
|
||||
for r in rows
|
||||
]
|
||||
|
||||
|
||||
@router.post("/video", response_model=VideoResponse)
|
||||
async def generate_video(
|
||||
body: VideoRequest,
|
||||
_: dict = Depends(get_current_user),
|
||||
current_user: dict = Depends(get_current_user),
|
||||
) -> VideoResponse:
|
||||
"""Generate a video from a text prompt."""
|
||||
try:
|
||||
@@ -151,12 +203,26 @@ async def generate_video(
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_502_BAD_GATEWAY, detail=f"OpenRouter error: {exc}")
|
||||
|
||||
user_id = current_user.get("id") or current_user.get("sub")
|
||||
job_id = result.get("id", "")
|
||||
polling_url = result.get("polling_url")
|
||||
job_status = result.get("status", "pending")
|
||||
now = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
async with get_write_lock():
|
||||
conn = get_conn()
|
||||
conn.execute(
|
||||
"""INSERT INTO generated_videos (user_id, job_id, model_id, prompt, polling_url, status, created_at, updated_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?)""",
|
||||
[user_id, job_id, body.model, body.prompt,
|
||||
polling_url, job_status, now, now],
|
||||
)
|
||||
|
||||
urls = result.get("unsigned_urls") or result.get("video_urls")
|
||||
return VideoResponse(
|
||||
id=result.get("id", ""),
|
||||
id=job_id,
|
||||
model=body.model,
|
||||
status=result.get("status", "queued"),
|
||||
polling_url=result.get("polling_url"),
|
||||
status=job_status,
|
||||
polling_url=polling_url,
|
||||
video_urls=urls,
|
||||
video_url=(urls or [None])[0],
|
||||
error=result.get("error"),
|
||||
@@ -167,7 +233,7 @@ async def generate_video(
|
||||
@router.post("/video/from-image", response_model=VideoResponse)
|
||||
async def generate_video_from_image(
|
||||
body: VideoFromImageRequest,
|
||||
_: dict = Depends(get_current_user),
|
||||
current_user: dict = Depends(get_current_user),
|
||||
) -> VideoResponse:
|
||||
"""Generate a video from an image and a text prompt."""
|
||||
try:
|
||||
@@ -183,12 +249,26 @@ async def generate_video_from_image(
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_502_BAD_GATEWAY, detail=f"OpenRouter error: {exc}")
|
||||
|
||||
user_id = current_user.get("id") or current_user.get("sub")
|
||||
job_id = result.get("id", "")
|
||||
polling_url = result.get("polling_url")
|
||||
job_status = result.get("status", "pending")
|
||||
now = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
async with get_write_lock():
|
||||
conn = get_conn()
|
||||
conn.execute(
|
||||
"""INSERT INTO generated_videos (user_id, job_id, model_id, prompt, polling_url, status, created_at, updated_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?)""",
|
||||
[user_id, job_id, body.model, body.prompt,
|
||||
polling_url, job_status, now, now],
|
||||
)
|
||||
|
||||
urls = result.get("unsigned_urls") or result.get("video_urls")
|
||||
return VideoResponse(
|
||||
id=result.get("id", ""),
|
||||
id=job_id,
|
||||
model=body.model,
|
||||
status=result.get("status", "queued"),
|
||||
polling_url=result.get("polling_url"),
|
||||
status=job_status,
|
||||
polling_url=polling_url,
|
||||
video_urls=urls,
|
||||
video_url=(urls or [None])[0],
|
||||
error=result.get("error"),
|
||||
@@ -199,23 +279,67 @@ async def generate_video_from_image(
|
||||
@router.get("/video/status", response_model=VideoResponse)
|
||||
async def poll_video_status(
|
||||
polling_url: str,
|
||||
_: dict = Depends(get_current_user),
|
||||
current_user: dict = Depends(get_current_user),
|
||||
) -> VideoResponse:
|
||||
"""Poll the status of a video generation job via its polling_url."""
|
||||
"""Poll status of a video generation job; updates DB row when completed/failed."""
|
||||
try:
|
||||
result = await openrouter.poll_video_status(polling_url)
|
||||
except Exception as exc:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_502_BAD_GATEWAY, detail=f"OpenRouter error: {exc}")
|
||||
|
||||
job_status = result.get("status", "processing")
|
||||
urls = result.get("unsigned_urls") or result.get("video_urls")
|
||||
video_url = (urls or [None])[0]
|
||||
|
||||
# Update DB row for this job when terminal state reached
|
||||
if job_status in ("completed", "failed"):
|
||||
now = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
async with get_write_lock():
|
||||
conn = get_conn()
|
||||
conn.execute(
|
||||
"""UPDATE generated_videos
|
||||
SET status = ?, video_url = ?, updated_at = ?
|
||||
WHERE job_id = ?""",
|
||||
[job_status, video_url, now, result.get("id", "")],
|
||||
)
|
||||
|
||||
return VideoResponse(
|
||||
id=result.get("id", ""),
|
||||
model=result.get("model", ""),
|
||||
status=result.get("status", "processing"),
|
||||
status=job_status,
|
||||
polling_url=result.get("polling_url"),
|
||||
video_urls=urls,
|
||||
video_url=(urls or [None])[0],
|
||||
video_url=video_url,
|
||||
error=result.get("error"),
|
||||
metadata=result.get("metadata"),
|
||||
)
|
||||
|
||||
|
||||
@router.get("/videos")
|
||||
async def list_generated_videos(
|
||||
current_user: dict = Depends(get_current_user),
|
||||
) -> list[dict]:
|
||||
"""Return all generated video jobs for the current user, newest first."""
|
||||
user_id = current_user.get("id") or current_user.get("sub")
|
||||
conn = get_conn()
|
||||
rows = conn.execute(
|
||||
"""SELECT id, job_id, model_id, prompt, polling_url, status, video_url, created_at
|
||||
FROM generated_videos
|
||||
WHERE user_id = ?
|
||||
ORDER BY created_at DESC""",
|
||||
[user_id],
|
||||
).fetchall()
|
||||
return [
|
||||
{
|
||||
"id": str(r[0]),
|
||||
"job_id": r[1],
|
||||
"model_id": r[2],
|
||||
"prompt": r[3],
|
||||
"polling_url": r[4],
|
||||
"status": r[5],
|
||||
"video_url": r[6],
|
||||
"created_at": r[7].isoformat() if r[7] else None,
|
||||
}
|
||||
for r in rows
|
||||
]
|
||||
|
||||
@@ -91,16 +91,28 @@ async def refresh_models_cache(conn: duckdb.DuckDBPyConnection) -> int:
|
||||
model_id = m.get("id", "")
|
||||
if not model_id:
|
||||
continue
|
||||
# Full output_modalities array from architecture (for proper modalities param in image gen)
|
||||
architecture = m.get("architecture") or {}
|
||||
raw_output_modalities: list | None = (
|
||||
architecture.get("output_modalities") or m.get("output_modalities")
|
||||
)
|
||||
output_modalities_json: str | None = (
|
||||
json.dumps([_normalize_modality(str(v))
|
||||
for v in raw_output_modalities])
|
||||
if isinstance(raw_output_modalities, list)
|
||||
else None
|
||||
)
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT INTO models_cache (model_id, name, modality, context_length, pricing, fetched_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?)
|
||||
INSERT INTO models_cache (model_id, name, modality, context_length, pricing, fetched_at, output_modalities)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?)
|
||||
ON CONFLICT (model_id) DO UPDATE SET
|
||||
name = excluded.name,
|
||||
modality = excluded.modality,
|
||||
context_length = excluded.context_length,
|
||||
pricing = excluded.pricing,
|
||||
fetched_at = excluded.fetched_at
|
||||
fetched_at = excluded.fetched_at,
|
||||
output_modalities = excluded.output_modalities
|
||||
""",
|
||||
[
|
||||
model_id,
|
||||
@@ -109,6 +121,7 @@ async def refresh_models_cache(conn: duckdb.DuckDBPyConnection) -> int:
|
||||
m.get("context_length"),
|
||||
json.dumps(pricing) if pricing else None,
|
||||
now,
|
||||
output_modalities_json,
|
||||
],
|
||||
)
|
||||
count += 1
|
||||
@@ -168,3 +181,20 @@ def get_cached_models(
|
||||
"pricing": pricing,
|
||||
})
|
||||
return result
|
||||
|
||||
|
||||
def get_model_output_modalities(
|
||||
conn: duckdb.DuckDBPyConnection,
|
||||
model_id: str,
|
||||
) -> list[str]:
|
||||
"""Return output_modalities list for a model; empty list if not found."""
|
||||
row = conn.execute(
|
||||
"SELECT output_modalities FROM models_cache WHERE model_id = ?",
|
||||
[model_id],
|
||||
).fetchone()
|
||||
if not row or not row[0]:
|
||||
return []
|
||||
try:
|
||||
return json.loads(row[0])
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
return []
|
||||
|
||||
@@ -95,8 +95,9 @@ async def generate_video(
|
||||
duration_seconds: int | None = None,
|
||||
aspect_ratio: str = "16:9",
|
||||
resolution: str | None = None,
|
||||
generate_audio: bool | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Request text-to-video generation via OpenRouter."""
|
||||
"""Request text-to-video generation via OpenRouter POST /videos."""
|
||||
base_url = os.getenv("OPENROUTER_BASE_URL", OPENROUTER_BASE_URL)
|
||||
payload: dict[str, Any] = {
|
||||
"model": model,
|
||||
@@ -104,9 +105,12 @@ async def generate_video(
|
||||
"aspect_ratio": aspect_ratio,
|
||||
}
|
||||
if duration_seconds is not None:
|
||||
payload["duration_seconds"] = duration_seconds
|
||||
# API uses 'duration' not 'duration_seconds'
|
||||
payload["duration"] = duration_seconds
|
||||
if resolution is not None:
|
||||
payload["resolution"] = resolution
|
||||
if generate_audio is not None:
|
||||
payload["generate_audio"] = generate_audio
|
||||
async with httpx.AsyncClient(timeout=120) as client:
|
||||
resp = client.build_request(
|
||||
"POST", f"{base_url}/videos", headers=_headers(), json=payload
|
||||
@@ -123,19 +127,31 @@ async def generate_video_from_image(
|
||||
duration_seconds: int | None = None,
|
||||
aspect_ratio: str = "16:9",
|
||||
resolution: str | None = None,
|
||||
generate_audio: bool | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Request image-to-video generation via OpenRouter."""
|
||||
"""Request image-to-video generation via OpenRouter POST /videos.
|
||||
|
||||
Uses frame_images array with first_frame as per OpenRouter API spec.
|
||||
"""
|
||||
base_url = os.getenv("OPENROUTER_BASE_URL", OPENROUTER_BASE_URL)
|
||||
payload: dict[str, Any] = {
|
||||
"model": model,
|
||||
"image_url": image_url,
|
||||
"prompt": prompt,
|
||||
"aspect_ratio": aspect_ratio,
|
||||
"frame_images": [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": image_url},
|
||||
"frame_type": "first_frame",
|
||||
}
|
||||
],
|
||||
}
|
||||
if duration_seconds is not None:
|
||||
payload["duration_seconds"] = duration_seconds
|
||||
payload["duration"] = duration_seconds
|
||||
if resolution is not None:
|
||||
payload["resolution"] = resolution
|
||||
if generate_audio is not None:
|
||||
payload["generate_audio"] = generate_audio
|
||||
async with httpx.AsyncClient(timeout=120) as client:
|
||||
resp = client.build_request(
|
||||
"POST", f"{base_url}/videos", headers=_headers(), json=payload
|
||||
@@ -154,6 +170,18 @@ async def poll_video_status(polling_url: str) -> dict[str, Any]:
|
||||
return response.json()
|
||||
|
||||
|
||||
async def list_video_models() -> list[dict[str, Any]]:
|
||||
"""Return video generation models from the dedicated /videos/models endpoint."""
|
||||
base_url = os.getenv("OPENROUTER_BASE_URL", OPENROUTER_BASE_URL)
|
||||
async with httpx.AsyncClient(timeout=15) as client:
|
||||
resp = client.build_request(
|
||||
"GET", f"{base_url}/videos/models", headers=_headers()
|
||||
)
|
||||
response = await client.send(resp)
|
||||
response.raise_for_status()
|
||||
return response.json().get("data", [])
|
||||
|
||||
|
||||
async def generate_image_chat(
|
||||
model: str,
|
||||
prompt: str,
|
||||
|
||||
@@ -53,7 +53,8 @@ async def test_stats_as_admin(client):
|
||||
resp = await client.get("/admin/stats", headers={"Authorization": f"Bearer {token}"})
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["users"]["total"] == 3 # 2 users + 1 admin
|
||||
# 2 users + 1 admin + 1 seeded admin (ai@allucanget.biz)
|
||||
assert data["users"]["total"] == 4
|
||||
assert "by_role" in data["users"]
|
||||
assert "refresh_tokens" in data
|
||||
|
||||
|
||||
@@ -53,7 +53,7 @@ async def _user_token(client):
|
||||
async def test_list_models(client):
|
||||
token = await _user_token(client)
|
||||
with patch(
|
||||
"backend.app.routers.ai.openrouter.list_models",
|
||||
"app.routers.ai.openrouter.list_models",
|
||||
new_callable=AsyncMock,
|
||||
return_value=FAKE_MODELS,
|
||||
):
|
||||
@@ -74,7 +74,7 @@ async def test_list_models_unauthenticated(client):
|
||||
async def test_list_models_upstream_error(client):
|
||||
token = await _user_token(client)
|
||||
with patch(
|
||||
"backend.app.routers.ai.openrouter.list_models",
|
||||
"app.routers.ai.openrouter.list_models",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=Exception("Connection refused"),
|
||||
):
|
||||
@@ -91,7 +91,7 @@ async def test_list_models_upstream_error(client):
|
||||
async def test_chat_success(client):
|
||||
token = await _user_token(client)
|
||||
with patch(
|
||||
"backend.app.routers.ai.openrouter.chat_completion",
|
||||
"app.routers.ai.openrouter.chat_completion",
|
||||
new_callable=AsyncMock,
|
||||
return_value=FAKE_CHAT_RESPONSE,
|
||||
):
|
||||
@@ -115,7 +115,7 @@ async def test_chat_success(client):
|
||||
async def test_chat_passes_parameters(client):
|
||||
token = await _user_token(client)
|
||||
mock = AsyncMock(return_value=FAKE_CHAT_RESPONSE)
|
||||
with patch("backend.app.routers.ai.openrouter.chat_completion", new_callable=AsyncMock, return_value=FAKE_CHAT_RESPONSE) as mock:
|
||||
with patch("app.routers.ai.openrouter.chat_completion", new_callable=AsyncMock, return_value=FAKE_CHAT_RESPONSE) as mock:
|
||||
await client.post(
|
||||
"/ai/chat",
|
||||
json={
|
||||
@@ -145,7 +145,7 @@ async def test_chat_unauthenticated(client):
|
||||
async def test_chat_upstream_error(client):
|
||||
token = await _user_token(client)
|
||||
with patch(
|
||||
"backend.app.routers.ai.openrouter.chat_completion",
|
||||
"app.routers.ai.openrouter.chat_completion",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=Exception("timeout"),
|
||||
):
|
||||
@@ -160,7 +160,7 @@ async def test_chat_upstream_error(client):
|
||||
async def test_chat_malformed_upstream_response(client):
|
||||
token = await _user_token(client)
|
||||
with patch(
|
||||
"backend.app.routers.ai.openrouter.chat_completion",
|
||||
"app.routers.ai.openrouter.chat_completion",
|
||||
new_callable=AsyncMock,
|
||||
return_value={"id": "x", "choices": []}, # empty choices
|
||||
):
|
||||
|
||||
+135
-69
@@ -18,15 +18,6 @@ FAKE_CHAT = {
|
||||
"usage": {"prompt_tokens": 5, "completion_tokens": 10, "total_tokens": 15},
|
||||
}
|
||||
|
||||
FAKE_IMAGE = {
|
||||
"id": "gen-img-1",
|
||||
"model": "openai/dall-e-3",
|
||||
"data": [
|
||||
{"url": "https://example.com/image.png",
|
||||
"revised_prompt": "A cat on the moon"},
|
||||
],
|
||||
}
|
||||
|
||||
FAKE_VIDEO = {
|
||||
"id": "gen-vid-1",
|
||||
"polling_url": "https://openrouter.ai/api/v1/videos/gen-vid-1",
|
||||
@@ -155,47 +146,13 @@ async def test_generate_text_upstream_error(client):
|
||||
# POST /generate/image
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def test_generate_image(client):
|
||||
token = await _user_token(client)
|
||||
with patch("app.routers.generate.openrouter.generate_image", new_callable=AsyncMock, return_value=FAKE_IMAGE):
|
||||
resp = await client.post(
|
||||
"/generate/image",
|
||||
json={"model": "openai/dall-e-3", "prompt": "A cat on the moon"},
|
||||
headers={"Authorization": f"Bearer {token}"},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["id"] == "gen-img-1"
|
||||
assert len(data["images"]) == 1
|
||||
assert data["images"][0]["url"] == "https://example.com/image.png"
|
||||
assert data["images"][0]["revised_prompt"] == "A cat on the moon"
|
||||
|
||||
|
||||
async def test_generate_image_unauthenticated(client):
|
||||
resp = await client.post("/generate/image", json={"model": "openai/dall-e-3", "prompt": "Hi"})
|
||||
assert resp.status_code == 401
|
||||
|
||||
|
||||
async def test_generate_image_upstream_error(client):
|
||||
token = await _user_token(client)
|
||||
with patch("app.routers.generate.openrouter.generate_image", new_callable=AsyncMock, side_effect=Exception("rate limit")):
|
||||
resp = await client.post(
|
||||
"/generate/image",
|
||||
json={"model": "openai/dall-e-3", "prompt": "Hi"},
|
||||
headers={"Authorization": f"Bearer {token}"},
|
||||
)
|
||||
assert resp.status_code == 502
|
||||
|
||||
|
||||
# --- Chat-based image generation (FLUX, GPT-5 Image Mini) ---
|
||||
|
||||
FAKE_IMAGE_CHAT_FLUX = {
|
||||
"id": "gen-img-chat-1",
|
||||
"model": "black-forest-labs/flux.2-klein-4b",
|
||||
"choices": [{
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "Here is your generated image.",
|
||||
"content": None,
|
||||
"images": [{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": "data:image/png;base64,abc123"},
|
||||
@@ -219,45 +176,65 @@ FAKE_IMAGE_CHAT_GPT5 = {
|
||||
}],
|
||||
}
|
||||
|
||||
FAKE_IMAGE_CHAT_GEMINI = {
|
||||
"id": "gen-img-chat-3",
|
||||
"model": "google/gemini-2.5-flash-image",
|
||||
"choices": [{
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "Here is your image.",
|
||||
"images": [{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": "data:image/png;base64,gemini123"},
|
||||
}],
|
||||
}
|
||||
}],
|
||||
}
|
||||
|
||||
async def test_generate_image_chat_flux(client):
|
||||
|
||||
async def test_generate_image(client):
|
||||
"""All models now use generate_image_chat (chat completions endpoint)."""
|
||||
token = await _user_token(client)
|
||||
with patch("app.routers.generate.openrouter.generate_image_chat", new_callable=AsyncMock, return_value=FAKE_IMAGE_CHAT_FLUX):
|
||||
with patch("app.routers.generate.openrouter.generate_image_chat", new_callable=AsyncMock, return_value=FAKE_IMAGE_CHAT_GEMINI):
|
||||
resp = await client.post(
|
||||
"/generate/image",
|
||||
json={"model": "black-forest-labs/flux.2-klein-4b",
|
||||
"prompt": "A sunset"},
|
||||
json={"model": "google/gemini-2.5-flash-image",
|
||||
"prompt": "A cat on the moon"},
|
||||
headers={"Authorization": f"Bearer {token}"},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["id"] == "gen-img-chat-1"
|
||||
assert data["id"] == "gen-img-chat-3"
|
||||
assert len(data["images"]) == 1
|
||||
assert data["images"][0]["url"] == "data:image/png;base64,abc123"
|
||||
assert data["images"][0]["url"] == "data:image/png;base64,gemini123"
|
||||
assert data["images"][0]["image_id"] is not None # stored in DB
|
||||
|
||||
|
||||
async def test_generate_image_chat_gpt5_image_mini(client):
|
||||
async def test_generate_image_unauthenticated(client):
|
||||
resp = await client.post("/generate/image", json={"model": "google/gemini-2.5-flash-image", "prompt": "Hi"})
|
||||
assert resp.status_code == 401
|
||||
|
||||
|
||||
async def test_generate_image_upstream_error(client):
|
||||
token = await _user_token(client)
|
||||
with patch("app.routers.generate.openrouter.generate_image_chat", new_callable=AsyncMock, return_value=FAKE_IMAGE_CHAT_GPT5):
|
||||
with patch("app.routers.generate.openrouter.generate_image_chat", new_callable=AsyncMock, side_effect=Exception("rate limit")):
|
||||
resp = await client.post(
|
||||
"/generate/image",
|
||||
json={"model": "openai/gpt-5-image-mini", "prompt": "A cat"},
|
||||
json={"model": "google/gemini-2.5-flash-image", "prompt": "Hi"},
|
||||
headers={"Authorization": f"Bearer {token}"},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["model"] == "openai/gpt-5-image-mini"
|
||||
assert len(data["images"]) == 1
|
||||
assert resp.status_code == 502
|
||||
|
||||
|
||||
async def test_generate_image_chat_with_image_config(client):
|
||||
async def test_generate_image_with_image_config(client):
|
||||
"""Passes aspect_ratio + image_size through to generate_image_chat."""
|
||||
token = await _user_token(client)
|
||||
mock = AsyncMock(return_value=FAKE_IMAGE_CHAT_FLUX)
|
||||
mock = AsyncMock(return_value=FAKE_IMAGE_CHAT_GEMINI)
|
||||
with patch("app.routers.generate.openrouter.generate_image_chat", mock):
|
||||
await client.post(
|
||||
"/generate/image",
|
||||
json={
|
||||
"model": "black-forest-labs/flux.2-klein-4b",
|
||||
"model": "google/gemini-2.5-flash-image",
|
||||
"prompt": "A landscape",
|
||||
"aspect_ratio": "16:9",
|
||||
"image_size": "2K",
|
||||
@@ -267,23 +244,112 @@ async def test_generate_image_chat_with_image_config(client):
|
||||
call_kwargs = mock.call_args.kwargs
|
||||
assert call_kwargs["image_config"]["aspect_ratio"] == "16:9"
|
||||
assert call_kwargs["image_config"]["image_size"] == "2K"
|
||||
assert call_kwargs["modalities"] == ["image"]
|
||||
|
||||
|
||||
async def test_generate_image_chat_unauthenticated(client):
|
||||
resp = await client.post("/generate/image", json={"model": "flux.2-klein-4b", "prompt": "Hi"})
|
||||
assert resp.status_code == 401
|
||||
|
||||
|
||||
async def test_generate_image_chat_upstream_error(client):
|
||||
async def test_generate_image_default_modalities_image_text(client):
|
||||
"""Model not in cache → default modalities = ['image', 'text']."""
|
||||
token = await _user_token(client)
|
||||
with patch("app.routers.generate.openrouter.generate_image_chat", new_callable=AsyncMock, side_effect=Exception("timeout")):
|
||||
mock = AsyncMock(return_value=FAKE_IMAGE_CHAT_GEMINI)
|
||||
with patch("app.routers.generate.openrouter.generate_image_chat", mock):
|
||||
await client.post(
|
||||
"/generate/image",
|
||||
json={"model": "google/gemini-2.5-flash-image", "prompt": "Hi"},
|
||||
headers={"Authorization": f"Bearer {token}"},
|
||||
)
|
||||
assert mock.call_args.kwargs["modalities"] == ["image", "text"]
|
||||
|
||||
|
||||
async def test_generate_image_image_only_modalities_from_cache(client):
|
||||
"""Model cached with image-only output_modalities → modalities = ['image']."""
|
||||
from app import db as db_module
|
||||
from app.services.models import get_model_output_modalities
|
||||
import json as _json
|
||||
token = await _user_token(client)
|
||||
|
||||
# Seed cache with image-only model
|
||||
conn = db_module.get_conn()
|
||||
from datetime import datetime, timezone
|
||||
now = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
conn.execute(
|
||||
"DELETE FROM models_cache WHERE model_id = 'black-forest-labs/flux.2-pro'"
|
||||
)
|
||||
conn.execute(
|
||||
"""INSERT INTO models_cache (model_id, name, modality, context_length, pricing, fetched_at, output_modalities)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?)""",
|
||||
["black-forest-labs/flux.2-pro", "FLUX.2 Pro", "image", None, None, now,
|
||||
_json.dumps(["image"])],
|
||||
)
|
||||
|
||||
mock = AsyncMock(return_value=FAKE_IMAGE_CHAT_FLUX)
|
||||
with patch("app.routers.generate.openrouter.generate_image_chat", mock):
|
||||
resp = await client.post(
|
||||
"/generate/image",
|
||||
json={"model": "black-forest-labs/flux.2-klein-4b", "prompt": "Hi"},
|
||||
json={"model": "black-forest-labs/flux.2-pro", "prompt": "Sky"},
|
||||
headers={"Authorization": f"Bearer {token}"},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
assert mock.call_args.kwargs["modalities"] == ["image"]
|
||||
|
||||
|
||||
async def test_generate_image_no_images_in_response(client):
|
||||
"""502 when model returns no images."""
|
||||
token = await _user_token(client)
|
||||
empty_response = {
|
||||
"id": "gen-empty",
|
||||
"model": "google/gemini-2.5-flash-image",
|
||||
"choices": [{"message": {"role": "assistant", "content": "ok", "images": []}}],
|
||||
}
|
||||
with patch("app.routers.generate.openrouter.generate_image_chat",
|
||||
new_callable=AsyncMock, return_value=empty_response):
|
||||
resp = await client.post(
|
||||
"/generate/image",
|
||||
json={"model": "google/gemini-2.5-flash-image", "prompt": "Hi"},
|
||||
headers={"Authorization": f"Bearer {token}"},
|
||||
)
|
||||
assert resp.status_code == 502
|
||||
assert "No images returned" in resp.json()["detail"]
|
||||
|
||||
|
||||
async def test_generate_image_flux(client):
|
||||
"""Flux model works correctly via chat completions."""
|
||||
token = await _user_token(client)
|
||||
with patch("app.routers.generate.openrouter.generate_image_chat",
|
||||
new_callable=AsyncMock, return_value=FAKE_IMAGE_CHAT_FLUX):
|
||||
resp = await client.post(
|
||||
"/generate/image",
|
||||
json={"model": "black-forest-labs/flux.2-klein-4b",
|
||||
"prompt": "A sunset"},
|
||||
headers={"Authorization": f"Bearer {token}"},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["images"][0]["url"] == "data:image/png;base64,abc123"
|
||||
|
||||
|
||||
async def test_generate_image_stored_in_db(client):
|
||||
"""Generated image row persists in generated_images table."""
|
||||
from app import db as db_module
|
||||
token = await _user_token(client)
|
||||
with patch("app.routers.generate.openrouter.generate_image_chat",
|
||||
new_callable=AsyncMock, return_value=FAKE_IMAGE_CHAT_GEMINI):
|
||||
resp = await client.post(
|
||||
"/generate/image",
|
||||
json={"model": "google/gemini-2.5-flash-image",
|
||||
"prompt": "A mountain"},
|
||||
headers={"Authorization": f"Bearer {token}"},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
image_id = resp.json()["images"][0]["image_id"]
|
||||
assert image_id is not None
|
||||
|
||||
row = db_module.get_conn().execute(
|
||||
"SELECT model_id, prompt, image_data FROM generated_images WHERE id = ?",
|
||||
[image_id],
|
||||
).fetchone()
|
||||
assert row is not None
|
||||
assert row[0] == "google/gemini-2.5-flash-image"
|
||||
assert row[1] == "A mountain"
|
||||
assert row[2] == "data:image/png;base64,gemini123"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@@ -15,6 +15,7 @@ from app.services.models import (
|
||||
_normalize_modality,
|
||||
_parse_modality,
|
||||
get_cached_models,
|
||||
get_model_output_modalities,
|
||||
is_cache_stale,
|
||||
refresh_models_cache,
|
||||
)
|
||||
@@ -28,28 +29,35 @@ FAKE_MODELS_RAW = [
|
||||
"name": "GPT-4o",
|
||||
"context_length": 128000,
|
||||
"pricing": {"prompt": "0.000005"},
|
||||
"architecture": {"modality": "text->text"},
|
||||
"architecture": {"modality": "text->text", "output_modalities": ["text"]},
|
||||
},
|
||||
{
|
||||
"id": "anthropic/claude-3-haiku",
|
||||
"name": "Claude 3 Haiku",
|
||||
"context_length": 200000,
|
||||
"pricing": {},
|
||||
"architecture": {"modality": "text+image->text"},
|
||||
"architecture": {"modality": "text+image->text", "output_modalities": ["text"]},
|
||||
},
|
||||
{
|
||||
"id": "openai/dall-e-3",
|
||||
"name": "DALL-E 3",
|
||||
"context_length": None,
|
||||
"pricing": {"image": "0.04"},
|
||||
"architecture": {"modality": "text->image"},
|
||||
"architecture": {"modality": "text->image", "output_modalities": ["image"]},
|
||||
},
|
||||
{
|
||||
"id": "openai/sora-2",
|
||||
"name": "Sora 2",
|
||||
"context_length": None,
|
||||
"pricing": {"video": "0.10"},
|
||||
"architecture": {"modality": "text->video"},
|
||||
"architecture": {"modality": "text->video", "output_modalities": ["video"]},
|
||||
},
|
||||
{
|
||||
"id": "google/gemini-2.5-flash-image",
|
||||
"name": "Gemini 2.5 Flash Image",
|
||||
"context_length": None,
|
||||
"pricing": {},
|
||||
"architecture": {"output_modalities": ["image", "text"]},
|
||||
},
|
||||
]
|
||||
|
||||
@@ -171,9 +179,9 @@ async def test_refresh_stores_models():
|
||||
return_value=FAKE_MODELS_RAW,
|
||||
):
|
||||
count = await refresh_models_cache(conn)
|
||||
assert count == 4
|
||||
assert count == 5
|
||||
all_models = get_cached_models(conn)
|
||||
assert len(all_models) == 4
|
||||
assert len(all_models) == 5
|
||||
|
||||
|
||||
async def test_refresh_replaces_old_cache():
|
||||
@@ -193,25 +201,29 @@ async def test_refresh_replaces_old_cache():
|
||||
ids = [m["id"] for m in get_cached_models(conn)]
|
||||
assert "old/model" not in ids
|
||||
assert "openai/gpt-4o" in ids
|
||||
assert len(ids) == 5
|
||||
|
||||
|
||||
def test_get_cached_models_filter_by_modality():
|
||||
conn = db_module.get_conn()
|
||||
now = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
for m in FAKE_MODELS_RAW:
|
||||
arch = m.get("architecture", {})
|
||||
modality = _parse_modality(arch.get("modality", "text->text"))
|
||||
modality = _extract_output_modality(m)
|
||||
conn.execute(
|
||||
"INSERT INTO models_cache (model_id, name, modality, fetched_at) VALUES (?, ?, ?, ?)",
|
||||
[m["id"], m["name"], modality, now],
|
||||
)
|
||||
text_models = get_cached_models(conn, modality="text")
|
||||
# gpt-4o, claude-3-haiku (gemini has output_modalities=["image","text"] → classified as "image")
|
||||
assert len(text_models) == 2
|
||||
assert all(m["modality"] == "text" for m in text_models)
|
||||
|
||||
image_models = get_cached_models(conn, modality="image")
|
||||
assert len(image_models) == 1
|
||||
assert image_models[0]["id"] == "openai/dall-e-3"
|
||||
# dall-e-3 + gemini (output_modalities starts with image)
|
||||
assert len(image_models) == 2
|
||||
image_ids = [m["id"] for m in image_models]
|
||||
assert "openai/dall-e-3" in image_ids
|
||||
assert "google/gemini-2.5-flash-image" in image_ids
|
||||
|
||||
video_models = get_cached_models(conn, modality="video")
|
||||
assert len(video_models) == 1
|
||||
@@ -233,7 +245,7 @@ async def test_list_models_endpoint_auto_refreshes(client):
|
||||
"/models/", headers={"Authorization": f"Bearer {token}"}
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
assert len(resp.json()) == 4
|
||||
assert len(resp.json()) == 5
|
||||
assert mock_fetch.await_count >= 1
|
||||
|
||||
|
||||
@@ -274,8 +286,10 @@ async def test_list_models_filter_by_modality(client):
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert len(data) == 1
|
||||
assert data[0]["id"] == "openai/dall-e-3"
|
||||
assert len(data) == 2 # dall-e-3 + gemini-2.5-flash-image
|
||||
image_ids = [m["id"] for m in data]
|
||||
assert "openai/dall-e-3" in image_ids
|
||||
assert "google/gemini-2.5-flash-image" in image_ids
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -301,7 +315,7 @@ async def test_refresh_endpoint_admin_succeeds(client):
|
||||
"/models/refresh", headers={"Authorization": f"Bearer {token}"}
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["refreshed"] == 4
|
||||
assert resp.json()["refreshed"] == 5
|
||||
|
||||
|
||||
async def test_refresh_endpoint_502_on_openrouter_error(client):
|
||||
@@ -315,3 +329,38 @@ async def test_refresh_endpoint_502_on_openrouter_error(client):
|
||||
"/models/refresh", headers={"Authorization": f"Bearer {token}"}
|
||||
)
|
||||
assert resp.status_code == 502
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Unit tests: get_model_output_modalities
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def test_get_model_output_modalities_image_only():
|
||||
conn = db_module.get_conn()
|
||||
with patch(
|
||||
"app.services.models.openrouter.list_models",
|
||||
new_callable=AsyncMock,
|
||||
return_value=FAKE_MODELS_RAW,
|
||||
):
|
||||
await refresh_models_cache(conn)
|
||||
modalities = get_model_output_modalities(conn, "openai/dall-e-3")
|
||||
assert modalities == ["image"]
|
||||
|
||||
|
||||
async def test_get_model_output_modalities_image_text():
|
||||
conn = db_module.get_conn()
|
||||
with patch(
|
||||
"app.services.models.openrouter.list_models",
|
||||
new_callable=AsyncMock,
|
||||
return_value=FAKE_MODELS_RAW,
|
||||
):
|
||||
await refresh_models_cache(conn)
|
||||
modalities = get_model_output_modalities(
|
||||
conn, "google/gemini-2.5-flash-image")
|
||||
assert set(modalities) == {"image", "text"}
|
||||
|
||||
|
||||
def test_get_model_output_modalities_unknown_model():
|
||||
conn = db_module.get_conn()
|
||||
result = get_model_output_modalities(conn, "unknown/model")
|
||||
assert result == []
|
||||
|
||||
@@ -115,7 +115,9 @@ async def test_list_users_as_admin(client):
|
||||
resp = await client.get("/users", headers={"Authorization": f"Bearer {admin_token}"})
|
||||
assert resp.status_code == 200
|
||||
assert isinstance(resp.json(), list)
|
||||
assert len(resp.json()) == 1
|
||||
assert len(resp.json()) >= 1
|
||||
emails = [u["email"] for u in resp.json()]
|
||||
assert "user@example.com" in emails
|
||||
|
||||
|
||||
async def test_list_users_as_regular_user(client):
|
||||
|
||||
Reference in New Issue
Block a user