Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: allow editing the system prompt of an agent post-creation #1585

Merged
merged 5 commits into from
Jul 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 50 additions & 18 deletions memgpt/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,16 +365,6 @@ def append_to_messages(self, added_messages: List[dict]):
]
self._append_to_messages(added_messages_objs)

def _swap_system_message(self, new_system_message: Message):
assert isinstance(new_system_message, Message)
assert new_system_message.role == "system", new_system_message
assert self._messages[0].role == "system", self._messages

self.persistence_manager.swap_system_message(new_system_message)

new_messages = [new_system_message] + self._messages[1:] # swap index 0 (system)
self._messages = new_messages

def _get_ai_reply(
self,
message_sequence: List[Message],
Expand All @@ -401,6 +391,10 @@ def _get_ai_reply(
# putting inner thoughts in func args or not
inner_thoughts_in_kwargs=inner_thoughts_in_kwargs,
)

if len(response.choices) == 0:
raise Exception(f"API call didn't return a message: {response}")

# special case for 'length'
if response.choices[0].finish_reason == "length":
raise Exception("Finish reason was length (maximum context length)")
Expand Down Expand Up @@ -929,21 +923,47 @@ def heartbeat_is_paused(self):
elapsed_time = get_utc_time() - self.pause_heartbeats_start
return elapsed_time.total_seconds() < self.pause_heartbeats_minutes * 60

def rebuild_memory(self):
def _swap_system_message_in_buffer(self, new_system_message: str):
"""Update the system message (NOT prompt) of the Agent (requires updating the internal buffer)"""
assert isinstance(new_system_message, str)
new_system_message_obj = Message.dict_to_message(
agent_id=self.agent_state.id,
user_id=self.agent_state.user_id,
model=self.model,
openai_message_dict={"role": "system", "content": new_system_message},
)

assert new_system_message_obj.role == "system", new_system_message_obj
assert self._messages[0].role == "system", self._messages

self.persistence_manager.swap_system_message(new_system_message_obj)

new_messages = [new_system_message_obj] + self._messages[1:] # swap index 0 (system)
self._messages = new_messages

def rebuild_memory(self, force=False, update_timestamp=True):
"""Rebuilds the system message with the latest memory object"""
curr_system_message = self.messages[0] # this is the system + memory bank, not just the system prompt

# NOTE: This is a hacky way to check if the memory has changed
memory_repr = str(self.memory)
if memory_repr == curr_system_message["content"][-(len(memory_repr)) :]:
if not force and memory_repr == curr_system_message["content"][-(len(memory_repr)) :]:
printd(f"Memory has not changed, not rebuilding system")
return

# If the memory didn't update, we probably don't want to update the timestamp inside
# For example, if we're doing a system prompt swap, this should probably be False
if update_timestamp:
memory_edit_timestamp = get_utc_time()
else:
# NOTE: a bit of a hack - we pull the timestamp from the message created_by
memory_edit_timestamp = self._messages[0].created_at

# update memory (TODO: potentially update recall/archival stats seperately)
new_system_message_str = compile_system_message(
system_prompt=self.system,
in_context_memory=self.memory,
in_context_memory_last_edit=get_utc_time(), # NOTE: new timestamp
in_context_memory_last_edit=memory_edit_timestamp,
archival_memory=self.persistence_manager.archival_memory,
recall_memory=self.persistence_manager.recall_memory,
user_defined_variables=None,
Expand All @@ -959,16 +979,28 @@ def rebuild_memory(self):
printd(f"Rebuilding system with new memory...\nDiff:\n{diff}")

# Swap the system message out (only if there is a diff)
self._swap_system_message(
Message.dict_to_message(
agent_id=self.agent_state.id, user_id=self.agent_state.user_id, model=self.model, openai_message_dict=new_system_message
)
)
self._swap_system_message_in_buffer(new_system_message=new_system_message_str)
assert self.messages[0]["content"] == new_system_message["content"], (
self.messages[0]["content"],
new_system_message["content"],
)

def update_system_prompt(self, new_system_prompt: str):
"""Update the system prompt of the agent (requires rebuilding the memory block if there's a difference)"""
assert isinstance(new_system_prompt, str)

if new_system_prompt == self.system:
input("same???")
return

self.system = new_system_prompt

# updating the system prompt requires rebuilding the memory block inside the compiled system message
self.rebuild_memory(force=True, update_timestamp=False)

# make sure to persist the change
_ = self.update_state()

def add_function(self, function_name: str) -> str:
# TODO: refactor
raise NotImplementedError
Expand Down
35 changes: 35 additions & 0 deletions memgpt/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,41 @@ def run_agent_loop(
questionary.print(f" {desc}")
continue

elif user_input.lower().startswith("/systemswap"):
if len(user_input) < len("/systemswap "):
print("Missing new system prompt after the command")
continue
old_system_prompt = memgpt_agent.system
new_system_prompt = user_input[len("/systemswap ") :].strip()

# Show warning and prompts to user
typer.secho(
"\nWARNING: You are about to change the system prompt.",
# fg=typer.colors.BRIGHT_YELLOW,
bold=True,
)
typer.secho(
f"\nOld system prompt:\n{old_system_prompt}",
fg=typer.colors.RED,
bold=True,
)
typer.secho(
f"\nNew system prompt:\n{new_system_prompt}",
fg=typer.colors.GREEN,
bold=True,
)

# Ask for confirmation
confirm = questionary.confirm("Do you want to proceed with the swap?").ask()

if confirm:
memgpt_agent.update_system_prompt(new_system_prompt=new_system_prompt)
print("System prompt updated successfully.")
else:
print("System prompt swap cancelled.")

continue

else:
print(f"Unrecognized command: {user_input}")
continue
Expand Down
Loading