Skip to content

Commit

Permalink
SSH extension (#1363)
Browse files Browse the repository at this point in the history
* Add SSH extension

* Add unit tests and break things down a bit

* Add ssh config command

* URL-encode modulus and exponent

* Make az ssh vm start interactive

* Finalize ssh commands

* Updated help docs

* Reorder imports

* Updated arguments

CLI arguments were changed from extra to argument with None defaults in
custom. Changed resource_group to resource_group_name argument for
auto-population from cli-core.

Removed Python 2 and added new Python 3s to setup.

* Fix lint error

* move ssh certificate part from core to extension

* follow CLI extension guide

* update CODEOWNERS

* update CODEOWNERS

* roll back to [email protected] for certificate file

* update --name to --vm-name

* fix sytle error

Co-authored-by: Ryan Rossiter <[email protected]>
Co-authored-by: Xiaojian Xu <[email protected]>
  • Loading branch information
3 people authored Aug 17, 2020
1 parent 16610a3 commit 92a2530
Show file tree
Hide file tree
Showing 18 changed files with 874 additions and 1 deletion.
4 changes: 3 additions & 1 deletion .github/CODEOWNERS
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,8 @@

/src/datashare/ @fengzhou-msft

/src/ssh/ @rlrossiter @danybeam @arrownj

/src/k8sconfiguration/ @NarayanThiru

/src/log-analytics-solution/ @zhoxing-ms
Expand All @@ -142,4 +144,4 @@

/src/attestation/ @YalinLi0312 @bim-msft

/src/guestconfig/ @gehuan
/src/guestconfig/ @gehuan
28 changes: 28 additions & 0 deletions src/ssh/azext_ssh/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# --------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for license information.
# --------------------------------------------------------------------------------------------

from azure.cli.core import AzCommandsLoader
from azext_ssh._help import helps # pylint: disable=unused-import


class SshCommandsLoader(AzCommandsLoader):

def __init__(self, cli_ctx=None):
from azure.cli.core.commands import CliCommandType

ssh_custom = CliCommandType(operations_tmpl='azext_ssh.custom#{}')
super(SshCommandsLoader, self).__init__(cli_ctx=cli_ctx, custom_command_type=ssh_custom)

def load_command_table(self, args):
from azext_ssh.commands import load_command_table
load_command_table(self, args)
return self.command_table

def load_arguments(self, command):
from azext_ssh._params import load_arguments
load_arguments(self, command)


COMMAND_LOADER_CLS = SshCommandsLoader
35 changes: 35 additions & 0 deletions src/ssh/azext_ssh/_help.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# --------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for license information.
# --------------------------------------------------------------------------------------------

from knack.help_files import helps

helps['ssh'] = """
type: group
short-summary: SSH into Azure VMs
"""

helps['ssh vm'] = """
type: command
short-summary: SSH into Azure VMs
examples:
- name: Give a resource group and VM to SSH to
text: |
az ssh vm --resource-group myResourceGroup --vm-name myVm
- name: Give the public IP of a VM to SSH to
text: |
az ssh vm --ip 1.2.3.4
"""

helps['ssh config'] = """
type: command
short-summary: Create an SSH config for Azure VMs which can then be imported to 3rd party SSH clients
examples:
- name: Give a resource group and VM for which to create a config, and save in a local file
text: |
az ssh config --resource-group myResourceGroup --vm-name myVm --file ./sshconfig
- name: Give the public IP of a VM for which to create a config
text: |
az ssh config --ip 1.2.3.4 --file ./sshconfig
"""
20 changes: 20 additions & 0 deletions src/ssh/azext_ssh/_params.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# --------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for license information.
# --------------------------------------------------------------------------------------------


def load_arguments(self, _):

with self.argument_context('ssh vm') as c:
c.argument('vm_name', options_list=['--vm-name'], help='The name of the VM')
c.argument('ssh_ip', options_list=['--ip'], help='The public IP address of the VM')
c.argument('public_key_file', help='The RSA public key file path')
c.argument('private_key_file', help='The RSA private key file path')

with self.argument_context('ssh config') as c:
c.argument('config_path', options_list=['--file'], help='The file path to write the SSH config to')
c.argument('vm_name', options_list=['--vm-name'], help='The name of the VM')
c.argument('ssh_ip', options_list=['--ip'], help='The public IP address of the VM')
c.argument('public_key_file', help='The RSA public key file path')
c.argument('private_key_file', help='The RSA private key file path')
4 changes: 4 additions & 0 deletions src/ssh/azext_ssh/azext_metadata.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
{
"azext.isPreview": true,
"azext.minCliCoreVersion": "2.4.0"
}
11 changes: 11 additions & 0 deletions src/ssh/azext_ssh/commands.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# --------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for license information.
# --------------------------------------------------------------------------------------------


def load_command_table(self, _):

with self.command_group('ssh') as g:
g.custom_command('vm', 'ssh_vm')
g.custom_command('config', 'ssh_config')
115 changes: 115 additions & 0 deletions src/ssh/azext_ssh/custom.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
# --------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for license information.
# --------------------------------------------------------------------------------------------

import functools
import os
import hashlib
import json

from knack import util

from . import ip_utils
from . import rsa_parser
from . import ssh_utils


def ssh_vm(cmd, resource_group_name=None, vm_name=None, ssh_ip=None, public_key_file=None, private_key_file=None):
_do_ssh_op(cmd, resource_group_name, vm_name, ssh_ip,
public_key_file, private_key_file, ssh_utils.start_ssh_connection)


def ssh_config(cmd, config_path, resource_group_name=None, vm_name=None, ssh_ip=None,
public_key_file=None, private_key_file=None):
op_call = functools.partial(ssh_utils.write_ssh_config, config_path, resource_group_name, vm_name)
_do_ssh_op(cmd, resource_group_name, vm_name, ssh_ip, public_key_file, private_key_file, op_call)


def _do_ssh_op(cmd, resource_group, vm_name, ssh_ip, public_key_file, private_key_file, op_call):
_assert_args(resource_group, vm_name, ssh_ip)
public_key_file, private_key_file = _check_public_private_files(public_key_file, private_key_file)
ssh_ip = ssh_ip or ip_utils.get_ssh_ip(cmd, resource_group, vm_name)

if not ssh_ip:
raise util.CLIError(f"VM '{vm_name}' does not have a public IP address to SSH to")

scopes = ["https://pas.windows.net/CheckMyAccess/Linux/user_impersonation"]
data = _prepare_jwk_data(public_key_file)
from azure.cli.core._profile import Profile
profile = Profile(cli_ctx=cmd.cli_ctx)
username, certificate = profile.get_msal_token(scopes, data)

cert_file = _write_cert_file(public_key_file, certificate)
op_call(ssh_ip, username, cert_file, private_key_file)


def _prepare_jwk_data(public_key_file):
modulus, exponent = _get_modulus_exponent(public_key_file)
key_hash = hashlib.sha256()
key_hash.update(modulus.encode('utf-8'))
key_hash.update(exponent.encode('utf-8'))
key_id = key_hash.hexdigest()
jwk = {
"kty": "RSA",
"n": modulus,
"e": exponent,
"kid": key_id
}
json_jwk = json.dumps(jwk)
data = {
"token_type": "ssh-cert",
"req_cnf": json_jwk,
"key_id": key_id
}
return data


def _assert_args(resource_group, vm_name, ssh_ip):
if not (resource_group or vm_name or ssh_ip):
raise util.CLIError("The VM must be specified by --ip or --resource-group and --name")

if resource_group and not vm_name or vm_name and not resource_group:
raise util.CLIError("--resource-group and --name must be provided together")

if ssh_ip and (vm_name or resource_group):
raise util.CLIError("--ip cannot be used with --resource-group or --name")


def _check_public_private_files(public_key_file, private_key_file):
ssh_dir_parts = ["~", ".ssh"]
public_key_file = public_key_file or os.path.expanduser(os.path.join(*ssh_dir_parts, "id_rsa.pub"))
private_key_file = private_key_file or os.path.expanduser(os.path.join(*ssh_dir_parts, "id_rsa"))

if not os.path.isfile(public_key_file):
raise util.CLIError(f"Pulic key file {public_key_file} not found")
if not os.path.isfile(private_key_file):
raise util.CLIError(f"Private key file {private_key_file} not found")

return public_key_file, private_key_file


def _write_cert_file(public_key_file, certificate_contents):
cert_file = os.path.join(*os.path.split(public_key_file)[:-1], "id_rsa-cert.pub")
with open(cert_file, 'w') as f:
f.write(f"[email protected] {certificate_contents}")

return cert_file


def _get_modulus_exponent(public_key_file):
if not os.path.isfile(public_key_file):
raise util.CLIError(f"Public key file '{public_key_file}' was not found")

with open(public_key_file, 'r') as f:
public_key_text = f.read()

parser = rsa_parser.RSAParser()
try:
parser.parse(public_key_text)
except Exception as e:
raise util.CLIError(f"Could not parse public key. Error: {str(e)}")
modulus = parser.modulus
exponent = parser.exponent

return modulus, exponent
22 changes: 22 additions & 0 deletions src/ssh/azext_ssh/file_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# --------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for license information.
# --------------------------------------------------------------------------------------------

import errno
import os


def make_dirs_for_file(file_path):
if not os.path.exists(file_path):
mkdir_p(os.path.dirname(file_path))


def mkdir_p(path):
try:
os.makedirs(path)
except OSError as exc: # Python <= 2.5
if exc.errno == errno.EEXIST and os.path.isdir(path):
pass
else:
raise
36 changes: 36 additions & 0 deletions src/ssh/azext_ssh/ip_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# --------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for license information.
# --------------------------------------------------------------------------------------------

from azure.cli.core.commands import client_factory
from azure.cli.core import profiles
from msrestazure import tools


def get_ssh_ip(cmd, resource_group, vm_name):
compute_client = client_factory.get_mgmt_service_client(cmd.cli_ctx, profiles.ResourceType.MGMT_COMPUTE)
network_client = client_factory.get_mgmt_service_client(cmd.cli_ctx, profiles.ResourceType.MGMT_NETWORK)
vm_client = compute_client.virtual_machines
nic_client = network_client.network_interfaces
ip_client = network_client.public_ip_addresses

vm = vm_client.get(resource_group, vm_name)
nics = vm.network_profile.network_interfaces
ssh_ip = None
for nic_ref in nics:
parsed_id = tools.parse_resource_id(nic_ref.id)
nic = nic_client.get(parsed_id['resource_group'], parsed_id['name'])
ip_configs = nic.ip_configurations
for ip_config in ip_configs:
public_ip_ref = ip_config.public_ip_address
parsed_ip_id = tools.parse_resource_id(public_ip_ref.id)
public_ip = ip_client.get(parsed_ip_id['resource_group'], parsed_ip_id['name'])
ssh_ip = public_ip.ip_address

if ssh_ip:
break

if ssh_ip:
break
return ssh_ip
60 changes: 60 additions & 0 deletions src/ssh/azext_ssh/rsa_parser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# --------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for license information.
# --------------------------------------------------------------------------------------------

import base64
import struct


class RSAParser(object):
# pylint: disable=too-few-public-methods
RSAAlgorithm = 'ssh-rsa'

def __init__(self):
self.algorithm = ''
self.modulus = ''
self.exponent = ''
self._key_length_big_endian = True

def parse(self, public_key_text):
text_parts = public_key_text.split(' ')

if len(text_parts) < 2:
error_str = ("Incorrectly formatted public key. "
"Key must be format '<algorithm> <base64_key>'")
raise ValueError(error_str)

algorithm = text_parts[0]
if algorithm != RSAParser.RSAAlgorithm:
raise ValueError(f"Public key is not ssh-rsa algorithm ({algorithm})")

b64_string = text_parts[1]
key_bytes = base64.b64decode(b64_string)
fields = list(self._get_fields(key_bytes))

if len(fields) < 3:
error_str = ("Incorrectly encoded public key. "
"Encoded key must be base64 encoded <algorithm><exponent><modulus>")
raise ValueError(error_str)

encoded_algorithm = fields[0].decode("ascii")
if encoded_algorithm != RSAParser.RSAAlgorithm:
raise ValueError(f"Encoded public key is not ssh-rsa algorithm ({encoded_algorithm})")

self.algorithm = encoded_algorithm
self.exponent = base64.urlsafe_b64encode(fields[1]).decode("ascii")
self.modulus = base64.urlsafe_b64encode(fields[2]).decode("ascii")

def _get_fields(self, key_bytes):
read = 0
while read < len(key_bytes):
length = struct.unpack(self._get_struct_format(), key_bytes[read:read + 4])[0]
read = read + 4
data = key_bytes[read:read + length]
read = read + length
yield data

def _get_struct_format(self):
format_start = ">" if self._key_length_big_endian else "<"
return format_start + "L"
Loading

0 comments on commit 92a2530

Please sign in to comment.