Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Clarify that a method returns only unthreaded receipts. (#13937)
Browse files Browse the repository at this point in the history
By renaming it and updating the docstring.

Additionally, refactors a method which is used only by tests.
  • Loading branch information
clokep authored Sep 29, 2022
1 parent 99a7e7e commit 5680169
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 76 deletions.
1 change: 1 addition & 0 deletions changelog.d/13937.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Experimental support for thread-specific receipts ([MSC3771](https://github.com/matrix-org/matrix-spec-proposals/pull/3771)).
12 changes: 3 additions & 9 deletions synapse/storage/databases/main/event_push_actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,14 +366,11 @@ def _get_unread_counts_by_receipt_txn(
user_id: str,
) -> NotifCounts:
# Get the stream ordering of the user's latest receipt in the room.
result = self.get_last_receipt_for_user_txn(
result = self.get_last_unthreaded_receipt_for_user_txn(
txn,
user_id,
room_id,
receipt_types=(
ReceiptTypes.READ,
ReceiptTypes.READ_PRIVATE,
),
receipt_types=(ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE),
)

if result:
Expand Down Expand Up @@ -574,10 +571,7 @@ def _get_receipts_by_room_txn(
receipt_types_clause, args = make_in_list_sql_clause(
self.database_engine,
"receipt_type",
(
ReceiptTypes.READ,
ReceiptTypes.READ_PRIVATE,
),
(ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE),
)

sql = f"""
Expand Down
36 changes: 5 additions & 31 deletions synapse/storage/databases/main/receipts.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,48 +135,21 @@ def get_max_receipt_stream_id(self) -> int:
"""Get the current max stream ID for receipts stream"""
return self._receipts_id_gen.get_current_token()

async def get_last_receipt_event_id_for_user(
self, user_id: str, room_id: str, receipt_types: Collection[str]
) -> Optional[str]:
"""
Fetch the event ID for the latest receipt in a room with one of the given receipt types.
Args:
user_id: The user to fetch receipts for.
room_id: The room ID to fetch the receipt for.
receipt_type: The receipt types to fetch.
Returns:
The latest receipt, if one exists.
"""
result = await self.db_pool.runInteraction(
"get_last_receipt_event_id_for_user",
self.get_last_receipt_for_user_txn,
user_id,
room_id,
receipt_types,
)
if not result:
return None

event_id, _ = result
return event_id

def get_last_receipt_for_user_txn(
def get_last_unthreaded_receipt_for_user_txn(
self,
txn: LoggingTransaction,
user_id: str,
room_id: str,
receipt_types: Collection[str],
) -> Optional[Tuple[str, int]]:
"""
Fetch the event ID and stream_ordering for the latest receipt in a room
with one of the given receipt types.
Fetch the event ID and stream_ordering for the latest unthreaded receipt
in a room with one of the given receipt types.
Args:
user_id: The user to fetch receipts for.
room_id: The room ID to fetch the receipt for.
receipt_type: The receipt types to fetch.
receipt_types: The receipt types to fetch.
Returns:
The event ID and stream ordering of the latest receipt, if one exists.
Expand All @@ -193,6 +166,7 @@ def get_last_receipt_for_user_txn(
WHERE {clause}
AND user_id = ?
AND room_id = ?
AND thread_id IS NULL
ORDER BY stream_ordering DESC
LIMIT 1
"""
Expand Down
74 changes: 38 additions & 36 deletions tests/storage/test_receipts.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Collection, Optional

from synapse.api.constants import ReceiptTypes
from synapse.types import UserID, create_requester
Expand Down Expand Up @@ -84,6 +85,33 @@ def prepare(self, reactor, clock, homeserver) -> None:
)
)

def get_last_unthreaded_receipt(
self, receipt_types: Collection[str], room_id: Optional[str] = None
) -> Optional[str]:
"""
Fetch the event ID for the latest unthreaded receipt in the test room for the test user.
Args:
receipt_types: The receipt types to fetch.
Returns:
The latest receipt, if one exists.
"""
result = self.get_success(
self.store.db_pool.runInteraction(
"get_last_receipt_event_id_for_user",
self.store.get_last_unthreaded_receipt_for_user_txn,
OUR_USER_ID,
room_id or self.room_id1,
receipt_types,
)
)
if not result:
return None

event_id, _ = result
return event_id

def test_return_empty_with_no_data(self) -> None:
res = self.get_success(
self.store.get_receipts_for_user(
Expand All @@ -107,16 +135,10 @@ def test_return_empty_with_no_data(self) -> None:
)
self.assertEqual(res, {})

res = self.get_success(
self.store.get_last_receipt_event_id_for_user(
OUR_USER_ID,
self.room_id1,
[
ReceiptTypes.READ,
ReceiptTypes.READ_PRIVATE,
],
)
res = self.get_last_unthreaded_receipt(
[ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE]
)

self.assertEqual(res, None)

def test_get_receipts_for_user(self) -> None:
Expand Down Expand Up @@ -228,29 +250,17 @@ def test_get_last_receipt_event_id_for_user(self) -> None:
)

# Test we get the latest event when we want both private and public receipts
res = self.get_success(
self.store.get_last_receipt_event_id_for_user(
OUR_USER_ID,
self.room_id1,
[ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE],
)
res = self.get_last_unthreaded_receipt(
[ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE]
)
self.assertEqual(res, event1_2_id)

# Test we get the older event when we want only public receipt
res = self.get_success(
self.store.get_last_receipt_event_id_for_user(
OUR_USER_ID, self.room_id1, [ReceiptTypes.READ]
)
)
res = self.get_last_unthreaded_receipt([ReceiptTypes.READ])
self.assertEqual(res, event1_1_id)

# Test we get the latest event when we want only the private receipt
res = self.get_success(
self.store.get_last_receipt_event_id_for_user(
OUR_USER_ID, self.room_id1, [ReceiptTypes.READ_PRIVATE]
)
)
res = self.get_last_unthreaded_receipt([ReceiptTypes.READ_PRIVATE])
self.assertEqual(res, event1_2_id)

# Test receipt updating
Expand All @@ -259,11 +269,7 @@ def test_get_last_receipt_event_id_for_user(self) -> None:
self.room_id1, ReceiptTypes.READ, OUR_USER_ID, [event1_2_id], None, {}
)
)
res = self.get_success(
self.store.get_last_receipt_event_id_for_user(
OUR_USER_ID, self.room_id1, [ReceiptTypes.READ]
)
)
res = self.get_last_unthreaded_receipt([ReceiptTypes.READ])
self.assertEqual(res, event1_2_id)

# Send some events into the second room
Expand All @@ -282,11 +288,7 @@ def test_get_last_receipt_event_id_for_user(self) -> None:
{},
)
)
res = self.get_success(
self.store.get_last_receipt_event_id_for_user(
OUR_USER_ID,
self.room_id2,
[ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE],
)
res = self.get_last_unthreaded_receipt(
[ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE], room_id=self.room_id2
)
self.assertEqual(res, event2_1_id)

0 comments on commit 5680169

Please sign in to comment.