diff --git a/changelog.d/7955.feature b/changelog.d/7955.feature new file mode 100644 index 000000000000..7d726046fe91 --- /dev/null +++ b/changelog.d/7955.feature @@ -0,0 +1 @@ +Add support for device dehydration. (MSC2697) diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index db417d60deb4..e1fd39356c4c 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -14,8 +14,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import json import logging -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Tuple from synapse.api import errors from synapse.api.constants import EventTypes @@ -28,6 +29,7 @@ from synapse.logging.opentracing import log_kv, set_tag, trace from synapse.metrics.background_process_metrics import run_as_background_process from synapse.types import ( + JsonDict, RoomStreamToken, get_domain_from_id, get_verify_key_from_cross_signing_key, @@ -489,6 +491,136 @@ async def user_left_room(self, user, room_id): # receive device updates. Mark this in DB. await self.store.mark_remote_user_device_list_as_unsubscribed(user_id) + async def store_dehydrated_device( + self, + user_id: str, + device_data: JsonDict, + initial_device_display_name: Optional[str] = None, + ) -> str: + """Store a dehydrated device for a user. If the user had a previous + dehydrated device, it is removed. + + Args: + user_id: the user that we are storing the device for + device_data: the dehydrated device information + initial_device_display_name: The display name to use for the device + Returns: + device id of the dehydrated device + """ + device_id = await self.check_device_registered( + user_id, None, initial_device_display_name, + ) + old_device_id = await self.store.store_dehydrated_device( + user_id, device_id, device_data + ) + if old_device_id is not None: + await self.delete_device(user_id, old_device_id) + return device_id + + async def get_dehydrated_device(self, user_id: str) -> Tuple[str, JsonDict]: + """Retrieve the information for a dehydrated device. + + Args: + user_id: the user whose dehydrated device we are looking for + Returns: + a tuple whose first item is the device ID, and the second item is + the dehydrated device information + """ + return await self.store.get_dehydrated_device(user_id) + + async def create_dehydration_token( + self, user_id: str, device_id: str, login_submission: JsonDict + ) -> str: + """Create a token for a client to fulfill a dehydration request. + + Args: + user_id: the user that we are creating the token for + device_id: the device ID for the dehydrated device. This is to + ensure that the device still exists when the user tells us + they want to use the dehydrated device. + login_submission: the contents of the login request. + Returns: + the dehydration token + """ + return await self.store.create_dehydration_token( + user_id, device_id, login_submission + ) + + async def rehydrate_device(self, token: str) -> dict: + """Process a rehydration request from the user. + + Args: + token: the dehydration token + Returns: + the login result, including the user's access token and device ID + """ + # FIXME: if can't find token, return 404 + token_info = await self.store.clear_dehydration_token(token, True) + + # normally, the constructor would do self.registration_handler = + # self.hs.get_registration_handler(), but doing that results in a + # circular dependency in the handlers. So do this for now + registration_handler = self.hs.get_registration_handler() + + if token_info["dehydrated"]: + # create access token for dehydrated device + initial_display_name = ( + None # FIXME: get display name from login submission? + ) + device_id, access_token = await registration_handler.register_device( + token_info.get("user_id"), + token_info.get("device_id"), + initial_display_name, + ) + + return { + "user_id": token_info["user_id"], + "access_token": access_token, + "home_server": self.hs.hostname, + "device_id": device_id, + } + + else: + # create device and access token from original login submission + login_submission = token_info["login_submission"] + device_id = login_submission.get("device_id") + initial_display_name = login_submission.get("initial_device_display_name") + device_id, access_token = await registration_handler.register_device( + token_info.get("user_id"), device_id, initial_display_name + ) + + return { + "user_id": token.info["user_id"], + "access_token": access_token, + "home_server": self.hs.hostname, + "device_id": device_id, + } + + async def cancel_rehydrate(self, token: str) -> dict: + """Cancel a rehydration request from the user and complete the user's login. + + Args: + token: the dehydration token + Returns: + the login result, including the user's access token and device ID + """ + # FIXME: if can't find token, return 404 + token_info = await self.store.clear_dehydration_token(token, False) + # create device and access token from original login submission + login_submission = token_info["login_submission"] + device_id = login_submission.get("device_id") + initial_display_name = login_submission.get("initial_device_display_name") + device_id, access_token = await self.registration_handler.register_device( + token_info.get("user_id"), device_id, initial_display_name + ) + + return { + "user_id": token_info.get("user_id"), + "access_token": access_token, + "home_server": self.hs.hostname, + "device_id": device_id, + } + def _update_device_from_client_ips(device, client_ips): ip = client_ips.get((device["user_id"], device["device_id"]), {}) diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index 379f668d6f8a..68fece986bcd 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -103,6 +103,7 @@ def __init__(self, hs): self.oidc_enabled = hs.config.oidc_enabled self.auth_handler = self.hs.get_auth_handler() + self.device_handler = hs.get_device_handler() self.registration_handler = hs.get_registration_handler() self.handlers = hs.get_handlers() self._well_known_builder = WellKnownBuilder(hs) @@ -339,6 +340,29 @@ async def _complete_login( ) user_id = canonical_uid + if login_submission.get("org.matrix.msc2697.restore_device"): + # user requested to rehydrate a device, so check if there they have + # a dehydrated device, and if so, allow them to try to rehydrate it + ( + device_id, + dehydrated_device, + ) = await self.device_handler.get_dehydrated_device(user_id) + if dehydrated_device: + token = await self.device_handler.create_dehydration_token( + user_id, device_id, login_submission + ) + result = { + "user_id": user_id, + "home_server": self.hs.hostname, + "device_data": dehydrated_device, + "device_id": device_id, + "dehydration_token": token, + } + + # FIXME: call callback? + + return result + device_id = login_submission.get("device_id") initial_display_name = login_submission.get("initial_device_display_name") device_id, access_token = await self.registration_handler.register_device( @@ -401,6 +425,96 @@ async def _do_jwt_login(self, login_submission: JsonDict) -> Dict[str, str]: return result +class RestoreDeviceServlet(RestServlet): + """Complete a rehydration request, either by letting the client use the + dehydrated device, or by creating a new device for the user. + + POST /org.matrix.msc2697/restore_device + Content-Type: application/json + + { + "rehydrate": true, + "dehydration_token": "an_opaque_token" + } + + HTTP/1.1 200 OK + Content-Type: application/json + + { // same format as the result from a /login request + "user_id": "@alice:example.org", + "device_id": "dehydrated_device", + "access_token": "another_opaque_token" + } + + """ + + PATTERNS = client_patterns("/org.matrix.msc2697/restore_device") + + def __init__(self, hs): + super(RestoreDeviceServlet, self).__init__() + self.hs = hs + self.device_handler = hs.get_device_handler() + self._well_known_builder = WellKnownBuilder(hs) + + async def on_POST(self, request: SynapseRequest): + submission = parse_json_object_from_request(request) + + if submission.get("rehydrate"): + result = await self.device_handler.rehydrate_device( + submission["dehydration_token"] + ) + else: + result = await self.device_handler.cancel_rehydrate( + submission["dehydration_token"] + ) + well_known_data = self._well_known_builder.get_well_known() + if well_known_data: + result["well_known"] = well_known_data + return (200, result) + + +class StoreDeviceServlet(RestServlet): + """Store a dehydrated device. + + POST /org.matrix.msc2697/device/dehydrate + Content-Type: application/json + + { + "device_data": { + "algorithm": "m.dehydration.v1.olm", + "account": "dehydrated_device" + } + } + + HTTP/1.1 200 OK + Content-Type: application/json + + { + "device_id": "dehydrated_device_id" + } + + """ + + PATTERNS = client_patterns("/org.matrix.msc2697/device/dehydrate") + + def __init__(self, hs): + super(StoreDeviceServlet, self).__init__() + self.hs = hs + self.auth = hs.get_auth() + self.device_handler = hs.get_device_handler() + + async def on_POST(self, request: SynapseRequest): + submission = parse_json_object_from_request(request) + requester = await self.auth.get_user_by_req(request) + + device_id = await self.device_handler.store_dehydrated_device( + requester.user.to_string(), + submission["device_data"], + submission.get("initial_device_display_name", None) + ) + return 200, {"device_id": device_id} + + class BaseSSORedirectServlet(RestServlet): """Common base class for /login/sso/redirect impls""" @@ -499,6 +613,8 @@ async def get_sso_url( def register_servlets(hs, http_server): LoginRestServlet(hs).register(http_server) + RestoreDeviceServlet(hs).register(http_server) + StoreDeviceServlet(hs).register(http_server) if hs.config.cas_enabled: CasRedirectServlet(hs).register(http_server) CasTicketServlet(hs).register(http_server) diff --git a/synapse/rest/client/v2_alpha/keys.py b/synapse/rest/client/v2_alpha/keys.py index 24bb090822a7..b86c8f598bb8 100644 --- a/synapse/rest/client/v2_alpha/keys.py +++ b/synapse/rest/client/v2_alpha/keys.py @@ -67,6 +67,7 @@ def __init__(self, hs): super(KeyUploadServlet, self).__init__() self.auth = hs.get_auth() self.e2e_keys_handler = hs.get_e2e_keys_handler() + self.device_handler = hs.get_device_handler() @trace(opname="upload_keys") async def on_POST(self, request, device_id): @@ -78,20 +79,25 @@ async def on_POST(self, request, device_id): # passing the device_id here is deprecated; however, we allow it # for now for compatibility with older clients. if requester.device_id is not None and device_id != requester.device_id: - set_tag("error", True) - log_kv( - { - "message": "Client uploading keys for a different device", - "logged_in_id": requester.device_id, - "key_being_uploaded": device_id, - } - ) - logger.warning( - "Client uploading keys for a different device " - "(logged in as %s, uploading for %s)", - requester.device_id, - device_id, - ) + ( + dehydrated_device_id, + _, + ) = await self.device_handler.get_dehydrated_device(user_id) + if device_id != dehydrated_device_id: + set_tag("error", True) + log_kv( + { + "message": "Client uploading keys for a different device", + "logged_in_id": requester.device_id, + "key_being_uploaded": device_id, + } + ) + logger.warning( + "Client uploading keys for a different device " + "(logged in as %s, uploading for %s)", + requester.device_id, + device_id, + ) else: device_id = requester.device_id diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index 88a7aadfc6c8..d6c6f0ac3497 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -35,7 +35,7 @@ LoggingTransaction, make_tuple_comparison_clause, ) -from synapse.types import Collection, get_verify_key_from_cross_signing_key +from synapse.types import Collection, JsonDict, get_verify_key_from_cross_signing_key from synapse.util.caches.descriptors import ( Cache, cached, @@ -43,7 +43,7 @@ cachedList, ) from synapse.util.iterutils import batch_iter -from synapse.util.stringutils import shortstr +from synapse.util.stringutils import random_string, shortstr logger = logging.getLogger(__name__) @@ -728,6 +728,168 @@ def _mark_remote_user_device_list_as_unsubscribed_txn(txn): _mark_remote_user_device_list_as_unsubscribed_txn, ) + async def get_dehydrated_device(self, user_id: str) -> Tuple[str, JsonDict]: + """Retrieve the information for a dehydrated device. + + Args: + user_id: the user whose dehydrated device we are looking for + Returns: + a tuple whose first item is the device ID, and the second item is + the dehydrated device information + """ + # FIXME: make sure device ID still exists in devices table + row = await self.db_pool.simple_select_one( + table="dehydrated_devices", + keyvalues={"user_id": user_id}, + retcols=["device_id", "device_data"], + allow_none=True, + ) + return (row["device_id"], json.loads(row["device_data"])) if row else (None, None) + + def _store_dehydrated_device_txn( + self, txn, user_id: str, device_id: str, device_data: str + ) -> Optional[str]: + old_device_id = self.db_pool.simple_select_one_onecol_txn( + txn, + table="dehydrated_devices", + keyvalues={"user_id": user_id}, + retcol="device_id", + allow_none=True, + ) + if old_device_id is None: + self.db_pool.simple_insert_txn( + txn, + table="dehydrated_devices", + values={ + "user_id": user_id, + "device_id": device_id, + "device_data": device_data, + }, + ) + else: + self.db_pool.simple_update_txn( + txn, + table="dehydrated_devices", + keyvalues={"user_id": user_id}, + updatevalues={"device_id": device_id, "device_data": device_data}, + ) + return old_device_id + + async def store_dehydrated_device( + self, user_id: str, device_id: str, device_data: JsonDict + ) -> Optional[str]: + """Store a dehydrated device for a user. + + Args: + user_id: the user that we are storing the device for + device_data: the dehydrated device information + initial_device_display_name: The display name to use for the device + Returns: + device id of the user's previous dehydrated device, if any + """ + return await self.db_pool.runInteraction( + "store_dehydrated_device_txn", + self._store_dehydrated_device_txn, + user_id, + device_id, + json.dumps(device_data), + ) + + async def create_dehydration_token( + self, user_id: str, device_id: str, login_submission: JsonDict + ) -> str: + """Create a token for a client to fulfill a dehydration request. + + Args: + user_id: the user that we are creating the token for + device_id: the device ID for the dehydrated device. This is to + ensure that the device still exists when the user tells us + they want to use the dehydrated device. + login_submission: the contents of the login request. + Returns: + the dehydration token + """ + # FIXME: expire any old tokens + + attempts = 0 + while attempts < 5: + token = random_string(24) + + try: + await self.db_pool.simple_insert( + table="dehydration_token", + values={ + "token": token, + "user_id": user_id, + "device_id": device_id, + "login_submission": json.dumps(login_submission), + "creation_time": self.hs.get_clock().time_msec(), + }, + desc="create_dehydration_token", + ) + return token + except self.db_pool.engine.module.IntegrityError: + attempts += 1 + raise StoreError(500, "Couldn't generate a token.") + + def _clear_dehydration_token_txn(self, txn, token: str, dehydrate: bool) -> dict: + token_info = self.db_pool.simple_select_one_txn( + txn, + "dehydration_token", + {"token": token}, + ["user_id", "device_id", "login_submission"], + ) + self.db_pool.simple_delete_one_txn( + txn, "dehydration_token", {"token": token}, + ) + token_info["login_submission"] = json.loads(token_info["login_submission"]) + + if dehydrate: + device_id = self.db_pool.simple_select_one_onecol_txn( + txn, + "dehydrated_devices", + keyvalues={"user_id": token_info["user_id"]}, + retcol="device_id", + allow_none=True, + ) + token_info["dehydrated"] = False + if device_id == token_info["device_id"]: + count = self.db_pool.simple_delete_txn( + txn, + "dehydrated_devices", + { + "user_id": token_info["user_id"], + "device_id": token_info["device_id"], + }, + ) + if count != 0: + token_info["dehydrated"] = True + + return token_info + + async def clear_dehydration_token(self, token: str, dehydrate: bool) -> dict: + """Use a dehydration token. If the client wishes to use the dehydrated + device, it will also remove the dehydrated device. + + Args: + token: the dehydration token + dehydrate: whether the client wishes to use the dehydrated device + Returns: + A dict giving the information related to the token. It will have + the following properties: + - user_id: the user associated from the token + - device_id: the ID of the dehydrated device + - login_submission: the original submission to /login + - dehydrated: (only present if the "dehydrate" parameter is True). + Whether the dehydrated device can be used by the client. + """ + return await self.db_pool.runInteraction( + "get_users_whose_devices_changed", + self._clear_dehydration_token_txn, + token, + dehydrate, + ) + class DeviceBackgroundUpdateStore(SQLBaseStore): def __init__(self, database: DatabasePool, db_conn, hs): diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py index 40354b8304dd..23f04a4887e3 100644 --- a/synapse/storage/databases/main/end_to_end_keys.py +++ b/synapse/storage/databases/main/end_to_end_keys.py @@ -641,6 +641,11 @@ def delete_e2e_keys_by_device_txn(txn): self._invalidate_cache_and_stream( txn, self.count_e2e_one_time_keys, (user_id, device_id) ) + self.db_pool.simple_delete_txn( + txn, + table="dehydrated_devices", + keyvalues={"user_id": user_id, "device_id": device_id}, + ) return self.db_pool.runInteraction( "delete_e2e_keys_by_device", delete_e2e_keys_by_device_txn diff --git a/synapse/storage/databases/main/schema/delta/58/11dehydration.sql b/synapse/storage/databases/main/schema/delta/58/11dehydration.sql new file mode 100644 index 000000000000..be5e8a47129d --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/58/11dehydration.sql @@ -0,0 +1,30 @@ +/* Copyright 2020 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE TABLE IF NOT EXISTS dehydrated_devices( + user_id TEXT NOT NULL PRIMARY KEY, + device_id TEXT NOT NULL, + device_data TEXT NOT NULL +); + +CREATE TABLE IF NOT EXISTS dehydration_token( + token TEXT NOT NULL PRIMARY KEY, + user_id TEXT NOT NULL, + device_id TEXT NOT NULL, + login_submission TEXT NOT NULL, + creation_time BIGINT NOT NULL +); + +-- FIXME: index on creation_time to expire old tokens diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py index db52725cfe73..d0c3f40e78be 100644 --- a/tests/rest/client/v1/test_login.py +++ b/tests/rest/client/v1/test_login.py @@ -754,3 +754,68 @@ def test_login_jwt_invalid_signature(self): channel.json_body["error"], "JWT validation failed: Signature verification failed", ) + + +class DehydrationTestCase(unittest.HomeserverTestCase): + + servlets = [ + synapse.rest.admin.register_servlets_for_client_rest_resource, + login.register_servlets, + logout.register_servlets, + devices.register_servlets, + ] + + def make_homeserver(self, reactor, clock): + self.hs = self.setup_test_homeserver() + self.hs.config.enable_registration = True + self.hs.config.registrations_require_3pid = [] + self.hs.config.auto_join_rooms = [] + self.hs.config.enable_registration_captcha = False + + return self.hs + + def test_dehydrate_and_rehydrate_device(self): + self.register_user("kermit", "monkey") + access_token = self.login("kermit", "monkey") + + # dehydrate a device + params = json.dumps({"device_data": "foobar"}) + request, channel = self.make_request( + b"POST", + b"/_matrix/client/unstable/org.matrix.msc2697/device/dehydrate", + params, + access_token=access_token, + ) + self.render(request) + self.assertEquals(channel.code, 200, channel.result) + dehydrated_device_id = channel.json_body["device_id"] + + # Log out + request, channel = self.make_request( + b"POST", "/logout", access_token=access_token + ) + self.render(request) + + # log in, requesting a dehydrated device + params = json.dumps( + { + "type": "m.login.password", + "user": "kermit", + "password": "monkey", + "org.matrix.msc2697.restore_device": True, + } + ) + request, channel = self.make_request("POST", "/_matrix/client/r0/login", params) + self.render(request) + self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.json_body["device_data"], "foobar") + self.assertEqual(channel.json_body["device_id"], dehydrated_device_id) + dehydration_token = channel.json_body["dehydration_token"] + + params = json.dumps({"rehydrate": True, "dehydration_token": dehydration_token}) + request, channel = self.make_request( + "POST", "/_matrix/client/unstable/org.matrix.msc2697/restore_device", params + ) + self.render(request) + self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.json_body["device_id"], dehydrated_device_id)