Skip to content

Commit

Permalink
EC2: Add support for ED25519 SSH key pairs (#6904)
Browse files Browse the repository at this point in the history
  • Loading branch information
viren-nadkarni authored Oct 12, 2023
1 parent e594430 commit 760e28b
Show file tree
Hide file tree
Showing 6 changed files with 104 additions and 52 deletions.
19 changes: 12 additions & 7 deletions moto/ec2/models/key_pairs.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,12 @@
InvalidKeyPairFormatError,
)
from ..utils import (
random_key_pair,
rsa_public_key_fingerprint,
rsa_public_key_parse,
generic_filter,
public_key_fingerprint,
public_key_parse,
random_key_pair_id,
random_ed25519_key_pair,
random_rsa_key_pair,
)
from moto.core.utils import iso_8601_datetime_with_milliseconds, utcnow

Expand Down Expand Up @@ -42,10 +43,14 @@ class KeyPairBackend:
def __init__(self) -> None:
self.keypairs: Dict[str, KeyPair] = {}

def create_key_pair(self, name: str) -> KeyPair:
def create_key_pair(self, name: str, key_type: str = "rsa") -> KeyPair:
if name in self.keypairs:
raise InvalidKeyPairDuplicateError(name)
keypair = KeyPair(name, **random_key_pair())
if key_type == "ed25519":
keypair = KeyPair(name, **random_ed25519_key_pair())
else:
keypair = KeyPair(name, **random_rsa_key_pair())

self.keypairs[name] = keypair
return keypair

Expand Down Expand Up @@ -77,11 +82,11 @@ def import_key_pair(self, key_name: str, public_key_material: str) -> KeyPair:
raise InvalidKeyPairDuplicateError(key_name)

try:
rsa_public_key = rsa_public_key_parse(public_key_material)
public_key = public_key_parse(public_key_material)
except ValueError:
raise InvalidKeyPairFormatError()

fingerprint = rsa_public_key_fingerprint(rsa_public_key)
fingerprint = public_key_fingerprint(public_key)
keypair = KeyPair(
key_name, material=public_key_material, fingerprint=fingerprint
)
Expand Down
4 changes: 2 additions & 2 deletions moto/ec2/responses/key_pairs.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
class KeyPairs(EC2BaseResponse):
def create_key_pair(self) -> str:
name = self._get_param("KeyName")
key_type = self._get_param("KeyType")
self.error_on_dryrun()

keypair = self.ec2_backend.create_key_pair(name)
keypair = self.ec2_backend.create_key_pair(name, key_type)
return self.response_template(CREATE_KEY_PAIR_RESPONSE).render(keypair=keypair)

def delete_key_pair(self) -> str:
Expand Down
51 changes: 40 additions & 11 deletions moto/ec2/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,11 @@
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives.asymmetric import rsa
from cryptography.hazmat.primitives.asymmetric.ed25519 import (
Ed25519PublicKey,
Ed25519PrivateKey,
)
from cryptography.hazmat.primitives.asymmetric.rsa import RSAPublicKey
from typing import Any, Dict, List, Set, TypeVar, Tuple, Optional, Union

from moto.core.utils import utcnow
Expand Down Expand Up @@ -552,7 +557,22 @@ def simple_aws_filter_to_re(filter_string: str) -> str:
return tmp_filter


def random_key_pair() -> Dict[str, str]:
def random_ed25519_key_pair() -> Dict[str, str]:
private_key = Ed25519PrivateKey.generate()
private_key_material = private_key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.OpenSSH,
encryption_algorithm=serialization.NoEncryption(),
)
fingerprint = public_key_fingerprint(private_key.public_key())

return {
"fingerprint": fingerprint,
"material": private_key_material.decode("ascii"),
}


def random_rsa_key_pair() -> Dict[str, str]:
private_key = rsa.generate_private_key(
public_exponent=65537, key_size=2048, backend=default_backend()
)
Expand All @@ -561,10 +581,10 @@ def random_key_pair() -> Dict[str, str]:
format=serialization.PrivateFormat.TraditionalOpenSSL,
encryption_algorithm=serialization.NoEncryption(),
)
public_key_fingerprint = rsa_public_key_fingerprint(private_key.public_key())
fingerprint = public_key_fingerprint(private_key.public_key())

