Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

merkledb -- limit number of goroutines calculating node IDs #1960

Merged
merged 13 commits into from
Sep 5, 2023
36 changes: 15 additions & 21 deletions x/merkledb/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (

"golang.org/x/exp/maps"
"golang.org/x/exp/slices"
"golang.org/x/sync/semaphore"

"go.opentelemetry.io/otel/attribute"

Expand Down Expand Up @@ -182,12 +183,9 @@ type merkleDB struct {
// Valid children of this trie.
childViews []*trieView

// rootGenConcurrency is the number of goroutines to use when
// generating a new state root.
//
// TODO: Limit concurrency across all views, instead of only within
// a single view (see `workers` in hypersdk)
rootGenConcurrency int
// calculateNodeIDsSema controls the number of goroutines calculating
// node IDs. Shared across all views.
calculateNodeIDsSema *semaphore.Weighted
}

// New returns a new merkle database.
Expand Down Expand Up @@ -218,15 +216,15 @@ func newDatabase(
},
}
trieDB := &merkleDB{
metrics: metrics,
baseDB: db,
valueNodeDB: newValueNodeDB(db, bufferPool, metrics, config.ValueNodeCacheSize),
intermediateNodeDB: newIntermediateNodeDB(db, bufferPool, metrics, config.IntermediateNodeCacheSize, config.EvictionBatchSize),
history: newTrieHistory(config.HistoryLength),
debugTracer: getTracerIfEnabled(config.TraceLevel, DebugTrace, config.Tracer),
infoTracer: getTracerIfEnabled(config.TraceLevel, InfoTrace, config.Tracer),
childViews: make([]*trieView, 0, defaultPreallocationSize),
rootGenConcurrency: rootGenConcurrency,
metrics: metrics,
baseDB: db,
valueNodeDB: newValueNodeDB(db, bufferPool, metrics, config.ValueNodeCacheSize),
intermediateNodeDB: newIntermediateNodeDB(db, bufferPool, metrics, config.IntermediateNodeCacheSize, config.EvictionBatchSize),
history: newTrieHistory(config.HistoryLength),
debugTracer: getTracerIfEnabled(config.TraceLevel, DebugTrace, config.Tracer),
infoTracer: getTracerIfEnabled(config.TraceLevel, InfoTrace, config.Tracer),
childViews: make([]*trieView, 0, defaultPreallocationSize),
calculateNodeIDsSema: semaphore.NewWeighted(int64(rootGenConcurrency)),
}

