Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

merkledb -- encode lengths as uvarints #2039

Merged
merged 15 commits into from
Sep 19, 2023
2 changes: 1 addition & 1 deletion x/merkledb/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ Once this is encoded, we `sha256` hash the resulting bytes to get the node's ID.

### Encoding Varints and Bytes

Varints are encoded with `binary.PutVarint` from the standard library's `binary/encoding` package.
Varints are encoded with `binary.PutUvarint` from the standard library's `binary/encoding` package.
Bytes are encoded by simply copying them onto the buffer.

## Design choices
Expand Down
77 changes: 34 additions & 43 deletions x/merkledb/codec.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,16 +42,13 @@ var (
trueBytes = []byte{trueByte}
falseBytes = []byte{falseByte}

errNegativeNumChildren = errors.New("number of children is negative")
errTooManyChildren = fmt.Errorf("length of children list is larger than branching factor of %d", NodeBranchFactor)
errChildIndexTooLarge = fmt.Errorf("invalid child index. Must be less than branching factor of %d", NodeBranchFactor)
errNegativeNibbleLength = errors.New("nibble length is negative")
errIntTooLarge = errors.New("integer too large to be decoded")
errLeadingZeroes = errors.New("varint has leading zeroes")
errInvalidBool = errors.New("decoded bool is neither true nor false")
errNonZeroNibblePadding = errors.New("nibbles should be padded with 0s")
errExtraSpace = errors.New("trailing buffer space")
errNegativeSliceLength = errors.New("negative slice length")
errIntOverflow = errors.New("value overflows int")
)

