Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Federation for end-to-end key requests. #208

Merged
merged 4 commits into from
Aug 18, 2015
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions synapse/federation/federation_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,36 @@ def make_query(self, destination, query_type, args,
destination, query_type, args, retry_on_dns_fail=retry_on_dns_fail
)

@log_function
def query_client_keys(self, destination, content):
"""Query device keys for a device hosted on a remote server.

Args:
destination (str): Domain name of the remote homeserver
content (dict): The query content.

Returns:
a Deferred which will eventually yield a JSON object from the
response
"""
sent_queries_counter.inc("client_device_keys")
return self.transport_layer.query_client_keys(destination, content)

@log_function
def claim_client_keys(self, destination, content):
"""Claims one-time keys for a device hosted on a remote server.

Args:
destination (str): Domain name of the remote homeserver
content (dict): The query content.

Returns:
a Deferred which will eventually yield a JSON object from the
response
"""
sent_queries_counter.inc("client_one_time_keys")
return self.transport_layer.claim_client_keys(destination, content)

@defer.inlineCallbacks
@log_function
def backfill(self, dest, context, limit, extremities):
Expand Down
43 changes: 43 additions & 0 deletions synapse/federation/federation_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

from synapse.crypto.event_signing import compute_event_signature

import simplejson as json
import logging


Expand Down Expand Up @@ -312,6 +313,48 @@ def on_query_auth_request(self, origin, content, event_id):
(200, send_content)
)

@defer.inlineCallbacks
@log_function
def on_query_client_keys(self, origin, content):
query = []
for user_id, device_ids in content.get("device_keys", {}).items():
if not device_ids:
query.append((user_id, None))
else:
for device_id in device_ids:
query.append((user_id, device_id))

results = yield self.store.get_e2e_device_keys(query)

json_result = {}
for user_id, device_keys in results.items():
for device_id, json_bytes in device_keys.items():
json_result.setdefault(user_id, {})[device_id] = json.loads(
json_bytes
)

defer.returnValue({"device_keys": json_result})

@defer.inlineCallbacks
@log_function
def on_claim_client_keys(self, origin, content):
query = []
for user_id, device_keys in content.get("one_time_keys", {}).items():
for device_id, algorithm in device_keys.items():
query.append((user_id, device_id, algorithm))

results = yield self.store.claim_e2e_one_time_keys(query)

json_result = {}
for user_id, device_keys in results.items():
for device_id, keys in device_keys.items():
for key_id, json_bytes in keys.items():
json_result.setdefault(user_id, {})[device_id] = {
key_id: json.loads(json_bytes)
}

defer.returnValue({"one_time_keys": json_result})

