From 17fe52baf4cc56703c07edcb104fda9309aba164 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Glauco=20Cust=C3=B3dio?= Date: Thu, 10 Aug 2023 19:31:28 +0100 Subject: [PATCH 1/2] add ttl to RedisCache --- libs/langchain/langchain/cache.py | 26 ++++++++++++++++--- .../cache/test_redis_cache.py | 12 +++++++++ 2 files changed, 34 insertions(+), 4 deletions(-) diff --git a/libs/langchain/langchain/cache.py b/libs/langchain/langchain/cache.py index 2136acaad18a6..fe036ad55089c 100644 --- a/libs/langchain/langchain/cache.py +++ b/libs/langchain/langchain/cache.py @@ -216,10 +216,25 @@ def __init__(self, database_path: str = ".langchain.db"): class RedisCache(BaseCache): """Cache that uses Redis as a backend.""" - # TODO - implement a TTL policy in Redis - - def __init__(self, redis_: Any): - """Initialize by passing in Redis instance.""" + def __init__(self, redis_: Any, ttl: int = 0): + """ + Initialize an instance of RedisCache. + + This method initializes an object with Redis caching capabilities. + It takes a `redis_` parameter, which should be an instance of a Redis + client class, allowing the object to interact with a Redis + server for caching purposes. + + Parameters: + redis_ (Any): An instance of a Redis client class + (e.g., redis.Redis) used for caching. + This allows the object to communicate with a + Redis server for caching operations. + ttl (int, optional): Time-to-live (TTL) for cached items in seconds. + If provided, it sets + the default time duration for how long cached items will remain valid. + Defaults to 0, indicating no automatic expiration. + """ try: from redis import Redis except ImportError: @@ -230,6 +245,7 @@ def __init__(self, redis_: Any): if not isinstance(redis_, Redis): raise ValueError("Please pass in Redis object.") self.redis = redis_ + self.ttl = ttl def _key(self, prompt: str, llm_string: str) -> str: """Compute key from prompt and llm_string""" @@ -267,6 +283,8 @@ def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> N str(idx): generation.text for idx, generation in enumerate(return_val) }, ) + if self.ttl > 0: + self.redis.expire(key, self.ttl) def clear(self, **kwargs: Any) -> None: """Clear cache. If `asynchronous` is True, flush asynchronously.""" diff --git a/libs/langchain/tests/integration_tests/cache/test_redis_cache.py b/libs/langchain/tests/integration_tests/cache/test_redis_cache.py index 7d43c9a4058e0..c38704c2de6c7 100644 --- a/libs/langchain/tests/integration_tests/cache/test_redis_cache.py +++ b/libs/langchain/tests/integration_tests/cache/test_redis_cache.py @@ -11,6 +11,18 @@ REDIS_TEST_URL = "redis://localhost:6379" +def test_redis_cache_ttl() -> None: + import time + + import redis + + langchain.llm_cache = RedisCache(redis_=redis.Redis.from_url(REDIS_TEST_URL), ttl=1) + langchain.llm_cache.update("foo", "bar", [Generation(text="fizz")]) + key = langchain.llm_cache._key("foo", "bar") + time.sleep(1.1) + assert langchain.llm_cache.redis.hgetall(key) == {} + + def test_redis_cache() -> None: import redis From 95a4c4b291bc579e6eb17a8c09bf803aeeb4ade1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Glauco=20Cust=C3=B3dio?= Date: Fri, 11 Aug 2023 10:59:42 +0100 Subject: [PATCH 2/2] improvements --- libs/langchain/langchain/cache.py | 29 +++++++++++-------- .../cache/test_redis_cache.py | 5 +--- 2 files changed, 18 insertions(+), 16 deletions(-) diff --git a/libs/langchain/langchain/cache.py b/libs/langchain/langchain/cache.py index fe036ad55089c..a78c0608011df 100644 --- a/libs/langchain/langchain/cache.py +++ b/libs/langchain/langchain/cache.py @@ -216,7 +216,7 @@ def __init__(self, database_path: str = ".langchain.db"): class RedisCache(BaseCache): """Cache that uses Redis as a backend.""" - def __init__(self, redis_: Any, ttl: int = 0): + def __init__(self, redis_: Any, *, ttl: Optional[int] = None): """ Initialize an instance of RedisCache. @@ -231,9 +231,9 @@ def __init__(self, redis_: Any, ttl: int = 0): This allows the object to communicate with a Redis server for caching operations. ttl (int, optional): Time-to-live (TTL) for cached items in seconds. - If provided, it sets - the default time duration for how long cached items will remain valid. - Defaults to 0, indicating no automatic expiration. + If provided, it sets the time duration for how long cached + items will remain valid. If not provided, cached items will not + have an automatic expiration. """ try: from redis import Redis @@ -277,14 +277,19 @@ def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> N return # Write to a Redis HASH key = self._key(prompt, llm_string) - self.redis.hset( - key, - mapping={ - str(idx): generation.text for idx, generation in enumerate(return_val) - }, - ) - if self.ttl > 0: - self.redis.expire(key, self.ttl) + + with self.redis.pipeline() as pipe: + pipe.hset( + key, + mapping={ + str(idx): generation.text + for idx, generation in enumerate(return_val) + }, + ) + if self.ttl is not None: + pipe.expire(key, self.ttl) + + pipe.execute() def clear(self, **kwargs: Any) -> None: """Clear cache. If `asynchronous` is True, flush asynchronously.""" diff --git a/libs/langchain/tests/integration_tests/cache/test_redis_cache.py b/libs/langchain/tests/integration_tests/cache/test_redis_cache.py index c38704c2de6c7..5d51a12e8e1fb 100644 --- a/libs/langchain/tests/integration_tests/cache/test_redis_cache.py +++ b/libs/langchain/tests/integration_tests/cache/test_redis_cache.py @@ -12,15 +12,12 @@ def test_redis_cache_ttl() -> None: - import time - import redis langchain.llm_cache = RedisCache(redis_=redis.Redis.from_url(REDIS_TEST_URL), ttl=1) langchain.llm_cache.update("foo", "bar", [Generation(text="fizz")]) key = langchain.llm_cache._key("foo", "bar") - time.sleep(1.1) - assert langchain.llm_cache.redis.hgetall(key) == {} + assert langchain.llm_cache.redis.pttl(key) > 0 def test_redis_cache() -> None: