From 7a16cc56af651523d825071e28bf59241203c820 Mon Sep 17 00:00:00 2001
From: Lee Burton <lburton@microsoft.com>
Date: Wed, 6 Jan 2021 19:06:00 -0800
Subject: [PATCH] [SSH] Add support for private IPs and force lowercase
 usernames (#2858)

---
 src/ssh/HISTORY.md                            |  7 ++++
 src/ssh/azext_ssh/_help.py                    |  2 +-
 src/ssh/azext_ssh/_params.py                  | 10 ++++--
 src/ssh/azext_ssh/custom.py                   | 28 +++++++++-------
 src/ssh/azext_ssh/ip_utils.py                 | 22 ++++++-------
 src/ssh/azext_ssh/ssh_utils.py                | 13 ++++++--
 src/ssh/azext_ssh/tests/latest/test_custom.py | 18 +++++------
 .../azext_ssh/tests/latest/test_ssh_utils.py  | 32 +++++++++++++++++--
 src/ssh/setup.py                              |  2 +-
 9 files changed, 91 insertions(+), 43 deletions(-)

diff --git a/src/ssh/HISTORY.md b/src/ssh/HISTORY.md
index 68e4fe4892d..c4cb447b37a 100644
--- a/src/ssh/HISTORY.md
+++ b/src/ssh/HISTORY.md
@@ -1,6 +1,13 @@
 Release History
 ===============
 
+0.1.3
+-----
+* Add support for using private IPs
+* Add option alias `--name` for `--vm-name`
+* Use lowercase username by default
+* Fix various typos
+
 0.1.2
 -----
 * Add support for hardware tokens (don't require the private key be passed in)
diff --git a/src/ssh/azext_ssh/_help.py b/src/ssh/azext_ssh/_help.py
index 2a3e942b743..7c1667b6e82 100644
--- a/src/ssh/azext_ssh/_help.py
+++ b/src/ssh/azext_ssh/_help.py
@@ -42,7 +42,7 @@
 
 helps['ssh cert'] = """
     type: command
-    short-summary: Create an SSH RSA certifcate signed by AAD
+    short-summary: Create an SSH RSA certificate signed by AAD
     examples:
         - name: Create a short lived ssh certificate signed by AAD
           text: |
diff --git a/src/ssh/azext_ssh/_params.py b/src/ssh/azext_ssh/_params.py
index 90c78d1145f..b73c677d180 100644
--- a/src/ssh/azext_ssh/_params.py
+++ b/src/ssh/azext_ssh/_params.py
@@ -7,17 +7,23 @@
 def load_arguments(self, _):
 
     with self.argument_context('ssh vm') as c:
-        c.argument('vm_name', options_list=['--vm-name', '-n'], help='The name of the VM')
+        c.argument('vm_name', options_list=['--vm-name', '--name', '-n'], help='The name of the VM')
         c.argument('ssh_ip', options_list=['--ip'], help='The public IP address (or hostname) of the VM')
         c.argument('public_key_file', options_list=['--public-key-file', '-p'], help='The RSA public key file path')
         c.argument('private_key_file', options_list=['--private-key-file', '-i'], help='The RSA private key file path')
+        c.argument('use_private_ip', options_list=['--prefer-private-ip'],
+                   help='Will use a private IP if available. By default only public IPs are used.')
 
     with self.argument_context('ssh config') as c:
         c.argument('config_path', options_list=['--file', '-f'], help='The file path to write the SSH config to')
-        c.argument('vm_name', options_list=['--vm-name', '-n'], help='The name of the VM')
+        c.argument('vm_name', options_list=['--vm-name', '--name', '-n'], help='The name of the VM')
         c.argument('ssh_ip', options_list=['--ip'], help='The public IP address (or hostname) of the VM')
         c.argument('public_key_file', options_list=['--public-key-file', '-p'], help='The RSA public key file path')
         c.argument('private_key_file', options_list=['--private-key-file', '-i'], help='The RSA private key file path')
+        c.argument('use_private_ip', options_list=['--prefer-private-ip'],
+                   help='Will use a private IP if available. By default only public IPs are used.')
+        c.argument('overwrite', action='store_true', options_list=['--overwrite'],
+                   help='Overwrites the config file if this flag is set')
 
     with self.argument_context('ssh cert') as c:
         c.argument('cert_path', options_list=['--file', '-f'],
diff --git a/src/ssh/azext_ssh/custom.py b/src/ssh/azext_ssh/custom.py
index 4f460ca667f..8c479d584c8 100644
--- a/src/ssh/azext_ssh/custom.py
+++ b/src/ssh/azext_ssh/custom.py
@@ -16,15 +16,16 @@
 from . import ssh_utils
 
 
-def ssh_vm(cmd, resource_group_name=None, vm_name=None, ssh_ip=None, public_key_file=None, private_key_file=None):
+def ssh_vm(cmd, resource_group_name=None, vm_name=None, ssh_ip=None, public_key_file=None,
+           private_key_file=None, use_private_ip=False):
     _do_ssh_op(cmd, resource_group_name, vm_name, ssh_ip,
-               public_key_file, private_key_file, ssh_utils.start_ssh_connection)
+               public_key_file, private_key_file, use_private_ip, ssh_utils.start_ssh_connection)
 
 
 def ssh_config(cmd, config_path, resource_group_name=None, vm_name=None, ssh_ip=None,
-               public_key_file=None, private_key_file=None):
-    op_call = functools.partial(ssh_utils.write_ssh_config, config_path, resource_group_name, vm_name)
-    _do_ssh_op(cmd, resource_group_name, vm_name, ssh_ip, public_key_file, private_key_file, op_call)
+               public_key_file=None, private_key_file=None, overwrite=False, use_private_ip=False):
+    op_call = functools.partial(ssh_utils.write_ssh_config, config_path, resource_group_name, vm_name, overwrite)
+    _do_ssh_op(cmd, resource_group_name, vm_name, ssh_ip, public_key_file, private_key_file, use_private_ip, op_call)
 
 
 def ssh_cert(cmd, cert_path=None, public_key_file=None):
@@ -33,13 +34,16 @@ def ssh_cert(cmd, cert_path=None, public_key_file=None):
     print(cert_file + "\n")
 
 
-def _do_ssh_op(cmd, resource_group, vm_name, ssh_ip, public_key_file, private_key_file, op_call):
+def _do_ssh_op(cmd, resource_group, vm_name, ssh_ip, public_key_file, private_key_file, use_private_ip, op_call):
     _assert_args(resource_group, vm_name, ssh_ip)
     public_key_file, private_key_file = _check_or_create_public_private_files(public_key_file, private_key_file)
-    ssh_ip = ssh_ip or ip_utils.get_ssh_ip(cmd, resource_group, vm_name)
+    ssh_ip = ssh_ip or ip_utils.get_ssh_ip(cmd, resource_group, vm_name, use_private_ip)
 
     if not ssh_ip:
-        raise util.CLIError(f"VM '{vm_name}' does not have a public IP address to SSH to")
+        if not use_private_ip:
+            raise util.CLIError(f"VM '{vm_name}' does not have a public IP address to SSH to")
+
+        raise util.CLIError(f"VM '{vm_name}' does not have a public or private IP address to SSH to")
 
     cert_file, username = _get_and_write_certificate(cmd, public_key_file, None)
     op_call(ssh_ip, username, cert_file, private_key_file)
@@ -53,7 +57,7 @@ def _get_and_write_certificate(cmd, public_key_file, cert_file):
     username, certificate = profile.get_msal_token(scopes, data)
     if not cert_file:
         cert_file = public_key_file + "-aadcert.pub"
-    return _write_cert_file(certificate, cert_file), username
+    return _write_cert_file(certificate, cert_file), username.lower()
 
 
 def _prepare_jwk_data(public_key_file):
@@ -79,13 +83,13 @@ def _prepare_jwk_data(public_key_file):
 
 def _assert_args(resource_group, vm_name, ssh_ip):
     if not (resource_group or vm_name or ssh_ip):
-        raise util.CLIError("The VM must be specified by --ip or --resource-group and --name")
+        raise util.CLIError("The VM must be specified by --ip or --resource-group and --vm-name/--name")
 
     if resource_group and not vm_name or vm_name and not resource_group:
-        raise util.CLIError("--resource-group and --name must be provided together")
+        raise util.CLIError("--resource-group and --vm-name/--name must be provided together")
 
     if ssh_ip and (vm_name or resource_group):
-        raise util.CLIError("--ip cannot be used with --resource-group or --name")
+        raise util.CLIError("--ip cannot be used with --resource-group or --vm-name/--name")
 
 
 def _check_or_create_public_private_files(public_key_file, private_key_file):
diff --git a/src/ssh/azext_ssh/ip_utils.py b/src/ssh/azext_ssh/ip_utils.py
index c08f8bfc041..04d8de26ccf 100644
--- a/src/ssh/azext_ssh/ip_utils.py
+++ b/src/ssh/azext_ssh/ip_utils.py
@@ -8,7 +8,7 @@
 from msrestazure import tools
 
 
-def get_ssh_ip(cmd, resource_group, vm_name):
+def get_ssh_ip(cmd, resource_group, vm_name, use_private_ip):
     compute_client = client_factory.get_mgmt_service_client(cmd.cli_ctx, profiles.ResourceType.MGMT_COMPUTE)
     network_client = client_factory.get_mgmt_service_client(cmd.cli_ctx, profiles.ResourceType.MGMT_NETWORK)
     vm_client = compute_client.virtual_machines
@@ -16,21 +16,17 @@ def get_ssh_ip(cmd, resource_group, vm_name):
     ip_client = network_client.public_ip_addresses
 
     vm = vm_client.get(resource_group, vm_name)
-    nics = vm.network_profile.network_interfaces
-    ssh_ip = None
-    for nic_ref in nics:
+
+    for nic_ref in vm.network_profile.network_interfaces:
         parsed_id = tools.parse_resource_id(nic_ref.id)
         nic = nic_client.get(parsed_id['resource_group'], parsed_id['name'])
-        ip_configs = nic.ip_configurations
-        for ip_config in ip_configs:
+        for ip_config in nic.ip_configurations:
+            if use_private_ip and ip_config.private_ip_address:
+                return ip_config.private_ip_address
             public_ip_ref = ip_config.public_ip_address
             parsed_ip_id = tools.parse_resource_id(public_ip_ref.id)
             public_ip = ip_client.get(parsed_ip_id['resource_group'], parsed_ip_id['name'])
-            ssh_ip = public_ip.ip_address
-
-            if ssh_ip:
-                break
+            if public_ip.ip_address:
+                return public_ip.ip_address
 
-        if ssh_ip:
-            break
-    return ssh_ip
+    return None
diff --git a/src/ssh/azext_ssh/ssh_utils.py b/src/ssh/azext_ssh/ssh_utils.py
index 593efbd929f..fa232dc5f4c 100644
--- a/src/ssh/azext_ssh/ssh_utils.py
+++ b/src/ssh/azext_ssh/ssh_utils.py
@@ -27,10 +27,12 @@ def create_ssh_keyfile(private_key_file):
     subprocess.call(command, shell=platform.system() == 'Windows')
 
 
-def write_ssh_config(config_path, resource_group, vm_name,
+def write_ssh_config(config_path, resource_group, vm_name, overwrite,
                      ip, username, cert_file, private_key_file):
     file_utils.make_dirs_for_file(config_path)
-    lines = []
+
+    lines = [""]
+
     if resource_group and vm_name:
         lines.append("Host " + resource_group + "-" + vm_name)
         lines.append("\tUser " + username)
@@ -49,7 +51,12 @@ def write_ssh_config(config_path, resource_group, vm_name,
     if private_key_file:
         lines.append("\tIdentityFile " + private_key_file)
 
-    with open(config_path, 'w') as f:
+    if overwrite:
+        mode = 'w'
+    else:
+        mode = 'a'
+
+    with open(config_path, mode) as f:
         f.write('\n'.join(lines))
 
 
diff --git a/src/ssh/azext_ssh/tests/latest/test_custom.py b/src/ssh/azext_ssh/tests/latest/test_custom.py
index d54363e1416..3ff75ca5fbc 100644
--- a/src/ssh/azext_ssh/tests/latest/test_custom.py
+++ b/src/ssh/azext_ssh/tests/latest/test_custom.py
@@ -16,26 +16,26 @@ class SshCustomCommandTest(unittest.TestCase):
     @mock.patch('azext_ssh.custom.ssh_utils')
     def test_ssh_vm(self, mock_ssh_utils, mock_do_op):
         cmd = mock.Mock()
-        custom.ssh_vm(cmd, "rg", "vm", "ip", "public", "private")
+        custom.ssh_vm(cmd, "rg", "vm", "ip", "public", "private", False)
 
         mock_do_op.assert_called_once_with(
-            cmd, "rg", "vm", "ip", "public", "private", mock_ssh_utils.start_ssh_connection)
+            cmd, "rg", "vm", "ip", "public", "private", False, mock_ssh_utils.start_ssh_connection)
 
     @mock.patch('azext_ssh.custom._do_ssh_op')
     @mock.patch('azext_ssh.ssh_utils.write_ssh_config')
     def test_ssh_config(self, mock_ssh_utils, mock_do_op):
         cmd = mock.Mock()
 
-        def do_op_side_effect(cmd, resource_group, vm_name, ssh_ip, public_key_file, private_key_file, op_call):
+        def do_op_side_effect(cmd, resource_group, vm_name, ssh_ip, public_key_file, private_key_file, use_private_ip, op_call):
             op_call(ssh_ip, "username", "cert_file", private_key_file)
 
         mock_do_op.side_effect = do_op_side_effect
-        custom.ssh_config(cmd, "path/to/file", "rg", "vm", "ip", "public", "private")
+        custom.ssh_config(cmd, "path/to/file", "rg", "vm", "ip", "public", "private", False, False)
 
-        mock_ssh_utils.assert_called_once_with("path/to/file", "rg", "vm", "ip", "username", "cert_file", "private")
+        mock_ssh_utils.assert_called_once_with("path/to/file", "rg", "vm", False, "ip", "username", "cert_file", "private")
 
         mock_do_op.assert_called_once_with(
-            cmd, "rg", "vm", "ip", "public", "private", mock.ANY)
+            cmd, "rg", "vm", "ip", "public", "private", False, mock.ANY)
 
     @mock.patch('os.path.join')
     @mock.patch('azext_ssh.custom._assert_args')
@@ -53,7 +53,7 @@ def test_do_ssh_op(self, mock_write_cert, mock_ssh_creds, mock_get_mod_exp, mock
         mock_ssh_creds.return_value = "username", "certificate"
         mock_join.return_value = "public-aadcert.pub"
 
-        custom._do_ssh_op(cmd, None, None, "1.2.3.4", "publicfile", "privatefile", mock_op)
+        custom._do_ssh_op(cmd, None, None, "1.2.3.4", "publicfile", "privatefile", False, mock_op)
 
         mock_assert.assert_called_once_with(None, None, "1.2.3.4")
         mock_check_files.assert_called_once_with("publicfile", "privatefile")
@@ -76,11 +76,11 @@ def test_do_ssh_op_no_public_ip(self, mock_get_mod_exp, mock_ip, mock_check_file
 
         self.assertRaises(
             util.CLIError, custom._do_ssh_op, cmd, "rg", "vm", None,
-            "publicfile", "privatefile", mock_op)
+            "publicfile", "privatefile", False, mock_op)
 
         mock_assert.assert_called_once_with("rg", "vm", None)
         mock_check_files.assert_called_once_with("publicfile", "privatefile")
-        mock_ip.assert_called_once_with(cmd, "rg", "vm")
+        mock_ip.assert_called_once_with(cmd, "rg", "vm", False)
 
     def test_assert_args_no_ip_or_vm(self):
         self.assertRaises(util.CLIError, custom._assert_args, None, None, None)
diff --git a/src/ssh/azext_ssh/tests/latest/test_ssh_utils.py b/src/ssh/azext_ssh/tests/latest/test_ssh_utils.py
index 11d74dc62e8..fceb5e09fb6 100644
--- a/src/ssh/azext_ssh/tests/latest/test_ssh_utils.py
+++ b/src/ssh/azext_ssh/tests/latest/test_ssh_utils.py
@@ -32,6 +32,7 @@ def test_start_ssh_connection(self, mock_call, mock_build, mock_host, mock_path)
     @mock.patch('azext_ssh.ssh_utils.file_utils.make_dirs_for_file')
     def test_write_ssh_config_ip_and_vm(self, mock_make_dirs):
         expected_lines = [
+            "",
             "Host rg-vm",
             "\tUser username",
             "\tHostName 1.2.3.4",
@@ -47,16 +48,43 @@ def test_write_ssh_config_ip_and_vm(self, mock_make_dirs):
             mock_file = mock.Mock()
             mock_open.return_value.__enter__.return_value = mock_file
             ssh_utils.write_ssh_config(
-                "path/to/file", "rg", "vm", "1.2.3.4", "username", "cert", "privatekey"
+                "path/to/file", "rg", "vm", True, "1.2.3.4", "username", "cert", "privatekey"
             )
 
         mock_make_dirs.assert_called_once_with("path/to/file")
         mock_open.assert_called_once_with("path/to/file", "w")
         mock_file.write.assert_called_once_with('\n'.join(expected_lines))
 
+    @mock.patch('azext_ssh.ssh_utils.file_utils.make_dirs_for_file')
+    def test_write_ssh_config_append(self, mock_make_dirs):
+        expected_lines = [
+            "",
+            "Host rg-vm",
+            "\tUser username",
+            "\tHostName 1.2.3.4",
+            "\tCertificateFile cert",
+            "\tIdentityFile privatekey",
+            "Host 1.2.3.4",
+            "\tUser username",
+            "\tCertificateFile cert",
+            "\tIdentityFile privatekey"
+        ]
+
+        with mock.patch('builtins.open') as mock_open:
+            mock_file = mock.Mock()
+            mock_open.return_value.__enter__.return_value = mock_file
+            ssh_utils.write_ssh_config(
+                "path/to/file", "rg", "vm", False, "1.2.3.4", "username", "cert", "privatekey"
+            )
+
+        mock_make_dirs.assert_called_once_with("path/to/file")
+        mock_open.assert_called_once_with("path/to/file", "a")
+        mock_file.write.assert_called_once_with('\n'.join(expected_lines))
+
     @mock.patch('azext_ssh.ssh_utils.file_utils.make_dirs_for_file')
     def test_write_ssh_config_ip_only(self, mock_make_dirs):
         expected_lines = [
+            "",
             "Host 1.2.3.4",
             "\tUser username",
             "\tCertificateFile cert",
@@ -67,7 +95,7 @@ def test_write_ssh_config_ip_only(self, mock_make_dirs):
             mock_file = mock.Mock()
             mock_open.return_value.__enter__.return_value = mock_file
             ssh_utils.write_ssh_config(
-                "path/to/file", None, None, "1.2.3.4", "username", "cert", "privatekey"
+                "path/to/file", None, None, True, "1.2.3.4", "username", "cert", "privatekey"
             )
 
         mock_make_dirs.assert_called_once_with("path/to/file")
diff --git a/src/ssh/setup.py b/src/ssh/setup.py
index e3853dc8caa..7876f43ad93 100644
--- a/src/ssh/setup.py
+++ b/src/ssh/setup.py
@@ -7,7 +7,7 @@
 
 from setuptools import setup, find_packages
 
-VERSION = "0.1.2"
+VERSION = "0.1.3"
 
 CLASSIFIERS = [
     'Development Status :: 4 - Beta',