feat: enhance database queries with error handling and improve SQL statement readability
Co-authored-by: Copilot <copilot@github.com>
This commit is contained in:
@@ -20,10 +20,18 @@ async def get_stats(_: dict = Depends(require_admin)) -> dict:
|
|||||||
sql_token_count = "SELECT COUNT(*) FROM refresh_tokens"
|
sql_token_count = "SELECT COUNT(*) FROM refresh_tokens"
|
||||||
sql_tokens_active = "SELECT COUNT(*) FROM refresh_tokens WHERE revoked = false AND expires_at > ?"
|
sql_tokens_active = "SELECT COUNT(*) FROM refresh_tokens WHERE revoked = false AND expires_at > ?"
|
||||||
now = datetime.now(timezone.utc)
|
now = datetime.now(timezone.utc)
|
||||||
total_users = conn.execute(sql_user_count).fetchone()[0]
|
|
||||||
|
total_users_row = conn.execute(sql_user_count).fetchone()
|
||||||
|
total_users = total_users_row[0] if total_users_row else 0
|
||||||
|
|
||||||
users_by_role = conn.execute(sql_user_counts).fetchall()
|
users_by_role = conn.execute(sql_user_counts).fetchall()
|
||||||
total_tokens = conn.execute(sql_token_count).fetchone()[0]
|
|
||||||
active_tokens = conn.execute(sql_tokens_active, [now]).fetchone()[0]
|
total_tokens_row = conn.execute(sql_token_count).fetchone()
|
||||||
|
total_tokens = total_tokens_row[0] if total_tokens_row else 0
|
||||||
|
|
||||||
|
active_tokens_row = conn.execute(sql_tokens_active, [now]).fetchone()
|
||||||
|
active_tokens = active_tokens_row[0] if active_tokens_row else 0
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"users": {
|
"users": {
|
||||||
"total": total_users,
|
"total": total_users,
|
||||||
@@ -41,7 +49,8 @@ async def get_stats(_: dict = Depends(require_admin)) -> dict:
|
|||||||
async def db_health(_: dict = Depends(require_admin)) -> dict:
|
async def db_health(_: dict = Depends(require_admin)) -> dict:
|
||||||
"""Verify DuckDB is reachable."""
|
"""Verify DuckDB is reachable."""
|
||||||
conn = get_conn()
|
conn = get_conn()
|
||||||
result = conn.execute("SELECT 1").fetchone()[0]
|
result_row = conn.execute("SELECT 1").fetchone()
|
||||||
|
result = result_row[0] if result_row else 0
|
||||||
return {"status": "ok" if result == 1 else "error"}
|
return {"status": "ok" if result == 1 else "error"}
|
||||||
|
|
||||||
|
|
||||||
@@ -54,9 +63,14 @@ async def purge_tokens(_: dict = Depends(require_admin)) -> dict:
|
|||||||
sql_count = "SELECT COUNT(*) FROM refresh_tokens"
|
sql_count = "SELECT COUNT(*) FROM refresh_tokens"
|
||||||
sql_delete = "DELETE FROM refresh_tokens WHERE revoked = true OR expires_at <= ?"
|
sql_delete = "DELETE FROM refresh_tokens WHERE revoked = true OR expires_at <= ?"
|
||||||
async with lock:
|
async with lock:
|
||||||
before = conn.execute(sql_count).fetchone()[0]
|
before_row = conn.execute(sql_count).fetchone()
|
||||||
|
before = before_row[0] if before_row else 0
|
||||||
|
|
||||||
conn.execute(sql_delete, [now])
|
conn.execute(sql_delete, [now])
|
||||||
after = conn.execute(sql_count).fetchone()[0]
|
|
||||||
|
after_row = conn.execute(sql_count).fetchone()
|
||||||
|
after = after_row[0] if after_row else 0
|
||||||
|
|
||||||
return {"deleted": before - after, "remaining": after}
|
return {"deleted": before - after, "remaining": after}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -24,7 +24,8 @@ async def register(body: RegisterRequest) -> dict:
|
|||||||
try:
|
try:
|
||||||
user = await register_user(body.email, body.password)
|
user = await register_user(body.email, body.password)
|
||||||
except ValueError as exc:
|
except ValueError as exc:
|
||||||
raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail=str(exc))
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_409_CONFLICT, detail=str(exc))
|
||||||
return {"id": user["id"], "email": user["email"], "role": user["role"]}
|
return {"id": user["id"], "email": user["email"], "role": user["role"]}
|
||||||
|
|
||||||
|
|
||||||
@@ -40,7 +41,8 @@ async def login(body: LoginRequest) -> TokenResponse:
|
|||||||
jti = str(uuid.uuid4())
|
jti = str(uuid.uuid4())
|
||||||
await store_refresh_token(user["id"], jti)
|
await store_refresh_token(user["id"], jti)
|
||||||
return TokenResponse(
|
return TokenResponse(
|
||||||
access_token=create_access_token(user["id"], user["email"], user["role"]),
|
access_token=create_access_token(
|
||||||
|
user["id"], user["email"], user["role"]),
|
||||||
refresh_token=create_refresh_token(user["id"], jti),
|
refresh_token=create_refresh_token(user["id"], jti),
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -73,9 +75,8 @@ async def refresh(body: RefreshRequest) -> TokenResponse:
|
|||||||
|
|
||||||
from ..db import get_conn
|
from ..db import get_conn
|
||||||
conn = get_conn()
|
conn = get_conn()
|
||||||
row = conn.execute(
|
sql_fetch = "SELECT email, role FROM users WHERE id = ?"
|
||||||
"SELECT email, role FROM users WHERE id = ?", [user_id]
|
row = conn.execute(sql_fetch, [user_id]).fetchone()
|
||||||
).fetchone()
|
|
||||||
if row is None:
|
if row is None:
|
||||||
raise credentials_error
|
raise credentials_error
|
||||||
|
|
||||||
|
|||||||
@@ -129,15 +129,13 @@ async def generate_image(
|
|||||||
user_id = current_user.get("id") or current_user.get("sub")
|
user_id = current_user.get("id") or current_user.get("sub")
|
||||||
now = datetime.now(timezone.utc).replace(tzinfo=None)
|
now = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||||
stored: list[ImageResult] = []
|
stored: list[ImageResult] = []
|
||||||
|
sql_insert = "INSERT INTO generated_images (user_id, model_id, prompt, image_data, created_at) VALUES (?, ?, ?, ?, ?) RETURNING id"
|
||||||
async with get_write_lock():
|
async with get_write_lock():
|
||||||
conn = get_conn()
|
conn = get_conn()
|
||||||
for img in images:
|
for img in images:
|
||||||
if img.url:
|
if img.url:
|
||||||
row = conn.execute(
|
row = conn.execute(
|
||||||
"""INSERT INTO generated_images (user_id, model_id, prompt, image_data, created_at)
|
sql_insert, [user_id, body.model, body.prompt, img.url, now],).fetchone()
|
||||||
VALUES (?, ?, ?, ?, ?) RETURNING id""",
|
|
||||||
[user_id, body.model, body.prompt, img.url, now],
|
|
||||||
).fetchone()
|
|
||||||
image_id = str(row[0]) if row else None
|
image_id = str(row[0]) if row else None
|
||||||
else:
|
else:
|
||||||
image_id = None
|
image_id = None
|
||||||
@@ -167,13 +165,8 @@ async def list_generated_images(
|
|||||||
"""Return all generated images for the current user, newest first."""
|
"""Return all generated images for the current user, newest first."""
|
||||||
user_id = current_user.get("id") or current_user.get("sub")
|
user_id = current_user.get("id") or current_user.get("sub")
|
||||||
conn = get_conn()
|
conn = get_conn()
|
||||||
rows = conn.execute(
|
sql_fetch = "SELECT id, model_id, prompt, image_data, created_at FROM generated_images WHERE user_id = ? ORDER BY created_at DESC"
|
||||||
"""SELECT id, model_id, prompt, image_data, created_at
|
rows = conn.execute(sql_fetch, [user_id]).fetchall()
|
||||||
FROM generated_images
|
|
||||||
WHERE user_id = ?
|
|
||||||
ORDER BY created_at DESC""",
|
|
||||||
[user_id],
|
|
||||||
).fetchall()
|
|
||||||
return [
|
return [
|
||||||
{
|
{
|
||||||
"id": str(r[0]),
|
"id": str(r[0]),
|
||||||
|
|||||||
@@ -35,7 +35,8 @@ def verify_password(plain: str, hashed: str) -> bool:
|
|||||||
# --- Tokens ---
|
# --- Tokens ---
|
||||||
|
|
||||||
def create_access_token(user_id: str, email: str, role: str) -> str:
|
def create_access_token(user_id: str, email: str, role: str) -> str:
|
||||||
expire = datetime.now(timezone.utc) + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
|
expire = datetime.now(timezone.utc) + \
|
||||||
|
timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
|
||||||
payload = {
|
payload = {
|
||||||
"sub": user_id,
|
"sub": user_id,
|
||||||
"email": email,
|
"email": email,
|
||||||
@@ -47,7 +48,8 @@ def create_access_token(user_id: str, email: str, role: str) -> str:
|
|||||||
|
|
||||||
|
|
||||||
def create_refresh_token(user_id: str, jti: str) -> str:
|
def create_refresh_token(user_id: str, jti: str) -> str:
|
||||||
expire = datetime.now(timezone.utc) + timedelta(days=REFRESH_TOKEN_EXPIRE_DAYS)
|
expire = datetime.now(timezone.utc) + \
|
||||||
|
timedelta(days=REFRESH_TOKEN_EXPIRE_DAYS)
|
||||||
payload = {
|
payload = {
|
||||||
"sub": user_id,
|
"sub": user_id,
|
||||||
"jti": jti,
|
"jti": jti,
|
||||||
@@ -68,28 +70,25 @@ async def register_user(email: str, password: str) -> dict[str, Any]:
|
|||||||
"""Insert a new user. Returns the created user row."""
|
"""Insert a new user. Returns the created user row."""
|
||||||
conn = get_conn()
|
conn = get_conn()
|
||||||
lock = get_write_lock()
|
lock = get_write_lock()
|
||||||
|
sql_check = "SELECT id FROM users WHERE email = ?"
|
||||||
|
sql_insert = "INSERT INTO users (email, password_hash) VALUES (?, ?)"
|
||||||
|
sql_fetch = "SELECT id, email, role FROM users WHERE email = ?"
|
||||||
async with lock:
|
async with lock:
|
||||||
existing = conn.execute(
|
existing = conn.execute(sql_check, [email]).fetchone()
|
||||||
"SELECT id FROM users WHERE email = ?", [email]
|
|
||||||
).fetchone()
|
|
||||||
if existing:
|
if existing:
|
||||||
raise ValueError("Email already registered.")
|
raise ValueError("Email already registered.")
|
||||||
conn.execute(
|
conn.execute(sql_insert, [email, hash_password(password)],)
|
||||||
"INSERT INTO users (email, password_hash) VALUES (?, ?)",
|
row = conn.execute(sql_fetch, [email]).fetchone()
|
||||||
[email, hash_password(password)],
|
if row is None:
|
||||||
)
|
raise RuntimeError("Failed to fetch user after registration.")
|
||||||
row = conn.execute(
|
|
||||||
"SELECT id, email, role FROM users WHERE email = ?", [email]
|
|
||||||
).fetchone()
|
|
||||||
return {"id": str(row[0]), "email": row[1], "role": row[2]}
|
return {"id": str(row[0]), "email": row[1], "role": row[2]}
|
||||||
|
|
||||||
|
|
||||||
async def authenticate_user(email: str, password: str) -> dict[str, Any] | None:
|
async def authenticate_user(email: str, password: str) -> dict[str, Any] | None:
|
||||||
"""Return user dict if credentials are valid, else None."""
|
"""Return user dict if credentials are valid, else None."""
|
||||||
conn = get_conn()
|
conn = get_conn()
|
||||||
row = conn.execute(
|
sql_fetch = "SELECT id, email, password_hash, role FROM users WHERE email = ?"
|
||||||
"SELECT id, email, password_hash, role FROM users WHERE email = ?", [email]
|
row = conn.execute(sql_fetch, [email]).fetchone()
|
||||||
).fetchone()
|
|
||||||
if row is None or not verify_password(password, row[2]):
|
if row is None or not verify_password(password, row[2]):
|
||||||
return None
|
return None
|
||||||
return {"id": str(row[0]), "email": row[1], "role": row[3]}
|
return {"id": str(row[0]), "email": row[1], "role": row[3]}
|
||||||
@@ -99,34 +98,30 @@ async def store_refresh_token(user_id: str, jti: str) -> None:
|
|||||||
"""Persist a refresh token JTI in the database."""
|
"""Persist a refresh token JTI in the database."""
|
||||||
conn = get_conn()
|
conn = get_conn()
|
||||||
lock = get_write_lock()
|
lock = get_write_lock()
|
||||||
|
sql_insert = "INSERT INTO refresh_tokens (jti, user_id, expires_at) VALUES (?, ?, ?)"
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
expires_at = datetime.now(timezone.utc) + timedelta(days=REFRESH_TOKEN_EXPIRE_DAYS)
|
expires_at = datetime.now(timezone.utc) + \
|
||||||
|
timedelta(days=REFRESH_TOKEN_EXPIRE_DAYS)
|
||||||
async with lock:
|
async with lock:
|
||||||
conn.execute(
|
conn.execute(sql_insert, [jti, user_id, expires_at])
|
||||||
"INSERT INTO refresh_tokens (jti, user_id, expires_at) VALUES (?, ?, ?)",
|
|
||||||
[jti, user_id, expires_at],
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def revoke_refresh_token(jti: str) -> None:
|
async def revoke_refresh_token(jti: str) -> None:
|
||||||
"""Mark a refresh token as revoked."""
|
"""Mark a refresh token as revoked."""
|
||||||
conn = get_conn()
|
conn = get_conn()
|
||||||
lock = get_write_lock()
|
lock = get_write_lock()
|
||||||
|
sql_update = "UPDATE refresh_tokens SET revoked = true WHERE jti = ?"
|
||||||
async with lock:
|
async with lock:
|
||||||
conn.execute(
|
conn.execute(sql_update, [jti])
|
||||||
"UPDATE refresh_tokens SET revoked = true WHERE jti = ?", [jti]
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def validate_refresh_token_jti(jti: str, user_id: str) -> bool:
|
async def validate_refresh_token_jti(jti: str, user_id: str) -> bool:
|
||||||
"""Return True if the JTI exists, is not revoked, and belongs to user_id."""
|
"""Return True if the JTI exists, is not revoked, and belongs to user_id."""
|
||||||
conn = get_conn()
|
conn = get_conn()
|
||||||
now = datetime.now(timezone.utc)
|
now = datetime.now(timezone.utc)
|
||||||
row = conn.execute(
|
sql_select = """
|
||||||
"""
|
|
||||||
SELECT 1 FROM refresh_tokens
|
SELECT 1 FROM refresh_tokens
|
||||||
WHERE jti = ? AND user_id = ? AND revoked = false AND expires_at > ?
|
WHERE jti = ? AND user_id = ? AND revoked = false AND expires_at > ?
|
||||||
""",
|
"""
|
||||||
[jti, user_id, now],
|
row = conn.execute(sql_select, [jti, user_id, now]).fetchone()
|
||||||
).fetchone()
|
|
||||||
return row is not None
|
return row is not None
|
||||||
|
|||||||
Reference in New Issue
Block a user