Skip to content
This repository was archived by the owner on Feb 3, 2021. It is now read-only.

Commit

Permalink
Feature: add node run command (#572)
Browse files Browse the repository at this point in the history
* add node run command

* whitespace

* add node-run doc

* add host flag

* refactor, print->log

* generated username

* more secure random

* better handling of find node, type conversion

* add generate_user_on_node

* docs update

* fix docs

* remove duplicate import, sort
  • Loading branch information
jafreck authored Jun 4, 2018
1 parent b9a863b commit af449dc
Show file tree
Hide file tree
Showing 6 changed files with 115 additions and 34 deletions.
71 changes: 56 additions & 15 deletions aztk/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,18 @@

import azure.batch.models as batch_models
import azure.batch.models.batch_error as batch_error
from Cryptodome.PublicKey import RSA

import aztk.error as error
import aztk.models as models
import aztk.utils.azure_api as azure_api
import aztk.utils.constants as constants
import aztk.utils.get_ssh_key as get_ssh_key
import aztk.utils.helpers as helpers
import aztk.utils.ssh as ssh_lib
import aztk.models as models
import azure.batch.models as batch_models
from azure.batch.models import batch_error
from Cryptodome.PublicKey import RSA
from aztk.internal import cluster_data
from aztk.utils import secure_utils


class Client:
def __init__(self, secrets_config: models.SecretsConfiguration):
Expand Down Expand Up @@ -205,17 +207,26 @@ def __create_user_on_node(self, username, pool_id, node_id, ssh_key=None, passwo
except batch_error.BatchErrorException as error:
raise error

def __generate_user_on_pool(self, username, pool_id, nodes):
def __generate_user_on_node(self, pool_id, node_id):
generated_username = secure_utils.generate_random_string()
ssh_key = RSA.generate(2048)
ssh_pub_key = ssh_key.publickey().exportKey('OpenSSH').decode('utf-8')
self.__create_user_on_node(generated_username, pool_id, node_id, ssh_pub_key)
return generated_username, ssh_key

def __generate_user_on_pool(self, pool_id, nodes):
generated_username = secure_utils.generate_random_string()
ssh_key = RSA.generate(2048)
ssh_pub_key = ssh_key.publickey().exportKey('OpenSSH').decode('utf-8')
with concurrent.futures.ThreadPoolExecutor() as executor:
futures = {executor.submit(self.__create_user_on_node,
username,
generated_username,
pool_id,
node.id,
ssh_pub_key): node for node in nodes}
concurrent.futures.wait(futures)
return ssh_key

return generated_username, ssh_key

def __create_user_on_pool(self, username, pool_id, nodes, ssh_pub_key=None, password=None):
with concurrent.futures.ThreadPoolExecutor() as executor:
Expand All @@ -232,19 +243,48 @@ def __delete_user_on_pool(self, username, pool_id, nodes):
futures = [exector.submit(self.__delete_user, pool_id, node.id, username) for node in nodes]
concurrent.futures.wait(futures)

def __node_run(self, cluster_id, node_id, command, internal, container_name=None, timeout=None):
pool, nodes = self.__get_pool_details(cluster_id)
try:
node = next(node for node in nodes if node.id == node_id)
except StopIteration:
raise error.AztkError("Node with id {} not found".format(node_id))

if internal:
node_rls = models.RemoteLogin(ip_address=node.ip_address, port="22")
else:
node_rls = self.__get_remote_login_settings(pool.id, node.id)

try:
generated_username, ssh_key = self.__generate_user_on_node(pool.id, node.id)
output = ssh_lib.node_exec_command(
node.id,
command,
generated_username,
node_rls.ip_address,
node_rls.port,
ssh_key=ssh_key.exportKey().decode('utf-8'),
container_name=container_name,
timeout=timeout
)
return output
finally:
self.__delete_user(cluster_id, node.id, generated_username)

def __cluster_run(self, cluster_id, command, internal, container_name=None, timeout=None):
pool, nodes = self.__get_pool_details(cluster_id)
nodes = [node for node in nodes]
nodes = list(nodes)
if internal:
cluster_nodes = [(node, models.RemoteLogin(ip_address=node.ip_address, port="22")) for node in nodes]
else:
cluster_nodes = [(node, self.__get_remote_login_settings(pool.id, node.id)) for node in nodes]

try:
ssh_key = self.__generate_user_on_pool('aztk', pool.id, nodes)
generated_username, ssh_key = self.__generate_user_on_pool(pool.id, nodes)
output = asyncio.get_event_loop().run_until_complete(
ssh_lib.clus_exec_command(
command,
'aztk',
generated_username,
cluster_nodes,
ssh_key=ssh_key.exportKey().decode('utf-8'),
container_name=container_name,
Expand All @@ -255,21 +295,22 @@ def __cluster_run(self, cluster_id, command, internal, container_name=None, time
except OSError as exc:
raise exc
finally:
self.__delete_user_on_pool('aztk', pool.id, nodes)
self.__delete_user_on_pool(generated_username, pool.id, nodes)

def __cluster_copy(self, cluster_id, source_path, destination_path, container_name=None, internal=False, get=False, timeout=None):
pool, nodes = self.__get_pool_details(cluster_id)
nodes = [node for node in nodes]
nodes = list(nodes)
if internal:
cluster_nodes = [(node, models.RemoteLogin(ip_address=node.ip_address, port="22")) for node in nodes]
else:
cluster_nodes = [(node, self.__get_remote_login_settings(pool.id, node.id)) for node in nodes]

try:
ssh_key = self.__generate_user_on_pool('aztk', pool.id, nodes)
generated_username, ssh_key = self.__generate_user_on_pool(pool.id, nodes)
output = asyncio.get_event_loop().run_until_complete(
ssh_lib.clus_copy(
container_name=container_name,
username='aztk',
username=generated_username,
nodes=cluster_nodes,
source_path=source_path,
destination_path=destination_path,
Expand All @@ -282,7 +323,7 @@ def __cluster_copy(self, cluster_id, source_path, destination_path, container_na
except (OSError, batch_error.BatchErrorException) as exc:
raise exc
finally:
self.__delete_user_on_pool('aztk', pool.id, nodes)
self.__delete_user_on_pool(generated_username, pool.id, nodes)

def __submit_job(self,
job_configuration,
Expand Down
22 changes: 18 additions & 4 deletions aztk/spark/client.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,20 @@
from typing import List

import azure.batch.models.batch_error as batch_error

import aztk
from aztk import error
from aztk.client import Client as BaseClient
from aztk.internal.cluster_data import NodeData
from aztk.spark import models
from aztk.utils import helpers
from aztk.spark.helpers import create_cluster as create_cluster_helper
from aztk.spark.helpers import submit as cluster_submit_helper
from aztk.spark.helpers import job_submission as job_submit_helper
from aztk.spark.helpers import get_log as get_log_helper
from aztk.spark.helpers import job_submission as job_submit_helper
from aztk.spark.helpers import submit as cluster_submit_helper
from aztk.spark.helpers import cluster_diagnostic_helper
from aztk.spark.utils import util
from aztk.internal.cluster_data import NodeData
from aztk.utils import helpers


class Client(BaseClient):
"""
Expand Down Expand Up @@ -170,6 +173,17 @@ def cluster_run(self, cluster_id: str, command: str, host=False, internal: bool
except batch_error.BatchErrorException as e:
raise error.AztkError(helpers.format_batch_exception(e))

def node_run(self, cluster_id: str, node_id: str, command: str, host=False, internal: bool = False, timeout=None):
try:
return self.__node_run(cluster_id,
node_id,
command,
internal,
container_name='spark' if not host else None,
timeout=timeout)
except batch_error.BatchErrorException as e:
raise error.AztkError(helpers.format_batch_exception(e))

def cluster_copy(self, cluster_id: str, source_path: str, destination_path: str, host: bool = False, internal: bool = False, timeout=None):
try:
container_name = None if host else 'spark'
Expand Down
10 changes: 9 additions & 1 deletion aztk/utils/secure_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
import random
import string

from Cryptodome.Cipher import AES, PKCS1_OAEP
from Cryptodome.PublicKey import RSA
from Cryptodome.Random import get_random_bytes
from Cryptodome.Cipher import AES, PKCS1_OAEP


def encrypt_password(ssh_pub_key, password):
if not password:
Expand All @@ -16,3 +20,7 @@ def encrypt_password(ssh_pub_key, password):
cipher_aes = AES.new(session_key, AES.MODE_EAX)
ciphertext, tag = cipher_aes.encrypt_and_digest(password.encode())
return [encrypted_aes_session_key, cipher_aes.nonce, tag, ciphertext]


def generate_random_string(charset=string.ascii_uppercase + string.ascii_lowercase, length=16):
return ''.join(random.SystemRandom().choice(charset) for _ in range(length))
27 changes: 13 additions & 14 deletions aztk_cli/spark/endpoints/cluster/cluster_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,26 +10,25 @@ def setup_parser(parser: argparse.ArgumentParser):
dest='cluster_id',
required=True,
help='The unique id of your spark cluster')
parser.add_argument('--node-id', '-n',
dest='node_id',
required=False,
help='The unique id of the node in the cluster to run the command on')
parser.add_argument('command',
help='The command to run on your spark cluster')
parser.add_argument('--internal', action='store_true',
help='Connect using the local IP of the master node. Only use if using a VPN.')
parser.set_defaults(internal=False)
help='Connect using the local IP of the master node. Only use if using a VPN')
parser.add_argument('--host', action='store_true',
help='Run the command on the host instead of the Spark Docker container')
parser.set_defaults(internal=False, host=False)


def execute(args: typing.NamedTuple):
spark_client = aztk.spark.Client(config.load_aztk_secrets())
with utils.Spinner():
results = spark_client.cluster_run(args.cluster_id, args.command, args.internal)
[print_execute_result(node_id, result) for node_id, result in results]
if args.node_id:
results = [spark_client.node_run(args.cluster_id, args.node_id, args.command, args.host, args.internal)]
else:
results = spark_client.cluster_run(args.cluster_id, args.command, args.host, args.internal)


def print_execute_result(node_id, result):
print("-" * (len(node_id) + 6))
print("| ", node_id, " |")
print("-" * (len(node_id) + 6))
if isinstance(result, Exception):
print(result + "\n")
else:
for line in result:
print(line)
[utils.log_execute_result(node_id, result) for node_id, result in results]
11 changes: 11 additions & 0 deletions aztk_cli/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,3 +449,14 @@ def print_cluster_conf(cluster_conf: ClusterConfiguration, wait: bool):
def log_property(label: str, value: str):
label += ":"
log.info("{0:30} {1}".format(label, value))


def log_execute_result(node_id, result):
log.info("-" * (len(node_id) + 4))
log.info("| %s |", node_id)
log.info("-" * (len(node_id) + 4))
if isinstance(result, Exception):
log.info("%s\n", result)
else:
for line in result:
print(line)
8 changes: 8 additions & 0 deletions docs/10-clusters.md
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,14 @@ aztk spark cluster run --id <your_cluster_id> "<command>"

The command is executed through an SSH tunnel.

### Run a command on a specific node in the cluster
To run a command on all nodes in the cluster, run:
```sh
aztk spark cluster run --id <your_cluster_id> --node-id <your_node_id> "<command>"
```
This command is executed through a SSH tunnel.
To get the id of nodes in your cluster, run `aztk spark cluster get --id <your_cluster_id>`.

### Copy a file to all nodes in the cluster
To securely copy a file to all nodes, run:
```sh
Expand Down

0 comments on commit af449dc

Please sign in to comment.