From cb7249a3c48dfe4a2a972b3a66a37f680aebbcc2 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 13 Jan 2023 11:13:26 +0000 Subject: [PATCH] Merge device list replication streams --- docs/upgrade.md | 9 +-- synapse/replication/tcp/client.py | 8 ++- synapse/replication/tcp/streams/__init__.py | 3 - synapse/replication/tcp/streams/_base.py | 73 ++++++++++++++------- synapse/storage/databases/main/devices.py | 13 ++-- 5 files changed, 69 insertions(+), 37 deletions(-) diff --git a/docs/upgrade.md b/docs/upgrade.md index 8a76172e43cc..270c33b6562e 100644 --- a/docs/upgrade.md +++ b/docs/upgrade.md @@ -92,12 +92,13 @@ process, for example: ## Changes to the account data replication streams -Synapse has changed the format of the account data replication streams (between -workers). This is a forwards- and backwards-incompatible change: v1.75 workers -cannot process account data replicated by v1.76 workers, and vice versa. +Synapse has changed the format of the account data and devices replication +streams (between workers). This is a forwards- and backwards-incompatible +change: v1.75 workers cannot process account data replicated by v1.76 workers, +and vice versa. Once all workers are upgraded to v1.76 (or downgraded to v1.75), account data -replication will resume as normal. +and device replication will resume as normal. # Upgrading to v1.74.0 diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index 7263bb2796da..31022ce5fb41 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -187,7 +187,7 @@ async def on_rdata( elif stream_name == DeviceListsStream.NAME: all_room_ids: Set[str] = set() for row in rows: - if row.entity.startswith("@"): + if row.entity.startswith("@") and not row.is_signature: room_ids = await self.store.get_rooms_for_user(row.entity) all_room_ids.update(room_ids) self.notifier.on_new_event( @@ -422,7 +422,11 @@ async def process_replication_rows( # The entities are either user IDs (starting with '@') whose devices # have changed, or remote servers that we need to tell about # changes. - hosts = {row.entity for row in rows if not row.entity.startswith("@")} + hosts = { + row.entity + for row in rows + if not row.entity.startswith("@") and not row.is_signature + } for host in hosts: self.federation_sender.send_device_messages(host, immediate=False) diff --git a/synapse/replication/tcp/streams/__init__.py b/synapse/replication/tcp/streams/__init__.py index a7eadfa3c9f9..9c67f661a362 100644 --- a/synapse/replication/tcp/streams/__init__.py +++ b/synapse/replication/tcp/streams/__init__.py @@ -37,7 +37,6 @@ Stream, ToDeviceStream, TypingStream, - UserSignatureStream, ) from synapse.replication.tcp.streams.events import EventsStream from synapse.replication.tcp.streams.federation import FederationStream @@ -62,7 +61,6 @@ ToDeviceStream, FederationStream, AccountDataStream, - UserSignatureStream, UnPartialStatedRoomStream, UnPartialStatedEventStream, ) @@ -82,7 +80,6 @@ "DeviceListsStream", "ToDeviceStream", "AccountDataStream", - "UserSignatureStream", "UnPartialStatedRoomStream", "UnPartialStatedEventStream", ] diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py index fbf78da9c254..cb782ee01363 100644 --- a/synapse/replication/tcp/streams/_base.py +++ b/synapse/replication/tcp/streams/_base.py @@ -463,18 +463,66 @@ class DeviceListsStream(Stream): @attr.s(slots=True, frozen=True, auto_attribs=True) class DeviceListsStreamRow: entity: str + is_signature: bool NAME = "device_lists" ROW_TYPE = DeviceListsStreamRow def __init__(self, hs: "HomeServer"): - store = hs.get_datastores().main + self.store = hs.get_datastores().main super().__init__( hs.get_instance_name(), - current_token_without_instance(store.get_device_stream_token), - store.get_all_device_list_changes_for_remotes, + current_token_without_instance(self.store.get_device_stream_token), + self._update_function, + ) + + async def _update_function( + self, + instance_name: str, + from_token: Token, + current_token: Token, + target_row_count: int, + ) -> StreamUpdateResult: + ( + device_updates, + devices_to_token, + devices_limited, + ) = await self.store.get_all_device_list_changes_for_remotes( + instance_name, from_token, current_token, target_row_count ) + ( + signatures_updates, + signatures_to_token, + signatures_limited, + ) = await self.store.get_all_user_signature_changes_for_remotes( + instance_name, from_token, current_token, target_row_count + ) + + upper_limit_token = current_token + if devices_limited: + upper_limit_token = min(upper_limit_token, devices_to_token) + if signatures_limited: + upper_limit_token = min(upper_limit_token, signatures_to_token) + + device_updates = [ + (stream_id, (entity, False)) + for stream_id, (entity,) in device_updates + if stream_id <= upper_limit_token + ] + + signatures_updates = [ + (stream_id, (entity, True)) + for stream_id, (entity,) in signatures_updates + if stream_id <= upper_limit_token + ] + + updates = list( + heapq.merge(device_updates, signatures_updates, key=lambda row: row[0]) + ) + + return updates, upper_limit_token, devices_limited or signatures_limited + class ToDeviceStream(Stream): """New to_device messages for a client""" @@ -583,22 +631,3 @@ async def _update_function( heapq.merge(room_rows, global_rows, tag_rows, key=lambda row: row[0]) ) return updates, to_token, limited - - -class UserSignatureStream(Stream): - """A user has signed their own device with their user-signing key""" - - @attr.s(slots=True, frozen=True, auto_attribs=True) - class UserSignatureStreamRow: - user_id: str - - NAME = "user_signature" - ROW_TYPE = UserSignatureStreamRow - - def __init__(self, hs: "HomeServer"): - store = hs.get_datastores().main - super().__init__( - hs.get_instance_name(), - current_token_without_instance(store.get_device_stream_token), - store.get_all_user_signature_changes_for_remotes, - ) diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index b06766447338..cd186c84726c 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -38,7 +38,7 @@ whitelisted_homeserver, ) from synapse.metrics.background_process_metrics import wrap_as_background_process -from synapse.replication.tcp.streams._base import DeviceListsStream, UserSignatureStream +from synapse.replication.tcp.streams._base import DeviceListsStream from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause from synapse.storage.database import ( DatabasePool, @@ -163,9 +163,7 @@ def process_replication_rows( ) -> None: if stream_name == DeviceListsStream.NAME: self._invalidate_caches_for_devices(token, rows) - elif stream_name == UserSignatureStream.NAME: - for row in rows: - self._user_signature_stream_cache.entity_has_changed(row.user_id, token) + return super().process_replication_rows(stream_name, instance_name, token, rows) def process_replication_position( @@ -173,14 +171,17 @@ def process_replication_position( ) -> None: if stream_name == DeviceListsStream.NAME: self._device_list_id_gen.advance(instance_name, token) - elif stream_name == UserSignatureStream.NAME: - self._device_list_id_gen.advance(instance_name, token) + super().process_replication_position(stream_name, instance_name, token) def _invalidate_caches_for_devices( self, token: int, rows: Iterable[DeviceListsStream.DeviceListsStreamRow] ) -> None: for row in rows: + if row.is_signature: + self._user_signature_stream_cache.entity_has_changed(row.entity, token) + continue + # The entities are either user IDs (starting with '@') whose devices # have changed, or remote servers that we need to tell about # changes.