Skip to content

Commit

Permalink
private/protocol/rest: Support Normalization of Headers Location
Browse files Browse the repository at this point in the history
  • Loading branch information
skmcgrail committed Dec 19, 2019
1 parent 7a13965 commit 40078df
Show file tree
Hide file tree
Showing 3 changed files with 128 additions and 7 deletions.
11 changes: 8 additions & 3 deletions private/protocol/rest/build.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ func buildLocationElements(r *request.Request, v reflect.Value, buildGETQuery bo
var err error
switch field.Tag.Get("location") {
case "headers": // header maps
err = buildHeaderMap(&r.HTTPRequest.Header, m, field.Tag)
err = buildHeaderMap(&r.HTTPRequest.Header, m, field.Tag, aws.BoolValue(r.Config.NormalizeHeaders))
case "header":
err = buildHeader(&r.HTTPRequest.Header, m, name, field.Tag)
case "uri":
Expand Down Expand Up @@ -173,7 +173,7 @@ func buildHeader(header *http.Header, v reflect.Value, name string, tag reflect.
return nil
}

func buildHeaderMap(header *http.Header, v reflect.Value, tag reflect.StructTag) error {
func buildHeaderMap(header *http.Header, v reflect.Value, tag reflect.StructTag, normalize bool) error {
prefix := tag.Get("locationName")
for _, key := range v.MapKeys() {
str, err := convertType(v.MapIndex(key), tag)
Expand All @@ -186,7 +186,12 @@ func buildHeaderMap(header *http.Header, v reflect.Value, tag reflect.StructTag)
keyStr := strings.TrimSpace(key.String())
str = strings.TrimSpace(str)

header.Add(prefix+keyStr, str)
if normalize {
lk := strings.ToLower(prefix + keyStr)
(*header)[lk] = append((*header)[lk], str)
} else {
header.Add(prefix+keyStr, str)
}
}
return nil
}
Expand Down
111 changes: 111 additions & 0 deletions private/protocol/rest/rest_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"bytes"
"io/ioutil"
"net/http"
"reflect"
"testing"

"github.com/aws/aws-sdk-go/aws"
Expand Down Expand Up @@ -61,3 +62,113 @@ func TestUnsetHeaders(t *testing.T) {
t.Fatal(req.Error)
}
}

func TestNormalizedHeaders(t *testing.T) {
cases := map[string]struct {
inputValues map[string]*string
outputValues http.Header
expectedInputHeaders http.Header
expectedOutput map[string]*string
normalize bool
}{
"non-normalized headers": {
inputValues: map[string]*string{
"baz": aws.String("bazValue"),
"BAR": aws.String("barValue"),
},
expectedInputHeaders: http.Header{
"X-Amz-Meta-Baz": []string{"bazValue"},
"X-Amz-Meta-Bar": []string{"barValue"},
},
outputValues: http.Header{
"X-Amz-Meta-Baz": []string{"bazValue"},
"X-Amz-Meta-Bar": []string{"barValue"},
},
expectedOutput: map[string]*string{
"Baz": aws.String("bazValue"),
"Bar": aws.String("barValue"),
},
},
"normalized headers": {
inputValues: map[string]*string{
"baz": aws.String("bazValue"),
"BAR": aws.String("barValue"),
},
expectedInputHeaders: http.Header{
"x-amz-meta-baz": []string{"bazValue"},
"x-amz-meta-bar": []string{"barValue"},
},
outputValues: http.Header{
"X-Amz-Meta-Baz": []string{"bazValue"},
"X-Amz-Meta-Bar": []string{"barValue"},
},
expectedOutput: map[string]*string{
"baz": aws.String("bazValue"),
"bar": aws.String("barValue"),
},
normalize: true,
},
}

for name, tt := range cases {
t.Run(name, func(t *testing.T) {
cfg := &aws.Config{Region: aws.String("us-west-2"), NormalizeHeaders: &tt.normalize}
c := unit.Session.ClientConfig("testService", cfg)
svc := client.New(
*cfg,
metadata.ClientInfo{
ServiceName: "testService",
SigningName: c.SigningName,
SigningRegion: c.SigningRegion,
Endpoint: c.Endpoint,
APIVersion: "",
},
c.Handlers,
)

// Handlers
op := &request.Operation{
Name: "test-operation",
HTTPPath: "/",
}

input := &struct {
Foo map[string]*string `location:"headers" locationName:"x-amz-meta-" type:"map"`
}{
Foo: tt.inputValues,
}

output := &struct {
Foo map[string]*string `location:"headers" locationName:"x-amz-meta-" type:"map"`
}{}

req := svc.NewRequest(op, input, output)
req.HTTPResponse = &http.Response{
StatusCode: 200,
Body: ioutil.NopCloser(bytes.NewBuffer(nil)),
Header: tt.outputValues,
}

// Build Request
rest.Build(req)
if req.Error != nil {
t.Fatal(req.Error)
}

if e, a := tt.expectedInputHeaders, req.HTTPRequest.Header; !reflect.DeepEqual(e, a) {
t.Errorf("expected %v, but got %v", e, a)
}

// unmarshal response
rest.UnmarshalMeta(req)
rest.Unmarshal(req)
if req.Error != nil {
t.Fatal(req.Error)
}

if e, a := tt.expectedOutput, output.Foo; !reflect.DeepEqual(e, a) {
t.Errorf("expected %v, but got %v", e, a)
}
})
}
}
13 changes: 9 additions & 4 deletions private/protocol/rest/unmarshal.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/awsutil"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/private/protocol"
)
Expand Down Expand Up @@ -120,7 +121,7 @@ func unmarshalLocationElements(r *request.Request, v reflect.Value) {
}
case "headers":
prefix := field.Tag.Get("locationName")
err := unmarshalHeaderMap(m, r.HTTPResponse.Header, prefix)
err := unmarshalHeaderMap(m, r.HTTPResponse.Header, prefix, aws.BoolValue(r.Config.NormalizeHeaders))
if err != nil {
r.Error = awserr.New(request.ErrCodeSerialization, "failed to decode REST response", err)
break
Expand All @@ -145,16 +146,20 @@ func unmarshalStatusCode(v reflect.Value, statusCode int) {
}
}

func unmarshalHeaderMap(r reflect.Value, headers http.Header, prefix string) error {
func unmarshalHeaderMap(r reflect.Value, headers http.Header, prefix string, normalize bool) error {
if len(headers) == 0 {
return nil
}
switch r.Interface().(type) {
case map[string]*string: // we only support string map value types
out := map[string]*string{}
for k, v := range headers {
k = http.CanonicalHeaderKey(k)
if strings.HasPrefix(strings.ToLower(k), strings.ToLower(prefix)) {
if awsutil.StringHasPrefixFold(k, prefix) {
if normalize == true {
k = strings.ToLower(k)
} else {
k = http.CanonicalHeaderKey(k)
}
out[k[len(prefix):]] = &v[0]
}
}
Expand Down

0 comments on commit 40078df

Please sign in to comment.