diff --git a/sdk/azidentity/CHANGELOG.md b/sdk/azidentity/CHANGELOG.md index dc1bf9eb2d3b..05da21ba8770 100644 --- a/sdk/azidentity/CHANGELOG.md +++ b/sdk/azidentity/CHANGELOG.md @@ -6,7 +6,12 @@ ### Breaking Changes > These changes affect only code written against a beta version such as v1.3.0-beta.4 +* Moved `NewWorkloadIdentityCredential()` parameters into `WorkloadIdentityCredentialOptions`. + The constructor now reads default configuration from environment variables set by the Azure + workload identity webhook by default. + ([#20478](https://github.com/Azure/azure-sdk-for-go/pull/20478)) * Removed CAE support. It will return in the next beta release. + ([#20479](https://github.com/Azure/azure-sdk-for-go/pull/20479)) ### Bugs Fixed * Fixed an issue in `DefaultAzureCredential` that could cause the managed identity endpoint check to fail in rare circumstances. diff --git a/sdk/azidentity/azidentity_test.go b/sdk/azidentity/azidentity_test.go index d9aac0b6ad78..91caa2c24350 100644 --- a/sdk/azidentity/azidentity_test.go +++ b/sdk/azidentity/azidentity_test.go @@ -393,8 +393,13 @@ func TestAdditionallyAllowedTenants(t *testing.T) { { name: credNameWorkloadIdentity, ctor: func(co azcore.ClientOptions) (azcore.TokenCredential, error) { - o := WorkloadIdentityCredentialOptions{AdditionallyAllowedTenants: test.allowed, ClientOptions: co} - return NewWorkloadIdentityCredential(fakeTenantID, fakeClientID, af, &o) + return NewWorkloadIdentityCredential(&WorkloadIdentityCredentialOptions{ + AdditionallyAllowedTenants: test.allowed, + ClientID: fakeClientID, + ClientOptions: co, + TenantID: fakeTenantID, + TokenFilePath: af, + }) }, }, { diff --git a/sdk/azidentity/default_azure_credential.go b/sdk/azidentity/default_azure_credential.go index 2bb606aa7f35..d947b1b96e88 100644 --- a/sdk/azidentity/default_azure_credential.go +++ b/sdk/azidentity/default_azure_credential.go @@ -79,36 +79,20 @@ func NewDefaultAzureCredential(options *DefaultAzureCredentialOptions) (*Default } // workload identity requires values for AZURE_AUTHORITY_HOST, AZURE_CLIENT_ID, AZURE_FEDERATED_TOKEN_FILE, AZURE_TENANT_ID - haveWorkloadConfig := false - clientID, haveClientID := os.LookupEnv(azureClientID) - if haveClientID { - if file, ok := os.LookupEnv(azureFederatedTokenFile); ok { - if _, ok := os.LookupEnv(azureAuthorityHost); ok { - if tenantID, ok := os.LookupEnv(azureTenantID); ok { - haveWorkloadConfig = true - workloadCred, err := NewWorkloadIdentityCredential(tenantID, clientID, file, &WorkloadIdentityCredentialOptions{ - AdditionallyAllowedTenants: additionalTenants, - ClientOptions: options.ClientOptions, - DisableInstanceDiscovery: options.DisableInstanceDiscovery, - }) - if err == nil { - creds = append(creds, workloadCred) - } else { - errorMessages = append(errorMessages, credNameWorkloadIdentity+": "+err.Error()) - creds = append(creds, &defaultCredentialErrorReporter{credType: credNameWorkloadIdentity, err: err}) - } - } - } - } - } - if !haveWorkloadConfig { - err := errors.New("missing environment variables for workload identity. Check webhook and pod configuration") + wic, err := NewWorkloadIdentityCredential(&WorkloadIdentityCredentialOptions{ + AdditionallyAllowedTenants: additionalTenants, + ClientOptions: options.ClientOptions, + DisableInstanceDiscovery: options.DisableInstanceDiscovery, + }) + if err == nil { + creds = append(creds, wic) + } else { + errorMessages = append(errorMessages, credNameWorkloadIdentity+": "+err.Error()) creds = append(creds, &defaultCredentialErrorReporter{credType: credNameWorkloadIdentity, err: err}) } - o := &ManagedIdentityCredentialOptions{ClientOptions: options.ClientOptions} - if haveClientID { - o.ID = ClientID(clientID) + if ID, ok := os.LookupEnv(azureClientID); ok { + o.ID = ClientID(ID) } miCred, err := NewManagedIdentityCredential(o) if err == nil { diff --git a/sdk/azidentity/workload_identity.go b/sdk/azidentity/workload_identity.go index d4f0dc667b6e..76ce57f45494 100644 --- a/sdk/azidentity/workload_identity.go +++ b/sdk/azidentity/workload_identity.go @@ -8,6 +8,7 @@ package azidentity import ( "context" + "errors" "os" "sync" "time" @@ -37,16 +38,42 @@ type WorkloadIdentityCredentialOptions struct { // Add the wildcard value "*" to allow the credential to acquire tokens for any tenant in which the // application is registered. AdditionallyAllowedTenants []string + // ClientID of the service principal. Defaults to the value of the environment variable AZURE_CLIENT_ID. + ClientID string // DisableInstanceDiscovery allows disconnected cloud solutions to skip instance discovery for unknown authority hosts. DisableInstanceDiscovery bool + // TenantID of the service principal. Defaults to the value of the environment variable AZURE_TENANT_ID. + TenantID string + // TokenFilePath is the path a file containing the workload identity token. Defaults to the value of the + // environment variable AZURE_FEDERATED_TOKEN_FILE. + TokenFilePath string } -// NewWorkloadIdentityCredential constructs a WorkloadIdentityCredential. tenantID and clientID specify the identity the credential authenticates. -// file is a path to a file containing a Kubernetes service account token that authenticates the identity. -func NewWorkloadIdentityCredential(tenantID, clientID, file string, options *WorkloadIdentityCredentialOptions) (*WorkloadIdentityCredential, error) { +// NewWorkloadIdentityCredential constructs a WorkloadIdentityCredential. Service principal configuration is read +// from environment variables as set by the Azure workload identity webhook. Set options to override those values. +func NewWorkloadIdentityCredential(options *WorkloadIdentityCredentialOptions) (*WorkloadIdentityCredential, error) { if options == nil { options = &WorkloadIdentityCredentialOptions{} } + ok := false + clientID := options.ClientID + if clientID == "" { + if clientID, ok = os.LookupEnv(azureClientID); !ok { + return nil, errors.New("no client ID specified. Check pod configuration or set ClientID in the options") + } + } + file := options.TokenFilePath + if file == "" { + if file, ok = os.LookupEnv(azureFederatedTokenFile); !ok { + return nil, errors.New("no token file specified. Check pod configuration or set TokenFilePath in the options") + } + } + tenantID := options.TenantID + if tenantID == "" { + if tenantID, ok = os.LookupEnv(azureTenantID); !ok { + return nil, errors.New("no tenant ID specified. Check pod configuration or set TenantID in the options") + } + } w := WorkloadIdentityCredential{file: file, mtx: &sync.RWMutex{}} caco := ClientAssertionCredentialOptions{ AdditionallyAllowedTenants: options.AdditionallyAllowedTenants, diff --git a/sdk/azidentity/workload_identity_test.go b/sdk/azidentity/workload_identity_test.go index dccbe0d7eae4..121e60a0e389 100644 --- a/sdk/azidentity/workload_identity_test.go +++ b/sdk/azidentity/workload_identity_test.go @@ -23,7 +23,6 @@ import ( "time" "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" - "github.com/Azure/azure-sdk-for-go/sdk/internal/mock" "github.com/golang-jwt/jwt/v4" "github.com/google/uuid" ) @@ -71,8 +70,13 @@ func TestWorkloadIdentityCredential_Live(t *testing.T) { t.Run(name, func(t *testing.T) { co, stop := initRecording(t) defer stop() - o := WorkloadIdentityCredentialOptions{ClientOptions: co, DisableInstanceDiscovery: b} - cred, err := NewWorkloadIdentityCredential(liveSP.tenantID, liveSP.clientID, f, &o) + cred, err := NewWorkloadIdentityCredential(&WorkloadIdentityCredentialOptions{ + ClientID: liveSP.clientID, + ClientOptions: co, + DisableInstanceDiscovery: b, + TenantID: liveSP.tenantID, + TokenFilePath: f, + }) if err != nil { t.Fatal(err) } @@ -86,7 +90,7 @@ func TestWorkloadIdentityCredential(t *testing.T) { if err := os.WriteFile(tempFile, []byte(tokenValue), os.ModePerm); err != nil { t.Fatalf("failed to write token file: %v", err) } - validateReq := func(req *http.Request) bool { + sts := mockSTS{tenant: fakeTenantID, tokenRequestCallback: func(req *http.Request) { if err := req.ParseForm(); err != nil { t.Error(err) } @@ -103,18 +107,13 @@ func TestWorkloadIdentityCredential(t *testing.T) { if actual := strings.Split(req.URL.Path, "/")[1]; actual != fakeTenantID { t.Errorf(`unexpected tenant "%s"`, actual) } - return true - } - srv, close := mock.NewServer(mock.WithTransformAllRequestsToTestServerUrl()) - defer close() - srv.AppendResponse(mock.WithBody(instanceDiscoveryResponse)) - srv.AppendResponse(mock.WithBody(tenantDiscoveryResponse)) - srv.AppendResponse(mock.WithPredicate(validateReq), mock.WithBody(accessTokenRespSuccess)) - srv.AppendResponse() - opts := WorkloadIdentityCredentialOptions{ - ClientOptions: policy.ClientOptions{Transport: srv}, - } - cred, err := NewWorkloadIdentityCredential(fakeTenantID, fakeClientID, tempFile, &opts) + }} + cred, err := NewWorkloadIdentityCredential(&WorkloadIdentityCredentialOptions{ + ClientID: fakeClientID, + ClientOptions: policy.ClientOptions{Transport: &sts}, + TenantID: fakeTenantID, + TokenFilePath: tempFile, + }) if err != nil { t.Fatal(err) } @@ -124,7 +123,7 @@ func TestWorkloadIdentityCredential(t *testing.T) { func TestWorkloadIdentityCredential_Expiration(t *testing.T) { tokenReqs := 0 tempFile := filepath.Join(t.TempDir(), "test-workload-token-file") - validateReq := func(req *http.Request) bool { + sts := mockSTS{tenant: fakeTenantID, tokenRequestCallback: func(req *http.Request) { if err := req.ParseForm(); err != nil { t.Error(err) } @@ -134,20 +133,13 @@ func TestWorkloadIdentityCredential_Expiration(t *testing.T) { t.Errorf(`expected assertion "%d", got "%s"`, tokenReqs, actual[0]) } tokenReqs++ - return true - } - srv, close := mock.NewServer(mock.WithTransformAllRequestsToTestServerUrl()) - defer close() - srv.AppendResponse(mock.WithBody(instanceDiscoveryResponse)) - srv.AppendResponse(mock.WithBody(tenantDiscoveryResponse)) - srv.AppendResponse(mock.WithPredicate(validateReq), mock.WithBody(accessTokenRespSuccess)) - srv.AppendResponse() - srv.AppendResponse(mock.WithPredicate(validateReq), mock.WithBody(accessTokenRespSuccess)) - srv.AppendResponse() - opts := WorkloadIdentityCredentialOptions{ - ClientOptions: policy.ClientOptions{Transport: srv}, - } - cred, err := NewWorkloadIdentityCredential(fakeTenantID, fakeClientID, tempFile, &opts) + }} + cred, err := NewWorkloadIdentityCredential(&WorkloadIdentityCredentialOptions{ + ClientID: fakeClientID, + ClientOptions: policy.ClientOptions{Transport: &sts}, + TenantID: fakeTenantID, + TokenFilePath: tempFile, + }) if err != nil { t.Fatal(err) } @@ -167,3 +159,82 @@ func TestWorkloadIdentityCredential_Expiration(t *testing.T) { t.Fatalf("expected 2 token requests, got %d", tokenReqs) } } + +func TestTestWorkloadIdentityCredential_IncompleteConfig(t *testing.T) { + f := filepath.Join(t.TempDir(), t.Name()) + for _, env := range []map[string]string{ + {}, + + {azureClientID: fakeClientID}, + {azureFederatedTokenFile: f}, + {azureTenantID: fakeTenantID}, + + {azureClientID: fakeClientID, azureTenantID: fakeTenantID}, + {azureClientID: fakeClientID, azureFederatedTokenFile: f}, + {azureTenantID: fakeTenantID, azureFederatedTokenFile: f}, + } { + t.Run("", func(t *testing.T) { + for k, v := range env { + t.Setenv(k, v) + } + if _, err := NewWorkloadIdentityCredential(nil); err == nil { + t.Fatal("expected an error") + } + }) + } +} + +func TestWorkloadIdentityCredential_Options(t *testing.T) { + clientID := "not-" + fakeClientID + tenantID := "not-" + fakeTenantID + wrongFile := filepath.Join(t.TempDir(), "wrong") + rightFile := filepath.Join(t.TempDir(), "right") + if err := os.WriteFile(rightFile, []byte(tokenValue), os.ModePerm); err != nil { + t.Fatal(err) + } + sts := mockSTS{ + tenant: tenantID, + tokenRequestCallback: func(req *http.Request) { + if err := req.ParseForm(); err != nil { + t.Error(err) + } + if actual, ok := req.PostForm["client_assertion"]; !ok { + t.Error("expected a client_assertion") + } else if len(actual) != 1 || actual[0] != tokenValue { + t.Errorf(`unexpected assertion "%s"`, actual[0]) + } + if actual, ok := req.PostForm["client_id"]; !ok { + t.Error("expected a client_id") + } else if len(actual) != 1 || actual[0] != clientID { + t.Errorf(`unexpected assertion "%s"`, actual[0]) + } + if actual := strings.Split(req.URL.Path, "/")[1]; actual != tenantID { + t.Errorf(`unexpected tenant "%s"`, actual) + } + }, + } + // options should override environment variables + for k, v := range map[string]string{ + azureClientID: fakeClientID, + azureFederatedTokenFile: wrongFile, + azureTenantID: fakeTenantID, + } { + t.Setenv(k, v) + } + cred, err := NewWorkloadIdentityCredential(&WorkloadIdentityCredentialOptions{ + ClientID: clientID, + ClientOptions: policy.ClientOptions{Transport: &sts}, + TenantID: tenantID, + TokenFilePath: rightFile, + }) + if err != nil { + t.Fatal(err) + } + tk, err := cred.GetToken(context.Background(), policy.TokenRequestOptions{Scopes: []string{liveTestScope}}) + if err != nil { + t.Fatal(err) + } + if tk.Token != tokenValue { + t.Fatalf("unexpected token %q", tk.Token) + } +}