diff --git a/internal/trie/node/hash.go b/internal/trie/node/hash.go index 5c9ef06f1b..f4e2b6469c 100644 --- a/internal/trie/node/hash.go +++ b/internal/trie/node/hash.go @@ -5,51 +5,78 @@ package node import ( "bytes" + "fmt" "github.com/ChainSafe/gossamer/internal/trie/pools" "github.com/ChainSafe/gossamer/lib/common" ) -// EncodeAndHash returns the encoding of the node and +// MerkleValue produces the Merkle value from the encoding of a node. +// For root nodes, the Merkle value is always the Blak2b hash of the encoding. +// For other nodes, the Merkle value is either: +// - the encoding if it is less than 32 bytes +// - the Blake2b hash of the encoding +func MerkleValue(encoding []byte, isRoot bool) (merkleValue []byte, err error) { + if !isRoot && len(encoding) < 32 { + merkleValue = make([]byte, len(encoding)) + copy(merkleValue, encoding) + return merkleValue, nil + } + + hashDigest, err := common.Blake2bHash(encoding) + if err != nil { + return nil, err + } + + merkleValue = hashDigest[:] + return merkleValue, nil +} + +// MerkleValue returns the encoding of the node and // the blake2b hash digest of the encoding of the node. // If the encoding is less than 32 bytes, the hash returned // is the encoding and not the hash of the encoding. -func (n *Node) EncodeAndHash(isRoot bool) (encoding, hash []byte, err error) { - if !n.Dirty && n.Encoding != nil && n.HashDigest != nil { - return n.Encoding, n.HashDigest, nil +func (n *Node) MerkleValue(isRoot bool) (merkleValue []byte, err error) { + if !n.Dirty && n.HashDigest != nil { + return n.HashDigest, nil } - buffer := pools.EncodingBuffers.Get().(*bytes.Buffer) - buffer.Reset() - defer pools.EncodingBuffers.Put(buffer) + _, merkleValue, err = n.EncodeAndHash(isRoot) + return merkleValue, err +} - err = n.Encode(buffer) - if err != nil { - return nil, nil, err +// EncodeAndHash returns the encoding of the node and the +// Merkle value of the node. See the `MerkleValue` +// method for more details on the value of the Merkle value. +func (n *Node) EncodeAndHash(isRoot bool) (encoding, merkleValue []byte, err error) { + if !n.Dirty && n.Encoding != nil && n.HashDigest != nil { + return n.Encoding, n.HashDigest, nil } - bufferBytes := buffer.Bytes() + if n.Dirty || n.Encoding == nil { + buffer := pools.EncodingBuffers.Get().(*bytes.Buffer) + buffer.Reset() + defer pools.EncodingBuffers.Put(buffer) - // TODO remove this copying since it defeats the purpose of `buffer` - // and the sync.Pool. - n.Encoding = make([]byte, len(bufferBytes)) - copy(n.Encoding, bufferBytes) - encoding = n.Encoding // no need to copy + err = n.Encode(buffer) + if err != nil { + return nil, nil, fmt.Errorf("encode node: %w", err) + } - if !isRoot && buffer.Len() < 32 { - n.HashDigest = make([]byte, len(bufferBytes)) - copy(n.HashDigest, bufferBytes) - hash = n.HashDigest // no need to copy - return encoding, hash, nil + bufferBytes := buffer.Bytes() + + // TODO remove this copying since it defeats the purpose of `buffer` + // and the sync.Pool. + n.Encoding = make([]byte, len(bufferBytes)) + copy(n.Encoding, bufferBytes) } + encoding = n.Encoding // no need to copy - // Note: using the sync.Pool's buffer is useful here. - hashArray, err := common.Blake2bHash(buffer.Bytes()) + merkleValue, err = MerkleValue(encoding, isRoot) if err != nil { - return nil, nil, err + return nil, nil, fmt.Errorf("merkle value: %w", err) } - n.HashDigest = hashArray[:] - hash = n.HashDigest // no need to copy + n.HashDigest = merkleValue // no need to copy - return encoding, hash, nil + return encoding, merkleValue, nil } diff --git a/internal/trie/node/hash_test.go b/internal/trie/node/hash_test.go index b2d785342b..ac5c99564a 100644 --- a/internal/trie/node/hash_test.go +++ b/internal/trie/node/hash_test.go @@ -9,6 +9,110 @@ import ( "github.com/stretchr/testify/assert" ) +func Test_MerkleValue(t *testing.T) { + t.Parallel() + + longEncoding := []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33} //nolint:lll + longEncodingMerkleValue := []byte{0xfc, 0xd2, 0xd9, 0xac, 0xe8, 0x70, 0x52, 0x81, 0x1d, 0x9f, 0x34, 0x27, 0xb5, 0x8f, 0xf3, 0x98, 0xd2, 0xe9, 0xed, 0x83, 0xf3, 0x1, 0xbc, 0x7e, 0xc1, 0xbe, 0x8b, 0x59, 0x39, 0x62, 0xf1, 0x7d} //nolint:lll + + testCases := map[string]struct { + encoding []byte + isRoot bool + merkleValue []byte + errWrapped error + errMessage string + }{ + "non root small encoding": { + encoding: []byte{1}, + merkleValue: []byte{1}, + }, + "non root long encoding": { + encoding: longEncoding, + merkleValue: longEncodingMerkleValue, + }, + "root small encoding": { + encoding: []byte{1}, + isRoot: true, + merkleValue: []byte{ + 0xee, 0x15, 0x5a, 0xce, 0x9c, 0x40, 0x29, 0x20, + 0x74, 0xcb, 0x6a, 0xff, 0x8c, 0x9c, 0xcd, 0xd2, + 0x73, 0xc8, 0x16, 0x48, 0xff, 0x11, 0x49, 0xef, + 0x36, 0xbc, 0xea, 0x6e, 0xbb, 0x8a, 0x3e, 0x25}, + }, + "root long encoding": { + encoding: longEncoding, + isRoot: true, + merkleValue: longEncodingMerkleValue, + }, + } + + for name, testCase := range testCases { + testCase := testCase + t.Run(name, func(t *testing.T) { + t.Parallel() + + merkleValue, err := MerkleValue(testCase.encoding, testCase.isRoot) + + assert.ErrorIs(t, err, testCase.errWrapped) + if testCase.errWrapped != nil { + assert.EqualError(t, err, testCase.errMessage) + } + assert.Equal(t, testCase.merkleValue, merkleValue) + }) + } +} + +func Test_Node_MerkleValue(t *testing.T) { + t.Parallel() + + testCases := map[string]struct { + node Node + isRoot bool + merkleValue []byte + errWrapped error + errMessage string + }{ + "cached merkle value": { + node: Node{ + HashDigest: []byte{1}, + }, + merkleValue: []byte{1}, + }, + "non root small encoding": { + node: Node{ + Encoding: []byte{1}, + }, + merkleValue: []byte{1}, + }, + "root small encoding": { + node: Node{ + Encoding: []byte{1}, + }, + isRoot: true, + merkleValue: []byte{ + 0xee, 0x15, 0x5a, 0xce, 0x9c, 0x40, 0x29, 0x20, + 0x74, 0xcb, 0x6a, 0xff, 0x8c, 0x9c, 0xcd, 0xd2, + 0x73, 0xc8, 0x16, 0x48, 0xff, 0x11, 0x49, 0xef, + 0x36, 0xbc, 0xea, 0x6e, 0xbb, 0x8a, 0x3e, 0x25}, + }, + } + + for name, testCase := range testCases { + testCase := testCase + t.Run(name, func(t *testing.T) { + t.Parallel() + + merkleValue, err := testCase.node.MerkleValue(testCase.isRoot) + + assert.ErrorIs(t, err, testCase.errWrapped) + if testCase.errWrapped != nil { + assert.EqualError(t, err, testCase.errMessage) + } + assert.Equal(t, testCase.merkleValue, merkleValue) + }) + } +} + func Test_Node_EncodeAndHash(t *testing.T) { t.Parallel() diff --git a/lib/trie/database.go b/lib/trie/database.go index 6c300a8b04..a27d190d1c 100644 --- a/lib/trie/database.go +++ b/lib/trie/database.go @@ -98,7 +98,8 @@ func (t *Trie) LoadFromProof(proofEncodedNodes [][]byte, rootHash []byte) error decodedNode.Encoding = rawNode decodedNode.HashDigest = nil - _, hash, err := decodedNode.EncodeAndHash(false) + const isRoot = false + hash, err := decodedNode.MerkleValue(isRoot) if err != nil { return fmt.Errorf("cannot encode and hash node at index %d: %w", i, err) } @@ -185,9 +186,10 @@ func (t *Trie) load(db chaindb.Database, n *Node) error { if len(hash) == 0 && child.Type() == node.Leaf { // node has already been loaded inline // just set encoding + hash digest - _, _, err := child.EncodeAndHash(false) + const isRoot = false + _, err := child.MerkleValue(isRoot) if err != nil { - return err + return fmt.Errorf("merkle value: %w", err) } child.SetDirty(false) continue @@ -454,7 +456,7 @@ func (t *Trie) getInsertedNodeHashes(n *Node, hashes map[common.Hash]struct{}) ( return nil } - _, hash, err := n.EncodeAndHash(n == t.root) + hash, err := n.MerkleValue(n == t.root) if err != nil { return fmt.Errorf( "cannot encode and hash node with hash 0x%x: %w",