diff --git a/proxy.go b/proxy.go index d1474189..2ba1fc0f 100644 --- a/proxy.go +++ b/proxy.go @@ -573,6 +573,9 @@ func (p *Proxy) handle(ctx *Context, conn net.Conn, brw *bufio.ReadWriter) error if _, ok := err.(*trafficshape.ErrForceClose); ok { closing = errClose } + if err == io.ErrUnexpectedEOF { + closing = errClose + } } err = brw.Flush() if err != nil { diff --git a/proxy_test.go b/proxy_test.go index 30bf17cf..d2f15cca 100644 --- a/proxy_test.go +++ b/proxy_test.go @@ -294,6 +294,109 @@ func TestIntegrationHTTP100Continue(t *testing.T) { } } +func TestIntegrationUnexpectedUpstreamFailure(t *testing.T) { + t.Parallel() + + l, err := net.Listen("tcp", "[::]:0") + if err != nil { + t.Fatalf("net.Listen(): got %v, want no error", err) + } + + p := NewProxy() + defer p.Close() + + // setting a large proxy timeout + p.SetTimeout(1000 * time.Second) + + sl, err := net.Listen("tcp", "[::]:0") + if err != nil { + t.Fatalf("net.Listen(): got %v, want no error", err) + } + + go func() { + time.Sleep(1 * time.Second) + conn, err := sl.Accept() + if err != nil { + log.Errorf("proxy_test: failed to accept connection: %v", err) + return + } + defer conn.Close() + + log.Infof("proxy_test: accepted connection: %s\n", conn.RemoteAddr()) + + req, err := http.ReadRequest(bufio.NewReader(conn)) + if err != nil { + log.Errorf("proxy_test: failed to read request: %v", err) + return + } + + res := &http.Response{ + Status: "200 OK", + StatusCode: 200, + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + Body: ioutil.NopCloser(bytes.NewBufferString("body content")), + // Content length is set as 13 but response + // stops after sending 12 bytes + ContentLength: 13, + Request: req, + Header: make(http.Header, 0), + } + res.Write(conn) + conn.Close() + + log.Infof("proxy_test: sent 200 response\n") + }() + + tm := martiantest.NewModifier() + p.SetRequestModifier(tm) + p.SetResponseModifier(tm) + + go p.Serve(l) + + conn, err := net.Dial("tcp", l.Addr().String()) + if err != nil { + t.Fatalf("net.Dial(): got %v, want no error", err) + } + defer conn.Close() + + host := sl.Addr().String() + raw := fmt.Sprintf("POST http://%s/ HTTP/1.1\r\n"+ + "Host: %s\r\n"+ + "\r\n", host, host) + if _, err := conn.Write([]byte(raw)); err != nil { + t.Fatalf("conn.Write(headers): got %v, want no error", err) + } + + res, err := http.ReadResponse(bufio.NewReader(conn), nil) + if err != nil { + t.Fatalf("http.ReadResponse(): got %v, want no error", err) + } + defer res.Body.Close() + + if got, want := res.StatusCode, 200; got != want { + t.Fatalf("res.StatusCode: got %d, want %d", got, want) + } + + got, err := ioutil.ReadAll(res.Body) + // if below error is unhandled in proxy, the test will timeout. + if err != io.ErrUnexpectedEOF { + t.Fatalf("ioutil.ReadAll(): got %v, want %v", err, io.ErrUnexpectedEOF) + } + + if want := []byte("body content"); !bytes.Equal(got, want) { + t.Errorf("res.Body: got %q, want %q", got, want) + } + + if !tm.RequestModified() { + t.Error("tm.RequestModified(): got false, want true") + } + if !tm.ResponseModified() { + t.Error("tm.ResponseModified(): got false, want true") + } +} + func TestIntegrationHTTPDownstreamProxy(t *testing.T) { t.Parallel()