diff --git a/pkg/cloudprovider/aws/fake/ec2api.go b/pkg/cloudprovider/aws/fake/ec2api.go index f91851a63b37..8e5438414549 100644 --- a/pkg/cloudprovider/aws/fake/ec2api.go +++ b/pkg/cloudprovider/aws/fake/ec2api.go @@ -17,6 +17,7 @@ package fake import ( "context" "fmt" + "strings" "github.com/Pallinder/go-randomdata" "github.com/aws/aws-sdk-go/aws" @@ -99,18 +100,81 @@ func (e *EC2API) DescribeLaunchTemplatesWithContext(context.Context, *ec2.Descri }}}, nil } -func (e *EC2API) DescribeSubnetsWithContext(context.Context, *ec2.DescribeSubnetsInput, ...request.Option) (*ec2.DescribeSubnetsOutput, error) { +func (e *EC2API) DescribeSubnetsWithContext(ctx context.Context, input *ec2.DescribeSubnetsInput, options ...request.Option) (*ec2.DescribeSubnetsOutput, error) { if e.WantErr != nil { return nil, e.WantErr } if e.DescribeSubnetsOutput != nil { return e.DescribeSubnetsOutput, nil } - return &ec2.DescribeSubnetsOutput{Subnets: []*ec2.Subnet{ + subnets := []*ec2.Subnet{} + + for _, subnet := range []*ec2.Subnet{ {SubnetId: aws.String("test-subnet-1"), AvailabilityZone: aws.String("test-zone-1a")}, - {SubnetId: aws.String("test-subnet-2"), AvailabilityZone: aws.String("test-zone-1b")}, - {SubnetId: aws.String("test-subnet-3"), AvailabilityZone: aws.String("test-zone-1c")}, - }}, nil + {SubnetId: aws.String("test-subnet-2"), AvailabilityZone: aws.String("test-zone-1b"), Tags: []*ec2.Tag{{Key: aws.String("Name"), Value: aws.String("foo")}}}, + {SubnetId: aws.String("test-subnet-3"), AvailabilityZone: aws.String("test-zone-1c"), Tags: []*ec2.Tag{{Key: aws.String("TestTag")}}}, + } { + if matchesNameFilter(subnet, input.Filters) && matchesTagKeyFilter(subnet, input.Filters) { + subnets = append(subnets, subnet) + } + } + return &ec2.DescribeSubnetsOutput{Subnets: subnets}, nil +} + +func matchesNameFilter(subnet *ec2.Subnet, filters []*ec2.Filter) bool { + for _, filter := range filters { + if aws.StringValue(filter.Name) == "tag:Name" { + // Attempt to find name tag + for _, value := range filter.Values { + if !hasNameTag(subnet.Tags, aws.StringValue(value)) { + return false + } + } + // Fail otherwise + return false + } + } + // Succeed if it hasn't been filtered + return true +} + +func matchesTagKeyFilter(subnet *ec2.Subnet, filters []*ec2.Filter) bool { + for _, filter := range filters { + if aws.StringValue(filter.Name) == "tag-key" { + for _, value := range aws.StringValueSlice(filter.Values) { + // Ignore cluster tag when filtering subnets for testing since its set dynamically. + if strings.HasPrefix(value, "kubernetes.io/cluster/") { + continue + } + // Attempt to find tag with matching key + if !hasTagKey(subnet.Tags, value) { + return false + } + // Fail otherwise + return false + } + } + } + // Succeed if it hasn't been filtered + return true +} + +func hasNameTag(tags []*ec2.Tag, name string) bool { + for _, tag := range tags { + if aws.StringValue(tag.Key) == "tag:Name" && aws.StringValue(tag.Value) == name { + return true + } + } + return false +} + +func hasTagKey(tags []*ec2.Tag, key string) bool { + for _, tag := range tags { + if aws.StringValue(tag.Key) == key { + return true + } + } + return false } func (e *EC2API) DescribeSecurityGroupsWithContext(context.Context, *ec2.DescribeSecurityGroupsInput, ...request.Option) (*ec2.DescribeSecurityGroupsOutput, error) { diff --git a/pkg/cloudprovider/aws/suite_test.go b/pkg/cloudprovider/aws/suite_test.go index 93f60bab1f7a..bf26ba531778 100644 --- a/pkg/cloudprovider/aws/suite_test.go +++ b/pkg/cloudprovider/aws/suite_test.go @@ -36,7 +36,6 @@ import ( "github.com/awslabs/karpenter/pkg/test" . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" - . "github.com/onsi/gomega/gstruct" v1 "k8s.io/api/core/v1" "k8s.io/apimachinery/pkg/api/resource" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" @@ -475,6 +474,7 @@ var _ = Describe("Allocation", func() { // Setup pod := test.PendingPod() ExpectCreatedWithStatus(env.Client, pod) + provisioner.Spec.InstanceTypes = []string{"m5.large"} // limit instance type to simplify ConsistOf checks ExpectCreated(env.Client, provisioner) ExpectEventuallyReconciled(env.Client, provisioner) // Assertions @@ -482,17 +482,86 @@ var _ = Describe("Allocation", func() { node := ExpectNodeExists(env.Client, scheduled.Spec.NodeName) Expect(fakeEC2API.CalledWithCreateFleetInput).To(HaveLen(1)) Expect(fakeEC2API.CalledWithCreateFleetInput[0].LaunchTemplateConfigs).To(HaveLen(1)) - Expect(fakeEC2API.CalledWithCreateFleetInput[0].LaunchTemplateConfigs[0].Overrides).To(ContainElements( - MatchFields(IgnoreMissing&AllowDuplicates, Fields{"SubnetId": Not(BeNil())}), - MatchFields(IgnoreMissing&AllowDuplicates, Fields{"SubnetId": Not(BeNil())}), + Expect(fakeEC2API.CalledWithCreateFleetInput[0].LaunchTemplateConfigs[0].Overrides).To(ConsistOf( + &ec2.FleetLaunchTemplateOverridesRequest{SubnetId: aws.String("test-subnet-1"), InstanceType: aws.String("m5.large")}, + &ec2.FleetLaunchTemplateOverridesRequest{SubnetId: aws.String("test-subnet-2"), InstanceType: aws.String("m5.large")}, + &ec2.FleetLaunchTemplateOverridesRequest{SubnetId: aws.String("test-subnet-3"), InstanceType: aws.String("m5.large")}, )) Expect(node.Labels).ToNot(HaveKey(SubnetNameLabel)) Expect(node.Labels).ToNot(HaveKey(SubnetTagKeyLabel)) }) - It("should default to a provisioner's specified subnet name", func() {}) - It("should default to a provisioner's specified subnet tag key", func() {}) - It("should allow a pod to override the subnet name", func() {}) - It("should allow a pod to override the subnet tags", func() {}) + It("should default to a provisioner's specified subnet name", func() { + // Setup + pod := test.PendingPod() + ExpectCreatedWithStatus(env.Client, pod) + provisioner.Spec.Labels = map[string]string{SubnetNameLabel: "foo"} + provisioner.Spec.InstanceTypes = []string{"m5.large"} // limit instance type to simplify ConsistOf checks + ExpectCreated(env.Client, provisioner) + ExpectEventuallyReconciled(env.Client, provisioner) + // Assertions + scheduled := ExpectPodExists(env.Client, pod.GetName(), pod.GetNamespace()) + node := ExpectNodeExists(env.Client, scheduled.Spec.NodeName) + Expect(fakeEC2API.CalledWithCreateFleetInput).To(HaveLen(1)) + Expect(fakeEC2API.CalledWithCreateFleetInput[0].LaunchTemplateConfigs).To(HaveLen(1)) + Expect(fakeEC2API.CalledWithCreateFleetInput[0].LaunchTemplateConfigs[0].Overrides).To(ConsistOf( + &ec2.FleetLaunchTemplateOverridesRequest{SubnetId: aws.String("test-subnet-2"), InstanceType: aws.String("m5.large")}, + )) + Expect(node.Labels).To(HaveKeyWithValue(SubnetNameLabel, provisioner.Spec.Labels[SubnetNameLabel])) + Expect(node.Labels).ToNot(HaveKey(SubnetTagKeyLabel)) + }) + It("should default to a provisioner's specified subnet tag key", func() { + pod := test.PendingPod() + ExpectCreatedWithStatus(env.Client, pod) + provisioner.Spec.Labels = map[string]string{SubnetTagKeyLabel: "TestTag"} + provisioner.Spec.InstanceTypes = []string{"m5.large"} // limit instance type to simplify ConsistOf checks + ExpectCreated(env.Client, provisioner) + ExpectEventuallyReconciled(env.Client, provisioner) + // Assertions + scheduled := ExpectPodExists(env.Client, pod.GetName(), pod.GetNamespace()) + node := ExpectNodeExists(env.Client, scheduled.Spec.NodeName) + Expect(fakeEC2API.CalledWithCreateFleetInput).To(HaveLen(1)) + Expect(fakeEC2API.CalledWithCreateFleetInput[0].LaunchTemplateConfigs).To(HaveLen(1)) + Expect(fakeEC2API.CalledWithCreateFleetInput[0].LaunchTemplateConfigs[0].Overrides).To(ConsistOf( + &ec2.FleetLaunchTemplateOverridesRequest{SubnetId: aws.String("test-subnet-3"), InstanceType: aws.String("m5.large")}, + )) + Expect(node.Labels).ToNot(HaveKey(SubnetNameLabel)) + Expect(node.Labels).To(HaveKeyWithValue(SubnetTagKeyLabel, provisioner.Spec.Labels[SubnetTagKeyLabel])) + }) + It("should allow a pod to override the subnet name", func() { + // Setup + pod := test.PendingPod(test.PodOptions{NodeSelector: map[string]string{SubnetNameLabel: "foo"}}) + ExpectCreatedWithStatus(env.Client, pod) + provisioner.Spec.InstanceTypes = []string{"m5.large"} // limit instance type to simplify ConsistOf checks + ExpectCreated(env.Client, provisioner) + ExpectEventuallyReconciled(env.Client, provisioner) + // Assertions + scheduled := ExpectPodExists(env.Client, pod.GetName(), pod.GetNamespace()) + node := ExpectNodeExists(env.Client, scheduled.Spec.NodeName) + Expect(fakeEC2API.CalledWithCreateFleetInput).To(HaveLen(1)) + Expect(fakeEC2API.CalledWithCreateFleetInput[0].LaunchTemplateConfigs).To(HaveLen(1)) + Expect(fakeEC2API.CalledWithCreateFleetInput[0].LaunchTemplateConfigs[0].Overrides).To(ConsistOf( + &ec2.FleetLaunchTemplateOverridesRequest{SubnetId: aws.String("test-subnet-2"), InstanceType: aws.String("m5.large")}, + )) + Expect(node.Labels).To(HaveKeyWithValue(SubnetNameLabel, provisioner.Spec.Labels[SubnetNameLabel])) + Expect(node.Labels).ToNot(HaveKey(SubnetTagKeyLabel)) + }) + It("should allow a pod to override the subnet tags", func() { + pod := test.PendingPod(test.PodOptions{NodeSelector: map[string]string{SubnetTagKeyLabel: "TestTag"}}) + ExpectCreatedWithStatus(env.Client, pod) + provisioner.Spec.InstanceTypes = []string{"m5.large"} // limit instance type to simplify ConsistOf checks + ExpectCreated(env.Client, provisioner) + ExpectEventuallyReconciled(env.Client, provisioner) + // Assertions + scheduled := ExpectPodExists(env.Client, pod.GetName(), pod.GetNamespace()) + node := ExpectNodeExists(env.Client, scheduled.Spec.NodeName) + Expect(fakeEC2API.CalledWithCreateFleetInput).To(HaveLen(1)) + Expect(fakeEC2API.CalledWithCreateFleetInput[0].LaunchTemplateConfigs).To(HaveLen(1)) + Expect(fakeEC2API.CalledWithCreateFleetInput[0].LaunchTemplateConfigs[0].Overrides).To(ConsistOf( + &ec2.FleetLaunchTemplateOverridesRequest{SubnetId: aws.String("test-subnet-3"), InstanceType: aws.String("m5.large")}, + )) + Expect(node.Labels).ToNot(HaveKey(SubnetNameLabel)) + Expect(node.Labels).To(HaveKeyWithValue(SubnetTagKeyLabel, provisioner.Spec.Labels[SubnetTagKeyLabel])) + }) }) }) Context("Validation", func() {