Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add functions to get context window overview #1903

Merged
merged 8 commits into from
Oct 18, 2024
Merged
28 changes: 26 additions & 2 deletions letta/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from letta.interface import AgentInterface
from letta.llm_api.helpers import is_context_overflow_error
from letta.llm_api.llm_api_tools import create
from letta.local_llm.utils import num_tokens_from_messages
from letta.local_llm.utils import num_tokens_from_functions, num_tokens_from_messages
from letta.memory import ArchivalMemory, RecallMemory, summarize_messages
from letta.metadata import MetadataStore
from letta.persistence_manager import LocalStateManager
Expand All @@ -33,6 +33,9 @@
from letta.schemas.enums import MessageRole
from letta.schemas.memory import ContextWindowOverview, Memory
from letta.schemas.message import Message, UpdateMessage
from letta.schemas.openai.chat_completion_request import (
Tool as ChatCompletionRequestTool,
)
from letta.schemas.openai.chat_completion_response import ChatCompletionResponse
from letta.schemas.openai.chat_completion_response import (
Message as ChatCompletionMessage,
Expand Down Expand Up @@ -1458,6 +1461,24 @@ def get_context_window(self) -> ContextWindowOverview:
)
num_tokens_external_memory_summary = count_tokens(external_memory_summary)

# tokens taken up by function definitions
if self.functions:
available_functions_definitions = [ChatCompletionRequestTool(type="function", function=f) for f in self.functions]
num_tokens_available_functions_definitions = num_tokens_from_functions(functions=self.functions, model=self.model)
else:
available_functions_definitions = []
num_tokens_available_functions_definitions = 0

num_tokens_used_total = (
num_tokens_system # system prompt
+ num_tokens_available_functions_definitions # function definitions
+ num_tokens_core_memory # core memory
+ num_tokens_external_memory_summary # metadata (statistics) about recall/archival
+ num_tokens_summary_memory # summary of ongoing conversation
+ num_tokens_messages # tokens taken by messages
)
assert isinstance(num_tokens_used_total, int)

return ContextWindowOverview(
# context window breakdown (in messages)
num_messages=len(self._messages),
Expand All @@ -1466,7 +1487,7 @@ def get_context_window(self) -> ContextWindowOverview:
num_tokens_external_memory_summary=num_tokens_external_memory_summary,
# top-level information
context_window_size_max=self.agent_state.llm_config.context_window,
context_window_size_current=num_tokens_system + num_tokens_core_memory + num_tokens_summary_memory + num_tokens_messages,
context_window_size_current=num_tokens_used_total,
# context window breakdown (in tokens)
num_tokens_system=num_tokens_system,
system_prompt=system_prompt,
Expand All @@ -1476,6 +1497,9 @@ def get_context_window(self) -> ContextWindowOverview:
summary_memory=summary_memory,
num_tokens_messages=num_tokens_messages,
messages=self._messages,
# related to functions
num_tokens_functions_definitions=num_tokens_available_functions_definitions,
functions_definitions=available_functions_definitions,
)


