Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix empty response body deser in case of error response #801

Merged
merged 2 commits into from
Oct 8, 2020
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion aws/protocol/ec2query/error_utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ type ErrorComponents struct {
// GetErrorResponseComponents returns the error components from a ec2query error response body
func GetErrorResponseComponents(r io.Reader) (ErrorComponents, error) {
var er ErrorComponents
if err := xml.NewDecoder(r).Decode(&er); err != nil {
if err := xml.NewDecoder(r).Decode(&er); err != nil && err != io.EOF {
return ErrorComponents{}, fmt.Errorf("error while fetching xml error response code: %w", err)
}
return er, nil
Expand Down
3 changes: 3 additions & 0 deletions aws/protocol/ec2query/error_utils_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@ func TestGetResponseErrorCode(t *testing.T) {
expectedErrorMessage: "Hi",
expectedErrorRequestID: "foo-id",
},
"no response body": {
errorResponse: bytes.NewReader([]byte(``)),
},
}

for name, c := range cases {
Expand Down
8 changes: 4 additions & 4 deletions aws/protocol/xml/error_utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ type ErrorComponents struct {
func GetErrorResponseComponents(r io.Reader, noErrorWrapping bool) (ErrorComponents, error) {
if noErrorWrapping {
var errResponse noWrappedErrorResponse
if err := xml.NewDecoder(r).Decode(&errResponse); err != nil {
return ErrorComponents{}, fmt.Errorf("error while deserializingg xml error response: %w", err)
if err := xml.NewDecoder(r).Decode(&errResponse); err != nil && err != io.EOF {
return ErrorComponents{}, fmt.Errorf("error while deserializing xml error response: %w", err)
}
return ErrorComponents{
Code: errResponse.Code,
Expand All @@ -29,8 +29,8 @@ func GetErrorResponseComponents(r io.Reader, noErrorWrapping bool) (ErrorCompone
}

var errResponse wrappedErrorResponse
if err := xml.NewDecoder(r).Decode(&errResponse); err != nil {
return ErrorComponents{}, fmt.Errorf("error while deserializingg xml error response: %w", err)
if err := xml.NewDecoder(r).Decode(&errResponse); err != nil && err != io.EOF {
return ErrorComponents{}, fmt.Errorf("error while deserializing xml error response: %w", err)
}
return ErrorComponents{
Code: errResponse.Code,
Expand Down
3 changes: 3 additions & 0 deletions aws/protocol/xml/error_utils_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ func TestGetResponseErrorCode(t *testing.T) {
expectedErrorMessage: "Hi",
expectedErrorRequestID: "foo-id",
},
"no response body": {
errorResponse: bytes.NewReader([]byte(``)),
},
}

for name, c := range cases {
Expand Down
12 changes: 3 additions & 9 deletions service/internal/s3shared/xml_utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"encoding/xml"
"fmt"
"io"
"io/ioutil"
)

// ErrorComponents represents the error response fields
Expand All @@ -18,14 +17,9 @@ type ErrorComponents struct {

// GetErrorResponseComponents returns the error fields from an xml error response body
func GetErrorResponseComponents(r io.Reader) (ErrorComponents, error) {
rb, err := ioutil.ReadAll(r)
if err != nil {
return ErrorComponents{}, err
}

var errComponents ErrorComponents
if err := xml.Unmarshal(rb, &errComponents); err != nil {
return ErrorComponents{}, fmt.Errorf("error while deserializingg xml error response : %w", err)
if err := xml.NewDecoder(r).Decode(&errComponents); err != nil && err != io.EOF {
return ErrorComponents{}, fmt.Errorf("error while deserializing xml error response : %w", err)
}
return errComponents, err
return errComponents, nil
}
57 changes: 57 additions & 0 deletions service/s3/internal/customizations/unit_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
package customizations_test

import (
"context"
"errors"
"net/http"
"net/http/httptest"
"testing"
"time"

"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/service/s3"

"github.com/awslabs/smithy-go"
)

func Test_HeadBucket(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(400)
}))
defer server.Close()

ctx, cancelFn := context.WithTimeout(context.Background(), 5*time.Second)
defer cancelFn()

cfg := aws.Config{
Region: "us-east-1",
EndpointResolver: aws.EndpointResolverFunc(func(service, region string) (aws.Endpoint, error) {
return aws.Endpoint{
URL: server.URL,
SigningName: "s3",
}, nil
}),
Retryer: aws.NoOpRetryer{},
}

client := s3.NewFromConfig(cfg, func(options *s3.Options) {
options.UsePathStyle = true
})
params := &s3.HeadBucketInput{Bucket: aws.String("aws-sdk-go-data")}
_, err := client.HeadBucket(ctx, params)
if err == nil {
t.Error("expected error, got none")
}

var apiErr smithy.APIError
if !errors.As(err, &apiErr) {
t.Fatalf("expect error to be API error, was not, %v", err)
}
if len(apiErr.ErrorCode()) == 0 {
t.Errorf("expect non-empty error code")
}
if len(apiErr.ErrorMessage()) == 0 {
t.Errorf("expect non-empty error message")
}
}