Skip to content

Commit

Permalink
Remove sentinel node from MerkleDB proofs (#2106)
Browse files Browse the repository at this point in the history
Signed-off-by: David Boehm <[email protected]>
Co-authored-by: Stephen Buttolph <[email protected]>
Co-authored-by: Dan Laine <[email protected]>
  • Loading branch information
3 people authored Nov 9, 2023
1 parent 094ce50 commit 86201ae
Show file tree
Hide file tree
Showing 6 changed files with 170 additions and 83 deletions.
60 changes: 37 additions & 23 deletions x/merkledb/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,7 @@ const (
)

var (
rootKey []byte
_ MerkleDB = (*merkleDB)(nil)
_ MerkleDB = (*merkleDB)(nil)

codec = newCodec()

Expand All @@ -54,8 +53,8 @@ var (
hadCleanShutdown = []byte{1}
didNotHaveCleanShutdown = []byte{0}

errSameRoot = errors.New("start and end root are the same")
errNoNewRoot = errors.New("there was no updated root in change list")
errSameRoot = errors.New("start and end root are the same")
errNoNewSentinel = errors.New("there was no updated sentinel node in change list")
)

type ChangeProofer interface {
Expand Down Expand Up @@ -194,8 +193,10 @@ type merkleDB struct {
debugTracer trace.Tracer
infoTracer trace.Tracer

// The root of this trie.
root *node
// The sentinel node of this trie.
// It is the node with a nil key and is the ancestor of all nodes in the trie.
// If it has a value or has multiple children, it is also the root of the trie.
sentinelNode *node

// Valid children of this trie.
childViews []*trieView
Expand Down Expand Up @@ -286,7 +287,7 @@ func newDatabase(
// Deletes every intermediate node and rebuilds them by re-adding every key/value.
// TODO: make this more efficient by only clearing out the stale portions of the trie.
func (db *merkleDB) rebuild(ctx context.Context, cacheSize int) error {
db.root = newNode(Key{})
db.sentinelNode = newNode(Key{})

// Delete intermediate nodes.
if err := database.ClearPrefix(db.baseDB, intermediateNodePrefix, rebuildIntermediateDeletionWriteSize); err != nil {
Expand Down Expand Up @@ -569,7 +570,20 @@ func (db *merkleDB) GetMerkleRoot(ctx context.Context) (ids.ID, error) {

// Assumes [db.lock] is read locked.
func (db *merkleDB) getMerkleRoot() ids.ID {
return db.root.id
if !isSentinelNodeTheRoot(db.sentinelNode) {
// if the sentinel node should be skipped, the trie's root is the nil key node's only child
for _, childEntry := range db.sentinelNode.children {
return childEntry.id
}
}
return db.sentinelNode.id
}

// isSentinelNodeTheRoot returns true if the passed in sentinel node has a value and or multiple child nodes
// When this is true, the root of the trie is the sentinel node
// When this is false, the root of the trie is the sentinel node's single child
func isSentinelNodeTheRoot(sentinel *node) bool {
return sentinel.valueDigest.HasValue() || len(sentinel.children) != 1
}

func (db *merkleDB) GetProof(ctx context.Context, key []byte) (*Proof, error) {
Expand Down Expand Up @@ -915,9 +929,9 @@ func (db *merkleDB) commitChanges(ctx context.Context, trieToCommit *trieView) e
return nil
}

rootChange, ok := changes.nodes[Key{}]
sentinelChange, ok := changes.nodes[Key{}]
if !ok {
return errNoNewRoot
return errNoNewSentinel
}

currentValueNodeBatch := db.valueNodeDB.NewBatch()
Expand Down Expand Up @@ -959,7 +973,7 @@ func (db *merkleDB) commitChanges(ctx context.Context, trieToCommit *trieView) e

// Only modify in-memory state after the commit succeeds
// so that we don't need to clean up on error.
db.root = rootChange.after
db.sentinelNode = sentinelChange.after
db.history.record(changes)
return nil
}
Expand Down Expand Up @@ -1140,33 +1154,33 @@ func (db *merkleDB) invalidateChildrenExcept(exception *trieView) {
}

func (db *merkleDB) initializeRootIfNeeded() (ids.ID, error) {
// not sure if the root exists or had a value or not
// not sure if the sentinel node exists or if it had a value
// check under both prefixes
var err error
db.root, err = db.intermediateNodeDB.Get(Key{})
db.sentinelNode, err = db.intermediateNodeDB.Get(Key{})
if errors.Is(err, database.ErrNotFound) {
db.root, err = db.valueNodeDB.Get(Key{})
db.sentinelNode, err = db.valueNodeDB.Get(Key{})
}
if err == nil {
// Root already exists, so calculate its id
db.root.calculateID(db.metrics)
return db.root.id, nil
// sentinel node already exists, so calculate the root ID of the trie
db.sentinelNode.calculateID(db.metrics)
return db.getMerkleRoot(), nil
}
if !errors.Is(err, database.ErrNotFound) {
return ids.Empty, err
}

// Root doesn't exist; make a new one.
db.root = newNode(Key{})
// sentinel node doesn't exist; make a new one.
db.sentinelNode = newNode(Key{})

// update its ID
db.root.calculateID(db.metrics)
db.sentinelNode.calculateID(db.metrics)

if err := db.intermediateNodeDB.Put(Key{}, db.root); err != nil {
if err := db.intermediateNodeDB.Put(Key{}, db.sentinelNode); err != nil {
return ids.Empty, err
}

return db.root.id, nil
return db.sentinelNode.id, nil
}

// Returns a view of the trie as it was when it had root [rootID] for keys within range [start, end].
Expand Down Expand Up @@ -1243,7 +1257,7 @@ func (db *merkleDB) getNode(key Key, hasValue bool) (*node, error) {
case db.closed:
return nil, database.ErrClosed
case key == Key{}:
return db.root, nil
return db.sentinelNode, nil
case hasValue:
return db.valueNodeDB.Get(key)
}
Expand Down
17 changes: 11 additions & 6 deletions x/merkledb/history_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ func Test_History_Simple(t *testing.T) {
origProof, err := db.GetRangeProof(context.Background(), maybe.Some([]byte("k")), maybe.Some([]byte("key3")), 10)
require.NoError(err)
require.NotNil(origProof)
origRootID := db.root.id

origRootID := db.getMerkleRoot()
require.NoError(origProof.Verify(context.Background(), maybe.Some([]byte("k")), maybe.Some([]byte("key3")), origRootID, db.tokenSize))

batch = db.NewBatch()
Expand Down Expand Up @@ -338,7 +339,8 @@ func Test_History_RepeatedRoot(t *testing.T) {
origProof, err := db.GetRangeProof(context.Background(), maybe.Some([]byte("k")), maybe.Some([]byte("key3")), 10)
require.NoError(err)
require.NotNil(origProof)
origRootID := db.root.id

origRootID := db.getMerkleRoot()
require.NoError(origProof.Verify(context.Background(), maybe.Some([]byte("k")), maybe.Some([]byte("key3")), origRootID, db.tokenSize))

batch = db.NewBatch()
Expand Down Expand Up @@ -380,7 +382,8 @@ func Test_History_ExcessDeletes(t *testing.T) {
origProof, err := db.GetRangeProof(context.Background(), maybe.Some([]byte("k")), maybe.Some([]byte("key3")), 10)
require.NoError(err)
require.NotNil(origProof)
origRootID := db.root.id

origRootID := db.getMerkleRoot()
require.NoError(origProof.Verify(context.Background(), maybe.Some([]byte("k")), maybe.Some([]byte("key3")), origRootID, db.tokenSize))

batch = db.NewBatch()
Expand Down Expand Up @@ -412,7 +415,8 @@ func Test_History_DontIncludeAllNodes(t *testing.T) {
origProof, err := db.GetRangeProof(context.Background(), maybe.Some([]byte("k")), maybe.Some([]byte("key3")), 10)
require.NoError(err)
require.NotNil(origProof)
origRootID := db.root.id

origRootID := db.getMerkleRoot()
require.NoError(origProof.Verify(context.Background(), maybe.Some([]byte("k")), maybe.Some([]byte("key3")), origRootID, db.tokenSize))

batch = db.NewBatch()
Expand Down Expand Up @@ -440,7 +444,7 @@ func Test_History_Branching2Nodes(t *testing.T) {
origProof, err := db.GetRangeProof(context.Background(), maybe.Some([]byte("k")), maybe.Some([]byte("key3")), 10)
require.NoError(err)
require.NotNil(origProof)
origRootID := db.root.id
origRootID := db.getMerkleRoot()
require.NoError(origProof.Verify(context.Background(), maybe.Some([]byte("k")), maybe.Some([]byte("key3")), origRootID, db.tokenSize))

batch = db.NewBatch()
Expand Down Expand Up @@ -468,7 +472,8 @@ func Test_History_Branching3Nodes(t *testing.T) {
origProof, err := db.GetRangeProof(context.Background(), maybe.Some([]byte("k")), maybe.Some([]byte("key3")), 10)
require.NoError(err)
require.NotNil(origProof)
origRootID := db.root.id

origRootID := db.getMerkleRoot()
require.NoError(origProof.Verify(context.Background(), maybe.Some([]byte("k")), maybe.Some([]byte("key3")), origRootID, db.tokenSize))

batch = db.NewBatch()
Expand Down
49 changes: 23 additions & 26 deletions x/merkledb/proof_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,9 @@ func Test_Proof_Verify_Bad_Data(t *testing.T) {
expectedErr: nil,
},
{
name: "odd length key with value",
name: "odd length key path with value",
malform: func(proof *Proof) {
proof.Path[1].ValueOrHash = maybe.Some([]byte{1, 2})
proof.Path[0].ValueOrHash = maybe.Some([]byte{1, 2})
},
expectedErr: ErrPartialByteLengthWithValue,
},
Expand Down Expand Up @@ -150,7 +150,7 @@ func Test_RangeProof_Extra_Value(t *testing.T) {
context.Background(),
maybe.Some([]byte{1}),
maybe.Some([]byte{5, 5}),
db.root.id,
db.getMerkleRoot(),
db.tokenSize,
))

Expand All @@ -160,7 +160,7 @@ func Test_RangeProof_Extra_Value(t *testing.T) {
context.Background(),
maybe.Some([]byte{1}),
maybe.Some([]byte{5, 5}),
db.root.id,
db.getMerkleRoot(),
db.tokenSize,
)
require.ErrorIs(err, ErrInvalidProof)
Expand All @@ -187,9 +187,9 @@ func Test_RangeProof_Verify_Bad_Data(t *testing.T) {
expectedErr: ErrProofValueDoesntMatch,
},
{
name: "EndProof: odd length key with value",
name: "EndProof: odd length key path with value",
malform: func(proof *RangeProof) {
proof.EndProof[1].ValueOrHash = maybe.Some([]byte{1, 2})
proof.EndProof[0].ValueOrHash = maybe.Some([]byte{1, 2})
},
expectedErr: ErrPartialByteLengthWithValue,
},
Expand Down Expand Up @@ -255,6 +255,7 @@ func Test_Proof(t *testing.T) {
context.Background(),
ViewChanges{
BatchOps: []database.BatchOp{
{Key: []byte("key"), Value: []byte("value")},
{Key: []byte("key0"), Value: []byte("value0")},
{Key: []byte("key1"), Value: []byte("value1")},
{Key: []byte("key2"), Value: []byte("value2")},
Expand All @@ -273,12 +274,11 @@ func Test_Proof(t *testing.T) {

require.Len(proof.Path, 3)

require.Equal(ToKey([]byte("key")), proof.Path[0].Key)
require.Equal(maybe.Some([]byte("value")), proof.Path[0].ValueOrHash)
require.Equal(ToKey([]byte("key1")), proof.Path[2].Key)
require.Equal(maybe.Some([]byte("value1")), proof.Path[2].ValueOrHash)

require.Equal(ToKey([]byte{}), proof.Path[0].Key)
require.True(proof.Path[0].ValueOrHash.IsNothing())

expectedRootID, err := trie.GetMerkleRoot(context.Background())
require.NoError(err)
require.NoError(proof.Verify(context.Background(), expectedRootID, dbTrie.tokenSize))
Expand Down Expand Up @@ -501,9 +501,8 @@ func Test_RangeProof(t *testing.T) {
require.Equal([]byte{2}, proof.KeyValues[1].Value)
require.Equal([]byte{3}, proof.KeyValues[2].Value)

require.Nil(proof.EndProof[0].Key.Bytes())
require.Equal([]byte{0}, proof.EndProof[1].Key.Bytes())
require.Equal([]byte{3}, proof.EndProof[2].Key.Bytes())
require.Equal([]byte{0}, proof.EndProof[0].Key.Bytes())
require.Equal([]byte{3}, proof.EndProof[1].Key.Bytes())

// only a single node here since others are duplicates in endproof
require.Equal([]byte{1}, proof.StartProof[0].Key.Bytes())
Expand All @@ -512,7 +511,7 @@ func Test_RangeProof(t *testing.T) {
context.Background(),
maybe.Some([]byte{1}),
maybe.Some([]byte{3, 5}),
db.root.id,
db.getMerkleRoot(),
db.tokenSize,
))
}
Expand Down Expand Up @@ -557,15 +556,14 @@ func Test_RangeProof_NilStart(t *testing.T) {
require.Equal([]byte("value1"), proof.KeyValues[0].Value)
require.Equal([]byte("value2"), proof.KeyValues[1].Value)

require.Equal(ToKey([]byte("key2")), proof.EndProof[2].Key)
require.Equal(ToKey([]byte("key2")).Take(28), proof.EndProof[1].Key)
require.Equal(ToKey([]byte("")), proof.EndProof[0].Key)
require.Equal(ToKey([]byte("key2")), proof.EndProof[1].Key)
require.Equal(ToKey([]byte("key2")).Take(28), proof.EndProof[0].Key)

require.NoError(proof.Verify(
context.Background(),
maybe.Nothing[[]byte](),
maybe.Some([]byte("key35")),
db.root.id,
db.getMerkleRoot(),
db.tokenSize,
))
}
Expand All @@ -592,15 +590,14 @@ func Test_RangeProof_NilEnd(t *testing.T) {

require.Equal([]byte{1}, proof.StartProof[0].Key.Bytes())

require.Nil(proof.EndProof[0].Key.Bytes())
require.Equal([]byte{0}, proof.EndProof[1].Key.Bytes())
require.Equal([]byte{2}, proof.EndProof[2].Key.Bytes())
require.Equal([]byte{0}, proof.EndProof[0].Key.Bytes())
require.Equal([]byte{2}, proof.EndProof[1].Key.Bytes())

require.NoError(proof.Verify(
context.Background(),
maybe.Some([]byte{1}),
maybe.Nothing[[]byte](),
db.root.id,
db.getMerkleRoot(),
db.tokenSize,
))
}
Expand Down Expand Up @@ -635,15 +632,15 @@ func Test_RangeProof_EmptyValues(t *testing.T) {
require.Len(proof.StartProof, 1)
require.Equal(ToKey([]byte("key1")), proof.StartProof[0].Key)

require.Len(proof.EndProof, 3)
require.Equal(ToKey([]byte("key2")), proof.EndProof[2].Key)
require.Equal(ToKey([]byte{}), proof.EndProof[0].Key)
require.Len(proof.EndProof, 2)
require.Equal(ToKey([]byte("key2")), proof.EndProof[1].Key)
require.Equal(ToKey([]byte("key2")).Take(28), proof.EndProof[0].Key)

require.NoError(proof.Verify(
context.Background(),
maybe.Some([]byte("key1")),
maybe.Some([]byte("key2")),
db.root.id,
db.getMerkleRoot(),
db.tokenSize,
))
}
Expand Down Expand Up @@ -779,7 +776,7 @@ func Test_ChangeProof_Verify_Bad_Data(t *testing.T) {
{
name: "odd length key path with value",
malform: func(proof *ChangeProof) {
proof.EndProof[1].ValueOrHash = maybe.Some([]byte{1, 2})
proof.EndProof[0].ValueOrHash = maybe.Some([]byte{1, 2})
},
expectedErr: ErrPartialByteLengthWithValue,
},
Expand Down
Loading

0 comments on commit 86201ae

Please sign in to comment.