diff --git a/aws/corehandlers/handlers.go b/aws/corehandlers/handlers.go index a054d39346c..8456e29b560 100644 --- a/aws/corehandlers/handlers.go +++ b/aws/corehandlers/handlers.go @@ -24,30 +24,38 @@ type lener interface { // BuildContentLengthHandler builds the content length of a request based on the body, // or will use the HTTPRequest.Header's "Content-Length" if defined. If unable // to determine request body length and no "Content-Length" was specified it will panic. +// +// The Content-Length will only be aded to the request if the length of the body +// is greater than 0. If the body is empty or the current `Content-Length` +// header is <= 0, the header will also be stripped. var BuildContentLengthHandler = request.NamedHandler{Name: "core.BuildContentLengthHandler", Fn: func(r *request.Request) { + var length int64 + if slength := r.HTTPRequest.Header.Get("Content-Length"); slength != "" { - length, _ := strconv.ParseInt(slength, 10, 64) - r.HTTPRequest.ContentLength = length - return + length, _ = strconv.ParseInt(slength, 10, 64) + } else { + switch body := r.Body.(type) { + case nil: + length = 0 + case lener: + length = int64(body.Len()) + case io.Seeker: + r.BodyStart, _ = body.Seek(0, 1) + end, _ := body.Seek(0, 2) + body.Seek(r.BodyStart, 0) // make sure to seek back to original location + length = end - r.BodyStart + default: + panic("Cannot get length of body, must provide `ContentLength`") + } } - var length int64 - switch body := r.Body.(type) { - case nil: - length = 0 - case lener: - length = int64(body.Len()) - case io.Seeker: - r.BodyStart, _ = body.Seek(0, 1) - end, _ := body.Seek(0, 2) - body.Seek(r.BodyStart, 0) // make sure to seek back to original location - length = end - r.BodyStart - default: - panic("Cannot get length of body, must provide `ContentLength`") + if length > 0 { + r.HTTPRequest.ContentLength = length + r.HTTPRequest.Header.Set("Content-Length", fmt.Sprintf("%d", length)) + } else { + r.HTTPRequest.ContentLength = 0 + r.HTTPRequest.Header.Del("Content-Length") } - - r.HTTPRequest.ContentLength = length - r.HTTPRequest.Header.Set("Content-Length", fmt.Sprintf("%d", length)) }} // SDKVersionUserAgentHandler is a request handler for adding the SDK Version to the user agent. diff --git a/aws/corehandlers/handlers_test.go b/aws/corehandlers/handlers_test.go index 917c4cbe28c..5b61a33b67b 100644 --- a/aws/corehandlers/handlers_test.go +++ b/aws/corehandlers/handlers_test.go @@ -5,6 +5,7 @@ import ( "fmt" "io/ioutil" "net/http" + "net/http/httptest" "os" "testing" @@ -16,6 +17,8 @@ import ( "github.com/aws/aws-sdk-go/aws/credentials" "github.com/aws/aws-sdk-go/aws/request" "github.com/aws/aws-sdk-go/awstesting" + "github.com/aws/aws-sdk-go/awstesting/unit" + "github.com/aws/aws-sdk-go/service/s3" ) func TestValidateEndpointHandler(t *testing.T) { @@ -113,3 +116,77 @@ func TestSendHandlerError(t *testing.T) { assert.Error(t, r.Error) assert.NotNil(t, r.HTTPResponse) } + +func setupContentLengthTestServer(t *testing.T, hasContentLength bool, contentLength int64) *httptest.Server { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, ok := r.Header["Content-Length"] + assert.Equal(t, hasContentLength, ok, "expect content length to be set, %t", hasContentLength) + assert.Equal(t, contentLength, r.ContentLength) + + b, err := ioutil.ReadAll(r.Body) + assert.NoError(t, err) + r.Body.Close() + + authHeader := r.Header.Get("Authorization") + if hasContentLength { + assert.Contains(t, authHeader, "content-length") + } else { + assert.NotContains(t, authHeader, "content-length") + } + + assert.Equal(t, contentLength, int64(len(b))) + })) + + return server +} + +func TestBuildContentLength_ZeroBody(t *testing.T) { + server := setupContentLengthTestServer(t, false, 0) + + svc := s3.New(unit.Session, &aws.Config{ + Endpoint: aws.String(server.URL), + S3ForcePathStyle: aws.Bool(true), + DisableSSL: aws.Bool(true), + }) + _, err := svc.GetObject(&s3.GetObjectInput{ + Bucket: aws.String("bucketname"), + Key: aws.String("keyname"), + }) + + assert.NoError(t, err) +} + +func TestBuildContentLength_NegativeBody(t *testing.T) { + server := setupContentLengthTestServer(t, false, 0) + + svc := s3.New(unit.Session, &aws.Config{ + Endpoint: aws.String(server.URL), + S3ForcePathStyle: aws.Bool(true), + DisableSSL: aws.Bool(true), + }) + req, _ := svc.GetObjectRequest(&s3.GetObjectInput{ + Bucket: aws.String("bucketname"), + Key: aws.String("keyname"), + }) + + req.HTTPRequest.Header.Set("Content-Length", "-1") + + assert.NoError(t, req.Send()) +} + +func TestBuildContentLength_WithBody(t *testing.T) { + server := setupContentLengthTestServer(t, true, 1024) + + svc := s3.New(unit.Session, &aws.Config{ + Endpoint: aws.String(server.URL), + S3ForcePathStyle: aws.Bool(true), + DisableSSL: aws.Bool(true), + }) + _, err := svc.PutObject(&s3.PutObjectInput{ + Bucket: aws.String("bucketname"), + Key: aws.String("keyname"), + Body: bytes.NewReader(make([]byte, 1024)), + }) + + assert.NoError(t, err) +} diff --git a/private/signer/v4/v4.go b/private/signer/v4/v4.go index 8bfe5d08d26..5a434b3bbb5 100644 --- a/private/signer/v4/v4.go +++ b/private/signer/v4/v4.go @@ -29,9 +29,8 @@ const ( var ignoredHeaders = rules{ blacklist{ mapRule{ - "Authorization": struct{}{}, - "Content-Length": struct{}{}, - "User-Agent": struct{}{}, + "Authorization": struct{}{}, + "User-Agent": struct{}{}, }, }, } diff --git a/private/signer/v4/v4_test.go b/private/signer/v4/v4_test.go index f113f44c6d6..309da8ed6bc 100644 --- a/private/signer/v4/v4_test.go +++ b/private/signer/v4/v4_test.go @@ -56,8 +56,8 @@ func TestPresignRequest(t *testing.T) { signer.sign() expectedDate := "19700101T000000Z" - expectedHeaders := "content-type;host;x-amz-meta-other-header;x-amz-meta-other-header_with_underscore" - expectedSig := "59c79b83112a55d188a0708cdfd776f19e4265e700990c60798a05d8923a1300" + expectedHeaders := "content-length;content-type;host;x-amz-meta-other-header;x-amz-meta-other-header_with_underscore" + expectedSig := "ea7856749041f727690c580569738282e99c79355fe0d8f125d3b5535d2ece83" expectedCred := "AKID/19700101/us-east-1/dynamodb/aws4_request" expectedTarget := "prefix.Operation" @@ -75,7 +75,7 @@ func TestSignRequest(t *testing.T) { signer.sign() expectedDate := "19700101T000000Z" - expectedSig := "AWS4-HMAC-SHA256 Credential=AKID/19700101/us-east-1/dynamodb/aws4_request, SignedHeaders=content-type;host;x-amz-date;x-amz-meta-other-header;x-amz-meta-other-header_with_underscore;x-amz-security-token;x-amz-target, Signature=47f95059b6f4c3fb5043545281560b3366961d3014757f8aac7480953c344509" + expectedSig := "AWS4-HMAC-SHA256 Credential=AKID/19700101/us-east-1/dynamodb/aws4_request, SignedHeaders=content-length;content-type;host;x-amz-date;x-amz-meta-other-header;x-amz-meta-other-header_with_underscore;x-amz-security-token;x-amz-target, Signature=ea766cabd2ec977d955a3c2bae1ae54f4515d70752f2207618396f20aa85bd21" q := signer.Request.Header assert.Equal(t, expectedSig, q.Get("Authorization"))