Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for Subnets specification/override #454

Merged
merged 3 commits into from
Jun 21, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cmd/webhook/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ func main() {
ServiceName: "karpenter-webhook",
SecretName: "karpenter-webhook-cert",
}),
"Karpenter Webhooks",
"karpenter.webhooks",
config,
certificates.NewController,
NewCRDDefaultingWebhook,
Expand Down
31 changes: 12 additions & 19 deletions pkg/apis/provisioning/v1alpha1/provisioner.go
Original file line number Diff line number Diff line change
Expand Up @@ -137,27 +137,20 @@ type ProvisionerList struct {
Items []Provisioner `json:"items"`
}

func (p *Provisioner) ConstraintsWithOverrides(pod *v1.Pod) *Constraints {
return &Constraints{
Taints: p.Spec.Taints,
Labels: p.Spec.Constraints.getLabels(p.Name, p.Namespace, pod),
Zones: p.Spec.Constraints.getZones(pod),
InstanceTypes: p.Spec.Constraints.getInstanceTypes(pod),
Architecture: p.Spec.Constraints.getArchitecture(pod),
OperatingSystem: p.Spec.Constraints.getOperatingSystem(pod),
}
func (c *Constraints) WithLabel(key string, value string) *Constraints {
c.Labels = functional.UnionStringMaps(c.Labels, map[string]string{key: value})
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could this also be done with c.Labels[key] = value?

functional.UnionStringMaps() works with "last write wins", so it seems like we overwrite an existing entry in c.labels with either versions, so it might be cleaner just to add it in.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oopsie, nice call!

return c
}

func (c *Constraints) getLabels(name string, namespace string, pod *v1.Pod) map[string]string {
// These keys are guaranteed to not collide due to validation logic
return functional.UnionStringMaps(
c.Labels,
pod.Spec.NodeSelector,
map[string]string{
ProvisionerNameLabelKey: name,
ProvisionerNamespaceLabelKey: namespace,
},
)
func (c *Constraints) WithOverrides(pod *v1.Pod) *Constraints {
return &Constraints{
Taints: c.Taints,
Labels: functional.UnionStringMaps(c.Labels, pod.Spec.NodeSelector),
Zones: c.getZones(pod),
InstanceTypes: c.getInstanceTypes(pod),
Architecture: c.getArchitecture(pod),
OperatingSystem: c.getOperatingSystem(pod),
}
}

func (c *Constraints) getZones(pod *v1.Pod) []string {
Expand Down
110 changes: 70 additions & 40 deletions pkg/apis/provisioning/v1alpha1/provisioner_validation.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
"fmt"

"github.com/awslabs/karpenter/pkg/utils/functional"
"k8s.io/apimachinery/pkg/util/validation"
"knative.dev/pkg/apis"
)

Expand All @@ -40,89 +41,118 @@ var (
SupportedOperatingSystems = []string{}
SupportedZones = []string{}
SupportedInstanceTypes = []string{}
ValidationHook func(ctx context.Context, spec *ProvisionerSpec) *apis.FieldError
ConstraintsValidationHook func(ctx context.Context, constraints *Constraints) *apis.FieldError
)

func (p *Provisioner) Validate(ctx context.Context) (errs *apis.FieldError) {
return errs.Also(
apis.ValidateObjectMetadata(p),
p.Spec.Validate(ctx),
apis.ValidateObjectMetadata(p).ViaField("metadata"),
p.Spec.validate(ctx).ViaField("spec"),
njtran marked this conversation as resolved.
Show resolved Hide resolved
)
}

func (s *ProvisionerSpec) Validate(ctx context.Context) (errs *apis.FieldError) {
errs = errs.Also(
s.validateClusterSpec(ctx),
s.validateLabels(ctx),
s.validateZones(ctx),
s.validateInstanceTypes(ctx),
s.validateArchitecture(ctx),
s.validateOperatingSystem(ctx),
func (s *ProvisionerSpec) validate(ctx context.Context) (errs *apis.FieldError) {
return errs.Also(
s.Cluster.validate(ctx).ViaField("cluster"),
// This validation is on the ProvisionerSpec despire the fact that
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: s/despire/despite/

// labels are a property of Constraints. This is necessary because
// validation is applied to constraints that include pod overrides.
// These labels are restricted when creating provisioners, but are not
// restricted for pods since they're necessary to override constraints.
s.validateRestrictedLabels(ctx),
s.Constraints.Validate(ctx),
)
if ValidationHook != nil {
errs = errs.Also(ValidationHook(ctx, s))
}

func (s *ProvisionerSpec) validateRestrictedLabels(ctx context.Context) (errs *apis.FieldError) {
for key := range s.Labels {
if functional.ContainsString(RestrictedLabels, key) {
errs = errs.Also(apis.ErrInvalidKeyName(key, "labels"))
}
}
return errs
}

func (s *ProvisionerSpec) validateClusterSpec(ctx context.Context) (errs *apis.FieldError) {
if s.Cluster == nil {
return errs.Also(apis.ErrMissingField("spec.cluster"))
func (s *ClusterSpec) validate(ctx context.Context) (errs *apis.FieldError) {
if s == nil {
return errs.Also(apis.ErrMissingField())
}
if len(s.Name) == 0 {
errs = errs.Also(apis.ErrMissingField("name"))
}
if len(s.Cluster.Name) == 0 {
errs = errs.Also(apis.ErrMissingField("spec.cluster.name"))
if len(s.Endpoint) == 0 {
errs = errs.Also(apis.ErrMissingField("endpoint"))
}
if len(s.Cluster.Endpoint) == 0 {
errs = errs.Also(apis.ErrMissingField("spec.cluster.endpoint"))
if len(s.CABundle) == 0 {
errs = errs.Also(apis.ErrMissingField("caBundle"))
}
if len(s.Cluster.CABundle) == 0 {
errs = errs.Also(apis.ErrMissingField("spec.cluster.caBundle"))
return errs
}

// Validate constraints subresource. This validation logic is used both upon
// creation of a provisioner as well as when a pod is attempting to be
// provisioned. If a provisioner fails validation, it will be rejected by the
// API Server. If constraints.WithOverrides(pod) fails validation, the pod will
// be ignored for provisioning.
func (c *Constraints) Validate(ctx context.Context) (errs *apis.FieldError) {
errs = errs.Also(
c.validateLabels(ctx),
c.validateArchitecture(ctx),
c.validateOperatingSystem(ctx),
c.validateZones(ctx),
c.validateInstanceTypes(ctx),
)
if ConstraintsValidationHook != nil {
errs = errs.Also(ConstraintsValidationHook(ctx, c))
}
return errs
}

func (s *ProvisionerSpec) validateLabels(ctx context.Context) (errs *apis.FieldError) {
for _, restricted := range RestrictedLabels {
if _, ok := s.Labels[restricted]; ok {
errs = errs.Also(apis.ErrInvalidKeyName(restricted, "spec.labels"))
func (c *Constraints) validateLabels(ctx context.Context) (errs *apis.FieldError) {
for key, value := range c.Labels {
for _, err := range validation.IsQualifiedName(key) {
errs = errs.Also(apis.ErrInvalidKeyName(key, "labels", err))
}
for _, err := range validation.IsValidLabelValue(value) {
errs = errs.Also(apis.ErrInvalidValue(value+", "+err, "labels"))
}
}
return errs
}

func (s *ProvisionerSpec) validateArchitecture(ctx context.Context) (errs *apis.FieldError) {
if s.Architecture == nil {
func (c *Constraints) validateArchitecture(ctx context.Context) (errs *apis.FieldError) {
if c.Architecture == nil {
return nil
}
if !functional.ContainsString(SupportedArchitectures, *s.Architecture) {
errs = errs.Also(apis.ErrInvalidValue(fmt.Sprintf("%s not in %v", *s.Architecture, SupportedArchitectures), "spec.architecture"))
if !functional.ContainsString(SupportedArchitectures, *c.Architecture) {
errs = errs.Also(apis.ErrInvalidValue(fmt.Sprintf("%s not in %v", *c.Architecture, SupportedArchitectures), "architecture"))
}
return errs
}

func (s *ProvisionerSpec) validateOperatingSystem(ctx context.Context) (errs *apis.FieldError) {
if s.OperatingSystem == nil {
func (c *Constraints) validateOperatingSystem(ctx context.Context) (errs *apis.FieldError) {
if c.OperatingSystem == nil {
return nil
}
if !functional.ContainsString(SupportedOperatingSystems, *s.OperatingSystem) {
errs = errs.Also(apis.ErrInvalidValue(fmt.Sprintf("%s not in %v", *s.OperatingSystem, SupportedOperatingSystems), "spec.operatingSystem"))
if !functional.ContainsString(SupportedOperatingSystems, *c.OperatingSystem) {
errs = errs.Also(apis.ErrInvalidValue(fmt.Sprintf("%s not in %v", *c.OperatingSystem, SupportedOperatingSystems), "operatingSystem"))
}
return errs
}

func (s *ProvisionerSpec) validateZones(ctx context.Context) (errs *apis.FieldError) {
for i, zone := range s.Zones {
func (c *Constraints) validateZones(ctx context.Context) (errs *apis.FieldError) {
for i, zone := range c.Zones {
if !functional.ContainsString(SupportedZones, zone) {
errs = errs.Also(apis.ErrInvalidArrayValue(fmt.Sprintf("%s not in %v", zone, SupportedZones), "spec.zones", i))
errs = errs.Also(apis.ErrInvalidArrayValue(fmt.Sprintf("%s not in %v", zone, SupportedZones), "zones", i))
}
}
return errs
}

func (s *ProvisionerSpec) validateInstanceTypes(ctx context.Context) (errs *apis.FieldError) {
for i, instanceType := range s.InstanceTypes {
func (c *Constraints) validateInstanceTypes(ctx context.Context) (errs *apis.FieldError) {
for i, instanceType := range c.InstanceTypes {
if !functional.ContainsString(SupportedInstanceTypes, instanceType) {
errs = errs.Also(apis.ErrInvalidArrayValue(fmt.Sprintf("%s not in %v", instanceType, SupportedInstanceTypes), "spec.instanceTypes", i))
errs = errs.Also(apis.ErrInvalidArrayValue(fmt.Sprintf("%s not in %v", instanceType, SupportedInstanceTypes), "instanceTypes", i))
}
}
return errs
Expand Down
37 changes: 23 additions & 14 deletions pkg/apis/provisioning/v1alpha1/provisioner_validation_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,22 +65,31 @@ var _ = Describe("Validation", func() {
}
})

It("should fail for restricted labels", func() {
for _, label := range []string{
ArchitectureLabelKey,
OperatingSystemLabelKey,
ProvisionerNameLabelKey,
ProvisionerNamespaceLabelKey,
ProvisionerPhaseLabel,
ProvisionerTTLKey,
ZoneLabelKey,
InstanceTypeLabelKey,
} {
provisioner.Spec.Labels = map[string]string{label: randomdata.SillyName()}
Context("Labels", func() {
It("should fail for invalid label keys", func() {
provisioner.Spec.Labels = map[string]string{"spaces are not allowed": randomdata.SillyName()}
Expect(provisioner.Validate(ctx)).ToNot(Succeed())
}
})
It("should fail for invalid label values", func() {
provisioner.Spec.Labels = map[string]string{randomdata.SillyName(): "/ is not allowed"}
Expect(provisioner.Validate(ctx)).ToNot(Succeed())
})
It("should fail for restricted labels", func() {
for _, label := range []string{
ArchitectureLabelKey,
OperatingSystemLabelKey,
ProvisionerNameLabelKey,
ProvisionerNamespaceLabelKey,
ProvisionerPhaseLabel,
ProvisionerTTLKey,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is an annotation, so you may not need to check this in labels.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch!

ZoneLabelKey,
InstanceTypeLabelKey,
} {
provisioner.Spec.Labels = map[string]string{label: randomdata.SillyName()}
Expect(provisioner.Validate(ctx)).ToNot(Succeed())
}
})
})

Context("Zones", func() {
SupportedZones = append(SupportedZones, "test-zone-1")
It("should succeed if unspecified", func() {
Expand Down
71 changes: 15 additions & 56 deletions pkg/cloudprovider/aws/cloudprovider.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ package aws
import (
"context"
"fmt"
"strings"
"time"

"github.com/aws/aws-sdk-go/aws"
Expand All @@ -28,7 +27,6 @@ import (
"github.com/awslabs/karpenter/pkg/apis/provisioning/v1alpha1"
"github.com/awslabs/karpenter/pkg/cloudprovider"
"github.com/awslabs/karpenter/pkg/cloudprovider/aws/utils"
"github.com/awslabs/karpenter/pkg/utils/functional"
"github.com/awslabs/karpenter/pkg/utils/project"
"github.com/patrickmn/go-cache"
"go.uber.org/zap"
Expand Down Expand Up @@ -108,39 +106,33 @@ func withUserAgent(sess *session.Session) *session.Session {
}

// Create a set of nodes given the constraints.
func (a *CloudProvider) Create(ctx context.Context, provisioner *v1alpha1.Provisioner, packings []*cloudprovider.Packing) ([]*cloudprovider.PackedNode, error) {
func (c *CloudProvider) Create(ctx context.Context, provisioner *v1alpha1.Provisioner, packings []*cloudprovider.Packing) ([]*cloudprovider.PackedNode, error) {
instanceIDs := []*string{}
instancePackings := map[string]*cloudprovider.Packing{}
for _, packing := range packings {
constraints := Constraints(*packing.Constraints)
constraints := Constraints{*packing.Constraints}
// 1. Get Subnets and constrain by zones
zonalSubnets, err := a.subnetProvider.GetZonalSubnets(ctx, provisioner.Spec.Cluster.Name)
subnets, err := c.subnetProvider.Get(ctx, provisioner, &constraints)
if err != nil {
return nil, fmt.Errorf("getting zonal subnets, %w", err)
}
zonalSubnetOptions := map[string][]*ec2.Subnet{}
for zone, subnets := range zonalSubnets {
if len(constraints.Zones) == 0 || functional.ContainsString(constraints.Zones, zone) {
zonalSubnetOptions[zone] = subnets
}
}
// 2. Get Launch Template
launchTemplate, err := a.launchTemplateProvider.Get(ctx, provisioner, &constraints)
launchTemplate, err := c.launchTemplateProvider.Get(ctx, provisioner, &constraints)
if err != nil {
return nil, fmt.Errorf("getting launch template, %w", err)
}
// 3. Create instance
instanceID, err := a.instanceProvider.Create(ctx, launchTemplate, packing.InstanceTypeOptions, zonalSubnets, constraints.GetCapacityType())
instanceID, err := c.instanceProvider.Create(ctx, launchTemplate, packing.InstanceTypeOptions, subnets, constraints.GetCapacityType())
if err != nil {
// TODO Aggregate errors and continue
return nil, fmt.Errorf("creating capacity %w", err)
zap.S().Errorf("Continuing after failing to launch instances, %s", err.Error())
continue
}
instancePackings[*instanceID] = packing
instanceIDs = append(instanceIDs, instanceID)
}

// 4. Convert to Nodes
nodes, err := a.nodeAPI.For(ctx, instanceIDs)
nodes, err := c.nodeAPI.For(ctx, instanceIDs)
if err != nil {
return nil, fmt.Errorf("determining nodes, %w", err)
}
Expand All @@ -158,49 +150,16 @@ func (a *CloudProvider) Create(ctx context.Context, provisioner *v1alpha1.Provis
return packedNodes, nil
}

func (a *CloudProvider) GetInstanceTypes(ctx context.Context) ([]cloudprovider.InstanceType, error) {
return a.instanceTypeProvider.Get(ctx)
func (c *CloudProvider) GetInstanceTypes(ctx context.Context) ([]cloudprovider.InstanceType, error) {
return c.instanceTypeProvider.Get(ctx)
}

func (a *CloudProvider) Terminate(ctx context.Context, nodes []*v1.Node) error {
return a.instanceProvider.Terminate(ctx, nodes)
func (c *CloudProvider) Terminate(ctx context.Context, nodes []*v1.Node) error {
return c.instanceProvider.Terminate(ctx, nodes)
}

// Validate cloud provider specific components of the cluster spec
func (a *CloudProvider) Validate(ctx context.Context, spec *v1alpha1.ProvisionerSpec) (errs *apis.FieldError) {
return errs.Also(
validateAllowedLabels(*spec),
validateCapacityTypeLabel(*spec),
validateLaunchTemplateLabels(*spec),
)
}

func validateAllowedLabels(spec v1alpha1.ProvisionerSpec) (errs *apis.FieldError) {
for key := range spec.Labels {
if strings.HasPrefix(key, AWSLabelPrefix) && !functional.ContainsString(AllowedLabels, key) {
errs = errs.Also(apis.ErrInvalidKeyName(key, "spec.labels"))
}
}
return errs
}

func validateCapacityTypeLabel(spec v1alpha1.ProvisionerSpec) (errs *apis.FieldError) {
capacityType, ok := spec.Labels[CapacityTypeLabel]
if !ok {
return nil
}
capacityTypes := []string{CapacityTypeSpot, CapacityTypeOnDemand}
if !functional.ContainsString(capacityTypes, capacityType) {
errs = errs.Also(apis.ErrInvalidValue(fmt.Sprintf("%s not in %v", capacityType, capacityTypes), fmt.Sprintf("spec.labels[%s]", CapacityTypeLabel)))
}
return errs
}

func validateLaunchTemplateLabels(spec v1alpha1.ProvisionerSpec) (errs *apis.FieldError) {
if _, versionExists := spec.Labels[LaunchTemplateVersionLabel]; versionExists {
if _, bothExist := spec.Labels[LaunchTemplateIdLabel]; !bothExist {
return errs.Also(apis.ErrMissingField(fmt.Sprintf("spec.labels[%s]", LaunchTemplateIdLabel)))
}
}
return errs
func (c *CloudProvider) Validate(ctx context.Context, constraints *v1alpha1.Constraints) (errs *apis.FieldError) {
awsConstraints := Constraints{*constraints}
return awsConstraints.Validate(ctx)
}
Loading