diff --git a/alembic/versions/5987401b40ae_refactor_agent_memory.py b/alembic/versions/5987401b40ae_refactor_agent_memory.py new file mode 100644 index 0000000000..889e9425b5 --- /dev/null +++ b/alembic/versions/5987401b40ae_refactor_agent_memory.py @@ -0,0 +1,34 @@ +"""Refactor agent memory + +Revision ID: 5987401b40ae +Revises: 1c8880d671ee +Create Date: 2024-11-25 14:35:00.896507 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "5987401b40ae" +down_revision: Union[str, None] = "1c8880d671ee" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.alter_column("agents", "tools", new_column_name="tool_names") + op.drop_column("agents", "memory") + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.alter_column("agents", "tool_names", new_column_name="tools") + op.add_column("agents", sa.Column("memory", postgresql.JSON(astext_type=sa.Text()), autoincrement=False, nullable=True)) + # ### end Alembic commands ### diff --git a/examples/docs/tools.py b/examples/docs/tools.py index 0aa6cbaad1..b41fb501ad 100644 --- a/examples/docs/tools.py +++ b/examples/docs/tools.py @@ -45,7 +45,7 @@ def roll_d20() -> str: TerminalToolRule(tool_name="send_message"), ], ) -print(f"Created agent with name {agent_state.name} with tools {agent_state.tools}") +print(f"Created agent with name {agent_state.name} with tools {agent_state.tool_names}") # Message an agent response = client.send_message(agent_id=agent_state.id, role="user", message="roll a dice") diff --git a/examples/tool_rule_usage.py b/examples/tool_rule_usage.py index 5ea1add437..aca7c4f8f7 100644 --- a/examples/tool_rule_usage.py +++ b/examples/tool_rule_usage.py @@ -3,7 +3,7 @@ from letta import create_client from letta.schemas.letta_message import FunctionCallMessage -from letta.schemas.tool_rule import InitToolRule, TerminalToolRule, ToolRule +from letta.schemas.tool_rule import ChildToolRule, InitToolRule, TerminalToolRule from tests.helpers.endpoints_helper import ( assert_invoked_send_message_with_keyword, setup_agent, @@ -100,10 +100,10 @@ def main(): # 3. Create the tool rules. It must be called in this order, or there will be an error thrown. tool_rules = [ InitToolRule(tool_name="first_secret_word"), - ToolRule(tool_name="first_secret_word", children=["second_secret_word"]), - ToolRule(tool_name="second_secret_word", children=["third_secret_word"]), - ToolRule(tool_name="third_secret_word", children=["fourth_secret_word"]), - ToolRule(tool_name="fourth_secret_word", children=["send_message"]), + ChildToolRule(tool_name="first_secret_word", children=["second_secret_word"]), + ChildToolRule(tool_name="second_secret_word", children=["third_secret_word"]), + ChildToolRule(tool_name="third_secret_word", children=["fourth_secret_word"]), + ChildToolRule(tool_name="fourth_secret_word", children=["send_message"]), TerminalToolRule(tool_name="send_message"), ] diff --git a/letta/__init__.py b/letta/__init__.py index 83c5a692b2..fa5c1d4bb2 100644 --- a/letta/__init__.py +++ b/letta/__init__.py @@ -4,7 +4,7 @@ from letta.client.client import LocalClient, RESTClient, create_client # imports for easier access -from letta.schemas.agent import AgentState +from letta.schemas.agent import AgentState, PersistedAgentState from letta.schemas.block import Block from letta.schemas.embedding_config import EmbeddingConfig from letta.schemas.enums import JobStatus diff --git a/letta/agent.py b/letta/agent.py index 32dd6ad9fd..73f0199ce2 100644 --- a/letta/agent.py +++ b/letta/agent.py @@ -31,7 +31,7 @@ from letta.orm import User from letta.persistence_manager import LocalStateManager from letta.schemas.agent import AgentState, AgentStepResponse -from letta.schemas.block import Block +from letta.schemas.block import BlockUpdate from letta.schemas.embedding_config import EmbeddingConfig from letta.schemas.enums import MessageRole from letta.schemas.memory import ContextWindowOverview, Memory @@ -235,11 +235,8 @@ class Agent(BaseAgent): def __init__( self, interface: Optional[Union[AgentInterface, StreamingRefreshCLIInterface]], - # agents can be created from providing agent_state - agent_state: AgentState, - tools: List[Tool], + agent_state: AgentState, # in-memory representation of the agent state (read from multiple tables) user: User, - # memory: Memory, # extras messages_total: Optional[int] = None, # TODO remove? first_message_verify_mono: bool = True, # TODO move to config? @@ -253,7 +250,7 @@ def __init__( self.user = user # link tools - self.link_tools(tools) + self.link_tools(agent_state.tools) # initialize a tool rules solver if agent_state.tool_rules: @@ -265,26 +262,14 @@ def __init__( # add default rule for having send_message be a terminal tool if agent_state.tool_rules is None: agent_state.tool_rules = [] - # Define the rule to add - send_message_terminal_rule = TerminalToolRule(tool_name="send_message") - # Check if an equivalent rule is already present - if not any( - isinstance(rule, TerminalToolRule) and rule.tool_name == send_message_terminal_rule.tool_name for rule in agent_state.tool_rules - ): - agent_state.tool_rules.append(send_message_terminal_rule) self.tool_rules_solver = ToolRulesSolver(tool_rules=agent_state.tool_rules) # gpt-4, gpt-3.5-turbo, ... self.model = self.agent_state.llm_config.model - # Store the system instructions (used to rebuild memory) - self.system = self.agent_state.system - - # Initialize the memory object - self.memory = self.agent_state.memory - assert isinstance(self.memory, Memory), f"Memory object is not of type Memory: {type(self.memory)}" - printd("Initialized memory object", self.memory.compile()) + # state managers + self.block_manager = BlockManager() # Interface must implement: # - internal_monologue @@ -322,8 +307,8 @@ def __init__( # Generate a sequence of initial messages to put in the buffer init_messages = initialize_message_sequence( model=self.model, - system=self.system, - memory=self.memory, + system=self.agent_state.system, + memory=self.agent_state.memory, archival_memory=None, recall_memory=None, memory_edit_timestamp=get_utc_time(), @@ -345,8 +330,8 @@ def __init__( # Basic "more human than human" initial message sequence init_messages = initialize_message_sequence( model=self.model, - system=self.system, - memory=self.memory, + system=self.agent_state.system, + memory=self.agent_state.memory, archival_memory=None, recall_memory=None, memory_edit_timestamp=get_utc_time(), @@ -380,6 +365,76 @@ def __init__( # Create the agent in the DB self.update_state() + def update_memory_if_change(self, new_memory: Memory) -> bool: + """ + Update internal memory object and system prompt if there have been modifications. + + Args: + new_memory (Memory): the new memory object to compare to the current memory object + + Returns: + modified (bool): whether the memory was updated + """ + if self.agent_state.memory.compile() != new_memory.compile(): + # update the blocks (LRW) in the DB + for label in self.agent_state.memory.list_block_labels(): + updated_value = new_memory.get_block(label).value + if updated_value != self.agent_state.memory.get_block(label).value: + # update the block if it's changed + block_id = self.agent_state.memory.get_block(label).id + block = self.block_manager.update_block( + block_id=block_id, block_update=BlockUpdate(value=updated_value), actor=self.user + ) + + # refresh memory from DB (using block ids) + self.agent_state.memory = Memory( + blocks=[self.block_manager.get_block_by_id(block.id, actor=self.user) for block in self.agent_state.memory.get_blocks()] + ) + + # NOTE: don't do this since re-buildin the memory is handled at the start of the step + # rebuild memory - this records the last edited timestamp of the memory + # TODO: pass in update timestamp from block edit time + self.rebuild_system_prompt() + + return True + return False + + def execute_tool_and_persist_state(self, function_name, function_to_call, function_args): + """ + Execute tool modifications and persist the state of the agent. + Note: only some agent state modifications will be persisted, such as data in the AgentState ORM and block data + """ + # TODO: add agent manager here + orig_memory_str = self.agent_state.memory.compile() + + # TODO: need to have an AgentState object that actually has full access to the block data + # this is because the sandbox tools need to be able to access block.value to edit this data + try: + if function_name in BASE_TOOLS: + # base tools are allowed to access the `Agent` object and run on the database + function_args["self"] = self # need to attach self to arg since it's dynamically linked + function_response = function_to_call(**function_args) + else: + # execute tool in a sandbox + # TODO: allow agent_state to specify which sandbox to execute tools in + sandbox_run_result = ToolExecutionSandbox(function_name, function_args, self.agent_state.user_id).run( + agent_state=self.agent_state.__deepcopy__() + ) + function_response, updated_agent_state = sandbox_run_result.func_return, sandbox_run_result.agent_state + assert orig_memory_str == self.agent_state.memory.compile(), "Memory should not be modified in a sandbox tool" + self.update_memory_if_change(updated_agent_state.memory) + except Exception as e: + # Need to catch error here, or else trunction wont happen + # TODO: modify to function execution error + from letta.constants import MAX_ERROR_MESSAGE_CHAR_LIMIT + + error_msg = f"Error executing tool {function_name}: {e}" + if len(error_msg) > MAX_ERROR_MESSAGE_CHAR_LIMIT: + error_msg = error_msg[:MAX_ERROR_MESSAGE_CHAR_LIMIT] + raise ValueError(error_msg) + + return function_response + @property def messages(self) -> List[dict]: """Getter method that converts the internal Message list into OpenAI-style dicts""" @@ -392,16 +447,6 @@ def messages(self, value): def link_tools(self, tools: List[Tool]): """Bind a tool object (schema + python function) to the agent object""" - # tools - for tool in tools: - assert tool, f"Tool is None - must be error in querying tool from DB" - assert tool.name in self.agent_state.tools, f"Tool {tool} not found in agent_state.tools" - for tool_name in self.agent_state.tools: - assert tool_name in [tool.name for tool in tools], f"Tool name {tool_name} not included in agent tool list" - - # Update tools - self.tools = tools - # Store the functions schemas (this is passed as an argument to ChatCompletion) self.functions = [] self.functions_python = {} @@ -416,9 +461,8 @@ def link_tools(self, tools: List[Tool]): exec(tool.source_code, env) self.functions_python[tool.json_schema["name"]] = env[tool.json_schema["name"]] self.functions.append(tool.json_schema) - except Exception as e: + except Exception: warnings.warn(f"WARNING: tool {tool.name} failed to link") - print(e) assert all([callable(f) for k, f in self.functions_python.items()]), self.functions_python def _load_messages_from_recall(self, message_ids: List[str]) -> List[Message]: @@ -727,27 +771,10 @@ def _handle_ai_response( if isinstance(function_args[name], dict): function_args[name] = spec[name](**function_args[name]) - # TODO: This needs to be rethought, how do we allow functions that modify agent state/db? - # TODO: There should probably be two types of tools: stateless/stateful - - if function_name in BASE_TOOLS: - function_args["self"] = self # need to attach self to arg since it's dynamically linked - function_response = function_to_call(**function_args) - else: - # execute tool in a sandbox - # TODO: allow agent_state to specify which sandbox to execute tools in - sandbox_run_result = ToolExecutionSandbox(function_name, function_args, self.agent_state.user_id).run( - agent_state=self.agent_state - ) - function_response, updated_agent_state = sandbox_run_result.func_return, sandbox_run_result.agent_state - # update agent state - if self.agent_state != updated_agent_state and updated_agent_state is not None: - self.agent_state = updated_agent_state - self.memory = self.agent_state.memory # TODO: don't duplicate - - # rebuild memory - self.rebuild_memory() + # handle tool execution (sandbox) and state updates + function_response = self.execute_tool_and_persist_state(function_name, function_to_call, function_args) + # handle trunction if function_name in ["conversation_search", "conversation_search_date", "archival_memory_search"]: # with certain functions we rely on the paging mechanism to handle overflow truncate = False @@ -820,7 +847,7 @@ def _handle_ai_response( # rebuild memory # TODO: @charles please check this - self.rebuild_memory() + self.rebuild_system_prompt() # Update ToolRulesSolver state with last called function self.tool_rules_solver.update_tool_usage(function_name) @@ -936,17 +963,10 @@ def inner_step( # Step 0: update core memory # only pulling latest block data if shared memory is being used - # TODO: ensure we're passing in metadata store from all surfaces - if ms is not None: - should_update = False - for block in self.agent_state.memory.to_dict()["memory"].values(): - if not block.get("template", False): - should_update = True - if should_update: - # TODO: the force=True can be optimized away - # once we ensure we're correctly comparing whether in-memory core - # data is different than persisted core data. - self.rebuild_memory(force=True, ms=ms) + current_persisted_memory = Memory( + blocks=[self.block_manager.get_block_by_id(block.id, actor=self.user) for block in self.agent_state.memory.get_blocks()] + ) # read blocks from DB + self.update_memory_if_change(current_persisted_memory) # Step 1: add user message if isinstance(messages, Message): @@ -1229,43 +1249,10 @@ def _swap_system_message_in_buffer(self, new_system_message: str): new_messages = [new_system_message_obj] + self._messages[1:] # swap index 0 (system) self._messages = new_messages - def update_memory_blocks_from_db(self): - for block in self.memory.to_dict()["memory"].values(): - if block.get("templates", False): - # we don't expect to update shared memory blocks that - # are templates. this is something we could update in the - # future if we expect templates to change often. - continue - block_id = block.get("id") - - # TODO: This is really hacky and we should probably figure out how to - db_block = BlockManager().get_block_by_id(block_id=block_id, actor=self.user) - if db_block is None: - # this case covers if someone has deleted a shared block by interacting - # with some other agent. - # in that case we should remove this shared block from the agent currently being - # evaluated. - printd(f"removing block: {block_id=}") - continue - if not isinstance(db_block.value, str): - printd(f"skipping block update, unexpected value: {block_id=}") - continue - # TODO: we may want to update which columns we're updating from shared memory e.g. the limit - self.memory.update_block_value(label=block.get("label", ""), value=db_block.value) - - def rebuild_memory(self, force=False, update_timestamp=True, ms: Optional[MetadataStore] = None): + def rebuild_system_prompt(self, force=False, update_timestamp=True): """Rebuilds the system message with the latest memory object and any shared memory block updates""" curr_system_message = self.messages[0] # this is the system + memory bank, not just the system prompt - # NOTE: This is a hacky way to check if the memory has changed - memory_repr = self.memory.compile() - if not force and memory_repr == curr_system_message["content"][-(len(memory_repr)) :]: - printd(f"Memory has not changed, not rebuilding system") - return - - if ms: - self.update_memory_blocks_from_db() - # If the memory didn't update, we probably don't want to update the timestamp inside # For example, if we're doing a system prompt swap, this should probably be False if update_timestamp: @@ -1276,8 +1263,8 @@ def rebuild_memory(self, force=False, update_timestamp=True, ms: Optional[Metada # update memory (TODO: potentially update recall/archival stats seperately) new_system_message_str = compile_system_message( - system_prompt=self.system, - in_context_memory=self.memory, + system_prompt=self.agent_state.system, + in_context_memory=self.agent_state.memory, in_context_memory_last_edit=memory_edit_timestamp, archival_memory=self.persistence_manager.archival_memory, recall_memory=self.persistence_manager.recall_memory, @@ -1304,13 +1291,13 @@ def update_system_prompt(self, new_system_prompt: str): """Update the system prompt of the agent (requires rebuilding the memory block if there's a difference)""" assert isinstance(new_system_prompt, str) - if new_system_prompt == self.system: + if new_system_prompt == self.agent_state.system: return - self.system = new_system_prompt + self.agent_state.system = new_system_prompt # updating the system prompt requires rebuilding the memory block inside the compiled system message - self.rebuild_memory(force=True, update_timestamp=False) + self.rebuild_system_prompt(force=True, update_timestamp=False) # make sure to persist the change _ = self.update_state() @@ -1324,6 +1311,7 @@ def remove_function(self, function_name: str) -> str: raise NotImplementedError def update_state(self) -> AgentState: + # TODO: this should be removed and self._messages should be moved into self.agent_state.in_context_messages message_ids = [msg.id for msg in self._messages] # Assert that these are all strings @@ -1331,12 +1319,8 @@ def update_state(self) -> AgentState: warnings.warn(f"Non-string message IDs found in agent state: {message_ids}") message_ids = [m_id for m_id in message_ids if isinstance(m_id, str)] - assert isinstance(self.memory, Memory), f"Memory is not a Memory object: {type(self.memory)}" - # override any fields that may have been updated self.agent_state.message_ids = message_ids - self.agent_state.memory = self.memory - self.agent_state.system = self.system return self.agent_state @@ -1537,7 +1521,7 @@ def get_context_window(self) -> ContextWindowOverview: system_prompt = self.agent_state.system # TODO is this the current system or the initial system? num_tokens_system = count_tokens(system_prompt) - core_memory = self.memory.compile() + core_memory = self.agent_state.memory.compile() num_tokens_core_memory = count_tokens(core_memory) # conversion of messages to OpenAI dict format, which is passed to the token counter @@ -1629,37 +1613,15 @@ def save_agent(agent: Agent, ms: MetadataStore): agent.update_state() agent_state = agent.agent_state - agent_id = agent_state.id assert isinstance(agent_state.memory, Memory), f"Memory is not a Memory object: {type(agent_state.memory)}" - # NOTE: we're saving agent memory before persisting the agent to ensure - # that allocated block_ids for each memory block are present in the agent model - save_agent_memory(agent=agent) - - if ms.get_agent(agent_id=agent.agent_state.id): - ms.update_agent(agent_state) + # TODO: move this to agent manager + # convert to persisted model + persisted_agent_state = agent.agent_state.to_persisted_agent_state() + if ms.get_agent(agent_id=persisted_agent_state.id): + ms.update_agent(persisted_agent_state) else: - ms.create_agent(agent_state) - - agent.agent_state = ms.get_agent(agent_id=agent_id) - assert isinstance(agent.agent_state.memory, Memory), f"Memory is not a Memory object: {type(agent_state.memory)}" - - -def save_agent_memory(agent: Agent): - """ - Save agent memory to metadata store. Memory is a collection of blocks and each block is persisted to the block table. - - NOTE: we are assuming agent.update_state has already been called. - """ - - for block_dict in agent.memory.to_dict()["memory"].values(): - # TODO: block creation should happen in one place to enforce these sort of constraints consistently. - block = Block(**block_dict) - # FIXME: should we expect for block values to be None? If not, we need to figure out why that is - # the case in some tests, if so we should relax the DB constraint. - if block.value is None: - block.value = "" - BlockManager().create_or_update_block(block, actor=agent.user) + ms.create_agent(persisted_agent_state) def strip_name_field_from_user_message(user_message_text: str) -> Tuple[str, Optional[str]]: diff --git a/letta/agent_store/chroma.py b/letta/agent_store/chroma.py index e192a1543b..eace737b3d 100644 --- a/letta/agent_store/chroma.py +++ b/letta/agent_store/chroma.py @@ -125,6 +125,8 @@ def get_all(self, filters: Optional[Dict] = {}, limit=None): ids, filters = self.get_filters(filters) if self.collection.count() == 0: return [] + if ids == []: + ids = None if limit: results = self.collection.get(ids=ids, include=self.include, where=filters, limit=limit) else: diff --git a/letta/cli/cli.py b/letta/cli/cli.py index 076a179ab7..79c04c49b7 100644 --- a/letta/cli/cli.py +++ b/letta/cli/cli.py @@ -219,7 +219,7 @@ def run( ) # create agent - tools = [server.tool_manager.get_tool_by_name(tool_name=tool_name, actor=client.user) for tool_name in agent_state.tools] + tools = [server.tool_manager.get_tool_by_name(tool_name=tool_name, actor=client.user) for tool_name in agent_state.tool_names] letta_agent = Agent(agent_state=agent_state, interface=interface(), tools=tools, user=client.user) else: # create new agent @@ -311,13 +311,11 @@ def run( metadata=metadata, ) assert isinstance(agent_state.memory, Memory), f"Expected Memory, got {type(agent_state.memory)}" - typer.secho(f"-> 🛠️ {len(agent_state.tools)} tools: {', '.join([t for t in agent_state.tools])}", fg=typer.colors.WHITE) - tools = [server.tool_manager.get_tool_by_name(tool_name, actor=client.user) for tool_name in agent_state.tools] + typer.secho(f"-> 🛠️ {len(agent_state.tools)} tools: {', '.join([t for t in agent_state.tool_names])}", fg=typer.colors.WHITE) letta_agent = Agent( interface=interface(), - agent_state=agent_state, - tools=tools, + agent_state=client.get_agent(agent_state.id), # gpt-3.5-turbo tends to omit inner monologue, relax this requirement for now first_message_verify_mono=True if (model is not None and "gpt-4" in model) else False, user=client.user, diff --git a/letta/client/client.py b/letta/client/client.py index 1e6881c55c..0bf773c04a 100644 --- a/letta/client/client.py +++ b/letta/client/client.py @@ -5,20 +5,17 @@ import requests import letta.utils -from letta.constants import ADMIN_PREFIX, BASE_TOOLS, DEFAULT_HUMAN, DEFAULT_PERSONA +from letta.constants import ( + ADMIN_PREFIX, + BASE_MEMORY_TOOLS, + BASE_TOOLS, + DEFAULT_HUMAN, + DEFAULT_PERSONA, +) from letta.data_sources.connectors import DataConnector from letta.functions.functions import parse_source_code -from letta.memory import get_memory_functions from letta.schemas.agent import AgentState, AgentType, CreateAgent, UpdateAgentState -from letta.schemas.block import ( - Block, - BlockCreate, - BlockUpdate, - Human, - Persona, - UpdateHuman, - UpdatePersona, -) +from letta.schemas.block import Block, BlockUpdate, CreateBlock, Human, Persona from letta.schemas.embedding_config import EmbeddingConfig # new schemas @@ -82,7 +79,7 @@ def create_agent( agent_type: Optional[AgentType] = AgentType.memgpt_agent, embedding_config: Optional[EmbeddingConfig] = None, llm_config: Optional[LLMConfig] = None, - memory: Memory = ChatMemory(human=get_human_text(DEFAULT_HUMAN), persona=get_persona_text(DEFAULT_PERSONA)), + memory=None, system: Optional[str] = None, tools: Optional[List[str]] = None, tool_rules: Optional[List[BaseToolRule]] = None, @@ -522,27 +519,13 @@ def create_agent( Returns: agent_state (AgentState): State of the created agent """ - - # TODO: implement this check once name lookup works - # if name: - # exist_agent_id = self.get_agent_id(agent_name=name) - - # raise ValueError(f"Agent with name {name} already exists") - - # construct list of tools tool_names = [] if tools: tool_names += tools if include_base_tools: tool_names += BASE_TOOLS + tool_names += BASE_MEMORY_TOOLS - # add memory tools - memory_functions = get_memory_functions(memory) - for func_name, func in memory_functions.items(): - tool = self.create_or_update_tool(func, name=func_name, tags=["memory", "letta-base"]) - tool_names.append(tool.name) - - # check if default configs are provided assert embedding_config or self._default_embedding_config, f"Embedding config must be provided" assert llm_config or self._default_llm_config, f"LLM config must be provided" @@ -551,7 +534,7 @@ def create_agent( name=name, description=description, metadata_=metadata, - memory=memory, + memory_blocks=[], tools=tool_names, tool_rules=tool_rules, system=system, @@ -573,7 +556,20 @@ def create_agent( if response.status_code != 200: raise ValueError(f"Status {response.status_code} - Failed to create agent: {response.text}") - return AgentState(**response.json()) + + # gather agent state + agent_state = AgentState(**response.json()) + + # create and link blocks + for block in memory.get_blocks(): + if not self.get_block(block.id): + # note: this does not update existing blocks + # WARNING: this resets the block ID - this method is a hack for backwards compat, should eventually use CreateBlock not Memory + block = self.create_block(label=block.label, value=block.value, limit=block.limit) + self.link_agent_memory_block(agent_id=agent_state.id, block_id=block.id) + + # refresh and return agent + return self.get_agent(agent_state.id) def update_message( self, @@ -606,12 +602,11 @@ def update_agent( name: Optional[str] = None, description: Optional[str] = None, system: Optional[str] = None, - tools: Optional[List[str]] = None, + tool_names: Optional[List[str]] = None, metadata: Optional[Dict] = None, llm_config: Optional[LLMConfig] = None, embedding_config: Optional[EmbeddingConfig] = None, message_ids: Optional[List[str]] = None, - memory: Optional[Memory] = None, tags: Optional[List[str]] = None, ): """ @@ -622,12 +617,11 @@ def update_agent( name (str): Name of the agent description (str): Description of the agent system (str): System configuration - tools (List[str]): List of tools + tool_names (List[str]): List of tools metadata (Dict): Metadata llm_config (LLMConfig): LLM configuration embedding_config (EmbeddingConfig): Embedding configuration message_ids (List[str]): List of message IDs - memory (Memory): Memory configuration tags (List[str]): Tags for filtering agents Returns: @@ -637,14 +631,13 @@ def update_agent( id=agent_id, name=name, system=system, - tools=tools, + tool_names=tool_names, tags=tags, description=description, metadata_=metadata, llm_config=llm_config, embedding_config=embedding_config, message_ids=message_ids, - memory=memory, ) response = requests.patch(f"{self.base_url}/{self.api_prefix}/agents/{agent_id}", json=request.model_dump(), headers=self.headers) if response.status_code != 200: @@ -1001,8 +994,12 @@ def list_blocks(self, label: Optional[str] = None, templates_only: Optional[bool else: return [Block(**block) for block in response.json()] - def create_block(self, label: str, value: str, template_name: Optional[str] = None, is_template: bool = False) -> Block: # - request = BlockCreate(label=label, value=value, template=is_template, template_name=template_name) + def create_block( + self, label: str, value: str, limit: Optional[int] = None, template_name: Optional[str] = None, is_template: bool = False + ) -> Block: # + request = CreateBlock(label=label, value=value, template=is_template, template_name=template_name) + if limit: + request.limit = limit response = requests.post(f"{self.base_url}/{self.api_prefix}/blocks", json=request.model_dump(), headers=self.headers) if response.status_code != 200: raise ValueError(f"Failed to create block: {response.text}") @@ -1772,21 +1769,32 @@ def list_sandbox_env_vars( raise ValueError(f"Failed to list environment variables for sandbox config ID '{sandbox_config_id}': {response.text}") return [SandboxEnvironmentVariable(**var_data) for var_data in response.json()] - def update_agent_memory_label(self, agent_id: str, current_label: str, new_label: str) -> Memory: + def update_agent_memory_block_label(self, agent_id: str, current_label: str, new_label: str) -> Memory: + """Rename a block in the agent's core memory - # @router.patch("/{agent_id}/memory/label", response_model=Memory, operation_id="update_agent_memory_label") - response = requests.patch( - f"{self.base_url}/{self.api_prefix}/agents/{agent_id}/memory/label", - headers=self.headers, - json={"current_label": current_label, "new_label": new_label}, - ) - if response.status_code != 200: - raise ValueError(f"Failed to update agent memory label: {response.text}") - return Memory(**response.json()) + Args: + agent_id (str): The agent ID + current_label (str): The current label of the block + new_label (str): The new label of the block + + Returns: + memory (Memory): The updated memory + """ + block = self.get_agent_memory_block(agent_id, current_label) + return self.update_block(block.id, label=new_label) + + # TODO: remove this + def add_agent_memory_block(self, agent_id: str, create_block: CreateBlock) -> Memory: + """ + Create and link a memory block to an agent's core memory - def add_agent_memory_block(self, agent_id: str, create_block: BlockCreate) -> Memory: + Args: + agent_id (str): The agent ID + create_block (CreateBlock): The block to create - # @router.post("/{agent_id}/memory/block", response_model=Memory, operation_id="add_agent_memory_block") + Returns: + memory (Memory): The updated memory + """ response = requests.post( f"{self.base_url}/{self.api_prefix}/agents/{agent_id}/memory/block", headers=self.headers, @@ -1796,9 +1804,38 @@ def add_agent_memory_block(self, agent_id: str, create_block: BlockCreate) -> Me raise ValueError(f"Failed to add agent memory block: {response.text}") return Memory(**response.json()) + def link_agent_memory_block(self, agent_id: str, block_id: str) -> Memory: + """ + Link a block to an agent's core memory + + Args: + agent_id (str): The agent ID + block_id (str): The block ID + + Returns: + memory (Memory): The updated memory + """ + params = {"agent_id": agent_id} + response = requests.patch( + f"{self.base_url}/{self.api_prefix}/blocks/{block_id}/attach", + params=params, + headers=self.headers, + ) + if response.status_code != 200: + raise ValueError(f"Failed to link agent memory block: {response.text}") + return Block(**response.json()) + def remove_agent_memory_block(self, agent_id: str, block_label: str) -> Memory: + """ + Unlike a block from the agent's core memory + + Args: + agent_id (str): The agent ID + block_label (str): The block label - # @router.delete("/{agent_id}/memory/block/{block_label}", response_model=Memory, operation_id="remove_agent_memory_block") + Returns: + memory (Memory): The updated memory + """ response = requests.delete( f"{self.base_url}/{self.api_prefix}/agents/{agent_id}/memory/block/{block_label}", headers=self.headers, @@ -1807,17 +1844,108 @@ def remove_agent_memory_block(self, agent_id: str, block_label: str) -> Memory: raise ValueError(f"Failed to remove agent memory block: {response.text}") return Memory(**response.json()) - def update_agent_memory_limit(self, agent_id: str, block_label: str, limit: int) -> Memory: + def get_agent_memory_blocks(self, agent_id: str) -> List[Block]: + """ + Get all the blocks in the agent's core memory + + Args: + agent_id (str): The agent ID + + Returns: + blocks (List[Block]): The blocks in the agent's core memory + """ + response = requests.get(f"{self.base_url}/{self.api_prefix}/agents/{agent_id}/memory/block", headers=self.headers) + if response.status_code != 200: + raise ValueError(f"Failed to get agent memory blocks: {response.text}") + return [Block(**block) for block in response.json()] + + def get_agent_memory_block(self, agent_id: str, label: str) -> Block: + """ + Get a block in the agent's core memory by its label + + Args: + agent_id (str): The agent ID + label (str): The label in the agent's core memory - # @router.patch("/{agent_id}/memory/limit", response_model=Memory, operation_id="update_agent_memory_limit") + Returns: + block (Block): The block corresponding to the label + """ + response = requests.get( + f"{self.base_url}/{self.api_prefix}/agents/{agent_id}/memory/block/{label}", + headers=self.headers, + ) + if response.status_code != 200: + raise ValueError(f"Failed to get agent memory block: {response.text}") + return Block(**response.json()) + + def update_agent_memory_block( + self, + agent_id: str, + label: str, + value: Optional[str] = None, + limit: Optional[int] = None, + ): + """ + Update a block in the agent's core memory by specifying its label + + Args: + agent_id (str): The agent ID + label (str): The label of the block + value (str): The new value of the block + limit (int): The new limit of the block + + Returns: + block (Block): The updated block + """ + # setup data + data = {} + if value: + data["value"] = value + if limit: + data["limit"] = limit response = requests.patch( - f"{self.base_url}/{self.api_prefix}/agents/{agent_id}/memory/limit", + f"{self.base_url}/{self.api_prefix}/agents/{agent_id}/memory/block/{label}", headers=self.headers, - json={"label": block_label, "limit": limit}, + json=data, ) if response.status_code != 200: - raise ValueError(f"Failed to update agent memory limit: {response.text}") - return Memory(**response.json()) + raise ValueError(f"Failed to update agent memory block: {response.text}") + return Block(**response.json()) + + def update_block( + self, + block_id: str, + label: Optional[str] = None, + value: Optional[str] = None, + limit: Optional[int] = None, + ): + """ + Update a block given the ID with the provided fields + + Args: + block_id (str): ID of the block + label (str): Label to assign to the block + value (str): Value to assign to the block + limit (int): Token limit to assign to the block + + Returns: + block (Block): Updated block + """ + data = {} + if value: + data["value"] = value + if limit: + data["limit"] = limit + if label: + data["label"] = label + response = requests.patch( + f"{self.base_url}/{self.api_prefix}/blocks/{block_id}", + headers=self.headers, + json=data, + ) + if response.status_code != 200: + raise ValueError(f"Failed to update block: {response.text}") + return Block(**response.json()) class LocalClient(AbstractClient): @@ -1916,6 +2044,11 @@ def create_agent( llm_config: LLMConfig = None, # memory memory: Memory = ChatMemory(human=get_human_text(DEFAULT_HUMAN), persona=get_persona_text(DEFAULT_PERSONA)), + # TODO: change to this when we are ready to migrate all the tests/examples (matches the REST API) + # memory_blocks=[ + # {"label": "human", "value": get_human_text(DEFAULT_HUMAN), "limit": 5000}, + # {"label": "persona", "value": get_persona_text(DEFAULT_PERSONA), "limit": 5000}, + # ], # system system: Optional[str] = None, # tools @@ -1934,7 +2067,7 @@ def create_agent( name (str): Name of the agent embedding_config (EmbeddingConfig): Embedding configuration llm_config (LLMConfig): LLM configuration - memory (Memory): Memory configuration + memory_blocks (List[Dict]): List of configurations for the memory blocks (placed in core-memory) system (str): System configuration tools (List[str]): List of tools tool_rules (Optional[List[BaseToolRule]]): List of tool rules @@ -1956,14 +2089,7 @@ def create_agent( tool_names += tools if include_base_tools: tool_names += BASE_TOOLS - - # add memory tools - memory_functions = get_memory_functions(memory) - for func_name, func in memory_functions.items(): - tool = self.create_or_update_tool(func, name=func_name, tags=["memory", "letta-base"]) - tool_names.append(tool.name) - - self.interface.clear() + tool_names += BASE_MEMORY_TOOLS # check if default configs are provided assert embedding_config or self._default_embedding_config, f"Embedding config must be provided" @@ -1975,7 +2101,10 @@ def create_agent( name=name, description=description, metadata_=metadata, - memory=memory, + # memory=memory, + memory_blocks=[], + # memory_blocks = memory.get_blocks(), + # memory_tools=memory_tools, tools=tool_names, tool_rules=tool_rules, system=system, @@ -1987,7 +2116,18 @@ def create_agent( ), actor=self.user, ) - return agent_state + + # TODO: remove when we fully migrate to block creation CreateAgent model + # Link additional blocks to the agent (block ids created on the client) + # This needs to happen since the create agent does not allow passing in blocks which have already been persisted and have an ID + # So we create the agent and then link the blocks afterwards + user = self.server.get_user_or_default(self.user_id) + for block in memory.get_blocks(): + self.server.block_manager.create_or_update_block(block, actor=user) + self.server.link_block_to_agent_memory(user_id=self.user_id, agent_id=agent_state.id, block_id=block.id) + + # TODO: get full agent state + return self.server.get_agent(agent_state.id) def update_message( self, @@ -2024,7 +2164,6 @@ def update_agent( llm_config: Optional[LLMConfig] = None, embedding_config: Optional[EmbeddingConfig] = None, message_ids: Optional[List[str]] = None, - memory: Optional[Memory] = None, ): """ Update an existing agent @@ -2039,26 +2178,25 @@ def update_agent( llm_config (LLMConfig): LLM configuration embedding_config (EmbeddingConfig): Embedding configuration message_ids (List[str]): List of message IDs - memory (Memory): Memory configuration tags (List[str]): Tags for filtering agents Returns: agent_state (AgentState): State of the updated agent """ + # TODO: add the abilitty to reset linked block_ids self.interface.clear() agent_state = self.server.update_agent( UpdateAgentState( id=agent_id, name=name, system=system, - tools=tools, + tool_names=tools, tags=tags, description=description, metadata_=metadata, llm_config=llm_config, embedding_config=embedding_config, message_ids=message_ids, - memory=memory, ), actor=self.user, ) @@ -2197,7 +2335,7 @@ def update_in_context_memory(self, agent_id: str, section: str, value: Union[Lis """ # TODO: implement this (not sure what it should look like) - memory = self.server.update_agent_core_memory(user_id=self.user_id, agent_id=agent_id, new_memory_contents={section: value}) + memory = self.server.update_agent_core_memory(user_id=self.user_id, agent_id=agent_id, label=section, value=value) return memory def get_archival_memory_summary(self, agent_id: str) -> ArchivalMemorySummary: @@ -2946,7 +3084,9 @@ def list_blocks(self, label: Optional[str] = None, templates_only: Optional[bool """ return self.server.block_manager.get_blocks(actor=self.user, label=label, is_template=templates_only) - def create_block(self, label: str, value: str, template_name: Optional[str] = None, is_template: bool = False) -> Block: # + def create_block( + self, label: str, value: str, limit: Optional[int] = None, template_name: Optional[str] = None, is_template: bool = False + ) -> Block: # """ Create a block @@ -2954,13 +3094,15 @@ def create_block(self, label: str, value: str, template_name: Optional[str] = No label (str): Label of the block name (str): Name of the block text (str): Text of the block + limit (int): Character of the block Returns: block (Block): Created block """ - return self.server.block_manager.create_or_update_block( - Block(label=label, template_name=template_name, value=value, is_template=is_template), actor=self.user - ) + block = Block(label=label, template_name=template_name, value=value, is_template=is_template) + if limit: + block.limit = limit + return self.server.block_manager.create_or_update_block(block, actor=self.user) def update_block(self, block_id: str, name: Optional[str] = None, text: Optional[str] = None, limit: Optional[int] = None) -> Block: """ @@ -3115,20 +3257,142 @@ def list_sandbox_env_vars( sandbox_config_id=sandbox_config_id, actor=self.user, limit=limit, cursor=cursor ) - def update_agent_memory_label(self, agent_id: str, current_label: str, new_label: str) -> Memory: - return self.server.update_agent_memory_label( - user_id=self.user_id, agent_id=agent_id, current_block_label=current_label, new_block_label=new_label - ) + def update_agent_memory_block_label(self, agent_id: str, current_label: str, new_label: str) -> Memory: + """Rename a block in the agent's core memory + + Args: + agent_id (str): The agent ID + current_label (str): The current label of the block + new_label (str): The new label of the block + + Returns: + memory (Memory): The updated memory + """ + block = self.get_agent_memory_block(agent_id, current_label) + return self.update_block(block.id, label=new_label) + + # TODO: remove this + def add_agent_memory_block(self, agent_id: str, create_block: CreateBlock) -> Memory: + """ + Create and link a memory block to an agent's core memory + + Args: + agent_id (str): The agent ID + create_block (CreateBlock): The block to create - def add_agent_memory_block(self, agent_id: str, create_block: BlockCreate) -> Memory: + Returns: + memory (Memory): The updated memory + """ block_req = Block(**create_block.model_dump()) block = self.server.block_manager.create_or_update_block(actor=self.user, block=block_req) # Link the block to the agent updated_memory = self.server.link_block_to_agent_memory(user_id=self.user_id, agent_id=agent_id, block_id=block.id) return updated_memory + def link_agent_memory_block(self, agent_id: str, block_id: str) -> Memory: + """ + Link a block to an agent's core memory + + Args: + agent_id (str): The agent ID + block_id (str): The block ID + + Returns: + memory (Memory): The updated memory + """ + return self.server.link_block_to_agent_memory(user_id=self.user_id, agent_id=agent_id, block_id=block_id) + def remove_agent_memory_block(self, agent_id: str, block_label: str) -> Memory: + """ + Unlike a block from the agent's core memory + + Args: + agent_id (str): The agent ID + block_label (str): The block label + + Returns: + memory (Memory): The updated memory + """ return self.server.unlink_block_from_agent_memory(user_id=self.user_id, agent_id=agent_id, block_label=block_label) - def update_agent_memory_limit(self, agent_id: str, block_label: str, limit: int) -> Memory: - return self.server.update_agent_memory_limit(user_id=self.user_id, agent_id=agent_id, block_label=block_label, limit=limit) + def get_agent_memory_blocks(self, agent_id: str) -> List[Block]: + """ + Get all the blocks in the agent's core memory + + Args: + agent_id (str): The agent ID + + Returns: + blocks (List[Block]): The blocks in the agent's core memory + """ + block_ids = self.server.blocks_agents_manager.list_block_ids_for_agent(agent_id=agent_id) + return [self.server.block_manager.get_block_by_id(block_id, actor=self.user) for block_id in block_ids] + + def get_agent_memory_block(self, agent_id: str, label: str) -> Block: + """ + Get a block in the agent's core memory by its label + + Args: + agent_id (str): The agent ID + label (str): The label in the agent's core memory + + Returns: + block (Block): The block corresponding to the label + """ + block_id = self.server.blocks_agents_manager.get_block_id_for_label(agent_id=agent_id, block_label=label) + return self.server.block_manager.get_block_by_id(block_id, actor=self.user) + + def update_agent_memory_block( + self, + agent_id: str, + label: str, + value: Optional[str] = None, + limit: Optional[int] = None, + ): + """ + Update a block in the agent's core memory by specifying its label + + Args: + agent_id (str): The agent ID + label (str): The label of the block + value (str): The new value of the block + limit (int): The new limit of the block + + Returns: + block (Block): The updated block + """ + block = self.get_agent_memory_block(agent_id, label) + data = {} + if value: + data["value"] = value + if limit: + data["limit"] = limit + return self.server.block_manager.update_block(block.id, actor=self.user, block_update=BlockUpdate(**data)) + + def update_block( + self, + block_id: str, + label: Optional[str] = None, + value: Optional[str] = None, + limit: Optional[int] = None, + ): + """ + Update a block given the ID with the provided fields + + Args: + block_id (str): ID of the block + label (str): Label to assign to the block + value (str): Value to assign to the block + limit (int): Token limit to assign to the block + + Returns: + block (Block): Updated block + """ + data = {} + if value: + data["value"] = value + if limit: + data["limit"] = limit + if label: + data["label"] = label + return self.server.block_manager.update_block(block_id, actor=self.user, block_update=BlockUpdate(**data)) diff --git a/letta/config.py b/letta/config.py index 51287e0091..cc0c5aa720 100644 --- a/letta/config.py +++ b/letta/config.py @@ -16,7 +16,7 @@ LETTA_DIR, ) from letta.log import get_logger -from letta.schemas.agent import AgentState +from letta.schemas.agent import PersistedAgentState from letta.schemas.embedding_config import EmbeddingConfig from letta.schemas.llm_config import LLMConfig @@ -434,7 +434,7 @@ def save(self): json.dump(vars(self), f, indent=4) def to_agent_state(self): - return AgentState( + return PersistedAgentState( name=self.name, preset=self.preset, persona=self.persona, diff --git a/letta/constants.py b/letta/constants.py index c16e241566..32a4946b8f 100644 --- a/letta/constants.py +++ b/letta/constants.py @@ -38,6 +38,8 @@ # Base tools that cannot be edited, as they access agent state directly BASE_TOOLS = ["send_message", "conversation_search", "conversation_search_date", "archival_memory_insert", "archival_memory_search"] +# Base memory tools CAN be edited, and are added by default by the server +BASE_MEMORY_TOOLS = ["core_memory_append", "core_memory_replace"] # The name of the tool used to send message to the user # May not be relevant in cases where the agent has multiple ways to message to user (send_imessage, send_discord_mesasge, ...) @@ -127,6 +129,9 @@ # These serve as in-context examples of how to use functions / what user messages look like MESSAGE_SUMMARY_TRUNC_KEEP_N_LAST = 3 +# Maximum length of an error message +MAX_ERROR_MESSAGE_CHAR_LIMIT = 500 + # Default memory limits CORE_MEMORY_PERSONA_CHAR_LIMIT: int = 5000 CORE_MEMORY_HUMAN_CHAR_LIMIT: int = 5000 diff --git a/letta/functions/function_sets/base.py b/letta/functions/function_sets/base.py index 51cfab5daf..e7bd4a9d94 100644 --- a/letta/functions/function_sets/base.py +++ b/letta/functions/function_sets/base.py @@ -172,3 +172,40 @@ def archival_memory_search(self: Agent, query: str, page: Optional[int] = 0) -> results_formatted = [f"timestamp: {d['timestamp']}, memory: {d['content']}" for d in results] results_str = f"{results_pref} {json_dumps(results_formatted)}" return results_str + + +def core_memory_append(agent_state: "AgentState", label: str, content: str) -> Optional[str]: # type: ignore + """ + Append to the contents of core memory. + + Args: + label (str): Section of the memory to be edited (persona or human). + content (str): Content to write to the memory. All unicode (including emojis) are supported. + + Returns: + Optional[str]: None is always returned as this function does not produce a response. + """ + current_value = str(agent_state.memory.get_block(label).value) + new_value = current_value + "\n" + str(content) + agent_state.memory.update_block_value(label=label, value=new_value) + return None + + +def core_memory_replace(agent_state: "AgentState", label: str, old_content: str, new_content: str) -> Optional[str]: # type: ignore + """ + Replace the contents of core memory. To delete memories, use an empty string for new_content. + + Args: + label (str): Section of the memory to be edited (persona or human). + old_content (str): String to replace. Must be an exact match. + new_content (str): Content to write to the memory. All unicode (including emojis) are supported. + + Returns: + Optional[str]: None is always returned as this function does not produce a response. + """ + current_value = str(agent_state.memory.get_block(label).value) + if old_content not in current_value: + raise ValueError(f"Old content '{old_content}' not found in memory block '{label}'") + new_value = current_value.replace(str(old_content), str(new_content)) + agent_state.memory.update_block_value(label=label, value=new_value) + return None diff --git a/letta/helpers/tool_rule_solver.py b/letta/helpers/tool_rule_solver.py index 4c50686c38..ef4d9a9b37 100644 --- a/letta/helpers/tool_rule_solver.py +++ b/letta/helpers/tool_rule_solver.py @@ -2,11 +2,12 @@ from pydantic import BaseModel, Field +from letta.schemas.enums import ToolRuleType from letta.schemas.tool_rule import ( BaseToolRule, + ChildToolRule, InitToolRule, TerminalToolRule, - ToolRule, ) @@ -21,7 +22,7 @@ class ToolRulesSolver(BaseModel): init_tool_rules: List[InitToolRule] = Field( default_factory=list, description="Initial tool rules to be used at the start of tool execution." ) - tool_rules: List[ToolRule] = Field( + tool_rules: List[ChildToolRule] = Field( default_factory=list, description="Standard tool rules for controlling execution sequence and allowed transitions." ) terminal_tool_rules: List[TerminalToolRule] = Field( @@ -33,11 +34,11 @@ def __init__(self, tool_rules: List[BaseToolRule], **kwargs): super().__init__(**kwargs) # Separate the provided tool rules into init, standard, and terminal categories for rule in tool_rules: - if isinstance(rule, InitToolRule): + if rule.type == ToolRuleType.run_first: self.init_tool_rules.append(rule) - elif isinstance(rule, ToolRule): + elif rule.type == ToolRuleType.constrain_child_tools: self.tool_rules.append(rule) - elif isinstance(rule, TerminalToolRule): + elif rule.type == ToolRuleType.exit_loop: self.terminal_tool_rules.append(rule) # Validate the tool rules to ensure they form a DAG diff --git a/letta/main.py b/letta/main.py index abfd36ae9c..88d20e08b0 100644 --- a/letta/main.py +++ b/letta/main.py @@ -189,7 +189,7 @@ def run_agent_loop( elif user_input.lower() == "/memory": print(f"\nDumping memory contents:\n") - print(f"{letta_agent.memory.compile()}") + print(f"{letta_agent.agent_state.memory.compile()}") print(f"{letta_agent.persistence_manager.archival_memory.compile()}") print(f"{letta_agent.persistence_manager.recall_memory.compile()}") continue diff --git a/letta/metadata.py b/letta/metadata.py index 1b8f6a220d..3fdfa4038e 100644 --- a/letta/metadata.py +++ b/letta/metadata.py @@ -3,27 +3,21 @@ import os import secrets import warnings -from typing import List, Optional +from typing import List, Optional, Union from sqlalchemy import JSON, Column, DateTime, Index, String, TypeDecorator from sqlalchemy.sql import func from letta.config import LettaConfig from letta.orm.base import Base -from letta.schemas.agent import AgentState +from letta.schemas.agent import PersistedAgentState from letta.schemas.api_key import APIKey from letta.schemas.embedding_config import EmbeddingConfig -from letta.schemas.enums import JobStatus +from letta.schemas.enums import JobStatus, ToolRuleType from letta.schemas.job import Job from letta.schemas.llm_config import LLMConfig -from letta.schemas.memory import Memory from letta.schemas.openai.chat_completions import ToolCall, ToolCallFunction -from letta.schemas.tool_rule import ( - BaseToolRule, - InitToolRule, - TerminalToolRule, - ToolRule, -) +from letta.schemas.tool_rule import ChildToolRule, InitToolRule, TerminalToolRule from letta.schemas.user import User from letta.services.per_agent_lock_manager import PerAgentLockManager from letta.settings import settings @@ -165,28 +159,35 @@ class ToolRulesColumn(TypeDecorator): def load_dialect_impl(self, dialect): return dialect.type_descriptor(JSON()) - def process_bind_param(self, value: List[BaseToolRule], dialect): + def process_bind_param(self, value, dialect): """Convert a list of ToolRules to JSON-serializable format.""" if value: - return [rule.model_dump() for rule in value] + data = [rule.model_dump() for rule in value] + for d in data: + d["type"] = d["type"].value + + for d in data: + assert not (d["type"] == "ToolRule" and "children" not in d), "ToolRule does not have children field" + return data return value - def process_result_value(self, value, dialect) -> List[BaseToolRule]: + def process_result_value(self, value, dialect) -> List[Union[ChildToolRule, InitToolRule, TerminalToolRule]]: """Convert JSON back to a list of ToolRules.""" if value: return [self.deserialize_tool_rule(rule_data) for rule_data in value] return value @staticmethod - def deserialize_tool_rule(data: dict) -> BaseToolRule: + def deserialize_tool_rule(data: dict) -> Union[ChildToolRule, InitToolRule, TerminalToolRule]: """Deserialize a dictionary to the appropriate ToolRule subclass based on the 'type'.""" - rule_type = data.get("type") # Remove 'type' field if it exists since it is a class var - if rule_type == "InitToolRule": + rule_type = ToolRuleType(data.get("type")) # Remove 'type' field if it exists since it is a class var + if rule_type == ToolRuleType.run_first: return InitToolRule(**data) - elif rule_type == "TerminalToolRule": + elif rule_type == ToolRuleType.exit_loop: return TerminalToolRule(**data) - elif rule_type == "ToolRule": - return ToolRule(**data) + elif rule_type == ToolRuleType.constrain_child_tools: + rule = ChildToolRule(**data) + return rule else: raise ValueError(f"Unknown tool rule type: {rule_type}") @@ -205,7 +206,6 @@ class AgentModel(Base): # state (context compilation) message_ids = Column(JSON) - memory = Column(JSON) system = Column(String) # configs @@ -217,7 +217,7 @@ class AgentModel(Base): metadata_ = Column(JSON) # tools - tools = Column(JSON) + tool_names = Column(JSON) tool_rules = Column(ToolRulesColumn) Index(__tablename__ + "_idx_user", user_id), @@ -225,24 +225,22 @@ class AgentModel(Base): def __repr__(self) -> str: return f"" - def to_record(self) -> AgentState: - agent_state = AgentState( + def to_record(self) -> PersistedAgentState: + agent_state = PersistedAgentState( id=self.id, user_id=self.user_id, name=self.name, created_at=self.created_at, description=self.description, message_ids=self.message_ids, - memory=Memory.load(self.memory), # load dictionary system=self.system, - tools=self.tools, + tool_names=self.tool_names, tool_rules=self.tool_rules, agent_type=self.agent_type, llm_config=self.llm_config, embedding_config=self.embedding_config, metadata_=self.metadata_, ) - assert isinstance(agent_state.memory, Memory), f"Memory object is not of type Memory: {type(agent_state.memory)}" return agent_state @@ -347,18 +345,18 @@ def get_all_api_keys_for_user(self, user_id: str) -> List[APIKey]: return tokens @enforce_types - def create_agent(self, agent: AgentState): + def create_agent(self, agent: PersistedAgentState): # insert into agent table # make sure agent.name does not already exist for user user_id with self.session_maker() as session: if session.query(AgentModel).filter(AgentModel.name == agent.name).filter(AgentModel.user_id == agent.user_id).count() > 0: raise ValueError(f"Agent with name {agent.name} already exists") fields = vars(agent) - fields["memory"] = agent.memory.to_dict() - if "_internal_memory" in fields: - del fields["_internal_memory"] - else: - warnings.warn(f"Agent {agent.id} has no _internal_memory field") + # fields["memory"] = agent.memory.to_dict() + # if "_internal_memory" in fields: + # del fields["_internal_memory"] + # else: + # warnings.warn(f"Agent {agent.id} has no _internal_memory field") if "tags" in fields: del fields["tags"] else: @@ -367,15 +365,15 @@ def create_agent(self, agent: AgentState): session.commit() @enforce_types - def update_agent(self, agent: AgentState): + def update_agent(self, agent: PersistedAgentState): with self.session_maker() as session: fields = vars(agent) - if isinstance(agent.memory, Memory): # TODO: this is nasty but this whole class will soon be removed so whatever - fields["memory"] = agent.memory.to_dict() - if "_internal_memory" in fields: - del fields["_internal_memory"] - else: - warnings.warn(f"Agent {agent.id} has no _internal_memory field") + # if isinstance(agent.memory, Memory): # TODO: this is nasty but this whole class will soon be removed so whatever + # fields["memory"] = agent.memory.to_dict() + # if "_internal_memory" in fields: + # del fields["_internal_memory"] + # else: + # warnings.warn(f"Agent {agent.id} has no _internal_memory field") if "tags" in fields: del fields["tags"] else: @@ -400,7 +398,7 @@ def delete_agent(self, agent_id: str, per_agent_lock_manager: PerAgentLockManage session.commit() @enforce_types - def list_agents(self, user_id: str) -> List[AgentState]: + def list_agents(self, user_id: str) -> List[PersistedAgentState]: with self.session_maker() as session: results = session.query(AgentModel).filter(AgentModel.user_id == user_id).all() return [r.to_record() for r in results] @@ -408,7 +406,7 @@ def list_agents(self, user_id: str) -> List[AgentState]: @enforce_types def get_agent( self, agent_id: Optional[str] = None, agent_name: Optional[str] = None, user_id: Optional[str] = None - ) -> Optional[AgentState]: + ) -> Optional[PersistedAgentState]: with self.session_maker() as session: if agent_id: results = session.query(AgentModel).filter(AgentModel.id == agent_id).all() diff --git a/letta/o1_agent.py b/letta/o1_agent.py index 9539e4aff3..a6b70b595f 100644 --- a/letta/o1_agent.py +++ b/letta/o1_agent.py @@ -6,7 +6,6 @@ from letta.schemas.agent import AgentState from letta.schemas.message import Message from letta.schemas.openai.chat_completion_response import UsageStatistics -from letta.schemas.tool import Tool from letta.schemas.usage import LettaUsageStatistics from letta.schemas.user import User @@ -45,13 +44,11 @@ def __init__( interface: AgentInterface, agent_state: AgentState, user: User, - tools: List[Tool] = [], max_thinking_steps: int = 10, first_message_verify_mono: bool = False, ): - super().__init__(interface, agent_state, tools, user) + super().__init__(interface, agent_state, user) self.max_thinking_steps = max_thinking_steps - self.tools = tools self.first_message_verify_mono = first_message_verify_mono def step( diff --git a/letta/persistence_manager.py b/letta/persistence_manager.py index ca8c097bfa..7dd22a998f 100644 --- a/letta/persistence_manager.py +++ b/letta/persistence_manager.py @@ -121,6 +121,7 @@ def prepend_to_messages(self, added_messages: List[Message]): # self.messages = [self.messages[0]] + added_messages + self.messages[1:] # add to recall memory + self.recall_memory.insert_many([m for m in added_messages]) def append_to_messages(self, added_messages: List[Message]): # first tag with timestamps diff --git a/letta/schemas/agent.py b/letta/schemas/agent.py index 648546ef06..8b5161eb32 100644 --- a/letta/schemas/agent.py +++ b/letta/schemas/agent.py @@ -2,15 +2,18 @@ from enum import Enum from typing import Dict, List, Optional -from pydantic import BaseModel, Field, field_validator, model_validator +from pydantic import BaseModel, Field, field_validator +from letta.schemas.block import CreateBlock from letta.schemas.embedding_config import EmbeddingConfig from letta.schemas.letta_base import LettaBase from letta.schemas.llm_config import LLMConfig from letta.schemas.memory import Memory from letta.schemas.message import Message from letta.schemas.openai.chat_completion_response import UsageStatistics -from letta.schemas.tool_rule import BaseToolRule +from letta.schemas.source import Source +from letta.schemas.tool import Tool +from letta.schemas.tool_rule import ToolRule class BaseAgent(LettaBase, validate_assignment=True): @@ -32,23 +35,8 @@ class AgentType(str, Enum): o1_agent = "o1_agent" -class AgentState(BaseAgent, validate_assignment=True): - """ - Representation of an agent's state. This is the state of the agent at a given time, and is persisted in the DB backend. The state has all the information needed to recreate a persisted agent. - - Parameters: - id (str): The unique identifier of the agent. - name (str): The name of the agent (must be unique to the user). - created_at (datetime): The datetime the agent was created. - message_ids (List[str]): The ids of the messages in the agent's in-context memory. - memory (Memory): The in-context memory of the agent. - tools (List[str]): The tools used by the agent. This includes any memory editing functions specified in `memory`. - system (str): The system prompt used by the agent. - llm_config (LLMConfig): The LLM configuration used by the agent. - embedding_config (EmbeddingConfig): The embedding configuration used by the agent. - - """ - +class PersistedAgentState(BaseAgent, validate_assignment=True): + # NOTE: this has been changed to represent the data stored in the ORM, NOT what is passed around internally or returned to the user id: str = BaseAgent.generate_id_field() name: str = Field(..., description="The name of the agent.") created_at: datetime = Field(..., description="The datetime the agent was created.", default_factory=datetime.now) @@ -56,16 +44,12 @@ class AgentState(BaseAgent, validate_assignment=True): # in-context memory message_ids: Optional[List[str]] = Field(default=None, description="The ids of the messages in the agent's in-context memory.") - memory: Memory = Field(default_factory=Memory, description="The in-context memory of the agent.") - # tools - tools: List[str] = Field(..., description="The tools used by the agent.") + # TODO: move to ORM mapping + tool_names: List[str] = Field(..., description="The tools used by the agent.") # tool rules - tool_rules: Optional[List[BaseToolRule]] = Field(default=None, description="The list of tool rules.") - - # tags - tags: Optional[List[str]] = Field(None, description="The tags associated with the agent.") + tool_rules: Optional[List[ToolRule]] = Field(default=None, description="The list of tool rules.") # system prompt system: str = Field(..., description="The system prompt used by the agent.") @@ -77,40 +61,62 @@ class AgentState(BaseAgent, validate_assignment=True): llm_config: LLMConfig = Field(..., description="The LLM configuration used by the agent.") embedding_config: EmbeddingConfig = Field(..., description="The embedding configuration used by the agent.") - def __init__(self, **data): - super().__init__(**data) - self._internal_memory = self.memory + class Config: + arbitrary_types_allowed = True + validate_assignment = True - @model_validator(mode="after") - def verify_memory_type(self): - try: - assert isinstance(self.memory, Memory) - except Exception as e: - raise e - return self - @property - def memory(self) -> Memory: - return self._internal_memory +class AgentState(PersistedAgentState): + """ + Representation of an agent's state. This is the state of the agent at a given time, and is persisted in the DB backend. The state has all the information needed to recreate a persisted agent. + + Parameters: + id (str): The unique identifier of the agent. + name (str): The name of the agent (must be unique to the user). + created_at (datetime): The datetime the agent was created. + message_ids (List[str]): The ids of the messages in the agent's in-context memory. + memory (Memory): The in-context memory of the agent. + tools (List[str]): The tools used by the agent. This includes any memory editing functions specified in `memory`. + system (str): The system prompt used by the agent. + llm_config (LLMConfig): The LLM configuration used by the agent. + embedding_config (EmbeddingConfig): The embedding configuration used by the agent. - @memory.setter - def memory(self, value): - if not isinstance(value, Memory): - raise TypeError(f"Expected Memory, got {type(value).__name__}") - self._internal_memory = value + """ - class Config: - arbitrary_types_allowed = True - validate_assignment = True + # NOTE: this is what is returned to the client and also what is used to initialize `Agent` + + # This is an object representing the in-process state of a running `Agent` + # Field in this object can be theoretically edited by tools, and will be persisted by the ORM + memory: Memory = Field(..., description="The in-context memory of the agent.") + tools: List[Tool] = Field(..., description="The tools used by the agent.") + sources: List[Source] = Field(..., description="The sources used by the agent.") + tags: List[str] = Field(..., description="The tags associated with the agent.") + # TODO: add in context message objects + def to_persisted_agent_state(self) -> PersistedAgentState: + # turn back into persisted agent + data = self.model_dump() + del data["memory"] + del data["tools"] + del data["sources"] + del data["tags"] + return PersistedAgentState(**data) -class CreateAgent(BaseAgent): + +class CreateAgent(BaseAgent): # # all optional as server can generate defaults name: Optional[str] = Field(None, description="The name of the agent.") message_ids: Optional[List[str]] = Field(None, description="The ids of the messages in the agent's in-context memory.") - memory: Optional[Memory] = Field(None, description="The in-context memory of the agent.") + + # memory creation + memory_blocks: List[CreateBlock] = Field( + # [CreateHuman(), CreatePersona()], description="The blocks to create in the agent's in-context memory." + ..., + description="The blocks to create in the agent's in-context memory.", + ) + tools: Optional[List[str]] = Field(None, description="The tools used by the agent.") - tool_rules: Optional[List[BaseToolRule]] = Field(None, description="The tool rules governing the agent.") + tool_rules: Optional[List[ToolRule]] = Field(None, description="The tool rules governing the agent.") tags: Optional[List[str]] = Field(None, description="The tags associated with the agent.") system: Optional[str] = Field(None, description="The system prompt used by the agent.") agent_type: Optional[AgentType] = Field(None, description="The type of agent.") @@ -151,7 +157,7 @@ def validate_name(cls, name: str) -> str: class UpdateAgentState(BaseAgent): id: str = Field(..., description="The id of the agent.") name: Optional[str] = Field(None, description="The name of the agent.") - tools: Optional[List[str]] = Field(None, description="The tools used by the agent.") + tool_names: Optional[List[str]] = Field(None, description="The tools used by the agent.") tags: Optional[List[str]] = Field(None, description="The tags associated with the agent.") system: Optional[str] = Field(None, description="The system prompt used by the agent.") llm_config: Optional[LLMConfig] = Field(None, description="The LLM configuration used by the agent.") @@ -159,7 +165,6 @@ class UpdateAgentState(BaseAgent): # TODO: determine if these should be editable via this schema? message_ids: Optional[List[str]] = Field(None, description="The ids of the messages in the agent's in-context memory.") - memory: Optional[Memory] = Field(None, description="The in-context memory of the agent.") class AgentStepResponse(BaseModel): diff --git a/letta/schemas/block.py b/letta/schemas/block.py index 6679d50357..b48bdbbce1 100644 --- a/letta/schemas/block.py +++ b/letta/schemas/block.py @@ -28,6 +28,12 @@ class BaseBlock(LettaBase, validate_assignment=True): description: Optional[str] = Field(None, description="Description of the block.") metadata_: Optional[dict] = Field({}, description="Metadata of the block.") + # def __len__(self): + # return len(self.value) + + class Config: + extra = "ignore" # Ignores extra fields + @model_validator(mode="after") def verify_char_limit(self) -> Self: if self.value and len(self.value) > self.limit: @@ -36,9 +42,6 @@ def verify_char_limit(self) -> Self: return self - # def __len__(self): - # return len(self.value) - def __setattr__(self, name, value): """Run validation if self.value is updated""" super().__setattr__(name, value) @@ -46,9 +49,6 @@ def __setattr__(self, name, value): # run validation self.__class__.model_validate(self.model_dump(exclude_unset=True)) - class Config: - extra = "ignore" # Ignores extra fields - class Block(BaseBlock): """ @@ -88,11 +88,11 @@ class Persona(Block): label: str = "persona" -class BlockCreate(BaseBlock): - """Create a block""" - - is_template: bool = True - label: str = Field(..., description="Label of the block.") +# class CreateBlock(BaseBlock): +# """Create a block""" +# +# is_template: bool = True +# label: str = Field(..., description="Label of the block.") class BlockLabelUpdate(BaseModel): @@ -102,16 +102,16 @@ class BlockLabelUpdate(BaseModel): new_label: str = Field(..., description="New label of the block.") -class CreatePersona(BlockCreate): - """Create a persona block""" - - label: str = "persona" - - -class CreateHuman(BlockCreate): - """Create a human block""" - - label: str = "human" +# class CreatePersona(CreateBlock): +# """Create a persona block""" +# +# label: str = "persona" +# +# +# class CreateHuman(CreateBlock): +# """Create a human block""" +# +# label: str = "human" class BlockUpdate(BaseBlock): @@ -131,13 +131,57 @@ class BlockLimitUpdate(BaseModel): limit: int = Field(..., description="New limit of the block.") -class UpdatePersona(BlockUpdate): - """Update a persona block""" +# class UpdatePersona(BlockUpdate): +# """Update a persona block""" +# +# label: str = "persona" +# +# +# class UpdateHuman(BlockUpdate): +# """Update a human block""" +# +# label: str = "human" + + +class CreateBlock(BaseBlock): + """Create a block""" + + label: str = Field(..., description="Label of the block.") + limit: int = Field(2000, description="Character limit of the block.") + value: str = Field(..., description="Value of the block.") + + # block templates + is_template: bool = False + template_name: Optional[str] = Field(None, description="Name of the block if it is a template.", alias="name") + + +class CreateHuman(CreateBlock): + """Create a human block""" + + label: str = "human" + + +class CreatePersona(CreateBlock): + """Create a persona block""" label: str = "persona" -class UpdateHuman(BlockUpdate): - """Update a human block""" +class CreateBlockTemplate(CreateBlock): + """Create a block template""" + + is_template: bool = True + + +class CreateHumanBlockTemplate(CreateHuman): + """Create a human block template""" + is_template: bool = True label: str = "human" + + +class CreatePersonaBlockTemplate(CreatePersona): + """Create a persona block template""" + + is_template: bool = True + label: str = "persona" diff --git a/letta/schemas/enums.py b/letta/schemas/enums.py index ea1335a990..8b74b83732 100644 --- a/letta/schemas/enums.py +++ b/letta/schemas/enums.py @@ -33,3 +33,17 @@ class MessageStreamStatus(str, Enum): done_generation = "[DONE_GEN]" done_step = "[DONE_STEP]" done = "[DONE]" + + +class ToolRuleType(str, Enum): + """ + Type of tool rule. + """ + + # note: some of these should be renamed when we do the data migration + + run_first = "InitToolRule" + exit_loop = "TerminalToolRule" # reasoning loop should exit + continue_loop = "continue_loop" # reasoning loop should continue + constrain_child_tools = "ToolRule" + require_parent_tools = "require_parent_tools" diff --git a/letta/schemas/memory.py b/letta/schemas/memory.py index f38a18fac6..9084006dbe 100644 --- a/letta/schemas/memory.py +++ b/letta/schemas/memory.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, Dict, List, Optional +from typing import TYPE_CHECKING, List, Optional from jinja2 import Template, TemplateSyntaxError from pydantic import BaseModel, Field @@ -55,19 +55,16 @@ class ContextWindowOverview(BaseModel): class Memory(BaseModel, validate_assignment=True): """ - Represents the in-context memory of the agent. This includes both the `Block` objects (labelled by sections), as well as tools to edit the blocks. - - Attributes: - memory (Dict[str, Block]): Mapping from memory block section to memory block. + Represents the in-context memory (i.e. Core memory) of the agent. This includes both the `Block` objects (labelled by sections), as well as tools to edit the blocks. """ - # Memory.memory is a dict mapping from memory block label to memory block. - memory: Dict[str, Block] = Field(default_factory=dict, description="Mapping from memory block section to memory block.") + # Memory.block contains the list of memory blocks in the core memory + blocks: List[Block] = Field(..., description="Memory blocks contained in the agent's in-context memory") # Memory.template is a Jinja2 template for compiling memory module into a prompt string. prompt_template: str = Field( - default="{% for block in memory.values() %}" + default="{% for block in blocks %}" '<{{ block.label }} characters="{{ block.value|length }}/{{ block.limit }}">\n' "{{ block.value }}\n" "" @@ -90,7 +87,7 @@ def set_prompt_template(self, prompt_template: str): Template(prompt_template) # Validate compatibility with current memory structure - test_render = Template(prompt_template).render(memory=self.memory) + test_render = Template(prompt_template).render(blocks=self.blocks) # If we get here, the template is valid and compatible self.prompt_template = prompt_template @@ -99,107 +96,49 @@ def set_prompt_template(self, prompt_template: str): except Exception as e: raise ValueError(f"Prompt template is not compatible with current memory structure: {str(e)}") - @classmethod - def load(cls, state: dict): - """Load memory from dictionary object""" - obj = cls() - if len(state.keys()) == 2 and "memory" in state and "prompt_template" in state: - # New format - obj.prompt_template = state["prompt_template"] - for key, value in state["memory"].items(): - # TODO: This is migration code, please take a look at a later time to get rid of this - if "name" in value: - value["template_name"] = value["name"] - value.pop("name") - obj.memory[key] = Block(**value) - else: - # Old format (pre-template) - for key, value in state.items(): - obj.memory[key] = Block(**value) - return obj - def compile(self) -> str: """Generate a string representation of the memory in-context using the Jinja2 template""" template = Template(self.prompt_template) - return template.render(memory=self.memory) - - def to_dict(self): - """Convert to dictionary representation""" - return { - "memory": {key: value.model_dump() for key, value in self.memory.items()}, - "prompt_template": self.prompt_template, - } - - def to_flat_dict(self): - """Convert to a dictionary that maps directly from block label to values""" - return {k: v.value for k, v in self.memory.items() if v is not None} + return template.render(blocks=self.blocks) def list_block_labels(self) -> List[str]: """Return a list of the block names held inside the memory object""" - return list(self.memory.keys()) + # return list(self.memory.keys()) + return [block.label for block in self.blocks] # TODO: these should actually be label, not name def get_block(self, label: str) -> Block: """Correct way to index into the memory.memory field, returns a Block""" - if label not in self.memory: - raise KeyError(f"Block field {label} does not exist (available sections = {', '.join(list(self.memory.keys()))})") - else: - return self.memory[label] + keys = [] + for block in self.blocks: + if block.label == label: + return block + keys.append(block.label) + raise KeyError(f"Block field {label} does not exist (available sections = {', '.join(keys)})") def get_blocks(self) -> List[Block]: """Return a list of the blocks held inside the memory object""" - return list(self.memory.values()) - - def link_block(self, block: Block, override: Optional[bool] = False): - """Link a new block to the memory object""" - if not isinstance(block, Block): - raise ValueError(f"Param block must be type Block (not {type(block)})") - if not override and block.label in self.memory: - raise ValueError(f"Block with label {block.label} already exists") - - self.memory[block.label] = block + # return list(self.memory.values()) + return self.blocks - def unlink_block(self, block_label: str) -> Block: - """Unlink a block from the memory object""" - if block_label not in self.memory: - raise ValueError(f"Block with label {block_label} does not exist") - - return self.memory.pop(block_label) + def set_block(self, block: Block): + """Set a block in the memory object""" + for i, b in enumerate(self.blocks): + if b.label == block.label: + self.blocks[i] = block + return + self.blocks.append(block) def update_block_value(self, label: str, value: str): """Update the value of a block""" - if label not in self.memory: - raise ValueError(f"Block with label {label} does not exist") if not isinstance(value, str): raise ValueError(f"Provided value must be a string") - self.memory[label].value = value - - def update_block_label(self, current_label: str, new_label: str): - """Update the label of a block""" - if current_label not in self.memory: - raise ValueError(f"Block with label {current_label} does not exist") - if not isinstance(new_label, str): - raise ValueError(f"Provided new label must be a string") - - # First change the label of the block - self.memory[current_label].label = new_label - - # Then swap the block to the new label - self.memory[new_label] = self.memory.pop(current_label) - - def update_block_limit(self, label: str, limit: int): - """Update the limit of a block""" - if label not in self.memory: - raise ValueError(f"Block with label {label} does not exist") - if not isinstance(limit, int): - raise ValueError(f"Provided limit must be an integer") - - # Check to make sure the new limit is greater than the current length of the block - if len(self.memory[label].value) > limit: - raise ValueError(f"New limit {limit} is less than the current length of the block {len(self.memory[label].value)}") - - self.memory[label].limit = limit + for block in self.blocks: + if block.label == label: + block.value = value + return + raise ValueError(f"Block with label {label} does not exist") # TODO: ideally this is refactored into ChatMemory and the subclasses are given more specific names. @@ -222,13 +161,7 @@ def __init__(self, blocks: List[Block] = []): Args: blocks (List[Block]): List of blocks to be linked to the memory object. """ - super().__init__() - for block in blocks: - # TODO: centralize these internal schema validations - # assert block.name is not None and block.name != "", "each existing chat block must have a name" - # self.link_block(name=block.name, block=block) - assert block.label is not None and block.label != "", "each existing chat block must have a name" - self.link_block(block=block) + super().__init__(blocks=blocks) def core_memory_append(agent_state: "AgentState", label: str, content: str) -> Optional[str]: # type: ignore """ @@ -280,9 +213,7 @@ def __init__(self, persona: str, human: str, limit: int = CORE_MEMORY_BLOCK_CHAR human (str): The starter value for the human block. limit (int): The character limit for each block. """ - super().__init__() - self.link_block(block=Block(value=persona, limit=limit, label="persona")) - self.link_block(block=Block(value=human, limit=limit, label="human")) + super().__init__(blocks=[Block(value=persona, limit=limit, label="persona"), Block(value=human, limit=limit, label="human")]) class UpdateMemory(BaseModel): diff --git a/letta/schemas/tool_rule.py b/letta/schemas/tool_rule.py index d1540a613a..42f460e467 100644 --- a/letta/schemas/tool_rule.py +++ b/letta/schemas/tool_rule.py @@ -1,21 +1,24 @@ -from typing import List +from typing import List, Union from pydantic import Field +from letta.schemas.enums import ToolRuleType from letta.schemas.letta_base import LettaBase class BaseToolRule(LettaBase): __id_prefix__ = "tool_rule" tool_name: str = Field(..., description="The name of the tool. Must exist in the database for the user's organization.") + type: ToolRuleType -class ToolRule(BaseToolRule): +class ChildToolRule(BaseToolRule): """ A ToolRule represents a tool that can be invoked by the agent. """ - type: str = Field("ToolRule") + # type: str = Field("ToolRule") + type: ToolRuleType = ToolRuleType.constrain_child_tools children: List[str] = Field(..., description="The children tools that can be invoked.") @@ -24,7 +27,8 @@ class InitToolRule(BaseToolRule): Represents the initial tool rule configuration. """ - type: str = Field("InitToolRule") + # type: str = Field("InitToolRule") + type: ToolRuleType = ToolRuleType.run_first class TerminalToolRule(BaseToolRule): @@ -32,4 +36,8 @@ class TerminalToolRule(BaseToolRule): Represents a terminal tool rule configuration where if this tool gets called, it must end the agent loop. """ - type: str = Field("TerminalToolRule") + # type: str = Field("TerminalToolRule") + type: ToolRuleType = ToolRuleType.exit_loop + + +ToolRule = Union[ChildToolRule, InitToolRule, TerminalToolRule] diff --git a/letta/server/rest_api/routers/openai/assistants/threads.py b/letta/server/rest_api/routers/openai/assistants/threads.py index a8a072e5b1..8742aa4219 100644 --- a/letta/server/rest_api/routers/openai/assistants/threads.py +++ b/letta/server/rest_api/routers/openai/assistants/threads.py @@ -117,7 +117,7 @@ def create_message( tool_call_id=None, name=None, ) - agent = server._get_or_load_agent(agent_id=agent_id) + agent = server.load_agent(agent_id=agent_id) # add message to agent agent._append_to_messages([message]) @@ -246,7 +246,7 @@ def create_run( # TODO: add request.instructions as a message? agent_id = thread_id # TODO: override preset of agent with request.assistant_id - agent = server._get_or_load_agent(agent_id=agent_id) + agent = server.load_agent(agent_id=agent_id) agent.inner_step(messages=[]) # already has messages added run_id = str(uuid.uuid4()) create_time = int(get_utc_time().timestamp()) diff --git a/letta/server/rest_api/routers/v1/agents.py b/letta/server/rest_api/routers/v1/agents.py index 470c456bbc..1ed69dbbae 100644 --- a/letta/server/rest_api/routers/v1/agents.py +++ b/letta/server/rest_api/routers/v1/agents.py @@ -1,13 +1,17 @@ import asyncio from datetime import datetime -from typing import Dict, List, Optional, Union +from typing import List, Optional, Union from fastapi import APIRouter, Body, Depends, Header, HTTPException, Query, status from fastapi.responses import JSONResponse, StreamingResponse from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG from letta.schemas.agent import AgentState, CreateAgent, UpdateAgentState -from letta.schemas.block import Block, BlockCreate, BlockLabelUpdate, BlockLimitUpdate +from letta.schemas.block import ( # , BlockLabelUpdate, BlockLimitUpdate + Block, + BlockUpdate, + CreateBlock, +) from letta.schemas.enums import MessageStreamStatus from letta.schemas.letta_message import ( LegacyLettaMessage, @@ -18,7 +22,6 @@ from letta.schemas.letta_response import LettaResponse from letta.schemas.memory import ( ArchivalMemorySummary, - BasicBlockMemory, ContextWindowOverview, CreateArchivalMemory, Memory, @@ -82,13 +85,6 @@ def create_agent( Create a new agent with the specified configuration. """ actor = server.get_user_or_default(user_id=user_id) - agent.user_id = actor.id - # TODO: sarah make general - # TODO: eventually remove this - assert agent.memory is not None # TODO: dont force this, can be None (use default human/person) - blocks = agent.memory.get_blocks() - agent.memory = BasicBlockMemory(blocks=blocks) - return server.create_agent(agent, actor=actor) @@ -195,6 +191,7 @@ def get_agent_in_context_messages( return server.get_in_context_messages(agent_id=agent_id) +# TODO: remove? can also get with agent blocks @router.get("/{agent_id}/memory", response_model=Memory, operation_id="get_agent_memory") def get_agent_memory( agent_id: str, @@ -208,47 +205,40 @@ def get_agent_memory( return server.get_agent_memory(agent_id=agent_id) -@router.patch("/{agent_id}/memory", response_model=Memory, operation_id="update_agent_memory") -def update_agent_memory( +@router.get("/{agent_id}/memory/block/{block_label}", response_model=Block, operation_id="get_agent_memory_block") +def get_agent_memory_block( agent_id: str, - request: Dict = Body(...), + block_label: str, server: "SyncServer" = Depends(get_letta_server), user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present ): """ - Update the core memory of a specific agent. - This endpoint accepts new memory contents (labels as keys, and values as values) and updates the core memory of the agent identified by the user ID and agent ID. - This endpoint accepts new memory contents to update the core memory of the agent. - This endpoint only supports modifying existing blocks; it does not support deleting/unlinking or creating/linking blocks. + Retrieve a memory block from an agent. """ actor = server.get_user_or_default(user_id=user_id) - memory = server.update_agent_core_memory(user_id=actor.id, agent_id=agent_id, new_memory_contents=request) - return memory + block_id = server.blocks_agents_manager.get_block_id_for_label(agent_id=agent_id, block_label=block_label) + return server.block_manager.get_block_by_id(block_id, actor=actor) -@router.patch("/{agent_id}/memory/label", response_model=Memory, operation_id="update_agent_memory_label") -def update_agent_memory_label( +@router.get("/{agent_id}/memory/block", response_model=List[Block], operation_id="get_agent_memory_blocks") +def get_agent_memory_blocks( agent_id: str, - update_label: BlockLabelUpdate = Body(...), server: "SyncServer" = Depends(get_letta_server), user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present ): """ - Update the label of a block in an agent's memory. + Retrieve the memory blocks of a specific agent. """ actor = server.get_user_or_default(user_id=user_id) - - memory = server.update_agent_memory_label( - user_id=actor.id, agent_id=agent_id, current_block_label=update_label.current_label, new_block_label=update_label.new_label - ) - return memory + block_ids = server.blocks_agents_manager.list_block_ids_for_agent(agent_id=agent_id) + return [server.block_manager.get_block_by_id(block_id, actor=actor) for block_id in block_ids] @router.post("/{agent_id}/memory/block", response_model=Memory, operation_id="add_agent_memory_block") def add_agent_memory_block( agent_id: str, - create_block: BlockCreate = Body(...), + create_block: CreateBlock = Body(...), server: "SyncServer" = Depends(get_letta_server), user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present ): @@ -267,7 +257,7 @@ def add_agent_memory_block( return updated_memory -@router.delete("/{agent_id}/memory/block/{block_label}", response_model=Memory, operation_id="remove_agent_memory_block") +@router.delete("/{agent_id}/memory/block/{block_label}", response_model=Memory, operation_id="remove_agent_memory_block_by_label") def remove_agent_memory_block( agent_id: str, # TODO should this be block_id, or the label? @@ -287,25 +277,24 @@ def remove_agent_memory_block( return updated_memory -@router.patch("/{agent_id}/memory/limit", response_model=Memory, operation_id="update_agent_memory_limit") -def update_agent_memory_limit( +@router.patch("/{agent_id}/memory/block/{block_label}", response_model=Block, operation_id="update_agent_memory_block_by_label") +def update_agent_memory_block( agent_id: str, - update_label: BlockLimitUpdate = Body(...), + block_label: str, + update_block: BlockUpdate = Body(...), server: "SyncServer" = Depends(get_letta_server), user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present ): """ - Update the limit of a block in an agent's memory. + Removes a memory block from an agent by unlnking it. If the block is not linked to any other agent, it is deleted. """ actor = server.get_user_or_default(user_id=user_id) - memory = server.update_agent_memory_limit( - user_id=actor.id, - agent_id=agent_id, - block_label=update_label.label, - limit=update_label.limit, - ) - return memory + # get the block_id from the label + block_id = server.blocks_agents_manager.get_block_id_for_label(agent_id=agent_id, block_label=block_label) + + # update the block + return server.block_manager.update_block(block_id=block_id, block_update=update_block, actor=actor) @router.get("/{agent_id}/memory/recall", response_model=RecallMemorySummary, operation_id="get_agent_recall_memory_summary") @@ -548,7 +537,8 @@ async def send_message_to_agent( # Get the generator object off of the agent's streaming interface # This will be attached to the POST SSE request used under-the-hood - letta_agent = server._get_or_load_agent(agent_id=agent_id) + # letta_agent = server.load_agent(agent_id=agent_id) + letta_agent = server.load_agent(agent_id=agent_id) # Disable token streaming if not OpenAI # TODO: cleanup this logic @@ -590,6 +580,7 @@ async def send_message_to_agent( user_id=user_id, agent_id=agent_id, messages=messages, + interface=streaming_interface, ) ) diff --git a/letta/server/rest_api/routers/v1/blocks.py b/letta/server/rest_api/routers/v1/blocks.py index 6fee08dd69..798e7aa4ac 100644 --- a/letta/server/rest_api/routers/v1/blocks.py +++ b/letta/server/rest_api/routers/v1/blocks.py @@ -3,7 +3,8 @@ from fastapi import APIRouter, Body, Depends, Header, HTTPException, Query from letta.orm.errors import NoResultFound -from letta.schemas.block import Block, BlockCreate, BlockUpdate +from letta.schemas.block import Block, BlockUpdate, CreateBlock +from letta.schemas.memory import Memory from letta.server.rest_api.utils import get_letta_server from letta.server.server import SyncServer @@ -28,7 +29,7 @@ def list_blocks( @router.post("/", response_model=Block, operation_id="create_memory_block") def create_block( - create_block: BlockCreate = Body(...), + create_block: CreateBlock = Body(...), server: SyncServer = Depends(get_letta_server), user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present ): @@ -40,12 +41,12 @@ def create_block( @router.patch("/{block_id}", response_model=Block, operation_id="update_memory_block") def update_block( block_id: str, - updated_block: BlockUpdate = Body(...), + update_block: BlockUpdate = Body(...), server: SyncServer = Depends(get_letta_server), user_id: Optional[str] = Header(None, alias="user_id"), ): actor = server.get_user_or_default(user_id=user_id) - return server.block_manager.update_block(block_id=block_id, block_update=updated_block, actor=actor) + return server.block_manager.update_block(block_id=block_id, block_update=update_block, actor=actor) @router.delete("/{block_id}", response_model=Block, operation_id="delete_memory_block") @@ -64,8 +65,52 @@ def get_block( server: SyncServer = Depends(get_letta_server), user_id: Optional[str] = Header(None, alias="user_id"), ): + print("call get block", block_id) actor = server.get_user_or_default(user_id=user_id) try: - return server.block_manager.get_block_by_id(block_id=block_id, actor=actor) + block = server.block_manager.get_block_by_id(block_id=block_id, actor=actor) + if block is None: + raise HTTPException(status_code=404, detail="Block not found") + return block except NoResultFound: raise HTTPException(status_code=404, detail="Block not found") + + +@router.patch("/{block_id}/attach", response_model=Block, operation_id="update_agent_memory_block") +def link_agent_memory_block( + block_id: str, + agent_id: str = Query(..., description="The unique identifier of the agent to attach the source to."), + server: "SyncServer" = Depends(get_letta_server), + user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present +): + """ + Link a memory block to an agent. + """ + actor = server.get_user_or_default(user_id=user_id) + + block = server.block_manager.get_block_by_id(block_id=block_id, actor=actor) + if block is None: + raise HTTPException(status_code=404, detail="Block not found") + + server.blocks_agents_manager.add_block_to_agent(agent_id=agent_id, block_id=block_id, block_label=block.label) + return block + + +@router.patch("/{block_id}/detach", response_model=Memory, operation_id="update_agent_memory_block") +def unlink_agent_memory_block( + block_id: str, + agent_id: str = Query(..., description="The unique identifier of the agent to attach the source to."), + server: "SyncServer" = Depends(get_letta_server), + user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present +): + """ + Unlink a memory block from an agent + """ + actor = server.get_user_or_default(user_id=user_id) + + block = server.block_manager.get_block_by_id(block_id=block_id, actor=actor) + if block is None: + raise HTTPException(status_code=404, detail="Block not found") + # Link the block to the agent + server.blocks_agents_manager.remove_block_with_id_from_agent(agent_id=agent_id, block_id=block_id) + return block diff --git a/letta/server/server.py b/letta/server/server.py index 6d54f42339..832753115c 100644 --- a/letta/server/server.py +++ b/letta/server/server.py @@ -18,24 +18,10 @@ from letta.credentials import LettaCredentials from letta.data_sources.connectors import DataConnector, load_data -# from letta.data_types import ( -# AgentState, -# EmbeddingConfig, -# LLMConfig, -# Message, -# Preset, -# Source, -# Token, -# User, -# ) -from letta.functions.functions import generate_schema, parse_source_code -from letta.functions.schema_generator import generate_schema - # TODO use custom interface from letta.interface import AgentInterface # abstract from letta.interface import CLIInterface # for printing to terminal from letta.log import get_logger -from letta.memory import get_memory_functions from letta.metadata import MetadataStore from letta.o1_agent import O1Agent from letta.orm import Base @@ -54,8 +40,15 @@ VLLMChatCompletionsProvider, VLLMCompletionsProvider, ) -from letta.schemas.agent import AgentState, AgentType, CreateAgent, UpdateAgentState +from letta.schemas.agent import ( + AgentState, + AgentType, + CreateAgent, + PersistedAgentState, + UpdateAgentState, +) from letta.schemas.api_key import APIKey, APIKeyCreate +from letta.schemas.block import Block, BlockUpdate from letta.schemas.embedding_config import EmbeddingConfig # openai schemas @@ -87,9 +80,6 @@ from letta.services.user_manager import UserManager from letta.utils import create_random_username, json_dumps, json_loads -# from letta.llm_api_tools import openai_get_model_list, azure_openai_get_model_list, smart_urljoin - - logger = get_logger(__name__) @@ -129,10 +119,11 @@ def update_agent_core_memory(self, user_id: str, agent_id: str, new_memory_conte @abstractmethod def create_agent( self, - user_id: str, - agent_config: Union[dict, AgentState], - interface: Union[AgentInterface, None], - ) -> str: + request: CreateAgent, + actor: User, + # interface + interface: Union[AgentInterface, None] = None, + ) -> AgentState: """Create a new agent using a config""" raise NotImplementedError @@ -254,8 +245,8 @@ def __init__( self.block_manager = BlockManager() self.source_manager = SourceManager() self.agents_tags_manager = AgentsTagsManager() - self.blocks_agents_manager = BlocksAgentsManager() self.sandbox_config_manager = SandboxConfigManager(tool_settings) + self.blocks_agents_manager = BlocksAgentsManager() # Managers that interface with parallelism self.per_agent_lock_manager = PerAgentLockManager() @@ -372,92 +363,23 @@ def _add_agent(self, user_id: str, agent_id: str, agent_obj: Agent) -> None: } ) - def _load_agent(self, agent_id: str, actor: User, interface: Union[AgentInterface, None] = None) -> Agent: - """Loads a saved agent into memory (if it doesn't exist, throw an error)""" - assert isinstance(agent_id, str), agent_id - user_id = actor.id + def load_agent(self, agent_id: str, interface: Union[AgentInterface, None] = None) -> Agent: + """Updated method to load agents from persisted storage""" + agent_state = self.get_agent(agent_id=agent_id) + actor = self.user_manager.get_user_by_id(user_id=agent_state.user_id) - # If an interface isn't specified, use the default - if interface is None: - interface = self.default_interface_factory() - - try: - logger.debug(f"Grabbing agent user_id={user_id} agent_id={agent_id} from database") - agent_state = self.ms.get_agent(agent_id=agent_id, user_id=user_id) - if not agent_state: - logger.exception(f"agent_id {agent_id} does not exist") - raise ValueError(f"agent_id {agent_id} does not exist") - - # Instantiate an agent object using the state retrieved - logger.debug(f"Creating an agent object") - tool_objs = [] - for name in agent_state.tools: - # TODO: This should be a hard failure, but for migration reasons, we patch it for now - tool_obj = self.tool_manager.get_tool_by_name(tool_name=name, actor=actor) - if tool_obj: - tool_obj = self.tool_manager.get_tool_by_name(tool_name=name, actor=actor) - tool_objs.append(tool_obj) - else: - warnings.warn(f"Tried to retrieve a tool with name {name} from the agent_state, but does not exist in tool db.") - - # set agent_state tools to only the names of the available tools - agent_state.tools = [t.name for t in tool_objs] - - # Make sure the memory is a memory object - assert isinstance(agent_state.memory, Memory) - - if agent_state.agent_type == AgentType.memgpt_agent: - letta_agent = Agent(agent_state=agent_state, interface=interface, tools=tool_objs, user=actor) - elif agent_state.agent_type == AgentType.o1_agent: - letta_agent = O1Agent(agent_state=agent_state, interface=interface, tools=tool_objs, user=actor) - else: - raise NotImplementedError("Not a supported agent type") - - # Add the agent to the in-memory store and return its reference - logger.debug(f"Adding agent to the agent cache: user_id={user_id}, agent_id={agent_id}") - self._add_agent(user_id=user_id, agent_id=agent_id, agent_obj=letta_agent) - return letta_agent - - except Exception as e: - logger.exception(f"Error occurred while trying to get agent {agent_id}:\n{e}") - raise - - def _get_or_load_agent(self, agent_id: str, caching: bool = True) -> Agent: - """Check if the agent is in-memory, then load""" - - # Gets the agent state - agent_state = self.ms.get_agent(agent_id=agent_id) - if not agent_state: - raise ValueError(f"Agent does not exist") - user_id = agent_state.user_id - actor = self.user_manager.get_user_by_id(user_id) - - logger.debug(f"Checking for agent user_id={user_id} agent_id={agent_id}") - if caching: - # TODO: consider disabling loading cached agents due to potential concurrency issues - letta_agent = self._get_agent(user_id=user_id, agent_id=agent_id) - if not letta_agent: - logger.debug(f"Agent not loaded, loading agent user_id={user_id} agent_id={agent_id}") - letta_agent = self._load_agent(agent_id=agent_id, actor=actor) + interface = interface or self.default_interface_factory() + if agent_state.agent_type == AgentType.memgpt_agent: + return Agent(agent_state=agent_state, interface=interface, user=actor) else: - # This breaks unit tests in test_local_client.py - letta_agent = self._load_agent(agent_id=agent_id, actor=actor) - - # letta_agent = self._get_agent(user_id=user_id, agent_id=agent_id) - # if not letta_agent: - # logger.debug(f"Agent not loaded, loading agent user_id={user_id} agent_id={agent_id}") - - # NOTE: no longer caching, always forcing a lot from the database - # Loads the agent objects - # letta_agent = self._load_agent(agent_id=agent_id, actor=actor) - - return letta_agent + return O1Agent(agent_state=agent_state, interface=interface, user=actor) def _step( self, user_id: str, agent_id: str, input_messages: Union[Message, List[Message]], + interface: Union[AgentInterface, None] = None, # needed to getting responses # timestamp: Optional[datetime], ) -> LettaUsageStatistics: """Send the input message through the agent""" @@ -473,7 +395,8 @@ def _step( try: # Get the agent object (loaded in memory) - letta_agent = self._get_or_load_agent(agent_id=agent_id) + # letta_agent = self._get_or_load_agent(agent_id=agent_id) + letta_agent = self.load_agent(agent_id=agent_id, interface=interface) if letta_agent is None: raise KeyError(f"Agent (user={user_id}, agent={agent_id}) is not loaded") @@ -490,6 +413,9 @@ def _step( skip_verify=True, ) + # save agent after step + save_agent(letta_agent, self.ms) + except Exception as e: logger.error(f"Error in server._step: {e}") print(traceback.print_exc()) @@ -507,7 +433,7 @@ def _command(self, user_id: str, agent_id: str, command: str) -> LettaUsageStati logger.debug(f"Got command: {command}") # Get the agent object (loaded in memory) - letta_agent = self._get_or_load_agent(agent_id=agent_id) + letta_agent = self.load_agent(agent_id=agent_id) usage = None if command.lower() == "exit": @@ -544,7 +470,7 @@ def _command(self, user_id: str, agent_id: str, command: str) -> LettaUsageStati elif command.lower() == "memory": ret_str = ( f"\nDumping memory contents:\n" - + f"\n{str(letta_agent.memory)}" + + f"\n{str(letta_agent.agent_state.memory)}" + f"\n{str(letta_agent.persistence_manager.archival_memory)}" + f"\n{str(letta_agent.persistence_manager.recall_memory)}" ) @@ -736,6 +662,7 @@ def send_messages( # whether or not to wrap user and system message as MemGPT-style stringified JSON wrap_user_message: bool = True, wrap_system_message: bool = True, + interface: Union[AgentInterface, None] = None, # needed to getting responses ) -> LettaUsageStatistics: """Send a list of messages to the agent @@ -788,7 +715,7 @@ def send_messages( raise ValueError(f"All messages must be of type Message or MessageCreate, got {[type(message) for message in messages]}") # Run the agent state forward - return self._step(user_id=user_id, agent_id=agent_id, input_messages=message_objects) + return self._step(user_id=user_id, agent_id=agent_id, input_messages=message_objects, interface=interface) # @LockingServer.agent_lock_decorator def run_command(self, user_id: str, agent_id: str, command: str) -> LettaUsageStatistics: @@ -836,129 +763,109 @@ def create_agent( else: raise ValueError(f"Invalid agent type: {request.agent_type}") + # create blocks (note: cannot be linked into the agent_id is created) + blocks = [] + for create_block in request.memory_blocks: + block = self.block_manager.create_or_update_block(Block(**create_block.model_dump()), actor=actor) + blocks.append(block) + + # get tools + only add if they exist + tool_objs = [] + if request.tools: + for tool_name in request.tools: + tool_obj = self.tool_manager.get_tool_by_name(tool_name=tool_name, actor=actor) + if tool_obj: + tool_objs.append(tool_obj) + else: + warnings.warn(f"Attempted to add a nonexistent tool {tool_name} to agent {request.name}, skipping.") + # reset the request.tools to only valid tools + request.tools = [t.name for t in tool_objs] + + # get the user logger.debug(f"Attempting to find user: {user_id}") user = self.user_manager.get_user_by_id(user_id=user_id) if not user: raise ValueError(f"cannot find user with associated client id: {user_id}") - try: - # model configuration - llm_config = request.llm_config - embedding_config = request.embedding_config - - # get tools + only add if they exist - tool_objs = [] - if request.tools: - for tool_name in request.tools: - tool_obj = self.tool_manager.get_tool_by_name(tool_name=tool_name, actor=actor) - if tool_obj: - tool_objs.append(tool_obj) - else: - warnings.warn(f"Attempted to add a nonexistent tool {tool_name} to agent {request.name}, skipping.") - # reset the request.tools to only valid tools - request.tools = [t.name for t in tool_objs] - - assert request.memory is not None - memory_functions = get_memory_functions(request.memory) - for func_name, func in memory_functions.items(): - - if request.tools and func_name in request.tools: - # tool already added - continue - source_code = parse_source_code(func) - # memory functions are not terminal - json_schema = generate_schema(func, name=func_name) - source_type = "python" - tags = ["memory", "memgpt-base"] - tool = self.tool_manager.create_or_update_tool( - Tool( - source_code=source_code, - source_type=source_type, - tags=tags, - json_schema=json_schema, - ), - actor=actor, - ) - tool_objs.append(tool) - if not request.tools: - request.tools = [] - request.tools.append(tool.name) - - # TODO: save the agent state - agent_state = AgentState( - name=request.name, - user_id=user_id, - tools=request.tools if request.tools else [], - tool_rules=request.tool_rules if request.tool_rules else [], - agent_type=request.agent_type or AgentType.memgpt_agent, - llm_config=llm_config, - embedding_config=embedding_config, - system=request.system, - memory=request.memory, - description=request.description, - metadata_=request.metadata_, - tags=request.tags, - ) - if request.agent_type == AgentType.memgpt_agent: - agent = Agent( - interface=interface, - agent_state=agent_state, - tools=tool_objs, - # gpt-3.5-turbo tends to omit inner monologue, relax this requirement for now - first_message_verify_mono=( - True if (llm_config and llm_config.model is not None and "gpt-4" in llm_config.model) else False - ), - user=actor, - initial_message_sequence=request.initial_message_sequence, - ) - elif request.agent_type == AgentType.o1_agent: - agent = O1Agent( - interface=interface, - agent_state=agent_state, - tools=tool_objs, - # gpt-3.5-turbo tends to omit inner monologue, relax this requirement for now - first_message_verify_mono=( - True if (llm_config and llm_config.model is not None and "gpt-4" in llm_config.model) else False - ), - user=actor, - ) - # rebuilding agent memory on agent create in case shared memory blocks - # were specified in the new agent's memory config. we're doing this for two reasons: - # 1. if only the ID of the shared memory block was specified, we can fetch its most recent value - # 2. if the shared block state changed since this agent initialization started, we can be sure to have the latest value - agent.rebuild_memory(force=True, ms=self.ms) - # FIXME: this is a hacky way to get the system prompts injected into agent into the DB - # self.ms.update_agent(agent.agent_state) - except Exception as e: - logger.exception(e) - try: - if agent: - self.ms.delete_agent(agent_id=agent.agent_state.id, per_agent_lock_manager=self.per_agent_lock_manager) - except Exception as delete_e: - logger.exception(f"Failed to delete_agent:\n{delete_e}") - raise e + # TODO: create the message objects (NOTE: do this after we migrate to `CreateMessage`) - # save agent - save_agent(agent, self.ms) - logger.debug(f"Created new agent from config: {agent}") + # created and persist the agent state in the DB + agent_state = PersistedAgentState( + name=request.name, + user_id=user_id, + tool_names=request.tools if request.tools else [], + tool_rules=request.tool_rules, + agent_type=request.agent_type or AgentType.memgpt_agent, + llm_config=request.llm_config, + embedding_config=request.embedding_config, + system=request.system, + # other metadata + description=request.description, + metadata_=request.metadata_, + ) + # TODO: move this to agent ORM + # this saves the agent ID and state into the DB + self.ms.create_agent(agent_state) + + # Note: mappings (e.g. tags, blocks) are created after the agent is persisted + # TODO: add source mappings here as well - # TODO: move this into save_agent. save_agent should be moved to server.py + # create the tags if request.tags: for tag in request.tags: - self.agents_tags_manager.add_tag_to_agent(agent_id=agent.agent_state.id, tag=tag, actor=actor) + self.agents_tags_manager.add_tag_to_agent(agent_id=agent_state.id, tag=tag, actor=actor) - assert isinstance(agent.agent_state.memory, Memory), f"Invalid memory type: {type(agent_state.memory)}" + # create block mappins (now that agent is persisted) + for block in blocks: + # this links the created block to the agent + self.blocks_agents_manager.add_block_to_agent(block_id=block.id, agent_id=agent_state.id, block_label=block.label) - # TODO: remove (hacky) - agent.agent_state.tags = self.agents_tags_manager.get_tags_for_agent(agent_id=agent.agent_state.id, actor=actor) + in_memory_agent_state = self.get_agent(agent_state.id) + return in_memory_agent_state + + def get_agent(self, agent_id: str) -> AgentState: + """ + Retrieve the full agent state from the DB. + This gathers data accross multiple tables to provide the full state of an agent, which is passed into the `Agent` object for creation. + """ + + # get data persisted from the DB + agent_state = self.ms.get_agent(agent_id=agent_id) + if agent_state is None: + # agent does not exist + return None + user = self.user_manager.get_user_by_id(user_id=agent_state.user_id) - return agent.agent_state + # construct the in-memory, full agent state - this gather data stored in different tables but that needs to be passed to `Agent` + # we also return this data to the user to provide all the state related to an agent + + # get `Memory` object by getting the linked block IDs and fetching the blocks, then putting that into a `Memory` object + # this is the "in memory" representation of the in-context memory + block_ids = self.blocks_agents_manager.list_block_ids_for_agent(agent_id=agent_id) + blocks = [] + for block_id in block_ids: + block = self.block_manager.get_block_by_id(block_id=block_id, actor=user) + assert block, f"Block with ID {block_id} does not exist" + blocks.append(block) + memory = Memory(blocks=blocks) + + # get `Tool` objects + tools = [self.tool_manager.get_tool_by_name(tool_name=tool_name, actor=user) for tool_name in agent_state.tool_names] + + # get `Source` objects + sources = self.list_attached_sources(agent_id=agent_id) + + # get the tags + tags = self.agents_tags_manager.get_tags_for_agent(agent_id=agent_id, actor=user) + + # return the full agent state - this contains all data needed to recreate the agent + return AgentState(**agent_state.model_dump(), memory=memory, tools=tools, sources=sources, tags=tags) def update_agent( self, request: UpdateAgentState, actor: User, - ): + ) -> AgentState: """Update the agents core memory block, return the new state""" try: self.user_manager.get_user_by_id(user_id=actor.id) @@ -969,13 +876,7 @@ def update_agent( raise ValueError(f"Agent agent_id={request.id} does not exist") # Get the agent object (loaded in memory) - letta_agent = self._get_or_load_agent(agent_id=request.id) - - # update the core memory of the agent - if request.memory: - assert isinstance(request.memory, Memory), type(request.memory) - new_memory_contents = request.memory.to_flat_dict() - _ = self.update_agent_core_memory(user_id=actor.id, agent_id=request.id, new_memory_contents=new_memory_contents) + letta_agent = self.load_agent(agent_id=request.id) # update the system prompt if request.system: @@ -989,13 +890,13 @@ def update_agent( letta_agent.set_message_buffer(message_ids=request.message_ids) # tools - if request.tools: + if request.tool_names: # Replace tools and also re-link # (1) get tools + make sure they exist # Current and target tools as sets of tool names - current_tools = set(letta_agent.agent_state.tools) - target_tools = set(request.tools) + current_tools = set(letta_agent.agent_state.tool_names) + target_tools = set(request.tool_names) # Calculate tools to add and remove tools_to_add = target_tools - current_tools @@ -1012,7 +913,7 @@ def update_agent( self.add_tool_to_agent(agent_id=request.id, tool_id=tool.id, user_id=actor.id) # reload agent - letta_agent = self._get_or_load_agent(agent_id=request.id) + letta_agent = self.load_agent(agent_id=request.id) # configs if request.llm_config: @@ -1040,7 +941,6 @@ def update_agent( self.agents_tags_manager.delete_tag_from_agent(agent_id=letta_agent.agent_state.id, tag=tag, actor=actor) # save the agent - assert isinstance(letta_agent.memory, Memory) save_agent(letta_agent, self.ms) # TODO: probably reload the agent somehow? return letta_agent.agent_state @@ -1053,8 +953,8 @@ def get_tools_from_agent(self, agent_id: str, user_id: Optional[str]) -> List[To raise ValueError(f"Agent agent_id={agent_id} does not exist") # Get the agent object (loaded in memory) - letta_agent = self._get_or_load_agent(agent_id=agent_id) - return letta_agent.tools + letta_agent = self.load_agent(agent_id=agent_id) + return letta_agent.agent_state.tools def add_tool_to_agent( self, @@ -1072,7 +972,7 @@ def add_tool_to_agent( raise ValueError(f"Agent agent_id={agent_id} does not exist") # Get the agent object (loaded in memory) - letta_agent = self._get_or_load_agent(agent_id=agent_id) + letta_agent = self.load_agent(agent_id=agent_id) # Get all the tool objects from the request tool_objs = [] @@ -1080,7 +980,7 @@ def add_tool_to_agent( assert tool_obj, f"Tool with id={tool_id} does not exist" tool_objs.append(tool_obj) - for tool in letta_agent.tools: + for tool in letta_agent.agent_state.tools: tool_obj = self.tool_manager.get_tool_by_id(tool_id=tool.id, actor=user) assert tool_obj, f"Tool with id={tool.id} does not exist" @@ -1089,7 +989,7 @@ def add_tool_to_agent( tool_objs.append(tool_obj) # replace the list of tool names ("ids") inside the agent state - letta_agent.agent_state.tools = [tool.name for tool in tool_objs] + letta_agent.agent_state.tool_names = [tool.name for tool in tool_objs] # then attempt to link the tools modules letta_agent.link_tools(tool_objs) @@ -1114,11 +1014,11 @@ def remove_tool_from_agent( raise ValueError(f"Agent agent_id={agent_id} does not exist") # Get the agent object (loaded in memory) - letta_agent = self._get_or_load_agent(agent_id=agent_id) + letta_agent = self.load_agent(agent_id=agent_id) # Get all the tool_objs tool_objs = [] - for tool in letta_agent.tools: + for tool in letta_agent.agent_state.tools: tool_obj = self.tool_manager.get_tool_by_id(tool_id=tool.id, actor=user) assert tool_obj, f"Tool with id={tool.id} does not exist" @@ -1127,7 +1027,7 @@ def remove_tool_from_agent( tool_objs.append(tool_obj) # replace the list of tool names ("ids") inside the agent state - letta_agent.agent_state.tools = [tool.name for tool in tool_objs] + letta_agent.agent_state.tool_names = [tool.name for tool in tool_objs] # then attempt to link the tools modules letta_agent.link_tools(tool_objs) @@ -1136,18 +1036,9 @@ def remove_tool_from_agent( save_agent(letta_agent, self.ms) return letta_agent.agent_state - def _agent_state_to_config(self, agent_state: AgentState) -> dict: - """Convert AgentState to a dict for a JSON response""" - assert agent_state is not None - - agent_config = { - "id": agent_state.id, - "name": agent_state.name, - "human": agent_state._metadata.get("human", None), - "persona": agent_state._metadata.get("persona", None), - "created_at": agent_state.created_at.isoformat(), - } - return agent_config + def get_agent_state(self, user_id: str, agent_id: str) -> AgentState: + # TODO: duplicate, remove + return self.get_agent(agent_id=agent_id) def list_agents(self, user_id: str, tags: Optional[List[str]] = None) -> List[AgentState]: """List all available agents to a user""" @@ -1155,13 +1046,13 @@ def list_agents(self, user_id: str, tags: Optional[List[str]] = None) -> List[Ag if tags is None: agents_states = self.ms.list_agents(user_id=user_id) - return agents_states + agent_ids = [agent.id for agent in agents_states] else: agent_ids = [] for tag in tags: agent_ids += self.agents_tags_manager.get_agents_by_tag(tag=tag, actor=user) - return [self.get_agent_state(user_id=user.id, agent_id=agent_id) for agent_id in agent_ids] + return [self.get_agent(agent_id=agent_id) for agent_id in agent_ids] # convert name->id @@ -1185,34 +1076,34 @@ def get_source_id(self, source_name: str, user_id: str) -> str: def get_agent_memory(self, agent_id: str) -> Memory: """Return the memory of an agent (core memory)""" - agent = self._get_or_load_agent(agent_id=agent_id) - return agent.memory + agent = self.load_agent(agent_id=agent_id) + return agent.agent_state.memory def get_archival_memory_summary(self, agent_id: str) -> ArchivalMemorySummary: - agent = self._get_or_load_agent(agent_id=agent_id) + agent = self.load_agent(agent_id=agent_id) return ArchivalMemorySummary(size=len(agent.persistence_manager.archival_memory)) def get_recall_memory_summary(self, agent_id: str) -> RecallMemorySummary: - agent = self._get_or_load_agent(agent_id=agent_id) + agent = self.load_agent(agent_id=agent_id) return RecallMemorySummary(size=len(agent.persistence_manager.recall_memory)) def get_in_context_message_ids(self, agent_id: str) -> List[str]: """Get the message ids of the in-context messages in the agent's memory""" # Get the agent object (loaded in memory) - letta_agent = self._get_or_load_agent(agent_id=agent_id) - return [m.id for m in letta_agent._messages] + agent = self.load_agent(agent_id=agent_id) + return [m.id for m in agent._messages] def get_in_context_messages(self, agent_id: str) -> List[Message]: """Get the in-context messages in the agent's memory""" # Get the agent object (loaded in memory) - letta_agent = self._get_or_load_agent(agent_id=agent_id) - return letta_agent._messages + agent = self.load_agent(agent_id=agent_id) + return agent._messages def get_agent_message(self, agent_id: str, message_id: str) -> Message: """Get a single message from the agent's memory""" # Get the agent object (loaded in memory) - letta_agent = self._get_or_load_agent(agent_id=agent_id) - message = letta_agent.persistence_manager.recall_memory.storage.get(id=message_id) + agent = self.load_agent(agent_id=agent_id) + message = agent.persistence_manager.recall_memory.storage.get(id=message_id) return message def get_agent_messages( @@ -1223,7 +1114,7 @@ def get_agent_messages( ) -> Union[List[Message], List[LettaMessage]]: """Paginated query of all messages in agent message queue""" # Get the agent object (loaded in memory) - letta_agent = self._get_or_load_agent(agent_id=agent_id) + letta_agent = self.load_agent(agent_id=agent_id) if start < 0 or count < 0: raise ValueError("Start and count values should be non-negative") @@ -1241,10 +1132,6 @@ def get_agent_messages( # Slice the list for pagination messages = reversed_messages[start:end_index] - ## Convert to json - ## Add a tag indicating in-context or not - # json_messages = [{**record.to_json(), "in_context": True} for record in messages] - else: # need to access persistence manager for additional messages db_iterator = letta_agent.persistence_manager.recall_memory.storage.get_all_paginated(page_size=count, offset=start) @@ -1274,7 +1161,7 @@ def get_agent_archival(self, user_id: str, agent_id: str, start: int, count: int raise ValueError(f"Agent agent_id={agent_id} does not exist") # Get the agent object (loaded in memory) - letta_agent = self._get_or_load_agent(agent_id=agent_id) + letta_agent = self.load_agent(agent_id=agent_id) # iterate over records db_iterator = letta_agent.persistence_manager.archival_memory.storage.get_all_paginated(page_size=count, offset=start) @@ -1299,7 +1186,7 @@ def get_agent_archival_cursor( raise ValueError(f"Agent agent_id={agent_id} does not exist") # Get the agent object (loaded in memory) - letta_agent = self._get_or_load_agent(agent_id=agent_id) + letta_agent = self.load_agent(agent_id=agent_id) # iterate over recorde cursor, records = letta_agent.persistence_manager.archival_memory.storage.get_all_cursor( @@ -1314,11 +1201,15 @@ def insert_archival_memory(self, user_id: str, agent_id: str, memory_contents: s raise ValueError(f"Agent agent_id={agent_id} does not exist") # Get the agent object (loaded in memory) - letta_agent = self._get_or_load_agent(agent_id=agent_id) + letta_agent = self.load_agent(agent_id=agent_id) # Insert into archival memory passage_ids = letta_agent.persistence_manager.archival_memory.insert(memory_string=memory_contents, return_ids=True) + # Update the agent + # TODO: should this update the system prompt? + save_agent(letta_agent, self.ms) + # TODO: this is gross, fix return [letta_agent.persistence_manager.archival_memory.storage.get(id=passage_id) for passage_id in passage_ids] @@ -1331,7 +1222,7 @@ def delete_archival_memory(self, user_id: str, agent_id: str, memory_id: str): # TODO: should return a passage # Get the agent object (loaded in memory) - letta_agent = self._get_or_load_agent(agent_id=agent_id) + letta_agent = self.load_agent(agent_id=agent_id) # Delete by ID # TODO check if it exists first, and throw error if not @@ -1359,7 +1250,7 @@ def get_agent_recall_cursor( raise ValueError(f"Agent agent_id={agent_id} does not exist") # Get the agent object (loaded in memory) - letta_agent = self._get_or_load_agent(agent_id=agent_id) + letta_agent = self.load_agent(agent_id=agent_id) # iterate over records cursor, records = letta_agent.persistence_manager.recall_memory.storage.get_all_cursor( @@ -1384,28 +1275,6 @@ def get_agent_recall_cursor( return records - def get_agent_state(self, user_id: str, agent_id: Optional[str], agent_name: Optional[str] = None) -> Optional[AgentState]: - """Return the config of an agent""" - user = self.user_manager.get_user_by_id(user_id=user_id) - if agent_id: - if self.ms.get_agent(agent_id=agent_id, user_id=user_id) is None: - return None - else: - agent_state = self.ms.get_agent(agent_name=agent_name, user_id=user_id) - if agent_state is None: - raise ValueError(f"Agent agent_name={agent_name} does not exist") - agent_id = agent_state.id - - # Get the agent object (loaded in memory) - letta_agent = self._get_or_load_agent(agent_id=agent_id) - assert isinstance(letta_agent.memory, Memory) - - letta_agent.update_memory_blocks_from_db() - agent_state = letta_agent.agent_state.model_copy(deep=True) - # Load the tags in for the agent_state - agent_state.tags = self.agents_tags_manager.get_tags_for_agent(agent_id=agent_id, actor=user) - return agent_state - def get_server_config(self, include_defaults: bool = False) -> dict: """Return the base config""" @@ -1429,39 +1298,23 @@ def clean_keys(config): return response - def update_agent_core_memory(self, user_id: str, agent_id: str, new_memory_contents: dict) -> Memory: - """Update the agents core memory block, return the new state""" - if self.user_manager.get_user_by_id(user_id=user_id) is None: - raise ValueError(f"User user_id={user_id} does not exist") - if self.ms.get_agent(agent_id=agent_id, user_id=user_id) is None: - raise ValueError(f"Agent agent_id={agent_id} does not exist") + def update_agent_core_memory(self, user_id: str, agent_id: str, label: str, value: str) -> Memory: + """Update the value of a block in the agent's memory""" - # Get the agent object (loaded in memory) - letta_agent = self._get_or_load_agent(agent_id=agent_id) - - # old_core_memory = self.get_agent_memory(agent_id=agent_id) - - modified = False - for key, value in new_memory_contents.items(): - if letta_agent.memory.get_block(key) is None: - # raise ValueError(f"Key {key} not found in agent memory {list(letta_agent.memory.list_block_names())}") - raise ValueError(f"Key {key} not found in agent memory {str(letta_agent.memory.memory)}") - if value is None: - continue - if letta_agent.memory.get_block(key) != value: - letta_agent.memory.update_block_value(label=key, value=value) # update agent memory - modified = True - - # If we modified the memory contents, we need to rebuild the memory block inside the system message - if modified: - letta_agent.rebuild_memory() - # letta_agent.rebuild_memory(force=True, ms=self.ms) # This breaks unit tests in test_local_client.py - # save agent - save_agent(letta_agent, self.ms) + # get the block id + block = self.get_agent_block_by_label(user_id=user_id, agent_id=agent_id, label=label) + block_id = block.id - return self.ms.get_agent(agent_id=agent_id).memory + # update the block + self.block_manager.update_block( + block_id=block_id, block_update=BlockUpdate(value=value), actor=self.user_manager.get_user_by_id(user_id=user_id) + ) - def rename_agent(self, user_id: str, agent_id: str, new_agent_name: str) -> AgentState: + # load agent + letta_agent = self.load_agent(agent_id=agent_id) + return letta_agent.agent_state.memory + + def rename_agent(self, user_id: str, agent_id: str, new_agent_name: str) -> PersistedAgentState: """Update the name of the agent in the database""" if self.user_manager.get_user_by_id(user_id=user_id) is None: raise ValueError(f"User user_id={user_id} does not exist") @@ -1469,7 +1322,7 @@ def rename_agent(self, user_id: str, agent_id: str, new_agent_name: str) -> Agen raise ValueError(f"Agent agent_id={agent_id} does not exist") # Get the agent object (loaded in memory) - letta_agent = self._get_or_load_agent(agent_id=agent_id) + letta_agent = self.load_agent(agent_id=agent_id) current_name = letta_agent.agent_state.name if current_name == new_agent_name: @@ -1491,6 +1344,7 @@ def delete_agent(self, user_id: str, agent_id: str): # TODO: REMOVE THIS ONCE WE MIGRATE AGENTMODEL TO ORM MODEL # TODO: EVENTUALLY WE GET AUTO-DELETES WHEN WE SPECIFY RELATIONSHIPS IN THE ORM self.agents_tags_manager.delete_all_tags_from_agent(agent_id=agent_id, actor=actor) + self.blocks_agents_manager.remove_all_agent_blocks(agent_id=agent_id) if self.ms.get_agent(agent_id=agent_id, user_id=user_id) is None: raise ValueError(f"Agent agent_id={agent_id} does not exist") @@ -1658,9 +1512,10 @@ def attach_source_to_agent( raise ValueError(f"Need to provide at least source_id or source_name to find the source.") # get connection to data source storage source_connector = StorageConnector.get_storage_connector(TableType.PASSAGES, self.config, user_id=user_id) + assert data_source, f"Data source with id={source_id} or name={source_name} does not exist" # load agent - agent = self._get_or_load_agent(agent_id=agent_id) + agent = self.load_agent(agent_id=agent_id) # attach source to agent agent.attach_source(data_source.id, source_connector, self.ms) @@ -1685,7 +1540,7 @@ def detach_source_from_agent( source_id = source.id # delete all Passage objects with source_id==source_id from agent's archival memory - agent = self._get_or_load_agent(agent_id=agent_id) + agent = self.load_agent(agent_id=agent_id) archival_memory = agent.persistence_manager.archival_memory archival_memory.storage.delete({"source_id": source_id}) @@ -1764,34 +1619,43 @@ def add_default_external_tools(self, actor: User) -> bool: def get_agent_message(self, agent_id: str, message_id: str) -> Optional[Message]: """Get a single message from the agent's memory""" # Get the agent object (loaded in memory) - letta_agent = self._get_or_load_agent(agent_id=agent_id) + letta_agent = self.load_agent(agent_id=agent_id) message = letta_agent.persistence_manager.recall_memory.storage.get(id=message_id) + save_agent(letta_agent, self.ms) return message def update_agent_message(self, agent_id: str, request: UpdateMessage) -> Message: """Update the details of a message associated with an agent""" # Get the current message - letta_agent = self._get_or_load_agent(agent_id=agent_id) - return letta_agent.update_message(request=request) + letta_agent = self.load_agent(agent_id=agent_id) + response = letta_agent.update_message(request=request) + save_agent(letta_agent, self.ms) + return response def rewrite_agent_message(self, agent_id: str, new_text: str) -> Message: # Get the current message - letta_agent = self._get_or_load_agent(agent_id=agent_id) - return letta_agent.rewrite_message(new_text=new_text) + letta_agent = self.load_agent(agent_id=agent_id) + response = letta_agent.rewrite_message(new_text=new_text) + save_agent(letta_agent, self.ms) + return response def rethink_agent_message(self, agent_id: str, new_thought: str) -> Message: # Get the current message - letta_agent = self._get_or_load_agent(agent_id=agent_id) - return letta_agent.rethink_message(new_thought=new_thought) + letta_agent = self.load_agent(agent_id=agent_id) + response = letta_agent.rethink_message(new_thought=new_thought) + save_agent(letta_agent, self.ms) + return response def retry_agent_message(self, agent_id: str) -> List[Message]: # Get the current message - letta_agent = self._get_or_load_agent(agent_id=agent_id) - return letta_agent.retry_message() + letta_agent = self.load_agent(agent_id=agent_id) + response = letta_agent.retry_message() + save_agent(letta_agent, self.ms) + return response def get_user_or_default(self, user_id: Optional[str]) -> User: """Get the user object for user_id if it exists, otherwise return the default user object""" @@ -1840,121 +1704,49 @@ def get_agent_context_window( agent_id: str, ) -> ContextWindowOverview: # Get the current message - letta_agent = self._get_or_load_agent(agent_id=agent_id) + letta_agent = self.load_agent(agent_id=agent_id) return letta_agent.get_context_window() - def update_agent_memory_label(self, user_id: str, agent_id: str, current_block_label: str, new_block_label: str) -> Memory: - """Update the label of a block in an agent's memory""" - - # Get the user - user = self.user_manager.get_user_by_id(user_id=user_id) - - # Link a block to an agent's memory - letta_agent = self._get_or_load_agent(agent_id=agent_id) - letta_agent.memory.update_block_label(current_label=current_block_label, new_label=new_block_label) - assert new_block_label in letta_agent.memory.list_block_labels() - self.block_manager.create_or_update_block(block=letta_agent.memory.get_block(new_block_label), actor=user) - - # check that the block was updated - updated_block = self.block_manager.get_block_by_id(block_id=letta_agent.memory.get_block(new_block_label).id, actor=user) - - # Recompile the agent memory - letta_agent.rebuild_memory(force=True, ms=self.ms) - - # save agent - save_agent(letta_agent, self.ms) - - updated_agent = self.ms.get_agent(agent_id=agent_id) - if updated_agent is None: - raise ValueError(f"Agent with id {agent_id} not found after linking block") - assert new_block_label in updated_agent.memory.list_block_labels() - assert current_block_label not in updated_agent.memory.list_block_labels() - return updated_agent.memory - def link_block_to_agent_memory(self, user_id: str, agent_id: str, block_id: str) -> Memory: """Link a block to an agent's memory""" - - # Get the user - user = self.user_manager.get_user_by_id(user_id=user_id) - - # Get the block first - block = self.block_manager.get_block_by_id(block_id=block_id, actor=user) + block = self.block_manager.get_block_by_id(block_id=block_id, actor=self.user_manager.get_user_by_id(user_id=user_id)) if block is None: raise ValueError(f"Block with id {block_id} not found") + self.blocks_agents_manager.add_block_to_agent(agent_id, block_id, block_label=block.label) - # Link a block to an agent's memory - letta_agent = self._get_or_load_agent(agent_id=agent_id) - letta_agent.memory.link_block(block=block) - assert block.label in letta_agent.memory.list_block_labels() - - # Recompile the agent memory - letta_agent.rebuild_memory(force=True, ms=self.ms) - - # save agent - save_agent(letta_agent, self.ms) - - updated_agent = self.ms.get_agent(agent_id=agent_id) - if updated_agent is None: - raise ValueError(f"Agent with id {agent_id} not found after linking block") - assert block.label in updated_agent.memory.list_block_labels() - - return updated_agent.memory + # get agent memory + memory = self.load_agent(agent_id=agent_id).agent_state.memory + return memory def unlink_block_from_agent_memory(self, user_id: str, agent_id: str, block_label: str, delete_if_no_ref: bool = True) -> Memory: """Unlink a block from an agent's memory. If the block is not linked to any agent, delete it.""" + self.blocks_agents_manager.remove_block_with_label_from_agent(agent_id=agent_id, block_label=block_label) - # Get the user - user = self.user_manager.get_user_by_id(user_id=user_id) - - # Link a block to an agent's memory - letta_agent = self._get_or_load_agent(agent_id=agent_id) - unlinked_block = letta_agent.memory.unlink_block(block_label=block_label) - assert unlinked_block.label not in letta_agent.memory.list_block_labels() - - # Check if the block is linked to any other agent - # TODO needs reference counting GC to handle loose blocks - # block = self.block_manager.get_block_by_id(block_id=unlinked_block.id, actor=user) - # if block is None: - # raise ValueError(f"Block with id {block_id} not found") - - # Recompile the agent memory - letta_agent.rebuild_memory(force=True, ms=self.ms) - - # save agent - save_agent(letta_agent, self.ms) - - updated_agent = self.ms.get_agent(agent_id=agent_id) - if updated_agent is None: - raise ValueError(f"Agent with id {agent_id} not found after linking block") - assert unlinked_block.label not in updated_agent.memory.list_block_labels() - return updated_agent.memory + # get agent memory + memory = self.load_agent(agent_id=agent_id).agent_state.memory + return memory def update_agent_memory_limit(self, user_id: str, agent_id: str, block_label: str, limit: int) -> Memory: """Update the limit of a block in an agent's memory""" + block = self.get_agent_block_by_label(user_id=user_id, agent_id=agent_id, label=block_label) + self.block_manager.update_block( + block_id=block.id, block_update=BlockUpdate(limit=limit), actor=self.user_manager.get_user_by_id(user_id=user_id) + ) + # get agent memory + memory = self.load_agent(agent_id=agent_id).agent_state.memory + return memory + + def upate_block(self, user_id: str, block_id: str, block_update: BlockUpdate) -> Block: + """Update a block""" + return self.block_manager.update_block( + block_id=block_id, block_update=block_update, actor=self.user_manager.get_user_by_id(user_id=user_id) + ) - # Get the user - user = self.user_manager.get_user_by_id(user_id=user_id) - - # Link a block to an agent's memory - letta_agent = self._get_or_load_agent(agent_id=agent_id) - letta_agent.memory.update_block_limit(label=block_label, limit=limit) - assert block_label in letta_agent.memory.list_block_labels() - - # write out the update the database - self.block_manager.create_or_update_block(block=letta_agent.memory.get_block(block_label), actor=user) - - # check that the block was updated - updated_block = self.block_manager.get_block_by_id(block_id=letta_agent.memory.get_block(block_label).id, actor=user) - assert updated_block and updated_block.limit == limit - - # Recompile the agent memory - letta_agent.rebuild_memory(force=True, ms=self.ms) - - # save agent - save_agent(letta_agent, self.ms) - - updated_agent = self.ms.get_agent(agent_id=agent_id) - if updated_agent is None: - raise ValueError(f"Agent with id {agent_id} not found after linking block") - assert updated_agent.memory.get_block(label=block_label).limit == limit - return updated_agent.memory + def get_agent_block_by_label(self, user_id: str, agent_id: str, label: str) -> Block: + """Get a block by label""" + # TODO: implement at ORM? + for block_id in self.blocks_agents_manager.list_block_ids_for_agent(agent_id=agent_id): + block = self.block_manager.get_block_by_id(block_id=block_id, actor=self.user_manager.get_user_by_id(user_id=user_id)) + if block.label == label: + return block + return None diff --git a/letta/services/block_manager.py b/letta/services/block_manager.py index dcae5f5cd6..ac6e42b861 100644 --- a/letta/services/block_manager.py +++ b/letta/services/block_manager.py @@ -29,10 +29,8 @@ def create_or_update_block(self, block: Block, actor: PydanticUser) -> PydanticB self.update_block(block.id, update_data, actor) else: with self.session_maker() as session: - # Always write the organization_id - block.organization_id = actor.organization_id data = block.model_dump(exclude_none=True) - block = BlockModel(**data) + block = BlockModel(**data, organization_id=actor.organization_id) block.create(session, actor=actor) return block.to_pydantic() @@ -53,6 +51,11 @@ def update_block(self, block_id: str, block_update: BlockUpdate, actor: Pydantic update_data = block_update.model_dump(exclude_unset=True, exclude_none=True) for key, value in update_data.items(): setattr(block, key, value) + try: + block.to_pydantic() + except Exception as e: + # invalid pydantic model + raise ValueError(f"Failed to create pydantic model: {e}") block.update(db_session=session, actor=actor) # TODO: REMOVE THIS ONCE AGENT IS ON ORM -> Update blocks_agents diff --git a/letta/services/blocks_agents_manager.py b/letta/services/blocks_agents_manager.py index 586a581aa1..121db58649 100644 --- a/letta/services/blocks_agents_manager.py +++ b/letta/services/blocks_agents_manager.py @@ -89,3 +89,18 @@ def list_agent_ids_with_block(self, block_id: str) -> List[str]: with self.session_maker() as session: blocks_agents_record = BlocksAgentsModel.list(db_session=session, block_id=block_id) return [record.agent_id for record in blocks_agents_record] + + @enforce_types + def get_block_id_for_label(self, agent_id: str, block_label: str) -> str: + """Get the block ID for a specific block label for an agent.""" + with self.session_maker() as session: + try: + blocks_agents_record = BlocksAgentsModel.read(db_session=session, agent_id=agent_id, block_label=block_label) + return blocks_agents_record.block_id + except NoResultFound: + raise ValueError(f"Block label '{block_label}' not found for agent '{agent_id}'.") + + @enforce_types + def remove_all_agent_blocks(self, agent_id: str): + for block_id in self.list_block_ids_for_agent(agent_id): + self.remove_block_with_id_from_agent(agent_id, block_id) diff --git a/letta/services/tool_execution_sandbox.py b/letta/services/tool_execution_sandbox.py index 905e63e088..5f2b428a1f 100644 --- a/letta/services/tool_execution_sandbox.py +++ b/letta/services/tool_execution_sandbox.py @@ -154,7 +154,7 @@ def run_e2b_sandbox(self, code: str) -> Optional[SandboxRunResult]: env_vars = self.sandbox_config_manager.get_sandbox_env_vars_as_dict(sandbox_config_id=sbx_config.id, actor=self.user, limit=100) execution = sbx.run_code(code, envs=env_vars) if execution.error is not None: - raise Exception(f"Executing tool {self.tool_name} failed with {execution.error}. Generated code: \n\n{code}") + raise Exception(f"Executing tool {self.tool_name} failed with {execution.error}") elif len(execution.results) == 0: return None else: diff --git a/letta/services/tool_manager.py b/letta/services/tool_manager.py index 28bcabf3df..7acf4fa52f 100644 --- a/letta/services/tool_manager.py +++ b/letta/services/tool_manager.py @@ -24,6 +24,7 @@ class ToolManager: "archival_memory_insert", "archival_memory_search", ] + BASE_MEMORY_TOOL_NAMES = ["core_memory_append", "core_memory_replace"] def __init__(self): # Fetching the db_context similarly as in OrganizationManager @@ -155,7 +156,7 @@ def add_base_tools(self, actor: PydanticUser) -> List[PydanticTool]: # create tool in db tools = [] for name, schema in functions_to_schema.items(): - if name in self.BASE_TOOL_NAMES: + if name in self.BASE_TOOL_NAMES + self.BASE_MEMORY_TOOL_NAMES: # print([str(inspect.getsource(line)) for line in schema["imports"]]) source_code = inspect.getsource(schema["python_function"]) tags = [module_name] diff --git a/locust_test.py b/locust_test.py index 570e6eef47..1e74d405ab 100644 --- a/locust_test.py +++ b/locust_test.py @@ -4,7 +4,7 @@ from locust import HttpUser, between, task from letta.constants import BASE_TOOLS, DEFAULT_HUMAN, DEFAULT_PERSONA -from letta.schemas.agent import AgentState, CreateAgent +from letta.schemas.agent import CreateAgent, PersistedAgentState from letta.schemas.letta_request import LettaRequest from letta.schemas.letta_response import LettaResponse from letta.schemas.memory import ChatMemory @@ -49,7 +49,7 @@ def on_start(self): response.failure(f"Failed to create agent: {response.text}") response_json = response.json() - agent_state = AgentState(**response_json) + agent_state = PersistedAgentState(**response_json) self.agent_id = agent_state.id print("Created agent", self.agent_id, agent_state.name) diff --git a/tests/helpers/endpoints_helper.py b/tests/helpers/endpoints_helper.py index 42575c79bc..37c2da18be 100644 --- a/tests/helpers/endpoints_helper.py +++ b/tests/helpers/endpoints_helper.py @@ -103,8 +103,9 @@ def check_first_response_is_valid_for_llm_endpoint(filename: str) -> ChatComplet cleanup(client=client, agent_uuid=agent_uuid) agent_state = setup_agent(client, filename) - tools = [client.get_tool(client.get_tool_id(name=name)) for name in agent_state.tools] - agent = Agent(interface=None, tools=tools, agent_state=agent_state, user=client.user) + tools = [client.get_tool(client.get_tool_id(name=name)) for name in agent_state.tool_names] + full_agent_state = client.get_agent(agent_state.id) + agent = Agent(agent_state=full_agent_state, interface=None, user=client.user) response = create( llm_config=agent_state.llm_config, @@ -169,7 +170,8 @@ def check_agent_uses_external_tool(filename: str) -> LettaResponse: # Set up client client = create_client() cleanup(client=client, agent_uuid=agent_uuid) - tool = client.load_composio_tool(action=Action.GITHUB_STAR_A_REPOSITORY_FOR_THE_AUTHENTICATED_USER) + # tool = client.load_composio_tool(action=Action.GITHUB_STAR_A_REPOSITORY_FOR_THE_AUTHENTICATED_USER) + tool = client.load_composio_tool(action=Action.WEBTOOL_SCRAPE_WEBSITE_CONTENT) tool_name = tool.name # Set up persona for tool usage @@ -211,6 +213,8 @@ def check_agent_recall_chat_memory(filename: str) -> LettaResponse: human_name = "BananaBoy" agent_state = setup_agent(client, filename, memory_human_str=f"My name is {human_name}") + print("MEMORY", agent_state.memory.get_block("human").value) + response = client.user_message(agent_id=agent_state.id, message="Repeat my name back to me.") # Basic checks @@ -332,7 +336,7 @@ def check_agent_summarize_memory_simple(filename: str) -> LettaResponse: client.user_message(agent_id=agent_state.id, message="Does the number 42 ring a bell?") # Summarize - agent = client.server._get_or_load_agent(agent_id=agent_state.id) + agent = client.server.load_agent(agent_id=agent_state.id) agent.summarize_messages_inplace() print(f"Summarization succeeded: messages[1] = \n\n{json_dumps(agent.messages[1])}\n") diff --git a/tests/integration_test_summarizer.py b/tests/integration_test_summarizer.py index 622ef4b6e7..eeb71af5a8 100644 --- a/tests/integration_test_summarizer.py +++ b/tests/integration_test_summarizer.py @@ -45,9 +45,12 @@ def test_summarizer(config_filename): # Create agent agent_state = client.create_agent(name=agent_name, llm_config=llm_config, embedding_config=embedding_config) - tools = [client.get_tool(client.get_tool_id(name=tool_name)) for tool_name in agent_state.tools] + full_agent_state = client.get_agent(agent_id=agent_state.id) letta_agent = Agent( - interface=StreamingRefreshCLIInterface(), agent_state=agent_state, tools=tools, first_message_verify_mono=False, user=client.user + interface=StreamingRefreshCLIInterface(), + agent_state=full_agent_state, + first_message_verify_mono=False, + user=client.user, ) # Make conversation diff --git a/tests/test_agent_tool_graph.py b/tests/test_agent_tool_graph.py index 850ffb64b1..1d5dbcdcc8 100644 --- a/tests/test_agent_tool_graph.py +++ b/tests/test_agent_tool_graph.py @@ -4,7 +4,7 @@ from letta import create_client from letta.schemas.letta_message import FunctionCallMessage -from letta.schemas.tool_rule import InitToolRule, TerminalToolRule, ToolRule +from letta.schemas.tool_rule import ChildToolRule, InitToolRule, TerminalToolRule from letta.settings import tool_settings from tests.helpers.endpoints_helper import ( assert_invoked_function_call, @@ -107,10 +107,10 @@ def test_single_path_agent_tool_call_graph(mock_e2b_api_key_none): # Make tool rules tool_rules = [ InitToolRule(tool_name="first_secret_word"), - ToolRule(tool_name="first_secret_word", children=["second_secret_word"]), - ToolRule(tool_name="second_secret_word", children=["third_secret_word"]), - ToolRule(tool_name="third_secret_word", children=["fourth_secret_word"]), - ToolRule(tool_name="fourth_secret_word", children=["send_message"]), + ChildToolRule(tool_name="first_secret_word", children=["second_secret_word"]), + ChildToolRule(tool_name="second_secret_word", children=["third_secret_word"]), + ChildToolRule(tool_name="third_secret_word", children=["fourth_secret_word"]), + ChildToolRule(tool_name="fourth_secret_word", children=["send_message"]), TerminalToolRule(tool_name="send_message"), ] diff --git a/tests/test_base_functions.py b/tests/test_base_functions.py index 0ce4cf91ff..0668f2fd0f 100644 --- a/tests/test_base_functions.py +++ b/tests/test_base_functions.py @@ -23,7 +23,7 @@ def agent_obj(): agent_state = client.create_agent() global agent_obj - agent_obj = client.server._get_or_load_agent(agent_id=agent_state.id) + agent_obj = client.server.load_agent(agent_id=agent_state.id) yield agent_obj client.delete_agent(agent_obj.agent_state.id) diff --git a/tests/test_client.py b/tests/test_client.py index 710d74aab2..57fd670e4a 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -12,7 +12,7 @@ from letta import LocalClient, RESTClient, create_client from letta.orm import SandboxConfig, SandboxEnvironmentVariable from letta.schemas.agent import AgentState -from letta.schemas.block import BlockCreate +from letta.schemas.block import CreateBlock from letta.schemas.embedding_config import EmbeddingConfig from letta.schemas.llm_config import LLMConfig from letta.schemas.sandbox_config import LocalSandboxConfig, SandboxType @@ -199,7 +199,7 @@ def test_update_agent_memory_label(client: Union[LocalClient, RESTClient], agent example_new_label = "example_new_label" assert example_new_label not in current_labels - client.update_agent_memory_label(agent_id=agent.id, current_label=example_label, new_label=example_new_label) + client.update_agent_memory_block_label(agent_id=agent.id, current_label=example_label, new_label=example_new_label) updated_agent = client.get_agent(agent_id=agent.id) assert example_new_label in updated_agent.memory.list_block_labels() @@ -222,7 +222,7 @@ def test_add_remove_agent_memory_block(client: Union[LocalClient, RESTClient], a # Link a new memory block client.add_agent_memory_block( agent_id=agent.id, - create_block=BlockCreate( + create_block=CreateBlock( label=example_new_label, value=example_new_value, limit=1000, @@ -284,12 +284,12 @@ def test_update_agent_memory_limit(client: Union[LocalClient, RESTClient], agent # We expect this to throw a value error with pytest.raises(ValueError): - client.update_agent_memory_limit(agent_id=agent.id, block_label=example_label, limit=example_new_limit) + client.update_agent_memory_block(agent_id=agent.id, label=example_label, limit=example_new_limit) # Now try the same thing with a higher limit example_new_limit = current_block_length + 10000 assert example_new_limit > current_block_length - client.update_agent_memory_limit(agent_id=agent.id, block_label=example_label, limit=example_new_limit) + client.update_agent_memory_block(agent_id=agent.id, label=example_label, limit=example_new_limit) updated_agent = client.get_agent(agent_id=agent.id) assert example_new_limit == updated_agent.memory.get_block(label=example_label).limit diff --git a/tests/test_client_legacy.py b/tests/test_client_legacy.py index c2c92e6474..96e27ab6ef 100644 --- a/tests/test_client_legacy.py +++ b/tests/test_client_legacy.py @@ -10,13 +10,12 @@ from sqlalchemy import delete from letta import create_client -from letta.agent import initialize_message_sequence from letta.client.client import LocalClient, RESTClient from letta.constants import DEFAULT_PRESET from letta.orm import FileMetadata, Source from letta.schemas.agent import AgentState from letta.schemas.embedding_config import EmbeddingConfig -from letta.schemas.enums import MessageRole, MessageStreamStatus +from letta.schemas.enums import MessageStreamStatus from letta.schemas.letta_message import ( AssistantMessage, FunctionCallMessage, @@ -31,8 +30,7 @@ from letta.schemas.message import Message from letta.schemas.usage import LettaUsageStatistics from letta.services.tool_manager import ToolManager -from letta.settings import model_settings -from letta.utils import get_utc_time +from letta.settings import model_settings, tool_settings from tests.helpers.client_helper import upload_file_using_client # from tests.utils import create_config @@ -57,6 +55,21 @@ def run_server(): start_server(debug=True) +@pytest.fixture +def mock_e2b_api_key_none(): + # Store the original value of e2b_api_key + original_api_key = tool_settings.e2b_api_key + + # Set e2b_api_key to None + tool_settings.e2b_api_key = None + + # Yield control to the test + yield + + # Restore the original value of e2b_api_key + tool_settings.e2b_api_key = original_api_key + + # Fixture to create clients with different configurations @pytest.fixture( # params=[{"server": True}, {"server": False}], # whether to use REST API server @@ -107,7 +120,7 @@ def agent(client: Union[LocalClient, RESTClient]): client.delete_agent(agent_state.id) -def test_agent(client: Union[LocalClient, RESTClient], agent: AgentState): +def test_agent(mock_e2b_api_key_none, client: Union[LocalClient, RESTClient], agent: AgentState): # test client.rename_agent new_name = "RenamedTestAgent" @@ -126,15 +139,15 @@ def test_agent(client: Union[LocalClient, RESTClient], agent: AgentState): assert client.agent_exists(agent_id=delete_agent.id) == False, "Agent deletion failed" -def test_memory(client: Union[LocalClient, RESTClient], agent: AgentState): +def test_memory(mock_e2b_api_key_none, client: Union[LocalClient, RESTClient], agent: AgentState): # _reset_config() memory_response = client.get_in_context_memory(agent_id=agent.id) print("MEMORY", memory_response.compile()) updated_memory = {"human": "Updated human memory", "persona": "Updated persona memory"} - client.update_in_context_memory(agent_id=agent.id, section="human", value=updated_memory["human"]) - client.update_in_context_memory(agent_id=agent.id, section="persona", value=updated_memory["persona"]) + client.update_agent_memory_block(agent_id=agent.id, label="human", value=updated_memory["human"]) + client.update_agent_memory_block(agent_id=agent.id, label="persona", value=updated_memory["persona"]) updated_memory_response = client.get_in_context_memory(agent_id=agent.id) assert ( updated_memory_response.get_block("human").value == updated_memory["human"] @@ -142,7 +155,7 @@ def test_memory(client: Union[LocalClient, RESTClient], agent: AgentState): ), "Memory update failed" -def test_agent_interactions(client: Union[LocalClient, RESTClient], agent: AgentState): +def test_agent_interactions(mock_e2b_api_key_none, client: Union[LocalClient, RESTClient], agent: AgentState): # test that it is a LettaMessage message = "Hello again, agent!" print("Sending message", message) @@ -165,7 +178,7 @@ def test_agent_interactions(client: Union[LocalClient, RESTClient], agent: Agent # TODO: add streaming tests -def test_archival_memory(client: Union[LocalClient, RESTClient], agent: AgentState): +def test_archival_memory(mock_e2b_api_key_none, client: Union[LocalClient, RESTClient], agent: AgentState): # _reset_config() memory_content = "Archival memory content" @@ -199,7 +212,7 @@ def test_archival_memory(client: Union[LocalClient, RESTClient], agent: AgentSta client.get_archival_memory(agent.id) -def test_core_memory(client: Union[LocalClient, RESTClient], agent: AgentState): +def test_core_memory(mock_e2b_api_key_none, client: Union[LocalClient, RESTClient], agent: AgentState): response = client.send_message(agent_id=agent.id, message="Update your core memory to remember that my name is Timber!", role="user") print("Response", response) @@ -207,7 +220,7 @@ def test_core_memory(client: Union[LocalClient, RESTClient], agent: AgentState): assert "Timber" in memory.get_block("human").value, f"Updating core memory failed: {memory.get_block('human').value}" -def test_streaming_send_message(client: Union[LocalClient, RESTClient], agent: AgentState): +def test_streaming_send_message(mock_e2b_api_key_none, client: Union[LocalClient, RESTClient], agent: AgentState): if isinstance(client, LocalClient): pytest.skip("Skipping test_streaming_send_message because LocalClient does not support streaming") assert isinstance(client, RESTClient), client @@ -317,7 +330,7 @@ def test_list_tools_pagination(client: Union[LocalClient, RESTClient]): def test_list_tools(client: Union[LocalClient, RESTClient]): tools = client.add_base_tools() tool_names = [t.name for t in tools] - expected = ToolManager.BASE_TOOL_NAMES + expected = ToolManager.BASE_TOOL_NAMES + ToolManager.BASE_MEMORY_TOOL_NAMES assert sorted(tool_names) == sorted(expected) @@ -562,21 +575,25 @@ def has_model_endpoint_type(models: List["LLMConfig"], target_type: str) -> bool assert has_model_endpoint_type(models, "anthropic") -def test_shared_blocks(client: Union[LocalClient, RESTClient], agent: AgentState): +def test_shared_blocks(mock_e2b_api_key_none, client: Union[LocalClient, RESTClient], agent: AgentState): # _reset_config() # create a block block = client.create_block(label="human", value="username: sarah") # create agents with shared block + from letta.schemas.block import Block from letta.schemas.memory import BasicBlockMemory - persona1_block = client.create_block(label="persona", value="you are agent 1") - persona2_block = client.create_block(label="persona", value="you are agent 2") - + # persona1_block = client.create_block(label="persona", value="you are agent 1") + # persona2_block = client.create_block(label="persona", value="you are agent 2") # create agnets - agent_state1 = client.create_agent(name="agent1", memory=BasicBlockMemory(blocks=[block, persona1_block])) - agent_state2 = client.create_agent(name="agent2", memory=BasicBlockMemory(blocks=[block, persona2_block])) + agent_state1 = client.create_agent(name="agent1", memory=BasicBlockMemory([Block(label="persona", value="you are agent 1"), block])) + agent_state2 = client.create_agent(name="agent2", memory=BasicBlockMemory([Block(label="persona", value="you are agent 2"), block])) + + ## attach shared block to both agents + # client.link_agent_memory_block(agent_state1.id, block.id) + # client.link_agent_memory_block(agent_state2.id, block.id) # update memory response = client.user_message(agent_id=agent_state1.id, message="my name is actually charles") @@ -607,64 +624,66 @@ def cleanup_agents(): print(f"Failed to delete agent {agent_id}: {e}") -def test_initial_message_sequence(client: Union[LocalClient, RESTClient], agent: AgentState, cleanup_agents: List[str]): - """Test that we can set an initial message sequence - - If we pass in None, we should get a "default" message sequence - If we pass in a non-empty list, we should get that sequence - If we pass in an empty list, we should get an empty sequence - """ - - # The reference initial message sequence: - reference_init_messages = initialize_message_sequence( - model=agent.llm_config.model, - system=agent.system, - memory=agent.memory, - archival_memory=None, - recall_memory=None, - memory_edit_timestamp=get_utc_time(), - include_initial_boot_message=True, - ) - - # system, login message, send_message test, send_message receipt - assert len(reference_init_messages) > 0 - assert len(reference_init_messages) == 4, f"Expected 4 messages, got {len(reference_init_messages)}" - - # Test with default sequence - default_agent_state = client.create_agent(name="test-default-message-sequence", initial_message_sequence=None) - cleanup_agents.append(default_agent_state.id) - assert default_agent_state.message_ids is not None - assert len(default_agent_state.message_ids) > 0 - assert len(default_agent_state.message_ids) == len( - reference_init_messages - ), f"Expected {len(reference_init_messages)} messages, got {len(default_agent_state.message_ids)}" - - # Test with empty sequence - empty_agent_state = client.create_agent(name="test-empty-message-sequence", initial_message_sequence=[]) - cleanup_agents.append(empty_agent_state.id) - assert empty_agent_state.message_ids is not None - assert len(empty_agent_state.message_ids) == 1, f"Expected 0 messages, got {len(empty_agent_state.message_ids)}" - - # Test with custom sequence - custom_sequence = [ - Message( - role=MessageRole.user, - text="Hello, how are you?", - user_id=agent.user_id, - agent_id=agent.id, - model=agent.llm_config.model, - name=None, - tool_calls=None, - tool_call_id=None, - ), - ] - custom_agent_state = client.create_agent(name="test-custom-message-sequence", initial_message_sequence=custom_sequence) - cleanup_agents.append(custom_agent_state.id) - assert custom_agent_state.message_ids is not None - assert ( - len(custom_agent_state.message_ids) == len(custom_sequence) + 1 - ), f"Expected {len(custom_sequence) + 1} messages, got {len(custom_agent_state.message_ids)}" - assert custom_agent_state.message_ids[1:] == [msg.id for msg in custom_sequence] +## NOTE: we need to add this back once agents can also create blocks during agent creation +# def test_initial_message_sequence(client: Union[LocalClient, RESTClient], agent: AgentState, cleanup_agents: List[str]): +# """Test that we can set an initial message sequence +# +# If we pass in None, we should get a "default" message sequence +# If we pass in a non-empty list, we should get that sequence +# If we pass in an empty list, we should get an empty sequence +# """ +# +# # The reference initial message sequence: +# reference_init_messages = initialize_message_sequence( +# model=agent.llm_config.model, +# system=agent.system, +# memory=agent.memory, +# archival_memory=None, +# recall_memory=None, +# memory_edit_timestamp=get_utc_time(), +# include_initial_boot_message=True, +# ) +# +# # system, login message, send_message test, send_message receipt +# assert len(reference_init_messages) > 0 +# assert len(reference_init_messages) == 4, f"Expected 4 messages, got {len(reference_init_messages)}" +# +# # Test with default sequence +# default_agent_state = client.create_agent(name="test-default-message-sequence", initial_message_sequence=None) +# cleanup_agents.append(default_agent_state.id) +# assert default_agent_state.message_ids is not None +# assert len(default_agent_state.message_ids) > 0 +# assert len(default_agent_state.message_ids) == len( +# reference_init_messages +# ), f"Expected {len(reference_init_messages)} messages, got {len(default_agent_state.message_ids)}" +# +# # Test with empty sequence +# empty_agent_state = client.create_agent(name="test-empty-message-sequence", initial_message_sequence=[]) +# cleanup_agents.append(empty_agent_state.id) +# # NOTE: allowed to be None initially +# #assert empty_agent_state.message_ids is not None +# #assert len(empty_agent_state.message_ids) == 1, f"Expected 0 messages, got {len(empty_agent_state.message_ids)}" +# +# # Test with custom sequence +# custom_sequence = [ +# Message( +# role=MessageRole.user, +# text="Hello, how are you?", +# user_id=agent.user_id, +# agent_id=agent.id, +# model=agent.llm_config.model, +# name=None, +# tool_calls=None, +# tool_call_id=None, +# ), +# ] +# custom_agent_state = client.create_agent(name="test-custom-message-sequence", initial_message_sequence=custom_sequence) +# cleanup_agents.append(custom_agent_state.id) +# assert custom_agent_state.message_ids is not None +# assert ( +# len(custom_agent_state.message_ids) == len(custom_sequence) + 1 +# ), f"Expected {len(custom_sequence) + 1} messages, got {len(custom_agent_state.message_ids)}" +# assert custom_agent_state.message_ids[1:] == [msg.id for msg in custom_sequence] def test_add_and_manage_tags_for_agent(client: Union[LocalClient, RESTClient], agent: AgentState): diff --git a/tests/test_different_embedding_size.py b/tests/test_different_embedding_size.py index 0e5e895696..58748339e1 100644 --- a/tests/test_different_embedding_size.py +++ b/tests/test_different_embedding_size.py @@ -66,7 +66,7 @@ # # # openai: add passages # passages, openai_embeddings = generate_passages(client.user, openai_agent) -# openai_agent_run = client.server._get_or_load_agent(user_id=client.user.id, agent_id=openai_agent.id) +# openai_agent_run = client.server.load_agent(user_id=client.user.id, agent_id=openai_agent.id) # openai_agent_run.persistence_manager.archival_memory.storage.insert_many(passages) # # # create client @@ -84,7 +84,7 @@ # # # hosted: add passages # passages, hosted_embeddings = generate_passages(client.user, hosted_agent) -# hosted_agent_run = client.server._get_or_load_agent(user_id=client.user.id, agent_id=hosted_agent.id) +# hosted_agent_run = client.server.load_agent(user_id=client.user.id, agent_id=hosted_agent.id) # hosted_agent_run.persistence_manager.archival_memory.storage.insert_many(passages) # # # test passage dimentionality diff --git a/tests/test_local_client.py b/tests/test_local_client.py index 8b667106cb..74ab87afef 100644 --- a/tests/test_local_client.py +++ b/tests/test_local_client.py @@ -4,7 +4,7 @@ from letta import create_client from letta.client.client import LocalClient -from letta.schemas.agent import AgentState +from letta.schemas.agent import PersistedAgentState from letta.schemas.embedding_config import EmbeddingConfig from letta.schemas.llm_config import LLMConfig from letta.schemas.memory import BasicBlockMemory, ChatMemory, Memory @@ -51,10 +51,10 @@ def test_agent(client: LocalClient): print("TOOLS", [t.name for t in tools]) agent_state = client.get_agent(agent_state_test.id) assert agent_state.name == "test_agent2" - for block in agent_state.memory.to_dict()["memory"].values(): - db_block = client.server.block_manager.get_block_by_id(block.get("id"), actor=client.user) + for block in agent_state.memory.blocks: + db_block = client.server.block_manager.get_block_by_id(block.id, actor=client.user) assert db_block is not None, "memory block not persisted on agent create" - assert db_block.value == block.get("value"), "persisted block data does not match in-memory data" + assert db_block.value == block.value, "persisted block data does not match in-memory data" assert isinstance(agent_state.memory, Memory) # update agent: name @@ -68,6 +68,8 @@ def test_agent(client: LocalClient): client.update_agent(agent_state_test.id, system=new_system_prompt) assert client.get_agent(agent_state_test.id).system == new_system_prompt + response = client.user_message(agent_id=agent_state_test.id, message="Hello") + agent_state = client.get_agent(agent_state_test.id) assert isinstance(agent_state.memory, Memory) # update agent: message_ids old_message_ids = agent_state.message_ids @@ -79,10 +81,10 @@ def test_agent(client: LocalClient): assert isinstance(agent_state.memory, Memory) # update agent: tools tool_to_delete = "send_message" - assert tool_to_delete in agent_state.tools - new_agent_tools = [t_name for t_name in agent_state.tools if t_name != tool_to_delete] + assert tool_to_delete in agent_state.tool_names + new_agent_tools = [t_name for t_name in agent_state.tool_names if t_name != tool_to_delete] client.update_agent(agent_state_test.id, tools=new_agent_tools) - assert client.get_agent(agent_state_test.id).tools == new_agent_tools + assert client.get_agent(agent_state_test.id).tool_names == new_agent_tools assert isinstance(agent_state.memory, Memory) # update agent: memory @@ -92,7 +94,10 @@ def test_agent(client: LocalClient): assert agent_state.memory.get_block("human").value != new_human assert agent_state.memory.get_block("persona").value != new_persona - client.update_agent(agent_state_test.id, memory=new_memory) + # client.update_agent(agent_state_test.id, memory=new_memory) + # update blocks: + client.update_agent_memory_block(agent_state_test.id, label="human", value=new_human) + client.update_agent_memory_block(agent_state_test.id, label="persona", value=new_persona) assert client.get_agent(agent_state_test.id).memory.get_block("human").value == new_human assert client.get_agent(agent_state_test.id).memory.get_block("persona").value == new_persona @@ -181,11 +186,6 @@ def test_agent_with_shared_blocks(client: LocalClient): ) assert isinstance(first_agent_state_test.memory, Memory) - first_blocks_dict = first_agent_state_test.memory.to_dict()["memory"] - assert persona_block.id == first_blocks_dict.get("persona", {}).get("id") - assert human_block.id == first_blocks_dict.get("human", {}).get("id") - client.update_in_context_memory(first_agent_state_test.id, section="human", value="I'm an analyst therapist.") - # when this agent is created with the shared block references this agent's in-memory blocks should # have this latest value set by the other agent. second_agent_state_test = client.create_agent( @@ -194,11 +194,21 @@ def test_agent_with_shared_blocks(client: LocalClient): description="This is a test agent using shared memory blocks", ) + first_memory = first_agent_state_test.memory + assert persona_block.id == first_memory.get_block("persona").id + assert human_block.id == first_memory.get_block("human").id + client.update_agent_memory_block(first_agent_state_test.id, label="human", value="I'm an analyst therapist.") + print("Updated human block value:", client.get_agent_memory_block(first_agent_state_test.id, label="human").value) + + # refresh agent state + second_agent_state_test = client.get_agent(second_agent_state_test.id) + assert isinstance(second_agent_state_test.memory, Memory) - second_blocks_dict = second_agent_state_test.memory.to_dict()["memory"] - assert persona_block.id == second_blocks_dict.get("persona", {}).get("id") - assert human_block.id == second_blocks_dict.get("human", {}).get("id") - assert second_blocks_dict.get("human", {}).get("value") == "I'm an analyst therapist." + second_memory = second_agent_state_test.memory + assert persona_block.id == second_memory.get_block("persona").id + assert human_block.id == second_memory.get_block("human").id + # assert second_blocks_dict.get("human", {}).get("value") == "I'm an analyst therapist." + assert second_memory.get_block("human").value == "I'm an analyst therapist." finally: if first_agent_state_test: @@ -207,7 +217,7 @@ def test_agent_with_shared_blocks(client: LocalClient): client.delete_agent(second_agent_state_test.id) -def test_memory(client: LocalClient, agent: AgentState): +def test_memory(client: LocalClient, agent: PersistedAgentState): # get agent memory original_memory = client.get_in_context_memory(agent.id) assert original_memory is not None @@ -220,7 +230,7 @@ def test_memory(client: LocalClient, agent: AgentState): assert updated_memory.get_block("human").value != original_memory_value # check if the memory has been updated -def test_archival_memory(client: LocalClient, agent: AgentState): +def test_archival_memory(client: LocalClient, agent: PersistedAgentState): """Test functions for interacting with archival memory store""" # add archival memory @@ -235,7 +245,7 @@ def test_archival_memory(client: LocalClient, agent: AgentState): client.delete_archival_memory(agent.id, passage.id) -def test_recall_memory(client: LocalClient, agent: AgentState): +def test_recall_memory(client: LocalClient, agent: PersistedAgentState): """Test functions for interacting with recall memory store""" # send message to the agent @@ -390,13 +400,9 @@ def test_shared_blocks_without_send_message(client: LocalClient): memory=memory, ) - agent_1.memory.update_block_value(label="shared_memory", value="I am no longer an [empty] memory") - block_id = agent_1.memory.get_block("shared_memory").id - client.update_block(block_id, text="I am no longer an [empty] memory") - client.update_agent(agent_id=agent_1.id, memory=agent_1.memory) + client.update_block(block_id, value="I am no longer an [empty] memory") agent_1 = client.get_agent(agent_1.id) agent_2 = client.get_agent(agent_2.id) - client.update_agent(agent_id=agent_2.id, memory=agent_2.memory) assert agent_1.memory.get_block("shared_memory").value == "I am no longer an [empty] memory" assert agent_2.memory.get_block("shared_memory").value == "I am no longer an [empty] memory" diff --git a/tests/test_managers.py b/tests/test_managers.py index 218296ec80..6375290817 100644 --- a/tests/test_managers.py +++ b/tests/test_managers.py @@ -22,11 +22,10 @@ ) from letta.schemas.agent import CreateAgent from letta.schemas.block import Block as PydanticBlock -from letta.schemas.block import BlockUpdate +from letta.schemas.block import BlockUpdate, CreateBlock from letta.schemas.embedding_config import EmbeddingConfig from letta.schemas.file import FileMetadata as PydanticFileMetadata from letta.schemas.llm_config import LLMConfig -from letta.schemas.memory import ChatMemory from letta.schemas.organization import Organization as PydanticOrganization from letta.schemas.sandbox_config import ( E2BSandboxConfig, @@ -120,10 +119,8 @@ def sarah_agent(server: SyncServer, default_user, default_organization): agent_state = server.create_agent( request=CreateAgent( name="sarah_agent", - memory=ChatMemory( - human="Charles", - persona="I am a helpful assistant", - ), + # memory_blocks=[CreateBlock(label="human", value="Charles"), CreateBlock(label="persona", value="I am a helpful assistant")], + memory_blocks=[], llm_config=LLMConfig.default_config("gpt-4"), embedding_config=EmbeddingConfig.default_config(provider="openai"), ), @@ -138,10 +135,7 @@ def charles_agent(server: SyncServer, default_user, default_organization): agent_state = server.create_agent( request=CreateAgent( name="charles_agent", - memory=ChatMemory( - human="Sarah", - persona="I am a helpful assistant", - ), + memory_blocks=[CreateBlock(label="human", value="Charles"), CreateBlock(label="persona", value="I am a helpful assistant")], llm_config=LLMConfig.default_config("gpt-4"), embedding_config=EmbeddingConfig.default_config(provider="openai"), ), diff --git a/tests/test_memory.py b/tests/test_memory.py index ba00c98330..85e12e8014 100644 --- a/tests/test_memory.py +++ b/tests/test_memory.py @@ -1,7 +1,6 @@ import pytest # Import the classes here, assuming the above definitions are in a module named memory_module -from letta.schemas.block import Block from letta.schemas.memory import ChatMemory, Memory @@ -17,23 +16,6 @@ def test_create_chat_memory(): assert chat_memory.get_block("human").value == "User" -def test_dump_memory_as_json(sample_memory: Memory): - """Test dumping ChatMemory as JSON compatible dictionary""" - memory_dict = sample_memory.to_dict()["memory"] - assert isinstance(memory_dict, dict) - assert "persona" in memory_dict - assert memory_dict["persona"]["value"] == "Chat Agent" - - -def test_load_memory_from_json(sample_memory: Memory): - """Test loading ChatMemory from a JSON compatible dictionary""" - memory_dict = sample_memory.to_dict()["memory"] - print(memory_dict) - new_memory = Memory.load(memory_dict) - assert new_memory.get_block("persona").value == "Chat Agent" - assert new_memory.get_block("human").value == "User" - - def test_memory_limit_validation(sample_memory: Memory): """Test exceeding memory limit""" with pytest.raises(ValueError): @@ -43,30 +25,15 @@ def test_memory_limit_validation(sample_memory: Memory): sample_memory.get_block("persona").value = "x " * 10000 -def test_memory_jinja2_template_load(sample_memory: Memory): - """Test loading a memory with and without a jinja2 template""" - - # Test loading a memory with a template - memory_dict = sample_memory.to_dict() - memory_dict["prompt_template"] = sample_memory.get_prompt_template() - new_memory = Memory.load(memory_dict) - assert new_memory.get_prompt_template() == sample_memory.get_prompt_template() - - # Test loading a memory without a template (old format) - memory_dict = sample_memory.to_dict() - memory_dict_old_format = memory_dict["memory"] - new_memory = Memory.load(memory_dict_old_format) - assert new_memory.get_prompt_template() is not None # Ensure a default template is set - assert new_memory.to_dict()["memory"] == memory_dict_old_format - - def test_memory_jinja2_template(sample_memory: Memory): """Test to make sure the jinja2 template string is equivalent to the old __repr__ method""" def old_repr(self: Memory) -> str: """Generate a string representation of the memory in-context""" section_strs = [] - for section, module in self.memory.items(): + for block in sample_memory.get_blocks(): + section = block.label + module = block section_strs.append(f'<{section} characters="{len(module.value)}/{module.limit}">\n{module.value}\n') return "\n".join(section_strs) @@ -106,50 +73,3 @@ def test_memory_jinja2_set_template(sample_memory: Memory): ) with pytest.raises(ValueError): sample_memory.set_prompt_template(prompt_template=template_bad_memory_structure) - - -def test_link_unlink_block(sample_memory: Memory): - """Test linking and unlinking a block to the memory""" - - # Link a new block - - test_new_label = "test_new_label" - test_new_value = "test_new_value" - test_new_block = Block(label=test_new_label, value=test_new_value, limit=2000) - - current_labels = sample_memory.list_block_labels() - assert test_new_label not in current_labels - - sample_memory.link_block(block=test_new_block) - assert test_new_label in sample_memory.list_block_labels() - assert sample_memory.get_block(test_new_label).value == test_new_value - - # Unlink the block - sample_memory.unlink_block(block_label=test_new_label) - assert test_new_label not in sample_memory.list_block_labels() - - -def test_update_block_label(sample_memory: Memory): - """Test updating the label of a block""" - - test_new_label = "test_new_label" - current_labels = sample_memory.list_block_labels() - assert test_new_label not in current_labels - test_old_label = current_labels[0] - - sample_memory.update_block_label(current_label=test_old_label, new_label=test_new_label) - assert test_new_label in sample_memory.list_block_labels() - assert test_old_label not in sample_memory.list_block_labels() - - -def test_update_block_limit(sample_memory: Memory): - """Test updating the limit of a block""" - - test_new_limit = 1000 - current_labels = sample_memory.list_block_labels() - test_old_label = current_labels[0] - - assert sample_memory.get_block(label=test_old_label).limit != test_new_limit - - sample_memory.update_block_limit(label=test_old_label, limit=test_new_limit) - assert sample_memory.get_block(label=test_old_label).limit == test_new_limit diff --git a/tests/test_persistence.py b/tests/test_persistence.py index afc93e5e03..9b86f2b235 100644 --- a/tests/test_persistence.py +++ b/tests/test_persistence.py @@ -20,7 +20,7 @@ # else: # client2 = Letta(quickstart="letta_hosted", user_id=test_user_id) # print(f"\n\n[3] CREATING CLIENT2, LOADING AGENT {test_agent_state.id}!") -# client2_agent_obj = client2.server._get_or_load_agent(user_id=test_user_id, agent_id=test_agent_state.id) +# client2_agent_obj = client2.server.load_agent(user_id=test_user_id, agent_id=test_agent_state.id) # client2_agent_state = client2_agent_obj.update_state() # print(f"[3] LOADED AGENT! AGENT {client2_agent_state.id}\n\tmessages={client2_agent_state.state['messages']}") # @@ -45,7 +45,7 @@ # client3 = Letta(quickstart="openai", user_id=test_user_id) # else: # client3 = Letta(quickstart="letta_hosted", user_id=test_user_id) -# client3_agent_obj = client3.server._get_or_load_agent(user_id=test_user_id, agent_id=test_agent_state.id) +# client3_agent_obj = client3.server.load_agent(user_id=test_user_id, agent_id=test_agent_state.id) # client3_agent_state = client3_agent_obj.update_state() # # check_state_equivalence(vars(test_agent_state_post_message), vars(client3_agent_state)) diff --git a/tests/test_server.py b/tests/test_server.py index 01d6756be7..94a799869e 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -25,6 +25,7 @@ from letta.schemas.agent import CreateAgent from letta.schemas.embedding_config import EmbeddingConfig from letta.schemas.llm_config import LLMConfig +from letta.schemas.message import Message from letta.schemas.memory import ChatMemory from letta.schemas.source import Source from letta.server.server import SyncServer @@ -74,10 +75,7 @@ def agent_id(server, user_id): request=CreateAgent( name="test_agent", tools=BASE_TOOLS, - memory=ChatMemory( - human="Sarah", - persona="I am a helpful assistant", - ), + memory_blocks=[], llm_config=LLMConfig.default_config("gpt-4"), embedding_config=EmbeddingConfig.default_config(provider="openai"), ), @@ -134,9 +132,8 @@ def test_load_data(server, user_id, agent_id): connector = DummyDataConnector(archival_memories) server.load_data(user_id, connector, source.name) - -@pytest.mark.order(3) -def test_attach_source_to_agent(server, user_id, agent_id): + # @pytest.mark.order(3) + # def test_attach_source_to_agent(server, user_id, agent_id): # check archival memory size passages_before = server.get_agent_archival(user_id=user_id, agent_id=agent_id, start=0, count=10000) assert len(passages_before) == 0 @@ -245,7 +242,7 @@ def test_agent_rethink_rewrite_retry(server, user_id, agent_id): server.user_message(user_id=user_id, agent_id=agent_id, message="Hello?") # Grab the raw Agent object - letta_agent = server._get_or_load_agent(agent_id=agent_id) + letta_agent = server.load_agent(agent_id=agent_id) assert letta_agent._messages[-1].role == MessageRole.tool assert letta_agent._messages[-2].role == MessageRole.assistant last_agent_message = letta_agent._messages[-2] @@ -256,7 +253,7 @@ def test_agent_rethink_rewrite_retry(server, user_id, agent_id): server.rethink_agent_message(agent_id=agent_id, new_thought=new_thought) # Grab the agent object again (make sure it's live) - letta_agent = server._get_or_load_agent(agent_id=agent_id) + letta_agent = server.load_agent(agent_id=agent_id) assert letta_agent._messages[-1].role == MessageRole.tool assert letta_agent._messages[-2].role == MessageRole.assistant last_agent_message = letta_agent._messages[-2] @@ -273,7 +270,7 @@ def test_agent_rethink_rewrite_retry(server, user_id, agent_id): server.rewrite_agent_message(agent_id=agent_id, new_text=new_text) # Grab the agent object again (make sure it's live) - letta_agent = server._get_or_load_agent(agent_id=agent_id) + letta_agent = server.load_agent(agent_id=agent_id) assert letta_agent._messages[-1].role == MessageRole.tool assert letta_agent._messages[-2].role == MessageRole.assistant last_agent_message = letta_agent._messages[-2] @@ -284,7 +281,7 @@ def test_agent_rethink_rewrite_retry(server, user_id, agent_id): server.retry_agent_message(agent_id=agent_id) # Grab the agent object again (make sure it's live) - letta_agent = server._get_or_load_agent(agent_id=agent_id) + letta_agent = server.load_agent(agent_id=agent_id) assert letta_agent._messages[-1].role == MessageRole.tool assert letta_agent._messages[-2].role == MessageRole.assistant last_agent_message = letta_agent._messages[-2] @@ -345,10 +342,7 @@ def test_load_agent_with_nonexistent_tool_names_does_not_error(server: SyncServe request=CreateAgent( name="nonexistent_tools_agent", tools=tools, - memory=ChatMemory( - human="Sarah", - persona="I am a helpful assistant", - ), + memory_blocks=[], llm_config=LLMConfig.default_config("gpt-4"), embedding_config=EmbeddingConfig.default_config(provider="openai"), ), @@ -356,13 +350,13 @@ def test_load_agent_with_nonexistent_tool_names_does_not_error(server: SyncServe ) # Check that the tools in agent_state do NOT include the fake name - assert fake_tool_name not in agent_state.tools - assert set(BASE_TOOLS).issubset(set(agent_state.tools)) + assert fake_tool_name not in agent_state.tool_names + assert set(BASE_TOOLS).issubset(set(agent_state.tool_names)) # Load the agent from the database and check that it doesn't error / tools are correct saved_tools = server.get_tools_from_agent(agent_id=agent_state.id, user_id=user_id) - assert fake_tool_name not in agent_state.tools - assert set(BASE_TOOLS).issubset(set(agent_state.tools)) + assert fake_tool_name not in agent_state.tool_names + assert set(BASE_TOOLS).issubset(set(agent_state.tool_names)) # cleanup server.delete_agent(user_id, agent_state.id) @@ -372,10 +366,7 @@ def test_delete_agent_same_org(server: SyncServer, org_id: str, user_id: str): agent_state = server.create_agent( request=CreateAgent( name="nonexistent_tools_agent", - memory=ChatMemory( - human="Sarah", - persona="I am a helpful assistant", - ), + memory_blocks=[], llm_config=LLMConfig.default_config("gpt-4"), embedding_config=EmbeddingConfig.default_config(provider="openai"), ), diff --git a/tests/test_summarize.py b/tests/test_summarize.py index 31a8592912..90499d5fb3 100644 --- a/tests/test_summarize.py +++ b/tests/test_summarize.py @@ -1,11 +1,14 @@ import uuid from typing import List +import pytest + from letta import create_client from letta.client.client import LocalClient from letta.schemas.embedding_config import EmbeddingConfig from letta.schemas.llm_config import LLMConfig from letta.schemas.message import Message +from letta.settings import tool_settings from .utils import wipe_config @@ -18,6 +21,21 @@ # TODO: these tests should add function calls into the summarized message sequence:W +@pytest.fixture +def mock_e2b_api_key_none(): + # Store the original value of e2b_api_key + original_api_key = tool_settings.e2b_api_key + + # Set e2b_api_key to None + tool_settings.e2b_api_key = None + + # Yield control to the test + yield + + # Restore the original value of e2b_api_key + tool_settings.e2b_api_key = original_api_key + + def create_test_agent(): """Create a test agent that we can call functions on""" wipe_config() @@ -33,10 +51,10 @@ def create_test_agent(): ) global agent_obj - agent_obj = client.server._get_or_load_agent(agent_id=agent_state.id) + agent_obj = client.server.load_agent(agent_id=agent_state.id) -def test_summarize_messages_inplace(): +def test_summarize_messages_inplace(mock_e2b_api_key_none): """Test summarization via sending the summarize CLI command or via a direct call to the agent object""" global client global agent_obj @@ -73,12 +91,15 @@ def test_summarize_messages_inplace(): assert response is not None and len(response) > 0 print(f"test_summarize: response={response}") + # reload agent object + agent_obj = client.server.load_agent(agent_id=agent_obj.agent_state.id) + agent_obj.summarize_messages_inplace() print(f"Summarization succeeded: messages[1] = \n{agent_obj.messages[1]}") # response = client.run_command(agent_id=agent_obj.agent_state.id, command="summarize") -def test_auto_summarize(): +def test_auto_summarize(mock_e2b_api_key_none): """Test that the summarizer triggers by itself""" client = create_client() client.set_default_llm_config(LLMConfig.default_config("gpt-4")) @@ -86,7 +107,7 @@ def test_auto_summarize(): small_context_llm_config = LLMConfig.default_config("gpt-4") # default system prompt + funcs lead to ~2300 tokens, after one message it's at 2523 tokens - SMALL_CONTEXT_WINDOW = 3000 + SMALL_CONTEXT_WINDOW = 4000 small_context_llm_config.context_window = SMALL_CONTEXT_WINDOW agent_state = client.create_agent( @@ -98,7 +119,7 @@ def test_auto_summarize(): def summarize_message_exists(messages: List[Message]) -> bool: for message in messages: - if message.text and "have been hidden from view due to conversation memory constraints" in message.text: + if message.text and "The following is a summary of the previous" in message.text: print(f"Summarize message found after {message_count} messages: \n {message.text}") return True return False @@ -114,11 +135,12 @@ def summarize_message_exists(messages: List[Message]) -> bool: ) message_count += 1 - print(f"Message {message_count}: \n\n{response.messages}") + print(f"Message {message_count}: \n\n{response.messages}" + "--------------------------------") # check if the summarize message is inside the messages assert isinstance(client, LocalClient), "Test only works with LocalClient" - agent_obj = client.server._get_or_load_agent(agent_id=agent_state.id) + agent_obj = client.server.load_agent(agent_id=agent_state.id) + print("SUMMARY", summarize_message_exists(agent_obj._messages)) if summarize_message_exists(agent_obj._messages): break diff --git a/tests/test_tool_rule_solver.py b/tests/test_tool_rule_solver.py index 1347811933..9de6a6302b 100644 --- a/tests/test_tool_rule_solver.py +++ b/tests/test_tool_rule_solver.py @@ -2,7 +2,7 @@ from letta.helpers import ToolRulesSolver from letta.helpers.tool_rule_solver import ToolRuleValidationError -from letta.schemas.tool_rule import InitToolRule, TerminalToolRule, ToolRule +from letta.schemas.tool_rule import ChildToolRule, InitToolRule, TerminalToolRule # Constants for tool names used in the tests START_TOOL = "start_tool" @@ -30,7 +30,7 @@ def test_get_allowed_tool_names_with_init_rules(): def test_get_allowed_tool_names_with_subsequent_rule(): # Setup: Tool rule sequence init_rule = InitToolRule(tool_name=START_TOOL) - rule_1 = ToolRule(tool_name=START_TOOL, children=[NEXT_TOOL, HELPER_TOOL]) + rule_1 = ChildToolRule(tool_name=START_TOOL, children=[NEXT_TOOL, HELPER_TOOL]) solver = ToolRulesSolver(init_tool_rules=[init_rule], tool_rules=[rule_1], terminal_tool_rules=[]) # Action: Update usage and get allowed tools @@ -84,8 +84,8 @@ def test_get_allowed_tool_names_no_matching_rule_error(): def test_update_tool_usage_and_get_allowed_tool_names_combined(): # Setup: More complex rule chaining init_rule = InitToolRule(tool_name=START_TOOL) - rule_1 = ToolRule(tool_name=START_TOOL, children=[NEXT_TOOL]) - rule_2 = ToolRule(tool_name=NEXT_TOOL, children=[FINAL_TOOL]) + rule_1 = ChildToolRule(tool_name=START_TOOL, children=[NEXT_TOOL]) + rule_2 = ChildToolRule(tool_name=NEXT_TOOL, children=[FINAL_TOOL]) terminal_rule = TerminalToolRule(tool_name=FINAL_TOOL) solver = ToolRulesSolver(init_tool_rules=[init_rule], tool_rules=[rule_1, rule_2], terminal_tool_rules=[terminal_rule]) @@ -107,10 +107,10 @@ def test_update_tool_usage_and_get_allowed_tool_names_combined(): def test_tool_rules_with_cycle_detection(): # Setup: Define tool rules with both connected, disconnected nodes and a cycle init_rule = InitToolRule(tool_name=START_TOOL) - rule_1 = ToolRule(tool_name=START_TOOL, children=[NEXT_TOOL]) - rule_2 = ToolRule(tool_name=NEXT_TOOL, children=[HELPER_TOOL]) - rule_3 = ToolRule(tool_name=HELPER_TOOL, children=[START_TOOL]) # This creates a cycle: start -> next -> helper -> start - rule_4 = ToolRule(tool_name=FINAL_TOOL, children=[END_TOOL]) # Disconnected rule, no cycle here + rule_1 = ChildToolRule(tool_name=START_TOOL, children=[NEXT_TOOL]) + rule_2 = ChildToolRule(tool_name=NEXT_TOOL, children=[HELPER_TOOL]) + rule_3 = ChildToolRule(tool_name=HELPER_TOOL, children=[START_TOOL]) # This creates a cycle: start -> next -> helper -> start + rule_4 = ChildToolRule(tool_name=FINAL_TOOL, children=[END_TOOL]) # Disconnected rule, no cycle here terminal_rule = TerminalToolRule(tool_name=END_TOOL) # Action & Assert: Attempt to create the ToolRulesSolver with a cycle should raise ValidationError @@ -118,7 +118,7 @@ def test_tool_rules_with_cycle_detection(): ToolRulesSolver(tool_rules=[init_rule, rule_1, rule_2, rule_3, rule_4, terminal_rule]) # Extra setup: Define tool rules without a cycle but with hanging nodes - rule_5 = ToolRule(tool_name=PREP_TOOL, children=[FINAL_TOOL]) # Hanging node with no connection to start_tool + rule_5 = ChildToolRule(tool_name=PREP_TOOL, children=[FINAL_TOOL]) # Hanging node with no connection to start_tool # Assert that a configuration without cycles does not raise an error try: