Skip to content

Commit

Permalink
Fix invite acceptance over federation (#15)
Browse files Browse the repository at this point in the history
  • Loading branch information
H-Shay authored May 16, 2023
1 parent 85725a0 commit 936525a
Show file tree
Hide file tree
Showing 3 changed files with 204 additions and 25 deletions.
55 changes: 48 additions & 7 deletions synapse_auto_accept_invite/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from typing import Any, Dict, Optional, Tuple

import attr
from synapse.module_api import EventBase, ModuleApi
from synapse.module_api import EventBase, ModuleApi, run_as_background_process

logger = logging.getLogger(__name__)
ACCOUNT_DATA_DIRECT_MESSAGE_LIST = "m.direct"
Expand Down Expand Up @@ -95,12 +95,16 @@ async def on_new_event(self, event: EventBase, *args: Any) -> None:
not self._config.accept_invites_only_for_direct_messages
or is_direct_message is True
):
# Make the user join the room.
await self._api.update_room_membership(
sender=event.state_key,
target=event.state_key,
room_id=event.room_id,
new_membership="join",
# Make the user join the room. We run this as a background process to circumvent a race condition
# that occurs when responding to invites over federation (see https://github.com/matrix-org/synapse-auto-accept-invite/issues/12)
run_as_background_process(
"retry_make_join",
self._retry_make_join,
event.state_key,
event.state_key,
event.room_id,
"join",
bg_start_span=False,
)

if is_direct_message:
Expand Down Expand Up @@ -149,3 +153,40 @@ async def _mark_room_as_direct_message(
await self._api.account_data_manager.put_global(
user_id, ACCOUNT_DATA_DIRECT_MESSAGE_LIST, dm_map
)

async def _retry_make_join(
self, sender: str, target: str, room_id: str, new_membership: str
) -> None:
"""
A function to retry sending the `make_join` request with an increasing backoff. This is
implemented to work around a race condition when receiving invites over federation.
Args:
sender: the user performing the membership change
target: the for whom the membership is changing
room_id: room id of the room to join to
new_membership: the type of membership event (in this case will be "join")
"""

sleep = 0
retries = 0
join_event = None

while retries < 5:
try:
await self._api.sleep(sleep)
join_event = await self._api.update_room_membership(
sender=sender,
target=target,
room_id=room_id,
new_membership=new_membership,
)
except Exception as e:
logger.info(
f"Update_room_membership raised the following exception: {e}"
)
sleep = 2**retries
retries += 1

if join_event is not None:
break
28 changes: 23 additions & 5 deletions tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, Optional, TypeVar
import asyncio
from asyncio import Future
from typing import Any, Awaitable, Dict, Optional, TypeVar
from unittest.mock import Mock

import attr
from synapse.module_api import ModuleApi
from synapse.module_api import ModuleApi, run_as_background_process

from synapse_auto_accept_invite import InviteAutoAccepter

Expand Down Expand Up @@ -44,12 +46,24 @@ def membership(self) -> str:


T = TypeVar("T")
TV = TypeVar("TV")


async def make_awaitable(value: T) -> T:
return value


def make_multiple_awaitable(result: TV) -> Awaitable[TV]:
"""
Makes an awaitable, suitable for mocking an `async` function.
This uses Futures as they can be awaited multiple times so can be returned
to multiple callers. Stolen from synapse.
"""
future: Future[TV] = Future()
future.set_result(result)
return future


def create_module(
config_override: Dict[str, Any] = {}, worker_name: Optional[str] = None
) -> InviteAutoAccepter:
Expand All @@ -58,10 +72,14 @@ def create_module(
module_api = Mock(spec=ModuleApi)
module_api.is_mine.side_effect = lambda a: a.split(":")[1] == "test"
module_api.worker_name = worker_name
module_api.sleep.return_value = make_multiple_awaitable(None)

# Python 3.6 doesn't support awaiting on a mock, so we make it return an awaitable
# value.
module_api.update_room_membership.return_value = make_awaitable(None)
config = InviteAutoAccepter.parse_config(config_override)

run_as_background_process.side_effect = (
lambda desc, func, *args, bg_start_span, **kwargs: asyncio.create_task(
func(*args, **kwargs)
)
)

return InviteAutoAccepter(config, module_api)
146 changes: 133 additions & 13 deletions tests/test_accept_invite.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import cast
import asyncio
from typing import Any, cast
from unittest.mock import Mock

import aiounittest
Expand All @@ -31,9 +32,10 @@ def setUp(self) -> None:
# We know our module API is a mock, but mypy doesn't.
self.mocked_update_membership: Mock = self.module._api.update_room_membership # type: ignore[assignment]

async def test_accept_invite(self) -> None:
async def test_simple_accept_invite(self) -> None:
"""Tests that receiving an invite for a local user makes the module attempt to
make the invitee join the room.
make the invitee join the room. This test verifies that it works if the call to
update membership returns a join event on the first try.
"""
invite = MockEvent(
sender=self.user_id,
Expand All @@ -42,13 +44,86 @@ async def test_accept_invite(self) -> None:
content={"membership": "invite"},
)

join_event = MockEvent(
sender="someone",
state_key="someone",
type="m.room.member",
content={"membership": "join"},
)
self.mocked_update_membership.return_value = make_awaitable(join_event)

# Stop mypy from complaining that we give on_new_event a MockEvent rather than an
# EventBase.
await self.module.on_new_event(event=invite) # type: ignore[arg-type]

# Check that the mocked method is called exactly once and with the right
# arguments to attempt to make the user join the room.
self.mocked_update_membership.assert_called_once_with(
await self.retry_assertions(
self.mocked_update_membership,
1,
sender=invite.state_key,
target=invite.state_key,
room_id=invite.room_id,
new_membership="join",
)

async def test_accept_invite_with_failures(self) -> None:
"""Tests that receiving an invite for a local user makes the module attempt to
make the invitee join the room. This test verifies that it works if the call to
update membership returns exceptions before successfully completing and returning an event.
"""
invite = MockEvent(
sender=self.user_id,
state_key=self.invitee,
type="m.room.member",
content={"membership": "invite"},
)

join_event = MockEvent(
sender="someone",
state_key="someone",
type="m.room.member",
content={"membership": "join"},
)
# the first two calls raise an exception while the third call is successful
self.mocked_update_membership.side_effect = [
Exception(),
Exception(),
make_awaitable(join_event),
]

# Stop mypy from complaining that we give on_new_event a MockEvent rather than an
# EventBase.
await self.module.on_new_event(event=invite) # type: ignore[arg-type]

await self.retry_assertions(
self.mocked_update_membership,
3,
sender=invite.state_key,
target=invite.state_key,
room_id=invite.room_id,
new_membership="join",
)

async def test_accept_invite_failures(self) -> None:
"""Tests that receiving an invite for a local user makes the module attempt to
make the invitee join the room. This test verifies that if the update_membership call
fails consistently, _retry_make_join will break the loop after the set number of retries and
execution will continue.
"""
invite = MockEvent(
sender=self.user_id,
state_key=self.invitee,
type="m.room.member",
content={"membership": "invite"},
)
self.mocked_update_membership.side_effect = Exception()

# Stop mypy from complaining that we give on_new_event a MockEvent rather than an
# EventBase.
await self.module.on_new_event(event=invite) # type: ignore[arg-type]

await self.retry_assertions(
self.mocked_update_membership,
5,
sender=invite.state_key,
target=invite.state_key,
room_id=invite.room_id,
Expand All @@ -68,6 +143,14 @@ async def test_accept_invite_direct_message(self) -> None:
room_id="!the:room",
)

join_event = MockEvent(
sender="someone",
state_key="someone",
type="m.room.member",
content={"membership": "join"},
)
self.mocked_update_membership.return_value = make_awaitable(join_event)

# We will mock out the account data get/put methods to check that the flags
# are properly set.
account_data_put: Mock = cast(
Expand All @@ -90,9 +173,9 @@ async def test_accept_invite_direct_message(self) -> None:
# EventBase.
await self.module.on_new_event(event=invite) # type: ignore[arg-type]

# Check that the mocked method is called exactly once and with the right
# arguments to attempt to make the user join the room.
self.mocked_update_membership.assert_called_once_with(
await self.retry_assertions(
self.mocked_update_membership,
1,
sender=invite.state_key,
target=invite.state_key,
room_id=invite.room_id,
Expand Down Expand Up @@ -188,6 +271,15 @@ async def test_accept_invite_direct_message_if_only_enabled_for_direct_messages(
account_data_get: Mock = cast(Mock, module._api.account_data_manager.get_global)
account_data_get.return_value = make_awaitable({})

mocked_update_membership: Mock = module._api.update_room_membership # type: ignore[assignment]
join_event = MockEvent(
sender="someone",
state_key="someone",
type="m.room.member",
content={"membership": "join"},
)
mocked_update_membership.return_value = make_awaitable(join_event)

invite = MockEvent(
sender=self.user_id,
state_key=self.invitee,
Expand All @@ -199,10 +291,9 @@ async def test_accept_invite_direct_message_if_only_enabled_for_direct_messages(
# EventBase.
await module.on_new_event(event=invite) # type: ignore[arg-type]

# Check that the mocked method is called exactly once and with the right
# arguments to attempt to make the user join the room.
mocked_update_membership: Mock = module._api.update_room_membership # type: ignore[assignment]
mocked_update_membership.assert_called_once_with(
await self.retry_assertions(
mocked_update_membership,
1,
sender=invite.state_key,
target=invite.state_key,
room_id=invite.room_id,
Expand Down Expand Up @@ -262,3 +353,32 @@ def test_runs_on_only_one_worker(self) -> None:
cast(
Mock, specified_module._api.register_third_party_rules_callbacks
).assert_called_once()

async def retry_assertions(
self, mock: Mock, call_count: int, **kwargs: Any
) -> None:
"""
This is a hacky way to ensure that the assertions are not called before the other coroutine
has a chance to call `update_room_membership`. It catches the exception caused by a failure,
and sleeps the thread before retrying, up until 5 tries.
Args:
call_count: the number of times the mock should have been called
mock: the mocked function we want to assert on
kwargs: keyword arguments to assert that the mock was called with
"""

i = 0
while i < 5:
try:
# Check that the mocked method is called the expected amount of times and with the right
# arguments to attempt to make the user join the room.
mock.assert_called_with(**kwargs)
self.assertEqual(call_count, mock.call_count)
break
except AssertionError as e:
i += 1
if i == 5:
# we've used up the tries, force the test to fail as we've already caught the exception
self.fail(e)
await asyncio.sleep(1)

0 comments on commit 936525a

Please sign in to comment.