Skip to content

Commit

Permalink
Scp send rc (#340)
Browse files Browse the repository at this point in the history
* Added SCP large file test.
* Fix issue with scp_send - resolves #337 
* Updated changelog
* Updated embedded server, tests
  • Loading branch information
pkittenis authored Mar 20, 2022
1 parent 8e20f47 commit cf29d9d
Show file tree
Hide file tree
Showing 6 changed files with 96 additions and 38 deletions.
1 change: 1 addition & 0 deletions .environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@ dependencies:
- setuptools
- pip
- toolchain3
- cython
7 changes: 6 additions & 1 deletion Changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,18 @@ Changes
--------

* ``pssh.exceptions.ConnectionError`` is now the same as built-in ``ConnectionError`` and deprecated - to be removed.
* Clients now continue connecting with all addresses in DNS list. In the case where an address refuses connection,
* Clients now attempt to connect with all addresses in DNS list. In the case where an address refuses connection,
other available addresses are attempted without delay.

For example where a host resolves to both IPv4 and v6 addresses while only one address is
accepting connections, or multiple v4/v6 addresses where only some are accepting connections.
* Connection actively refused error is no longer subject to retries.

Fixes
-----

* ``scp_send`` in native clients would sometimes fail to send all data in a race condition with client going out of scope.


2.8.0
+++++
Expand Down
6 changes: 5 additions & 1 deletion pssh/clients/native/single.py
Original file line number Diff line number Diff line change
Expand Up @@ -598,6 +598,7 @@ def _scp_recv(self, remote_file, local_file):
total += size
local_fh.write(data)
finally:
local_fh.flush()
local_fh.close()
file_chan.close()

Expand Down Expand Up @@ -659,13 +660,16 @@ def _scp_send(self, local_file, remote_file):
raise SCPError(msg, remote_file, self.host, ex)
try:
with open(local_file, 'rb', 2097152) as local_fh:
for data in local_fh:
data = local_fh.read(self._BUF_SIZE)
while data:
self.eagain_write(chan.write, data)
data = local_fh.read(self._BUF_SIZE)
except Exception as ex:
msg = "Error writing to remote file %s on host %s - %s"
logger.error(msg, remote_file, self.host, ex)
raise SCPError(msg, remote_file, self.host, ex)
finally:
self._eagain(chan.flush)
chan.close()

def _sftp_openfh(self, open_func, remote_file, *args):
Expand Down
8 changes: 5 additions & 3 deletions tests/embedded_server/openssh.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,12 +86,14 @@ def start_server(self):
logger.debug("Starting server with %s" % (" ".join(cmd),))
self.server_proc = Popen(cmd)
try:
self.server_proc.wait(.1)
self.server_proc.wait(.3)
except TimeoutExpired:
pass
else:
logger.error(self.server_proc.stdout.read())
logger.error(self.server_proc.stderr.read())
if self.server_proc.stdout is not None:
logger.error(self.server_proc.stdout.read())
if self.server_proc.stderr is not None:
logger.error(self.server_proc.stderr.read())
raise Exception("Server could not start")

def stop(self):
Expand Down
68 changes: 63 additions & 5 deletions tests/native/test_parallel_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,8 @@ def test_client_shells_read_timeout(self):
def test_client_shells_timeout(self):
client = ParallelSSHClient([self.host], pkey=self.user_key, port=self.port,
timeout=0.01, num_retries=1)
client._make_ssh_client = MagicMock()
client._make_ssh_client.side_effect = Timeout
self.assertRaises(Timeout, client.open_shell)

def test_client_shells_join_timeout(self):
Expand Down Expand Up @@ -1517,8 +1519,8 @@ def test_scp_send_dir_recurse(self):
except OSError:
pass

def test_scp_send_large_files_timeout(self):
hosts = ['127.0.0.1%s' % (i,) for i in range(1, 10)]
def test_scp_send_larger_files(self):
hosts = ['127.0.0.1%s' % (i,) for i in range(1, 3)]
servers = [OpenSSHServer(host, port=self.port) for host in hosts]
for server in servers:
server.start_server()
Expand All @@ -1535,7 +1537,7 @@ def test_scp_send_large_files_timeout(self):
remote_file_names = [arg['remote_file'] for arg in copy_args]
sha = sha256()
with open(local_filename, 'wb') as file_h:
for _ in range(5000):
for _ in range(10000):
data = os.urandom(1024)
file_h.write(data)
sha.update(data)
Expand All @@ -1547,13 +1549,15 @@ def test_scp_send_large_files_timeout(self):
except Exception:
raise
else:
sleep(.2)
del client
for remote_file_name in remote_file_names:
remote_file_abspath = os.path.expanduser('~/' + remote_file_name)
self.assertTrue(os.path.isfile(remote_file_abspath))
with open(remote_file_abspath, 'rb') as remote_fh:
for data in remote_fh:
data = remote_fh.read(10240)
while data:
sha.update(data)
data = remote_fh.read(10240)
remote_file_sha = sha.hexdigest()
sha = sha256()
self.assertEqual(source_file_sha, remote_file_sha)
Expand Down Expand Up @@ -1679,6 +1683,60 @@ def test_scp_recv(self):
except Exception:
pass

def test_scp_recv_larger_files(self):
hosts = ['127.0.0.1%s' % (i,) for i in range(1, 3)]
servers = [OpenSSHServer(host, port=self.port) for host in hosts]
for server in servers:
server.start_server()
client = ParallelSSHClient(
hosts, port=self.port, pkey=self.user_key, num_retries=1, timeout=1,
pool_size=len(hosts),
)
dir_name = os.path.dirname(__file__)
remote_filename = 'test_file'
remote_filepath = os.path.join(dir_name, remote_filename)
local_filename = 'file_copy'
copy_args = [{
'remote_file': remote_filepath,
'local_file': os.path.expanduser("~/" + 'host_%s_%s' % (n, local_filename))}
for n in range(len(hosts))
]
local_file_names = [
arg['local_file'] for arg in copy_args]
sha = sha256()
with open(remote_filepath, 'wb') as file_h:
for _ in range(10000):
data = os.urandom(1024)
file_h.write(data)
sha.update(data)
file_h.flush()
source_file_sha = sha.hexdigest()
sha = sha256()
cmds = client.scp_recv('%(remote_file)s', '%(local_file)s', copy_args=copy_args)
try:
joinall(cmds, raise_error=True)
except Exception:
raise
else:
del client
for _local_file_name in local_file_names:
self.assertTrue(os.path.isfile(_local_file_name))
with open(_local_file_name, 'rb') as fh:
data = fh.read(10240)
while data:
sha.update(data)
data = fh.read(10240)
local_file_sha = sha.hexdigest()
sha = sha256()
self.assertEqual(source_file_sha, local_file_sha)
finally:
try:
os.unlink(remote_filepath)
for _local_file_name in local_file_names:
os.unlink(_local_file_name)
except OSError:
pass

def test_bad_hosts_value(self):
self.assertRaises(TypeError, ParallelSSHClient, 'a host')
self.assertRaises(TypeError, ParallelSSHClient, b'a host')
Expand Down
44 changes: 16 additions & 28 deletions tests/native/test_single_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,10 +405,10 @@ def test_auth_retry_failure(self):

def test_connection_timeout(self):
cmd = spawn(SSHClient, 'fakehost.com', port=12345,
num_retries=1, timeout=1, _auth_thread_pool=False)
num_retries=1, timeout=.1, _auth_thread_pool=False)
# Should fail within greenlet timeout, otherwise greenlet will
# raise timeout which will fail the test
self.assertRaises(ConnectionErrorException, cmd.get, timeout=2)
self.assertRaises(ConnectionErrorException, cmd.get, timeout=1)

