Skip to content

Commit

Permalink
Workload identity credential defaults to environment configuration (A…
Browse files Browse the repository at this point in the history
  • Loading branch information
chlowell authored Mar 29, 2023
1 parent 9b0eb84 commit 1cd56ad
Show file tree
Hide file tree
Showing 5 changed files with 155 additions and 63 deletions.
5 changes: 5 additions & 0 deletions sdk/azidentity/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,12 @@

### Breaking Changes
> These changes affect only code written against a beta version such as v1.3.0-beta.4
* Moved `NewWorkloadIdentityCredential()` parameters into `WorkloadIdentityCredentialOptions`.
The constructor now reads default configuration from environment variables set by the Azure
workload identity webhook by default.
([#20478](https://github.com/Azure/azure-sdk-for-go/pull/20478))
* Removed CAE support. It will return in the next beta release.
([#20479](https://github.com/Azure/azure-sdk-for-go/pull/20479))

### Bugs Fixed
* Fixed an issue in `DefaultAzureCredential` that could cause the managed identity endpoint check to fail in rare circumstances.
Expand Down
9 changes: 7 additions & 2 deletions sdk/azidentity/azidentity_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -393,8 +393,13 @@ func TestAdditionallyAllowedTenants(t *testing.T) {
{
name: credNameWorkloadIdentity,
ctor: func(co azcore.ClientOptions) (azcore.TokenCredential, error) {
o := WorkloadIdentityCredentialOptions{AdditionallyAllowedTenants: test.allowed, ClientOptions: co}
return NewWorkloadIdentityCredential(fakeTenantID, fakeClientID, af, &o)
return NewWorkloadIdentityCredential(&WorkloadIdentityCredentialOptions{
AdditionallyAllowedTenants: test.allowed,
ClientID: fakeClientID,
ClientOptions: co,
TenantID: fakeTenantID,
TokenFilePath: af,
})
},
},
{
Expand Down
38 changes: 11 additions & 27 deletions sdk/azidentity/default_azure_credential.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,36 +79,20 @@ func NewDefaultAzureCredential(options *DefaultAzureCredentialOptions) (*Default
}

// workload identity requires values for AZURE_AUTHORITY_HOST, AZURE_CLIENT_ID, AZURE_FEDERATED_TOKEN_FILE, AZURE_TENANT_ID
haveWorkloadConfig := false
clientID, haveClientID := os.LookupEnv(azureClientID)
if haveClientID {
if file, ok := os.LookupEnv(azureFederatedTokenFile); ok {
if _, ok := os.LookupEnv(azureAuthorityHost); ok {
if tenantID, ok := os.LookupEnv(azureTenantID); ok {
haveWorkloadConfig = true
workloadCred, err := NewWorkloadIdentityCredential(tenantID, clientID, file, &WorkloadIdentityCredentialOptions{
AdditionallyAllowedTenants: additionalTenants,
ClientOptions: options.ClientOptions,
DisableInstanceDiscovery: options.DisableInstanceDiscovery,
})
if err == nil {
creds = append(creds, workloadCred)
} else {
errorMessages = append(errorMessages, credNameWorkloadIdentity+": "+err.Error())
creds = append(creds, &defaultCredentialErrorReporter{credType: credNameWorkloadIdentity, err: err})
}
}
}
}
}
if !haveWorkloadConfig {
err := errors.New("missing environment variables for workload identity. Check webhook and pod configuration")
wic, err := NewWorkloadIdentityCredential(&WorkloadIdentityCredentialOptions{
AdditionallyAllowedTenants: additionalTenants,
ClientOptions: options.ClientOptions,
DisableInstanceDiscovery: options.DisableInstanceDiscovery,
})
if err == nil {
creds = append(creds, wic)
} else {
errorMessages = append(errorMessages, credNameWorkloadIdentity+": "+err.Error())
creds = append(creds, &defaultCredentialErrorReporter{credType: credNameWorkloadIdentity, err: err})
}

o := &ManagedIdentityCredentialOptions{ClientOptions: options.ClientOptions}
if haveClientID {
o.ID = ClientID(clientID)
if ID, ok := os.LookupEnv(azureClientID); ok {
o.ID = ClientID(ID)
}
miCred, err := NewManagedIdentityCredential(o)
if err == nil {
Expand Down
33 changes: 30 additions & 3 deletions sdk/azidentity/workload_identity.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ package azidentity

import (
"context"
"errors"
"os"
"sync"
"time"
Expand Down Expand Up @@ -37,16 +38,42 @@ type WorkloadIdentityCredentialOptions struct {
// Add the wildcard value "*" to allow the credential to acquire tokens for any tenant in which the
// application is registered.
AdditionallyAllowedTenants []string
// ClientID of the service principal. Defaults to the value of the environment variable AZURE_CLIENT_ID.
ClientID string
// DisableInstanceDiscovery allows disconnected cloud solutions to skip instance discovery for unknown authority hosts.
DisableInstanceDiscovery bool
// TenantID of the service principal. Defaults to the value of the environment variable AZURE_TENANT_ID.
TenantID string
// TokenFilePath is the path a file containing the workload identity token. Defaults to the value of the
// environment variable AZURE_FEDERATED_TOKEN_FILE.
TokenFilePath string
}

// NewWorkloadIdentityCredential constructs a WorkloadIdentityCredential. tenantID and clientID specify the identity the credential authenticates.
// file is a path to a file containing a Kubernetes service account token that authenticates the identity.
func NewWorkloadIdentityCredential(tenantID, clientID, file string, options *WorkloadIdentityCredentialOptions) (*WorkloadIdentityCredential, error) {
// NewWorkloadIdentityCredential constructs a WorkloadIdentityCredential. Service principal configuration is read
// from environment variables as set by the Azure workload identity webhook. Set options to override those values.
func NewWorkloadIdentityCredential(options *WorkloadIdentityCredentialOptions) (*WorkloadIdentityCredential, error) {
if options == nil {
options = &WorkloadIdentityCredentialOptions{}
}
ok := false
clientID := options.ClientID
if clientID == "" {
if clientID, ok = os.LookupEnv(azureClientID); !ok {
return nil, errors.New("no client ID specified. Check pod configuration or set ClientID in the options")
}
}
file := options.TokenFilePath
if file == "" {
if file, ok = os.LookupEnv(azureFederatedTokenFile); !ok {
return nil, errors.New("no token file specified. Check pod configuration or set TokenFilePath in the options")
}
}
tenantID := options.TenantID
if tenantID == "" {
if tenantID, ok = os.LookupEnv(azureTenantID); !ok {
return nil, errors.New("no tenant ID specified. Check pod configuration or set TenantID in the options")
}
}
w := WorkloadIdentityCredential{file: file, mtx: &sync.RWMutex{}}
caco := ClientAssertionCredentialOptions{
AdditionallyAllowedTenants: options.AdditionallyAllowedTenants,
Expand Down
133 changes: 102 additions & 31 deletions sdk/azidentity/workload_identity_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ import (
"time"

"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
"github.com/Azure/azure-sdk-for-go/sdk/internal/mock"
"github.com/golang-jwt/jwt/v4"
"github.com/google/uuid"
)
Expand Down Expand Up @@ -71,8 +70,13 @@ func TestWorkloadIdentityCredential_Live(t *testing.T) {
t.Run(name, func(t *testing.T) {
co, stop := initRecording(t)
defer stop()
o := WorkloadIdentityCredentialOptions{ClientOptions: co, DisableInstanceDiscovery: b}
cred, err := NewWorkloadIdentityCredential(liveSP.tenantID, liveSP.clientID, f, &o)
cred, err := NewWorkloadIdentityCredential(&WorkloadIdentityCredentialOptions{
ClientID: liveSP.clientID,
ClientOptions: co,
DisableInstanceDiscovery: b,
TenantID: liveSP.tenantID,
TokenFilePath: f,
})
if err != nil {
t.Fatal(err)
}
Expand All @@ -86,7 +90,7 @@ func TestWorkloadIdentityCredential(t *testing.T) {
if err := os.WriteFile(tempFile, []byte(tokenValue), os.ModePerm); err != nil {
t.Fatalf("failed to write token file: %v", err)
}
validateReq := func(req *http.Request) bool {
sts := mockSTS{tenant: fakeTenantID, tokenRequestCallback: func(req *http.Request) {
if err := req.ParseForm(); err != nil {
t.Error(err)
}
Expand All @@ -103,18 +107,13 @@ func TestWorkloadIdentityCredential(t *testing.T) {
if actual := strings.Split(req.URL.Path, "/")[1]; actual != fakeTenantID {
t.Errorf(`unexpected tenant "%s"`, actual)
}
return true
}
srv, close := mock.NewServer(mock.WithTransformAllRequestsToTestServerUrl())
defer close()
srv.AppendResponse(mock.WithBody(instanceDiscoveryResponse))
srv.AppendResponse(mock.WithBody(tenantDiscoveryResponse))
srv.AppendResponse(mock.WithPredicate(validateReq), mock.WithBody(accessTokenRespSuccess))
srv.AppendResponse()
opts := WorkloadIdentityCredentialOptions{
ClientOptions: policy.ClientOptions{Transport: srv},
}
cred, err := NewWorkloadIdentityCredential(fakeTenantID, fakeClientID, tempFile, &opts)
}}
cred, err := NewWorkloadIdentityCredential(&WorkloadIdentityCredentialOptions{
ClientID: fakeClientID,
ClientOptions: policy.ClientOptions{Transport: &sts},
TenantID: fakeTenantID,
TokenFilePath: tempFile,
})
if err != nil {
t.Fatal(err)
}
Expand All @@ -124,7 +123,7 @@ func TestWorkloadIdentityCredential(t *testing.T) {
func TestWorkloadIdentityCredential_Expiration(t *testing.T) {
tokenReqs := 0
tempFile := filepath.Join(t.TempDir(), "test-workload-token-file")
validateReq := func(req *http.Request) bool {
sts := mockSTS{tenant: fakeTenantID, tokenRequestCallback: func(req *http.Request) {
if err := req.ParseForm(); err != nil {
t.Error(err)
}
Expand All @@ -134,20 +133,13 @@ func TestWorkloadIdentityCredential_Expiration(t *testing.T) {
t.Errorf(`expected assertion "%d", got "%s"`, tokenReqs, actual[0])
}
tokenReqs++
return true
}
srv, close := mock.NewServer(mock.WithTransformAllRequestsToTestServerUrl())
defer close()
srv.AppendResponse(mock.WithBody(instanceDiscoveryResponse))
srv.AppendResponse(mock.WithBody(tenantDiscoveryResponse))
srv.AppendResponse(mock.WithPredicate(validateReq), mock.WithBody(accessTokenRespSuccess))
srv.AppendResponse()
srv.AppendResponse(mock.WithPredicate(validateReq), mock.WithBody(accessTokenRespSuccess))
srv.AppendResponse()
opts := WorkloadIdentityCredentialOptions{
ClientOptions: policy.ClientOptions{Transport: srv},
}
cred, err := NewWorkloadIdentityCredential(fakeTenantID, fakeClientID, tempFile, &opts)
}}
cred, err := NewWorkloadIdentityCredential(&WorkloadIdentityCredentialOptions{
ClientID: fakeClientID,
ClientOptions: policy.ClientOptions{Transport: &sts},
TenantID: fakeTenantID,
TokenFilePath: tempFile,
})
if err != nil {
t.Fatal(err)
}
Expand All @@ -167,3 +159,82 @@ func TestWorkloadIdentityCredential_Expiration(t *testing.T) {
t.Fatalf("expected 2 token requests, got %d", tokenReqs)
}
}

func TestTestWorkloadIdentityCredential_IncompleteConfig(t *testing.T) {
f := filepath.Join(t.TempDir(), t.Name())
for _, env := range []map[string]string{
{},

{azureClientID: fakeClientID},
{azureFederatedTokenFile: f},
{azureTenantID: fakeTenantID},

{azureClientID: fakeClientID, azureTenantID: fakeTenantID},
{azureClientID: fakeClientID, azureFederatedTokenFile: f},
{azureTenantID: fakeTenantID, azureFederatedTokenFile: f},
} {
t.Run("", func(t *testing.T) {
for k, v := range env {
t.Setenv(k, v)
}
if _, err := NewWorkloadIdentityCredential(nil); err == nil {
t.Fatal("expected an error")
}
})
}
}

func TestWorkloadIdentityCredential_Options(t *testing.T) {
clientID := "not-" + fakeClientID
tenantID := "not-" + fakeTenantID
wrongFile := filepath.Join(t.TempDir(), "wrong")
rightFile := filepath.Join(t.TempDir(), "right")
if err := os.WriteFile(rightFile, []byte(tokenValue), os.ModePerm); err != nil {
t.Fatal(err)
}
sts := mockSTS{
tenant: tenantID,
tokenRequestCallback: func(req *http.Request) {
if err := req.ParseForm(); err != nil {
t.Error(err)
}
if actual, ok := req.PostForm["client_assertion"]; !ok {
t.Error("expected a client_assertion")
} else if len(actual) != 1 || actual[0] != tokenValue {
t.Errorf(`unexpected assertion "%s"`, actual[0])
}
if actual, ok := req.PostForm["client_id"]; !ok {
t.Error("expected a client_id")
} else if len(actual) != 1 || actual[0] != clientID {
t.Errorf(`unexpected assertion "%s"`, actual[0])
}
if actual := strings.Split(req.URL.Path, "/")[1]; actual != tenantID {
t.Errorf(`unexpected tenant "%s"`, actual)
}
},
}
// options should override environment variables
for k, v := range map[string]string{
azureClientID: fakeClientID,
azureFederatedTokenFile: wrongFile,
azureTenantID: fakeTenantID,
} {
t.Setenv(k, v)
}
cred, err := NewWorkloadIdentityCredential(&WorkloadIdentityCredentialOptions{
ClientID: clientID,
ClientOptions: policy.ClientOptions{Transport: &sts},
TenantID: tenantID,
TokenFilePath: rightFile,
})
if err != nil {
t.Fatal(err)
}
tk, err := cred.GetToken(context.Background(), policy.TokenRequestOptions{Scopes: []string{liveTestScope}})
if err != nil {
t.Fatal(err)
}
if tk.Token != tokenValue {
t.Fatalf("unexpected token %q", tk.Token)
}
}

0 comments on commit 1cd56ad

Please sign in to comment.