diff --git a/mvcc/kv_test.go b/mvcc/kv_test.go index d6f49ee14a9..2d7dc01ff7f 100644 --- a/mvcc/kv_test.go +++ b/mvcc/kv_test.go @@ -716,7 +716,7 @@ func TestWatchableKVWatch(t *testing.T) { w := s.NewWatchStream() defer w.Close() - wid := w.Watch([]byte("foo"), []byte("fop"), 0) + wid, _ := w.Watch(0, []byte("foo"), []byte("fop"), 0) wev := []mvccpb.Event{ {Type: mvccpb.PUT, @@ -783,7 +783,7 @@ func TestWatchableKVWatch(t *testing.T) { } w = s.NewWatchStream() - wid = w.Watch([]byte("foo1"), []byte("foo2"), 3) + wid, _ = w.Watch(0, []byte("foo1"), []byte("foo2"), 3) select { case resp := <-w.Chan(): diff --git a/mvcc/watchable_store_bench_test.go b/mvcc/watchable_store_bench_test.go index 769d1bc38a8..198fea6bb42 100644 --- a/mvcc/watchable_store_bench_test.go +++ b/mvcc/watchable_store_bench_test.go @@ -78,7 +78,7 @@ func BenchmarkWatchableStoreWatchSyncPut(b *testing.B) { watchIDs := make([]WatchID, b.N) for i := range watchIDs { // non-0 value to keep watchers in unsynced - watchIDs[i] = w.Watch(k, nil, 1) + watchIDs[i], _ = w.Watch(0, k, nil, 1) } b.ResetTimer() @@ -142,7 +142,7 @@ func BenchmarkWatchableStoreUnsyncedCancel(b *testing.B) { watchIDs := make([]WatchID, watcherN) for i := 0; i < watcherN; i++ { // non-0 value to keep watchers in unsynced - watchIDs[i] = w.Watch(testKey, nil, 1) + watchIDs[i], _ = w.Watch(0, testKey, nil, 1) } // random-cancel N watchers to make it not biased towards @@ -182,7 +182,7 @@ func BenchmarkWatchableStoreSyncedCancel(b *testing.B) { watchIDs := make([]WatchID, watcherN) for i := 0; i < watcherN; i++ { // 0 for startRev to keep watchers in synced - watchIDs[i] = w.Watch(testKey, nil, 0) + watchIDs[i], _ = w.Watch(0, testKey, nil, 0) } // randomly cancel watchers to make it not biased towards diff --git a/mvcc/watchable_store_test.go b/mvcc/watchable_store_test.go index 52e1b90c0c0..c36a541a5f3 100644 --- a/mvcc/watchable_store_test.go +++ b/mvcc/watchable_store_test.go @@ -42,7 +42,7 @@ func TestWatch(t *testing.T) { s.Put(testKey, testValue, lease.NoLease) w := s.NewWatchStream() - w.Watch(testKey, nil, 0) + w.Watch(0, testKey, nil, 0) if !s.synced.contains(string(testKey)) { // the key must have had an entry in synced @@ -63,7 +63,7 @@ func TestNewWatcherCancel(t *testing.T) { s.Put(testKey, testValue, lease.NoLease) w := s.NewWatchStream() - wt := w.Watch(testKey, nil, 0) + wt, _ := w.Watch(0, testKey, nil, 0) if err := w.Cancel(wt); err != nil { t.Error(err) @@ -114,7 +114,7 @@ func TestCancelUnsynced(t *testing.T) { watchIDs := make([]WatchID, watcherN) for i := 0; i < watcherN; i++ { // use 1 to keep watchers in unsynced - watchIDs[i] = w.Watch(testKey, nil, 1) + watchIDs[i], _ = w.Watch(0, testKey, nil, 1) } for _, idx := range watchIDs { @@ -160,7 +160,7 @@ func TestSyncWatchers(t *testing.T) { for i := 0; i < watcherN; i++ { // specify rev as 1 to keep watchers in unsynced - w.Watch(testKey, nil, 1) + w.Watch(0, testKey, nil, 1) } // Before running s.syncWatchers() synced should be empty because we manually @@ -242,7 +242,7 @@ func TestWatchCompacted(t *testing.T) { } w := s.NewWatchStream() - wt := w.Watch(testKey, nil, compactRev-1) + wt, _ := w.Watch(0, testKey, nil, compactRev-1) select { case resp := <-w.Chan(): @@ -271,7 +271,7 @@ func TestWatchFutureRev(t *testing.T) { w := s.NewWatchStream() wrev := int64(10) - w.Watch(testKey, nil, wrev) + w.Watch(0, testKey, nil, wrev) for i := 0; i < 10; i++ { rev := s.Put(testKey, testValue, lease.NoLease) @@ -310,7 +310,7 @@ func TestWatchRestore(t *testing.T) { defer cleanup(newStore, newBackend, newPath) w := newStore.NewWatchStream() - w.Watch(testKey, nil, rev-1) + w.Watch(0, testKey, nil, rev-1) newStore.Restore(b) select { @@ -349,7 +349,7 @@ func TestWatchBatchUnsynced(t *testing.T) { } w := s.NewWatchStream() - w.Watch(v, nil, 1) + w.Watch(0, v, nil, 1) for i := 0; i < batches; i++ { if resp := <-w.Chan(); len(resp.Events) != watchBatchMaxRevs { t.Fatalf("len(events) = %d, want %d", len(resp.Events), watchBatchMaxRevs) @@ -485,7 +485,7 @@ func TestWatchVictims(t *testing.T) { for i := 0; i < numWatches; i++ { go func() { w := s.NewWatchStream() - w.Watch(testKey, nil, 1) + w.Watch(0, testKey, nil, 1) defer func() { w.Close() wg.Done() @@ -561,7 +561,7 @@ func TestStressWatchCancelClose(t *testing.T) { w := s.NewWatchStream() ids := make([]WatchID, 10) for i := range ids { - ids[i] = w.Watch(testKey, nil, 0) + ids[i], _ = w.Watch(0, testKey, nil, 0) } <-readyc wg.Add(1 + len(ids)/2) diff --git a/mvcc/watcher.go b/mvcc/watcher.go index bc0c6322fd1..886b87d5a47 100644 --- a/mvcc/watcher.go +++ b/mvcc/watcher.go @@ -22,8 +22,14 @@ import ( "github.com/coreos/etcd/mvcc/mvccpb" ) +// AutoWatchID is the watcher ID passed in WatchStream.Watch when no +// user-provided ID is available. If pass, an ID will automatically be assigned. +const AutoWatchID WatchID = 0 + var ( - ErrWatcherNotExist = errors.New("mvcc: watcher does not exist") + ErrWatcherNotExist = errors.New("mvcc: watcher does not exist") + ErrEmptyWatcherRange = errors.New("mvcc: watcher range is empty") + ErrWatcherDuplicateID = errors.New("mvcc: duplicate watch ID provided on the WatchStream") ) type WatchID int64 @@ -36,12 +42,13 @@ type WatchStream interface { // happened on the given key or range [key, end) from the given startRev. // // The whole event history can be watched unless compacted. - // If `startRev` <=0, watch observes events after currentRev. + // If "startRev" <=0, watch observes events after currentRev. // - // The returned `id` is the ID of this watcher. It appears as WatchID + // The returned "id" is the ID of this watcher. It appears as WatchID // in events that are sent to the created watcher through stream channel. - // - Watch(key, end []byte, startRev int64, fcs ...FilterFunc) WatchID + // The watch ID is used when it's not equal to AutoWatchID. Otherwise, + // an auto-generated watch ID is returned. + Watch(id WatchID, key, end []byte, startRev int64, fcs ...FilterFunc) (WatchID, error) // Chan returns a chan. All watch response will be sent to the returned chan. Chan() <-chan WatchResponse @@ -98,28 +105,34 @@ type watchStream struct { } // Watch creates a new watcher in the stream and returns its WatchID. -// TODO: return error if ws is closed? -func (ws *watchStream) Watch(key, end []byte, startRev int64, fcs ...FilterFunc) WatchID { +func (ws *watchStream) Watch(id WatchID, key, end []byte, startRev int64, fcs ...FilterFunc) (WatchID, error) { // prevent wrong range where key >= end lexicographically // watch request with 'WithFromKey' has empty-byte range end if len(end) != 0 && bytes.Compare(key, end) != -1 { - return -1 + return -1, ErrEmptyWatcherRange } ws.mu.Lock() defer ws.mu.Unlock() if ws.closed { - return -1 + return -1, ErrEmptyWatcherRange } - id := ws.nextID - ws.nextID++ + if id == AutoWatchID { + for ws.watchers[ws.nextID] != nil { + ws.nextID++ + } + id = ws.nextID + ws.nextID++ + } else if _, ok := ws.watchers[id]; ok { + return -1, ErrWatcherDuplicateID + } w, c := ws.watchable.watch(key, end, startRev, id, ws.ch, fcs...) ws.cancels[id] = c ws.watchers[id] = w - return id + return id, nil } func (ws *watchStream) Chan() <-chan WatchResponse { diff --git a/mvcc/watcher_bench_test.go b/mvcc/watcher_bench_test.go index 8a4242f3f20..86cbea7df2e 100644 --- a/mvcc/watcher_bench_test.go +++ b/mvcc/watcher_bench_test.go @@ -33,6 +33,6 @@ func BenchmarkKVWatcherMemoryUsage(b *testing.B) { b.ReportAllocs() b.StartTimer() for i := 0; i < b.N; i++ { - w.Watch([]byte(fmt.Sprint("foo", i)), nil, 0) + w.Watch(0, []byte(fmt.Sprint("foo", i)), nil, 0) } } diff --git a/mvcc/watcher_test.go b/mvcc/watcher_test.go index 3d259d1f160..f08e7db099b 100644 --- a/mvcc/watcher_test.go +++ b/mvcc/watcher_test.go @@ -40,7 +40,7 @@ func TestWatcherWatchID(t *testing.T) { idm := make(map[WatchID]struct{}) for i := 0; i < 10; i++ { - id := w.Watch([]byte("foo"), nil, 0) + id, _ := w.Watch(0, []byte("foo"), nil, 0) if _, ok := idm[id]; ok { t.Errorf("#%d: id %d exists", i, id) } @@ -62,7 +62,7 @@ func TestWatcherWatchID(t *testing.T) { // unsynced watchers for i := 10; i < 20; i++ { - id := w.Watch([]byte("foo2"), nil, 1) + id, _ := w.Watch(0, []byte("foo2"), nil, 1) if _, ok := idm[id]; ok { t.Errorf("#%d: id %d exists", i, id) } @@ -79,6 +79,41 @@ func TestWatcherWatchID(t *testing.T) { } } +func TestWatcherRequestsCustomID(t *testing.T) { + b, tmpPath := backend.NewDefaultTmpBackend() + s := WatchableKV(newWatchableStore(b, &lease.FakeLessor{}, nil)) + defer cleanup(s, b, tmpPath) + + w := s.NewWatchStream() + defer w.Close() + + // - Request specifically ID #1 + // - Try to duplicate it, get an error + // - Make sure the auto-assignment skips over things we manually assigned + + tt := []struct { + GivenID WatchID + ExpectedID WatchID + ExpectedErr error + }{ + {1, 1, nil}, + {1, 0, ErrWatcherDuplicateID}, + {0, 0, nil}, + {0, 2, nil}, + } + + for i, tcase := range tt { + id, err := w.Watch(tcase.GivenID, []byte("foo"), nil, 0) + if tcase.ExpectedErr != nil || err != nil { + if err != tcase.ExpectedErr { + t.Errorf("expected get error %q in test case %q, got %q", tcase.ExpectedErr, i, err) + } + } else if tcase.ExpectedID != id { + t.Errorf("expected to create ID %d, got %d in test case %d", tcase.ExpectedID, id, i) + } + } +} + // TestWatcherWatchPrefix tests if Watch operation correctly watches // and returns events with matching prefixes. func TestWatcherWatchPrefix(t *testing.T) { @@ -95,7 +130,7 @@ func TestWatcherWatchPrefix(t *testing.T) { keyWatch, keyEnd, keyPut := []byte("foo"), []byte("fop"), []byte("foobar") for i := 0; i < 10; i++ { - id := w.Watch(keyWatch, keyEnd, 0) + id, _ := w.Watch(0, keyWatch, keyEnd, 0) if _, ok := idm[id]; ok { t.Errorf("#%d: unexpected duplicated id %x", i, id) } @@ -127,7 +162,7 @@ func TestWatcherWatchPrefix(t *testing.T) { // unsynced watchers for i := 10; i < 15; i++ { - id := w.Watch(keyWatch1, keyEnd1, 1) + id, _ := w.Watch(0, keyWatch1, keyEnd1, 1) if _, ok := idm[id]; ok { t.Errorf("#%d: id %d exists", i, id) } @@ -163,14 +198,14 @@ func TestWatcherWatchWrongRange(t *testing.T) { w := s.NewWatchStream() defer w.Close() - if id := w.Watch([]byte("foa"), []byte("foa"), 1); id != -1 { - t.Fatalf("key == end range given; id expected -1, got %d", id) + if _, err := w.Watch(0, []byte("foa"), []byte("foa"), 1); err != ErrEmptyWatcherRange { + t.Fatalf("key == end range given; expected ErrEmptyWatcherRange, got %+v", err) } - if id := w.Watch([]byte("fob"), []byte("foa"), 1); id != -1 { - t.Fatalf("key > end range given; id expected -1, got %d", id) + if _, err := w.Watch(0, []byte("fob"), []byte("foa"), 1); err != ErrEmptyWatcherRange { + t.Fatalf("key > end range given; expected ErrEmptyWatcherRange, got %+v", err) } // watch request with 'WithFromKey' has empty-byte range end - if id := w.Watch([]byte("foo"), []byte{}, 1); id != 0 { + if id, _ := w.Watch(0, []byte("foo"), []byte{}, 1); id != 0 { t.Fatalf("\x00 is range given; id expected 0, got %d", id) } } @@ -192,7 +227,7 @@ func TestWatchDeleteRange(t *testing.T) { w := s.NewWatchStream() from, to := []byte(testKeyPrefix), []byte(fmt.Sprintf("%s_%d", testKeyPrefix, 99)) - w.Watch(from, to, 0) + w.Watch(0, from, to, 0) s.DeleteRange(from, to) @@ -222,7 +257,7 @@ func TestWatchStreamCancelWatcherByID(t *testing.T) { w := s.NewWatchStream() defer w.Close() - id := w.Watch([]byte("foo"), nil, 0) + id, _ := w.Watch(0, []byte("foo"), nil, 0) tests := []struct { cancelID WatchID @@ -284,7 +319,7 @@ func TestWatcherRequestProgress(t *testing.T) { default: } - id := w.Watch(notTestKey, nil, 1) + id, _ := w.Watch(0, notTestKey, nil, 1) w.RequestProgress(id) select { case resp := <-w.Chan(): @@ -295,7 +330,7 @@ func TestWatcherRequestProgress(t *testing.T) { s.syncWatchers() w.RequestProgress(id) - wrs := WatchResponse{WatchID: 0, Revision: 2} + wrs := WatchResponse{WatchID: id, Revision: 2} select { case resp := <-w.Chan(): if !reflect.DeepEqual(resp, wrs) { @@ -318,7 +353,7 @@ func TestWatcherWatchWithFilter(t *testing.T) { return e.Type == mvccpb.PUT } - w.Watch([]byte("foo"), nil, 0, filterPut) + w.Watch(0, []byte("foo"), nil, 0, filterPut) done := make(chan struct{}) go func() {