diff --git a/comms/dht/src/dedup/mod.rs b/comms/dht/src/dedup/mod.rs index 9c7ec3684d..f178b804dc 100644 --- a/comms/dht/src/dedup/mod.rs +++ b/comms/dht/src/dedup/mod.rs @@ -24,24 +24,19 @@ mod dedup_cache; pub use dedup_cache::DedupCacheDatabase; -use crate::{actor::DhtRequester, inbound::DhtInboundMessage}; -use digest::Digest; +use crate::{actor::DhtRequester, inbound::DecryptedDhtMessage}; use futures::{future::BoxFuture, task::Context}; use log::*; use std::task::Poll; -use tari_comms::{pipeline::PipelineError, types::Challenge}; +use tari_comms::pipeline::PipelineError; use tari_utilities::hex::Hex; use tower::{layer::Layer, Service, ServiceExt}; const LOG_TARGET: &str = "comms::dht::dedup"; -fn hash_inbound_message(message: &DhtInboundMessage) -> Vec { - Challenge::new().chain(&message.body).finalize().to_vec() -} - /// # DHT Deduplication middleware /// -/// Takes in a `DhtInboundMessage` and checks the message signature cache for duplicates. +/// Takes in a `DecryptedDhtMessage` and checks the message signature cache for duplicates. /// If a duplicate message is detected, it is discarded. #[derive(Clone)] pub struct DedupMiddleware { @@ -60,9 +55,9 @@ impl DedupMiddleware { } } -impl Service for DedupMiddleware +impl Service for DedupMiddleware where - S: Service + Clone + Send + 'static, + S: Service + Clone + Send + 'static, S::Future: Send, { type Error = PipelineError; @@ -73,22 +68,21 @@ where Poll::Ready(Ok(())) } - fn call(&mut self, mut message: DhtInboundMessage) -> Self::Future { + fn call(&mut self, mut message: DecryptedDhtMessage) -> Self::Future { let next_service = self.next_service.clone(); let mut dht_requester = self.dht_requester.clone(); let allowed_message_occurrences = self.allowed_message_occurrences; Box::pin(async move { - let hash = hash_inbound_message(&message); trace!( target: LOG_TARGET, "Inserting message hash {} for message {} (Trace: {})", - hash.to_hex(), + message.hash.to_hex(), message.tag, message.dht_header.message_tag ); message.dedup_hit_count = dht_requester - .add_message_to_dedup_cache(hash, message.source_peer.public_key.clone()) + .add_message_to_dedup_cache(message.hash.clone(), message.source_peer.public_key.clone()) .await?; if message.dedup_hit_count as usize > allowed_message_occurrences { @@ -144,6 +138,7 @@ mod test { envelope::DhtMessageFlags, test_utils::{create_dht_actor_mock, make_dht_inbound_message, make_node_identity, service_spy}, }; + use tari_comms::wrap_in_envelope_body; use tari_test_utils::panic_context; use tokio::runtime::Runtime; @@ -163,13 +158,14 @@ mod test { assert!(dedup.poll_ready(&mut cx).is_ready()); let node_identity = make_node_identity(); - let msg = make_dht_inbound_message(&node_identity, Vec::new(), DhtMessageFlags::empty(), false, false); + let inbound_message = make_dht_inbound_message(&node_identity, vec![], DhtMessageFlags::empty(), false, false); + let decrypted_msg = DecryptedDhtMessage::succeeded(wrap_in_envelope_body!(vec![]), None, inbound_message); - rt.block_on(dedup.call(msg.clone())).unwrap(); + rt.block_on(dedup.call(decrypted_msg.clone())).unwrap(); assert_eq!(spy.call_count(), 1); mock_state.set_number_of_message_hits(4); - rt.block_on(dedup.call(msg)).unwrap(); + rt.block_on(dedup.call(decrypted_msg)).unwrap(); assert_eq!(spy.call_count(), 1); // Drop dedup so that the DhtMock will stop running drop(dedup); @@ -179,28 +175,29 @@ mod test { fn deterministic_hash() { const TEST_MSG: &[u8] = b"test123"; const EXPECTED_HASH: &str = "90cccd774db0ac8c6ea2deff0e26fc52768a827c91c737a2e050668d8c39c224"; + let node_identity = make_node_identity(); - let msg = make_dht_inbound_message( + let dht_message = make_dht_inbound_message( &node_identity, TEST_MSG.to_vec(), DhtMessageFlags::empty(), false, false, ); - let hash1 = hash_inbound_message(&msg); + let decrypted1 = DecryptedDhtMessage::succeeded(wrap_in_envelope_body!(vec![]), None, dht_message); let node_identity = make_node_identity(); - let msg = make_dht_inbound_message( + let dht_message = make_dht_inbound_message( &node_identity, TEST_MSG.to_vec(), DhtMessageFlags::empty(), false, false, ); - let hash2 = hash_inbound_message(&msg); + let decrypted2 = DecryptedDhtMessage::succeeded(wrap_in_envelope_body!(vec![]), None, dht_message); - assert_eq!(hash1, hash2); - let subjects = &[hash1, hash2]; + assert_eq!(decrypted1.hash, decrypted2.hash); + let subjects = &[decrypted1.hash, decrypted2.hash]; assert!(subjects.iter().all(|h| h.to_hex() == EXPECTED_HASH)); } } diff --git a/comms/dht/src/dht.rs b/comms/dht/src/dht.rs index 1b762e02bc..cae6bbbfb3 100644 --- a/comms/dht/src/dht.rs +++ b/comms/dht/src/dht.rs @@ -295,21 +295,21 @@ impl Dht { ServiceBuilder::new() .layer(MetricsLayer::new(self.metrics_collector.clone())) .layer(inbound::DeserializeLayer::new(self.peer_manager.clone())) + .layer(filter::FilterLayer::new(self.unsupported_saf_messages_filter())) + .layer(inbound::DecryptionLayer::new( + self.config.clone(), + self.node_identity.clone(), + self.connectivity.clone(), + )) .layer(DedupLayer::new( self.dht_requester(), self.config.dedup_allowed_message_occurrences, )) - .layer(filter::FilterLayer::new(self.unsupported_saf_messages_filter())) + .layer(filter::FilterLayer::new(filter_messages_to_rebroadcast)) .layer(MessageLoggingLayer::new(format!( "Inbound [{}]", self.node_identity.node_id().short_str() ))) - .layer(inbound::DecryptionLayer::new( - self.config.clone(), - self.node_identity.clone(), - self.connectivity.clone(), - )) - .layer(filter::FilterLayer::new(filter_messages_to_rebroadcast)) .layer(store_forward::StoreLayer::new( self.config.clone(), Arc::clone(&self.peer_manager), diff --git a/comms/dht/src/inbound/message.rs b/comms/dht/src/inbound/message.rs index a2f75b755f..048b9c65e7 100644 --- a/comms/dht/src/inbound/message.rs +++ b/comms/dht/src/inbound/message.rs @@ -21,6 +21,7 @@ // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. use crate::envelope::{DhtMessageFlags, DhtMessageHeader}; +use digest::Digest; use std::{ fmt, fmt::{Display, Formatter}, @@ -29,9 +30,17 @@ use std::{ use tari_comms::{ message::{EnvelopeBody, MessageTag}, peer_manager::Peer, - types::CommsPublicKey, + types::{Challenge, CommsPublicKey}, }; +fn hash_inbound_message(message: &DhtInboundMessage) -> Vec { + Challenge::new() + .chain(&message.dht_header.origin_mac) + .chain(&message.body) + .finalize() + .to_vec() +} + #[derive(Debug, Clone)] pub struct DhtInboundMessage { pub tag: MessageTag, @@ -84,6 +93,7 @@ pub struct DecryptedDhtMessage { pub is_already_forwarded: bool, pub decryption_result: Result>, pub dedup_hit_count: u32, + pub hash: Vec, } impl DecryptedDhtMessage { @@ -104,6 +114,7 @@ impl DecryptedDhtMessage { message: DhtInboundMessage, ) -> Self { Self { + hash: hash_inbound_message(&message), tag: message.tag, source_peer: message.source_peer, authenticated_origin, @@ -118,6 +129,7 @@ impl DecryptedDhtMessage { pub fn failed(message: DhtInboundMessage) -> Self { Self { + hash: hash_inbound_message(&message), tag: message.tag, source_peer: message.source_peer, authenticated_origin: None, diff --git a/comms/dht/tests/dht.rs b/comms/dht/tests/dht.rs index cebedac101..de83924bc5 100644 --- a/comms/dht/tests/dht.rs +++ b/comms/dht/tests/dht.rs @@ -623,6 +623,160 @@ async fn dht_propagate_dedup() { assert_eq!(count_messages_received(&received, &[&node_C_id]), 1); } +#[tokio::test] +#[allow(non_snake_case)] +async fn dht_do_not_store_invalid_message_in_dedup() { + let mut config = dht_config(); + config.dedup_allowed_message_occurrences = 3; + + // Node C receives messages from A and B + let mut node_C = make_node("node_B", PeerFeatures::COMMUNICATION_NODE, config.clone(), None).await; + + // Node B forwards a message from A but modifies it + let mut node_B = make_node( + "node_B", + PeerFeatures::COMMUNICATION_NODE, + config.clone(), + Some(node_C.to_peer()), + ) + .await; + + // Node A creates a message sends it to B, B modifies it, sends it to C; Node A sends message to C + let node_A = make_node("node_A", PeerFeatures::COMMUNICATION_NODE, config.clone(), [ + node_B.to_peer(), + node_C.to_peer(), + ]) + .await; + + log::info!( + "NodeA = {}, NodeB = {}, NodeC = {}", + node_A.node_identity().node_id().short_str(), + node_B.node_identity().node_id().short_str(), + node_C.node_identity().node_id().short_str(), + ); + + // Connect the peers that should be connected + node_A + .comms + .connectivity() + .dial_peer(node_B.node_identity().node_id().clone()) + .await + .unwrap(); + + node_A + .comms + .connectivity() + .dial_peer(node_C.node_identity().node_id().clone()) + .await + .unwrap(); + + node_B + .comms + .connectivity() + .dial_peer(node_C.node_identity().node_id().clone()) + .await + .unwrap(); + + let mut node_C_messaging = node_C.messaging_events.subscribe(); + + #[derive(Clone, PartialEq, ::prost::Message)] + struct Person { + #[prost(string, tag = "1")] + name: String, + #[prost(uint32, tag = "2")] + age: u32, + } + + // Just a message to test connectivity between Node A -> Node C, and to get the header from + let out_msg = OutboundDomainMessage::new(123, Person { + name: "John Conway".into(), + age: 82, + }); + + node_A + .dht + .outbound_requester() + .send_message( + SendMessageParams::new() + .propagate(NodeDestination::Unknown, vec![node_C.node_identity().node_id().clone()]) + .with_destination(NodeDestination::Unknown) + .finish(), + out_msg, + ) + .await + .unwrap(); + + // Get the message that was received by Node B + let mut msg = node_B.next_inbound_message(Duration::from_secs(10)).await.unwrap(); + let bytes = msg.decryption_result.unwrap().to_encoded_bytes(); + + // Clone header without modification + let header_unmodified = msg.dht_header.clone(); + + // Modify the header + msg.dht_header.message_type = DhtMessageType::from_i32(3i32).unwrap(); + + // Forward modified message to Node C - Should get us banned + node_B + .dht + .outbound_requester() + .send_raw( + SendMessageParams::new() + .propagate(NodeDestination::Unknown, vec![msg.source_peer.node_id.clone()]) + .with_destination(NodeDestination::Unknown) + .with_dht_header(msg.dht_header) + .finish(), + bytes.clone(), + ) + .await + .unwrap(); + + node_A + .dht + .outbound_requester() + .send_raw( + SendMessageParams::new() + .propagate(NodeDestination::Unknown, vec![]) + .with_dht_header(header_unmodified) + .finish(), + bytes, + ) + .await + .unwrap(); + + // Node C receives the correct message from Node A + let msg = node_C + .next_inbound_message(Duration::from_secs(10)) + .await + .expect("Node C expected an inbound message but it never arrived"); + assert!(msg.decryption_succeeded()); + log::info!("Received message {}", msg.tag); + let person = msg + .decryption_result + .unwrap() + .decode_part::(1) + .unwrap() + .unwrap(); + assert_eq!(person.name, "John Conway"); + // TODO Test not working as it receives the message only if dedup_allowed_message_occurrences > 1 + + let node_A_id = node_A.node_identity().node_id().clone(); + let node_B_id = node_B.node_identity().node_id().clone(); + + node_A.shutdown().await; + node_B.shutdown().await; + node_C.shutdown().await; + + // Check the message flow BEFORE deduping + let received = filter_received(collect_try_recv!(node_C_messaging, timeout = Duration::from_secs(20))); + + let received_from_a = count_messages_received(&received, &[&node_A_id]); + let received_from_b = count_messages_received(&received, &[&node_B_id]); + + assert_eq!(received_from_a, 1); + assert_eq!(received_from_b, 1); +} + #[tokio::test] #[allow(non_snake_case)] async fn dht_repropagate() {