Skip to content

Commit

Permalink
Implemented a well known label for subnets
Browse files Browse the repository at this point in the history
  • Loading branch information
ellistarn committed Jun 18, 2021
1 parent 5dbe0d9 commit b15ce16
Show file tree
Hide file tree
Showing 19 changed files with 769 additions and 424 deletions.
2 changes: 1 addition & 1 deletion cmd/webhook/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ func main() {
ServiceName: "karpenter-webhook",
SecretName: "karpenter-webhook-cert",
}),
"Karpenter Webhooks",
"karpenter.webhooks",
config,
certificates.NewController,
NewCRDDefaultingWebhook,
Expand Down
31 changes: 12 additions & 19 deletions pkg/apis/provisioning/v1alpha1/provisioner.go
Original file line number Diff line number Diff line change
Expand Up @@ -137,27 +137,20 @@ type ProvisionerList struct {
Items []Provisioner `json:"items"`
}

func (p *Provisioner) ConstraintsWithOverrides(pod *v1.Pod) *Constraints {
return &Constraints{
Taints: p.Spec.Taints,
Labels: p.Spec.Constraints.getLabels(p.Name, p.Namespace, pod),
Zones: p.Spec.Constraints.getZones(pod),
InstanceTypes: p.Spec.Constraints.getInstanceTypes(pod),
Architecture: p.Spec.Constraints.getArchitecture(pod),
OperatingSystem: p.Spec.Constraints.getOperatingSystem(pod),
}
func (c *Constraints) AddLabels(labels map[string]string) *Constraints {
c.Labels = functional.UnionStringMaps(c.Labels, labels)
return c
}

func (c *Constraints) getLabels(name string, namespace string, pod *v1.Pod) map[string]string {
// These keys are guaranteed to not collide due to validation logic
return functional.UnionStringMaps(
c.Labels,
pod.Spec.NodeSelector,
map[string]string{
ProvisionerNameLabelKey: name,
ProvisionerNamespaceLabelKey: namespace,
},
)
func (c *Constraints) WithOverrides(pod *v1.Pod) *Constraints {
return &Constraints{
Taints: c.Taints,
Labels: c.Labels,
Zones: c.getZones(pod),
InstanceTypes: c.getInstanceTypes(pod),
Architecture: c.getArchitecture(pod),
OperatingSystem: c.getOperatingSystem(pod),
}
}

func (c *Constraints) getZones(pod *v1.Pod) []string {
Expand Down
95 changes: 54 additions & 41 deletions pkg/apis/provisioning/v1alpha1/provisioner_validation.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
"fmt"

"github.com/awslabs/karpenter/pkg/utils/functional"
"k8s.io/apimachinery/pkg/util/validation"
"knative.dev/pkg/apis"
)

Expand All @@ -40,89 +41,101 @@ var (
SupportedOperatingSystems = []string{}
SupportedZones = []string{}
SupportedInstanceTypes = []string{}
ValidationHook func(ctx context.Context, spec *ProvisionerSpec) *apis.FieldError
ConstraintsValidationHook func(ctx context.Context, constraints *Constraints) *apis.FieldError
)

func (p *Provisioner) Validate(ctx context.Context) (errs *apis.FieldError) {
return errs.Also(
apis.ValidateObjectMetadata(p),
p.Spec.Validate(ctx),
apis.ValidateObjectMetadata(p).ViaField("metadata"),
p.Spec.Validate(ctx).ViaField("spec"),
)
}

func (s *ProvisionerSpec) Validate(ctx context.Context) (errs *apis.FieldError) {
errs = errs.Also(
s.validateClusterSpec(ctx),
s.validateLabels(ctx),
s.validateZones(ctx),
s.validateInstanceTypes(ctx),
s.validateArchitecture(ctx),
s.validateOperatingSystem(ctx),
return errs.Also(
s.Cluster.Validate(ctx).ViaField("cluster"),
s.Constraints.Validate(ctx),
)
if ValidationHook != nil {
errs = errs.Also(ValidationHook(ctx, s))
}
return errs
}

func (s *ProvisionerSpec) validateClusterSpec(ctx context.Context) (errs *apis.FieldError) {
if s.Cluster == nil {
return errs.Also(apis.ErrMissingField("spec.cluster"))
func (s *ClusterSpec) Validate(ctx context.Context) (errs *apis.FieldError) {
if s == nil {
return errs.Also(apis.ErrMissingField())
}
if len(s.Name) == 0 {
errs = errs.Also(apis.ErrMissingField("name"))
}
if len(s.Cluster.Name) == 0 {
errs = errs.Also(apis.ErrMissingField("spec.cluster.name"))
if len(s.Endpoint) == 0 {
errs = errs.Also(apis.ErrMissingField("endpoint"))
}
if len(s.Cluster.Endpoint) == 0 {
errs = errs.Also(apis.ErrMissingField("spec.cluster.endpoint"))
if len(s.CABundle) == 0 {
errs = errs.Also(apis.ErrMissingField("caBundle"))
}
if len(s.Cluster.CABundle) == 0 {
errs = errs.Also(apis.ErrMissingField("spec.cluster.caBundle"))
return errs
}

func (c *Constraints) Validate(ctx context.Context) (errs *apis.FieldError) {
errs = errs.Also(
c.validateLabels(ctx),
c.validateArchitecture(ctx),
c.validateOperatingSystem(ctx),
c.validateZones(ctx),
c.validateInstanceTypes(ctx),
)
if ConstraintsValidationHook != nil {
errs = errs.Also(ConstraintsValidationHook(ctx, c))
}
return errs
}

func (s *ProvisionerSpec) validateLabels(ctx context.Context) (errs *apis.FieldError) {
for _, restricted := range RestrictedLabels {
if _, ok := s.Labels[restricted]; ok {
errs = errs.Also(apis.ErrInvalidKeyName(restricted, "spec.labels"))
func (c *Constraints) validateLabels(ctx context.Context) (errs *apis.FieldError) {
for key, value := range c.Labels {
if functional.ContainsString(RestrictedLabels, key) {
errs = errs.Also(apis.ErrInvalidKeyName(key, "labels"))
}
for _, err := range validation.IsQualifiedName(key) {
errs = errs.Also(apis.ErrInvalidKeyName(key, "labels", err))
}
for _, err := range validation.IsValidLabelValue(value) {
errs = errs.Also(apis.ErrInvalidValue(value + ", " + err, "labels"))
}
}
return errs
}

func (s *ProvisionerSpec) validateArchitecture(ctx context.Context) (errs *apis.FieldError) {
if s.Architecture == nil {
func (c *Constraints) validateArchitecture(ctx context.Context) (errs *apis.FieldError) {
if c.Architecture == nil {
return nil
}
if !functional.ContainsString(SupportedArchitectures, *s.Architecture) {
errs = errs.Also(apis.ErrInvalidValue(fmt.Sprintf("%s not in %v", *s.Architecture, SupportedArchitectures), "spec.architecture"))
if !functional.ContainsString(SupportedArchitectures, *c.Architecture) {
errs = errs.Also(apis.ErrInvalidValue(fmt.Sprintf("%s not in %v", *c.Architecture, SupportedArchitectures), "architecture"))
}
return errs
}

func (s *ProvisionerSpec) validateOperatingSystem(ctx context.Context) (errs *apis.FieldError) {
if s.OperatingSystem == nil {
func (c *Constraints) validateOperatingSystem(ctx context.Context) (errs *apis.FieldError) {
if c.OperatingSystem == nil {
return nil
}
if !functional.ContainsString(SupportedOperatingSystems, *s.OperatingSystem) {
errs = errs.Also(apis.ErrInvalidValue(fmt.Sprintf("%s not in %v", *s.OperatingSystem, SupportedOperatingSystems), "spec.operatingSystem"))
if !functional.ContainsString(SupportedOperatingSystems, *c.OperatingSystem) {
errs = errs.Also(apis.ErrInvalidValue(fmt.Sprintf("%s not in %v", *c.OperatingSystem, SupportedOperatingSystems), "operatingSystem"))
}
return errs
}

func (s *ProvisionerSpec) validateZones(ctx context.Context) (errs *apis.FieldError) {
for i, zone := range s.Zones {
func (c *Constraints) validateZones(ctx context.Context) (errs *apis.FieldError) {
for i, zone := range c.Zones {
if !functional.ContainsString(SupportedZones, zone) {
errs = errs.Also(apis.ErrInvalidArrayValue(fmt.Sprintf("%s not in %v", zone, SupportedZones), "spec.zones", i))
errs = errs.Also(apis.ErrInvalidArrayValue(fmt.Sprintf("%s not in %v", zone, SupportedZones), "zones", i))
}
}
return errs
}

func (s *ProvisionerSpec) validateInstanceTypes(ctx context.Context) (errs *apis.FieldError) {
for i, instanceType := range s.InstanceTypes {
func (c *Constraints) validateInstanceTypes(ctx context.Context) (errs *apis.FieldError) {
for i, instanceType := range c.InstanceTypes {
if !functional.ContainsString(SupportedInstanceTypes, instanceType) {
errs = errs.Also(apis.ErrInvalidArrayValue(fmt.Sprintf("%s not in %v", instanceType, SupportedInstanceTypes), "spec.instanceTypes", i))
errs = errs.Also(apis.ErrInvalidArrayValue(fmt.Sprintf("%s not in %v", instanceType, SupportedInstanceTypes), "instanceTypes", i))
}
}
return errs
Expand Down
37 changes: 23 additions & 14 deletions pkg/apis/provisioning/v1alpha1/provisioner_validation_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,22 +65,31 @@ var _ = Describe("Validation", func() {
}
})

It("should fail for restricted labels", func() {
for _, label := range []string{
ArchitectureLabelKey,
OperatingSystemLabelKey,
ProvisionerNameLabelKey,
ProvisionerNamespaceLabelKey,
ProvisionerPhaseLabel,
ProvisionerTTLKey,
ZoneLabelKey,
InstanceTypeLabelKey,
} {
provisioner.Spec.Labels = map[string]string{label: randomdata.SillyName()}
Context("Labels", func() {
It("should fail for invalid label keys", func() {
provisioner.Spec.Labels = map[string]string{"spaces are not allowed": randomdata.SillyName()}
Expect(provisioner.Validate(ctx)).ToNot(Succeed())
}
})
It("should fail for invalid label values", func() {
provisioner.Spec.Labels = map[string]string{randomdata.SillyName(): "/ is not allowed"}
Expect(provisioner.Validate(ctx)).ToNot(Succeed())
})
It("should fail for restricted labels", func() {
for _, label := range []string{
ArchitectureLabelKey,
OperatingSystemLabelKey,
ProvisionerNameLabelKey,
ProvisionerNamespaceLabelKey,
ProvisionerPhaseLabel,
ProvisionerTTLKey,
ZoneLabelKey,
InstanceTypeLabelKey,
} {
provisioner.Spec.Labels = map[string]string{label: randomdata.SillyName()}
Expect(provisioner.Validate(ctx)).ToNot(Succeed())
}
})
})

Context("Zones", func() {
SupportedZones = append(SupportedZones, "test-zone-1")
It("should succeed if unspecified", func() {
Expand Down
67 changes: 13 additions & 54 deletions pkg/cloudprovider/aws/cloudprovider.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ package aws
import (
"context"
"fmt"
"strings"
"time"

"github.com/aws/aws-sdk-go/aws"
Expand All @@ -28,7 +27,6 @@ import (
"github.com/awslabs/karpenter/pkg/apis/provisioning/v1alpha1"
"github.com/awslabs/karpenter/pkg/cloudprovider"
"github.com/awslabs/karpenter/pkg/cloudprovider/aws/utils"
"github.com/awslabs/karpenter/pkg/utils/functional"
"github.com/awslabs/karpenter/pkg/utils/project"
"github.com/patrickmn/go-cache"
"go.uber.org/zap"
Expand Down Expand Up @@ -108,29 +106,23 @@ func withUserAgent(sess *session.Session) *session.Session {
}

// Create a set of nodes given the constraints.
func (a *CloudProvider) Create(ctx context.Context, provisioner *v1alpha1.Provisioner, packings []*cloudprovider.Packing) ([]*cloudprovider.PackedNode, error) {
func (c *CloudProvider) Create(ctx context.Context, provisioner *v1alpha1.Provisioner, packings []*cloudprovider.Packing) ([]*cloudprovider.PackedNode, error) {
instanceIDs := []*string{}
instancePackings := map[string]*cloudprovider.Packing{}
for _, packing := range packings {
constraints := Constraints(*packing.Constraints)
constraints := Constraints{*packing.Constraints}
// 1. Get Subnets and constrain by zones
zonalSubnets, err := a.subnetProvider.GetZonalSubnets(ctx, provisioner.Spec.Cluster.Name)
subnets, err := c.subnetProvider.Get(ctx, provisioner, &constraints)
if err != nil {
return nil, fmt.Errorf("getting zonal subnets, %w", err)
}
zonalSubnetOptions := map[string][]*ec2.Subnet{}
for zone, subnets := range zonalSubnets {
if len(constraints.Zones) == 0 || functional.ContainsString(constraints.Zones, zone) {
zonalSubnetOptions[zone] = subnets
}
}
// 2. Get Launch Template
launchTemplate, err := a.launchTemplateProvider.Get(ctx, provisioner, &constraints)
launchTemplate, err := c.launchTemplateProvider.Get(ctx, provisioner, &constraints)
if err != nil {
return nil, fmt.Errorf("getting launch template, %w", err)
}
// 3. Create instance
instanceID, err := a.instanceProvider.Create(ctx, launchTemplate, packing.InstanceTypeOptions, zonalSubnets, constraints.GetCapacityType())
instanceID, err := c.instanceProvider.Create(ctx, launchTemplate, packing.InstanceTypeOptions, subnets, constraints.GetCapacityType())
if err != nil {
// TODO Aggregate errors and continue
return nil, fmt.Errorf("creating capacity %w", err)
Expand All @@ -140,7 +132,7 @@ func (a *CloudProvider) Create(ctx context.Context, provisioner *v1alpha1.Provis
}

// 4. Convert to Nodes
nodes, err := a.nodeAPI.For(ctx, instanceIDs)
nodes, err := c.nodeAPI.For(ctx, instanceIDs)
if err != nil {
return nil, fmt.Errorf("determining nodes, %w", err)
}
Expand All @@ -158,49 +150,16 @@ func (a *CloudProvider) Create(ctx context.Context, provisioner *v1alpha1.Provis
return packedNodes, nil
}

func (a *CloudProvider) GetInstanceTypes(ctx context.Context) ([]cloudprovider.InstanceType, error) {
return a.instanceTypeProvider.Get(ctx)
func (c *CloudProvider) GetInstanceTypes(ctx context.Context) ([]cloudprovider.InstanceType, error) {
return c.instanceTypeProvider.Get(ctx)
}

func (a *CloudProvider) Terminate(ctx context.Context, nodes []*v1.Node) error {
return a.instanceProvider.Terminate(ctx, nodes)
func (c *CloudProvider) Terminate(ctx context.Context, nodes []*v1.Node) error {
return c.instanceProvider.Terminate(ctx, nodes)
}

// Validate cloud provider specific components of the cluster spec
func (a *CloudProvider) Validate(ctx context.Context, spec *v1alpha1.ProvisionerSpec) (errs *apis.FieldError) {
return errs.Also(
validateAllowedLabels(*spec),
validateCapacityTypeLabel(*spec),
validateLaunchTemplateLabels(*spec),
)
}

func validateAllowedLabels(spec v1alpha1.ProvisionerSpec) (errs *apis.FieldError) {
for key := range spec.Labels {
if strings.HasPrefix(key, AWSLabelPrefix) && !functional.ContainsString(AllowedLabels, key) {
errs = errs.Also(apis.ErrInvalidKeyName(key, "spec.labels"))
}
}
return errs
}

func validateCapacityTypeLabel(spec v1alpha1.ProvisionerSpec) (errs *apis.FieldError) {
capacityType, ok := spec.Labels[CapacityTypeLabel]
if !ok {
return nil
}
capacityTypes := []string{CapacityTypeSpot, CapacityTypeOnDemand}
if !functional.ContainsString(capacityTypes, capacityType) {
errs = errs.Also(apis.ErrInvalidValue(fmt.Sprintf("%s not in %v", capacityType, capacityTypes), fmt.Sprintf("spec.labels[%s]", CapacityTypeLabel)))
}
return errs
}

func validateLaunchTemplateLabels(spec v1alpha1.ProvisionerSpec) (errs *apis.FieldError) {
if _, versionExists := spec.Labels[LaunchTemplateVersionLabel]; versionExists {
if _, bothExist := spec.Labels[LaunchTemplateIdLabel]; !bothExist {
return errs.Also(apis.ErrMissingField(fmt.Sprintf("spec.labels[%s]", LaunchTemplateIdLabel)))
}
}
return errs
func (c *CloudProvider) Validate(ctx context.Context, constraints *v1alpha1.Constraints) (errs *apis.FieldError) {
awsConstraints := Constraints{*constraints}
return awsConstraints.Validate(ctx)
}
Loading

0 comments on commit b15ce16

Please sign in to comment.