Skip to content

Commit

Permalink
add heap set
Browse files Browse the repository at this point in the history
  • Loading branch information
joshua-kim committed Oct 5, 2023
1 parent a53f73a commit 6b3147c
Show file tree
Hide file tree
Showing 4 changed files with 187 additions and 199 deletions.
56 changes: 56 additions & 0 deletions utils/heap/set.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
// Copyright (C) 2019-2023, Ava Labs, Inc. All rights reserved.
// See the file LICENSE for licensing terms.

package heap

// NewSet returns a heap without duplicates ordered by its values
func NewSet[T comparable](less func(a, b T) bool) Set[T] {
return Set[T]{
set: Map[T, struct{}]{
queue: &indexedQueue[T, struct{}]{
entries: make([]entry[T, struct{}], 0),

Check failure on line 11 in utils/heap/set.go

View workflow job for this annotation

GitHub Actions / test_e2e

unknown field entries in struct literal of type indexedQueue[T, struct{}]

Check failure on line 11 in utils/heap/set.go

View workflow job for this annotation

GitHub Actions / test_upgrade

unknown field entries in struct literal of type indexedQueue[T, struct{}]

Check failure on line 11 in utils/heap/set.go

View workflow job for this annotation

GitHub Actions / test_e2e_persistent

unknown field entries in struct literal of type indexedQueue[T, struct{}]

Check failure on line 11 in utils/heap/set.go

View workflow job for this annotation

GitHub Actions / build_unit_test (macos-12)

unknown field entries in struct literal of type indexedQueue[T, struct{}]

Check failure on line 11 in utils/heap/set.go

View workflow job for this annotation

GitHub Actions / build_unit_test (ubuntu-20.04)

unknown field entries in struct literal of type indexedQueue[T, struct{}]

Check failure on line 11 in utils/heap/set.go

View workflow job for this annotation

GitHub Actions / build_unit_test (ubuntu-22.04)

unknown field entries in struct literal of type indexedQueue[T, struct{}]
index: make(map[T]int),
less: func(a, b entry[T, struct{}]) bool {

Check failure on line 13 in utils/heap/set.go

View workflow job for this annotation

GitHub Actions / test_e2e

unknown field less in struct literal of type indexedQueue[T, struct{}]

Check failure on line 13 in utils/heap/set.go

View workflow job for this annotation

GitHub Actions / test_upgrade

unknown field less in struct literal of type indexedQueue[T, struct{}]

Check failure on line 13 in utils/heap/set.go

View workflow job for this annotation

GitHub Actions / test_e2e_persistent

unknown field less in struct literal of type indexedQueue[T, struct{}]

Check failure on line 13 in utils/heap/set.go

View workflow job for this annotation

GitHub Actions / build_unit_test (macos-12)

unknown field less in struct literal of type indexedQueue[T, struct{}]

Check failure on line 13 in utils/heap/set.go

View workflow job for this annotation

GitHub Actions / build_unit_test (ubuntu-20.04)

unknown field less in struct literal of type indexedQueue[T, struct{}]

Check failure on line 13 in utils/heap/set.go

View workflow job for this annotation

GitHub Actions / build_unit_test (ubuntu-22.04)

unknown field less in struct literal of type indexedQueue[T, struct{}]
return less(a.k, b.k)
},
},
},
}
}

type Set[T comparable] struct {
set Map[T, struct{}]
}

// Push returns if a value was overwritten
func (s Set[T]) Push(t T) bool {
_, ok := s.set.Push(t, struct{}{})
return ok
}

func (s Set[T]) Pop() (T, bool) {
pop, _, ok := s.set.Pop()
return pop, ok
}

func (s Set[T]) Peek() (T, bool) {
peek, _, ok := s.set.Peek()
return peek, ok
}

func (s Set[T]) Len() int {
return s.set.Len()
}

func (s Set[T]) Remove(i int) T {
remove, _ := s.set.Remove(i)
return remove
}

func (s Set[T]) Fix(i int) {
s.set.Fix(i)
}

func (s Set[T]) Index() map[T]int {
return s.set.queue.index
}
72 changes: 72 additions & 0 deletions utils/heap/set_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
// Copyright (C) 2019-2023, Ava Labs, Inc. All rights reserved.
// See the file LICENSE for licensing terms.

package heap

