From ff1ba6cf22672b726a122822aaea73df50e61fce Mon Sep 17 00:00:00 2001 From: sado Date: Tue, 11 Apr 2023 11:02:05 +0800 Subject: [PATCH] finish consistent hash --- kit/hash/consistent.go | 144 +++++++++++++++++++++++++++++++++++- kit/hash/consistent_test.go | 96 ++++++++++++++++++++++++ 2 files changed, 237 insertions(+), 3 deletions(-) create mode 100644 kit/hash/consistent_test.go diff --git a/kit/hash/consistent.go b/kit/hash/consistent.go index a9176e4..802090b 100644 --- a/kit/hash/consistent.go +++ b/kit/hash/consistent.go @@ -1,6 +1,9 @@ package hash import ( + "fmt" + "sort" + "strconv" "sync" "github.com/spaolacci/murmur3" @@ -15,8 +18,10 @@ const ( type ( Consistent interface { Add(node interface{}, opts ...ConsistentAddWith) - Get(v interface{}) (node interface{}, has bool) + Get(node interface{}) (value interface{}, has bool) Remove(node interface{}) + // Cascade 多次hash排序后的hash数组 + Cascade(node interface{}, opts ...ConsistentAddWith) []uint64 } consistent struct { @@ -88,17 +93,150 @@ func NewConsistent(withs ...ConsistentWith) Consistent { dft.hash = murmur3.Sum64 } - return &consistent{opt: dft} + return &consistent{ + opt: dft, + vtrNum: dft.vtr, + vtrKeys: make([]uint64, 0), + vtrRing: make(map[uint64][]interface{}), + nodes: make(map[string]struct{}), + } +} + +func (c *consistent) Cascade(node interface{}, opts ...ConsistentAddWith) []uint64 { + dft := &optionAdd{vtr: c.vtrNum, weight: 0} + for i := range opts { + opts[i](dft) + } + + vtr := c.vtrNum + if dft.weight > 0 { + vtr = c.vtrNum * dft.weight / maxWeight + } + + if vtr > c.vtrNum { + vtr = c.vtrNum + } + + nodeExpr := c.marshal(node) + c.lock.Lock() + defer c.lock.Unlock() + + keys := make([]uint64, 0) + for i := int64(0); i < vtr; i++ { + hashV := c.opt.hash([]byte(nodeExpr + strconv.Itoa(int(i)))) + keys = append(keys, hashV) + } + + sort.Slice(keys, func(i, j int) bool { + return keys[i] < keys[j] + }) + + return keys } func (c *consistent) Add(node interface{}, opts ...ConsistentAddWith) { + dft := &optionAdd{vtr: c.vtrNum, weight: 0} + for i := range opts { + opts[i](dft) + } + + vtr := c.vtrNum + if dft.weight > 0 { + vtr = c.vtrNum * dft.weight / maxWeight + } + + if vtr > c.vtrNum { + vtr = c.vtrNum + } + nodeExpr := c.marshal(node) + c.lock.Lock() + defer c.lock.Unlock() + c.addNode(nodeExpr) + + for i := int64(0); i < vtr; i++ { + hashV := c.opt.hash([]byte(nodeExpr + strconv.Itoa(int(i)))) + c.vtrKeys = append(c.vtrKeys, hashV) + c.vtrRing[hashV] = append(c.vtrRing[hashV], node) + } + + sort.Slice(c.vtrKeys, func(i, j int) bool { + return c.vtrKeys[i] < c.vtrKeys[j] + }) } -func (c *consistent) Get(v interface{}) (node interface{}, has bool) { +func (c *consistent) Get(node interface{}) (value interface{}, has bool) { + c.lock.Lock() + defer c.lock.Unlock() + + if len(c.vtrRing) == 0 { + return nil, false + } + nodeExpr := c.marshal(node) + hashV := c.opt.hash([]byte(nodeExpr)) + index := sort.Search(len(c.vtrKeys), func(i int) bool { + return c.vtrKeys[i] >= hashV + }) % len(c.vtrKeys) + + nodes := c.vtrRing[c.vtrKeys[index]] + switch len(nodes) { + case 0: + return nil, false + case 1: + return nodes[0], true + default: + index := c.opt.hash([]byte(c.innerMarshal(node))) + pos := int(index % uint64(len(nodes))) + return nodes[pos], true + } } func (c *consistent) Remove(node interface{}) { + nodeExpr := c.marshal(node) + + c.lock.Lock() + defer c.lock.Unlock() + + if _, ok := c.nodes[nodeExpr]; !ok { + return + } + + for i := int64(0); i < c.vtrNum; i++ { + hashV := c.opt.hash([]byte(nodeExpr + strconv.Itoa(int(i)))) + index := sort.Search(len(c.vtrKeys), func(i int) bool { + return c.vtrKeys[i] >= hashV + }) + if index < len(c.vtrKeys) && c.vtrKeys[index] == hashV { + c.vtrKeys = append(c.vtrKeys[:index], c.vtrKeys[index+1:]...) + } + + if _, ok := c.vtrRing[hashV]; ok { + newNodes := c.vtrRing[hashV][:0] + for _, node := range c.vtrRing[hashV] { + if c.marshal(node) != nodeExpr { + newNodes = append(newNodes, node) + } + } + if len(newNodes) > 0 { + c.vtrRing[hashV] = newNodes + } else { + delete(c.vtrRing, hashV) + } + } + } + + delete(c.nodes, nodeExpr) +} + +func (c *consistent) addNode(key string) { + c.nodes[key] = struct{}{} +} + +func (c *consistent) marshal(v interface{}) string { + return fmt.Sprintf("%v", v) +} +func (c *consistent) innerMarshal(node interface{}) string { + return fmt.Sprintf("%d:%v", prime, node) } diff --git a/kit/hash/consistent_test.go b/kit/hash/consistent_test.go new file mode 100644 index 0000000..360226c --- /dev/null +++ b/kit/hash/consistent_test.go @@ -0,0 +1,96 @@ +package hash + +import ( + "math" + "strconv" + "testing" + + "github.com/stretchr/testify/assert" +) + +func Test_NewConsistent(t *testing.T) { + assert.NotPanics(t, func() { + NewConsistent() + }) + + assert.NotPanics(t, func() { + NewConsistent(ConsistentWithHash(nil)) + }) + + assert.NotPanics(t, func() { + NewConsistent(ConsistentWithVtr(-1)) + }) +} + +func Test_Consistent_Get(t *testing.T) { + ch := NewConsistent() + for i := 0; i < 20; i++ { + ch.Add("prefix" + strconv.Itoa(i)) + } + + keys := make(map[int]string, 1000) + for i := 0; i < 1000; i++ { + key, ok := ch.Get(1000 + i) + assert.True(t, ok) + assert.NotNil(t, key) + keys[i] = key.(string) + } + +} + +func Test_Consistent_Cascade(t *testing.T) { + hash := NewConsistent(ConsistentWithVtr(10)) + + cascade := hash.Cascade("foo") + t.Logf("cascade: %v", cascade) + assert.True(t, len(cascade) == 100) +} + +func TestConsistentHash(t *testing.T) { + ch := NewConsistent() + val, ok := ch.Get("any") + assert.False(t, ok) + assert.Nil(t, val) + + for i := 0; i < 20; i++ { + ch.Add("localhost:"+strconv.Itoa(i), ConsistentAddWithVtr(minVirtual<<1)) + } + + keys := make(map[string]int) + for i := 0; i < 1000; i++ { + key, ok := ch.Get(1000 + i) + assert.True(t, ok) + keys[key.(string)]++ + } + + mi := make(map[interface{}]int, len(keys)) + for k, v := range keys { + mi[k] = v + } + entropy := calcEntropy(mi) + assert.True(t, entropy > .95) +} + +const epsilon = 1e-6 + +func calcEntropy(m map[interface{}]int) float64 { + if len(m) == 0 || len(m) == 1 { + return 1 + } + + var entropy float64 + var total int + for _, v := range m { + total += v + } + + for _, v := range m { + proba := float64(v) / float64(total) + if proba < epsilon { + proba = epsilon + } + entropy -= proba * math.Log2(proba) + } + + return entropy / math.Log2(float64(len(m))) +}