Skip to content

Commit

Permalink
add heap set
Browse files Browse the repository at this point in the history
  • Loading branch information
joshua-kim committed Oct 4, 2023
1 parent a55cdce commit dac0d55
Show file tree
Hide file tree
Showing 4 changed files with 136 additions and 242 deletions.
56 changes: 56 additions & 0 deletions utils/heap/set.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
// Copyright (C) 2019-2023, Ava Labs, Inc. All rights reserved.
// See the file LICENSE for licensing terms.

package heap

// NewSet returns a heap without duplicates ordered by its values
func NewSet[T comparable](less func(a, b T) bool) Set[T] {
return Set[T]{
set: Map[T, struct{}]{
queue: &indexedQueue[T, struct{}]{
entries: make([]entry[T, struct{}], 0),
index: make(map[T]int),
less: func(a, b entry[T, struct{}]) bool {
return less(a.k, b.k)
},
},
},
}
}

type Set[T comparable] struct {
set Map[T, struct{}]
}

// Push returns if a value was overwritten
func (s Set[T]) Push(t T) bool {
_, ok := s.set.Push(t, struct{}{})
return ok
}

func (s Set[T]) Pop() (T, bool) {
pop, _, ok := s.set.Pop()
return pop, ok
}

func (s Set[T]) Peek() (T, bool) {
peek, _, ok := s.set.Peek()
return peek, ok
}

func (s Set[T]) Len() int {
return s.set.Len()
}

func (s Set[T]) Remove(i int) T {
remove, _ := s.set.Remove(i)
return remove
}

func (s Set[T]) Fix(i int) {
s.set.Fix(i)
}

func (s Set[T]) Index() map[T]int {
return s.set.queue.index
}
62 changes: 20 additions & 42 deletions vms/platformvm/state/merged_iterator.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,15 @@

package state

import "container/heap"
import "github.com/ava-labs/avalanchego/utils/heap"

var (
_ StakerIterator = (*mergedIterator)(nil)
_ heap.Interface = (*mergedIterator)(nil)
)
var _ StakerIterator = (*mergedIterator)(nil)

type mergedIterator struct {
initialized bool
// heap only contains iterators that have been initialized and are not
// exhausted.
heap []StakerIterator
heap heap.Queue[StakerIterator]
}

// Returns an iterator that returns all of the elements of [stakers] in order.
Expand All @@ -36,15 +33,19 @@ func NewMergedIterator(stakers ...StakerIterator) StakerIterator {
}

it := &mergedIterator{
heap: stakers,
heap: heap.OfQueue[StakerIterator](

Check failure on line 36 in vms/platformvm/state/merged_iterator.go

View workflow job for this annotation

GitHub Actions / test_e2e

undefined: heap.OfQueue

Check failure on line 36 in vms/platformvm/state/merged_iterator.go

View workflow job for this annotation

GitHub Actions / Static analysis

undefined: heap.OfQueue) (typecheck)

Check failure on line 36 in vms/platformvm/state/merged_iterator.go

View workflow job for this annotation

GitHub Actions / test_e2e_persistent

undefined: heap.OfQueue

Check failure on line 36 in vms/platformvm/state/merged_iterator.go

View workflow job for this annotation

GitHub Actions / test_upgrade

undefined: heap.OfQueue

Check failure on line 36 in vms/platformvm/state/merged_iterator.go

View workflow job for this annotation

GitHub Actions / build_unit_test (ubuntu-20.04)

undefined: heap.OfQueue

Check failure on line 36 in vms/platformvm/state/merged_iterator.go

View workflow job for this annotation

GitHub Actions / build_unit_test (ubuntu-22.04)

undefined: heap.OfQueue

Check failure on line 36 in vms/platformvm/state/merged_iterator.go

View workflow job for this annotation

GitHub Actions / build_unit_test (self-hosted, linux, ARM64, focal)

undefined: heap.OfQueue

Check failure on line 36 in vms/platformvm/state/merged_iterator.go

View workflow job for this annotation

GitHub Actions / build_unit_test (self-hosted, linux, ARM64, jammy)

undefined: heap.OfQueue
func(a, b StakerIterator) bool {
return a.Value().Less(b.Value())
},
stakers...,
),
}

heap.Init(it)
return it
}

