Skip to content

Commit

Permalink
feat: allow updating Messages via the server, remap CLI commands th…
Browse files Browse the repository at this point in the history
…at update messages to use `server.py` (#1715)

Co-authored-by: Sarah Wooders <[email protected]>
  • Loading branch information
cpacker and sarahwooders authored Sep 8, 2024
1 parent 8d5bf31 commit 3634dba
Show file tree
Hide file tree
Showing 9 changed files with 412 additions and 72 deletions.
156 changes: 153 additions & 3 deletions memgpt/agent.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import datetime
import inspect
import traceback
import warnings
from abc import ABC, abstractmethod
from typing import List, Literal, Optional, Tuple, Union

Expand All @@ -24,9 +25,9 @@
from memgpt.schemas.agent import AgentState
from memgpt.schemas.block import Block
from memgpt.schemas.embedding_config import EmbeddingConfig
from memgpt.schemas.enums import OptionState
from memgpt.schemas.enums import MessageRole, OptionState
from memgpt.schemas.memory import Memory
from memgpt.schemas.message import Message
from memgpt.schemas.message import Message, UpdateMessage
from memgpt.schemas.openai.chat_completion_response import ChatCompletionResponse
from memgpt.schemas.openai.chat_completion_response import (
Message as ChatCompletionMessage,
Expand Down Expand Up @@ -201,7 +202,7 @@ def update_state(self) -> AgentState:
raise NotImplementedError


class Agent(object):
class Agent(BaseAgent):
def __init__(
self,
interface: AgentInterface,
Expand Down Expand Up @@ -369,6 +370,14 @@ def set_message_buffer(self, message_ids: List[str], force_utc: bool = True):
# also sync the message IDs attribute
self.agent_state.message_ids = message_ids

def refresh_message_buffer(self):
"""Refresh the message buffer from the database"""

messages_to_sync = self.agent_state.message_ids
assert messages_to_sync and all([isinstance(msg_id, str) for msg_id in messages_to_sync])

self.set_message_buffer(message_ids=messages_to_sync)

def _trim_messages(self, num):
"""Trim messages from the front, not including the system message"""
self.persistence_manager.trim_messages(num)
Expand Down Expand Up @@ -1209,6 +1218,147 @@ def attach_source(self, source_id: str, source_connector: StorageConnector, ms:
f"Attached data source {source.name} to agent {self.agent_state.name}, consisting of {len(all_passages)}. Agent now has {total_agent_passages} embeddings in archival memory.",
)

def update_message(self, request: UpdateMessage) -> Message:
"""Update the details of a message associated with an agent"""

message = self.persistence_manager.recall_memory.storage.get(id=request.id)
if message is None:
raise ValueError(f"Message with id {request.id} not found")
assert isinstance(message, Message), f"Message is not a Message object: {type(message)}"

# Override fields
# NOTE: we try to do some sanity checking here (see asserts), but it's not foolproof
if request.role:
message.role = request.role
if request.text:
message.text = request.text
if request.name:
message.name = request.name
if request.tool_calls:
assert message.role == MessageRole.assistant, "Tool calls can only be added to assistant messages"
message.tool_calls = request.tool_calls
if request.tool_call_id:
assert message.role == MessageRole.tool, "tool_call_id can only be added to tool messages"
message.tool_call_id = request.tool_call_id

# Save the updated message
self.persistence_manager.recall_memory.storage.update(record=message)

# Return the updated message
updated_message = self.persistence_manager.recall_memory.storage.get(id=message.id)
if updated_message is None:
raise ValueError(f"Error persisting message - message with id {request.id} not found")
return updated_message

# TODO(sarah): should we be creating a new message here, or just editing a message?
def rethink_message(self, new_thought: str) -> Message:
"""Rethink / update the last message"""
for x in range(len(self.messages) - 1, 0, -1):
msg_obj = self._messages[x]
if msg_obj.role == MessageRole.assistant:
updated_message = self.update_message(
request=UpdateMessage(
id=msg_obj.id,
text=new_thought,
)
)
self.refresh_message_buffer()
return updated_message
raise ValueError(f"No assistant message found to update")

# TODO(sarah): should we be creating a new message here, or just editing a message?
def rewrite_message(self, new_text: str) -> Message:
"""Rewrite / update the send_message text on the last message"""

# Walk backwards through the messages until we find an assistant message
for x in range(len(self._messages) - 1, 0, -1):
if self._messages[x].role == MessageRole.assistant:
# Get the current message content
message_obj = self._messages[x]

# The rewrite target is the output of send_message
if message_obj.tool_calls is not None and len(message_obj.tool_calls) > 0:

# Check that we hit an assistant send_message call
name_string = message_obj.tool_calls[0].function.name
if name_string is None or name_string != "send_message":
raise ValueError("Assistant missing send_message function call")

args_string = message_obj.tool_calls[0].function.arguments
if args_string is None:
raise ValueError("Assistant missing send_message function arguments")

args_json = json_loads(args_string)
if "message" not in args_json:
raise ValueError("Assistant missing send_message message argument")

# Once we found our target, rewrite it
args_json["message"] = new_text
new_args_string = json_dumps(args_json)
message_obj.tool_calls[0].function.arguments = new_args_string

# Write the update to the DB
updated_message = self.update_message(
request=UpdateMessage(
id=message_obj.id,
tool_calls=message_obj.tool_calls,
)
)
self.refresh_message_buffer()
return updated_message

raise ValueError("No assistant message found to update")

def pop_message(self, count: int = 1) -> List[Message]:
"""Pop the last N messages from the agent's memory"""
n_messages = len(self._messages)
popped_messages = []
MIN_MESSAGES = 2
if n_messages <= MIN_MESSAGES:
raise ValueError(f"Agent only has {n_messages} messages in stack, none left to pop")
elif n_messages - count < MIN_MESSAGES:
raise ValueError(f"Agent only has {n_messages} messages in stack, cannot pop more than {n_messages - MIN_MESSAGES}")
else:
# print(f"Popping last {count} messages from stack")
for _ in range(min(count, len(self._messages))):
# remove the message from the internal state of the agent
deleted_message = self._messages.pop()
# then also remove it from recall storage
try:
self.persistence_manager.recall_memory.storage.delete(filters={"id": deleted_message.id})
popped_messages.append(deleted_message)
except Exception as e:
warnings.warn(f"Error deleting message {deleted_message.id} from recall memory: {e}")
self._messages.append(deleted_message)
break

return popped_messages

def pop_until_user(self) -> List[Message]:
"""Pop all messages until the last user message"""
if MessageRole.user not in [msg.role for msg in self._messages]:
raise ValueError("No user message found in buffer")

popped_messages = []
while len(self._messages) > 0:
if self._messages[-1].role == MessageRole.user:
# we want to pop up to the last user message
return popped_messages
else:
popped_messages.append(self.pop_message(count=1))

raise ValueError("No user message found in buffer")

def retry_message(self) -> List[Message]:
"""Retry / regenerate the last message"""

self.pop_until_user()
user_message = self.pop_message(count=1)[0]
messages, _, _, _, _ = self.step(user_message=user_message.text, return_dicts=False)

assert messages is not None and all(isinstance(msg, Message) for msg in messages), "step() returned non-Message objects"
return messages


def save_agent(agent: Agent, ms: MetadataStore):
"""Save agent to metadata store"""
Expand Down
51 changes: 50 additions & 1 deletion memgpt/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@
Memory,
RecallMemorySummary,
)
from memgpt.schemas.message import Message, MessageCreate
from memgpt.schemas.message import Message, MessageCreate, UpdateMessage
from memgpt.schemas.openai.chat_completions import ToolCall
from memgpt.schemas.passage import Passage
from memgpt.schemas.source import Source, SourceCreate, SourceUpdate
from memgpt.schemas.tool import Tool, ToolCreate, ToolUpdate
Expand Down Expand Up @@ -377,6 +378,31 @@ def create_agent(
raise ValueError(f"Status {response.status_code} - Failed to create agent: {response.text}")
return AgentState(**response.json())

def update_message(
self,
agent_id: str,
message_id: str,
role: Optional[MessageRole] = None,
text: Optional[str] = None,
name: Optional[str] = None,
tool_calls: Optional[List[ToolCall]] = None,
tool_call_id: Optional[str] = None,
) -> Message:
request = UpdateMessage(
id=message_id,
role=role,
text=text,
name=name,
tool_calls=tool_calls,
tool_call_id=tool_call_id,
)
response = requests.patch(
f"{self.base_url}/api/agents/{agent_id}/messages/{message_id}", json=request.model_dump(), headers=self.headers
)
if response.status_code != 200:
raise ValueError(f"Failed to update message: {response.text}")
return Message(**response.json())

def update_agent(
self,
agent_id: str,
Expand Down Expand Up @@ -1402,6 +1428,29 @@ def create_agent(
)
return agent_state

def update_message(
self,
agent_id: str,
message_id: str,
role: Optional[MessageRole] = None,
text: Optional[str] = None,
name: Optional[str] = None,
tool_calls: Optional[List[ToolCall]] = None,
tool_call_id: Optional[str] = None,
) -> Message:
message = self.server.update_agent_message(
agent_id=agent_id,
request=UpdateMessage(
id=message_id,
role=role,
text=text,
name=name,
tool_calls=tool_calls,
tool_call_id=tool_call_id,
),
)
return message

def update_agent(
self,
agent_id: str,
Expand Down
82 changes: 20 additions & 62 deletions memgpt/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,82 +206,40 @@ def run_agent_loop(
# Check if there's an additional argument that's an integer
command = user_input.strip().split()
pop_amount = int(command[1]) if len(command) > 1 and command[1].isdigit() else 3
n_messages = len(memgpt_agent._messages)
MIN_MESSAGES = 2
if n_messages <= MIN_MESSAGES:
print(f"Agent only has {n_messages} messages in stack, none left to pop")
elif n_messages - pop_amount < MIN_MESSAGES:
print(f"Agent only has {n_messages} messages in stack, cannot pop more than {n_messages - MIN_MESSAGES}")
else:
print(f"Popping last {pop_amount} messages from stack")
for _ in range(min(pop_amount, len(memgpt_agent._messages))):
# remove the message from the internal state of the agent
deleted_message = memgpt_agent._messages.pop()
# then also remove it from recall storage
memgpt_agent.persistence_manager.recall_memory.storage.delete(filters={"id": deleted_message.id})
try:
popped_messages = memgpt_agent.pop_message(count=pop_amount)
except ValueError as e:
print(f"Error popping messages: {e}")
continue

elif user_input.lower() == "/retry":
print(f"Retrying for another answer")
while len(memgpt_agent._messages) > 0:
if memgpt_agent._messages[-1].role == "user":
# we want to pop up to the last user message and send it again
user_message = memgpt_agent._messages[-1].text
deleted_message = memgpt_agent._messages.pop()
# then also remove it from recall storage
memgpt_agent.persistence_manager.recall_memory.storage.delete(filters={"id": deleted_message.id})
break
deleted_message = memgpt_agent._messages.pop()
# then also remove it from recall storage
memgpt_agent.persistence_manager.recall_memory.storage.delete(filters={"id": deleted_message.id})
print(f"Retrying for another answer...")
try:
memgpt_agent.retry_message()
except Exception as e:
print(f"Error retrying message: {e}")
continue

elif user_input.lower() == "/rethink" or user_input.lower().startswith("/rethink "):
if len(user_input) < len("/rethink "):
print("Missing text after the command")
continue
for x in range(len(memgpt_agent.messages) - 1, 0, -1):
msg_obj = memgpt_agent._messages[x]
if msg_obj.role == "assistant":
clean_new_text = user_input[len("/rethink ") :].strip()
msg_obj.text = clean_new_text
# To persist to the database, all we need to do is "re-insert" into recall memory
memgpt_agent.persistence_manager.recall_memory.storage.update(record=msg_obj)
break
try:
memgpt_agent.rethink_message(new_thought=user_input[len("/rethink ") :].strip())
except Exception as e:
print(f"Error rethinking message: {e}")
continue

elif user_input.lower() == "/rewrite" or user_input.lower().startswith("/rewrite "):
if len(user_input) < len("/rewrite "):
print("Missing text after the command")
continue
for x in range(len(memgpt_agent.messages) - 1, 0, -1):
if memgpt_agent.messages[x].get("role") == "assistant":
text = user_input[len("/rewrite ") :].strip()
# Get the current message content
# The rewrite target is the output of send_message
message_obj = memgpt_agent._messages[x]
if message_obj.tool_calls is not None and len(message_obj.tool_calls) > 0:
# Check that we hit an assistant send_message call
name_string = message_obj.tool_calls[0].function.get("name")
if name_string is None or name_string != "send_message":
print("Assistant missing send_message function call")
break # cancel op
args_string = message_obj.tool_calls[0].function.get("arguments")
if args_string is None:
print("Assistant missing send_message function arguments")
break # cancel op
args_json = json_loads(args_string)
if "message" not in args_json:
print("Assistant missing send_message message argument")
break # cancel op

# Once we found our target, rewrite it
args_json["message"] = text
new_args_string = json_dumps(args_json)
message_obj.tool_calls[0].function["arguments"] = new_args_string

# To persist to the database, all we need to do is "re-insert" into recall memory
memgpt_agent.persistence_manager.recall_memory.storage.update(record=message_obj)
break

text = user_input[len("/rewrite ") :].strip()
try:
memgpt_agent.rewrite_message(new_text=text)
except Exception as e:
print(f"Error rewriting message: {e}")
continue

elif user_input.lower() == "/summarize":
Expand Down
10 changes: 8 additions & 2 deletions memgpt/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,14 @@ def load_dialect_impl(self, dialect):

def process_bind_param(self, value, dialect):
if value:
# return [vars(tool) for tool in value]
return value
values = []
for v in value:
if isinstance(v, ToolCall):
values.append(v.model_dump())
else:
values.append(v)
return values

return value

def process_result_value(self, value, dialect):
Expand Down
Loading

0 comments on commit 3634dba

Please sign in to comment.