diff --git a/api/blockchain_api.go b/api/blockchain_api.go index bd12c420..f5c2481a 100644 --- a/api/blockchain_api.go +++ b/api/blockchain_api.go @@ -173,7 +173,7 @@ func (api *BlockchainApi) TxReceipt(hash common.Hash) *TxReceipt { } func (api *BlockchainApi) Mempool() []common.Hash { - pending := api.pool.GetPendingTransaction(true, false) + pending := api.pool.GetPendingTransaction(true, common.MultiShard, false) var txs []common.Hash for _, tx := range pending { diff --git a/core/mempool/async_txpool.go b/core/mempool/async_txpool.go index 787c11ef..3b545d12 100644 --- a/core/mempool/async_txpool.go +++ b/core/mempool/async_txpool.go @@ -4,6 +4,7 @@ import ( "errors" "fmt" "github.com/idena-network/idena-go/blockchain/types" + "github.com/idena-network/idena-go/common" ) const batchSize = 1000 @@ -45,8 +46,8 @@ func (pool *AsyncTxPool) AddExternalTxs(txs ...*types.Transaction) error { return nil } -func (pool *AsyncTxPool) GetPendingTransaction(noFilter bool, count bool) []*types.Transaction { - return pool.txPool.GetPendingTransaction(noFilter, count) +func (pool *AsyncTxPool) GetPendingTransaction(noFilter bool, id common.ShardId, count bool) []*types.Transaction { + return pool.txPool.GetPendingTransaction(noFilter, id, count) } func (pool *AsyncTxPool) loop() { diff --git a/core/mempool/txpool.go b/core/mempool/txpool.go index e9aa949a..48012a63 100644 --- a/core/mempool/txpool.go +++ b/core/mempool/txpool.go @@ -38,7 +38,7 @@ var ( type TransactionPool interface { AddInternalTx(tx *types.Transaction) error AddExternalTxs(txs ...*types.Transaction) error - GetPendingTransaction(noFilter bool, count bool) []*types.Transaction + GetPendingTransaction(noFilter bool, id common.ShardId, count bool) []*types.Transaction IsSyncing() bool } @@ -399,7 +399,7 @@ func (pool *TxPool) put(tx *types.Transaction) error { return nil } -func (pool *TxPool) GetPendingTransaction(noFilter bool, count bool) []*types.Transaction { +func (pool *TxPool) GetPendingTransaction(noFilter bool, shardId common.ShardId, count bool) []*types.Transaction { all := pool.all.List() pool.mutex.Lock() defer pool.mutex.Unlock() @@ -410,9 +410,11 @@ func (pool *TxPool) GetPendingTransaction(noFilter bool, count bool) []*types.Tr } for _, tx := range all { if noFilter || pool.txSyncCounts[tx.Hash()] <= maxTxSyncCounts { - result = append(result, tx) - if count { - pool.txSyncCounts[tx.Hash()] ++ + if shardId == common.MultiShard || tx.LoadShardId() == shardId { + result = append(result, tx) + if count { + pool.txSyncCounts[tx.Hash()]++ + } } } } diff --git a/core/mempool/txpool_test.go b/core/mempool/txpool_test.go index 4ca78fdf..41fb6b9b 100644 --- a/core/mempool/txpool_test.go +++ b/core/mempool/txpool_test.go @@ -316,7 +316,7 @@ func TestTxPool_AddWithTxKeeper(t *testing.T) { time.Sleep(time.Second) require.Len(t, pool.txKeeper.txs, 320) - pool.txKeeper.RemoveTxs([]common.Hash{pool.GetPendingTransaction(false, false)[0].Hash()}) + pool.txKeeper.RemoveTxs([]common.Hash{pool.GetPendingTransaction(false, common.MultiShard, false)[0].Hash()}) time.Sleep(time.Second) prevPool := pool diff --git a/deferredtx/job_test.go b/deferredtx/job_test.go index 19d24d76..1a5ebc29 100644 --- a/deferredtx/job_test.go +++ b/deferredtx/job_test.go @@ -43,7 +43,7 @@ func (f *fakeTxPool) AddExternalTxs(txs ...*types.Transaction) error { panic("implement me") } -func (f fakeTxPool) GetPendingTransaction(bool, bool) []*types.Transaction { +func (f fakeTxPool) GetPendingTransaction(bool, common.ShardId, bool) []*types.Transaction { panic("implement me") } diff --git a/protocol/gossip.go b/protocol/gossip.go index 3471842f..a2679608 100644 --- a/protocol/gossip.go +++ b/protocol/gossip.go @@ -884,7 +884,7 @@ func (h *IdenaGossipHandler) RequestBlockByHash(hash common.Hash) { func (h *IdenaGossipHandler) syncTxPool(p *protoPeer) { const maximalPeersNumberForFullSync = 3 - pending := h.txpool.GetPendingTransaction(p.peers <= maximalPeersNumberForFullSync, true) + pending := h.txpool.GetPendingTransaction(p.peers <= maximalPeersNumberForFullSync, p.shardId, true) for _, tx := range pending { payload := pushPullHash{ Type: pushTx, diff --git a/tests/txpool_test.go b/tests/txpool_test.go index 990fdd47..eddc7154 100644 --- a/tests/txpool_test.go +++ b/tests/txpool_test.go @@ -263,7 +263,7 @@ func TestTxPool_ResetTo(t *testing.T) { }, }) - txs := pool.GetPendingTransaction(true, false) + txs := pool.GetPendingTransaction(true, common.MultiShard, false) require.Equal(4, len(txs)) require.Contains(txs, tx1)