diff --git a/changelog.d/3541.feature b/changelog.d/3541.feature new file mode 100644 index 000000000000..24524136eac9 --- /dev/null +++ b/changelog.d/3541.feature @@ -0,0 +1 @@ +Optimisation to make handling incoming federation requests more efficient. \ No newline at end of file diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index d3ecebd29f20..20fb46fc89a7 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -43,7 +43,6 @@ add_hashes_and_signatures, compute_event_signature, ) -from synapse.events.utils import prune_event from synapse.events.validator import EventValidator from synapse.state import resolve_events_with_factory from synapse.types import UserID, get_domain_from_id @@ -52,8 +51,8 @@ from synapse.util.distributor import user_joined_room from synapse.util.frozenutils import unfreeze from synapse.util.logutils import log_function -from synapse.util.metrics import measure_func from synapse.util.retryutils import NotRetryingDestination +from synapse.visibility import filter_events_for_server from ._base import BaseHandler @@ -501,137 +500,6 @@ def _process_received_pdu(self, origin, pdu, state, auth_chain): user = UserID.from_string(event.state_key) yield user_joined_room(self.distributor, user, event.room_id) - @measure_func("_filter_events_for_server") - @defer.inlineCallbacks - def _filter_events_for_server(self, server_name, room_id, events): - """Filter the given events for the given server, redacting those the - server can't see. - - Assumes the server is currently in the room. - - Returns - list[FrozenEvent] - """ - # First lets check to see if all the events have a history visibility - # of "shared" or "world_readable". If thats the case then we don't - # need to check membership (as we know the server is in the room). - event_to_state_ids = yield self.store.get_state_ids_for_events( - frozenset(e.event_id for e in events), - types=( - (EventTypes.RoomHistoryVisibility, ""), - ) - ) - - visibility_ids = set() - for sids in event_to_state_ids.itervalues(): - hist = sids.get((EventTypes.RoomHistoryVisibility, "")) - if hist: - visibility_ids.add(hist) - - # If we failed to find any history visibility events then the default - # is "shared" visiblity. - if not visibility_ids: - defer.returnValue(events) - - event_map = yield self.store.get_events(visibility_ids) - all_open = all( - e.content.get("history_visibility") in (None, "shared", "world_readable") - for e in event_map.itervalues() - ) - - if all_open: - defer.returnValue(events) - - # Ok, so we're dealing with events that have non-trivial visibility - # rules, so we need to also get the memberships of the room. - - event_to_state_ids = yield self.store.get_state_ids_for_events( - frozenset(e.event_id for e in events), - types=( - (EventTypes.RoomHistoryVisibility, ""), - (EventTypes.Member, None), - ) - ) - - # We only want to pull out member events that correspond to the - # server's domain. - - def check_match(id): - try: - return server_name == get_domain_from_id(id) - except Exception: - return False - - # Parses mapping `event_id -> (type, state_key) -> state event_id` - # to get all state ids that we're interested in. - event_map = yield self.store.get_events([ - e_id - for key_to_eid in list(event_to_state_ids.values()) - for key, e_id in key_to_eid.items() - if key[0] != EventTypes.Member or check_match(key[1]) - ]) - - event_to_state = { - e_id: { - key: event_map[inner_e_id] - for key, inner_e_id in key_to_eid.iteritems() - if inner_e_id in event_map - } - for e_id, key_to_eid in event_to_state_ids.iteritems() - } - - erased_senders = yield self.store.are_users_erased( - e.sender for e in events, - ) - - def redact_disallowed(event, state): - # if the sender has been gdpr17ed, always return a redacted - # copy of the event. - if erased_senders[event.sender]: - logger.info( - "Sender of %s has been erased, redacting", - event.event_id, - ) - return prune_event(event) - - if not state: - return event - - history = state.get((EventTypes.RoomHistoryVisibility, ''), None) - if history: - visibility = history.content.get("history_visibility", "shared") - if visibility in ["invited", "joined"]: - # We now loop through all state events looking for - # membership states for the requesting server to determine - # if the server is either in the room or has been invited - # into the room. - for ev in state.itervalues(): - if ev.type != EventTypes.Member: - continue - try: - domain = get_domain_from_id(ev.state_key) - except Exception: - continue - - if domain != server_name: - continue - - memtype = ev.membership - if memtype == Membership.JOIN: - return event - elif memtype == Membership.INVITE: - if visibility == "invited": - return event - else: - return prune_event(event) - - return event - - defer.returnValue([ - redact_disallowed(e, event_to_state[e.event_id]) - for e in events - ]) - @log_function @defer.inlineCallbacks def backfill(self, dest, room_id, limit, extremities): @@ -1558,7 +1426,7 @@ def on_backfill_request(self, origin, room_id, pdu_list, limit): limit ) - events = yield self._filter_events_for_server(origin, room_id, events) + events = yield filter_events_for_server(self.store, origin, events) defer.returnValue(events) @@ -1605,8 +1473,8 @@ def get_persisted_pdu(self, origin, event_id): if not in_room: raise AuthError(403, "Host not in room.") - events = yield self._filter_events_for_server( - origin, event.room_id, [event] + events = yield filter_events_for_server( + self.store, origin, [event], ) event = events[0] defer.returnValue(event) @@ -1896,8 +1764,8 @@ def on_get_missing_events(self, origin, room_id, earliest_events, min_depth=min_depth, ) - missing_events = yield self._filter_events_for_server( - origin, room_id, missing_events, + missing_events = yield filter_events_for_server( + self.store, origin, missing_events, ) defer.returnValue(missing_events) diff --git a/synapse/visibility.py b/synapse/visibility.py index 015c2bab3743..dc33f61d2bab 100644 --- a/synapse/visibility.py +++ b/synapse/visibility.py @@ -16,10 +16,13 @@ import logging import operator +import six + from twisted.internet import defer from synapse.api.constants import EventTypes, Membership from synapse.events.utils import prune_event +from synapse.types import get_domain_from_id from synapse.util.logcontext import make_deferred_yieldable, preserve_fn logger = logging.getLogger(__name__) @@ -225,3 +228,141 @@ def allowed(event): # we turn it into a list before returning it. defer.returnValue(list(filtered_events)) + + +@defer.inlineCallbacks +def filter_events_for_server(store, server_name, events): + # First lets check to see if all the events have a history visibility + # of "shared" or "world_readable". If thats the case then we don't + # need to check membership (as we know the server is in the room). + event_to_state_ids = yield store.get_state_ids_for_events( + frozenset(e.event_id for e in events), + types=( + (EventTypes.RoomHistoryVisibility, ""), + ) + ) + + visibility_ids = set() + for sids in event_to_state_ids.itervalues(): + hist = sids.get((EventTypes.RoomHistoryVisibility, "")) + if hist: + visibility_ids.add(hist) + + # If we failed to find any history visibility events then the default + # is "shared" visiblity. + if not visibility_ids: + defer.returnValue(events) + + event_map = yield store.get_events(visibility_ids) + all_open = all( + e.content.get("history_visibility") in (None, "shared", "world_readable") + for e in event_map.itervalues() + ) + + if all_open: + defer.returnValue(events) + + # Ok, so we're dealing with events that have non-trivial visibility + # rules, so we need to also get the memberships of the room. + + # first, for each event we're wanting to return, get the event_ids + # of the history vis and membership state at those events. + event_to_state_ids = yield store.get_state_ids_for_events( + frozenset(e.event_id for e in events), + types=( + (EventTypes.RoomHistoryVisibility, ""), + (EventTypes.Member, None), + ) + ) + + # We only want to pull out member events that correspond to the + # server's domain. + # + # event_to_state_ids contains lots of duplicates, so it turns out to be + # cheaper to build a complete set of unique + # ((type, state_key), event_id) tuples, and then filter out the ones we + # don't want. + # + state_key_to_event_id_set = { + e + for key_to_eid in six.itervalues(event_to_state_ids) + for e in key_to_eid.items() + } + + def include(typ, state_key): + if typ != EventTypes.Member: + return True + + # we avoid using get_domain_from_id here for efficiency. + idx = state_key.find(":") + if idx == -1: + return False + return state_key[idx + 1:] == server_name + + event_map = yield store.get_events([ + e_id + for key, e_id in state_key_to_event_id_set + if include(key[0], key[1]) + ]) + + event_to_state = { + e_id: { + key: event_map[inner_e_id] + for key, inner_e_id in key_to_eid.iteritems() + if inner_e_id in event_map + } + for e_id, key_to_eid in event_to_state_ids.iteritems() + } + + erased_senders = yield store.are_users_erased( + e.sender for e in events, + ) + + def redact_disallowed(event, state): + # if the sender has been gdpr17ed, always return a redacted + # copy of the event. + if erased_senders[event.sender]: + logger.info( + "Sender of %s has been erased, redacting", + event.event_id, + ) + return prune_event(event) + + if not state: + return event + + history = state.get((EventTypes.RoomHistoryVisibility, ''), None) + if history: + visibility = history.content.get("history_visibility", "shared") + if visibility in ["invited", "joined"]: + # We now loop through all state events looking for + # membership states for the requesting server to determine + # if the server is either in the room or has been invited + # into the room. + for ev in state.itervalues(): + if ev.type != EventTypes.Member: + continue + try: + domain = get_domain_from_id(ev.state_key) + except Exception: + continue + + if domain != server_name: + continue + + memtype = ev.membership + if memtype == Membership.JOIN: + return event + elif memtype == Membership.INVITE: + if visibility == "invited": + return event + else: + # server has no users in the room: redact + return prune_event(event) + + return event + + defer.returnValue([ + redact_disallowed(e, event_to_state[e.event_id]) + for e in events + ]) diff --git a/tests/test_visibility.py b/tests/test_visibility.py new file mode 100644 index 000000000000..8436c29fe879 --- /dev/null +++ b/tests/test_visibility.py @@ -0,0 +1,261 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging + +from twisted.internet import defer +from twisted.internet.defer import succeed + +from synapse.events import FrozenEvent +from synapse.visibility import filter_events_for_server + +import tests.unittest +from tests.utils import setup_test_homeserver + +logger = logging.getLogger(__name__) + +TEST_ROOM_ID = "!TEST:ROOM" + + +class FilterEventsForServerTestCase(tests.unittest.TestCase): + @defer.inlineCallbacks + def setUp(self): + self.hs = yield setup_test_homeserver() + self.event_creation_handler = self.hs.get_event_creation_handler() + self.event_builder_factory = self.hs.get_event_builder_factory() + self.store = self.hs.get_datastore() + + @defer.inlineCallbacks + def test_filtering(self): + # + # The events to be filtered consist of 10 membership events (it doesn't + # really matter if they are joins or leaves, so let's make them joins). + # One of those membership events is going to be for a user on the + # server we are filtering for (so we can check the filtering is doing + # the right thing). + # + + # before we do that, we persist some other events to act as state. + self.inject_visibility("@admin:hs", "joined") + for i in range(0, 10): + yield self.inject_room_member("@resident%i:hs" % i) + + events_to_filter = [] + + for i in range(0, 10): + user = "@user%i:%s" % ( + i, "test_server" if i == 5 else "other_server" + ) + evt = yield self.inject_room_member(user, extra_content={"a": "b"}) + events_to_filter.append(evt) + + filtered = yield filter_events_for_server( + self.store, "test_server", events_to_filter, + ) + + # the result should be 5 redacted events, and 5 unredacted events. + for i in range(0, 5): + self.assertEqual(events_to_filter[i].event_id, filtered[i].event_id) + self.assertNotIn("a", filtered[i].content) + + for i in range(5, 10): + self.assertEqual(events_to_filter[i].event_id, filtered[i].event_id) + self.assertEqual(filtered[i].content["a"], "b") + + @defer.inlineCallbacks + def inject_visibility(self, user_id, visibility): + content = {"history_visibility": visibility} + builder = self.event_builder_factory.new({ + "type": "m.room.history_visibility", + "sender": user_id, + "state_key": "", + "room_id": TEST_ROOM_ID, + "content": content, + }) + + event, context = yield self.event_creation_handler.create_new_client_event( + builder + ) + yield self.hs.get_datastore().persist_event(event, context) + defer.returnValue(event) + + @defer.inlineCallbacks + def inject_room_member(self, user_id, membership="join", extra_content={}): + content = {"membership": membership} + content.update(extra_content) + builder = self.event_builder_factory.new({ + "type": "m.room.member", + "sender": user_id, + "state_key": user_id, + "room_id": TEST_ROOM_ID, + "content": content, + }) + + event, context = yield self.event_creation_handler.create_new_client_event( + builder + ) + + yield self.hs.get_datastore().persist_event(event, context) + defer.returnValue(event) + + @defer.inlineCallbacks + def test_large_room(self): + # see what happens when we have a large room with hundreds of thousands + # of membership events + + # As above, the events to be filtered consist of 10 membership events, + # where one of them is for a user on the server we are filtering for. + + import cProfile + import pstats + import time + + # we stub out the store, because building up all that state the normal + # way is very slow. + test_store = _TestStore() + + # our initial state is 100000 membership events and one + # history_visibility event. + room_state = [] + + history_visibility_evt = FrozenEvent({ + "event_id": "$history_vis", + "type": "m.room.history_visibility", + "sender": "@resident_user_0:test.com", + "state_key": "", + "room_id": TEST_ROOM_ID, + "content": {"history_visibility": "joined"}, + }) + room_state.append(history_visibility_evt) + test_store.add_event(history_visibility_evt) + + for i in range(0, 100000): + user = "@resident_user_%i:test.com" % (i, ) + evt = FrozenEvent({ + "event_id": "$res_event_%i" % (i, ), + "type": "m.room.member", + "state_key": user, + "sender": user, + "room_id": TEST_ROOM_ID, + "content": { + "membership": "join", + "extra": "zzz," + }, + }) + room_state.append(evt) + test_store.add_event(evt) + + events_to_filter = [] + for i in range(0, 10): + user = "@user%i:%s" % ( + i, "test_server" if i == 5 else "other_server" + ) + evt = FrozenEvent({ + "event_id": "$evt%i" % (i, ), + "type": "m.room.member", + "state_key": user, + "sender": user, + "room_id": TEST_ROOM_ID, + "content": { + "membership": "join", + "extra": "zzz", + }, + }) + events_to_filter.append(evt) + room_state.append(evt) + + test_store.add_event(evt) + test_store.set_state_ids_for_event(evt, { + (e.type, e.state_key): e.event_id for e in room_state + }) + + pr = cProfile.Profile() + pr.enable() + + logger.info("Starting filtering") + start = time.time() + filtered = yield filter_events_for_server( + test_store, "test_server", events_to_filter, + ) + logger.info("Filtering took %f seconds", time.time() - start) + + pr.disable() + with open("filter_events_for_server.profile", "w+") as f: + ps = pstats.Stats(pr, stream=f).sort_stats('cumulative') + ps.print_stats() + + # the result should be 5 redacted events, and 5 unredacted events. + for i in range(0, 5): + self.assertEqual(events_to_filter[i].event_id, filtered[i].event_id) + self.assertNotIn("extra", filtered[i].content) + + for i in range(5, 10): + self.assertEqual(events_to_filter[i].event_id, filtered[i].event_id) + self.assertEqual(filtered[i].content["extra"], "zzz") + + test_large_room.skip = "Disabled by default because it's slow" + + +class _TestStore(object): + """Implements a few methods of the DataStore, so that we can test + filter_events_for_server + + """ + def __init__(self): + # data for get_events: a map from event_id to event + self.events = {} + + # data for get_state_ids_for_events mock: a map from event_id to + # a map from (type_state_key) -> event_id for the state at that + # event + self.state_ids_for_events = {} + + def add_event(self, event): + self.events[event.event_id] = event + + def set_state_ids_for_event(self, event, state): + self.state_ids_for_events[event.event_id] = state + + def get_state_ids_for_events(self, events, types): + res = {} + include_memberships = False + for (type, state_key) in types: + if type == "m.room.history_visibility": + continue + if type != "m.room.member" or state_key is not None: + raise RuntimeError( + "Unimplemented: get_state_ids with type (%s, %s)" % + (type, state_key), + ) + include_memberships = True + + if include_memberships: + for event_id in events: + res[event_id] = self.state_ids_for_events[event_id] + + else: + k = ("m.room.history_visibility", "") + for event_id in events: + hve = self.state_ids_for_events[event_id][k] + res[event_id] = {k: hve} + + return succeed(res) + + def get_events(self, events): + return succeed({ + event_id: self.events[event_id] for event_id in events + }) + + def are_users_erased(self, users): + return succeed({u: False for u in users})