diff --git a/x/merkledb/db.go b/x/merkledb/db.go index 453800ba92a5..62fcbd9f4c0d 100644 --- a/x/merkledb/db.go +++ b/x/merkledb/db.go @@ -15,6 +15,7 @@ import ( "golang.org/x/exp/maps" "golang.org/x/exp/slices" + "golang.org/x/sync/semaphore" "go.opentelemetry.io/otel/attribute" @@ -125,9 +126,8 @@ type Config struct { // RootGenConcurrency is the number of goroutines to use when // generating a new state root. // - // If 0 is specified, [runtime.NumCPU] will be used. If -1 is specified, - // no limit will be used. - RootGenConcurrency int + // If 0 is specified, [runtime.NumCPU] will be used. + RootGenConcurrency uint // The number of bytes to write to disk when intermediate nodes are evicted // from their cache and written to disk. EvictionBatchSize int @@ -182,12 +182,9 @@ type merkleDB struct { // Valid children of this trie. childViews []*trieView - // rootGenConcurrency is the number of goroutines to use when - // generating a new state root. - // - // TODO: Limit concurrency across all views, instead of only within - // a single view (see `workers` in hypersdk) - rootGenConcurrency int + // calculateNodeIDsSema controls the number of goroutines inside + // [calculateNodeIDsHelper] at any given time. + calculateNodeIDsSema *semaphore.Weighted } // New returns a new merkle database. @@ -205,7 +202,7 @@ func newDatabase( config Config, metrics merkleMetrics, ) (*merkleDB, error) { - rootGenConcurrency := runtime.NumCPU() + rootGenConcurrency := uint(runtime.NumCPU()) if config.RootGenConcurrency != 0 { rootGenConcurrency = config.RootGenConcurrency } @@ -218,15 +215,15 @@ func newDatabase( }, } trieDB := &merkleDB{ - metrics: metrics, - baseDB: db, - valueNodeDB: newValueNodeDB(db, bufferPool, metrics, config.ValueNodeCacheSize), - intermediateNodeDB: newIntermediateNodeDB(db, bufferPool, metrics, config.IntermediateNodeCacheSize, config.EvictionBatchSize), - history: newTrieHistory(config.HistoryLength), - debugTracer: getTracerIfEnabled(config.TraceLevel, DebugTrace, config.Tracer), - infoTracer: getTracerIfEnabled(config.TraceLevel, InfoTrace, config.Tracer), - childViews: make([]*trieView, 0, defaultPreallocationSize), - rootGenConcurrency: rootGenConcurrency, + metrics: metrics, + baseDB: db, + valueNodeDB: newValueNodeDB(db, bufferPool, metrics, config.ValueNodeCacheSize), + intermediateNodeDB: newIntermediateNodeDB(db, bufferPool, metrics, config.IntermediateNodeCacheSize, config.EvictionBatchSize), + history: newTrieHistory(config.HistoryLength), + debugTracer: getTracerIfEnabled(config.TraceLevel, DebugTrace, config.Tracer), + infoTracer: getTracerIfEnabled(config.TraceLevel, InfoTrace, config.Tracer), + childViews: make([]*trieView, 0, defaultPreallocationSize), + calculateNodeIDsSema: semaphore.NewWeighted(int64(rootGenConcurrency)), } root, err := trieDB.initializeRootIfNeeded() diff --git a/x/merkledb/trieview.go b/x/merkledb/trieview.go index c3bbfcdb5f38..555b0434476f 100644 --- a/x/merkledb/trieview.go +++ b/x/merkledb/trieview.go @@ -15,7 +15,6 @@ import ( oteltrace "go.opentelemetry.io/otel/trace" "golang.org/x/exp/slices" - "golang.org/x/sync/errgroup" "github.com/ava-labs/avalanchego/database" "github.com/ava-labs/avalanchego/ids" @@ -213,6 +212,7 @@ func newHistoricalTrieView( } // Recalculates the node IDs for all changed nodes in the trie. +// Cancelling [ctx] doesn't cancel calculation. It's used only for tracing. func (t *trieView) calculateNodeIDs(ctx context.Context) error { var err error t.calculateNodesOnce.Do(func() { @@ -225,27 +225,25 @@ func (t *trieView) calculateNodeIDs(ctx context.Context) error { // We wait to create the span until after checking that we need to actually // calculateNodeIDs to make traces more useful (otherwise there may be a span // per key modified even though IDs are not re-calculated). - ctx, span := t.db.infoTracer.Start(ctx, "MerkleDB.trieview.calculateNodeIDs") + _, span := t.db.infoTracer.Start(ctx, "MerkleDB.trieview.calculateNodeIDs") defer span.End() // add all the changed key/values to the nodes of the trie for key, change := range t.changes.values { if change.after.IsNothing() { + // Note we're setting [err] defined outside this function. if err = t.remove(key); err != nil { return } + // Note we're setting [err] defined outside this function. } else if _, err = t.insert(key, change.after); err != nil { return } } - // [eg] limits the number of goroutines we start. - var eg errgroup.Group - eg.SetLimit(t.db.rootGenConcurrency) - t.calculateNodeIDsHelper(ctx, t.root, &eg) - if err = eg.Wait(); err != nil { - return - } + _ = t.db.calculateNodeIDsSema.Acquire(context.Background(), 1) + t.calculateNodeIDsHelper(t.root) + t.db.calculateNodeIDsSema.Release(1) t.changes.rootID = t.root.id // ensure no ancestor changes occurred during execution @@ -259,18 +257,14 @@ func (t *trieView) calculateNodeIDs(ctx context.Context) error { // Calculates the ID of all descendants of [n] which need to be recalculated, // and then calculates the ID of [n] itself. -func (t *trieView) calculateNodeIDsHelper(ctx context.Context, n *node, eg *errgroup.Group) { +func (t *trieView) calculateNodeIDsHelper(n *node) { var ( // We use [wg] to wait until all descendants of [n] have been updated. - // Note we can't wait on [eg] because [eg] may have started goroutines - // that aren't calculating IDs for descendants of [n]. wg sync.WaitGroup updatedChildren = make(chan *node, len(n.children)) ) for childIndex, child := range n.children { - childIndex, child := childIndex, child - childPath := n.key + path(childIndex) + child.compressedPath childNodeChange, ok := t.changes.nodes[childPath] if !ok { @@ -279,22 +273,24 @@ func (t *trieView) calculateNodeIDsHelper(ctx context.Context, n *node, eg *errg } wg.Add(1) - updateChild := func() { + calculateChildID := func() { defer wg.Done() - t.calculateNodeIDsHelper(ctx, childNodeChange.after, eg) + t.calculateNodeIDsHelper(childNodeChange.after) // Note that this will never block updatedChildren <- childNodeChange.after } // Try updating the child and its descendants in a goroutine. - if ok := eg.TryGo(func() error { - updateChild() - return nil - }); !ok { + if ok := t.db.calculateNodeIDsSema.TryAcquire(1); ok { + go func() { + calculateChildID() + t.db.calculateNodeIDsSema.Release(1) + }() + } else { // We're at the goroutine limit; do the work in this goroutine. - updateChild() + calculateChildID() } }