Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: refactor the node struct #748

Merged
merged 2 commits into from
Apr 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion diff.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ type KVPairReceiver func(pair *KVPair) error
//
// The algorithm don't run in constant memory strictly, but it tried the best the only
// keep minimal intermediate states in memory.
func (ndb *nodeDB) extractStateChanges(prevVersion int64, prevRoot *NodeKey, root *NodeKey, receiver KVPairReceiver) error {
func (ndb *nodeDB) extractStateChanges(prevVersion int64, prevRoot, root []byte, receiver KVPairReceiver) error {
curIter, err := NewNodeIterator(root, ndb)
if err != nil {
return err
Expand Down
10 changes: 5 additions & 5 deletions import.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ func (i *Importer) writeNode(node *Node) error {
bytesCopy := make([]byte, buf.Len())
copy(bytesCopy, buf.Bytes())

if err := i.batch.Set(i.tree.ndb.nodeKey(node.nodeKey), bytesCopy); err != nil {
if err := i.batch.Set(i.tree.ndb.nodeKey(node.GetKey()), bytesCopy); err != nil {
return err
}

Expand Down Expand Up @@ -135,8 +135,8 @@ func (i *Importer) Add(exportNode *ExportNode) error {
} else if stackSize >= 2 && i.stack[stackSize-1].subtreeHeight < node.subtreeHeight && i.stack[stackSize-2].subtreeHeight < node.subtreeHeight {
node.leftNode = i.stack[stackSize-2]
node.rightNode = i.stack[stackSize-1]
node.leftNodeKey = node.leftNode.nodeKey
node.rightNodeKey = node.rightNode.nodeKey
node.leftNodeKey = node.leftNode.GetKey()
node.rightNodeKey = node.rightNode.GetKey()
node.size = node.leftNode.size + node.rightNode.size
// Update the stack now.
if err := i.writeNode(i.stack[stackSize-2]); err != nil {
Expand Down Expand Up @@ -169,7 +169,7 @@ func (i *Importer) Commit() error {

switch len(i.stack) {
case 0:
if err := i.batch.Set(i.tree.ndb.nodeKey(&NodeKey{version: i.version, nonce: 1}), []byte{}); err != nil {
if err := i.batch.Set(i.tree.ndb.nodeKey(GetRootKey(i.version)), []byte{}); err != nil {
return err
}
case 1:
Expand All @@ -178,7 +178,7 @@ func (i *Importer) Commit() error {
return err
}
if i.stack[0].nodeKey.version < i.version { // it means there is no update in the given version
if err := i.batch.Set(i.tree.ndb.nodeKey(&NodeKey{version: i.version, nonce: 1}), i.tree.ndb.nodeKey(i.stack[0].nodeKey)); err != nil {
if err := i.batch.Set(i.tree.ndb.nodeKey(GetRootKey(i.version)), i.tree.ndb.nodeKey(i.stack[0].GetKey())); err != nil {
return err
}
}
Expand Down
2 changes: 1 addition & 1 deletion iterator.go
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ type NodeIterator struct {
}

// NewNodeIterator returns a new NodeIterator to traverse the tree of the root node.
func NewNodeIterator(rootKey *NodeKey, ndb *nodeDB) (*NodeIterator, error) {
func NewNodeIterator(rootKey []byte, ndb *nodeDB) (*NodeIterator, error) {
if rootKey == nil {
return &NodeIterator{
nodesToVisit: []*Node{},
Expand Down
4 changes: 2 additions & 2 deletions iterator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,7 @@ func TestNodeIterator_Success(t *testing.T) {
require.NoError(t, err)

// check if the iterating count is same with the entire node count of the tree
itr, err := NewNodeIterator(tree.root.nodeKey, tree.ndb)
itr, err := NewNodeIterator(tree.root.GetKey(), tree.ndb)
require.NoError(t, err)
nodeCount := 0
for ; itr.Valid(); itr.Next(false) {
Expand All @@ -353,7 +353,7 @@ func TestNodeIterator_Success(t *testing.T) {
require.Equal(t, int64(nodeCount), tree.Size()*2-1)

// check if the skipped node count is right
itr, err = NewNodeIterator(tree.root.nodeKey, tree.ndb)
itr, err = NewNodeIterator(tree.root.GetKey(), tree.ndb)
require.NoError(t, err)
updateCount := 0
skipCount := 0
Expand Down
8 changes: 4 additions & 4 deletions mutable_tree.go
Original file line number Diff line number Diff line change
Expand Up @@ -973,10 +973,10 @@ func (tree *MutableTree) balance(node *Node) (newSelf *Node, err error) {
func (tree *MutableTree) saveNewNodes(version int64) error {
nonce := int32(0)
newNodes := make([]*Node, 0)
var recursiveAssignKey func(*Node) (*NodeKey, error)
recursiveAssignKey = func(node *Node) (*NodeKey, error) {
var recursiveAssignKey func(*Node) ([]byte, error)
recursiveAssignKey = func(node *Node) ([]byte, error) {
if node.nodeKey != nil {
return node.nodeKey, nil
return node.nodeKey.GetKey(), nil
}
nonce++
node.nodeKey = &NodeKey{
Expand All @@ -1002,7 +1002,7 @@ func (tree *MutableTree) saveNewNodes(version int64) error {
if err != nil {
return nil, err
}
return node.nodeKey, nil
return node.nodeKey.GetKey(), nil
}

if _, err := recursiveAssignKey(tree.root); err != nil {
Expand Down
12 changes: 6 additions & 6 deletions mutable_tree_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -807,7 +807,7 @@ func TestUpgradeStorageToFast_DbErrorConstructor_Failure(t *testing.T) {

// rIterMock is used to get the latest version from disk. We are mocking that rIterMock returns latestTreeVersion from disk
rIterMock.EXPECT().Valid().Return(true).Times(1)
rIterMock.EXPECT().Key().Return(nodeKeyFormat.Key(1))
rIterMock.EXPECT().Key().Return(nodeKeyFormat.Key(GetRootKey(1)))
rIterMock.EXPECT().Close().Return(nil).Times(1)

expectedError := errors.New("some db error")
Expand All @@ -832,7 +832,7 @@ func TestUpgradeStorageToFast_DbErrorEnableFastStorage_Failure(t *testing.T) {

// rIterMock is used to get the latest version from disk. We are mocking that rIterMock returns latestTreeVersion from disk
rIterMock.EXPECT().Valid().Return(true).Times(1)
rIterMock.EXPECT().Key().Return(nodeKeyFormat.Key(1))
rIterMock.EXPECT().Key().Return(nodeKeyFormat.Key(GetRootKey(1)))
rIterMock.EXPECT().Close().Return(nil).Times(1)

expectedError := errors.New("some db error")
Expand Down Expand Up @@ -883,7 +883,7 @@ func TestFastStorageReUpgradeProtection_NoForceUpgrade_Success(t *testing.T) {

// rIterMock is used to get the latest version from disk. We are mocking that rIterMock returns latestTreeVersion from disk
rIterMock.EXPECT().Valid().Return(true).Times(1)
rIterMock.EXPECT().Key().Return(nodeKeyFormat.Key(1))
rIterMock.EXPECT().Key().Return(nodeKeyFormat.Key(GetRootKey(1)))
rIterMock.EXPECT().Close().Return(nil).Times(1)

batchMock := mock.NewMockBatch(ctrl)
Expand Down Expand Up @@ -946,7 +946,7 @@ func TestFastStorageReUpgradeProtection_ForceUpgradeFirstTime_NoForceSecondTime_

// rIterMock is used to get the latest version from disk. We are mocking that rIterMock returns latestTreeVersion from disk
rIterMock.EXPECT().Valid().Return(true).Times(1)
rIterMock.EXPECT().Key().Return(nodeKeyFormat.Key(latestTreeVersion))
rIterMock.EXPECT().Key().Return(nodeKeyFormat.Key(GetRootKey(latestTreeVersion)))
rIterMock.EXPECT().Close().Return(nil).Times(1)

fastNodeKeyToDelete := []byte("some_key")
Expand Down Expand Up @@ -1447,7 +1447,7 @@ func TestMutableTree_InitialVersion_FirstVersion(t *testing.T) {
_, version, err := tree.SaveVersion()
require.NoError(t, err)
require.Equal(t, initialVersion, version)
rootKey := &NodeKey{version: version, nonce: 1}
rootKey := GetRootKey(version)
// the nodes created at the first version are not assigned with the `InitialVersion`
node, err := tree.ndb.GetNode(rootKey)
require.NoError(t, err)
Expand All @@ -1459,7 +1459,7 @@ func TestMutableTree_InitialVersion_FirstVersion(t *testing.T) {
_, version, err = tree.SaveVersion()
require.NoError(t, err)
require.Equal(t, initialVersion+1, version)
rootKey = &NodeKey{version: version, nonce: 1}
rootKey = GetRootKey(version)
// the following versions behaves normally
node, err = tree.ndb.GetNode(rootKey)
require.NoError(t, err)
Expand Down
58 changes: 40 additions & 18 deletions node.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,21 +24,38 @@ type NodeKey struct {
nonce int32
}

// GetKey returns a byte slice of the NodeKey.
func (nk *NodeKey) GetKey() []byte {
b := make([]byte, 12)
binary.BigEndian.PutUint64(b, uint64(nk.version))
binary.BigEndian.PutUint32(b[8:], uint32(nk.nonce))
return b
}

// GetNodeKey returns a NodeKey from a byte slice.
func GetNodeKey(key []byte) *NodeKey {
return &NodeKey{
version: int64(binary.BigEndian.Uint64(key)),
nonce: int32(binary.BigEndian.Uint32(key[8:])),
}
}

// GetRootKey returns a byte slice of the root node key for the given version.
func GetRootKey(version int64) []byte {
b := make([]byte, 12)
binary.BigEndian.PutUint64(b, uint64(version))
binary.BigEndian.PutUint32(b[8:], 1)
return b
}

// Node represents a node in a Tree.
type Node struct {
key []byte
value []byte
hash []byte
nodeKey *NodeKey
leftNodeKey *NodeKey
rightNodeKey *NodeKey
leftNodeKey []byte
rightNodeKey []byte
size int64
leftNode *Node
rightNode *Node
Expand All @@ -57,11 +74,16 @@ func NewNode(key []byte, value []byte) *Node {
}
}

// GetKey returns the key of the node.
func (node *Node) GetKey() []byte {
return node.nodeKey.GetKey()
}

// MakeNode constructs an *Node from an encoded byte slice.
//
// The new node doesn't have its hash saved or set. The caller must set it
// afterwards.
func MakeNode(nodeKey *NodeKey, buf []byte) (*Node, error) {
func MakeNode(nk []byte, buf []byte) (*Node, error) {
// Read node header (height, size, key).
height, n, cause := encoding.DecodeVarint(buf)
if cause != nil {
Expand All @@ -87,7 +109,7 @@ func MakeNode(nodeKey *NodeKey, buf []byte) (*Node, error) {
node := &Node{
subtreeHeight: int8(height),
size: size,
nodeKey: nodeKey,
nodeKey: GetNodeKey(nk),
key: key,
}

Expand Down Expand Up @@ -143,16 +165,12 @@ func MakeNode(nodeKey *NodeKey, buf []byte) (*Node, error) {
}
rightNodeKey.nonce = int32(nonce)

node.leftNodeKey = &leftNodeKey
node.rightNodeKey = &rightNodeKey
node.leftNodeKey = leftNodeKey.GetKey()
node.rightNodeKey = rightNodeKey.GetKey()
}
return node, nil
}

func (node *Node) GetKey() []byte {
return node.nodeKey.GetKey()
}

// String returns a string representation of the node key.
func (nk *NodeKey) String() string {
return fmt.Sprintf("(%d, %d)", nk.version, nk.nonce)
Expand Down Expand Up @@ -456,12 +474,14 @@ func (node *Node) encodedSize() int {
} else {
n += encoding.EncodeBytesSize(node.hash)
if node.leftNodeKey != nil {
n += encoding.EncodeVarintSize(node.leftNodeKey.version) +
encoding.EncodeVarintSize(int64(node.leftNodeKey.nonce))
nk := GetNodeKey(node.leftNodeKey)
n += encoding.EncodeVarintSize(nk.version) +
encoding.EncodeVarintSize(int64(nk.nonce))
}
if node.rightNodeKey != nil {
n += encoding.EncodeVarintSize(node.rightNodeKey.version) +
encoding.EncodeVarintSize(int64(node.rightNodeKey.nonce))
nk := GetNodeKey(node.rightNodeKey)
n += encoding.EncodeVarintSize(nk.version) +
encoding.EncodeVarintSize(int64(nk.nonce))
}
}
return n
Expand Down Expand Up @@ -500,23 +520,25 @@ func (node *Node) writeBytes(w io.Writer) error {
if node.leftNodeKey == nil {
return ErrLeftNodeKeyEmpty
}
cause = encoding.EncodeVarint(w, node.leftNodeKey.version)
leftNodeKey := GetNodeKey(node.leftNodeKey)
cause = encoding.EncodeVarint(w, leftNodeKey.version)
if cause != nil {
return fmt.Errorf("writing the version of left node key, %w", cause)
}
cause = encoding.EncodeVarint(w, int64(node.leftNodeKey.nonce))
cause = encoding.EncodeVarint(w, int64(leftNodeKey.nonce))
if cause != nil {
return fmt.Errorf("writing the nonce of left node key, %w", cause)
}

if node.rightNodeKey == nil {
return ErrRightNodeKeyEmpty
}
cause = encoding.EncodeVarint(w, node.rightNodeKey.version)
rightNodeKey := GetNodeKey(node.rightNodeKey)
cause = encoding.EncodeVarint(w, rightNodeKey.version)
if cause != nil {
return fmt.Errorf("writing the version of right node key, %w", cause)
}
cause = encoding.EncodeVarint(w, int64(node.rightNodeKey.nonce))
cause = encoding.EncodeVarint(w, int64(rightNodeKey.nonce))
if cause != nil {
return fmt.Errorf("writing the nonce of right node key, %w", cause)
}
Expand Down
Loading