diff --git a/changelog.d/15462.misc b/changelog.d/15462.misc new file mode 100644 index 000000000000..36e4bffbc86b --- /dev/null +++ b/changelog.d/15462.misc @@ -0,0 +1 @@ +Update support for [MSC3983](https://github.com/matrix-org/matrix-spec-proposals/pull/3983) to allow always returning fallback-keys in a `/keys/claim` request. diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index d7740eb3b448..c618f3d7a6cd 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -1005,7 +1005,7 @@ async def on_query_user_devices( @trace async def on_claim_client_keys( - self, origin: str, content: JsonDict + self, origin: str, content: JsonDict, always_include_fallback_keys: bool ) -> Dict[str, Any]: query = [] for user_id, device_keys in content.get("one_time_keys", {}).items(): @@ -1013,7 +1013,9 @@ async def on_claim_client_keys( query.append((user_id, device_id, algorithm)) log_kv({"message": "Claiming one time keys.", "user, device pairs": query}) - results = await self._e2e_keys_handler.claim_local_one_time_keys(query) + results = await self._e2e_keys_handler.claim_local_one_time_keys( + query, always_include_fallback_keys=always_include_fallback_keys + ) json_result: Dict[str, Dict[str, Dict[str, JsonDict]]] = {} for result in results: diff --git a/synapse/federation/transport/server/__init__.py b/synapse/federation/transport/server/__init__.py index 753372fc5476..55d2cd0a9aa2 100644 --- a/synapse/federation/transport/server/__init__.py +++ b/synapse/federation/transport/server/__init__.py @@ -25,6 +25,7 @@ from synapse.federation.transport.server.federation import ( FEDERATION_SERVLET_CLASSES, FederationAccountStatusServlet, + FederationUnstableClientKeysClaimServlet, ) from synapse.http.server import HttpServer, JsonResource from synapse.http.servlet import ( @@ -298,6 +299,11 @@ def register_servlets( and not hs.config.experimental.msc3720_enabled ): continue + if ( + servletclass == FederationUnstableClientKeysClaimServlet + and not hs.config.experimental.msc3983_appservice_otk_claims + ): + continue servletclass( hs=hs, diff --git a/synapse/federation/transport/server/federation.py b/synapse/federation/transport/server/federation.py index ec5b5eeafa29..e2340d70d509 100644 --- a/synapse/federation/transport/server/federation.py +++ b/synapse/federation/transport/server/federation.py @@ -577,7 +577,28 @@ class FederationClientKeysClaimServlet(BaseFederationServerServlet): async def on_POST( self, origin: str, content: JsonDict, query: Dict[bytes, List[bytes]] ) -> Tuple[int, JsonDict]: - response = await self.handler.on_claim_client_keys(origin, content) + response = await self.handler.on_claim_client_keys( + origin, content, always_include_fallback_keys=False + ) + return 200, response + + +class FederationUnstableClientKeysClaimServlet(BaseFederationServerServlet): + """ + Identical to the stable endpoint (FederationClientKeysClaimServlet) except it + always includes fallback keys in the response. + """ + + PREFIX = FEDERATION_UNSTABLE_PREFIX + PATH = "/user/keys/claim" + CATEGORY = "Federation requests" + + async def on_POST( + self, origin: str, content: JsonDict, query: Dict[bytes, List[bytes]] + ) -> Tuple[int, JsonDict]: + response = await self.handler.on_claim_client_keys( + origin, content, always_include_fallback_keys=True + ) return 200, response diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py index da887647d4d8..4ca2bc04203b 100644 --- a/synapse/handlers/appservice.py +++ b/synapse/handlers/appservice.py @@ -842,9 +842,7 @@ async def _check_user_exists(self, user_id: str) -> bool: async def claim_e2e_one_time_keys( self, query: Iterable[Tuple[str, str, str]] - ) -> Tuple[ - Iterable[Dict[str, Dict[str, Dict[str, JsonDict]]]], List[Tuple[str, str, str]] - ]: + ) -> Tuple[Dict[str, Dict[str, Dict[str, JsonDict]]], List[Tuple[str, str, str]]]: """Claim one time keys from application services. Users which are exclusively owned by an application service are sent a @@ -856,7 +854,7 @@ async def claim_e2e_one_time_keys( Returns: A tuple of: - An iterable of maps of user ID -> a map device ID -> a map of key ID -> JSON bytes. + A map of user ID -> a map device ID -> a map of key ID -> JSON. A copy of the input which has not been fulfilled (either because they are not appservice users or the appservice does not support @@ -897,12 +895,11 @@ async def claim_e2e_one_time_keys( ) # Patch together the results -- they are all independent (since they - # require exclusive control over the users). They get returned as a list - # and the caller combines them. - claimed_keys: List[Dict[str, Dict[str, Dict[str, JsonDict]]]] = [] + # require exclusive control over the users, which is the outermost key). + claimed_keys: Dict[str, Dict[str, Dict[str, JsonDict]]] = {} for success, result in results: if success: - claimed_keys.append(result[0]) + claimed_keys.update(result[0]) missing.extend(result[1]) return claimed_keys, missing diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index 007366747014..d1ab95126c0b 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -563,7 +563,9 @@ async def on_federation_query_client_keys( return ret async def claim_local_one_time_keys( - self, local_query: List[Tuple[str, str, str]] + self, + local_query: List[Tuple[str, str, str]], + always_include_fallback_keys: bool, ) -> Iterable[Dict[str, Dict[str, Dict[str, JsonDict]]]]: """Claim one time keys for local users. @@ -573,6 +575,7 @@ async def claim_local_one_time_keys( Args: local_query: An iterable of tuples of (user ID, device ID, algorithm). + always_include_fallback_keys: True to always include fallback keys. Returns: An iterable of maps of user ID -> a map device ID -> a map of key ID -> JSON bytes. @@ -583,24 +586,73 @@ async def claim_local_one_time_keys( # If the application services have not provided any keys via the C-S # API, query it directly for one-time keys. if self._query_appservices_for_otks: + # TODO Should this query for fallback keys of uploaded OTKs if + # always_include_fallback_keys is True? The MSC is ambiguous. ( appservice_results, not_found, ) = await self._appservice_handler.claim_e2e_one_time_keys(not_found) else: - appservice_results = [] + appservice_results = {} + + # Calculate which user ID / device ID / algorithm tuples to get fallback + # keys for. This can be either only missing results *or* all results + # (which don't already have a fallback key). + if always_include_fallback_keys: + # Build the fallback query as any part of the original query where + # the appservice didn't respond with a fallback key. + fallback_query = [] + + # Iterate each item in the original query and search the results + # from the appservice for that user ID / device ID. If it is found, + # check if any of the keys match the requested algorithm & are a + # fallback key. + for user_id, device_id, algorithm in local_query: + # Check if the appservice responded for this query. + as_result = appservice_results.get(user_id, {}).get(device_id, {}) + found_otk = False + for key_id, key_json in as_result.items(): + if key_id.startswith(f"{algorithm}:"): + # A OTK or fallback key was found for this query. + found_otk = True + # A fallback key was found for this query, no need to + # query further. + if key_json.get("fallback", False): + break + + else: + # No fallback key was found from appservices, query for it. + # Only mark the fallback key as used if no OTK was found + # (from either the database or appservices). + mark_as_used = not found_otk and not any( + key_id.startswith(f"{algorithm}:") + for key_id in otk_results.get(user_id, {}) + .get(device_id, {}) + .keys() + ) + fallback_query.append((user_id, device_id, algorithm, mark_as_used)) + + else: + # All fallback keys get marked as used. + fallback_query = [ + (user_id, device_id, algorithm, True) + for user_id, device_id, algorithm in not_found + ] # For each user that does not have a one-time keys available, see if # there is a fallback key. - fallback_results = await self.store.claim_e2e_fallback_keys(not_found) + fallback_results = await self.store.claim_e2e_fallback_keys(fallback_query) # Return the results in order, each item from the input query should # only appear once in the combined list. - return (otk_results, *appservice_results, fallback_results) + return (otk_results, appservice_results, fallback_results) @trace async def claim_one_time_keys( - self, query: Dict[str, Dict[str, Dict[str, str]]], timeout: Optional[int] + self, + query: Dict[str, Dict[str, Dict[str, str]]], + timeout: Optional[int], + always_include_fallback_keys: bool, ) -> JsonDict: local_query: List[Tuple[str, str, str]] = [] remote_queries: Dict[str, Dict[str, Dict[str, str]]] = {} @@ -617,7 +669,9 @@ async def claim_one_time_keys( set_tag("local_key_query", str(local_query)) set_tag("remote_key_query", str(remote_queries)) - results = await self.claim_local_one_time_keys(local_query) + results = await self.claim_local_one_time_keys( + local_query, always_include_fallback_keys + ) # A map of user ID -> device ID -> key ID -> key. json_result: Dict[str, Dict[str, Dict[str, JsonDict]]] = {} @@ -625,7 +679,9 @@ async def claim_one_time_keys( for user_id, device_keys in result.items(): for device_id, keys in device_keys.items(): for key_id, key in keys.items(): - json_result.setdefault(user_id, {})[device_id] = {key_id: key} + json_result.setdefault(user_id, {}).setdefault( + device_id, {} + ).update({key_id: key}) # Remote failures. failures: Dict[str, JsonDict] = {} diff --git a/synapse/rest/client/keys.py b/synapse/rest/client/keys.py index 6209b79b019e..2a2509410961 100644 --- a/synapse/rest/client/keys.py +++ b/synapse/rest/client/keys.py @@ -15,6 +15,7 @@ # limitations under the License. import logging +import re from typing import TYPE_CHECKING, Any, Optional, Tuple from synapse.api.errors import InvalidAPICallError, SynapseError @@ -288,7 +289,33 @@ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: await self.auth.get_user_by_req(request, allow_guest=True) timeout = parse_integer(request, "timeout", 10 * 1000) body = parse_json_object_from_request(request) - result = await self.e2e_keys_handler.claim_one_time_keys(body, timeout) + result = await self.e2e_keys_handler.claim_one_time_keys( + body, timeout, always_include_fallback_keys=False + ) + return 200, result + + +class UnstableOneTimeKeyServlet(RestServlet): + """ + Identical to the stable endpoint (OneTimeKeyServlet) except it always includes + fallback keys in the response. + """ + + PATTERNS = [re.compile(r"^/_matrix/client/unstable/org.matrix.msc3983/keys/claim$")] + CATEGORY = "Encryption requests" + + def __init__(self, hs: "HomeServer"): + super().__init__() + self.auth = hs.get_auth() + self.e2e_keys_handler = hs.get_e2e_keys_handler() + + async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: + await self.auth.get_user_by_req(request, allow_guest=True) + timeout = parse_integer(request, "timeout", 10 * 1000) + body = parse_json_object_from_request(request) + result = await self.e2e_keys_handler.claim_one_time_keys( + body, timeout, always_include_fallback_keys=True + ) return 200, result @@ -394,6 +421,8 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: KeyQueryServlet(hs).register(http_server) KeyChangesServlet(hs).register(http_server) OneTimeKeyServlet(hs).register(http_server) + if hs.config.experimental.msc3983_appservice_otk_claims: + UnstableOneTimeKeyServlet(hs).register(http_server) if hs.config.worker.worker_app is None: SigningKeyUploadServlet(hs).register(http_server) SignaturesUploadServlet(hs).register(http_server) diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py index dc7768c50cab..1a4ae55304ba 100644 --- a/synapse/storage/databases/main/end_to_end_keys.py +++ b/synapse/storage/databases/main/end_to_end_keys.py @@ -1149,18 +1149,19 @@ def _claim_e2e_one_time_key_returning( return results, missing async def claim_e2e_fallback_keys( - self, query_list: Iterable[Tuple[str, str, str]] + self, query_list: Iterable[Tuple[str, str, str, bool]] ) -> Dict[str, Dict[str, Dict[str, JsonDict]]]: """Take a list of fallback keys out of the database. Args: - query_list: An iterable of tuples of (user ID, device ID, algorithm). + query_list: An iterable of tuples of + (user ID, device ID, algorithm, whether the key should be marked as used). Returns: A map of user ID -> a map device ID -> a map of key ID -> JSON. """ results: Dict[str, Dict[str, Dict[str, JsonDict]]] = {} - for user_id, device_id, algorithm in query_list: + for user_id, device_id, algorithm, mark_as_used in query_list: row = await self.db_pool.simple_select_one( table="e2e_fallback_keys_json", keyvalues={ @@ -1180,7 +1181,7 @@ async def claim_e2e_fallback_keys( used = row["used"] # Mark fallback key as used if not already. - if not used: + if not used and mark_as_used: await self.db_pool.simple_update_one( table="e2e_fallback_keys_json", keyvalues={ diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py index 013b9ee5504f..18edebd652fc 100644 --- a/tests/handlers/test_e2e_keys.py +++ b/tests/handlers/test_e2e_keys.py @@ -160,7 +160,9 @@ def test_claim_one_time_key(self) -> None: res2 = self.get_success( self.handler.claim_one_time_keys( - {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None + {"one_time_keys": {local_user: {device_id: "alg1"}}}, + timeout=None, + always_include_fallback_keys=False, ) ) self.assertEqual( @@ -203,7 +205,9 @@ def test_fallback_key(self) -> None: # key claim_res = self.get_success( self.handler.claim_one_time_keys( - {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None + {"one_time_keys": {local_user: {device_id: "alg1"}}}, + timeout=None, + always_include_fallback_keys=False, ) ) self.assertEqual( @@ -220,7 +224,9 @@ def test_fallback_key(self) -> None: # claiming an OTK again should return the same fallback key claim_res = self.get_success( self.handler.claim_one_time_keys( - {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None + {"one_time_keys": {local_user: {device_id: "alg1"}}}, + timeout=None, + always_include_fallback_keys=False, ) ) self.assertEqual( @@ -267,7 +273,9 @@ def test_fallback_key(self) -> None: claim_res = self.get_success( self.handler.claim_one_time_keys( - {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None + {"one_time_keys": {local_user: {device_id: "alg1"}}}, + timeout=None, + always_include_fallback_keys=False, ) ) self.assertEqual( @@ -277,7 +285,9 @@ def test_fallback_key(self) -> None: claim_res = self.get_success( self.handler.claim_one_time_keys( - {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None + {"one_time_keys": {local_user: {device_id: "alg1"}}}, + timeout=None, + always_include_fallback_keys=False, ) ) self.assertEqual( @@ -296,7 +306,9 @@ def test_fallback_key(self) -> None: claim_res = self.get_success( self.handler.claim_one_time_keys( - {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None + {"one_time_keys": {local_user: {device_id: "alg1"}}}, + timeout=None, + always_include_fallback_keys=False, ) ) self.assertEqual( @@ -304,6 +316,75 @@ def test_fallback_key(self) -> None: {"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key3}}}, ) + def test_fallback_key_always_returned(self) -> None: + local_user = "@boris:" + self.hs.hostname + device_id = "xyz" + fallback_key = {"alg1:k1": "fallback_key1"} + otk = {"alg1:k2": "key2"} + + # we shouldn't have any unused fallback keys yet + res = self.get_success( + self.store.get_e2e_unused_fallback_key_types(local_user, device_id) + ) + self.assertEqual(res, []) + + # Upload a OTK & fallback key. + self.get_success( + self.handler.upload_keys_for_user( + local_user, + device_id, + {"one_time_keys": otk, "fallback_keys": fallback_key}, + ) + ) + + # we should now have an unused alg1 key + fallback_res = self.get_success( + self.store.get_e2e_unused_fallback_key_types(local_user, device_id) + ) + self.assertEqual(fallback_res, ["alg1"]) + + # Claiming an OTK and requesting to always return the fallback key should + # return both. + claim_res = self.get_success( + self.handler.claim_one_time_keys( + {"one_time_keys": {local_user: {device_id: "alg1"}}}, + timeout=None, + always_include_fallback_keys=True, + ) + ) + self.assertEqual( + claim_res, + { + "failures": {}, + "one_time_keys": {local_user: {device_id: {**fallback_key, **otk}}}, + }, + ) + + # This should not mark the key as used. + fallback_res = self.get_success( + self.store.get_e2e_unused_fallback_key_types(local_user, device_id) + ) + self.assertEqual(fallback_res, ["alg1"]) + + # Claiming an OTK again should return only the fallback key. + claim_res = self.get_success( + self.handler.claim_one_time_keys( + {"one_time_keys": {local_user: {device_id: "alg1"}}}, + timeout=None, + always_include_fallback_keys=True, + ) + ) + self.assertEqual( + claim_res, + {"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key}}}, + ) + + # And mark it as used. + fallback_res = self.get_success( + self.store.get_e2e_unused_fallback_key_types(local_user, device_id) + ) + self.assertEqual(fallback_res, []) + def test_replace_master_key(self) -> None: """uploading a new signing key should make the old signing key unavailable""" local_user = "@boris:" + self.hs.hostname @@ -1004,6 +1085,7 @@ def test_query_appservice(self) -> None: } }, timeout=None, + always_include_fallback_keys=False, ) ) self.assertEqual( @@ -1016,6 +1098,153 @@ def test_query_appservice(self) -> None: }, ) + @override_config({"experimental_features": {"msc3983_appservice_otk_claims": True}}) + def test_query_appservice_with_fallback(self) -> None: + local_user = "@boris:" + self.hs.hostname + device_id_1 = "xyz" + fallback_key = {"alg1:k1": {"desc": "fallback_key1", "fallback": True}} + otk = {"alg1:k2": {"desc": "key2"}} + as_fallback_key = {"alg1:k3": {"desc": "fallback_key3", "fallback": True}} + as_otk = {"alg1:k4": {"desc": "key4"}} + + # Inject an appservice interested in this user. + appservice = ApplicationService( + token="i_am_an_app_service", + id="1234", + namespaces={"users": [{"regex": r"@boris:.+", "exclusive": True}]}, + # Note: this user does not have to match the regex above + sender="@as_main:test", + ) + self.hs.get_datastores().main.services_cache = [appservice] + self.hs.get_datastores().main.exclusive_user_regex = _make_exclusive_regex( + [appservice] + ) + + # Setup a response. + self.appservice_api.claim_client_keys.return_value = make_awaitable( + ({local_user: {device_id_1: {**as_otk, **as_fallback_key}}}, []) + ) + + # Claim OTKs, which will ask the appservice and do nothing else. + claim_res = self.get_success( + self.handler.claim_one_time_keys( + {"one_time_keys": {local_user: {device_id_1: "alg1"}}}, + timeout=None, + always_include_fallback_keys=True, + ) + ) + self.assertEqual( + claim_res, + { + "failures": {}, + "one_time_keys": { + local_user: {device_id_1: {**as_otk, **as_fallback_key}} + }, + }, + ) + + # Now upload a fallback key. + res = self.get_success( + self.store.get_e2e_unused_fallback_key_types(local_user, device_id_1) + ) + self.assertEqual(res, []) + + self.get_success( + self.handler.upload_keys_for_user( + local_user, + device_id_1, + {"fallback_keys": fallback_key}, + ) + ) + + # we should now have an unused alg1 key + fallback_res = self.get_success( + self.store.get_e2e_unused_fallback_key_types(local_user, device_id_1) + ) + self.assertEqual(fallback_res, ["alg1"]) + + # The appservice will return only the OTK. + self.appservice_api.claim_client_keys.return_value = make_awaitable( + ({local_user: {device_id_1: as_otk}}, []) + ) + + # Claim OTKs, which should return the OTK from the appservice and the + # uploaded fallback key. + claim_res = self.get_success( + self.handler.claim_one_time_keys( + {"one_time_keys": {local_user: {device_id_1: "alg1"}}}, + timeout=None, + always_include_fallback_keys=True, + ) + ) + self.assertEqual( + claim_res, + { + "failures": {}, + "one_time_keys": { + local_user: {device_id_1: {**as_otk, **fallback_key}} + }, + }, + ) + + # But the fallback key should not be marked as used. + fallback_res = self.get_success( + self.store.get_e2e_unused_fallback_key_types(local_user, device_id_1) + ) + self.assertEqual(fallback_res, ["alg1"]) + + # Now upload a OTK. + self.get_success( + self.handler.upload_keys_for_user( + local_user, + device_id_1, + {"one_time_keys": otk}, + ) + ) + + # Claim OTKs, which will return information only from the database. + claim_res = self.get_success( + self.handler.claim_one_time_keys( + {"one_time_keys": {local_user: {device_id_1: "alg1"}}}, + timeout=None, + always_include_fallback_keys=True, + ) + ) + self.assertEqual( + claim_res, + { + "failures": {}, + "one_time_keys": {local_user: {device_id_1: {**otk, **fallback_key}}}, + }, + ) + + # But the fallback key should not be marked as used. + fallback_res = self.get_success( + self.store.get_e2e_unused_fallback_key_types(local_user, device_id_1) + ) + self.assertEqual(fallback_res, ["alg1"]) + + # Finally, return only the fallback key from the appservice. + self.appservice_api.claim_client_keys.return_value = make_awaitable( + ({local_user: {device_id_1: as_fallback_key}}, []) + ) + + # Claim OTKs, which will return only the fallback key from the database. + claim_res = self.get_success( + self.handler.claim_one_time_keys( + {"one_time_keys": {local_user: {device_id_1: "alg1"}}}, + timeout=None, + always_include_fallback_keys=True, + ) + ) + self.assertEqual( + claim_res, + { + "failures": {}, + "one_time_keys": {local_user: {device_id_1: as_fallback_key}}, + }, + ) + @override_config({"experimental_features": {"msc3984_appservice_key_query": True}}) def test_query_local_devices_appservice(self) -> None: """Test that querying of appservices for keys overrides responses from the database."""