Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Workload identity credential defaults to environment configuration #20478

Merged
merged 3 commits into from
Mar 29, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions sdk/azidentity/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@
### Features Added

### Breaking Changes
> These changes affect only code written against a beta version such as v1.3.0-beta.4
* Moved `NewWorkloadIdentityCredential()` parameters into `WorkloadIdentityCredentialOptions`.
The constructor now reads default configuration from environment variables set by the Azure
workload identity webhook by default.

### Bugs Fixed
* Fixed an issue in `DefaultAzureCredential` that could cause the managed identity endpoint check to fail in rare circumstances.
Expand Down
9 changes: 7 additions & 2 deletions sdk/azidentity/azidentity_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -393,8 +393,13 @@ func TestAdditionallyAllowedTenants(t *testing.T) {
{
name: credNameWorkloadIdentity,
ctor: func(co azcore.ClientOptions) (azcore.TokenCredential, error) {
o := WorkloadIdentityCredentialOptions{AdditionallyAllowedTenants: test.allowed, ClientOptions: co}
return NewWorkloadIdentityCredential(fakeTenantID, fakeClientID, af, &o)
return NewWorkloadIdentityCredential(&WorkloadIdentityCredentialOptions{
AdditionallyAllowedTenants: test.allowed,
ClientID: fakeClientID,
ClientOptions: co,
TenantID: fakeTenantID,
TokenFilePath: af,
})
},
},
{
Expand Down
38 changes: 11 additions & 27 deletions sdk/azidentity/default_azure_credential.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,36 +79,20 @@ func NewDefaultAzureCredential(options *DefaultAzureCredentialOptions) (*Default
}

// workload identity requires values for AZURE_AUTHORITY_HOST, AZURE_CLIENT_ID, AZURE_FEDERATED_TOKEN_FILE, AZURE_TENANT_ID
haveWorkloadConfig := false
clientID, haveClientID := os.LookupEnv(azureClientID)
if haveClientID {
if file, ok := os.LookupEnv(azureFederatedTokenFile); ok {
if _, ok := os.LookupEnv(azureAuthorityHost); ok {
if tenantID, ok := os.LookupEnv(azureTenantID); ok {
haveWorkloadConfig = true
workloadCred, err := NewWorkloadIdentityCredential(tenantID, clientID, file, &WorkloadIdentityCredentialOptions{
AdditionallyAllowedTenants: additionalTenants,
ClientOptions: options.ClientOptions,
DisableInstanceDiscovery: options.DisableInstanceDiscovery,
})
if err == nil {
creds = append(creds, workloadCred)
} else {
errorMessages = append(errorMessages, credNameWorkloadIdentity+": "+err.Error())
creds = append(creds, &defaultCredentialErrorReporter{credType: credNameWorkloadIdentity, err: err})
}
}
}
}
}
if !haveWorkloadConfig {
err := errors.New("missing environment variables for workload identity. Check webhook and pod configuration")
wic, err := NewWorkloadIdentityCredential(&WorkloadIdentityCredentialOptions{
AdditionallyAllowedTenants: additionalTenants,
ClientOptions: options.ClientOptions,
DisableInstanceDiscovery: options.DisableInstanceDiscovery,
})
if err == nil {
creds = append(creds, wic)
} else {
errorMessages = append(errorMessages, credNameWorkloadIdentity+": "+err.Error())
creds = append(creds, &defaultCredentialErrorReporter{credType: credNameWorkloadIdentity, err: err})
}

o := &ManagedIdentityCredentialOptions{ClientOptions: options.ClientOptions}
if haveClientID {
o.ID = ClientID(clientID)
if ID, ok := os.LookupEnv(azureClientID); ok {
o.ID = ClientID(ID)
}
miCred, err := NewManagedIdentityCredential(o)
if err == nil {
Expand Down
33 changes: 30 additions & 3 deletions sdk/azidentity/workload_identity.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ package azidentity

import (
"context"
"errors"
"os"
"sync"
"time"
Expand Down Expand Up @@ -37,16 +38,42 @@ type WorkloadIdentityCredentialOptions struct {
// Add the wildcard value "*" to allow the credential to acquire tokens for any tenant in which the
// application is registered.
AdditionallyAllowedTenants []string
// ClientID of the service principal. Defaults to the value of the environment variable AZURE_CLIENT_ID.
ClientID string
// DisableInstanceDiscovery allows disconnected cloud solutions to skip instance discovery for unknown authority hosts.
DisableInstanceDiscovery bool
// TenantID of the service principal. Defaults to the value of the environment variable AZURE_TENANT_ID.
TenantID string
// TokenFilePath is the path a file containing the workload identity token. Defaults to the value of the
// environment variable AZURE_FEDERATED_TOKEN_FILE.
TokenFilePath string
}

// NewWorkloadIdentityCredential constructs a WorkloadIdentityCredential. tenantID and clientID specify the identity the credential authenticates.
// file is a path to a file containing a Kubernetes service account token that authenticates the identity.
func NewWorkloadIdentityCredential(tenantID, clientID, file string, options *WorkloadIdentityCredentialOptions) (*WorkloadIdentityCredential, error) {
// NewWorkloadIdentityCredential constructs a WorkloadIdentityCredential. Service principal configuration is read
// from environment variables as set by the Azure workload identity webhook. Set options to override those values.
func NewWorkloadIdentityCredential(options *WorkloadIdentityCredentialOptions) (*WorkloadIdentityCredential, error) {
if options == nil {
options = &WorkloadIdentityCredentialOptions{}
}
ok := false
clientID := options.ClientID
if clientID == "" {
if clientID, ok = os.LookupEnv(azureClientID); !ok {
return nil, errors.New("no client ID specified. Check pod configuration or set ClientID in the options")
}
}
file := options.TokenFilePath
if file == "" {
if file, ok = os.LookupEnv(azureFederatedTokenFile); !ok {
return nil, errors.New("no token file specified. Check pod configuration or set TokenFilePath in the options")
}
}
tenantID := options.TenantID
if tenantID == "" {
if tenantID, ok = os.LookupEnv(azureTenantID); !ok {
return nil, errors.New("no tenant ID specified. Check pod configuration or set TenantID in the options")
}
}
w := WorkloadIdentityCredential{file: file, mtx: &sync.RWMutex{}}
caco := ClientAssertionCredentialOptions{
AdditionallyAllowedTenants: options.AdditionallyAllowedTenants,
Expand Down
133 changes: 102 additions & 31 deletions sdk/azidentity/workload_identity_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ import (
"time"

"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
"github.com/Azure/azure-sdk-for-go/sdk/internal/mock"
"github.com/golang-jwt/jwt/v4"
"github.com/google/uuid"
)
Expand Down Expand Up @@ -71,8 +70,13 @@ func TestWorkloadIdentityCredential_Live(t *testing.T) {
t.Run(name, func(t *testing.T) {
co, stop := initRecording(t)
defer stop()
o := WorkloadIdentityCredentialOptions{ClientOptions: co, DisableInstanceDiscovery: b}
cred, err := NewWorkloadIdentityCredential(liveSP.tenantID, liveSP.clientID, f, &o)
cred, err := NewWorkloadIdentityCredential(&WorkloadIdentityCredentialOptions{
ClientID: liveSP.clientID,
ClientOptions: co,
DisableInstanceDiscovery: b,
TenantID: liveSP.tenantID,
TokenFilePath: f,
})
if err != nil {
t.Fatal(err)
}
Expand All @@ -86,7 +90,7 @@ func TestWorkloadIdentityCredential(t *testing.T) {
if err := os.WriteFile(tempFile, []byte(tokenValue), os.ModePerm); err != nil {
t.Fatalf("failed to write token file: %v", err)
}
validateReq := func(req *http.Request) bool {
sts := mockSTS{tenant: fakeTenantID, tokenRequestCallback: func(req *http.Request) {
if err := req.ParseForm(); err != nil {
t.Error(err)
}
Expand All @@ -103,18 +107,13 @@ func TestWorkloadIdentityCredential(t *testing.T) {
if actual := strings.Split(req.URL.Path, "/")[1]; actual != fakeTenantID {
t.Errorf(`unexpected tenant "%s"`, actual)
}
return true
}
srv, close := mock.NewServer(mock.WithTransformAllRequestsToTestServerUrl())
defer close()
srv.AppendResponse(mock.WithBody(instanceDiscoveryResponse))
srv.AppendResponse(mock.WithBody(tenantDiscoveryResponse))
srv.AppendResponse(mock.WithPredicate(validateReq), mock.WithBody(accessTokenRespSuccess))
srv.AppendResponse()
opts := WorkloadIdentityCredentialOptions{
ClientOptions: policy.ClientOptions{Transport: srv},
}
cred, err := NewWorkloadIdentityCredential(fakeTenantID, fakeClientID, tempFile, &opts)
}}
cred, err := NewWorkloadIdentityCredential(&WorkloadIdentityCredentialOptions{
ClientID: fakeClientID,
ClientOptions: policy.ClientOptions{Transport: &sts},
TenantID: fakeTenantID,
TokenFilePath: tempFile,
})
if err != nil {
t.Fatal(err)
}
Expand All @@ -124,7 +123,7 @@ func TestWorkloadIdentityCredential(t *testing.T) {
func TestWorkloadIdentityCredential_Expiration(t *testing.T) {
tokenReqs := 0
tempFile := filepath.Join(t.TempDir(), "test-workload-token-file")
validateReq := func(req *http.Request) bool {
sts := mockSTS{tenant: fakeTenantID, tokenRequestCallback: func(req *http.Request) {
if err := req.ParseForm(); err != nil {
t.Error(err)
}
Expand All @@ -134,20 +133,13 @@ func TestWorkloadIdentityCredential_Expiration(t *testing.T) {
t.Errorf(`expected assertion "%d", got "%s"`, tokenReqs, actual[0])
}
tokenReqs++
return true
}
srv, close := mock.NewServer(mock.WithTransformAllRequestsToTestServerUrl())
defer close()
srv.AppendResponse(mock.WithBody(instanceDiscoveryResponse))
srv.AppendResponse(mock.WithBody(tenantDiscoveryResponse))
srv.AppendResponse(mock.WithPredicate(validateReq), mock.WithBody(accessTokenRespSuccess))
srv.AppendResponse()
srv.AppendResponse(mock.WithPredicate(validateReq), mock.WithBody(accessTokenRespSuccess))
srv.AppendResponse()
opts := WorkloadIdentityCredentialOptions{
ClientOptions: policy.ClientOptions{Transport: srv},
}
cred, err := NewWorkloadIdentityCredential(fakeTenantID, fakeClientID, tempFile, &opts)
}}
cred, err := NewWorkloadIdentityCredential(&WorkloadIdentityCredentialOptions{
ClientID: fakeClientID,
ClientOptions: policy.ClientOptions{Transport: &sts},
TenantID: fakeTenantID,
TokenFilePath: tempFile,
})
if err != nil {
t.Fatal(err)
}
Expand All @@ -167,3 +159,82 @@ func TestWorkloadIdentityCredential_Expiration(t *testing.T) {
t.Fatalf("expected 2 token requests, got %d", tokenReqs)
}
}

func TestTestWorkloadIdentityCredential_IncompleteConfig(t *testing.T) {
f := filepath.Join(t.TempDir(), t.Name())
for _, env := range []map[string]string{
{},

{azureClientID: fakeClientID},
{azureFederatedTokenFile: f},
{azureTenantID: fakeTenantID},

{azureClientID: fakeClientID, azureTenantID: fakeTenantID},
{azureClientID: fakeClientID, azureFederatedTokenFile: f},
{azureTenantID: fakeTenantID, azureFederatedTokenFile: f},
} {
t.Run("", func(t *testing.T) {
for k, v := range env {
t.Setenv(k, v)
}
if _, err := NewWorkloadIdentityCredential(nil); err == nil {
t.Fatal("expected an error")
}
})
}
}

func TestWorkloadIdentityCredential_Options(t *testing.T) {
clientID := "not-" + fakeClientID
tenantID := "not-" + fakeTenantID
wrongFile := filepath.Join(t.TempDir(), "wrong")
rightFile := filepath.Join(t.TempDir(), "right")
if err := os.WriteFile(rightFile, []byte(tokenValue), os.ModePerm); err != nil {
t.Fatal(err)
}
sts := mockSTS{
tenant: tenantID,
tokenRequestCallback: func(req *http.Request) {
if err := req.ParseForm(); err != nil {
t.Error(err)
}
if actual, ok := req.PostForm["client_assertion"]; !ok {
t.Error("expected a client_assertion")
} else if len(actual) != 1 || actual[0] != tokenValue {
t.Errorf(`unexpected assertion "%s"`, actual[0])
}
if actual, ok := req.PostForm["client_id"]; !ok {
t.Error("expected a client_id")
} else if len(actual) != 1 || actual[0] != clientID {
t.Errorf(`unexpected assertion "%s"`, actual[0])
}
if actual := strings.Split(req.URL.Path, "/")[1]; actual != tenantID {
t.Errorf(`unexpected tenant "%s"`, actual)
}
},
}
// options should override environment variables
for k, v := range map[string]string{
azureClientID: fakeClientID,
azureFederatedTokenFile: wrongFile,
azureTenantID: fakeTenantID,
} {
t.Setenv(k, v)
}
cred, err := NewWorkloadIdentityCredential(&WorkloadIdentityCredentialOptions{
ClientID: clientID,
ClientOptions: policy.ClientOptions{Transport: &sts},
TenantID: tenantID,
TokenFilePath: rightFile,
})
if err != nil {
t.Fatal(err)
}
tk, err := cred.GetToken(context.Background(), policy.TokenRequestOptions{Scopes: []string{liveTestScope}})
if err != nil {
t.Fatal(err)
}
if tk.Token != tokenValue {
t.Fatalf("unexpected token %q", tk.Token)
}
}