diff --git a/deepsubtree.go b/deepsubtree.go index 0ccfd102a..b43f48fb0 100644 --- a/deepsubtree.go +++ b/deepsubtree.go @@ -166,7 +166,7 @@ func (dst *DeepSubTree) verifyOperation(operation Operation, key []byte, value [ return err } } - err = dst.AddExistenceProofs(traceOp.Proofs, nil) + err = dst.AddExistenceProofs(traceOp.Proofs, rootHash) if err != nil { return err } @@ -540,11 +540,7 @@ func (dst *DeepSubTree) AddExistenceProofs(existenceProofs []*ics23.ExistencePro return err } } - rootHash, err := dst.GetInitialRootHash() - if err != nil { - return err - } - err = dst.buildTree(rootHash) + err := dst.buildTree(rootHash) if err != nil { return err } diff --git a/mutable_tree.go b/mutable_tree.go index 6db76d25c..03d644623 100644 --- a/mutable_tree.go +++ b/mutable_tree.go @@ -337,7 +337,7 @@ func (tree *MutableTree) recursiveSet(node *Node, key []byte, value []byte, orph newSelf *Node, updated bool, err error, ) { version := tree.version + 1 - + node.addTrace(tree.ImmutableTree, node.key) if node.isLeaf() { if !tree.skipFastStorageUpgrade { tree.addUnsavedAddition(key, NewFastNode(key, value, version)) @@ -500,7 +500,7 @@ func (tree *MutableTree) remove(key []byte) (value []byte, orphaned []*Node, rem // - the orphaned nodes. func (tree *MutableTree) recursiveRemove(node *Node, key []byte, orphans *[]*Node) (newHash []byte, newSelf *Node, newKey []byte, newValue []byte, err error) { version := tree.version + 1 - + node.addTrace(tree.ImmutableTree, node.key) if node.isLeaf() { if bytes.Equal(key, node.key) { *orphans = append(*orphans, node) diff --git a/node.go b/node.go index 77654c3dd..5f5c68326 100644 --- a/node.go +++ b/node.go @@ -180,6 +180,7 @@ func (node *Node) has(t *ImmutableTree, key []byte) (has bool, err error) { // The index is the index in the list of leaf nodes sorted lexicographically by key. The leftmost leaf has index 0. // It's neighbor has index 1 and so on. func (node *Node) get(t *ImmutableTree, key []byte) (index int64, value []byte, err error) { + node.addTrace(t, node.key) if node.isLeaf() { switch bytes.Compare(node.key, key) { case -1: @@ -190,7 +191,6 @@ func (node *Node) get(t *ImmutableTree, key []byte) (index int64, value []byte, return 0, node.value, nil } } - if bytes.Compare(key, node.key) < 0 { leftNode, err := node.getLeftNode(t) if err != nil { @@ -465,11 +465,16 @@ func (node *Node) writeBytes(w io.Writer) error { return nil } +func (node *Node) addTrace(t *ImmutableTree, key []byte) { + if t == nil || t.ndb == nil { + return + } + t.ndb.addTrace(key) +} + func (node *Node) getLeftNode(t *ImmutableTree) (*Node, error) { if node.leftNode != nil { - if t != nil && t.ndb != nil { - t.ndb.addTrace(node.leftNode.key) - } + node.addTrace(t, node.leftNode.key) return node.leftNode, nil } leftNode, err := t.ndb.GetNode(node.leftHash) @@ -482,9 +487,7 @@ func (node *Node) getLeftNode(t *ImmutableTree) (*Node, error) { func (node *Node) getRightNode(t *ImmutableTree) (*Node, error) { if node.rightNode != nil { - if t != nil && t.ndb != nil { - t.ndb.addTrace(node.rightNode.key) - } + node.addTrace(t, node.rightNode.key) return node.rightNode, nil } rightNode, err := t.ndb.GetNode(node.rightHash)