Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Implement MSC3816, consider the root event for thread participation #12766

Merged
merged 7 commits into from
Jun 6, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/12766.bugfix
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Implement [MSC3816](https://github.com/matrix-org/matrix-spec-proposals/pull/3816): sending the root event in a thread should count as "participated" in it.
58 changes: 37 additions & 21 deletions synapse/handlers/relations.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import (
TYPE_CHECKING,
Collection,
Dict,
FrozenSet,
Iterable,
List,
Optional,
Tuple,
)
from typing import TYPE_CHECKING, Dict, FrozenSet, Iterable, List, Optional, Tuple

import attr

Expand Down Expand Up @@ -256,13 +247,19 @@ async def get_annotations_for_event(

return filtered_results

async def get_threads_for_events(
self, event_ids: Collection[str], user_id: str, ignored_users: FrozenSet[str]
async def _get_threads_for_events(
self,
events_by_id: Dict[str, EventBase],
relations_by_id: Dict[str, str],
user_id: str,
ignored_users: FrozenSet[str],
) -> Dict[str, _ThreadAggregation]:
"""Get the bundled aggregations for threads for the requested events.

Args:
event_ids: Events to get aggregations for threads.
events_by_id: A map of event_id to events to get aggregations for threads.
relations_by_id: A map of event_id to the relation type, if one exists
for that event.
user_id: The user requesting the bundled aggregations.
ignored_users: The users ignored by the requesting user.

Expand All @@ -273,16 +270,34 @@ async def get_threads_for_events(
"""
user = UserID.from_string(user_id)

# It is not valid to start a thread on an event which itself relates to another event.
event_ids = [eid for eid in events_by_id.keys() if eid not in relations_by_id]

# Fetch thread summaries.
summaries = await self._main_store.get_thread_summaries(event_ids)

# Only fetch participated for a limited selection based on what had
# summaries.
# Limit fetching whether the requester has participated in a thread to
# events which are thread roots.
thread_event_ids = [
event_id for event_id, summary in summaries.items() if summary
]
participated = await self._main_store.get_threads_participated(
thread_event_ids, user_id

# Pre-seed thread participation with whether the requester sent the event.
participated = {
event_id: events_by_id[event_id].sender == user_id
for event_id in thread_event_ids
}
# For events the requester did not send, check the database for whether
# the requester sent a threaded reply.
participated.update(
await self._main_store.get_threads_participated(
[
event_id
for event_id in thread_event_ids
if not participated[event_id]
],
user_id,
)
)

# Then subtract off the results for any ignored users.
Expand Down Expand Up @@ -343,7 +358,8 @@ async def get_threads_for_events(
count=thread_count,
# If there's a thread summary it must also exist in the
# participated dictionary.
current_user_participated=participated[event_id],
current_user_participated=events_by_id[event_id].sender == user_id
or participated[event_id],
)

return results
Expand Down Expand Up @@ -401,9 +417,9 @@ async def get_bundled_aggregations(
# events to be fetched. Thus, we check those first!

# Fetch thread summaries (but only for the directly requested events).
threads = await self.get_threads_for_events(
# It is not valid to start a thread on an event which itself relates to another event.
[eid for eid in events_by_id.keys() if eid not in relations_by_id],
threads = await self._get_threads_for_events(
events_by_id,
relations_by_id,
user_id,
ignored_users,
)
Expand Down
85 changes: 59 additions & 26 deletions tests/rest/client/test_relations.py
Original file line number Diff line number Diff line change
Expand Up @@ -896,6 +896,7 @@ def _test_bundled_aggregations(
relation_type: str,
assertion_callable: Callable[[JsonDict], None],
expected_db_txn_for_event: int,
access_token: Optional[str] = None,
) -> None:
"""
Makes requests to various endpoints which should include bundled aggregations
Expand All @@ -907,7 +908,9 @@ def _test_bundled_aggregations(
for relation-specific assertions.
expected_db_txn_for_event: The number of database transactions which
are expected for a call to /event/.
access_token: The access token to user, defaults to self.user_token.
"""
access_token = access_token or self.user_token

def assert_bundle(event_json: JsonDict) -> None:
"""Assert the expected values of the bundled aggregations."""
Expand All @@ -921,7 +924,7 @@ def assert_bundle(event_json: JsonDict) -> None:
channel = self.make_request(
"GET",
f"/rooms/{self.room}/event/{self.parent_id}",
access_token=self.user_token,
access_token=access_token,
)
self.assertEqual(200, channel.code, channel.json_body)
assert_bundle(channel.json_body)
Expand All @@ -932,7 +935,7 @@ def assert_bundle(event_json: JsonDict) -> None:
channel = self.make_request(
"GET",
f"/rooms/{self.room}/messages?dir=b",
access_token=self.user_token,
access_token=access_token,
)
self.assertEqual(200, channel.code, channel.json_body)
assert_bundle(self._find_event_in_chunk(channel.json_body["chunk"]))
Expand All @@ -941,15 +944,15 @@ def assert_bundle(event_json: JsonDict) -> None:
channel = self.make_request(
"GET",
f"/rooms/{self.room}/context/{self.parent_id}",
access_token=self.user_token,
access_token=access_token,
)
self.assertEqual(200, channel.code, channel.json_body)
assert_bundle(channel.json_body["event"])

# Request sync.
filter = urllib.parse.quote_plus(b'{"room": {"timeline": {"limit": 4}}}')
channel = self.make_request(
"GET", f"/sync?filter={filter}", access_token=self.user_token
"GET", f"/sync?filter={filter}", access_token=access_token
)
self.assertEqual(200, channel.code, channel.json_body)
room_timeline = channel.json_body["rooms"]["join"][self.room]["timeline"]
Expand All @@ -962,7 +965,7 @@ def assert_bundle(event_json: JsonDict) -> None:
"/search",
# Search term matches the parent message.
content={"search_categories": {"room_events": {"search_term": "Hi"}}},
access_token=self.user_token,
access_token=access_token,
)
self.assertEqual(200, channel.code, channel.json_body)
chunk = [
Expand Down Expand Up @@ -1037,30 +1040,60 @@ def test_thread(self) -> None:
"""
Test that threads get correctly bundled.
"""
self._send_relation(RelationTypes.THREAD, "m.room.test")
channel = self._send_relation(RelationTypes.THREAD, "m.room.test")
# The root message is from "user", send replies as "user2".
self._send_relation(
RelationTypes.THREAD, "m.room.test", access_token=self.user2_token
)
channel = self._send_relation(
RelationTypes.THREAD, "m.room.test", access_token=self.user2_token
)
thread_2 = channel.json_body["event_id"]

def assert_thread(bundled_aggregations: JsonDict) -> None:
self.assertEqual(2, bundled_aggregations.get("count"))
self.assertTrue(bundled_aggregations.get("current_user_participated"))
# The latest thread event has some fields that don't matter.
self.assert_dict(
{
"content": {
"m.relates_to": {
"event_id": self.parent_id,
"rel_type": RelationTypes.THREAD,
}
# This needs two assertion functions which are identical except for whether
# the current_user_participated flag is True, create a factory for the
# two versions.
def _gen_assert(participated: bool) -> Callable[[JsonDict], None]:
def assert_thread(bundled_aggregations: JsonDict) -> None:
self.assertEqual(2, bundled_aggregations.get("count"))
self.assertEqual(
participated, bundled_aggregations.get("current_user_participated")
)
# The latest thread event has some fields that don't matter.
self.assert_dict(
{
"content": {
"m.relates_to": {
"event_id": self.parent_id,
"rel_type": RelationTypes.THREAD,
}
},
"event_id": thread_2,
"sender": self.user2_id,
"type": "m.room.test",
},
"event_id": thread_2,
"sender": self.user_id,
"type": "m.room.test",
},
bundled_aggregations.get("latest_event"),
)
bundled_aggregations.get("latest_event"),
)

self._test_bundled_aggregations(RelationTypes.THREAD, assert_thread, 9)
return assert_thread

# The "user" sent the root event and is making queries for the bundled
# aggregations: they have participated.
self._test_bundled_aggregations(RelationTypes.THREAD, _gen_assert(True), 8)
# The "user2" sent replies in the thread and is making queries for the
# bundled aggregations: they have participated.
#
# Note that this re-uses some cached values, so the total number of
# queries is much smaller.
self._test_bundled_aggregations(
RelationTypes.THREAD, _gen_assert(True), 2, access_token=self.user2_token
)

# A user with no interactions with the thread: they have not participated.
user3_id, user3_token = self._create_user("charlie")
self.helper.join(self.room, user=user3_id, tok=user3_token)
self._test_bundled_aggregations(
RelationTypes.THREAD, _gen_assert(False), 2, access_token=user3_token
)

def test_thread_with_bundled_aggregations_for_latest(self) -> None:
"""
Expand Down Expand Up @@ -1106,7 +1139,7 @@ def assert_thread(bundled_aggregations: JsonDict) -> None:
bundled_aggregations["latest_event"].get("unsigned"),
)

self._test_bundled_aggregations(RelationTypes.THREAD, assert_thread, 9)
self._test_bundled_aggregations(RelationTypes.THREAD, assert_thread, 8)

def test_nested_thread(self) -> None:
"""
Expand Down