Skip to content
This repository was archived by the owner on Feb 3, 2021. It is now read-only.

Feature: pure python ssh #577

Merged
merged 18 commits into from
Jun 5, 2018
17 changes: 17 additions & 0 deletions aztk/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,23 @@ def __cluster_copy(self, cluster_id, source_path, destination_path, container_na
finally:
self.__delete_user_on_pool('aztk', pool.id, nodes)

def __ssh_into_node(self, pool_id, node_id, username, ssh_key=None, password=None, port_forward_list=None, internal=False):
if internal:
result = self.batch_client.compute_node.get(pool_id=pool_id, node_id=node_id)
rls = models.RemoteLogin(ip_address=result.ip_address, port="22")
else:
result = self.batch_client.compute_node.get_remote_login_settings(pool_id, node_id)
rls = models.RemoteLogin(ip_address=result.remote_login_ip_address, port=str(result.remote_login_port))

ssh_lib.node_ssh(
username=username,
hostname=rls.ip_address,
port=rls.port,
ssh_key=ssh_key,
password=password,
port_forward_list=port_forward_list,
)

def __submit_job(self,
job_configuration,
start_task,
Expand Down
4 changes: 4 additions & 0 deletions aztk/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,10 @@ def __init__(self, ip_address, port):
self.ip_address = ip_address
self.port = port

class PortForwardingSpecification:
def __init__(self, remote_port, local_port):
self.remote_port = remote_port
self.local_port = local_port

class ServicePrincipalConfiguration(ConfigurationBase):
"""
Expand Down
6 changes: 6 additions & 0 deletions aztk/spark/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,12 @@ def cluster_download(self, cluster_id: str, source_path: str, destination_path:
except batch_error.BatchErrorException as e:
raise error.AztkError(helpers.format_batch_exception(e))

def cluster_ssh_into_master(self, cluster_id, node_id, username, ssh_key=None, password=None, port_forward_list=None, internal=False):
try:
self.__ssh_into_node(cluster_id, node_id, username, ssh_key, password, port_forward_list, internal)
except batch_error.BatchErrorException as e:
raise error.AztkError(helpers.format_batch_exception(e))

'''
job submission
'''
Expand Down
3 changes: 3 additions & 0 deletions aztk/spark/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ def __get_master_node_id(self):
class RemoteLogin(aztk.models.RemoteLogin):
pass

class PortForwardingSpecification(aztk.models.PortForwardingSpecification):
pass

class File(aztk.models.File):
pass
Expand Down Expand Up @@ -91,6 +93,7 @@ class SharedKeyConfiguration(aztk.models.SharedKeyConfiguration):
class DockerConfiguration(aztk.models.DockerConfiguration):
pass


class PluginConfiguration(aztk.models.PluginConfiguration):
pass

Expand Down
99 changes: 98 additions & 1 deletion aztk/utils/ssh.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,70 @@
import socket
import socketserver as SocketServer
import sys
import threading
from concurrent.futures import ThreadPoolExecutor

from aztk.error import AztkError
from . import helpers


g_verbose = False
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

?

Copy link
Member

@timotheeguerin timotheeguerin May 30, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

log.debug? and --verbose flag

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah that would be better

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

actually, since this is in the SDK, we don't current have a logger. Maybe that would be nice to have though.



def verbose(s):
if g_verbose:
print(s)


class ForwardServer(SocketServer.ThreadingTCPServer):
daemon_threads = True
allow_reuse_address = True

# pylint: disable=no-member
class Handler(SocketServer.BaseRequestHandler):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cloud you separate in another file?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel like this belongs here since this is the ssh tunnel request handler and this file is meant to have the utilities necessary to make ssh connections and commands.

def handle(self):
try:
channel = self.ssh_transport.open_channel('direct-tcpip',
(self.chain_host, self.chain_port),
self.request.getpeername())
except Exception as e:
verbose('Incoming request to %s:%d failed: %s' % (self.chain_host,
self.chain_port,
repr(e)))
return
if channel is None:
verbose('Incoming request to %s:%d was rejected by the SSH server.' %
(self.chain_host, self.chain_port))
return

verbose('Connected! Tunnel open %r -> %r -> %r' % (self.request.getpeername(),
channel.getpeername(), (self.chain_host, self.chain_port)))
while True:
r, w, x = select.select([self.request, channel], [], [])
if self.request in r:
data = self.request.recv(1024)
if len(data) == 0:
break
channel.send(data)
if channel in r:
data = channel.recv(1024)
if len(data) == 0:
break
self.request.send(data)

peername = self.request.getpeername()
channel.close()
self.request.close()
verbose('Tunnel closed from %r' % (peername,))


def forward_tunnel(local_port, remote_host, remote_port, transport):
class SubHandler(Handler):
chain_host = remote_host
chain_port = remote_port
ssh_transport = transport
thread = threading.Thread(target=ForwardServer(('', local_port), SubHandler).serve_forever, daemon=True)
thread.start()
return thread


def connect(hostname,
Expand Down Expand Up @@ -39,6 +99,21 @@ def connect(hostname,
return client


def forward_ports(client, port_forward_list):
threads = []
if port_forward_list:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would maybe do

if not port_forward_list:
     return []

to have less nesting

for port_forwarding_specification in port_forward_list:
threads.append(
forward_tunnel(
port_forwarding_specification.remote_port,
"127.0.0.1",
port_forwarding_specification.local_port,
client.get_transport()
)
)
return threads


def node_exec_command(node_id, command, username, hostname, port, ssh_key=None, password=None, container_name=None, timeout=None):
try:
client = connect(hostname=hostname, port=port, username=username, password=password, pkey=ssh_key, timeout=timeout)
Expand Down Expand Up @@ -133,3 +208,25 @@ async def clus_copy(username, nodes, source_path, destination_path, ssh_key=None
container_name,
timeout) for node, node_rls in nodes]
)


def node_ssh(username, hostname, port, ssh_key=None, password=None, port_forward_list=None, timeout=None):
try:
client = connect(
hostname=hostname,
port=port,
username=username,
password=password,
pkey=ssh_key,
timeout=timeout
)
threads = forward_ports(client=client, port_forward_list=port_forward_list)
except AztkError as e:
raise e

try:
import time
while True:
time.sleep(1)
except KeyboardInterrupt:
pass
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

doesn't this prevent ctrl + d

Copy link
Member Author

@jafreck jafreck May 30, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no, it will just intercept it so the stacktrace isn't printed

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a comment about it so we don't forget why its needed

83 changes: 60 additions & 23 deletions aztk_cli/spark/endpoints/cluster/cluster_ssh.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,16 @@
from aztk_cli import config, log, utils
from aztk_cli.config import SshConfig

from aztk.spark.models import PortForwardingSpecification


def setup_parser(parser: argparse.ArgumentParser):
parser.add_argument('--id', dest="cluster_id", help='The unique id of your spark cluster')
parser.add_argument('--webui', help='Local port to port spark\'s master UI to')
parser.add_argument('--jobui', help='Local port to port spark\'s job UI to')
parser.add_argument('--jobhistoryui', help='Local port to port spark\'s job history UI to')
parser.add_argument('-u', '--username', help='Username to spark cluster')
parser.add_argument('--password', help='Password for the specified ssh user')
parser.add_argument('--host', dest="host", action='store_true', help='Connect to the host of the Spark container')
parser.add_argument('--no-connect', dest="connect", action='store_false',
help='Do not create the ssh session. Only print out the command to run.')
Expand Down Expand Up @@ -52,29 +55,11 @@ def execute(args: typing.NamedTuple):
utils.log_property("connect", ssh_conf.connect)
log.info("-------------------------------------------")

# get ssh command
try:
ssh_cmd = utils.ssh_in_master(
client=spark_client,
cluster_id=ssh_conf.cluster_id,
webui=ssh_conf.web_ui_port,
jobui=ssh_conf.job_ui_port,
jobhistoryui=ssh_conf.job_history_ui_port,
username=ssh_conf.username,
host=ssh_conf.host,
connect=ssh_conf.connect,
internal=ssh_conf.internal)

if not ssh_conf.connect:
log.info("")
log.info("Use the following command to connect to your spark head node:")
log.info("\t%s", ssh_cmd)

except batch_error.BatchErrorException as e:
if e.error.code == "PoolNotFound":
raise aztk.error.AztkError("The cluster you are trying to connect to does not exist.")
else:
raise
shell_out_ssh(spark_client, ssh_conf)
except OSError:
# no ssh client is found, falling back to pure python
native_python_ssh_into_master(spark_client, cluster, ssh_conf, args.password)


def print_plugin_ports(cluster_config: ClusterConfiguration):
Expand All @@ -88,7 +73,7 @@ def print_plugin_ports(cluster_config: ClusterConfiguration):
if port.expose_publicly:
has_ports = True
plugin_ports[plugin.name].append(port)

if has_ports:
log.info("plugins:")

Expand All @@ -101,3 +86,55 @@ def print_plugin_ports(cluster_config: ClusterConfiguration):
label += " {}".format(port.name)
url = "{0}{1}".format(http_prefix, port.public_port)
utils.log_property(label, url)


def native_python_ssh_into_master(spark_client, cluster, ssh_conf, password):
configuration = spark_client.get_cluster_config(cluster.id)
plugin_ports = []
if configuration and configuration.plugins:
ports = [
PortForwardingSpecification(
port.internal,
port.public_port) for plugin in configuration.plugins for port in plugin.ports if port.expose_publicly
]
plugin_ports.extend(ports)

print("Press ctrl+c to exit...")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldn't that be ctrl + d? or not as we are not giving shell access

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We aren't giving shell access, the process just hangs so the ports stay open until killed (with crl + c).

spark_client.cluster_ssh_into_master(
cluster.id,
cluster.master_node_id,
ssh_conf.username,
ssh_key=None,
password=password,
port_forward_list=[
PortForwardingSpecification(remote_port=8080, local_port=8080), # web ui
PortForwardingSpecification(remote_port=4040, local_port=4040), # job ui
PortForwardingSpecification(remote_port=18080, local_port=18080), # job history ui
] + plugin_ports,
internal=ssh_conf.internal
)


def shell_out_ssh(spark_client, ssh_conf):
try:
ssh_cmd = utils.ssh_in_master(
client=spark_client,
cluster_id=ssh_conf.cluster_id,
webui=ssh_conf.web_ui_port,
jobui=ssh_conf.job_ui_port,
jobhistoryui=ssh_conf.job_history_ui_port,
username=ssh_conf.username,
host=ssh_conf.host,
connect=ssh_conf.connect,
internal=ssh_conf.internal)

if not ssh_conf.connect:
log.info("")
log.info("Use the following command to connect to your spark head node:")
log.info("\t%s", ssh_cmd)

except batch_error.BatchErrorException as e:
if e.error.code == "PoolNotFound":
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should have an enum(or const) for those error codes

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agreed, but I think that should probably come in a single PR for the entire repo.

raise aztk.error.AztkError("The cluster you are trying to connect to does not exist.")
else:
raise
9 changes: 8 additions & 1 deletion aztk_cli/utils.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,19 @@
import datetime
import getpass
import subprocess
import sys
import threading
import time
from subprocess import call
from typing import List

import azure.batch.models as batch_models

from aztk import error, utils
from aztk.utils import get_ssh_key, helpers
from aztk.models import ClusterConfiguration
from aztk.spark import models
from aztk.utils import get_ssh_key, helpers

from . import log


Expand Down Expand Up @@ -152,6 +156,8 @@ def ssh_in_master(
:param ports: an list of local and remote ports
:type ports: [[<local-port>, <remote-port>]]
"""
# check if ssh is available, this throws OSError if ssh is not present
subprocess.call(["ssh"], stdout=subprocess.PIPE, stderr=subprocess.PIPE)

# Get master node id from task (job and task are both named pool_id)
cluster = client.get_cluster(cluster_id)
Expand Down Expand Up @@ -212,6 +218,7 @@ def ssh_in_master(

if connect:
call(command, shell=True)

return '\n\t{}\n'.format(command)

def print_batch_exception(batch_exception):
Expand Down