Skip to content

Commit

Permalink
refactor: make remote invoke reactive to display results as soon as t…
Browse files Browse the repository at this point in the history
…hey are available (aws#5359)

* refactor: make remote invoke reactive to display results as soon as they are available

* addressed the comments
  • Loading branch information
mndeveci authored and lucashuy committed Jun 22, 2023
1 parent a0e4063 commit c35758c
Show file tree
Hide file tree
Showing 10 changed files with 305 additions and 305 deletions.
13 changes: 1 addition & 12 deletions samcli/commands/remote/invoke/cli.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""CLI command for "invoke" command."""
import logging
from io import TextIOWrapper
from typing import cast

import click

Expand Down Expand Up @@ -124,16 +123,6 @@ def do_cli(
payload=event, payload_file=event_file, parameters=parameter, output_format=output_format
)

remote_invoke_result = remote_invoke_context.run(remote_invoke_input=remote_invoke_input)

if remote_invoke_result.is_succeeded():
LOG.debug("Invoking resource was successfull, writing response to stdout")
if remote_invoke_result.log_output:
LOG.debug("Writing log output to stderr")
remote_invoke_context.stderr.write(remote_invoke_result.log_output.encode())
output_response = cast(str, remote_invoke_result.response)
remote_invoke_context.stdout.write(output_response.encode())
else:
raise cast(Exception, remote_invoke_result.exception)
remote_invoke_context.run(remote_invoke_input=remote_invoke_input)
except (ErrorBotoApiCallException, InvalideBotoResponseException, InvalidResourceBotoParameterException) as ex:
raise UserException(str(ex), wrapped_from=ex.__class__.__name__) from ex
48 changes: 39 additions & 9 deletions samcli/commands/remote/remote_invoke_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
Context object used by `sam remote invoke` command
"""
import logging
from dataclasses import dataclass
from typing import Optional, cast

from botocore.exceptions import ClientError
Expand All @@ -15,7 +16,12 @@
UnsupportedServiceForRemoteInvoke,
)
from samcli.lib.remote_invoke.remote_invoke_executor_factory import RemoteInvokeExecutorFactory
from samcli.lib.remote_invoke.remote_invoke_executors import RemoteInvokeExecutionInfo
from samcli.lib.remote_invoke.remote_invoke_executors import (
RemoteInvokeConsumer,
RemoteInvokeExecutionInfo,
RemoteInvokeLogOutput,
RemoteInvokeResponse,
)
from samcli.lib.utils import osutils
from samcli.lib.utils.arn_utils import ARNParts, InvalidArnValue
from samcli.lib.utils.boto_utils import BotoProviderType, get_client_error_code
Expand Down Expand Up @@ -61,7 +67,7 @@ def __enter__(self) -> "RemoteInvokeContext":
def __exit__(self, *args) -> None:
pass

def run(self, remote_invoke_input: RemoteInvokeExecutionInfo) -> RemoteInvokeExecutionInfo:
def run(self, remote_invoke_input: RemoteInvokeExecutionInfo) -> None:
"""
Instantiates remote invoke executor with populated resource summary information, executes it with the provided
input & returns its response back to the caller. If no executor can be instantiated it raises
Expand All @@ -72,11 +78,6 @@ def run(self, remote_invoke_input: RemoteInvokeExecutionInfo) -> RemoteInvokeExe
remote_invoke_input: RemoteInvokeExecutionInfo
RemoteInvokeExecutionInfo which contains the payload and other information that will be required during
the invocation
Returns
-------
RemoteInvokeExecutionInfo
Populates result and exception info (if any) and returns back to the caller
"""
if not self._resource_summary:
raise AmbiguousResourceForRemoteInvoke(
Expand All @@ -85,13 +86,18 @@ def run(self, remote_invoke_input: RemoteInvokeExecutionInfo) -> RemoteInvokeExe
)

