Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MerkleDB Reduce buffer creation/memcopy on path construction #2124

Merged
merged 16 commits into from
Oct 16, 2023
Merged
98 changes: 50 additions & 48 deletions x/merkledb/path.go
Original file line number Diff line number Diff line change
Expand Up @@ -149,11 +149,7 @@ func (p Path) Token(index int) byte {
// Path with [token] appended to the end.
func (p Path) Append(token byte) Path {
buffer := make([]byte, p.bytesNeeded(p.tokensLength+1))
copy(buffer, p.value)
// Shift [token] to the left such that it's at the correct
// index within its storage byte, then OR it with its storage
// byte to write the token into the byte.
buffer[len(buffer)-1] |= token << p.bitsToShift(p.tokensLength)
p.appendIntoBuffer(buffer, token)
return Path{
value: byteSliceToString(buffer),
tokensLength: p.tokensLength + 1,
Expand Down Expand Up @@ -216,49 +212,6 @@ func (p Path) bytesNeeded(tokens int) int {
return size
}

// Extend returns a new Path that equals the passed Path appended to the current Path
func (p Path) Extend(path Path) Path {
if p.tokensLength == 0 {
return path
}
if path.tokensLength == 0 {
return p
}

totalLength := p.tokensLength + path.tokensLength

// copy existing value into the buffer
buffer := make([]byte, p.bytesNeeded(totalLength))
copy(buffer, p.value)

// If the existing value fits into a whole number of bytes,
// the extension path can be copied directly into the buffer.
if !p.hasPartialByte() {
copy(buffer[len(p.value):], path.value)
return Path{
value: byteSliceToString(buffer),
tokensLength: totalLength,
pathConfig: p.pathConfig,
}
}

// The existing path doesn't fit into a whole number of bytes.
// Figure out how many bits to shift.
shift := p.bitsToShift(p.tokensLength - 1)
// Fill the partial byte with the first [shift] bits of the extension path
buffer[len(p.value)-1] |= path.value[0] >> (8 - shift)

// copy the rest of the extension path bytes into the buffer,
// shifted byte shift bits
shiftCopy(buffer[len(p.value):], path.value, shift)

return Path{
value: byteSliceToString(buffer),
tokensLength: totalLength,
pathConfig: p.pathConfig,
}
}

// Treats [src] as a bit array and copies it into [dst] shifted by [shift] bits.
// For example, if [src] is [0b0000_0001, 0b0000_0010] and [shift] is 4,
// we copy [0b0001_0000, 0b0010_0000] into [dst].
Expand Down Expand Up @@ -306,6 +259,55 @@ func (p Path) Skip(tokensToSkip int) Path {
return result
}

func (p Path) AppendExtend(token byte, path Path) Path {

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I prefer just Extend

appendBytes := p.bytesNeeded(p.tokensLength + 1)
totalLength := p.tokensLength + 1 + path.tokensLength
buffer := make([]byte, p.bytesNeeded(totalLength))
p.appendIntoBuffer(buffer[:appendBytes], token)

// the extension path will be shifted based on the number of tokens in the partial byte
tokenRemainder := (p.tokensLength + 1) % p.tokensPerByte
path.extendIntoBuffer(tokenRemainder, buffer[appendBytes-1:])
dboehm-avalabs marked this conversation as resolved.
Show resolved Hide resolved

return Path{
value: byteSliceToString(buffer),
tokensLength: totalLength,
pathConfig: p.pathConfig,
}
}

func (p Path) appendIntoBuffer(buffer []byte, token byte) {
copy(buffer, p.value)

// Shift [token] to the left such that it's at the correct
// index within its storage byte, then OR it with its storage
// byte to write the token into the byte.
buffer[len(buffer)-1] |= token << p.bitsToShift(p.tokensLength)
}

func (p Path) extendIntoBuffer(tokenRemainder int, buffer []byte) {
dboehm-avalabs marked this conversation as resolved.
Show resolved Hide resolved
if p.tokensLength == 0 {
return
}

// If the existing value fits into a whole number of bytes,
// the extension path can be copied directly into the buffer.
if tokenRemainder == 0 {
copy(buffer[1:], p.value)
return
}

// The existing path doesn't fit into a whole number of bytes.
// Figure out how many bits to shift.
shift := p.bitsToShift(tokenRemainder - 1)
// Fill the partial byte with the first [shift] bits of the extension path
buffer[0] |= p.value[0] >> (8 - shift)

// copy the rest of the extension path bytes into the buffer,
// shifted byte shift bits
shiftCopy(buffer[1:], p.value, shift)
}

// Take returns a new Path that contains the first tokensToTake tokens of the current Path
func (p Path) Take(tokensToTake int) Path {
if p.tokensLength == tokensToTake {
Expand Down
83 changes: 46 additions & 37 deletions x/merkledb/path_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -275,70 +275,76 @@ func Test_Path_Append(t *testing.T) {
}
}

func Test_Path_Extend(t *testing.T) {
func Test_Path_AppendExtend(t *testing.T) {
require := require.New(t)

path2 := NewPath([]byte{0b1000_0000}, BranchFactor2).Take(1)
p := NewPath([]byte{0b01010101}, BranchFactor2)
extendedP := path2.Extend(p)
require.Equal([]byte{0b10101010, 0b1000_0000}, extendedP.Bytes())
extendedP := path2.AppendExtend(0, p)
require.Equal([]byte{0b10010101, 0b01000_000}, extendedP.Bytes())
require.Equal(byte(1), extendedP.Token(0))
require.Equal(byte(0), extendedP.Token(1))
require.Equal(byte(1), extendedP.Token(2))
require.Equal(byte(0), extendedP.Token(3))
require.Equal(byte(1), extendedP.Token(4))
require.Equal(byte(0), extendedP.Token(5))
require.Equal(byte(1), extendedP.Token(6))
require.Equal(byte(0), extendedP.Token(7))
require.Equal(byte(1), extendedP.Token(8))

p = NewPath([]byte{0b01010101, 0b1000_0000}, BranchFactor2).Take(9)
extendedP = path2.Extend(p)
require.Equal([]byte{0b10101010, 0b1100_0000}, extendedP.Bytes())
require.Equal(byte(0), extendedP.Token(2))
require.Equal(byte(1), extendedP.Token(3))
require.Equal(byte(0), extendedP.Token(4))
require.Equal(byte(1), extendedP.Token(5))
require.Equal(byte(0), extendedP.Token(6))
require.Equal(byte(1), extendedP.Token(7))
require.Equal(byte(0), extendedP.Token(8))
require.Equal(byte(1), extendedP.Token(9))

p = NewPath([]byte{0b0101_0101, 0b1000_0000}, BranchFactor2).Take(9)
extendedP = path2.AppendExtend(0, p)
require.Equal([]byte{0b1001_0101, 0b0110_0000}, extendedP.Bytes())
require.Equal(byte(1), extendedP.Token(0))
require.Equal(byte(0), extendedP.Token(1))
require.Equal(byte(1), extendedP.Token(2))
require.Equal(byte(0), extendedP.Token(3))
require.Equal(byte(1), extendedP.Token(4))
require.Equal(byte(0), extendedP.Token(5))
require.Equal(byte(1), extendedP.Token(6))
require.Equal(byte(0), extendedP.Token(7))
require.Equal(byte(1), extendedP.Token(8))
require.Equal(byte(0), extendedP.Token(2))
require.Equal(byte(1), extendedP.Token(3))
require.Equal(byte(0), extendedP.Token(4))
require.Equal(byte(1), extendedP.Token(5))
require.Equal(byte(0), extendedP.Token(6))
require.Equal(byte(1), extendedP.Token(7))
require.Equal(byte(0), extendedP.Token(8))
require.Equal(byte(1), extendedP.Token(9))
require.Equal(byte(1), extendedP.Token(10))

path4 := NewPath([]byte{0b0100_0000}, BranchFactor4).Take(1)
p = NewPath([]byte{0b0101_0101}, BranchFactor4)
extendedP = path4.Extend(p)
require.Equal([]byte{0b0101_0101, 0b0100_0000}, extendedP.Bytes())
extendedP = path4.AppendExtend(0, p)
require.Equal([]byte{0b0100_0101, 0b0101_0000}, extendedP.Bytes())
require.Equal(byte(1), extendedP.Token(0))
require.Equal(byte(1), extendedP.Token(1))
require.Equal(byte(0), extendedP.Token(1))
require.Equal(byte(1), extendedP.Token(2))
require.Equal(byte(1), extendedP.Token(3))
require.Equal(byte(1), extendedP.Token(4))
require.Equal(byte(1), extendedP.Token(5))

path16 := NewPath([]byte{0b0001_0000}, BranchFactor16).Take(1)
p = NewPath([]byte{0b0001_0001}, BranchFactor16)
extendedP = path16.Extend(p)
require.Equal([]byte{0b0001_0001, 0b0001_0000}, extendedP.Bytes())
extendedP = path16.AppendExtend(0, p)
require.Equal([]byte{0b0001_0000, 0b0001_0001}, extendedP.Bytes())
require.Equal(byte(1), extendedP.Token(0))
require.Equal(byte(1), extendedP.Token(1))
require.Equal(byte(0), extendedP.Token(1))
require.Equal(byte(1), extendedP.Token(2))
require.Equal(byte(1), extendedP.Token(3))

p = NewPath([]byte{0b0001_0001, 0b0001_0001}, BranchFactor16)
extendedP = path16.Extend(p)
require.Equal([]byte{0b0001_0001, 0b0001_0001, 0b0001_0000}, extendedP.Bytes())
extendedP = path16.AppendExtend(0, p)
require.Equal([]byte{0b0001_0000, 0b0001_0001, 0b0001_0001}, extendedP.Bytes())
require.Equal(byte(1), extendedP.Token(0))
require.Equal(byte(1), extendedP.Token(1))
require.Equal(byte(0), extendedP.Token(1))
require.Equal(byte(1), extendedP.Token(2))
require.Equal(byte(1), extendedP.Token(3))
require.Equal(byte(1), extendedP.Token(4))
require.Equal(byte(1), extendedP.Token(5))

path256 := NewPath([]byte{0b0000_0001}, BranchFactor256)
p = NewPath([]byte{0b0000_0001}, BranchFactor256)
extendedP = path256.Extend(p)
require.Equal([]byte{0b0000_0001, 0b0000_0001}, extendedP.Bytes())
extendedP = path256.AppendExtend(0, p)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add tests where we don't pass 0?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is covered in fuzz

require.Equal([]byte{0b0000_0001, 0b0000_0000, 0b0000_0001}, extendedP.Bytes())
require.Equal(byte(1), extendedP.Token(0))
require.Equal(byte(1), extendedP.Token(1))
require.Equal(byte(0), extendedP.Token(1))
require.Equal(byte(1), extendedP.Token(2))
}

func TestPathBytesNeeded(t *testing.T) {
Expand Down Expand Up @@ -458,11 +464,12 @@ func TestPathBytesNeeded(t *testing.T) {
}
}

func FuzzPathExtend(f *testing.F) {
func FuzzPathAppendExtend(f *testing.F) {
f.Fuzz(func(
t *testing.T,
first []byte,
second []byte,
token byte,
forceFirstOdd bool,
forceSecondOdd bool,
) {
Expand All @@ -476,13 +483,15 @@ func FuzzPathExtend(f *testing.F) {
if forceSecondOdd && path2.tokensLength > 0 {
path2 = path2.Take(path2.tokensLength - 1)
}
extendedP := path1.Extend(path2)
require.Equal(path1.tokensLength+path2.tokensLength, extendedP.tokensLength)
token = byte(int(token) % int(branchFactor))
extendedP := path1.AppendExtend(token, path2)
require.Equal(path1.tokensLength+path2.tokensLength+1, extendedP.tokensLength)
for i := 0; i < path1.tokensLength; i++ {
require.Equal(path1.Token(i), extendedP.Token(i))
}
require.Equal(token, extendedP.Token(path1.tokensLength))
for i := 0; i < path2.tokensLength; i++ {
require.Equal(path2.Token(i), extendedP.Token(i+path1.tokensLength))
require.Equal(path2.Token(i), extendedP.Token(i+1+path1.tokensLength))
}
}
})
Expand Down
2 changes: 1 addition & 1 deletion x/merkledb/proof.go
Original file line number Diff line number Diff line change
Expand Up @@ -862,7 +862,7 @@ func addPathInfo(
if existingChild, ok := n.children[index]; ok {
compressedPath = existingChild.compressedPath
}
childPath := keyPath.Append(index).Extend(compressedPath)
childPath := keyPath.AppendExtend(index, compressedPath)
if (shouldInsertLeftChildren && childPath.Less(insertChildrenLessThan.Value())) ||
(shouldInsertRightChildren && childPath.Greater(insertChildrenGreaterThan.Value())) {
// We didn't set the other values on the child entry, but it doesn't matter.
Expand Down
2 changes: 1 addition & 1 deletion x/merkledb/trie_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1201,7 +1201,7 @@ func Test_Trie_ConcurrentNewViewAndCommit(t *testing.T) {
// Assumes this node has exactly one child.
func getSingleChildPath(n *node) Path {
for index, entry := range n.children {
return n.key.Append(index).Extend(entry.compressedPath)
return n.key.AppendExtend(index, entry.compressedPath)
}
return Path{}
}
Expand Down
6 changes: 3 additions & 3 deletions x/merkledb/trieview.go
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ func (t *trieView) calculateNodeIDsHelper(n *node) {
)

for childIndex, child := range n.children {
childPath := n.key.Append(childIndex).Extend(child.compressedPath)
childPath := n.key.AppendExtend(childIndex, child.compressedPath)
childNodeChange, ok := t.changes.nodes[childPath]
if !ok {
// This child wasn't changed.
Expand Down Expand Up @@ -367,7 +367,7 @@ func (t *trieView) getProof(ctx context.Context, key []byte) (*Proof, error) {

childNode, err := t.getNodeWithID(
child.id,
closestNode.key.Append(nextIndex).Extend(child.compressedPath),
closestNode.key.AppendExtend(nextIndex, child.compressedPath),
child.hasValue,
)
if err != nil {
Expand Down Expand Up @@ -694,7 +694,7 @@ func (t *trieView) compressNodePath(parent, node *node) error {
// "Cycle" over the key/values to find the only child.
// Note this iteration once because len(node.children) == 1.
for index, entry := range node.children {
childPath = node.key.Append(index).Extend(entry.compressedPath)
childPath = node.key.AppendExtend(index, entry.compressedPath)
childEntry = entry
}

Expand Down
Loading