Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Refactor storing of server keys (#16261)
Browse files Browse the repository at this point in the history
  • Loading branch information
erikjohnston authored Sep 12, 2023
1 parent 9400dc0 commit 2b35626
Show file tree
Hide file tree
Showing 6 changed files with 106 additions and 365 deletions.
1 change: 1 addition & 0 deletions changelog.d/16261.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Simplify server key storage.
35 changes: 7 additions & 28 deletions synapse/crypto/keyring.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,7 @@
get_verify_key,
is_signing_algorithm_supported,
)
from signedjson.sign import (
SignatureVerifyException,
encode_canonical_json,
signature_ids,
verify_signed_json,
)
from signedjson.sign import SignatureVerifyException, signature_ids, verify_signed_json
from signedjson.types import VerifyKey
from unpaddedbase64 import decode_base64

Expand Down Expand Up @@ -596,24 +591,12 @@ async def process_v2_response(
verify_key=verify_key, valid_until_ts=key_data["expired_ts"]
)

key_json_bytes = encode_canonical_json(response_json)

await make_deferred_yieldable(
defer.gatherResults(
[
run_in_background(
self.store.store_server_keys_json,
server_name=server_name,
key_id=key_id,
from_server=from_server,
ts_now_ms=time_added_ms,
ts_expires_ms=ts_valid_until_ms,
key_json_bytes=key_json_bytes,
)
for key_id in verify_keys
],
consumeErrors=True,
).addErrback(unwrapFirstError)
await self.store.store_server_keys_response(
server_name=server_name,
from_server=from_server,
ts_added_ms=time_added_ms,
verify_keys=verify_keys,
response_json=response_json,
)

return verify_keys
Expand Down Expand Up @@ -775,10 +758,6 @@ async def get_server_verify_key_v2_indirect(

keys.setdefault(server_name, {}).update(processed_response)

await self.store.store_server_signature_keys(
perspective_name, time_now_ms, added_keys
)

return keys

def _validate_perspectives_response(
Expand Down
219 changes: 72 additions & 147 deletions synapse/storage/databases/main/keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,17 @@
import itertools
import json
import logging
from typing import Dict, Iterable, Mapping, Optional, Tuple
from typing import Dict, Iterable, Optional, Tuple

from canonicaljson import encode_canonical_json
from signedjson.key import decode_verify_key_bytes
from unpaddedbase64 import decode_base64

from synapse.storage.database import LoggingTransaction
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
from synapse.storage.keys import FetchKeyResult, FetchKeyResultForRemote
from synapse.storage.types import Cursor
from synapse.types import JsonDict
from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.iterutils import batch_iter

Expand All @@ -36,162 +39,84 @@
class KeyStore(CacheInvalidationWorkerStore):
"""Persistence for signature verification keys"""

@cached()
def _get_server_signature_key(
self, server_name_and_key_id: Tuple[str, str]
) -> FetchKeyResult:
raise NotImplementedError()

@cachedList(
cached_method_name="_get_server_signature_key",
list_name="server_name_and_key_ids",
)
async def get_server_signature_keys(
self, server_name_and_key_ids: Iterable[Tuple[str, str]]
) -> Dict[Tuple[str, str], FetchKeyResult]:
"""
Args:
server_name_and_key_ids:
iterable of (server_name, key-id) tuples to fetch keys for
Returns:
A map from (server_name, key_id) -> FetchKeyResult, or None if the
key is unknown
"""
keys = {}

def _get_keys(txn: Cursor, batch: Tuple[Tuple[str, str], ...]) -> None:
"""Processes a batch of keys to fetch, and adds the result to `keys`."""

# batch_iter always returns tuples so it's safe to do len(batch)
sql = """
SELECT server_name, key_id, verify_key, ts_valid_until_ms
FROM server_signature_keys WHERE 1=0
""" + " OR (server_name=? AND key_id=?)" * len(
batch
)

txn.execute(sql, tuple(itertools.chain.from_iterable(batch)))

for row in txn:
server_name, key_id, key_bytes, ts_valid_until_ms = row

if ts_valid_until_ms is None:
# Old keys may be stored with a ts_valid_until_ms of null,
# in which case we treat this as if it was set to `0`, i.e.
# it won't match key requests that define a minimum
# `ts_valid_until_ms`.
ts_valid_until_ms = 0

keys[(server_name, key_id)] = FetchKeyResult(
verify_key=decode_verify_key_bytes(key_id, bytes(key_bytes)),
valid_until_ts=ts_valid_until_ms,
)

def _txn(txn: Cursor) -> Dict[Tuple[str, str], FetchKeyResult]:
for batch in batch_iter(server_name_and_key_ids, 50):
_get_keys(txn, batch)
return keys

return await self.db_pool.runInteraction("get_server_signature_keys", _txn)

async def store_server_signature_keys(
async def store_server_keys_response(
self,
server_name: str,
from_server: str,
ts_added_ms: int,
verify_keys: Mapping[Tuple[str, str], FetchKeyResult],
verify_keys: Dict[str, FetchKeyResult],
response_json: JsonDict,
) -> None:
"""Stores NACL verification keys for remote servers.
"""Stores the keys for the given server that we got from `from_server`.
Args:
from_server: Where the verification keys were looked up
ts_added_ms: The time to record that the key was added
verify_keys:
keys to be stored. Each entry is a triplet of
(server_name, key_id, key).
server_name: The owner of the keys
from_server: Which server we got the keys from
ts_added_ms: When we're adding the keys
verify_keys: The decoded keys
response_json: The full *signed* response JSON that contains the keys.
"""
key_values = []
value_values = []
invalidations = []
for (server_name, key_id), fetch_result in verify_keys.items():
key_values.append((server_name, key_id))
value_values.append(
(
from_server,
ts_added_ms,
fetch_result.valid_until_ts,
db_binary_type(fetch_result.verify_key.encode()),
)
)
# invalidate takes a tuple corresponding to the params of
# _get_server_signature_key. _get_server_signature_key only takes one
# param, which is itself the 2-tuple (server_name, key_id).
invalidations.append((server_name, key_id))

await self.db_pool.simple_upsert_many(
table="server_signature_keys",
key_names=("server_name", "key_id"),
key_values=key_values,
value_names=(
"from_server",
"ts_added_ms",
"ts_valid_until_ms",
"verify_key",
),
value_values=value_values,
desc="store_server_signature_keys",
)
key_json_bytes = encode_canonical_json(response_json)

def store_server_keys_response_txn(txn: LoggingTransaction) -> None:
self.db_pool.simple_upsert_many_txn(
txn,
table="server_signature_keys",
key_names=("server_name", "key_id"),
key_values=[(server_name, key_id) for key_id in verify_keys],
value_names=(
"from_server",
"ts_added_ms",
"ts_valid_until_ms",
"verify_key",
),
value_values=[
(
from_server,
ts_added_ms,
fetch_result.valid_until_ts,
db_binary_type(fetch_result.verify_key.encode()),
)
for fetch_result in verify_keys.values()
],
)

invalidate = self._get_server_signature_key.invalidate
for i in invalidations:
invalidate((i,))
self.db_pool.simple_upsert_many_txn(
txn,
table="server_keys_json",
key_names=("server_name", "key_id", "from_server"),
key_values=[
(server_name, key_id, from_server) for key_id in verify_keys
],
value_names=(
"ts_added_ms",
"ts_valid_until_ms",
"key_json",
),
value_values=[
(
ts_added_ms,
fetch_result.valid_until_ts,
db_binary_type(key_json_bytes),
)
for fetch_result in verify_keys.values()
],
)

async def store_server_keys_json(
self,
server_name: str,
key_id: str,
from_server: str,
ts_now_ms: int,
ts_expires_ms: int,
key_json_bytes: bytes,
) -> None:
"""Stores the JSON bytes for a set of keys from a server
The JSON should be signed by the originating server, the intermediate
server, and by this server. Updates the value for the
(server_name, key_id, from_server) triplet if one already existed.
Args:
server_name: The name of the server.
key_id: The identifier of the key this JSON is for.
from_server: The server this JSON was fetched from.
ts_now_ms: The time now in milliseconds.
ts_valid_until_ms: The time when this json stops being valid.
key_json_bytes: The encoded JSON.
"""
await self.db_pool.simple_upsert(
table="server_keys_json",
keyvalues={
"server_name": server_name,
"key_id": key_id,
"from_server": from_server,
},
values={
"server_name": server_name,
"key_id": key_id,
"from_server": from_server,
"ts_added_ms": ts_now_ms,
"ts_valid_until_ms": ts_expires_ms,
"key_json": db_binary_type(key_json_bytes),
},
desc="store_server_keys_json",
)
# invalidate takes a tuple corresponding to the params of
# _get_server_keys_json. _get_server_keys_json only takes one
# param, which is itself the 2-tuple (server_name, key_id).
for key_id in verify_keys:
self._invalidate_cache_and_stream(
txn, self._get_server_keys_json, ((server_name, key_id),)
)
self._invalidate_cache_and_stream(
txn, self.get_server_key_json_for_remote, (server_name, key_id)
)

# invalidate takes a tuple corresponding to the params of
# _get_server_keys_json. _get_server_keys_json only takes one
# param, which is itself the 2-tuple (server_name, key_id).
await self.invalidate_cache_and_stream(
"_get_server_keys_json", ((server_name, key_id),)
)
await self.invalidate_cache_and_stream(
"get_server_key_json_for_remote", (server_name, key_id)
await self.db_pool.runInteraction(
"store_server_keys_response", store_server_keys_response_txn
)

@cached()
Expand Down
53 changes: 13 additions & 40 deletions tests/crypto/test_keyring.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.
import time
from typing import Any, Dict, List, Optional, cast
from unittest.mock import AsyncMock, Mock
from unittest.mock import Mock

import attr
import canonicaljson
Expand Down Expand Up @@ -189,23 +189,24 @@ def test_verify_json_for_server(self) -> None:
kr = keyring.Keyring(self.hs)

key1 = signedjson.key.generate_signing_key("1")
r = self.hs.get_datastores().main.store_server_keys_json(
r = self.hs.get_datastores().main.store_server_keys_response(
"server9",
get_key_id(key1),
from_server="test",
ts_now_ms=int(time.time() * 1000),
ts_expires_ms=1000,
ts_added_ms=int(time.time() * 1000),
verify_keys={
get_key_id(key1): FetchKeyResult(
verify_key=get_verify_key(key1), valid_until_ts=1000
)
},
# The entire response gets signed & stored, just include the bits we
# care about.
key_json_bytes=canonicaljson.encode_canonical_json(
{
"verify_keys": {
get_key_id(key1): {
"key": encode_verify_key_base64(get_verify_key(key1))
}
response_json={
"verify_keys": {
get_key_id(key1): {
"key": encode_verify_key_base64(get_verify_key(key1))
}
}
),
},
)
self.get_success(r)

Expand Down Expand Up @@ -285,34 +286,6 @@ async def get_keys(
d = kr.verify_json_for_server(self.hs.hostname, json1, 0)
self.get_success(d)

def test_verify_json_for_server_with_null_valid_until_ms(self) -> None:
"""Tests that we correctly handle key requests for keys we've stored
with a null `ts_valid_until_ms`
"""
mock_fetcher = Mock()
mock_fetcher.get_keys = AsyncMock(return_value={})

key1 = signedjson.key.generate_signing_key("1")
r = self.hs.get_datastores().main.store_server_signature_keys(
"server9",
int(time.time() * 1000),
# None is not a valid value in FetchKeyResult, but we're abusing this
# API to insert null values into the database. The nulls get converted
# to 0 when fetched in KeyStore.get_server_signature_keys.
{("server9", get_key_id(key1)): FetchKeyResult(get_verify_key(key1), None)}, # type: ignore[arg-type]
)
self.get_success(r)

json1: JsonDict = {}
signedjson.sign.sign_json(json1, "server9", key1)

# should succeed on a signed object with a 0 minimum_valid_until_ms
d = self.hs.get_datastores().main.get_server_signature_keys(
[("server9", get_key_id(key1))]
)
result = self.get_success(d)
self.assertEqual(result[("server9", get_key_id(key1))].valid_until_ts, 0)

def test_verify_json_dedupes_key_requests(self) -> None:
"""Two requests for the same key should be deduped."""
key1 = signedjson.key.generate_signing_key("1")
Expand Down
Loading

0 comments on commit 2b35626

Please sign in to comment.