diff --git a/internal/cli/cli.go b/internal/cli/cli.go index 08b526520..cbc3587bd 100644 --- a/internal/cli/cli.go +++ b/internal/cli/cli.go @@ -3,17 +3,13 @@ package cli import ( "context" "fmt" - "net/http" - "strings" - "github.com/auth0/go-auth0/management" "github.com/spf13/cobra" "github.com/spf13/pflag" "github.com/auth0/auth0-cli/internal/analytics" "github.com/auth0/auth0-cli/internal/ansi" "github.com/auth0/auth0-cli/internal/auth0" - "github.com/auth0/auth0-cli/internal/buildinfo" "github.com/auth0/auth0-cli/internal/config" "github.com/auth0/auth0-cli/internal/display" "github.com/auth0/auth0-cli/internal/iostream" @@ -108,15 +104,7 @@ func (c *cli) setupWithAuthentication(ctx context.Context) error { } } - userAgent := fmt.Sprintf("%v/%v", userAgent, strings.TrimPrefix(buildinfo.Version, "v")) - - api, err := management.New( - tenant.Domain, - management.WithStaticToken(tenant.GetAccessToken()), - management.WithUserAgent(userAgent), - management.WithAuth0ClientEnvEntry("Auth0-CLI", strings.TrimPrefix(buildinfo.Version, "v")), - management.WithRetries(5, []int{http.StatusTooManyRequests, http.StatusInternalServerError}), - ) + api, err := initializeManagementClient(tenant.Domain, tenant.GetAccessToken()) if err != nil { return err } diff --git a/internal/cli/management.go b/internal/cli/management.go new file mode 100644 index 000000000..5f0286c8f --- /dev/null +++ b/internal/cli/management.go @@ -0,0 +1,122 @@ +package cli + +import ( + "crypto/tls" + "crypto/x509" + "fmt" + "net/http" + "net/url" + "regexp" + "strconv" + "strings" + "time" + + "github.com/PuerkitoBio/rehttp" + "github.com/auth0/go-auth0/management" + + "github.com/auth0/auth0-cli/internal/buildinfo" +) + +func initializeManagementClient(tenantDomain string, accessToken string) (*management.Management, error) { + client, err := management.New( + tenantDomain, + management.WithStaticToken(accessToken), + management.WithUserAgent(fmt.Sprintf("%v/%v", userAgent, strings.TrimPrefix(buildinfo.Version, "v"))), + management.WithAuth0ClientEnvEntry("Auth0-CLI", strings.TrimPrefix(buildinfo.Version, "v")), + management.WithNoRetries(), + management.WithClient(customClientWithRetries()), + ) + + return client, err +} + +func customClientWithRetries() *http.Client { + client := &http.Client{ + Transport: rateLimitTransport( + retryableErrorTransport( + http.DefaultTransport, + ), + ), + } + + return client +} + +func rateLimitTransport(tripper http.RoundTripper) http.RoundTripper { + return rehttp.NewTransport(tripper, rateLimitRetry, rateLimitDelay) +} + +func rateLimitRetry(attempt rehttp.Attempt) bool { + if attempt.Response == nil { + return false + } + + return attempt.Response.StatusCode == http.StatusTooManyRequests +} + +func rateLimitDelay(attempt rehttp.Attempt) time.Duration { + resetAt := attempt.Response.Header.Get("X-RateLimit-Reset") + + resetAtUnix, err := strconv.ParseInt(resetAt, 10, 64) + if err != nil { + resetAtUnix = time.Now().Add(5 * time.Second).Unix() + } + + return time.Duration(resetAtUnix-time.Now().Unix()) * time.Second +} + +func retryableErrorTransport(tripper http.RoundTripper) http.RoundTripper { + retryableCodes := []int{ + http.StatusServiceUnavailable, + http.StatusInternalServerError, + http.StatusBadGateway, + http.StatusGatewayTimeout, + // Cloudflare-specific server error that is generated + // because Cloudflare did not receive an HTTP response + // from the origin server after an HTTP Connection was made. + 524, + } + + return rehttp.NewTransport( + tripper, + rehttp.RetryAll( + rehttp.RetryMaxRetries(3), + rehttp.RetryAny( + rehttp.RetryStatuses(retryableCodes...), + rehttp.RetryIsErr(retryableErrorRetryFunc), + ), + ), + rehttp.ExpJitterDelay(500*time.Millisecond, 10*time.Second), + ) +} + +func retryableErrorRetryFunc(err error) bool { + if err == nil { + return false + } + + if v, ok := err.(*url.Error); ok { + // Don't retry if the error was due to too many redirects. + if regexp.MustCompile(`stopped after \d+ redirects\z`).MatchString(v.Error()) { + return false + } + + // Don't retry if the error was due to an invalid protocol scheme. + if regexp.MustCompile(`unsupported protocol scheme`).MatchString(v.Error()) { + return false + } + + // Don't retry if the certificate issuer is unknown. + if _, ok := v.Err.(*tls.CertificateVerificationError); ok { + return false + } + + // Don't retry if the certificate issuer is unknown. + if _, ok := v.Err.(x509.UnknownAuthorityError); ok { + return false + } + } + + // The error is likely recoverable so retry. + return true +} diff --git a/internal/cli/management_test.go b/internal/cli/management_test.go new file mode 100644 index 000000000..2342eb18f --- /dev/null +++ b/internal/cli/management_test.go @@ -0,0 +1,175 @@ +package cli + +import ( + "crypto/tls" + "crypto/x509" + "errors" + "net/http" + "net/http/httptest" + "net/url" + "strconv" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestCustomClientWithRetries(t *testing.T) { + t.Run("it retries on rate limit error", func(t *testing.T) { + apiCalls := 0 + fail := true + testServer := httptest.NewServer(http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) { + apiCalls++ + + if fail { + fail = false + writer.WriteHeader(429) + resetAt := time.Now().Add(time.Second).Unix() + writer.Header().Set("X-RateLimit-Reset", strconv.Itoa(int(resetAt))) + return + } + + writer.WriteHeader(200) + })) + + client := customClientWithRetries() + + request, err := http.NewRequest(http.MethodGet, testServer.URL, nil) + require.NoError(t, err) + + response, err := client.Do(request) + require.NoError(t, err) + + assert.Equal(t, 200, response.StatusCode) + assert.False(t, fail) + assert.Equal(t, 2, apiCalls) + + t.Cleanup(func() { + testServer.Close() + err := response.Body.Close() + require.NoError(t, err) + }) + }) + + t.Run("it retries on server error", func(t *testing.T) { + apiCalls := 0 + fail := true + testServer := httptest.NewServer(http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) { + apiCalls++ + + if fail { + fail = false + writer.WriteHeader(500) + return + } + + writer.WriteHeader(200) + })) + + client := customClientWithRetries() + + request, err := http.NewRequest(http.MethodGet, testServer.URL, nil) + require.NoError(t, err) + + response, err := client.Do(request) + require.NoError(t, err) + + assert.Equal(t, 200, response.StatusCode) + assert.False(t, fail) + assert.Equal(t, 2, apiCalls) + + t.Cleanup(func() { + testServer.Close() + err := response.Body.Close() + require.NoError(t, err) + }) + }) + + t.Run("it does not retry more than 3 times on server error", func(t *testing.T) { + apiCalls := 0 + testServer := httptest.NewServer(http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) { + apiCalls++ + writer.WriteHeader(500) + })) + + client := customClientWithRetries() + + request, err := http.NewRequest(http.MethodGet, testServer.URL, nil) + require.NoError(t, err) + + response, err := client.Do(request) + require.NoError(t, err) + + assert.Equal(t, 500, response.StatusCode) + assert.Equal(t, 3+1, apiCalls) // 3 retries + 1 first call. + + t.Cleanup(func() { + testServer.Close() + err := response.Body.Close() + require.NoError(t, err) + }) + }) +} + +func TestRetryableErrorRetryFunc(t *testing.T) { + testCases := []struct { + name string + err error + expected bool + }{ + { + name: "NilError", + err: nil, + expected: false, + }, + { + name: "TooManyRedirectsError", + err: &url.Error{ + Op: "Get", + URL: "http://example.com", + Err: errors.New("stopped after 5 redirects"), + }, + expected: false, + }, + { + name: "UnsupportedProtocolSchemeError", + err: &url.Error{ + Op: "Get", + URL: "ftp://example.com", + Err: errors.New("unsupported protocol scheme"), + }, + expected: false, + }, + { + name: "CertificateVerificationError", + err: &url.Error{ + Op: "Get", + URL: "https://example.com", + Err: &tls.CertificateVerificationError{}, + }, + expected: false, + }, + { + name: "UnknownAuthorityError", + err: &url.Error{ + Op: "Get", + URL: "https://example.com", + Err: x509.UnknownAuthorityError{}, + }, + expected: false, + }, + { + name: "OtherError", + err: errors.New("some other error"), + expected: true, + }, + } + + for _, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + actual := retryableErrorRetryFunc(testCase.err) + assert.Equal(t, testCase.expected, actual) + }) + } +}