diff --git a/tests/fuzzers/stacktrie/trie_fuzzer.go b/tests/fuzzers/stacktrie/trie_fuzzer.go index 074e7b1c3..a0ba68e21 100644 --- a/tests/fuzzers/stacktrie/trie_fuzzer.go +++ b/tests/fuzzers/stacktrie/trie_fuzzer.go @@ -27,6 +27,7 @@ import ( "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core/rawdb" + "github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/ethdb" "github.com/ethereum/go-ethereum/trie" "golang.org/x/crypto/sha3" @@ -213,47 +214,47 @@ func (f *fuzzer) fuzz() int { } // Ensure all the nodes are persisted correctly // Need tracked deleted nodes. - // var ( - // nodeset = make(map[string][]byte) // path -> blob - // trieC = trie.NewStackTrie(func(owner common.Hash, path []byte, hash common.Hash, blob []byte) { - // if crypto.Keccak256Hash(blob) != hash { - // panic("invalid node blob") - // } - // if owner != (common.Hash{}) { - // panic("invalid node owner") - // } - // nodeset[string(path)] = common.CopyBytes(blob) - // }) - // checked int - // ) - // for _, kv := range vals { - // trieC.Update(kv.k, kv.v) - // } - // rootC, _ := trieC.Commit() - // if rootA != rootC { - // panic(fmt.Sprintf("roots differ: (trie) %x != %x (stacktrie)", rootA, rootC)) - // } - // trieA, _ = trie.New(trie.TrieID(rootA), dbA) - // iterA := trieA.NodeIterator(nil) - // for iterA.Next(true) { - // if iterA.Hash() == (common.Hash{}) { - // if _, present := nodeset[string(iterA.Path())]; present { - // panic("unexpected tiny node") - // } - // continue - // } - // nodeBlob, present := nodeset[string(iterA.Path())] - // if !present { - // panic("missing node") - // } - // if !bytes.Equal(nodeBlob, iterA.NodeBlob()) { - // panic("node blob is not matched") - // } - // checked += 1 - // } - // if checked != len(nodeset) { - // panic("node number is not matched") - // } + var ( + nodeset = make(map[string][]byte) // path -> blob + trieC = trie.NewStackTrie(func(owner common.Hash, path []byte, hash common.Hash, blob []byte) { + if crypto.Keccak256Hash(blob) != hash { + panic("invalid node blob") + } + if owner != (common.Hash{}) { + panic("invalid node owner") + } + nodeset[string(path)] = common.CopyBytes(blob) + }) + checked int + ) + for _, kv := range vals { + trieC.Update(kv.k, kv.v) + } + rootC, _ := trieC.Commit() + if rootA != rootC { + panic(fmt.Sprintf("roots differ: (trie) %x != %x (stacktrie)", rootA, rootC)) + } + trieA, _ = trie.New(trie.TrieID(rootA), dbA) + iterA := trieA.NodeIterator(nil) + for iterA.Next(true) { + if iterA.Hash() == (common.Hash{}) { + if _, present := nodeset[string(iterA.Path())]; present { + panic("unexpected tiny node") + } + continue + } + nodeBlob, present := nodeset[string(iterA.Path())] + if !present { + panic("missing node") + } + if !bytes.Equal(nodeBlob, iterA.NodeBlob()) { + panic("node blob is not matched") + } + checked += 1 + } + if checked != len(nodeset) { + panic("node number is not matched") + } return 1 } diff --git a/trie/iterator.go b/trie/iterator.go index 39a9ebcef..20c4d44fb 100644 --- a/trie/iterator.go +++ b/trie/iterator.go @@ -86,6 +86,10 @@ type NodeIterator interface { // For leaf nodes, the last element of the path is the 'terminator symbol' 0x10. Path() []byte + // NodeBlob returns the rlp-encoded value of the current iterated node. + // If the node is an embedded node in its parent, nil is returned then. + NodeBlob() []byte + // Leaf returns true iff the current node is a leaf node. Leaf() bool @@ -227,6 +231,18 @@ func (it *nodeIterator) Path() []byte { return it.path } +func (it *nodeIterator) NodeBlob() []byte { + if it.Hash() == (common.Hash{}) { + return nil // skip the non-standalone node + } + blob, err := it.resolveBlob(it.Hash().Bytes(), it.Path()) + if err != nil { + it.err = err + return nil + } + return blob +} + func (it *nodeIterator) Error() error { if it.err == errIteratorEnd { return nil @@ -369,6 +385,20 @@ func (it *nodeIterator) resolveHash(hash hashNode, path []byte) (node, error) { return it.trie.reader.node(path, common.BytesToHash(hash)) } +func (it *nodeIterator) resolveBlob(hash hashNode, path []byte) ([]byte, error) { + if it.resolver != nil { + if blob, err := it.resolver.Get(hash); err == nil && len(blob) > 0 { + return blob, nil + } + } + // Retrieve the specified node from the underlying node reader. + // it.trie.resolveAndTrack is not used since in that function the + // loaded blob will be tracked, while it's not required here since + // all loaded nodes won't be linked to trie at all and track nodes + // may lead to out-of-memory issue. + return it.trie.reader.nodeBlob(path, common.BytesToHash(hash)) +} + func (st *nodeIteratorState) resolve(it *nodeIterator, path []byte) error { if hash, ok := st.node.(hashNode); ok { resolved, err := it.resolveHash(hash, path) @@ -557,6 +587,10 @@ func (it *differenceIterator) Path() []byte { return it.b.Path() } +func (it *differenceIterator) NodeBlob() []byte { + return it.b.NodeBlob() +} + func (it *differenceIterator) AddResolver(resolver ethdb.KeyValueStore) { panic("not implemented") } @@ -668,6 +702,10 @@ func (it *unionIterator) Path() []byte { return (*it.items)[0].Path() } +func (it *unionIterator) NodeBlob() []byte { + return (*it.items)[0].NodeBlob() +} + func (it *unionIterator) AddResolver(resolver ethdb.KeyValueStore) { panic("not implemented") } diff --git a/trie/iterator_test.go b/trie/iterator_test.go index d0e9b7f12..6fc6eea78 100644 --- a/trie/iterator_test.go +++ b/trie/iterator_test.go @@ -563,3 +563,54 @@ func TestNodeIteratorLargeTrie(t *testing.T) { t.Fatalf("Too many lookups during seek, have %d want %d", have, want) } } + +func TestIteratorNodeBlob(t *testing.T) { + var ( + db = rawdb.NewMemoryDatabase() + triedb = NewDatabase(db) + trie = NewEmpty(triedb) + ) + vals := []struct{ k, v string }{ + {"do", "verb"}, + {"ether", "wookiedoo"}, + {"horse", "stallion"}, + {"shaman", "horse"}, + {"doge", "coin"}, + {"dog", "puppy"}, + {"somethingveryoddindeedthis is", "myothernodedata"}, + } + all := make(map[string]string) + for _, val := range vals { + all[val.k] = val.v + trie.Update([]byte(val.k), []byte(val.v)) + } + trie.Commit(false) + triedb.Cap(0) + + found := make(map[common.Hash][]byte) + it := trie.NodeIterator(nil) + for it.Next(true) { + if it.Hash() == (common.Hash{}) { + continue + } + found[it.Hash()] = it.NodeBlob() + } + + dbIter := db.NewIterator(nil, nil) + defer dbIter.Release() + + var count int + for dbIter.Next() { + got, present := found[common.BytesToHash(dbIter.Key())] + if !present { + t.Fatalf("Miss trie node %v", dbIter.Key()) + } + if !bytes.Equal(got, dbIter.Value()) { + t.Fatalf("Unexpected trie node want %v got %v", dbIter.Value(), got) + } + count += 1 + } + if count != len(found) { + t.Fatal("Find extra trie node via iterator") + } +} diff --git a/trie/utils_test.go b/trie/utils_test.go index d9e229544..011d93967 100644 --- a/trie/utils_test.go +++ b/trie/utils_test.go @@ -173,10 +173,7 @@ func TestTrieTracePrevValue(t *testing.T) { if iter.Hash() == (common.Hash{}) { continue } - blob, err := trie.reader.nodeBlob(iter.Path(), iter.Hash()) - if err != nil { - t.Fatal(err) - } + blob := iter.NodeBlob() seen[string(iter.Path())] = common.CopyBytes(blob) }