Skip to content

Commit

Permalink
feat: Separate out streaming route (#2111)
Browse files Browse the repository at this point in the history
  • Loading branch information
mattzh72 authored Nov 27, 2024
1 parent c2b41f7 commit f23d436
Show file tree
Hide file tree
Showing 16 changed files with 292 additions and 402 deletions.
4 changes: 2 additions & 2 deletions examples/swarm/swarm.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,9 @@ def run(self, agent_name: str, message: str):
# print(self.client.get_agent(agent_id).tools)
# TODO: implement with sending multiple messages
if len(history) == 0:
response = self.client.send_message(agent_id=agent_id, message=message, role="user", include_full_message=True)
response = self.client.send_message(agent_id=agent_id, message=message, role="user")
else:
response = self.client.send_messages(agent_id=agent_id, messages=history, include_full_message=True)
response = self.client.send_messages(agent_id=agent_id, messages=history)

# update history
history += response.messages
Expand Down
47 changes: 13 additions & 34 deletions letta/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,11 +154,10 @@ def send_message(
stream: Optional[bool] = False,
stream_steps: bool = False,
stream_tokens: bool = False,
include_full_message: Optional[bool] = False,
) -> LettaResponse:
raise NotImplementedError

def user_message(self, agent_id: str, message: str, include_full_message: Optional[bool] = False) -> LettaResponse:
def user_message(self, agent_id: str, message: str) -> LettaResponse:
raise NotImplementedError

def create_human(self, name: str, text: str) -> Human:
Expand Down Expand Up @@ -839,7 +838,7 @@ def get_in_context_messages(self, agent_id: str) -> List[Message]:

# agent interactions

def user_message(self, agent_id: str, message: str, include_full_message: Optional[bool] = False) -> LettaResponse:
def user_message(self, agent_id: str, message: str) -> LettaResponse:
"""
Send a message to an agent as a user
Expand All @@ -850,7 +849,7 @@ def user_message(self, agent_id: str, message: str, include_full_message: Option
Returns:
response (LettaResponse): Response from the agent
"""
return self.send_message(agent_id, message, role="user", include_full_message=include_full_message)
return self.send_message(agent_id=agent_id, message=message, role="user")

def save(self):
raise NotImplementedError
Expand Down Expand Up @@ -937,13 +936,13 @@ def get_messages(

def send_message(
self,
agent_id: str,
message: str,
role: str,
agent_id: Optional[str] = None,
name: Optional[str] = None,
stream: Optional[bool] = False,
stream_steps: bool = False,
stream_tokens: bool = False,
include_full_message: bool = False,
) -> Union[LettaResponse, Generator[LettaStreamingResponse, None, None]]:
"""
Send a message to an agent
Expand All @@ -964,17 +963,11 @@ def send_message(
# TODO: figure out how to handle stream_steps and stream_tokens

# When streaming steps is True, stream_tokens must be False
request = LettaRequest(
messages=messages,
stream_steps=stream_steps,
stream_tokens=stream_tokens,
return_message_object=include_full_message,
)
request = LettaRequest(messages=messages)
if stream_tokens or stream_steps:
from letta.client.streaming import _sse_post

request.return_message_object = False
return _sse_post(f"{self.base_url}/{self.api_prefix}/agents/{agent_id}/messages", request.model_dump(), self.headers)
return _sse_post(f"{self.base_url}/{self.api_prefix}/agents/{agent_id}/messages/stream", request.model_dump(), self.headers)
else:
response = requests.post(
f"{self.base_url}/{self.api_prefix}/agents/{agent_id}/messages", json=request.model_dump(), headers=self.headers
Expand Down Expand Up @@ -2250,7 +2243,6 @@ def send_messages(
self,
agent_id: str,
messages: List[Union[Message | MessageCreate]],
include_full_message: Optional[bool] = False,
):
"""
Send pre-packed messages to an agent.
Expand All @@ -2270,15 +2262,7 @@ def send_messages(
self.save()

# format messages
messages = self.interface.to_list()
if include_full_message:
letta_messages = messages
else:
letta_messages = []
for m in messages:
letta_messages += m.to_letta_message()

return LettaResponse(messages=letta_messages, usage=usage)
return LettaResponse(messages=messages, usage=usage)

def send_message(
self,
Expand All @@ -2289,7 +2273,6 @@ def send_message(
agent_name: Optional[str] = None,
stream_steps: bool = False,
stream_tokens: bool = False,
include_full_message: Optional[bool] = False,
) -> LettaResponse:
"""
Send a message to an agent
Expand Down Expand Up @@ -2338,16 +2321,13 @@ def send_message(

# format messages
messages = self.interface.to_list()
if include_full_message:
letta_messages = messages
else:
letta_messages = []
for m in messages:
letta_messages += m.to_letta_message()
letta_messages = []
for m in messages:
letta_messages += m.to_letta_message()

return LettaResponse(messages=letta_messages, usage=usage)

def user_message(self, agent_id: str, message: str, include_full_message: Optional[bool] = False) -> LettaResponse:
def user_message(self, agent_id: str, message: str) -> LettaResponse:
"""
Send a message to an agent as a user
Expand All @@ -2359,7 +2339,7 @@ def user_message(self, agent_id: str, message: str, include_full_message: Option
response (LettaResponse): Response from the agent
"""
self.interface.clear()
return self.send_message(role="user", agent_id=agent_id, message=message, include_full_message=include_full_message)
return self.send_message(role="user", agent_id=agent_id, message=message)

def run_command(self, agent_id: str, command: str) -> LettaResponse:
"""
Expand Down Expand Up @@ -2951,7 +2931,6 @@ def get_messages(
after=after,
limit=limit,
reverse=True,
return_message_object=True,
)

def list_blocks(self, label: Optional[str] = None, templates_only: Optional[bool] = True) -> List[Block]:
Expand Down
2 changes: 1 addition & 1 deletion letta/functions/function_sets/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# If the function fails, throw an exception


def send_message(self: Agent, message: str) -> Optional[str]:
def send_message(self: "Agent", message: str) -> Optional[str]:
"""
Sends a message to the human user.
Expand Down
34 changes: 11 additions & 23 deletions letta/schemas/letta_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,33 +8,21 @@

class LettaRequest(BaseModel):
messages: Union[List[MessageCreate], List[Message]] = Field(..., description="The messages to be sent to the agent.")
run_async: bool = Field(default=False, description="Whether to asynchronously send the messages to the agent.") # TODO: implement

stream_steps: bool = Field(
default=False, description="Flag to determine if the response should be streamed. Set to True for streaming agent steps."
)
stream_tokens: bool = Field(
default=False,
description="Flag to determine if individual tokens should be streamed. Set to True for token streaming (requires stream_steps = True).",
)

return_message_object: bool = Field(
default=False,
description="Set True to return the raw Message object. Set False to return the Message in the format of the Letta API.",
)

# Flags to support the use of AssistantMessage message types

use_assistant_message: bool = Field(
default=False,
description="[Only applicable if return_message_object is False] If true, returns AssistantMessage objects when the agent calls a designated message tool. If false, return FunctionCallMessage objects for all tool calls.",
)

assistant_message_function_name: str = Field(
assistant_message_tool_name: str = Field(
default=DEFAULT_MESSAGE_TOOL,
description="[Only applicable if use_assistant_message is True] The name of the designated message tool.",
description="The name of the designated message tool.",
)
assistant_message_function_kwarg: str = Field(
assistant_message_tool_kwarg: str = Field(
default=DEFAULT_MESSAGE_TOOL_KWARG,
description="[Only applicable if use_assistant_message is True] The name of the message argument in the designated message tool.",
description="The name of the message argument in the designated message tool.",
)


class LettaStreamingRequest(LettaRequest):
stream_tokens: bool = Field(
default=False,
description="Flag to determine if individual tokens should be streamed. Set to True for token streaming (requires stream_steps = True).",
)
3 changes: 1 addition & 2 deletions letta/schemas/letta_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

from letta.schemas.enums import MessageStreamStatus
from letta.schemas.letta_message import LettaMessage, LettaMessageUnion
from letta.schemas.message import Message
from letta.schemas.usage import LettaUsageStatistics
from letta.utils import json_dumps

Expand All @@ -24,7 +23,7 @@ class LettaResponse(BaseModel):
usage (LettaUsageStatistics): The usage statistics
"""

messages: Union[List[Message], List[LettaMessageUnion]] = Field(..., description="The messages returned by the agent.")
messages: List[LettaMessageUnion] = Field(..., description="The messages returned by the agent.")
usage: LettaUsageStatistics = Field(..., description="The usage statistics of the agent.")

def __str__(self):
Expand Down
6 changes: 3 additions & 3 deletions letta/schemas/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,8 +134,8 @@ def to_json(self):
def to_letta_message(
self,
assistant_message: bool = False,
assistant_message_function_name: str = DEFAULT_MESSAGE_TOOL,
assistant_message_function_kwarg: str = DEFAULT_MESSAGE_TOOL_KWARG,
assistant_message_tool_name: str = DEFAULT_MESSAGE_TOOL,
assistant_message_tool_kwarg: str = DEFAULT_MESSAGE_TOOL_KWARG,
) -> List[LettaMessage]:
"""Convert message object (in DB format) to the style used by the original Letta API"""

Expand All @@ -156,7 +156,7 @@ def to_letta_message(
for tool_call in self.tool_calls:
# If we're supporting using assistant message,
# then we want to treat certain function calls as a special case
if assistant_message and tool_call.function.name == assistant_message_function_name:
if assistant_message and tool_call.function.name == assistant_message_tool_name:
# We need to unpack the actual message contents from the function call
try:
func_args = json.loads(tool_call.function.arguments)
Expand Down
31 changes: 12 additions & 19 deletions letta/server/rest_api/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,9 +271,8 @@ def __init__(
self,
multi_step=True,
# Related to if we want to try and pass back the AssistantMessage as a special case function
use_assistant_message=False,
assistant_message_function_name=DEFAULT_MESSAGE_TOOL,
assistant_message_function_kwarg=DEFAULT_MESSAGE_TOOL_KWARG,
assistant_message_tool_name=DEFAULT_MESSAGE_TOOL,
assistant_message_tool_kwarg=DEFAULT_MESSAGE_TOOL_KWARG,
# Related to if we expect inner_thoughts to be in the kwargs
inner_thoughts_in_kwargs=True,
inner_thoughts_kwarg=INNER_THOUGHTS_KWARG,
Expand All @@ -287,7 +286,7 @@ def __init__(
self.streaming_chat_completion_mode_function_name = None # NOTE: sadly need to track state during stream
# If chat completion mode, we need a special stream reader to
# turn function argument to send_message into a normal text stream
self.streaming_chat_completion_json_reader = FunctionArgumentsStreamHandler(json_key=assistant_message_function_kwarg)
self.streaming_chat_completion_json_reader = FunctionArgumentsStreamHandler(json_key=assistant_message_tool_kwarg)

self._chunks = deque()
self._event = asyncio.Event() # Use an event to notify when chunks are available
Expand All @@ -300,9 +299,9 @@ def __init__(
self.multi_step_gen_indicator = MessageStreamStatus.done_generation

# Support for AssistantMessage
self.use_assistant_message = use_assistant_message
self.assistant_message_function_name = assistant_message_function_name
self.assistant_message_function_kwarg = assistant_message_function_kwarg
self.use_assistant_message = False # TODO: Remove this
self.assistant_message_tool_name = assistant_message_tool_name
self.assistant_message_tool_kwarg = assistant_message_tool_kwarg

# Support for inner_thoughts_in_kwargs
self.inner_thoughts_in_kwargs = inner_thoughts_in_kwargs
Expand Down Expand Up @@ -455,17 +454,14 @@ def _process_chunk_to_letta_style(

# If we get a "hit" on the special keyword we're looking for, we want to skip to the next chunk
# TODO I don't think this handles the function name in multi-pieces problem. Instead, we should probably reset the streaming_chat_completion_mode_function_name when we make this hit?
# if self.streaming_chat_completion_mode_function_name == self.assistant_message_function_name:
if tool_call.function.name == self.assistant_message_function_name:
# if self.streaming_chat_completion_mode_function_name == self.assistant_message_tool_name:
if tool_call.function.name == self.assistant_message_tool_name:
self.streaming_chat_completion_json_reader.reset()
# early exit to turn into content mode
return None

# if we're in the middle of parsing a send_message, we'll keep processing the JSON chunks
if (
tool_call.function.arguments
and self.streaming_chat_completion_mode_function_name == self.assistant_message_function_name
):
if tool_call.function.arguments and self.streaming_chat_completion_mode_function_name == self.assistant_message_tool_name:
# Strip out any extras tokens
cleaned_func_args = self.streaming_chat_completion_json_reader.process_json_chunk(tool_call.function.arguments)
# In the case that we just have the prefix of something, no message yet, then we should early exit to move to the next chunk
Expand Down Expand Up @@ -500,9 +496,6 @@ def _process_chunk_to_letta_style(
)

elif self.inner_thoughts_in_kwargs and tool_call.function:
if self.use_assistant_message:
raise NotImplementedError("inner_thoughts_in_kwargs with use_assistant_message not yet supported")

processed_chunk = None

if tool_call.function.name:
Expand Down Expand Up @@ -909,13 +902,13 @@ def function_message(self, msg: str, msg_obj: Optional[Message] = None):

if (
self.use_assistant_message
and function_call.function.name == self.assistant_message_function_name
and self.assistant_message_function_kwarg in func_args
and function_call.function.name == self.assistant_message_tool_name
and self.assistant_message_tool_kwarg in func_args
):
processed_chunk = AssistantMessage(
id=msg_obj.id,
date=msg_obj.created_at,
assistant_message=func_args[self.assistant_message_function_kwarg],
assistant_message=func_args[self.assistant_message_tool_kwarg],
)
else:
processed_chunk = FunctionCallMessage(
Expand Down
1 change: 0 additions & 1 deletion letta/server/rest_api/routers/openai/assistants/threads.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,6 @@ def list_messages(
before=before_uuid,
order_by="created_at",
reverse=reverse,
return_message_object=True,
)
assert isinstance(json_messages, List)
assert all([isinstance(message, Message) for message in json_messages])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,6 @@ async def create_chat_completion(
stream_tokens=True,
# Turn on ChatCompletion mode (eg remaps send_message to content)
chat_completion_mode=True,
return_message_object=False,
)

else:
Expand All @@ -86,7 +85,6 @@ async def create_chat_completion(
# Turn streaming OFF
stream_steps=False,
stream_tokens=False,
return_message_object=False,
)
# print(response_messages)

Expand Down
Loading

0 comments on commit f23d436

Please sign in to comment.