Skip to content

Commit

Permalink
Merge pull request #356 from aws/fixup/RequestBody
Browse files Browse the repository at this point in the history
transport/http: Fix handling of nil and http.NoBody in Request.Build
  • Loading branch information
jasdel authored Mar 7, 2022
2 parents e7e1256 + 6630cb6 commit 90a0225
Show file tree
Hide file tree
Showing 5 changed files with 173 additions and 33 deletions.
8 changes: 8 additions & 0 deletions .changelog/0271f0ea23254cf68a4c08e4c69fd6c8.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
{
"id": "0271f0ea-2325-4cf6-8a4c-08e4c69fd6c8",
"type": "bugfix",
"description": "Updates the smithy-go HTTP Request to correctly handle building the request to an http.Request. Related to [aws/aws-sdk-go-v2#1583](https://github.com/aws/aws-sdk-go-v2/issues/1583)",
"modules": [
"."
]
}
6 changes: 0 additions & 6 deletions transport/http/middleware_content_length.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,6 @@ func (m *ComputeContentLength) HandleBuild(
"failed getting length of request stream, %w", err)
} else if ok {
req.ContentLength = n
if n == 0 {
// If the content length could be determined, and the body is empty
// the stream must be cleared to prevent unexpected chunk encoding.
req, _ = req.SetStream(nil)
in.Request = req
}
}

