-
Notifications
You must be signed in to change notification settings - Fork 1.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Add SSH extension * Add unit tests and break things down a bit * Add ssh config command * URL-encode modulus and exponent * Make az ssh vm start interactive * Finalize ssh commands * Updated help docs * Reorder imports * Updated arguments CLI arguments were changed from extra to argument with None defaults in custom. Changed resource_group to resource_group_name argument for auto-population from cli-core. Removed Python 2 and added new Python 3s to setup. * Fix lint error * move ssh certificate part from core to extension * follow CLI extension guide * update CODEOWNERS * update CODEOWNERS * roll back to [email protected] for certificate file * update --name to --vm-name * fix sytle error Co-authored-by: Ryan Rossiter <[email protected]> Co-authored-by: Xiaojian Xu <[email protected]>
- Loading branch information
1 parent
16610a3
commit 92a2530
Showing
18 changed files
with
874 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
# -------------------------------------------------------------------------------------------- | ||
# Copyright (c) Microsoft Corporation. All rights reserved. | ||
# Licensed under the MIT License. See License.txt in the project root for license information. | ||
# -------------------------------------------------------------------------------------------- | ||
|
||
from azure.cli.core import AzCommandsLoader | ||
from azext_ssh._help import helps # pylint: disable=unused-import | ||
|
||
|
||
class SshCommandsLoader(AzCommandsLoader): | ||
|
||
def __init__(self, cli_ctx=None): | ||
from azure.cli.core.commands import CliCommandType | ||
|
||
ssh_custom = CliCommandType(operations_tmpl='azext_ssh.custom#{}') | ||
super(SshCommandsLoader, self).__init__(cli_ctx=cli_ctx, custom_command_type=ssh_custom) | ||
|
||
def load_command_table(self, args): | ||
from azext_ssh.commands import load_command_table | ||
load_command_table(self, args) | ||
return self.command_table | ||
|
||
def load_arguments(self, command): | ||
from azext_ssh._params import load_arguments | ||
load_arguments(self, command) | ||
|
||
|
||
COMMAND_LOADER_CLS = SshCommandsLoader |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
# -------------------------------------------------------------------------------------------- | ||
# Copyright (c) Microsoft Corporation. All rights reserved. | ||
# Licensed under the MIT License. See License.txt in the project root for license information. | ||
# -------------------------------------------------------------------------------------------- | ||
|
||
from knack.help_files import helps | ||
|
||
helps['ssh'] = """ | ||
type: group | ||
short-summary: SSH into Azure VMs | ||
""" | ||
|
||
helps['ssh vm'] = """ | ||
type: command | ||
short-summary: SSH into Azure VMs | ||
examples: | ||
- name: Give a resource group and VM to SSH to | ||
text: | | ||
az ssh vm --resource-group myResourceGroup --vm-name myVm | ||
- name: Give the public IP of a VM to SSH to | ||
text: | | ||
az ssh vm --ip 1.2.3.4 | ||
""" | ||
|
||
helps['ssh config'] = """ | ||
type: command | ||
short-summary: Create an SSH config for Azure VMs which can then be imported to 3rd party SSH clients | ||
examples: | ||
- name: Give a resource group and VM for which to create a config, and save in a local file | ||
text: | | ||
az ssh config --resource-group myResourceGroup --vm-name myVm --file ./sshconfig | ||
- name: Give the public IP of a VM for which to create a config | ||
text: | | ||
az ssh config --ip 1.2.3.4 --file ./sshconfig | ||
""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
# -------------------------------------------------------------------------------------------- | ||
# Copyright (c) Microsoft Corporation. All rights reserved. | ||
# Licensed under the MIT License. See License.txt in the project root for license information. | ||
# -------------------------------------------------------------------------------------------- | ||
|
||
|
||
def load_arguments(self, _): | ||
|
||
with self.argument_context('ssh vm') as c: | ||
c.argument('vm_name', options_list=['--vm-name'], help='The name of the VM') | ||
c.argument('ssh_ip', options_list=['--ip'], help='The public IP address of the VM') | ||
c.argument('public_key_file', help='The RSA public key file path') | ||
c.argument('private_key_file', help='The RSA private key file path') | ||
|
||
with self.argument_context('ssh config') as c: | ||
c.argument('config_path', options_list=['--file'], help='The file path to write the SSH config to') | ||
c.argument('vm_name', options_list=['--vm-name'], help='The name of the VM') | ||
c.argument('ssh_ip', options_list=['--ip'], help='The public IP address of the VM') | ||
c.argument('public_key_file', help='The RSA public key file path') | ||
c.argument('private_key_file', help='The RSA private key file path') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
{ | ||
"azext.isPreview": true, | ||
"azext.minCliCoreVersion": "2.4.0" | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
# -------------------------------------------------------------------------------------------- | ||
# Copyright (c) Microsoft Corporation. All rights reserved. | ||
# Licensed under the MIT License. See License.txt in the project root for license information. | ||
# -------------------------------------------------------------------------------------------- | ||
|
||
|
||
def load_command_table(self, _): | ||
|
||
with self.command_group('ssh') as g: | ||
g.custom_command('vm', 'ssh_vm') | ||
g.custom_command('config', 'ssh_config') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,115 @@ | ||
# -------------------------------------------------------------------------------------------- | ||
# Copyright (c) Microsoft Corporation. All rights reserved. | ||
# Licensed under the MIT License. See License.txt in the project root for license information. | ||
# -------------------------------------------------------------------------------------------- | ||
|
||
import functools | ||
import os | ||
import hashlib | ||
import json | ||
|
||
from knack import util | ||
|
||
from . import ip_utils | ||
from . import rsa_parser | ||
from . import ssh_utils | ||
|
||
|
||
def ssh_vm(cmd, resource_group_name=None, vm_name=None, ssh_ip=None, public_key_file=None, private_key_file=None): | ||
_do_ssh_op(cmd, resource_group_name, vm_name, ssh_ip, | ||
public_key_file, private_key_file, ssh_utils.start_ssh_connection) | ||
|
||
|
||
def ssh_config(cmd, config_path, resource_group_name=None, vm_name=None, ssh_ip=None, | ||
public_key_file=None, private_key_file=None): | ||
op_call = functools.partial(ssh_utils.write_ssh_config, config_path, resource_group_name, vm_name) | ||
_do_ssh_op(cmd, resource_group_name, vm_name, ssh_ip, public_key_file, private_key_file, op_call) | ||
|
||
|
||
def _do_ssh_op(cmd, resource_group, vm_name, ssh_ip, public_key_file, private_key_file, op_call): | ||
_assert_args(resource_group, vm_name, ssh_ip) | ||
public_key_file, private_key_file = _check_public_private_files(public_key_file, private_key_file) | ||
ssh_ip = ssh_ip or ip_utils.get_ssh_ip(cmd, resource_group, vm_name) | ||
|
||
if not ssh_ip: | ||
raise util.CLIError(f"VM '{vm_name}' does not have a public IP address to SSH to") | ||
|
||
scopes = ["https://pas.windows.net/CheckMyAccess/Linux/user_impersonation"] | ||
data = _prepare_jwk_data(public_key_file) | ||
from azure.cli.core._profile import Profile | ||
profile = Profile(cli_ctx=cmd.cli_ctx) | ||
username, certificate = profile.get_msal_token(scopes, data) | ||
|
||
cert_file = _write_cert_file(public_key_file, certificate) | ||
op_call(ssh_ip, username, cert_file, private_key_file) | ||
|
||
|
||
def _prepare_jwk_data(public_key_file): | ||
modulus, exponent = _get_modulus_exponent(public_key_file) | ||
key_hash = hashlib.sha256() | ||
key_hash.update(modulus.encode('utf-8')) | ||
key_hash.update(exponent.encode('utf-8')) | ||
key_id = key_hash.hexdigest() | ||
jwk = { | ||
"kty": "RSA", | ||
"n": modulus, | ||
"e": exponent, | ||
"kid": key_id | ||
} | ||
json_jwk = json.dumps(jwk) | ||
data = { | ||
"token_type": "ssh-cert", | ||
"req_cnf": json_jwk, | ||
"key_id": key_id | ||
} | ||
return data | ||
|
||
|
||
def _assert_args(resource_group, vm_name, ssh_ip): | ||
if not (resource_group or vm_name or ssh_ip): | ||
raise util.CLIError("The VM must be specified by --ip or --resource-group and --name") | ||
|
||
if resource_group and not vm_name or vm_name and not resource_group: | ||
raise util.CLIError("--resource-group and --name must be provided together") | ||
|
||
if ssh_ip and (vm_name or resource_group): | ||
raise util.CLIError("--ip cannot be used with --resource-group or --name") | ||
|
||
|
||
def _check_public_private_files(public_key_file, private_key_file): | ||
ssh_dir_parts = ["~", ".ssh"] | ||
public_key_file = public_key_file or os.path.expanduser(os.path.join(*ssh_dir_parts, "id_rsa.pub")) | ||
private_key_file = private_key_file or os.path.expanduser(os.path.join(*ssh_dir_parts, "id_rsa")) | ||
|
||
if not os.path.isfile(public_key_file): | ||
raise util.CLIError(f"Pulic key file {public_key_file} not found") | ||
if not os.path.isfile(private_key_file): | ||
raise util.CLIError(f"Private key file {private_key_file} not found") | ||
|
||
return public_key_file, private_key_file | ||
|
||
|
||
def _write_cert_file(public_key_file, certificate_contents): | ||
cert_file = os.path.join(*os.path.split(public_key_file)[:-1], "id_rsa-cert.pub") | ||
with open(cert_file, 'w') as f: | ||
f.write(f"[email protected] {certificate_contents}") | ||
|
||
return cert_file | ||
|
||
|
||
def _get_modulus_exponent(public_key_file): | ||
if not os.path.isfile(public_key_file): | ||
raise util.CLIError(f"Public key file '{public_key_file}' was not found") | ||
|
||
with open(public_key_file, 'r') as f: | ||
public_key_text = f.read() | ||
|
||
parser = rsa_parser.RSAParser() | ||
try: | ||
parser.parse(public_key_text) | ||
except Exception as e: | ||
raise util.CLIError(f"Could not parse public key. Error: {str(e)}") | ||
modulus = parser.modulus | ||
exponent = parser.exponent | ||
|
||
return modulus, exponent |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
# -------------------------------------------------------------------------------------------- | ||
# Copyright (c) Microsoft Corporation. All rights reserved. | ||
# Licensed under the MIT License. See License.txt in the project root for license information. | ||
# -------------------------------------------------------------------------------------------- | ||
|
||
import errno | ||
import os | ||
|
||
|
||
def make_dirs_for_file(file_path): | ||
if not os.path.exists(file_path): | ||
mkdir_p(os.path.dirname(file_path)) | ||
|
||
|
||
def mkdir_p(path): | ||
try: | ||
os.makedirs(path) | ||
except OSError as exc: # Python <= 2.5 | ||
if exc.errno == errno.EEXIST and os.path.isdir(path): | ||
pass | ||
else: | ||
raise |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
# -------------------------------------------------------------------------------------------- | ||
# Copyright (c) Microsoft Corporation. All rights reserved. | ||
# Licensed under the MIT License. See License.txt in the project root for license information. | ||
# -------------------------------------------------------------------------------------------- | ||
|
||
from azure.cli.core.commands import client_factory | ||
from azure.cli.core import profiles | ||
from msrestazure import tools | ||
|
||
|
||
def get_ssh_ip(cmd, resource_group, vm_name): | ||
compute_client = client_factory.get_mgmt_service_client(cmd.cli_ctx, profiles.ResourceType.MGMT_COMPUTE) | ||
network_client = client_factory.get_mgmt_service_client(cmd.cli_ctx, profiles.ResourceType.MGMT_NETWORK) | ||
vm_client = compute_client.virtual_machines | ||
nic_client = network_client.network_interfaces | ||
ip_client = network_client.public_ip_addresses | ||
|
||
vm = vm_client.get(resource_group, vm_name) | ||
nics = vm.network_profile.network_interfaces | ||
ssh_ip = None | ||
for nic_ref in nics: | ||
parsed_id = tools.parse_resource_id(nic_ref.id) | ||
nic = nic_client.get(parsed_id['resource_group'], parsed_id['name']) | ||
ip_configs = nic.ip_configurations | ||
for ip_config in ip_configs: | ||
public_ip_ref = ip_config.public_ip_address | ||
parsed_ip_id = tools.parse_resource_id(public_ip_ref.id) | ||
public_ip = ip_client.get(parsed_ip_id['resource_group'], parsed_ip_id['name']) | ||
ssh_ip = public_ip.ip_address | ||
|
||
if ssh_ip: | ||
break | ||
|
||
if ssh_ip: | ||
break | ||
return ssh_ip |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
# -------------------------------------------------------------------------------------------- | ||
# Copyright (c) Microsoft Corporation. All rights reserved. | ||
# Licensed under the MIT License. See License.txt in the project root for license information. | ||
# -------------------------------------------------------------------------------------------- | ||
|
||
import base64 | ||
import struct | ||
|
||
|
||
class RSAParser(object): | ||
# pylint: disable=too-few-public-methods | ||
RSAAlgorithm = 'ssh-rsa' | ||
|
||
def __init__(self): | ||
self.algorithm = '' | ||
self.modulus = '' | ||
self.exponent = '' | ||
self._key_length_big_endian = True | ||
|
||
def parse(self, public_key_text): | ||
text_parts = public_key_text.split(' ') | ||
|
||
if len(text_parts) < 2: | ||
error_str = ("Incorrectly formatted public key. " | ||
"Key must be format '<algorithm> <base64_key>'") | ||
raise ValueError(error_str) | ||
|
||
algorithm = text_parts[0] | ||
if algorithm != RSAParser.RSAAlgorithm: | ||
raise ValueError(f"Public key is not ssh-rsa algorithm ({algorithm})") | ||
|
||
b64_string = text_parts[1] | ||
key_bytes = base64.b64decode(b64_string) | ||
fields = list(self._get_fields(key_bytes)) | ||
|
||
if len(fields) < 3: | ||
error_str = ("Incorrectly encoded public key. " | ||
"Encoded key must be base64 encoded <algorithm><exponent><modulus>") | ||
raise ValueError(error_str) | ||
|
||
encoded_algorithm = fields[0].decode("ascii") | ||
if encoded_algorithm != RSAParser.RSAAlgorithm: | ||
raise ValueError(f"Encoded public key is not ssh-rsa algorithm ({encoded_algorithm})") | ||
|
||
self.algorithm = encoded_algorithm | ||
self.exponent = base64.urlsafe_b64encode(fields[1]).decode("ascii") | ||
self.modulus = base64.urlsafe_b64encode(fields[2]).decode("ascii") | ||
|
||
def _get_fields(self, key_bytes): | ||
read = 0 | ||
while read < len(key_bytes): | ||
length = struct.unpack(self._get_struct_format(), key_bytes[read:read + 4])[0] | ||
read = read + 4 | ||
data = key_bytes[read:read + length] | ||
read = read + length | ||
yield data | ||
|
||
def _get_struct_format(self): | ||
format_start = ">" if self._key_length_big_endian else "<" | ||
return format_start + "L" |
Oops, something went wrong.