Skip to content

Commit

Permalink
Support for Security Group specification/override
Browse files Browse the repository at this point in the history
  • Loading branch information
ellistarn committed Jun 23, 2021
1 parent 5fdb59e commit 46a02ba
Show file tree
Hide file tree
Showing 13 changed files with 382 additions and 257 deletions.
55 changes: 55 additions & 0 deletions pkg/cloudprovider/aws/ami.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
package aws

import (
"context"
"fmt"
"strings"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/service/ssm"
"github.com/aws/aws-sdk-go/service/ssm/ssmiface"
"github.com/patrickmn/go-cache"
"go.uber.org/zap"
"k8s.io/client-go/kubernetes"
)

type AMIProvider struct {
cache *cache.Cache
ssm ssmiface.SSMAPI
clientSet *kubernetes.Clientset
}

func NewAMIProvider(ssm ssmiface.SSMAPI, clientSet *kubernetes.Clientset) *AMIProvider {
return &AMIProvider{
ssm: ssm,
clientSet: clientSet,
cache: cache.New(CacheTTL, CacheCleanupInterval),
}
}

func (p *AMIProvider) Get(ctx context.Context, constraints *Constraints) (string, error) {
version, err := p.kubeServerVersion()
if err != nil {
return "", fmt.Errorf("kube server version, %w", err)
}
name := fmt.Sprintf("/aws/service/bottlerocket/aws-k8s-%s/%s/latest/image_id", version, KubeToAWSArchitectures[*constraints.Architecture])
if id, ok := p.cache.Get(name); ok {
return id.(string), nil
}
output, err := p.ssm.GetParameterWithContext(ctx, &ssm.GetParameterInput{Name: aws.String(name)})
if err != nil {
return "", fmt.Errorf("getting ssm parameter, %w", err)
}
ami := aws.StringValue(output.Parameter.Value)
p.cache.Set(name, ami, CacheTTL)
zap.S().Debugf("Successfully discovered ami %s for query %s", ami, name)
return ami, nil
}

