Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
add caching
Browse files Browse the repository at this point in the history
  • Loading branch information
uhoreg committed Dec 10, 2019
1 parent 0cc7bd8 commit f1ea148
Show file tree
Hide file tree
Showing 3 changed files with 170 additions and 91 deletions.
25 changes: 12 additions & 13 deletions synapse/handlers/e2e_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,25 +280,24 @@ def get_cross_signing_keys_from_cache(self, query, from_user_id):
defer.Deferred[dict[str, dict[str, dict]]]: map from
(master_keys|self_signing_keys|user_signing_keys) -> user_id -> key
"""
master_keys = {}
self_signing_keys = {}
user_signing_keys = {}

users = list(query)
user_ids = list(query)

master_keys = yield self.store.get_e2e_cross_signing_keys_bulk(
users, "master", from_user_id
)
self_signing_keys = yield self.store.get_e2e_cross_signing_keys_bulk(
users, "self_signing", from_user_id
)
keys = yield self.store.get_e2e_cross_signing_keys_bulk(user_ids, from_user_id)

if from_user_id in users:
for user_id, user_info in keys.items():
if "master" in user_info:
master_keys[user_id] = user_info["master"]
if "self_signing" in user_info:
self_signing_keys[user_id] = user_info["self_signing"]

if from_user_id in keys and "user_signing" in keys[from_user_id]:
# users can see other users' master and self-signing keys, but can
# only see their own user-signing keys
key = yield self.store.get_e2e_cross_signing_key(
from_user_id, "user_signing", from_user_id
)
if key:
user_signing_keys[from_user_id] = key
user_signing_keys[from_user_id] = keys[from_user_id]["user_signing"]

return {
"master_keys": master_keys,
Expand Down
234 changes: 157 additions & 77 deletions synapse/storage/data_stores/main/end_to_end_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Iterable

from six import iteritems

Expand All @@ -25,7 +24,7 @@

from synapse.logging.opentracing import log_kv, set_tag, trace
from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.util.caches.descriptors import cached
from synapse.util.caches.descriptors import cached, cachedList


class EndToEndKeyWorkerStore(SQLBaseStore):
Expand Down Expand Up @@ -337,108 +336,180 @@ def get_e2e_cross_signing_key(self, user_id, key_type, from_user_id=None):
from_user_id,
)

def _get_e2e_cross_signing_keys_bulk_txn(
self,
txn: Connection,
users: Iterable[str],
key_type: str,
from_user_id: str = None,
) -> dict:
"""Returns the cross-signing keys for a set of users.
@cached(num_args=1)
def _get_bare_e2e_cross_signing_keys(self, user_id):
"""Dummy function. Only used to make a cache for
_get_bare_e2e_cross_signing_keys_bulk.
"""
pass

@cachedList(
cached_method_name="_get_bare_e2e_cross_signing_keys",
list_name="user_ids",
num_args=1,
)
def _get_bare_e2e_cross_signing_keys_bulk(self, user_ids: list) -> dict:
"""Returns the cross-signing keys for a set of users. The output of this
function should be passed to _get_e2e_cross_signing_signatures_txn if
the signatures for the calling user need to be fetched.
Args:
txn (twisted.enterprise.adbapi.Connection): db connection
users (iterable[str]): the users whose keys are being requested
key_type (str): the type of keys that are being requested: either 'master'
for a master key, 'self_signing' for a self-signing key, or
'user_signing' for a user-signing key
from_user_id (str): if specified, signatures made by this user on
the keys will be included in the result
user_ids (list[str]): the users whose keys are being requested
Returns:
dict[str, dict]: mapping from user ID to key data. If a user's
cross-signing key was not found, their user ID will not be in the
dict.
dict[str, dict[str, dict]]: mapping from user ID to key type to key
data. If a user's cross-signing keys were not found, their user
ID will not be in the dict.
"""
return self.db.runInteraction(
"get_bare_e2e_cross_signing_keys_bulk",
self._get_bare_e2e_cross_signing_keys_bulk_txn,
user_ids,
)

