Skip to content

Commit

Permalink
fix: cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
cpacker committed Sep 3, 2024
1 parent d0aec04 commit d8e2dcb
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 22 deletions.
2 changes: 1 addition & 1 deletion memgpt/schemas/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ class Block(BaseBlock):
"""Block of the LLM context"""

id: str = BaseBlock.generate_id_field()
value: str = Field(..., description="Value of the block.")
value: Union[str, List[str]] = Field(..., description="Value of the block.")


class Human(Block):
Expand Down
39 changes: 22 additions & 17 deletions memgpt/schemas/memory.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
from typing import Dict, List, Optional, Union
from typing import TYPE_CHECKING, Dict, List, Optional, Union

from jinja2 import Template, TemplateSyntaxError
from pydantic import BaseModel, Field

# Forward referencing to avoid circular import with Agent -> Memory -> Agent
if TYPE_CHECKING:
from memgpt.agent import Agent

from memgpt.schemas.block import Block


Expand All @@ -13,7 +17,7 @@ class Memory(BaseModel, validate_assignment=True):
memory: Dict[str, Block] = Field(default_factory=dict, description="Mapping from memory block section to memory block.")

# Memory.template is a Jinja2 template for compiling memory module into a prompt string.
template: str = Field(
prompt_template: str = Field(
default="{% for section, module in memory.items() %}"
'<{{ section }} characters="{{ module.value|length }}/{{ module.limit }}">\n'
"{{ module.value }}\n"
Expand All @@ -23,36 +27,36 @@ class Memory(BaseModel, validate_assignment=True):
description="Jinja2 template for compiling memory module into a prompt string",
)

def get_template(self) -> str:
def get_prompt_template(self) -> str:
"""Return the current Jinja2 template string."""
return str(self.template)
return str(self.prompt_template)

def set_template(self, template: str):
def set_prompt_template(self, prompt_template: str):
"""
Set a new Jinja2 template string.
Validates the template syntax and compatibility with current memory structure.
"""
try:
# Validate Jinja2 syntax
Template(template)
Template(prompt_template)

# Validate compatibility with current memory structure
test_render = Template(template).render(memory=self.memory)
test_render = Template(prompt_template).render(memory=self.memory)

# If we get here, the template is valid and compatible
self.template = template
self.prompt_template = prompt_template
except TemplateSyntaxError as e:
raise ValueError(f"Invalid Jinja2 template syntax: {str(e)}")
except Exception as e:
raise ValueError(f"Template is not compatible with current memory structure: {str(e)}")
raise ValueError(f"Prompt template is not compatible with current memory structure: {str(e)}")

@classmethod
def load(cls, state: dict):
"""Load memory from dictionary object"""
obj = cls()
if len(state.keys()) == 2 and "memory" in state and "template" in state:
if len(state.keys()) == 2 and "memory" in state and "prompt_template" in state:
# New format
obj.template = state["template"]
obj.prompt_template = state["prompt_template"]
for key, value in state["memory"].items():
obj.memory[key] = Block(**value)
else:
Expand All @@ -63,14 +67,14 @@ def load(cls, state: dict):

def compile(self) -> str:
"""Generate a string representation of the memory in-context using the Jinja2 template"""
template = Template(self.template)
template = Template(self.prompt_template)
return template.render(memory=self.memory)

def to_dict(self):
"""Convert to dictionary representation"""
return {
"memory": {key: value.dict() for key, value in self.memory.items()},
"template": self.template,
"memory": {key: value.model_dump() for key, value in self.memory.items()},
"prompt_template": self.prompt_template,
}

def to_flat_dict(self):
Expand All @@ -84,7 +88,7 @@ def list_block_names(self) -> List[str]:
def get_block(self, name: str) -> Block:
"""Correct way to index into the memory.memory field, returns a Block"""
if name not in self.memory:
return KeyError(f"Block field {name} does not exist (available sections = {', '.join(list(self.memory.keys()))})")
raise KeyError(f"Block field {name} does not exist (available sections = {', '.join(list(self.memory.keys()))})")
else:
return self.memory[name]

Expand All @@ -111,7 +115,8 @@ def update_block_value(self, name: str, value: Union[List[str], str]):

# TODO: ideally this is refactored into ChatMemory and the subclasses are given more specific names.
class BaseChatMemory(Memory):
def core_memory_append(self, name: str, content: str) -> Optional[str]:

def core_memory_append(self: "Agent", name: str, content: str) -> Optional[str]: # type: ignore
"""
Append to the contents of core memory.
Expand All @@ -127,7 +132,7 @@ def core_memory_append(self, name: str, content: str) -> Optional[str]:
self.memory.update_block_value(name=name, value=new_value)
return None

def core_memory_replace(self, name: str, old_content: str, new_content: str) -> Optional[str]:
def core_memory_replace(self: "Agent", name: str, old_content: str, new_content: str) -> Optional[str]: # type: ignore
"""
Replace the contents of core memory. To delete memories, use an empty string for new_content.
Expand Down
8 changes: 4 additions & 4 deletions tests/test_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,10 +100,10 @@ def old_repr(self: Memory) -> str:
def test_memory_jinja2_set_template(sample_memory: Memory):
"""Test setting the template for the memory"""

example_template = sample_memory.get_template()
example_template = sample_memory.get_prompt_template()

# Try setting a valid template
sample_memory.set_template(template=example_template)
sample_memory.set_prompt_template(prompt_template=example_template)

# Try setting an invalid template (bad jinja2)
template_bad_jinja = (
Expand All @@ -115,7 +115,7 @@ def test_memory_jinja2_set_template(sample_memory: Memory):
"{% endfor %" # Missing closing curly brace
)
with pytest.raises(ValueError):
sample_memory.set_template(template=template_bad_jinja)
sample_memory.set_prompt_template(prompt_template=template_bad_jinja)

# Try setting an invalid template (not compatible with memory structure)
template_bad_memory_structure = (
Expand All @@ -127,4 +127,4 @@ def test_memory_jinja2_set_template(sample_memory: Memory):
"{% endfor %}"
)
with pytest.raises(ValueError):
sample_memory.set_template(template=template_bad_memory_structure)
sample_memory.set_prompt_template(prompt_template=template_bad_memory_structure)

0 comments on commit d8e2dcb

Please sign in to comment.