Skip to content

Commit

Permalink
reverse check
Browse files Browse the repository at this point in the history
  • Loading branch information
omerfirmak committed Feb 4, 2025
1 parent ddede19 commit da175b8
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 20 deletions.
13 changes: 7 additions & 6 deletions trie/zk_trie.go
Original file line number Diff line number Diff line change
Expand Up @@ -234,15 +234,15 @@ func VerifyProofSMT(rootHash common.Hash, key []byte, proofDb ethdb.KeyValueRead
}
}

func (t *ZkTrie) CountLeaves() uint64 {
func (t *ZkTrie) CountLeaves(cb func(key, value []byte)) uint64 {
root, err := t.ZkTrie.Tree().Root()
if err != nil {
panic("CountLeaves cannot get root")
}
return t.countLeaves(root, 0)
return t.countLeaves(root, cb, 0)
}

func (t *ZkTrie) countLeaves(root *zkt.Hash, depth int) uint64 {
func (t *ZkTrie) countLeaves(root *zkt.Hash, cb func(key, value []byte), depth int) uint64 {
if root == nil {
return 0
}
Expand All @@ -253,19 +253,20 @@ func (t *ZkTrie) countLeaves(root *zkt.Hash, depth int) uint64 {
}

if rootNode.Type == zktrie.NodeTypeLeaf_New {
cb(rootNode.NodeKey.Bytes(), rootNode.Data())
return 1
} else {
count := make(chan uint64)
if depth < 5 {
go func() {
count <- t.countLeaves(rootNode.ChildL, depth+1)
count <- t.countLeaves(rootNode.ChildL, cb, depth+1)
}()
go func() {
count <- t.countLeaves(rootNode.ChildR, depth+1)
count <- t.countLeaves(rootNode.ChildR, cb, depth+1)
}()
return <-count + <-count
} else {
return t.countLeaves(rootNode.ChildL, depth+1) + t.countLeaves(rootNode.ChildR, depth+1)
return t.countLeaves(rootNode.ChildL, cb, depth+1) + t.countLeaves(rootNode.ChildR, cb, depth+1)
}
}
}
Expand Down
22 changes: 8 additions & 14 deletions trie/zk_trie_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,7 @@ type dbs struct {
mptDb *leveldb.Database
}

var accountsLeft = -1
var accountsDone = 0

func checkTrieEquality(t *testing.T, dbs *dbs, zkRoot, mptRoot common.Hash, leafChecker func(*testing.T, *dbs, []byte, []byte)) {
zkTrie, err := NewZkTrie(zkRoot, NewZktrieDatabase(dbs.zkDb))
Expand All @@ -301,18 +301,12 @@ func checkTrieEquality(t *testing.T, dbs *dbs, zkRoot, mptRoot common.Hash, leaf
mptTrie, err := NewSecure(mptRoot, NewDatabaseWithConfig(dbs.mptDb, &Config{Preimages: true}))
require.NoError(t, err)

expectedLeaves := zkTrie.CountLeaves()
trieIt := NewIterator(mptTrie.NodeIterator(nil))
if accountsLeft == -1 {
accountsLeft = int(expectedLeaves)
}

for trieIt.Next() {
expectedLeaves--
preimageKey := mptTrie.GetKey(trieIt.Key)
expectedLeaves := zkTrie.CountLeaves(func(key, value []byte) {
preimageKey := zkTrie.GetKey(key)
require.NotEmpty(t, preimageKey)
leafChecker(t, dbs, zkTrie.Get(preimageKey), mptTrie.Get(preimageKey))
}
leafChecker(t, dbs, value, mptTrie.Get(preimageKey))
})

require.Zero(t, expectedLeaves)
}

Expand All @@ -326,8 +320,8 @@ func checkAccountEquality(t *testing.T, dbs *dbs, zkAccountBytes, mptAccountByte
require.True(t, mptAccount.Balance.Cmp(zkAccount.Balance) == 0)
require.Equal(t, mptAccount.KeccakCodeHash, zkAccount.KeccakCodeHash)
checkTrieEquality(t, dbs, common.BytesToHash(zkAccount.Root[:]), common.BytesToHash(mptAccount.Root[:]), checkStorageEquality)
accountsLeft--
t.Log("Accounts left:", accountsLeft)
accountsDone++
t.Log("Accounts done:", accountsDone)
}

func checkStorageEquality(t *testing.T, _ *dbs, zkStorageBytes, mptStorageBytes []byte) {
Expand Down

0 comments on commit da175b8

Please sign in to comment.