From 936525a43fb00d44dc7c44473b4ab74656fa457d Mon Sep 17 00:00:00 2001 From: Shay Date: Tue, 16 May 2023 10:59:32 -0700 Subject: [PATCH] Fix invite acceptance over federation (#15) --- synapse_auto_accept_invite/__init__.py | 55 ++++++++-- tests/__init__.py | 28 ++++- tests/test_accept_invite.py | 146 ++++++++++++++++++++++--- 3 files changed, 204 insertions(+), 25 deletions(-) diff --git a/synapse_auto_accept_invite/__init__.py b/synapse_auto_accept_invite/__init__.py index b93d414..fb23c51 100644 --- a/synapse_auto_accept_invite/__init__.py +++ b/synapse_auto_accept_invite/__init__.py @@ -15,7 +15,7 @@ from typing import Any, Dict, Optional, Tuple import attr -from synapse.module_api import EventBase, ModuleApi +from synapse.module_api import EventBase, ModuleApi, run_as_background_process logger = logging.getLogger(__name__) ACCOUNT_DATA_DIRECT_MESSAGE_LIST = "m.direct" @@ -95,12 +95,16 @@ async def on_new_event(self, event: EventBase, *args: Any) -> None: not self._config.accept_invites_only_for_direct_messages or is_direct_message is True ): - # Make the user join the room. - await self._api.update_room_membership( - sender=event.state_key, - target=event.state_key, - room_id=event.room_id, - new_membership="join", + # Make the user join the room. We run this as a background process to circumvent a race condition + # that occurs when responding to invites over federation (see https://github.com/matrix-org/synapse-auto-accept-invite/issues/12) + run_as_background_process( + "retry_make_join", + self._retry_make_join, + event.state_key, + event.state_key, + event.room_id, + "join", + bg_start_span=False, ) if is_direct_message: @@ -149,3 +153,40 @@ async def _mark_room_as_direct_message( await self._api.account_data_manager.put_global( user_id, ACCOUNT_DATA_DIRECT_MESSAGE_LIST, dm_map ) + + async def _retry_make_join( + self, sender: str, target: str, room_id: str, new_membership: str + ) -> None: + """ + A function to retry sending the `make_join` request with an increasing backoff. This is + implemented to work around a race condition when receiving invites over federation. + + Args: + sender: the user performing the membership change + target: the for whom the membership is changing + room_id: room id of the room to join to + new_membership: the type of membership event (in this case will be "join") + """ + + sleep = 0 + retries = 0 + join_event = None + + while retries < 5: + try: + await self._api.sleep(sleep) + join_event = await self._api.update_room_membership( + sender=sender, + target=target, + room_id=room_id, + new_membership=new_membership, + ) + except Exception as e: + logger.info( + f"Update_room_membership raised the following exception: {e}" + ) + sleep = 2**retries + retries += 1 + + if join_event is not None: + break diff --git a/tests/__init__.py b/tests/__init__.py index 43dcec9..14b39b8 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -11,11 +11,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, Optional, TypeVar +import asyncio +from asyncio import Future +from typing import Any, Awaitable, Dict, Optional, TypeVar from unittest.mock import Mock import attr -from synapse.module_api import ModuleApi +from synapse.module_api import ModuleApi, run_as_background_process from synapse_auto_accept_invite import InviteAutoAccepter @@ -44,12 +46,24 @@ def membership(self) -> str: T = TypeVar("T") +TV = TypeVar("TV") async def make_awaitable(value: T) -> T: return value +def make_multiple_awaitable(result: TV) -> Awaitable[TV]: + """ + Makes an awaitable, suitable for mocking an `async` function. + This uses Futures as they can be awaited multiple times so can be returned + to multiple callers. Stolen from synapse. + """ + future: Future[TV] = Future() + future.set_result(result) + return future + + def create_module( config_override: Dict[str, Any] = {}, worker_name: Optional[str] = None ) -> InviteAutoAccepter: @@ -58,10 +72,14 @@ def create_module( module_api = Mock(spec=ModuleApi) module_api.is_mine.side_effect = lambda a: a.split(":")[1] == "test" module_api.worker_name = worker_name + module_api.sleep.return_value = make_multiple_awaitable(None) - # Python 3.6 doesn't support awaiting on a mock, so we make it return an awaitable - # value. - module_api.update_room_membership.return_value = make_awaitable(None) config = InviteAutoAccepter.parse_config(config_override) + run_as_background_process.side_effect = ( + lambda desc, func, *args, bg_start_span, **kwargs: asyncio.create_task( + func(*args, **kwargs) + ) + ) + return InviteAutoAccepter(config, module_api) diff --git a/tests/test_accept_invite.py b/tests/test_accept_invite.py index 09b8bec..1365701 100644 --- a/tests/test_accept_invite.py +++ b/tests/test_accept_invite.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import cast +import asyncio +from typing import Any, cast from unittest.mock import Mock import aiounittest @@ -31,9 +32,10 @@ def setUp(self) -> None: # We know our module API is a mock, but mypy doesn't. self.mocked_update_membership: Mock = self.module._api.update_room_membership # type: ignore[assignment] - async def test_accept_invite(self) -> None: + async def test_simple_accept_invite(self) -> None: """Tests that receiving an invite for a local user makes the module attempt to - make the invitee join the room. + make the invitee join the room. This test verifies that it works if the call to + update membership returns a join event on the first try. """ invite = MockEvent( sender=self.user_id, @@ -42,13 +44,86 @@ async def test_accept_invite(self) -> None: content={"membership": "invite"}, ) + join_event = MockEvent( + sender="someone", + state_key="someone", + type="m.room.member", + content={"membership": "join"}, + ) + self.mocked_update_membership.return_value = make_awaitable(join_event) + # Stop mypy from complaining that we give on_new_event a MockEvent rather than an # EventBase. await self.module.on_new_event(event=invite) # type: ignore[arg-type] - # Check that the mocked method is called exactly once and with the right - # arguments to attempt to make the user join the room. - self.mocked_update_membership.assert_called_once_with( + await self.retry_assertions( + self.mocked_update_membership, + 1, + sender=invite.state_key, + target=invite.state_key, + room_id=invite.room_id, + new_membership="join", + ) + + async def test_accept_invite_with_failures(self) -> None: + """Tests that receiving an invite for a local user makes the module attempt to + make the invitee join the room. This test verifies that it works if the call to + update membership returns exceptions before successfully completing and returning an event. + """ + invite = MockEvent( + sender=self.user_id, + state_key=self.invitee, + type="m.room.member", + content={"membership": "invite"}, + ) + + join_event = MockEvent( + sender="someone", + state_key="someone", + type="m.room.member", + content={"membership": "join"}, + ) + # the first two calls raise an exception while the third call is successful + self.mocked_update_membership.side_effect = [ + Exception(), + Exception(), + make_awaitable(join_event), + ] + + # Stop mypy from complaining that we give on_new_event a MockEvent rather than an + # EventBase. + await self.module.on_new_event(event=invite) # type: ignore[arg-type] + + await self.retry_assertions( + self.mocked_update_membership, + 3, + sender=invite.state_key, + target=invite.state_key, + room_id=invite.room_id, + new_membership="join", + ) + + async def test_accept_invite_failures(self) -> None: + """Tests that receiving an invite for a local user makes the module attempt to + make the invitee join the room. This test verifies that if the update_membership call + fails consistently, _retry_make_join will break the loop after the set number of retries and + execution will continue. + """ + invite = MockEvent( + sender=self.user_id, + state_key=self.invitee, + type="m.room.member", + content={"membership": "invite"}, + ) + self.mocked_update_membership.side_effect = Exception() + + # Stop mypy from complaining that we give on_new_event a MockEvent rather than an + # EventBase. + await self.module.on_new_event(event=invite) # type: ignore[arg-type] + + await self.retry_assertions( + self.mocked_update_membership, + 5, sender=invite.state_key, target=invite.state_key, room_id=invite.room_id, @@ -68,6 +143,14 @@ async def test_accept_invite_direct_message(self) -> None: room_id="!the:room", ) + join_event = MockEvent( + sender="someone", + state_key="someone", + type="m.room.member", + content={"membership": "join"}, + ) + self.mocked_update_membership.return_value = make_awaitable(join_event) + # We will mock out the account data get/put methods to check that the flags # are properly set. account_data_put: Mock = cast( @@ -90,9 +173,9 @@ async def test_accept_invite_direct_message(self) -> None: # EventBase. await self.module.on_new_event(event=invite) # type: ignore[arg-type] - # Check that the mocked method is called exactly once and with the right - # arguments to attempt to make the user join the room. - self.mocked_update_membership.assert_called_once_with( + await self.retry_assertions( + self.mocked_update_membership, + 1, sender=invite.state_key, target=invite.state_key, room_id=invite.room_id, @@ -188,6 +271,15 @@ async def test_accept_invite_direct_message_if_only_enabled_for_direct_messages( account_data_get: Mock = cast(Mock, module._api.account_data_manager.get_global) account_data_get.return_value = make_awaitable({}) + mocked_update_membership: Mock = module._api.update_room_membership # type: ignore[assignment] + join_event = MockEvent( + sender="someone", + state_key="someone", + type="m.room.member", + content={"membership": "join"}, + ) + mocked_update_membership.return_value = make_awaitable(join_event) + invite = MockEvent( sender=self.user_id, state_key=self.invitee, @@ -199,10 +291,9 @@ async def test_accept_invite_direct_message_if_only_enabled_for_direct_messages( # EventBase. await module.on_new_event(event=invite) # type: ignore[arg-type] - # Check that the mocked method is called exactly once and with the right - # arguments to attempt to make the user join the room. - mocked_update_membership: Mock = module._api.update_room_membership # type: ignore[assignment] - mocked_update_membership.assert_called_once_with( + await self.retry_assertions( + mocked_update_membership, + 1, sender=invite.state_key, target=invite.state_key, room_id=invite.room_id, @@ -262,3 +353,32 @@ def test_runs_on_only_one_worker(self) -> None: cast( Mock, specified_module._api.register_third_party_rules_callbacks ).assert_called_once() + + async def retry_assertions( + self, mock: Mock, call_count: int, **kwargs: Any + ) -> None: + """ + This is a hacky way to ensure that the assertions are not called before the other coroutine + has a chance to call `update_room_membership`. It catches the exception caused by a failure, + and sleeps the thread before retrying, up until 5 tries. + + Args: + call_count: the number of times the mock should have been called + mock: the mocked function we want to assert on + kwargs: keyword arguments to assert that the mock was called with + """ + + i = 0 + while i < 5: + try: + # Check that the mocked method is called the expected amount of times and with the right + # arguments to attempt to make the user join the room. + mock.assert_called_with(**kwargs) + self.assertEqual(call_count, mock.call_count) + break + except AssertionError as e: + i += 1 + if i == 5: + # we've used up the tries, force the test to fail as we've already caught the exception + self.fail(e) + await asyncio.sleep(1)