return next.HandleBuild(ctx, in)
Expand Down
61 changes: 42 additions & 19 deletions transport/http/middleware_content_length_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,38 +13,51 @@ import (

func TestContentLengthMiddleware(t *testing.T) {
cases := map[string]struct {
Stream io.Reader
ExpectLen int64
ExpectErr string
Stream io.Reader
ExpectNilStream bool
ExpectLen int64
ExpectErr string
}{
// Cases
"bytes.Reader": {
Stream: bytes.NewReader(make([]byte, 10)),
ExpectLen: 10,
Stream: bytes.NewReader(make([]byte, 10)),
ExpectLen: 10,
ExpectNilStream: false,
},
"bytes.Buffer": {
Stream: bytes.NewBuffer(make([]byte, 10)),
ExpectLen: 10,
Stream: bytes.NewBuffer(make([]byte, 10)),
ExpectLen: 10,
ExpectNilStream: false,
},
"strings.Reader": {
Stream: strings.NewReader("hello"),
ExpectLen: 5,
Stream: strings.NewReader("hello"),
ExpectLen: 5,
ExpectNilStream: false,
},
"empty stream": {
Stream: strings.NewReader(""),
ExpectLen: 0,
Stream: strings.NewReader(""),
ExpectLen: 0,
ExpectNilStream: false,
},
"empty stream bytes": {
Stream: bytes.NewReader([]byte{}),
ExpectLen: 0,
ExpectNilStream: false,
},
"nil stream": {
ExpectLen: 0,
ExpectLen: 0,
ExpectNilStream: true,
},
"un-seekable and no length": {
Stream: &basicReader{buf: make([]byte, 10)},
ExpectLen: -1,
Stream: &basicReader{buf: make([]byte, 10)},
ExpectLen: -1,
ExpectNilStream: false,
},
"with error": {
Stream: &errorSecondSeekableReader{err: fmt.Errorf("seek failed")},
ExpectErr: "seek failed",
ExpectLen: -1,
Stream: &errorSecondSeekableReader{err: fmt.Errorf("seek failed")},
ExpectErr: "seek failed",
ExpectLen: -1,
ExpectNilStream: false,
},
}

Expand All @@ -57,10 +70,15 @@ func TestContentLengthMiddleware(t *testing.T) {
t.Fatalf("expect to set stream, %v", err)
}

var updatedRequest *Request
var m ComputeContentLength
_, _, err = m.HandleBuild(context.Background(),
middleware.BuildInput{Request: req},
nopBuildHandler,
middleware.BuildHandlerFunc(func(ctx context.Context, input middleware.BuildInput) (
out middleware.BuildOutput, metadata middleware.Metadata, err error) {
updatedRequest = input.Request.(*Request)
return out, metadata, nil
}),
)
if len(c.ExpectErr) != 0 {
if err == nil {
Expand All @@ -69,13 +87,18 @@ func TestContentLengthMiddleware(t *testing.T) {
if e, a := c.ExpectErr, err.Error(); !strings.Contains(a, e) {
t.Fatalf("expect error to contain %q, got %v", e, a)
}
return
} else if err != nil {
t.Fatalf("expect no error, got %v", err)
}

if e, a := c.ExpectLen, req.ContentLength; e != a {
if e, a := c.ExpectLen, updatedRequest.ContentLength; e != a {
t.Errorf("expect %v content-length, got %v", e, a)
}

if e, a := c.ExpectNilStream, updatedRequest.stream == nil; e != a {
t.Errorf("expect %v nil stream, got %v", e, a)
}
})
}
}
Expand Down
10 changes: 9 additions & 1 deletion transport/http/request.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,10 @@ func (r *Request) IsStreamSeekable() bool {
func (r *Request) SetStream(reader io.Reader) (rc *Request, err error) {
rc = r.Clone()

if reader == http.NoBody {
reader = nil
}

switch v := reader.(type) {
case io.Seeker:
n, err := v.Seek(0, io.SeekCurrent)
Expand Down Expand Up @@ -139,7 +143,11 @@ func (r *Request) Build(ctx context.Context) *http.Request {
req.Body = ioutil.NopCloser(stream)
req.ContentLength = -1
default:
if r.stream != nil {
// HTTP Client Request must only have a non-nil body if the
// ContentLength is explicitly unknown (-1) or non-zero. The HTTP
// Client will interpret a non-nil body and ContentLength 0 as
// "unknown". This is unwanted behavior.
if req.ContentLength != 0 && r.stream != nil {
req.Body = iointernal.NewSafeReadCloser(ioutil.NopCloser(stream))
}
}
Expand Down
121 changes: 114 additions & 7 deletions transport/http/request_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@ import (
"bytes"
"context"
"io"
"io/ioutil"
"net/http"
"net/url"
"os"
"strconv"
"strings"
"testing"
Expand All @@ -28,12 +29,7 @@ func TestRequestRewindable(t *testing.T) {

for name, c := range cases {
t.Run(name, func(t *testing.T) {
req := &Request{
Request: &http.Request{
URL: &url.URL{},
Header: http.Header{},
},
}
req := NewStackRequest().(*Request)

req, err := req.SetStream(c.Stream)
if err != nil {
Expand Down Expand Up @@ -108,3 +104,114 @@ func TestRequestBuild_contentLength(t *testing.T) {
})
}
}

func TestRequestSetStream(t *testing.T) {
cases := map[string]struct {
reader io.Reader
expectSeekable bool
expectStreamStartPos int64
expectContentLength int64
expectNilStream bool
expectNilBody bool
expectReqContentLength int64
}{
"nil stream": {
expectNilStream: true,
expectNilBody: true,
},
"empty unseekable stream": {
reader: bytes.NewBuffer([]byte{}),
expectNilStream: false,
expectNilBody: true,
},
"empty seekable stream": {
reader: bytes.NewReader([]byte{}),
expectContentLength: 0,
expectSeekable: true,
expectNilStream: false,
expectNilBody: true,
},
"unseekable no len stream": {
reader: ioutil.NopCloser(bytes.NewBuffer([]byte("abc123"))),
expectContentLength: -1,
expectNilStream: false,
expectNilBody: false,
expectReqContentLength: -1,
},
"unseekable stream": {
reader: bytes.NewBuffer([]byte("abc123")),
expectContentLength: 6,
expectNilStream: false,
expectNilBody: false,
expectReqContentLength: 6,
},
"seekable stream": {
reader: bytes.NewReader([]byte("abc123")),
expectContentLength: 6,
expectNilStream: false,
expectSeekable: true,
expectNilBody: false,
expectReqContentLength: 6,
},
"offset seekable stream": {
reader: func() io.Reader {
r := bytes.NewReader([]byte("abc123"))
_, _ = r.Seek(1, os.SEEK_SET)
return r
}(),
expectStreamStartPos: 1,
expectContentLength: 5,
expectSeekable: true,
expectNilStream: false,
expectNilBody: false,
expectReqContentLength: 5,
},
"NoBody stream": {
reader: http.NoBody,
expectNilStream: true,
expectNilBody: true,
},
}

for name, c := range cases {
t.Run(name, func(t *testing.T) {
var err error
req := NewStackRequest().(*Request)
req, err = req.SetStream(c.reader)
if err != nil {
t.Fatalf("expect not error, got %v", err)
}

if e, a := c.expectSeekable, req.IsStreamSeekable(); e != a {
t.Errorf("expect %v seekable, got %v", e, a)
}
if e, a := c.expectStreamStartPos, req.streamStartPos; e != a {
t.Errorf("expect %v seek start position, got %v", e, a)
}
if e, a := c.expectNilStream, req.stream == nil; e != a {
t.Errorf("expect %v nil stream, got %v", e, a)
}

if l, ok, err := req.StreamLength(); err != nil {
t.Fatalf("expect no stream length error, got %v", err)
} else if ok {
req.ContentLength = l
}

if e, a := c.expectContentLength, req.ContentLength; e != a {
t.Errorf("expect %v content-length, got %v", e, a)
}
if e, a := c.expectStreamStartPos, req.streamStartPos; e != a {
t.Errorf("expect %v streamStartPos, got %v", e, a)
}

r := req.Build(context.Background())
if e, a := c.expectNilBody, r.Body == nil; e != a {
t.Errorf("expect %v request nil body, got %v", e, a)
}
if e, a := c.expectContentLength, req.ContentLength; e != a {
t.Errorf("expect %v request content-length, got %v", e, a)
}
})
}
}

0 comments on commit 90a0225

Please sign in to comment.