Skip to content

Commit

Permalink
merkledb -- encode lengths as uvarints (#2039)
Browse files Browse the repository at this point in the history
  • Loading branch information
Dan Laine authored Sep 19, 2023
1 parent 856df85 commit bd83641
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 58 deletions.
2 changes: 1 addition & 1 deletion x/merkledb/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ Once this is encoded, we `sha256` hash the resulting bytes to get the node's ID.

### Encoding Varints and Bytes

Varints are encoded with `binary.PutVarint` from the standard library's `binary/encoding` package.
Varints are encoded with `binary.PutUvarint` from the standard library's `binary/encoding` package.
Bytes are encoded by simply copying them onto the buffer.

## Design choices
Expand Down
78 changes: 34 additions & 44 deletions x/merkledb/codec.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,16 +42,13 @@ var (
trueBytes = []byte{trueByte}
falseBytes = []byte{falseByte}

errNegativeNumChildren = errors.New("number of children is negative")
errTooManyChildren = fmt.Errorf("length of children list is larger than branching factor of %d", NodeBranchFactor)
errChildIndexTooLarge = fmt.Errorf("invalid child index. Must be less than branching factor of %d", NodeBranchFactor)
errNegativeNibbleLength = errors.New("nibble length is negative")
errIntTooLarge = errors.New("integer too large to be decoded")
errLeadingZeroes = errors.New("varint has leading zeroes")
errInvalidBool = errors.New("decoded bool is neither true nor false")
errNonZeroNibblePadding = errors.New("nibbles should be padded with 0s")
errExtraSpace = errors.New("trailing buffer space")
errNegativeSliceLength = errors.New("negative slice length")
errIntOverflow = errors.New("value overflows int")
)

// encoderDecoder defines the interface needed by merkleDB to marshal
Expand Down Expand Up @@ -86,6 +83,8 @@ func newCodec() encoderDecoder {
// Note that bytes.Buffer.Write always returns nil so we
// can ignore its return values in [codecImpl] methods.
type codecImpl struct {
// Invariant: Every byte slice returned by [varIntPool] has
// length [binary.MaxVarintLen64].
varIntPool sync.Pool
}

Expand All @@ -98,12 +97,12 @@ func (c *codecImpl) encodeDBNode(n *dbNode) []byte {
)

c.encodeMaybeByteSlice(buf, n.value)
c.encodeInt(buf, numChildren)
c.encodeUint(buf, uint64(numChildren))
// Note we insert children in order of increasing index
// for determinism.
for index := byte(0); index < NodeBranchFactor; index++ {
if entry, ok := n.children[index]; ok {
c.encodeInt(buf, int(index))
c.encodeUint(buf, uint64(index))
path := entry.compressedPath.Serialize()
c.encodeSerializedPath(buf, path)
_, _ = buf.Write(entry.id[:])
Expand All @@ -121,12 +120,12 @@ func (c *codecImpl) encodeHashValues(hv *hashValues) []byte {
buf = bytes.NewBuffer(make([]byte, 0, estimatedLen))
)

c.encodeInt(buf, numChildren)
c.encodeUint(buf, uint64(numChildren))

// ensure that the order of entries is consistent
for index := byte(0); index < NodeBranchFactor; index++ {
if entry, ok := hv.Children[index]; ok {
c.encodeInt(buf, int(index))
c.encodeUint(buf, uint64(index))
_, _ = buf.Write(entry.id[:])
}
}
Expand All @@ -149,26 +148,24 @@ func (c *codecImpl) decodeDBNode(b []byte, n *dbNode) error {
}
n.value = value

numChildren, err := c.decodeInt(src)
numChildren, err := c.decodeUint(src)
switch {
case err != nil:
return err
case numChildren < 0:
return errNegativeNumChildren
case numChildren > NodeBranchFactor:
return errTooManyChildren
case numChildren > src.Len()/minChildLen:
case numChildren > uint64(src.Len()/minChildLen):
return io.ErrUnexpectedEOF
}

n.children = make(map[byte]child, NodeBranchFactor)
previousChild := -1
for i := 0; i < numChildren; i++ {
index, err := c.decodeInt(src)
var previousChild uint64
for i := uint64(0); i < numChildren; i++ {
index, err := c.decodeUint(src)
if err != nil {
return err
}
if index <= previousChild || index >= NodeBranchFactor {
if index >= NodeBranchFactor || (i != 0 && index <= previousChild) {
return errChildIndexTooLarge
}
previousChild = index
Expand Down Expand Up @@ -221,25 +218,19 @@ func (*codecImpl) decodeBool(src *bytes.Reader) (bool, error) {
}
}

func (c *codecImpl) encodeInt(dst *bytes.Buffer, value int) {
c.encodeInt64(dst, int64(value))
}

func (*codecImpl) decodeInt(src *bytes.Reader) (int, error) {
func (*codecImpl) decodeUint(src *bytes.Reader) (uint64, error) {
// To ensure encoding/decoding is canonical, we need to check for leading
// zeroes in the varint.
// The last byte of the varint we read is the most significant byte.
// If it's 0, then it's a leading zero, which is considered invalid in the
// canonical encoding.
startLen := src.Len()
val64, err := binary.ReadVarint(src)
switch {
case err == io.EOF:
return 0, io.ErrUnexpectedEOF
case err != nil:
val64, err := binary.ReadUvarint(src)
if err != nil {
if err == io.EOF {
return 0, io.ErrUnexpectedEOF
}
return 0, err
case val64 > math.MaxInt:
return 0, errIntTooLarge
}
endLen := src.Len()

Expand All @@ -257,12 +248,12 @@ func (*codecImpl) decodeInt(src *bytes.Reader) (int, error) {
}
}

return int(val64), nil
return val64, nil
}

func (c *codecImpl) encodeInt64(dst *bytes.Buffer, value int64) {
func (c *codecImpl) encodeUint(dst *bytes.Buffer, value uint64) {
buf := c.varIntPool.Get().([]byte)
size := binary.PutVarint(buf, value)
size := binary.PutUvarint(buf, value)
_, _ = dst.Write(buf[:size])
c.varIntPool.Put(buf)
}
Expand Down Expand Up @@ -297,17 +288,15 @@ func (c *codecImpl) decodeByteSlice(src *bytes.Reader) ([]byte, error) {
return nil, io.ErrUnexpectedEOF
}

length, err := c.decodeInt(src)
length, err := c.decodeUint(src)
switch {
case err == io.EOF:
return nil, io.ErrUnexpectedEOF
case err != nil:
return nil, err
case length < 0:
return nil, errNegativeSliceLength
case length == 0:
return nil, nil
case length > src.Len():
case length > uint64(src.Len()):
return nil, io.ErrUnexpectedEOF
}

Expand All @@ -320,7 +309,7 @@ func (c *codecImpl) decodeByteSlice(src *bytes.Reader) ([]byte, error) {
}

func (c *codecImpl) encodeByteSlice(dst *bytes.Buffer, value []byte) {
c.encodeInt(dst, len(value))
c.encodeUint(dst, uint64(len(value)))
if value != nil {
_, _ = dst.Write(value)
}
Expand All @@ -340,7 +329,7 @@ func (*codecImpl) decodeID(src *bytes.Reader) (ids.ID, error) {
}

func (c *codecImpl) encodeSerializedPath(dst *bytes.Buffer, s SerializedPath) {
c.encodeInt(dst, s.NibbleLength)
c.encodeUint(dst, uint64(s.NibbleLength))
_, _ = dst.Write(s.Value)
}

Expand All @@ -349,15 +338,16 @@ func (c *codecImpl) decodeSerializedPath(src *bytes.Reader) (SerializedPath, err
return SerializedPath{}, io.ErrUnexpectedEOF
}

var (
result SerializedPath
err error
)
if result.NibbleLength, err = c.decodeInt(src); err != nil {
nibbleLength, err := c.decodeUint(src)
if err != nil {
return SerializedPath{}, err
}
if result.NibbleLength < 0 {
return SerializedPath{}, errNegativeNibbleLength
if nibbleLength > math.MaxInt {
return SerializedPath{}, errIntOverflow
}

result := SerializedPath{
NibbleLength: int(nibbleLength),
}
pathBytesLen := result.NibbleLength >> 1
hasOddLen := result.hasOddLength()
Expand Down
16 changes: 3 additions & 13 deletions x/merkledb/codec_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ func FuzzCodecInt(f *testing.F) {
codec := codec.(*codecImpl)
reader := bytes.NewReader(b)
startLen := reader.Len()
got, err := codec.decodeInt(reader)
got, err := codec.decodeUint(reader)
if err != nil {
t.SkipNow()
}
Expand All @@ -63,7 +63,7 @@ func FuzzCodecInt(f *testing.F) {

// Encoding [got] should be the same as [b].
var buf bytes.Buffer
codec.encodeInt(&buf, got)
codec.encodeUint(&buf, got)
bufBytes := buf.Bytes()
require.Len(bufBytes, numRead)
require.Equal(b[:numRead], bufBytes)
Expand Down Expand Up @@ -195,22 +195,12 @@ func TestCodecDecodeDBNode(t *testing.T) {
}

nodeBytes := codec.encodeDBNode(&proof)

// Remove num children (0) from end
nodeBytes = nodeBytes[:len(nodeBytes)-minVarIntLen]
proofBytesBuf := bytes.NewBuffer(nodeBytes)
// Put num children -1 at end
codec.(*codecImpl).encodeInt(proofBytesBuf, -1)

err = codec.decodeDBNode(proofBytesBuf.Bytes(), &parsedDBNode)
require.ErrorIs(err, errNegativeNumChildren)

// Remove num children from end
nodeBytes = proofBytesBuf.Bytes()
nodeBytes = nodeBytes[:len(nodeBytes)-minVarIntLen]
proofBytesBuf = bytes.NewBuffer(nodeBytes)
// Put num children NodeBranchFactor+1 at end
codec.(*codecImpl).encodeInt(proofBytesBuf, NodeBranchFactor+1)
codec.(*codecImpl).encodeUint(proofBytesBuf, NodeBranchFactor+1)

err = codec.decodeDBNode(proofBytesBuf.Bytes(), &parsedDBNode)
require.ErrorIs(err, errTooManyChildren)
Expand Down

0 comments on commit bd83641

Please sign in to comment.