diff --git a/changelog.d/15131.misc b/changelog.d/15131.misc new file mode 100644 index 000000000000..441e77ba6532 --- /dev/null +++ b/changelog.d/15131.misc @@ -0,0 +1 @@ +Add a new third party callback `check_event_allowed_v2` that is compatible with new batch persisting mechanisms. \ No newline at end of file diff --git a/docs/modules/third_party_rules_callbacks.md b/docs/modules/third_party_rules_callbacks.md index 4a27d976fb13..f301bfdcba82 100644 --- a/docs/modules/third_party_rules_callbacks.md +++ b/docs/modules/third_party_rules_callbacks.md @@ -10,6 +10,75 @@ The available third party rules callbacks are: ### `check_event_allowed` +_First introduced in Synapse v1.7x.x + +```python +async def check_event_allowed_v2( + event: "synapse.events.EventBase", + state_events: "synapse.types.StateMap", +) -> Tuple[bool, Optional[dict], Optional[dict]] +``` + +** +This callback is very experimental and can and will break without notice. Module developers +are encouraged to implement `check_event_for_spam` from the spam checker category instead. +** + +Returns: + +- A tuple consisting of: + + - a boolean representing whether or not the event is allowed + - an optional dict to form the basis of a replacement event for the event + - an optional dict to form the basis of an additional event to be sent into the + room + +Called when processing any incoming event, with the event and a `StateMap` +representing the current state of the room the event is being sent into. A `StateMap` is +a dictionary that maps tuples containing an event type and a state key to the +corresponding state event. For example retrieving the room's `m.room.create` event from +the `state_events` argument would look like this: `state_events.get(("m.room.create", ""))`. +The module must return a boolean indicating whether the event can be allowed. + +Note that this callback function processes incoming events coming via federation +traffic (on top of client traffic). This means denying an event might cause the local +copy of the room's history to diverge from that of remote servers. This may cause +federation issues in the room. It is strongly recommended to only deny events using this +callback function if the sender is a local user, or in a private federation in which all +servers are using the same module, with the same configuration. + +If the boolean returned by the module is `True`, it may tell Synapse to replace the +event with new data by returning the new event's data as a dictionary. In order to do +that, it is recommended the module calls `event.get_dict()` to get the current event as a +dictionary, and modify the returned dictionary accordingly. + +Module writers may also wish to use this check to send a second event into the room along +with the event being checked, if this is the case the module writer must provide a dict that +will form the basis of the event that is to be added to the room and it must be returned by `check_event_allowed_v2`. +This dict will then be turned into an event at the appropriate time and it will be persisted after the event +that triggered it, and if the event that triggered it is in a batch of events for persisting, it will be added to the +end of that batch. Note that the event MAY NOT be a membership event. + +If `check_event_allowed_v2` raises an exception, the module is assumed to have failed. +The event will not be accepted but is not treated as explicitly rejected, either. +An HTTP request causing the module check will likely result in a 500 Internal +Server Error. + +When the boolean returned by the module is `False`, the event is rejected. +(Module developers should not use exceptions for rejection.) + +Note that replacing the event or adding an event only works for events sent by local users, not for events +received over federation. + +If multiple modules implement this callback, they will be considered in order. If a +callback returns `True`, Synapse falls through to the next one. The value of the first +callback that does not return `True` will be used. If this happens, Synapse will not call +any of the subsequent implementations of this callback. This callback cannot be used in conjunction with `check_event_allowed`, +only one of these callbacks may be operational at a time - if both `check_event_allowed` and `check_event_allowed_v2` +active only `check_event_allowed` will be executed. + +### `check_event_allowed` + _First introduced in Synapse v1.39.0_ ```python diff --git a/synapse/events/third_party_rules.py b/synapse/events/third_party_rules.py index 61d4530be784..79e2c994d626 100644 --- a/synapse/events/third_party_rules.py +++ b/synapse/events/third_party_rules.py @@ -32,6 +32,10 @@ CHECK_EVENT_ALLOWED_CALLBACK = Callable[ [EventBase, StateMap[EventBase]], Awaitable[Tuple[bool, Optional[dict]]] ] +CHECK_EVENT_ALLOWED_V2_CALLBACK = Callable[ + [EventBase, StateMap[EventBase]], + Awaitable[Tuple[bool, Optional[dict], Optional[dict]]], +] ON_CREATE_ROOM_CALLBACK = Callable[[Requester, dict, bool], Awaitable] CHECK_THREEPID_CAN_BE_INVITED_CALLBACK = Callable[ [str, str, StateMap[EventBase]], Awaitable[bool] @@ -155,6 +159,9 @@ def __init__(self, hs: "HomeServer"): self._storage_controllers = hs.get_storage_controllers() self._check_event_allowed_callbacks: List[CHECK_EVENT_ALLOWED_CALLBACK] = [] + self._check_event_allowed_v2_callbacks: List[ + CHECK_EVENT_ALLOWED_V2_CALLBACK + ] = [] self._on_create_room_callbacks: List[ON_CREATE_ROOM_CALLBACK] = [] self._check_threepid_can_be_invited_callbacks: List[ CHECK_THREEPID_CAN_BE_INVITED_CALLBACK @@ -184,6 +191,7 @@ def __init__(self, hs: "HomeServer"): def register_third_party_rules_callbacks( self, check_event_allowed: Optional[CHECK_EVENT_ALLOWED_CALLBACK] = None, + check_event_allowed_v2: Optional[CHECK_EVENT_ALLOWED_V2_CALLBACK] = None, on_create_room: Optional[ON_CREATE_ROOM_CALLBACK] = None, check_threepid_can_be_invited: Optional[ CHECK_THREEPID_CAN_BE_INVITED_CALLBACK @@ -210,6 +218,9 @@ def register_third_party_rules_callbacks( if check_event_allowed is not None: self._check_event_allowed_callbacks.append(check_event_allowed) + if check_event_allowed_v2 is not None: + self._check_event_allowed_v2_callbacks.append(check_event_allowed_v2) + if on_create_room is not None: self._on_create_room_callbacks.append(on_create_room) @@ -256,7 +267,7 @@ async def check_event_allowed( self, event: EventBase, context: UnpersistedEventContextBase, - ) -> Tuple[bool, Optional[dict]]: + ) -> Tuple[bool, Optional[dict], Optional[dict]]: """Check if a provided event should be allowed in the given context. The module can return: @@ -264,7 +275,8 @@ async def check_event_allowed( * False: the event is not allowed, and should be rejected with M_FORBIDDEN. If the event is allowed, the module can also return a dictionary to use as a - replacement for the event. + replacement for the event, and/or return a dictionary to use as the basis for + another event to be sent into the room. Args: event: The event to be checked. @@ -274,8 +286,11 @@ async def check_event_allowed( The result from the ThirdPartyRules module, as above. """ # Bail out early without hitting the store if we don't have any callbacks to run. - if len(self._check_event_allowed_callbacks) == 0: - return True, None + if ( + len(self._check_event_allowed_callbacks) == 0 + and len(self._check_event_allowed_v2_callbacks) == 0 + ): + return True, None, None prev_state_ids = await context.get_prev_state_ids() @@ -288,35 +303,63 @@ async def check_event_allowed( # the hashes and signatures. event.freeze() - for callback in self._check_event_allowed_callbacks: - try: - res, replacement_data = await delay_cancellation( - callback(event, state_events) - ) - except CancelledError: - raise - except SynapseError as e: - # FIXME: Being able to throw SynapseErrors is relied upon by - # some modules. PR #10386 accidentally broke this ability. - # That said, we aren't keen on exposing this implementation detail - # to modules and we should one day have a proper way to do what - # is wanted. - # This module callback needs a rework so that hacks such as - # this one are not necessary. - raise e - except Exception: - raise ModuleFailedException( - "Failed to run `check_event_allowed` module API callback" - ) + if len(self._check_event_allowed_callbacks) != 0: + for callback in self._check_event_allowed_callbacks: + try: + res, replacement_data = await delay_cancellation( + callback(event, state_events) + ) + except CancelledError: + raise + except SynapseError as e: + # FIXME: Being able to throw SynapseErrors is relied upon by + # some modules. PR #10386 accidentally broke this ability. + # That said, we aren't keen on exposing this implementation detail + # to modules and we should one day have a proper way to do what + # is wanted. + # This module callback needs a rework so that hacks such as + # this one are not necessary. + raise e + except Exception: + raise ModuleFailedException( + "Failed to run `check_event_allowed` module API callback" + ) - # Return if the event shouldn't be allowed or if the module came up with a - # replacement dict for the event. - if res is False: - return res, None - elif isinstance(replacement_data, dict): - return True, replacement_data + # Return if the event shouldn't be allowed or if the module came up with a + # replacement dict for the event. + if res is False: + return res, None, None + elif isinstance(replacement_data, dict): + return True, replacement_data, None + else: + for v2_callback in self._check_event_allowed_v2_callbacks: + try: + res, replacement_data, new_event = await delay_cancellation( + v2_callback(event, state_events) + ) + except CancelledError: + raise + except SynapseError as e: + # FIXME: Being able to throw SynapseErrors is relied upon by + # some modules. PR #10386 accidentally broke this ability. + # That said, we aren't keen on exposing this implementation detail + # to modules and we should one day have a proper way to do what + # is wanted. + # This module callback needs a rework so that hacks such as + # this one are not necessary. + raise e + except Exception: + raise ModuleFailedException( + "Failed to run `check_event_allowed_v2` module API callback" + ) - return True, None + # Return if the event shouldn't be allowed, if the module came up with a + # replacement dict for the event, or if the module wants to send a new event + if res is False: + return res, None, None + else: + return True, replacement_data, new_event + return True, None, None async def on_create_room( self, requester: Requester, config: dict, is_requester_admin: bool diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 80156ef343aa..dedcc620acab 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -1007,6 +1007,7 @@ async def on_make_join_request( ( event, unpersisted_context, + _, ) = await self.event_creation_handler.create_new_client_event( builder=builder, prev_event_ids=prev_event_ids, @@ -1198,7 +1199,7 @@ async def on_make_leave_request( }, ) - event, _ = await self.event_creation_handler.create_new_client_event( + event, _, _ = await self.event_creation_handler.create_new_client_event( builder=builder ) @@ -1251,9 +1252,10 @@ async def on_make_knock_request( ( event, unpersisted_context, + _, ) = await self.event_creation_handler.create_new_client_event(builder=builder) - event_allowed, _ = await self.third_party_event_rules.check_event_allowed( + event_allowed, _, _ = await self.third_party_event_rules.check_event_allowed( event, unpersisted_context ) if not event_allowed: @@ -1446,6 +1448,7 @@ async def exchange_third_party_invite( ( event, unpersisted_context, + _, ) = await self.event_creation_handler.create_new_client_event( builder=builder ) @@ -1528,6 +1531,7 @@ async def on_exchange_third_party_invite_request( ( event, unpersisted_context, + _, ) = await self.event_creation_handler.create_new_client_event( builder=builder ) @@ -1610,6 +1614,7 @@ async def add_display_name_to_third_party_invite( ( event, unpersisted_context, + _, ) = await self.event_creation_handler.create_new_client_event(builder=builder) EventValidator().validate_new(event, self.config) diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py index b7136f8d1ccb..0b974afe4d80 100644 --- a/synapse/handlers/federation_event.py +++ b/synapse/handlers/federation_event.py @@ -404,9 +404,11 @@ async def on_send_membership_event( # for knock events, we run the third-party event rules. It's not entirely clear # why we don't do this for other sorts of membership events. if event.membership == Membership.KNOCK: - event_allowed, _ = await self._third_party_event_rules.check_event_allowed( - event, context - ) + ( + event_allowed, + _, + _, + ) = await self._third_party_event_rules.check_event_allowed(event, context) if not event_allowed: logger.info("Sending of knock %s forbidden by third-party rules", event) raise SynapseError( diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index da129ec16a4a..d283a938c002 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -16,6 +16,7 @@ # limitations under the License. import logging import random +from builtins import dict from http import HTTPStatus from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Tuple @@ -577,7 +578,7 @@ async def create_event( state_map: Optional[StateMap[str]] = None, for_batch: bool = False, current_state_group: Optional[int] = None, - ) -> Tuple[EventBase, UnpersistedEventContextBase]: + ) -> Tuple[EventBase, UnpersistedEventContextBase, Optional[dict]]: """ Given a dict from a client, create a new event. If bool for_batch is true, will create an event using the prev_event_ids, and will create an event context for @@ -649,7 +650,9 @@ async def create_event( exceeded Returns: - Tuple of created event, Context + Tuple of created event, Context, and an optional event dict to form the basis + of a new event if third_party_rules would like to send an additional event as a + consequence of this event. """ await self.auth_blocking.check_auth_blocking(requester=requester) @@ -711,7 +714,7 @@ async def create_event( builder.internal_metadata.historical = historical - event, unpersisted_context = await self.create_new_client_event( + event, unpersisted_context, new_event = await self.create_new_client_event( builder=builder, requester=requester, allow_no_prev_events=allow_no_prev_events, @@ -765,7 +768,7 @@ async def create_event( ) self.validator.validate_new(event, self.config) - return event, unpersisted_context + return event, unpersisted_context, new_event async def _is_exempt_from_privacy_policy( self, builder: EventBuilder, requester: Requester @@ -1005,7 +1008,11 @@ async def create_and_send_nonmember_event( max_retries = 5 for i in range(max_retries): try: - event, unpersisted_context = await self.create_event( + ( + event, + unpersisted_context, + third_party_event_dict, + ) = await self.create_event( requester, event_dict, txn_id=txn_id, @@ -1054,9 +1061,24 @@ async def create_and_send_nonmember_event( Codes.FORBIDDEN, ) + events_and_context = [(event, context)] + if third_party_event_dict: + ( + third_party_event, + unpersisted_third_party_context, + _, + ) = await self.create_event( + requester, + third_party_event_dict, + ) + third_party_context = await unpersisted_third_party_context.persist( + third_party_event + ) + events_and_context.append((third_party_event, third_party_context)) + ev = await self.handle_new_client_event( requester=requester, - events_and_context=[(event, context)], + events_and_context=events_and_context, ratelimit=ratelimit, ignore_shadow_ban=ignore_shadow_ban, ) @@ -1086,7 +1108,7 @@ async def create_new_client_event( state_map: Optional[StateMap[str]] = None, for_batch: bool = False, current_state_group: Optional[int] = None, - ) -> Tuple[EventBase, UnpersistedEventContextBase]: + ) -> Tuple[EventBase, UnpersistedEventContextBase, Optional[dict]]: """Create a new event for a local client. If bool for_batch is true, will create an event using the prev_event_ids, and will create an event context for the event using the parameters state_map and current_state_group, thus these parameters @@ -1135,7 +1157,9 @@ async def create_new_client_event( batch persisting Returns: - Tuple of created event, UnpersistedEventContext + Tuple of created event, UnpersistedEventContext, and an optional event dict + to form the basis of a new event if third_party_rules would like to send an + additional event as a consequence of this event. """ # Strip down the state_event_ids to only what we need to auth the event. # For example, we don't need extra m.room.member that don't match event.sender @@ -1269,9 +1293,11 @@ async def create_new_client_event( if requester: context.app_service = requester.app_service - res, new_content = await self.third_party_event_rules.check_event_allowed( - event, context - ) + ( + res, + new_content, + new_event, + ) = await self.third_party_event_rules.check_event_allowed(event, context) if res is False: logger.info( "Event %s forbidden by third-party rules", @@ -1291,7 +1317,7 @@ async def create_new_client_event( await self._validate_event_relation(event) logger.debug("Created event %s", event.event_id) - return event, context + return event, context, new_event async def _validate_event_relation(self, event: EventBase) -> None: """ @@ -2046,7 +2072,7 @@ async def _send_dummy_event_for_room(self, room_id: str) -> bool: max_retries = 5 for i in range(max_retries): try: - event, unpersisted_context = await self.create_event( + event, unpersisted_context, _ = await self.create_event( requester, { "type": EventTypes.Dummy, diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index be120cb12f36..e5bffa9e2558 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -213,6 +213,7 @@ async def upgrade_room( ( tombstone_event, tombstone_unpersisted_context, + _, ) = await self.event_creation_handler.create_event( requester, { @@ -1066,7 +1067,11 @@ async def create_event( content: JsonDict, for_batch: bool, **kwargs: Any, - ) -> Tuple[EventBase, synapse.events.snapshot.UnpersistedEventContextBase]: + ) -> Tuple[ + EventBase, + synapse.events.snapshot.UnpersistedEventContextBase, + Optional[dict], + ]: """ Creates an event and associated event context. Args: @@ -1088,6 +1093,7 @@ async def create_event( ( new_event, new_unpersisted_context, + third_party_event, ) = await self.event_creation_handler.create_event( creator, event_dict, @@ -1103,7 +1109,7 @@ async def create_event( prev_event = [new_event.event_id] state_map[(new_event.type, new_event.state_key)] = new_event.event_id - return new_event, new_unpersisted_context + return new_event, new_unpersisted_context, third_party_event visibility = room_config.get("visibility", "private") preset_config = room_config.get( @@ -1121,7 +1127,7 @@ async def create_event( ) creation_content.update({"creator": creator_id}) - creation_event, unpersisted_creation_context = await create_event( + creation_event, unpersisted_creation_context, _ = await create_event( EventTypes.Create, creation_content, False ) creation_context = await unpersisted_creation_context.persist(creation_event) @@ -1161,14 +1167,17 @@ async def create_event( current_state_group = event_to_state[member_event_id] events_to_send = [] + third_party_events_to_append = [] # We treat the power levels override specially as this needs to be one # of the first events that get sent into a room. pl_content = initial_state.pop((EventTypes.PowerLevels, ""), None) if pl_content is not None: - power_event, power_context = await create_event( + power_event, power_context, power_tp_event = await create_event( EventTypes.PowerLevels, pl_content, True ) events_to_send.append((power_event, power_context)) + if power_tp_event: + third_party_events_to_append.append(power_tp_event) else: power_level_content: JsonDict = { "users": {creator_id: 100}, @@ -1211,76 +1220,114 @@ async def create_event( # apply those. if power_level_content_override: power_level_content.update(power_level_content_override) - pl_event, pl_context = await create_event( + pl_event, pl_context, pl_tp_event = await create_event( EventTypes.PowerLevels, power_level_content, True, ) events_to_send.append((pl_event, pl_context)) + if pl_tp_event: + third_party_events_to_append.append(pl_tp_event) if room_alias and (EventTypes.CanonicalAlias, "") not in initial_state: - room_alias_event, room_alias_context = await create_event( + room_alias_event, room_alias_context, ra_tp_event = await create_event( EventTypes.CanonicalAlias, {"alias": room_alias.to_string()}, True ) events_to_send.append((room_alias_event, room_alias_context)) + if ra_tp_event: + third_party_events_to_append.append(ra_tp_event) if (EventTypes.JoinRules, "") not in initial_state: - join_rules_event, join_rules_context = await create_event( + join_rules_event, join_rules_context, jr_tp_event = await create_event( EventTypes.JoinRules, {"join_rule": config["join_rules"]}, True, ) events_to_send.append((join_rules_event, join_rules_context)) + if jr_tp_event: + third_party_events_to_append.append(jr_tp_event) if (EventTypes.RoomHistoryVisibility, "") not in initial_state: - visibility_event, visibility_context = await create_event( + visibility_event, visibility_context, vis_tp_event = await create_event( EventTypes.RoomHistoryVisibility, {"history_visibility": config["history_visibility"]}, True, ) events_to_send.append((visibility_event, visibility_context)) + if vis_tp_event: + third_party_events_to_append.append(vis_tp_event) if config["guest_can_join"]: if (EventTypes.GuestAccess, "") not in initial_state: - guest_access_event, guest_access_context = await create_event( + ( + guest_access_event, + guest_access_context, + ga_tp_event, + ) = await create_event( EventTypes.GuestAccess, {EventContentFields.GUEST_ACCESS: GuestAccess.CAN_JOIN}, True, ) events_to_send.append((guest_access_event, guest_access_context)) + if ga_tp_event: + third_party_events_to_append.append(ga_tp_event) for (etype, state_key), content in initial_state.items(): - event, context = await create_event( + event, context, tp_event = await create_event( etype, content, True, state_key=state_key ) events_to_send.append((event, context)) + if tp_event: + third_party_events_to_append.append(tp_event) if config["encrypted"]: - encryption_event, encryption_context = await create_event( + encryption_event, encryption_context, encrypt_tp_event = await create_event( EventTypes.RoomEncryption, {"algorithm": RoomEncryptionAlgorithms.DEFAULT}, True, state_key="", ) events_to_send.append((encryption_event, encryption_context)) + if encrypt_tp_event: + third_party_events_to_append.append(encrypt_tp_event) if "name" in room_config: name = room_config["name"] - name_event, name_context = await create_event( + name_event, name_context, name_tp_event = await create_event( EventTypes.Name, {"name": name}, True, ) events_to_send.append((name_event, name_context)) + if name_tp_event: + third_party_events_to_append.append(name_tp_event) if "topic" in room_config: topic = room_config["topic"] - topic_event, topic_context = await create_event( + topic_event, topic_context, topic_tp_event = await create_event( EventTypes.Topic, {"topic": topic}, True, ) events_to_send.append((topic_event, topic_context)) + if topic_tp_event: + third_party_events_to_append.append(topic_tp_event) + + for event_dict in third_party_events_to_append: + ( + event, + unpersisted_context, + _, + ) = await self.event_creation_handler.create_event( + creator, + event_dict, + prev_event_ids=prev_event, + state_map=state_map, + for_batch=True, + current_state_group=current_state_group, + ) + context = await unpersisted_context.persist(event) + events_to_send.append((event, context)) datastore = self.hs.get_datastores().state events_and_context = ( diff --git a/synapse/handlers/room_batch.py b/synapse/handlers/room_batch.py index bf9df60218a1..8b5e02af17f7 100644 --- a/synapse/handlers/room_batch.py +++ b/synapse/handlers/room_batch.py @@ -327,7 +327,11 @@ async def persist_historical_events( # Mark all events as historical event_dict["content"][EventContentFields.MSC2716_HISTORICAL] = True - event, unpersisted_context = await self.event_creation_handler.create_event( + ( + event, + unpersisted_context, + _, + ) = await self.event_creation_handler.create_event( await self.create_requester_for_user_id_from_app_service( ev["sender"], app_service_requester.app_service ), diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 509c5578895b..9d3096df8d7c 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -418,6 +418,7 @@ async def _local_membership_update( ( event, unpersisted_context, + third_party_event, ) = await self.event_creation_handler.create_event( requester, { @@ -472,6 +473,20 @@ async def _local_membership_update( ratelimit=ratelimit, ) ) + if third_party_event: + ( + tp_event, + tp_unpersisted_context, + _, + ) = await self.event_creation_handler.create_event( + requester, + third_party_event, + prev_event_ids=[result_event.event_id], + ) + tp_context = await tp_unpersisted_context.persist(tp_event) + await self.event_creation_handler.handle_new_client_event( + requester, events_and_context=[(tp_event, tp_context)] + ) if event.membership == Membership.LEAVE: if prev_member_event_id: @@ -1951,6 +1966,7 @@ async def _generate_local_out_of_band_leave( ( event, unpersisted_context, + third_party_event_dict, ) = await self.event_creation_handler.create_event( requester, event_dict, @@ -1962,10 +1978,24 @@ async def _generate_local_out_of_band_leave( context = await unpersisted_context.persist(event) event.internal_metadata.out_of_band_membership = True + events_and_context = [(event, context)] + if third_party_event_dict: + ( + third_party_event, + third_party_unpersisted_context, + _, + ) = await self.event_creation_handler.create_event( + requester, third_party_event_dict + ) + third_party_context = await third_party_unpersisted_context.persist( + event + ) + events_and_context.append((third_party_event, third_party_context)) + result_event = ( await self.event_creation_handler.handle_new_client_event( requester, - events_and_context=[(event, context)], + events_and_context=events_and_context, extra_users=[UserID.from_string(target_user)], ) ) diff --git a/tests/handlers/test_message.py b/tests/handlers/test_message.py index 9691d66b48a0..2e838e6572c0 100644 --- a/tests/handlers/test_message.py +++ b/tests/handlers/test_message.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import Tuple +from typing import Optional, Tuple from twisted.test.proto_helpers import MemoryReactor @@ -81,7 +81,7 @@ def _create_and_persist_member_event(self) -> Tuple[EventBase, EventContext]: def _create_duplicate_event( self, txn_id: str - ) -> Tuple[EventBase, UnpersistedEventContextBase]: + ) -> Tuple[EventBase, UnpersistedEventContextBase, Optional[dict]]: """Create a new event with the given transaction ID. All events produced by this method will be considered duplicates. """ @@ -109,7 +109,7 @@ def test_duplicated_txn_id(self) -> None: txn_id = "something_suitably_random" - event1, unpersisted_context = self._create_duplicate_event(txn_id) + event1, unpersisted_context, _ = self._create_duplicate_event(txn_id) context = self.get_success(unpersisted_context.persist(event1)) ret_event1 = self.get_success( @@ -122,7 +122,7 @@ def test_duplicated_txn_id(self) -> None: self.assertEqual(event1.event_id, ret_event1.event_id) - event2, unpersisted_context = self._create_duplicate_event(txn_id) + event2, unpersisted_context, _ = self._create_duplicate_event(txn_id) context = self.get_success(unpersisted_context.persist(event2)) # We want to test that the deduplication at the persit event end works, @@ -144,7 +144,7 @@ def test_duplicated_txn_id(self) -> None: # Let's test that calling `persist_event` directly also does the right # thing. - event3, unpersisted_context = self._create_duplicate_event(txn_id) + event3, unpersisted_context, _ = self._create_duplicate_event(txn_id) context = self.get_success(unpersisted_context.persist(event3)) self.assertNotEqual(event1.event_id, event3.event_id) @@ -160,8 +160,9 @@ def test_duplicated_txn_id(self) -> None: # Let's test that calling `persist_events` directly also does the right # thing. - event4, unpersisted_context = self._create_duplicate_event(txn_id) + event4, unpersisted_context, _ = self._create_duplicate_event(txn_id) context = self.get_success(unpersisted_context.persist(event4)) + self.assertNotEqual(event1.event_id, event3.event_id) events, _ = self.get_success( @@ -181,9 +182,9 @@ def test_duplicated_txn_id_one_call(self) -> None: txn_id = "something_else_suitably_random" # Create two duplicate events to persist at the same time - event1, unpersisted_context1 = self._create_duplicate_event(txn_id) + event1, unpersisted_context1, _ = self._create_duplicate_event(txn_id) context1 = self.get_success(unpersisted_context1.persist(event1)) - event2, unpersisted_context2 = self._create_duplicate_event(txn_id) + event2, unpersisted_context2, _ = self._create_duplicate_event(txn_id) context2 = self.get_success(unpersisted_context2.persist(event2)) # Ensure their event IDs are different to start with @@ -209,7 +210,7 @@ def test_when_empty_prev_events_allowed_create_event_with_empty_prev_events( memberEvent, _ = self._create_and_persist_member_event() # Try to create the event with empty prev_events bit with some auth_events - event, _ = self.get_success( + event, _, _ = self.get_success( self.handler.create_event( self.requester, { diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py index aff1ec475884..161ff0a6c144 100644 --- a/tests/handlers/test_register.py +++ b/tests/handlers/test_register.py @@ -507,7 +507,8 @@ def test_auto_create_auto_join_room_preset_invalid_permissions(self) -> None: # Lower the permissions of the inviter. event_creation_handler = self.hs.get_event_creation_handler() requester = create_requester(inviter) - event, unpersisted_context = self.get_success( + + event, unpersisted_context, _ = self.get_success( event_creation_handler.create_event( requester, { diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py index da4d24082648..ce095eb68a05 100644 --- a/tests/handlers/test_user_directory.py +++ b/tests/handlers/test_user_directory.py @@ -965,7 +965,7 @@ def _add_user_to_room( }, ) - event, unpersisted_context = self.get_success( + event, unpersisted_context, _ = self.get_success( self.event_creation_handler.create_new_client_event(builder) ) diff --git a/tests/push/test_bulk_push_rule_evaluator.py b/tests/push/test_bulk_push_rule_evaluator.py index 46df0102f77a..978c2d5a3462 100644 --- a/tests/push/test_bulk_push_rule_evaluator.py +++ b/tests/push/test_bulk_push_rule_evaluator.py @@ -130,7 +130,7 @@ def test_action_for_event_by_user_handles_noninteger_room_power_levels( # Create a new message event, and try to evaluate it under the dodgy # power level event. - event, unpersisted_context = self.get_success( + event, unpersisted_context, _ = self.get_success( self.event_creation_handler.create_event( self.requester, { @@ -171,7 +171,7 @@ def test_action_for_event_by_user_disabled_by_config(self) -> None: """Ensure that push rules are not calculated when disabled in the config""" # Create a new message event which should cause a notification. - event, unpersisted_context = self.get_success( + event, unpersisted_context, _ = self.get_success( self.event_creation_handler.create_event( self.requester, { @@ -202,7 +202,7 @@ def _create_and_process( ) -> bool: """Returns true iff the `mentions` trigger an event push action.""" # Create a new message event which should cause a notification. - event, unpersisted_context = self.get_success( + event, unpersisted_context, _ = self.get_success( self.event_creation_handler.create_event( self.requester, { @@ -378,7 +378,7 @@ def test_suppress_edits(self) -> None: bulk_evaluator = BulkPushRuleEvaluator(self.hs) # Create & persist an event to use as the parent of the relation. - event, unpersisted_context = self.get_success( + event, unpersisted_context, _ = self.get_success( self.event_creation_handler.create_event( self.requester, { diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py index 4b8f889a71ba..c278f6bbade3 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py @@ -2935,7 +2935,7 @@ def test_get_rooms_with_nonlocal_user(self) -> None: }, ) - event, unpersisted_context = self.get_success( + event, unpersisted_context, _ = self.get_success( event_creation_handler.create_new_client_event(builder) ) diff --git a/tests/rest/client/test_third_party_rules.py b/tests/rest/client/test_third_party_rules.py index 753ecc8d161a..1bdb6bb6a5a2 100644 --- a/tests/rest/client/test_third_party_rules.py +++ b/tests/rest/client/test_third_party_rules.py @@ -275,6 +275,46 @@ async def check( ev = channel.json_body self.assertEqual(ev["content"]["x"], "y") + def test_add_event(self) -> None: + # needs checking of combo of return conditions, ie replace event and send event + async def check( + ev: EventBase, state: StateMap[EventBase] + ) -> Tuple[bool, Optional[JsonDict], Optional[dict]]: + event_dict = { + "type": "m.room.test", + "room_id": self.room_id, + "sender": self.user_id, + "content": { + "creator": "test_user", + "body": "message", + "msgtype": "message", + }, + } + if ev.type == "message": + return True, None, event_dict + else: + return True, None, None + + self.hs.get_third_party_event_rules()._check_event_allowed_v2_callbacks = [ + check + ] + + channel = self.make_request( + "PUT", + "/_matrix/client/r0/rooms/%s/send/message/1" % self.room_id, + {"x": "x"}, + access_token=self.tok, + ) + self.assertEqual(channel.code, 200, channel.result) + + events = self.get_success( + self.hs.get_datastores().main.get_forward_extremities_for_room(self.room_id) + ) + event = events[1] + + e = self.get_success(self.hs.get_datastores().main.get_event(event["event_id"])) + self.assertEqual("m.room.test", e.type) + def test_message_edit(self) -> None: """Ensure that the module doesn't cause issues with edited messages.""" diff --git a/tests/storage/test_event_chain.py b/tests/storage/test_event_chain.py index e39b63edac42..2a9aa9e21c54 100644 --- a/tests/storage/test_event_chain.py +++ b/tests/storage/test_event_chain.py @@ -522,7 +522,8 @@ def _generate_room(self) -> Tuple[str, List[Set[str]]]: latest_event_ids = self.get_success( self.store.get_prev_events_for_room(room_id) ) - event, unpersisted_context = self.get_success( + + event, unpersisted_context, _ = self.get_success( event_handler.create_event( self.requester, { @@ -545,7 +546,7 @@ def _generate_room(self) -> Tuple[str, List[Set[str]]]: assert state_ids1 is not None state1 = set(state_ids1.values()) - event, unpersisted_context = self.get_success( + event, unpersisted_context, _ = self.get_success( event_handler.create_event( self.requester, { diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py index 0100f7da14c6..b8b997f38a71 100644 --- a/tests/storage/test_redaction.py +++ b/tests/storage/test_redaction.py @@ -74,7 +74,7 @@ def inject_room_member( # type: ignore[override] }, ) - event, unpersisted_context = self.get_success( + event, unpersisted_context, _ = self.get_success( self.event_creation_handler.create_new_client_event(builder) ) @@ -98,7 +98,7 @@ def inject_message(self, room: RoomID, user: UserID, body: str) -> EventBase: }, ) - event, unpersisted_context = self.get_success( + event, unpersisted_context, _ = self.get_success( self.event_creation_handler.create_new_client_event(builder) ) @@ -123,7 +123,7 @@ def inject_redaction( }, ) - event, unpersisted_context = self.get_success( + event, unpersisted_context, _ = self.get_success( self.event_creation_handler.create_new_client_event(builder) ) @@ -265,7 +265,7 @@ def type(self) -> str: def internal_metadata(self) -> _EventInternalMetadata: return self._base_builder.internal_metadata - event_1, unpersisted_context_1 = self.get_success( + event_1, unpersisted_context_1, _ = self.get_success( self.event_creation_handler.create_new_client_event( cast( EventBuilder, @@ -290,7 +290,7 @@ def internal_metadata(self) -> _EventInternalMetadata: self.get_success(self._persistence.persist_event(event_1, context_1)) - event_2, unpersisted_context_2 = self.get_success( + event_2, unpersisted_context_2, _ = self.get_success( self.event_creation_handler.create_new_client_event( cast( EventBuilder, @@ -431,7 +431,7 @@ def test_store_redacted_redaction(self) -> None: }, ) - redaction_event, unpersisted_context = self.get_success( + redaction_event, unpersisted_context, _ = self.get_success( self.event_creation_handler.create_new_client_event(builder) ) diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py index 62aed6af0a73..1a1214c7a291 100644 --- a/tests/storage/test_state.py +++ b/tests/storage/test_state.py @@ -67,7 +67,7 @@ def inject_state_event( }, ) - event, unpersisted_context = self.get_success( + event, unpersisted_context, _ = self.get_success( self.event_creation_handler.create_new_client_event(builder) ) @@ -521,7 +521,7 @@ def test_batched_state_group_storing(self) -> None: }, ) - event1, unpersisted_context1 = self.get_success( + event1, unpersisted_context1, _ = self.get_success( self.event_creation_handler.create_new_client_event(builder) ) events_and_context.append((event1, unpersisted_context1)) @@ -537,7 +537,7 @@ def test_batched_state_group_storing(self) -> None: }, ) - event2, unpersisted_context2 = self.get_success( + event2, unpersisted_context2, _ = self.get_success( self.event_creation_handler.create_new_client_event(builder2) ) events_and_context.append((event2, unpersisted_context2)) @@ -552,7 +552,7 @@ def test_batched_state_group_storing(self) -> None: }, ) - event3, unpersisted_context3 = self.get_success( + event3, unpersisted_context3, _ = self.get_success( self.event_creation_handler.create_new_client_event(builder3) ) events_and_context.append((event3, unpersisted_context3)) @@ -568,7 +568,7 @@ def test_batched_state_group_storing(self) -> None: }, ) - event4, unpersisted_context4 = self.get_success( + event4, unpersisted_context4, _ = self.get_success( self.event_creation_handler.create_new_client_event(builder4) ) events_and_context.append((event4, unpersisted_context4)) diff --git a/tests/test_utils/event_injection.py b/tests/test_utils/event_injection.py index 9679904c3321..c619ef7f386b 100644 --- a/tests/test_utils/event_injection.py +++ b/tests/test_utils/event_injection.py @@ -95,6 +95,7 @@ async def create_event( ( event, unpersisted_context, + _, ) = await hs.get_event_creation_handler().create_new_client_event( builder, prev_event_ids=prev_event_ids ) diff --git a/tests/test_visibility.py b/tests/test_visibility.py index 9ed330f55497..6004490b8c3f 100644 --- a/tests/test_visibility.py +++ b/tests/test_visibility.py @@ -207,7 +207,7 @@ def _inject_visibility(self, user_id: str, visibility: str) -> EventBase: }, ) - event, unpersisted_context = self.get_success( + event, unpersisted_context, _ = self.get_success( self.event_creation_handler.create_new_client_event(builder) ) context = self.get_success(unpersisted_context.persist(event)) @@ -233,7 +233,7 @@ def _inject_room_member( }, ) - event, unpersisted_context = self.get_success( + event, unpersisted_context, _ = self.get_success( self.event_creation_handler.create_new_client_event(builder) ) context = self.get_success(unpersisted_context.persist(event)) @@ -256,7 +256,7 @@ def _inject_message( }, ) - event, unpersisted_context = self.get_success( + event, unpersisted_context, _ = self.get_success( self.event_creation_handler.create_new_client_event(builder) ) context = self.get_success(unpersisted_context.persist(event)) diff --git a/tests/unittest.py b/tests/unittest.py index f9160faa1d0c..4b31f84494af 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -723,7 +723,7 @@ def create_and_send_event( event_creator = self.hs.get_event_creation_handler() requester = create_requester(user) - event, unpersisted_context = self.get_success( + event, unpersisted_context, _ = self.get_success( event_creator.create_event( requester, { diff --git a/tests/utils.py b/tests/utils.py index a0ac11bc5cd2..3badfb7d413c 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -335,9 +335,11 @@ async def create_room(hs: HomeServer, room_id: str, creator_id: str) -> None: }, ) - event, unpersisted_context = await event_creation_handler.create_new_client_event( - builder - ) + ( + event, + unpersisted_context, + _, + ) = await event_creation_handler.create_new_client_event(builder) context = await unpersisted_context.persist(event) await persistence_store.persist_event(event, context)