-
Notifications
You must be signed in to change notification settings - Fork 66
Feature: pure python ssh #577
Changes from 9 commits
82c6278
164f412
54140f6
c833f58
db8a572
18fad2b
2e1af8d
21c6219
895020c
2bda21f
c64cfed
960ef2f
c246b70
1a0f8bf
03622a9
322e0fe
9eca1b8
bb21977
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,10 +8,70 @@ | |
import socket | ||
import socketserver as SocketServer | ||
import sys | ||
import threading | ||
from concurrent.futures import ThreadPoolExecutor | ||
|
||
from aztk.error import AztkError | ||
from . import helpers | ||
|
||
|
||
g_verbose = False | ||
|
||
|
||
def verbose(s): | ||
if g_verbose: | ||
print(s) | ||
|
||
|
||
class ForwardServer(SocketServer.ThreadingTCPServer): | ||
daemon_threads = True | ||
allow_reuse_address = True | ||
|
||
# pylint: disable=no-member | ||
class Handler(SocketServer.BaseRequestHandler): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Cloud you separate in another file? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I feel like this belongs here since this is the ssh tunnel request handler and this file is meant to have the utilities necessary to make ssh connections and commands. |
||
def handle(self): | ||
try: | ||
channel = self.ssh_transport.open_channel('direct-tcpip', | ||
(self.chain_host, self.chain_port), | ||
self.request.getpeername()) | ||
except Exception as e: | ||
verbose('Incoming request to %s:%d failed: %s' % (self.chain_host, | ||
self.chain_port, | ||
repr(e))) | ||
return | ||
if channel is None: | ||
verbose('Incoming request to %s:%d was rejected by the SSH server.' % | ||
(self.chain_host, self.chain_port)) | ||
return | ||
|
||
verbose('Connected! Tunnel open %r -> %r -> %r' % (self.request.getpeername(), | ||
channel.getpeername(), (self.chain_host, self.chain_port))) | ||
while True: | ||
r, w, x = select.select([self.request, channel], [], []) | ||
if self.request in r: | ||
data = self.request.recv(1024) | ||
if len(data) == 0: | ||
break | ||
channel.send(data) | ||
if channel in r: | ||
data = channel.recv(1024) | ||
if len(data) == 0: | ||
break | ||
self.request.send(data) | ||
|
||
peername = self.request.getpeername() | ||
channel.close() | ||
self.request.close() | ||
verbose('Tunnel closed from %r' % (peername,)) | ||
|
||
|
||
def forward_tunnel(local_port, remote_host, remote_port, transport): | ||
class SubHandler(Handler): | ||
chain_host = remote_host | ||
chain_port = remote_port | ||
ssh_transport = transport | ||
thread = threading.Thread(target=ForwardServer(('', local_port), SubHandler).serve_forever, daemon=True) | ||
thread.start() | ||
return thread | ||
|
||
|
||
def connect(hostname, | ||
|
@@ -39,6 +99,21 @@ def connect(hostname, | |
return client | ||
|
||
|
||
def forward_ports(client, port_forward_list): | ||
threads = [] | ||
if port_forward_list: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would maybe do
to have less nesting |
||
for port_forwarding_specification in port_forward_list: | ||
threads.append( | ||
forward_tunnel( | ||
port_forwarding_specification.remote_port, | ||
"127.0.0.1", | ||
port_forwarding_specification.local_port, | ||
client.get_transport() | ||
) | ||
) | ||
return threads | ||
|
||
|
||
def node_exec_command(node_id, command, username, hostname, port, ssh_key=None, password=None, container_name=None, timeout=None): | ||
try: | ||
client = connect(hostname=hostname, port=port, username=username, password=password, pkey=ssh_key, timeout=timeout) | ||
|
@@ -133,3 +208,25 @@ async def clus_copy(username, nodes, source_path, destination_path, ssh_key=None | |
container_name, | ||
timeout) for node, node_rls in nodes] | ||
) | ||
|
||
|
||
def node_ssh(username, hostname, port, ssh_key=None, password=None, port_forward_list=None, timeout=None): | ||
try: | ||
client = connect( | ||
hostname=hostname, | ||
port=port, | ||
username=username, | ||
password=password, | ||
pkey=ssh_key, | ||
timeout=timeout | ||
) | ||
threads = forward_ports(client=client, port_forward_list=port_forward_list) | ||
except AztkError as e: | ||
raise e | ||
|
||
try: | ||
import time | ||
while True: | ||
time.sleep(1) | ||
except KeyboardInterrupt: | ||
pass | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. doesn't this prevent ctrl + d There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. no, it will just intercept it so the stacktrace isn't printed There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you add a comment about it so we don't forget why its needed |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,13 +8,16 @@ | |
from aztk_cli import config, log, utils | ||
from aztk_cli.config import SshConfig | ||
|
||
from aztk.spark.models import PortForwardingSpecification | ||
|
||
|
||
def setup_parser(parser: argparse.ArgumentParser): | ||
parser.add_argument('--id', dest="cluster_id", help='The unique id of your spark cluster') | ||
parser.add_argument('--webui', help='Local port to port spark\'s master UI to') | ||
parser.add_argument('--jobui', help='Local port to port spark\'s job UI to') | ||
parser.add_argument('--jobhistoryui', help='Local port to port spark\'s job history UI to') | ||
parser.add_argument('-u', '--username', help='Username to spark cluster') | ||
parser.add_argument('--password', help='Password for the specified ssh user') | ||
parser.add_argument('--host', dest="host", action='store_true', help='Connect to the host of the Spark container') | ||
parser.add_argument('--no-connect', dest="connect", action='store_false', | ||
help='Do not create the ssh session. Only print out the command to run.') | ||
|
@@ -52,29 +55,11 @@ def execute(args: typing.NamedTuple): | |
utils.log_property("connect", ssh_conf.connect) | ||
log.info("-------------------------------------------") | ||
|
||
# get ssh command | ||
try: | ||
ssh_cmd = utils.ssh_in_master( | ||
client=spark_client, | ||
cluster_id=ssh_conf.cluster_id, | ||
webui=ssh_conf.web_ui_port, | ||
jobui=ssh_conf.job_ui_port, | ||
jobhistoryui=ssh_conf.job_history_ui_port, | ||
username=ssh_conf.username, | ||
host=ssh_conf.host, | ||
connect=ssh_conf.connect, | ||
internal=ssh_conf.internal) | ||
|
||
if not ssh_conf.connect: | ||
log.info("") | ||
log.info("Use the following command to connect to your spark head node:") | ||
log.info("\t%s", ssh_cmd) | ||
|
||
except batch_error.BatchErrorException as e: | ||
if e.error.code == "PoolNotFound": | ||
raise aztk.error.AztkError("The cluster you are trying to connect to does not exist.") | ||
else: | ||
raise | ||
shell_out_ssh(spark_client, ssh_conf) | ||
except OSError: | ||
# no ssh client is found, falling back to pure python | ||
native_python_ssh_into_master(spark_client, cluster, ssh_conf, args.password) | ||
|
||
|
||
def print_plugin_ports(cluster_config: ClusterConfiguration): | ||
|
@@ -88,7 +73,7 @@ def print_plugin_ports(cluster_config: ClusterConfiguration): | |
if port.expose_publicly: | ||
has_ports = True | ||
plugin_ports[plugin.name].append(port) | ||
|
||
if has_ports: | ||
log.info("plugins:") | ||
|
||
|
@@ -101,3 +86,55 @@ def print_plugin_ports(cluster_config: ClusterConfiguration): | |
label += " {}".format(port.name) | ||
url = "{0}{1}".format(http_prefix, port.public_port) | ||
utils.log_property(label, url) | ||
|
||
|
||
def native_python_ssh_into_master(spark_client, cluster, ssh_conf, password): | ||
configuration = spark_client.get_cluster_config(cluster.id) | ||
plugin_ports = [] | ||
if configuration and configuration.plugins: | ||
ports = [ | ||
PortForwardingSpecification( | ||
port.internal, | ||
port.public_port) for plugin in configuration.plugins for port in plugin.ports if port.expose_publicly | ||
] | ||
plugin_ports.extend(ports) | ||
|
||
print("Press ctrl+c to exit...") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. shouldn't that be ctrl + d? or not as we are not giving shell access There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We aren't giving shell access, the process just hangs so the ports stay open until killed (with |
||
spark_client.cluster_ssh_into_master( | ||
cluster.id, | ||
cluster.master_node_id, | ||
ssh_conf.username, | ||
ssh_key=None, | ||
password=password, | ||
port_forward_list=[ | ||
PortForwardingSpecification(remote_port=8080, local_port=8080), # web ui | ||
PortForwardingSpecification(remote_port=4040, local_port=4040), # job ui | ||
PortForwardingSpecification(remote_port=18080, local_port=18080), # job history ui | ||
] + plugin_ports, | ||
internal=ssh_conf.internal | ||
) | ||
|
||
|
||
def shell_out_ssh(spark_client, ssh_conf): | ||
try: | ||
ssh_cmd = utils.ssh_in_master( | ||
client=spark_client, | ||
cluster_id=ssh_conf.cluster_id, | ||
webui=ssh_conf.web_ui_port, | ||
jobui=ssh_conf.job_ui_port, | ||
jobhistoryui=ssh_conf.job_history_ui_port, | ||
username=ssh_conf.username, | ||
host=ssh_conf.host, | ||
connect=ssh_conf.connect, | ||
internal=ssh_conf.internal) | ||
|
||
if not ssh_conf.connect: | ||
log.info("") | ||
log.info("Use the following command to connect to your spark head node:") | ||
log.info("\t%s", ssh_cmd) | ||
|
||
except batch_error.BatchErrorException as e: | ||
if e.error.code == "PoolNotFound": | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we should have an enum(or const) for those error codes There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. agreed, but I think that should probably come in a single PR for the entire repo. |
||
raise aztk.error.AztkError("The cluster you are trying to connect to does not exist.") | ||
else: | ||
raise |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
log.debug? and --verbose flag
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah that would be better
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
actually, since this is in the SDK, we don't current have a logger. Maybe that would be nice to have though.