Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve SSL handshake retries by only counting failed attempts #128

Merged
merged 2 commits into from
Jan 30, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 48 additions & 45 deletions emailproxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
__author__ = 'Simon Robinson'
__copyright__ = 'Copyright (c) 2022 Simon Robinson'
__license__ = 'Apache 2.0'
__version__ = '2022-12-23' # ISO 8601 (YYYY-MM-DD)
__version__ = '2023-01-21' # ISO 8601 (YYYY-MM-DD)

import argparse
import base64
Expand All @@ -25,6 +25,7 @@
import plistlib
import queue
import re
import select
import signal
import socket
import ssl
Expand Down Expand Up @@ -127,7 +128,7 @@ class NSObject:

RECEIVE_BUFFER_SIZE = 65536 # number of bytes to try to read from the socket at a time (limit is per socket)

MAX_SSL_HANDSHAKE_ATTEMPTS = 65536 # maximum number of attempts before aborting local SSL/TLS handshake; 0 = no limit
MAX_SSL_HANDSHAKE_ATTEMPTS = 1024 # number of attempts before aborting SSL/TLS handshake (max 10ms each); 0 = no limit

# IMAP/POP/SMTP require \r\n as a line terminator (we use lines only pre-authentication; afterwards just pass through)
LINE_TERMINATOR = b'\r\n'
Expand Down Expand Up @@ -765,9 +766,11 @@ def decode_credentials(str_data):
class SSLAsyncoreDispatcher(asyncore.dispatcher_with_send):
def __init__(self, connection=None, socket_map=None):
asyncore.dispatcher_with_send.__init__(self, sock=connection, map=socket_map)
self.ssl_connection, self.ssl_handshake_attempts, self.ssl_handshake_completed = self.reset()
self.ssl_handshake_errors = (ssl.SSLWantReadError, ssl.SSLWantWriteError,
ssl.SSLEOFError, ssl.SSLZeroReturnError)
self.ssl_connection, self.ssl_handshake_attempts, self.ssl_handshake_completed = self._reset()

def reset(self, is_ssl=False):
def _reset(self, is_ssl=False):
self.ssl_connection = is_ssl
self.ssl_handshake_attempts = 0
self.ssl_handshake_completed = not is_ssl
Expand All @@ -779,88 +782,90 @@ def info_string(self):
def set_ssl_connection(self, is_ssl=False):
# note that the actual SSLContext.wrap_socket (and associated unwrap()) are handled outside this class
if not self.ssl_connection and is_ssl:
self.reset(True)
self._reset(True)
if is_ssl:
# we don't start negotiation here because a failed handshake in __init__ means remove_client also fails
Log.debug(self.info_string(), '<-> [ Starting TLS handshake ]')

elif self.ssl_connection and not is_ssl:
self.reset()
self._reset()

def ssl_handshake(self):
def _ssl_handshake(self):
if not isinstance(self.socket, ssl.SSLSocket):
Log.error(self.info_string(), 'Unable to initiate handshake with a non-SSL socket; aborting')
raise ssl.SSLError(-1, APP_PACKAGE)

# attempting to connect insecurely to a secure socket could loop indefinitely here - we set a maximum attempt
# count and catch in handle_error() when `ssl_handshake_attempts` expires, but there's not much else we can do
self.ssl_handshake_attempts += 1
if 0 < MAX_SSL_HANDSHAKE_ATTEMPTS < self.ssl_handshake_attempts:
Log.error(self.info_string(), 'SSL socket handshake failed (reached `MAX_SSL_HANDSHAKE_ATTEMPTS`)')
raise ssl.SSLError(-1, APP_PACKAGE)

