Skip to content

Commit

Permalink
Update AveragerHeap#Add to support updates (ava-labs#1559)
Browse files Browse the repository at this point in the history
  • Loading branch information
StephenButtolph authored Jun 8, 2022
1 parent b363d26 commit 75d3da1
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 10 deletions.
18 changes: 12 additions & 6 deletions utils/math/averager_heap.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,12 @@ var (

// AveragerHeap maintains a heap of the averagers.
type AveragerHeap interface {
// Add will do nothing if [nodeID] is already in the heap
Add(nodeID ids.NodeID, averager Averager)
// Add the average to the heap. If [nodeID] is already in the heap, the
// average will be replaced and the old average will be returned. If there
// was not an old average, false will be returned.
Add(nodeID ids.NodeID, averager Averager) (Averager, bool)
// Remove attempts to remove the average that was added with the provided
// [nodeID], if none is contained in the heap, [false] will be returned
// [nodeID], if none is contained in the heap, [false] will be returned.
Remove(nodeID ids.NodeID) (Averager, bool)
// Pop attempts to remove the node with either the largest or smallest
// average, depending on if this is a max heap or a min heap, respectively.
Expand Down Expand Up @@ -64,15 +66,19 @@ func NewMaxAveragerHeap() AveragerHeap {
}}
}

func (h averagerHeap) Add(nodeID ids.NodeID, averager Averager) {
if _, exists := h.b.nodeIDToEntry[nodeID]; exists {
return
func (h averagerHeap) Add(nodeID ids.NodeID, averager Averager) (Averager, bool) {
if e, exists := h.b.nodeIDToEntry[nodeID]; exists {
oldAverager := e.averager
e.averager = averager
heap.Fix(h.b, e.index)
return oldAverager, true
}

heap.Push(h.b, &averagerHeapEntry{
nodeID: nodeID,
averager: averager,
})
return nil, false
}

func (h averagerHeap) Remove(nodeID ids.NodeID) (Averager, bool) {
Expand Down
23 changes: 19 additions & 4 deletions utils/math/averager_heap_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@ func TestAveragerHeap(t *testing.T) {
l := test.h.Len()
assert.Zero(l)

test.h.Add(n1, test.a[1])
_, ok = test.h.Add(n1, test.a[1])
assert.False(ok)

n, a, ok := test.h.Peek()
assert.True(ok)
Expand All @@ -61,13 +62,18 @@ func TestAveragerHeap(t *testing.T) {
l = test.h.Len()
assert.Equal(1, l)

test.h.Add(n1, test.a[1])
a, ok = test.h.Add(n1, test.a[1])
assert.True(ok)
assert.Equal(test.a[1], a)

l = test.h.Len()
assert.Equal(1, l)

test.h.Add(n0, test.a[0])
test.h.Add(n2, test.a[2])
_, ok = test.h.Add(n0, test.a[0])
assert.False(ok)

_, ok = test.h.Add(n2, test.a[2])
assert.False(ok)

n, a, ok = test.h.Pop()
assert.True(ok)
Expand All @@ -89,5 +95,14 @@ func TestAveragerHeap(t *testing.T) {

l = test.h.Len()
assert.Equal(1, l)

a, ok = test.h.Add(n2, test.a[0])
assert.True(ok)
assert.Equal(test.a[2], a)

n, a, ok = test.h.Pop()
assert.True(ok)
assert.Equal(n2, n)
assert.Equal(test.a[0], a)
}
}

0 comments on commit 75d3da1

Please sign in to comment.