Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve static type checking #333

Merged
merged 5 commits into from
Apr 21, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/333.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Improve static type checking.
4 changes: 2 additions & 2 deletions sygnal/apnstruncate.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
]


def json_encode(payload) -> bytes:
def json_encode(payload: Dict[str, Any]) -> bytes:
return json.dumps(payload, ensure_ascii=False).encode()


Expand Down Expand Up @@ -115,7 +115,7 @@ def _choppables_for_aps(aps: Dict[str, Any]) -> List[Choppable]:
def _choppable_get(
aps: Dict[str, Any],
choppable: Choppable,
):
) -> str:
if choppable[0] == "alert":
return aps["alert"]
elif choppable[0] == "alert.body":
Expand Down
4 changes: 2 additions & 2 deletions sygnal/gcmpushkin.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ async def create(
return cls(name, sygnal, config)

async def _perform_http_request(
self, body: Dict, headers: Dict[AnyStr, List[AnyStr]]
self, body: Dict[str, Any], headers: Dict[AnyStr, List[AnyStr]]
) -> Tuple[IResponse, str]:
"""
Perform an HTTP request to the FCM server with the body and headers
Expand Down Expand Up @@ -208,7 +208,7 @@ async def _request_dispatch(
self,
n: Notification,
log: NotificationLoggerAdapter,
body: dict,
body: Dict[str, Any],
headers: Dict[AnyStr, List[AnyStr]],
pushkeys: List[str],
span: Span,
Expand Down
26 changes: 13 additions & 13 deletions tests/test_apns.py
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it now possible to mark these files as disallow_untyped_defs = true in mypy.ini?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@


class ApnsTestCase(testutils.TestCase):
def setUp(self):
def setUp(self) -> None:
self.apns_mock_class = patch("sygnal.apnspushkin.APNs").start()
self.apns_mock = MagicMock()
self.apns_mock_class.return_value = self.apns_mock
Expand All @@ -82,7 +82,7 @@ def get_test_pushkin(self, name: str) -> ApnsPushkin:
assert isinstance(test_pushkin, ApnsPushkin)
return test_pushkin

def config_setup(self, config):
def config_setup(self, config) -> None:
super().config_setup(config)
config["apps"][PUSHKIN_ID] = {"type": "apns", "certfile": TEST_CERTFILE_PATH}
config["apps"][PUSHKIN_ID_WITH_PUSH_TYPE] = {
Expand All @@ -91,7 +91,7 @@ def config_setup(self, config):
"push_type": "alert",
}

def test_payload_truncation(self):
def test_payload_truncation(self) -> None:
"""
Tests that APNS message bodies will be truncated to fit the limits of
APNS.
Expand All @@ -114,7 +114,7 @@ def test_payload_truncation(self):

self.assertLessEqual(len(apnstruncate.json_encode(payload)), 240)

def test_payload_truncation_test_validity(self):
def test_payload_truncation_test_validity(self) -> None:
"""
This tests that L{test_payload_truncation_success} is a valid test
by showing that not limiting the truncation size would result in a
Expand All @@ -138,7 +138,7 @@ def test_payload_truncation_test_validity(self):

self.assertGreater(len(apnstruncate.json_encode(payload)), 200)

def test_expected(self):
def test_expected(self) -> None:
"""
Tests the expected case: a good response from APNS means we pass on
a good response to the homeserver.
Expand Down Expand Up @@ -177,7 +177,7 @@ def test_expected(self):

self.assertEqual({"rejected": []}, resp)

def test_expected_event_id_only_with_default_payload(self):
def test_expected_event_id_only_with_default_payload(self) -> None:
"""
Tests the expected fallback case: a good response from APNS means we pass on
a good response to the homeserver.
Expand Down Expand Up @@ -214,7 +214,7 @@ def test_expected_event_id_only_with_default_payload(self):

self.assertEqual({"rejected": []}, resp)

def test_expected_badge_only_with_default_payload(self):
def test_expected_badge_only_with_default_payload(self) -> None:
"""
Tests the expected fallback case: a good response from APNS means we pass on
a good response to the homeserver.
Expand Down Expand Up @@ -243,7 +243,7 @@ def test_expected_badge_only_with_default_payload(self):

self.assertEqual({"rejected": []}, resp)

def test_expected_full_with_default_payload(self):
def test_expected_full_with_default_payload(self) -> None:
"""
Tests the expected fallback case: a good response from APNS means we pass on
a good response to the homeserver.
Expand Down Expand Up @@ -285,7 +285,7 @@ def test_expected_full_with_default_payload(self):

self.assertEqual({"rejected": []}, resp)

def test_misconfigured_payload_is_rejected(self):
def test_misconfigured_payload_is_rejected(self) -> None:
"""Test that a malformed default_payload causes pushkey to be rejected"""

resp = self._request(
Expand All @@ -294,7 +294,7 @@ def test_misconfigured_payload_is_rejected(self):

self.assertEqual({"rejected": ["badpayload"]}, resp)

def test_rejection(self):
def test_rejection(self) -> None:
"""
Tests the rejection case: a rejection response from APNS leads to us
passing on a rejection to the homeserver.
Expand All @@ -312,7 +312,7 @@ def test_rejection(self):
self.assertEqual(1, method.call_count)
self.assertEqual({"rejected": ["spqr"]}, resp)

def test_no_retry_on_4xx(self):
def test_no_retry_on_4xx(self) -> None:
"""
Test that we don't retry when we get a 4xx error but do not mark as
rejected.
Expand All @@ -330,7 +330,7 @@ def test_no_retry_on_4xx(self):
self.assertEqual(1, method.call_count)
self.assertEqual(502, resp)

def test_retry_on_5xx(self):
def test_retry_on_5xx(self) -> None:
"""
Test that we DO retry when we get a 5xx error and do not mark as
rejected.
Expand All @@ -348,7 +348,7 @@ def test_retry_on_5xx(self):
self.assertGreater(method.call_count, 1)
self.assertEqual(502, resp)

def test_expected_with_push_type(self):
def test_expected_with_push_type(self) -> None:
"""
Tests the expected case: a good response from APNS means we pass on
a good response to the homeserver.
Expand Down
20 changes: 10 additions & 10 deletions tests/test_apnstruncate.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from sygnal.apnstruncate import json_encode, truncate


def simplestring(length, offset=0):
def simplestring(length: int, offset: int = 0) -> str:
"""
Deterministically generates a string.
Args:
Expand All @@ -41,7 +41,7 @@ def simplestring(length, offset=0):
)


