Skip to content

Commit

Permalink
feat: add multipart file upload streaming #309
Browse files Browse the repository at this point in the history
  • Loading branch information
jeevatkm committed Oct 4, 2024
1 parent 9215e71 commit 2e127ca
Show file tree
Hide file tree
Showing 12 changed files with 628 additions and 582 deletions.
52 changes: 7 additions & 45 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ import (
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
"io"
"maps"
"net/http"
Expand Down Expand Up @@ -62,6 +61,7 @@ var (
hdrContentTypeKey = http.CanonicalHeaderKey("Content-Type")
hdrContentLengthKey = http.CanonicalHeaderKey("Content-Length")
hdrContentEncodingKey = http.CanonicalHeaderKey("Content-Encoding")
hdrContentDisposition = http.CanonicalHeaderKey("Content-Disposition")
hdrLocationKey = http.CanonicalHeaderKey("Location")
hdrAuthorizationKey = http.CanonicalHeaderKey("Authorization")
hdrWwwAuthenticateKey = http.CanonicalHeaderKey("WWW-Authenticate")
Expand Down Expand Up @@ -616,8 +616,7 @@ func (c *Client) R() *Request {
ResponseBodyUnlimitedReads: c.resBodyUnlimitedReads,

client: c,
multipartFiles: []*File{},
multipartFields: []*MultipartField{},
multipartFields: make([]*MultipartField, 0),
jsonEscapeHTML: c.jsonEscapeHTML,
log: c.log,
setContentLength: c.setContentLength,
Expand Down Expand Up @@ -1860,6 +1859,11 @@ func (c *Client) execute(req *Request) (*Response, error) {
if err != nil {
return response, err
}
if req.multipartErrChan != nil {
if err = <-req.multipartErrChan; err != nil {
return response, err
}
}
if resp != nil {
response.Body = resp.Body

Expand Down Expand Up @@ -1968,45 +1972,3 @@ func (c *Client) onInvalidHooks(req *Request, err error) {
h(req, err)
}
}

//‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾
// File struct and its methods
//_______________________________________________________________________

// File struct represents file information for multipart request
type File struct {
Name string
ParamName string
io.Reader
}

// String method returns the string value of current file details
func (f *File) String() string {
return fmt.Sprintf("ParamName: %v; FileName: %v", f.ParamName, f.Name)
}

// Clone method returns deep copy of f.
func (f *File) Clone() *File {
ff := new(File)
*ff = *f
return ff
}

//‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾
// MultipartField struct
//_______________________________________________________________________

// MultipartField struct represents the custom data part for a multipart request
type MultipartField struct {
Param string
FileName string
ContentType string
io.Reader
}

// Clone method returns the deep copy of m except [io.Reader].
func (m *MultipartField) Clone() *MultipartField {
mm := new(MultipartField)
*mm = *m
return mm
}
1 change: 0 additions & 1 deletion client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -487,7 +487,6 @@ func TestClientSettingsCoverage(t *testing.T) {
s, err := resp.fmtBodyString(0)
assertNil(t, err)
assertEqual(t, "***** NO CONTENT *****", s)
fmt.Println(err, s)
}

func TestContentLengthWhenBodyIsNil(t *testing.T) {
Expand Down
100 changes: 71 additions & 29 deletions middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -430,54 +430,96 @@ func parseResponseBody(c *Client, res *Response) (err error) {
}

func handleMultipart(c *Client, r *Request) error {
r.bodyBuf = acquireBuffer()
w := multipart.NewWriter(r.bodyBuf)

// Set boundary if not set by user
if r.multipartBoundary != "" {
if err := w.SetBoundary(r.multipartBoundary); err != nil {
return err
for k, v := range c.FormData() {
if _, ok := r.FormData[k]; ok {
continue
}
r.FormData[k] = v[:]
}

for k, v := range c.FormData() {
for _, iv := range v {
if err := w.WriteField(k, iv); err != nil {
mfLen := len(r.multipartFields)
if mfLen == 0 {
r.bodyBuf = acquireBuffer()
mw := multipart.NewWriter(r.bodyBuf)

// set boundary if it is provided by the user
if !isStringEmpty(r.multipartBoundary) {
if err := mw.SetBoundary(r.multipartBoundary); err != nil {
return err
}
}
}

for k, v := range r.FormData {
for _, iv := range v {
if strings.HasPrefix(k, "@") { // file
if err := addFile(w, k[1:], iv); err != nil {
return err
}
} else { // form value
if err := w.WriteField(k, iv); err != nil {
return err
}
}
if err := r.writeFormData(mw); err != nil {
return err

Check warning on line 453 in middleware.go

View check run for this annotation

Codecov / codecov/patch

middleware.go#L453

Added line #L453 was not covered by tests
}

r.Header.Set(hdrContentTypeKey, mw.FormDataContentType())
closeq(mw)

return nil
}

// #21 - adding io.Reader support
for _, f := range r.multipartFiles {
if err := addFileReader(w, f); err != nil {
// multipart streaming
bodyReader, bodyWriter := io.Pipe()
mw := multipart.NewWriter(bodyWriter)
r.Body = bodyReader
r.multipartErrChan = make(chan error, 1)

// set boundary if it is provided by the user
if !isStringEmpty(r.multipartBoundary) {
if err := mw.SetBoundary(r.multipartBoundary); err != nil {
return err
}
}

// GitHub #130 adding multipart field support with content type
go func() {
defer close(r.multipartErrChan)
if err := createMultipart(mw, r); err != nil {
r.multipartErrChan <- err
}
closeq(mw)
closeq(bodyWriter)
}()

r.Header.Set(hdrContentTypeKey, mw.FormDataContentType())
return nil
}

func createMultipart(w *multipart.Writer, r *Request) error {
if err := r.writeFormData(w); err != nil {
return err

Check warning on line 490 in middleware.go

View check run for this annotation

Codecov / codecov/patch

middleware.go#L490

Added line #L490 was not covered by tests
}

for _, mf := range r.multipartFields {
if err := addMultipartFormField(w, mf); err != nil {
if err := mf.openFileIfRequired(); err != nil {
return err
}

p := make([]byte, 512)
size, err := mf.Reader.Read(p)
if err != nil && err != io.EOF {
return err
}
// auto detect content type if empty
if isStringEmpty(mf.ContentType) {
mf.ContentType = http.DetectContentType(p[:size])
}

partWriter, err := w.CreatePart(mf.createHeader())
if err != nil {
return err

Check warning on line 510 in middleware.go

View check run for this annotation

Codecov / codecov/patch

middleware.go#L510

Added line #L510 was not covered by tests
}

if _, err = partWriter.Write(p[:size]); err != nil {
return err

Check warning on line 514 in middleware.go

View check run for this annotation

Codecov / codecov/patch

middleware.go#L514

Added line #L514 was not covered by tests
}
_, err = io.Copy(partWriter, mf.Reader)
if err != nil {
return err
}
}

r.Header.Set(hdrContentTypeKey, w.FormDataContentType())
return w.Close()
return nil
}

func handleFormData(c *Client, r *Request) {
Expand Down
90 changes: 1 addition & 89 deletions middleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package resty
import (
"bytes"
"encoding/json"
"errors"
"io"
"mime"
"mime/multipart"
Expand Down Expand Up @@ -452,12 +451,6 @@ func Benchmark_parseRequestHeader(b *testing.B) {
}
}

type errorReader struct{}

func (errorReader) Read(p []byte) (n int, err error) {
return 0, errors.New("fake")
}

func TestParseRequestBody(t *testing.T) {
for _, tt := range []struct {
name string
Expand Down Expand Up @@ -755,59 +748,6 @@ func TestParseRequestBody(t *testing.T) {
expectedContentType: "text/xml",
expectedContentLength: "41",
},
{
name: "multipart form data",
initClient: func(c *Client) {
c.SetFormData(map[string]string{
"foo": "1",
"bar": "2",
})
},
initRequest: func(r *Request) {
r.SetFormData(map[string]string{
"foo": "3",
"baz": "4",
})
r.SetMultipartFormData(map[string]string{
"foo": "5",
"xyz": "6",
}).SetContentLength(true)
},
expectedBodyBuf: []byte(`{"bar":"2", "baz":"4", "foo":"5", "xyz":"6"}`),
expectedContentType: "multipart/form-data; boundary=",
expectedContentLength: "744",
},
{
name: "multipart fields",
initRequest: func(r *Request) {
r.SetMultipartFields(
&MultipartField{
Param: "foo",
ContentType: "text/plain",
Reader: strings.NewReader("1"),
},
&MultipartField{
Param: "bar",
ContentType: "text/plain",
Reader: strings.NewReader("2"),
},
).SetContentLength(true)
},
expectedBodyBuf: []byte(`{"bar":"2","foo":"1"}`),
expectedContentType: "multipart/form-data; boundary=",
expectedContentLength: "344",
},
{
name: "multipart files",
initRequest: func(r *Request) {
r.SetFileReader("foo", "foo.txt", strings.NewReader("1")).
SetFileReader("bar", "bar.txt", strings.NewReader("2")).
SetContentLength(true)
},
expectedBodyBuf: []byte(`{"bar":"2","foo":"1"}`),
expectedContentType: "multipart/form-data; boundary=",
expectedContentLength: "414",
},
{
name: "body with errorReader",
initRequest: func(r *Request) {
Expand Down Expand Up @@ -835,34 +775,6 @@ func TestParseRequestBody(t *testing.T) {
},
wantErr: true,
},
{
name: "multipart fields with errorReader",
initRequest: func(r *Request) {
r.SetMultipartFields(&MultipartField{
Param: "foo",
ContentType: "text/plain",
Reader: &errorReader{},
})
},
wantErr: true,
},
{
name: "multipart files with errorReader",
initRequest: func(r *Request) {
r.SetFileReader("foo", "foo.txt", &errorReader{})
},
wantErr: true,
},
{
name: "multipart with file not found",
initRequest: func(r *Request) {
r.SetFormData(map[string]string{
"@foo": "foo.txt",
})
r.isMultiPart = true
},
wantErr: true,
},
} {
t.Run(tt.name, func(t *testing.T) {
c := New()
Expand Down Expand Up @@ -1075,7 +987,7 @@ func Benchmark_parseRequestBody_MultiPart(b *testing.B) {
SetFileReader("qwe", "qwe.txt", strings.NewReader("7")).
SetMultipartFields(
&MultipartField{
Param: "sdj",
Name: "sdj",
ContentType: "text/plain",
Reader: strings.NewReader("8"),
},
Expand Down
Loading

0 comments on commit 2e127ca

Please sign in to comment.