acc6991341
Co-authored-by: Copilot <copilot@github.com>
369 lines
13 KiB
Python
369 lines
13 KiB
Python
"""Flask frontend application."""
|
|
import functools
|
|
|
|
import httpx
|
|
from flask import (
|
|
Flask,
|
|
Response,
|
|
flash,
|
|
jsonify,
|
|
redirect,
|
|
render_template,
|
|
request,
|
|
session,
|
|
url_for,
|
|
)
|
|
|
|
from .config import Config
|
|
|
|
app = Flask(__name__)
|
|
app.config.from_object(Config)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Helpers
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def _backend(path: str) -> str:
|
|
return f"{app.config['BACKEND_URL']}{path}"
|
|
|
|
|
|
def _api(method: str, path: str, *, token: str | None = None, **kwargs):
|
|
headers = kwargs.pop("headers", {})
|
|
if token:
|
|
headers["Authorization"] = f"Bearer {token}"
|
|
return httpx.request(method, _backend(path), headers=headers, timeout=30, **kwargs)
|
|
|
|
|
|
def _model_matches_modality(model: dict, modality: str) -> bool:
|
|
"""Heuristic fallback when backend modality filter returns empty."""
|
|
model_modality = (model.get("modality") or "").lower()
|
|
if model_modality == modality:
|
|
return True
|
|
|
|
text = f"{model.get('id', '')} {model.get('name', '')}".lower()
|
|
keywords = {
|
|
"image": ["image", "dall-e", "flux", "stable-diffusion", "sdxl", "recraft", "ideogram", "gpt-image"],
|
|
"video": ["video", "sora", "runway", "veo", "kling", "pika", "luma", "wan"],
|
|
"audio": ["audio", "speech", "voice", "tts", "transcribe", "whisper"],
|
|
}
|
|
|
|
if modality in keywords:
|
|
return any(k in text for k in keywords[modality])
|
|
|
|
if modality == "text":
|
|
non_text_hits = any(
|
|
k in text for k in keywords["image"] + keywords["video"] + keywords["audio"])
|
|
return not non_text_hits
|
|
|
|
return False
|
|
|
|
|
|
def _load_models(token: str, modality: str) -> list[dict]:
|
|
"""Load models for modality; fallback to unfiltered cache if needed."""
|
|
try:
|
|
models_resp = _api("GET", "/models/", token=token,
|
|
params={"modality": modality})
|
|
except httpx.RequestError:
|
|
return []
|
|
if models_resp.status_code == 200:
|
|
try:
|
|
models = models_resp.json()
|
|
except ValueError:
|
|
models = []
|
|
if models:
|
|
return models
|
|
|
|
try:
|
|
all_resp = _api("GET", "/models/", token=token)
|
|
except httpx.RequestError:
|
|
return []
|
|
if all_resp.status_code != 200:
|
|
return []
|
|
|
|
try:
|
|
all_models = all_resp.json()
|
|
except ValueError:
|
|
return []
|
|
filtered = [m for m in all_models if _model_matches_modality(m, modality)]
|
|
return filtered or all_models
|
|
|
|
|
|
def login_required(view):
|
|
@functools.wraps(view)
|
|
def wrapped(*args, **kwargs):
|
|
if "access_token" not in session:
|
|
return redirect(url_for("login"))
|
|
return view(*args, **kwargs)
|
|
return wrapped
|
|
|
|
|
|
def admin_required(view):
|
|
@functools.wraps(view)
|
|
def wrapped(*args, **kwargs):
|
|
if "access_token" not in session:
|
|
return redirect(url_for("login"))
|
|
if session.get("user_role") != "admin":
|
|
flash("Admin access required.", "error")
|
|
return redirect(url_for("dashboard"))
|
|
return view(*args, **kwargs)
|
|
return wrapped
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Auth routes
|
|
# ---------------------------------------------------------------------------
|
|
|
|
@app.get("/")
|
|
def index():
|
|
if "access_token" in session:
|
|
return redirect(url_for("dashboard"))
|
|
return redirect(url_for("login"))
|
|
|
|
|
|
@app.route("/login", methods=["GET", "POST"])
|
|
def login():
|
|
if request.method == "POST":
|
|
email = request.form["email"]
|
|
password = request.form["password"]
|
|
resp = _api("POST", "/auth/login",
|
|
json={"email": email, "password": password})
|
|
if resp.status_code == 200:
|
|
data = resp.json()
|
|
session["access_token"] = data["access_token"]
|
|
session["refresh_token"] = data["refresh_token"]
|
|
me = _api("GET", "/users/me", token=data["access_token"])
|
|
if me.status_code == 200:
|
|
u = me.json()
|
|
session["user_email"] = u.get("email", "")
|
|
session["user_role"] = u.get("role", "user")
|
|
return redirect(url_for("dashboard"))
|
|
flash("Invalid email or password.", "error")
|
|
return render_template("login.html")
|
|
|
|
|
|
@app.route("/register", methods=["GET", "POST"])
|
|
def register():
|
|
if request.method == "POST":
|
|
email = request.form["email"]
|
|
password = request.form["password"]
|
|
resp = _api("POST", "/auth/register",
|
|
json={"email": email, "password": password})
|
|
if resp.status_code == 201:
|
|
flash("Account created. Please log in.", "success")
|
|
return redirect(url_for("login"))
|
|
detail = resp.json().get("detail", "Registration failed.")
|
|
flash(detail, "error")
|
|
return render_template("register.html")
|
|
|
|
|
|
@app.get("/logout")
|
|
def logout():
|
|
refresh_token = session.get("refresh_token")
|
|
if refresh_token:
|
|
_api("POST", "/auth/logout", json={"refresh_token": refresh_token})
|
|
session.clear()
|
|
return redirect(url_for("login"))
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Authenticated routes
|
|
# ---------------------------------------------------------------------------
|
|
|
|
@app.get("/dashboard")
|
|
@login_required
|
|
def dashboard():
|
|
token = session["access_token"]
|
|
resp = _api("GET", "/users/me", token=token)
|
|
user = resp.json() if resp.status_code == 200 else {}
|
|
img_resp = _api("GET", "/images/", token=token)
|
|
images = img_resp.json() if img_resp.status_code == 200 else []
|
|
return render_template("dashboard.html", user=user, images=images)
|
|
|
|
|
|
# ── Generate ──────────────────────────────────────────────────────────────
|
|
|
|
@app.get("/images/<image_id>/file")
|
|
@login_required
|
|
def serve_uploaded_image(image_id: str):
|
|
resp = _api("GET", f"/images/{image_id}/file",
|
|
token=session["access_token"])
|
|
if resp.status_code != 200:
|
|
return Response("Not found", status=404)
|
|
return Response(
|
|
resp.content,
|
|
status=200,
|
|
content_type=resp.headers.get("content-type", "image/jpeg"),
|
|
)
|
|
|
|
|
|
@app.get("/generate")
|
|
@login_required
|
|
def generate():
|
|
return redirect(url_for("generate_text"))
|
|
|
|
|
|
@app.route("/generate/text", methods=["GET", "POST"])
|
|
@login_required
|
|
def generate_text():
|
|
result = error = None
|
|
token = session["access_token"]
|
|
if request.method == "POST":
|
|
resp = _api("POST", "/generate/text", token=token, json={
|
|
"model": request.form.get("model", "").strip(),
|
|
"prompt": request.form.get("prompt", "").strip(),
|
|
})
|
|
if resp.status_code == 200:
|
|
result = resp.json()
|
|
else:
|
|
error = resp.json().get("detail", "Generation failed.")
|
|
models = _load_models(token, "text")
|
|
return render_template("generate_text.html", result=result, error=error, models=models)
|
|
|
|
|
|
@app.route("/generate/image", methods=["GET", "POST"])
|
|
@login_required
|
|
def generate_image():
|
|
result = error = None
|
|
token = session["access_token"]
|
|
if request.method == "POST":
|
|
# Upload reference image if provided
|
|
ref_file = request.files.get("reference_image")
|
|
if ref_file and ref_file.filename:
|
|
up_resp = _api(
|
|
"POST", "/images/upload",
|
|
token=token,
|
|
files={"file": (ref_file.filename,
|
|
ref_file.stream, ref_file.content_type)},
|
|
)
|
|
if up_resp.status_code not in (200, 201):
|
|
error = up_resp.json().get("detail", "Image upload failed.")
|
|
models = _load_models(token, "image")
|
|
return render_template("generate_image.html", result=result, error=error, models=models)
|
|
|
|
resp = _api("POST", "/generate/image", token=token, json={
|
|
"model": request.form.get("model", "").strip(),
|
|
"prompt": request.form.get("prompt", "").strip(),
|
|
"n": int(request.form.get("n", 1)),
|
|
"size": request.form.get("size", "1024x1024"),
|
|
"aspect_ratio": request.form.get("aspect_ratio", "").strip() or None,
|
|
"image_size": request.form.get("image_size", "").strip() or None,
|
|
})
|
|
if resp.status_code == 200:
|
|
result = resp.json()
|
|
else:
|
|
error = resp.json().get("detail", "Generation failed.")
|
|
models = _load_models(token, "image")
|
|
return render_template("generate_image.html", result=result, error=error, models=models)
|
|
|
|
|
|
@app.route("/generate/video", methods=["GET", "POST"])
|
|
@login_required
|
|
def generate_video():
|
|
result = error = None
|
|
token = session["access_token"]
|
|
if request.method == "POST":
|
|
mode = request.form.get("mode", "text")
|
|
duration_raw = request.form.get("duration_seconds", "")
|
|
duration = int(
|
|
duration_raw) if duration_raw.strip().isdigit() else None
|
|
resolution = request.form.get("resolution", "").strip() or None
|
|
if mode == "image":
|
|
resp = _api("POST", "/generate/video/from-image", token=token, json={
|
|
"model": request.form.get("model", "").strip(),
|
|
"image_url": request.form.get("image_url", "").strip(),
|
|
"prompt": request.form.get("prompt", "").strip(),
|
|
"aspect_ratio": request.form.get("aspect_ratio", "16:9"),
|
|
"duration_seconds": duration,
|
|
"resolution": resolution,
|
|
})
|
|
else:
|
|
resp = _api("POST", "/generate/video", token=token, json={
|
|
"model": request.form.get("model", "").strip(),
|
|
"prompt": request.form.get("prompt", "").strip(),
|
|
"aspect_ratio": request.form.get("aspect_ratio", "16:9"),
|
|
"duration_seconds": duration,
|
|
"resolution": resolution,
|
|
})
|
|
if resp.status_code == 200:
|
|
result = resp.json()
|
|
else:
|
|
error = resp.json().get("detail", "Generation failed.")
|
|
models = _load_models(token, "video")
|
|
return render_template("generate_video.html", result=result, error=error, models=models)
|
|
|
|
|
|
@app.get("/generate/video/status")
|
|
@login_required
|
|
def generate_video_status():
|
|
"""Proxy video status polling to the backend."""
|
|
polling_url = request.args.get("polling_url", "")
|
|
if not polling_url:
|
|
return jsonify({"error": "polling_url required"}), 400
|
|
resp = _api(
|
|
"GET", "/generate/video/status",
|
|
token=session["access_token"],
|
|
params={"polling_url": polling_url},
|
|
)
|
|
return jsonify(resp.json()), resp.status_code
|
|
|
|
|
|
# ── Admin ─────────────────────────────────────────────────────────────────
|
|
|
|
@app.get("/admin")
|
|
@admin_required
|
|
def admin():
|
|
token = session["access_token"]
|
|
stats_resp = _api("GET", "/admin/stats", token=token)
|
|
users_resp = _api("GET", "/users", token=token)
|
|
stats = stats_resp.json() if stats_resp.status_code == 200 else {}
|
|
users = users_resp.json() if users_resp.status_code == 200 else []
|
|
return render_template("admin.html", stats=stats, users=users)
|
|
|
|
|
|
@app.post("/admin/users/<user_id>/role")
|
|
@admin_required
|
|
def admin_set_role(user_id: str):
|
|
role = request.form.get("role", "user")
|
|
_api("PUT", f"/users/{user_id}/role",
|
|
token=session["access_token"], json={"role": role})
|
|
flash(f"Role updated to '{role}'.", "success")
|
|
return redirect(url_for("admin"))
|
|
|
|
|
|
@app.post("/admin/users/<user_id>/delete")
|
|
@admin_required
|
|
def admin_delete_user(user_id: str):
|
|
_api("DELETE", f"/users/{user_id}", token=session["access_token"])
|
|
flash("User deleted.", "success")
|
|
return redirect(url_for("admin"))
|
|
|
|
|
|
# ── Profile ───────────────────────────────────────────────────────────────
|
|
|
|
@app.route("/users/profile", methods=["GET", "POST"])
|
|
@login_required
|
|
def profile():
|
|
token = session["access_token"]
|
|
if request.method == "POST":
|
|
payload: dict = {}
|
|
new_email = request.form.get("email", "").strip()
|
|
new_password = request.form.get("password", "").strip()
|
|
if new_email:
|
|
payload["email"] = new_email
|
|
if new_password:
|
|
payload["password"] = new_password
|
|
if payload:
|
|
resp = _api("PUT", "/users/me", token=token, json=payload)
|
|
if resp.status_code == 200:
|
|
updated = resp.json()
|
|
session["user_email"] = updated.get(
|
|
"email", session.get("user_email", ""))
|
|
flash("Profile updated.", "success")
|
|
else:
|
|
flash(resp.json().get("detail", "Update failed."), "error")
|
|
return redirect(url_for("profile"))
|
|
resp = _api("GET", "/users/me", token=token)
|
|
user = resp.json() if resp.status_code == 200 else {}
|
|
return render_template("profile.html", user=user)
|