Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Implement MSC3706: partial state in /send_join response (#11967)
Browse files Browse the repository at this point in the history
* Make `get_auth_chain_ids` return a Set

It has a set internally, and a set is often useful where it gets used, so let's
avoid converting to an intermediate list.

* Minor refactors in `on_send_join_request`

A little bit of non-functional groundwork

* Implement MSC3706: partial state in /send_join response
  • Loading branch information
richvdh authored Feb 12, 2022
1 parent b2b971f commit 63c4634
Show file tree
Hide file tree
Showing 7 changed files with 262 additions and 21 deletions.
1 change: 1 addition & 0 deletions changelog.d/11967.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Experimental implementation of [MSC3706](https://github.com/matrix-org/matrix-doc/pull/3706): extensions to `/send_join` to support reduced response size.
3 changes: 3 additions & 0 deletions synapse/config/experimental.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,3 +61,6 @@ def read_config(self, config: JsonDict, **kwargs):
self.msc2409_to_device_messages_enabled: bool = experimental.get(
"msc2409_to_device_messages_enabled", False
)

# MSC3706 (server-side support for partial state in /send_join responses)
self.msc3706_enabled: bool = experimental.get("msc3706_enabled", False)
91 changes: 81 additions & 10 deletions synapse/federation/federation_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
Any,
Awaitable,
Callable,
Collection,
Dict,
Iterable,
List,
Expand Down Expand Up @@ -64,7 +65,7 @@
ReplicationGetQueryRestServlet,
)
from synapse.storage.databases.main.lock import Lock
from synapse.types import JsonDict, get_domain_from_id
from synapse.types import JsonDict, StateMap, get_domain_from_id
from synapse.util import json_decoder, unwrapFirstError
from synapse.util.async_helpers import Linearizer, concurrently_execute, gather_results
from synapse.util.caches.response_cache import ResponseCache
Expand Down Expand Up @@ -571,7 +572,7 @@ async def _on_state_ids_request_compute(
) -> JsonDict:
state_ids = await self.handler.get_state_ids_for_pdu(room_id, event_id)
auth_chain_ids = await self.store.get_auth_chain_ids(room_id, state_ids)
return {"pdu_ids": state_ids, "auth_chain_ids": auth_chain_ids}
return {"pdu_ids": state_ids, "auth_chain_ids": list(auth_chain_ids)}

async def _on_context_state_request_compute(
self, room_id: str, event_id: Optional[str]
Expand Down Expand Up @@ -645,27 +646,61 @@ async def on_invite_request(
return {"event": ret_pdu.get_pdu_json(time_now)}

async def on_send_join_request(
self, origin: str, content: JsonDict, room_id: str
self,
origin: str,
content: JsonDict,
room_id: str,
caller_supports_partial_state: bool = False,
) -> Dict[str, Any]:
event, context = await self._on_send_membership_event(
origin, content, Membership.JOIN, room_id
)

prev_state_ids = await context.get_prev_state_ids()
state_ids = list(prev_state_ids.values())
auth_chain = await self.store.get_auth_chain(room_id, state_ids)
state = await self.store.get_events(state_ids)

state_event_ids: Collection[str]
servers_in_room: Optional[Collection[str]]
if caller_supports_partial_state:
state_event_ids = _get_event_ids_for_partial_state_join(
event, prev_state_ids
)
servers_in_room = await self.state.get_hosts_in_room_at_events(
room_id, event_ids=event.prev_event_ids()
)
else:
state_event_ids = prev_state_ids.values()
servers_in_room = None

auth_chain_event_ids = await self.store.get_auth_chain_ids(
room_id, state_event_ids
)

# if the caller has opted in, we can omit any auth_chain events which are
# already in state_event_ids
if caller_supports_partial_state:
auth_chain_event_ids.difference_update(state_event_ids)

auth_chain_events = await self.store.get_events_as_list(auth_chain_event_ids)
state_events = await self.store.get_events_as_list(state_event_ids)

# we try to do all the async stuff before this point, so that time_now is as
# accurate as possible.
time_now = self._clock.time_msec()
event_json = event.get_pdu_json()
return {
event_json = event.get_pdu_json(time_now)
resp = {
# TODO Remove the unstable prefix when servers have updated.
"org.matrix.msc3083.v2.event": event_json,
"event": event_json,
"state": [p.get_pdu_json(time_now) for p in state.values()],
"auth_chain": [p.get_pdu_json(time_now) for p in auth_chain],
"state": [p.get_pdu_json(time_now) for p in state_events],
"auth_chain": [p.get_pdu_json(time_now) for p in auth_chain_events],
"org.matrix.msc3706.partial_state": caller_supports_partial_state,
}

if servers_in_room is not None:
resp["org.matrix.msc3706.servers_in_room"] = list(servers_in_room)

return resp

async def on_make_leave_request(
self, origin: str, room_id: str, user_id: str
) -> Dict[str, Any]:
Expand Down Expand Up @@ -1339,3 +1374,39 @@ async def on_query(self, query_type: str, args: dict) -> JsonDict:
# error.
logger.warning("No handler registered for query type %s", query_type)
raise NotFoundError("No handler for Query type '%s'" % (query_type,))


def _get_event_ids_for_partial_state_join(
join_event: EventBase,
prev_state_ids: StateMap[str],
) -> Collection[str]:
"""Calculate state to be retuned in a partial_state send_join
Args:
join_event: the join event being send_joined
prev_state_ids: the event ids of the state before the join
Returns:
the event ids to be returned
"""

# return all non-member events
state_event_ids = {
event_id
for (event_type, state_key), event_id in prev_state_ids.items()
if event_type != EventTypes.Member
}

# we also need the current state of the current user (it's going to
# be an auth event for the new join, so we may as well return it)
current_membership_event_id = prev_state_ids.get(
(EventTypes.Member, join_event.state_key)
)
if current_membership_event_id is not None:
state_event_ids.add(current_membership_event_id)

# TODO: return a few more members:
# - those with invites
# - those that are kicked? / banned

return state_event_ids
20 changes: 19 additions & 1 deletion synapse/federation/transport/server/federation.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,6 +412,16 @@ class FederationV2SendJoinServlet(BaseFederationServerServlet):

PREFIX = FEDERATION_V2_PREFIX

def __init__(
self,
hs: "HomeServer",
authenticator: Authenticator,
ratelimiter: FederationRateLimiter,
server_name: str,
):
super().__init__(hs, authenticator, ratelimiter, server_name)
self._msc3706_enabled = hs.config.experimental.msc3706_enabled

async def on_PUT(
self,
origin: str,
Expand All @@ -422,7 +432,15 @@ async def on_PUT(
) -> Tuple[int, JsonDict]:
# TODO(paul): assert that event_id parsed from path actually
# match those given in content
result = await self.handler.on_send_join_request(origin, content, room_id)

partial_state = False
if self._msc3706_enabled:
partial_state = parse_boolean_from_args(
query, "org.matrix.msc3706.partial_state", default=False
)
result = await self.handler.on_send_join_request(
origin, content, room_id, caller_supports_partial_state=partial_state
)
return 200, result


Expand Down
12 changes: 6 additions & 6 deletions synapse/storage/databases/main/event_federation.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ async def get_auth_chain_ids(
room_id: str,
event_ids: Collection[str],
include_given: bool = False,
) -> List[str]:
) -> Set[str]:
"""Get auth events for given event_ids. The events *must* be state events.
Args:
Expand All @@ -130,7 +130,7 @@ async def get_auth_chain_ids(
include_given: include the given events in result
Returns:
list of event_ids
set of event_ids
"""

# Check if we have indexed the room so we can use the chain cover
Expand Down Expand Up @@ -159,7 +159,7 @@ async def get_auth_chain_ids(

def _get_auth_chain_ids_using_cover_index_txn(
self, txn: Cursor, room_id: str, event_ids: Collection[str], include_given: bool
) -> List[str]:
) -> Set[str]:
"""Calculates the auth chain IDs using the chain index."""

# First we look up the chain ID/sequence numbers for the given events.
Expand Down Expand Up @@ -272,11 +272,11 @@ def _get_auth_chain_ids_using_cover_index_txn(
txn.execute(sql, (chain_id, max_no))
results.update(r for r, in txn)

return list(results)
return results

def _get_auth_chain_ids_txn(
self, txn: LoggingTransaction, event_ids: Collection[str], include_given: bool
) -> List[str]:
) -> Set[str]:
"""Calculates the auth chain IDs.
This is used when we don't have a cover index for the room.
Expand Down Expand Up @@ -331,7 +331,7 @@ def _get_auth_chain_ids_txn(
front = new_front
results.update(front)

return list(results)
return results

async def get_auth_chain_difference(
self, room_id: str, state_sets: List[Set[str]]
Expand Down
148 changes: 148 additions & 0 deletions tests/federation/test_federation_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,21 @@

from parameterized import parameterized

from twisted.test.proto_helpers import MemoryReactor

from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.config.server import DEFAULT_ROOM_VERSION
from synapse.crypto.event_signing import add_hashes_and_signatures
from synapse.events import make_event_from_dict
from synapse.federation.federation_server import server_matches_acl_event
from synapse.rest import admin
from synapse.rest.client import login, room
from synapse.server import HomeServer
from synapse.types import JsonDict
from synapse.util import Clock

from tests import unittest
from tests.unittest import override_config


class FederationServerTests(unittest.FederatingHomeserverTestCase):
Expand Down Expand Up @@ -152,6 +161,145 @@ def test_needs_to_be_in_room(self):
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")


class SendJoinFederationTests(unittest.FederatingHomeserverTestCase):
servlets = [
admin.register_servlets,
room.register_servlets,
login.register_servlets,
]

def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer):
super().prepare(reactor, clock, hs)

# create the room
creator_user_id = self.register_user("kermit", "test")
tok = self.login("kermit", "test")
self._room_id = self.helper.create_room_as(
room_creator=creator_user_id, tok=tok
)

# a second member on the orgin HS
second_member_user_id = self.register_user("fozzie", "bear")
tok2 = self.login("fozzie", "bear")
self.helper.join(self._room_id, second_member_user_id, tok=tok2)

def _make_join(self, user_id) -> JsonDict:
channel = self.make_signed_federation_request(
"GET",
f"/_matrix/federation/v1/make_join/{self._room_id}/{user_id}"
f"?ver={DEFAULT_ROOM_VERSION}",
)
self.assertEquals(channel.code, 200, channel.json_body)
return channel.json_body

def test_send_join(self):
"""happy-path test of send_join"""
joining_user = "@misspiggy:" + self.OTHER_SERVER_NAME
join_result = self._make_join(joining_user)

join_event_dict = join_result["event"]
add_hashes_and_signatures(
KNOWN_ROOM_VERSIONS[DEFAULT_ROOM_VERSION],
join_event_dict,
signature_name=self.OTHER_SERVER_NAME,
signing_key=self.OTHER_SERVER_SIGNATURE_KEY,
)
channel = self.make_signed_federation_request(
"PUT",
f"/_matrix/federation/v2/send_join/{self._room_id}/x",
content=join_event_dict,
)
self.assertEquals(channel.code, 200, channel.json_body)

# we should get complete room state back
returned_state = [
(ev["type"], ev["state_key"]) for ev in channel.json_body["state"]
]
self.assertCountEqual(
returned_state,
[
("m.room.create", ""),
("m.room.power_levels", ""),
("m.room.join_rules", ""),
("m.room.history_visibility", ""),
("m.room.member", "@kermit:test"),
("m.room.member", "@fozzie:test"),
# nb: *not* the joining user
],
)

# also check the auth chain
returned_auth_chain_events = [
(ev["type"], ev["state_key"]) for ev in channel.json_body["auth_chain"]
]
self.assertCountEqual(
returned_auth_chain_events,
[
("m.room.create", ""),
("m.room.member", "@kermit:test"),
("m.room.power_levels", ""),
("m.room.join_rules", ""),
],
)

# the room should show that the new user is a member
r = self.get_success(
self.hs.get_state_handler().get_current_state(self._room_id)
)
self.assertEqual(r[("m.room.member", joining_user)].membership, "join")

@override_config({"experimental_features": {"msc3706_enabled": True}})
def test_send_join_partial_state(self):
"""When MSC3706 support is enabled, /send_join should return partial state"""
joining_user = "@misspiggy:" + self.OTHER_SERVER_NAME
join_result = self._make_join(joining_user)

join_event_dict = join_result["event"]
add_hashes_and_signatures(
KNOWN_ROOM_VERSIONS[DEFAULT_ROOM_VERSION],
join_event_dict,
signature_name=self.OTHER_SERVER_NAME,
signing_key=self.OTHER_SERVER_SIGNATURE_KEY,
)
channel = self.make_signed_federation_request(
"PUT",
f"/_matrix/federation/v2/send_join/{self._room_id}/x?org.matrix.msc3706.partial_state=true",
content=join_event_dict,
)
self.assertEquals(channel.code, 200, channel.json_body)

# expect a reduced room state
returned_state = [
(ev["type"], ev["state_key"]) for ev in channel.json_body["state"]
]
self.assertCountEqual(
returned_state,
[
("m.room.create", ""),
("m.room.power_levels", ""),
("m.room.join_rules", ""),
("m.room.history_visibility", ""),
],
)

# the auth chain should not include anything already in "state"
returned_auth_chain_events = [
(ev["type"], ev["state_key"]) for ev in channel.json_body["auth_chain"]
]
self.assertCountEqual(
returned_auth_chain_events,
[
("m.room.member", "@kermit:test"),
],
)

# the room should show that the new user is a member
r = self.get_success(
self.hs.get_state_handler().get_current_state(self._room_id)
)
self.assertEqual(r[("m.room.member", joining_user)].membership, "join")


def _create_acl_event(content):
return make_event_from_dict(
{
Expand Down
Loading

0 comments on commit 63c4634

Please sign in to comment.