Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

aws: Add config option to unmarshal API response header maps to normalized lower case map keys #3033

Merged
merged 12 commits into from
Jan 7, 2020
7 changes: 7 additions & 0 deletions aws/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,13 @@ type Config struct {
// in the ARN, when an ARN is provided as an argument to a bucket parameter.
S3UseARNRegion *bool

// Set this to `true` to enable the SDK to unmarshal API response header maps to
// normalized lower case map keys.
//
// For example S3's X-Amz-Meta prefixed header will be unmarshaled to lower case
// Metadata member's map keys. The value of the header in the map is unaffected.
LowerCaseHeaderMaps *bool

// Set this to `true` to disable the EC2Metadata client from overriding the
// default http.Client's Timeout. This is helpful if you do not want the
// EC2Metadata client to create a new http.Client. This options is only
Expand Down
5 changes: 2 additions & 3 deletions aws/signer/v4/header_rules.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
package v4

import (
"net/http"
"strings"
"github.com/aws/aws-sdk-go/internal/strings"
)

// validator houses a set of rule needed for validation of a
Expand Down Expand Up @@ -61,7 +60,7 @@ type patterns []string
// been found
func (p patterns) IsValid(value string) bool {
for _, pattern := range p {
if strings.HasPrefix(http.CanonicalHeaderKey(value), pattern) {
if strings.HasPrefixFold(value, pattern) {
return true
}
}
Expand Down
3 changes: 1 addition & 2 deletions aws/signer/v4/v4.go
Original file line number Diff line number Diff line change
Expand Up @@ -608,8 +608,7 @@ func (ctx *signingCtx) buildCanonicalHeaders(r rule, header http.Header) {
var headers []string
headers = append(headers, "host")
for k, v := range header {
canonicalKey := http.CanonicalHeaderKey(k)
if !r.IsValid(canonicalKey) {
if !r.IsValid(k) {
continue // ignored header
}
if ctx.SignedHeaderVals == nil {
Expand Down
11 changes: 11 additions & 0 deletions internal/strings/strings.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
package strings

import (
"strings"
)

// HasPrefixFold tests whether the string s begins with prefix, interpreted as UTF-8 strings,
// under Unicode case-folding.
func HasPrefixFold(s, prefix string) bool {
return len(s) >= len(prefix) && strings.EqualFold(s[0:len(prefix)], prefix)
}
83 changes: 83 additions & 0 deletions internal/strings/strings_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
// +build go1.7

package strings

import (
"strings"
"testing"
)

func TestHasPrefixFold(t *testing.T) {
type args struct {
s string
prefix string
}
tests := map[string]struct {
args args
want bool
}{
"empty strings and prefix": {
args: args{
s: "",
prefix: "",
},
want: true,
},
"strings starts with prefix": {
args: args{
s: "some string",
prefix: "some",
},
want: true,
},
"prefix longer then string": {
args: args{
s: "some",
prefix: "some string",
},
},
"equal length string and prefix": {
args: args{
s: "short string",
prefix: "short string",
},
want: true,
},
"different cases": {
args: args{
s: "ShOrT StRING",
prefix: "short",
},
want: true,
},
"empty prefix not empty string": {
args: args{
s: "ShOrT StRING",
prefix: "",
},
want: true,
},
"mixed-case prefixes": {
args: args{
s: "SoMe String",
prefix: "sOme",
},
want: true,
},
}
for name, tt := range tests {
t.Run(name, func(t *testing.T) {
if got := HasPrefixFold(tt.args.s, tt.args.prefix); got != tt.want {
t.Errorf("HasPrefixFold() = %v, want %v", got, tt.want)
}
})
}
}

func BenchmarkHasPrefixFold(b *testing.B) {
HasPrefixFold("SoME string", "sOmE")
}

func BenchmarkHasPrefix(b *testing.B) {
strings.HasPrefix(strings.ToLower("SoME string"), strings.ToLower("sOmE"))
}
23 changes: 23 additions & 0 deletions private/model/api/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -951,6 +951,29 @@ func (a *API) writeInputOutputLocationName() {
}
}

func (a *API) addHeaderMapDocumentation() {
for _, shape := range a.Shapes {
if !shape.UsedAsOutput {
continue
}
for _, shapeRef := range shape.MemberRefs {
if shapeRef.Location == "headers" {
if dLen := len(shapeRef.Documentation); dLen > 0 {
if shapeRef.Documentation[dLen-1] != '\n' {
shapeRef.Documentation += "\n"
}
shapeRef.Documentation += "//"
}
shapeRef.Documentation += `
// By default unmarshaled keys are written as a map keys in following canonicalized format:
// the first letter and any letter following a hyphen will be capitalized, and the rest as lowercase.
// Set ` + "`aws.Config.LowerCaseHeaderMaps`" + ` to ` + "`true`" + ` to write unmarshaled keys to the map as lowercase.
`
}
}
}
}

func getDeprecatedMessage(msg string, name string) string {
if len(msg) == 0 {
return name + " has been deprecated"
Expand Down
2 changes: 2 additions & 0 deletions private/model/api/load.go
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,8 @@ func (a *API) Setup() error {
return err
}

a.addHeaderMapDocumentation()
skmcgrail marked this conversation as resolved.
Show resolved Hide resolved

if !a.NoRemoveUnusedShapes {
a.removeUnusedShapes()
}
Expand Down
113 changes: 113 additions & 0 deletions private/protocol/rest/rest_test.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
// +build go1.7

package rest_test

import (
"bytes"
"io/ioutil"
"net/http"
"reflect"
"testing"

"github.com/aws/aws-sdk-go/aws"
Expand Down Expand Up @@ -61,3 +64,113 @@ func TestUnsetHeaders(t *testing.T) {
t.Fatal(req.Error)
}
}

func TestNormalizedHeaders(t *testing.T) {
cases := map[string]struct {
inputValues map[string]*string
outputValues http.Header
expectedInputHeaders http.Header
expectedOutput map[string]*string
normalize bool
}{
"non-normalized headers": {
inputValues: map[string]*string{
"baz": aws.String("bazValue"),
"BAR": aws.String("barValue"),
},
expectedInputHeaders: http.Header{
"X-Amz-Meta-Baz": []string{"bazValue"},
"X-Amz-Meta-Bar": []string{"barValue"},
},
outputValues: http.Header{
"X-Amz-Meta-Baz": []string{"bazValue"},
"X-Amz-Meta-Bar": []string{"barValue"},
},
expectedOutput: map[string]*string{
"Baz": aws.String("bazValue"),
"Bar": aws.String("barValue"),
},
},
"normalized headers": {
inputValues: map[string]*string{
"baz": aws.String("bazValue"),
"BAR": aws.String("barValue"),
},
expectedInputHeaders: http.Header{
"X-Amz-Meta-Baz": []string{"bazValue"},
"X-Amz-Meta-Bar": []string{"barValue"},
},
outputValues: http.Header{
"X-Amz-Meta-Baz": []string{"bazValue"},
"X-Amz-Meta-Bar": []string{"barValue"},
},
expectedOutput: map[string]*string{
"baz": aws.String("bazValue"),
"bar": aws.String("barValue"),
},
normalize: true,
},
}

for name, tt := range cases {
t.Run(name, func(t *testing.T) {
cfg := &aws.Config{Region: aws.String("us-west-2"), LowerCaseHeaderMaps: &tt.normalize}
c := unit.Session.ClientConfig("testService", cfg)
svc := client.New(
*cfg,
metadata.ClientInfo{
ServiceName: "testService",
SigningName: c.SigningName,
SigningRegion: c.SigningRegion,
Endpoint: c.Endpoint,
APIVersion: "",
},
c.Handlers,
)

// Handlers
op := &request.Operation{
Name: "test-operation",
HTTPPath: "/",
}

input := &struct {
Foo map[string]*string `location:"headers" locationName:"x-amz-meta-" type:"map"`
}{
Foo: tt.inputValues,
}

output := &struct {
Foo map[string]*string `location:"headers" locationName:"x-amz-meta-" type:"map"`
}{}

req := svc.NewRequest(op, input, output)
req.HTTPResponse = &http.Response{
StatusCode: 200,
Body: ioutil.NopCloser(bytes.NewBuffer(nil)),
Header: tt.outputValues,
}

// Build Request
rest.Build(req)
if req.Error != nil {
t.Fatal(req.Error)
}

if e, a := tt.expectedInputHeaders, req.HTTPRequest.Header; !reflect.DeepEqual(e, a) {
t.Errorf("expected %v, but got %v", e, a)
}

// unmarshal response
rest.UnmarshalMeta(req)
rest.Unmarshal(req)
if req.Error != nil {
t.Fatal(req.Error)
}

if e, a := tt.expectedOutput, output.Foo; !reflect.DeepEqual(e, a) {
t.Errorf("expected %v, but got %v", e, a)
}
})
}
}
13 changes: 9 additions & 4 deletions private/protocol/rest/unmarshal.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/request"
awsStrings "github.com/aws/aws-sdk-go/internal/strings"
"github.com/aws/aws-sdk-go/private/protocol"
)

Expand Down Expand Up @@ -120,7 +121,7 @@ func unmarshalLocationElements(r *request.Request, v reflect.Value) {
}
case "headers":
prefix := field.Tag.Get("locationName")
err := unmarshalHeaderMap(m, r.HTTPResponse.Header, prefix)
err := unmarshalHeaderMap(m, r.HTTPResponse.Header, prefix, aws.BoolValue(r.Config.LowerCaseHeaderMaps))
if err != nil {
r.Error = awserr.New(request.ErrCodeSerialization, "failed to decode REST response", err)
break
Expand All @@ -145,16 +146,20 @@ func unmarshalStatusCode(v reflect.Value, statusCode int) {
}
}

func unmarshalHeaderMap(r reflect.Value, headers http.Header, prefix string) error {
func unmarshalHeaderMap(r reflect.Value, headers http.Header, prefix string, normalize bool) error {
skmcgrail marked this conversation as resolved.
Show resolved Hide resolved
if len(headers) == 0 {
return nil
}
switch r.Interface().(type) {
case map[string]*string: // we only support string map value types
out := map[string]*string{}
for k, v := range headers {
k = http.CanonicalHeaderKey(k)
if strings.HasPrefix(strings.ToLower(k), strings.ToLower(prefix)) {
if awsStrings.HasPrefixFold(k, prefix) {
if normalize == true {
k = strings.ToLower(k)
} else {
k = http.CanonicalHeaderKey(k)
}
out[k[len(prefix):]] = &v[0]
}
}
Expand Down
6 changes: 6 additions & 0 deletions private/protocol/restjson/unmarshal_test.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 8 additions & 0 deletions service/s3/api.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.