// encoderDecoder defines the interface needed by merkleDB to marshal
Expand Down Expand Up @@ -86,6 +83,8 @@ func newCodec() encoderDecoder {
// Note that bytes.Buffer.Write always returns nil so we
// can ignore its return values in [codecImpl] methods.
type codecImpl struct {
// Invariant: Every byte slice returned by [varIntPool] has
// length [binary.MaxVarintLen64].
varIntPool sync.Pool
}

Expand All @@ -98,12 +97,12 @@ func (c *codecImpl) encodeDBNode(n *dbNode) []byte {
)

c.encodeMaybeByteSlice(buf, n.value)
c.encodeInt(buf, numChildren)
c.encodeUint(buf, uint64(numChildren))
// Note we insert children in order of increasing index
// for determinism.
for index := byte(0); index < NodeBranchFactor; index++ {
if entry, ok := n.children[index]; ok {
c.encodeInt(buf, int(index))
c.encodeUint(buf, uint64(index))
path := entry.compressedPath.Serialize()
c.encodeSerializedPath(buf, path)
_, _ = buf.Write(entry.id[:])
Expand All @@ -121,12 +120,12 @@ func (c *codecImpl) encodeHashValues(hv *hashValues) []byte {
buf = bytes.NewBuffer(make([]byte, 0, estimatedLen))
)

c.encodeInt(buf, numChildren)
c.encodeUint(buf, uint64(numChildren))

// ensure that the order of entries is consistent
for index := byte(0); index < NodeBranchFactor; index++ {
if entry, ok := hv.Children[index]; ok {
c.encodeInt(buf, int(index))
c.encodeUint(buf, uint64(index))
_, _ = buf.Write(entry.id[:])
}
}
Expand All @@ -149,26 +148,24 @@ func (c *codecImpl) decodeDBNode(b []byte, n *dbNode) error {
}
n.value = value

numChildren, err := c.decodeInt(src)
numChildren, err := c.decodeUint(src)
switch {
case err != nil:
return err
case numChildren < 0:
return errNegativeNumChildren
case numChildren > NodeBranchFactor:
return errTooManyChildren
case numChildren > src.Len()/minChildLen:
case numChildren > uint64(src.Len()/minChildLen):
return io.ErrUnexpectedEOF
}

n.children = make(map[byte]child, NodeBranchFactor)
previousChild := -1
for i := 0; i < numChildren; i++ {
index, err := c.decodeInt(src)
var previousChild uint64
for i := uint64(0); i < numChildren; i++ {
index, err := c.decodeUint(src)
if err != nil {
return err
}
if index <= previousChild || index >= NodeBranchFactor {
if index >= NodeBranchFactor || (i != 0 && index <= previousChild) {
return errChildIndexTooLarge
}
previousChild = index
Expand Down Expand Up @@ -221,25 +218,19 @@ func (*codecImpl) decodeBool(src *bytes.Reader) (bool, error) {
}
}

func (c *codecImpl) encodeInt(dst *bytes.Buffer, value int) {
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We only encode non-negative ints so this was replaced with encodeUint64

c.encodeInt64(dst, int64(value))
}

func (*codecImpl) decodeInt(src *bytes.Reader) (int, error) {
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed to decodeUint since we always read uint64s

func (*codecImpl) decodeUint(src *bytes.Reader) (uint64, error) {
// To ensure encoding/decoding is canonical, we need to check for leading
// zeroes in the varint.
// The last byte of the varint we read is the most significant byte.
// If it's 0, then it's a leading zero, which is considered invalid in the
// canonical encoding.
startLen := src.Len()
val64, err := binary.ReadVarint(src)
switch {
case err == io.EOF:
return 0, io.ErrUnexpectedEOF
case err != nil:
val64, err := binary.ReadUvarint(src)
if err != nil {
if err == io.EOF {
return 0, io.ErrUnexpectedEOF
}
return 0, err
case val64 > math.MaxInt:
return 0, errIntTooLarge
}
endLen := src.Len()

Expand All @@ -257,12 +248,12 @@ func (*codecImpl) decodeInt(src *bytes.Reader) (int, error) {
}
}

return int(val64), nil
return val64, nil
}

func (c *codecImpl) encodeInt64(dst *bytes.Buffer, value int64) {
func (c *codecImpl) encodeUint(dst *bytes.Buffer, value uint64) {
buf := c.varIntPool.Get().([]byte)
size := binary.PutVarint(buf, value)
size := binary.PutUvarint(buf, value)
_, _ = dst.Write(buf[:size])
c.varIntPool.Put(buf)
}
Expand Down Expand Up @@ -297,14 +288,13 @@ func (c *codecImpl) decodeByteSlice(src *bytes.Reader) ([]byte, error) {
return nil, io.ErrUnexpectedEOF
}

length, err := c.decodeInt(src)
length64, err := c.decodeUint(src)
length := int(length64)
switch {
case err == io.EOF:
return nil, io.ErrUnexpectedEOF
case err != nil:
return nil, err
case length < 0:
return nil, errNegativeSliceLength
case length == 0:
return nil, nil
case length > src.Len():
Expand All @@ -320,7 +310,7 @@ func (c *codecImpl) decodeByteSlice(src *bytes.Reader) ([]byte, error) {
}

func (c *codecImpl) encodeByteSlice(dst *bytes.Buffer, value []byte) {
c.encodeInt(dst, len(value))
c.encodeUint(dst, uint64(len(value)))
if value != nil {
_, _ = dst.Write(value)
}
Expand All @@ -340,7 +330,7 @@ func (*codecImpl) decodeID(src *bytes.Reader) (ids.ID, error) {
}

func (c *codecImpl) encodeSerializedPath(dst *bytes.Buffer, s SerializedPath) {
c.encodeInt(dst, s.NibbleLength)
c.encodeUint(dst, uint64(s.NibbleLength))
_, _ = dst.Write(s.Value)
}

Expand All @@ -349,15 +339,16 @@ func (c *codecImpl) decodeSerializedPath(src *bytes.Reader) (SerializedPath, err
return SerializedPath{}, io.ErrUnexpectedEOF
}

var (
result SerializedPath
err error
)
if result.NibbleLength, err = c.decodeInt(src); err != nil {
nibbleLength, err := c.decodeUint(src)
if err != nil {
return SerializedPath{}, err
}
if result.NibbleLength < 0 {
return SerializedPath{}, errNegativeNibbleLength
if nibbleLength > math.MaxInt {
return SerializedPath{}, errIntOverflow
}

result := SerializedPath{
NibbleLength: int(nibbleLength),
}
pathBytesLen := result.NibbleLength >> 1
hasOddLen := result.hasOddLength()
Expand Down
16 changes: 3 additions & 13 deletions x/merkledb/codec_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ func FuzzCodecInt(f *testing.F) {
codec := codec.(*codecImpl)
reader := bytes.NewReader(b)
startLen := reader.Len()
got, err := codec.decodeInt(reader)
got, err := codec.decodeUint(reader)
if err != nil {
t.SkipNow()
}
Expand All @@ -63,7 +63,7 @@ func FuzzCodecInt(f *testing.F) {

// Encoding [got] should be the same as [b].
var buf bytes.Buffer
codec.encodeInt(&buf, got)
codec.encodeUint(&buf, got)
bufBytes := buf.Bytes()
require.Len(bufBytes, numRead)
require.Equal(b[:numRead], bufBytes)
Expand Down Expand Up @@ -195,22 +195,12 @@ func TestCodecDecodeDBNode(t *testing.T) {
}

nodeBytes := codec.encodeDBNode(&proof)

// Remove num children (0) from end
nodeBytes = nodeBytes[:len(nodeBytes)-minVarIntLen]
proofBytesBuf := bytes.NewBuffer(nodeBytes)
// Put num children -1 at end
codec.(*codecImpl).encodeInt(proofBytesBuf, -1)

err = codec.decodeDBNode(proofBytesBuf.Bytes(), &parsedDBNode)
require.ErrorIs(err, errNegativeNumChildren)

// Remove num children from end
nodeBytes = proofBytesBuf.Bytes()
nodeBytes = nodeBytes[:len(nodeBytes)-minVarIntLen]
proofBytesBuf = bytes.NewBuffer(nodeBytes)
// Put num children NodeBranchFactor+1 at end
codec.(*codecImpl).encodeInt(proofBytesBuf, NodeBranchFactor+1)
codec.(*codecImpl).encodeUint(proofBytesBuf, NodeBranchFactor+1)

err = codec.decodeDBNode(proofBytesBuf.Bytes(), &parsedDBNode)
require.ErrorIs(err, errTooManyChildren)
Expand Down
Loading