diff --git a/http.go b/http.go index 89bc2b5fe8..5ea7847556 100644 --- a/http.go +++ b/http.go @@ -3,6 +3,7 @@ package fasthttp import ( "bufio" "bytes" + "compress/gzip" "encoding/base64" "errors" "fmt" @@ -345,6 +346,15 @@ func (req *Request) bodyBytes() []byte { if req.bodyRaw != nil { return req.bodyRaw } + if req.bodyStream != nil { + bodyBuf := req.bodyBuffer() + bodyBuf.Reset() + _, err := copyZeroAlloc(bodyBuf, req.bodyStream) + req.closeBodyStream() //nolint:errcheck + if err != nil { + bodyBuf.SetString(err.Error()) + } + } if req.body == nil { return nil } @@ -630,14 +640,6 @@ func (req *Request) SwapBody(body []byte) []byte { func (req *Request) Body() []byte { if req.bodyRaw != nil { return req.bodyRaw - } else if req.bodyStream != nil { - bodyBuf := req.bodyBuffer() - bodyBuf.Reset() - _, err := copyZeroAlloc(bodyBuf, req.bodyStream) - req.closeBodyStream() //nolint:errcheck - if err != nil { - bodyBuf.SetString(err.Error()) - } } else if req.onlyMultipartForm() { body, err := marshalMultipartForm(req.multipartForm, req.multipartFormBoundary) if err != nil { @@ -814,24 +816,43 @@ func (req *Request) MultipartForm() (*multipart.Form, error) { return nil, ErrNoMultipartForm } + var err error ce := req.Header.peek(strContentEncoding) - body := req.bodyBytes() - if bytes.Equal(ce, strGzip) { - // Do not care about memory usage here. - var err error - if body, err = AppendGunzipBytes(nil, body); err != nil { - return nil, fmt.Errorf("cannot gunzip request body: %s", err) + + if req.bodyStream != nil { + bodyStream := req.bodyStream + if bytes.Equal(ce, strGzip) { + // Do not care about memory usage here. + if bodyStream, err = gzip.NewReader(bodyStream); err != nil { + return nil, fmt.Errorf("cannot gunzip request body: %s", err) + } + } else if len(ce) > 0 { + return nil, fmt.Errorf("unsupported Content-Encoding: %q", ce) } - } else if len(ce) > 0 { - return nil, fmt.Errorf("unsupported Content-Encoding: %q", ce) - } - f, err := readMultipartForm(bytes.NewReader(body), req.multipartFormBoundary, len(body), len(body)) - if err != nil { - return nil, err + mr := multipart.NewReader(bodyStream, req.multipartFormBoundary) + req.multipartForm, err = mr.ReadForm(8 * 1024) + if err != nil { + return nil, fmt.Errorf("cannot read multipart/form-data body: %s", err) + } + } else { + body := req.bodyBytes() + if bytes.Equal(ce, strGzip) { + // Do not care about memory usage here. + if body, err = AppendGunzipBytes(nil, body); err != nil { + return nil, fmt.Errorf("cannot gunzip request body: %s", err) + } + } else if len(ce) > 0 { + return nil, fmt.Errorf("unsupported Content-Encoding: %q", ce) + } + + req.multipartForm, err = readMultipartForm(bytes.NewReader(body), req.multipartFormBoundary, len(body), len(body)) + if err != nil { + return nil, err + } } - req.multipartForm = f - return f, nil + + return req.multipartForm, nil } func marshalMultipartForm(f *multipart.Form, boundary string) ([]byte, error) { @@ -1022,6 +1043,9 @@ func (req *Request) readLimitBody(r *bufio.Reader, maxBodySize int, getOnly bool } func (req *Request) readBodyStream(r *bufio.Reader, maxBodySize int, getOnly bool, preParseMultipartForm bool) error { + // Do not reset the request here - the caller must reset it before + // calling this method. + if getOnly && !req.Header.IsGet() { return ErrGetOnly } @@ -1033,39 +1057,7 @@ func (req *Request) readBodyStream(r *bufio.Reader, maxBodySize int, getOnly boo return nil } - var err error - contentLength := req.Header.realContentLength() - if contentLength > 0 { - if preParseMultipartForm { - // Pre-read multipart form data of known length. - // This way we limit memory usage for large file uploads, since their contents - // is streamed into temporary files if file size exceeds defaultMaxInMemoryFileSize. - req.multipartFormBoundary = b2s(req.Header.MultipartFormBoundary()) - if len(req.multipartFormBoundary) > 0 && len(req.Header.peek(strContentEncoding)) == 0 { - req.multipartForm, err = readMultipartForm(r, req.multipartFormBoundary, contentLength, defaultMaxInMemoryFileSize) - if err != nil { - req.Reset() - } - return err - } - } - } - - if contentLength == -2 { - // identity body has no sense for http requests, since - // the end of body is determined by connection close. - // So just ignore request body for requests without - // 'Content-Length' and 'Transfer-Encoding' headers. - req.Header.SetContentLength(0) - return nil - } - - bodyBuf := req.bodyBuffer() - bodyBuf.Reset() - - req.bodyStream = acquireRequestStream(bodyBuf, r, contentLength) - - return nil + return req.ContinueReadBodyStream(r, maxBodySize, preParseMultipartForm) } // MayContinue returns true if the request contains @@ -1170,21 +1162,15 @@ func (req *Request) ContinueReadBodyStream(r *bufio.Reader, maxBodySize int, pre bodyBuf := req.bodyBuffer() bodyBuf.Reset() bodyBuf.B, err = readBodyWithStreaming(r, contentLength, maxBodySize, bodyBuf.B) - bodyBufLen := maxBodySize - if contentLength < maxBodySize { - bodyBufLen = cap(bodyBuf.B) - } if err != nil { if err == ErrBodyTooLarge { req.Header.SetContentLength(contentLength) req.body = bodyBuf - req.bodyRaw = bodyBuf.B[:bodyBufLen] req.bodyStream = acquireRequestStream(bodyBuf, r, contentLength) return nil } if err == errChunkedStream { req.body = bodyBuf - req.bodyRaw = bodyBuf.B[:bodyBufLen] req.bodyStream = acquireRequestStream(bodyBuf, r, -1) return nil } @@ -1193,7 +1179,6 @@ func (req *Request) ContinueReadBodyStream(r *bufio.Reader, maxBodySize int, pre } req.body = bodyBuf - req.bodyRaw = bodyBuf.B[:bodyBufLen] req.bodyStream = acquireRequestStream(bodyBuf, r, contentLength) req.Header.SetContentLength(len(bodyBuf.B)) return nil @@ -1936,24 +1921,27 @@ func readBody(r *bufio.Reader, contentLength int, maxBodySize int, dst []byte) ( var errChunkedStream = errors.New("chunked stream") func readBodyWithStreaming(r *bufio.Reader, contentLength int, maxBodySize int, dst []byte) (b []byte, err error) { + if contentLength == -1 { + // handled in requestStream.Read() + return b, errChunkedStream + } + dst = dst[:0] - switch { - case contentLength >= 0 && maxBodySize >= contentLength: - readN := maxBodySize - if contentLength > 8*1024 { - readN = 8 * 1024 - } + + readN := maxBodySize + if readN > contentLength { + readN = contentLength + } + if readN > 8*1024 { + readN = 8 * 1024 + } + + if contentLength >= 0 && maxBodySize >= contentLength { b, err = appendBodyFixedSize(r, dst, readN) - case contentLength == -1: - // handled in requestStream.Read() - err = errChunkedStream - default: - readN := maxBodySize - if contentLength > 8*1024 { - readN = 8 * 1024 - } + } else { b, err = readBodyIdentity(r, readN, dst) } + if err != nil { return b, err } diff --git a/server_test.go b/server_test.go index c7baa3e080..a0583f6408 100644 --- a/server_test.go +++ b/server_test.go @@ -1073,7 +1073,16 @@ func TestServerServeTLSEmbed(t *testing.T) { func TestServerMultipartFormDataRequest(t *testing.T) { t.Parallel() - reqS := `POST /upload HTTP/1.1 + for _, test := range []struct { + StreamRequestBody bool + DisablePreParseMultipartForm bool + }{ + {false, false}, + {false, true}, + {true, false}, + {true, true}, + } { + reqS := `POST /upload HTTP/1.1 Host: qwerty.com Content-Length: 521 Content-Type: multipart/form-data; boundary=----WebKitFormBoundaryJwfATyF8tmxSJnLg @@ -1100,91 +1109,94 @@ Connection: close ` - ln := fasthttputil.NewInmemoryListener() - - s := &Server{ - Handler: func(ctx *RequestCtx) { - switch string(ctx.Path()) { - case "/upload": - f, err := ctx.MultipartForm() - if err != nil { - t.Errorf("unexpected error: %s", err) - } - if len(f.Value) != 1 { - t.Errorf("unexpected values %d. Expecting %d", len(f.Value), 1) - } - if len(f.File) != 1 { - t.Errorf("unexpected file values %d. Expecting %d", len(f.File), 1) - } - fv := ctx.FormValue("f1") - if string(fv) != "value1" { - t.Errorf("unexpected form value: %q. Expecting %q", fv, "value1") + ln := fasthttputil.NewInmemoryListener() + + s := &Server{ + StreamRequestBody: test.StreamRequestBody, + DisablePreParseMultipartForm: test.DisablePreParseMultipartForm, + Handler: func(ctx *RequestCtx) { + switch string(ctx.Path()) { + case "/upload": + f, err := ctx.MultipartForm() + if err != nil { + t.Errorf("unexpected error: %s", err) + } + if len(f.Value) != 1 { + t.Errorf("unexpected values %d. Expecting %d", len(f.Value), 1) + } + if len(f.File) != 1 { + t.Errorf("unexpected file values %d. Expecting %d", len(f.File), 1) + } + fv := ctx.FormValue("f1") + if string(fv) != "value1" { + t.Errorf("unexpected form value: %q. Expecting %q", fv, "value1") + } + ctx.Redirect("/", StatusSeeOther) + default: + ctx.WriteString("non-upload") //nolint:errcheck } - ctx.Redirect("/", StatusSeeOther) - default: - ctx.WriteString("non-upload") //nolint:errcheck - } - }, - } - - ch := make(chan struct{}) - go func() { - if err := s.Serve(ln); err != nil { - t.Errorf("unexpected error: %s", err) + }, } - close(ch) - }() - conn, err := ln.Dial() - if err != nil { - t.Fatalf("unexpected error: %s", err) - } - if _, err = conn.Write([]byte(reqS)); err != nil { - t.Fatalf("unexpected error: %s", err) - } + ch := make(chan struct{}) + go func() { + if err := s.Serve(ln); err != nil { + t.Errorf("unexpected error: %s", err) + } + close(ch) + }() - var resp Response - br := bufio.NewReader(conn) - respCh := make(chan struct{}) - go func() { - if err := resp.Read(br); err != nil { - t.Errorf("error when reading response: %s", err) - } - if resp.StatusCode() != StatusSeeOther { - t.Errorf("unexpected status code %d. Expecting %d", resp.StatusCode(), StatusSeeOther) + conn, err := ln.Dial() + if err != nil { + t.Fatalf("unexpected error: %s", err) } - loc := resp.Header.Peek(HeaderLocation) - if string(loc) != "http://qwerty.com/" { - t.Errorf("unexpected location %q. Expecting %q", loc, "http://qwerty.com/") + if _, err = conn.Write([]byte(reqS)); err != nil { + t.Fatalf("unexpected error: %s", err) } - if err := resp.Read(br); err != nil { - t.Errorf("error when reading the second response: %s", err) - } - if resp.StatusCode() != StatusOK { - t.Errorf("unexpected status code: %d. Expecting %d", resp.StatusCode(), StatusOK) - } - body := resp.Body() - if string(body) != "non-upload" { - t.Errorf("unexpected body %q. Expecting %q", body, "non-upload") - } - close(respCh) - }() + var resp Response + br := bufio.NewReader(conn) + respCh := make(chan struct{}) + go func() { + if err := resp.Read(br); err != nil { + t.Errorf("error when reading response: %s", err) + } + if resp.StatusCode() != StatusSeeOther { + t.Errorf("unexpected status code %d. Expecting %d", resp.StatusCode(), StatusSeeOther) + } + loc := resp.Header.Peek(HeaderLocation) + if string(loc) != "http://qwerty.com/" { + t.Errorf("unexpected location %q. Expecting %q", loc, "http://qwerty.com/") + } - select { - case <-respCh: - case <-time.After(time.Second): - t.Fatal("timeout") - } + if err := resp.Read(br); err != nil { + t.Errorf("error when reading the second response: %s", err) + } + if resp.StatusCode() != StatusOK { + t.Errorf("unexpected status code: %d. Expecting %d", resp.StatusCode(), StatusOK) + } + body := resp.Body() + if string(body) != "non-upload" { + t.Errorf("unexpected body %q. Expecting %q", body, "non-upload") + } + close(respCh) + }() - if err := ln.Close(); err != nil { - t.Fatalf("error when closing listener: %s", err) - } + select { + case <-respCh: + case <-time.After(time.Second): + t.Fatal("timeout") + } - select { - case <-ch: - case <-time.After(time.Second): - t.Fatal("timeout when waiting for the server to stop") + if err := ln.Close(); err != nil { + t.Fatalf("error when closing listener: %s", err) + } + + select { + case <-ch: + case <-time.After(time.Second): + t.Fatal("timeout when waiting for the server to stop") + } } } @@ -3413,8 +3425,8 @@ func TestMaxBodySizePerRequest(t *testing.T) { func TestStreamRequestBody(t *testing.T) { t.Parallel() - part1 := strings.Repeat("1", 1<<10) - part2 := strings.Repeat("2", 1<<20-1<<10) + part1 := strings.Repeat("1", 1<<15) + part2 := strings.Repeat("2", 1<<16) contentLength := len(part1) + len(part2) next := make(chan struct{}) @@ -3424,15 +3436,17 @@ func TestStreamRequestBody(t *testing.T) { close(next) checkReader(t, ctx.RequestBodyStream(), part2) }, - DisableKeepalive: true, StreamRequestBody: true, } pipe := fasthttputil.NewPipeConns() cc, sc := pipe.Conn1(), pipe.Conn2() //write headers and part1 body - if _, err := cc.Write([]byte(fmt.Sprintf("POST /foo2 HTTP/1.1\r\nHost: aaa.com\r\nContent-Length: %d\r\nContent-Type: aa\r\n\r\n%s", contentLength, part1))); err != nil { - t.Error(err) + if _, err := cc.Write([]byte(fmt.Sprintf("POST /foo2 HTTP/1.1\r\nHost: aaa.com\r\nContent-Length: %d\r\nContent-Type: aa\r\n\r\n", contentLength))); err != nil { + t.Fatal(err) + } + if _, err := cc.Write([]byte(part1)); err != nil { + t.Fatal(err) } ch := make(chan error) @@ -3447,12 +3461,15 @@ func TestStreamRequestBody(t *testing.T) { } if _, err := cc.Write([]byte(part2)); err != nil { - t.Error(err) + t.Fatal(err) + } + if err := sc.Close(); err != nil { + t.Fatal(err) } select { case err := <-ch: - if err != nil { + if err == nil || err.Error() != "connection closed" { // fasthttputil.errConnectionClosed is private so do a string match. t.Fatalf("Unexpected error from serveConn: %s", err) } case <-time.After(500 * time.Millisecond): diff --git a/streaming.go b/streaming.go index a6ad0a9afc..39000a26d5 100644 --- a/streaming.go +++ b/streaming.go @@ -45,7 +45,12 @@ func (rs *requestStream) Read(p []byte) (int, error) { } var n int var err error - if int(rs.prefetchedBytes.Size()) > rs.totalBytesRead { + prefetchedSize := int(rs.prefetchedBytes.Size()) + if prefetchedSize > rs.totalBytesRead { + left := prefetchedSize - rs.totalBytesRead + if len(p) > left { + p = p[:left] + } n, err := rs.prefetchedBytes.Read(p) rs.totalBytesRead += n if n == rs.contentLength { @@ -53,6 +58,10 @@ func (rs *requestStream) Read(p []byte) (int, error) { } return n, err } else { + left := rs.contentLength - rs.totalBytesRead + if len(p) > left { + p = p[:left] + } n, err = rs.reader.Read(p) rs.totalBytesRead += n if err != nil { diff --git a/streaming_test.go b/streaming_test.go index e99033c031..a943cb85b6 100644 --- a/streaming_test.go +++ b/streaming_test.go @@ -6,10 +6,100 @@ import ( "io/ioutil" "sync" "testing" + "time" "github.com/valyala/fasthttp/fasthttputil" ) +func TestStreamingPipeline(t *testing.T) { + t.Parallel() + + reqS := `POST /one HTTP/1.1 +Host: example.com +Content-Length: 10 + +aaaaaaaaaa +POST /two HTTP/1.1 +Host: example.com +Content-Length: 10 + +aaaaaaaaaa` + + ln := fasthttputil.NewInmemoryListener() + + s := &Server{ + StreamRequestBody: true, + Handler: func(ctx *RequestCtx) { + body := "" + expected := "aaaaaaaaaa" + if string(ctx.Path()) == "/one" { + body = string(ctx.PostBody()) + } else { + all, err := ioutil.ReadAll(ctx.RequestBodyStream()) + if err != nil { + t.Error(err) + } + body = string(all) + } + if body != expected { + t.Errorf("expected %q got %q", expected, body) + } + }, + } + + ch := make(chan struct{}) + go func() { + if err := s.Serve(ln); err != nil { + t.Errorf("unexpected error: %s", err) + } + close(ch) + }() + + conn, err := ln.Dial() + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + if _, err = conn.Write([]byte(reqS)); err != nil { + t.Fatalf("unexpected error: %s", err) + } + + var resp Response + br := bufio.NewReader(conn) + respCh := make(chan struct{}) + go func() { + if err := resp.Read(br); err != nil { + t.Errorf("error when reading response: %s", err) + } + if resp.StatusCode() != StatusOK { + t.Errorf("unexpected status code %d. Expecting %d", resp.StatusCode(), StatusOK) + } + + if err := resp.Read(br); err != nil { + t.Errorf("error when reading response: %s", err) + } + if resp.StatusCode() != StatusOK { + t.Errorf("unexpected status code %d. Expecting %d", resp.StatusCode(), StatusOK) + } + close(respCh) + }() + + select { + case <-respCh: + case <-time.After(time.Second): + t.Fatal("timeout") + } + + if err := ln.Close(); err != nil { + t.Fatalf("error when closing listener: %s", err) + } + + select { + case <-ch: + case <-time.After(time.Second): + t.Fatal("timeout when waiting for the server to stop") + } +} + func TestRequestStream(t *testing.T) { body := createFixedBody(3) chunkedBody := createChunkedBody(body)