Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for locking a cache key, to avoid overlapping expensive operations #256

Merged
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 122 additions & 0 deletions aries_cloudagent/cache/base.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,23 @@
"""Abstract base classes for cache."""

import asyncio
from abc import ABC, abstractmethod
from typing import Any, Sequence, Text, Union

from ..error import BaseError


class CacheError(BaseError):
"""Base class for cache-related errors."""


class BaseCache(ABC):
"""Abstract cache interface."""

def __init__(self):
"""Initialize the cache instance."""
self._key_locks = {}

@abstractmethod
async def get(self, key: Text):
"""
Expand Down Expand Up @@ -46,6 +57,117 @@ async def clear(self, key: Text):
async def flush(self):
"""Remove all items from the cache."""

def acquire(self, key: Text):
"""Acquire a lock on a given cache key."""
result = CacheKeyLock(self, key)
first = self._key_locks.setdefault(key, result)
if first is not result:
result.parent = first
return result

def release(self, key: Text):
"""Release the lock on a given cache key."""
if key in self._key_locks:
del self._key_locks[key]

def __repr__(self) -> str:
"""Human readable representation of `BaseStorageRecordSearch`."""
return "<{}>".format(self.__class__.__name__)


class CacheKeyLock:
"""
A lock on a particular cache key.

Used to prevent multiple async threads from generating
or querying the same semi-expensive data. Not thread safe.
"""

def __init__(self, cache: BaseCache, key: Text):
"""Initialize the key lock."""
self.cache = cache
self.exception: BaseException = None
self.key = key
self.released = False
self._future: asyncio.Future = asyncio.get_event_loop().create_future()
self._parent: "CacheKeyLock" = None

@property
def done(self) -> bool:
"""Accessor for the done state."""
return self._future.done()

@property
def future(self) -> asyncio.Future:
"""Fetch the result in the form of an awaitable future."""
return self._future

@property
def result(self) -> Any:
"""Fetch the current result, if any."""
if self.done:
return self._future.result()

@property
def parent(self) -> "CacheKeyLock":
"""Accessor for the parent key lock, if any."""
return self._parent

@parent.setter
def parent(self, parent: "CacheKeyLock"):
"""Setter for the parent lock."""
self._parent = parent
parent._future.add_done_callback(self._handle_parent_done)

def _handle_parent_done(self, fut: asyncio.Future):
"""Handle completion of parent's future."""
result = fut.result()
if result:
self._future.set_result(fut.result())

async def set_result(self, value: Any, ttl: int = None):
"""Set the result, updating the cache and any waiters."""
if self.done and value:
raise CacheError("Result already set")
self._future.set_result(value)
if not self._parent or self._parent.done:
await self.cache.set(self.key, value, ttl)

def __await__(self):
"""Wait for a result to be produced."""
return (yield from self._future)

async def __aenter__(self):
"""Async context manager entry."""
result = None
if self.parent:
result = await self.parent
if result:
await self # wait for parent's done handler to complete
if not result:
found = await self.cache.get(self.key)
if found:
self._future.set_result(found)
return self

def release(self):
"""Release the cache lock."""
if not self.parent and not self.released:
self.cache.release(self.key)
self.released = True

async def __aexit__(self, exc_type, exc_val, exc_tb):
"""
Async context manager exit.

`None` is returned to any waiters if no value is produced.
"""
if exc_val:
self.exception = exc_val
if not self.done:
self._future.set_result(None)
self.release()

def __del__(self):
"""Handle deletion."""
self.release()
24 changes: 8 additions & 16 deletions aries_cloudagent/cache/basic.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""Basic in-memory cache implementation."""

from datetime import datetime, timedelta

import time
from typing import Any, Sequence, Text, Union

from .base import BaseCache
Expand All @@ -12,17 +11,17 @@ class BasicCache(BaseCache):

def __init__(self):
"""Initialize a `BasicCache` instance."""

super().__init__()
# looks like { "key": { "expires": <epoch timestamp>, "value": <val> } }
self._cache = {}

def _remove_expired_cache_items(self):
"""Remove all expired items from cache."""
for key in self._cache.copy(): # iterate copy, del from original
cache_item_expiry = self._cache[key]["expires"]
for key, val in self._cache.copy().items(): # iterate copy, del from original
cache_item_expiry = val["expires"]
if cache_item_expiry is None:
continue
now = datetime.now().timestamp()
now = time.perf_counter()
if now >= cache_item_expiry:
del self._cache[key]

Expand Down Expand Up @@ -53,16 +52,9 @@ async def set(self, keys: Union[Text, Sequence[Text]], value: Any, ttl: int = No

"""
self._remove_expired_cache_items()
now = datetime.now()
expires_ts = None
if ttl:
expires = now + timedelta(seconds=ttl)
expires_ts = expires.timestamp()
for key in ([keys] if isinstance(keys, Text) else keys):
self._cache[key] = {
"expires": expires_ts,
"value": value
}
expires_ts = time.perf_counter() + ttl if ttl else None
for key in [keys] if isinstance(keys, Text) else keys:
self._cache[key] = {"expires": expires_ts, "value": value}

