Skip to content

Commit

Permalink
Always include the user's receipts
Browse files Browse the repository at this point in the history
  • Loading branch information
erikjohnston committed Aug 27, 2024
1 parent 3e36aff commit 2c86aba
Show file tree
Hide file tree
Showing 2 changed files with 103 additions and 11 deletions.
35 changes: 33 additions & 2 deletions synapse/handlers/sliding_sync/extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,21 @@
# <https://www.gnu.org/licenses/agpl-3.0.html>.
#

import itertools
import logging
from typing import TYPE_CHECKING, Dict, List, Mapping, Optional, Sequence, Set

from typing_extensions import assert_never

from synapse.api.constants import AccountDataTypes
from synapse.api.constants import AccountDataTypes, EduTypes
from synapse.handlers.receipts import ReceiptEventSource
from synapse.handlers.sliding_sync.types import (
HaveSentRoomFlag,
MutablePerConnectionState,
PerConnectionState,
)
from synapse.logging.opentracing import trace
from synapse.storage.databases.main.receipts import ReceiptInRoom
from synapse.types import (
DeviceListUpdates,
JsonMapping,
Expand Down Expand Up @@ -555,7 +557,36 @@ async def get_receipts_extension_response(
initial_receipts = await self.store.get_linearized_receipts_for_events(
room_and_event_ids=initial_rooms_and_event_ids,
)
fetched_receipts.extend(initial_receipts)

# We also always send down receipts for the current user. This
# may add duplicate receipts if they were returned in the above
# query, but they will get deduplicated below.
user_receipts = (
await self.store.get_linearized_receipts_for_user_in_rooms(
user_id=sync_config.user.to_string(),
room_ids=initial_rooms,
to_key=to_token.receipt_key,
)
)

# Merge the two sets of receipts together
for room_id in initial_receipts.keys() | user_receipts.keys():
receipt_content = ReceiptInRoom.merge_to_content(
list(
itertools.chain(
initial_receipts.get(room_id, []),
user_receipts.get(room_id, []),
)
)
)

fetched_receipts.append(
{
"room_id": room_id,
"type": EduTypes.RECEIPT,
"content": receipt_content,
}
)

fetched_receipts = ReceiptEventSource.filter_out_private_receipts(
fetched_receipts, sync_config.user.to_string()
Expand Down
79 changes: 70 additions & 9 deletions synapse/storage/databases/main/receipts.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,7 +534,7 @@ def f(
async def get_linearized_receipts_for_events(
self,
room_and_event_ids: Collection[Tuple[str, str]],
) -> Sequence[JsonMapping]:
) -> Mapping[str, Sequence[ReceiptInRoom]]:
"""Get all receipts for the given set of events.
Arguments:
Expand Down Expand Up @@ -590,14 +590,7 @@ def get_linearized_receipts_for_events_txn(
)
)

return [
{
"type": EduTypes.RECEIPT,
"room_id": room_id,
"content": ReceiptInRoom.merge_to_content(receipts),
}
for room_id, receipts in room_to_receipts.items()
]
return room_to_receipts

@cached(
num_args=2,
Expand Down Expand Up @@ -670,6 +663,74 @@ def f(txn: LoggingTransaction) -> List[Tuple[str, str, str, str, str]]:

return results

async def get_linearized_receipts_for_user_in_rooms(
self, user_id: str, room_ids: StrCollection, to_key: MultiWriterStreamToken
) -> Mapping[str, Sequence[ReceiptInRoom]]:
"""Fetch all receipts for the user in the given room.
Returns:
A dict from room ID to receipts in the room.
"""

def get_linearized_receipts_for_user_in_rooms_txn(
txn: LoggingTransaction,
batch_room_ids: StrCollection,
) -> List[Tuple[str, str, str, str, Optional[str], str]]:
clause, args = make_in_list_sql_clause(
self.database_engine, "room_id", batch_room_ids
)

sql = f"""
SELECT instance_name, stream_id, room_id, receipt_type, user_id, event_id, thread_id, data
FROM receipts_linearized
WHERE {clause} AND user_id = ? AND stream_id <= ?
"""

args.append(user_id)
args.append(to_key.get_max_stream_pos())

txn.execute(sql, args)

return [
(room_id, receipt_type, user_id, event_id, thread_id, data)
for instance_name, stream_id, room_id, receipt_type, user_id, event_id, thread_id, data in txn
if MultiWriterStreamToken.is_stream_position_in_range(
low=None,
high=to_key,
instance_name=instance_name,
pos=stream_id,
)
]

# room_id -> receipts
room_to_receipts: Dict[str, List[ReceiptInRoom]] = {}
for batch in batch_iter(room_ids, 1000):
batch_results = await self.db_pool.runInteraction(
"get_linearized_receipts_for_events",
get_linearized_receipts_for_user_in_rooms_txn,
batch,
)

for (
room_id,
receipt_type,
user_id,
event_id,
thread_id,
data,
) in batch_results:
room_to_receipts.setdefault(room_id, []).append(
ReceiptInRoom(
receipt_type=receipt_type,
user_id=user_id,
event_id=event_id,
thread_id=thread_id,
data=db_to_json(data),
)
)

return room_to_receipts

async def get_rooms_with_receipts_between(
self,
room_ids: StrCollection,
Expand Down

0 comments on commit 2c86aba

Please sign in to comment.