Skip to content

Commit

Permalink
chore(trie): add MerkleValue method and function
Browse files Browse the repository at this point in the history
  • Loading branch information
qdm12 committed Jun 14, 2022
1 parent 2fa5d8a commit acc0652
Show file tree
Hide file tree
Showing 3 changed files with 164 additions and 31 deletions.
81 changes: 54 additions & 27 deletions internal/trie/node/hash.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,51 +5,78 @@ package node

import (
"bytes"
"fmt"

"github.com/ChainSafe/gossamer/internal/trie/pools"
"github.com/ChainSafe/gossamer/lib/common"
)

// EncodeAndHash returns the encoding of the node and
// MerkleValue produces the Merkle value from the encoding of a node.
// For root nodes, the Merkle value is always the Blak2b hash of the encoding.
// For other nodes, the Merkle value is either:
// - the encoding if it is less than 32 bytes
// - the Blake2b hash of the encoding
func MerkleValue(encoding []byte, isRoot bool) (merkleValue []byte, err error) {
if !isRoot && len(encoding) < 32 {
merkleValue = make([]byte, len(encoding))
copy(merkleValue, encoding)
return merkleValue, nil
}

hashDigest, err := common.Blake2bHash(encoding)
if err != nil {
return nil, err
}

merkleValue = hashDigest[:]
return merkleValue, nil
}

