Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix bug in ec2 api mock #354

Merged
merged 5 commits into from
Apr 9, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 63 additions & 52 deletions pkg/cloudprovider/aws/fake/ec2api.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,9 @@ import (
"github.com/aws/aws-sdk-go/service/ec2/ec2iface"
)

type EC2API struct {
ec2iface.EC2API
// EC2Behavior must be reset between tests otherwise tests will
// pollute each other.
type EC2Behavior struct {
CreateFleetOutput *ec2.CreateFleetOutput
DescribeInstancesOutput *ec2.DescribeInstancesOutput
DescribeLaunchTemplatesOutput *ec2.DescribeLaunchTemplatesOutput
Expand All @@ -36,82 +37,92 @@ type EC2API struct {
DescribeInstanceTypeOfferingsOutput *ec2.DescribeInstanceTypeOfferingsOutput
DescribeAvailabilityZonesOutput *ec2.DescribeAvailabilityZonesOutput
WantErr error
CalledWithCreateFleetInput []ec2.CreateFleetInput
Instances []*ec2.Instance
}

CalledWithCreateFleetInput []ec2.CreateFleetInput
Instances []*ec2.Instance
type EC2API struct {
ec2iface.EC2API
EC2Behavior
}

func (a *EC2API) Reset() {
a.CalledWithCreateFleetInput = nil
// Reset must be called between tests otherwise tests will pollute
// each other.
func (e *EC2API) Reset() {
e.EC2Behavior = EC2Behavior{}
JacobGabrielson marked this conversation as resolved.
Show resolved Hide resolved
}

func (a *EC2API) CreateFleetWithContext(ctx context.Context, input *ec2.CreateFleetInput, options ...request.Option) (*ec2.CreateFleetOutput, error) {
a.CalledWithCreateFleetInput = append(a.CalledWithCreateFleetInput, *input)
if a.WantErr != nil {
return nil, a.WantErr
func (e *EC2API) CreateFleetWithContext(ctx context.Context, input *ec2.CreateFleetInput, options ...request.Option) (*ec2.CreateFleetOutput, error) {
e.CalledWithCreateFleetInput = append(e.CalledWithCreateFleetInput, *input)
if e.WantErr != nil {
return nil, e.WantErr
}
if a.CreateFleetOutput != nil {
return a.CreateFleetOutput, nil
if e.CreateFleetOutput != nil {
return e.CreateFleetOutput, nil
}
instance := &ec2.Instance{
InstanceId: aws.String(randomdata.SillyName()),
Placement: &ec2.Placement{AvailabilityZone: aws.String("test-zone")},
PrivateDnsName: aws.String(fmt.Sprintf("test-instance-%d.example.com", len(a.Instances))),
Placement: &ec2.Placement{AvailabilityZone: aws.String("test-zone-1a")},
PrivateDnsName: aws.String(fmt.Sprintf("test-instance-%d.example.com", len(e.Instances))),
}
a.Instances = append(a.Instances, instance)
e.Instances = append(e.Instances, instance)
return &ec2.CreateFleetOutput{Instances: []*ec2.CreateFleetInstance{{InstanceIds: []*string{instance.InstanceId}}}}, nil
}

func (a *EC2API) DescribeInstancesWithContext(context.Context, *ec2.DescribeInstancesInput, ...request.Option) (*ec2.DescribeInstancesOutput, error) {
if a.WantErr != nil {
return nil, a.WantErr
func (e *EC2API) DescribeInstancesWithContext(context.Context, *ec2.DescribeInstancesInput, ...request.Option) (*ec2.DescribeInstancesOutput, error) {
if e.WantErr != nil {
return nil, e.WantErr
}
if a.DescribeInstancesOutput != nil {
return a.DescribeInstancesOutput, nil
if e.DescribeInstancesOutput != nil {
return e.DescribeInstancesOutput, nil
}
return &ec2.DescribeInstancesOutput{
Reservations: []*ec2.Reservation{{Instances: a.Instances}},
Reservations: []*ec2.Reservation{{Instances: e.Instances}},
}, nil
}

func (a *EC2API) DescribeLaunchTemplatesWithContext(context.Context, *ec2.DescribeLaunchTemplatesInput, ...request.Option) (*ec2.DescribeLaunchTemplatesOutput, error) {
if a.WantErr != nil {
return nil, a.WantErr
func (e *EC2API) DescribeLaunchTemplatesWithContext(context.Context, *ec2.DescribeLaunchTemplatesInput, ...request.Option) (*ec2.DescribeLaunchTemplatesOutput, error) {
if e.WantErr != nil {
return nil, e.WantErr
}
if a.DescribeLaunchTemplatesOutput != nil {
return a.DescribeLaunchTemplatesOutput, nil
if e.DescribeLaunchTemplatesOutput != nil {
return e.DescribeLaunchTemplatesOutput, nil
}
return &ec2.DescribeLaunchTemplatesOutput{LaunchTemplates: []*ec2.LaunchTemplate{{
LaunchTemplateName: aws.String("test-launch-template"),
}}}, nil
}

func (a *EC2API) DescribeSubnetsWithContext(context.Context, *ec2.DescribeSubnetsInput, ...request.Option) (*ec2.DescribeSubnetsOutput, error) {
if a.WantErr != nil {
return nil, a.WantErr
func (e *EC2API) DescribeSubnetsWithContext(context.Context, *ec2.DescribeSubnetsInput, ...request.Option) (*ec2.DescribeSubnetsOutput, error) {
if e.WantErr != nil {
return nil, e.WantErr
}
if a.DescribeSubnetsOutput != nil {
return a.DescribeSubnetsOutput, nil
if e.DescribeSubnetsOutput != nil {
return e.DescribeSubnetsOutput, nil
}
return &ec2.DescribeSubnetsOutput{Subnets: []*ec2.Subnet{{SubnetId: aws.String("test-subnet"), AvailabilityZone: aws.String("test-zone")}}}, nil
return &ec2.DescribeSubnetsOutput{Subnets: []*ec2.Subnet{
{SubnetId: aws.String("test-subnet-1"), AvailabilityZone: aws.String("test-zone-1a")},
{SubnetId: aws.String("test-subnet-2"), AvailabilityZone: aws.String("test-zone-1b")},
{SubnetId: aws.String("test-subnet-3"), AvailabilityZone: aws.String("test-zone-1c")},
}}, nil
}

func (a *EC2API) DescribeSecurityGroupsWithContext(context.Context, *ec2.DescribeSecurityGroupsInput, ...request.Option) (*ec2.DescribeSecurityGroupsOutput, error) {
if a.WantErr != nil {
return nil, a.WantErr
func (e *EC2API) DescribeSecurityGroupsWithContext(context.Context, *ec2.DescribeSecurityGroupsInput, ...request.Option) (*ec2.DescribeSecurityGroupsOutput, error) {
if e.WantErr != nil {
return nil, e.WantErr
}
if a.DescribeSecurityGroupsOutput != nil {
return a.DescribeSecurityGroupsOutput, nil
if e.DescribeSecurityGroupsOutput != nil {
return e.DescribeSecurityGroupsOutput, nil
}
return &ec2.DescribeSecurityGroupsOutput{SecurityGroups: []*ec2.SecurityGroup{{GroupId: aws.String("test-group")}}}, nil
}

func (a *EC2API) DescribeAvailabilityZonesWithContext(context.Context, *ec2.DescribeAvailabilityZonesInput, ...request.Option) (*ec2.DescribeAvailabilityZonesOutput, error) {
if a.WantErr != nil {
return nil, a.WantErr
func (e *EC2API) DescribeAvailabilityZonesWithContext(context.Context, *ec2.DescribeAvailabilityZonesInput, ...request.Option) (*ec2.DescribeAvailabilityZonesOutput, error) {
if e.WantErr != nil {
return nil, e.WantErr
}
if a.DescribeAvailabilityZonesOutput != nil {
return a.DescribeAvailabilityZonesOutput, nil
if e.DescribeAvailabilityZonesOutput != nil {
return e.DescribeAvailabilityZonesOutput, nil
}
return &ec2.DescribeAvailabilityZonesOutput{AvailabilityZones: []*ec2.AvailabilityZone{
{ZoneName: aws.String("test-zone-1a"), ZoneId: aws.String("testzone1a")},
Expand All @@ -120,12 +131,12 @@ func (a *EC2API) DescribeAvailabilityZonesWithContext(context.Context, *ec2.Desc
}}, nil
}

func (a *EC2API) DescribeInstanceTypesPagesWithContext(ctx context.Context, input *ec2.DescribeInstanceTypesInput, fn func(*ec2.DescribeInstanceTypesOutput, bool) bool, opts ...request.Option) error {
if a.WantErr != nil {
return a.WantErr
func (e *EC2API) DescribeInstanceTypesPagesWithContext(ctx context.Context, input *ec2.DescribeInstanceTypesInput, fn func(*ec2.DescribeInstanceTypesOutput, bool) bool, opts ...request.Option) error {
if e.WantErr != nil {
return e.WantErr
}
if a.DescribeInstanceTypesOutput != nil {
fn(a.DescribeInstanceTypesOutput, false)
if e.DescribeInstanceTypesOutput != nil {
fn(e.DescribeInstanceTypesOutput, false)
return nil
}
fn(&ec2.DescribeInstanceTypesOutput{
Expand Down Expand Up @@ -175,12 +186,12 @@ func (a *EC2API) DescribeInstanceTypesPagesWithContext(ctx context.Context, inpu
return nil
}

func (a *EC2API) DescribeInstanceTypeOfferingsPagesWithContext(ctx context.Context, input *ec2.DescribeInstanceTypeOfferingsInput, fn func(*ec2.DescribeInstanceTypeOfferingsOutput, bool) bool, opts ...request.Option) error {
if a.WantErr != nil {
return a.WantErr
func (e *EC2API) DescribeInstanceTypeOfferingsPagesWithContext(ctx context.Context, input *ec2.DescribeInstanceTypeOfferingsInput, fn func(*ec2.DescribeInstanceTypeOfferingsOutput, bool) bool, opts ...request.Option) error {
if e.WantErr != nil {
return e.WantErr
}
if a.DescribeInstanceTypeOfferingsOutput != nil {
fn(a.DescribeInstanceTypeOfferingsOutput, false)
if e.DescribeInstanceTypeOfferingsOutput != nil {
fn(e.DescribeInstanceTypeOfferingsOutput, false)
return nil
}
fn(&ec2.DescribeInstanceTypeOfferingsOutput{
Expand Down
6 changes: 4 additions & 2 deletions pkg/cloudprovider/aws/instancetypes_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -130,8 +130,10 @@ var _ = Describe("InstanceTypes", func() {

func getInstanceTypeProviderMocks(zones []string, instanceTypes []string) ec2iface.EC2API {
ec2api := &fake.EC2API{
DescribeInstanceTypesOutput: &ec2.DescribeInstanceTypesOutput{},
DescribeInstanceTypeOfferingsOutput: &ec2.DescribeInstanceTypeOfferingsOutput{},
EC2Behavior: fake.EC2Behavior{
DescribeInstanceTypesOutput: &ec2.DescribeInstanceTypesOutput{},
DescribeInstanceTypeOfferingsOutput: &ec2.DescribeInstanceTypeOfferingsOutput{},
},
}

for _, instanceType := range instanceTypes {
Expand Down
15 changes: 0 additions & 15 deletions pkg/cloudprovider/aws/suite_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -136,11 +136,6 @@ var _ = Describe("Allocation", func() {

Context("Reconciliation", func() {
It("should default to a cluster zone", func() {
fakeEC2API.DescribeSubnetsOutput = &ec2.DescribeSubnetsOutput{Subnets: []*ec2.Subnet{
{SubnetId: aws.String("test-subnet-1"), AvailabilityZone: aws.String("test-zone-1a")},
{SubnetId: aws.String("test-subnet-2"), AvailabilityZone: aws.String("test-zone-1b")},
{SubnetId: aws.String("test-subnet-3"), AvailabilityZone: aws.String("test-zone-1c")},
}}
// Setup
pod := test.PendingPod()
ExpectCreatedWithStatus(env.Client, pod)
Expand All @@ -167,11 +162,6 @@ var _ = Describe("Allocation", func() {
))
})
It("should default to a provisioner's zone", func() {
fakeEC2API.DescribeSubnetsOutput = &ec2.DescribeSubnetsOutput{Subnets: []*ec2.Subnet{
{SubnetId: aws.String("test-subnet-1"), AvailabilityZone: aws.String("test-zone-1a")},
{SubnetId: aws.String("test-subnet-2"), AvailabilityZone: aws.String("test-zone-1b")},
{SubnetId: aws.String("test-subnet-3"), AvailabilityZone: aws.String("test-zone-1c")},
}}
// Setup
provisioner.Spec.Zones = []string{"test-zone-1a", "test-zone-1b"}
pod := test.PendingPod()
Expand All @@ -196,11 +186,6 @@ var _ = Describe("Allocation", func() {
)
})
It("should allow pod to override default zone", func() {
fakeEC2API.DescribeSubnetsOutput = &ec2.DescribeSubnetsOutput{Subnets: []*ec2.Subnet{
{SubnetId: aws.String("test-subnet-1"), AvailabilityZone: aws.String("test-zone-1a")},
{SubnetId: aws.String("test-subnet-2"), AvailabilityZone: aws.String("test-zone-1b")},
{SubnetId: aws.String("test-subnet-3"), AvailabilityZone: aws.String("test-zone-1c")},
}}
// Setup
provisioner.Spec.Zones = []string{"test-zone-1a", "test-zone-1b"}
pod := test.PendingPodWith(test.PodOptions{NodeSelector: map[string]string{v1alpha1.ZoneLabelKey: "test-zone-1c"}})
Expand Down