From 5624c8b961ed6a8310a2c6723ae13e854721756b Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 30 May 2024 14:03:49 +0100 Subject: [PATCH] In sync wait for worker to catch up since token (#17215) Otherwise things will get confused. An alternative would be to make sure that for lagging stream we don't return anything (and make sure the returned next_batch token doesn't go backwards). But that is a faff. --- changelog.d/17215.bugfix | 1 + pyproject.toml | 6 +- synapse/handlers/sync.py | 35 +++++++++++ synapse/notifier.py | 23 ++++++++ synapse/storage/databases/main/events.py | 7 +++ .../storage/databases/main/events_worker.py | 11 +++- synapse/types/__init__.py | 58 ++++++++++++++++++- 7 files changed, 134 insertions(+), 7 deletions(-) create mode 100644 changelog.d/17215.bugfix diff --git a/changelog.d/17215.bugfix b/changelog.d/17215.bugfix new file mode 100644 index 00000000000..10981b798e0 --- /dev/null +++ b/changelog.d/17215.bugfix @@ -0,0 +1 @@ +Fix bug where duplicate events could be sent down sync when using workers that are overloaded. diff --git a/pyproject.toml b/pyproject.toml index ea14b981997..9a3348be497 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -200,10 +200,8 @@ netaddr = ">=0.7.18" # add a lower bound to the Jinja2 dependency. Jinja2 = ">=3.0" bleach = ">=1.4.3" -# We use `ParamSpec` and `Concatenate`, which were added in `typing-extensions` 3.10.0.0. -# Additionally we need https://github.com/python/typing/pull/817 to allow types to be -# generic over ParamSpecs. -typing-extensions = ">=3.10.0.1" +# We use `Self`, which were added in `typing-extensions` 4.0. +typing-extensions = ">=4.0" # We enforce that we have a `cryptography` version that bundles an `openssl` # with the latest security patches. cryptography = ">=3.4.7" diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index ac5bddd52fc..1d7d9dfdd0f 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -284,6 +284,23 @@ def __bool__(self) -> bool: or self.device_lists ) + @staticmethod + def empty(next_batch: StreamToken) -> "SyncResult": + "Return a new empty result" + return SyncResult( + next_batch=next_batch, + presence=[], + account_data=[], + joined=[], + invited=[], + knocked=[], + archived=[], + to_device=[], + device_lists=DeviceListUpdates(), + device_one_time_keys_count={}, + device_unused_fallback_key_types=[], + ) + @attr.s(slots=True, frozen=True, auto_attribs=True) class E2eeSyncResult: @@ -497,6 +514,24 @@ async def _wait_for_sync_for_user( if context: context.tag = sync_label + if since_token is not None: + # We need to make sure this worker has caught up with the token. If + # this returns false it means we timed out waiting, and we should + # just return an empty response. + start = self.clock.time_msec() + if not await self.notifier.wait_for_stream_token(since_token): + logger.warning( + "Timed out waiting for worker to catch up. Returning empty response" + ) + return SyncResult.empty(since_token) + + # If we've spent significant time waiting to catch up, take it off + # the timeout. + now = self.clock.time_msec() + if now - start > 1_000: + timeout -= now - start + timeout = max(timeout, 0) + # if we have a since token, delete any to-device messages before that token # (since we now know that the device has received them) if since_token is not None: diff --git a/synapse/notifier.py b/synapse/notifier.py index 7c1cd3b5f2f..ced9e9ad667 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py @@ -763,6 +763,29 @@ async def check_for_updates( return result + async def wait_for_stream_token(self, stream_token: StreamToken) -> bool: + """Wait for this worker to catch up with the given stream token.""" + + start = self.clock.time_msec() + while True: + current_token = self.event_sources.get_current_token() + if stream_token.is_before_or_eq(current_token): + return True + + now = self.clock.time_msec() + + if now - start > 10_000: + return False + + logger.info( + "Waiting for current token to reach %s; currently at %s", + stream_token, + current_token, + ) + + # TODO: be better + await self.clock.sleep(0.5) + async def _get_room_ids( self, user: UserID, explicit_room_id: Optional[str] ) -> Tuple[StrCollection, bool]: diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index fd7167904dc..f1bd85aa276 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -95,6 +95,10 @@ class DeltaState: to_insert: StateMap[str] no_longer_in_room: bool = False + def is_noop(self) -> bool: + """Whether this state delta is actually empty""" + return not self.to_delete and not self.to_insert and not self.no_longer_in_room + class PersistEventsStore: """Contains all the functions for writing events to the database. @@ -1017,6 +1021,9 @@ async def update_current_state( ) -> None: """Update the current state stored in the datatabase for the given room""" + if state_delta.is_noop(): + return + async with self._stream_id_gen.get_next() as stream_ordering: await self.db_pool.runInteraction( "update_current_state", diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index 426df2a9d27..c06c44deb1f 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -200,7 +200,11 @@ def __init__( notifier=hs.get_replication_notifier(), stream_name="events", instance_name=hs.get_instance_name(), - tables=[("events", "instance_name", "stream_ordering")], + tables=[ + ("events", "instance_name", "stream_ordering"), + ("current_state_delta_stream", "instance_name", "stream_id"), + ("ex_outlier_stream", "instance_name", "event_stream_ordering"), + ], sequence_name="events_stream_seq", writers=hs.config.worker.writers.events, ) @@ -210,7 +214,10 @@ def __init__( notifier=hs.get_replication_notifier(), stream_name="backfill", instance_name=hs.get_instance_name(), - tables=[("events", "instance_name", "stream_ordering")], + tables=[ + ("events", "instance_name", "stream_ordering"), + ("ex_outlier_stream", "instance_name", "event_stream_ordering"), + ], sequence_name="events_backfill_stream_seq", positive=False, writers=hs.config.worker.writers.events, diff --git a/synapse/types/__init__.py b/synapse/types/__init__.py index 509a2d3a0f9..151658df534 100644 --- a/synapse/types/__init__.py +++ b/synapse/types/__init__.py @@ -48,7 +48,7 @@ from immutabledict import immutabledict from signedjson.key import decode_verify_key_bytes from signedjson.types import VerifyKey -from typing_extensions import TypedDict +from typing_extensions import Self, TypedDict from unpaddedbase64 import decode_base64 from zope.interface import Interface @@ -515,6 +515,27 @@ def get_stream_pos_for_instance(self, instance_name: str) -> int: # at `self.stream`. return self.instance_map.get(instance_name, self.stream) + def is_before_or_eq(self, other_token: Self) -> bool: + """Wether this token is before the other token, i.e. every constituent + part is before the other. + + Essentially it is `self <= other`. + + Note: if `self.is_before_or_eq(other_token) is False` then that does not + imply that the reverse is True. + """ + if self.stream > other_token.stream: + return False + + instances = self.instance_map.keys() | other_token.instance_map.keys() + for instance in instances: + if self.instance_map.get( + instance, self.stream + ) > other_token.instance_map.get(instance, other_token.stream): + return False + + return True + @attr.s(frozen=True, slots=True, order=False) class RoomStreamToken(AbstractMultiWriterStreamToken): @@ -1008,6 +1029,41 @@ def get_field( """Returns the stream ID for the given key.""" return getattr(self, key.value) + def is_before_or_eq(self, other_token: "StreamToken") -> bool: + """Wether this token is before the other token, i.e. every constituent + part is before the other. + + Essentially it is `self <= other`. + + Note: if `self.is_before_or_eq(other_token) is False` then that does not + imply that the reverse is True. + """ + + for _, key in StreamKeyType.__members__.items(): + if key == StreamKeyType.TYPING: + # Typing stream is allowed to "reset", and so comparisons don't + # really make sense as is. + # TODO: Figure out a better way of tracking resets. + continue + + self_value = self.get_field(key) + other_value = other_token.get_field(key) + + if isinstance(self_value, RoomStreamToken): + assert isinstance(other_value, RoomStreamToken) + if not self_value.is_before_or_eq(other_value): + return False + elif isinstance(self_value, MultiWriterStreamToken): + assert isinstance(other_value, MultiWriterStreamToken) + if not self_value.is_before_or_eq(other_value): + return False + else: + assert isinstance(other_value, int) + if self_value > other_value: + return False + + return True + StreamToken.START = StreamToken( RoomStreamToken(stream=0), 0, 0, MultiWriterStreamToken(stream=0), 0, 0, 0, 0, 0, 0