Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Ensure shallow copy of data when returning back cached data #6167

Merged
merged 2 commits into from
May 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions pkg/providers/amifamily/ami.go
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,9 @@ func (p *DefaultProvider) Get(ctx context.Context, nodeClass *v1beta1.EC2NodeCla

func (p *DefaultProvider) getDefaultAMIs(ctx context.Context, nodeClass *v1beta1.EC2NodeClass, options *Options) (res AMIs, err error) {
if images, ok := p.cache.Get(lo.FromPtr(nodeClass.Spec.AMIFamily)); ok {
return images.(AMIs), nil
// Ensure what's returned from this function is a deep-copy of AMIs so alterations
// to the data don't affect the original
return append(AMIs{}, images.(AMIs)...), nil
}
amiFamily := GetAMIFamily(nodeClass.Spec.AMIFamily, options)
kubernetesVersion, err := p.versionProvider.Get(ctx)
Expand Down Expand Up @@ -191,7 +193,9 @@ func (p *DefaultProvider) getAMIs(ctx context.Context, terms []v1beta1.AMISelect
return nil, err
}
if images, ok := p.cache.Get(fmt.Sprintf("%d", hash)); ok {
return images.(AMIs), nil
// Ensure what's returned from this function is a deep-copy of AMIs so alterations
// to the data don't affect the original
return append(AMIs{}, images.(AMIs)...), nil
}
images := map[uint64]AMI{}
for _, filtersAndOwners := range filterAndOwnerSets {
Expand Down
52 changes: 48 additions & 4 deletions pkg/providers/amifamily/suite_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"context"
"fmt"
"sort"
"sync"
"testing"
"time"

Expand Down Expand Up @@ -74,7 +75,7 @@ var _ = BeforeEach(func() {
{
Name: aws.String(amd64AMI),
ImageId: aws.String("amd64-ami-id"),
CreationDate: aws.String(time.Now().Format(time.RFC3339)),
CreationDate: aws.String(time.Time{}.Format(time.RFC3339)),
Architecture: aws.String("x86_64"),
Tags: []*ec2.Tag{
{Key: aws.String("Name"), Value: aws.String(amd64AMI)},
Expand All @@ -84,7 +85,7 @@ var _ = BeforeEach(func() {
{
Name: aws.String(arm64AMI),
ImageId: aws.String("arm64-ami-id"),
CreationDate: aws.String(time.Now().Add(time.Minute).Format(time.RFC3339)),
CreationDate: aws.String(time.Time{}.Add(time.Minute).Format(time.RFC3339)),
Architecture: aws.String("arm64"),
Tags: []*ec2.Tag{
{Key: aws.String("Name"), Value: aws.String(arm64AMI)},
Expand All @@ -94,7 +95,7 @@ var _ = BeforeEach(func() {
{
Name: aws.String(amd64NvidiaAMI),
ImageId: aws.String("amd64-nvidia-ami-id"),
CreationDate: aws.String(time.Now().Add(2 * time.Minute).Format(time.RFC3339)),
CreationDate: aws.String(time.Time{}.Add(2 * time.Minute).Format(time.RFC3339)),
Architecture: aws.String("x86_64"),
Tags: []*ec2.Tag{
{Key: aws.String("Name"), Value: aws.String(amd64NvidiaAMI)},
Expand All @@ -104,7 +105,7 @@ var _ = BeforeEach(func() {
{
Name: aws.String(arm64NvidiaAMI),
ImageId: aws.String("arm64-nvidia-ami-id"),
CreationDate: aws.String(time.Now().Add(2 * time.Minute).Format(time.RFC3339)),
CreationDate: aws.String(time.Time{}.Add(2 * time.Minute).Format(time.RFC3339)),
Architecture: aws.String("arm64"),
Tags: []*ec2.Tag{
{Key: aws.String("Name"), Value: aws.String(arm64NvidiaAMI)},
Expand Down Expand Up @@ -196,6 +197,49 @@ var _ = Describe("AMIProvider", func() {
Expect(err).ToNot(HaveOccurred())
Expect(amis).To(HaveLen(0))
})
It("should not cause data races when calling Get() simultaneously", func() {
nodeClass.Spec.AMISelectorTerms = []v1beta1.AMISelectorTerm{
{
ID: "amd64-ami-id",
},
{
ID: "arm64-ami-id",
},
}
wg := sync.WaitGroup{}
for i := 0; i < 10000; i++ {
wg.Add(1)
go func() {
defer wg.Done()
defer GinkgoRecover()
images, err := awsEnv.AMIProvider.Get(ctx, nodeClass, &amifamily.Options{})
Expect(err).ToNot(HaveOccurred())

Expect(images).To(HaveLen(2))
// Sort everything in parallel and ensure that we don't get data races
images.Sort()
Expect(images).To(BeEquivalentTo([]amifamily.AMI{
{
Name: arm64AMI,
AmiID: "arm64-ami-id",
CreationDate: time.Time{}.Add(time.Minute).Format(time.RFC3339),
Requirements: scheduling.NewLabelRequirements(map[string]string{
v1.LabelArchStable: corev1beta1.ArchitectureArm64,
}),
},
{
Name: amd64AMI,
AmiID: "amd64-ami-id",
CreationDate: time.Time{}.Format(time.RFC3339),
Requirements: scheduling.NewLabelRequirements(map[string]string{
v1.LabelArchStable: corev1beta1.ArchitectureAmd64,
}),
},
}))
}()
}
wg.Wait()
})
Context("SSM Alias Missing", func() {
It("should succeed to partially resolve AMIs if all SSM aliases don't exist (Al2)", func() {
nodeClass.Spec.AMIFamily = &v1beta1.AMIFamilyAL2
Expand Down
4 changes: 3 additions & 1 deletion pkg/providers/instancetype/instancetype.go
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,9 @@ func (p *DefaultProvider) List(ctx context.Context, kc *corev1beta1.KubeletConfi
aws.StringValue(nodeClass.Spec.AMIFamily),
)
if item, ok := p.instanceTypesCache.Get(key); ok {
return item.([]*cloudprovider.InstanceType), nil
// Ensure what's returned from this function is a shallow-copy of the slice (not a deep-copy of the data itself)
// so that modifications to the ordering of the data don't affect the original
return append([]*cloudprovider.InstanceType{}, item.([]*cloudprovider.InstanceType)...), nil
}

// Get all zones across all offerings
Expand Down
37 changes: 36 additions & 1 deletion pkg/providers/instancetype/suite_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"reflect"
"sort"
"strings"
"sync"
"testing"
"time"

Expand Down Expand Up @@ -909,7 +910,6 @@ var _ = Describe("InstanceTypeProvider", func() {
ExpectProvisioned(ctx, env.Client, cluster, cloudProvider, prov, pod)
ExpectScheduled(ctx, env.Client, pod)
})

Context("Overhead", func() {
var info *ec2.InstanceTypeInfo
BeforeEach(func() {
Expand Down Expand Up @@ -2310,6 +2310,41 @@ var _ = Describe("InstanceTypeProvider", func() {
uniqueInstanceTypeList(instanceTypeResult)
})
})
It("should not cause data races when calling List() simultaneously", func() {
mu := sync.RWMutex{}
var instanceTypeOrder []string
wg := sync.WaitGroup{}
for i := 0; i < 10000; i++ {
wg.Add(1)
go func() {
defer wg.Done()
defer GinkgoRecover()
instanceTypes, err := awsEnv.InstanceTypesProvider.List(ctx, &corev1beta1.KubeletConfiguration{}, nodeClass)
Expect(err).ToNot(HaveOccurred())

// Sort everything in parallel and ensure that we don't get data races
sort.Slice(instanceTypes, func(i, j int) bool {
return instanceTypes[i].Name < instanceTypes[j].Name
})
// Get the ordering of the instance types based on name
tempInstanceTypeOrder := lo.Map(instanceTypes, func(i *corecloudprovider.InstanceType, _ int) string {
return i.Name
})
// Expect that all the elements in the instance type list are unique
Expect(lo.Uniq(tempInstanceTypeOrder)).To(HaveLen(len(tempInstanceTypeOrder)))

// We have to lock since we are doing simultaneous access to this value
mu.Lock()
if len(instanceTypeOrder) == 0 {
instanceTypeOrder = tempInstanceTypeOrder
} else {
Expect(tempInstanceTypeOrder).To(BeEquivalentTo(instanceTypeOrder))
}
mu.Unlock()
}()
}
wg.Wait()
})
})

func uniqueInstanceTypeList(instanceTypesLists [][]*corecloudprovider.InstanceType) {
Expand Down
4 changes: 3 additions & 1 deletion pkg/providers/securitygroup/securitygroup.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,9 @@ func (p *DefaultProvider) getSecurityGroups(ctx context.Context, filterSets [][]
return nil, err
}
if sg, ok := p.cache.Get(fmt.Sprint(hash)); ok {
return sg.([]*ec2.SecurityGroup), nil
// Ensure what's returned from this function is a shallow-copy of the slice (not a deep-copy of the data itself)
// so that modifications to the ordering of the data don't affect the original
return append([]*ec2.SecurityGroup{}, sg.([]*ec2.SecurityGroup)...), nil
}
securityGroups := map[string]*ec2.SecurityGroup{}
for _, filters := range filterSets {
Expand Down
68 changes: 68 additions & 0 deletions pkg/providers/securitygroup/suite_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ package securitygroup_test

import (
"context"
"sort"
"sync"
"testing"

"github.com/aws/aws-sdk-go/aws"
Expand Down Expand Up @@ -330,6 +332,72 @@ var _ = Describe("SecurityGroupProvider", func() {
}
})
})
It("should not cause data races when calling List() simultaneously", func() {
wg := sync.WaitGroup{}
for i := 0; i < 10000; i++ {
wg.Add(1)
go func() {
defer wg.Done()
defer GinkgoRecover()
securityGroups, err := awsEnv.SecurityGroupProvider.List(ctx, nodeClass)
Expect(err).ToNot(HaveOccurred())

Expect(securityGroups).To(HaveLen(3))
// Sort everything in parallel and ensure that we don't get data races
sort.Slice(securityGroups, func(i, j int) bool {
return *securityGroups[i].GroupId < *securityGroups[j].GroupId
})
Expect(securityGroups).To(BeEquivalentTo([]*ec2.SecurityGroup{
{
GroupId: lo.ToPtr("sg-test1"),
GroupName: lo.ToPtr("securityGroup-test1"),
Tags: []*ec2.Tag{
{
Key: lo.ToPtr("Name"),
Value: lo.ToPtr("test-security-group-1"),
},
{
Key: lo.ToPtr("foo"),
Value: lo.ToPtr("bar"),
},
},
},
{
GroupId: lo.ToPtr("sg-test2"),
GroupName: lo.ToPtr("securityGroup-test2"),
Tags: []*ec2.Tag{
{
Key: lo.ToPtr("Name"),
Value: lo.ToPtr("test-security-group-2"),
},
{
Key: lo.ToPtr("foo"),
Value: lo.ToPtr("bar"),
},
},
},
{
GroupId: lo.ToPtr("sg-test3"),
GroupName: lo.ToPtr("securityGroup-test3"),
Tags: []*ec2.Tag{
{
Key: lo.ToPtr("Name"),
Value: lo.ToPtr("test-security-group-3"),
},
{
Key: lo.ToPtr("TestTag"),
},
{
Key: lo.ToPtr("foo"),
Value: lo.ToPtr("bar"),
},
},
},
}))
}()
}
wg.Wait()
})
})

func ExpectConsistsOfSecurityGroups(expected, actual []*ec2.SecurityGroup) {
Expand Down
4 changes: 3 additions & 1 deletion pkg/providers/subnet/subnet.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,9 @@ func (p *DefaultProvider) List(ctx context.Context, nodeClass *v1beta1.EC2NodeCl
return nil, err
}
if subnets, ok := p.cache.Get(fmt.Sprint(hash)); ok {
return subnets.([]*ec2.Subnet), nil
// Ensure what's returned from this function is a shallow-copy of the slice (not a deep-copy of the data itself)
// so that modifications to the ordering of the data don't affect the original
return append([]*ec2.Subnet{}, subnets.([]*ec2.Subnet)...), nil
}

// Ensure that all the subnets that are returned here are unique
Expand Down
89 changes: 89 additions & 0 deletions pkg/providers/subnet/suite_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ package subnet_test

import (
"context"
"sort"
"sync"
"testing"

"github.com/aws/aws-sdk-go/aws"
Expand Down Expand Up @@ -294,6 +296,93 @@ var _ = Describe("SubnetProvider", func() {
}
})
})
It("should not cause data races when calling List() simultaneously", func() {
wg := sync.WaitGroup{}
for i := 0; i < 10000; i++ {
wg.Add(1)
go func() {
defer wg.Done()
defer GinkgoRecover()
subnets, err := awsEnv.SubnetProvider.List(ctx, nodeClass)
Expect(err).ToNot(HaveOccurred())

Expect(subnets).To(HaveLen(4))
// Sort everything in parallel and ensure that we don't get data races
sort.Slice(subnets, func(i, j int) bool {
if int(*subnets[i].AvailableIpAddressCount) != int(*subnets[j].AvailableIpAddressCount) {
return int(*subnets[i].AvailableIpAddressCount) > int(*subnets[j].AvailableIpAddressCount)
}
return *subnets[i].SubnetId < *subnets[j].SubnetId
})
Expect(subnets).To(BeEquivalentTo([]*ec2.Subnet{
{
AvailabilityZone: lo.ToPtr("test-zone-1a"),
AvailableIpAddressCount: lo.ToPtr[int64](100),
SubnetId: lo.ToPtr("subnet-test1"),
MapPublicIpOnLaunch: lo.ToPtr(false),
Tags: []*ec2.Tag{
{
Key: lo.ToPtr("Name"),
Value: lo.ToPtr("test-subnet-1"),
},
{
Key: lo.ToPtr("foo"),
Value: lo.ToPtr("bar"),
},
},
},
{
AvailabilityZone: lo.ToPtr("test-zone-1b"),
AvailableIpAddressCount: lo.ToPtr[int64](100),
MapPublicIpOnLaunch: lo.ToPtr(true),
SubnetId: lo.ToPtr("subnet-test2"),

Tags: []*ec2.Tag{
{
Key: lo.ToPtr("Name"),
Value: lo.ToPtr("test-subnet-2"),
},
{
Key: lo.ToPtr("foo"),
Value: lo.ToPtr("bar"),
},
},
},
{
AvailabilityZone: lo.ToPtr("test-zone-1c"),
AvailableIpAddressCount: lo.ToPtr[int64](100),
SubnetId: lo.ToPtr("subnet-test3"),
Tags: []*ec2.Tag{
{
Key: lo.ToPtr("Name"),
Value: lo.ToPtr("test-subnet-3"),
},
{
Key: lo.ToPtr("TestTag"),
},
{
Key: lo.ToPtr("foo"),
Value: lo.ToPtr("bar"),
},
},
},
{
AvailabilityZone: lo.ToPtr("test-zone-1a-local"),
AvailableIpAddressCount: lo.ToPtr[int64](100),
SubnetId: lo.ToPtr("subnet-test4"),
MapPublicIpOnLaunch: lo.ToPtr(true),
Tags: []*ec2.Tag{
{
Key: lo.ToPtr("Name"),
Value: lo.ToPtr("test-subnet-4"),
},
},
},
}))
}()
}
wg.Wait()
})
})

func ExpectConsistsOfSubnets(expected, actual []*ec2.Subnet) {
Expand Down
Loading