From 8d3542a64e2689a00ed87f9bd58fe3e1d3b10ed8 Mon Sep 17 00:00:00 2001 From: Hubert Chathi Date: Wed, 22 May 2019 16:42:00 -0400 Subject: [PATCH] implement federation parts of cross-signing --- .../sender/per_destination_queue.py | 4 +- synapse/handlers/device.py | 13 +- synapse/handlers/e2e_keys.py | 116 +++++++++++++++++- synapse/storage/devices.py | 56 ++++++++- 4 files changed, 179 insertions(+), 10 deletions(-) diff --git a/synapse/federation/sender/per_destination_queue.py b/synapse/federation/sender/per_destination_queue.py index fad980b89307..0486af2dbfe0 100644 --- a/synapse/federation/sender/per_destination_queue.py +++ b/synapse/federation/sender/per_destination_queue.py @@ -366,10 +366,10 @@ def _get_device_update_edus(self, limit): Edu( origin=self._server_name, destination=self._destination, - edu_type="m.device_list_update", + edu_type=edu_type, content=content, ) - for content in results + for (edu_type, content) in results ] assert len(edus) <= limit, "get_devices_by_remote returned too many EDUs" diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index 5f23ee448876..cd6eb52316e9 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -458,7 +458,18 @@ def notify_user_signature_update(self, from_user_id, user_ids): @defer.inlineCallbacks def on_federation_query_user_devices(self, user_id): stream_id, devices = yield self.store.get_devices_with_keys_by_user(user_id) - return {"user_id": user_id, "stream_id": stream_id, "devices": devices} + master_key = yield self.store.get_e2e_cross_signing_key(user_id, "master") + self_signing_key = yield self.store.get_e2e_cross_signing_key( + user_id, "self_signing" + ) + + return { + "user_id": user_id, + "stream_id": stream_id, + "devices": devices, + "master_key": master_key, + "self_signing_key": self_signing_key + } @defer.inlineCallbacks def user_left_room(self, user, room_id): diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index 5ea54f60be83..849ee04f93a0 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -36,6 +36,8 @@ get_verify_key_from_cross_signing_key, ) from synapse.util import unwrapFirstError +from synapse.util.async_helpers import Linearizer +from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.retryutils import NotRetryingDestination logger = logging.getLogger(__name__) @@ -49,10 +51,17 @@ def __init__(self, hs): self.is_mine = hs.is_mine self.clock = hs.get_clock() + self._edu_updater = SigningKeyEduUpdater(hs, self) + + federation_registry = hs.get_federation_registry() + + federation_registry.register_edu_handler( + "m.signing_key_update", self._edu_updater.incoming_signing_key_update, + ) # doesn't really work as part of the generic query API, because the # query request requires an object POST, but we abuse the # "query handler" interface. - hs.get_federation_registry().register_query_handler( + federation_registry.register_query_handler( "client_keys", self.on_federation_query_client_keys ) @@ -343,7 +352,15 @@ def on_federation_query_client_keys(self, query_body): """ device_keys_query = query_body.get("device_keys", {}) res = yield self.query_local_devices(device_keys_query) - return {"device_keys": res} + ret = {"device_keys": res} + + # add in the cross-signing keys + cross_signing_keys = yield self.query_cross_signing_keys(device_keys_query) + + for key, value in iteritems(cross_signing_keys): + ret[key + "_keys"] = value + + return ret @trace @defer.inlineCallbacks @@ -1047,3 +1064,98 @@ class SignatureListItem: target_user_id = attr.ib() target_device_id = attr.ib() signature = attr.ib() + + +class SigningKeyEduUpdater(object): + "Handles incoming signing key updates from federation and updates the DB" + + def __init__(self, hs, e2e_keys_handler): + self.store = hs.get_datastore() + self.federation = hs.get_federation_client() + self.clock = hs.get_clock() + self.e2e_keys_handler = e2e_keys_handler + + self._remote_edu_linearizer = Linearizer(name="remote_signing_key") + + # user_id -> list of updates waiting to be handled. + self._pending_updates = {} + + # Recently seen stream ids. We don't bother keeping these in the DB, + # but they're useful to have them about to reduce the number of spurious + # resyncs. + self._seen_updates = ExpiringCache( + cache_name="signing_key_update_edu", + clock=self.clock, + max_len=10000, + expiry_ms=30 * 60 * 1000, + iterable=True, + ) + + @defer.inlineCallbacks + def incoming_signing_key_update(self, origin, edu_content): + """Called on incoming signing key update from federation. Responsible for + parsing the EDU and adding to pending updates list. + + Args: + origin (string): the server that sent the EDU + edu_content (dict): the contents of the EDU + """ + + user_id = edu_content.pop("user_id") + master_key = edu_content.pop("master_key", None) + self_signing_key = edu_content.pop("self_signing_key", None) + + if get_domain_from_id(user_id) != origin: + # TODO: Raise? + logger.warning("Got signing key update edu for %r from %r", user_id, origin) + return + + room_ids = yield self.store.get_rooms_for_user(user_id) + if not room_ids: + # We don't share any rooms with this user. Ignore update, as we + # probably won't get any further updates. + return + + self._pending_updates.setdefault(user_id, []).append( + (master_key, self_signing_key, edu_content) + ) + + yield self._handle_signing_key_updates(user_id) + + @defer.inlineCallbacks + def _handle_signing_key_updates(self, user_id): + """Actually handle pending updates. + + Args: + user_id (string): the user whose updates we are processing + """ + + device_handler = self.e2e_keys_handler.device_handler + + with (yield self._remote_edu_linearizer.queue(user_id)): + pending_updates = self._pending_updates.pop(user_id, []) + if not pending_updates: + # This can happen since we batch updates + return + + device_ids = [] + + logger.info("pending updates: %r", pending_updates) + + for master_key, self_signing_key, edu_content in pending_updates: + if master_key: + yield self.store.set_e2e_cross_signing_key( + user_id, "master", master_key + ) + device_id = \ + get_verify_key_from_cross_signing_key(master_key)[1].version + device_ids.append(device_id) + if self_signing_key: + yield self.store.set_e2e_cross_signing_key( + user_id, "self_signing", self_signing_key + ) + device_id = \ + get_verify_key_from_cross_signing_key(self_signing_key)[1].version + device_ids.append(device_id) + + yield device_handler.notify_device_update(user_id, device_ids) diff --git a/synapse/storage/devices.py b/synapse/storage/devices.py index f7a35423481e..182e95fa216e 100644 --- a/synapse/storage/devices.py +++ b/synapse/storage/devices.py @@ -94,9 +94,10 @@ def get_devices_by_remote(self, destination, from_stream_id, limit): """Get stream of updates to send to remote servers Returns: - Deferred[tuple[int, list[dict]]]: + Deferred[tuple[int, list[tuple[string,dict]]]]: current stream id (ie, the stream id of the last update included in the - response), and the list of updates + response), and the list of updates, where each update is a pair of EDU + type and EDU contents """ now_stream_id = self._device_list_id_gen.get_current_token() @@ -129,6 +130,25 @@ def get_devices_by_remote(self, destination, from_stream_id, limit): if not updates: return now_stream_id, [] + # get the cross-signing keys of the users the list + users = set(r[0] for r in updates) + master_key_by_user = {} + self_signing_key_by_user = {} + for user in users: + cross_signing_key = yield self.get_e2e_cross_signing_key(user, "master") + key_id, verify_key = get_verify_key_from_cross_signing_key(cross_signing_key) + master_key_by_user[user] = { + "key_info": cross_signing_key, + "pubkey": verify_key.version + } + + cross_signing_key = yield self.get_e2e_cross_signing_key(user, "self_signing") + key_id, verify_key = get_verify_key_from_cross_signing_key(cross_signing_key) + self_signing_key_by_user[user] = { + "key_info": cross_signing_key, + "pubkey": verify_key.version + } + # if we have exceeded the limit, we need to exclude any results with the # same stream_id as the last row. if len(updates) > limit: @@ -158,6 +178,10 @@ def get_devices_by_remote(self, destination, from_stream_id, limit): # Stop processing updates break + if update[1] == master_key_by_user[update[0]]["pubkey"] or \ + update[1] == self_signing_key_by_user[update[0]]["pubkey"]: + continue + key = (update[0], update[1]) update_context = update[3] @@ -172,16 +196,37 @@ def get_devices_by_remote(self, destination, from_stream_id, limit): # means that there are more than limit updates all of which have the same # steam_id. + # figure out which cross-signing keys were changed by intersecting the + # update list with the master/self-signing key by user maps + cross_signing_keys_by_user = {} + for user_id, device_id, stream in updates: + if device_id == master_key_by_user[user_id]["pubkey"]: + result = cross_signing_keys_by_user.setdefault(user_id, {}) + result["master_key"] = \ + master_key_by_user[user_id]["key_info"] + elif device_id == self_signing_key_by_user[user_id]["pubkey"]: + result = cross_signing_keys_by_user.setdefault(user_id, {}) + result["self_signing_key"] = \ + self_signing_key_by_user[user_id]["key_info"] + + cross_signing_results = [] + + # add the updated cross-signing keys to the results list + for user_id, result in iteritems(cross_signing_keys_by_user): + result["user_id"] = user_id + cross_signing_results.append(("m.signing_key_update", result)) + # That should only happen if a client is spamming the server with new # devices, in which case E2E isn't going to work well anyway. We'll just # skip that stream_id and return an empty list, and continue with the next # stream_id next time. - if not query_map: + if not query_map and not cross_signing_results: return stream_id_cutoff, [] results = yield self._get_device_update_edus_by_remote( destination, from_stream_id, query_map ) + results.extend(cross_signing_results) return now_stream_id, results @@ -200,6 +245,7 @@ def _get_devices_by_remote_txn( Returns: List: List of device updates """ + # get the list of device updates that need to be sent sql = """ SELECT user_id, device_id, stream_id, opentracing_context FROM device_lists_outbound_pokes WHERE destination = ? AND ? < stream_id AND stream_id <= ? AND sent = ? @@ -231,7 +277,7 @@ def _get_device_update_edus_by_remote(self, destination, from_stream_id, query_m query_map.keys(), include_all_devices=True, include_deleted_devices=True, - ) + ) if query_map else {} results = [] for user_id, user_devices in iteritems(devices): @@ -262,7 +308,7 @@ def _get_device_update_edus_by_remote(self, destination, from_stream_id, query_m else: result["deleted"] = True - results.append(result) + results.append(("m.device_list_update", result)) return results