Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

merkledb -- limit number of goroutines calculating node IDs #1960

Merged
merged 13 commits into from
Sep 5, 2023
35 changes: 16 additions & 19 deletions x/merkledb/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (

"golang.org/x/exp/maps"
"golang.org/x/exp/slices"
"golang.org/x/sync/semaphore"

"go.opentelemetry.io/otel/attribute"

Expand Down Expand Up @@ -125,9 +126,8 @@ type Config struct {
// RootGenConcurrency is the number of goroutines to use when
// generating a new state root.
//
// If 0 is specified, [runtime.NumCPU] will be used. If -1 is specified,
// no limit will be used.
RootGenConcurrency int
// If 0 is specified, [runtime.NumCPU] will be used.
RootGenConcurrency uint
// The number of bytes to write to disk when intermediate nodes are evicted
// from their cache and written to disk.
EvictionBatchSize int
Expand Down Expand Up @@ -182,12 +182,9 @@ type merkleDB struct {
// Valid children of this trie.
childViews []*trieView

// rootGenConcurrency is the number of goroutines to use when
// generating a new state root.
//
// TODO: Limit concurrency across all views, instead of only within
// a single view (see `workers` in hypersdk)
rootGenConcurrency int
// calculateNodeIDsSema controls the number of goroutines inside
// [calculateNodeIDsHelper] at any given time.
calculateNodeIDsSema *semaphore.Weighted
}

// New returns a new merkle database.
Expand All @@ -205,7 +202,7 @@ func newDatabase(
config Config,
metrics merkleMetrics,
) (*merkleDB, error) {
rootGenConcurrency := runtime.NumCPU()
rootGenConcurrency := uint(runtime.NumCPU())
if config.RootGenConcurrency != 0 {
rootGenConcurrency = config.RootGenConcurrency
}
Expand All @@ -218,15 +215,15 @@ func newDatabase(
},
}
trieDB := &merkleDB{
metrics: metrics,
baseDB: db,
valueNodeDB: newValueNodeDB(db, bufferPool, metrics, config.ValueNodeCacheSize),
intermediateNodeDB: newIntermediateNodeDB(db, bufferPool, metrics, config.IntermediateNodeCacheSize, config.EvictionBatchSize),
history: newTrieHistory(config.HistoryLength),
debugTracer: getTracerIfEnabled(config.TraceLevel, DebugTrace, config.Tracer),
infoTracer: getTracerIfEnabled(config.TraceLevel, InfoTrace, config.Tracer),
childViews: make([]*trieView, 0, defaultPreallocationSize),
rootGenConcurrency: rootGenConcurrency,
metrics: metrics,
baseDB: db,
valueNodeDB: newValueNodeDB(db, bufferPool, metrics, config.ValueNodeCacheSize),
intermediateNodeDB: newIntermediateNodeDB(db, bufferPool, metrics, config.IntermediateNodeCacheSize, config.EvictionBatchSize),
history: newTrieHistory(config.HistoryLength),
debugTracer: getTracerIfEnabled(config.TraceLevel, DebugTrace, config.Tracer),
infoTracer: getTracerIfEnabled(config.TraceLevel, InfoTrace, config.Tracer),
childViews: make([]*trieView, 0, defaultPreallocationSize),
calculateNodeIDsSema: semaphore.NewWeighted(int64(rootGenConcurrency)),
}

root, err := trieDB.initializeRootIfNeeded()
Expand Down
38 changes: 17 additions & 21 deletions x/merkledb/trieview.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ import (
oteltrace "go.opentelemetry.io/otel/trace"

"golang.org/x/exp/slices"
"golang.org/x/sync/errgroup"

"github.com/ava-labs/avalanchego/database"
"github.com/ava-labs/avalanchego/ids"
Expand Down Expand Up @@ -213,6 +212,7 @@ func newHistoricalTrieView(
}

// Recalculates the node IDs for all changed nodes in the trie.
// Cancelling [ctx] doesn't cancel calculation. It's used only for tracing.
func (t *trieView) calculateNodeIDs(ctx context.Context) error {
var err error
t.calculateNodesOnce.Do(func() {
Expand All @@ -225,27 +225,25 @@ func (t *trieView) calculateNodeIDs(ctx context.Context) error {
// We wait to create the span until after checking that we need to actually
// calculateNodeIDs to make traces more useful (otherwise there may be a span
// per key modified even though IDs are not re-calculated).
ctx, span := t.db.infoTracer.Start(ctx, "MerkleDB.trieview.calculateNodeIDs")
_, span := t.db.infoTracer.Start(ctx, "MerkleDB.trieview.calculateNodeIDs")
defer span.End()

// add all the changed key/values to the nodes of the trie
for key, change := range t.changes.values {
if change.after.IsNothing() {
// Note we're setting [err] defined outside this function.
if err = t.remove(key); err != nil {
return
}
// Note we're setting [err] defined outside this function.
} else if _, err = t.insert(key, change.after); err != nil {
return
}
}

// [eg] limits the number of goroutines we start.
var eg errgroup.Group
eg.SetLimit(t.db.rootGenConcurrency)
t.calculateNodeIDsHelper(ctx, t.root, &eg)
if err = eg.Wait(); err != nil {
return
}
_ = t.db.calculateNodeIDsSema.Acquire(context.Background(), 1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this to have the # of threads match the semaphore weight?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure I understand the question. In order for a new goroutine to run calculateNodeIDsHelper, it must acquire 1 from the semaphore and release that 1 when it exits the goroutine.

t.calculateNodeIDsHelper(t.root)
t.db.calculateNodeIDsSema.Release(1)
t.changes.rootID = t.root.id

// ensure no ancestor changes occurred during execution
Expand All @@ -259,18 +257,14 @@ func (t *trieView) calculateNodeIDs(ctx context.Context) error {

// Calculates the ID of all descendants of [n] which need to be recalculated,
// and then calculates the ID of [n] itself.
func (t *trieView) calculateNodeIDsHelper(ctx context.Context, n *node, eg *errgroup.Group) {
func (t *trieView) calculateNodeIDsHelper(n *node) {
var (
// We use [wg] to wait until all descendants of [n] have been updated.
// Note we can't wait on [eg] because [eg] may have started goroutines
// that aren't calculating IDs for descendants of [n].
wg sync.WaitGroup
updatedChildren = make(chan *node, len(n.children))
)

for childIndex, child := range n.children {
childIndex, child := childIndex, child
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't think this was actually necessary


childPath := n.key + path(childIndex) + child.compressedPath
childNodeChange, ok := t.changes.nodes[childPath]
if !ok {
Expand All @@ -279,22 +273,24 @@ func (t *trieView) calculateNodeIDsHelper(ctx context.Context, n *node, eg *errg
}

wg.Add(1)
updateChild := func() {
calculateChildID := func() {
defer wg.Done()

t.calculateNodeIDsHelper(ctx, childNodeChange.after, eg)
t.calculateNodeIDsHelper(childNodeChange.after)

// Note that this will never block
updatedChildren <- childNodeChange.after
}

// Try updating the child and its descendants in a goroutine.
if ok := eg.TryGo(func() error {
updateChild()
return nil
}); !ok {
if ok := t.db.calculateNodeIDsSema.TryAcquire(1); ok {
go func() {
calculateChildID()
t.db.calculateNodeIDsSema.Release(1)
}()
} else {
// We're at the goroutine limit; do the work in this goroutine.
updateChild()
calculateChildID()
}
}

Expand Down
Loading