Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add multipart file upload streaming #309 #879

Merged
merged 1 commit into from
Oct 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 7 additions & 45 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ import (
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
"io"
"maps"
"net/http"
Expand Down Expand Up @@ -62,6 +61,7 @@ var (
hdrContentTypeKey = http.CanonicalHeaderKey("Content-Type")
hdrContentLengthKey = http.CanonicalHeaderKey("Content-Length")
hdrContentEncodingKey = http.CanonicalHeaderKey("Content-Encoding")
hdrContentDisposition = http.CanonicalHeaderKey("Content-Disposition")
hdrLocationKey = http.CanonicalHeaderKey("Location")
hdrAuthorizationKey = http.CanonicalHeaderKey("Authorization")
hdrWwwAuthenticateKey = http.CanonicalHeaderKey("WWW-Authenticate")
Expand Down Expand Up @@ -616,8 +616,7 @@ func (c *Client) R() *Request {
ResponseBodyUnlimitedReads: c.resBodyUnlimitedReads,

client: c,
multipartFiles: []*File{},
multipartFields: []*MultipartField{},
multipartFields: make([]*MultipartField, 0),
jsonEscapeHTML: c.jsonEscapeHTML,
log: c.log,
setContentLength: c.setContentLength,
Expand Down Expand Up @@ -1860,6 +1859,11 @@ func (c *Client) execute(req *Request) (*Response, error) {
if err != nil {
return response, err
}
if req.multipartErrChan != nil {
if err = <-req.multipartErrChan; err != nil {
return response, err
}
}
if resp != nil {
response.Body = resp.Body

Expand Down Expand Up @@ -1968,45 +1972,3 @@ func (c *Client) onInvalidHooks(req *Request, err error) {
h(req, err)
}
}

//‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾
// File struct and its methods
//_______________________________________________________________________

// File struct represents file information for multipart request
type File struct {
Name string
ParamName string
io.Reader
}

// String method returns the string value of current file details
func (f *File) String() string {
return fmt.Sprintf("ParamName: %v; FileName: %v", f.ParamName, f.Name)
}

// Clone method returns deep copy of f.
func (f *File) Clone() *File {
ff := new(File)
*ff = *f
return ff
}

//‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾
// MultipartField struct
//_______________________________________________________________________

// MultipartField struct represents the custom data part for a multipart request
type MultipartField struct {
Param string
FileName string
ContentType string
io.Reader
}

// Clone method returns the deep copy of m except [io.Reader].
func (m *MultipartField) Clone() *MultipartField {
mm := new(MultipartField)
*mm = *m
return mm
}
1 change: 0 additions & 1 deletion client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -487,7 +487,6 @@ func TestClientSettingsCoverage(t *testing.T) {
s, err := resp.fmtBodyString(0)
assertNil(t, err)
assertEqual(t, "***** NO CONTENT *****", s)
fmt.Println(err, s)
}

func TestContentLengthWhenBodyIsNil(t *testing.T) {
Expand Down
100 changes: 71 additions & 29 deletions middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -430,54 +430,96 @@
}

func handleMultipart(c *Client, r *Request) error {
r.bodyBuf = acquireBuffer()
w := multipart.NewWriter(r.bodyBuf)

// Set boundary if not set by user
if r.multipartBoundary != "" {
if err := w.SetBoundary(r.multipartBoundary); err != nil {
return err
for k, v := range c.FormData() {
if _, ok := r.FormData[k]; ok {
continue
}
r.FormData[k] = v[:]
}

for k, v := range c.FormData() {
for _, iv := range v {
if err := w.WriteField(k, iv); err != nil {
mfLen := len(r.multipartFields)
if mfLen == 0 {
r.bodyBuf = acquireBuffer()
mw := multipart.NewWriter(r.bodyBuf)

// set boundary if it is provided by the user
if !isStringEmpty(r.multipartBoundary) {
if err := mw.SetBoundary(r.multipartBoundary); err != nil {
return err
}
}
}

for k, v := range r.FormData {
for _, iv := range v {
if strings.HasPrefix(k, "@") { // file
if err := addFile(w, k[1:], iv); err != nil {
return err
}
} else { // form value
if err := w.WriteField(k, iv); err != nil {
return err
}
}
if err := r.writeFormData(mw); err != nil {
return err

Check warning on line 453 in middleware.go

View check run for this annotation

Codecov / codecov/patch

middleware.go#L453

Added line #L453 was not covered by tests
}

r.Header.Set(hdrContentTypeKey, mw.FormDataContentType())
closeq(mw)

return nil
}

// #21 - adding io.Reader support
for _, f := range r.multipartFiles {
if err := addFileReader(w, f); err != nil {
// multipart streaming
bodyReader, bodyWriter := io.Pipe()
mw := multipart.NewWriter(bodyWriter)
r.Body = bodyReader
r.multipartErrChan = make(chan error, 1)

// set boundary if it is provided by the user
if !isStringEmpty(r.multipartBoundary) {
if err := mw.SetBoundary(r.multipartBoundary); err != nil {
return err
}
}

// GitHub #130 adding multipart field support with content type
go func() {
defer close(r.multipartErrChan)
if err := createMultipart(mw, r); err != nil {
r.multipartErrChan <- err
}
closeq(mw)
closeq(bodyWriter)
}()

r.Header.Set(hdrContentTypeKey, mw.FormDataContentType())
return nil
}

func createMultipart(w *multipart.Writer, r *Request) error {
if err := r.writeFormData(w); err != nil {
return err

Check warning on line 490 in middleware.go

View check run for this annotation

Codecov / codecov/patch

middleware.go#L490

Added line #L490 was not covered by tests
}

for _, mf := range r.multipartFields {
if err := addMultipartFormField(w, mf); err != nil {
if err := mf.openFileIfRequired(); err != nil {
return err
}

p := make([]byte, 512)
size, err := mf.Reader.Read(p)
if err != nil && err != io.EOF {
return err
}
// auto detect content type if empty
if isStringEmpty(mf.ContentType) {
mf.ContentType = http.DetectContentType(p[:size])
}

partWriter, err := w.CreatePart(mf.createHeader())
if err != nil {
return err

Check warning on line 510 in middleware.go

View check run for this annotation

Codecov / codecov/patch

middleware.go#L510

Added line #L510 was not covered by tests
}

if _, err = partWriter.Write(p[:size]); err != nil {
return err

Check warning on line 514 in middleware.go

View check run for this annotation

Codecov / codecov/patch

middleware.go#L514

Added line #L514 was not covered by tests
}
_, err = io.Copy(partWriter, mf.Reader)
if err != nil {
return err
}
}

r.Header.Set(hdrContentTypeKey, w.FormDataContentType())
return w.Close()
return nil
}

func handleFormData(c *Client, r *Request) {
Expand Down
90 changes: 1 addition & 89 deletions middleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package resty
import (
"bytes"
"encoding/json"
"errors"
"io"
"mime"
"mime/multipart"
Expand Down Expand Up @@ -452,12 +451,6 @@ func Benchmark_parseRequestHeader(b *testing.B) {
}
}

type errorReader struct{}

func (errorReader) Read(p []byte) (n int, err error) {
return 0, errors.New("fake")
}

func TestParseRequestBody(t *testing.T) {
for _, tt := range []struct {
name string
Expand Down Expand Up @@ -755,59 +748,6 @@ func TestParseRequestBody(t *testing.T) {
expectedContentType: "text/xml",
expectedContentLength: "41",
},
{
name: "multipart form data",
initClient: func(c *Client) {
c.SetFormData(map[string]string{
"foo": "1",
"bar": "2",
})
},
initRequest: func(r *Request) {
r.SetFormData(map[string]string{
"foo": "3",
"baz": "4",
})
r.SetMultipartFormData(map[string]string{
"foo": "5",
"xyz": "6",
}).SetContentLength(true)
},
expectedBodyBuf: []byte(`{"bar":"2", "baz":"4", "foo":"5", "xyz":"6"}`),
expectedContentType: "multipart/form-data; boundary=",
expectedContentLength: "744",
},
{
name: "multipart fields",
initRequest: func(r *Request) {
r.SetMultipartFields(
&MultipartField{
Param: "foo",
ContentType: "text/plain",
Reader: strings.NewReader("1"),
},
&MultipartField{
Param: "bar",
ContentType: "text/plain",
Reader: strings.NewReader("2"),
},
).SetContentLength(true)
},
expectedBodyBuf: []byte(`{"bar":"2","foo":"1"}`),
expectedContentType: "multipart/form-data; boundary=",
expectedContentLength: "344",
},
{
name: "multipart files",
initRequest: func(r *Request) {
r.SetFileReader("foo", "foo.txt", strings.NewReader("1")).
SetFileReader("bar", "bar.txt", strings.NewReader("2")).
SetContentLength(true)
},
expectedBodyBuf: []byte(`{"bar":"2","foo":"1"}`),
expectedContentType: "multipart/form-data; boundary=",
expectedContentLength: "414",
},
{
name: "body with errorReader",
initRequest: func(r *Request) {
Expand Down Expand Up @@ -835,34 +775,6 @@ func TestParseRequestBody(t *testing.T) {
},
wantErr: true,
},
{
name: "multipart fields with errorReader",
initRequest: func(r *Request) {
r.SetMultipartFields(&MultipartField{
Param: "foo",
ContentType: "text/plain",
Reader: &errorReader{},
})
},
wantErr: true,
},
{
name: "multipart files with errorReader",
initRequest: func(r *Request) {
r.SetFileReader("foo", "foo.txt", &errorReader{})
},
wantErr: true,
},
{
name: "multipart with file not found",
initRequest: func(r *Request) {
r.SetFormData(map[string]string{
"@foo": "foo.txt",
})
r.isMultiPart = true
},
wantErr: true,
},
} {
t.Run(tt.name, func(t *testing.T) {
c := New()
Expand Down Expand Up @@ -1075,7 +987,7 @@ func Benchmark_parseRequestBody_MultiPart(b *testing.B) {
SetFileReader("qwe", "qwe.txt", strings.NewReader("7")).
SetMultipartFields(
&MultipartField{
Param: "sdj",
Name: "sdj",
ContentType: "text/plain",
Reader: strings.NewReader("8"),
},
Expand Down
Loading