Skip to content

Commit

Permalink
NewManagedIdentityCredential returns an error for unsupported ID opti…
Browse files Browse the repository at this point in the history
…ons (#23267)
  • Loading branch information
chlowell authored Aug 6, 2024
1 parent ea67e9c commit 84d213c
Show file tree
Hide file tree
Showing 6 changed files with 106 additions and 100 deletions.
10 changes: 10 additions & 0 deletions sdk/azidentity/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,16 @@
### Features Added

### Breaking Changes
* `NewManagedIdentityCredential` now returns an error when a user-assigned identity
is specified on a platform whose managed identity API doesn't support that.
`ManagedIdentityCredential.GetToken()` formerly logged a warning in these cases.
Returning an error instead prevents the credential authenticating an unexpected
identity, causing a client to act with unexpected privileges. The affected
platforms are:
* Azure Arc
* Azure ML (when a resource ID is specified; client IDs are supported)
* Cloud Shell
* Service Fabric

### Bugs Fixed

Expand Down
25 changes: 25 additions & 0 deletions sdk/azidentity/default_azure_credential_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ package azidentity

import (
"context"
"errors"
"fmt"
"io"
"net/http"
Expand Down Expand Up @@ -395,3 +396,27 @@ func TestDefaultAzureCredential_IMDS(t *testing.T) {
require.Equal(t, tokenValue, tk.Token)
})
}

