Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

merkledb -- encode lengths as uvarints #2039

Merged
merged 15 commits into from
Sep 19, 2023
2 changes: 1 addition & 1 deletion x/merkledb/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ Once this is encoded, we `sha256` hash the resulting bytes to get the node's ID.

### Encoding Varints and Bytes

Varints are encoded with `binary.PutVarint` from the standard library's `binary/encoding` package.
Varints are encoded with `binary.PutUvarint` from the standard library's `binary/encoding` package.
Bytes are encoded by simply copying them onto the buffer.

## Design choices
Expand Down
38 changes: 17 additions & 21 deletions x/merkledb/codec.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ var (
trueBytes = []byte{trueByte}
falseBytes = []byte{falseByte}

errNegativeNumChildren = errors.New("number of children is negative")
errTooManyChildren = fmt.Errorf("length of children list is larger than branching factor of %d", NodeBranchFactor)
errChildIndexTooLarge = fmt.Errorf("invalid child index. Must be less than branching factor of %d", NodeBranchFactor)
errNegativeNibbleLength = errors.New("nibble length is negative")
Expand Down Expand Up @@ -86,6 +85,8 @@ func newCodec() encoderDecoder {
// Note that bytes.Buffer.Write always returns nil so we
// can ignore its return values in [codecImpl] methods.
type codecImpl struct {
// Invariant: Every byte slice returned by [varIntPool] has
// length [binary.MaxVarintLen64].
varIntPool sync.Pool
}

Expand All @@ -98,12 +99,12 @@ func (c *codecImpl) encodeDBNode(n *dbNode) []byte {
)

c.encodeMaybeByteSlice(buf, n.value)
c.encodeInt(buf, numChildren)
c.encodeUint64(buf, uint64(numChildren))
// Note we insert children in order of increasing index
// for determinism.
for index := byte(0); index < NodeBranchFactor; index++ {
if entry, ok := n.children[index]; ok {
c.encodeInt(buf, int(index))
c.encodeUint64(buf, uint64(index))
path := entry.compressedPath.Serialize()
c.encodeSerializedPath(buf, path)
_, _ = buf.Write(entry.id[:])
Expand All @@ -121,12 +122,12 @@ func (c *codecImpl) encodeHashValues(hv *hashValues) []byte {
buf = bytes.NewBuffer(make([]byte, 0, estimatedLen))
)

c.encodeInt(buf, numChildren)
c.encodeUint64(buf, uint64(numChildren))

// ensure that the order of entries is consistent
for index := byte(0); index < NodeBranchFactor; index++ {
if entry, ok := hv.Children[index]; ok {
c.encodeInt(buf, int(index))
c.encodeUint64(buf, uint64(index))
_, _ = buf.Write(entry.id[:])
}
}
Expand All @@ -149,12 +150,10 @@ func (c *codecImpl) decodeDBNode(b []byte, n *dbNode) error {
}
n.value = value

numChildren, err := c.decodeInt(src)
numChildren, err := c.decodeUint(src)
switch {
case err != nil:
return err
case numChildren < 0:
return errNegativeNumChildren
case numChildren > NodeBranchFactor:
return errTooManyChildren
case numChildren > src.Len()/minChildLen:
Expand All @@ -164,7 +163,7 @@ func (c *codecImpl) decodeDBNode(b []byte, n *dbNode) error {
n.children = make(map[byte]child, NodeBranchFactor)
previousChild := -1
for i := 0; i < numChildren; i++ {
index, err := c.decodeInt(src)
index, err := c.decodeUint(src)
if err != nil {
return err
}
Expand Down Expand Up @@ -221,18 +220,15 @@ func (*codecImpl) decodeBool(src *bytes.Reader) (bool, error) {
}
}

func (c *codecImpl) encodeInt(dst *bytes.Buffer, value int) {
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We only encode non-negative ints so this was replaced with encodeUint64

c.encodeInt64(dst, int64(value))
}

func (*codecImpl) decodeInt(src *bytes.Reader) (int, error) {
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed to decodeUint since we always read uint64s

// decodeUint decodes a uvarint from [src] and returns it as an int.
func (*codecImpl) decodeUint(src *bytes.Reader) (int, error) {
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Arguably this should return a uint but a bunch of places where we use this, we want it as an int, so idk

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we want to just update all of those variables to be uints instead?

// To ensure encoding/decoding is canonical, we need to check for leading
// zeroes in the varint.
// The last byte of the varint we read is the most significant byte.
// If it's 0, then it's a leading zero, which is considered invalid in the
// canonical encoding.
startLen := src.Len()
val64, err := binary.ReadVarint(src)
val64, err := binary.ReadUvarint(src)
switch {
case err == io.EOF:
return 0, io.ErrUnexpectedEOF
Expand Down Expand Up @@ -260,9 +256,9 @@ func (*codecImpl) decodeInt(src *bytes.Reader) (int, error) {
return int(val64), nil
}

func (c *codecImpl) encodeInt64(dst *bytes.Buffer, value int64) {
func (c *codecImpl) encodeUint64(dst *bytes.Buffer, value uint64) {
buf := c.varIntPool.Get().([]byte)
size := binary.PutVarint(buf, value)
size := binary.PutUvarint(buf, value)
_, _ = dst.Write(buf[:size])
c.varIntPool.Put(buf)
}
Expand Down Expand Up @@ -297,7 +293,7 @@ func (c *codecImpl) decodeByteSlice(src *bytes.Reader) ([]byte, error) {
return nil, io.ErrUnexpectedEOF
}

length, err := c.decodeInt(src)
length, err := c.decodeUint(src)
switch {
case err == io.EOF:
return nil, io.ErrUnexpectedEOF
Expand All @@ -320,7 +316,7 @@ func (c *codecImpl) decodeByteSlice(src *bytes.Reader) ([]byte, error) {
}

func (c *codecImpl) encodeByteSlice(dst *bytes.Buffer, value []byte) {
c.encodeInt(dst, len(value))
c.encodeUint64(dst, uint64(len(value)))
if value != nil {
_, _ = dst.Write(value)
}
Expand All @@ -340,7 +336,7 @@ func (*codecImpl) decodeID(src *bytes.Reader) (ids.ID, error) {
}

func (c *codecImpl) encodeSerializedPath(dst *bytes.Buffer, s SerializedPath) {
c.encodeInt(dst, s.NibbleLength)
c.encodeUint64(dst, uint64(s.NibbleLength))
_, _ = dst.Write(s.Value)
}

Expand All @@ -353,7 +349,7 @@ func (c *codecImpl) decodeSerializedPath(src *bytes.Reader) (SerializedPath, err
result SerializedPath
err error
)
if result.NibbleLength, err = c.decodeInt(src); err != nil {
if result.NibbleLength, err = c.decodeUint(src); err != nil {
return SerializedPath{}, err
}
if result.NibbleLength < 0 {
Expand Down
16 changes: 3 additions & 13 deletions x/merkledb/codec_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ func FuzzCodecInt(f *testing.F) {
codec := codec.(*codecImpl)
reader := bytes.NewReader(b)
startLen := reader.Len()
got, err := codec.decodeInt(reader)
got, err := codec.decodeUint(reader)
if err != nil {
t.SkipNow()
}
Expand All @@ -63,7 +63,7 @@ func FuzzCodecInt(f *testing.F) {

// Encoding [got] should be the same as [b].
var buf bytes.Buffer
codec.encodeInt(&buf, got)
codec.encodeUint64(&buf, uint64(got))
bufBytes := buf.Bytes()
require.Len(bufBytes, numRead)
require.Equal(b[:numRead], bufBytes)
Expand Down Expand Up @@ -195,22 +195,12 @@ func TestCodecDecodeDBNode(t *testing.T) {
}

nodeBytes := codec.encodeDBNode(&proof)

// Remove num children (0) from end
nodeBytes = nodeBytes[:len(nodeBytes)-minVarIntLen]
proofBytesBuf := bytes.NewBuffer(nodeBytes)
// Put num children -1 at end
codec.(*codecImpl).encodeInt(proofBytesBuf, -1)

err = codec.decodeDBNode(proofBytesBuf.Bytes(), &parsedDBNode)
require.ErrorIs(err, errNegativeNumChildren)

// Remove num children from end
nodeBytes = proofBytesBuf.Bytes()
nodeBytes = nodeBytes[:len(nodeBytes)-minVarIntLen]
proofBytesBuf = bytes.NewBuffer(nodeBytes)
// Put num children NodeBranchFactor+1 at end
codec.(*codecImpl).encodeInt(proofBytesBuf, NodeBranchFactor+1)
codec.(*codecImpl).encodeUint64(proofBytesBuf, NodeBranchFactor+1)

err = codec.decodeDBNode(proofBytesBuf.Bytes(), &parsedDBNode)
require.ErrorIs(err, errTooManyChildren)
Expand Down
Loading