Skip to content

Commit

Permalink
Conn refused (#339)
Browse files Browse the repository at this point in the history
* Handle connection refused better when there are multiple available addresses from DNS resolution
* Updated tests, changelog
  • Loading branch information
pkittenis authored Mar 20, 2022
1 parent 9c9b678 commit 8e20f47
Show file tree
Hide file tree
Showing 7 changed files with 88 additions and 43 deletions.
15 changes: 15 additions & 0 deletions Changelog.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,21 @@
Change Log
============

2.9.0
+++++

Changes
--------

* ``pssh.exceptions.ConnectionError`` is now the same as built-in ``ConnectionError`` and deprecated - to be removed.
* Clients now continue connecting with all addresses in DNS list. In the case where an address refuses connection,
other available addresses are attempted without delay.

For example where a host resolves to both IPv4 and v6 addresses while only one address is
accepting connections, or multiple v4/v6 addresses where only some are accepting connections.
* Connection actively refused error is no longer subject to retries.


2.8.0
+++++

Expand Down
18 changes: 15 additions & 3 deletions pssh/clients/base/single.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,8 +207,8 @@ def _auth_retry(self, retries=1):
if retries < self.num_retries:
sleep(self.retry_delay)
return self._auth_retry(retries=retries+1)
msg = "Authentication error while connecting to %s:%s - %s"
raise AuthenticationError(msg, self.host, self.port, ex)
msg = "Authentication error while connecting to %s:%s - %s - retries %s/%s"
raise AuthenticationError(msg, self.host, self.port, ex, retries, self.num_retries)

def disconnect(self):
raise NotImplementedError
Expand Down Expand Up @@ -284,13 +284,25 @@ def _connect(self, host, port, retries=1):
host, str(ex.args[1]), retries,
self.num_retries)
raise unknown_ex from ex
family, _type, proto, _, sock_addr = addr_info[0]
for i, (family, _type, proto, _, sock_addr) in enumerate(addr_info):
try:
return self._connect_socket(family, _type, proto, sock_addr, host, port, retries)
except ConnectionRefusedError as ex:
if i+1 == len(addr_info):
logger.error("No available addresses from %s", [addr[4] for addr in addr_info])
ex.args += (host, port)
raise
continue

def _connect_socket(self, family, _type, proto, sock_addr, host, port, retries):
self.sock = socket.socket(family, _type)
if self.timeout:
self.sock.settimeout(self.timeout)
logger.debug("Connecting to %s:%s", host, port)
try:
self.sock.connect(sock_addr)
except ConnectionRefusedError:
raise
except sock_error as ex:
logger.error("Error connecting to host '%s:%s' - retry %s/%s",
host, port, retries, self.num_retries)
Expand Down
8 changes: 1 addition & 7 deletions pssh/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,7 @@ class UnknownHostError(Exception):


UnknownHostException = UnknownHostError


class ConnectionError(Exception):
"""Raised on error connecting (connection refused/timed out)"""
pass


ConnectionError = ConnectionError
ConnectionErrorException = ConnectionError


Expand Down
36 changes: 16 additions & 20 deletions tests/native/test_parallel_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@
from pssh.exceptions import UnknownHostException, \
AuthenticationException, ConnectionErrorException, \
HostArgumentException, SFTPError, SFTPIOError, Timeout, SCPError, \
PKeyFileError, ShellError, HostArgumentError, NoIPv6AddressFoundError
PKeyFileError, ShellError, HostArgumentError, NoIPv6AddressFoundError, \
AuthenticationError
from pssh.output import HostOutput

from .base_ssh2_case import PKEY_FILENAME, PUB_FILE
Expand Down Expand Up @@ -276,7 +277,7 @@ def test_pssh_client_hosts_list_part_failure(self):
self.assertTrue(client.finished(output))
self.assertEqual(len(hosts), len(output))
self.assertIsNotNone(output[1].exception)
self.assertEqual(output[1].exception.args[1], hosts[1])
self.assertEqual(output[1].host, hosts[1])
self.assertIsInstance(output[1].exception, ConnectionErrorException)

def test_pssh_client_timeout(self):
Expand Down Expand Up @@ -350,23 +351,23 @@ def test_pssh_client_long_running_command_exit_codes_no_stdout(self):