remote_invoke_executor_factory = RemoteInvokeExecutorFactory(self._boto_client_provider)
remote_invoke_executor = remote_invoke_executor_factory.create_remote_invoke_executor(self._resource_summary)
remote_invoke_executor = remote_invoke_executor_factory.create_remote_invoke_executor(
self._resource_summary,
remote_invoke_input.output_format,
DefaultRemoteInvokeResponseConsumer(self.stdout),
DefaultRemoteInvokeLogConsumer(self.stderr),
)
if not remote_invoke_executor:
raise NoExecutorFoundForRemoteInvoke(
f"Resource type {self._resource_summary.resource_type} is not supported for remote invoke"
)

return remote_invoke_executor.execute(remote_invoke_input)
remote_invoke_executor.execute(remote_invoke_input)

def _populate_resource_summary(self) -> None:
"""
Expand Down Expand Up @@ -225,3 +231,27 @@ def stderr(self) -> StreamWriter:
"""
stream = osutils.stderr()
return StreamWriter(stream, auto_flush=True)


@dataclass
class DefaultRemoteInvokeResponseConsumer(RemoteInvokeConsumer[RemoteInvokeResponse]):
"""
Default RemoteInvokeResponse consumer, writes given response event to the configured StreamWriter
"""

_stream_writer: StreamWriter

def consume(self, remote_invoke_response: RemoteInvokeResponse) -> None:
self._stream_writer.write(cast(str, remote_invoke_response.response).encode())


@dataclass
class DefaultRemoteInvokeLogConsumer(RemoteInvokeConsumer[RemoteInvokeLogOutput]):
"""
Default RemoteInvokeLogOutput consumer, writes given log event to the configured StreamWriter
"""

_stream_writer: StreamWriter

def consume(self, remote_invoke_response: RemoteInvokeLogOutput) -> None:
self._stream_writer.write(remote_invoke_response.log_output.encode())
114 changes: 41 additions & 73 deletions samcli/lib/remote_invoke/lambda_invoke_executors.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,11 @@
from samcli.lib.remote_invoke.remote_invoke_executors import (
BotoActionExecutor,
RemoteInvokeExecutionInfo,
RemoteInvokeIterableResponseType,
RemoteInvokeLogOutput,
RemoteInvokeOutputFormat,
RemoteInvokeRequestResponseMapper,
RemoteInvokeResponse,
)
from samcli.lib.utils import boto_utils

Expand All @@ -45,10 +48,12 @@ class AbstractLambdaInvokeExecutor(BotoActionExecutor, ABC):

_lambda_client: Any
_function_name: str
_remote_output_format: RemoteInvokeOutputFormat

def __init__(self, lambda_client: Any, function_name: str):
def __init__(self, lambda_client: Any, function_name: str, remote_output_format: RemoteInvokeOutputFormat):
self._lambda_client = lambda_client
self._function_name = function_name
self._remote_output_format = remote_output_format
self.request_parameters = {"InvocationType": "RequestResponse", "LogType": "Tail"}

def validate_action_parameters(self, parameters: dict) -> None:
Expand All @@ -65,12 +70,15 @@ def validate_action_parameters(self, parameters: dict) -> None:
else:
self.request_parameters[parameter_key] = parameter_value

def _execute_action(self, payload: str):
def _execute_action(self, payload: str) -> RemoteInvokeIterableResponseType:
self.request_parameters[FUNCTION_NAME] = self._function_name
self.request_parameters[PAYLOAD] = payload

return self._execute_lambda_invoke(payload)

