Skip to content

Commit

Permalink
refactor: upstream most of Azure managed CAS changes in cloudprovider…
Browse files Browse the repository at this point in the history
…/azure
  • Loading branch information
comtalyst committed Jul 21, 2024
1 parent c8e4721 commit 2eb6cbe
Show file tree
Hide file tree
Showing 23 changed files with 343 additions and 168 deletions.
4 changes: 2 additions & 2 deletions cluster-autoscaler/cloudprovider/azure/azure_agent_pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ func (as *AgentPool) GetVMIndexes() ([]int, map[int]string, error) {
}

indexes = append(indexes, index)
resourceID, err := convertResourceGroupNameToLower("azure://" + *instance.ID)
resourceID, err := convertResourceGroupNameToLower(azurePrefix + *instance.ID)
if err != nil {
return nil, nil, err
}
Expand Down Expand Up @@ -489,7 +489,7 @@ func (as *AgentPool) Nodes() ([]cloudprovider.Instance, error) {

// To keep consistent with providerID from kubernetes cloud provider, convert
// resourceGroupName in the ID to lower case.
resourceID, err := convertResourceGroupNameToLower("azure://" + *instance.ID)
resourceID, err := convertResourceGroupNameToLower(azurePrefix + *instance.ID)
if err != nil {
return nil, err
}
Expand Down
27 changes: 18 additions & 9 deletions cluster-autoscaler/cloudprovider/azure/azure_agent_pool_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ import (
"github.com/Azure/azure-sdk-for-go/services/storage/mgmt/2021-09-01/storage"
"github.com/Azure/go-autorest/autorest/date"
"github.com/Azure/go-autorest/autorest/to"
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/assert"
"go.uber.org/mock/gomock"
)

var (
Expand Down Expand Up @@ -185,7 +185,8 @@ func TestGetVMsFromCache(t *testing.T) {
mockVMClient := mockvmclient.NewMockInterface(ctrl)
testAS.manager.azClient.virtualMachinesClient = mockVMClient
mockVMClient.EXPECT().List(gomock.Any(), testAS.manager.config.ResourceGroup).Return(expectedVMs, nil)
ac, err := newAzureCache(testAS.manager.azClient, refreshInterval, testAS.manager.config.ResourceGroup, vmTypeStandard, false, "")
testAS.manager.config.VMType = vmTypeStandard
ac, err := newAzureCache(testAS.manager.azClient, refreshInterval, *testAS.manager.config)
assert.NoError(t, err)
testAS.manager.azureCache = ac

Expand All @@ -203,7 +204,8 @@ func TestGetVMIndexes(t *testing.T) {
mockVMClient := mockvmclient.NewMockInterface(ctrl)
as.manager.azClient.virtualMachinesClient = mockVMClient
mockVMClient.EXPECT().List(gomock.Any(), as.manager.config.ResourceGroup).Return(expectedVMs, nil)
ac, err := newAzureCache(as.manager.azClient, refreshInterval, as.manager.config.ResourceGroup, vmTypeStandard, false, "")
as.manager.config.VMType = vmTypeStandard
ac, err := newAzureCache(as.manager.azClient, refreshInterval, *as.manager.config)
assert.NoError(t, err)
as.manager.azureCache = ac

Expand Down Expand Up @@ -242,7 +244,8 @@ func TestGetCurSize(t *testing.T) {
mockVMClient := mockvmclient.NewMockInterface(ctrl)
as.manager.azClient.virtualMachinesClient = mockVMClient
mockVMClient.EXPECT().List(gomock.Any(), as.manager.config.ResourceGroup).Return(expectedVMs, nil)
ac, err := newAzureCache(as.manager.azClient, refreshInterval, as.manager.config.ResourceGroup, vmTypeStandard, false, "")
as.manager.config.VMType = vmTypeStandard
ac, err := newAzureCache(as.manager.azClient, refreshInterval, *as.manager.config)
assert.NoError(t, err)
as.manager.azureCache = ac

Expand All @@ -266,7 +269,8 @@ func TestAgentPoolTargetSize(t *testing.T) {
as.manager.azClient.virtualMachinesClient = mockVMClient
expectedVMs := getExpectedVMs()
mockVMClient.EXPECT().List(gomock.Any(), as.manager.config.ResourceGroup).Return(expectedVMs, nil)
ac, err := newAzureCache(as.manager.azClient, refreshInterval, as.manager.config.ResourceGroup, vmTypeStandard, false, "")
as.manager.config.VMType = vmTypeStandard
ac, err := newAzureCache(as.manager.azClient, refreshInterval, *as.manager.config)
assert.NoError(t, err)
as.manager.azureCache = ac

Expand All @@ -285,7 +289,8 @@ func TestAgentPoolIncreaseSize(t *testing.T) {
as.manager.azClient.virtualMachinesClient = mockVMClient
expectedVMs := getExpectedVMs()
mockVMClient.EXPECT().List(gomock.Any(), as.manager.config.ResourceGroup).Return(expectedVMs, nil).MaxTimes(2)
ac, err := newAzureCache(as.manager.azClient, refreshInterval, as.manager.config.ResourceGroup, vmTypeStandard, false, "")
as.manager.config.VMType = vmTypeStandard
ac, err := newAzureCache(as.manager.azClient, refreshInterval, *as.manager.config)
assert.NoError(t, err)
as.manager.azureCache = ac

Expand Down Expand Up @@ -313,7 +318,8 @@ func TestDecreaseTargetSize(t *testing.T) {
as.manager.azClient.virtualMachinesClient = mockVMClient
expectedVMs := getExpectedVMs()
mockVMClient.EXPECT().List(gomock.Any(), as.manager.config.ResourceGroup).Return(expectedVMs, nil).MaxTimes(3)
ac, err := newAzureCache(as.manager.azClient, refreshInterval, as.manager.config.ResourceGroup, vmTypeStandard, false, "")
as.manager.config.VMType = vmTypeStandard
ac, err := newAzureCache(as.manager.azClient, refreshInterval, *as.manager.config)
assert.NoError(t, err)
as.manager.azureCache = ac

Expand Down Expand Up @@ -431,7 +437,9 @@ func TestAgentPoolDeleteNodes(t *testing.T) {
mockSAClient := mockstorageaccountclient.NewMockInterface(ctrl)
as.manager.azClient.storageAccountsClient = mockSAClient
mockVMClient.EXPECT().List(gomock.Any(), as.manager.config.ResourceGroup).Return(expectedVMs, nil)
ac, err := newAzureCache(as.manager.azClient, refreshInterval, as.manager.config.ResourceGroup, vmTypeStandard, false, "")
as.manager.config.VMType = vmTypeStandard
ac, err := newAzureCache(as.manager.azClient, refreshInterval, *as.manager.config)
as.manager.config.VMType = vmTypeVMSS
assert.NoError(t, err)
as.manager.azureCache = ac

Expand Down Expand Up @@ -497,7 +505,8 @@ func TestAgentPoolNodes(t *testing.T) {
mockVMClient := mockvmclient.NewMockInterface(ctrl)
as.manager.azClient.virtualMachinesClient = mockVMClient
mockVMClient.EXPECT().List(gomock.Any(), as.manager.config.ResourceGroup).Return(expectedVMs, nil)
ac, err := newAzureCache(as.manager.azClient, refreshInterval, as.manager.config.ResourceGroup, vmTypeStandard, false, "")
as.manager.config.VMType = vmTypeStandard
ac, err := newAzureCache(as.manager.azClient, refreshInterval, *as.manager.config)
assert.NoError(t, err)
as.manager.azureCache = ac

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ func matchDiscoveryConfig(labels map[string]*string, configs []labelAutoDiscover
return nil
}

if len(v) > 0 {
if v != "" {
if value == nil || *value != v {
return nil
}
Expand Down
77 changes: 57 additions & 20 deletions cluster-autoscaler/cloudprovider/azure/azure_cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,35 +39,72 @@ var (
// azureCache is used for caching cluster resources state.
//
// It is needed to:
// - keep track of node groups (VM and VMSS types) in the cluster,
// - keep track of instances and which node group they belong to,
// - limit repetitive Azure API calls.
// - keep track of node groups (VM and VMSS types) in the cluster,
// - keep track of instances and which node group they belong to,
// (for VMSS it only keeps track of instanceid-to-nodegroup mapping)
// - limit repetitive Azure API calls.
//
// It backs efficient responds to
// - cloudprovider.NodeGroups() (= registeredNodeGroups)
// - cloudprovider.NodeGroupForNode (via azureManager.GetNodeGroupForInstance => FindForInstance,
// using instanceToNodeGroup and unownedInstances)
//
// CloudProvider.Refresh, called before every autoscaler loop (every 10s by defaul),
// is implemented by AzureManager.Refresh which makes the cache refresh decision,
// based on AzureManager.lastRefresh and azureCache.refreshInterval.
type azureCache struct {
mutex sync.Mutex
interrupt chan struct{}
azClient *azClient
mutex sync.Mutex
interrupt chan struct{}
azClient *azClient

// refreshInterval specifies how often azureCache needs to be refreshed.
// The value comes from AZURE_VMSS_CACHE_TTL env var (or 1min if not specified),
// and is also used by some other caches. Together with AzureManager.lastRefresh,
// it is uses to decide whether a refresh is needed.
refreshInterval time.Duration

// Cache content.
resourceGroup string
vmType string
vmsPoolSet map[string]struct{} // track the nodepools that're vms pool
scaleSets map[string]compute.VirtualMachineScaleSet
virtualMachines map[string][]compute.VirtualMachine

// resourceGroup specifies the name of the resource group that this cache tracks
resourceGroup string

// vmType can be one of vmTypeVMSS (default), vmTypeStandard
vmType string

vmsPoolSet map[string]struct{} // track the nodepools that're vms pool

// scaleSets keeps the set of all known scalesets in the resource group, populated/refreshed via VMSS.List() call.
// It is only used/populated if vmType is vmTypeVMSS (default).
scaleSets map[string]compute.VirtualMachineScaleSet
// virtualMachines keeps the set of all VMs in the resource group.
// It is only used/populated if vmType is vmTypeStandard.
virtualMachines map[string][]compute.VirtualMachine

// registeredNodeGroups represents all known NodeGroups.
registeredNodeGroups []cloudprovider.NodeGroup
instanceToNodeGroup map[azureRef]cloudprovider.NodeGroup
unownedInstances map[azureRef]bool
autoscalingOptions map[azureRef]map[string]string
skus map[string]*skewer.Cache

// instanceToNodeGroup maintains a mapping from instance Ids to nodegroups.
// It is populated from the results of calling Nodes() on each nodegroup.
// It is used (together with unownedInstances) when looking up the nodegroup
// for a given instance id (see FindForInstance).
instanceToNodeGroup map[azureRef]cloudprovider.NodeGroup

// unownedInstance maintains a set of instance ids not belonging to any nodegroup.
// It is used (together with instanceToNodeGroup) when looking up the nodegroup for a given instance id.
// It is reset by invalidateUnownedInstanceCache().
unownedInstances map[azureRef]bool

autoscalingOptions map[azureRef]map[string]string
skus map[string]*skewer.Cache
}

func newAzureCache(client *azClient, cacheTTL time.Duration, resourceGroup, vmType string, enableDynamicInstanceList bool, defaultLocation string) (*azureCache, error) {
func newAzureCache(client *azClient, cacheTTL time.Duration, config Config) (*azureCache, error) {
cache := &azureCache{
interrupt: make(chan struct{}),
azClient: client,
refreshInterval: cacheTTL,
resourceGroup: resourceGroup,
vmType: vmType,
resourceGroup: config.ResourceGroup,
vmType: config.VMType,
vmsPoolSet: make(map[string]struct{}),
scaleSets: make(map[string]compute.VirtualMachineScaleSet),
virtualMachines: make(map[string][]compute.VirtualMachine),
Expand All @@ -78,8 +115,8 @@ func newAzureCache(client *azClient, cacheTTL time.Duration, resourceGroup, vmTy
skus: make(map[string]*skewer.Cache),
}

if enableDynamicInstanceList {
cache.skus[defaultLocation] = &skewer.Cache{}
if config.EnableDynamicInstanceList {
cache.skus[config.Location] = &skewer.Cache{}
}

if err := cache.regenerate(); err != nil {
Expand Down
9 changes: 5 additions & 4 deletions cluster-autoscaler/cloudprovider/azure/azure_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,7 @@ func newServicePrincipalTokenFromCredentials(config *Config, env *azure.Environm
if err != nil {
return nil, fmt.Errorf("getting the managed service identity endpoint: %v", err)
}
if len(config.UserAssignedIdentityID) > 0 {
if config.UserAssignedIdentityID != "" {
klog.V(4).Info("azure: using User Assigned MSI ID to retrieve access token")
return adal.NewServicePrincipalTokenFromMSIWithUserAssignedID(msiEndpoint,
env.ServiceManagementEndpoint,
Expand All @@ -314,7 +314,7 @@ func newServicePrincipalTokenFromCredentials(config *Config, env *azure.Environm
env.ServiceManagementEndpoint)
}

if len(config.AADClientSecret) > 0 {
if config.AADClientSecret != "" {
klog.V(2).Infoln("azure: using client_id+client_secret to retrieve access token")
return adal.NewServicePrincipalToken(
*oauthConfig,
Expand All @@ -323,13 +323,13 @@ func newServicePrincipalTokenFromCredentials(config *Config, env *azure.Environm
env.ServiceManagementEndpoint)
}

if len(config.AADClientCertPath) > 0 && len(config.AADClientCertPassword) > 0 {
if config.AADClientCertPath != "" {
klog.V(2).Infoln("azure: using jwt client_assertion (client_cert+client_private_key) to retrieve access token")
certData, err := ioutil.ReadFile(config.AADClientCertPath)
if err != nil {
return nil, fmt.Errorf("reading the client certificate from file %s: %v", config.AADClientCertPath, err)
}
certificate, privateKey, err := decodePkcs12(certData, config.AADClientCertPassword)
certificate, privateKey, err := adal.DecodePfxCertificateData(certData, config.AADClientCertPassword)
if err != nil {
return nil, fmt.Errorf("decoding the client certificate: %v", err)
}
Expand Down Expand Up @@ -399,6 +399,7 @@ func newAzClient(cfg *Config, env *azure.Environment) (*azClient, error) {
// https://github.com/Azure/go-autorest/blob/main/autorest/azure/environments.go
skuClient := compute.NewResourceSkusClientWithBaseURI(azClientConfig.ResourceManagerEndpoint, cfg.SubscriptionID)
skuClient.Authorizer = azClientConfig.Authorizer
skuClient.UserAgent = azClientConfig.UserAgent
klog.V(5).Infof("Created sku client with authorizer: %v", skuClient)

agentPoolClient, err := newAgentpoolClient(cfg)
Expand Down
71 changes: 71 additions & 0 deletions cluster-autoscaler/cloudprovider/azure/azure_client_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
/*
Copyright 2018 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package azure

import (
"os"
"testing"

"github.com/Azure/go-autorest/autorest/adal"
"github.com/Azure/go-autorest/autorest/azure"
"github.com/stretchr/testify/assert"
)

func TestGetServicePrincipalTokenFromCertificate(t *testing.T) {
config := &Config{
TenantID: "TenantID",
AADClientID: "AADClientID",
AADClientCertPath: "./testdata/test.pfx",
AADClientCertPassword: "id",
}
env := &azure.PublicCloud
token, err := newServicePrincipalTokenFromCredentials(config, env)
assert.NoError(t, err)

oauthConfig, err := adal.NewOAuthConfig(env.ActiveDirectoryEndpoint, config.TenantID)
assert.NoError(t, err)
pfxContent, err := os.ReadFile("./testdata/test.pfx")
assert.NoError(t, err)
certificate, privateKey, err := adal.DecodePfxCertificateData(pfxContent, "id")
assert.NoError(t, err)
spt, err := adal.NewServicePrincipalTokenFromCertificate(
*oauthConfig, config.AADClientID, certificate, privateKey, env.ServiceManagementEndpoint)
assert.NoError(t, err)
assert.Equal(t, token, spt)
}

func TestGetServicePrincipalTokenFromCertificateWithoutPassword(t *testing.T) {
config := &Config{
TenantID: "TenantID",
AADClientID: "AADClientID",
AADClientCertPath: "./testdata/testnopassword.pfx",
}
env := &azure.PublicCloud
token, err := newServicePrincipalTokenFromCredentials(config, env)
assert.NoError(t, err)

oauthConfig, err := adal.NewOAuthConfig(env.ActiveDirectoryEndpoint, config.TenantID)
assert.NoError(t, err)
pfxContent, err := os.ReadFile("./testdata/testnopassword.pfx")
assert.NoError(t, err)
certificate, privateKey, err := adal.DecodePfxCertificateData(pfxContent, "")
assert.NoError(t, err)
spt, err := adal.NewServicePrincipalTokenFromCertificate(
*oauthConfig, config.AADClientID, certificate, privateKey, env.ServiceManagementEndpoint)
assert.NoError(t, err)
assert.Equal(t, token, spt)
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ import (

const (
// GPULabel is the label added to nodes with GPU resource.
GPULabel = "accelerator"
GPULabel = AKSLabelKeyPrefixValue + "accelerator"
legacyGPULabel = "accelerator"
)

var (
Expand Down Expand Up @@ -73,7 +74,7 @@ func (azure *AzureCloudProvider) Name() string {

// GPULabel returns the label added to nodes with GPU resource.
func (azure *AzureCloudProvider) GPULabel() string {
return GPULabel
return legacyGPULabel // Use legacy to avoid breaking, for now
}

// GetAvailableGPUTypes return all available GPU types cloud provider supports
Expand Down
Loading

0 comments on commit 2eb6cbe

Please sign in to comment.