From 11ea16777f52264f0a1d6f4db5b7db5c6c147523 Mon Sep 17 00:00:00 2001 From: Quentin Dufour Date: Thu, 9 May 2019 12:01:41 +0200 Subject: [PATCH 1/9] Limit the number of EDUs in transactions to 100 as expected by receiver (#5138) Fixes #3951. --- changelog.d/5138.bugfix | 1 + .../sender/per_destination_queue.py | 56 ++++++++++--------- synapse/storage/deviceinbox.py | 2 +- 3 files changed, 32 insertions(+), 27 deletions(-) create mode 100644 changelog.d/5138.bugfix diff --git a/changelog.d/5138.bugfix b/changelog.d/5138.bugfix new file mode 100644 index 000000000000..c1945a8eb257 --- /dev/null +++ b/changelog.d/5138.bugfix @@ -0,0 +1 @@ +Limit the number of EDUs in transactions to 100 as expected by synapse. Thanks to @superboum for this work! diff --git a/synapse/federation/sender/per_destination_queue.py b/synapse/federation/sender/per_destination_queue.py index be992110032b..4d96f026c664 100644 --- a/synapse/federation/sender/per_destination_queue.py +++ b/synapse/federation/sender/per_destination_queue.py @@ -33,6 +33,9 @@ from synapse.storage import UserPresenceState from synapse.util.retryutils import NotRetryingDestination, get_retry_limiter +# This is defined in the Matrix spec and enforced by the receiver. +MAX_EDUS_PER_TRANSACTION = 100 + logger = logging.getLogger(__name__) @@ -197,7 +200,8 @@ def _transaction_transmission_loop(self): pending_pdus = [] while True: device_message_edus, device_stream_id, dev_list_id = ( - yield self._get_new_device_messages() + # We have to keep 2 free slots for presence and rr_edus + yield self._get_new_device_messages(MAX_EDUS_PER_TRANSACTION - 2) ) # BEGIN CRITICAL SECTION @@ -216,19 +220,9 @@ def _transaction_transmission_loop(self): pending_edus = [] - pending_edus.extend(self._get_rr_edus(force_flush=False)) - # We can only include at most 100 EDUs per transactions - pending_edus.extend(self._pop_pending_edus(100 - len(pending_edus))) - - pending_edus.extend( - self._pending_edus_keyed.values() - ) - - self._pending_edus_keyed = {} - - pending_edus.extend(device_message_edus) - + # rr_edus and pending_presence take at most one slot each + pending_edus.extend(self._get_rr_edus(force_flush=False)) pending_presence = self._pending_presence self._pending_presence = {} if pending_presence: @@ -248,6 +242,12 @@ def _transaction_transmission_loop(self): ) ) + pending_edus.extend(device_message_edus) + pending_edus.extend(self._pop_pending_edus(MAX_EDUS_PER_TRANSACTION - len(pending_edus))) + while len(pending_edus) < MAX_EDUS_PER_TRANSACTION and self._pending_edus_keyed: + _, val = self._pending_edus_keyed.popitem() + pending_edus.append(val) + if pending_pdus: logger.debug("TX [%s] len(pending_pdus_by_dest[dest]) = %d", self._destination, len(pending_pdus)) @@ -259,7 +259,7 @@ def _transaction_transmission_loop(self): # if we've decided to send a transaction anyway, and we have room, we # may as well send any pending RRs - if len(pending_edus) < 100: + if len(pending_edus) < MAX_EDUS_PER_TRANSACTION: pending_edus.extend(self._get_rr_edus(force_flush=True)) # END CRITICAL SECTION @@ -346,33 +346,37 @@ def _pop_pending_edus(self, limit): return pending_edus @defer.inlineCallbacks - def _get_new_device_messages(self): - last_device_stream_id = self._last_device_stream_id - to_device_stream_id = self._store.get_to_device_stream_token() - contents, stream_id = yield self._store.get_new_device_msgs_for_remote( - self._destination, last_device_stream_id, to_device_stream_id + def _get_new_device_messages(self, limit): + last_device_list = self._last_device_list_stream_id + # Will return at most 20 entries + now_stream_id, results = yield self._store.get_devices_by_remote( + self._destination, last_device_list ) edus = [ Edu( origin=self._server_name, destination=self._destination, - edu_type="m.direct_to_device", + edu_type="m.device_list_update", content=content, ) - for content in contents + for content in results ] - last_device_list = self._last_device_list_stream_id - now_stream_id, results = yield self._store.get_devices_by_remote( - self._destination, last_device_list + assert len(edus) <= limit, "get_devices_by_remote returned too many EDUs" + + last_device_stream_id = self._last_device_stream_id + to_device_stream_id = self._store.get_to_device_stream_token() + contents, stream_id = yield self._store.get_new_device_msgs_for_remote( + self._destination, last_device_stream_id, to_device_stream_id, limit - len(edus) ) edus.extend( Edu( origin=self._server_name, destination=self._destination, - edu_type="m.device_list_update", + edu_type="m.direct_to_device", content=content, ) - for content in results + for content in contents ) + defer.returnValue((edus, stream_id, now_stream_id)) diff --git a/synapse/storage/deviceinbox.py b/synapse/storage/deviceinbox.py index fed4ea361035..9b0a99cb490e 100644 --- a/synapse/storage/deviceinbox.py +++ b/synapse/storage/deviceinbox.py @@ -118,7 +118,7 @@ def delete_messages_for_device_txn(txn): defer.returnValue(count) def get_new_device_msgs_for_remote( - self, destination, last_stream_id, current_stream_id, limit=100 + self, destination, last_stream_id, current_stream_id, limit ): """ Args: From 130f932cbc5ca5477af263c5290e0c9a51b5150c Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Thu, 9 May 2019 16:27:02 +0100 Subject: [PATCH 2/9] Run `black` on per_destination_queue ... mostly to fix pep8 fails --- .../sender/per_destination_queue.py | 74 ++++++++++--------- 1 file changed, 39 insertions(+), 35 deletions(-) diff --git a/synapse/federation/sender/per_destination_queue.py b/synapse/federation/sender/per_destination_queue.py index 4d96f026c664..fae8bea392b3 100644 --- a/synapse/federation/sender/per_destination_queue.py +++ b/synapse/federation/sender/per_destination_queue.py @@ -40,8 +40,7 @@ sent_edus_counter = Counter( - "synapse_federation_client_sent_edus", - "Total number of EDUs successfully sent", + "synapse_federation_client_sent_edus", "Total number of EDUs successfully sent" ) sent_edus_by_type = Counter( @@ -61,6 +60,7 @@ class PerDestinationQueue(object): destination (str): the server_name of the destination that we are managing transmission for. """ + def __init__(self, hs, transaction_manager, destination): self._server_name = hs.hostname self._clock = hs.get_clock() @@ -71,17 +71,17 @@ def __init__(self, hs, transaction_manager, destination): self.transmission_loop_running = False # a list of tuples of (pending pdu, order) - self._pending_pdus = [] # type: list[tuple[EventBase, int]] - self._pending_edus = [] # type: list[Edu] + self._pending_pdus = [] # type: list[tuple[EventBase, int]] + self._pending_edus = [] # type: list[Edu] # Pending EDUs by their "key". Keyed EDUs are EDUs that get clobbered # based on their key (e.g. typing events by room_id) # Map of (edu_type, key) -> Edu - self._pending_edus_keyed = {} # type: dict[tuple[str, str], Edu] + self._pending_edus_keyed = {} # type: dict[tuple[str, str], Edu] # Map of user_id -> UserPresenceState of pending presence to be sent to this # destination - self._pending_presence = {} # type: dict[str, UserPresenceState] + self._pending_presence = {} # type: dict[str, UserPresenceState] # room_id -> receipt_type -> user_id -> receipt_dict self._pending_rrs = {} @@ -123,9 +123,7 @@ def send_presence(self, states): Args: states (iterable[UserPresenceState]): presence to send """ - self._pending_presence.update({ - state.user_id: state for state in states - }) + self._pending_presence.update({state.user_id: state for state in states}) self.attempt_new_transaction() def queue_read_receipt(self, receipt): @@ -135,14 +133,9 @@ def queue_read_receipt(self, receipt): Args: receipt (synapse.api.receipt_info.ReceiptInfo): receipt to be queued """ - self._pending_rrs.setdefault( - receipt.room_id, {}, - ).setdefault( + self._pending_rrs.setdefault(receipt.room_id, {}).setdefault( receipt.receipt_type, {} - )[receipt.user_id] = { - "event_ids": receipt.event_ids, - "data": receipt.data, - } + )[receipt.user_id] = {"event_ids": receipt.event_ids, "data": receipt.data} def flush_read_receipts_for_room(self, room_id): # if we don't have any read-receipts for this room, it may be that we've already @@ -173,10 +166,7 @@ def attempt_new_transaction(self): # request at which point pending_pdus just keeps growing. # we need application-layer timeouts of some flavour of these # requests - logger.debug( - "TX [%s] Transaction already in progress", - self._destination - ) + logger.debug("TX [%s] Transaction already in progress", self._destination) return logger.debug("TX [%s] Starting transaction loop", self._destination) @@ -243,14 +233,22 @@ def _transaction_transmission_loop(self): ) pending_edus.extend(device_message_edus) - pending_edus.extend(self._pop_pending_edus(MAX_EDUS_PER_TRANSACTION - len(pending_edus))) - while len(pending_edus) < MAX_EDUS_PER_TRANSACTION and self._pending_edus_keyed: + pending_edus.extend( + self._pop_pending_edus(MAX_EDUS_PER_TRANSACTION - len(pending_edus)) + ) + while ( + len(pending_edus) < MAX_EDUS_PER_TRANSACTION + and self._pending_edus_keyed + ): _, val = self._pending_edus_keyed.popitem() pending_edus.append(val) if pending_pdus: - logger.debug("TX [%s] len(pending_pdus_by_dest[dest]) = %d", - self._destination, len(pending_pdus)) + logger.debug( + "TX [%s] len(pending_pdus_by_dest[dest]) = %d", + self._destination, + len(pending_pdus), + ) if not pending_pdus and not pending_edus: logger.debug("TX [%s] Nothing to send", self._destination) @@ -303,22 +301,25 @@ def _transaction_transmission_loop(self): except HttpResponseException as e: logger.warning( "TX [%s] Received %d response to transaction: %s", - self._destination, e.code, e, + self._destination, + e.code, + e, ) except RequestSendFailed as e: - logger.warning("TX [%s] Failed to send transaction: %s", self._destination, e) + logger.warning( + "TX [%s] Failed to send transaction: %s", self._destination, e + ) for p, _ in pending_pdus: - logger.info("Failed to send event %s to %s", p.event_id, - self._destination) + logger.info( + "Failed to send event %s to %s", p.event_id, self._destination + ) except Exception: - logger.exception( - "TX [%s] Failed to send transaction", - self._destination, - ) + logger.exception("TX [%s] Failed to send transaction", self._destination) for p, _ in pending_pdus: - logger.info("Failed to send event %s to %s", p.event_id, - self._destination) + logger.info( + "Failed to send event %s to %s", p.event_id, self._destination + ) finally: # We want to be *very* sure we clear this after we stop processing self.transmission_loop_running = False @@ -367,7 +368,10 @@ def _get_new_device_messages(self, limit): last_device_stream_id = self._last_device_stream_id to_device_stream_id = self._store.get_to_device_stream_token() contents, stream_id = yield self._store.get_new_device_msgs_for_remote( - self._destination, last_device_stream_id, to_device_stream_id, limit - len(edus) + self._destination, + last_device_stream_id, + to_device_stream_id, + limit - len(edus), ) edus.extend( Edu( From 84cebb89cc90bc5f03ce0f586d457f1a2226b34c Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Thu, 9 May 2019 22:53:46 +0100 Subject: [PATCH 3/9] remove instructions for jessie installation (#5164) We don't ship jessie packages, so these were a bit misleading. --- INSTALL.md | 21 +++++---------------- 1 file changed, 5 insertions(+), 16 deletions(-) diff --git a/INSTALL.md b/INSTALL.md index 1032a2c68e77..b88d826f6c53 100644 --- a/INSTALL.md +++ b/INSTALL.md @@ -257,9 +257,8 @@ https://github.com/spantaleev/matrix-docker-ansible-deploy #### Matrix.org packages Matrix.org provides Debian/Ubuntu packages of the latest stable version of -Synapse via https://packages.matrix.org/debian/. To use them: - -For Debian 9 (Stretch), Ubuntu 16.04 (Xenial), and later: +Synapse via https://packages.matrix.org/debian/. They are available for Debian +9 (Stretch), Ubuntu 16.04 (Xenial), and later. To use them: ``` sudo apt install -y lsb-release wget apt-transport-https @@ -270,19 +269,6 @@ sudo apt update sudo apt install matrix-synapse-py3 ``` -For Debian 8 (Jessie): - -``` -sudo apt install -y lsb-release wget apt-transport-https -sudo wget -O /etc/apt/trusted.gpg.d/matrix-org-archive-keyring.gpg https://packages.matrix.org/debian/matrix-org-archive-keyring.gpg -echo "deb [signed-by=5586CCC0CBBBEFC7A25811ADF473DD4473365DE1] https://packages.matrix.org/debian/ $(lsb_release -cs) main" | - sudo tee /etc/apt/sources.list.d/matrix-org.list -sudo apt update -sudo apt install matrix-synapse-py3 -``` - -The fingerprint of the repository signing key is AAF9AE843A7584B5A3E4CD2BCF45A512DE2DA058. - **Note**: if you followed a previous version of these instructions which recommended using `apt-key add` to add an old key from `https://matrix.org/packages/debian/`, you should note that this key has been @@ -290,6 +276,9 @@ revoked. You should remove the old key with `sudo apt-key remove C35EB17E1EAE708E6603A9B3AD0592FE47F0DF61`, and follow the above instructions to update your configuration. +The fingerprint of the repository signing key (as shown by `gpg +/usr/share/keyrings/matrix-org-archive-keyring.gpg`) is +`AAF9AE843A7584B5A3E4CD2BCF45A512DE2DA058`. #### Downstream Debian/Ubuntu packages From d9a02d1201d030dc846e194ecb6d4c5e4619ed83 Mon Sep 17 00:00:00 2001 From: colonelkrud Date: Thu, 9 May 2019 18:27:04 -0400 Subject: [PATCH 4/9] Add AllowEncodedSlashes to apache (#5068) * Add AllowEncodedSlashes to apache Add `AllowEncodedSlashes On` to apache config to support encoding for v3 rooms. "The AllowEncodedSlashes setting is not inherited by virtual hosts, and virtual hosts are used in many default Apache configurations, such as the one in Ubuntu. The workaround is to add the AllowEncodedSlashes setting inside a container (/etc/apache2/sites-available/default in Ubuntu)." Source: https://stackoverflow.com/questions/4390436/need-to-allow-encoded-slashes-on-apache * change allowencodedslashes to nodecode --- docs/reverse_proxy.rst | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/reverse_proxy.rst b/docs/reverse_proxy.rst index cc81ceb84b59..7619b1097bcb 100644 --- a/docs/reverse_proxy.rst +++ b/docs/reverse_proxy.rst @@ -69,6 +69,7 @@ Let's assume that we expect clients to connect to our server at SSLEngine on ServerName matrix.example.com; + AllowEncodedSlashes NoDecode ProxyPass /_matrix http://127.0.0.1:8008/_matrix nocanon ProxyPassReverse /_matrix http://127.0.0.1:8008/_matrix @@ -77,6 +78,7 @@ Let's assume that we expect clients to connect to our server at SSLEngine on ServerName example.com; + AllowEncodedSlashes NoDecode ProxyPass /_matrix http://127.0.0.1:8008/_matrix nocanon ProxyPassReverse /_matrix http://127.0.0.1:8008/_matrix From b36c82576e3bb7ea72600ecf0e80c904ccf47d1d Mon Sep 17 00:00:00 2001 From: Amber Brown Date: Fri, 10 May 2019 00:12:11 -0500 Subject: [PATCH 5/9] Run Black on the tests again (#5170) --- changelog.d/5170.misc | 1 + tests/api/test_filtering.py | 1 - tests/api/test_ratelimiting.py | 6 +- tests/app/test_openid_listener.py | 46 ++- tests/config/test_generate.py | 8 +- tests/config/test_room_directory.py | 178 ++++++----- tests/config/test_server.py | 1 - tests/config/test_tls.py | 11 +- tests/crypto/test_keyring.py | 3 +- tests/federation/test_federation_sender.py | 113 ++++--- tests/handlers/test_directory.py | 24 +- tests/handlers/test_e2e_room_keys.py | 277 +++++++++--------- tests/handlers/test_presence.py | 14 +- tests/handlers/test_typing.py | 152 ++++------ tests/handlers/test_user_directory.py | 8 +- tests/http/__init__.py | 6 +- .../test_matrix_federation_agent.py | 212 +++++--------- tests/http/federation/test_srv_resolver.py | 36 +-- tests/http/test_fedclient.py | 31 +- tests/patch_inline_callbacks.py | 16 +- tests/replication/slave/storage/_base.py | 15 +- .../replication/slave/storage/test_events.py | 19 +- tests/replication/tcp/streams/_base.py | 6 +- tests/rest/admin/test_admin.py | 103 +++---- tests/rest/client/test_identity.py | 8 +- tests/rest/client/v1/test_directory.py | 53 ++-- tests/rest/client/v1/test_login.py | 36 +-- tests/rest/client/v1/test_profile.py | 27 +- tests/rest/client/v2_alpha/test_register.py | 82 +++--- tests/rest/media/v1/test_base.py | 14 +- tests/rest/test_well_known.py | 13 +- tests/server.py | 7 +- .../test_resource_limits_server_notices.py | 1 - tests/state/test_v2.py | 204 +++---------- tests/storage/test_background_update.py | 4 +- tests/storage/test_base.py | 5 +- tests/storage/test_end_to_end_keys.py | 1 - tests/storage/test_monthly_active_users.py | 17 +- tests/storage/test_redaction.py | 6 +- tests/storage/test_registration.py | 2 +- tests/storage/test_roommember.py | 2 +- tests/storage/test_state.py | 83 ++---- tests/storage/test_user_directory.py | 4 +- tests/test_event_auth.py | 8 +- tests/test_federation.py | 1 - tests/test_mau.py | 8 +- tests/test_metrics.py | 56 ++-- tests/test_terms_auth.py | 8 +- tests/test_types.py | 6 +- tests/test_utils/logging_setup.py | 4 +- tests/test_visibility.py | 6 +- tests/unittest.py | 13 +- tests/util/test_async_utils.py | 23 +- tests/utils.py | 9 +- 54 files changed, 829 insertions(+), 1169 deletions(-) create mode 100644 changelog.d/5170.misc diff --git a/changelog.d/5170.misc b/changelog.d/5170.misc new file mode 100644 index 000000000000..7919dac55580 --- /dev/null +++ b/changelog.d/5170.misc @@ -0,0 +1 @@ +Run `black` on the tests directory. diff --git a/tests/api/test_filtering.py b/tests/api/test_filtering.py index 2a7044801a91..6ba623de1336 100644 --- a/tests/api/test_filtering.py +++ b/tests/api/test_filtering.py @@ -109,7 +109,6 @@ def test_valid_filters(self): "event_format": "client", "event_fields": ["type", "content", "sender"], }, - # a single backslash should be permitted (though it is debatable whether # it should be permitted before anything other than `.`, and what that # actually means) diff --git a/tests/api/test_ratelimiting.py b/tests/api/test_ratelimiting.py index 30a255d44196..dbdd427cac22 100644 --- a/tests/api/test_ratelimiting.py +++ b/tests/api/test_ratelimiting.py @@ -10,19 +10,19 @@ def test_allowed(self): key="test_id", time_now_s=0, rate_hz=0.1, burst_count=1 ) self.assertTrue(allowed) - self.assertEquals(10., time_allowed) + self.assertEquals(10.0, time_allowed) allowed, time_allowed = limiter.can_do_action( key="test_id", time_now_s=5, rate_hz=0.1, burst_count=1 ) self.assertFalse(allowed) - self.assertEquals(10., time_allowed) + self.assertEquals(10.0, time_allowed) allowed, time_allowed = limiter.can_do_action( key="test_id", time_now_s=10, rate_hz=0.1, burst_count=1 ) self.assertTrue(allowed) - self.assertEquals(20., time_allowed) + self.assertEquals(20.0, time_allowed) def test_pruning(self): limiter = Ratelimiter() diff --git a/tests/app/test_openid_listener.py b/tests/app/test_openid_listener.py index 590abc1e92ba..48792d14801b 100644 --- a/tests/app/test_openid_listener.py +++ b/tests/app/test_openid_listener.py @@ -25,16 +25,18 @@ class FederationReaderOpenIDListenerTests(HomeserverTestCase): def make_homeserver(self, reactor, clock): hs = self.setup_test_homeserver( - http_client=None, homeserverToUse=FederationReaderServer, + http_client=None, homeserverToUse=FederationReaderServer ) return hs - @parameterized.expand([ - (["federation"], "auth_fail"), - ([], "no_resource"), - (["openid", "federation"], "auth_fail"), - (["openid"], "auth_fail"), - ]) + @parameterized.expand( + [ + (["federation"], "auth_fail"), + ([], "no_resource"), + (["openid", "federation"], "auth_fail"), + (["openid"], "auth_fail"), + ] + ) def test_openid_listener(self, names, expectation): """ Test different openid listener configurations. @@ -53,17 +55,14 @@ def test_openid_listener(self, names, expectation): # Grab the resource from the site that was told to listen site = self.reactor.tcpServers[0][1] try: - self.resource = ( - site.resource.children[b"_matrix"].children[b"federation"] - ) + self.resource = site.resource.children[b"_matrix"].children[b"federation"] except KeyError: if expectation == "no_resource": return raise request, channel = self.make_request( - "GET", - "/_matrix/federation/v1/openid/userinfo", + "GET", "/_matrix/federation/v1/openid/userinfo" ) self.render(request) @@ -74,16 +73,18 @@ def test_openid_listener(self, names, expectation): class SynapseHomeserverOpenIDListenerTests(HomeserverTestCase): def make_homeserver(self, reactor, clock): hs = self.setup_test_homeserver( - http_client=None, homeserverToUse=SynapseHomeServer, + http_client=None, homeserverToUse=SynapseHomeServer ) return hs - @parameterized.expand([ - (["federation"], "auth_fail"), - ([], "no_resource"), - (["openid", "federation"], "auth_fail"), - (["openid"], "auth_fail"), - ]) + @parameterized.expand( + [ + (["federation"], "auth_fail"), + ([], "no_resource"), + (["openid", "federation"], "auth_fail"), + (["openid"], "auth_fail"), + ] + ) def test_openid_listener(self, names, expectation): """ Test different openid listener configurations. @@ -102,17 +103,14 @@ def test_openid_listener(self, names, expectation): # Grab the resource from the site that was told to listen site = self.reactor.tcpServers[0][1] try: - self.resource = ( - site.resource.children[b"_matrix"].children[b"federation"] - ) + self.resource = site.resource.children[b"_matrix"].children[b"federation"] except KeyError: if expectation == "no_resource": return raise request, channel = self.make_request( - "GET", - "/_matrix/federation/v1/openid/userinfo", + "GET", "/_matrix/federation/v1/openid/userinfo" ) self.render(request) diff --git a/tests/config/test_generate.py b/tests/config/test_generate.py index 795b4c298d99..5017cbce853b 100644 --- a/tests/config/test_generate.py +++ b/tests/config/test_generate.py @@ -45,13 +45,7 @@ def test_generate_config_generates_files(self): ) self.assertSetEqual( - set( - [ - "homeserver.yaml", - "lemurs.win.log.config", - "lemurs.win.signing.key", - ] - ), + set(["homeserver.yaml", "lemurs.win.log.config", "lemurs.win.signing.key"]), set(os.listdir(self.dir)), ) diff --git a/tests/config/test_room_directory.py b/tests/config/test_room_directory.py index 47fffcfeb22c..0ec10019b3aa 100644 --- a/tests/config/test_room_directory.py +++ b/tests/config/test_room_directory.py @@ -22,7 +22,8 @@ class RoomDirectoryConfigTestCase(unittest.TestCase): def test_alias_creation_acl(self): - config = yaml.safe_load(""" + config = yaml.safe_load( + """ alias_creation_rules: - user_id: "*bob*" alias: "*" @@ -38,43 +39,49 @@ def test_alias_creation_acl(self): action: "allow" room_list_publication_rules: [] - """) + """ + ) rd_config = RoomDirectoryConfig() rd_config.read_config(config) - self.assertFalse(rd_config.is_alias_creation_allowed( - user_id="@bob:example.com", - room_id="!test", - alias="#test:example.com", - )) - - self.assertTrue(rd_config.is_alias_creation_allowed( - user_id="@test:example.com", - room_id="!test", - alias="#unofficial_st:example.com", - )) - - self.assertTrue(rd_config.is_alias_creation_allowed( - user_id="@foobar:example.com", - room_id="!test", - alias="#test:example.com", - )) - - self.assertTrue(rd_config.is_alias_creation_allowed( - user_id="@gah:example.com", - room_id="!test", - alias="#goo:example.com", - )) - - self.assertFalse(rd_config.is_alias_creation_allowed( - user_id="@test:example.com", - room_id="!test", - alias="#test:example.com", - )) + self.assertFalse( + rd_config.is_alias_creation_allowed( + user_id="@bob:example.com", room_id="!test", alias="#test:example.com" + ) + ) + + self.assertTrue( + rd_config.is_alias_creation_allowed( + user_id="@test:example.com", + room_id="!test", + alias="#unofficial_st:example.com", + ) + ) + + self.assertTrue( + rd_config.is_alias_creation_allowed( + user_id="@foobar:example.com", + room_id="!test", + alias="#test:example.com", + ) + ) + + self.assertTrue( + rd_config.is_alias_creation_allowed( + user_id="@gah:example.com", room_id="!test", alias="#goo:example.com" + ) + ) + + self.assertFalse( + rd_config.is_alias_creation_allowed( + user_id="@test:example.com", room_id="!test", alias="#test:example.com" + ) + ) def test_room_publish_acl(self): - config = yaml.safe_load(""" + config = yaml.safe_load( + """ alias_creation_rules: [] room_list_publication_rules: @@ -92,55 +99,66 @@ def test_room_publish_acl(self): action: "allow" - room_id: "!test-deny" action: "deny" - """) + """ + ) rd_config = RoomDirectoryConfig() rd_config.read_config(config) - self.assertFalse(rd_config.is_publishing_room_allowed( - user_id="@bob:example.com", - room_id="!test", - aliases=["#test:example.com"], - )) - - self.assertTrue(rd_config.is_publishing_room_allowed( - user_id="@test:example.com", - room_id="!test", - aliases=["#unofficial_st:example.com"], - )) - - self.assertTrue(rd_config.is_publishing_room_allowed( - user_id="@foobar:example.com", - room_id="!test", - aliases=[], - )) - - self.assertTrue(rd_config.is_publishing_room_allowed( - user_id="@gah:example.com", - room_id="!test", - aliases=["#goo:example.com"], - )) - - self.assertFalse(rd_config.is_publishing_room_allowed( - user_id="@test:example.com", - room_id="!test", - aliases=["#test:example.com"], - )) - - self.assertTrue(rd_config.is_publishing_room_allowed( - user_id="@foobar:example.com", - room_id="!test-deny", - aliases=[], - )) - - self.assertFalse(rd_config.is_publishing_room_allowed( - user_id="@gah:example.com", - room_id="!test-deny", - aliases=[], - )) - - self.assertTrue(rd_config.is_publishing_room_allowed( - user_id="@test:example.com", - room_id="!test", - aliases=["#unofficial_st:example.com", "#blah:example.com"], - )) + self.assertFalse( + rd_config.is_publishing_room_allowed( + user_id="@bob:example.com", + room_id="!test", + aliases=["#test:example.com"], + ) + ) + + self.assertTrue( + rd_config.is_publishing_room_allowed( + user_id="@test:example.com", + room_id="!test", + aliases=["#unofficial_st:example.com"], + ) + ) + + self.assertTrue( + rd_config.is_publishing_room_allowed( + user_id="@foobar:example.com", room_id="!test", aliases=[] + ) + ) + + self.assertTrue( + rd_config.is_publishing_room_allowed( + user_id="@gah:example.com", + room_id="!test", + aliases=["#goo:example.com"], + ) + ) + + self.assertFalse( + rd_config.is_publishing_room_allowed( + user_id="@test:example.com", + room_id="!test", + aliases=["#test:example.com"], + ) + ) + + self.assertTrue( + rd_config.is_publishing_room_allowed( + user_id="@foobar:example.com", room_id="!test-deny", aliases=[] + ) + ) + + self.assertFalse( + rd_config.is_publishing_room_allowed( + user_id="@gah:example.com", room_id="!test-deny", aliases=[] + ) + ) + + self.assertTrue( + rd_config.is_publishing_room_allowed( + user_id="@test:example.com", + room_id="!test", + aliases=["#unofficial_st:example.com", "#blah:example.com"], + ) + ) diff --git a/tests/config/test_server.py b/tests/config/test_server.py index f5836d73ac4e..de64965a6069 100644 --- a/tests/config/test_server.py +++ b/tests/config/test_server.py @@ -19,7 +19,6 @@ class ServerConfigTestCase(unittest.TestCase): - def test_is_threepid_reserved(self): user1 = {'medium': 'email', 'address': 'user1@example.com'} user2 = {'medium': 'email', 'address': 'user2@example.com'} diff --git a/tests/config/test_tls.py b/tests/config/test_tls.py index c260d3359f36..40ca42877843 100644 --- a/tests/config/test_tls.py +++ b/tests/config/test_tls.py @@ -26,7 +26,6 @@ def has_tls_listener(self): class TLSConfigTests(TestCase): - def test_warn_self_signed(self): """ Synapse will give a warning when it loads a self-signed certificate. @@ -34,7 +33,8 @@ def test_warn_self_signed(self): config_dir = self.mktemp() os.mkdir(config_dir) with open(os.path.join(config_dir, "cert.pem"), 'w') as f: - f.write("""-----BEGIN CERTIFICATE----- + f.write( + """-----BEGIN CERTIFICATE----- MIID6DCCAtACAws9CjANBgkqhkiG9w0BAQUFADCBtzELMAkGA1UEBhMCVFIxDzAN BgNVBAgMBsOHb3J1bTEUMBIGA1UEBwwLQmHFn21ha8OnxLExEjAQBgNVBAMMCWxv Y2FsaG9zdDEcMBoGA1UECgwTVHdpc3RlZCBNYXRyaXggTGFiczEkMCIGA1UECwwb @@ -56,11 +56,12 @@ def test_warn_self_signed(self): iZupikC5MT1LQaRwidkSNxCku1TfAyueiBwhLnFwTmIGNnhuDCutEVAD9kFmcJN2 SznugAcPk4doX2+rL+ila+ThqgPzIkwTUHtnmjI0TI6xsDUlXz5S3UyudrE2Qsfz s4niecZKPBizL6aucT59CsunNmmb5Glq8rlAcU+1ZTZZzGYqVYhF6axB9Qg= ------END CERTIFICATE-----""") +-----END CERTIFICATE-----""" + ) config = { "tls_certificate_path": os.path.join(config_dir, "cert.pem"), - "tls_fingerprints": [] + "tls_fingerprints": [], } t = TestConfig() @@ -75,5 +76,5 @@ def test_warn_self_signed(self): "Self-signed TLS certificates will not be accepted by " "Synapse 1.0. Please either provide a valid certificate, " "or use Synapse's ACME support to provision one." - ) + ), ) diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py index f5bd7a1aa1c6..3c79d4afe749 100644 --- a/tests/crypto/test_keyring.py +++ b/tests/crypto/test_keyring.py @@ -169,7 +169,7 @@ def second_lookup(): self.http_client.post_json.return_value = defer.Deferred() res_deferreds_2 = kr.verify_json_objects_for_server( - [("server10", json1, )] + [("server10", json1)] ) res_deferreds_2[0].addBoth(self.check_context, None) yield logcontext.make_deferred_yieldable(res_deferreds_2[0]) @@ -345,6 +345,7 @@ def _verify_json_for_server(keyring, server_name, json_object): """thin wrapper around verify_json_for_server which makes sure it is wrapped with the patched defer.inlineCallbacks. """ + @defer.inlineCallbacks def v(): rv1 = yield keyring.verify_json_for_server(server_name, json_object) diff --git a/tests/federation/test_federation_sender.py b/tests/federation/test_federation_sender.py index 28e7e27416c9..7bb106b5f7df 100644 --- a/tests/federation/test_federation_sender.py +++ b/tests/federation/test_federation_sender.py @@ -33,11 +33,15 @@ def test_send_receipts(self): mock_state_handler = self.hs.get_state_handler() mock_state_handler.get_current_hosts_in_room.return_value = ["test", "host2"] - mock_send_transaction = self.hs.get_federation_transport_client().send_transaction + mock_send_transaction = ( + self.hs.get_federation_transport_client().send_transaction + ) mock_send_transaction.return_value = defer.succeed({}) sender = self.hs.get_federation_sender() - receipt = ReadReceipt("room_id", "m.read", "user_id", ["event_id"], {"ts": 1234}) + receipt = ReadReceipt( + "room_id", "m.read", "user_id", ["event_id"], {"ts": 1234} + ) self.successResultOf(sender.send_read_receipt(receipt)) self.pump() @@ -46,21 +50,24 @@ def test_send_receipts(self): mock_send_transaction.assert_called_once() json_cb = mock_send_transaction.call_args[0][1] data = json_cb() - self.assertEqual(data['edus'], [ - { - 'edu_type': 'm.receipt', - 'content': { - 'room_id': { - 'm.read': { - 'user_id': { - 'event_ids': ['event_id'], - 'data': {'ts': 1234}, - }, - }, + self.assertEqual( + data['edus'], + [ + { + 'edu_type': 'm.receipt', + 'content': { + 'room_id': { + 'm.read': { + 'user_id': { + 'event_ids': ['event_id'], + 'data': {'ts': 1234}, + } + } + } }, - }, - }, - ]) + } + ], + ) def test_send_receipts_with_backoff(self): """Send two receipts in quick succession; the second should be flushed, but @@ -68,11 +75,15 @@ def test_send_receipts_with_backoff(self): mock_state_handler = self.hs.get_state_handler() mock_state_handler.get_current_hosts_in_room.return_value = ["test", "host2"] - mock_send_transaction = self.hs.get_federation_transport_client().send_transaction + mock_send_transaction = ( + self.hs.get_federation_transport_client().send_transaction + ) mock_send_transaction.return_value = defer.succeed({}) sender = self.hs.get_federation_sender() - receipt = ReadReceipt("room_id", "m.read", "user_id", ["event_id"], {"ts": 1234}) + receipt = ReadReceipt( + "room_id", "m.read", "user_id", ["event_id"], {"ts": 1234} + ) self.successResultOf(sender.send_read_receipt(receipt)) self.pump() @@ -81,25 +92,30 @@ def test_send_receipts_with_backoff(self): mock_send_transaction.assert_called_once() json_cb = mock_send_transaction.call_args[0][1] data = json_cb() - self.assertEqual(data['edus'], [ - { - 'edu_type': 'm.receipt', - 'content': { - 'room_id': { - 'm.read': { - 'user_id': { - 'event_ids': ['event_id'], - 'data': {'ts': 1234}, - }, - }, + self.assertEqual( + data['edus'], + [ + { + 'edu_type': 'm.receipt', + 'content': { + 'room_id': { + 'm.read': { + 'user_id': { + 'event_ids': ['event_id'], + 'data': {'ts': 1234}, + } + } + } }, - }, - }, - ]) + } + ], + ) mock_send_transaction.reset_mock() # send the second RR - receipt = ReadReceipt("room_id", "m.read", "user_id", ["other_id"], {"ts": 1234}) + receipt = ReadReceipt( + "room_id", "m.read", "user_id", ["other_id"], {"ts": 1234} + ) self.successResultOf(sender.send_read_receipt(receipt)) self.pump() mock_send_transaction.assert_not_called() @@ -111,18 +127,21 @@ def test_send_receipts_with_backoff(self): mock_send_transaction.assert_called_once() json_cb = mock_send_transaction.call_args[0][1] data = json_cb() - self.assertEqual(data['edus'], [ - { - 'edu_type': 'm.receipt', - 'content': { - 'room_id': { - 'm.read': { - 'user_id': { - 'event_ids': ['other_id'], - 'data': {'ts': 1234}, - }, - }, + self.assertEqual( + data['edus'], + [ + { + 'edu_type': 'm.receipt', + 'content': { + 'room_id': { + 'm.read': { + 'user_id': { + 'event_ids': ['other_id'], + 'data': {'ts': 1234}, + } + } + } }, - }, - }, - ]) + } + ], + ) diff --git a/tests/handlers/test_directory.py b/tests/handlers/test_directory.py index 5b2105bc760d..917548bb31d0 100644 --- a/tests/handlers/test_directory.py +++ b/tests/handlers/test_directory.py @@ -115,11 +115,7 @@ def prepare(self, reactor, clock, hs): # We cheekily override the config to add custom alias creation rules config = {} config["alias_creation_rules"] = [ - { - "user_id": "*", - "alias": "#unofficial_*", - "action": "allow", - } + {"user_id": "*", "alias": "#unofficial_*", "action": "allow"} ] config["room_list_publication_rules"] = [] @@ -162,9 +158,7 @@ def prepare(self, reactor, clock, hs): room_id = self.helper.create_room_as(self.user_id) request, channel = self.make_request( - "PUT", - b"directory/list/room/%s" % (room_id.encode('ascii'),), - b'{}', + "PUT", b"directory/list/room/%s" % (room_id.encode('ascii'),), b'{}' ) self.render(request) self.assertEquals(200, channel.code, channel.result) @@ -179,10 +173,7 @@ def test_disabling_room_list(self): self.directory_handler.enable_room_list_search = True # Room list is enabled so we should get some results - request, channel = self.make_request( - "GET", - b"publicRooms", - ) + request, channel = self.make_request("GET", b"publicRooms") self.render(request) self.assertEquals(200, channel.code, channel.result) self.assertTrue(len(channel.json_body["chunk"]) > 0) @@ -191,10 +182,7 @@ def test_disabling_room_list(self): self.directory_handler.enable_room_list_search = False # Room list disabled so we should get no results - request, channel = self.make_request( - "GET", - b"publicRooms", - ) + request, channel = self.make_request("GET", b"publicRooms") self.render(request) self.assertEquals(200, channel.code, channel.result) self.assertTrue(len(channel.json_body["chunk"]) == 0) @@ -202,9 +190,7 @@ def test_disabling_room_list(self): # Room list disabled so we shouldn't be allowed to publish rooms room_id = self.helper.create_room_as(self.user_id) request, channel = self.make_request( - "PUT", - b"directory/list/room/%s" % (room_id.encode('ascii'),), - b'{}', + "PUT", b"directory/list/room/%s" % (room_id.encode('ascii'),), b'{}' ) self.render(request) self.assertEquals(403, channel.code, channel.result) diff --git a/tests/handlers/test_e2e_room_keys.py b/tests/handlers/test_e2e_room_keys.py index 1c49bbbc3c51..2e72a1dd2381 100644 --- a/tests/handlers/test_e2e_room_keys.py +++ b/tests/handlers/test_e2e_room_keys.py @@ -36,7 +36,7 @@ "first_message_index": 1, "forwarded_count": 1, "is_verified": False, - "session_data": "SSBBTSBBIEZJU0gK" + "session_data": "SSBBTSBBIEZJU0gK", } } } @@ -47,15 +47,13 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): def __init__(self, *args, **kwargs): super(E2eRoomKeysHandlerTestCase, self).__init__(*args, **kwargs) - self.hs = None # type: synapse.server.HomeServer + self.hs = None # type: synapse.server.HomeServer self.handler = None # type: synapse.handlers.e2e_keys.E2eRoomKeysHandler @defer.inlineCallbacks def setUp(self): self.hs = yield utils.setup_test_homeserver( - self.addCleanup, - handlers=None, - replication_layer=mock.Mock(), + self.addCleanup, handlers=None, replication_layer=mock.Mock() ) self.handler = synapse.handlers.e2e_room_keys.E2eRoomKeysHandler(self.hs) self.local_user = "@boris:" + self.hs.hostname @@ -88,67 +86,86 @@ def test_get_missing_version_info(self): def test_create_version(self): """Check that we can create and then retrieve versions. """ - res = yield self.handler.create_version(self.local_user, { - "algorithm": "m.megolm_backup.v1", - "auth_data": "first_version_auth_data", - }) + res = yield self.handler.create_version( + self.local_user, + {"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"}, + ) self.assertEqual(res, "1") # check we can retrieve it as the current version res = yield self.handler.get_version_info(self.local_user) - self.assertDictEqual(res, { - "version": "1", - "algorithm": "m.megolm_backup.v1", - "auth_data": "first_version_auth_data", - }) + self.assertDictEqual( + res, + { + "version": "1", + "algorithm": "m.megolm_backup.v1", + "auth_data": "first_version_auth_data", + }, + ) # check we can retrieve it as a specific version res = yield self.handler.get_version_info(self.local_user, "1") - self.assertDictEqual(res, { - "version": "1", - "algorithm": "m.megolm_backup.v1", - "auth_data": "first_version_auth_data", - }) + self.assertDictEqual( + res, + { + "version": "1", + "algorithm": "m.megolm_backup.v1", + "auth_data": "first_version_auth_data", + }, + ) # upload a new one... - res = yield self.handler.create_version(self.local_user, { - "algorithm": "m.megolm_backup.v1", - "auth_data": "second_version_auth_data", - }) + res = yield self.handler.create_version( + self.local_user, + { + "algorithm": "m.megolm_backup.v1", + "auth_data": "second_version_auth_data", + }, + ) self.assertEqual(res, "2") # check we can retrieve it as the current version res = yield self.handler.get_version_info(self.local_user) - self.assertDictEqual(res, { - "version": "2", - "algorithm": "m.megolm_backup.v1", - "auth_data": "second_version_auth_data", - }) + self.assertDictEqual( + res, + { + "version": "2", + "algorithm": "m.megolm_backup.v1", + "auth_data": "second_version_auth_data", + }, + ) @defer.inlineCallbacks def test_update_version(self): """Check that we can update versions. """ - version = yield self.handler.create_version(self.local_user, { - "algorithm": "m.megolm_backup.v1", - "auth_data": "first_version_auth_data", - }) + version = yield self.handler.create_version( + self.local_user, + {"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"}, + ) self.assertEqual(version, "1") - res = yield self.handler.update_version(self.local_user, version, { - "algorithm": "m.megolm_backup.v1", - "auth_data": "revised_first_version_auth_data", - "version": version - }) + res = yield self.handler.update_version( + self.local_user, + version, + { + "algorithm": "m.megolm_backup.v1", + "auth_data": "revised_first_version_auth_data", + "version": version, + }, + ) self.assertDictEqual(res, {}) # check we can retrieve it as the current version res = yield self.handler.get_version_info(self.local_user) - self.assertDictEqual(res, { - "algorithm": "m.megolm_backup.v1", - "auth_data": "revised_first_version_auth_data", - "version": version - }) + self.assertDictEqual( + res, + { + "algorithm": "m.megolm_backup.v1", + "auth_data": "revised_first_version_auth_data", + "version": version, + }, + ) @defer.inlineCallbacks def test_update_missing_version(self): @@ -156,11 +173,15 @@ def test_update_missing_version(self): """ res = None try: - yield self.handler.update_version(self.local_user, "1", { - "algorithm": "m.megolm_backup.v1", - "auth_data": "revised_first_version_auth_data", - "version": "1" - }) + yield self.handler.update_version( + self.local_user, + "1", + { + "algorithm": "m.megolm_backup.v1", + "auth_data": "revised_first_version_auth_data", + "version": "1", + }, + ) except errors.SynapseError as e: res = e.code self.assertEqual(res, 404) @@ -170,29 +191,37 @@ def test_update_bad_version(self): """Check that we get a 400 if the version in the body is missing or doesn't match """ - version = yield self.handler.create_version(self.local_user, { - "algorithm": "m.megolm_backup.v1", - "auth_data": "first_version_auth_data", - }) + version = yield self.handler.create_version( + self.local_user, + {"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"}, + ) self.assertEqual(version, "1") res = None try: - yield self.handler.update_version(self.local_user, version, { - "algorithm": "m.megolm_backup.v1", - "auth_data": "revised_first_version_auth_data" - }) + yield self.handler.update_version( + self.local_user, + version, + { + "algorithm": "m.megolm_backup.v1", + "auth_data": "revised_first_version_auth_data", + }, + ) except errors.SynapseError as e: res = e.code self.assertEqual(res, 400) res = None try: - yield self.handler.update_version(self.local_user, version, { - "algorithm": "m.megolm_backup.v1", - "auth_data": "revised_first_version_auth_data", - "version": "incorrect" - }) + yield self.handler.update_version( + self.local_user, + version, + { + "algorithm": "m.megolm_backup.v1", + "auth_data": "revised_first_version_auth_data", + "version": "incorrect", + }, + ) except errors.SynapseError as e: res = e.code self.assertEqual(res, 400) @@ -223,10 +252,10 @@ def test_delete_missing_current_version(self): def test_delete_version(self): """Check that we can create and then delete versions. """ - res = yield self.handler.create_version(self.local_user, { - "algorithm": "m.megolm_backup.v1", - "auth_data": "first_version_auth_data", - }) + res = yield self.handler.create_version( + self.local_user, + {"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"}, + ) self.assertEqual(res, "1") # check we can delete it @@ -255,16 +284,14 @@ def test_get_missing_backup(self): def test_get_missing_room_keys(self): """Check we get an empty response from an empty backup """ - version = yield self.handler.create_version(self.local_user, { - "algorithm": "m.megolm_backup.v1", - "auth_data": "first_version_auth_data", - }) + version = yield self.handler.create_version( + self.local_user, + {"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"}, + ) self.assertEqual(version, "1") res = yield self.handler.get_room_keys(self.local_user, version) - self.assertDictEqual(res, { - "rooms": {} - }) + self.assertDictEqual(res, {"rooms": {}}) # TODO: test the locking semantics when uploading room_keys, # although this is probably best done in sytest @@ -275,7 +302,9 @@ def test_upload_room_keys_no_versions(self): """ res = None try: - yield self.handler.upload_room_keys(self.local_user, "no_version", room_keys) + yield self.handler.upload_room_keys( + self.local_user, "no_version", room_keys + ) except errors.SynapseError as e: res = e.code self.assertEqual(res, 404) @@ -285,10 +314,10 @@ def test_upload_room_keys_bogus_version(self): """Check that we get a 404 on uploading keys when an nonexistent version is specified """ - version = yield self.handler.create_version(self.local_user, { - "algorithm": "m.megolm_backup.v1", - "auth_data": "first_version_auth_data", - }) + version = yield self.handler.create_version( + self.local_user, + {"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"}, + ) self.assertEqual(version, "1") res = None @@ -304,16 +333,19 @@ def test_upload_room_keys_bogus_version(self): def test_upload_room_keys_wrong_version(self): """Check that we get a 403 on uploading keys for an old version """ - version = yield self.handler.create_version(self.local_user, { - "algorithm": "m.megolm_backup.v1", - "auth_data": "first_version_auth_data", - }) + version = yield self.handler.create_version( + self.local_user, + {"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"}, + ) self.assertEqual(version, "1") - version = yield self.handler.create_version(self.local_user, { - "algorithm": "m.megolm_backup.v1", - "auth_data": "second_version_auth_data", - }) + version = yield self.handler.create_version( + self.local_user, + { + "algorithm": "m.megolm_backup.v1", + "auth_data": "second_version_auth_data", + }, + ) self.assertEqual(version, "2") res = None @@ -327,10 +359,10 @@ def test_upload_room_keys_wrong_version(self): def test_upload_room_keys_insert(self): """Check that we can insert and retrieve keys for a session """ - version = yield self.handler.create_version(self.local_user, { - "algorithm": "m.megolm_backup.v1", - "auth_data": "first_version_auth_data", - }) + version = yield self.handler.create_version( + self.local_user, + {"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"}, + ) self.assertEqual(version, "1") yield self.handler.upload_room_keys(self.local_user, version, room_keys) @@ -340,18 +372,13 @@ def test_upload_room_keys_insert(self): # check getting room_keys for a given room res = yield self.handler.get_room_keys( - self.local_user, - version, - room_id="!abc:matrix.org" + self.local_user, version, room_id="!abc:matrix.org" ) self.assertDictEqual(res, room_keys) # check getting room_keys for a given session_id res = yield self.handler.get_room_keys( - self.local_user, - version, - room_id="!abc:matrix.org", - session_id="c0ff33", + self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33" ) self.assertDictEqual(res, room_keys) @@ -359,10 +386,10 @@ def test_upload_room_keys_insert(self): def test_upload_room_keys_merge(self): """Check that we can upload a new room_key for an existing session and have it correctly merged""" - version = yield self.handler.create_version(self.local_user, { - "algorithm": "m.megolm_backup.v1", - "auth_data": "first_version_auth_data", - }) + version = yield self.handler.create_version( + self.local_user, + {"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"}, + ) self.assertEqual(version, "1") yield self.handler.upload_room_keys(self.local_user, version, room_keys) @@ -378,7 +405,7 @@ def test_upload_room_keys_merge(self): res = yield self.handler.get_room_keys(self.local_user, version) self.assertEqual( res['rooms']['!abc:matrix.org']['sessions']['c0ff33']['session_data'], - "SSBBTSBBIEZJU0gK" + "SSBBTSBBIEZJU0gK", ) # test that marking the session as verified however /does/ replace it @@ -387,8 +414,7 @@ def test_upload_room_keys_merge(self): res = yield self.handler.get_room_keys(self.local_user, version) self.assertEqual( - res['rooms']['!abc:matrix.org']['sessions']['c0ff33']['session_data'], - "new" + res['rooms']['!abc:matrix.org']['sessions']['c0ff33']['session_data'], "new" ) # test that a session with a higher forwarded_count doesn't replace one @@ -399,8 +425,7 @@ def test_upload_room_keys_merge(self): res = yield self.handler.get_room_keys(self.local_user, version) self.assertEqual( - res['rooms']['!abc:matrix.org']['sessions']['c0ff33']['session_data'], - "new" + res['rooms']['!abc:matrix.org']['sessions']['c0ff33']['session_data'], "new" ) # TODO: check edge cases as well as the common variations here @@ -409,56 +434,36 @@ def test_upload_room_keys_merge(self): def test_delete_room_keys(self): """Check that we can insert and delete keys for a session """ - version = yield self.handler.create_version(self.local_user, { - "algorithm": "m.megolm_backup.v1", - "auth_data": "first_version_auth_data", - }) + version = yield self.handler.create_version( + self.local_user, + {"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"}, + ) self.assertEqual(version, "1") # check for bulk-delete yield self.handler.upload_room_keys(self.local_user, version, room_keys) yield self.handler.delete_room_keys(self.local_user, version) res = yield self.handler.get_room_keys( - self.local_user, - version, - room_id="!abc:matrix.org", - session_id="c0ff33", + self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33" ) - self.assertDictEqual(res, { - "rooms": {} - }) + self.assertDictEqual(res, {"rooms": {}}) # check for bulk-delete per room yield self.handler.upload_room_keys(self.local_user, version, room_keys) yield self.handler.delete_room_keys( - self.local_user, - version, - room_id="!abc:matrix.org", + self.local_user, version, room_id="!abc:matrix.org" ) res = yield self.handler.get_room_keys( - self.local_user, - version, - room_id="!abc:matrix.org", - session_id="c0ff33", + self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33" ) - self.assertDictEqual(res, { - "rooms": {} - }) + self.assertDictEqual(res, {"rooms": {}}) # check for bulk-delete per session yield self.handler.upload_room_keys(self.local_user, version, room_keys) yield self.handler.delete_room_keys( - self.local_user, - version, - room_id="!abc:matrix.org", - session_id="c0ff33", + self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33" ) res = yield self.handler.get_room_keys( - self.local_user, - version, - room_id="!abc:matrix.org", - session_id="c0ff33", + self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33" ) - self.assertDictEqual(res, { - "rooms": {} - }) + self.assertDictEqual(res, {"rooms": {}}) diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py index 94c6080e3485..f70c6e7d656c 100644 --- a/tests/handlers/test_presence.py +++ b/tests/handlers/test_presence.py @@ -424,8 +424,7 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase): def make_homeserver(self, reactor, clock): hs = self.setup_test_homeserver( - "server", http_client=None, - federation_sender=Mock(), + "server", http_client=None, federation_sender=Mock() ) return hs @@ -457,7 +456,7 @@ def test_remote_joins(self): # Mark test2 as online, test will be offline with a last_active of 0 self.presence_handler.set_state( - UserID.from_string("@test2:server"), {"presence": PresenceState.ONLINE}, + UserID.from_string("@test2:server"), {"presence": PresenceState.ONLINE} ) self.reactor.pump([0]) # Wait for presence updates to be handled @@ -506,13 +505,13 @@ def test_remote_gets_presence_when_local_user_joins(self): # Mark test as online self.presence_handler.set_state( - UserID.from_string("@test:server"), {"presence": PresenceState.ONLINE}, + UserID.from_string("@test:server"), {"presence": PresenceState.ONLINE} ) # Mark test2 as online, test will be offline with a last_active of 0. # Note we don't join them to the room yet self.presence_handler.set_state( - UserID.from_string("@test2:server"), {"presence": PresenceState.ONLINE}, + UserID.from_string("@test2:server"), {"presence": PresenceState.ONLINE} ) # Add servers to the room @@ -541,8 +540,7 @@ def test_remote_gets_presence_when_local_user_joins(self): ) self.assertEqual(expected_state.state, PresenceState.ONLINE) self.federation_sender.send_presence_to_destinations.assert_called_once_with( - destinations=set(("server2", "server3")), - states=[expected_state] + destinations=set(("server2", "server3")), states=[expected_state] ) def _add_new_user(self, room_id, user_id): @@ -565,7 +563,7 @@ def _add_new_user(self, room_id, user_id): type=EventTypes.Member, sender=user_id, state_key=user_id, - content={"membership": Membership.JOIN} + content={"membership": Membership.JOIN}, ) prev_event_ids = self.get_success( diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py index 5a0b6c201c9b..cb8b4d29138b 100644 --- a/tests/handlers/test_typing.py +++ b/tests/handlers/test_typing.py @@ -64,20 +64,22 @@ def make_homeserver(self, reactor, clock): mock_federation_client.put_json.return_value = defer.succeed((200, "OK")) hs = self.setup_test_homeserver( - datastore=(Mock( - spec=[ - # Bits that Federation needs - "prep_send_transaction", - "delivered_txn", - "get_received_txn_response", - "set_received_txn_response", - "get_destination_retry_timings", - "get_devices_by_remote", - # Bits that user_directory needs - "get_user_directory_stream_pos", - "get_current_state_deltas", - ] - )), + datastore=( + Mock( + spec=[ + # Bits that Federation needs + "prep_send_transaction", + "delivered_txn", + "get_received_txn_response", + "set_received_txn_response", + "get_destination_retry_timings", + "get_devices_by_remote", + # Bits that user_directory needs + "get_user_directory_stream_pos", + "get_current_state_deltas", + ] + ) + ), notifier=Mock(), http_client=mock_federation_client, keyring=mock_keyring, @@ -87,7 +89,7 @@ def make_homeserver(self, reactor, clock): def prepare(self, reactor, clock, hs): # the tests assume that we are starting at unix time 1000 - reactor.pump((1000, )) + reactor.pump((1000,)) mock_notifier = hs.get_notifier() self.on_new_event = mock_notifier.on_new_event @@ -114,6 +116,7 @@ def get_received_txn_response(*args): def check_joined_room(room_id, user_id): if user_id not in [u.to_string() for u in self.room_members]: raise AuthError(401, "User is not in the room") + hs.get_auth().check_joined_room = check_joined_room def get_joined_hosts_for_room(room_id): @@ -123,6 +126,7 @@ def get_joined_hosts_for_room(room_id): def get_current_users_in_room(room_id): return set(str(u) for u in self.room_members) + hs.get_state_handler().get_current_users_in_room = get_current_users_in_room self.datastore.get_user_directory_stream_pos.return_value = ( @@ -141,21 +145,16 @@ def test_started_typing_local(self): self.assertEquals(self.event_source.get_current_key(), 0) - self.successResultOf(self.handler.started_typing( - target_user=U_APPLE, - auth_user=U_APPLE, - room_id=ROOM_ID, - timeout=20000, - )) - - self.on_new_event.assert_has_calls( - [call('typing_key', 1, rooms=[ROOM_ID])] + self.successResultOf( + self.handler.started_typing( + target_user=U_APPLE, auth_user=U_APPLE, room_id=ROOM_ID, timeout=20000 + ) ) + self.on_new_event.assert_has_calls([call('typing_key', 1, rooms=[ROOM_ID])]) + self.assertEquals(self.event_source.get_current_key(), 1) - events = self.event_source.get_new_events( - room_ids=[ROOM_ID], from_key=0 - ) + events = self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=0) self.assertEquals( events[0], [ @@ -170,12 +169,11 @@ def test_started_typing_local(self): def test_started_typing_remote_send(self): self.room_members = [U_APPLE, U_ONION] - self.successResultOf(self.handler.started_typing( - target_user=U_APPLE, - auth_user=U_APPLE, - room_id=ROOM_ID, - timeout=20000, - )) + self.successResultOf( + self.handler.started_typing( + target_user=U_APPLE, auth_user=U_APPLE, room_id=ROOM_ID, timeout=20000 + ) + ) put_json = self.hs.get_http_client().put_json put_json.assert_called_once_with( @@ -216,14 +214,10 @@ def test_started_typing_remote_recv(self): self.render(request) self.assertEqual(channel.code, 200) - self.on_new_event.assert_has_calls( - [call('typing_key', 1, rooms=[ROOM_ID])] - ) + self.on_new_event.assert_has_calls([call('typing_key', 1, rooms=[ROOM_ID])]) self.assertEquals(self.event_source.get_current_key(), 1) - events = self.event_source.get_new_events( - room_ids=[ROOM_ID], from_key=0 - ) + events = self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=0) self.assertEquals( events[0], [ @@ -247,14 +241,14 @@ def test_stopped_typing(self): self.assertEquals(self.event_source.get_current_key(), 0) - self.successResultOf(self.handler.stopped_typing( - target_user=U_APPLE, auth_user=U_APPLE, room_id=ROOM_ID - )) - - self.on_new_event.assert_has_calls( - [call('typing_key', 1, rooms=[ROOM_ID])] + self.successResultOf( + self.handler.stopped_typing( + target_user=U_APPLE, auth_user=U_APPLE, room_id=ROOM_ID + ) ) + self.on_new_event.assert_has_calls([call('typing_key', 1, rooms=[ROOM_ID])]) + put_json = self.hs.get_http_client().put_json put_json.assert_called_once_with( "farm", @@ -274,18 +268,10 @@ def test_stopped_typing(self): ) self.assertEquals(self.event_source.get_current_key(), 1) - events = self.event_source.get_new_events( - room_ids=[ROOM_ID], from_key=0 - ) + events = self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=0) self.assertEquals( events[0], - [ - { - "type": "m.typing", - "room_id": ROOM_ID, - "content": {"user_ids": []}, - } - ], + [{"type": "m.typing", "room_id": ROOM_ID, "content": {"user_ids": []}}], ) def test_typing_timeout(self): @@ -293,22 +279,17 @@ def test_typing_timeout(self): self.assertEquals(self.event_source.get_current_key(), 0) - self.successResultOf(self.handler.started_typing( - target_user=U_APPLE, - auth_user=U_APPLE, - room_id=ROOM_ID, - timeout=10000, - )) - - self.on_new_event.assert_has_calls( - [call('typing_key', 1, rooms=[ROOM_ID])] + self.successResultOf( + self.handler.started_typing( + target_user=U_APPLE, auth_user=U_APPLE, room_id=ROOM_ID, timeout=10000 + ) ) + + self.on_new_event.assert_has_calls([call('typing_key', 1, rooms=[ROOM_ID])]) self.on_new_event.reset_mock() self.assertEquals(self.event_source.get_current_key(), 1) - events = self.event_source.get_new_events( - room_ids=[ROOM_ID], from_key=0 - ) + events = self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=0) self.assertEquals( events[0], [ @@ -320,45 +301,30 @@ def test_typing_timeout(self): ], ) - self.reactor.pump([16, ]) + self.reactor.pump([16]) - self.on_new_event.assert_has_calls( - [call('typing_key', 2, rooms=[ROOM_ID])] - ) + self.on_new_event.assert_has_calls([call('typing_key', 2, rooms=[ROOM_ID])]) self.assertEquals(self.event_source.get_current_key(), 2) - events = self.event_source.get_new_events( - room_ids=[ROOM_ID], from_key=1 - ) + events = self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=1) self.assertEquals( events[0], - [ - { - "type": "m.typing", - "room_id": ROOM_ID, - "content": {"user_ids": []}, - } - ], + [{"type": "m.typing", "room_id": ROOM_ID, "content": {"user_ids": []}}], ) # SYN-230 - see if we can still set after timeout - self.successResultOf(self.handler.started_typing( - target_user=U_APPLE, - auth_user=U_APPLE, - room_id=ROOM_ID, - timeout=10000, - )) - - self.on_new_event.assert_has_calls( - [call('typing_key', 3, rooms=[ROOM_ID])] + self.successResultOf( + self.handler.started_typing( + target_user=U_APPLE, auth_user=U_APPLE, room_id=ROOM_ID, timeout=10000 + ) ) + + self.on_new_event.assert_has_calls([call('typing_key', 3, rooms=[ROOM_ID])]) self.on_new_event.reset_mock() self.assertEquals(self.event_source.get_current_key(), 3) - events = self.event_source.get_new_events( - room_ids=[ROOM_ID], from_key=0 - ) + events = self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=0) self.assertEquals( events[0], [ diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py index 7dd1a1daf8c9..44468f5382f9 100644 --- a/tests/handlers/test_user_directory.py +++ b/tests/handlers/test_user_directory.py @@ -352,9 +352,7 @@ def test_disabling_room_list(self): # Assert user directory is not empty request, channel = self.make_request( - "POST", - b"user_directory/search", - b'{"search_term":"user2"}', + "POST", b"user_directory/search", b'{"search_term":"user2"}' ) self.render(request) self.assertEquals(200, channel.code, channel.result) @@ -363,9 +361,7 @@ def test_disabling_room_list(self): # Disable user directory and check search returns nothing self.config.user_directory_search_enabled = False request, channel = self.make_request( - "POST", - b"user_directory/search", - b'{"search_term":"user2"}', + "POST", b"user_directory/search", b'{"search_term":"user2"}' ) self.render(request) self.assertEquals(200, channel.code, channel.result) diff --git a/tests/http/__init__.py b/tests/http/__init__.py index ee8010f59883..851fc0eb332e 100644 --- a/tests/http/__init__.py +++ b/tests/http/__init__.py @@ -24,14 +24,12 @@ def get_test_cert_file(): # # openssl req -x509 -newkey rsa:4096 -keyout server.pem -out server.pem -days 36500 \ # -nodes -subj '/CN=testserv' - return os.path.join( - os.path.dirname(__file__), - 'server.pem', - ) + return os.path.join(os.path.dirname(__file__), 'server.pem') class ServerTLSContext(object): """A TLS Context which presents our test cert.""" + def __init__(self): self.filename = get_test_cert_file() diff --git a/tests/http/federation/test_matrix_federation_agent.py b/tests/http/federation/test_matrix_federation_agent.py index e9eb662c4c71..7036615041fe 100644 --- a/tests/http/federation/test_matrix_federation_agent.py +++ b/tests/http/federation/test_matrix_federation_agent.py @@ -79,12 +79,12 @@ def _make_connection(self, client_factory, expected_sni): # stubbing that out here. client_protocol = client_factory.buildProtocol(None) client_protocol.makeConnection( - FakeTransport(server_tls_protocol, self.reactor, client_protocol), + FakeTransport(server_tls_protocol, self.reactor, client_protocol) ) # tell the server tls protocol to send its stuff back to the client, too server_tls_protocol.makeConnection( - FakeTransport(client_protocol, self.reactor, server_tls_protocol), + FakeTransport(client_protocol, self.reactor, server_tls_protocol) ) # give the reactor a pump to get the TLS juices flowing. @@ -125,7 +125,7 @@ def _make_get_request(self, uri): _check_logcontext(context) def _handle_well_known_connection( - self, client_factory, expected_sni, content, response_headers={}, + self, client_factory, expected_sni, content, response_headers={} ): """Handle an outgoing HTTPs connection: wire it up to a server, check that the request is for a .well-known, and send the response. @@ -139,8 +139,7 @@ def _handle_well_known_connection( """ # make the connection for .well-known well_known_server = self._make_connection( - client_factory, - expected_sni=expected_sni, + client_factory, expected_sni=expected_sni ) # check the .well-known request and send a response self.assertEqual(len(well_known_server.requests), 1) @@ -154,17 +153,14 @@ def _send_well_known_response(self, request, content, headers={}): """ self.assertEqual(request.method, b'GET') self.assertEqual(request.path, b'/.well-known/matrix/server') - self.assertEqual( - request.requestHeaders.getRawHeaders(b'host'), - [b'testserv'], - ) + self.assertEqual(request.requestHeaders.getRawHeaders(b'host'), [b'testserv']) # send back a response for k, v in headers.items(): request.setHeader(k, v) request.write(content) request.finish() - self.reactor.pump((0.1, )) + self.reactor.pump((0.1,)) def test_get(self): """ @@ -184,18 +180,14 @@ def test_get(self): self.assertEqual(port, 8448) # make a test server, and wire up the client - http_server = self._make_connection( - client_factory, - expected_sni=b"testserv", - ) + http_server = self._make_connection(client_factory, expected_sni=b"testserv") self.assertEqual(len(http_server.requests), 1) request = http_server.requests[0] self.assertEqual(request.method, b'GET') self.assertEqual(request.path, b'/foo/bar') self.assertEqual( - request.requestHeaders.getRawHeaders(b'host'), - [b'testserv:8448'] + request.requestHeaders.getRawHeaders(b'host'), [b'testserv:8448'] ) content = request.content.read() self.assertEqual(content, b'') @@ -244,19 +236,13 @@ def test_get_ip_address(self): self.assertEqual(port, 8448) # make a test server, and wire up the client - http_server = self._make_connection( - client_factory, - expected_sni=None, - ) + http_server = self._make_connection(client_factory, expected_sni=None) self.assertEqual(len(http_server.requests), 1) request = http_server.requests[0] self.assertEqual(request.method, b'GET') self.assertEqual(request.path, b'/foo/bar') - self.assertEqual( - request.requestHeaders.getRawHeaders(b'host'), - [b'1.2.3.4'], - ) + self.assertEqual(request.requestHeaders.getRawHeaders(b'host'), [b'1.2.3.4']) # finish the request request.finish() @@ -285,19 +271,13 @@ def test_get_ipv6_address(self): self.assertEqual(port, 8448) # make a test server, and wire up the client - http_server = self._make_connection( - client_factory, - expected_sni=None, - ) + http_server = self._make_connection(client_factory, expected_sni=None) self.assertEqual(len(http_server.requests), 1) request = http_server.requests[0] self.assertEqual(request.method, b'GET') self.assertEqual(request.path, b'/foo/bar') - self.assertEqual( - request.requestHeaders.getRawHeaders(b'host'), - [b'[::1]'], - ) + self.assertEqual(request.requestHeaders.getRawHeaders(b'host'), [b'[::1]']) # finish the request request.finish() @@ -326,19 +306,13 @@ def test_get_ipv6_address_with_port(self): self.assertEqual(port, 80) # make a test server, and wire up the client - http_server = self._make_connection( - client_factory, - expected_sni=None, - ) + http_server = self._make_connection(client_factory, expected_sni=None) self.assertEqual(len(http_server.requests), 1) request = http_server.requests[0] self.assertEqual(request.method, b'GET') self.assertEqual(request.path, b'/foo/bar') - self.assertEqual( - request.requestHeaders.getRawHeaders(b'host'), - [b'[::1]:80'], - ) + self.assertEqual(request.requestHeaders.getRawHeaders(b'host'), [b'[::1]:80']) # finish the request request.finish() @@ -377,7 +351,7 @@ def test_get_no_srv_no_well_known(self): # now there should be a SRV lookup self.mock_resolver.resolve_service.assert_called_once_with( - b"_matrix._tcp.testserv", + b"_matrix._tcp.testserv" ) # we should fall back to a direct connection @@ -387,19 +361,13 @@ def test_get_no_srv_no_well_known(self): self.assertEqual(port, 8448) # make a test server, and wire up the client - http_server = self._make_connection( - client_factory, - expected_sni=b'testserv', - ) + http_server = self._make_connection(client_factory, expected_sni=b'testserv') self.assertEqual(len(http_server.requests), 1) request = http_server.requests[0] self.assertEqual(request.method, b'GET') self.assertEqual(request.path, b'/foo/bar') - self.assertEqual( - request.requestHeaders.getRawHeaders(b'host'), - [b'testserv'], - ) + self.assertEqual(request.requestHeaders.getRawHeaders(b'host'), [b'testserv']) # finish the request request.finish() @@ -427,13 +395,14 @@ def test_get_well_known(self): self.assertEqual(port, 443) self._handle_well_known_connection( - client_factory, expected_sni=b"testserv", + client_factory, + expected_sni=b"testserv", content=b'{ "m.server": "target-server" }', ) # there should be a SRV lookup self.mock_resolver.resolve_service.assert_called_once_with( - b"_matrix._tcp.target-server", + b"_matrix._tcp.target-server" ) # now we should get a connection to the target server @@ -444,8 +413,7 @@ def test_get_well_known(self): # make a test server, and wire up the client http_server = self._make_connection( - client_factory, - expected_sni=b'target-server', + client_factory, expected_sni=b'target-server' ) self.assertEqual(len(http_server.requests), 1) @@ -453,8 +421,7 @@ def test_get_well_known(self): self.assertEqual(request.method, b'GET') self.assertEqual(request.path, b'/foo/bar') self.assertEqual( - request.requestHeaders.getRawHeaders(b'host'), - [b'target-server'], + request.requestHeaders.getRawHeaders(b'host'), [b'target-server'] ) # finish the request @@ -490,8 +457,7 @@ def test_get_well_known_redirect(self): self.assertEqual(port, 443) redirect_server = self._make_connection( - client_factory, - expected_sni=b"testserv", + client_factory, expected_sni=b"testserv" ) # send a 302 redirect @@ -500,7 +466,7 @@ def test_get_well_known_redirect(self): request.redirect(b'https://testserv/even_better_known') request.finish() - self.reactor.pump((0.1, )) + self.reactor.pump((0.1,)) # now there should be another connection clients = self.reactor.tcpClients @@ -510,8 +476,7 @@ def test_get_well_known_redirect(self): self.assertEqual(port, 443) well_known_server = self._make_connection( - client_factory, - expected_sni=b"testserv", + client_factory, expected_sni=b"testserv" ) self.assertEqual(len(well_known_server.requests), 1, "No request after 302") @@ -521,11 +486,11 @@ def test_get_well_known_redirect(self): request.write(b'{ "m.server": "target-server" }') request.finish() - self.reactor.pump((0.1, )) + self.reactor.pump((0.1,)) # there should be a SRV lookup self.mock_resolver.resolve_service.assert_called_once_with( - b"_matrix._tcp.target-server", + b"_matrix._tcp.target-server" ) # now we should get a connection to the target server @@ -536,8 +501,7 @@ def test_get_well_known_redirect(self): # make a test server, and wire up the client http_server = self._make_connection( - client_factory, - expected_sni=b'target-server', + client_factory, expected_sni=b'target-server' ) self.assertEqual(len(http_server.requests), 1) @@ -545,8 +509,7 @@ def test_get_well_known_redirect(self): self.assertEqual(request.method, b'GET') self.assertEqual(request.path, b'/foo/bar') self.assertEqual( - request.requestHeaders.getRawHeaders(b'host'), - [b'target-server'], + request.requestHeaders.getRawHeaders(b'host'), [b'target-server'] ) # finish the request @@ -585,12 +548,12 @@ def test_get_invalid_well_known(self): self.assertEqual(port, 443) self._handle_well_known_connection( - client_factory, expected_sni=b"testserv", content=b'NOT JSON', + client_factory, expected_sni=b"testserv", content=b'NOT JSON' ) # now there should be a SRV lookup self.mock_resolver.resolve_service.assert_called_once_with( - b"_matrix._tcp.testserv", + b"_matrix._tcp.testserv" ) # we should fall back to a direct connection @@ -600,19 +563,13 @@ def test_get_invalid_well_known(self): self.assertEqual(port, 8448) # make a test server, and wire up the client - http_server = self._make_connection( - client_factory, - expected_sni=b'testserv', - ) + http_server = self._make_connection(client_factory, expected_sni=b'testserv') self.assertEqual(len(http_server.requests), 1) request = http_server.requests[0] self.assertEqual(request.method, b'GET') self.assertEqual(request.path, b'/foo/bar') - self.assertEqual( - request.requestHeaders.getRawHeaders(b'host'), - [b'testserv'], - ) + self.assertEqual(request.requestHeaders.getRawHeaders(b'host'), [b'testserv']) # finish the request request.finish() @@ -635,7 +592,7 @@ def test_get_hostname_srv(self): # the request for a .well-known will have failed with a DNS lookup error. self.mock_resolver.resolve_service.assert_called_once_with( - b"_matrix._tcp.testserv", + b"_matrix._tcp.testserv" ) # Make sure treq is trying to connect @@ -646,19 +603,13 @@ def test_get_hostname_srv(self): self.assertEqual(port, 8443) # make a test server, and wire up the client - http_server = self._make_connection( - client_factory, - expected_sni=b'testserv', - ) + http_server = self._make_connection(client_factory, expected_sni=b'testserv') self.assertEqual(len(http_server.requests), 1) request = http_server.requests[0] self.assertEqual(request.method, b'GET') self.assertEqual(request.path, b'/foo/bar') - self.assertEqual( - request.requestHeaders.getRawHeaders(b'host'), - [b'testserv'], - ) + self.assertEqual(request.requestHeaders.getRawHeaders(b'host'), [b'testserv']) # finish the request request.finish() @@ -685,17 +636,18 @@ def test_get_well_known_srv(self): self.assertEqual(port, 443) self.mock_resolver.resolve_service.side_effect = lambda _: [ - Server(host=b"srvtarget", port=8443), + Server(host=b"srvtarget", port=8443) ] self._handle_well_known_connection( - client_factory, expected_sni=b"testserv", + client_factory, + expected_sni=b"testserv", content=b'{ "m.server": "target-server" }', ) # there should be a SRV lookup self.mock_resolver.resolve_service.assert_called_once_with( - b"_matrix._tcp.target-server", + b"_matrix._tcp.target-server" ) # now we should get a connection to the target of the SRV record @@ -706,8 +658,7 @@ def test_get_well_known_srv(self): # make a test server, and wire up the client http_server = self._make_connection( - client_factory, - expected_sni=b'target-server', + client_factory, expected_sni=b'target-server' ) self.assertEqual(len(http_server.requests), 1) @@ -715,8 +666,7 @@ def test_get_well_known_srv(self): self.assertEqual(request.method, b'GET') self.assertEqual(request.path, b'/foo/bar') self.assertEqual( - request.requestHeaders.getRawHeaders(b'host'), - [b'target-server'], + request.requestHeaders.getRawHeaders(b'host'), [b'target-server'] ) # finish the request @@ -757,7 +707,7 @@ def test_idna_servername(self): # now there should have been a SRV lookup self.mock_resolver.resolve_service.assert_called_once_with( - b"_matrix._tcp.xn--bcher-kva.com", + b"_matrix._tcp.xn--bcher-kva.com" ) # We should fall back to port 8448 @@ -769,8 +719,7 @@ def test_idna_servername(self): # make a test server, and wire up the client http_server = self._make_connection( - client_factory, - expected_sni=b'xn--bcher-kva.com', + client_factory, expected_sni=b'xn--bcher-kva.com' ) self.assertEqual(len(http_server.requests), 1) @@ -778,8 +727,7 @@ def test_idna_servername(self): self.assertEqual(request.method, b'GET') self.assertEqual(request.path, b'/foo/bar') self.assertEqual( - request.requestHeaders.getRawHeaders(b'host'), - [b'xn--bcher-kva.com'], + request.requestHeaders.getRawHeaders(b'host'), [b'xn--bcher-kva.com'] ) # finish the request @@ -801,7 +749,7 @@ def test_idna_srv_target(self): self.assertNoResult(test_d) self.mock_resolver.resolve_service.assert_called_once_with( - b"_matrix._tcp.xn--bcher-kva.com", + b"_matrix._tcp.xn--bcher-kva.com" ) # Make sure treq is trying to connect @@ -813,8 +761,7 @@ def test_idna_srv_target(self): # make a test server, and wire up the client http_server = self._make_connection( - client_factory, - expected_sni=b'xn--bcher-kva.com', + client_factory, expected_sni=b'xn--bcher-kva.com' ) self.assertEqual(len(http_server.requests), 1) @@ -822,8 +769,7 @@ def test_idna_srv_target(self): self.assertEqual(request.method, b'GET') self.assertEqual(request.path, b'/foo/bar') self.assertEqual( - request.requestHeaders.getRawHeaders(b'host'), - [b'xn--bcher-kva.com'], + request.requestHeaders.getRawHeaders(b'host'), [b'xn--bcher-kva.com'] ) # finish the request @@ -897,67 +843,70 @@ def test_cache_control(self): # uppercase self.assertEqual( _cache_period_from_headers( - Headers({b'Cache-Control': [b'foo, Max-Age = 100, bar']}), - ), 100, + Headers({b'Cache-Control': [b'foo, Max-Age = 100, bar']}) + ), + 100, ) # missing value - self.assertIsNone(_cache_period_from_headers( - Headers({b'Cache-Control': [b'max-age=, bar']}), - )) + self.assertIsNone( + _cache_period_from_headers(Headers({b'Cache-Control': [b'max-age=, bar']})) + ) # hackernews: bogus due to semicolon - self.assertIsNone(_cache_period_from_headers( - Headers({b'Cache-Control': [b'private; max-age=0']}), - )) + self.assertIsNone( + _cache_period_from_headers( + Headers({b'Cache-Control': [b'private; max-age=0']}) + ) + ) # github self.assertEqual( _cache_period_from_headers( - Headers({b'Cache-Control': [b'max-age=0, private, must-revalidate']}), - ), 0, + Headers({b'Cache-Control': [b'max-age=0, private, must-revalidate']}) + ), + 0, ) # google self.assertEqual( _cache_period_from_headers( - Headers({b'cache-control': [b'private, max-age=0']}), - ), 0, + Headers({b'cache-control': [b'private, max-age=0']}) + ), + 0, ) def test_expires(self): self.assertEqual( _cache_period_from_headers( Headers({b'Expires': [b'Wed, 30 Jan 2019 07:35:33 GMT']}), - time_now=lambda: 1548833700 - ), 33, + time_now=lambda: 1548833700, + ), + 33, ) # cache-control overrides expires self.assertEqual( _cache_period_from_headers( - Headers({ - b'cache-control': [b'max-age=10'], - b'Expires': [b'Wed, 30 Jan 2019 07:35:33 GMT'] - }), - time_now=lambda: 1548833700 - ), 10, + Headers( + { + b'cache-control': [b'max-age=10'], + b'Expires': [b'Wed, 30 Jan 2019 07:35:33 GMT'], + } + ), + time_now=lambda: 1548833700, + ), + 10, ) # invalid expires means immediate expiry - self.assertEqual( - _cache_period_from_headers( - Headers({b'Expires': [b'0']}), - ), 0, - ) + self.assertEqual(_cache_period_from_headers(Headers({b'Expires': [b'0']})), 0) def _check_logcontext(context): current = LoggingContext.current_context() if current is not context: - raise AssertionError( - "Expected logcontext %s but was %s" % (context, current), - ) + raise AssertionError("Expected logcontext %s but was %s" % (context, current)) def _build_test_server(): @@ -973,7 +922,7 @@ def _build_test_server(): server_factory.log = _log_request server_tls_factory = TLSMemoryBIOFactory( - ServerTLSContext(), isClient=False, wrappedFactory=server_factory, + ServerTLSContext(), isClient=False, wrappedFactory=server_factory ) return server_tls_factory.buildProtocol(None) @@ -987,6 +936,7 @@ def _log_request(request): @implementer(IPolicyForHTTPS) class TrustingTLSPolicyForHTTPS(object): """An IPolicyForHTTPS which doesn't do any certificate verification""" + def creatorForNetloc(self, hostname, port): certificateOptions = OpenSSLCertificateOptions() return ClientTLSOptions(hostname, certificateOptions.getContext()) diff --git a/tests/http/federation/test_srv_resolver.py b/tests/http/federation/test_srv_resolver.py index a872e2441e79..034c0db8d2f7 100644 --- a/tests/http/federation/test_srv_resolver.py +++ b/tests/http/federation/test_srv_resolver.py @@ -68,9 +68,7 @@ def do_lookup(): dns_client_mock.lookupService.assert_called_once_with(service_name) - result_deferred.callback( - ([answer_srv], None, None) - ) + result_deferred.callback(([answer_srv], None, None)) servers = self.successResultOf(test_d) @@ -112,7 +110,7 @@ def test_from_cache(self): cache = {service_name: [entry]} resolver = SrvResolver( - dns_client=dns_client_mock, cache=cache, get_time=clock.time, + dns_client=dns_client_mock, cache=cache, get_time=clock.time ) servers = yield resolver.resolve_service(service_name) @@ -168,11 +166,13 @@ def test_disabled_service(self): self.assertNoResult(resolve_d) # returning a single "." should make the lookup fail with a ConenctError - lookup_deferred.callback(( - [dns.RRHeader(type=dns.SRV, payload=dns.Record_SRV(target=b"."))], - None, - None, - )) + lookup_deferred.callback( + ( + [dns.RRHeader(type=dns.SRV, payload=dns.Record_SRV(target=b"."))], + None, + None, + ) + ) self.failureResultOf(resolve_d, ConnectError) @@ -191,14 +191,16 @@ def test_non_srv_answer(self): resolve_d = resolver.resolve_service(service_name) self.assertNoResult(resolve_d) - lookup_deferred.callback(( - [ - dns.RRHeader(type=dns.A, payload=dns.Record_A()), - dns.RRHeader(type=dns.SRV, payload=dns.Record_SRV(target=b"host")), - ], - None, - None, - )) + lookup_deferred.callback( + ( + [ + dns.RRHeader(type=dns.A, payload=dns.Record_A()), + dns.RRHeader(type=dns.SRV, payload=dns.Record_SRV(target=b"host")), + ], + None, + None, + ) + ) servers = self.successResultOf(resolve_d) diff --git a/tests/http/test_fedclient.py b/tests/http/test_fedclient.py index cd8e086f869a..279e456614ac 100644 --- a/tests/http/test_fedclient.py +++ b/tests/http/test_fedclient.py @@ -36,9 +36,7 @@ def check_logcontext(context): current = LoggingContext.current_context() if current is not context: - raise AssertionError( - "Expected logcontext %s but was %s" % (context, current), - ) + raise AssertionError("Expected logcontext %s but was %s" % (context, current)) class FederationClientTests(HomeserverTestCase): @@ -54,6 +52,7 @@ def test_client_get(self): """ happy-path test of a GET request """ + @defer.inlineCallbacks def do_request(): with LoggingContext("one") as context: @@ -175,8 +174,7 @@ def test_client_never_connect(self): self.assertIsInstance(f.value, RequestSendFailed) self.assertIsInstance( - f.value.inner_exception, - (ConnectingCancelledError, TimeoutError), + f.value.inner_exception, (ConnectingCancelledError, TimeoutError) ) def test_client_connect_no_response(self): @@ -216,9 +214,7 @@ def test_client_gets_headers(self): Once the client gets the headers, _request returns successfully. """ request = MatrixFederationRequest( - method="GET", - destination="testserv:8008", - path="foo/bar", + method="GET", destination="testserv:8008", path="foo/bar" ) d = self.cl._send_request(request, timeout=10000) @@ -258,8 +254,10 @@ def test_client_headers_no_body(self): # Send it the HTTP response client.dataReceived( - (b"HTTP/1.1 200 OK\r\nContent-Type: application/json\r\n" - b"Server: Fake\r\n\r\n") + ( + b"HTTP/1.1 200 OK\r\nContent-Type: application/json\r\n" + b"Server: Fake\r\n\r\n" + ) ) # Push by enough to time it out @@ -274,9 +272,7 @@ def test_client_requires_trailing_slashes(self): requiring a trailing slash. We need to retry the request with a trailing slash. Workaround for Synapse <= v0.99.3, explained in #3622. """ - d = self.cl.get_json( - "testserv:8008", "foo/bar", try_trailing_slash_on_400=True, - ) + d = self.cl.get_json("testserv:8008", "foo/bar", try_trailing_slash_on_400=True) # Send the request self.pump() @@ -329,9 +325,7 @@ def test_client_does_not_retry_on_400_plus(self): See test_client_requires_trailing_slashes() for context. """ - d = self.cl.get_json( - "testserv:8008", "foo/bar", try_trailing_slash_on_400=True, - ) + d = self.cl.get_json("testserv:8008", "foo/bar", try_trailing_slash_on_400=True) # Send the request self.pump() @@ -368,10 +362,7 @@ def test_client_does_not_retry_on_400_plus(self): self.failureResultOf(d) def test_client_sends_body(self): - self.cl.post_json( - "testserv:8008", "foo/bar", timeout=10000, - data={"a": "b"} - ) + self.cl.post_json("testserv:8008", "foo/bar", timeout=10000, data={"a": "b"}) self.pump() diff --git a/tests/patch_inline_callbacks.py b/tests/patch_inline_callbacks.py index 0f613945c8e8..ee0add3455e7 100644 --- a/tests/patch_inline_callbacks.py +++ b/tests/patch_inline_callbacks.py @@ -45,7 +45,9 @@ def wrapped(*args, **kwargs): except Exception: if LoggingContext.current_context() != start_context: err = "%s changed context from %s to %s on exception" % ( - f, start_context, LoggingContext.current_context() + f, + start_context, + LoggingContext.current_context(), ) print(err, file=sys.stderr) raise Exception(err) @@ -54,7 +56,9 @@ def wrapped(*args, **kwargs): if not isinstance(res, Deferred) or res.called: if LoggingContext.current_context() != start_context: err = "%s changed context from %s to %s" % ( - f, start_context, LoggingContext.current_context() + f, + start_context, + LoggingContext.current_context(), ) # print the error to stderr because otherwise all we # see in travis-ci is the 500 error @@ -66,9 +70,7 @@ def wrapped(*args, **kwargs): err = ( "%s returned incomplete deferred in non-sentinel context " "%s (start was %s)" - ) % ( - f, LoggingContext.current_context(), start_context, - ) + ) % (f, LoggingContext.current_context(), start_context) print(err, file=sys.stderr) raise Exception(err) @@ -76,7 +78,9 @@ def check_ctx(r): if LoggingContext.current_context() != start_context: err = "%s completion of %s changed context from %s to %s" % ( "Failure" if isinstance(r, Failure) else "Success", - f, start_context, LoggingContext.current_context(), + f, + start_context, + LoggingContext.current_context(), ) print(err, file=sys.stderr) raise Exception(err) diff --git a/tests/replication/slave/storage/_base.py b/tests/replication/slave/storage/_base.py index 1f72a2a04f8a..104349cdbd83 100644 --- a/tests/replication/slave/storage/_base.py +++ b/tests/replication/slave/storage/_base.py @@ -74,21 +74,18 @@ def check(self, method, args, expected_result=None): self.assertEqual( master_result, expected_result, - "Expected master result to be %r but was %r" % ( - expected_result, master_result - ), + "Expected master result to be %r but was %r" + % (expected_result, master_result), ) self.assertEqual( slaved_result, expected_result, - "Expected slave result to be %r but was %r" % ( - expected_result, slaved_result - ), + "Expected slave result to be %r but was %r" + % (expected_result, slaved_result), ) self.assertEqual( master_result, slaved_result, - "Slave result %r does not match master result %r" % ( - slaved_result, master_result - ), + "Slave result %r does not match master result %r" + % (slaved_result, master_result), ) diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py index 65ecff3bd6eb..a368117b43ca 100644 --- a/tests/replication/slave/storage/test_events.py +++ b/tests/replication/slave/storage/test_events.py @@ -234,10 +234,7 @@ def test_get_rooms_for_user_with_stream_ordering_with_multi_event_persist(self): type="m.room.member", sender=USER_ID_2, key=USER_ID_2, membership="join" ) msg, msgctx = self.build_event() - self.get_success(self.master_store.persist_events([ - (j2, j2ctx), - (msg, msgctx), - ])) + self.get_success(self.master_store.persist_events([(j2, j2ctx), (msg, msgctx)])) self.replicate() event_source = RoomEventSource(self.hs) @@ -257,15 +254,13 @@ def test_get_rooms_for_user_with_stream_ordering_with_multi_event_persist(self): # # First, we get a list of the rooms we are joined to joined_rooms = self.get_success( - self.slaved_store.get_rooms_for_user_with_stream_ordering( - USER_ID_2, - ), + self.slaved_store.get_rooms_for_user_with_stream_ordering(USER_ID_2) ) # Then, we get a list of the events since the last sync membership_changes = self.get_success( self.slaved_store.get_membership_changes_for_user( - USER_ID_2, prev_token, current_token, + USER_ID_2, prev_token, current_token ) ) @@ -298,9 +293,7 @@ def persist(self, backfill=False, **kwargs): self.master_store.persist_events([(event, context)], backfilled=True) ) else: - self.get_success( - self.master_store.persist_event(event, context) - ) + self.get_success(self.master_store.persist_event(event, context)) return event @@ -359,9 +352,7 @@ def build_event( ) else: state_handler = self.hs.get_state_handler() - context = self.get_success(state_handler.compute_event_context( - event - )) + context = self.get_success(state_handler.compute_event_context(event)) self.master_store.add_push_actions_to_staging( event.event_id, {user_id: actions for user_id, actions in push_actions} diff --git a/tests/replication/tcp/streams/_base.py b/tests/replication/tcp/streams/_base.py index 38b368a97244..ce3835ae6a15 100644 --- a/tests/replication/tcp/streams/_base.py +++ b/tests/replication/tcp/streams/_base.py @@ -22,6 +22,7 @@ class BaseStreamTestCase(unittest.HomeserverTestCase): """Base class for tests of the replication streams""" + def prepare(self, reactor, clock, hs): # build a replication server server_factory = ReplicationStreamProtocolFactory(self.hs) @@ -52,6 +53,7 @@ def replicate_stream(self, stream, token="NOW"): class TestReplicationClientHandler(object): """Drop-in for ReplicationClientHandler which just collects RDATA rows""" + def __init__(self): self.received_rdata_rows = [] @@ -69,6 +71,4 @@ def finished_connecting(self): def on_rdata(self, stream_name, token, rows): for r in rows: - self.received_rdata_rows.append( - (stream_name, token, r) - ) + self.received_rdata_rows.append((stream_name, token, r)) diff --git a/tests/rest/admin/test_admin.py b/tests/rest/admin/test_admin.py index da19a83918c8..ee5f09041f27 100644 --- a/tests/rest/admin/test_admin.py +++ b/tests/rest/admin/test_admin.py @@ -41,10 +41,10 @@ def test_version_string(self): request, channel = self.make_request("GET", self.url, shorthand=False) self.render(request) - self.assertEqual(200, int(channel.result["code"]), - msg=channel.result["body"]) - self.assertEqual({'server_version', 'python_version'}, - set(channel.json_body.keys())) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + {'server_version', 'python_version'}, set(channel.json_body.keys()) + ) class UserRegisterTestCase(unittest.HomeserverTestCase): @@ -200,9 +200,7 @@ def test_nonce_reuse(self): nonce = channel.json_body["nonce"] want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1) - want_mac.update( - nonce.encode('ascii') + b"\x00bob\x00abc123\x00admin" - ) + want_mac.update(nonce.encode('ascii') + b"\x00bob\x00abc123\x00admin") want_mac = want_mac.hexdigest() body = json.dumps( @@ -330,11 +328,13 @@ def nonce(): # # Invalid user_type - body = json.dumps({ - "nonce": nonce(), - "username": "a", - "password": "1234", - "user_type": "invalid"} + body = json.dumps( + { + "nonce": nonce(), + "username": "a", + "password": "1234", + "user_type": "invalid", + } ) request, channel = self.make_request("POST", self.url, body.encode('utf8')) self.render(request) @@ -357,9 +357,7 @@ def prepare(self, reactor, clock, hs): hs.config.user_consent_version = "1" consent_uri_builder = Mock() - consent_uri_builder.build_user_consent_uri.return_value = ( - "http://example.com" - ) + consent_uri_builder.build_user_consent_uri.return_value = "http://example.com" self.event_creation_handler._consent_uri_builder = consent_uri_builder self.store = hs.get_datastore() @@ -371,9 +369,7 @@ def prepare(self, reactor, clock, hs): self.other_user_token = self.login("user", "pass") # Mark the admin user as having consented - self.get_success( - self.store.user_set_consent_version(self.admin_user, "1"), - ) + self.get_success(self.store.user_set_consent_version(self.admin_user, "1")) def test_shutdown_room_consent(self): """Test that we can shutdown rooms with local users who have not @@ -385,9 +381,7 @@ def test_shutdown_room_consent(self): room_id = self.helper.create_room_as(self.other_user, tok=self.other_user_token) # Assert one user in room - users_in_room = self.get_success( - self.store.get_users_in_room(room_id), - ) + users_in_room = self.get_success(self.store.get_users_in_room(room_id)) self.assertEqual([self.other_user], users_in_room) # Enable require consent to send events @@ -395,8 +389,7 @@ def test_shutdown_room_consent(self): # Assert that the user is getting consent error self.helper.send( - room_id, - body="foo", tok=self.other_user_token, expect_code=403, + room_id, body="foo", tok=self.other_user_token, expect_code=403 ) # Test that the admin can still send shutdown @@ -412,9 +405,7 @@ def test_shutdown_room_consent(self): self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) # Assert there is now no longer anyone in the room - users_in_room = self.get_success( - self.store.get_users_in_room(room_id), - ) + users_in_room = self.get_success(self.store.get_users_in_room(room_id)) self.assertEqual([], users_in_room) @unittest.DEBUG @@ -459,24 +450,20 @@ def _assert_peek(self, room_id, expect_code): url = "rooms/%s/initialSync" % (room_id,) request, channel = self.make_request( - "GET", - url.encode('ascii'), - access_token=self.admin_user_tok, + "GET", url.encode('ascii'), access_token=self.admin_user_tok ) self.render(request) self.assertEqual( - expect_code, int(channel.result["code"]), msg=channel.result["body"], + expect_code, int(channel.result["code"]), msg=channel.result["body"] ) url = "events?timeout=0&room_id=" + room_id request, channel = self.make_request( - "GET", - url.encode('ascii'), - access_token=self.admin_user_tok, + "GET", url.encode('ascii'), access_token=self.admin_user_tok ) self.render(request) self.assertEqual( - expect_code, int(channel.result["code"]), msg=channel.result["body"], + expect_code, int(channel.result["code"]), msg=channel.result["body"] ) @@ -502,15 +489,11 @@ def test_delete_group(self): "POST", "/create_group".encode('ascii'), access_token=self.admin_user_tok, - content={ - "localpart": "test", - } + content={"localpart": "test"}, ) self.render(request) - self.assertEqual( - 200, int(channel.result["code"]), msg=channel.result["body"], - ) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) group_id = channel.json_body["group_id"] @@ -520,27 +503,17 @@ def test_delete_group(self): url = "/groups/%s/admin/users/invite/%s" % (group_id, self.other_user) request, channel = self.make_request( - "PUT", - url.encode('ascii'), - access_token=self.admin_user_tok, - content={} + "PUT", url.encode('ascii'), access_token=self.admin_user_tok, content={} ) self.render(request) - self.assertEqual( - 200, int(channel.result["code"]), msg=channel.result["body"], - ) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) url = "/groups/%s/self/accept_invite" % (group_id,) request, channel = self.make_request( - "PUT", - url.encode('ascii'), - access_token=self.other_user_token, - content={} + "PUT", url.encode('ascii'), access_token=self.other_user_token, content={} ) self.render(request) - self.assertEqual( - 200, int(channel.result["code"]), msg=channel.result["body"], - ) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) # Check other user knows they're in the group self.assertIn(group_id, self._get_groups_user_is_in(self.admin_user_tok)) @@ -552,15 +525,11 @@ def test_delete_group(self): "POST", url.encode('ascii'), access_token=self.admin_user_tok, - content={ - "localpart": "test", - } + content={"localpart": "test"}, ) self.render(request) - self.assertEqual( - 200, int(channel.result["code"]), msg=channel.result["body"], - ) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) # Check group returns 404 self._check_group(group_id, expect_code=404) @@ -576,28 +545,22 @@ def _check_group(self, group_id, expect_code): url = "/groups/%s/profile" % (group_id,) request, channel = self.make_request( - "GET", - url.encode('ascii'), - access_token=self.admin_user_tok, + "GET", url.encode('ascii'), access_token=self.admin_user_tok ) self.render(request) self.assertEqual( - expect_code, int(channel.result["code"]), msg=channel.result["body"], + expect_code, int(channel.result["code"]), msg=channel.result["body"] ) def _get_groups_user_is_in(self, access_token): """Returns the list of groups the user is in (given their access token) """ request, channel = self.make_request( - "GET", - "/joined_groups".encode('ascii'), - access_token=access_token, + "GET", "/joined_groups".encode('ascii'), access_token=access_token ) self.render(request) - self.assertEqual( - 200, int(channel.result["code"]), msg=channel.result["body"], - ) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) return channel.json_body["groups"] diff --git a/tests/rest/client/test_identity.py b/tests/rest/client/test_identity.py index 2e51ffa4189f..1a714ff58a59 100644 --- a/tests/rest/client/test_identity.py +++ b/tests/rest/client/test_identity.py @@ -44,7 +44,7 @@ def test_3pid_lookup_disabled(self): tok = self.login("kermit", "monkey") request, channel = self.make_request( - b"POST", "/createRoom", b"{}", access_token=tok, + b"POST", "/createRoom", b"{}", access_token=tok ) self.render(request) self.assertEquals(channel.result["code"], b"200", channel.result) @@ -56,11 +56,9 @@ def test_3pid_lookup_disabled(self): "address": "test@example.com", } request_data = json.dumps(params) - request_url = ( - "/rooms/%s/invite" % (room_id) - ).encode('ascii') + request_url = ("/rooms/%s/invite" % (room_id)).encode('ascii') request, channel = self.make_request( - b"POST", request_url, request_data, access_token=tok, + b"POST", request_url, request_data, access_token=tok ) self.render(request) self.assertEquals(channel.result["code"], b"403", channel.result) diff --git a/tests/rest/client/v1/test_directory.py b/tests/rest/client/v1/test_directory.py index f63c68e7ed73..73c5b44b46c7 100644 --- a/tests/rest/client/v1/test_directory.py +++ b/tests/rest/client/v1/test_directory.py @@ -45,7 +45,7 @@ def prepare(self, reactor, clock, homeserver): self.room_owner_tok = self.login("room_owner", "test") self.room_id = self.helper.create_room_as( - self.room_owner, tok=self.room_owner_tok, + self.room_owner, tok=self.room_owner_tok ) self.user = self.register_user("user", "test") @@ -80,12 +80,10 @@ def test_room_creation_too_long(self): # We use deliberately a localpart under the length threshold so # that we can make sure that the check is done on the whole alias. - data = { - "room_alias_name": random_string(256 - len(self.hs.hostname)), - } + data = {"room_alias_name": random_string(256 - len(self.hs.hostname))} request_data = json.dumps(data) request, channel = self.make_request( - "POST", url, request_data, access_token=self.user_tok, + "POST", url, request_data, access_token=self.user_tok ) self.render(request) self.assertEqual(channel.code, 400, channel.result) @@ -96,51 +94,42 @@ def test_room_creation(self): # Check with an alias of allowed length. There should already be # a test that ensures it works in test_register.py, but let's be # as cautious as possible here. - data = { - "room_alias_name": random_string(5), - } + data = {"room_alias_name": random_string(5)} request_data = json.dumps(data) request, channel = self.make_request( - "POST", url, request_data, access_token=self.user_tok, + "POST", url, request_data, access_token=self.user_tok ) self.render(request) self.assertEqual(channel.code, 200, channel.result) def set_alias_via_state_event(self, expected_code, alias_length=5): - url = ("/_matrix/client/r0/rooms/%s/state/m.room.aliases/%s" - % (self.room_id, self.hs.hostname)) - - data = { - "aliases": [ - self.random_alias(alias_length), - ], - } + url = "/_matrix/client/r0/rooms/%s/state/m.room.aliases/%s" % ( + self.room_id, + self.hs.hostname, + ) + + data = {"aliases": [self.random_alias(alias_length)]} request_data = json.dumps(data) request, channel = self.make_request( - "PUT", url, request_data, access_token=self.user_tok, + "PUT", url, request_data, access_token=self.user_tok ) self.render(request) self.assertEqual(channel.code, expected_code, channel.result) def set_alias_via_directory(self, expected_code, alias_length=5): url = "/_matrix/client/r0/directory/room/%s" % self.random_alias(alias_length) - data = { - "room_id": self.room_id, - } + data = {"room_id": self.room_id} request_data = json.dumps(data) request, channel = self.make_request( - "PUT", url, request_data, access_token=self.user_tok, + "PUT", url, request_data, access_token=self.user_tok ) self.render(request) self.assertEqual(channel.code, expected_code, channel.result) def random_alias(self, length): - return RoomAlias( - random_string(length), - self.hs.hostname, - ).to_string() + return RoomAlias(random_string(length), self.hs.hostname).to_string() def ensure_user_left_room(self): self.ensure_membership("leave") @@ -151,17 +140,9 @@ def ensure_user_joined_room(self): def ensure_membership(self, membership): try: if membership == "leave": - self.helper.leave( - room=self.room_id, - user=self.user, - tok=self.user_tok, - ) + self.helper.leave(room=self.room_id, user=self.user, tok=self.user_tok) if membership == "join": - self.helper.join( - room=self.room_id, - user=self.user, - tok=self.user_tok, - ) + self.helper.join(room=self.room_id, user=self.user, tok=self.user_tok) except AssertionError: # We don't care whether the leave request didn't return a 200 (e.g. # if the user isn't already in the room), because we only want to diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py index 9ebd91f6789e..0397f91a9e69 100644 --- a/tests/rest/client/v1/test_login.py +++ b/tests/rest/client/v1/test_login.py @@ -37,10 +37,7 @@ def test_POST_ratelimiting_per_address(self): for i in range(0, 6): params = { "type": "m.login.password", - "identifier": { - "type": "m.id.user", - "user": "kermit" + str(i), - }, + "identifier": {"type": "m.id.user", "user": "kermit" + str(i)}, "password": "monkey", } request_data = json.dumps(params) @@ -57,14 +54,11 @@ def test_POST_ratelimiting_per_address(self): # than 1min. self.assertTrue(retry_after_ms < 6000) - self.reactor.advance(retry_after_ms / 1000.) + self.reactor.advance(retry_after_ms / 1000.0) params = { "type": "m.login.password", - "identifier": { - "type": "m.id.user", - "user": "kermit" + str(i), - }, + "identifier": {"type": "m.id.user", "user": "kermit" + str(i)}, "password": "monkey", } request_data = json.dumps(params) @@ -82,10 +76,7 @@ def test_POST_ratelimiting_per_account(self): for i in range(0, 6): params = { "type": "m.login.password", - "identifier": { - "type": "m.id.user", - "user": "kermit", - }, + "identifier": {"type": "m.id.user", "user": "kermit"}, "password": "monkey", } request_data = json.dumps(params) @@ -102,14 +93,11 @@ def test_POST_ratelimiting_per_account(self): # than 1min. self.assertTrue(retry_after_ms < 6000) - self.reactor.advance(retry_after_ms / 1000.) + self.reactor.advance(retry_after_ms / 1000.0) params = { "type": "m.login.password", - "identifier": { - "type": "m.id.user", - "user": "kermit", - }, + "identifier": {"type": "m.id.user", "user": "kermit"}, "password": "monkey", } request_data = json.dumps(params) @@ -127,10 +115,7 @@ def test_POST_ratelimiting_per_account_failed_attempts(self): for i in range(0, 6): params = { "type": "m.login.password", - "identifier": { - "type": "m.id.user", - "user": "kermit", - }, + "identifier": {"type": "m.id.user", "user": "kermit"}, "password": "notamonkey", } request_data = json.dumps(params) @@ -147,14 +132,11 @@ def test_POST_ratelimiting_per_account_failed_attempts(self): # than 1min. self.assertTrue(retry_after_ms < 6000) - self.reactor.advance(retry_after_ms / 1000.) + self.reactor.advance(retry_after_ms / 1000.0) params = { "type": "m.login.password", - "identifier": { - "type": "m.id.user", - "user": "kermit", - }, + "identifier": {"type": "m.id.user", "user": "kermit"}, "password": "notamonkey", } request_data = json.dumps(params) diff --git a/tests/rest/client/v1/test_profile.py b/tests/rest/client/v1/test_profile.py index 7306e61b7c5b..ed034879cf6f 100644 --- a/tests/rest/client/v1/test_profile.py +++ b/tests/rest/client/v1/test_profile.py @@ -199,37 +199,24 @@ def test_not_in_shared_room(self): def test_in_shared_room(self): self.ensure_requester_left_room() - self.helper.join( - room=self.room_id, - user=self.requester, - tok=self.requester_tok, - ) + self.helper.join(room=self.room_id, user=self.requester, tok=self.requester_tok) self.try_fetch_profile(200, self.requester_tok) def try_fetch_profile(self, expected_code, access_token=None): - self.request_profile( - expected_code, - access_token=access_token - ) + self.request_profile(expected_code, access_token=access_token) self.request_profile( - expected_code, - url_suffix="/displayname", - access_token=access_token, + expected_code, url_suffix="/displayname", access_token=access_token ) self.request_profile( - expected_code, - url_suffix="/avatar_url", - access_token=access_token, + expected_code, url_suffix="/avatar_url", access_token=access_token ) def request_profile(self, expected_code, url_suffix="", access_token=None): request, channel = self.make_request( - "GET", - self.profile_url + url_suffix, - access_token=access_token, + "GET", self.profile_url + url_suffix, access_token=access_token ) self.render(request) self.assertEqual(channel.code, expected_code, channel.result) @@ -237,9 +224,7 @@ def request_profile(self, expected_code, url_suffix="", access_token=None): def ensure_requester_left_room(self): try: self.helper.leave( - room=self.room_id, - user=self.requester, - tok=self.requester_tok, + room=self.room_id, user=self.requester, tok=self.requester_tok ) except AssertionError: # We don't care whether the leave request didn't return a 200 (e.g. diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py index 1c3a621d262e..be95dc592d0f 100644 --- a/tests/rest/client/v2_alpha/test_register.py +++ b/tests/rest/client/v2_alpha/test_register.py @@ -41,11 +41,10 @@ def test_POST_appservice_registration_valid(self): as_token = "i_am_an_app_service" appservice = ApplicationService( - as_token, self.hs.config.server_name, + as_token, + self.hs.config.server_name, id="1234", - namespaces={ - "users": [{"regex": r"@as_user.*", "exclusive": True}], - }, + namespaces={"users": [{"regex": r"@as_user.*", "exclusive": True}]}, ) self.hs.get_datastore().services_cache.append(appservice) @@ -57,10 +56,7 @@ def test_POST_appservice_registration_valid(self): self.render(request) self.assertEquals(channel.result["code"], b"200", channel.result) - det_data = { - "user_id": user_id, - "home_server": self.hs.hostname, - } + det_data = {"user_id": user_id, "home_server": self.hs.hostname} self.assertDictContainsSubset(det_data, channel.json_body) def test_POST_appservice_registration_invalid(self): @@ -128,10 +124,7 @@ def test_POST_guest_registration(self): request, channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}") self.render(request) - det_data = { - "home_server": self.hs.hostname, - "device_id": "guest_device", - } + det_data = {"home_server": self.hs.hostname, "device_id": "guest_device"} self.assertEquals(channel.result["code"], b"200", channel.result) self.assertDictContainsSubset(det_data, channel.json_body) @@ -159,7 +152,7 @@ def test_POST_ratelimiting_guest(self): else: self.assertEquals(channel.result["code"], b"200", channel.result) - self.reactor.advance(retry_after_ms / 1000.) + self.reactor.advance(retry_after_ms / 1000.0) request, channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}") self.render(request) @@ -187,7 +180,7 @@ def test_POST_ratelimiting(self): else: self.assertEquals(channel.result["code"], b"200", channel.result) - self.reactor.advance(retry_after_ms / 1000.) + self.reactor.advance(retry_after_ms / 1000.0) request, channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}") self.render(request) @@ -221,23 +214,19 @@ def test_validity_period(self): # The specific endpoint doesn't matter, all we need is an authenticated # endpoint. - request, channel = self.make_request( - b"GET", "/sync", access_token=tok, - ) + request, channel = self.make_request(b"GET", "/sync", access_token=tok) self.render(request) self.assertEquals(channel.result["code"], b"200", channel.result) self.reactor.advance(datetime.timedelta(weeks=1).total_seconds()) - request, channel = self.make_request( - b"GET", "/sync", access_token=tok, - ) + request, channel = self.make_request(b"GET", "/sync", access_token=tok) self.render(request) self.assertEquals(channel.result["code"], b"403", channel.result) self.assertEquals( - channel.json_body["errcode"], Codes.EXPIRED_ACCOUNT, channel.result, + channel.json_body["errcode"], Codes.EXPIRED_ACCOUNT, channel.result ) def test_manual_renewal(self): @@ -253,21 +242,17 @@ def test_manual_renewal(self): admin_tok = self.login("admin", "adminpassword") url = "/_matrix/client/unstable/admin/account_validity/validity" - params = { - "user_id": user_id, - } + params = {"user_id": user_id} request_data = json.dumps(params) request, channel = self.make_request( - b"POST", url, request_data, access_token=admin_tok, + b"POST", url, request_data, access_token=admin_tok ) self.render(request) self.assertEquals(channel.result["code"], b"200", channel.result) # The specific endpoint doesn't matter, all we need is an authenticated # endpoint. - request, channel = self.make_request( - b"GET", "/sync", access_token=tok, - ) + request, channel = self.make_request(b"GET", "/sync", access_token=tok) self.render(request) self.assertEquals(channel.result["code"], b"200", channel.result) @@ -286,20 +271,18 @@ def test_manual_expire(self): } request_data = json.dumps(params) request, channel = self.make_request( - b"POST", url, request_data, access_token=admin_tok, + b"POST", url, request_data, access_token=admin_tok ) self.render(request) self.assertEquals(channel.result["code"], b"200", channel.result) # The specific endpoint doesn't matter, all we need is an authenticated # endpoint. - request, channel = self.make_request( - b"GET", "/sync", access_token=tok, - ) + request, channel = self.make_request(b"GET", "/sync", access_token=tok) self.render(request) self.assertEquals(channel.result["code"], b"403", channel.result) self.assertEquals( - channel.json_body["errcode"], Codes.EXPIRED_ACCOUNT, channel.result, + channel.json_body["errcode"], Codes.EXPIRED_ACCOUNT, channel.result ) @@ -358,10 +341,15 @@ def test_renewal_email(self): # We need to manually add an email address otherwise the handler will do # nothing. now = self.hs.clock.time_msec() - self.get_success(self.store.user_add_threepid( - user_id=user_id, medium="email", address="kermit@example.com", - validated_at=now, added_at=now, - )) + self.get_success( + self.store.user_add_threepid( + user_id=user_id, + medium="email", + address="kermit@example.com", + validated_at=now, + added_at=now, + ) + ) # Move 6 days forward. This should trigger a renewal email to be sent. self.reactor.advance(datetime.timedelta(days=6).total_seconds()) @@ -379,9 +367,7 @@ def test_renewal_email(self): # our access token should be denied from now, otherwise they should # succeed. self.reactor.advance(datetime.timedelta(days=3).total_seconds()) - request, channel = self.make_request( - b"GET", "/sync", access_token=tok, - ) + request, channel = self.make_request(b"GET", "/sync", access_token=tok) self.render(request) self.assertEquals(channel.result["code"], b"200", channel.result) @@ -393,13 +379,19 @@ def test_manual_email_send(self): # We need to manually add an email address otherwise the handler will do # nothing. now = self.hs.clock.time_msec() - self.get_success(self.store.user_add_threepid( - user_id=user_id, medium="email", address="kermit@example.com", - validated_at=now, added_at=now, - )) + self.get_success( + self.store.user_add_threepid( + user_id=user_id, + medium="email", + address="kermit@example.com", + validated_at=now, + added_at=now, + ) + ) request, channel = self.make_request( - b"POST", "/_matrix/client/unstable/account_validity/send_mail", + b"POST", + "/_matrix/client/unstable/account_validity/send_mail", access_token=tok, ) self.render(request) diff --git a/tests/rest/media/v1/test_base.py b/tests/rest/media/v1/test_base.py index af8f74eb4217..00688a732521 100644 --- a/tests/rest/media/v1/test_base.py +++ b/tests/rest/media/v1/test_base.py @@ -26,20 +26,14 @@ class GetFileNameFromHeadersTests(unittest.TestCase): b'inline; filename="aze%20rty"': u"aze%20rty", b'inline; filename="aze\"rty"': u'aze"rty', b'inline; filename="azer;ty"': u"azer;ty", - b"inline; filename*=utf-8''foo%C2%A3bar": u"foo£bar", } def tests(self): for hdr, expected in self.TEST_CASES.items(): - res = get_filename_from_headers( - { - b'Content-Disposition': [hdr], - }, - ) + res = get_filename_from_headers({b'Content-Disposition': [hdr]}) self.assertEqual( - res, expected, - "expected output for %s to be %s but was %s" % ( - hdr, expected, res, - ) + res, + expected, + "expected output for %s to be %s but was %s" % (hdr, expected, res), ) diff --git a/tests/rest/test_well_known.py b/tests/rest/test_well_known.py index 8d8f03e00580..b090bb974cf9 100644 --- a/tests/rest/test_well_known.py +++ b/tests/rest/test_well_known.py @@ -31,27 +31,24 @@ def test_well_known(self): self.hs.config.default_identity_server = "https://testis" request, channel = self.make_request( - "GET", - "/.well-known/matrix/client", - shorthand=False, + "GET", "/.well-known/matrix/client", shorthand=False ) self.render(request) self.assertEqual(request.code, 200) self.assertEqual( - channel.json_body, { + channel.json_body, + { "m.homeserver": {"base_url": "https://tesths"}, "m.identity_server": {"base_url": "https://testis"}, - } + }, ) def test_well_known_no_public_baseurl(self): self.hs.config.public_baseurl = None request, channel = self.make_request( - "GET", - "/.well-known/matrix/client", - shorthand=False, + "GET", "/.well-known/matrix/client", shorthand=False ) self.render(request) diff --git a/tests/server.py b/tests/server.py index 8f89f4a83d44..fc4134548850 100644 --- a/tests/server.py +++ b/tests/server.py @@ -182,7 +182,8 @@ def make_request( if federation_auth_origin is not None: req.requestHeaders.addRawHeader( - b"Authorization", b"X-Matrix origin=%s,key=,sig=" % (federation_auth_origin,) + b"Authorization", + b"X-Matrix origin=%s,key=,sig=" % (federation_auth_origin,), ) if content: @@ -233,7 +234,7 @@ def __init__(self): class FakeResolver(object): def getHostByName(self, name, timeout=None): if name not in lookups: - return fail(DNSLookupError("OH NO: unknown %s" % (name, ))) + return fail(DNSLookupError("OH NO: unknown %s" % (name,))) return succeed(lookups[name]) self.nameResolver = SimpleResolverComplexifier(FakeResolver()) @@ -454,6 +455,6 @@ def flush(self, maxbytes=None): logger.warning("Exception writing to protocol: %s", e) return - self.buffer = self.buffer[len(to_write):] + self.buffer = self.buffer[len(to_write) :] if self.buffer and self.autoflush: self._reactor.callLater(0.0, self.flush) diff --git a/tests/server_notices/test_resource_limits_server_notices.py b/tests/server_notices/test_resource_limits_server_notices.py index be73e718c27d..a490b81ed4b2 100644 --- a/tests/server_notices/test_resource_limits_server_notices.py +++ b/tests/server_notices/test_resource_limits_server_notices.py @@ -27,7 +27,6 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): - def make_homeserver(self, reactor, clock): hs_config = self.default_config("test") hs_config.server_notices_mxid = "@server:test" diff --git a/tests/state/test_v2.py b/tests/state/test_v2.py index f448b0132672..9c5311d916ad 100644 --- a/tests/state/test_v2.py +++ b/tests/state/test_v2.py @@ -50,6 +50,7 @@ class FakeEvent(object): refer to events. The event_id has node_id as localpart and example.com as domain. """ + def __init__(self, id, sender, type, state_key, content): self.node_id = id self.event_id = EventID(id, "example.com").to_string() @@ -142,24 +143,14 @@ def to_event(self, auth_events, prev_events): content=MEMBERSHIP_CONTENT_JOIN, ), FakeEvent( - id="START", - sender=ZARA, - type=EventTypes.Message, - state_key=None, - content={}, + id="START", sender=ZARA, type=EventTypes.Message, state_key=None, content={} ), FakeEvent( - id="END", - sender=ZARA, - type=EventTypes.Message, - state_key=None, - content={}, + id="END", sender=ZARA, type=EventTypes.Message, state_key=None, content={} ), ] -INITIAL_EDGES = [ - "START", "IMZ", "IMC", "IMB", "IJR", "IPOWER", "IMA", "CREATE", -] +INITIAL_EDGES = ["START", "IMZ", "IMC", "IMB", "IJR", "IPOWER", "IMA", "CREATE"] class StateTestCase(unittest.TestCase): @@ -170,12 +161,7 @@ def test_ban_vs_pl(self): sender=ALICE, type=EventTypes.PowerLevels, state_key="", - content={ - "users": { - ALICE: 100, - BOB: 50, - } - }, + content={"users": {ALICE: 100, BOB: 50}}, ), FakeEvent( id="MA", @@ -196,19 +182,11 @@ def test_ban_vs_pl(self): sender=BOB, type=EventTypes.PowerLevels, state_key='', - content={ - "users": { - ALICE: 100, - BOB: 50, - }, - }, + content={"users": {ALICE: 100, BOB: 50}}, ), ] - edges = [ - ["END", "MB", "MA", "PA", "START"], - ["END", "PB", "PA"], - ] + edges = [["END", "MB", "MA", "PA", "START"], ["END", "PB", "PA"]] expected_state_ids = ["PA", "MA", "MB"] @@ -232,10 +210,7 @@ def test_join_rule_evasion(self): ), ] - edges = [ - ["END", "JR", "START"], - ["END", "ME", "START"], - ] + edges = [["END", "JR", "START"], ["END", "ME", "START"]] expected_state_ids = ["JR"] @@ -248,45 +223,25 @@ def test_offtopic_pl(self): sender=ALICE, type=EventTypes.PowerLevels, state_key="", - content={ - "users": { - ALICE: 100, - BOB: 50, - } - }, + content={"users": {ALICE: 100, BOB: 50}}, ), FakeEvent( id="PB", sender=BOB, type=EventTypes.PowerLevels, state_key='', - content={ - "users": { - ALICE: 100, - BOB: 50, - CHARLIE: 50, - }, - }, + content={"users": {ALICE: 100, BOB: 50, CHARLIE: 50}}, ), FakeEvent( id="PC", sender=CHARLIE, type=EventTypes.PowerLevels, state_key='', - content={ - "users": { - ALICE: 100, - BOB: 50, - CHARLIE: 0, - }, - }, + content={"users": {ALICE: 100, BOB: 50, CHARLIE: 0}}, ), ] - edges = [ - ["END", "PC", "PB", "PA", "START"], - ["END", "PA"], - ] + edges = [["END", "PC", "PB", "PA", "START"], ["END", "PA"]] expected_state_ids = ["PC"] @@ -295,68 +250,38 @@ def test_offtopic_pl(self): def test_topic_basic(self): events = [ FakeEvent( - id="T1", - sender=ALICE, - type=EventTypes.Topic, - state_key="", - content={}, + id="T1", sender=ALICE, type=EventTypes.Topic, state_key="", content={} ), FakeEvent( id="PA1", sender=ALICE, type=EventTypes.PowerLevels, state_key='', - content={ - "users": { - ALICE: 100, - BOB: 50, - }, - }, + content={"users": {ALICE: 100, BOB: 50}}, ), FakeEvent( - id="T2", - sender=ALICE, - type=EventTypes.Topic, - state_key="", - content={}, + id="T2", sender=ALICE, type=EventTypes.Topic, state_key="", content={} ), FakeEvent( id="PA2", sender=ALICE, type=EventTypes.PowerLevels, state_key='', - content={ - "users": { - ALICE: 100, - BOB: 0, - }, - }, + content={"users": {ALICE: 100, BOB: 0}}, ), FakeEvent( id="PB", sender=BOB, type=EventTypes.PowerLevels, state_key='', - content={ - "users": { - ALICE: 100, - BOB: 50, - }, - }, + content={"users": {ALICE: 100, BOB: 50}}, ), FakeEvent( - id="T3", - sender=BOB, - type=EventTypes.Topic, - state_key="", - content={}, + id="T3", sender=BOB, type=EventTypes.Topic, state_key="", content={} ), ] - edges = [ - ["END", "PA2", "T2", "PA1", "T1", "START"], - ["END", "T3", "PB", "PA1"], - ] + edges = [["END", "PA2", "T2", "PA1", "T1", "START"], ["END", "T3", "PB", "PA1"]] expected_state_ids = ["PA2", "T2"] @@ -365,30 +290,17 @@ def test_topic_basic(self): def test_topic_reset(self): events = [ FakeEvent( - id="T1", - sender=ALICE, - type=EventTypes.Topic, - state_key="", - content={}, + id="T1", sender=ALICE, type=EventTypes.Topic, state_key="", content={} ), FakeEvent( id="PA", sender=ALICE, type=EventTypes.PowerLevels, state_key='', - content={ - "users": { - ALICE: 100, - BOB: 50, - }, - }, + content={"users": {ALICE: 100, BOB: 50}}, ), FakeEvent( - id="T2", - sender=BOB, - type=EventTypes.Topic, - state_key="", - content={}, + id="T2", sender=BOB, type=EventTypes.Topic, state_key="", content={} ), FakeEvent( id="MB", @@ -399,10 +311,7 @@ def test_topic_reset(self): ), ] - edges = [ - ["END", "MB", "T2", "PA", "T1", "START"], - ["END", "T1"], - ] + edges = [["END", "MB", "T2", "PA", "T1", "START"], ["END", "T1"]] expected_state_ids = ["T1", "MB", "PA"] @@ -411,61 +320,34 @@ def test_topic_reset(self): def test_topic(self): events = [ FakeEvent( - id="T1", - sender=ALICE, - type=EventTypes.Topic, - state_key="", - content={}, + id="T1", sender=ALICE, type=EventTypes.Topic, state_key="", content={} ), FakeEvent( id="PA1", sender=ALICE, type=EventTypes.PowerLevels, state_key='', - content={ - "users": { - ALICE: 100, - BOB: 50, - }, - }, + content={"users": {ALICE: 100, BOB: 50}}, ), FakeEvent( - id="T2", - sender=ALICE, - type=EventTypes.Topic, - state_key="", - content={}, + id="T2", sender=ALICE, type=EventTypes.Topic, state_key="", content={} ), FakeEvent( id="PA2", sender=ALICE, type=EventTypes.PowerLevels, state_key='', - content={ - "users": { - ALICE: 100, - BOB: 0, - }, - }, + content={"users": {ALICE: 100, BOB: 0}}, ), FakeEvent( id="PB", sender=BOB, type=EventTypes.PowerLevels, state_key='', - content={ - "users": { - ALICE: 100, - BOB: 50, - }, - }, + content={"users": {ALICE: 100, BOB: 50}}, ), FakeEvent( - id="T3", - sender=BOB, - type=EventTypes.Topic, - state_key="", - content={}, + id="T3", sender=BOB, type=EventTypes.Topic, state_key="", content={} ), FakeEvent( id="MZ1", @@ -475,11 +357,7 @@ def test_topic(self): content={}, ), FakeEvent( - id="T4", - sender=ALICE, - type=EventTypes.Topic, - state_key="", - content={}, + id="T4", sender=ALICE, type=EventTypes.Topic, state_key="", content={} ), ] @@ -587,13 +465,7 @@ def do_check(self, events, edges, expected_state_ids): class LexicographicalTestCase(unittest.TestCase): def test_simple(self): - graph = { - "l": {"o"}, - "m": {"n", "o"}, - "n": {"o"}, - "o": set(), - "p": {"o"}, - } + graph = {"l": {"o"}, "m": {"n", "o"}, "n": {"o"}, "o": set(), "p": {"o"}} res = list(lexicographical_topological_sort(graph, key=lambda x: x)) @@ -680,7 +552,13 @@ def setUp(self): self.expected_combined_state = { (e.type, e.state_key): e.event_id - for e in [create_event, alice_member, join_rules, bob_member, charlie_member] + for e in [ + create_event, + alice_member, + join_rules, + bob_member, + charlie_member, + ] } def test_event_map_none(self): @@ -720,11 +598,7 @@ def get_events(self, event_ids, allow_rejected=False): Deferred[dict[str, FrozenEvent]]: Dict from event_id to event. """ - return { - eid: self.event_map[eid] - for eid in event_ids - if eid in self.event_map - } + return {eid: self.event_map[eid] for eid in event_ids if eid in self.event_map} def get_auth_chain(self, event_ids): """Gets the full auth chain for a set of events (including rejected diff --git a/tests/storage/test_background_update.py b/tests/storage/test_background_update.py index 5568a607c774..fbb930269431 100644 --- a/tests/storage/test_background_update.py +++ b/tests/storage/test_background_update.py @@ -9,9 +9,7 @@ class BackgroundUpdateTestCase(unittest.TestCase): @defer.inlineCallbacks def setUp(self): - hs = yield setup_test_homeserver( - self.addCleanup - ) + hs = yield setup_test_homeserver(self.addCleanup) self.store = hs.get_datastore() self.clock = hs.get_clock() diff --git a/tests/storage/test_base.py b/tests/storage/test_base.py index f18db8c38465..c778de1f0c05 100644 --- a/tests/storage/test_base.py +++ b/tests/storage/test_base.py @@ -56,10 +56,7 @@ def runWithConnection(func, *args, **kwargs): fake_engine = Mock(wraps=engine) fake_engine.can_native_upsert = False hs = TestHomeServer( - "test", - db_pool=self.db_pool, - config=config, - database_engine=fake_engine, + "test", db_pool=self.db_pool, config=config, database_engine=fake_engine ) self.datastore = SQLBaseStore(None, hs) diff --git a/tests/storage/test_end_to_end_keys.py b/tests/storage/test_end_to_end_keys.py index 11fb8c0c198d..cd2bcd4ca301 100644 --- a/tests/storage/test_end_to_end_keys.py +++ b/tests/storage/test_end_to_end_keys.py @@ -20,7 +20,6 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase): - @defer.inlineCallbacks def setUp(self): hs = yield tests.utils.setup_test_homeserver(self.addCleanup) diff --git a/tests/storage/test_monthly_active_users.py b/tests/storage/test_monthly_active_users.py index d6569a82bbbf..f458c03054f0 100644 --- a/tests/storage/test_monthly_active_users.py +++ b/tests/storage/test_monthly_active_users.py @@ -56,8 +56,7 @@ def test_initialise_reserved_users(self): self.store.register(user_id=user1, token="123", password_hash=None) self.store.register(user_id=user2, token="456", password_hash=None) self.store.register( - user_id=user3, token="789", - password_hash=None, user_type=UserTypes.SUPPORT + user_id=user3, token="789", password_hash=None, user_type=UserTypes.SUPPORT ) self.pump() @@ -173,9 +172,7 @@ def test_populate_monthly_users_is_guest(self): def test_populate_monthly_users_should_update(self): self.store.upsert_monthly_active_user = Mock() - self.store.is_trial_user = Mock( - return_value=defer.succeed(False) - ) + self.store.is_trial_user = Mock(return_value=defer.succeed(False)) self.store.user_last_seen_monthly_active = Mock( return_value=defer.succeed(None) @@ -187,13 +184,9 @@ def test_populate_monthly_users_should_update(self): def test_populate_monthly_users_should_not_update(self): self.store.upsert_monthly_active_user = Mock() - self.store.is_trial_user = Mock( - return_value=defer.succeed(False) - ) + self.store.is_trial_user = Mock(return_value=defer.succeed(False)) self.store.user_last_seen_monthly_active = Mock( - return_value=defer.succeed( - self.hs.get_clock().time_msec() - ) + return_value=defer.succeed(self.hs.get_clock().time_msec()) ) self.store.populate_monthly_active_users('user_id') self.pump() @@ -243,7 +236,7 @@ def test_support_user_not_add_to_mau_limits(self): user_id=support_user_id, token="123", password_hash=None, - user_type=UserTypes.SUPPORT + user_type=UserTypes.SUPPORT, ) self.store.upsert_monthly_active_user(support_user_id) diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py index 0fc5019e9f46..4823d44decfc 100644 --- a/tests/storage/test_redaction.py +++ b/tests/storage/test_redaction.py @@ -60,7 +60,7 @@ def inject_room_member( "state_key": user.to_string(), "room_id": room.to_string(), "content": content, - } + }, ) event, context = yield self.event_creation_handler.create_new_client_event( @@ -83,7 +83,7 @@ def inject_message(self, room, user, body): "state_key": user.to_string(), "room_id": room.to_string(), "content": {"body": body, "msgtype": u"message"}, - } + }, ) event, context = yield self.event_creation_handler.create_new_client_event( @@ -105,7 +105,7 @@ def inject_redaction(self, room, event_id, user, reason): "room_id": room.to_string(), "content": {"reason": reason}, "redacts": event_id, - } + }, ) event, context = yield self.event_creation_handler.create_new_client_event( diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py index cb3cc4d2e507..c0e0155bb4a6 100644 --- a/tests/storage/test_registration.py +++ b/tests/storage/test_registration.py @@ -116,7 +116,7 @@ def test_is_support_user(self): user_id=SUPPORT_USER, token="456", password_hash=None, - user_type=UserTypes.SUPPORT + user_type=UserTypes.SUPPORT, ) res = yield self.store.is_support_user(SUPPORT_USER) self.assertTrue(res) diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py index 063387863eb3..73ed943f5a39 100644 --- a/tests/storage/test_roommember.py +++ b/tests/storage/test_roommember.py @@ -58,7 +58,7 @@ def inject_room_member(self, room, user, membership, replaces_state=None): "state_key": user.to_string(), "room_id": room.to_string(), "content": {"membership": membership}, - } + }, ) event, context = yield self.event_creation_handler.create_new_client_event( diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py index 78e260a7faff..b6169436de49 100644 --- a/tests/storage/test_state.py +++ b/tests/storage/test_state.py @@ -29,7 +29,6 @@ class StateStoreTestCase(tests.unittest.TestCase): - @defer.inlineCallbacks def setUp(self): hs = yield tests.utils.setup_test_homeserver(self.addCleanup) @@ -57,7 +56,7 @@ def inject_state_event(self, room, sender, typ, state_key, content): "state_key": state_key, "room_id": room.to_string(), "content": content, - } + }, ) event, context = yield self.event_creation_handler.create_new_client_event( @@ -83,15 +82,14 @@ def test_get_state_groups_ids(self): self.room, self.u_alice, EventTypes.Name, '', {"name": "test room"} ) - state_group_map = yield self.store.get_state_groups_ids(self.room, [e2.event_id]) + state_group_map = yield self.store.get_state_groups_ids( + self.room, [e2.event_id] + ) self.assertEqual(len(state_group_map), 1) state_map = list(state_group_map.values())[0] self.assertDictEqual( state_map, - { - (EventTypes.Create, ''): e1.event_id, - (EventTypes.Name, ''): e2.event_id, - }, + {(EventTypes.Create, ''): e1.event_id, (EventTypes.Name, ''): e2.event_id}, ) @defer.inlineCallbacks @@ -103,15 +101,11 @@ def test_get_state_groups(self): self.room, self.u_alice, EventTypes.Name, '', {"name": "test room"} ) - state_group_map = yield self.store.get_state_groups( - self.room, [e2.event_id]) + state_group_map = yield self.store.get_state_groups(self.room, [e2.event_id]) self.assertEqual(len(state_group_map), 1) state_list = list(state_group_map.values())[0] - self.assertEqual( - {ev.event_id for ev in state_list}, - {e1.event_id, e2.event_id}, - ) + self.assertEqual({ev.event_id for ev in state_list}, {e1.event_id, e2.event_id}) @defer.inlineCallbacks def test_get_state_for_event(self): @@ -147,9 +141,7 @@ def test_get_state_for_event(self): ) # check we get the full state as of the final event - state = yield self.store.get_state_for_event( - e5.event_id, - ) + state = yield self.store.get_state_for_event(e5.event_id) self.assertIsNotNone(e4) @@ -194,7 +186,7 @@ def test_get_state_for_event(self): state_filter=StateFilter( types={EventTypes.Member: {self.u_alice.to_string()}}, include_others=True, - ) + ), ) self.assertStateMapEqual( @@ -208,9 +200,9 @@ def test_get_state_for_event(self): # check that we can grab everything except members state = yield self.store.get_state_for_event( - e5.event_id, state_filter=StateFilter( - types={EventTypes.Member: set()}, - include_others=True, + e5.event_id, + state_filter=StateFilter( + types={EventTypes.Member: set()}, include_others=True ), ) @@ -229,10 +221,10 @@ def test_get_state_for_event(self): # test _get_state_for_group_using_cache correctly filters out members # with types=[] (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( - self.store._state_group_cache, group, + self.store._state_group_cache, + group, state_filter=StateFilter( - types={EventTypes.Member: set()}, - include_others=True, + types={EventTypes.Member: set()}, include_others=True ), ) @@ -249,8 +241,7 @@ def test_get_state_for_event(self): self.store._state_group_members_cache, group, state_filter=StateFilter( - types={EventTypes.Member: set()}, - include_others=True, + types={EventTypes.Member: set()}, include_others=True ), ) @@ -263,8 +254,7 @@ def test_get_state_for_event(self): self.store._state_group_cache, group, state_filter=StateFilter( - types={EventTypes.Member: None}, - include_others=True, + types={EventTypes.Member: None}, include_others=True ), ) @@ -281,8 +271,7 @@ def test_get_state_for_event(self): self.store._state_group_members_cache, group, state_filter=StateFilter( - types={EventTypes.Member: None}, - include_others=True, + types={EventTypes.Member: None}, include_others=True ), ) @@ -302,8 +291,7 @@ def test_get_state_for_event(self): self.store._state_group_cache, group, state_filter=StateFilter( - types={EventTypes.Member: {e5.state_key}}, - include_others=True, + types={EventTypes.Member: {e5.state_key}}, include_others=True ), ) @@ -320,8 +308,7 @@ def test_get_state_for_event(self): self.store._state_group_members_cache, group, state_filter=StateFilter( - types={EventTypes.Member: {e5.state_key}}, - include_others=True, + types={EventTypes.Member: {e5.state_key}}, include_others=True ), ) @@ -334,8 +321,7 @@ def test_get_state_for_event(self): self.store._state_group_members_cache, group, state_filter=StateFilter( - types={EventTypes.Member: {e5.state_key}}, - include_others=False, + types={EventTypes.Member: {e5.state_key}}, include_others=False ), ) @@ -384,10 +370,10 @@ def test_get_state_for_event(self): # with types=[] room_id = self.room.to_string() (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( - self.store._state_group_cache, group, + self.store._state_group_cache, + group, state_filter=StateFilter( - types={EventTypes.Member: set()}, - include_others=True, + types={EventTypes.Member: set()}, include_others=True ), ) @@ -399,8 +385,7 @@ def test_get_state_for_event(self): self.store._state_group_members_cache, group, state_filter=StateFilter( - types={EventTypes.Member: set()}, - include_others=True, + types={EventTypes.Member: set()}, include_others=True ), ) @@ -413,8 +398,7 @@ def test_get_state_for_event(self): self.store._state_group_cache, group, state_filter=StateFilter( - types={EventTypes.Member: None}, - include_others=True, + types={EventTypes.Member: None}, include_others=True ), ) @@ -425,8 +409,7 @@ def test_get_state_for_event(self): self.store._state_group_members_cache, group, state_filter=StateFilter( - types={EventTypes.Member: None}, - include_others=True, + types={EventTypes.Member: None}, include_others=True ), ) @@ -445,8 +428,7 @@ def test_get_state_for_event(self): self.store._state_group_cache, group, state_filter=StateFilter( - types={EventTypes.Member: {e5.state_key}}, - include_others=True, + types={EventTypes.Member: {e5.state_key}}, include_others=True ), ) @@ -457,8 +439,7 @@ def test_get_state_for_event(self): self.store._state_group_members_cache, group, state_filter=StateFilter( - types={EventTypes.Member: {e5.state_key}}, - include_others=True, + types={EventTypes.Member: {e5.state_key}}, include_others=True ), ) @@ -471,8 +452,7 @@ def test_get_state_for_event(self): self.store._state_group_cache, group, state_filter=StateFilter( - types={EventTypes.Member: {e5.state_key}}, - include_others=False, + types={EventTypes.Member: {e5.state_key}}, include_others=False ), ) @@ -483,8 +463,7 @@ def test_get_state_for_event(self): self.store._state_group_members_cache, group, state_filter=StateFilter( - types={EventTypes.Member: {e5.state_key}}, - include_others=False, + types={EventTypes.Member: {e5.state_key}}, include_others=False ), ) diff --git a/tests/storage/test_user_directory.py b/tests/storage/test_user_directory.py index fd3361404fc3..d7d244ce97d7 100644 --- a/tests/storage/test_user_directory.py +++ b/tests/storage/test_user_directory.py @@ -36,9 +36,7 @@ def setUp(self): yield self.store.update_profile_in_user_dir(ALICE, "alice", None) yield self.store.update_profile_in_user_dir(BOB, "bob", None) yield self.store.update_profile_in_user_dir(BOBBY, "bobby", None) - yield self.store.add_users_in_public_rooms( - "!room:id", (ALICE, BOB) - ) + yield self.store.add_users_in_public_rooms("!room:id", (ALICE, BOB)) @defer.inlineCallbacks def test_search_user_dir(self): diff --git a/tests/test_event_auth.py b/tests/test_event_auth.py index 4c8f87e95813..8b2741d27704 100644 --- a/tests/test_event_auth.py +++ b/tests/test_event_auth.py @@ -37,7 +37,9 @@ def test_random_users_cannot_send_state_before_first_pl(self): # creator should be able to send state event_auth.check( - RoomVersions.V1.identifier, _random_state_event(creator), auth_events, + RoomVersions.V1.identifier, + _random_state_event(creator), + auth_events, do_sig_check=False, ) @@ -82,7 +84,9 @@ def test_state_default_level(self): # king should be able to send state event_auth.check( - RoomVersions.V1.identifier, _random_state_event(king), auth_events, + RoomVersions.V1.identifier, + _random_state_event(king), + auth_events, do_sig_check=False, ) diff --git a/tests/test_federation.py b/tests/test_federation.py index 1a5dc32c88ae..6a8339b5614f 100644 --- a/tests/test_federation.py +++ b/tests/test_federation.py @@ -1,4 +1,3 @@ - from mock import Mock from twisted.internet.defer import maybeDeferred, succeed diff --git a/tests/test_mau.py b/tests/test_mau.py index 00be1a8c21f4..1fbe0d51fffc 100644 --- a/tests/test_mau.py +++ b/tests/test_mau.py @@ -33,9 +33,7 @@ class TestMauLimit(unittest.HomeserverTestCase): def make_homeserver(self, reactor, clock): self.hs = self.setup_test_homeserver( - "red", - http_client=None, - federation_client=Mock(), + "red", http_client=None, federation_client=Mock() ) self.store = self.hs.get_datastore() @@ -210,9 +208,7 @@ def create_user(self, localpart): return access_token def do_sync_for_user(self, token): - request, channel = self.make_request( - "GET", "/sync", access_token=token - ) + request, channel = self.make_request("GET", "/sync", access_token=token) self.render(request) if channel.code != 200: diff --git a/tests/test_metrics.py b/tests/test_metrics.py index 0ff6d0e283cb..2edbae5c6d7b 100644 --- a/tests/test_metrics.py +++ b/tests/test_metrics.py @@ -44,9 +44,7 @@ def get_sample_labels_value(sample): class TestMauLimit(unittest.TestCase): def test_basic(self): gauge = InFlightGauge( - "test1", "", - labels=["test_label"], - sub_metrics=["foo", "bar"], + "test1", "", labels=["test_label"], sub_metrics=["foo", "bar"] ) def handle1(metrics): @@ -59,37 +57,49 @@ def handle2(metrics): gauge.register(("key1",), handle1) - self.assert_dict({ - "test1_total": {("key1",): 1}, - "test1_foo": {("key1",): 2}, - "test1_bar": {("key1",): 5}, - }, self.get_metrics_from_gauge(gauge)) + self.assert_dict( + { + "test1_total": {("key1",): 1}, + "test1_foo": {("key1",): 2}, + "test1_bar": {("key1",): 5}, + }, + self.get_metrics_from_gauge(gauge), + ) gauge.unregister(("key1",), handle1) - self.assert_dict({ - "test1_total": {("key1",): 0}, - "test1_foo": {("key1",): 0}, - "test1_bar": {("key1",): 0}, - }, self.get_metrics_from_gauge(gauge)) + self.assert_dict( + { + "test1_total": {("key1",): 0}, + "test1_foo": {("key1",): 0}, + "test1_bar": {("key1",): 0}, + }, + self.get_metrics_from_gauge(gauge), + ) gauge.register(("key1",), handle1) gauge.register(("key2",), handle2) - self.assert_dict({ - "test1_total": {("key1",): 1, ("key2",): 1}, - "test1_foo": {("key1",): 2, ("key2",): 3}, - "test1_bar": {("key1",): 5, ("key2",): 7}, - }, self.get_metrics_from_gauge(gauge)) + self.assert_dict( + { + "test1_total": {("key1",): 1, ("key2",): 1}, + "test1_foo": {("key1",): 2, ("key2",): 3}, + "test1_bar": {("key1",): 5, ("key2",): 7}, + }, + self.get_metrics_from_gauge(gauge), + ) gauge.unregister(("key2",), handle2) gauge.register(("key1",), handle2) - self.assert_dict({ - "test1_total": {("key1",): 2, ("key2",): 0}, - "test1_foo": {("key1",): 5, ("key2",): 0}, - "test1_bar": {("key1",): 7, ("key2",): 0}, - }, self.get_metrics_from_gauge(gauge)) + self.assert_dict( + { + "test1_total": {("key1",): 2, ("key2",): 0}, + "test1_foo": {("key1",): 5, ("key2",): 0}, + "test1_bar": {("key1",): 7, ("key2",): 0}, + }, + self.get_metrics_from_gauge(gauge), + ) def get_metrics_from_gauge(self, gauge): results = {} diff --git a/tests/test_terms_auth.py b/tests/test_terms_auth.py index 0968e86a7b03..f412985d2c8c 100644 --- a/tests/test_terms_auth.py +++ b/tests/test_terms_auth.py @@ -69,10 +69,10 @@ def test_ui_auth(self): "name": "My Cool Privacy Policy", "url": "https://example.org/_matrix/consent?v=1.0", }, - "version": "1.0" - }, - }, - }, + "version": "1.0", + } + } + } } self.assertIsInstance(channel.json_body["params"], dict) self.assertDictContainsSubset(channel.json_body["params"], expected_params) diff --git a/tests/test_types.py b/tests/test_types.py index d314a7ff58ec..d83c36559fb2 100644 --- a/tests/test_types.py +++ b/tests/test_types.py @@ -94,8 +94,7 @@ def testUpperCase(self): def testSymbols(self): self.assertEqual( - map_username_to_mxid_localpart("test=$?_1234"), - "test=3d=24=3f_1234", + map_username_to_mxid_localpart("test=$?_1234"), "test=3d=24=3f_1234" ) def testLeadingUnderscore(self): @@ -105,6 +104,5 @@ def testNonAscii(self): # this should work with either a unicode or a bytes self.assertEqual(map_username_to_mxid_localpart(u'têst'), "t=c3=aast") self.assertEqual( - map_username_to_mxid_localpart(u'têst'.encode('utf-8')), - "t=c3=aast", + map_username_to_mxid_localpart(u'têst'.encode('utf-8')), "t=c3=aast" ) diff --git a/tests/test_utils/logging_setup.py b/tests/test_utils/logging_setup.py index d0bc8e2112cd..fde0baee8ee6 100644 --- a/tests/test_utils/logging_setup.py +++ b/tests/test_utils/logging_setup.py @@ -22,6 +22,7 @@ class ToTwistedHandler(logging.Handler): """logging handler which sends the logs to the twisted log""" + tx_log = twisted.logger.Logger() def emit(self, record): @@ -41,7 +42,8 @@ def setup_logging(): root_logger = logging.getLogger() log_format = ( - "%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(request)s - %(message)s" + "%(asctime)s - %(name)s - %(lineno)d - " + "%(levelname)s - %(request)s - %(message)s" ) handler = ToTwistedHandler() diff --git a/tests/test_visibility.py b/tests/test_visibility.py index 3bdb500514fc..6a180ddc3229 100644 --- a/tests/test_visibility.py +++ b/tests/test_visibility.py @@ -132,7 +132,7 @@ def inject_visibility(self, user_id, visibility): "state_key": "", "room_id": TEST_ROOM_ID, "content": content, - } + }, ) event, context = yield self.event_creation_handler.create_new_client_event( @@ -153,7 +153,7 @@ def inject_room_member(self, user_id, membership="join", extra_content={}): "state_key": user_id, "room_id": TEST_ROOM_ID, "content": content, - } + }, ) event, context = yield self.event_creation_handler.create_new_client_event( @@ -174,7 +174,7 @@ def inject_message(self, user_id, content=None): "sender": user_id, "room_id": TEST_ROOM_ID, "content": content, - } + }, ) event, context = yield self.event_creation_handler.create_new_client_event( diff --git a/tests/unittest.py b/tests/unittest.py index 029a88d770a7..94df8cf47ebf 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -84,9 +84,8 @@ def setUp(orig): # all future bets are off. if LoggingContext.current_context() is not LoggingContext.sentinel: self.fail( - "Test starting with non-sentinel logging context %s" % ( - LoggingContext.current_context(), - ) + "Test starting with non-sentinel logging context %s" + % (LoggingContext.current_context(),) ) old_level = logging.getLogger().level @@ -300,7 +299,13 @@ def make_request( content = json.dumps(content).encode('utf8') return make_request( - self.reactor, method, path, content, access_token, request, shorthand, + self.reactor, + method, + path, + content, + access_token, + request, + shorthand, federation_auth_origin, ) diff --git a/tests/util/test_async_utils.py b/tests/util/test_async_utils.py index 84dd71e47a96..bf85d3b8ecff 100644 --- a/tests/util/test_async_utils.py +++ b/tests/util/test_async_utils.py @@ -42,10 +42,10 @@ def canceller(_d): self.assertNoResult(timing_out_d) self.assertFalse(cancelled[0], "deferred was cancelled prematurely") - self.clock.pump((1.0, )) + self.clock.pump((1.0,)) self.assertTrue(cancelled[0], "deferred was not cancelled by timeout") - self.failureResultOf(timing_out_d, defer.TimeoutError, ) + self.failureResultOf(timing_out_d, defer.TimeoutError) def test_times_out_when_canceller_throws(self): """Test that we have successfully worked around @@ -59,9 +59,9 @@ def canceller(_d): self.assertNoResult(timing_out_d) - self.clock.pump((1.0, )) + self.clock.pump((1.0,)) - self.failureResultOf(timing_out_d, defer.TimeoutError, ) + self.failureResultOf(timing_out_d, defer.TimeoutError) def test_logcontext_is_preserved_on_cancellation(self): blocking_was_cancelled = [False] @@ -80,10 +80,10 @@ def blocking(): # the errbacks should be run in the test logcontext def errback(res, deferred_name): self.assertIs( - LoggingContext.current_context(), context_one, - "errback %s run in unexpected logcontext %s" % ( - deferred_name, LoggingContext.current_context(), - ) + LoggingContext.current_context(), + context_one, + "errback %s run in unexpected logcontext %s" + % (deferred_name, LoggingContext.current_context()), ) return res @@ -94,11 +94,10 @@ def errback(res, deferred_name): self.assertIs(LoggingContext.current_context(), LoggingContext.sentinel) timing_out_d.addErrback(errback, "timingout") - self.clock.pump((1.0, )) + self.clock.pump((1.0,)) self.assertTrue( - blocking_was_cancelled[0], - "non-completing deferred was not cancelled", + blocking_was_cancelled[0], "non-completing deferred was not cancelled" ) - self.failureResultOf(timing_out_d, defer.TimeoutError, ) + self.failureResultOf(timing_out_d, defer.TimeoutError) self.assertIs(LoggingContext.current_context(), context_one) diff --git a/tests/utils.py b/tests/utils.py index cb75514851ea..c2ef4b0bb580 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -68,7 +68,9 @@ def setupdb(): # connect to postgres to create the base database. db_conn = db_engine.module.connect( - user=POSTGRES_USER, host=POSTGRES_HOST, password=POSTGRES_PASSWORD, + user=POSTGRES_USER, + host=POSTGRES_HOST, + password=POSTGRES_PASSWORD, dbname=POSTGRES_DBNAME_FOR_INITIAL_CREATE, ) db_conn.autocommit = True @@ -94,7 +96,9 @@ def setupdb(): def _cleanup(): db_conn = db_engine.module.connect( - user=POSTGRES_USER, host=POSTGRES_HOST, password=POSTGRES_PASSWORD, + user=POSTGRES_USER, + host=POSTGRES_HOST, + password=POSTGRES_PASSWORD, dbname=POSTGRES_DBNAME_FOR_INITIAL_CREATE, ) db_conn.autocommit = True @@ -114,7 +118,6 @@ def default_config(name): "server_name": name, "media_store_path": "media", "uploads_path": "uploads", - # the test signing key is just an arbitrary ed25519 key to keep the config # parser happy "signing_key": "ed25519 a_lPym qvioDNmfExFBRPgdTU+wtFYKq4JfwFRv7sYVgWvmgJg", From ee90c06e38c16f9ca6d5e01c081ee99720fafafb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Christoph=20M=C3=BCller?= Date: Fri, 10 May 2019 10:09:25 +0200 Subject: [PATCH 6/9] Set syslog identifiers in systemd units (#5023) --- changelog.d/5023.feature | 3 +++ .../system/matrix-synapse-worker@.service | 1 + contrib/systemd-with-workers/system/matrix-synapse.service | 1 + contrib/systemd/matrix-synapse.service | 2 +- debian/changelog | 7 +++++++ debian/matrix-synapse.service | 1 + 6 files changed, 14 insertions(+), 1 deletion(-) create mode 100644 changelog.d/5023.feature diff --git a/changelog.d/5023.feature b/changelog.d/5023.feature new file mode 100644 index 000000000000..98551f79a323 --- /dev/null +++ b/changelog.d/5023.feature @@ -0,0 +1,3 @@ +Configure the example systemd units to have a log identifier of `matrix-synapse` +instead of the executable name, `python`. +Contributed by Christoph Müller. diff --git a/contrib/systemd-with-workers/system/matrix-synapse-worker@.service b/contrib/systemd-with-workers/system/matrix-synapse-worker@.service index 912984b9d2f1..9d980d516881 100644 --- a/contrib/systemd-with-workers/system/matrix-synapse-worker@.service +++ b/contrib/systemd-with-workers/system/matrix-synapse-worker@.service @@ -12,6 +12,7 @@ ExecStart=/opt/venvs/matrix-synapse/bin/python -m synapse.app.%i --config-path=/ ExecReload=/bin/kill -HUP $MAINPID Restart=always RestartSec=3 +SyslogIdentifier=matrix-synapse-%i [Install] WantedBy=matrix-synapse.service diff --git a/contrib/systemd-with-workers/system/matrix-synapse.service b/contrib/systemd-with-workers/system/matrix-synapse.service index 8bb4e400dc12..3aae19034cb8 100644 --- a/contrib/systemd-with-workers/system/matrix-synapse.service +++ b/contrib/systemd-with-workers/system/matrix-synapse.service @@ -11,6 +11,7 @@ ExecStart=/opt/venvs/matrix-synapse/bin/python -m synapse.app.homeserver --confi ExecReload=/bin/kill -HUP $MAINPID Restart=always RestartSec=3 +SyslogIdentifier=matrix-synapse [Install] WantedBy=matrix.target diff --git a/contrib/systemd/matrix-synapse.service b/contrib/systemd/matrix-synapse.service index efb157e941ae..595b69916c59 100644 --- a/contrib/systemd/matrix-synapse.service +++ b/contrib/systemd/matrix-synapse.service @@ -22,10 +22,10 @@ Group=nogroup WorkingDirectory=/opt/synapse ExecStart=/opt/synapse/env/bin/python -m synapse.app.homeserver --config-path=/opt/synapse/homeserver.yaml +SyslogIdentifier=matrix-synapse # adjust the cache factor if necessary # Environment=SYNAPSE_CACHE_FACTOR=2.0 [Install] WantedBy=multi-user.target - diff --git a/debian/changelog b/debian/changelog index c25425bf2611..454fa8eb164a 100644 --- a/debian/changelog +++ b/debian/changelog @@ -1,3 +1,10 @@ +matrix-synapse-py3 (0.99.3.2+nmu1) UNRELEASED; urgency=medium + + [ Christoph Müller ] + * Configure the systemd units to have a log identifier of `matrix-synapse` + + -- Christoph Müller Wed, 17 Apr 2019 16:17:32 +0200 + matrix-synapse-py3 (0.99.3.2) stable; urgency=medium * New synapse release 0.99.3.2. diff --git a/debian/matrix-synapse.service b/debian/matrix-synapse.service index 942e4b83fef9..b0a8d72e6d25 100644 --- a/debian/matrix-synapse.service +++ b/debian/matrix-synapse.service @@ -11,6 +11,7 @@ ExecStart=/opt/venvs/matrix-synapse/bin/python -m synapse.app.homeserver --confi ExecReload=/bin/kill -HUP $MAINPID Restart=always RestartSec=3 +SyslogIdentifier=matrix-synapse [Install] WantedBy=multi-user.target From cd3f30014ad7bd71b4c86a34487a7719d3ac8caf Mon Sep 17 00:00:00 2001 From: Gergely Polonkai Date: Fri, 10 May 2019 10:15:08 +0200 Subject: [PATCH 7/9] Make Prometheus snippet less confusing on the metrics collection doc (#4288) Signed-off-by: Gergely Polonkai --- docs/metrics-howto.rst | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/docs/metrics-howto.rst b/docs/metrics-howto.rst index 5bbb5a4f3a22..32b064e2da22 100644 --- a/docs/metrics-howto.rst +++ b/docs/metrics-howto.rst @@ -48,7 +48,10 @@ How to monitor Synapse metrics using Prometheus - job_name: "synapse" metrics_path: "/_synapse/metrics" static_configs: - - targets: ["my.server.here:9092"] + - targets: ["my.server.here:port"] + + where ``my.server.here`` is the IP address of Synapse, and ``port`` is the listener port + configured with the ``metrics`` resource. If your prometheus is older than 1.5.2, you will need to replace ``static_configs`` in the above with ``target_groups``. From 085ae346ace418e0fc043ac5f568f85ebf80038e Mon Sep 17 00:00:00 2001 From: David Baker Date: Fri, 10 May 2019 10:52:24 +0100 Subject: [PATCH 8/9] Add a DUMMY stage to captcha-only registration flow This allows the client to complete the email last which is more natual for the user. Without this stage, if the client would complete the recaptcha (and terms, if enabled) stages and then the registration request would complete because you've now completed a flow, even if you were intending to complete the flow that's the same except has email auth at the end. Adding a dummy auth stage to the recaptcha-only flow means it's always unambiguous which flow the client was trying to complete. Longer term we should think about changing the protocol so the client explicitly says which flow it's trying to complete. https://github.com/vector-im/riot-web/issues/9586 --- synapse/rest/client/v2_alpha/register.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index dc3e265bcd1d..95578b76aea6 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -345,7 +345,7 @@ def on_POST(self, request): if self.hs.config.enable_registration_captcha: # only support 3PIDless registration if no 3PIDs are required if not require_email and not require_msisdn: - flows.extend([[LoginType.RECAPTCHA]]) + flows.extend([[LoginType.RECAPTCHA, LoginType.DUMMY]]) # only support the email-only flow if we don't require MSISDN 3PIDs if not require_msisdn: flows.extend([[LoginType.EMAIL_IDENTITY, LoginType.RECAPTCHA]]) From c2bb7476c93d1f5026a4c8b715b2f70316d3a506 Mon Sep 17 00:00:00 2001 From: David Baker Date: Fri, 10 May 2019 11:08:01 +0100 Subject: [PATCH 9/9] Revert 085ae346ace418e0fc043ac5f568f85ebf80038e Accidentally went straight to develop --- synapse/rest/client/v2_alpha/register.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index 95578b76aea6..dc3e265bcd1d 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -345,7 +345,7 @@ def on_POST(self, request): if self.hs.config.enable_registration_captcha: # only support 3PIDless registration if no 3PIDs are required if not require_email and not require_msisdn: - flows.extend([[LoginType.RECAPTCHA, LoginType.DUMMY]]) + flows.extend([[LoginType.RECAPTCHA]]) # only support the email-only flow if we don't require MSISDN 3PIDs if not require_msisdn: flows.extend([[LoginType.EMAIL_IDENTITY, LoginType.RECAPTCHA]])