def _execute_boto_call(self, boto_client_method) -> dict:
try:
return self._execute_lambda_invoke(payload)
return cast(dict, boto_client_method(**self.request_parameters))
except ParamValidationError as param_val_ex:
raise InvalidResourceBotoParameterException(
f"Invalid parameter key provided."
Expand All @@ -86,41 +94,60 @@ def _execute_action(self, payload: str):
raise ErrorBotoApiCallException(client_ex) from client_ex

@abstractmethod
def _execute_lambda_invoke(self, payload: str):
pass
def _execute_lambda_invoke(self, payload: str) -> RemoteInvokeIterableResponseType:
raise NotImplementedError()


class LambdaInvokeExecutor(AbstractLambdaInvokeExecutor):
"""
Calls "invoke" method of "lambda" service with given input.
"""

def _execute_lambda_invoke(self, payload: str) -> dict:
def _execute_lambda_invoke(self, payload: str) -> RemoteInvokeIterableResponseType:
LOG.debug(
"Calling lambda_client.invoke with FunctionName:%s, Payload:%s, parameters:%s",
self._function_name,
payload,
self.request_parameters,
)
return cast(dict, self._lambda_client.invoke(**self.request_parameters))
lambda_response = self._execute_boto_call(self._lambda_client.invoke)
if self._remote_output_format == RemoteInvokeOutputFormat.RAW:
yield RemoteInvokeResponse(lambda_response)
if self._remote_output_format == RemoteInvokeOutputFormat.DEFAULT:
log_result = lambda_response.get(LOG_RESULT)
if log_result:
yield RemoteInvokeLogOutput(base64.b64decode(log_result).decode("utf-8"))
yield RemoteInvokeResponse(cast(StreamingBody, lambda_response.get(PAYLOAD)).read().decode("utf-8"))


class LambdaInvokeWithResponseStreamExecutor(AbstractLambdaInvokeExecutor):
"""
Calls "invoke_with_response_stream" method of "lambda" service with given input.
"""

def _execute_lambda_invoke(self, payload: str) -> dict:
def _execute_lambda_invoke(self, payload: str) -> RemoteInvokeIterableResponseType:
LOG.debug(
"Calling lambda_client.invoke_with_response_stream with FunctionName:%s, Payload:%s, parameters:%s",
self._function_name,
payload,
self.request_parameters,
)
return cast(dict, self._lambda_client.invoke_with_response_stream(**self.request_parameters))
lambda_response = self._execute_boto_call(self._lambda_client.invoke_with_response_stream)
if self._remote_output_format == RemoteInvokeOutputFormat.RAW:
yield RemoteInvokeResponse(lambda_response)
if self._remote_output_format == RemoteInvokeOutputFormat.DEFAULT:
event_stream: EventStream = lambda_response.get(EVENT_STREAM, [])
for event in event_stream:
if PAYLOAD_CHUNK in event:
yield RemoteInvokeResponse(event.get(PAYLOAD_CHUNK).get(PAYLOAD).decode("utf-8"))
if INVOKE_COMPLETE in event:
if LOG_RESULT in event.get(INVOKE_COMPLETE):
yield RemoteInvokeLogOutput(
base64.b64decode(event.get(INVOKE_COMPLETE).get(LOG_RESULT)).decode("utf-8")
)


class DefaultConvertToJSON(RemoteInvokeRequestResponseMapper):
class DefaultConvertToJSON(RemoteInvokeRequestResponseMapper[RemoteInvokeExecutionInfo]):
"""
If a regular string is provided as payload, this class will convert it into a JSON object
"""
Expand All @@ -143,13 +170,13 @@ def map(self, test_input: RemoteInvokeExecutionInfo) -> RemoteInvokeExecutionInf
return test_input


class LambdaResponseConverter(RemoteInvokeRequestResponseMapper):
class LambdaResponseConverter(RemoteInvokeRequestResponseMapper[RemoteInvokeResponse]):
"""
This class helps to convert response from lambda service. Normally lambda service
returns 'Payload' field as stream, this class converts that stream into string object
"""

def map(self, remote_invoke_input: RemoteInvokeExecutionInfo) -> RemoteInvokeExecutionInfo:
def map(self, remote_invoke_input: RemoteInvokeResponse) -> RemoteInvokeResponse:
LOG.debug("Mapping Lambda response to string object")
if not isinstance(remote_invoke_input.response, dict):
raise InvalideBotoResponseException("Invalid response type received from Lambda service, expecting dict")
Expand All @@ -168,7 +195,7 @@ class LambdaStreamResponseConverter(RemoteInvokeRequestResponseMapper):
This mapper, gets all 'PayloadChunk's and 'InvokeComplete' events and decodes them for next mapper.
"""

def map(self, remote_invoke_input: RemoteInvokeExecutionInfo) -> RemoteInvokeExecutionInfo:
def map(self, remote_invoke_input: RemoteInvokeResponse) -> RemoteInvokeResponse:
LOG.debug("Mapping Lambda response to string object")
if not isinstance(remote_invoke_input.response, dict):
raise InvalideBotoResponseException("Invalid response type received from Lambda service, expecting dict")
Expand All @@ -180,70 +207,11 @@ def map(self, remote_invoke_input: RemoteInvokeExecutionInfo) -> RemoteInvokeExe
decoded_payload_chunk = event.get(PAYLOAD_CHUNK).get(PAYLOAD).decode("utf-8")
decoded_event_stream.append({PAYLOAD_CHUNK: {PAYLOAD: decoded_payload_chunk}})
if INVOKE_COMPLETE in event:
log_output = event.get(INVOKE_COMPLETE).get(LOG_RESULT, b"")
decoded_event_stream.append({INVOKE_COMPLETE: {LOG_RESULT: log_output}})
decoded_event_stream.append(event)
remote_invoke_input.response[EVENT_STREAM] = decoded_event_stream
return remote_invoke_input


class LambdaResponseOutputFormatter(RemoteInvokeRequestResponseMapper):
"""
This class helps to format output response for lambda service that will be printed on the CLI.
If LogResult is found in the response, the decoded LogResult will be written to stderr. The response payload will
be written to stdout.
"""

def map(self, remote_invoke_input: RemoteInvokeExecutionInfo) -> RemoteInvokeExecutionInfo:
"""
Maps the lambda response output to the type of output format specified as user input.
If output_format is original-boto-response, write the original boto API response
to stdout.
"""
if remote_invoke_input.output_format == RemoteInvokeOutputFormat.DEFAULT:
LOG.debug("Formatting Lambda output response")
boto_response = cast(dict, remote_invoke_input.response)
log_field = boto_response.get(LOG_RESULT)
if log_field:
log_result = base64.b64decode(log_field).decode("utf-8")
remote_invoke_input.log_output = log_result

invocation_type_parameter = remote_invoke_input.parameters.get("InvocationType")
if invocation_type_parameter and invocation_type_parameter != "RequestResponse":
remote_invoke_input.response = {"StatusCode": boto_response["StatusCode"]}
else:
remote_invoke_input.response = boto_response.get(PAYLOAD)

return remote_invoke_input


class LambdaStreamResponseOutputFormatter(RemoteInvokeRequestResponseMapper):
"""
This class helps to format streaming output response for lambda service that will be printed on the CLI.
It loops through EventStream elements and adds them to response, and once InvokeComplete is reached, it updates
log_output and response objects in remote_invoke_input.
"""

def map(self, remote_invoke_input: RemoteInvokeExecutionInfo) -> RemoteInvokeExecutionInfo:
"""
Maps the lambda response output to the type of output format specified as user input.
If output_format is original-boto-response, write the original boto API response
to stdout.
"""
if remote_invoke_input.output_format == RemoteInvokeOutputFormat.DEFAULT:
LOG.debug("Formatting Lambda output response")
boto_response = cast(dict, remote_invoke_input.response)
combined_response = ""
for event in boto_response.get(EVENT_STREAM, []):
if PAYLOAD_CHUNK in event:
payload_chunk = event.get(PAYLOAD_CHUNK).get(PAYLOAD)
combined_response = f"{combined_response}{payload_chunk}"
if INVOKE_COMPLETE in event:
log_result = base64.b64decode(event.get(INVOKE_COMPLETE).get(LOG_RESULT)).decode("utf-8")
remote_invoke_input.log_output = log_result
remote_invoke_input.response = combined_response
return remote_invoke_input


def _is_function_invoke_mode_response_stream(lambda_client: Any, function_name: str):
"""
Returns True if given function has RESPONSE_STREAM as InvokeMode, False otherwise
Expand Down
Loading

0 comments on commit c35758c

Please sign in to comment.