diff --git a/server/etcdserver/api/v3rpc/watch.go b/server/etcdserver/api/v3rpc/watch.go index 512825d32400..cd834aa3e860 100644 --- a/server/etcdserver/api/v3rpc/watch.go +++ b/server/etcdserver/api/v3rpc/watch.go @@ -144,6 +144,10 @@ type serverWatchStream struct { // records fragmented watch IDs fragment map[mvcc.WatchID]bool + // indicates whether we have an outstanding global progress + // notification to send + deferredProgress bool + // closec indicates the stream is closed. closec chan struct{} @@ -173,6 +177,8 @@ func (ws *watchServer) Watch(stream pb.Watch_WatchServer) (err error) { prevKV: make(map[mvcc.WatchID]bool), fragment: make(map[mvcc.WatchID]bool), + deferredProgress: false, + closec: make(chan struct{}), } @@ -359,10 +365,16 @@ func (sws *serverWatchStream) recvLoop() error { } case *pb.WatchRequest_ProgressRequest: if uv.ProgressRequest != nil { - sws.ctrlStream <- &pb.WatchResponse{ - Header: sws.newResponseHeader(sws.watchStream.Rev()), - WatchId: clientv3.InvalidWatchID, // response is not associated with any WatchId and will be broadcast to all watch channels + sws.mu.Lock() + // Ignore if deferred progress notification is already in progress + if !sws.deferredProgress { + // Request progress for all watchers, + // force generation of a response + if !sws.watchStream.RequestProgressAll() { + sws.deferredProgress = true + } } + sws.mu.Unlock() } default: // we probably should not shutdown the entire stream when @@ -430,11 +442,15 @@ func (sws *serverWatchStream) sendLoop() { Canceled: canceled, } - if _, okID := ids[wresp.WatchID]; !okID { - // buffer if id not yet announced - wrs := append(pending[wresp.WatchID], wr) - pending[wresp.WatchID] = wrs - continue + // Progress notifications can have WatchID -1 + // if they announce on behalf of multiple watchers + if wresp.WatchID != clientv3.InvalidWatchID { + if _, okID := ids[wresp.WatchID]; !okID { + // buffer if id not yet announced + wrs := append(pending[wresp.WatchID], wr) + pending[wresp.WatchID] = wrs + continue + } } mvcc.ReportEventReceived(len(evs)) @@ -465,6 +481,11 @@ func (sws *serverWatchStream) sendLoop() { // elide next progress update if sent a key update sws.progress[wresp.WatchID] = false } + if sws.deferredProgress { + if sws.watchStream.RequestProgressAll() { + sws.deferredProgress = false + } + } sws.mu.Unlock() case c, ok := <-sws.ctrlStream: diff --git a/server/mvcc/watchable_store.go b/server/mvcc/watchable_store.go index 3a9fd344cca5..d3c0208d667c 100644 --- a/server/mvcc/watchable_store.go +++ b/server/mvcc/watchable_store.go @@ -15,6 +15,7 @@ package mvcc import ( + clientv3 "go.etcd.io/etcd/client/v3" "sync" "time" @@ -41,6 +42,7 @@ var ( type watchable interface { watch(key, end []byte, startRev int64, id WatchID, ch chan<- WatchResponse, fcs ...FilterFunc) (*watcher, cancelFunc) progress(w *watcher) + progressAll(watchers map[WatchID]*watcher) bool rev() int64 } @@ -477,14 +479,34 @@ func (s *watchableStore) addVictim(victim watcherBatch) { func (s *watchableStore) rev() int64 { return s.store.Rev() } func (s *watchableStore) progress(w *watcher) { + s.progressIfSync(map[WatchID]*watcher{w.id: w}, w.id) +} + +func (s *watchableStore) progressAll(watchers map[WatchID]*watcher) bool { + return s.progressIfSync(watchers, clientv3.InvalidWatchID) +} + +func (s *watchableStore) progressIfSync(watchers map[WatchID]*watcher, responseWatchID WatchID) bool { s.mu.RLock() defer s.mu.RUnlock() - if _, ok := s.synced.watchers[w]; ok { - w.send(WatchResponse{WatchID: w.id, Revision: s.rev()}) - // If the ch is full, this watcher is receiving events. - // We do not need to send progress at all. + // Any watcher unsynced? + for _, w := range watchers { + if _, ok := s.synced.watchers[w]; !ok { + return false + } + } + + // If all watchers are synchronised, send out progress + // notification on first watcher. Note that all watchers + // should have the same underlying stream, and the progress + // notification will be broadcasted client-side if required + // (see dispatchEvent in client/v3/watch.go) + for _, w := range watchers { + w.send(WatchResponse{WatchID: responseWatchID, Revision: s.rev()}) + return true } + return true } type watcher struct { diff --git a/server/mvcc/watcher.go b/server/mvcc/watcher.go index 7d2490b1d6e9..c67c21d61397 100644 --- a/server/mvcc/watcher.go +++ b/server/mvcc/watcher.go @@ -58,6 +58,13 @@ type WatchStream interface { // of the watchers since the watcher is currently synced. RequestProgress(id WatchID) + // RequestProgressAll requests a progress notification for all + // watchers sharing the stream. If all watchers are synced, a + // progress notification with watch ID -1 will be sent to an + // arbitrary watcher of this stream, and the function returns + // true. + RequestProgressAll() bool + // Cancel cancels a watcher by giving its ID. If watcher does not exist, an error will be // returned. Cancel(id WatchID) error @@ -188,3 +195,9 @@ func (ws *watchStream) RequestProgress(id WatchID) { } ws.watchable.progress(w) } + +func (ws *watchStream) RequestProgressAll() bool { + ws.mu.Lock() + defer ws.mu.Unlock() + return ws.watchable.progressAll(ws.watchers) +} diff --git a/server/mvcc/watcher_test.go b/server/mvcc/watcher_test.go index bbada4ed5dc5..81fb2aa4a293 100644 --- a/server/mvcc/watcher_test.go +++ b/server/mvcc/watcher_test.go @@ -17,12 +17,14 @@ package mvcc import ( "bytes" "fmt" + "go.uber.org/zap/zaptest" "os" "reflect" "testing" "time" "go.etcd.io/etcd/api/v3/mvccpb" + clientv3 "go.etcd.io/etcd/client/v3" "go.etcd.io/etcd/server/v3/lease" betesting "go.etcd.io/etcd/server/v3/mvcc/backend/testing" "go.uber.org/zap" @@ -342,6 +344,58 @@ func TestWatcherRequestProgress(t *testing.T) { } } +func TestWatcherRequestProgressAll(t *testing.T) { + b, tmpPath := betesting.NewDefaultTmpBackend(t) + + // manually create watchableStore instead of newWatchableStore + // because newWatchableStore automatically calls syncWatchers + // method to sync watchers in unsynced map. We want to keep watchers + // in unsynced to test if syncWatchers works as expected. + s := &watchableStore{ + store: NewStore(zaptest.NewLogger(t), b, &lease.FakeLessor{}, StoreConfig{}), + unsynced: newWatcherGroup(), + synced: newWatcherGroup(), + stopc: make(chan struct{}), + } + + defer func() { + s.store.Close() + os.Remove(tmpPath) + }() + + testKey := []byte("foo") + notTestKey := []byte("bad") + testValue := []byte("bar") + s.Put(testKey, testValue, lease.NoLease) + + // Create watch stream with watcher. We will not actually get + // any notifications on it specifically, but there needs to be + // at least one Watch for progress notifications to get + // generated. + w := s.NewWatchStream() + w.Watch(0, notTestKey, nil, 1) + + w.RequestProgressAll() + select { + case resp := <-w.Chan(): + t.Fatalf("unexpected %+v", resp) + default: + } + + s.syncWatchers() + + w.RequestProgressAll() + wrs := WatchResponse{WatchID: clientv3.InvalidWatchID, Revision: 2} + select { + case resp := <-w.Chan(): + if !reflect.DeepEqual(resp, wrs) { + t.Fatalf("got %+v, expect %+v", resp, wrs) + } + case <-time.After(time.Second): + t.Fatal("failed to receive progress") + } +} + func TestWatcherWatchWithFilter(t *testing.T) { b, tmpPath := betesting.NewDefaultTmpBackend(t) s := WatchableKV(newWatchableStore(zap.NewExample(), b, &lease.FakeLessor{}, StoreConfig{})) diff --git a/tests/integration/v3_watch_test.go b/tests/integration/v3_watch_test.go index 07561a278a92..518db37e2c26 100644 --- a/tests/integration/v3_watch_test.go +++ b/tests/integration/v3_watch_test.go @@ -1404,3 +1404,71 @@ func TestV3WatchCloseCancelRace(t *testing.T) { t.Fatalf("expected %s watch, got %s", expected, minWatches) } } + +// TestV3WatchProgressWaitsForSync checks that progress notifications +// don't get sent until the watcher is synchronised +func TestV3WatchProgressWaitsForSync(t *testing.T) { + + // Disable for gRPC proxy, as it does not support requesting + // progress notifications + if ThroughProxy { + t.Skip("grpc proxy currently does not support requesting progress notifications") + } + + BeforeTest(t) + + clus := NewClusterV3(t, &ClusterConfig{Size: 1}) + defer clus.Terminate(t) + + client := clus.RandClient() + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + // Write a couple values into key to make sure there's a + // non-trivial amount of history. + count := 1001 + t.Logf("Writing key 'foo' %d times", count) + for i := 0; i < count; i++ { + _, err := client.Put(ctx, "foo", fmt.Sprintf("bar%d", i)) + require.NoError(t, err) + } + + // Create watch channel starting at revision 1 (i.e. it starts + // unsynced because of the update above) + wch := client.Watch(ctx, "foo", clientv3.WithRev(1)) + + // Immediately request a progress notification. As the client + // is unsynchronised, the server will have to defer the + // notification internally. + err := client.RequestProgress(ctx) + require.NoError(t, err) + + // Verify that we get the watch responses first. Note that + // events might be spread across multiple packets. + var event_count = 0 + for event_count < count { + wr := <-wch + if wr.Err() != nil { + t.Fatal(fmt.Errorf("watch error: %w", wr.Err())) + } + if wr.IsProgressNotify() { + t.Fatal("Progress notification from unsynced client!") + } + if wr.Header.Revision != int64(count+1) { + t.Fatal("Incomplete watch response!") + } + event_count += len(wr.Events) + } + + // ... followed by the requested progress notification + wr2 := <-wch + if wr2.Err() != nil { + t.Fatal(fmt.Errorf("watch error: %w", wr2.Err())) + } + if !wr2.IsProgressNotify() { + t.Fatal("Did not receive progress notification!") + } + if wr2.Header.Revision != int64(count+1) { + t.Fatal("Wrong revision in progress notification!") + } +}