From cf5d9347d08224755921f7b4922931aebef1d482 Mon Sep 17 00:00:00 2001 From: Charles Packer Date: Sat, 19 Oct 2024 21:47:48 -0700 Subject: [PATCH] feat: add function IDs to `LettaMessage` function calls and response (#1909) --- letta/agent.py | 3 ++- letta/llm_api/openai.py | 14 ++++++++++++-- letta/schemas/letta_message.py | 12 ++++++++---- letta/schemas/message.py | 3 +++ letta/server/rest_api/interface.py | 17 +++++++++++++++-- letta/utils.py | 3 +++ 6 files changed, 43 insertions(+), 9 deletions(-) diff --git a/letta/agent.py b/letta/agent.py index 55173cfb50..8ffb33ef3c 100644 --- a/letta/agent.py +++ b/letta/agent.py @@ -503,7 +503,7 @@ def _get_ai_reply( def _handle_ai_response( self, response_message: ChatCompletionMessage, # TODO should we eventually move the Message creation outside of this function? - override_tool_call_id: bool = True, + override_tool_call_id: bool = False, # If we are streaming, we needed to create a Message ID ahead of time, # and now we want to use it in the creation of the Message object # TODO figure out a cleaner way to do this @@ -530,6 +530,7 @@ def _handle_ai_response( # generate UUID for tool call if override_tool_call_id or response_message.function_call: + warnings.warn("Overriding the tool call can result in inconsistent tool call IDs during streaming") tool_call_id = get_tool_call_id() # needs to be a string for JSON response_message.tool_calls[0].id = tool_call_id else: diff --git a/letta/llm_api/openai.py b/letta/llm_api/openai.py index 45768fbb15..29ba9cfeef 100644 --- a/letta/llm_api/openai.py +++ b/letta/llm_api/openai.py @@ -41,7 +41,7 @@ AgentChunkStreamingInterface, AgentRefreshStreamingInterface, ) -from letta.utils import smart_urljoin +from letta.utils import get_tool_call_id, smart_urljoin OPENAI_SSE_DONE = "[DONE]" @@ -174,6 +174,7 @@ def openai_chat_completions_process_stream( stream_interface: Optional[Union[AgentChunkStreamingInterface, AgentRefreshStreamingInterface]] = None, create_message_id: bool = True, create_message_datetime: bool = True, + override_tool_call_id: bool = True, ) -> ChatCompletionResponse: """Process a streaming completion response, and return a ChatCompletionRequest at the end. @@ -244,6 +245,14 @@ def openai_chat_completions_process_stream( ): assert isinstance(chat_completion_chunk, ChatCompletionChunkResponse), type(chat_completion_chunk) + # NOTE: this assumes that the tool call ID will only appear in one of the chunks during the stream + if override_tool_call_id: + for choice in chat_completion_chunk.choices: + if choice.delta.tool_calls and len(choice.delta.tool_calls) > 0: + for tool_call in choice.delta.tool_calls: + if tool_call.id is not None: + tool_call.id = get_tool_call_id() + if stream_interface: if isinstance(stream_interface, AgentChunkStreamingInterface): stream_interface.process_chunk( @@ -290,6 +299,7 @@ def openai_chat_completions_process_stream( else: accum_message.content += content_delta + # TODO(charles) make sure this works for parallel tool calling? if message_delta.tool_calls is not None: tool_calls_delta = message_delta.tool_calls @@ -340,7 +350,7 @@ def openai_chat_completions_process_stream( assert all([c.finish_reason != TEMP_STREAM_FINISH_REASON for c in chat_completion_response.choices]) assert all( [ - all([tc != TEMP_STREAM_TOOL_CALL_ID for tc in c.message.tool_calls]) if c.message.tool_calls else True + all([tc.id != TEMP_STREAM_TOOL_CALL_ID for tc in c.message.tool_calls]) if c.message.tool_calls else True for c in chat_completion_response.choices ] ) diff --git a/letta/schemas/letta_message.py b/letta/schemas/letta_message.py index b70878927b..b3f7bf9000 100644 --- a/letta/schemas/letta_message.py +++ b/letta/schemas/letta_message.py @@ -78,12 +78,14 @@ class FunctionCall(BaseModel): name: str arguments: str + function_call_id: str class FunctionCallDelta(BaseModel): name: Optional[str] arguments: Optional[str] + function_call_id: Optional[str] # NOTE: this is a workaround to exclude None values from the JSON dump, # since the OpenAI style of returning chunks doesn't include keys with null values @@ -129,10 +131,10 @@ class Config: @classmethod def validate_function_call(cls, v): if isinstance(v, dict): - if "name" in v and "arguments" in v: - return FunctionCall(name=v["name"], arguments=v["arguments"]) - elif "name" in v or "arguments" in v: - return FunctionCallDelta(name=v.get("name"), arguments=v.get("arguments")) + if "name" in v and "arguments" in v and "function_call_id" in v: + return FunctionCall(name=v["name"], arguments=v["arguments"], function_call_id=v["function_call_id"]) + elif "name" in v or "arguments" in v or "function_call_id" in v: + return FunctionCallDelta(name=v.get("name"), arguments=v.get("arguments"), function_call_id=v.get("function_call_id")) else: raise ValueError("function_call must contain either 'name' or 'arguments'") return v @@ -147,11 +149,13 @@ class FunctionReturn(LettaMessage): status (Literal["success", "error"]): The status of the function call id (str): The ID of the message date (datetime): The date the message was created in ISO format + function_call_id (str): A unique identifier for the function call that generated this message """ message_type: Literal["function_return"] = "function_return" function_return: str status: Literal["success", "error"] + function_call_id: str # Legacy Letta API had an additional type "assistant_message" and the "function_call" was a formatted string diff --git a/letta/schemas/message.py b/letta/schemas/message.py index fa7f0be8f5..4ddcc0c86d 100644 --- a/letta/schemas/message.py +++ b/letta/schemas/message.py @@ -178,6 +178,7 @@ def to_letta_message( function_call=FunctionCall( name=tool_call.function.name, arguments=tool_call.function.arguments, + function_call_id=tool_call.id, ), ) ) @@ -203,6 +204,7 @@ def to_letta_message( raise ValueError(f"Invalid status: {status}") except json.JSONDecodeError: raise ValueError(f"Failed to decode function return: {self.text}") + assert self.tool_call_id is not None messages.append( # TODO make sure this is what the API returns # function_return may not match exactly... @@ -211,6 +213,7 @@ def to_letta_message( date=self.created_at, function_return=self.text, status=status_enum, + function_call_id=self.tool_call_id, ) ) elif self.role == MessageRole.user: diff --git a/letta/server/rest_api/interface.py b/letta/server/rest_api/interface.py index b8b06d78cf..17731f1528 100644 --- a/letta/server/rest_api/interface.py +++ b/letta/server/rest_api/interface.py @@ -531,7 +531,11 @@ def _process_chunk_to_letta_style( processed_chunk = FunctionCallMessage( id=message_id, date=message_date, - function_call=FunctionCallDelta(name=tool_call_delta.get("name"), arguments=tool_call_delta.get("arguments")), + function_call=FunctionCallDelta( + name=tool_call_delta.get("name"), + arguments=tool_call_delta.get("arguments"), + function_call_id=tool_call_delta.get("id"), + ), ) else: @@ -548,7 +552,11 @@ def _process_chunk_to_letta_style( processed_chunk = FunctionCallMessage( id=message_id, date=message_date, - function_call=FunctionCallDelta(name=tool_call_delta.get("name"), arguments=tool_call_delta.get("arguments")), + function_call=FunctionCallDelta( + name=tool_call_delta.get("name"), + arguments=tool_call_delta.get("arguments"), + function_call_id=tool_call_delta.get("id"), + ), ) elif choice.finish_reason is not None: @@ -759,6 +767,7 @@ def function_message(self, msg: str, msg_obj: Optional[Message] = None): function_call=FunctionCall( name=function_call.function.name, arguments=function_call.function.arguments, + function_call_id=function_call.id, ), ) @@ -786,21 +795,25 @@ def function_message(self, msg: str, msg_obj: Optional[Message] = None): elif msg.startswith("Success: "): msg = msg.replace("Success: ", "") # new_message = {"function_return": msg, "status": "success"} + assert msg_obj.tool_call_id is not None new_message = FunctionReturn( id=msg_obj.id, date=msg_obj.created_at, function_return=msg, status="success", + function_call_id=msg_obj.tool_call_id, ) elif msg.startswith("Error: "): msg = msg.replace("Error: ", "") # new_message = {"function_return": msg, "status": "error"} + assert msg_obj.tool_call_id is not None new_message = FunctionReturn( id=msg_obj.id, date=msg_obj.created_at, function_return=msg, status="error", + function_call_id=msg_obj.tool_call_id, ) else: diff --git a/letta/utils.py b/letta/utils.py index 13a9c531da..c85f4ef818 100644 --- a/letta/utils.py +++ b/letta/utils.py @@ -488,6 +488,9 @@ def is_utc_datetime(dt: datetime) -> bool: def get_tool_call_id() -> str: + # TODO(sarah) make this a slug-style string? + # e.g. OpenAI: "call_xlIfzR1HqAW7xJPa3ExJSg3C" + # or similar to agents: "call-xlIfzR1HqAW7xJPa3ExJSg3C" return str(uuid.uuid4())[:TOOL_CALL_ID_MAX_LEN]