Skip to content

Commit

Permalink
fix: Ensure shallow copy of data when returning back cached data (aws…
Browse files Browse the repository at this point in the history
  • Loading branch information
jonathan-innis committed May 13, 2024
1 parent 7acfd40 commit 641258c
Show file tree
Hide file tree
Showing 8 changed files with 256 additions and 10 deletions.
8 changes: 6 additions & 2 deletions pkg/providers/amifamily/ami.go
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,9 @@ func (p *Provider) Get(ctx context.Context, nodeClass *v1beta1.EC2NodeClass, opt

func (p *Provider) getDefaultAMIs(ctx context.Context, nodeClass *v1beta1.EC2NodeClass, options *Options) (res AMIs, err error) {
if images, ok := p.cache.Get(lo.FromPtr(nodeClass.Spec.AMIFamily)); ok {
return images.(AMIs), nil
// Ensure what's returned from this function is a deep-copy of AMIs so alterations
// to the data don't affect the original
return append(AMIs{}, images.(AMIs)...), nil
}
amiFamily := GetAMIFamily(nodeClass.Spec.AMIFamily, options)
kubernetesVersion, err := p.versionProvider.Get(ctx)
Expand Down Expand Up @@ -187,7 +189,9 @@ func (p *Provider) getAMIs(ctx context.Context, terms []v1beta1.AMISelectorTerm)
return nil, err
}
if images, ok := p.cache.Get(fmt.Sprintf("%d", hash)); ok {
return images.(AMIs), nil
// Ensure what's returned from this function is a deep-copy of AMIs so alterations
// to the data don't affect the original
return append(AMIs{}, images.(AMIs)...), nil
}
images := map[uint64]AMI{}
for _, filtersAndOwners := range filterAndOwnerSets {
Expand Down
52 changes: 48 additions & 4 deletions pkg/providers/amifamily/suite_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"context"
"fmt"
"sort"
"sync"
"testing"
"time"

Expand Down Expand Up @@ -74,7 +75,7 @@ var _ = BeforeEach(func() {
{
Name: aws.String(amd64AMI),
ImageId: aws.String("amd64-ami-id"),
CreationDate: aws.String(time.Now().Format(time.RFC3339)),
CreationDate: aws.String(time.Time{}.Format(time.RFC3339)),
Architecture: aws.String("x86_64"),
Tags: []*ec2.Tag{
{Key: aws.String("Name"), Value: aws.String(amd64AMI)},
Expand All @@ -84,7 +85,7 @@ var _ = BeforeEach(func() {
{
Name: aws.String(arm64AMI),
ImageId: aws.String("arm64-ami-id"),
CreationDate: aws.String(time.Now().Add(time.Minute).Format(time.RFC3339)),
CreationDate: aws.String(time.Time{}.Add(time.Minute).Format(time.RFC3339)),
Architecture: aws.String("arm64"),
Tags: []*ec2.Tag{
{Key: aws.String("Name"), Value: aws.String(arm64AMI)},
Expand All @@ -94,7 +95,7 @@ var _ = BeforeEach(func() {
{
Name: aws.String(amd64NvidiaAMI),
ImageId: aws.String("amd64-nvidia-ami-id"),
CreationDate: aws.String(time.Now().Add(2 * time.Minute).Format(time.RFC3339)),
CreationDate: aws.String(time.Time{}.Add(2 * time.Minute).Format(time.RFC3339)),
Architecture: aws.String("x86_64"),
Tags: []*ec2.Tag{
{Key: aws.String("Name"), Value: aws.String(amd64NvidiaAMI)},
Expand All @@ -104,7 +105,7 @@ var _ = BeforeEach(func() {
{
Name: aws.String(arm64NvidiaAMI),
ImageId: aws.String("arm64-nvidia-ami-id"),
CreationDate: aws.String(time.Now().Add(2 * time.Minute).Format(time.RFC3339)),
CreationDate: aws.String(time.Time{}.Add(2 * time.Minute).Format(time.RFC3339)),
Architecture: aws.String("arm64"),
Tags: []*ec2.Tag{
{Key: aws.String("Name"), Value: aws.String(arm64NvidiaAMI)},
Expand Down Expand Up @@ -196,6 +197,49 @@ var _ = Describe("AMIProvider", func() {
Expect(err).ToNot(HaveOccurred())
Expect(amis).To(HaveLen(0))
})
It("should not cause data races when calling Get() simultaneously", func() {
nodeClass.Spec.AMISelectorTerms = []v1beta1.AMISelectorTerm{
{
ID: "amd64-ami-id",
},
{
ID: "arm64-ami-id",
},
}
wg := sync.WaitGroup{}
for i := 0; i < 10000; i++ {
wg.Add(1)
go func() {
defer wg.Done()
defer GinkgoRecover()
images, err := awsEnv.AMIProvider.Get(ctx, nodeClass, &amifamily.Options{})
Expect(err).ToNot(HaveOccurred())

Expect(images).To(HaveLen(2))
// Sort everything in parallel and ensure that we don't get data races
images.Sort()
Expect(images).To(BeEquivalentTo([]amifamily.AMI{
{
Name: arm64AMI,
AmiID: "arm64-ami-id",
CreationDate: time.Time{}.Add(time.Minute).Format(time.RFC3339),
Requirements: scheduling.NewLabelRequirements(map[string]string{
v1.LabelArchStable: corev1beta1.ArchitectureArm64,
}),
},
{
Name: amd64AMI,
AmiID: "amd64-ami-id",
CreationDate: time.Time{}.Format(time.RFC3339),
Requirements: scheduling.NewLabelRequirements(map[string]string{
v1.LabelArchStable: corev1beta1.ArchitectureAmd64,
}),
},
}))
}()
}
wg.Wait()
})
Context("SSM Alias Missing", func() {
It("should succeed to partially resolve AMIs if all SSM aliases don't exist (Al2)", func() {
nodeClass.Spec.AMIFamily = &v1beta1.AMIFamilyAL2
Expand Down
4 changes: 3 additions & 1 deletion pkg/providers/instancetype/instancetype.go
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,9 @@ func (p *Provider) List(ctx context.Context, kc *corev1beta1.KubeletConfiguratio
systemReservedHash,
)
if item, ok := p.cache.Get(key); ok {
return item.([]*cloudprovider.InstanceType), nil
// Ensure what's returned from this function is a shallow-copy of the slice (not a deep-copy of the data itself)
// so that modifications to the ordering of the data don't affect the original
return append([]*cloudprovider.InstanceType{}, item.([]*cloudprovider.InstanceType)...), nil
}
result := lo.Map(instanceTypes, func(i *ec2.InstanceTypeInfo, _ int) *cloudprovider.InstanceType {
return NewInstanceType(ctx, i, kc, p.region, nodeClass, p.createOfferings(ctx, i, instanceTypeOfferings[aws.StringValue(i.InstanceType)], zones, subnetZones))
Expand Down
37 changes: 36 additions & 1 deletion pkg/providers/instancetype/suite_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"net"
"sort"
"strings"
"sync"
"testing"
"time"

Expand Down Expand Up @@ -783,7 +784,6 @@ var _ = Describe("InstanceTypes", func() {
ExpectScheduled(ctx, env.Client, pod)

})

Context("Overhead", func() {
var info *ec2.InstanceTypeInfo
BeforeEach(func() {
Expand Down Expand Up @@ -1676,6 +1676,41 @@ var _ = Describe("InstanceTypes", func() {
})
})
})
It("should not cause data races when calling List() simultaneously", func() {
mu := sync.RWMutex{}
var instanceTypeOrder []string
wg := sync.WaitGroup{}
for i := 0; i < 10000; i++ {
wg.Add(1)
go func() {
defer wg.Done()
defer GinkgoRecover()
instanceTypes, err := awsEnv.InstanceTypesProvider.List(ctx, &corev1beta1.KubeletConfiguration{}, nodeClass)
Expect(err).ToNot(HaveOccurred())

// Sort everything in parallel and ensure that we don't get data races
sort.Slice(instanceTypes, func(i, j int) bool {
return instanceTypes[i].Name < instanceTypes[j].Name
})
// Get the ordering of the instance types based on name
tempInstanceTypeOrder := lo.Map(instanceTypes, func(i *corecloudprovider.InstanceType, _ int) string {
return i.Name
})
// Expect that all the elements in the instance type list are unique
Expect(lo.Uniq(tempInstanceTypeOrder)).To(HaveLen(len(tempInstanceTypeOrder)))

// We have to lock since we are doing simultaneous access to this value
mu.Lock()
if len(instanceTypeOrder) == 0 {
instanceTypeOrder = tempInstanceTypeOrder
} else {
Expect(tempInstanceTypeOrder).To(BeEquivalentTo(instanceTypeOrder))
}
mu.Unlock()
}()
}
wg.Wait()
})
})

// generateSpotPricing creates a spot price history output for use in a mock that has all spot offerings discounted by 50%
Expand Down
4 changes: 3 additions & 1 deletion pkg/providers/securitygroup/securitygroup.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,9 @@ func (p *Provider) getSecurityGroups(ctx context.Context, filterSets [][]*ec2.Fi
return nil, err
}
if sg, ok := p.cache.Get(fmt.Sprint(hash)); ok {
return sg.([]*ec2.SecurityGroup), nil
// Ensure what's returned from this function is a shallow-copy of the slice (not a deep-copy of the data itself)
// so that modifications to the ordering of the data don't affect the original
return append([]*ec2.SecurityGroup{}, sg.([]*ec2.SecurityGroup)...), nil
}
securityGroups := map[string]*ec2.SecurityGroup{}
for _, filters := range filterSets {
Expand Down
68 changes: 68 additions & 0 deletions pkg/providers/securitygroup/suite_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ package securitygroup_test

import (
"context"
"sort"
"sync"
"testing"

"github.com/aws/aws-sdk-go/aws"
Expand Down Expand Up @@ -265,6 +267,72 @@ var _ = Describe("SecurityGroupProvider", func() {
},
}, securityGroups)
})
It("should not cause data races when calling List() simultaneously", func() {
wg := sync.WaitGroup{}
for i := 0; i < 10000; i++ {
wg.Add(1)
go func() {
defer wg.Done()
defer GinkgoRecover()
securityGroups, err := awsEnv.SecurityGroupProvider.List(ctx, nodeClass)
Expect(err).ToNot(HaveOccurred())

Expect(securityGroups).To(HaveLen(3))
// Sort everything in parallel and ensure that we don't get data races
sort.Slice(securityGroups, func(i, j int) bool {
return *securityGroups[i].GroupId < *securityGroups[j].GroupId
})
Expect(securityGroups).To(BeEquivalentTo([]*ec2.SecurityGroup{
{
GroupId: lo.ToPtr("sg-test1"),
GroupName: lo.ToPtr("securityGroup-test1"),
Tags: []*ec2.Tag{
{
Key: lo.ToPtr("Name"),
Value: lo.ToPtr("test-security-group-1"),
},
{
Key: lo.ToPtr("foo"),
Value: lo.ToPtr("bar"),
},
},
},
{
GroupId: lo.ToPtr("sg-test2"),
GroupName: lo.ToPtr("securityGroup-test2"),
Tags: []*ec2.Tag{
{
Key: lo.ToPtr("Name"),
Value: lo.ToPtr("test-security-group-2"),
},
{
Key: lo.ToPtr("foo"),
Value: lo.ToPtr("bar"),
},
},
},
{
GroupId: lo.ToPtr("sg-test3"),
GroupName: lo.ToPtr("securityGroup-test3"),
Tags: []*ec2.Tag{
{
Key: lo.ToPtr("Name"),
Value: lo.ToPtr("test-security-group-3"),
},
{
Key: lo.ToPtr("TestTag"),
},
{
Key: lo.ToPtr("foo"),
Value: lo.ToPtr("bar"),
},
},
},
}))
}()
}
wg.Wait()
})
})

func ExpectConsistsOfSecurityGroups(expected, actual []*ec2.SecurityGroup) {
Expand Down
4 changes: 3 additions & 1 deletion pkg/providers/subnet/subnet.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,9 @@ func (p *Provider) List(ctx context.Context, nodeClass *v1beta1.EC2NodeClass) ([
return nil, err
}
if subnets, ok := p.cache.Get(fmt.Sprint(hash)); ok {
return subnets.([]*ec2.Subnet), nil
// Ensure what's returned from this function is a shallow-copy of the slice (not a deep-copy of the data itself)
// so that modifications to the ordering of the data don't affect the original
return append([]*ec2.Subnet{}, subnets.([]*ec2.Subnet)...), nil
}

// Ensure that all the subnets that are returned here are unique
Expand Down
89 changes: 89 additions & 0 deletions pkg/providers/subnet/suite_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ package subnet_test

import (
"context"
"sort"
"sync"
"testing"

"github.com/aws/aws-sdk-go/aws"
Expand Down Expand Up @@ -240,6 +242,93 @@ var _ = Describe("SubnetProvider", func() {
Expect(onlyPrivate).To(BeTrue())
})
})
It("should not cause data races when calling List() simultaneously", func() {
wg := sync.WaitGroup{}
for i := 0; i < 10000; i++ {
wg.Add(1)
go func() {
defer wg.Done()
defer GinkgoRecover()
subnets, err := awsEnv.SubnetProvider.List(ctx, nodeClass)
Expect(err).ToNot(HaveOccurred())

Expect(subnets).To(HaveLen(4))
// Sort everything in parallel and ensure that we don't get data races
sort.Slice(subnets, func(i, j int) bool {
if int(*subnets[i].AvailableIpAddressCount) != int(*subnets[j].AvailableIpAddressCount) {
return int(*subnets[i].AvailableIpAddressCount) > int(*subnets[j].AvailableIpAddressCount)
}
return *subnets[i].SubnetId < *subnets[j].SubnetId
})
Expect(subnets).To(BeEquivalentTo([]*ec2.Subnet{
{
AvailabilityZone: lo.ToPtr("test-zone-1a"),
AvailableIpAddressCount: lo.ToPtr[int64](100),
SubnetId: lo.ToPtr("subnet-test1"),
MapPublicIpOnLaunch: lo.ToPtr(false),
Tags: []*ec2.Tag{
{
Key: lo.ToPtr("Name"),
Value: lo.ToPtr("test-subnet-1"),
},
{
Key: lo.ToPtr("foo"),
Value: lo.ToPtr("bar"),
},
},
},
{
AvailabilityZone: lo.ToPtr("test-zone-1b"),
AvailableIpAddressCount: lo.ToPtr[int64](100),
MapPublicIpOnLaunch: lo.ToPtr(true),
SubnetId: lo.ToPtr("subnet-test2"),

Tags: []*ec2.Tag{
{
Key: lo.ToPtr("Name"),
Value: lo.ToPtr("test-subnet-2"),
},
{
Key: lo.ToPtr("foo"),
Value: lo.ToPtr("bar"),
},
},
},
{
AvailabilityZone: lo.ToPtr("test-zone-1c"),
AvailableIpAddressCount: lo.ToPtr[int64](100),
SubnetId: lo.ToPtr("subnet-test3"),
Tags: []*ec2.Tag{
{
Key: lo.ToPtr("Name"),
Value: lo.ToPtr("test-subnet-3"),
},
{
Key: lo.ToPtr("TestTag"),
},
{
Key: lo.ToPtr("foo"),
Value: lo.ToPtr("bar"),
},
},
},
{
AvailabilityZone: lo.ToPtr("test-zone-1a-local"),
AvailableIpAddressCount: lo.ToPtr[int64](100),
SubnetId: lo.ToPtr("subnet-test4"),
MapPublicIpOnLaunch: lo.ToPtr(true),
Tags: []*ec2.Tag{
{
Key: lo.ToPtr("Name"),
Value: lo.ToPtr("test-subnet-4"),
},
},
},
}))
}()
}
wg.Wait()
})
})

func ExpectConsistsOfSubnets(expected, actual []*ec2.Subnet) {
Expand Down

0 comments on commit 641258c

Please sign in to comment.