207 lines
7.6 KiB
Python
207 lines
7.6 KiB
Python
"""Generate blueprint — text, image, video generation."""
|
|
import httpx
|
|
|
|
from flask import (
|
|
Blueprint, flash, jsonify, redirect, render_template, request, session, url_for,
|
|
)
|
|
|
|
from ..helpers import _api, _load_models, login_required
|
|
|
|
generate_bp = Blueprint("generate", __name__)
|
|
|
|
|
|
def _detail_or_default(resp, default: str) -> str:
|
|
try:
|
|
payload = resp.json()
|
|
except ValueError:
|
|
return default
|
|
if isinstance(payload, dict):
|
|
detail = payload.get("detail")
|
|
if detail:
|
|
return str(detail)
|
|
return default
|
|
|
|
|
|
@generate_bp.get("/generate")
|
|
@login_required
|
|
def index():
|
|
return redirect(url_for("generate.text"))
|
|
|
|
|
|
@generate_bp.route("/generate/text", methods=["GET", "POST"])
|
|
@login_required
|
|
def text():
|
|
error = None
|
|
token = session["access_token"]
|
|
chat_history: list[dict] = session.get("chat_history", [])
|
|
system_prompt: str = session.get("chat_system_prompt", "")
|
|
model: str = session.get("chat_model", "")
|
|
|
|
if request.method == "POST":
|
|
action = request.form.get("action", "send")
|
|
|
|
if action == "clear":
|
|
session.pop("chat_history", None)
|
|
session.pop("chat_system_prompt", None)
|
|
session.pop("chat_model", None)
|
|
return redirect(url_for("generate.text"))
|
|
|
|
prompt = request.form.get("prompt", "").strip()
|
|
model = request.form.get("model", "").strip()
|
|
system_prompt = request.form.get("system_prompt", "").strip()
|
|
|
|
session["chat_model"] = model
|
|
session["chat_system_prompt"] = system_prompt
|
|
|
|
if prompt:
|
|
messages = [m for m in chat_history if m["role"]
|
|
in ("user", "assistant")]
|
|
messages.append({"role": "user", "content": prompt})
|
|
|
|
payload: dict = {
|
|
"model": model,
|
|
"messages": [{"role": m["role"], "content": m["content"]} for m in messages],
|
|
}
|
|
if system_prompt:
|
|
payload["system_prompt"] = system_prompt
|
|
|
|
resp = _api("POST", "/generate/text", token=token, json=payload)
|
|
if resp.status_code == 200:
|
|
data = resp.json()
|
|
chat_history = list(messages)
|
|
chat_history.append({"role": "assistant", "content": data["content"],
|
|
"usage": data.get("usage")})
|
|
session["chat_history"] = chat_history
|
|
else:
|
|
try:
|
|
error = resp.json().get("detail", "Generation failed.")
|
|
except Exception:
|
|
error = "Generation failed."
|
|
|
|
models = _load_models(token, "text")
|
|
return render_template(
|
|
"generate_text.html",
|
|
chat_history=session.get("chat_history", []),
|
|
error=error,
|
|
models=models,
|
|
system_prompt=system_prompt,
|
|
current_model=model,
|
|
)
|
|
|
|
|
|
@generate_bp.route("/generate/image", methods=["GET", "POST"])
|
|
@login_required
|
|
def image():
|
|
result = error = None
|
|
token = session["access_token"]
|
|
if request.method == "POST":
|
|
try:
|
|
ref_file = request.files.get("reference_image")
|
|
if ref_file and ref_file.filename:
|
|
up_resp = _api(
|
|
"POST", "/images/upload",
|
|
token=token,
|
|
files={"file": (ref_file.filename,
|
|
ref_file.stream, ref_file.content_type)},
|
|
)
|
|
if up_resp.status_code not in (200, 201):
|
|
error = _detail_or_default(up_resp, "Image upload failed.")
|
|
models = _load_models(token, "image")
|
|
return render_template("generate_image.html", result=result, error=error, models=models)
|
|
|
|
resp = _api("POST", "/generate/image", token=token, timeout=120, json={
|
|
"model": request.form.get("model", "").strip(),
|
|
"prompt": request.form.get("prompt", "").strip(),
|
|
"n": int(request.form.get("n", 1)),
|
|
"size": request.form.get("size", "1024x1024"),
|
|
"aspect_ratio": request.form.get("aspect_ratio", "").strip() or None,
|
|
"image_size": request.form.get("image_size", "").strip() or None,
|
|
})
|
|
if resp.status_code == 200:
|
|
result = resp.json()
|
|
else:
|
|
error = _detail_or_default(resp, "Generation failed.")
|
|
except httpx.TimeoutException:
|
|
error = "Image generation timed out. Please try again."
|
|
except httpx.RequestError:
|
|
error = "Cannot reach generation service. Please try again."
|
|
models = _load_models(token, "image")
|
|
return render_template("generate_image.html", result=result, error=error, models=models)
|
|
|
|
|
|
@generate_bp.route("/generate/video", methods=["GET", "POST"])
|
|
@login_required
|
|
def video():
|
|
error = None
|
|
token = session["access_token"]
|
|
if request.method == "POST":
|
|
mode = request.form.get("mode", "text")
|
|
duration_raw = request.form.get("duration_seconds", "")
|
|
duration = int(
|
|
duration_raw) if duration_raw.strip().isdigit() else None
|
|
resolution = request.form.get("resolution", "").strip() or None
|
|
|
|
if mode == "image":
|
|
resp = _api("POST", "/generate/video/from-image", token=token, json={
|
|
"model": request.form.get("model", "").strip(),
|
|
"image_url": request.form.get("image_url", "").strip(),
|
|
"prompt": request.form.get("prompt", "").strip(),
|
|
"aspect_ratio": request.form.get("aspect_ratio", "16:9"),
|
|
"duration_seconds": duration,
|
|
"resolution": resolution,
|
|
})
|
|
else:
|
|
resp = _api("POST", "/generate/video", token=token, json={
|
|
"model": request.form.get("model", "").strip(),
|
|
"prompt": request.form.get("prompt", "").strip(),
|
|
"aspect_ratio": request.form.get("aspect_ratio", "16:9"),
|
|
"duration_seconds": duration,
|
|
"resolution": resolution,
|
|
})
|
|
|
|
if resp.status_code == 200:
|
|
result = resp.json()
|
|
db_id = result.get("db_id")
|
|
if db_id:
|
|
return redirect(url_for("gallery.video_detail", video_id=db_id))
|
|
flash("Video job started.", "success")
|
|
return redirect(url_for("gallery.index"))
|
|
else:
|
|
error = resp.json().get("detail", "Generation failed.")
|
|
|
|
models = _load_models(token, "video")
|
|
return render_template("generate_video.html", error=error, models=models)
|
|
|
|
|
|
@generate_bp.get("/generate/video/status")
|
|
@login_required
|
|
def video_status():
|
|
"""Proxy video status polling to the backend."""
|
|
polling_url = request.args.get("polling_url", "")
|
|
if not polling_url:
|
|
return jsonify({"error": "polling_url required"}), 400
|
|
resp = _api(
|
|
"GET", "/generate/video/status",
|
|
token=session["access_token"],
|
|
params={"polling_url": polling_url},
|
|
)
|
|
return jsonify(resp.json()), resp.status_code
|
|
|
|
|
|
@generate_bp.get("/generate/video/<video_id>/status")
|
|
@login_required
|
|
def video_db_status(video_id: str):
|
|
"""Return current DB status for a video job (polled by frontend JS)."""
|
|
resp = _api(
|
|
"GET", f"/generate/videos/{video_id}", token=session["access_token"])
|
|
return jsonify(resp.json()), resp.status_code
|
|
|
|
|
|
@generate_bp.post("/generate/video/<video_id>/cancel")
|
|
@login_required
|
|
def cancel_video_job(video_id: str):
|
|
"""Proxy cancel request to backend."""
|
|
resp = _api(
|
|
"POST", f"/generate/videos/{video_id}/cancel", token=session["access_token"])
|
|
return jsonify(resp.json()), resp.status_code
|