diff --git a/client.go b/client.go index c9edbd0..3ea4e85 100644 --- a/client.go +++ b/client.go @@ -73,6 +73,10 @@ var ( // specifically so we resort to matching on the error string. schemeErrorRe = regexp.MustCompile(`unsupported protocol scheme`) + // timeNow sets the function that returns the current time. + // This defaults to time.Now. Changes to this should only be done in tests. + timeNow = time.Now + // A regular expression to match the error returned by net/http when the // TLS certificate is not trusted. This error isn't typed // specifically so we resort to matching on the error string. @@ -535,10 +539,8 @@ func baseRetryPolicy(resp *http.Response, err error) (bool, error) { func DefaultBackoff(min, max time.Duration, attemptNum int, resp *http.Response) time.Duration { if resp != nil { if resp.StatusCode == http.StatusTooManyRequests || resp.StatusCode == http.StatusServiceUnavailable { - if s, ok := resp.Header["Retry-After"]; ok { - if sleep, err := strconv.ParseInt(s[0], 10, 64); err == nil { - return time.Second * time.Duration(sleep) - } + if sleep, ok := parseRetryAfterHeader(resp.Header["Retry-After"]); ok { + return sleep } } } @@ -551,6 +553,41 @@ func DefaultBackoff(min, max time.Duration, attemptNum int, resp *http.Response) return sleep } +// parseRetryAfterHeader parses the Retry-After header and returns the +// delay duration according to the spec: https://httpwg.org/specs/rfc7231.html#header.retry-after +// The bool returned will be true if the header was successfully parsed. +// Otherwise, the header was either not present, or was not parseable according to the spec. +// +// Retry-After headers come in two flavors: Seconds or HTTP-Date +// +// Examples: +// * Retry-After: Fri, 31 Dec 1999 23:59:59 GMT +// * Retry-After: 120 +func parseRetryAfterHeader(headers []string) (time.Duration, bool) { + if len(headers) == 0 || headers[0] == "" { + return 0, false + } + header := headers[0] + // Retry-After: 120 + if sleep, err := strconv.ParseInt(header, 10, 64); err == nil { + if sleep < 0 { // a negative sleep doesn't make sense + return 0, false + } + return time.Second * time.Duration(sleep), true + } + + // Retry-After: Fri, 31 Dec 1999 23:59:59 GMT + retryTime, err := time.Parse(time.RFC1123, header) + if err != nil { + return 0, false + } + if until := retryTime.Sub(timeNow()); until > 0 { + return until, true + } + // date is in the past + return 0, true +} + // LinearJitterBackoff provides a callback for Client.Backoff which will // perform linear backoff based on the attempt number and with jitter to // prevent a thundering herd. diff --git a/client_test.go b/client_test.go index c5e98a5..0c0c13b 100644 --- a/client_test.go +++ b/client_test.go @@ -655,12 +655,67 @@ func TestClient_CheckRetry(t *testing.T) { } } +func testStaticTime(t *testing.T) { + timeNow = func() time.Time { + now, err := time.Parse(time.RFC1123, "Fri, 31 Dec 1999 23:59:57 GMT") + if err != nil { + panic(err) + } + return now + } + t.Cleanup(func() { + timeNow = time.Now + }) +} + +func TestParseRetryAfterHeader(t *testing.T) { + testStaticTime(t) + tests := []struct { + name string + headers []string + sleep time.Duration + ok bool + }{ + {"seconds", []string{"2"}, time.Second * 2, true}, + {"date", []string{"Fri, 31 Dec 1999 23:59:59 GMT"}, time.Second * 2, true}, + {"past-date", []string{"Fri, 31 Dec 1999 23:59:00 GMT"}, 0, true}, + {"nil", nil, 0, false}, + {"two-headers", []string{"2", "3"}, time.Second * 2, true}, + {"empty", []string{""}, 0, false}, + {"negative", []string{"-2"}, 0, false}, + {"bad-date", []string{"Fri, 32 Dec 1999 23:59:59 GMT"}, 0, false}, + {"bad-date-format", []string{"badbadbad"}, 0, false}, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + sleep, ok := parseRetryAfterHeader(test.headers) + if ok != test.ok { + t.Fatalf("expected ok=%t, got ok=%t", test.ok, ok) + } + if sleep != test.sleep { + t.Fatalf("expected sleep=%v, got sleep=%v", test.sleep, sleep) + } + }) + } +} + func TestClient_DefaultBackoff(t *testing.T) { - for _, code := range []int{http.StatusTooManyRequests, http.StatusServiceUnavailable} { - t.Run(fmt.Sprintf("http_%d", code), func(t *testing.T) { + testStaticTime(t) + tests := []struct { + name string + code int + retryHeader string + }{ + {"http_429_seconds", http.StatusTooManyRequests, "2"}, + {"http_429_date", http.StatusTooManyRequests, "Fri, 31 Dec 1999 23:59:59 GMT"}, + {"http_503_seconds", http.StatusServiceUnavailable, "2"}, + {"http_503_date", http.StatusServiceUnavailable, "Fri, 31 Dec 1999 23:59:59 GMT"}, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Retry-After", "2") - http.Error(w, fmt.Sprintf("test_%d_body", code), code) + w.Header().Set("Retry-After", test.retryHeader) + http.Error(w, fmt.Sprintf("test_%d_body", test.code), test.code) })) defer ts.Close()