Files

588 lines
20 KiB
Python

"""Flask frontend application."""
import functools
from datetime import datetime, timezone
import httpx
from flask import (
Flask,
Response,
flash,
jsonify,
redirect,
render_template,
request,
session,
url_for,
)
from .config import Config
app = Flask(__name__)
app.config.from_object(Config)
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
@app.template_filter("fromisoformat")
def from_iso_format(s: str) -> datetime:
"""Convert ISO 8601 string to datetime object."""
return datetime.fromisoformat(s)
@app.template_filter("humantime")
def human_time(dt: datetime) -> str:
"""Format a datetime object into a human-readable relative time."""
now = datetime.now(timezone.utc)
# Ensure dt is aware for comparison
if dt.tzinfo is None:
dt = dt.replace(tzinfo=timezone.utc)
diff = now - dt
seconds = diff.total_seconds()
if seconds < 60:
return "just now"
elif seconds < 3600:
minutes = int(seconds / 60)
return f"{minutes} minute{'s' if minutes > 1 else ''} ago"
elif seconds < 86400:
hours = int(seconds / 3600)
return f"{hours} hour{'s' if hours > 1 else ''} ago"
elif seconds < 2592000:
days = int(seconds / 86400)
return f"{days} day{'s' if days > 1 else ''} ago"
elif seconds < 31536000:
months = int(seconds / 2592000)
return f"{months} month{'s' if months > 1 else ''} ago"
else:
years = int(seconds / 31536000)
return f"{years} year{'s' if years > 1 else ''} ago"
def _backend(path: str) -> str:
return f"{app.config['BACKEND_URL']}{path}"
def _api(method: str, path: str, *, token: str | None = None, **kwargs):
headers = kwargs.pop("headers", {})
if token:
headers["Authorization"] = f"Bearer {token}"
return httpx.request(method, _backend(path), headers=headers, timeout=30, **kwargs)
def _model_matches_modality(model: dict, modality: str) -> bool:
"""Heuristic fallback when backend modality filter returns empty."""
model_modality = (model.get("modality") or "").lower()
if model_modality == modality:
return True
text = f"{model.get('id', '')} {model.get('name', '')}".lower()
keywords = {
"image": ["image", "dall-e", "flux", "stable-diffusion", "sdxl", "recraft", "ideogram", "gpt-image"],
"video": ["video", "sora", "runway", "veo", "kling", "pika", "luma", "wan"],
"audio": ["audio", "speech", "voice", "tts", "transcribe", "whisper"],
}
if modality in keywords:
return any(k in text for k in keywords[modality])
if modality == "text":
non_text_hits = any(
k in text for k in keywords["image"] + keywords["video"] + keywords["audio"])
return not non_text_hits
return False
def _load_models(token: str, modality: str) -> list[dict]:
"""Load models for modality; fallback to unfiltered cache if needed."""
try:
models_resp = _api("GET", "/models/", token=token,
params={"modality": modality})
except httpx.RequestError:
return []
if models_resp.status_code == 200:
try:
models = models_resp.json()
except ValueError:
models = []
if models:
return models
try:
all_resp = _api("GET", "/models/", token=token)
except httpx.RequestError:
return []
if all_resp.status_code != 200:
return []
try:
all_models = all_resp.json()
except ValueError:
return []
filtered = [m for m in all_models if _model_matches_modality(m, modality)]
return filtered or all_models
def login_required(view):
@functools.wraps(view)
def wrapped(*args, **kwargs):
if "access_token" not in session:
return redirect(url_for("login"))
return view(*args, **kwargs)
return wrapped
def admin_required(view):
@functools.wraps(view)
def wrapped(*args, **kwargs):
if "access_token" not in session:
return redirect(url_for("login"))
if session.get("user_role") != "admin":
flash("Admin access required.", "error")
return redirect(url_for("dashboard"))
return view(*args, **kwargs)
return wrapped
# ---------------------------------------------------------------------------
# Auth routes
# ---------------------------------------------------------------------------
@app.get("/")
def index():
if "access_token" in session:
return redirect(url_for("dashboard"))
return redirect(url_for("login"))
@app.route("/login", methods=["GET", "POST"])
def login():
if request.method == "POST":
email = request.form["email"]
password = request.form["password"]
resp = _api("POST", "/auth/login",
json={"email": email, "password": password})
if resp.status_code == 200:
data = resp.json()
session["access_token"] = data["access_token"]
session["refresh_token"] = data["refresh_token"]
me = _api("GET", "/users/me", token=data["access_token"])
if me.status_code == 200:
u = me.json()
session["user_email"] = u.get("email", "")
session["user_role"] = u.get("role", "user")
return redirect(url_for("dashboard"))
flash("Invalid email or password.", "error")
return render_template("login.html")
@app.route("/register", methods=["GET", "POST"])
def register():
if request.method == "POST":
email = request.form["email"]
password = request.form["password"]
resp = _api("POST", "/auth/register",
json={"email": email, "password": password})
if resp.status_code == 201:
flash("Account created. Please log in.", "success")
return redirect(url_for("login"))
detail = resp.json().get("detail", "Registration failed.")
flash(detail, "error")
return render_template("register.html")
@app.get("/logout")
def logout():
refresh_token = session.get("refresh_token")
if refresh_token:
_api("POST", "/auth/logout", json={"refresh_token": refresh_token})
session.clear()
return redirect(url_for("login"))
# ---------------------------------------------------------------------------
# Authenticated routes
# ---------------------------------------------------------------------------
@app.get("/dashboard")
@login_required
def dashboard():
token = session["access_token"]
resp = _api("GET", "/users/me", token=token)
user = resp.json() if resp.status_code == 200 else {}
img_resp = _api("GET", "/images/", token=token)
images = img_resp.json() if img_resp.status_code == 200 else []
gen_resp = _api("GET", "/generate/images", token=token)
generated_images = gen_resp.json() if gen_resp.status_code == 200 else []
vid_resp = _api("GET", "/generate/videos", token=token)
videos = vid_resp.json() if vid_resp.status_code == 200 else []
pending_videos = [v for v in videos if v.get(
"status") not in ("completed", "failed")]
completed_videos = [v for v in videos if v.get("status") == "completed"]
return render_template("dashboard.html", user=user, images=images,
generated_images=generated_images,
pending_videos=pending_videos,
completed_videos=completed_videos)
@app.get("/gallery")
@login_required
def gallery():
token = session["access_token"]
# Fetch all content types
uploads_resp = _api("GET", "/images/", token=token)
uploads = uploads_resp.json() if uploads_resp.status_code == 200 else []
gen_images_resp = _api("GET", "/generate/images", token=token)
generated_images = gen_images_resp.json(
) if gen_images_resp.status_code == 200 else []
videos_resp = _api("GET", "/generate/videos", token=token)
videos = videos_resp.json() if videos_resp.status_code == 200 else []
# Separate pending videos
pending_videos = [v for v in videos if v.get(
"status") not in ("completed", "failed")]
completed_videos = [v for v in videos if v.get("status") == "completed"]
return render_template(
"gallery.html",
uploads=uploads,
generated_images=generated_images,
pending_videos=pending_videos,
completed_videos=completed_videos,
)
@app.get("/gallery/image/<image_id>")
@login_required
def image_detail(image_id: str):
token = session["access_token"]
resp = _api("GET", f"/generate/images/{image_id}", token=token)
image = resp.json() if resp.status_code == 200 else None
return render_template("image_detail.html", image=image)
@app.get("/gallery/video/<video_id>")
@login_required
def video_detail(video_id: str):
token = session["access_token"]
resp = _api("GET", f"/generate/videos/{video_id}", token=token)
video = resp.json() if resp.status_code == 200 else None
return render_template("video_detail.html", video=video)
@app.get("/gallery/upload/<image_id>")
@login_required
def upload_detail(image_id: str):
token = session["access_token"]
resp = _api("GET", f"/images/{image_id}", token=token)
image = resp.json() if resp.status_code == 200 else None
return render_template("upload_detail.html", image=image)
# ── Generate ──────────────────────────────────────────────────────────────
@app.get("/images/<image_id>/file")
@login_required
def serve_uploaded_image(image_id: str):
resp = _api("GET", f"/images/{image_id}/file",
token=session["access_token"])
if resp.status_code != 200:
return Response("Not found", status=404)
return Response(
resp.content,
status=200,
content_type=resp.headers.get("content-type", "image/jpeg"),
)
@app.get("/generate")
@login_required
def generate():
return redirect(url_for("generate_text"))
@app.route("/generate/text", methods=["GET", "POST"])
@login_required
def generate_text():
error = None
token = session["access_token"]
chat_history: list[dict] = session.get("chat_history", [])
system_prompt: str = session.get("chat_system_prompt", "")
model: str = session.get("chat_model", "")
if request.method == "POST":
action = request.form.get("action", "send")
if action == "clear":
session.pop("chat_history", None)
session.pop("chat_system_prompt", None)
session.pop("chat_model", None)
return redirect(url_for("generate_text"))
prompt = request.form.get("prompt", "").strip()
model = request.form.get("model", "").strip()
system_prompt = request.form.get("system_prompt", "").strip()
# Persist model + system_prompt across turns
session["chat_model"] = model
session["chat_system_prompt"] = system_prompt
if prompt:
# Build messages: history (user/assistant only) + new user msg
messages = [m for m in chat_history if m["role"]
in ("user", "assistant")]
messages.append({"role": "user", "content": prompt})
payload: dict = {
"model": model,
"messages": [{"role": m["role"], "content": m["content"]} for m in messages],
}
if system_prompt:
payload["system_prompt"] = system_prompt
resp = _api("POST", "/generate/text", token=token, json=payload)
if resp.status_code == 200:
data = resp.json()
chat_history = list(messages)
chat_history.append({"role": "assistant", "content": data["content"],
"usage": data.get("usage")})
session["chat_history"] = chat_history
else:
try:
error = resp.json().get("detail", "Generation failed.")
except Exception:
error = "Generation failed."
models = _load_models(token, "text")
return render_template(
"generate_text.html",
chat_history=session.get("chat_history", []),
error=error,
models=models,
system_prompt=system_prompt,
current_model=model,
)
@app.route("/generate/image", methods=["GET", "POST"])
@login_required
def generate_image():
result = error = None
token = session["access_token"]
if request.method == "POST":
# Upload reference image if provided
ref_file = request.files.get("reference_image")
if ref_file and ref_file.filename:
up_resp = _api(
"POST", "/images/upload",
token=token,
files={"file": (ref_file.filename,
ref_file.stream, ref_file.content_type)},
)
if up_resp.status_code not in (200, 201):
error = up_resp.json().get("detail", "Image upload failed.")
models = _load_models(token, "image")
return render_template("generate_image.html", result=result, error=error, models=models)
resp = _api("POST", "/generate/image", token=token, json={
"model": request.form.get("model", "").strip(),
"prompt": request.form.get("prompt", "").strip(),
"n": int(request.form.get("n", 1)),
"size": request.form.get("size", "1024x1024"),
"aspect_ratio": request.form.get("aspect_ratio", "").strip() or None,
"image_size": request.form.get("image_size", "").strip() or None,
})
if resp.status_code == 200:
result = resp.json()
else:
error = resp.json().get("detail", "Generation failed.")
models = _load_models(token, "image")
return render_template("generate_image.html", result=result, error=error, models=models)
@app.route("/generate/video", methods=["GET", "POST"])
@login_required
def generate_video():
error = None
token = session["access_token"]
if request.method == "POST":
mode = request.form.get("mode", "text")
duration_raw = request.form.get("duration_seconds", "")
duration = int(
duration_raw) if duration_raw.strip().isdigit() else None
resolution = request.form.get("resolution", "").strip() or None
if mode == "image":
resp = _api("POST", "/generate/video/from-image", token=token, json={
"model": request.form.get("model", "").strip(),
"image_url": request.form.get("image_url", "").strip(),
"prompt": request.form.get("prompt", "").strip(),
"aspect_ratio": request.form.get("aspect_ratio", "16:9"),
"duration_seconds": duration,
"resolution": resolution,
})
else:
resp = _api("POST", "/generate/video", token=token, json={
"model": request.form.get("model", "").strip(),
"prompt": request.form.get("prompt", "").strip(),
"aspect_ratio": request.form.get("aspect_ratio", "16:9"),
"duration_seconds": duration,
"resolution": resolution,
})
if resp.status_code == 200:
result = resp.json()
# On success, redirect to the detail page to monitor progress
db_id = result.get("db_id")
if db_id:
return redirect(url_for("video_detail", video_id=db_id))
# Fallback for older backend versions
flash("Video job started.", "success")
return redirect(url_for("gallery"))
else:
error = resp.json().get("detail", "Generation failed.")
models = _load_models(token, "video")
return render_template("generate_video.html", error=error, models=models)
@app.get("/generate/video/status")
@login_required
def generate_video_status():
"""Proxy video status polling to the backend."""
polling_url = request.args.get("polling_url", "")
if not polling_url:
return jsonify({"error": "polling_url required"}), 400
resp = _api(
"GET", "/generate/video/status",
token=session["access_token"],
params={"polling_url": polling_url},
)
return jsonify(resp.json()), resp.status_code
@app.get("/generate/video/<video_id>/status")
@login_required
def generate_video_db_status(video_id: str):
"""Return current DB status for a video job (polled by frontend JS)."""
resp = _api(
"GET", f"/generate/videos/{video_id}", token=session["access_token"])
return jsonify(resp.json()), resp.status_code
@app.post("/generate/video/<video_id>/cancel")
@login_required
def cancel_video_job(video_id: str):
"""Proxy cancel request to backend."""
resp = _api(
"POST", f"/generate/videos/{video_id}/cancel", token=session["access_token"])
return jsonify(resp.json()), resp.status_code
# ── Admin ─────────────────────────────────────────────────────────────────
@app.get("/admin")
@admin_required
def admin():
token = session["access_token"]
stats_resp = _api("GET", "/admin/stats", token=token)
users_resp = _api("GET", "/users", token=token)
stats = stats_resp.json() if stats_resp.status_code == 200 else {}
users = users_resp.json() if users_resp.status_code == 200 else []
return render_template("admin.html", stats=stats, users=users)
@app.post("/admin/users/<user_id>/role")
@admin_required
def admin_set_role(user_id: str):
role = request.form.get("role", "user")
_api("PUT", f"/users/{user_id}/role",
token=session["access_token"], json={"role": role})
flash(f"Role updated to '{role}'.", "success")
return redirect(url_for("admin"))
@app.post("/admin/users/<user_id>/delete")
@admin_required
def admin_delete_user(user_id: str):
_api("DELETE", f"/users/{user_id}", token=session["access_token"])
flash("User deleted.", "success")
return redirect(url_for("admin"))
@app.get("/admin/models")
@admin_required
def admin_models():
"""Show model cache status and list all models."""
return render_template("admin/models.html")
# ── Admin API proxies (same-origin for browser JS, avoids mixed-content) ──
@app.get("/api/admin/videos")
@admin_required
def api_admin_list_videos():
resp = _api("GET", "/admin/videos", token=session["access_token"])
return jsonify(resp.json()), resp.status_code
@app.post("/api/admin/videos/<job_id>/retry")
@admin_required
def api_admin_retry_video(job_id: str):
resp = _api(
"POST", f"/admin/videos/{job_id}/retry", token=session["access_token"])
return jsonify(resp.json()), resp.status_code
@app.post("/api/admin/videos/<job_id>/cancel")
@admin_required
def api_admin_cancel_video(job_id: str):
resp = _api(
"POST", f"/admin/videos/{job_id}/cancel", token=session["access_token"])
return jsonify(resp.json()), resp.status_code
@app.delete("/api/admin/videos/<job_id>")
@admin_required
def api_admin_delete_video(job_id: str):
resp = _api(
"DELETE", f"/admin/videos/{job_id}", token=session["access_token"])
return jsonify(resp.json()), resp.status_code
# ── Profile ───────────────────────────────────────────────────────────────
@app.route("/users/profile", methods=["GET", "POST"])
@login_required
def profile():
token = session["access_token"]
if request.method == "POST":
payload: dict = {}
new_email = request.form.get("email", "").strip()
new_password = request.form.get("password", "").strip()
if new_email:
payload["email"] = new_email
if new_password:
payload["password"] = new_password
if payload:
resp = _api("PUT", "/users/me", token=token, json=payload)
if resp.status_code == 200:
updated = resp.json()
session["user_email"] = updated.get(
"email", session.get("user_email", ""))
flash("Profile updated.", "success")
else:
flash(resp.json().get("detail", "Update failed."), "error")
return redirect(url_for("profile"))
resp = _api("GET", "/users/me", token=token)
user = resp.json() if resp.status_code == 200 else {}
return render_template("profile.html", user=user)