Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Heap Queue #2135

Merged
merged 14 commits into from
Oct 5, 2023
94 changes: 94 additions & 0 deletions utils/heap/queue.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
// Copyright (C) 2019-2023, Ava Labs, Inc. All rights reserved.
// See the file LICENSE for licensing terms.

package heap

import (
"container/heap"

"github.com/ava-labs/avalanchego/utils"
)

var _ heap.Interface = (*queue[int])(nil)

// NewQueue returns an empty heap. See QueueOf for more.
func NewQueue[T any](less func(a, b T) bool) Queue[T] {
return QueueOf(less)
}

// QueueOf returns a heap containing entries ordered by less.
func QueueOf[T any](less func(a, b T) bool, entries ...T) Queue[T] {
q := Queue[T]{
queue: &queue[T]{
entries: make([]T, len(entries)),
less: less,
},
}

copy(q.queue.entries, entries)
heap.Init(q.queue)
return q
}

type Queue[T any] struct {
queue *queue[T]
}

func (q *Queue[T]) Len() int {
return len(q.queue.entries)
}

func (q *Queue[T]) Push(t T) {
heap.Push(q.queue, t)
}

func (q *Queue[T]) Pop() (T, bool) {
if q.Len() == 0 {
return utils.Zero[T](), false
}

return heap.Pop(q.queue).(T), true
}

func (q *Queue[T]) Peek() (T, bool) {
if q.Len() == 0 {
return utils.Zero[T](), false
}

return q.queue.entries[0], true
}

func (q *Queue[T]) Fix(i int) {
heap.Fix(q.queue, i)
}

type queue[T any] struct {
entries []T
less func(a, b T) bool
}

func (q *queue[T]) Len() int {
return len(q.entries)
}

func (q *queue[T]) Less(i, j int) bool {
return q.less(q.entries[i], q.entries[j])
}

func (q *queue[T]) Swap(i, j int) {
q.entries[i], q.entries[j] = q.entries[j], q.entries[i]
}

func (q *queue[T]) Push(e any) {
q.entries = append(q.entries, e.(T))
}

func (q *queue[T]) Pop() any {
end := len(q.entries) - 1
dhrubabasu marked this conversation as resolved.
Show resolved Hide resolved

popped := q.entries[end]
q.entries[end] = utils.Zero[T]()
q.entries = q.entries[:end]

return popped
}
72 changes: 72 additions & 0 deletions utils/heap/queue_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
// Copyright (C) 2019-2023, Ava Labs, Inc. All rights reserved.
// See the file LICENSE for licensing terms.

package heap

import (
"testing"

"github.com/stretchr/testify/require"
)

func TestHeap(t *testing.T) {
tests := []struct {
name string
setup func(h Queue[int])
expected []int
}{
{
name: "only push",
setup: func(h Queue[int]) {
h.Push(1)
h.Push(2)
h.Push(3)
},
expected: []int{1, 2, 3},
},
{
name: "out of order pushes",
setup: func(h Queue[int]) {
h.Push(1)
h.Push(5)
h.Push(2)
h.Push(4)
h.Push(3)
},
expected: []int{1, 2, 3, 4, 5},
},
{
name: "push and pop",
setup: func(h Queue[int]) {
h.Push(1)
h.Push(5)
h.Push(2)
h.Push(4)
h.Push(3)
h.Pop()
h.Pop()
h.Pop()
},
expected: []int{4, 5},
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
require := require.New(t)

h := NewQueue[int](func(a, b int) bool {
return a < b
})

tt.setup(h)

require.Equal(len(tt.expected), h.Len())
for _, expected := range tt.expected {
got, ok := h.Pop()
require.True(ok)
require.Equal(expected, got)
}
})
}
}
62 changes: 20 additions & 42 deletions vms/platformvm/state/merged_iterator.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,15 @@

package state

import "container/heap"
import "github.com/ava-labs/avalanchego/utils/heap"

var (
_ StakerIterator = (*mergedIterator)(nil)
_ heap.Interface = (*mergedIterator)(nil)
)
var _ StakerIterator = (*mergedIterator)(nil)

type mergedIterator struct {
initialized bool
// heap only contains iterators that have been initialized and are not
// exhausted.
heap []StakerIterator
heap heap.Queue[StakerIterator]
}

// Returns an iterator that returns all of the elements of [stakers] in order.
Expand All @@ -36,15 +33,19 @@ func NewMergedIterator(stakers ...StakerIterator) StakerIterator {
}

it := &mergedIterator{
heap: stakers,
heap: heap.QueueOf(
func(a, b StakerIterator) bool {
return a.Value().Less(b.Value())
},
stakers...,
),
}

heap.Init(it)
return it
}

