From d8e2dcb9bcafca7a55ffb8113f04347c6eaddeaa Mon Sep 17 00:00:00 2001 From: cpacker Date: Tue, 3 Sep 2024 14:48:45 -0700 Subject: [PATCH] fix: cleanup --- memgpt/schemas/block.py | 2 +- memgpt/schemas/memory.py | 39 ++++++++++++++++++++++----------------- tests/test_memory.py | 8 ++++---- 3 files changed, 27 insertions(+), 22 deletions(-) diff --git a/memgpt/schemas/block.py b/memgpt/schemas/block.py index 79b8ef85d4..1d79dc497b 100644 --- a/memgpt/schemas/block.py +++ b/memgpt/schemas/block.py @@ -62,7 +62,7 @@ class Block(BaseBlock): """Block of the LLM context""" id: str = BaseBlock.generate_id_field() - value: str = Field(..., description="Value of the block.") + value: Union[str, List[str]] = Field(..., description="Value of the block.") class Human(Block): diff --git a/memgpt/schemas/memory.py b/memgpt/schemas/memory.py index 53e92c8c08..ad5ed27aa5 100644 --- a/memgpt/schemas/memory.py +++ b/memgpt/schemas/memory.py @@ -1,8 +1,12 @@ -from typing import Dict, List, Optional, Union +from typing import TYPE_CHECKING, Dict, List, Optional, Union from jinja2 import Template, TemplateSyntaxError from pydantic import BaseModel, Field +# Forward referencing to avoid circular import with Agent -> Memory -> Agent +if TYPE_CHECKING: + from memgpt.agent import Agent + from memgpt.schemas.block import Block @@ -13,7 +17,7 @@ class Memory(BaseModel, validate_assignment=True): memory: Dict[str, Block] = Field(default_factory=dict, description="Mapping from memory block section to memory block.") # Memory.template is a Jinja2 template for compiling memory module into a prompt string. - template: str = Field( + prompt_template: str = Field( default="{% for section, module in memory.items() %}" '<{{ section }} characters="{{ module.value|length }}/{{ module.limit }}">\n' "{{ module.value }}\n" @@ -23,36 +27,36 @@ class Memory(BaseModel, validate_assignment=True): description="Jinja2 template for compiling memory module into a prompt string", ) - def get_template(self) -> str: + def get_prompt_template(self) -> str: """Return the current Jinja2 template string.""" - return str(self.template) + return str(self.prompt_template) - def set_template(self, template: str): + def set_prompt_template(self, prompt_template: str): """ Set a new Jinja2 template string. Validates the template syntax and compatibility with current memory structure. """ try: # Validate Jinja2 syntax - Template(template) + Template(prompt_template) # Validate compatibility with current memory structure - test_render = Template(template).render(memory=self.memory) + test_render = Template(prompt_template).render(memory=self.memory) # If we get here, the template is valid and compatible - self.template = template + self.prompt_template = prompt_template except TemplateSyntaxError as e: raise ValueError(f"Invalid Jinja2 template syntax: {str(e)}") except Exception as e: - raise ValueError(f"Template is not compatible with current memory structure: {str(e)}") + raise ValueError(f"Prompt template is not compatible with current memory structure: {str(e)}") @classmethod def load(cls, state: dict): """Load memory from dictionary object""" obj = cls() - if len(state.keys()) == 2 and "memory" in state and "template" in state: + if len(state.keys()) == 2 and "memory" in state and "prompt_template" in state: # New format - obj.template = state["template"] + obj.prompt_template = state["prompt_template"] for key, value in state["memory"].items(): obj.memory[key] = Block(**value) else: @@ -63,14 +67,14 @@ def load(cls, state: dict): def compile(self) -> str: """Generate a string representation of the memory in-context using the Jinja2 template""" - template = Template(self.template) + template = Template(self.prompt_template) return template.render(memory=self.memory) def to_dict(self): """Convert to dictionary representation""" return { - "memory": {key: value.dict() for key, value in self.memory.items()}, - "template": self.template, + "memory": {key: value.model_dump() for key, value in self.memory.items()}, + "prompt_template": self.prompt_template, } def to_flat_dict(self): @@ -84,7 +88,7 @@ def list_block_names(self) -> List[str]: def get_block(self, name: str) -> Block: """Correct way to index into the memory.memory field, returns a Block""" if name not in self.memory: - return KeyError(f"Block field {name} does not exist (available sections = {', '.join(list(self.memory.keys()))})") + raise KeyError(f"Block field {name} does not exist (available sections = {', '.join(list(self.memory.keys()))})") else: return self.memory[name] @@ -111,7 +115,8 @@ def update_block_value(self, name: str, value: Union[List[str], str]): # TODO: ideally this is refactored into ChatMemory and the subclasses are given more specific names. class BaseChatMemory(Memory): - def core_memory_append(self, name: str, content: str) -> Optional[str]: + + def core_memory_append(self: "Agent", name: str, content: str) -> Optional[str]: # type: ignore """ Append to the contents of core memory. @@ -127,7 +132,7 @@ def core_memory_append(self, name: str, content: str) -> Optional[str]: self.memory.update_block_value(name=name, value=new_value) return None - def core_memory_replace(self, name: str, old_content: str, new_content: str) -> Optional[str]: + def core_memory_replace(self: "Agent", name: str, old_content: str, new_content: str) -> Optional[str]: # type: ignore """ Replace the contents of core memory. To delete memories, use an empty string for new_content. diff --git a/tests/test_memory.py b/tests/test_memory.py index d641fb55b1..6581fb3119 100644 --- a/tests/test_memory.py +++ b/tests/test_memory.py @@ -100,10 +100,10 @@ def old_repr(self: Memory) -> str: def test_memory_jinja2_set_template(sample_memory: Memory): """Test setting the template for the memory""" - example_template = sample_memory.get_template() + example_template = sample_memory.get_prompt_template() # Try setting a valid template - sample_memory.set_template(template=example_template) + sample_memory.set_prompt_template(prompt_template=example_template) # Try setting an invalid template (bad jinja2) template_bad_jinja = ( @@ -115,7 +115,7 @@ def test_memory_jinja2_set_template(sample_memory: Memory): "{% endfor %" # Missing closing curly brace ) with pytest.raises(ValueError): - sample_memory.set_template(template=template_bad_jinja) + sample_memory.set_prompt_template(prompt_template=template_bad_jinja) # Try setting an invalid template (not compatible with memory structure) template_bad_memory_structure = ( @@ -127,4 +127,4 @@ def test_memory_jinja2_set_template(sample_memory: Memory): "{% endfor %}" ) with pytest.raises(ValueError): - sample_memory.set_template(template=template_bad_memory_structure) + sample_memory.set_prompt_template(prompt_template=template_bad_memory_structure)