diff --git a/sdk/azidentity/azidentity.go b/sdk/azidentity/azidentity.go index 736da23ccf42..fab8bb992c8b 100644 --- a/sdk/azidentity/azidentity.go +++ b/sdk/azidentity/azidentity.go @@ -7,6 +7,7 @@ import ( "encoding/json" "net/http" "net/url" + "os" "time" "github.com/Azure/azure-sdk-for-go/sdk/azcore" @@ -118,8 +119,13 @@ type TokenCredentialOptions struct { // NewIdentityClientOptions initializes an instance of IdentityClientOptions with default settings // NewIdentityClientOptions initializes an instance of IdentityClientOptions with default settings func (c *TokenCredentialOptions) setDefaultValues() (*TokenCredentialOptions, error) { + authorityHost := defaultAuthorityHost + if envAuthorityHost := os.Getenv("AZURE_AUTHORITY_HOST"); envAuthorityHost != "" { + authorityHost = envAuthorityHost + } + if c == nil { - defaultAuthorityHostURL, err := url.Parse(defaultAuthorityHost) + defaultAuthorityHostURL, err := url.Parse(authorityHost) if err != nil { return nil, err } @@ -127,7 +133,7 @@ func (c *TokenCredentialOptions) setDefaultValues() (*TokenCredentialOptions, er } if c.AuthorityHost == nil { - defaultAuthorityHostURL, err := url.Parse(defaultAuthorityHost) + defaultAuthorityHostURL, err := url.Parse(authorityHost) if err != nil { return nil, err } diff --git a/sdk/azidentity/azidentity_test.go b/sdk/azidentity/azidentity_test.go index 6ca1245f47a8..af560deef281 100644 --- a/sdk/azidentity/azidentity_test.go +++ b/sdk/azidentity/azidentity_test.go @@ -5,11 +5,17 @@ package azidentity import ( "net/url" + "os" "testing" "github.com/Azure/azure-sdk-for-go/sdk/azcore" ) +const ( + envHostString = "https://mock.com/" + customHostString = "https://custommock.com/" +) + func Test_AuthorityHost_Parse(t *testing.T) { _, err := url.Parse(defaultAuthorityHost) if err != nil { @@ -27,3 +33,54 @@ func Test_NonNilTokenCredentialOptsNilAuthorityHost(t *testing.T) { t.Fatalf("Did not set default authority host") } } + +func Test_SetEnvAuthorityHost(t *testing.T) { + err := os.Setenv("AZURE_AUTHORITY_HOST", envHostString) + if err != nil { + t.Fatalf("Unexpected error when initializing environment variables: %v", err) + } + + opts := &TokenCredentialOptions{} + opts, err = opts.setDefaultValues() + if opts.AuthorityHost.String() != envHostString { + t.Fatalf("Unexpected error when get host from environment vairable: %v", err) + } + + // Unset that host environment vairable to avoid other tests failed. + err = os.Unsetenv("AZURE_AUTHORITY_HOST") + if err != nil { + t.Fatalf("Unexpected error when unset environment vairable: %v", err) + } +} + +func Test_CustomAuthorityHost(t *testing.T) { + err := os.Setenv("AZURE_AUTHORITY_HOST", envHostString) + if err != nil { + t.Fatalf("Unexpected error when initializing environment variables: %v", err) + } + + customHost, err := url.Parse(customHostString) + if err != nil { + t.Fatalf("Received an error: %v", err) + } + + opts := &TokenCredentialOptions{AuthorityHost: customHost} + opts, err = opts.setDefaultValues() + if opts.AuthorityHost.String() != customHostString { + t.Fatalf("Unexpected error when get host from environment vairable: %v", err) + } + + // Unset that host environment vairable to avoid other tests failed. + err = os.Unsetenv("AZURE_AUTHORITY_HOST") + if err != nil { + t.Fatalf("Unexpected error when unset environment vairable: %v", err) + } +} + +func Test_DefaultAuthorityHost(t *testing.T) { + opts := &TokenCredentialOptions{} + opts, err := opts.setDefaultValues() + if opts.AuthorityHost.String() != defaultAuthorityHost { + t.Fatalf("Unexpected error when set default AuthorityHost: %v", err) + } +}