diff --git a/lambda/runtime_api_client.go b/lambda/runtime_api_client.go index a83c3ce8..84384c29 100644 --- a/lambda/runtime_api_client.go +++ b/lambda/runtime_api_client.go @@ -6,6 +6,7 @@ package lambda import ( "bytes" + "encoding/base64" "fmt" "io" "io/ioutil" //nolint: staticcheck @@ -21,6 +22,8 @@ const ( headerCognitoIdentity = "Lambda-Runtime-Cognito-Identity" headerClientContext = "Lambda-Runtime-Client-Context" headerInvokedFunctionARN = "Lambda-Runtime-Invoked-Function-Arn" + trailerLambdaErrorType = "Lambda-Runtime-Function-Error-Type" + trailerLambdaErrorBody = "Lambda-Runtime-Function-Error-Body" contentTypeJSON = "application/json" contentTypeBytes = "application/octet-stream" apiVersion = "2018-06-01" @@ -106,10 +109,12 @@ func (c *runtimeAPIClient) next() (*invoke, error) { } func (c *runtimeAPIClient) post(url string, body io.Reader, contentType string) error { - req, err := http.NewRequest(http.MethodPost, url, body) + b := newErrorCapturingReader(body) + req, err := http.NewRequest(http.MethodPost, url, b) if err != nil { return fmt.Errorf("failed to construct POST request to %s: %v", url, err) } + req.Trailer = b.Trailer req.Header.Set("User-Agent", c.userAgent) req.Header.Set("Content-Type", contentType) @@ -122,7 +127,6 @@ func (c *runtimeAPIClient) post(url string, body io.Reader, contentType string) log.Printf("runtime API client failed to close %s response body: %v", url, err) } }() - if resp.StatusCode != http.StatusAccepted { return fmt.Errorf("failed to POST to %s: got unexpected status code: %d", url, resp.StatusCode) } @@ -134,3 +138,30 @@ func (c *runtimeAPIClient) post(url string, body io.Reader, contentType string) return nil } + +func newErrorCapturingReader(r io.Reader) *errorCapturingReader { + trailer := http.Header{ + trailerLambdaErrorType: nil, + trailerLambdaErrorBody: nil, + } + return &errorCapturingReader{r, trailer} +} + +type errorCapturingReader struct { + reader io.Reader + Trailer http.Header +} + +func (r *errorCapturingReader) Read(p []byte) (int, error) { + if r.reader == nil { + return 0, io.EOF + } + n, err := r.reader.Read(p) + if err != nil && err != io.EOF { + lambdaErr := lambdaErrorResponse(err) + r.Trailer.Set(trailerLambdaErrorType, lambdaErr.Type) + r.Trailer.Set(trailerLambdaErrorBody, base64.StdEncoding.EncodeToString(safeMarshal(lambdaErr))) + return 0, io.EOF + } + return n, err +}