Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

[WIP] device dehydration #7955

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/7955.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add support for device dehydration.
uhoreg marked this conversation as resolved.
Show resolved Hide resolved
91 changes: 90 additions & 1 deletion synapse/handlers/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import logging
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Tuple

from synapse.api import errors
from synapse.api.constants import EventTypes
Expand All @@ -28,6 +29,7 @@
from synapse.logging.opentracing import log_kv, set_tag, trace
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.types import (
JsonDict,
RoomStreamToken,
get_domain_from_id,
get_verify_key_from_cross_signing_key,
Expand Down Expand Up @@ -489,6 +491,93 @@ async def user_left_room(self, user, room_id):
# receive device updates. Mark this in DB.
await self.store.mark_remote_user_device_list_as_unsubscribed(user_id)

async def store_dehydrated_device(
self,
user_id: str,
device_data: str,
initial_device_display_name: Optional[str] = None,
) -> str:
device_id = await self.check_device_registered(
user_id, None, initial_device_display_name,
)
old_device_id = await self.store.store_dehydrated_device(
user_id, device_id, device_data
)
if old_device_id is not None:
await self.delete_device(user_id, old_device_id)
return device_id

async def get_dehydrated_device(self, user_id: str) -> Tuple[str, str]:
return await self.store.get_dehydrated_device(user_id)

async def get_dehydration_token(
self, user_id: str, device_id: str, login_submission: JsonDict
) -> str:
return await self.store.create_dehydration_token(
user_id, device_id, json.dumps(login_submission)
)

async def rehydrate_device(self, token: str) -> dict:
# FIXME: if can't find token, return 404
token_info = await self.store.clear_dehydration_token(token, True)

# normally, the constructor would do self.registration_handler =
# self.hs.get_registration_handler(), but doing that results in a
# circular dependency in the handlers. So do this for now
registration_handler = self.hs.get_registration_handler()

if token_info["dehydrated"]:
# create access token for dehydrated device
initial_display_name = (
None # FIXME: get display name from login submission?
)
device_id, access_token = await registration_handler.register_device(
token_info.get("user_id"),
token_info.get("device_id"),
initial_display_name,
)

return {
"user_id": token_info.get("user_id"),
"access_token": access_token,
"home_server": self.hs.hostname,
"device_id": device_id,
}

else:
# create device and access token from original login submission
login_submission = token_info.get("login_submission")
device_id = login_submission.get("device_id")
initial_display_name = login_submission.get("initial_device_display_name")
device_id, access_token = await registration_handler.register_device(
token_info.get("user_id"), device_id, initial_display_name
)

return {
"user_id": token.info.get("user_id"),
"access_token": access_token,
"home_server": self.hs.hostname,
"device_id": device_id,
}

async def cancel_rehydrate(self, token: str) -> dict:
# FIXME: if can't find token, return 404
token_info = await self.store.clear_dehydration_token(token, False)
# create device and access token from original login submission
login_submission = token_info.get("login_submission")
device_id = login_submission.get("device_id")
initial_display_name = login_submission.get("initial_device_display_name")
device_id, access_token = await self.registration_handler.register_device(
token_info.get("user_id"), device_id, initial_display_name
)

return {
"user_id": token_info.get("user_id"),
"access_token": access_token,
"home_server": self.hs.hostname,
"device_id": device_id,
}


def _update_device_from_client_ips(device, client_ips):
ip = client_ips.get((device["user_id"], device["device_id"]), {})
Expand Down
70 changes: 70 additions & 0 deletions synapse/rest/client/v1/login.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ def __init__(self, hs):
self.oidc_enabled = hs.config.oidc_enabled

self.auth_handler = self.hs.get_auth_handler()
self.device_handler = hs.get_device_handler()
self.registration_handler = hs.get_registration_handler()
self.handlers = hs.get_handlers()
self._well_known_builder = WellKnownBuilder(hs)
Expand Down Expand Up @@ -339,6 +340,27 @@ async def _complete_login(
)
user_id = canonical_uid

if login_submission.get("org.matrix.msc2697.restore_device"):
(
device_id,
dehydrated_device,
) = await self.device_handler.get_dehydrated_device(user_id)
if dehydrated_device:
token = await self.device_handler.get_dehydration_token(
user_id, device_id, login_submission
)
result = {
"user_id": user_id,
"home_server": self.hs.hostname,
"device_data": dehydrated_device,
"device_id": device_id,
"dehydration_token": token,
}

# FIXME: call callback?
Copy link
Member

@clokep clokep Aug 14, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The documentation about what this callback might be (see the check_auth info in there) is less than ideal. It comes from password auth modules and might be None or a Callable that is "called with the result from the /login call (including access_token, device_id, etc.)"

I do not know of any auth providers which use this frankly, but I think calling it like below is the reasonable thing to do. Is there a particular reason you think it should not be called?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We talked a bit more about this and the reason we can't just call it here is that we don't yet have the access token (the client hasn't acknowledged that they've been able to use the dehydrated device).

It seems like we would ideally want to call this at the end of the call to /restore_device, but we no longer have the callback at that point. Will need to think about what the best option is here.


return result

device_id = login_submission.get("device_id")
initial_display_name = login_submission.get("initial_device_display_name")
device_id, access_token = await self.registration_handler.register_device(
Expand Down Expand Up @@ -401,6 +423,52 @@ async def _do_jwt_login(self, login_submission: JsonDict) -> Dict[str, str]:
return result


class RestoreDeviceServlet(RestServlet):
PATTERNS = client_patterns("/org.matrix.msc2697/restore_device")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if this is supposed to show up only in unstable or also r0 (the current code would do both).


def __init__(self, hs):
super(RestoreDeviceServlet, self).__init__()
self.hs = hs
self.device_handler = hs.get_device_handler()

async def on_POST(self, request: SynapseRequest):
submission = parse_json_object_from_request(request)

if submission.get("rehydrate"):
return (
200,
await self.device_handler.rehydrate_device(
submission.get("dehydration_token")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We probably want to use assert_params_in_dict for dehydration_token since it is required -- this will raise a sane error message. (This is true for any parameters that are required in the added endpoints.)

),
)
else:
return (
200,
await self.device_handler.cancel_rehydrate(
submission.get("dehydration_token")
),
)


class StoreDeviceServlet(RestServlet):
PATTERNS = client_patterns("/org.matrix.msc2697/device/dehydrate")

def __init__(self, hs):
super(StoreDeviceServlet, self).__init__()
self.hs = hs
self.auth = hs.get_auth()
self.device_handler = hs.get_device_handler()

async def on_POST(self, request: SynapseRequest):
submission = parse_json_object_from_request(request)
requester = await self.auth.get_user_by_req(request)

device_id = await self.device_handler.store_dehydrated_device(
requester.user.to_string(), submission.get("device_data")
)
return 200, {"device_id": device_id}


class BaseSSORedirectServlet(RestServlet):
"""Common base class for /login/sso/redirect impls"""

Expand Down Expand Up @@ -499,6 +567,8 @@ async def get_sso_url(

def register_servlets(hs, http_server):
LoginRestServlet(hs).register(http_server)
RestoreDeviceServlet(hs).register(http_server)
StoreDeviceServlet(hs).register(http_server)
if hs.config.cas_enabled:
CasRedirectServlet(hs).register(http_server)
CasTicketServlet(hs).register(http_server)
Expand Down
34 changes: 20 additions & 14 deletions synapse/rest/client/v2_alpha/keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ def __init__(self, hs):
super(KeyUploadServlet, self).__init__()
self.auth = hs.get_auth()
self.e2e_keys_handler = hs.get_e2e_keys_handler()
self.device_handler = hs.get_device_handler()

@trace(opname="upload_keys")
async def on_POST(self, request, device_id):
Expand All @@ -78,20 +79,25 @@ async def on_POST(self, request, device_id):
# passing the device_id here is deprecated; however, we allow it
# for now for compatibility with older clients.
if requester.device_id is not None and device_id != requester.device_id:
set_tag("error", True)
log_kv(
{
"message": "Client uploading keys for a different device",
"logged_in_id": requester.device_id,
"key_being_uploaded": device_id,
}
)
logger.warning(
"Client uploading keys for a different device "
"(logged in as %s, uploading for %s)",
requester.device_id,
device_id,
)
(
dehydrated_device_id,
_,
) = await self.device_handler.get_dehydrated_device(user_id)
if device_id != dehydrated_device_id:
set_tag("error", True)
log_kv(
{
"message": "Client uploading keys for a different device",
"logged_in_id": requester.device_id,
"key_being_uploaded": device_id,
}
)
logger.warning(
"Client uploading keys for a different device "
"(logged in as %s, uploading for %s)",
requester.device_id,
device_id,
)
else:
device_id = requester.device_id

Expand Down
119 changes: 118 additions & 1 deletion synapse/storage/databases/main/devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
cachedList,
)
from synapse.util.iterutils import batch_iter
from synapse.util.stringutils import shortstr
from synapse.util.stringutils import random_string, shortstr

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -728,6 +728,123 @@ def _mark_remote_user_device_list_as_unsubscribed_txn(txn):
_mark_remote_user_device_list_as_unsubscribed_txn,
)

async def get_dehydrated_device(self, user_id: str) -> Tuple[str, str]:
# FIXME: make sure device ID still exists in devices table
row = await self.db_pool.simple_select_one(
table="dehydrated_devices",
keyvalues={"user_id": user_id},
retcols=["device_id", "device_data"],
allow_none=True,
)
return (row["device_id"], row["device_data"]) if row else (None, None)

def _store_dehydrated_device_txn(
self, txn, user_id: str, device_id: str, device_data: str
) -> Optional[str]:
old_device_id = self.db_pool.simple_select_one_onecol_txn(
txn,
table="dehydrated_devices",
keyvalues={"user_id": user_id},
retcol="device_id",
allow_none=True,
)
if old_device_id is None:
self.db_pool.simple_insert_txn(
txn,
table="dehydrated_devices",
values={
"user_id": user_id,
"device_id": device_id,
"device_data": device_data,
},
)
else:
self.db_pool.simple_update_txn(
txn,
table="dehydrated_devices",
keyvalues={"user_id": user_id},
updatevalues={"device_id": device_id, "device_data": device_data},
)
return old_device_id

async def store_dehydrated_device(
self, user_id: str, device_id: str, device_data: str
) -> Optional[str]:
return await self.db_pool.runInteraction(
"store_dehydrated_device_txn",
self._store_dehydrated_device_txn,
user_id,
device_id,
device_data,
)

async def create_dehydration_token(
self, user_id: str, device_id: str, login_submission: str
) -> str:
# FIXME: expire any old tokens

attempts = 0
while attempts < 5:
token = random_string(24)

try:
await self.db_pool.simple_insert(
table="dehydration_token",
values={
"token": token,
"user_id": user_id,
"device_id": device_id,
"login_submission": login_submission,
"creation_time": self.hs.get_clock().time_msec(),
},
desc="create_dehydration_token",
)
return token
except self.db_pool.engine.module.IntegrityError:
attempts += 1
raise StoreError(500, "Couldn't generate a token.")

def _clear_dehydration_token_txn(self, txn, token: str, dehydrate: bool) -> dict:
token_info = self.db_pool.simple_select_one_txn(
txn,
"dehydration_token",
{"token": token},
["user_id", "device_id", "login_submission"],
)
self.db_pool.simple_delete_one_txn(
txn, "dehydration_token", {"token": token},
)

if dehydrate:
device = self.db_pool.simple_select_one_txn(
txn,
"dehydrated_devices",
{"user_id": token_info["user_id"]},
["device_id", "device_data"],
allow_none=True,
)
if device and device["device_id"] == token_info["device_id"]:
count = self.db_pool.simple_delete_txn(
txn,
"dehydrated_devices",
{
"user_id": token_info["user_id"],
"device_id": token_info["device_id"],
},
)
if count != 0:
token_info["dehydrated"] = True

return token_info

async def clear_dehydration_token(self, token: str, dehydrate: bool) -> dict:
return await self.db_pool.runInteraction(
"get_users_whose_devices_changed",
self._clear_dehydration_token_txn,
token,
dehydrate,
)


class DeviceBackgroundUpdateStore(SQLBaseStore):
def __init__(self, database: DatabasePool, db_conn, hs):
Expand Down
Loading