diff --git a/modules/azure/client_factory.go b/modules/azure/client_factory.go index d871de180..7a30f82bb 100644 --- a/modules/azure/client_factory.go +++ b/modules/azure/client_factory.go @@ -18,6 +18,7 @@ import ( "github.com/Azure/azure-sdk-for-go/services/compute/mgmt/2019-07-01/compute" "github.com/Azure/azure-sdk-for-go/services/containerservice/mgmt/2019-11-01/containerservice" kvmng "github.com/Azure/azure-sdk-for-go/services/keyvault/mgmt/2016-10-01/keyvault" + "github.com/Azure/azure-sdk-for-go/services/network/mgmt/2019-09-01/network" "github.com/Azure/azure-sdk-for-go/services/resources/mgmt/2019-06-01/subscriptions" autorestAzure "github.com/Azure/go-autorest/autorest/azure" ) @@ -138,6 +139,26 @@ func CreateCosmosDBSQLClientE(subscriptionID string) (*documentdb.SQLResourcesCl return &cosmosClient, nil } +// CreatePublicIPAddressesClientE returns a public IP address client instance configured with the correct BaseURI depending on +// the Azure environment that is currently setup (or "Public", if none is setup). +func CreatePublicIPAddressesClientE(subscriptionID string) (*network.PublicIPAddressesClient, error) { + // Validate Azure subscription ID + subscriptionID, err := getTargetAzureSubscription(subscriptionID) + if err != nil { + return nil, err + } + + // Lookup environment URI + baseURI, err := getEnvironmentEndpointE(ResourceManagerEndpointName) + if err != nil { + return nil, err + } + + // Create client + client := network.NewPublicIPAddressesClientWithBaseURI(baseURI, subscriptionID) + return &client, err +} + // CreateKeyVaultManagementClientE is a helper function that will setup a key vault management client with the correct BaseURI depending on // the Azure environment that is currently setup (or "Public", if none is setup). func CreateKeyVaultManagementClientE(subscriptionID string) (*kvmng.VaultsClient, error) { diff --git a/modules/azure/client_factory_test.go b/modules/azure/client_factory_test.go index f52a5ad83..0796979d7 100644 --- a/modules/azure/client_factory_test.go +++ b/modules/azure/client_factory_test.go @@ -222,3 +222,37 @@ func TestCosmosDBSQLClientBaseURISetCorrectly(t *testing.T) { }) } } + +func TestPublicIPAddressesClientBaseURISetCorrectly(t *testing.T) { + var cases = []struct { + CaseName string + EnvironmentName string + ExpectedBaseURI string + }{ + {"GovCloud/CosmosDBAccountClient", govCloudEnvName, autorest.USGovernmentCloud.ResourceManagerEndpoint}, + {"PublicCloud/CosmosDBAccountClient", publicCloudEnvName, autorest.PublicCloud.ResourceManagerEndpoint}, + {"ChinaCloud/CosmosDBAccountClient", chinaCloudEnvName, autorest.ChinaCloud.ResourceManagerEndpoint}, + {"GermanCloud/CosmosDBAccountClient", germanyCloudEnvName, autorest.GermanCloud.ResourceManagerEndpoint}, + } + + // save any current env value and restore on exit + currentEnv := os.Getenv(AzureEnvironmentEnvName) + defer os.Setenv(AzureEnvironmentEnvName, currentEnv) + + for _, tt := range cases { + // The following is necessary to make sure testCase's values don't + // get updated due to concurrency within the scope of t.Run(..) below + tt := tt + t.Run(tt.CaseName, func(t *testing.T) { + // Override env setting + os.Setenv(AzureEnvironmentEnvName, tt.EnvironmentName) + + // Get a VM client + client, err := CreatePublicIPAddressesClientE("") + require.NoError(t, err) + + // Check for correct ARM URI + assert.Equal(t, tt.ExpectedBaseURI, client.BaseURI) + }) + } +} diff --git a/modules/azure/publicaddress.go b/modules/azure/publicaddress.go index 28eea1cf6..ef21415ea 100644 --- a/modules/azure/publicaddress.go +++ b/modules/azure/publicaddress.go @@ -97,15 +97,12 @@ func GetPublicIPAddressE(publicIPAddressName string, resGroupName string, subscr // GetPublicIPAddressClientE creates a Public IP Addresses client in the specified Azure Subscription. func GetPublicIPAddressClientE(subscriptionID string) (*network.PublicIPAddressesClient, error) { - // Validate Azure subscription ID - subscriptionID, err := getTargetAzureSubscription(subscriptionID) + // Get the Public IP Address client from clientfactory + client, err := CreatePublicIPAddressesClientE(subscriptionID) if err != nil { return nil, err } - // Get the Public IP Address client - client := network.NewPublicIPAddressesClient(subscriptionID) - // Create an authorizer authorizer, err := NewAuthorizer() if err != nil { @@ -113,5 +110,5 @@ func GetPublicIPAddressClientE(subscriptionID string) (*network.PublicIPAddresse } client.Authorizer = *authorizer - return &client, nil + return client, nil }