def test_pssh_client_retries(self):
"""Test connection error retries"""
listen_port = self.make_random_port()
# listen_port = self.make_random_port()
expected_num_tries = 2
client = ParallelSSHClient([self.host], port=listen_port,
pkey=self.user_key,
client = ParallelSSHClient([self.host], port=self.port,
pkey=b"fake",
num_retries=expected_num_tries,
retry_delay=.1,
)
self.assertRaises(ConnectionErrorException, client.run_command, 'blah')
self.assertRaises(AuthenticationError, client.run_command, 'blah')
try:
client.run_command('blah')
except ConnectionErrorException as ex:
except AuthenticationError as ex:
max_tries = ex.args[-2:][0]
num_tries = ex.args[-1:][0]
self.assertEqual(expected_num_tries, num_tries,
msg="Got unexpected number of retries %s - "
"expected %s" % (num_tries, expected_num_tries,))
self.assertEqual(expected_num_tries, max_tries)
self.assertEqual(expected_num_tries, num_tries)
else:
raise Exception('No ConnectionErrorException')
raise Exception('No AuthenticationError')

def test_sftp_exceptions(self):
# Port with no server listening on it on separate ip
Expand All @@ -380,7 +381,8 @@ def test_sftp_exceptions(self):
try:
cmd.get()
except Exception as ex:
self.assertEqual(ex.args[1], self.host)
self.assertEqual(ex.args[2], self.host)
self.assertEqual(ex.args[3], port)
self.assertIsInstance(ex, ConnectionErrorException)
else:
raise Exception("Expected ConnectionErrorException, got none")
Expand Down Expand Up @@ -859,7 +861,7 @@ def test_identical_hosts_in_host_list(self):
_host_stdout = list(host_out.stdout)
self.assertListEqual(_host_stdout, expected_stdout)

def test_connection_error_exception(self):
def test_connection_error(self):
"""Test that we get connection error exception in output with correct arguments"""
# Make port with no server listening on it on separate ip
host = '127.0.0.3'
Expand All @@ -874,13 +876,7 @@ def test_connection_error_exception(self):
for host_output in output:
exit_code = host_output.exit_code
self.assertEqual(exit_code, None)
try:
raise output[0].exception
except ConnectionErrorException as ex:
self.assertEqual(ex.args[1], host)
self.assertEqual(ex.args[2], port)
else:
raise Exception("Expected ConnectionErrorException")
self.assertIsInstance(output[0].exception, ConnectionError)

def test_bad_pkey_path(self):
self.assertRaises(PKeyFileError, ParallelSSHClient, [self.host], port=self.port,
Expand Down
37 changes: 33 additions & 4 deletions tests/native/test_single_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
import shutil
import tempfile
from tempfile import NamedTemporaryFile

import pytest
from pytest import raises
from unittest.mock import MagicMock, call, patch
from hashlib import sha256
Expand All @@ -36,7 +38,7 @@
)
from pssh.exceptions import (AuthenticationException, ConnectionErrorException,
SessionError, SFTPIOError, SFTPError, SCPError, PKeyFileError, Timeout,
AuthenticationError, NoIPv6AddressFoundError,
AuthenticationError, NoIPv6AddressFoundError, ConnectionError
)

from .base_ssh2_case import SSH2TestCase
Expand Down Expand Up @@ -89,6 +91,10 @@ def _sftp_exc(local_file, remote_file):
self.assertRaises(
SFTPIOError, client.copy_remote_file, 'fake_remote_file_not_exists', 'local')

def test_conn_refused(self):
with pytest.raises(ConnectionRefusedError):
SSHClient('127.0.0.99', port=self.port, num_retries=1, timeout=1)

@patch('pssh.clients.base.single.socket')
def test_ipv6(self, gsocket):
# As of Oct 2021, CircleCI does not support IPv6 in its containers.
Expand All @@ -102,18 +108,41 @@ def test_ipv6(self, gsocket):
_sock = MagicMock()
gsocket.socket.return_value = _sock
sock_con = MagicMock()
sock_con.side_effect = ConnectionRefusedError
_sock.connect = sock_con
getaddrinfo = MagicMock()
gsocket.getaddrinfo = getaddrinfo
getaddrinfo.return_value = [(
socket.AF_INET6, socket.SocketKind.SOCK_STREAM, socket.IPPROTO_TCP, '', addr_info)]
with raises(TypeError):
# Mock object as a file descriptor will raise TypeError
with raises(ConnectionError):
client = SSHClient(host, port=self.port, pkey=self.user_key,
num_retries=1)
getaddrinfo.assert_called_once_with(host, self.port, proto=socket.IPPROTO_TCP)
sock_con.assert_called_once_with(addr_info)

@patch('pssh.clients.base.single.socket')
def test_multiple_available_addr(self, gsocket):
host = '127.0.0.1'
addr_info = (host, self.port)
gsocket.IPPROTO_TCP = socket.IPPROTO_TCP
gsocket.socket = MagicMock()
_sock = MagicMock()
gsocket.socket.return_value = _sock
sock_con = MagicMock()
sock_con.side_effect = ConnectionRefusedError
_sock.connect = sock_con
getaddrinfo = MagicMock()
gsocket.getaddrinfo = getaddrinfo
getaddrinfo.return_value = [
(socket.AF_INET, socket.SocketKind.SOCK_STREAM, socket.IPPROTO_TCP, '', addr_info),
(socket.AF_INET, socket.SocketKind.SOCK_STREAM, socket.IPPROTO_TCP, '', addr_info),
]
with raises(ConnectionError):
client = SSHClient(host, port=self.port, pkey=self.user_key,
num_retries=1)
getaddrinfo.assert_called_with(host, self.port, proto=socket.IPPROTO_TCP)
assert sock_con.call_count == len(getaddrinfo.return_value)

def test_no_ipv6(self):
try:
SSHClient(self.host,
Expand Down Expand Up @@ -357,7 +386,7 @@ def test_password_auth_failure(self):
raise AssertionError

def test_retry_failure(self):
self.assertRaises(ConnectionErrorException,
self.assertRaises(ConnectionError,
SSHClient, self.host, port=12345,
num_retries=2, _auth_thread_pool=False,
retry_delay=.1,
Expand Down
15 changes: 7 additions & 8 deletions tests/ssh/test_parallel_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,14 +238,14 @@ def test_pssh_client_hosts_list_part_failure(self):
self.assertIsNotNone(output[1].exception,
msg="Failed host %s has no exception in output - %s" % (hosts[1], output,))
self.assertTrue(output[1].exception is not None)
self.assertEqual(output[1].exception.args[1], hosts[1])
self.assertEqual(output[1].host, hosts[1])
self.assertEqual(output[1].exception.args[-2], hosts[1])
try:
raise output[1].exception
except ConnectionErrorException:
pass
else:
raise Exception("Expected ConnectionError, got %s instead" % (
output[1].exception,))
raise Exception("Expected ConnectionError, got %s instead" % (output[1].exception,))

def test_pssh_client_timeout(self):
# 1ms timeout
Expand Down Expand Up @@ -316,14 +316,13 @@ def test_connection_error_exception(self):
num_retries=1)
output = client.run_command(self.cmd, stop_on_errors=False)
client.join(output)
self.assertIsNotNone(output[0].exception,
msg="Got no exception for host %s - expected connection error" % (
host,))
self.assertIsInstance(output[0].exception, ConnectionErrorException)
self.assertEqual(output[0].host, host)
try:
raise output[0].exception
except ConnectionErrorException as ex:
self.assertEqual(ex.args[1], host)
self.assertEqual(ex.args[2], port)
self.assertEqual(ex.args[-2], host)
self.assertEqual(ex.args[-1], port)
else:
raise Exception("Expected ConnectionErrorException")

Expand Down
2 changes: 1 addition & 1 deletion tests/ssh/test_single_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from datetime import datetime

from gevent import sleep, Timeout as GTimeout, spawn
from ssh.session import Session
# from ssh.session import Session
from ssh.exceptions import AuthenticationDenied
from pssh.exceptions import AuthenticationException, ConnectionErrorException, \
SessionError, SFTPIOError, SFTPError, SCPError, PKeyFileError, Timeout, \
Expand Down

0 comments on commit 8e20f47

Please sign in to comment.