root, err := trieDB.initializeRootIfNeeded()
Expand Down Expand Up @@ -1085,9 +1083,7 @@ func (db *merkleDB) initializeRootIfNeeded() (ids.ID, error) {
}
if err == nil {
// Root already exists, so calculate its id
if err := db.root.calculateID(db.metrics); err != nil {
return ids.Empty, err
}
db.root.calculateID(db.metrics)
return db.root.id, nil
}
if err != database.ErrNotFound {
Expand All @@ -1098,9 +1094,7 @@ func (db *merkleDB) initializeRootIfNeeded() (ids.ID, error) {
db.root = newNode(nil, RootPath)

// update its ID
if err := db.root.calculateID(db.metrics); err != nil {
return ids.Empty, err
}
db.root.calculateID(db.metrics)

if err := db.intermediateNodeDB.Put(RootPath, db.root); err != nil {
return ids.Empty, err
Expand Down
13 changes: 5 additions & 8 deletions x/merkledb/node.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,21 +98,18 @@ func (n *node) onNodeChanged() {
}

// Returns and caches the ID of this node.
func (n *node) calculateID(metrics merkleMetrics) error {
func (n *node) calculateID(metrics merkleMetrics) {
if n.id != ids.Empty {
return nil
return
}

hv := &hashValues{
metrics.HashCalculated()
bytes := codec.encodeHashValues(&hashValues{
Children: n.children,
Value: n.valueDigest,
Key: n.key.Serialize(),
}

bytes := codec.encodeHashValues(hv)
metrics.HashCalculated()
})
n.id = hashing.ComputeHash256Array(bytes)
return nil
}

// Set [n]'s value to [val].
Expand Down
6 changes: 3 additions & 3 deletions x/merkledb/node_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ func Test_Node_Marshal(t *testing.T) {
childNode.setValue(maybe.Some([]byte("value")))
require.NotNil(t, childNode)

require.NoError(t, childNode.calculateID(&mockMetrics{}))
childNode.calculateID(&mockMetrics{})
root.addChild(childNode)

data := root.bytes()
Expand All @@ -45,15 +45,15 @@ func Test_Node_Marshal_Errors(t *testing.T) {
childNode1.setValue(maybe.Some([]byte("value1")))
require.NotNil(t, childNode1)

require.NoError(t, childNode1.calculateID(&mockMetrics{}))
childNode1.calculateID(&mockMetrics{})
root.addChild(childNode1)

fullpath = newPath([]byte{237})
childNode2 := newNode(root, fullpath)
childNode2.setValue(maybe.Some([]byte("value2")))
require.NotNil(t, childNode2)

require.NoError(t, childNode2.calculateID(&mockMetrics{}))
childNode2.calculateID(&mockMetrics{})
root.addChild(childNode2)

data := root.bytes()
Expand Down
43 changes: 18 additions & 25 deletions x/merkledb/trieview.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ import (
oteltrace "go.opentelemetry.io/otel/trace"

"golang.org/x/exp/slices"
"golang.org/x/sync/errgroup"

"github.com/ava-labs/avalanchego/database"
"github.com/ava-labs/avalanchego/ids"
Expand Down Expand Up @@ -213,6 +212,7 @@ func newHistoricalTrieView(
}

// Recalculates the node IDs for all changed nodes in the trie.
// Cancelling [ctx] doesn't cancel calculation. It's used only for tracing.
func (t *trieView) calculateNodeIDs(ctx context.Context) error {
var err error
t.calculateNodesOnce.Do(func() {
Expand All @@ -231,23 +231,19 @@ func (t *trieView) calculateNodeIDs(ctx context.Context) error {
// add all the changed key/values to the nodes of the trie
for key, change := range t.changes.values {
if change.after.IsNothing() {
// Note we're setting [err] defined outside this function.
if err = t.remove(key); err != nil {
return
}
// Note we're setting [err] defined outside this function.
} else if _, err = t.insert(key, change.after); err != nil {
return
}
}

// [eg] limits the number of goroutines we start.
var eg errgroup.Group
eg.SetLimit(t.db.rootGenConcurrency)
if err = t.calculateNodeIDsHelper(ctx, t.root, &eg); err != nil {
return
}
if err = eg.Wait(); err != nil {
return
}
_ = t.db.calculateNodeIDsSema.Acquire(context.Background(), 1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this to have the # of threads match the semaphore weight?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure I understand the question. In order for a new goroutine to run calculateNodeIDsHelper, it must acquire 1 from the semaphore and release that 1 when it exits the goroutine.

t.calculateNodeIDsHelper(ctx, t.root)
t.db.calculateNodeIDsSema.Release(1)
t.changes.rootID = t.root.id

// ensure no ancestor changes occurred during execution
Expand All @@ -261,18 +257,17 @@ func (t *trieView) calculateNodeIDs(ctx context.Context) error {

// Calculates the ID of all descendants of [n] which need to be recalculated,
// and then calculates the ID of [n] itself.
func (t *trieView) calculateNodeIDsHelper(ctx context.Context, n *node, eg *errgroup.Group) error {
func (t *trieView) calculateNodeIDsHelper(ctx context.Context, n *node) {
_, span := t.db.debugTracer.Start(ctx, "MerkleDB.trieview.calculateNodeIDsHelper")
defer span.End()

var (
// We use [wg] to wait until all descendants of [n] have been updated.
// Note we can't wait on [eg] because [eg] may have started goroutines
// that aren't calculating IDs for descendants of [n].
wg sync.WaitGroup
updatedChildren = make(chan *node, len(n.children))
)

for childIndex, child := range n.children {
childIndex, child := childIndex, child
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't think this was actually necessary


childPath := n.key + path(childIndex) + child.compressedPath
childNodeChange, ok := t.changes.nodes[childPath]
if !ok {
Expand All @@ -281,24 +276,22 @@ func (t *trieView) calculateNodeIDsHelper(ctx context.Context, n *node, eg *errg
}

wg.Add(1)
updateChild := func() error {
calculateChildID := func() {
defer wg.Done()

if err := t.calculateNodeIDsHelper(ctx, childNodeChange.after, eg); err != nil {
return err
}
t.calculateNodeIDsHelper(ctx, childNodeChange.after)

// Note that this will never block
updatedChildren <- childNodeChange.after
return nil
}

// Try updating the child and its descendants in a goroutine.
if ok := eg.TryGo(updateChild); !ok {
if ok := t.db.calculateNodeIDsSema.TryAcquire(1); ok {
go calculateChildID()
t.db.calculateNodeIDsSema.Release(1)
danlaine marked this conversation as resolved.
Show resolved Hide resolved
} else {
// We're at the goroutine limit; do the work in this goroutine.
if err := updateChild(); err != nil {
return err
}
calculateChildID()
}
}

Expand All @@ -311,7 +304,7 @@ func (t *trieView) calculateNodeIDsHelper(ctx context.Context, n *node, eg *errg
}

// The IDs [n]'s descendants are up to date so we can calculate [n]'s ID.
return n.calculateID(t.db.metrics)
n.calculateID(t.db.metrics)
}

// GetProof returns a proof that [bytesPath] is in or not in trie [t].
Expand Down
Loading