From c418d367052afaa9f471e1e6306f55cbc0c6572b Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 12 Jan 2023 11:55:21 +0000 Subject: [PATCH] Merge the two accound data streams --- synapse/replication/tcp/client.py | 3 +- synapse/replication/tcp/handler.py | 3 +- synapse/replication/tcp/streams/__init__.py | 3 -- synapse/replication/tcp/streams/_base.py | 49 +++++++++-------- .../storage/databases/main/account_data.py | 6 +-- synapse/storage/databases/main/tags.py | 54 +++++-------------- 6 files changed, 41 insertions(+), 77 deletions(-) diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index b5e40da5337e..7263bb2796da 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -33,7 +33,6 @@ PushersStream, PushRulesStream, ReceiptsStream, - TagAccountDataStream, ToDeviceStream, TypingStream, UnPartialStatedEventStream, @@ -168,7 +167,7 @@ async def on_rdata( self.notifier.on_new_event( StreamKeyType.PUSH_RULES, token, users=[row.user_id for row in rows] ) - elif stream_name in (AccountDataStream.NAME, TagAccountDataStream.NAME): + elif stream_name in AccountDataStream.NAME: self.notifier.on_new_event( StreamKeyType.ACCOUNT_DATA, token, users=[row.user_id for row in rows] ) diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py index 0f166d16aa5f..d03a53d76429 100644 --- a/synapse/replication/tcp/handler.py +++ b/synapse/replication/tcp/handler.py @@ -58,7 +58,6 @@ PresenceStream, ReceiptsStream, Stream, - TagAccountDataStream, ToDeviceStream, TypingStream, ) @@ -145,7 +144,7 @@ def __init__(self, hs: "HomeServer"): continue - if isinstance(stream, (AccountDataStream, TagAccountDataStream)): + if isinstance(stream, AccountDataStream): # Only add AccountDataStream and TagAccountDataStream as a source on the # instance in charge of account_data persistence. if hs.get_instance_name() in hs.config.worker.writers.account_data: diff --git a/synapse/replication/tcp/streams/__init__.py b/synapse/replication/tcp/streams/__init__.py index 110f10aab9a5..a7eadfa3c9f9 100644 --- a/synapse/replication/tcp/streams/__init__.py +++ b/synapse/replication/tcp/streams/__init__.py @@ -35,7 +35,6 @@ PushRulesStream, ReceiptsStream, Stream, - TagAccountDataStream, ToDeviceStream, TypingStream, UserSignatureStream, @@ -62,7 +61,6 @@ DeviceListsStream, ToDeviceStream, FederationStream, - TagAccountDataStream, AccountDataStream, UserSignatureStream, UnPartialStatedRoomStream, @@ -83,7 +81,6 @@ "CachesStream", "DeviceListsStream", "ToDeviceStream", - "TagAccountDataStream", "AccountDataStream", "UserSignatureStream", "UnPartialStatedRoomStream", diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py index e01155ad597b..fbf78da9c254 100644 --- a/synapse/replication/tcp/streams/_base.py +++ b/synapse/replication/tcp/streams/_base.py @@ -28,8 +28,8 @@ import attr +from synapse.api.constants import AccountDataTypes from synapse.replication.http.streams import ReplicationGetStreamUpdates -from synapse.types import JsonDict if TYPE_CHECKING: from synapse.server import HomeServer @@ -495,27 +495,6 @@ def __init__(self, hs: "HomeServer"): ) -class TagAccountDataStream(Stream): - """Someone added/removed a tag for a room""" - - @attr.s(slots=True, frozen=True, auto_attribs=True) - class TagAccountDataStreamRow: - user_id: str - room_id: str - data: JsonDict - - NAME = "tag_account_data" - ROW_TYPE = TagAccountDataStreamRow - - def __init__(self, hs: "HomeServer"): - store = hs.get_datastores().main - super().__init__( - hs.get_instance_name(), - current_token_without_instance(store.get_max_account_data_stream_id), - store.get_all_updated_tags, - ) - - class AccountDataStream(Stream): """Global or per room account data was changed""" @@ -560,6 +539,19 @@ async def _update_function( to_token = room_results[-1][0] limited = True + tags, tag_to_token, tags_limited = await self.store.get_all_updated_tags( + instance_name, + from_token, + to_token, + limit, + ) + + # again, if the tag results hit the limit, limit the global results to + # the same stream token. + if tags_limited: + to_token = tag_to_token + limited = True + # convert the global results to the right format, and limit them to the to_token # at the same time global_rows = ( @@ -568,11 +560,16 @@ async def _update_function( if stream_id <= to_token ) - # we know that the room_results are already limited to `to_token` so no need - # for a check on `stream_id` here. room_rows = ( (stream_id, (user_id, room_id, account_data_type)) for stream_id, user_id, room_id, account_data_type in room_results + if stream_id <= to_token + ) + + tag_rows = ( + (stream_id, (user_id, room_id, AccountDataTypes.TAG)) + for stream_id, user_id, room_id in tags + if stream_id <= to_token ) # We need to return a sorted list, so merge them together. @@ -582,7 +579,9 @@ async def _update_function( # leading to a comparison between the data tuples. The comparison could # fail due to attempting to compare the `room_id` which results in a # `TypeError` from comparing a `str` vs `None`. - updates = list(heapq.merge(room_rows, global_rows, key=lambda row: row[0])) + updates = list( + heapq.merge(room_rows, global_rows, tag_rows, key=lambda row: row[0]) + ) return updates, to_token, limited diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py index 86032897f53a..881d7089dbb2 100644 --- a/synapse/storage/databases/main/account_data.py +++ b/synapse/storage/databases/main/account_data.py @@ -27,7 +27,7 @@ ) from synapse.api.constants import AccountDataTypes -from synapse.replication.tcp.streams import AccountDataStream, TagAccountDataStream +from synapse.replication.tcp.streams import AccountDataStream from synapse.storage._base import db_to_json from synapse.storage.database import ( DatabasePool, @@ -454,9 +454,7 @@ def process_replication_rows( def process_replication_position( self, stream_name: str, instance_name: str, token: int ) -> None: - if stream_name == TagAccountDataStream.NAME: - self._account_data_id_gen.advance(instance_name, token) - elif stream_name == AccountDataStream.NAME: + if stream_name == AccountDataStream.NAME: self._account_data_id_gen.advance(instance_name, token) super().process_replication_position(stream_name, instance_name, token) diff --git a/synapse/storage/databases/main/tags.py b/synapse/storage/databases/main/tags.py index e23c927e02f2..d5500cdd470c 100644 --- a/synapse/storage/databases/main/tags.py +++ b/synapse/storage/databases/main/tags.py @@ -17,7 +17,8 @@ import logging from typing import Any, Dict, Iterable, List, Tuple, cast -from synapse.replication.tcp.streams import TagAccountDataStream +from synapse.api.constants import AccountDataTypes +from synapse.replication.tcp.streams import AccountDataStream from synapse.storage._base import db_to_json from synapse.storage.database import LoggingTransaction from synapse.storage.databases.main.account_data import AccountDataWorkerStore @@ -54,7 +55,7 @@ async def get_tags_for_user(self, user_id: str) -> Dict[str, Dict[str, JsonDict] async def get_all_updated_tags( self, instance_name: str, last_id: int, current_id: int, limit: int - ) -> Tuple[List[Tuple[int, Tuple[str, str, str]]], int, bool]: + ) -> Tuple[List[Tuple[int, str, str]], int, bool]: """Get updates for tags replication stream. Args: @@ -73,7 +74,7 @@ async def get_all_updated_tags( The token returned can be used in a subsequent call to this function to get further updatees. - The updates are a list of 2-tuples of stream ID and the row data + The updates are a list of tuples of stream ID, user ID and room ID """ if last_id == current_id: @@ -96,38 +97,13 @@ def get_all_updated_tags_txn( "get_all_updated_tags", get_all_updated_tags_txn ) - def get_tag_content( - txn: LoggingTransaction, tag_ids: List[Tuple[int, str, str]] - ) -> List[Tuple[int, Tuple[str, str, str]]]: - sql = "SELECT tag, content FROM room_tags WHERE user_id=? AND room_id=?" - results = [] - for stream_id, user_id, room_id in tag_ids: - txn.execute(sql, (user_id, room_id)) - tags = [] - for tag, content in txn: - tags.append(json_encoder.encode(tag) + ":" + content) - tag_json = "{" + ",".join(tags) + "}" - results.append((stream_id, (user_id, room_id, tag_json))) - - return results - - batch_size = 50 - results = [] - for i in range(0, len(tag_ids), batch_size): - tags = await self.db_pool.runInteraction( - "get_all_updated_tag_content", - get_tag_content, - tag_ids[i : i + batch_size], - ) - results.extend(tags) - limited = False upto_token = current_id - if len(results) >= limit: - upto_token = results[-1][0] + if len(tag_ids) >= limit: + upto_token = tag_ids[-1][0] limited = True - return results, upto_token, limited + return tag_ids, upto_token, limited async def get_updated_tags( self, user_id: str, stream_id: int @@ -299,20 +275,16 @@ def process_replication_rows( token: int, rows: Iterable[Any], ) -> None: - if stream_name == TagAccountDataStream.NAME: + if stream_name == AccountDataStream.NAME: for row in rows: - self.get_tags_for_user.invalidate((row.user_id,)) - self._account_data_stream_cache.entity_has_changed(row.user_id, token) + if row.data_type == AccountDataTypes.TAG: + self.get_tags_for_user.invalidate((row.user_id,)) + self._account_data_stream_cache.entity_has_changed( + row.user_id, token + ) super().process_replication_rows(stream_name, instance_name, token, rows) - def process_replication_position( - self, stream_name: str, instance_name: str, token: int - ) -> None: - if stream_name == TagAccountDataStream.NAME: - self._account_data_id_gen.advance(instance_name, token) - super().process_replication_position(stream_name, instance_name, token) - class TagsStore(TagsWorkerStore): pass