diff --git a/backend/app/routers/admin.py b/backend/app/routers/admin.py index 6436675..c4ea8da 100644 --- a/backend/app/routers/admin.py +++ b/backend/app/routers/admin.py @@ -20,10 +20,18 @@ async def get_stats(_: dict = Depends(require_admin)) -> dict: sql_token_count = "SELECT COUNT(*) FROM refresh_tokens" sql_tokens_active = "SELECT COUNT(*) FROM refresh_tokens WHERE revoked = false AND expires_at > ?" now = datetime.now(timezone.utc) - total_users = conn.execute(sql_user_count).fetchone()[0] + + total_users_row = conn.execute(sql_user_count).fetchone() + total_users = total_users_row[0] if total_users_row else 0 + users_by_role = conn.execute(sql_user_counts).fetchall() - total_tokens = conn.execute(sql_token_count).fetchone()[0] - active_tokens = conn.execute(sql_tokens_active, [now]).fetchone()[0] + + total_tokens_row = conn.execute(sql_token_count).fetchone() + total_tokens = total_tokens_row[0] if total_tokens_row else 0 + + active_tokens_row = conn.execute(sql_tokens_active, [now]).fetchone() + active_tokens = active_tokens_row[0] if active_tokens_row else 0 + return { "users": { "total": total_users, @@ -41,7 +49,8 @@ async def get_stats(_: dict = Depends(require_admin)) -> dict: async def db_health(_: dict = Depends(require_admin)) -> dict: """Verify DuckDB is reachable.""" conn = get_conn() - result = conn.execute("SELECT 1").fetchone()[0] + result_row = conn.execute("SELECT 1").fetchone() + result = result_row[0] if result_row else 0 return {"status": "ok" if result == 1 else "error"} @@ -54,9 +63,14 @@ async def purge_tokens(_: dict = Depends(require_admin)) -> dict: sql_count = "SELECT COUNT(*) FROM refresh_tokens" sql_delete = "DELETE FROM refresh_tokens WHERE revoked = true OR expires_at <= ?" async with lock: - before = conn.execute(sql_count).fetchone()[0] + before_row = conn.execute(sql_count).fetchone() + before = before_row[0] if before_row else 0 + conn.execute(sql_delete, [now]) - after = conn.execute(sql_count).fetchone()[0] + + after_row = conn.execute(sql_count).fetchone() + after = after_row[0] if after_row else 0 + return {"deleted": before - after, "remaining": after} diff --git a/backend/app/routers/auth.py b/backend/app/routers/auth.py index f8907cc..6707897 100644 --- a/backend/app/routers/auth.py +++ b/backend/app/routers/auth.py @@ -24,7 +24,8 @@ async def register(body: RegisterRequest) -> dict: try: user = await register_user(body.email, body.password) except ValueError as exc: - raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail=str(exc)) + raise HTTPException( + status_code=status.HTTP_409_CONFLICT, detail=str(exc)) return {"id": user["id"], "email": user["email"], "role": user["role"]} @@ -40,7 +41,8 @@ async def login(body: LoginRequest) -> TokenResponse: jti = str(uuid.uuid4()) await store_refresh_token(user["id"], jti) return TokenResponse( - access_token=create_access_token(user["id"], user["email"], user["role"]), + access_token=create_access_token( + user["id"], user["email"], user["role"]), refresh_token=create_refresh_token(user["id"], jti), ) @@ -73,9 +75,8 @@ async def refresh(body: RefreshRequest) -> TokenResponse: from ..db import get_conn conn = get_conn() - row = conn.execute( - "SELECT email, role FROM users WHERE id = ?", [user_id] - ).fetchone() + sql_fetch = "SELECT email, role FROM users WHERE id = ?" + row = conn.execute(sql_fetch, [user_id]).fetchone() if row is None: raise credentials_error diff --git a/backend/app/routers/generate.py b/backend/app/routers/generate.py index 093c243..490523e 100644 --- a/backend/app/routers/generate.py +++ b/backend/app/routers/generate.py @@ -129,15 +129,13 @@ async def generate_image( user_id = current_user.get("id") or current_user.get("sub") now = datetime.now(timezone.utc).replace(tzinfo=None) stored: list[ImageResult] = [] + sql_insert = "INSERT INTO generated_images (user_id, model_id, prompt, image_data, created_at) VALUES (?, ?, ?, ?, ?) RETURNING id" async with get_write_lock(): conn = get_conn() for img in images: if img.url: row = conn.execute( - """INSERT INTO generated_images (user_id, model_id, prompt, image_data, created_at) - VALUES (?, ?, ?, ?, ?) RETURNING id""", - [user_id, body.model, body.prompt, img.url, now], - ).fetchone() + sql_insert, [user_id, body.model, body.prompt, img.url, now],).fetchone() image_id = str(row[0]) if row else None else: image_id = None @@ -167,13 +165,8 @@ async def list_generated_images( """Return all generated images for the current user, newest first.""" user_id = current_user.get("id") or current_user.get("sub") conn = get_conn() - rows = conn.execute( - """SELECT id, model_id, prompt, image_data, created_at - FROM generated_images - WHERE user_id = ? - ORDER BY created_at DESC""", - [user_id], - ).fetchall() + sql_fetch = "SELECT id, model_id, prompt, image_data, created_at FROM generated_images WHERE user_id = ? ORDER BY created_at DESC" + rows = conn.execute(sql_fetch, [user_id]).fetchall() return [ { "id": str(r[0]), @@ -245,12 +238,12 @@ async def generate_video( conn = get_conn() conn.execute( """INSERT INTO generated_videos (user_id, job_id, model_id, prompt, polling_url, status, created_at, updated_at) - VALUES (?, ?, ?, ?, ?, ?, ?, ?)""", + VALUES (?, ?, ?, ?, ?, ?, ?, ?)""", [user_id, job_id, body.model, body.prompt, - polling_url, job_status, now, now], + polling_url, job_status, now, now], ) - urls = result.get("unsigned_urls") or result.get("video_urls") + urls = result.get("unsigned_urls") or result.get("video_urls") return VideoResponse( id=job_id, model=body.model, @@ -298,22 +291,22 @@ async def generate_video_from_image( conn = get_conn() conn.execute( """INSERT INTO generated_videos (user_id, job_id, model_id, prompt, polling_url, status, created_at, updated_at) - VALUES (?, ?, ?, ?, ?, ?, ?, ?)""", + VALUES (?, ?, ?, ?, ?, ?, ?, ?)""", [user_id, job_id, body.model, body.prompt, polling_url, job_status, now, now], ) - urls = result.get("unsigned_urls") or result.get("video_urls") - return VideoResponse( - id=job_id, - model=body.model, - status=job_status, - polling_url=polling_url, - video_urls=urls, - video_url=(urls or [None])[0], - error=result.get("error"), - metadata=result.get("metadata"), - ) + urls = result.get("unsigned_urls") or result.get("video_urls") + return VideoResponse( + id=job_id, + model=body.model, + status=job_status, + polling_url=polling_url, + video_urls=urls, + video_url=(urls or [None])[0], + error=result.get("error"), + metadata=result.get("metadata"), + ) @router.get("/video/status", response_model=VideoResponse) @@ -339,8 +332,8 @@ async def poll_video_status( conn = get_conn() conn.execute( """UPDATE generated_videos - SET status = ?, video_url = ?, updated_at = ? - WHERE job_id = ?""", + SET status = ?, video_url = ?, updated_at = ? + WHERE job_id = ?""", [job_status, video_url, now, result.get("id", "")], ) diff --git a/backend/app/services/auth.py b/backend/app/services/auth.py index c1e5ac0..c7e40e8 100644 --- a/backend/app/services/auth.py +++ b/backend/app/services/auth.py @@ -35,7 +35,8 @@ def verify_password(plain: str, hashed: str) -> bool: # --- Tokens --- def create_access_token(user_id: str, email: str, role: str) -> str: - expire = datetime.now(timezone.utc) + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) + expire = datetime.now(timezone.utc) + \ + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) payload = { "sub": user_id, "email": email, @@ -47,7 +48,8 @@ def create_access_token(user_id: str, email: str, role: str) -> str: def create_refresh_token(user_id: str, jti: str) -> str: - expire = datetime.now(timezone.utc) + timedelta(days=REFRESH_TOKEN_EXPIRE_DAYS) + expire = datetime.now(timezone.utc) + \ + timedelta(days=REFRESH_TOKEN_EXPIRE_DAYS) payload = { "sub": user_id, "jti": jti, @@ -68,28 +70,25 @@ async def register_user(email: str, password: str) -> dict[str, Any]: """Insert a new user. Returns the created user row.""" conn = get_conn() lock = get_write_lock() + sql_check = "SELECT id FROM users WHERE email = ?" + sql_insert = "INSERT INTO users (email, password_hash) VALUES (?, ?)" + sql_fetch = "SELECT id, email, role FROM users WHERE email = ?" async with lock: - existing = conn.execute( - "SELECT id FROM users WHERE email = ?", [email] - ).fetchone() + existing = conn.execute(sql_check, [email]).fetchone() if existing: raise ValueError("Email already registered.") - conn.execute( - "INSERT INTO users (email, password_hash) VALUES (?, ?)", - [email, hash_password(password)], - ) - row = conn.execute( - "SELECT id, email, role FROM users WHERE email = ?", [email] - ).fetchone() + conn.execute(sql_insert, [email, hash_password(password)],) + row = conn.execute(sql_fetch, [email]).fetchone() + if row is None: + raise RuntimeError("Failed to fetch user after registration.") return {"id": str(row[0]), "email": row[1], "role": row[2]} async def authenticate_user(email: str, password: str) -> dict[str, Any] | None: """Return user dict if credentials are valid, else None.""" conn = get_conn() - row = conn.execute( - "SELECT id, email, password_hash, role FROM users WHERE email = ?", [email] - ).fetchone() + sql_fetch = "SELECT id, email, password_hash, role FROM users WHERE email = ?" + row = conn.execute(sql_fetch, [email]).fetchone() if row is None or not verify_password(password, row[2]): return None return {"id": str(row[0]), "email": row[1], "role": row[3]} @@ -99,34 +98,30 @@ async def store_refresh_token(user_id: str, jti: str) -> None: """Persist a refresh token JTI in the database.""" conn = get_conn() lock = get_write_lock() + sql_insert = "INSERT INTO refresh_tokens (jti, user_id, expires_at) VALUES (?, ?, ?)" from datetime import timedelta - expires_at = datetime.now(timezone.utc) + timedelta(days=REFRESH_TOKEN_EXPIRE_DAYS) + expires_at = datetime.now(timezone.utc) + \ + timedelta(days=REFRESH_TOKEN_EXPIRE_DAYS) async with lock: - conn.execute( - "INSERT INTO refresh_tokens (jti, user_id, expires_at) VALUES (?, ?, ?)", - [jti, user_id, expires_at], - ) + conn.execute(sql_insert, [jti, user_id, expires_at]) async def revoke_refresh_token(jti: str) -> None: """Mark a refresh token as revoked.""" conn = get_conn() lock = get_write_lock() + sql_update = "UPDATE refresh_tokens SET revoked = true WHERE jti = ?" async with lock: - conn.execute( - "UPDATE refresh_tokens SET revoked = true WHERE jti = ?", [jti] - ) + conn.execute(sql_update, [jti]) async def validate_refresh_token_jti(jti: str, user_id: str) -> bool: """Return True if the JTI exists, is not revoked, and belongs to user_id.""" conn = get_conn() now = datetime.now(timezone.utc) - row = conn.execute( - """ + sql_select = """ SELECT 1 FROM refresh_tokens WHERE jti = ? AND user_id = ? AND revoked = false AND expires_at > ? - """, - [jti, user_id, now], - ).fetchone() + """ + row = conn.execute(sql_select, [jti, user_id, now]).fetchone() return row is not None