# see: https://github.com/python/cpython/issues/54293
try:
# note that attempting to connect insecurely to a secure socket may loop indefinitely here - we attempt
# to catch this in handle_error() when the client gives up, but there's not much else we can do
# noinspection PyUnresolvedReferences
self.socket.do_handshake()
except (ssl.SSLWantReadError, ssl.SSLWantWriteError):
pass
except (ssl.SSLEOFError, ssl.SSLZeroReturnError):
self.handle_close()
except self.ssl_handshake_errors as e:
self._ssl_handshake_error(e)
else:
Log.debug(self.info_string(), '<-> [ TLS handshake complete ]')
Log.debug(self.info_string(), '<-> [', self.socket.version(), 'handshake complete ]')
self.ssl_handshake_attempts = 0
self.ssl_handshake_completed = True

def _ssl_handshake_error(self, error):
if isinstance(error, ssl.SSLWantReadError):
select.select([self.socket], [], [], 0.01) # wait for the socket to be readable (10ms timeout)
elif isinstance(error, ssl.SSLWantWriteError):
select.select([], [self.socket], [], 0.01) # wait for the socket to be writable (10ms timeout)
else:
self.handle_close()

def handle_read_event(self):
# additional Exceptions are propagated to handle_error(); no need to handle here
if not self.ssl_handshake_completed:
self.ssl_handshake()
self._ssl_handshake()
else:
# on the first connection event to a secure server we need to handle SSL handshake events (because we don't
# have a 'not_currently_ssl_but_will_be_once_connected'-type state) - a version of this class that didn't
# have to deal with both unsecured, wrapped *and* STARTTLS-type sockets would only need this in recv/send
try:
super().handle_read_event()
except (ssl.SSLWantReadError, ssl.SSLWantWriteError):
self.ssl_handshake_completed = False
except (ssl.SSLEOFError, ssl.SSLZeroReturnError):
self.handle_close()
except self.ssl_handshake_errors as e:
self._ssl_handshake_error(e)

def handle_write_event(self):
# additional Exceptions are propagated to handle_error(); no need to handle here
if not self.ssl_handshake_completed:
self.ssl_handshake()
self._ssl_handshake()
else:
# as in handle_read_event, we need to handle SSL handshake events
try:
super().handle_write_event()
except (ssl.SSLWantReadError, ssl.SSLWantWriteError):
self.ssl_handshake_completed = False
except (ssl.SSLEOFError, ssl.SSLZeroReturnError):
self.handle_close()
except self.ssl_handshake_errors as e:
self._ssl_handshake_error(e)

def recv(self, buffer_size):
# additional Exceptions are propagated to handle_error(); no need to handle here
try:
return super().recv(buffer_size)
except (ssl.SSLWantReadError, ssl.SSLWantWriteError):
self.ssl_handshake_completed = False
except (ssl.SSLEOFError, ssl.SSLZeroReturnError):
self.handle_close()
except self.ssl_handshake_errors as e:
self._ssl_handshake_error(e)
return b''

def send(self, byte_data):
# additional Exceptions are propagated to handle_error(); no need to handle here
try:
return super().send(byte_data) # buffers before sending via the socket, so failure is okay; will auto-retry
except (ssl.SSLWantReadError, ssl.SSLWantWriteError):
self.ssl_handshake_completed = False
except (ssl.SSLEOFError, ssl.SSLZeroReturnError):
self.handle_close()
except self.ssl_handshake_errors as e:
self._ssl_handshake_error(e)
return 0

def handle_error(self):
error_type, value, _traceback = sys.exc_info()
del _traceback # used to be required in python 2; may no-longer be needed, but best to be safe
if self.ssl_connection:
# OSError 0 ('Error') and SSL errors here are caused by connection handshake failures or timeouts
# APP_PACKAGE is used when we throw our own SSLError on handshake timeout
# APP_PACKAGE is used when we throw our own SSLError on handshake timeout or socket misconfiguration
ssl_errors = ['SSLV3_ALERT_BAD_CERTIFICATE', 'PEER_DID_NOT_RETURN_A_CERTIFICATE', 'WRONG_VERSION_NUMBER',
'CERTIFICATE_VERIFY_FAILED', 'TLSV1_ALERT_PROTOCOL_VERSION', 'TLSV1_ALERT_UNKNOWN_CA',
APP_PACKAGE]
Expand Down Expand Up @@ -905,7 +910,7 @@ def __init__(self, proxy_type, connection, socket_map, connection_info, server_c
self.authenticated = False

self.set_ssl_connection(
custom_configuration['local_certificate_path'] and custom_configuration['local_key_path'])
bool(custom_configuration['local_certificate_path'] and custom_configuration['local_key_path']))

