Skip to content

Commit

Permalink
Simplified Cloudprovider Create API and delgate batching the provider…
Browse files Browse the repository at this point in the history
… specific implementation (#1575)
  • Loading branch information
ellistarn authored Mar 25, 2022
1 parent b61c9e9 commit 74c2601
Show file tree
Hide file tree
Showing 12 changed files with 168 additions and 225 deletions.
43 changes: 4 additions & 39 deletions pkg/cloudprovider/aws/cloudprovider.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@ import (
"fmt"
"time"

"github.com/aws/karpenter/pkg/utils/resources"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/client"
"github.com/aws/aws-sdk-go/aws/ec2metadata"
Expand All @@ -38,7 +36,6 @@ import (
"github.com/aws/karpenter/pkg/utils/injection"
"github.com/aws/karpenter/pkg/utils/project"

"go.uber.org/multierr"
v1 "k8s.io/api/core/v1"
"k8s.io/client-go/transport"
"knative.dev/pkg/apis"
Expand Down Expand Up @@ -101,24 +98,12 @@ func NewCloudProvider(ctx context.Context, options cloudprovider.Options) *Cloud
}

// Create a node given the constraints.
func (c *CloudProvider) Create(ctx context.Context, constraints *v1alpha5.Constraints, instanceTypes []cloudprovider.InstanceType, quantity int, callback func(*v1.Node) error) error {
vendorConstraints, err := v1alpha1.Deserialize(constraints)
if err != nil {
return err
}
instanceTypes = c.filterInstanceTypes(instanceTypes)

// Create will only return an error if zero nodes could be launched.
// Partial fulfillment will be logged
nodes, err := c.instanceProvider.Create(ctx, vendorConstraints, instanceTypes, quantity)
func (c *CloudProvider) Create(ctx context.Context, nodeRequest *cloudprovider.NodeRequest) (*v1.Node, error) {
vendorConstraints, err := v1alpha1.Deserialize(nodeRequest.Constraints)
if err != nil {
return fmt.Errorf("launching instances, %w", err)
return nil, err
}
var errs error
for _, node := range nodes {
errs = multierr.Append(errs, callback(node))
}
return errs
return c.instanceProvider.Create(ctx, vendorConstraints, nodeRequest.InstanceTypeOptions)
}

// GetInstanceTypes returns all available InstanceTypes despite accepting a Constraints struct (note that it does not utilize Requirements)
Expand Down Expand Up @@ -161,26 +146,6 @@ func (c *CloudProvider) Name() string {
return "aws"
}

// filterInstanceTypes is used to eliminate GPU instance types from the list of possible instance types when a
// non-GPU instance type will work. If the list of instance types consists of both GPU and non-GPU types, then only
// the non-GPU types will be returned. If it has only GPU types, the list will be returned unaltered.
func (c *CloudProvider) filterInstanceTypes(instanceTypes []cloudprovider.InstanceType) []cloudprovider.InstanceType {
var genericInstanceTypes []cloudprovider.InstanceType
for _, it := range instanceTypes {
itRes := it.Resources()
if resources.IsZero(itRes[v1alpha1.ResourceAWSNeuron]) &&
resources.IsZero(itRes[v1alpha1.ResourceAMDGPU]) &&
resources.IsZero(itRes[v1alpha1.ResourceNVIDIAGPU]) {
genericInstanceTypes = append(genericInstanceTypes, it)
}
}
// if we got some subset of non-GPU types, then prefer to use those
if len(genericInstanceTypes) != 0 {
return genericInstanceTypes
}
return instanceTypes
}

// get the current region from EC2 IMDS
func getRegionFromIMDS(sess *session.Session) string {
region, err := ec2metadata.New(sess).Region()
Expand Down
126 changes: 54 additions & 72 deletions pkg/cloudprovider/aws/instance.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ import (
"github.com/aws/karpenter/pkg/cloudprovider/aws/apis/v1alpha1"
"github.com/aws/karpenter/pkg/utils/injection"
"github.com/aws/karpenter/pkg/utils/options"
"github.com/aws/karpenter/pkg/utils/resources"
)

const (
Expand Down Expand Up @@ -68,45 +69,33 @@ func NewInstanceProvider(ec2api ec2iface.EC2API, instanceTypeProvider *InstanceT
// instanceTypes should be sorted by priority for spot capacity type.
// If spot is not used, the instanceTypes are not required to be sorted
// because we are using ec2 fleet's lowest-price OD allocation strategy
func (p *InstanceProvider) Create(ctx context.Context, constraints *v1alpha1.Constraints, instanceTypes []cloudprovider.InstanceType, quantity int) ([]*v1.Node, error) {
func (p *InstanceProvider) Create(ctx context.Context, constraints *v1alpha1.Constraints, instanceTypes []cloudprovider.InstanceType) (*v1.Node, error) {
// Launch Instance
ids, err := p.launchInstances(ctx, constraints, instanceTypes, quantity)
instanceTypes = p.filterInstanceTypes(instanceTypes)
id, err := p.launchInstance(ctx, constraints, instanceTypes)
if err != nil {
return nil, err
}
// Get Instance with backoff retry since EC2 is eventually consistent
instances := []*ec2.Instance{}
instance := &ec2.Instance{}
if err := retry.Do(
func() (err error) { instances, err = p.getInstances(ctx, ids); return err },
func() (err error) { instance, err = p.getInstance(ctx, aws.StringValue(id)); return err },
retry.Delay(1*time.Second),
retry.Attempts(6),
); err != nil && len(instances) == 0 {
); err != nil {
return nil, err
} else if err != nil {
logging.FromContext(ctx).Errorf("retrieving node name for %d/%d instances", quantity-len(instances), quantity)
}

nodes := []*v1.Node{}
for _, instance := range instances {
logging.FromContext(ctx).Infof("Launched instance: %s, hostname: %s, type: %s, zone: %s, capacityType: %s",
aws.StringValue(instance.InstanceId),
aws.StringValue(instance.PrivateDnsName),
aws.StringValue(instance.InstanceType),
aws.StringValue(instance.Placement.AvailabilityZone),
getCapacityType(instance),
)
// Convert Instance to Node
node, err := p.instanceToNode(ctx, instance, instanceTypes)
if err != nil {
logging.FromContext(ctx).Errorf("creating Node from an EC2 Instance: %s", err)
continue
}
nodes = append(nodes, node)
}
if len(nodes) == 0 {
return nil, fmt.Errorf("zero nodes were created")
}
return nodes, nil
logging.FromContext(ctx).Errorf("retrieving node name for instance %s", aws.StringValue(instance.InstanceId))
}
logging.FromContext(ctx).Infof("Launched instance: %s, hostname: %s, type: %s, zone: %s, capacityType: %s",
aws.StringValue(instance.InstanceId),
aws.StringValue(instance.PrivateDnsName),
aws.StringValue(instance.InstanceType),
aws.StringValue(instance.Placement.AvailabilityZone),
getCapacityType(instance),
)
// Convert Instance to Node
return p.instanceToNode(ctx, instance, instanceTypes), nil
}

func (p *InstanceProvider) Terminate(ctx context.Context, node *v1.Node) error {
Expand All @@ -125,9 +114,8 @@ func (p *InstanceProvider) Terminate(ctx context.Context, node *v1.Node) error {
return nil
}

func (p *InstanceProvider) launchInstances(ctx context.Context, constraints *v1alpha1.Constraints, instanceTypes []cloudprovider.InstanceType, quantity int) ([]*string, error) {
func (p *InstanceProvider) launchInstance(ctx context.Context, constraints *v1alpha1.Constraints, instanceTypes []cloudprovider.InstanceType) (*string, error) {
capacityType := p.getCapacityType(constraints, instanceTypes)

// Get Launch Template Configs, which may differ due to GPU or Architecture requirements
launchTemplateConfigs, err := p.getLaunchTemplateConfigs(ctx, constraints, instanceTypes, capacityType)
if err != nil {
Expand All @@ -140,7 +128,7 @@ func (p *InstanceProvider) launchInstances(ctx context.Context, constraints *v1a
LaunchTemplateConfigs: launchTemplateConfigs,
TargetCapacitySpecification: &ec2.TargetCapacitySpecificationRequest{
DefaultTargetCapacityType: aws.String(capacityType),
TotalTargetCapacity: aws.Int64(int64(quantity)),
TotalTargetCapacity: aws.Int64(1),
},
TagSpecifications: []*ec2.TagSpecification{
{ResourceType: aws.String(ec2.ResourceTypeInstance), Tags: tags},
Expand All @@ -157,14 +145,10 @@ func (p *InstanceProvider) launchInstances(ctx context.Context, constraints *v1a
return nil, fmt.Errorf("creating fleet %w", err)
}
p.updateUnavailableOfferingsCache(ctx, createFleetOutput.Errors, capacityType)
instanceIds := combineFleetInstances(*createFleetOutput)
if len(instanceIds) == 0 {
if len(createFleetOutput.Instances) == 0 || len(createFleetOutput.Instances[0].InstanceIds) == 0 {
return nil, combineFleetErrors(createFleetOutput.Errors)
} else if len(instanceIds) != quantity {
logging.FromContext(ctx).Errorf("Failed to launch %d EC2 instances out of the %d EC2 instances requested: %s",
quantity-len(instanceIds), quantity, combineFleetErrors(createFleetOutput.Errors).Error())
}
return instanceIds, nil
return createFleetOutput.Instances[0].InstanceIds[0], nil
}

func (p *InstanceProvider) getLaunchTemplateConfigs(ctx context.Context, constraints *v1alpha1.Constraints, instanceTypes []cloudprovider.InstanceType, capacityType string) ([]*ec2.FleetLaunchTemplateConfigRequest, error) {
Expand Down Expand Up @@ -239,34 +223,28 @@ func (p *InstanceProvider) getOverrides(instanceTypeOptions []cloudprovider.Inst
return overrides
}

func (p *InstanceProvider) getInstances(ctx context.Context, ids []*string) ([]*ec2.Instance, error) {
describeInstancesOutput, err := p.ec2api.DescribeInstancesWithContext(ctx, &ec2.DescribeInstancesInput{InstanceIds: ids})
func (p *InstanceProvider) getInstance(ctx context.Context, id string) (*ec2.Instance, error) {
describeInstancesOutput, err := p.ec2api.DescribeInstancesWithContext(ctx, &ec2.DescribeInstancesInput{InstanceIds: aws.StringSlice([]string{id})})
if isNotFound(err) {
return nil, err
}
if err != nil {
return nil, fmt.Errorf("failed to describe ec2 instances, %w", err)
}
describedInstances := combineReservations(describeInstancesOutput.Reservations)
if len(describedInstances) != len(ids) {
return nil, fmt.Errorf("expected %d instance(s), but got %d", len(ids), len(describedInstances))
if len(describeInstancesOutput.Reservations) != 1 || len(describeInstancesOutput.Reservations[0].Instances) != 1 {
return nil, fmt.Errorf("expected instance but got 0")
}
instance := describeInstancesOutput.Reservations[0].Instances[0]
if injection.GetOptions(ctx).GetAWSNodeNameConvention() == options.ResourceName {
return describedInstances, nil
return instance, nil
}

instances := []*ec2.Instance{}
for _, instance := range describedInstances {
if len(aws.StringValue(instance.PrivateDnsName)) == 0 {
err = multierr.Append(err, fmt.Errorf("got instance %s but PrivateDnsName was not set", aws.StringValue(instance.InstanceId)))
continue
}
instances = append(instances, instance)
if len(aws.StringValue(instance.PrivateDnsName)) == 0 {
return nil, multierr.Append(err, fmt.Errorf("got instance %s but PrivateDnsName was not set", aws.StringValue(instance.InstanceId)))
}
return instances, err
return instance, nil
}

func (p *InstanceProvider) instanceToNode(ctx context.Context, instance *ec2.Instance, instanceTypes []cloudprovider.InstanceType) (*v1.Node, error) {
func (p *InstanceProvider) instanceToNode(ctx context.Context, instance *ec2.Instance, instanceTypes []cloudprovider.InstanceType) *v1.Node {
for _, instanceType := range instanceTypes {
if instanceType.Name() == aws.StringValue(instance.InstanceType) {
nodeName := strings.ToLower(aws.StringValue(instance.PrivateDnsName))
Expand Down Expand Up @@ -310,10 +288,10 @@ func (p *InstanceProvider) instanceToNode(ctx context.Context, instance *ec2.Ins
OperatingSystem: v1alpha5.OperatingSystemLinux,
},
},
}, nil
}
}
}
return nil, fmt.Errorf("unrecognized instance type %s", aws.StringValue(instance.InstanceType))
panic(fmt.Sprintf("unrecognized instance type %s", aws.StringValue(instance.InstanceType)))
}

func (p *InstanceProvider) updateUnavailableOfferingsCache(ctx context.Context, errors []*ec2.CreateFleetError, capacityType string) {
Expand All @@ -340,6 +318,26 @@ func (p *InstanceProvider) getCapacityType(constraints *v1alpha1.Constraints, in
return v1alpha1.CapacityTypeOnDemand
}

// filterInstanceTypes is used to eliminate GPU instance types from the list of possible instance types when a
// non-GPU instance type will work. If the list of instance types consists of both GPU and non-GPU types, then only
// the non-GPU types will be returned. If it has only GPU types, the list will be returned unaltered.
func (p *InstanceProvider) filterInstanceTypes(instanceTypes []cloudprovider.InstanceType) []cloudprovider.InstanceType {
var genericInstanceTypes []cloudprovider.InstanceType
for _, it := range instanceTypes {
itRes := it.Resources()
if resources.IsZero(itRes[v1alpha1.ResourceAWSNeuron]) &&
resources.IsZero(itRes[v1alpha1.ResourceAMDGPU]) &&
resources.IsZero(itRes[v1alpha1.ResourceNVIDIAGPU]) {
genericInstanceTypes = append(genericInstanceTypes, it)
}
}
// if we got some subset of non-GPU types, then prefer to use those
if len(genericInstanceTypes) != 0 {
return genericInstanceTypes
}
return instanceTypes
}

func getInstanceID(node *v1.Node) (*string, error) {
id := strings.Split(node.Spec.ProviderID, "/")
if len(id) < 5 {
Expand All @@ -365,19 +363,3 @@ func getCapacityType(instance *ec2.Instance) string {
}
return v1alpha1.CapacityTypeOnDemand
}

func combineFleetInstances(createFleetOutput ec2.CreateFleetOutput) []*string {
instanceIds := []*string{}
for _, reservation := range createFleetOutput.Instances {
instanceIds = append(instanceIds, reservation.InstanceIds...)
}
return instanceIds
}

func combineReservations(reservations []*ec2.Reservation) []*ec2.Instance {
instances := []*ec2.Instance{}
for _, reservation := range reservations {
instances = append(instances, reservation.Instances...)
}
return instances
}
4 changes: 4 additions & 0 deletions pkg/cloudprovider/aws/instancetypes.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ package aws
import (
"context"
"fmt"
"sync"
"time"

"github.com/aws/aws-sdk-go/aws"
Expand All @@ -42,6 +43,7 @@ const (
)

type InstanceTypeProvider struct {
sync.Mutex
ec2api ec2iface.EC2API
subnetProvider *SubnetProvider
// Has two entries: one for all the instance types and one for all zones; values cached *before* considering insufficient capacity errors
Expand All @@ -62,6 +64,8 @@ func NewInstanceTypeProvider(ec2api ec2iface.EC2API, subnetProvider *SubnetProvi

// Get all instance type options (the constraints are only used for tag filtering on subnets, not for Requirements filtering)
func (p *InstanceTypeProvider) Get(ctx context.Context, provider *v1alpha1.AWS) ([]cloudprovider.InstanceType, error) {
p.Lock()
defer p.Unlock()
// Get InstanceTypes from EC2
instanceTypes, err := p.getInstanceTypes(ctx)
if err != nil {
Expand Down
6 changes: 2 additions & 4 deletions pkg/cloudprovider/aws/launchtemplate.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,8 @@ func launchTemplateName(options *amifamily.LaunchTemplate) string {
}

func (p *LaunchTemplateProvider) Get(ctx context.Context, constraints *v1alpha1.Constraints, instanceTypes []cloudprovider.InstanceType, additionalLabels map[string]string) (map[string][]cloudprovider.InstanceType, error) {
p.Lock()
defer p.Unlock()
// If Launch Template is directly specified then just use it
if constraints.LaunchTemplateName != nil {
return map[string][]cloudprovider.InstanceType{ptr.StringValue(constraints.LaunchTemplateName): instanceTypes}, nil
Expand Down Expand Up @@ -124,10 +126,6 @@ func (p *LaunchTemplateProvider) Get(ctx context.Context, constraints *v1alpha1.
}

func (p *LaunchTemplateProvider) ensureLaunchTemplate(ctx context.Context, options *amifamily.LaunchTemplate) (*ec2.LaunchTemplate, error) {
// Ensure that multiple threads don't attempt to create the same launch template
p.Lock()
defer p.Unlock()

var launchTemplate *ec2.LaunchTemplate
name := launchTemplateName(options)
// Read from cache
Expand Down
4 changes: 4 additions & 0 deletions pkg/cloudprovider/aws/securitygroups.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ package aws
import (
"context"
"fmt"
"sync"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/service/ec2"
Expand All @@ -29,6 +30,7 @@ import (
)

type SecurityGroupProvider struct {
sync.Mutex
ec2api ec2iface.EC2API
cache *cache.Cache
}
Expand All @@ -41,6 +43,8 @@ func NewSecurityGroupProvider(ec2api ec2iface.EC2API) *SecurityGroupProvider {
}

func (p *SecurityGroupProvider) Get(ctx context.Context, constraints *v1alpha1.Constraints) ([]string, error) {
p.Lock()
defer p.Unlock()
// Get SecurityGroups
securityGroups, err := p.getSecurityGroups(ctx, p.getFilters(constraints))
if err != nil {
Expand Down
4 changes: 4 additions & 0 deletions pkg/cloudprovider/aws/subnets.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ package aws
import (
"context"
"fmt"
"sync"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/service/ec2"
Expand All @@ -30,6 +31,7 @@ import (
)

type SubnetProvider struct {
sync.Mutex
ec2api ec2iface.EC2API
cache *cache.Cache
}
Expand All @@ -42,6 +44,8 @@ func NewSubnetProvider(ec2api ec2iface.EC2API) *SubnetProvider {
}

func (p *SubnetProvider) Get(ctx context.Context, constraints *v1alpha1.AWS) ([]*ec2.Subnet, error) {
p.Lock()
defer p.Unlock()
filters := getFilters(constraints)
hash, err := hashstructure.Hash(filters, hashstructure.FormatV2, nil)
if err != nil {
Expand Down
Loading

0 comments on commit 74c2601

Please sign in to comment.