func TestDefaultAzureCredential_UnsupportedMIClientID(t *testing.T) {
fail := true
before := defaultAzTokenProvider
defer func() { defaultAzTokenProvider = before }()
defaultAzTokenProvider = func(ctx context.Context, scopes []string, tenant, subscription string) ([]byte, error) {
if fail {
return nil, errors.New("fail")
}
return mockAzTokenProviderSuccess(ctx, scopes, tenant, subscription)
}
t.Setenv(azureClientID, fakeClientID)
t.Setenv(msiEndpoint, fakeMIEndpoint)

cred, err := NewDefaultAzureCredential(nil)
require.NoError(t, err, "an unsupported client ID isn't a constructor error")

_, err = cred.GetToken(ctx, testTRO)
require.ErrorContains(t, err, "Cloud Shell", "error should mention the unsupported ID")

fail = false
_, err = cred.GetToken(ctx, testTRO)
require.NoError(t, err, "expected a token from AzureCLICredential")
}
53 changes: 19 additions & 34 deletions sdk/azidentity/managed_identity_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,9 @@ func newManagedIdentityClient(options *ManagedIdentityCredentialOptions) (*manag
if endpoint, ok := os.LookupEnv(identityEndpoint); ok {
if _, ok := os.LookupEnv(identityHeader); ok {
if _, ok := os.LookupEnv(identityServerThumbprint); ok {
if options.ID != nil {
return nil, errors.New("the Service Fabric API doesn't support specifying a user-assigned managed identity at runtime")
}
env = "Service Fabric"
c.endpoint = endpoint
c.msiType = msiTypeServiceFabric
Expand All @@ -152,16 +155,25 @@ func newManagedIdentityClient(options *ManagedIdentityCredentialOptions) (*manag
c.msiType = msiTypeAppService
}
} else if _, ok := os.LookupEnv(arcIMDSEndpoint); ok {
if options.ID != nil {
return nil, errors.New("the Azure Arc API doesn't support specifying a user-assigned managed identity at runtime")
}
env = "Azure Arc"
c.endpoint = endpoint
c.msiType = msiTypeAzureArc
}
} else if endpoint, ok := os.LookupEnv(msiEndpoint); ok {
c.endpoint = endpoint
if _, ok := os.LookupEnv(msiSecret); ok {
if options.ID != nil && options.ID.idKind() == miResourceID {
return nil, errors.New("the Azure ML API doesn't support specifying a managed identity by resource ID")
}
env = "Azure ML"
c.msiType = msiTypeAzureML
} else {
if options.ID != nil {
return nil, errors.New("the Cloud Shell API doesn't support user-assigned managed identities")
}
env = "Cloud Shell"
c.msiType = msiTypeCloudShell
}
Expand Down Expand Up @@ -314,13 +326,13 @@ func (c *managedIdentityClient) createAuthRequest(ctx context.Context, id Manage
msg := fmt.Sprintf("failed to retreive secret key from the identity endpoint: %v", err)
return nil, newAuthenticationFailedError(credNameManagedIdentity, msg, nil, err)
}
return c.createAzureArcAuthRequest(ctx, id, scopes, key)
return c.createAzureArcAuthRequest(ctx, scopes, key)
case msiTypeAzureML:
return c.createAzureMLAuthRequest(ctx, id, scopes)
case msiTypeServiceFabric:
return c.createServiceFabricAuthRequest(ctx, id, scopes)
return c.createServiceFabricAuthRequest(ctx, scopes)
case msiTypeCloudShell:
return c.createCloudShellAuthRequest(ctx, id, scopes)
return c.createCloudShellAuthRequest(ctx, scopes)
default:
return nil, newCredentialUnavailableError(credNameManagedIdentity, "managed identity isn't supported in this environment")
}
Expand Down Expand Up @@ -378,9 +390,7 @@ func (c *managedIdentityClient) createAzureMLAuthRequest(ctx context.Context, id
q.Add("clientid", os.Getenv(defaultIdentityClientID))
if id != nil {
if id.idKind() == miResourceID {
log.Write(EventAuthentication, "WARNING: Azure ML doesn't support specifying a managed identity by resource ID")
q.Set("clientid", "")
q.Set(miResID, id.String())
return nil, newAuthenticationFailedError(credNameManagedIdentity, "Azure ML doesn't support specifying a managed identity by resource ID", nil, nil)
} else {
q.Set("clientid", id.String())
}
Expand All @@ -389,7 +399,7 @@ func (c *managedIdentityClient) createAzureMLAuthRequest(ctx context.Context, id
return request, nil
}

func (c *managedIdentityClient) createServiceFabricAuthRequest(ctx context.Context, id ManagedIDKind, scopes []string) (*policy.Request, error) {
func (c *managedIdentityClient) createServiceFabricAuthRequest(ctx context.Context, scopes []string) (*policy.Request, error) {
request, err := azruntime.NewRequest(ctx, http.MethodGet, c.endpoint)
if err != nil {
return nil, err
Expand All @@ -399,14 +409,6 @@ func (c *managedIdentityClient) createServiceFabricAuthRequest(ctx context.Conte
request.Raw().Header.Set("Secret", os.Getenv(identityHeader))
q.Add("api-version", serviceFabricAPIVersion)
q.Add("resource", strings.Join(scopes, " "))
if id != nil {
log.Write(EventAuthentication, "WARNING: Service Fabric doesn't support selecting a user-assigned identity at runtime")
if id.idKind() == miResourceID {
q.Add(miResID, id.String())
} else {
q.Add(qpClientID, id.String())
}
}
request.Raw().URL.RawQuery = q.Encode()
return request, nil
}
Expand Down Expand Up @@ -463,7 +465,7 @@ func (c *managedIdentityClient) getAzureArcSecretKey(ctx context.Context, resour
return string(key), nil
}

func (c *managedIdentityClient) createAzureArcAuthRequest(ctx context.Context, id ManagedIDKind, resources []string, key string) (*policy.Request, error) {
func (c *managedIdentityClient) createAzureArcAuthRequest(ctx context.Context, resources []string, key string) (*policy.Request, error) {
request, err := azruntime.NewRequest(ctx, http.MethodGet, c.endpoint)
if err != nil {
return nil, err
Expand All @@ -473,19 +475,11 @@ func (c *managedIdentityClient) createAzureArcAuthRequest(ctx context.Context, i
q := request.Raw().URL.Query()
q.Add("api-version", azureArcAPIVersion)
q.Add("resource", strings.Join(resources, " "))
if id != nil {
log.Write(EventAuthentication, "WARNING: Azure Arc doesn't support user-assigned managed identities")
if id.idKind() == miResourceID {
q.Add(miResID, id.String())
} else {
q.Add(qpClientID, id.String())
}
}
request.Raw().URL.RawQuery = q.Encode()
return request, nil
}

func (c *managedIdentityClient) createCloudShellAuthRequest(ctx context.Context, id ManagedIDKind, scopes []string) (*policy.Request, error) {
func (c *managedIdentityClient) createCloudShellAuthRequest(ctx context.Context, scopes []string) (*policy.Request, error) {
request, err := azruntime.NewRequest(ctx, http.MethodPost, c.endpoint)
if err != nil {
return nil, err
Expand All @@ -498,14 +492,5 @@ func (c *managedIdentityClient) createCloudShellAuthRequest(ctx context.Context,
if err := request.SetBody(body, "application/x-www-form-urlencoded"); err != nil {
return nil, err
}
if id != nil {
log.Write(EventAuthentication, "WARNING: Cloud Shell doesn't support user-assigned managed identities")
q := request.Raw().URL.Query()
if id.idKind() == miResourceID {
q.Add(miResID, id.String())
} else {
q.Add(qpClientID, id.String())
}
}
return request, nil
}
61 changes: 0 additions & 61 deletions sdk/azidentity/managed_identity_client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ import (

"github.com/Azure/azure-sdk-for-go/sdk/azcore"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
"github.com/Azure/azure-sdk-for-go/sdk/internal/log"
"github.com/Azure/azure-sdk-for-go/sdk/internal/mock"
)

Expand Down Expand Up @@ -123,63 +122,3 @@ func TestManagedIdentityClient_IMDSErrors(t *testing.T) {
})
}
}

func TestManagedIdentityClient_UserAssignedIDWarning(t *testing.T) {
for _, test := range []struct {
name string
createRequest func(*managedIdentityClient) error
}{
{
name: "Azure Arc",
createRequest: func(client *managedIdentityClient) error {
_, err := client.createAzureArcAuthRequest(context.Background(), client.id, []string{liveTestScope}, "key")
return err
},
},
{
name: "Cloud Shell",
createRequest: func(client *managedIdentityClient) error {
_, err := client.createCloudShellAuthRequest(context.Background(), client.id, []string{liveTestScope})
return err
},
},
{
name: "Service Fabric",
createRequest: func(client *managedIdentityClient) error {
_, err := client.createServiceFabricAuthRequest(context.Background(), client.id, []string{liveTestScope})
return err
},
},
} {
for _, id := range []ManagedIDKind{ClientID(fakeClientID), ResourceID(fakeResourceID)} {
s := "-ClientID"
if id.String() == fakeResourceID {
s = "-ResourceID"
}
t.Run(test.name+s, func(t *testing.T) {
msgs := []string{}
log.SetListener(func(event log.Event, msg string) {
if event == EventAuthentication {
msgs = append(msgs, msg)
}
})
client, err := newManagedIdentityClient(&ManagedIdentityCredentialOptions{
ID: id,
})
if err != nil {
t.Fatal(err)
}
err = test.createRequest(client)
if err != nil {
t.Fatal(err)
}
for _, msg := range msgs {
if strings.Contains(msg, test.name) && strings.Contains(msg, "user-assigned") {
return
}
}
t.Fatalf("expected warning about user-assigned ID, got:\n%s", strings.Join(msgs, "\n"))
})
}
}
}
22 changes: 17 additions & 5 deletions sdk/azidentity/managed_identity_credential.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,12 @@ type ManagedIDKind interface {
idKind() managedIdentityIDKind
}

// ClientID is the client ID of a user-assigned managed identity.
// ClientID is the client ID of a user-assigned managed identity. NewManagedIdentityCredential
// returns an error when a ClientID is specified on the following platforms:
//
// - Azure Arc
// - Cloud Shell
// - Service Fabric
type ClientID string

func (ClientID) idKind() managedIdentityIDKind {
Expand All @@ -44,7 +49,13 @@ func (c ClientID) String() string {
return string(c)
}

// ResourceID is the resource ID of a user-assigned managed identity.
// ResourceID is the resource ID of a user-assigned managed identity. NewManagedIdentityCredential
// returns an error when a ResourceID is specified on the following platforms:
//
// - Azure Arc
// - Azure ML
// - Cloud Shell
// - Service Fabric
type ResourceID string

func (ResourceID) idKind() managedIdentityIDKind {
Expand All @@ -60,9 +71,10 @@ func (r ResourceID) String() string {
type ManagedIdentityCredentialOptions struct {
azcore.ClientOptions

// ID is the ID of a managed identity the credential should authenticate. Set this field to use a specific identity
// instead of the hosting environment's default. The value may be the identity's client ID or resource ID, but note that
// some platforms don't accept resource IDs.
// ID of a managed identity the credential should authenticate. Set this field to use a specific identity instead of
// the hosting environment's default. The value may be the identity's client ID or resource ID.
// NewManagedIdentityCredential returns an error when the hosting environment doesn't support user-assigned managed
// identities, or the specified kind of ID.
ID ManagedIDKind

// dac indicates whether the credential is part of DefaultAzureCredential. When true, and the environment doesn't have
Expand Down
35 changes: 35 additions & 0 deletions sdk/azidentity/managed_identity_credential_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -673,3 +673,38 @@ func TestManagedIdentityCredential_ServiceFabric(t *testing.T) {
}
testGetTokenSuccess(t, cred)
}

func TestManagedIdentityCredential_UnsupportedID(t *testing.T) {
t.Run("Azure Arc", func(t *testing.T) {
t.Setenv(identityEndpoint, fakeMIEndpoint)
t.Setenv(arcIMDSEndpoint, fakeMIEndpoint)
_, err := NewManagedIdentityCredential(&ManagedIdentityCredentialOptions{ID: ResourceID(fakeResourceID)})
require.Error(t, err)
_, err = NewManagedIdentityCredential(&ManagedIdentityCredentialOptions{ID: ClientID(fakeClientID)})
require.Error(t, err)
})
t.Run("Azure ML", func(t *testing.T) {
t.Setenv(msiEndpoint, fakeMIEndpoint)
t.Setenv(msiSecret, "...")
_, err := NewManagedIdentityCredential(&ManagedIdentityCredentialOptions{ID: ResourceID(fakeResourceID)})
require.Error(t, err)
_, err = NewManagedIdentityCredential(&ManagedIdentityCredentialOptions{ID: ClientID(fakeClientID)})
require.NoError(t, err)
})
t.Run("Cloud Shell", func(t *testing.T) {
t.Setenv(msiEndpoint, fakeMIEndpoint)
_, err := NewManagedIdentityCredential(&ManagedIdentityCredentialOptions{ID: ResourceID(fakeResourceID)})
require.Error(t, err)
_, err = NewManagedIdentityCredential(&ManagedIdentityCredentialOptions{ID: ClientID(fakeClientID)})
require.Error(t, err)
})
t.Run("Service Fabric", func(t *testing.T) {
t.Setenv(identityEndpoint, fakeMIEndpoint)
t.Setenv(identityHeader, "...")
t.Setenv(identityServerThumbprint, "...")
_, err := NewManagedIdentityCredential(&ManagedIdentityCredentialOptions{ID: ResourceID(fakeResourceID)})
require.Error(t, err)
_, err = NewManagedIdentityCredential(&ManagedIdentityCredentialOptions{ID: ClientID(fakeClientID)})
require.Error(t, err)
})
}

0 comments on commit 84d213c

Please sign in to comment.