diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index 700c0c1b4e69..c71386a214d1 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import itertools import logging from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Union @@ -215,16 +216,33 @@ async def action_for_event_by_user( ) = await self._get_power_levels_and_sender_level(event, context) # If the experimental feature is not enabled, skip fetching relations. - if self._relations_match_enabled: + relations = {} + if self._relations_match_enabled or True: + # If the event does not have a relation, then cannot have any mutual + # relations. relation = relation_from_event(event) if relation: - relations = await self.store.get_mutual_event_relations( - relation.parent_id - ) - else: - relations = set() - else: - relations = set() + # Pre-filter to figure out which relation types are interesting. + rel_types = set() + for rule in itertools.chain(*rules_by_user.values()): + # Skip disabled rules. + if not rule.get("enabled"): + continue + + for condition in rule["conditions"]: + if condition["kind"] != "org.matrix.msc3772.relation_match": + continue + + # rel_type is required. + rel_type = condition.get("rel_type") + if rel_type: + rel_types.add(rel_type) + + # If any valid rules were found, fetch the mutual relations. + if rel_types: + relations = await self.store.get_mutual_event_relations( + relation.parent_id, rel_types + ) evaluator = PushRuleEvaluatorForEvent( event, diff --git a/synapse/push/push_rule_evaluator.py b/synapse/push/push_rule_evaluator.py index 75706154d67d..2e8a017add34 100644 --- a/synapse/push/push_rule_evaluator.py +++ b/synapse/push/push_rule_evaluator.py @@ -120,7 +120,7 @@ def __init__( room_member_count: int, sender_power_level: int, power_levels: Dict[str, Union[int, Dict[str, int]]], - relations: Set[Tuple[str, str, str]], + relations: Dict[str, Set[Tuple[str, str]]], relations_match_enabled: bool, ): self._event = event @@ -293,12 +293,10 @@ def _relation_match(self, condition: dict, user_id: str) -> bool: type_pattern = condition.get("type") # If any other relations matches, return True. - for relation in self._relations: - if rel_type != relation[0]: + for sender, event_type in self._relations.get(rel_type, ()): + if sender_pattern and not _glob_matches(sender_pattern, sender): continue - if sender_pattern and not _glob_matches(sender_pattern, relation[1]): - continue - if type_pattern and not _glob_matches(type_pattern, relation[2]): + if type_pattern and not _glob_matches(type_pattern, event_type): continue # All values must have matched. return True diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index e2098767afd6..17e35cf63e68 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -1829,7 +1829,7 @@ def _handle_event_relations( (relation.parent_id,), ) txn.call_after( - self.store.get_mutual_event_relations.invalidate, + self.store.get_mutual_event_relations_for_rel_type.invalidate, (relation.parent_id,), ) @@ -2009,7 +2009,9 @@ def _handle_redact_relations( txn, self.store.get_thread_participated, (redacted_relates_to,) ) self.store._invalidate_cache_and_stream( - txn, self.store.get_mutual_event_relations, (redacted_relates_to,) + txn, + self.store.get_mutual_event_relations_for_rel_type, + (redacted_relates_to,), ) self.db_pool.simple_delete_txn( diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py index cf194af75535..df6a0f484660 100644 --- a/synapse/storage/databases/main/relations.py +++ b/synapse/storage/databases/main/relations.py @@ -13,6 +13,7 @@ # limitations under the License. import logging +from collections import defaultdict from typing import ( Collection, Dict, @@ -768,9 +769,18 @@ def _get_if_user_has_annotated_event(txn: LoggingTransaction) -> bool: ) @cached(iterable=True) + async def get_mutual_event_relations_for_rel_type( + self, event_id: str, relation_type: str + ) -> Set[Tuple[str, str]]: + raise NotImplementedError() + + @cachedList( + cached_method_name="get_mutual_event_relations_for_rel_type", + list_name="relation_types", + ) async def get_mutual_event_relations( - self, event_id: str - ) -> Set[Tuple[str, str, str]]: + self, event_id: str, relation_types: Collection[str] + ) -> Dict[str, Set[Tuple[str, str]]]: """ Fetch event metadata for events which related to the same event as the given event. @@ -780,20 +790,29 @@ async def get_mutual_event_relations( event_id: The event ID which is targeted by relations. Returns: - A set of tuples of: - The relation type - The sender - The event type + A dictionary of relation type to: + A set of tuples of: + The sender + The event type """ - sql = """ + rel_type_sql, rel_type_args = make_in_list_sql_clause( + self.database_engine, "rel_type", relation_types + ) + + sql = f""" SELECT DISTINCT relation_type, sender, type FROM event_relations INNER JOIN events USING (event_id) - WHERE relates_to_id = ? + WHERE relates_to_id = ? AND {rel_type_sql} """ - def _get_event_relations(txn: LoggingTransaction) -> Set[Tuple[str, str, str]]: - txn.execute(sql, (event_id,)) - return set(cast(List[Tuple[str, str, str]], txn.fetchall())) + def _get_event_relations( + txn: LoggingTransaction, + ) -> Dict[str, Set[Tuple[str, str]]]: + txn.execute(sql, [event_id] + rel_type_args) + result = defaultdict(set) + for rel_type, sender, type in txn.fetchall(): + result[rel_type].add((sender, type)) + return result return await self.db_pool.runInteraction( "get_event_relations", _get_event_relations diff --git a/tests/push/test_push_rule_evaluator.py b/tests/push/test_push_rule_evaluator.py index bed82b9613ce..9b623d0033cd 100644 --- a/tests/push/test_push_rule_evaluator.py +++ b/tests/push/test_push_rule_evaluator.py @@ -29,7 +29,7 @@ class PushRuleEvaluatorTestCase(unittest.TestCase): def _get_evaluator( self, content: JsonDict, - relations: Optional[Set[Tuple[str, str, str]]] = None, + relations: Optional[Dict[str, Set[Tuple[str, str]]]] = None, relations_match_enabled: bool = False, ) -> PushRuleEvaluatorForEvent: event = FrozenEvent( @@ -292,7 +292,7 @@ def test_relation_match(self) -> None: # Check if the experimental feature is disabled. evaluator = self._get_evaluator( - {}, {("m.annotation", "@user:test", "m.reaction")} + {}, {"m.annotation": {("@user:test", "m.reaction")}} ) condition = {"kind": "relation_match"} # Oddly, an unknown condition always matches. @@ -300,7 +300,7 @@ def test_relation_match(self) -> None: # A push rule evaluator with the experimental rule enabled. evaluator = self._get_evaluator( - {}, {("m.annotation", "@user:test", "m.reaction")}, True + {}, {"m.annotation": {("@user:test", "m.reaction")}}, True ) # Check just relation type.