Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: Validate before dedup #3450

Closed
43 changes: 20 additions & 23 deletions comms/dht/src/dedup/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,24 +24,19 @@ mod dedup_cache;

pub use dedup_cache::DedupCacheDatabase;

use crate::{actor::DhtRequester, inbound::DhtInboundMessage};
use digest::Digest;
use crate::{actor::DhtRequester, inbound::DecryptedDhtMessage};
use futures::{future::BoxFuture, task::Context};
use log::*;
use std::task::Poll;
use tari_comms::{pipeline::PipelineError, types::Challenge};
use tari_comms::pipeline::PipelineError;
use tari_utilities::hex::Hex;
use tower::{layer::Layer, Service, ServiceExt};

const LOG_TARGET: &str = "comms::dht::dedup";

fn hash_inbound_message(message: &DhtInboundMessage) -> Vec<u8> {
Challenge::new().chain(&message.body).finalize().to_vec()
}

/// # DHT Deduplication middleware
///
/// Takes in a `DhtInboundMessage` and checks the message signature cache for duplicates.
/// Takes in a `DecryptedDhtMessage` and checks the message signature cache for duplicates.
/// If a duplicate message is detected, it is discarded.
#[derive(Clone)]
pub struct DedupMiddleware<S> {
Expand All @@ -60,9 +55,9 @@ impl<S> DedupMiddleware<S> {
}
}

impl<S> Service<DhtInboundMessage> for DedupMiddleware<S>
impl<S> Service<DecryptedDhtMessage> for DedupMiddleware<S>
where
S: Service<DhtInboundMessage, Response = (), Error = PipelineError> + Clone + Send + 'static,
S: Service<DecryptedDhtMessage, Response = (), Error = PipelineError> + Clone + Send + 'static,
S::Future: Send,
{
type Error = PipelineError;
Expand All @@ -73,22 +68,21 @@ where
Poll::Ready(Ok(()))
}

