"""Flask frontend application.""" import functools import httpx from flask import ( Flask, Response, flash, jsonify, redirect, render_template, request, session, url_for, ) from .config import Config app = Flask(__name__) app.config.from_object(Config) # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- def _backend(path: str) -> str: return f"{app.config['BACKEND_URL']}{path}" def _api(method: str, path: str, *, token: str | None = None, **kwargs): headers = kwargs.pop("headers", {}) if token: headers["Authorization"] = f"Bearer {token}" return httpx.request(method, _backend(path), headers=headers, timeout=30, **kwargs) def _model_matches_modality(model: dict, modality: str) -> bool: """Heuristic fallback when backend modality filter returns empty.""" model_modality = (model.get("modality") or "").lower() if model_modality == modality: return True text = f"{model.get('id', '')} {model.get('name', '')}".lower() keywords = { "image": ["image", "dall-e", "flux", "stable-diffusion", "sdxl", "recraft", "ideogram", "gpt-image"], "video": ["video", "sora", "runway", "veo", "kling", "pika", "luma", "wan"], "audio": ["audio", "speech", "voice", "tts", "transcribe", "whisper"], } if modality in keywords: return any(k in text for k in keywords[modality]) if modality == "text": non_text_hits = any( k in text for k in keywords["image"] + keywords["video"] + keywords["audio"]) return not non_text_hits return False def _load_models(token: str, modality: str) -> list[dict]: """Load models for modality; fallback to unfiltered cache if needed.""" try: models_resp = _api("GET", "/models/", token=token, params={"modality": modality}) except httpx.RequestError: return [] if models_resp.status_code == 200: try: models = models_resp.json() except ValueError: models = [] if models: return models try: all_resp = _api("GET", "/models/", token=token) except httpx.RequestError: return [] if all_resp.status_code != 200: return [] try: all_models = all_resp.json() except ValueError: return [] filtered = [m for m in all_models if _model_matches_modality(m, modality)] return filtered or all_models def login_required(view): @functools.wraps(view) def wrapped(*args, **kwargs): if "access_token" not in session: return redirect(url_for("login")) return view(*args, **kwargs) return wrapped def admin_required(view): @functools.wraps(view) def wrapped(*args, **kwargs): if "access_token" not in session: return redirect(url_for("login")) if session.get("user_role") != "admin": flash("Admin access required.", "error") return redirect(url_for("dashboard")) return view(*args, **kwargs) return wrapped # --------------------------------------------------------------------------- # Auth routes # --------------------------------------------------------------------------- @app.get("/") def index(): if "access_token" in session: return redirect(url_for("dashboard")) return redirect(url_for("login")) @app.route("/login", methods=["GET", "POST"]) def login(): if request.method == "POST": email = request.form["email"] password = request.form["password"] resp = _api("POST", "/auth/login", json={"email": email, "password": password}) if resp.status_code == 200: data = resp.json() session["access_token"] = data["access_token"] session["refresh_token"] = data["refresh_token"] me = _api("GET", "/users/me", token=data["access_token"]) if me.status_code == 200: u = me.json() session["user_email"] = u.get("email", "") session["user_role"] = u.get("role", "user") return redirect(url_for("dashboard")) flash("Invalid email or password.", "error") return render_template("login.html") @app.route("/register", methods=["GET", "POST"]) def register(): if request.method == "POST": email = request.form["email"] password = request.form["password"] resp = _api("POST", "/auth/register", json={"email": email, "password": password}) if resp.status_code == 201: flash("Account created. Please log in.", "success") return redirect(url_for("login")) detail = resp.json().get("detail", "Registration failed.") flash(detail, "error") return render_template("register.html") @app.get("/logout") def logout(): refresh_token = session.get("refresh_token") if refresh_token: _api("POST", "/auth/logout", json={"refresh_token": refresh_token}) session.clear() return redirect(url_for("login")) # --------------------------------------------------------------------------- # Authenticated routes # --------------------------------------------------------------------------- @app.get("/dashboard") @login_required def dashboard(): token = session["access_token"] resp = _api("GET", "/users/me", token=token) user = resp.json() if resp.status_code == 200 else {} img_resp = _api("GET", "/images/", token=token) images = img_resp.json() if img_resp.status_code == 200 else [] gen_resp = _api("GET", "/generate/images", token=token) generated_images = gen_resp.json() if gen_resp.status_code == 200 else [] vid_resp = _api("GET", "/generate/videos", token=token) generated_videos = vid_resp.json() if vid_resp.status_code == 200 else [] return render_template("dashboard.html", user=user, images=images, generated_images=generated_images, generated_videos=generated_videos) # ── Generate ────────────────────────────────────────────────────────────── @app.get("/images//file") @login_required def serve_uploaded_image(image_id: str): resp = _api("GET", f"/images/{image_id}/file", token=session["access_token"]) if resp.status_code != 200: return Response("Not found", status=404) return Response( resp.content, status=200, content_type=resp.headers.get("content-type", "image/jpeg"), ) @app.get("/generate") @login_required def generate(): return redirect(url_for("generate_text")) @app.route("/generate/text", methods=["GET", "POST"]) @login_required def generate_text(): error = None token = session["access_token"] chat_history: list[dict] = session.get("chat_history", []) system_prompt: str = session.get("chat_system_prompt", "") model: str = session.get("chat_model", "") if request.method == "POST": action = request.form.get("action", "send") if action == "clear": session.pop("chat_history", None) session.pop("chat_system_prompt", None) session.pop("chat_model", None) return redirect(url_for("generate_text")) prompt = request.form.get("prompt", "").strip() model = request.form.get("model", "").strip() system_prompt = request.form.get("system_prompt", "").strip() # Persist model + system_prompt across turns session["chat_model"] = model session["chat_system_prompt"] = system_prompt if prompt: # Build messages: history (user/assistant only) + new user msg messages = [m for m in chat_history if m["role"] in ("user", "assistant")] messages.append({"role": "user", "content": prompt}) payload: dict = { "model": model, "messages": [{"role": m["role"], "content": m["content"]} for m in messages], } if system_prompt: payload["system_prompt"] = system_prompt resp = _api("POST", "/generate/text", token=token, json=payload) if resp.status_code == 200: data = resp.json() chat_history = list(messages) chat_history.append({"role": "assistant", "content": data["content"], "usage": data.get("usage")}) session["chat_history"] = chat_history else: try: error = resp.json().get("detail", "Generation failed.") except Exception: error = "Generation failed." models = _load_models(token, "text") return render_template( "generate_text.html", chat_history=session.get("chat_history", []), error=error, models=models, system_prompt=system_prompt, current_model=model, ) @app.route("/generate/image", methods=["GET", "POST"]) @login_required def generate_image(): result = error = None token = session["access_token"] if request.method == "POST": # Upload reference image if provided ref_file = request.files.get("reference_image") if ref_file and ref_file.filename: up_resp = _api( "POST", "/images/upload", token=token, files={"file": (ref_file.filename, ref_file.stream, ref_file.content_type)}, ) if up_resp.status_code not in (200, 201): error = up_resp.json().get("detail", "Image upload failed.") models = _load_models(token, "image") return render_template("generate_image.html", result=result, error=error, models=models) resp = _api("POST", "/generate/image", token=token, json={ "model": request.form.get("model", "").strip(), "prompt": request.form.get("prompt", "").strip(), "n": int(request.form.get("n", 1)), "size": request.form.get("size", "1024x1024"), "aspect_ratio": request.form.get("aspect_ratio", "").strip() or None, "image_size": request.form.get("image_size", "").strip() or None, }) if resp.status_code == 200: result = resp.json() else: error = resp.json().get("detail", "Generation failed.") models = _load_models(token, "image") return render_template("generate_image.html", result=result, error=error, models=models) @app.route("/generate/video", methods=["GET", "POST"]) @login_required def generate_video(): result = error = None token = session["access_token"] if request.method == "POST": mode = request.form.get("mode", "text") duration_raw = request.form.get("duration_seconds", "") duration = int( duration_raw) if duration_raw.strip().isdigit() else None resolution = request.form.get("resolution", "").strip() or None if mode == "image": resp = _api("POST", "/generate/video/from-image", token=token, json={ "model": request.form.get("model", "").strip(), "image_url": request.form.get("image_url", "").strip(), "prompt": request.form.get("prompt", "").strip(), "aspect_ratio": request.form.get("aspect_ratio", "16:9"), "duration_seconds": duration, "resolution": resolution, }) else: resp = _api("POST", "/generate/video", token=token, json={ "model": request.form.get("model", "").strip(), "prompt": request.form.get("prompt", "").strip(), "aspect_ratio": request.form.get("aspect_ratio", "16:9"), "duration_seconds": duration, "resolution": resolution, }) if resp.status_code == 200: result = resp.json() else: error = resp.json().get("detail", "Generation failed.") models = _load_models(token, "video") return render_template("generate_video.html", result=result, error=error, models=models) @app.get("/generate/video/status") @login_required def generate_video_status(): """Proxy video status polling to the backend.""" polling_url = request.args.get("polling_url", "") if not polling_url: return jsonify({"error": "polling_url required"}), 400 resp = _api( "GET", "/generate/video/status", token=session["access_token"], params={"polling_url": polling_url}, ) return jsonify(resp.json()), resp.status_code # ── Admin ───────────────────────────────────────────────────────────────── @app.get("/admin") @admin_required def admin(): token = session["access_token"] stats_resp = _api("GET", "/admin/stats", token=token) users_resp = _api("GET", "/users", token=token) stats = stats_resp.json() if stats_resp.status_code == 200 else {} users = users_resp.json() if users_resp.status_code == 200 else [] return render_template("admin.html", stats=stats, users=users) @app.post("/admin/users//role") @admin_required def admin_set_role(user_id: str): role = request.form.get("role", "user") _api("PUT", f"/users/{user_id}/role", token=session["access_token"], json={"role": role}) flash(f"Role updated to '{role}'.", "success") return redirect(url_for("admin")) @app.post("/admin/users//delete") @admin_required def admin_delete_user(user_id: str): _api("DELETE", f"/users/{user_id}", token=session["access_token"]) flash("User deleted.", "success") return redirect(url_for("admin")) # ── Profile ─────────────────────────────────────────────────────────────── @app.route("/users/profile", methods=["GET", "POST"]) @login_required def profile(): token = session["access_token"] if request.method == "POST": payload: dict = {} new_email = request.form.get("email", "").strip() new_password = request.form.get("password", "").strip() if new_email: payload["email"] = new_email if new_password: payload["password"] = new_password if payload: resp = _api("PUT", "/users/me", token=token, json=payload) if resp.status_code == 200: updated = resp.json() session["user_email"] = updated.get( "email", session.get("user_email", "")) flash("Profile updated.", "success") else: flash(resp.json().get("detail", "Update failed."), "error") return redirect(url_for("profile")) resp = _api("GET", "/users/me", token=token) user = resp.json() if resp.status_code == 200 else {} return render_template("profile.html", user=user)