Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Rewrite cache list decorator #3384

Merged
merged 2 commits into from
Aug 1, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/3384.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Rewrite cache list decorator
131 changes: 64 additions & 67 deletions synapse/util/caches/descriptors.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,105 +473,101 @@ def __get__(self, obj, objtype=None):

@functools.wraps(self.orig)
def wrapped(*args, **kwargs):
# If we're passed a cache_context then we'll want to call its invalidate()
# whenever we are invalidated
# If we're passed a cache_context then we'll want to call its
# invalidate() whenever we are invalidated
invalidate_callback = kwargs.pop("on_invalidate", None)

arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs)
keyargs = [arg_dict[arg_nm] for arg_nm in self.arg_names]
list_args = arg_dict[self.list_name]

# cached is a dict arg -> deferred, where deferred results in a
# 2-tuple (`arg`, `result`)
results = {}
cached_defers = {}
missing = []

def update_results_dict(res, arg):
results[arg] = res

# list of deferreds to wait for
cached_defers = []

missing = set()

# If the cache takes a single arg then that is used as the key,
# otherwise a tuple is used.
if num_args == 1:
def cache_get(arg):
return cache.get(arg, callback=invalidate_callback)
def arg_to_cache_key(arg):
return arg
else:
key = list(keyargs)
keylist = list(keyargs)

def cache_get(arg):
key[self.list_pos] = arg
return cache.get(tuple(key), callback=invalidate_callback)
def arg_to_cache_key(arg):
keylist[self.list_pos] = arg
return tuple(keylist)

for arg in list_args:
try:
res = cache_get(arg)

res = cache.get(arg_to_cache_key(arg),
callback=invalidate_callback)
if not isinstance(res, ObservableDeferred):
results[arg] = res
elif not res.has_succeeded():
res = res.observe()
res.addCallback(lambda r, arg: (arg, r), arg)
cached_defers[arg] = res
res.addCallback(update_results_dict, arg)
cached_defers.append(res)
else:
results[arg] = res.get_result()
except KeyError:
missing.append(arg)
missing.add(arg)

if missing:
# we need an observable deferred for each entry in the list,
# which we put in the cache. Each deferred resolves with the
# relevant result for that key.
deferreds_map = {}
for arg in missing:
deferred = defer.Deferred()
deferreds_map[arg] = deferred
key = arg_to_cache_key(arg)
observable = ObservableDeferred(deferred)
cache.set(key, observable, callback=invalidate_callback)

def complete_all(res):
# the wrapped function has completed. It returns a
# a dict. We can now resolve the observable deferreds in
# the cache and update our own result map.
for e in missing:
val = res.get(e, None)
deferreds_map[e].callback(val)
results[e] = val

def errback(f):
# the wrapped function has failed. Invalidate any cache
# entries we're supposed to be populating, and fail
# their deferreds.
for e in missing:
key = arg_to_cache_key(e)
cache.invalidate(key)
deferreds_map[e].errback(f)

# return the failure, to propagate to our caller.
return f

args_to_call = dict(arg_dict)
args_to_call[self.list_name] = missing
args_to_call[self.list_name] = list(missing)

ret_d = defer.maybeDeferred(
cached_defers.append(defer.maybeDeferred(
logcontext.preserve_fn(self.function_to_call),
**args_to_call
)

ret_d = ObservableDeferred(ret_d)

# We need to create deferreds for each arg in the list so that
# we can insert the new deferred into the cache.
for arg in missing:
observer = ret_d.observe()
observer.addCallback(lambda r, arg: r.get(arg, None), arg)

observer = ObservableDeferred(observer)

if num_args == 1:
cache.set(
arg, observer,
callback=invalidate_callback
)

def invalidate(f, key):
cache.invalidate(key)
return f
observer.addErrback(invalidate, arg)
else:
key = list(keyargs)
key[self.list_pos] = arg
cache.set(
tuple(key), observer,
callback=invalidate_callback
)

def invalidate(f, key):
cache.invalidate(key)
return f
observer.addErrback(invalidate, tuple(key))

res = observer.observe()
res.addCallback(lambda r, arg: (arg, r), arg)

cached_defers[arg] = res
).addCallbacks(complete_all, errback))

if cached_defers:
def update_results_dict(res):
results.update(res)
return results

