diff --git a/apns2/client.py b/apns2/client.py index 0947350..abf9a6c 100644 --- a/apns2/client.py +++ b/apns2/client.py @@ -1,15 +1,12 @@ import collections +import httpx import json import logging -import time -import typing -import weakref from enum import Enum -from threading import Thread from typing import Dict, Iterable, Optional, Tuple, Union -from .credentials import CertificateCredentials, Credentials -from .errors import ConnectionFailed, exception_class_for_reason +from .credentials import Credentials, CertificateCredentials, TokenCredentials +from .errors import exception_class_for_reason # We don't generally need to know about the Credentials subclasses except to # keep the old API, where APNsClient took a cert_file from .payload import Payload @@ -29,7 +26,7 @@ class NotificationType(Enum): MDM = 'mdm' -RequestStream = collections.namedtuple('RequestStream', ['stream_id', 'token']) +RequestStream = collections.namedtuple('RequestStream', ['token', 'status', 'reason']) Notification = collections.namedtuple('Notification', ['token', 'payload']) DEFAULT_APNS_PRIORITY = NotificationPriority.Immediate @@ -52,57 +49,39 @@ def __init__(self, json_encoder: Optional[type] = None, password: Optional[str] = None, proxy_host: Optional[str] = None, proxy_port: Optional[int] = None, heartbeat_period: Optional[float] = None) -> None: + if isinstance(credentials, str): - self.__credentials = CertificateCredentials(credentials, password) # type: Credentials + self.__credentials = CertificateCredentials(credentials, password) else: self.__credentials = credentials + self._init_connection(use_sandbox, use_alternative_port, proto, proxy_host, proxy_port) if heartbeat_period: - self._start_heartbeat(heartbeat_period) + raise NotImplementedError("heartbeat not supported") self.__json_encoder = json_encoder - self.__max_concurrent_streams = 0 - self.__previous_server_max_concurrent_streams = None def _init_connection(self, use_sandbox: bool, use_alternative_port: bool, proto: Optional[str], proxy_host: Optional[str], proxy_port: Optional[int]) -> None: - server = self.SANDBOX_SERVER if use_sandbox else self.LIVE_SERVER - port = self.ALTERNATIVE_PORT if use_alternative_port else self.DEFAULT_PORT - self._connection = self.__credentials.create_connection(server, port, proto, proxy_host, proxy_port) - - def _start_heartbeat(self, heartbeat_period: float) -> None: - conn_ref = weakref.ref(self._connection) - - def watchdog() -> None: - while True: - conn = conn_ref() - if conn is None: - break - - conn.ping('-' * 8) - time.sleep(heartbeat_period) - - thread = Thread(target=watchdog) - thread.setDaemon(True) - thread.start() + self.__server = self.SANDBOX_SERVER if use_sandbox else self.LIVE_SERVER + self.__port = self.ALTERNATIVE_PORT if use_alternative_port else self.DEFAULT_PORT def send_notification(self, token_hex: str, notification: Payload, topic: Optional[str] = None, priority: NotificationPriority = NotificationPriority.Immediate, expiration: Optional[int] = None, collapse_id: Optional[str] = None) -> None: - stream_id = self.send_notification_async(token_hex, notification, topic, priority, expiration, collapse_id) - result = self.get_notification_result(stream_id) - if result != 'Success': - if isinstance(result, tuple): - reason, info = result - raise exception_class_for_reason(reason)(info) - else: - raise exception_class_for_reason(result) - - def send_notification_async(self, token_hex: str, notification: Payload, topic: Optional[str] = None, - priority: NotificationPriority = NotificationPriority.Immediate, - expiration: Optional[int] = None, collapse_id: Optional[str] = None, - push_type: Optional[NotificationType] = None) -> int: + with httpx.Client(http2=True, verify=self.__credentials.ssl_context) as client: + status, reason = self.send_notification_sync(token_hex, notification, client, topic, priority, expiration, + collapse_id) + + if status != 200: + raise exception_class_for_reason(reason) + + def send_notification_sync(self, token_hex: str, notification: Payload, client: httpx.Client, + topic: Optional[str] = None, + priority: NotificationPriority = NotificationPriority.Immediate, + expiration: Optional[int] = None, collapse_id: Optional[str] = None, + push_type: Optional[NotificationType] = None) -> int: json_str = json.dumps(notification.dict(), cls=self.__json_encoder, ensure_ascii=False, separators=(',', ':')) json_payload = json_str.encode('utf-8') @@ -138,128 +117,57 @@ def send_notification_async(self, token_hex: str, notification: Payload, topic: if expiration is not None: headers['apns-expiration'] = '%d' % expiration - auth_header = self.__credentials.get_authorization_header(topic) - if auth_header is not None: - headers['authorization'] = auth_header + if isinstance(self.__credentials, TokenCredentials): + auth_header = self.__credentials.get_authorization_header(topic) + if auth_header is not None: + headers['authorization'] = auth_header if collapse_id is not None: headers['apns-collapse-id'] = collapse_id - url = '/3/device/{}'.format(token_hex) - stream_id = self._connection.request('POST', url, json_payload, headers) # type: int - return stream_id + url = f'https://{self.__server}:{self.__port}/3/device/{token_hex}' + response = client.post(url, headers=headers, data=json_payload) + return response.status_code, response.text - def get_notification_result(self, stream_id: int) -> Union[str, Tuple[str, str]]: + def get_notification_result(self, status: int, reason: str) -> Union[str, Tuple[str, str]]: """ Get result for specified stream - The function returns: 'Success' or 'failure reason' or ('Unregistered', timestamp) + The function returns: 'Success' or 'failure reason' """ - with self._connection.get_response(stream_id) as response: - if response.status == 200: - return 'Success' - else: - raw_data = response.read().decode('utf-8') - data = json.loads(raw_data) # type: Dict[str, str] - if response.status == 410: - return data['reason'], data['timestamp'] - else: - return data['reason'] + if status == 200: + return 'Success' + else: + return reason def send_notification_batch(self, notifications: Iterable[Notification], topic: Optional[str] = None, priority: NotificationPriority = NotificationPriority.Immediate, expiration: Optional[int] = None, collapse_id: Optional[str] = None, push_type: Optional[NotificationType] = None) -> Dict[str, Union[str, Tuple[str, str]]]: """ - Send a notification to a list of tokens in batch. Instead of sending a synchronous request - for each token, send multiple requests concurrently. This is done on the same connection, - using HTTP/2 streams (one request per stream). - - APNs allows many streams simultaneously, but the number of streams can vary depending on - server load. This method reads the SETTINGS frame sent by the server to figure out the - maximum number of concurrent streams. Typically, APNs reports a maximum of 500. + Send a notification to a list of tokens in batch. The function returns a dictionary mapping each token to its result. The result is "Success" if the token was sent successfully, or the string returned by APNs in the 'reason' field of the response, if the token generated an error. """ - notification_iterator = iter(notifications) - next_notification = next(notification_iterator, None) - # Make sure we're connected to APNs, so that we receive and process the server's SETTINGS - # frame before starting to send notifications. - self.connect() - results = {} - open_streams = collections.deque() # type: typing.Deque[RequestStream] - # Loop on the tokens, sending as many requests as possible concurrently to APNs. - # When reaching the maximum concurrent streams limit, wait for a response before sending - # another request. - while len(open_streams) > 0 or next_notification is not None: - # Update the max_concurrent_streams on every iteration since a SETTINGS frame can be - # sent by the server at any time. - self.update_max_concurrent_streams() - if next_notification is not None and len(open_streams) < self.__max_concurrent_streams: + + # Loop over notifications + with httpx.Client(http2=True, verify=self.__credentials.ssl_context) as client: + for next_notification in notifications: logger.info('Sending to token %s', next_notification.token) - stream_id = self.send_notification_async(next_notification.token, next_notification.payload, topic, - priority, expiration, collapse_id, push_type) - open_streams.append(RequestStream(stream_id, next_notification.token)) - - next_notification = next(notification_iterator, None) - if next_notification is None: - # No tokens remaining. Proceed to get results for pending requests. - logger.info('Finished sending all tokens, waiting for pending requests.') - else: - # We have at least one request waiting for response (otherwise we would have either - # sent new requests or exited the while loop.) Wait for the first outstanding stream - # to return a response. - pending_stream = open_streams.popleft() - result = self.get_notification_result(pending_stream.stream_id) - logger.info('Got response for %s: %s', pending_stream.token, result) - results[pending_stream.token] = result + status, reason = self.send_notification_sync(next_notification.token, next_notification.payload, client, + topic, priority, expiration, collapse_id, push_type) + result = self.get_notification_result(status, reason) + logger.info('Got response for %s: %s', next_notification.token, result) + results[next_notification.token] = result return results - def update_max_concurrent_streams(self) -> None: - # Get the max_concurrent_streams setting returned by the server. - # The max_concurrent_streams value is saved in the H2Connection instance that must be - # accessed using a with statement in order to acquire a lock. - # pylint: disable=protected-access - with self._connection._conn as connection: - max_concurrent_streams = connection.remote_settings.max_concurrent_streams - - if max_concurrent_streams == self.__previous_server_max_concurrent_streams: - # The server hasn't issued an updated SETTINGS frame. - return - - self.__previous_server_max_concurrent_streams = max_concurrent_streams - # Handle and log unexpected values sent by APNs, just in case. - if max_concurrent_streams > CONCURRENT_STREAMS_SAFETY_MAXIMUM: - logger.warning('APNs max_concurrent_streams too high (%s), resorting to default maximum (%s)', - max_concurrent_streams, CONCURRENT_STREAMS_SAFETY_MAXIMUM) - self.__max_concurrent_streams = CONCURRENT_STREAMS_SAFETY_MAXIMUM - elif max_concurrent_streams < 1: - logger.warning('APNs reported max_concurrent_streams less than 1 (%s), using value of 1', - max_concurrent_streams) - self.__max_concurrent_streams = 1 - else: - logger.info('APNs set max_concurrent_streams to %s', max_concurrent_streams) - self.__max_concurrent_streams = max_concurrent_streams - def connect(self) -> None: """ Establish a connection to APNs. If already connected, the function does nothing. If the connection fails, the function retries up to MAX_CONNECTION_RETRIES times. """ - retries = 0 - while retries < MAX_CONNECTION_RETRIES: - # noinspection PyBroadException - try: - self._connection.connect() - logger.info('Connected to APNs') - return - except Exception: # pylint: disable=broad-except - # close the connnection, otherwise next connect() call would do nothing - self._connection.close() - retries += 1 - logger.exception('Failed connecting to APNs (attempt %s of %s)', retries, MAX_CONNECTION_RETRIES) - - raise ConnectionFailed() + # Not needed for HTTPX + logger.info('APNsClient.connect called') diff --git a/apns2/credentials.py b/apns2/credentials.py index 028093e..fb74ba4 100644 --- a/apns2/credentials.py +++ b/apns2/credentials.py @@ -1,30 +1,19 @@ +import ssl import time -from typing import Optional, Tuple, TYPE_CHECKING +from os import PathLike +from typing import Optional, Tuple, Union import jwt -from hyper import HTTP20Connection # type: ignore -from hyper.tls import init_context # type: ignore - -if TYPE_CHECKING: - from hyper.ssl_compat import SSLContext # type: ignore - DEFAULT_TOKEN_LIFETIME = 2700 DEFAULT_TOKEN_ENCRYPTION_ALGORITHM = 'ES256' # Abstract Base class. This should not be instantiated directly. class Credentials(object): - def __init__(self, ssl_context: 'Optional[SSLContext]' = None) -> None: + def __init__(self, ssl_context: Optional[ssl.SSLContext] = None) -> None: super().__init__() - self.__ssl_context = ssl_context - - # Creates a connection with the credentials, if available or necessary. - def create_connection(self, server: str, port: int, proto: Optional[str], proxy_host: Optional[str] = None, - proxy_port: Optional[int] = None) -> HTTP20Connection: - # self.__ssl_context may be none, and that's fine. - return HTTP20Connection(server, port, ssl_context=self.__ssl_context, force_proto=proto or 'h2', - secure=True, proxy_host=proxy_host, proxy_port=proxy_port) + self.ssl_context = ssl_context def get_authorization_header(self, topic: Optional[str]) -> Optional[str]: return None @@ -32,11 +21,10 @@ def get_authorization_header(self, topic: Optional[str]) -> Optional[str]: # Credentials subclass for certificate authentication class CertificateCredentials(Credentials): - def __init__(self, cert_file: Optional[str] = None, password: Optional[str] = None, - cert_chain: Optional[str] = None) -> None: - ssl_context = init_context(cert=cert_file, cert_password=password) - if cert_chain: - ssl_context.load_cert_chain(cert_chain) + def __init__(self, cert_file: Optional[Union[str, bytes, PathLike[str], PathLike[bytes]]] = None, + password: Optional[str] = None) -> None: + ssl_context = ssl.create_default_context() + ssl_context.load_cert_chain(cert_file, password=password) super(CertificateCredentials, self).__init__(ssl_context) diff --git a/pyproject.toml b/pyproject.toml index ac5145a..c400f91 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,13 +19,14 @@ classifiers = [ "Programming Language :: Python :: 3.7", "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", "Topic :: Software Development :: Libraries" ] [tool.poetry.dependencies] python = ">=3.7" cryptography = ">=1.7.2" -hyper = ">=0.7" +httpx = {version = ">=0.13.0", extras = ["http2"] } pyjwt = ">=2.0.0" [tool.poetry.dev-dependencies] @@ -49,7 +50,7 @@ disable = "missing-docstring, too-few-public-methods, locally-disabled, invalid- [tool.tox] legacy_tox_ini = """ [tox] -envlist = py37, py38, py39 +envlist = py37, py38, py39, py310 isolated_build = True [testenv]