diff --git a/client.go b/client.go index a6256345..c66bcb8b 100644 --- a/client.go +++ b/client.go @@ -784,14 +784,14 @@ func (c *Client) execute(req *Request) (*Response, error) { // to modify the *resty.Request object for _, f := range c.udBeforeRequest { if err = f(c, req); err != nil { - return nil, err + return nil, wrapNoRetryErr(err) } } // resty middlewares for _, f := range c.beforeRequest { if err = f(c, req); err != nil { - return nil, err + return nil, wrapNoRetryErr(err) } } @@ -802,12 +802,12 @@ func (c *Client) execute(req *Request) (*Response, error) { // call pre-request if defined if c.preReqHook != nil { if err = c.preReqHook(c, req.RawRequest); err != nil { - return nil, err + return nil, wrapNoRetryErr(err) } } if err = requestLogger(c, req); err != nil { - return nil, err + return nil, wrapNoRetryErr(err) } req.Time = time.Now() @@ -855,7 +855,7 @@ func (c *Client) execute(req *Request) (*Response, error) { } } - return response, err + return response, wrapNoRetryErr(err) } // getting TLS client config if not exists then create one diff --git a/request.go b/request.go index edaed15a..776a26e5 100644 --- a/request.go +++ b/request.go @@ -633,6 +633,7 @@ func (r *Request) Send() (*Response, error) { // resp, err := client.R().Execute(resty.GET, "http://httpbin.org/get") func (r *Request) Execute(method, url string) (*Response, error) { var addrs []*net.SRV + var resp *Response var err error if r.isMultiPart && !(method == MethodPost || method == MethodPut || method == MethodPatch) { @@ -650,10 +651,10 @@ func (r *Request) Execute(method, url string) (*Response, error) { r.URL = r.selectAddr(addrs, url, 0) if r.client.RetryCount == 0 { - return r.client.execute(r) + resp, err = r.client.execute(r) + return resp, unwrapNoRetryErr(err) } - var resp *Response attempt := 0 err = Backoff( func() (*Response, error) { @@ -674,7 +675,7 @@ func (r *Request) Execute(method, url string) (*Response, error) { RetryConditions(r.client.RetryConditions), ) - return resp, err + return resp, unwrapNoRetryErr(err) } //‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾ diff --git a/request_test.go b/request_test.go index 25a478d8..3bf1f169 100644 --- a/request_test.go +++ b/request_test.go @@ -315,8 +315,14 @@ func TestForceContentTypeForGH276andGH240(t *testing.T) { ts := createPostServer(t) defer ts.Close() + retried := 0 c := dc() c.SetDebug(false) + c.SetRetryCount(3) + c.SetRetryAfter(RetryAfterFunc(func(*Client, *Response) (time.Duration, error) { + retried++ + return 0, nil + })) resp, err := c.R(). SetBody(map[string]interface{}{"username": "testuser", "password": "testpass"}). @@ -326,6 +332,7 @@ func TestForceContentTypeForGH276andGH240(t *testing.T) { assertNotNil(t, err) // expecting error due to incorrect content type from server end assertEqual(t, http.StatusOK, resp.StatusCode()) + assertEqual(t, 0, retried) t.Logf("Result Success: %q", resp.Result().(*AuthSuccess)) @@ -524,7 +531,7 @@ func TestRequestAuthScheme(t *testing.T) { resp, err := c.R(). SetAuthScheme("Bearer"). SetAuthToken("004DDB79-6801-4587-B976-F093E6AC44FF-Request"). - Get(ts.URL + "/profile") + Get(ts.URL + "/profile") assertError(t, err) assertEqual(t, http.StatusOK, resp.StatusCode()) diff --git a/retry.go b/retry.go index fa97315a..0b7c6ffe 100644 --- a/retry.go +++ b/retry.go @@ -99,10 +99,11 @@ func Backoff(operation func() (*Response, error), options ...Option) error { return err } - needsRetry := err != nil // retry on operation errors by default + err1 := unwrapNoRetryErr(err) // raw error, it used for return users callback. + needsRetry := err != nil && err == err1 // retry on a few operation errors by default for _, condition := range opts.retryConditions { - needsRetry = condition(resp, err) + needsRetry = condition(resp, err1) if needsRetry { break } diff --git a/util.go b/util.go index 6f71dba6..a247be46 100644 --- a/util.go +++ b/util.go @@ -331,3 +331,25 @@ func copyHeaders(hdrs http.Header) http.Header { } return nh } + +type noRetryErr struct { + err error +} + +func (e *noRetryErr) Error() string { + return e.err.Error() +} + +func wrapNoRetryErr(err error) error { + if err != nil { + err = &noRetryErr{err: err} + } + return err +} + +func unwrapNoRetryErr(err error) error { + if e, ok := err.(*noRetryErr); ok { + err = e.err + } + return err +}