feat: enhance database queries with error handling and improve SQL statement readability
Co-authored-by: Copilot <copilot@github.com>
This commit is contained in:
@@ -20,10 +20,18 @@ async def get_stats(_: dict = Depends(require_admin)) -> dict:
|
||||
sql_token_count = "SELECT COUNT(*) FROM refresh_tokens"
|
||||
sql_tokens_active = "SELECT COUNT(*) FROM refresh_tokens WHERE revoked = false AND expires_at > ?"
|
||||
now = datetime.now(timezone.utc)
|
||||
total_users = conn.execute(sql_user_count).fetchone()[0]
|
||||
|
||||
total_users_row = conn.execute(sql_user_count).fetchone()
|
||||
total_users = total_users_row[0] if total_users_row else 0
|
||||
|
||||
users_by_role = conn.execute(sql_user_counts).fetchall()
|
||||
total_tokens = conn.execute(sql_token_count).fetchone()[0]
|
||||
active_tokens = conn.execute(sql_tokens_active, [now]).fetchone()[0]
|
||||
|
||||
total_tokens_row = conn.execute(sql_token_count).fetchone()
|
||||
total_tokens = total_tokens_row[0] if total_tokens_row else 0
|
||||
|
||||
active_tokens_row = conn.execute(sql_tokens_active, [now]).fetchone()
|
||||
active_tokens = active_tokens_row[0] if active_tokens_row else 0
|
||||
|
||||
return {
|
||||
"users": {
|
||||
"total": total_users,
|
||||
@@ -41,7 +49,8 @@ async def get_stats(_: dict = Depends(require_admin)) -> dict:
|
||||
async def db_health(_: dict = Depends(require_admin)) -> dict:
|
||||
"""Verify DuckDB is reachable."""
|
||||
conn = get_conn()
|
||||
result = conn.execute("SELECT 1").fetchone()[0]
|
||||
result_row = conn.execute("SELECT 1").fetchone()
|
||||
result = result_row[0] if result_row else 0
|
||||
return {"status": "ok" if result == 1 else "error"}
|
||||
|
||||
|
||||
@@ -54,9 +63,14 @@ async def purge_tokens(_: dict = Depends(require_admin)) -> dict:
|
||||
sql_count = "SELECT COUNT(*) FROM refresh_tokens"
|
||||
sql_delete = "DELETE FROM refresh_tokens WHERE revoked = true OR expires_at <= ?"
|
||||
async with lock:
|
||||
before = conn.execute(sql_count).fetchone()[0]
|
||||
before_row = conn.execute(sql_count).fetchone()
|
||||
before = before_row[0] if before_row else 0
|
||||
|
||||
conn.execute(sql_delete, [now])
|
||||
after = conn.execute(sql_count).fetchone()[0]
|
||||
|
||||
after_row = conn.execute(sql_count).fetchone()
|
||||
after = after_row[0] if after_row else 0
|
||||
|
||||
return {"deleted": before - after, "remaining": after}
|
||||
|
||||
|
||||
|
||||
@@ -24,7 +24,8 @@ async def register(body: RegisterRequest) -> dict:
|
||||
try:
|
||||
user = await register_user(body.email, body.password)
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail=str(exc))
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT, detail=str(exc))
|
||||
return {"id": user["id"], "email": user["email"], "role": user["role"]}
|
||||
|
||||
|
||||
@@ -40,7 +41,8 @@ async def login(body: LoginRequest) -> TokenResponse:
|
||||
jti = str(uuid.uuid4())
|
||||
await store_refresh_token(user["id"], jti)
|
||||
return TokenResponse(
|
||||
access_token=create_access_token(user["id"], user["email"], user["role"]),
|
||||
access_token=create_access_token(
|
||||
user["id"], user["email"], user["role"]),
|
||||
refresh_token=create_refresh_token(user["id"], jti),
|
||||
)
|
||||
|
||||
@@ -73,9 +75,8 @@ async def refresh(body: RefreshRequest) -> TokenResponse:
|
||||
|
||||
from ..db import get_conn
|
||||
conn = get_conn()
|
||||
row = conn.execute(
|
||||
"SELECT email, role FROM users WHERE id = ?", [user_id]
|
||||
).fetchone()
|
||||
sql_fetch = "SELECT email, role FROM users WHERE id = ?"
|
||||
row = conn.execute(sql_fetch, [user_id]).fetchone()
|
||||
if row is None:
|
||||
raise credentials_error
|
||||
|
||||
|
||||
@@ -129,15 +129,13 @@ async def generate_image(
|
||||
user_id = current_user.get("id") or current_user.get("sub")
|
||||
now = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
stored: list[ImageResult] = []
|
||||
sql_insert = "INSERT INTO generated_images (user_id, model_id, prompt, image_data, created_at) VALUES (?, ?, ?, ?, ?) RETURNING id"
|
||||
async with get_write_lock():
|
||||
conn = get_conn()
|
||||
for img in images:
|
||||
if img.url:
|
||||
row = conn.execute(
|
||||
"""INSERT INTO generated_images (user_id, model_id, prompt, image_data, created_at)
|
||||
VALUES (?, ?, ?, ?, ?) RETURNING id""",
|
||||
[user_id, body.model, body.prompt, img.url, now],
|
||||
).fetchone()
|
||||
sql_insert, [user_id, body.model, body.prompt, img.url, now],).fetchone()
|
||||
image_id = str(row[0]) if row else None
|
||||
else:
|
||||
image_id = None
|
||||
@@ -167,13 +165,8 @@ async def list_generated_images(
|
||||
"""Return all generated images for the current user, newest first."""
|
||||
user_id = current_user.get("id") or current_user.get("sub")
|
||||
conn = get_conn()
|
||||
rows = conn.execute(
|
||||
"""SELECT id, model_id, prompt, image_data, created_at
|
||||
FROM generated_images
|
||||
WHERE user_id = ?
|
||||
ORDER BY created_at DESC""",
|
||||
[user_id],
|
||||
).fetchall()
|
||||
sql_fetch = "SELECT id, model_id, prompt, image_data, created_at FROM generated_images WHERE user_id = ? ORDER BY created_at DESC"
|
||||
rows = conn.execute(sql_fetch, [user_id]).fetchall()
|
||||
return [
|
||||
{
|
||||
"id": str(r[0]),
|
||||
@@ -245,12 +238,12 @@ async def generate_video(
|
||||
conn = get_conn()
|
||||
conn.execute(
|
||||
"""INSERT INTO generated_videos (user_id, job_id, model_id, prompt, polling_url, status, created_at, updated_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?)""",
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?)""",
|
||||
[user_id, job_id, body.model, body.prompt,
|
||||
polling_url, job_status, now, now],
|
||||
polling_url, job_status, now, now],
|
||||
)
|
||||
|
||||
urls = result.get("unsigned_urls") or result.get("video_urls")
|
||||
urls = result.get("unsigned_urls") or result.get("video_urls")
|
||||
return VideoResponse(
|
||||
id=job_id,
|
||||
model=body.model,
|
||||
@@ -298,22 +291,22 @@ async def generate_video_from_image(
|
||||
conn = get_conn()
|
||||
conn.execute(
|
||||
"""INSERT INTO generated_videos (user_id, job_id, model_id, prompt, polling_url, status, created_at, updated_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?)""",
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?)""",
|
||||
[user_id, job_id, body.model, body.prompt,
|
||||
polling_url, job_status, now, now],
|
||||
)
|
||||
|
||||
urls = result.get("unsigned_urls") or result.get("video_urls")
|
||||
return VideoResponse(
|
||||
id=job_id,
|
||||
model=body.model,
|
||||
status=job_status,
|
||||
polling_url=polling_url,
|
||||
video_urls=urls,
|
||||
video_url=(urls or [None])[0],
|
||||
error=result.get("error"),
|
||||
metadata=result.get("metadata"),
|
||||
)
|
||||
urls = result.get("unsigned_urls") or result.get("video_urls")
|
||||
return VideoResponse(
|
||||
id=job_id,
|
||||
model=body.model,
|
||||
status=job_status,
|
||||
polling_url=polling_url,
|
||||
video_urls=urls,
|
||||
video_url=(urls or [None])[0],
|
||||
error=result.get("error"),
|
||||
metadata=result.get("metadata"),
|
||||
)
|
||||
|
||||
|
||||
@router.get("/video/status", response_model=VideoResponse)
|
||||
@@ -339,8 +332,8 @@ async def poll_video_status(
|
||||
conn = get_conn()
|
||||
conn.execute(
|
||||
"""UPDATE generated_videos
|
||||
SET status = ?, video_url = ?, updated_at = ?
|
||||
WHERE job_id = ?""",
|
||||
SET status = ?, video_url = ?, updated_at = ?
|
||||
WHERE job_id = ?""",
|
||||
[job_status, video_url, now, result.get("id", "")],
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user