def _get_bare_e2e_cross_signing_keys_bulk_txn(
self, txn: Connection, user_ids: list,
) -> dict:
"""Returns the cross-signing keys for a set of users. The output of this
function should be passed to _get_e2e_cross_signing_signatures_txn if
the signatures for the calling user need to be fetched.
# convert to a list if needed, so that we can slice it
if not isinstance(users, list):
users = list(users)
Args:
txn (twisted.enterprise.adbapi.Connection): db connection
user_ids (list[str]): the users whose keys are being requested
Returns:
dict[str, dict[str, dict]]: mapping from user ID to key type to key
data. If a user's cross-signing keys were not found, their user
ID will not be in the dict.
"""
result = {}

batch_size = 100
chunks = [users[i : i + batch_size] for i in range(0, len(users), batch_size)]
chunks = [
user_ids[i : i + batch_size] for i in range(0, len(user_ids), batch_size)
]
for user_chunk in chunks:
sql = """
SELECT k.user_id, k.keydata, k.stream_id
SELECT k.user_id, k.keytype, k.keydata, k.stream_id
FROM e2e_cross_signing_keys k
INNER JOIN (SELECT user_id, MAX(stream_id) AS stream_id
INNER JOIN (SELECT user_id, keytype, MAX(stream_id) AS stream_id
FROM e2e_cross_signing_keys
WHERE keytype = ?
GROUP BY user_id) s
USING (user_id, stream_id)
WHERE k.user_id IN (%s) AND k.keytype = ?
GROUP BY user_id, keytype) s
USING (user_id, stream_id, keytype)
WHERE k.user_id IN (%s)
""" % (
",".join("?" for u in user_chunk)
",".join("?" for u in user_chunk),
)
query_params = [key_type]
query_params = []
query_params.extend(user_chunk)
query_params.append(key_type)

txn.execute(sql, query_params)
rows = self.db.cursor_to_dict(txn)

devices = {}
for row in rows:
user_id = row["user_id"]
key_type = row["keytype"]
key = json.loads(row["keydata"])
result[user_id] = key
user_info = result.setdefault(user_id, {})
user_info[key_type] = key

return result

def _get_e2e_cross_signing_signatures_txn(
self, txn: Connection, keys: dict, from_user_id: str,
) -> dict:
"""Returns the cross-signing signatures made by a user on a set of keys.
Args:
txn (twisted.enterprise.adbapi.Connection): db connection
keys (dict[str, dict[str, dict]]): a map of user ID to key type to
key data. This dict will be modified to add signatures.
from_user_id (str): fetch the signatures made by this user
Returns:
dict[str, dict[str, dict]]: mapping from user ID to key type to key
data. The return value will be the same as the keys argument,
with the modifications included.
"""

# find out what cross-signing keys (a.k.a. devices) we need to get
# signatures for. This is a map of (user_id, device_id) to key type
# (device_id is the key's public part).
devices = {}

for user_id, user_info in keys.items():
for key_type, key in user_info.items():
device_id = None
for k in key["keys"].values():
devices[user_id] = k

