diff --git a/.changelog/0271f0ea23254cf68a4c08e4c69fd6c8.json b/.changelog/0271f0ea23254cf68a4c08e4c69fd6c8.json new file mode 100644 index 000000000..9f5b787b4 --- /dev/null +++ b/.changelog/0271f0ea23254cf68a4c08e4c69fd6c8.json @@ -0,0 +1,8 @@ +{ + "id": "0271f0ea-2325-4cf6-8a4c-08e4c69fd6c8", + "type": "bugfix", + "description": "Updates the smithy-go HTTP Request to correctly handle building the request to an http.Request. Related to [aws/aws-sdk-go-v2#1583](https://github.com/aws/aws-sdk-go-v2/issues/1583)", + "modules": [ + "." + ] +} \ No newline at end of file diff --git a/transport/http/middleware_content_length.go b/transport/http/middleware_content_length.go index fa2c82755..9969389bb 100644 --- a/transport/http/middleware_content_length.go +++ b/transport/http/middleware_content_length.go @@ -44,12 +44,6 @@ func (m *ComputeContentLength) HandleBuild( "failed getting length of request stream, %w", err) } else if ok { req.ContentLength = n - if n == 0 { - // If the content length could be determined, and the body is empty - // the stream must be cleared to prevent unexpected chunk encoding. - req, _ = req.SetStream(nil) - in.Request = req - } } return next.HandleBuild(ctx, in) diff --git a/transport/http/middleware_content_length_test.go b/transport/http/middleware_content_length_test.go index cd27849d1..16a1f265c 100644 --- a/transport/http/middleware_content_length_test.go +++ b/transport/http/middleware_content_length_test.go @@ -13,38 +13,51 @@ import ( func TestContentLengthMiddleware(t *testing.T) { cases := map[string]struct { - Stream io.Reader - ExpectLen int64 - ExpectErr string + Stream io.Reader + ExpectNilStream bool + ExpectLen int64 + ExpectErr string }{ // Cases "bytes.Reader": { - Stream: bytes.NewReader(make([]byte, 10)), - ExpectLen: 10, + Stream: bytes.NewReader(make([]byte, 10)), + ExpectLen: 10, + ExpectNilStream: false, }, "bytes.Buffer": { - Stream: bytes.NewBuffer(make([]byte, 10)), - ExpectLen: 10, + Stream: bytes.NewBuffer(make([]byte, 10)), + ExpectLen: 10, + ExpectNilStream: false, }, "strings.Reader": { - Stream: strings.NewReader("hello"), - ExpectLen: 5, + Stream: strings.NewReader("hello"), + ExpectLen: 5, + ExpectNilStream: false, }, "empty stream": { - Stream: strings.NewReader(""), - ExpectLen: 0, + Stream: strings.NewReader(""), + ExpectLen: 0, + ExpectNilStream: false, + }, + "empty stream bytes": { + Stream: bytes.NewReader([]byte{}), + ExpectLen: 0, + ExpectNilStream: false, }, "nil stream": { - ExpectLen: 0, + ExpectLen: 0, + ExpectNilStream: true, }, "un-seekable and no length": { - Stream: &basicReader{buf: make([]byte, 10)}, - ExpectLen: -1, + Stream: &basicReader{buf: make([]byte, 10)}, + ExpectLen: -1, + ExpectNilStream: false, }, "with error": { - Stream: &errorSecondSeekableReader{err: fmt.Errorf("seek failed")}, - ExpectErr: "seek failed", - ExpectLen: -1, + Stream: &errorSecondSeekableReader{err: fmt.Errorf("seek failed")}, + ExpectErr: "seek failed", + ExpectLen: -1, + ExpectNilStream: false, }, } @@ -57,10 +70,15 @@ func TestContentLengthMiddleware(t *testing.T) { t.Fatalf("expect to set stream, %v", err) } + var updatedRequest *Request var m ComputeContentLength _, _, err = m.HandleBuild(context.Background(), middleware.BuildInput{Request: req}, - nopBuildHandler, + middleware.BuildHandlerFunc(func(ctx context.Context, input middleware.BuildInput) ( + out middleware.BuildOutput, metadata middleware.Metadata, err error) { + updatedRequest = input.Request.(*Request) + return out, metadata, nil + }), ) if len(c.ExpectErr) != 0 { if err == nil { @@ -69,13 +87,18 @@ func TestContentLengthMiddleware(t *testing.T) { if e, a := c.ExpectErr, err.Error(); !strings.Contains(a, e) { t.Fatalf("expect error to contain %q, got %v", e, a) } + return } else if err != nil { t.Fatalf("expect no error, got %v", err) } - if e, a := c.ExpectLen, req.ContentLength; e != a { + if e, a := c.ExpectLen, updatedRequest.ContentLength; e != a { t.Errorf("expect %v content-length, got %v", e, a) } + + if e, a := c.ExpectNilStream, updatedRequest.stream == nil; e != a { + t.Errorf("expect %v nil stream, got %v", e, a) + } }) } } diff --git a/transport/http/request.go b/transport/http/request.go index 5796a689c..6a759ff3e 100644 --- a/transport/http/request.go +++ b/transport/http/request.go @@ -108,6 +108,10 @@ func (r *Request) IsStreamSeekable() bool { func (r *Request) SetStream(reader io.Reader) (rc *Request, err error) { rc = r.Clone() + if reader == http.NoBody { + reader = nil + } + switch v := reader.(type) { case io.Seeker: n, err := v.Seek(0, io.SeekCurrent) @@ -139,7 +143,11 @@ func (r *Request) Build(ctx context.Context) *http.Request { req.Body = ioutil.NopCloser(stream) req.ContentLength = -1 default: - if r.stream != nil { + // HTTP Client Request must only have a non-nil body if the + // ContentLength is explicitly unknown (-1) or non-zero. The HTTP + // Client will interpret a non-nil body and ContentLength 0 as + // "unknown". This is unwanted behavior. + if req.ContentLength != 0 && r.stream != nil { req.Body = iointernal.NewSafeReadCloser(ioutil.NopCloser(stream)) } } diff --git a/transport/http/request_test.go b/transport/http/request_test.go index 685f710e8..602dab1d2 100644 --- a/transport/http/request_test.go +++ b/transport/http/request_test.go @@ -4,8 +4,9 @@ import ( "bytes" "context" "io" + "io/ioutil" "net/http" - "net/url" + "os" "strconv" "strings" "testing" @@ -28,12 +29,7 @@ func TestRequestRewindable(t *testing.T) { for name, c := range cases { t.Run(name, func(t *testing.T) { - req := &Request{ - Request: &http.Request{ - URL: &url.URL{}, - Header: http.Header{}, - }, - } + req := NewStackRequest().(*Request) req, err := req.SetStream(c.Stream) if err != nil { @@ -108,3 +104,114 @@ func TestRequestBuild_contentLength(t *testing.T) { }) } } + +func TestRequestSetStream(t *testing.T) { + cases := map[string]struct { + reader io.Reader + expectSeekable bool + expectStreamStartPos int64 + expectContentLength int64 + expectNilStream bool + expectNilBody bool + expectReqContentLength int64 + }{ + "nil stream": { + expectNilStream: true, + expectNilBody: true, + }, + "empty unseekable stream": { + reader: bytes.NewBuffer([]byte{}), + expectNilStream: false, + expectNilBody: true, + }, + "empty seekable stream": { + reader: bytes.NewReader([]byte{}), + expectContentLength: 0, + expectSeekable: true, + expectNilStream: false, + expectNilBody: true, + }, + "unseekable no len stream": { + reader: ioutil.NopCloser(bytes.NewBuffer([]byte("abc123"))), + expectContentLength: -1, + expectNilStream: false, + expectNilBody: false, + expectReqContentLength: -1, + }, + "unseekable stream": { + reader: bytes.NewBuffer([]byte("abc123")), + expectContentLength: 6, + expectNilStream: false, + expectNilBody: false, + expectReqContentLength: 6, + }, + "seekable stream": { + reader: bytes.NewReader([]byte("abc123")), + expectContentLength: 6, + expectNilStream: false, + expectSeekable: true, + expectNilBody: false, + expectReqContentLength: 6, + }, + "offset seekable stream": { + reader: func() io.Reader { + r := bytes.NewReader([]byte("abc123")) + _, _ = r.Seek(1, os.SEEK_SET) + return r + }(), + expectStreamStartPos: 1, + expectContentLength: 5, + expectSeekable: true, + expectNilStream: false, + expectNilBody: false, + expectReqContentLength: 5, + }, + "NoBody stream": { + reader: http.NoBody, + expectNilStream: true, + expectNilBody: true, + }, + } + + for name, c := range cases { + t.Run(name, func(t *testing.T) { + var err error + req := NewStackRequest().(*Request) + req, err = req.SetStream(c.reader) + if err != nil { + t.Fatalf("expect not error, got %v", err) + } + + if e, a := c.expectSeekable, req.IsStreamSeekable(); e != a { + t.Errorf("expect %v seekable, got %v", e, a) + } + if e, a := c.expectStreamStartPos, req.streamStartPos; e != a { + t.Errorf("expect %v seek start position, got %v", e, a) + } + if e, a := c.expectNilStream, req.stream == nil; e != a { + t.Errorf("expect %v nil stream, got %v", e, a) + } + + if l, ok, err := req.StreamLength(); err != nil { + t.Fatalf("expect no stream length error, got %v", err) + } else if ok { + req.ContentLength = l + } + + if e, a := c.expectContentLength, req.ContentLength; e != a { + t.Errorf("expect %v content-length, got %v", e, a) + } + if e, a := c.expectStreamStartPos, req.streamStartPos; e != a { + t.Errorf("expect %v streamStartPos, got %v", e, a) + } + + r := req.Build(context.Background()) + if e, a := c.expectNilBody, r.Body == nil; e != a { + t.Errorf("expect %v request nil body, got %v", e, a) + } + if e, a := c.expectContentLength, req.ContentLength; e != a { + t.Errorf("expect %v request content-length, got %v", e, a) + } + }) + } +}