Skip to content

Commit

Permalink
chore: fix unit test data races (#3478) (#3479)
Browse files Browse the repository at this point in the history
chore: fix unit test data races

Signed-off-by: Jonathan West <[email protected]>
  • Loading branch information
jgwest authored Mar 28, 2024
1 parent 4d6e12b commit 149ff1e
Show file tree
Hide file tree
Showing 14 changed files with 156 additions and 84 deletions.
21 changes: 15 additions & 6 deletions analysis/controller_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"encoding/json"
"fmt"
"reflect"
"sync"
"testing"
"time"

Expand Down Expand Up @@ -50,9 +51,13 @@ type fixture struct {
// Actions expected to happen on the client.
actions []core.Action
// Objects from here preloaded into NewSimpleFake.
objects []runtime.Object
enqueuedObjects map[string]int
unfreezeTime func() error
objects []runtime.Object

// Acquire 'enqueuedObjectMutex' before accessing enqueuedObjects
enqueuedObjects map[string]int
enqueuedObjectMutex sync.Mutex

unfreezeTime func() error
// fake provider
provider *mocks.Provider

Expand All @@ -66,11 +71,11 @@ func newFixture(t *testing.T) *fixture {
f.objects = []runtime.Object{}
f.enqueuedObjects = make(map[string]int)
f.now = time.Now()
timeutil.Now = func() time.Time {
timeutil.SetNowTimeFunc(func() time.Time {
return f.now
}
})
f.unfreezeTime = func() error {
timeutil.Now = time.Now
timeutil.SetNowTimeFunc(time.Now)
return nil
}
return f
Expand Down Expand Up @@ -122,6 +127,10 @@ func (f *fixture) newController(resync resyncFunc) (*Controller, informers.Share
if key, err = cache.MetaNamespaceKeyFunc(obj); err != nil {
panic(err)
}

f.enqueuedObjectMutex.Lock()
defer f.enqueuedObjectMutex.Unlock()

count, ok := f.enqueuedObjects[key]
if !ok {
count = 0
Expand Down
8 changes: 5 additions & 3 deletions experiments/controller_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,13 +116,15 @@ func newFixture(t *testing.T, objects ...runtime.Object) *fixture {
f.kubeclient = k8sfake.NewSimpleClientset(f.kubeobjects...)
f.enqueuedObjects = make(map[string]int)
now := time.Now()
timeutil.Now = func() time.Time {

timeutil.SetNowTimeFunc(func() time.Time {
return now
}
})
f.unfreezeTime = func() error {
timeutil.Now = time.Now
timeutil.SetNowTimeFunc(time.Now)
return nil
}

return f
}

Expand Down
7 changes: 7 additions & 0 deletions pkg/kubectl-argo-rollouts/cmd/get/get_rollout.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"fmt"
"io"
"strings"
"sync"
"time"

"github.com/juju/ansiterm"
Expand Down Expand Up @@ -59,7 +60,11 @@ func NewCmdGetRollout(o *options.ArgoRolloutsOptions) *cobra.Command {
getOptions.PrintRollout(ri)
} else {
rolloutUpdates := make(chan *rollout.RolloutInfo)
var rolloutUpdatesMutex sync.Mutex

controller.RegisterCallback(func(roInfo *rollout.RolloutInfo) {
rolloutUpdatesMutex.Lock()
defer rolloutUpdatesMutex.Unlock()
rolloutUpdates <- roInfo
})
stopCh := ctx.Done()
Expand All @@ -72,6 +77,8 @@ func NewCmdGetRollout(o *options.ArgoRolloutsOptions) *cobra.Command {
}
go getOptions.WatchRollout(stopCh, rolloutUpdates)
controller.Run(ctx)
rolloutUpdatesMutex.Lock()
defer rolloutUpdatesMutex.Unlock()
close(rolloutUpdates)
}
return nil
Expand Down
22 changes: 19 additions & 3 deletions pkg/kubectl-argo-rollouts/viewcontroller/viewcontroller.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package viewcontroller
import (
"context"
"reflect"
"sync"
"time"

"github.com/argoproj/argo-rollouts/utils/queue"
Expand All @@ -11,7 +12,6 @@ import (
v1 "k8s.io/api/apps/v1"
"k8s.io/apimachinery/pkg/labels"
"k8s.io/apimachinery/pkg/util/wait"
"k8s.io/client-go/informers"
kubeinformers "k8s.io/client-go/informers"
"k8s.io/client-go/kubernetes"
appslisters "k8s.io/client-go/listers/apps/v1"
Expand All @@ -32,7 +32,7 @@ type viewController struct {
name string
namespace string

kubeInformerFactory informers.SharedInformerFactory
kubeInformerFactory kubeinformers.SharedInformerFactory
rolloutsInformerFactory rolloutinformers.SharedInformerFactory

replicaSetLister appslisters.ReplicaSetNamespaceLister
Expand All @@ -48,6 +48,8 @@ type viewController struct {
prevObj any
getObj func() (any, error)
callbacks []func(any)
// acquire 'callbacksLock' before reading/writing to 'callbacks'
callbacksLock sync.Mutex
}

type RolloutViewController struct {
Expand Down Expand Up @@ -164,7 +166,13 @@ func (c *viewController) processNextWorkItem() bool {
return true
}
if !reflect.DeepEqual(c.prevObj, newObj) {
for _, cb := range c.callbacks {

// Acquire the mutex and make a thread-local copy of the list of callbacks
c.callbacksLock.Lock()
callbacks := append(make([]func(any), 0), c.callbacks...)
c.callbacksLock.Unlock()

for _, cb := range callbacks {
cb(newObj)
}
c.prevObj = newObj
Expand All @@ -173,6 +181,9 @@ func (c *viewController) processNextWorkItem() bool {
}

func (c *viewController) DeregisterCallbacks() {
c.callbacksLock.Lock()
defer c.callbacksLock.Unlock()

c.callbacks = nil
}

Expand Down Expand Up @@ -218,6 +229,9 @@ func (c *RolloutViewController) RegisterCallback(callback RolloutInfoCallback) {
cb := func(i any) {
callback(i.(*rollout.RolloutInfo))
}
c.callbacksLock.Lock()
defer c.callbacksLock.Unlock()

c.callbacks = append(c.callbacks, cb)
}

Expand Down Expand Up @@ -246,5 +260,7 @@ func (c *ExperimentViewController) RegisterCallback(callback ExperimentInfoCallb
cb := func(i any) {
callback(i.(*rollout.ExperimentInfo))
}
c.callbacksLock.Lock()
defer c.callbacksLock.Unlock()
c.callbacks = append(c.callbacks, cb)
}
22 changes: 20 additions & 2 deletions pkg/kubectl-argo-rollouts/viewcontroller/viewcontroller_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package viewcontroller

import (
"context"
"sync"
"testing"
"time"

Expand Down Expand Up @@ -53,7 +54,11 @@ func TestRolloutControllerCallback(t *testing.T) {
}

callbackCalled := false
var callbackCalledLock sync.Mutex // acquire before accessing callbackCalled

cb := func(roInfo *rollout.RolloutInfo) {
callbackCalledLock.Lock()
defer callbackCalledLock.Unlock()
callbackCalled = true
assert.Equal(t, roInfo.ObjectMeta.Name, "foo")
}
Expand All @@ -67,11 +72,16 @@ func TestRolloutControllerCallback(t *testing.T) {
go c.Run(ctx)
time.Sleep(time.Second)
for i := 0; i < 100; i++ {
if callbackCalled {
callbackCalledLock.Lock()
isCallbackCalled := callbackCalled
callbackCalledLock.Unlock()
if isCallbackCalled {
break
}
time.Sleep(10 * time.Millisecond)
}
callbackCalledLock.Lock()
defer callbackCalledLock.Unlock()
assert.True(t, callbackCalled)
}

Expand Down Expand Up @@ -100,8 +110,11 @@ func TestExperimentControllerCallback(t *testing.T) {
},
}

var callbackCalledLock sync.Mutex // acquire before accessing callbackCalled
callbackCalled := false
cb := func(expInfo *rollout.ExperimentInfo) {
callbackCalledLock.Lock()
defer callbackCalledLock.Unlock()
callbackCalled = true
assert.Equal(t, expInfo.ObjectMeta.Name, "foo")
}
Expand All @@ -115,10 +128,15 @@ func TestExperimentControllerCallback(t *testing.T) {
go c.Run(ctx)
time.Sleep(time.Second)
for i := 0; i < 100; i++ {
if callbackCalled {
callbackCalledLock.Lock()
isCallbackCalled := callbackCalled
callbackCalledLock.Unlock()
if isCallbackCalled {
break
}
time.Sleep(10 * time.Millisecond)
}
callbackCalledLock.Lock()
defer callbackCalledLock.Unlock()
assert.True(t, callbackCalled)
}
4 changes: 4 additions & 0 deletions rollout/canary_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1156,6 +1156,8 @@ func TestSyncRolloutWaitAddToQueue(t *testing.T) {
f.runController(key, true, false, c, i, k8sI)

// When the controller starts, it will enqueue the rollout while syncing the informer and during the reconciliation step
f.enqueuedObjectsLock.Lock()
defer f.enqueuedObjectsLock.Unlock()
assert.Equal(t, 2, f.enqueuedObjects[key])
}

Expand Down Expand Up @@ -1204,6 +1206,8 @@ func TestSyncRolloutIgnoreWaitOutsideOfReconciliationPeriod(t *testing.T) {
c, i, k8sI := f.newController(func() time.Duration { return 30 * time.Minute })
f.runController(key, true, false, c, i, k8sI)
// When the controller starts, it will enqueue the rollout so we expect the rollout to enqueue at least once.
f.enqueuedObjectsLock.Lock()
defer f.enqueuedObjectsLock.Unlock()
assert.Equal(t, 1, f.enqueuedObjects[key])
}

Expand Down
26 changes: 16 additions & 10 deletions rollout/controller_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,10 +99,12 @@ type fixture struct {
kubeactions []core.Action
actions []core.Action
// Objects from here preloaded into NewSimpleFake.
kubeobjects []runtime.Object
objects []runtime.Object
enqueuedObjects map[string]int
unfreezeTime func() error
kubeobjects []runtime.Object
objects []runtime.Object
// Acquire 'enqueuedObjectsLock' before accessing enqueuedObjects
enqueuedObjects map[string]int
enqueuedObjectsLock sync.Mutex
unfreezeTime func() error

// events holds all the K8s Event Reasons emitted during the run
events []string
Expand All @@ -116,9 +118,12 @@ func newFixture(t *testing.T) *fixture {
f.kubeobjects = []runtime.Object{}
f.enqueuedObjects = make(map[string]int)
now := time.Now()
timeutil.Now = func() time.Time { return now }

timeutil.SetNowTimeFunc(func() time.Time {
return now
})
f.unfreezeTime = func() error {
timeutil.Now = time.Now
timeutil.SetNowTimeFunc(time.Now)
return nil
}

Expand Down Expand Up @@ -598,15 +603,15 @@ func (f *fixture) newController(resync resyncFunc) (*Controller, informers.Share
RefResolver: &FakeWorkloadRefResolver{},
})

var enqueuedObjectsLock sync.Mutex
c.enqueueRollout = func(obj any) {
var key string
var err error
if key, err = cache.MetaNamespaceKeyFunc(obj); err != nil {
panic(err)
}
enqueuedObjectsLock.Lock()
defer enqueuedObjectsLock.Unlock()

f.enqueuedObjectsLock.Lock()
defer f.enqueuedObjectsLock.Unlock()
count, ok := f.enqueuedObjects[key]
if !ok {
count = 0
Expand Down Expand Up @@ -720,7 +725,8 @@ func (f *fixture) runController(rolloutName string, startInformers bool, expectE
f.t.Errorf("%d expected actions did not happen:%+v", len(f.kubeactions)-len(k8sActions), f.kubeactions[len(k8sActions):])
}
fakeRecorder := c.recorder.(*record.FakeEventRecorder)
f.events = fakeRecorder.Events

f.events = fakeRecorder.Events()
return c
}

Expand Down
2 changes: 1 addition & 1 deletion rollout/sync_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -455,7 +455,7 @@ func TestSendStateChangeEvents(t *testing.T) {
recorder := record.NewFakeEventRecorder()
roCtx.recorder = recorder
roCtx.sendStateChangeEvents(&test.prevStatus, &test.newStatus)
assert.Equal(t, test.expectedEventReasons, recorder.Events)
assert.Equal(t, test.expectedEventReasons, recorder.Events())
}
}

Expand Down
6 changes: 4 additions & 2 deletions rollout/trafficrouting/ambassador/ambassador_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -136,16 +136,18 @@ type getReturn struct {
func (f *fakeClient) Get(ctx context.Context, name string, options metav1.GetOptions, subresources ...string) (*unstructured.Unstructured, error) {
invokation := &getInvokation{name: name}
f.mu.Lock()
defer f.mu.Unlock()
f.getInvokations = append(f.getInvokations, invokation)
f.mu.Unlock()

if len(f.getReturns) == 0 {
return nil, nil
}
ret := f.getReturns[0]
if len(f.getReturns) >= len(f.getInvokations) {
ret = f.getReturns[len(f.getInvokations)-1]
}
return ret.obj, ret.err
// We clone the object before returning it, to prevent modification of the fake object in memory by the calling function
return ret.obj.DeepCopy(), ret.err
}

func (f *fakeClient) Create(ctx context.Context, obj *unstructured.Unstructured, options metav1.CreateOptions, subresources ...string) (*unstructured.Unstructured, error) {
Expand Down
Loading

0 comments on commit 149ff1e

Please sign in to comment.