Skip to content

Commit

Permalink
fix: don't expect response to be json in endpointcreds provider (#2381)
Browse files Browse the repository at this point in the history
  • Loading branch information
lucix-aws authored Nov 21, 2023
1 parent 3bd97c0 commit 1c69d08
Show file tree
Hide file tree
Showing 5 changed files with 173 additions and 26 deletions.
8 changes: 8 additions & 0 deletions .changelog/018d3cef4def4b019c5ac7c60555b7e3.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
{
"id": "018d3cef-4def-4b01-9c5a-c7c60555b7e3",
"type": "bugfix",
"description": "Don't expect error responses to have a JSON payload in the endpointcreds provider.",
"modules": [
"credentials"
]
}
23 changes: 19 additions & 4 deletions credentials/endpointcreds/internal/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,16 @@ func New(options Options, optFns ...func(*Options)) *Client {
}

if options.Retryer == nil {
options.Retryer = retry.NewStandard()
// Amazon-owned implementations of this endpoint are known to sometimes
// return plaintext responses (i.e. no Code) like normal, add a few
// additional status codes
options.Retryer = retry.NewStandard(func(o *retry.StandardOptions) {
o.Retryables = append(o.Retryables, retry.RetryableHTTPStatusCode{
Codes: map[int]struct{}{
http.StatusTooManyRequests: {},
},
})
})
}

for _, fn := range optFns {
Expand Down Expand Up @@ -122,9 +131,10 @@ type GetCredentialsOutput struct {

// EndpointError is an error returned from the endpoint service
type EndpointError struct {
Code string `json:"code"`
Message string `json:"message"`
Fault smithy.ErrorFault `json:"-"`
Code string `json:"code"`
Message string `json:"message"`
Fault smithy.ErrorFault `json:"-"`
statusCode int `json:"-"`
}

// Error is the error mesage string
Expand All @@ -146,3 +156,8 @@ func (e *EndpointError) ErrorMessage() string {
func (e *EndpointError) ErrorFault() smithy.ErrorFault {
return e.Fault
}

// HTTPStatusCode implements retry.HTTPStatusCode.
func (e *EndpointError) HTTPStatusCode() int {
return e.statusCode
}
53 changes: 38 additions & 15 deletions credentials/endpointcreds/internal/client/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,21 +18,23 @@ import (

func TestClient_GetCredentials(t *testing.T) {
cases := map[string]struct {
Token string
RelativeURI string
ResponseCode int
ResponseBody []byte
ExpectResult *GetCredentialsOutput
ExpectErr bool
ValidateRequest func(*testing.T, *http.Request)
ValidateError func(*testing.T, error) bool
Token string
RelativeURI string
ResponseCode int
ResponseBody []byte
ResponseContentType string
ExpectResult *GetCredentialsOutput
ExpectErr bool
ValidateRequest func(*testing.T, *http.Request)
ValidateError func(*testing.T, error) bool
}{
"success static": {
ResponseCode: 200,
ResponseBody: []byte(` {
"AccessKeyId" : "FooKey",
"SecretAccessKey" : "FooSecret"
}`),
ResponseContentType: "application/json",
ExpectResult: &GetCredentialsOutput{
AccessKeyID: "FooKey",
SecretAccessKey: "FooSecret",
Expand All @@ -45,6 +47,7 @@ func TestClient_GetCredentials(t *testing.T) {
"AccessKeyId" : "FooKey",
"SecretAccessKey" : "FooSecret"
}`),
ResponseContentType: "application/json",
ExpectResult: &GetCredentialsOutput{
AccessKeyID: "FooKey",
SecretAccessKey: "FooSecret",
Expand All @@ -59,6 +62,7 @@ func TestClient_GetCredentials(t *testing.T) {
"Token": "FooToken",
"Expiration": "2016-02-25T06:03:31Z"
}`),
ResponseContentType: "application/json",
ExpectResult: &GetCredentialsOutput{
AccessKeyID: "FooKey",
SecretAccessKey: "FooSecret",
Expand All @@ -76,6 +80,7 @@ func TestClient_GetCredentials(t *testing.T) {
"AccessKeyId" : "FooKey",
"SecretAccessKey" : "FooSecret"
}`),
ResponseContentType: "application/json",
ValidateRequest: func(t *testing.T, r *http.Request) {
t.Helper()
if e, a := "/path/to/thing", r.URL.Path; e != a {
Expand All @@ -96,7 +101,8 @@ func TestClient_GetCredentials(t *testing.T) {
"code": "Unauthorized",
"message": "not authorized for endpoint"
}`),
ExpectErr: true,
ResponseContentType: "application/json",
ExpectErr: true,
ValidateError: func(t *testing.T, err error) (ok bool) {
t.Helper()
var apiError smithy.APIError
Expand Down Expand Up @@ -126,7 +132,8 @@ func TestClient_GetCredentials(t *testing.T) {
"code": "InternalError",
"message": "an error occurred"
}`),
ExpectErr: true,
ResponseContentType: "application/json",
ExpectErr: true,
ValidateError: func(t *testing.T, err error) (ok bool) {
t.Helper()
var apiError smithy.APIError
Expand All @@ -151,13 +158,28 @@ func TestClient_GetCredentials(t *testing.T) {
},
},
"non-json error response": {
ResponseCode: 500,
ResponseBody: []byte(`<html><body>unexpected message format</body></html>`),
ExpectErr: true,
ResponseCode: 500,
ResponseBody: []byte(`<html><body>unexpected message format</body></html>`),
ResponseContentType: "text/html",
ExpectErr: true,
ValidateError: func(t *testing.T, err error) (ok bool) {
t.Helper()
if e, a := "failed to decode error message", err.Error(); !strings.Contains(a, e) {
t.Errorf("expect %v, got %v", e, a)
var apiError smithy.APIError
if errors.As(err, &apiError) {
if e, a := "", apiError.ErrorCode(); e != a {
t.Errorf("expect %v, got %v", e, a)
ok = false
}
if e, a := "<html><body>unexpected message format</body></html>", apiError.ErrorMessage(); e != a {
t.Errorf("expect %v, got %v", e, a)
ok = false
}
if e, a := smithy.FaultServer, apiError.ErrorFault(); e != a {
t.Errorf("expect %v, got %v", e, a)
ok = false
}
} else {
t.Errorf("expect %T error type, got %T: %v", apiError, err, err)
ok = false
}
return ok
Expand All @@ -177,6 +199,7 @@ func TestClient_GetCredentials(t *testing.T) {

actualReq.Body = ioutil.NopCloser(bytes.NewReader(buf.Bytes()))

w.Header().Set("Content-Type", tt.ResponseContentType)
w.WriteHeader(tt.ResponseCode)
w.Write(tt.ResponseBody)
}))
Expand Down
42 changes: 35 additions & 7 deletions credentials/endpointcreds/internal/client/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"encoding/json"
"fmt"
"io"
"net/url"

"github.com/aws/smithy-go"
Expand Down Expand Up @@ -104,17 +105,44 @@ func (d *deserializeOpGetCredential) HandleDeserialize(ctx context.Context, in s
}

func deserializeError(response *smithyhttp.Response) error {
var errShape *EndpointError
err := json.NewDecoder(response.Body).Decode(&errShape)
// we could be talking to anything, json isn't guaranteed
// see https://github.com/aws/aws-sdk-go-v2/issues/2316
if response.Header.Get("Content-Type") == "application/json" {
return deserializeJSONError(response)
}

msg, err := io.ReadAll(response.Body)
if err != nil {
return &smithy.DeserializationError{Err: fmt.Errorf("failed to decode error message, %w", err)}
return &smithy.DeserializationError{
Err: fmt.Errorf("read response, %w", err),
}
}

return &EndpointError{
// no sensible value for Code
Message: string(msg),
Fault: stof(response.StatusCode),
statusCode: response.StatusCode,
}
}

if response.StatusCode >= 500 {
errShape.Fault = smithy.FaultServer
} else {
errShape.Fault = smithy.FaultClient
func deserializeJSONError(response *smithyhttp.Response) error {
var errShape *EndpointError
if err := json.NewDecoder(response.Body).Decode(&errShape); err != nil {
return &smithy.DeserializationError{
Err: fmt.Errorf("failed to decode error message, %w", err),
}
}

errShape.Fault = stof(response.StatusCode)
errShape.statusCode = response.StatusCode
return errShape
}

// maps HTTP status code to smithy ErrorFault
func stof(code int) smithy.ErrorFault {
if code >= 500 {
return smithy.FaultServer
}
return smithy.FaultClient
}
73 changes: 73 additions & 0 deletions credentials/endpointcreds/provider_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"context"
"errors"
"fmt"
"io"
"io/ioutil"
"net/http"
"strings"
Expand Down Expand Up @@ -201,6 +202,9 @@ func TestFailedRetrieveCredentials(t *testing.T) {
"code": "Error",
"message": "Message"
}`))),
Header: http.Header{
"Content-Type": {"application/json"},
},
}, nil
})
})
Expand Down Expand Up @@ -238,3 +242,72 @@ func TestFailedRetrieveCredentials(t *testing.T) {
t.Errorf("expect empty creds not to be expired")
}
}

