diff --git a/moto/kms/exceptions.py b/moto/kms/exceptions.py index 1ebf0d1233b5..5a8ad4fc3e23 100644 --- a/moto/kms/exceptions.py +++ b/moto/kms/exceptions.py @@ -47,3 +47,21 @@ def __init__(self) -> None: super().__init__("InvalidCiphertextException", "") self.description = '{"__type":"InvalidCiphertextException"}' + + +class InvalidKeyUsageException(JsonRESTError): + code = 400 + + def __init__(self) -> None: + super().__init__("InvalidKeyUsageException", "") + + self.description = '{"__type":"InvalidKeyUsageException"}' + + +class KMSInvalidMacException(JsonRESTError): + code = 400 + + def __init__(self) -> None: + super().__init__("KMSInvalidMacException", "") + + self.description = '{"__type":"KMSInvalidMacException"}' diff --git a/moto/kms/models.py b/moto/kms/models.py index 5e9f3f6f0cc0..03b834c37cee 100644 --- a/moto/kms/models.py +++ b/moto/kms/models.py @@ -4,23 +4,29 @@ from collections import defaultdict from copy import copy from datetime import datetime, timedelta -from typing import Any, Dict, Iterable, List, Optional, Set, Tuple +from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union from moto.core.base_backend import BackendDict, BaseBackend from moto.core.common_models import BaseModel, CloudFormationModel from moto.core.exceptions import JsonRESTError from moto.core.utils import unix_time from moto.moto_api._internal import mock_random +from moto.utilities.paginator import paginate from moto.utilities.tagging_service import TaggingService from moto.utilities.utils import get_partition -from .exceptions import ValidationException +from .exceptions import ( + InvalidKeyUsageException, + KMSInvalidMacException, + ValidationException, +) from .utils import ( RESERVED_ALIASES, KeySpec, SigningAlgorithm, decrypt, encrypt, + generate_hmac, generate_key_id, generate_master_key, generate_private_key, @@ -100,6 +106,8 @@ def __init__( ) self.grants: Dict[str, Grant] = dict() + self.rotations: List[Dict[str, Any]] = [] + def add_grant( self, name: str, @@ -272,6 +280,15 @@ def get_cfn_attribute(self, attribute_name: str) -> str: class KmsBackend(BaseBackend): + PAGINATION_MODEL = { + "list_key_rotations": { + "input_token": "next_marker", + "limit_key": "limit", + "limit_default": 1000, + "unique_attribute": "RotationDate", + } + } + def __init__(self, region_name: str, account_id: Optional[str] = None): super().__init__(region_name=region_name, account_id=account_id) # type: ignore self.keys: Dict[str, Key] = {} @@ -460,6 +477,11 @@ def put_key_policy(self, key_id: str, policy: str) -> None: def get_key_policy(self, key_id: str) -> str: return self.keys[self.get_key_id(key_id)].policy + def list_key_policies(self) -> None: + # Marker to indicate this is implemented + # Responses uses 'describe_key' + pass + def disable_key(self, key_id: str) -> None: self.keys[key_id].enabled = False self.keys[key_id].key_state = "Disabled" @@ -526,6 +548,11 @@ def re_encrypt( ) return new_ciphertext_blob, decrypting_arn, encrypting_arn + def generate_random(self) -> None: + # Marker to indicate this is implemented + # Responses uses 'os.urandom' + pass + def generate_data_key( self, key_id: str, @@ -714,5 +741,68 @@ def get_public_key(self, key_id: str) -> Tuple[Key, bytes]: key = self.describe_key(key_id) return key, key.private_key.public_key() + def rotate_key_on_demand(self, key_id: str) -> str: + key: Key = self.keys[self.get_key_id(key_id)] + + rotation = { + "KeyId": key_id, + "RotationDate": datetime.now().timestamp(), + "RotationType": "ON_DEMAND", + } + + # Add to key rotations + key.rotations.append(rotation) + + return key_id + + @paginate(PAGINATION_MODEL) + def list_key_rotations( + self, key_id: str, limit: int, next_marker: str + ) -> List[Dict[str, Union[str, float]]]: + key: Key = self.keys[self.get_key_id(key_id)] + + return key.rotations + + def generate_mac( + self, + message: bytes, + key_id: str, + mac_algorithm: str, + grant_tokens: List[str], + dry_run: bool, + ) -> Tuple[str, str, str]: + key = self.keys[key_id] + + if ( + key.key_usage != "GENERATE_VERIFY_MAC" + or key.key_spec not in KeySpec.hmac_key_specs() + ): + raise InvalidKeyUsageException() + + mac = generate_hmac( + key=key.key_material, message=message, mac_algorithm=mac_algorithm + ) + return mac, mac_algorithm, key_id + + def verify_mac( + self, + message: bytes, + key_id: str, + mac_algorithm: str, + mac: str, + grant_tokens: List[str], + dry_run: bool, + ) -> None: + regenerated_mac, _, _ = self.generate_mac( + message=message, + key_id=key_id, + mac_algorithm=mac_algorithm, + grant_tokens=grant_tokens, + dry_run=dry_run, + ) + + if mac != regenerated_mac: + raise KMSInvalidMacException() + kms_backends = BackendDict(KmsBackend, "kms") diff --git a/moto/kms/responses.py b/moto/kms/responses.py index 7f5b9b674004..a5c90a268788 100644 --- a/moto/kms/responses.py +++ b/moto/kms/responses.py @@ -604,6 +604,36 @@ def generate_data_key_without_plaintext(self) -> str: return json.dumps(result) + def generate_mac(self) -> str: + message = self._get_param("Message") + key_id = self._get_param("KeyId") + mac_algorithm = self._get_param("MacAlgorithm") + grant_tokens = self._get_param("GrantTokens") + dry_run = self._get_param("DryRun") + + self._validate_key_id(key_id) + + mac_algorithms = { + "HMAC_SHA_224", + "HMAC_SHA_256", + "HMAC_SHA_384", + "HMAC_SHA_512", + } + if mac_algorithm and mac_algorithm not in mac_algorithms: + raise ValidationException( + f"MacAlgorithm must be one of {', '.join(mac_algorithms)}" + ) + + mac, mac_algorithm, key_id = self.kms_backend.generate_mac( + message=message, + key_id=key_id, + mac_algorithm=mac_algorithm, + grant_tokens=grant_tokens, + dry_run=dry_run, + ) + + return json.dumps(dict(Mac=mac, MacAlgorithm=mac_algorithm, KeyId=key_id)) + def generate_random(self) -> str: """https://docs.aws.amazon.com/kms/latest/APIReference/API_GenerateRandom.html""" number_of_bytes = self._get_param("NumberOfBytes") @@ -703,6 +733,38 @@ def verify(self) -> str: } ) + def verify_mac(self) -> str: + message = self._get_param("Message") + mac = self._get_param("Mac") + key_id = self._get_param("KeyId") + mac_algorithm = self._get_param("MacAlgorithm") + grant_tokens = self._get_param("GrantTokens") + dry_run = self._get_param("DryRun") + + self._validate_key_id(key_id) + + mac_algorithms = { + "HMAC_SHA_224", + "HMAC_SHA_256", + "HMAC_SHA_384", + "HMAC_SHA_512", + } + if mac_algorithm and mac_algorithm not in mac_algorithms: + raise ValidationException( + f"MacAlgorithm must be one of {', '.join(mac_algorithms)}" + ) + + self.kms_backend.verify_mac( + message=message, + key_id=key_id, + mac_algorithm=mac_algorithm, + mac=mac, + grant_tokens=grant_tokens, + dry_run=dry_run, + ) + + return json.dumps(dict(KeyId=key_id, MacValid=True, MacAlgorithm=mac_algorithm)) + def get_public_key(self) -> str: key_id = self._get_param("KeyId") @@ -719,6 +781,35 @@ def get_public_key(self) -> str: } ) + def rotate_key_on_demand(self) -> str: + key_id = self._get_param("KeyId") + + self._validate_key_id(key_id) + + key_id = self.kms_backend.rotate_key_on_demand( + key_id=key_id, + ) + return json.dumps(dict(KeyId=key_id)) + + def list_key_rotations(self) -> str: + key_id = self._get_param("KeyId") + limit = self._get_param("Limit", 1000) + marker = self._get_param("Marker") + + self._validate_key_id(key_id) + + rotations, next_marker = self.kms_backend.list_key_rotations( + key_id=key_id, limit=limit, next_marker=marker + ) + is_truncated = next_marker is not None + + response = {"Rotations": rotations, "Truncated": is_truncated} + + if is_truncated: + response["NextMarker"] = next_marker + + return json.dumps(response) + def _assert_default_policy(policy_name: str) -> None: if policy_name != "default": diff --git a/moto/kms/utils.py b/moto/kms/utils.py index 0d63d052c38d..d1d510972c77 100644 --- a/moto/kms/utils.py +++ b/moto/kms/utils.py @@ -1,3 +1,6 @@ +import base64 +import hashlib +import hmac import io import os import struct @@ -443,3 +446,28 @@ def decrypt( raise InvalidCiphertextException() return plaintext, ciphertext.key_id + + +def generate_hmac( + key: bytes, + message: bytes, + mac_algorithm: str, +) -> str: + """ + Returns a base64 encoded HMAC + """ + + algos = { + "HMAC_SHA_224": hashlib.sha224, + "HMAC_SHA_256": hashlib.sha256, + "HMAC_SHA_384": hashlib.sha384, + "HMAC_SHA_512": hashlib.sha512, + } + + hmac_val = hmac.new( + key=key, + msg=message, + digestmod=algos[mac_algorithm], + ) + + return base64.b64encode(hmac_val.digest()).decode("utf-8") diff --git a/tests/test_kms/test_kms_key_rotation.py b/tests/test_kms/test_kms_key_rotation.py new file mode 100644 index 000000000000..f809a78b3330 --- /dev/null +++ b/tests/test_kms/test_kms_key_rotation.py @@ -0,0 +1,106 @@ +"""Unit tests for kms-supported APIs.""" + +import boto3 +import pytest + +from moto import mock_aws + +# See our Development Tips on writing tests for hints on how to write good tests: +# http://docs.getmoto.org/en/latest/docs/contributing/development_tips/tests.html + + +@mock_aws +def test_rotate_key_on_demand_with_existing_key(): + client = boto3.client("kms", region_name="us-east-2") + + key_id = client.create_key(Policy="my policy")["KeyMetadata"]["KeyId"] + + resp = client.rotate_key_on_demand(KeyId=key_id) + + assert resp["KeyId"] == key_id + + +@mock_aws +def test_rotate_key_on_demand_with_non_existing_key_fails(): + client = boto3.client("kms", region_name="us-east-2") + + with pytest.raises(client.exceptions.NotFoundException): + client.rotate_key_on_demand(KeyId="some-id") + + +@mock_aws +def test_list_key_rotations_with_non_existing_key_fails(): + client = boto3.client("kms", region_name="us-east-2") + + with pytest.raises(client.exceptions.NotFoundException): + client.list_key_rotations(KeyId="some-id") + + +@mock_aws +def test_list_key_rotations_are_empty_on_new_key(): + client = boto3.client("kms", region_name="us-east-2") + + key_id = client.create_key(Policy="my policy")["KeyMetadata"]["KeyId"] + + resp = client.list_key_rotations(KeyId=key_id) + + assert len(resp["Rotations"]) == 0 + assert resp["Truncated"] is False + assert "NextMarker" not in resp + + +@mock_aws +def test_list_key_rotations_returns_one_rotation(): + # Arrange + client = boto3.client("kms", region_name="us-east-2") + key_id = client.create_key(Policy="my policy")["KeyMetadata"]["KeyId"] + client.rotate_key_on_demand(KeyId=key_id) + + # Act + resp = client.list_key_rotations(KeyId=key_id) + + # Assert + assert len(resp["Rotations"]) == 1 + assert resp["Truncated"] is False + assert "NextMarker" not in resp + assert resp["Rotations"][0]["RotationType"] == "ON_DEMAND" + + +@mock_aws +def test_list_key_rotations_returns_truncated_and_next_marker(): + # Arrange + client = boto3.client("kms", region_name="us-east-2") + key_id = client.create_key(Policy="my policy")["KeyMetadata"]["KeyId"] + client.rotate_key_on_demand(KeyId=key_id) + client.rotate_key_on_demand(KeyId=key_id) + client.rotate_key_on_demand(KeyId=key_id) + + # Act + resp = client.list_key_rotations(KeyId=key_id, Limit=1) + + # Assert + assert len(resp["Rotations"]) == 1 + assert resp["Truncated"] is True + assert "NextMarker" in resp + + +@mock_aws +def test_list_key_rotations_pagination(): + # Arrange + client = boto3.client("kms", region_name="us-east-2") + key_id = client.create_key(Policy="my policy")["KeyMetadata"]["KeyId"] + client.rotate_key_on_demand(KeyId=key_id) + client.rotate_key_on_demand(KeyId=key_id) + client.rotate_key_on_demand(KeyId=key_id) + initial_page = client.list_key_rotations(KeyId=key_id, Limit=1) + + # Act + + final_page = client.list_key_rotations( + KeyId=key_id, Limit=2, Marker=initial_page["NextMarker"] + ) + + # Assert + assert len(final_page["Rotations"]) == 2 + assert final_page["Truncated"] is False + assert "NextMarker" not in final_page diff --git a/tests/test_kms/test_kms_mac.py b/tests/test_kms/test_kms_mac.py new file mode 100644 index 000000000000..cf209f9a742a --- /dev/null +++ b/tests/test_kms/test_kms_mac.py @@ -0,0 +1,135 @@ +"""Unit tests for kms-supported APIs.""" + +import base64 + +import boto3 +import pytest + +from moto import mock_aws + +# See our Development Tips on writing tests for hints on how to write good tests: +# http://docs.getmoto.org/en/latest/docs/contributing/development_tips/tests.html + + +def create_hmac_key() -> str: + client = boto3.client("kms", region_name="eu-central-1") + + key_id = client.create_key( + KeyUsage="GENERATE_VERIFY_MAC", KeySpec="HMAC_512", Policy="My Policy" + )["KeyMetadata"]["KeyId"] + + return key_id + + +@mock_aws +def test_generate_mac(): + # Arrange + key_id = create_hmac_key() + client = boto3.client("kms", region_name="eu-central-1") + + # Act + resp = client.generate_mac( + KeyId=key_id, + MacAlgorithm="HMAC_SHA_512", + Message=base64.b64encode("Hello World".encode("utf-8")), + ) + + # Assert + assert "Mac" in resp + assert resp["KeyId"] == key_id + assert resp["MacAlgorithm"] == "HMAC_SHA_512" + + +@mock_aws +def test_generate_fails_for_non_existent_key(): + # Arrange + client = boto3.client("kms", region_name="eu-central-1") + + # Act + Assert + with pytest.raises(client.exceptions.NotFoundException): + client.generate_mac( + KeyId="some-key", + MacAlgorithm="HMAC_SHA_512", + Message=base64.b64encode("Hello World".encode("utf-8")), + ) + + +@mock_aws +def test_generate_fails_for_invalid_key_usage(): + # Arrange + client = boto3.client("kms", region_name="eu-central-1") + key_id = client.create_key( + KeyUsage="ENCRYPT_DECRYPT", KeySpec="HMAC_512", Policy="My Policy" + )["KeyMetadata"]["KeyId"] + + # Act + Assert + with pytest.raises(client.exceptions.InvalidKeyUsageException): + client.generate_mac( + KeyId=key_id, + MacAlgorithm="HMAC_SHA_512", + Message=base64.b64encode("Hello World".encode("utf-8")), + ) + + +@mock_aws +def test_generate_fails_for_invalid_key_spec(): + # Arrange + client = boto3.client("kms", region_name="eu-central-1") + key_id = client.create_key( + KeyUsage="GENERATE_VERIFY_MAC", KeySpec="RSA_2048", Policy="My Policy" + )["KeyMetadata"]["KeyId"] + + # Act + Assert + with pytest.raises(client.exceptions.InvalidKeyUsageException): + client.generate_mac( + KeyId=key_id, + MacAlgorithm="HMAC_SHA_512", + Message=base64.b64encode("Hello World".encode("utf-8")), + ) + + +@mock_aws +def test_verify_mac(): + # Arrange + key_id = create_hmac_key() + client = boto3.client("kms", region_name="eu-central-1") + mac = client.generate_mac( + KeyId=key_id, + MacAlgorithm="HMAC_SHA_512", + Message=base64.b64encode("Hello World".encode("utf-8")), + )["Mac"] + + # Act + resp = client.verify_mac( + KeyId=key_id, + MacAlgorithm="HMAC_SHA_512", + Message=base64.b64encode("Hello World".encode("utf-8")), + Mac=mac, + ) + + # Assert + assert resp["KeyId"] == key_id + assert resp["MacAlgorithm"] == "HMAC_SHA_512" + assert resp["MacValid"] is True + + +@mock_aws +def test_verify_mac_fails_for_another_key_id(): + # Arrange + key_id = create_hmac_key() + other_key_id = create_hmac_key() + client = boto3.client("kms", region_name="eu-central-1") + mac = client.generate_mac( + KeyId=key_id, + MacAlgorithm="HMAC_SHA_512", + Message=base64.b64encode("Hello World".encode("utf-8")), + )["Mac"] + + # Act + Assert + with pytest.raises(client.exceptions.KMSInvalidMacException): + client.verify_mac( + KeyId=other_key_id, + MacAlgorithm="HMAC_SHA_512", + Message=base64.b64encode("Hello World".encode("utf-8")), + Mac=mac, + )