137 lines
4.3 KiB
Python
137 lines
4.3 KiB
Python
"""Flask frontend application."""
|
|
import functools
|
|
|
|
import httpx
|
|
from flask import (
|
|
Flask,
|
|
flash,
|
|
redirect,
|
|
render_template,
|
|
request,
|
|
session,
|
|
url_for,
|
|
)
|
|
|
|
from frontend.app.config import Config
|
|
|
|
app = Flask(__name__)
|
|
app.config.from_object(Config)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Helpers
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def _backend(path: str) -> str:
|
|
return f"{app.config['BACKEND_URL']}{path}"
|
|
|
|
|
|
def _api(method: str, path: str, *, token: str | None = None, **kwargs):
|
|
headers = kwargs.pop("headers", {})
|
|
if token:
|
|
headers["Authorization"] = f"Bearer {token}"
|
|
return httpx.request(method, _backend(path), headers=headers, timeout=30, **kwargs)
|
|
|
|
|
|
def login_required(view):
|
|
@functools.wraps(view)
|
|
def wrapped(*args, **kwargs):
|
|
if "access_token" not in session:
|
|
return redirect(url_for("login"))
|
|
return view(*args, **kwargs)
|
|
return wrapped
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Auth routes
|
|
# ---------------------------------------------------------------------------
|
|
|
|
@app.get("/")
|
|
def index():
|
|
if "access_token" in session:
|
|
return redirect(url_for("dashboard"))
|
|
return redirect(url_for("login"))
|
|
|
|
|
|
@app.route("/login", methods=["GET", "POST"])
|
|
def login():
|
|
if request.method == "POST":
|
|
email = request.form["email"]
|
|
password = request.form["password"]
|
|
resp = _api("POST", "/auth/login", json={"email": email, "password": password})
|
|
if resp.status_code == 200:
|
|
data = resp.json()
|
|
session["access_token"] = data["access_token"]
|
|
session["refresh_token"] = data["refresh_token"]
|
|
return redirect(url_for("dashboard"))
|
|
flash("Invalid email or password.", "error")
|
|
return render_template("login.html")
|
|
|
|
|
|
@app.route("/register", methods=["GET", "POST"])
|
|
def register():
|
|
if request.method == "POST":
|
|
email = request.form["email"]
|
|
password = request.form["password"]
|
|
resp = _api("POST", "/auth/register", json={"email": email, "password": password})
|
|
if resp.status_code == 201:
|
|
flash("Account created. Please log in.", "success")
|
|
return redirect(url_for("login"))
|
|
detail = resp.json().get("detail", "Registration failed.")
|
|
flash(detail, "error")
|
|
return render_template("register.html")
|
|
|
|
|
|
@app.get("/logout")
|
|
def logout():
|
|
refresh_token = session.get("refresh_token")
|
|
if refresh_token:
|
|
_api("POST", "/auth/logout", json={"refresh_token": refresh_token})
|
|
session.clear()
|
|
return redirect(url_for("login"))
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Authenticated routes
|
|
# ---------------------------------------------------------------------------
|
|
|
|
@app.get("/dashboard")
|
|
@login_required
|
|
def dashboard():
|
|
token = session["access_token"]
|
|
resp = _api("GET", "/users/me", token=token)
|
|
user = resp.json() if resp.status_code == 200 else {}
|
|
return render_template("dashboard.html", user=user)
|
|
|
|
|
|
@app.route("/generate", methods=["GET", "POST"])
|
|
@login_required
|
|
def generate():
|
|
result = None
|
|
error = None
|
|
if request.method == "POST":
|
|
gen_type = request.form.get("type", "text")
|
|
model = request.form.get("model", "").strip()
|
|
prompt = request.form.get("prompt", "").strip()
|
|
token = session["access_token"]
|
|
|
|
if gen_type == "text":
|
|
resp = _api("POST", "/generate/text", token=token,
|
|
json={"model": model, "prompt": prompt})
|
|
elif gen_type == "image":
|
|
resp = _api("POST", "/generate/image", token=token,
|
|
json={"model": model, "prompt": prompt})
|
|
elif gen_type == "video":
|
|
resp = _api("POST", "/generate/video", token=token,
|
|
json={"model": model, "prompt": prompt})
|
|
else:
|
|
resp = None
|
|
|
|
if resp is not None and resp.status_code == 200:
|
|
result = resp.json()
|
|
else:
|
|
detail = resp.json().get("detail", "Generation failed.") if resp is not None else "Unknown error."
|
|
error = detail
|
|
|
|
return render_template("generate.html", result=result, error=error)
|