From 533b456a53d3f4eccf82aaad02d856950eefb933 Mon Sep 17 00:00:00 2001 From: Philipp van Kempen Date: Mon, 22 Jul 2024 12:39:53 +0200 Subject: [PATCH 1/2] add generic SSHTarget class --- mlonmcu/target/ssh_target.py | 169 +++++++++++++++++++++++++++++++++++ 1 file changed, 169 insertions(+) create mode 100644 mlonmcu/target/ssh_target.py diff --git a/mlonmcu/target/ssh_target.py b/mlonmcu/target/ssh_target.py new file mode 100644 index 00000000..f272b5d8 --- /dev/null +++ b/mlonmcu/target/ssh_target.py @@ -0,0 +1,169 @@ +# +# Copyright (c) 2022 TUM Department of Electrical and Computer Engineering. +# +# This file is part of MLonMCU. +# See https://github.com/tum-ei-eda/mlonmcu.git for further info. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +"""MLonMCU SSH Target definitions""" + +import os +import re +# import tempfile +# import time +import socket +from pathlib import Path + +import paramiko +from paramiko import BadHostKeyException, AuthenticationException, SSHException + +from mlonmcu.config import str2bool +from .target import Target + + +class SSHTarget(Target): + """TODO""" + + DEFAULTS = { + **Target.DEFAULTS, + "hostname": None, + "port": 22, + "username": None, + "password": None, + "ignore_known_hosts": True, + "workdir": None, + } + + @property + def hostname(self): + value = self.config["hostname"] + assert value is not None, "hostname not defined" + return value + + @property + def port(self): + value = self.config["port"] + if isinstance(value, str): + value = int(value) + assert isinstance(value, int) + return value + + @property + def username(self): + value = self.config["username"] + return value + + @property + def password(self): + value = self.config["password"] + return value + + @property + def ignore_known_hosts(self): + value = self.config["ignore_known_hosts"] + return str2bool(value) if not isinstance(value, (bool, int)) else value + + @property + def workdir(self): + value = self.config["workdir"] + if value is not None: + if isinstance(value, str): + value = Path(value) + assert isinstance(value, Path) + return value + + def __repr__(self): + return f"SSHTarget({self.name})" + + def check_remote(self): + ssh = paramiko.SSHClient() + try: + ssh.connect(self.hostname, port=self.port, username=self.username, password=self.password) + # TODO: key_filename=key_file) + return True + except (BadHostKeyException, AuthenticationException, + SSHException, socket.error) as e: + print(e) # TODO: remove + return False + raise NotImplementedError + + def create_remote_directory(self, path): + ssh = paramiko.SSHClient() + # ssh.load_host_keys(os.path.expanduser(os.path.join("~", ".ssh", "known_hosts"))) + if self.ignore_known_hosts: + ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy()) + ssh.connect(self.hostname, port=self.port, username=self.username, password=self.password) + command = f"mkdir -p {path}" + stdin, stdout, stderr = ssh.exec_command(command) + ssh.close() + + def copy_to_remote(self, src, dest): + ssh = paramiko.SSHClient() + # ssh.load_host_keys(os.path.expanduser(os.path.join("~", ".ssh", "known_hosts"))) + if self.ignore_known_hosts: + ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy()) + ssh.connect(self.hostname, port=self.port, username=self.username, password=self.password) + sftp = ssh.open_sftp() + sftp.put(str(src), str(dest)) + sftp.close() + ssh.close() + + def copy_from_remote(self, src, dest): + ssh = paramiko.SSHClient() + # ssh.load_host_keys(os.path.expanduser(os.path.join("~", ".ssh", "known_hosts"))) + if self.ignore_known_hosts: + ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy()) + ssh.connect(self.hostname, port=self.port, username=self.username, password=self.password) + sftp = ssh.open_sftp() + sftp.get(str(src), str(dest)) + sftp.close() + ssh.close() + + def parse_exit(self, out): + print("parse_exit (ssh_target)") + exit_code = super().parse_exit(out) + exit_match = re.search(r"SSH EXIT=(.*)", out) + if exit_match: + exit_code = int(exit_match.group(1)) + return exit_code + + def exec_via_ssh(self, program: Path, *args, cwd=os.getcwd(), **kwargs): + # TODO: keep connection established! + self.check_remote() + if self.workdir is None: + raise NotImplementedError("temp workdir") + else: + self.create_remote_directory(self.workdir) + workdir = self.workdir + remote_program = workdir / program.name + self.copy_to_remote(program, remote_program) + ssh = paramiko.SSHClient() + # ssh.load_host_keys(os.path.expanduser(os.path.join("~", ".ssh", "known_hosts"))) + if self.ignore_known_hosts: + ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy()) + ssh.connect(self.hostname, port=self.port, username=self.username, password=self.password) + args_str = " ".join(args) + command = f"cd {workdir} && chmod +x {remote_program} && {remote_program} {args_str}; echo SSH EXIT=$?" + stdin, stdout, stderr = ssh.exec_command(command) + # print("stdin", stdin) + # print("stdout", stdout) + # print("stderr", stderr) + output = stderr.read().strip() + stdout.read().strip() + output = output.decode() + if self.print_outputs: + print("output", output) # TODO: cleanup + ssh.close() + return output + +# TODO: logger From b9e8e2ba30956023b25b6b44bb9a6e825a49e41f Mon Sep 17 00:00:00 2001 From: Philipp van Kempen Date: Mon, 22 Jul 2024 12:41:29 +0200 Subject: [PATCH 2/2] add host_x86_ssh target --- mlonmcu/target/_target.py | 2 ++ mlonmcu/target/host_x86_ssh.py | 60 ++++++++++++++++++++++++++++++++++ 2 files changed, 62 insertions(+) create mode 100644 mlonmcu/target/host_x86_ssh.py diff --git a/mlonmcu/target/_target.py b/mlonmcu/target/_target.py index 937a494f..336eb0ae 100644 --- a/mlonmcu/target/_target.py +++ b/mlonmcu/target/_target.py @@ -28,6 +28,7 @@ ) from .arm import Corstone300Target from .host_x86 import HostX86Target +from .host_x86_ssh import HostX86SSHTarget TARGET_REGISTRY = {} @@ -47,6 +48,7 @@ def get_targets(): register_target("etiss_pulpino", EtissPulpinoTarget) register_target("etiss", EtissTarget) register_target("host_x86", HostX86Target) +register_target("host_x86_ssh", HostX86SSHTarget) register_target("corstone300", Corstone300Target) register_target("spike", SpikeTarget) register_target("ovpsim", OVPSimTarget) diff --git a/mlonmcu/target/host_x86_ssh.py b/mlonmcu/target/host_x86_ssh.py new file mode 100644 index 00000000..b66c378c --- /dev/null +++ b/mlonmcu/target/host_x86_ssh.py @@ -0,0 +1,60 @@ +# +# Copyright (c) 2022 TUM Department of Electrical and Computer Engineering. +# +# This file is part of MLonMCU. +# See https://github.com/tum-ei-eda/mlonmcu.git for further info. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +"""MLonMCU Host/x86 Target definitions""" + +import stat +from pathlib import Path + +from mlonmcu.config import str2bool +from mlonmcu.setup.utils import execute +from .common import cli +from .ssh_target import SSHTarget +from .host_x86 import HostX86Target + + +class HostX86SSHTarget(SSHTarget, HostX86Target): + """TODO""" + + FEATURES = SSHTarget.FEATURES | HostX86Target.FEATURES # TODO: do not allow gdbserver + + DEFAULTS = { + **SSHTarget.DEFAULTS, + **HostX86Target.DEFAULTS, + } + + def __init__(self, name="host_x86_ssh", features=None, config=None): + super().__init__(name, features=features, config=config) + + def exec(self, program, *args, handle_exit=None, **kwargs): + if self.gdbserver_enable: + raise NotImplementedError("gdbserver via ssh") + + output = self.exec_via_ssh(program, *args, **kwargs) + if handle_exit: + exit_code = handle_exit(0, out=output) + print("exit_code", exit_code) + assert exit_code == 0 + return output, [] + + def get_target_system(self): + return "host_x86" + + +if __name__ == "__main__": + cli(target=HostX86SSHTarget)