Skip to content

Commit

Permalink
Feat: Add KMS methods for Key Rotations and MACs (#8462)
Browse files Browse the repository at this point in the history
  • Loading branch information
MauriceBrg authored Jan 8, 2025
1 parent 7536c08 commit db72e95
Show file tree
Hide file tree
Showing 6 changed files with 470 additions and 2 deletions.
18 changes: 18 additions & 0 deletions moto/kms/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,21 @@ def __init__(self) -> None:
super().__init__("InvalidCiphertextException", "")

self.description = '{"__type":"InvalidCiphertextException"}'


class InvalidKeyUsageException(JsonRESTError):
code = 400

def __init__(self) -> None:
super().__init__("InvalidKeyUsageException", "")

self.description = '{"__type":"InvalidKeyUsageException"}'


class KMSInvalidMacException(JsonRESTError):
code = 400

def __init__(self) -> None:
super().__init__("KMSInvalidMacException", "")

self.description = '{"__type":"KMSInvalidMacException"}'
94 changes: 92 additions & 2 deletions moto/kms/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,23 +4,29 @@
from collections import defaultdict
from copy import copy
from datetime import datetime, timedelta
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union

from moto.core.base_backend import BackendDict, BaseBackend
from moto.core.common_models import BaseModel, CloudFormationModel
from moto.core.exceptions import JsonRESTError
from moto.core.utils import unix_time
from moto.moto_api._internal import mock_random
from moto.utilities.paginator import paginate
from moto.utilities.tagging_service import TaggingService
from moto.utilities.utils import get_partition

from .exceptions import ValidationException
from .exceptions import (
InvalidKeyUsageException,
KMSInvalidMacException,
ValidationException,
)
from .utils import (
RESERVED_ALIASES,
KeySpec,
SigningAlgorithm,
decrypt,
encrypt,
generate_hmac,
generate_key_id,
generate_master_key,
generate_private_key,
Expand Down Expand Up @@ -100,6 +106,8 @@ def __init__(
)
self.grants: Dict[str, Grant] = dict()

self.rotations: List[Dict[str, Any]] = []

def add_grant(
self,
name: str,
Expand Down Expand Up @@ -272,6 +280,15 @@ def get_cfn_attribute(self, attribute_name: str) -> str:


class KmsBackend(BaseBackend):
PAGINATION_MODEL = {
"list_key_rotations": {
"input_token": "next_marker",
"limit_key": "limit",
"limit_default": 1000,
"unique_attribute": "RotationDate",
}
}

def __init__(self, region_name: str, account_id: Optional[str] = None):
super().__init__(region_name=region_name, account_id=account_id) # type: ignore
self.keys: Dict[str, Key] = {}
Expand Down Expand Up @@ -460,6 +477,11 @@ def put_key_policy(self, key_id: str, policy: str) -> None:
def get_key_policy(self, key_id: str) -> str:
return self.keys[self.get_key_id(key_id)].policy

def list_key_policies(self) -> None:
# Marker to indicate this is implemented
# Responses uses 'describe_key'
pass

def disable_key(self, key_id: str) -> None:
self.keys[key_id].enabled = False
self.keys[key_id].key_state = "Disabled"
Expand Down Expand Up @@ -526,6 +548,11 @@ def re_encrypt(
)
return new_ciphertext_blob, decrypting_arn, encrypting_arn

def generate_random(self) -> None:
# Marker to indicate this is implemented
# Responses uses 'os.urandom'
pass

def generate_data_key(
self,
key_id: str,
Expand Down Expand Up @@ -714,5 +741,68 @@ def get_public_key(self, key_id: str) -> Tuple[Key, bytes]:
key = self.describe_key(key_id)
return key, key.private_key.public_key()

def rotate_key_on_demand(self, key_id: str) -> str:
key: Key = self.keys[self.get_key_id(key_id)]

rotation = {
"KeyId": key_id,
"RotationDate": datetime.now().timestamp(),
"RotationType": "ON_DEMAND",
}

# Add to key rotations
key.rotations.append(rotation)

return key_id

@paginate(PAGINATION_MODEL)
def list_key_rotations(
self, key_id: str, limit: int, next_marker: str
) -> List[Dict[str, Union[str, float]]]:
key: Key = self.keys[self.get_key_id(key_id)]

return key.rotations

def generate_mac(
self,
message: bytes,
key_id: str,
mac_algorithm: str,
grant_tokens: List[str],
dry_run: bool,
) -> Tuple[str, str, str]:
key = self.keys[key_id]

if (
key.key_usage != "GENERATE_VERIFY_MAC"
or key.key_spec not in KeySpec.hmac_key_specs()
):
raise InvalidKeyUsageException()

mac = generate_hmac(
key=key.key_material, message=message, mac_algorithm=mac_algorithm
)
return mac, mac_algorithm, key_id

def verify_mac(
self,
message: bytes,
key_id: str,
mac_algorithm: str,
mac: str,
grant_tokens: List[str],
dry_run: bool,
) -> None:
regenerated_mac, _, _ = self.generate_mac(
message=message,
key_id=key_id,
mac_algorithm=mac_algorithm,
grant_tokens=grant_tokens,
dry_run=dry_run,
)

if mac != regenerated_mac:
raise KMSInvalidMacException()


kms_backends = BackendDict(KmsBackend, "kms")
91 changes: 91 additions & 0 deletions moto/kms/responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -604,6 +604,36 @@ def generate_data_key_without_plaintext(self) -> str:

return json.dumps(result)

def generate_mac(self) -> str:
message = self._get_param("Message")
key_id = self._get_param("KeyId")
mac_algorithm = self._get_param("MacAlgorithm")
grant_tokens = self._get_param("GrantTokens")
dry_run = self._get_param("DryRun")

self._validate_key_id(key_id)

mac_algorithms = {
"HMAC_SHA_224",
"HMAC_SHA_256",
"HMAC_SHA_384",
"HMAC_SHA_512",
}
if mac_algorithm and mac_algorithm not in mac_algorithms:
raise ValidationException(
f"MacAlgorithm must be one of {', '.join(mac_algorithms)}"
)

mac, mac_algorithm, key_id = self.kms_backend.generate_mac(
message=message,
key_id=key_id,
mac_algorithm=mac_algorithm,
grant_tokens=grant_tokens,
dry_run=dry_run,
)

return json.dumps(dict(Mac=mac, MacAlgorithm=mac_algorithm, KeyId=key_id))

def generate_random(self) -> str:
"""https://docs.aws.amazon.com/kms/latest/APIReference/API_GenerateRandom.html"""
number_of_bytes = self._get_param("NumberOfBytes")
Expand Down Expand Up @@ -703,6 +733,38 @@ def verify(self) -> str:
}
)

def verify_mac(self) -> str:
message = self._get_param("Message")
mac = self._get_param("Mac")
key_id = self._get_param("KeyId")
mac_algorithm = self._get_param("MacAlgorithm")
grant_tokens = self._get_param("GrantTokens")
dry_run = self._get_param("DryRun")

self._validate_key_id(key_id)

mac_algorithms = {
"HMAC_SHA_224",
"HMAC_SHA_256",
"HMAC_SHA_384",
"HMAC_SHA_512",
}
if mac_algorithm and mac_algorithm not in mac_algorithms:
raise ValidationException(
f"MacAlgorithm must be one of {', '.join(mac_algorithms)}"
)

self.kms_backend.verify_mac(
message=message,
key_id=key_id,
mac_algorithm=mac_algorithm,
mac=mac,
grant_tokens=grant_tokens,
dry_run=dry_run,
)

return json.dumps(dict(KeyId=key_id, MacValid=True, MacAlgorithm=mac_algorithm))

def get_public_key(self) -> str:
key_id = self._get_param("KeyId")

Expand All @@ -719,6 +781,35 @@ def get_public_key(self) -> str:
}
)

