diff --git a/common/crypto/Cargo.toml b/common/crypto/Cargo.toml index 8a172db3b..2054e146e 100644 --- a/common/crypto/Cargo.toml +++ b/common/crypto/Cargo.toml @@ -7,3 +7,9 @@ edition = "2018" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] +ophelia-secp256k1 = { git = "https://github.com/zeroqn/ophelia.git" } +ophelia = { git = "https://github.com/zeroqn/ophelia.git" } + +[features] +default = ["generate"] +generate = ["ophelia-secp256k1/generate", "ophelia/generate"] diff --git a/common/crypto/src/lib.rs b/common/crypto/src/lib.rs index 31e1bb209..e1bda462c 100644 --- a/common/crypto/src/lib.rs +++ b/common/crypto/src/lib.rs @@ -1,7 +1,4 @@ -#[cfg(test)] -mod tests { - #[test] - fn it_works() { - assert_eq!(2 + 2, 4); - } -} +pub use ophelia::{Crypto, PrivateKey, PublicKey, Signature}; +pub use ophelia_secp256k1::{ + Secp256k1, Secp256k1PrivateKey, Secp256k1PublicKey, Secp256k1Signature, +}; diff --git a/core/mempool/Cargo.toml b/core/mempool/Cargo.toml index a03417324..aca30a1e3 100644 --- a/core/mempool/Cargo.toml +++ b/core/mempool/Cargo.toml @@ -8,19 +8,23 @@ edition = "2018" [dependencies] protocol = { path = "../../protocol" } +common-crypto = { path = "../../common/crypto"} futures-preview = "0.3.0-alpha.18" +runtime-tokio = "0.3.0-alpha.6" runtime = "0.3.0-alpha.7" crossbeam-queue = "0.1" derive_more = "0.15" async-trait = "0.1" parking_lot = "0.8" +num-traits = "0.2" bytes = "0.4" rayon = "1.0" +rand = "0.6" hex = "0.3" [dev-dependencies] -num-traits = "0.2" chashmap = "2.2" -rand = "0.7.0" + + diff --git a/core/mempool/src/error.rs b/core/mempool/src/error.rs deleted file mode 100644 index ebfa31b57..000000000 --- a/core/mempool/src/error.rs +++ /dev/null @@ -1,37 +0,0 @@ -use std::error::Error; - -use derive_more::{Display, From}; - -use protocol::types::Hash; -use protocol::{ProtocolError, ProtocolErrorKind}; - -#[derive(Debug, Display, From)] -pub enum MemPoolError { - #[display(fmt = "Tx: {:?} insert failed", tx_hash)] - Insert { tx_hash: Hash }, - #[display(fmt = "Mempool reach limit: {}", pool_size)] - ReachLimit { pool_size: usize }, - #[display(fmt = "Tx: {:?} exists in pool", tx_hash)] - Dup { tx_hash: Hash }, - #[display(fmt = "Pull {} tx_hashes, return {} signed_txs", require, response)] - EnsureBreak { require: usize, response: usize }, - #[display( - fmt = "Return mismatch number of full transaction, require: {}, response: {}. This should not happen!", - require, - response - )] - MisMatch { require: usize, response: usize }, - #[display( - fmt = "Transaction insert into candidate queue with len: {} failed which should not happen!", - len - )] - InsertCandidate { len: usize }, -} - -impl Error for MemPoolError {} - -impl From for ProtocolError { - fn from(error: MemPoolError) -> ProtocolError { - ProtocolError::new(ProtocolErrorKind::Mempool, Box::new(error)) - } -} diff --git a/core/mempool/src/lib.rs b/core/mempool/src/lib.rs index 1fae63d3f..b9b1c48a3 100644 --- a/core/mempool/src/lib.rs +++ b/core/mempool/src/lib.rs @@ -1,23 +1,24 @@ #![feature(test)] -mod error; mod map; +mod test; mod tx_cache; +use std::error::Error; use std::sync::atomic::{AtomicU64, Ordering}; use async_trait::async_trait; +use derive_more::{Display, From}; use protocol::traits::{Context, MemPool, MemPoolAdapter, MixedTxHashes}; use protocol::types::{Hash, SignedTransaction}; -use protocol::ProtocolResult; +use protocol::{ProtocolError, ProtocolErrorKind, ProtocolResult}; -use crate::error::MemPoolError; use crate::map::Map; use crate::tx_cache::TxCache; /// Memory pool for caching transactions. -struct HashMemPool { +pub struct HashMemPool { /// Pool size limit. pool_size: usize, /// A system param limits the life time of an off-chain transaction. @@ -57,6 +58,18 @@ where current_epoch_id: AtomicU64::new(current_epoch_id), } } + + pub fn get_tx_cache(&self) -> &TxCache { + &self.tx_cache + } + + pub fn get_callback_cache(&self) -> &Map { + &self.callback_cache + } + + pub fn get_adapter(&self) -> &Adapter { + &self.adapter + } } #[async_trait] @@ -66,40 +79,19 @@ where { async fn insert(&self, ctx: Context, tx: SignedTransaction) -> ProtocolResult<()> { let tx_hash = &tx.tx_hash; - // 1. check size - if self.tx_cache.len() >= self.pool_size { - return Err(MemPoolError::ReachLimit { - pool_size: self.pool_size, - } - .into()); - } - // 2. check pool exist - if self.tx_cache.contain(tx_hash) { - return Err(MemPoolError::Dup { - tx_hash: tx_hash.clone(), - } - .into()); - } - // 3. check signature + self.tx_cache.check_reach_limit(self.pool_size)?; + self.tx_cache.check_exist(tx_hash)?; self.adapter .check_signature(ctx.clone(), tx.clone()) .await?; - - // 4. check transaction self.adapter .check_transaction(ctx.clone(), tx.clone()) .await?; - - // 5. check storage exist self.adapter .check_storage_exist(ctx.clone(), tx_hash.clone()) .await?; - - // 6. do insert self.tx_cache.insert_new_tx(tx.clone())?; - - // 7. network broadcast self.adapter.broadcast_tx(ctx.clone(), tx).await?; Ok(()) @@ -189,3 +181,31 @@ where Ok(()) } } + +#[derive(Debug, Display, From)] +pub enum MemPoolError { + #[display(fmt = "Tx: {:?} inserts failed", tx_hash)] + Insert { tx_hash: Hash }, + #[display(fmt = "Mempool reaches limit: {}", pool_size)] + ReachLimit { pool_size: usize }, + #[display(fmt = "Tx: {:?} exists in pool", tx_hash)] + Dup { tx_hash: Hash }, + #[display(fmt = "Pull txs, require: {}, response: {}", require, response)] + EnsureBreak { require: usize, response: usize }, + #[display(fmt = "Fetch full txs, require: {}, response: {}", require, response)] + MisMatch { require: usize, response: usize }, + #[display(fmt = "Tx inserts candidate_queue failed, len: {}", len)] + InsertCandidate { len: usize }, + #[display(fmt = "Tx: {:?} check_sig failed", tx_hash)] + CheckSig { tx_hash: Hash }, + #[display(fmt = "Check_hash failed, expect: {:?}, get: {:?}", expect, actual)] + CheckHash { expect: Hash, actual: Hash }, +} + +impl Error for MemPoolError {} + +impl From for ProtocolError { + fn from(error: MemPoolError) -> ProtocolError { + ProtocolError::new(ProtocolErrorKind::Mempool, Box::new(error)) + } +} diff --git a/core/mempool/src/map.rs b/core/mempool/src/map.rs index 0637846ec..655347557 100644 --- a/core/mempool/src/map.rs +++ b/core/mempool/src/map.rs @@ -143,11 +143,11 @@ mod tests { use crate::map::Map; - const GEN_TX_SIZE: usize = 100_000; + const GEN_TX_SIZE: usize = 1000; #[bench] - fn bench_insert_sharding(b: &mut Bencher) { - let txs = gen_txs(GEN_TX_SIZE); + fn bench_map_insert(b: &mut Bencher) { + let txs = mock_txs(GEN_TX_SIZE); b.iter(move || { let cache = Map::new(GEN_TX_SIZE); @@ -158,8 +158,8 @@ mod tests { } #[bench] - fn bench_insert_std(b: &mut Bencher) { - let txs = gen_txs(GEN_TX_SIZE); + fn bench_std_map_insert(b: &mut Bencher) { + let txs = mock_txs(GEN_TX_SIZE); b.iter(move || { let cache = Arc::new(RwLock::new(HashMap::new())); @@ -170,8 +170,8 @@ mod tests { } #[bench] - fn bench_insert_chashmap(b: &mut Bencher) { - let txs = gen_txs(GEN_TX_SIZE); + fn bench_chashmap_insert(b: &mut Bencher) { + let txs = mock_txs(GEN_TX_SIZE); b.iter(move || { let cache = CHashMap::new(); @@ -181,7 +181,7 @@ mod tests { }); } - fn gen_txs(size: usize) -> Vec<(Hash, Hash)> { + fn mock_txs(size: usize) -> Vec<(Hash, Hash)> { let mut txs = Vec::with_capacity(size); for _ in 0..size { let tx: Vec = (0..10).map(|_| random::()).collect(); diff --git a/core/mempool/src/test.rs b/core/mempool/src/test.rs new file mode 100644 index 000000000..4192aaa6b --- /dev/null +++ b/core/mempool/src/test.rs @@ -0,0 +1,549 @@ +#[cfg(test)] +mod tests { + extern crate test; + + use std::collections::HashMap; + use std::convert::{From, TryFrom}; + use std::sync::Arc; + + use async_trait::async_trait; + use bytes::Bytes; + use chashmap::CHashMap; + use futures::executor; + use num_traits::FromPrimitive; + use rand::random; + use rand::rngs::OsRng; + use rayon::iter::IntoParallelRefIterator; + use rayon::prelude::*; + use test::Bencher; + + use common_crypto::{ + Crypto, PrivateKey, PublicKey, Secp256k1, Secp256k1PrivateKey, Secp256k1PublicKey, + Secp256k1Signature, Signature, + }; + use protocol::codec::ProtocolCodec; + use protocol::traits::{Context, MemPool, MemPoolAdapter, MixedTxHashes}; + use protocol::types::{ + AccountAddress as Address, Fee, Hash, RawTransaction, SignedTransaction, TransactionAction, + }; + use protocol::ProtocolResult; + + use crate::{HashMemPool, MemPoolError}; + + const AMOUNT: i32 = 42; + const CYCLE_LIMIT: u64 = 10_000; + const CURRENT_EPOCH_ID: u64 = 999; + const POOL_SIZE: usize = 100_000; + const TIMEOUT: u64 = 1000; + const TIMEOUT_GAP: u64 = 100; + const TX_CYCLE: u64 = 1; + + pub struct HashMemPoolAdapter { + network_txs: CHashMap, + } + + impl HashMemPoolAdapter { + fn new() -> HashMemPoolAdapter { + HashMemPoolAdapter { + network_txs: CHashMap::new(), + } + } + } + + #[async_trait] + impl MemPoolAdapter for HashMemPoolAdapter { + async fn pull_txs( + &self, + _ctx: Context, + tx_hashes: Vec, + ) -> ProtocolResult> { + let mut vec = Vec::new(); + for hash in tx_hashes { + if let Some(tx) = self.network_txs.get(&hash) { + vec.push(tx.clone()); + } + } + Ok(vec) + } + + async fn broadcast_tx(&self, _ctx: Context, tx: SignedTransaction) -> ProtocolResult<()> { + self.network_txs.insert(tx.tx_hash.clone(), tx); + Ok(()) + } + + async fn check_signature( + &self, + _ctx: Context, + tx: SignedTransaction, + ) -> ProtocolResult<()> { + check_hash(tx.clone()).await?; + check_sig(&tx) + } + + async fn check_transaction( + &self, + _ctx: Context, + _tx: SignedTransaction, + ) -> ProtocolResult<()> { + Ok(()) + } + + async fn check_storage_exist(&self, _ctx: Context, _tx_hash: Hash) -> ProtocolResult<()> { + Ok(()) + } + } + + macro_rules! insert { + (normal($pool_size: expr, $input: expr, $output: expr)) => { + insert!(inner($pool_size, 1, $input, 0, $output)); + }; + (repeat($repeat: expr, $input: expr, $output: expr)) => { + insert!(inner($input * 10, $repeat, $input, 0, $output)); + }; + (invalid($valid: expr, $invalid: expr, $output: expr)) => { + insert!(inner($valid * 10, 1, $valid, $invalid, $output)); + }; + (inner($pool_size: expr, $repeat: expr, $valid: expr, $invalid: expr, $output: expr)) => { + let mempool = Arc::new(new_mempool( + $pool_size, + CYCLE_LIMIT, + TIMEOUT_GAP, + CURRENT_EPOCH_ID, + )); + let txs = mock_txs($valid, $invalid, TIMEOUT); + for _ in 0..$repeat { + concurrent_insert(txs.clone(), Arc::clone(&mempool)); + } + assert_eq!(mempool.get_tx_cache().len(), $output); + }; + } + + #[test] + fn test_insert() { + // 1. insertion under pool size. + insert!(normal(100, 100, 100)); + + // 2. insertion above pool size. + insert!(normal(100, 101, 100)); + + // 3. repeat insertion + insert!(repeat(5, 200, 200)); + + // 4. invalid insertion + insert!(invalid(80, 10, 80)); + } + + macro_rules! package { + (normal($cycle_limit: expr, $insert: expr, $expect_order: expr, $expect_propose: expr)) => { + package!(inner( + $cycle_limit, + CURRENT_EPOCH_ID, + TIMEOUT_GAP, + TIMEOUT, + $insert, + $expect_order, + $expect_propose + )); + }; + (timeout($current_epoch_id: expr, $timeout_gap: expr, $timeout: expr, $insert: expr, $expect: expr)) => { + package!(inner( + $insert, + $current_epoch_id, + $timeout_gap, + $timeout, + $insert, + $expect, + 0 + )); + }; + (inner($cycle_limit: expr, $current_epoch_id: expr, $timeout_gap: expr, $timeout: expr, $insert: expr, $expect_order: expr, $expect_propose: expr)) => { + let mempool = &Arc::new(new_mempool( + $insert * 10, + $cycle_limit, + $timeout_gap, + $current_epoch_id, + )); + let txs = mock_txs($insert, 0, $timeout); + concurrent_insert(txs.clone(), Arc::clone(mempool)); + let mixed_tx_hashes = exec_package(Arc::clone(mempool)); + assert_eq!(mixed_tx_hashes.order_tx_hashes.len(), $expect_order); + assert_eq!(mixed_tx_hashes.propose_tx_hashes.len(), $expect_propose); + }; + } + + #[test] + fn test_package() { + // 1. pool_size <= cycle_limit + package!(normal(100, 50, 50, 0)); + package!(normal(100, 100, 100, 0)); + + // 2. cycle_limit < pool_size <= 2 * cycle_limit + package!(normal(100, 101, 100, 1)); + package!(normal(100, 200, 100, 100)); + + // 3. 2 * cycle_limit < pool_size + package!(normal(100, 201, 100, 100)); + + // 4. current_epoch_id >= tx.timeout + package!(timeout(100, 50, 100, 10, 0)); + package!(timeout(100, 50, 90, 10, 0)); + + // 5. current_epoch_id + timeout_gap < tx.timeout + package!(timeout(100, 50, 151, 10, 0)); + package!(timeout(100, 50, 160, 10, 0)); + + // 6. tx.timeout - timeout_gap =< current_epoch_id < tx.timeout + package!(timeout(100, 50, 150, 10, 10)); + package!(timeout(100, 50, 101, 10, 10)); + } + + #[test] + fn test_package_order_consistent_with_insert_order() { + let mempool = &Arc::new(default_mempool()); + + let txs = &default_mock_txs(100); + txs.iter() + .for_each(|signed_tx| exec_insert(signed_tx, Arc::clone(mempool))); + let mixed_tx_hashes = exec_package(Arc::clone(mempool)); + assert!(check_order_consistant(&mixed_tx_hashes, txs)); + + // flush partial txs and test order consistency + let (remove_txs, reserve_txs) = txs.split_at(50); + let remove_hashes: Vec = remove_txs.iter().map(|tx| tx.tx_hash.clone()).collect(); + exec_flush(remove_hashes, Arc::clone(mempool)); + let mixed_tx_hashes = exec_package(Arc::clone(mempool)); + assert!(check_order_consistant(&mixed_tx_hashes, reserve_txs)); + } + + #[test] + fn test_flush() { + let mempool = Arc::new(default_mempool()); + + // insert txs + let txs = default_mock_txs(555); + concurrent_insert(txs.clone(), Arc::clone(&mempool)); + assert_eq!(mempool.get_tx_cache().len(), 555); + + let callback_cache = mempool.get_callback_cache(); + txs.iter().for_each(|tx| { + callback_cache.insert(tx.tx_hash.clone(), tx.clone()); + }); + assert_eq!(callback_cache.len(), 555); + + // flush exist txs + let (remove_txs, _) = txs.split_at(123); + let remove_hashes: Vec = remove_txs.iter().map(|tx| tx.tx_hash.clone()).collect(); + exec_flush(remove_hashes, Arc::clone(&mempool)); + assert_eq!(mempool.get_tx_cache().len(), 432); + assert_eq!(mempool.get_tx_cache().queue_len(), 555); + exec_package(Arc::clone(&mempool)); + assert_eq!(mempool.get_tx_cache().queue_len(), 432); + assert_eq!(callback_cache.len(), 0); + + // flush absent txs + let txs = default_mock_txs(222); + let remove_hashes: Vec = txs.iter().map(|tx| tx.tx_hash.clone()).collect(); + exec_flush(remove_hashes, Arc::clone(&mempool)); + assert_eq!(mempool.get_tx_cache().len(), 432); + assert_eq!(mempool.get_tx_cache().queue_len(), 432); + } + + macro_rules! ensure_order_txs { + ($in_pool: expr, $out_pool: expr) => { + let mempool = &Arc::new(default_mempool()); + + let txs = &default_mock_txs($in_pool + $out_pool); + let (in_pool_txs, out_pool_txs) = txs.split_at($in_pool); + concurrent_insert(in_pool_txs.to_vec(), Arc::clone(mempool)); + concurrent_broadcast(out_pool_txs.to_vec(), Arc::clone(mempool)); + + let tx_hashes: Vec = txs.iter().map(|tx| tx.tx_hash.clone()).collect(); + exec_ensure_order_txs(tx_hashes.clone(), Arc::clone(mempool)); + + assert_eq!(mempool.get_callback_cache().len(), $out_pool); + + let fetch_txs = exec_get_full_txs(tx_hashes, Arc::clone(mempool)); + assert_eq!(fetch_txs.len(), txs.len()); + }; + } + + #[test] + fn test_ensure_order_txs() { + // all txs are in pool + ensure_order_txs!(100, 0); + // 50 txs are not in pool + ensure_order_txs!(50, 50); + // all txs are not in pool + ensure_order_txs!(0, 100); + } + + #[test] + fn test_sync_propose_txs() { + let mempool = &Arc::new(default_mempool()); + + let txs = &default_mock_txs(50); + let (exist_txs, need_sync_txs) = txs.split_at(20); + concurrent_insert(exist_txs.to_vec(), Arc::clone(mempool)); + concurrent_broadcast(need_sync_txs.to_vec(), Arc::clone(mempool)); + + let tx_hashes: Vec = txs.iter().map(|tx| tx.tx_hash.clone()).collect(); + exec_sync_propose_txs(tx_hashes.clone(), Arc::clone(mempool)); + + assert_eq!(mempool.get_tx_cache().len(), 50); + } + + #[bench] + fn bench_insert(b: &mut Bencher) { + let mempool = &Arc::new(default_mempool()); + + b.iter(|| { + let txs = default_mock_txs(100); + concurrent_insert(txs, Arc::clone(mempool)); + }); + } + + #[bench] + fn bench_package(b: &mut Bencher) { + let mempool = Arc::new(default_mempool()); + let txs = default_mock_txs(50_000); + concurrent_insert(txs.clone(), Arc::clone(&mempool)); + b.iter(|| { + exec_package(Arc::clone(&mempool)); + }); + } + + #[bench] + fn bench_flush(b: &mut Bencher) { + let mempool = &Arc::new(default_mempool()); + let txs = &default_mock_txs(100); + let remove_hashes: &Vec = &txs.iter().map(|tx| tx.tx_hash.clone()).collect(); + b.iter(|| { + concurrent_insert(txs.clone(), Arc::clone(mempool)); + exec_flush(remove_hashes.clone(), Arc::clone(mempool)); + exec_package(Arc::clone(mempool)); + }); + } + + #[bench] + fn bench_mock_txs(b: &mut Bencher) { + b.iter(|| { + default_mock_txs(100); + }); + } + + #[bench] + fn bench_check_sig(b: &mut Bencher) { + let txs = &default_mock_txs(100); + + b.iter(|| { + concurrent_check_sig(txs.clone()); + }); + } + + fn default_mock_txs(size: usize) -> Vec { + mock_txs(size, 0, TIMEOUT) + } + + fn mock_txs(valid_size: usize, invalid_size: usize, timeout: u64) -> Vec { + let mut vec = Vec::new(); + let mut rng = OsRng::new().expect("OsRng"); + let (priv_key, pub_key) = Secp256k1::generate_keypair(&mut rng); + let address = pub_key_to_address(&pub_key).unwrap(); + for i in 0..valid_size + invalid_size { + vec.push(mock_signed_tx( + &priv_key, + &pub_key, + &address, + timeout, + i < valid_size, + )); + } + vec + } + + fn default_mempool() -> HashMemPool { + new_mempool(POOL_SIZE, CYCLE_LIMIT, TIMEOUT_GAP, CURRENT_EPOCH_ID) + } + + fn new_mempool( + pool_size: usize, + cycle_limit: u64, + timeout_gap: u64, + current_epoch_id: u64, + ) -> HashMemPool { + let adapter = HashMemPoolAdapter::new(); + HashMemPool::new( + pool_size, + timeout_gap, + cycle_limit, + current_epoch_id, + adapter, + ) + } + + fn pub_key_to_address(pub_key: &Secp256k1PublicKey) -> ProtocolResult
{ + let mut pub_key_str = Hash::digest(pub_key.to_bytes()).as_hex(); + pub_key_str.truncate(40); + pub_key_str.insert_str(0, "10"); + Address::from_bytes(Bytes::from(hex::decode(pub_key_str).unwrap())) + } + + async fn check_hash(tx: SignedTransaction) -> ProtocolResult<()> { + let mut raw = tx.raw; + let raw_bytes = raw.encode().await?; + let tx_hash = Hash::digest(raw_bytes); + if tx_hash != tx.tx_hash { + return Err(MemPoolError::CheckHash { + expect: tx.tx_hash.clone(), + actual: tx_hash.clone(), + } + .into()); + } + Ok(()) + } + + fn check_sig(tx: &SignedTransaction) -> ProtocolResult<()> { + if Secp256k1::verify_signature(&tx.tx_hash.as_bytes(), &tx.signature, &tx.pubkey).is_err() { + return Err(MemPoolError::CheckSig { + tx_hash: tx.tx_hash.clone(), + } + .into()); + } + Ok(()) + } + + fn concurrent_check_sig(txs: Vec) { + txs.par_iter().for_each(|signed_tx| { + check_sig(signed_tx).unwrap(); + }); + } + + fn concurrent_insert( + txs: Vec, + mempool: Arc>, + ) { + txs.par_iter() + .for_each(|signed_tx| exec_insert(signed_tx, Arc::clone(&mempool))); + } + + fn concurrent_broadcast( + txs: Vec, + mempool: Arc>, + ) { + txs.par_iter().for_each(|signed_tx| { + executor::block_on(async { + mempool + .get_adapter() + .broadcast_tx(HashMap::new(), signed_tx.clone()) + .await + .unwrap(); + }) + }); + } + + fn exec_insert(signed_tx: &SignedTransaction, mempool: Arc>) { + executor::block_on(async { + let _ = mempool.insert(HashMap::new(), signed_tx.clone()).await; + }); + } + + fn exec_flush(remove_hashes: Vec, mempool: Arc>) { + executor::block_on(async { + mempool.flush(HashMap::new(), remove_hashes).await.unwrap(); + }); + } + + fn exec_package(mempool: Arc>) -> MixedTxHashes { + executor::block_on(async { mempool.package(HashMap::new()).await.unwrap() }) + } + + fn exec_ensure_order_txs( + require_hashes: Vec, + mempool: Arc>, + ) { + executor::block_on(async { + mempool + .ensure_order_txs(HashMap::new(), require_hashes) + .await + .unwrap(); + }) + } + + fn exec_sync_propose_txs( + require_hashes: Vec, + mempool: Arc>, + ) { + executor::block_on(async { + mempool + .sync_propose_txs(HashMap::new(), require_hashes) + .await + .unwrap(); + }) + } + + fn exec_get_full_txs( + require_hashes: Vec, + mempool: Arc>, + ) -> Vec { + executor::block_on(async { + mempool + .get_full_txs(HashMap::new(), require_hashes) + .await + .unwrap() + }) + } + + fn mock_signed_tx( + priv_key: &Secp256k1PrivateKey, + pub_key: &Secp256k1PublicKey, + address: &Address, + timeout: u64, + valid: bool, + ) -> SignedTransaction { + let nonce = Hash::digest(Bytes::from(get_random_bytes(10))); + let fee = Fee { + asset_id: nonce.clone(), + cycle: TX_CYCLE, + }; + let action = TransactionAction::Transfer { + receiver: address.clone(), + asset_id: nonce.clone(), + amount: FromPrimitive::from_i32(AMOUNT).unwrap(), + }; + let mut raw = RawTransaction { + chain_id: nonce.clone(), + nonce, + timeout, + fee, + action, + }; + + let raw_bytes = executor::block_on(async { raw.encode().await.unwrap() }); + let tx_hash = Hash::digest(raw_bytes); + + let signature = if valid { + Secp256k1::sign_message(&tx_hash.as_bytes(), &priv_key.to_bytes()).unwrap() + } else { + Secp256k1Signature::try_from([0u8; 64].as_parallel_slice()).unwrap() + }; + + SignedTransaction { + raw, + tx_hash, + pubkey: pub_key.to_bytes(), + signature: signature.to_bytes(), + } + } + + fn get_random_bytes(len: usize) -> Vec { + (0..len).map(|_| random::()).collect() + } + + fn check_order_consistant(mixed_tx_hashes: &MixedTxHashes, txs: &[SignedTransaction]) -> bool { + mixed_tx_hashes + .order_tx_hashes + .iter() + .enumerate() + .any(|(i, hash)| hash == &txs.get(i).unwrap().tx_hash) + } +} diff --git a/core/mempool/src/tx_cache.rs b/core/mempool/src/tx_cache.rs index 14cb12568..562f66bb8 100644 --- a/core/mempool/src/tx_cache.rs +++ b/core/mempool/src/tx_cache.rs @@ -7,8 +7,8 @@ use protocol::traits::MixedTxHashes; use protocol::types::{Hash, SignedTransaction}; use protocol::ProtocolResult; -use crate::error::MemPoolError; use crate::map::Map; +use crate::MemPoolError; /// Wrap `SignedTransaction` with two marks for mempool management. /// @@ -161,7 +161,7 @@ impl TxCache { pub fn show_unknown(&self, tx_hashes: Vec) -> Vec { tx_hashes .into_iter() - .filter(|tx_hash| self.contain(tx_hash)) + .filter(|tx_hash| !self.contain(tx_hash)) .collect() } @@ -220,7 +220,7 @@ impl TxCache { cycle_count += shared_tx.tx.raw.fee.cycle; if cycle_count > cycle_limit { stage = stage.next(); - cycle_count = 0; + cycle_count = shared_tx.tx.raw.fee.cycle; } match stage { @@ -245,6 +245,25 @@ impl TxCache { }) } + #[inline] + pub fn check_exist(&self, tx_hash: &Hash) -> ProtocolResult<()> { + if self.contain(tx_hash) { + return Err(MemPoolError::Dup { + tx_hash: tx_hash.clone(), + } + .into()); + } + Ok(()) + } + + #[inline] + pub fn check_reach_limit(&self, pool_size: usize) -> ProtocolResult<()> { + if self.len() >= pool_size { + return Err(MemPoolError::ReachLimit { pool_size }.into()); + } + Ok(()) + } + #[inline] pub fn contain(&self, tx_hash: &Hash) -> bool { self.map.contains_key(tx_hash) @@ -256,7 +275,7 @@ impl TxCache { } #[allow(dead_code)] - fn queue_len(&self) -> usize { + pub fn queue_len(&self) -> usize { if self.is_zero.load(Ordering::Relaxed) { self.queue_0.len() } else { @@ -348,11 +367,11 @@ mod tests { use crate::tx_cache::TxCache; use std::thread::JoinHandle; - const POOL_SIZE: usize = 100_000; + const POOL_SIZE: usize = 1000; const BYTES_LEN: usize = 10; - const TX_NUM: usize = 100_000; + const TX_NUM: usize = 1000; const TX_CYCLE: u64 = 1; - const CYCLE_LIMIT: u64 = 50000; + const CYCLE_LIMIT: u64 = 500; const CURRENT_H: u64 = 100; const TIMEOUT: u64 = 150;