diff --git a/blockchain/chain.go b/blockchain/chain.go index 065cef2234..b28ff88a7e 100644 --- a/blockchain/chain.go +++ b/blockchain/chain.go @@ -2465,9 +2465,14 @@ func New(ctx context.Context, config *Config) (*BlockChain, error) { best := b.BestSnapshot() tipHeight := uint32(best.Height) - b.spendPruner, err = spendpruner.NewSpendJournalPruner(b.db, - b.BatchRemoveSpendEntry, tipHeight, spendpruner.BatchPruneInterval, - spendpruner.DependencyPruneInterval) + cfg := &spendpruner.SpendJournalPrunerConfig{ + DB: b.db, + BatchRemoveSpendEntry: b.BatchRemoveSpendEntry, + BatchPruneInterval: spendpruner.BatchPruneInterval, + DependencyPruneInterval: spendpruner.DependencyPruneInterval, + BlockHeightByHash: b.BlockHeightByHash, + } + b.spendPruner, err = spendpruner.NewSpendJournalPruner(cfg, tipHeight) if err != nil { return nil, err } diff --git a/blockchain/internal/spendpruner/db.go b/blockchain/internal/spendpruner/db.go index d64910da82..9fb32e64e4 100644 --- a/blockchain/internal/spendpruner/db.go +++ b/blockchain/internal/spendpruner/db.go @@ -96,16 +96,18 @@ func dbUpdateSpendConsumerDependencies(dbTx database.Tx, blockHash chainhash.Has func dbPersistSpendHeights(dbTx database.Tx, spendHeights map[chainhash.Hash]uint32) error { heightsBucket := dbTx.Metadata().Bucket(spendJournalHeightsBucketName) - // return immediately if there are no spend heights to persist. + // Return immediately if there are no spend heights to persist. if len(spendHeights) == 0 { return nil } // Persist all spend height map entries. for blockHash, height := range spendHeights { - var b [8]byte + hashCopy := blockHash + + var b [4]byte binary.LittleEndian.PutUint32(b[:], height) - err := heightsBucket.Put(blockHash[:], b[:]) + err := heightsBucket.Put(hashCopy[:], b[:]) if err != nil { return err } @@ -132,8 +134,8 @@ func dbPruneSpendDependencies(dbTx database.Tx, keys []chainhash.Hash) error { } // Prune all spend dependency map entries. - for _, blockHash := range keys { - err := depsBucket.Delete(blockHash[:]) + for idx := range keys { + err := depsBucket.Delete(keys[idx][:]) if err != nil { return err } @@ -153,8 +155,8 @@ func dbPruneSpendHeights(dbTx database.Tx, keys []chainhash.Hash) error { } // Persist all spend height map entries. - for _, blockHash := range keys { - err := heightsBucket.Delete(blockHash[:]) + for idx := range keys { + err := heightsBucket.Delete(keys[idx][:]) if err != nil { return err } @@ -189,13 +191,11 @@ func dbFetchSpendHeights(dbTx database.Tx) (map[chainhash.Hash]uint32, error) { spendHeights := make(map[chainhash.Hash]uint32) cursor := heightsBucket.Cursor() for ok := cursor.First(); ok; ok = cursor.Next() { - hash, err := chainhash.NewHash(cursor.Key()) - if err != nil { - return nil, err - } + var hash chainhash.Hash + copy(hash[:], cursor.Key()) height := binary.LittleEndian.Uint32(cursor.Value()) - spendHeights[*hash] = height + spendHeights[hash] = height } return spendHeights, nil diff --git a/blockchain/internal/spendpruner/db_test.go b/blockchain/internal/spendpruner/db_test.go index 1b923fc1a1..6567d262d8 100644 --- a/blockchain/internal/spendpruner/db_test.go +++ b/blockchain/internal/spendpruner/db_test.go @@ -90,8 +90,8 @@ func TestDeserializeConsumerDependencies(t *testing.T) { // createdDB creates the test database. This is intended to be used for testing // purposes only. -func createDB() (database.DB, func(), error) { - dbPath := filepath.Join(os.TempDir(), "spdb") +func createDB(dir string) (database.DB, func(), error) { + dbPath := filepath.Join(dir, "spdb") err := os.MkdirAll(dbPath, 0700) if err != nil { @@ -100,25 +100,21 @@ func createDB() (database.DB, func(), error) { db, err := database.Create("ffldb", dbPath, wire.SimNet) if err != nil { - os.RemoveAll(dbPath) return nil, nil, err } err = initConsumerDependenciesBucket(db) if err != nil { - os.RemoveAll(dbPath) return nil, nil, err } err = initSpendJournalHeightsBucket(db) if err != nil { - os.RemoveAll(dbPath) return nil, nil, err } teardown := func() { db.Close() - os.RemoveAll(dbPath) } return db, teardown, nil diff --git a/blockchain/internal/spendpruner/pruner.go b/blockchain/internal/spendpruner/pruner.go index 6ca8b6e0f8..4e4f85109d 100644 --- a/blockchain/internal/spendpruner/pruner.go +++ b/blockchain/internal/spendpruner/pruner.go @@ -35,17 +35,33 @@ const ( maxDependencyDifference = 288 ) +// SpendJournalPrunerConfig is the configuration struct for the spend journal +// pruner. +type SpendJournalPrunerConfig struct { + // DB represents the spendpruner database. + DB database.DB + // BatchRemoveSpendEntry purges the spend journal entries of the + // provided batched block hashes if they are not part of the main chain. + BatchRemoveSpendEntry func(hash []chainhash.Hash) error + // BatchPruneInterval is the maximum time between processing batched prunes. + BatchPruneInterval time.Duration + // BlockHeightByHash returns the height of the block with the given hash + // in the main chain. + BlockHeightByHash func(hash *chainhash.Hash) (int64, error) + // DependencyPruneInterval is the maximum time between processing + // dependency prunes. + DependencyPruneInterval time.Duration +} + // SpendJournalPruner represents a spend journal pruner that ensures spend // journal entries needed by consumers are retained until no longer needed. type SpendJournalPruner struct { + cfg *SpendJournalPrunerConfig + // This field tracks the chain tip height based on block connections and // disconnections. currentTip uint32 // Update atomically. - // This removes the spend journal entries of the provided block hashes if - // they are not part of the main chain. - batchRemoveSpendEntry func(hash []chainhash.Hash) error - // These fields track spend consumers, their spend journal dependencies // and block heights of the spend entries. dependents map[chainhash.Hash][]string @@ -63,43 +79,31 @@ type SpendJournalPruner struct { pruneBatch []chainhash.Hash pruneBatchMtx sync.Mutex - // This is the maximum time between processing batched prunes. - batchPruneInterval time.Duration - - // This is the maximum time between processing dependency prunes. - dependencyPruneInterval time.Duration - - // This field provides access to the database. - db database.DB - // This field synchronizes channel sends and receives. quit chan struct{} } // NewSpendJournalPruner initializes a spend journal pruner. -func NewSpendJournalPruner(db database.DB, batchRemoveSpendEntry func(hash []chainhash.Hash) error, currentTip uint32, batchPruneInterval time.Duration, dependencyPruneInterval time.Duration) (*SpendJournalPruner, error) { - err := initConsumerDependenciesBucket(db) +func NewSpendJournalPruner(cfg *SpendJournalPrunerConfig, currentTip uint32) (*SpendJournalPruner, error) { + err := initConsumerDependenciesBucket(cfg.DB) if err != nil { return nil, err } - err = initSpendJournalHeightsBucket(db) + err = initSpendJournalHeightsBucket(cfg.DB) if err != nil { return nil, err } spendPruner := &SpendJournalPruner{ - db: db, - currentTip: currentTip, - batchRemoveSpendEntry: batchRemoveSpendEntry, - batchPruneInterval: batchPruneInterval, - dependencyPruneInterval: dependencyPruneInterval, - dependents: make(map[chainhash.Hash][]string), - spendHeights: make(map[chainhash.Hash]uint32), - consumers: make(map[string]SpendConsumer), - pruneBatch: make([]chainhash.Hash, 0, batchThreshold), - ch: make(chan struct{}, batchSignalBufferSize), - quit: make(chan struct{}), + cfg: cfg, + currentTip: currentTip, + dependents: make(map[chainhash.Hash][]string), + spendHeights: make(map[chainhash.Hash]uint32), + consumers: make(map[string]SpendConsumer), + pruneBatch: make([]chainhash.Hash, 0, batchThreshold), + ch: make(chan struct{}, batchSignalBufferSize), + quit: make(chan struct{}), } err = spendPruner.loadSpendConsumerDependencies() @@ -112,9 +116,33 @@ func NewSpendJournalPruner(db database.DB, batchRemoveSpendEntry func(hash []cha return nil, err } + spendPruner.generateDependencySpendHeights() + return spendPruner, nil } +// generateDepedencySpendHeights creates the associated spend heights +// for loaded spend dependencies without spend heights. +func (s *SpendJournalPruner) generateDependencySpendHeights() { + s.dependentsMtx.Lock() + defer s.dependentsMtx.Unlock() + s.spendHeightsMtx.Lock() + defer s.spendHeightsMtx.Unlock() + + for depHash := range s.dependents { + hash := depHash + if _, ok := s.spendHeights[hash]; !ok { + height, err := s.cfg.BlockHeightByHash(&hash) + if err != nil { + log.Error("no spend height found for hash %s: %v", hash, err) + height = 0 + } + + s.spendHeights[hash] = uint32(height) + } + } +} + // AddConsumer adds a spend journal consumer to the pruner. func (s *SpendJournalPruner) AddConsumer(consumer SpendConsumer) { s.consumersMtx.Lock() @@ -174,7 +202,7 @@ func (s *SpendJournalPruner) pruneSpendDependencies(dependencies []chainhash.Has s.spendHeightsMtx.Unlock() } - err := s.db.Update(func(tx database.Tx) error { + err := s.cfg.DB.Update(func(tx database.Tx) error { err := dbPruneSpendDependencies(tx, dependencies) if err != nil { return err @@ -315,7 +343,7 @@ func (s *SpendJournalPruner) addSpendConsumerDependencies(blockHash *chainhash.H // Update the persisted spend consumer deps entry for the provided block // hash as well as the spend heights map if it was updated. - err := s.db.Update(func(tx database.Tx) error { + err := s.cfg.DB.Update(func(tx database.Tx) error { err := dbUpdateSpendConsumerDependencies(tx, *blockHash, dependents) if err != nil { return err @@ -406,7 +434,7 @@ func (s *SpendJournalPruner) removeSpendConsumerDependencies(blockHash *chainhas // Remove the tracked spend journal entry for the provided // block hash. - return s.db.Update(func(tx database.Tx) error { + return s.cfg.DB.Update(func(tx database.Tx) error { err := dbUpdateSpendConsumerDependencies(tx, *blockHash, nil) if err != nil { msg := fmt.Sprintf("unable to remove persisted consumer "+ @@ -428,7 +456,7 @@ func (s *SpendJournalPruner) removeSpendConsumerDependencies(blockHash *chainhas // loadSpendConsumerDependencies loads persisted consumer spend dependencies // from the database. func (s *SpendJournalPruner) loadSpendConsumerDependencies() error { - return s.db.View(func(tx database.Tx) error { + return s.cfg.DB.View(func(tx database.Tx) error { consumerDeps, err := dbFetchSpendConsumerDependencies(tx) if err != nil { msg := fmt.Sprintf("unable to load spend consumer "+ @@ -447,7 +475,7 @@ func (s *SpendJournalPruner) loadSpendConsumerDependencies() error { // loadSpendJournalHeights loads persisted spend journal heights // from the database. func (s *SpendJournalPruner) loadSpendJournalHeights() error { - return s.db.View(func(tx database.Tx) error { + return s.cfg.DB.View(func(tx database.Tx) error { spendHeights, err := dbFetchSpendHeights(tx) if err != nil { msg := fmt.Sprintf("unable to load spend journal "+ @@ -500,7 +528,7 @@ func (s *SpendJournalPruner) handleBatchPrunes(ctx context.Context) { s.pruneBatch = s.pruneBatch[:0] s.pruneBatchMtx.Unlock() - err := s.batchRemoveSpendEntry(batch) + err := s.cfg.BatchRemoveSpendEntry(batch) if err != nil { log.Errorf("unable to batch remove spend entries: %v", err) } @@ -512,8 +540,8 @@ func (s *SpendJournalPruner) handleBatchPrunes(ctx context.Context) { // // This should be run as a goroutine. func (s *SpendJournalPruner) handleTicks(ctx context.Context) { - batchTicker := time.NewTicker(s.batchPruneInterval) - dependencyTicker := time.NewTicker(s.dependencyPruneInterval) + batchTicker := time.NewTicker(s.cfg.BatchPruneInterval) + dependencyTicker := time.NewTicker(s.cfg.DependencyPruneInterval) for { select { case <-ctx.Done(): diff --git a/blockchain/internal/spendpruner/pruner_test.go b/blockchain/internal/spendpruner/pruner_test.go index cf25709aa7..50a4ce84bd 100644 --- a/blockchain/internal/spendpruner/pruner_test.go +++ b/blockchain/internal/spendpruner/pruner_test.go @@ -80,19 +80,31 @@ func TestSpendPruner(t *testing.T) { needSpendDataErr: fmt.Errorf("unable to confirm spend data need"), } - db, teardown, err := createDB() + db, teardown, err := createDB(t.TempDir()) if err != nil { t.Fatal(err) } defer teardown() + height := int64(0) + blockHeightByHash := func(hash *chainhash.Hash) (int64, error) { + height++ + return height, nil + } + batchPruneInterval := time.Millisecond * 100 dependentPruneInterval := time.Millisecond * 100 tipHeight := uint32(1) ctx, cancel := context.WithCancel(context.Background()) - pruner, err := NewSpendJournalPruner(db, chain.BatchRemoveSpendEntry, - tipHeight, batchPruneInterval, dependentPruneInterval) + cfg := &SpendJournalPrunerConfig{ + DB: db, + BatchRemoveSpendEntry: chain.BatchRemoveSpendEntry, + BatchPruneInterval: batchPruneInterval, + DependencyPruneInterval: dependentPruneInterval, + BlockHeightByHash: blockHeightByHash, + } + pruner, err := NewSpendJournalPruner(cfg, tipHeight) if err != nil { t.Fatal(err) } @@ -207,7 +219,7 @@ func TestSpendPruner(t *testing.T) { } removeSpendConsumerDep := func(pruner *SpendJournalPruner, blockHash *chainhash.Hash, consumerID string) error { - return pruner.db.Update(func(tx database.Tx) error { + return pruner.cfg.DB.Update(func(tx database.Tx) error { return pruner.RemoveSpendConsumerDependency(tx, blockHash, consumerID) }) } @@ -402,8 +414,7 @@ func TestSpendPruner(t *testing.T) { cancel() // Load the spend pruner from the database. - pruner, err = NewSpendJournalPruner(db, chain.BatchRemoveSpendEntry, - tipHeight, batchPruneInterval, dependentPruneInterval) + pruner, err = NewSpendJournalPruner(cfg, tipHeight) if err != nil { t.Fatal(err) } @@ -427,3 +438,78 @@ func TestSpendPruner(t *testing.T) { expected, len(deps)) } } + +func TestGenerateDependencySpendHeights(t *testing.T) { + db, teardown, err := createDB(t.TempDir()) + if err != nil { + t.Fatal(err) + } + + defer teardown() + + height := int64(0) + blockHeightByHash := func(hash *chainhash.Hash) (int64, error) { + height++ + return height, nil + } + + batchPruneInterval := time.Millisecond * 100 + dependentPruneInterval := time.Millisecond * 100 + tipHeight := uint32(1) + cfg := &SpendJournalPrunerConfig{ + DB: db, + BatchRemoveSpendEntry: func(hash []chainhash.Hash) error { + return nil + }, + BatchPruneInterval: batchPruneInterval, + DependencyPruneInterval: dependentPruneInterval, + BlockHeightByHash: blockHeightByHash, + } + pruner := &SpendJournalPruner{ + cfg: cfg, + currentTip: tipHeight, + dependents: make(map[chainhash.Hash][]string), + spendHeights: make(map[chainhash.Hash]uint32), + consumers: make(map[string]SpendConsumer), + pruneBatch: make([]chainhash.Hash, 0, batchThreshold), + ch: make(chan struct{}, batchSignalBufferSize), + quit: make(chan struct{}), + } + + hashA := chainhash.Hash{'a'} + hashB := chainhash.Hash{'b'} + hashC := chainhash.Hash{'c'} + + pruner.dependentsMtx.Lock() + pruner.dependents[hashA] = []string{} + pruner.dependents[hashB] = []string{} + pruner.dependents[hashC] = []string{} + pruner.dependentsMtx.Unlock() + + pruner.spendHeightsMtx.Lock() + pruner.spendHeights[hashA] = 10 + pruner.spendHeights[hashC] = 20 + pruner.spendHeightsMtx.Unlock() + + pruner.generateDependencySpendHeights() + + pruner.spendHeightsMtx.Lock() + spendHeightsLen := len(pruner.spendHeights) + spendHeight, ok := pruner.spendHeights[hashB] + pruner.spendHeightsMtx.Unlock() + + // Ensure the spend heights set is now 3. + if spendHeightsLen != 3 { + t.Fatalf("expected 3 spend height entries. got %d", spendHeightsLen) + } + + // Ensure the height associated with hashA is 1. + if !ok { + t.Fatalf("expected hashA to have a spend height entry") + } + + if spendHeight != 1 { + t.Fatalf("expected associated spend height for hashA to "+ + "be 1, got %d", height) + } +}