From 7c89612a175d917768bc2603dcdf1a4460d3cce8 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 27 Aug 2024 09:57:22 +0100 Subject: [PATCH] Always include the user's receipts --- synapse/handlers/sliding_sync/extensions.py | 56 ++++++++++---- synapse/storage/databases/main/receipts.py | 81 ++++++++++++++++++--- 2 files changed, 115 insertions(+), 22 deletions(-) diff --git a/synapse/handlers/sliding_sync/extensions.py b/synapse/handlers/sliding_sync/extensions.py index e4b6f4e77f7..9df24c79d91 100644 --- a/synapse/handlers/sliding_sync/extensions.py +++ b/synapse/handlers/sliding_sync/extensions.py @@ -12,12 +12,13 @@ # . # +import itertools import logging from typing import TYPE_CHECKING, Dict, List, Mapping, Optional, Sequence, Set from typing_extensions import assert_never -from synapse.api.constants import AccountDataTypes +from synapse.api.constants import AccountDataTypes, EduTypes from synapse.handlers.receipts import ReceiptEventSource from synapse.handlers.sliding_sync.types import ( HaveSentRoomFlag, @@ -25,6 +26,7 @@ PerConnectionState, ) from synapse.logging.opentracing import trace +from synapse.storage.databases.main.receipts import ReceiptInRoom from synapse.types import ( DeviceListUpdates, JsonMapping, @@ -541,21 +543,49 @@ async def get_receipts_extension_response( ) fetched_receipts.extend(previously_receipts) - # For rooms we haven't previously sent down, we could send all receipts - # from that room but we only want to include receipts for events - # in the timeline to avoid bloating and blowing up the sync response - # as the number of users in the room increases. (this behavior is part of the spec) - initial_rooms_and_event_ids = [ - (room_id, event.event_id) - for room_id in initial_rooms - if room_id in actual_room_response_map - for event in actual_room_response_map[room_id].timeline_events - ] - if initial_rooms_and_event_ids: + if initial_rooms: + # We also always send down receipts for the current user. + user_receipts = ( + await self.store.get_linearized_receipts_for_user_in_rooms( + user_id=sync_config.user.to_string(), + room_ids=initial_rooms, + to_key=to_token.receipt_key, + ) + ) + + # For rooms we haven't previously sent down, we could send all receipts + # from that room but we only want to include receipts for events + # in the timeline to avoid bloating and blowing up the sync response + # as the number of users in the room increases. (this behavior is part of the spec) + initial_rooms_and_event_ids = [ + (room_id, event.event_id) + for room_id in initial_rooms + if room_id in actual_room_response_map + for event in actual_room_response_map[room_id].timeline_events + ] initial_receipts = await self.store.get_linearized_receipts_for_events( room_and_event_ids=initial_rooms_and_event_ids, ) - fetched_receipts.extend(initial_receipts) + + # Combine the receipts for a room and add them to + # `fetched_receipts` + for room_id in initial_receipts.keys() | user_receipts.keys(): + receipt_content = ReceiptInRoom.merge_to_content( + list( + itertools.chain( + initial_receipts.get(room_id, []), + user_receipts.get(room_id, []), + ) + ) + ) + + fetched_receipts.append( + { + "room_id": room_id, + "type": EduTypes.RECEIPT, + "content": receipt_content, + } + ) fetched_receipts = ReceiptEventSource.filter_out_private_receipts( fetched_receipts, sync_config.user.to_string() diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py index c6c66a4879f..bf107435741 100644 --- a/synapse/storage/databases/main/receipts.py +++ b/synapse/storage/databases/main/receipts.py @@ -534,7 +534,7 @@ def f( async def get_linearized_receipts_for_events( self, room_and_event_ids: Collection[Tuple[str, str]], - ) -> Sequence[JsonMapping]: + ) -> Mapping[str, Sequence[ReceiptInRoom]]: """Get all receipts for the given set of events. Arguments: @@ -544,6 +544,8 @@ async def get_linearized_receipts_for_events( Returns: A list of receipts, one per room. """ + if not room_and_event_ids: + return {} def get_linearized_receipts_for_events_txn( txn: LoggingTransaction, @@ -590,14 +592,7 @@ def get_linearized_receipts_for_events_txn( ) ) - return [ - { - "type": EduTypes.RECEIPT, - "room_id": room_id, - "content": ReceiptInRoom.merge_to_content(receipts), - } - for room_id, receipts in room_to_receipts.items() - ] + return room_to_receipts @cached( num_args=2, @@ -670,6 +665,74 @@ def f(txn: LoggingTransaction) -> List[Tuple[str, str, str, str, str]]: return results + async def get_linearized_receipts_for_user_in_rooms( + self, user_id: str, room_ids: StrCollection, to_key: MultiWriterStreamToken + ) -> Mapping[str, Sequence[ReceiptInRoom]]: + """Fetch all receipts for the user in the given room. + + Returns: + A dict from room ID to receipts in the room. + """ + + def get_linearized_receipts_for_user_in_rooms_txn( + txn: LoggingTransaction, + batch_room_ids: StrCollection, + ) -> List[Tuple[str, str, str, str, Optional[str], str]]: + clause, args = make_in_list_sql_clause( + self.database_engine, "room_id", batch_room_ids + ) + + sql = f""" + SELECT instance_name, stream_id, room_id, receipt_type, user_id, event_id, thread_id, data + FROM receipts_linearized + WHERE {clause} AND user_id = ? AND stream_id <= ? + """ + + args.append(user_id) + args.append(to_key.get_max_stream_pos()) + + txn.execute(sql, args) + + return [ + (room_id, receipt_type, user_id, event_id, thread_id, data) + for instance_name, stream_id, room_id, receipt_type, user_id, event_id, thread_id, data in txn + if MultiWriterStreamToken.is_stream_position_in_range( + low=None, + high=to_key, + instance_name=instance_name, + pos=stream_id, + ) + ] + + # room_id -> receipts + room_to_receipts: Dict[str, List[ReceiptInRoom]] = {} + for batch in batch_iter(room_ids, 1000): + batch_results = await self.db_pool.runInteraction( + "get_linearized_receipts_for_events", + get_linearized_receipts_for_user_in_rooms_txn, + batch, + ) + + for ( + room_id, + receipt_type, + user_id, + event_id, + thread_id, + data, + ) in batch_results: + room_to_receipts.setdefault(room_id, []).append( + ReceiptInRoom( + receipt_type=receipt_type, + user_id=user_id, + event_id=event_id, + thread_id=thread_id, + data=db_to_json(data), + ) + ) + + return room_to_receipts + async def get_rooms_with_receipts_between( self, room_ids: StrCollection,