def rotate_key_on_demand(self) -> str:
key_id = self._get_param("KeyId")

self._validate_key_id(key_id)

key_id = self.kms_backend.rotate_key_on_demand(
key_id=key_id,
)
return json.dumps(dict(KeyId=key_id))

def list_key_rotations(self) -> str:
key_id = self._get_param("KeyId")
limit = self._get_param("Limit", 1000)
marker = self._get_param("Marker")

self._validate_key_id(key_id)

rotations, next_marker = self.kms_backend.list_key_rotations(
key_id=key_id, limit=limit, next_marker=marker
)
is_truncated = next_marker is not None

response = {"Rotations": rotations, "Truncated": is_truncated}

if is_truncated:
response["NextMarker"] = next_marker

return json.dumps(response)


def _assert_default_policy(policy_name: str) -> None:
if policy_name != "default":
Expand Down
28 changes: 28 additions & 0 deletions moto/kms/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import base64
import hashlib
import hmac
import io
import os
import struct
Expand Down Expand Up @@ -443,3 +446,28 @@ def decrypt(
raise InvalidCiphertextException()

return plaintext, ciphertext.key_id


def generate_hmac(
key: bytes,
message: bytes,
mac_algorithm: str,
) -> str:
"""
Returns a base64 encoded HMAC
"""

algos = {
"HMAC_SHA_224": hashlib.sha224,
"HMAC_SHA_256": hashlib.sha256,
"HMAC_SHA_384": hashlib.sha384,
"HMAC_SHA_512": hashlib.sha512,
}

hmac_val = hmac.new(
key=key,
msg=message,
digestmod=algos[mac_algorithm],
)

return base64.b64encode(hmac_val.digest()).decode("utf-8")
Loading

0 comments on commit db72e95

Please sign in to comment.