diff --git a/hmy/downloader/adapter.go b/hmy/downloader/adapter.go new file mode 100644 index 0000000000..3b66686e96 --- /dev/null +++ b/hmy/downloader/adapter.go @@ -0,0 +1,38 @@ +package downloader + +import ( + "context" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/event" + "github.com/harmony-one/harmony/consensus/engine" + "github.com/harmony-one/harmony/core/types" + "github.com/harmony-one/harmony/p2p/stream/common/streammanager" + syncproto "github.com/harmony-one/harmony/p2p/stream/protocols/sync" + sttypes "github.com/harmony-one/harmony/p2p/stream/types" +) + +type syncProtocol interface { + GetCurrentBlockNumber(ctx context.Context, opts ...syncproto.Option) (uint64, sttypes.StreamID, error) + GetBlocksByNumber(ctx context.Context, bns []uint64, opts ...syncproto.Option) ([]*types.Block, sttypes.StreamID, error) + GetBlockHashes(ctx context.Context, bns []uint64, opts ...syncproto.Option) ([]common.Hash, sttypes.StreamID, error) + GetBlocksByHashes(ctx context.Context, hs []common.Hash, opts ...syncproto.Option) ([]*types.Block, sttypes.StreamID, error) + + RemoveStream(stID sttypes.StreamID) // If a stream delivers invalid data, remove the stream + SubscribeAddStreamEvent(ch chan<- streammanager.EvtStreamAdded) event.Subscription + NumStreams() int +} + +type blockChain interface { + engine.ChainReader + Engine() engine.Engine + + InsertChain(chain types.Blocks, verifyHeaders bool) (int, error) + WriteCommitSig(blockNum uint64, lastCommits []byte) error +} + +// insertHelper is the interface help to verify and insert a block. +type insertHelper interface { + verifyAndInsertBlocks(blocks types.Blocks) (int, error) + verifyAndInsertBlock(block *types.Block) error +} diff --git a/hmy/downloader/adapter_test.go b/hmy/downloader/adapter_test.go new file mode 100644 index 0000000000..890088b6c6 --- /dev/null +++ b/hmy/downloader/adapter_test.go @@ -0,0 +1,323 @@ +package downloader + +import ( + "context" + "fmt" + "math/big" + "sync" + + "github.com/harmony-one/harmony/consensus/engine" + staking "github.com/harmony-one/harmony/staking/types" + + "github.com/harmony-one/harmony/block" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/event" + + "github.com/harmony-one/harmony/core/types" + "github.com/harmony-one/harmony/internal/params" + "github.com/harmony-one/harmony/p2p/stream/common/streammanager" + syncproto "github.com/harmony-one/harmony/p2p/stream/protocols/sync" + sttypes "github.com/harmony-one/harmony/p2p/stream/types" + "github.com/harmony-one/harmony/shard" +) + +type testBlockChain struct { + curBN uint64 + insertErrHook func(bn uint64) error + lock sync.Mutex +} + +func newTestBlockChain(curBN uint64, insertErrHook func(bn uint64) error) *testBlockChain { + return &testBlockChain{ + curBN: curBN, + insertErrHook: insertErrHook, + } +} + +func (bc *testBlockChain) CurrentBlock() *types.Block { + bc.lock.Lock() + defer bc.lock.Unlock() + + return makeTestBlock(bc.curBN) +} + +func (bc *testBlockChain) CurrentHeader() *block.Header { + bc.lock.Lock() + defer bc.lock.Unlock() + + return makeTestBlock(bc.curBN).Header() +} + +func (bc *testBlockChain) currentBlockNumber() uint64 { + bc.lock.Lock() + defer bc.lock.Unlock() + + return bc.curBN +} + +func (bc *testBlockChain) InsertChain(chain types.Blocks, verifyHeaders bool) (int, error) { + bc.lock.Lock() + defer bc.lock.Unlock() + + for i, block := range chain { + if bc.insertErrHook != nil { + if err := bc.insertErrHook(block.NumberU64()); err != nil { + return i, err + } + } + if block.NumberU64() <= bc.curBN { + continue + } + if block.NumberU64() != bc.curBN+1 { + return i, fmt.Errorf("not expected block number: %v / %v", block.NumberU64(), bc.curBN+1) + } + bc.curBN++ + } + return len(chain), nil +} + +func (bc *testBlockChain) changeBlockNumber(val uint64) { + bc.lock.Lock() + defer bc.lock.Unlock() + + bc.curBN = val +} + +func (bc *testBlockChain) ShardID() uint32 { return 0 } +func (bc *testBlockChain) ReadShardState(epoch *big.Int) (*shard.State, error) { return nil, nil } +func (bc *testBlockChain) Config() *params.ChainConfig { return nil } +func (bc *testBlockChain) WriteCommitSig(blockNum uint64, lastCommits []byte) error { return nil } +func (bc *testBlockChain) GetHeader(hash common.Hash, number uint64) *block.Header { return nil } +func (bc *testBlockChain) GetHeaderByNumber(number uint64) *block.Header { return nil } +func (bc *testBlockChain) GetHeaderByHash(hash common.Hash) *block.Header { return nil } +func (bc *testBlockChain) GetBlock(hash common.Hash, number uint64) *types.Block { return nil } +func (bc *testBlockChain) ReadValidatorList() ([]common.Address, error) { return nil, nil } +func (bc *testBlockChain) ReadCommitSig(blockNum uint64) ([]byte, error) { return nil, nil } +func (bc *testBlockChain) ReadBlockRewardAccumulator(uint64) (*big.Int, error) { return nil, nil } +func (bc *testBlockChain) ValidatorCandidates() []common.Address { return nil } +func (bc *testBlockChain) Engine() engine.Engine { return nil } +func (bc *testBlockChain) ReadValidatorInformation(addr common.Address) (*staking.ValidatorWrapper, error) { + return nil, nil +} +func (bc *testBlockChain) ReadValidatorSnapshot(addr common.Address) (*staking.ValidatorSnapshot, error) { + return nil, nil +} +func (bc *testBlockChain) ReadValidatorSnapshotAtEpoch(epoch *big.Int, addr common.Address) (*staking.ValidatorSnapshot, error) { + return nil, nil +} +func (bc *testBlockChain) ReadValidatorStats(addr common.Address) (*staking.ValidatorStats, error) { + return nil, nil +} +func (bc *testBlockChain) SuperCommitteeForNextEpoch(beacon engine.ChainReader, header *block.Header, isVerify bool) (*shard.State, error) { + return nil, nil +} + +type testInsertHelper struct { + bc *testBlockChain +} + +func (ch *testInsertHelper) verifyAndInsertBlock(block *types.Block) error { + _, err := ch.bc.InsertChain(types.Blocks{block}, true) + return err +} +func (ch *testInsertHelper) verifyAndInsertBlocks(blocks types.Blocks) (int, error) { + return ch.bc.InsertChain(blocks, true) +} + +const ( + initStreamNum = 32 + minStreamNum = 16 +) + +type testSyncProtocol struct { + streamIDs []sttypes.StreamID + remoteChain *testBlockChain + requestErrHook func(uint64) error + + curIndex int + numStreams int + lock sync.Mutex +} + +func newTestSyncProtocol(targetBN uint64, numStreams int, requestErrHook func(uint64) error) *testSyncProtocol { + return &testSyncProtocol{ + streamIDs: makeStreamIDs(numStreams), + remoteChain: newTestBlockChain(targetBN, nil), + requestErrHook: requestErrHook, + curIndex: 0, + numStreams: numStreams, + } +} + +func (sp *testSyncProtocol) GetCurrentBlockNumber(ctx context.Context, opts ...syncproto.Option) (uint64, sttypes.StreamID, error) { + sp.lock.Lock() + defer sp.lock.Unlock() + + bn := sp.remoteChain.currentBlockNumber() + + return bn, sp.nextStreamID(), nil +} + +func (sp *testSyncProtocol) GetBlocksByNumber(ctx context.Context, bns []uint64, opts ...syncproto.Option) ([]*types.Block, sttypes.StreamID, error) { + sp.lock.Lock() + defer sp.lock.Unlock() + + res := make([]*types.Block, 0, len(bns)) + for _, bn := range bns { + if sp.requestErrHook != nil { + if err := sp.requestErrHook(bn); err != nil { + return nil, sp.nextStreamID(), err + } + } + if bn > sp.remoteChain.currentBlockNumber() { + res = append(res, nil) + } else { + res = append(res, makeTestBlock(bn)) + } + } + return res, sp.nextStreamID(), nil +} + +func (sp *testSyncProtocol) GetBlockHashes(ctx context.Context, bns []uint64, opts ...syncproto.Option) ([]common.Hash, sttypes.StreamID, error) { + sp.lock.Lock() + defer sp.lock.Unlock() + + res := make([]common.Hash, 0, len(bns)) + for _, bn := range bns { + if sp.requestErrHook != nil { + if err := sp.requestErrHook(bn); err != nil { + return nil, sp.nextStreamID(), err + } + } + if bn > sp.remoteChain.currentBlockNumber() { + res = append(res, emptyHash) + } else { + res = append(res, makeTestBlockHash(bn)) + } + } + return res, sp.nextStreamID(), nil +} + +func (sp *testSyncProtocol) GetBlocksByHashes(ctx context.Context, hs []common.Hash, opts ...syncproto.Option) ([]*types.Block, sttypes.StreamID, error) { + sp.lock.Lock() + defer sp.lock.Unlock() + + res := make([]*types.Block, 0, len(hs)) + for _, h := range hs { + bn := testHashToNumber(h) + if sp.requestErrHook != nil { + if err := sp.requestErrHook(bn); err != nil { + return nil, sp.nextStreamID(), err + } + } + if bn > sp.remoteChain.currentBlockNumber() { + res = append(res, nil) + } else { + res = append(res, makeTestBlock(bn)) + } + } + return res, sp.nextStreamID(), nil +} + +func (sp *testSyncProtocol) RemoveStream(target sttypes.StreamID) { + sp.lock.Lock() + defer sp.lock.Unlock() + + for i, stid := range sp.streamIDs { + if stid == target { + if i == len(sp.streamIDs)-1 { + sp.streamIDs = sp.streamIDs[:i] + } else { + sp.streamIDs = append(sp.streamIDs[:i], sp.streamIDs[i+1:]...) + } + // mock discovery + if len(sp.streamIDs) < minStreamNum { + sp.streamIDs = append(sp.streamIDs, makeStreamID(sp.numStreams)) + sp.numStreams++ + } + } + } +} + +func (sp *testSyncProtocol) NumStreams() int { + sp.lock.Lock() + defer sp.lock.Unlock() + + return len(sp.streamIDs) +} + +func (sp *testSyncProtocol) SubscribeAddStreamEvent(ch chan<- streammanager.EvtStreamAdded) event.Subscription { + var evtFeed event.Feed + go func() { + sp.lock.Lock() + num := len(sp.streamIDs) + sp.lock.Unlock() + for i := 0; i != num; i++ { + evtFeed.Send(streammanager.EvtStreamAdded{Stream: nil}) + } + }() + return evtFeed.Subscribe(ch) +} + +// TODO: add with whitelist stuff +func (sp *testSyncProtocol) nextStreamID() sttypes.StreamID { + if sp.curIndex >= len(sp.streamIDs) { + sp.curIndex = 0 + } + index := sp.curIndex + sp.curIndex++ + if sp.curIndex >= len(sp.streamIDs) { + sp.curIndex = 0 + } + return sp.streamIDs[index] +} + +func (sp *testSyncProtocol) changeBlockNumber(val uint64) { + sp.remoteChain.changeBlockNumber(val) +} + +func makeStreamIDs(size int) []sttypes.StreamID { + res := make([]sttypes.StreamID, 0, size) + for i := 0; i != size; i++ { + res = append(res, makeStreamID(i)) + } + return res +} + +func makeStreamID(index int) sttypes.StreamID { + return sttypes.StreamID(fmt.Sprintf("test stream %v", index)) +} + +var ( + hashNumberMap = map[common.Hash]uint64{} + computed uint64 + hashNumberLock sync.Mutex +) + +func testHashToNumber(h common.Hash) uint64 { + hashNumberLock.Lock() + defer hashNumberLock.Unlock() + + if h == emptyHash { + panic("not allowed") + } + if bn, ok := hashNumberMap[h]; ok { + return bn + } + for ; ; computed++ { + ch := makeTestBlockHash(computed) + hashNumberMap[ch] = computed + if ch == h { + return computed + } + } +} + +func testNumberToHashes(nums []uint64) []common.Hash { + hashes := make([]common.Hash, 0, len(nums)) + for _, num := range nums { + hashes = append(hashes, makeTestBlockHash(num)) + } + return hashes +} diff --git a/hmy/downloader/beaconhelper.go b/hmy/downloader/beaconhelper.go new file mode 100644 index 0000000000..4b0e643e68 --- /dev/null +++ b/hmy/downloader/beaconhelper.go @@ -0,0 +1,152 @@ +package downloader + +import ( + "time" + + "github.com/harmony-one/harmony/core/types" + "github.com/harmony-one/harmony/internal/utils" + "github.com/rs/zerolog" +) + +// lastMileCache keeps the last 50 number blocks in memory cache +const lastMileCap = 50 + +type ( + // beaconHelper is the helper for the beacon downloader. The beaconHelper is only started + // when node is running on side chain, listening to beacon client pub-sub message and + // insert the latest blocks to the beacon chain. + beaconHelper struct { + bc blockChain + ih insertHelper + blockC <-chan *types.Block + // TODO: refactor this hook to consensus module. We'd better put it in + // consensus module under a subscription. + insertHook func() + + lastMileCache *blocksByNumber + insertC chan insertTask + closeC chan struct{} + logger zerolog.Logger + } + + insertTask struct { + doneC chan struct{} + } +) + +func newBeaconHelper(bc blockChain, ih insertHelper, blockC <-chan *types.Block, insertHook func()) *beaconHelper { + return &beaconHelper{ + bc: bc, + ih: ih, + blockC: blockC, + insertHook: insertHook, + lastMileCache: newBlocksByNumber(lastMileCap), + insertC: make(chan insertTask, 1), + closeC: make(chan struct{}), + logger: utils.Logger().With(). + Str("module", "downloader"). + Str("sub-module", "beacon helper"). + Logger(), + } +} + +func (bh *beaconHelper) start() { + go bh.loop() +} + +func (bh *beaconHelper) close() { + close(bh.closeC) +} + +func (bh *beaconHelper) loop() { + for { + select { + case <-time.Tick(10 * time.Second): + bh.insertAsync() + + case b, ok := <-bh.blockC: + if !ok { + return // blockC closed. Node exited + } + if b == nil { + continue + } + bh.lastMileCache.push(b) + bh.insertAsync() + + case it := <-bh.insertC: + inserted, bn, err := bh.insertLastMileBlocks() + if err != nil { + bh.logger.Warn().Err(err).Msg("insert last mile blocks error") + continue + } + bh.logger.Info().Int("inserted", inserted). + Uint64("end height", bn). + Uint32("shard", bh.bc.ShardID()). + Msg("insert last mile blocks") + + close(it.doneC) + + case <-bh.closeC: + return + } + } +} + +// insertSync triggers the insert last mile without blocking +func (bh *beaconHelper) insertAsync() { + select { + case bh.insertC <- insertTask{ + doneC: make(chan struct{}), + }: + default: + } +} + +// insertSync triggers the insert last mile while blocking +func (bh *beaconHelper) insertSync() { + task := insertTask{ + doneC: make(chan struct{}), + } + bh.insertC <- task + <-task.doneC +} + +func (bh *beaconHelper) insertLastMileBlocks() (inserted int, bn uint64, err error) { + bn = bh.bc.CurrentBlock().NumberU64() + 1 + for { + b := bh.getNextBlock(bn) + if b == nil { + bn-- + return + } + if err = bh.ih.verifyAndInsertBlock(b); err != nil { + bn-- + return + } + bh.logger.Info().Uint64("number", b.NumberU64()).Msg("Inserted block from beacon pub-sub") + if bh.insertHook != nil { + bh.insertHook() + } + inserted++ + bn++ + } +} + +func (bh *beaconHelper) getNextBlock(expBN uint64) *types.Block { + for bh.lastMileCache.len() > 0 { + b := bh.lastMileCache.pop() + if b == nil { + return nil + } + if b.NumberU64() < expBN { + continue + } + if b.NumberU64() > expBN { + bh.lastMileCache.push(b) + return nil + } + return b + } + return nil +} diff --git a/hmy/downloader/const.go b/hmy/downloader/const.go new file mode 100644 index 0000000000..c7246c64c6 --- /dev/null +++ b/hmy/downloader/const.go @@ -0,0 +1,68 @@ +package downloader + +import ( + "github.com/harmony-one/harmony/core/types" + nodeconfig "github.com/harmony-one/harmony/internal/configs/node" +) + +const ( + numBlocksByNumPerRequest int = 10 // number of blocks for each request + blocksPerInsert int = 50 // number of blocks for each insert batch + + numBlockHashesPerRequest int = 20 // number of get block hashes for short range sync + numBlocksByHashesUpperCap int = 10 // number of get blocks by hashes upper cap + numBlocksByHashesLowerCap int = 3 // number of get blocks by hashes lower cap + + lastMileThres int = 10 + + // soft cap of size in resultQueue. When the queue size is larger than this limit, + // no more request will be assigned to workers to wait for InsertChain to finish. + softQueueCap int = 100 + + defaultConcurrency = 16 +) + +type ( + // Config is the downloader config + Config struct { + // parameters + Network nodeconfig.NetworkType + Concurrency int // Number of concurrent sync requests + MinStreams int // Minimum number of streams to do sync + InitStreams int // Number of streams requirement for initial bootstrap + + // stream manager config + SmSoftLowCap int + SmHardLowCap int + SmHiCap int + SmDiscBatch int + + // config for beacon config + BHConfig *BeaconHelperConfig + } + + // BeaconHelperConfig is the extra config used for beaconHelper which uses + // pub-sub block message to do sync. + BeaconHelperConfig struct { + BlockC <-chan *types.Block + InsertHook func() + } +) + +func (c *Config) fixValues() { + if c.Concurrency == 0 { + c.Concurrency = defaultConcurrency + } + if c.Concurrency > c.MinStreams { + c.MinStreams = c.Concurrency + } + if c.MinStreams > c.InitStreams { + c.InitStreams = c.MinStreams + } + if c.MinStreams > c.SmSoftLowCap { + c.SmSoftLowCap = c.MinStreams + } + if c.MinStreams > c.SmHardLowCap { + c.SmHardLowCap = c.MinStreams + } +} diff --git a/hmy/downloader/downloader.go b/hmy/downloader/downloader.go new file mode 100644 index 0000000000..905b91f77e --- /dev/null +++ b/hmy/downloader/downloader.go @@ -0,0 +1,250 @@ +package downloader + +import ( + "context" + "time" + + "github.com/ethereum/go-ethereum/event" + "github.com/harmony-one/harmony/core" + nodeconfig "github.com/harmony-one/harmony/internal/configs/node" + "github.com/harmony-one/harmony/internal/utils" + "github.com/harmony-one/harmony/p2p" + "github.com/harmony-one/harmony/p2p/stream/common/streammanager" + "github.com/harmony-one/harmony/p2p/stream/protocols/sync" + "github.com/rs/zerolog" +) + +type ( + // Downloader is responsible for sync task of one shard + Downloader struct { + bc blockChain + ih insertHelper + syncProtocol syncProtocol + bh *beaconHelper + + downloadC chan struct{} + closeC chan struct{} + ctx context.Context + cancel func() + + evtDownloadFinished event.Feed // channel for each download task finished + evtDownloadStarted event.Feed // channel for each download has started + + status status + config Config + logger zerolog.Logger + } +) + +// NewDownloader creates a new downloader +func NewDownloader(host p2p.Host, bc *core.BlockChain, config Config) *Downloader { + config.fixValues() + + ih := newInsertHelper(bc) + + sp := sync.NewProtocol(sync.Config{ + Chain: bc, + Host: host.GetP2PHost(), + Discovery: host.GetDiscovery(), + ShardID: nodeconfig.ShardID(bc.ShardID()), + Network: config.Network, + + SmSoftLowCap: config.SmSoftLowCap, + SmHardLowCap: config.SmHardLowCap, + SmHiCap: config.SmHiCap, + DiscBatch: config.SmDiscBatch, + }) + host.AddStreamProtocol(sp) + + var bh *beaconHelper + if config.BHConfig != nil && bc.ShardID() == 0 { + bh = newBeaconHelper(bc, ih, config.BHConfig.BlockC, config.BHConfig.InsertHook) + } + + ctx, cancel := context.WithCancel(context.Background()) + + return &Downloader{ + bc: bc, + ih: ih, + syncProtocol: sp, + bh: bh, + + downloadC: make(chan struct{}), + closeC: make(chan struct{}), + ctx: ctx, + cancel: cancel, + + status: newStatus(), + config: config, + logger: utils.Logger().With().Str("module", "downloader").Logger(), + } +} + +// Start start the downloader +func (d *Downloader) Start() { + go d.run() + + if d.bh != nil { + d.bh.start() + } +} + +// Close close the downloader +func (d *Downloader) Close() { + close(d.closeC) + d.cancel() + + if d.bh != nil { + d.bh.close() + } +} + +// DownloadAsync triggers the download async. If there is already a download task that is +// in progress, return ErrDownloadInProgress. +func (d *Downloader) DownloadAsync() { + select { + case d.downloadC <- struct{}{}: + case <-time.After(100 * time.Millisecond): + } +} + +// NumPeers returns the number of peers connected of a specific shard. +func (d *Downloader) NumPeers() int { + return d.syncProtocol.NumStreams() +} + +// IsSyncing return the current sync status +func (d *Downloader) SyncStatus() (bool, uint64) { + syncing, target := d.status.get() + if !syncing { + target = d.bc.CurrentBlock().NumberU64() + } + return syncing, target +} + +// SubscribeDownloadStarted subscribe download started +func (d *Downloader) SubscribeDownloadStarted(ch chan struct{}) event.Subscription { + return d.evtDownloadStarted.Subscribe(ch) +} + +// SubscribeDownloadFinishedEvent subscribe the download finished +func (d *Downloader) SubscribeDownloadFinished(ch chan struct{}) event.Subscription { + return d.evtDownloadFinished.Subscribe(ch) +} + +func (d *Downloader) run() { + d.waitForBootFinish() + d.loop() +} + +// waitForBootFinish wait for stream manager to finish the initial discovery and have +// enough peers to start downloader +func (d *Downloader) waitForBootFinish() { + evtCh := make(chan streammanager.EvtStreamAdded, 1) + sub := d.syncProtocol.SubscribeAddStreamEvent(evtCh) + defer sub.Unsubscribe() + + checkCh := make(chan struct{}, 1) + trigger := func() { + select { + case checkCh <- struct{}{}: + default: + } + } + trigger() + + t := time.NewTicker(10 * time.Second) + + for { + d.logger.Info().Msg("waiting for initial bootstrap discovery") + select { + case <-t.C: + trigger() + + case <-evtCh: + trigger() + + case <-checkCh: + if d.syncProtocol.NumStreams() >= d.config.InitStreams { + return + } + case <-d.closeC: + return + } + } +} + +func (d *Downloader) loop() { + ticker := time.NewTicker(10 * time.Second) + initSync := true + trigger := func() { + select { + case d.downloadC <- struct{}{}: + case <-time.After(100 * time.Millisecond): + } + } + go trigger() + + for { + select { + case <-ticker.C: + go trigger() + + case <-d.downloadC: + addedBN, err := d.doDownload(initSync) + if err != nil { + // If error happens, sleep 5 seconds and retry + d.logger.Warn().Err(err).Bool("bootstrap", initSync).Msg("failed to download") + go func() { + time.Sleep(5 * time.Second) + trigger() + }() + continue + } + d.logger.Info().Int("block added", addedBN). + Uint64("current height", d.bc.CurrentBlock().NumberU64()). + Bool("initSync", initSync). + Uint32("shard", d.bc.ShardID()). + Msg("sync finished") + + if addedBN != 0 { + // If block number has been changed, trigger another sync + // and try to add last mile from pub-sub (blocking) + go trigger() + if d.bh != nil { + d.bh.insertSync() + } + } + initSync = false + + case <-d.closeC: + return + } + } +} + +func (d *Downloader) doDownload(initSync bool) (n int, err error) { + if initSync { + d.logger.Info().Uint64("current number", d.bc.CurrentBlock().NumberU64()). + Uint32("shard ID", d.bc.ShardID()).Msg("start long range sync") + n, err = d.doLongRangeSync() + } else { + d.logger.Info().Uint64("current number", d.bc.CurrentBlock().NumberU64()). + Uint32("shard ID", d.bc.ShardID()).Msg("start short range sync") + n, err = d.doShortRangeSync() + } + if err != nil { + return + } + return +} + +func (d *Downloader) startSyncing() { + d.status.startSyncing() + d.evtDownloadStarted.Send(struct{}{}) +} + +func (d *Downloader) finishSyncing() { + d.status.finishSyncing() + d.evtDownloadFinished.Send(struct{}{}) +} diff --git a/hmy/downloader/downloader_test.go b/hmy/downloader/downloader_test.go new file mode 100644 index 0000000000..123a9a7874 --- /dev/null +++ b/hmy/downloader/downloader_test.go @@ -0,0 +1,97 @@ +package downloader + +import ( + "context" + "fmt" + "testing" + "time" +) + +func TestDownloader_Integration(t *testing.T) { + sp := newTestSyncProtocol(1000, 48, nil) + bc := newTestBlockChain(0, nil) + ctx, cancel := context.WithCancel(context.Background()) + c := Config{} + c.fixValues() // use default config values + + d := &Downloader{ + bc: bc, + ih: &testInsertHelper{bc}, + syncProtocol: sp, + downloadC: make(chan struct{}), + closeC: make(chan struct{}), + ctx: ctx, + cancel: cancel, + config: c, + } + + // subscribe download event + finishedCh := make(chan struct{}, 1) + finishedSub := d.SubscribeDownloadFinished(finishedCh) + startedCh := make(chan struct{}, 1) + startedSub := d.SubscribeDownloadStarted(startedCh) + defer finishedSub.Unsubscribe() + defer startedSub.Unsubscribe() + + // Start the downloader + d.Start() + defer d.Close() + + // During bootstrap, trigger two download task: one long range, one short range. + // The second one will not trigger start / finish events. + if err := checkReceiveChanMulTimes(startedCh, 1, 10*time.Second); err != nil { + t.Fatal(err) + } + if err := checkReceiveChanMulTimes(finishedCh, 1, 10*time.Second); err != nil { + t.Fatal(err) + } + if curBN := d.bc.CurrentBlock().NumberU64(); curBN != 1000 { + t.Fatal("blockchain not synced to the latest") + } + + // Increase the remote block number, and trigger one download task manually + sp.changeBlockNumber(1010) + d.DownloadAsync() + // We shall do short range test twice + if err := checkReceiveChanMulTimes(startedCh, 1, 10*time.Second); err != nil { + t.Fatal(err) + } + if err := checkReceiveChanMulTimes(finishedCh, 1, 10*time.Second); err != nil { + t.Fatal(err) + } + if curBN := d.bc.CurrentBlock().NumberU64(); curBN != 1010 { + t.Fatal("blockchain not synced to the latest") + } + + // Remote block number unchanged, and trigger one download task manually + d.DownloadAsync() + if err := checkReceiveChanMulTimes(startedCh, 0, 10*time.Second); err != nil { + t.Fatal(err) + } + if err := checkReceiveChanMulTimes(finishedCh, 0, 10*time.Second); err != nil { + t.Fatal(err) + } + + // At last, check number of streams, should be exactly the same as the initial number + if sp.numStreams != 48 { + t.Errorf("unexpected number of streams at the end: %v / %v", sp.numStreams, 48) + } +} + +func checkReceiveChanMulTimes(ch chan struct{}, times int, timeout time.Duration) error { + t := time.Tick(timeout) + + for i := 0; i != times; i++ { + select { + case <-ch: + case <-t: + return fmt.Errorf("timed out %v", timeout) + } + } + select { + case <-ch: + return fmt.Errorf("received an extra event") + case <-time.After(100 * time.Millisecond): + } + return nil +} diff --git a/hmy/downloader/downloaders.go b/hmy/downloader/downloaders.go new file mode 100644 index 0000000000..230343db88 --- /dev/null +++ b/hmy/downloader/downloaders.go @@ -0,0 +1,74 @@ +package downloader + +import ( + "github.com/harmony-one/harmony/core" + "github.com/harmony-one/harmony/p2p" +) + +// Downloaders is the set of downloaders +type Downloaders struct { + ds map[uint32]*Downloader +} + +// NewDownloaders creates Downloaders for sync of multiple blockchains +func NewDownloaders(host p2p.Host, bcs []*core.BlockChain, config Config) *Downloaders { + ds := make(map[uint32]*Downloader) + + for _, bc := range bcs { + if bc == nil { + continue + } + if _, ok := ds[bc.ShardID()]; ok { + continue + } + ds[bc.ShardID()] = NewDownloader(host, bc, config) + } + return &Downloaders{ds} +} + +// Start start the downloaders +func (ds *Downloaders) Start() { + for _, d := range ds.ds { + d.Start() + } +} + +// Close close the downloaders +func (ds *Downloaders) Close() { + for _, d := range ds.ds { + d.Close() + } +} + +// DownloadAsync triggers a download +func (ds *Downloaders) DownloadAsync(shardID uint32) { + d, ok := ds.ds[shardID] + if !ok && d != nil { + d.DownloadAsync() + } +} + +// GetShardDownloader get the downloader with the given shard ID +func (ds *Downloaders) GetShardDownloader(shardID uint32) *Downloader { + return ds.ds[shardID] +} + +// NumPeers returns the connected peers for each shard +func (ds *Downloaders) NumPeers() map[uint32]int { + res := make(map[uint32]int) + + for sid, d := range ds.ds { + res[sid] = d.NumPeers() + } + return res +} + +// SyncStatus returns whether the given shard is doing syncing task and the target block +// number. +func (ds *Downloaders) SyncStatus(shardID uint32) (bool, uint64) { + d, ok := ds.ds[shardID] + if !ok { + return false, 0 + } + return d.SyncStatus() +} diff --git a/hmy/downloader/inserthelper.go b/hmy/downloader/inserthelper.go new file mode 100644 index 0000000000..9c5134d9b1 --- /dev/null +++ b/hmy/downloader/inserthelper.go @@ -0,0 +1,220 @@ +package downloader + +import ( + "fmt" + "hash" + "math/big" + "sync" + + "github.com/ethereum/go-ethereum/common" + lru "github.com/hashicorp/golang-lru" + "github.com/pkg/errors" + "golang.org/x/crypto/sha3" + + bls_core "github.com/harmony-one/bls/ffi/go/bls" + "github.com/harmony-one/harmony/consensus/quorum" + "github.com/harmony-one/harmony/consensus/signature" + "github.com/harmony-one/harmony/core/types" + bls_cosi "github.com/harmony-one/harmony/crypto/bls" + "github.com/harmony-one/harmony/internal/chain" + "github.com/harmony-one/harmony/multibls" + "github.com/harmony-one/harmony/shard" +) + +// sigVerifyError is the error type of failing verify the signature of the current block. +// Since this is a sanity field and is not included the block hash, it needs extra verification. +// The error types is used to differentiate the error of signature verification VS insert error. +type sigVerifyError struct { + err error +} + +func (err *sigVerifyError) Error() string { + return fmt.Sprintf("failed verify signature: %v", err.err.Error()) +} + +// insertHelperImpl helps to verify and insert blocks, along with some caching mechanism. +type insertHelperImpl struct { + bc blockChain + + deciderCache *lru.Cache // Epoch -> quorum.Decider + shardStateCache *lru.Cache // Epoch -> *shard.State + verifiedSigCache *lru.Cache // verifiedSigKey -> struct{}{} +} + +func newInsertHelper(bc blockChain) insertHelper { + deciderCache, _ := lru.New(5) + shardStateCache, _ := lru.New(5) + sigCache, _ := lru.New(20) + return &insertHelperImpl{ + bc: bc, + deciderCache: deciderCache, + shardStateCache: shardStateCache, + verifiedSigCache: sigCache, + } +} + +func (ch *insertHelperImpl) verifyAndInsertBlocks(blocks types.Blocks) (int, error) { + for i, block := range blocks { + if err := ch.verifyAndInsertBlock(block); err != nil { + return i, err + } + } + return len(blocks), nil +} + +func (ch *insertHelperImpl) verifyAndInsertBlock(block *types.Block) error { + // verify the commit sig of current block + if err := ch.verifyBlockSignature(block); err != nil { + return &sigVerifyError{err} + } + ch.markBlockSigVerified(block, block.GetCurrentCommitSig()) + + // verify header. Skip verify the previous seal if we have already verified + verifySeal := !ch.isBlockLastSigVerified(block) + if err := ch.bc.Engine().VerifyHeader(ch.bc, block.Header(), verifySeal); err != nil { + return err + } + // Insert chain. + if _, err := ch.bc.InsertChain(types.Blocks{block}, false); err != nil { + return err + } + // Write commit sig data + return ch.bc.WriteCommitSig(block.NumberU64(), block.GetCurrentCommitSig()) +} + +func (ch *insertHelperImpl) verifyBlockSignature(block *types.Block) error { + // TODO: This is the duplicate logic to the implementation of verifySeal and consensus. + // Better refactor to the blockchain or engine structure + decider, err := ch.readDeciderByEpoch(block.Epoch()) + if err != nil { + return err + } + sig, mask, err := decodeCommitSig(block.GetCurrentCommitSig(), decider.Participants()) + if err != nil { + return err + } + if !decider.IsQuorumAchievedByMask(mask) { + return errors.New("quorum not achieved") + } + + commitSigBytes := signature.ConstructCommitPayload(ch.bc, block.Epoch(), block.Hash(), + block.NumberU64(), block.Header().ViewID().Uint64()) + if !sig.VerifyHash(mask.AggregatePublic, commitSigBytes) { + return errors.New("aggregate signature failed verification") + } + return nil +} + +func (ch *insertHelperImpl) writeBlockSignature(block *types.Block) error { + return ch.bc.WriteCommitSig(block.NumberU64(), block.GetCurrentCommitSig()) +} + +func (ch *insertHelperImpl) getDeciderByEpoch(epoch *big.Int) (quorum.Decider, error) { + epochUint := epoch.Uint64() + if decider, ok := ch.deciderCache.Get(epochUint); ok && decider != nil { + return decider.(quorum.Decider), nil + } + decider, err := ch.getDeciderByEpoch(epoch) + if err != nil { + return nil, errors.Wrapf(err, "unable to read quorum of epoch %v", epoch.Uint64()) + } + ch.deciderCache.Add(epochUint, decider) + return decider, nil +} + +func (ch *insertHelperImpl) readDeciderByEpoch(epoch *big.Int) (quorum.Decider, error) { + isStaking := ch.bc.Config().IsStaking(epoch) + decider := ch.getNewDecider(isStaking) + ss, err := ch.getShardState(epoch) + if err != nil { + return nil, err + } + subComm, err := ss.FindCommitteeByID(ch.shardID()) + if err != nil { + return nil, err + } + pubKeys, err := subComm.BLSPublicKeys() + if err != nil { + return nil, err + } + decider.UpdateParticipants(pubKeys) + if _, err := decider.SetVoters(subComm, epoch); err != nil { + return nil, err + } + return decider, nil +} + +func (ch *insertHelperImpl) getNewDecider(isStaking bool) quorum.Decider { + if isStaking { + return quorum.NewDecider(quorum.SuperMajorityVote, ch.bc.ShardID()) + } else { + return quorum.NewDecider(quorum.SuperMajorityStake, ch.bc.ShardID()) + } +} + +func (ch *insertHelperImpl) getShardState(epoch *big.Int) (*shard.State, error) { + if ss, ok := ch.shardStateCache.Get(epoch.Uint64()); ok && ss != nil { + return ss.(*shard.State), nil + } + ss, err := ch.bc.ReadShardState(epoch) + if err != nil { + return nil, err + } + ch.shardStateCache.Add(epoch.Uint64(), ss) + return ss, nil +} + +func (ch *insertHelperImpl) markBlockSigVerified(block *types.Block, sigAndBitmap []byte) { + key := newVerifiedSigKey(block.Hash(), sigAndBitmap) + ch.verifiedSigCache.Add(key, struct{}{}) +} + +func (ch *insertHelperImpl) isBlockLastSigVerified(block *types.Block) bool { + lastSig := block.Header().LastCommitSignature() + lastBM := block.Header().LastCommitBitmap() + lastSigAndBM := append(lastSig[:], lastBM...) + + key := newVerifiedSigKey(block.Hash(), lastSigAndBM) + _, ok := ch.verifiedSigCache.Get(key) + return ok +} + +func (ch *insertHelperImpl) shardID() uint32 { + return ch.bc.ShardID() +} + +func decodeCommitSig(commitBytes []byte, publicKeys multibls.PublicKeys) (*bls_core.Sign, *bls_cosi.Mask, error) { + if len(commitBytes) < bls_cosi.BLSSignatureSizeInBytes { + return nil, nil, fmt.Errorf("unexpected signature bytes size: %v / %v", len(commitBytes), + bls_cosi.BLSSignatureSizeInBytes) + } + return chain.ReadSignatureBitmapByPublicKeys(commitBytes, publicKeys) +} + +type verifiedSigKey struct { + blockHash common.Hash + sbHash common.Hash // hash of block signature + bitmap +} + +var hasherPool = sync.Pool{ + New: func() interface{} { + return sha3.New256() + }, +} + +func newVerifiedSigKey(blockHash common.Hash, sigAndBitmap []byte) verifiedSigKey { + hasher := hasherPool.Get().(hash.Hash) + defer func() { + hasher.Reset() + hasherPool.Put(hasher) + }() + + var sbHash common.Hash + hasher.Write(sigAndBitmap) + hasher.Sum(sbHash[0:]) + + return verifiedSigKey{ + blockHash: blockHash, + sbHash: sbHash, + } +} diff --git a/hmy/downloader/inserthelper_test.go b/hmy/downloader/inserthelper_test.go new file mode 100644 index 0000000000..00acaffe49 --- /dev/null +++ b/hmy/downloader/inserthelper_test.go @@ -0,0 +1,19 @@ +package downloader + +import ( + "testing" + + "github.com/ethereum/go-ethereum/common" +) + +func BenchmarkNewVerifiedSigKey(b *testing.B) { + var bh common.Hash + commitSig := make([]byte, 100) + for i := 0; i != len(commitSig); i++ { + commitSig[i] = 0xf + } + + for i := 0; i != b.N; i++ { + newVerifiedSigKey(bh, commitSig) + } +} diff --git a/hmy/downloader/longrange.go b/hmy/downloader/longrange.go new file mode 100644 index 0000000000..a60cc0e7b8 --- /dev/null +++ b/hmy/downloader/longrange.go @@ -0,0 +1,513 @@ +package downloader + +import ( + "context" + "fmt" + "sync" + "time" + + "github.com/harmony-one/harmony/core/types" + syncproto "github.com/harmony-one/harmony/p2p/stream/protocols/sync" + sttypes "github.com/harmony-one/harmony/p2p/stream/types" + "github.com/pkg/errors" + "github.com/rs/zerolog" +) + +// doLongRangeSync does the long range sync. +// One LongRangeSync consists of several iterations. +// For each iteration, estimate the current block number, then fetch block & insert to blockchain +func (d *Downloader) doLongRangeSync() (int, error) { + var totalInserted int + + d.startSyncing() + defer d.finishSyncing() + + for { + ctx, cancel := context.WithCancel(d.ctx) + + iter := &lrSyncIter{ + bc: d.bc, + p: d.syncProtocol, + ih: d.ih, + d: d, + ctx: ctx, + config: d.config, + logger: d.logger.With().Str("mode", "long range").Logger(), + } + if err := iter.doLongRangeSync(); err != nil { + cancel() + return totalInserted + iter.inserted, err + } + cancel() + + totalInserted += iter.inserted + + if iter.inserted < lastMileThres { + return totalInserted, nil + } + } +} + +// lrSyncIter run a single iteration of a full long range sync. +// First get a rough estimate of the current block height, and then sync to this +// block number +type lrSyncIter struct { + bc blockChain + p syncProtocol + ih insertHelper + d *Downloader + + gbm *getBlocksManager // initialized when finished get block number + inserted int + + config Config + ctx context.Context + logger zerolog.Logger +} + +func (lsi *lrSyncIter) doLongRangeSync() error { + if err := lsi.checkPrerequisites(); err != nil { + return err + } + bn, err := lsi.estimateCurrentNumber() + if err != nil { + return err + } + lsi.logger.Info().Uint64("target number", bn).Msg("estimated remote current number") + lsi.d.status.setTargetBN(bn) + + return lsi.fetchAndInsertBlocks(bn) +} + +func (lsi *lrSyncIter) checkPrerequisites() error { + return lsi.checkHaveEnoughStreams() +} + +// estimateCurrentNumber roughly estimate the current block number. +// The block number does not need to be exact, but just a temporary target of the iteration +func (lsi *lrSyncIter) estimateCurrentNumber() (uint64, error) { + var ( + cnResults = make(map[sttypes.StreamID]uint64) + lock sync.Mutex + wg sync.WaitGroup + ) + wg.Add(lsi.config.Concurrency) + for i := 0; i != lsi.config.Concurrency; i++ { + go func() { + defer wg.Done() + bn, stid, err := lsi.doGetCurrentNumberRequest() + if err != nil { + lsi.logger.Err(err).Str("streamID", string(stid)). + Msg("getCurrentNumber request failed. Removing stream") + if !errors.Is(err, context.Canceled) { + lsi.p.RemoveStream(stid) + } + return + } + lock.Lock() + cnResults[stid] = bn + lock.Unlock() + }() + } + wg.Wait() + + if len(cnResults) == 0 { + select { + case <-lsi.ctx.Done(): + return 0, lsi.ctx.Err() + default: + } + return 0, errors.New("zero block number response from remote nodes") + } + bn := computeBNMaxVote(cnResults) + return bn, nil +} + +func (lsi *lrSyncIter) doGetCurrentNumberRequest() (uint64, sttypes.StreamID, error) { + ctx, cancel := context.WithTimeout(lsi.ctx, 10*time.Second) + defer cancel() + + bn, stid, err := lsi.p.GetCurrentBlockNumber(ctx, syncproto.WithHighPriority()) + if err != nil { + return 0, stid, err + } + return bn, stid, nil +} + +// fetchAndInsertBlocks use the pipeline pattern to boost the performance of inserting blocks. +// TODO: For resharding, use the pipeline to do fast sync (epoch loop, header loop, body loop) +func (lsi *lrSyncIter) fetchAndInsertBlocks(targetBN uint64) error { + gbm := newGetBlocksManager(lsi.bc, targetBN, lsi.logger) + lsi.gbm = gbm + + // Setup workers to fetch blocks from remote node + for i := 0; i != lsi.config.Concurrency; i++ { + worker := &getBlocksWorker{ + gbm: gbm, + protocol: lsi.p, + ctx: lsi.ctx, + } + go worker.workLoop() + } + + // insert the blocks to chain. Return when the target block number is reached. + lsi.insertChainLoop(targetBN) + + select { + case <-lsi.ctx.Done(): + return lsi.ctx.Err() + default: + } + return nil +} + +func (lsi *lrSyncIter) insertChainLoop(targetBN uint64) { + var ( + gbm = lsi.gbm + t = time.NewTicker(100 * time.Millisecond) + resultC = make(chan struct{}, 1) + ) + + trigger := func() { + select { + case resultC <- struct{}{}: + default: + } + } + + for { + select { + case <-lsi.ctx.Done(): + return + + case <-t.C: + // Redundancy, periodically check whether there is blocks that can be processed + trigger() + + case <-gbm.resultC: + // New block arrive in resultQueue + trigger() + + case <-resultC: + blockResults := gbm.PullContinuousBlocks(blocksPerInsert) + if len(blockResults) > 0 { + lsi.processBlocks(blockResults, targetBN) + // more blocks is expected being able to be pulled from queue + trigger() + } + if lsi.bc.CurrentBlock().NumberU64() >= targetBN { + return + } + } + } +} + +func (lsi *lrSyncIter) processBlocks(results []*blockResult, targetBN uint64) { + blocks := blockResultsToBlocks(results) + + for i, block := range blocks { + if err := lsi.ih.verifyAndInsertBlock(block); err != nil { + lsi.logger.Warn().Err(err).Uint64("target block", targetBN). + Uint64("block number", block.NumberU64()). + Msg("insert blocks failed in long range") + + lsi.p.RemoveStream(results[i].stid) + lsi.gbm.HandleInsertError(results, i) + return + } + + lsi.inserted++ + } + lsi.gbm.HandleInsertResult(results) +} + +func (lsi *lrSyncIter) checkHaveEnoughStreams() error { + numStreams := lsi.p.NumStreams() + if numStreams < lsi.config.MinStreams { + return fmt.Errorf("number of streams smaller than minimum: %v < %v", + numStreams, lsi.config.MinStreams) + } + return nil +} + +// getBlocksWorker does the request job +type getBlocksWorker struct { + gbm *getBlocksManager + protocol syncProtocol + + ctx context.Context +} + +func (w *getBlocksWorker) workLoop() { + for { + select { + case <-w.ctx.Done(): + return + default: + } + batch := w.gbm.GetNextBatch() + if len(batch) == 0 { + select { + case <-w.ctx.Done(): + return + case <-time.After(100 * time.Millisecond): + continue + } + } + + blocks, stid, err := w.doBatch(batch) + if err != nil { + if !errors.Is(err, context.Canceled) { + w.protocol.RemoveStream(stid) + } + err = errors.Wrap(err, "request error") + w.gbm.HandleRequestError(batch, err, stid) + } else { + w.gbm.HandleRequestResult(batch, blocks, stid) + } + } +} + +func (w *getBlocksWorker) doBatch(bns []uint64) ([]*types.Block, sttypes.StreamID, error) { + ctx, cancel := context.WithTimeout(w.ctx, 10*time.Second) + defer cancel() + + blocks, stid, err := w.protocol.GetBlocksByNumber(ctx, bns) + if err != nil { + return nil, stid, err + } + if err := validateGetBlocksResult(bns, blocks); err != nil { + return nil, stid, err + } + return blocks, stid, nil +} + +// getBlocksManager is the helper structure for get blocks request management +type getBlocksManager struct { + chain blockChain + + targetBN uint64 + requesting map[uint64]struct{} // block numbers that have been assigned to workers but not received + processing map[uint64]struct{} // block numbers received requests but not inserted + retries *prioritizedNumbers // requests where error happens + rq *resultQueue // result queue wait to be inserted into blockchain + + resultC chan struct{} + logger zerolog.Logger + lock sync.Mutex +} + +func newGetBlocksManager(chain blockChain, targetBN uint64, logger zerolog.Logger) *getBlocksManager { + return &getBlocksManager{ + chain: chain, + targetBN: targetBN, + requesting: make(map[uint64]struct{}), + processing: make(map[uint64]struct{}), + retries: newPrioritizedNumbers(), + rq: newResultQueue(), + resultC: make(chan struct{}, 1), + logger: logger, + } +} + +// GetNextBatch get the next block numbers batch +func (gbm *getBlocksManager) GetNextBatch() []uint64 { + gbm.lock.Lock() + defer gbm.lock.Unlock() + + cap := numBlocksByNumPerRequest + + bns := gbm.getBatchFromRetries(cap) + cap -= len(bns) + gbm.addBatchToRequesting(bns) + + if gbm.availableForMoreTasks() { + addBNs := gbm.getBatchFromUnprocessed(cap) + gbm.addBatchToRequesting(addBNs) + bns = append(bns, addBNs...) + } + + return bns +} + +// HandleRequestError handles the error result +func (gbm *getBlocksManager) HandleRequestError(bns []uint64, err error, stid sttypes.StreamID) { + gbm.lock.Lock() + defer gbm.lock.Unlock() + + gbm.logger.Warn().Err(err).Str("stream", string(stid)).Msg("get blocks error") + + // add requested block numbers to retries + for _, bn := range bns { + delete(gbm.requesting, bn) + gbm.retries.push(bn) + } + + // remove results from result queue by the stream and add back to retries + removed := gbm.rq.removeResultsByStreamID(stid) + for _, bn := range removed { + delete(gbm.processing, bn) + gbm.retries.push(bn) + } +} + +// HandleRequestResult handles get blocks result +func (gbm *getBlocksManager) HandleRequestResult(bns []uint64, blocks []*types.Block, stid sttypes.StreamID) { + gbm.lock.Lock() + defer gbm.lock.Unlock() + + for i, bn := range bns { + delete(gbm.requesting, bn) + if blocks[i] == nil { + gbm.retries.push(bn) + } else { + gbm.processing[bn] = struct{}{} + } + } + gbm.rq.addBlockResults(blocks, stid) + select { + case gbm.resultC <- struct{}{}: + default: + } +} + +// HandleInsertResult handle the insert result +func (gbm *getBlocksManager) HandleInsertResult(inserted []*blockResult) { + gbm.lock.Lock() + defer gbm.lock.Unlock() + + for _, block := range inserted { + delete(gbm.processing, block.getBlockNumber()) + } +} + +// HandleInsertError handles the error during InsertChain +func (gbm *getBlocksManager) HandleInsertError(results []*blockResult, n int) { + gbm.lock.Lock() + defer gbm.lock.Unlock() + + var ( + inserted []*blockResult + errResult *blockResult + abandoned []*blockResult + ) + inserted = results[:n] + errResult = results[n] + if n != len(results) { + abandoned = results[n+1:] + } + + for _, res := range inserted { + delete(gbm.processing, res.getBlockNumber()) + } + for _, res := range abandoned { + gbm.rq.addBlockResults([]*types.Block{res.block}, res.stid) + } + + delete(gbm.processing, errResult.getBlockNumber()) + gbm.retries.push(errResult.getBlockNumber()) + + removed := gbm.rq.removeResultsByStreamID(errResult.stid) + for _, bn := range removed { + delete(gbm.processing, bn) + gbm.retries.push(bn) + } +} + +// PullContinuousBlocks pull continuous blocks from request queue +func (gbm *getBlocksManager) PullContinuousBlocks(cap int) []*blockResult { + gbm.lock.Lock() + defer gbm.lock.Unlock() + + expHeight := gbm.chain.CurrentBlock().NumberU64() + 1 + results, stales := gbm.rq.popBlockResults(expHeight, cap) + // For stale blocks, we remove them from processing + for _, bn := range stales { + delete(gbm.processing, bn) + } + return results +} + +func (gbm *getBlocksManager) getBatchFromRetries(cap int) []uint64 { + var ( + requestBNs []uint64 + curHeight = gbm.chain.CurrentBlock().NumberU64() + ) + for cnt := 0; cnt < cap; cnt++ { + bn := gbm.retries.pop() + if bn == 0 { + break // no more retries + } + if bn <= curHeight { + continue + } + requestBNs = append(requestBNs, bn) + } + return requestBNs +} + +func (gbm *getBlocksManager) getBatchFromUnprocessed(cap int) []uint64 { + var ( + requestBNs []uint64 + curHeight = gbm.chain.CurrentBlock().NumberU64() + ) + bn := curHeight + 1 + // TODO: this algorithm can be potentially optimized. + for cnt := 0; cnt < cap && bn <= gbm.targetBN; cnt++ { + for bn <= gbm.targetBN { + _, ok1 := gbm.requesting[bn] + _, ok2 := gbm.processing[bn] + if !ok1 && !ok2 { + requestBNs = append(requestBNs, bn) + bn++ + break + } + bn++ + } + } + return requestBNs +} + +func (gbm *getBlocksManager) availableForMoreTasks() bool { + return gbm.rq.results.Len() < softQueueCap +} + +func (gbm *getBlocksManager) addBatchToRequesting(bns []uint64) { + for _, bn := range bns { + gbm.requesting[bn] = struct{}{} + } +} + +func validateGetBlocksResult(requested []uint64, result []*types.Block) error { + if len(result) != len(requested) { + return fmt.Errorf("unexpected number of blocks delivered: %v / %v", len(result), len(requested)) + } + for i, block := range result { + if block != nil && block.NumberU64() != requested[i] { + return fmt.Errorf("block with unexpected number delivered: %v / %v", block.NumberU64(), requested[i]) + } + } + return nil +} + +func computeBNMaxVote(votes map[sttypes.StreamID]uint64) uint64 { + var ( + nm = make(map[uint64]int) + res uint64 + maxCnt int + ) + for _, bn := range votes { + _, ok := nm[bn] + if !ok { + nm[bn] = 0 + } + nm[bn]++ + cnt := nm[bn] + + if cnt > maxCnt || (cnt == maxCnt && bn > res) { + res = bn + maxCnt = cnt + } + } + return res +} diff --git a/hmy/downloader/longrange_test.go b/hmy/downloader/longrange_test.go new file mode 100644 index 0000000000..5ac8ba2e89 --- /dev/null +++ b/hmy/downloader/longrange_test.go @@ -0,0 +1,336 @@ +package downloader + +import ( + "context" + "fmt" + "math/rand" + "sync" + "testing" + + sttypes "github.com/harmony-one/harmony/p2p/stream/types" + "github.com/pkg/errors" +) + +func TestDownloader_doLongRangeSync(t *testing.T) { + targetBN := uint64(1000) + bc := newTestBlockChain(1, nil) + + d := &Downloader{ + bc: bc, + ih: &testInsertHelper{bc}, + syncProtocol: newTestSyncProtocol(targetBN, 32, nil), + config: Config{ + Concurrency: 16, + MinStreams: 16, + }, + ctx: context.Background(), + } + synced, err := d.doLongRangeSync() + if err != nil { + t.Error(err) + } + if synced == 0 { + t.Errorf("synced false") + } + if curNum := d.bc.CurrentBlock().NumberU64(); curNum != targetBN { + t.Errorf("block number not expected: %v / %v", curNum, targetBN) + } +} + +func TestLrSyncIter_EstimateCurrentNumber(t *testing.T) { + lsi := &lrSyncIter{ + p: newTestSyncProtocol(100, 32, nil), + ctx: context.Background(), + config: Config{ + Concurrency: 16, + MinStreams: 10, + }, + } + bn, err := lsi.estimateCurrentNumber() + if err != nil { + t.Error(err) + } + if bn != 100 { + t.Errorf("unexpected block number: %v / %v", bn, 100) + } +} + +func TestGetBlocksManager_GetNextBatch(t *testing.T) { + tests := []struct { + gbm *getBlocksManager + expBNs []uint64 + }{ + { + gbm: makeGetBlocksManager( + 10, 100, []uint64{9, 11, 12, 13}, + []uint64{14, 15, 16}, []uint64{}, 0, + ), + expBNs: []uint64{17, 18, 19, 20, 21, 22, 23, 24, 25, 26}, + }, + { + gbm: makeGetBlocksManager( + 10, 100, []uint64{9, 13, 14, 15, 16}, + []uint64{}, []uint64{10, 11, 12}, 0, + ), + expBNs: []uint64{11, 12, 17, 18, 19, 20, 21, 22, 23, 24}, + }, + { + gbm: makeGetBlocksManager( + 10, 100, []uint64{9, 13, 14, 15, 16}, + []uint64{}, []uint64{10, 11, 12}, 120, + ), + expBNs: []uint64{11, 12}, + }, + { + gbm: makeGetBlocksManager( + 10, 100, []uint64{9, 13, 14, 15, 16}, + []uint64{}, []uint64{}, 120, + ), + expBNs: []uint64{}, + }, + { + gbm: makeGetBlocksManager( + 10, 20, []uint64{9, 13, 14, 15, 16}, + []uint64{}, []uint64{}, 0, + ), + expBNs: []uint64{11, 12, 17, 18, 19, 20}, + }, + { + gbm: makeGetBlocksManager( + 10, 100, []uint64{9, 13, 14, 15, 16}, + []uint64{}, []uint64{}, 0, + ), + expBNs: []uint64{11, 12, 17, 18, 19, 20, 21, 22, 23, 24}, + }, + } + + for i, test := range tests { + if i < 4 { + continue + } + batch := test.gbm.GetNextBatch() + if len(test.expBNs) != len(batch) { + t.Errorf("Test %v: unexpected size [%v] / [%v]", i, batch, test.expBNs) + } + for i := range test.expBNs { + if test.expBNs[i] != batch[i] { + t.Errorf("Test %v: [%v] / [%v]", i, batch, test.expBNs) + } + } + } +} + +func TestLrSyncIter_FetchAndInsertBlocks(t *testing.T) { + targetBN := uint64(1000) + chain := newTestBlockChain(0, nil) + protocol := newTestSyncProtocol(targetBN, 32, nil) + ctx, _ := context.WithCancel(context.Background()) + + lsi := &lrSyncIter{ + bc: chain, + ih: &testInsertHelper{chain}, + d: &Downloader{}, + p: protocol, + gbm: nil, + config: Config{ + Concurrency: 100, + }, + ctx: ctx, + } + lsi.fetchAndInsertBlocks(targetBN) + + if err := fetchAndInsertBlocksResultCheck(lsi, targetBN, initStreamNum); err != nil { + t.Error(err) + } +} + +// When FetchAndInsertBlocks, one request has an error +func TestLrSyncIter_FetchAndInsertBlocks_ErrRequest(t *testing.T) { + targetBN := uint64(1000) + var once sync.Once + errHook := func(bn uint64) error { + var err error + once.Do(func() { + err = errors.New("test error expected") + }) + return err + } + chain := newTestBlockChain(0, nil) + protocol := newTestSyncProtocol(targetBN, 32, errHook) + ctx, _ := context.WithCancel(context.Background()) + + lsi := &lrSyncIter{ + bc: chain, + ih: &testInsertHelper{chain}, + d: &Downloader{}, + p: protocol, + gbm: nil, + config: Config{ + Concurrency: 100, + }, + ctx: ctx, + } + lsi.fetchAndInsertBlocks(targetBN) + + if err := fetchAndInsertBlocksResultCheck(lsi, targetBN, initStreamNum-1); err != nil { + t.Error(err) + } +} + +// When FetchAndInsertBlocks, one insertion has an error +func TestLrSyncIter_FetchAndInsertBlocks_ErrInsert(t *testing.T) { + targetBN := uint64(1000) + var once sync.Once + errHook := func(bn uint64) error { + var err error + once.Do(func() { + err = errors.New("test error expected") + }) + return err + } + chain := newTestBlockChain(0, errHook) + protocol := newTestSyncProtocol(targetBN, 32, nil) + ctx, _ := context.WithCancel(context.Background()) + + lsi := &lrSyncIter{ + bc: chain, + ih: &testInsertHelper{chain}, + d: &Downloader{}, + p: protocol, + gbm: nil, + config: Config{ + Concurrency: 100, + }, + ctx: ctx, + } + lsi.fetchAndInsertBlocks(targetBN) + + if err := fetchAndInsertBlocksResultCheck(lsi, targetBN, initStreamNum-1); err != nil { + t.Error(err) + } +} + +// When FetchAndInsertBlocks, randomly error happens +func TestLrSyncIter_FetchAndInsertBlocks_RandomErr(t *testing.T) { + targetBN := uint64(10000) + rand.Seed(0) + errHook := func(bn uint64) error { + // 10% error happens + if rand.Intn(10)%10 == 0 { + return errors.New("error expected") + } + return nil + } + chain := newTestBlockChain(0, errHook) + protocol := newTestSyncProtocol(targetBN, 32, errHook) + ctx, _ := context.WithCancel(context.Background()) + + lsi := &lrSyncIter{ + bc: chain, + ih: &testInsertHelper{chain}, + d: &Downloader{}, + p: protocol, + gbm: nil, + config: Config{ + Concurrency: 100, + }, + ctx: ctx, + } + lsi.fetchAndInsertBlocks(targetBN) + + if err := fetchAndInsertBlocksResultCheck(lsi, targetBN, minStreamNum); err != nil { + t.Error(err) + } +} + +func fetchAndInsertBlocksResultCheck(lsi *lrSyncIter, targetBN uint64, expNumStreams int) error { + if bn := lsi.bc.CurrentBlock().NumberU64(); bn != targetBN { + return fmt.Errorf("did not reached targetBN: %v / %v", bn, targetBN) + } + lsi.gbm.lock.Lock() + defer lsi.gbm.lock.Unlock() + if len(lsi.gbm.processing) != 0 { + return fmt.Errorf("not empty processing: %v", lsi.gbm.processing) + } + if len(lsi.gbm.requesting) != 0 { + return fmt.Errorf("not empty requesting: %v", lsi.gbm.requesting) + } + if lsi.gbm.retries.length() != 0 { + return fmt.Errorf("not empty retries: %v", lsi.gbm.retries) + } + if lsi.gbm.rq.length() != 0 { + return fmt.Errorf("not empty result queue: %v", lsi.gbm.rq.results) + } + tsp := lsi.p.(*testSyncProtocol) + if len(tsp.streamIDs) != expNumStreams { + return fmt.Errorf("num streams not expected: %v / %v", len(tsp.streamIDs), expNumStreams) + } + return nil +} + +func TestComputeBNMaxVote(t *testing.T) { + tests := []struct { + votes map[sttypes.StreamID]uint64 + exp uint64 + }{ + { + votes: map[sttypes.StreamID]uint64{ + makeStreamID(0): 10, + makeStreamID(1): 10, + makeStreamID(2): 20, + }, + exp: 10, + }, + { + votes: map[sttypes.StreamID]uint64{ + makeStreamID(0): 10, + makeStreamID(1): 20, + }, + exp: 20, + }, + { + votes: map[sttypes.StreamID]uint64{ + makeStreamID(0): 20, + makeStreamID(1): 10, + makeStreamID(2): 20, + }, + exp: 20, + }, + } + + for i, test := range tests { + res := computeBNMaxVote(test.votes) + if res != test.exp { + t.Errorf("Test %v: unexpected bn %v / %v", i, res, test.exp) + } + } +} + +func makeGetBlocksManager(curBN, targetBN uint64, requesting, processing, retries []uint64, sizeRQ int) *getBlocksManager { + chain := newTestBlockChain(curBN, nil) + requestingM := make(map[uint64]struct{}) + for _, bn := range requesting { + requestingM[bn] = struct{}{} + } + processingM := make(map[uint64]struct{}) + for _, bn := range processing { + processingM[bn] = struct{}{} + } + retriesPN := newPrioritizedNumbers() + for _, retry := range retries { + retriesPN.push(retry) + } + rq := newResultQueue() + for i := uint64(0); i != uint64(sizeRQ); i++ { + rq.addBlockResults(makeTestBlocks([]uint64{i + curBN}), "") + } + return &getBlocksManager{ + chain: chain, + targetBN: targetBN, + requesting: requestingM, + processing: processingM, + retries: retriesPN, + rq: rq, + resultC: make(chan struct{}, 1), + } +} diff --git a/hmy/downloader/shortrange.go b/hmy/downloader/shortrange.go new file mode 100644 index 0000000000..2c0e52e8c7 --- /dev/null +++ b/hmy/downloader/shortrange.go @@ -0,0 +1,448 @@ +package downloader + +import ( + "context" + "fmt" + "math" + "sync" + "time" + + "github.com/ethereum/go-ethereum/common" + "github.com/harmony-one/harmony/core/types" + syncProto "github.com/harmony-one/harmony/p2p/stream/protocols/sync" + sttypes "github.com/harmony-one/harmony/p2p/stream/types" + "github.com/pkg/errors" + "github.com/rs/zerolog" +) + +// doShortRangeSync does the short range sync. +// Compared with long range sync, short range sync is more focused on syncing to the latest block. +// It consist of 3 steps: +// 1. Obtain the block hashes and ompute the longest hash chain.. +// 2. Get blocks by hashes from computed hash chain. +// 3. Insert the blocks to blockchain. +func (d *Downloader) doShortRangeSync() (int, error) { + sh := &srHelper{ + syncProtocol: d.syncProtocol, + ctx: d.ctx, + config: d.config, + logger: d.logger.With().Str("mode", "short range").Logger(), + } + + if err := sh.checkPrerequisites(); err != nil { + return 0, errors.Wrap(err, "prerequisite") + } + curBN := d.bc.CurrentBlock().NumberU64() + hashChain, whitelist, err := sh.getHashChain(curBN) + if err != nil { + return 0, errors.Wrap(err, "getHashChain") + } + if len(hashChain) == 0 { + // short circuit for no sync is needed + return 0, nil + } + + d.startSyncing() + expEndBN := curBN + uint64(len(hashChain)) - 1 + d.status.setTargetBN(expEndBN) + defer d.finishSyncing() + + blocks, err := sh.getBlocksByHashes(hashChain, whitelist) + if err != nil { + if !errors.Is(err, context.Canceled) { + sh.removeStreams(whitelist) // Remote nodes cannot provide blocks with target hashes + } + return 0, errors.Wrap(err, "getBlocksByHashes") + } + n, err := d.ih.verifyAndInsertBlocks(blocks) + if err != nil { + if !errors.As(err, &sigVerifyError{}) { + sh.removeStreams(whitelist) // Data provided by remote nodes is corrupted + } + return n, errors.Wrap(err, "InsertChain") + } + return len(blocks), nil +} + +type srHelper struct { + syncProtocol syncProtocol + + ctx context.Context + config Config + logger zerolog.Logger +} + +func (sh *srHelper) getHashChain(curBN uint64) ([]common.Hash, []sttypes.StreamID, error) { + bns := sh.prepareBlockHashNumbers(curBN) + results := newBlockHashResults(bns) + + var wg sync.WaitGroup + wg.Add(sh.config.Concurrency) + + for i := 0; i != sh.config.Concurrency; i++ { + go func() { + defer wg.Done() + + hashes, stid, err := sh.doGetBlockHashesRequest(bns) + if err != nil { + return + } + results.addResult(hashes, stid) + }() + } + wg.Wait() + + select { + case <-sh.ctx.Done(): + return nil, nil, sh.ctx.Err() + default: + } + + hashChain, wl := results.computeLongestHashChain() + return hashChain, wl, nil +} + +func (sh *srHelper) getBlocksByHashes(hashes []common.Hash, whitelist []sttypes.StreamID) ([]*types.Block, error) { + ctx, cancel := context.WithCancel(sh.ctx) + m := newGetBlocksByHashManager(hashes, whitelist) + + var ( + wg sync.WaitGroup + gErr error + errLock sync.Mutex + ) + + wg.Add(sh.config.Concurrency) + for i := 0; i != sh.config.Concurrency; i++ { + go func() { + defer wg.Done() + defer cancel() // it's ok to cancel context more than once + + for { + if m.isDone() { + return + } + hashes, wl, err := m.getNextHashes() + if err != nil { + errLock.Lock() + gErr = err + errLock.Unlock() + return + } + if len(hashes) == 0 { + select { + case <-time.After(200 * time.Millisecond): + continue + case <-ctx.Done(): + return + } + } + blocks, stid, err := sh.doGetBlocksByHashesRequest(ctx, hashes, wl) + if err != nil { + sh.logger.Err(err).Msg("getBlocksByHashes worker failed") + m.handleResultError(hashes, stid) + } else { + m.addResult(hashes, blocks, stid) + } + } + }() + } + wg.Wait() + + if gErr != nil { + return nil, gErr + } + select { + case <-sh.ctx.Done(): + return nil, sh.ctx.Err() + default: + } + + return m.getResults() +} + +func (sh *srHelper) checkPrerequisites() error { + if sh.syncProtocol.NumStreams() < sh.config.Concurrency { + return errors.New("not enough streams") + } + return nil +} + +func (sh *srHelper) prepareBlockHashNumbers(curNumber uint64) []uint64 { + res := make([]uint64, 0, numBlockHashesPerRequest) + + for bn := curNumber + 1; bn <= curNumber+uint64(numBlockHashesPerRequest); bn++ { + res = append(res, bn) + } + return res +} + +func (sh *srHelper) doGetBlockHashesRequest(bns []uint64) ([]common.Hash, sttypes.StreamID, error) { + ctx, cancel := context.WithTimeout(sh.ctx, 1*time.Second) + defer cancel() + + hashes, stid, err := sh.syncProtocol.GetBlockHashes(ctx, bns) + if err != nil { + return nil, stid, err + } + if len(hashes) != len(bns) { + err := errors.New("unexpected get block hashes result delivered") + sh.logger.Warn().Err(err).Str("stream", string(stid)).Msg("failed to doGetBlockHashesRequest") + sh.syncProtocol.RemoveStream(stid) + return nil, stid, err + } + return hashes, stid, nil +} + +func (sh *srHelper) doGetBlocksByHashesRequest(ctx context.Context, hashes []common.Hash, wl []sttypes.StreamID) ([]*types.Block, sttypes.StreamID, error) { + ctx, cancel := context.WithTimeout(sh.ctx, 10*time.Second) + defer cancel() + + blocks, stid, err := sh.syncProtocol.GetBlocksByHashes(ctx, hashes, + syncProto.WithWhitelist(wl)) + if err != nil { + return nil, stid, err + } + if err := checkGetBlockByHashesResult(blocks, hashes); err != nil { + sh.logger.Warn().Err(err).Str("stream", string(stid)).Msg("failed to getBlockByHashes") + sh.syncProtocol.RemoveStream(stid) + return nil, stid, err + } + return blocks, stid, nil +} + +func (sh *srHelper) removeStreams(sts []sttypes.StreamID) { + for _, st := range sts { + sh.syncProtocol.RemoveStream(st) + } +} + +func checkGetBlockByHashesResult(blocks []*types.Block, hashes []common.Hash) error { + if len(blocks) != len(hashes) { + return errors.New("unexpected number of getBlocksByHashes result") + } + for i, block := range blocks { + if block == nil { + return errors.New("nil block found") + } + if block.Hash() != hashes[i] { + return fmt.Errorf("unexpected block hash: %x / %x", block.Hash(), hashes[i]) + } + } + return nil +} + +type ( + blockHashResults struct { + bns []uint64 + results []map[sttypes.StreamID]common.Hash + + lock sync.Mutex + } +) + +func newBlockHashResults(bns []uint64) *blockHashResults { + results := make([]map[sttypes.StreamID]common.Hash, 0, len(bns)) + for range bns { + results = append(results, make(map[sttypes.StreamID]common.Hash)) + } + return &blockHashResults{ + bns: bns, + results: results, + } +} + +func (res *blockHashResults) addResult(hashes []common.Hash, stid sttypes.StreamID) { + res.lock.Lock() + defer res.lock.Unlock() + + for i, h := range hashes { + if h == emptyHash { + return // nil block hash reached + } + res.results[i][stid] = h + } + return +} + +func (res *blockHashResults) computeLongestHashChain() ([]common.Hash, []sttypes.StreamID) { + var ( + whitelist map[sttypes.StreamID]struct{} + hashChain []common.Hash + ) + for _, result := range res.results { + hash, nextWl := countHashMaxVote(result, whitelist) + if hash == emptyHash { + break + } + hashChain = append(hashChain, hash) + whitelist = nextWl + } + + sts := make([]sttypes.StreamID, 0, len(whitelist)) + for st := range whitelist { + sts = append(sts, st) + } + return hashChain, sts +} + +func countHashMaxVote(m map[sttypes.StreamID]common.Hash, whitelist map[sttypes.StreamID]struct{}) (common.Hash, map[sttypes.StreamID]struct{}) { + var ( + voteM = make(map[common.Hash]int) + res common.Hash + maxCnt = 0 + ) + + for st, h := range m { + if len(whitelist) != 0 { + if _, ok := whitelist[st]; !ok { + continue + } + } + if _, ok := voteM[h]; !ok { + voteM[h] = 0 + } + voteM[h]++ + if voteM[h] > maxCnt { + maxCnt = voteM[h] + res = h + } + } + + nextWl := make(map[sttypes.StreamID]struct{}) + for st, h := range m { + if h != res { + continue + } + if len(whitelist) != 0 { + if _, ok := whitelist[st]; ok { + nextWl[st] = struct{}{} + } + } else { + nextWl[st] = struct{}{} + } + } + return res, nextWl +} + +type getBlocksByHashManager struct { + hashes []common.Hash + pendings map[common.Hash]struct{} + results map[common.Hash]blockResult + whitelist []sttypes.StreamID + + lock sync.Mutex +} + +func newGetBlocksByHashManager(hashes []common.Hash, whitelist []sttypes.StreamID) *getBlocksByHashManager { + return &getBlocksByHashManager{ + hashes: hashes, + pendings: make(map[common.Hash]struct{}), + results: make(map[common.Hash]blockResult), + whitelist: whitelist, + } +} + +func (m *getBlocksByHashManager) getNextHashes() ([]common.Hash, []sttypes.StreamID, error) { + m.lock.Lock() + defer m.lock.Unlock() + + num := m.numBlocksPerRequest() + hashes := make([]common.Hash, 0, num) + if len(m.whitelist) == 0 { + return nil, nil, errors.New("empty white list") + } + + for _, hash := range m.hashes { + if len(hashes) == num { + break + } + _, ok1 := m.pendings[hash] + _, ok2 := m.results[hash] + if !ok1 && !ok2 { + hashes = append(hashes, hash) + } + } + sts := make([]sttypes.StreamID, len(m.whitelist)) + copy(sts, m.whitelist) + return hashes, sts, nil +} + +func (m *getBlocksByHashManager) numBlocksPerRequest() int { + val := divideCeil(len(m.hashes), len(m.whitelist)) + if val < numBlocksByHashesLowerCap { + val = numBlocksByHashesLowerCap + } + if val > numBlocksByHashesUpperCap { + val = numBlocksByHashesUpperCap + } + return val +} + +func (m *getBlocksByHashManager) addResult(hashes []common.Hash, blocks []*types.Block, stid sttypes.StreamID) { + m.lock.Lock() + defer m.lock.Unlock() + + for i, hash := range hashes { + block := blocks[i] + delete(m.pendings, hash) + m.results[hash] = blockResult{ + block: block, + stid: stid, + } + } +} + +func (m *getBlocksByHashManager) handleResultError(hashes []common.Hash, stid sttypes.StreamID) { + m.lock.Lock() + defer m.lock.Unlock() + + m.removeStreamID(stid) + + for _, hash := range hashes { + delete(m.pendings, hash) + } +} + +func (m *getBlocksByHashManager) getResults() ([]*types.Block, error) { + m.lock.Lock() + defer m.lock.Unlock() + + blocks := make([]*types.Block, 0, len(m.hashes)) + for _, hash := range m.hashes { + if m.results[hash].block == nil { + return nil, errors.New("SANITY: nil block found") + } + blocks = append(blocks, m.results[hash].block) + } + return blocks, nil +} + +func (m *getBlocksByHashManager) isDone() bool { + m.lock.Lock() + defer m.lock.Unlock() + + return len(m.results) == len(m.hashes) +} + +func (m *getBlocksByHashManager) removeStreamID(target sttypes.StreamID) { + // O(n^2) complexity. But considering the whitelist size is small, should not + // have performance issue. +loop: + for i, stid := range m.whitelist { + if stid == target { + if i == len(m.whitelist) { + m.whitelist = m.whitelist[:i] + } else { + m.whitelist = append(m.whitelist[:i], m.whitelist[i+1:]...) + } + goto loop + } + } + return +} + +func divideCeil(x, y int) int { + fVal := float64(x) / float64(y) + return int(math.Ceil(fVal)) +} diff --git a/hmy/downloader/shortrange_test.go b/hmy/downloader/shortrange_test.go new file mode 100644 index 0000000000..40993c187b --- /dev/null +++ b/hmy/downloader/shortrange_test.go @@ -0,0 +1,438 @@ +package downloader + +import ( + "context" + "errors" + "fmt" + "sync" + "testing" + + "github.com/ethereum/go-ethereum/common" + sttypes "github.com/harmony-one/harmony/p2p/stream/types" + "github.com/rs/zerolog" +) + +func TestDownloader_doShortRangeSync(t *testing.T) { + chain := newTestBlockChain(100, nil) + + d := &Downloader{ + bc: chain, + ih: &testInsertHelper{chain}, + syncProtocol: newTestSyncProtocol(105, 32, nil), + config: Config{ + Concurrency: 16, + MinStreams: 16, + }, + ctx: context.Background(), + logger: zerolog.Logger{}, + } + n, err := d.doShortRangeSync() + if err != nil { + t.Error(err) + } + if n == 0 { + t.Error("not synced") + } + if curNum := d.bc.CurrentBlock().NumberU64(); curNum != 105 { + t.Errorf("unexpected block number after sync: %v / %v", curNum, 105) + } +} + +func TestSrHelper_getHashChain(t *testing.T) { + tests := []struct { + curBN uint64 + syncProtocol syncProtocol + config Config + + expHashChainSize int + expStSize int + }{ + { + curBN: 100, + syncProtocol: newTestSyncProtocol(1000, 32, nil), + config: Config{ + Concurrency: 16, + MinStreams: 16, + }, + expHashChainSize: numBlockHashesPerRequest, + expStSize: 16, // Concurrency + }, + { + curBN: 100, + syncProtocol: newTestSyncProtocol(100, 32, nil), + config: Config{ + Concurrency: 16, + MinStreams: 16, + }, + expHashChainSize: 0, + expStSize: 0, + }, + { + curBN: 100, + syncProtocol: newTestSyncProtocol(110, 32, nil), + config: Config{ + Concurrency: 16, + MinStreams: 16, + }, + expHashChainSize: 10, + expStSize: 16, + }, + { + // stream size is smaller than concurrency + curBN: 100, + syncProtocol: newTestSyncProtocol(1000, 10, nil), + config: Config{ + Concurrency: 16, + MinStreams: 8, + }, + expHashChainSize: numBlockHashesPerRequest, + expStSize: 10, + }, + { + // one stream reports an error, else are fine + curBN: 100, + syncProtocol: newTestSyncProtocol(1000, 32, makeOnceErrorFunc()), + config: Config{ + Concurrency: 16, + MinStreams: 16, + }, + expHashChainSize: numBlockHashesPerRequest, + expStSize: 15, // Concurrency + }, + { + // error happens at one block number, all stream removed + curBN: 100, + syncProtocol: newTestSyncProtocol(1000, 32, func(bn uint64) error { + if bn == 110 { + return errors.New("test error") + } + return nil + }), + config: Config{ + Concurrency: 16, + MinStreams: 16, + }, + expHashChainSize: 0, + expStSize: 0, + }, + { + curBN: 100, + syncProtocol: newTestSyncProtocol(1000, 32, nil), + config: Config{ + Concurrency: 16, + MinStreams: 16, + }, + expHashChainSize: numBlockHashesPerRequest, + expStSize: 16, // Concurrency + }, + } + + for i, test := range tests { + sh := &srHelper{ + syncProtocol: test.syncProtocol, + ctx: context.Background(), + config: test.config, + } + hashChain, wl, err := sh.getHashChain(test.curBN) + if err != nil { + t.Error(err) + } + if len(hashChain) != test.expHashChainSize { + t.Errorf("Test %v: hash chain size unexpected: %v / %v", i, len(hashChain), test.expHashChainSize) + } + if len(wl) != test.expStSize { + t.Errorf("Test %v: whitelist size unexpected: %v / %v", i, len(wl), test.expStSize) + } + } +} + +func TestSrHelper_GetBlocksByHashes(t *testing.T) { + tests := []struct { + hashes []common.Hash + syncProtocol syncProtocol + config Config + + expBlockNumbers []uint64 + expErr error + }{ + { + hashes: testNumberToHashes([]uint64{101, 102, 103, 104, 105, 106, 107, 108, 109, 110}), + syncProtocol: newTestSyncProtocol(1000, 32, nil), + config: Config{ + Concurrency: 16, + MinStreams: 16, + }, + expBlockNumbers: []uint64{101, 102, 103, 104, 105, 106, 107, 108, 109, 110}, + expErr: nil, + }, + { + // remote node cannot give the block with the given hash + hashes: testNumberToHashes([]uint64{101, 102, 103, 104, 105, 106, 107, 108, 109, 110}), + syncProtocol: newTestSyncProtocol(100, 32, nil), + config: Config{ + Concurrency: 16, + MinStreams: 16, + }, + expBlockNumbers: []uint64{}, + expErr: errors.New("all streams are bad"), + }, + { + // one request return an error, else are fine + hashes: testNumberToHashes([]uint64{101, 102, 103, 104, 105, 106, 107, 108, 109, 110}), + syncProtocol: newTestSyncProtocol(1000, 32, makeOnceErrorFunc()), + config: Config{ + Concurrency: 16, + MinStreams: 16, + }, + expBlockNumbers: []uint64{101, 102, 103, 104, 105, 106, 107, 108, 109, 110}, + expErr: nil, + }, + { + // All nodes encounter an error + hashes: testNumberToHashes([]uint64{101, 102, 103, 104, 105, 106, 107, 108, 109, 110}), + syncProtocol: newTestSyncProtocol(1000, 32, func(n uint64) error { + if n == 109 { + return errors.New("test error") + } + return nil + }), + config: Config{ + Concurrency: 16, + MinStreams: 16, + }, + expErr: errors.New("error expected"), + }, + } + for i, test := range tests { + sh := &srHelper{ + syncProtocol: test.syncProtocol, + ctx: context.Background(), + config: test.config, + } + blocks, err := sh.getBlocksByHashes(test.hashes, makeStreamIDs(5)) + if (err == nil) != (test.expErr == nil) { + t.Errorf("Test %v: unexpected error %v / %v", i, err, test.expErr) + } + if len(blocks) != len(test.expBlockNumbers) { + t.Errorf("Test %v: unepxected block number size: %v / %v", i, len(blocks), len(test.expBlockNumbers)) + } + for i, block := range blocks { + gotNum := testHashToNumber(block.Hash()) + if gotNum != test.expBlockNumbers[i] { + t.Errorf("Test %v: unexpected block number", i) + } + } + } +} + +func TestBlockHashResult_ComputeLongestHashChain(t *testing.T) { + tests := []struct { + bns []uint64 + results map[sttypes.StreamID][]int64 + expChain []uint64 + expWhitelist map[sttypes.StreamID]struct{} + expErr error + }{ + { + bns: []uint64{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, + results: map[sttypes.StreamID][]int64{ + makeStreamID(0): {1, 2, 3, 4, 5, 6, 7}, + makeStreamID(1): {1, 2, 3, 4, 5, 6, 7}, + makeStreamID(2): {1, 2, 3, 4, 5}, // left behind + }, + expChain: []uint64{1, 2, 3, 4, 5, 6, 7}, + expWhitelist: map[sttypes.StreamID]struct{}{ + makeStreamID(0): {}, + makeStreamID(1): {}, + }, + }, + { + // minority fork + bns: []uint64{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, + results: map[sttypes.StreamID][]int64{ + makeStreamID(0): {1, 2, 3, 4, 5, 6, 7}, + makeStreamID(1): {1, 2, 3, 4, 5, 6, 7}, + makeStreamID(2): {1, 2, 3, 4, 5, 7, 8, 9}, + }, + expChain: []uint64{1, 2, 3, 4, 5, 6, 7}, + expWhitelist: map[sttypes.StreamID]struct{}{ + makeStreamID(0): {}, + makeStreamID(1): {}, + }, + }, { + // nil block + bns: []uint64{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, + results: map[sttypes.StreamID][]int64{ + makeStreamID(0): {}, + makeStreamID(1): {}, + makeStreamID(2): {}, + }, + expChain: nil, + expWhitelist: nil, + }, { + // not continuous block + bns: []uint64{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, + results: map[sttypes.StreamID][]int64{ + makeStreamID(0): {1, 2, 3, 4, 5, 6, 7, -1, 9}, + makeStreamID(1): {1, 2, 3, 4, 5, 6, 7}, + makeStreamID(2): {1, 2, 3, 4, 5, 7, 8, 9}, + }, + expChain: []uint64{1, 2, 3, 4, 5, 6, 7}, + expWhitelist: map[sttypes.StreamID]struct{}{ + makeStreamID(0): {}, + makeStreamID(1): {}, + }, + }, + { + // not continuous block + bns: []uint64{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, + results: map[sttypes.StreamID][]int64{}, + expErr: errors.New("zero result"), + }, + } + + for i, test := range tests { + res := newBlockHashResults(test.bns) + for st, hs := range test.results { + res.addResult(makeTestBlockHashes(hs), st) + } + + chain, wl := res.computeLongestHashChain() + + if err := checkHashChainResult(chain, test.expChain); err != nil { + t.Errorf("Test %v: %v", i, err) + } + if err := checkStreamSetEqual(streamIDListToMap(wl), test.expWhitelist); err != nil { + t.Errorf("Test %v: %v", i, err) + } + } +} + +func checkHashChainResult(gots []common.Hash, exps []uint64) error { + if len(gots) != len(exps) { + return errors.New("unexpected size") + } + for i, got := range gots { + exp := exps[i] + if got != makeTestBlockHash(exp) { + return errors.New("unexpected block hash") + } + } + return nil +} + +func TestHashMaxVote(t *testing.T) { + tests := []struct { + m map[sttypes.StreamID]common.Hash + whitelist map[sttypes.StreamID]struct{} + expRes common.Hash + expWhitelist map[sttypes.StreamID]struct{} + }{ + { + m: map[sttypes.StreamID]common.Hash{ + makeStreamID(0): makeTestBlockHash(0), + makeStreamID(1): makeTestBlockHash(1), + makeStreamID(2): makeTestBlockHash(1), + }, + whitelist: map[sttypes.StreamID]struct{}{ + makeStreamID(0): {}, + makeStreamID(1): {}, + makeStreamID(2): {}, + }, + expRes: makeTestBlockHash(1), + expWhitelist: map[sttypes.StreamID]struct{}{ + makeStreamID(1): {}, + makeStreamID(2): {}, + }, + }, { + m: map[sttypes.StreamID]common.Hash{ + makeStreamID(0): makeTestBlockHash(0), + makeStreamID(1): makeTestBlockHash(1), + makeStreamID(2): makeTestBlockHash(1), + }, + whitelist: nil, + expRes: makeTestBlockHash(1), + expWhitelist: map[sttypes.StreamID]struct{}{ + makeStreamID(1): {}, + makeStreamID(2): {}, + }, + }, { + m: map[sttypes.StreamID]common.Hash{ + makeStreamID(0): makeTestBlockHash(0), + makeStreamID(1): makeTestBlockHash(1), + makeStreamID(2): makeTestBlockHash(1), + makeStreamID(3): makeTestBlockHash(0), + makeStreamID(4): makeTestBlockHash(0), + }, + whitelist: map[sttypes.StreamID]struct{}{ + makeStreamID(0): {}, + makeStreamID(1): {}, + makeStreamID(2): {}, + }, + expRes: makeTestBlockHash(1), + expWhitelist: map[sttypes.StreamID]struct{}{ + makeStreamID(1): {}, + makeStreamID(2): {}, + }, + }, + } + + for i, test := range tests { + h, wl := countHashMaxVote(test.m, test.whitelist) + + if h != test.expRes { + t.Errorf("Test %v: unexpected hash: %x / %x", i, h, test.expRes) + } + if err := checkStreamSetEqual(wl, test.expWhitelist); err != nil { + t.Errorf("Test %v: %v", i, err) + } + } +} + +func checkStreamSetEqual(m1, m2 map[sttypes.StreamID]struct{}) error { + if len(m1) != len(m2) { + return fmt.Errorf("unexpected size: %v / %v", len(m1), len(m2)) + } + for st := range m1 { + if _, ok := m2[st]; !ok { + return errors.New("not equal") + } + } + return nil +} + +func makeTestBlockHashes(bns []int64) []common.Hash { + hs := make([]common.Hash, 0, len(bns)) + for _, bn := range bns { + if bn < 0 { + hs = append(hs, emptyHash) + } else { + hs = append(hs, makeTestBlockHash(uint64(bn))) + } + } + return hs +} + +func streamIDListToMap(sts []sttypes.StreamID) map[sttypes.StreamID]struct{} { + res := make(map[sttypes.StreamID]struct{}) + + for _, st := range sts { + res[st] = struct{}{} + } + return res +} + +func makeTestBlockHash(bn uint64) common.Hash { + return makeTestBlock(bn).Hash() +} + +func makeOnceErrorFunc() func(num uint64) error { + var once sync.Once + return func(num uint64) error { + var err error + once.Do(func() { + err = errors.New("test error expected") + }) + return err + } +} diff --git a/hmy/downloader/types.go b/hmy/downloader/types.go new file mode 100644 index 0000000000..56c6bcb8c4 --- /dev/null +++ b/hmy/downloader/types.go @@ -0,0 +1,292 @@ +package downloader + +import ( + "container/heap" + "sync" + + "github.com/ethereum/go-ethereum/common" + "github.com/harmony-one/harmony/core/types" + sttypes "github.com/harmony-one/harmony/p2p/stream/types" +) + +var ( + emptyHash common.Hash +) + +type status struct { + isSyncing bool + targetBN uint64 + lock sync.Mutex +} + +func newStatus() status { + return status{} +} + +func (s *status) startSyncing() { + s.lock.Lock() + defer s.lock.Unlock() + + s.isSyncing = true +} + +func (s *status) setTargetBN(val uint64) { + s.lock.Lock() + defer s.lock.Unlock() + + s.targetBN = val +} + +func (s *status) finishSyncing() { + s.lock.Lock() + defer s.lock.Unlock() + + s.isSyncing = false + s.targetBN = 0 +} + +func (s *status) get() (bool, uint64) { + s.lock.Lock() + defer s.lock.Unlock() + + return s.isSyncing, s.targetBN +} + +type getBlocksResult struct { + bns []uint64 + blocks []*types.Block + stid sttypes.StreamID +} + +type resultQueue struct { + results *priorityQueue + lock sync.Mutex +} + +func newResultQueue() *resultQueue { + pq := make(priorityQueue, 0, 200) // 200 - rough estimate + heap.Init(&pq) + return &resultQueue{ + results: &pq, + } +} + +// addBlockResults adds the blocks to the result queue to be processed by insertChainLoop. +// If a nil block is detected in the block list, will not process further blocks. +func (rq *resultQueue) addBlockResults(blocks []*types.Block, stid sttypes.StreamID) { + rq.lock.Lock() + defer rq.lock.Unlock() + + for _, block := range blocks { + if block == nil { + continue + } + heap.Push(rq.results, &blockResult{ + block: block, + stid: stid, + }) + } + return +} + +// popBlockResults pop a continuous list of blocks starting at expStartBN with capped size. +// Return the stale block numbers as the second return value +func (rq *resultQueue) popBlockResults(expStartBN uint64, cap int) ([]*blockResult, []uint64) { + rq.lock.Lock() + defer rq.lock.Unlock() + + var ( + res = make([]*blockResult, 0, cap) + stales []uint64 + ) + + for cnt := 0; rq.results.Len() > 0 && cnt < cap; cnt++ { + br := heap.Pop(rq.results).(*blockResult) + // stale block number + if br.block.NumberU64() < expStartBN { + stales = append(stales, br.block.NumberU64()) + continue + } + if br.block.NumberU64() != expStartBN { + heap.Push(rq.results, br) + return res, stales + } + res = append(res, br) + expStartBN++ + } + return res, stales +} + +// removeResultsByStreamID remove the block results of the given stream, return the block +// number removed from the queue +func (rq *resultQueue) removeResultsByStreamID(stid sttypes.StreamID) []uint64 { + rq.lock.Lock() + defer rq.lock.Unlock() + + var removed []uint64 + +Loop: + for { + for i, res := range *rq.results { + blockRes := res.(*blockResult) + if blockRes.stid == stid { + rq.removeByIndex(i) + removed = append(removed, blockRes.block.NumberU64()) + goto Loop + } + } + break + } + return removed +} + +func (rq *resultQueue) length() int { + return len(*rq.results) +} + +func (rq *resultQueue) removeByIndex(index int) { + heap.Remove(rq.results, index) +} + +// bnPrioritizedItem is the item which uses block number to determine its priority +type bnPrioritizedItem interface { + getBlockNumber() uint64 +} + +type blockResult struct { + block *types.Block + stid sttypes.StreamID +} + +func (br *blockResult) getBlockNumber() uint64 { + return br.block.NumberU64() +} + +func blockResultsToBlocks(results []*blockResult) []*types.Block { + blocks := make([]*types.Block, 0, len(results)) + + for _, result := range results { + blocks = append(blocks, result.block) + } + return blocks +} + +type ( + prioritizedNumber uint64 + + prioritizedNumbers struct { + q *priorityQueue + } +) + +func (b prioritizedNumber) getBlockNumber() uint64 { + return uint64(b) +} + +func newPrioritizedNumbers() *prioritizedNumbers { + pqs := make(priorityQueue, 0) + heap.Init(&pqs) + return &prioritizedNumbers{ + q: &pqs, + } +} + +func (pbs *prioritizedNumbers) push(bn uint64) { + heap.Push(pbs.q, prioritizedNumber(bn)) +} + +func (pbs *prioritizedNumbers) pop() uint64 { + if pbs.q.Len() == 0 { + return 0 + } + item := heap.Pop(pbs.q) + return uint64(item.(prioritizedNumber)) +} + +func (pbs *prioritizedNumbers) length() int { + return len(*pbs.q) +} + +type ( + blockByNumber types.Block + + // blocksByNumber is the priority queue ordered by number + blocksByNumber struct { + q *priorityQueue + cap int + } +) + +func (b *blockByNumber) getBlockNumber() uint64 { + raw := (*types.Block)(b) + return raw.NumberU64() +} + +func newBlocksByNumber(cap int) *blocksByNumber { + pqs := make(priorityQueue, 0) + heap.Init(&pqs) + return &blocksByNumber{ + q: &pqs, + cap: cap, + } +} + +func (bs *blocksByNumber) push(b *types.Block) { + heap.Push(bs.q, (*blockByNumber)(b)) + for bs.q.Len() > bs.cap { + heap.Pop(bs.q) + } +} + +func (bs *blocksByNumber) pop() *types.Block { + if bs.q.Len() == 0 { + return nil + } + item := heap.Pop(bs.q) + return (*types.Block)(item.(*blockByNumber)) +} + +func (bs *blocksByNumber) len() int { + return bs.q.Len() +} + +// priorityQueue is a priorityQueue with lowest block number with highest priority +type priorityQueue []bnPrioritizedItem + +// resultQueue implements heap interface +func (q priorityQueue) Len() int { + return len(q) +} + +// resultQueue implements heap interface +func (q priorityQueue) Less(i, j int) bool { + bn1 := q[i].getBlockNumber() + bn2 := q[j].getBlockNumber() + return bn1 < bn2 // small block number has higher priority +} + +// resultQueue implements heap interface +func (q priorityQueue) Swap(i, j int) { + q[i], q[j] = q[j], q[i] +} + +// resultQueue implements heap interface +func (q *priorityQueue) Push(x interface{}) { + item, ok := x.(bnPrioritizedItem) + if !ok { + panic("wrong type of getBlockNumber interface") + } + *q = append(*q, item) +} + +// resultQueue implements heap interface +func (q *priorityQueue) Pop() interface{} { + prev := *q + n := len(prev) + if n == 0 { + return nil + } + res := prev[n-1] + *q = prev[0 : n-1] + return res +} diff --git a/hmy/downloader/types_test.go b/hmy/downloader/types_test.go new file mode 100644 index 0000000000..0f1443708d --- /dev/null +++ b/hmy/downloader/types_test.go @@ -0,0 +1,261 @@ +package downloader + +import ( + "container/heap" + "fmt" + "math/big" + "strings" + "testing" + + "github.com/harmony-one/harmony/block" + headerV3 "github.com/harmony-one/harmony/block/v3" + "github.com/harmony-one/harmony/core/types" + sttypes "github.com/harmony-one/harmony/p2p/stream/types" +) + +func TestResultQueue_AddBlockResults(t *testing.T) { + tests := []struct { + initBNs []uint64 + addBNs []uint64 + expSize int + }{ + { + initBNs: []uint64{}, + addBNs: []uint64{1, 2, 3, 4}, + expSize: 4, + }, + { + initBNs: []uint64{1, 2, 3, 4}, + addBNs: []uint64{5, 6, 7, 8}, + expSize: 8, + }, + } + for i, test := range tests { + rq := makeTestResultQueue(test.initBNs) + rq.addBlockResults(makeTestBlocks(test.addBNs), "") + + if rq.results.Len() != test.expSize { + t.Errorf("Test %v: unexpected size: %v / %v", i, rq.results.Len(), test.expSize) + } + } +} + +func TestResultQueue_PopBlockResults(t *testing.T) { + tests := []struct { + initBNs []uint64 + cap int + expStart uint64 + expSize int + staleSize int + }{ + { + initBNs: []uint64{1, 2, 3, 4, 5}, + cap: 3, + expStart: 1, + expSize: 3, + staleSize: 0, + }, + { + initBNs: []uint64{1, 2, 3, 4, 5}, + cap: 10, + expStart: 1, + expSize: 5, + staleSize: 0, + }, + { + initBNs: []uint64{1, 3, 4, 5}, + cap: 10, + expStart: 1, + expSize: 1, + staleSize: 0, + }, + { + initBNs: []uint64{1, 2, 3, 4, 5}, + cap: 10, + expStart: 0, + expSize: 0, + staleSize: 0, + }, + { + initBNs: []uint64{1, 1, 1, 1, 2}, + cap: 10, + expStart: 1, + expSize: 2, + staleSize: 3, + }, + { + initBNs: []uint64{1, 2, 3, 4, 5}, + cap: 10, + expStart: 2, + expSize: 4, + staleSize: 1, + }, + } + for i, test := range tests { + rq := makeTestResultQueue(test.initBNs) + res, stales := rq.popBlockResults(test.expStart, test.cap) + if len(res) != test.expSize { + t.Errorf("Test %v: unexpect size %v / %v", i, len(res), test.expSize) + } + if len(stales) != test.staleSize { + t.Errorf("Test %v: unexpect stale size %v / %v", i, len(stales), test.staleSize) + } + } +} + +func TestResultQueue_RemoveResultsByStreamID(t *testing.T) { + tests := []struct { + rq *resultQueue + rmStreamID sttypes.StreamID + removed int + expSize int + }{ + { + rq: makeTestResultQueue([]uint64{1, 2, 3, 4}), + rmStreamID: "test stream id", + removed: 4, + expSize: 0, + }, + { + rq: func() *resultQueue { + rq := makeTestResultQueue([]uint64{2, 3, 4, 5}) + rq.addBlockResults([]*types.Block{ + makeTestBlock(1), + makeTestBlock(5), + makeTestBlock(6), + }, "another test stream id") + return rq + }(), + rmStreamID: "test stream id", + removed: 4, + expSize: 3, + }, + { + rq: func() *resultQueue { + rq := makeTestResultQueue([]uint64{2, 3, 4, 5}) + rq.addBlockResults([]*types.Block{ + makeTestBlock(1), + makeTestBlock(5), + makeTestBlock(6), + }, "another test stream id") + return rq + }(), + rmStreamID: "another test stream id", + removed: 3, + expSize: 4, + }, + } + for i, test := range tests { + res := test.rq.removeResultsByStreamID(test.rmStreamID) + if len(res) != test.removed { + t.Errorf("Test %v: unexpected number removed %v / %v", i, len(res), test.removed) + } + if gotSize := test.rq.results.Len(); gotSize != test.expSize { + t.Errorf("Test %v: unexpected number after removal %v / %v", i, gotSize, test.expSize) + } + } +} + +func makeTestResultQueue(bns []uint64) *resultQueue { + rq := newResultQueue() + for _, bn := range bns { + heap.Push(rq.results, &blockResult{ + block: makeTestBlock(bn), + stid: "test stream id", + }) + } + return rq +} + +func TestPrioritizedBlocks(t *testing.T) { + addBNs := []uint64{4, 7, 6, 9} + + bns := newPrioritizedNumbers() + for _, bn := range addBNs { + bns.push(bn) + } + prevBN := uint64(0) + for len(*bns.q) > 0 { + b := bns.pop() + if b < prevBN { + t.Errorf("number not incrementing") + } + prevBN = b + } + if last := bns.pop(); last != 0 { + t.Errorf("last elem is not 0") + } +} + +func TestBlocksByNumber(t *testing.T) { + addBNs := []uint64{4, 7, 6, 9} + + bns := newBlocksByNumber(10) + for _, bn := range addBNs { + bns.push(makeTestBlock(bn)) + } + if bns.len() != len(addBNs) { + t.Errorf("size unexpected: %v / %v", bns.len(), len(addBNs)) + } + prevBN := uint64(0) + for len(*bns.q) > 0 { + b := bns.pop() + if b.NumberU64() < prevBN { + t.Errorf("number not incrementing") + } + prevBN = b.NumberU64() + } + if lastBlock := bns.pop(); lastBlock != nil { + t.Errorf("last block is not nil") + } +} + +func TestPriorityQueue(t *testing.T) { + testBNs := []uint64{1, 9, 2, 4, 5, 12} + pq := make(priorityQueue, 0, 10) + heap.Init(&pq) + for _, bn := range testBNs { + heap.Push(&pq, &blockResult{ + block: makeTestBlock(bn), + stid: "", + }) + } + cmpBN := uint64(0) + for pq.Len() > 0 { + bn := heap.Pop(&pq).(*blockResult).block.NumberU64() + if bn < cmpBN { + t.Errorf("not incrementing") + } + cmpBN = bn + } + if pq.Len() != 0 { + t.Errorf("after poping, size not 0") + } +} + +func makeTestBlocks(bns []uint64) []*types.Block { + blocks := make([]*types.Block, 0, len(bns)) + for _, bn := range bns { + blocks = append(blocks, makeTestBlock(bn)) + } + return blocks +} + +func makeTestBlock(bn uint64) *types.Block { + testHeader := &block.Header{Header: headerV3.NewHeader()} + testHeader.SetNumber(big.NewInt(int64(bn))) + return types.NewBlockWithHeader(testHeader) +} + +func assertError(got, expect error) error { + if (got == nil) != (expect == nil) { + return fmt.Errorf("unexpected error [%v] / [%v]", got, expect) + } + if (got == nil) || (expect == nil) { + return nil + } + if !strings.Contains(got.Error(), expect.Error()) { + return fmt.Errorf("unexpected error [%v] / [%v]", got, expect) + } + return nil +}