From f6cba8ab7391f76d5feb96e4915328c84184bc89 Mon Sep 17 00:00:00 2001 From: Mehmet Nuri Deveci <5735811+mndeveci@users.noreply.github.com> Date: Mon, 12 Jun 2023 18:01:04 -0700 Subject: [PATCH] feat: add lambda streaming support for remote invoke (#5307) * feat: support response streaming with remote invoke * add invoker and mappers * Update output formatting of stream response * add unit tests * fix formatting * Add docs * address comments * formatting * move is_function_invoke_mode_response_stream into lambda invoke executors and add/update string constants --- requirements/base.txt | 4 +- .../remote_invoke/lambda_invoke_executors.py | 132 ++++++++-- .../remote_invoke_executor_factory.py | 22 +- .../test_lambda_invoke_executors.py | 236 ++++++++++++++---- .../test_remote_invoke_executor_factory.py | 62 +++-- 5 files changed, 371 insertions(+), 85 deletions(-) diff --git a/requirements/base.txt b/requirements/base.txt index 5632d14797..7fa1b843d5 100644 --- a/requirements/base.txt +++ b/requirements/base.txt @@ -1,8 +1,8 @@ chevron~=0.12 click~=8.0 Flask<2.3 -#Need to add Schemas latest SDK. -boto3>=1.19.5,==1.* +#Need to add latest lambda changes which will return invoke mode details +boto3>=1.26.109,==1.* jmespath~=1.0.1 ruamel_yaml~=0.17.21 PyYAML>=5.4.1,==5.* diff --git a/samcli/lib/remote_invoke/lambda_invoke_executors.py b/samcli/lib/remote_invoke/lambda_invoke_executors.py index 82e65b8117..30f046127c 100644 --- a/samcli/lib/remote_invoke/lambda_invoke_executors.py +++ b/samcli/lib/remote_invoke/lambda_invoke_executors.py @@ -4,9 +4,11 @@ import base64 import json import logging +from abc import ABC, abstractmethod from json import JSONDecodeError -from typing import Any, Dict, cast +from typing import Any, cast +from botocore.eventstream import EventStream from botocore.exceptions import ClientError, ParamValidationError from botocore.response import StreamingBody @@ -26,12 +28,19 @@ LOG = logging.getLogger(__name__) FUNCTION_NAME = "FunctionName" PAYLOAD = "Payload" +EVENT_STREAM = "EventStream" +PAYLOAD_CHUNK = "PayloadChunk" +INVOKE_COMPLETE = "InvokeComplete" +LOG_RESULT = "LogResult" +INVOKE_MODE = "InvokeMode" +RESPONSE_STREAM = "RESPONSE_STREAM" -class LambdaInvokeExecutor(BotoActionExecutor): + +class AbstractLambdaInvokeExecutor(BotoActionExecutor, ABC): """ - Calls "invoke" method of "lambda" service with given input. - If a file location provided, the file handle will be passed as Payload object + Abstract class for different lambda invocation executors, see implementation for details. + For Payload parameter, if a file location provided, the file handle will be passed as Payload object """ _lambda_client: Any @@ -59,14 +68,9 @@ def validate_action_parameters(self, parameters: dict) -> None: def _execute_action(self, payload: str): self.request_parameters[FUNCTION_NAME] = self._function_name self.request_parameters[PAYLOAD] = payload - LOG.debug( - "Calling lambda_client.invoke with FunctionName:%s, Payload:%s, parameters:%s", - self._function_name, - payload, - self.request_parameters, - ) + try: - response = self._lambda_client.invoke(**self.request_parameters) + return self._execute_lambda_invoke(payload) except ParamValidationError as param_val_ex: raise InvalidResourceBotoParameterException( f"Invalid parameter key provided." @@ -80,7 +84,40 @@ def _execute_action(self, payload: str): elif boto_utils.get_client_error_code(client_ex) == "InvalidRequestContentException": raise InvalidResourceBotoParameterException(client_ex) from client_ex raise ErrorBotoApiCallException(client_ex) from client_ex - return response + + @abstractmethod + def _execute_lambda_invoke(self, payload: str): + pass + + +class LambdaInvokeExecutor(AbstractLambdaInvokeExecutor): + """ + Calls "invoke" method of "lambda" service with given input. + """ + + def _execute_lambda_invoke(self, payload: str) -> dict: + LOG.debug( + "Calling lambda_client.invoke with FunctionName:%s, Payload:%s, parameters:%s", + self._function_name, + payload, + self.request_parameters, + ) + return cast(dict, self._lambda_client.invoke(**self.request_parameters)) + + +class LambdaInvokeWithResponseStreamExecutor(AbstractLambdaInvokeExecutor): + """ + Calls "invoke_with_response_stream" method of "lambda" service with given input. + """ + + def _execute_lambda_invoke(self, payload: str) -> dict: + LOG.debug( + "Calling lambda_client.invoke_with_response_stream with FunctionName:%s, Payload:%s, parameters:%s", + self._function_name, + payload, + self.request_parameters, + ) + return cast(dict, self._lambda_client.invoke_with_response_stream(**self.request_parameters)) class DefaultConvertToJSON(RemoteInvokeRequestResponseMapper): @@ -124,6 +161,31 @@ def map(self, remote_invoke_input: RemoteInvokeExecutionInfo) -> RemoteInvokeExe return remote_invoke_input +class LambdaStreamResponseConverter(RemoteInvokeRequestResponseMapper): + """ + This class helps to convert response from lambda invoke_with_response_stream API call. + That API call returns 'EventStream' which yields 'PayloadChunk's and 'InvokeComplete' as they become available. + This mapper, gets all 'PayloadChunk's and 'InvokeComplete' events and decodes them for next mapper. + """ + + def map(self, remote_invoke_input: RemoteInvokeExecutionInfo) -> RemoteInvokeExecutionInfo: + LOG.debug("Mapping Lambda response to string object") + if not isinstance(remote_invoke_input.response, dict): + raise InvalideBotoResponseException("Invalid response type received from Lambda service, expecting dict") + + event_stream: EventStream = remote_invoke_input.response.get(EVENT_STREAM, []) + decoded_event_stream = [] + for event in event_stream: + if PAYLOAD_CHUNK in event: + decoded_payload_chunk = event.get(PAYLOAD_CHUNK).get(PAYLOAD).decode("utf-8") + decoded_event_stream.append({PAYLOAD_CHUNK: {PAYLOAD: decoded_payload_chunk}}) + if INVOKE_COMPLETE in event: + log_output = event.get(INVOKE_COMPLETE).get(LOG_RESULT, b"") + decoded_event_stream.append({INVOKE_COMPLETE: {LOG_RESULT: log_output}}) + remote_invoke_input.response[EVENT_STREAM] = decoded_event_stream + return remote_invoke_input + + class LambdaResponseOutputFormatter(RemoteInvokeRequestResponseMapper): """ This class helps to format output response for lambda service that will be printed on the CLI. @@ -139,8 +201,8 @@ def map(self, remote_invoke_input: RemoteInvokeExecutionInfo) -> RemoteInvokeExe """ if remote_invoke_input.output_format == RemoteInvokeOutputFormat.DEFAULT: LOG.debug("Formatting Lambda output response") - boto_response = cast(Dict, remote_invoke_input.response) - log_field = boto_response.get("LogResult") + boto_response = cast(dict, remote_invoke_input.response) + log_field = boto_response.get(LOG_RESULT) if log_field: log_result = base64.b64decode(log_field).decode("utf-8") remote_invoke_input.log_output = log_result @@ -152,3 +214,45 @@ def map(self, remote_invoke_input: RemoteInvokeExecutionInfo) -> RemoteInvokeExe remote_invoke_input.response = boto_response.get(PAYLOAD) return remote_invoke_input + + +class LambdaStreamResponseOutputFormatter(RemoteInvokeRequestResponseMapper): + """ + This class helps to format streaming output response for lambda service that will be printed on the CLI. + It loops through EventStream elements and adds them to response, and once InvokeComplete is reached, it updates + log_output and response objects in remote_invoke_input. + """ + + def map(self, remote_invoke_input: RemoteInvokeExecutionInfo) -> RemoteInvokeExecutionInfo: + """ + Maps the lambda response output to the type of output format specified as user input. + If output_format is original-boto-response, write the original boto API response + to stdout. + """ + if remote_invoke_input.output_format == RemoteInvokeOutputFormat.DEFAULT: + LOG.debug("Formatting Lambda output response") + boto_response = cast(dict, remote_invoke_input.response) + combined_response = "" + for event in boto_response.get(EVENT_STREAM, []): + if PAYLOAD_CHUNK in event: + payload_chunk = event.get(PAYLOAD_CHUNK).get(PAYLOAD) + combined_response = f"{combined_response}{payload_chunk}" + if INVOKE_COMPLETE in event: + log_result = base64.b64decode(event.get(INVOKE_COMPLETE).get(LOG_RESULT)).decode("utf-8") + remote_invoke_input.log_output = log_result + remote_invoke_input.response = combined_response + return remote_invoke_input + + +def _is_function_invoke_mode_response_stream(lambda_client: Any, function_name: str): + """ + Returns True if given function has RESPONSE_STREAM as InvokeMode, False otherwise + """ + try: + function_url_config = lambda_client.get_function_url_config(FunctionName=function_name) + function_invoke_mode = function_url_config.get(INVOKE_MODE) + LOG.debug("InvokeMode of function %s: %s", function_name, function_invoke_mode) + return function_invoke_mode == RESPONSE_STREAM + except ClientError as ex: + LOG.debug("Function %s, doesn't have Function URL configured, using regular invoke", function_name, exc_info=ex) + return False diff --git a/samcli/lib/remote_invoke/remote_invoke_executor_factory.py b/samcli/lib/remote_invoke/remote_invoke_executor_factory.py index a30a9532db..33ec958e1f 100644 --- a/samcli/lib/remote_invoke/remote_invoke_executor_factory.py +++ b/samcli/lib/remote_invoke/remote_invoke_executor_factory.py @@ -7,8 +7,12 @@ from samcli.lib.remote_invoke.lambda_invoke_executors import ( DefaultConvertToJSON, LambdaInvokeExecutor, + LambdaInvokeWithResponseStreamExecutor, LambdaResponseConverter, LambdaResponseOutputFormatter, + LambdaStreamResponseConverter, + LambdaStreamResponseOutputFormatter, + _is_function_invoke_mode_response_stream, ) from samcli.lib.remote_invoke.remote_invoke_executors import RemoteInvokeExecutor, ResponseObjectToJsonStringMapper from samcli.lib.utils.cloudformation import CloudFormationResourceSummary @@ -64,6 +68,22 @@ def _create_lambda_boto_executor(self, cfn_resource_summary: CloudFormationResou :return: Returns the created remote invoke Executor """ + lambda_client = self._boto_client_provider("lambda") + if _is_function_invoke_mode_response_stream(lambda_client, cfn_resource_summary.physical_resource_id): + LOG.debug("Creating response stream invocator for function %s", cfn_resource_summary.physical_resource_id) + return RemoteInvokeExecutor( + request_mappers=[DefaultConvertToJSON()], + response_mappers=[ + LambdaStreamResponseConverter(), + LambdaStreamResponseOutputFormatter(), + ResponseObjectToJsonStringMapper(), + ], + boto_action_executor=LambdaInvokeWithResponseStreamExecutor( + lambda_client, + cfn_resource_summary.physical_resource_id, + ), + ) + return RemoteInvokeExecutor( request_mappers=[DefaultConvertToJSON()], response_mappers=[ @@ -72,7 +92,7 @@ def _create_lambda_boto_executor(self, cfn_resource_summary: CloudFormationResou ResponseObjectToJsonStringMapper(), ], boto_action_executor=LambdaInvokeExecutor( - self._boto_client_provider("lambda"), + lambda_client, cfn_resource_summary.physical_resource_id, ), ) diff --git a/tests/unit/lib/remote_invoke/test_lambda_invoke_executors.py b/tests/unit/lib/remote_invoke/test_lambda_invoke_executors.py index b719b77da5..15ff272bac 100644 --- a/tests/unit/lib/remote_invoke/test_lambda_invoke_executors.py +++ b/tests/unit/lib/remote_invoke/test_lambda_invoke_executors.py @@ -1,23 +1,96 @@ +import base64 +from abc import ABC, abstractmethod +from typing import Any from unittest import TestCase from unittest.mock import Mock, patch + from parameterized import parameterized from samcli.lib.remote_invoke.lambda_invoke_executors import ( - LambdaInvokeExecutor, + EVENT_STREAM, + INVOKE_COMPLETE, + LOG_RESULT, + PAYLOAD, + PAYLOAD_CHUNK, + AbstractLambdaInvokeExecutor, + ClientError, DefaultConvertToJSON, - LambdaResponseConverter, - LambdaResponseOutputFormatter, ErrorBotoApiCallException, - InvalidResourceBotoParameterException, InvalideBotoResponseException, - RemoteInvokeOutputFormat, - ClientError, + InvalidResourceBotoParameterException, + LambdaInvokeExecutor, + LambdaInvokeWithResponseStreamExecutor, + LambdaResponseConverter, + LambdaResponseOutputFormatter, + LambdaStreamResponseConverter, + LambdaStreamResponseOutputFormatter, ParamValidationError, + RemoteInvokeOutputFormat, + _is_function_invoke_mode_response_stream, ) from samcli.lib.remote_invoke.remote_invoke_executors import RemoteInvokeExecutionInfo -class TestLambdaInvokeExecutor(TestCase): +class CommonTestsLambdaInvokeExecutor: + class AbstractLambdaInvokeExecutorTest(ABC, TestCase): + lambda_client: Any + lambda_invoke_executor: AbstractLambdaInvokeExecutor + + @abstractmethod + def _get_boto3_method(self): + pass + + @parameterized.expand( + [ + ("ValidationException",), + ("InvalidRequestContentException",), + ] + ) + def test_execute_action_invalid_parameter_value_throws_client_error(self, error_code): + given_payload = Mock() + error = ClientError(error_response={"Error": {"Code": error_code}}, operation_name="invoke") + self._get_boto3_method().side_effect = error + with self.assertRaises(InvalidResourceBotoParameterException): + self.lambda_invoke_executor._execute_action(given_payload) + + def test_execute_action_invalid_parameter_key_throws_parameter_validation_exception(self): + given_payload = Mock() + error = ParamValidationError(report="Invalid parameters") + self._get_boto3_method().side_effect = error + with self.assertRaises(InvalidResourceBotoParameterException): + self.lambda_invoke_executor._execute_action(given_payload) + + def test_execute_action_throws_client_error_exception(self): + given_payload = Mock() + error = ClientError(error_response={"Error": {"Code": "MockException"}}, operation_name="invoke") + self._get_boto3_method().side_effect = error + with self.assertRaises(ErrorBotoApiCallException): + self.lambda_invoke_executor._execute_action(given_payload) + + @parameterized.expand( + [ + ({}, {"InvocationType": "RequestResponse", "LogType": "Tail"}), + ({"InvocationType": "Event"}, {"InvocationType": "Event", "LogType": "Tail"}), + ( + {"InvocationType": "DryRun", "Qualifier": "TestQualifier"}, + {"InvocationType": "DryRun", "LogType": "Tail", "Qualifier": "TestQualifier"}, + ), + ( + {"InvocationType": "RequestResponse", "LogType": "None"}, + {"InvocationType": "RequestResponse", "LogType": "None"}, + ), + ( + {"FunctionName": "MyFunction", "Payload": "{hello world}"}, + {"InvocationType": "RequestResponse", "LogType": "Tail"}, + ), + ] + ) + def test_validate_action_parameters(self, parameters, expected_boto_parameters): + self.lambda_invoke_executor.validate_action_parameters(parameters) + self.assertEqual(self.lambda_invoke_executor.request_parameters, expected_boto_parameters) + + +class TestLambdaInvokeExecutor(CommonTestsLambdaInvokeExecutor.AbstractLambdaInvokeExecutorTest): def setUp(self) -> None: self.lambda_client = Mock() self.function_name = Mock() @@ -35,54 +108,30 @@ def test_execute_action(self): FunctionName=self.function_name, Payload=given_payload, InvocationType="RequestResponse", LogType="Tail" ) - @parameterized.expand( - [ - ("ValidationException",), - ("InvalidRequestContentException",), - ] - ) - def test_execute_action_invalid_parameter_value_throws_client_error(self, error_code): - given_payload = Mock() - error = ClientError(error_response={"Error": {"Code": error_code}}, operation_name="invoke") - self.lambda_client.invoke.side_effect = error - with self.assertRaises(InvalidResourceBotoParameterException): - self.lambda_invoke_executor._execute_action(given_payload) + def _get_boto3_method(self): + return self.lambda_client.invoke - def test_execute_action_invalid_parameter_key_throws_parameter_validation_exception(self): - given_payload = Mock() - error = ParamValidationError(report="Invalid parameters") - self.lambda_client.invoke.side_effect = error - with self.assertRaises(InvalidResourceBotoParameterException): - self.lambda_invoke_executor._execute_action(given_payload) - def test_execute_action_throws_client_error_exception(self): +class TestLambdaInvokeWithResponseStreamExecutor(CommonTestsLambdaInvokeExecutor.AbstractLambdaInvokeExecutorTest): + def setUp(self) -> None: + self.lambda_client = Mock() + self.function_name = Mock() + self.lambda_invoke_executor = LambdaInvokeWithResponseStreamExecutor(self.lambda_client, self.function_name) + + def test_execute_action(self): given_payload = Mock() - error = ClientError(error_response={"Error": {"Code": "MockException"}}, operation_name="invoke") - self.lambda_client.invoke.side_effect = error - with self.assertRaises(ErrorBotoApiCallException): - self.lambda_invoke_executor._execute_action(given_payload) + given_result = Mock() + self.lambda_client.invoke_with_response_stream.return_value = given_result - @parameterized.expand( - [ - ({}, {"InvocationType": "RequestResponse", "LogType": "Tail"}), - ({"InvocationType": "Event"}, {"InvocationType": "Event", "LogType": "Tail"}), - ( - {"InvocationType": "DryRun", "Qualifier": "TestQualifier"}, - {"InvocationType": "DryRun", "LogType": "Tail", "Qualifier": "TestQualifier"}, - ), - ( - {"InvocationType": "RequestResponse", "LogType": "None"}, - {"InvocationType": "RequestResponse", "LogType": "None"}, - ), - ( - {"FunctionName": "MyFunction", "Payload": "{hello world}"}, - {"InvocationType": "RequestResponse", "LogType": "Tail"}, - ), - ] - ) - def test_validate_action_parameters(self, parameters, expected_boto_parameters): - self.lambda_invoke_executor.validate_action_parameters(parameters) - self.assertEqual(self.lambda_invoke_executor.request_parameters, expected_boto_parameters) + result = self.lambda_invoke_executor._execute_action(given_payload) + + self.assertEqual(result, given_result) + self.lambda_client.invoke_with_response_stream.assert_called_with( + FunctionName=self.function_name, Payload=given_payload, InvocationType="RequestResponse", LogType="Tail" + ) + + def _get_boto3_method(self): + return self.lambda_client.invoke_with_response_stream class TestDefaultConvertToJSON(TestCase): @@ -143,6 +192,46 @@ def test_lambda_streaming_body_invalid_response_exception(self): self.lambda_response_converter.map(remote_invoke_execution_info) +class TestLambdaStreamResponseConverter(TestCase): + def setUp(self) -> None: + self.lambda_stream_response_converter = LambdaStreamResponseConverter() + + @parameterized.expand([({LOG_RESULT: base64.b64encode(b"log output")}, base64.b64encode(b"log output")), ({}, b"")]) + def test_lambda_streaming_body_response_conversion(self, invoke_complete_response, mapped_log_response): + output_format = RemoteInvokeOutputFormat.DEFAULT + given_test_result = { + EVENT_STREAM: [ + {PAYLOAD_CHUNK: {PAYLOAD: b"stream1"}}, + {PAYLOAD_CHUNK: {PAYLOAD: b"stream2"}}, + {PAYLOAD_CHUNK: {PAYLOAD: b"stream3"}}, + {INVOKE_COMPLETE: invoke_complete_response}, + ] + } + remote_invoke_execution_info = RemoteInvokeExecutionInfo(None, None, {}, output_format) + remote_invoke_execution_info.response = given_test_result + + expected_result = { + EVENT_STREAM: [ + {PAYLOAD_CHUNK: {PAYLOAD: "stream1"}}, + {PAYLOAD_CHUNK: {PAYLOAD: "stream2"}}, + {PAYLOAD_CHUNK: {PAYLOAD: "stream3"}}, + {INVOKE_COMPLETE: {LOG_RESULT: mapped_log_response}}, + ] + } + + result = self.lambda_stream_response_converter.map(remote_invoke_execution_info) + + self.assertEqual(result.response, expected_result) + + def test_lambda_streaming_body_invalid_response_exception(self): + output_format = RemoteInvokeOutputFormat.DEFAULT + remote_invoke_execution_info = RemoteInvokeExecutionInfo(None, None, {}, output_format) + remote_invoke_execution_info.response = Mock() + + with self.assertRaises(InvalideBotoResponseException): + self.lambda_stream_response_converter.map(remote_invoke_execution_info) + + class TestLambdaResponseOutputFormatter(TestCase): def setUp(self) -> None: self.lambda_response_converter = LambdaResponseOutputFormatter() @@ -191,3 +280,48 @@ def test_non_default_invocation_type_output_formatter(self, parameters): result = self.lambda_response_converter.map(remote_invoke_execution_info) self.assertEqual(result.response, expected_result) + + +class TestLambdaStreamResponseOutputFormatter(TestCase): + def setUp(self) -> None: + self.lambda_response_converter = LambdaStreamResponseOutputFormatter() + + def test_none_event_stream(self): + remote_invoke_execution_info = RemoteInvokeExecutionInfo(None, None, {}, RemoteInvokeOutputFormat.DEFAULT) + remote_invoke_execution_info.response = {} + + mapped_response = self.lambda_response_converter.map(remote_invoke_execution_info) + self.assertEqual(mapped_response.response, "") + + def test_event_stream(self): + remote_invoke_execution_info = RemoteInvokeExecutionInfo(None, None, {}, RemoteInvokeOutputFormat.DEFAULT) + remote_invoke_execution_info.response = { + EVENT_STREAM: [ + {PAYLOAD_CHUNK: {PAYLOAD: "stream1"}}, + {PAYLOAD_CHUNK: {PAYLOAD: "stream2"}}, + {PAYLOAD_CHUNK: {PAYLOAD: "stream3"}}, + {INVOKE_COMPLETE: {LOG_RESULT: base64.b64encode(b"log output")}}, + ] + } + + mapped_response = self.lambda_response_converter.map(remote_invoke_execution_info) + self.assertEqual(mapped_response.response, "stream1stream2stream3") + self.assertEqual(mapped_response.log_output, "log output") + + +class TestLambdaInvokeExecutorUtilities(TestCase): + @parameterized.expand( + [ + ({}, False), + ({"InvokeMode": "BUFFERED"}, False), + ({"InvokeMode": "RESPONSE_STREAM"}, True), + (ClientError({}, "operation"), False), + ] + ) + def test_is_function_invoke_mode_response_stream(self, boto_response, expected_result): + given_boto_client = Mock() + if type(boto_response) is ClientError: + given_boto_client.get_function_url_config.side_effect = boto_response + else: + given_boto_client.get_function_url_config.return_value = boto_response + self.assertEqual(_is_function_invoke_mode_response_stream(given_boto_client, "function_id"), expected_result) diff --git a/tests/unit/lib/remote_invoke/test_remote_invoke_executor_factory.py b/tests/unit/lib/remote_invoke/test_remote_invoke_executor_factory.py index 6ba12c409f..3a1f938e19 100644 --- a/tests/unit/lib/remote_invoke/test_remote_invoke_executor_factory.py +++ b/tests/unit/lib/remote_invoke/test_remote_invoke_executor_factory.py @@ -1,7 +1,11 @@ from unittest import TestCase from unittest.mock import patch, Mock -from samcli.lib.remote_invoke.remote_invoke_executor_factory import RemoteInvokeExecutorFactory +from parameterized import parameterized + +from samcli.lib.remote_invoke.remote_invoke_executor_factory import ( + RemoteInvokeExecutorFactory, +) class TestRemoteInvokeExecutorFactory(TestCase): @@ -33,21 +37,32 @@ def test_failed_create_test_executor(self): executor = self.remote_invoke_executor_factory.create_remote_invoke_executor(given_cfn_resource_summary) self.assertIsNone(executor) + @parameterized.expand([(True,), (False,)]) @patch("samcli.lib.remote_invoke.remote_invoke_executor_factory.LambdaInvokeExecutor") + @patch("samcli.lib.remote_invoke.remote_invoke_executor_factory.LambdaInvokeWithResponseStreamExecutor") @patch("samcli.lib.remote_invoke.remote_invoke_executor_factory.DefaultConvertToJSON") @patch("samcli.lib.remote_invoke.remote_invoke_executor_factory.LambdaResponseConverter") + @patch("samcli.lib.remote_invoke.remote_invoke_executor_factory.LambdaStreamResponseConverter") @patch("samcli.lib.remote_invoke.remote_invoke_executor_factory.LambdaResponseOutputFormatter") + @patch("samcli.lib.remote_invoke.remote_invoke_executor_factory.LambdaStreamResponseOutputFormatter") @patch("samcli.lib.remote_invoke.remote_invoke_executor_factory.ResponseObjectToJsonStringMapper") @patch("samcli.lib.remote_invoke.remote_invoke_executor_factory.RemoteInvokeExecutor") + @patch("samcli.lib.remote_invoke.remote_invoke_executor_factory._is_function_invoke_mode_response_stream") def test_create_lambda_test_executor( self, + is_function_invoke_mode_response_stream, + patched_is_function_invoke_mode_response_stream, patched_remote_invoke_executor, patched_object_to_json_converter, + patched_stream_response_output_formatter, patched_response_output_formatter, + patched_stream_response_converter, patched_response_converter, patched_convert_to_default_json, + patched_lambda_invoke_with_response_stream_executor, patched_lambda_invoke_executor, ): + patched_is_function_invoke_mode_response_stream.return_value = is_function_invoke_mode_response_stream given_physical_resource_id = "physical_resource_id" given_cfn_resource_summary = Mock(physical_resource_id=given_physical_resource_id) @@ -60,20 +75,33 @@ def test_create_lambda_test_executor( lambda_executor = self.remote_invoke_executor_factory._create_lambda_boto_executor(given_cfn_resource_summary) self.assertEqual(lambda_executor, given_remote_invoke_executor) - - patched_convert_to_default_json.assert_called_once() - patched_response_output_formatter.assert_called_once() - patched_response_converter.assert_called_once() - self.boto_client_provider_mock.assert_called_with("lambda") - patched_lambda_invoke_executor.assert_called_with(given_lambda_client, given_physical_resource_id) - - patched_remote_invoke_executor.assert_called_with( - request_mappers=[patched_convert_to_default_json()], - response_mappers=[ - patched_response_converter(), - patched_response_output_formatter(), - patched_object_to_json_converter(), - ], - boto_action_executor=patched_lambda_invoke_executor(), - ) + patched_convert_to_default_json.assert_called_once() + patched_object_to_json_converter.assert_called_once() + + if is_function_invoke_mode_response_stream: + patched_stream_response_output_formatter.assert_called_once() + patched_stream_response_converter.assert_called_once() + patched_lambda_invoke_with_response_stream_executor.assert_called_once() + patched_remote_invoke_executor.assert_called_with( + request_mappers=[patched_convert_to_default_json()], + response_mappers=[ + patched_stream_response_converter(), + patched_stream_response_output_formatter(), + patched_object_to_json_converter(), + ], + boto_action_executor=patched_lambda_invoke_with_response_stream_executor(), + ) + else: + patched_response_output_formatter.assert_called_once() + patched_response_converter.assert_called_once() + patched_lambda_invoke_executor.assert_called_with(given_lambda_client, given_physical_resource_id) + patched_remote_invoke_executor.assert_called_with( + request_mappers=[patched_convert_to_default_json()], + response_mappers=[ + patched_response_converter(), + patched_response_output_formatter(), + patched_object_to_json_converter(), + ], + boto_action_executor=patched_lambda_invoke_executor(), + )