func (p *AMIProvider) kubeServerVersion() (string, error) {
version, err := p.clientSet.Discovery().ServerVersion()
if err != nil {
return "", err
}
return fmt.Sprintf("%s.%s", version.Major, strings.TrimSuffix(version.Minor, "+")), nil
}
13 changes: 5 additions & 8 deletions pkg/cloudprovider/aws/cloudprovider.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ import (
"github.com/awslabs/karpenter/pkg/cloudprovider"
"github.com/awslabs/karpenter/pkg/cloudprovider/aws/utils"
"github.com/awslabs/karpenter/pkg/utils/project"
"github.com/patrickmn/go-cache"
"go.uber.org/zap"
v1 "k8s.io/api/core/v1"
"knative.dev/pkg/apis"
Expand Down Expand Up @@ -76,13 +75,11 @@ func NewCloudProvider(options cloudprovider.Options) *CloudProvider {
ec2api := ec2.New(sess)
return &CloudProvider{
nodeAPI: &NodeFactory{ec2api: ec2api},
launchTemplateProvider: &LaunchTemplateProvider{
ec2api: ec2api,
cache: cache.New(CacheTTL, CacheCleanupInterval),
securityGroupProvider: NewSecurityGroupProvider(ec2api),
ssm: ssm.New(sess),
clientSet: options.ClientSet,
},
launchTemplateProvider: NewLaunchTemplateProvider(
ec2api,
NewAMIProvider(ssm.New(sess), options.ClientSet),
NewSecurityGroupProvider(ec2api),
),
subnetProvider: NewSubnetProvider(ec2api),
instanceTypeProvider: NewInstanceTypeProvider(ec2api),
instanceProvider: &InstanceProvider{ec2api: ec2api},
Expand Down
36 changes: 28 additions & 8 deletions pkg/cloudprovider/aws/constraints.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,16 @@ var (
LaunchTemplateVersionLabel = AWSLabelPrefix + "launch-template-version"
SubnetNameLabel = AWSLabelPrefix + "subnet-name"
SubnetTagKeyLabel = AWSLabelPrefix + "subnet-tag-key"
SecurityGroupNameLabel = AWSLabelPrefix + "security-group-name"
SecurityGroupTagKeyLabel = AWSLabelPrefix + "security-group-tag-key"
AllowedLabels = []string{
CapacityTypeLabel,
LaunchTemplateIdLabel,
LaunchTemplateVersionLabel,
SubnetNameLabel,
SubnetTagKeyLabel,
SecurityGroupNameLabel,
SecurityGroupTagKeyLabel,
}
AWSToKubeArchitectures = map[string]string{
"x86_64": v1alpha1.ArchitectureAmd64,
Expand All @@ -66,8 +70,8 @@ func (c *Constraints) GetCapacityType() string {
}

type LaunchTemplate struct {
Id *string
Version *string
Id string
Version string
}

func (c *Constraints) GetLaunchTemplate() *LaunchTemplate {
Expand All @@ -80,25 +84,41 @@ func (c *Constraints) GetLaunchTemplate() *LaunchTemplate {
version = DefaultLaunchTemplateVersion
}
return &LaunchTemplate{
Id: &id,
Version: &version,
Id: id,
Version: version,
}
}

func (c *Constraints) GetSubnetName() *string {
subnetName, ok := c.Labels[SubnetNameLabel]
name, ok := c.Labels[SubnetNameLabel]
if !ok {
return nil
}
return aws.String(subnetName)
return aws.String(name)
}

func (c *Constraints) GetSubnetTagKey() *string {
subnetTag, ok := c.Labels[SubnetTagKeyLabel]
tag, ok := c.Labels[SubnetTagKeyLabel]
if !ok {
return nil
}
return aws.String(subnetTag)
return aws.String(tag)
}

func (c *Constraints) GetSecurityGroupName() *string {
name, ok := c.Labels[SecurityGroupNameLabel]
if !ok {
return nil
}
return aws.String(name)
}

func (c *Constraints) GetSecurityGroupTagKey() *string {
tag, ok := c.Labels[SecurityGroupTagKeyLabel]
if !ok {
return nil
}
return aws.String(tag)
}

func (c *Constraints) Validate(ctx context.Context) (errs *apis.FieldError) {
Expand Down
68 changes: 31 additions & 37 deletions pkg/cloudprovider/aws/fake/ec2api.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (

"github.com/Pallinder/go-randomdata"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/service/ec2"
"github.com/aws/aws-sdk-go/service/ec2/ec2iface"
Expand All @@ -29,17 +30,17 @@ import (
// EC2Behavior must be reset between tests otherwise tests will
// pollute each other.
type EC2Behavior struct {
CreateFleetOutput *ec2.CreateFleetOutput
DescribeInstancesOutput *ec2.DescribeInstancesOutput
DescribeLaunchTemplatesOutput *ec2.DescribeLaunchTemplatesOutput
DescribeSubnetsOutput *ec2.DescribeSubnetsOutput
DescribeSecurityGroupsOutput *ec2.DescribeSecurityGroupsOutput
DescribeInstanceTypesOutput *ec2.DescribeInstanceTypesOutput
DescribeInstanceTypeOfferingsOutput *ec2.DescribeInstanceTypeOfferingsOutput
DescribeAvailabilityZonesOutput *ec2.DescribeAvailabilityZonesOutput
WantErr error
CalledWithCreateFleetInput []ec2.CreateFleetInput
CalledWithCreateFleetInput []*ec2.CreateFleetInput
CalledWithCreateLaunchTemplateInput []*ec2.CreateLaunchTemplateInput
Instances []*ec2.Instance
LaunchTemplates []*ec2.LaunchTemplate
}

type EC2API struct {
Expand All @@ -54,13 +55,7 @@ func (e *EC2API) Reset() {
}

func (e *EC2API) CreateFleetWithContext(ctx context.Context, input *ec2.CreateFleetInput, options ...request.Option) (*ec2.CreateFleetOutput, error) {
e.CalledWithCreateFleetInput = append(e.CalledWithCreateFleetInput, *input)
if e.WantErr != nil {
return nil, e.WantErr
}
if e.CreateFleetOutput != nil {
return e.CreateFleetOutput, nil
}
e.CalledWithCreateFleetInput = append(e.CalledWithCreateFleetInput, input)
if input.LaunchTemplateConfigs[0].LaunchTemplateSpecification.LaunchTemplateId == nil &&
input.LaunchTemplateConfigs[0].LaunchTemplateSpecification.LaunchTemplateName == nil {
return nil, fmt.Errorf("missing launch template id or name")
Expand All @@ -69,15 +64,20 @@ func (e *EC2API) CreateFleetWithContext(ctx context.Context, input *ec2.CreateFl
InstanceId: aws.String(randomdata.SillyName()),
Placement: &ec2.Placement{AvailabilityZone: aws.String("test-zone-1a")},
PrivateDnsName: aws.String(fmt.Sprintf("test-instance-%d.example.com", len(e.Instances))),
InstanceType: input.LaunchTemplateConfigs[0].Overrides[0].InstanceType,
}
e.Instances = append(e.Instances, instance)
return &ec2.CreateFleetOutput{Instances: []*ec2.CreateFleetInstance{{InstanceIds: []*string{instance.InstanceId}}}}, nil
}

func (e *EC2API) CreateLaunchTemplateWithContext(ctx context.Context, input *ec2.CreateLaunchTemplateInput, options ...request.Option) (*ec2.CreateLaunchTemplateOutput, error) {
e.CalledWithCreateLaunchTemplateInput = append(e.CalledWithCreateLaunchTemplateInput, input)
launchTemplate := &ec2.LaunchTemplate{LaunchTemplateName: input.LaunchTemplateName, LaunchTemplateId: aws.String("test-launch-template-id")}
e.LaunchTemplates = append(e.LaunchTemplates, launchTemplate)
return &ec2.CreateLaunchTemplateOutput{LaunchTemplate: launchTemplate}, nil
}

func (e *EC2API) DescribeInstancesWithContext(context.Context, *ec2.DescribeInstancesInput, ...request.Option) (*ec2.DescribeInstancesOutput, error) {
if e.WantErr != nil {
return nil, e.WantErr
}
if e.DescribeInstancesOutput != nil {
return e.DescribeInstancesOutput, nil
}
Expand All @@ -86,23 +86,25 @@ func (e *EC2API) DescribeInstancesWithContext(context.Context, *ec2.DescribeInst
}, nil
}

func (e *EC2API) DescribeLaunchTemplatesWithContext(context.Context, *ec2.DescribeLaunchTemplatesInput, ...request.Option) (*ec2.DescribeLaunchTemplatesOutput, error) {
if e.WantErr != nil {
return nil, e.WantErr
}
func (e *EC2API) DescribeLaunchTemplatesWithContext(ctx context.Context, input *ec2.DescribeLaunchTemplatesInput, options ...request.Option) (*ec2.DescribeLaunchTemplatesOutput, error) {
if e.DescribeLaunchTemplatesOutput != nil {
return e.DescribeLaunchTemplatesOutput, nil
}
return &ec2.DescribeLaunchTemplatesOutput{LaunchTemplates: []*ec2.LaunchTemplate{{
LaunchTemplateName: aws.String("test-launch-template-name"),
LaunchTemplateId: aws.String("test-launch-template-id"),
}}}, nil
output := &ec2.DescribeLaunchTemplatesOutput{}
for _, wanted := range input.LaunchTemplateNames {
for _, launchTemplate := range e.LaunchTemplates {
if launchTemplate.LaunchTemplateName == wanted {
output.LaunchTemplates = append(output.LaunchTemplates, launchTemplate)
}
}
}
if len(output.LaunchTemplates) == 0 {
return nil, awserr.New("InvalidLaunchTemplateName.NotFoundException", "not found", nil)
}
return output, nil
}

func (e *EC2API) DescribeSubnetsWithContext(context.Context, *ec2.DescribeSubnetsInput, ...request.Option) (*ec2.DescribeSubnetsOutput, error) {
if e.WantErr != nil {
return nil, e.WantErr
}
if e.DescribeSubnetsOutput != nil {
return e.DescribeSubnetsOutput, nil
}
Expand All @@ -117,19 +119,17 @@ func (e *EC2API) DescribeSubnetsWithContext(context.Context, *ec2.DescribeSubnet
}

func (e *EC2API) DescribeSecurityGroupsWithContext(context.Context, *ec2.DescribeSecurityGroupsInput, ...request.Option) (*ec2.DescribeSecurityGroupsOutput, error) {
if e.WantErr != nil {
return nil, e.WantErr
}
if e.DescribeSecurityGroupsOutput != nil {
return e.DescribeSecurityGroupsOutput, nil
}
return &ec2.DescribeSecurityGroupsOutput{SecurityGroups: []*ec2.SecurityGroup{{GroupId: aws.String("test-group")}}}, nil
return &ec2.DescribeSecurityGroupsOutput{SecurityGroups: []*ec2.SecurityGroup{
{GroupId: aws.String("test-security-group-1"), Tags: []*ec2.Tag{{Key: aws.String("Name"), Value: aws.String("test-security-group-1")}}},
{GroupId: aws.String("test-security-group-2"), Tags: []*ec2.Tag{{Key: aws.String("Name"), Value: aws.String("test-security-group-2")}}},
{GroupId: aws.String("test-security-group-3"), Tags: []*ec2.Tag{{Key: aws.String("Name"), Value: aws.String("test-security-group-3")}, {Key: aws.String("TestTag")}}},
}}, nil
}

func (e *EC2API) DescribeAvailabilityZonesWithContext(context.Context, *ec2.DescribeAvailabilityZonesInput, ...request.Option) (*ec2.DescribeAvailabilityZonesOutput, error) {
if e.WantErr != nil {
return nil, e.WantErr
}
if e.DescribeAvailabilityZonesOutput != nil {
return e.DescribeAvailabilityZonesOutput, nil
}
Expand All @@ -141,9 +141,6 @@ func (e *EC2API) DescribeAvailabilityZonesWithContext(context.Context, *ec2.Desc
}

func (e *EC2API) DescribeInstanceTypesPagesWithContext(ctx context.Context, input *ec2.DescribeInstanceTypesInput, fn func(*ec2.DescribeInstanceTypesOutput, bool) bool, opts ...request.Option) error {
if e.WantErr != nil {
return e.WantErr
}
if e.DescribeInstanceTypesOutput != nil {
fn(e.DescribeInstanceTypesOutput, false)
return nil
Expand Down Expand Up @@ -267,9 +264,6 @@ func (e *EC2API) DescribeInstanceTypesPagesWithContext(ctx context.Context, inpu
}

func (e *EC2API) DescribeInstanceTypeOfferingsPagesWithContext(ctx context.Context, input *ec2.DescribeInstanceTypeOfferingsInput, fn func(*ec2.DescribeInstanceTypeOfferingsOutput, bool) bool, opts ...request.Option) error {
if e.WantErr != nil {
return e.WantErr
}
if e.DescribeInstanceTypeOfferingsOutput != nil {
fn(e.DescribeInstanceTypeOfferingsOutput, false)
return nil
Expand Down
35 changes: 0 additions & 35 deletions pkg/cloudprovider/aws/fake/sqsqueue.go

This file was deleted.

5 changes: 2 additions & 3 deletions pkg/cloudprovider/aws/instance.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,8 @@ func (p *InstanceProvider) Create(ctx context.Context,
},
LaunchTemplateConfigs: []*ec2.FleetLaunchTemplateConfigRequest{{
LaunchTemplateSpecification: &ec2.FleetLaunchTemplateSpecificationRequest{
LaunchTemplateId: launchTemplate.Id,
Version: launchTemplate.Version,
LaunchTemplateId: aws.String(launchTemplate.Id),
Version: aws.String(launchTemplate.Version),
},
Overrides: overrides,
}},
Expand All @@ -102,7 +102,6 @@ func (p *InstanceProvider) Create(ctx context.Context,
if count := len(createFleetOutput.Instances[0].InstanceIds); count != 1 {
return nil, fmt.Errorf("expected 1 instance ids, but got %d due to errors %v", count, createFleetOutput.Errors)
}
// TODO aggregate errors
if count := len(createFleetOutput.Errors); count > 0 {
zap.S().Warnf("CreateFleet encountered %d errors, but still launched instances, %v", count, createFleetOutput.Errors)
}
Expand Down
Loading

0 comments on commit 46a02ba

Please sign in to comment.