From 4d50be52bea96ec817479ce77c6c8cba404872e5 Mon Sep 17 00:00:00 2001 From: Ellis Tarn Date: Mon, 7 Jun 2021 16:31:07 -0700 Subject: [PATCH] Implemented a well known label for subnets --- go.sum | 2 - pkg/cloudprovider/aws/capacity.go | 18 +- pkg/cloudprovider/aws/constraints.go | 33 +- pkg/cloudprovider/aws/instance.go | 34 +- pkg/cloudprovider/aws/launchtemplate.go | 2 +- pkg/cloudprovider/aws/subnets.go | 74 ++- pkg/cloudprovider/aws/suite_test.go | 591 +++++++++++++----- pkg/cloudprovider/aws/validation.go | 9 + .../v1alpha1/allocation/suite_test.go | 38 +- .../v1alpha1/reallocation/suite_test.go | 6 +- pkg/test/pods.go | 37 +- 11 files changed, 575 insertions(+), 269 deletions(-) diff --git a/go.sum b/go.sum index 9c1f30629c78..31e6ae3dc869 100644 --- a/go.sum +++ b/go.sum @@ -94,8 +94,6 @@ github.com/aws/aws-lambda-go v1.13.3/go.mod h1:4UKl9IzQMoD+QF79YdCuzCwp8VbmG4VAQ github.com/aws/aws-sdk-go v1.23.20/go.mod h1:KmX6BPdI08NWTb3/sm4ZGu5ShLoqVDhKgpiN924inxo= github.com/aws/aws-sdk-go v1.27.0/go.mod h1:KmX6BPdI08NWTb3/sm4ZGu5ShLoqVDhKgpiN924inxo= github.com/aws/aws-sdk-go v1.31.12/go.mod h1:5zCpMtNQVjRREroY7sYe8lOMRSxkhG6MZveU8YkpAk0= -github.com/aws/aws-sdk-go v1.38.11 h1:jmxKh557ZRc+Z8fALnGrL01Ctjks2aSUFLb7n/BZoEs= -github.com/aws/aws-sdk-go v1.38.11/go.mod h1:hcU610XS61/+aQV88ixoOzUoG7v3b31pl2zKMmprdro= github.com/aws/aws-sdk-go v1.38.62 h1:w7r48cTciWCJK//YH+oN8HhNXzPDdlucV3XT6KGDMjE= github.com/aws/aws-sdk-go v1.38.62/go.mod h1:hcU610XS61/+aQV88ixoOzUoG7v3b31pl2zKMmprdro= github.com/aws/aws-sdk-go-v2 v0.18.0/go.mod h1:JWVYvqSMppoMJC0x5wdwiImzgXTI9FuZwxzkQq9wy+g= diff --git a/pkg/cloudprovider/aws/capacity.go b/pkg/cloudprovider/aws/capacity.go index b8f07f112e3e..55d54edd1755 100644 --- a/pkg/cloudprovider/aws/capacity.go +++ b/pkg/cloudprovider/aws/capacity.go @@ -18,10 +18,8 @@ import ( "context" "fmt" - "github.com/aws/aws-sdk-go/service/ec2" "github.com/awslabs/karpenter/pkg/apis/provisioning/v1alpha1" "github.com/awslabs/karpenter/pkg/cloudprovider" - "github.com/awslabs/karpenter/pkg/utils/functional" ) // Capacity cloud provider implementation using AWS Fleet. @@ -41,23 +39,23 @@ func (c *Capacity) Create(ctx context.Context, packings []*cloudprovider.Packing for _, packing := range packings { constraints := Constraints(*packing.Constraints) // 1. Get Subnets and constrain by zones - zonalSubnets, err := c.subnetProvider.GetZonalSubnets(ctx, c.provisioner.Spec.Cluster.Name) + subnets, err := c.subnetProvider.Get(ctx, c.provisioner, &constraints) if err != nil { return nil, fmt.Errorf("getting zonal subnets, %w", err) } - zonalSubnetOptions := map[string][]*ec2.Subnet{} - for zone, subnets := range zonalSubnets { - if len(constraints.Zones) == 0 || functional.ContainsString(constraints.Zones, zone) { - zonalSubnetOptions[zone] = subnets - } - } + // zonalSubnetOptions := map[string][]*ec2.Subnet{} + // for zone, subnets := range zonalSubnets { + // if len(constraints.Zones) == 0 || functional.ContainsString(constraints.Zones, zone) { + // zonalSubnetOptions[zone] = subnets + // } + // } // 2. Get Launch Template launchTemplate, err := c.launchTemplateProvider.Get(ctx, c.provisioner, &constraints) if err != nil { return nil, fmt.Errorf("getting launch template, %w", err) } // 3. Create instance - instanceID, err := c.instanceProvider.Create(ctx, launchTemplate, packing.InstanceTypeOptions, zonalSubnets, constraints.GetCapacityType()) + instanceID, err := c.instanceProvider.Create(ctx, launchTemplate, packing.InstanceTypeOptions, subnets, constraints.GetCapacityType()) if err != nil { // TODO Aggregate errors and continue return nil, fmt.Errorf("creating capacity %w", err) diff --git a/pkg/cloudprovider/aws/constraints.go b/pkg/cloudprovider/aws/constraints.go index dcd5e1ae4784..879fb3cf842a 100644 --- a/pkg/cloudprovider/aws/constraints.go +++ b/pkg/cloudprovider/aws/constraints.go @@ -15,6 +15,7 @@ limitations under the License. package aws import ( + "github.com/aws/aws-sdk-go/aws" "github.com/awslabs/karpenter/pkg/apis/provisioning/v1alpha1" "github.com/awslabs/karpenter/pkg/utils/functional" ) @@ -22,7 +23,7 @@ import ( const ( CapacityTypeSpot = "spot" CapacityTypeOnDemand = "on-demand" - defaultLaunchTemplateVersion = "$Default" + DefaultLaunchTemplateVersion = "$Default" ) var ( @@ -30,8 +31,16 @@ var ( CapacityTypeLabel = AWSLabelPrefix + "capacity-type" LaunchTemplateIdLabel = AWSLabelPrefix + "launch-template-id" LaunchTemplateVersionLabel = AWSLabelPrefix + "launch-template-version" - AllowedLabels = []string{CapacityTypeLabel, LaunchTemplateIdLabel, LaunchTemplateVersionLabel} - AWSToKubeArchitectures = map[string]string{ + SubnetNameLabel = AWSLabelPrefix + "subnet" + SubnetTagKeyLabel = AWSLabelPrefix + "subnet-tag-key" + AllowedLabels = []string{ + CapacityTypeLabel, + LaunchTemplateIdLabel, + LaunchTemplateVersionLabel, + SubnetNameLabel, + SubnetTagKeyLabel, + } + AWSToKubeArchitectures = map[string]string{ "x86_64": v1alpha1.ArchitectureAmd64, v1alpha1.ArchitectureArm64: v1alpha1.ArchitectureArm64, } @@ -61,10 +70,26 @@ func (c *Constraints) GetLaunchTemplate() *LaunchTemplate { } version, ok := c.Labels[LaunchTemplateVersionLabel] if !ok { - version = defaultLaunchTemplateVersion + version = DefaultLaunchTemplateVersion } return &LaunchTemplate{ Id: &id, Version: &version, } } + +func (c *Constraints) GetSubnetName() *string { + subnetName, ok := c.Labels[SubnetNameLabel] + if !ok { + return nil + } + return aws.String(subnetName) +} + +func (c *Constraints) GetSubnetTagKey() *string { + subnetTag, ok := c.Labels[SubnetTagKeyLabel] + if !ok { + return nil + } + return aws.String(subnetTag) +} diff --git a/pkg/cloudprovider/aws/instance.go b/pkg/cloudprovider/aws/instance.go index 10a3ddc56044..6086b77ff053 100644 --- a/pkg/cloudprovider/aws/instance.go +++ b/pkg/cloudprovider/aws/instance.go @@ -17,7 +17,6 @@ package aws import ( "context" "fmt" - "math/rand" "strings" "github.com/aws/aws-sdk-go/aws" @@ -40,29 +39,30 @@ type InstanceProvider struct { func (p *InstanceProvider) Create(ctx context.Context, launchTemplate *LaunchTemplate, instanceTypeOptions []cloudprovider.InstanceType, - zonalSubnetOptions map[string][]*ec2.Subnet, + subnets []*ec2.Subnet, capacityType string, ) (*string, error) { // 1. Construct override options. var overrides []*ec2.FleetLaunchTemplateOverridesRequest for i, instanceType := range instanceTypeOptions { for _, zone := range instanceType.Zones() { - subnets := zonalSubnetOptions[zone] - if len(subnets) == 0 { - continue + for _, subnet := range subnets { + if aws.StringValue(subnet.AvailabilityZone) == zone { + override := &ec2.FleetLaunchTemplateOverridesRequest{ + InstanceType: aws.String(instanceType.Name()), + SubnetId: subnet.SubnetId, + } + // Add a priority for spot requests since we are using the capacity-optimized-prioritized spot allocation strategy + // to reduce the likelihood of getting an excessively large instance type. + // instanceTypeOptions are sorted by vcpus and memory so this prioritizes smaller instance types. + if capacityType == CapacityTypeSpot { + override.Priority = aws.Float64(float64(i)) + } + overrides = append(overrides, override) + // FleetAPI cannot span subnets from the same AZ, so break after the first one. + break + } } - override := &ec2.FleetLaunchTemplateOverridesRequest{ - InstanceType: aws.String(instanceType.Name()), - // FleetAPI cannot span subnets from the same AZ, so randomize. - SubnetId: aws.String(*subnets[rand.Intn(len(subnets))].SubnetId), - } - // Add a priority for spot requests since we are using the capacity-optimized-prioritized spot allocation strategy - // to reduce the likelihood of getting an excessively large instance type. - // instanceTypeOptions are sorted by vcpus and memory so this prioritizes smaller instance types. - if capacityType == CapacityTypeSpot { - override.Priority = aws.Float64(float64(i)) - } - overrides = append(overrides, override) } } // 2. Create fleet diff --git a/pkg/cloudprovider/aws/launchtemplate.go b/pkg/cloudprovider/aws/launchtemplate.go index df7821f62483..cbea03f7d302 100644 --- a/pkg/cloudprovider/aws/launchtemplate.go +++ b/pkg/cloudprovider/aws/launchtemplate.go @@ -101,7 +101,7 @@ func (p *LaunchTemplateProvider) Get(ctx context.Context, provisioner *v1alpha1. return nil, fmt.Errorf("hashing launch template, %w", err) } - result := &LaunchTemplate{Version: aws.String(defaultLaunchTemplateVersion)} + result := &LaunchTemplate{Version: aws.String(DefaultLaunchTemplateVersion)} if cached, ok := p.cache.Get(fmt.Sprint(key)); ok { result.Id = cached.(*ec2.LaunchTemplate).LaunchTemplateId return result, nil diff --git a/pkg/cloudprovider/aws/subnets.go b/pkg/cloudprovider/aws/subnets.go index 0483cec9b9cb..379bea1fb1d0 100644 --- a/pkg/cloudprovider/aws/subnets.go +++ b/pkg/cloudprovider/aws/subnets.go @@ -21,6 +21,7 @@ import ( "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/service/ec2" "github.com/aws/aws-sdk-go/service/ec2/ec2iface" + "github.com/awslabs/karpenter/pkg/apis/provisioning/v1alpha1" "github.com/patrickmn/go-cache" "go.uber.org/zap" ) @@ -37,32 +38,67 @@ func NewSubnetProvider(ec2api ec2iface.EC2API) *SubnetProvider { } } -func (s *SubnetProvider) GetZonalSubnets(ctx context.Context, clusterName string) (map[string][]*ec2.Subnet, error) { - if zonalSubnets, ok := s.cache.Get(clusterName); ok { - return zonalSubnets.(map[string][]*ec2.Subnet), nil - } - zonalSubnets, err := s.getZonalSubnets(ctx, clusterName) +func (s *SubnetProvider) Get(ctx context.Context, provisioner *v1alpha1.Provisioner, constraints *Constraints) ([]*ec2.Subnet, error) { + // 1. Get all viable subnets for this provisioner + subnets, err := s.getSubnets(ctx, provisioner) if err != nil { return nil, err } - s.cache.Set(clusterName, zonalSubnets, CacheTTL) - zap.S().Debugf("Successfully discovered subnets in %d zones for cluster %s", len(zonalSubnets), clusterName) - return zonalSubnets, nil + // 2. Filter by subnet name if constrained + if name := constraints.GetSubnetName(); name != nil { + subnets = filter(byName(aws.StringValue(name)), subnets) + } + // 2. Filter by subnet tag key if constrained + if tagKey := constraints.GetSubnetTagKey(); tagKey != nil { + subnets = filter(byTagKey(*tagKey), subnets) + } + return subnets, nil } -func (s *SubnetProvider) getZonalSubnets(ctx context.Context, clusterName string) (map[string][]*ec2.Subnet, error) { - describeSubnetOutput, err := s.ec2api.DescribeSubnetsWithContext(ctx, &ec2.DescribeSubnetsInput{ - Filters: []*ec2.Filter{{ - Name: aws.String("tag-key"), - Values: []*string{aws.String(fmt.Sprintf(ClusterTagKeyFormat, clusterName))}, - }}, - }) +func (s *SubnetProvider) getSubnets(ctx context.Context, provisioner *v1alpha1.Provisioner) ([]*ec2.Subnet, error) { + if subnets, ok := s.cache.Get(provisioner.Spec.Cluster.Name); ok { + return subnets.([]*ec2.Subnet), nil + } + output, err := s.ec2api.DescribeSubnetsWithContext(ctx, &ec2.DescribeSubnetsInput{Filters: []*ec2.Filter{{ + Name: aws.String("tag-key"), // Subnets must be tagged for the cluster + Values: []*string{aws.String(fmt.Sprintf(ClusterTagKeyFormat, provisioner.Spec.Cluster.Name))}, + }}}) if err != nil { return nil, fmt.Errorf("describing subnets, %w", err) } - zonalSubnetMap := map[string][]*ec2.Subnet{} - for _, subnet := range describeSubnetOutput.Subnets { - zonalSubnetMap[*subnet.AvailabilityZone] = append(zonalSubnetMap[*subnet.AvailabilityZone], subnet) + zap.S().Debugf("Successfully discovered %d subnets for cluster %s", len(output.Subnets), provisioner.Spec.Cluster.Name) + s.cache.Set(provisioner.Spec.Cluster.Name, output.Subnets, CacheTTL) + return output.Subnets, nil +} + +func filter(predicate func(*ec2.Subnet) bool, subnets []*ec2.Subnet) []*ec2.Subnet { + result := []*ec2.Subnet{} + for _, subnet := range subnets { + if predicate(subnet) { + result = append(result, subnet) + } + } + return result +} + +func byName(name string) func(*ec2.Subnet) bool { + return func(subnet *ec2.Subnet) bool { + for _, tag := range subnet.Tags { + if aws.StringValue(tag.Key) == "Name" { + return aws.StringValue(tag.Value) == name + } + } + return false + } +} + +func byTagKey(tagKey string) func(*ec2.Subnet) bool { + return func(subnet *ec2.Subnet) bool { + for _, tag := range subnet.Tags { + if aws.StringValue(tag.Key) == tagKey { + return true + } + } + return false } - return zonalSubnetMap, nil } diff --git a/pkg/cloudprovider/aws/suite_test.go b/pkg/cloudprovider/aws/suite_test.go index eab715edcd05..bf26ba531778 100644 --- a/pkg/cloudprovider/aws/suite_test.go +++ b/pkg/cloudprovider/aws/suite_test.go @@ -129,194 +129,439 @@ var _ = Describe("Allocation", func() { }) Context("Reconciliation", func() { - It("should default to a cluster zone", func() { - // Setup - pod := test.PendingPod() - ExpectCreatedWithStatus(env.Client, pod) - ExpectCreated(env.Client, provisioner) - ExpectEventuallyReconciled(env.Client, provisioner) - // Assertions - scheduled := ExpectPodExists(env.Client, pod.GetName(), pod.GetNamespace()) - ExpectNodeExists(env.Client, scheduled.Spec.NodeName) - Expect(fakeEC2API.CalledWithCreateFleetInput).To(HaveLen(1)) - Expect(fakeEC2API.CalledWithCreateFleetInput[0].LaunchTemplateConfigs[0].Overrides).To( - ContainElements( - &ec2.FleetLaunchTemplateOverridesRequest{ - InstanceType: aws.String("m5.large"), - SubnetId: aws.String("test-subnet-1"), + Context("Zone", func() { + It("should default to a cluster zone", func() { + // Setup + pod := test.PendingPod() + ExpectCreatedWithStatus(env.Client, pod) + ExpectCreated(env.Client, provisioner) + ExpectEventuallyReconciled(env.Client, provisioner) + // Assertions + scheduled := ExpectPodExists(env.Client, pod.GetName(), pod.GetNamespace()) + ExpectNodeExists(env.Client, scheduled.Spec.NodeName) + Expect(fakeEC2API.CalledWithCreateFleetInput).To(HaveLen(1)) + Expect(fakeEC2API.CalledWithCreateFleetInput[0].LaunchTemplateConfigs[0].Overrides).To( + ContainElements( + &ec2.FleetLaunchTemplateOverridesRequest{ + InstanceType: aws.String("m5.large"), + SubnetId: aws.String("test-subnet-1"), + }, + &ec2.FleetLaunchTemplateOverridesRequest{ + InstanceType: aws.String("m5.large"), + SubnetId: aws.String("test-subnet-2"), + }, + &ec2.FleetLaunchTemplateOverridesRequest{ + InstanceType: aws.String("m5.large"), + SubnetId: aws.String("test-subnet-3"), + }, + )) + }) + It("should default to a provisioner's zone", func() { + // Setup + provisioner.Spec.Zones = []string{"test-zone-1a", "test-zone-1b"} + pod := test.PendingPod() + ExpectCreatedWithStatus(env.Client, pod) + ExpectCreated(env.Client, provisioner) + ExpectEventuallyReconciled(env.Client, provisioner) + // Assertions + scheduled := ExpectPodExists(env.Client, pod.GetName(), pod.GetNamespace()) + ExpectNodeExists(env.Client, scheduled.Spec.NodeName) + Expect(fakeEC2API.CalledWithCreateFleetInput).To(HaveLen(1)) + Expect(fakeEC2API.CalledWithCreateFleetInput[0].LaunchTemplateConfigs[0].Overrides).To( + ContainElements( + &ec2.FleetLaunchTemplateOverridesRequest{ + InstanceType: aws.String("m5.large"), + SubnetId: aws.String("test-subnet-1"), + }, + &ec2.FleetLaunchTemplateOverridesRequest{ + InstanceType: aws.String("m5.large"), + SubnetId: aws.String("test-subnet-2"), + }, + ), + ) + }) + It("should allow a pod to override the zone", func() { + // Setup + provisioner.Spec.Zones = []string{"test-zone-1a", "test-zone-1b"} + pod := test.PendingPod(test.PodOptions{NodeSelector: map[string]string{v1alpha1.ZoneLabelKey: "test-zone-1c"}}) + ExpectCreatedWithStatus(env.Client, pod) + ExpectCreated(env.Client, provisioner) + ExpectEventuallyReconciled(env.Client, provisioner) + // Assertions + scheduled := ExpectPodExists(env.Client, pod.GetName(), pod.GetNamespace()) + ExpectNodeExists(env.Client, scheduled.Spec.NodeName) + Expect(fakeEC2API.CalledWithCreateFleetInput).To(HaveLen(1)) + Expect(fakeEC2API.CalledWithCreateFleetInput[0].LaunchTemplateConfigs[0].Overrides).To( + ContainElements( + &ec2.FleetLaunchTemplateOverridesRequest{ + InstanceType: aws.String("m5.large"), + SubnetId: aws.String("test-subnet-3"), + }, + ), + ) + }) + }) + Context("InstanceType", func() { + It("should launch instances for Nvidia GPU resource requests", func() { + // Setup + pod1 := test.PendingPod(test.PodOptions{ + ResourceRequirements: v1.ResourceRequirements{ + Requests: v1.ResourceList{resources.NvidiaGPU: resource.MustParse("1")}, + Limits: v1.ResourceList{resources.NvidiaGPU: resource.MustParse("1")}, }, - &ec2.FleetLaunchTemplateOverridesRequest{ - InstanceType: aws.String("m5.large"), - SubnetId: aws.String("test-subnet-2"), + }) + // Should pack onto same instance + pod2 := test.PendingPod(test.PodOptions{ + ResourceRequirements: v1.ResourceRequirements{ + Requests: v1.ResourceList{resources.NvidiaGPU: resource.MustParse("2")}, + Limits: v1.ResourceList{resources.NvidiaGPU: resource.MustParse("2")}, }, - &ec2.FleetLaunchTemplateOverridesRequest{ - InstanceType: aws.String("m5.large"), - SubnetId: aws.String("test-subnet-3"), + }) + // Should pack onto a separate instance + pod3 := test.PendingPod(test.PodOptions{ + ResourceRequirements: v1.ResourceRequirements{ + Requests: v1.ResourceList{resources.NvidiaGPU: resource.MustParse("4")}, + Limits: v1.ResourceList{resources.NvidiaGPU: resource.MustParse("4")}, }, - )) - }) - It("should default to a provisioner's zone", func() { - // Setup - provisioner.Spec.Zones = []string{"test-zone-1a", "test-zone-1b"} - pod := test.PendingPod() - ExpectCreatedWithStatus(env.Client, pod) - ExpectCreated(env.Client, provisioner) - ExpectEventuallyReconciled(env.Client, provisioner) - // Assertions - scheduled := ExpectPodExists(env.Client, pod.GetName(), pod.GetNamespace()) - ExpectNodeExists(env.Client, scheduled.Spec.NodeName) - Expect(fakeEC2API.CalledWithCreateFleetInput).To(HaveLen(1)) - Expect(fakeEC2API.CalledWithCreateFleetInput[0].LaunchTemplateConfigs[0].Overrides).To( - ContainElements( - &ec2.FleetLaunchTemplateOverridesRequest{ - InstanceType: aws.String("m5.large"), - SubnetId: aws.String("test-subnet-1"), + }) + ExpectCreatedWithStatus(env.Client, pod1, pod2, pod3) + ExpectCreated(env.Client, provisioner) + ExpectEventuallyReconciled(env.Client, provisioner) + // Assertions + scheduled1 := ExpectPodExists(env.Client, pod1.GetName(), pod1.GetNamespace()) + scheduled2 := ExpectPodExists(env.Client, pod2.GetName(), pod2.GetNamespace()) + scheduled3 := ExpectPodExists(env.Client, pod3.GetName(), pod3.GetNamespace()) + Expect(scheduled1.Spec.NodeName).To(Equal(scheduled2.Spec.NodeName)) + Expect(scheduled1.Spec.NodeName).ToNot(Equal(scheduled3.Spec.NodeName)) + ExpectNodeExists(env.Client, scheduled1.Spec.NodeName) + ExpectNodeExists(env.Client, scheduled3.Spec.NodeName) + Expect(fakeEC2API.CalledWithCreateFleetInput[0].LaunchTemplateConfigs[0].Overrides).To( + ContainElements( + &ec2.FleetLaunchTemplateOverridesRequest{ + InstanceType: aws.String("p3.8xlarge"), + SubnetId: aws.String("test-subnet-1"), + }, + ), + ) + Expect(fakeEC2API.CalledWithCreateFleetInput[1].LaunchTemplateConfigs[0].Overrides).To( + ContainElements( + &ec2.FleetLaunchTemplateOverridesRequest{ + InstanceType: aws.String("p3.8xlarge"), + SubnetId: aws.String("test-subnet-1"), + }, + ), + ) + }) + It("should launch instances for AWS Neuron resource requests", func() { + // Setup + pod1 := test.PendingPod(test.PodOptions{ + ResourceRequirements: v1.ResourceRequirements{ + Requests: v1.ResourceList{resources.AWSNeuron: resource.MustParse("1")}, + Limits: v1.ResourceList{resources.AWSNeuron: resource.MustParse("1")}, }, - &ec2.FleetLaunchTemplateOverridesRequest{ - InstanceType: aws.String("m5.large"), - SubnetId: aws.String("test-subnet-2"), + }) + // Should pack onto same instance + pod2 := test.PendingPod(test.PodOptions{ + ResourceRequirements: v1.ResourceRequirements{ + Requests: v1.ResourceList{resources.AWSNeuron: resource.MustParse("2")}, + Limits: v1.ResourceList{resources.AWSNeuron: resource.MustParse("2")}, }, - ), - ) - }) - It("should allow pod to override default zone", func() { - // Setup - provisioner.Spec.Zones = []string{"test-zone-1a", "test-zone-1b"} - pod := test.PendingPodWith(test.PodOptions{NodeSelector: map[string]string{v1alpha1.ZoneLabelKey: "test-zone-1c"}}) - ExpectCreatedWithStatus(env.Client, pod) - ExpectCreated(env.Client, provisioner) - ExpectEventuallyReconciled(env.Client, provisioner) - // Assertions - scheduled := ExpectPodExists(env.Client, pod.GetName(), pod.GetNamespace()) - ExpectNodeExists(env.Client, scheduled.Spec.NodeName) - Expect(fakeEC2API.CalledWithCreateFleetInput).To(HaveLen(1)) - Expect(fakeEC2API.CalledWithCreateFleetInput[0].LaunchTemplateConfigs[0].Overrides).To( - ContainElements( - &ec2.FleetLaunchTemplateOverridesRequest{ - InstanceType: aws.String("m5.large"), - SubnetId: aws.String("test-subnet-3"), + }) + // Should pack onto a separate instance + pod3 := test.PendingPod(test.PodOptions{ + ResourceRequirements: v1.ResourceRequirements{ + Requests: v1.ResourceList{resources.AWSNeuron: resource.MustParse("4")}, + Limits: v1.ResourceList{resources.AWSNeuron: resource.MustParse("4")}, }, - ), - ) + }) + ExpectCreatedWithStatus(env.Client, pod1, pod2, pod3) + ExpectCreated(env.Client, provisioner) + ExpectEventuallyReconciled(env.Client, provisioner) + // Assertions + scheduled1 := ExpectPodExists(env.Client, pod1.GetName(), pod1.GetNamespace()) + scheduled2 := ExpectPodExists(env.Client, pod2.GetName(), pod2.GetNamespace()) + scheduled3 := ExpectPodExists(env.Client, pod3.GetName(), pod3.GetNamespace()) + Expect(scheduled1.Spec.NodeName).To(Equal(scheduled2.Spec.NodeName)) + Expect(scheduled1.Spec.NodeName).ToNot(Equal(scheduled3.Spec.NodeName)) + ExpectNodeExists(env.Client, scheduled1.Spec.NodeName) + ExpectNodeExists(env.Client, scheduled3.Spec.NodeName) + Expect(fakeEC2API.CalledWithCreateFleetInput[0].LaunchTemplateConfigs[0].Overrides).To( + ContainElements( + &ec2.FleetLaunchTemplateOverridesRequest{ + InstanceType: aws.String("inf1.6xlarge"), + SubnetId: aws.String("test-subnet-1"), + }, + ), + ) + Expect(fakeEC2API.CalledWithCreateFleetInput[1].LaunchTemplateConfigs[0].Overrides).To( + ContainElements( + &ec2.FleetLaunchTemplateOverridesRequest{ + InstanceType: aws.String("inf1.6xlarge"), + SubnetId: aws.String("test-subnet-1"), + }, + ), + ) + }) }) - It("should launch nodes for pods with different node selectors", func() { - // Setup - lt1 := "abc123" - lt2 := "34sy4s" - pod1 := test.PendingPodWith(test.PodOptions{NodeSelector: map[string]string{LaunchTemplateIdLabel: lt1}}) - pod2 := test.PendingPodWith(test.PodOptions{NodeSelector: map[string]string{LaunchTemplateIdLabel: lt2}}) - ExpectCreatedWithStatus(env.Client, pod1, pod2) - ExpectCreated(env.Client, provisioner) - ExpectEventuallyReconciled(env.Client, provisioner) - // Assertions - scheduled1 := ExpectPodExists(env.Client, pod1.GetName(), pod1.GetNamespace()) - scheduled2 := ExpectPodExists(env.Client, pod2.GetName(), pod2.GetNamespace()) - node1 := ExpectNodeExists(env.Client, scheduled1.Spec.NodeName) - node2 := ExpectNodeExists(env.Client, scheduled2.Spec.NodeName) - Expect(scheduled1.Spec.NodeName).NotTo(Equal(scheduled2.Spec.NodeName)) - Expect(fakeEC2API.CalledWithCreateFleetInput).To(HaveLen(2)) - Expect(node1.ObjectMeta.Labels).To(HaveKeyWithValue(LaunchTemplateIdLabel, lt1)) - Expect(node2.ObjectMeta.Labels).To(HaveKeyWithValue(LaunchTemplateIdLabel, lt2)) + Context("CapacityType", func() { + It("should default to on demand", func() { + // Setup + pod := test.PendingPod() + ExpectCreatedWithStatus(env.Client, pod) + ExpectCreated(env.Client, provisioner) + ExpectEventuallyReconciled(env.Client, provisioner) + // Assertions + scheduled := ExpectPodExists(env.Client, pod.GetName(), pod.GetNamespace()) + node := ExpectNodeExists(env.Client, scheduled.Spec.NodeName) + Expect(fakeEC2API.CalledWithCreateFleetInput).To(HaveLen(1)) + Expect(fakeEC2API.CalledWithCreateFleetInput[0].LaunchTemplateConfigs).To(HaveLen(1)) + Expect(aws.StringValue(fakeEC2API.CalledWithCreateFleetInput[0].TargetCapacitySpecification.DefaultTargetCapacityType)).To(Equal(CapacityTypeOnDemand)) + Expect(node.Labels).ToNot(HaveKey(CapacityTypeLabel)) + }) + It("should default to a provisioner's specified capacity type", func() { + // Setup + pod := test.PendingPod() + ExpectCreatedWithStatus(env.Client, pod) + provisioner.Spec.Labels = map[string]string{CapacityTypeLabel: CapacityTypeSpot} + ExpectCreated(env.Client, provisioner) + ExpectEventuallyReconciled(env.Client, provisioner) + // Assertions + scheduled := ExpectPodExists(env.Client, pod.GetName(), pod.GetNamespace()) + node := ExpectNodeExists(env.Client, scheduled.Spec.NodeName) + Expect(fakeEC2API.CalledWithCreateFleetInput).To(HaveLen(1)) + Expect(fakeEC2API.CalledWithCreateFleetInput[0].LaunchTemplateConfigs).To(HaveLen(1)) + Expect(aws.StringValue(fakeEC2API.CalledWithCreateFleetInput[0].TargetCapacitySpecification.DefaultTargetCapacityType)).To(Equal(CapacityTypeSpot)) + Expect(node.Labels).To(HaveKeyWithValue(CapacityTypeLabel, CapacityTypeSpot)) + }) + It("should allow a pod to override the capacity type", func() { + // Setup + pod := test.PendingPod(test.PodOptions{NodeSelector: map[string]string{CapacityTypeLabel: CapacityTypeSpot}}) + ExpectCreatedWithStatus(env.Client, pod) + ExpectCreated(env.Client, provisioner) + ExpectEventuallyReconciled(env.Client, provisioner) + // Assertions + scheduled := ExpectPodExists(env.Client, pod.GetName(), pod.GetNamespace()) + node := ExpectNodeExists(env.Client, scheduled.Spec.NodeName) + Expect(fakeEC2API.CalledWithCreateFleetInput).To(HaveLen(1)) + Expect(fakeEC2API.CalledWithCreateFleetInput[0].LaunchTemplateConfigs).To(HaveLen(1)) + Expect(aws.StringValue(fakeEC2API.CalledWithCreateFleetInput[0].TargetCapacitySpecification.DefaultTargetCapacityType)).To(Equal(CapacityTypeSpot)) + Expect(node.Labels).To(HaveKeyWithValue(CapacityTypeLabel, CapacityTypeSpot)) + }) }) - It("should launch instances for Nvidia GPU resource requests", func() { - // Setup - pod1 := test.PendingPodWith(test.PodOptions{ - ResourceRequirements: v1.ResourceRequirements{ - Requests: v1.ResourceList{resources.NvidiaGPU: resource.MustParse("1")}, - Limits: v1.ResourceList{resources.NvidiaGPU: resource.MustParse("1")}, - }, + Context("LaunchTemplates", func() { + It("should default to a generated launch template", func() { + // Setup + pod := test.PendingPod() + ExpectCreatedWithStatus(env.Client, pod) + ExpectCreated(env.Client, provisioner) + ExpectEventuallyReconciled(env.Client, provisioner) + // Assertions + scheduled := ExpectPodExists(env.Client, pod.GetName(), pod.GetNamespace()) + node := ExpectNodeExists(env.Client, scheduled.Spec.NodeName) + Expect(fakeEC2API.CalledWithCreateFleetInput).To(HaveLen(1)) + Expect(fakeEC2API.CalledWithCreateFleetInput[0].LaunchTemplateConfigs).To(HaveLen(1)) + launchTemplate := fakeEC2API.CalledWithCreateFleetInput[0].LaunchTemplateConfigs[0].LaunchTemplateSpecification + Expect(aws.StringValue(launchTemplate.LaunchTemplateId)).To(Equal("test-launch-template-id")) + Expect(aws.StringValue(launchTemplate.Version)).To(Equal(DefaultLaunchTemplateVersion)) + Expect(node.Labels).ToNot(HaveKey(LaunchTemplateIdLabel)) + Expect(node.Labels).ToNot(HaveKey(LaunchTemplateVersionLabel)) }) - // Should pack onto same instance - pod2 := test.PendingPodWith(test.PodOptions{ - ResourceRequirements: v1.ResourceRequirements{ - Requests: v1.ResourceList{resources.NvidiaGPU: resource.MustParse("2")}, - Limits: v1.ResourceList{resources.NvidiaGPU: resource.MustParse("2")}, - }, + It("should default to a provisioner's launch template id and version", func() { + // Setup + pod := test.PendingPod() + ExpectCreatedWithStatus(env.Client, pod) + provisioner.Spec.Labels = map[string]string{ + LaunchTemplateIdLabel: randomdata.SillyName(), + LaunchTemplateVersionLabel: randomdata.SillyName(), + } + ExpectCreated(env.Client, provisioner) + ExpectEventuallyReconciled(env.Client, provisioner) + // Assertions + scheduled := ExpectPodExists(env.Client, pod.GetName(), pod.GetNamespace()) + node := ExpectNodeExists(env.Client, scheduled.Spec.NodeName) + Expect(fakeEC2API.CalledWithCreateFleetInput).To(HaveLen(1)) + Expect(fakeEC2API.CalledWithCreateFleetInput[0].LaunchTemplateConfigs).To(HaveLen(1)) + launchTemplate := fakeEC2API.CalledWithCreateFleetInput[0].LaunchTemplateConfigs[0].LaunchTemplateSpecification + Expect(aws.StringValue(launchTemplate.LaunchTemplateId)).To(Equal(provisioner.Spec.Labels[LaunchTemplateIdLabel])) + Expect(aws.StringValue(launchTemplate.Version)).To(Equal(provisioner.Spec.Labels[LaunchTemplateVersionLabel])) + Expect(node.Labels).To(HaveKeyWithValue(LaunchTemplateIdLabel, provisioner.Spec.Labels[LaunchTemplateIdLabel])) + Expect(node.Labels).To(HaveKeyWithValue(LaunchTemplateVersionLabel, provisioner.Spec.Labels[LaunchTemplateVersionLabel])) }) - // Should pack onto a separate instance - pod3 := test.PendingPodWith(test.PodOptions{ - ResourceRequirements: v1.ResourceRequirements{ - Requests: v1.ResourceList{resources.NvidiaGPU: resource.MustParse("4")}, - Limits: v1.ResourceList{resources.NvidiaGPU: resource.MustParse("4")}, - }, + It("should default to a provisioner's launch template and the default launch template version", func() { + // Setup + pod := test.PendingPod() + ExpectCreatedWithStatus(env.Client, pod) + provisioner.Spec.Labels = map[string]string{LaunchTemplateIdLabel: randomdata.SillyName()} + ExpectCreated(env.Client, provisioner) + ExpectEventuallyReconciled(env.Client, provisioner) + // Assertions + scheduled := ExpectPodExists(env.Client, pod.GetName(), pod.GetNamespace()) + node := ExpectNodeExists(env.Client, scheduled.Spec.NodeName) + Expect(fakeEC2API.CalledWithCreateFleetInput).To(HaveLen(1)) + Expect(fakeEC2API.CalledWithCreateFleetInput[0].LaunchTemplateConfigs).To(HaveLen(1)) + launchTemplate := fakeEC2API.CalledWithCreateFleetInput[0].LaunchTemplateConfigs[0].LaunchTemplateSpecification + Expect(aws.StringValue(launchTemplate.LaunchTemplateId)).To(Equal(provisioner.Spec.Labels[LaunchTemplateIdLabel])) + Expect(aws.StringValue(launchTemplate.Version)).To(Equal(DefaultLaunchTemplateVersion)) + Expect(node.Labels).To(HaveKeyWithValue(LaunchTemplateIdLabel, provisioner.Spec.Labels[LaunchTemplateIdLabel])) + Expect(node.Labels).ToNot(HaveKey(LaunchTemplateVersionLabel)) + }) + It("should allow a pod to override the launch template id and version", func() { + // Setup + pod := test.PendingPod(test.PodOptions{NodeSelector: map[string]string{ + LaunchTemplateIdLabel: randomdata.SillyName(), + LaunchTemplateVersionLabel: randomdata.SillyName(), + }}) + ExpectCreatedWithStatus(env.Client, pod) + provisioner.Spec.Labels = map[string]string{ + LaunchTemplateIdLabel: randomdata.SillyName(), + LaunchTemplateVersionLabel: randomdata.SillyName(), + } + ExpectCreated(env.Client, provisioner) + ExpectEventuallyReconciled(env.Client, provisioner) + // Assertions + scheduled := ExpectPodExists(env.Client, pod.GetName(), pod.GetNamespace()) + node := ExpectNodeExists(env.Client, scheduled.Spec.NodeName) + Expect(fakeEC2API.CalledWithCreateFleetInput).To(HaveLen(1)) + Expect(fakeEC2API.CalledWithCreateFleetInput[0].LaunchTemplateConfigs).To(HaveLen(1)) + launchTemplate := fakeEC2API.CalledWithCreateFleetInput[0].LaunchTemplateConfigs[0].LaunchTemplateSpecification + Expect(aws.StringValue(launchTemplate.LaunchTemplateId)).To(Equal(pod.Spec.NodeSelector[LaunchTemplateIdLabel])) + Expect(aws.StringValue(launchTemplate.Version)).To(Equal(pod.Spec.NodeSelector[LaunchTemplateVersionLabel])) + Expect(node.Labels).To(HaveKeyWithValue(LaunchTemplateIdLabel, pod.Spec.NodeSelector[LaunchTemplateIdLabel])) + Expect(node.Labels).To(HaveKeyWithValue(LaunchTemplateVersionLabel, pod.Spec.NodeSelector[LaunchTemplateVersionLabel])) + }) + It("should allow a pod to override the launch template id and use the default launch template version", func() { + // Setup + pod := test.PendingPod(test.PodOptions{NodeSelector: map[string]string{LaunchTemplateIdLabel: randomdata.SillyName()}}) + ExpectCreatedWithStatus(env.Client, pod) + provisioner.Spec.Labels = map[string]string{LaunchTemplateIdLabel: randomdata.SillyName()} + ExpectCreated(env.Client, provisioner) + ExpectEventuallyReconciled(env.Client, provisioner) + // Assertions + scheduled := ExpectPodExists(env.Client, pod.GetName(), pod.GetNamespace()) + node := ExpectNodeExists(env.Client, scheduled.Spec.NodeName) + Expect(fakeEC2API.CalledWithCreateFleetInput).To(HaveLen(1)) + Expect(fakeEC2API.CalledWithCreateFleetInput[0].LaunchTemplateConfigs).To(HaveLen(1)) + launchTemplate := fakeEC2API.CalledWithCreateFleetInput[0].LaunchTemplateConfigs[0].LaunchTemplateSpecification + Expect(aws.StringValue(launchTemplate.LaunchTemplateId)).To(Equal(pod.Spec.NodeSelector[LaunchTemplateIdLabel])) + Expect(aws.StringValue(launchTemplate.Version)).To(Equal(DefaultLaunchTemplateVersion)) + Expect(node.Labels).To(HaveKeyWithValue(LaunchTemplateIdLabel, pod.Spec.NodeSelector[LaunchTemplateIdLabel])) + Expect(node.Labels).ToNot(HaveKey(LaunchTemplateVersionLabel)) + }) + It("should allow a pod to override the launch template id and use the provisioner's launch template version", func() { + // Setup + pod := test.PendingPod(test.PodOptions{NodeSelector: map[string]string{LaunchTemplateIdLabel: randomdata.SillyName()}}) + ExpectCreatedWithStatus(env.Client, pod) + provisioner.Spec.Labels = map[string]string{ + LaunchTemplateIdLabel: randomdata.SillyName(), + LaunchTemplateVersionLabel: randomdata.SillyName(), + } + ExpectCreated(env.Client, provisioner) + ExpectEventuallyReconciled(env.Client, provisioner) + // Assertions + scheduled := ExpectPodExists(env.Client, pod.GetName(), pod.GetNamespace()) + node := ExpectNodeExists(env.Client, scheduled.Spec.NodeName) + Expect(fakeEC2API.CalledWithCreateFleetInput).To(HaveLen(1)) + Expect(fakeEC2API.CalledWithCreateFleetInput[0].LaunchTemplateConfigs).To(HaveLen(1)) + launchTemplate := fakeEC2API.CalledWithCreateFleetInput[0].LaunchTemplateConfigs[0].LaunchTemplateSpecification + Expect(aws.StringValue(launchTemplate.LaunchTemplateId)).To(Equal(pod.Spec.NodeSelector[LaunchTemplateIdLabel])) + Expect(aws.StringValue(launchTemplate.Version)).To(Equal(provisioner.Spec.Labels[LaunchTemplateVersionLabel])) + Expect(node.Labels).To(HaveKeyWithValue(LaunchTemplateIdLabel, pod.Spec.NodeSelector[LaunchTemplateIdLabel])) + Expect(node.Labels).To(HaveKeyWithValue(LaunchTemplateVersionLabel, provisioner.Spec.Labels[LaunchTemplateVersionLabel])) }) - ExpectCreatedWithStatus(env.Client, pod1, pod2, pod3) - ExpectCreated(env.Client, provisioner) - ExpectEventuallyReconciled(env.Client, provisioner) - // Assertions - scheduled1 := ExpectPodExists(env.Client, pod1.GetName(), pod1.GetNamespace()) - scheduled2 := ExpectPodExists(env.Client, pod2.GetName(), pod2.GetNamespace()) - scheduled3 := ExpectPodExists(env.Client, pod3.GetName(), pod3.GetNamespace()) - Expect(scheduled1.Spec.NodeName).To(Equal(scheduled2.Spec.NodeName)) - Expect(scheduled1.Spec.NodeName).ToNot(Equal(scheduled3.Spec.NodeName)) - ExpectNodeExists(env.Client, scheduled1.Spec.NodeName) - ExpectNodeExists(env.Client, scheduled3.Spec.NodeName) - Expect(fakeEC2API.CalledWithCreateFleetInput[0].LaunchTemplateConfigs[0].Overrides).To( - ContainElements( - &ec2.FleetLaunchTemplateOverridesRequest{ - InstanceType: aws.String("p3.8xlarge"), - SubnetId: aws.String("test-subnet-1"), - }, - ), - ) - Expect(fakeEC2API.CalledWithCreateFleetInput[1].LaunchTemplateConfigs[0].Overrides).To( - ContainElements( - &ec2.FleetLaunchTemplateOverridesRequest{ - InstanceType: aws.String("p3.8xlarge"), - SubnetId: aws.String("test-subnet-1"), - }, - ), - ) }) - It("should launch instances for AWS Neuron resource requests", func() { - // Setup - pod1 := test.PendingPodWith(test.PodOptions{ - ResourceRequirements: v1.ResourceRequirements{ - Requests: v1.ResourceList{resources.AWSNeuron: resource.MustParse("1")}, - Limits: v1.ResourceList{resources.AWSNeuron: resource.MustParse("1")}, - }, + Context("Subnets", func() { + It("should default to the clusters subnets", func() { + // Setup + pod := test.PendingPod() + ExpectCreatedWithStatus(env.Client, pod) + provisioner.Spec.InstanceTypes = []string{"m5.large"} // limit instance type to simplify ConsistOf checks + ExpectCreated(env.Client, provisioner) + ExpectEventuallyReconciled(env.Client, provisioner) + // Assertions + scheduled := ExpectPodExists(env.Client, pod.GetName(), pod.GetNamespace()) + node := ExpectNodeExists(env.Client, scheduled.Spec.NodeName) + Expect(fakeEC2API.CalledWithCreateFleetInput).To(HaveLen(1)) + Expect(fakeEC2API.CalledWithCreateFleetInput[0].LaunchTemplateConfigs).To(HaveLen(1)) + Expect(fakeEC2API.CalledWithCreateFleetInput[0].LaunchTemplateConfigs[0].Overrides).To(ConsistOf( + &ec2.FleetLaunchTemplateOverridesRequest{SubnetId: aws.String("test-subnet-1"), InstanceType: aws.String("m5.large")}, + &ec2.FleetLaunchTemplateOverridesRequest{SubnetId: aws.String("test-subnet-2"), InstanceType: aws.String("m5.large")}, + &ec2.FleetLaunchTemplateOverridesRequest{SubnetId: aws.String("test-subnet-3"), InstanceType: aws.String("m5.large")}, + )) + Expect(node.Labels).ToNot(HaveKey(SubnetNameLabel)) + Expect(node.Labels).ToNot(HaveKey(SubnetTagKeyLabel)) }) - // Should pack onto same instance - pod2 := test.PendingPodWith(test.PodOptions{ - ResourceRequirements: v1.ResourceRequirements{ - Requests: v1.ResourceList{resources.AWSNeuron: resource.MustParse("2")}, - Limits: v1.ResourceList{resources.AWSNeuron: resource.MustParse("2")}, - }, + It("should default to a provisioner's specified subnet name", func() { + // Setup + pod := test.PendingPod() + ExpectCreatedWithStatus(env.Client, pod) + provisioner.Spec.Labels = map[string]string{SubnetNameLabel: "foo"} + provisioner.Spec.InstanceTypes = []string{"m5.large"} // limit instance type to simplify ConsistOf checks + ExpectCreated(env.Client, provisioner) + ExpectEventuallyReconciled(env.Client, provisioner) + // Assertions + scheduled := ExpectPodExists(env.Client, pod.GetName(), pod.GetNamespace()) + node := ExpectNodeExists(env.Client, scheduled.Spec.NodeName) + Expect(fakeEC2API.CalledWithCreateFleetInput).To(HaveLen(1)) + Expect(fakeEC2API.CalledWithCreateFleetInput[0].LaunchTemplateConfigs).To(HaveLen(1)) + Expect(fakeEC2API.CalledWithCreateFleetInput[0].LaunchTemplateConfigs[0].Overrides).To(ConsistOf( + &ec2.FleetLaunchTemplateOverridesRequest{SubnetId: aws.String("test-subnet-2"), InstanceType: aws.String("m5.large")}, + )) + Expect(node.Labels).To(HaveKeyWithValue(SubnetNameLabel, provisioner.Spec.Labels[SubnetNameLabel])) + Expect(node.Labels).ToNot(HaveKey(SubnetTagKeyLabel)) }) - // Should pack onto a separate instance - pod3 := test.PendingPodWith(test.PodOptions{ - ResourceRequirements: v1.ResourceRequirements{ - Requests: v1.ResourceList{resources.AWSNeuron: resource.MustParse("4")}, - Limits: v1.ResourceList{resources.AWSNeuron: resource.MustParse("4")}, - }, + It("should default to a provisioner's specified subnet tag key", func() { + pod := test.PendingPod() + ExpectCreatedWithStatus(env.Client, pod) + provisioner.Spec.Labels = map[string]string{SubnetTagKeyLabel: "TestTag"} + provisioner.Spec.InstanceTypes = []string{"m5.large"} // limit instance type to simplify ConsistOf checks + ExpectCreated(env.Client, provisioner) + ExpectEventuallyReconciled(env.Client, provisioner) + // Assertions + scheduled := ExpectPodExists(env.Client, pod.GetName(), pod.GetNamespace()) + node := ExpectNodeExists(env.Client, scheduled.Spec.NodeName) + Expect(fakeEC2API.CalledWithCreateFleetInput).To(HaveLen(1)) + Expect(fakeEC2API.CalledWithCreateFleetInput[0].LaunchTemplateConfigs).To(HaveLen(1)) + Expect(fakeEC2API.CalledWithCreateFleetInput[0].LaunchTemplateConfigs[0].Overrides).To(ConsistOf( + &ec2.FleetLaunchTemplateOverridesRequest{SubnetId: aws.String("test-subnet-3"), InstanceType: aws.String("m5.large")}, + )) + Expect(node.Labels).ToNot(HaveKey(SubnetNameLabel)) + Expect(node.Labels).To(HaveKeyWithValue(SubnetTagKeyLabel, provisioner.Spec.Labels[SubnetTagKeyLabel])) + }) + It("should allow a pod to override the subnet name", func() { + // Setup + pod := test.PendingPod(test.PodOptions{NodeSelector: map[string]string{SubnetNameLabel: "foo"}}) + ExpectCreatedWithStatus(env.Client, pod) + provisioner.Spec.InstanceTypes = []string{"m5.large"} // limit instance type to simplify ConsistOf checks + ExpectCreated(env.Client, provisioner) + ExpectEventuallyReconciled(env.Client, provisioner) + // Assertions + scheduled := ExpectPodExists(env.Client, pod.GetName(), pod.GetNamespace()) + node := ExpectNodeExists(env.Client, scheduled.Spec.NodeName) + Expect(fakeEC2API.CalledWithCreateFleetInput).To(HaveLen(1)) + Expect(fakeEC2API.CalledWithCreateFleetInput[0].LaunchTemplateConfigs).To(HaveLen(1)) + Expect(fakeEC2API.CalledWithCreateFleetInput[0].LaunchTemplateConfigs[0].Overrides).To(ConsistOf( + &ec2.FleetLaunchTemplateOverridesRequest{SubnetId: aws.String("test-subnet-2"), InstanceType: aws.String("m5.large")}, + )) + Expect(node.Labels).To(HaveKeyWithValue(SubnetNameLabel, provisioner.Spec.Labels[SubnetNameLabel])) + Expect(node.Labels).ToNot(HaveKey(SubnetTagKeyLabel)) + }) + It("should allow a pod to override the subnet tags", func() { + pod := test.PendingPod(test.PodOptions{NodeSelector: map[string]string{SubnetTagKeyLabel: "TestTag"}}) + ExpectCreatedWithStatus(env.Client, pod) + provisioner.Spec.InstanceTypes = []string{"m5.large"} // limit instance type to simplify ConsistOf checks + ExpectCreated(env.Client, provisioner) + ExpectEventuallyReconciled(env.Client, provisioner) + // Assertions + scheduled := ExpectPodExists(env.Client, pod.GetName(), pod.GetNamespace()) + node := ExpectNodeExists(env.Client, scheduled.Spec.NodeName) + Expect(fakeEC2API.CalledWithCreateFleetInput).To(HaveLen(1)) + Expect(fakeEC2API.CalledWithCreateFleetInput[0].LaunchTemplateConfigs).To(HaveLen(1)) + Expect(fakeEC2API.CalledWithCreateFleetInput[0].LaunchTemplateConfigs[0].Overrides).To(ConsistOf( + &ec2.FleetLaunchTemplateOverridesRequest{SubnetId: aws.String("test-subnet-3"), InstanceType: aws.String("m5.large")}, + )) + Expect(node.Labels).ToNot(HaveKey(SubnetNameLabel)) + Expect(node.Labels).To(HaveKeyWithValue(SubnetTagKeyLabel, provisioner.Spec.Labels[SubnetTagKeyLabel])) }) - ExpectCreatedWithStatus(env.Client, pod1, pod2, pod3) - ExpectCreated(env.Client, provisioner) - ExpectEventuallyReconciled(env.Client, provisioner) - // Assertions - scheduled1 := ExpectPodExists(env.Client, pod1.GetName(), pod1.GetNamespace()) - scheduled2 := ExpectPodExists(env.Client, pod2.GetName(), pod2.GetNamespace()) - scheduled3 := ExpectPodExists(env.Client, pod3.GetName(), pod3.GetNamespace()) - Expect(scheduled1.Spec.NodeName).To(Equal(scheduled2.Spec.NodeName)) - Expect(scheduled1.Spec.NodeName).ToNot(Equal(scheduled3.Spec.NodeName)) - ExpectNodeExists(env.Client, scheduled1.Spec.NodeName) - ExpectNodeExists(env.Client, scheduled3.Spec.NodeName) - Expect(fakeEC2API.CalledWithCreateFleetInput[0].LaunchTemplateConfigs[0].Overrides).To( - ContainElements( - &ec2.FleetLaunchTemplateOverridesRequest{ - InstanceType: aws.String("inf1.6xlarge"), - SubnetId: aws.String("test-subnet-1"), - }, - ), - ) - Expect(fakeEC2API.CalledWithCreateFleetInput[1].LaunchTemplateConfigs[0].Overrides).To( - ContainElements( - &ec2.FleetLaunchTemplateOverridesRequest{ - InstanceType: aws.String("inf1.6xlarge"), - SubnetId: aws.String("test-subnet-1"), - }, - ), - ) }) }) Context("Validation", func() { diff --git a/pkg/cloudprovider/aws/validation.go b/pkg/cloudprovider/aws/validation.go index 3dd89693230a..ab1edd242a85 100644 --- a/pkg/cloudprovider/aws/validation.go +++ b/pkg/cloudprovider/aws/validation.go @@ -30,6 +30,7 @@ func (c *Capacity) Validate(ctx context.Context) (errs *apis.FieldError) { validateAllowedLabels(c.provisioner.Spec), validateCapacityTypeLabel(c.provisioner.Spec), validateLaunchTemplateLabels(c.provisioner.Spec), + validateSubnetLabels(c.provisioner.Spec), ) } func validateAllowedLabels(spec v1alpha1.ProvisionerSpec) (errs *apis.FieldError) { @@ -61,3 +62,11 @@ func validateLaunchTemplateLabels(spec v1alpha1.ProvisionerSpec) (errs *apis.Fie } return errs } + +func validateSubnetLabels(spec v1alpha1.ProvisionerSpec) (errs *apis.FieldError) { + constraints := Constraints(spec.Constraints) + if constraints.GetSubnetName() != nil && constraints.GetSubnetTagKey() != nil { + errs = errs.Also(apis.ErrMultipleOneOf(fmt.Sprintf("spec.labels[%s]", SubnetNameLabel), fmt.Sprintf("spec.labels[%s]", SubnetTagKeyLabel))) + } + return errs +} diff --git a/pkg/controllers/provisioning/v1alpha1/allocation/suite_test.go b/pkg/controllers/provisioning/v1alpha1/allocation/suite_test.go index 3a83846f1b5b..63ed481f7d1e 100644 --- a/pkg/controllers/provisioning/v1alpha1/allocation/suite_test.go +++ b/pkg/controllers/provisioning/v1alpha1/allocation/suite_test.go @@ -104,35 +104,35 @@ var _ = Describe("Allocation", func() { // Unconstrained test.PendingPod(), // Constrained by provisioner - test.PendingPodWith(test.PodOptions{ + test.PendingPod(test.PodOptions{ NodeSelector: map[string]string{v1alpha1.ProvisionerNameLabelKey: provisioner.Name, v1alpha1.ProvisionerNamespaceLabelKey: provisioner.Namespace}, }), } schedulable := []client.Object{ // Constrained by zone - test.PendingPodWith(test.PodOptions{ + test.PendingPod(test.PodOptions{ NodeSelector: map[string]string{v1alpha1.ZoneLabelKey: "test-zone-1"}, }), // Constrained by instanceType - test.PendingPodWith(test.PodOptions{ + test.PendingPod(test.PodOptions{ NodeSelector: map[string]string{v1alpha1.InstanceTypeLabelKey: "default-instance-type"}, }), // Constrained by architecture - test.PendingPodWith(test.PodOptions{ + test.PendingPod(test.PodOptions{ NodeSelector: map[string]string{v1alpha1.ArchitectureLabelKey: "arm64"}, }), // Constrained by operating system - test.PendingPodWith(test.PodOptions{ + test.PendingPod(test.PodOptions{ NodeSelector: map[string]string{v1alpha1.OperatingSystemLabelKey: "windows"}, }), // Constrained by arbitrary label - test.PendingPodWith(test.PodOptions{ + test.PendingPod(test.PodOptions{ NodeSelector: map[string]string{"foo": "bar"}, }), } unschedulable := []client.Object{ // Ignored, matches another provisioner - test.PendingPodWith(test.PodOptions{ + test.PendingPod(test.PodOptions{ NodeSelector: map[string]string{v1alpha1.ProvisionerNameLabelKey: "test", v1alpha1.ProvisionerNamespaceLabelKey: "test"}, }), } @@ -161,11 +161,11 @@ var _ = Describe("Allocation", func() { provisioner.Spec.Taints = []v1.Taint{{Key: "test-key", Value: "test-value", Effect: v1.TaintEffectNoSchedule}} schedulable := []client.Object{ // Tolerates with OpExists - test.PendingPodWith(test.PodOptions{ + test.PendingPod(test.PodOptions{ Tolerations: []v1.Toleration{{Key: "test-key", Operator: v1.TolerationOpExists}}, }), // Tolerates with OpEqual - test.PendingPodWith(test.PodOptions{ + test.PendingPod(test.PodOptions{ Tolerations: []v1.Toleration{{Key: "test-key", Value: "test-value", Operator: v1.TolerationOpEqual}}, }), } @@ -173,15 +173,15 @@ var _ = Describe("Allocation", func() { // Missing toleration test.PendingPod(), // key mismatch with OpExists - test.PendingPodWith(test.PodOptions{ + test.PendingPod(test.PodOptions{ Tolerations: []v1.Toleration{{Key: "invalid", Operator: v1.TolerationOpExists}}, }), // value mismatch with OpEqual - test.PendingPodWith(test.PodOptions{ + test.PendingPod(test.PodOptions{ Tolerations: []v1.Toleration{{Key: "test-key", Value: "invalid", Operator: v1.TolerationOpEqual}}, }), // key mismatch with OpEqual - test.PendingPodWith(test.PodOptions{ + test.PendingPod(test.PodOptions{ Tolerations: []v1.Toleration{{Key: "invalid", Value: "test-value", Operator: v1.TolerationOpEqual}}, }), } @@ -204,13 +204,13 @@ var _ = Describe("Allocation", func() { }) It("should provision nodes for accelerators", func() { schedulable := []client.Object{ - test.PendingPodWith(test.PodOptions{ + test.PendingPod(test.PodOptions{ ResourceRequirements: v1.ResourceRequirements{Limits: v1.ResourceList{resources.NvidiaGPU: resource.MustParse("1")}}, }), - test.PendingPodWith(test.PodOptions{ + test.PendingPod(test.PodOptions{ ResourceRequirements: v1.ResourceRequirements{Limits: v1.ResourceList{resources.AMDGPU: resource.MustParse("1")}}, }), - test.PendingPodWith(test.PodOptions{ + test.PendingPod(test.PodOptions{ ResourceRequirements: v1.ResourceRequirements{Limits: v1.ResourceList{resources.AWSNeuron: resource.MustParse("1")}}, }), } @@ -234,20 +234,20 @@ var _ = Describe("Allocation", func() { Selector: &metav1.LabelSelector{MatchLabels: map[string]string{"app": "test"}}, Template: v1.PodTemplateSpec{ ObjectMeta: metav1.ObjectMeta{Labels: map[string]string{"app": "test"}}, - Spec: test.PendingPodWith(test.PodOptions{ + Spec: test.PendingPod(test.PodOptions{ ResourceRequirements: v1.ResourceRequirements{Requests: v1.ResourceList{v1.ResourceCPU: resource.MustParse("1"), v1.ResourceMemory: resource.MustParse("1Gi")}}, }).Spec, }}, }, } schedulable := []client.Object{ - test.PendingPodWith(test.PodOptions{ + test.PendingPod(test.PodOptions{ ResourceRequirements: v1.ResourceRequirements{Requests: v1.ResourceList{v1.ResourceCPU: resource.MustParse("1"), v1.ResourceMemory: resource.MustParse("1Gi")}}, }), - test.PendingPodWith(test.PodOptions{ + test.PendingPod(test.PodOptions{ ResourceRequirements: v1.ResourceRequirements{Requests: v1.ResourceList{v1.ResourceCPU: resource.MustParse("1"), v1.ResourceMemory: resource.MustParse("1Gi")}}, }), - test.PendingPodWith(test.PodOptions{ + test.PendingPod(test.PodOptions{ ResourceRequirements: v1.ResourceRequirements{Requests: v1.ResourceList{v1.ResourceCPU: resource.MustParse("1"), v1.ResourceMemory: resource.MustParse("1Gi")}}, }), } diff --git a/pkg/controllers/provisioning/v1alpha1/reallocation/suite_test.go b/pkg/controllers/provisioning/v1alpha1/reallocation/suite_test.go index 223f18120ab8..7a89aa88a56a 100644 --- a/pkg/controllers/provisioning/v1alpha1/reallocation/suite_test.go +++ b/pkg/controllers/provisioning/v1alpha1/reallocation/suite_test.go @@ -71,7 +71,11 @@ var _ = Describe("Reallocation", func() { Namespace: "default", }, Spec: v1alpha1.ProvisionerSpec{ +<<<<<<< HEAD Cluster: &v1alpha1.ClusterSpec{Name: "test-cluster", Endpoint: "http://test-cluster", CABundle: "dGVzdC1jbHVzdGVyCg=="}, +======= + Cluster: &v1alpha1.ClusterSpec{Name: "test-cluster", Endpoint: "http://test-cluster", CABundle: "dGVzdC1jbHVzdGVyCg=="}, +>>>>>>> Removed certmanager dependency in favor of knative pkg. TTLSeconds: ptr.Int32(300), }, } @@ -112,7 +116,7 @@ var _ = Describe("Reallocation", func() { v1alpha1.ProvisionerTTLKey: time.Now().Add(time.Duration(100) * time.Second).Format(time.RFC3339), }, }) - pod := test.PendingPodWith(test.PodOptions{ + pod := test.Pod(test.PodOptions{ Name: strings.ToLower(randomdata.SillyName()), Namespace: provisioner.Namespace, NodeName: node.Name, diff --git a/pkg/test/pods.go b/pkg/test/pods.go index f83230113364..3d77fb4897aa 100644 --- a/pkg/test/pods.go +++ b/pkg/test/pods.go @@ -37,7 +37,15 @@ type PodOptions struct { Conditions []v1.PodCondition } -func defaults(options PodOptions) *v1.Pod { +// Pod creates a test pod with defaults that can be overriden by PodOptions. +// Overrides are applied in order, with a last write wins semantic. +func Pod(optionss ...PodOptions) *v1.Pod { + options := PodOptions{} + for _, opts := range optionss { + if err := mergo.Merge(&options, opts, mergo.WithOverride); err != nil { + panic(fmt.Sprintf("Failed to merge pod options: %s", err.Error())) + } + } if options.Name == "" { options.Name = strings.ToLower(randomdata.SillyName()) } @@ -47,9 +55,6 @@ func defaults(options PodOptions) *v1.Pod { if options.Image == "" { options.Image = "k8s.gcr.io/pause" } - if len(options.Conditions) == 0 { - options.Conditions = []v1.PodCondition{{Type: v1.PodScheduled, Reason: v1.PodReasonUnschedulable, Status: v1.ConditionFalse}} - } return &v1.Pod{ ObjectMeta: metav1.ObjectMeta{ Name: options.Name, @@ -70,23 +75,9 @@ func defaults(options PodOptions) *v1.Pod { } } -// PendingPod creates a pending test pod with the minimal set of other -// fields defaulted to something sane. -func PendingPod() *v1.Pod { - return defaults(PodOptions{}) -} - -// PendingPodWith creates a pending test pod with fields overridden by -// options. -func PendingPodWith(options PodOptions) *v1.Pod { - return PodWith(PendingPod(), options) -} - -// PodWith overrides, in-place, pod with any non-zero elements of -// options. It returns the same pod simply for ease of use. -func PodWith(pod *v1.Pod, options PodOptions) *v1.Pod { - if err := mergo.Merge(pod, defaults(options), mergo.WithOverride); err != nil { - panic(fmt.Sprintf("unexpected error in test code: %v", err)) - } - return pod +// PendingPod creates a test pod with a pending scheduling status condition +func PendingPod(options ...PodOptions) *v1.Pod { + return Pod(append(options, PodOptions{ + Conditions: []v1.PodCondition{{Type: v1.PodScheduled, Reason: v1.PodReasonUnschedulable, Status: v1.ConditionFalse}}, + })...) }