Skip to content

Commit

Permalink
Merge pull request #256 from andrewwhitehead/feature/cache-lock
Browse files Browse the repository at this point in the history
Add support for locking a cache key, to avoid overlapping expensive operations
  • Loading branch information
nrempel authored Nov 8, 2019
2 parents 5ee3ebc + da298d6 commit ca0e463
Show file tree
Hide file tree
Showing 6 changed files with 356 additions and 72 deletions.
122 changes: 122 additions & 0 deletions aries_cloudagent/cache/base.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,23 @@
"""Abstract base classes for cache."""

import asyncio
from abc import ABC, abstractmethod
from typing import Any, Sequence, Text, Union

from ..error import BaseError


class CacheError(BaseError):
"""Base class for cache-related errors."""


class BaseCache(ABC):
"""Abstract cache interface."""

def __init__(self):
"""Initialize the cache instance."""
self._key_locks = {}

@abstractmethod
async def get(self, key: Text):
"""
Expand Down Expand Up @@ -46,6 +57,117 @@ async def clear(self, key: Text):
async def flush(self):
"""Remove all items from the cache."""

def acquire(self, key: Text):
"""Acquire a lock on a given cache key."""
result = CacheKeyLock(self, key)
first = self._key_locks.setdefault(key, result)
if first is not result:
result.parent = first
return result

def release(self, key: Text):
"""Release the lock on a given cache key."""
if key in self._key_locks:
del self._key_locks[key]

def __repr__(self) -> str:
"""Human readable representation of `BaseStorageRecordSearch`."""
return "<{}>".format(self.__class__.__name__)


class CacheKeyLock:
"""
A lock on a particular cache key.
Used to prevent multiple async threads from generating
or querying the same semi-expensive data. Not thread safe.
"""

def __init__(self, cache: BaseCache, key: Text):
"""Initialize the key lock."""
self.cache = cache
self.exception: BaseException = None
self.key = key
self.released = False
self._future: asyncio.Future = asyncio.get_event_loop().create_future()
self._parent: "CacheKeyLock" = None

@property
def done(self) -> bool:
"""Accessor for the done state."""
return self._future.done()

@property
def future(self) -> asyncio.Future:
"""Fetch the result in the form of an awaitable future."""
return self._future

@property
def result(self) -> Any:
"""Fetch the current result, if any."""
if self.done:
return self._future.result()

@property
def parent(self) -> "CacheKeyLock":
"""Accessor for the parent key lock, if any."""
return self._parent

@parent.setter
def parent(self, parent: "CacheKeyLock"):
"""Setter for the parent lock."""
self._parent = parent
parent._future.add_done_callback(self._handle_parent_done)

def _handle_parent_done(self, fut: asyncio.Future):
"""Handle completion of parent's future."""
result = fut.result()
if result:
self._future.set_result(fut.result())

async def set_result(self, value: Any, ttl: int = None):
"""Set the result, updating the cache and any waiters."""
if self.done and value:
raise CacheError("Result already set")
self._future.set_result(value)
if not self._parent or self._parent.done:
await self.cache.set(self.key, value, ttl)

def __await__(self):
"""Wait for a result to be produced."""
return (yield from self._future)

async def __aenter__(self):
"""Async context manager entry."""
result = None
if self.parent:
result = await self.parent
if result:
await self # wait for parent's done handler to complete
if not result:
found = await self.cache.get(self.key)
if found:
self._future.set_result(found)
return self

def release(self):
"""Release the cache lock."""
if not self.parent and not self.released:
self.cache.release(self.key)
self.released = True

async def __aexit__(self, exc_type, exc_val, exc_tb):
"""
Async context manager exit.
`None` is returned to any waiters if no value is produced.
"""
if exc_val:
self.exception = exc_val
if not self.done:
self._future.set_result(None)
self.release()

