Skip to content

Commit

Permalink
Merge pull request #8696 from planetscale/ds-fix-vstream-test
Browse files Browse the repository at this point in the history
tests: use AtomicInt32 instead of int to fix races
  • Loading branch information
deepthi authored Aug 27, 2021
2 parents 21f173b + 50b40e1 commit 016756e
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 26 deletions.
34 changes: 24 additions & 10 deletions go/vt/vtgate/vstream_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,10 @@ type vstream struct {
journaler map[int64]*journalEvent

// err can only be set once.
once sync.Once
err error
// errMu protects err by ensuring its value is read or written by only one goroutine at a time.
once sync.Once
err error
errMu sync.Mutex

// Other input parameters
tabletType topodatapb.TabletType
Expand Down Expand Up @@ -238,7 +240,7 @@ func (vs *vstream) stream(ctx context.Context) error {
}
vs.wg.Wait()

return vs.err
return vs.getError()
}

func (vs *vstream) sendEvents(ctx context.Context) {
Expand All @@ -260,7 +262,7 @@ func (vs *vstream) sendEvents(ctx context.Context) {
send := func(evs []*binlogdatapb.VEvent) error {
if err := vs.send(evs); err != nil {
vs.once.Do(func() {
vs.err = err
vs.setError(err)
})
return err
}
Expand All @@ -270,13 +272,13 @@ func (vs *vstream) sendEvents(ctx context.Context) {
select {
case <-ctx.Done():
vs.once.Do(func() {
vs.err = fmt.Errorf("context canceled")
vs.setError(fmt.Errorf("context canceled"))
})
return
case evs := <-vs.eventCh:
if err := send(evs); err != nil {
vs.once.Do(func() {
vs.err = err
vs.setError(err)
})
return
}
Expand All @@ -290,7 +292,7 @@ func (vs *vstream) sendEvents(ctx context.Context) {
}}
if err := send(evs); err != nil {
vs.once.Do(func() {
vs.err = err
vs.setError(err)
})
return
}
Expand All @@ -309,7 +311,7 @@ func (vs *vstream) startOneStream(ctx context.Context, sgtid *binlogdatapb.Shard
if err != nil {
log.Errorf("Error in vstream for %+v: %s", sgtid, err)
vs.once.Do(func() {
vs.err = err
vs.setError(err)
vs.cancel()
})
}
Expand Down Expand Up @@ -592,8 +594,8 @@ func (vs *vstream) sendAll(sgtid *binlogdatapb.ShardGtid, eventss [][]*binlogdat

// Send all chunks while holding the lock.
for _, events := range eventss {
if vs.err != nil {
return vs.err
if err := vs.getError(); err != nil {
return err
}
// convert all gtids to vgtids. This should be done here while holding the lock.
for j, event := range events {
Expand Down Expand Up @@ -638,6 +640,18 @@ func (vs *vstream) sendAll(sgtid *binlogdatapb.ShardGtid, eventss [][]*binlogdat
return nil
}

func (vs *vstream) getError() error {
vs.errMu.Lock()
defer vs.errMu.Unlock()
return vs.err
}

func (vs *vstream) setError(err error) {
vs.errMu.Lock()
defer vs.errMu.Unlock()
vs.err = err
}

// getJournalEvent returns a journalEvent. The caller has to wait on its done channel.
// Once it closes, the caller has to return (end their stream).
// The function has three parts:
Expand Down
30 changes: 14 additions & 16 deletions go/vt/vtgate/vstream_manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ import (
"testing"
"time"

"vitess.io/vitess/go/sync2"

"vitess.io/vitess/go/vt/topo"

vtgatepb "vitess.io/vitess/go/vt/proto/vtgate"
Expand Down Expand Up @@ -199,7 +201,6 @@ func TestVStreamEvents(t *testing.T) {
// TestVStreamChunks ensures that a transaction that's broken
// into chunks is sent together.
func TestVStreamChunks(t *testing.T) {
t.Skip("flaky test")
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

Expand All @@ -222,8 +223,9 @@ func TestVStreamChunks(t *testing.T) {

rowEncountered := false
doneCounting := false
rowCount := 0
ddlCount := 0
var rowCount, ddlCount sync2.AtomicInt32
rowCount.Set(0)
ddlCount.Set(0)
vgtid := &binlogdatapb.VGtid{
ShardGtids: []*binlogdatapb.ShardGtid{{
Keyspace: ks,
Expand All @@ -243,7 +245,7 @@ func TestVStreamChunks(t *testing.T) {
return fmt.Errorf("unexpected event: %v", events[0])
}
rowEncountered = true
rowCount++
rowCount.Add(1)
case binlogdatapb.VEventType_COMMIT:
if !rowEncountered {
t.Errorf("Unexpected event, COMMIT after non-rows: %v", events[0])
Expand All @@ -255,22 +257,18 @@ func TestVStreamChunks(t *testing.T) {
t.Errorf("Unexpected event, DDL during ROW events: %v", events[0])
return fmt.Errorf("unexpected event: %v", events[0])
}
ddlCount++
ddlCount.Add(1)
default:
t.Errorf("Unexpected event: %v", events[0])
return fmt.Errorf("unexpected event: %v", events[0])
}
if rowCount == 100 && ddlCount == 100 {
if rowCount.Get() == int32(100) && ddlCount.Get() == int32(100) {
cancel()
}
return nil
})
if rowCount != 100 {
t.Errorf("rowCount: %d, want 100", rowCount)
}
if ddlCount != 100 {
t.Errorf("ddlCount: %d, want 100", ddlCount)
}
assert.Equal(t, int32(100), rowCount.Get())
assert.Equal(t, int32(100), ddlCount.Get())
}

func TestVStreamMulti(t *testing.T) {
Expand Down Expand Up @@ -336,7 +334,6 @@ func TestVStreamMulti(t *testing.T) {
}

func TestVStreamRetry(t *testing.T) {
t.Skip()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

Expand All @@ -358,7 +355,8 @@ func TestVStreamRetry(t *testing.T) {
sbc0.AddVStreamEvents(nil, vterrors.Errorf(vtrpcpb.Code_FAILED_PRECONDITION, "bb"))
sbc0.AddVStreamEvents(nil, vterrors.Errorf(vtrpcpb.Code_FAILED_PRECONDITION, "cc"))
sbc0.AddVStreamEvents(nil, vterrors.Errorf(vtrpcpb.Code_FAILED_PRECONDITION, "final error"))
count := 0
var count sync2.AtomicInt32
count.Set(0)
vgtid := &binlogdatapb.VGtid{
ShardGtids: []*binlogdatapb.ShardGtid{{
Keyspace: ks,
Expand All @@ -367,15 +365,15 @@ func TestVStreamRetry(t *testing.T) {
}},
}
err := vsm.VStream(ctx, topodatapb.TabletType_PRIMARY, vgtid, nil, &vtgatepb.VStreamFlags{}, func(events []*binlogdatapb.VEvent) error {
count++
count.Add(1)
return nil
})
wantErr := "final error"
if err == nil || !strings.Contains(err.Error(), wantErr) {
t.Errorf("vstream end: %v, must contain %v", err.Error(), wantErr)
}
time.Sleep(100 * time.Millisecond) // wait for goroutine within VStream to finish
assert.Equal(t, 2, count)
assert.Equal(t, int32(2), count.Get())
}

func TestVStreamShouldNotSendSourceHeartbeats(t *testing.T) {
Expand Down

0 comments on commit 016756e

Please sign in to comment.