diff --git a/backend/app/models/ai.py b/backend/app/models/ai.py index 503862b..9894077 100644 --- a/backend/app/models/ai.py +++ b/backend/app/models/ai.py @@ -33,8 +33,9 @@ class ModelInfo(BaseModel): class TextRequest(BaseModel): model: str - prompt: str + prompt: str = "" system_prompt: str | None = None + messages: list[ChatMessage] | None = None temperature: float = 0.7 max_tokens: int = 1024 diff --git a/backend/app/routers/generate.py b/backend/app/routers/generate.py index 3661d33..f801a07 100644 --- a/backend/app/routers/generate.py +++ b/backend/app/routers/generate.py @@ -23,10 +23,17 @@ async def generate_text( _: dict = Depends(get_current_user), ) -> TextResponse: """Generate text from a prompt using a chat model.""" - messages = [] - if body.system_prompt: - messages.append({"role": "system", "content": body.system_prompt}) - messages.append({"role": "user", "content": body.prompt}) + if body.messages: + messages = [{"role": m.role, "content": m.content} + for m in body.messages] + if body.system_prompt and (not messages or messages[0]["role"] != "system"): + messages.insert( + 0, {"role": "system", "content": body.system_prompt}) + else: + messages = [] + if body.system_prompt: + messages.append({"role": "system", "content": body.system_prompt}) + messages.append({"role": "user", "content": body.prompt}) try: result = await openrouter.chat_completion( diff --git a/backend/tests/test_generate.py b/backend/tests/test_generate.py index 2d92d5c..595326a 100644 --- a/backend/tests/test_generate.py +++ b/backend/tests/test_generate.py @@ -69,7 +69,7 @@ async def _user_token(client): async def test_generate_text(client): token = await _user_token(client) - with patch("backend.app.routers.generate.openrouter.chat_completion", new_callable=AsyncMock, return_value=FAKE_CHAT): + with patch("app.routers.generate.openrouter.chat_completion", new_callable=AsyncMock, return_value=FAKE_CHAT): resp = await client.post( "/generate/text", json={"model": "openai/gpt-4o", "prompt": "Tell me a story"}, @@ -85,7 +85,7 @@ async def test_generate_text(client): async def test_generate_text_with_system_prompt(client): token = await _user_token(client) mock = AsyncMock(return_value=FAKE_CHAT) - with patch("backend.app.routers.generate.openrouter.chat_completion", mock): + with patch("app.routers.generate.openrouter.chat_completion", mock): await client.post( "/generate/text", json={"model": "openai/gpt-4o", "prompt": "Hello", @@ -97,6 +97,44 @@ async def test_generate_text_with_system_prompt(client): assert call_messages[1] == {"role": "user", "content": "Hello"} +async def test_generate_text_with_messages_array(client): + """messages field takes precedence over prompt for multi-turn chat.""" + token = await _user_token(client) + mock = AsyncMock(return_value=FAKE_CHAT) + messages = [ + {"role": "user", "content": "First message"}, + {"role": "assistant", "content": "Reply"}, + {"role": "user", "content": "Follow up"}, + ] + with patch("app.routers.generate.openrouter.chat_completion", mock): + resp = await client.post( + "/generate/text", + json={"model": "openai/gpt-4o", "messages": messages}, + headers={"Authorization": f"Bearer {token}"}, + ) + assert resp.status_code == 200 + call_messages = mock.call_args.kwargs["messages"] + assert len(call_messages) == 3 + assert call_messages[2]["content"] == "Follow up" + + +async def test_generate_text_messages_with_system_prompt(client): + """system_prompt prepended when messages provided and no system msg present.""" + token = await _user_token(client) + mock = AsyncMock(return_value=FAKE_CHAT) + messages = [{"role": "user", "content": "Hi"}] + with patch("app.routers.generate.openrouter.chat_completion", mock): + await client.post( + "/generate/text", + json={"model": "openai/gpt-4o", "messages": messages, + "system_prompt": "Be brief."}, + headers={"Authorization": f"Bearer {token}"}, + ) + call_messages = mock.call_args.kwargs["messages"] + assert call_messages[0] == {"role": "system", "content": "Be brief."} + assert call_messages[1] == {"role": "user", "content": "Hi"} + + async def test_generate_text_unauthenticated(client): resp = await client.post("/generate/text", json={"model": "openai/gpt-4o", "prompt": "Hi"}) assert resp.status_code == 401 @@ -104,7 +142,7 @@ async def test_generate_text_unauthenticated(client): async def test_generate_text_upstream_error(client): token = await _user_token(client) - with patch("backend.app.routers.generate.openrouter.chat_completion", new_callable=AsyncMock, side_effect=Exception("timeout")): + with patch("app.routers.generate.openrouter.chat_completion", new_callable=AsyncMock, side_effect=Exception("timeout")): resp = await client.post( "/generate/text", json={"model": "openai/gpt-4o", "prompt": "Hi"}, @@ -119,7 +157,7 @@ async def test_generate_text_upstream_error(client): async def test_generate_image(client): token = await _user_token(client) - with patch("backend.app.routers.generate.openrouter.generate_image", new_callable=AsyncMock, return_value=FAKE_IMAGE): + with patch("app.routers.generate.openrouter.generate_image", new_callable=AsyncMock, return_value=FAKE_IMAGE): resp = await client.post( "/generate/image", json={"model": "openai/dall-e-3", "prompt": "A cat on the moon"}, @@ -140,7 +178,7 @@ async def test_generate_image_unauthenticated(client): async def test_generate_image_upstream_error(client): token = await _user_token(client) - with patch("backend.app.routers.generate.openrouter.generate_image", new_callable=AsyncMock, side_effect=Exception("rate limit")): + with patch("app.routers.generate.openrouter.generate_image", new_callable=AsyncMock, side_effect=Exception("rate limit")): resp = await client.post( "/generate/image", json={"model": "openai/dall-e-3", "prompt": "Hi"}, @@ -184,7 +222,7 @@ FAKE_IMAGE_CHAT_GPT5 = { async def test_generate_image_chat_flux(client): token = await _user_token(client) - with patch("backend.app.routers.generate.openrouter.generate_image_chat", new_callable=AsyncMock, return_value=FAKE_IMAGE_CHAT_FLUX): + with patch("app.routers.generate.openrouter.generate_image_chat", new_callable=AsyncMock, return_value=FAKE_IMAGE_CHAT_FLUX): resp = await client.post( "/generate/image", json={"model": "black-forest-labs/flux.2-klein-4b", @@ -200,7 +238,7 @@ async def test_generate_image_chat_flux(client): async def test_generate_image_chat_gpt5_image_mini(client): token = await _user_token(client) - with patch("backend.app.routers.generate.openrouter.generate_image_chat", new_callable=AsyncMock, return_value=FAKE_IMAGE_CHAT_GPT5): + with patch("app.routers.generate.openrouter.generate_image_chat", new_callable=AsyncMock, return_value=FAKE_IMAGE_CHAT_GPT5): resp = await client.post( "/generate/image", json={"model": "openai/gpt-5-image-mini", "prompt": "A cat"}, @@ -215,7 +253,7 @@ async def test_generate_image_chat_gpt5_image_mini(client): async def test_generate_image_chat_with_image_config(client): token = await _user_token(client) mock = AsyncMock(return_value=FAKE_IMAGE_CHAT_FLUX) - with patch("backend.app.routers.generate.openrouter.generate_image_chat", mock): + with patch("app.routers.generate.openrouter.generate_image_chat", mock): await client.post( "/generate/image", json={ @@ -239,7 +277,7 @@ async def test_generate_image_chat_unauthenticated(client): async def test_generate_image_chat_upstream_error(client): token = await _user_token(client) - with patch("backend.app.routers.generate.openrouter.generate_image_chat", new_callable=AsyncMock, side_effect=Exception("timeout")): + with patch("app.routers.generate.openrouter.generate_image_chat", new_callable=AsyncMock, side_effect=Exception("timeout")): resp = await client.post( "/generate/image", json={"model": "black-forest-labs/flux.2-klein-4b", "prompt": "Hi"}, @@ -254,7 +292,7 @@ async def test_generate_image_chat_upstream_error(client): async def test_generate_video(client): token = await _user_token(client) - with patch("backend.app.routers.generate.openrouter.generate_video", new_callable=AsyncMock, return_value=FAKE_VIDEO): + with patch("app.routers.generate.openrouter.generate_video", new_callable=AsyncMock, return_value=FAKE_VIDEO): resp = await client.post( "/generate/video", json={"model": "stability/stable-video", @@ -276,7 +314,7 @@ async def test_generate_video_unauthenticated(client): async def test_generate_video_upstream_error(client): token = await _user_token(client) - with patch("backend.app.routers.generate.openrouter.generate_video", new_callable=AsyncMock, side_effect=Exception("503")): + with patch("app.routers.generate.openrouter.generate_video", new_callable=AsyncMock, side_effect=Exception("503")): resp = await client.post( "/generate/video", json={"model": "stability/stable-video", "prompt": "Hi"}, @@ -291,7 +329,7 @@ async def test_generate_video_upstream_error(client): async def test_generate_video_from_image(client): token = await _user_token(client) - with patch("backend.app.routers.generate.openrouter.generate_video_from_image", new_callable=AsyncMock, return_value=FAKE_VIDEO_DONE): + with patch("app.routers.generate.openrouter.generate_video_from_image", new_callable=AsyncMock, return_value=FAKE_VIDEO_DONE): resp = await client.post( "/generate/video/from-image", json={ @@ -315,7 +353,7 @@ async def test_poll_video_status(client): "status": "completed", "unsigned_urls": ["https://example.com/video.mp4"], } - with patch("backend.app.routers.generate.openrouter.poll_video_status", new_callable=AsyncMock, return_value=mock_result): + with patch("app.routers.generate.openrouter.poll_video_status", new_callable=AsyncMock, return_value=mock_result): resp = await client.get( "/generate/video/status", params={"polling_url": "https://openrouter.ai/api/v1/videos/gen-vid-1"}, @@ -337,7 +375,7 @@ async def test_poll_video_status_unauthenticated(client): async def test_poll_video_status_upstream_error(client): token = await _user_token(client) - with patch("backend.app.routers.generate.openrouter.poll_video_status", new_callable=AsyncMock, side_effect=Exception("timeout")): + with patch("app.routers.generate.openrouter.poll_video_status", new_callable=AsyncMock, side_effect=Exception("timeout")): resp = await client.get( "/generate/video/status", params={"polling_url": "https://openrouter.ai/api/v1/videos/gen-vid-1"}, @@ -356,7 +394,7 @@ async def test_generate_video_from_image_unauthenticated(client): async def test_generate_video_from_image_upstream_error(client): token = await _user_token(client) - with patch("backend.app.routers.generate.openrouter.generate_video_from_image", new_callable=AsyncMock, side_effect=Exception("error")): + with patch("app.routers.generate.openrouter.generate_video_from_image", new_callable=AsyncMock, side_effect=Exception("error")): resp = await client.post( "/generate/video/from-image", json={"model": "runway/gen-3", diff --git a/frontend/app/main.py b/frontend/app/main.py index a5c380e..77d82d3 100644 --- a/frontend/app/main.py +++ b/frontend/app/main.py @@ -206,19 +206,64 @@ def generate(): @app.route("/generate/text", methods=["GET", "POST"]) @login_required def generate_text(): - result = error = None + error = None token = session["access_token"] + chat_history: list[dict] = session.get("chat_history", []) + system_prompt: str = session.get("chat_system_prompt", "") + model: str = session.get("chat_model", "") + if request.method == "POST": - resp = _api("POST", "/generate/text", token=token, json={ - "model": request.form.get("model", "").strip(), - "prompt": request.form.get("prompt", "").strip(), - }) - if resp.status_code == 200: - result = resp.json() - else: - error = resp.json().get("detail", "Generation failed.") + action = request.form.get("action", "send") + + if action == "clear": + session.pop("chat_history", None) + session.pop("chat_system_prompt", None) + session.pop("chat_model", None) + return redirect(url_for("generate_text")) + + prompt = request.form.get("prompt", "").strip() + model = request.form.get("model", "").strip() + system_prompt = request.form.get("system_prompt", "").strip() + + # Persist model + system_prompt across turns + session["chat_model"] = model + session["chat_system_prompt"] = system_prompt + + if prompt: + # Build messages: history (user/assistant only) + new user msg + messages = [m for m in chat_history if m["role"] + in ("user", "assistant")] + messages.append({"role": "user", "content": prompt}) + + payload: dict = { + "model": model, + "messages": [{"role": m["role"], "content": m["content"]} for m in messages], + } + if system_prompt: + payload["system_prompt"] = system_prompt + + resp = _api("POST", "/generate/text", token=token, json=payload) + if resp.status_code == 200: + data = resp.json() + chat_history = list(messages) + chat_history.append({"role": "assistant", "content": data["content"], + "usage": data.get("usage")}) + session["chat_history"] = chat_history + else: + try: + error = resp.json().get("detail", "Generation failed.") + except Exception: + error = "Generation failed." + models = _load_models(token, "text") - return render_template("generate_text.html", result=result, error=error, models=models) + return render_template( + "generate_text.html", + chat_history=session.get("chat_history", []), + error=error, + models=models, + system_prompt=system_prompt, + current_model=model, + ) @app.route("/generate/image", methods=["GET", "POST"]) diff --git a/frontend/app/static/style.css b/frontend/app/static/style.css index 03d3667..8e212a9 100644 --- a/frontend/app/static/style.css +++ b/frontend/app/static/style.css @@ -695,3 +695,123 @@ pre { border-radius: 8px; margin-top: 0.5rem; } + +/* ─── Chat interface ─────────────────────────────────────────────────────── */ +.chat-page { + display: flex; + flex-direction: column; + height: calc(100vh - 100px); + max-height: 900px; +} + +.chat-header { + display: flex; + align-items: center; + justify-content: space-between; + margin-bottom: 0.75rem; +} + +.chat-config { + border: 1px solid var(--border, #ddd); + border-radius: 6px; + padding: 0.5rem 0.75rem; + margin-bottom: 0.75rem; + font-size: 0.9rem; +} + +.chat-config summary { + cursor: pointer; + font-weight: 500; + user-select: none; +} + +.chat-config-body { + display: flex; + flex-direction: column; + gap: 0.4rem; + margin-top: 0.5rem; +} + +.chat-history { + flex: 1; + overflow-y: auto; + display: flex; + flex-direction: column; + gap: 0.75rem; + padding: 0.5rem 0; + border-top: 1px solid var(--border, #ddd); + border-bottom: 1px solid var(--border, #ddd); + margin-bottom: 0.75rem; +} + +.chat-empty { + color: var(--text-muted, #888); + text-align: center; + margin: auto; + font-size: 0.9rem; +} + +.chat-bubble { + max-width: 80%; + padding: 0.6rem 0.9rem; + border-radius: 12px; + font-size: 0.9rem; + line-height: 1.5; +} + +.chat-bubble--user { + align-self: flex-end; + background: var(--accent, #7c6ff7); + color: #fff; + border-bottom-right-radius: 3px; +} + +.chat-bubble--assistant { + align-self: flex-start; + background: var(--surface-2, #f0f0f0); + color: var(--text, #222); + border-bottom-left-radius: 3px; +} + +.bubble-role { + display: block; + font-size: 0.7rem; + font-weight: 600; + text-transform: uppercase; + opacity: 0.7; + margin-bottom: 0.25rem; +} + +.bubble-content { + white-space: pre-wrap; + word-break: break-word; +} + +.bubble-meta { + display: block; + font-size: 0.7rem; + opacity: 0.6; + margin-top: 0.3rem; + text-align: right; +} + +.chat-input-row { + display: flex; + gap: 0.5rem; + align-items: flex-end; +} + +.chat-input-textarea { + flex: 1; + resize: none; + border-radius: 8px; + padding: 0.5rem 0.75rem; + font-size: 0.95rem; + min-height: 2.5rem; + max-height: 8rem; +} + +.btn-sm { + padding: 0.3rem 0.7rem; + font-size: 0.8rem; +} diff --git a/frontend/app/templates/generate_text.html b/frontend/app/templates/generate_text.html index 82f907f..c4c1a7b 100644 --- a/frontend/app/templates/generate_text.html +++ b/frontend/app/templates/generate_text.html @@ -1,52 +1,92 @@ {% extends "base.html" %} {% block title %}Text Generation — All You Can GET AI{% endblock %} {% block content %} -