def sillystring(length, offset=0):
def sillystring(length: int, offset: int = 0) -> str:
"""
Deterministically generates a string
Args:
Expand All @@ -63,7 +63,7 @@ def payload_for_aps(aps):


class TruncateTestCase(unittest.TestCase):
def test_dont_truncate(self):
def test_dont_truncate(self) -> None:
"""
Tests that truncation is not performed if unnecessary.
"""
Expand All @@ -72,7 +72,7 @@ def test_dont_truncate(self):
aps = {"alert": txt}
self.assertEqual(txt, truncate(payload_for_aps(aps), 256)["aps"]["alert"])

def test_truncate_alert(self):
def test_truncate_alert(self) -> None:
"""
Tests that the 'alert' string field will be truncated when needed.
"""
Expand All @@ -83,7 +83,7 @@ def test_truncate_alert(self):
txt[:5], truncate(payload_for_aps(aps), overhead + 5)["aps"]["alert"]
)

def test_truncate_alert_body(self):
def test_truncate_alert_body(self) -> None:
"""
Tests that the 'alert' 'body' field will be truncated when needed.
"""
Expand All @@ -95,7 +95,7 @@ def test_truncate_alert_body(self):
truncate(payload_for_aps(aps), overhead + 5)["aps"]["alert"]["body"],
)

def test_truncate_loc_arg(self):
def test_truncate_loc_arg(self) -> None:
"""
Tests that the 'alert' 'loc-args' field will be truncated when needed.
(Tests with one loc arg)
Expand All @@ -108,7 +108,7 @@ def test_truncate_loc_arg(self):
truncate(payload_for_aps(aps), overhead + 5)["aps"]["alert"]["loc-args"][0],
)

