diff --git a/pkg/cloudprovider/aws/ami.go b/pkg/cloudprovider/aws/ami.go new file mode 100644 index 000000000000..0250df721ae3 --- /dev/null +++ b/pkg/cloudprovider/aws/ami.go @@ -0,0 +1,51 @@ +package aws + +import ( + "context" + "fmt" + "strings" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/service/ssm" + "github.com/aws/aws-sdk-go/service/ssm/ssmiface" + "github.com/patrickmn/go-cache" + "k8s.io/client-go/kubernetes" +) + +type AMIProvider struct { + cache *cache.Cache + ssm ssmiface.SSMAPI + clientSet *kubernetes.Clientset +} + +func NewAMIProvider(ssm ssmiface.SSMAPI, clientSet *kubernetes.Clientset) *AMIProvider { + return &AMIProvider{ + ssm: ssm, + clientSet: clientSet, + cache: cache.New(CacheTTL, CacheCleanupInterval), + } +} + +func (p *AMIProvider) Get(ctx context.Context, constraints *Constraints) (string, error) { + version, err := p.kubeServerVersion() + if err != nil { + return "", fmt.Errorf("kube server version, %w", err) + } + name := fmt.Sprintf("/aws/service/bottlerocket/aws-k8s-%s/%s/latest/image_id", version, KubeToAWSArchitectures[*constraints.Architecture]) + if id, ok := p.cache.Get(name); ok { + return id.(string), nil + } + output, err := p.ssm.GetParameterWithContext(ctx, &ssm.GetParameterInput{Name: aws.String(name)}) + if err != nil { + return "", fmt.Errorf("getting ssm parameter, %w", err) + } + return aws.StringValue(output.Parameter.Value), nil +} + +func (p *AMIProvider) kubeServerVersion() (string, error) { + version, err := p.clientSet.Discovery().ServerVersion() + if err != nil { + return "", err + } + return fmt.Sprintf("%s.%s", version.Major, strings.TrimSuffix(version.Minor, "+")), nil +} diff --git a/pkg/cloudprovider/aws/cloudprovider.go b/pkg/cloudprovider/aws/cloudprovider.go index 39df68fbeb32..5bc5ee1e7637 100644 --- a/pkg/cloudprovider/aws/cloudprovider.go +++ b/pkg/cloudprovider/aws/cloudprovider.go @@ -28,7 +28,6 @@ import ( "github.com/awslabs/karpenter/pkg/cloudprovider" "github.com/awslabs/karpenter/pkg/cloudprovider/aws/utils" "github.com/awslabs/karpenter/pkg/utils/project" - "github.com/patrickmn/go-cache" "go.uber.org/zap" v1 "k8s.io/api/core/v1" "knative.dev/pkg/apis" @@ -76,13 +75,11 @@ func NewCloudProvider(options cloudprovider.Options) *CloudProvider { ec2api := ec2.New(sess) return &CloudProvider{ nodeAPI: &NodeFactory{ec2api: ec2api}, - launchTemplateProvider: &LaunchTemplateProvider{ - ec2api: ec2api, - cache: cache.New(CacheTTL, CacheCleanupInterval), - securityGroupProvider: NewSecurityGroupProvider(ec2api), - ssm: ssm.New(sess), - clientSet: options.ClientSet, - }, + launchTemplateProvider: NewLaunchTemplateProvider( + ec2api, + NewAMIProvider(ssm.New(sess), options.ClientSet), + NewSecurityGroupProvider(ec2api), + ), subnetProvider: NewSubnetProvider(ec2api), instanceTypeProvider: NewInstanceTypeProvider(ec2api), instanceProvider: &InstanceProvider{ec2api: ec2api}, diff --git a/pkg/cloudprovider/aws/constraints.go b/pkg/cloudprovider/aws/constraints.go index 929687369a77..9af53f2b3abb 100644 --- a/pkg/cloudprovider/aws/constraints.go +++ b/pkg/cloudprovider/aws/constraints.go @@ -38,12 +38,16 @@ var ( LaunchTemplateVersionLabel = AWSLabelPrefix + "launch-template-version" SubnetNameLabel = AWSLabelPrefix + "subnet-name" SubnetTagKeyLabel = AWSLabelPrefix + "subnet-tag-key" + SecurityGroupNameLabel = AWSLabelPrefix + "security-group-name" + SecurityGroupTagKeyLabel = AWSLabelPrefix + "security-group-tag-key" AllowedLabels = []string{ CapacityTypeLabel, LaunchTemplateIdLabel, LaunchTemplateVersionLabel, SubnetNameLabel, SubnetTagKeyLabel, + SecurityGroupNameLabel, + SecurityGroupTagKeyLabel, } AWSToKubeArchitectures = map[string]string{ "x86_64": v1alpha1.ArchitectureAmd64, @@ -66,8 +70,8 @@ func (c *Constraints) GetCapacityType() string { } type LaunchTemplate struct { - Id *string - Version *string + Id string + Version string } func (c *Constraints) GetLaunchTemplate() *LaunchTemplate { @@ -80,25 +84,41 @@ func (c *Constraints) GetLaunchTemplate() *LaunchTemplate { version = DefaultLaunchTemplateVersion } return &LaunchTemplate{ - Id: &id, - Version: &version, + Id: id, + Version: version, } } func (c *Constraints) GetSubnetName() *string { - subnetName, ok := c.Labels[SubnetNameLabel] + name, ok := c.Labels[SubnetNameLabel] if !ok { return nil } - return aws.String(subnetName) + return aws.String(name) } func (c *Constraints) GetSubnetTagKey() *string { - subnetTag, ok := c.Labels[SubnetTagKeyLabel] + tag, ok := c.Labels[SubnetTagKeyLabel] if !ok { return nil } - return aws.String(subnetTag) + return aws.String(tag) +} + +func (c *Constraints) GetSecurityGroupName() *string { + name, ok := c.Labels[SecurityGroupNameLabel] + if !ok { + return nil + } + return aws.String(name) +} + +func (c *Constraints) GetSecurityGroupTagKey() *string { + tag, ok := c.Labels[SecurityGroupTagKeyLabel] + if !ok { + return nil + } + return aws.String(tag) } func (c *Constraints) Validate(ctx context.Context) (errs *apis.FieldError) { diff --git a/pkg/cloudprovider/aws/fake/ec2api.go b/pkg/cloudprovider/aws/fake/ec2api.go index 0d87e854ceb5..f15752e96d2a 100644 --- a/pkg/cloudprovider/aws/fake/ec2api.go +++ b/pkg/cloudprovider/aws/fake/ec2api.go @@ -20,6 +20,7 @@ import ( "github.com/Pallinder/go-randomdata" "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/awserr" "github.com/aws/aws-sdk-go/aws/request" "github.com/aws/aws-sdk-go/service/ec2" "github.com/aws/aws-sdk-go/service/ec2/ec2iface" @@ -29,7 +30,6 @@ import ( // EC2Behavior must be reset between tests otherwise tests will // pollute each other. type EC2Behavior struct { - CreateFleetOutput *ec2.CreateFleetOutput DescribeInstancesOutput *ec2.DescribeInstancesOutput DescribeLaunchTemplatesOutput *ec2.DescribeLaunchTemplatesOutput DescribeSubnetsOutput *ec2.DescribeSubnetsOutput @@ -37,9 +37,10 @@ type EC2Behavior struct { DescribeInstanceTypesOutput *ec2.DescribeInstanceTypesOutput DescribeInstanceTypeOfferingsOutput *ec2.DescribeInstanceTypeOfferingsOutput DescribeAvailabilityZonesOutput *ec2.DescribeAvailabilityZonesOutput - WantErr error - CalledWithCreateFleetInput []ec2.CreateFleetInput + CalledWithCreateFleetInput []*ec2.CreateFleetInput + CalledWithCreateLaunchTemplateInput []*ec2.CreateLaunchTemplateInput Instances []*ec2.Instance + LaunchTemplates []*ec2.LaunchTemplate } type EC2API struct { @@ -54,13 +55,7 @@ func (e *EC2API) Reset() { } func (e *EC2API) CreateFleetWithContext(ctx context.Context, input *ec2.CreateFleetInput, options ...request.Option) (*ec2.CreateFleetOutput, error) { - e.CalledWithCreateFleetInput = append(e.CalledWithCreateFleetInput, *input) - if e.WantErr != nil { - return nil, e.WantErr - } - if e.CreateFleetOutput != nil { - return e.CreateFleetOutput, nil - } + e.CalledWithCreateFleetInput = append(e.CalledWithCreateFleetInput, input) if input.LaunchTemplateConfigs[0].LaunchTemplateSpecification.LaunchTemplateId == nil && input.LaunchTemplateConfigs[0].LaunchTemplateSpecification.LaunchTemplateName == nil { return nil, fmt.Errorf("missing launch template id or name") @@ -74,10 +69,14 @@ func (e *EC2API) CreateFleetWithContext(ctx context.Context, input *ec2.CreateFl return &ec2.CreateFleetOutput{Instances: []*ec2.CreateFleetInstance{{InstanceIds: []*string{instance.InstanceId}}}}, nil } +func (e *EC2API) CreateLaunchTemplateWithContext(ctx context.Context, input *ec2.CreateLaunchTemplateInput, options ...request.Option) (*ec2.CreateLaunchTemplateOutput, error) { + e.CalledWithCreateLaunchTemplateInput = append(e.CalledWithCreateLaunchTemplateInput, input) + launchTemplate := &ec2.LaunchTemplate{LaunchTemplateName: input.LaunchTemplateName, LaunchTemplateId: aws.String("test-launch-template-id")} + e.LaunchTemplates = append(e.LaunchTemplates, launchTemplate) + return &ec2.CreateLaunchTemplateOutput{LaunchTemplate: launchTemplate}, nil +} + func (e *EC2API) DescribeInstancesWithContext(context.Context, *ec2.DescribeInstancesInput, ...request.Option) (*ec2.DescribeInstancesOutput, error) { - if e.WantErr != nil { - return nil, e.WantErr - } if e.DescribeInstancesOutput != nil { return e.DescribeInstancesOutput, nil } @@ -86,23 +85,25 @@ func (e *EC2API) DescribeInstancesWithContext(context.Context, *ec2.DescribeInst }, nil } -func (e *EC2API) DescribeLaunchTemplatesWithContext(context.Context, *ec2.DescribeLaunchTemplatesInput, ...request.Option) (*ec2.DescribeLaunchTemplatesOutput, error) { - if e.WantErr != nil { - return nil, e.WantErr - } +func (e *EC2API) DescribeLaunchTemplatesWithContext(ctx context.Context, input *ec2.DescribeLaunchTemplatesInput, options ...request.Option) (*ec2.DescribeLaunchTemplatesOutput, error) { if e.DescribeLaunchTemplatesOutput != nil { return e.DescribeLaunchTemplatesOutput, nil } - return &ec2.DescribeLaunchTemplatesOutput{LaunchTemplates: []*ec2.LaunchTemplate{{ - LaunchTemplateName: aws.String("test-launch-template-name"), - LaunchTemplateId: aws.String("test-launch-template-id"), - }}}, nil + output := &ec2.DescribeLaunchTemplatesOutput{} + for _, wanted := range input.LaunchTemplateNames { + for _, launchTemplate := range e.LaunchTemplates { + if launchTemplate.LaunchTemplateName == wanted { + output.LaunchTemplates = append(output.LaunchTemplates, launchTemplate) + } + } + } + if len(output.LaunchTemplates) == 0 { + return nil, awserr.New("InvalidLaunchTemplateName.NotFoundException", "not found", nil) + } + return output, nil } func (e *EC2API) DescribeSubnetsWithContext(context.Context, *ec2.DescribeSubnetsInput, ...request.Option) (*ec2.DescribeSubnetsOutput, error) { - if e.WantErr != nil { - return nil, e.WantErr - } if e.DescribeSubnetsOutput != nil { return e.DescribeSubnetsOutput, nil } @@ -117,19 +118,17 @@ func (e *EC2API) DescribeSubnetsWithContext(context.Context, *ec2.DescribeSubnet } func (e *EC2API) DescribeSecurityGroupsWithContext(context.Context, *ec2.DescribeSecurityGroupsInput, ...request.Option) (*ec2.DescribeSecurityGroupsOutput, error) { - if e.WantErr != nil { - return nil, e.WantErr - } if e.DescribeSecurityGroupsOutput != nil { return e.DescribeSecurityGroupsOutput, nil } - return &ec2.DescribeSecurityGroupsOutput{SecurityGroups: []*ec2.SecurityGroup{{GroupId: aws.String("test-group")}}}, nil + return &ec2.DescribeSecurityGroupsOutput{SecurityGroups: []*ec2.SecurityGroup{ + {GroupId: aws.String("test-security-group-1"), Tags: []*ec2.Tag{{Key: aws.String("Name"), Value: aws.String("test-security-group-1")}}}, + {GroupId: aws.String("test-security-group-2"), Tags: []*ec2.Tag{{Key: aws.String("Name"), Value: aws.String("test-security-group-2")}}}, + {GroupId: aws.String("test-security-group-3"), Tags: []*ec2.Tag{{Key: aws.String("Name"), Value: aws.String("test-security-group-3")}, {Key: aws.String("TestTag")}}}, + }}, nil } func (e *EC2API) DescribeAvailabilityZonesWithContext(context.Context, *ec2.DescribeAvailabilityZonesInput, ...request.Option) (*ec2.DescribeAvailabilityZonesOutput, error) { - if e.WantErr != nil { - return nil, e.WantErr - } if e.DescribeAvailabilityZonesOutput != nil { return e.DescribeAvailabilityZonesOutput, nil } @@ -141,9 +140,6 @@ func (e *EC2API) DescribeAvailabilityZonesWithContext(context.Context, *ec2.Desc } func (e *EC2API) DescribeInstanceTypesPagesWithContext(ctx context.Context, input *ec2.DescribeInstanceTypesInput, fn func(*ec2.DescribeInstanceTypesOutput, bool) bool, opts ...request.Option) error { - if e.WantErr != nil { - return e.WantErr - } if e.DescribeInstanceTypesOutput != nil { fn(e.DescribeInstanceTypesOutput, false) return nil @@ -267,9 +263,6 @@ func (e *EC2API) DescribeInstanceTypesPagesWithContext(ctx context.Context, inpu } func (e *EC2API) DescribeInstanceTypeOfferingsPagesWithContext(ctx context.Context, input *ec2.DescribeInstanceTypeOfferingsInput, fn func(*ec2.DescribeInstanceTypeOfferingsOutput, bool) bool, opts ...request.Option) error { - if e.WantErr != nil { - return e.WantErr - } if e.DescribeInstanceTypeOfferingsOutput != nil { fn(e.DescribeInstanceTypeOfferingsOutput, false) return nil diff --git a/pkg/cloudprovider/aws/fake/sqsqueue.go b/pkg/cloudprovider/aws/fake/sqsqueue.go deleted file mode 100644 index 0e18bcf29172..000000000000 --- a/pkg/cloudprovider/aws/fake/sqsqueue.go +++ /dev/null @@ -1,35 +0,0 @@ -/* -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package fake - -import ( - "github.com/aws/aws-sdk-go/service/sqs" - "github.com/aws/aws-sdk-go/service/sqs/sqsiface" -) - -type SQSAPI struct { - sqsiface.SQSAPI - QueueUrlOutput sqs.GetQueueUrlOutput - QueueAttributeOutput sqs.GetQueueAttributesOutput - WantErr error -} - -func (m SQSAPI) GetQueueUrl(*sqs.GetQueueUrlInput) (*sqs.GetQueueUrlOutput, error) { - return &m.QueueUrlOutput, m.WantErr -} - -func (m SQSAPI) GetQueueAttributes(*sqs.GetQueueAttributesInput) (*sqs.GetQueueAttributesOutput, error) { - return &m.QueueAttributeOutput, m.WantErr -} diff --git a/pkg/cloudprovider/aws/instance.go b/pkg/cloudprovider/aws/instance.go index 6f005a9a9f46..715bdcf3af72 100644 --- a/pkg/cloudprovider/aws/instance.go +++ b/pkg/cloudprovider/aws/instance.go @@ -87,8 +87,8 @@ func (p *InstanceProvider) Create(ctx context.Context, }, LaunchTemplateConfigs: []*ec2.FleetLaunchTemplateConfigRequest{{ LaunchTemplateSpecification: &ec2.FleetLaunchTemplateSpecificationRequest{ - LaunchTemplateId: launchTemplate.Id, - Version: launchTemplate.Version, + LaunchTemplateId: aws.String(launchTemplate.Id), + Version: aws.String(launchTemplate.Version), }, Overrides: overrides, }}, diff --git a/pkg/cloudprovider/aws/launchtemplate.go b/pkg/cloudprovider/aws/launchtemplate.go index cbea03f7d302..c35e916883d9 100644 --- a/pkg/cloudprovider/aws/launchtemplate.go +++ b/pkg/cloudprovider/aws/launchtemplate.go @@ -19,47 +19,49 @@ import ( "context" "encoding/base64" "fmt" - "strings" "text/template" "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/awserr" "github.com/aws/aws-sdk-go/service/ec2" "github.com/aws/aws-sdk-go/service/ec2/ec2iface" - "github.com/aws/aws-sdk-go/service/ssm" - "github.com/aws/aws-sdk-go/service/ssm/ssmiface" "github.com/awslabs/karpenter/pkg/apis/provisioning/v1alpha1" "github.com/mitchellh/hashstructure/v2" "github.com/patrickmn/go-cache" "go.uber.org/zap" - v1 "k8s.io/api/core/v1" - "k8s.io/apimachinery/pkg/types" - "k8s.io/client-go/kubernetes" ) const ( - launchTemplateNameFormat = "Karpenter-%s/%s/%s-%s" + launchTemplateNameFormat = "Karpenter-%s-%s" bottlerocketUserData = ` [settings.kubernetes] api-server = "{{.Cluster.Endpoint}}" cluster-certificate = "{{.Cluster.CABundle}}" cluster-name = "{{.Cluster.Name}}" -{{if .Labels }}[settings.kubernetes.node-labels]{{ end }} -{{ range $Key, $Value := .Labels }}"{{ $Key }}" = "{{ $Value }}" +{{if .Constraints.Labels }}[settings.kubernetes.node-labels]{{ end }} +{{ range $Key, $Value := .Constraints.Labels }}"{{ $Key }}" = "{{ $Value }}" {{ end }} -{{if .Taints }}[settings.kubernetes.node-taints]{{ end }} -{{ range $Taint := .Taints }}"{{ $Taint.Key }}" = "{{ $Taint.Value}}:{{ $Taint.Effect }}" +{{if .Constraints.Taints }}[settings.kubernetes.node-taints]{{ end }} +{{ range $Taint := .Constraints.Taints }}"{{ $Taint.Key }}" = "{{ $Taint.Value}}:{{ $Taint.Effect }}" {{ end }} ` ) type LaunchTemplateProvider struct { ec2api ec2iface.EC2API - cache *cache.Cache + amiProvider *AMIProvider securityGroupProvider *SecurityGroupProvider - ssm ssmiface.SSMAPI - clientSet *kubernetes.Clientset + cache *cache.Cache +} + +func NewLaunchTemplateProvider(ec2api ec2iface.EC2API, amiProvider *AMIProvider, securityGroupProvider *SecurityGroupProvider) *LaunchTemplateProvider { + return &LaunchTemplateProvider{ + ec2api: ec2api, + amiProvider: amiProvider, + securityGroupProvider: securityGroupProvider, + cache: cache.New(CacheTTL, CacheCleanupInterval), + } } func launchTemplateName(options *launchTemplateOptions) string { @@ -67,91 +69,80 @@ func launchTemplateName(options *launchTemplateOptions) string { if err != nil { zap.S().Panicf("hashing launch template, %w", err) } - return fmt.Sprintf(launchTemplateNameFormat, options.Cluster.Name, options.Provisioner.Name, options.Provisioner.Namespace, fmt.Sprint(hash)) + return fmt.Sprintf(launchTemplateNameFormat, options.Cluster.Name, fmt.Sprint(hash)) } // launchTemplateOptions is hashed and results in the creation of a real EC2 // LaunchTemplate. Do not change this struct without thinking through the impact // to the number of LaunchTemplates that will result from this change. type launchTemplateOptions struct { - Provisioner types.NamespacedName - Cluster v1alpha1.ClusterSpec - Architecture string - Labels map[string]string - Taints []v1.Taint + // Edge-triggered fields that will only change on kube events. + Cluster v1alpha1.ClusterSpec + UserData string + // Level-triggered fields that may change out of sync. + SecurityGroups []string + AMIID string } func (p *LaunchTemplateProvider) Get(ctx context.Context, provisioner *v1alpha1.Provisioner, constraints *Constraints) (*LaunchTemplate, error) { - // If the customer specified a launch template then just use it + // 1. If the customer specified a launch template then just use it if result := constraints.GetLaunchTemplate(); result != nil { return result, nil } - options := launchTemplateOptions{ - Provisioner: types.NamespacedName{Name: provisioner.Name, Namespace: provisioner.Namespace}, - Cluster: *provisioner.Spec.Cluster, - Architecture: KubeToAWSArchitectures[*constraints.Architecture], - Labels: constraints.Labels, - Taints: constraints.Taints, - } - // See if we have a cached copy of the default one first, to avoid - // making an API call to EC2 - key, err := hashstructure.Hash(options, hashstructure.FormatV2, nil) + // 2. Get constrained AMI ID + amiID, err := p.amiProvider.Get(ctx, constraints) if err != nil { - return nil, fmt.Errorf("hashing launch template, %w", err) + return nil, err } - result := &LaunchTemplate{Version: aws.String(DefaultLaunchTemplateVersion)} - if cached, ok := p.cache.Get(fmt.Sprint(key)); ok { - result.Id = cached.(*ec2.LaunchTemplate).LaunchTemplateId - return result, nil + // 3. Get constrained security groups + securityGroups, err := p.getSecurityGroupIds(ctx, provisioner, constraints) + if err != nil { + return nil, err } - // Call EC2 to get launch template, creating if necessary - launchTemplate, err := p.getLaunchTemplate(ctx, &options) + // 4. Ensure the launch template exists, or create it + launchTemplate, err := p.ensureLaunchTemplate(ctx, &launchTemplateOptions{ + Cluster: *provisioner.Spec.Cluster, + UserData: p.getUserData(provisioner, constraints), + AMIID: amiID, + SecurityGroups: securityGroups, + }) if err != nil { return nil, err } - result.Id = launchTemplate.LaunchTemplateId - p.cache.Set(fmt.Sprint(key), launchTemplate, CacheTTL) - return result, nil + return &LaunchTemplate{ + Id: aws.StringValue(launchTemplate.LaunchTemplateId), + Version: fmt.Sprint(DefaultLaunchTemplateVersion), + }, nil } -// TODO, reconcile launch template if not equal to desired launch template (AMI upgrade, role changed, etc) -func (p *LaunchTemplateProvider) getLaunchTemplate(ctx context.Context, options *launchTemplateOptions) (*ec2.LaunchTemplate, error) { - describelaunchTemplateOutput, err := p.ec2api.DescribeLaunchTemplatesWithContext(ctx, &ec2.DescribeLaunchTemplatesInput{ - LaunchTemplateNames: []*string{aws.String(launchTemplateName(options))}, +func (p *LaunchTemplateProvider) ensureLaunchTemplate(ctx context.Context, options *launchTemplateOptions) (*ec2.LaunchTemplate, error) { + name := launchTemplateName(options) + output, err := p.ec2api.DescribeLaunchTemplatesWithContext(ctx, &ec2.DescribeLaunchTemplatesInput{ + LaunchTemplateNames: []*string{aws.String(name)}, }) + var launchTemplate *ec2.LaunchTemplate if aerr, ok := err.(awserr.Error); ok && aerr.Code() == "InvalidLaunchTemplateName.NotFoundException" { - return p.createLaunchTemplate(ctx, options) - } - if err != nil { + launchTemplate, err = p.createLaunchTemplate(ctx, options) + if err != nil { + return nil, fmt.Errorf("creating launch template, %w", err) + } + } else if err != nil { return nil, fmt.Errorf("describing launch templates, %w", err) + } else if len(output.LaunchTemplates) != 1 { + return nil, fmt.Errorf("expected to find one launch template, but found %d", len(output.LaunchTemplates)) + } else { + zap.S().Debugf("Successfully discovered launch template %s", name) + launchTemplate = output.LaunchTemplates[0] } - if length := len(describelaunchTemplateOutput.LaunchTemplates); length > 1 { - return nil, fmt.Errorf("expected to find one launch template, but found %d", length) - } - launchTemplate := describelaunchTemplateOutput.LaunchTemplates[0] - zap.S().Debugf("Successfully discovered launch template %s for %s/%s", *launchTemplate.LaunchTemplateName, options.Provisioner.Name, options.Provisioner.Namespace) + p.cache.Set(name, launchTemplate, CacheTTL) return launchTemplate, nil } func (p *LaunchTemplateProvider) createLaunchTemplate(ctx context.Context, options *launchTemplateOptions) (*ec2.LaunchTemplate, error) { - securityGroupIds, err := p.getSecurityGroupIds(ctx, options.Cluster.Name) - if err != nil { - return nil, fmt.Errorf("getting security groups, %w", err) - } - amiID, err := p.getAMIID(ctx, options.Architecture) - if err != nil { - return nil, fmt.Errorf("getting AMI ID, %w", err) - } - zap.S().Debugf("Successfully discovered AMI ID %s for architecture %s", *amiID, options.Architecture) - userData, err := p.getUserData(options) - if err != nil { - return nil, fmt.Errorf("getting user data, %w", err) - } - - output, err := p.ec2api.CreateLaunchTemplate(&ec2.CreateLaunchTemplateInput{ + output, err := p.ec2api.CreateLaunchTemplateWithContext(ctx, &ec2.CreateLaunchTemplateInput{ LaunchTemplateName: aws.String(launchTemplateName(options)), LaunchTemplateData: &ec2.RequestLaunchTemplateData{ IamInstanceProfile: &ec2.LaunchTemplateIamInstanceProfileSpecificationRequest{ @@ -170,57 +161,38 @@ func (p *LaunchTemplateProvider) createLaunchTemplate(ctx context.Context, optio }, }, }}, - SecurityGroupIds: securityGroupIds, - UserData: userData, - ImageId: amiID, + SecurityGroupIds: aws.StringSlice(options.SecurityGroups), + UserData: aws.String(options.UserData), + ImageId: aws.String(options.AMIID), }, }) if err != nil { - return nil, fmt.Errorf("creating launch template, %w", err) + return nil, err } zap.S().Debugf("Successfully created default launch template, %s", *output.LaunchTemplate.LaunchTemplateName) return output.LaunchTemplate, nil } -func (p *LaunchTemplateProvider) getSecurityGroupIds(ctx context.Context, clusterName string) ([]*string, error) { - securityGroupIds := []*string{} - securityGroups, err := p.securityGroupProvider.Get(ctx, clusterName) +func (p *LaunchTemplateProvider) getSecurityGroupIds(ctx context.Context, provisioner *v1alpha1.Provisioner, constraints *Constraints) ([]string, error) { + securityGroupIds := []string{} + securityGroups, err := p.securityGroupProvider.Get(ctx, provisioner, constraints) if err != nil { - return nil, err + return nil, fmt.Errorf("getting security group ids, %w", err) } for _, securityGroup := range securityGroups { - securityGroupIds = append(securityGroupIds, securityGroup.GroupId) + securityGroupIds = append(securityGroupIds, aws.StringValue(securityGroup.GroupId)) } return securityGroupIds, nil } -func (p *LaunchTemplateProvider) getAMIID(ctx context.Context, arch string) (*string, error) { - version, err := p.kubeServerVersion() - if err != nil { - return nil, fmt.Errorf("kube server version, %w", err) - } - paramOutput, err := p.ssm.GetParameterWithContext(ctx, &ssm.GetParameterInput{ - Name: aws.String(fmt.Sprintf("/aws/service/bottlerocket/aws-k8s-%s/%s/latest/image_id", version, arch)), - }) - if err != nil { - return nil, fmt.Errorf("getting ssm parameter, %w", err) - } - return paramOutput.Parameter.Value, nil -} - -func (p *LaunchTemplateProvider) getUserData(options *launchTemplateOptions) (*string, error) { +func (p *LaunchTemplateProvider) getUserData(provisioner *v1alpha1.Provisioner, constraints *Constraints) string { t := template.Must(template.New("userData").Parse(bottlerocketUserData)) var userData bytes.Buffer - if err := t.Execute(&userData, options); err != nil { - return nil, err - } - return aws.String(base64.StdEncoding.EncodeToString(userData.Bytes())), nil -} - -func (p *LaunchTemplateProvider) kubeServerVersion() (string, error) { - version, err := p.clientSet.Discovery().ServerVersion() - if err != nil { - return "", err + if err := t.Execute(&userData, struct { + Constraints *Constraints + Cluster v1alpha1.ClusterSpec + }{constraints, *provisioner.Spec.Cluster}); err != nil { + panic(fmt.Sprintf("Parsing user data from %v, %v, %s", provisioner, constraints, err.Error())) } - return fmt.Sprintf("%s.%s", version.Major, strings.TrimSuffix(version.Minor, "+")), nil + return base64.StdEncoding.EncodeToString(userData.Bytes()) } diff --git a/pkg/cloudprovider/aws/securitygroups.go b/pkg/cloudprovider/aws/securitygroups.go index 0e305debeceb..ac81a9eeeb9d 100644 --- a/pkg/cloudprovider/aws/securitygroups.go +++ b/pkg/cloudprovider/aws/securitygroups.go @@ -17,9 +17,12 @@ package aws import ( "context" "fmt" + "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/service/ec2" "github.com/aws/aws-sdk-go/service/ec2/ec2iface" + "github.com/awslabs/karpenter/pkg/apis/provisioning/v1alpha1" + "github.com/awslabs/karpenter/pkg/cloudprovider/aws/utils/predicates" "github.com/patrickmn/go-cache" "go.uber.org/zap" ) @@ -36,26 +39,50 @@ func NewSecurityGroupProvider(ec2api ec2iface.EC2API) *SecurityGroupProvider { } } -func (s *SecurityGroupProvider) Get(ctx context.Context, clusterName string) ([]*ec2.SecurityGroup, error) { - if securityGroups, ok := s.cache.Get(clusterName); ok { - return securityGroups.([]*ec2.SecurityGroup), nil +func (s *SecurityGroupProvider) Get(ctx context.Context, provisioner *v1alpha1.Provisioner, constraints *Constraints) ([]*ec2.SecurityGroup, error) { + // 1. Get Security Groups + securityGroups, err := s.getSecurityGroups(ctx, provisioner.Spec.Cluster.Name) + if err != nil { + return nil, err } - return s.getSecurityGroups(ctx, clusterName) + // 2. Filter by subnet name if constrained + if name := constraints.GetSecurityGroupName(); name != nil { + securityGroups = filterSecurityGroups(securityGroups, withSecurityGroupTags(predicates.HasNameTag(*name))) + } + // 3. Filter by security group tag key if constrained + if tagKey := constraints.GetSecurityGroupTagKey(); tagKey != nil { + securityGroups = filterSecurityGroups(securityGroups, withSecurityGroupTags(predicates.HasTagKey(*tagKey))) + } + return securityGroups, nil } func (s *SecurityGroupProvider) getSecurityGroups(ctx context.Context, clusterName string) ([]*ec2.SecurityGroup, error) { - describeSecurityGroupOutput, err := s.ec2api.DescribeSecurityGroupsWithContext(ctx, &ec2.DescribeSecurityGroupsInput{ + if securityGroups, ok := s.cache.Get(clusterName); ok { + return securityGroups.([]*ec2.SecurityGroup), nil + } + output, err := s.ec2api.DescribeSecurityGroupsWithContext(ctx, &ec2.DescribeSecurityGroupsInput{ Filters: []*ec2.Filter{{ - Name: aws.String("tag-key"), + Name: aws.String("tag-key"), // Security Groups must be tagged for the cluster Values: []*string{aws.String(fmt.Sprintf(ClusterTagKeyFormat, clusterName))}, }}, }) if err != nil { return nil, fmt.Errorf("describing security groups with tag key %s, %w", fmt.Sprintf(ClusterTagKeyFormat, clusterName), err) } + s.cache.Set(clusterName, output.SecurityGroups, CacheTTL) + zap.S().Debugf("Successfully discovered %d security groups for cluster %s", len(output.SecurityGroups), clusterName) + return output.SecurityGroups, nil +} - securityGroups := describeSecurityGroupOutput.SecurityGroups - s.cache.Set(clusterName, securityGroups, CacheTTL) - zap.S().Debugf("Successfully discovered %d security groups for cluster %s", len(securityGroups), clusterName) - return securityGroups, nil +func filterSecurityGroups(securityGroups []*ec2.SecurityGroup, predicate func(securityGroup *ec2.SecurityGroup) bool) (result []*ec2.SecurityGroup) { + for _, securityGroup := range securityGroups { + if predicate(securityGroup) { + result = append(result, securityGroup) + } + } + return result +} + +func withSecurityGroupTags(predicate func([]*ec2.Tag) bool) func(securityGroup *ec2.SecurityGroup) bool { + return func(securityGroup *ec2.SecurityGroup) bool { return predicate(securityGroup.Tags) } } diff --git a/pkg/cloudprovider/aws/subnets.go b/pkg/cloudprovider/aws/subnets.go index acb615f24cfd..13afffd26ccd 100644 --- a/pkg/cloudprovider/aws/subnets.go +++ b/pkg/cloudprovider/aws/subnets.go @@ -22,6 +22,7 @@ import ( "github.com/aws/aws-sdk-go/service/ec2" "github.com/aws/aws-sdk-go/service/ec2/ec2iface" "github.com/awslabs/karpenter/pkg/apis/provisioning/v1alpha1" + "github.com/awslabs/karpenter/pkg/cloudprovider/aws/utils/predicates" "github.com/patrickmn/go-cache" "go.uber.org/zap" ) @@ -46,15 +47,15 @@ func (s *SubnetProvider) Get(ctx context.Context, provisioner *v1alpha1.Provisio } // 2. Filter by subnet name if constrained if name := constraints.GetSubnetName(); name != nil { - subnets = filter(byName(aws.StringValue(name)), subnets) + subnets = filterSubnets(subnets, withSubnetTags(predicates.HasNameTag(*name))) } // 3. Filter by subnet tag key if constrained if tagKey := constraints.GetSubnetTagKey(); tagKey != nil { - subnets = filter(byTagKey(*tagKey), subnets) + subnets = filterSubnets(subnets, withSubnetTags(predicates.HasTagKey(*tagKey))) } // 4. Filter by zones if constrained if len(constraints.Zones) != 0 { - subnets = filter(byZones(constraints.Zones), subnets) + subnets = filterSubnets(subnets, withSubnetZone(predicates.WithinStrings(constraints.Zones))) } return subnets, nil } @@ -70,13 +71,12 @@ func (s *SubnetProvider) getSubnets(ctx context.Context, provisioner *v1alpha1.P if err != nil { return nil, fmt.Errorf("describing subnets, %w", err) } - zap.S().Debugf("Successfully discovered %d subnets for cluster %s", len(output.Subnets), provisioner.Spec.Cluster.Name) s.cache.Set(provisioner.Spec.Cluster.Name, output.Subnets, CacheTTL) + zap.S().Debugf("Successfully discovered %d subnets for cluster %s", len(output.Subnets), provisioner.Spec.Cluster.Name) return output.Subnets, nil } -func filter(predicate func(*ec2.Subnet) bool, subnets []*ec2.Subnet) []*ec2.Subnet { - result := []*ec2.Subnet{} +func filterSubnets(subnets []*ec2.Subnet, predicate func(subnet *ec2.Subnet) bool) (result []*ec2.Subnet) { for _, subnet := range subnets { if predicate(subnet) { result = append(result, subnet) @@ -85,35 +85,10 @@ func filter(predicate func(*ec2.Subnet) bool, subnets []*ec2.Subnet) []*ec2.Subn return result } -func byName(name string) func(*ec2.Subnet) bool { - return func(subnet *ec2.Subnet) bool { - for _, tag := range subnet.Tags { - if aws.StringValue(tag.Key) == "Name" { - return aws.StringValue(tag.Value) == name - } - } - return false - } -} - -func byTagKey(tagKey string) func(*ec2.Subnet) bool { - return func(subnet *ec2.Subnet) bool { - for _, tag := range subnet.Tags { - if aws.StringValue(tag.Key) == tagKey { - return true - } - } - return false - } +func withSubnetTags(predicate func([]*ec2.Tag) bool) func(subnet *ec2.Subnet) bool { + return func(subnet *ec2.Subnet) bool { return predicate(subnet.Tags) } } -func byZones(zones []string) func(*ec2.Subnet) bool { - return func(subnet *ec2.Subnet) bool { - for _, zone := range zones { - if aws.StringValue(subnet.AvailabilityZone) == zone { - return true - } - } - return false - } +func withSubnetZone(predicate func(string) bool) func(subnet *ec2.Subnet) bool { + return func(subnet *ec2.Subnet) bool { return predicate(aws.StringValue(subnet.AvailabilityZone)) } } diff --git a/pkg/cloudprovider/aws/suite_test.go b/pkg/cloudprovider/aws/suite_test.go index a5d6f2518ac9..7a3280d7981d 100644 --- a/pkg/cloudprovider/aws/suite_test.go +++ b/pkg/cloudprovider/aws/suite_test.go @@ -46,26 +46,22 @@ func TestAPIs(t *testing.T) { RunSpecs(t, "CloudProvider/AWS") } -var subnetCache = cache.New(CacheTTL, CacheCleanupInterval) var launchTemplateCache = cache.New(CacheTTL, CacheCleanupInterval) -var instanceProfileCache = cache.New(CacheTTL, CacheCleanupInterval) -var securityGroupCache = cache.New(CacheTTL, CacheCleanupInterval) var fakeEC2API *fake.EC2API var env = test.NewEnvironment(func(e *test.Environment) { clientSet := kubernetes.NewForConfigOrDie(e.Manager.GetConfig()) fakeEC2API = &fake.EC2API{} cloudProvider := &CloudProvider{ - nodeAPI: &NodeFactory{ec2api: fakeEC2API}, + nodeAPI: &NodeFactory{fakeEC2API}, launchTemplateProvider: &LaunchTemplateProvider{ - ec2api: fakeEC2API, - cache: launchTemplateCache, - securityGroupProvider: NewSecurityGroupProvider(fakeEC2API), - ssm: &fake.SSMAPI{}, - clientSet: clientSet, + fakeEC2API, + NewAMIProvider(&fake.SSMAPI{}, clientSet), + NewSecurityGroupProvider(fakeEC2API), + cache.New(CacheTTL, CacheCleanupInterval), }, subnetProvider: NewSubnetProvider(fakeEC2API), instanceTypeProvider: NewInstanceTypeProvider(fakeEC2API), - instanceProvider: &InstanceProvider{ec2api: fakeEC2API}, + instanceProvider: &InstanceProvider{fakeEC2API}, } registry.RegisterOrDie(cloudProvider) e.Manager.RegisterControllers( @@ -109,14 +105,7 @@ var _ = Describe("Allocation", func() { AfterEach(func() { fakeEC2API.Reset() ExpectCleanedUp(env.Client) - for _, cache := range []*cache.Cache{ - subnetCache, - launchTemplateCache, - instanceProfileCache, - securityGroupCache, - } { - cache.Flush() - } + launchTemplateCache.Flush() }) Context("Reconciliation", func() { @@ -463,6 +452,76 @@ var _ = Describe("Allocation", func() { Expect(pod.Spec.NodeName).To(BeEmpty()) }) }) + Context("Security Groups", func() { + It("should default to the clusters security groups", func() { + // Setup + pod := AttemptProvisioning(env.Client, provisioner, test.PendingPod()) + // Assertions + node := ExpectNodeExists(env.Client, pod.Spec.NodeName) + Expect(fakeEC2API.CalledWithCreateLaunchTemplateInput).To(HaveLen(1)) + Expect(fakeEC2API.CalledWithCreateLaunchTemplateInput[0].LaunchTemplateData.SecurityGroupIds).To(ConsistOf( + aws.String("test-security-group-1"), + aws.String("test-security-group-2"), + aws.String("test-security-group-3"), + )) + Expect(node.Labels).ToNot(HaveKey(SecurityGroupNameLabel)) + Expect(node.Labels).ToNot(HaveKey(SecurityGroupTagKeyLabel)) + }) + It("should default to a provisioner's specified security groups name", func() { + // Setup + provisioner.Spec.Labels = map[string]string{SecurityGroupNameLabel: "test-security-group-2"} + pod := AttemptProvisioning(env.Client, provisioner, test.PendingPod()) + // Assertions + node := ExpectNodeExists(env.Client, pod.Spec.NodeName) + Expect(fakeEC2API.CalledWithCreateLaunchTemplateInput).To(HaveLen(1)) + Expect(fakeEC2API.CalledWithCreateLaunchTemplateInput[0].LaunchTemplateData.SecurityGroupIds).To(ConsistOf( + aws.String("test-security-group-2"), + )) + Expect(node.Labels).To(HaveKeyWithValue(SecurityGroupNameLabel, provisioner.Spec.Labels[SecurityGroupNameLabel])) + Expect(node.Labels).ToNot(HaveKey(SecurityGroupTagKeyLabel)) + }) + It("should default to a provisioner's specified security groups tag key", func() { + provisioner.Spec.Labels = map[string]string{SecurityGroupTagKeyLabel: "TestTag"} + pod := AttemptProvisioning(env.Client, provisioner, test.PendingPod()) + // Assertions + node := ExpectNodeExists(env.Client, pod.Spec.NodeName) + Expect(fakeEC2API.CalledWithCreateLaunchTemplateInput).To(HaveLen(1)) + Expect(fakeEC2API.CalledWithCreateLaunchTemplateInput[0].LaunchTemplateData.SecurityGroupIds).To(ConsistOf( + aws.String("test-security-group-3"), + )) + Expect(node.Labels).ToNot(HaveKey(SecurityGroupNameLabel)) + Expect(node.Labels).To(HaveKeyWithValue(SecurityGroupTagKeyLabel, provisioner.Spec.Labels[SecurityGroupTagKeyLabel])) + }) + It("should allow a pod to override the security groups name", func() { + // Setup + pod := AttemptProvisioning(env.Client, provisioner, + test.PendingPod(test.PodOptions{NodeSelector: map[string]string{SecurityGroupNameLabel: "test-security-group-2"}}), + ) + // Assertions + node := ExpectNodeExists(env.Client, pod.Spec.NodeName) + Expect(fakeEC2API.CalledWithCreateLaunchTemplateInput).To(HaveLen(1)) + Expect(fakeEC2API.CalledWithCreateLaunchTemplateInput[0].LaunchTemplateData.SecurityGroupIds).To(ConsistOf( + aws.String("test-security-group-2"), + )) + Expect(node.Labels).To(HaveKeyWithValue(SecurityGroupNameLabel, pod.Spec.NodeSelector[SecurityGroupNameLabel])) + Expect(node.Labels).ToNot(HaveKey(SecurityGroupTagKeyLabel)) + }) + It("should allow a pod to override the security groups tags", func() { + pod := AttemptProvisioning(env.Client, provisioner, + test.PendingPod(test.PodOptions{NodeSelector: map[string]string{SecurityGroupTagKeyLabel: "TestTag"}}), + ) + // Assertions + node := ExpectNodeExists(env.Client, pod.Spec.NodeName) + Expect(fakeEC2API.CalledWithCreateFleetInput).To(HaveLen(1)) + Expect(fakeEC2API.CalledWithCreateFleetInput[0].LaunchTemplateConfigs).To(HaveLen(1)) + Expect(fakeEC2API.CalledWithCreateLaunchTemplateInput).To(HaveLen(1)) + Expect(fakeEC2API.CalledWithCreateLaunchTemplateInput[0].LaunchTemplateData.SecurityGroupIds).To(ConsistOf( + aws.String("test-security-group-3"), + )) + Expect(node.Labels).ToNot(HaveKey(SecurityGroupNameLabel)) + Expect(node.Labels).To(HaveKeyWithValue(SecurityGroupTagKeyLabel, pod.Spec.NodeSelector[SecurityGroupTagKeyLabel])) + }) + }) }) Context("Validation", func() { Context("ClusterSpec", func() { diff --git a/pkg/cloudprovider/aws/utils/predicates/strings.go b/pkg/cloudprovider/aws/utils/predicates/strings.go new file mode 100644 index 000000000000..44c4c6484673 --- /dev/null +++ b/pkg/cloudprovider/aws/utils/predicates/strings.go @@ -0,0 +1,13 @@ +package predicates + +// WithinStrings returns a func that returns true if string is within strings. +func WithinStrings(allowed []string) func(string) bool { + return func(actual string) bool { + for _, expected := range allowed { + if expected == actual { + return true + } + } + return false + } +} diff --git a/pkg/cloudprovider/aws/utils/predicates/tags.go b/pkg/cloudprovider/aws/utils/predicates/tags.go new file mode 100644 index 000000000000..da352d8e730d --- /dev/null +++ b/pkg/cloudprovider/aws/utils/predicates/tags.go @@ -0,0 +1,30 @@ +package predicates + +import ( + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/service/ec2" +) + +// HasNameTag returns a func that returns true if name tag matches name +func HasNameTag(name string) func([]*ec2.Tag) bool { + return func(tags []*ec2.Tag) bool { + for _, tag := range tags { + if aws.StringValue(tag.Key) == "Name" { + return aws.StringValue(tag.Value) == name + } + } + return false + } +} + +// HasNameTag returns a func that returns true if tag exists with tagKey +func HasTagKey(tagKey string) func([]*ec2.Tag) bool { + return func(tags []*ec2.Tag) bool { + for _, tag := range tags { + if aws.StringValue(tag.Key) == tagKey { + return true + } + } + return false + } +}