From 99c82666cb1021d99e15d78d450e841e1361c87b Mon Sep 17 00:00:00 2001 From: Manav Aggarwal Date: Fri, 28 Oct 2022 12:37:33 +0200 Subject: [PATCH] Fix GetSiblingNode --- deepsubtree.go | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/deepsubtree.go b/deepsubtree.go index 3f296e21f..6f93a52cb 100644 --- a/deepsubtree.go +++ b/deepsubtree.go @@ -62,7 +62,7 @@ func (tree *ImmutableTree) GetSiblingNode(key []byte) (*Node, error) { func (tree *ImmutableTree) recursiveGetSiblingNode(node *Node, key []byte) (*Node, error) { if node == nil || node.isLeaf() { - return nil, fmt.Errorf("no sibling node found for key: %s", key) + return nil, nil } leftNode, err := node.getLeftNode(tree) if err != nil { @@ -72,14 +72,19 @@ func (tree *ImmutableTree) recursiveGetSiblingNode(node *Node, key []byte) (*Nod if err != nil { return nil, err } - if leftNode != nil && bytes.Equal(leftNode.key, key) { + if leftNode != nil && leftNode.isLeaf() && bytes.Equal(leftNode.key, key) { return rightNode, nil } - if rightNode != nil && bytes.Equal(rightNode.key, key) { + if rightNode != nil && rightNode.isLeaf() && bytes.Equal(rightNode.key, key) { return leftNode, nil } - if bytes.Compare(node.key, key) < 0 { - return tree.recursiveGetSiblingNode(leftNode, key) + + siblingNode, err := tree.recursiveGetSiblingNode(leftNode, key) + if err != nil { + return nil, err + } + if siblingNode != nil { + return siblingNode, nil } return tree.recursiveGetSiblingNode(rightNode, key) }