From dac0d55ee7cea84702929ee4b8d55edd69ecf09e Mon Sep 17 00:00:00 2001 From: Joshua Kim <20001595+joshua-kim@users.noreply.github.com> Date: Wed, 4 Oct 2023 17:37:28 -0400 Subject: [PATCH] add heap set --- utils/heap/set.go | 56 +++++++++ vms/platformvm/state/merged_iterator.go | 62 ++++------ x/sync/workheap.go | 112 ++++++------------ x/sync/workheap_test.go | 148 ++++-------------------- 4 files changed, 136 insertions(+), 242 deletions(-) create mode 100644 utils/heap/set.go diff --git a/utils/heap/set.go b/utils/heap/set.go new file mode 100644 index 000000000000..467a1528fb26 --- /dev/null +++ b/utils/heap/set.go @@ -0,0 +1,56 @@ +// Copyright (C) 2019-2023, Ava Labs, Inc. All rights reserved. +// See the file LICENSE for licensing terms. + +package heap + +// NewSet returns a heap without duplicates ordered by its values +func NewSet[T comparable](less func(a, b T) bool) Set[T] { + return Set[T]{ + set: Map[T, struct{}]{ + queue: &indexedQueue[T, struct{}]{ + entries: make([]entry[T, struct{}], 0), + index: make(map[T]int), + less: func(a, b entry[T, struct{}]) bool { + return less(a.k, b.k) + }, + }, + }, + } +} + +type Set[T comparable] struct { + set Map[T, struct{}] +} + +// Push returns if a value was overwritten +func (s Set[T]) Push(t T) bool { + _, ok := s.set.Push(t, struct{}{}) + return ok +} + +func (s Set[T]) Pop() (T, bool) { + pop, _, ok := s.set.Pop() + return pop, ok +} + +func (s Set[T]) Peek() (T, bool) { + peek, _, ok := s.set.Peek() + return peek, ok +} + +func (s Set[T]) Len() int { + return s.set.Len() +} + +func (s Set[T]) Remove(i int) T { + remove, _ := s.set.Remove(i) + return remove +} + +func (s Set[T]) Fix(i int) { + s.set.Fix(i) +} + +func (s Set[T]) Index() map[T]int { + return s.set.queue.index +} diff --git a/vms/platformvm/state/merged_iterator.go b/vms/platformvm/state/merged_iterator.go index 2be90a1a3106..08a7e34811b2 100644 --- a/vms/platformvm/state/merged_iterator.go +++ b/vms/platformvm/state/merged_iterator.go @@ -3,18 +3,15 @@ package state -import "container/heap" +import "github.com/ava-labs/avalanchego/utils/heap" -var ( - _ StakerIterator = (*mergedIterator)(nil) - _ heap.Interface = (*mergedIterator)(nil) -) +var _ StakerIterator = (*mergedIterator)(nil) type mergedIterator struct { initialized bool // heap only contains iterators that have been initialized and are not // exhausted. - heap []StakerIterator + heap heap.Queue[StakerIterator] } // Returns an iterator that returns all of the elements of [stakers] in order. @@ -36,15 +33,19 @@ func NewMergedIterator(stakers ...StakerIterator) StakerIterator { } it := &mergedIterator{ - heap: stakers, + heap: heap.OfQueue[StakerIterator]( + func(a, b StakerIterator) bool { + return a.Value().Less(b.Value()) + }, + stakers..., + ), } - heap.Init(it) return it } func (it *mergedIterator) Next() bool { - if len(it.heap) == 0 { + if it.heap.Len() == 0 { return false } @@ -57,54 +58,31 @@ func (it *mergedIterator) Next() bool { } // Update the heap root. - current := it.heap[0] + current, _ := it.heap.Peek() if current.Next() { // Calling Next() above modifies [current] so we fix the heap. - heap.Fix(it, 0) + it.heap.Fix(0) return true } // The old root is exhausted. Remove it from the heap. current.Release() - heap.Pop(it) - return len(it.heap) > 0 + it.heap.Pop() + return it.heap.Len() > 0 } func (it *mergedIterator) Value() *Staker { - return it.heap[0].Value() + peek, _ := it.heap.Peek() + return peek.Value() } -// When Release() returns, Release() has been called on each element of -// [stakers]. func (it *mergedIterator) Release() { - for _, it := range it.heap { - it.Release() + for it.heap.Len() > 0 { + removed, _ := it.heap.Pop() + removed.Release() } - it.heap = nil } -// Returns the number of sub-iterators in [it]. func (it *mergedIterator) Len() int { - return len(it.heap) -} - -func (it *mergedIterator) Less(i, j int) bool { - return it.heap[i].Value().Less(it.heap[j].Value()) -} - -func (it *mergedIterator) Swap(i, j int) { - it.heap[j], it.heap[i] = it.heap[i], it.heap[j] -} - -// Push is never actually used - but we need it to implement heap.Interface. -func (it *mergedIterator) Push(value interface{}) { - it.heap = append(it.heap, value.(StakerIterator)) -} - -func (it *mergedIterator) Pop() interface{} { - newLength := len(it.heap) - 1 - value := it.heap[newLength] - it.heap[newLength] = nil - it.heap = it.heap[:newLength] - return value + return it.heap.Len() } diff --git a/x/sync/workheap.go b/x/sync/workheap.go index 36b27e229296..cd5132217db3 100644 --- a/x/sync/workheap.go +++ b/x/sync/workheap.go @@ -5,23 +5,13 @@ package sync import ( "bytes" - "container/heap" - + "github.com/ava-labs/avalanchego/utils/heap" "github.com/ava-labs/avalanchego/utils/math" "github.com/ava-labs/avalanchego/utils/maybe" "github.com/google/btree" ) -var _ heap.Interface = (*innerHeap)(nil) - -type heapItem struct { - workItem *workItem - heapIndex int -} - -type innerHeap []*heapItem - // A priority queue of syncWorkItems. // Note that work item ranges never overlap. // Supports range merging and priority updating. @@ -29,20 +19,23 @@ type innerHeap []*heapItem type workHeap struct { // Max heap of items by priority. // i.e. heap.Pop returns highest priority item. - innerHeap innerHeap + innerHeap heap.Set[*workItem] // The heap items sorted by range start. // A Nothing start is considered to be the smallest. - sortedItems *btree.BTreeG[*heapItem] + sortedItems *btree.BTreeG[*workItem] closed bool } func newWorkHeap() *workHeap { return &workHeap{ + innerHeap: heap.NewSet[*workItem](func(a, b *workItem) bool { + return a.priority > b.priority + }), sortedItems: btree.NewG( 2, - func(a, b *heapItem) bool { - aNothing := a.workItem.start.IsNothing() - bNothing := b.workItem.start.IsNothing() + func(a, b *workItem) bool { + aNothing := a.start.IsNothing() + bNothing := b.start.IsNothing() if aNothing { // [a] is Nothing, so if [b] is Nothing, they're equal. // Otherwise, [b] is greater. @@ -53,9 +46,10 @@ func newWorkHeap() *workHeap { return false } // [a] and [b] both contain values. Compare the values. - return bytes.Compare(a.workItem.start.Value(), b.workItem.start.Value()) < 0 + return bytes.Compare(a.start.Value(), b.start.Value()) < 0 }, ), + closed: false, } } @@ -70,10 +64,8 @@ func (wh *workHeap) Insert(item *workItem) { return } - wrappedItem := &heapItem{workItem: item} - - heap.Push(&wh.innerHeap, wrappedItem) - wh.sortedItems.ReplaceOrInsert(wrappedItem) + wh.innerHeap.Push(item) + wh.sortedItems.ReplaceOrInsert(item) } // Pops and returns a work item from the heap. @@ -82,9 +74,9 @@ func (wh *workHeap) GetWork() *workItem { if wh.closed || wh.Len() == 0 { return nil } - item := heap.Pop(&wh.innerHeap).(*heapItem) + item, _ := wh.innerHeap.Pop() wh.sortedItems.Delete(item) - return item.workItem + return item } // Insert the item into the heap, merging it with existing items @@ -99,25 +91,23 @@ func (wh *workHeap) MergeInsert(item *workItem) { return } - var mergedBefore, mergedAfter *heapItem - searchItem := &heapItem{ - workItem: &workItem{ - start: item.start, - }, + var mergedBefore, mergedAfter *workItem + searchItem := &workItem{ + start: item.start, } // Find the item with the greatest start range which is less than [item.start]. // Note that the iterator function will run at most once, since it always returns false. wh.sortedItems.DescendLessOrEqual( searchItem, - func(beforeItem *heapItem) bool { - if item.localRootID == beforeItem.workItem.localRootID && - maybe.Equal(item.start, beforeItem.workItem.end, bytes.Equal) { + func(beforeItem *workItem) bool { + if item.localRootID == beforeItem.localRootID && + maybe.Equal(item.start, beforeItem.end, bytes.Equal) { // [beforeItem.start, beforeItem.end] and [item.start, item.end] are // merged into [beforeItem.start, item.end] - beforeItem.workItem.end = item.end - beforeItem.workItem.priority = math.Max(item.priority, beforeItem.workItem.priority) - heap.Fix(&wh.innerHeap, beforeItem.heapIndex) + beforeItem.end = item.end + beforeItem.priority = math.Max(item.priority, beforeItem.priority) + wh.innerHeap.Fix(wh.innerHeap.Index()[beforeItem]) mergedBefore = beforeItem } return false @@ -127,14 +117,14 @@ func (wh *workHeap) MergeInsert(item *workItem) { // Note that the iterator function will run at most once, since it always returns false. wh.sortedItems.AscendGreaterOrEqual( searchItem, - func(afterItem *heapItem) bool { - if item.localRootID == afterItem.workItem.localRootID && - maybe.Equal(item.end, afterItem.workItem.start, bytes.Equal) { + func(afterItem *workItem) bool { + if item.localRootID == afterItem.localRootID && + maybe.Equal(item.end, afterItem.start, bytes.Equal) { // [item.start, item.end] and [afterItem.start, afterItem.end] are merged into // [item.start, afterItem.end]. - afterItem.workItem.start = item.start - afterItem.workItem.priority = math.Max(item.priority, afterItem.workItem.priority) - heap.Fix(&wh.innerHeap, afterItem.heapIndex) + afterItem.start = item.start + afterItem.priority = math.Max(item.priority, afterItem.priority) + wh.innerHeap.Fix(wh.innerHeap.Index()[afterItem]) mergedAfter = afterItem } return false @@ -144,12 +134,12 @@ func (wh *workHeap) MergeInsert(item *workItem) { // we can combine the before item with the after item if mergedBefore != nil && mergedAfter != nil { // combine the two ranges - mergedBefore.workItem.end = mergedAfter.workItem.end + mergedBefore.end = mergedAfter.end // remove the second range since it is now covered by the first wh.remove(mergedAfter) // update the priority - mergedBefore.workItem.priority = math.Max(mergedBefore.workItem.priority, mergedAfter.workItem.priority) - heap.Fix(&wh.innerHeap, mergedBefore.heapIndex) + mergedBefore.priority = math.Max(mergedBefore.priority, mergedAfter.priority) + wh.innerHeap.Fix(wh.innerHeap.Index()[mergedBefore]) } // nothing was merged, so add new item to the heap @@ -160,43 +150,11 @@ func (wh *workHeap) MergeInsert(item *workItem) { } // Deletes [item] from the heap. -func (wh *workHeap) remove(item *heapItem) { - heap.Remove(&wh.innerHeap, item.heapIndex) - +func (wh *workHeap) remove(item *workItem) { + wh.innerHeap.Remove(wh.innerHeap.Index()[item]) wh.sortedItems.Delete(item) } func (wh *workHeap) Len() int { return wh.innerHeap.Len() } - -// below this line are the implementations required for heap.Interface - -func (h innerHeap) Len() int { - return len(h) -} - -func (h innerHeap) Less(i int, j int) bool { - return h[i].workItem.priority > h[j].workItem.priority -} - -func (h innerHeap) Swap(i int, j int) { - h[i], h[j] = h[j], h[i] - h[i].heapIndex = i - h[j].heapIndex = j -} - -func (h *innerHeap) Pop() interface{} { - old := *h - n := len(old) - item := old[n-1] - old[n-1] = nil - *h = old[0 : n-1] - return item -} - -func (h *innerHeap) Push(x interface{}) { - item := x.(*heapItem) - item.heapIndex = len(*h) - *h = append(*h, item) -} diff --git a/x/sync/workheap_test.go b/x/sync/workheap_test.go index 7f50468a1fbd..0a3262a9310f 100644 --- a/x/sync/workheap_test.go +++ b/x/sync/workheap_test.go @@ -17,102 +17,6 @@ import ( "github.com/ava-labs/avalanchego/utils/maybe" ) -// Tests heap.Interface methods Push, Pop, Swap, Len, Less. -func Test_WorkHeap_InnerHeap(t *testing.T) { - require := require.New(t) - - lowPriorityItem := &heapItem{ - workItem: &workItem{ - start: maybe.Some([]byte{1}), - end: maybe.Some([]byte{2}), - priority: lowPriority, - localRootID: ids.GenerateTestID(), - }, - } - - mediumPriorityItem := &heapItem{ - workItem: &workItem{ - start: maybe.Some([]byte{3}), - end: maybe.Some([]byte{4}), - priority: medPriority, - localRootID: ids.GenerateTestID(), - }, - } - - highPriorityItem := &heapItem{ - workItem: &workItem{ - start: maybe.Some([]byte{5}), - end: maybe.Some([]byte{6}), - priority: highPriority, - localRootID: ids.GenerateTestID(), - }, - } - - h := innerHeap{} - require.Zero(h.Len()) - - // Note we're calling Push and Pop on the heap directly, - // not using heap.Push and heap.Pop. - h.Push(lowPriorityItem) - // Heap has [lowPriorityItem] - require.Equal(1, h.Len()) - require.Equal(lowPriorityItem, h[0]) - - got := h.Pop() - // Heap has [] - require.Equal(lowPriorityItem, got) - require.Zero(h.Len()) - - h.Push(lowPriorityItem) - h.Push(mediumPriorityItem) - // Heap has [lowPriorityItem, mediumPriorityItem] - require.Equal(2, h.Len()) - require.Equal(lowPriorityItem, h[0]) - require.Equal(mediumPriorityItem, h[1]) - - got = h.Pop() - // Heap has [lowPriorityItem] - require.Equal(mediumPriorityItem, got) - require.Equal(1, h.Len()) - - got = h.Pop() - // Heap has [] - require.Equal(lowPriorityItem, got) - require.Zero(h.Len()) - - h.Push(mediumPriorityItem) - h.Push(lowPriorityItem) - h.Push(highPriorityItem) - // Heap has [mediumPriorityItem, lowPriorityItem, highPriorityItem] - require.Equal(mediumPriorityItem, h[0]) - require.Equal(lowPriorityItem, h[1]) - require.Equal(highPriorityItem, h[2]) - - h.Swap(0, 1) - // Heap has [lowPriorityItem, mediumPriorityItem, highPriorityItem] - require.Equal(lowPriorityItem, h[0]) - require.Equal(mediumPriorityItem, h[1]) - require.Equal(highPriorityItem, h[2]) - - h.Swap(1, 2) - // Heap has [lowPriorityItem, highPriorityItem, mediumPriorityItem] - require.Equal(lowPriorityItem, h[0]) - require.Equal(highPriorityItem, h[1]) - require.Equal(mediumPriorityItem, h[2]) - - h.Swap(0, 2) - // Heap has [mediumPriorityItem, highPriorityItem, lowPriorityItem] - require.Equal(mediumPriorityItem, h[0]) - require.Equal(highPriorityItem, h[1]) - require.Equal(lowPriorityItem, h[2]) - require.False(h.Less(0, 1)) - require.True(h.Less(1, 0)) - require.True(h.Less(1, 2)) - require.False(h.Less(2, 1)) - require.True(h.Less(0, 2)) - require.False(h.Less(2, 0)) -} - // Tests Insert and GetWork func Test_WorkHeap_Insert_GetWork(t *testing.T) { require := require.New(t) @@ -144,8 +48,8 @@ func Test_WorkHeap_Insert_GetWork(t *testing.T) { // Ensure [sortedItems] is in right order. got := []*workItem{} h.sortedItems.Ascend( - func(i *heapItem) bool { - got = append(got, i.workItem) + func(i *workItem) bool { + got = append(got, i) return true }, ) @@ -195,40 +99,42 @@ func Test_WorkHeap_remove(t *testing.T) { h.Insert(lowPriorityItem) - wrappedLowPriorityItem := h.innerHeap[0] + wrappedLowPriorityItem, ok := h.innerHeap.Peek() + require.True(ok) h.remove(wrappedLowPriorityItem) require.Zero(h.Len()) - require.Empty(h.innerHeap) require.Zero(h.sortedItems.Len()) h.Insert(lowPriorityItem) h.Insert(mediumPriorityItem) h.Insert(highPriorityItem) - wrappedhighPriorityItem := h.innerHeap[0] - require.Equal(highPriorityItem, wrappedhighPriorityItem.workItem) + wrappedhighPriorityItem, ok := h.innerHeap.Peek() + require.True(ok) + require.Equal(highPriorityItem, wrappedhighPriorityItem) h.remove(wrappedhighPriorityItem) require.Equal(2, h.Len()) - require.Len(h.innerHeap, 2) require.Equal(2, h.sortedItems.Len()) - require.Zero(h.innerHeap[0].heapIndex) - require.Equal(mediumPriorityItem, h.innerHeap[0].workItem) + got, ok := h.innerHeap.Peek() + require.True(ok) + require.Equal(mediumPriorityItem, got) - wrappedMediumPriorityItem := h.innerHeap[0] - require.Equal(mediumPriorityItem, wrappedMediumPriorityItem.workItem) + wrappedMediumPriorityItem, ok := h.innerHeap.Peek() + require.True(ok) + require.Equal(mediumPriorityItem, wrappedMediumPriorityItem) h.remove(wrappedMediumPriorityItem) require.Equal(1, h.Len()) - require.Len(h.innerHeap, 1) require.Equal(1, h.sortedItems.Len()) - require.Zero(h.innerHeap[0].heapIndex) - require.Equal(lowPriorityItem, h.innerHeap[0].workItem) + got, ok = h.innerHeap.Peek() + require.True(ok) + require.Equal(lowPriorityItem, got) - wrappedLowPriorityItem = h.innerHeap[0] - require.Equal(lowPriorityItem, wrappedLowPriorityItem.workItem) + wrappedLowPriorityItem, ok = h.innerHeap.Peek() + require.True(ok) + require.Equal(lowPriorityItem, wrappedLowPriorityItem) h.remove(wrappedLowPriorityItem) require.Zero(h.Len()) - require.Empty(h.innerHeap) require.Zero(h.sortedItems.Len()) } @@ -367,13 +273,11 @@ func TestWorkHeapMergeInsertRandom(t *testing.T) { start = maybe.Nothing[[]byte]() } // Make sure end is updated - got, ok := h.sortedItems.Get(&heapItem{ - workItem: &workItem{ - start: start, - }, + got, ok := h.sortedItems.Get(&workItem{ + start: start, }) require.True(ok) - require.Equal(newEnd, got.workItem.end.Value()) + require.Equal(newEnd, got.end.Value()) } } @@ -397,13 +301,11 @@ func TestWorkHeapMergeInsertRandom(t *testing.T) { require.Equal(len(ranges), h.Len()) // Make sure start is updated - got, ok := h.sortedItems.Get(&heapItem{ - workItem: &workItem{ - start: newStart, - }, + got, ok := h.sortedItems.Get(&workItem{ + start: newStart, }) require.True(ok) - require.Equal(newStartBytes, got.workItem.start.Value()) + require.Equal(newStartBytes, got.start.Value()) } } }