diff --git a/pkg/cloudprovider/aws/ami.go b/pkg/cloudprovider/aws/ami.go new file mode 100644 index 000000000000..2a2a19b6ffe7 --- /dev/null +++ b/pkg/cloudprovider/aws/ami.go @@ -0,0 +1,51 @@ +package aws + +import ( + "context" + "fmt" + "strings" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/service/ssm" + "github.com/aws/aws-sdk-go/service/ssm/ssmiface" + "github.com/patrickmn/go-cache" + "k8s.io/client-go/kubernetes" +) + +type AMIProvider struct { + cache *cache.Cache + ssm ssmiface.SSMAPI + clientSet *kubernetes.Clientset +} + +func NewAMIProvider(ssm ssmiface.SSMAPI, clientSet *kubernetes.Clientset) *AMIProvider { + return &AMIProvider{ + ssm: ssm, + clientSet: clientSet, + cache: cache.New(CacheTTL, CacheCleanupInterval), + } +} + +func (p *AMIProvider) Get(ctx context.Context, constraints *Constraints) (string, error) { + version, err := p.kubeServerVersion() + if err != nil { + return "", fmt.Errorf("kube server version, %w", err) + } + name := fmt.Sprintf("/aws/service/bottlerocket/aws-k8s-%s/%s/latest/image_id", version, KubeToAWSArchitectures[*constraints.Architecture]) + if id, ok := p.cache.Get(name); ok { + return id.(string), nil + } + output, err := p.ssm.GetParameterWithContext(ctx, &ssm.GetParameterInput{Name: aws.String(name)}) + if err != nil { + return "", fmt.Errorf("getting ssm parameter, %w", err) + } + return aws.StringValue(output.Parameter.Value), nil +} + +func (p *AMIProvider) kubeServerVersion() (string, error) { + version, err := p.clientSet.Discovery().ServerVersion() + if err != nil { + return "", err + } + return fmt.Sprintf("%s.%s", version.Major, strings.TrimSuffix(version.Minor, "+")), nil +} diff --git a/pkg/cloudprovider/aws/cloudprovider.go b/pkg/cloudprovider/aws/cloudprovider.go index 5e37ab83fbc6..12a87198e2ae 100644 --- a/pkg/cloudprovider/aws/cloudprovider.go +++ b/pkg/cloudprovider/aws/cloudprovider.go @@ -79,9 +79,8 @@ func NewCloudProvider(options cloudprovider.Options) *CloudProvider { launchTemplateProvider: &LaunchTemplateProvider{ ec2api: ec2api, cache: cache.New(CacheTTL, CacheCleanupInterval), + amiProvider: NewAMIProvider(ssm.New(sess), options.ClientSet), securityGroupProvider: NewSecurityGroupProvider(ec2api), - ssm: ssm.New(sess), - clientSet: options.ClientSet, }, subnetProvider: NewSubnetProvider(ec2api), instanceTypeProvider: NewInstanceTypeProvider(ec2api), diff --git a/pkg/cloudprovider/aws/constraints.go b/pkg/cloudprovider/aws/constraints.go index b1fdb931c820..0f95fbb5a438 100644 --- a/pkg/cloudprovider/aws/constraints.go +++ b/pkg/cloudprovider/aws/constraints.go @@ -38,12 +38,16 @@ var ( LaunchTemplateVersionLabel = AWSLabelPrefix + "launch-template-version" SubnetNameLabel = AWSLabelPrefix + "subnet-name" SubnetTagKeyLabel = AWSLabelPrefix + "subnet-tag-key" + SecurityGroupNameLabel = AWSLabelPrefix + "security-group-name" + SecurityGroupTagKeyLabel = AWSLabelPrefix + "security-group-tag-key" AllowedLabels = []string{ CapacityTypeLabel, LaunchTemplateIdLabel, LaunchTemplateVersionLabel, SubnetNameLabel, SubnetTagKeyLabel, + SecurityGroupNameLabel, + SecurityGroupTagKeyLabel, } AWSToKubeArchitectures = map[string]string{ "x86_64": v1alpha1.ArchitectureAmd64, @@ -66,10 +70,18 @@ func (c *Constraints) GetCapacityType() string { } type LaunchTemplate struct { - Id *string - Version *string + Id string + Version string } +// func (c *Constraints) GetProvisionerName() string { +// return c.Labels[v1alpha1.ProvisionerNameLabelKey] +// } + +// func (c *Constraints) GetProvisionerNamespace() string { +// return c.Labels[v1alpha1.ProvisionerNamespaceLabelKey] +// } + func (c *Constraints) GetLaunchTemplate() *LaunchTemplate { id, ok := c.Labels[LaunchTemplateIdLabel] if !ok { @@ -80,28 +92,44 @@ func (c *Constraints) GetLaunchTemplate() *LaunchTemplate { version = DefaultLaunchTemplateVersion } return &LaunchTemplate{ - Id: &id, - Version: &version, + Id: id, + Version: version, } } func (c *Constraints) GetSubnetName() *string { - subnetName, ok := c.Labels[SubnetNameLabel] + name, ok := c.Labels[SubnetNameLabel] if !ok { return nil } - return aws.String(subnetName) + return aws.String(name) } func (c *Constraints) GetSubnetTagKey() *string { - subnetTag, ok := c.Labels[SubnetTagKeyLabel] + tag, ok := c.Labels[SubnetTagKeyLabel] + if !ok { + return nil + } + return aws.String(tag) +} + +func (c *Constraints) GetSecurityGroupName() *string { + name, ok := c.Labels[SecurityGroupNameLabel] + if !ok { + return nil + } + return aws.String(name) +} + +func (c *Constraints) GetSecurityGroupTagKey() *string { + tag, ok := c.Labels[SecurityGroupTagKeyLabel] if !ok { return nil } - return aws.String(subnetTag) + return aws.String(tag) } -func (c *Constraints) Validate(ctx context.Context) (errs *apis.FieldError) { +func (c *Constraints) Validate(ctx context.Context) (errs *apis.FieldError) { return errs.Also( c.validateAllowedLabels(ctx), c.validateCapacityType(ctx), diff --git a/pkg/cloudprovider/aws/fake/ec2api.go b/pkg/cloudprovider/aws/fake/ec2api.go index 0d87e854ceb5..2e74edca5755 100644 --- a/pkg/cloudprovider/aws/fake/ec2api.go +++ b/pkg/cloudprovider/aws/fake/ec2api.go @@ -123,7 +123,11 @@ func (e *EC2API) DescribeSecurityGroupsWithContext(context.Context, *ec2.Describ if e.DescribeSecurityGroupsOutput != nil { return e.DescribeSecurityGroupsOutput, nil } - return &ec2.DescribeSecurityGroupsOutput{SecurityGroups: []*ec2.SecurityGroup{{GroupId: aws.String("test-group")}}}, nil + return &ec2.DescribeSecurityGroupsOutput{SecurityGroups: []*ec2.SecurityGroup{ + {GroupId: aws.String("test-group-1"), Tags: []*ec2.Tag{{Key: aws.String("Name"), Value: aws.String("test-group-1")}}}, + {GroupId: aws.String("test-group-1"), Tags: []*ec2.Tag{{Key: aws.String("Name"), Value: aws.String("test-group-1")}}}, + {GroupId: aws.String("test-group-1"), Tags: []*ec2.Tag{{Key: aws.String("Name"), Value: aws.String("test-group-1")}}}, + }}, nil } func (e *EC2API) DescribeAvailabilityZonesWithContext(context.Context, *ec2.DescribeAvailabilityZonesInput, ...request.Option) (*ec2.DescribeAvailabilityZonesOutput, error) { diff --git a/pkg/cloudprovider/aws/instance.go b/pkg/cloudprovider/aws/instance.go index d7dbb0cac747..2a36d5d0fd85 100644 --- a/pkg/cloudprovider/aws/instance.go +++ b/pkg/cloudprovider/aws/instance.go @@ -86,8 +86,8 @@ func (p *InstanceProvider) Create(ctx context.Context, }, LaunchTemplateConfigs: []*ec2.FleetLaunchTemplateConfigRequest{{ LaunchTemplateSpecification: &ec2.FleetLaunchTemplateSpecificationRequest{ - LaunchTemplateId: launchTemplate.Id, - Version: launchTemplate.Version, + LaunchTemplateId: aws.String(launchTemplate.Id), + Version: aws.String(launchTemplate.Version), }, Overrides: overrides, }}, diff --git a/pkg/cloudprovider/aws/launchtemplate.go b/pkg/cloudprovider/aws/launchtemplate.go index cbea03f7d302..e97e7c3ac788 100644 --- a/pkg/cloudprovider/aws/launchtemplate.go +++ b/pkg/cloudprovider/aws/launchtemplate.go @@ -19,37 +19,31 @@ import ( "context" "encoding/base64" "fmt" - "strings" "text/template" "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/awserr" "github.com/aws/aws-sdk-go/service/ec2" "github.com/aws/aws-sdk-go/service/ec2/ec2iface" - "github.com/aws/aws-sdk-go/service/ssm" - "github.com/aws/aws-sdk-go/service/ssm/ssmiface" "github.com/awslabs/karpenter/pkg/apis/provisioning/v1alpha1" "github.com/mitchellh/hashstructure/v2" "github.com/patrickmn/go-cache" "go.uber.org/zap" - v1 "k8s.io/api/core/v1" - "k8s.io/apimachinery/pkg/types" - "k8s.io/client-go/kubernetes" ) const ( - launchTemplateNameFormat = "Karpenter-%s/%s/%s-%s" + launchTemplateNameFormat = "Karpenter-%s-%s" bottlerocketUserData = ` [settings.kubernetes] api-server = "{{.Cluster.Endpoint}}" cluster-certificate = "{{.Cluster.CABundle}}" cluster-name = "{{.Cluster.Name}}" -{{if .Labels }}[settings.kubernetes.node-labels]{{ end }} -{{ range $Key, $Value := .Labels }}"{{ $Key }}" = "{{ $Value }}" +{{if .Constraints.Labels }}[settings.kubernetes.node-labels]{{ end }} +{{ range $Key, $Value := .Constraints.Labels }}"{{ $Key }}" = "{{ $Value }}" {{ end }} -{{if .Taints }}[settings.kubernetes.node-taints]{{ end }} -{{ range $Taint := .Taints }}"{{ $Taint.Key }}" = "{{ $Taint.Value}}:{{ $Taint.Effect }}" +{{if .Constraints.Taints }}[settings.kubernetes.node-taints]{{ end }} +{{ range $Taint := .Constraints.Taints }}"{{ $Taint.Key }}" = "{{ $Taint.Value}}:{{ $Taint.Effect }}" {{ end }} ` ) @@ -57,9 +51,8 @@ cluster-name = "{{.Cluster.Name}}" type LaunchTemplateProvider struct { ec2api ec2iface.EC2API cache *cache.Cache + amiProvider *AMIProvider securityGroupProvider *SecurityGroupProvider - ssm ssmiface.SSMAPI - clientSet *kubernetes.Clientset } func launchTemplateName(options *launchTemplateOptions) string { @@ -67,90 +60,80 @@ func launchTemplateName(options *launchTemplateOptions) string { if err != nil { zap.S().Panicf("hashing launch template, %w", err) } - return fmt.Sprintf(launchTemplateNameFormat, options.Cluster.Name, options.Provisioner.Name, options.Provisioner.Namespace, fmt.Sprint(hash)) + return fmt.Sprintf(launchTemplateNameFormat, options.Cluster.Name, fmt.Sprint(hash)) } // launchTemplateOptions is hashed and results in the creation of a real EC2 // LaunchTemplate. Do not change this struct without thinking through the impact // to the number of LaunchTemplates that will result from this change. type launchTemplateOptions struct { - Provisioner types.NamespacedName - Cluster v1alpha1.ClusterSpec - Architecture string - Labels map[string]string - Taints []v1.Taint + // Edge-triggered fields that will only change on kube events. + Cluster v1alpha1.ClusterSpec + UserData string + // Constraints Constraints + // Level-triggered fields that may change out of sync. + SecurityGroups []string + AMIID string } func (p *LaunchTemplateProvider) Get(ctx context.Context, provisioner *v1alpha1.Provisioner, constraints *Constraints) (*LaunchTemplate, error) { - // If the customer specified a launch template then just use it + // 1. If the customer specified a launch template then just use it if result := constraints.GetLaunchTemplate(); result != nil { return result, nil } - options := launchTemplateOptions{ - Provisioner: types.NamespacedName{Name: provisioner.Name, Namespace: provisioner.Namespace}, - Cluster: *provisioner.Spec.Cluster, - Architecture: KubeToAWSArchitectures[*constraints.Architecture], - Labels: constraints.Labels, - Taints: constraints.Taints, - } - // See if we have a cached copy of the default one first, to avoid - // making an API call to EC2 - key, err := hashstructure.Hash(options, hashstructure.FormatV2, nil) + // 2. Get constrained AMI ID + amiID, err := p.amiProvider.Get(ctx, constraints) if err != nil { - return nil, fmt.Errorf("hashing launch template, %w", err) + return nil, err } - result := &LaunchTemplate{Version: aws.String(DefaultLaunchTemplateVersion)} - if cached, ok := p.cache.Get(fmt.Sprint(key)); ok { - result.Id = cached.(*ec2.LaunchTemplate).LaunchTemplateId - return result, nil + // 3. Get constrained security groups + securityGroups, err := p.getSecurityGroupIds(ctx, provisioner, constraints) + if err != nil { + return nil, err } - // Call EC2 to get launch template, creating if necessary - launchTemplate, err := p.getLaunchTemplate(ctx, &options) + // 4. Ensure the launch template exists, or create it + launchTemplate, err := p.ensureLaunchTemplate(ctx, &launchTemplateOptions{ + Cluster: *provisioner.Spec.Cluster, + UserData: p.getUserData(provisioner, constraints), + AMIID: amiID, + SecurityGroups: securityGroups, + }) if err != nil { return nil, err } - result.Id = launchTemplate.LaunchTemplateId - p.cache.Set(fmt.Sprint(key), launchTemplate, CacheTTL) - return result, nil + return &LaunchTemplate{ + Id: aws.StringValue(launchTemplate.LaunchTemplateId), + Version: fmt.Sprint(DefaultLaunchTemplateVersion), + }, nil } -// TODO, reconcile launch template if not equal to desired launch template (AMI upgrade, role changed, etc) -func (p *LaunchTemplateProvider) getLaunchTemplate(ctx context.Context, options *launchTemplateOptions) (*ec2.LaunchTemplate, error) { - describelaunchTemplateOutput, err := p.ec2api.DescribeLaunchTemplatesWithContext(ctx, &ec2.DescribeLaunchTemplatesInput{ - LaunchTemplateNames: []*string{aws.String(launchTemplateName(options))}, +func (p *LaunchTemplateProvider) ensureLaunchTemplate(ctx context.Context, options *launchTemplateOptions) (*ec2.LaunchTemplate, error) { + name := launchTemplateName(options) + output, err := p.ec2api.DescribeLaunchTemplatesWithContext(ctx, &ec2.DescribeLaunchTemplatesInput{ + LaunchTemplateNames: []*string{aws.String(name)}, }) + var launchTemplate *ec2.LaunchTemplate if aerr, ok := err.(awserr.Error); ok && aerr.Code() == "InvalidLaunchTemplateName.NotFoundException" { - return p.createLaunchTemplate(ctx, options) - } - if err != nil { + launchTemplate, err = p.createLaunchTemplate(ctx, options) + if err != nil { + return nil, fmt.Errorf("creating launch template, %w", err) + } + } else if err != nil { return nil, fmt.Errorf("describing launch templates, %w", err) - } - if length := len(describelaunchTemplateOutput.LaunchTemplates); length > 1 { + } else if length := len(output.LaunchTemplates); length > 1 { return nil, fmt.Errorf("expected to find one launch template, but found %d", length) + } else { + launchTemplate = output.LaunchTemplates[0] } - launchTemplate := describelaunchTemplateOutput.LaunchTemplates[0] - zap.S().Debugf("Successfully discovered launch template %s for %s/%s", *launchTemplate.LaunchTemplateName, options.Provisioner.Name, options.Provisioner.Namespace) + p.cache.Set(name, launchTemplate, CacheTTL) + zap.S().Debugf("Successfully discovered launch template %s", name) return launchTemplate, nil } func (p *LaunchTemplateProvider) createLaunchTemplate(ctx context.Context, options *launchTemplateOptions) (*ec2.LaunchTemplate, error) { - securityGroupIds, err := p.getSecurityGroupIds(ctx, options.Cluster.Name) - if err != nil { - return nil, fmt.Errorf("getting security groups, %w", err) - } - amiID, err := p.getAMIID(ctx, options.Architecture) - if err != nil { - return nil, fmt.Errorf("getting AMI ID, %w", err) - } - zap.S().Debugf("Successfully discovered AMI ID %s for architecture %s", *amiID, options.Architecture) - userData, err := p.getUserData(options) - if err != nil { - return nil, fmt.Errorf("getting user data, %w", err) - } - output, err := p.ec2api.CreateLaunchTemplate(&ec2.CreateLaunchTemplateInput{ LaunchTemplateName: aws.String(launchTemplateName(options)), LaunchTemplateData: &ec2.RequestLaunchTemplateData{ @@ -170,57 +153,38 @@ func (p *LaunchTemplateProvider) createLaunchTemplate(ctx context.Context, optio }, }, }}, - SecurityGroupIds: securityGroupIds, - UserData: userData, - ImageId: amiID, + SecurityGroupIds: aws.StringSlice(options.SecurityGroups), + UserData: aws.String(options.UserData), + ImageId: aws.String(options.AMIID), }, }) if err != nil { - return nil, fmt.Errorf("creating launch template, %w", err) + return nil, err } zap.S().Debugf("Successfully created default launch template, %s", *output.LaunchTemplate.LaunchTemplateName) return output.LaunchTemplate, nil } -func (p *LaunchTemplateProvider) getSecurityGroupIds(ctx context.Context, clusterName string) ([]*string, error) { - securityGroupIds := []*string{} - securityGroups, err := p.securityGroupProvider.Get(ctx, clusterName) +func (p *LaunchTemplateProvider) getSecurityGroupIds(ctx context.Context, provisioner *v1alpha1.Provisioner, constraints *Constraints) ([]string, error) { + securityGroupIds := []string{} + securityGroups, err := p.securityGroupProvider.Get(ctx, provisioner, constraints) if err != nil { - return nil, err + return nil, fmt.Errorf("getting security group ids, %w", err) } for _, securityGroup := range securityGroups { - securityGroupIds = append(securityGroupIds, securityGroup.GroupId) + securityGroupIds = append(securityGroupIds, aws.StringValue(securityGroup.GroupId)) } return securityGroupIds, nil } -func (p *LaunchTemplateProvider) getAMIID(ctx context.Context, arch string) (*string, error) { - version, err := p.kubeServerVersion() - if err != nil { - return nil, fmt.Errorf("kube server version, %w", err) - } - paramOutput, err := p.ssm.GetParameterWithContext(ctx, &ssm.GetParameterInput{ - Name: aws.String(fmt.Sprintf("/aws/service/bottlerocket/aws-k8s-%s/%s/latest/image_id", version, arch)), - }) - if err != nil { - return nil, fmt.Errorf("getting ssm parameter, %w", err) - } - return paramOutput.Parameter.Value, nil -} - -func (p *LaunchTemplateProvider) getUserData(options *launchTemplateOptions) (*string, error) { +func (p *LaunchTemplateProvider) getUserData(provisioner *v1alpha1.Provisioner, constraints *Constraints) string { t := template.Must(template.New("userData").Parse(bottlerocketUserData)) var userData bytes.Buffer - if err := t.Execute(&userData, options); err != nil { - return nil, err - } - return aws.String(base64.StdEncoding.EncodeToString(userData.Bytes())), nil -} - -func (p *LaunchTemplateProvider) kubeServerVersion() (string, error) { - version, err := p.clientSet.Discovery().ServerVersion() - if err != nil { - return "", err + if err := t.Execute(&userData, struct { + Constraints *Constraints + Cluster v1alpha1.ClusterSpec + }{constraints, *provisioner.Spec.Cluster}); err != nil { + panic(fmt.Sprintf("Parsing user data from %v, %v, %s", provisioner, constraints, err.Error())) } - return fmt.Sprintf("%s.%s", version.Major, strings.TrimSuffix(version.Minor, "+")), nil + return base64.StdEncoding.EncodeToString(userData.Bytes()) } diff --git a/pkg/cloudprovider/aws/securitygroups.go b/pkg/cloudprovider/aws/securitygroups.go index 0e305debeceb..821d0b615f68 100644 --- a/pkg/cloudprovider/aws/securitygroups.go +++ b/pkg/cloudprovider/aws/securitygroups.go @@ -17,9 +17,12 @@ package aws import ( "context" "fmt" + "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/service/ec2" "github.com/aws/aws-sdk-go/service/ec2/ec2iface" + "github.com/awslabs/karpenter/pkg/apis/provisioning/v1alpha1" + "github.com/awslabs/karpenter/pkg/cloudprovider/aws/utils" "github.com/patrickmn/go-cache" "go.uber.org/zap" ) @@ -36,26 +39,47 @@ func NewSecurityGroupProvider(ec2api ec2iface.EC2API) *SecurityGroupProvider { } } -func (s *SecurityGroupProvider) Get(ctx context.Context, clusterName string) ([]*ec2.SecurityGroup, error) { - if securityGroups, ok := s.cache.Get(clusterName); ok { - return securityGroups.([]*ec2.SecurityGroup), nil +func (s *SecurityGroupProvider) Get(ctx context.Context, provisioner *v1alpha1.Provisioner, constraints *Constraints) ([]*ec2.SecurityGroup, error) { + // 1. Get Security Groups + securityGroups, err := s.getSecurityGroups(ctx, provisioner.Spec.Cluster.Name) + if err != nil { + return nil, err } - return s.getSecurityGroups(ctx, clusterName) + // 2. Filter by subnet name if constrained + if name := constraints.GetSecurityGroupName(); name != nil { + securityGroups = filterSecurityGroups(utils.ByName(aws.StringValue(name)), securityGroups) + } + // 3. Filter by security group tag key if constrained + if name := constraints.GetSecurityGroupTagKey(); name != nil { + securityGroups = filterSecurityGroups(utils.ByTagKey(aws.StringValue(name)), securityGroups) + } + return securityGroups, nil } func (s *SecurityGroupProvider) getSecurityGroups(ctx context.Context, clusterName string) ([]*ec2.SecurityGroup, error) { - describeSecurityGroupOutput, err := s.ec2api.DescribeSecurityGroupsWithContext(ctx, &ec2.DescribeSecurityGroupsInput{ + if securityGroups, ok := s.cache.Get(clusterName); ok { + return securityGroups.([]*ec2.SecurityGroup), nil + } + output, err := s.ec2api.DescribeSecurityGroupsWithContext(ctx, &ec2.DescribeSecurityGroupsInput{ Filters: []*ec2.Filter{{ - Name: aws.String("tag-key"), + Name: aws.String("tag-key"), // Security Groups must be tagged for the cluster Values: []*string{aws.String(fmt.Sprintf(ClusterTagKeyFormat, clusterName))}, }}, }) if err != nil { return nil, fmt.Errorf("describing security groups with tag key %s, %w", fmt.Sprintf(ClusterTagKeyFormat, clusterName), err) } + s.cache.Set(clusterName, output.SecurityGroups, CacheTTL) + zap.S().Debugf("Successfully discovered %d security groups for cluster %s", len(output.SecurityGroups), clusterName) + return output.SecurityGroups, nil +} - securityGroups := describeSecurityGroupOutput.SecurityGroups - s.cache.Set(clusterName, securityGroups, CacheTTL) - zap.S().Debugf("Successfully discovered %d security groups for cluster %s", len(securityGroups), clusterName) - return securityGroups, nil +func filterSecurityGroups(predicate func([]*ec2.Tag) bool, securityGroups []*ec2.SecurityGroup) []*ec2.SecurityGroup { + result := []*ec2.SecurityGroup{} + for _, securityGroup := range securityGroups { + if predicate(securityGroup.Tags) { + result = append(result, securityGroup) + } + } + return result } diff --git a/pkg/cloudprovider/aws/subnets.go b/pkg/cloudprovider/aws/subnets.go index f478e9336ab9..3bc6f6baf77b 100644 --- a/pkg/cloudprovider/aws/subnets.go +++ b/pkg/cloudprovider/aws/subnets.go @@ -22,6 +22,7 @@ import ( "github.com/aws/aws-sdk-go/service/ec2" "github.com/aws/aws-sdk-go/service/ec2/ec2iface" "github.com/awslabs/karpenter/pkg/apis/provisioning/v1alpha1" + "github.com/awslabs/karpenter/pkg/cloudprovider/aws/utils" "github.com/patrickmn/go-cache" "go.uber.org/zap" ) @@ -46,11 +47,11 @@ func (s *SubnetProvider) Get(ctx context.Context, provisioner *v1alpha1.Provisio } // 2. Filter by subnet name if constrained if name := constraints.GetSubnetName(); name != nil { - subnets = filter(byName(aws.StringValue(name)), subnets) + subnets = filterSubnets(utils.ByName(aws.StringValue(name)), subnets) } // 3. Filter by subnet tag key if constrained if tagKey := constraints.GetSubnetTagKey(); tagKey != nil { - subnets = filter(byTagKey(*tagKey), subnets) + subnets = filterSubnets(utils.ByTagKey(*tagKey), subnets) } return subnets, nil } @@ -66,39 +67,17 @@ func (s *SubnetProvider) getSubnets(ctx context.Context, provisioner *v1alpha1.P if err != nil { return nil, fmt.Errorf("describing subnets, %w", err) } - zap.S().Debugf("Successfully discovered %d subnets for cluster %s", len(output.Subnets), provisioner.Spec.Cluster.Name) s.cache.Set(provisioner.Spec.Cluster.Name, output.Subnets, CacheTTL) + zap.S().Debugf("Successfully discovered %d subnets for cluster %s", len(output.Subnets), provisioner.Spec.Cluster.Name) return output.Subnets, nil } -func filter(predicate func(*ec2.Subnet) bool, subnets []*ec2.Subnet) []*ec2.Subnet { +func filterSubnets(predicate func([]*ec2.Tag) bool, subnets []*ec2.Subnet) []*ec2.Subnet { result := []*ec2.Subnet{} for _, subnet := range subnets { - if predicate(subnet) { + if predicate(subnet.Tags) { result = append(result, subnet) } } return result } - -func byName(name string) func(*ec2.Subnet) bool { - return func(subnet *ec2.Subnet) bool { - for _, tag := range subnet.Tags { - if aws.StringValue(tag.Key) == "Name" { - return aws.StringValue(tag.Value) == name - } - } - return false - } -} - -func byTagKey(tagKey string) func(*ec2.Subnet) bool { - return func(subnet *ec2.Subnet) bool { - for _, tag := range subnet.Tags { - if aws.StringValue(tag.Key) == tagKey { - return true - } - } - return false - } -} diff --git a/pkg/cloudprovider/aws/suite_test.go b/pkg/cloudprovider/aws/suite_test.go index f502f079f96d..59234618ebfe 100644 --- a/pkg/cloudprovider/aws/suite_test.go +++ b/pkg/cloudprovider/aws/suite_test.go @@ -59,8 +59,7 @@ var env = test.NewEnvironment(func(e *test.Environment) { ec2api: fakeEC2API, cache: launchTemplateCache, securityGroupProvider: NewSecurityGroupProvider(fakeEC2API), - ssm: &fake.SSMAPI{}, - clientSet: clientSet, + amiProvider: NewAMIProvider(&fake.SSMAPI{}, clientSet), }, subnetProvider: NewSubnetProvider(fakeEC2API), instanceTypeProvider: NewInstanceTypeProvider(fakeEC2API), diff --git a/pkg/cloudprovider/aws/utils/tags.go b/pkg/cloudprovider/aws/utils/tags.go new file mode 100644 index 000000000000..5b10a301ad4d --- /dev/null +++ b/pkg/cloudprovider/aws/utils/tags.go @@ -0,0 +1,28 @@ +package utils + +import ( + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/service/ec2" +) + +func ByName(name string) func([]*ec2.Tag) bool { + return func(tags []*ec2.Tag) bool { + for _, tag := range tags { + if aws.StringValue(tag.Key) == "Name" { + return aws.StringValue(tag.Value) == name + } + } + return false + } +} + +func ByTagKey(tagKey string) func([]*ec2.Tag) bool { + return func(tags []*ec2.Tag) bool { + for _, tag := range tags { + if aws.StringValue(tag.Key) == tagKey { + return true + } + } + return false + } +}