-
Notifications
You must be signed in to change notification settings - Fork 650
Commit
Add Recursion Detection middleware to all SDK requests
- Loading branch information
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
{ | ||
"id": "d74f8a81-3ddb-431f-b600-6abefbdaba1b", | ||
"type": "feature", | ||
"description": "add recursion detection middleware to all SDK requests to avoid recursion invocation in Lambda", | ||
"modules": [ | ||
"." | ||
] | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
package middleware | ||
|
||
import ( | ||
"context" | ||
"fmt" | ||
"github.com/aws/smithy-go/middleware" | ||
smithyhttp "github.com/aws/smithy-go/transport/http" | ||
"os" | ||
) | ||
|
||
const envAwsLambdaFunctionName = "AWS_LAMBDA_FUNCTION_NAME" | ||
const envAmznTraceID = "_X_AMZN_TRACE_ID" | ||
const amznTraceIDHeader = "X-Amzn-Trace-Id" | ||
|
||
// AddRecursionDetection adds recursionDetection to the middleware stack | ||
func AddRecursionDetection(stack *middleware.Stack) error { | ||
return stack.Build.Add(&RecursionDetection{}, middleware.After) | ||
} | ||
|
||
// RecursionDetection detects Lambda environment and sets its X-Ray trace ID to request header if absent | ||
// to avoid recursion invocation in Lambda | ||
type RecursionDetection struct{} | ||
|
||
// ID returns the middleware identifier | ||
func (m *RecursionDetection) ID() string { | ||
return "RecursionDetection" | ||
} | ||
|
||
// HandleBuild detects Lambda environment and adds its trace ID to request header if absent | ||
func (m *RecursionDetection) HandleBuild( | ||
ctx context.Context, in middleware.BuildInput, next middleware.BuildHandler, | ||
) ( | ||
out middleware.BuildOutput, metadata middleware.Metadata, err error, | ||
) { | ||
req, ok := in.Request.(*smithyhttp.Request) | ||
if !ok { | ||
return out, metadata, fmt.Errorf("unknown request type %T", req) | ||
} | ||
|
||
_, hasLambdaEnv := os.LookupEnv(envAwsLambdaFunctionName) | ||
xAmznTraceID, hasTraceID := os.LookupEnv(envAmznTraceID) | ||
value := req.Header.Get(amznTraceIDHeader) | ||
// only set the X-Amzn-Trace-Id header when it is not set initially, the | ||
// current environment is Lambda and the _X_AMZN_TRACE_ID env variable exists | ||
if value != "" || !hasLambdaEnv || !hasTraceID { | ||
return next.HandleBuild(ctx, in) | ||
} | ||
|
||
req.Header.Set(amznTraceIDHeader, percentEncode(xAmznTraceID)) | ||
return next.HandleBuild(ctx, in) | ||
} | ||
|
||
func percentEncode(s string) string { | ||
upperhex := "0123456789ABCDEF" | ||
hexCount := 0 | ||
for i := 0; i < len(s); i++ { | ||
c := s[i] | ||
if shouldEncode(c) { | ||
hexCount++ | ||
} | ||
} | ||
|
||
if hexCount == 0 { | ||
return s | ||
} | ||
|
||
required := len(s) + 2*hexCount | ||
t := make([]byte, required) | ||
j := 0 | ||
for i := 0; i < len(s); i++ { | ||
if c := s[i]; shouldEncode(c) { | ||
t[j] = '%' | ||
t[j+1] = upperhex[c>>4] | ||
t[j+2] = upperhex[c&15] | ||
j += 3 | ||
} else { | ||
t[j] = c | ||
j++ | ||
} | ||
} | ||
return string(t) | ||
} | ||
|
||
func shouldEncode(c byte) bool { | ||
if 'a' <= c && c <= 'z' || 'A' <= c && c <= 'Z' || '0' <= c && c <= '9' { | ||
return false | ||
} | ||
switch c { | ||
case '-', '=', ';', ':', '+', '&', '[', ']', '{', '}', '"', '\'', ',': | ||
return false | ||
default: | ||
return true | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
package middleware | ||
|
||
import ( | ||
"context" | ||
smithymiddleware "github.com/aws/smithy-go/middleware" | ||
smithyhttp "github.com/aws/smithy-go/transport/http" | ||
"os" | ||
"testing" | ||
) | ||
|
||
func TestRecursionDetection(t *testing.T) { | ||
cases := map[string]struct { | ||
LambdaFuncName string | ||
TraceID string | ||
HeaderBefore string | ||
HeaderAfter string | ||
}{ | ||
"non lambda env and no trace ID header before": {}, | ||
"with lambda env but no trace ID env variable, no trace ID header before": { | ||
LambdaFuncName: "some-function1", | ||
}, | ||
"with lambda env and trace ID env variable, no trace ID header before": { | ||
LambdaFuncName: "some-function2", | ||
TraceID: "traceID1", | ||
HeaderAfter: "traceID1", | ||
}, | ||
"with lambda env and trace ID env variable, has trace ID header before": { | ||
LambdaFuncName: "some-function3", | ||
TraceID: "traceID2", | ||
HeaderBefore: "traceID1", | ||
HeaderAfter: "traceID1", | ||
}, | ||
"with lambda env and trace ID (needs encoding) env variable, no trace ID header before": { | ||
LambdaFuncName: "some-function4", | ||
TraceID: "traceID3\n", | ||
HeaderAfter: "traceID3%0A", | ||
}, | ||
"with lambda env and trace ID (contains chars must not be encoded) env variable, no trace ID header before": { | ||
LambdaFuncName: "some-function5", | ||
TraceID: "traceID4-=;:+&[]{}\"'", | ||
HeaderAfter: "traceID4-=;:+&[]{}\"'", | ||
}, | ||
} | ||
|
||
for name, c := range cases { | ||
t.Run(name, func(t *testing.T) { | ||
// clear current case's environment variables and restore them at the end of the test func goroutine | ||
restoreEnv := clearEnv() | ||
defer restoreEnv() | ||
|
||
setEnvVar(t, envAwsLambdaFunctionName, c.LambdaFuncName) | ||
setEnvVar(t, envAmznTraceID, c.TraceID) | ||
|
||
req := smithyhttp.NewStackRequest().(*smithyhttp.Request) | ||
if c.HeaderBefore != "" { | ||
req.Header.Set(amznTraceIDHeader, c.HeaderBefore) | ||
} | ||
var updatedRequest *smithyhttp.Request | ||
m := RecursionDetection{} | ||
_, _, err := m.HandleBuild(context.Background(), | ||
smithymiddleware.BuildInput{Request: req}, | ||
smithymiddleware.BuildHandlerFunc(func(ctx context.Context, input smithymiddleware.BuildInput) ( | ||
out smithymiddleware.BuildOutput, metadata smithymiddleware.Metadata, err error) { | ||
updatedRequest = input.Request.(*smithyhttp.Request) | ||
return out, metadata, nil | ||
}), | ||
) | ||
if err != nil { | ||
t.Fatalf("expect no error, got %v", err) | ||
} | ||
|
||
if e, a := c.HeaderAfter, updatedRequest.Header.Get(amznTraceIDHeader); e != a { | ||
t.Errorf("expect header value %v found, got %v", e, a) | ||
} | ||
}) | ||
} | ||
} | ||
|
||
// check if test case has environment variable and set to os if it has | ||
func setEnvVar(t *testing.T, key, value string) { | ||
if value != "" { | ||
err := os.Setenv(key, value) | ||
if err != nil { | ||
t.Fatalf("expect no error, got %v", err) | ||
} | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
package software.amazon.smithy.aws.go.codegen.customization; | ||
|
||
import software.amazon.smithy.aws.go.codegen.AwsGoDependency; | ||
import software.amazon.smithy.go.codegen.SymbolUtils; | ||
import software.amazon.smithy.go.codegen.integration.GoIntegration; | ||
import software.amazon.smithy.go.codegen.integration.MiddlewareRegistrar; | ||
import software.amazon.smithy.go.codegen.integration.RuntimeClientPlugin; | ||
import software.amazon.smithy.utils.ListUtils; | ||
|
||
import java.util.List; | ||
|
||
/** | ||
* Add middleware during operation builder step, which detects Lambda environment and sets its X-Ray trace ID to | ||
* request header if absent to avoid recursion invocation in Lambda | ||
*/ | ||
public class LambdaRecursionDetection implements GoIntegration { | ||
/** | ||
* Gets the sort order of the customization from -128 to 127, with lowest | ||
* executed first. | ||
* | ||
* @return Returns the sort order, defaults to -40. | ||
*/ | ||
@Override | ||
public byte getOrder() { | ||
return 126; | ||
} | ||
|
||
@Override | ||
public List<RuntimeClientPlugin> getClientPlugins() { | ||
return ListUtils.of( | ||
RuntimeClientPlugin.builder() | ||
.registerMiddleware(MiddlewareRegistrar.builder() | ||
.resolvedFunction(SymbolUtils.createValueSymbolBuilder( | ||
"AddRecursionDetection", AwsGoDependency.AWS_MIDDLEWARE) | ||
.build()) | ||
.build() | ||
) | ||
.build() | ||
); | ||
} | ||
} |
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.