feat: enhance database queries with error handling and improve SQL statement readability

Co-authored-by: Copilot <copilot@github.com>
This commit is contained in:
2026-04-29 16:28:22 +02:00
parent df85676fa2
commit 8e36f48527
4 changed files with 70 additions and 67 deletions
+20 -6
View File
@@ -20,10 +20,18 @@ async def get_stats(_: dict = Depends(require_admin)) -> dict:
sql_token_count = "SELECT COUNT(*) FROM refresh_tokens" sql_token_count = "SELECT COUNT(*) FROM refresh_tokens"
sql_tokens_active = "SELECT COUNT(*) FROM refresh_tokens WHERE revoked = false AND expires_at > ?" sql_tokens_active = "SELECT COUNT(*) FROM refresh_tokens WHERE revoked = false AND expires_at > ?"
now = datetime.now(timezone.utc) now = datetime.now(timezone.utc)
total_users = conn.execute(sql_user_count).fetchone()[0]
total_users_row = conn.execute(sql_user_count).fetchone()
total_users = total_users_row[0] if total_users_row else 0
users_by_role = conn.execute(sql_user_counts).fetchall() users_by_role = conn.execute(sql_user_counts).fetchall()
total_tokens = conn.execute(sql_token_count).fetchone()[0]
active_tokens = conn.execute(sql_tokens_active, [now]).fetchone()[0] total_tokens_row = conn.execute(sql_token_count).fetchone()
total_tokens = total_tokens_row[0] if total_tokens_row else 0
active_tokens_row = conn.execute(sql_tokens_active, [now]).fetchone()
active_tokens = active_tokens_row[0] if active_tokens_row else 0
return { return {
"users": { "users": {
"total": total_users, "total": total_users,
@@ -41,7 +49,8 @@ async def get_stats(_: dict = Depends(require_admin)) -> dict:
async def db_health(_: dict = Depends(require_admin)) -> dict: async def db_health(_: dict = Depends(require_admin)) -> dict:
"""Verify DuckDB is reachable.""" """Verify DuckDB is reachable."""
conn = get_conn() conn = get_conn()
result = conn.execute("SELECT 1").fetchone()[0] result_row = conn.execute("SELECT 1").fetchone()
result = result_row[0] if result_row else 0
return {"status": "ok" if result == 1 else "error"} return {"status": "ok" if result == 1 else "error"}
@@ -54,9 +63,14 @@ async def purge_tokens(_: dict = Depends(require_admin)) -> dict:
sql_count = "SELECT COUNT(*) FROM refresh_tokens" sql_count = "SELECT COUNT(*) FROM refresh_tokens"
sql_delete = "DELETE FROM refresh_tokens WHERE revoked = true OR expires_at <= ?" sql_delete = "DELETE FROM refresh_tokens WHERE revoked = true OR expires_at <= ?"
async with lock: async with lock:
before = conn.execute(sql_count).fetchone()[0] before_row = conn.execute(sql_count).fetchone()
before = before_row[0] if before_row else 0
conn.execute(sql_delete, [now]) conn.execute(sql_delete, [now])
after = conn.execute(sql_count).fetchone()[0]
after_row = conn.execute(sql_count).fetchone()
after = after_row[0] if after_row else 0
return {"deleted": before - after, "remaining": after} return {"deleted": before - after, "remaining": after}
+6 -5
View File
@@ -24,7 +24,8 @@ async def register(body: RegisterRequest) -> dict:
try: try:
user = await register_user(body.email, body.password) user = await register_user(body.email, body.password)
except ValueError as exc: except ValueError as exc:
raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail=str(exc)) raise HTTPException(
status_code=status.HTTP_409_CONFLICT, detail=str(exc))
return {"id": user["id"], "email": user["email"], "role": user["role"]} return {"id": user["id"], "email": user["email"], "role": user["role"]}
@@ -40,7 +41,8 @@ async def login(body: LoginRequest) -> TokenResponse:
jti = str(uuid.uuid4()) jti = str(uuid.uuid4())
await store_refresh_token(user["id"], jti) await store_refresh_token(user["id"], jti)
return TokenResponse( return TokenResponse(
access_token=create_access_token(user["id"], user["email"], user["role"]), access_token=create_access_token(
user["id"], user["email"], user["role"]),
refresh_token=create_refresh_token(user["id"], jti), refresh_token=create_refresh_token(user["id"], jti),
) )
@@ -73,9 +75,8 @@ async def refresh(body: RefreshRequest) -> TokenResponse:
from ..db import get_conn from ..db import get_conn
conn = get_conn() conn = get_conn()
row = conn.execute( sql_fetch = "SELECT email, role FROM users WHERE id = ?"
"SELECT email, role FROM users WHERE id = ?", [user_id] row = conn.execute(sql_fetch, [user_id]).fetchone()
).fetchone()
if row is None: if row is None:
raise credentials_error raise credentials_error
+21 -28
View File
@@ -129,15 +129,13 @@ async def generate_image(
user_id = current_user.get("id") or current_user.get("sub") user_id = current_user.get("id") or current_user.get("sub")
now = datetime.now(timezone.utc).replace(tzinfo=None) now = datetime.now(timezone.utc).replace(tzinfo=None)
stored: list[ImageResult] = [] stored: list[ImageResult] = []
sql_insert = "INSERT INTO generated_images (user_id, model_id, prompt, image_data, created_at) VALUES (?, ?, ?, ?, ?) RETURNING id"
async with get_write_lock(): async with get_write_lock():
conn = get_conn() conn = get_conn()
for img in images: for img in images:
if img.url: if img.url:
row = conn.execute( row = conn.execute(
"""INSERT INTO generated_images (user_id, model_id, prompt, image_data, created_at) sql_insert, [user_id, body.model, body.prompt, img.url, now],).fetchone()
VALUES (?, ?, ?, ?, ?) RETURNING id""",
[user_id, body.model, body.prompt, img.url, now],
).fetchone()
image_id = str(row[0]) if row else None image_id = str(row[0]) if row else None
else: else:
image_id = None image_id = None
@@ -167,13 +165,8 @@ async def list_generated_images(
"""Return all generated images for the current user, newest first.""" """Return all generated images for the current user, newest first."""
user_id = current_user.get("id") or current_user.get("sub") user_id = current_user.get("id") or current_user.get("sub")
conn = get_conn() conn = get_conn()
rows = conn.execute( sql_fetch = "SELECT id, model_id, prompt, image_data, created_at FROM generated_images WHERE user_id = ? ORDER BY created_at DESC"
"""SELECT id, model_id, prompt, image_data, created_at rows = conn.execute(sql_fetch, [user_id]).fetchall()
FROM generated_images
WHERE user_id = ?
ORDER BY created_at DESC""",
[user_id],
).fetchall()
return [ return [
{ {
"id": str(r[0]), "id": str(r[0]),
@@ -245,12 +238,12 @@ async def generate_video(
conn = get_conn() conn = get_conn()
conn.execute( conn.execute(
"""INSERT INTO generated_videos (user_id, job_id, model_id, prompt, polling_url, status, created_at, updated_at) """INSERT INTO generated_videos (user_id, job_id, model_id, prompt, polling_url, status, created_at, updated_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)""", VALUES (?, ?, ?, ?, ?, ?, ?, ?)""",
[user_id, job_id, body.model, body.prompt, [user_id, job_id, body.model, body.prompt,
polling_url, job_status, now, now], polling_url, job_status, now, now],
) )
urls = result.get("unsigned_urls") or result.get("video_urls") urls = result.get("unsigned_urls") or result.get("video_urls")
return VideoResponse( return VideoResponse(
id=job_id, id=job_id,
model=body.model, model=body.model,
@@ -298,22 +291,22 @@ async def generate_video_from_image(
conn = get_conn() conn = get_conn()
conn.execute( conn.execute(
"""INSERT INTO generated_videos (user_id, job_id, model_id, prompt, polling_url, status, created_at, updated_at) """INSERT INTO generated_videos (user_id, job_id, model_id, prompt, polling_url, status, created_at, updated_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)""", VALUES (?, ?, ?, ?, ?, ?, ?, ?)""",
[user_id, job_id, body.model, body.prompt, [user_id, job_id, body.model, body.prompt,
polling_url, job_status, now, now], polling_url, job_status, now, now],
) )
urls = result.get("unsigned_urls") or result.get("video_urls") urls = result.get("unsigned_urls") or result.get("video_urls")
return VideoResponse( return VideoResponse(
id=job_id, id=job_id,
model=body.model, model=body.model,
status=job_status, status=job_status,
polling_url=polling_url, polling_url=polling_url,
video_urls=urls, video_urls=urls,
video_url=(urls or [None])[0], video_url=(urls or [None])[0],
error=result.get("error"), error=result.get("error"),
metadata=result.get("metadata"), metadata=result.get("metadata"),
) )
@router.get("/video/status", response_model=VideoResponse) @router.get("/video/status", response_model=VideoResponse)
@@ -339,8 +332,8 @@ async def poll_video_status(
conn = get_conn() conn = get_conn()
conn.execute( conn.execute(
"""UPDATE generated_videos """UPDATE generated_videos
SET status = ?, video_url = ?, updated_at = ? SET status = ?, video_url = ?, updated_at = ?
WHERE job_id = ?""", WHERE job_id = ?""",
[job_status, video_url, now, result.get("id", "")], [job_status, video_url, now, result.get("id", "")],
) )
+23 -28
View File
@@ -35,7 +35,8 @@ def verify_password(plain: str, hashed: str) -> bool:
# --- Tokens --- # --- Tokens ---
def create_access_token(user_id: str, email: str, role: str) -> str: def create_access_token(user_id: str, email: str, role: str) -> str:
expire = datetime.now(timezone.utc) + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) expire = datetime.now(timezone.utc) + \
timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
payload = { payload = {
"sub": user_id, "sub": user_id,
"email": email, "email": email,
@@ -47,7 +48,8 @@ def create_access_token(user_id: str, email: str, role: str) -> str:
def create_refresh_token(user_id: str, jti: str) -> str: def create_refresh_token(user_id: str, jti: str) -> str:
expire = datetime.now(timezone.utc) + timedelta(days=REFRESH_TOKEN_EXPIRE_DAYS) expire = datetime.now(timezone.utc) + \
timedelta(days=REFRESH_TOKEN_EXPIRE_DAYS)
payload = { payload = {
"sub": user_id, "sub": user_id,
"jti": jti, "jti": jti,
@@ -68,28 +70,25 @@ async def register_user(email: str, password: str) -> dict[str, Any]:
"""Insert a new user. Returns the created user row.""" """Insert a new user. Returns the created user row."""
conn = get_conn() conn = get_conn()
lock = get_write_lock() lock = get_write_lock()
sql_check = "SELECT id FROM users WHERE email = ?"
sql_insert = "INSERT INTO users (email, password_hash) VALUES (?, ?)"
sql_fetch = "SELECT id, email, role FROM users WHERE email = ?"
async with lock: async with lock:
existing = conn.execute( existing = conn.execute(sql_check, [email]).fetchone()
"SELECT id FROM users WHERE email = ?", [email]
).fetchone()
if existing: if existing:
raise ValueError("Email already registered.") raise ValueError("Email already registered.")
conn.execute( conn.execute(sql_insert, [email, hash_password(password)],)
"INSERT INTO users (email, password_hash) VALUES (?, ?)", row = conn.execute(sql_fetch, [email]).fetchone()
[email, hash_password(password)], if row is None:
) raise RuntimeError("Failed to fetch user after registration.")
row = conn.execute(
"SELECT id, email, role FROM users WHERE email = ?", [email]
).fetchone()
return {"id": str(row[0]), "email": row[1], "role": row[2]} return {"id": str(row[0]), "email": row[1], "role": row[2]}
async def authenticate_user(email: str, password: str) -> dict[str, Any] | None: async def authenticate_user(email: str, password: str) -> dict[str, Any] | None:
"""Return user dict if credentials are valid, else None.""" """Return user dict if credentials are valid, else None."""
conn = get_conn() conn = get_conn()
row = conn.execute( sql_fetch = "SELECT id, email, password_hash, role FROM users WHERE email = ?"
"SELECT id, email, password_hash, role FROM users WHERE email = ?", [email] row = conn.execute(sql_fetch, [email]).fetchone()
).fetchone()
if row is None or not verify_password(password, row[2]): if row is None or not verify_password(password, row[2]):
return None return None
return {"id": str(row[0]), "email": row[1], "role": row[3]} return {"id": str(row[0]), "email": row[1], "role": row[3]}
@@ -99,34 +98,30 @@ async def store_refresh_token(user_id: str, jti: str) -> None:
"""Persist a refresh token JTI in the database.""" """Persist a refresh token JTI in the database."""
conn = get_conn() conn = get_conn()
lock = get_write_lock() lock = get_write_lock()
sql_insert = "INSERT INTO refresh_tokens (jti, user_id, expires_at) VALUES (?, ?, ?)"
from datetime import timedelta from datetime import timedelta
expires_at = datetime.now(timezone.utc) + timedelta(days=REFRESH_TOKEN_EXPIRE_DAYS) expires_at = datetime.now(timezone.utc) + \
timedelta(days=REFRESH_TOKEN_EXPIRE_DAYS)
async with lock: async with lock:
conn.execute( conn.execute(sql_insert, [jti, user_id, expires_at])
"INSERT INTO refresh_tokens (jti, user_id, expires_at) VALUES (?, ?, ?)",
[jti, user_id, expires_at],
)
async def revoke_refresh_token(jti: str) -> None: async def revoke_refresh_token(jti: str) -> None:
"""Mark a refresh token as revoked.""" """Mark a refresh token as revoked."""
conn = get_conn() conn = get_conn()
lock = get_write_lock() lock = get_write_lock()
sql_update = "UPDATE refresh_tokens SET revoked = true WHERE jti = ?"
async with lock: async with lock:
conn.execute( conn.execute(sql_update, [jti])
"UPDATE refresh_tokens SET revoked = true WHERE jti = ?", [jti]
)
async def validate_refresh_token_jti(jti: str, user_id: str) -> bool: async def validate_refresh_token_jti(jti: str, user_id: str) -> bool:
"""Return True if the JTI exists, is not revoked, and belongs to user_id.""" """Return True if the JTI exists, is not revoked, and belongs to user_id."""
conn = get_conn() conn = get_conn()
now = datetime.now(timezone.utc) now = datetime.now(timezone.utc)
row = conn.execute( sql_select = """
"""
SELECT 1 FROM refresh_tokens SELECT 1 FROM refresh_tokens
WHERE jti = ? AND user_id = ? AND revoked = false AND expires_at > ? WHERE jti = ? AND user_id = ? AND revoked = false AND expires_at > ?
""", """
[jti, user_id, now], row = conn.execute(sql_select, [jti, user_id, now]).fetchone()
).fetchone()
return row is not None return row is not None