Skip to content

Commit

Permalink
Merge pull request #168 from tum-ei-eda/feature-ssh-target
Browse files Browse the repository at this point in the history
Feature ssh target
  • Loading branch information
PhilippvK authored Jul 22, 2024
2 parents 1499f87 + b9e8e2b commit 3bb3b08
Show file tree
Hide file tree
Showing 3 changed files with 231 additions and 0 deletions.
2 changes: 2 additions & 0 deletions mlonmcu/target/_target.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
)
from .arm import Corstone300Target
from .host_x86 import HostX86Target
from .host_x86_ssh import HostX86SSHTarget

TARGET_REGISTRY = {}

Expand All @@ -47,6 +48,7 @@ def get_targets():
register_target("etiss_pulpino", EtissPulpinoTarget)
register_target("etiss", EtissTarget)
register_target("host_x86", HostX86Target)
register_target("host_x86_ssh", HostX86SSHTarget)
register_target("corstone300", Corstone300Target)
register_target("spike", SpikeTarget)
register_target("ovpsim", OVPSimTarget)
Expand Down
60 changes: 60 additions & 0 deletions mlonmcu/target/host_x86_ssh.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
#
# Copyright (c) 2022 TUM Department of Electrical and Computer Engineering.
#
# This file is part of MLonMCU.
# See https://github.com/tum-ei-eda/mlonmcu.git for further info.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""MLonMCU Host/x86 Target definitions"""

import stat

Check failure on line 21 in mlonmcu/target/host_x86_ssh.py

View workflow job for this annotation

GitHub Actions / Flake8

mlonmcu/target/host_x86_ssh.py#L21

'stat' imported but unused (F401)
from pathlib import Path

Check failure on line 22 in mlonmcu/target/host_x86_ssh.py

View workflow job for this annotation

GitHub Actions / Flake8

mlonmcu/target/host_x86_ssh.py#L22

'pathlib.Path' imported but unused (F401)

from mlonmcu.config import str2bool

Check failure on line 24 in mlonmcu/target/host_x86_ssh.py

View workflow job for this annotation

GitHub Actions / Flake8

mlonmcu/target/host_x86_ssh.py#L24

'mlonmcu.config.str2bool' imported but unused (F401)
from mlonmcu.setup.utils import execute

Check failure on line 25 in mlonmcu/target/host_x86_ssh.py

View workflow job for this annotation

GitHub Actions / Flake8

mlonmcu/target/host_x86_ssh.py#L25

'mlonmcu.setup.utils.execute' imported but unused (F401)
from .common import cli
from .ssh_target import SSHTarget
from .host_x86 import HostX86Target


class HostX86SSHTarget(SSHTarget, HostX86Target):
"""TODO"""

FEATURES = SSHTarget.FEATURES | HostX86Target.FEATURES # TODO: do not allow gdbserver

DEFAULTS = {
**SSHTarget.DEFAULTS,
**HostX86Target.DEFAULTS,
}

def __init__(self, name="host_x86_ssh", features=None, config=None):
super().__init__(name, features=features, config=config)

def exec(self, program, *args, handle_exit=None, **kwargs):
if self.gdbserver_enable:
raise NotImplementedError("gdbserver via ssh")

output = self.exec_via_ssh(program, *args, **kwargs)
if handle_exit:
exit_code = handle_exit(0, out=output)
print("exit_code", exit_code)
assert exit_code == 0
return output, []

def get_target_system(self):
return "host_x86"


if __name__ == "__main__":
cli(target=HostX86SSHTarget)
169 changes: 169 additions & 0 deletions mlonmcu/target/ssh_target.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
#
# Copyright (c) 2022 TUM Department of Electrical and Computer Engineering.
#
# This file is part of MLonMCU.
# See https://github.com/tum-ei-eda/mlonmcu.git for further info.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""MLonMCU SSH Target definitions"""

import os
import re
# import tempfile
# import time
import socket
from pathlib import Path

import paramiko
from paramiko import BadHostKeyException, AuthenticationException, SSHException

from mlonmcu.config import str2bool
from .target import Target


