Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Add unstable /keys/claim endpoint which always returns fallback keys. (
Browse files Browse the repository at this point in the history
…#15462)

It can be useful to always return the fallback key when attempting to
claim keys. This adds an unstable endpoint for `/keys/claim` which
always returns fallback keys in addition to one-time-keys.

The fallback key(s) are not marked as "used" unless there are no
corresponding OTKs.

This is currently defined in MSC3983 (although likely to be split out
to a separate MSC). The endpoint shape may change or be requested
differently (i.e. a keyword parameter on the current endpoint), but the
core logic should be reasonable.
  • Loading branch information
clokep authored Apr 25, 2023
1 parent b39b02c commit 8e97394
Show file tree
Hide file tree
Showing 9 changed files with 371 additions and 29 deletions.
1 change: 1 addition & 0 deletions changelog.d/15462.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Update support for [MSC3983](https://github.com/matrix-org/matrix-spec-proposals/pull/3983) to allow always returning fallback-keys in a `/keys/claim` request.
6 changes: 4 additions & 2 deletions synapse/federation/federation_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -1005,15 +1005,17 @@ async def on_query_user_devices(

@trace
async def on_claim_client_keys(
self, origin: str, content: JsonDict
self, origin: str, content: JsonDict, always_include_fallback_keys: bool
) -> Dict[str, Any]:
query = []
for user_id, device_keys in content.get("one_time_keys", {}).items():
for device_id, algorithm in device_keys.items():
query.append((user_id, device_id, algorithm))

log_kv({"message": "Claiming one time keys.", "user, device pairs": query})
results = await self._e2e_keys_handler.claim_local_one_time_keys(query)
results = await self._e2e_keys_handler.claim_local_one_time_keys(
query, always_include_fallback_keys=always_include_fallback_keys
)

json_result: Dict[str, Dict[str, Dict[str, JsonDict]]] = {}
for result in results:
Expand Down
6 changes: 6 additions & 0 deletions synapse/federation/transport/server/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from synapse.federation.transport.server.federation import (
FEDERATION_SERVLET_CLASSES,
FederationAccountStatusServlet,
FederationUnstableClientKeysClaimServlet,
)
from synapse.http.server import HttpServer, JsonResource
from synapse.http.servlet import (
Expand Down Expand Up @@ -298,6 +299,11 @@ def register_servlets(
and not hs.config.experimental.msc3720_enabled
):
continue
if (
servletclass == FederationUnstableClientKeysClaimServlet
and not hs.config.experimental.msc3983_appservice_otk_claims
):
continue

servletclass(
hs=hs,
Expand Down
23 changes: 22 additions & 1 deletion synapse/federation/transport/server/federation.py
Original file line number Diff line number Diff line change
Expand Up @@ -577,7 +577,28 @@ class FederationClientKeysClaimServlet(BaseFederationServerServlet):
async def on_POST(
self, origin: str, content: JsonDict, query: Dict[bytes, List[bytes]]
) -> Tuple[int, JsonDict]:
response = await self.handler.on_claim_client_keys(origin, content)
response = await self.handler.on_claim_client_keys(
origin, content, always_include_fallback_keys=False
)
return 200, response


class FederationUnstableClientKeysClaimServlet(BaseFederationServerServlet):
"""
Identical to the stable endpoint (FederationClientKeysClaimServlet) except it
always includes fallback keys in the response.
"""

PREFIX = FEDERATION_UNSTABLE_PREFIX
PATH = "/user/keys/claim"
CATEGORY = "Federation requests"

async def on_POST(
self, origin: str, content: JsonDict, query: Dict[bytes, List[bytes]]
) -> Tuple[int, JsonDict]:
response = await self.handler.on_claim_client_keys(
origin, content, always_include_fallback_keys=True
)
return 200, response


Expand Down
13 changes: 5 additions & 8 deletions synapse/handlers/appservice.py
Original file line number Diff line number Diff line change
Expand Up @@ -842,9 +842,7 @@ async def _check_user_exists(self, user_id: str) -> bool:

async def claim_e2e_one_time_keys(
self, query: Iterable[Tuple[str, str, str]]
) -> Tuple[
Iterable[Dict[str, Dict[str, Dict[str, JsonDict]]]], List[Tuple[str, str, str]]
]:
) -> Tuple[Dict[str, Dict[str, Dict[str, JsonDict]]], List[Tuple[str, str, str]]]:
"""Claim one time keys from application services.
Users which are exclusively owned by an application service are sent a
Expand All @@ -856,7 +854,7 @@ async def claim_e2e_one_time_keys(
Returns:
A tuple of:
An iterable of maps of user ID -> a map device ID -> a map of key ID -> JSON bytes.
A map of user ID -> a map device ID -> a map of key ID -> JSON.
A copy of the input which has not been fulfilled (either because
they are not appservice users or the appservice does not support
Expand Down Expand Up @@ -897,12 +895,11 @@ async def claim_e2e_one_time_keys(
)

# Patch together the results -- they are all independent (since they
# require exclusive control over the users). They get returned as a list
# and the caller combines them.
claimed_keys: List[Dict[str, Dict[str, Dict[str, JsonDict]]]] = []
# require exclusive control over the users, which is the outermost key).
claimed_keys: Dict[str, Dict[str, Dict[str, JsonDict]]] = {}
for success, result in results:
if success:
claimed_keys.append(result[0])
claimed_keys.update(result[0])
missing.extend(result[1])

return claimed_keys, missing
Expand Down
70 changes: 63 additions & 7 deletions synapse/handlers/e2e_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -563,7 +563,9 @@ async def on_federation_query_client_keys(
return ret

async def claim_local_one_time_keys(
self, local_query: List[Tuple[str, str, str]]
self,
local_query: List[Tuple[str, str, str]],
always_include_fallback_keys: bool,
) -> Iterable[Dict[str, Dict[str, Dict[str, JsonDict]]]]:
"""Claim one time keys for local users.
Expand All @@ -573,6 +575,7 @@ async def claim_local_one_time_keys(
Args:
local_query: An iterable of tuples of (user ID, device ID, algorithm).
always_include_fallback_keys: True to always include fallback keys.
Returns:
An iterable of maps of user ID -> a map device ID -> a map of key ID -> JSON bytes.
Expand All @@ -583,24 +586,73 @@ async def claim_local_one_time_keys(
# If the application services have not provided any keys via the C-S
# API, query it directly for one-time keys.
if self._query_appservices_for_otks:
# TODO Should this query for fallback keys of uploaded OTKs if
# always_include_fallback_keys is True? The MSC is ambiguous.
(
appservice_results,
not_found,
) = await self._appservice_handler.claim_e2e_one_time_keys(not_found)
else:
appservice_results = []
appservice_results = {}

# Calculate which user ID / device ID / algorithm tuples to get fallback
# keys for. This can be either only missing results *or* all results
# (which don't already have a fallback key).
if always_include_fallback_keys:
# Build the fallback query as any part of the original query where
# the appservice didn't respond with a fallback key.
fallback_query = []

# Iterate each item in the original query and search the results
# from the appservice for that user ID / device ID. If it is found,
# check if any of the keys match the requested algorithm & are a
# fallback key.
for user_id, device_id, algorithm in local_query:
# Check if the appservice responded for this query.
as_result = appservice_results.get(user_id, {}).get(device_id, {})
found_otk = False
for key_id, key_json in as_result.items():
if key_id.startswith(f"{algorithm}:"):
# A OTK or fallback key was found for this query.
found_otk = True
# A fallback key was found for this query, no need to
# query further.
if key_json.get("fallback", False):
break

else:
# No fallback key was found from appservices, query for it.
# Only mark the fallback key as used if no OTK was found
# (from either the database or appservices).
mark_as_used = not found_otk and not any(
key_id.startswith(f"{algorithm}:")
for key_id in otk_results.get(user_id, {})
.get(device_id, {})
.keys()
)
fallback_query.append((user_id, device_id, algorithm, mark_as_used))

else:
# All fallback keys get marked as used.
fallback_query = [
(user_id, device_id, algorithm, True)
for user_id, device_id, algorithm in not_found
]

# For each user that does not have a one-time keys available, see if
# there is a fallback key.
fallback_results = await self.store.claim_e2e_fallback_keys(not_found)
fallback_results = await self.store.claim_e2e_fallback_keys(fallback_query)

# Return the results in order, each item from the input query should
# only appear once in the combined list.
return (otk_results, *appservice_results, fallback_results)
return (otk_results, appservice_results, fallback_results)

@trace
async def claim_one_time_keys(
self, query: Dict[str, Dict[str, Dict[str, str]]], timeout: Optional[int]
self,
query: Dict[str, Dict[str, Dict[str, str]]],
timeout: Optional[int],
always_include_fallback_keys: bool,
) -> JsonDict:
local_query: List[Tuple[str, str, str]] = []
remote_queries: Dict[str, Dict[str, Dict[str, str]]] = {}
Expand All @@ -617,15 +669,19 @@ async def claim_one_time_keys(
set_tag("local_key_query", str(local_query))
set_tag("remote_key_query", str(remote_queries))

results = await self.claim_local_one_time_keys(local_query)
results = await self.claim_local_one_time_keys(
local_query, always_include_fallback_keys
)

# A map of user ID -> device ID -> key ID -> key.
json_result: Dict[str, Dict[str, Dict[str, JsonDict]]] = {}
for result in results:
for user_id, device_keys in result.items():
for device_id, keys in device_keys.items():
for key_id, key in keys.items():
json_result.setdefault(user_id, {})[device_id] = {key_id: key}
json_result.setdefault(user_id, {}).setdefault(
device_id, {}
).update({key_id: key})

# Remote failures.
failures: Dict[str, JsonDict] = {}
Expand Down
31 changes: 30 additions & 1 deletion synapse/rest/client/keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# limitations under the License.

import logging
import re
from typing import TYPE_CHECKING, Any, Optional, Tuple

from synapse.api.errors import InvalidAPICallError, SynapseError
Expand Down Expand Up @@ -288,7 +289,33 @@ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
await self.auth.get_user_by_req(request, allow_guest=True)
timeout = parse_integer(request, "timeout", 10 * 1000)
body = parse_json_object_from_request(request)
result = await self.e2e_keys_handler.claim_one_time_keys(body, timeout)
result = await self.e2e_keys_handler.claim_one_time_keys(
body, timeout, always_include_fallback_keys=False
)
return 200, result


class UnstableOneTimeKeyServlet(RestServlet):
"""
Identical to the stable endpoint (OneTimeKeyServlet) except it always includes
fallback keys in the response.
"""

PATTERNS = [re.compile(r"^/_matrix/client/unstable/org.matrix.msc3983/keys/claim$")]
CATEGORY = "Encryption requests"

def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.e2e_keys_handler = hs.get_e2e_keys_handler()

async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
await self.auth.get_user_by_req(request, allow_guest=True)
timeout = parse_integer(request, "timeout", 10 * 1000)
body = parse_json_object_from_request(request)
result = await self.e2e_keys_handler.claim_one_time_keys(
body, timeout, always_include_fallback_keys=True
)
return 200, result


Expand Down Expand Up @@ -394,6 +421,8 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
KeyQueryServlet(hs).register(http_server)
KeyChangesServlet(hs).register(http_server)
OneTimeKeyServlet(hs).register(http_server)
if hs.config.experimental.msc3983_appservice_otk_claims:
UnstableOneTimeKeyServlet(hs).register(http_server)
if hs.config.worker.worker_app is None:
SigningKeyUploadServlet(hs).register(http_server)
SignaturesUploadServlet(hs).register(http_server)
9 changes: 5 additions & 4 deletions synapse/storage/databases/main/end_to_end_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -1149,18 +1149,19 @@ def _claim_e2e_one_time_key_returning(
return results, missing

async def claim_e2e_fallback_keys(
self, query_list: Iterable[Tuple[str, str, str]]
self, query_list: Iterable[Tuple[str, str, str, bool]]
) -> Dict[str, Dict[str, Dict[str, JsonDict]]]:
"""Take a list of fallback keys out of the database.
Args:
query_list: An iterable of tuples of (user ID, device ID, algorithm).
query_list: An iterable of tuples of
(user ID, device ID, algorithm, whether the key should be marked as used).
Returns:
A map of user ID -> a map device ID -> a map of key ID -> JSON.
"""
results: Dict[str, Dict[str, Dict[str, JsonDict]]] = {}
for user_id, device_id, algorithm in query_list:
for user_id, device_id, algorithm, mark_as_used in query_list:
row = await self.db_pool.simple_select_one(
table="e2e_fallback_keys_json",
keyvalues={
Expand All @@ -1180,7 +1181,7 @@ async def claim_e2e_fallback_keys(
used = row["used"]

# Mark fallback key as used if not already.
if not used:
if not used and mark_as_used:
await self.db_pool.simple_update_one(
table="e2e_fallback_keys_json",
keyvalues={
Expand Down
Loading

0 comments on commit 8e97394

Please sign in to comment.