Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Minor @cachedList enhancements #9975

Merged
merged 1 commit into from
May 14, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/9975.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Minor enhancements to the `@cachedList` descriptor.
2 changes: 1 addition & 1 deletion synapse/storage/databases/main/devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -665,7 +665,7 @@ async def get_device_list_last_stream_id_for_remote(
cached_method_name="get_device_list_last_stream_id_for_remote",
list_name="user_ids",
)
async def get_device_list_last_stream_id_for_remotes(self, user_ids: str):
async def get_device_list_last_stream_id_for_remotes(self, user_ids: Iterable[str]):
rows = await self.db_pool.simple_select_many_batch(
table="device_lists_remote_extremeties",
column="user_id",
Expand Down
4 changes: 2 additions & 2 deletions synapse/storage/databases/main/end_to_end_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,7 +473,7 @@ def _get_bare_e2e_cross_signing_keys(self, user_id):
num_args=1,
)
async def _get_bare_e2e_cross_signing_keys_bulk(
self, user_ids: List[str]
self, user_ids: Iterable[str]
) -> Dict[str, Dict[str, dict]]:
"""Returns the cross-signing keys for a set of users. The output of this
function should be passed to _get_e2e_cross_signing_signatures_txn if
Expand All @@ -497,7 +497,7 @@ async def _get_bare_e2e_cross_signing_keys_bulk(
def _get_bare_e2e_cross_signing_keys_bulk_txn(
self,
txn: Connection,
user_ids: List[str],
user_ids: Iterable[str],
) -> Dict[str, Dict[str, dict]]:
"""Returns the cross-signing keys for a set of users. The output of this
function should be passed to _get_e2e_cross_signing_signatures_txn if
Expand Down
13 changes: 5 additions & 8 deletions synapse/storage/databases/main/user_erasure_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Dict, Iterable

from synapse.storage._base import SQLBaseStore
from synapse.util.caches.descriptors import cached, cachedList

Expand All @@ -37,21 +39,16 @@ async def is_user_erased(self, user_id: str) -> bool:
return bool(result)

@cachedList(cached_method_name="is_user_erased", list_name="user_ids")
async def are_users_erased(self, user_ids):
async def are_users_erased(self, user_ids: Iterable[str]) -> Dict[str, bool]:
"""
Checks which users in a list have requested erasure

Args:
user_ids (iterable[str]): full user id to check
user_ids: full user ids to check

Returns:
dict[str, bool]:
for each user, whether the user has requested erasure.
for each user, whether the user has requested erasure.
"""
# this serves the dual purpose of (a) making sure we can do len and
# iterate it multiple times, and (b) avoiding duplicates.
user_ids = tuple(set(user_ids))
Comment on lines -51 to -53
Copy link
Member

@anoadragon453 anoadragon453 May 13, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just checking that this was OK to remove because cachedList effectively does these two things for us?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh sorry, yes I meant to add a comment to say exactly that.


rows = await self.db_pool.simple_select_many_batch(
table="erased_users",
column="user_id",
Expand Down
14 changes: 8 additions & 6 deletions synapse/util/caches/descriptors.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,8 +322,8 @@ def _wrapped(*args, **kwargs):
class DeferredCacheListDescriptor(_CacheDescriptorBase):
"""Wraps an existing cache to support bulk fetching of keys.

Given a list of keys it looks in the cache to find any hits, then passes
the list of missing keys to the wrapped function.
Given an iterable of keys it looks in the cache to find any hits, then passes
the tuple of missing keys to the wrapped function.

Once wrapped, the function returns a Deferred which resolves to the list
of results.
Expand Down Expand Up @@ -437,7 +437,9 @@ def errback(f):
return f

args_to_call = dict(arg_dict)
args_to_call[self.list_name] = list(missing)
# copy the missing set before sending it to the callee, to guard against
# modification.
args_to_call[self.list_name] = tuple(missing)

cached_defers.append(
defer.maybeDeferred(
Expand Down Expand Up @@ -522,14 +524,14 @@ def cachedList(

Used to do batch lookups for an already created cache. A single argument
is specified as a list that is iterated through to lookup keys in the
original cache. A new list consisting of the keys that weren't in the cache
get passed to the original function, the result of which is stored in the
original cache. A new tuple consisting of the (deduplicated) keys that weren't in
the cache gets passed to the original function, the result of which is stored in the
cache.

Args:
cached_method_name: The name of the single-item lookup method.
This is only used to find the cache to use.
list_name: The name of the argument that is the list to use to
list_name: The name of the argument that is the iterable to use to
do batch lookups in the cache.
num_args: Number of arguments to use as the key in the cache
(including list_name). Defaults to all named parameters.
Expand Down
17 changes: 14 additions & 3 deletions tests/util/caches/test_descriptors.py
Original file line number Diff line number Diff line change
Expand Up @@ -666,18 +666,20 @@ async def list_fn(self, args1, arg2):
with LoggingContext("c1") as c1:
obj = Cls()
obj.mock.return_value = {10: "fish", 20: "chips"}

# start the lookup off
d1 = obj.list_fn([10, 20], 2)
self.assertEqual(current_context(), SENTINEL_CONTEXT)
r = yield d1
self.assertEqual(current_context(), c1)
obj.mock.assert_called_once_with([10, 20], 2)
obj.mock.assert_called_once_with((10, 20), 2)
self.assertEqual(r, {10: "fish", 20: "chips"})
obj.mock.reset_mock()

# a call with different params should call the mock again
obj.mock.return_value = {30: "peas"}
r = yield obj.list_fn([20, 30], 2)
obj.mock.assert_called_once_with([30], 2)
obj.mock.assert_called_once_with((30,), 2)
self.assertEqual(r, {20: "chips", 30: "peas"})
obj.mock.reset_mock()

Expand All @@ -692,6 +694,15 @@ async def list_fn(self, args1, arg2):
obj.mock.assert_not_called()
self.assertEqual(r, {10: "fish", 20: "chips", 30: "peas"})

# we should also be able to use a (single-use) iterable, and should
# deduplicate the keys
obj.mock.reset_mock()
obj.mock.return_value = {40: "gravy"}
iterable = (x for x in [10, 40, 40])
r = yield obj.list_fn(iterable, 2)
obj.mock.assert_called_once_with((40,), 2)
self.assertEqual(r, {10: "fish", 40: "gravy"})

@defer.inlineCallbacks
def test_invalidate(self):
"""Make sure that invalidation callbacks are called."""
Expand All @@ -717,7 +728,7 @@ async def list_fn(self, args1, arg2):
# cache miss
obj.mock.return_value = {10: "fish", 20: "chips"}
r1 = yield obj.list_fn([10, 20], 2, on_invalidate=invalidate0)
obj.mock.assert_called_once_with([10, 20], 2)
obj.mock.assert_called_once_with((10, 20), 2)
self.assertEqual(r1, {10: "fish", 20: "chips"})
obj.mock.reset_mock()

Expand Down