def test_client_read_timeout(self):
client = SSHClient(self.host, port=self.port,
Expand Down Expand Up @@ -657,27 +657,22 @@ def test_scp_recv_large_file(self):
os.unlink(_path)
except OSError:
pass
sha = sha256()
try:
with open(file_path_from, 'wb') as fh:
# ~300MB
for _ in range(20000000):
fh.write(b"adsfasldkfjabafj")
for _ in range(10000):
data = os.urandom(1024)
fh.write(data)
sha.update(data)
source_file_sha = sha.hexdigest()
self.client.scp_recv(file_path_from, file_copy_to_dirpath)
self.assertTrue(os.path.isfile(file_copy_to_dirpath))
read_file_size = os.stat(file_path_from).st_size
written_file_size = os.stat(file_copy_to_dirpath).st_size
self.assertEqual(read_file_size, written_file_size)
sha = sha256()
with open(file_path_from, 'rb') as fh:
for block in fh:
sha.update(block)
read_file_hash = sha.hexdigest()
sha = sha256()
with open(file_copy_to_dirpath, 'rb') as fh:
for block in fh:
sha.update(block)
written_file_hash = sha.hexdigest()
self.assertEqual(read_file_hash, written_file_hash)
self.assertEqual(source_file_sha, written_file_hash)
finally:
for _path in (file_path_from, file_copy_to_dirpath):
try:
Expand Down Expand Up @@ -728,29 +723,22 @@ def test_scp_send_large_file(self):
os.unlink(_path)
except OSError:
pass
sha = sha256()
try:
with open(file_path_from, 'wb') as fh:
# ~300MB
for _ in range(20000000):
fh.write(b"adsfasldkfjabafj")
for _ in range(10000):
data = os.urandom(1024)
fh.write(data)
sha.update(data)
source_file_sha = sha.hexdigest()
self.client.scp_send(file_path_from, file_copy_to_dirpath)
self.assertTrue(os.path.isfile(file_copy_to_dirpath))
# OS file flush race condition
sleep(.1)
read_file_size = os.stat(file_path_from).st_size
written_file_size = os.stat(file_copy_to_dirpath).st_size
self.assertEqual(read_file_size, written_file_size)
sha = sha256()
with open(file_path_from, 'rb') as fh:
for block in fh:
sha.update(block)
read_file_hash = sha.hexdigest()
sha = sha256()
with open(file_copy_to_dirpath, 'rb') as fh:
for block in fh:
sha.update(block)
written_file_hash = sha.hexdigest()
self.assertEqual(read_file_hash, written_file_hash)
self.assertEqual(source_file_sha, written_file_hash)
finally:
for _path in (file_path_from, file_copy_to_dirpath):
try:
Expand Down

0 comments on commit cf29d9d

Please sign in to comment.