diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index c24d8c24255..4652dff59bb 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -120,6 +120,8 @@ /src/datashare/ @fengzhou-msft +/src/ssh/ @rlrossiter @danybeam @arrownj + /src/k8sconfiguration/ @NarayanThiru /src/log-analytics-solution/ @zhoxing-ms @@ -142,4 +144,4 @@ /src/attestation/ @YalinLi0312 @bim-msft -/src/guestconfig/ @gehuan \ No newline at end of file +/src/guestconfig/ @gehuan diff --git a/src/ssh/azext_ssh/__init__.py b/src/ssh/azext_ssh/__init__.py new file mode 100644 index 00000000000..57cd43aa7a2 --- /dev/null +++ b/src/ssh/azext_ssh/__init__.py @@ -0,0 +1,28 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- + +from azure.cli.core import AzCommandsLoader +from azext_ssh._help import helps # pylint: disable=unused-import + + +class SshCommandsLoader(AzCommandsLoader): + + def __init__(self, cli_ctx=None): + from azure.cli.core.commands import CliCommandType + + ssh_custom = CliCommandType(operations_tmpl='azext_ssh.custom#{}') + super(SshCommandsLoader, self).__init__(cli_ctx=cli_ctx, custom_command_type=ssh_custom) + + def load_command_table(self, args): + from azext_ssh.commands import load_command_table + load_command_table(self, args) + return self.command_table + + def load_arguments(self, command): + from azext_ssh._params import load_arguments + load_arguments(self, command) + + +COMMAND_LOADER_CLS = SshCommandsLoader diff --git a/src/ssh/azext_ssh/_help.py b/src/ssh/azext_ssh/_help.py new file mode 100644 index 00000000000..c064b6d375e --- /dev/null +++ b/src/ssh/azext_ssh/_help.py @@ -0,0 +1,35 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- + +from knack.help_files import helps + +helps['ssh'] = """ + type: group + short-summary: SSH into Azure VMs +""" + +helps['ssh vm'] = """ + type: command + short-summary: SSH into Azure VMs + examples: + - name: Give a resource group and VM to SSH to + text: | + az ssh vm --resource-group myResourceGroup --vm-name myVm + - name: Give the public IP of a VM to SSH to + text: | + az ssh vm --ip 1.2.3.4 +""" + +helps['ssh config'] = """ + type: command + short-summary: Create an SSH config for Azure VMs which can then be imported to 3rd party SSH clients + examples: + - name: Give a resource group and VM for which to create a config, and save in a local file + text: | + az ssh config --resource-group myResourceGroup --vm-name myVm --file ./sshconfig + - name: Give the public IP of a VM for which to create a config + text: | + az ssh config --ip 1.2.3.4 --file ./sshconfig +""" diff --git a/src/ssh/azext_ssh/_params.py b/src/ssh/azext_ssh/_params.py new file mode 100644 index 00000000000..22324511773 --- /dev/null +++ b/src/ssh/azext_ssh/_params.py @@ -0,0 +1,20 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- + + +def load_arguments(self, _): + + with self.argument_context('ssh vm') as c: + c.argument('vm_name', options_list=['--vm-name'], help='The name of the VM') + c.argument('ssh_ip', options_list=['--ip'], help='The public IP address of the VM') + c.argument('public_key_file', help='The RSA public key file path') + c.argument('private_key_file', help='The RSA private key file path') + + with self.argument_context('ssh config') as c: + c.argument('config_path', options_list=['--file'], help='The file path to write the SSH config to') + c.argument('vm_name', options_list=['--vm-name'], help='The name of the VM') + c.argument('ssh_ip', options_list=['--ip'], help='The public IP address of the VM') + c.argument('public_key_file', help='The RSA public key file path') + c.argument('private_key_file', help='The RSA private key file path') diff --git a/src/ssh/azext_ssh/azext_metadata.json b/src/ssh/azext_ssh/azext_metadata.json new file mode 100644 index 00000000000..6a44beb25b4 --- /dev/null +++ b/src/ssh/azext_ssh/azext_metadata.json @@ -0,0 +1,4 @@ +{ + "azext.isPreview": true, + "azext.minCliCoreVersion": "2.4.0" +} \ No newline at end of file diff --git a/src/ssh/azext_ssh/commands.py b/src/ssh/azext_ssh/commands.py new file mode 100644 index 00000000000..e8049935233 --- /dev/null +++ b/src/ssh/azext_ssh/commands.py @@ -0,0 +1,11 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- + + +def load_command_table(self, _): + + with self.command_group('ssh') as g: + g.custom_command('vm', 'ssh_vm') + g.custom_command('config', 'ssh_config') diff --git a/src/ssh/azext_ssh/custom.py b/src/ssh/azext_ssh/custom.py new file mode 100644 index 00000000000..e0b09f6ddf5 --- /dev/null +++ b/src/ssh/azext_ssh/custom.py @@ -0,0 +1,115 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- + +import functools +import os +import hashlib +import json + +from knack import util + +from . import ip_utils +from . import rsa_parser +from . import ssh_utils + + +def ssh_vm(cmd, resource_group_name=None, vm_name=None, ssh_ip=None, public_key_file=None, private_key_file=None): + _do_ssh_op(cmd, resource_group_name, vm_name, ssh_ip, + public_key_file, private_key_file, ssh_utils.start_ssh_connection) + + +def ssh_config(cmd, config_path, resource_group_name=None, vm_name=None, ssh_ip=None, + public_key_file=None, private_key_file=None): + op_call = functools.partial(ssh_utils.write_ssh_config, config_path, resource_group_name, vm_name) + _do_ssh_op(cmd, resource_group_name, vm_name, ssh_ip, public_key_file, private_key_file, op_call) + + +def _do_ssh_op(cmd, resource_group, vm_name, ssh_ip, public_key_file, private_key_file, op_call): + _assert_args(resource_group, vm_name, ssh_ip) + public_key_file, private_key_file = _check_public_private_files(public_key_file, private_key_file) + ssh_ip = ssh_ip or ip_utils.get_ssh_ip(cmd, resource_group, vm_name) + + if not ssh_ip: + raise util.CLIError(f"VM '{vm_name}' does not have a public IP address to SSH to") + + scopes = ["https://pas.windows.net/CheckMyAccess/Linux/user_impersonation"] + data = _prepare_jwk_data(public_key_file) + from azure.cli.core._profile import Profile + profile = Profile(cli_ctx=cmd.cli_ctx) + username, certificate = profile.get_msal_token(scopes, data) + + cert_file = _write_cert_file(public_key_file, certificate) + op_call(ssh_ip, username, cert_file, private_key_file) + + +def _prepare_jwk_data(public_key_file): + modulus, exponent = _get_modulus_exponent(public_key_file) + key_hash = hashlib.sha256() + key_hash.update(modulus.encode('utf-8')) + key_hash.update(exponent.encode('utf-8')) + key_id = key_hash.hexdigest() + jwk = { + "kty": "RSA", + "n": modulus, + "e": exponent, + "kid": key_id + } + json_jwk = json.dumps(jwk) + data = { + "token_type": "ssh-cert", + "req_cnf": json_jwk, + "key_id": key_id + } + return data + + +def _assert_args(resource_group, vm_name, ssh_ip): + if not (resource_group or vm_name or ssh_ip): + raise util.CLIError("The VM must be specified by --ip or --resource-group and --name") + + if resource_group and not vm_name or vm_name and not resource_group: + raise util.CLIError("--resource-group and --name must be provided together") + + if ssh_ip and (vm_name or resource_group): + raise util.CLIError("--ip cannot be used with --resource-group or --name") + + +def _check_public_private_files(public_key_file, private_key_file): + ssh_dir_parts = ["~", ".ssh"] + public_key_file = public_key_file or os.path.expanduser(os.path.join(*ssh_dir_parts, "id_rsa.pub")) + private_key_file = private_key_file or os.path.expanduser(os.path.join(*ssh_dir_parts, "id_rsa")) + + if not os.path.isfile(public_key_file): + raise util.CLIError(f"Pulic key file {public_key_file} not found") + if not os.path.isfile(private_key_file): + raise util.CLIError(f"Private key file {private_key_file} not found") + + return public_key_file, private_key_file + + +def _write_cert_file(public_key_file, certificate_contents): + cert_file = os.path.join(*os.path.split(public_key_file)[:-1], "id_rsa-cert.pub") + with open(cert_file, 'w') as f: + f.write(f"ssh-rsa-cert-v01@openssh.com {certificate_contents}") + + return cert_file + + +def _get_modulus_exponent(public_key_file): + if not os.path.isfile(public_key_file): + raise util.CLIError(f"Public key file '{public_key_file}' was not found") + + with open(public_key_file, 'r') as f: + public_key_text = f.read() + + parser = rsa_parser.RSAParser() + try: + parser.parse(public_key_text) + except Exception as e: + raise util.CLIError(f"Could not parse public key. Error: {str(e)}") + modulus = parser.modulus + exponent = parser.exponent + + return modulus, exponent diff --git a/src/ssh/azext_ssh/file_utils.py b/src/ssh/azext_ssh/file_utils.py new file mode 100644 index 00000000000..3c3bd795c15 --- /dev/null +++ b/src/ssh/azext_ssh/file_utils.py @@ -0,0 +1,22 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- + +import errno +import os + + +def make_dirs_for_file(file_path): + if not os.path.exists(file_path): + mkdir_p(os.path.dirname(file_path)) + + +def mkdir_p(path): + try: + os.makedirs(path) + except OSError as exc: # Python <= 2.5 + if exc.errno == errno.EEXIST and os.path.isdir(path): + pass + else: + raise diff --git a/src/ssh/azext_ssh/ip_utils.py b/src/ssh/azext_ssh/ip_utils.py new file mode 100644 index 00000000000..c08f8bfc041 --- /dev/null +++ b/src/ssh/azext_ssh/ip_utils.py @@ -0,0 +1,36 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- + +from azure.cli.core.commands import client_factory +from azure.cli.core import profiles +from msrestazure import tools + + +def get_ssh_ip(cmd, resource_group, vm_name): + compute_client = client_factory.get_mgmt_service_client(cmd.cli_ctx, profiles.ResourceType.MGMT_COMPUTE) + network_client = client_factory.get_mgmt_service_client(cmd.cli_ctx, profiles.ResourceType.MGMT_NETWORK) + vm_client = compute_client.virtual_machines + nic_client = network_client.network_interfaces + ip_client = network_client.public_ip_addresses + + vm = vm_client.get(resource_group, vm_name) + nics = vm.network_profile.network_interfaces + ssh_ip = None + for nic_ref in nics: + parsed_id = tools.parse_resource_id(nic_ref.id) + nic = nic_client.get(parsed_id['resource_group'], parsed_id['name']) + ip_configs = nic.ip_configurations + for ip_config in ip_configs: + public_ip_ref = ip_config.public_ip_address + parsed_ip_id = tools.parse_resource_id(public_ip_ref.id) + public_ip = ip_client.get(parsed_ip_id['resource_group'], parsed_ip_id['name']) + ssh_ip = public_ip.ip_address + + if ssh_ip: + break + + if ssh_ip: + break + return ssh_ip diff --git a/src/ssh/azext_ssh/rsa_parser.py b/src/ssh/azext_ssh/rsa_parser.py new file mode 100644 index 00000000000..68ad8619c59 --- /dev/null +++ b/src/ssh/azext_ssh/rsa_parser.py @@ -0,0 +1,60 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- + +import base64 +import struct + + +class RSAParser(object): + # pylint: disable=too-few-public-methods + RSAAlgorithm = 'ssh-rsa' + + def __init__(self): + self.algorithm = '' + self.modulus = '' + self.exponent = '' + self._key_length_big_endian = True + + def parse(self, public_key_text): + text_parts = public_key_text.split(' ') + + if len(text_parts) < 2: + error_str = ("Incorrectly formatted public key. " + "Key must be format ' '") + raise ValueError(error_str) + + algorithm = text_parts[0] + if algorithm != RSAParser.RSAAlgorithm: + raise ValueError(f"Public key is not ssh-rsa algorithm ({algorithm})") + + b64_string = text_parts[1] + key_bytes = base64.b64decode(b64_string) + fields = list(self._get_fields(key_bytes)) + + if len(fields) < 3: + error_str = ("Incorrectly encoded public key. " + "Encoded key must be base64 encoded ") + raise ValueError(error_str) + + encoded_algorithm = fields[0].decode("ascii") + if encoded_algorithm != RSAParser.RSAAlgorithm: + raise ValueError(f"Encoded public key is not ssh-rsa algorithm ({encoded_algorithm})") + + self.algorithm = encoded_algorithm + self.exponent = base64.urlsafe_b64encode(fields[1]).decode("ascii") + self.modulus = base64.urlsafe_b64encode(fields[2]).decode("ascii") + + def _get_fields(self, key_bytes): + read = 0 + while read < len(key_bytes): + length = struct.unpack(self._get_struct_format(), key_bytes[read:read + 4])[0] + read = read + 4 + data = key_bytes[read:read + length] + read = read + length + yield data + + def _get_struct_format(self): + format_start = ">" if self._key_length_big_endian else "<" + return format_start + "L" diff --git a/src/ssh/azext_ssh/ssh_utils.py b/src/ssh/azext_ssh/ssh_utils.py new file mode 100644 index 00000000000..974f47dfc00 --- /dev/null +++ b/src/ssh/azext_ssh/ssh_utils.py @@ -0,0 +1,72 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- +import os +import platform +import subprocess + +from knack import log +from knack import util + +from . import file_utils + +logger = log.get_logger(__name__) + + +def start_ssh_connection(ip, username, cert_file, private_key_file): + command = [_get_ssh_path(), _get_host(username, ip)] + command = command + _build_args(cert_file, private_key_file) + logger.debug("Running ssh command %s", ' '.join(command)) + subprocess.call(command, shell=True) + + +def write_ssh_config(config_path, resource_group, vm_name, + ip, username, cert_file, private_key_file): + file_utils.make_dirs_for_file(config_path) + lines = [] + if resource_group and vm_name: + lines.append("Host " + resource_group + "-" + vm_name) + lines.append("\tUser " + username) + lines.append("\tHostName " + ip) + lines.append("\tCertificateFile " + cert_file) + lines.append("\tIdentityFile " + private_key_file) + + lines.append("Host " + ip) + lines.append("\tUser " + username) + lines.append("\tHostName " + ip) + lines.append("\tCertificateFile " + cert_file) + lines.append("\tIdentityFile " + private_key_file) + + with open(config_path, 'w') as f: + f.write('\n'.join(lines)) + + +def _get_ssh_path(): + ssh_path = "ssh" + + if platform.system() == 'Windows': + arch_data = platform.architecture() + is_32bit = arch_data[0] == '32bit' + sys_path = 'SysNative' if is_32bit else 'System32' + system_root = os.environ['SystemRoot'] + system32_path = os.path.join(system_root, sys_path) + ssh_path = os.path.join(system32_path, "openSSH", "ssh.exe") + logger.debug("Platform architecture: %s", str(arch_data)) + logger.debug("System Root: %s", system_root) + logger.debug("Attempting to run ssh from path %s", ssh_path) + + if not os.path.isfile(ssh_path): + raise util.CLIError("Could not find ssh.exe. Is the OpenSSH client installed?") + + return ssh_path + + +def _get_host(username, ip): + return username + "@" + ip + + +def _build_args(cert_file, private_key_file): + private_key = ["-i", private_key_file] + certificate = ["-o", "CertificateFile=" + cert_file] + return private_key + certificate diff --git a/src/ssh/azext_ssh/tests/__init__.py b/src/ssh/azext_ssh/tests/__init__.py new file mode 100644 index 00000000000..34913fb394d --- /dev/null +++ b/src/ssh/azext_ssh/tests/__init__.py @@ -0,0 +1,4 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- diff --git a/src/ssh/azext_ssh/tests/latest/__init__.py b/src/ssh/azext_ssh/tests/latest/__init__.py new file mode 100644 index 00000000000..34913fb394d --- /dev/null +++ b/src/ssh/azext_ssh/tests/latest/__init__.py @@ -0,0 +1,4 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- diff --git a/src/ssh/azext_ssh/tests/latest/test_custom.py b/src/ssh/azext_ssh/tests/latest/test_custom.py new file mode 100644 index 00000000000..7580cf75d5d --- /dev/null +++ b/src/ssh/azext_ssh/tests/latest/test_custom.py @@ -0,0 +1,196 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- + +import io +from knack import util +import mock +import unittest + +from azext_ssh import custom + + +class SshCustomCommandTest(unittest.TestCase): + @mock.patch('azext_ssh.custom._do_ssh_op') + @mock.patch('azext_ssh.custom.ssh_utils') + def test_ssh_vm(self, mock_ssh_utils, mock_do_op): + cmd = mock.Mock() + custom.ssh_vm(cmd, "rg", "vm", "ip", "public", "private") + + mock_do_op.assert_called_once_with( + cmd, "rg", "vm", "ip", "public", "private", mock_ssh_utils.start_ssh_connection) + + @unittest.skip('have problems with patching functools.partial, will enable it after getting the root cause.') + @mock.patch('azext_ssh.custom._do_ssh_op') + @mock.patch('azext_ssh.custom.ssh_utils') + @mock.patch('functools.partial') + def test_ssh_config(self, mock_partial, mock_ssh_utils, mock_do_op): + cmd = mock.Mock() + custom.ssh_config(cmd, "path/to/file", "rg", "vm", "ip", "public", "private") + + mock_partial.assert_called_once_with( + mock_ssh_utils.write_ssh_config, "path/to/file", "rg", "vm") + mock_do_op.assert_called_once_with( + cmd, "rg", "vm", "ip", "public", "private", mock_partial.return_value) + + @mock.patch('azext_ssh.custom._assert_args') + @mock.patch('azext_ssh.custom._check_public_private_files') + @mock.patch('azext_ssh.ip_utils.get_ssh_ip') + @mock.patch('azext_ssh.custom._get_modulus_exponent') + @mock.patch('azure.cli.core._profile.Profile.get_msal_token') + @mock.patch('azext_ssh.custom._write_cert_file') + def test_do_ssh_op(self, mock_write_cert, mock_ssh_creds, mock_get_mod_exp, mock_ip, + mock_check_files, mock_assert): + cmd = mock.Mock() + mock_op = mock.Mock() + mock_check_files.return_value = "public", "private" + mock_get_mod_exp.return_value = "modulus", "exponent" + mock_ssh_creds.return_value = "username", "certificate" + + custom._do_ssh_op(cmd, None, None, "1.2.3.4", "publicfile", "privatefile", mock_op) + + mock_assert.assert_called_once_with(None, None, "1.2.3.4") + mock_check_files.assert_called_once_with("publicfile", "privatefile") + mock_ip.assert_not_called() + mock_get_mod_exp.assert_called_once_with("public") + mock_write_cert.assert_called_once_with("public", "certificate") + mock_op.assert_called_once_with( + "1.2.3.4", "username", mock_write_cert.return_value, "private") + + @mock.patch('azext_ssh.custom._assert_args') + @mock.patch('azext_ssh.custom._check_public_private_files') + @mock.patch('azext_ssh.ip_utils.get_ssh_ip') + @mock.patch('azext_ssh.custom._get_modulus_exponent') + def test_do_ssh_op_no_public_ip(self, mock_get_mod_exp, mock_ip, mock_check_files, mock_assert): + cmd = mock.Mock() + mock_op = mock.Mock() + mock_check_files.return_value = "public", "private" + mock_get_mod_exp.return_value = "modulus", "exponent" + mock_ip.return_value = None + + self.assertRaises( + util.CLIError, custom._do_ssh_op, cmd, "rg", "vm", None, + "publicfile", "privatefile", mock_op) + + mock_assert.assert_called_once_with("rg", "vm", None) + mock_check_files.assert_called_once_with("publicfile", "privatefile") + mock_ip.assert_called_once_with(cmd, "rg", "vm") + + def test_assert_args_no_ip_or_vm(self): + self.assertRaises(util.CLIError, custom._assert_args, None, None, None) + + def test_assert_args_vm_rg_mismatch(self): + self.assertRaises(util.CLIError, custom._assert_args, "rg", None, None) + self.assertRaises(util.CLIError, custom._assert_args, None, "vm", None) + + def test_assert_args_ip_with_vm_or_rg(self): + self.assertRaises(util.CLIError, custom._assert_args, None, "vm", "ip") + self.assertRaises(util.CLIError, custom._assert_args, "rg", "vm", "ip") + + @mock.patch('os.path.isfile') + @mock.patch('os.path.expanduser') + @mock.patch('os.path.join') + def test_check_public_private_files_defaults(self, mock_join, mock_expand, mock_isfile): + mock_expand.side_effect = ['publicfile', 'privatefile'] + mock_isfile.return_value = True + + public, private = custom._check_public_private_files(None, None) + + self.assertEqual('publicfile', public) + self.assertEqual('privatefile', private) + mock_expand.assert_has_calls([ + mock.call(mock_join.return_value), + mock.call(mock_join.return_value) + ]) + mock_join.assert_has_calls([ + mock.call("~", ".ssh", "id_rsa.pub"), + mock.call("~", ".ssh", "id_rsa") + ]) + mock_isfile.assert_has_calls([ + mock.call('publicfile'), + mock.call('privatefile') + ]) + + @mock.patch('os.path.isfile') + @mock.patch('os.path.expanduser') + @mock.patch('os.path.join') + def test_check_public_private_files_no_public(self, mock_join, mock_expand, mock_isfile): + mock_isfile.side_effect = [False, True] + + self.assertRaises( + util.CLIError, custom._check_public_private_files, "public", None) + + mock_expand.assert_called_once_with(mock_join.return_value) + mock_join.assert_called_once_with("~", ".ssh", "id_rsa") + mock_isfile.assert_called_once_with("public") + + @mock.patch('os.path.isfile') + @mock.patch('os.path.expanduser') + @mock.patch('os.path.join') + def test_check_public_private_files_no_private(self, mock_join, mock_expand, mock_isfile): + mock_isfile.side_effect = [True, False] + + self.assertRaises( + util.CLIError, custom._check_public_private_files, "public", "private") + + mock_expand.assert_not_called() + mock_join.assert_not_called() + mock_isfile.assert_has_calls([ + mock.call("public"), + mock.call("private") + ]) + + @mock.patch('os.path.join') + @mock.patch('os.path.split') + @mock.patch('builtins.open') + def test_write_cert_file(self, mock_open, mock_split, mock_join): + mock_file = mock.Mock() + mock_open.return_value.__enter__.return_value = mock_file + mock_split.return_value = ["path", "to", "publickey"] + + file_name = custom._write_cert_file("public", "cert") + + self.assertEqual(mock_join.return_value, file_name) + mock_split.assert_called_once_with("public") + mock_join.assert_called_once_with("path", "to", "id_rsa-cert.pub") + mock_open.assert_called_once_with(mock_join.return_value, 'w') + mock_file.write.assert_called_once_with("ssh-rsa-cert-v01@openssh.com cert") + + @mock.patch('azext_ssh.rsa_parser.RSAParser') + @mock.patch('os.path.isfile') + @mock.patch('builtins.open') + def test_get_modulus_exponent_success(self, mock_open, mock_isfile, mock_parser): + mock_isfile.return_value = True + mock_open.return_value = io.StringIO('publickey') + + modulus, exponent = custom._get_modulus_exponent('file') + + self.assertEqual(mock_parser.return_value.modulus, modulus) + self.assertEqual(mock_parser.return_value.exponent, exponent) + mock_isfile.assert_called_once_with('file') + mock_open.assert_called_once_with('file', 'r') + mock_parser.return_value.parse.assert_called_once_with('publickey') + + @mock.patch('os.path.isfile') + def test_get_modulus_exponent_file_not_found(self, mock_isfile): + mock_isfile.return_value = False + + self.assertRaises(util.CLIError, custom._get_modulus_exponent, 'file') + mock_isfile.assert_called_once_with('file') + + @mock.patch('azext_ssh.rsa_parser.RSAParser') + @mock.patch('os.path.isfile') + @mock.patch('builtins.open') + def test_get_modulus_exponent_parse_error(self, mock_open, mock_isfile, mock_parser): + mock_isfile.return_value = True + mock_open.return_value = io.StringIO('publickey') + mock_parser_obj = mock.Mock() + mock_parser.return_value = mock_parser_obj + mock_parser_obj.parse.side_effect = ValueError + + self.assertRaises(util.CLIError, custom._get_modulus_exponent, 'file') + + +if __name__ == '__main__': + unittest.main() diff --git a/src/ssh/azext_ssh/tests/latest/test_rsa_parser.py b/src/ssh/azext_ssh/tests/latest/test_rsa_parser.py new file mode 100644 index 00000000000..a3804bde281 --- /dev/null +++ b/src/ssh/azext_ssh/tests/latest/test_rsa_parser.py @@ -0,0 +1,86 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- + +import unittest +import mock + +from azext_ssh import rsa_parser + + +class RSAParserTest(unittest.TestCase): + def test_rsa_parser_success(self): + public_key_text = 'ssh-rsa ' + self._get_good_key() + parser = rsa_parser.RSAParser() + + parser.parse(public_key_text) + + print(parser.modulus) + self.assertEqual('ssh-rsa', parser.algorithm) + self.assertEqual(self._get_good_modulus(), parser.modulus) + self.assertEqual(self._get_good_exponent(), parser.exponent) + + def test_rsa_parser_too_few_public_key_text_fields(self): + public_key_text = 'algo' + parser = rsa_parser.RSAParser() + + self.assertRaises(ValueError, parser.parse, public_key_text) + + def test_rsa_parser_wrong_algorithm(self): + public_key_text = 'wrongalgo key' + parser = rsa_parser.RSAParser() + + self.assertRaises(ValueError, parser.parse, public_key_text) + + @mock.patch('base64.b64decode') + def test_rsa_parser_algorithm_mismatch(self, mock_decode): + public_key_text = 'ssh-rsa key' + parser = rsa_parser.RSAParser() + + with mock.patch.object(parser, '_get_fields') as mock_get_fields: + mock_get_fields.return_value = [b'otheralgo', b'exp', b'mod'] + self.assertRaises(ValueError, parser.parse, public_key_text) + + mock_decode.assert_called_once_with('key') + mock_get_fields.assert_called_once_with(mock_decode.return_value) + + @mock.patch('base64.b64decode') + def test_rsa_parser_too_few_encoded_fields(self, mock_decode): + public_key_text = 'ssh-rsa key' + mock_decode.return_value = b'decodedkey' + parser = rsa_parser.RSAParser() + + with mock.patch.object(parser, '_get_fields') as mock_get_fields: + mock_get_fields.return_value = [b'ssh-rsa', b'exp'] + self.assertRaises(ValueError, parser.parse, public_key_text) + + mock_decode.assert_called_once_with('key') + mock_get_fields.assert_called_once_with(mock_decode.return_value) + + def _get_good_key(self): + return ( + "AAAAB3NzaC1yc2EAAAADAQABAAABAQChdsBRgNFUAmv4UEYFVSVP2xf0z3rPiS" + "ewgrV16p3Qu7VdxBokCAwvV6KGOGjAU/DKopmKaXcSTDg0mADdgtjJHfZi38Pg" + "55UbFnz/G5RteiUt/IVcz6XdR1ejkxmzFkkAP1LqGSsZWOT+0mJIDuydGleS4h" + "Y5KLle/elhlL8DBbmGFiQwxkAV+ujHCAVs8XDPJPkdiP3F5NGOFIHW09KnuRvE" + "TGgBEJmwCtqr7dWm5rGIU3CTcQHNP+LiYUFTbQKLmwKO6YN7tFGp+DrQNjVTtO" + "01WNK+pzLPEynJr2tJ5g3VgJKJ8QwaDBuK+OASyeTS3ejmvn+b0FDzAHASHn+H" + ) + + def _get_good_modulus(self): + return ( + "AKF2wFGA0VQCa_hQRgVVJU_bF_TPes-JJ7CCtXXqndC7tV3EGiQIDC9XooY4aM" + "BT8MqimYppdxJMODSYAN2C2Mkd9mLfw-DnlRsWfP8blG16JS38hVzPpd1HV6OT" + "GbMWSQA_UuoZKxlY5P7SYkgO7J0aV5LiFjkouV796WGUvwMFuYYWJDDGQBX66M" + "cIBWzxcM8k-R2I_cXk0Y4UgdbT0qe5G8RMaAEQmbAK2qvt1abmsYhTcJNxAc0_" + "4uJhQVNtAoubAo7pg3u0Uan4OtA2NVO07TVY0r6nMs8TKcmva0nmDdWAkonxDB" + "oMG4r44BLJ5NLd6Oa-f5vQUPMAcBIef4c=" + ) + + def _get_good_exponent(self): + return "AQAB" + + +if __name__ == '__main__': + unittest.main() diff --git a/src/ssh/azext_ssh/tests/latest/test_ssh_utils.py b/src/ssh/azext_ssh/tests/latest/test_ssh_utils.py new file mode 100644 index 00000000000..66b2d2bec19 --- /dev/null +++ b/src/ssh/azext_ssh/tests/latest/test_ssh_utils.py @@ -0,0 +1,136 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- + +from knack import util +import mock +import unittest + +from azext_ssh import ssh_utils + + +class SSHUtilsTests(unittest.TestCase): + @mock.patch.object(ssh_utils, '_get_ssh_path') + @mock.patch.object(ssh_utils, '_get_host') + @mock.patch.object(ssh_utils, '_build_args') + @mock.patch('subprocess.call') + def test_start_ssh_connection(self, mock_call, mock_build, mock_host, mock_path): + mock_path.return_value = "ssh" + mock_host.return_value = "user@ip" + mock_build.return_value = ['-i', 'file', '-o', 'option'] + expected_command = ["ssh", "user@ip", "-i", "file", "-o", "option"] + + ssh_utils.start_ssh_connection("ip", "user", "cert", "private") + + mock_path.assert_called_once_with() + mock_host.assert_called_once_with("user", "ip") + mock_build.assert_called_once_with("cert", "private") + mock_call.assert_called_once_with(expected_command, shell=True) + + @mock.patch('azext_ssh.ssh_utils.file_utils.make_dirs_for_file') + def test_write_ssh_config_ip_and_vm(self, mock_make_dirs): + expected_lines = [ + "Host rg-vm", + "\tUser username", + "\tHostName 1.2.3.4", + "\tCertificateFile cert", + "\tIdentityFile privatekey", + "Host 1.2.3.4", + "\tUser username", + "\tHostName 1.2.3.4", + "\tCertificateFile cert", + "\tIdentityFile privatekey" + ] + + with mock.patch('builtins.open') as mock_open: + mock_file = mock.Mock() + mock_open.return_value.__enter__.return_value = mock_file + ssh_utils.write_ssh_config( + "path/to/file", "rg", "vm", "1.2.3.4", "username", "cert", "privatekey" + ) + + mock_make_dirs.assert_called_once_with("path/to/file") + mock_open.assert_called_once_with("path/to/file", "w") + mock_file.write.assert_called_once_with('\n'.join(expected_lines)) + + @mock.patch('azext_ssh.ssh_utils.file_utils.make_dirs_for_file') + def test_write_ssh_config_ip_only(self, mock_make_dirs): + expected_lines = [ + "Host 1.2.3.4", + "\tUser username", + "\tHostName 1.2.3.4", + "\tCertificateFile cert", + "\tIdentityFile privatekey" + ] + + with mock.patch('builtins.open') as mock_open: + mock_file = mock.Mock() + mock_open.return_value.__enter__.return_value = mock_file + ssh_utils.write_ssh_config( + "path/to/file", None, None, "1.2.3.4", "username", "cert", "privatekey" + ) + + mock_make_dirs.assert_called_once_with("path/to/file") + mock_open.assert_called_once_with("path/to/file", "w") + mock_file.write.assert_called_once_with('\n'.join(expected_lines)) + + @mock.patch('platform.system') + def test_get_ssh_path_non_windows(self, mock_system): + mock_system.return_value = "Mac" + + actual_path = ssh_utils._get_ssh_path() + self.assertEqual('ssh', actual_path) + mock_system.assert_called_once_with() + + def test_get_ssh_path_windows_32bit(self): + self._test_ssh_path_windows('32bit', 'SysNative') + + def test_get_ssh_path_windows_64bit(self): + self._test_ssh_path_windows('64bit', 'System32') + + @mock.patch('platform.system') + @mock.patch('platform.architecture') + @mock.patch('os.environ') + @mock.patch('os.path.isfile') + def test_get_ssh_path_windows_ssh_not_found(self, mock_isfile, mock_environ, mock_arch, mock_sys): + mock_sys.return_value = "Windows" + mock_arch.return_value = ("32bit", "foo", "bar") + mock_environ.__getitem__.return_value = "rootpath" + mock_isfile.return_value = False + + self.assertRaises(util.CLIError, ssh_utils._get_ssh_path) + + def test_get_host(self): + actual_host = ssh_utils._get_host("username", "10.0.0.1") + self.assertEqual("username@10.0.0.1", actual_host) + + def test_build_args(self): + actual_args = ssh_utils._build_args("cert", "privatekey") + expected_args = ["-i", "privatekey", "-o", "CertificateFile=cert"] + self.assertEqual(expected_args, actual_args) + + @mock.patch('platform.system') + @mock.patch('platform.architecture') + @mock.patch('os.path.join') + @mock.patch('os.environ') + @mock.patch('os.path.isfile') + def _test_ssh_path_windows(self, arch, expected_sys_path, mock_isfile, mock_environ, mock_join, mock_arch, mock_system): + mock_system.return_value = "Windows" + mock_arch.return_value = (arch, "foo", "bar") + mock_environ.__getitem__.return_value = "rootpath" + mock_join.side_effect = ["system32path", "sshfilepath"] + mock_isfile.return_value = True + expected_join_calls = [ + mock.call("rootpath", expected_sys_path), + mock.call("system32path", "openSSH", "ssh.exe") + ] + + actual_path = ssh_utils._get_ssh_path() + + self.assertEqual("sshfilepath", actual_path) + mock_system.assert_called_once_with() + mock_arch.assert_called_once_with() + mock_environ.__getitem__.assert_called_once_with("SystemRoot") + mock_join.assert_has_calls(expected_join_calls) + mock_isfile.assert_called_once_with("sshfilepath") diff --git a/src/ssh/setup.cfg b/src/ssh/setup.cfg new file mode 100644 index 00000000000..e69de29bb2d diff --git a/src/ssh/setup.py b/src/ssh/setup.py new file mode 100644 index 00000000000..f3c01549186 --- /dev/null +++ b/src/ssh/setup.py @@ -0,0 +1,42 @@ +#!/usr/bin/env python + +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- + +from setuptools import setup, find_packages + +VERSION = "0.1.0" + +CLASSIFIERS = [ + 'Development Status :: 4 - Beta', + 'Intended Audience :: Developers', + 'Intended Audience :: System Administrators', + 'Programming Language :: Python', + 'Programming Language :: Python :: 3', + 'Programming Language :: Python :: 3.6', + 'Programming Language :: Python :: 3.7', + 'Programming Language :: Python :: 3.8', + 'License :: OSI Approved :: MIT License', +] + +DEPENDENCIES = [ + 'paramiko==2.6.0', + 'cryptography==2.8.0' +] + +setup( + name='ssh', + version=VERSION, + description='SSH into VMs', + long_description='SSH into VMs using RBAC', + license='MIT', + author='Ryan Rossiter', + author_email='ryrossit@microsoft.com', + url='https://github.com/Azure/azure-cli-extensions/tree/master/src/ssh', + classifiers=CLASSIFIERS, + packages=find_packages(), + install_requires=DEPENDENCIES, + package_data={'azext_ssh': ['azext_metadata.json']} +)