func (it *mergedIterator) Next() bool {
if len(it.heap) == 0 {
if it.heap.Len() == 0 {
return false
}

Expand All @@ -57,54 +58,31 @@ func (it *mergedIterator) Next() bool {
}

// Update the heap root.
current := it.heap[0]
current, _ := it.heap.Peek()
if current.Next() {
// Calling Next() above modifies [current] so we fix the heap.
heap.Fix(it, 0)
it.heap.Fix(0)
return true
}

// The old root is exhausted. Remove it from the heap.
current.Release()
heap.Pop(it)
return len(it.heap) > 0
it.heap.Pop()
return it.heap.Len() > 0
}

func (it *mergedIterator) Value() *Staker {
return it.heap[0].Value()
peek, _ := it.heap.Peek()
return peek.Value()
}

// When Release() returns, Release() has been called on each element of
// [stakers].
func (it *mergedIterator) Release() {
for _, it := range it.heap {
it.Release()
for it.heap.Len() > 0 {
removed, _ := it.heap.Pop()
removed.Release()
}
it.heap = nil
}

// Returns the number of sub-iterators in [it].
func (it *mergedIterator) Len() int {
return len(it.heap)
}

func (it *mergedIterator) Less(i, j int) bool {
return it.heap[i].Value().Less(it.heap[j].Value())
}

func (it *mergedIterator) Swap(i, j int) {
it.heap[j], it.heap[i] = it.heap[i], it.heap[j]
}

// Push is never actually used - but we need it to implement heap.Interface.
func (it *mergedIterator) Push(value interface{}) {
it.heap = append(it.heap, value.(StakerIterator))
}

func (it *mergedIterator) Pop() interface{} {
newLength := len(it.heap) - 1
value := it.heap[newLength]
it.heap[newLength] = nil
it.heap = it.heap[:newLength]
return value
return it.heap.Len()
}
112 changes: 35 additions & 77 deletions x/sync/workheap.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,44 +5,37 @@ package sync

import (
"bytes"
"container/heap"

"github.com/ava-labs/avalanchego/utils/heap"
"github.com/ava-labs/avalanchego/utils/math"
"github.com/ava-labs/avalanchego/utils/maybe"

"github.com/google/btree"
)

var _ heap.Interface = (*innerHeap)(nil)

type heapItem struct {
workItem *workItem
heapIndex int
}

type innerHeap []*heapItem

// A priority queue of syncWorkItems.
// Note that work item ranges never overlap.
// Supports range merging and priority updating.
// Not safe for concurrent use.
type workHeap struct {
// Max heap of items by priority.
// i.e. heap.Pop returns highest priority item.
innerHeap innerHeap
innerHeap heap.Set[*workItem]
// The heap items sorted by range start.
// A Nothing start is considered to be the smallest.
sortedItems *btree.BTreeG[*heapItem]
sortedItems *btree.BTreeG[*workItem]
closed bool
}

func newWorkHeap() *workHeap {
return &workHeap{
innerHeap: heap.NewSet[*workItem](func(a, b *workItem) bool {
return a.priority > b.priority
}),
sortedItems: btree.NewG(
2,
func(a, b *heapItem) bool {
aNothing := a.workItem.start.IsNothing()
bNothing := b.workItem.start.IsNothing()
func(a, b *workItem) bool {
aNothing := a.start.IsNothing()
bNothing := b.start.IsNothing()
if aNothing {
// [a] is Nothing, so if [b] is Nothing, they're equal.
// Otherwise, [b] is greater.
Expand All @@ -53,9 +46,10 @@ func newWorkHeap() *workHeap {
return false
}
// [a] and [b] both contain values. Compare the values.
return bytes.Compare(a.workItem.start.Value(), b.workItem.start.Value()) < 0
return bytes.Compare(a.start.Value(), b.start.Value()) < 0
},
),
closed: false,
}
}

Expand All @@ -70,10 +64,8 @@ func (wh *workHeap) Insert(item *workItem) {
return
}

wrappedItem := &heapItem{workItem: item}

heap.Push(&wh.innerHeap, wrappedItem)
wh.sortedItems.ReplaceOrInsert(wrappedItem)
wh.innerHeap.Push(item)
wh.sortedItems.ReplaceOrInsert(item)
}

// Pops and returns a work item from the heap.
Expand All @@ -82,9 +74,9 @@ func (wh *workHeap) GetWork() *workItem {
if wh.closed || wh.Len() == 0 {
return nil
}
item := heap.Pop(&wh.innerHeap).(*heapItem)
item, _ := wh.innerHeap.Pop()
wh.sortedItems.Delete(item)
return item.workItem
return item
}

// Insert the item into the heap, merging it with existing items
Expand All @@ -99,25 +91,23 @@ func (wh *workHeap) MergeInsert(item *workItem) {
return
}

var mergedBefore, mergedAfter *heapItem
searchItem := &heapItem{
workItem: &workItem{
start: item.start,
},
var mergedBefore, mergedAfter *workItem
searchItem := &workItem{
start: item.start,
}

// Find the item with the greatest start range which is less than [item.start].
// Note that the iterator function will run at most once, since it always returns false.
wh.sortedItems.DescendLessOrEqual(
searchItem,
func(beforeItem *heapItem) bool {
if item.localRootID == beforeItem.workItem.localRootID &&
maybe.Equal(item.start, beforeItem.workItem.end, bytes.Equal) {
func(beforeItem *workItem) bool {
if item.localRootID == beforeItem.localRootID &&
maybe.Equal(item.start, beforeItem.end, bytes.Equal) {
// [beforeItem.start, beforeItem.end] and [item.start, item.end] are
// merged into [beforeItem.start, item.end]
beforeItem.workItem.end = item.end
beforeItem.workItem.priority = math.Max(item.priority, beforeItem.workItem.priority)
heap.Fix(&wh.innerHeap, beforeItem.heapIndex)
beforeItem.end = item.end
beforeItem.priority = math.Max(item.priority, beforeItem.priority)
wh.innerHeap.Fix(wh.innerHeap.Index()[beforeItem])
mergedBefore = beforeItem
}
return false
Expand All @@ -127,14 +117,14 @@ func (wh *workHeap) MergeInsert(item *workItem) {
// Note that the iterator function will run at most once, since it always returns false.
wh.sortedItems.AscendGreaterOrEqual(
searchItem,
func(afterItem *heapItem) bool {
if item.localRootID == afterItem.workItem.localRootID &&
maybe.Equal(item.end, afterItem.workItem.start, bytes.Equal) {
func(afterItem *workItem) bool {
if item.localRootID == afterItem.localRootID &&
maybe.Equal(item.end, afterItem.start, bytes.Equal) {
// [item.start, item.end] and [afterItem.start, afterItem.end] are merged into
// [item.start, afterItem.end].
afterItem.workItem.start = item.start
afterItem.workItem.priority = math.Max(item.priority, afterItem.workItem.priority)
heap.Fix(&wh.innerHeap, afterItem.heapIndex)
afterItem.start = item.start
afterItem.priority = math.Max(item.priority, afterItem.priority)
wh.innerHeap.Fix(wh.innerHeap.Index()[afterItem])
mergedAfter = afterItem
}
return false
Expand All @@ -144,12 +134,12 @@ func (wh *workHeap) MergeInsert(item *workItem) {
// we can combine the before item with the after item
if mergedBefore != nil && mergedAfter != nil {
// combine the two ranges
mergedBefore.workItem.end = mergedAfter.workItem.end
mergedBefore.end = mergedAfter.end
// remove the second range since it is now covered by the first
wh.remove(mergedAfter)
// update the priority
mergedBefore.workItem.priority = math.Max(mergedBefore.workItem.priority, mergedAfter.workItem.priority)
heap.Fix(&wh.innerHeap, mergedBefore.heapIndex)
mergedBefore.priority = math.Max(mergedBefore.priority, mergedAfter.priority)
wh.innerHeap.Fix(wh.innerHeap.Index()[mergedBefore])
}

// nothing was merged, so add new item to the heap
Expand All @@ -160,43 +150,11 @@ func (wh *workHeap) MergeInsert(item *workItem) {
}

// Deletes [item] from the heap.
func (wh *workHeap) remove(item *heapItem) {
heap.Remove(&wh.innerHeap, item.heapIndex)

func (wh *workHeap) remove(item *workItem) {
wh.innerHeap.Remove(wh.innerHeap.Index()[item])
wh.sortedItems.Delete(item)
}

func (wh *workHeap) Len() int {
return wh.innerHeap.Len()
}

// below this line are the implementations required for heap.Interface

func (h innerHeap) Len() int {
return len(h)
}

func (h innerHeap) Less(i int, j int) bool {
return h[i].workItem.priority > h[j].workItem.priority
}

func (h innerHeap) Swap(i int, j int) {
h[i], h[j] = h[j], h[i]
h[i].heapIndex = i
h[j].heapIndex = j
}

func (h *innerHeap) Pop() interface{} {
old := *h
n := len(old)
item := old[n-1]
old[n-1] = nil
*h = old[0 : n-1]
return item
}

func (h *innerHeap) Push(x interface{}) {
item := x.(*heapItem)
item.heapIndex = len(*h)
*h = append(*h, item)
}
Loading

0 comments on commit dac0d55

Please sign in to comment.