Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Merge device list replication streams
Browse files Browse the repository at this point in the history
  • Loading branch information
erikjohnston committed Jan 13, 2023
1 parent 73ff493 commit cb7249a
Show file tree
Hide file tree
Showing 5 changed files with 69 additions and 37 deletions.
9 changes: 5 additions & 4 deletions docs/upgrade.md
Original file line number Diff line number Diff line change
Expand Up @@ -92,12 +92,13 @@ process, for example:
## Changes to the account data replication streams
Synapse has changed the format of the account data replication streams (between
workers). This is a forwards- and backwards-incompatible change: v1.75 workers
cannot process account data replicated by v1.76 workers, and vice versa.
Synapse has changed the format of the account data and devices replication
streams (between workers). This is a forwards- and backwards-incompatible
change: v1.75 workers cannot process account data replicated by v1.76 workers,
and vice versa.
Once all workers are upgraded to v1.76 (or downgraded to v1.75), account data
replication will resume as normal.
and device replication will resume as normal.
# Upgrading to v1.74.0
Expand Down
8 changes: 6 additions & 2 deletions synapse/replication/tcp/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ async def on_rdata(
elif stream_name == DeviceListsStream.NAME:
all_room_ids: Set[str] = set()
for row in rows:
if row.entity.startswith("@"):
if row.entity.startswith("@") and not row.is_signature:
room_ids = await self.store.get_rooms_for_user(row.entity)
all_room_ids.update(room_ids)
self.notifier.on_new_event(
Expand Down Expand Up @@ -422,7 +422,11 @@ async def process_replication_rows(
# The entities are either user IDs (starting with '@') whose devices
# have changed, or remote servers that we need to tell about
# changes.
hosts = {row.entity for row in rows if not row.entity.startswith("@")}
hosts = {
row.entity
for row in rows
if not row.entity.startswith("@") and not row.is_signature
}
for host in hosts:
self.federation_sender.send_device_messages(host, immediate=False)

Expand Down
3 changes: 0 additions & 3 deletions synapse/replication/tcp/streams/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@
Stream,
ToDeviceStream,
TypingStream,
UserSignatureStream,
)
from synapse.replication.tcp.streams.events import EventsStream
from synapse.replication.tcp.streams.federation import FederationStream
Expand All @@ -62,7 +61,6 @@
ToDeviceStream,
FederationStream,
AccountDataStream,
UserSignatureStream,
UnPartialStatedRoomStream,
UnPartialStatedEventStream,
)
Expand All @@ -82,7 +80,6 @@
"DeviceListsStream",
"ToDeviceStream",
"AccountDataStream",
"UserSignatureStream",
"UnPartialStatedRoomStream",
"UnPartialStatedEventStream",
]
73 changes: 51 additions & 22 deletions synapse/replication/tcp/streams/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,18 +463,66 @@ class DeviceListsStream(Stream):
@attr.s(slots=True, frozen=True, auto_attribs=True)
class DeviceListsStreamRow:
entity: str
is_signature: bool

NAME = "device_lists"
ROW_TYPE = DeviceListsStreamRow

def __init__(self, hs: "HomeServer"):
store = hs.get_datastores().main
self.store = hs.get_datastores().main
super().__init__(
hs.get_instance_name(),
current_token_without_instance(store.get_device_stream_token),
store.get_all_device_list_changes_for_remotes,
current_token_without_instance(self.store.get_device_stream_token),
self._update_function,
)

async def _update_function(
self,
instance_name: str,
from_token: Token,
current_token: Token,
target_row_count: int,
) -> StreamUpdateResult:
(
device_updates,
devices_to_token,
devices_limited,
) = await self.store.get_all_device_list_changes_for_remotes(
instance_name, from_token, current_token, target_row_count
)

(
signatures_updates,
signatures_to_token,
signatures_limited,
) = await self.store.get_all_user_signature_changes_for_remotes(
instance_name, from_token, current_token, target_row_count
)

upper_limit_token = current_token
if devices_limited:
upper_limit_token = min(upper_limit_token, devices_to_token)
if signatures_limited:
upper_limit_token = min(upper_limit_token, signatures_to_token)

device_updates = [
(stream_id, (entity, False))
for stream_id, (entity,) in device_updates
if stream_id <= upper_limit_token
]

signatures_updates = [
(stream_id, (entity, True))
for stream_id, (entity,) in signatures_updates
if stream_id <= upper_limit_token
]

updates = list(
heapq.merge(device_updates, signatures_updates, key=lambda row: row[0])
)

return updates, upper_limit_token, devices_limited or signatures_limited


class ToDeviceStream(Stream):
"""New to_device messages for a client"""
Expand Down Expand Up @@ -583,22 +631,3 @@ async def _update_function(
heapq.merge(room_rows, global_rows, tag_rows, key=lambda row: row[0])
)
return updates, to_token, limited


class UserSignatureStream(Stream):
"""A user has signed their own device with their user-signing key"""

@attr.s(slots=True, frozen=True, auto_attribs=True)
class UserSignatureStreamRow:
user_id: str

NAME = "user_signature"
ROW_TYPE = UserSignatureStreamRow

def __init__(self, hs: "HomeServer"):
store = hs.get_datastores().main
super().__init__(
hs.get_instance_name(),
current_token_without_instance(store.get_device_stream_token),
store.get_all_user_signature_changes_for_remotes,
)
13 changes: 7 additions & 6 deletions synapse/storage/databases/main/devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
whitelisted_homeserver,
)
from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.replication.tcp.streams._base import DeviceListsStream, UserSignatureStream
from synapse.replication.tcp.streams._base import DeviceListsStream
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
from synapse.storage.database import (
DatabasePool,
Expand Down Expand Up @@ -163,24 +163,25 @@ def process_replication_rows(
) -> None:
if stream_name == DeviceListsStream.NAME:
self._invalidate_caches_for_devices(token, rows)
elif stream_name == UserSignatureStream.NAME:
for row in rows:
self._user_signature_stream_cache.entity_has_changed(row.user_id, token)

return super().process_replication_rows(stream_name, instance_name, token, rows)

def process_replication_position(
self, stream_name: str, instance_name: str, token: int
) -> None:
if stream_name == DeviceListsStream.NAME:
self._device_list_id_gen.advance(instance_name, token)
elif stream_name == UserSignatureStream.NAME:
self._device_list_id_gen.advance(instance_name, token)

super().process_replication_position(stream_name, instance_name, token)

def _invalidate_caches_for_devices(
self, token: int, rows: Iterable[DeviceListsStream.DeviceListsStreamRow]
) -> None:
for row in rows:
if row.is_signature:
self._user_signature_stream_cache.entity_has_changed(row.entity, token)
continue

# The entities are either user IDs (starting with '@') whose devices
# have changed, or remote servers that we need to tell about
# changes.
Expand Down

0 comments on commit cb7249a

Please sign in to comment.