Skip to content

Commit

Permalink
Fix bug in ec2 api mock (#354)
Browse files Browse the repository at this point in the history
More fields needed to be Reset
  • Loading branch information
Jacob Gabrielson authored Apr 9, 2021
1 parent 176ba09 commit e335f87
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 69 deletions.
115 changes: 63 additions & 52 deletions pkg/cloudprovider/aws/fake/ec2api.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,9 @@ import (
"github.com/aws/aws-sdk-go/service/ec2/ec2iface"
)

type EC2API struct {
ec2iface.EC2API
// EC2Behavior must be reset between tests otherwise tests will
// pollute each other.
type EC2Behavior struct {
CreateFleetOutput *ec2.CreateFleetOutput
DescribeInstancesOutput *ec2.DescribeInstancesOutput
DescribeLaunchTemplatesOutput *ec2.DescribeLaunchTemplatesOutput
Expand All @@ -36,82 +37,92 @@ type EC2API struct {
DescribeInstanceTypeOfferingsOutput *ec2.DescribeInstanceTypeOfferingsOutput
DescribeAvailabilityZonesOutput *ec2.DescribeAvailabilityZonesOutput
WantErr error
CalledWithCreateFleetInput []ec2.CreateFleetInput
Instances []*ec2.Instance
}

CalledWithCreateFleetInput []ec2.CreateFleetInput
Instances []*ec2.Instance
type EC2API struct {
ec2iface.EC2API
EC2Behavior
}

func (a *EC2API) Reset() {
a.CalledWithCreateFleetInput = nil
// Reset must be called between tests otherwise tests will pollute
// each other.
func (e *EC2API) Reset() {
e.EC2Behavior = EC2Behavior{}
}

func (a *EC2API) CreateFleetWithContext(ctx context.Context, input *ec2.CreateFleetInput, options ...request.Option) (*ec2.CreateFleetOutput, error) {
a.CalledWithCreateFleetInput = append(a.CalledWithCreateFleetInput, *input)
if a.WantErr != nil {
return nil, a.WantErr
func (e *EC2API) CreateFleetWithContext(ctx context.Context, input *ec2.CreateFleetInput, options ...request.Option) (*ec2.CreateFleetOutput, error) {
e.CalledWithCreateFleetInput = append(e.CalledWithCreateFleetInput, *input)
if e.WantErr != nil {
return nil, e.WantErr
}
if a.CreateFleetOutput != nil {
return a.CreateFleetOutput, nil
if e.CreateFleetOutput != nil {
return e.CreateFleetOutput, nil
}
instance := &ec2.Instance{
InstanceId: aws.String(randomdata.SillyName()),
Placement: &ec2.Placement{AvailabilityZone: aws.String("test-zone")},
PrivateDnsName: aws.String(fmt.Sprintf("test-instance-%d.example.com", len(a.Instances))),
Placement: &ec2.Placement{AvailabilityZone: aws.String("test-zone-1a")},
PrivateDnsName: aws.String(fmt.Sprintf("test-instance-%d.example.com", len(e.Instances))),
}
a.Instances = append(a.Instances, instance)
e.Instances = append(e.Instances, instance)
return &ec2.CreateFleetOutput{Instances: []*ec2.CreateFleetInstance{{InstanceIds: []*string{instance.InstanceId}}}}, nil
}

func (a *EC2API) DescribeInstancesWithContext(context.Context, *ec2.DescribeInstancesInput, ...request.Option) (*ec2.DescribeInstancesOutput, error) {
if a.WantErr != nil {
return nil, a.WantErr
func (e *EC2API) DescribeInstancesWithContext(context.Context, *ec2.DescribeInstancesInput, ...request.Option) (*ec2.DescribeInstancesOutput, error) {
if e.WantErr != nil {
return nil, e.WantErr
}
if a.DescribeInstancesOutput != nil {
return a.DescribeInstancesOutput, nil
if e.DescribeInstancesOutput != nil {
return e.DescribeInstancesOutput, nil
}
return &ec2.DescribeInstancesOutput{
Reservations: []*ec2.Reservation{{Instances: a.Instances}},
Reservations: []*ec2.Reservation{{Instances: e.Instances}},
}, nil
}

func (a *EC2API) DescribeLaunchTemplatesWithContext(context.Context, *ec2.DescribeLaunchTemplatesInput, ...request.Option) (*ec2.DescribeLaunchTemplatesOutput, error) {
if a.WantErr != nil {
return nil, a.WantErr
func (e *EC2API) DescribeLaunchTemplatesWithContext(context.Context, *ec2.DescribeLaunchTemplatesInput, ...request.Option) (*ec2.DescribeLaunchTemplatesOutput, error) {
if e.WantErr != nil {
return nil, e.WantErr
}
if a.DescribeLaunchTemplatesOutput != nil {
return a.DescribeLaunchTemplatesOutput, nil
if e.DescribeLaunchTemplatesOutput != nil {
return e.DescribeLaunchTemplatesOutput, nil
}
return &ec2.DescribeLaunchTemplatesOutput{LaunchTemplates: []*ec2.LaunchTemplate{{
LaunchTemplateName: aws.String("test-launch-template"),
}}}, nil
}

func (a *EC2API) DescribeSubnetsWithContext(context.Context, *ec2.DescribeSubnetsInput, ...request.Option) (*ec2.DescribeSubnetsOutput, error) {
if a.WantErr != nil {
return nil, a.WantErr
func (e *EC2API) DescribeSubnetsWithContext(context.Context, *ec2.DescribeSubnetsInput, ...request.Option) (*ec2.DescribeSubnetsOutput, error) {
if e.WantErr != nil {
return nil, e.WantErr
}
if a.DescribeSubnetsOutput != nil {
return a.DescribeSubnetsOutput, nil
if e.DescribeSubnetsOutput != nil {
return e.DescribeSubnetsOutput, nil
}
return &ec2.DescribeSubnetsOutput{Subnets: []*ec2.Subnet{{SubnetId: aws.String("test-subnet"), AvailabilityZone: aws.String("test-zone")}}}, nil
return &ec2.DescribeSubnetsOutput{Subnets: []*ec2.Subnet{
{SubnetId: aws.String("test-subnet-1"), AvailabilityZone: aws.String("test-zone-1a")},
{SubnetId: aws.String("test-subnet-2"), AvailabilityZone: aws.String("test-zone-1b")},
{SubnetId: aws.String("test-subnet-3"), AvailabilityZone: aws.String("test-zone-1c")},
}}, nil
}

func (a *EC2API) DescribeSecurityGroupsWithContext(context.Context, *ec2.DescribeSecurityGroupsInput, ...request.Option) (*ec2.DescribeSecurityGroupsOutput, error) {
if a.WantErr != nil {
return nil, a.WantErr
func (e *EC2API) DescribeSecurityGroupsWithContext(context.Context, *ec2.DescribeSecurityGroupsInput, ...request.Option) (*ec2.DescribeSecurityGroupsOutput, error) {
if e.WantErr != nil {
return nil, e.WantErr
}
if a.DescribeSecurityGroupsOutput != nil {
return a.DescribeSecurityGroupsOutput, nil
if e.DescribeSecurityGroupsOutput != nil {
return e.DescribeSecurityGroupsOutput, nil
}
return &ec2.DescribeSecurityGroupsOutput{SecurityGroups: []*ec2.SecurityGroup{{GroupId: aws.String("test-group")}}}, nil
}

func (a *EC2API) DescribeAvailabilityZonesWithContext(context.Context, *ec2.DescribeAvailabilityZonesInput, ...request.Option) (*ec2.DescribeAvailabilityZonesOutput, error) {
if a.WantErr != nil {
return nil, a.WantErr
func (e *EC2API) DescribeAvailabilityZonesWithContext(context.Context, *ec2.DescribeAvailabilityZonesInput, ...request.Option) (*ec2.DescribeAvailabilityZonesOutput, error) {
if e.WantErr != nil {
return nil, e.WantErr
}
if a.DescribeAvailabilityZonesOutput != nil {
return a.DescribeAvailabilityZonesOutput, nil
if e.DescribeAvailabilityZonesOutput != nil {
return e.DescribeAvailabilityZonesOutput, nil
}
return &ec2.DescribeAvailabilityZonesOutput{AvailabilityZones: []*ec2.AvailabilityZone{
{ZoneName: aws.String("test-zone-1a"), ZoneId: aws.String("testzone1a")},
Expand All @@ -120,12 +131,12 @@ func (a *EC2API) DescribeAvailabilityZonesWithContext(context.Context, *ec2.Desc
}}, nil
}

func (a *EC2API) DescribeInstanceTypesPagesWithContext(ctx context.Context, input *ec2.DescribeInstanceTypesInput, fn func(*ec2.DescribeInstanceTypesOutput, bool) bool, opts ...request.Option) error {
if a.WantErr != nil {
return a.WantErr
func (e *EC2API) DescribeInstanceTypesPagesWithContext(ctx context.Context, input *ec2.DescribeInstanceTypesInput, fn func(*ec2.DescribeInstanceTypesOutput, bool) bool, opts ...request.Option) error {
if e.WantErr != nil {
return e.WantErr
}
if a.DescribeInstanceTypesOutput != nil {
fn(a.DescribeInstanceTypesOutput, false)
if e.DescribeInstanceTypesOutput != nil {
fn(e.DescribeInstanceTypesOutput, false)
return nil
}
fn(&ec2.DescribeInstanceTypesOutput{
Expand Down Expand Up @@ -175,12 +186,12 @@ func (a *EC2API) DescribeInstanceTypesPagesWithContext(ctx context.Context, inpu
return nil
}

func (a *EC2API) DescribeInstanceTypeOfferingsPagesWithContext(ctx context.Context, input *ec2.DescribeInstanceTypeOfferingsInput, fn func(*ec2.DescribeInstanceTypeOfferingsOutput, bool) bool, opts ...request.Option) error {
if a.WantErr != nil {
return a.WantErr
func (e *EC2API) DescribeInstanceTypeOfferingsPagesWithContext(ctx context.Context, input *ec2.DescribeInstanceTypeOfferingsInput, fn func(*ec2.DescribeInstanceTypeOfferingsOutput, bool) bool, opts ...request.Option) error {
if e.WantErr != nil {
return e.WantErr
}
if a.DescribeInstanceTypeOfferingsOutput != nil {
fn(a.DescribeInstanceTypeOfferingsOutput, false)
if e.DescribeInstanceTypeOfferingsOutput != nil {
fn(e.DescribeInstanceTypeOfferingsOutput, false)
return nil
}
fn(&ec2.DescribeInstanceTypeOfferingsOutput{
Expand Down
6 changes: 4 additions & 2 deletions pkg/cloudprovider/aws/instancetypes_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -130,8 +130,10 @@ var _ = Describe("InstanceTypes", func() {

func getInstanceTypeProviderMocks(zones []string, instanceTypes []string) ec2iface.EC2API {
ec2api := &fake.EC2API{
DescribeInstanceTypesOutput: &ec2.DescribeInstanceTypesOutput{},
DescribeInstanceTypeOfferingsOutput: &ec2.DescribeInstanceTypeOfferingsOutput{},
EC2Behavior: fake.EC2Behavior{
DescribeInstanceTypesOutput: &ec2.DescribeInstanceTypesOutput{},
DescribeInstanceTypeOfferingsOutput: &ec2.DescribeInstanceTypeOfferingsOutput{},
},
}

for _, instanceType := range instanceTypes {
Expand Down
15 changes: 0 additions & 15 deletions pkg/cloudprovider/aws/suite_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -136,11 +136,6 @@ var _ = Describe("Allocation", func() {

Context("Reconciliation", func() {
It("should default to a cluster zone", func() {
fakeEC2API.DescribeSubnetsOutput = &ec2.DescribeSubnetsOutput{Subnets: []*ec2.Subnet{
{SubnetId: aws.String("test-subnet-1"), AvailabilityZone: aws.String("test-zone-1a")},
{SubnetId: aws.String("test-subnet-2"), AvailabilityZone: aws.String("test-zone-1b")},
{SubnetId: aws.String("test-subnet-3"), AvailabilityZone: aws.String("test-zone-1c")},
}}
// Setup
pod := test.PendingPod()
ExpectCreatedWithStatus(env.Client, pod)
Expand All @@ -167,11 +162,6 @@ var _ = Describe("Allocation", func() {
))
})
It("should default to a provisioner's zone", func() {
fakeEC2API.DescribeSubnetsOutput = &ec2.DescribeSubnetsOutput{Subnets: []*ec2.Subnet{
{SubnetId: aws.String("test-subnet-1"), AvailabilityZone: aws.String("test-zone-1a")},
{SubnetId: aws.String("test-subnet-2"), AvailabilityZone: aws.String("test-zone-1b")},
{SubnetId: aws.String("test-subnet-3"), AvailabilityZone: aws.String("test-zone-1c")},
}}
// Setup
provisioner.Spec.Zones = []string{"test-zone-1a", "test-zone-1b"}
pod := test.PendingPod()
Expand All @@ -196,11 +186,6 @@ var _ = Describe("Allocation", func() {
)
})
It("should allow pod to override default zone", func() {
fakeEC2API.DescribeSubnetsOutput = &ec2.DescribeSubnetsOutput{Subnets: []*ec2.Subnet{
{SubnetId: aws.String("test-subnet-1"), AvailabilityZone: aws.String("test-zone-1a")},
{SubnetId: aws.String("test-subnet-2"), AvailabilityZone: aws.String("test-zone-1b")},
{SubnetId: aws.String("test-subnet-3"), AvailabilityZone: aws.String("test-zone-1c")},
}}
// Setup
provisioner.Spec.Zones = []string{"test-zone-1a", "test-zone-1b"}
pod := test.PendingPodWith(test.PodOptions{NodeSelector: map[string]string{v1alpha1.ZoneLabelKey: "test-zone-1c"}})
Expand Down

0 comments on commit e335f87

Please sign in to comment.