Skip to content

Commit

Permalink
Merge pull request aws-cloudformation#1 from brianlaoaws/feature/down…
Browse files Browse the repository at this point in the history
…load-payload

Add support to download hook target data for stack-level hooks
  • Loading branch information
brianlaoaws authored Mar 15, 2024
2 parents ebf8346 + 13500e4 commit 727b2df
Show file tree
Hide file tree
Showing 3 changed files with 140 additions and 1 deletion.
42 changes: 42 additions & 0 deletions src/cloudformation_cli_python_lib/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
from dataclasses import dataclass, field, fields

import json
import requests # type: ignore
from datetime import date, datetime, time
from requests.adapters import HTTPAdapter # type: ignore
from typing import (
Any,
Callable,
Expand All @@ -14,6 +16,7 @@
Type,
Union,
)
from urllib3 import Retry # type: ignore

from .exceptions import InvalidRequest
from .interface import (
Expand All @@ -25,6 +28,12 @@
HookInvocationPoint,
)

HOOK_REQUEST_DATA_TARGET_MODEL_FIELD_NAME = "targetModel"
HOOK_REMOTE_PAYLOAD_CONNECT_AND_READ_TIMEOUT_SECONDS = 10
HOOK_REMOTE_PAYLOAD_RETRY_LIMIT = 3
HOOK_REMOTE_PAYLOAD_RETRY_BACKOFF_FACTOR = 1
HOOK_REMOTE_PAYLOAD_RETRY_STATUSES = [500, 502, 503, 504]


class KitchenSinkEncoder(json.JSONEncoder):
def default(self, o): # type: ignore # pylint: disable=method-hidden
Expand Down Expand Up @@ -214,6 +223,7 @@ class HookRequestData:
targetType: str
targetLogicalId: str
targetModel: Mapping[str, Any]
payload: Optional[str] = None
callerCredentials: Optional[Credentials] = None
providerCredentials: Optional[Credentials] = None
providerLogGroupName: Optional[str] = None
Expand All @@ -234,6 +244,30 @@ def deserialize(cls, json_data: MutableMapping[str, Any]) -> "HookRequestData":
if creds:
cred_data = json.loads(creds)
setattr(req_data, key, Credentials(**cred_data))

if req_data.is_hook_invocation_payload_remote():
with requests.Session() as s:
retries = Retry(
total=HOOK_REMOTE_PAYLOAD_RETRY_LIMIT,
backoff_factor=HOOK_REMOTE_PAYLOAD_RETRY_BACKOFF_FACTOR,
status_forcelist=HOOK_REMOTE_PAYLOAD_RETRY_STATUSES,
)

s.mount("http://", HTTPAdapter(max_retries=retries))
s.mount("https://", HTTPAdapter(max_retries=retries))

response = s.get(
req_data.payload,
timeout=HOOK_REMOTE_PAYLOAD_CONNECT_AND_READ_TIMEOUT_SECONDS,
)

if response.status_code == 200:
setattr(
req_data,
HOOK_REQUEST_DATA_TARGET_MODEL_FIELD_NAME,
response.json(),
)

return req_data

def serialize(self) -> Mapping[str, Any]:
Expand All @@ -247,6 +281,14 @@ def serialize(self) -> Mapping[str, Any]:
if value is not None
}

def is_hook_invocation_payload_remote(self) -> bool:
if (
not self.targetModel and self.payload
): # pylint: disable=simplifiable-if-statement
return True

return False


@dataclass
class HookInvocationRequest:
Expand Down
1 change: 1 addition & 0 deletions src/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
install_requires=[
"boto3>=1.10.20",
'dataclasses;python_version<"3.7"',
"requests>=2.22",
],
license="Apache License 2.0",
classifiers=[
Expand Down
98 changes: 97 additions & 1 deletion tests/lib/hook_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,15 @@
OperationStatus,
ProgressEvent,
)
from cloudformation_cli_python_lib.utils import Credentials, HookInvocationRequest
from cloudformation_cli_python_lib.utils import (
Credentials,
HookInvocationRequest,
HookRequestData,
)

import json
from datetime import datetime
from typing import Any, Mapping
from unittest.mock import Mock, call, patch, sentinel

ENTRYPOINT_PAYLOAD = {
Expand Down Expand Up @@ -50,6 +55,34 @@
"hookModel": sentinel.type_configuration,
}

