diff --git a/pkg/kv/kvserver/client_raft_test.go b/pkg/kv/kvserver/client_raft_test.go index e0a70de719ce..a70bdb75c5bb 100644 --- a/pkg/kv/kvserver/client_raft_test.go +++ b/pkg/kv/kvserver/client_raft_test.go @@ -30,6 +30,7 @@ import ( "github.com/cockroachdb/cockroach/pkg/kv" "github.com/cockroachdb/cockroach/pkg/kv/kvpb" "github.com/cockroachdb/cockroach/pkg/kv/kvserver" + "github.com/cockroachdb/cockroach/pkg/kv/kvserver/allocator/storepool" "github.com/cockroachdb/cockroach/pkg/kv/kvserver/kvflowcontrol/kvflowdispatch" "github.com/cockroachdb/cockroach/pkg/kv/kvserver/kvserverbase" "github.com/cockroachdb/cockroach/pkg/kv/kvserver/kvserverpb" @@ -6689,3 +6690,113 @@ func TestRaftUnquiesceLeaderNoProposal(t *testing.T) { require.Equal(t, initialStatus.Progress[1].Match, status.Progress[1].Match) t.Logf("n1 still leader with no new proposals at log index %d", status.Progress[1].Match) } + +// getMapsDiff returns the difference between the values of corresponding +// metrics in two maps. Assumption: beforeMap and afterMap contain the same set +// of keys. +func getMapsDiff(beforeMap map[string]int64, afterMap map[string]int64) map[string]int64 { + diffMap := make(map[string]int64) + for metricName, beforeValue := range beforeMap { + if v, ok := afterMap[metricName]; ok { + diffMap[metricName] = v - beforeValue + } + } + return diffMap +} + +// TestStoreMetricsOnIncomingOutgoingMsg verifies that HandleRaftRequest() and +// HandleRaftRequestSent() correctly update metrics for incoming and outgoing +// raft messages. +func TestStoreMetricsOnIncomingOutgoingMsg(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + ctx := context.Background() + clock := hlc.NewClockForTesting(timeutil.NewManualTime(timeutil.Unix(0, 123))) + cfg := kvserver.TestStoreConfig(clock) + var stopper *stop.Stopper + stopper, _, _, cfg.StorePool, _ = storepool.CreateTestStorePool(ctx, cfg.Settings, + liveness.TestTimeUntilNodeDead, false, /* deterministic */ + func() int { return 1 }, /* nodeCount */ + livenesspb.NodeLivenessStatus_DEAD) + defer stopper.Stop(ctx) + + // Create a noop store and request. + node := roachpb.NodeDescriptor{NodeID: roachpb.NodeID(1)} + eng := storage.NewDefaultInMemForTesting() + stopper.AddCloser(eng) + cfg.Transport = kvserver.NewDummyRaftTransport(cfg.Settings, cfg.AmbientCtx.Tracer) + store := kvserver.NewStore(ctx, cfg, eng, &node) + store.Ident = &roachpb.StoreIdent{ + ClusterID: uuid.Nil, + StoreID: 1, + NodeID: 1, + } + request := &kvserverpb.RaftMessageRequest{ + RangeID: 1, + FromReplica: roachpb.ReplicaDescriptor{}, + ToReplica: roachpb.ReplicaDescriptor{}, + Message: raftpb.Message{ + From: 1, + To: 2, + Type: raftpb.MsgTimeoutNow, + Term: 1, + }, + } + + metricsNames := []string{ + "raft.rcvd.bytes", + "raft.rcvd.cross_region.bytes", + "raft.rcvd.cross_zone.bytes", + "raft.sent.bytes", + "raft.sent.cross_region.bytes", + "raft.sent.cross_zone.bytes"} + stream := noopRaftMessageResponseStream{} + expectedSize := int64(request.Size()) + + t.Run("received raft message", func(t *testing.T) { + before, metricsErr := store.Metrics().GetStoreMetrics(metricsNames) + if metricsErr != nil { + t.Error(metricsErr) + } + if err := store.HandleRaftRequest(context.Background(), request, stream); err != nil { + t.Fatalf("HandleRaftRequest returned err %s", err) + } + after, metricsErr := store.Metrics().GetStoreMetrics(metricsNames) + if metricsErr != nil { + t.Error(metricsErr) + } + actual := getMapsDiff(before, after) + expected := map[string]int64{ + "raft.rcvd.bytes": expectedSize, + "raft.rcvd.cross_region.bytes": 0, + "raft.rcvd.cross_zone.bytes": 0, + "raft.sent.bytes": 0, + "raft.sent.cross_region.bytes": 0, + "raft.sent.cross_zone.bytes": 0, + } + require.Equal(t, expected, actual) + }) + + t.Run("sent raft message", func(t *testing.T) { + before, metricsErr := store.Metrics().GetStoreMetrics(metricsNames) + if metricsErr != nil { + t.Error(metricsErr) + } + store.HandleRaftRequestSent(context.Background(), + request.FromReplica.NodeID, request.ToReplica.NodeID, int64(request.Size())) + after, metricsErr := store.Metrics().GetStoreMetrics(metricsNames) + if metricsErr != nil { + t.Error(metricsErr) + } + actual := getMapsDiff(before, after) + expected := map[string]int64{ + "raft.rcvd.bytes": 0, + "raft.rcvd.cross_region.bytes": 0, + "raft.rcvd.cross_zone.bytes": 0, + "raft.sent.bytes": expectedSize, + "raft.sent.cross_region.bytes": 0, + "raft.sent.cross_zone.bytes": 0, + } + require.Equal(t, expected, actual) + }) +} diff --git a/pkg/kv/kvserver/helpers_test.go b/pkg/kv/kvserver/helpers_test.go index 1039d7383f39..55f5f545094f 100644 --- a/pkg/kv/kvserver/helpers_test.go +++ b/pkg/kv/kvserver/helpers_test.go @@ -42,6 +42,7 @@ import ( "github.com/cockroachdb/cockroach/pkg/util/circuit" "github.com/cockroachdb/cockroach/pkg/util/hlc" "github.com/cockroachdb/cockroach/pkg/util/log" + "github.com/cockroachdb/cockroach/pkg/util/metric" "github.com/cockroachdb/cockroach/pkg/util/quotapool" "github.com/cockroachdb/cockroach/pkg/util/randutil" "github.com/cockroachdb/cockroach/pkg/util/timeutil" @@ -226,6 +227,51 @@ func (s *Store) RaftSchedulerPriorityIDs() []roachpb.RangeID { return s.scheduler.PriorityIDs() } +// GetStoreMetric retrieves the count of the store metric whose metadata name +// matches with the given name parameter. If the specified metric cannot be +// found, the function will return an error. +func (sm *StoreMetrics) GetStoreMetric(name string) (int64, error) { + var c int64 + var found bool + sm.registry.Each(func(n string, v interface{}) { + if name == n { + switch t := v.(type) { + case *metric.Counter: + c = t.Count() + found = true + case *metric.Gauge: + c = t.Value() + found = true + } + } + }) + if !found { + return -1, errors.Errorf("cannot find metric for %s", name) + } + return c, nil +} + +// GetStoreMetrics fetches the count of each specified Store metric from the +// `metricNames` parameter and returns the result as a map. The keys in the map +// represent the metric metadata names, while the corresponding values indicate +// the count of each metric. If any of the specified metric cannot be found or +// is not a counter, the function will return an error. +// +// Assumption: 1. The metricNames parameter should consist of string literals +// that match the metadata names used for metric counters. 2. Each metric name +// provided in `metricNames` must exist, unique and be a counter type. +func (sm *StoreMetrics) GetStoreMetrics(metricsNames []string) (map[string]int64, error) { + metrics := make(map[string]int64) + for _, metricName := range metricsNames { + count, err := sm.GetStoreMetric(metricName) + if err != nil { + return map[string]int64{}, errors.Errorf("cannot find metric for %s", metricName) + } + metrics[metricName] = count + } + return metrics, nil +} + func NewTestStorePool(cfg StoreConfig) *storepool.StorePool { liveness.TimeUntilNodeDead.Override(context.Background(), &cfg.Settings.SV, liveness.TestTimeUntilNodeDeadOff) return storepool.NewStorePool( diff --git a/pkg/kv/kvserver/metrics.go b/pkg/kv/kvserver/metrics.go index a82dec31be5f..ccfc2120ac8d 100644 --- a/pkg/kv/kvserver/metrics.go +++ b/pkg/kv/kvserver/metrics.go @@ -1381,6 +1381,59 @@ handling consumes writes. Unit: metric.Unit_BYTES, } + metaRaftRcvdBytes = metric.Metadata{ + Name: "raft.rcvd.bytes", + Help: `Number of bytes in Raft messages received by this store. Note + that this does not include raft snapshot received.`, + Measurement: "Bytes", + Unit: metric.Unit_BYTES, + } + metaRaftRcvdCrossRegionBytes = metric.Metadata{ + Name: "raft.rcvd.cross_region.bytes", + Help: `Number of bytes received by this store for cross region Raft messages + (when region tiers are configured). Note that this does not include raft + snapshot received.`, + Measurement: "Bytes", + Unit: metric.Unit_BYTES, + } + metaRaftRcvdCrossZoneBytes = metric.Metadata{ + Name: "raft.rcvd.cross_zone.bytes", + Help: `Number of bytes received by this store for cross zone, same region + Raft messages (when region and zone tiers are configured). If region tiers + are not configured, this count may include data sent between different + regions. To ensure accurate monitoring of transmitted data, it is important + to set up a consistent locality configuration across nodes. Note that this + does not include raft snapshot received.`, + Measurement: "Bytes", + Unit: metric.Unit_BYTES, + } + metaRaftSentBytes = metric.Metadata{ + Name: "raft.sent.bytes", + Help: `Number of bytes in Raft messages sent by this store. Note that + this does not include raft snapshot sent.`, + Measurement: "Bytes", + Unit: metric.Unit_BYTES, + } + metaRaftSentCrossRegionBytes = metric.Metadata{ + Name: "raft.sent.cross_region.bytes", + Help: `Number of bytes sent by this store for cross region Raft messages + (when region tiers are configured). Note that this does not include raft + snapshot sent.`, + Measurement: "Bytes", + Unit: metric.Unit_BYTES, + } + metaRaftSentCrossZoneBytes = metric.Metadata{ + Name: "raft.sent.cross_zone.bytes", + Help: `Number of bytes sent by this store for cross zone, same region Raft + messages (when region and zone tiers are configured). If region tiers are + not configured, this count may include data sent between different regions. + To ensure accurate monitoring of transmitted data, it is important to set up + a consistent locality configuration across nodes. Note that this does not + include raft snapshot sent.`, + Measurement: "Bytes", + Unit: metric.Unit_BYTES, + } + metaRaftCoalescedHeartbeatsPending = metric.Metadata{ Name: "raft.heartbeats.pending", Help: "Number of pending heartbeats and responses waiting to be coalesced", @@ -2290,11 +2343,17 @@ type StoreMetrics struct { // Raft message metrics. // // An array for conveniently finding the appropriate metric. - RaftRcvdMessages [maxRaftMsgType + 1]*metric.Counter - RaftRcvdDropped *metric.Counter - RaftRcvdDroppedBytes *metric.Counter - RaftRcvdQueuedBytes *metric.Gauge - RaftRcvdSteppedBytes *metric.Counter + RaftRcvdMessages [maxRaftMsgType + 1]*metric.Counter + RaftRcvdDropped *metric.Counter + RaftRcvdDroppedBytes *metric.Counter + RaftRcvdQueuedBytes *metric.Gauge + RaftRcvdSteppedBytes *metric.Counter + RaftRcvdBytes *metric.Counter + RaftRcvdCrossRegionBytes *metric.Counter + RaftRcvdCrossZoneBytes *metric.Counter + RaftSentBytes *metric.Counter + RaftSentCrossRegionBytes *metric.Counter + RaftSentCrossZoneBytes *metric.Counter // Raft log metrics. RaftLogFollowerBehindCount *metric.Gauge @@ -2962,10 +3021,16 @@ func newStoreMetrics(histogramWindow time.Duration) *StoreMetrics { raftpb.MsgTransferLeader: metric.NewCounter(metaRaftRcvdTransferLeader), raftpb.MsgTimeoutNow: metric.NewCounter(metaRaftRcvdTimeoutNow), }, - RaftRcvdDropped: metric.NewCounter(metaRaftRcvdDropped), - RaftRcvdDroppedBytes: metric.NewCounter(metaRaftRcvdDroppedBytes), - RaftRcvdQueuedBytes: metric.NewGauge(metaRaftRcvdQueuedBytes), - RaftRcvdSteppedBytes: metric.NewCounter(metaRaftRcvdSteppedBytes), + RaftRcvdDropped: metric.NewCounter(metaRaftRcvdDropped), + RaftRcvdDroppedBytes: metric.NewCounter(metaRaftRcvdDroppedBytes), + RaftRcvdQueuedBytes: metric.NewGauge(metaRaftRcvdQueuedBytes), + RaftRcvdSteppedBytes: metric.NewCounter(metaRaftRcvdSteppedBytes), + RaftRcvdBytes: metric.NewCounter(metaRaftRcvdBytes), + RaftRcvdCrossRegionBytes: metric.NewCounter(metaRaftRcvdCrossRegionBytes), + RaftRcvdCrossZoneBytes: metric.NewCounter(metaRaftRcvdCrossZoneBytes), + RaftSentBytes: metric.NewCounter(metaRaftSentBytes), + RaftSentCrossRegionBytes: metric.NewCounter(metaRaftSentCrossRegionBytes), + RaftSentCrossZoneBytes: metric.NewCounter(metaRaftSentCrossZoneBytes), // Raft log metrics. RaftLogFollowerBehindCount: metric.NewGauge(metaRaftLogFollowerBehindCount), @@ -3272,6 +3337,47 @@ func (sm *StoreMetrics) updateCrossLocalityMetricsOnSnapshotRcvd( } } +// updateCrossLocalityMetricsOnIncomingRaftMsg updates store metrics for raft +// messages that have been received via HandleRaftRequest. In the cases of +// messages containing heartbeats or heartbeat_resps, they capture the byte +// count of requests with coalesced heartbeats before any uncoalescing happens. +// The metrics being updated include 1. total byte count of messages received 2. +// cross-region metrics, which monitor activities across different regions, and +// 3. cross-zone metrics, which monitor activities across different zones within +// the same region or in cases where region tiers are not configured. +func (sm *StoreMetrics) updateCrossLocalityMetricsOnIncomingRaftMsg( + comparisonResult roachpb.LocalityComparisonType, msgSize int64, +) { + sm.RaftRcvdBytes.Inc(msgSize) + switch comparisonResult { + case roachpb.LocalityComparisonType_CROSS_REGION: + sm.RaftRcvdCrossRegionBytes.Inc(msgSize) + case roachpb.LocalityComparisonType_SAME_REGION_CROSS_ZONE: + sm.RaftRcvdCrossZoneBytes.Inc(msgSize) + } +} + +// updateCrossLocalityMetricsOnOutgoingRaftMsg updates store metrics for raft +// messages that are about to be sent via raftSendQueue. In the cases of +// messages containing heartbeats or heartbeat_resps, they capture the byte +// count of requests with coalesced heartbeats. The metrics being updated +// include 1. total byte count of messages sent 2. cross-region metrics, which +// monitor activities across different regions, and 3. cross-zone metrics, which +// monitor activities across different zones within the same region or in cases +// where region tiers are not configured. Note that these metrics may include +// messages that get dropped by `SendAsync` due to a full outgoing queue. +func (sm *StoreMetrics) updateCrossLocalityMetricsOnOutgoingRaftMsg( + comparisonResult roachpb.LocalityComparisonType, msgSize int64, +) { + sm.RaftSentBytes.Inc(msgSize) + switch comparisonResult { + case roachpb.LocalityComparisonType_CROSS_REGION: + sm.RaftSentCrossRegionBytes.Inc(msgSize) + case roachpb.LocalityComparisonType_SAME_REGION_CROSS_ZONE: + sm.RaftSentCrossZoneBytes.Inc(msgSize) + } +} + func (sm *StoreMetrics) updateEnvStats(stats storage.EnvStats) { sm.EncryptionAlgorithm.Update(int64(stats.EncryptionType)) } diff --git a/pkg/kv/kvserver/raft_transport.go b/pkg/kv/kvserver/raft_transport.go index 8458cf2f8a4f..3d0a1c19f4ef 100644 --- a/pkg/kv/kvserver/raft_transport.go +++ b/pkg/kv/kvserver/raft_transport.go @@ -136,6 +136,22 @@ type IncomingRaftMessageHandler interface { ) *kvserverpb.DelegateSnapshotResponse } +// OutgoingRaftMessageHandler is the interface that must be implemented by +// arguments to RaftTransport.ListenOutgoingMessage. +type OutgoingRaftMessageHandler interface { + // HandleRaftRequestSent is called synchronously for every Raft messages right + // before it is sent to raftSendQueue in RaftTransport.SendAsync(). Note that + // the message may not be successfully queued if it gets dropped by SendAsync + // due to a full outgoing queue. + // + // As of now, the only use case of this function is for metrics update on + // messages sent which is why it only takes specific properties of the request + // as arguments. But it can be easily extended to take the complete request if + // needed. + HandleRaftRequestSent(ctx context.Context, + fromNodeID roachpb.NodeID, toNodeID roachpb.NodeID, msgSize int64) +} + // RaftTransport handles the rpc messages for raft. // // The raft transport is asynchronous with respect to the caller, and @@ -163,6 +179,7 @@ type RaftTransport struct { dialer *nodedialer.Dialer incomingMessageHandlers syncutil.IntMap // map[roachpb.StoreID]*IncomingRaftMessageHandler + outgoingMessageHandlers syncutil.IntMap // map[roachpb.StoreID]*OutgoingRaftMessageHandler kvflowControl struct { // Everything nested under this struct is used to return flow tokens @@ -372,6 +389,9 @@ func (t *RaftTransport) queueByteSize() int64 { return size } +// getIncomingRaftMessageHandler returns the registered +// IncomingRaftMessageHandler for the given StoreID. If no handlers are +// registered for the StoreID, it returns (nil, false). func (t *RaftTransport) getIncomingRaftMessageHandler( storeID roachpb.StoreID, ) (IncomingRaftMessageHandler, bool) { @@ -381,6 +401,18 @@ func (t *RaftTransport) getIncomingRaftMessageHandler( return nil, false } +// getOutgoingMessageHandler returns the registered OutgoingRaftMessageHandler +// for the given StoreID. If no handlers are registered for the StoreID, it +// returns (nil, false). +func (t *RaftTransport) getOutgoingMessageHandler( + storeID roachpb.StoreID, +) (OutgoingRaftMessageHandler, bool) { + if value, ok := t.outgoingMessageHandlers.Load(int64(storeID)); ok { + return *(*OutgoingRaftMessageHandler)(value), true + } + return nil, false +} + // handleRaftRequest proxies a request to the listening server interface. func (t *RaftTransport) handleRaftRequest( ctx context.Context, req *kvserverpb.RaftMessageRequest, respStream RaftMessageResponseStream, @@ -593,6 +625,19 @@ func (t *RaftTransport) StopIncomingRaftMessages(storeID roachpb.StoreID) { t.incomingMessageHandlers.Delete(int64(storeID)) } +// ListenOutgoingMessage registers an OutgoingRaftMessageHandler to capture +// messages right before they are sent through the raftSendQueue. +func (t *RaftTransport) ListenOutgoingMessage( + storeID roachpb.StoreID, handler OutgoingRaftMessageHandler, +) { + t.outgoingMessageHandlers.Store(int64(storeID), unsafe.Pointer(&handler)) +} + +// StopOutgoingMessage unregisters an OutgoingRaftMessageHandler. +func (t *RaftTransport) StopOutgoingMessage(storeID roachpb.StoreID) { + t.outgoingMessageHandlers.Delete(int64(storeID)) +} + // processQueue opens a Raft client stream and sends messages from the // designated queue (ch) via that stream, exiting when an error is received or // when it idles out. All messages remaining in the queue at that point are @@ -867,6 +912,11 @@ func (t *RaftTransport) SendAsync( // Note: computing the size of the request *before* sending it to the queue, // because the receiver takes ownership of, and can modify it. size := int64(req.Size()) + if outgoingMessageHandler, ok := t.getOutgoingMessageHandler(req.FromReplica.StoreID); ok { + // Capture outgoing Raft messages only when the sender's store has an + // OutgoingRaftMessageHandler registered. + outgoingMessageHandler.HandleRaftRequestSent(context.Background(), req.FromReplica.NodeID, req.ToReplica.NodeID, size) + } select { case q.reqs <- req: q.bytes.Add(size) diff --git a/pkg/kv/kvserver/replica_learner_test.go b/pkg/kv/kvserver/replica_learner_test.go index a9e5f38b20e9..401809b191ec 100644 --- a/pkg/kv/kvserver/replica_learner_test.go +++ b/pkg/kv/kvserver/replica_learner_test.go @@ -42,7 +42,6 @@ import ( "github.com/cockroachdb/cockroach/pkg/util/ctxgroup" "github.com/cockroachdb/cockroach/pkg/util/leaktest" "github.com/cockroachdb/cockroach/pkg/util/log" - "github.com/cockroachdb/cockroach/pkg/util/metric" "github.com/cockroachdb/cockroach/pkg/util/syncutil" "github.com/cockroachdb/cockroach/pkg/util/tracing" "github.com/cockroachdb/cockroach/pkg/util/tracing/tracingpb" @@ -127,25 +126,11 @@ func getFirstStoreMetric(t *testing.T, s serverutils.TestServerInterface, name s t.Helper() store, err := s.GetStores().(*kvserver.Stores).GetStore(s.GetFirstStoreID()) require.NoError(t, err) - - var c int64 - var found bool - store.Registry().Each(func(n string, v interface{}) { - if name == n { - switch t := v.(type) { - case *metric.Counter: - c = t.Count() - found = true - case *metric.Gauge: - c = t.Value() - found = true - } - } - }) - if !found { - panic(fmt.Sprintf("couldn't find metric %s", name)) + count, err := store.Metrics().GetStoreMetric(name) + if err != nil { + panic(err) } - return c + return count } func TestAddReplicaViaLearner(t *testing.T) { diff --git a/pkg/kv/kvserver/replicate_queue_test.go b/pkg/kv/kvserver/replicate_queue_test.go index 1f4847e90ce2..59f739bdd1d6 100644 --- a/pkg/kv/kvserver/replicate_queue_test.go +++ b/pkg/kv/kvserver/replicate_queue_test.go @@ -1893,7 +1893,7 @@ func (h delayingRaftMessageHandler) HandleRaftRequest( } go func() { time.Sleep(raftDelay) - err := h.IncomingRaftMessageHandler.HandleRaftRequest(ctx, req, respStream) + err := h.IncomingRaftMessageHandler.HandleRaftRequest(context.Background(), req, respStream) if err != nil { log.Infof(ctx, "HandleRaftRequest returned err %s", err) } diff --git a/pkg/kv/kvserver/store.go b/pkg/kv/kvserver/store.go index aea8ad626276..c39a5004c8f7 100644 --- a/pkg/kv/kvserver/store.go +++ b/pkg/kv/kvserver/store.go @@ -1024,6 +1024,8 @@ type Store struct { } var _ kv.Sender = &Store{} +var _ IncomingRaftMessageHandler = &Store{} +var _ OutgoingRaftMessageHandler = &Store{} // A StoreConfig encompasses the auxiliary objects and configuration // required to create a store. @@ -2062,6 +2064,7 @@ func (s *Store) Start(ctx context.Context, stopper *stop.Stopper) error { // Start Raft processing goroutines. s.cfg.Transport.ListenIncomingRaftMessages(s.StoreID(), s) + s.cfg.Transport.ListenOutgoingMessage(s.StoreID(), s) s.processRaft(ctx) // Register a callback to unquiesce any ranges with replicas on a diff --git a/pkg/kv/kvserver/store_raft.go b/pkg/kv/kvserver/store_raft.go index da6bfb7d24d9..dde432399287 100644 --- a/pkg/kv/kvserver/store_raft.go +++ b/pkg/kv/kvserver/store_raft.go @@ -269,6 +269,8 @@ func (s *Store) uncoalesceBeats( func (s *Store) HandleRaftRequest( ctx context.Context, req *kvserverpb.RaftMessageRequest, respStream RaftMessageResponseStream, ) *kvpb.Error { + comparisonResult := s.getLocalityComparison(ctx, req.FromReplica.NodeID, req.ToReplica.NodeID) + s.metrics.updateCrossLocalityMetricsOnIncomingRaftMsg(comparisonResult, int64(req.Size())) // NB: unlike the other two IncomingRaftMessageHandler methods implemented by // Store, this one doesn't need to directly run through a Stopper task because // it delegates all work through a raftScheduler, whose workers' lifetimes are @@ -320,6 +322,18 @@ func (s *Store) HandleRaftUncoalescedRequest( return enqueue } +// HandleRaftRequestSent is called to capture outgoing Raft messages just prior +// to their transmission to the raftSendQueue. Note that the message might not +// be successfully queued if it gets dropped by SendAsync due to a full outgoing +// queue. Currently, this is only used for metrics update which is why it only +// takes specific properties of the request as arguments. +func (s *Store) HandleRaftRequestSent( + ctx context.Context, fromNodeID roachpb.NodeID, toNodeID roachpb.NodeID, msgSize int64, +) { + comparisonResult := s.getLocalityComparison(ctx, fromNodeID, toNodeID) + s.metrics.updateCrossLocalityMetricsOnOutgoingRaftMsg(comparisonResult, msgSize) +} + // withReplicaForRequest calls the supplied function with the (lazily // initialized) Replica specified in the request. The replica passed to // the function will have its Replica.raftMu locked. @@ -721,6 +735,7 @@ func (s *Store) processRaft(ctx context.Context) { _ = s.stopper.RunAsyncTask(ctx, "coalesced-hb-loop", s.coalescedHeartbeatsLoop) s.stopper.AddCloser(stop.CloserFn(func() { s.cfg.Transport.StopIncomingRaftMessages(s.StoreID()) + s.cfg.Transport.StopOutgoingMessage(s.StoreID()) })) s.syncWaiter.Start(ctx, s.stopper) diff --git a/pkg/kv/kvserver/store_raft_test.go b/pkg/kv/kvserver/store_raft_test.go index f410b0506d54..30217493f4c4 100644 --- a/pkg/kv/kvserver/store_raft_test.go +++ b/pkg/kv/kvserver/store_raft_test.go @@ -13,15 +13,25 @@ package kvserver import ( "context" + "fmt" "math" "testing" + "github.com/cockroachdb/cockroach/pkg/kv/kvserver/allocator/storepool" "github.com/cockroachdb/cockroach/pkg/kv/kvserver/kvserverpb" + "github.com/cockroachdb/cockroach/pkg/kv/kvserver/liveness" + "github.com/cockroachdb/cockroach/pkg/kv/kvserver/liveness/livenesspb" "github.com/cockroachdb/cockroach/pkg/roachpb" "github.com/cockroachdb/cockroach/pkg/settings/cluster" + "github.com/cockroachdb/cockroach/pkg/storage" + "github.com/cockroachdb/cockroach/pkg/util/hlc" "github.com/cockroachdb/cockroach/pkg/util/leaktest" + "github.com/cockroachdb/cockroach/pkg/util/log" "github.com/cockroachdb/cockroach/pkg/util/metric" "github.com/cockroachdb/cockroach/pkg/util/mon" + "github.com/cockroachdb/cockroach/pkg/util/stop" + "github.com/cockroachdb/cockroach/pkg/util/timeutil" + "github.com/cockroachdb/cockroach/pkg/util/uuid" "github.com/stretchr/testify/require" "go.etcd.io/raft/v3/raftpb" ) @@ -124,3 +134,108 @@ func TestRaftReceiveQueue(t *testing.T) { require.Equal(t, n5, q5.acc.Used()) // we didn't touch q5 } } + +// getMapsDiff returns the difference between the values of corresponding +// metrics in two maps. Assumption: beforeMap and afterMap contain the same set +// of keys. +func getMapsDiff(beforeMap map[string]int64, afterMap map[string]int64) map[string]int64 { + diffMap := make(map[string]int64) + for metricName, beforeValue := range beforeMap { + if v, ok := afterMap[metricName]; ok { + diffMap[metricName] = v - beforeValue + } + } + return diffMap +} + +// TestRaftCrossLocalityMetrics verifies that +// updateCrossLocalityMetricsOn{Incoming|Outgoing}RaftMsg correctly updates +// cross-region, cross-zone byte count metrics for incoming and outgoing raft +// msg. +func TestRaftCrossLocalityMetrics(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + ctx := context.Background() + clock := hlc.NewClockForTesting(timeutil.NewManualTime(timeutil.Unix(0, 123))) + cfg := TestStoreConfig(clock) + var stopper *stop.Stopper + stopper, _, _, cfg.StorePool, _ = storepool.CreateTestStorePool(ctx, cfg.Settings, + liveness.TestTimeUntilNodeDead, false, /* deterministic */ + func() int { return 0 }, /* nodeCount */ + livenesspb.NodeLivenessStatus_DEAD) + defer stopper.Stop(ctx) + + // Create a noop store. + node := roachpb.NodeDescriptor{NodeID: roachpb.NodeID(1)} + eng := storage.NewDefaultInMemForTesting() + stopper.AddCloser(eng) + cfg.Transport = NewDummyRaftTransport(cfg.Settings, cfg.AmbientCtx.Tracer) + store := NewStore(ctx, cfg, eng, &node) + store.Ident = &roachpb.StoreIdent{ + ClusterID: uuid.Nil, + StoreID: 1, + NodeID: 1, + } + + const expectedInc = 10 + metricsNames := []string{ + "raft.rcvd.bytes", + "raft.rcvd.cross_region.bytes", + "raft.rcvd.cross_zone.bytes", + "raft.sent.bytes", + "raft.sent.cross_region.bytes", + "raft.sent.cross_zone.bytes"} + for _, tc := range []struct { + crossLocalityType roachpb.LocalityComparisonType + expectedMetricChange [6]int64 + forRequest bool + }{ + {crossLocalityType: roachpb.LocalityComparisonType_CROSS_REGION, + expectedMetricChange: [6]int64{expectedInc, expectedInc, 0, 0, 0, 0}, + forRequest: true, + }, + {crossLocalityType: roachpb.LocalityComparisonType_SAME_REGION_CROSS_ZONE, + expectedMetricChange: [6]int64{expectedInc, 0, expectedInc, 0, 0, 0}, + forRequest: true, + }, + {crossLocalityType: roachpb.LocalityComparisonType_SAME_REGION_SAME_ZONE, + expectedMetricChange: [6]int64{expectedInc, 0, 0, 0, 0, 0}, + forRequest: true, + }, + {crossLocalityType: roachpb.LocalityComparisonType_CROSS_REGION, + expectedMetricChange: [6]int64{0, 0, 0, expectedInc, expectedInc, 0}, + forRequest: false, + }, + {crossLocalityType: roachpb.LocalityComparisonType_SAME_REGION_CROSS_ZONE, + expectedMetricChange: [6]int64{0, 0, 0, expectedInc, 0, expectedInc}, + forRequest: false, + }, + {crossLocalityType: roachpb.LocalityComparisonType_SAME_REGION_SAME_ZONE, + expectedMetricChange: [6]int64{0, 0, 0, expectedInc, 0, 0}, + forRequest: false, + }, + } { + t.Run(fmt.Sprintf("%-v", tc.crossLocalityType), func(t *testing.T) { + beforeMetrics, metricsErr := store.metrics.GetStoreMetrics(metricsNames) + if metricsErr != nil { + t.Error(metricsErr) + } + if tc.forRequest { + store.Metrics().updateCrossLocalityMetricsOnIncomingRaftMsg(tc.crossLocalityType, expectedInc) + } else { + store.Metrics().updateCrossLocalityMetricsOnOutgoingRaftMsg(tc.crossLocalityType, expectedInc) + } + + afterMetrics, metricsErr := store.metrics.GetStoreMetrics(metricsNames) + if metricsErr != nil { + t.Error(metricsErr) + } + metricsDiff := getMapsDiff(beforeMetrics, afterMetrics) + expectedDiff := make(map[string]int64, 6) + for i, inc := range tc.expectedMetricChange { + expectedDiff[metricsNames[i]] = inc + } + require.Equal(t, metricsDiff, expectedDiff) + }) + } +}