Skip to content

Commit

Permalink
feat: add mutex to make Client thread-safe
Browse files Browse the repository at this point in the history
  • Loading branch information
tttturtle-russ committed Aug 19, 2024
1 parent 7b974b5 commit 9f6bfc5
Show file tree
Hide file tree
Showing 14 changed files with 503 additions and 183 deletions.
445 changes: 382 additions & 63 deletions client.go

Large diffs are not rendered by default.

60 changes: 30 additions & 30 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ func TestClientProxy(t *testing.T) {
assertNotNil(t, resp)
assertNotNil(t, err)

// Error
// error
c.SetProxy("//not.a.user@%66%6f%6f.com:8888")

resp, err = c.R().
Expand Down Expand Up @@ -337,9 +337,9 @@ func TestClientSetHeaderVerbatim(t *testing.T) {
SetHeader("header-lowercase", "value_standard")

//lint:ignore SA1008 valid one, so ignore this!
unConventionHdrValue := strings.Join(c.Header["header-lowercase"], "")
unConventionHdrValue := strings.Join(c.header["header-lowercase"], "")
assertEqual(t, "value_lowercase", unConventionHdrValue)
assertEqual(t, "value_standard", c.Header.Get("Header-Lowercase"))
assertEqual(t, "value_standard", c.header.Get("Header-Lowercase"))
}

func TestClientSetTransport(t *testing.T) {
Expand Down Expand Up @@ -385,20 +385,20 @@ func TestClientOptions(t *testing.T) {
assertEqual(t, client.setContentLength, true)

client.SetBaseURL("http://httpbin.org")
assertEqual(t, "http://httpbin.org", client.BaseURL)
assertEqual(t, "http://httpbin.org", client.baseURL)

client.SetHeader(hdrContentTypeKey, "application/json; charset=utf-8")
client.SetHeaders(map[string]string{
hdrUserAgentKey: "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_10_5) go-resty v0.1",
"X-Request-Id": strconv.FormatInt(time.Now().UnixNano(), 10),
})
assertEqual(t, "application/json; charset=utf-8", client.Header.Get(hdrContentTypeKey))
assertEqual(t, "application/json; charset=utf-8", client.header.Get(hdrContentTypeKey))

client.SetCookie(&http.Cookie{
Name: "default-cookie",
Value: "This is cookie default-cookie value",
})
assertEqual(t, "default-cookie", client.Cookies[0].Name)
assertEqual(t, "default-cookie", client.cookies[0].Name)

cookies := []*http.Cookie{
{
Expand All @@ -410,45 +410,45 @@ func TestClientOptions(t *testing.T) {
},
}
client.SetCookies(cookies)
assertEqual(t, "default-cookie-1", client.Cookies[1].Name)
assertEqual(t, "default-cookie-2", client.Cookies[2].Name)
assertEqual(t, "default-cookie-1", client.cookies[1].Name)
assertEqual(t, "default-cookie-2", client.cookies[2].Name)

client.SetQueryParam("test_param_1", "Param_1")
client.SetQueryParams(map[string]string{"test_param_2": "Param_2", "test_param_3": "Param_3"})
assertEqual(t, "Param_3", client.QueryParam.Get("test_param_3"))
assertEqual(t, "Param_3", client.queryParam.Get("test_param_3"))

rTime := strconv.FormatInt(time.Now().UnixNano(), 10)
client.SetFormData(map[string]string{"r_time": rTime})
assertEqual(t, rTime, client.FormData.Get("r_time"))
assertEqual(t, rTime, client.formData.Get("r_time"))

client.SetBasicAuth("myuser", "mypass")
assertEqual(t, "myuser", client.UserInfo.Username)
assertEqual(t, "myuser", client.userInfo.Username)

client.SetAuthToken("AC75BD37F019E08FBC594900518B4F7E")
assertEqual(t, "AC75BD37F019E08FBC594900518B4F7E", client.Token)
assertEqual(t, "AC75BD37F019E08FBC594900518B4F7E", client.token)

client.SetDisableWarn(true)
assertEqual(t, client.DisableWarn, true)
assertEqual(t, client.disableWarn, true)

client.SetRetryCount(3)
assertEqual(t, 3, client.RetryCount)
assertEqual(t, 3, client.retryCount)

rwt := time.Duration(1000) * time.Millisecond
client.SetRetryWaitTime(rwt)
assertEqual(t, rwt, client.RetryWaitTime)
assertEqual(t, rwt, client.retryWaitTime)

mrwt := time.Duration(2) * time.Second
client.SetRetryMaxWaitTime(mrwt)
assertEqual(t, mrwt, client.RetryMaxWaitTime)
assertEqual(t, mrwt, client.retryMaxWaitTime)

client.AddRetryAfterErrorCondition()
equal(client.RetryConditions[0], func(response *Response, err error) bool {
equal(client.retryConditions[0], func(response *Response, err error) bool {
return response.IsError()
})

err := &AuthError{}
client.SetError(err)
if reflect.TypeOf(err) == client.Error {
if reflect.TypeOf(err) == client.error {
t.Error("SetError failed")
}

Expand All @@ -474,14 +474,14 @@ func TestClientOptions(t *testing.T) {
client.SetContentLength(true)

client.SetDebug(true)
assertEqual(t, client.Debug, true)
assertEqual(t, client.debug, true)

var sl int64 = 1000000
client.SetDebugBodyLimit(sl)
assertEqual(t, client.debugBodySizeLimit, sl)

client.SetAllowGetMethodPayload(true)
assertEqual(t, client.AllowGetMethodPayload, true)
assertEqual(t, client.allowGetMethodPayload, true)

client.SetScheme("http")
assertEqual(t, client.scheme, "http")
Expand Down Expand Up @@ -615,31 +615,31 @@ func TestClientNewRequest(t *testing.T) {
func TestClientSetJSONMarshaler(t *testing.T) {
m := func(v interface{}) ([]byte, error) { return nil, nil }
c := New().SetJSONMarshaler(m)
p1 := fmt.Sprintf("%p", c.JSONMarshal)
p1 := fmt.Sprintf("%p", c.jsonMarshal)
p2 := fmt.Sprintf("%p", m)
assertEqual(t, p1, p2) // functions can not be compared, we only can compare pointers
}

func TestClientSetJSONUnmarshaler(t *testing.T) {
m := func([]byte, interface{}) error { return nil }
c := New().SetJSONUnmarshaler(m)
p1 := fmt.Sprintf("%p", c.JSONUnmarshal)
p1 := fmt.Sprintf("%p", c.jsonUnmarshal)
p2 := fmt.Sprintf("%p", m)
assertEqual(t, p1, p2) // functions can not be compared, we only can compare pointers
}

func TestClientSetXMLMarshaler(t *testing.T) {
m := func(v interface{}) ([]byte, error) { return nil, nil }
c := New().SetXMLMarshaler(m)
p1 := fmt.Sprintf("%p", c.XMLMarshal)
p1 := fmt.Sprintf("%p", c.xmlMarshal)
p2 := fmt.Sprintf("%p", m)
assertEqual(t, p1, p2) // functions can not be compared, we only can compare pointers
}

func TestClientSetXMLUnmarshaler(t *testing.T) {
m := func([]byte, interface{}) error { return nil }
c := New().SetXMLUnmarshaler(m)
p1 := fmt.Sprintf("%p", c.XMLUnmarshal)
p1 := fmt.Sprintf("%p", c.xmlUnmarshal)
p2 := fmt.Sprintf("%p", m)
assertEqual(t, p1, p2) // functions can not be compared, we only can compare pointers
}
Expand Down Expand Up @@ -1114,21 +1114,21 @@ func TestClone(t *testing.T) {
parent.SetBaseURL("http://localhost")

// set an interface field
parent.UserInfo = &User{
parent.userInfo = &User{
Username: "parent",
}

clone := parent.Clone()
// update value of non-interface type - change will only happen on clone
clone.SetBaseURL("https://local.host")
// update value of interface type - change will also happen on parent
clone.UserInfo.Username = "clone"
clone.userInfo.Username = "clone"

// asert non-interface type
assertEqual(t, "http://localhost", parent.BaseURL)
assertEqual(t, "https://local.host", clone.BaseURL)
assertEqual(t, "http://localhost", parent.baseURL)
assertEqual(t, "https://local.host", clone.baseURL)

// assert interface type
assertEqual(t, "clone", parent.UserInfo.Username)
assertEqual(t, "clone", clone.UserInfo.Username)
assertEqual(t, "clone", parent.userInfo.Username)
assertEqual(t, "clone", clone.userInfo.Username)
}
6 changes: 3 additions & 3 deletions context_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ func TestSetContextCancel(t *testing.T) {

<-ch // wait for server to finish request handling

t.Logf("Error: %v", err)
t.Logf("error: %v", err)
if !errIsContextCanceled(err) {
t.Errorf("Got unexpected error: %v", err)
}
Expand Down Expand Up @@ -122,7 +122,7 @@ func TestSetContextCancelRetry(t *testing.T) {

<-ch // wait for server to finish request handling

t.Logf("Error: %v", err)
t.Logf("error: %v", err)
if !errIsContextCanceled(err) {
t.Errorf("Got unexpected error: %v", err)
}
Expand Down Expand Up @@ -165,7 +165,7 @@ func TestSetContextCancelWithError(t *testing.T) {

<-ch // wait for server to finish request handling

t.Logf("Error: %v", err)
t.Logf("error: %v", err)
if !errIsContextCanceled(err) {
t.Errorf("Got unexpected error: %v", err)
}
Expand Down
2 changes: 1 addition & 1 deletion example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ func Example_put() {
Tags: []string{"article", "sample", "resty"},
}).
SetAuthToken("C6A79608-782F-4ED0-A11D-BD82FAD829CD").
SetError(&Error{}). // or SetError(Error{}).
SetError(&Error{}). // or SetError(error{}).
Put("https://myapp.com/article/1234")

printOutput(resp, err)
Expand Down
54 changes: 27 additions & 27 deletions middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,14 @@ const debugRequestLogKey = "__restyDebugRequestLog"
//_______________________________________________________________________

func parseRequestURL(c *Client, r *Request) error {
if l := len(c.PathParams) + len(c.RawPathParams) + len(r.PathParams) + len(r.RawPathParams); l > 0 {
if l := len(c.pathParams) + len(c.rawPathParams) + len(r.PathParams) + len(r.RawPathParams); l > 0 {
params := make(map[string]string, l)

// GitHub #103 Path Params
for p, v := range r.PathParams {
params[p] = url.PathEscape(v)
}
for p, v := range c.PathParams {
for p, v := range c.pathParams {
if _, ok := params[p]; !ok {
params[p] = url.PathEscape(v)
}
Expand All @@ -46,7 +46,7 @@ func parseRequestURL(c *Client, r *Request) error {
params[p] = v
}
}
for p, v := range c.RawPathParams {
for p, v := range c.rawPathParams {
if _, ok := params[p]; !ok {
params[p] = v
}
Expand Down Expand Up @@ -114,7 +114,7 @@ func parseRequestURL(c *Client, r *Request) error {
r.URL = "/" + r.URL
}

reqURL, err = url.Parse(c.BaseURL + r.URL)
reqURL, err = url.Parse(c.baseURL + r.URL)
if err != nil {
return err
}
Expand All @@ -126,8 +126,8 @@ func parseRequestURL(c *Client, r *Request) error {
}

// Adding Query Param
if len(c.QueryParam)+len(r.QueryParam) > 0 {
for k, v := range c.QueryParam {
if len(c.queryParam)+len(r.QueryParam) > 0 {
for k, v := range c.queryParam {
// skip query parameter if it was set in request
if _, ok := r.QueryParam[k]; ok {
continue
Expand Down Expand Up @@ -155,7 +155,7 @@ func parseRequestURL(c *Client, r *Request) error {
}

func parseRequestHeader(c *Client, r *Request) error {
for k, v := range c.Header {
for k, v := range c.header {
if _, ok := r.Header[k]; ok {
continue
}
Expand All @@ -174,13 +174,13 @@ func parseRequestHeader(c *Client, r *Request) error {
}

func parseRequestBody(c *Client, r *Request) error {
if isPayloadSupported(r.Method, c.AllowGetMethodPayload) {
if isPayloadSupported(r.Method, c.allowGetMethodPayload) {
switch {
case r.isMultiPart: // Handling Multipart
if err := handleMultipart(c, r); err != nil {
return err
}
case len(c.FormData) > 0 || len(r.FormData) > 0: // Handling Form Data
case len(c.formData) > 0 || len(r.FormData) > 0: // Handling Form Data
handleFormData(c, r)
case r.Body != nil: // Handling Request body
handleContentType(c, r)
Expand All @@ -205,7 +205,7 @@ func parseRequestBody(c *Client, r *Request) error {

func createHTTPRequest(c *Client, r *Request) (err error) {
if r.bodyBuf == nil {
if reader, ok := r.Body.(io.Reader); ok && isPayloadSupported(r.Method, c.AllowGetMethodPayload) {
if reader, ok := r.Body.(io.Reader); ok && isPayloadSupported(r.Method, c.allowGetMethodPayload) {
r.RawRequest, err = http.NewRequest(r.Method, r.URL, reader)
} else if c.setContentLength || r.setContentLength {
r.RawRequest, err = http.NewRequest(r.Method, r.URL, http.NoBody)
Expand All @@ -229,7 +229,7 @@ func createHTTPRequest(c *Client, r *Request) (err error) {
r.RawRequest.Header = r.Header

// Add cookies from client instance into http request
for _, cookie := range c.Cookies {
for _, cookie := range c.cookies {
r.RawRequest.AddCookie(cookie)
}

Expand Down Expand Up @@ -271,32 +271,32 @@ func addCredentials(c *Client, r *Request) error {
if r.UserInfo != nil { // takes precedence
r.RawRequest.SetBasicAuth(r.UserInfo.Username, r.UserInfo.Password)
isBasicAuth = true
} else if c.UserInfo != nil {
r.RawRequest.SetBasicAuth(c.UserInfo.Username, c.UserInfo.Password)
} else if c.userInfo != nil {
r.RawRequest.SetBasicAuth(c.userInfo.Username, c.userInfo.Password)
isBasicAuth = true
}

if !c.DisableWarn {
if !c.disableWarn {
if isBasicAuth && !strings.HasPrefix(r.URL, "https") {
r.log.Warnf("Using Basic Auth in HTTP mode is not secure, use HTTPS")
}
}

// Set the Authorization Header Scheme
// Set the Authorization header Scheme
var authScheme string
if !IsStringEmpty(r.AuthScheme) {
authScheme = r.AuthScheme
} else if !IsStringEmpty(c.AuthScheme) {
authScheme = c.AuthScheme
} else if !IsStringEmpty(c.authScheme) {
authScheme = c.authScheme
} else {
authScheme = "Bearer"
}

// Build the Token Auth header
// Build the token Auth header
if !IsStringEmpty(r.Token) { // takes precedence
r.RawRequest.Header.Set(c.HeaderAuthorizationKey, authScheme+" "+r.Token)
} else if !IsStringEmpty(c.Token) {
r.RawRequest.Header.Set(c.HeaderAuthorizationKey, authScheme+" "+c.Token)
r.RawRequest.Header.Set(c.headerAuthorizationKey, authScheme+" "+r.Token)
} else if !IsStringEmpty(c.token) {
r.RawRequest.Header.Set(c.headerAuthorizationKey, authScheme+" "+c.token)
}

return nil
Expand Down Expand Up @@ -389,11 +389,11 @@ func parseResponseBody(c *Client, res *Response) (err error) {
}
}

// HTTP status code > 399, considered as Error
// HTTP status code > 399, considered as error
if res.IsError() {
// global error interface
if res.Request.Error == nil && c.Error != nil {
res.Request.Error = reflect.New(c.Error).Interface()
if res.Request.Error == nil && c.error != nil {
res.Request.Error = reflect.New(c.error).Interface()
}

if res.Request.Error != nil {
Expand All @@ -412,7 +412,7 @@ func handleMultipart(c *Client, r *Request) error {
r.bodyBuf = acquireBuffer()
w := multipart.NewWriter(r.bodyBuf)

for k, v := range c.FormData {
for k, v := range c.formData {
for _, iv := range v {
if err := w.WriteField(k, iv); err != nil {
return err
Expand Down Expand Up @@ -453,7 +453,7 @@ func handleMultipart(c *Client, r *Request) error {
}

func handleFormData(c *Client, r *Request) {
for k, v := range c.FormData {
for k, v := range c.formData {
if _, ok := r.FormData[k]; ok {
continue
}
Expand Down Expand Up @@ -501,7 +501,7 @@ func handleRequestBody(c *Client, r *Request) error {
if IsJSONType(contentType) && (kind == reflect.Struct || kind == reflect.Map || kind == reflect.Slice) {
r.bodyBuf, err = jsonMarshal(c, r, r.Body)
} else if IsXMLType(contentType) && (kind == reflect.Struct) {
bodyBytes, err = c.XMLMarshal(r.Body)
bodyBytes, err = c.xmlMarshal(r.Body)
}
if err != nil {
return err
Expand Down
Loading

0 comments on commit 9f6bfc5

Please sign in to comment.