Skip to content

Commit

Permalink
Kinesis - support Stream ARNs across all methods (#5893)
Browse files Browse the repository at this point in the history
  • Loading branch information
bblommers authored Feb 1, 2023
1 parent 67ecc3b commit 19bfa92
Show file tree
Hide file tree
Showing 7 changed files with 314 additions and 109 deletions.
169 changes: 112 additions & 57 deletions moto/kinesis/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import itertools

from operator import attrgetter
from typing import Any, Dict, List, Optional, Tuple

from moto.core import BaseBackend, BackendDict, BaseModel, CloudFormationModel
from moto.core.utils import unix_time
Expand Down Expand Up @@ -439,12 +440,14 @@ def create_from_cloudformation_json(
for tag_item in properties.get("Tags", [])
}

backend = kinesis_backends[account_id][region_name]
backend: KinesisBackend = kinesis_backends[account_id][region_name]
stream = backend.create_stream(
resource_name, shard_count, retention_period_hours=retention_period_hours
)
if any(tags):
backend.add_tags_to_stream(stream.stream_name, tags)
backend.add_tags_to_stream(
stream_arn=None, stream_name=stream.stream_name, tags=tags
)
return stream

@classmethod
Expand Down Expand Up @@ -489,8 +492,8 @@ def update_from_cloudformation_json(
def delete_from_cloudformation_json(
cls, resource_name, cloudformation_json, account_id, region_name
):
backend = kinesis_backends[account_id][region_name]
backend.delete_stream(resource_name)
backend: KinesisBackend = kinesis_backends[account_id][region_name]
backend.delete_stream(stream_arn=None, stream_name=resource_name)

@staticmethod
def is_replacement_update(properties):
Expand Down Expand Up @@ -521,7 +524,7 @@ def physical_resource_id(self):
class KinesisBackend(BaseBackend):
def __init__(self, region_name, account_id):
super().__init__(region_name, account_id)
self.streams = OrderedDict()
self.streams: Dict[str, Stream] = OrderedDict()

@staticmethod
def default_vpc_endpoint_service(service_region, zones):
Expand All @@ -546,38 +549,49 @@ def create_stream(
self.streams[stream_name] = stream
return stream

def describe_stream(self, stream_name) -> Stream:
if stream_name in self.streams:
def describe_stream(
self, stream_arn: Optional[str], stream_name: Optional[str]
) -> Stream:
if stream_name and stream_name in self.streams:
return self.streams[stream_name]
else:
raise StreamNotFoundError(stream_name, self.account_id)
if stream_arn:
for stream in self.streams.values():
if stream.arn == stream_arn:
return stream
if stream_arn:
stream_name = stream_arn.split("/")[1]
raise StreamNotFoundError(stream_name, self.account_id)

def describe_stream_summary(self, stream_name):
return self.describe_stream(stream_name)
def describe_stream_summary(
self, stream_arn: Optional[str], stream_name: Optional[str]
) -> Stream:
return self.describe_stream(stream_arn=stream_arn, stream_name=stream_name)

def list_streams(self):
return self.streams.values()

def delete_stream(self, stream_name):
if stream_name in self.streams:
return self.streams.pop(stream_name)
raise StreamNotFoundError(stream_name, self.account_id)
def delete_stream(
self, stream_arn: Optional[str], stream_name: Optional[str]
) -> Stream:
stream = self.describe_stream(stream_arn=stream_arn, stream_name=stream_name)
return self.streams.pop(stream.stream_name)

def get_shard_iterator(
self,
stream_name,
shard_id,
shard_iterator_type,
starting_sequence_number,
at_timestamp,
stream_arn: Optional[str],
stream_name: Optional[str],
shard_id: str,
shard_iterator_type: str,
starting_sequence_number: str,
at_timestamp: str,
):
# Validate params
stream = self.describe_stream(stream_name)
stream = self.describe_stream(stream_arn=stream_arn, stream_name=stream_name)
try:
shard = stream.get_shard(shard_id)
except ShardNotFoundError:
raise ResourceNotFoundError(
message=f"Shard {shard_id} in stream {stream_name} under account {self.account_id} does not exist"
message=f"Shard {shard_id} in stream {stream.stream_name} under account {self.account_id} does not exist"
)

shard_iterator = compose_new_shard_iterator(
Expand All @@ -589,11 +603,13 @@ def get_shard_iterator(
)
return shard_iterator

def get_records(self, shard_iterator, limit):
def get_records(
self, stream_arn: Optional[str], shard_iterator: str, limit: Optional[int]
):
decomposed = decompose_shard_iterator(shard_iterator)
stream_name, shard_id, last_sequence_id = decomposed

stream = self.describe_stream(stream_name)
stream = self.describe_stream(stream_arn=stream_arn, stream_name=stream_name)
shard = stream.get_shard(shard_id)

records, last_sequence_id, millis_behind_latest = shard.get_records(
Expand All @@ -608,21 +624,22 @@ def get_records(self, shard_iterator, limit):

def put_record(
self,
stream_arn,
stream_name,
partition_key,
explicit_hash_key,
data,
):
stream = self.describe_stream(stream_name)
stream = self.describe_stream(stream_arn=stream_arn, stream_name=stream_name)

sequence_number, shard_id = stream.put_record(
partition_key, explicit_hash_key, data
)

return sequence_number, shard_id

def put_records(self, stream_name, records):
stream = self.describe_stream(stream_name)
def put_records(self, stream_arn, stream_name, records):
stream = self.describe_stream(stream_arn=stream_arn, stream_name=stream_name)

response = {"FailedRecordCount": 0, "Records": []}

Expand Down Expand Up @@ -651,8 +668,10 @@ def put_records(self, stream_name, records):

return response

def split_shard(self, stream_name, shard_to_split, new_starting_hash_key):
stream = self.describe_stream(stream_name)
def split_shard(
self, stream_arn, stream_name, shard_to_split, new_starting_hash_key
):
stream = self.describe_stream(stream_arn=stream_arn, stream_name=stream_name)

if not re.match("[a-zA-Z0-9_.-]+", shard_to_split):
raise ValidationException(
Expand All @@ -675,37 +694,46 @@ def split_shard(self, stream_name, shard_to_split, new_starting_hash_key):

stream.split_shard(shard_to_split, new_starting_hash_key)

def merge_shards(self, stream_name, shard_to_merge, adjacent_shard_to_merge):
stream = self.describe_stream(stream_name)
def merge_shards(
self, stream_arn, stream_name, shard_to_merge, adjacent_shard_to_merge
):
stream = self.describe_stream(stream_arn=stream_arn, stream_name=stream_name)

if shard_to_merge not in stream.shards:
raise ShardNotFoundError(
shard_to_merge, stream=stream_name, account_id=self.account_id
shard_to_merge, stream=stream.stream_name, account_id=self.account_id
)

if adjacent_shard_to_merge not in stream.shards:
raise ShardNotFoundError(
adjacent_shard_to_merge, stream=stream_name, account_id=self.account_id
adjacent_shard_to_merge,
stream=stream.stream_name,
account_id=self.account_id,
)

stream.merge_shards(shard_to_merge, adjacent_shard_to_merge)

def update_shard_count(self, stream_name, target_shard_count):
stream = self.describe_stream(stream_name)
def update_shard_count(self, stream_arn, stream_name, target_shard_count):
stream = self.describe_stream(stream_arn=stream_arn, stream_name=stream_name)
current_shard_count = len([s for s in stream.shards.values() if s.is_open])

stream.update_shard_count(target_shard_count)

return current_shard_count

@paginate(pagination_model=PAGINATION_MODEL)
def list_shards(self, stream_name):
stream = self.describe_stream(stream_name)
def list_shards(self, stream_arn: Optional[str], stream_name: Optional[str]):
stream = self.describe_stream(stream_arn=stream_arn, stream_name=stream_name)
shards = sorted(stream.shards.values(), key=lambda x: x.shard_id)
return [shard.to_json() for shard in shards]

def increase_stream_retention_period(self, stream_name, retention_period_hours):
stream = self.describe_stream(stream_name)
def increase_stream_retention_period(
self,
stream_arn: Optional[str],
stream_name: Optional[str],
retention_period_hours: int,
) -> None:
stream = self.describe_stream(stream_arn=stream_arn, stream_name=stream_name)
if retention_period_hours < 24:
raise InvalidRetentionPeriod(retention_period_hours, too_short=True)
if retention_period_hours > 8760:
Expand All @@ -718,8 +746,13 @@ def increase_stream_retention_period(self, stream_name, retention_period_hours):
)
stream.retention_period_hours = retention_period_hours

def decrease_stream_retention_period(self, stream_name, retention_period_hours):
stream = self.describe_stream(stream_name)
def decrease_stream_retention_period(
self,
stream_arn: Optional[str],
stream_name: Optional[str],
retention_period_hours: int,
) -> None:
stream = self.describe_stream(stream_arn=stream_arn, stream_name=stream_name)
if retention_period_hours < 24:
raise InvalidRetentionPeriod(retention_period_hours, too_short=True)
if retention_period_hours > 8760:
Expand All @@ -733,9 +766,9 @@ def decrease_stream_retention_period(self, stream_name, retention_period_hours):
stream.retention_period_hours = retention_period_hours

def list_tags_for_stream(
self, stream_name, exclusive_start_tag_key=None, limit=None
self, stream_arn, stream_name, exclusive_start_tag_key=None, limit=None
):
stream = self.describe_stream(stream_name)
stream = self.describe_stream(stream_arn=stream_arn, stream_name=stream_name)

tags = []
result = {"HasMoreTags": False, "Tags": tags}
Expand All @@ -750,25 +783,47 @@ def list_tags_for_stream(

return result

def add_tags_to_stream(self, stream_name, tags):
stream = self.describe_stream(stream_name)
def add_tags_to_stream(
self,
stream_arn: Optional[str],
stream_name: Optional[str],
tags: Dict[str, str],
) -> None:
stream = self.describe_stream(stream_arn=stream_arn, stream_name=stream_name)
stream.tags.update(tags)

def remove_tags_from_stream(self, stream_name, tag_keys):
stream = self.describe_stream(stream_name)
def remove_tags_from_stream(
self, stream_arn: Optional[str], stream_name: Optional[str], tag_keys: List[str]
) -> None:
stream = self.describe_stream(stream_arn=stream_arn, stream_name=stream_name)
for key in tag_keys:
if key in stream.tags:
del stream.tags[key]

def enable_enhanced_monitoring(self, stream_name, shard_level_metrics):
stream = self.describe_stream(stream_name)
def enable_enhanced_monitoring(
self,
stream_arn: Optional[str],
stream_name: Optional[str],
shard_level_metrics: List[str],
) -> Tuple[str, str, Dict[str, Any], Dict[str, Any]]:
stream = self.describe_stream(stream_arn=stream_arn, stream_name=stream_name)
current_shard_level_metrics = stream.shard_level_metrics
desired_metrics = list(set(current_shard_level_metrics + shard_level_metrics))
stream.shard_level_metrics = desired_metrics
return current_shard_level_metrics, desired_metrics
return (
stream.arn,
stream.stream_name,
current_shard_level_metrics,
desired_metrics,
)

def disable_enhanced_monitoring(self, stream_name, to_be_disabled):
stream = self.describe_stream(stream_name)
def disable_enhanced_monitoring(
self,
stream_arn: Optional[str],
stream_name: Optional[str],
to_be_disabled: List[str],
) -> Tuple[str, str, Dict[str, Any], Dict[str, Any]]:
stream = self.describe_stream(stream_arn=stream_arn, stream_name=stream_name)
current_metrics = stream.shard_level_metrics
if "ALL" in to_be_disabled:
desired_metrics = []
Expand All @@ -777,7 +832,7 @@ def disable_enhanced_monitoring(self, stream_name, to_be_disabled):
metric for metric in current_metrics if metric not in to_be_disabled
]
stream.shard_level_metrics = desired_metrics
return current_metrics, desired_metrics
return stream.arn, stream.stream_name, current_metrics, desired_metrics

def _find_stream_by_arn(self, stream_arn):
for stream in self.streams.values():
Expand Down Expand Up @@ -826,13 +881,13 @@ def deregister_stream_consumer(self, stream_arn, consumer_name, consumer_arn):
# It will be a noop for other streams
stream.delete_consumer(consumer_arn)

def start_stream_encryption(self, stream_name, encryption_type, key_id):
stream = self.describe_stream(stream_name)
def start_stream_encryption(self, stream_arn, stream_name, encryption_type, key_id):
stream = self.describe_stream(stream_arn=stream_arn, stream_name=stream_name)
stream.encryption_type = encryption_type
stream.key_id = key_id

def stop_stream_encryption(self, stream_name):
stream = self.describe_stream(stream_name)
def stop_stream_encryption(self, stream_arn, stream_name):
stream = self.describe_stream(stream_arn=stream_arn, stream_name=stream_name)
stream.encryption_type = "NONE"
stream.key_id = None

Expand Down
Loading

0 comments on commit 19bfa92

Please sign in to comment.