import (
"testing"

"github.com/stretchr/testify/require"
)

func TestSet(t *testing.T) {
tests := []struct {
name string
setup func(h Set[int])
expected []int
}{
{
name: "only push",
setup: func(h Set[int]) {
h.Push(1)
h.Push(2)
h.Push(3)
},
expected: []int{1, 2, 3},
},
{
name: "out of order pushes",
setup: func(h Set[int]) {
h.Push(1)
h.Push(5)
h.Push(2)
h.Push(4)
h.Push(3)
},
expected: []int{1, 2, 3, 4, 5},
},
{
name: "push and pop",
setup: func(h Set[int]) {
h.Push(1)
h.Push(5)
h.Push(2)
h.Push(4)
h.Push(3)
h.Pop()
h.Pop()
h.Pop()
},
expected: []int{4, 5},
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
require := require.New(t)

h := NewSet[int](func(a, b int) bool {
return a < b
})

tt.setup(h)

require.Equal(len(tt.expected), h.Len())
for _, expected := range tt.expected {
got, ok := h.Pop()
require.True(ok)
require.Equal(expected, got)
}
})
}
}
110 changes: 34 additions & 76 deletions x/sync/workheap.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,44 +5,38 @@ package sync

import (
"bytes"
"container/heap"

"github.com/ava-labs/avalanchego/utils/heap"
"github.com/ava-labs/avalanchego/utils/math"
"github.com/ava-labs/avalanchego/utils/maybe"

"github.com/google/btree"
)

var _ heap.Interface = (*innerHeap)(nil)

type heapItem struct {
workItem *workItem
heapIndex int
}

type innerHeap []*heapItem

// A priority queue of syncWorkItems.
// Note that work item ranges never overlap.
// Supports range merging and priority updating.
// Not safe for concurrent use.
type workHeap struct {
// Max heap of items by priority.
// i.e. heap.Pop returns highest priority item.
innerHeap innerHeap
innerHeap heap.Set[*workItem]
// The heap items sorted by range start.
// A Nothing start is considered to be the smallest.
sortedItems *btree.BTreeG[*heapItem]
sortedItems *btree.BTreeG[*workItem]
closed bool
}