def info_string(self):
debug_string = '; %s:%d->%s:%d' % (self.connection_info[0], self.connection_info[1], self.server_address[0],
Expand Down Expand Up @@ -1382,10 +1387,10 @@ def handle_error(self):
error_type, value, _traceback = sys.exc_info()
del _traceback # used to be required in python 2; may no-longer be needed, but best to be safe
if error_type == TimeoutError and value.errno == errno.ETIMEDOUT or \
error_type == ConnectionResetError and value.errno == errno.ECONNRESET or \
issubclass(error_type, ConnectionError) and value.errno in [errno.ECONNRESET, errno.ECONNREFUSED] or \
error_type == OSError and value.errno in [0, errno.ENETDOWN, errno.EHOSTUNREACH]:
# TimeoutError 60 = 'Operation timed out'; # ConnectionResetError 54 = 'Connection reset by peer'; OSError
# 0 = 'Error' (typically network failure); OSError 50 = 'Network is down'; OSError 65 = 'No route to host'
# TimeoutError 60 = 'Operation timed out'; ConnectionError 54 = 'Connection reset by peer', 61 = 'Connection
# refused; OSError 0 = 'Error' (typically network failure), 50 = 'Network is down', 65 = 'No route to host'
Log.info(self.info_string(), 'Caught network error (server) - is there a network connection?',
'Error type', error_type, 'with message:', value)
self.handle_close()
Expand Down Expand Up @@ -1751,14 +1756,12 @@ def handle_error(self):
del _traceback # used to be required in python 2; may no-longer be needed, but best to be safe
if error_type == socket.gaierror and value.errno in [8, 11001] or \
error_type == TimeoutError and value.errno == errno.ETIMEDOUT or \
error_type == ConnectionResetError and value.errno == errno.ECONNRESET or \
error_type == ConnectionRefusedError and value.errno == errno.ECONNREFUSED or \
issubclass(error_type, ConnectionError) and value.errno in [errno.ECONNRESET, errno.ECONNREFUSED] or \
error_type == OSError and value.errno in [0, errno.EINVAL, errno.ENETDOWN, errno.EHOSTUNREACH]:
# gaierror 8 = 'nodename nor servname provided, or not known'; gaierror 11001 = 'getaddrinfo failed'
# (caused by getpeername() failing due to no network connection); TimeoutError 60 = 'Operation timed out';
# ConnectionResetError 54 = 'Connection reset by peer'; ConnectionRefusedError 61 = 'Connection refused';
# OSError 0 = 'Error' (local SSL failure); OSError 22 = 'Invalid argument' (same cause as gaierror 11001);
# OSError 50 = 'Network is down'; OSError 65 = 'No route to host'
# gaierror 8 = 'nodename nor servname provided, or not known', gaierror 11001 = 'getaddrinfo failed' (caused
# by getpeername() failing due to no connection); TimeoutError 60 = 'Operation timed out'; ConnectionError
# 54 = 'Connection reset by peer', 61 = 'Connection refused; OSError 0 = 'Error' (local SSL failure),
# 22 = 'Invalid argument' (same cause as gaierror 11001), 50 = 'Network is down', 65 = 'No route to host'
Log.info('Caught network error in', self.info_string(), '- is there a network connection?',
'Error type', error_type, 'with message:', value)
else:
Expand Down