def test_truncate_loc_args(self):
def test_truncate_loc_args(self) -> None:
"""
Tests that the 'alert' 'loc-args' field will be truncated when needed.
(Tests with two loc args)
Expand All @@ -130,7 +130,7 @@ def test_truncate_loc_args(self):
],
)

def test_python_unicode_support(self):
def test_python_unicode_support(self) -> None:
"""
Tests Python's unicode support :-
a one character unicode string should have a length of one, even if it's one
Expand All @@ -146,7 +146,7 @@ def test_python_unicode_support(self):
)
self.fail(msg)

def test_truncate_string_with_multibyte(self):
def test_truncate_string_with_multibyte(self) -> None:
"""
Tests that truncation works as expected on strings containing one
multibyte character.
Expand All @@ -160,7 +160,7 @@ def test_truncate_string_with_multibyte(self):
txt[:17], truncate(payload_for_aps(aps), overhead + 20)["aps"]["alert"]
)

def test_truncate_multibyte(self):
def test_truncate_multibyte(self) -> None:
"""
Tests that truncation works as expected on strings containing only
multibyte characters.
Expand Down
38 changes: 23 additions & 15 deletions tests/test_gcm.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import json
from typing import TYPE_CHECKING, Any, AnyStr, Dict, List, Tuple

from sygnal.gcmpushkin import GcmPushkin

from tests import testutils
from tests.testutils import DummyResponse

if TYPE_CHECKING:
from sygnal.sygnal import Sygnal