func newWorkHeap() *workHeap {
return &workHeap{
innerHeap: heap.NewSet[*workItem](func(a, b *workItem) bool {
return a.priority > b.priority
}),
sortedItems: btree.NewG(
2,
func(a, b *heapItem) bool {
aNothing := a.workItem.start.IsNothing()
bNothing := b.workItem.start.IsNothing()
func(a, b *workItem) bool {
aNothing := a.start.IsNothing()
bNothing := b.start.IsNothing()
if aNothing {
// [a] is Nothing, so if [b] is Nothing, they're equal.
// Otherwise, [b] is greater.
Expand All @@ -53,7 +47,7 @@ func newWorkHeap() *workHeap {
return false
}
// [a] and [b] both contain values. Compare the values.
return bytes.Compare(a.workItem.start.Value(), b.workItem.start.Value()) < 0
return bytes.Compare(a.start.Value(), b.start.Value()) < 0
},
),
}
Expand All @@ -70,10 +64,8 @@ func (wh *workHeap) Insert(item *workItem) {
return
}

wrappedItem := &heapItem{workItem: item}

heap.Push(&wh.innerHeap, wrappedItem)
wh.sortedItems.ReplaceOrInsert(wrappedItem)
wh.innerHeap.Push(item)
wh.sortedItems.ReplaceOrInsert(item)
}

// Pops and returns a work item from the heap.
Expand All @@ -82,9 +74,9 @@ func (wh *workHeap) GetWork() *workItem {
if wh.closed || wh.Len() == 0 {
return nil
}
item := heap.Pop(&wh.innerHeap).(*heapItem)
item, _ := wh.innerHeap.Pop()
wh.sortedItems.Delete(item)
return item.workItem
return item
}

// Insert the item into the heap, merging it with existing items
Expand All @@ -99,25 +91,23 @@ func (wh *workHeap) MergeInsert(item *workItem) {
return
}

var mergedBefore, mergedAfter *heapItem
searchItem := &heapItem{
workItem: &workItem{
start: item.start,
},
var mergedBefore, mergedAfter *workItem
searchItem := &workItem{
start: item.start,
}

// Find the item with the greatest start range which is less than [item.start].
// Note that the iterator function will run at most once, since it always returns false.
wh.sortedItems.DescendLessOrEqual(
searchItem,
func(beforeItem *heapItem) bool {
if item.localRootID == beforeItem.workItem.localRootID &&
maybe.Equal(item.start, beforeItem.workItem.end, bytes.Equal) {
func(beforeItem *workItem) bool {
if item.localRootID == beforeItem.localRootID &&
maybe.Equal(item.start, beforeItem.end, bytes.Equal) {
// [beforeItem.start, beforeItem.end] and [item.start, item.end] are
// merged into [beforeItem.start, item.end]
beforeItem.workItem.end = item.end
beforeItem.workItem.priority = math.Max(item.priority, beforeItem.workItem.priority)
heap.Fix(&wh.innerHeap, beforeItem.heapIndex)
beforeItem.end = item.end
beforeItem.priority = math.Max(item.priority, beforeItem.priority)
wh.innerHeap.Fix(wh.innerHeap.Index()[beforeItem])
mergedBefore = beforeItem
}
return false
Expand All @@ -127,14 +117,14 @@ func (wh *workHeap) MergeInsert(item *workItem) {
// Note that the iterator function will run at most once, since it always returns false.
wh.sortedItems.AscendGreaterOrEqual(
searchItem,
func(afterItem *heapItem) bool {
if item.localRootID == afterItem.workItem.localRootID &&
maybe.Equal(item.end, afterItem.workItem.start, bytes.Equal) {
func(afterItem *workItem) bool {
if item.localRootID == afterItem.localRootID &&
maybe.Equal(item.end, afterItem.start, bytes.Equal) {
// [item.start, item.end] and [afterItem.start, afterItem.end] are merged into
// [item.start, afterItem.end].
afterItem.workItem.start = item.start
afterItem.workItem.priority = math.Max(item.priority, afterItem.workItem.priority)
heap.Fix(&wh.innerHeap, afterItem.heapIndex)
afterItem.start = item.start
afterItem.priority = math.Max(item.priority, afterItem.priority)
wh.innerHeap.Fix(wh.innerHeap.Index()[afterItem])
mergedAfter = afterItem
}
return false
Expand All @@ -144,12 +134,12 @@ func (wh *workHeap) MergeInsert(item *workItem) {
// we can combine the before item with the after item
if mergedBefore != nil && mergedAfter != nil {
// combine the two ranges
mergedBefore.workItem.end = mergedAfter.workItem.end
mergedBefore.end = mergedAfter.end
// remove the second range since it is now covered by the first
wh.remove(mergedAfter)
// update the priority
mergedBefore.workItem.priority = math.Max(mergedBefore.workItem.priority, mergedAfter.workItem.priority)
heap.Fix(&wh.innerHeap, mergedBefore.heapIndex)
mergedBefore.priority = math.Max(mergedBefore.priority, mergedAfter.priority)
wh.innerHeap.Fix(wh.innerHeap.Index()[mergedBefore])
}

// nothing was merged, so add new item to the heap
Expand All @@ -160,43 +150,11 @@ func (wh *workHeap) MergeInsert(item *workItem) {
}

// Deletes [item] from the heap.
func (wh *workHeap) remove(item *heapItem) {
heap.Remove(&wh.innerHeap, item.heapIndex)

func (wh *workHeap) remove(item *workItem) {
wh.innerHeap.Remove(wh.innerHeap.Index()[item])
wh.sortedItems.Delete(item)
}

func (wh *workHeap) Len() int {
return wh.innerHeap.Len()
}

// below this line are the implementations required for heap.Interface

func (h innerHeap) Len() int {
return len(h)
}

func (h innerHeap) Less(i int, j int) bool {
return h[i].workItem.priority > h[j].workItem.priority
}

func (h innerHeap) Swap(i int, j int) {
h[i], h[j] = h[j], h[i]
h[i].heapIndex = i
h[j].heapIndex = j
}

func (h *innerHeap) Pop() interface{} {
old := *h
n := len(old)
item := old[n-1]
old[n-1] = nil
*h = old[0 : n-1]
return item
}

func (h *innerHeap) Push(x interface{}) {
item := x.(*heapItem)
item.heapIndex = len(*h)
*h = append(*h, item)
}
Loading

0 comments on commit 6b3147c

Please sign in to comment.