diff --git a/moto/kinesis/models.py b/moto/kinesis/models.py index 90f3672e6932..8efedbdb9641 100644 --- a/moto/kinesis/models.py +++ b/moto/kinesis/models.py @@ -4,6 +4,7 @@ import itertools from operator import attrgetter +from typing import Any, Dict, List, Optional, Tuple from moto.core import BaseBackend, BackendDict, BaseModel, CloudFormationModel from moto.core.utils import unix_time @@ -439,12 +440,14 @@ def create_from_cloudformation_json( for tag_item in properties.get("Tags", []) } - backend = kinesis_backends[account_id][region_name] + backend: KinesisBackend = kinesis_backends[account_id][region_name] stream = backend.create_stream( resource_name, shard_count, retention_period_hours=retention_period_hours ) if any(tags): - backend.add_tags_to_stream(stream.stream_name, tags) + backend.add_tags_to_stream( + stream_arn=None, stream_name=stream.stream_name, tags=tags + ) return stream @classmethod @@ -489,8 +492,8 @@ def update_from_cloudformation_json( def delete_from_cloudformation_json( cls, resource_name, cloudformation_json, account_id, region_name ): - backend = kinesis_backends[account_id][region_name] - backend.delete_stream(resource_name) + backend: KinesisBackend = kinesis_backends[account_id][region_name] + backend.delete_stream(stream_arn=None, stream_name=resource_name) @staticmethod def is_replacement_update(properties): @@ -521,7 +524,7 @@ def physical_resource_id(self): class KinesisBackend(BaseBackend): def __init__(self, region_name, account_id): super().__init__(region_name, account_id) - self.streams = OrderedDict() + self.streams: Dict[str, Stream] = OrderedDict() @staticmethod def default_vpc_endpoint_service(service_region, zones): @@ -546,38 +549,49 @@ def create_stream( self.streams[stream_name] = stream return stream - def describe_stream(self, stream_name) -> Stream: - if stream_name in self.streams: + def describe_stream( + self, stream_arn: Optional[str], stream_name: Optional[str] + ) -> Stream: + if stream_name and stream_name in self.streams: return self.streams[stream_name] - else: - raise StreamNotFoundError(stream_name, self.account_id) + if stream_arn: + for stream in self.streams.values(): + if stream.arn == stream_arn: + return stream + if stream_arn: + stream_name = stream_arn.split("/")[1] + raise StreamNotFoundError(stream_name, self.account_id) - def describe_stream_summary(self, stream_name): - return self.describe_stream(stream_name) + def describe_stream_summary( + self, stream_arn: Optional[str], stream_name: Optional[str] + ) -> Stream: + return self.describe_stream(stream_arn=stream_arn, stream_name=stream_name) def list_streams(self): return self.streams.values() - def delete_stream(self, stream_name): - if stream_name in self.streams: - return self.streams.pop(stream_name) - raise StreamNotFoundError(stream_name, self.account_id) + def delete_stream( + self, stream_arn: Optional[str], stream_name: Optional[str] + ) -> Stream: + stream = self.describe_stream(stream_arn=stream_arn, stream_name=stream_name) + return self.streams.pop(stream.stream_name) def get_shard_iterator( self, - stream_name, - shard_id, - shard_iterator_type, - starting_sequence_number, - at_timestamp, + stream_arn: Optional[str], + stream_name: Optional[str], + shard_id: str, + shard_iterator_type: str, + starting_sequence_number: str, + at_timestamp: str, ): # Validate params - stream = self.describe_stream(stream_name) + stream = self.describe_stream(stream_arn=stream_arn, stream_name=stream_name) try: shard = stream.get_shard(shard_id) except ShardNotFoundError: raise ResourceNotFoundError( - message=f"Shard {shard_id} in stream {stream_name} under account {self.account_id} does not exist" + message=f"Shard {shard_id} in stream {stream.stream_name} under account {self.account_id} does not exist" ) shard_iterator = compose_new_shard_iterator( @@ -589,11 +603,13 @@ def get_shard_iterator( ) return shard_iterator - def get_records(self, shard_iterator, limit): + def get_records( + self, stream_arn: Optional[str], shard_iterator: str, limit: Optional[int] + ): decomposed = decompose_shard_iterator(shard_iterator) stream_name, shard_id, last_sequence_id = decomposed - stream = self.describe_stream(stream_name) + stream = self.describe_stream(stream_arn=stream_arn, stream_name=stream_name) shard = stream.get_shard(shard_id) records, last_sequence_id, millis_behind_latest = shard.get_records( @@ -608,12 +624,13 @@ def get_records(self, shard_iterator, limit): def put_record( self, + stream_arn, stream_name, partition_key, explicit_hash_key, data, ): - stream = self.describe_stream(stream_name) + stream = self.describe_stream(stream_arn=stream_arn, stream_name=stream_name) sequence_number, shard_id = stream.put_record( partition_key, explicit_hash_key, data @@ -621,8 +638,8 @@ def put_record( return sequence_number, shard_id - def put_records(self, stream_name, records): - stream = self.describe_stream(stream_name) + def put_records(self, stream_arn, stream_name, records): + stream = self.describe_stream(stream_arn=stream_arn, stream_name=stream_name) response = {"FailedRecordCount": 0, "Records": []} @@ -651,8 +668,10 @@ def put_records(self, stream_name, records): return response - def split_shard(self, stream_name, shard_to_split, new_starting_hash_key): - stream = self.describe_stream(stream_name) + def split_shard( + self, stream_arn, stream_name, shard_to_split, new_starting_hash_key + ): + stream = self.describe_stream(stream_arn=stream_arn, stream_name=stream_name) if not re.match("[a-zA-Z0-9_.-]+", shard_to_split): raise ValidationException( @@ -675,23 +694,27 @@ def split_shard(self, stream_name, shard_to_split, new_starting_hash_key): stream.split_shard(shard_to_split, new_starting_hash_key) - def merge_shards(self, stream_name, shard_to_merge, adjacent_shard_to_merge): - stream = self.describe_stream(stream_name) + def merge_shards( + self, stream_arn, stream_name, shard_to_merge, adjacent_shard_to_merge + ): + stream = self.describe_stream(stream_arn=stream_arn, stream_name=stream_name) if shard_to_merge not in stream.shards: raise ShardNotFoundError( - shard_to_merge, stream=stream_name, account_id=self.account_id + shard_to_merge, stream=stream.stream_name, account_id=self.account_id ) if adjacent_shard_to_merge not in stream.shards: raise ShardNotFoundError( - adjacent_shard_to_merge, stream=stream_name, account_id=self.account_id + adjacent_shard_to_merge, + stream=stream.stream_name, + account_id=self.account_id, ) stream.merge_shards(shard_to_merge, adjacent_shard_to_merge) - def update_shard_count(self, stream_name, target_shard_count): - stream = self.describe_stream(stream_name) + def update_shard_count(self, stream_arn, stream_name, target_shard_count): + stream = self.describe_stream(stream_arn=stream_arn, stream_name=stream_name) current_shard_count = len([s for s in stream.shards.values() if s.is_open]) stream.update_shard_count(target_shard_count) @@ -699,13 +722,18 @@ def update_shard_count(self, stream_name, target_shard_count): return current_shard_count @paginate(pagination_model=PAGINATION_MODEL) - def list_shards(self, stream_name): - stream = self.describe_stream(stream_name) + def list_shards(self, stream_arn: Optional[str], stream_name: Optional[str]): + stream = self.describe_stream(stream_arn=stream_arn, stream_name=stream_name) shards = sorted(stream.shards.values(), key=lambda x: x.shard_id) return [shard.to_json() for shard in shards] - def increase_stream_retention_period(self, stream_name, retention_period_hours): - stream = self.describe_stream(stream_name) + def increase_stream_retention_period( + self, + stream_arn: Optional[str], + stream_name: Optional[str], + retention_period_hours: int, + ) -> None: + stream = self.describe_stream(stream_arn=stream_arn, stream_name=stream_name) if retention_period_hours < 24: raise InvalidRetentionPeriod(retention_period_hours, too_short=True) if retention_period_hours > 8760: @@ -718,8 +746,13 @@ def increase_stream_retention_period(self, stream_name, retention_period_hours): ) stream.retention_period_hours = retention_period_hours - def decrease_stream_retention_period(self, stream_name, retention_period_hours): - stream = self.describe_stream(stream_name) + def decrease_stream_retention_period( + self, + stream_arn: Optional[str], + stream_name: Optional[str], + retention_period_hours: int, + ) -> None: + stream = self.describe_stream(stream_arn=stream_arn, stream_name=stream_name) if retention_period_hours < 24: raise InvalidRetentionPeriod(retention_period_hours, too_short=True) if retention_period_hours > 8760: @@ -733,9 +766,9 @@ def decrease_stream_retention_period(self, stream_name, retention_period_hours): stream.retention_period_hours = retention_period_hours def list_tags_for_stream( - self, stream_name, exclusive_start_tag_key=None, limit=None + self, stream_arn, stream_name, exclusive_start_tag_key=None, limit=None ): - stream = self.describe_stream(stream_name) + stream = self.describe_stream(stream_arn=stream_arn, stream_name=stream_name) tags = [] result = {"HasMoreTags": False, "Tags": tags} @@ -750,25 +783,47 @@ def list_tags_for_stream( return result - def add_tags_to_stream(self, stream_name, tags): - stream = self.describe_stream(stream_name) + def add_tags_to_stream( + self, + stream_arn: Optional[str], + stream_name: Optional[str], + tags: Dict[str, str], + ) -> None: + stream = self.describe_stream(stream_arn=stream_arn, stream_name=stream_name) stream.tags.update(tags) - def remove_tags_from_stream(self, stream_name, tag_keys): - stream = self.describe_stream(stream_name) + def remove_tags_from_stream( + self, stream_arn: Optional[str], stream_name: Optional[str], tag_keys: List[str] + ) -> None: + stream = self.describe_stream(stream_arn=stream_arn, stream_name=stream_name) for key in tag_keys: if key in stream.tags: del stream.tags[key] - def enable_enhanced_monitoring(self, stream_name, shard_level_metrics): - stream = self.describe_stream(stream_name) + def enable_enhanced_monitoring( + self, + stream_arn: Optional[str], + stream_name: Optional[str], + shard_level_metrics: List[str], + ) -> Tuple[str, str, Dict[str, Any], Dict[str, Any]]: + stream = self.describe_stream(stream_arn=stream_arn, stream_name=stream_name) current_shard_level_metrics = stream.shard_level_metrics desired_metrics = list(set(current_shard_level_metrics + shard_level_metrics)) stream.shard_level_metrics = desired_metrics - return current_shard_level_metrics, desired_metrics + return ( + stream.arn, + stream.stream_name, + current_shard_level_metrics, + desired_metrics, + ) - def disable_enhanced_monitoring(self, stream_name, to_be_disabled): - stream = self.describe_stream(stream_name) + def disable_enhanced_monitoring( + self, + stream_arn: Optional[str], + stream_name: Optional[str], + to_be_disabled: List[str], + ) -> Tuple[str, str, Dict[str, Any], Dict[str, Any]]: + stream = self.describe_stream(stream_arn=stream_arn, stream_name=stream_name) current_metrics = stream.shard_level_metrics if "ALL" in to_be_disabled: desired_metrics = [] @@ -777,7 +832,7 @@ def disable_enhanced_monitoring(self, stream_name, to_be_disabled): metric for metric in current_metrics if metric not in to_be_disabled ] stream.shard_level_metrics = desired_metrics - return current_metrics, desired_metrics + return stream.arn, stream.stream_name, current_metrics, desired_metrics def _find_stream_by_arn(self, stream_arn): for stream in self.streams.values(): @@ -826,13 +881,13 @@ def deregister_stream_consumer(self, stream_arn, consumer_name, consumer_arn): # It will be a noop for other streams stream.delete_consumer(consumer_arn) - def start_stream_encryption(self, stream_name, encryption_type, key_id): - stream = self.describe_stream(stream_name) + def start_stream_encryption(self, stream_arn, stream_name, encryption_type, key_id): + stream = self.describe_stream(stream_arn=stream_arn, stream_name=stream_name) stream.encryption_type = encryption_type stream.key_id = key_id - def stop_stream_encryption(self, stream_name): - stream = self.describe_stream(stream_name) + def stop_stream_encryption(self, stream_arn, stream_name): + stream = self.describe_stream(stream_arn=stream_arn, stream_name=stream_name) stream.encryption_type = "NONE" stream.key_id = None diff --git a/moto/kinesis/responses.py b/moto/kinesis/responses.py index 59b09c0cea35..92f8501c1fbf 100644 --- a/moto/kinesis/responses.py +++ b/moto/kinesis/responses.py @@ -1,7 +1,7 @@ import json from moto.core.responses import BaseResponse -from .models import kinesis_backends +from .models import kinesis_backends, KinesisBackend class KinesisResponse(BaseResponse): @@ -13,7 +13,7 @@ def parameters(self): return json.loads(self.body) @property - def kinesis_backend(self): + def kinesis_backend(self) -> KinesisBackend: return kinesis_backends[self.current_account][self.region] def create_stream(self): @@ -27,13 +27,15 @@ def create_stream(self): def describe_stream(self): stream_name = self.parameters.get("StreamName") + stream_arn = self.parameters.get("StreamARN") limit = self.parameters.get("Limit") - stream = self.kinesis_backend.describe_stream(stream_name) + stream = self.kinesis_backend.describe_stream(stream_arn, stream_name) return json.dumps(stream.to_json(shard_limit=limit)) def describe_stream_summary(self): + stream_arn = self.parameters.get("StreamARN") stream_name = self.parameters.get("StreamName") - stream = self.kinesis_backend.describe_stream_summary(stream_name) + stream = self.kinesis_backend.describe_stream_summary(stream_arn, stream_name) return json.dumps(stream.to_json_summary()) def list_streams(self): @@ -58,11 +60,13 @@ def list_streams(self): ) def delete_stream(self): + stream_arn = self.parameters.get("StreamARN") stream_name = self.parameters.get("StreamName") - self.kinesis_backend.delete_stream(stream_name) + self.kinesis_backend.delete_stream(stream_arn, stream_name) return "" def get_shard_iterator(self): + stream_arn = self.parameters.get("StreamARN") stream_name = self.parameters.get("StreamName") shard_id = self.parameters.get("ShardId") shard_iterator_type = self.parameters.get("ShardIteratorType") @@ -70,6 +74,7 @@ def get_shard_iterator(self): at_timestamp = self.parameters.get("Timestamp") shard_iterator = self.kinesis_backend.get_shard_iterator( + stream_arn, stream_name, shard_id, shard_iterator_type, @@ -80,6 +85,7 @@ def get_shard_iterator(self): return json.dumps({"ShardIterator": shard_iterator}) def get_records(self): + stream_arn = self.parameters.get("StreamARN") shard_iterator = self.parameters.get("ShardIterator") limit = self.parameters.get("Limit") @@ -87,7 +93,7 @@ def get_records(self): next_shard_iterator, records, millis_behind_latest, - ) = self.kinesis_backend.get_records(shard_iterator, limit) + ) = self.kinesis_backend.get_records(stream_arn, shard_iterator, limit) return json.dumps( { @@ -98,12 +104,14 @@ def get_records(self): ) def put_record(self): + stream_arn = self.parameters.get("StreamARN") stream_name = self.parameters.get("StreamName") partition_key = self.parameters.get("PartitionKey") explicit_hash_key = self.parameters.get("ExplicitHashKey") data = self.parameters.get("Data") sequence_number, shard_id = self.kinesis_backend.put_record( + stream_arn, stream_name, partition_key, explicit_hash_key, @@ -113,37 +121,44 @@ def put_record(self): return json.dumps({"SequenceNumber": sequence_number, "ShardId": shard_id}) def put_records(self): + stream_arn = self.parameters.get("StreamARN") stream_name = self.parameters.get("StreamName") records = self.parameters.get("Records") - response = self.kinesis_backend.put_records(stream_name, records) + response = self.kinesis_backend.put_records(stream_arn, stream_name, records) return json.dumps(response) def split_shard(self): + stream_arn = self.parameters.get("StreamARN") stream_name = self.parameters.get("StreamName") shard_to_split = self.parameters.get("ShardToSplit") new_starting_hash_key = self.parameters.get("NewStartingHashKey") self.kinesis_backend.split_shard( - stream_name, shard_to_split, new_starting_hash_key + stream_arn, stream_name, shard_to_split, new_starting_hash_key ) return "" def merge_shards(self): + stream_arn = self.parameters.get("StreamARN") stream_name = self.parameters.get("StreamName") shard_to_merge = self.parameters.get("ShardToMerge") adjacent_shard_to_merge = self.parameters.get("AdjacentShardToMerge") self.kinesis_backend.merge_shards( - stream_name, shard_to_merge, adjacent_shard_to_merge + stream_arn, stream_name, shard_to_merge, adjacent_shard_to_merge ) return "" def list_shards(self): + stream_arn = self.parameters.get("StreamARN") stream_name = self.parameters.get("StreamName") next_token = self.parameters.get("NextToken") max_results = self.parameters.get("MaxResults", 10000) shards, token = self.kinesis_backend.list_shards( - stream_name=stream_name, limit=max_results, next_token=next_token + stream_arn=stream_arn, + stream_name=stream_name, + limit=max_results, + next_token=next_token, ) res = {"Shards": shards} if token: @@ -151,10 +166,13 @@ def list_shards(self): return json.dumps(res) def update_shard_count(self): + stream_arn = self.parameters.get("StreamARN") stream_name = self.parameters.get("StreamName") target_shard_count = self.parameters.get("TargetShardCount") current_shard_count = self.kinesis_backend.update_shard_count( - stream_name=stream_name, target_shard_count=target_shard_count + stream_arn=stream_arn, + stream_name=stream_name, + target_shard_count=target_shard_count, ) return json.dumps( dict( @@ -165,67 +183,80 @@ def update_shard_count(self): ) def increase_stream_retention_period(self): + stream_arn = self.parameters.get("StreamARN") stream_name = self.parameters.get("StreamName") retention_period_hours = self.parameters.get("RetentionPeriodHours") self.kinesis_backend.increase_stream_retention_period( - stream_name, retention_period_hours + stream_arn, stream_name, retention_period_hours ) return "" def decrease_stream_retention_period(self): + stream_arn = self.parameters.get("StreamARN") stream_name = self.parameters.get("StreamName") retention_period_hours = self.parameters.get("RetentionPeriodHours") self.kinesis_backend.decrease_stream_retention_period( - stream_name, retention_period_hours + stream_arn, stream_name, retention_period_hours ) return "" def add_tags_to_stream(self): + stream_arn = self.parameters.get("StreamARN") stream_name = self.parameters.get("StreamName") tags = self.parameters.get("Tags") - self.kinesis_backend.add_tags_to_stream(stream_name, tags) + self.kinesis_backend.add_tags_to_stream(stream_arn, stream_name, tags) return json.dumps({}) def list_tags_for_stream(self): + stream_arn = self.parameters.get("StreamARN") stream_name = self.parameters.get("StreamName") exclusive_start_tag_key = self.parameters.get("ExclusiveStartTagKey") limit = self.parameters.get("Limit") response = self.kinesis_backend.list_tags_for_stream( - stream_name, exclusive_start_tag_key, limit + stream_arn, stream_name, exclusive_start_tag_key, limit ) return json.dumps(response) def remove_tags_from_stream(self): + stream_arn = self.parameters.get("StreamARN") stream_name = self.parameters.get("StreamName") tag_keys = self.parameters.get("TagKeys") - self.kinesis_backend.remove_tags_from_stream(stream_name, tag_keys) + self.kinesis_backend.remove_tags_from_stream(stream_arn, stream_name, tag_keys) return json.dumps({}) def enable_enhanced_monitoring(self): + stream_arn = self.parameters.get("StreamARN") stream_name = self.parameters.get("StreamName") shard_level_metrics = self.parameters.get("ShardLevelMetrics") - current, desired = self.kinesis_backend.enable_enhanced_monitoring( - stream_name=stream_name, shard_level_metrics=shard_level_metrics + arn, name, current, desired = self.kinesis_backend.enable_enhanced_monitoring( + stream_arn=stream_arn, + stream_name=stream_name, + shard_level_metrics=shard_level_metrics, ) return json.dumps( dict( - StreamName=stream_name, + StreamName=name, CurrentShardLevelMetrics=current, DesiredShardLevelMetrics=desired, + StreamARN=arn, ) ) def disable_enhanced_monitoring(self): + stream_arn = self.parameters.get("StreamARN") stream_name = self.parameters.get("StreamName") shard_level_metrics = self.parameters.get("ShardLevelMetrics") - current, desired = self.kinesis_backend.disable_enhanced_monitoring( - stream_name=stream_name, to_be_disabled=shard_level_metrics + arn, name, current, desired = self.kinesis_backend.disable_enhanced_monitoring( + stream_arn=stream_arn, + stream_name=stream_name, + to_be_disabled=shard_level_metrics, ) return json.dumps( dict( - StreamName=stream_name, + StreamName=name, CurrentShardLevelMetrics=current, DesiredShardLevelMetrics=desired, + StreamARN=arn, ) ) @@ -267,17 +298,24 @@ def deregister_stream_consumer(self): return json.dumps(dict()) def start_stream_encryption(self): + stream_arn = self.parameters.get("StreamARN") stream_name = self.parameters.get("StreamName") encryption_type = self.parameters.get("EncryptionType") key_id = self.parameters.get("KeyId") self.kinesis_backend.start_stream_encryption( - stream_name=stream_name, encryption_type=encryption_type, key_id=key_id + stream_arn=stream_arn, + stream_name=stream_name, + encryption_type=encryption_type, + key_id=key_id, ) return json.dumps(dict()) def stop_stream_encryption(self): + stream_arn = self.parameters.get("StreamARN") stream_name = self.parameters.get("StreamName") - self.kinesis_backend.stop_stream_encryption(stream_name=stream_name) + self.kinesis_backend.stop_stream_encryption( + stream_arn=stream_arn, stream_name=stream_name + ) return json.dumps(dict()) def update_stream_mode(self): diff --git a/moto/kinesis/urls.py b/moto/kinesis/urls.py index 94f8028599c0..5c6ab5288397 100644 --- a/moto/kinesis/urls.py +++ b/moto/kinesis/urls.py @@ -6,6 +6,8 @@ # Somewhere around boto3-1.26.31 botocore-1.29.31, AWS started using a new endpoint: # 111122223333.control-kinesis.us-east-1.amazonaws.com r"https?://(.+)\.control-kinesis\.(.+)\.amazonaws\.com", + # When passing in the StreamARN to get_shard_iterator/get_records, this endpoint is called: + r"https?://(.+)\.data-kinesis\.(.+)\.amazonaws\.com", ] url_paths = {"{0}/$": KinesisResponse.dispatch} diff --git a/tests/test_kinesis/test_kinesis.py b/tests/test_kinesis/test_kinesis.py index 005e3382ae9b..116e8d73361c 100644 --- a/tests/test_kinesis/test_kinesis.py +++ b/tests/test_kinesis/test_kinesis.py @@ -19,15 +19,17 @@ def test_stream_creation_on_demand(): client.create_stream( StreamName="my_stream", StreamModeDetails={"StreamMode": "ON_DEMAND"} ) + # At the same time, test whether we can pass the StreamARN instead of the name + stream_arn = get_stream_arn(client, "my_stream") # AWS starts with 4 shards by default - shard_list = client.list_shards(StreamName="my_stream")["Shards"] + shard_list = client.list_shards(StreamARN=stream_arn)["Shards"] shard_list.should.have.length_of(4) # Cannot update-shard-count when we're in on-demand mode with pytest.raises(ClientError) as exc: client.update_shard_count( - StreamName="my_stream", TargetShardCount=3, ScalingType="UNIFORM_SCALING" + StreamARN=stream_arn, TargetShardCount=3, ScalingType="UNIFORM_SCALING" ) err = exc.value.response["Error"] err["Code"].should.equal("ValidationException") @@ -39,7 +41,7 @@ def test_stream_creation_on_demand(): @mock_kinesis def test_update_stream_mode(): client = boto3.client("kinesis", region_name="eu-west-1") - resp = client.create_stream( + client.create_stream( StreamName="my_stream", StreamModeDetails={"StreamMode": "ON_DEMAND"} ) arn = client.describe_stream(StreamName="my_stream")["StreamDescription"][ @@ -56,7 +58,7 @@ def test_update_stream_mode(): @mock_kinesis -def test_describe_non_existent_stream_boto3(): +def test_describe_non_existent_stream(): client = boto3.client("kinesis", region_name="us-west-2") with pytest.raises(ClientError) as exc: client.describe_stream_summary(StreamName="not-a-stream") @@ -68,7 +70,7 @@ def test_describe_non_existent_stream_boto3(): @mock_kinesis -def test_list_and_delete_stream_boto3(): +def test_list_and_delete_stream(): client = boto3.client("kinesis", region_name="us-west-2") client.list_streams()["StreamNames"].should.have.length_of(0) @@ -79,6 +81,10 @@ def test_list_and_delete_stream_boto3(): client.delete_stream(StreamName="stream1") client.list_streams()["StreamNames"].should.have.length_of(1) + stream_arn = get_stream_arn(client, "stream2") + client.delete_stream(StreamARN=stream_arn) + client.list_streams()["StreamNames"].should.have.length_of(0) + @mock_kinesis def test_delete_unknown_stream(): @@ -128,9 +134,15 @@ def test_describe_stream_summary(): ) stream["StreamStatus"].should.equal("ACTIVE") + stream_arn = get_stream_arn(conn, stream_name) + resp = conn.describe_stream_summary(StreamARN=stream_arn) + stream = resp["StreamDescriptionSummary"] + + stream["StreamName"].should.equal(stream_name) + @mock_kinesis -def test_basic_shard_iterator_boto3(): +def test_basic_shard_iterator(): client = boto3.client("kinesis", region_name="us-west-1") stream_name = "mystream" @@ -149,7 +161,30 @@ def test_basic_shard_iterator_boto3(): @mock_kinesis -def test_get_invalid_shard_iterator_boto3(): +def test_basic_shard_iterator_by_stream_arn(): + client = boto3.client("kinesis", region_name="us-west-1") + + stream_name = "mystream" + client.create_stream(StreamName=stream_name, ShardCount=1) + stream = client.describe_stream(StreamName=stream_name)["StreamDescription"] + shard_id = stream["Shards"][0]["ShardId"] + + resp = client.get_shard_iterator( + StreamARN=stream["StreamARN"], + ShardId=shard_id, + ShardIteratorType="TRIM_HORIZON", + ) + shard_iterator = resp["ShardIterator"] + + resp = client.get_records( + StreamARN=stream["StreamARN"], ShardIterator=shard_iterator + ) + resp.should.have.key("Records").length_of(0) + resp.should.have.key("MillisBehindLatest").equal(0) + + +@mock_kinesis +def test_get_invalid_shard_iterator(): client = boto3.client("kinesis", region_name="us-west-1") stream_name = "mystream" @@ -169,21 +204,22 @@ def test_get_invalid_shard_iterator_boto3(): @mock_kinesis -def test_put_records_boto3(): +def test_put_records(): client = boto3.client("kinesis", region_name="eu-west-2") stream_name = "my_stream_summary" client.create_stream(StreamName=stream_name, ShardCount=1) stream = client.describe_stream(StreamName=stream_name)["StreamDescription"] + stream_arn = stream["StreamARN"] shard_id = stream["Shards"][0]["ShardId"] data = b"hello world" partition_key = "1234" - response = client.put_record( - StreamName=stream_name, Data=data, PartitionKey=partition_key + client.put_records( + Records=[{"Data": data, "PartitionKey": partition_key}] * 5, + StreamARN=stream_arn, ) - response["SequenceNumber"].should.equal("1") resp = client.get_shard_iterator( StreamName=stream_name, ShardId=shard_id, ShardIteratorType="TRIM_HORIZON" @@ -191,27 +227,28 @@ def test_put_records_boto3(): shard_iterator = resp["ShardIterator"] resp = client.get_records(ShardIterator=shard_iterator) - resp["Records"].should.have.length_of(1) + resp["Records"].should.have.length_of(5) record = resp["Records"][0] - record["Data"].should.equal(b"hello world") - record["PartitionKey"].should.equal("1234") + record["Data"].should.equal(data) + record["PartitionKey"].should.equal(partition_key) record["SequenceNumber"].should.equal("1") @mock_kinesis -def test_get_records_limit_boto3(): +def test_get_records_limit(): client = boto3.client("kinesis", region_name="eu-west-2") stream_name = "my_stream_summary" client.create_stream(StreamName=stream_name, ShardCount=1) stream = client.describe_stream(StreamName=stream_name)["StreamDescription"] + stream_arn = stream["StreamARN"] shard_id = stream["Shards"][0]["ShardId"] data = b"hello world" for index in range(5): - client.put_record(StreamName=stream_name, Data=data, PartitionKey=str(index)) + client.put_record(StreamARN=stream_arn, Data=data, PartitionKey=str(index)) resp = client.get_shard_iterator( StreamName=stream_name, ShardId=shard_id, ShardIteratorType="TRIM_HORIZON" @@ -229,7 +266,7 @@ def test_get_records_limit_boto3(): @mock_kinesis -def test_get_records_at_sequence_number_boto3(): +def test_get_records_at_sequence_number(): client = boto3.client("kinesis", region_name="eu-west-2") stream_name = "my_stream_summary" client.create_stream(StreamName=stream_name, ShardCount=1) @@ -268,7 +305,7 @@ def test_get_records_at_sequence_number_boto3(): @mock_kinesis -def test_get_records_after_sequence_number_boto3(): +def test_get_records_after_sequence_number(): client = boto3.client("kinesis", region_name="eu-west-2") stream_name = "my_stream_summary" client.create_stream(StreamName=stream_name, ShardCount=1) @@ -308,7 +345,7 @@ def test_get_records_after_sequence_number_boto3(): @mock_kinesis -def test_get_records_latest_boto3(): +def test_get_records_latest(): client = boto3.client("kinesis", region_name="eu-west-2") stream_name = "my_stream_summary" client.create_stream(StreamName=stream_name, ShardCount=1) @@ -607,6 +644,7 @@ def test_valid_decrease_stream_retention_period(): conn = boto3.client("kinesis", region_name="us-west-2") stream_name = "decrease_stream" conn.create_stream(StreamName=stream_name, ShardCount=1) + stream_arn = get_stream_arn(conn, stream_name) conn.increase_stream_retention_period( StreamName=stream_name, RetentionPeriodHours=30 @@ -618,6 +656,12 @@ def test_valid_decrease_stream_retention_period(): response = conn.describe_stream(StreamName=stream_name) response["StreamDescription"]["RetentionPeriodHours"].should.equal(25) + conn.increase_stream_retention_period(StreamARN=stream_arn, RetentionPeriodHours=29) + conn.decrease_stream_retention_period(StreamARN=stream_arn, RetentionPeriodHours=26) + + response = conn.describe_stream(StreamARN=stream_arn) + response["StreamDescription"]["RetentionPeriodHours"].should.equal(26) + @mock_kinesis def test_decrease_stream_retention_period_upwards(): @@ -671,7 +715,7 @@ def test_decrease_stream_retention_period_too_high(): @mock_kinesis -def test_invalid_shard_iterator_type_boto3(): +def test_invalid_shard_iterator_type(): client = boto3.client("kinesis", region_name="eu-west-2") stream_name = "my_stream_summary" client.create_stream(StreamName=stream_name, ShardCount=1) @@ -688,10 +732,11 @@ def test_invalid_shard_iterator_type_boto3(): @mock_kinesis -def test_add_list_remove_tags_boto3(): +def test_add_list_remove_tags(): client = boto3.client("kinesis", region_name="eu-west-2") stream_name = "my_stream_summary" client.create_stream(StreamName=stream_name, ShardCount=1) + stream_arn = get_stream_arn(client, stream_name) client.add_tags_to_stream( StreamName=stream_name, Tags={"tag1": "val1", "tag2": "val2", "tag3": "val3", "tag4": "val4"}, @@ -704,9 +749,9 @@ def test_add_list_remove_tags_boto3(): tags.should.contain({"Key": "tag3", "Value": "val3"}) tags.should.contain({"Key": "tag4", "Value": "val4"}) - client.add_tags_to_stream(StreamName=stream_name, Tags={"tag5": "val5"}) + client.add_tags_to_stream(StreamARN=stream_arn, Tags={"tag5": "val5"}) - tags = client.list_tags_for_stream(StreamName=stream_name)["Tags"] + tags = client.list_tags_for_stream(StreamARN=stream_arn)["Tags"] tags.should.have.length_of(5) tags.should.contain({"Key": "tag5", "Value": "val5"}) @@ -718,19 +763,33 @@ def test_add_list_remove_tags_boto3(): tags.should.contain({"Key": "tag4", "Value": "val4"}) tags.should.contain({"Key": "tag5", "Value": "val5"}) + client.remove_tags_from_stream(StreamARN=stream_arn, TagKeys=["tag4"]) + + tags = client.list_tags_for_stream(StreamName=stream_name)["Tags"] + tags.should.have.length_of(2) + tags.should.contain({"Key": "tag1", "Value": "val1"}) + tags.should.contain({"Key": "tag5", "Value": "val5"}) + @mock_kinesis -def test_merge_shards_boto3(): +def test_merge_shards(): client = boto3.client("kinesis", region_name="eu-west-2") stream_name = "my_stream_summary" client.create_stream(StreamName=stream_name, ShardCount=4) + stream_arn = get_stream_arn(client, stream_name) - for index in range(1, 100): + for index in range(1, 50): client.put_record( StreamName=stream_name, Data=f"data_{index}".encode("utf-8"), PartitionKey=str(index), ) + for index in range(51, 100): + client.put_record( + StreamARN=stream_arn, + Data=f"data_{index}".encode("utf-8"), + PartitionKey=str(index), + ) stream = client.describe_stream(StreamName=stream_name)["StreamDescription"] shards = stream["Shards"] @@ -757,7 +816,7 @@ def test_merge_shards_boto3(): active_shards.should.have.length_of(3) client.merge_shards( - StreamName=stream_name, + StreamARN=stream_arn, ShardToMerge="shardId-000000000004", AdjacentShardToMerge="shardId-000000000002", ) @@ -804,3 +863,9 @@ def test_merge_shards_invalid_arg(): err = exc.value.response["Error"] err["Code"].should.equal("InvalidArgumentException") err["Message"].should.equal("shardId-000000000002") + + +def get_stream_arn(client, stream_name): + return client.describe_stream(StreamName=stream_name)["StreamDescription"][ + "StreamARN" + ] diff --git a/tests/test_kinesis/test_kinesis_boto3.py b/tests/test_kinesis/test_kinesis_boto3.py index cc8ab6d17162..27bdc49c7c97 100644 --- a/tests/test_kinesis/test_kinesis_boto3.py +++ b/tests/test_kinesis/test_kinesis_boto3.py @@ -4,6 +4,7 @@ from botocore.exceptions import ClientError from moto import mock_kinesis from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID +from .test_kinesis import get_stream_arn import sure # noqa # pylint: disable=unused-import @@ -279,9 +280,10 @@ def test_split_shard(): def test_split_shard_that_was_split_before(): client = boto3.client("kinesis", region_name="us-west-2") client.create_stream(StreamName="my-stream", ShardCount=2) + stream_arn = get_stream_arn(client, "my-stream") client.split_shard( - StreamName="my-stream", + StreamARN=stream_arn, ShardToSplit="shardId-000000000001", NewStartingHashKey="170141183460469231731687303715884105829", ) diff --git a/tests/test_kinesis/test_kinesis_encryption.py b/tests/test_kinesis/test_kinesis_encryption.py index 2d3d5e4d790b..088ac2db1a54 100644 --- a/tests/test_kinesis/test_kinesis_encryption.py +++ b/tests/test_kinesis/test_kinesis_encryption.py @@ -1,6 +1,7 @@ import boto3 from moto import mock_kinesis +from .test_kinesis import get_stream_arn @mock_kinesis @@ -44,3 +45,27 @@ def test_disable_encryption(): desc = resp["StreamDescription"] desc.should.have.key("EncryptionType").should.equal("NONE") desc.shouldnt.have.key("KeyId") + + +@mock_kinesis +def test_disable_encryption__using_arns(): + client = boto3.client("kinesis", region_name="us-west-2") + client.create_stream(StreamName="my-stream", ShardCount=2) + stream_arn = get_stream_arn(client, "my-stream") + + resp = client.describe_stream(StreamName="my-stream") + desc = resp["StreamDescription"] + desc.should.have.key("EncryptionType").should.equal("NONE") + + client.start_stream_encryption( + StreamARN=stream_arn, EncryptionType="KMS", KeyId="n/a" + ) + + client.stop_stream_encryption( + StreamARN=stream_arn, EncryptionType="KMS", KeyId="n/a" + ) + + resp = client.describe_stream(StreamName="my-stream") + desc = resp["StreamDescription"] + desc.should.have.key("EncryptionType").should.equal("NONE") + desc.shouldnt.have.key("KeyId") diff --git a/tests/test_kinesis/test_kinesis_monitoring.py b/tests/test_kinesis/test_kinesis_monitoring.py index 21e0eaff381e..5da00303a97f 100644 --- a/tests/test_kinesis/test_kinesis_monitoring.py +++ b/tests/test_kinesis/test_kinesis_monitoring.py @@ -1,6 +1,8 @@ import boto3 from moto import mock_kinesis +from tests import DEFAULT_ACCOUNT_ID +from .test_kinesis import get_stream_arn @mock_kinesis @@ -16,6 +18,9 @@ def test_enable_enhanced_monitoring_all(): resp.should.have.key("StreamName").equals(stream_name) resp.should.have.key("CurrentShardLevelMetrics").equals([]) resp.should.have.key("DesiredShardLevelMetrics").equals(["ALL"]) + resp.should.have.key("StreamARN").equals( + f"arn:aws:kinesis:us-east-1:{DEFAULT_ACCOUNT_ID}:stream/{stream_name}" + ) @mock_kinesis @@ -70,9 +75,10 @@ def test_disable_enhanced_monitoring(): client = boto3.client("kinesis", region_name="us-east-1") stream_name = "my_stream_summary" client.create_stream(StreamName=stream_name, ShardCount=4) + stream_arn = get_stream_arn(client, stream_name) client.enable_enhanced_monitoring( - StreamName=stream_name, + StreamARN=stream_arn, ShardLevelMetrics=[ "IncomingBytes", "OutgoingBytes", @@ -84,6 +90,11 @@ def test_disable_enhanced_monitoring(): StreamName=stream_name, ShardLevelMetrics=["OutgoingBytes"] ) + resp.should.have.key("StreamName").equals(stream_name) + resp.should.have.key("StreamARN").equals( + f"arn:aws:kinesis:us-east-1:{DEFAULT_ACCOUNT_ID}:stream/{stream_name}" + ) + resp.should.have.key("CurrentShardLevelMetrics").should.have.length_of(3) resp["CurrentShardLevelMetrics"].should.contain("IncomingBytes") resp["CurrentShardLevelMetrics"].should.contain("OutgoingBytes") @@ -102,6 +113,13 @@ def test_disable_enhanced_monitoring(): metrics.should.contain("IncomingBytes") metrics.should.contain("WriteProvisionedThroughputExceeded") + resp = client.disable_enhanced_monitoring( + StreamARN=stream_arn, ShardLevelMetrics=["IncomingBytes"] + ) + + resp.should.have.key("CurrentShardLevelMetrics").should.have.length_of(2) + resp.should.have.key("DesiredShardLevelMetrics").should.have.length_of(1) + @mock_kinesis def test_disable_enhanced_monitoring_all():