diff --git a/x/merkledb/README.md b/x/merkledb/README.md index 8d15463d7f88..270382bf07d5 100644 --- a/x/merkledb/README.md +++ b/x/merkledb/README.md @@ -210,7 +210,7 @@ Once this is encoded, we `sha256` hash the resulting bytes to get the node's ID. ### Encoding Varints and Bytes -Varints are encoded with `binary.PutVarint` from the standard library's `binary/encoding` package. +Varints are encoded with `binary.PutUvarint` from the standard library's `binary/encoding` package. Bytes are encoded by simply copying them onto the buffer. ## Design choices diff --git a/x/merkledb/codec.go b/x/merkledb/codec.go index 62698e16af6d..5c5ad1d6f72a 100644 --- a/x/merkledb/codec.go +++ b/x/merkledb/codec.go @@ -42,16 +42,13 @@ var ( trueBytes = []byte{trueByte} falseBytes = []byte{falseByte} - errNegativeNumChildren = errors.New("number of children is negative") errTooManyChildren = fmt.Errorf("length of children list is larger than branching factor of %d", NodeBranchFactor) errChildIndexTooLarge = fmt.Errorf("invalid child index. Must be less than branching factor of %d", NodeBranchFactor) - errNegativeNibbleLength = errors.New("nibble length is negative") - errIntTooLarge = errors.New("integer too large to be decoded") errLeadingZeroes = errors.New("varint has leading zeroes") errInvalidBool = errors.New("decoded bool is neither true nor false") errNonZeroNibblePadding = errors.New("nibbles should be padded with 0s") errExtraSpace = errors.New("trailing buffer space") - errNegativeSliceLength = errors.New("negative slice length") + errIntOverflow = errors.New("value overflows int") ) // encoderDecoder defines the interface needed by merkleDB to marshal @@ -86,6 +83,8 @@ func newCodec() encoderDecoder { // Note that bytes.Buffer.Write always returns nil so we // can ignore its return values in [codecImpl] methods. type codecImpl struct { + // Invariant: Every byte slice returned by [varIntPool] has + // length [binary.MaxVarintLen64]. varIntPool sync.Pool } @@ -98,12 +97,12 @@ func (c *codecImpl) encodeDBNode(n *dbNode) []byte { ) c.encodeMaybeByteSlice(buf, n.value) - c.encodeInt(buf, numChildren) + c.encodeUint(buf, uint64(numChildren)) // Note we insert children in order of increasing index // for determinism. for index := byte(0); index < NodeBranchFactor; index++ { if entry, ok := n.children[index]; ok { - c.encodeInt(buf, int(index)) + c.encodeUint(buf, uint64(index)) path := entry.compressedPath.Serialize() c.encodeSerializedPath(buf, path) _, _ = buf.Write(entry.id[:]) @@ -121,12 +120,12 @@ func (c *codecImpl) encodeHashValues(hv *hashValues) []byte { buf = bytes.NewBuffer(make([]byte, 0, estimatedLen)) ) - c.encodeInt(buf, numChildren) + c.encodeUint(buf, uint64(numChildren)) // ensure that the order of entries is consistent for index := byte(0); index < NodeBranchFactor; index++ { if entry, ok := hv.Children[index]; ok { - c.encodeInt(buf, int(index)) + c.encodeUint(buf, uint64(index)) _, _ = buf.Write(entry.id[:]) } } @@ -149,26 +148,24 @@ func (c *codecImpl) decodeDBNode(b []byte, n *dbNode) error { } n.value = value - numChildren, err := c.decodeInt(src) + numChildren, err := c.decodeUint(src) switch { case err != nil: return err - case numChildren < 0: - return errNegativeNumChildren case numChildren > NodeBranchFactor: return errTooManyChildren - case numChildren > src.Len()/minChildLen: + case numChildren > uint64(src.Len()/minChildLen): return io.ErrUnexpectedEOF } n.children = make(map[byte]child, NodeBranchFactor) - previousChild := -1 - for i := 0; i < numChildren; i++ { - index, err := c.decodeInt(src) + var previousChild uint64 + for i := uint64(0); i < numChildren; i++ { + index, err := c.decodeUint(src) if err != nil { return err } - if index <= previousChild || index >= NodeBranchFactor { + if index >= NodeBranchFactor || (i != 0 && index <= previousChild) { return errChildIndexTooLarge } previousChild = index @@ -221,25 +218,19 @@ func (*codecImpl) decodeBool(src *bytes.Reader) (bool, error) { } } -func (c *codecImpl) encodeInt(dst *bytes.Buffer, value int) { - c.encodeInt64(dst, int64(value)) -} - -func (*codecImpl) decodeInt(src *bytes.Reader) (int, error) { +func (*codecImpl) decodeUint(src *bytes.Reader) (uint64, error) { // To ensure encoding/decoding is canonical, we need to check for leading // zeroes in the varint. // The last byte of the varint we read is the most significant byte. // If it's 0, then it's a leading zero, which is considered invalid in the // canonical encoding. startLen := src.Len() - val64, err := binary.ReadVarint(src) - switch { - case err == io.EOF: - return 0, io.ErrUnexpectedEOF - case err != nil: + val64, err := binary.ReadUvarint(src) + if err != nil { + if err == io.EOF { + return 0, io.ErrUnexpectedEOF + } return 0, err - case val64 > math.MaxInt: - return 0, errIntTooLarge } endLen := src.Len() @@ -257,12 +248,12 @@ func (*codecImpl) decodeInt(src *bytes.Reader) (int, error) { } } - return int(val64), nil + return val64, nil } -func (c *codecImpl) encodeInt64(dst *bytes.Buffer, value int64) { +func (c *codecImpl) encodeUint(dst *bytes.Buffer, value uint64) { buf := c.varIntPool.Get().([]byte) - size := binary.PutVarint(buf, value) + size := binary.PutUvarint(buf, value) _, _ = dst.Write(buf[:size]) c.varIntPool.Put(buf) } @@ -297,17 +288,15 @@ func (c *codecImpl) decodeByteSlice(src *bytes.Reader) ([]byte, error) { return nil, io.ErrUnexpectedEOF } - length, err := c.decodeInt(src) + length, err := c.decodeUint(src) switch { case err == io.EOF: return nil, io.ErrUnexpectedEOF case err != nil: return nil, err - case length < 0: - return nil, errNegativeSliceLength case length == 0: return nil, nil - case length > src.Len(): + case length > uint64(src.Len()): return nil, io.ErrUnexpectedEOF } @@ -320,7 +309,7 @@ func (c *codecImpl) decodeByteSlice(src *bytes.Reader) ([]byte, error) { } func (c *codecImpl) encodeByteSlice(dst *bytes.Buffer, value []byte) { - c.encodeInt(dst, len(value)) + c.encodeUint(dst, uint64(len(value))) if value != nil { _, _ = dst.Write(value) } @@ -340,7 +329,7 @@ func (*codecImpl) decodeID(src *bytes.Reader) (ids.ID, error) { } func (c *codecImpl) encodeSerializedPath(dst *bytes.Buffer, s SerializedPath) { - c.encodeInt(dst, s.NibbleLength) + c.encodeUint(dst, uint64(s.NibbleLength)) _, _ = dst.Write(s.Value) } @@ -349,15 +338,16 @@ func (c *codecImpl) decodeSerializedPath(src *bytes.Reader) (SerializedPath, err return SerializedPath{}, io.ErrUnexpectedEOF } - var ( - result SerializedPath - err error - ) - if result.NibbleLength, err = c.decodeInt(src); err != nil { + nibbleLength, err := c.decodeUint(src) + if err != nil { return SerializedPath{}, err } - if result.NibbleLength < 0 { - return SerializedPath{}, errNegativeNibbleLength + if nibbleLength > math.MaxInt { + return SerializedPath{}, errIntOverflow + } + + result := SerializedPath{ + NibbleLength: int(nibbleLength), } pathBytesLen := result.NibbleLength >> 1 hasOddLen := result.hasOddLength() diff --git a/x/merkledb/codec_test.go b/x/merkledb/codec_test.go index 7b3c17d58ee1..715d06f69cea 100644 --- a/x/merkledb/codec_test.go +++ b/x/merkledb/codec_test.go @@ -54,7 +54,7 @@ func FuzzCodecInt(f *testing.F) { codec := codec.(*codecImpl) reader := bytes.NewReader(b) startLen := reader.Len() - got, err := codec.decodeInt(reader) + got, err := codec.decodeUint(reader) if err != nil { t.SkipNow() } @@ -63,7 +63,7 @@ func FuzzCodecInt(f *testing.F) { // Encoding [got] should be the same as [b]. var buf bytes.Buffer - codec.encodeInt(&buf, got) + codec.encodeUint(&buf, got) bufBytes := buf.Bytes() require.Len(bufBytes, numRead) require.Equal(b[:numRead], bufBytes) @@ -195,22 +195,12 @@ func TestCodecDecodeDBNode(t *testing.T) { } nodeBytes := codec.encodeDBNode(&proof) - // Remove num children (0) from end nodeBytes = nodeBytes[:len(nodeBytes)-minVarIntLen] proofBytesBuf := bytes.NewBuffer(nodeBytes) - // Put num children -1 at end - codec.(*codecImpl).encodeInt(proofBytesBuf, -1) - - err = codec.decodeDBNode(proofBytesBuf.Bytes(), &parsedDBNode) - require.ErrorIs(err, errNegativeNumChildren) - // Remove num children from end - nodeBytes = proofBytesBuf.Bytes() - nodeBytes = nodeBytes[:len(nodeBytes)-minVarIntLen] - proofBytesBuf = bytes.NewBuffer(nodeBytes) // Put num children NodeBranchFactor+1 at end - codec.(*codecImpl).encodeInt(proofBytesBuf, NodeBranchFactor+1) + codec.(*codecImpl).encodeUint(proofBytesBuf, NodeBranchFactor+1) err = codec.decodeDBNode(proofBytesBuf.Bytes(), &parsedDBNode) require.ErrorIs(err, errTooManyChildren)