Skip to content

Commit

Permalink
Support for Subnets specification/override (#454)
Browse files Browse the repository at this point in the history
* Implemented a well known label for subnets

* PR Comments

* Fixed a bug with union string map
  • Loading branch information
ellistarn authored Jun 21, 2021
1 parent 9825587 commit 291e511
Show file tree
Hide file tree
Showing 21 changed files with 738 additions and 440 deletions.
31 changes: 12 additions & 19 deletions pkg/apis/provisioning/v1alpha1/provisioner.go
Original file line number Diff line number Diff line change
Expand Up @@ -137,27 +137,20 @@ type ProvisionerList struct {
Items []Provisioner `json:"items"`
}

func (p *Provisioner) ConstraintsWithOverrides(pod *v1.Pod) *Constraints {
return &Constraints{
Taints: p.Spec.Taints,
Labels: p.Spec.Constraints.getLabels(p.Name, p.Namespace, pod),
Zones: p.Spec.Constraints.getZones(pod),
InstanceTypes: p.Spec.Constraints.getInstanceTypes(pod),
Architecture: p.Spec.Constraints.getArchitecture(pod),
OperatingSystem: p.Spec.Constraints.getOperatingSystem(pod),
}
func (c *Constraints) WithLabel(key string, value string) *Constraints {
c.Labels = functional.UnionStringMaps(c.Labels, map[string]string{key: value})
return c
}

func (c *Constraints) getLabels(name string, namespace string, pod *v1.Pod) map[string]string {
// These keys are guaranteed to not collide due to validation logic
return functional.UnionStringMaps(
c.Labels,
pod.Spec.NodeSelector,
map[string]string{
ProvisionerNameLabelKey: name,
ProvisionerNamespaceLabelKey: namespace,
},
)
func (c *Constraints) WithOverrides(pod *v1.Pod) *Constraints {
return &Constraints{
Taints: c.Taints,
Labels: functional.UnionStringMaps(c.Labels, pod.Spec.NodeSelector),
Zones: c.getZones(pod),
InstanceTypes: c.getInstanceTypes(pod),
Architecture: c.getArchitecture(pod),
OperatingSystem: c.getOperatingSystem(pod),
}
}

func (c *Constraints) getZones(pod *v1.Pod) []string {
Expand Down
110 changes: 70 additions & 40 deletions pkg/apis/provisioning/v1alpha1/provisioner_validation.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
"fmt"

"github.com/awslabs/karpenter/pkg/utils/functional"
"k8s.io/apimachinery/pkg/util/validation"
"knative.dev/pkg/apis"
)

Expand All @@ -40,89 +41,118 @@ var (
SupportedOperatingSystems = []string{}
SupportedZones = []string{}
SupportedInstanceTypes = []string{}
ValidationHook func(ctx context.Context, spec *ProvisionerSpec) *apis.FieldError
ConstraintsValidationHook func(ctx context.Context, constraints *Constraints) *apis.FieldError
)

func (p *Provisioner) Validate(ctx context.Context) (errs *apis.FieldError) {
return errs.Also(
apis.ValidateObjectMetadata(p),
p.Spec.Validate(ctx),
apis.ValidateObjectMetadata(p).ViaField("metadata"),
p.Spec.validate(ctx).ViaField("spec"),
)
}

func (s *ProvisionerSpec) Validate(ctx context.Context) (errs *apis.FieldError) {
errs = errs.Also(
s.validateClusterSpec(ctx),
s.validateLabels(ctx),
s.validateZones(ctx),
s.validateInstanceTypes(ctx),
s.validateArchitecture(ctx),
s.validateOperatingSystem(ctx),
func (s *ProvisionerSpec) validate(ctx context.Context) (errs *apis.FieldError) {
return errs.Also(
s.Cluster.validate().ViaField("cluster"),
// This validation is on the ProvisionerSpec despite the fact that
// labels are a property of Constraints. This is necessary because
// validation is applied to constraints that include pod overrides.
// These labels are restricted when creating provisioners, but are not
// restricted for pods since they're necessary to override constraints.
s.validateRestrictedLabels(),
s.Constraints.Validate(ctx),
)
if ValidationHook != nil {
errs = errs.Also(ValidationHook(ctx, s))
}

func (s *ProvisionerSpec) validateRestrictedLabels() (errs *apis.FieldError) {
for key := range s.Labels {
if functional.ContainsString(RestrictedLabels, key) {
errs = errs.Also(apis.ErrInvalidKeyName(key, "labels"))
}
}
return errs
}

func (s *ProvisionerSpec) validateClusterSpec(ctx context.Context) (errs *apis.FieldError) {
if s.Cluster == nil {
return errs.Also(apis.ErrMissingField("spec.cluster"))
func (s *ClusterSpec) validate() (errs *apis.FieldError) {
if s == nil {
return errs.Also(apis.ErrMissingField())
}
if len(s.Name) == 0 {
errs = errs.Also(apis.ErrMissingField("name"))
}
if len(s.Cluster.Name) == 0 {
errs = errs.Also(apis.ErrMissingField("spec.cluster.name"))
if len(s.Endpoint) == 0 {
errs = errs.Also(apis.ErrMissingField("endpoint"))
}
if len(s.Cluster.Endpoint) == 0 {
errs = errs.Also(apis.ErrMissingField("spec.cluster.endpoint"))
if len(s.CABundle) == 0 {
errs = errs.Also(apis.ErrMissingField("caBundle"))
}
if len(s.Cluster.CABundle) == 0 {
errs = errs.Also(apis.ErrMissingField("spec.cluster.caBundle"))
return errs
}

// Validate constraints subresource. This validation logic is used both upon
// creation of a provisioner as well as when a pod is attempting to be
// provisioned. If a provisioner fails validation, it will be rejected by the
// API Server. If constraints.WithOverrides(pod) fails validation, the pod will
// be ignored for provisioning.
func (c *Constraints) Validate(ctx context.Context) (errs *apis.FieldError) {
errs = errs.Also(
c.validateLabels(),
c.validateArchitecture(),
c.validateOperatingSystem(),
c.validateZones(),
c.validateInstanceTypes(),
)
if ConstraintsValidationHook != nil {
errs = errs.Also(ConstraintsValidationHook(ctx, c))
}
return errs
}

func (s *ProvisionerSpec) validateLabels(ctx context.Context) (errs *apis.FieldError) {
for _, restricted := range RestrictedLabels {
if _, ok := s.Labels[restricted]; ok {
errs = errs.Also(apis.ErrInvalidKeyName(restricted, "spec.labels"))
func (c *Constraints) validateLabels() (errs *apis.FieldError) {
for key, value := range c.Labels {
for _, err := range validation.IsQualifiedName(key) {
errs = errs.Also(apis.ErrInvalidKeyName(key, "labels", err))
}
for _, err := range validation.IsValidLabelValue(value) {
errs = errs.Also(apis.ErrInvalidValue(value+", "+err, "labels"))
}
}
return errs
}

func (s *ProvisionerSpec) validateArchitecture(ctx context.Context) (errs *apis.FieldError) {
if s.Architecture == nil {
func (c *Constraints) validateArchitecture() (errs *apis.FieldError) {
if c.Architecture == nil {
return nil
}
if !functional.ContainsString(SupportedArchitectures, *s.Architecture) {
errs = errs.Also(apis.ErrInvalidValue(fmt.Sprintf("%s not in %v", *s.Architecture, SupportedArchitectures), "spec.architecture"))
if !functional.ContainsString(SupportedArchitectures, *c.Architecture) {
errs = errs.Also(apis.ErrInvalidValue(fmt.Sprintf("%s not in %v", *c.Architecture, SupportedArchitectures), "architecture"))
}
return errs
}

func (s *ProvisionerSpec) validateOperatingSystem(ctx context.Context) (errs *apis.FieldError) {
if s.OperatingSystem == nil {
func (c *Constraints) validateOperatingSystem() (errs *apis.FieldError) {
if c.OperatingSystem == nil {
return nil
}
if !functional.ContainsString(SupportedOperatingSystems, *s.OperatingSystem) {
errs = errs.Also(apis.ErrInvalidValue(fmt.Sprintf("%s not in %v", *s.OperatingSystem, SupportedOperatingSystems), "spec.operatingSystem"))
if !functional.ContainsString(SupportedOperatingSystems, *c.OperatingSystem) {
errs = errs.Also(apis.ErrInvalidValue(fmt.Sprintf("%s not in %v", *c.OperatingSystem, SupportedOperatingSystems), "operatingSystem"))
}
return errs
}

func (s *ProvisionerSpec) validateZones(ctx context.Context) (errs *apis.FieldError) {
for i, zone := range s.Zones {
func (c *Constraints) validateZones() (errs *apis.FieldError) {
for i, zone := range c.Zones {
if !functional.ContainsString(SupportedZones, zone) {
errs = errs.Also(apis.ErrInvalidArrayValue(fmt.Sprintf("%s not in %v", zone, SupportedZones), "spec.zones", i))
errs = errs.Also(apis.ErrInvalidArrayValue(fmt.Sprintf("%s not in %v", zone, SupportedZones), "zones", i))
}
}
return errs
}

func (s *ProvisionerSpec) validateInstanceTypes(ctx context.Context) (errs *apis.FieldError) {
for i, instanceType := range s.InstanceTypes {
func (c *Constraints) validateInstanceTypes() (errs *apis.FieldError) {
for i, instanceType := range c.InstanceTypes {
if !functional.ContainsString(SupportedInstanceTypes, instanceType) {
errs = errs.Also(apis.ErrInvalidArrayValue(fmt.Sprintf("%s not in %v", instanceType, SupportedInstanceTypes), "spec.instanceTypes", i))
errs = errs.Also(apis.ErrInvalidArrayValue(fmt.Sprintf("%s not in %v", instanceType, SupportedInstanceTypes), "instanceTypes", i))
}
}
return errs
Expand Down
36 changes: 22 additions & 14 deletions pkg/apis/provisioning/v1alpha1/provisioner_validation_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,22 +65,30 @@ var _ = Describe("Validation", func() {
}
})

It("should fail for restricted labels", func() {
for _, label := range []string{
ArchitectureLabelKey,
OperatingSystemLabelKey,
ProvisionerNameLabelKey,
ProvisionerNamespaceLabelKey,
ProvisionerPhaseLabel,
ProvisionerTTLKey,
ZoneLabelKey,
InstanceTypeLabelKey,
} {
provisioner.Spec.Labels = map[string]string{label: randomdata.SillyName()}
Context("Labels", func() {
It("should fail for invalid label keys", func() {
provisioner.Spec.Labels = map[string]string{"spaces are not allowed": randomdata.SillyName()}
Expect(provisioner.Validate(ctx)).ToNot(Succeed())
}
})
It("should fail for invalid label values", func() {
provisioner.Spec.Labels = map[string]string{randomdata.SillyName(): "/ is not allowed"}
Expect(provisioner.Validate(ctx)).ToNot(Succeed())
})
It("should fail for restricted labels", func() {
for _, label := range []string{
ArchitectureLabelKey,
OperatingSystemLabelKey,
ProvisionerNameLabelKey,
ProvisionerNamespaceLabelKey,
ProvisionerPhaseLabel,
ZoneLabelKey,
InstanceTypeLabelKey,
} {
provisioner.Spec.Labels = map[string]string{label: randomdata.SillyName()}
Expect(provisioner.Validate(ctx)).ToNot(Succeed())
}
})
})

Context("Zones", func() {
SupportedZones = append(SupportedZones, "test-zone-1")
It("should succeed if unspecified", func() {
Expand Down
71 changes: 15 additions & 56 deletions pkg/cloudprovider/aws/cloudprovider.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ package aws
import (
"context"
"fmt"
"strings"
"time"

"github.com/aws/aws-sdk-go/aws"
Expand All @@ -28,7 +27,6 @@ import (
"github.com/awslabs/karpenter/pkg/apis/provisioning/v1alpha1"
"github.com/awslabs/karpenter/pkg/cloudprovider"
"github.com/awslabs/karpenter/pkg/cloudprovider/aws/utils"
"github.com/awslabs/karpenter/pkg/utils/functional"
"github.com/awslabs/karpenter/pkg/utils/project"
"github.com/patrickmn/go-cache"
"go.uber.org/zap"
Expand Down Expand Up @@ -108,39 +106,33 @@ func withUserAgent(sess *session.Session) *session.Session {
}

// Create a set of nodes given the constraints.
func (a *CloudProvider) Create(ctx context.Context, provisioner *v1alpha1.Provisioner, packings []*cloudprovider.Packing) ([]*cloudprovider.PackedNode, error) {
func (c *CloudProvider) Create(ctx context.Context, provisioner *v1alpha1.Provisioner, packings []*cloudprovider.Packing) ([]*cloudprovider.PackedNode, error) {
instanceIDs := []*string{}
instancePackings := map[string]*cloudprovider.Packing{}
for _, packing := range packings {
constraints := Constraints(*packing.Constraints)
constraints := Constraints{*packing.Constraints}
// 1. Get Subnets and constrain by zones
zonalSubnets, err := a.subnetProvider.GetZonalSubnets(ctx, provisioner.Spec.Cluster.Name)
subnets, err := c.subnetProvider.Get(ctx, provisioner, &constraints)
if err != nil {
return nil, fmt.Errorf("getting zonal subnets, %w", err)
}
zonalSubnetOptions := map[string][]*ec2.Subnet{}
for zone, subnets := range zonalSubnets {
if len(constraints.Zones) == 0 || functional.ContainsString(constraints.Zones, zone) {
zonalSubnetOptions[zone] = subnets
}
}
// 2. Get Launch Template
launchTemplate, err := a.launchTemplateProvider.Get(ctx, provisioner, &constraints)
launchTemplate, err := c.launchTemplateProvider.Get(ctx, provisioner, &constraints)
if err != nil {
return nil, fmt.Errorf("getting launch template, %w", err)
}
// 3. Create instance
instanceID, err := a.instanceProvider.Create(ctx, launchTemplate, packing.InstanceTypeOptions, zonalSubnets, constraints.GetCapacityType())
instanceID, err := c.instanceProvider.Create(ctx, launchTemplate, packing.InstanceTypeOptions, subnets, constraints.GetCapacityType())
if err != nil {
// TODO Aggregate errors and continue
return nil, fmt.Errorf("creating capacity %w", err)
zap.S().Errorf("Continuing after failing to launch instances, %s", err.Error())
continue
}
instancePackings[*instanceID] = packing
instanceIDs = append(instanceIDs, instanceID)
}

// 4. Convert to Nodes
nodes, err := a.nodeAPI.For(ctx, instanceIDs)
nodes, err := c.nodeAPI.For(ctx, instanceIDs)
if err != nil {
return nil, fmt.Errorf("determining nodes, %w", err)
}
Expand All @@ -158,49 +150,16 @@ func (a *CloudProvider) Create(ctx context.Context, provisioner *v1alpha1.Provis
return packedNodes, nil
}

func (a *CloudProvider) GetInstanceTypes(ctx context.Context) ([]cloudprovider.InstanceType, error) {
return a.instanceTypeProvider.Get(ctx)
func (c *CloudProvider) GetInstanceTypes(ctx context.Context) ([]cloudprovider.InstanceType, error) {
return c.instanceTypeProvider.Get(ctx)
}

func (a *CloudProvider) Terminate(ctx context.Context, nodes []*v1.Node) error {
return a.instanceProvider.Terminate(ctx, nodes)
func (c *CloudProvider) Terminate(ctx context.Context, nodes []*v1.Node) error {
return c.instanceProvider.Terminate(ctx, nodes)
}

// Validate cloud provider specific components of the cluster spec
func (a *CloudProvider) Validate(ctx context.Context, spec *v1alpha1.ProvisionerSpec) (errs *apis.FieldError) {
return errs.Also(
validateAllowedLabels(*spec),
validateCapacityTypeLabel(*spec),
validateLaunchTemplateLabels(*spec),
)
}

func validateAllowedLabels(spec v1alpha1.ProvisionerSpec) (errs *apis.FieldError) {
for key := range spec.Labels {
if strings.HasPrefix(key, AWSLabelPrefix) && !functional.ContainsString(AllowedLabels, key) {
errs = errs.Also(apis.ErrInvalidKeyName(key, "spec.labels"))
}
}
return errs
}

func validateCapacityTypeLabel(spec v1alpha1.ProvisionerSpec) (errs *apis.FieldError) {
capacityType, ok := spec.Labels[CapacityTypeLabel]
if !ok {
return nil
}
capacityTypes := []string{CapacityTypeSpot, CapacityTypeOnDemand}
if !functional.ContainsString(capacityTypes, capacityType) {
errs = errs.Also(apis.ErrInvalidValue(fmt.Sprintf("%s not in %v", capacityType, capacityTypes), fmt.Sprintf("spec.labels[%s]", CapacityTypeLabel)))
}
return errs
}

func validateLaunchTemplateLabels(spec v1alpha1.ProvisionerSpec) (errs *apis.FieldError) {
if _, versionExists := spec.Labels[LaunchTemplateVersionLabel]; versionExists {
if _, bothExist := spec.Labels[LaunchTemplateIdLabel]; !bothExist {
return errs.Also(apis.ErrMissingField(fmt.Sprintf("spec.labels[%s]", LaunchTemplateIdLabel)))
}
}
return errs
func (c *CloudProvider) Validate(ctx context.Context, constraints *v1alpha1.Constraints) (errs *apis.FieldError) {
awsConstraints := Constraints{*constraints}
return awsConstraints.Validate(ctx)
}
Loading

0 comments on commit 291e511

Please sign in to comment.