Skip to content

Commit

Permalink
FEATURE: Enable stream overwrite for LLM Chat at the event level (#66)
Browse files Browse the repository at this point in the history
* Allowing stream overwrite at event level for LLM Chat

* Added overwrite flag
  • Loading branch information
cyrus2281 authored Nov 25, 2024
1 parent f1f5801 commit f4677f9
Show file tree
Hide file tree
Showing 9 changed files with 252 additions and 100 deletions.
31 changes: 29 additions & 2 deletions examples/llm/anthropic_chat.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
#
# The input message has the following schema:
# {
# "text": "<question or request as text>"
# "query": "<question or request as text>",
# "stream": false
# }
#
# It will then send an event back to Solace with the topic: `demo/question/response`
Expand Down Expand Up @@ -66,17 +67,23 @@ flows:
base_url: ${ANTHROPIC_API_ENDPOINT}
model: ${MODEL_NAME}
temperature: 0.01
llm_mode: stream
allow_overwrite_llm_mode: true
stream_to_flow: stream_output
input_transforms:
- type: copy
source_expression: |
template:You are a helpful AI assistant. Please help with the user's request below:
<user-question>
{{text://input.payload:text}}
{{text://input.payload:query}}
</user-question>
dest_expression: user_data.llm_input:messages.0.content
- type: copy
source_expression: static:user
dest_expression: user_data.llm_input:messages.0.role
- type: copy
source_expression: input.payload:stream
dest_expression: user_data.llm_input:stream
input_selection:
source_expression: user_data.llm_input

Expand All @@ -97,3 +104,23 @@ flows:
dest_expression: user_data.output:topic
input_selection:
source_expression: user_data.output

- name: stream_output
components:
# Send response back to broker
- component_name: send_response
component_module: broker_output
component_config:
<<: *broker_connection
payload_encoding: utf-8
payload_format: json
copy_user_properties: true
input_transforms:
- type: copy
source_expression: input.payload
dest_expression: user_data.output:payload
- type: copy
source_value: demo/question/stream
dest_expression: user_data.output:topic
input_selection:
source_expression: user_data.output
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# This is a wrapper around all the LangChain chat models
# The configuration will control dynamic loading of the chat models
from uuid import uuid4
from copy import deepcopy
from collections import namedtuple
from .langchain_chat_model_base import (
LangChainChatModelBase,
info_base,
Expand All @@ -17,6 +19,48 @@ def __init__(self, **kwargs):
super().__init__(info, **kwargs)

def invoke_model(
self, input_message, messages, session_id=None, clear_history=False
self,
input_message,
messages,
session_id=None,
clear_history=False,
stream=False,
):
return self.component.invoke(messages)
if not stream:
return self.component.invoke(messages)

aggregate_result = ""
current_batch = ""
response_uuid = str(uuid4())
first_chunk = True

for chunk in self.component.stream(messages):
aggregate_result += chunk.content
current_batch += chunk.content
if len(current_batch) >= self.stream_batch_size:
if self.stream_to_flow:
self.send_streaming_message(
input_message,
current_batch,
aggregate_result,
response_uuid,
first_chunk,
)
current_batch = ""
first_chunk = False

if self.stream_to_flow:
self.send_streaming_message(
input_message,
current_batch,
aggregate_result,
response_uuid,
first_chunk,
True,
)

result = namedtuple("Result", ["content", "response_uuid"])(
aggregate_result, response_uuid
)

return result
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from abc import abstractmethod
from langchain_core.output_parsers import JsonOutputParser

from .....common.message import Message
from .....common.utils import get_obj_text
from langchain.schema.messages import (
HumanMessage,
Expand Down Expand Up @@ -39,6 +40,28 @@
"description": "Model specific configuration for the chat model. "
"See documentation for valid parameter names.",
},
{
"name": "llm_mode",
"required": False,
"description": "The mode for streaming results: 'none' or 'stream'. 'stream' will just stream the results to the named flow. 'none' will wait for the full response.",
},
{
"name": "allow_overwrite_llm_mode",
"required": False,
"description": "Whether to allow the llm_mode to be overwritten by the `stream` from the input message.",
},
{
"name": "stream_to_flow",
"required": False,
"description": "Name the flow to stream the output to - this must be configured for llm_mode='stream'.",
"default": "",
},
{
"name": "stream_batch_size",
"required": False,
"description": "The minimum number of words in a single streaming result. Default: 15.",
"default": 15,
},
{
"name": "llm_response_format",
"required": False,
Expand Down Expand Up @@ -88,10 +111,18 @@


class LangChainChatModelBase(LangChainBase):

def __init__(self, info, **kwargs):
super().__init__(info, **kwargs)
self.llm_mode = self.get_config("llm_mode", "none")
self.allow_overwrite_llm_mode = self.get_config("allow_overwrite_llm_mode")
self.stream_to_flow = self.get_config("stream_to_flow", "")
self.stream_batch_size = self.get_config("stream_batch_size", 15)

def invoke(self, message, data):
messages = []

for item in data["messages"]:
for item in data.get("messages"):
if item["role"] == "system":
messages.append(SystemMessage(content=item["content"]))
elif item["role"] == "user" or item["role"] == "human":
Expand All @@ -109,9 +140,22 @@ def invoke(self, message, data):

session_id = data.get("session_id", None)
clear_history = data.get("clear_history", False)
stream = data.get("stream")

should_stream = self.llm_mode == "stream"
if (
self.allow_overwrite_llm_mode
and stream is not None
and isinstance(stream, bool)
):
should_stream = stream

llm_res = self.invoke_model(
message, messages, session_id=session_id, clear_history=clear_history
message,
messages,
session_id=session_id,
clear_history=clear_history,
stream=should_stream,
)

res_format = self.get_config("llm_response_format", "text")
Expand All @@ -134,6 +178,32 @@ def invoke(self, message, data):

@abstractmethod
def invoke_model(
self, input_message, messages, session_id=None, clear_history=False
self,
input_message,
messages,
session_id=None,
clear_history=False,
stream=False,
):
pass

def send_streaming_message(
self,
input_message,
chunk,
aggregate_result,
response_uuid,
first_chunk=False,
last_chunk=False,
):
message = Message(
payload={
"chunk": chunk,
"content": aggregate_result,
"response_uuid": response_uuid,
"first_chunk": first_chunk,
"last_chunk": last_chunk,
},
user_properties=input_message.get_user_properties(),
)
self.send_to_flow(self.stream_to_flow, message)
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
SystemMessage,
)

from .....common.message import Message
from .langchain_chat_model_base import (
LangChainChatModelBase,
info_base,
Expand Down Expand Up @@ -78,23 +77,6 @@
"description": "The configuration for the history class.",
"type": "object",
},
{
"name": "stream_to_flow",
"required": False,
"description": "Name the flow to stream the output to - this must be configured for llm_mode='stream'.",
"default": "",
},
{
"name": "llm_mode",
"required": False,
"description": "The mode for streaming results: 'sync' or 'stream'. 'stream' will just stream the results to the named flow. 'none' will wait for the full response.",
},
{
"name": "stream_batch_size",
"required": False,
"description": "The minimum number of words in a single streaming result. Default: 15.",
"default": 15,
},
{
"name": "set_response_uuid_in_user_properties",
"required": False,
Expand Down Expand Up @@ -128,15 +110,17 @@ def __init__(self, **kwargs):
)
self.history_max_tokens = self.get_config("history_max_tokens", 8000)
self.history_max_time = self.get_config("history_max_time", None)
self.stream_to_flow = self.get_config("stream_to_flow", "")
self.llm_mode = self.get_config("llm_mode", "none")
self.stream_batch_size = self.get_config("stream_batch_size", 15)
self.set_response_uuid_in_user_properties = self.get_config(
"set_response_uuid_in_user_properties", False
)

def invoke_model(
self, input_message, messages, session_id=None, clear_history=False
self,
input_message,
messages,
session_id=None,
clear_history=False,
stream=False,
):

if clear_history:
Expand Down Expand Up @@ -171,7 +155,7 @@ def invoke_model(
history_messages_key="chat_history",
)

if self.llm_mode == "none":
if not stream:
return runnable.invoke(
{"input": human_message},
config={
Expand Down Expand Up @@ -221,27 +205,6 @@ def invoke_model(

return result

def send_streaming_message(
self,
input_message,
chunk,
aggregate_result,
response_uuid,
first_chunk=False,
last_chunk=False,
):
message = Message(
payload={
"chunk": chunk,
"content": aggregate_result,
"response_uuid": response_uuid,
"first_chunk": first_chunk,
"last_chunk": last_chunk,
},
user_properties=input_message.get_user_properties(),
)
self.send_to_flow(self.stream_to_flow, message)

def create_history(self):

history_class = self.load_component(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,40 +32,6 @@
"description": "Sampling temperature to use",
"default": 0.7,
},
{
"name": "stream_to_flow",
"required": False,
"description": (
"Name the flow to stream the output to - this must be configured for "
"llm_mode='stream'. This is mutually exclusive with stream_to_next_component."
),
"default": "",
},
{
"name": "stream_to_next_component",
"required": False,
"description": (
"Whether to stream the output to the next component in the flow. "
"This is mutually exclusive with stream_to_flow."
),
"default": False,
},
{
"name": "llm_mode",
"required": False,
"description": (
"The mode for streaming results: 'sync' or 'stream'. 'stream' "
"will just stream the results to the named flow. 'none' will "
"wait for the full response."
),
"default": "none",
},
{
"name": "stream_batch_size",
"required": False,
"description": "The minimum number of words in a single streaming result. Default: 15.",
"default": 15,
},
{
"name": "set_response_uuid_in_user_properties",
"required": False,
Expand All @@ -91,10 +57,6 @@ def __init__(self, module_info, **kwargs):
def init(self):
litellm.suppress_debug_info = True
self.load_balancer = self.get_config("load_balancer")
self.stream_to_flow = self.get_config("stream_to_flow")
self.stream_to_next_component = self.get_config("stream_to_next_component")
self.llm_mode = self.get_config("llm_mode")
self.stream_batch_size = self.get_config("stream_batch_size")
self.set_response_uuid_in_user_properties = self.get_config(
"set_response_uuid_in_user_properties"
)
Expand Down
Loading

0 comments on commit f4677f9

Please sign in to comment.