From 64d8465a4ebb7af3b7be553fb5bbfde11935a77e Mon Sep 17 00:00:00 2001 From: Simon Morgan Date: Thu, 4 Aug 2022 12:32:21 +0100 Subject: [PATCH 1/4] Es 7755 (#1) * ES-7755 Simple HTTPX implementation * ES-7755 Add http2 extras to httpx --- apns2/client.py | 181 ++++++++++--------------------------------- apns2/credentials.py | 26 +------ pyproject.toml | 4 +- 3 files changed, 45 insertions(+), 166 deletions(-) diff --git a/apns2/client.py b/apns2/client.py index 0947350..e2f751c 100644 --- a/apns2/client.py +++ b/apns2/client.py @@ -1,15 +1,12 @@ import collections +import httpx import json import logging -import time -import typing -import weakref from enum import Enum -from threading import Thread from typing import Dict, Iterable, Optional, Tuple, Union -from .credentials import CertificateCredentials, Credentials -from .errors import ConnectionFailed, exception_class_for_reason +from .credentials import TokenCredentials +from .errors import exception_class_for_reason # We don't generally need to know about the Credentials subclasses except to # keep the old API, where APNsClient took a cert_file from .payload import Payload @@ -29,7 +26,7 @@ class NotificationType(Enum): MDM = 'mdm' -RequestStream = collections.namedtuple('RequestStream', ['stream_id', 'token']) +RequestStream = collections.namedtuple('RequestStream', ['token', 'status', 'reason']) Notification = collections.namedtuple('Notification', ['token', 'payload']) DEFAULT_APNS_PRIORITY = NotificationPriority.Immediate @@ -47,62 +44,40 @@ class APNsClient(object): ALTERNATIVE_PORT = 2197 def __init__(self, - credentials: Union[Credentials, str], + credentials: TokenCredentials, use_sandbox: bool = False, use_alternative_port: bool = False, proto: Optional[str] = None, json_encoder: Optional[type] = None, password: Optional[str] = None, proxy_host: Optional[str] = None, proxy_port: Optional[int] = None, heartbeat_period: Optional[float] = None) -> None: - if isinstance(credentials, str): - self.__credentials = CertificateCredentials(credentials, password) # type: Credentials - else: - self.__credentials = credentials + + self.__credentials = credentials self._init_connection(use_sandbox, use_alternative_port, proto, proxy_host, proxy_port) if heartbeat_period: - self._start_heartbeat(heartbeat_period) + raise NotImplementedError("heartbeat not supported") self.__json_encoder = json_encoder - self.__max_concurrent_streams = 0 - self.__previous_server_max_concurrent_streams = None def _init_connection(self, use_sandbox: bool, use_alternative_port: bool, proto: Optional[str], proxy_host: Optional[str], proxy_port: Optional[int]) -> None: - server = self.SANDBOX_SERVER if use_sandbox else self.LIVE_SERVER - port = self.ALTERNATIVE_PORT if use_alternative_port else self.DEFAULT_PORT - self._connection = self.__credentials.create_connection(server, port, proto, proxy_host, proxy_port) - - def _start_heartbeat(self, heartbeat_period: float) -> None: - conn_ref = weakref.ref(self._connection) - - def watchdog() -> None: - while True: - conn = conn_ref() - if conn is None: - break - - conn.ping('-' * 8) - time.sleep(heartbeat_period) - - thread = Thread(target=watchdog) - thread.setDaemon(True) - thread.start() + self.__server = self.SANDBOX_SERVER if use_sandbox else self.LIVE_SERVER + self.__port = self.ALTERNATIVE_PORT if use_alternative_port else self.DEFAULT_PORT def send_notification(self, token_hex: str, notification: Payload, topic: Optional[str] = None, priority: NotificationPriority = NotificationPriority.Immediate, expiration: Optional[int] = None, collapse_id: Optional[str] = None) -> None: - stream_id = self.send_notification_async(token_hex, notification, topic, priority, expiration, collapse_id) - result = self.get_notification_result(stream_id) - if result != 'Success': - if isinstance(result, tuple): - reason, info = result - raise exception_class_for_reason(reason)(info) - else: - raise exception_class_for_reason(result) - - def send_notification_async(self, token_hex: str, notification: Payload, topic: Optional[str] = None, - priority: NotificationPriority = NotificationPriority.Immediate, - expiration: Optional[int] = None, collapse_id: Optional[str] = None, - push_type: Optional[NotificationType] = None) -> int: + with httpx.Client(http2=True) as client: + status, reason = self.send_notification_sync(token_hex, notification, client, topic, priority, expiration, + collapse_id) + + if status != 200: + raise exception_class_for_reason(reason) + + def send_notification_sync(self, token_hex: str, notification: Payload, client: httpx.Client, + topic: Optional[str] = None, + priority: NotificationPriority = NotificationPriority.Immediate, + expiration: Optional[int] = None, collapse_id: Optional[str] = None, + push_type: Optional[NotificationType] = None) -> int: json_str = json.dumps(notification.dict(), cls=self.__json_encoder, ensure_ascii=False, separators=(',', ':')) json_payload = json_str.encode('utf-8') @@ -146,120 +121,48 @@ def send_notification_async(self, token_hex: str, notification: Payload, topic: headers['apns-collapse-id'] = collapse_id url = '/3/device/{}'.format(token_hex) - stream_id = self._connection.request('POST', url, json_payload, headers) # type: int - return stream_id + response = client.post(f"https://{self.__server}:{self.__port}{url}", headers=headers, data=json_payload) + return response.status_code, response.text - def get_notification_result(self, stream_id: int) -> Union[str, Tuple[str, str]]: + def get_notification_result(self, status: int, reason: str) -> Union[str, Tuple[str, str]]: """ Get result for specified stream - The function returns: 'Success' or 'failure reason' or ('Unregistered', timestamp) + The function returns: 'Success' or 'failure reason' """ - with self._connection.get_response(stream_id) as response: - if response.status == 200: - return 'Success' - else: - raw_data = response.read().decode('utf-8') - data = json.loads(raw_data) # type: Dict[str, str] - if response.status == 410: - return data['reason'], data['timestamp'] - else: - return data['reason'] + if status == 200: + return 'Success' + else: + return reason def send_notification_batch(self, notifications: Iterable[Notification], topic: Optional[str] = None, priority: NotificationPriority = NotificationPriority.Immediate, expiration: Optional[int] = None, collapse_id: Optional[str] = None, push_type: Optional[NotificationType] = None) -> Dict[str, Union[str, Tuple[str, str]]]: """ - Send a notification to a list of tokens in batch. Instead of sending a synchronous request - for each token, send multiple requests concurrently. This is done on the same connection, - using HTTP/2 streams (one request per stream). - - APNs allows many streams simultaneously, but the number of streams can vary depending on - server load. This method reads the SETTINGS frame sent by the server to figure out the - maximum number of concurrent streams. Typically, APNs reports a maximum of 500. + Send a notification to a list of tokens in batch. The function returns a dictionary mapping each token to its result. The result is "Success" if the token was sent successfully, or the string returned by APNs in the 'reason' field of the response, if the token generated an error. """ - notification_iterator = iter(notifications) - next_notification = next(notification_iterator, None) - # Make sure we're connected to APNs, so that we receive and process the server's SETTINGS - # frame before starting to send notifications. - self.connect() - results = {} - open_streams = collections.deque() # type: typing.Deque[RequestStream] - # Loop on the tokens, sending as many requests as possible concurrently to APNs. - # When reaching the maximum concurrent streams limit, wait for a response before sending - # another request. - while len(open_streams) > 0 or next_notification is not None: - # Update the max_concurrent_streams on every iteration since a SETTINGS frame can be - # sent by the server at any time. - self.update_max_concurrent_streams() - if next_notification is not None and len(open_streams) < self.__max_concurrent_streams: + + # Loop over notifications + with httpx.Client(http2=True) as client: + for next_notification in notifications: logger.info('Sending to token %s', next_notification.token) - stream_id = self.send_notification_async(next_notification.token, next_notification.payload, topic, - priority, expiration, collapse_id, push_type) - open_streams.append(RequestStream(stream_id, next_notification.token)) - - next_notification = next(notification_iterator, None) - if next_notification is None: - # No tokens remaining. Proceed to get results for pending requests. - logger.info('Finished sending all tokens, waiting for pending requests.') - else: - # We have at least one request waiting for response (otherwise we would have either - # sent new requests or exited the while loop.) Wait for the first outstanding stream - # to return a response. - pending_stream = open_streams.popleft() - result = self.get_notification_result(pending_stream.stream_id) - logger.info('Got response for %s: %s', pending_stream.token, result) - results[pending_stream.token] = result + status, reason = self.send_notification_sync(next_notification.token, next_notification.payload, client, + topic, priority, expiration, collapse_id, push_type) + result = self.get_notification_result(status, reason) + logger.info('Got response for %s: %s', next_notification.token, result) + results[next_notification.token] = result return results - def update_max_concurrent_streams(self) -> None: - # Get the max_concurrent_streams setting returned by the server. - # The max_concurrent_streams value is saved in the H2Connection instance that must be - # accessed using a with statement in order to acquire a lock. - # pylint: disable=protected-access - with self._connection._conn as connection: - max_concurrent_streams = connection.remote_settings.max_concurrent_streams - - if max_concurrent_streams == self.__previous_server_max_concurrent_streams: - # The server hasn't issued an updated SETTINGS frame. - return - - self.__previous_server_max_concurrent_streams = max_concurrent_streams - # Handle and log unexpected values sent by APNs, just in case. - if max_concurrent_streams > CONCURRENT_STREAMS_SAFETY_MAXIMUM: - logger.warning('APNs max_concurrent_streams too high (%s), resorting to default maximum (%s)', - max_concurrent_streams, CONCURRENT_STREAMS_SAFETY_MAXIMUM) - self.__max_concurrent_streams = CONCURRENT_STREAMS_SAFETY_MAXIMUM - elif max_concurrent_streams < 1: - logger.warning('APNs reported max_concurrent_streams less than 1 (%s), using value of 1', - max_concurrent_streams) - self.__max_concurrent_streams = 1 - else: - logger.info('APNs set max_concurrent_streams to %s', max_concurrent_streams) - self.__max_concurrent_streams = max_concurrent_streams - def connect(self) -> None: """ Establish a connection to APNs. If already connected, the function does nothing. If the connection fails, the function retries up to MAX_CONNECTION_RETRIES times. """ - retries = 0 - while retries < MAX_CONNECTION_RETRIES: - # noinspection PyBroadException - try: - self._connection.connect() - logger.info('Connected to APNs') - return - except Exception: # pylint: disable=broad-except - # close the connnection, otherwise next connect() call would do nothing - self._connection.close() - retries += 1 - logger.exception('Failed connecting to APNs (attempt %s of %s)', retries, MAX_CONNECTION_RETRIES) - - raise ConnectionFailed() + # Not needed for HTTPX + logger.info('APNsClient.connect called') diff --git a/apns2/credentials.py b/apns2/credentials.py index 028093e..ae8ede7 100644 --- a/apns2/credentials.py +++ b/apns2/credentials.py @@ -3,43 +3,19 @@ import jwt -from hyper import HTTP20Connection # type: ignore -from hyper.tls import init_context # type: ignore - -if TYPE_CHECKING: - from hyper.ssl_compat import SSLContext # type: ignore - DEFAULT_TOKEN_LIFETIME = 2700 DEFAULT_TOKEN_ENCRYPTION_ALGORITHM = 'ES256' # Abstract Base class. This should not be instantiated directly. class Credentials(object): - def __init__(self, ssl_context: 'Optional[SSLContext]' = None) -> None: + def __init__(self): super().__init__() - self.__ssl_context = ssl_context - - # Creates a connection with the credentials, if available or necessary. - def create_connection(self, server: str, port: int, proto: Optional[str], proxy_host: Optional[str] = None, - proxy_port: Optional[int] = None) -> HTTP20Connection: - # self.__ssl_context may be none, and that's fine. - return HTTP20Connection(server, port, ssl_context=self.__ssl_context, force_proto=proto or 'h2', - secure=True, proxy_host=proxy_host, proxy_port=proxy_port) def get_authorization_header(self, topic: Optional[str]) -> Optional[str]: return None -# Credentials subclass for certificate authentication -class CertificateCredentials(Credentials): - def __init__(self, cert_file: Optional[str] = None, password: Optional[str] = None, - cert_chain: Optional[str] = None) -> None: - ssl_context = init_context(cert=cert_file, cert_password=password) - if cert_chain: - ssl_context.load_cert_chain(cert_chain) - super(CertificateCredentials, self).__init__(ssl_context) - - # Credentials subclass for JWT token based authentication class TokenCredentials(Credentials): def __init__(self, auth_key_path: str, auth_key_id: str, team_id: str, diff --git a/pyproject.toml b/pyproject.toml index ac5145a..b9f5cb5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,7 +5,7 @@ build-backend = "poetry.core.masonry.api" [tool] [tool.poetry] name = "apns2" -version = "0.7.1" +version = "0.7.3" description = "A python library for interacting with the Apple Push Notification Service via HTTP/2 protocol" readme = 'README.md' authors = [ @@ -25,7 +25,7 @@ classifiers = [ [tool.poetry.dependencies] python = ">=3.7" cryptography = ">=1.7.2" -hyper = ">=0.7" +httpx = {version = ">=0.13.0", extras = ["http2"] } pyjwt = ">=2.0.0" [tool.poetry.dev-dependencies] From 97bb4c03942822740a6c746e9d21897c992f2195 Mon Sep 17 00:00:00 2001 From: Ori Avtalion Date: Tue, 16 Aug 2022 17:04:36 +0300 Subject: [PATCH 2/4] Add httpx authorization by credentials --- apns2/client.py | 23 ++++++++++++++--------- apns2/credentials.py | 14 ++++++++++++-- pyproject.toml | 5 +++-- 3 files changed, 29 insertions(+), 13 deletions(-) diff --git a/apns2/client.py b/apns2/client.py index e2f751c..5b9a827 100644 --- a/apns2/client.py +++ b/apns2/client.py @@ -5,7 +5,7 @@ from enum import Enum from typing import Dict, Iterable, Optional, Tuple, Union -from .credentials import TokenCredentials +from .credentials import Credentials, CertificateCredentials, TokenCredentials from .errors import exception_class_for_reason # We don't generally need to know about the Credentials subclasses except to # keep the old API, where APNsClient took a cert_file @@ -44,13 +44,17 @@ class APNsClient(object): ALTERNATIVE_PORT = 2197 def __init__(self, - credentials: TokenCredentials, + credentials: Union[Credentials, str], use_sandbox: bool = False, use_alternative_port: bool = False, proto: Optional[str] = None, json_encoder: Optional[type] = None, password: Optional[str] = None, proxy_host: Optional[str] = None, proxy_port: Optional[int] = None, heartbeat_period: Optional[float] = None) -> None: - self.__credentials = credentials + if isinstance(credentials, str): + self.__credentials = CertificateCredentials(credentials, password) + else: + self.__credentials = credentials + self._init_connection(use_sandbox, use_alternative_port, proto, proxy_host, proxy_port) if heartbeat_period: @@ -113,15 +117,16 @@ def send_notification_sync(self, token_hex: str, notification: Payload, client: if expiration is not None: headers['apns-expiration'] = '%d' % expiration - auth_header = self.__credentials.get_authorization_header(topic) - if auth_header is not None: - headers['authorization'] = auth_header + if isinstance(self.__credentials, TokenCredentials): + auth_header = self.__credentials.get_authorization_header(topic) + if auth_header is not None: + headers['authorization'] = auth_header if collapse_id is not None: headers['apns-collapse-id'] = collapse_id - url = '/3/device/{}'.format(token_hex) - response = client.post(f"https://{self.__server}:{self.__port}{url}", headers=headers, data=json_payload) + url = f'https://{self.__server}:{self.__port}/3/device/{token_hex}' + response = client.post(url, headers=headers, data=json_payload) return response.status_code, response.text def get_notification_result(self, status: int, reason: str) -> Union[str, Tuple[str, str]]: @@ -148,7 +153,7 @@ def send_notification_batch(self, notifications: Iterable[Notification], topic: results = {} # Loop over notifications - with httpx.Client(http2=True) as client: + with httpx.Client(http2=True, verify=self.__credentials.ssl_context) as client: for next_notification in notifications: logger.info('Sending to token %s', next_notification.token) status, reason = self.send_notification_sync(next_notification.token, next_notification.payload, client, diff --git a/apns2/credentials.py b/apns2/credentials.py index ae8ede7..d71bffb 100644 --- a/apns2/credentials.py +++ b/apns2/credentials.py @@ -1,5 +1,6 @@ +import ssl import time -from typing import Optional, Tuple, TYPE_CHECKING +from typing import Optional, Tuple import jwt @@ -9,13 +10,22 @@ # Abstract Base class. This should not be instantiated directly. class Credentials(object): - def __init__(self): + def __init__(self, ssl_context: Optional[ssl.SSLContext] = None) -> None: super().__init__() + self.ssl_context = ssl_context def get_authorization_header(self, topic: Optional[str]) -> Optional[str]: return None +# Credentials subclass for certificate authentication +class CertificateCredentials(Credentials): + def __init__(self, cert_file: Optional[str] = None, password: Optional[str] = None) -> None: + ssl_context = ssl.create_default_context() + ssl_context.load_cert_chain(cert_file, password=password) + super(CertificateCredentials, self).__init__(ssl_context) + + # Credentials subclass for JWT token based authentication class TokenCredentials(Credentials): def __init__(self, auth_key_path: str, auth_key_id: str, team_id: str, diff --git a/pyproject.toml b/pyproject.toml index b9f5cb5..c400f91 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,7 +5,7 @@ build-backend = "poetry.core.masonry.api" [tool] [tool.poetry] name = "apns2" -version = "0.7.3" +version = "0.7.1" description = "A python library for interacting with the Apple Push Notification Service via HTTP/2 protocol" readme = 'README.md' authors = [ @@ -19,6 +19,7 @@ classifiers = [ "Programming Language :: Python :: 3.7", "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", "Topic :: Software Development :: Libraries" ] @@ -49,7 +50,7 @@ disable = "missing-docstring, too-few-public-methods, locally-disabled, invalid- [tool.tox] legacy_tox_ini = """ [tox] -envlist = py37, py38, py39 +envlist = py37, py38, py39, py310 isolated_build = True [testenv] From 9bda01969eb5d62d82a378247d68178732de5ec5 Mon Sep 17 00:00:00 2001 From: Ori Avtalion Date: Thu, 14 Dec 2023 17:51:08 +0200 Subject: [PATCH 3/4] Fix missing ssl certificate in send_notification --- apns2/client.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/apns2/client.py b/apns2/client.py index 5b9a827..abf9a6c 100644 --- a/apns2/client.py +++ b/apns2/client.py @@ -70,7 +70,7 @@ def _init_connection(self, use_sandbox: bool, use_alternative_port: bool, proto: def send_notification(self, token_hex: str, notification: Payload, topic: Optional[str] = None, priority: NotificationPriority = NotificationPriority.Immediate, expiration: Optional[int] = None, collapse_id: Optional[str] = None) -> None: - with httpx.Client(http2=True) as client: + with httpx.Client(http2=True, verify=self.__credentials.ssl_context) as client: status, reason = self.send_notification_sync(token_hex, notification, client, topic, priority, expiration, collapse_id) From 897f80e3f1d9c6582880ce1c52839d7784223c30 Mon Sep 17 00:00:00 2001 From: Ori Avtalion Date: Sat, 10 Feb 2024 11:45:04 +0200 Subject: [PATCH 4/4] Update signature of cert_file --- apns2/credentials.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/apns2/credentials.py b/apns2/credentials.py index d71bffb..fb74ba4 100644 --- a/apns2/credentials.py +++ b/apns2/credentials.py @@ -1,6 +1,7 @@ import ssl import time -from typing import Optional, Tuple +from os import PathLike +from typing import Optional, Tuple, Union import jwt @@ -20,7 +21,8 @@ def get_authorization_header(self, topic: Optional[str]) -> Optional[str]: # Credentials subclass for certificate authentication class CertificateCredentials(Credentials): - def __init__(self, cert_file: Optional[str] = None, password: Optional[str] = None) -> None: + def __init__(self, cert_file: Optional[Union[str, bytes, PathLike[str], PathLike[bytes]]] = None, + password: Optional[str] = None) -> None: ssl_context = ssl.create_default_context() ssl_context.load_cert_chain(cert_file, password=password) super(CertificateCredentials, self).__init__(ssl_context)