Skip to content

Commit

Permalink
Streaming fixes (#970)
Browse files Browse the repository at this point in the history
- Allow DisablePreParseMultipartForm in combination with
StreamRequestBody.
- Support streaming into MultipartForm instead of reading the whole body
  first.
- Support calling ctx.PostBody() when streaming is enabled.
  • Loading branch information
erikdubbelboer authored Feb 16, 2021
1 parent 1b61ca2 commit 3cd0862
Show file tree
Hide file tree
Showing 4 changed files with 264 additions and 160 deletions.
140 changes: 64 additions & 76 deletions http.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package fasthttp
import (
"bufio"
"bytes"
"compress/gzip"
"encoding/base64"
"errors"
"fmt"
Expand Down Expand Up @@ -345,6 +346,15 @@ func (req *Request) bodyBytes() []byte {
if req.bodyRaw != nil {
return req.bodyRaw
}
if req.bodyStream != nil {
bodyBuf := req.bodyBuffer()
bodyBuf.Reset()
_, err := copyZeroAlloc(bodyBuf, req.bodyStream)
req.closeBodyStream() //nolint:errcheck
if err != nil {
bodyBuf.SetString(err.Error())
}
}
if req.body == nil {
return nil
}
Expand Down Expand Up @@ -630,14 +640,6 @@ func (req *Request) SwapBody(body []byte) []byte {
func (req *Request) Body() []byte {
if req.bodyRaw != nil {
return req.bodyRaw
} else if req.bodyStream != nil {
bodyBuf := req.bodyBuffer()
bodyBuf.Reset()
_, err := copyZeroAlloc(bodyBuf, req.bodyStream)
req.closeBodyStream() //nolint:errcheck
if err != nil {
bodyBuf.SetString(err.Error())
}
} else if req.onlyMultipartForm() {
body, err := marshalMultipartForm(req.multipartForm, req.multipartFormBoundary)
if err != nil {
Expand Down Expand Up @@ -814,24 +816,43 @@ func (req *Request) MultipartForm() (*multipart.Form, error) {
return nil, ErrNoMultipartForm
}

var err error
ce := req.Header.peek(strContentEncoding)
body := req.bodyBytes()
if bytes.Equal(ce, strGzip) {
// Do not care about memory usage here.
var err error
if body, err = AppendGunzipBytes(nil, body); err != nil {
return nil, fmt.Errorf("cannot gunzip request body: %s", err)

if req.bodyStream != nil {
bodyStream := req.bodyStream
if bytes.Equal(ce, strGzip) {
// Do not care about memory usage here.
if bodyStream, err = gzip.NewReader(bodyStream); err != nil {
return nil, fmt.Errorf("cannot gunzip request body: %s", err)
}
} else if len(ce) > 0 {
return nil, fmt.Errorf("unsupported Content-Encoding: %q", ce)
}
} else if len(ce) > 0 {
return nil, fmt.Errorf("unsupported Content-Encoding: %q", ce)
}

f, err := readMultipartForm(bytes.NewReader(body), req.multipartFormBoundary, len(body), len(body))
if err != nil {
return nil, err
mr := multipart.NewReader(bodyStream, req.multipartFormBoundary)
req.multipartForm, err = mr.ReadForm(8 * 1024)
if err != nil {
return nil, fmt.Errorf("cannot read multipart/form-data body: %s", err)
}
} else {
body := req.bodyBytes()
if bytes.Equal(ce, strGzip) {
// Do not care about memory usage here.
if body, err = AppendGunzipBytes(nil, body); err != nil {
return nil, fmt.Errorf("cannot gunzip request body: %s", err)
}
} else if len(ce) > 0 {
return nil, fmt.Errorf("unsupported Content-Encoding: %q", ce)
}

req.multipartForm, err = readMultipartForm(bytes.NewReader(body), req.multipartFormBoundary, len(body), len(body))
if err != nil {
return nil, err
}
}
req.multipartForm = f
return f, nil

return req.multipartForm, nil
}

func marshalMultipartForm(f *multipart.Form, boundary string) ([]byte, error) {
Expand Down Expand Up @@ -1022,6 +1043,9 @@ func (req *Request) readLimitBody(r *bufio.Reader, maxBodySize int, getOnly bool
}

func (req *Request) readBodyStream(r *bufio.Reader, maxBodySize int, getOnly bool, preParseMultipartForm bool) error {
// Do not reset the request here - the caller must reset it before
// calling this method.

if getOnly && !req.Header.IsGet() {
return ErrGetOnly
}
Expand All @@ -1033,39 +1057,7 @@ func (req *Request) readBodyStream(r *bufio.Reader, maxBodySize int, getOnly boo
return nil
}

var err error
contentLength := req.Header.realContentLength()
if contentLength > 0 {
if preParseMultipartForm {
// Pre-read multipart form data of known length.
// This way we limit memory usage for large file uploads, since their contents
// is streamed into temporary files if file size exceeds defaultMaxInMemoryFileSize.
req.multipartFormBoundary = b2s(req.Header.MultipartFormBoundary())
if len(req.multipartFormBoundary) > 0 && len(req.Header.peek(strContentEncoding)) == 0 {
req.multipartForm, err = readMultipartForm(r, req.multipartFormBoundary, contentLength, defaultMaxInMemoryFileSize)
if err != nil {
req.Reset()
}
return err
}
}
}

if contentLength == -2 {
// identity body has no sense for http requests, since
// the end of body is determined by connection close.
// So just ignore request body for requests without
// 'Content-Length' and 'Transfer-Encoding' headers.
req.Header.SetContentLength(0)
return nil
}

bodyBuf := req.bodyBuffer()
bodyBuf.Reset()

req.bodyStream = acquireRequestStream(bodyBuf, r, contentLength)

return nil
return req.ContinueReadBodyStream(r, maxBodySize, preParseMultipartForm)
}

// MayContinue returns true if the request contains
Expand Down Expand Up @@ -1170,21 +1162,15 @@ func (req *Request) ContinueReadBodyStream(r *bufio.Reader, maxBodySize int, pre
bodyBuf := req.bodyBuffer()
bodyBuf.Reset()
bodyBuf.B, err = readBodyWithStreaming(r, contentLength, maxBodySize, bodyBuf.B)
bodyBufLen := maxBodySize
if contentLength < maxBodySize {
bodyBufLen = cap(bodyBuf.B)
}
if err != nil {
if err == ErrBodyTooLarge {
req.Header.SetContentLength(contentLength)
req.body = bodyBuf
req.bodyRaw = bodyBuf.B[:bodyBufLen]
req.bodyStream = acquireRequestStream(bodyBuf, r, contentLength)
return nil
}
if err == errChunkedStream {
req.body = bodyBuf
req.bodyRaw = bodyBuf.B[:bodyBufLen]
req.bodyStream = acquireRequestStream(bodyBuf, r, -1)
return nil
}
Expand All @@ -1193,7 +1179,6 @@ func (req *Request) ContinueReadBodyStream(r *bufio.Reader, maxBodySize int, pre
}

req.body = bodyBuf
req.bodyRaw = bodyBuf.B[:bodyBufLen]
req.bodyStream = acquireRequestStream(bodyBuf, r, contentLength)
req.Header.SetContentLength(len(bodyBuf.B))
return nil
Expand Down Expand Up @@ -1936,24 +1921,27 @@ func readBody(r *bufio.Reader, contentLength int, maxBodySize int, dst []byte) (
var errChunkedStream = errors.New("chunked stream")

func readBodyWithStreaming(r *bufio.Reader, contentLength int, maxBodySize int, dst []byte) (b []byte, err error) {
if contentLength == -1 {
// handled in requestStream.Read()
return b, errChunkedStream
}

dst = dst[:0]
switch {
case contentLength >= 0 && maxBodySize >= contentLength:
readN := maxBodySize
if contentLength > 8*1024 {
readN = 8 * 1024
}

readN := maxBodySize
if readN > contentLength {
readN = contentLength
}
if readN > 8*1024 {
readN = 8 * 1024
}

if contentLength >= 0 && maxBodySize >= contentLength {
b, err = appendBodyFixedSize(r, dst, readN)
case contentLength == -1:
// handled in requestStream.Read()
err = errChunkedStream
default:
readN := maxBodySize
if contentLength > 8*1024 {
readN = 8 * 1024
}
} else {
b, err = readBodyIdentity(r, readN, dst)
}

if err != nil {
return b, err
}
Expand Down
Loading

0 comments on commit 3cd0862

Please sign in to comment.