Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Do not allow MSC3440 threads to fork threads (#11161)
Browse files Browse the repository at this point in the history
Adds validation to the Client-Server API to ensure that
the potential thread head does not relate to another event
already. This results in not allowing a thread to "fork" into
other threads.

If the target event is unknown for some reason (maybe it isn't
visible to your homeserver), but is the target of other events
it is assumed that the thread can be created from it. Otherwise,
it is rejected as an unknown event.
  • Loading branch information
clokep authored Nov 18, 2021
1 parent e2dabec commit 4bd54b2
Show file tree
Hide file tree
Showing 4 changed files with 176 additions and 8 deletions.
1 change: 1 addition & 0 deletions changelog.d/11161.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Experimental support for the thread relation defined in [MSC3440](https://github.com/matrix-org/matrix-doc/pull/3440).
54 changes: 48 additions & 6 deletions synapse/handlers/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -1001,13 +1001,52 @@ async def create_new_client_event(
)

self.validator.validate_new(event, self.config)
await self._validate_event_relation(event)
logger.debug("Created event %s", event.event_id)

return event, context

async def _validate_event_relation(self, event: EventBase) -> None:
"""
Ensure the relation data on a new event is not bogus.
Args:
event: The event being created.
Raises:
SynapseError if the event is invalid.
"""

relation = event.content.get("m.relates_to")
if not relation:
return

relation_type = relation.get("rel_type")
if not relation_type:
return

# Ensure the parent is real.
relates_to = relation.get("event_id")
if not relates_to:
return

parent_event = await self.store.get_event(relates_to, allow_none=True)
if parent_event:
# And in the same room.
if parent_event.room_id != event.room_id:
raise SynapseError(400, "Relations must be in the same room")

else:
# There must be some reason that the client knows the event exists,
# see if there are existing relations. If so, assume everything is fine.
if not await self.store.event_is_target_of_relation(relates_to):
# Otherwise, the client can't know about the parent event!
raise SynapseError(400, "Can't send relation to unknown event")

# If this event is an annotation then we check that that the sender
# can't annotate the same way twice (e.g. stops users from liking an
# event multiple times).
relation = event.content.get("m.relates_to", {})
if relation.get("rel_type") == RelationTypes.ANNOTATION:
relates_to = relation["event_id"]
if relation_type == RelationTypes.ANNOTATION:
aggregation_key = relation["key"]

already_exists = await self.store.has_user_annotated_event(
Expand All @@ -1016,9 +1055,12 @@ async def create_new_client_event(
if already_exists:
raise SynapseError(400, "Can't send same reaction twice")

logger.debug("Created event %s", event.event_id)

return event, context
# Don't attempt to start a thread if the parent event is a relation.
elif relation_type == RelationTypes.THREAD:
if await self.store.event_includes_relation(relates_to):
raise SynapseError(
400, "Cannot start threads from an event with a relation"
)

@measure_func("handle_new_client_event")
async def handle_new_client_event(
Expand Down
67 changes: 65 additions & 2 deletions synapse/storage/databases/main/relations.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,69 @@ def _get_recent_references_for_event_txn(
"get_recent_references_for_event", _get_recent_references_for_event_txn
)

async def event_includes_relation(self, event_id: str) -> bool:
"""Check if the given event relates to another event.
An event has a relation if it has a valid m.relates_to with a rel_type
and event_id in the content:
{
"content": {
"m.relates_to": {
"rel_type": "m.replace",
"event_id": "$other_event_id"
}
}
}
Args:
event_id: The event to check.
Returns:
True if the event includes a valid relation.
"""

result = await self.db_pool.simple_select_one_onecol(
table="event_relations",
keyvalues={"event_id": event_id},
retcol="event_id",
allow_none=True,
desc="event_includes_relation",
)
return result is not None

async def event_is_target_of_relation(self, parent_id: str) -> bool:
"""Check if the given event is the target of another event's relation.
An event is the target of an event relation if it has a valid
m.relates_to with a rel_type and event_id pointing to parent_id in the
content:
{
"content": {
"m.relates_to": {
"rel_type": "m.replace",
"event_id": "$parent_id"
}
}
}
Args:
parent_id: The event to check.
Returns:
True if the event is the target of another event's relation.
"""

result = await self.db_pool.simple_select_one_onecol(
table="event_relations",
keyvalues={"relates_to_id": parent_id},
retcol="event_id",
allow_none=True,
desc="event_is_target_of_relation",
)
return result is not None

@cached(tree=True)
async def get_aggregation_groups_for_event(
self,
Expand Down Expand Up @@ -362,7 +425,7 @@ async def events_have_relations(
%s;
"""

def _get_if_event_has_relations(txn) -> List[str]:
def _get_if_events_have_relations(txn) -> List[str]:
clauses: List[str] = []
clause, args = make_in_list_sql_clause(
txn.database_engine, "relates_to_id", parent_ids
Expand All @@ -387,7 +450,7 @@ def _get_if_event_has_relations(txn) -> List[str]:
return [row[0] for row in txn]

return await self.db_pool.runInteraction(
"get_if_event_has_relations", _get_if_event_has_relations
"get_if_events_have_relations", _get_if_events_have_relations
)

async def has_user_annotated_event(
Expand Down
62 changes: 62 additions & 0 deletions tests/rest/client/test_relations.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,49 @@ def test_deny_membership(self):
channel = self._send_relation(RelationTypes.ANNOTATION, EventTypes.Member)
self.assertEquals(400, channel.code, channel.json_body)

def test_deny_invalid_event(self):
"""Test that we deny relations on non-existant events"""
channel = self._send_relation(
RelationTypes.ANNOTATION,
EventTypes.Message,
parent_id="foo",
content={"body": "foo", "msgtype": "m.text"},
)
self.assertEquals(400, channel.code, channel.json_body)

# Unless that event is referenced from another event!
self.get_success(
self.hs.get_datastore().db_pool.simple_insert(
table="event_relations",
values={
"event_id": "bar",
"relates_to_id": "foo",
"relation_type": RelationTypes.THREAD,
},
desc="test_deny_invalid_event",
)
)
channel = self._send_relation(
RelationTypes.THREAD,
EventTypes.Message,
parent_id="foo",
content={"body": "foo", "msgtype": "m.text"},
)
self.assertEquals(200, channel.code, channel.json_body)

def test_deny_invalid_room(self):
"""Test that we deny relations on non-existant events"""
# Create another room and send a message in it.
room2 = self.helper.create_room_as(self.user_id, tok=self.user_token)
res = self.helper.send(room2, body="Hi!", tok=self.user_token)
parent_id = res["event_id"]

# Attempt to send an annotation to that event.
channel = self._send_relation(
RelationTypes.ANNOTATION, "m.reaction", parent_id=parent_id, key="A"
)
self.assertEquals(400, channel.code, channel.json_body)

def test_deny_double_react(self):
"""Test that we deny relations on membership events"""
channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="a")
Expand All @@ -99,6 +142,25 @@ def test_deny_double_react(self):
channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
self.assertEquals(400, channel.code, channel.json_body)

def test_deny_forked_thread(self):
"""It is invalid to start a thread off a thread."""
channel = self._send_relation(
RelationTypes.THREAD,
"m.room.message",
content={"msgtype": "m.text", "body": "foo"},
parent_id=self.parent_id,
)
self.assertEquals(200, channel.code, channel.json_body)
parent_id = channel.json_body["event_id"]

channel = self._send_relation(
RelationTypes.THREAD,
"m.room.message",
content={"msgtype": "m.text", "body": "foo"},
parent_id=parent_id,
)
self.assertEquals(400, channel.code, channel.json_body)

def test_basic_paginate_relations(self):
"""Tests that calling pagination API correctly the latest relations."""
channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
Expand Down

0 comments on commit 4bd54b2

Please sign in to comment.