Skip to content

Commit

Permalink
Address review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
mhr3 committed Nov 19, 2021
1 parent ee97a23 commit 587b484
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 32 deletions.
13 changes: 6 additions & 7 deletions http2/transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -1239,13 +1239,12 @@ func (cs *clientStream) writeRequest(req *http.Request) (err error) {
}

continueTimeout := cc.t.expectContinueTimeout()
if continueTimeout != 0 &&
!httpguts.HeaderValuesContainsToken(
req.Header["Expect"],
"100-continue") {
continueTimeout = 0
} else if continueTimeout != 0 {
cs.on100 = make(chan struct{}, 1)
if continueTimeout != 0 {
if !httpguts.HeaderValuesContainsToken(req.Header["Expect"], "100-continue") {
continueTimeout = 0
} else {
cs.on100 = make(chan struct{}, 1)
}
}

// Past this point (where we send request headers), it is possible for
Expand Down
91 changes: 66 additions & 25 deletions http2/transport_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5109,10 +5109,28 @@ func TestTransportServerResetStreamAtHeaders(t *testing.T) {
res.Body.Close()
}

type trackingReader struct {
rdr io.Reader
wasRead uint32
}

func (tr *trackingReader) Read(p []byte) (int, error) {
atomic.StoreUint32(&tr.wasRead, 1)
return tr.rdr.Read(p)
}

func (tr *trackingReader) WasRead() bool {
return atomic.LoadUint32(&tr.wasRead) != 0
}

func TestTransportExpectContinue(t *testing.T) {
st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
io.Copy(io.Discard, r.Body)
return
switch r.URL.Path {
case "/reject":
w.WriteHeader(403)
default:
io.Copy(io.Discard, r.Body)
}
}, optOnlyServer)
defer st.Close()

Expand All @@ -5130,31 +5148,54 @@ func TestTransportExpectContinue(t *testing.T) {
Transport: tr,
}

reqCh := make(chan error)
startTime := time.Now()
testCases := []struct {
Name string
Path string
Body *trackingReader
ExpectedCode int
ShouldRead bool
}{
{
Name: "read-all",
Path: "/",
Body: &trackingReader{rdr: strings.NewReader("hello")},
ExpectedCode: 200,
ShouldRead: true,
},
{
Name: "reject",
Path: "/reject",
Body: &trackingReader{rdr: strings.NewReader("hello")},
ExpectedCode: 403,
ShouldRead: false,
},
}

go func() {
req, err := http.NewRequest("POST", st.ts.URL, strings.NewReader("hello"))
if err != nil {
reqCh <- err
return
}
req.Header.Set("Expect", "100-continue")
res, err := client.Do(req)
if err != nil {
reqCh <- err
return
}
reqCh <- res.Body.Close()
}()
for _, tc := range testCases {
t.Run(tc.Name, func(t *testing.T) {
startTime := time.Now()

err = <-reqCh
if err != nil {
t.Fatal(err)
}
delta := time.Since(startTime)
if delta >= tr.ExpectContinueTimeout {
t.Error("Request didn't resume after receiving 100 continue")
req, err := http.NewRequest("POST", st.ts.URL+tc.Path, tc.Body)
if err != nil {
t.Fatal(err)
}
req.Header.Set("Expect", "100-continue")
res, err := client.Do(req)
if err != nil {
t.Fatal(err)
}
res.Body.Close()

if delta := time.Since(startTime); delta >= tr.ExpectContinueTimeout {
t.Error("Request didn't finish before expect continue timeout")
}
if res.StatusCode != tc.ExpectedCode {
t.Errorf("Unexpected status code, got %d, expected %d", res.StatusCode, tc.ExpectedCode)
}
if tc.Body.WasRead() != tc.ShouldRead {
t.Errorf("Unexpected read status, got %v, expected %v", tc.Body.WasRead(), tc.ShouldRead)
}
})
}
}

Expand Down

0 comments on commit 587b484

Please sign in to comment.