diff --git a/pkg/cloudprovider/aws/fake/ec2api.go b/pkg/cloudprovider/aws/fake/ec2api.go index 2f4f2ffcfd8b..d31385a77797 100644 --- a/pkg/cloudprovider/aws/fake/ec2api.go +++ b/pkg/cloudprovider/aws/fake/ec2api.go @@ -25,8 +25,9 @@ import ( "github.com/aws/aws-sdk-go/service/ec2/ec2iface" ) -type EC2API struct { - ec2iface.EC2API +// EC2Behavior must be reset between tests otherwise tests will +// pollute each other. +type EC2Behavior struct { CreateFleetOutput *ec2.CreateFleetOutput DescribeInstancesOutput *ec2.DescribeInstancesOutput DescribeLaunchTemplatesOutput *ec2.DescribeLaunchTemplatesOutput @@ -36,82 +37,92 @@ type EC2API struct { DescribeInstanceTypeOfferingsOutput *ec2.DescribeInstanceTypeOfferingsOutput DescribeAvailabilityZonesOutput *ec2.DescribeAvailabilityZonesOutput WantErr error + CalledWithCreateFleetInput []ec2.CreateFleetInput + Instances []*ec2.Instance +} - CalledWithCreateFleetInput []ec2.CreateFleetInput - Instances []*ec2.Instance +type EC2API struct { + ec2iface.EC2API + EC2Behavior } -func (a *EC2API) Reset() { - a.CalledWithCreateFleetInput = nil +// Reset must be called between tests otherwise tests will pollute +// each other. +func (e *EC2API) Reset() { + e.EC2Behavior = EC2Behavior{} } -func (a *EC2API) CreateFleetWithContext(ctx context.Context, input *ec2.CreateFleetInput, options ...request.Option) (*ec2.CreateFleetOutput, error) { - a.CalledWithCreateFleetInput = append(a.CalledWithCreateFleetInput, *input) - if a.WantErr != nil { - return nil, a.WantErr +func (e *EC2API) CreateFleetWithContext(ctx context.Context, input *ec2.CreateFleetInput, options ...request.Option) (*ec2.CreateFleetOutput, error) { + e.CalledWithCreateFleetInput = append(e.CalledWithCreateFleetInput, *input) + if e.WantErr != nil { + return nil, e.WantErr } - if a.CreateFleetOutput != nil { - return a.CreateFleetOutput, nil + if e.CreateFleetOutput != nil { + return e.CreateFleetOutput, nil } instance := &ec2.Instance{ InstanceId: aws.String(randomdata.SillyName()), - Placement: &ec2.Placement{AvailabilityZone: aws.String("test-zone")}, - PrivateDnsName: aws.String(fmt.Sprintf("test-instance-%d.example.com", len(a.Instances))), + Placement: &ec2.Placement{AvailabilityZone: aws.String("test-zone-1a")}, + PrivateDnsName: aws.String(fmt.Sprintf("test-instance-%d.example.com", len(e.Instances))), } - a.Instances = append(a.Instances, instance) + e.Instances = append(e.Instances, instance) return &ec2.CreateFleetOutput{Instances: []*ec2.CreateFleetInstance{{InstanceIds: []*string{instance.InstanceId}}}}, nil } -func (a *EC2API) DescribeInstancesWithContext(context.Context, *ec2.DescribeInstancesInput, ...request.Option) (*ec2.DescribeInstancesOutput, error) { - if a.WantErr != nil { - return nil, a.WantErr +func (e *EC2API) DescribeInstancesWithContext(context.Context, *ec2.DescribeInstancesInput, ...request.Option) (*ec2.DescribeInstancesOutput, error) { + if e.WantErr != nil { + return nil, e.WantErr } - if a.DescribeInstancesOutput != nil { - return a.DescribeInstancesOutput, nil + if e.DescribeInstancesOutput != nil { + return e.DescribeInstancesOutput, nil } return &ec2.DescribeInstancesOutput{ - Reservations: []*ec2.Reservation{{Instances: a.Instances}}, + Reservations: []*ec2.Reservation{{Instances: e.Instances}}, }, nil } -func (a *EC2API) DescribeLaunchTemplatesWithContext(context.Context, *ec2.DescribeLaunchTemplatesInput, ...request.Option) (*ec2.DescribeLaunchTemplatesOutput, error) { - if a.WantErr != nil { - return nil, a.WantErr +func (e *EC2API) DescribeLaunchTemplatesWithContext(context.Context, *ec2.DescribeLaunchTemplatesInput, ...request.Option) (*ec2.DescribeLaunchTemplatesOutput, error) { + if e.WantErr != nil { + return nil, e.WantErr } - if a.DescribeLaunchTemplatesOutput != nil { - return a.DescribeLaunchTemplatesOutput, nil + if e.DescribeLaunchTemplatesOutput != nil { + return e.DescribeLaunchTemplatesOutput, nil } return &ec2.DescribeLaunchTemplatesOutput{LaunchTemplates: []*ec2.LaunchTemplate{{ LaunchTemplateName: aws.String("test-launch-template"), }}}, nil } -func (a *EC2API) DescribeSubnetsWithContext(context.Context, *ec2.DescribeSubnetsInput, ...request.Option) (*ec2.DescribeSubnetsOutput, error) { - if a.WantErr != nil { - return nil, a.WantErr +func (e *EC2API) DescribeSubnetsWithContext(context.Context, *ec2.DescribeSubnetsInput, ...request.Option) (*ec2.DescribeSubnetsOutput, error) { + if e.WantErr != nil { + return nil, e.WantErr } - if a.DescribeSubnetsOutput != nil { - return a.DescribeSubnetsOutput, nil + if e.DescribeSubnetsOutput != nil { + return e.DescribeSubnetsOutput, nil } - return &ec2.DescribeSubnetsOutput{Subnets: []*ec2.Subnet{{SubnetId: aws.String("test-subnet"), AvailabilityZone: aws.String("test-zone")}}}, nil + return &ec2.DescribeSubnetsOutput{Subnets: []*ec2.Subnet{ + {SubnetId: aws.String("test-subnet-1"), AvailabilityZone: aws.String("test-zone-1a")}, + {SubnetId: aws.String("test-subnet-2"), AvailabilityZone: aws.String("test-zone-1b")}, + {SubnetId: aws.String("test-subnet-3"), AvailabilityZone: aws.String("test-zone-1c")}, + }}, nil } -func (a *EC2API) DescribeSecurityGroupsWithContext(context.Context, *ec2.DescribeSecurityGroupsInput, ...request.Option) (*ec2.DescribeSecurityGroupsOutput, error) { - if a.WantErr != nil { - return nil, a.WantErr +func (e *EC2API) DescribeSecurityGroupsWithContext(context.Context, *ec2.DescribeSecurityGroupsInput, ...request.Option) (*ec2.DescribeSecurityGroupsOutput, error) { + if e.WantErr != nil { + return nil, e.WantErr } - if a.DescribeSecurityGroupsOutput != nil { - return a.DescribeSecurityGroupsOutput, nil + if e.DescribeSecurityGroupsOutput != nil { + return e.DescribeSecurityGroupsOutput, nil } return &ec2.DescribeSecurityGroupsOutput{SecurityGroups: []*ec2.SecurityGroup{{GroupId: aws.String("test-group")}}}, nil } -func (a *EC2API) DescribeAvailabilityZonesWithContext(context.Context, *ec2.DescribeAvailabilityZonesInput, ...request.Option) (*ec2.DescribeAvailabilityZonesOutput, error) { - if a.WantErr != nil { - return nil, a.WantErr +func (e *EC2API) DescribeAvailabilityZonesWithContext(context.Context, *ec2.DescribeAvailabilityZonesInput, ...request.Option) (*ec2.DescribeAvailabilityZonesOutput, error) { + if e.WantErr != nil { + return nil, e.WantErr } - if a.DescribeAvailabilityZonesOutput != nil { - return a.DescribeAvailabilityZonesOutput, nil + if e.DescribeAvailabilityZonesOutput != nil { + return e.DescribeAvailabilityZonesOutput, nil } return &ec2.DescribeAvailabilityZonesOutput{AvailabilityZones: []*ec2.AvailabilityZone{ {ZoneName: aws.String("test-zone-1a"), ZoneId: aws.String("testzone1a")}, @@ -120,12 +131,12 @@ func (a *EC2API) DescribeAvailabilityZonesWithContext(context.Context, *ec2.Desc }}, nil } -func (a *EC2API) DescribeInstanceTypesPagesWithContext(ctx context.Context, input *ec2.DescribeInstanceTypesInput, fn func(*ec2.DescribeInstanceTypesOutput, bool) bool, opts ...request.Option) error { - if a.WantErr != nil { - return a.WantErr +func (e *EC2API) DescribeInstanceTypesPagesWithContext(ctx context.Context, input *ec2.DescribeInstanceTypesInput, fn func(*ec2.DescribeInstanceTypesOutput, bool) bool, opts ...request.Option) error { + if e.WantErr != nil { + return e.WantErr } - if a.DescribeInstanceTypesOutput != nil { - fn(a.DescribeInstanceTypesOutput, false) + if e.DescribeInstanceTypesOutput != nil { + fn(e.DescribeInstanceTypesOutput, false) return nil } fn(&ec2.DescribeInstanceTypesOutput{ @@ -175,12 +186,12 @@ func (a *EC2API) DescribeInstanceTypesPagesWithContext(ctx context.Context, inpu return nil } -func (a *EC2API) DescribeInstanceTypeOfferingsPagesWithContext(ctx context.Context, input *ec2.DescribeInstanceTypeOfferingsInput, fn func(*ec2.DescribeInstanceTypeOfferingsOutput, bool) bool, opts ...request.Option) error { - if a.WantErr != nil { - return a.WantErr +func (e *EC2API) DescribeInstanceTypeOfferingsPagesWithContext(ctx context.Context, input *ec2.DescribeInstanceTypeOfferingsInput, fn func(*ec2.DescribeInstanceTypeOfferingsOutput, bool) bool, opts ...request.Option) error { + if e.WantErr != nil { + return e.WantErr } - if a.DescribeInstanceTypeOfferingsOutput != nil { - fn(a.DescribeInstanceTypeOfferingsOutput, false) + if e.DescribeInstanceTypeOfferingsOutput != nil { + fn(e.DescribeInstanceTypeOfferingsOutput, false) return nil } fn(&ec2.DescribeInstanceTypeOfferingsOutput{ diff --git a/pkg/cloudprovider/aws/instancetypes_test.go b/pkg/cloudprovider/aws/instancetypes_test.go index 5e9e7ae5ae91..5c3549eb851f 100644 --- a/pkg/cloudprovider/aws/instancetypes_test.go +++ b/pkg/cloudprovider/aws/instancetypes_test.go @@ -130,8 +130,10 @@ var _ = Describe("InstanceTypes", func() { func getInstanceTypeProviderMocks(zones []string, instanceTypes []string) ec2iface.EC2API { ec2api := &fake.EC2API{ - DescribeInstanceTypesOutput: &ec2.DescribeInstanceTypesOutput{}, - DescribeInstanceTypeOfferingsOutput: &ec2.DescribeInstanceTypeOfferingsOutput{}, + EC2Behavior: fake.EC2Behavior{ + DescribeInstanceTypesOutput: &ec2.DescribeInstanceTypesOutput{}, + DescribeInstanceTypeOfferingsOutput: &ec2.DescribeInstanceTypeOfferingsOutput{}, + }, } for _, instanceType := range instanceTypes { diff --git a/pkg/cloudprovider/aws/suite_test.go b/pkg/cloudprovider/aws/suite_test.go index f489e7715a58..c95f9d76a153 100644 --- a/pkg/cloudprovider/aws/suite_test.go +++ b/pkg/cloudprovider/aws/suite_test.go @@ -136,11 +136,6 @@ var _ = Describe("Allocation", func() { Context("Reconciliation", func() { It("should default to a cluster zone", func() { - fakeEC2API.DescribeSubnetsOutput = &ec2.DescribeSubnetsOutput{Subnets: []*ec2.Subnet{ - {SubnetId: aws.String("test-subnet-1"), AvailabilityZone: aws.String("test-zone-1a")}, - {SubnetId: aws.String("test-subnet-2"), AvailabilityZone: aws.String("test-zone-1b")}, - {SubnetId: aws.String("test-subnet-3"), AvailabilityZone: aws.String("test-zone-1c")}, - }} // Setup pod := test.PendingPod() ExpectCreatedWithStatus(env.Client, pod) @@ -167,11 +162,6 @@ var _ = Describe("Allocation", func() { )) }) It("should default to a provisioner's zone", func() { - fakeEC2API.DescribeSubnetsOutput = &ec2.DescribeSubnetsOutput{Subnets: []*ec2.Subnet{ - {SubnetId: aws.String("test-subnet-1"), AvailabilityZone: aws.String("test-zone-1a")}, - {SubnetId: aws.String("test-subnet-2"), AvailabilityZone: aws.String("test-zone-1b")}, - {SubnetId: aws.String("test-subnet-3"), AvailabilityZone: aws.String("test-zone-1c")}, - }} // Setup provisioner.Spec.Zones = []string{"test-zone-1a", "test-zone-1b"} pod := test.PendingPod() @@ -196,11 +186,6 @@ var _ = Describe("Allocation", func() { ) }) It("should allow pod to override default zone", func() { - fakeEC2API.DescribeSubnetsOutput = &ec2.DescribeSubnetsOutput{Subnets: []*ec2.Subnet{ - {SubnetId: aws.String("test-subnet-1"), AvailabilityZone: aws.String("test-zone-1a")}, - {SubnetId: aws.String("test-subnet-2"), AvailabilityZone: aws.String("test-zone-1b")}, - {SubnetId: aws.String("test-subnet-3"), AvailabilityZone: aws.String("test-zone-1c")}, - }} // Setup provisioner.Spec.Zones = []string{"test-zone-1a", "test-zone-1b"} pod := test.PendingPodWith(test.PodOptions{NodeSelector: map[string]string{v1alpha1.ZoneLabelKey: "test-zone-1c"}})