Skip to content

Commit

Permalink
Merge pull request #12 from SolaceDev:better_streaming
Browse files Browse the repository at this point in the history
Better_streaming
  • Loading branch information
efunneko authored Jul 25, 2024
2 parents 8e64032 + fc02b88 commit b0b2548
Show file tree
Hide file tree
Showing 5 changed files with 66 additions and 14 deletions.
2 changes: 2 additions & 0 deletions docs/components/langchain_chat_model_with_history.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ component_config:
stream_to_flow: <string>
llm_mode: <string>
stream_batch_size: <string>
set_response_uuid_in_user_properties: <boolean>
```
| Parameter | Required | Default | Description |
Expand All @@ -38,6 +39,7 @@ component_config:
| stream_to_flow | False | | Name the flow to stream the output to - this must be configured for llm_mode='stream'. |
| llm_mode | False | | The mode for streaming results: 'sync' or 'stream'. 'stream' will just stream the results to the named flow. 'none' will wait for the full response. |
| stream_batch_size | False | 15 | The minimum number of words in a single streaming result. Default: 15. |
| set_response_uuid_in_user_properties | False | False | Whether to set the response_uuid in the user_properties of the input_message. This will allow other components to correlate streaming chunks with the full response. |
## Component Input Schema
Expand Down
4 changes: 4 additions & 0 deletions docs/components/timer_input.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,5 +22,9 @@ component_config:
## Component Output Schema
```
<<<<<<< HEAD
<None>
=======
<any>
>>>>>>> origin/main
```
11 changes: 10 additions & 1 deletion src/solace_ai_connector/components/component_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def __init__(self, module_info, **kwargs):
self.need_acknowledgement = False
self.stop_thread_event = threading.Event()
self.current_message = None
self.current_message_has_been_discarded = False

self.log_identifier = f"[{self.instance_name}.{self.flow_name}.{self.name}] "

Expand Down Expand Up @@ -159,9 +160,13 @@ def process_message(self, message):
self.trace_data(data)

# Invoke the component
self.current_message_has_been_discarded = False
result = self.invoke(message, data)

if result is not None:
if self.current_message_has_been_discarded:
# Call the message acknowledgements
message.call_acknowledgements()
elif result is not None:
# Do all the things we need to do after invoking the component
# Note that there are times where we don't want to
# send the message to the next component
Expand Down Expand Up @@ -193,6 +198,10 @@ def process_post_invoke(self, result, message):
self.current_message = message
self.send_message(message)

def discard_current_message(self):
# If the message is to be discarded, we need to acknowledge any previous components
self.current_message_has_been_discarded = True

def get_acknowledgement_callback(self):
# This should be overridden by the component if it needs to acknowledge messages
return None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import threading
from collections import namedtuple
from copy import deepcopy
from uuid import uuid4

from langchain_core.chat_history import BaseChatMessageHistory
from langchain_core.runnables.history import RunnableWithMessageHistory
Expand Down Expand Up @@ -86,6 +87,13 @@
"description": "The minimum number of words in a single streaming result. Default: 15.",
"default": 15,
},
{
"name": "set_response_uuid_in_user_properties",
"required": False,
"description": "Whether to set the response_uuid in the user_properties of the input_message. This will allow other components to correlate streaming chunks with the full response.",
"default": False,
"type": "boolean",
},
]
)
info["input_schema"]["properties"]["session_id"] = {
Expand Down Expand Up @@ -114,6 +122,9 @@ def __init__(self, **kwargs):
self.stream_to_flow = self.get_config("stream_to_flow", "")
self.llm_mode = self.get_config("llm_mode", "none")
self.stream_batch_size = self.get_config("stream_batch_size", 15)
self.set_response_uuid_in_user_properties = self.get_config(
"set_response_uuid_in_user_properties", False
)

def invoke_model(
self, input_message, messages, session_id=None, clear_history=False
Expand Down Expand Up @@ -161,6 +172,8 @@ def invoke_model(

aggregate_result = ""
current_batch = ""
response_uuid = str(uuid4())
first_chunk = True
for chunk in runnable.stream(
{"input": human_message},
config={
Expand All @@ -172,25 +185,50 @@ def invoke_model(
if len(current_batch.split()) >= self.stream_batch_size:
if self.stream_to_flow:
self.send_streaming_message(
input_message, current_batch, aggregate_result
input_message,
current_batch,
aggregate_result,
response_uuid,
first_chunk,
)
current_batch = ""
first_chunk = False

if current_batch:
if self.stream_to_flow:
self.send_streaming_message(
input_message, current_batch, aggregate_result
)
if self.stream_to_flow:
self.send_streaming_message(
input_message,
current_batch,
aggregate_result,
response_uuid,
first_chunk,
True,
)

result = namedtuple("Result", ["content"])(aggregate_result)
result = namedtuple("Result", ["content", "uuid"])(
aggregate_result, response_uuid
)

self.prune_large_message_from_history(session_id)

return result

def send_streaming_message(self, input_message, chunk, aggregate_result):
def send_streaming_message(
self,
input_message,
chunk,
aggregate_result,
response_uuid,
first_chunk=False,
last_chunk=False,
):
message = Message(
payload={"chunk": chunk, "aggregate_result": aggregate_result},
payload={
"chunk": chunk,
"aggregate_result": aggregate_result,
"response_uuid": response_uuid,
"first_chunk": first_chunk,
"last_chunk": last_chunk,
},
user_properties=input_message.get_user_properties(),
)
self.send_to_flow(self.stream_to_flow, message)
Expand All @@ -205,9 +243,6 @@ def create_history(self):
)
config = self.get_config("history_config", {})
history = self.create_component(config, history_class)
# memory = ConversationTokenBufferMemory(
# chat_memory=history, llm=self.component, max_token_limit=history_max_tokens
# )
return history

def get_history(self, session_id: str) -> BaseChatMessageHistory:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,9 @@
"required": False,
},
],
"output_schema": {"type": "any"},
"output_schema": {
"type": "None",
},
}


Expand Down

0 comments on commit b0b2548

Please sign in to comment.