type mockClientN struct {
responses []*http.Response
index int
}

func (c *mockClientN) Do(r *http.Request) (*http.Response, error) {
resp := c.responses[c.index]
c.index++
return resp, nil
}

func TestRetryHTTPStatusCode(t *testing.T) {
expTime := time.Now().UTC().Add(1 * time.Hour).Format("2006-01-02T15:04:05Z")
credsResp := fmt.Sprintf(`{"AccessKeyID":"AKID","SecretAccessKey":"SECRET","Token":"TOKEN","Expiration":"%s"}`, expTime)

p := endpointcreds.New("http://127.0.0.1", func(o *endpointcreds.Options) {
o.HTTPClient = &mockClientN{
responses: []*http.Response{
{
StatusCode: 429,
Body: io.NopCloser(strings.NewReader("You have made too many requests.")),
Header: http.Header{
"Content-Type": {"text/plain"},
},
},
{
StatusCode: 500,
Body: io.NopCloser(strings.NewReader("Internal server error.")),
Header: http.Header{
"Content-Type": {"text/plain"},
},
},
{
StatusCode: 200,
Body: ioutil.NopCloser(strings.NewReader(credsResp)),
Header: http.Header{
"Content-Type": {"application/json"},
},
},
},
}
})

creds, err := p.Retrieve(context.Background())
if err != nil {
t.Fatalf("expect no error, got %v", err)
}

if e, a := "AKID", creds.AccessKeyID; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "SECRET", creds.SecretAccessKey; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "TOKEN", creds.SessionToken; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if creds.Expired() {
t.Errorf("expect not expired")
}

sdk.NowTime = func() time.Time {
return time.Now().Add(2 * time.Hour)
}
if !creds.Expired() {
t.Errorf("expect to be expired")
}
}

0 comments on commit 1c69d08

Please sign in to comment.