Skip to content

Commit

Permalink
feat: added support for streaming in API (#247)
Browse files Browse the repository at this point in the history
* feat: added support for streaming in API

* fix: fixed tests for streaming generate api

* fix: fixed streaming api tests, again
  • Loading branch information
ErikBjare authored Nov 5, 2024
1 parent b2f2b47 commit 7b62201
Show file tree
Hide file tree
Showing 3 changed files with 209 additions and 47 deletions.
109 changes: 91 additions & 18 deletions gptme/server/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import atexit
import io
import logging
from contextlib import redirect_stdout
from datetime import datetime
from importlib import resources
Expand All @@ -18,11 +19,14 @@

from ..commands import execute_cmd
from ..dirs import get_logs_dir
from ..llm import reply
from ..llm import _stream
from ..logmanager import LogManager, get_user_conversations, prepare_messages
from ..message import Message
from ..models import get_model
from ..tools import execute_msg
from ..tools.base import ToolUse

logger = logging.getLogger(__name__)

api = flask.Blueprint("api", __name__)

Expand Down Expand Up @@ -94,15 +98,14 @@ def confirm_func(msg: str) -> bool:
def api_conversation_generate(logfile: str):
# get model or use server default
req_json = flask.request.json or {}
stream = req_json.get("stream", False) # Default to no streaming (backward compat)
model = req_json.get("model", get_model().model)

# load conversation
manager = LogManager.load(logfile, branch=req_json.get("branch", "main"))

# if prompt is a user-command, execute it
if manager.log[-1].role == "user":
# TODO: capture output of command and return it

f = io.StringIO()
print("Begin capturing stdout, to pass along command output.")
with redirect_stdout(f):
Expand All @@ -118,21 +121,91 @@ def api_conversation_generate(logfile: str):
# performs reduction/context trimming, if necessary
msgs = prepare_messages(manager.log.messages)

# generate response
# TODO: add support for streaming
msg = reply(msgs, model=model, stream=True)
msg = msg.replace(quiet=True)

# log response and run tools
resp_msgs = []
manager.append(msg)
resp_msgs.append(msg)
for reply_msg in execute_msg(msg, confirm_func):
manager.append(reply_msg)
resp_msgs.append(reply_msg)

return flask.jsonify(
[{"role": msg.role, "content": msg.content} for msg in resp_msgs]
if not msgs:
logger.error("No messages to process")
return flask.jsonify({"error": "No messages to process"})

if not stream:
# Non-streaming response
try:
# Get complete response
output = "".join(_stream(msgs, model))

# Store the message
msg = Message("assistant", output)
msg = msg.replace(quiet=True)
manager.append(msg)

# Execute any tools
reply_msgs = list(execute_msg(msg, confirm_func))
for reply_msg in reply_msgs:
manager.append(reply_msg)

# Return all messages
response = [{"role": "assistant", "content": output, "stored": True}]
response.extend(
{"role": msg.role, "content": msg.content, "stored": True}
for msg in reply_msgs
)
return flask.jsonify(response)

except Exception as e:
logger.exception("Error during generation")
return flask.jsonify({"error": str(e)})

# Streaming response
def generate():
# Start with an empty message
output = ""
try:
logger.info(f"Starting generation for conversation {logfile}")

# Prepare messages for the model
if not msgs:
logger.error("No messages to process")
yield f"data: {flask.json.dumps({'error': 'No messages to process'})}\n\n"
return

# Stream tokens from the model
logger.debug(f"Starting token stream with model {model}")
for char in (char for chunk in _stream(msgs, model) for char in chunk):
output += char
# Send each token as a JSON event
yield f"data: {flask.json.dumps({'role': 'assistant', 'content': char, 'stored': False})}\n\n"

# Check for complete tool uses
tooluses = list(ToolUse.iter_from_content(output))
if tooluses and any(tooluse.is_runnable for tooluse in tooluses):
logger.debug("Found runnable tool use, breaking stream")
break

# Store the complete message
logger.debug(f"Storing complete message: {output[:100]}...")
msg = Message("assistant", output)
msg = msg.replace(quiet=True)
manager.append(msg)

# Execute any tools and stream their output
for reply_msg in execute_msg(msg, confirm_func):
logger.debug(
f"Tool output: {reply_msg.role} - {reply_msg.content[:100]}..."
)
manager.append(reply_msg)
yield f"data: {flask.json.dumps({'role': reply_msg.role, 'content': reply_msg.content, 'stored': True})}\n\n"

except Exception as e:
logger.exception("Error during generation")
yield f"data: {flask.json.dumps({'error': str(e)})}\n\n"
finally:
logger.info("Generation complete")

return flask.Response(
generate(),
mimetype="text/event-stream",
headers={
"Cache-Control": "no-cache",
"X-Accel-Buffering": "no", # Disable buffering in nginx
},
)


Expand Down
96 changes: 74 additions & 22 deletions gptme/server/static/main.js
Original file line number Diff line number Diff line change
Expand Up @@ -239,30 +239,82 @@ new Vue({
},
async generate() {
this.generating = true;
const req = await fetch(
`${apiRoot}/${this.selectedConversation}/generate`,
{
method: "POST",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify({ branch: this.branch }),
let currentMessage = {
role: "assistant",
content: "",
timestamp: new Date().toISOString(),
};
this.chatLog.push(currentMessage);

try {
// Create EventSource with POST method using fetch
const response = await fetch(
`${apiRoot}/${this.selectedConversation}/generate`,
{
method: "POST",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify({ branch: this.branch }),
}
);

if (!response.ok) {
throw new Error(`HTTP error! status: ${response.status}`);
}
);
this.generating = false;
if (!req.ok) {
this.error = req.statusText;
return;
}
// req.json() can contain (not stored) responses to /commands,
// or the result of the generation.
// if it's unsaved results of a command, we need to display it
const data = await req.json();
if (data.length == 1 && data[0].stored === false) {
this.cmdout = data[0].content;

const reader = response.body.getReader();
const decoder = new TextDecoder();

while (true) {
const {value, done} = await reader.read();
if (done) break;

const chunk = decoder.decode(value);
// Parse SSE data
const lines = chunk.split('\n');
for (const line of lines) {
if (line.startsWith('data: ')) {
const data = JSON.parse(line.slice(6));

if (data.error) {
this.error = data.error;
break;
}

if (data.stored === false) {
// Streaming token from assistant
currentMessage.content += data.content;
currentMessage.html = this.mdToHtml(currentMessage.content);
this.scrollToBottom();
} else {
// Tool output or stored message
if (data.role === "system") {
this.cmdout = data.content;
} else {
// Add as a new message
const newMsg = {
role: data.role,
content: data.content,
timestamp: new Date().toISOString(),
html: this.mdToHtml(data.content),
};
this.chatLog.push(newMsg);
}
}
}
}
}

// After streaming is complete, reload to ensure we have the server's state
this.generating = false;
await this.selectConversation(this.selectedConversation, this.branch);
} catch (error) {
this.error = error.toString();
this.generating = false;
// Remove the temporary message on error
this.chatLog.pop();
}
// reload conversation
await this.selectConversation(this.selectedConversation, this.branch);
},
changeBranch(branch) {
this.branch = branch;
Expand Down
51 changes: 44 additions & 7 deletions tests/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,14 +71,51 @@ def test_api_conversation_generate(conv: str, client: FlaskClient):
)
assert response.status_code == 200

# Test regular (non-streaming) response
response = client.post(
f"/api/conversations/{conv}/generate",
json={"model": get_model().model},
json={"model": get_model().model, "stream": False},
)
assert response.status_code == 200
msgs = response.get_json()
assert len(msgs) >= 1
assert len(msgs) <= 2
assert msgs[0]["role"] == "assistant"
if len(msgs) == 2:
assert msgs[1]["role"] == "system"
data = response.get_data(as_text=True)
assert data # Ensure we got some response
msgs_resps = response.get_json()
assert msgs_resps is not None # Ensure we got valid JSON
# Assistant message + possible tool output
assert len(msgs_resps) >= 1

# First message should be the assistant's response
assert msgs_resps[0]["role"] == "assistant"


@pytest.mark.slow
def test_api_conversation_generate_stream(conv: str, client: FlaskClient):
# Ask the assistant to generate a test response
response = client.post(
f"/api/conversations/{conv}",
json={"role": "user", "content": "hello, just testing"},
)
assert response.status_code == 200

# Test streaming response
response = client.post(
f"/api/conversations/{conv}/generate",
json={"model": get_model().model, "stream": True},
headers={"Accept": "text/event-stream"},
)
assert response.status_code == 200
assert "text/event-stream" in response.headers["Content-Type"]

# Read and validate the streamed response
chunks = list(response.iter_encoded())
assert len(chunks) > 0

# Each chunk should be a Server-Sent Event
for chunk in chunks:
chunk_str = chunk.decode("utf-8")
assert chunk_str.startswith("data: ")
# Skip empty chunks (heartbeats)
if chunk_str.strip() == "data: ":
continue
data = chunk_str.replace("data: ", "").strip()
assert data # Non-empty data

0 comments on commit 7b62201

Please sign in to comment.