diff --git a/modules/azure/client_factory.go b/modules/azure/client_factory.go index 91307aa72..1dd97d8e9 100644 --- a/modules/azure/client_factory.go +++ b/modules/azure/client_factory.go @@ -22,6 +22,7 @@ import ( "github.com/Azure/azure-sdk-for-go/services/compute/mgmt/2019-07-01/compute" "github.com/Azure/azure-sdk-for-go/services/containerservice/mgmt/2019-11-01/containerservice" kvmng "github.com/Azure/azure-sdk-for-go/services/keyvault/mgmt/2016-10-01/keyvault" + "github.com/Azure/azure-sdk-for-go/services/network/mgmt/2019-09-01/network" "github.com/Azure/azure-sdk-for-go/services/resources/mgmt/2019-06-01/subscriptions" "github.com/Azure/azure-sdk-for-go/services/storage/mgmt/2019-06-01/storage" autorestAzure "github.com/Azure/go-autorest/autorest/azure" @@ -63,21 +64,23 @@ func CreateSubscriptionsClientE() (subscriptions.Client, error) { // CreateVirtualMachinesClientE returns a virtual machines client instance configured with the correct BaseURI depending on // the Azure environment that is currently setup (or "Public", if none is setup). -func CreateVirtualMachinesClientE(subscriptionID string) (compute.VirtualMachinesClient, error) { +func CreateVirtualMachinesClientE(subscriptionID string) (*compute.VirtualMachinesClient, error) { // Validate Azure subscription ID subscriptionID, err := getTargetAzureSubscription(subscriptionID) if err != nil { - return compute.VirtualMachinesClient{}, err + return nil, err } // Lookup environment URI baseURI, err := getBaseURI() if err != nil { - return compute.VirtualMachinesClient{}, err + return nil, err } // Create correct client based on type passed - return compute.NewVirtualMachinesClientWithBaseURI(baseURI, subscriptionID), nil + vmClient := compute.NewVirtualMachinesClientWithBaseURI(baseURI, subscriptionID) + + return &vmClient, nil } // snippet-tag-end::client_factory_example.CreateClient @@ -479,6 +482,166 @@ func CreateDiagnosticsSettingsClientE(subscriptionID string) (*insights.Diagnost return &client, nil } +// CreateNsgDefaultRulesClientE returns an NSG default (platform) rules client instance configured with the +// correct BaseURI depending on the Azure environment that is currently setup (or "Public", if none is setup). +func CreateNsgDefaultRulesClientE(subscriptionID string) (*network.DefaultSecurityRulesClient, error) { + // Validate Azure subscription ID + subscriptionID, err := getTargetAzureSubscription(subscriptionID) + if err != nil { + return nil, err + } + + // Lookup environment URI + baseURI, err := getEnvironmentEndpointE(ResourceManagerEndpointName) + if err != nil { + return nil, err + } + + // Create new client + nsgClient := network.NewDefaultSecurityRulesClientWithBaseURI(baseURI, subscriptionID) + return &nsgClient, nil +} + +// CreateNsgCustomRulesClientE returns an NSG custom (user) rules client instance configured with the +// correct BaseURI depending on the Azure environment that is currently setup (or "Public", if none is setup). +func CreateNsgCustomRulesClientE(subscriptionID string) (*network.SecurityRulesClient, error) { + // Validate Azure subscription ID + subscriptionID, err := getTargetAzureSubscription(subscriptionID) + if err != nil { + return nil, err + } + + // Lookup environment URI + baseURI, err := getEnvironmentEndpointE(ResourceManagerEndpointName) + if err != nil { + return nil, err + } + + // Create new client + nsgClient := network.NewSecurityRulesClientWithBaseURI(baseURI, subscriptionID) + return &nsgClient, nil +} + +// CreateNewNetworkInterfacesClientE returns an NIC client instance configured with the +// correct BaseURI depending on the Azure environment that is currently setup (or "Public", if none is setup). +func CreateNewNetworkInterfacesClientE(subscriptionID string) (*network.InterfacesClient, error) { + // Validate Azure subscription ID + subscriptionID, err := getTargetAzureSubscription(subscriptionID) + if err != nil { + return nil, err + } + + // Lookup environment URI + baseURI, err := getEnvironmentEndpointE(ResourceManagerEndpointName) + if err != nil { + return nil, err + } + + // create client + nicClient := network.NewInterfacesClientWithBaseURI(baseURI, subscriptionID) + return &nicClient, nil +} + +// CreateNewNetworkInterfaceIPConfigurationClientE returns an NIC IP configuration client instance configured with the +// correct BaseURI depending on the Azure environment that is currently setup (or "Public", if none is setup). +func CreateNewNetworkInterfaceIPConfigurationClientE(subscriptionID string) (*network.InterfaceIPConfigurationsClient, error) { + // Validate Azure subscription ID + subscriptionID, err := getTargetAzureSubscription(subscriptionID) + if err != nil { + return nil, err + } + + // Lookup environment URI + baseURI, err := getEnvironmentEndpointE(ResourceManagerEndpointName) + if err != nil { + return nil, err + } + + // create client + ipConfigClient := network.NewInterfaceIPConfigurationsClientWithBaseURI(baseURI, subscriptionID) + return &ipConfigClient, nil +} + +// CreatePublicIPAddressesClientE returns a public IP address client instance configured with the correct BaseURI depending on +// the Azure environment that is currently setup (or "Public", if none is setup). +func CreatePublicIPAddressesClientE(subscriptionID string) (*network.PublicIPAddressesClient, error) { + // Validate Azure subscription ID + subscriptionID, err := getTargetAzureSubscription(subscriptionID) + if err != nil { + return nil, err + } + + // Lookup environment URI + baseURI, err := getEnvironmentEndpointE(ResourceManagerEndpointName) + if err != nil { + return nil, err + } + + // Create client + client := network.NewPublicIPAddressesClientWithBaseURI(baseURI, subscriptionID) + return &client, nil +} + +// CreateLoadBalancerClientE returns a load balancer client instance configured with the correct BaseURI depending on +// the Azure environment that is currently setup (or "Public", if none is setup). +func CreateLoadBalancerClientE(subscriptionID string) (*network.LoadBalancersClient, error) { + // Validate Azure subscription ID + subscriptionID, err := getTargetAzureSubscription(subscriptionID) + if err != nil { + return nil, err + } + + // Lookup environment URI + baseURI, err := getEnvironmentEndpointE(ResourceManagerEndpointName) + if err != nil { + return nil, err + } + + //create LB client + client := network.NewLoadBalancersClientWithBaseURI(baseURI, subscriptionID) + return &client, nil +} + +// CreateNewSubnetClientE returns a Subnet client instance configured with the +// correct BaseURI depending on the Azure environment that is currently setup (or "Public", if none is setup). +func CreateNewSubnetClientE(subscriptionID string) (*network.SubnetsClient, error) { + // Validate Azure subscription ID + subscriptionID, err := getTargetAzureSubscription(subscriptionID) + if err != nil { + return nil, err + } + + // Lookup environment URI + baseURI, err := getEnvironmentEndpointE(ResourceManagerEndpointName) + if err != nil { + return nil, err + } + + // create client + subnetClient := network.NewSubnetsClientWithBaseURI(baseURI, subscriptionID) + return &subnetClient, nil +} + +// CreateNewVirtualNetworkClientE returns a Virtual Network client instance configured with the +// correct BaseURI depending on the Azure environment that is currently setup (or "Public", if none is setup). +func CreateNewVirtualNetworkClientE(subscriptionID string) (*network.VirtualNetworksClient, error) { + // Validate Azure subscription ID + subscriptionID, err := getTargetAzureSubscription(subscriptionID) + if err != nil { + return nil, err + } + + // Lookup environment URI + baseURI, err := getEnvironmentEndpointE(ResourceManagerEndpointName) + if err != nil { + return nil, err + } + + // create client + vnetClient := network.NewVirtualNetworksClientWithBaseURI(baseURI, subscriptionID) + return &vnetClient, nil +} + // GetKeyVaultURISuffixE returns the proper KeyVault URI suffix for the configured Azure environment. // This function would fail the test if there is an error. func GetKeyVaultURISuffixE() (string, error) { diff --git a/modules/azure/client_factory_test.go b/modules/azure/client_factory_test.go index f52a5ad83..81cf00dbb 100644 --- a/modules/azure/client_factory_test.go +++ b/modules/azure/client_factory_test.go @@ -222,3 +222,69 @@ func TestCosmosDBSQLClientBaseURISetCorrectly(t *testing.T) { }) } } +func TestPublicIPAddressesClientBaseURISetCorrectly(t *testing.T) { + var cases = []struct { + CaseName string + EnvironmentName string + ExpectedBaseURI string + }{ + {"GovCloud/CosmosDBAccountClient", govCloudEnvName, autorest.USGovernmentCloud.ResourceManagerEndpoint}, + {"PublicCloud/CosmosDBAccountClient", publicCloudEnvName, autorest.PublicCloud.ResourceManagerEndpoint}, + {"ChinaCloud/CosmosDBAccountClient", chinaCloudEnvName, autorest.ChinaCloud.ResourceManagerEndpoint}, + {"GermanCloud/CosmosDBAccountClient", germanyCloudEnvName, autorest.GermanCloud.ResourceManagerEndpoint}, + } + + // save any current env value and restore on exit + currentEnv := os.Getenv(AzureEnvironmentEnvName) + defer os.Setenv(AzureEnvironmentEnvName, currentEnv) + + for _, tt := range cases { + // The following is necessary to make sure testCase's values don't + // get updated due to concurrency within the scope of t.Run(..) below + tt := tt + t.Run(tt.CaseName, func(t *testing.T) { + // Override env setting + os.Setenv(AzureEnvironmentEnvName, tt.EnvironmentName) + + // Get a VM client + client, err := CreatePublicIPAddressesClientE("") + require.NoError(t, err) + + // Check for correct ARM URI + assert.Equal(t, tt.ExpectedBaseURI, client.BaseURI) + }) + } +} +func TestLoadBalancerClientBaseURISetCorrectly(t *testing.T) { + var cases = []struct { + CaseName string + EnvironmentName string + ExpectedBaseURI string + }{ + {"GovCloud/CosmosDBAccountClient", govCloudEnvName, autorest.USGovernmentCloud.ResourceManagerEndpoint}, + {"PublicCloud/CosmosDBAccountClient", publicCloudEnvName, autorest.PublicCloud.ResourceManagerEndpoint}, + {"ChinaCloud/CosmosDBAccountClient", chinaCloudEnvName, autorest.ChinaCloud.ResourceManagerEndpoint}, + {"GermanCloud/CosmosDBAccountClient", germanyCloudEnvName, autorest.GermanCloud.ResourceManagerEndpoint}, + } + + // save any current env value and restore on exit + currentEnv := os.Getenv(AzureEnvironmentEnvName) + defer os.Setenv(AzureEnvironmentEnvName, currentEnv) + + for _, tt := range cases { + // The following is necessary to make sure testCase's values don't + // get updated due to concurrency within the scope of t.Run(..) below + tt := tt + t.Run(tt.CaseName, func(t *testing.T) { + // Override env setting + os.Setenv(AzureEnvironmentEnvName, tt.EnvironmentName) + + // Get a VM client + client, err := CreateLoadBalancerClientE("") + require.NoError(t, err) + + // Check for correct ARM URI + assert.Equal(t, tt.ExpectedBaseURI, client.BaseURI) + }) + } +} diff --git a/modules/azure/compute.go b/modules/azure/compute.go index 8fc666698..9d8540cd2 100644 --- a/modules/azure/compute.go +++ b/modules/azure/compute.go @@ -34,7 +34,7 @@ func GetVirtualMachineClientE(subscriptionID string) (*compute.VirtualMachinesCl // Attach authorizer to the client vmClient.Authorizer = *authorizer - return &vmClient, nil + return vmClient, nil } // VirtualMachineExists indicates whether the specifcied Azure Virtual Machine exists. diff --git a/modules/azure/loadbalancer.go b/modules/azure/loadbalancer.go index cc98f64ea..382543995 100644 --- a/modules/azure/loadbalancer.go +++ b/modules/azure/loadbalancer.go @@ -183,15 +183,12 @@ func GetLoadBalancerE(loadBalancerName string, resourceGroupName string, subscri // GetLoadBalancerClientE gets a new Load Balancer client in the specified Azure Subscription. func GetLoadBalancerClientE(subscriptionID string) (*network.LoadBalancersClient, error) { - // Validate Azure subscription ID - subscriptionID, err := getTargetAzureSubscription(subscriptionID) + // Get the Load Balancer client + client, err := CreateLoadBalancerClientE(subscriptionID) if err != nil { return nil, err } - // Get the Load Balancer client - client := network.NewLoadBalancersClient(subscriptionID) - // Create an authorizer authorizer, err := NewAuthorizer() if err != nil { @@ -199,5 +196,5 @@ func GetLoadBalancerClientE(subscriptionID string) (*network.LoadBalancersClient } client.Authorizer = *authorizer - return &client, nil + return client, nil } diff --git a/modules/azure/networkinterface.go b/modules/azure/networkinterface.go index 14c6de0d0..eb6134c27 100644 --- a/modules/azure/networkinterface.go +++ b/modules/azure/networkinterface.go @@ -118,15 +118,12 @@ func GetNetworkInterfaceConfigurationE(nicName string, nicConfigName string, res // GetNetworkInterfaceConfigurationClientE creates a new Network Interface Configuration client in the specified Azure Subscription. func GetNetworkInterfaceConfigurationClientE(subscriptionID string) (*network.InterfaceIPConfigurationsClient, error) { - // Validate Azure Subscription ID - subscriptionID, err := getTargetAzureSubscription(subscriptionID) + // Create a new client from client factory + client, err := CreateNewNetworkInterfaceIPConfigurationClientE(subscriptionID) if err != nil { return nil, err } - // Get the NIC client - client := network.NewInterfaceIPConfigurationsClient(subscriptionID) - // Create an authorizer authorizer, err := NewAuthorizer() if err != nil { @@ -134,7 +131,7 @@ func GetNetworkInterfaceConfigurationClientE(subscriptionID string) (*network.In } client.Authorizer = *authorizer - return &client, nil + return client, nil } // GetNetworkInterfaceE gets a Network Interface in the specified Azure Resource Group. @@ -162,15 +159,12 @@ func GetNetworkInterfaceE(nicName string, resGroupName string, subscriptionID st // GetNetworkInterfaceClientE creates a new Network Interface client in the specified Azure Subscription. func GetNetworkInterfaceClientE(subscriptionID string) (*network.InterfacesClient, error) { - // Validate Azure Subscription ID - subscriptionID, err := getTargetAzureSubscription(subscriptionID) + // Create new NIC client from client factory + client, err := CreateNewNetworkInterfacesClientE(subscriptionID) if err != nil { return nil, err } - // Get the NIC client - client := network.NewInterfacesClient(subscriptionID) - // Create an authorizer authorizer, err := NewAuthorizer() if err != nil { @@ -178,5 +172,5 @@ func GetNetworkInterfaceClientE(subscriptionID string) (*network.InterfacesClien } client.Authorizer = *authorizer - return &client, nil + return client, nil } diff --git a/modules/azure/nsg.go b/modules/azure/nsg.go index 934456337..865142587 100644 --- a/modules/azure/nsg.go +++ b/modules/azure/nsg.go @@ -46,14 +46,12 @@ func GetDefaultNsgRulesClient(t *testing.T, subscriptionID string) network.Defau // defined on an network security group. Note that the "default" rules are those provided implicitly // by the Azure platform. func GetDefaultNsgRulesClientE(subscriptionID string) (network.DefaultSecurityRulesClient, error) { - // Validate Azure subscription ID - subscriptionID, err := getTargetAzureSubscription(subscriptionID) + // Get new default client from client factory + nsgClient, err := CreateNsgDefaultRulesClientE(subscriptionID) if err != nil { return network.DefaultSecurityRulesClient{}, err } - nsgClient := network.NewDefaultSecurityRulesClient(subscriptionID) - // Get an authorizer auth, err := NewAuthorizer() if err != nil { @@ -61,7 +59,7 @@ func GetDefaultNsgRulesClientE(subscriptionID string) (network.DefaultSecurityRu } nsgClient.Authorizer = *auth - return nsgClient, nil + return *nsgClient, nil } // GetCustomNsgRulesClient returns a rules client which can be used to read the list of *custom* security rules @@ -78,14 +76,12 @@ func GetCustomNsgRulesClient(t *testing.T, subscriptionID string) network.Securi // defined on an network security group. Note that the "custom" rules are those defined by // end users. func GetCustomNsgRulesClientE(subscriptionID string) (network.SecurityRulesClient, error) { - // Validate Azure subscription ID - subscriptionID, err := getTargetAzureSubscription(subscriptionID) + // Get new custom rules client from client factory + nsgClient, err := CreateNsgCustomRulesClientE(subscriptionID) if err != nil { return network.SecurityRulesClient{}, err } - nsgClient := network.NewSecurityRulesClient(subscriptionID) - // Get an authorizer auth, err := NewAuthorizer() if err != nil { @@ -93,7 +89,7 @@ func GetCustomNsgRulesClientE(subscriptionID string) (network.SecurityRulesClien } nsgClient.Authorizer = *auth - return nsgClient, nil + return *nsgClient, nil } // GetAllNSGRules returns an NsgRuleSummaryList instance containing the combined "default" and "custom" rules from a network diff --git a/modules/azure/publicaddress.go b/modules/azure/publicaddress.go index 28eea1cf6..ef21415ea 100644 --- a/modules/azure/publicaddress.go +++ b/modules/azure/publicaddress.go @@ -97,15 +97,12 @@ func GetPublicIPAddressE(publicIPAddressName string, resGroupName string, subscr // GetPublicIPAddressClientE creates a Public IP Addresses client in the specified Azure Subscription. func GetPublicIPAddressClientE(subscriptionID string) (*network.PublicIPAddressesClient, error) { - // Validate Azure subscription ID - subscriptionID, err := getTargetAzureSubscription(subscriptionID) + // Get the Public IP Address client from clientfactory + client, err := CreatePublicIPAddressesClientE(subscriptionID) if err != nil { return nil, err } - // Get the Public IP Address client - client := network.NewPublicIPAddressesClient(subscriptionID) - // Create an authorizer authorizer, err := NewAuthorizer() if err != nil { @@ -113,5 +110,5 @@ func GetPublicIPAddressClientE(subscriptionID string) (*network.PublicIPAddresse } client.Authorizer = *authorizer - return &client, nil + return client, nil } diff --git a/modules/azure/virtualnetwork.go b/modules/azure/virtualnetwork.go index c810b0fc4..49a935917 100644 --- a/modules/azure/virtualnetwork.go +++ b/modules/azure/virtualnetwork.go @@ -162,15 +162,12 @@ func GetSubnetE(subnetName string, vnetName string, resGroupName string, subscri // GetSubnetClientE creates a subnet client. func GetSubnetClientE(subscriptionID string) (*network.SubnetsClient, error) { - // Validate Azure subscription ID - subscriptionID, err := getTargetAzureSubscription(subscriptionID) + // Create a new Subnet client from client factory + client, err := CreateNewSubnetClientE(subscriptionID) if err != nil { return nil, err } - // Get the Subnet client - client := network.NewSubnetsClient(subscriptionID) - // Create an authorizer authorizer, err := NewAuthorizer() if err != nil { @@ -178,7 +175,7 @@ func GetSubnetClientE(subscriptionID string) (*network.SubnetsClient, error) { } client.Authorizer = *authorizer - return &client, nil + return client, nil } // GetVirtualNetworkE gets Virtual Network in the specified Azure Resource Group. @@ -205,15 +202,12 @@ func GetVirtualNetworkE(vnetName string, resGroupName string, subscriptionID str // GetVirtualNetworksClientE creates a virtual network client in the specified Azure Subscription. func GetVirtualNetworksClientE(subscriptionID string) (*network.VirtualNetworksClient, error) { - // Validate Azure subscription ID - subscriptionID, err := getTargetAzureSubscription(subscriptionID) + // Create a new Virtual Network client from client factory + client, err := CreateNewVirtualNetworkClientE(subscriptionID) if err != nil { return nil, err } - // Get the Virtual Network client - client := network.NewVirtualNetworksClient(subscriptionID) - // Create an authorizer authorizer, err := NewAuthorizer() if err != nil { @@ -221,5 +215,5 @@ func GetVirtualNetworksClientE(subscriptionID string) (*network.VirtualNetworksC } client.Authorizer = *authorizer - return &client, nil + return client, nil }