Skip to content

Commit

Permalink
feat: add function IDs to LettaMessage function calls and response (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
cpacker authored Oct 20, 2024
1 parent 4c08015 commit cf5d934
Show file tree
Hide file tree
Showing 6 changed files with 43 additions and 9 deletions.
3 changes: 2 additions & 1 deletion letta/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,7 +503,7 @@ def _get_ai_reply(
def _handle_ai_response(
self,
response_message: ChatCompletionMessage, # TODO should we eventually move the Message creation outside of this function?
override_tool_call_id: bool = True,
override_tool_call_id: bool = False,
# If we are streaming, we needed to create a Message ID ahead of time,
# and now we want to use it in the creation of the Message object
# TODO figure out a cleaner way to do this
Expand All @@ -530,6 +530,7 @@ def _handle_ai_response(

# generate UUID for tool call
if override_tool_call_id or response_message.function_call:
warnings.warn("Overriding the tool call can result in inconsistent tool call IDs during streaming")
tool_call_id = get_tool_call_id() # needs to be a string for JSON
response_message.tool_calls[0].id = tool_call_id
else:
Expand Down
14 changes: 12 additions & 2 deletions letta/llm_api/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
AgentChunkStreamingInterface,
AgentRefreshStreamingInterface,
)
from letta.utils import smart_urljoin
from letta.utils import get_tool_call_id, smart_urljoin

OPENAI_SSE_DONE = "[DONE]"

Expand Down Expand Up @@ -174,6 +174,7 @@ def openai_chat_completions_process_stream(
stream_interface: Optional[Union[AgentChunkStreamingInterface, AgentRefreshStreamingInterface]] = None,
create_message_id: bool = True,
create_message_datetime: bool = True,
override_tool_call_id: bool = True,
) -> ChatCompletionResponse:
"""Process a streaming completion response, and return a ChatCompletionRequest at the end.
Expand Down Expand Up @@ -244,6 +245,14 @@ def openai_chat_completions_process_stream(
):
assert isinstance(chat_completion_chunk, ChatCompletionChunkResponse), type(chat_completion_chunk)

# NOTE: this assumes that the tool call ID will only appear in one of the chunks during the stream
if override_tool_call_id:
for choice in chat_completion_chunk.choices:
if choice.delta.tool_calls and len(choice.delta.tool_calls) > 0:
for tool_call in choice.delta.tool_calls:
if tool_call.id is not None:
tool_call.id = get_tool_call_id()

if stream_interface:
if isinstance(stream_interface, AgentChunkStreamingInterface):
stream_interface.process_chunk(
Expand Down Expand Up @@ -290,6 +299,7 @@ def openai_chat_completions_process_stream(
else:
accum_message.content += content_delta

# TODO(charles) make sure this works for parallel tool calling?
if message_delta.tool_calls is not None:
tool_calls_delta = message_delta.tool_calls

Expand Down Expand Up @@ -340,7 +350,7 @@ def openai_chat_completions_process_stream(
assert all([c.finish_reason != TEMP_STREAM_FINISH_REASON for c in chat_completion_response.choices])
assert all(
[
all([tc != TEMP_STREAM_TOOL_CALL_ID for tc in c.message.tool_calls]) if c.message.tool_calls else True
all([tc.id != TEMP_STREAM_TOOL_CALL_ID for tc in c.message.tool_calls]) if c.message.tool_calls else True
for c in chat_completion_response.choices
]
)
Expand Down
12 changes: 8 additions & 4 deletions letta/schemas/letta_message.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,12 +78,14 @@ class FunctionCall(BaseModel):

name: str
arguments: str
function_call_id: str


class FunctionCallDelta(BaseModel):

name: Optional[str]
arguments: Optional[str]
function_call_id: Optional[str]

# NOTE: this is a workaround to exclude None values from the JSON dump,
# since the OpenAI style of returning chunks doesn't include keys with null values
Expand Down Expand Up @@ -129,10 +131,10 @@ class Config:
@classmethod
def validate_function_call(cls, v):
if isinstance(v, dict):
if "name" in v and "arguments" in v:
return FunctionCall(name=v["name"], arguments=v["arguments"])
elif "name" in v or "arguments" in v:
return FunctionCallDelta(name=v.get("name"), arguments=v.get("arguments"))
if "name" in v and "arguments" in v and "function_call_id" in v:
return FunctionCall(name=v["name"], arguments=v["arguments"], function_call_id=v["function_call_id"])
elif "name" in v or "arguments" in v or "function_call_id" in v:
return FunctionCallDelta(name=v.get("name"), arguments=v.get("arguments"), function_call_id=v.get("function_call_id"))
else:
raise ValueError("function_call must contain either 'name' or 'arguments'")
return v
Expand All @@ -147,11 +149,13 @@ class FunctionReturn(LettaMessage):
status (Literal["success", "error"]): The status of the function call
id (str): The ID of the message
date (datetime): The date the message was created in ISO format
function_call_id (str): A unique identifier for the function call that generated this message
"""

message_type: Literal["function_return"] = "function_return"
function_return: str
status: Literal["success", "error"]
function_call_id: str


# Legacy Letta API had an additional type "assistant_message" and the "function_call" was a formatted string
Expand Down
3 changes: 3 additions & 0 deletions letta/schemas/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@ def to_letta_message(
function_call=FunctionCall(
name=tool_call.function.name,
arguments=tool_call.function.arguments,
function_call_id=tool_call.id,
),
)
)
Expand All @@ -203,6 +204,7 @@ def to_letta_message(
raise ValueError(f"Invalid status: {status}")
except json.JSONDecodeError:
raise ValueError(f"Failed to decode function return: {self.text}")
assert self.tool_call_id is not None
messages.append(
# TODO make sure this is what the API returns
# function_return may not match exactly...
Expand All @@ -211,6 +213,7 @@ def to_letta_message(
date=self.created_at,
function_return=self.text,
status=status_enum,
function_call_id=self.tool_call_id,
)
)
elif self.role == MessageRole.user:
Expand Down
17 changes: 15 additions & 2 deletions letta/server/rest_api/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -531,7 +531,11 @@ def _process_chunk_to_letta_style(
processed_chunk = FunctionCallMessage(
id=message_id,
date=message_date,
function_call=FunctionCallDelta(name=tool_call_delta.get("name"), arguments=tool_call_delta.get("arguments")),
function_call=FunctionCallDelta(
name=tool_call_delta.get("name"),
arguments=tool_call_delta.get("arguments"),
function_call_id=tool_call_delta.get("id"),
),
)

else:
Expand All @@ -548,7 +552,11 @@ def _process_chunk_to_letta_style(
processed_chunk = FunctionCallMessage(
id=message_id,
date=message_date,
function_call=FunctionCallDelta(name=tool_call_delta.get("name"), arguments=tool_call_delta.get("arguments")),
function_call=FunctionCallDelta(
name=tool_call_delta.get("name"),
arguments=tool_call_delta.get("arguments"),
function_call_id=tool_call_delta.get("id"),
),
)

elif choice.finish_reason is not None:
Expand Down Expand Up @@ -759,6 +767,7 @@ def function_message(self, msg: str, msg_obj: Optional[Message] = None):
function_call=FunctionCall(
name=function_call.function.name,
arguments=function_call.function.arguments,
function_call_id=function_call.id,
),
)

Expand Down Expand Up @@ -786,21 +795,25 @@ def function_message(self, msg: str, msg_obj: Optional[Message] = None):
elif msg.startswith("Success: "):
msg = msg.replace("Success: ", "")
# new_message = {"function_return": msg, "status": "success"}
assert msg_obj.tool_call_id is not None
new_message = FunctionReturn(
id=msg_obj.id,
date=msg_obj.created_at,
function_return=msg,
status="success",
function_call_id=msg_obj.tool_call_id,
)

elif msg.startswith("Error: "):
msg = msg.replace("Error: ", "")
# new_message = {"function_return": msg, "status": "error"}
assert msg_obj.tool_call_id is not None
new_message = FunctionReturn(
id=msg_obj.id,
date=msg_obj.created_at,
function_return=msg,
status="error",
function_call_id=msg_obj.tool_call_id,
)

else:
Expand Down
3 changes: 3 additions & 0 deletions letta/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,6 +488,9 @@ def is_utc_datetime(dt: datetime) -> bool:


def get_tool_call_id() -> str:
# TODO(sarah) make this a slug-style string?
# e.g. OpenAI: "call_xlIfzR1HqAW7xJPa3ExJSg3C"
# or similar to agents: "call-xlIfzR1HqAW7xJPa3ExJSg3C"
return str(uuid.uuid4())[:TOOL_CALL_ID_MAX_LEN]


Expand Down

0 comments on commit cf5d934

Please sign in to comment.