Skip to content

Commit

Permalink
Merge pull request #370 from bingosummer/retrablehttpclient
Browse files Browse the repository at this point in the history
Make retrable http client's retry count configurable in Azure auth provider
  • Loading branch information
weinong authored Jul 31, 2023
2 parents 540dd48 + 5225666 commit 68273e5
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 30 deletions.
26 changes: 13 additions & 13 deletions auth/providers/azure/azure.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,14 +83,14 @@ func New(ctx context.Context, opts Options) (auth.Interface, error) {
c := &Authenticator{
Options: opts,
}
authInfoVal, err := getAuthInfo(ctx, c.Environment, c.TenantID, getMetadata)
authInfoVal, err := getAuthInfo(ctx, c.Environment, c.TenantID, c.HttpClientRetryCount, getMetadata)
if err != nil {
return nil, err
}

klog.V(3).Infof("Using issuer url: %v", authInfoVal.Issuer)

ctx = withRetryableHttpClient(ctx)
ctx = withRetryableHttpClient(ctx, c.HttpClientRetryCount)
provider, err := oidc.NewProvider(ctx, authInfoVal.Issuer)
if err != nil {
return nil, errors.Wrap(err, "failed to create provider for azure")
Expand Down Expand Up @@ -118,8 +118,8 @@ func New(ctx context.Context, opts Options) (auth.Interface, error) {
}

// makeRetryableHttpClient creates an HTTP client which attempts the request
// 3 times and has a 3 second timeout per attempt.
func makeRetryableHttpClient() retryablehttp.Client {
// (1 + retryCount) times and has a 3 second timeout per attempt.
func makeRetryableHttpClient(retryCount int) retryablehttp.Client {
// Copy the default HTTP client so we can set a timeout.
// (It uses the same transport since the pointer gets copied)
httpClient := *httpclient.DefaultHTTPClient
Expand All @@ -130,7 +130,7 @@ func makeRetryableHttpClient() retryablehttp.Client {
HTTPClient: &httpClient,
RetryWaitMin: 500 * time.Millisecond,
RetryWaitMax: 2 * time.Second,
RetryMax: 2, // initial + 2 retries = 3 attempts
RetryMax: retryCount, // initial + retryCount retries = (1 + retryCount) attempts
CheckRetry: retryablehttp.DefaultRetryPolicy,
Backoff: retryablehttp.DefaultBackoff,
Logger: log.Default(),
Expand All @@ -141,8 +141,8 @@ func makeRetryableHttpClient() retryablehttp.Client {
// *http.Client made from makeRetryableHttpClient.
// Some of the libraries we use will take the client out of the context via
// oauth2.HTTPClient and use it, so this way we can add retries to external code.
func withRetryableHttpClient(ctx context.Context) context.Context {
retryClient := makeRetryableHttpClient()
func withRetryableHttpClient(ctx context.Context, retryCount int) context.Context {
retryClient := makeRetryableHttpClient(retryCount)
return context.WithValue(ctx, oauth2.HTTPClient, retryClient.StandardClient())
}

Expand All @@ -152,9 +152,9 @@ type metadataJSON struct {
}

// https://docs.microsoft.com/en-us/azure/active-directory/develop/howto-convert-app-to-be-multi-tenant
func getMetadata(ctx context.Context, aadEndpoint, tenantID string) (*metadataJSON, error) {
func getMetadata(ctx context.Context, aadEndpoint, tenantID string, retryCount int) (*metadataJSON, error) {
metadataURL := aadEndpoint + tenantID + "/.well-known/openid-configuration"
retryClient := makeRetryableHttpClient()
retryClient := makeRetryableHttpClient(retryCount)

request, err := retryablehttp.NewRequest("GET", metadataURL, nil)
if err != nil {
Expand Down Expand Up @@ -198,7 +198,7 @@ func (s Authenticator) Check(ctx context.Context, token string) (*authv1.UserInf
}
}

ctx = withRetryableHttpClient(ctx)
ctx = withRetryableHttpClient(ctx, s.HttpClientRetryCount)
idToken, err := s.verifier.Verify(ctx, token)
if err != nil {
return nil, errors.Wrap(err, "failed to verify token for azure")
Expand Down Expand Up @@ -365,9 +365,9 @@ func (c claims) string(key string) (string, error) {
return s, nil
}

type getMetadataFunc = func(context.Context, string, string) (*metadataJSON, error)
type getMetadataFunc = func(context.Context, string, string, int) (*metadataJSON, error)

func getAuthInfo(ctx context.Context, environment, tenantID string, getMetadata getMetadataFunc) (*authInfo, error) {
func getAuthInfo(ctx context.Context, environment, tenantID string, retryCount int, getMetadata getMetadataFunc) (*authInfo, error) {
var err error
env := azure.PublicCloud
if environment != "" {
Expand All @@ -377,7 +377,7 @@ func getAuthInfo(ctx context.Context, environment, tenantID string, getMetadata
}
}

metadata, err := getMetadata(ctx, env.ActiveDirectoryEndpoint, tenantID)
metadata, err := getMetadata(ctx, env.ActiveDirectoryEndpoint, tenantID, retryCount)
if err != nil {
return nil, errors.Wrap(err, "failed to get metadata for azure")
}
Expand Down
36 changes: 19 additions & 17 deletions auth/providers/azure/azure_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ const (
accessTokenWithNoGroups = `{ "aud": "client_id", "iss" : "%v", "oid": "abc-123d4" }`
accessTokenWithoutOverageClaim = `{ "aud": "client_id", "iss" : "%v", "upn": "nahid", "_claim_names": {"foo": "src1"}, "_claim_sources": {"src1": {"endpoint": "https://foobar" }} }`
badToken = "bad_token"
httpClientRetryCount = 2
)

type signingKey struct {
Expand Down Expand Up @@ -103,15 +104,16 @@ func newRSAKey() (*signingKey, error) {
func clientSetup(clientID, clientSecret, tenantID, serverUrl string, useGroupUID, verifyClientID bool, authMode string) (*Authenticator, error) {
c := &Authenticator{
Options: Options{
Environment: "",
ClientID: clientID,
ClientSecret: clientSecret,
TenantID: tenantID,
UseGroupUID: useGroupUID,
AuthMode: ClientCredentialAuthMode,
AKSTokenURL: "",
VerifyClientID: verifyClientID,
AzureRegion: "eastus",
Environment: "",
ClientID: clientID,
ClientSecret: clientSecret,
TenantID: tenantID,
UseGroupUID: useGroupUID,
AuthMode: ClientCredentialAuthMode,
AKSTokenURL: "",
VerifyClientID: verifyClientID,
AzureRegion: "eastus",
HttpClientRetryCount: httpClientRetryCount,
},
}

Expand Down Expand Up @@ -545,16 +547,16 @@ func TestString(t *testing.T) {

func TestGetAuthInfo(t *testing.T) {
ctx := context.Background()
authInfo, err := getAuthInfo(ctx, "AzurePublicCloud", "testTenant", localGetMetadata)
authInfo, err := getAuthInfo(ctx, "AzurePublicCloud", "testTenant", httpClientRetryCount, localGetMetadata)
assert.NoError(t, err)
assert.Contains(t, authInfo.AADEndpoint, "login.microsoftonline.com")

authInfo, err = getAuthInfo(ctx, "AzureChinaCloud", "testTenant", localGetMetadata)
authInfo, err = getAuthInfo(ctx, "AzureChinaCloud", "testTenant", httpClientRetryCount, localGetMetadata)
assert.NoError(t, err)
assert.Contains(t, authInfo.AADEndpoint, "login.chinacloudapi.cn")
}

func localGetMetadata(context.Context, string, string) (*metadataJSON, error) {
func localGetMetadata(context.Context, string, string, int) (*metadataJSON, error) {
return &metadataJSON{
Issuer: "testIssuer",
MsgraphHost: "testHost",
Expand All @@ -570,9 +572,9 @@ func TestGetMetadata(t *testing.T) {
_, _ = writer.Write([]byte(`{"issuer":"testIssuer","msgraph_host":"testHost"}`))
}))
defer testServer.Close()
expectedMetadata, _ := localGetMetadata(ctx, "", "")
expectedMetadata, _ := localGetMetadata(ctx, "", "", httpClientRetryCount)

metadata, err := getMetadata(ctx, testServer.URL+"/", "testTenant")
metadata, err := getMetadata(ctx, testServer.URL+"/", "testTenant", httpClientRetryCount)
assert.NoError(t, err)
assert.Equal(t, expectedMetadata, metadata)
})
Expand All @@ -599,9 +601,9 @@ func TestGetMetadata(t *testing.T) {
}
}))
defer testServer.Close()
expectedMetadata, _ := localGetMetadata(ctx, "", "")
expectedMetadata, _ := localGetMetadata(ctx, "", "", httpClientRetryCount)

metadata, err := getMetadata(ctx, testServer.URL+"/", "testTenant")
metadata, err := getMetadata(ctx, testServer.URL+"/", "testTenant", httpClientRetryCount)
assert.NoError(t, err)
assert.Equal(t, expectedMetadata, metadata)
})
Expand All @@ -612,7 +614,7 @@ func TestGetMetadata(t *testing.T) {
testServer.CloseClientConnections()
}))

metadata, err := getMetadata(ctx, testServer.URL+"/", "testTenant")
metadata, err := getMetadata(ctx, testServer.URL+"/", "testTenant", httpClientRetryCount)
assert.Error(t, err)
assert.Nil(t, metadata)
})
Expand Down
4 changes: 4 additions & 0 deletions auth/providers/azure/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ type Options struct {
VerifyClientID bool
ResourceId string
AzureRegion string
HttpClientRetryCount int
}

func NewOptions() Options {
Expand All @@ -81,6 +82,7 @@ func (o *Options) AddFlags(fs *pflag.FlagSet) {
// resource id and region are needed to retrieve user's security group info via Arc obo service
fs.StringVar(&o.ResourceId, "azure.auth-resource-id", "", "azure cluster resource id (//subscription/<subName>/resourcegroups/<RGname>/providers/Microsoft.Kubernetes/connectedClusters/<clustername> for connectedk8s) used for making getMemberGroups to ARC OBO service")
fs.StringVar(&o.AzureRegion, "azure.region", "", "region where cluster is deployed")
fs.IntVar(&o.HttpClientRetryCount, "azure.http-client-retry-count", 2, "number of retries for retryablehttp client")
}

func (o *Options) Validate() []error {
Expand Down Expand Up @@ -222,6 +224,8 @@ func (o Options) Apply(d *apps.Deployment) (extraObjs []runtime.Object, err erro

args = append(args, fmt.Sprintf("--azure.verify-clientID=%t", o.VerifyClientID))

args = append(args, fmt.Sprintf("--azure.http-client-retry-count=%d", o.HttpClientRetryCount))

container.Args = args
d.Spec.Template.Spec.Containers[0] = container

Expand Down

0 comments on commit 68273e5

Please sign in to comment.