From 3634dba7d9a7ea4e2273583f63a155209fc93aca Mon Sep 17 00:00:00 2001 From: Charles Packer Date: Sun, 8 Sep 2024 15:24:41 -0700 Subject: [PATCH] feat: allow updating `Message`s via the server, remap CLI commands that update messages to use `server.py` (#1715) Co-authored-by: Sarah Wooders --- memgpt/agent.py | 156 ++++++++++++++++++++++- memgpt/client/client.py | 51 +++++++- memgpt/main.py | 82 +++--------- memgpt/metadata.py | 10 +- memgpt/schemas/message.py | 18 +++ memgpt/server/rest_api/agents/message.py | 15 ++- memgpt/server/server.py | 66 +++++++++- tests/test_client.py | 21 ++- tests/test_server.py | 65 ++++++++++ 9 files changed, 412 insertions(+), 72 deletions(-) diff --git a/memgpt/agent.py b/memgpt/agent.py index 78a9d11fe4..c4a65c9df2 100644 --- a/memgpt/agent.py +++ b/memgpt/agent.py @@ -1,6 +1,7 @@ import datetime import inspect import traceback +import warnings from abc import ABC, abstractmethod from typing import List, Literal, Optional, Tuple, Union @@ -24,9 +25,9 @@ from memgpt.schemas.agent import AgentState from memgpt.schemas.block import Block from memgpt.schemas.embedding_config import EmbeddingConfig -from memgpt.schemas.enums import OptionState +from memgpt.schemas.enums import MessageRole, OptionState from memgpt.schemas.memory import Memory -from memgpt.schemas.message import Message +from memgpt.schemas.message import Message, UpdateMessage from memgpt.schemas.openai.chat_completion_response import ChatCompletionResponse from memgpt.schemas.openai.chat_completion_response import ( Message as ChatCompletionMessage, @@ -201,7 +202,7 @@ def update_state(self) -> AgentState: raise NotImplementedError -class Agent(object): +class Agent(BaseAgent): def __init__( self, interface: AgentInterface, @@ -369,6 +370,14 @@ def set_message_buffer(self, message_ids: List[str], force_utc: bool = True): # also sync the message IDs attribute self.agent_state.message_ids = message_ids + def refresh_message_buffer(self): + """Refresh the message buffer from the database""" + + messages_to_sync = self.agent_state.message_ids + assert messages_to_sync and all([isinstance(msg_id, str) for msg_id in messages_to_sync]) + + self.set_message_buffer(message_ids=messages_to_sync) + def _trim_messages(self, num): """Trim messages from the front, not including the system message""" self.persistence_manager.trim_messages(num) @@ -1209,6 +1218,147 @@ def attach_source(self, source_id: str, source_connector: StorageConnector, ms: f"Attached data source {source.name} to agent {self.agent_state.name}, consisting of {len(all_passages)}. Agent now has {total_agent_passages} embeddings in archival memory.", ) + def update_message(self, request: UpdateMessage) -> Message: + """Update the details of a message associated with an agent""" + + message = self.persistence_manager.recall_memory.storage.get(id=request.id) + if message is None: + raise ValueError(f"Message with id {request.id} not found") + assert isinstance(message, Message), f"Message is not a Message object: {type(message)}" + + # Override fields + # NOTE: we try to do some sanity checking here (see asserts), but it's not foolproof + if request.role: + message.role = request.role + if request.text: + message.text = request.text + if request.name: + message.name = request.name + if request.tool_calls: + assert message.role == MessageRole.assistant, "Tool calls can only be added to assistant messages" + message.tool_calls = request.tool_calls + if request.tool_call_id: + assert message.role == MessageRole.tool, "tool_call_id can only be added to tool messages" + message.tool_call_id = request.tool_call_id + + # Save the updated message + self.persistence_manager.recall_memory.storage.update(record=message) + + # Return the updated message + updated_message = self.persistence_manager.recall_memory.storage.get(id=message.id) + if updated_message is None: + raise ValueError(f"Error persisting message - message with id {request.id} not found") + return updated_message + + # TODO(sarah): should we be creating a new message here, or just editing a message? + def rethink_message(self, new_thought: str) -> Message: + """Rethink / update the last message""" + for x in range(len(self.messages) - 1, 0, -1): + msg_obj = self._messages[x] + if msg_obj.role == MessageRole.assistant: + updated_message = self.update_message( + request=UpdateMessage( + id=msg_obj.id, + text=new_thought, + ) + ) + self.refresh_message_buffer() + return updated_message + raise ValueError(f"No assistant message found to update") + + # TODO(sarah): should we be creating a new message here, or just editing a message? + def rewrite_message(self, new_text: str) -> Message: + """Rewrite / update the send_message text on the last message""" + + # Walk backwards through the messages until we find an assistant message + for x in range(len(self._messages) - 1, 0, -1): + if self._messages[x].role == MessageRole.assistant: + # Get the current message content + message_obj = self._messages[x] + + # The rewrite target is the output of send_message + if message_obj.tool_calls is not None and len(message_obj.tool_calls) > 0: + + # Check that we hit an assistant send_message call + name_string = message_obj.tool_calls[0].function.name + if name_string is None or name_string != "send_message": + raise ValueError("Assistant missing send_message function call") + + args_string = message_obj.tool_calls[0].function.arguments + if args_string is None: + raise ValueError("Assistant missing send_message function arguments") + + args_json = json_loads(args_string) + if "message" not in args_json: + raise ValueError("Assistant missing send_message message argument") + + # Once we found our target, rewrite it + args_json["message"] = new_text + new_args_string = json_dumps(args_json) + message_obj.tool_calls[0].function.arguments = new_args_string + + # Write the update to the DB + updated_message = self.update_message( + request=UpdateMessage( + id=message_obj.id, + tool_calls=message_obj.tool_calls, + ) + ) + self.refresh_message_buffer() + return updated_message + + raise ValueError("No assistant message found to update") + + def pop_message(self, count: int = 1) -> List[Message]: + """Pop the last N messages from the agent's memory""" + n_messages = len(self._messages) + popped_messages = [] + MIN_MESSAGES = 2 + if n_messages <= MIN_MESSAGES: + raise ValueError(f"Agent only has {n_messages} messages in stack, none left to pop") + elif n_messages - count < MIN_MESSAGES: + raise ValueError(f"Agent only has {n_messages} messages in stack, cannot pop more than {n_messages - MIN_MESSAGES}") + else: + # print(f"Popping last {count} messages from stack") + for _ in range(min(count, len(self._messages))): + # remove the message from the internal state of the agent + deleted_message = self._messages.pop() + # then also remove it from recall storage + try: + self.persistence_manager.recall_memory.storage.delete(filters={"id": deleted_message.id}) + popped_messages.append(deleted_message) + except Exception as e: + warnings.warn(f"Error deleting message {deleted_message.id} from recall memory: {e}") + self._messages.append(deleted_message) + break + + return popped_messages + + def pop_until_user(self) -> List[Message]: + """Pop all messages until the last user message""" + if MessageRole.user not in [msg.role for msg in self._messages]: + raise ValueError("No user message found in buffer") + + popped_messages = [] + while len(self._messages) > 0: + if self._messages[-1].role == MessageRole.user: + # we want to pop up to the last user message + return popped_messages + else: + popped_messages.append(self.pop_message(count=1)) + + raise ValueError("No user message found in buffer") + + def retry_message(self) -> List[Message]: + """Retry / regenerate the last message""" + + self.pop_until_user() + user_message = self.pop_message(count=1)[0] + messages, _, _, _, _ = self.step(user_message=user_message.text, return_dicts=False) + + assert messages is not None and all(isinstance(msg, Message) for msg in messages), "step() returned non-Message objects" + return messages + def save_agent(agent: Agent, ms: MetadataStore): """Save agent to metadata store""" diff --git a/memgpt/client/client.py b/memgpt/client/client.py index 584d23d7a9..5e1256c229 100644 --- a/memgpt/client/client.py +++ b/memgpt/client/client.py @@ -35,7 +35,8 @@ Memory, RecallMemorySummary, ) -from memgpt.schemas.message import Message, MessageCreate +from memgpt.schemas.message import Message, MessageCreate, UpdateMessage +from memgpt.schemas.openai.chat_completions import ToolCall from memgpt.schemas.passage import Passage from memgpt.schemas.source import Source, SourceCreate, SourceUpdate from memgpt.schemas.tool import Tool, ToolCreate, ToolUpdate @@ -377,6 +378,31 @@ def create_agent( raise ValueError(f"Status {response.status_code} - Failed to create agent: {response.text}") return AgentState(**response.json()) + def update_message( + self, + agent_id: str, + message_id: str, + role: Optional[MessageRole] = None, + text: Optional[str] = None, + name: Optional[str] = None, + tool_calls: Optional[List[ToolCall]] = None, + tool_call_id: Optional[str] = None, + ) -> Message: + request = UpdateMessage( + id=message_id, + role=role, + text=text, + name=name, + tool_calls=tool_calls, + tool_call_id=tool_call_id, + ) + response = requests.patch( + f"{self.base_url}/api/agents/{agent_id}/messages/{message_id}", json=request.model_dump(), headers=self.headers + ) + if response.status_code != 200: + raise ValueError(f"Failed to update message: {response.text}") + return Message(**response.json()) + def update_agent( self, agent_id: str, @@ -1402,6 +1428,29 @@ def create_agent( ) return agent_state + def update_message( + self, + agent_id: str, + message_id: str, + role: Optional[MessageRole] = None, + text: Optional[str] = None, + name: Optional[str] = None, + tool_calls: Optional[List[ToolCall]] = None, + tool_call_id: Optional[str] = None, + ) -> Message: + message = self.server.update_agent_message( + agent_id=agent_id, + request=UpdateMessage( + id=message_id, + role=role, + text=text, + name=name, + tool_calls=tool_calls, + tool_call_id=tool_call_id, + ), + ) + return message + def update_agent( self, agent_id: str, diff --git a/memgpt/main.py b/memgpt/main.py index d9639ba55c..fceaff6b23 100644 --- a/memgpt/main.py +++ b/memgpt/main.py @@ -206,82 +206,40 @@ def run_agent_loop( # Check if there's an additional argument that's an integer command = user_input.strip().split() pop_amount = int(command[1]) if len(command) > 1 and command[1].isdigit() else 3 - n_messages = len(memgpt_agent._messages) - MIN_MESSAGES = 2 - if n_messages <= MIN_MESSAGES: - print(f"Agent only has {n_messages} messages in stack, none left to pop") - elif n_messages - pop_amount < MIN_MESSAGES: - print(f"Agent only has {n_messages} messages in stack, cannot pop more than {n_messages - MIN_MESSAGES}") - else: - print(f"Popping last {pop_amount} messages from stack") - for _ in range(min(pop_amount, len(memgpt_agent._messages))): - # remove the message from the internal state of the agent - deleted_message = memgpt_agent._messages.pop() - # then also remove it from recall storage - memgpt_agent.persistence_manager.recall_memory.storage.delete(filters={"id": deleted_message.id}) + try: + popped_messages = memgpt_agent.pop_message(count=pop_amount) + except ValueError as e: + print(f"Error popping messages: {e}") continue elif user_input.lower() == "/retry": - print(f"Retrying for another answer") - while len(memgpt_agent._messages) > 0: - if memgpt_agent._messages[-1].role == "user": - # we want to pop up to the last user message and send it again - user_message = memgpt_agent._messages[-1].text - deleted_message = memgpt_agent._messages.pop() - # then also remove it from recall storage - memgpt_agent.persistence_manager.recall_memory.storage.delete(filters={"id": deleted_message.id}) - break - deleted_message = memgpt_agent._messages.pop() - # then also remove it from recall storage - memgpt_agent.persistence_manager.recall_memory.storage.delete(filters={"id": deleted_message.id}) + print(f"Retrying for another answer...") + try: + memgpt_agent.retry_message() + except Exception as e: + print(f"Error retrying message: {e}") + continue elif user_input.lower() == "/rethink" or user_input.lower().startswith("/rethink "): if len(user_input) < len("/rethink "): print("Missing text after the command") continue - for x in range(len(memgpt_agent.messages) - 1, 0, -1): - msg_obj = memgpt_agent._messages[x] - if msg_obj.role == "assistant": - clean_new_text = user_input[len("/rethink ") :].strip() - msg_obj.text = clean_new_text - # To persist to the database, all we need to do is "re-insert" into recall memory - memgpt_agent.persistence_manager.recall_memory.storage.update(record=msg_obj) - break + try: + memgpt_agent.rethink_message(new_thought=user_input[len("/rethink ") :].strip()) + except Exception as e: + print(f"Error rethinking message: {e}") continue elif user_input.lower() == "/rewrite" or user_input.lower().startswith("/rewrite "): if len(user_input) < len("/rewrite "): print("Missing text after the command") continue - for x in range(len(memgpt_agent.messages) - 1, 0, -1): - if memgpt_agent.messages[x].get("role") == "assistant": - text = user_input[len("/rewrite ") :].strip() - # Get the current message content - # The rewrite target is the output of send_message - message_obj = memgpt_agent._messages[x] - if message_obj.tool_calls is not None and len(message_obj.tool_calls) > 0: - # Check that we hit an assistant send_message call - name_string = message_obj.tool_calls[0].function.get("name") - if name_string is None or name_string != "send_message": - print("Assistant missing send_message function call") - break # cancel op - args_string = message_obj.tool_calls[0].function.get("arguments") - if args_string is None: - print("Assistant missing send_message function arguments") - break # cancel op - args_json = json_loads(args_string) - if "message" not in args_json: - print("Assistant missing send_message message argument") - break # cancel op - - # Once we found our target, rewrite it - args_json["message"] = text - new_args_string = json_dumps(args_json) - message_obj.tool_calls[0].function["arguments"] = new_args_string - - # To persist to the database, all we need to do is "re-insert" into recall memory - memgpt_agent.persistence_manager.recall_memory.storage.update(record=message_obj) - break + + text = user_input[len("/rewrite ") :].strip() + try: + memgpt_agent.rewrite_message(new_text=text) + except Exception as e: + print(f"Error rewriting message: {e}") continue elif user_input.lower() == "/summarize": diff --git a/memgpt/metadata.py b/memgpt/metadata.py index 5d22d9380e..83d81d5ef7 100644 --- a/memgpt/metadata.py +++ b/memgpt/metadata.py @@ -91,8 +91,14 @@ def load_dialect_impl(self, dialect): def process_bind_param(self, value, dialect): if value: - # return [vars(tool) for tool in value] - return value + values = [] + for v in value: + if isinstance(v, ToolCall): + values.append(v.model_dump()) + else: + values.append(v) + return values + return value def process_result_value(self, value, dialect): diff --git a/memgpt/schemas/message.py b/memgpt/schemas/message.py index 00b2756bdf..26a43645d6 100644 --- a/memgpt/schemas/message.py +++ b/memgpt/schemas/message.py @@ -57,6 +57,24 @@ class MessageCreate(BaseMessage): name: Optional[str] = Field(None, description="The name of the participant.") +class UpdateMessage(BaseMessage): + """Request to update a message""" + + id: str = Field(..., description="The id of the message.") + role: Optional[MessageRole] = Field(None, description="The role of the participant.") + text: Optional[str] = Field(None, description="The text of the message.") + # NOTE: probably doesn't make sense to allow remapping user_id or agent_id (vs creating a new message) + # user_id: Optional[str] = Field(None, description="The unique identifier of the user.") + # agent_id: Optional[str] = Field(None, description="The unique identifier of the agent.") + # NOTE: we probably shouldn't allow updating the model field, otherwise this loses meaning + # model: Optional[str] = Field(None, description="The model used to make the function call.") + name: Optional[str] = Field(None, description="The name of the participant.") + # NOTE: we probably shouldn't allow updating the created_at field, right? + # created_at: Optional[datetime] = Field(None, description="The time the message was created.") + tool_calls: Optional[List[ToolCall]] = Field(None, description="The list of tool calls requested.") + tool_call_id: Optional[str] = Field(None, description="The id of the tool call.") + + class Message(BaseMessage): """ MemGPT's internal representation of a message. Includes methods to convert to/from LLM provider formats. diff --git a/memgpt/server/rest_api/agents/message.py b/memgpt/server/rest_api/agents/message.py index ee1e53ca95..896631838e 100644 --- a/memgpt/server/rest_api/agents/message.py +++ b/memgpt/server/rest_api/agents/message.py @@ -10,7 +10,7 @@ from memgpt.schemas.memgpt_message import LegacyMemGPTMessage, MemGPTMessage from memgpt.schemas.memgpt_request import MemGPTRequest from memgpt.schemas.memgpt_response import MemGPTResponse -from memgpt.schemas.message import Message +from memgpt.schemas.message import Message, UpdateMessage from memgpt.server.rest_api.auth_token import get_current_user from memgpt.server.rest_api.interface import QueuingInterface, StreamingServerInterface from memgpt.server.rest_api.utils import sse_async_generator @@ -181,4 +181,17 @@ async def send_message( return_message_object=request.return_message_object, ) + @router.patch("/agents/{agent_id}/messages/{message_id}", tags=["agents"], response_model=Message) + async def update_message( + agent_id: str, + message_id: str, + request: UpdateMessage = Body(...), + user_id: str = Depends(get_current_user_with_server), + ): + """ + Update the details of a message associated with an agent. + """ + assert request.id == message_id, f"Message ID mismatch: {request.id} != {message_id}" + return server.update_agent_message(agent_id=agent_id, request=request) + return router diff --git a/memgpt/server/server.py b/memgpt/server/server.py index f5a2f23a17..0573e5472f 100644 --- a/memgpt/server/server.py +++ b/memgpt/server/server.py @@ -56,7 +56,7 @@ from memgpt.schemas.llm_config import LLMConfig from memgpt.schemas.memgpt_message import MemGPTMessage from memgpt.schemas.memory import ArchivalMemorySummary, Memory, RecallMemorySummary -from memgpt.schemas.message import Message +from memgpt.schemas.message import Message, UpdateMessage from memgpt.schemas.openai.chat_completion_response import UsageStatistics from memgpt.schemas.passage import Passage from memgpt.schemas.source import Source, SourceCreate, SourceUpdate @@ -1702,9 +1702,71 @@ def add_default_blocks(self, user_id: str): name = os.path.basename(human_file).replace(".txt", "") self.create_block(CreateHuman(user_id=user_id, name=name, value=text, template=True), user_id=user_id, update=True) - def get_agent_message(self, agent_id: str, message_id: str) -> Message: + def get_agent_message(self, agent_id: str, message_id: str) -> Optional[Message]: """Get a single message from the agent's memory""" # Get the agent object (loaded in memory) memgpt_agent = self._get_or_load_agent(agent_id=agent_id) message = memgpt_agent.persistence_manager.recall_memory.storage.get(id=message_id) return message + + def update_agent_message(self, agent_id: str, request: UpdateMessage) -> Message: + """Update the details of a message associated with an agent""" + + # Get the current message + memgpt_agent = self._get_or_load_agent(agent_id=agent_id) + return memgpt_agent.update_message(request=request) + + # TODO decide whether this should be done in the server.py or agent.py + # Reason to put it in agent.py: + # - we use the agent object's persistence_manager to update the message + # - it makes it easy to do things like `retry`, `rethink`, etc. + # Reason to put it in server.py: + # - fundamentally, we should be able to edit a message (without agent id) + # in the server by directly accessing the DB / message store + """ + message = memgpt_agent.persistence_manager.recall_memory.storage.get(id=request.id) + if message is None: + raise ValueError(f"Message with id {request.id} not found") + + # Override fields + # NOTE: we try to do some sanity checking here (see asserts), but it's not foolproof + if request.role: + message.role = request.role + if request.text: + message.text = request.text + if request.name: + message.name = request.name + if request.tool_calls: + assert message.role == MessageRole.assistant, "Tool calls can only be added to assistant messages" + message.tool_calls = request.tool_calls + if request.tool_call_id: + assert message.role == MessageRole.tool, "tool_call_id can only be added to tool messages" + message.tool_call_id = request.tool_call_id + + # Save the updated message + memgpt_agent.persistence_manager.recall_memory.storage.update(record=message) + + # Return the updated message + updated_message = memgpt_agent.persistence_manager.recall_memory.storage.get(id=message.id) + if updated_message is None: + raise ValueError(f"Error persisting message - message with id {request.id} not found") + return updated_message + """ + + def rewrite_agent_message(self, agent_id: str, new_text: str) -> Message: + + # Get the current message + memgpt_agent = self._get_or_load_agent(agent_id=agent_id) + return memgpt_agent.rewrite_message(new_text=new_text) + + def rethink_agent_message(self, agent_id: str, new_thought: str) -> Message: + + # Get the current message + memgpt_agent = self._get_or_load_agent(agent_id=agent_id) + return memgpt_agent.rethink_message(new_thought=new_thought) + + def retry_agent_message(self, agent_id: str) -> List[Message]: + + # Get the current message + memgpt_agent = self._get_or_load_agent(agent_id=agent_id) + return memgpt_agent.retry_message() diff --git a/tests/test_client.py b/tests/test_client.py index b496f7ba33..6111149fb9 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -13,7 +13,7 @@ from memgpt.schemas.agent import AgentState from memgpt.schemas.enums import JobStatus, MessageStreamStatus from memgpt.schemas.memgpt_message import FunctionCallMessage, InternalMonologue -from memgpt.schemas.memgpt_response import MemGPTStreamingResponse +from memgpt.schemas.memgpt_response import MemGPTResponse, MemGPTStreamingResponse from memgpt.schemas.message import Message from memgpt.schemas.usage import MemGPTUsageStatistics @@ -395,3 +395,22 @@ def test_sources(client: Union[LocalClient, RESTClient], agent: AgentState): # delete the source client.delete_source(source.id) + + +def test_message_update(client: Union[LocalClient, RESTClient], agent: AgentState): + """Test that we can update the details of a message""" + + # create a message + message_response = client.send_message( + agent_id=agent.id, + message="Test message", + role="user", + ) + print("Messages=", message_response) + assert isinstance(message_response, MemGPTResponse) + assert isinstance(message_response.messages[-1], Message) + message = message_response.messages[-1] + + new_text = "This exact string would never show up in the message???" + new_message = client.update_message(message_id=message.id, text=new_text, agent_id=agent.id) + assert new_message.text == new_text diff --git a/tests/test_server.py b/tests/test_server.py index 1046b28b17..e63f3c8872 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -1,3 +1,4 @@ +import json import uuid import pytest @@ -375,3 +376,67 @@ def _test_get_messages_memgpt_format(server, user_id, agent_id, reverse=False): def test_get_messages_memgpt_format(server, user_id, agent_id): _test_get_messages_memgpt_format(server, user_id, agent_id, reverse=False) _test_get_messages_memgpt_format(server, user_id, agent_id, reverse=True) + + +def test_agent_rethink_rewrite_retry(server, user_id, agent_id): + """Test the /rethink, /rewrite, and /retry commands in the CLI + + - "rethink" replaces the inner thoughts of the last assistant message + - "rewrite" replaces the text of the last assistant message + - "retry" retries the last assistant message + """ + + # Send an initial message + server.user_message(user_id=user_id, agent_id=agent_id, message="Hello?") + + # Grab the raw Agent object + memgpt_agent = server._get_or_load_agent(agent_id=agent_id) + assert memgpt_agent._messages[-1].role == MessageRole.tool + assert memgpt_agent._messages[-2].role == MessageRole.assistant + last_agent_message = memgpt_agent._messages[-2] + + # Try "rethink" + new_thought = "I am thinking about the meaning of life, the universe, and everything. Bananas?" + assert last_agent_message.text is not None and last_agent_message.text != new_thought + server.rethink_agent_message(agent_id=agent_id, new_thought=new_thought) + + # Grab the agent object again (make sure it's live) + memgpt_agent = server._get_or_load_agent(agent_id=agent_id) + assert memgpt_agent._messages[-1].role == MessageRole.tool + assert memgpt_agent._messages[-2].role == MessageRole.assistant + last_agent_message = memgpt_agent._messages[-2] + assert last_agent_message.text == new_thought + + # Try "rewrite" + assert last_agent_message.tool_calls is not None + assert last_agent_message.tool_calls[0].function.name == "send_message" + assert last_agent_message.tool_calls[0].function.arguments is not None + args_json = json.loads(last_agent_message.tool_calls[0].function.arguments) + assert "message" in args_json and args_json["message"] is not None and args_json["message"] != "" + + new_text = "Why hello there my good friend! Is 42 what you're looking for? Bananas?" + server.rewrite_agent_message(agent_id=agent_id, new_text=new_text) + + # Grab the agent object again (make sure it's live) + memgpt_agent = server._get_or_load_agent(agent_id=agent_id) + assert memgpt_agent._messages[-1].role == MessageRole.tool + assert memgpt_agent._messages[-2].role == MessageRole.assistant + last_agent_message = memgpt_agent._messages[-2] + args_json = json.loads(last_agent_message.tool_calls[0].function.arguments) + assert "message" in args_json and args_json["message"] is not None and args_json["message"] == new_text + + # Try retry + server.retry_agent_message(agent_id=agent_id) + + # Grab the agent object again (make sure it's live) + memgpt_agent = server._get_or_load_agent(agent_id=agent_id) + assert memgpt_agent._messages[-1].role == MessageRole.tool + assert memgpt_agent._messages[-2].role == MessageRole.assistant + last_agent_message = memgpt_agent._messages[-2] + + # Make sure the inner thoughts changed + assert last_agent_message.text is not None and last_agent_message.text != new_thought + + # Make sure the message changed + args_json = json.loads(last_agent_message.tool_calls[0].function.arguments) + assert "message" in args_json and args_json["message"] is not None and args_json["message"] != new_text