fn call(&mut self, mut message: DhtInboundMessage) -> Self::Future {
fn call(&mut self, mut message: DecryptedDhtMessage) -> Self::Future {
let next_service = self.next_service.clone();
let mut dht_requester = self.dht_requester.clone();
let allowed_message_occurrences = self.allowed_message_occurrences;
Box::pin(async move {
let hash = hash_inbound_message(&message);
trace!(
target: LOG_TARGET,
"Inserting message hash {} for message {} (Trace: {})",
hash.to_hex(),
message.hash.to_hex(),
message.tag,
message.dht_header.message_tag
);

message.dedup_hit_count = dht_requester
.add_message_to_dedup_cache(hash, message.source_peer.public_key.clone())
.add_message_to_dedup_cache(message.hash.clone(), message.source_peer.public_key.clone())
.await?;

if message.dedup_hit_count as usize > allowed_message_occurrences {
Expand Down Expand Up @@ -144,6 +138,7 @@ mod test {
envelope::DhtMessageFlags,
test_utils::{create_dht_actor_mock, make_dht_inbound_message, make_node_identity, service_spy},
};
use tari_comms::wrap_in_envelope_body;
use tari_test_utils::panic_context;
use tokio::runtime::Runtime;

Expand All @@ -163,13 +158,14 @@ mod test {

assert!(dedup.poll_ready(&mut cx).is_ready());
let node_identity = make_node_identity();
let msg = make_dht_inbound_message(&node_identity, Vec::new(), DhtMessageFlags::empty(), false, false);
let inbound_message = make_dht_inbound_message(&node_identity, vec![], DhtMessageFlags::empty(), false, false);
let decrypted_msg = DecryptedDhtMessage::succeeded(wrap_in_envelope_body!(vec![]), None, inbound_message);

rt.block_on(dedup.call(msg.clone())).unwrap();
rt.block_on(dedup.call(decrypted_msg.clone())).unwrap();
assert_eq!(spy.call_count(), 1);

mock_state.set_number_of_message_hits(4);
rt.block_on(dedup.call(msg)).unwrap();
rt.block_on(dedup.call(decrypted_msg)).unwrap();
assert_eq!(spy.call_count(), 1);
// Drop dedup so that the DhtMock will stop running
drop(dedup);
Expand All @@ -179,28 +175,29 @@ mod test {
fn deterministic_hash() {
const TEST_MSG: &[u8] = b"test123";
const EXPECTED_HASH: &str = "90cccd774db0ac8c6ea2deff0e26fc52768a827c91c737a2e050668d8c39c224";

let node_identity = make_node_identity();
let msg = make_dht_inbound_message(
let dht_message = make_dht_inbound_message(
&node_identity,
TEST_MSG.to_vec(),
DhtMessageFlags::empty(),
false,
false,
);
let hash1 = hash_inbound_message(&msg);
let decrypted1 = DecryptedDhtMessage::succeeded(wrap_in_envelope_body!(vec![]), None, dht_message);

let node_identity = make_node_identity();
let msg = make_dht_inbound_message(
let dht_message = make_dht_inbound_message(
&node_identity,
TEST_MSG.to_vec(),
DhtMessageFlags::empty(),
false,
false,
);
let hash2 = hash_inbound_message(&msg);
let decrypted2 = DecryptedDhtMessage::succeeded(wrap_in_envelope_body!(vec![]), None, dht_message);

assert_eq!(hash1, hash2);
let subjects = &[hash1, hash2];
assert_eq!(decrypted1.hash, decrypted2.hash);
let subjects = &[decrypted1.hash, decrypted2.hash];
assert!(subjects.iter().all(|h| h.to_hex() == EXPECTED_HASH));
}
}
14 changes: 7 additions & 7 deletions comms/dht/src/dht.rs
Original file line number Diff line number Diff line change
Expand Up @@ -295,21 +295,21 @@ impl Dht {
ServiceBuilder::new()
.layer(MetricsLayer::new(self.metrics_collector.clone()))
.layer(inbound::DeserializeLayer::new(self.peer_manager.clone()))
.layer(filter::FilterLayer::new(self.unsupported_saf_messages_filter()))
.layer(inbound::DecryptionLayer::new(
self.config.clone(),
self.node_identity.clone(),
self.connectivity.clone(),
))
.layer(DedupLayer::new(
self.dht_requester(),
self.config.dedup_allowed_message_occurrences,
))
.layer(filter::FilterLayer::new(self.unsupported_saf_messages_filter()))
.layer(filter::FilterLayer::new(filter_messages_to_rebroadcast))
.layer(MessageLoggingLayer::new(format!(
"Inbound [{}]",
self.node_identity.node_id().short_str()
)))
.layer(inbound::DecryptionLayer::new(
self.config.clone(),
self.node_identity.clone(),
self.connectivity.clone(),
))
.layer(filter::FilterLayer::new(filter_messages_to_rebroadcast))
.layer(store_forward::StoreLayer::new(
self.config.clone(),
Arc::clone(&self.peer_manager),
Expand Down
14 changes: 13 additions & 1 deletion comms/dht/src/inbound/message.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
// USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

use crate::envelope::{DhtMessageFlags, DhtMessageHeader};
use digest::Digest;
use std::{
fmt,
fmt::{Display, Formatter},
Expand All @@ -29,9 +30,17 @@ use std::{
use tari_comms::{
message::{EnvelopeBody, MessageTag},
peer_manager::Peer,
types::CommsPublicKey,
types::{Challenge, CommsPublicKey},
};

fn hash_inbound_message(message: &DhtInboundMessage) -> Vec<u8> {
Challenge::new()
.chain(&message.dht_header.origin_mac)
.chain(&message.body)
.finalize()
.to_vec()
}

#[derive(Debug, Clone)]
pub struct DhtInboundMessage {
pub tag: MessageTag,
Expand Down Expand Up @@ -84,6 +93,7 @@ pub struct DecryptedDhtMessage {
pub is_already_forwarded: bool,
pub decryption_result: Result<EnvelopeBody, Vec<u8>>,
pub dedup_hit_count: u32,
pub hash: Vec<u8>,
}

impl DecryptedDhtMessage {
Expand All @@ -104,6 +114,7 @@ impl DecryptedDhtMessage {
message: DhtInboundMessage,
) -> Self {
Self {
hash: hash_inbound_message(&message),
tag: message.tag,
source_peer: message.source_peer,
authenticated_origin,
Expand All @@ -118,6 +129,7 @@ impl DecryptedDhtMessage {

pub fn failed(message: DhtInboundMessage) -> Self {
Self {
hash: hash_inbound_message(&message),
tag: message.tag,
source_peer: message.source_peer,
authenticated_origin: None,
Expand Down
154 changes: 154 additions & 0 deletions comms/dht/tests/dht.rs
Original file line number Diff line number Diff line change
Expand Up @@ -623,6 +623,160 @@ async fn dht_propagate_dedup() {
assert_eq!(count_messages_received(&received, &[&node_C_id]), 1);
}

#[tokio::test]
#[allow(non_snake_case)]
async fn dht_do_not_store_invalid_message_in_dedup() {
let mut config = dht_config();
config.dedup_allowed_message_occurrences = 3;

// Node C receives messages from A and B
let mut node_C = make_node("node_B", PeerFeatures::COMMUNICATION_NODE, config.clone(), None).await;

// Node B forwards a message from A but modifies it
let mut node_B = make_node(
"node_B",
PeerFeatures::COMMUNICATION_NODE,
config.clone(),
Some(node_C.to_peer()),
)
.await;

// Node A creates a message sends it to B, B modifies it, sends it to C; Node A sends message to C
let node_A = make_node("node_A", PeerFeatures::COMMUNICATION_NODE, config.clone(), [
node_B.to_peer(),
node_C.to_peer(),
])
.await;

log::info!(
"NodeA = {}, NodeB = {}, NodeC = {}",
node_A.node_identity().node_id().short_str(),
node_B.node_identity().node_id().short_str(),
node_C.node_identity().node_id().short_str(),
);

// Connect the peers that should be connected
node_A
.comms
.connectivity()
.dial_peer(node_B.node_identity().node_id().clone())
.await
.unwrap();

node_A
.comms
.connectivity()
.dial_peer(node_C.node_identity().node_id().clone())
.await
.unwrap();

node_B
.comms
.connectivity()
.dial_peer(node_C.node_identity().node_id().clone())
.await
.unwrap();

let mut node_C_messaging = node_C.messaging_events.subscribe();

#[derive(Clone, PartialEq, ::prost::Message)]
struct Person {
#[prost(string, tag = "1")]
name: String,
#[prost(uint32, tag = "2")]
age: u32,
}

// Just a message to test connectivity between Node A -> Node C, and to get the header from
let out_msg = OutboundDomainMessage::new(123, Person {
name: "John Conway".into(),
age: 82,
});

node_A
.dht
.outbound_requester()
.send_message(
SendMessageParams::new()
.propagate(NodeDestination::Unknown, vec![node_C.node_identity().node_id().clone()])
.with_destination(NodeDestination::Unknown)
.finish(),
out_msg,
)
.await
.unwrap();

// Get the message that was received by Node B
let mut msg = node_B.next_inbound_message(Duration::from_secs(10)).await.unwrap();
let bytes = msg.decryption_result.unwrap().to_encoded_bytes();

// Clone header without modification
let header_unmodified = msg.dht_header.clone();

// Modify the header
msg.dht_header.message_type = DhtMessageType::from_i32(3i32).unwrap();

// Forward modified message to Node C - Should get us banned
node_B
.dht
.outbound_requester()
.send_raw(
SendMessageParams::new()
.propagate(NodeDestination::Unknown, vec![msg.source_peer.node_id.clone()])
.with_destination(NodeDestination::Unknown)
.with_dht_header(msg.dht_header)
.finish(),
bytes.clone(),
)
.await
.unwrap();

node_A
.dht
.outbound_requester()
.send_raw(
SendMessageParams::new()
.propagate(NodeDestination::Unknown, vec![])
.with_dht_header(header_unmodified)
.finish(),
bytes,
)
.await
.unwrap();

// Node C receives the correct message from Node A
let msg = node_C
.next_inbound_message(Duration::from_secs(10))
.await
.expect("Node C expected an inbound message but it never arrived");
assert!(msg.decryption_succeeded());
log::info!("Received message {}", msg.tag);
let person = msg
.decryption_result
.unwrap()
.decode_part::<Person>(1)
.unwrap()
.unwrap();
assert_eq!(person.name, "John Conway");
// TODO Test not working as it receives the message only if dedup_allowed_message_occurrences > 1

let node_A_id = node_A.node_identity().node_id().clone();
let node_B_id = node_B.node_identity().node_id().clone();

node_A.shutdown().await;
node_B.shutdown().await;
node_C.shutdown().await;

// Check the message flow BEFORE deduping
let received = filter_received(collect_try_recv!(node_C_messaging, timeout = Duration::from_secs(20)));

let received_from_a = count_messages_received(&received, &[&node_A_id]);
let received_from_b = count_messages_received(&received, &[&node_B_id]);

assert_eq!(received_from_a, 1);
assert_eq!(received_from_b, 1);
}

#[tokio::test]
#[allow(non_snake_case)]
async fn dht_repropagate() {
Expand Down