diff --git a/Cargo.lock b/Cargo.lock index 5882630154..35626bba10 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -132,9 +132,9 @@ checksum = "34fde25430d87a9388dadbe6e34d7f72a462c8b43ac8d309b42b0a8505d7e2a5" [[package]] name = "anyhow" -version = "1.0.44" +version = "1.0.47" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "61604a8f862e1d5c3229fdd78f8b02c68dcf73a4c4b05fd636d12240aaa242c1" +checksum = "38d9ff5d688f1c13395289f67db01d4826b46dd694e7580accdc3e8430f2d98e" [[package]] name = "arc-swap" @@ -175,9 +175,9 @@ dependencies = [ [[package]] name = "arrayvec" -version = "0.7.1" +version = "0.7.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "be4dc07131ffa69b8072d35f5007352af944213cde02545e2103680baed38fcd" +checksum = "8da52d66c7071e2e3fa2a1e5c6d088fec47b593032b254f5e980de8ea54454d6" [[package]] name = "async-stream" @@ -236,9 +236,9 @@ checksum = "cdb031dd78e28731d87d56cc8ffef4a8f36ca26c38fe2de700543e627f8a464a" [[package]] name = "base58-monero" -version = "0.3.1" +version = "0.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1df741a56140a4d26932fde3bd69ce8abec7f50d238f8c16351caa789a6d8fe9" +checksum = "935c90240f9b7749c80746bf88ad9cb346f34b01ee30ad4d566dfdecd6e3cc6a" dependencies = [ "thiserror", ] @@ -275,9 +275,9 @@ dependencies = [ [[package]] name = "base64ct" -version = "1.1.1" +version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e6b4d9b1225d28d360ec6a231d65af1fd99a2a095154c8040689617290569c5c" +checksum = "392c772b012d685a640cdad68a5a21f4a45e696f85a2c2c907aab2fe49a91e19" [[package]] name = "bigdecimal" @@ -556,9 +556,9 @@ dependencies = [ [[package]] name = "cc" -version = "1.0.71" +version = "1.0.72" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "79c2681d6594606957bbb8631c4b90a7fcaaa72cdb714743a437b156d6a7eedd" +checksum = "22a9137b95ea06864e018375b72adfb7db6e6f68cfc8df5a04d00288050485ee" [[package]] name = "cexpr" @@ -671,9 +671,9 @@ checksum = "b0fc239e0f6cb375d2402d48afb92f76f5404fd1df208a41930ec81eda078bea" [[package]] name = "clang-sys" -version = "1.2.2" +version = "1.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "10612c0ec0e0a1ff0e97980647cb058a6e7aedb913d01d009c406b8b7d0b26ee" +checksum = "fa66045b9cb23c2e9c1520732030608b02ee07e5cfaa5a521ec15ded7fa24c90" dependencies = [ "glob", "libc", @@ -1053,9 +1053,9 @@ dependencies = [ [[package]] name = "curl-sys" -version = "0.4.49+curl-7.79.1" +version = "0.4.51+curl-7.80.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e0f44960aea24a786a46907b8824ebc0e66ca06bf4e4978408c7499620343483" +checksum = "d130987e6a6a34fe0889e1083022fa48cd90e6709a84be3fb8dd95801de5af20" dependencies = [ "cc", "libc", @@ -1322,9 +1322,9 @@ checksum = "56899898ce76aaf4a0f24d914c97ea6ed976d42fec6ad33fcbb0a1103e07b2b0" [[package]] name = "ed25519" -version = "1.2.0" +version = "1.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4620d40f6d2601794401d6dd95a5cf69b6c157852539470eeda433a99b3c0efc" +checksum = "74e1069e39f1454367eb2de793ed062fac4c35c2934b76a81d90dd9abcd28816" dependencies = [ "signature", ] @@ -1794,7 +1794,7 @@ checksum = "7fcd999463524c52659517fe2cea98493cfe485d10565e7b0fb07dbba7ad2753" dependencies = [ "cfg-if 1.0.0", "libc", - "wasi 0.10.0+wasi-snapshot-preview1", + "wasi 0.10.2+wasi-snapshot-preview1", ] [[package]] @@ -1928,9 +1928,9 @@ checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70" [[package]] name = "hex-literal" -version = "0.3.3" +version = "0.3.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "21e4590e13640f19f249fe3e4eca5113bc4289f2497710378190e7f4bd96f45b" +checksum = "7ebdb29d2ea9ed0083cd8cece49bbd968021bd99b0849edb4a9a7ee0fdf6a4e0" [[package]] name = "http" @@ -1962,9 +1962,9 @@ checksum = "acd94fdbe1d4ff688b67b04eee2e17bd50995534a61539e45adfefb45e5e5503" [[package]] name = "httpdate" -version = "1.0.1" +version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6456b8a6c8f33fee7d958fcd1b60d55b11940a79e63ae87013e6d22e26034440" +checksum = "c4a1e36c821dbe04574f602848a19f742f4fb3c98d40449f11bcad18d6b17421" [[package]] name = "humantime" @@ -1983,9 +1983,9 @@ checksum = "9a3a5bfb195931eeb336b2a7b4d761daec841b97f947d34394601737a7bba5e4" [[package]] name = "hyper" -version = "0.14.14" +version = "0.14.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2b91bb1f221b6ea1f1e4371216b70f40748774c2fb5971b450c07773fb92d26b" +checksum = "436ec0091e4f20e655156a30a0df3770fe2900aa301e548e08446ec794b6953c" dependencies = [ "bytes 1.1.0", "futures-channel", @@ -2204,9 +2204,9 @@ checksum = "830d08ce1d1d941e6b30645f1a0eb5643013d835ce3779a5fc208261dbe10f55" [[package]] name = "libc" -version = "0.2.105" +version = "0.2.108" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "869d572136620d55835903746bcb5cdc54cb2851fd0aeec53220b4bb65ef3013" +checksum = "8521a1b57e76b1ec69af7599e75e38e7b7fad6610f037db8c79b127201b5d119" [[package]] name = "libgit2-sys" @@ -2235,9 +2235,9 @@ dependencies = [ [[package]] name = "libloading" -version = "0.7.1" +version = "0.7.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c0cf036d15402bea3c5d4de17b3fce76b3e4a56ebc1f577be0e7a72f7c607cf0" +checksum = "afe203d669ec979b7128619bae5a63b7b42e9203c1b29146079ee05e2f604b52" dependencies = [ "cfg-if 1.0.0", "winapi 0.3.9", @@ -2869,9 +2869,9 @@ checksum = "624a8340c38c1b80fd549087862da4ba43e08858af025b236e509b6649fc13d5" [[package]] name = "openssl" -version = "0.10.37" +version = "0.10.38" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2bc6b9e4403633698352880b22cbe2f0e45dd0177f6fabe4585536e56a3e4f75" +checksum = "0c7ae222234c30df141154f159066c5093ff73b63204dcda7121eb082fc56a95" dependencies = [ "bitflags 1.3.2", "cfg-if 1.0.0", @@ -2889,18 +2889,18 @@ checksum = "28988d872ab76095a6e6ac88d99b54fd267702734fd7ffe610ca27f533ddb95a" [[package]] name = "openssl-src" -version = "111.16.0+1.1.1l" +version = "300.0.2+3.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7ab2173f69416cf3ec12debb5823d244127d23a9b127d5a5189aa97c5fa2859f" +checksum = "14a760a11390b1a5daf72074d4f6ff1a6e772534ae191f999f57e9ee8146d1fb" dependencies = [ "cc", ] [[package]] name = "openssl-sys" -version = "0.9.68" +version = "0.9.71" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1c571f25d3f66dd427e417cebf73dbe2361d6125cf6e3a70d143fdf97c9f5150" +checksum = "7df13d165e607909b363a4757a6f133f8a818a74e9d3a98d09c6128e15fa4c73" dependencies = [ "autocfg 1.0.1", "cc", @@ -4094,9 +4094,9 @@ dependencies = [ [[package]] name = "serde_json" -version = "1.0.68" +version = "1.0.71" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0f690853975602e1bfe1ccbf50504d67174e3bcf340f23b5ea9992e0587a52d8" +checksum = "063bf466a64011ac24040a49009724ee60a57da1b437617ceb32e53ad61bfb19" dependencies = [ "itoa", "ryu", @@ -4710,7 +4710,6 @@ dependencies = [ "tokio-stream", "tokio-util", "tower", - "tower-make", "tracing", "tracing-futures", "yamux", @@ -4759,7 +4758,6 @@ dependencies = [ "thiserror", "tokio", "tokio-stream", - "tokio-test", "tower", "tower-test", "ttl_cache", @@ -4920,7 +4918,7 @@ name = "tari_key_manager" version = "0.21.2" dependencies = [ "argon2", - "arrayvec 0.7.1", + "arrayvec 0.7.2", "blake2", "chacha20", "chrono", @@ -5388,12 +5386,11 @@ dependencies = [ [[package]] name = "time" -version = "0.1.44" +version = "0.1.43" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6db9e6914ab8b1ae1c260a4ae7a49b6c5611b40328a735b21862567685e73255" +checksum = "ca8a50ef2360fbd1eeb0ecd46795a87a19024eb4b53c5dc916ca1fd95fe62438" dependencies = [ "libc", - "wasi 0.10.0+wasi-snapshot-preview1", "winapi 0.3.9", ] @@ -5418,9 +5415,9 @@ dependencies = [ [[package]] name = "tinyvec" -version = "1.5.0" +version = "1.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f83b2a3d4d9091d0abd7eba4dc2710b1718583bd4d8992e2190720ea38f391f7" +checksum = "2c1c1d5a42b6245520c249549ec267180beaffcc0615401ac8e31853d4b6d8d2" dependencies = [ "tinyvec_macros", ] @@ -5634,15 +5631,6 @@ version = "0.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "343bc9466d3fe6b0f960ef45960509f84480bf4fd96f92901afe7ff3df9d3a62" -[[package]] -name = "tower-make" -version = "0.3.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ce50370d644a0364bf4877ffd4f76404156a248d104e2cc234cd391ea5cdc965" -dependencies = [ - "tower-service", -] - [[package]] name = "tower-service" version = "0.3.1" @@ -6004,9 +5992,9 @@ dependencies = [ [[package]] name = "unsigned-varint" -version = "0.7.0" +version = "0.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5f8d425fafb8cd76bc3f22aace4af471d3156301d7508f2107e98fbeae10bc7f" +checksum = "d86a8dc7f45e4c1b0d30e43038c38f274e77af056aa5f74b93c2cf9eb3c1c836" [[package]] name = "untrusted" @@ -6132,9 +6120,9 @@ checksum = "cccddf32554fecc6acb585f82a32a72e28b48f8c4c1883ddfeeeaa96f7d8e519" [[package]] name = "wasi" -version = "0.10.0+wasi-snapshot-preview1" +version = "0.10.2+wasi-snapshot-preview1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1a143597ca7c7793eff794def352d41792a93c481eb1042423ff7ff72ba2c31f" +checksum = "fd6fbd9a79829dd1ad0cc20627bf1ed606756a7f77edff7b66b7064f9cb327c6" [[package]] name = "wasm-bindgen" @@ -6339,9 +6327,9 @@ dependencies = [ [[package]] name = "zeroize_derive" -version = "1.2.0" +version = "1.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bdff2024a851a322b08f179173ae2ba620445aef1e838f0c196820eade4ae0c7" +checksum = "65f1a51723ec88c66d5d1fe80c841f17f63587d6691901d66be9bec6c3b51f73" dependencies = [ "proc-macro2 1.0.32", "quote 1.0.10", diff --git a/applications/tari_base_node/src/builder.rs b/applications/tari_base_node/src/builder.rs index bdacc47765..066f77a674 100644 --- a/applications/tari_base_node/src/builder.rs +++ b/applications/tari_base_node/src/builder.rs @@ -221,7 +221,7 @@ async fn build_node_context( let factories = CryptoFactories::default(); let randomx_factory = RandomXFactory::new(config.max_randomx_vms); let validators = Validators::new( - BodyOnlyValidator::default(), + BodyOnlyValidator::new(rules.clone()), HeaderValidator::new(rules.clone()), OrphanBlockValidator::new( rules.clone(), diff --git a/applications/tari_base_node/src/command_handler.rs b/applications/tari_base_node/src/command_handler.rs index 1f7ef20f11..799874d4d6 100644 --- a/applications/tari_base_node/src/command_handler.rs +++ b/applications/tari_base_node/src/command_handler.rs @@ -494,6 +494,7 @@ impl CommandHandler { s.push(format!( "LAST_SEEN = {}", Utc::now() + .naive_utc() .signed_duration_since(dt) .to_std() .map(format_duration_basic) diff --git a/applications/tari_base_node/src/recovery.rs b/applications/tari_base_node/src/recovery.rs index 21a29f3a68..4aa6a54a1b 100644 --- a/applications/tari_base_node/src/recovery.rs +++ b/applications/tari_base_node/src/recovery.rs @@ -98,7 +98,7 @@ pub async fn run_recovery(node_config: &GlobalConfig) -> Result<(), anyhow::Erro let factories = CryptoFactories::default(); let randomx_factory = RandomXFactory::new(node_config.max_randomx_vms); let validators = Validators::new( - BodyOnlyValidator::default(), + BodyOnlyValidator::new(rules.clone()), HeaderValidator::new(rules.clone()), OrphanBlockValidator::new( rules.clone(), diff --git a/base_layer/core/src/base_node/state_machine_service/states/horizon_state_sync/error.rs b/base_layer/core/src/base_node/state_machine_service/states/horizon_state_sync/error.rs index df62e5495a..0edfb31fa3 100644 --- a/base_layer/core/src/base_node/state_machine_service/states/horizon_state_sync/error.rs +++ b/base_layer/core/src/base_node/state_machine_service/states/horizon_state_sync/error.rs @@ -70,6 +70,8 @@ pub enum HorizonSyncError { MerkleMountainRangeError(#[from] MerkleMountainRangeError), #[error("Connectivity error: {0}")] ConnectivityError(#[from] ConnectivityError), + #[error("Validation error: {0}")] + ValidationError(#[from] ValidationError), } impl From for HorizonSyncError { diff --git a/base_layer/core/src/base_node/state_machine_service/states/horizon_state_sync/horizon_state_synchronization.rs b/base_layer/core/src/base_node/state_machine_service/states/horizon_state_sync/horizon_state_synchronization.rs index 0a3e208e29..ce0e04bf3c 100644 --- a/base_layer/core/src/base_node/state_machine_service/states/horizon_state_sync/horizon_state_synchronization.rs +++ b/base_layer/core/src/base_node/state_machine_service/states/horizon_state_sync/horizon_state_synchronization.rs @@ -43,6 +43,7 @@ use crate::{ transaction_kernel::TransactionKernel, transaction_output::TransactionOutput, }, + validation::helpers, }; use croaring::Bitmap; use futures::{stream::FuturesUnordered, StreamExt}; @@ -374,6 +375,11 @@ impl<'a, B: BlockchainBackend + 'static> HorizonStateSynchronization<'a, B> { let mut output_mmr = MerkleMountainRange::::new(output_pruned_set); let mut witness_mmr = MerkleMountainRange::::new(witness_pruned_set); + let mut constants = self + .shared + .consensus_rules + .consensus_constants(current_header.height()) + .clone(); while let Some(response) = output_stream.next().await { let res: SyncUtxosResponse = response?; @@ -401,6 +407,7 @@ impl<'a, B: BlockchainBackend + 'static> HorizonStateSynchronization<'a, B> { ); height_utxo_counter += 1; let output = TransactionOutput::try_from(output).map_err(HorizonSyncError::ConversionError)?; + helpers::check_tari_script_byte_size(&output.script, constants.get_max_script_byte_size())?; unpruned_outputs.push(output.clone()); output_mmr.push(output.hash())?; @@ -535,6 +542,11 @@ impl<'a, B: BlockchainBackend + 'static> HorizonStateSynchronization<'a, B> { break; } else { current_header = db.fetch_chain_header(current_header.height() + 1).await?; + constants = self + .shared + .consensus_rules + .consensus_constants(current_header.height()) + .clone(); debug!( target: LOG_TARGET, "Expecting to receive the next UTXO set {}-{} for header #{}", diff --git a/base_layer/core/src/base_node/sync/block_sync/error.rs b/base_layer/core/src/base_node/sync/block_sync/error.rs index 6e4e772e34..ed7d98288b 100644 --- a/base_layer/core/src/base_node/sync/block_sync/error.rs +++ b/base_layer/core/src/base_node/sync/block_sync/error.rs @@ -34,8 +34,6 @@ pub enum BlockSyncError { RpcRequestError(#[from] RpcStatus), #[error("Chain storage error: {0}")] ChainStorageError(#[from] ChainStorageError), - #[error("Peer sent invalid block body: {0}")] - ReceivedInvalidBlockBody(String), #[error("Peer sent a block that did not form a chain. Expected hash = {expected}, got = {got}")] PeerSentBlockThatDidNotFormAChain { expected: String, got: String }, #[error("Connectivity Error: {0}")] @@ -48,4 +46,6 @@ pub enum BlockSyncError { FailedToBan(ConnectivityError), #[error("Failed to construct valid chain block")] FailedToConstructChainBlock, + #[error("Peer violated the block sync protocol: {0}")] + ProtocolViolation(String), } diff --git a/base_layer/core/src/base_node/sync/block_sync/synchronizer.rs b/base_layer/core/src/base_node/sync/block_sync/synchronizer.rs index b476f4078f..5b4e1d29f2 100644 --- a/base_layer/core/src/base_node/sync/block_sync/synchronizer.rs +++ b/base_layer/core/src/base_node/sync/block_sync/synchronizer.rs @@ -26,7 +26,7 @@ use crate::{ sync::{hooks::Hooks, rpc, SyncPeer}, BlockSyncConfig, }, - blocks::{Block, ChainBlock}, + blocks::{Block, BlockValidationError, ChainBlock}, chain_storage::{async_db::AsyncBlockchainDb, BlockchainBackend}, proto::base_node::SyncBlocksRequest, tari_utilities::{hex::Hex, Hashable}, @@ -97,7 +97,29 @@ impl BlockSynchronizer { Ok(()) }, Err(err @ BlockSyncError::ValidationError(ValidationError::AsyncTaskFailed(_))) => Err(err), - Err(err @ BlockSyncError::ValidationError(_)) | Err(err @ BlockSyncError::ReceivedInvalidBlockBody(_)) => { + Err(BlockSyncError::ValidationError(err)) => { + match &err { + ValidationError::BlockHeaderError(_) => {}, + ValidationError::BlockError(BlockValidationError::MismatchedMmrRoots) | + ValidationError::BadBlockFound { .. } | + ValidationError::BlockError(BlockValidationError::MismatchedMmrSize { .. }) => { + let num_cleared = self.db.clear_all_pending_headers().await?; + warn!( + target: LOG_TARGET, + "Cleared {} incomplete headers from bad chain", num_cleared + ); + }, + _ => {}, + } + warn!( + target: LOG_TARGET, + "Banning peer because provided block failed validation: {}", err + ); + self.ban_peer(node_id, &err).await?; + Err(err.into()) + }, + Err(err @ BlockSyncError::ProtocolViolation(_)) => { + warn!(target: LOG_TARGET, "Banning peer: {}", err); self.ban_peer(node_id, &err).await?; Err(err) }, @@ -167,9 +189,10 @@ impl BlockSynchronizer { .fetch_chain_header_by_block_hash(block.hash.clone()) .await? .ok_or_else(|| { - BlockSyncError::ReceivedInvalidBlockBody("Peer sent hash for block header we do not have".into()) + BlockSyncError::ProtocolViolation("Peer sent hash for block header we do not have".into()) })?; + let current_height = header.height(); let header_hash = header.hash().clone(); if header.header().prev_hash != prev_hash { @@ -184,13 +207,13 @@ impl BlockSynchronizer { let body = block .body .map(AggregateBody::try_from) - .ok_or_else(|| BlockSyncError::ReceivedInvalidBlockBody("Block body was empty".to_string()))? - .map_err(BlockSyncError::ReceivedInvalidBlockBody)?; + .ok_or_else(|| BlockSyncError::ProtocolViolation("Block body was empty".to_string()))? + .map_err(BlockSyncError::ProtocolViolation)?; debug!( target: LOG_TARGET, "Validating block body #{} (PoW = {}, {})", - header.height(), + current_height, header.header().pow_algo(), body.to_counts_string(), ); @@ -198,7 +221,26 @@ impl BlockSynchronizer { let timer = Instant::now(); let (header, header_accum_data) = header.into_parts(); - let block = self.block_validator.validate_body(Block::new(header, body)).await?; + let block = match self.block_validator.validate_body(Block::new(header, body)).await { + Ok(block) => block, + Err(err @ ValidationError::BadBlockFound { .. }) | + Err(err @ ValidationError::FatalStorageError(_)) | + Err(err @ ValidationError::AsyncTaskFailed(_)) | + Err(err @ ValidationError::CustomError(_)) => return Err(err.into()), + Err(err) => { + // Add to bad blocks + if let Err(err) = self + .db + .write_transaction() + .insert_bad_block(header_hash, current_height) + .commit() + .await + { + error!(target: LOG_TARGET, "Failed to insert bad block: {}", err); + } + return Err(err.into()); + }, + }; let block = ChainBlock::try_construct(Arc::new(block), header_accum_data) .map(Arc::new) diff --git a/base_layer/core/src/base_node/sync/header_sync/validator.rs b/base_layer/core/src/base_node/sync/header_sync/validator.rs index 601e372469..efd0d8617f 100644 --- a/base_layer/core/src/base_node/sync/header_sync/validator.rs +++ b/base_layer/core/src/base_node/sync/header_sync/validator.rs @@ -30,6 +30,7 @@ use crate::{ tari_utilities::{epoch_time::EpochTime, hash::Hashable, hex::Hex}, validation::helpers::{ check_header_timestamp_greater_than_median, + check_not_bad_block, check_pow_data, check_target_difficulty, check_timestamp_ftl, @@ -138,7 +139,13 @@ impl BlockHeaderSyncValidator { ); let achieved_target = check_target_difficulty(&header, target_difficulty, &self.randomx_factory)?; - check_pow_data(&header, &self.consensus_rules, &*self.db.inner().db_read_access()?)?; + let block_hash = header.hash(); + + { + let txn = self.db.inner().db_read_access()?; + check_not_bad_block(&*txn, &block_hash)?; + check_pow_data(&header, &self.consensus_rules, &*txn)?; + } // Header is valid, add this header onto the validation state for the next round // Mutable borrow done later in the function to allow multiple immutable borrows before this line. This has @@ -159,7 +166,7 @@ impl BlockHeaderSyncValidator { state.target_difficulties.add_back(&header, target_difficulty); let accumulated_data = BlockHeaderAccumulatedData::builder(&state.previous_accum) - .with_hash(header.hash()) + .with_hash(block_hash) .with_achieved_target_difficulty(achieved_target) .with_total_kernel_offset(header.total_kernel_offset.clone()) .build()?; diff --git a/base_layer/core/src/chain_storage/async_db.rs b/base_layer/core/src/chain_storage/async_db.rs index 116e87e8dc..9437431e03 100644 --- a/base_layer/core/src/chain_storage/async_db.rs +++ b/base_layer/core/src/chain_storage/async_db.rs @@ -207,6 +207,8 @@ impl AsyncBlockchainDb { make_async_fn!(fetch_last_header() -> BlockHeader, "fetch_last_header"); + make_async_fn!(clear_all_pending_headers() -> usize, "clear_all_pending_headers"); + make_async_fn!(fetch_last_chain_header() -> ChainHeader, "fetch_last_chain_header"); make_async_fn!(fetch_tip_header() -> ChainHeader, "fetch_tip_header"); @@ -222,6 +224,8 @@ impl AsyncBlockchainDb { make_async_fn!(block_exists(block_hash: BlockHash) -> bool, "block_exists"); + make_async_fn!(bad_block_exists(block_hash: BlockHash) -> bool, "bad_block_exists"); + make_async_fn!(fetch_block(height: u64) -> HistoricalBlock, "fetch_block"); make_async_fn!(fetch_blocks>(bounds: T) -> Vec, "fetch_blocks"); @@ -372,6 +376,11 @@ impl<'a, B: BlockchainBackend + 'static> AsyncDbTransaction<'a, B> { self } + pub fn insert_bad_block(&mut self, hash: HashOutput, height: u64) -> &mut Self { + self.transaction.insert_bad_block(hash, height); + self + } + pub fn prune_output_at_positions(&mut self, positions: Vec) -> &mut Self { self.transaction.prune_outputs_at_positions(positions); self diff --git a/base_layer/core/src/chain_storage/blockchain_backend.rs b/base_layer/core/src/chain_storage/blockchain_backend.rs index ede1ca87c2..63bc5f542d 100644 --- a/base_layer/core/src/chain_storage/blockchain_backend.rs +++ b/base_layer/core/src/chain_storage/blockchain_backend.rs @@ -132,6 +132,10 @@ pub trait BlockchainBackend: Send + Sync { fn orphan_count(&self) -> Result; /// Returns the stored header with the highest corresponding height. fn fetch_last_header(&self) -> Result; + + /// Clear all headers that are beyond the current height of longest chain, returning the number of headers that were + /// deleted. + fn clear_all_pending_headers(&self) -> Result; /// Returns the stored header and accumulated data with the highest height. fn fetch_last_chain_header(&self) -> Result; /// Returns the stored header with the highest corresponding height. @@ -178,4 +182,7 @@ pub trait BlockchainBackend: Send + Sync { &self, mmr_positions: Vec, ) -> Result>, ChainStorageError>; + + /// Check if a block hash is in the bad block list + fn bad_block_exists(&self, block_hash: HashOutput) -> Result; } diff --git a/base_layer/core/src/chain_storage/blockchain_database.rs b/base_layer/core/src/chain_storage/blockchain_database.rs index 07cd9c20b3..d88f795de0 100644 --- a/base_layer/core/src/chain_storage/blockchain_database.rs +++ b/base_layer/core/src/chain_storage/blockchain_database.rs @@ -85,7 +85,6 @@ use crate::{ HeaderValidation, OrphanValidation, PostOrphanBodyValidation, - ValidationError, }, }; @@ -860,6 +859,11 @@ where B: BlockchainBackend Ok(()) } + pub fn clear_all_pending_headers(&self) -> Result { + let db = self.db_write_access()?; + db.clear_all_pending_headers() + } + /// Clean out the entire orphan pool pub fn cleanup_all_orphans(&self) -> Result<(), ChainStorageError> { let mut db = self.db_write_access()?; @@ -962,6 +966,12 @@ where B: BlockchainBackend Ok(db.contains(&DbKey::BlockHash(hash.clone()))? || db.contains(&DbKey::OrphanBlock(hash))?) } + /// Returns true if this block exists in the chain, or is orphaned. + pub fn bad_block_exists(&self, hash: BlockHash) -> Result { + let db = self.db_read_access()?; + db.bad_block_exists(hash) + } + /// Atomically commit the provided transaction to the database backend. This function does not update the metadata. pub fn commit(&self, txn: DbTransaction) -> Result<(), ChainStorageError> { let mut db = self.db_write_access()?; @@ -1276,10 +1286,13 @@ fn insert_best_block(txn: &mut DbTransaction, block: Arc) -> Result< block_hash.to_hex() ); if block.header().pow_algo() == PowAlgorithm::Monero { - let monero_seed = MoneroPowData::from_header(block.header()) - .map_err(|e| ValidationError::CustomError(e.to_string()))? - .randomx_key; - txn.insert_monero_seed_height(monero_seed.to_vec(), block.height()); + let monero_header = + MoneroPowData::from_header(block.header()).map_err(|e| ChainStorageError::InvalidArguments { + func: "insert_best_block", + arg: "block", + message: format!("block contained invalid or malformed monero PoW data: {}", e), + })?; + txn.insert_monero_seed_height(monero_header.randomx_key.to_vec(), block.height()); } let height = block.height(); diff --git a/base_layer/core/src/chain_storage/db_transaction.rs b/base_layer/core/src/chain_storage/db_transaction.rs index bfdd480033..b8d7b83229 100644 --- a/base_layer/core/src/chain_storage/db_transaction.rs +++ b/base_layer/core/src/chain_storage/db_transaction.rs @@ -184,6 +184,15 @@ impl DbTransaction { self } + /// Inserts a block hash into the bad block list + pub fn insert_bad_block(&mut self, block_hash: HashOutput, height: u64) -> &mut Self { + self.operations.push(WriteOperation::InsertBadBlock { + hash: block_hash, + height, + }); + self + } + /// Stores an orphan block. No checks are made as to whether this is actually an orphan. That responsibility lies /// with the calling function. /// The transaction will rollback and write will return an error if the orphan already exists. @@ -295,6 +304,10 @@ pub enum WriteOperation { witness_hash: HashOutput, mmr_position: u32, }, + InsertBadBlock { + hash: HashOutput, + height: u64, + }, DeleteHeader(u64), DeleteOrphan(HashOutput), DeleteBlock(HashOutput), @@ -414,6 +427,7 @@ impl fmt::Display for WriteOperation { SetPrunedHeight { height, .. } => write!(f, "Set pruned height to {}", height), DeleteHeader(height) => write!(f, "Delete header at height: {}", height), DeleteOrphan(hash) => write!(f, "Delete orphan with hash: {}", hash.to_hex()), + InsertBadBlock { hash, height } => write!(f, "Insert bad block #{} {}", height, hash.to_hex()), SetHorizonData { .. } => write!(f, "Set horizon data"), } } diff --git a/base_layer/core/src/chain_storage/lmdb_db/lmdb.rs b/base_layer/core/src/chain_storage/lmdb_db/lmdb.rs index 4b3cf5f2d4..783400b1c8 100644 --- a/base_layer/core/src/chain_storage/lmdb_db/lmdb.rs +++ b/base_layer/core/src/chain_storage/lmdb_db/lmdb.rs @@ -25,7 +25,7 @@ use lmdb_zero::{ del, error::{self, LmdbResultExt}, put, - traits::{AsLmdbBytes, FromLmdbBytes}, + traits::{AsLmdbBytes, CreateCursor, FromLmdbBytes}, ConstTransaction, Cursor, CursorIter, @@ -397,7 +397,7 @@ where Ok(result) } -/// Fetches all the size of all key/values in the given DB. Returns the number of entries, the total size of all the +/// Fetches the size of all key/values in the given DB. Returns the number of entries, the total size of all the /// keys and values in bytes. pub fn fetch_db_entry_sizes(txn: &ConstTransaction<'_>, db: &Database) -> Result<(u64, u64, u64), ChainStorageError> { let access = txn.access(); @@ -412,3 +412,40 @@ pub fn fetch_db_entry_sizes(txn: &ConstTransaction<'_>, db: &Database) -> Result } Ok((num_entries, total_key_size, total_value_size)) } + +pub fn lmdb_delete_each_where( + txn: &WriteTransaction<'_>, + db: &Database, + mut predicate: F, +) -> Result +where + K: FromLmdbBytes + ?Sized, + V: DeserializeOwned, + F: FnMut(&K, V) -> Option, +{ + let mut cursor = txn.cursor(db)?; + let mut access = txn.access(); + let mut num_deleted = 0; + while let Some((k, v)) = cursor.next::(&access).to_opt()? { + match deserialize(v) { + Ok(v) => match predicate(k, v) { + Some(true) => { + cursor.del(&mut access, del::Flags::empty())?; + num_deleted += 1; + }, + Some(false) => continue, + None => { + break; + }, + }, + Err(e) => { + error!( + target: LOG_TARGET, + "Could not could not deserialize value from lmdb: {:?}", e + ); + return Err(ChainStorageError::AccessError(e.to_string())); + }, + } + } + Ok(num_deleted) +} diff --git a/base_layer/core/src/chain_storage/lmdb_db/lmdb_db.rs b/base_layer/core/src/chain_storage/lmdb_db/lmdb_db.rs index 6a8908ad75..fd8bc37be5 100644 --- a/base_layer/core/src/chain_storage/lmdb_db/lmdb_db.rs +++ b/base_layer/core/src/chain_storage/lmdb_db/lmdb_db.rs @@ -50,6 +50,7 @@ use crate::{ lmdb::{ fetch_db_entry_sizes, lmdb_delete, + lmdb_delete_each_where, lmdb_delete_key_value, lmdb_delete_keys_starting_with, lmdb_exists, @@ -118,6 +119,7 @@ const LMDB_DB_MONERO_SEED_HEIGHT: &str = "monero_seed_height"; const LMDB_DB_ORPHAN_HEADER_ACCUMULATED_DATA: &str = "orphan_accumulated_data"; const LMDB_DB_ORPHAN_CHAIN_TIPS: &str = "orphan_chain_tips"; const LMDB_DB_ORPHAN_PARENT_MAP_INDEX: &str = "orphan_parent_map_index"; +const LMDB_DB_BAD_BLOCK_LIST: &str = "bad_blocks"; pub fn create_lmdb_database>(path: P, config: LMDBConfig) -> Result { let flags = db::CREATE; @@ -129,7 +131,7 @@ pub fn create_lmdb_database>(path: P, config: LMDBConfig) -> Resu let lmdb_store = LMDBBuilder::new() .set_path(path) .set_env_config(config) - .set_max_number_of_databases(20) + .set_max_number_of_databases(40) .add_database(LMDB_DB_METADATA, flags | db::INTEGERKEY) .add_database(LMDB_DB_HEADERS, flags | db::INTEGERKEY) .add_database(LMDB_DB_HEADER_ACCUMULATED_DATA, flags | db::INTEGERKEY) @@ -150,6 +152,7 @@ pub fn create_lmdb_database>(path: P, config: LMDBConfig) -> Resu .add_database(LMDB_DB_MONERO_SEED_HEIGHT, flags) .add_database(LMDB_DB_ORPHAN_CHAIN_TIPS, flags) .add_database(LMDB_DB_ORPHAN_PARENT_MAP_INDEX, flags | db::DUPSORT) + .add_database(LMDB_DB_BAD_BLOCK_LIST, flags) .build() .map_err(|err| ChainStorageError::CriticalError(format!("Could not create LMDB store:{}", err)))?; debug!(target: LOG_TARGET, "LMDB database creation successful"); @@ -180,6 +183,7 @@ pub struct LMDBDatabase { orphan_header_accumulated_data_db: DatabaseRef, orphan_chain_tips_db: DatabaseRef, orphan_parent_map_index: DatabaseRef, + bad_blocks: DatabaseRef, _file_lock: Arc, } @@ -211,6 +215,7 @@ impl LMDBDatabase { monero_seed_height_db: get_database(&store, LMDB_DB_MONERO_SEED_HEIGHT)?, orphan_chain_tips_db: get_database(&store, LMDB_DB_ORPHAN_CHAIN_TIPS)?, orphan_parent_map_index: get_database(&store, LMDB_DB_ORPHAN_PARENT_MAP_INDEX)?, + bad_blocks: get_database(&store, LMDB_DB_BAD_BLOCK_LIST)?, env, env_config: store.env_config(), _file_lock: Arc::new(file_lock), @@ -397,6 +402,9 @@ impl LMDBDatabase { MetadataValue::HorizonData(horizon_data.clone()), )?; }, + InsertBadBlock { hash, height } => { + self.insert_bad_block_and_cleanup(&write_txn, hash, *height)?; + }, } } write_txn.commit()?; @@ -404,7 +412,7 @@ impl LMDBDatabase { Ok(()) } - fn all_dbs(&self) -> [(&'static str, &DatabaseRef); 20] { + fn all_dbs(&self) -> [(&'static str, &DatabaseRef); 21] { [ ("metadata_db", &self.metadata_db), ("headers_db", &self.headers_db), @@ -432,6 +440,7 @@ impl LMDBDatabase { ("monero_seed_height_db", &self.monero_seed_height_db), ("orphan_chain_tips_db", &self.orphan_chain_tips_db), ("orphan_parent_map_index", &self.orphan_parent_map_index), + ("bad_blocks", &self.bad_blocks), ] } @@ -1272,6 +1281,31 @@ impl LMDBDatabase { fn fetch_last_header_in_txn(&self, txn: &ConstTransaction<'_>) -> Result, ChainStorageError> { lmdb_last(txn, &self.headers_db) } + + fn insert_bad_block_and_cleanup( + &self, + txn: &WriteTransaction<'_>, + hash: &HashOutput, + height: u64, + ) -> Result<(), ChainStorageError> { + const CLEAN_BAD_BLOCKS_BEFORE_REL_HEIGHT: u64 = 10000; + + lmdb_replace(txn, &self.bad_blocks, hash, &height)?; + // Clean up bad blocks that are far from the tip + let metadata = fetch_metadata(&*txn, &self.metadata_db)?; + let deleted_before_height = metadata + .height_of_longest_chain() + .saturating_sub(CLEAN_BAD_BLOCKS_BEFORE_REL_HEIGHT); + if deleted_before_height == 0 { + return Ok(()); + } + + let num_deleted = + lmdb_delete_each_where::<[u8], u64, _>(txn, &self.bad_blocks, |_, v| Some(v < deleted_before_height))?; + debug!(target: LOG_TARGET, "Cleaned out {} stale bad blocks", num_deleted); + + Ok(()) + } } pub fn create_recovery_lmdb_database>(path: P) -> Result<(), ChainStorageError> { @@ -2050,6 +2084,37 @@ impl BlockchainBackend for LMDBDatabase { }) .collect() } + + fn bad_block_exists(&self, block_hash: HashOutput) -> Result { + let txn = self.read_transaction()?; + lmdb_exists(&txn, &self.bad_blocks, &block_hash) + } + + fn clear_all_pending_headers(&self) -> Result { + let txn = self.write_transaction()?; + let last_header = match self.fetch_last_header_in_txn(&txn)? { + Some(h) => h, + None => { + return Ok(0); + }, + }; + let metadata = fetch_metadata(&txn, &self.metadata_db)?; + + if metadata.height_of_longest_chain() == last_header.height { + return Ok(0); + } + + let start = metadata.height_of_longest_chain() + 1; + let end = last_header.height; + + let mut num_deleted = 0; + for h in (start..=end).rev() { + self.delete_header(&txn, h)?; + num_deleted += 1; + } + txn.commit()?; + Ok(num_deleted) + } } // Fetch the chain metadata diff --git a/base_layer/core/src/chain_storage/tests/blockchain_database.rs b/base_layer/core/src/chain_storage/tests/blockchain_database.rs index 0d7365229d..2392717cb7 100644 --- a/base_layer/core/src/chain_storage/tests/blockchain_database.rs +++ b/base_layer/core/src/chain_storage/tests/blockchain_database.rs @@ -21,11 +21,11 @@ // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. use crate::{ - blocks::{Block, BlockHeader, NewBlockTemplate}, + blocks::{Block, BlockHeader, BlockHeaderAccumulatedData, ChainHeader, NewBlockTemplate}, chain_storage::{BlockchainDatabase, ChainStorageError}, consensus::ConsensusManager, crypto::tari_utilities::hex::Hex, - proof_of_work::Difficulty, + proof_of_work::{AchievedTargetDifficulty, Difficulty, PowAlgorithm}, tari_utilities::Hashable, test_helpers::{ blockchain::{create_new_blockchain, TempDatabase}, @@ -440,11 +440,14 @@ mod fetch_total_size_stats { use super::*; #[test] - fn it_works_when_db_is_empty() { + fn it_measures_the_number_of_entries() { let db = setup(); + let _ = add_many_chained_blocks(2, &db); let stats = db.fetch_total_size_stats().unwrap(); - // Returns one per db - assert_eq!(stats.sizes().len(), 20); + assert_eq!( + stats.sizes().iter().find(|s| s.name == "utxos_db").unwrap().num_entries, + 4003 + ); } } @@ -572,3 +575,52 @@ mod fetch_header_containing_kernel_mmr { matches!(err, ChainStorageError::ValueNotFound { .. }); } } + +mod clear_all_pending_headers { + use super::*; + + #[test] + fn it_clears_no_headers() { + let db = setup(); + assert_eq!(db.clear_all_pending_headers().unwrap(), 0); + let _ = add_many_chained_blocks(2, &db); + db.clear_all_pending_headers().unwrap(); + let last_header = db.fetch_last_header().unwrap(); + assert_eq!(last_header.height, 2); + } + + #[test] + fn it_clears_headers_after_tip() { + let db = setup(); + let _ = add_many_chained_blocks(2, &db); + let prev_block = db.fetch_block(2).unwrap(); + let mut prev_accum = prev_block.accumulated_data.clone(); + let mut prev_block = Arc::new(prev_block.try_into_block().unwrap()); + let headers = (0..5) + .map(|_| { + let (block, _) = create_next_block(&prev_block, vec![]); + let accum = BlockHeaderAccumulatedData::builder(&prev_accum) + .with_hash(block.hash()) + .with_achieved_target_difficulty( + AchievedTargetDifficulty::try_construct(PowAlgorithm::Sha3, 0.into(), 0.into()).unwrap(), + ) + .with_total_kernel_offset(Default::default()) + .build() + .unwrap(); + + let header = ChainHeader::try_construct(block.header.clone(), accum.clone()).unwrap(); + + prev_block = block; + prev_accum = accum; + header + }) + .collect(); + db.insert_valid_headers(headers).unwrap(); + let last_header = db.fetch_last_header().unwrap(); + assert_eq!(last_header.height, 7); + let num_deleted = db.clear_all_pending_headers().unwrap(); + assert_eq!(num_deleted, 5); + let last_header = db.fetch_last_header().unwrap(); + assert_eq!(last_header.height, 2); + } +} diff --git a/base_layer/core/src/consensus/consensus_constants.rs b/base_layer/core/src/consensus/consensus_constants.rs index 40fc48480a..c45ac75d3b 100644 --- a/base_layer/core/src/consensus/consensus_constants.rs +++ b/base_layer/core/src/consensus/consensus_constants.rs @@ -67,6 +67,8 @@ pub struct ConsensusConstants { faucet_value: MicroTari, /// Transaction Weight params transaction_weight: TransactionWeight, + /// Maximum byte size of TariScript + max_script_byte_size: usize, } /// This is just a convenience wrapper to put all the info into a hashmap per diff algo @@ -167,6 +169,11 @@ impl ConsensusConstants { self.median_timestamp_count } + /// The maximum serialized byte size of TariScript + pub fn get_max_script_byte_size(&self) -> usize { + self.max_script_byte_size + } + /// This is the min initial difficulty that can be requested for the pow pub fn min_pow_difficulty(&self, pow_algo: PowAlgorithm) -> Difficulty { match self.proof_of_work.get(&pow_algo) { @@ -226,6 +233,7 @@ impl ConsensusConstants { proof_of_work: algos, faucet_value: (5000 * 4000) * T, transaction_weight: TransactionWeight::v2(), + max_script_byte_size: 2048, }] } @@ -260,6 +268,7 @@ impl ConsensusConstants { proof_of_work: algos, faucet_value: (5000 * 4000) * T, transaction_weight: TransactionWeight::v1(), + max_script_byte_size: 2048, }] } @@ -321,6 +330,7 @@ impl ConsensusConstants { proof_of_work: algos, faucet_value: (5000 * 4000) * T, transaction_weight: TransactionWeight::v1(), + max_script_byte_size: 2048, }, ConsensusConstants { effective_from_height: 1400, @@ -337,6 +347,7 @@ impl ConsensusConstants { proof_of_work: algos2, faucet_value: (5000 * 4000) * T, transaction_weight: TransactionWeight::v1(), + max_script_byte_size: 2048, }, ] } @@ -371,6 +382,7 @@ impl ConsensusConstants { proof_of_work: algos, faucet_value: (5000 * 4000) * T, transaction_weight: TransactionWeight::v1(), + max_script_byte_size: 2048, }] } @@ -407,6 +419,7 @@ impl ConsensusConstants { proof_of_work: algos, faucet_value: (5000 * 4000) * T, transaction_weight: TransactionWeight::v2(), + max_script_byte_size: 2048, }] } @@ -441,6 +454,7 @@ impl ConsensusConstants { proof_of_work: algos, faucet_value: MicroTari::from(0), transaction_weight: TransactionWeight::v2(), + max_script_byte_size: 2048, }] } } @@ -478,6 +492,11 @@ impl ConsensusConstantsBuilder { self } + pub fn with_max_script_byte_size(mut self, byte_size: usize) -> Self { + self.consensus.max_script_byte_size = byte_size; + self + } + pub fn with_max_block_transaction_weight(mut self, weight: u64) -> Self { self.consensus.max_block_transaction_weight = weight; self diff --git a/base_layer/core/src/test_helpers/blockchain.rs b/base_layer/core/src/test_helpers/blockchain.rs index f18b55ee65..c91e77aebd 100644 --- a/base_layer/core/src/test_helpers/blockchain.rs +++ b/base_layer/core/src/test_helpers/blockchain.rs @@ -138,7 +138,7 @@ pub fn create_store_with_consensus_and_validators_and_config( pub fn create_store_with_consensus(rules: ConsensusManager) -> BlockchainDatabase { let factories = CryptoFactories::default(); let validators = Validators::new( - BodyOnlyValidator::default(), + BodyOnlyValidator::new(rules.clone()), MockValidator::new(true), OrphanBlockValidator::new(rules.clone(), false, factories), ); @@ -314,6 +314,10 @@ impl BlockchainBackend for TempDatabase { self.db.as_ref().unwrap().fetch_last_header() } + fn clear_all_pending_headers(&self) -> Result { + self.db.as_ref().unwrap().clear_all_pending_headers() + } + fn fetch_last_chain_header(&self) -> Result { self.db.as_ref().unwrap().fetch_last_chain_header() } @@ -386,6 +390,10 @@ impl BlockchainBackend for TempDatabase { .unwrap() .fetch_header_hash_by_deleted_mmr_positions(mmr_positions) } + + fn bad_block_exists(&self, block_hash: HashOutput) -> Result { + self.db.as_ref().unwrap().bad_block_exists(block_hash) + } } pub fn create_chained_blocks>( diff --git a/base_layer/core/src/validation/block_validators/async_validator.rs b/base_layer/core/src/validation/block_validators/async_validator.rs index 7fd1617f64..1d336d7a38 100644 --- a/base_layer/core/src/validation/block_validators/async_validator.rs +++ b/base_layer/core/src/validation/block_validators/async_validator.rs @@ -324,6 +324,7 @@ impl BlockValidator { .map(|outputs| { let range_proof_prover = self.factories.range_proof.clone(); let db = self.db.inner().clone(); + let max_script_size = self.rules.consensus_constants(height).get_max_script_byte_size(); task::spawn_blocking(move || { let db = db.db_read_access()?; let mut aggregate_sender_offset = PublicKey::default(); @@ -351,6 +352,8 @@ impl BlockValidator { aggregate_sender_offset = aggregate_sender_offset + &output.sender_offset_public_key; } + helpers::check_tari_script_byte_size(&output.script, max_script_size)?; + output.verify_metadata_signature()?; if !bypass_range_proof_verification { output.verify_range_proof(&range_proof_prover)?; diff --git a/base_layer/core/src/validation/block_validators/body_only.rs b/base_layer/core/src/validation/block_validators/body_only.rs index c6dfc9228f..2fb4ac80f5 100644 --- a/base_layer/core/src/validation/block_validators/body_only.rs +++ b/base_layer/core/src/validation/block_validators/body_only.rs @@ -25,6 +25,7 @@ use crate::{ blocks::ChainBlock, chain_storage, chain_storage::BlockchainBackend, + consensus::ConsensusManager, crypto::tari_utilities::hex::Hex, validation::{helpers, PostOrphanBodyValidation, ValidationError}, }; @@ -36,8 +37,15 @@ use tari_common_types::chain_metadata::ChainMetadata; /// This validator checks whether a block satisfies *all* consensus rules. If a block passes this validator, it is the /// next block on the blockchain. -#[derive(Default)] -pub struct BodyOnlyValidator; +pub struct BodyOnlyValidator { + rules: ConsensusManager, +} + +impl BodyOnlyValidator { + pub fn new(rules: ConsensusManager) -> Self { + Self { rules } + } +} impl PostOrphanBodyValidation for BodyOnlyValidator { /// The consensus checks that are done (in order of cheapest to verify to most expensive): @@ -66,7 +74,11 @@ impl PostOrphanBodyValidation for BodyOnlyValidator { let block_id = format!("block #{} ({})", block.header().height, block.hash().to_hex()); helpers::check_inputs_are_utxos(backend, &block.block().body)?; - helpers::check_not_duplicate_txos(backend, &block.block().body)?; + helpers::check_outputs( + backend, + self.rules.consensus_constants(block.height()), + &block.block().body, + )?; trace!( target: LOG_TARGET, "Block validation: All inputs and outputs are valid for {}", @@ -74,6 +86,7 @@ impl PostOrphanBodyValidation for BodyOnlyValidator { ); let mmr_roots = chain_storage::calculate_mmr_roots(backend, block.block())?; helpers::check_mmr_roots(block.header(), &mmr_roots)?; + helpers::check_not_bad_block(backend, block.hash())?; trace!( target: LOG_TARGET, "Block validation: MMR roots are valid for {}", diff --git a/base_layer/core/src/validation/block_validators/test.rs b/base_layer/core/src/validation/block_validators/test.rs index 2b5fcbbd75..9541756701 100644 --- a/base_layer/core/src/validation/block_validators/test.rs +++ b/base_layer/core/src/validation/block_validators/test.rs @@ -21,6 +21,7 @@ // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. use std::sync::Arc; +use tari_crypto::script; use tari_common::configuration::Network; use tari_test_utils::unpack_enum; @@ -189,3 +190,27 @@ async fn it_checks_txo_sort_order() { let err = validator.validate_block_body(block.block().clone()).await.unwrap_err(); assert!(matches!(err, ValidationError::UnsortedOrDuplicateOutput)); } + +#[tokio::test] +async fn it_limits_the_script_byte_size() { + let rules = ConsensusManager::builder(Network::LocalNet) + .add_consensus_constants( + ConsensusConstantsBuilder::new(Network::LocalNet) + .with_coinbase_lockheight(0) + .with_max_script_byte_size(0) + .build(), + ) + .build(); + let (mut blockchain, validator) = setup_with_rules(rules); + + let (_, coinbase_a) = blockchain.add_next_tip("A", Default::default()); + + let mut schema1 = txn_schema!(from: vec![coinbase_a], to: vec![50 * T, 12 * T]); + schema1.script = script!(Nop Nop Nop); + let (txs, _) = schema_to_transaction(&[schema1]); + let txs = txs.into_iter().map(|t| Arc::try_unwrap(t).unwrap()).collect::>(); + let (block, _) = blockchain.create_next_tip(BlockSpec::new().with_transactions(txs).finish()); + + let err = validator.validate_block_body(block.block().clone()).await.unwrap_err(); + assert!(matches!(err, ValidationError::TariScriptExceedsMaxSize { .. })); +} diff --git a/base_layer/core/src/validation/error.rs b/base_layer/core/src/validation/error.rs index bc2f3bb0d5..d949b7f2f5 100644 --- a/base_layer/core/src/validation/error.rs +++ b/base_layer/core/src/validation/error.rs @@ -89,6 +89,13 @@ pub enum ValidationError { IncorrectPreviousHash { expected: String, block_hash: String }, #[error("Async validation task failed: {0}")] AsyncTaskFailed(#[from] task::JoinError), + #[error("Bad block with hash {hash} found")] + BadBlockFound { hash: String }, + #[error("Script exceeded maximum script size, expected less than {max_script_size} but was {actual_script_size}")] + TariScriptExceedsMaxSize { + max_script_size: usize, + actual_script_size: usize, + }, } // ChainStorageError has a ValidationError variant, so to prevent a cyclic dependency we use a string representation in diff --git a/base_layer/core/src/validation/helpers.rs b/base_layer/core/src/validation/helpers.rs index 3a8f4b0152..30003bf262 100644 --- a/base_layer/core/src/validation/helpers.rs +++ b/base_layer/core/src/validation/helpers.rs @@ -20,15 +20,17 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -use crate::transactions::aggregated_body::AggregateBody; -use log::*; -use tari_crypto::tari_utilities::{epoch_time::EpochTime, hash::Hashable, hex::Hex}; - use crate::{ blocks::{Block, BlockHeader, BlockHeaderValidationError, BlockValidationError}, chain_storage::{BlockchainBackend, MmrRoots, MmrTree}, - consensus::{emission::Emission, ConsensusConstants, ConsensusManager}, - crypto::commitment::HomomorphicCommitmentFactory, + consensus::{ + emission::Emission, + ConsensusConstants, + ConsensusEncodingSized, + ConsensusEncodingWrapper, + ConsensusManager, + }, + crypto::{commitment::HomomorphicCommitmentFactory, tari_utilities::hex::to_hex}, proof_of_work::{ monero_difficulty, monero_rx::MoneroPowData, @@ -40,15 +42,21 @@ use crate::{ PowError, }, transactions::{ + aggregated_body::AggregateBody, tari_amount::MicroTari, transaction_entities::{KernelSum, TransactionError, TransactionInput, TransactionKernel, TransactionOutput}, CryptoFactories, }, validation::ValidationError, }; +use log::*; use std::cmp::Ordering; use tari_common_types::types::{Commitment, CommitmentFactory, PublicKey}; -use tari_crypto::keys::PublicKey as PublicKeyTrait; +use tari_crypto::{ + keys::PublicKey as PublicKeyTrait, + script::TariScript, + tari_utilities::{epoch_time::EpochTime, hash::Hashable, hex::Hex}, +}; pub const LOG_TARGET: &str = "c::val::helpers"; @@ -399,14 +407,34 @@ pub fn check_input_is_utxo(db: &B, input: &TransactionInpu Err(ValidationError::UnknownInput) } -/// This function checks that the outputs do not already exist in the UTxO set. -pub fn check_not_duplicate_txos(db: &B, body: &AggregateBody) -> Result<(), ValidationError> { +/// This function checks: +/// 1. the byte size of TariScript does not exceed the maximum +/// 2. that the outputs do not already exist in the UTxO set. +pub fn check_outputs( + db: &B, + constants: &ConsensusConstants, + body: &AggregateBody, +) -> Result<(), ValidationError> { + let max_script_size = constants.get_max_script_byte_size(); for output in body.outputs() { + check_tari_script_byte_size(&output.script, max_script_size)?; check_not_duplicate_txo(db, output)?; } Ok(()) } +/// Checks the byte size of TariScript is less than or equal to the given size, otherwise returns an error. +pub fn check_tari_script_byte_size(script: &TariScript, max_script_size: usize) -> Result<(), ValidationError> { + let script_size = ConsensusEncodingWrapper::wrap(script).consensus_encode_exact_size(); + if script_size > max_script_size { + return Err(ValidationError::TariScriptExceedsMaxSize { + max_script_size, + actual_script_size: script_size, + }); + } + Ok(()) +} + /// This function checks that the outputs do not already exist in the UTxO set. pub fn check_not_duplicate_txo( db: &B, @@ -503,6 +531,13 @@ pub fn check_mmr_roots(header: &BlockHeader, mmr_roots: &MmrRoots) -> Result<(), Ok(()) } +pub fn check_not_bad_block(db: &B, hash: &[u8]) -> Result<(), ValidationError> { + if db.bad_block_exists(hash.to_vec())? { + return Err(ValidationError::BadBlockFound { hash: to_hex(hash) }); + } + Ok(()) +} + pub fn check_coinbase_reward( factory: &CommitmentFactory, rules: &ConsensusManager, diff --git a/base_layer/core/src/validation/transaction_validators.rs b/base_layer/core/src/validation/transaction_validators.rs index c0d384f98d..48b28139e1 100644 --- a/base_layer/core/src/validation/transaction_validators.rs +++ b/base_layer/core/src/validation/transaction_validators.rs @@ -26,7 +26,7 @@ use crate::{ chain_storage::{BlockchainBackend, BlockchainDatabase}, transactions::{transaction_entities::Transaction, CryptoFactories}, validation::{ - helpers::{check_inputs_are_utxos, check_not_duplicate_txos}, + helpers::{check_inputs_are_utxos, check_outputs}, MempoolTransactionValidation, ValidationError, }, @@ -117,9 +117,10 @@ impl TxInputAndMaturityValidator { impl MempoolTransactionValidation for TxInputAndMaturityValidator { fn validate(&self, tx: &Transaction) -> Result<(), ValidationError> { + let constants = self.db.consensus_constants()?; let db = self.db.db_read_access()?; check_inputs_are_utxos(&*db, tx.body())?; - check_not_duplicate_txos(&*db, tx.body())?; + check_outputs(&*db, constants, tx.body())?; let tip_height = db.fetch_chain_metadata()?.height_of_longest_chain(); verify_timelocks(tx, tip_height)?; diff --git a/base_layer/core/tests/block_validation.rs b/base_layer/core/tests/block_validation.rs index a689152002..b8e3ddba07 100644 --- a/base_layer/core/tests/block_validation.rs +++ b/base_layer/core/tests/block_validation.rs @@ -81,7 +81,7 @@ fn test_genesis_block() { let rules = ConsensusManager::builder(network).build(); let backend = create_test_db(); let validators = Validators::new( - BodyOnlyValidator::default(), + BodyOnlyValidator::new(rules.clone()), HeaderValidator::new(rules.clone()), OrphanBlockValidator::new(rules.clone(), false, factories), ); @@ -268,7 +268,7 @@ fn test_orphan_validator() { let backend = create_test_db(); let orphan_validator = OrphanBlockValidator::new(rules.clone(), false, factories.clone()); let validators = Validators::new( - BodyOnlyValidator::default(), + BodyOnlyValidator::new(rules.clone()), HeaderValidator::new(rules.clone()), orphan_validator.clone(), ); @@ -385,10 +385,10 @@ fn test_orphan_body_validation() { .with_block(genesis.clone()) .build(); let backend = create_test_db(); - let body_only_validator = BodyOnlyValidator::default(); + let body_only_validator = BodyOnlyValidator::new(rules.clone()); let header_validator = HeaderValidator::new(rules.clone()); let validators = Validators::new( - BodyOnlyValidator::default(), + BodyOnlyValidator::new(rules.clone()), HeaderValidator::new(rules.clone()), OrphanBlockValidator::new(rules.clone(), false, factories.clone()), ); @@ -584,7 +584,7 @@ fn test_header_validation() { let backend = create_test_db(); let header_validator = HeaderValidator::new(rules.clone()); let validators = Validators::new( - BodyOnlyValidator::default(), + BodyOnlyValidator::new(rules.clone()), HeaderValidator::new(rules.clone()), OrphanBlockValidator::new(rules.clone(), false, factories.clone()), ); @@ -693,7 +693,7 @@ async fn test_block_sync_body_validator() { let backend = create_test_db(); let validators = Validators::new( - BodyOnlyValidator::default(), + BodyOnlyValidator::new(rules.clone()), HeaderValidator::new(rules.clone()), OrphanBlockValidator::new(rules.clone(), false, factories.clone()), ); diff --git a/base_layer/core/tests/node_service.rs b/base_layer/core/tests/node_service.rs index 15da1b7ac1..867f2e87dd 100644 --- a/base_layer/core/tests/node_service.rs +++ b/base_layer/core/tests/node_service.rs @@ -514,7 +514,7 @@ async fn local_get_new_block_with_zero_conf() { let (mut node, rules) = BaseNodeBuilder::new(network.into()) .with_consensus_manager(rules.clone()) .with_validators( - BodyOnlyValidator::default(), + BodyOnlyValidator::new(rules.clone()), HeaderValidator::new(rules.clone()), OrphanBlockValidator::new(rules, true, factories.clone()), ) @@ -593,7 +593,7 @@ async fn local_get_new_block_with_combined_transaction() { let (mut node, rules) = BaseNodeBuilder::new(network.into()) .with_consensus_manager(rules.clone()) .with_validators( - BodyOnlyValidator::default(), + BodyOnlyValidator::new(rules.clone()), HeaderValidator::new(rules.clone()), OrphanBlockValidator::new(rules, true, factories.clone()), ) diff --git a/common/src/exit_codes.rs b/common/src/exit_codes.rs index 7ab4a9ca9d..0a08db19e9 100644 --- a/common/src/exit_codes.rs +++ b/common/src/exit_codes.rs @@ -3,7 +3,7 @@ use thiserror::Error; /// Enum to show failure information #[derive(Debug, Clone, Error)] pub enum ExitCodes { - #[error("There is an error in the wallet configuration: {0}")] + #[error("There is an error in the configuration: {0}")] ConfigError(String), #[error("The application exited because an unknown error occurred: {0}. Check the logs for more details.")] UnknownError(String), diff --git a/comms/Cargo.toml b/comms/Cargo.toml index 7c461cf921..d0e0c6c0d5 100644 --- a/comms/Cargo.toml +++ b/comms/Cargo.toml @@ -44,14 +44,11 @@ thiserror = "1.0.26" tokio = { version = "1.14", features = ["rt-multi-thread", "time", "sync", "signal", "net", "macros", "io-util"] } tokio-stream = { version = "0.1.7", features = ["sync"] } tokio-util = { version = "0.6.7", features = ["codec", "compat"] } -tower = "0.4" +tower = {version = "0.4", features = ["util"]} tracing = "0.1.26" tracing-futures = "0.2.5" yamux = "=0.9.0" -# RPC dependencies -tower-make = { version = "0.3.0", optional = true } - # Metrics tari_metrics = { path = "../infrastructure/metrics" } @@ -70,4 +67,4 @@ tari_common = { version = "^0.21", path = "../common", features = ["build"] } c_integration = [] avx2 = ["tari_crypto/avx2"] metrics = [] -rpc = ["tower-make", "tower/util"] +rpc = ["tower/make", "tower/util"] diff --git a/comms/dht/Cargo.toml b/comms/dht/Cargo.toml index 3ac370073f..f7f5e7f445 100644 --- a/comms/dht/Cargo.toml +++ b/comms/dht/Cargo.toml @@ -38,7 +38,7 @@ serde_derive = "1.0.90" serde_repr = "0.1.5" thiserror = "1.0.26" tokio = { version = "1.14", features = ["rt", "macros"] } -tower = { version= "0.4", features=["full"] } +tower = { version = "0.4", features = ["full"] } ttl_cache = "0.5.1" # tower-filter dependencies @@ -49,6 +49,8 @@ tari_test_utils = { version = "^0.21", path = "../../infrastructure/test_utils" env_logger = "0.7.0" futures-test = { version = "0.3.5" } +futures-util = "^0.3.1" +lazy_static = "1.4.0" lmdb-zero = "0.4.4" tempfile = "3.1.0" tokio-stream = { version = "0.1.7", features = ["sync"] } @@ -57,9 +59,6 @@ clap = "2.33.0" # tower-filter dependencies tower-test = { version = "^0.4" } -tokio-test = "^0.4.2" -futures-util = "^0.3.1" -lazy_static = "1.4.0" [build-dependencies] tari_common = { version = "^0.21", path = "../../common" } diff --git a/comms/dht/examples/graphing_utilities/utilities.rs b/comms/dht/examples/graphing_utilities/utilities.rs index 8e07f43657..ef70f2f447 100644 --- a/comms/dht/examples/graphing_utilities/utilities.rs +++ b/comms/dht/examples/graphing_utilities/utilities.rs @@ -30,7 +30,7 @@ use petgraph::{ stable_graph::{NodeIndex, StableGraph}, visit::{Bfs, IntoNodeReferences}, }; -use std::{collections::HashMap, convert::TryFrom, fs, fs::File, io::Write, path::Path, process::Command, sync::Mutex}; +use std::{collections::HashMap, fs, fs::File, io::Write, path::Path, process::Command, sync::Mutex}; use tari_comms::{connectivity::ConnectivitySelection, peer_manager::NodeId}; use tari_test_utils::streams::convert_unbounded_mpsc_to_stream; @@ -100,9 +100,7 @@ pub async fn network_graph_snapshot( graph.add_edge( node_index.to_owned(), peer_node_index.to_owned(), - u128::try_from(distance) - .expect("Couldn't convert XorDistance to U128") - .to_string(), + distance.as_u128().to_string(), ); } if let Some(n) = num_neighbours { @@ -123,9 +121,7 @@ pub async fn network_graph_snapshot( neighbour_graph.add_edge( node_index.to_owned(), peer_node_index.to_owned(), - u128::try_from(distance) - .expect("Couldn't convert XorDistance to U128") - .to_string(), + distance.as_u128().to_string(), ); } } diff --git a/comms/dht/src/config.rs b/comms/dht/src/config.rs index 978f3581d9..58acbbb4d7 100644 --- a/comms/dht/src/config.rs +++ b/comms/dht/src/config.rs @@ -76,6 +76,7 @@ pub struct DhtConfig { /// The interval to change the random pool peers. /// Default: 2 hours pub connectivity_random_pool_refresh: Duration, + pub connectivity_high_failure_rate_cooldown: Duration, /// Network discovery config pub network_discovery: NetworkDiscoveryConfig, /// Length of time to ban a peer if the peer misbehaves at the DHT-level. @@ -144,6 +145,7 @@ impl Default for DhtConfig { discovery_request_timeout: Duration::from_secs(2 * 60), connectivity_update_interval: Duration::from_secs(2 * 60), connectivity_random_pool_refresh: Duration::from_secs(2 * 60 * 60), + connectivity_high_failure_rate_cooldown: Duration::from_secs(45), auto_join: false, join_cooldown_interval: Duration::from_secs(10 * 60), network_discovery: Default::default(), diff --git a/comms/dht/src/connectivity/mod.rs b/comms/dht/src/connectivity/mod.rs index 07b572b997..c95f98d41c 100644 --- a/comms/dht/src/connectivity/mod.rs +++ b/comms/dht/src/connectivity/mod.rs @@ -31,7 +31,7 @@ use log::*; use std::{sync::Arc, time::Instant}; use tari_comms::{ connectivity::{ConnectivityError, ConnectivityEvent, ConnectivityEventRx, ConnectivityRequester}, - peer_manager::{node_id::NodeDistance, NodeId, PeerManagerError, PeerQuery, PeerQuerySortBy}, + peer_manager::{NodeDistance, NodeId, PeerManagerError, PeerQuery, PeerQuerySortBy}, NodeIdentity, PeerConnection, PeerManager, @@ -80,6 +80,8 @@ pub struct DhtConnectivity { stats: Stats, dht_events: broadcast::Receiver>, metrics_collector: MetricsCollectorHandle, + cooldown_in_effect: Option, + recent_connection_failure_count: usize, shutdown_signal: ShutdownSignal, } @@ -108,6 +110,8 @@ impl DhtConnectivity { random_pool_last_refresh: None, stats: Stats::new(), dht_events, + recent_connection_failure_count: 0, + cooldown_in_effect: None, shutdown_signal, } } @@ -139,7 +143,7 @@ impl DhtConnectivity { self.refresh_neighbour_pool().await?; let mut ticker = time::interval(self.config.connectivity_update_interval); - ticker.set_missed_tick_behavior(MissedTickBehavior::Delay); + ticker.set_missed_tick_behavior(MissedTickBehavior::Skip); loop { tokio::select! { Ok(event) = connectivity_events.recv() => { @@ -155,12 +159,15 @@ impl DhtConnectivity { }, _ = ticker.tick() => { - if let Err(err) = self.refresh_random_pool_if_required().await { - debug!(target: LOG_TARGET, "Error refreshing random peer pool: {:?}", err); - } if let Err(err) = self.check_and_ban_flooding_peers().await { debug!(target: LOG_TARGET, "Error checking for peer flooding: {:?}", err); } + if let Err(err) = self.refresh_neighbour_pool_if_required().await { + debug!(target: LOG_TARGET, "Error refreshing neighbour peer pool: {:?}", err); + } + if let Err(err) = self.refresh_random_pool_if_required().await { + debug!(target: LOG_TARGET, "Error refreshing random peer pool: {:?}", err); + } self.log_status(); }, @@ -185,8 +192,16 @@ impl DhtConnectivity { .partition::, _>(|peer| self.connection_handles.iter().any(|c| c.peer_node_id() == *peer)); debug!( target: LOG_TARGET, - "DHT connectivity: neighbour pool: {}/{} ({} connected), random pool: {}/{} ({} connected, last refreshed \ - {}), active DHT connections: {}/{}", + "DHT connectivity status: {}neighbour pool: {}/{} ({} connected), random pool: {}/{} ({} connected, last \ + refreshed {}), active DHT connections: {}/{}", + self.cooldown_in_effect + .map(|ts| format!( + "COOLDOWN({:.2?} remaining) ", + self.config + .connectivity_high_failure_rate_cooldown + .saturating_sub(ts.elapsed()) + )) + .unwrap_or_else(String::new), self.neighbours.len(), self.config.num_neighbouring_nodes, neighbour_connected.len(), @@ -227,9 +242,7 @@ impl DhtConnectivity { "Network discovery discovered {} more neighbouring peers. Reinitializing pools", info.num_new_peers ); - if info.has_new_neighbours() { - self.refresh_peer_pools().await?; - } + self.refresh_peer_pools().await?; } }, _ => {}, @@ -273,6 +286,25 @@ impl DhtConnectivity { Ok(()) } + async fn refresh_neighbour_pool_if_required(&mut self) -> Result<(), DhtConnectivityError> { + if self.num_connected_neighbours() < self.config.num_neighbouring_nodes { + self.refresh_neighbour_pool().await?; + } + + Ok(()) + } + + fn num_connected_neighbours(&self) -> usize { + self.neighbours + .iter() + .filter(|peer| self.connection_handles.iter().any(|c| c.peer_node_id() == *peer)) + .count() + } + + fn connected_peers_iter(&self) -> impl Iterator { + self.connection_handles.iter().map(|c| c.peer_node_id()) + } + async fn refresh_neighbour_pool(&mut self) -> Result<(), DhtConnectivityError> { let mut new_neighbours = self .fetch_neighbouring_peers(self.config.num_neighbouring_nodes, &[]) @@ -312,7 +344,10 @@ impl DhtConnectivity { difference.iter().for_each(|peer| { self.remove_connection_handle(peer); }); - self.connectivity.request_many_dials(new_neighbours).await?; + + if !new_neighbours.is_empty() { + self.connectivity.request_many_dials(new_neighbours).await?; + } Ok(()) } @@ -338,39 +373,42 @@ impl DhtConnectivity { target: LOG_TARGET, "Unable to refresh random peer pool because there are insufficient known peers", ); - } else { - let (intersection, difference) = self - .random_pool - .drain(..) - .partition::, _>(|n| random_peers.contains(n)); - // Remove the peers that we want to keep from the `random_peers` to be added - random_peers.retain(|n| !intersection.contains(n)); - self.random_pool = intersection; - debug!( - target: LOG_TARGET, - "Adding new peers to random peer pool (#new = {}, #keeping = {}, #removing = {})", - random_peers.len(), - self.random_pool.len(), - difference.len() - ); - trace!( - target: LOG_TARGET, - "Random peers: Adding = {:?}, Removing = {:?}", - random_peers, - difference - ); - self.random_pool.extend(random_peers.clone()); - // Drop any connection handles that removed from the random pool - difference.iter().for_each(|peer| { - self.remove_connection_handle(peer); - }); - self.connectivity.request_many_dials(random_peers).await?; + return Ok(()); } + + let (intersection, difference) = self + .random_pool + .drain(..) + .partition::, _>(|n| random_peers.contains(n)); + // Remove the peers that we want to keep from the `random_peers` to be added + random_peers.retain(|n| !intersection.contains(n)); + self.random_pool = intersection; + debug!( + target: LOG_TARGET, + "Adding new peers to random peer pool (#new = {}, #keeping = {}, #removing = {})", + random_peers.len(), + self.random_pool.len(), + difference.len() + ); + trace!( + target: LOG_TARGET, + "Random peers: Adding = {:?}, Removing = {:?}", + random_peers, + difference + ); + self.random_pool.extend(random_peers.clone()); + // Drop any connection handles that removed from the random pool + difference.iter().for_each(|peer| { + self.remove_connection_handle(peer); + }); + self.connectivity.request_many_dials(random_peers).await?; + self.random_pool_last_refresh = Some(Instant::now()); Ok(()) } async fn handle_new_peer_connected(&mut self, conn: PeerConnection) -> Result<(), DhtConnectivityError> { + self.peer_manager.mark_last_seen(conn.peer_node_id()).await?; if conn.peer_features().is_client() { debug!( target: LOG_TARGET, @@ -426,7 +464,7 @@ impl DhtConnectivity { fn remove_connection_handle(&mut self, node_id: &NodeId) { if let Some(idx) = self.connection_handles.iter().position(|c| c.peer_node_id() == node_id) { - let conn = self.connection_handles.remove(idx); + let conn = self.connection_handles.swap_remove(idx); debug!(target: LOG_TARGET, "Removing peer connection {}", conn); } } @@ -438,9 +476,9 @@ impl DhtConnectivity { PeerConnected(conn) => { self.handle_new_peer_connected(conn).await?; }, - PeerConnectFailed(node_id) | PeerDisconnected(node_id) => { + PeerConnectFailed(node_id) => { if self.metrics_collector.clear_metrics(node_id.clone()).await.is_err() { - warn!( + debug!( target: LOG_TARGET, "Failed to clear metrics for peer `{}`. Metric collector is shut down.", node_id ); @@ -450,9 +488,52 @@ impl DhtConnectivity { return Ok(()); } - self.replace_pool_peer(&node_id).await?; + const TOLERATED_CONNECTION_FAILURES: usize = 40; + if self.recent_connection_failure_count < TOLERATED_CONNECTION_FAILURES { + self.recent_connection_failure_count += 1; + } + + if self.recent_connection_failure_count == TOLERATED_CONNECTION_FAILURES && + self.cooldown_in_effect.is_none() + { + warn!( + target: LOG_TARGET, + "Too many ({}) connection failures, cooldown is in effect", TOLERATED_CONNECTION_FAILURES + ); + self.cooldown_in_effect = Some(Instant::now()); + } + + if self + .cooldown_in_effect + .map(|ts| ts.elapsed() >= self.config.connectivity_high_failure_rate_cooldown) + .unwrap_or(true) + { + if self.cooldown_in_effect.is_some() { + self.cooldown_in_effect = None; + self.recent_connection_failure_count = 1; + } + self.replace_pool_peer(&node_id).await?; + } + self.log_status(); + }, + PeerDisconnected(node_id) => { + if self.metrics_collector.clear_metrics(node_id.clone()).await.is_err() { + debug!( + target: LOG_TARGET, + "Failed to clear metrics for peer `{}`. Metric collector is shut down.", node_id + ); + }; + if !self.is_pool_peer(&node_id) { + debug!(target: LOG_TARGET, "{} is not managed by the DHT. Ignoring", node_id); + return Ok(()); + } + debug!(target: LOG_TARGET, "Pool peer {} disconnected. Redialling...", node_id); + // Attempt to reestablish the lost connection to the pool peer. If reconnection fails, + // it is replaced with another peer (replace_pool_peer via PeerConnectFailed) + self.connectivity.request_many_dials([node_id]).await?; }, ConnectivityStateOnline(n) => { + self.refresh_peer_pools().await?; if self.config.auto_join && self.should_send_join() { debug!( target: LOG_TARGET, @@ -478,19 +559,26 @@ impl DhtConnectivity { async fn replace_pool_peer(&mut self, current_peer: &NodeId) -> Result<(), DhtConnectivityError> { if self.random_pool.contains(current_peer) { + let exclude = self.get_pool_peers(); + let pos = self + .random_pool + .iter() + .position(|n| n == current_peer) + .expect("unreachable panic"); + self.random_pool.swap_remove(pos); + debug!( target: LOG_TARGET, - "Peer '{}' in random pool is offline. Adding a new random peer if possible", current_peer + "Peer '{}' in random pool is unavailable. Adding a new random peer if possible", current_peer ); - let exclude = self.get_pool_peers(); match self.fetch_random_peers(1, &exclude).await?.pop() { Some(new_peer) => { self.remove_connection_handle(current_peer); if let Some(pos) = self.random_pool.iter().position(|n| n == current_peer) { - self.random_pool.remove(pos); + self.random_pool.swap_remove(pos); } self.random_pool.push(new_peer.clone()); - self.connectivity.request_many_dials(vec![new_peer]).await?; + self.connectivity.request_many_dials([new_peer]).await?; }, None => { debug!( @@ -505,11 +593,18 @@ impl DhtConnectivity { } if self.neighbours.contains(current_peer) { + let exclude = self.get_pool_peers(); + let pos = self + .neighbours + .iter() + .position(|n| n == current_peer) + .expect("unreachable panic"); + self.neighbours.remove(pos); + debug!( target: LOG_TARGET, "Peer '{}' in neighbour pool is offline. Adding a new peer if possible", current_peer ); - let exclude = self.get_pool_peers(); match self.fetch_neighbouring_peers(1, &exclude).await?.pop() { Some(node_id) => { self.remove_connection_handle(current_peer); @@ -517,7 +612,7 @@ impl DhtConnectivity { self.neighbours.remove(pos); } self.insert_neighbour(node_id.clone()); - self.connectivity.request_many_dials(vec![node_id]).await?; + self.connectivity.request_many_dials([node_id]).await?; }, None => { info!( @@ -590,6 +685,7 @@ impl DhtConnectivity { ) -> Result, DhtConnectivityError> { let peer_manager = &self.peer_manager; let node_id = self.node_identity.node_id(); + let connected = self.connected_peers_iter().collect::>(); // Fetch to all n nearest neighbour Communication Nodes // which are eligible for connection. // Currently that means: @@ -601,6 +697,7 @@ impl DhtConnectivity { let mut banned_count = 0; let mut excluded_count = 0; let mut filtered_out_node_count = 0; + let mut already_connected = 0; let query = PeerQuery::new() .select_where(|peer| { if peer.is_banned() { @@ -613,6 +710,11 @@ impl DhtConnectivity { return false; } + if connected.contains(&&peer.node_id) { + already_connected += 1; + return false; + } + if peer .offline_since() .map(|since| since <= self.config.offline_peer_cooldown) @@ -630,8 +732,9 @@ impl DhtConnectivity { true }) - .sort_by(PeerQuerySortBy::DistanceFrom(node_id)) - .limit(n); + .sort_by(PeerQuerySortBy::DistanceFromLastConnected(node_id)) + // Fetch double here so that there is a bigger closest peer set that can be ordered by last seen + .limit(n * 2); let peers = peer_manager.perform_query(query).await?; let total_excluded = banned_count + connect_ineligable_count + excluded_count + filtered_out_node_count; @@ -640,18 +743,20 @@ impl DhtConnectivity { target: LOG_TARGET, "\n====================================\n Closest Peer Selection\n\n {num_peers} peer(s) selected\n \ {total} peer(s) were not selected \n\n {banned} banned\n {filtered_out} not communication node\n \ - {not_connectable} are not connectable\n {excluded} explicitly excluded \ + {not_connectable} are not connectable\n {excluded} explicitly excluded\n {already_connected} already \ + connected \n====================================\n", num_peers = peers.len(), total = total_excluded, banned = banned_count, filtered_out = filtered_out_node_count, not_connectable = connect_ineligable_count, - excluded = excluded_count + excluded = excluded_count, + already_connected = already_connected ); } - Ok(peers.into_iter().map(|p| p.node_id).collect()) + Ok(peers.into_iter().map(|p| p.node_id).take(n).collect()) } async fn fetch_random_peers(&self, n: usize, excluded: &[NodeId]) -> Result, DhtConnectivityError> { diff --git a/comms/dht/src/connectivity/test.rs b/comms/dht/src/connectivity/test.rs index 876b54d42c..21ec72bf21 100644 --- a/comms/dht/src/connectivity/test.rs +++ b/comms/dht/src/connectivity/test.rs @@ -165,7 +165,6 @@ async fn added_neighbours() { #[runtime::test] async fn replace_peer_when_peer_goes_offline() { let node_identity = make_node_identity(); - // let node_identities = repeat_with(|| make_node_identity()).take(5).collect::>(); let node_identities = ordered_node_identities_by_distance(node_identity.node_id(), 6, PeerFeatures::COMMUNICATION_NODE); // Closest to this node @@ -193,7 +192,22 @@ async fn replace_peer_when_peer_goes_offline() { connectivity.publish_event(ConnectivityEvent::PeerDisconnected( node_identities[4].node_id().clone(), )); - connectivity.publish_event(ConnectivityEvent::ConnectivityStateOffline); + + async_assert!( + connectivity.call_count().await >= 1, + max_attempts = 20, + interval = Duration::from_millis(10), + ); + + let _ = connectivity.take_calls().await; + // Redial + let dialed = connectivity.take_dialed_peers().await; + assert_eq!(dialed.len(), 1); + assert_eq!(dialed[0], *node_identities[4].node_id()); + + connectivity.publish_event(ConnectivityEvent::PeerConnectFailed( + node_identities[4].node_id().clone(), + )); async_assert!( connectivity.call_count().await >= 1, @@ -203,8 +217,8 @@ async fn replace_peer_when_peer_goes_offline() { // Check that the next closer neighbour was added to the pool let dialed = connectivity.take_dialed_peers().await; - assert_eq!(dialed.len(), 2); - assert_eq!(dialed[0], *node_identities.last().unwrap().node_id()); + assert_eq!(dialed.len(), 1); + assert_eq!(dialed[0], *node_identities[5].node_id()); } #[runtime::test] diff --git a/comms/dht/src/network_discovery/discovering.rs b/comms/dht/src/network_discovery/discovering.rs index 91f4797c97..21bb93bf88 100644 --- a/comms/dht/src/network_discovery/discovering.rs +++ b/comms/dht/src/network_discovery/discovering.rs @@ -30,7 +30,7 @@ use log::*; use std::convert::TryInto; use tari_comms::{ connectivity::ConnectivityError, - peer_manager::{node_id::NodeDistance, NodeId, Peer, PeerFeatures}, + peer_manager::{NodeDistance, NodeId, Peer, PeerFeatures}, PeerConnection, PeerManager, }; diff --git a/comms/dht/src/rpc/test.rs b/comms/dht/src/rpc/test.rs index ef250217a2..677ca94a1a 100644 --- a/comms/dht/src/rpc/test.rs +++ b/comms/dht/src/rpc/test.rs @@ -28,7 +28,7 @@ use crate::{ use futures::StreamExt; use std::{convert::TryInto, sync::Arc, time::Duration}; use tari_comms::{ - peer_manager::{node_id::NodeDistance, NodeId, Peer, PeerFeatures}, + peer_manager::{NodeDistance, NodeId, Peer, PeerFeatures}, protocol::rpc::{mock::RpcRequestMock, RpcStatusCode}, runtime, test_utils::node_identity::{build_node_identity, ordered_node_identities_by_distance}, diff --git a/comms/src/builder/mod.rs b/comms/src/builder/mod.rs index ae1e8c0965..6fb3420706 100644 --- a/comms/src/builder/mod.rs +++ b/comms/src/builder/mod.rs @@ -205,7 +205,6 @@ impl CommsBuilder { match self.peer_storage.take() { Some(storage) => { - // TODO: Peer manager should be refactored to be backend agnostic #[cfg(not(test))] PeerManager::migrate_lmdb(&storage.inner())?; diff --git a/comms/src/connection_manager/common.rs b/comms/src/connection_manager/common.rs index e384d44b76..0fe01f04a8 100644 --- a/comms/src/connection_manager/common.rs +++ b/comms/src/connection_manager/common.rs @@ -122,7 +122,7 @@ pub async fn validate_and_add_peer_from_peer_identity( peer.addresses = addresses.into(); peer.set_offline(false); if let Some(addr) = dialed_addr { - peer.addresses.mark_successful_connection_attempt(addr); + peer.addresses.mark_last_seen_now(addr); } peer.features = PeerFeatures::from_bits_truncate(peer_identity.features); peer.supported_protocols = supported_protocols.clone(); @@ -153,7 +153,7 @@ pub async fn validate_and_add_peer_from_peer_identity( add_valid_identity_signature_to_peer(&mut new_peer, identity_sig)?; } if let Some(addr) = dialed_addr { - new_peer.addresses.mark_successful_connection_attempt(addr); + new_peer.addresses.mark_last_seen_now(addr); } new_peer }, diff --git a/comms/src/connectivity/config.rs b/comms/src/connectivity/config.rs index b2d90c753d..33a42177d7 100644 --- a/comms/src/connectivity/config.rs +++ b/comms/src/connectivity/config.rs @@ -41,6 +41,10 @@ pub struct ConnectivityConfig { /// The length of time to wait before disconnecting a connection that failed tie breaking. /// Default: 1s pub connection_tie_break_linger: Duration, + /// If the peer has not been seen within this interval, it will be removed from the peer list on the + /// next connection attempt. + /// Default: 24 hours + pub expire_peer_last_seen_duration: Duration, } impl Default for ConnectivityConfig { @@ -50,8 +54,9 @@ impl Default for ConnectivityConfig { connection_pool_refresh_interval: Duration::from_secs(60), reaper_min_inactive_age: Duration::from_secs(20 * 60), is_connection_reaping_enabled: true, - max_failures_mark_offline: 2, + max_failures_mark_offline: 1, connection_tie_break_linger: Duration::from_secs(2), + expire_peer_last_seen_duration: Duration::from_secs(24 * 60 * 60), } } } diff --git a/comms/src/connectivity/manager.rs b/comms/src/connectivity/manager.rs index e5c0e46bee..30f59edf00 100644 --- a/comms/src/connectivity/manager.rs +++ b/comms/src/connectivity/manager.rs @@ -92,7 +92,7 @@ impl ConnectivityManager { pool: ConnectionPool::new(), shutdown_signal: self.shutdown_signal, #[cfg(feature = "metrics")] - uptime: Instant::now(), + uptime: Some(Instant::now()), } .spawn() } @@ -148,7 +148,7 @@ struct ConnectivityManagerActor { pool: ConnectionPool, shutdown_signal: ShutdownSignal, #[cfg(feature = "metrics")] - uptime: Instant, + uptime: Option, } impl ConnectivityManagerActor { @@ -194,6 +194,7 @@ impl ConnectivityManagerActor { }, _ = ticker.tick() => { + self.cleanup_connection_stats(); if let Err(err) = self.refresh_connection_pool().await { error!(target: LOG_TARGET, "Error when refreshing connection pools: {:?}", err); } @@ -429,16 +430,32 @@ impl ConnectivityManagerActor { node_id.short_str(), num_failed ); - if self.peer_manager.set_offline(node_id, true).await? { - debug!( - target: LOG_TARGET, - "Peer `{}` was marked as offline but was already offline.", node_id - ); - } else { - // Only publish the `PeerOffline` event if we changed the offline state from online to offline + if !self.peer_manager.set_offline(node_id, true).await? { + // Only publish the `PeerOffline` event if we change from online to offline self.publish_event(ConnectivityEvent::PeerOffline(node_id.clone())); } - self.connection_stats.remove(node_id); + + if let Ok(peer) = self.peer_manager.find_by_node_id(node_id).await { + if !peer.is_banned() && + peer.last_seen_since() + // Haven't seen them in expire_peer_last_seen_duration + .map(|t| t > self.config.expire_peer_last_seen_duration) + // Or don't delete if never seen + .unwrap_or(false) + { + debug!( + target: LOG_TARGET, + "Peer `{}` was marked as offline after {} attempts (last seen: {}). Removing peer from peer \ + list", + node_id, + num_failed, + peer.last_seen_since() + .map(|d| format!("{}s ago", d.as_secs())) + .unwrap_or_else(|| "Never".to_string()), + ); + self.peer_manager.delete_peer(node_id).await?; + } + } } Ok(()) @@ -515,10 +532,7 @@ impl ConnectivityManagerActor { } let (node_id, mut new_status, connection) = match event { - PeerDisconnected(node_id) => { - self.connection_stats.remove(node_id); - (&*node_id, ConnectionStatus::Disconnected, None) - }, + PeerDisconnected(node_id) => (&*node_id, ConnectionStatus::Disconnected, None), PeerConnected(conn) => (conn.peer_node_id(), ConnectionStatus::Connected, Some(conn.clone())), PeerConnectFailed(node_id, ConnectionManagerError::DialCancelled) => { @@ -662,7 +676,12 @@ impl ConnectivityManagerActor { metrics::connections(ConnectionDirection::Inbound).set(num_inbound); metrics::connections(ConnectionDirection::Outbound).set(total - num_inbound); - metrics::uptime().set(i64::try_from(self.uptime.elapsed().as_secs()).unwrap_or(i64::MAX)); + + let uptime = self + .uptime + .map(|ts| i64::try_from(ts.elapsed().as_secs()).unwrap_or(i64::MAX)) + .unwrap_or(0); + metrics::uptime().set(uptime); } fn transition(&mut self, next_status: ConnectivityStatus, required_num_peers: usize) { @@ -681,6 +700,11 @@ impl ConnectivityManagerActor { target: LOG_TARGET, "Connectivity is ONLINE ({}/{} connections)", n, required_num_peers ); + + #[cfg(feature = "metrics")] + if self.uptime.is_none() { + self.uptime = Some(Instant::now()); + } self.publish_event(ConnectivityEvent::ConnectivityStateOnline(n)); }, (Degraded(m), Degraded(n)) => { @@ -705,6 +729,10 @@ impl ConnectivityManagerActor { target: LOG_TARGET, "Connectivity is OFFLINE (0/{} connections)", required_num_peers ); + #[cfg(feature = "metrics")] + { + self.uptime = None; + } self.publish_event(ConnectivityEvent::ConnectivityStateOffline); }, (status, next_status) => unreachable!("Unexpected status transition ({} to {})", status, next_status), @@ -748,6 +776,22 @@ impl ConnectivityManagerActor { } Ok(()) } + + fn cleanup_connection_stats(&mut self) { + let mut to_remove = Vec::new(); + for node_id in self.connection_stats.keys() { + let status = self.pool.get_connection_status(node_id); + if matches!( + status, + ConnectionStatus::NotConnected | ConnectionStatus::Failed | ConnectionStatus::Disconnected + ) { + to_remove.push(node_id.clone()); + } + } + for node_id in to_remove { + self.connection_stats.remove(&node_id); + } + } } fn delayed_close(conn: PeerConnection, delay: Duration) { diff --git a/comms/src/connectivity/selection.rs b/comms/src/connectivity/selection.rs index c300d2d353..e673ecdb41 100644 --- a/comms/src/connectivity/selection.rs +++ b/comms/src/connectivity/selection.rs @@ -132,7 +132,7 @@ mod test { use super::*; use crate::{ connection_manager::PeerConnectionRequest, - peer_manager::node_id::NodeDistance, + peer_manager::NodeDistance, test_utils::{mocks::create_dummy_peer_connection, node_id, node_identity::build_node_identity}, }; use std::iter::repeat_with; diff --git a/comms/src/lib.rs b/comms/src/lib.rs index 8d7e03399a..463fc0aadd 100644 --- a/comms/src/lib.rs +++ b/comms/src/lib.rs @@ -72,4 +72,4 @@ pub mod multiaddr { pub use async_trait::async_trait; pub use bytes::{Bytes, BytesMut}; #[cfg(feature = "rpc")] -pub use tower_make::MakeService; +pub use tower::make::MakeService; diff --git a/comms/src/net_address/multiaddr_with_stats.rs b/comms/src/net_address/multiaddr_with_stats.rs index 7125e11d47..fc3d8c2c55 100644 --- a/comms/src/net_address/multiaddr_with_stats.rs +++ b/comms/src/net_address/multiaddr_with_stats.rs @@ -81,8 +81,8 @@ impl MutliaddrWithStats { self.rejected_message_count += 1; } - /// Mark that a successful connection was established with this net address - pub fn mark_successful_connection_attempt(&mut self) { + /// Mark that a successful interaction occurred with this address + pub fn mark_last_seen_now(&mut self) { self.last_seen = Some(Utc::now()); self.connection_attempts = 0; } @@ -229,7 +229,7 @@ mod test { net_address_with_stats.mark_failed_connection_attempt(); assert!(net_address_with_stats.last_seen.is_none()); assert_eq!(net_address_with_stats.connection_attempts, 2); - net_address_with_stats.mark_successful_connection_attempt(); + net_address_with_stats.mark_last_seen_now(); assert!(net_address_with_stats.last_seen.is_some()); assert_eq!(net_address_with_stats.connection_attempts, 0); } @@ -251,10 +251,10 @@ mod test { let mut na1 = MutliaddrWithStats::from(net_address.clone()); let mut na2 = MutliaddrWithStats::from(net_address); thread::sleep(Duration::from_millis(1)); - na1.mark_successful_connection_attempt(); + na1.mark_last_seen_now(); assert!(na1 < na2); thread::sleep(Duration::from_millis(1)); - na2.mark_successful_connection_attempt(); + na2.mark_last_seen_now(); assert!(na1 > na2); thread::sleep(Duration::from_millis(1)); na1.mark_message_rejected(); diff --git a/comms/src/net_address/mutliaddresses_with_stats.rs b/comms/src/net_address/mutliaddresses_with_stats.rs index 3a5a2ac788..87111e41ce 100644 --- a/comms/src/net_address/mutliaddresses_with_stats.rs +++ b/comms/src/net_address/mutliaddresses_with_stats.rs @@ -132,13 +132,13 @@ impl MultiaddressesWithStats { } } - /// Mark that a successful connection was established with the specified net address + /// Mark that a successful interaction occurred with the specified address /// /// Returns true if the address is contained in this instance, otherwise false - pub fn mark_successful_connection_attempt(&mut self, address: &Multiaddr) -> bool { + pub fn mark_last_seen_now(&mut self, address: &Multiaddr) -> bool { match self.find_address_mut(address) { Some(addr) => { - addr.mark_successful_connection_attempt(); + addr.mark_last_seen_now(); self.last_attempted = Some(Utc::now()); self.addresses.sort(); true @@ -256,9 +256,9 @@ mod test { net_addresses.add_address(&net_address2); net_addresses.add_address(&net_address3); - assert!(net_addresses.mark_successful_connection_attempt(&net_address3)); - assert!(net_addresses.mark_successful_connection_attempt(&net_address1)); - assert!(net_addresses.mark_successful_connection_attempt(&net_address2)); + assert!(net_addresses.mark_last_seen_now(&net_address3)); + assert!(net_addresses.mark_last_seen_now(&net_address1)); + assert!(net_addresses.mark_last_seen_now(&net_address2)); let desired_last_seen = net_addresses.addresses[0].last_seen; let last_seen = net_addresses.last_seen(); assert_eq!(desired_last_seen.unwrap(), last_seen.unwrap()); @@ -338,7 +338,7 @@ mod test { // assert!(net_addresses.mark_failed_connection_attempt(&net_address2)); // assert!(net_addresses.mark_failed_connection_attempt(&net_address3)); // assert!(net_addresses.mark_failed_connection_attempt(&net_address1)); - // assert!(net_addresses.mark_successful_connection_attempt(&net_address2)); + // assert!(net_addresses.mark_last_seen_now(&net_address2)); // assert_eq!(net_addresses.addresses[0].connection_attempts, 0); // assert_eq!(net_addresses.addresses[1].connection_attempts, 1); // assert_eq!(net_addresses.addresses[2].connection_attempts, 2); diff --git a/comms/src/peer_manager/manager.rs b/comms/src/peer_manager/manager.rs index 03d5c22c35..96a308dbb6 100644 --- a/comms/src/peer_manager/manager.rs +++ b/comms/src/peer_manager/manager.rs @@ -23,11 +23,12 @@ use crate::{ peer_manager::{ migrations, - node_id::{NodeDistance, NodeId}, peer::{Peer, PeerFlags}, peer_id::PeerId, peer_storage::PeerStorage, wrapper::KeyValueWrapper, + NodeDistance, + NodeId, PeerFeatures, PeerManagerError, PeerQuery, @@ -197,6 +198,10 @@ impl PeerManager { .closest_peers(node_id, n, excluded_peers, features) } + pub async fn mark_last_seen(&self, node_id: &NodeId) -> Result<(), PeerManagerError> { + self.peer_storage.write().await.mark_last_seen(node_id) + } + /// Fetch n random peers pub async fn random_peers(&self, n: usize, excluded: &[NodeId]) -> Result, PeerManagerError> { // Send to a random set of peers of size n that are Communication Nodes diff --git a/comms/src/peer_manager/migrations.rs b/comms/src/peer_manager/migrations.rs index f921985562..4f13d15f69 100644 --- a/comms/src/peer_manager/migrations.rs +++ b/comms/src/peer_manager/migrations.rs @@ -32,13 +32,15 @@ pub(super) const MIGRATION_VERSION_KEY: u64 = u64::MAX; pub fn migrate(database: &LMDBDatabase) -> Result<(), LMDBError> { // Add migrations here in version order let migrations = vec![v4::Migration.boxed()]; - + if migrations.is_empty() { + return Ok(()); + } let latest_version = migrations.last().unwrap().get_version(); // If the database is empty there is nothing to migrate, so set it to the latest version if database.len()? == 0 { debug!(target: LOG_TARGET, "New database does not require migration"); - if let Err(err) = database.insert(&MIGRATION_VERSION_KEY, &(migrations.len() as u32)) { + if let Err(err) = database.insert(&MIGRATION_VERSION_KEY, &latest_version) { error!( target: LOG_TARGET, "Failed to update migration counter: {}. ** Database may be corrupt **", err @@ -77,7 +79,13 @@ pub fn migrate(database: &LMDBDatabase) -> Result<(), LMDBError> { debug!(target: LOG_TARGET, "Migration {} complete", version); }, - None => break Ok(()), + None => { + error!( + target: LOG_TARGET, + "Migration {} not found. Unable to migrate peer db", version + ); + return Ok(()); + }, } } } diff --git a/comms/src/peer_manager/migrations/v4.rs b/comms/src/peer_manager/migrations/v4.rs index 71b97979c2..2d27db8958 100644 --- a/comms/src/peer_manager/migrations/v4.rs +++ b/comms/src/peer_manager/migrations/v4.rs @@ -26,7 +26,6 @@ use crate::{ connection_stats::PeerConnectionStats, migrations::MIGRATION_VERSION_KEY, node_id::deserialize_node_id_from_hex, - IdentitySignature, NodeId, PeerFeatures, PeerFlags, @@ -79,13 +78,13 @@ pub struct PeerV4 { pub banned_until: Option, pub banned_reason: String, pub offline_at: Option, + pub last_seen: Option, pub features: PeerFeatures, pub connection_stats: PeerConnectionStats, pub supported_protocols: Vec, pub added_at: NaiveDateTime, pub user_agent: String, pub metadata: HashMap>, - pub identity_signature: Option, } pub struct Migration; @@ -98,7 +97,7 @@ impl super::Migration for Migration { } fn migrate(&self, db: &LMDBDatabase) -> Result<(), Self::Error> { - db.for_each::(|old_peer| { + let result = db.for_each::(|old_peer| { let result = old_peer.and_then(|(key, peer)| { if key == MIGRATION_VERSION_KEY { return Ok(()); @@ -109,6 +108,7 @@ impl super::Migration for Migration { id: peer.id, public_key: peer.public_key, node_id: peer.node_id, + last_seen: peer.addresses.last_seen().map(|ts| ts.naive_utc()), addresses: peer.addresses, flags: peer.flags, banned_until: peer.banned_until, @@ -120,7 +120,6 @@ impl super::Migration for Migration { added_at: peer.added_at, user_agent: peer.user_agent, metadata: peer.metadata, - identity_signature: None, }) .map_err(Into::into) }); @@ -132,7 +131,14 @@ impl super::Migration for Migration { ); } IterationResult::Continue - })?; + }); + + if let Err(err) = result { + error!( + target: LOG_TARGET, + "Error reading peer pd: {} ** Database may be corrupt **", err + ); + } Ok(()) } diff --git a/comms/src/peer_manager/mod.rs b/comms/src/peer_manager/mod.rs index 5513240414..ec8466d9ed 100644 --- a/comms/src/peer_manager/mod.rs +++ b/comms/src/peer_manager/mod.rs @@ -81,6 +81,9 @@ pub use identity_signature::IdentitySignature; pub mod node_id; pub use node_id::NodeId; +mod node_distance; +pub use node_distance::NodeDistance; + mod node_identity; pub use node_identity::NodeIdentity; @@ -91,7 +94,7 @@ mod peer_features; pub use peer_features::PeerFeatures; mod peer_id; -pub use peer_id::PeerId; +pub(crate) use peer_id::PeerId; mod manager; pub use manager::PeerManager; diff --git a/comms/src/peer_manager/node_distance.rs b/comms/src/peer_manager/node_distance.rs new file mode 100644 index 0000000000..3bb9673c54 --- /dev/null +++ b/comms/src/peer_manager/node_distance.rs @@ -0,0 +1,208 @@ +// Copyright 2021, The Tari Project +// +// Redistribution and use in source and binary forms, with or without modification, are permitted provided that the +// following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following +// disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the +// following disclaimer in the documentation and/or other materials provided with the distribution. +// +// 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote +// products derived from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, +// INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, +// WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +// USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +use super::{node_id::NodeIdError, NodeId}; +use std::{ + convert::{TryFrom, TryInto}, + fmt, + mem, +}; + +pub type NodeDistance = XorDistance; + +#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Default)] +pub struct XorDistance(u128); + +impl XorDistance { + /// Construct a new zero distance + pub fn new() -> Self { + Self(0) + } + + /// Calculate the distance between two node ids using the XOR metric. + pub fn from_node_ids(x: &NodeId, y: &NodeId) -> Self { + let arr = x ^ y; + arr[..] + .try_into() + .expect("unreachable panic: NodeId::byte_size() <= NodeDistance::byte_size()") + } + + /// Returns the maximum distance. + pub const fn max_distance() -> Self { + Self(u128::MAX) + } + + /// Returns a zero distance. + pub const fn zero() -> Self { + Self(0) + } + + /// Returns the number of bytes required to represent the `XorDistance` + pub const fn byte_size() -> usize { + mem::size_of::() + } + + /// Returns the bucket that this distance falls between. + /// The node distance falls between the `i`th bucket if 2^i <= distance < 2^(i+1). + pub fn get_bucket_index(&self) -> u8 { + ((Self::byte_size() as u8 * 8) - self.0.leading_zeros() as u8).saturating_sub(1) + } + + pub fn to_bytes(&self) -> [u8; Self::byte_size()] { + self.0.to_be_bytes() + } + + pub fn as_u128(&self) -> u128 { + self.0 + } +} + +impl TryFrom<&[u8]> for XorDistance { + type Error = NodeIdError; + + /// Construct a node distance from a set of bytes + fn try_from(bytes: &[u8]) -> Result { + if bytes.len() > Self::byte_size() { + return Err(NodeIdError::IncorrectByteCount); + } + + let mut buf = [0; Self::byte_size()]; + // Big endian has the MSB at index 0, if size of `bytes` is less than byte_size it must be offset to have + // leading 0 bytes + let offset = Self::byte_size() - bytes.len(); + buf[offset..].copy_from_slice(bytes); + Ok(XorDistance(u128::from_be_bytes(buf))) + } +} + +impl fmt::Display for NodeDistance { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{}", self.0) + } +} + +impl fmt::Debug for XorDistance { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let mut digits = 0; + let mut suffix = ""; + loop { + let prefix = self.0 / u128::pow(10, 3 * (digits + 1)); + + if prefix == 0 || digits > 8 { + return write!(f, "XorDist: {}{}", self.0 / u128::pow(10, 3 * digits), suffix); + } + + digits += 1; + suffix = match suffix { + "" => "thousand", + "thousand" => "million", + "million" => "billion", + "billion" => "trillion", + "trillion" => "quadrillion", + "quadrillion" => "quintillion", + "quintillion" => "sextillion", + "sextillion" => "septillion", + "septillion" => "e24", + _ => suffix, + } + } + } +} + +#[cfg(test)] +mod test { + use super::*; + use crate::types::CommsPublicKey; + use rand::rngs::OsRng; + use tari_crypto::keys::PublicKey; + + mod ord { + use super::*; + + #[test] + fn it_uses_big_endian_ordering() { + let a = NodeDistance::try_from(&[0u8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1][..]).unwrap(); + let b = NodeDistance::try_from(&[1u8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0][..]).unwrap(); + assert!(a < b); + } + } + + mod get_bucket_index { + use super::*; + + #[test] + fn it_returns_the_correct_index() { + fn check_for_dist(lsb_dist: u8, expected: u8) { + assert_eq!( + NodeDistance::try_from(&[0u8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, lsb_dist][..]) + .unwrap() + .get_bucket_index(), + expected, + "Failed for dist = {}", + lsb_dist + ); + } + + assert_eq!(NodeDistance::max_distance().get_bucket_index(), 127); + assert_eq!(NodeDistance::zero().get_bucket_index(), 0); + + check_for_dist(1, 0); + for i in 2..4 { + check_for_dist(i, 1); + } + for i in 4..8 { + check_for_dist(i, 2); + } + for i in 8..16 { + check_for_dist(i, 3); + } + assert_eq!( + NodeDistance::try_from(&[0u8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0b01000001, 0, 0][..]) + .unwrap() + .get_bucket_index(), + 8 * 2 + 7 - 1 + ); + + assert_eq!( + NodeDistance::try_from(&[0b10000000u8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0][..]) + .unwrap() + .get_bucket_index(), + 103 + ); + } + + #[test] + fn correctness_fuzzing() { + for _ in 0..100 { + let (_, pk) = CommsPublicKey::random_keypair(&mut OsRng); + let a = NodeId::from_public_key(&pk); + let (_, pk) = CommsPublicKey::random_keypair(&mut OsRng); + let b = NodeId::from_public_key(&pk); + let dist = NodeDistance::from_node_ids(&a, &b); + let i = dist.get_bucket_index() as u32; + let dist = dist.as_u128(); + assert!(2u128.pow(i) <= dist, "Failed for {}, i = {}", dist, i); + assert!(dist < 2u128.pow(i + 1), "Failed for {}, i = {}", dist, i,); + } + } + } +} diff --git a/comms/src/peer_manager/node_id.rs b/comms/src/peer_manager/node_id.rs index 651f3d1b46..00e5de4cfc 100644 --- a/comms/src/peer_manager/node_id.rs +++ b/comms/src/peer_manager/node_id.rs @@ -20,7 +20,7 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -use crate::types::CommsPublicKey; +use crate::{peer_manager::node_distance::NodeDistance, types::CommsPublicKey}; use blake2::{ digest::{Update, VariableOutput}, VarBlake2b, @@ -33,6 +33,7 @@ use std::{ fmt, hash::{Hash, Hasher}, marker::PhantomData, + ops::BitXor, }; use tari_crypto::tari_utilities::{ hex::{to_hex, Hex}, @@ -41,202 +42,37 @@ use tari_crypto::tari_utilities::{ }; use thiserror::Error; -type NodeIdArray = [u8; NodeId::BYTE_SIZE]; - -pub type NodeDistance = XorDistance; // or HammingDistance +pub(super) type NodeIdArray = [u8; NodeId::byte_size()]; #[derive(Debug, Error, Clone)] pub enum NodeIdError { - #[error("Incorrect byte count (expected {} bytes)", NodeId::BYTE_SIZE)] + #[error("Incorrect byte count (expected {} bytes)", NodeId::byte_size())] IncorrectByteCount, #[error("Invalid digest output size")] InvalidDigestOutputSize, } -//------------------------------------- XOR Metric -----------------------------------------------// -const NODE_XOR_DISTANCE_ARRAY_SIZE: usize = NodeId::BYTE_SIZE; -type NodeXorDistanceArray = [u8; NODE_XOR_DISTANCE_ARRAY_SIZE]; - -#[derive(Clone, Debug, Eq, PartialOrd, Ord, Default)] -pub struct XorDistance(NodeXorDistanceArray); - -impl XorDistance { - /// Construct a new zero distance - pub fn new() -> Self { - Self([0; NODE_XOR_DISTANCE_ARRAY_SIZE]) - } - - /// Calculate the distance between two node ids using the Hamming distance. - pub fn from_node_ids(x: &NodeId, y: &NodeId) -> Self { - Self(xor(&x.0, &y.0)) - } - - /// Returns the maximum distance. - pub const fn max_distance() -> Self { - Self([255; NODE_XOR_DISTANCE_ARRAY_SIZE]) - } - - /// Returns a zero distance. - pub const fn zero() -> Self { - Self([0; NODE_XOR_DISTANCE_ARRAY_SIZE]) - } - - /// Returns the number of bytes required to represent the `XorDistance` - pub const fn byte_length() -> usize { - NODE_XOR_DISTANCE_ARRAY_SIZE - } -} - -impl PartialEq for XorDistance { - fn eq(&self, nd: &XorDistance) -> bool { - self.0 == nd.0 - } -} - -impl TryFrom<&[u8]> for XorDistance { - type Error = NodeIdError; - - /// Construct a node distance from a set of bytes - fn try_from(elements: &[u8]) -> Result { - if elements.len() >= NODE_XOR_DISTANCE_ARRAY_SIZE { - let mut bytes = [0; NODE_XOR_DISTANCE_ARRAY_SIZE]; - bytes.copy_from_slice(&elements[0..NODE_XOR_DISTANCE_ARRAY_SIZE]); - Ok(XorDistance(bytes)) - } else { - Err(NodeIdError::IncorrectByteCount) - } - } -} - -impl TryFrom for u128 { - type Error = String; - - fn try_from(value: XorDistance) -> Result { - if XorDistance::byte_length() > 16 { - return Err("XorDistance has too many bytes to be converted to U128".to_string()); - } - let slice = value.as_bytes(); - let mut bytes: [u8; 16] = [0u8; 16]; - bytes[..XorDistance::byte_length()].copy_from_slice(&slice[..XorDistance::byte_length()]); - Ok(u128::from_be_bytes(bytes)) - } -} - -//---------------------------------- Hamming Distance --------------------------------------------// -const NODE_HAMMING_DISTANCE_ARRAY_SIZE: usize = 1; -type NodeHammingDistanceArray = [u8; NODE_HAMMING_DISTANCE_ARRAY_SIZE]; - -/// Hold the distance calculated between two NodeId's. This is used for DHT-style routing. -#[derive(Clone, Debug, Eq, PartialOrd, Ord, Default)] -pub struct HammingDistance(NodeHammingDistanceArray); - -impl HammingDistance { - /// Construct a new zero distance - pub fn new() -> Self { - Self([0; NODE_HAMMING_DISTANCE_ARRAY_SIZE]) - } - - /// Calculate the distance between two node ids using the Hamming distance. - pub fn from_node_ids(x: &NodeId, y: &NodeId) -> Self { - let xor_bytes = xor(&x.0, &y.0); - Self([hamming_distance(xor_bytes)]) - } - - /// Returns the maximum distance. - pub const fn max_distance() -> Self { - Self([NodeId::BYTE_SIZE as u8 * 8; NODE_HAMMING_DISTANCE_ARRAY_SIZE]) - } -} - -impl TryFrom<&[u8]> for HammingDistance { - type Error = NodeIdError; - - /// Construct a node distance from a set of bytes - fn try_from(elements: &[u8]) -> Result { - if elements.len() >= NODE_HAMMING_DISTANCE_ARRAY_SIZE { - let mut bytes = [0; NODE_HAMMING_DISTANCE_ARRAY_SIZE]; - bytes.copy_from_slice(&elements[0..NODE_HAMMING_DISTANCE_ARRAY_SIZE]); - Ok(HammingDistance(bytes)) - } else { - Err(NodeIdError::IncorrectByteCount) - } - } -} - -impl PartialEq for HammingDistance { - fn eq(&self, nd: &HammingDistance) -> bool { - self.0 == nd.0 - } -} - -/// Calculate the Exclusive OR between the node_id x and y. -fn xor(x: &NodeIdArray, y: &NodeIdArray) -> NodeIdArray { - let mut nd = [0u8; NodeId::BYTE_SIZE]; - for i in 0..nd.len() { - nd[i] = x[i] ^ y[i]; - } - nd -} - -/// Calculate the hamming distance (the number of set (1) bits of the XOR metric) -fn hamming_distance(nd: NodeIdArray) -> u8 { - let xor_bytes = &nd; - let mut set_bit_count = 0u8; - for b in xor_bytes { - let mut mask = 0b1u8; - for _ in 0..8 { - if b & mask > 0 { - set_bit_count += 1; - } - mask <<= 1; - } - } - - set_bit_count -} - -impl ByteArray for NodeDistance { - /// Try and convert the given byte array to a NodeDistance. Any failures (incorrect array length, - /// implementation-specific checks, etc) return a [ByteArrayError](enum.ByteArrayError.html). - fn from_bytes(bytes: &[u8]) -> Result { - bytes - .try_into() - .map_err(|err| ByteArrayError::ConversionError(format!("{:?}", err))) - } - - /// Return the NodeDistance as a byte array - fn as_bytes(&self) -> &[u8] { - self.0.as_ref() - } -} - -impl fmt::Display for NodeDistance { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "{}", to_hex(&self.0)) - } -} - -//--------------------------------------- NodeId -------------------------------------------------// - /// A Node Identity is used as a unique identifier for a node in the Tari communications network. #[derive(Clone, Eq, Deserialize, Serialize, Default)] pub struct NodeId(NodeIdArray); impl NodeId { - /// 104-bit/13 byte as per RFC-0151 - pub const BYTE_SIZE: usize = 13; - /// Construct a new node id on the origin pub fn new() -> Self { Default::default() } + /// 104-bit/13 byte as per RFC-0151 + pub const fn byte_size() -> usize { + 13 + } + /// Derive a node id from a public key: node_id=hash(public_key) pub fn from_key(key: &K) -> Self { let bytes = key.as_bytes(); - let mut buf = [0u8; NodeId::BYTE_SIZE]; - VarBlake2b::new(NodeId::BYTE_SIZE) - .expect("NodeId::NODE_ID_ARRAY_SIZE is invalid") + let mut buf = [0u8; NodeId::byte_size()]; + VarBlake2b::new(NodeId::byte_size()) + .expect("NodeId::byte_size() is invalid") .chain(bytes) .finalize_variable(|hash| { // Safety: output size and buf size are equal @@ -344,18 +180,31 @@ impl Ord for NodeId { } } +impl BitXor for &NodeId { + type Output = NodeIdArray; + + fn bitxor(self, rhs: Self) -> Self::Output { + let mut xor = [0u8; NodeId::byte_size()]; + #[allow(clippy::needless_range_loop)] + for i in 0..NodeId::byte_size() { + xor[i] = self.0[i] ^ rhs.0[i]; + } + xor + } +} + impl TryFrom<&[u8]> for NodeId { type Error = NodeIdError; /// Construct a node id from 32 bytes - fn try_from(elements: &[u8]) -> Result { - if elements.len() >= NodeId::BYTE_SIZE { - let mut bytes = [0; NodeId::BYTE_SIZE]; - bytes.copy_from_slice(&elements[0..NodeId::BYTE_SIZE]); - Ok(NodeId(bytes)) - } else { - Err(NodeIdError::IncorrectByteCount) + fn try_from(bytes: &[u8]) -> Result { + if bytes.len() != NodeId::byte_size() { + return Err(NodeIdError::IncorrectByteCount); } + + let mut buf = [0; NodeId::byte_size()]; + buf.copy_from_slice(bytes); + Ok(NodeId(buf)) } } @@ -422,8 +271,7 @@ mod test { #[test] fn display() { - let node_id = - NodeId::try_from(&[144u8, 28, 106, 112, 220, 197, 216, 119, 9, 217, 42, 77, 159, 211, 53][..]).unwrap(); + let node_id = NodeId::try_from(&[144u8, 28, 106, 112, 220, 197, 216, 119, 9, 217, 42, 77, 159][..]).unwrap(); let result = format!("{}", node_id); assert_eq!("901c6a70dcc5d87709d92a4d9f", result); @@ -444,64 +292,21 @@ mod test { #[test] fn test_distance_and_ordering() { - let node_id1 = NodeId::try_from( - [ - 144, 28, 106, 112, 220, 197, 216, 119, 9, 217, 42, 77, 159, 211, 53, 207, 0, 157, 5, 55, 235, 247, 160, - 195, 240, 48, 146, 168, 119, 15, 241, 54, - ] - .as_bytes(), - ) - .unwrap(); - let node_id2 = NodeId::try_from( - [ - 186, 43, 62, 14, 60, 214, 9, 180, 145, 122, 55, 160, 83, 83, 45, 185, 219, 206, 226, 128, 5, 26, 20, 0, - 192, 121, 216, 178, 134, 212, 51, 131, - ] - .as_bytes(), - ) - .unwrap(); - let node_id3 = NodeId::try_from( - [ - 60, 32, 246, 39, 108, 201, 214, 91, 30, 230, 3, 126, 31, 46, 66, 203, 27, 51, 240, 177, 230, 22, 118, - 102, 201, 55, 211, 147, 229, 26, 116, 103, - ] - .as_bytes(), - ) - .unwrap(); + let node_id1 = NodeId::try_from(&[144, 28, 106, 112, 220, 197, 216, 119, 9, 217, 42, 77, 159][..]).unwrap(); + let node_id2 = NodeId::try_from(&[186, 43, 62, 14, 60, 214, 9, 180, 145, 122, 55, 160, 83][..]).unwrap(); + let node_id3 = NodeId::try_from(&[60, 32, 246, 39, 108, 201, 214, 91, 30, 230, 3, 126, 31][..]).unwrap(); assert!(node_id1.0 < node_id2.0); assert!(node_id1.0 > node_id3.0); // XOR metric - let desired_n1_to_n2_dist = NodeDistance::try_from( - [ - 42, 55, 84, 126, 224, 19, 209, 195, 152, 163, 29, 237, 204, 128, 24, 118, 219, 83, 231, 183, 238, 237, - 180, 195, 48, 73, 74, 26, 241, 219, 194, 181, - ] - .as_bytes(), - ) - .unwrap(); - let desired_n1_to_n3_dist = NodeDistance::try_from( - [ - 172, 60, 156, 87, 176, 12, 14, 44, 23, 63, 41, 51, 128, 253, 119, 4, 27, 174, 245, 134, 13, 225, 214, - 165, 57, 7, 65, 59, 146, 21, 133, 81, - ] - .as_bytes(), - ) - .unwrap(); - // Hamming distance - // let desired_n1_to_n2_dist_bytes: &[u8] = &vec![52u8]; - // let desired_n1_to_n2_dist = NodeDistance::try_from(desired_n1_to_n2_dist_bytes).unwrap(); - // let desired_n1_to_n3_dist = NodeDistance::try_from( - // [ - // 46, 60, 156, 87, 176, 12, 14, 44, 23, 63, 41, 51, 128, 253, 119, 4, 27, 174, 245, 134, 13, 225, 214, - // 165, 57, 7, 65, 59, 146, 21, 133, 81, - // ] - // .as_bytes(), - // ) - // .unwrap(); // Unused bytes will be discarded + let desired_n1_to_n2_dist = + NodeDistance::try_from(&[42, 55, 84, 126, 224, 19, 209, 195, 152, 163, 29, 237, 204][..]).unwrap(); + let desired_n1_to_n3_dist = + NodeDistance::try_from(&[172, 60, 156, 87, 176, 12, 14, 44, 23, 63, 41, 51, 128][..]).unwrap(); + let n1_to_n2_dist = node_id1.distance(&node_id2); let n1_to_n3_dist = node_id1.distance(&node_id3); - assert!(n1_to_n2_dist < n1_to_n3_dist); // XOR metric - // assert!(n1_to_n2_dist > n1_to_n3_dist); // Hamming Distance + // Big-endian ordering + assert!(n1_to_n2_dist < n1_to_n3_dist); assert_eq!(n1_to_n2_dist, desired_n1_to_n2_dist); assert_eq!(n1_to_n3_dist, desired_n1_to_n3_dist); @@ -541,90 +346,30 @@ mod test { assert_eq!(knn_node_ids[2].0, [ 143, 189, 32, 210, 30, 231, 82, 5, 86, 85, 28, 82, 154 ]); - // Hamming distance nearest neighbours - // assert_eq!(knn_node_ids[0].0, [ - // 75, 146, 162, 130, 22, 63, 247, 182, 156, 103, 174, 32, 134 - // ]); - // assert_eq!(knn_node_ids[1].0, [ - // 134, 116, 78, 53, 246, 206, 200, 147, 126, 96, 54, 113, 67 - // ]); - // assert_eq!(knn_node_ids[2].0, [ - // 144, 28, 106, 112, 220, 197, 216, 119, 9, 217, 42, 77, 159 - // ]); + assert_eq!(node_id.closest(&node_ids, node_ids.len() + 1).len(), node_ids.len()); } #[test] fn partial_eq() { - let bytes = [ - 173, 218, 34, 188, 211, 173, 235, 82, 18, 159, 55, 47, 242, 24, 95, 60, 208, 53, 97, 51, 43, 71, 149, 89, - 123, 150, 162, 67, 240, 208, 67, 56, - ] - .as_bytes(); + let bytes = &[173, 218, 34, 188, 211, 173, 235, 82, 18, 159, 55, 47, 242][..]; let nid1 = NodeId::try_from(bytes).unwrap(); let nid2 = NodeId::try_from(bytes).unwrap(); assert_eq!(nid1, nid2); } - #[test] - fn hamming_distance() { - let mut node_id1 = NodeId::default().into_inner().to_vec(); - let mut node_id2 = NodeId::default().into_inner().to_vec(); - // Same bits - node_id1[0] = 0b00010100; - node_id2[0] = 0b00010100; - // Different bits - node_id1[1] = 0b11010100; - node_id1[12] = 0b01000011; - node_id2[10] = 0b01000011; - node_id2[9] = 0b11111111; - let node_id1 = NodeId::from_bytes(node_id1.as_slice()).unwrap(); - let node_id2 = NodeId::from_bytes(node_id2.as_slice()).unwrap(); - - let hamming_dist = HammingDistance::from_node_ids(&node_id1, &node_id2); - assert_eq!(hamming_dist, HammingDistance([18])); - - let node_max = NodeId::from_bytes(&[255; NodeId::BYTE_SIZE]).unwrap(); - let node_min = NodeId::default(); - - let hamming_dist = HammingDistance::from_node_ids(&node_max, &node_min); - assert_eq!(hamming_dist, HammingDistance::max_distance()); - } - #[test] fn convert_xor_distance_to_u128() { - let node_id1 = NodeId::try_from( - [ - 144, 28, 106, 112, 220, 197, 216, 119, 9, 217, 42, 77, 159, 211, 53, 207, 0, 157, 5, 55, 235, 247, 160, - 195, 240, 48, 146, 168, 119, 15, 241, 54, - ] - .as_bytes(), - ) - .unwrap(); - let node_id2 = NodeId::try_from( - [ - 186, 43, 62, 14, 60, 214, 9, 180, 145, 122, 55, 160, 83, 83, 45, 185, 219, 206, 226, 128, 5, 26, 20, 0, - 192, 121, 216, 178, 134, 212, 51, 131, - ] - .as_bytes(), - ) - .unwrap(); - let node_id3 = NodeId::try_from( - [ - 60, 32, 246, 39, 108, 201, 214, 91, 30, 230, 3, 126, 31, 46, 66, 203, 27, 51, 240, 177, 230, 22, 118, - 102, 201, 55, 211, 147, 229, 26, 116, 103, - ] - .as_bytes(), - ) - .unwrap(); - let n1_to_n2_dist = node_id1.distance(&node_id2); - let n1_to_n3_dist = node_id1.distance(&node_id3); - assert!(n1_to_n2_dist < n1_to_n3_dist); - let n12_distance = u128::try_from(n1_to_n2_dist).unwrap(); - let n13_distance = u128::try_from(n1_to_n3_dist).unwrap(); + let node_id1 = NodeId::try_from(&[128, 28, 106, 112, 220, 197, 216, 119, 9, 128, 42, 77, 55][..]).unwrap(); + let node_id2 = NodeId::try_from(&[160, 28, 106, 112, 220, 197, 216, 119, 9, 128, 42, 77, 54][..]).unwrap(); + let node_id3 = NodeId::try_from(&[64, 28, 106, 112, 220, 197, 216, 119, 9, 128, 42, 77, 54][..]).unwrap(); + let n12_distance = node_id1.distance(&node_id2); + let n13_distance = node_id1.distance(&node_id3); + assert_eq!(n12_distance.to_bytes()[..4], [0, 0, 0, 32]); + assert_eq!(n13_distance.to_bytes()[..4], [0, 0, 0, 192]); assert!(n12_distance < n13_distance); - assert_eq!(n12_distance, 56114865924689668092413877285545836544); - assert_eq!(n13_distance, 228941924089749863963604860508980641792); + assert_eq!(n12_distance.as_u128(), ((128 ^ 160) << (12 * 8)) + 1); + assert_eq!(n13_distance.as_u128(), ((128 ^ 64) << (12 * 8)) + 1); } } diff --git a/comms/src/peer_manager/peer.rs b/comms/src/peer_manager/peer.rs index 880c06f314..669f9550d0 100644 --- a/comms/src/peer_manager/peer.rs +++ b/comms/src/peer_manager/peer.rs @@ -34,7 +34,7 @@ use crate::{ utils::datetime::safe_future_datetime_from_duration, }; use bitflags::bitflags; -use chrono::{DateTime, NaiveDateTime, Utc}; +use chrono::{NaiveDateTime, Utc}; use multiaddr::Multiaddr; use serde::{Deserialize, Serialize}; use std::{ @@ -78,6 +78,7 @@ pub struct Peer { pub banned_until: Option, pub banned_reason: String, pub offline_at: Option, + pub last_seen: Option, /// Features supported by the peer pub features: PeerFeatures, /// Connection statics for the peer @@ -118,6 +119,7 @@ impl Peer { banned_until: None, banned_reason: String::new(), offline_at: None, + last_seen: None, connection_stats: Default::default(), added_at: Utc::now().naive_utc(), supported_protocols, @@ -218,8 +220,14 @@ impl Peer { } /// Provides that date time of the last successful interaction with the peer - pub fn last_seen(&self) -> Option> { - self.addresses.last_seen() + pub fn last_seen(&self) -> Option { + self.last_seen + } + + /// Provides that length of time since the last successful interaction with the peer + pub fn last_seen_since(&self) -> Option { + self.last_seen() + .and_then(|dt| Utc::now().naive_utc().signed_duration_since(dt).to_std().ok()) } /// Returns true if this peer has the given feature, otherwise false diff --git a/comms/src/peer_manager/peer_query.rs b/comms/src/peer_manager/peer_query.rs index 62c739d522..f104ae2207 100644 --- a/comms/src/peer_manager/peer_query.rs +++ b/comms/src/peer_manager/peer_query.rs @@ -21,7 +21,7 @@ // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. use crate::peer_manager::{peer_id::PeerId, NodeId, Peer, PeerManagerError}; -use std::cmp::min; +use std::cmp::{min, Ordering}; use tari_storage::{IterationResult, KeyValueStore}; type Predicate<'a, A> = Box bool + Send + 'a>; @@ -33,6 +33,10 @@ pub enum PeerQuerySortBy<'a> { None, /// Sort by distance from a given node id DistanceFrom(&'a NodeId), + /// Sort by last connected + LastConnected, + /// Sort by distance from a given node followed by last connected + DistanceFromLastConnected(&'a NodeId), } impl Default for PeerQuerySortBy<'_> { @@ -47,7 +51,6 @@ pub struct PeerQuery<'a> { select_predicate: Option>, limit: Option, sort_by: PeerQuerySortBy<'a>, - until_predicate: Option>, } impl<'a> PeerQuery<'a> { @@ -76,12 +79,6 @@ impl<'a> PeerQuery<'a> { self } - pub fn until(mut self, until_predicate: F) -> Self - where F: FnMut(&[Peer]) -> bool + Send + 'a { - self.until_predicate = Some(Box::new(until_predicate)); - self - } - /// Returns a `PeerQueryExecutor` with this `PeerQuery` pub(super) fn executor(self, store: &DS) -> PeerQueryExecutor<'a, '_, DS> where DS: KeyValueStore { @@ -102,14 +99,6 @@ impl<'a> PeerQuery<'a> { .map(|predicate| (predicate)(peer)) .unwrap_or(true) } - - /// Returns true if the result collector should stop early, otherwise false - fn should_stop(&mut self, peers: &[Peer]) -> bool { - self.until_predicate - .as_mut() - .map(|predicate| (predicate)(peers)) - .unwrap_or(false) - } } /// This struct executes the query using the given store @@ -127,19 +116,41 @@ where DS: KeyValueStore pub fn get_results(&mut self) -> Result, PeerManagerError> { match self.query.sort_by { - PeerQuerySortBy::None => self.get_query_results(), + PeerQuerySortBy::None => self.get_unsorted_results(), PeerQuerySortBy::DistanceFrom(node_id) => self.get_distance_sorted_results(node_id), + PeerQuerySortBy::LastConnected => self.get_last_connected_sorted_results(), + PeerQuerySortBy::DistanceFromLastConnected(node_id) => { + self.get_distance_then_last_connected_results(node_id) + }, } } + pub fn get_last_connected_sorted_results(&mut self) -> Result, PeerManagerError> { + self.get_sorted_results(last_seen_compare_desc) + } + pub fn get_distance_sorted_results(&mut self, node_id: &NodeId) -> Result, PeerManagerError> { - let mut peer_keys = Vec::new(); - let mut distances = Vec::new(); + self.get_sorted_results(|a, b| { + let a = a.node_id.distance(node_id); + let b = b.node_id.distance(node_id); + // Sort ascending + a.cmp(&b) + }) + } + + fn get_distance_then_last_connected_results(&mut self, node_id: &NodeId) -> Result, PeerManagerError> { + let mut peers = self.get_distance_sorted_results(node_id)?; + peers.sort_by(last_seen_compare_desc); + Ok(peers) + } + + fn get_sorted_results(&mut self, compare: F) -> Result, PeerManagerError> + where F: FnMut(&Peer, &Peer) -> Ordering { + let mut selected_peers = Vec::new(); self.store - .for_each_ok(|(peer_key, peer)| { + .for_each_ok(|(_, peer)| { if self.query.is_selected(&peer) { - peer_keys.push(peer_key); - distances.push(node_id.distance(&peer.node_id)); + selected_peers.push(peer); } IterationResult::Continue @@ -150,46 +161,24 @@ where DS: KeyValueStore let max_available = self .query .limit - .map(|limit| min(peer_keys.len(), limit)) - .unwrap_or_else(|| peer_keys.len()); + .map(|limit| min(selected_peers.len(), limit)) + .unwrap_or_else(|| selected_peers.len()); if max_available == 0 { return Ok(Vec::new()); } - // Perform partial sort of elements only up to N elements - let mut selected_peers = Vec::with_capacity(max_available); - for i in 0..max_available { - for j in (i + 1)..peer_keys.len() { - if distances[i] > distances[j] { - distances.swap(i, j); - peer_keys.swap(i, j); - } - } - let peer = self - .store - .get(&peer_keys[i]) - .map_err(PeerManagerError::DatabaseError)? - .ok_or(PeerManagerError::PeerNotFoundError)?; - - selected_peers.push(peer); - - if self.query.should_stop(&selected_peers) { - break; - } - } + selected_peers.sort_by(compare); + selected_peers.truncate(max_available); Ok(selected_peers) } - pub fn get_query_results(&mut self) -> Result, PeerManagerError> { - let mut selected_peers = match self.query.limit { - Some(n) => Vec::with_capacity(n), - None => Vec::new(), - }; + pub fn get_unsorted_results(&mut self) -> Result, PeerManagerError> { + let mut selected_peers = self.query.limit.map(Vec::with_capacity).unwrap_or_default(); self.store .for_each_ok(|(_, peer)| { - if self.query.within_limit(selected_peers.len()) && !self.query.should_stop(&selected_peers) { + if self.query.within_limit(selected_peers.len()) { if self.query.is_selected(&peer) { selected_peers.push(peer); } @@ -205,6 +194,17 @@ where DS: KeyValueStore } } +fn last_seen_compare_desc(a: &Peer, b: &Peer) -> Ordering { + match (a.last_seen(), b.last_seen()) { + // Sort descending + (Some(a), Some(b)) => b.cmp(&a), + // Nones go to the end + (None, Some(_)) => Ordering::Less, + (Some(_), None) => Ordering::Greater, + (None, None) => Ordering::Equal, + } +} + #[cfg(test)] mod test { use super::*; @@ -328,45 +328,6 @@ mod test { assert!(peers.iter().all(|peer| !peer.is_banned())); } - #[test] - fn select_where_until_query() { - // Create peer manager with random peers - let mut sample_peers = Vec::new(); - // Create 20 peers were the 1st and last one is bad - let _rng = rand::rngs::OsRng; - sample_peers.push(create_test_peer(true)); - let db = HashmapDatabase::new(); - let mut id_counter = 0; - - repeat_with(|| create_test_peer(true)).take(3).for_each(|peer| { - db.insert(id_counter, peer).unwrap(); - id_counter += 1; - }); - - repeat_with(|| create_test_peer(false)).take(5).for_each(|peer| { - db.insert(id_counter, peer).unwrap(); - id_counter += 1; - }); - - let peers = PeerQuery::new() - .select_where(|peer| !peer.is_banned()) - .until(|peers| peers.len() == 2) - .executor(&db) - .get_results() - .unwrap(); - - assert_eq!(peers.len(), 2); - assert!(peers.iter().all(|peer| !peer.is_banned())); - - let peers = PeerQuery::new() - .until(|peers| peers.len() == 100) - .executor(&db) - .get_results() - .unwrap(); - - assert_eq!(peers.len(), 8); - } - #[test] fn sort_by_query() { // Create peer manager with random peers diff --git a/comms/src/peer_manager/peer_storage.rs b/comms/src/peer_manager/peer_storage.rs index c566f53462..5fd4d1658e 100644 --- a/comms/src/peer_manager/peer_storage.rs +++ b/comms/src/peer_manager/peer_storage.rs @@ -22,16 +22,19 @@ use crate::{ peer_manager::{ - node_id::{NodeDistance, NodeId}, peer::{Peer, PeerFlags}, peer_id::{generate_peer_key, PeerId}, + NodeDistance, + NodeId, PeerFeatures, PeerManagerError, PeerQuery, + PeerQuerySortBy, }, protocol::ProtocolId, types::{CommsDatabase, CommsPublicKey}, }; +use chrono::Utc; use log::*; use multiaddr::Multiaddr; use rand::{rngs::OsRng, seq::SliceRandom}; @@ -220,7 +223,7 @@ where DS: KeyValueStore } pub fn find_all_starts_with(&self, partial: &[u8]) -> Result, PeerManagerError> { - if partial.is_empty() || partial.len() > NodeId::BYTE_SIZE { + if partial.is_empty() || partial.len() > NodeId::byte_size() { return Ok(Vec::new()); } @@ -339,25 +342,17 @@ where DS: KeyValueStore return Ok(Vec::new()); } - let mut distances = Vec::new(); - self.peer_db - .for_each_ok(|(_, peer)| { - if features.map(|f| peer.features == f).unwrap_or(true) && + let query = PeerQuery::new() + .select_where(|peer| { + features.map(|f| peer.features == f).unwrap_or(true) && !peer.is_banned() && !peer.is_offline() && !excluded_peers.contains(&peer.node_id) - { - let dist = node_id.distance(&peer.node_id); - distances.push((peer, dist)); - } - IterationResult::Continue }) - .map_err(PeerManagerError::DatabaseError)?; - - distances.sort_by(|(_, dist_a), (_, dist_b)| dist_a.cmp(dist_b)); - distances.truncate(n); + .sort_by(PeerQuerySortBy::DistanceFrom(node_id)) + .limit(n); - Ok(distances.into_iter().map(|(peer, _)| peer).collect()) + self.perform_query(query) } /// Compile a random list of communication node peers of size _n_ that are not banned or offline @@ -560,6 +555,16 @@ where DS: KeyValueStore .map_err(PeerManagerError::DatabaseError)?; Ok(result) } + + pub fn mark_last_seen(&mut self, node_id: &NodeId) -> Result<(), PeerManagerError> { + let mut peer = self.find_by_node_id(node_id)?; + peer.last_seen = Some(Utc::now().naive_utc()); + peer.set_offline(false); + self.peer_db + .insert(peer.id(), peer) + .map_err(PeerManagerError::DatabaseError)?; + Ok(()) + } } #[allow(clippy::from_over_into)] diff --git a/comms/src/protocol/rpc/server/mock.rs b/comms/src/protocol/rpc/server/mock.rs index 61c1651d7e..25d81b4778 100644 --- a/comms/src/protocol/rpc/server/mock.rs +++ b/comms/src/protocol/rpc/server/mock.rs @@ -61,8 +61,7 @@ use tokio::{ sync::{mpsc, Mutex, RwLock}, task, }; -use tower::Service; -use tower_make::MakeService; +use tower::{make::MakeService, Service}; pub struct RpcRequestMock { comms_provider: RpcCommsBackend, diff --git a/comms/src/protocol/rpc/server/mod.rs b/comms/src/protocol/rpc/server/mod.rs index ed02b82149..a50a01b7e2 100644 --- a/comms/src/protocol/rpc/server/mod.rs +++ b/comms/src/protocol/rpc/server/mod.rs @@ -78,8 +78,7 @@ use std::{ }; use tokio::{sync::mpsc, time}; use tokio_stream::Stream; -use tower::Service; -use tower_make::MakeService; +use tower::{make::MakeService, Service}; use tracing::{debug, error, instrument, span, trace, warn, Instrument, Level}; const LOG_TARGET: &str = "comms::rpc"; diff --git a/comms/src/protocol/rpc/server/router.rs b/comms/src/protocol/rpc/server/router.rs index 342454e122..4d41db5271 100644 --- a/comms/src/protocol/rpc/server/router.rs +++ b/comms/src/protocol/rpc/server/router.rs @@ -51,8 +51,7 @@ use futures::{ }; use std::sync::Arc; use tokio::sync::mpsc; -use tower::Service; -use tower_make::MakeService; +use tower::{make::MakeService, Service}; /// Allows service factories of different types to be composed into a single service that resolves a given `ProtocolId` pub struct Router { diff --git a/comms/src/tor/mod.rs b/comms/src/tor/mod.rs index 01cec723d4..bc08b6285b 100644 --- a/comms/src/tor/mod.rs +++ b/comms/src/tor/mod.rs @@ -38,6 +38,7 @@ pub use control_client::{ PortMapping, PrivateKey, TorClientError, + TorControlEvent, TorControlPortClient, }; diff --git a/comms/tests/rpc_stress.rs b/comms/tests/rpc_stress.rs index d1cff350f2..f06c7bc892 100644 --- a/comms/tests/rpc_stress.rs +++ b/comms/tests/rpc_stress.rs @@ -264,9 +264,11 @@ async fn high_contention_high_concurrency() { async fn run() { // let _ = env_logger::try_init(); log_timing("quick", quick()).await; - log_timing("basic", basic()).await; - log_timing("many_small_messages", many_small_messages()).await; - log_timing("few_large_messages", few_large_messages()).await; + if option_env!["COMMS_SKIP_LONG_RUNNING_STRESS_TESTS"].is_none() { + log_timing("basic", basic()).await; + log_timing("many_small_messages", many_small_messages()).await; + log_timing("few_large_messages", few_large_messages()).await; + } // log_timing("payload_limit", payload_limit()).await; // log_timing("high_contention", high_contention()).await; // log_timing("high_concurrency", high_concurrency()).await;