class SSHTarget(Target):
"""TODO"""

DEFAULTS = {
**Target.DEFAULTS,
"hostname": None,
"port": 22,
"username": None,
"password": None,
"ignore_known_hosts": True,
"workdir": None,
}

@property
def hostname(self):
value = self.config["hostname"]
assert value is not None, "hostname not defined"
return value

@property
def port(self):
value = self.config["port"]
if isinstance(value, str):
value = int(value)
assert isinstance(value, int)
return value

@property
def username(self):
value = self.config["username"]
return value

@property
def password(self):
value = self.config["password"]
return value

@property
def ignore_known_hosts(self):
value = self.config["ignore_known_hosts"]
return str2bool(value) if not isinstance(value, (bool, int)) else value

@property
def workdir(self):
value = self.config["workdir"]
if value is not None:
if isinstance(value, str):
value = Path(value)
assert isinstance(value, Path)
return value

def __repr__(self):
return f"SSHTarget({self.name})"

def check_remote(self):
ssh = paramiko.SSHClient()
try:
ssh.connect(self.hostname, port=self.port, username=self.username, password=self.password)
# TODO: key_filename=key_file)
return True
except (BadHostKeyException, AuthenticationException,
SSHException, socket.error) as e:
print(e) # TODO: remove
return False
raise NotImplementedError

def create_remote_directory(self, path):
ssh = paramiko.SSHClient()
# ssh.load_host_keys(os.path.expanduser(os.path.join("~", ".ssh", "known_hosts")))
if self.ignore_known_hosts:
ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
ssh.connect(self.hostname, port=self.port, username=self.username, password=self.password)
command = f"mkdir -p {path}"
stdin, stdout, stderr = ssh.exec_command(command)
ssh.close()

def copy_to_remote(self, src, dest):
ssh = paramiko.SSHClient()
# ssh.load_host_keys(os.path.expanduser(os.path.join("~", ".ssh", "known_hosts")))
if self.ignore_known_hosts:
ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
ssh.connect(self.hostname, port=self.port, username=self.username, password=self.password)
sftp = ssh.open_sftp()
sftp.put(str(src), str(dest))
sftp.close()
ssh.close()

def copy_from_remote(self, src, dest):
ssh = paramiko.SSHClient()
# ssh.load_host_keys(os.path.expanduser(os.path.join("~", ".ssh", "known_hosts")))
if self.ignore_known_hosts:
ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
ssh.connect(self.hostname, port=self.port, username=self.username, password=self.password)
sftp = ssh.open_sftp()
sftp.get(str(src), str(dest))
sftp.close()
ssh.close()

def parse_exit(self, out):
print("parse_exit (ssh_target)")
exit_code = super().parse_exit(out)
exit_match = re.search(r"SSH EXIT=(.*)", out)
if exit_match:
exit_code = int(exit_match.group(1))
return exit_code

def exec_via_ssh(self, program: Path, *args, cwd=os.getcwd(), **kwargs):
# TODO: keep connection established!
self.check_remote()
if self.workdir is None:
raise NotImplementedError("temp workdir")
else:
self.create_remote_directory(self.workdir)
workdir = self.workdir
remote_program = workdir / program.name
self.copy_to_remote(program, remote_program)
ssh = paramiko.SSHClient()
# ssh.load_host_keys(os.path.expanduser(os.path.join("~", ".ssh", "known_hosts")))
if self.ignore_known_hosts:
ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
ssh.connect(self.hostname, port=self.port, username=self.username, password=self.password)
args_str = " ".join(args)
command = f"cd {workdir} && chmod +x {remote_program} && {remote_program} {args_str}; echo SSH EXIT=$?"
stdin, stdout, stderr = ssh.exec_command(command)
# print("stdin", stdin)
# print("stdout", stdout)
# print("stderr", stderr)
output = stderr.read().strip() + stdout.read().strip()
output = output.decode()
if self.print_outputs:
print("output", output) # TODO: cleanup
ssh.close()
return output

# TODO: logger

0 comments on commit 3bb3b08

Please sign in to comment.