From f93e39ecd2e168e159b290924a8abae210206e84 Mon Sep 17 00:00:00 2001 From: Charles Packer Date: Thu, 1 Aug 2024 14:18:54 -0700 Subject: [PATCH] feat: added system prompt override to the CLI (#1602) --- memgpt/cli/cli.py | 15 +++++++++++++++ memgpt/client/client.py | 6 ++++++ 2 files changed, 21 insertions(+) diff --git a/memgpt/cli/cli.py b/memgpt/cli/cli.py index d00b13918b..8e3752a848 100644 --- a/memgpt/cli/cli.py +++ b/memgpt/cli/cli.py @@ -395,6 +395,7 @@ def run( persona: Annotated[Optional[str], typer.Option(help="Specify persona")] = None, agent: Annotated[Optional[str], typer.Option(help="Specify agent name")] = None, human: Annotated[Optional[str], typer.Option(help="Specify human")] = None, + system: Annotated[Optional[str], typer.Option(help="Specify system prompt (raw text)")] = None, # model flags model: Annotated[Optional[str], typer.Option(help="Specify the LLM model")] = None, model_wrapper: Annotated[Optional[str], typer.Option(help="Specify the LLM model wrapper")] = None, @@ -584,6 +585,16 @@ def run( ) agent_state.llm_config.model_endpoint_type = model_endpoint_type + # NOTE: commented out because this seems dangerous - instead users should use /systemswap when in the CLI + # # user specified a new system prompt + # if system: + # # NOTE: agent_state.system is the ORIGINAL system prompt, + # # whereas agent_state.state["system"] is the LATEST system prompt + # existing_system_prompt = agent_state.state["system"] if "system" in agent_state.state else None + # if existing_system_prompt != system: + # # override + # agent_state.state["system"] = system + # Update the agent with any overrides ms.update_agent(agent_state) tools = [] @@ -638,6 +649,9 @@ def run( client = create_client() human_obj = ms.get_human(human, user.id) persona_obj = ms.get_persona(persona, user.id) + # TODO pull system prompts from the metadata store + # NOTE: will be overriden later to a default + system_prompt = system if system else None if human_obj is None: typer.secho("Couldn't find human {human} in database, please run `memgpt add human`", fg=typer.colors.RED) if persona_obj is None: @@ -652,6 +666,7 @@ def run( # add tools agent_state = client.create_agent( name=agent_name, + system_prompt=system_prompt, embedding_config=embedding_config, llm_config=llm_config, memory=memory, diff --git a/memgpt/client/client.py b/memgpt/client/client.py index cfe0731379..c60cf3857b 100644 --- a/memgpt/client/client.py +++ b/memgpt/client/client.py @@ -260,6 +260,8 @@ def create_agent( llm_config: Optional[LLMConfig] = None, # memory memory: BaseMemory = ChatMemory(human=get_human_text(DEFAULT_HUMAN), persona=get_human_text(DEFAULT_PERSONA)), + # system prompt (can be templated) + system_prompt: Optional[str] = None, # tools tools: Optional[List[str]] = None, include_base_tools: Optional[bool] = True, @@ -298,6 +300,7 @@ def create_agent( "config": { "name": name, "preset": preset, + "system": system_prompt, "persona": memory.memory["persona"].value, "human": memory.memory["human"].value, "function_names": tool_names, @@ -727,6 +730,8 @@ def create_agent( llm_config: Optional[LLMConfig] = None, # memory memory: BaseMemory = ChatMemory(human=get_human_text(DEFAULT_HUMAN), persona=get_human_text(DEFAULT_PERSONA)), + # system prompt (can be templated) + system_prompt: Optional[str] = None, # tools tools: Optional[List[str]] = None, include_base_tools: Optional[bool] = True, @@ -756,6 +761,7 @@ def create_agent( user_id=self.user_id, name=name, memory=memory, + system=system_prompt, llm_config=llm_config, embedding_config=embedding_config, tools=tool_names,