return {
"fingerprint": public_key_fingerprint,
"fingerprint": fingerprint,
"material": private_key_material.decode("ascii"),
}

Expand Down Expand Up @@ -645,7 +665,9 @@ def generate_instance_identity_document(instance: Any) -> Dict[str, Any]:
return document


def rsa_public_key_parse(key_material: Any) -> Any:
def public_key_parse(
key_material: Union[str, bytes]
) -> Union[RSAPublicKey, Ed25519PublicKey]:
# These imports take ~.5s; let's keep them local
import sshpubkeys.exceptions
from sshpubkeys.keys import SSHKey
Expand All @@ -654,19 +676,26 @@ def rsa_public_key_parse(key_material: Any) -> Any:
if not isinstance(key_material, bytes):
key_material = key_material.encode("ascii")

decoded_key = base64.b64decode(key_material).decode("ascii")
public_key = SSHKey(decoded_key)
decoded_key = base64.b64decode(key_material)
public_key = SSHKey(decoded_key.decode("ascii"))
except (sshpubkeys.exceptions.InvalidKeyException, UnicodeDecodeError):
raise ValueError("bad key")

if not public_key.rsa:
raise ValueError("bad key")
if public_key.rsa:
return public_key.rsa

# `cryptography` currently does not support RSA RFC4716/SSH2 format, otherwise we could get rid of `sshpubkeys` and
# simply use `load_ssh_public_key()`
if public_key.key_type == b"ssh-ed25519":
return serialization.load_ssh_public_key(decoded_key) # type: ignore[return-value]

return public_key.rsa
raise ValueError("bad key")


