From ab3298b48ea3d204a5179e8d68d03c5924f44efe Mon Sep 17 00:00:00 2001
From: Jayant Shrivastava <jayants@cockroachlabs.com>
Date: Wed, 29 Nov 2023 16:56:29 -0500
Subject: [PATCH] changefeedccl: reduce rebalancing memory usage from O(ranges)
 to O(spans)

Previously, the `rebalanceSpanPartitions` would use O(ranges) memory. This change
rewrites it to use range iterators, reducing the memory usage to O(spans).

This change also adds a randomized test to assert that all spans are accounted for after
rebalancing. It also adds one more unit test.

Informs: #113898
Epic: None
---
 pkg/ccl/changefeedccl/BUILD.bazel             |   1 +
 pkg/ccl/changefeedccl/changefeed_dist.go      | 168 ++++++++++------
 pkg/ccl/changefeedccl/changefeed_dist_test.go | 182 ++++++++++++++----
 3 files changed, 259 insertions(+), 92 deletions(-)

diff --git a/pkg/ccl/changefeedccl/BUILD.bazel b/pkg/ccl/changefeedccl/BUILD.bazel
index 0fa3d9dafb0f..1c5ec6c8f074 100644
--- a/pkg/ccl/changefeedccl/BUILD.bazel
+++ b/pkg/ccl/changefeedccl/BUILD.bazel
@@ -299,6 +299,7 @@ go_test(
         "//pkg/util/ctxgroup",
         "//pkg/util/encoding",
         "//pkg/util/hlc",
+        "//pkg/util/intsets",
         "//pkg/util/json",
         "//pkg/util/leaktest",
         "//pkg/util/log",
diff --git a/pkg/ccl/changefeedccl/changefeed_dist.go b/pkg/ccl/changefeedccl/changefeed_dist.go
index c25b5bcacbd4..79863e53a570 100644
--- a/pkg/ccl/changefeedccl/changefeed_dist.go
+++ b/pkg/ccl/changefeedccl/changefeed_dist.go
@@ -10,12 +10,14 @@ package changefeedccl
 
 import (
 	"context"
+	"math"
 	"sort"
 
 	"github.com/cockroachdb/cockroach/pkg/ccl/changefeedccl/cdceval"
 	"github.com/cockroachdb/cockroach/pkg/ccl/changefeedccl/changefeedbase"
 	"github.com/cockroachdb/cockroach/pkg/jobs/jobspb"
 	"github.com/cockroachdb/cockroach/pkg/jobs/jobsprofiler"
+	"github.com/cockroachdb/cockroach/pkg/keys"
 	"github.com/cockroachdb/cockroach/pkg/kv"
 	"github.com/cockroachdb/cockroach/pkg/kv/kvclient/kvcoord"
 	"github.com/cockroachdb/cockroach/pkg/roachpb"
@@ -400,9 +402,9 @@ func makePlan(
 			}
 			sender := execCtx.ExecCfg().DB.NonTransactionalSender()
 			distSender := sender.(*kv.CrossRangeTxnWrapperSender).Wrapped().(*kvcoord.DistSender)
-
+			ri := kvcoord.MakeRangeIterator(distSender)
 			spanPartitions, err = rebalanceSpanPartitions(
-				ctx, &distResolver{distSender}, rebalanceThreshold.Get(sv), spanPartitions)
+				ctx, &ri, rebalanceThreshold.Get(sv), spanPartitions)
 			if err != nil {
 				return nil, nil, err
 			}
@@ -549,6 +551,7 @@ func (w *changefeedResultWriter) Err() error {
 	return w.err
 }
 
+// TODO(#120427): improve this to be more useful.
 var rebalanceThreshold = settings.RegisterFloatSetting(
 	settings.ApplicationLevel,
 	"changefeed.balance_range_distribution.sensitivity",
@@ -557,80 +560,127 @@ var rebalanceThreshold = settings.RegisterFloatSetting(
 	settings.PositiveFloat,
 )
 
-type rangeResolver interface {
-	getRangesForSpans(ctx context.Context, spans []roachpb.Span) ([]roachpb.Span, error)
-}
-
-type distResolver struct {
-	*kvcoord.DistSender
+type rangeIterator interface {
+	Desc() *roachpb.RangeDescriptor
+	NeedAnother(rs roachpb.RSpan) bool
+	Valid() bool
+	Error() error
+	Next(ctx context.Context)
+	Seek(ctx context.Context, key roachpb.RKey, scanDir kvcoord.ScanDirection)
 }
 
-func (r *distResolver) getRangesForSpans(
-	ctx context.Context, spans []roachpb.Span,
-) ([]roachpb.Span, error) {
-	spans, _, err := r.DistSender.AllRangeSpans(ctx, spans)
-	return spans, err
+// rebalancingPartition is a container used to store a partition undergoing
+// rebalancing.
+type rebalancingPartition struct {
+	// These fields store the current number of ranges and spans in this partition.
+	// They are initialized corresponding to the sql.SpanPartition partition below
+	// and mutated during rebalancing.
+	numRanges int
+	group     roachpb.SpanGroup
+
+	// The original span partition corresponding to this bucket and its
+	// index in the original []sql.SpanPartition.
+	part sql.SpanPartition
+	pIdx int
 }
 
 func rebalanceSpanPartitions(
-	ctx context.Context, r rangeResolver, sensitivity float64, p []sql.SpanPartition,
+	ctx context.Context, ri rangeIterator, sensitivity float64, partitions []sql.SpanPartition,
 ) ([]sql.SpanPartition, error) {
-	if len(p) <= 1 {
-		return p, nil
+	if len(partitions) <= 1 {
+		return partitions, nil
 	}
 
-	// Explode set of spans into set of ranges.
-	// TODO(yevgeniy): This might not be great if the tables are huge.
-	numRanges := 0
-	for i := range p {
-		spans, err := r.getRangesForSpans(ctx, p[i].Spans)
-		if err != nil {
-			return nil, err
+	// Create partition builder structs for the partitions array above.
+	var builders = make([]rebalancingPartition, len(partitions))
+	var totalRanges int
+	for i, p := range partitions {
+		builders[i].part = p
+		builders[i].pIdx = i
+		nRanges, ok := p.NumRanges()
+		// We cannot rebalance if we're missing range information.
+		if !ok {
+			log.Warning(ctx, "skipping rebalance due to missing range info")
+			return partitions, nil
 		}
-		p[i].Spans = spans
-		numRanges += len(spans)
+		builders[i].numRanges = nRanges
+		totalRanges += nRanges
+		builders[i].group.Add(p.Spans...)
 	}
 
 	// Sort descending based on the number of ranges.
-	sort.Slice(p, func(i, j int) bool {
-		return len(p[i].Spans) > len(p[j].Spans)
+	sort.Slice(builders, func(i, j int) bool {
+		return builders[i].numRanges > builders[j].numRanges
 	})
 
-	targetRanges := int((1 + sensitivity) * float64(numRanges) / float64(len(p)))
-
-	for i, j := 0, len(p)-1; i < j && len(p[i].Spans) > targetRanges && len(p[j].Spans) < targetRanges; {
-		from, to := i, j
-
-		// Figure out how many ranges we can move.
-		numToMove := len(p[from].Spans) - targetRanges
-		canMove := targetRanges - len(p[to].Spans)
-		if numToMove <= canMove {
-			i++
-		}
-		if canMove <= numToMove {
-			numToMove = canMove
-			j--
-		}
-		if numToMove == 0 {
-			break
+	targetRanges := int(math.Ceil((1 + sensitivity) * float64(totalRanges) / float64(len(partitions))))
+	to := len(builders) - 1
+	from := 0
+
+	// In each iteration of the outer loop, check if `from` has too many ranges.
+	// If so, move them to other partitions which need more ranges
+	// starting from `to` and moving down. Otherwise, increment `from` and check
+	// again.
+	for ; from < to && builders[from].numRanges > targetRanges; from++ {
+		// numToMove is the number of ranges which need to be moved out of `from`
+		// to other partitions.
+		numToMove := builders[from].numRanges - targetRanges
+		count := 0
+		needMore := func() bool {
+			return count < numToMove
 		}
+		// Iterate over all the spans in `from`.
+		for spanIdx := 0; from < to && needMore() && spanIdx < len(builders[from].part.Spans); spanIdx++ {
+			sp := builders[from].part.Spans[spanIdx]
+			rSpan, err := keys.SpanAddr(sp)
+			if err != nil {
+				return nil, err
+			}
+			// Iterate over the ranges in the current span.
+			for ri.Seek(ctx, rSpan.Key, kvcoord.Ascending); from < to && needMore(); ri.Next(ctx) {
+				// Error check.
+				if !ri.Valid() {
+					return nil, ri.Error()
+				}
 
-		// Move numToMove spans from 'from' to 'to'.
-		idx := len(p[from].Spans) - numToMove
-		p[to].Spans = append(p[to].Spans, p[from].Spans[idx:]...)
-		p[from].Spans = p[from].Spans[:idx]
+				// Move one range from `from` to `to`.
+				count += 1
+				builders[from].numRanges -= 1
+				builders[to].numRanges += 1
+				// If the range boundaries are outside the original span, trim
+				// the range.
+				startKey := ri.Desc().StartKey
+				if startKey.Compare(rSpan.Key) == -1 {
+					startKey = rSpan.Key
+				}
+				endKey := ri.Desc().EndKey
+				if endKey.Compare(rSpan.EndKey) == 1 {
+					endKey = rSpan.EndKey
+				}
+				diff := roachpb.Span{
+					Key: startKey.AsRawKey(), EndKey: endKey.AsRawKey(),
+				}
+				builders[from].group.Sub(diff)
+				builders[to].group.Add(diff)
+
+				// Since we moved a range, `to` may have enough ranges.
+				// Decrement `to` until we find a new partition which needs more
+				// ranges.
+				for from < to && builders[to].numRanges >= targetRanges {
+					to--
+				}
+				// No more ranges in this span.
+				if !ri.NeedAnother(rSpan) {
+					break
+				}
+			}
+		}
 	}
 
-	// Collapse ranges into nice set of contiguous spans.
-	for i := range p {
-		var g roachpb.SpanGroup
-		g.Add(p[i].Spans...)
-		p[i].Spans = g.Slice()
+	// Overwrite the original partitions slice with the balanced partitions.
+	for _, b := range builders {
+		partitions[b.pIdx] = sql.MakeSpanPartitionWithRangeCount(
+			b.part.SQLInstanceID, b.group.Slice(), b.numRanges)
 	}
-
-	// Finally, re-sort based on the node id.
-	sort.Slice(p, func(i, j int) bool {
-		return p[i].SQLInstanceID < p[j].SQLInstanceID
-	})
-	return p, nil
+	return partitions, nil
 }
diff --git a/pkg/ccl/changefeedccl/changefeed_dist_test.go b/pkg/ccl/changefeedccl/changefeed_dist_test.go
index 656802bfa14c..d8f04a26813b 100644
--- a/pkg/ccl/changefeedccl/changefeed_dist_test.go
+++ b/pkg/ccl/changefeedccl/changefeed_dist_test.go
@@ -13,6 +13,7 @@ import (
 	"fmt"
 	"math"
 	"reflect"
+	"sort"
 	"strings"
 	"testing"
 
@@ -30,18 +31,70 @@ import (
 	"github.com/cockroachdb/cockroach/pkg/testutils/skip"
 	"github.com/cockroachdb/cockroach/pkg/testutils/sqlutils"
 	"github.com/cockroachdb/cockroach/pkg/testutils/testcluster"
+	"github.com/cockroachdb/cockroach/pkg/util/intsets"
 	"github.com/cockroachdb/cockroach/pkg/util/leaktest"
 	"github.com/cockroachdb/cockroach/pkg/util/log"
+	"github.com/cockroachdb/cockroach/pkg/util/randutil"
 	"github.com/cockroachdb/errors"
 	"github.com/stretchr/testify/require"
 )
 
+// mockRangeIterator iterates over ranges in a span assuming that each range
+// contains one character.
+type mockRangeIterator struct {
+	rangeDesc *roachpb.RangeDescriptor
+}
+
+var _ rangeIterator = (*mockRangeIterator)(nil)
+
+func nextKey(startKey []byte) []byte {
+	return []byte{startKey[0] + 1}
+}
+
+// Desc implements the rangeIterator interface.
+func (ri *mockRangeIterator) Desc() *roachpb.RangeDescriptor {
+	return ri.rangeDesc
+}
+
+// NeedAnother implements the rangeIterator interface.
+func (ri *mockRangeIterator) NeedAnother(rs roachpb.RSpan) bool {
+	return ri.rangeDesc.EndKey.Less(rs.EndKey)
+}
+
+// Valid implements the rangeIterator interface.
+func (ri *mockRangeIterator) Valid() bool {
+	return true
+}
+
+// Error implements the rangeIterator interface.
+func (ri *mockRangeIterator) Error() error {
+	panic("unexpected call to Error()")
+}
+
+// Next implements the rangeIterator interface.
+func (ri *mockRangeIterator) Next(ctx context.Context) {
+	ri.rangeDesc.StartKey = nextKey(ri.rangeDesc.StartKey)
+	ri.rangeDesc.EndKey = nextKey(ri.rangeDesc.EndKey)
+}
+
+// Seek implements the rangeIterator interface.
+func (ri *mockRangeIterator) Seek(_ context.Context, key roachpb.RKey, _ kvcoord.ScanDirection) {
+	ri.rangeDesc = &roachpb.RangeDescriptor{
+		StartKey: key,
+		EndKey:   nextKey(key),
+	}
+}
+
 var partitions = func(p ...sql.SpanPartition) []sql.SpanPartition {
 	return p
 }
 
 var mkPart = func(n base.SQLInstanceID, spans ...roachpb.Span) sql.SpanPartition {
-	return sql.SpanPartition{SQLInstanceID: n, Spans: spans}
+	var count int
+	for _, sp := range spans {
+		count += int(rune(sp.EndKey[0]) - rune(sp.Key[0]))
+	}
+	return sql.MakeSpanPartitionWithRangeCount(n, spans, count)
 }
 
 // mkRange makes a range containing a single rune.
@@ -67,30 +120,19 @@ var mkSingleLetterRanges = func(start, end rune) (result []roachpb.Span) {
 	return result
 }
 
-// letterRangeResolver resolves spans such that each letter is a range.
-type letterRangeResolver struct{}
-
-func (r *letterRangeResolver) getRangesForSpans(
-	_ context.Context, inSpans []roachpb.Span,
-) (spans []roachpb.Span, _ error) {
-	for _, sp := range inSpans {
-		spans = append(spans, mkSingleLetterRanges(rune(sp.Key[0]), rune(sp.EndKey[0]))...)
-	}
-	return spans, nil
-}
-
 // TestPartitionSpans unit tests the rebalanceSpanPartitions function.
 func TestPartitionSpans(t *testing.T) {
 	defer leaktest.AfterTest(t)()
-	const sensitivity = 0.01
+	defer log.Scope(t).Close(t)
 
 	// 26 nodes, 1 range per node.
 	make26NodesBalanced := func() (p []sql.SpanPartition) {
 		for i := rune(0); i < 26; i += 1 {
-			p = append(p, sql.SpanPartition{
-				SQLInstanceID: base.SQLInstanceID(i + 1),
-				Spans:         []roachpb.Span{mkRange('a' + i)},
-			})
+			p = append(p, sql.MakeSpanPartitionWithRangeCount(
+				base.SQLInstanceID(i+1),
+				[]roachpb.Span{mkRange('z' - i)},
+				1,
+			))
 		}
 		return p
 	}
@@ -98,13 +140,13 @@ func TestPartitionSpans(t *testing.T) {
 	// 26 nodes. All empty except for the first, which has 26 ranges.
 	make26NodesImBalanced := func() (p []sql.SpanPartition) {
 		for i := rune(0); i < 26; i += 1 {
-			sp := sql.SpanPartition{
-				SQLInstanceID: base.SQLInstanceID(i + 1),
-			}
 			if i == 0 {
-				sp.Spans = append(sp.Spans, mkSpan('a', 'z'+1))
+				p = append(p, sql.MakeSpanPartitionWithRangeCount(
+					base.SQLInstanceID(i+1), []roachpb.Span{mkSpan('a', 'z'+1)}, 26))
+			} else {
+				p = append(p, sql.MakeSpanPartitionWithRangeCount(base.SQLInstanceID(i+1), []roachpb.Span{}, 0))
 			}
-			p = append(p, sp)
+
 		}
 		return p
 	}
@@ -122,9 +164,9 @@ func TestPartitionSpans(t *testing.T) {
 				mkPart(3, mkSpan('q', 'z'+1)), // 10
 			),
 			expect: partitions(
-				mkPart(1, mkSpan('a', 'j')),               // 9
-				mkPart(2, mkSpan('j', 'q'), mkRange('z')), // 8
-				mkPart(3, mkSpan('q', 'z')),               // 9
+				mkPart(1, mkSpan('a', 'j')),   // 9
+				mkPart(2, mkSpan('j', 'r')),   // 8
+				mkPart(3, mkSpan('r', 'z'+1)), // 9
 			),
 		},
 		{
@@ -135,9 +177,9 @@ func TestPartitionSpans(t *testing.T) {
 				mkPart(3, mkSpan('c', 'e'), mkSpan('p', 'r')), // 4
 			),
 			expect: partitions(
-				mkPart(1, mkSpan('a', 'c'), mkSpan('e', 'l')), // 9
-				mkPart(2, mkSpan('r', 'z')),                   // 8
-				mkPart(3, mkSpan('c', 'e'), mkSpan('l', 'r')), // 8
+				mkPart(1, mkSpan('o', 'p'), mkSpan('r', 'z')),                   // 9
+				mkPart(2, mkSpan('a', 'c'), mkSpan('e', 'l')),                   // 9
+				mkPart(3, mkSpan('c', 'e'), mkSpan('l', 'o'), mkSpan('p', 'r')), // 7
 			),
 		},
 		{
@@ -148,9 +190,9 @@ func TestPartitionSpans(t *testing.T) {
 				mkPart(3, mkRange('z')),                      // 1
 			),
 			expect: partitions(
-				mkPart(1, mkSpan('a', 'k')),                   // 10
-				mkPart(2, mkSpan('k', 'r'), mkSpan('y', 'z')), // 8
-				mkPart(3, mkSpan('r', 'y'), mkRange('z')),     // 7
+				mkPart(1, mkSpan('p', 'y')),                   // 9
+				mkPart(2, mkSpan('i', 'p'), mkSpan('y', 'z')), // 8
+				mkPart(3, mkSpan('a', 'i'), mkRange('z')),     // 9
 			),
 		},
 		{
@@ -190,7 +232,7 @@ func TestPartitionSpans(t *testing.T) {
 	} {
 		t.Run(tc.name, func(t *testing.T) {
 			sp, err := rebalanceSpanPartitions(context.Background(),
-				&letterRangeResolver{}, sensitivity, tc.input)
+				&mockRangeIterator{}, 0.00, tc.input)
 			t.Log("expected partitions")
 			for _, p := range tc.expect {
 				t.Log(p)
@@ -203,6 +245,80 @@ func TestPartitionSpans(t *testing.T) {
 			require.Equal(t, tc.expect, sp)
 		})
 	}
+
+	dedupe := func(in []int) []int {
+		ret := intsets.Fast{}
+		for _, id := range in {
+			ret.Add(id)
+		}
+		return ret.Ordered()
+	}
+	copySpans := func(partitions []sql.SpanPartition) (g roachpb.SpanGroup) {
+		for _, p := range partitions {
+			for _, sp := range p.Spans {
+				g.Add(sp)
+			}
+		}
+		return
+	}
+	// Create a random input and assert that the output has the same
+	// spans as the input.
+	t.Run("random", func(t *testing.T) {
+		rng, _ := randutil.NewTestRand()
+		numPartitions := rng.Intn(8) + 1
+		numSpans := rng.Intn(25) + 1
+
+		// Randomly create spans and assign them to nodes. For example,
+		// {1 {h-i}, {m-n}, {t-u}}
+		// {2 {a-c}, {d-f}, {l-m}, {s-t}, {x-z}}
+		// {3 {c-d}, {i-j}, {u-w}}
+		// {4 {w-x}}
+		// {5 {f-h}, {p-s}}
+		// {6 {j-k}, {k-l}, {n-o}, {o-p}}
+
+		// First, select some indexes in ['a' ... 'z'] to partition at.
+		spanIdxs := make([]int, numSpans)
+		for i := range spanIdxs {
+			spanIdxs[i] = rng.Intn((int('z')-int('a'))-1) + int('a') + 1
+		}
+		sort.Slice(spanIdxs, func(i int, j int) bool {
+			return spanIdxs[i] < spanIdxs[j]
+		})
+		// Make sure indexes are unique.
+		spanIdxs = dedupe(spanIdxs)
+
+		// Generate spans and assign them randomly to partitions.
+		input := make([]sql.SpanPartition, numPartitions)
+		for i, key := range spanIdxs {
+			assignTo := rng.Intn(numPartitions)
+			if i == 0 {
+				input[assignTo].Spans = append(input[assignTo].Spans, mkSpan('a', (rune(key))))
+			} else {
+				input[assignTo].Spans = append(input[assignTo].Spans, mkSpan((rune(spanIdxs[i-1])), rune(key)))
+			}
+		}
+		last := rng.Intn(numPartitions)
+		input[last].Spans = append(input[last].Spans, mkSpan(rune(spanIdxs[len(spanIdxs)-1]), 'z'))
+
+		// Populate the remaining fields in the partitions.
+		for i := range input {
+			input[i] = mkPart(base.SQLInstanceID(i+1), input[i].Spans...)
+		}
+
+		t.Log(input)
+
+		// Ensure the set of input spans matches the set of output spans.
+		g1 := copySpans(input)
+		output, err := rebalanceSpanPartitions(context.Background(),
+			&mockRangeIterator{}, 0.00, input)
+		require.NoError(t, err)
+
+		t.Log(output)
+
+		g2 := copySpans(output)
+		require.True(t, g1.Encloses(g2.Slice()...))
+		require.True(t, g2.Encloses(g1.Slice()...))
+	})
 }
 
 type rangeDistributionTester struct {