Skip to content

Commit

Permalink
[SSH] Add support for private IPs and force lowercase usernames (Azur…
Browse files Browse the repository at this point in the history
  • Loading branch information
N6UDP authored Jan 7, 2021
1 parent a59ebae commit 7a16cc5
Show file tree
Hide file tree
Showing 9 changed files with 91 additions and 43 deletions.
7 changes: 7 additions & 0 deletions src/ssh/HISTORY.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,13 @@
Release History
===============

0.1.3
-----
* Add support for using private IPs
* Add option alias `--name` for `--vm-name`
* Use lowercase username by default
* Fix various typos

0.1.2
-----
* Add support for hardware tokens (don't require the private key be passed in)
Expand Down
2 changes: 1 addition & 1 deletion src/ssh/azext_ssh/_help.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@

helps['ssh cert'] = """
type: command
short-summary: Create an SSH RSA certifcate signed by AAD
short-summary: Create an SSH RSA certificate signed by AAD
examples:
- name: Create a short lived ssh certificate signed by AAD
text: |
Expand Down
10 changes: 8 additions & 2 deletions src/ssh/azext_ssh/_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,23 @@
def load_arguments(self, _):

with self.argument_context('ssh vm') as c:
c.argument('vm_name', options_list=['--vm-name', '-n'], help='The name of the VM')
c.argument('vm_name', options_list=['--vm-name', '--name', '-n'], help='The name of the VM')
c.argument('ssh_ip', options_list=['--ip'], help='The public IP address (or hostname) of the VM')
c.argument('public_key_file', options_list=['--public-key-file', '-p'], help='The RSA public key file path')
c.argument('private_key_file', options_list=['--private-key-file', '-i'], help='The RSA private key file path')
c.argument('use_private_ip', options_list=['--prefer-private-ip'],
help='Will use a private IP if available. By default only public IPs are used.')

with self.argument_context('ssh config') as c:
c.argument('config_path', options_list=['--file', '-f'], help='The file path to write the SSH config to')
c.argument('vm_name', options_list=['--vm-name', '-n'], help='The name of the VM')
c.argument('vm_name', options_list=['--vm-name', '--name', '-n'], help='The name of the VM')
c.argument('ssh_ip', options_list=['--ip'], help='The public IP address (or hostname) of the VM')
c.argument('public_key_file', options_list=['--public-key-file', '-p'], help='The RSA public key file path')
c.argument('private_key_file', options_list=['--private-key-file', '-i'], help='The RSA private key file path')
c.argument('use_private_ip', options_list=['--prefer-private-ip'],
help='Will use a private IP if available. By default only public IPs are used.')
c.argument('overwrite', action='store_true', options_list=['--overwrite'],
help='Overwrites the config file if this flag is set')

with self.argument_context('ssh cert') as c:
c.argument('cert_path', options_list=['--file', '-f'],
Expand Down
28 changes: 16 additions & 12 deletions src/ssh/azext_ssh/custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,16 @@
from . import ssh_utils


def ssh_vm(cmd, resource_group_name=None, vm_name=None, ssh_ip=None, public_key_file=None, private_key_file=None):
def ssh_vm(cmd, resource_group_name=None, vm_name=None, ssh_ip=None, public_key_file=None,
private_key_file=None, use_private_ip=False):
_do_ssh_op(cmd, resource_group_name, vm_name, ssh_ip,
public_key_file, private_key_file, ssh_utils.start_ssh_connection)
public_key_file, private_key_file, use_private_ip, ssh_utils.start_ssh_connection)


def ssh_config(cmd, config_path, resource_group_name=None, vm_name=None, ssh_ip=None,
public_key_file=None, private_key_file=None):
op_call = functools.partial(ssh_utils.write_ssh_config, config_path, resource_group_name, vm_name)
_do_ssh_op(cmd, resource_group_name, vm_name, ssh_ip, public_key_file, private_key_file, op_call)
public_key_file=None, private_key_file=None, overwrite=False, use_private_ip=False):
op_call = functools.partial(ssh_utils.write_ssh_config, config_path, resource_group_name, vm_name, overwrite)
_do_ssh_op(cmd, resource_group_name, vm_name, ssh_ip, public_key_file, private_key_file, use_private_ip, op_call)


def ssh_cert(cmd, cert_path=None, public_key_file=None):
Expand All @@ -33,13 +34,16 @@ def ssh_cert(cmd, cert_path=None, public_key_file=None):
print(cert_file + "\n")


def _do_ssh_op(cmd, resource_group, vm_name, ssh_ip, public_key_file, private_key_file, op_call):
def _do_ssh_op(cmd, resource_group, vm_name, ssh_ip, public_key_file, private_key_file, use_private_ip, op_call):
_assert_args(resource_group, vm_name, ssh_ip)
public_key_file, private_key_file = _check_or_create_public_private_files(public_key_file, private_key_file)
ssh_ip = ssh_ip or ip_utils.get_ssh_ip(cmd, resource_group, vm_name)
ssh_ip = ssh_ip or ip_utils.get_ssh_ip(cmd, resource_group, vm_name, use_private_ip)

if not ssh_ip:
raise util.CLIError(f"VM '{vm_name}' does not have a public IP address to SSH to")
if not use_private_ip:
raise util.CLIError(f"VM '{vm_name}' does not have a public IP address to SSH to")

raise util.CLIError(f"VM '{vm_name}' does not have a public or private IP address to SSH to")

cert_file, username = _get_and_write_certificate(cmd, public_key_file, None)
op_call(ssh_ip, username, cert_file, private_key_file)
Expand All @@ -53,7 +57,7 @@ def _get_and_write_certificate(cmd, public_key_file, cert_file):
username, certificate = profile.get_msal_token(scopes, data)
if not cert_file:
cert_file = public_key_file + "-aadcert.pub"
return _write_cert_file(certificate, cert_file), username
return _write_cert_file(certificate, cert_file), username.lower()


def _prepare_jwk_data(public_key_file):
Expand All @@ -79,13 +83,13 @@ def _prepare_jwk_data(public_key_file):

def _assert_args(resource_group, vm_name, ssh_ip):
if not (resource_group or vm_name or ssh_ip):
raise util.CLIError("The VM must be specified by --ip or --resource-group and --name")
raise util.CLIError("The VM must be specified by --ip or --resource-group and --vm-name/--name")

if resource_group and not vm_name or vm_name and not resource_group:
raise util.CLIError("--resource-group and --name must be provided together")
raise util.CLIError("--resource-group and --vm-name/--name must be provided together")

if ssh_ip and (vm_name or resource_group):
raise util.CLIError("--ip cannot be used with --resource-group or --name")
raise util.CLIError("--ip cannot be used with --resource-group or --vm-name/--name")


def _check_or_create_public_private_files(public_key_file, private_key_file):
Expand Down
22 changes: 9 additions & 13 deletions src/ssh/azext_ssh/ip_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,29 +8,25 @@
from msrestazure import tools


def get_ssh_ip(cmd, resource_group, vm_name):
def get_ssh_ip(cmd, resource_group, vm_name, use_private_ip):
compute_client = client_factory.get_mgmt_service_client(cmd.cli_ctx, profiles.ResourceType.MGMT_COMPUTE)
network_client = client_factory.get_mgmt_service_client(cmd.cli_ctx, profiles.ResourceType.MGMT_NETWORK)
vm_client = compute_client.virtual_machines
nic_client = network_client.network_interfaces
ip_client = network_client.public_ip_addresses

vm = vm_client.get(resource_group, vm_name)
nics = vm.network_profile.network_interfaces
ssh_ip = None
for nic_ref in nics:

for nic_ref in vm.network_profile.network_interfaces:
parsed_id = tools.parse_resource_id(nic_ref.id)
nic = nic_client.get(parsed_id['resource_group'], parsed_id['name'])
ip_configs = nic.ip_configurations
for ip_config in ip_configs:
for ip_config in nic.ip_configurations:
if use_private_ip and ip_config.private_ip_address:
return ip_config.private_ip_address
public_ip_ref = ip_config.public_ip_address
parsed_ip_id = tools.parse_resource_id(public_ip_ref.id)
public_ip = ip_client.get(parsed_ip_id['resource_group'], parsed_ip_id['name'])
ssh_ip = public_ip.ip_address

if ssh_ip:
break
if public_ip.ip_address:
return public_ip.ip_address

if ssh_ip:
break
return ssh_ip
return None
13 changes: 10 additions & 3 deletions src/ssh/azext_ssh/ssh_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,12 @@ def create_ssh_keyfile(private_key_file):
subprocess.call(command, shell=platform.system() == 'Windows')


def write_ssh_config(config_path, resource_group, vm_name,
def write_ssh_config(config_path, resource_group, vm_name, overwrite,
ip, username, cert_file, private_key_file):
file_utils.make_dirs_for_file(config_path)
lines = []

lines = [""]

if resource_group and vm_name:
lines.append("Host " + resource_group + "-" + vm_name)
lines.append("\tUser " + username)
Expand All @@ -49,7 +51,12 @@ def write_ssh_config(config_path, resource_group, vm_name,
if private_key_file:
lines.append("\tIdentityFile " + private_key_file)

with open(config_path, 'w') as f:
if overwrite:
mode = 'w'
else:
mode = 'a'

with open(config_path, mode) as f:
f.write('\n'.join(lines))


Expand Down
18 changes: 9 additions & 9 deletions src/ssh/azext_ssh/tests/latest/test_custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,26 +16,26 @@ class SshCustomCommandTest(unittest.TestCase):
@mock.patch('azext_ssh.custom.ssh_utils')
def test_ssh_vm(self, mock_ssh_utils, mock_do_op):
cmd = mock.Mock()
custom.ssh_vm(cmd, "rg", "vm", "ip", "public", "private")
custom.ssh_vm(cmd, "rg", "vm", "ip", "public", "private", False)

mock_do_op.assert_called_once_with(
cmd, "rg", "vm", "ip", "public", "private", mock_ssh_utils.start_ssh_connection)
cmd, "rg", "vm", "ip", "public", "private", False, mock_ssh_utils.start_ssh_connection)

@mock.patch('azext_ssh.custom._do_ssh_op')
@mock.patch('azext_ssh.ssh_utils.write_ssh_config')
def test_ssh_config(self, mock_ssh_utils, mock_do_op):
cmd = mock.Mock()

def do_op_side_effect(cmd, resource_group, vm_name, ssh_ip, public_key_file, private_key_file, op_call):
def do_op_side_effect(cmd, resource_group, vm_name, ssh_ip, public_key_file, private_key_file, use_private_ip, op_call):
op_call(ssh_ip, "username", "cert_file", private_key_file)

mock_do_op.side_effect = do_op_side_effect
custom.ssh_config(cmd, "path/to/file", "rg", "vm", "ip", "public", "private")
custom.ssh_config(cmd, "path/to/file", "rg", "vm", "ip", "public", "private", False, False)

mock_ssh_utils.assert_called_once_with("path/to/file", "rg", "vm", "ip", "username", "cert_file", "private")
mock_ssh_utils.assert_called_once_with("path/to/file", "rg", "vm", False, "ip", "username", "cert_file", "private")

mock_do_op.assert_called_once_with(
cmd, "rg", "vm", "ip", "public", "private", mock.ANY)
cmd, "rg", "vm", "ip", "public", "private", False, mock.ANY)

@mock.patch('os.path.join')
@mock.patch('azext_ssh.custom._assert_args')
Expand All @@ -53,7 +53,7 @@ def test_do_ssh_op(self, mock_write_cert, mock_ssh_creds, mock_get_mod_exp, mock
mock_ssh_creds.return_value = "username", "certificate"
mock_join.return_value = "public-aadcert.pub"

custom._do_ssh_op(cmd, None, None, "1.2.3.4", "publicfile", "privatefile", mock_op)
custom._do_ssh_op(cmd, None, None, "1.2.3.4", "publicfile", "privatefile", False, mock_op)

mock_assert.assert_called_once_with(None, None, "1.2.3.4")
mock_check_files.assert_called_once_with("publicfile", "privatefile")
Expand All @@ -76,11 +76,11 @@ def test_do_ssh_op_no_public_ip(self, mock_get_mod_exp, mock_ip, mock_check_file

self.assertRaises(
util.CLIError, custom._do_ssh_op, cmd, "rg", "vm", None,
"publicfile", "privatefile", mock_op)
"publicfile", "privatefile", False, mock_op)

mock_assert.assert_called_once_with("rg", "vm", None)
mock_check_files.assert_called_once_with("publicfile", "privatefile")
mock_ip.assert_called_once_with(cmd, "rg", "vm")
mock_ip.assert_called_once_with(cmd, "rg", "vm", False)

def test_assert_args_no_ip_or_vm(self):
self.assertRaises(util.CLIError, custom._assert_args, None, None, None)
Expand Down
32 changes: 30 additions & 2 deletions src/ssh/azext_ssh/tests/latest/test_ssh_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def test_start_ssh_connection(self, mock_call, mock_build, mock_host, mock_path)
@mock.patch('azext_ssh.ssh_utils.file_utils.make_dirs_for_file')
def test_write_ssh_config_ip_and_vm(self, mock_make_dirs):
expected_lines = [
"",
"Host rg-vm",
"\tUser username",
"\tHostName 1.2.3.4",
Expand All @@ -47,16 +48,43 @@ def test_write_ssh_config_ip_and_vm(self, mock_make_dirs):
mock_file = mock.Mock()
mock_open.return_value.__enter__.return_value = mock_file
ssh_utils.write_ssh_config(
"path/to/file", "rg", "vm", "1.2.3.4", "username", "cert", "privatekey"
"path/to/file", "rg", "vm", True, "1.2.3.4", "username", "cert", "privatekey"
)

mock_make_dirs.assert_called_once_with("path/to/file")
mock_open.assert_called_once_with("path/to/file", "w")
mock_file.write.assert_called_once_with('\n'.join(expected_lines))

@mock.patch('azext_ssh.ssh_utils.file_utils.make_dirs_for_file')
def test_write_ssh_config_append(self, mock_make_dirs):
expected_lines = [
"",
"Host rg-vm",
"\tUser username",
"\tHostName 1.2.3.4",
"\tCertificateFile cert",
"\tIdentityFile privatekey",
"Host 1.2.3.4",
"\tUser username",
"\tCertificateFile cert",
"\tIdentityFile privatekey"
]

with mock.patch('builtins.open') as mock_open:
mock_file = mock.Mock()
mock_open.return_value.__enter__.return_value = mock_file
ssh_utils.write_ssh_config(
"path/to/file", "rg", "vm", False, "1.2.3.4", "username", "cert", "privatekey"
)

mock_make_dirs.assert_called_once_with("path/to/file")
mock_open.assert_called_once_with("path/to/file", "a")
mock_file.write.assert_called_once_with('\n'.join(expected_lines))

@mock.patch('azext_ssh.ssh_utils.file_utils.make_dirs_for_file')
def test_write_ssh_config_ip_only(self, mock_make_dirs):
expected_lines = [
"",
"Host 1.2.3.4",
"\tUser username",
"\tCertificateFile cert",
Expand All @@ -67,7 +95,7 @@ def test_write_ssh_config_ip_only(self, mock_make_dirs):
mock_file = mock.Mock()
mock_open.return_value.__enter__.return_value = mock_file
ssh_utils.write_ssh_config(
"path/to/file", None, None, "1.2.3.4", "username", "cert", "privatekey"
"path/to/file", None, None, True, "1.2.3.4", "username", "cert", "privatekey"
)

mock_make_dirs.assert_called_once_with("path/to/file")
Expand Down
2 changes: 1 addition & 1 deletion src/ssh/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from setuptools import setup, find_packages

VERSION = "0.1.2"
VERSION = "0.1.3"

CLASSIFIERS = [
'Development Status :: 4 - Beta',
Expand Down

0 comments on commit 7a16cc5

Please sign in to comment.