func (it *mergedIterator) Next() bool {
if len(it.heap) == 0 {
if it.heap.Len() == 0 {
return false
}

Expand All @@ -57,54 +58,31 @@ func (it *mergedIterator) Next() bool {
}

// Update the heap root.
current := it.heap[0]
current, _ := it.heap.Peek()
if current.Next() {
// Calling Next() above modifies [current] so we fix the heap.
heap.Fix(it, 0)
it.heap.Fix(0)
return true
}

// The old root is exhausted. Remove it from the heap.
current.Release()
heap.Pop(it)
return len(it.heap) > 0
it.heap.Pop()
return it.heap.Len() > 0
}

func (it *mergedIterator) Value() *Staker {
return it.heap[0].Value()
peek, _ := it.heap.Peek()
return peek.Value()
}

// When Release() returns, Release() has been called on each element of
// [stakers].
func (it *mergedIterator) Release() {
for _, it := range it.heap {
it.Release()
for it.heap.Len() > 0 {
removed, _ := it.heap.Pop()
removed.Release()
}
it.heap = nil
}

// Returns the number of sub-iterators in [it].
func (it *mergedIterator) Len() int {
return len(it.heap)
}

func (it *mergedIterator) Less(i, j int) bool {
return it.heap[i].Value().Less(it.heap[j].Value())
}

func (it *mergedIterator) Swap(i, j int) {
it.heap[j], it.heap[i] = it.heap[i], it.heap[j]
}

// Push is never actually used - but we need it to implement heap.Interface.
func (it *mergedIterator) Push(value interface{}) {
it.heap = append(it.heap, value.(StakerIterator))
}

func (it *mergedIterator) Pop() interface{} {
newLength := len(it.heap) - 1
value := it.heap[newLength]
it.heap[newLength] = nil
it.heap = it.heap[:newLength]
return value
return it.heap.Len()
}
47 changes: 12 additions & 35 deletions vms/platformvm/state/staker_diff_iterator.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,13 @@
package state

import (
"container/heap"

"github.com/ava-labs/avalanchego/utils/heap"
"github.com/ava-labs/avalanchego/vms/platformvm/txs"
)

var (
_ StakerDiffIterator = (*stakerDiffIterator)(nil)
_ StakerIterator = (*mutableStakerIterator)(nil)
_ heap.Interface = (*mutableStakerIterator)(nil)
)

// StakerDiffIterator is an iterator that iterates over the events that will be
Expand Down Expand Up @@ -114,74 +112,53 @@ func (it *stakerDiffIterator) advancePending() {
type mutableStakerIterator struct {
iteratorExhausted bool
iterator StakerIterator
heap []*Staker
heap heap.Queue[*Staker]
}

func newMutableStakerIterator(iterator StakerIterator) *mutableStakerIterator {
return &mutableStakerIterator{
iteratorExhausted: !iterator.Next(),
iterator: iterator,
heap: heap.NewQueue((*Staker).Less),
}
}

// Add should not be called until after Next has been called at least once.
func (it *mutableStakerIterator) Add(staker *Staker) {
heap.Push(it, staker)
it.heap.Push(staker)
}

func (it *mutableStakerIterator) Next() bool {
// The only time the heap should be empty - is when the iterator is
// exhausted or uninitialized.
if len(it.heap) > 0 {
heap.Pop(it)
if it.heap.Len() > 0 {
it.heap.Pop()
}

// If the iterator is exhausted, the only elements left to iterate over are
// in the heap.
if it.iteratorExhausted {
return len(it.heap) > 0
return it.heap.Len() > 0
}

// If the heap doesn't contain the next staker to return, we need to move
// the next element from the iterator into the heap.
nextIteratorStaker := it.iterator.Value()
if len(it.heap) == 0 || nextIteratorStaker.Less(it.heap[0]) {
peek, ok := it.heap.Peek()
if !ok || nextIteratorStaker.Less(peek) {
it.Add(nextIteratorStaker)
it.iteratorExhausted = !it.iterator.Next()
}
return true
}

func (it *mutableStakerIterator) Value() *Staker {
return it.heap[0]
peek, _ := it.heap.Peek()
return peek
}

func (it *mutableStakerIterator) Release() {
it.iteratorExhausted = true
it.iterator.Release()
it.heap = nil
}

func (it *mutableStakerIterator) Len() int {
return len(it.heap)
}

func (it *mutableStakerIterator) Less(i, j int) bool {
return it.heap[i].Less(it.heap[j])
}

func (it *mutableStakerIterator) Swap(i, j int) {
it.heap[j], it.heap[i] = it.heap[i], it.heap[j]
}

func (it *mutableStakerIterator) Push(value interface{}) {
it.heap = append(it.heap, value.(*Staker))
}

func (it *mutableStakerIterator) Pop() interface{} {
newLength := len(it.heap) - 1
value := it.heap[newLength]
it.heap[newLength] = nil
it.heap = it.heap[:newLength]
return value
it.heap = heap.NewQueue((*Staker).Less)
}
Loading