@defer.inlineCallbacks
@log_function
def on_get_missing_events(self, origin, room_id, earliest_events,
Expand Down
70 changes: 70 additions & 0 deletions synapse/federation/transport/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,76 @@ def send_query_auth(self, destination, room_id, event_id, content):

defer.returnValue(content)

@defer.inlineCallbacks
@log_function
def query_client_keys(self, destination, query_content):
"""Query the device keys for a list of user ids hosted on a remote
server.

Request:
{
"device_keys": {
"<user_id>": ["<device_id>"]
} }

Response:
{
"device_keys": {
"<user_id>": {
"<device_id>": {...}
} } }

Args:
destination(str): The server to query.
query_content(dict): The user ids to query.
Returns:
A dict containg the device keys.
"""
path = PREFIX + "/user/keys/query"

content = yield self.client.post_json(
destination=destination,
path=path,
data=query_content,
)
defer.returnValue(content)

@defer.inlineCallbacks
@log_function
def claim_client_keys(self, destination, query_content):
"""Claim one-time keys for a list of devices hosted on a remote server.

Request:
{
"one_time_keys": {
"<user_id>": {
"<device_id>": "<algorithm>"
} } }

Response:
{
"device_keys": {
"<user_id>": {
"<device_id>": {
"<algorithm>:<key_id>": "<key_base64>"
} } } }

Args:
destination(str): The server to query.
query_content(dict): The user ids to query.
Returns:
A dict containg the one-time keys.
"""

path = PREFIX + "/user/keys/claim"

content = yield self.client.post_json(
destination=destination,
path=path,
data=query_content,
)
defer.returnValue(content)

@defer.inlineCallbacks
@log_function
def get_missing_events(self, destination, room_id, earliest_events,
Expand Down
20 changes: 20 additions & 0 deletions synapse/federation/transport/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,24 @@ def on_PUT(self, origin, content, query, context, event_id):
defer.returnValue((200, content))


class FederationClientKeysQueryServlet(BaseFederationServlet):
PATH = "/user/keys/query"

@defer.inlineCallbacks
def on_POST(self, origin, content, query):
response = yield self.handler.on_query_client_keys(origin, content)
defer.returnValue((200, response))


class FederationClientKeysClaimServlet(BaseFederationServlet):
PATH = "/user/keys/claim"

@defer.inlineCallbacks
def on_POST(self, origin, content, query):
response = yield self.handler.on_claim_client_keys(origin, content)
defer.returnValue((200, response))


class FederationQueryAuthServlet(BaseFederationServlet):
PATH = "/query_auth/([^/]*)/([^/]*)"

Expand Down Expand Up @@ -373,4 +391,6 @@ def on_POST(self, origin, content, query, room_id):
FederationQueryAuthServlet,
FederationGetMissingEventsServlet,
FederationEventAuthServlet,
FederationClientKeysQueryServlet,
FederationClientKeysClaimServlet,
)
100 changes: 70 additions & 30 deletions synapse/rest/client/v2_alpha/keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

from synapse.api.errors import SynapseError
from synapse.http.servlet import RestServlet
from synapse.types import UserID
from syutil.jsonutil import encode_canonical_json

from ._base import client_v2_pattern
Expand Down Expand Up @@ -164,45 +165,63 @@ def __init__(self, hs):
super(KeyQueryServlet, self).__init__()
self.store = hs.get_datastore()
self.auth = hs.get_auth()
self.federation = hs.get_replication_layer()
self.is_mine = hs.is_mine

@defer.inlineCallbacks
def on_POST(self, request, user_id, device_id):
logger.debug("onPOST")
yield self.auth.get_user_by_req(request)
try:
body = json.loads(request.content.read())
except:
raise SynapseError(400, "Invalid key JSON")
query = []
for user_id, device_ids in body.get("device_keys", {}).items():
if not device_ids:
query.append((user_id, None))
else:
for device_id in device_ids:
query.append((user_id, device_id))
results = yield self.store.get_e2e_device_keys(query)
defer.returnValue(self.json_result(request, results))
result = yield self.handle_request(body)
defer.returnValue(result)

@defer.inlineCallbacks
def on_GET(self, request, user_id, device_id):
auth_user, client_info = yield self.auth.get_user_by_req(request)
auth_user_id = auth_user.to_string()
if not user_id:
user_id = auth_user_id
if not device_id:
device_id = None
# Returns a map of user_id->device_id->json_bytes.
results = yield self.store.get_e2e_device_keys([(user_id, device_id)])
defer.returnValue(self.json_result(request, results))

def json_result(self, request, results):
user_id = user_id if user_id else auth_user_id
device_ids = [device_id] if device_id else []
result = yield self.handle_request(
{"device_keys": {user_id: device_ids}}
)
defer.returnValue(result)

@defer.inlineCallbacks
def handle_request(self, body):
local_query = []
remote_queries = {}
for user_id, device_ids in body.get("device_keys", {}).items():
user = UserID.from_string(user_id)
if self.is_mine(user):
if not device_ids:
local_query.append((user_id, None))
else:
for device_id in device_ids:
local_query.append((user_id, device_id))
else:
remote_queries.setdefault(user.domain, {})[user_id] = list(
device_ids
)
results = yield self.store.get_e2e_device_keys(local_query)

json_result = {}
for user_id, device_keys in results.items():
for device_id, json_bytes in device_keys.items():
json_result.setdefault(user_id, {})[device_id] = json.loads(
json_bytes
)
return (200, {"device_keys": json_result})

for destination, device_keys in remote_queries.items():
remote_result = yield self.federation.query_client_keys(
destination, {"device_keys": device_keys}
)
for user_id, keys in remote_result["device_keys"].items():
if user_id in device_keys:
json_result[user_id] = keys
defer.returnValue((200, {"device_keys": json_result}))


class OneTimeKeyServlet(RestServlet):
Expand Down Expand Up @@ -236,14 +255,16 @@ def __init__(self, hs):
self.store = hs.get_datastore()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.federation = hs.get_replication_layer()
self.is_mine = hs.is_mine

@defer.inlineCallbacks
def on_GET(self, request, user_id, device_id, algorithm):
yield self.auth.get_user_by_req(request)
results = yield self.store.claim_e2e_one_time_keys(
[(user_id, device_id, algorithm)]
result = yield self.handle_request(
{"one_time_keys": {user_id: {device_id: algorithm}}}
)
defer.returnValue(self.json_result(request, results))
defer.returnValue(result)

@defer.inlineCallbacks
def on_POST(self, request, user_id, device_id, algorithm):
Expand All @@ -252,22 +273,41 @@ def on_POST(self, request, user_id, device_id, algorithm):
body = json.loads(request.content.read())
except:
raise SynapseError(400, "Invalid key JSON")
query = []
result = yield self.handle_request(body)
defer.returnValue(result)

@defer.inlineCallbacks
def handle_request(self, body):
local_query = []
remote_queries = {}
for user_id, device_keys in body.get("one_time_keys", {}).items():
for device_id, algorithm in device_keys.items():
query.append((user_id, device_id, algorithm))
results = yield self.store.claim_e2e_one_time_keys(query)
defer.returnValue(self.json_result(request, results))
user = UserID.from_string(user_id)
if self.is_mine(user):
for device_id, algorithm in device_keys.items():
local_query.append((user_id, device_id, algorithm))
else:
remote_queries.setdefault(user.domain, {})[user_id] = (
device_keys
)
results = yield self.store.claim_e2e_one_time_keys(local_query)

def json_result(self, request, results):
json_result = {}
for user_id, device_keys in results.items():
for device_id, keys in device_keys.items():
for key_id, json_bytes in keys.items():
json_result.setdefault(user_id, {})[device_id] = {
key_id: json.loads(json_bytes)
}
return (200, {"one_time_keys": json_result})

for destination, device_keys in remote_queries.items():
remote_result = yield self.federation.claim_client_keys(
destination, {"one_time_keys": device_keys}
)
for user_id, keys in remote_result["one_time_keys"].items():
if user_id in device_keys:
json_result[user_id] = keys

defer.returnValue((200, {"one_time_keys": json_result}))


def register_servlets(hs, http_server):
Expand Down