def __del__(self):
"""Handle deletion."""
self.release()
24 changes: 8 additions & 16 deletions aries_cloudagent/cache/basic.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""Basic in-memory cache implementation."""

from datetime import datetime, timedelta

import time
from typing import Any, Sequence, Text, Union

from .base import BaseCache
Expand All @@ -12,17 +11,17 @@ class BasicCache(BaseCache):

def __init__(self):
"""Initialize a `BasicCache` instance."""

super().__init__()
# looks like { "key": { "expires": <epoch timestamp>, "value": <val> } }
self._cache = {}

def _remove_expired_cache_items(self):
"""Remove all expired items from cache."""
for key in self._cache.copy(): # iterate copy, del from original
cache_item_expiry = self._cache[key]["expires"]
for key, val in self._cache.copy().items(): # iterate copy, del from original
cache_item_expiry = val["expires"]
if cache_item_expiry is None:
continue
now = datetime.now().timestamp()
now = time.perf_counter()
if now >= cache_item_expiry:
del self._cache[key]

Expand Down Expand Up @@ -53,16 +52,9 @@ async def set(self, keys: Union[Text, Sequence[Text]], value: Any, ttl: int = No
"""
self._remove_expired_cache_items()
now = datetime.now()
expires_ts = None
if ttl:
expires = now + timedelta(seconds=ttl)
expires_ts = expires.timestamp()
for key in ([keys] if isinstance(keys, Text) else keys):
self._cache[key] = {
"expires": expires_ts,
"value": value
}
expires_ts = time.perf_counter() + ttl if ttl else None
for key in [keys] if isinstance(keys, Text) else keys:
self._cache[key] = {"expires": expires_ts, "value": value}

async def clear(self, key: Text):
"""
Expand Down
98 changes: 97 additions & 1 deletion aries_cloudagent/cache/tests/test_basic_cache.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from asyncio import sleep
from asyncio import ensure_future, sleep, wait_for
import pytest

from ..base import CacheError
from ..basic import BasicCache


Expand Down Expand Up @@ -69,3 +70,98 @@ async def test_set_expires_multi(self, cache):
async def test_flush(self, cache):
await cache.flush()
assert cache._cache == {}

@pytest.mark.asyncio
async def test_clear(self, cache):
await cache.set("key", "value")
await cache.clear("key")
item = await cache.get("key")
assert item is None

@pytest.mark.asyncio
async def test_acquire_release(self, cache):
test_key = "test_key"
lock = cache.acquire(test_key)
await lock.__aenter__()
assert test_key in cache._key_locks
await lock.__aexit__(None, None, None)
assert test_key not in cache._key_locks
assert await cache.get(test_key) is None

@pytest.mark.asyncio
async def test_acquire_with_future(self, cache):
test_key = "test_key"
test_result = "test_result"
lock = cache.acquire(test_key)
await lock.__aenter__()
await lock.set_result(test_result)
await lock.__aexit__(None, None, None)
assert await wait_for(lock, 1) == test_result
assert lock.done
assert lock.result == test_result
assert lock.future.result() == test_result

@pytest.mark.asyncio
async def test_acquire_release_with_waiter(self, cache):
test_key = "test_key"
test_result = "test_result"
lock = cache.acquire(test_key)
await lock.__aenter__()

lock2 = cache.acquire(test_key)
assert lock.parent is None
assert lock2.parent is lock
await lock.set_result(test_result)
await lock.__aexit__(None, None, None)

assert await cache.get(test_key) == test_result
assert await wait_for(lock, 1) == test_result
assert await wait_for(lock2, 1) == test_result

@pytest.mark.asyncio
async def test_duplicate_set(self, cache):
test_key = "test_key"
test_result = "test_result"
lock = cache.acquire(test_key)
async with lock:
assert not lock.done
await lock.set_result(test_result)
with pytest.raises(CacheError):
await lock.set_result(test_result)
assert lock.done
assert test_key not in cache._key_locks

@pytest.mark.asyncio
async def test_populated(self, cache):
test_key = "test_key"
test_result = "test_result"
await cache.set(test_key, test_result)
lock = cache.acquire(test_key)
lock2 = cache.acquire(test_key)

async def check():
async with lock as entry:
async with lock2 as entry2:
assert entry2.done # parent value located
assert entry2.result == test_result
assert entry.done
assert entry.result == test_result
assert test_key not in cache._key_locks

await wait_for(check(), 1)

@pytest.mark.asyncio
async def test_acquire_exception(self, cache):
test_key = "test_key"
test_result = "test_result"
lock = cache.acquire(test_key)
with pytest.raises(ValueError):
async with lock:
raise ValueError
assert isinstance(lock.exception, ValueError)
assert lock.done
assert lock.result is None

@pytest.mark.asyncio
async def test_repr(self, cache):
assert isinstance(repr(cache), str)
6 changes: 5 additions & 1 deletion aries_cloudagent/conductor.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,11 @@ async def setup(self):
# at the class level (!) should not be performed multiple times
collector.wrap(
ConnectionManager,
("get_connection_target", "fetch_did_document", "find_connection"),
(
"get_connection_target",
"fetch_did_document",
"find_message_connection",
),
)

async def start(self) -> None:
Expand Down
Loading

0 comments on commit ca0e463

Please sign in to comment.