From 257d4b82de2e6336f6043e9e27f4fc6d2143e9e4 Mon Sep 17 00:00:00 2001 From: Charles Packer Date: Sun, 28 Jul 2024 20:14:34 -0700 Subject: [PATCH] feat: allow editing the system prompt of an agent post-creation (#1585) --- memgpt/agent.py | 68 ++++++++++++++++++++++++++++++++++++------------- memgpt/main.py | 35 +++++++++++++++++++++++++ 2 files changed, 85 insertions(+), 18 deletions(-) diff --git a/memgpt/agent.py b/memgpt/agent.py index 9d60a91ade..7fa6e47d80 100644 --- a/memgpt/agent.py +++ b/memgpt/agent.py @@ -365,16 +365,6 @@ def append_to_messages(self, added_messages: List[dict]): ] self._append_to_messages(added_messages_objs) - def _swap_system_message(self, new_system_message: Message): - assert isinstance(new_system_message, Message) - assert new_system_message.role == "system", new_system_message - assert self._messages[0].role == "system", self._messages - - self.persistence_manager.swap_system_message(new_system_message) - - new_messages = [new_system_message] + self._messages[1:] # swap index 0 (system) - self._messages = new_messages - def _get_ai_reply( self, message_sequence: List[Message], @@ -401,6 +391,10 @@ def _get_ai_reply( # putting inner thoughts in func args or not inner_thoughts_in_kwargs=inner_thoughts_in_kwargs, ) + + if len(response.choices) == 0: + raise Exception(f"API call didn't return a message: {response}") + # special case for 'length' if response.choices[0].finish_reason == "length": raise Exception("Finish reason was length (maximum context length)") @@ -929,21 +923,47 @@ def heartbeat_is_paused(self): elapsed_time = get_utc_time() - self.pause_heartbeats_start return elapsed_time.total_seconds() < self.pause_heartbeats_minutes * 60 - def rebuild_memory(self): + def _swap_system_message_in_buffer(self, new_system_message: str): + """Update the system message (NOT prompt) of the Agent (requires updating the internal buffer)""" + assert isinstance(new_system_message, str) + new_system_message_obj = Message.dict_to_message( + agent_id=self.agent_state.id, + user_id=self.agent_state.user_id, + model=self.model, + openai_message_dict={"role": "system", "content": new_system_message}, + ) + + assert new_system_message_obj.role == "system", new_system_message_obj + assert self._messages[0].role == "system", self._messages + + self.persistence_manager.swap_system_message(new_system_message_obj) + + new_messages = [new_system_message_obj] + self._messages[1:] # swap index 0 (system) + self._messages = new_messages + + def rebuild_memory(self, force=False, update_timestamp=True): """Rebuilds the system message with the latest memory object""" curr_system_message = self.messages[0] # this is the system + memory bank, not just the system prompt # NOTE: This is a hacky way to check if the memory has changed memory_repr = str(self.memory) - if memory_repr == curr_system_message["content"][-(len(memory_repr)) :]: + if not force and memory_repr == curr_system_message["content"][-(len(memory_repr)) :]: printd(f"Memory has not changed, not rebuilding system") return + # If the memory didn't update, we probably don't want to update the timestamp inside + # For example, if we're doing a system prompt swap, this should probably be False + if update_timestamp: + memory_edit_timestamp = get_utc_time() + else: + # NOTE: a bit of a hack - we pull the timestamp from the message created_by + memory_edit_timestamp = self._messages[0].created_at + # update memory (TODO: potentially update recall/archival stats seperately) new_system_message_str = compile_system_message( system_prompt=self.system, in_context_memory=self.memory, - in_context_memory_last_edit=get_utc_time(), # NOTE: new timestamp + in_context_memory_last_edit=memory_edit_timestamp, archival_memory=self.persistence_manager.archival_memory, recall_memory=self.persistence_manager.recall_memory, user_defined_variables=None, @@ -959,16 +979,28 @@ def rebuild_memory(self): printd(f"Rebuilding system with new memory...\nDiff:\n{diff}") # Swap the system message out (only if there is a diff) - self._swap_system_message( - Message.dict_to_message( - agent_id=self.agent_state.id, user_id=self.agent_state.user_id, model=self.model, openai_message_dict=new_system_message - ) - ) + self._swap_system_message_in_buffer(new_system_message=new_system_message_str) assert self.messages[0]["content"] == new_system_message["content"], ( self.messages[0]["content"], new_system_message["content"], ) + def update_system_prompt(self, new_system_prompt: str): + """Update the system prompt of the agent (requires rebuilding the memory block if there's a difference)""" + assert isinstance(new_system_prompt, str) + + if new_system_prompt == self.system: + input("same???") + return + + self.system = new_system_prompt + + # updating the system prompt requires rebuilding the memory block inside the compiled system message + self.rebuild_memory(force=True, update_timestamp=False) + + # make sure to persist the change + _ = self.update_state() + def add_function(self, function_name: str) -> str: # TODO: refactor raise NotImplementedError diff --git a/memgpt/main.py b/memgpt/main.py index 1de91bd69a..fcf392ddb6 100644 --- a/memgpt/main.py +++ b/memgpt/main.py @@ -377,6 +377,41 @@ def run_agent_loop( questionary.print(f" {desc}") continue + elif user_input.lower().startswith("/systemswap"): + if len(user_input) < len("/systemswap "): + print("Missing new system prompt after the command") + continue + old_system_prompt = memgpt_agent.system + new_system_prompt = user_input[len("/systemswap ") :].strip() + + # Show warning and prompts to user + typer.secho( + "\nWARNING: You are about to change the system prompt.", + # fg=typer.colors.BRIGHT_YELLOW, + bold=True, + ) + typer.secho( + f"\nOld system prompt:\n{old_system_prompt}", + fg=typer.colors.RED, + bold=True, + ) + typer.secho( + f"\nNew system prompt:\n{new_system_prompt}", + fg=typer.colors.GREEN, + bold=True, + ) + + # Ask for confirmation + confirm = questionary.confirm("Do you want to proceed with the swap?").ask() + + if confirm: + memgpt_agent.update_system_prompt(new_system_prompt=new_system_prompt) + print("System prompt updated successfully.") + else: + print("System prompt swap cancelled.") + + continue + else: print(f"Unrecognized command: {user_input}") continue