diff --git a/client.go b/client.go index adbdd92..53957ff 100644 --- a/client.go +++ b/client.go @@ -470,12 +470,14 @@ func baseRetryPolicy(resp *http.Response, err error) (bool, error) { // (HTTP Code 429) is found in the resp parameter. Hence it will return the number of // seconds the server states it may be ready to process more requests from this client. func DefaultBackoff(min, max time.Duration, attemptNum int, resp *http.Response) time.Duration { - if resp != nil { - if resp.StatusCode == http.StatusTooManyRequests || resp.StatusCode == http.StatusServiceUnavailable { - if s, ok := resp.Header["Retry-After"]; ok { - if sleep, err := strconv.ParseInt(s[0], 10, 64); err == nil { - return time.Second * time.Duration(sleep) - } + if resp != nil && resp.StatusCode == http.StatusTooManyRequests { + if s, ok := resp.Header["Retry-After"]; ok { + if sleep, err := strconv.ParseInt(s[0], 10, 64); err == nil { + return time.Second * time.Duration(sleep) + } + + if after, err := time.Parse(time.RFC1123, s[0]); err == nil { + return after.Sub(time.Now()) } } } diff --git a/client_test.go b/client_test.go index 082b407..50ff8c3 100644 --- a/client_test.go +++ b/client_test.go @@ -7,6 +7,7 @@ import ( "fmt" "io" "io/ioutil" + "math" "net" "net/http" "net/http/httptest" @@ -524,38 +525,59 @@ func TestClient_CheckRetry(t *testing.T) { } func TestClient_DefaultBackoff(t *testing.T) { - for _, code := range []int{http.StatusTooManyRequests, http.StatusServiceUnavailable} { - t.Run(fmt.Sprintf("http_%d", code), func(t *testing.T) { - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Retry-After", "2") - http.Error(w, fmt.Sprintf("test_%d_body", code), code) - })) - defer ts.Close() - - client := NewClient() - - var retryAfter time.Duration - retryable := false - - client.CheckRetry = func(_ context.Context, resp *http.Response, err error) (bool, error) { - retryable, _ = DefaultRetryPolicy(context.Background(), resp, err) - retryAfter = DefaultBackoff(client.RetryWaitMin, client.RetryWaitMax, 1, resp) - return false, nil - } + type tcase struct { + name string + header string + } - _, err := client.Get(ts.URL) - if err != nil { - t.Fatalf("expected no errors since retryable") - } + cases := []tcase{ + { + name: "RFC1123_datetime", + header: time.Now().Add(2 * time.Second).Format(time.RFC1123), + }, + { + name: "numeric_duration", + header: "2", + }, + } - if !retryable { - t.Fatal("Since the error is recoverable, the default policy shall return true") - } + var header string - if retryAfter != 2*time.Second { - t.Fatalf("The header Retry-After specified 2 seconds, and shall not be %d seconds", retryAfter/time.Second) - } - }) + for _, code := range []int{http.StatusTooManyRequests, http.StatusServiceUnavailable} { + for _, tt := range cases { + t.Run(fmt.Sprintf("http_%d", code), func(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Retry-After", header) + http.Error(w, fmt.Sprintf("test_%d_body", code), code) + })) + defer ts.Close() + + client := NewClient() + + var retryAfter time.Duration + retryable := false + client.CheckRetry = func(_ context.Context, resp *http.Response, err error) (bool, error) { + retryable, _ = DefaultRetryPolicy(context.Background(), resp, err) + retryAfter = DefaultBackoff(client.RetryWaitMin, client.RetryWaitMax, 1, resp) + return false, nil + } + + header = tt.header + + _, err := client.Get(ts.URL) + if err != nil { + t.Fatalf("expected no errors since retryable") + } + + if !retryable { + t.Fatal("Since the error is recoverable, the default policy shall return true") + } + + if math.Ceil(retryAfter.Seconds()) != 2 { + t.Fatalf("The header Retry-After specified 2 seconds, and shall not be %.0f seconds", math.Ceil(retryAfter.Seconds())) + } + }) + } } }