From 33faea8c2afbb06d1d04aebf0d75bca101a0c6d8 Mon Sep 17 00:00:00 2001
From: Arvind Bright <arvind.bright100@gmail.com>
Date: Wed, 29 May 2024 16:50:05 -0700
Subject: [PATCH] ringhash: fix normalizeWeights (#7156)

---
 xds/internal/balancer/ringhash/ring.go        | 43 +++++++++++--------
 xds/internal/balancer/ringhash/ring_test.go   |  6 +--
 .../balancer/ringhash/ringhash_test.go        | 36 ++++++++++++++--
 3 files changed, 61 insertions(+), 24 deletions(-)

diff --git a/xds/internal/balancer/ringhash/ring.go b/xds/internal/balancer/ringhash/ring.go
index 4d7fdb35e722..eac89b5b4d05 100644
--- a/xds/internal/balancer/ringhash/ring.go
+++ b/xds/internal/balancer/ringhash/ring.go
@@ -116,30 +116,37 @@ func newRing(subConns *resolver.AddressMap, minRingSize, maxRingSize uint64, log
 	return &ring{items: items}
 }
 
-// normalizeWeights divides all the weights by the sum, so that the total weight
-// is 1.
+// normalizeWeights calculates the normalized weights for each subConn in the
+// given subConns map. It returns a slice of subConnWithWeight structs, where
+// each struct contains a subConn and its corresponding weight. The function
+// also returns the minimum weight among all subConns.
+//
+// The normalized weight of each subConn is calculated by dividing its weight
+// attribute by the sum of all subConn weights. If the weight attribute is not
+// found on the address, a default weight of 1 is used.
+//
+// The addresses are sorted in ascending order to ensure consistent results.
 //
 // Must be called with a non-empty subConns map.
 func normalizeWeights(subConns *resolver.AddressMap) ([]subConnWithWeight, float64) {
 	var weightSum uint32
-	keys := subConns.Keys()
-	for _, a := range keys {
-		weightSum += getWeightAttribute(a)
+	// Since attributes are explicitly ignored in the AddressMap key, we need to
+	// iterate over the values to get the weights.
+	scVals := subConns.Values()
+	for _, a := range scVals {
+		weightSum += a.(*subConn).weight
 	}
-	ret := make([]subConnWithWeight, 0, len(keys))
-	min := float64(1.0)
-	for _, a := range keys {
-		v, _ := subConns.Get(a)
-		scInfo := v.(*subConn)
-		// getWeightAttribute() returns 1 if the weight attribute is not found
-		// on the address. And since this function is guaranteed to be called
-		// with a non-empty subConns map, weightSum is guaranteed to be
-		// non-zero. So, we need not worry about divide a by zero error here.
-		nw := float64(getWeightAttribute(a)) / float64(weightSum)
+	ret := make([]subConnWithWeight, 0, subConns.Len())
+	min := 1.0
+	for _, a := range scVals {
+		scInfo := a.(*subConn)
+		// (*subConn).weight is set to 1 if the weight attribute is not found on
+		// the address. And since this function is guaranteed to be called with
+		// a non-empty subConns map, weightSum is guaranteed to be non-zero. So,
+		// we need not worry about divide by zero error here.
+		nw := float64(scInfo.weight) / float64(weightSum)
 		ret = append(ret, subConnWithWeight{sc: scInfo, weight: nw})
-		if nw < min {
-			min = nw
-		}
+		min = math.Min(min, nw)
 	}
 	// Sort the addresses to return consistent results.
 	//
diff --git a/xds/internal/balancer/ringhash/ring_test.go b/xds/internal/balancer/ringhash/ring_test.go
index 9c6eb0c242ff..1c3a1985b964 100644
--- a/xds/internal/balancer/ringhash/ring_test.go
+++ b/xds/internal/balancer/ringhash/ring_test.go
@@ -38,9 +38,9 @@ func init() {
 		testAddr("c", 4),
 	}
 	testSubConnMap = resolver.NewAddressMap()
-	testSubConnMap.Set(testAddrs[0], &subConn{addr: "a"})
-	testSubConnMap.Set(testAddrs[1], &subConn{addr: "b"})
-	testSubConnMap.Set(testAddrs[2], &subConn{addr: "c"})
+	testSubConnMap.Set(testAddrs[0], &subConn{addr: "a", weight: 3})
+	testSubConnMap.Set(testAddrs[1], &subConn{addr: "b", weight: 3})
+	testSubConnMap.Set(testAddrs[2], &subConn{addr: "c", weight: 4})
 }
 
 func testAddr(addr string, weight uint32) resolver.Address {
diff --git a/xds/internal/balancer/ringhash/ringhash_test.go b/xds/internal/balancer/ringhash/ringhash_test.go
index a1edfe5d228a..f6778d832f8c 100644
--- a/xds/internal/balancer/ringhash/ringhash_test.go
+++ b/xds/internal/balancer/ringhash/ringhash_test.go
@@ -344,17 +344,26 @@ func (s) TestThreeSubConnsAffinityMultiple(t *testing.T) {
 	}
 }
 
+// TestAddrWeightChange covers the following scenarios after setting up the
+// balancer with 3 addresses [A, B, C]:
+//   - updates balancer with [A, B, C], a new Picker should not be sent.
+//   - updates balancer with [A, B] (C removed), a new Picker is sent and the
+//     ring is updated.
+//   - updates balancer with [A, B], but B has a weight of 2, a new Picker is
+//     sent.  And the new ring should contain the correct number of entries
+//     and weights.
 func (s) TestAddrWeightChange(t *testing.T) {
-	wantAddrs := []resolver.Address{
+	addrs := []resolver.Address{
 		{Addr: testBackendAddrStrs[0]},
 		{Addr: testBackendAddrStrs[1]},
 		{Addr: testBackendAddrStrs[2]},
 	}
-	cc, b, p0 := setupTest(t, wantAddrs)
+	cc, b, p0 := setupTest(t, addrs)
 	ring0 := p0.(*picker).ring
 
+	// Update with the same addresses, should not send a new Picker.
 	if err := b.UpdateClientConnState(balancer.ClientConnState{
-		ResolverState:  resolver.State{Addresses: wantAddrs},
+		ResolverState:  resolver.State{Addresses: addrs},
 		BalancerConfig: testConfig,
 	}); err != nil {
 		t.Fatalf("UpdateClientConnState returned err: %v", err)
@@ -407,6 +416,27 @@ func (s) TestAddrWeightChange(t *testing.T) {
 	if p2.(*picker).ring == ring1 {
 		t.Fatalf("new picker after changing address weight has the same ring as before, want different")
 	}
+	// With the new update, the ring must look like this:
+	//   [
+	//     {idx:0 sc: {addr: testBackendAddrStrs[0], weight: 1}},
+	//     {idx:1 sc: {addr: testBackendAddrStrs[1], weight: 2}},
+	//     {idx:2 sc: {addr: testBackendAddrStrs[2], weight: 2}},
+	//   ].
+	if len(p2.(*picker).ring.items) != 3 {
+		t.Fatalf("new picker after changing address weight has %d entries, want 3", len(p2.(*picker).ring.items))
+	}
+	for _, i := range p2.(*picker).ring.items {
+		if i.sc.addr == testBackendAddrStrs[0] {
+			if i.sc.weight != 1 {
+				t.Fatalf("new picker after changing address weight has weight %d for %v, want 1", i.sc.weight, i.sc.addr)
+			}
+		}
+		if i.sc.addr == testBackendAddrStrs[1] {
+			if i.sc.weight != 2 {
+				t.Fatalf("new picker after changing address weight has weight %d for %v, want 2", i.sc.weight, i.sc.addr)
+			}
+		}
+	}
 }
 
 // TestSubConnToConnectWhenOverallTransientFailure covers the situation when the