Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Use servlets for /key/ endpoints #14229

Merged
merged 8 commits into from
Oct 20, 2022
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/14229.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Refactor `/key/` endpoints to use `RestServlet` classes.
2 changes: 1 addition & 1 deletion synapse/api/urls.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
FEDERATION_V2_PREFIX = FEDERATION_PREFIX + "/v2"
FEDERATION_UNSTABLE_PREFIX = FEDERATION_PREFIX + "/unstable"
STATIC_PREFIX = "/_matrix/static"
SERVER_KEY_V2_PREFIX = "/_matrix/key/v2"
SERVER_KEY_PREFIX = "/_matrix/key"
MEDIA_R0_PREFIX = "/_matrix/media/r0"
MEDIA_V3_PREFIX = "/_matrix/media/v3"
LEGACY_MEDIA_PREFIX = "/_matrix/media/v1"
Expand Down
18 changes: 7 additions & 11 deletions synapse/app/generic_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
LEGACY_MEDIA_PREFIX,
MEDIA_R0_PREFIX,
MEDIA_V3_PREFIX,
SERVER_KEY_V2_PREFIX,
SERVER_KEY_PREFIX,
)
from synapse.app import _base
from synapse.app._base import (
Expand Down Expand Up @@ -325,13 +325,13 @@ def _listen_http(self, listener_config: ListenerConfig) -> None:

presence.register_servlets(self, resource)

resources.update({CLIENT_API_PREFIX: resource})
resources[CLIENT_API_PREFIX] = resource

resources.update(build_synapse_client_resource_tree(self))
resources.update({"/.well-known": well_known_resource(self)})
resources["/.well-known"] = well_known_resource(self)

elif name == "federation":
resources.update({FEDERATION_PREFIX: TransportLayerServer(self)})
resources[FEDERATION_PREFIX] = TransportLayerServer(self)
elif name == "media":
if self.config.media.can_load_media_repo:
media_repo = self.get_media_repository_resource()
Expand Down Expand Up @@ -359,16 +359,12 @@ def _listen_http(self, listener_config: ListenerConfig) -> None:
# Only load the openid resource separately if federation resource
# is not specified since federation resource includes openid
# resource.
resources.update(
{
FEDERATION_PREFIX: TransportLayerServer(
self, servlet_groups=["openid"]
)
}
resources[FEDERATION_PREFIX] = TransportLayerServer(
self, servlet_groups=["openid"]
)

if name in ["keys", "federation"]:
resources[SERVER_KEY_V2_PREFIX] = KeyApiV2Resource(self)
resources[SERVER_KEY_PREFIX] = KeyApiV2Resource(self)
clokep marked this conversation as resolved.
Show resolved Hide resolved

if name == "replication":
resources[REPLICATION_PREFIX] = ReplicationRestResource(self)
Expand Down
24 changes: 8 additions & 16 deletions synapse/app/homeserver.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
LEGACY_MEDIA_PREFIX,
MEDIA_R0_PREFIX,
MEDIA_V3_PREFIX,
SERVER_KEY_V2_PREFIX,
SERVER_KEY_PREFIX,
STATIC_PREFIX,
)
from synapse.app import _base
Expand Down Expand Up @@ -215,30 +215,22 @@ def _configure_named_resource(
consent_resource: Resource = ConsentResource(self)
if compress:
consent_resource = gz_wrap(consent_resource)
resources.update({"/_matrix/consent": consent_resource})
resources["/_matrix/consent"] = consent_resource

if name == "federation":
federation_resource: Resource = TransportLayerServer(self)
if compress:
federation_resource = gz_wrap(federation_resource)
resources.update({FEDERATION_PREFIX: federation_resource})
resources[FEDERATION_PREFIX] = federation_resource

if name == "openid":
resources.update(
{
FEDERATION_PREFIX: TransportLayerServer(
self, servlet_groups=["openid"]
)
}
resources[FEDERATION_PREFIX] = TransportLayerServer(
self, servlet_groups=["openid"]
)

if name in ["static", "client"]:
resources.update(
{
STATIC_PREFIX: StaticResource(
os.path.join(os.path.dirname(synapse.__file__), "static")
)
}
resources[STATIC_PREFIX] = StaticResource(
os.path.join(os.path.dirname(synapse.__file__), "static")
)

if name in ["media", "federation", "client"]:
Expand All @@ -257,7 +249,7 @@ def _configure_named_resource(
)

if name in ["keys", "federation"]:
resources[SERVER_KEY_V2_PREFIX] = KeyApiV2Resource(self)
resources[SERVER_KEY_PREFIX] = KeyApiV2Resource(self)

if name == "metrics" and self.config.metrics.enable_metrics:
metrics_resource: Resource = MetricsResource(RegistryProxy)
Expand Down
19 changes: 11 additions & 8 deletions synapse/rest/key/v2/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,20 @@

from typing import TYPE_CHECKING

from twisted.web.resource import Resource

from .local_key_resource import LocalKey
from .remote_key_resource import RemoteKey
from synapse.http.server import HttpServer, JsonResource
from synapse.rest.key.v2.local_key_resource import LocalKey
from synapse.rest.key.v2.remote_key_resource import RemoteKey

if TYPE_CHECKING:
from synapse.server import HomeServer


class KeyApiV2Resource(Resource):
class KeyApiV2Resource(JsonResource):
def __init__(self, hs: "HomeServer"):
Resource.__init__(self)
self.putChild(b"server", LocalKey(hs))
self.putChild(b"query", RemoteKey(hs))
super().__init__(hs, canonical_json=True)
self.register_servlets(self, hs)

@staticmethod
def register_servlets(http_server: HttpServer, hs: "HomeServer") -> None:
LocalKey(hs).register(http_server)
RemoteKey(hs).register(http_server)
22 changes: 11 additions & 11 deletions synapse/rest/key/v2/local_key_resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,15 @@
# limitations under the License.

import logging
from typing import TYPE_CHECKING, Optional
import re
from typing import TYPE_CHECKING, Optional, Tuple

from canonicaljson import encode_canonical_json
from signedjson.sign import sign_json
from unpaddedbase64 import encode_base64

from twisted.web.resource import Resource
from twisted.web.server import Request

from synapse.http.server import respond_with_json_bytes
from synapse.http.site import SynapseRequest
from synapse.http.servlet import RestServlet
from synapse.types import JsonDict

if TYPE_CHECKING:
Expand All @@ -31,7 +30,7 @@
logger = logging.getLogger(__name__)


class LocalKey(Resource):
class LocalKey(RestServlet):
"""HTTP resource containing encoding the TLS X.509 certificate and NACL
signature verification keys for this server::

Expand Down Expand Up @@ -61,18 +60,17 @@ class LocalKey(Resource):
}
"""

isLeaf = True
PATTERNS = (re.compile("^/_matrix/key/v2/server(/(?P<key_id>[^/]*))?$"),)

def __init__(self, hs: "HomeServer"):
self.config = hs.config
self.clock = hs.get_clock()
self.update_response_body(self.clock.time_msec())
Resource.__init__(self)

def update_response_body(self, time_now_msec: int) -> None:
refresh_interval = self.config.key.key_refresh_interval
self.valid_until_ts = int(time_now_msec + refresh_interval)
self.response_body = encode_canonical_json(self.response_json_object())
self.response_body = self.response_json_object()

def response_json_object(self) -> JsonDict:
verify_keys = {}
Expand All @@ -99,9 +97,11 @@ def response_json_object(self) -> JsonDict:
json_object = sign_json(json_object, self.config.server.server_name, key)
return json_object

def render_GET(self, request: SynapseRequest) -> Optional[int]:
def on_GET(
self, request: Request, key_id: Optional[str] = None
) -> Tuple[int, JsonDict]:
time_now = self.clock.time_msec()
# Update the expiry time if less than half the interval remains.
if time_now + self.config.key.key_refresh_interval / 2 > self.valid_until_ts:
self.update_response_body(time_now)
return respond_with_json_bytes(request, 200, self.response_body)
return 200, self.response_body
73 changes: 42 additions & 31 deletions synapse/rest/key/v2/remote_key_resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,20 @@
# limitations under the License.

import logging
from typing import TYPE_CHECKING, Dict, Set
import re
from typing import TYPE_CHECKING, Dict, Optional, Set, Tuple

from signedjson.sign import sign_json

from synapse.api.errors import Codes, SynapseError
from twisted.web.server import Request

from synapse.crypto.keyring import ServerKeyFetcher
from synapse.http.server import DirectServeJsonResource, respond_with_json
from synapse.http.servlet import parse_integer, parse_json_object_from_request
from synapse.http.site import SynapseRequest
from synapse.http.server import HttpServer
from synapse.http.servlet import (
RestServlet,
parse_integer,
parse_json_object_from_request,
)
from synapse.types import JsonDict
from synapse.util import json_decoder
from synapse.util.async_helpers import yieldable_gather_results
Expand All @@ -32,7 +37,7 @@
logger = logging.getLogger(__name__)


class RemoteKey(DirectServeJsonResource):
class RemoteKey(RestServlet):
"""HTTP resource for retrieving the TLS certificate and NACL signature
verification keys for a collection of servers. Checks that the reported
X.509 TLS certificate matches the one used in the HTTPS connection. Checks
Expand Down Expand Up @@ -88,11 +93,7 @@ class RemoteKey(DirectServeJsonResource):
}
"""

isLeaf = True

def __init__(self, hs: "HomeServer"):
super().__init__()

self.fetcher = ServerKeyFetcher(hs)
self.store = hs.get_datastores().main
self.clock = hs.get_clock()
Expand All @@ -101,36 +102,48 @@ def __init__(self, hs: "HomeServer"):
)
self.config = hs.config

async def _async_render_GET(self, request: SynapseRequest) -> None:
assert request.postpath is not None
if len(request.postpath) == 1:
(server,) = request.postpath
query: dict = {server.decode("ascii"): {}}
elif len(request.postpath) == 2:
server, key_id = request.postpath
def register(self, http_server: HttpServer) -> None:
http_server.register_paths(
"GET",
(
re.compile(
"^/_matrix/key/v2/query/(?P<server>[^/]*)(/(?P<key_id>[^/]*))?$"
),
),
self.on_GET,
self.__class__.__name__,
)
http_server.register_paths(
"POST",
(re.compile("^/_matrix/key/v2/query$"),),
self.on_POST,
self.__class__.__name__,
)

async def on_GET(
self, request: Request, server: str, key_id: Optional[str] = None
) -> Tuple[int, JsonDict]:
if server and key_id:
minimum_valid_until_ts = parse_integer(request, "minimum_valid_until_ts")
arguments = {}
if minimum_valid_until_ts is not None:
arguments["minimum_valid_until_ts"] = minimum_valid_until_ts
query = {server.decode("ascii"): {key_id.decode("ascii"): arguments}}
query = {server: {key_id: arguments}}
else:
raise SynapseError(404, "Not found %r" % request.postpath, Codes.NOT_FOUND)
query = {server: {}}

await self.query_keys(request, query, query_remote_on_cache_miss=True)
return 200, await self.query_keys(query, query_remote_on_cache_miss=True)

async def _async_render_POST(self, request: SynapseRequest) -> None:
async def on_POST(self, request: Request) -> Tuple[int, JsonDict]:
content = parse_json_object_from_request(request)

query = content["server_keys"]

await self.query_keys(request, query, query_remote_on_cache_miss=True)
return 200, await self.query_keys(query, query_remote_on_cache_miss=True)

async def query_keys(
self,
request: SynapseRequest,
query: JsonDict,
query_remote_on_cache_miss: bool = False,
) -> None:
self, query: JsonDict, query_remote_on_cache_miss: bool = False
) -> JsonDict:
logger.info("Handling query for keys %r", query)

store_queries = []
Expand Down Expand Up @@ -232,7 +245,7 @@ async def query_keys(
for server_name, keys in cache_misses.items()
),
)
await self.query_keys(request, query, query_remote_on_cache_miss=False)
return await self.query_keys(query, query_remote_on_cache_miss=False)
else:
signed_keys = []
for key_json_raw in json_results:
Expand All @@ -244,6 +257,4 @@ async def query_keys(

signed_keys.append(key_json)

response = {"server_keys": signed_keys}

respond_with_json(request, 200, response, canonical_json=True)
return {"server_keys": signed_keys}