diff --git a/iterator.go b/iterator.go index 3b7249059..887b049f1 100644 --- a/iterator.go +++ b/iterator.go @@ -718,6 +718,9 @@ func (it *Iterator) prefetch() { // smallest key greater than the provided key if iterating in the forward direction. // Behavior would be reversed if iterating backwards. func (it *Iterator) Seek(key []byte) { + if len(key) > 0 { + it.txn.addReadKey(key) + } for i := it.data.pop(); i != nil; i = it.data.pop() { i.wg.Wait() it.waste.push(i) diff --git a/txn_test.go b/txn_test.go index fe29da291..391f4304a 100644 --- a/txn_test.go +++ b/txn_test.go @@ -22,6 +22,7 @@ import ( "math/rand" "strconv" "sync" + "sync/atomic" "testing" "time" @@ -865,3 +866,81 @@ func TestArmV7Issue311Fix(t *testing.T) { require.NoError(t, err) require.NoError(t, db.Close()) } + +// This test tries to perform a GetAndSet operation using multiple concurrent +// transaction and only one of the transactions should be successful. +// Regression test for https://github.com/dgraph-io/badger/issues/1289 +func TestConflict(t *testing.T) { + key := []byte("foo") + setCount := uint32(0) + + testAndSet := func(wg *sync.WaitGroup, db *DB) { + defer wg.Done() + txn := db.NewTransaction(true) + defer txn.Discard() + + _, err := txn.Get(key) + if err == ErrKeyNotFound { + // Unset the error. + err = nil + require.NoError(t, txn.Set(key, []byte("AA"))) + txn.CommitWith(func(err error) { + if err == nil { + require.LessOrEqual(t, uint32(1), atomic.AddUint32(&setCount, 1)) + } else { + + require.Error(t, err, ErrConflict) + } + }) + } + require.NoError(t, err) + } + testAndSetItr := func(wg *sync.WaitGroup, db *DB) { + defer wg.Done() + txn := db.NewTransaction(true) + defer txn.Discard() + + iopt := DefaultIteratorOptions + it := txn.NewIterator(iopt) + + found := false + for it.Seek(key); it.Valid(); it.Next() { + found = true + } + it.Close() + + if !found { + require.NoError(t, txn.Set(key, []byte("AA"))) + txn.CommitWith(func(err error) { + if err == nil { + require.LessOrEqual(t, atomic.AddUint32(&setCount, 1), uint32(1)) + } else { + require.Error(t, err, ErrConflict) + } + }) + } + } + + runTest := func(t *testing.T, fn func(wg *sync.WaitGroup, db *DB)) { + loop := 10 + numGo := 16 // This many concurrent transactions. + for i := 0; i < loop; i++ { + var wg sync.WaitGroup + wg.Add(numGo) + setCount = 0 + runBadgerTest(t, nil, func(t *testing.T, db *DB) { + for j := 0; j < numGo; j++ { + go fn(&wg, db) + } + wg.Wait() + }) + require.Equal(t, uint32(1), atomic.LoadUint32(&setCount)) + } + } + t.Run("TxnGet", func(t *testing.T) { + runTest(t, testAndSet) + }) + t.Run("ItrSeek", func(t *testing.T) { + runTest(t, testAndSetItr) + }) +}