From 534325e3cb651fe2b41f889a8e1f72be9b86fd68 Mon Sep 17 00:00:00 2001 From: Mike Yeaney Date: Wed, 24 Mar 2021 11:40:50 -0400 Subject: [PATCH] Updated LoadBalancer module to use sovereign client factory. --- modules/azure/client_factory.go | 21 +++++++++++++++++ modules/azure/client_factory_test.go | 34 ++++++++++++++++++++++++++++ modules/azure/loadbalancer.go | 7 ++++-- 3 files changed, 60 insertions(+), 2 deletions(-) diff --git a/modules/azure/client_factory.go b/modules/azure/client_factory.go index d871de180..91875c1fd 100644 --- a/modules/azure/client_factory.go +++ b/modules/azure/client_factory.go @@ -18,6 +18,7 @@ import ( "github.com/Azure/azure-sdk-for-go/services/compute/mgmt/2019-07-01/compute" "github.com/Azure/azure-sdk-for-go/services/containerservice/mgmt/2019-11-01/containerservice" kvmng "github.com/Azure/azure-sdk-for-go/services/keyvault/mgmt/2016-10-01/keyvault" + "github.com/Azure/azure-sdk-for-go/services/network/mgmt/2019-09-01/network" "github.com/Azure/azure-sdk-for-go/services/resources/mgmt/2019-06-01/subscriptions" autorestAzure "github.com/Azure/go-autorest/autorest/azure" ) @@ -170,6 +171,26 @@ func GetKeyVaultURISuffixE() (string, error) { return env.KeyVaultDNSSuffix, nil } +// CreateLoadBalancerClientE returns a load balancer client instance configured with the correct BaseURI depending on +// the Azure environment that is currently setup (or "Public", if none is setup). +func CreateLoadBalancerClientE(subscriptionID string) (*network.LoadBalancersClient, error) { + // Validate Azure subscription ID + subscriptionID, err := getTargetAzureSubscription(subscriptionID) + if err != nil { + return nil, err + } + + // Lookup environment URI + baseURI, err := getEnvironmentEndpointE(ResourceManagerEndpointName) + if err != nil { + return nil, err + } + + //create LB client + client := network.NewLoadBalancersClientWithBaseURI(baseURI, subscriptionID) + return &client, nil +} + // getDefaultEnvironmentName returns either a configured Azure environment name, or the public default func getDefaultEnvironmentName() string { envName, exists := os.LookupEnv(AzureEnvironmentEnvName) diff --git a/modules/azure/client_factory_test.go b/modules/azure/client_factory_test.go index f52a5ad83..068203222 100644 --- a/modules/azure/client_factory_test.go +++ b/modules/azure/client_factory_test.go @@ -222,3 +222,37 @@ func TestCosmosDBSQLClientBaseURISetCorrectly(t *testing.T) { }) } } + +func TestLoadBalancerClientBaseURISetCorrectly(t *testing.T) { + var cases = []struct { + CaseName string + EnvironmentName string + ExpectedBaseURI string + }{ + {"GovCloud/CosmosDBAccountClient", govCloudEnvName, autorest.USGovernmentCloud.ResourceManagerEndpoint}, + {"PublicCloud/CosmosDBAccountClient", publicCloudEnvName, autorest.PublicCloud.ResourceManagerEndpoint}, + {"ChinaCloud/CosmosDBAccountClient", chinaCloudEnvName, autorest.ChinaCloud.ResourceManagerEndpoint}, + {"GermanCloud/CosmosDBAccountClient", germanyCloudEnvName, autorest.GermanCloud.ResourceManagerEndpoint}, + } + + // save any current env value and restore on exit + currentEnv := os.Getenv(AzureEnvironmentEnvName) + defer os.Setenv(AzureEnvironmentEnvName, currentEnv) + + for _, tt := range cases { + // The following is necessary to make sure testCase's values don't + // get updated due to concurrency within the scope of t.Run(..) below + tt := tt + t.Run(tt.CaseName, func(t *testing.T) { + // Override env setting + os.Setenv(AzureEnvironmentEnvName, tt.EnvironmentName) + + // Get a VM client + client, err := CreateLoadBalancerClientE("") + require.NoError(t, err) + + // Check for correct ARM URI + assert.Equal(t, tt.ExpectedBaseURI, client.BaseURI) + }) + } +} diff --git a/modules/azure/loadbalancer.go b/modules/azure/loadbalancer.go index cc98f64ea..bedc1dc96 100644 --- a/modules/azure/loadbalancer.go +++ b/modules/azure/loadbalancer.go @@ -190,7 +190,10 @@ func GetLoadBalancerClientE(subscriptionID string) (*network.LoadBalancersClient } // Get the Load Balancer client - client := network.NewLoadBalancersClient(subscriptionID) + client, err := CreateLoadBalancerClientE(subscriptionID) + if err != nil { + return nil, err + } // Create an authorizer authorizer, err := NewAuthorizer() @@ -199,5 +202,5 @@ func GetLoadBalancerClientE(subscriptionID string) (*network.LoadBalancersClient } client.Authorizer = *authorizer - return &client, nil + return client, nil }