Skip to content

Commit

Permalink
merkledb -- codec remove err checks (#1899)
Browse files Browse the repository at this point in the history
  • Loading branch information
Dan Laine authored Aug 23, 2023
1 parent 5caabf1 commit cf6eedc
Show file tree
Hide file tree
Showing 5 changed files with 55 additions and 116 deletions.
110 changes: 36 additions & 74 deletions x/merkledb/codec.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ var (
trueBytes = []byte{trueByte}
falseBytes = []byte{falseByte}

errEncodeNil = errors.New("can't encode nil pointer or interface")
errDecodeNil = errors.New("can't decode nil")
errNegativeNumChildren = errors.New("number of children is negative")
errTooManyChildren = fmt.Errorf("length of children list is larger than branching factor of %d", NodeBranchFactor)
Expand All @@ -55,8 +54,10 @@ type encoderDecoder interface {
}

type encoder interface {
encodeDBNode(n *dbNode) ([]byte, error)
encodeHashValues(hv *hashValues) ([]byte, error)
// Assumes [n] is non-nil.
encodeDBNode(n *dbNode) []byte
// Assumes [hv] is non-nil.
encodeHashValues(hv *hashValues) []byte
}

type decoder interface {
Expand All @@ -73,71 +74,45 @@ func newCodec() encoderDecoder {
}
}

// Note that bytes.Buffer.Write always returns nil so we
// can ignore its return values in [codecImpl] methods.
type codecImpl struct {
varIntPool sync.Pool
}

func (c *codecImpl) encodeDBNode(n *dbNode) ([]byte, error) {
if n == nil {
return nil, errEncodeNil
}

func (c *codecImpl) encodeDBNode(n *dbNode) []byte {
buf := &bytes.Buffer{}
if err := c.encodeMaybeByteSlice(buf, n.value); err != nil {
return nil, err
}
c.encodeMaybeByteSlice(buf, n.value)
childrenLength := len(n.children)
if err := c.encodeInt(buf, childrenLength); err != nil {
return nil, err
}
c.encodeInt(buf, childrenLength)
for index := byte(0); index < NodeBranchFactor; index++ {
if entry, ok := n.children[index]; ok {
if err := c.encodeInt(buf, int(index)); err != nil {
return nil, err
}
c.encodeInt(buf, int(index))
path := entry.compressedPath.Serialize()
if err := c.encodeSerializedPath(path, buf); err != nil {
return nil, err
}
if _, err := buf.Write(entry.id[:]); err != nil {
return nil, err
}
c.encodeSerializedPath(path, buf)
_, _ = buf.Write(entry.id[:])
}
}
return buf.Bytes(), nil
return buf.Bytes()
}

func (c *codecImpl) encodeHashValues(hv *hashValues) ([]byte, error) {
if hv == nil {
return nil, errEncodeNil
}

func (c *codecImpl) encodeHashValues(hv *hashValues) []byte {
buf := &bytes.Buffer{}

length := len(hv.Children)
if err := c.encodeInt(buf, length); err != nil {
return nil, err
}
c.encodeInt(buf, length)

// ensure that the order of entries is consistent
for index := byte(0); index < NodeBranchFactor; index++ {
if entry, ok := hv.Children[index]; ok {
if err := c.encodeInt(buf, int(index)); err != nil {
return nil, err
}
if _, err := buf.Write(entry.id[:]); err != nil {
return nil, err
}
c.encodeInt(buf, int(index))
_, _ = buf.Write(entry.id[:])
}
}
if err := c.encodeMaybeByteSlice(buf, hv.Value); err != nil {
return nil, err
}
if err := c.encodeSerializedPath(hv.Key, buf); err != nil {
return nil, err
}
c.encodeMaybeByteSlice(buf, hv.Value)
c.encodeSerializedPath(hv.Key, buf)

return buf.Bytes(), nil
return buf.Bytes()
}

func (c *codecImpl) decodeDBNode(b []byte, n *dbNode) error {
Expand Down Expand Up @@ -201,13 +176,12 @@ func (c *codecImpl) decodeDBNode(b []byte, n *dbNode) error {
return err
}

func (*codecImpl) encodeBool(dst io.Writer, value bool) error {
func (*codecImpl) encodeBool(dst *bytes.Buffer, value bool) {
bytesValue := falseBytes
if value {
bytesValue = trueBytes
}
_, err := dst.Write(bytesValue)
return err
_, _ = dst.Write(bytesValue)
}

func (*codecImpl) decodeBool(src *bytes.Reader) (bool, error) {
Expand All @@ -228,8 +202,8 @@ func (*codecImpl) decodeBool(src *bytes.Reader) (bool, error) {
}
}

func (c *codecImpl) encodeInt(dst io.Writer, value int) error {
return c.encodeInt64(dst, int64(value))
func (c *codecImpl) encodeInt(dst *bytes.Buffer, value int) {
c.encodeInt64(dst, int64(value))
}

func (*codecImpl) decodeInt(src *bytes.Reader) (int, error) {
Expand Down Expand Up @@ -267,22 +241,18 @@ func (*codecImpl) decodeInt(src *bytes.Reader) (int, error) {
return int(val64), nil
}

func (c *codecImpl) encodeInt64(dst io.Writer, value int64) error {
func (c *codecImpl) encodeInt64(dst *bytes.Buffer, value int64) {
buf := c.varIntPool.Get().([]byte)
size := binary.PutVarint(buf, value)
_, err := dst.Write(buf[:size])
_, _ = dst.Write(buf[:size])
c.varIntPool.Put(buf)
return err
}

func (c *codecImpl) encodeMaybeByteSlice(dst io.Writer, maybeValue maybe.Maybe[[]byte]) error {
if err := c.encodeBool(dst, !maybeValue.IsNothing()); err != nil {
return err
}
if maybeValue.IsNothing() {
return nil
func (c *codecImpl) encodeMaybeByteSlice(dst *bytes.Buffer, maybeValue maybe.Maybe[[]byte]) {
c.encodeBool(dst, !maybeValue.IsNothing())
if maybeValue.HasValue() {
c.encodeByteSlice(dst, maybeValue.Value())
}
return c.encodeByteSlice(dst, maybeValue.Value())
}

func (c *codecImpl) decodeMaybeByteSlice(src *bytes.Reader) (maybe.Maybe[[]byte], error) {
Expand Down Expand Up @@ -338,16 +308,11 @@ func (c *codecImpl) decodeByteSlice(src *bytes.Reader) ([]byte, error) {
return result, nil
}

func (c *codecImpl) encodeByteSlice(dst io.Writer, value []byte) error {
if err := c.encodeInt(dst, len(value)); err != nil {
return err
}
func (c *codecImpl) encodeByteSlice(dst *bytes.Buffer, value []byte) {
c.encodeInt(dst, len(value))
if value != nil {
if _, err := dst.Write(value); err != nil {
return err
}
_, _ = dst.Write(value)
}
return nil
}

func (*codecImpl) decodeID(src *bytes.Reader) (ids.ID, error) {
Expand All @@ -365,12 +330,9 @@ func (*codecImpl) decodeID(src *bytes.Reader) (ids.ID, error) {
return id, nil
}

func (c *codecImpl) encodeSerializedPath(s SerializedPath, dst io.Writer) error {
if err := c.encodeInt(dst, s.NibbleLength); err != nil {
return err
}
_, err := dst.Write(s.Value)
return err
func (c *codecImpl) encodeSerializedPath(s SerializedPath, dst *bytes.Buffer) {
c.encodeInt(dst, s.NibbleLength)
_, _ = dst.Write(s.Value)
}

func (c *codecImpl) decodeSerializedPath(src *bytes.Reader) (SerializedPath, error) {
Expand Down
22 changes: 9 additions & 13 deletions x/merkledb/codec_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ func FuzzCodecBool(f *testing.F) {

// Encoding [got] should be the same as [b].
var buf bytes.Buffer
require.NoError(codec.encodeBool(&buf, got))
codec.encodeBool(&buf, got)
bufBytes := buf.Bytes()
require.Len(bufBytes, numRead)
require.Equal(b[:numRead], bufBytes)
Expand Down Expand Up @@ -157,7 +157,7 @@ func FuzzCodecInt(f *testing.F) {

// Encoding [got] should be the same as [b].
var buf bytes.Buffer
require.NoError(codec.encodeInt(&buf, got))
codec.encodeInt(&buf, got)
bufBytes := buf.Bytes()
require.Len(bufBytes, numRead)
require.Equal(b[:numRead], bufBytes)
Expand Down Expand Up @@ -185,7 +185,7 @@ func FuzzCodecSerializedPath(f *testing.F) {

// Encoding [got] should be the same as [b].
var buf bytes.Buffer
require.NoError(codec.encodeSerializedPath(got, &buf))
codec.encodeSerializedPath(got, &buf)
bufBytes := buf.Bytes()
require.Len(bufBytes, numRead)
require.Equal(b[:numRead], bufBytes)
Expand All @@ -211,8 +211,7 @@ func FuzzCodecDBNodeCanonical(f *testing.F) {
}

// Encoding [node] should be the same as [b].
buf, err := codec.encodeDBNode(node)
require.NoError(err)
buf := codec.encodeDBNode(node)
require.Equal(b, buf)
},
)
Expand Down Expand Up @@ -264,8 +263,7 @@ func FuzzCodecDBNodeDeterministic(f *testing.F) {
children: children,
}

nodeBytes, err := codec.encodeDBNode(&node)
require.NoError(err)
nodeBytes := codec.encodeDBNode(&node)

var gotNode dbNode
require.NoError(codec.decodeDBNode(nodeBytes, &gotNode))
Expand All @@ -274,8 +272,7 @@ func FuzzCodecDBNodeDeterministic(f *testing.F) {
nilEmptySlices(&gotNode)
require.Equal(node, gotNode)

nodeBytes2, err := codec.encodeDBNode(&gotNode)
require.NoError(err)
nodeBytes2 := codec.encodeDBNode(&gotNode)
require.Equal(nodeBytes, nodeBytes2)
},
)
Expand All @@ -299,14 +296,13 @@ func TestCodec_DecodeDBNode(t *testing.T) {
children: map[byte]child{},
}

nodeBytes, err := codec.encodeDBNode(&proof)
require.NoError(err)
nodeBytes := codec.encodeDBNode(&proof)

// Remove num children (0) from end
nodeBytes = nodeBytes[:len(nodeBytes)-minVarIntLen]
proofBytesBuf := bytes.NewBuffer(nodeBytes)
// Put num children -1 at end
require.NoError(codec.(*codecImpl).encodeInt(proofBytesBuf, -1))
codec.(*codecImpl).encodeInt(proofBytesBuf, -1)

err = codec.decodeDBNode(proofBytesBuf.Bytes(), &parsedDBNode)
require.ErrorIs(err, errNegativeNumChildren)
Expand All @@ -316,7 +312,7 @@ func TestCodec_DecodeDBNode(t *testing.T) {
nodeBytes = nodeBytes[:len(nodeBytes)-minVarIntLen]
proofBytesBuf = bytes.NewBuffer(nodeBytes)
// Put num children NodeBranchFactor+1 at end
require.NoError(codec.(*codecImpl).encodeInt(proofBytesBuf, NodeBranchFactor+1))
codec.(*codecImpl).encodeInt(proofBytesBuf, NodeBranchFactor+1)

err = codec.decodeDBNode(proofBytesBuf.Bytes(), &parsedDBNode)
require.ErrorIs(err, errTooManyChildren)
Expand Down
12 changes: 2 additions & 10 deletions x/merkledb/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -830,11 +830,7 @@ func (db *merkleDB) onEviction(n *node) error {

// Writes [n] to [batch]. Assumes [n] is non-nil.
func writeNodeToBatch(batch database.Batch, n *node) error {
nodeBytes, err := n.marshal()
if err != nil {
return err
}

nodeBytes := n.marshal()
return batch.Put(n.key.Bytes(), nodeBytes)
}

Expand Down Expand Up @@ -1167,12 +1163,8 @@ func (db *merkleDB) initializeRootIfNeeded() (ids.ID, error) {
}

// write the newly constructed root to the DB
rootBytes, err := db.root.marshal()
if err != nil {
return ids.Empty, err
}

batch := db.nodeDB.NewBatch()
rootBytes := db.root.marshal()
if err := batch.Put(rootKey, rootBytes); err != nil {
return ids.Empty, err
}
Expand Down
19 changes: 5 additions & 14 deletions x/merkledb/node.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,17 +81,12 @@ func (n *node) hasValue() bool {
}

// Returns the byte representation of this node.
func (n *node) marshal() ([]byte, error) {
if n.nodeBytes != nil {
return n.nodeBytes, nil
func (n *node) marshal() []byte {
if n.nodeBytes == nil {
n.nodeBytes = codec.encodeDBNode(&n.dbNode)
}

nodeBytes, err := codec.encodeDBNode(&n.dbNode)
if err != nil {
return nil, err
}
n.nodeBytes = nodeBytes
return n.nodeBytes, nil
return n.nodeBytes
}

// clear the cached values that will need to be recalculated whenever the node changes
Expand All @@ -113,11 +108,7 @@ func (n *node) calculateID(metrics merkleMetrics) error {
Key: n.key.Serialize(),
}

bytes, err := codec.encodeHashValues(hv)
if err != nil {
return err
}

bytes := codec.encodeHashValues(hv)
metrics.HashCalculated()
n.id = hashing.ComputeHash256Array(bytes)
return nil
Expand Down
8 changes: 3 additions & 5 deletions x/merkledb/node_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,7 @@ func Test_Node_Marshal(t *testing.T) {
require.NoError(t, childNode.calculateID(&mockMetrics{}))
root.addChild(childNode)

data, err := root.marshal()
require.NoError(t, err)
data := root.marshal()
rootParsed, err := parseNode(newPath([]byte("")), data)
require.NoError(t, err)
require.Len(t, rootParsed.children, 1)
Expand Down Expand Up @@ -57,12 +56,11 @@ func Test_Node_Marshal_Errors(t *testing.T) {
require.NoError(t, childNode2.calculateID(&mockMetrics{}))
root.addChild(childNode2)

data, err := root.marshal()
require.NoError(t, err)
data := root.marshal()

for i := 1; i < len(data); i++ {
broken := data[:i]
_, err = parseNode(newPath([]byte("")), broken)
_, err := parseNode(newPath([]byte("")), broken)
require.ErrorIs(t, err, io.ErrUnexpectedEOF)
}
}

0 comments on commit cf6eedc

Please sign in to comment.