diff --git a/socketIO_client/__init__.py b/socketIO_client/__init__.py index dcf40c2..28d57fb 100644 --- a/socketIO_client/__init__.py +++ b/socketIO_client/__init__.py @@ -59,6 +59,7 @@ def __init__( def _transport(self): if self._opened: return self._transport_instance + self._engineIO_session = None self._engineIO_session = self._get_engineIO_session() self._negotiate_transport() self._connect_namespaces() @@ -69,11 +70,13 @@ def _transport(self): def _get_engineIO_session(self): warning_screen = self._yield_warning_screen() for elapsed_time in warning_screen: - transport = XHR_PollingTransport( - self._http_session, self._is_secure, self._url) try: + transport_name = self._client_transports[0] + self._transport_instance = self._get_transport(transport_name) + self.transport_name = transport_name + engineIO_packet_type, engineIO_packet_data = next( - transport.recv_packet()) + self._transport_instance.recv_packet()) break except (TimeoutError, ConnectionError) as e: if not self._wait_for_connection: @@ -81,27 +84,41 @@ def _get_engineIO_session(self): warning = Exception( '[engine.io waiting for connection] %s' % e) warning_screen.throw(warning) + assert engineIO_packet_type == 0 # engineIO_packet_type == open - return parse_engineIO_session(engineIO_packet_data) + session = parse_engineIO_session(engineIO_packet_data) + # Set the timeout on the WebSocket transport if needed, since we didn't + # have it earlier. + self._transport_instance.set_timeout(session.ping_timeout) + return session def _negotiate_transport(self): - self._transport_instance = self._get_transport('xhr-polling') - self.transport_name = 'xhr-polling' - is_ws_client = 'websocket' in self._client_transports - is_ws_server = 'websocket' in self._engineIO_session.transport_upgrades - if is_ws_client and is_ws_server: - try: - transport = self._get_transport('websocket') - transport.send_packet(2, 'probe') - for packet_type, packet_data in transport.recv_packet(): - if packet_type == 3 and packet_data == b'probe': - transport.send_packet(5, '') - self._transport_instance = transport - self.transport_name = 'websocket' - else: - self._warn('unexpected engine.io packet') - except Exception: - pass + if self.transport_name != 'websocket': + # If not using websocket transport, recreate initial transport with + # session information. This is essentially free and ensures correct + # initialization. + self.transport_name = self._client_transports[0] + self._transport_instance = self._get_transport(self.transport_name) + + # Attempt to upgrade to websocket transport if possible. + is_ws_client = \ + 'websocket' in self._client_transports + is_ws_server = \ + 'websocket' in self._engineIO_session.transport_upgrades + if is_ws_client and is_ws_server: + try: + transport = self._get_transport('websocket') + transport.send_packet(2, 'probe') + for packet_type, packet_data in transport.recv_packet(): + if packet_type == 3 and packet_data == b'probe': + transport.send_packet(5, '') + self._transport_instance = transport + self.transport_name = 'websocket' + else: + self._warn('unexpected engine.io packet') + except Exception: + pass + self._debug('[engine.io transport selected] %s', self.transport_name) def _reset_heartbeat(self): diff --git a/socketIO_client/tests/.gitignore b/socketIO_client/tests/.gitignore new file mode 100644 index 0000000..3c3629e --- /dev/null +++ b/socketIO_client/tests/.gitignore @@ -0,0 +1 @@ +node_modules diff --git a/socketIO_client/tests/__init__.py b/socketIO_client/tests/__init__.py index 8da37a1..1887249 100644 --- a/socketIO_client/tests/__init__.py +++ b/socketIO_client/tests/__init__.py @@ -323,6 +323,16 @@ def setUp(self): self.assertEqual(self.socketIO.transport_name, 'websocket') +class Test_WebsocketTransport_Only(BaseMixin, TestCase): + + def setUp(self): + super(Test_WebsocketTransport_Only, self).setUp() + self.socketIO = SocketIO(HOST, PORT, LoggingNamespace, transports=[ + 'websocket'], verify=False) + self.assertEqual(self.socketIO.transport_name, 'websocket') + self.assertEqual(self.socketIO._transport._timeout, 60) + + class Namespace(LoggingNamespace): def initialize(self): diff --git a/socketIO_client/transports.py b/socketIO_client/transports.py index 85baa06..562e95e 100644 --- a/socketIO_client/transports.py +++ b/socketIO_client/transports.py @@ -118,6 +118,8 @@ def __init__(self, http_session, is_secure, url, engineIO_session=None): if engineIO_session: params['sid'] = engineIO_session.id kw['timeout'] = self._timeout = engineIO_session.ping_timeout + else: + self._timeout = None ws_url = '%s://%s/?%s' % ( 'wss' if is_secure else 'ws', url, format_query(params)) http_scheme = 'https' if is_secure else 'http' @@ -168,6 +170,8 @@ def send_packet(self, engineIO_packet_type, engineIO_packet_data=''): raise ConnectionError('send disconnected (%s)' % e) def set_timeout(self, seconds=None): + if self._timeout is None: + self._timeout = seconds self._connection.settimeout(seconds or self._timeout)