From 7b622016168267be441db5e7b65d9a0a86b9a613 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Erik=20Bj=C3=A4reholt?= Date: Tue, 5 Nov 2024 09:42:14 +0100 Subject: [PATCH] feat: added support for streaming in API (#247) * feat: added support for streaming in API * fix: fixed tests for streaming generate api * fix: fixed streaming api tests, again --- gptme/server/api.py | 109 ++++++++++++++++++++++++++++++------ gptme/server/static/main.js | 96 +++++++++++++++++++++++-------- tests/test_server.py | 51 ++++++++++++++--- 3 files changed, 209 insertions(+), 47 deletions(-) diff --git a/gptme/server/api.py b/gptme/server/api.py index 424b6099..135f40a0 100644 --- a/gptme/server/api.py +++ b/gptme/server/api.py @@ -7,6 +7,7 @@ import atexit import io +import logging from contextlib import redirect_stdout from datetime import datetime from importlib import resources @@ -18,11 +19,14 @@ from ..commands import execute_cmd from ..dirs import get_logs_dir -from ..llm import reply +from ..llm import _stream from ..logmanager import LogManager, get_user_conversations, prepare_messages from ..message import Message from ..models import get_model from ..tools import execute_msg +from ..tools.base import ToolUse + +logger = logging.getLogger(__name__) api = flask.Blueprint("api", __name__) @@ -94,6 +98,7 @@ def confirm_func(msg: str) -> bool: def api_conversation_generate(logfile: str): # get model or use server default req_json = flask.request.json or {} + stream = req_json.get("stream", False) # Default to no streaming (backward compat) model = req_json.get("model", get_model().model) # load conversation @@ -101,8 +106,6 @@ def api_conversation_generate(logfile: str): # if prompt is a user-command, execute it if manager.log[-1].role == "user": - # TODO: capture output of command and return it - f = io.StringIO() print("Begin capturing stdout, to pass along command output.") with redirect_stdout(f): @@ -118,21 +121,91 @@ def api_conversation_generate(logfile: str): # performs reduction/context trimming, if necessary msgs = prepare_messages(manager.log.messages) - # generate response - # TODO: add support for streaming - msg = reply(msgs, model=model, stream=True) - msg = msg.replace(quiet=True) - - # log response and run tools - resp_msgs = [] - manager.append(msg) - resp_msgs.append(msg) - for reply_msg in execute_msg(msg, confirm_func): - manager.append(reply_msg) - resp_msgs.append(reply_msg) - - return flask.jsonify( - [{"role": msg.role, "content": msg.content} for msg in resp_msgs] + if not msgs: + logger.error("No messages to process") + return flask.jsonify({"error": "No messages to process"}) + + if not stream: + # Non-streaming response + try: + # Get complete response + output = "".join(_stream(msgs, model)) + + # Store the message + msg = Message("assistant", output) + msg = msg.replace(quiet=True) + manager.append(msg) + + # Execute any tools + reply_msgs = list(execute_msg(msg, confirm_func)) + for reply_msg in reply_msgs: + manager.append(reply_msg) + + # Return all messages + response = [{"role": "assistant", "content": output, "stored": True}] + response.extend( + {"role": msg.role, "content": msg.content, "stored": True} + for msg in reply_msgs + ) + return flask.jsonify(response) + + except Exception as e: + logger.exception("Error during generation") + return flask.jsonify({"error": str(e)}) + + # Streaming response + def generate(): + # Start with an empty message + output = "" + try: + logger.info(f"Starting generation for conversation {logfile}") + + # Prepare messages for the model + if not msgs: + logger.error("No messages to process") + yield f"data: {flask.json.dumps({'error': 'No messages to process'})}\n\n" + return + + # Stream tokens from the model + logger.debug(f"Starting token stream with model {model}") + for char in (char for chunk in _stream(msgs, model) for char in chunk): + output += char + # Send each token as a JSON event + yield f"data: {flask.json.dumps({'role': 'assistant', 'content': char, 'stored': False})}\n\n" + + # Check for complete tool uses + tooluses = list(ToolUse.iter_from_content(output)) + if tooluses and any(tooluse.is_runnable for tooluse in tooluses): + logger.debug("Found runnable tool use, breaking stream") + break + + # Store the complete message + logger.debug(f"Storing complete message: {output[:100]}...") + msg = Message("assistant", output) + msg = msg.replace(quiet=True) + manager.append(msg) + + # Execute any tools and stream their output + for reply_msg in execute_msg(msg, confirm_func): + logger.debug( + f"Tool output: {reply_msg.role} - {reply_msg.content[:100]}..." + ) + manager.append(reply_msg) + yield f"data: {flask.json.dumps({'role': reply_msg.role, 'content': reply_msg.content, 'stored': True})}\n\n" + + except Exception as e: + logger.exception("Error during generation") + yield f"data: {flask.json.dumps({'error': str(e)})}\n\n" + finally: + logger.info("Generation complete") + + return flask.Response( + generate(), + mimetype="text/event-stream", + headers={ + "Cache-Control": "no-cache", + "X-Accel-Buffering": "no", # Disable buffering in nginx + }, ) diff --git a/gptme/server/static/main.js b/gptme/server/static/main.js index 91bddada..a6f3c048 100644 --- a/gptme/server/static/main.js +++ b/gptme/server/static/main.js @@ -239,30 +239,82 @@ new Vue({ }, async generate() { this.generating = true; - const req = await fetch( - `${apiRoot}/${this.selectedConversation}/generate`, - { - method: "POST", - headers: { - "Content-Type": "application/json", - }, - body: JSON.stringify({ branch: this.branch }), + let currentMessage = { + role: "assistant", + content: "", + timestamp: new Date().toISOString(), + }; + this.chatLog.push(currentMessage); + + try { + // Create EventSource with POST method using fetch + const response = await fetch( + `${apiRoot}/${this.selectedConversation}/generate`, + { + method: "POST", + headers: { + "Content-Type": "application/json", + }, + body: JSON.stringify({ branch: this.branch }), + } + ); + + if (!response.ok) { + throw new Error(`HTTP error! status: ${response.status}`); } - ); - this.generating = false; - if (!req.ok) { - this.error = req.statusText; - return; - } - // req.json() can contain (not stored) responses to /commands, - // or the result of the generation. - // if it's unsaved results of a command, we need to display it - const data = await req.json(); - if (data.length == 1 && data[0].stored === false) { - this.cmdout = data[0].content; + + const reader = response.body.getReader(); + const decoder = new TextDecoder(); + + while (true) { + const {value, done} = await reader.read(); + if (done) break; + + const chunk = decoder.decode(value); + // Parse SSE data + const lines = chunk.split('\n'); + for (const line of lines) { + if (line.startsWith('data: ')) { + const data = JSON.parse(line.slice(6)); + + if (data.error) { + this.error = data.error; + break; + } + + if (data.stored === false) { + // Streaming token from assistant + currentMessage.content += data.content; + currentMessage.html = this.mdToHtml(currentMessage.content); + this.scrollToBottom(); + } else { + // Tool output or stored message + if (data.role === "system") { + this.cmdout = data.content; + } else { + // Add as a new message + const newMsg = { + role: data.role, + content: data.content, + timestamp: new Date().toISOString(), + html: this.mdToHtml(data.content), + }; + this.chatLog.push(newMsg); + } + } + } + } + } + + // After streaming is complete, reload to ensure we have the server's state + this.generating = false; + await this.selectConversation(this.selectedConversation, this.branch); + } catch (error) { + this.error = error.toString(); + this.generating = false; + // Remove the temporary message on error + this.chatLog.pop(); } - // reload conversation - await this.selectConversation(this.selectedConversation, this.branch); }, changeBranch(branch) { this.branch = branch; diff --git a/tests/test_server.py b/tests/test_server.py index d2856c1c..f8fae99a 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -71,14 +71,51 @@ def test_api_conversation_generate(conv: str, client: FlaskClient): ) assert response.status_code == 200 + # Test regular (non-streaming) response response = client.post( f"/api/conversations/{conv}/generate", - json={"model": get_model().model}, + json={"model": get_model().model, "stream": False}, ) assert response.status_code == 200 - msgs = response.get_json() - assert len(msgs) >= 1 - assert len(msgs) <= 2 - assert msgs[0]["role"] == "assistant" - if len(msgs) == 2: - assert msgs[1]["role"] == "system" + data = response.get_data(as_text=True) + assert data # Ensure we got some response + msgs_resps = response.get_json() + assert msgs_resps is not None # Ensure we got valid JSON + # Assistant message + possible tool output + assert len(msgs_resps) >= 1 + + # First message should be the assistant's response + assert msgs_resps[0]["role"] == "assistant" + + +@pytest.mark.slow +def test_api_conversation_generate_stream(conv: str, client: FlaskClient): + # Ask the assistant to generate a test response + response = client.post( + f"/api/conversations/{conv}", + json={"role": "user", "content": "hello, just testing"}, + ) + assert response.status_code == 200 + + # Test streaming response + response = client.post( + f"/api/conversations/{conv}/generate", + json={"model": get_model().model, "stream": True}, + headers={"Accept": "text/event-stream"}, + ) + assert response.status_code == 200 + assert "text/event-stream" in response.headers["Content-Type"] + + # Read and validate the streamed response + chunks = list(response.iter_encoded()) + assert len(chunks) > 0 + + # Each chunk should be a Server-Sent Event + for chunk in chunks: + chunk_str = chunk.decode("utf-8") + assert chunk_str.startswith("data: ") + # Skip empty chunks (heartbeats) + if chunk_str.strip() == "data: ": + continue + data = chunk_str.replace("data: ", "").strip() + assert data # Non-empty data