Skip to content

Commit

Permalink
feat: allow editing the system prompt of an agent post-creation (#1585)
Browse files Browse the repository at this point in the history
  • Loading branch information
cpacker authored Jul 29, 2024
1 parent 9c9411e commit 257d4b8
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 18 deletions.
68 changes: 50 additions & 18 deletions memgpt/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,16 +365,6 @@ def append_to_messages(self, added_messages: List[dict]):
]
self._append_to_messages(added_messages_objs)

def _swap_system_message(self, new_system_message: Message):
assert isinstance(new_system_message, Message)
assert new_system_message.role == "system", new_system_message
assert self._messages[0].role == "system", self._messages

self.persistence_manager.swap_system_message(new_system_message)

new_messages = [new_system_message] + self._messages[1:] # swap index 0 (system)
self._messages = new_messages

def _get_ai_reply(
self,
message_sequence: List[Message],
Expand All @@ -401,6 +391,10 @@ def _get_ai_reply(
# putting inner thoughts in func args or not
inner_thoughts_in_kwargs=inner_thoughts_in_kwargs,
)

if len(response.choices) == 0:
raise Exception(f"API call didn't return a message: {response}")

# special case for 'length'
if response.choices[0].finish_reason == "length":
raise Exception("Finish reason was length (maximum context length)")
Expand Down Expand Up @@ -929,21 +923,47 @@ def heartbeat_is_paused(self):
elapsed_time = get_utc_time() - self.pause_heartbeats_start
return elapsed_time.total_seconds() < self.pause_heartbeats_minutes * 60

def rebuild_memory(self):
def _swap_system_message_in_buffer(self, new_system_message: str):
"""Update the system message (NOT prompt) of the Agent (requires updating the internal buffer)"""
assert isinstance(new_system_message, str)
new_system_message_obj = Message.dict_to_message(
agent_id=self.agent_state.id,
user_id=self.agent_state.user_id,
model=self.model,
openai_message_dict={"role": "system", "content": new_system_message},
)

assert new_system_message_obj.role == "system", new_system_message_obj
assert self._messages[0].role == "system", self._messages

self.persistence_manager.swap_system_message(new_system_message_obj)

new_messages = [new_system_message_obj] + self._messages[1:] # swap index 0 (system)
self._messages = new_messages

def rebuild_memory(self, force=False, update_timestamp=True):
"""Rebuilds the system message with the latest memory object"""
curr_system_message = self.messages[0] # this is the system + memory bank, not just the system prompt

# NOTE: This is a hacky way to check if the memory has changed
memory_repr = str(self.memory)
if memory_repr == curr_system_message["content"][-(len(memory_repr)) :]:
if not force and memory_repr == curr_system_message["content"][-(len(memory_repr)) :]:
printd(f"Memory has not changed, not rebuilding system")
return

# If the memory didn't update, we probably don't want to update the timestamp inside
# For example, if we're doing a system prompt swap, this should probably be False
if update_timestamp:
memory_edit_timestamp = get_utc_time()
else:
# NOTE: a bit of a hack - we pull the timestamp from the message created_by
memory_edit_timestamp = self._messages[0].created_at

# update memory (TODO: potentially update recall/archival stats seperately)
new_system_message_str = compile_system_message(
system_prompt=self.system,
in_context_memory=self.memory,
in_context_memory_last_edit=get_utc_time(), # NOTE: new timestamp
in_context_memory_last_edit=memory_edit_timestamp,
archival_memory=self.persistence_manager.archival_memory,
recall_memory=self.persistence_manager.recall_memory,
user_defined_variables=None,
Expand All @@ -959,16 +979,28 @@ def rebuild_memory(self):
printd(f"Rebuilding system with new memory...\nDiff:\n{diff}")

# Swap the system message out (only if there is a diff)
self._swap_system_message(
Message.dict_to_message(
agent_id=self.agent_state.id, user_id=self.agent_state.user_id, model=self.model, openai_message_dict=new_system_message
)
)
self._swap_system_message_in_buffer(new_system_message=new_system_message_str)
assert self.messages[0]["content"] == new_system_message["content"], (
self.messages[0]["content"],
new_system_message["content"],
)

def update_system_prompt(self, new_system_prompt: str):
"""Update the system prompt of the agent (requires rebuilding the memory block if there's a difference)"""
assert isinstance(new_system_prompt, str)

if new_system_prompt == self.system:
input("same???")
return

self.system = new_system_prompt

# updating the system prompt requires rebuilding the memory block inside the compiled system message
self.rebuild_memory(force=True, update_timestamp=False)

# make sure to persist the change
_ = self.update_state()

def add_function(self, function_name: str) -> str:
# TODO: refactor
raise NotImplementedError
Expand Down
35 changes: 35 additions & 0 deletions memgpt/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,41 @@ def run_agent_loop(
questionary.print(f" {desc}")
continue

elif user_input.lower().startswith("/systemswap"):
if len(user_input) < len("/systemswap "):
print("Missing new system prompt after the command")
continue
old_system_prompt = memgpt_agent.system
new_system_prompt = user_input[len("/systemswap ") :].strip()

# Show warning and prompts to user
typer.secho(
"\nWARNING: You are about to change the system prompt.",
# fg=typer.colors.BRIGHT_YELLOW,
bold=True,
)
typer.secho(
f"\nOld system prompt:\n{old_system_prompt}",
fg=typer.colors.RED,
bold=True,
)
typer.secho(
f"\nNew system prompt:\n{new_system_prompt}",
fg=typer.colors.GREEN,
bold=True,
)

# Ask for confirmation
confirm = questionary.confirm("Do you want to proceed with the swap?").ask()

if confirm:
memgpt_agent.update_system_prompt(new_system_prompt=new_system_prompt)
print("System prompt updated successfully.")
else:
print("System prompt swap cancelled.")

continue

else:
print(f"Unrecognized command: {user_input}")
continue
Expand Down

0 comments on commit 257d4b8

Please sign in to comment.