if devices and from_user_id:
# if we're asked to get signatures, and we have any devices to get
# signatures for, fetch the signatures
sql = """
SELECT target_user_id, key_id, signature
FROM e2e_cross_signing_signatures
WHERE user_id = ?
AND (%s)
""" % (
" OR ".join(
"(target_user_id = ? AND target_device_id = ?)" for d in devices
)
device_id = k
devices[(user_id, device_id)] = key_type

device_list = list(devices)

# split into batches
batch_size = 100
chunks = [
device_list[i : i + batch_size]
for i in range(0, len(device_list), batch_size)
]
for user_chunk in chunks:
sql = """
SELECT target_user_id, target_device_id, key_id, signature
FROM e2e_cross_signing_signatures
WHERE user_id = ?
AND (%s)
""" % (
" OR ".join(
"(target_user_id = ? AND target_device_id = ?)" for d in devices
)
query_params = [from_user_id]
for item in devices.items():
# item is a (user_id, device_id) tuple
query_params.extend(item)

txn.execute(sql, query_params)
rows = self.db.cursor_to_dict(txn)

# and add the signatures to the appropriate keys
for row in rows:
key_id = row["key_id"]
target_user_id = row["target_user_id"]
target_user_key = result[target_user_id]
signatures = target_user_key.setdefault("signatures", {})
user_sigs = signatures.setdefault(from_user_id, {})
user_sigs[key_id] = row["signature"]
)
query_params = [from_user_id]
for item in devices:
# item is a (user_id, device_id) tuple
query_params.extend(item)

return result
txn.execute(sql, query_params)
rows = self.db.cursor_to_dict(txn)

# and add the signatures to the appropriate keys
for row in rows:
key_id = row["key_id"]
target_user_id = row["target_user_id"]
target_device_id = row["target_device_id"]
key_type = devices[(target_user_id, target_device_id)]
# We need to copy everything, because the result may have come
# from the cache. dict.copy only does a shallow copy, so we
# need to recursively copy the dicts that will be modified.
user_info = keys[target_user_id] = keys[target_user_id].copy()
target_user_key = user_info[key_type] = user_info[key_type].copy()
if "signatures" in target_user_key:
signatures = target_user_key["signatures"] = target_user_key[
"signatures"
].copy()
if from_user_id in signatures:
user_sigs = signatures[from_user_id] = signatures[from_user_id]
user_sigs[key_id] = row["signature"]
else:
signatures[from_user_id] = {key_id: row["signature"]}
else:
target_user_key["signatures"] = {
from_user_id: {key_id: row["signature"]}
}

return keys

@defer.inlineCallbacks
def get_e2e_cross_signing_keys_bulk(
self, users: Iterable[str], key_type: str, from_user_id: str = None
self, user_ids: list, from_user_id: str = None
) -> defer.Deferred:
"""Returns the cross-signing keys for a set of users.
Args:
users (iterable[str]): the users whose keys are being requested
key_type (str): the type of keys that are being requested: either 'master'
for a master key, 'self_signing' for a self-signing key, or
'user_signing' for a user-signing key
user_ids (list[str]): the users whose keys are being requested
from_user_id (str): if specified, signatures made by this user on
the self-signing keys will be included in the result
Expand All @@ -447,13 +518,18 @@ def get_e2e_cross_signing_keys_bulk(
cross-signing key was not found, their user ID will not be in
the dict.
"""
return self.db.runInteraction(
"get_e2e_cross_signing_key",
self._get_e2e_cross_signing_keys_bulk_txn,
users,
key_type,
from_user_id,
)

result = yield self._get_bare_e2e_cross_signing_keys_bulk(user_ids)

if from_user_id:
result = yield self.db.runInteraction(
"get_e2e_cross_signing_signatures",
self._get_e2e_cross_signing_signatures_txn,
result,
from_user_id,
)

return result

def get_all_user_signature_changes_for_remotes(self, from_key, to_key):
"""Return a list of changes from the user signature stream to notify remotes.
Expand Down Expand Up @@ -643,6 +719,10 @@ def _set_e2e_cross_signing_key_txn(self, txn, user_id, key_type, key):
},
)

self._invalidate_cache_and_stream(
txn, self._get_bare_e2e_cross_signing_keys, (user_id,)
)

def set_e2e_cross_signing_key(self, user_id, key_type, key):
"""Set a user's cross-signing key.
Expand Down
2 changes: 1 addition & 1 deletion synapse/util/caches/descriptors.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ def __init__(
else:
self.function_to_call = orig

arg_spec = inspect.getargspec(orig)
arg_spec = inspect.getfullargspec(orig)
all_args = arg_spec.args

if "cache_context" in all_args:
Expand Down

0 comments on commit f1ea148

Please sign in to comment.