From 3aeca2588b79111a48a6083c88efc4d68a2cea19 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 16 Dec 2022 08:53:28 -0500 Subject: [PATCH] Add missing type hints to tests.config. (#14681) --- changelog.d/14681.misc | 1 + mypy.ini | 4 +- synapse/config/cache.py | 4 +- synapse/util/caches/lrucache.py | 9 +--- tests/config/test___main__.py | 6 +-- tests/config/test_background_update.py | 4 +- tests/config/test_base.py | 10 ++--- tests/config/test_cache.py | 57 ++++++++++++------------ tests/config/test_database.py | 2 +- tests/config/test_generate.py | 8 ++-- tests/config/test_load.py | 12 ++--- tests/config/test_ratelimiting.py | 2 +- tests/config/test_registration_config.py | 4 +- tests/config/test_room_directory.py | 4 +- tests/config/test_server.py | 18 ++++---- tests/config/test_tls.py | 53 +++++++++++++--------- tests/config/test_util.py | 2 +- tests/config/utils.py | 11 ++--- 18 files changed, 108 insertions(+), 103 deletions(-) create mode 100644 changelog.d/14681.misc diff --git a/changelog.d/14681.misc b/changelog.d/14681.misc new file mode 100644 index 000000000000..d44571b73149 --- /dev/null +++ b/changelog.d/14681.misc @@ -0,0 +1 @@ +Add missing type hints. diff --git a/mypy.ini b/mypy.ini index 1a37414e581c..80fbcdfeabf0 100644 --- a/mypy.ini +++ b/mypy.ini @@ -36,8 +36,6 @@ exclude = (?x) |tests/api/test_ratelimiting.py |tests/app/test_openid_listener.py |tests/appservice/test_scheduler.py - |tests/config/test_cache.py - |tests/config/test_tls.py |tests/crypto/test_keyring.py |tests/events/test_presence_router.py |tests/events/test_utils.py @@ -89,7 +87,7 @@ disallow_untyped_defs = False [mypy-tests.*] disallow_untyped_defs = False -[mypy-tests.config.test_api] +[mypy-tests.config.*] disallow_untyped_defs = True [mypy-tests.federation.transport.test_client] diff --git a/synapse/config/cache.py b/synapse/config/cache.py index eb4194a5a91b..015b2a138e85 100644 --- a/synapse/config/cache.py +++ b/synapse/config/cache.py @@ -16,7 +16,7 @@ import os import re import threading -from typing import Any, Callable, Dict, Optional +from typing import Any, Callable, Dict, Mapping, Optional import attr @@ -94,7 +94,7 @@ def add_resizable_cache( class CacheConfig(Config): section = "caches" - _environ = os.environ + _environ: Mapping[str, str] = os.environ event_cache_size: int cache_factors: Dict[str, float] diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py index dcf0eac3bf08..452d5d04c1c0 100644 --- a/synapse/util/caches/lrucache.py +++ b/synapse/util/caches/lrucache.py @@ -788,26 +788,21 @@ def __len__(self) -> int: def __contains__(self, key: KT) -> bool: return self.contains(key) - def set_cache_factor(self, factor: float) -> bool: + def set_cache_factor(self, factor: float) -> None: """ Set the cache factor for this individual cache. This will trigger a resize if it changes, which may require evicting items from the cache. - - Returns: - Whether the cache changed size or not. """ if not self.apply_cache_factor_from_config: - return False + return new_size = int(self._original_max_size * factor) if new_size != self.max_size: self.max_size = new_size if self._on_resize: self._on_resize() - return True - return False def __del__(self) -> None: # We're about to be deleted, so we make sure to clear up all the nodes diff --git a/tests/config/test___main__.py b/tests/config/test___main__.py index b1c73d36124f..cb5d4b05c366 100644 --- a/tests/config/test___main__.py +++ b/tests/config/test___main__.py @@ -17,15 +17,15 @@ class ConfigMainFileTestCase(ConfigFileTestCase): - def test_executes_without_an_action(self): + def test_executes_without_an_action(self) -> None: self.generate_config() main(["", "-c", self.config_file]) - def test_read__error_if_key_not_found(self): + def test_read__error_if_key_not_found(self) -> None: self.generate_config() with self.assertRaises(SystemExit): main(["", "read", "foo.bar.hello", "-c", self.config_file]) - def test_read__passes_if_key_found(self): + def test_read__passes_if_key_found(self) -> None: self.generate_config() main(["", "read", "server.server_name", "-c", self.config_file]) diff --git a/tests/config/test_background_update.py b/tests/config/test_background_update.py index 0c32c1ca299e..e4bad2ba6e5f 100644 --- a/tests/config/test_background_update.py +++ b/tests/config/test_background_update.py @@ -22,7 +22,7 @@ class BackgroundUpdateConfigTestCase(HomeserverTestCase): # Tests that the default values in the config are correctly loaded. Note that the default # values are loaded when the corresponding config options are commented out, which is why there isn't # a config specified here. - def test_default_configuration(self): + def test_default_configuration(self) -> None: background_updater = BackgroundUpdater( self.hs, self.hs.get_datastores().main.db_pool ) @@ -46,7 +46,7 @@ def test_default_configuration(self): """ ) ) - def test_custom_configuration(self): + def test_custom_configuration(self) -> None: background_updater = BackgroundUpdater( self.hs, self.hs.get_datastores().main.db_pool ) diff --git a/tests/config/test_base.py b/tests/config/test_base.py index 6a52f862f488..3fbfe6c1da2d 100644 --- a/tests/config/test_base.py +++ b/tests/config/test_base.py @@ -24,13 +24,13 @@ class BaseConfigTestCase(unittest.TestCase): - def setUp(self): + def setUp(self) -> None: # The root object needs a server property with a public_baseurl. root = Mock() root.server.public_baseurl = "http://test" self.config = Config(root) - def test_loading_missing_templates(self): + def test_loading_missing_templates(self) -> None: # Use a temporary directory that exists on the system, but that isn't likely to # contain template files with tempfile.TemporaryDirectory() as tmp_dir: @@ -50,7 +50,7 @@ def test_loading_missing_templates(self): "Template file did not contain our test string", ) - def test_loading_custom_templates(self): + def test_loading_custom_templates(self) -> None: # Use a temporary directory that exists on the system with tempfile.TemporaryDirectory() as tmp_dir: # Create a temporary bogus template file @@ -79,7 +79,7 @@ def test_loading_custom_templates(self): "Template file did not contain our test string", ) - def test_multiple_custom_template_directories(self): + def test_multiple_custom_template_directories(self) -> None: """Tests that directories are searched in the right order if multiple custom template directories are provided. """ @@ -137,7 +137,7 @@ def test_multiple_custom_template_directories(self): for td in tempdirs: td.cleanup() - def test_loading_template_from_nonexistent_custom_directory(self): + def test_loading_template_from_nonexistent_custom_directory(self) -> None: with self.assertRaises(ConfigError): self.config.read_templates( ["some_filename.html"], ("a_nonexistent_directory",) diff --git a/tests/config/test_cache.py b/tests/config/test_cache.py index d2b3c299e354..96f66af328dd 100644 --- a/tests/config/test_cache.py +++ b/tests/config/test_cache.py @@ -13,26 +13,27 @@ # limitations under the License. from synapse.config.cache import CacheConfig, add_resizable_cache +from synapse.types import JsonDict from synapse.util.caches.lrucache import LruCache from tests.unittest import TestCase class CacheConfigTests(TestCase): - def setUp(self): + def setUp(self) -> None: # Reset caches before each test since there's global state involved. self.config = CacheConfig() self.config.reset() - def tearDown(self): + def tearDown(self) -> None: # Also reset the caches after each test to leave state pristine. self.config.reset() - def test_individual_caches_from_environ(self): + def test_individual_caches_from_environ(self) -> None: """ Individual cache factors will be loaded from the environment. """ - config = {} + config: JsonDict = {} self.config._environ = { "SYNAPSE_CACHE_FACTOR_SOMETHING_OR_OTHER": "2", "SYNAPSE_NOT_CACHE": "BLAH", @@ -42,15 +43,15 @@ def test_individual_caches_from_environ(self): self.assertEqual(dict(self.config.cache_factors), {"something_or_other": 2.0}) - def test_config_overrides_environ(self): + def test_config_overrides_environ(self) -> None: """ Individual cache factors defined in the environment will take precedence over those in the config. """ - config = {"caches": {"per_cache_factors": {"foo": 2, "bar": 3}}} + config: JsonDict = {"caches": {"per_cache_factors": {"foo": 2, "bar": 3}}} self.config._environ = { "SYNAPSE_CACHE_FACTOR_SOMETHING_OR_OTHER": "2", - "SYNAPSE_CACHE_FACTOR_FOO": 1, + "SYNAPSE_CACHE_FACTOR_FOO": "1", } self.config.read_config(config, config_dir_path="", data_dir_path="") self.config.resize_all_caches() @@ -60,104 +61,104 @@ def test_config_overrides_environ(self): {"foo": 1.0, "bar": 3.0, "something_or_other": 2.0}, ) - def test_individual_instantiated_before_config_load(self): + def test_individual_instantiated_before_config_load(self) -> None: """ If a cache is instantiated before the config is read, it will be given the default cache size in the interim, and then resized once the config is loaded. """ - cache = LruCache(100) + cache: LruCache = LruCache(100) add_resizable_cache("foo", cache_resize_callback=cache.set_cache_factor) self.assertEqual(cache.max_size, 50) - config = {"caches": {"per_cache_factors": {"foo": 3}}} + config: JsonDict = {"caches": {"per_cache_factors": {"foo": 3}}} self.config.read_config(config) self.config.resize_all_caches() self.assertEqual(cache.max_size, 300) - def test_individual_instantiated_after_config_load(self): + def test_individual_instantiated_after_config_load(self) -> None: """ If a cache is instantiated after the config is read, it will be immediately resized to the correct size given the per_cache_factor if there is one. """ - config = {"caches": {"per_cache_factors": {"foo": 2}}} + config: JsonDict = {"caches": {"per_cache_factors": {"foo": 2}}} self.config.read_config(config, config_dir_path="", data_dir_path="") self.config.resize_all_caches() - cache = LruCache(100) + cache: LruCache = LruCache(100) add_resizable_cache("foo", cache_resize_callback=cache.set_cache_factor) self.assertEqual(cache.max_size, 200) - def test_global_instantiated_before_config_load(self): + def test_global_instantiated_before_config_load(self) -> None: """ If a cache is instantiated before the config is read, it will be given the default cache size in the interim, and then resized to the new default cache size once the config is loaded. """ - cache = LruCache(100) + cache: LruCache = LruCache(100) add_resizable_cache("foo", cache_resize_callback=cache.set_cache_factor) self.assertEqual(cache.max_size, 50) - config = {"caches": {"global_factor": 4}} + config: JsonDict = {"caches": {"global_factor": 4}} self.config.read_config(config, config_dir_path="", data_dir_path="") self.config.resize_all_caches() self.assertEqual(cache.max_size, 400) - def test_global_instantiated_after_config_load(self): + def test_global_instantiated_after_config_load(self) -> None: """ If a cache is instantiated after the config is read, it will be immediately resized to the correct size given the global factor if there is no per-cache factor. """ - config = {"caches": {"global_factor": 1.5}} + config: JsonDict = {"caches": {"global_factor": 1.5}} self.config.read_config(config, config_dir_path="", data_dir_path="") self.config.resize_all_caches() - cache = LruCache(100) + cache: LruCache = LruCache(100) add_resizable_cache("foo", cache_resize_callback=cache.set_cache_factor) self.assertEqual(cache.max_size, 150) - def test_cache_with_asterisk_in_name(self): + def test_cache_with_asterisk_in_name(self) -> None: """Some caches have asterisks in their name, test that they are set correctly.""" - config = { + config: JsonDict = { "caches": { "per_cache_factors": {"*cache_a*": 5, "cache_b": 6, "cache_c": 2} } } self.config._environ = { "SYNAPSE_CACHE_FACTOR_CACHE_A": "2", - "SYNAPSE_CACHE_FACTOR_CACHE_B": 3, + "SYNAPSE_CACHE_FACTOR_CACHE_B": "3", } self.config.read_config(config, config_dir_path="", data_dir_path="") self.config.resize_all_caches() - cache_a = LruCache(100) + cache_a: LruCache = LruCache(100) add_resizable_cache("*cache_a*", cache_resize_callback=cache_a.set_cache_factor) self.assertEqual(cache_a.max_size, 200) - cache_b = LruCache(100) + cache_b: LruCache = LruCache(100) add_resizable_cache("*Cache_b*", cache_resize_callback=cache_b.set_cache_factor) self.assertEqual(cache_b.max_size, 300) - cache_c = LruCache(100) + cache_c: LruCache = LruCache(100) add_resizable_cache("*cache_c*", cache_resize_callback=cache_c.set_cache_factor) self.assertEqual(cache_c.max_size, 200) - def test_apply_cache_factor_from_config(self): + def test_apply_cache_factor_from_config(self) -> None: """Caches can disable applying cache factor updates, mainly used by event cache size. """ - config = {"caches": {"event_cache_size": "10k"}} + config: JsonDict = {"caches": {"event_cache_size": "10k"}} self.config.read_config(config, config_dir_path="", data_dir_path="") self.config.resize_all_caches() - cache = LruCache( + cache: LruCache = LruCache( max_size=self.config.event_cache_size, apply_cache_factor_from_config=False, ) diff --git a/tests/config/test_database.py b/tests/config/test_database.py index 9eca10bbe9b6..240277bcc6b0 100644 --- a/tests/config/test_database.py +++ b/tests/config/test_database.py @@ -20,7 +20,7 @@ class DatabaseConfigTestCase(unittest.TestCase): - def test_database_configured_correctly(self): + def test_database_configured_correctly(self) -> None: conf = yaml.safe_load( DatabaseConfig().generate_config_section(data_dir_path="/data_dir_path") ) diff --git a/tests/config/test_generate.py b/tests/config/test_generate.py index fdfbb0e38e9c..3a023669320f 100644 --- a/tests/config/test_generate.py +++ b/tests/config/test_generate.py @@ -25,14 +25,14 @@ class ConfigGenerationTestCase(unittest.TestCase): - def setUp(self): + def setUp(self) -> None: self.dir = tempfile.mkdtemp() self.file = os.path.join(self.dir, "homeserver.yaml") - def tearDown(self): + def tearDown(self) -> None: shutil.rmtree(self.dir) - def test_generate_config_generates_files(self): + def test_generate_config_generates_files(self) -> None: with redirect_stdout(StringIO()): HomeServerConfig.load_or_generate_config( "", @@ -56,7 +56,7 @@ def test_generate_config_generates_files(self): os.path.join(os.getcwd(), "homeserver.log"), ) - def assert_log_filename_is(self, log_config_file, expected): + def assert_log_filename_is(self, log_config_file: str, expected: str) -> None: with open(log_config_file) as f: config = f.read() # find the 'filename' line diff --git a/tests/config/test_load.py b/tests/config/test_load.py index 69a4e9413b26..fcbe79cc7a28 100644 --- a/tests/config/test_load.py +++ b/tests/config/test_load.py @@ -21,14 +21,14 @@ class ConfigLoadingFileTestCase(ConfigFileTestCase): - def test_load_fails_if_server_name_missing(self): + def test_load_fails_if_server_name_missing(self) -> None: self.generate_config_and_remove_lines_containing("server_name") with self.assertRaises(ConfigError): HomeServerConfig.load_config("", ["-c", self.config_file]) with self.assertRaises(ConfigError): HomeServerConfig.load_or_generate_config("", ["-c", self.config_file]) - def test_generates_and_loads_macaroon_secret_key(self): + def test_generates_and_loads_macaroon_secret_key(self) -> None: self.generate_config() with open(self.config_file) as f: @@ -58,7 +58,7 @@ def test_generates_and_loads_macaroon_secret_key(self): "was: %r" % (config2.key.macaroon_secret_key,) ) - def test_load_succeeds_if_macaroon_secret_key_missing(self): + def test_load_succeeds_if_macaroon_secret_key_missing(self) -> None: self.generate_config_and_remove_lines_containing("macaroon") config1 = HomeServerConfig.load_config("", ["-c", self.config_file]) config2 = HomeServerConfig.load_config("", ["-c", self.config_file]) @@ -73,7 +73,7 @@ def test_load_succeeds_if_macaroon_secret_key_missing(self): config1.key.macaroon_secret_key, config3.key.macaroon_secret_key ) - def test_disable_registration(self): + def test_disable_registration(self) -> None: self.generate_config() self.add_lines_to_config( ["enable_registration: true", "disable_registration: true"] @@ -93,7 +93,7 @@ def test_disable_registration(self): assert config3 is not None self.assertTrue(config3.registration.enable_registration) - def test_stats_enabled(self): + def test_stats_enabled(self) -> None: self.generate_config_and_remove_lines_containing("enable_metrics") self.add_lines_to_config(["enable_metrics: true"]) @@ -101,7 +101,7 @@ def test_stats_enabled(self): config = HomeServerConfig.load_config("", ["-c", self.config_file]) self.assertFalse(config.metrics.metrics_flags.known_servers) - def test_depreciated_identity_server_flag_throws_error(self): + def test_depreciated_identity_server_flag_throws_error(self) -> None: self.generate_config() # Needed to ensure that actual key/value pair added below don't end up on a line with a comment self.add_lines_to_config([" "]) diff --git a/tests/config/test_ratelimiting.py b/tests/config/test_ratelimiting.py index 1b63e1adfd36..f12147eaa000 100644 --- a/tests/config/test_ratelimiting.py +++ b/tests/config/test_ratelimiting.py @@ -18,7 +18,7 @@ class RatelimitConfigTestCase(TestCase): - def test_parse_rc_federation(self): + def test_parse_rc_federation(self) -> None: config_dict = default_config("test") config_dict["rc_federation"] = { "window_size": 20000, diff --git a/tests/config/test_registration_config.py b/tests/config/test_registration_config.py index 33d7b70e32ed..f6869d7f0645 100644 --- a/tests/config/test_registration_config.py +++ b/tests/config/test_registration_config.py @@ -21,7 +21,7 @@ class RegistrationConfigTestCase(ConfigFileTestCase): - def test_session_lifetime_must_not_be_exceeded_by_smaller_lifetimes(self): + def test_session_lifetime_must_not_be_exceeded_by_smaller_lifetimes(self) -> None: """ session_lifetime should logically be larger than, or at least as large as, all the different token lifetimes. @@ -91,7 +91,7 @@ def test_session_lifetime_must_not_be_exceeded_by_smaller_lifetimes(self): "", ) - def test_refuse_to_start_if_open_registration_and_no_verification(self): + def test_refuse_to_start_if_open_registration_and_no_verification(self) -> None: self.generate_config() self.add_lines_to_config( [ diff --git a/tests/config/test_room_directory.py b/tests/config/test_room_directory.py index db745815eff2..297ab377921a 100644 --- a/tests/config/test_room_directory.py +++ b/tests/config/test_room_directory.py @@ -20,7 +20,7 @@ class RoomDirectoryConfigTestCase(unittest.TestCase): - def test_alias_creation_acl(self): + def test_alias_creation_acl(self) -> None: config = yaml.safe_load( """ alias_creation_rules: @@ -78,7 +78,7 @@ def test_alias_creation_acl(self): ) ) - def test_room_publish_acl(self): + def test_room_publish_acl(self) -> None: config = yaml.safe_load( """ alias_creation_rules: [] diff --git a/tests/config/test_server.py b/tests/config/test_server.py index 1f27a5470189..41a3fb0b6db6 100644 --- a/tests/config/test_server.py +++ b/tests/config/test_server.py @@ -21,7 +21,7 @@ class ServerConfigTestCase(unittest.TestCase): - def test_is_threepid_reserved(self): + def test_is_threepid_reserved(self) -> None: user1 = {"medium": "email", "address": "user1@example.com"} user2 = {"medium": "email", "address": "user2@example.com"} user3 = {"medium": "email", "address": "user3@example.com"} @@ -32,7 +32,7 @@ def test_is_threepid_reserved(self): self.assertFalse(is_threepid_reserved(config, user3)) self.assertFalse(is_threepid_reserved(config, user1_msisdn)) - def test_unsecure_listener_no_listeners_open_private_ports_false(self): + def test_unsecure_listener_no_listeners_open_private_ports_false(self) -> None: conf = yaml.safe_load( ServerConfig().generate_config_section( "CONFDIR", "/data_dir_path", "che.org", False, None @@ -52,7 +52,7 @@ def test_unsecure_listener_no_listeners_open_private_ports_false(self): self.assertEqual(conf["listeners"], expected_listeners) - def test_unsecure_listener_no_listeners_open_private_ports_true(self): + def test_unsecure_listener_no_listeners_open_private_ports_true(self) -> None: conf = yaml.safe_load( ServerConfig().generate_config_section( "CONFDIR", "/data_dir_path", "che.org", True, None @@ -71,7 +71,7 @@ def test_unsecure_listener_no_listeners_open_private_ports_true(self): self.assertEqual(conf["listeners"], expected_listeners) - def test_listeners_set_correctly_open_private_ports_false(self): + def test_listeners_set_correctly_open_private_ports_false(self) -> None: listeners = [ { "port": 8448, @@ -95,7 +95,7 @@ def test_listeners_set_correctly_open_private_ports_false(self): self.assertEqual(conf["listeners"], listeners) - def test_listeners_set_correctly_open_private_ports_true(self): + def test_listeners_set_correctly_open_private_ports_true(self) -> None: listeners = [ { "port": 8448, @@ -131,14 +131,14 @@ def test_listeners_set_correctly_open_private_ports_true(self): class GenerateIpSetTestCase(unittest.TestCase): - def test_empty(self): + def test_empty(self) -> None: ip_set = generate_ip_set(()) self.assertFalse(ip_set) ip_set = generate_ip_set((), ()) self.assertFalse(ip_set) - def test_generate(self): + def test_generate(self) -> None: """Check adding IPv4 and IPv6 addresses.""" # IPv4 address ip_set = generate_ip_set(("1.2.3.4",)) @@ -160,7 +160,7 @@ def test_generate(self): ip_set = generate_ip_set(("1.2.3.4", "::1.2.3.4")) self.assertEqual(len(ip_set.iter_cidrs()), 4) - def test_extra(self): + def test_extra(self) -> None: """Extra IP addresses are treated the same.""" ip_set = generate_ip_set((), ("1.2.3.4",)) self.assertEqual(len(ip_set.iter_cidrs()), 4) @@ -172,7 +172,7 @@ def test_extra(self): ip_set = generate_ip_set(("1.2.3.4",), ("1.2.3.4",)) self.assertEqual(len(ip_set.iter_cidrs()), 4) - def test_bad_value(self): + def test_bad_value(self) -> None: """An error should be raised if a bad value is passed in.""" with self.assertRaises(ConfigError): generate_ip_set(("not-an-ip",)) diff --git a/tests/config/test_tls.py b/tests/config/test_tls.py index 9ba5781573fc..7510fc464333 100644 --- a/tests/config/test_tls.py +++ b/tests/config/test_tls.py @@ -13,13 +13,20 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import cast + import idna from OpenSSL import SSL from synapse.config._base import Config, RootConfig +from synapse.config.homeserver import HomeServerConfig from synapse.config.tls import ConfigError, TlsConfig -from synapse.crypto.context_factory import FederationPolicyForHTTPS +from synapse.crypto.context_factory import ( + FederationPolicyForHTTPS, + SSLClientConnectionCreator, +) +from synapse.types import JsonDict from tests.unittest import TestCase @@ -27,7 +34,7 @@ class FakeServer(Config): section = "server" - def has_tls_listener(self): + def has_tls_listener(self) -> bool: return False @@ -36,21 +43,21 @@ class TestConfig(RootConfig): class TLSConfigTests(TestCase): - def test_tls_client_minimum_default(self): + def test_tls_client_minimum_default(self) -> None: """ The default client TLS version is 1.0. """ - config = {} + config: JsonDict = {} t = TestConfig() t.tls.read_config(config, config_dir_path="", data_dir_path="") self.assertEqual(t.tls.federation_client_minimum_tls_version, "1") - def test_tls_client_minimum_set(self): + def test_tls_client_minimum_set(self) -> None: """ The default client TLS version can be set to 1.0, 1.1, and 1.2. """ - config = {"federation_client_minimum_tls_version": 1} + config: JsonDict = {"federation_client_minimum_tls_version": 1} t = TestConfig() t.tls.read_config(config, config_dir_path="", data_dir_path="") self.assertEqual(t.tls.federation_client_minimum_tls_version, "1") @@ -76,7 +83,7 @@ def test_tls_client_minimum_set(self): t.tls.read_config(config, config_dir_path="", data_dir_path="") self.assertEqual(t.tls.federation_client_minimum_tls_version, "1.2") - def test_tls_client_minimum_1_point_3_missing(self): + def test_tls_client_minimum_1_point_3_missing(self) -> None: """ If TLS 1.3 support is missing and it's configured, it will raise a ConfigError. @@ -88,7 +95,7 @@ def test_tls_client_minimum_1_point_3_missing(self): self.addCleanup(setattr, SSL, "SSL.OP_NO_TLSv1_3", OP_NO_TLSv1_3) assert not hasattr(SSL, "OP_NO_TLSv1_3") - config = {"federation_client_minimum_tls_version": 1.3} + config: JsonDict = {"federation_client_minimum_tls_version": 1.3} t = TestConfig() with self.assertRaises(ConfigError) as e: t.tls.read_config(config, config_dir_path="", data_dir_path="") @@ -100,7 +107,7 @@ def test_tls_client_minimum_1_point_3_missing(self): ), ) - def test_tls_client_minimum_1_point_3_exists(self): + def test_tls_client_minimum_1_point_3_exists(self) -> None: """ If TLS 1.3 support exists and it's configured, it will be settable. """ @@ -110,20 +117,20 @@ def test_tls_client_minimum_1_point_3_exists(self): self.addCleanup(lambda: delattr(SSL, "OP_NO_TLSv1_3")) assert hasattr(SSL, "OP_NO_TLSv1_3") - config = {"federation_client_minimum_tls_version": 1.3} + config: JsonDict = {"federation_client_minimum_tls_version": 1.3} t = TestConfig() t.tls.read_config(config, config_dir_path="", data_dir_path="") self.assertEqual(t.tls.federation_client_minimum_tls_version, "1.3") - def test_tls_client_minimum_set_passed_through_1_2(self): + def test_tls_client_minimum_set_passed_through_1_2(self) -> None: """ The configured TLS version is correctly configured by the ContextFactory. """ - config = {"federation_client_minimum_tls_version": 1.2} + config: JsonDict = {"federation_client_minimum_tls_version": 1.2} t = TestConfig() t.tls.read_config(config, config_dir_path="", data_dir_path="") - cf = FederationPolicyForHTTPS(t) + cf = FederationPolicyForHTTPS(cast(HomeServerConfig, t)) options = _get_ssl_context_options(cf._verify_ssl_context) # The context has had NO_TLSv1_1 and NO_TLSv1_0 set, but not NO_TLSv1_2 @@ -131,15 +138,15 @@ def test_tls_client_minimum_set_passed_through_1_2(self): self.assertNotEqual(options & SSL.OP_NO_TLSv1_1, 0) self.assertEqual(options & SSL.OP_NO_TLSv1_2, 0) - def test_tls_client_minimum_set_passed_through_1_0(self): + def test_tls_client_minimum_set_passed_through_1_0(self) -> None: """ The configured TLS version is correctly configured by the ContextFactory. """ - config = {"federation_client_minimum_tls_version": 1} + config: JsonDict = {"federation_client_minimum_tls_version": 1} t = TestConfig() t.tls.read_config(config, config_dir_path="", data_dir_path="") - cf = FederationPolicyForHTTPS(t) + cf = FederationPolicyForHTTPS(cast(HomeServerConfig, t)) options = _get_ssl_context_options(cf._verify_ssl_context) # The context has not had any of the NO_TLS set. @@ -147,11 +154,11 @@ def test_tls_client_minimum_set_passed_through_1_0(self): self.assertEqual(options & SSL.OP_NO_TLSv1_1, 0) self.assertEqual(options & SSL.OP_NO_TLSv1_2, 0) - def test_whitelist_idna_failure(self): + def test_whitelist_idna_failure(self) -> None: """ The federation certificate whitelist will not allow IDNA domain names. """ - config = { + config: JsonDict = { "federation_certificate_verification_whitelist": [ "example.com", "*.ドメイン.テスト", @@ -163,11 +170,11 @@ def test_whitelist_idna_failure(self): ) self.assertIn("IDNA domain names", str(e)) - def test_whitelist_idna_result(self): + def test_whitelist_idna_result(self) -> None: """ The federation certificate whitelist will match on IDNA encoded names. """ - config = { + config: JsonDict = { "federation_certificate_verification_whitelist": [ "example.com", "*.xn--eckwd4c7c.xn--zckzah", @@ -176,14 +183,16 @@ def test_whitelist_idna_result(self): t = TestConfig() t.tls.read_config(config, config_dir_path="", data_dir_path="") - cf = FederationPolicyForHTTPS(t) + cf = FederationPolicyForHTTPS(cast(HomeServerConfig, t)) # Not in the whitelist opts = cf.get_options(b"notexample.com") + assert isinstance(opts, SSLClientConnectionCreator) self.assertTrue(opts._verifier._verify_certs) # Caught by the wildcard opts = cf.get_options(idna.encode("テスト.ドメイン.テスト")) + assert isinstance(opts, SSLClientConnectionCreator) self.assertFalse(opts._verifier._verify_certs) @@ -191,4 +200,4 @@ def _get_ssl_context_options(ssl_context: SSL.Context) -> int: """get the options bits from an openssl context object""" # the OpenSSL.SSL.Context wrapper doesn't expose get_options, so we have to # use the low-level interface - return SSL._lib.SSL_CTX_get_options(ssl_context._context) + return SSL._lib.SSL_CTX_get_options(ssl_context._context) # type: ignore[attr-defined] diff --git a/tests/config/test_util.py b/tests/config/test_util.py index 3d4929daacf0..7073654832e1 100644 --- a/tests/config/test_util.py +++ b/tests/config/test_util.py @@ -21,7 +21,7 @@ class ValidateConfigTestCase(TestCase): """Test cases for synapse.config._util.validate_config""" - def test_bad_object_in_array(self): + def test_bad_object_in_array(self) -> None: """malformed objects within an array should be validated correctly""" # consider a structure: diff --git a/tests/config/utils.py b/tests/config/utils.py index 94c18a052ba4..4c0e8a064a6c 100644 --- a/tests/config/utils.py +++ b/tests/config/utils.py @@ -17,19 +17,20 @@ import unittest from contextlib import redirect_stdout from io import StringIO +from typing import List from synapse.config.homeserver import HomeServerConfig class ConfigFileTestCase(unittest.TestCase): - def setUp(self): + def setUp(self) -> None: self.dir = tempfile.mkdtemp() self.config_file = os.path.join(self.dir, "homeserver.yaml") - def tearDown(self): + def tearDown(self) -> None: shutil.rmtree(self.dir) - def generate_config(self): + def generate_config(self) -> None: with redirect_stdout(StringIO()): HomeServerConfig.load_or_generate_config( "", @@ -43,7 +44,7 @@ def generate_config(self): ], ) - def generate_config_and_remove_lines_containing(self, needle): + def generate_config_and_remove_lines_containing(self, needle: str) -> None: self.generate_config() with open(self.config_file) as f: @@ -52,7 +53,7 @@ def generate_config_and_remove_lines_containing(self, needle): with open(self.config_file, "w") as f: f.write("".join(contents)) - def add_lines_to_config(self, lines): + def add_lines_to_config(self, lines: List[str]) -> None: with open(self.config_file, "a") as f: for line in lines: f.write(line + "\n")