diff --git a/sdk/azidentity/CHANGELOG.md b/sdk/azidentity/CHANGELOG.md index 00c9cf30d4ac..69c5235854d4 100644 --- a/sdk/azidentity/CHANGELOG.md +++ b/sdk/azidentity/CHANGELOG.md @@ -29,6 +29,23 @@ ``` * Removed `InteractiveBrowserCredentialOptions.ClientSecret` and `.Port` * Removed `AADAuthenticationFailedError` +* Removed `id` parameter of `NewManagedIdentityCredential()`. User assigned identities are now + specified by `ManagedIdentityCredentialOptions.ID`: + ```go + // before + cred, err := NewManagedIdentityCredential("client-id", nil) + // or, for a resource ID + opts := &ManagedIdentityCredentialOptions{ID: ResourceID} + cred, err := NewManagedIdentityCredential("/subscriptions/...", opts) + + // after + clientID := ClientID("7cf7db0d-...") + opts := &ManagedIdentityCredentialOptions{ID: clientID} + // or, for a resource ID + resID: ResourceID("/subscriptions/...") + opts := &ManagedIdentityCredentialOptions{ID: resID} + cred, err := NewManagedIdentityCredential(opts) + ``` ### Features Added * Added connection configuration options to `DefaultAzureCredentialOptions` diff --git a/sdk/azidentity/default_azure_credential.go b/sdk/azidentity/default_azure_credential.go index 76f5d6d6e4d4..de944c71c2a1 100644 --- a/sdk/azidentity/default_azure_credential.go +++ b/sdk/azidentity/default_azure_credential.go @@ -57,7 +57,7 @@ func NewDefaultAzureCredential(options *DefaultAzureCredentialOptions) (*Chained errMsg += err.Error() } - msiCred, err := NewManagedIdentityCredential("", &ManagedIdentityCredentialOptions{HTTPClient: options.HTTPClient, + msiCred, err := NewManagedIdentityCredential(&ManagedIdentityCredentialOptions{HTTPClient: options.HTTPClient, Logging: options.Logging, Telemetry: options.Telemetry, }) diff --git a/sdk/azidentity/managed_identity_client.go b/sdk/azidentity/managed_identity_client.go index 804e4d370c15..659f94b4bb4c 100644 --- a/sdk/azidentity/managed_identity_client.go +++ b/sdk/azidentity/managed_identity_client.go @@ -59,7 +59,7 @@ type managedIdentityClient struct { imdsAvailableTimeout time.Duration msiType msiType endpoint string - id ManagedIdentityIDKind + id ManagedIDKind unavailableMessage string } @@ -92,12 +92,12 @@ func newManagedIdentityClient(options *ManagedIdentityCredentialOptions) *manage // ctx: The current context for controlling the request lifetime. // clientID: The client (application) ID of the service principal. // scopes: The scopes required for the token. -func (c *managedIdentityClient) authenticate(ctx context.Context, clientID string, scopes []string) (*azcore.AccessToken, error) { +func (c *managedIdentityClient) authenticate(ctx context.Context, id ManagedIDKind, scopes []string) (*azcore.AccessToken, error) { if len(c.unavailableMessage) > 0 { return nil, &CredentialUnavailableError{credentialType: "Managed Identity Credential", message: c.unavailableMessage} } - msg, err := c.createAuthRequest(ctx, clientID, scopes) + msg, err := c.createAuthRequest(ctx, id, scopes) if err != nil { return nil, err } @@ -112,7 +112,7 @@ func (c *managedIdentityClient) authenticate(ctx context.Context, clientID strin } if c.msiType == msiTypeIMDS && resp.StatusCode == 400 { - if len(clientID) > 0 { + if id != nil { return nil, &AuthenticationFailedError{msg: "The requested identity isn't assigned to this resource."} } c.unavailableMessage = "No default identity is assigned to this resource." @@ -163,12 +163,12 @@ func (c *managedIdentityClient) createAccessToken(res *http.Response) (*azcore.A } } -func (c *managedIdentityClient) createAuthRequest(ctx context.Context, clientID string, scopes []string) (*policy.Request, error) { +func (c *managedIdentityClient) createAuthRequest(ctx context.Context, id ManagedIDKind, scopes []string) (*policy.Request, error) { switch c.msiType { case msiTypeIMDS: - return c.createIMDSAuthRequest(ctx, clientID, scopes) + return c.createIMDSAuthRequest(ctx, id, scopes) case msiTypeAppServiceV20170901, msiTypeAppServiceV20190801: - return c.createAppServiceAuthRequest(ctx, clientID, scopes) + return c.createAppServiceAuthRequest(ctx, id, scopes) case msiTypeAzureArc: // need to perform preliminary request to retreive the secret key challenge provided by the HIMDS service key, err := c.getAzureArcSecretKey(ctx, scopes) @@ -177,9 +177,9 @@ func (c *managedIdentityClient) createAuthRequest(ctx context.Context, clientID } return c.createAzureArcAuthRequest(ctx, key, scopes) case msiTypeServiceFabric: - return c.createServiceFabricAuthRequest(ctx, clientID, scopes) + return c.createServiceFabricAuthRequest(ctx, id, scopes) case msiTypeCloudShell: - return c.createCloudShellAuthRequest(ctx, clientID, scopes) + return c.createCloudShellAuthRequest(ctx, id, scopes) default: errorMsg := "" switch c.msiType { @@ -193,7 +193,7 @@ func (c *managedIdentityClient) createAuthRequest(ctx context.Context, clientID } } -func (c *managedIdentityClient) createIMDSAuthRequest(ctx context.Context, id string, scopes []string) (*policy.Request, error) { +func (c *managedIdentityClient) createIMDSAuthRequest(ctx context.Context, id ManagedIDKind, scopes []string) (*policy.Request, error) { request, err := runtime.NewRequest(ctx, http.MethodGet, c.endpoint) if err != nil { return nil, err @@ -202,16 +202,18 @@ func (c *managedIdentityClient) createIMDSAuthRequest(ctx context.Context, id st q := request.Raw().URL.Query() q.Add("api-version", c.imdsAPIVersion) q.Add("resource", strings.Join(scopes, " ")) - if c.id == ResourceID { - q.Add(qpResID, id) - } else if id != "" { - q.Add(qpClientID, id) + if id != nil { + if id.idKind() == miResourceID { + q.Add(qpResID, id.String()) + } else { + q.Add(qpClientID, id.String()) + } } request.Raw().URL.RawQuery = q.Encode() return request, nil } -func (c *managedIdentityClient) createAppServiceAuthRequest(ctx context.Context, id string, scopes []string) (*policy.Request, error) { +func (c *managedIdentityClient) createAppServiceAuthRequest(ctx context.Context, id ManagedIDKind, scopes []string) (*policy.Request, error) { request, err := runtime.NewRequest(ctx, http.MethodGet, c.endpoint) if err != nil { return nil, err @@ -221,20 +223,24 @@ func (c *managedIdentityClient) createAppServiceAuthRequest(ctx context.Context, request.Raw().Header.Set("secret", os.Getenv(msiSecret)) q.Add("api-version", "2017-09-01") q.Add("resource", strings.Join(scopes, " ")) - if c.id == ResourceID { - q.Add(qpResID, id) - } else if id != "" { - // the legacy 2017 API version specifically specifies "clientid" and not "client_id" as a query param - q.Add("clientid", id) + if id != nil { + if id.idKind() == miResourceID { + q.Add(qpResID, id.String()) + } else { + // the legacy 2017 API version specifically specifies "clientid" and not "client_id" as a query param + q.Add("clientid", id.String()) + } } } else if c.msiType == msiTypeAppServiceV20190801 { request.Raw().Header.Set("X-IDENTITY-HEADER", os.Getenv(identityHeader)) q.Add("api-version", "2019-08-01") q.Add("resource", scopes[0]) - if c.id == ResourceID { - q.Add(qpResID, id) - } else if id != "" { - q.Add(qpClientID, id) + if id != nil { + if id.idKind() == miResourceID { + q.Add(qpResID, id.String()) + } else { + q.Add(qpClientID, id.String()) + } } } @@ -242,7 +248,7 @@ func (c *managedIdentityClient) createAppServiceAuthRequest(ctx context.Context, return request, nil } -func (c *managedIdentityClient) createServiceFabricAuthRequest(ctx context.Context, id string, scopes []string) (*policy.Request, error) { +func (c *managedIdentityClient) createServiceFabricAuthRequest(ctx context.Context, id ManagedIDKind, scopes []string) (*policy.Request, error) { request, err := runtime.NewRequest(ctx, http.MethodGet, c.endpoint) if err != nil { return nil, err @@ -252,8 +258,8 @@ func (c *managedIdentityClient) createServiceFabricAuthRequest(ctx context.Conte request.Raw().Header.Set("Secret", os.Getenv(identityHeader)) q.Add("api-version", serviceFabricAPIVersion) q.Add("resource", strings.Join(scopes, " ")) - if id != "" { - q.Add(qpClientID, id) + if id != nil { + q.Add(qpClientID, id.String()) } request.Raw().URL.RawQuery = q.Encode() return request, nil @@ -310,7 +316,7 @@ func (c *managedIdentityClient) createAzureArcAuthRequest(ctx context.Context, k return request, nil } -func (c *managedIdentityClient) createCloudShellAuthRequest(ctx context.Context, clientID string, scopes []string) (*policy.Request, error) { +func (c *managedIdentityClient) createCloudShellAuthRequest(ctx context.Context, id ManagedIDKind, scopes []string) (*policy.Request, error) { request, err := runtime.NewRequest(ctx, http.MethodPost, c.endpoint) if err != nil { return nil, err @@ -318,8 +324,8 @@ func (c *managedIdentityClient) createCloudShellAuthRequest(ctx context.Context, request.Raw().Header.Set(headerMetadata, "true") data := url.Values{} data.Set("resource", strings.Join(scopes, " ")) - if clientID != "" { - data.Set(qpClientID, clientID) + if id != nil { + data.Set(qpClientID, id.String()) } dataEncoded := data.Encode() body := streaming.NopCloser(strings.NewReader(dataEncoded)) diff --git a/sdk/azidentity/managed_identity_credential.go b/sdk/azidentity/managed_identity_credential.go index 2e0e416033d8..2a91e15e5baf 100644 --- a/sdk/azidentity/managed_identity_credential.go +++ b/sdk/azidentity/managed_identity_credential.go @@ -5,6 +5,7 @@ package azidentity import ( "context" + "fmt" "os" "strings" @@ -12,24 +13,50 @@ import ( "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" ) -// ManagedIdentityIDKind is used to specify the type of identifier that is passed in for a user-assigned managed identity. -type ManagedIdentityIDKind int +type managedIdentityIDKind int const ( - // ClientID is the default identifier for a user-assigned managed identity. - ClientID ManagedIdentityIDKind = 0 - // ResourceID is set when the resource ID of the user-assigned managed identity is to be used. - ResourceID ManagedIdentityIDKind = 1 + miClientID managedIdentityIDKind = 0 + miResourceID managedIdentityIDKind = 1 ) +// ManagedIDKind identifies the ID of a managed identity as either a client or resource ID +type ManagedIDKind interface { + fmt.Stringer + idKind() managedIdentityIDKind +} + +// ClientID is an identity's client ID. Use it with ManagedIdentityCredentialOptions, for example: +// ManagedIdentityCredentialOptions{ID: ClientID("7cf7db0d-...")} +type ClientID string + +func (ClientID) idKind() managedIdentityIDKind { + return miClientID +} + +func (c ClientID) String() string { + return string(c) +} + +// ResourceID is an identity's resource ID. Use it with ManagedIdentityCredentialOptions, for example: +// ManagedIdentityCredentialOptions{ID: ResourceID("/subscriptions/...")} +type ResourceID string + +func (ResourceID) idKind() managedIdentityIDKind { + return miResourceID +} + +func (r ResourceID) String() string { + return string(r) +} + // ManagedIdentityCredentialOptions contains parameters that can be used to configure the pipeline used with Managed Identity Credential. // All zero-value fields will be initialized with their default values. type ManagedIdentityCredentialOptions struct { - // ID is used to configure an alternate identifier for a user-assigned identity. The default is client ID. - // Select the identifier to be used and pass the corresponding ID value in the string param in - // NewManagedIdentityCredential(). - // Hint: Choose from the list of allowed ManagedIdentityIDKind values. - ID ManagedIdentityIDKind + // ID is the ID of a managed identity the credential should authenticate. Set this field to use a specific identity + // instead of the hosting environment's default. The value may be the identity's client ID or resource ID, but note that + // some platforms don't accept resource IDs. + ID ManagedIDKind // HTTPClient sets the transport for making HTTP requests. // Leave this as nil to use the default HTTP transport. @@ -46,17 +73,15 @@ type ManagedIdentityCredentialOptions struct { // managed identity environments such as Azure VMs, App Service, Azure Functions, Azure CloudShell, among others. More information about configuring managed identities can be found here: // https://docs.microsoft.com/en-us/azure/active-directory/managed-identities-azure-resources/overview type ManagedIdentityCredential struct { - id string + id ManagedIDKind client *managedIdentityClient } -// NewManagedIdentityCredential creates an instance of the ManagedIdentityCredential capable of authenticating a resource that has a managed identity. -// id: The ID that corresponds to the user assigned managed identity. Defaults to the identity's client ID. To use another identifier, -// pass in the value for the identifier here AND choose the correct ID kind to be used in the request by setting ManagedIdentityIDKind in the options. +// NewManagedIdentityCredential creates a credential instance capable of authenticating an Azure managed identity in any hosting environment +// supporting managed identities. See https://docs.microsoft.com/azure/active-directory/managed-identities-azure-resources/overview for more +// information about Azure Managed Identity. // options: ManagedIdentityCredentialOptions that configure the pipeline for requests sent to Azure Active Directory. -// More information on user assigned managed identities cam be found here: -// https://docs.microsoft.com/en-us/azure/active-directory/managed-identities-azure-resources/overview#how-a-user-assigned-managed-identity-works-with-an-azure-vm -func NewManagedIdentityCredential(id string, options *ManagedIdentityCredentialOptions) (*ManagedIdentityCredential, error) { +func NewManagedIdentityCredential(options *ManagedIdentityCredentialOptions) (*ManagedIdentityCredential, error) { // Create a new Managed Identity Client with default options if options == nil { options = &ManagedIdentityCredentialOptions{} @@ -72,11 +97,16 @@ func NewManagedIdentityCredential(id string, options *ManagedIdentityCredentialO // Assign the msiType discovered onto the client client.msiType = msiType // check if no clientID is specified then check if it exists in an environment variable - if len(id) == 0 { - if options.ID == ResourceID { - id = os.Getenv("AZURE_RESOURCE_ID") + id := options.ID + if id == nil { + cID := os.Getenv("AZURE_CLIENT_ID") + if cID != "" { + id = ClientID(cID) } else { - id = os.Getenv("AZURE_CLIENT_ID") + rID := os.Getenv("AZURE_RESOURCE_ID") + if rID != "" { + id = ResourceID(rID) + } } } return &ManagedIdentityCredential{id: id, client: client}, nil diff --git a/sdk/azidentity/managed_identity_credential_test.go b/sdk/azidentity/managed_identity_credential_test.go index e156104adc03..fe3c4dcdb3ce 100644 --- a/sdk/azidentity/managed_identity_credential_test.go +++ b/sdk/azidentity/managed_identity_credential_test.go @@ -56,7 +56,7 @@ func TestManagedIdentityCredential_GetTokenInAzureArcLive(t *testing.T) { if len(os.Getenv(arcIMDSEndpoint)) == 0 { t.Skip() } - msiCred, err := NewManagedIdentityCredential(clientID, nil) + msiCred, err := NewManagedIdentityCredential(nil) if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -70,7 +70,7 @@ func TestManagedIdentityCredential_GetTokenInCloudShellLive(t *testing.T) { if len(os.Getenv("MSI_ENDPOINT")) == 0 { t.Skip() } - msiCred, err := NewManagedIdentityCredential(clientID, nil) + msiCred, err := NewManagedIdentityCredential(nil) if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -89,7 +89,7 @@ func TestManagedIdentityCredential_GetTokenInCloudShellMock(t *testing.T) { defer clearEnvVars("MSI_ENDPOINT") options := ManagedIdentityCredentialOptions{} options.HTTPClient = srv - msiCred, err := NewManagedIdentityCredential(clientID, &options) + msiCred, err := NewManagedIdentityCredential(&options) if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -108,7 +108,7 @@ func TestManagedIdentityCredential_GetTokenInCloudShellMockFail(t *testing.T) { defer clearEnvVars("MSI_ENDPOINT") options := ManagedIdentityCredentialOptions{} options.HTTPClient = srv - msiCred, err := NewManagedIdentityCredential("", &options) + msiCred, err := NewManagedIdentityCredential(&options) if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -128,7 +128,7 @@ func TestManagedIdentityCredential_GetTokenInAppServiceV20170901Mock_windows(t * defer clearEnvVars("MSI_ENDPOINT", "MSI_SECRET") options := ManagedIdentityCredentialOptions{} options.HTTPClient = srv - msiCred, err := NewManagedIdentityCredential(clientID, &options) + msiCred, err := NewManagedIdentityCredential(&options) if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -154,7 +154,7 @@ func TestManagedIdentityCredential_GetTokenInAppServiceV20170901Mock_linux(t *te defer clearEnvVars("MSI_ENDPOINT", "MSI_SECRET") options := ManagedIdentityCredentialOptions{} options.HTTPClient = srv - msiCred, err := NewManagedIdentityCredential(clientID, &options) + msiCred, err := NewManagedIdentityCredential(&options) if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -180,7 +180,7 @@ func TestManagedIdentityCredential_GetTokenInAppServiceV20190801Mock_windows(t * defer clearEnvVars("IDENTITY_ENDPOINT", "IDENTITY_HEADER") options := ManagedIdentityCredentialOptions{} options.HTTPClient = srv - msiCred, err := NewManagedIdentityCredential("", &options) + msiCred, err := NewManagedIdentityCredential(&options) if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -206,7 +206,7 @@ func TestManagedIdentityCredential_GetTokenInAppServiceV20190801Mock_linux(t *te defer clearEnvVars("IDENTITY_ENDPOINT", "IDENTITY_HEADER") options := ManagedIdentityCredentialOptions{} options.HTTPClient = srv - msiCred, err := NewManagedIdentityCredential("", &options) + msiCred, err := NewManagedIdentityCredential(&options) if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -235,7 +235,7 @@ func TestManagedIdentityCredential_GetTokenInAzureFunctions_linux(t *testing.T) _ = os.Setenv("IDENTITY_ENDPOINT", srv.URL()) _ = os.Setenv("IDENTITY_HEADER", "header") defer clearEnvVars("IDENTITY_ENDPOINT", "IDENTITY_HEADER") - msiCred, err := NewManagedIdentityCredential(clientID, &ManagedIdentityCredentialOptions{ + msiCred, err := NewManagedIdentityCredential(&ManagedIdentityCredentialOptions{ HTTPClient: srv, }) if err != nil { @@ -259,12 +259,12 @@ func TestManagedIdentityCredential_CreateAppServiceAuthRequestV20190801(t *testi _ = os.Setenv("IDENTITY_ENDPOINT", "somevalue") _ = os.Setenv("IDENTITY_HEADER", "header") defer clearEnvVars("IDENTITY_ENDPOINT", "IDENTITY_HEADER") - cred, err := NewManagedIdentityCredential(clientID, nil) + cred, err := NewManagedIdentityCredential(nil) if err != nil { t.Fatalf("unexpected error: %v", err) } cred.client.endpoint = imdsEndpoint - req, err := cred.client.createAuthRequest(context.Background(), clientID, []string{msiScope}) + req, err := cred.client.createAuthRequest(context.Background(), ClientID(clientID), []string{msiScope}) if err != nil { t.Fatal(err) } @@ -292,12 +292,12 @@ func TestManagedIdentityCredential_CreateAppServiceAuthRequestV20170901(t *testi _ = os.Setenv("MSI_ENDPOINT", "somevalue") _ = os.Setenv("MSI_SECRET", "secret") defer clearEnvVars("MSI_ENDPOINT", "MSI_SECRET") - cred, err := NewManagedIdentityCredential(clientID, nil) + cred, err := NewManagedIdentityCredential(nil) if err != nil { t.Fatalf("unexpected error: %v", err) } cred.client.endpoint = imdsEndpoint - req, err := cred.client.createAuthRequest(context.Background(), clientID, []string{msiScope}) + req, err := cred.client.createAuthRequest(context.Background(), ClientID(clientID), []string{msiScope}) if err != nil { t.Fatal(err) } @@ -329,7 +329,7 @@ func TestManagedIdentityCredential_CreateAccessTokenExpiresOnStringInt(t *testin defer clearEnvVars("MSI_ENDPOINT", "MSI_SECRET") options := ManagedIdentityCredentialOptions{} options.HTTPClient = srv - msiCred, err := NewManagedIdentityCredential(clientID, &options) + msiCred, err := NewManagedIdentityCredential(&options) if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -349,7 +349,7 @@ func TestManagedIdentityCredential_GetTokenInAppServiceMockFail(t *testing.T) { defer clearEnvVars("MSI_ENDPOINT", "MSI_SECRET") options := ManagedIdentityCredentialOptions{} options.HTTPClient = srv - msiCred, err := NewManagedIdentityCredential("", &options) + msiCred, err := NewManagedIdentityCredential(&options) if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -369,7 +369,7 @@ func TestManagedIdentityCredential_GetTokenIMDS400(t *testing.T) { } res2 := res1 options.HTTPClient = newMockImds(res1, res2) - cred, err := NewManagedIdentityCredential("", &options) + cred, err := NewManagedIdentityCredential(&options) if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -393,7 +393,7 @@ func TestManagedIdentityCredential_NewManagedIdentityCredentialFail(t *testing.T defer clearEnvVars("MSI_ENDPOINT") options := ManagedIdentityCredentialOptions{} options.HTTPClient = srv - cred, err := NewManagedIdentityCredential("", &options) + cred, err := NewManagedIdentityCredential(&options) if err != nil { t.Fatal(err) } @@ -412,7 +412,7 @@ func TestBearerPolicy_ManagedIdentityCredential(t *testing.T) { defer clearEnvVars("MSI_ENDPOINT") options := ManagedIdentityCredentialOptions{} options.HTTPClient = srv - cred, err := NewManagedIdentityCredential(clientID, &options) + cred, err := NewManagedIdentityCredential(&options) if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -436,7 +436,7 @@ func TestManagedIdentityCredential_GetTokenUnexpectedJSON(t *testing.T) { defer clearEnvVars("MSI_ENDPOINT") options := ManagedIdentityCredentialOptions{} options.HTTPClient = srv - msiCred, err := NewManagedIdentityCredential(clientID, &options) + msiCred, err := NewManagedIdentityCredential(&options) if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -451,12 +451,12 @@ func TestManagedIdentityCredential_CreateIMDSAuthRequest(t *testing.T) { // to test IMDS authentication request creation. _ = os.Setenv("MSI_ENDPOINT", "somevalue") defer clearEnvVars("MSI_ENDPOINT") - cred, err := NewManagedIdentityCredential(clientID, nil) + cred, err := NewManagedIdentityCredential(nil) if err != nil { t.Fatalf("unexpected error: %v", err) } cred.client.endpoint = imdsEndpoint - req, err := cred.client.createIMDSAuthRequest(context.Background(), clientID, []string{msiScope}) + req, err := cred.client.createIMDSAuthRequest(context.Background(), ClientID(clientID), []string{msiScope}) if err != nil { t.Fatal(err) } @@ -497,7 +497,7 @@ func TestManagedIdentityCredential_GetTokenEnvVar(t *testing.T) { defer clearEnvVars("MSI_ENDPOINT") options := ManagedIdentityCredentialOptions{} options.HTTPClient = srv - msiCred, err := NewManagedIdentityCredential("", &options) + msiCred, err := NewManagedIdentityCredential(&options) if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -519,7 +519,7 @@ func TestManagedIdentityCredential_GetTokenNilResource(t *testing.T) { defer clearEnvVars("MSI_ENDPOINT") options := ManagedIdentityCredentialOptions{} options.HTTPClient = srv - msiCred, err := NewManagedIdentityCredential("", &options) + msiCred, err := NewManagedIdentityCredential(&options) if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -542,7 +542,7 @@ func TestManagedIdentityCredential_ScopesImmutable(t *testing.T) { options := ManagedIdentityCredentialOptions{ HTTPClient: srv, } - cred, err := NewManagedIdentityCredential("", &options) + cred, err := NewManagedIdentityCredential(&options) if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -566,7 +566,7 @@ func TestManagedIdentityCredential_GetTokenMultipleResources(t *testing.T) { defer clearEnvVars("MSI_ENDPOINT") options := ManagedIdentityCredentialOptions{} options.HTTPClient = srv - msiCred, err := NewManagedIdentityCredential("", &options) + msiCred, err := NewManagedIdentityCredential(&options) if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -589,8 +589,8 @@ func TestManagedIdentityCredential_UseResourceID(t *testing.T) { defer clearEnvVars("MSI_ENDPOINT", "MSI_SECRET") options := ManagedIdentityCredentialOptions{} options.HTTPClient = srv - options.ID = ResourceID - cred, err := NewManagedIdentityCredential("sample/resource/id", &options) + options.ID = ResourceID("sample/resource/id") + cred, err := NewManagedIdentityCredential(&options) if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -610,12 +610,12 @@ func TestManagedIdentityCredential_ResourceID_AppService(t *testing.T) { _ = os.Setenv("IDENTITY_HEADER", "header") defer clearEnvVars("IDENTITY_ENDPOINT", "IDENTITY_HEADER") resID := "sample/resource/id" - cred, err := NewManagedIdentityCredential(resID, &ManagedIdentityCredentialOptions{ID: ResourceID}) + cred, err := NewManagedIdentityCredential(&ManagedIdentityCredentialOptions{ID: ResourceID(resID)}) if err != nil { t.Fatalf("unexpected error: %v", err) } cred.client.endpoint = imdsEndpoint - req, err := cred.client.createAuthRequest(context.Background(), resID, []string{msiScope}) + req, err := cred.client.createAuthRequest(context.Background(), cred.id, []string{msiScope}) if err != nil { t.Fatal(err) } @@ -642,13 +642,13 @@ func TestManagedIdentityCredential_ResourceID_IMDS(t *testing.T) { _ = os.Setenv("MSI_ENDPOINT", "http://foo.com/") defer clearEnvVars("MSI_ENDPOINT") resID := "sample/resource/id" - cred, err := NewManagedIdentityCredential(resID, &ManagedIdentityCredentialOptions{ID: ResourceID}) + cred, err := NewManagedIdentityCredential(&ManagedIdentityCredentialOptions{ID: ResourceID(resID)}) if err != nil { t.Fatalf("unexpected error: %v", err) } cred.client.msiType = msiTypeIMDS cred.client.endpoint = imdsEndpoint - req, err := cred.client.createAuthRequest(context.Background(), resID, []string{msiScope}) + req, err := cred.client.createAuthRequest(context.Background(), cred.id, []string{msiScope}) if err != nil { t.Fatal(err) } @@ -677,7 +677,7 @@ func TestManagedIdentityCredential_CreateAccessTokenExpiresOnInt(t *testing.T) { defer clearEnvVars("MSI_ENDPOINT", "MSI_SECRET") options := ManagedIdentityCredentialOptions{} options.HTTPClient = srv - msiCred, err := NewManagedIdentityCredential(clientID, &options) + msiCred, err := NewManagedIdentityCredential(&options) if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -698,7 +698,7 @@ func TestManagedIdentityCredential_CreateAccessTokenExpiresOnFail(t *testing.T) defer clearEnvVars("MSI_ENDPOINT", "MSI_SECRET") options := ManagedIdentityCredentialOptions{} options.HTTPClient = srv - msiCred, err := NewManagedIdentityCredential(clientID, &options) + msiCred, err := NewManagedIdentityCredential(&options) if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -712,28 +712,29 @@ func TestManagedIdentityCredential_ResourceID_envVar(t *testing.T) { // setting a dummy value for IDENTITY_ENDPOINT in order to be able to get a ManagedIdentityCredential type _ = os.Setenv("IDENTITY_ENDPOINT", "somevalue") _ = os.Setenv("IDENTITY_HEADER", "header") - _ = os.Setenv("AZURE_CLIENT_ID", "client_id") _ = os.Setenv("AZURE_RESOURCE_ID", "resource_id") defer clearEnvVars("IDENTITY_ENDPOINT", "IDENTITY_HEADER", "AZURE_CLIENT_ID", "AZURE_RESOURCE_ID") - cred, err := NewManagedIdentityCredential("", &ManagedIdentityCredentialOptions{ID: ResourceID}) + cred, err := NewManagedIdentityCredential(nil) if err != nil { t.Fatalf("unexpected error: %v", err) } - if cred.id != "resource_id" { + if cred.id != ResourceID("resource_id") { t.Fatal("unexpected id value stored") } - cred, err = NewManagedIdentityCredential("", nil) + _ = os.Setenv("AZURE_RESOURCE_ID", "") + _ = os.Setenv("AZURE_CLIENT_ID", "client_id") + cred, err = NewManagedIdentityCredential(nil) if err != nil { t.Fatalf("unexpected error: %v", err) } - if cred.id != "client_id" { + if cred.id != ClientID("client_id") { t.Fatal("unexpected id value stored") } - cred, err = NewManagedIdentityCredential("", &ManagedIdentityCredentialOptions{ID: ClientID}) + cred, err = NewManagedIdentityCredential(nil) if err != nil { t.Fatalf("unexpected error: %v", err) } - if cred.id != "client_id" { + if cred.id != ClientID("client_id") { t.Fatal("unexpected id value stored") } }