diff --git a/mlonmcu/target/ssh_target.py b/mlonmcu/target/ssh_target.py index 30cfbf32..4ca65eb3 100644 --- a/mlonmcu/target/ssh_target.py +++ b/mlonmcu/target/ssh_target.py @@ -23,7 +23,6 @@ # import tempfile # import time -import socket from pathlib import Path import paramiko @@ -87,48 +86,30 @@ def workdir(self): def __repr__(self): return f"SSHTarget({self.name})" - def check_remote(self): - ssh = paramiko.SSHClient() - try: - ssh.connect(self.hostname, port=self.port, username=self.username, password=self.password) - # TODO: key_filename=key_file) - return True - except (BadHostKeyException, AuthenticationException, SSHException, socket.error) as e: - print(e) # TODO: remove - return False - raise NotImplementedError - - def create_remote_directory(self, path): - ssh = paramiko.SSHClient() - # ssh.load_host_keys(os.path.expanduser(os.path.join("~", ".ssh", "known_hosts"))) - if self.ignore_known_hosts: - ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy()) - ssh.connect(self.hostname, port=self.port, username=self.username, password=self.password) + # def check_remote(self): + # ssh = paramiko.SSHClient() + # try: + # ssh.connect(self.hostname, port=self.port, username=self.username, password=self.password) + # # TODO: key_filename=key_file) + # return True + # except (BadHostKeyException, AuthenticationException, SSHException, socket.error) as e: + # print(e) # TODO: remove + # return False + # raise NotImplementedError + + def create_remote_directory(self, ssh, path): command = f"mkdir -p {path}" stdin, stdout, stderr = ssh.exec_command(command) - ssh.close() - - def copy_to_remote(self, src, dest): - ssh = paramiko.SSHClient() - # ssh.load_host_keys(os.path.expanduser(os.path.join("~", ".ssh", "known_hosts"))) - if self.ignore_known_hosts: - ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy()) - ssh.connect(self.hostname, port=self.port, username=self.username, password=self.password) + + def copy_to_remote(self, ssh, src, dest): sftp = ssh.open_sftp() sftp.put(str(src), str(dest)) sftp.close() - ssh.close() - - def copy_from_remote(self, src, dest): - ssh = paramiko.SSHClient() - # ssh.load_host_keys(os.path.expanduser(os.path.join("~", ".ssh", "known_hosts"))) - if self.ignore_known_hosts: - ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy()) - ssh.connect(self.hostname, port=self.port, username=self.username, password=self.password) + + def copy_from_remote(self, ssh, src, dest): sftp = ssh.open_sftp() sftp.get(str(src), str(dest)) sftp.close() - ssh.close() def parse_exit(self, out): exit_code = super().parse_exit(out) @@ -139,30 +120,29 @@ def parse_exit(self, out): def exec_via_ssh(self, program: Path, *args, cwd=os.getcwd(), **kwargs): # TODO: keep connection established! - self.check_remote() - if self.workdir is None: - raise NotImplementedError("temp workdir") - else: - self.create_remote_directory(self.workdir) - workdir = self.workdir - remote_program = workdir / program.name - self.copy_to_remote(program, remote_program) - ssh = paramiko.SSHClient() - # ssh.load_host_keys(os.path.expanduser(os.path.join("~", ".ssh", "known_hosts"))) - if self.ignore_known_hosts: - ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy()) - ssh.connect(self.hostname, port=self.port, username=self.username, password=self.password) - args_str = " ".join(args) - command = f"cd {workdir} && chmod +x {remote_program} && {remote_program} {args_str}; echo SSH EXIT=$?" - stdin, stdout, stderr = ssh.exec_command(command) - # print("stdin", stdin) - # print("stdout", stdout) - # print("stderr", stderr) - output = stderr.read().strip() + stdout.read().strip() - output = output.decode() - if self.print_outputs: - print("output", output) # TODO: cleanup - ssh.close() + # self.check_remote() + with paramiko.SSHClient() as ssh: + # ssh.load_host_keys(os.path.expanduser(os.path.join("~", ".ssh", "known_hosts"))) + if self.ignore_known_hosts: + ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy()) + ssh.connect(self.hostname, port=self.port, username=self.username, password=self.password) + if self.workdir is None: + raise NotImplementedError("temp workdir") + else: + self.create_remote_directory(ssh, self.workdir) + workdir = self.workdir + remote_program = workdir / program.name + self.copy_to_remote(ssh, program, remote_program) + args_str = " ".join(args) + command = f"cd {workdir} && chmod +x {remote_program} && {remote_program} {args_str}; echo SSH EXIT=$?" + stdin, stdout, stderr = ssh.exec_command(command) + # print("stdin", stdin) + # print("stdout", stdout) + # print("stderr", stderr) + output = stderr.read().strip() + stdout.read().strip() + output = output.decode() + if self.print_outputs: + print("output", output) # TODO: cleanup return output