return logcontext.make_deferred_yieldable(defer.gatherResults(
list(cached_defers.values()),
d = defer.gatherResults(
cached_defers,
consumeErrors=True,
).addCallback(update_results_dict).addErrback(
).addCallbacks(
lambda _: results,
unwrapFirstError
))
)
return logcontext.make_deferred_yieldable(d)
else:
return results

Expand Down Expand Up @@ -625,7 +621,8 @@ def cachedList(cached_method_name, list_name, num_args=None, inlineCallbacks=Fal
cache.

Args:
cache (Cache): The underlying cache to use.
cached_method_name (str): The name of the single-item lookup method.
This is only used to find the cache to use.
list_name (str): The name of the argument that is the list to use to
do batch lookups in the cache.
num_args (int): Number of arguments to use as the key in the cache
Expand Down
101 changes: 101 additions & 0 deletions tests/util/caches/test_descriptors.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,3 +273,104 @@ def fn(self, arg1, arg2=2, arg3=3):
r = yield obj.fn(2, 3)
self.assertEqual(r, 'chips')
obj.mock.assert_not_called()


class CachedListDescriptorTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_cache(self):
class Cls(object):
def __init__(self):
self.mock = mock.Mock()

@descriptors.cached()
def fn(self, arg1, arg2):
pass

@descriptors.cachedList("fn", "args1", inlineCallbacks=True)
def list_fn(self, args1, arg2):
assert (
logcontext.LoggingContext.current_context().request == "c1"
)
# we want this to behave like an asynchronous function
yield run_on_reactor()
assert (
logcontext.LoggingContext.current_context().request == "c1"
)
defer.returnValue(self.mock(args1, arg2))

with logcontext.LoggingContext() as c1:
c1.request = "c1"
obj = Cls()
obj.mock.return_value = {10: 'fish', 20: 'chips'}
d1 = obj.list_fn([10, 20], 2)
self.assertEqual(
logcontext.LoggingContext.current_context(),
logcontext.LoggingContext.sentinel,
)
r = yield d1
self.assertEqual(
logcontext.LoggingContext.current_context(),
c1
)
obj.mock.assert_called_once_with([10, 20], 2)
self.assertEqual(r, {10: 'fish', 20: 'chips'})
obj.mock.reset_mock()

# a call with different params should call the mock again
obj.mock.return_value = {30: 'peas'}
r = yield obj.list_fn([20, 30], 2)
obj.mock.assert_called_once_with([30], 2)
self.assertEqual(r, {20: 'chips', 30: 'peas'})
obj.mock.reset_mock()

# all the values should now be cached
r = yield obj.fn(10, 2)
self.assertEqual(r, 'fish')
r = yield obj.fn(20, 2)
self.assertEqual(r, 'chips')
r = yield obj.fn(30, 2)
self.assertEqual(r, 'peas')
r = yield obj.list_fn([10, 20, 30], 2)
obj.mock.assert_not_called()
self.assertEqual(r, {10: 'fish', 20: 'chips', 30: 'peas'})

@defer.inlineCallbacks
def test_invalidate(self):
"""Make sure that invalidation callbacks are called."""
class Cls(object):
def __init__(self):
self.mock = mock.Mock()

@descriptors.cached()
def fn(self, arg1, arg2):
pass

@descriptors.cachedList("fn", "args1", inlineCallbacks=True)
def list_fn(self, args1, arg2):
# we want this to behave like an asynchronous function
yield run_on_reactor()
defer.returnValue(self.mock(args1, arg2))

obj = Cls()
invalidate0 = mock.Mock()
invalidate1 = mock.Mock()

# cache miss
obj.mock.return_value = {10: 'fish', 20: 'chips'}
r1 = yield obj.list_fn([10, 20], 2, on_invalidate=invalidate0)
obj.mock.assert_called_once_with([10, 20], 2)
self.assertEqual(r1, {10: 'fish', 20: 'chips'})
obj.mock.reset_mock()

# cache hit
r2 = yield obj.list_fn([10, 20], 2, on_invalidate=invalidate1)
obj.mock.assert_not_called()
self.assertEqual(r2, {10: 'fish', 20: 'chips'})

invalidate0.assert_not_called()
invalidate1.assert_not_called()

# now if we invalidate the keys, both invalidations should get called
obj.fn.invalidate((10, 2))
invalidate0.assert_called_once()
invalidate1.assert_called_once()