Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

tests, trie: use slices package for sorting #27496

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 5 additions & 9 deletions tests/fuzzers/rangeproof/rangeproof-fuzzer.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,25 +21,19 @@ import (
"encoding/binary"
"fmt"
"io"
"sort"

"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core/rawdb"
"github.com/ethereum/go-ethereum/ethdb/memorydb"
"github.com/ethereum/go-ethereum/trie"
"golang.org/x/exp/slices"
)

type kv struct {
k, v []byte
t bool
}

type entrySlice []*kv

func (p entrySlice) Len() int { return len(p) }
func (p entrySlice) Less(i, j int) bool { return bytes.Compare(p[i].k, p[j].k) < 0 }
func (p entrySlice) Swap(i, j int) { p[i], p[j] = p[j], p[i] }

type fuzzer struct {
input io.Reader
exhausted bool
Expand Down Expand Up @@ -97,14 +91,16 @@ func (f *fuzzer) fuzz() int {
if f.exhausted {
return 0 // input too short
}
var entries entrySlice
var entries []*kv
for _, kv := range vals {
entries = append(entries, kv)
}
if len(entries) <= 1 {
return 0
}
sort.Sort(entries)
slices.SortFunc(entries, func(a, b *kv) bool {
return bytes.Compare(a.k, b.k) < 0
})

var ok = 0
for {
Expand Down
21 changes: 5 additions & 16 deletions tests/fuzzers/stacktrie/trie_fuzzer.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ import (
"fmt"
"hash"
"io"
"sort"

"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core/rawdb"
Expand All @@ -33,6 +32,7 @@ import (
"github.com/ethereum/go-ethereum/trie"
"github.com/ethereum/go-ethereum/trie/trienode"
"golang.org/x/crypto/sha3"
"golang.org/x/exp/slices"
)

type fuzzer struct {
Expand Down Expand Up @@ -104,19 +104,6 @@ func (b *spongeBatch) Replay(w ethdb.KeyValueWriter) error { return nil }
type kv struct {
k, v []byte
}
type kvs []kv

func (k kvs) Len() int {
return len(k)
}

func (k kvs) Less(i, j int) bool {
return bytes.Compare(k[i].k, k[j].k) < 0
}

func (k kvs) Swap(i, j int) {
k[j], k[i] = k[i], k[j]
}

// Fuzz is the fuzzing entry-point.
// The function must return
Expand Down Expand Up @@ -156,7 +143,7 @@ func (f *fuzzer) fuzz() int {
trieB = trie.NewStackTrie(func(owner common.Hash, path []byte, hash common.Hash, blob []byte) {
rawdb.WriteTrieNode(spongeB, owner, path, hash, blob, dbB.Scheme())
})
vals kvs
vals []kv
useful bool
maxElements = 10000
// operate on unique keys only
Expand Down Expand Up @@ -192,7 +179,9 @@ func (f *fuzzer) fuzz() int {
dbA.Commit(rootA, false)

// Stacktrie requires sorted insertion
sort.Sort(vals)
slices.SortFunc(vals, func(a, b kv) bool {
return bytes.Compare(a.k, b.k) < 0
})
for _, kv := range vals {
if f.debugging {
fmt.Printf("{\"%#x\" , \"%#x\"} // stacktrie.Update\n", kv.k, kv.v)
Expand Down
4 changes: 4 additions & 0 deletions trie/iterator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,10 @@ type kv struct {
t bool
}

func (k *kv) less(other *kv) bool {
return bytes.Compare(k.k, other.k) < 0
}

func TestIteratorLargeData(t *testing.T) {
trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase()))
vals := make(map[string]*kv)
Expand Down
72 changes: 33 additions & 39 deletions trie/proof_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,13 @@ import (
"encoding/binary"
"fmt"
mrand "math/rand"
"sort"
"testing"

"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core/rawdb"
"github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/ethdb/memorydb"
"golang.org/x/exp/slices"
)

// Prng is a pseudo random number generator seeded by strong randomness.
Expand Down Expand Up @@ -165,21 +165,15 @@ func TestMissingKeyProof(t *testing.T) {
}
}

type entrySlice []*kv

func (p entrySlice) Len() int { return len(p) }
func (p entrySlice) Less(i, j int) bool { return bytes.Compare(p[i].k, p[j].k) < 0 }
func (p entrySlice) Swap(i, j int) { p[i], p[j] = p[j], p[i] }

// TestRangeProof tests normal range proof with both edge proofs
// as the existent proof. The test cases are generated randomly.
func TestRangeProof(t *testing.T) {
trie, vals := randomTrie(4096)
var entries entrySlice
var entries []*kv
for _, kv := range vals {
entries = append(entries, kv)
}
sort.Sort(entries)
slices.SortFunc(entries, (*kv).less)
for i := 0; i < 500; i++ {
start := mrand.Intn(len(entries))
end := mrand.Intn(len(entries)-start) + start + 1
Expand Down Expand Up @@ -208,11 +202,11 @@ func TestRangeProof(t *testing.T) {
// The test cases are generated randomly.
func TestRangeProofWithNonExistentProof(t *testing.T) {
trie, vals := randomTrie(4096)
var entries entrySlice
var entries []*kv
for _, kv := range vals {
entries = append(entries, kv)
}
sort.Sort(entries)
slices.SortFunc(entries, (*kv).less)
for i := 0; i < 500; i++ {
start := mrand.Intn(len(entries))
end := mrand.Intn(len(entries)-start) + start + 1
Expand Down Expand Up @@ -280,11 +274,11 @@ func TestRangeProofWithNonExistentProof(t *testing.T) {
// - There exists a gap between the last element and the right edge proof
func TestRangeProofWithInvalidNonExistentProof(t *testing.T) {
trie, vals := randomTrie(4096)
var entries entrySlice
var entries []*kv
for _, kv := range vals {
entries = append(entries, kv)
}
sort.Sort(entries)
slices.SortFunc(entries, (*kv).less)

// Case 1
start, end := 100, 200
Expand Down Expand Up @@ -337,11 +331,11 @@ func TestRangeProofWithInvalidNonExistentProof(t *testing.T) {
// non-existent one.
func TestOneElementRangeProof(t *testing.T) {
trie, vals := randomTrie(4096)
var entries entrySlice
var entries []*kv
for _, kv := range vals {
entries = append(entries, kv)
}
sort.Sort(entries)
slices.SortFunc(entries, (*kv).less)

// One element with existent edge proof, both edge proofs
// point to the SAME key.
Expand Down Expand Up @@ -424,11 +418,11 @@ func TestOneElementRangeProof(t *testing.T) {
// The edge proofs can be nil.
func TestAllElementsProof(t *testing.T) {
trie, vals := randomTrie(4096)
var entries entrySlice
var entries []*kv
for _, kv := range vals {
entries = append(entries, kv)
}
sort.Sort(entries)
slices.SortFunc(entries, (*kv).less)

var k [][]byte
var v [][]byte
Expand Down Expand Up @@ -474,13 +468,13 @@ func TestAllElementsProof(t *testing.T) {
func TestSingleSideRangeProof(t *testing.T) {
for i := 0; i < 64; i++ {
trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase()))
var entries entrySlice
var entries []*kv
for i := 0; i < 4096; i++ {
value := &kv{randBytes(32), randBytes(20), false}
trie.MustUpdate(value.k, value.v)
entries = append(entries, value)
}
sort.Sort(entries)
slices.SortFunc(entries, (*kv).less)

var cases = []int{0, 1, 50, 100, 1000, 2000, len(entries) - 1}
for _, pos := range cases {
Expand Down Expand Up @@ -509,13 +503,13 @@ func TestSingleSideRangeProof(t *testing.T) {
func TestReverseSingleSideRangeProof(t *testing.T) {
for i := 0; i < 64; i++ {
trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase()))
var entries entrySlice
var entries []*kv
for i := 0; i < 4096; i++ {
value := &kv{randBytes(32), randBytes(20), false}
trie.MustUpdate(value.k, value.v)
entries = append(entries, value)
}
sort.Sort(entries)
slices.SortFunc(entries, (*kv).less)

var cases = []int{0, 1, 50, 100, 1000, 2000, len(entries) - 1}
for _, pos := range cases {
Expand Down Expand Up @@ -545,11 +539,11 @@ func TestReverseSingleSideRangeProof(t *testing.T) {
// The prover is expected to detect the error.
func TestBadRangeProof(t *testing.T) {
trie, vals := randomTrie(4096)
var entries entrySlice
var entries []*kv
for _, kv := range vals {
entries = append(entries, kv)
}
sort.Sort(entries)
slices.SortFunc(entries, (*kv).less)

for i := 0; i < 500; i++ {
start := mrand.Intn(len(entries))
Expand Down Expand Up @@ -648,11 +642,11 @@ func TestGappedRangeProof(t *testing.T) {
// TestSameSideProofs tests the element is not in the range covered by proofs
func TestSameSideProofs(t *testing.T) {
trie, vals := randomTrie(4096)
var entries entrySlice
var entries []*kv
for _, kv := range vals {
entries = append(entries, kv)
}
sort.Sort(entries)
slices.SortFunc(entries, (*kv).less)

pos := 1000
first := decreaseKey(common.CopyBytes(entries[pos].k))
Expand Down Expand Up @@ -690,13 +684,13 @@ func TestSameSideProofs(t *testing.T) {

func TestHasRightElement(t *testing.T) {
trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase()))
var entries entrySlice
var entries []*kv
for i := 0; i < 4096; i++ {
value := &kv{randBytes(32), randBytes(20), false}
trie.MustUpdate(value.k, value.v)
entries = append(entries, value)
}
sort.Sort(entries)
slices.SortFunc(entries, (*kv).less)

var cases = []struct {
start int
Expand Down Expand Up @@ -764,11 +758,11 @@ func TestHasRightElement(t *testing.T) {
// The first edge proof must be a non-existent proof.
func TestEmptyRangeProof(t *testing.T) {
trie, vals := randomTrie(4096)
var entries entrySlice
var entries []*kv
for _, kv := range vals {
entries = append(entries, kv)
}
sort.Sort(entries)
slices.SortFunc(entries, (*kv).less)

var cases = []struct {
pos int
Expand Down Expand Up @@ -799,11 +793,11 @@ func TestEmptyRangeProof(t *testing.T) {
func TestBloatedProof(t *testing.T) {
// Use a small trie
trie, kvs := nonRandomTrie(100)
var entries entrySlice
var entries []*kv
for _, kv := range kvs {
entries = append(entries, kv)
}
sort.Sort(entries)
slices.SortFunc(entries, (*kv).less)
var keys [][]byte
var vals [][]byte

Expand Down Expand Up @@ -833,11 +827,11 @@ func TestBloatedProof(t *testing.T) {
// noop technically, but practically should be rejected.
func TestEmptyValueRangeProof(t *testing.T) {
trie, values := randomTrie(512)
var entries entrySlice
var entries []*kv
for _, kv := range values {
entries = append(entries, kv)
}
sort.Sort(entries)
slices.SortFunc(entries, (*kv).less)

// Create a new entry with a slightly modified key
mid := len(entries) / 2
Expand Down Expand Up @@ -877,11 +871,11 @@ func TestEmptyValueRangeProof(t *testing.T) {
// practically should be rejected.
func TestAllElementsEmptyValueRangeProof(t *testing.T) {
trie, values := randomTrie(512)
var entries entrySlice
var entries []*kv
for _, kv := range values {
entries = append(entries, kv)
}
sort.Sort(entries)
slices.SortFunc(entries, (*kv).less)

// Create a new entry with a slightly modified key
mid := len(entries) / 2
Expand Down Expand Up @@ -983,11 +977,11 @@ func BenchmarkVerifyRangeProof5000(b *testing.B) { benchmarkVerifyRangeProof(b,

func benchmarkVerifyRangeProof(b *testing.B, size int) {
trie, vals := randomTrie(8192)
var entries entrySlice
var entries []*kv
for _, kv := range vals {
entries = append(entries, kv)
}
sort.Sort(entries)
slices.SortFunc(entries, (*kv).less)

start := 2
end := start + size
Expand Down Expand Up @@ -1020,11 +1014,11 @@ func BenchmarkVerifyRangeNoProof1000(b *testing.B) { benchmarkVerifyRangeNoProof

func benchmarkVerifyRangeNoProof(b *testing.B, size int) {
trie, vals := randomTrie(size)
var entries entrySlice
var entries []*kv
for _, kv := range vals {
entries = append(entries, kv)
}
sort.Sort(entries)
slices.SortFunc(entries, (*kv).less)

var keys [][]byte
var values [][]byte
Expand Down
8 changes: 5 additions & 3 deletions trie/trienode/node.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@ package trienode

import (
"fmt"
"sort"
"strings"

"github.com/ethereum/go-ethereum/common"
"golang.org/x/exp/slices"
)

// Node is a wrapper which contains the encoded blob of the trie node and its
Expand Down Expand Up @@ -100,12 +100,14 @@ func NewNodeSet(owner common.Hash) *NodeSet {
// ForEachWithOrder iterates the nodes with the order from bottom to top,
// right to left, nodes with the longest path will be iterated first.
func (set *NodeSet) ForEachWithOrder(callback func(path string, n *Node)) {
var paths sort.StringSlice
var paths []string
for path := range set.Nodes {
paths = append(paths, path)
}
// Bottom-up, longest path first
sort.Sort(sort.Reverse(paths))
slices.SortFunc(paths, func(a, b string) bool {
return a > b // Sort in reverse order
})
for _, path := range paths {
callback(path, set.Nodes[path].Unwrap())
}
Expand Down