STACK_LEVEL_HOOK_ENTRYPOINT_PAYLOAD = {
"awsAccountId": "123456789012",
"clientRequestToken": "4b90a7e4-b790-456b-a937-0cfdfa211dfe",
"region": "us-east-1",
"actionInvocationPoint": "CREATE_PRE_PROVISION",
"hookTypeName": "AWS::Test::TestHook",
"hookTypeVersion": "1.0",
"requestContext": {
"invocation": 1,
"callbackContext": {},
},
"requestData": {
"callerCredentials": '{"accessKeyId": "IASAYK835GAIFHAHEI23", "secretAccessKey": "66iOGPN5LnpZorcLr8Kh25u8AbjHVllv5poh2O0", "sessionToken": "lameHS2vQOknSHWhdFYTxm2eJc1JMn9YBNI4nV4mXue945KPL6DHfW8EsUQT5zwssYEC1NvYP9yD6Y5s5lKR3chflOHPFsIe6eqg"}', # noqa: B950
"providerCredentials": '{"accessKeyId": "HDI0745692Y45IUTYR78", "secretAccessKey": "4976TUYVI2345GW87ERYG823RF87GY9EIUH452I3", "sessionToken": "842HYOFIQAEUDF78R8T7IU43HSADYGIFHBJSDHFA87SDF9PYvN1CEYASDUYFT5TQ97YASIHUDFAIUEYRISDKJHFAYSUDTFSDFADS"}', # noqa: B950
"providerLogGroupName": "providerLoggingGroupName",
"targetName": "STACK",
"targetType": "STACK",
"targetLogicalId": "myStack",
"hookEncryptionKeyArn": None,
"hookEncryptionKeyRole": None,
"payload": "https://someS3PresignedURL",
"targetModel": {},
},
"stackId": "arn:aws:cloudformation:us-east-1:123456789012:stack/SampleStack/e"
"722ae60-fe62-11e8-9a0e-0ae8cc519968",
"hookModel": sentinel.type_configuration,
}


TYPE_NAME = "Test::Foo::Bar"

Expand Down Expand Up @@ -456,3 +489,66 @@ def test_test_entrypoint_success():
)
def test_get_hook_status(operation_status, hook_status):
assert hook_status == Hook._get_hook_status(operation_status)


def test__hook_request_data_remote_payload():
non_remote_input = HookRequestData(
targetName="someTargetName",
targetType="someTargetModel",
targetLogicalId="someTargetLogicalId",
targetModel={"resourceProperties": {"propKeyA": "propValueA"}},
)
assert non_remote_input.is_hook_invocation_payload_remote() is False

non_remote_input_1 = HookRequestData(
targetName="someTargetName",
targetType="someTargetModel",
targetLogicalId="someTargetLogicalId",
targetModel={"resourceProperties": {"propKeyA": "propValueA"}},
payload="https://someUrl",
)
assert non_remote_input_1.is_hook_invocation_payload_remote() is False

remote_input = HookRequestData(
targetName="someTargetName",
targetType="someTargetModel",
targetLogicalId="someTargetLogicalId",
targetModel={},
payload="https://someUrl",
)
assert remote_input.is_hook_invocation_payload_remote() is True


def test__test_stack_level_hook_input(hook):
hook = Hook(TYPE_NAME, Mock())

with patch("cloudformation_cli_python_lib.utils.requests.get") as mock_requests_lib:
mock_requests_lib.return_value = MockResponse(200, {"foo": "bar"})
_, _, _, req = hook._parse_request(STACK_LEVEL_HOOK_ENTRYPOINT_PAYLOAD)

assert req.requestData.targetName == "STACK"
assert req.requestData.targetType == "STACK"
assert req.requestData.targetLogicalId == "myStack"
assert req.requestData.targetModel == {"foo": "bar"}


def test__test_stack_level_hook_input_failed_s3_download(hook):
hook = Hook(TYPE_NAME, Mock())

with patch("cloudformation_cli_python_lib.utils.requests.get") as mock_requests_lib:
mock_requests_lib.return_value = MockResponse(404, {"foo": "bar"})
_, _, _, req = hook._parse_request(STACK_LEVEL_HOOK_ENTRYPOINT_PAYLOAD)

assert req.requestData.targetName == "STACK"
assert req.requestData.targetType == "STACK"
assert req.requestData.targetLogicalId == "myStack"
assert req.requestData.targetModel == {}


@dataclass
class MockResponse:
status_code: int
_json: Mapping[str, Any]

def json(self) -> Mapping[str, Any]:
return self._json

0 comments on commit 727b2df

Please sign in to comment.