def rsa_public_key_fingerprint(rsa_public_key: Any) -> str:
key_data = rsa_public_key.public_bytes(
def public_key_fingerprint(public_key: Union[RSAPublicKey, Ed25519PublicKey]) -> str:
# TODO: Use different fingerprint calculation methods based on key type and source
# see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/verify-keys.html#how-ec2-key-fingerprints-are-calculated
key_data = public_key.public_bytes(
encoding=serialization.Encoding.DER,
format=serialization.PublicFormat.SubjectPublicKeyInfo,
)
Expand Down
25 changes: 17 additions & 8 deletions tests/test_ec2/helpers.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,23 @@
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import rsa
from cryptography.hazmat.primitives.asymmetric import rsa, ed25519


def rsa_check_private_key(private_key_material):
def check_private_key(private_key_material, key_type):
assert isinstance(private_key_material, str)

private_key = serialization.load_pem_private_key(
data=private_key_material.encode("ascii"),
backend=default_backend(),
password=None,
)
assert isinstance(private_key, rsa.RSAPrivateKey)
if key_type == "rsa":
private_key = serialization.load_pem_private_key(
data=private_key_material.encode("ascii"),
backend=default_backend(),
password=None,
)
assert isinstance(private_key, rsa.RSAPrivateKey)
elif key_type == "ed25519":
private_key = serialization.load_ssh_private_key(
data=private_key_material.encode("ascii"),
password=None,
)
assert isinstance(private_key, ed25519.Ed25519PrivateKey)
else:
raise AssertionError("Bad private key")
48 changes: 27 additions & 21 deletions tests/test_ec2/test_key_pairs.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,16 @@
from uuid import uuid4
from unittest import SkipTest

from .helpers import rsa_check_private_key
from .helpers import check_private_key


ED25519_PUBLIC_KEY_OPENSSH = b"""\
ssh-ed25519 \
AAAAC3NzaC1lZDI1NTE5AAAAIEwsSB9HbTeKCdkSlMZeTq9jZggaPJUwAsUi/7wakB+B \
moto@getmoto"""

ED25519_PUBLIC_KEY_FINGERPRINT = "6c:d9:cf:90:d7:f7:bc:46:83:9e:f5:56:aa:e1:13:38"

RSA_PUBLIC_KEY_OPENSSH = b"""\
ssh-rsa \
AAAAB3NzaC1yc2EAAAADAQABAAABAQDusXfgTE4eBP50NglSzCSEGnIL6+cr6m3H\
Expand Down Expand Up @@ -77,17 +84,18 @@ def test_key_pairs_create_dryrun_boto3():


@mock_ec2
def test_key_pairs_create_boto3():
@pytest.mark.parametrize("key_type", ["rsa", "ed25519"])
def test_key_pairs_create_boto3(key_type):
ec2 = boto3.resource("ec2", "us-west-1")
client = boto3.client("ec2", "us-west-1")

key_name = str(uuid4())[0:6]
kp = ec2.create_key_pair(KeyName=key_name)
rsa_check_private_key(kp.key_material)
kp = ec2.create_key_pair(KeyName=key_name, KeyType=key_type)
check_private_key(kp.key_material, key_type)
# Verify the client can create a key_pair as well - should behave the same
key_name2 = str(uuid4())
kp2 = client.create_key_pair(KeyName=key_name2)
rsa_check_private_key(kp2["KeyMaterial"])
kp2 = client.create_key_pair(KeyName=key_name2, KeyType=key_type)
check_private_key(kp2["KeyMaterial"], key_type)

assert kp.key_material != kp2["KeyMaterial"]

Expand Down Expand Up @@ -145,13 +153,22 @@ def test_key_pairs_delete_exist_boto3():


@mock_ec2
def test_key_pairs_import_boto3():
@pytest.mark.parametrize(
"public_key,fingerprint",
[
(RSA_PUBLIC_KEY_OPENSSH, RSA_PUBLIC_KEY_FINGERPRINT),
(RSA_PUBLIC_KEY_RFC4716, RSA_PUBLIC_KEY_FINGERPRINT),
(ED25519_PUBLIC_KEY_OPENSSH, ED25519_PUBLIC_KEY_FINGERPRINT),
],
ids=["rsa-openssh", "rsa-rfc4716", "ed25519"],
)
def test_key_pairs_import_boto3(public_key, fingerprint):
client = boto3.client("ec2", "us-west-1")

key_name = str(uuid4())[0:6]
with pytest.raises(ClientError) as ex:
client.import_key_pair(
KeyName=key_name, PublicKeyMaterial=RSA_PUBLIC_KEY_OPENSSH, DryRun=True
KeyName=key_name, PublicKeyMaterial=public_key, DryRun=True
)
assert ex.value.response["Error"]["Code"] == "DryRunOperation"
assert ex.value.response["ResponseMetadata"]["HTTPStatusCode"] == 412
Expand All @@ -160,26 +177,15 @@ def test_key_pairs_import_boto3():
== "An error occurred (DryRunOperation) when calling the ImportKeyPair operation: Request would have succeeded, but DryRun flag is set"
)

kp1 = client.import_key_pair(
KeyName=key_name, PublicKeyMaterial=RSA_PUBLIC_KEY_OPENSSH
)
kp1 = client.import_key_pair(KeyName=key_name, PublicKeyMaterial=public_key)

assert "KeyPairId" in kp1
assert kp1["KeyName"] == key_name
assert kp1["KeyFingerprint"] == RSA_PUBLIC_KEY_FINGERPRINT

key_name2 = str(uuid4())
kp2 = client.import_key_pair(
KeyName=key_name2, PublicKeyMaterial=RSA_PUBLIC_KEY_RFC4716
)
assert "KeyPairId" in kp2
assert kp2["KeyName"] == key_name2
assert kp2["KeyFingerprint"] == RSA_PUBLIC_KEY_FINGERPRINT
assert kp1["KeyFingerprint"] == fingerprint

all_kps = client.describe_key_pairs()["KeyPairs"]
all_names = [kp["KeyName"] for kp in all_kps]
assert kp1["KeyName"] in all_names
assert kp2["KeyName"] in all_names


@mock_ec2
Expand Down
9 changes: 6 additions & 3 deletions tests/test_ec2/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,20 @@

from moto.ec2 import utils

from .helpers import rsa_check_private_key
from .helpers import check_private_key


def test_random_key_pair():
key_pair = utils.random_key_pair()
rsa_check_private_key(key_pair["material"])
key_pair = utils.random_rsa_key_pair()
check_private_key(key_pair["material"], "rsa")

# AWS uses MD5 fingerprints, which are 47 characters long, *not* SHA1
# fingerprints with 59 characters.
assert len(key_pair["fingerprint"]) == 47

key_pair = utils.random_ed25519_key_pair()
check_private_key(key_pair["material"], "ed25519")


def test_random_ipv6_cidr():
def mocked_random_resource_id(chars: int):
Expand Down

0 comments on commit 760e28b

Please sign in to comment.