From 7a16cc56af651523d825071e28bf59241203c820 Mon Sep 17 00:00:00 2001 From: Lee Burton <lburton@microsoft.com> Date: Wed, 6 Jan 2021 19:06:00 -0800 Subject: [PATCH] [SSH] Add support for private IPs and force lowercase usernames (#2858) --- src/ssh/HISTORY.md | 7 ++++ src/ssh/azext_ssh/_help.py | 2 +- src/ssh/azext_ssh/_params.py | 10 ++++-- src/ssh/azext_ssh/custom.py | 28 +++++++++------- src/ssh/azext_ssh/ip_utils.py | 22 ++++++------- src/ssh/azext_ssh/ssh_utils.py | 13 ++++++-- src/ssh/azext_ssh/tests/latest/test_custom.py | 18 +++++------ .../azext_ssh/tests/latest/test_ssh_utils.py | 32 +++++++++++++++++-- src/ssh/setup.py | 2 +- 9 files changed, 91 insertions(+), 43 deletions(-) diff --git a/src/ssh/HISTORY.md b/src/ssh/HISTORY.md index 68e4fe4892d..c4cb447b37a 100644 --- a/src/ssh/HISTORY.md +++ b/src/ssh/HISTORY.md @@ -1,6 +1,13 @@ Release History =============== +0.1.3 +----- +* Add support for using private IPs +* Add option alias `--name` for `--vm-name` +* Use lowercase username by default +* Fix various typos + 0.1.2 ----- * Add support for hardware tokens (don't require the private key be passed in) diff --git a/src/ssh/azext_ssh/_help.py b/src/ssh/azext_ssh/_help.py index 2a3e942b743..7c1667b6e82 100644 --- a/src/ssh/azext_ssh/_help.py +++ b/src/ssh/azext_ssh/_help.py @@ -42,7 +42,7 @@ helps['ssh cert'] = """ type: command - short-summary: Create an SSH RSA certifcate signed by AAD + short-summary: Create an SSH RSA certificate signed by AAD examples: - name: Create a short lived ssh certificate signed by AAD text: | diff --git a/src/ssh/azext_ssh/_params.py b/src/ssh/azext_ssh/_params.py index 90c78d1145f..b73c677d180 100644 --- a/src/ssh/azext_ssh/_params.py +++ b/src/ssh/azext_ssh/_params.py @@ -7,17 +7,23 @@ def load_arguments(self, _): with self.argument_context('ssh vm') as c: - c.argument('vm_name', options_list=['--vm-name', '-n'], help='The name of the VM') + c.argument('vm_name', options_list=['--vm-name', '--name', '-n'], help='The name of the VM') c.argument('ssh_ip', options_list=['--ip'], help='The public IP address (or hostname) of the VM') c.argument('public_key_file', options_list=['--public-key-file', '-p'], help='The RSA public key file path') c.argument('private_key_file', options_list=['--private-key-file', '-i'], help='The RSA private key file path') + c.argument('use_private_ip', options_list=['--prefer-private-ip'], + help='Will use a private IP if available. By default only public IPs are used.') with self.argument_context('ssh config') as c: c.argument('config_path', options_list=['--file', '-f'], help='The file path to write the SSH config to') - c.argument('vm_name', options_list=['--vm-name', '-n'], help='The name of the VM') + c.argument('vm_name', options_list=['--vm-name', '--name', '-n'], help='The name of the VM') c.argument('ssh_ip', options_list=['--ip'], help='The public IP address (or hostname) of the VM') c.argument('public_key_file', options_list=['--public-key-file', '-p'], help='The RSA public key file path') c.argument('private_key_file', options_list=['--private-key-file', '-i'], help='The RSA private key file path') + c.argument('use_private_ip', options_list=['--prefer-private-ip'], + help='Will use a private IP if available. By default only public IPs are used.') + c.argument('overwrite', action='store_true', options_list=['--overwrite'], + help='Overwrites the config file if this flag is set') with self.argument_context('ssh cert') as c: c.argument('cert_path', options_list=['--file', '-f'], diff --git a/src/ssh/azext_ssh/custom.py b/src/ssh/azext_ssh/custom.py index 4f460ca667f..8c479d584c8 100644 --- a/src/ssh/azext_ssh/custom.py +++ b/src/ssh/azext_ssh/custom.py @@ -16,15 +16,16 @@ from . import ssh_utils -def ssh_vm(cmd, resource_group_name=None, vm_name=None, ssh_ip=None, public_key_file=None, private_key_file=None): +def ssh_vm(cmd, resource_group_name=None, vm_name=None, ssh_ip=None, public_key_file=None, + private_key_file=None, use_private_ip=False): _do_ssh_op(cmd, resource_group_name, vm_name, ssh_ip, - public_key_file, private_key_file, ssh_utils.start_ssh_connection) + public_key_file, private_key_file, use_private_ip, ssh_utils.start_ssh_connection) def ssh_config(cmd, config_path, resource_group_name=None, vm_name=None, ssh_ip=None, - public_key_file=None, private_key_file=None): - op_call = functools.partial(ssh_utils.write_ssh_config, config_path, resource_group_name, vm_name) - _do_ssh_op(cmd, resource_group_name, vm_name, ssh_ip, public_key_file, private_key_file, op_call) + public_key_file=None, private_key_file=None, overwrite=False, use_private_ip=False): + op_call = functools.partial(ssh_utils.write_ssh_config, config_path, resource_group_name, vm_name, overwrite) + _do_ssh_op(cmd, resource_group_name, vm_name, ssh_ip, public_key_file, private_key_file, use_private_ip, op_call) def ssh_cert(cmd, cert_path=None, public_key_file=None): @@ -33,13 +34,16 @@ def ssh_cert(cmd, cert_path=None, public_key_file=None): print(cert_file + "\n") -def _do_ssh_op(cmd, resource_group, vm_name, ssh_ip, public_key_file, private_key_file, op_call): +def _do_ssh_op(cmd, resource_group, vm_name, ssh_ip, public_key_file, private_key_file, use_private_ip, op_call): _assert_args(resource_group, vm_name, ssh_ip) public_key_file, private_key_file = _check_or_create_public_private_files(public_key_file, private_key_file) - ssh_ip = ssh_ip or ip_utils.get_ssh_ip(cmd, resource_group, vm_name) + ssh_ip = ssh_ip or ip_utils.get_ssh_ip(cmd, resource_group, vm_name, use_private_ip) if not ssh_ip: - raise util.CLIError(f"VM '{vm_name}' does not have a public IP address to SSH to") + if not use_private_ip: + raise util.CLIError(f"VM '{vm_name}' does not have a public IP address to SSH to") + + raise util.CLIError(f"VM '{vm_name}' does not have a public or private IP address to SSH to") cert_file, username = _get_and_write_certificate(cmd, public_key_file, None) op_call(ssh_ip, username, cert_file, private_key_file) @@ -53,7 +57,7 @@ def _get_and_write_certificate(cmd, public_key_file, cert_file): username, certificate = profile.get_msal_token(scopes, data) if not cert_file: cert_file = public_key_file + "-aadcert.pub" - return _write_cert_file(certificate, cert_file), username + return _write_cert_file(certificate, cert_file), username.lower() def _prepare_jwk_data(public_key_file): @@ -79,13 +83,13 @@ def _prepare_jwk_data(public_key_file): def _assert_args(resource_group, vm_name, ssh_ip): if not (resource_group or vm_name or ssh_ip): - raise util.CLIError("The VM must be specified by --ip or --resource-group and --name") + raise util.CLIError("The VM must be specified by --ip or --resource-group and --vm-name/--name") if resource_group and not vm_name or vm_name and not resource_group: - raise util.CLIError("--resource-group and --name must be provided together") + raise util.CLIError("--resource-group and --vm-name/--name must be provided together") if ssh_ip and (vm_name or resource_group): - raise util.CLIError("--ip cannot be used with --resource-group or --name") + raise util.CLIError("--ip cannot be used with --resource-group or --vm-name/--name") def _check_or_create_public_private_files(public_key_file, private_key_file): diff --git a/src/ssh/azext_ssh/ip_utils.py b/src/ssh/azext_ssh/ip_utils.py index c08f8bfc041..04d8de26ccf 100644 --- a/src/ssh/azext_ssh/ip_utils.py +++ b/src/ssh/azext_ssh/ip_utils.py @@ -8,7 +8,7 @@ from msrestazure import tools -def get_ssh_ip(cmd, resource_group, vm_name): +def get_ssh_ip(cmd, resource_group, vm_name, use_private_ip): compute_client = client_factory.get_mgmt_service_client(cmd.cli_ctx, profiles.ResourceType.MGMT_COMPUTE) network_client = client_factory.get_mgmt_service_client(cmd.cli_ctx, profiles.ResourceType.MGMT_NETWORK) vm_client = compute_client.virtual_machines @@ -16,21 +16,17 @@ def get_ssh_ip(cmd, resource_group, vm_name): ip_client = network_client.public_ip_addresses vm = vm_client.get(resource_group, vm_name) - nics = vm.network_profile.network_interfaces - ssh_ip = None - for nic_ref in nics: + + for nic_ref in vm.network_profile.network_interfaces: parsed_id = tools.parse_resource_id(nic_ref.id) nic = nic_client.get(parsed_id['resource_group'], parsed_id['name']) - ip_configs = nic.ip_configurations - for ip_config in ip_configs: + for ip_config in nic.ip_configurations: + if use_private_ip and ip_config.private_ip_address: + return ip_config.private_ip_address public_ip_ref = ip_config.public_ip_address parsed_ip_id = tools.parse_resource_id(public_ip_ref.id) public_ip = ip_client.get(parsed_ip_id['resource_group'], parsed_ip_id['name']) - ssh_ip = public_ip.ip_address - - if ssh_ip: - break + if public_ip.ip_address: + return public_ip.ip_address - if ssh_ip: - break - return ssh_ip + return None diff --git a/src/ssh/azext_ssh/ssh_utils.py b/src/ssh/azext_ssh/ssh_utils.py index 593efbd929f..fa232dc5f4c 100644 --- a/src/ssh/azext_ssh/ssh_utils.py +++ b/src/ssh/azext_ssh/ssh_utils.py @@ -27,10 +27,12 @@ def create_ssh_keyfile(private_key_file): subprocess.call(command, shell=platform.system() == 'Windows') -def write_ssh_config(config_path, resource_group, vm_name, +def write_ssh_config(config_path, resource_group, vm_name, overwrite, ip, username, cert_file, private_key_file): file_utils.make_dirs_for_file(config_path) - lines = [] + + lines = [""] + if resource_group and vm_name: lines.append("Host " + resource_group + "-" + vm_name) lines.append("\tUser " + username) @@ -49,7 +51,12 @@ def write_ssh_config(config_path, resource_group, vm_name, if private_key_file: lines.append("\tIdentityFile " + private_key_file) - with open(config_path, 'w') as f: + if overwrite: + mode = 'w' + else: + mode = 'a' + + with open(config_path, mode) as f: f.write('\n'.join(lines)) diff --git a/src/ssh/azext_ssh/tests/latest/test_custom.py b/src/ssh/azext_ssh/tests/latest/test_custom.py index d54363e1416..3ff75ca5fbc 100644 --- a/src/ssh/azext_ssh/tests/latest/test_custom.py +++ b/src/ssh/azext_ssh/tests/latest/test_custom.py @@ -16,26 +16,26 @@ class SshCustomCommandTest(unittest.TestCase): @mock.patch('azext_ssh.custom.ssh_utils') def test_ssh_vm(self, mock_ssh_utils, mock_do_op): cmd = mock.Mock() - custom.ssh_vm(cmd, "rg", "vm", "ip", "public", "private") + custom.ssh_vm(cmd, "rg", "vm", "ip", "public", "private", False) mock_do_op.assert_called_once_with( - cmd, "rg", "vm", "ip", "public", "private", mock_ssh_utils.start_ssh_connection) + cmd, "rg", "vm", "ip", "public", "private", False, mock_ssh_utils.start_ssh_connection) @mock.patch('azext_ssh.custom._do_ssh_op') @mock.patch('azext_ssh.ssh_utils.write_ssh_config') def test_ssh_config(self, mock_ssh_utils, mock_do_op): cmd = mock.Mock() - def do_op_side_effect(cmd, resource_group, vm_name, ssh_ip, public_key_file, private_key_file, op_call): + def do_op_side_effect(cmd, resource_group, vm_name, ssh_ip, public_key_file, private_key_file, use_private_ip, op_call): op_call(ssh_ip, "username", "cert_file", private_key_file) mock_do_op.side_effect = do_op_side_effect - custom.ssh_config(cmd, "path/to/file", "rg", "vm", "ip", "public", "private") + custom.ssh_config(cmd, "path/to/file", "rg", "vm", "ip", "public", "private", False, False) - mock_ssh_utils.assert_called_once_with("path/to/file", "rg", "vm", "ip", "username", "cert_file", "private") + mock_ssh_utils.assert_called_once_with("path/to/file", "rg", "vm", False, "ip", "username", "cert_file", "private") mock_do_op.assert_called_once_with( - cmd, "rg", "vm", "ip", "public", "private", mock.ANY) + cmd, "rg", "vm", "ip", "public", "private", False, mock.ANY) @mock.patch('os.path.join') @mock.patch('azext_ssh.custom._assert_args') @@ -53,7 +53,7 @@ def test_do_ssh_op(self, mock_write_cert, mock_ssh_creds, mock_get_mod_exp, mock mock_ssh_creds.return_value = "username", "certificate" mock_join.return_value = "public-aadcert.pub" - custom._do_ssh_op(cmd, None, None, "1.2.3.4", "publicfile", "privatefile", mock_op) + custom._do_ssh_op(cmd, None, None, "1.2.3.4", "publicfile", "privatefile", False, mock_op) mock_assert.assert_called_once_with(None, None, "1.2.3.4") mock_check_files.assert_called_once_with("publicfile", "privatefile") @@ -76,11 +76,11 @@ def test_do_ssh_op_no_public_ip(self, mock_get_mod_exp, mock_ip, mock_check_file self.assertRaises( util.CLIError, custom._do_ssh_op, cmd, "rg", "vm", None, - "publicfile", "privatefile", mock_op) + "publicfile", "privatefile", False, mock_op) mock_assert.assert_called_once_with("rg", "vm", None) mock_check_files.assert_called_once_with("publicfile", "privatefile") - mock_ip.assert_called_once_with(cmd, "rg", "vm") + mock_ip.assert_called_once_with(cmd, "rg", "vm", False) def test_assert_args_no_ip_or_vm(self): self.assertRaises(util.CLIError, custom._assert_args, None, None, None) diff --git a/src/ssh/azext_ssh/tests/latest/test_ssh_utils.py b/src/ssh/azext_ssh/tests/latest/test_ssh_utils.py index 11d74dc62e8..fceb5e09fb6 100644 --- a/src/ssh/azext_ssh/tests/latest/test_ssh_utils.py +++ b/src/ssh/azext_ssh/tests/latest/test_ssh_utils.py @@ -32,6 +32,7 @@ def test_start_ssh_connection(self, mock_call, mock_build, mock_host, mock_path) @mock.patch('azext_ssh.ssh_utils.file_utils.make_dirs_for_file') def test_write_ssh_config_ip_and_vm(self, mock_make_dirs): expected_lines = [ + "", "Host rg-vm", "\tUser username", "\tHostName 1.2.3.4", @@ -47,16 +48,43 @@ def test_write_ssh_config_ip_and_vm(self, mock_make_dirs): mock_file = mock.Mock() mock_open.return_value.__enter__.return_value = mock_file ssh_utils.write_ssh_config( - "path/to/file", "rg", "vm", "1.2.3.4", "username", "cert", "privatekey" + "path/to/file", "rg", "vm", True, "1.2.3.4", "username", "cert", "privatekey" ) mock_make_dirs.assert_called_once_with("path/to/file") mock_open.assert_called_once_with("path/to/file", "w") mock_file.write.assert_called_once_with('\n'.join(expected_lines)) + @mock.patch('azext_ssh.ssh_utils.file_utils.make_dirs_for_file') + def test_write_ssh_config_append(self, mock_make_dirs): + expected_lines = [ + "", + "Host rg-vm", + "\tUser username", + "\tHostName 1.2.3.4", + "\tCertificateFile cert", + "\tIdentityFile privatekey", + "Host 1.2.3.4", + "\tUser username", + "\tCertificateFile cert", + "\tIdentityFile privatekey" + ] + + with mock.patch('builtins.open') as mock_open: + mock_file = mock.Mock() + mock_open.return_value.__enter__.return_value = mock_file + ssh_utils.write_ssh_config( + "path/to/file", "rg", "vm", False, "1.2.3.4", "username", "cert", "privatekey" + ) + + mock_make_dirs.assert_called_once_with("path/to/file") + mock_open.assert_called_once_with("path/to/file", "a") + mock_file.write.assert_called_once_with('\n'.join(expected_lines)) + @mock.patch('azext_ssh.ssh_utils.file_utils.make_dirs_for_file') def test_write_ssh_config_ip_only(self, mock_make_dirs): expected_lines = [ + "", "Host 1.2.3.4", "\tUser username", "\tCertificateFile cert", @@ -67,7 +95,7 @@ def test_write_ssh_config_ip_only(self, mock_make_dirs): mock_file = mock.Mock() mock_open.return_value.__enter__.return_value = mock_file ssh_utils.write_ssh_config( - "path/to/file", None, None, "1.2.3.4", "username", "cert", "privatekey" + "path/to/file", None, None, True, "1.2.3.4", "username", "cert", "privatekey" ) mock_make_dirs.assert_called_once_with("path/to/file") diff --git a/src/ssh/setup.py b/src/ssh/setup.py index e3853dc8caa..7876f43ad93 100644 --- a/src/ssh/setup.py +++ b/src/ssh/setup.py @@ -7,7 +7,7 @@ from setuptools import setup, find_packages -VERSION = "0.1.2" +VERSION = "0.1.3" CLASSIFIERS = [ 'Development Status :: 4 - Beta',