Skip to content

Commit

Permalink
feat: Resource level attributes support (aws#2008)
Browse files Browse the repository at this point in the history
* Fix for invalid MQ event source managed policy

* Fix for invalid managed policy for MQ, included support for new MQ event source property, updated test cases

* Black reformatting

* Test case changes

* Changed policy name

* Modified test cases with new policy name

* Added resource attributes and unit tests

* Resource attributes initial work

* Passthrough attributes for some resources, updated some tests

* Resolve merge conflicts

* Fixed a typo

* Modified implicit api plugin for resource attributes support

* Partial update of the tests

* Partially updated test cases, black reformatted

* Partially updated test templates

* Partially updated test templates

* Partially updated test templates

* Added event bridge support for passthrough resource attributes

* Partially updated test templates (up to function with amq kms)

* Partially updated test templates (up to sns)

* Partially updated test templates (all the ones left)

* Prevented passthrough resource attributes from changing layer version hashes

* Added test to verify resource passthrough precedence for implicit api

* Modified tests related to lambda layer to revert the hash changes, keeping the hash the same with resource attributes added

* fix: mutable default values in method definitions (aws#1997)

* fix: remove explicit logging level set in single module (aws#1998)

* run automated tests for resource level attribute support

* Skipping metadata in layer hashing

* Refactored the classes for TestTranslatorEndToEnd and TestResourceLevelAttributes to share the same parent class

* Added new translator tests for version and layer resources

* Added new unit tests

* Removed after transform resource plugin

* Black reformatting

* Refactoring implicit api plugin support for DeletionPolicy and UpdateReplacePolicy

* Refactoring to improve code quality

* Added simple documentation

* Black reformatting

* Added input template that was missing

* Refactoring: use sets instead of lists for implicit api plugin

* Changing import to be compatible with py2.7

* Changing test deployment hashes to their actual values

Co-authored-by: Mehmet Nuri Deveci <[email protected]>
  • Loading branch information
qingchm and mndeveci committed May 19, 2021
1 parent 8061f10 commit a754bb3
Show file tree
Hide file tree
Showing 40 changed files with 3,419 additions and 1,558 deletions.
29 changes: 25 additions & 4 deletions samtranslator/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,11 @@ class Resource(object):
property_types = None
_keywords = ["logical_id", "relative_id", "depends_on", "resource_attributes"]

_supported_resource_attributes = ["DeletionPolicy", "UpdatePolicy", "Condition"]
# For attributes in this list, they will be passed into the translated template for the same resource itself.
_supported_resource_attributes = ["DeletionPolicy", "UpdatePolicy", "Condition", "UpdateReplacePolicy", "Metadata"]
# For attributes in this list, they will be passed into the translated template for the same resource,
# as well as all the auto-generated resources that are created from this resource.
_pass_through_attributes = ["Condition", "DeletionPolicy", "UpdateReplacePolicy"]

# Runtime attributes that can be qureied resource. They are CloudFormation attributes like ARN, Name etc that
# will be resolvable at runtime. This map will be implemented by sub-classes to express list of attributes they
Expand Down Expand Up @@ -76,6 +80,22 @@ def __init__(self, logical_id, relative_id=None, depends_on=None, attributes=Non
for attr, value in attributes.items():
self.set_resource_attribute(attr, value)

@classmethod
def get_supported_resource_attributes(cls):
"""
A getter method for the supported resource attributes
returns: a tuple that contains the name of all supported resource attributes
"""
return tuple(cls._supported_resource_attributes)

@classmethod
def get_pass_through_attributes(cls):
"""
A getter method for the resource attributes to be passed to auto-generated resources
returns: a tuple that contains the name of all pass through attributes
"""
return tuple(cls._pass_through_attributes)

@classmethod
def from_dict(cls, logical_id, resource_dict, relative_id=None, sam_plugins=None):
"""Constructs a Resource object with the given logical id, based on the given resource dict. The resource dict
Expand Down Expand Up @@ -318,9 +338,10 @@ def get_passthrough_resource_attributes(self):
:return: Dictionary of resource attributes.
"""
attributes = None
if "Condition" in self.resource_attributes:
attributes = {"Condition": self.resource_attributes["Condition"]}
attributes = {}
for resource_attribute in self.get_pass_through_attributes():
if resource_attribute in self.resource_attributes:
attributes[resource_attribute] = self.resource_attributes.get(resource_attribute)
return attributes


Expand Down
30 changes: 25 additions & 5 deletions samtranslator/model/api/api_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -631,7 +631,11 @@ def _construct_usage_plan(self, rest_api_stage=None):
# create usage plan for this api only
elif usage_plan_properties.get("CreateUsagePlan") == "PER_API":
usage_plan_logical_id = self.logical_id + "UsagePlan"
usage_plan = ApiGatewayUsagePlan(logical_id=usage_plan_logical_id, depends_on=[self.logical_id])
usage_plan = ApiGatewayUsagePlan(
logical_id=usage_plan_logical_id,
depends_on=[self.logical_id],
attributes=self.passthrough_resource_attributes,
)
api_stages = list()
api_stage = dict()
api_stage["ApiId"] = ref(self.logical_id)
Expand All @@ -649,7 +653,9 @@ def _construct_usage_plan(self, rest_api_stage=None):
if self.logical_id not in self.shared_api_usage_plan.depends_on_shared:
self.shared_api_usage_plan.depends_on_shared.append(self.logical_id)
usage_plan = ApiGatewayUsagePlan(
logical_id=usage_plan_logical_id, depends_on=self.shared_api_usage_plan.depends_on_shared
logical_id=usage_plan_logical_id,
depends_on=self.shared_api_usage_plan.depends_on_shared,
attributes=self.passthrough_resource_attributes,
)
api_stage = dict()
api_stage["ApiId"] = ref(self.logical_id)
Expand Down Expand Up @@ -684,7 +690,11 @@ def _construct_api_key(self, usage_plan_logical_id, create_usage_plan, rest_api_
# create an api key resource for all the apis
LOG.info("Creating api key resource for all the Apis from SHARED usage plan")
api_key_logical_id = "ServerlessApiKey"
api_key = ApiGatewayApiKey(logical_id=api_key_logical_id, depends_on=[usage_plan_logical_id])
api_key = ApiGatewayApiKey(
logical_id=api_key_logical_id,
depends_on=[usage_plan_logical_id],
attributes=self.passthrough_resource_attributes,
)
api_key.Enabled = True
stage_key = dict()
stage_key["RestApiId"] = ref(self.logical_id)
Expand All @@ -696,7 +706,12 @@ def _construct_api_key(self, usage_plan_logical_id, create_usage_plan, rest_api_
else:
# create an api key resource for this api
api_key_logical_id = self.logical_id + "ApiKey"
api_key = ApiGatewayApiKey(logical_id=api_key_logical_id, depends_on=[usage_plan_logical_id])
api_key = ApiGatewayApiKey(
logical_id=api_key_logical_id,
depends_on=[usage_plan_logical_id],
attributes=self.passthrough_resource_attributes,
)
# api_key = ApiGatewayApiKey(logical_id=api_key_logical_id, depends_on=[usage_plan_logical_id])
api_key.Enabled = True
stage_keys = list()
stage_key = dict()
Expand All @@ -721,7 +736,12 @@ def _construct_usage_plan_key(self, usage_plan_logical_id, create_usage_plan, ap
# create a mapping between api key and the usage plan
usage_plan_key_logical_id = self.logical_id + "UsagePlanKey"

usage_plan_key = ApiGatewayUsagePlanKey(logical_id=usage_plan_key_logical_id, depends_on=[api_key.logical_id])
usage_plan_key = ApiGatewayUsagePlanKey(
logical_id=usage_plan_key_logical_id,
depends_on=[api_key.logical_id],
attributes=self.passthrough_resource_attributes,
)
# usage_plan_key = ApiGatewayUsagePlanKey(logical_id=usage_plan_key_logical_id, depends_on=[api_key.logical_id])
usage_plan_key.KeyId = ref(api_key.logical_id)
usage_plan_key.KeyType = "API_KEY"
usage_plan_key.UsagePlanId = ref(usage_plan_logical_id)
Expand Down
10 changes: 5 additions & 5 deletions samtranslator/model/eventbridge_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,15 @@

class EventBridgeRuleUtils:
@staticmethod
def create_dead_letter_queue_with_policy(rule_logical_id, rule_arn, queue_logical_id=None):
def create_dead_letter_queue_with_policy(rule_logical_id, rule_arn, queue_logical_id=None, attributes=None):
resources = []

queue = SQSQueue(queue_logical_id or rule_logical_id + "Queue")
queue = SQSQueue(queue_logical_id or rule_logical_id + "Queue", attributes=attributes)
dlq_queue_arn = queue.get_runtime_attr("arn")
dlq_queue_url = queue.get_runtime_attr("queue_url")

# grant necessary permission to Eventbridge Rule resource for sending messages to dead-letter queue
policy = SQSQueuePolicy(rule_logical_id + "QueuePolicy")
policy = SQSQueuePolicy(rule_logical_id + "QueuePolicy", attributes=attributes)
policy.PolicyDocument = SQSQueuePolicies.eventbridge_dlq_send_message_resource_based_policy(
rule_arn, dlq_queue_arn
)
Expand Down Expand Up @@ -41,14 +41,14 @@ def validate_dlq_config(source_logical_id, dead_letter_config):
raise InvalidEventException(source_logical_id, "No 'Arn' or 'Type' property provided for DeadLetterConfig")

@staticmethod
def get_dlq_queue_arn_and_resources(cw_event_source, source_arn):
def get_dlq_queue_arn_and_resources(cw_event_source, source_arn, attributes):
"""returns dlq queue arn and dlq_resources, assuming cw_event_source.DeadLetterConfig has been validated"""
dlq_queue_arn = cw_event_source.DeadLetterConfig.get("Arn")
if dlq_queue_arn is not None:
return dlq_queue_arn, []
queue_logical_id = cw_event_source.DeadLetterConfig.get("QueueLogicalId")
dlq_resources = EventBridgeRuleUtils.create_dead_letter_queue_with_policy(
cw_event_source.logical_id, source_arn, queue_logical_id
cw_event_source.logical_id, source_arn, queue_logical_id, attributes
)
dlq_queue_arn = dlq_resources[0].get_runtime_attr("arn")
return dlq_queue_arn, dlq_resources
8 changes: 5 additions & 3 deletions samtranslator/model/eventsources/cloudwatchlogs.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,13 @@ def get_source_arn(self):
)

def get_subscription_filter(self, function, permission):
subscription_filter = SubscriptionFilter(self.logical_id, depends_on=[permission.logical_id])
subscription_filter = SubscriptionFilter(
self.logical_id,
depends_on=[permission.logical_id],
attributes=function.get_passthrough_resource_attributes(),
)
subscription_filter.LogGroupName = self.LogGroupName
subscription_filter.FilterPattern = self.FilterPattern
subscription_filter.DestinationArn = function.get_runtime_attr("arn")
if "Condition" in function.resource_attributes:
subscription_filter.set_resource_attribute("Condition", function.resource_attributes["Condition"])

return subscription_filter
7 changes: 3 additions & 4 deletions samtranslator/model/eventsources/pull.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,9 @@ def to_cloudformation(self, **kwargs):

resources = []

lambda_eventsourcemapping = LambdaEventSourceMapping(self.logical_id)
lambda_eventsourcemapping = LambdaEventSourceMapping(
self.logical_id, attributes=function.get_passthrough_resource_attributes()
)
resources.append(lambda_eventsourcemapping)

try:
Expand Down Expand Up @@ -122,9 +124,6 @@ def to_cloudformation(self, **kwargs):
)
lambda_eventsourcemapping.DestinationConfig = self.DestinationConfig

if "Condition" in function.resource_attributes:
lambda_eventsourcemapping.set_resource_attribute("Condition", function.resource_attributes["Condition"])

if "role" in kwargs:
self._link_policy(kwargs["role"], destination_config_policy)

Expand Down
69 changes: 37 additions & 32 deletions samtranslator/model/eventsources/push.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,8 @@ def to_cloudformation(self, **kwargs):

resources = []

events_rule = EventsRule(self.logical_id)
passthrough_resource_attributes = function.get_passthrough_resource_attributes()
events_rule = EventsRule(self.logical_id, attributes=passthrough_resource_attributes)
resources.append(events_rule)

events_rule.ScheduleExpression = self.Schedule
Expand All @@ -126,13 +127,13 @@ def to_cloudformation(self, **kwargs):
dlq_queue_arn = None
if self.DeadLetterConfig is not None:
EventBridgeRuleUtils.validate_dlq_config(self.logical_id, self.DeadLetterConfig)
dlq_queue_arn, dlq_resources = EventBridgeRuleUtils.get_dlq_queue_arn_and_resources(self, source_arn)
dlq_queue_arn, dlq_resources = EventBridgeRuleUtils.get_dlq_queue_arn_and_resources(
self, source_arn, passthrough_resource_attributes
)
resources.extend(dlq_resources)

events_rule.Targets = [self._construct_target(function, dlq_queue_arn)]

if CONDITION in function.resource_attributes:
events_rule.set_resource_attribute(CONDITION, function.resource_attributes[CONDITION])
resources.append(self._construct_permission(function, source_arn=source_arn))

return resources
Expand Down Expand Up @@ -186,20 +187,21 @@ def to_cloudformation(self, **kwargs):

resources = []

events_rule = EventsRule(self.logical_id)
passthrough_resource_attributes = function.get_passthrough_resource_attributes()
events_rule = EventsRule(self.logical_id, attributes=passthrough_resource_attributes)
events_rule.EventBusName = self.EventBusName
events_rule.EventPattern = self.Pattern
source_arn = events_rule.get_runtime_attr("arn")

dlq_queue_arn = None
if self.DeadLetterConfig is not None:
EventBridgeRuleUtils.validate_dlq_config(self.logical_id, self.DeadLetterConfig)
dlq_queue_arn, dlq_resources = EventBridgeRuleUtils.get_dlq_queue_arn_and_resources(self, source_arn)
dlq_queue_arn, dlq_resources = EventBridgeRuleUtils.get_dlq_queue_arn_and_resources(
self, source_arn, passthrough_resource_attributes
)
resources.extend(dlq_resources)

events_rule.Targets = [self._construct_target(function, dlq_queue_arn)]
if CONDITION in function.resource_attributes:
events_rule.set_resource_attribute(CONDITION, function.resource_attributes[CONDITION])

resources.append(events_rule)
resources.append(self._construct_permission(function, source_arn=source_arn))
Expand Down Expand Up @@ -427,20 +429,20 @@ def to_cloudformation(self, **kwargs):
self.Topic,
self.Region,
self.FilterPolicy,
function.resource_attributes,
function,
)
return [self._construct_permission(function, source_arn=self.Topic), subscription]

# SNS -> SQS(Create New) -> Lambda
if isinstance(self.SqsSubscription, bool):
resources = []
queue = self._inject_sqs_queue()
queue = self._inject_sqs_queue(function)
queue_arn = queue.get_runtime_attr("arn")
queue_url = queue.get_runtime_attr("queue_url")

queue_policy = self._inject_sqs_queue_policy(self.Topic, queue_arn, queue_url, function.resource_attributes)
queue_policy = self._inject_sqs_queue_policy(self.Topic, queue_arn, queue_url, function)
subscription = self._inject_subscription(
"sqs", queue_arn, self.Topic, self.Region, self.FilterPolicy, function.resource_attributes
"sqs", queue_arn, self.Topic, self.Region, self.FilterPolicy, function
)
event_source = self._inject_sqs_event_source_mapping(function, role, queue_arn)

Expand All @@ -462,47 +464,46 @@ def to_cloudformation(self, **kwargs):
enabled = self.SqsSubscription.get("Enabled", None)

queue_policy = self._inject_sqs_queue_policy(
self.Topic, queue_arn, queue_url, function.resource_attributes, queue_policy_logical_id
)
subscription = self._inject_subscription(
"sqs", queue_arn, self.Topic, self.Region, self.FilterPolicy, function.resource_attributes
self.Topic, queue_arn, queue_url, function, queue_policy_logical_id
)
subscription = self._inject_subscription("sqs", queue_arn, self.Topic, self.Region, self.FilterPolicy, function)
event_source = self._inject_sqs_event_source_mapping(function, role, queue_arn, batch_size, enabled)

resources = resources + event_source
resources.append(queue_policy)
resources.append(subscription)
return resources

def _inject_subscription(self, protocol, endpoint, topic, region, filterPolicy, resource_attributes):
subscription = SNSSubscription(self.logical_id)
def _inject_subscription(self, protocol, endpoint, topic, region, filterPolicy, function):
subscription = SNSSubscription(self.logical_id, attributes=function.get_passthrough_resource_attributes())
subscription.Protocol = protocol
subscription.Endpoint = endpoint
subscription.TopicArn = topic

if region is not None:
subscription.Region = region
if CONDITION in resource_attributes:
subscription.set_resource_attribute(CONDITION, resource_attributes[CONDITION])

if filterPolicy is not None:
subscription.FilterPolicy = filterPolicy

return subscription

def _inject_sqs_queue(self):
return SQSQueue(self.logical_id + "Queue")
def _inject_sqs_queue(self, function):
return SQSQueue(self.logical_id + "Queue", attributes=function.get_passthrough_resource_attributes())

def _inject_sqs_event_source_mapping(self, function, role, queue_arn, batch_size=None, enabled=None):
event_source = SQS(self.logical_id + "EventSourceMapping")
event_source = SQS(
self.logical_id + "EventSourceMapping", attributes=function.get_passthrough_resource_attributes()
)
event_source.Queue = queue_arn
event_source.BatchSize = batch_size or 10
event_source.Enabled = enabled or True
return event_source.to_cloudformation(function=function, role=role)

def _inject_sqs_queue_policy(self, topic_arn, queue_arn, queue_url, resource_attributes, logical_id=None):
policy = SQSQueuePolicy(logical_id or self.logical_id + "QueuePolicy")
if CONDITION in resource_attributes:
policy.set_resource_attribute(CONDITION, resource_attributes[CONDITION])
def _inject_sqs_queue_policy(self, topic_arn, queue_arn, queue_url, function, logical_id=None):
policy = SQSQueuePolicy(
logical_id or self.logical_id + "QueuePolicy", attributes=function.get_passthrough_resource_attributes()
)

policy.PolicyDocument = SQSQueuePolicies.sns_topic_send_message_role_policy(topic_arn, queue_arn)
policy.Queues = [queue_url]
Expand Down Expand Up @@ -895,7 +896,7 @@ def to_cloudformation(self, **kwargs):
return resources

def _construct_iot_rule(self, function):
rule = IotTopicRule(self.logical_id)
rule = IotTopicRule(self.logical_id, attributes=function.get_passthrough_resource_attributes())

payload = {
"Sql": self.Sql,
Expand All @@ -907,8 +908,6 @@ def _construct_iot_rule(self, function):
payload["AwsIotSqlVersion"] = self.AwsIotSqlVersion

rule.TopicRulePayload = payload
if CONDITION in function.resource_attributes:
rule.set_resource_attribute(CONDITION, function.resource_attributes[CONDITION])

return rule

Expand Down Expand Up @@ -953,11 +952,17 @@ def to_cloudformation(self, **kwargs):

resources = []
source_arn = fnGetAtt(userpool_id, "Arn")
resources.append(
self._construct_permission(function, source_arn=source_arn, prefix=function.logical_id + "Cognito")
lambda_permission = self._construct_permission(
function, source_arn=source_arn, prefix=function.logical_id + "Cognito"
)
for attribute, value in function.get_passthrough_resource_attributes().items():
lambda_permission.set_resource_attribute(attribute, value)
resources.append(lambda_permission)

self._inject_lambda_config(function, userpool)
userpool_resource = CognitoUserPool.from_dict(userpool_id, userpool)
for attribute, value in function.get_passthrough_resource_attributes().items():
userpool_resource.set_resource_attribute(attribute, value)
resources.append(CognitoUserPool.from_dict(userpool_id, userpool))
return resources

Expand Down
Loading

0 comments on commit a754bb3

Please sign in to comment.