Expand Down
21 changes: 16 additions & 5 deletions letta/llm_api/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,13 @@
from letta.schemas.llm_config import LLMConfig
from letta.schemas.message import Message as _Message
from letta.schemas.message import MessageRole as _MessageRole
from letta.schemas.openai.chat_completion_request import ChatCompletionRequest
from letta.schemas.openai.chat_completion_request import (
ChatCompletionRequest,
FunctionCall as ToolFunctionChoiceFunctionCall,
)
from letta.schemas.openai.chat_completion_request import (
Tool,
ToolFunctionChoice,
cast_message_to_subtype,
)
from letta.schemas.openai.chat_completion_response import (
Expand Down Expand Up @@ -100,10 +105,10 @@ def openai_get_model_list(

def build_openai_chat_completions_request(
llm_config: LLMConfig,
messages: List[Message],
messages: List[_Message],
user_id: Optional[str],
functions: Optional[list],
function_call: str,
function_call: Optional[str],
use_tool_naming: bool,
max_tokens: Optional[int],
) -> ChatCompletionRequest:
Expand All @@ -124,11 +129,17 @@ def build_openai_chat_completions_request(
model = None

if use_tool_naming:
if function_call is None:
tool_choice = None
elif function_call not in ["none", "auto", "required"]:
tool_choice = ToolFunctionChoice(type="function", function=ToolFunctionChoiceFunctionCall(name=function_call))
else:
tool_choice = function_call
data = ChatCompletionRequest(
model=model,
messages=openai_message_list,
tools=[{"type": "function", "function": f} for f in functions] if functions else None,
tool_choice=function_call,
tools=[Tool(type="function", function=f) for f in functions] if functions else None,
tool_choice=tool_choice,
user=str(user_id),
max_tokens=max_tokens,
)
Expand Down
28 changes: 22 additions & 6 deletions letta/local_llm/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os
import warnings
from typing import List
from typing import List, Union

import requests
import tiktoken
Expand All @@ -11,6 +11,7 @@
import letta.local_llm.llm_chat_completion_wrappers.dolphin as dolphin
import letta.local_llm.llm_chat_completion_wrappers.llama3 as llama3
import letta.local_llm.llm_chat_completion_wrappers.zephyr as zephyr
from letta.schemas.openai.chat_completion_request import Tool, ToolCall


def post_json_auth_request(uri, json_payload, auth_type, auth_key):
Expand Down Expand Up @@ -123,7 +124,7 @@ def num_tokens_from_functions(functions: List[dict], model: str = "gpt-4"):
return num_tokens


def num_tokens_from_tool_calls(tool_calls: List[dict], model: str = "gpt-4"):
def num_tokens_from_tool_calls(tool_calls: Union[List[dict], List[ToolCall]], model: str = "gpt-4"):
"""Based on above code (num_tokens_from_functions).

Example to encode:
Expand All @@ -144,10 +145,25 @@ def num_tokens_from_tool_calls(tool_calls: List[dict], model: str = "gpt-4"):

num_tokens = 0
for tool_call in tool_calls:
function_tokens = len(encoding.encode(tool_call["id"]))
function_tokens += 2 + len(encoding.encode(tool_call["type"]))
function_tokens += 2 + len(encoding.encode(tool_call["function"]["name"]))
function_tokens += 2 + len(encoding.encode(tool_call["function"]["arguments"]))
if isinstance(tool_call, dict):
tool_call_id = tool_call["id"]
tool_call_type = tool_call["type"]
tool_call_function = tool_call["function"]
tool_call_function_name = tool_call_function["name"]
tool_call_function_arguments = tool_call_function["arguments"]
elif isinstance(tool_call, Tool):
tool_call_id = tool_call.id
tool_call_type = tool_call.type
tool_call_function = tool_call.function
tool_call_function_name = tool_call_function.name
tool_call_function_arguments = tool_call_function.arguments
else:
raise ValueError(f"Unknown tool call type: {type(tool_call)}")

function_tokens = len(encoding.encode(tool_call_id))
function_tokens += 2 + len(encoding.encode(tool_call_type))
function_tokens += 2 + len(encoding.encode(tool_call_function_name))
function_tokens += 2 + len(encoding.encode(tool_call_function_arguments))

num_tokens += function_tokens

Expand Down
4 changes: 4 additions & 0 deletions letta/schemas/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from letta.schemas.block import Block
from letta.schemas.message import Message
from letta.schemas.openai.chat_completion_request import Tool


class ContextWindowOverview(BaseModel):
Expand Down Expand Up @@ -41,6 +42,9 @@ class ContextWindowOverview(BaseModel):
num_tokens_summary_memory: int = Field(..., description="The number of tokens in the summary memory.")
summary_memory: Optional[str] = Field(None, description="The content of the summary memory.")

num_tokens_functions_definitions: int = Field(..., description="The number of tokens in the functions definitions.")
functions_definitions: Optional[List[Tool]] = Field(..., description="The content of the functions definitions.")

num_tokens_messages: int = Field(..., description="The number of tokens in the messages list.")
# TODO make list of messages?
# messages: List[dict] = Field(..., description="The messages in the context window.")
Expand Down
4 changes: 2 additions & 2 deletions letta/schemas/openai/chat_completion_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ class ToolFunctionChoice(BaseModel):
function: FunctionCall


ToolChoice = Union[Literal["none", "auto"], ToolFunctionChoice]
ToolChoice = Union[Literal["none", "auto", "required"], ToolFunctionChoice]


## tools ##
Expand Down Expand Up @@ -117,7 +117,7 @@ class ChatCompletionRequest(BaseModel):

# function-calling related
tools: Optional[List[Tool]] = None
tool_choice: Optional[ToolChoice] = "none"
tool_choice: Optional[ToolChoice] = None # "none" means don't call a tool
# deprecated scheme
functions: Optional[List[FunctionSchema]] = None
function_call: Optional[FunctionCallChoice] = None
16 changes: 15 additions & 1 deletion letta/server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,12 @@
from letta.schemas.job import Job
from letta.schemas.letta_message import LettaMessage
from letta.schemas.llm_config import LLMConfig
from letta.schemas.memory import ArchivalMemorySummary, Memory, RecallMemorySummary
from letta.schemas.memory import (
ArchivalMemorySummary,
ContextWindowOverview,
Memory,
RecallMemorySummary,
)
from letta.schemas.message import Message, MessageCreate, MessageRole, UpdateMessage
from letta.schemas.organization import Organization, OrganizationCreate
from letta.schemas.passage import Passage
Expand Down Expand Up @@ -2166,3 +2171,12 @@ def add_llm_model(self, request: LLMConfig) -> LLMConfig:

def add_embedding_model(self, request: EmbeddingConfig) -> EmbeddingConfig:
"""Add a new embedding model"""

def get_agent_context_window(
self,
user_id: str,
agent_id: str,
) -> ContextWindowOverview:
# Get the current message
letta_agent = self._get_or_load_agent(agent_id=agent_id)
return letta_agent.get_context_window()
40 changes: 40 additions & 0 deletions tests/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -507,3 +507,43 @@ def test_agent_rethink_rewrite_retry(server, user_id, agent_id):
args_json = json.loads(last_agent_message.tool_calls[0].function.arguments)
print(args_json)
assert "message" in args_json and args_json["message"] is not None and args_json["message"] != new_text


def test_get_context_window_overview(server: SyncServer, user_id: str, agent_id: str):
"""Test that the context window overview fetch works"""

overview = server.get_agent_context_window(user_id=user_id, agent_id=agent_id)
assert overview is not None

# Run some basic checks
assert overview.context_window_size_max is not None
assert overview.context_window_size_current is not None
assert overview.num_archival_memory is not None
assert overview.num_recall_memory is not None
assert overview.num_tokens_external_memory_summary is not None
assert overview.num_tokens_system is not None
assert overview.system_prompt is not None
assert overview.num_tokens_core_memory is not None
assert overview.core_memory is not None
assert overview.num_tokens_summary_memory is not None
if overview.num_tokens_summary_memory > 0:
assert overview.summary_memory is not None
else:
assert overview.summary_memory is None
assert overview.num_tokens_functions_definitions is not None
if overview.num_tokens_functions_definitions > 0:
assert overview.functions_definitions is not None
else:
assert overview.functions_definitions is None
assert overview.num_tokens_messages is not None
assert overview.messages is not None

assert overview.context_window_size_max >= overview.context_window_size_current
assert overview.context_window_size_current == (
overview.num_tokens_system
+ overview.num_tokens_core_memory
+ overview.num_tokens_summary_memory
+ overview.num_tokens_messages
+ overview.num_tokens_functions_definitions
+ overview.num_tokens_external_memory_summary
)
Loading