diff --git a/synapse/storage/events.py b/synapse/storage/events.py index d38e622f3cab..dfda39bbe055 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -79,7 +79,7 @@ def encode_json(json_object): """ out = frozendict_json_encoder.encode(json_object) if isinstance(out, bytes): - out = out.decode('utf8') + out = out.decode("utf8") return out @@ -813,9 +813,10 @@ def _persist_events_txn( """ all_events_and_contexts = events_and_contexts + min_stream_order = events_and_contexts[0][0].internal_metadata.stream_ordering max_stream_order = events_and_contexts[-1][0].internal_metadata.stream_ordering - self._update_current_state_txn(txn, state_delta_for_room, max_stream_order) + self._update_current_state_txn(txn, state_delta_for_room, min_stream_order) self._update_forward_extremities_txn( txn, @@ -890,7 +891,7 @@ def _persist_events_txn( backfilled=backfilled, ) - def _update_current_state_txn(self, txn, state_delta_by_room, max_stream_order): + def _update_current_state_txn(self, txn, state_delta_by_room, stream_id): for room_id, current_state_tuple in iteritems(state_delta_by_room): to_delete, to_insert = current_state_tuple @@ -899,6 +900,12 @@ def _update_current_state_txn(self, txn, state_delta_by_room, max_stream_order): # that we can use it to calculate the `prev_event_id`. (This # allows us to not have to pull out the existing state # unnecessarily). + # + # The stream_id for the update is chosen to be the minimum of the stream_ids + # for the batch of the events that we are persisting; that means we do not + # end up in a situation where workers see events before the + # current_state_delta updates. + # sql = """ INSERT INTO current_state_delta_stream (stream_id, room_id, type, state_key, event_id, prev_event_id) @@ -911,7 +918,7 @@ def _update_current_state_txn(self, txn, state_delta_by_room, max_stream_order): sql, ( ( - max_stream_order, + stream_id, room_id, etype, state_key, @@ -929,7 +936,7 @@ def _update_current_state_txn(self, txn, state_delta_by_room, max_stream_order): sql, ( ( - max_stream_order, + stream_id, room_id, etype, state_key, @@ -970,7 +977,7 @@ def _update_current_state_txn(self, txn, state_delta_by_room, max_stream_order): txn.call_after( self._curr_state_delta_stream_cache.entity_has_changed, room_id, - max_stream_order, + stream_id, ) # Invalidate the various caches @@ -988,8 +995,7 @@ def _update_current_state_txn(self, txn, state_delta_by_room, max_stream_order): for member in members_changed: txn.call_after( - self.get_rooms_for_user_with_stream_ordering.invalidate, - (member,), + self.get_rooms_for_user_with_stream_ordering.invalidate, (member,) ) self._invalidate_state_caches_and_stream(txn, room_id, members_changed) diff --git a/tests/replication/slave/storage/_base.py b/tests/replication/slave/storage/_base.py index 524af4f8d156..1f72a2a04f8a 100644 --- a/tests/replication/slave/storage/_base.py +++ b/tests/replication/slave/storage/_base.py @@ -56,7 +56,9 @@ def prepare(self, reactor, clock, hs): client = client_factory.buildProtocol(None) client.makeConnection(FakeTransport(server, reactor)) - server.makeConnection(FakeTransport(client, reactor)) + + self.server_to_client_transport = FakeTransport(client, reactor) + server.makeConnection(self.server_to_client_transport) def replicate(self): """Tell the master side of replication that something has happened, and then @@ -69,6 +71,24 @@ def check(self, method, args, expected_result=None): master_result = self.get_success(getattr(self.master_store, method)(*args)) slaved_result = self.get_success(getattr(self.slaved_store, method)(*args)) if expected_result is not None: - self.assertEqual(master_result, expected_result) - self.assertEqual(slaved_result, expected_result) - self.assertEqual(master_result, slaved_result) + self.assertEqual( + master_result, + expected_result, + "Expected master result to be %r but was %r" % ( + expected_result, master_result + ), + ) + self.assertEqual( + slaved_result, + expected_result, + "Expected slave result to be %r but was %r" % ( + expected_result, slaved_result + ), + ) + self.assertEqual( + master_result, + slaved_result, + "Slave result %r does not match master result %r" % ( + slaved_result, master_result + ), + ) diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py index 932611efda2d..65ecff3bd6eb 100644 --- a/tests/replication/slave/storage/test_events.py +++ b/tests/replication/slave/storage/test_events.py @@ -11,11 +11,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import logging from canonicaljson import encode_canonical_json from synapse.events import FrozenEvent, _EventInternalMetadata from synapse.events.snapshot import EventContext +from synapse.handlers.room import RoomEventSource from synapse.replication.slave.storage.events import SlavedEventStore from synapse.storage.roommember import RoomsForUser @@ -26,6 +28,8 @@ OUTLIER = {"outlier": True} ROOM_ID = "!room:blue" +logger = logging.getLogger(__name__) + def dict_equals(self, other): me = encode_canonical_json(self.get_pdu_json()) @@ -191,9 +195,116 @@ def test_get_rooms_for_user_with_stream_ordering(self): {(ROOM_ID, j2.internal_metadata.stream_ordering)}, ) + def test_get_rooms_for_user_with_stream_ordering_with_multi_event_persist(self): + """Check that current_state invalidation happens correctly with multiple events + in the persistence batch. + + This test attempts to reproduce a race condition between the event persistence + loop and a worker-based Sync handler. + + The problem occurred when the master persisted several events in one batch. It + only updates the current_state at the end of each batch, so the obvious thing + to do is then to issue a current_state_delta stream update corresponding to the + last stream_id in the batch. + + However, that raises the possibility that a worker will see the replication + notification for a join event before the current_state caches are invalidated. + + The test involves: + * creating a join and a message event for a user, and persisting them in the + same batch + + * controlling the replication stream so that updates are sent gradually + + * between each bunch of replication updates, check that we see a consistent + snapshot of the state. + """ + self.persist(type="m.room.create", key="", creator=USER_ID) + self.persist(type="m.room.member", key=USER_ID, membership="join") + self.replicate() + self.check("get_rooms_for_user_with_stream_ordering", (USER_ID_2,), set()) + + # limit the replication rate + repl_transport = self.server_to_client_transport + repl_transport.autoflush = False + + # build the join and message events and persist them in the same batch. + logger.info("----- build test events ------") + j2, j2ctx = self.build_event( + type="m.room.member", sender=USER_ID_2, key=USER_ID_2, membership="join" + ) + msg, msgctx = self.build_event() + self.get_success(self.master_store.persist_events([ + (j2, j2ctx), + (msg, msgctx), + ])) + self.replicate() + + event_source = RoomEventSource(self.hs) + event_source.store = self.slaved_store + current_token = self.get_success(event_source.get_current_key()) + + # gradually stream out the replication + while repl_transport.buffer: + logger.info("------ flush ------") + repl_transport.flush(30) + self.pump(0) + + prev_token = current_token + current_token = self.get_success(event_source.get_current_key()) + + # attempt to replicate the behaviour of the sync handler. + # + # First, we get a list of the rooms we are joined to + joined_rooms = self.get_success( + self.slaved_store.get_rooms_for_user_with_stream_ordering( + USER_ID_2, + ), + ) + + # Then, we get a list of the events since the last sync + membership_changes = self.get_success( + self.slaved_store.get_membership_changes_for_user( + USER_ID_2, prev_token, current_token, + ) + ) + + logger.info( + "%s->%s: joined_rooms=%r membership_changes=%r", + prev_token, + current_token, + joined_rooms, + membership_changes, + ) + + # the membership change is only any use to us if the room is in the + # joined_rooms list. + if membership_changes: + self.assertEqual( + joined_rooms, {(ROOM_ID, j2.internal_metadata.stream_ordering)} + ) + event_id = 0 - def persist( + def persist(self, backfill=False, **kwargs): + """ + Returns: + synapse.events.FrozenEvent: The event that was persisted. + """ + event, context = self.build_event(**kwargs) + + if backfill: + self.get_success( + self.master_store.persist_events([(event, context)], backfilled=True) + ) + else: + self.get_success( + self.master_store.persist_event(event, context) + ) + + return event + + def build_event( self, sender=USER_ID, room_id=ROOM_ID, @@ -201,7 +312,6 @@ def persist( key=None, internal={}, state=None, - backfill=False, depth=None, prev_events=[], auth_events=[], @@ -210,10 +320,7 @@ def persist( push_actions=[], **content ): - """ - Returns: - synapse.events.FrozenEvent: The event that was persisted. - """ + if depth is None: depth = self.event_id @@ -252,23 +359,11 @@ def persist( ) else: state_handler = self.hs.get_state_handler() - context = self.get_success(state_handler.compute_event_context(event)) + context = self.get_success(state_handler.compute_event_context( + event + )) self.master_store.add_push_actions_to_staging( event.event_id, {user_id: actions for user_id, actions in push_actions} ) - - ordering = None - if backfill: - self.get_success( - self.master_store.persist_events([(event, context)], backfilled=True) - ) - else: - ordering, _ = self.get_success( - self.master_store.persist_event(event, context) - ) - - if ordering: - event.internal_metadata.stream_ordering = ordering - - return event + return event, context diff --git a/tests/server.py b/tests/server.py index ea26dea6237e..8f89f4a83d44 100644 --- a/tests/server.py +++ b/tests/server.py @@ -365,6 +365,7 @@ class FakeTransport(object): disconnected = False buffer = attr.ib(default=b'') producer = attr.ib(default=None) + autoflush = attr.ib(default=True) def getPeer(self): return None @@ -415,31 +416,44 @@ def _produce(): def write(self, byt): self.buffer = self.buffer + byt - def _write(): - if not self.buffer: - # nothing to do. Don't write empty buffers: it upsets the - # TLSMemoryBIOProtocol - return - - if self.disconnected: - return - logger.info("%s->%s: %s", self._protocol, self.other, self.buffer) - - if getattr(self.other, "transport") is not None: - try: - self.other.dataReceived(self.buffer) - self.buffer = b"" - except Exception as e: - logger.warning("Exception writing to protocol: %s", e) - return - - self._reactor.callLater(0.0, _write) - # always actually do the write asynchronously. Some protocols (notably the # TLSMemoryBIOProtocol) get very confused if a read comes back while they are # still doing a write. Doing a callLater here breaks the cycle. - self._reactor.callLater(0.0, _write) + if self.autoflush: + self._reactor.callLater(0.0, self.flush) def writeSequence(self, seq): for x in seq: self.write(x) + + def flush(self, maxbytes=None): + if not self.buffer: + # nothing to do. Don't write empty buffers: it upsets the + # TLSMemoryBIOProtocol + return + + if self.disconnected: + return + + if getattr(self.other, "transport") is None: + # the other has no transport yet; reschedule + if self.autoflush: + self._reactor.callLater(0.0, self.flush) + return + + if maxbytes is not None: + to_write = self.buffer[:maxbytes] + else: + to_write = self.buffer + + logger.info("%s->%s: %s", self._protocol, self.other, to_write) + + try: + self.other.dataReceived(to_write) + except Exception as e: + logger.warning("Exception writing to protocol: %s", e) + return + + self.buffer = self.buffer[len(to_write):] + if self.buffer and self.autoflush: + self._reactor.callLater(0.0, self.flush)