DEVICE_EXAMPLE = {"app_id": "com.example.gcm", "pushkey": "spqr", "pushkey_ts": 42}
DEVICE_EXAMPLE2 = {"app_id": "com.example.gcm", "pushkey": "spqr2", "pushkey_ts": 42}
DEVICE_EXAMPLE_WITH_DEFAULT_PAYLOAD = {
Expand Down Expand Up @@ -57,30 +61,34 @@ class TestGcmPushkin(GcmPushkin):
can be preloaded with virtual requests.
"""

def __init__(self, name, sygnal, config):
def __init__(self, name: str, sygnal: "Sygnal", config: Dict[str, Any]):
super().__init__(name, sygnal, config)
self.preloaded_response = None
self.preloaded_response_payload = None
self.last_request_body = None
self.last_request_headers = None
self.preloaded_response = DummyResponse(0)
self.preloaded_response_payload: Dict[str, Any] = {}
self.last_request_body: Dict[str, Any] = {}
self.last_request_headers: Dict[AnyStr, List[AnyStr]] = {} # type: ignore[valid-type]
DMRobertson marked this conversation as resolved.
Show resolved Hide resolved
self.num_requests = 0

def preload_with_response(self, code, response_payload):
def preload_with_response(
self, code: int, response_payload: Dict[str, Any]
) -> None:
"""
Preloads a fake GCM response.
"""
self.preloaded_response = DummyResponse(code)
self.preloaded_response_payload = response_payload

async def _perform_http_request(self, body, headers):
async def _perform_http_request( # type: ignore[override]
self, body: Dict[str, Any], headers: Dict[AnyStr, List[AnyStr]]
) -> Tuple[DummyResponse, str]:
self.last_request_body = body
self.last_request_headers = headers
self.num_requests += 1
return self.preloaded_response, json.dumps(self.preloaded_response_payload)


class GcmTestCase(testutils.TestCase):
def config_setup(self, config):
def config_setup(self, config: Dict[str, Any]) -> None:
config["apps"]["com.example.gcm"] = {
"type": "tests.test_gcm.TestGcmPushkin",
"api_key": "kii",
Expand All @@ -96,7 +104,7 @@ def get_test_pushkin(self, name: str) -> TestGcmPushkin:
assert isinstance(pushkin, TestGcmPushkin)
return pushkin

def test_expected(self):
def test_expected(self) -> None:
"""
Tests the expected case: a good response from GCM leads to a good
response from Sygnal.
Expand All @@ -111,7 +119,7 @@ def test_expected(self):
self.assertEqual(resp, {"rejected": []})
self.assertEqual(gcm.num_requests, 1)

def test_expected_with_default_payload(self):
def test_expected_with_default_payload(self) -> None:
"""
Tests the expected case: a good response from GCM leads to a good
response from Sygnal.
Expand All @@ -128,7 +136,7 @@ def test_expected_with_default_payload(self):
self.assertEqual(resp, {"rejected": []})
self.assertEqual(gcm.num_requests, 1)

def test_misformed_default_payload_rejected(self):
def test_misformed_default_payload_rejected(self) -> None:
"""
Tests that a non-dict default_payload is rejected.
"""
Expand All @@ -144,7 +152,7 @@ def test_misformed_default_payload_rejected(self):
self.assertEqual(resp, {"rejected": ["badpayload"]})
self.assertEqual(gcm.num_requests, 0)

def test_rejected(self):
def test_rejected(self) -> None:
"""
Tests the rejected case: a pushkey rejected to GCM leads to Sygnal
informing the homeserver of the rejection.
Expand All @@ -159,7 +167,7 @@ def test_rejected(self):
self.assertEqual(resp, {"rejected": ["spqr"]})
self.assertEqual(gcm.num_requests, 1)

def test_batching(self):
def test_batching(self) -> None:
"""
Tests that multiple GCM devices have their notification delivered to GCM
together, instead of being delivered separately.
Expand All @@ -184,7 +192,7 @@ def test_batching(self):
self.assertEqual(gcm.last_request_body["registration_ids"], ["spqr", "spqr2"])
self.assertEqual(gcm.num_requests, 1)

def test_batching_individual_failure(self):
def test_batching_individual_failure(self) -> None:
"""
Tests that multiple GCM devices have their notification delivered to GCM
together, instead of being delivered separately,
Expand All @@ -211,7 +219,7 @@ def test_batching_individual_failure(self):
self.assertEqual(gcm.last_request_body["registration_ids"], ["spqr", "spqr2"])
self.assertEqual(gcm.num_requests, 1)

def test_fcm_options(self):
def test_fcm_options(self) -> None:
"""
Tests that the config option `fcm_options` allows setting a base layer
of options to pass to FCM, for example ones that would be needed for iOS.
Expand Down
6 changes: 3 additions & 3 deletions tests/test_proxy_url_parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@


class ProxyUrlTestCase(unittest.TestCase):
def test_decompose_http_proxy_url(self):
def test_decompose_http_proxy_url(self) -> None:
parts = decompose_http_proxy_url("http://example.org")
self.assertEqual(parts, HttpProxyUrl("example.org", 80, None))

Expand All @@ -35,7 +35,7 @@ def test_decompose_http_proxy_url(self):
parts, HttpProxyUrl("example.org", 8080, ("bob", "secretsquirrel"))
)

def test_decompose_username_only(self):
def test_decompose_username_only(self) -> None:
"""
We do not support usernames without passwords for now — this tests the
current behaviour, though (it ignores the username).
Expand All @@ -44,7 +44,7 @@ def test_decompose_username_only(self):
parts = decompose_http_proxy_url("http://[email protected]:8080")
self.assertEqual(parts, HttpProxyUrl("example.org", 8080, None))

def test_decompose_http_proxy_url_failure(self):
def test_decompose_http_proxy_url_failure(self) -> None:
# test that non-HTTP schemes raise an exception
self.assertRaises(
RuntimeError, lambda: decompose_http_proxy_url("ftp://example.org")
Expand Down