-
Notifications
You must be signed in to change notification settings - Fork 236
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Showing
4 changed files
with
694 additions
and
595 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,281 @@ | ||
from tests.rest.client.test_sendtodevice_base import SendToDeviceTestCaseBase | ||
from tests.unittest import HomeserverTestCase | ||
from twisted.test.proto_helpers import MemoryReactor | ||
|
||
from synapse.api.constants import EduTypes | ||
from synapse.rest import admin | ||
from synapse.rest.client import login, sendtodevice, sync | ||
from synapse.server import HomeServer | ||
from synapse.util import Clock | ||
|
||
class SendToDeviceTestCase(SendToDeviceTestCaseBase, HomeserverTestCase): | ||
# See SendToDeviceTestCaseBase for tests | ||
from tests.unittest import HomeserverTestCase, override_config | ||
|
||
|
||
class NotTested: | ||
""" | ||
We nest the base test class to avoid the tests being run twice by the test runner | ||
when we share/import these tests in other files. Without this, Twisted trial throws | ||
a `KeyError` in the reporter when using multiple jobs (`poetry run trial --jobs=6`). | ||
""" | ||
|
||
class SendToDeviceTestCaseBase(HomeserverTestCase): | ||
""" | ||
Test `/sendToDevice` will deliver messages across to people receiving them over `/sync`. | ||
In order to run the tests, inherit from this base-class with `HomeserverTestCase`, e.g. | ||
`class SendToDeviceTestCase(SendToDeviceTestCase, HomeserverTestCase)` | ||
""" | ||
|
||
servlets = [ | ||
admin.register_servlets, | ||
login.register_servlets, | ||
sendtodevice.register_servlets, | ||
sync.register_servlets, | ||
] | ||
|
||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: | ||
self.sync_endpoint = "/sync" | ||
|
||
def test_user_to_user(self) -> None: | ||
"""A to-device message from one user to another should get delivered""" | ||
|
||
user1 = self.register_user("u1", "pass") | ||
user1_tok = self.login("u1", "pass", "d1") | ||
|
||
user2 = self.register_user("u2", "pass") | ||
user2_tok = self.login("u2", "pass", "d2") | ||
|
||
# send the message | ||
test_msg = {"foo": "bar"} | ||
chan = self.make_request( | ||
"PUT", | ||
"/_matrix/client/r0/sendToDevice/m.test/1234", | ||
content={"messages": {user2: {"d2": test_msg}}}, | ||
access_token=user1_tok, | ||
) | ||
self.assertEqual(chan.code, 200, chan.result) | ||
|
||
# check it appears | ||
channel = self.make_request( | ||
"GET", self.sync_endpoint, access_token=user2_tok | ||
) | ||
self.assertEqual(channel.code, 200, channel.result) | ||
expected_result = { | ||
"events": [ | ||
{ | ||
"sender": user1, | ||
"type": "m.test", | ||
"content": test_msg, | ||
} | ||
] | ||
} | ||
self.assertEqual(channel.json_body["to_device"], expected_result) | ||
|
||
# it should re-appear if we do another sync because the to-device message is not | ||
# deleted until we acknowledge it by sending a `?since=...` parameter in the | ||
# next sync request corresponding to the `next_batch` value from the response. | ||
channel = self.make_request( | ||
"GET", self.sync_endpoint, access_token=user2_tok | ||
) | ||
self.assertEqual(channel.code, 200, channel.result) | ||
self.assertEqual(channel.json_body["to_device"], expected_result) | ||
|
||
# it should *not* appear if we do an incremental sync | ||
sync_token = channel.json_body["next_batch"] | ||
channel = self.make_request( | ||
"GET", | ||
f"{self.sync_endpoint}?since={sync_token}", | ||
access_token=user2_tok, | ||
) | ||
self.assertEqual(channel.code, 200, channel.result) | ||
self.assertEqual( | ||
channel.json_body.get("to_device", {}).get("events", []), [] | ||
) | ||
|
||
@override_config({"rc_key_requests": {"per_second": 10, "burst_count": 2}}) | ||
def test_local_room_key_request(self) -> None: | ||
"""m.room_key_request has special-casing; test from local user""" | ||
user1 = self.register_user("u1", "pass") | ||
user1_tok = self.login("u1", "pass", "d1") | ||
|
||
user2 = self.register_user("u2", "pass") | ||
user2_tok = self.login("u2", "pass", "d2") | ||
|
||
# send three messages | ||
for i in range(3): | ||
chan = self.make_request( | ||
"PUT", | ||
f"/_matrix/client/r0/sendToDevice/m.room_key_request/{i}", | ||
content={"messages": {user2: {"d2": {"idx": i}}}}, | ||
access_token=user1_tok, | ||
) | ||
self.assertEqual(chan.code, 200, chan.result) | ||
|
||
# now sync: we should get two of the three (because burst_count=2) | ||
channel = self.make_request( | ||
"GET", self.sync_endpoint, access_token=user2_tok | ||
) | ||
self.assertEqual(channel.code, 200, channel.result) | ||
msgs = channel.json_body["to_device"]["events"] | ||
self.assertEqual(len(msgs), 2) | ||
for i in range(2): | ||
self.assertEqual( | ||
msgs[i], | ||
{ | ||
"sender": user1, | ||
"type": "m.room_key_request", | ||
"content": {"idx": i}, | ||
}, | ||
) | ||
sync_token = channel.json_body["next_batch"] | ||
|
||
# ... time passes | ||
self.reactor.advance(1) | ||
|
||
# and we can send more messages | ||
chan = self.make_request( | ||
"PUT", | ||
"/_matrix/client/r0/sendToDevice/m.room_key_request/3", | ||
content={"messages": {user2: {"d2": {"idx": 3}}}}, | ||
access_token=user1_tok, | ||
) | ||
self.assertEqual(chan.code, 200, chan.result) | ||
|
||
# ... which should arrive | ||
channel = self.make_request( | ||
"GET", | ||
f"{self.sync_endpoint}?since={sync_token}", | ||
access_token=user2_tok, | ||
) | ||
self.assertEqual(channel.code, 200, channel.result) | ||
msgs = channel.json_body["to_device"]["events"] | ||
self.assertEqual(len(msgs), 1) | ||
self.assertEqual( | ||
msgs[0], | ||
{"sender": user1, "type": "m.room_key_request", "content": {"idx": 3}}, | ||
) | ||
|
||
@override_config({"rc_key_requests": {"per_second": 10, "burst_count": 2}}) | ||
def test_remote_room_key_request(self) -> None: | ||
"""m.room_key_request has special-casing; test from remote user""" | ||
user2 = self.register_user("u2", "pass") | ||
user2_tok = self.login("u2", "pass", "d2") | ||
|
||
federation_registry = self.hs.get_federation_registry() | ||
|
||
# send three messages | ||
for i in range(3): | ||
self.get_success( | ||
federation_registry.on_edu( | ||
EduTypes.DIRECT_TO_DEVICE, | ||
"remote_server", | ||
{ | ||
"sender": "@user:remote_server", | ||
"type": "m.room_key_request", | ||
"messages": {user2: {"d2": {"idx": i}}}, | ||
"message_id": f"{i}", | ||
}, | ||
) | ||
) | ||
|
||
# now sync: we should get two of the three | ||
channel = self.make_request( | ||
"GET", self.sync_endpoint, access_token=user2_tok | ||
) | ||
self.assertEqual(channel.code, 200, channel.result) | ||
msgs = channel.json_body["to_device"]["events"] | ||
self.assertEqual(len(msgs), 2) | ||
for i in range(2): | ||
self.assertEqual( | ||
msgs[i], | ||
{ | ||
"sender": "@user:remote_server", | ||
"type": "m.room_key_request", | ||
"content": {"idx": i}, | ||
}, | ||
) | ||
sync_token = channel.json_body["next_batch"] | ||
|
||
# ... time passes | ||
self.reactor.advance(1) | ||
|
||
# and we can send more messages | ||
self.get_success( | ||
federation_registry.on_edu( | ||
EduTypes.DIRECT_TO_DEVICE, | ||
"remote_server", | ||
{ | ||
"sender": "@user:remote_server", | ||
"type": "m.room_key_request", | ||
"messages": {user2: {"d2": {"idx": 3}}}, | ||
"message_id": "3", | ||
}, | ||
) | ||
) | ||
|
||
# ... which should arrive | ||
channel = self.make_request( | ||
"GET", | ||
f"{self.sync_endpoint}?since={sync_token}", | ||
access_token=user2_tok, | ||
) | ||
self.assertEqual(channel.code, 200, channel.result) | ||
msgs = channel.json_body["to_device"]["events"] | ||
self.assertEqual(len(msgs), 1) | ||
self.assertEqual( | ||
msgs[0], | ||
{ | ||
"sender": "@user:remote_server", | ||
"type": "m.room_key_request", | ||
"content": {"idx": 3}, | ||
}, | ||
) | ||
|
||
def test_limited_sync(self) -> None: | ||
"""If a limited sync for to-devices happens the next /sync should respond immediately.""" | ||
|
||
self.register_user("u1", "pass") | ||
user1_tok = self.login("u1", "pass", "d1") | ||
|
||
user2 = self.register_user("u2", "pass") | ||
user2_tok = self.login("u2", "pass", "d2") | ||
|
||
# Do an initial sync | ||
channel = self.make_request( | ||
"GET", self.sync_endpoint, access_token=user2_tok | ||
) | ||
self.assertEqual(channel.code, 200, channel.result) | ||
sync_token = channel.json_body["next_batch"] | ||
|
||
# Send 150 to-device messages. We limit to 100 in `/sync` | ||
for i in range(150): | ||
test_msg = {"foo": "bar"} | ||
chan = self.make_request( | ||
"PUT", | ||
f"/_matrix/client/r0/sendToDevice/m.test/1234-{i}", | ||
content={"messages": {user2: {"d2": test_msg}}}, | ||
access_token=user1_tok, | ||
) | ||
self.assertEqual(chan.code, 200, chan.result) | ||
|
||
channel = self.make_request( | ||
"GET", | ||
f"{self.sync_endpoint}?since={sync_token}&timeout=300000", | ||
access_token=user2_tok, | ||
) | ||
self.assertEqual(channel.code, 200, channel.result) | ||
messages = channel.json_body.get("to_device", {}).get("events", []) | ||
self.assertEqual(len(messages), 100) | ||
sync_token = channel.json_body["next_batch"] | ||
|
||
channel = self.make_request( | ||
"GET", | ||
f"{self.sync_endpoint}?since={sync_token}&timeout=300000", | ||
access_token=user2_tok, | ||
) | ||
self.assertEqual(channel.code, 200, channel.result) | ||
messages = channel.json_body.get("to_device", {}).get("events", []) | ||
self.assertEqual(len(messages), 50) | ||
|
||
|
||
class SendToDeviceTestCase(NotTested.SendToDeviceTestCaseBase): | ||
# See SendToDeviceTestCaseBase above | ||
pass |
Oops, something went wrong.