Skip to content

Commit

Permalink
Redis-backed Entity Memory (#2397)
Browse files Browse the repository at this point in the history
I wanted to be able to persist Entity Memory in a Redis database, so I
abstracted `ConversationEntityMemory` to allow for pluggable Entity
stores (d06f90d).

Then I implemented a Entity store that... erm... stores Entities in
Redis. By default, Entities will expire from memory after 24 hours, but
they'll be persisted for another 3 days every time they're recalled. The
idea is to give the AIs a bit of a spaced-repetition memory, but I have
yet to see if this is useful. The memory is partitioned by `session_id`
(user ID? chat channel? whatever, really) so entities from one user
don't leak to another.

While developing this, I did notice that the Entity summaries are kind
of buggy (they summarize AI-generated content and not just information
the human gave them, sometimes they add things like "No new information
provided. Existing summary remains: As stated previously, X", etc.), but
I'll tackle that later. First I wanted to get some input on this idea.
  • Loading branch information
alexiri authored Apr 7, 2023
1 parent aa439ac commit 5e83016
Show file tree
Hide file tree
Showing 2 changed files with 143 additions and 8 deletions.
6 changes: 5 additions & 1 deletion langchain/memory/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,10 @@
from langchain.memory.chat_message_histories.in_memory import ChatMessageHistory
from langchain.memory.chat_message_histories.redis import RedisChatMessageHistory
from langchain.memory.combined import CombinedMemory
from langchain.memory.entity import ConversationEntityMemory
from langchain.memory.entity import (
ConversationEntityMemory,
ConversationEntityRedisMemory,
)
from langchain.memory.kg import ConversationKGMemory
from langchain.memory.readonly import ReadOnlySharedMemory
from langchain.memory.simple import SimpleMemory
Expand All @@ -23,6 +26,7 @@
"ConversationSummaryBufferMemory",
"ConversationKGMemory",
"ConversationEntityMemory",
"ConversationEntityRedisMemory",
"ConversationSummaryMemory",
"ChatMessageHistory",
"ConversationStringBufferMemory",
Expand Down
145 changes: 138 additions & 7 deletions langchain/memory/entity.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from typing import Any, Dict, List, Optional
import logging
from abc import abstractmethod
from itertools import islice
from typing import Any, Dict, Iterable, List, Optional

from langchain.chains.llm import LLMChain
from langchain.memory.chat_memory import BaseChatMemory
Expand All @@ -10,20 +13,46 @@
from langchain.prompts.base import BasePromptTemplate
from langchain.schema import BaseLanguageModel, BaseMessage, get_buffer_string

logger = logging.getLogger(__name__)

class ConversationEntityMemory(BaseChatMemory):

class BaseConversationEntityMemory(BaseChatMemory):
"""Entity extractor & summarizer to memory."""

human_prefix: str = "Human"
ai_prefix: str = "AI"
llm: BaseLanguageModel
entity_extraction_prompt: BasePromptTemplate = ENTITY_EXTRACTION_PROMPT
entity_summarization_prompt: BasePromptTemplate = ENTITY_SUMMARIZATION_PROMPT
store: Dict[str, Optional[str]] = {}
entity_cache: List[str] = []
k: int = 3
chat_history_key: str = "history"

@abstractmethod
def store_get(self, key: str, default: Optional[str] = None) -> Optional[str]:
"""Get entity value from store."""
pass

@abstractmethod
def store_set(self, key: str, value: Optional[str]) -> None:
"""Set entity value in store."""
pass

@abstractmethod
def store_del(self, key: str) -> None:
"""Delete entity value from store."""
pass

@abstractmethod
def store_exists(self, key: str) -> bool:
"""Check if entity exists in store."""
pass

@abstractmethod
def store_clear(self) -> None:
"""Delete all entities from store."""
pass

@property
def buffer(self) -> List[BaseMessage]:
return self.chat_memory.messages
Expand Down Expand Up @@ -58,7 +87,7 @@ def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
entities = [w.strip() for w in output.split(",")]
entity_summaries = {}
for entity in entities:
entity_summaries[entity] = self.store.get(entity, "")
entity_summaries[entity] = self.store_get(entity, "")
self.entity_cache = entities
if self.return_messages:
buffer: Any = self.buffer[-self.k * 2 :]
Expand Down Expand Up @@ -87,16 +116,118 @@ def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
chain = LLMChain(llm=self.llm, prompt=self.entity_summarization_prompt)

for entity in self.entity_cache:
existing_summary = self.store.get(entity, "")
existing_summary = self.store_get(entity, "")
output = chain.predict(
summary=existing_summary,
entity=entity,
history=buffer_string,
input=input_data,
)
self.store[entity] = output.strip()
self.store_set(entity, output.strip())

def clear(self) -> None:
"""Clear memory contents."""
self.chat_memory.clear()
self.store = {}
self.store_clear()


class ConversationEntityMemory(BaseConversationEntityMemory):
"""Basic in-memory entity store."""

store: Dict[str, Optional[str]] = {}

def store_get(self, key: str, default: Optional[str] = None) -> Optional[str]:
return self.store.get(key, default)

def store_set(self, key: str, value: Optional[str]) -> None:
self.store[key] = value

def store_del(self, key: str) -> None:
del self.store[key]

def store_exists(self, key: str) -> bool:
return key in self.store

def store_clear(self) -> None:
return self.store.clear()


class ConversationEntityRedisMemory(BaseConversationEntityMemory):
"""Redis-backed Entity store. Entities get a TTL of 1 day by default, and
that TTL is extended by 3 days every time the entity is read back.
"""

redis_client: Any
session_id: str = "default"
key_prefix: str = "memory_store"
ttl: Optional[int] = 60 * 60 * 24
recall_ttl: Optional[int] = 60 * 60 * 24 * 3

def __init__(
self,
session_id: str = "default",
url: str = "redis://localhost:6379/0",
key_prefix: str = "memory_store",
ttl: Optional[int] = 60 * 60 * 24,
recall_ttl: Optional[int] = 60 * 60 * 24 * 3,
*args: Any,
**kwargs: Any,
):
try:
import redis
except ImportError:
raise ValueError(
"Could not import redis python package. "
"Please install it with `pip install redis`."
)

super().__init__(*args, **kwargs)

try:
self.redis_client = redis.Redis.from_url(url=url, decode_responses=True)
except redis.exceptions.ConnectionError as error:
logger.error(error)

self.session_id = session_id
self.key_prefix = key_prefix
self.ttl = ttl
self.recall_ttl = recall_ttl or ttl

@property
def full_key_prefix(self) -> str:
return f"{self.key_prefix}:{self.session_id}"

def store_get(self, key: str, default: Optional[str] = None) -> Optional[str]:
res = (
self.redis_client.getex(f"{self.full_key_prefix}:{key}", ex=self.recall_ttl)
or default
or ""
)
logger.debug(f"REDIS MEM get '{self.full_key_prefix}:{key}': '{res}'")
return res

def store_set(self, key: str, value: Optional[str]) -> None:
if not value:
return self.store_del(key)
self.redis_client.set(f"{self.full_key_prefix}:{key}", value, ex=self.ttl)
logger.debug(
f"REDIS MEM set '{self.full_key_prefix}:{key}': '{value}' EX {self.ttl}"
)

def store_del(self, key: str) -> None:
self.redis_client.delete(f"{self.full_key_prefix}:{key}")

def store_exists(self, key: str) -> bool:
return self.redis_client.exists(f"{self.full_key_prefix}:{key}") == 1

def store_clear(self) -> None:
# iterate a list in batches of size batch_size
def batched(iterable: Iterable[Any], batch_size: int) -> Iterable[Any]:
iterator = iter(iterable)
while batch := list(islice(iterator, batch_size)):
yield batch

for keybatch in batched(
self.redis_client.scan_iter(f"{self.full_key_prefix}:*"), 500
):
self.redis_client.delete(*keybatch)

0 comments on commit 5e83016

Please sign in to comment.