// MerkleValue returns the encoding of the node and
// the blake2b hash digest of the encoding of the node.
// If the encoding is less than 32 bytes, the hash returned
// is the encoding and not the hash of the encoding.
func (n *Node) EncodeAndHash(isRoot bool) (encoding, hash []byte, err error) {
if !n.Dirty && n.Encoding != nil && n.HashDigest != nil {
return n.Encoding, n.HashDigest, nil
func (n *Node) MerkleValue(isRoot bool) (merkleValue []byte, err error) {
if !n.Dirty && n.HashDigest != nil {
return n.HashDigest, nil
}

buffer := pools.EncodingBuffers.Get().(*bytes.Buffer)
buffer.Reset()
defer pools.EncodingBuffers.Put(buffer)
_, merkleValue, err = n.EncodeAndHash(isRoot)
return merkleValue, err
}

err = n.Encode(buffer)
if err != nil {
return nil, nil, err
// EncodeAndHash returns the encoding of the node and the
// Merkle value of the node. See the `MerkleValue`
// method for more details on the value of the Merkle value.
func (n *Node) EncodeAndHash(isRoot bool) (encoding, merkleValue []byte, err error) {
if !n.Dirty && n.Encoding != nil && n.HashDigest != nil {
return n.Encoding, n.HashDigest, nil
}

bufferBytes := buffer.Bytes()
if n.Dirty || n.Encoding == nil {
buffer := pools.EncodingBuffers.Get().(*bytes.Buffer)
buffer.Reset()
defer pools.EncodingBuffers.Put(buffer)

// TODO remove this copying since it defeats the purpose of `buffer`
// and the sync.Pool.
n.Encoding = make([]byte, len(bufferBytes))
copy(n.Encoding, bufferBytes)
encoding = n.Encoding // no need to copy
err = n.Encode(buffer)
if err != nil {
return nil, nil, fmt.Errorf("encode node: %w", err)
}

if !isRoot && buffer.Len() < 32 {
n.HashDigest = make([]byte, len(bufferBytes))
copy(n.HashDigest, bufferBytes)
hash = n.HashDigest // no need to copy
return encoding, hash, nil
bufferBytes := buffer.Bytes()

// TODO remove this copying since it defeats the purpose of `buffer`
// and the sync.Pool.
n.Encoding = make([]byte, len(bufferBytes))
copy(n.Encoding, bufferBytes)
}
encoding = n.Encoding // no need to copy

// Note: using the sync.Pool's buffer is useful here.
hashArray, err := common.Blake2bHash(buffer.Bytes())
merkleValue, err = MerkleValue(encoding, isRoot)
if err != nil {
return nil, nil, err
return nil, nil, fmt.Errorf("merkle value: %w", err)
}
n.HashDigest = hashArray[:]
hash = n.HashDigest // no need to copy
n.HashDigest = merkleValue // no need to copy

return encoding, hash, nil
return encoding, merkleValue, nil
}
104 changes: 104 additions & 0 deletions internal/trie/node/hash_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,110 @@ import (
"github.com/stretchr/testify/assert"
)

func Test_MerkleValue(t *testing.T) {
t.Parallel()

longEncoding := []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33} //nolint:lll
longEncodingMerkleValue := []byte{0xfc, 0xd2, 0xd9, 0xac, 0xe8, 0x70, 0x52, 0x81, 0x1d, 0x9f, 0x34, 0x27, 0xb5, 0x8f, 0xf3, 0x98, 0xd2, 0xe9, 0xed, 0x83, 0xf3, 0x1, 0xbc, 0x7e, 0xc1, 0xbe, 0x8b, 0x59, 0x39, 0x62, 0xf1, 0x7d} //nolint:lll

testCases := map[string]struct {
encoding []byte
isRoot bool
merkleValue []byte
errWrapped error
errMessage string
}{
"non root small encoding": {
encoding: []byte{1},
merkleValue: []byte{1},
},
"non root long encoding": {
encoding: longEncoding,
merkleValue: longEncodingMerkleValue,
},
"root small encoding": {
encoding: []byte{1},
isRoot: true,
merkleValue: []byte{
0xee, 0x15, 0x5a, 0xce, 0x9c, 0x40, 0x29, 0x20,
0x74, 0xcb, 0x6a, 0xff, 0x8c, 0x9c, 0xcd, 0xd2,
0x73, 0xc8, 0x16, 0x48, 0xff, 0x11, 0x49, 0xef,
0x36, 0xbc, 0xea, 0x6e, 0xbb, 0x8a, 0x3e, 0x25},
},
"root long encoding": {
encoding: longEncoding,
isRoot: true,
merkleValue: longEncodingMerkleValue,
},
}

for name, testCase := range testCases {
testCase := testCase
t.Run(name, func(t *testing.T) {
t.Parallel()

merkleValue, err := MerkleValue(testCase.encoding, testCase.isRoot)

assert.ErrorIs(t, err, testCase.errWrapped)
if testCase.errWrapped != nil {
assert.EqualError(t, err, testCase.errMessage)
}
assert.Equal(t, testCase.merkleValue, merkleValue)
})
}
}

func Test_Node_MerkleValue(t *testing.T) {
t.Parallel()

testCases := map[string]struct {
node Node
isRoot bool
merkleValue []byte
errWrapped error
errMessage string
}{
"cached merkle value": {
node: Node{
HashDigest: []byte{1},
},
merkleValue: []byte{1},
},
"non root small encoding": {
node: Node{
Encoding: []byte{1},
},
merkleValue: []byte{1},
},
"root small encoding": {
node: Node{
Encoding: []byte{1},
},
isRoot: true,
merkleValue: []byte{
0xee, 0x15, 0x5a, 0xce, 0x9c, 0x40, 0x29, 0x20,
0x74, 0xcb, 0x6a, 0xff, 0x8c, 0x9c, 0xcd, 0xd2,
0x73, 0xc8, 0x16, 0x48, 0xff, 0x11, 0x49, 0xef,
0x36, 0xbc, 0xea, 0x6e, 0xbb, 0x8a, 0x3e, 0x25},
},
}

for name, testCase := range testCases {
testCase := testCase
t.Run(name, func(t *testing.T) {
t.Parallel()

merkleValue, err := testCase.node.MerkleValue(testCase.isRoot)

assert.ErrorIs(t, err, testCase.errWrapped)
if testCase.errWrapped != nil {
assert.EqualError(t, err, testCase.errMessage)
}
assert.Equal(t, testCase.merkleValue, merkleValue)
})
}
}

func Test_Node_EncodeAndHash(t *testing.T) {
t.Parallel()

Expand Down
10 changes: 6 additions & 4 deletions lib/trie/database.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,8 @@ func (t *Trie) LoadFromProof(proofEncodedNodes [][]byte, rootHash []byte) error
decodedNode.Encoding = rawNode
decodedNode.HashDigest = nil

_, hash, err := decodedNode.EncodeAndHash(false)
const isRoot = false
hash, err := decodedNode.MerkleValue(isRoot)
if err != nil {
return fmt.Errorf("cannot encode and hash node at index %d: %w", i, err)
}
Expand Down Expand Up @@ -185,9 +186,10 @@ func (t *Trie) load(db chaindb.Database, n *Node) error {
if len(hash) == 0 && child.Type() == node.Leaf {
// node has already been loaded inline
// just set encoding + hash digest
_, _, err := child.EncodeAndHash(false)
const isRoot = false
_, err := child.MerkleValue(isRoot)
if err != nil {
return err
return fmt.Errorf("merkle value: %w", err)
}
child.SetDirty(false)
continue
Expand Down Expand Up @@ -454,7 +456,7 @@ func (t *Trie) getInsertedNodeHashes(n *Node, hashes map[common.Hash]struct{}) (
return nil
}

_, hash, err := n.EncodeAndHash(n == t.root)
hash, err := n.MerkleValue(n == t.root)
if err != nil {
return fmt.Errorf(
"cannot encode and hash node with hash 0x%x: %w",
Expand Down

0 comments on commit acc0652

Please sign in to comment.