Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Cleanups in the Keyring #5001

Merged
merged 6 commits into from
Apr 8, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/5001.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Clean up some code in the server-key Keyring.
79 changes: 48 additions & 31 deletions synapse/crypto/keyring.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from six import raise_from
from six.moves import urllib

import nacl.signing
from signedjson.key import (
decode_verify_key_bytes,
encode_verify_key_base64,
Expand Down Expand Up @@ -494,11 +495,11 @@ def get_server_verify_key_v2_indirect(self, server_names_and_key_ids,
)

processed_response = yield self.process_v2_response(
perspective_name, response, only_from_server=False
perspective_name, response
)
server_name = response["server_name"]

for server_name, response_keys in processed_response.items():
keys.setdefault(server_name, {}).update(response_keys)
keys.setdefault(server_name, {}).update(processed_response)

yield logcontext.make_deferred_yieldable(defer.gatherResults(
[
Expand All @@ -517,7 +518,7 @@ def get_server_verify_key_v2_indirect(self, server_names_and_key_ids,

@defer.inlineCallbacks
def get_server_verify_key_v2_direct(self, server_name, key_ids):
keys = {}
keys = {} # type: dict[str, nacl.signing.VerifyKey]

for requested_key_id in key_ids:
if requested_key_id in keys:
Expand All @@ -542,6 +543,11 @@ def get_server_verify_key_v2_direct(self, server_name, key_ids):
or server_name not in response[u"signatures"]):
raise KeyLookupError("Key response not signed by remote server")

if response["server_name"] != server_name:
raise KeyLookupError("Expected a response for server %r not %r" % (
server_name, response["server_name"]
))

response_keys = yield self.process_v2_response(
from_server=server_name,
requested_ids=[requested_key_id],
Expand All @@ -550,24 +556,45 @@ def get_server_verify_key_v2_direct(self, server_name, key_ids):

keys.update(response_keys)

yield logcontext.make_deferred_yieldable(defer.gatherResults(
[
run_in_background(
self.store_keys,
server_name=key_server_name,
from_server=server_name,
verify_keys=verify_keys,
)
for key_server_name, verify_keys in keys.items()
],
consumeErrors=True
).addErrback(unwrapFirstError))

defer.returnValue(keys)
yield self.store_keys(
server_name=server_name,
from_server=server_name,
verify_keys=keys,
)
defer.returnValue({server_name: keys})

@defer.inlineCallbacks
def process_v2_response(self, from_server, response_json,
requested_ids=[], only_from_server=True):
def process_v2_response(
self, from_server, response_json, requested_ids=[],
):
"""Parse a 'Server Keys' structure from the result of a /key request

This is used to parse either the entirety of the response from
GET /_matrix/key/v2/server, or a single entry from the list returned by
POST /_matrix/key/v2/query.

Checks that each signature in the response that claims to come from the origin
server is valid. (Does not check that there actually is such a signature, for
some reason.)

Stores the json in server_keys_json so that it can be used for future responses
to /_matrix/key/v2/query.

Args:
from_server (str): the name of the server producing this result: either
the origin server for a /_matrix/key/v2/server request, or the notary
for a /_matrix/key/v2/query.

response_json (dict): the json-decoded Server Keys response object

requested_ids (iterable[str]): a list of the key IDs that were requested.
We will store the json for these key ids as well as any that are
actually in the response

Returns:
Deferred[dict[str, nacl.signing.VerifyKey]]:
map from key_id to key object
"""
time_now_ms = self.clock.time_msec()
response_keys = {}
verify_keys = {}
Expand All @@ -589,15 +616,7 @@ def process_v2_response(self, from_server, response_json,
verify_key.time_added = time_now_ms
old_verify_keys[key_id] = verify_key

results = {}
server_name = response_json["server_name"]
if only_from_server:
if server_name != from_server:
raise KeyLookupError(
"Expected a response for server %r not %r" % (
from_server, server_name
)
)
for key_id in response_json["signatures"].get(server_name, {}):
if key_id not in response_json["verify_keys"]:
raise KeyLookupError(
Expand Down Expand Up @@ -643,9 +662,7 @@ def process_v2_response(self, from_server, response_json,
consumeErrors=True,
).addErrback(unwrapFirstError))

results[server_name] = response_keys

defer.returnValue(results)
defer.returnValue(response_keys)

def store_keys(self, server_name, from_server, verify_keys):
"""Store a collection of verify keys for a given server
Expand Down
4 changes: 2 additions & 2 deletions synapse/storage/keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,8 +188,8 @@ def get_server_keys_json(self, server_keys):
Args:
server_keys (list): List of (server_name, key_id, source) triplets.
Returns:
Dict mapping (server_name, key_id, source) triplets to dicts with
"ts_valid_until_ms" and "key_json" keys.
Deferred[dict[Tuple[str, str, str|None], list[dict]]]:
Dict mapping (server_name, key_id, source) triplets to lists of dicts
"""

def _get_server_keys_json_txn(txn):
Expand Down
130 changes: 130 additions & 0 deletions tests/crypto/test_keyring.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,15 @@

from mock import Mock

import canonicaljson
import signedjson.key
import signedjson.sign

from twisted.internet import defer

from synapse.api.errors import SynapseError
from synapse.crypto import keyring
from synapse.crypto.keyring import KeyLookupError
from synapse.util import logcontext
from synapse.util.logcontext import LoggingContext

Expand All @@ -48,6 +50,9 @@ def get_signed_key(self, server_name, verify_key):
key_id: {"key": signedjson.key.encode_verify_key_base64(verify_key)}
},
}
return self.get_signed_response(res)

def get_signed_response(self, res):
signedjson.sign.sign_json(res, self.server_name, self.key)
return res

Expand Down Expand Up @@ -202,6 +207,131 @@ def test_verify_json_for_server(self):
self.assertFalse(d.called)
self.get_success(d)

def test_get_keys_from_server(self):
# arbitrarily advance the clock a bit
self.reactor.advance(100)

SERVER_NAME = "server2"
kr = keyring.Keyring(self.hs)
testkey = signedjson.key.generate_signing_key("ver1")
testverifykey = signedjson.key.get_verify_key(testkey)
testverifykey_id = "ed25519:ver1"
VALID_UNTIL_TS = 1000

# valid response
response = {
"server_name": SERVER_NAME,
"old_verify_keys": {},
"valid_until_ts": VALID_UNTIL_TS,
"verify_keys": {
testverifykey_id: {
"key": signedjson.key.encode_verify_key_base64(testverifykey)
}
},
}
signedjson.sign.sign_json(response, SERVER_NAME, testkey)

def get_json(destination, path, **kwargs):
self.assertEqual(destination, SERVER_NAME)
self.assertEqual(path, "/_matrix/key/v2/server/key1")
return response

self.http_client.get_json.side_effect = get_json

server_name_and_key_ids = [(SERVER_NAME, ("key1",))]
keys = self.get_success(kr.get_keys_from_server(server_name_and_key_ids))
k = keys[SERVER_NAME][testverifykey_id]
self.assertEqual(k, testverifykey)
self.assertEqual(k.alg, "ed25519")
self.assertEqual(k.version, "ver1")

# check that the perspectives store is correctly updated
lookup_triplet = (SERVER_NAME, testverifykey_id, None)
key_json = self.get_success(
self.hs.get_datastore().get_server_keys_json([lookup_triplet])
)
res = key_json[lookup_triplet]
self.assertEqual(len(res), 1)
res = res[0]
self.assertEqual(res["key_id"], testverifykey_id)
self.assertEqual(res["from_server"], SERVER_NAME)
self.assertEqual(res["ts_added_ms"], self.reactor.seconds() * 1000)
self.assertEqual(res["ts_valid_until_ms"], VALID_UNTIL_TS)

# we expect it to be encoded as canonical json *before* it hits the db
self.assertEqual(
bytes(res["key_json"]), canonicaljson.encode_canonical_json(response)
)

# change the server name: it should cause a rejection
response["server_name"] = "OTHER_SERVER"
self.get_failure(
kr.get_keys_from_server(server_name_and_key_ids), KeyLookupError
)

def test_get_keys_from_perspectives(self):
# arbitrarily advance the clock a bit
self.reactor.advance(100)

SERVER_NAME = "server2"
kr = keyring.Keyring(self.hs)
testkey = signedjson.key.generate_signing_key("ver1")
testverifykey = signedjson.key.get_verify_key(testkey)
testverifykey_id = "ed25519:ver1"
VALID_UNTIL_TS = 200 * 1000

# valid response
response = {
"server_name": SERVER_NAME,
"old_verify_keys": {},
"valid_until_ts": VALID_UNTIL_TS,
"verify_keys": {
testverifykey_id: {
"key": signedjson.key.encode_verify_key_base64(testverifykey)
}
},
}

persp_resp = {
"server_keys": [self.mock_perspective_server.get_signed_response(response)]
}

def post_json(destination, path, data, **kwargs):
self.assertEqual(destination, self.mock_perspective_server.server_name)
self.assertEqual(path, "/_matrix/key/v2/query")

# check that the request is for the expected key
q = data["server_keys"]
self.assertEqual(list(q[SERVER_NAME].keys()), ["key1"])
return persp_resp

self.http_client.post_json.side_effect = post_json

server_name_and_key_ids = [(SERVER_NAME, ("key1",))]
keys = self.get_success(kr.get_keys_from_perspectives(server_name_and_key_ids))
self.assertIn(SERVER_NAME, keys)
k = keys[SERVER_NAME][testverifykey_id]
self.assertEqual(k, testverifykey)
self.assertEqual(k.alg, "ed25519")
self.assertEqual(k.version, "ver1")

# check that the perspectives store is correctly updated
lookup_triplet = (SERVER_NAME, testverifykey_id, None)
key_json = self.get_success(
self.hs.get_datastore().get_server_keys_json([lookup_triplet])
)
res = key_json[lookup_triplet]
self.assertEqual(len(res), 1)
res = res[0]
self.assertEqual(res["key_id"], testverifykey_id)
self.assertEqual(res["ts_added_ms"], self.reactor.seconds() * 1000)
self.assertEqual(res["ts_valid_until_ms"], VALID_UNTIL_TS)

self.assertEqual(
bytes(res["key_json"]),
canonicaljson.encode_canonical_json(persp_resp["server_keys"][0]),
)


@defer.inlineCallbacks
def run_in_context(f, *args, **kwargs):
Expand Down