async def clear(self, key: Text):
"""
Expand Down
98 changes: 97 additions & 1 deletion aries_cloudagent/cache/tests/test_basic_cache.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from asyncio import sleep
from asyncio import ensure_future, sleep, wait_for
import pytest

from ..base import CacheError
from ..basic import BasicCache


Expand Down Expand Up @@ -69,3 +70,98 @@ async def test_set_expires_multi(self, cache):
async def test_flush(self, cache):
await cache.flush()
assert cache._cache == {}

@pytest.mark.asyncio
async def test_clear(self, cache):
await cache.set("key", "value")
await cache.clear("key")
item = await cache.get("key")
assert item is None

@pytest.mark.asyncio
async def test_acquire_release(self, cache):
test_key = "test_key"
lock = cache.acquire(test_key)
await lock.__aenter__()
assert test_key in cache._key_locks
await lock.__aexit__(None, None, None)
assert test_key not in cache._key_locks
assert await cache.get(test_key) is None

@pytest.mark.asyncio
async def test_acquire_with_future(self, cache):
test_key = "test_key"
test_result = "test_result"
lock = cache.acquire(test_key)
await lock.__aenter__()
await lock.set_result(test_result)
await lock.__aexit__(None, None, None)
assert await wait_for(lock, 1) == test_result
assert lock.done
assert lock.result == test_result
assert lock.future.result() == test_result

@pytest.mark.asyncio
async def test_acquire_release_with_waiter(self, cache):
test_key = "test_key"
test_result = "test_result"
lock = cache.acquire(test_key)
await lock.__aenter__()

lock2 = cache.acquire(test_key)
assert lock.parent is None
assert lock2.parent is lock
await lock.set_result(test_result)
await lock.__aexit__(None, None, None)

assert await cache.get(test_key) == test_result
assert await wait_for(lock, 1) == test_result
assert await wait_for(lock2, 1) == test_result

@pytest.mark.asyncio
async def test_duplicate_set(self, cache):
test_key = "test_key"
test_result = "test_result"
lock = cache.acquire(test_key)
async with lock:
assert not lock.done
await lock.set_result(test_result)
with pytest.raises(CacheError):
await lock.set_result(test_result)
assert lock.done
assert test_key not in cache._key_locks

@pytest.mark.asyncio
async def test_populated(self, cache):
test_key = "test_key"
test_result = "test_result"
await cache.set(test_key, test_result)
lock = cache.acquire(test_key)
lock2 = cache.acquire(test_key)

async def check():
async with lock as entry:
async with lock2 as entry2:
assert entry2.done # parent value located
assert entry2.result == test_result
assert entry.done
assert entry.result == test_result
assert test_key not in cache._key_locks

await wait_for(check(), 1)

@pytest.mark.asyncio
async def test_acquire_exception(self, cache):
test_key = "test_key"
test_result = "test_result"
lock = cache.acquire(test_key)
with pytest.raises(ValueError):
async with lock:
raise ValueError
assert isinstance(lock.exception, ValueError)
assert lock.done
assert lock.result is None

@pytest.mark.asyncio
async def test_repr(self, cache):
assert isinstance(repr(cache), str)
6 changes: 5 additions & 1 deletion aries_cloudagent/conductor.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,11 @@ async def setup(self):
# at the class level (!) should not be performed multiple times
collector.wrap(
ConnectionManager,
("get_connection_target", "fetch_did_document", "find_connection"),
(
"get_connection_target",
"fetch_did_document",
"find_message_connection",
),
)

async def start(self) -> None:
Expand Down
Loading