Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(dht/encryption): greatly reduce heap allocations for encrypted messaging #4753

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion comms/core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,6 @@ pub mod multiaddr {
}

pub use async_trait::async_trait;
pub use bytes::{Bytes, BytesMut};
pub use bytes::{Buf, BufMut, Bytes, BytesMut};
#[cfg(feature = "rpc")]
pub use tower::make::MakeService;
13 changes: 13 additions & 0 deletions comms/core/src/message/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@

#[macro_use]
mod envelope;

use bytes::BytesMut;
pub use envelope::EnvelopeBody;

mod error;
Expand All @@ -52,5 +54,16 @@ pub trait MessageExt: prost::Message {
);
buf
}

/// Encodes a message into a BytesMut, allocating the buffer on the heap as necessary.
fn encode_into_bytes_mut(&self) -> BytesMut
where Self: Sized {
let mut buf = BytesMut::with_capacity(self.encoded_len());
self.encode(&mut buf).expect(
"prost::Message::encode documentation says it is infallible unless the buffer has insufficient capacity. \
This buffer's capacity was set with encoded_len",
);
buf
}
}
impl<T: prost::Message> MessageExt for T {}
2 changes: 1 addition & 1 deletion comms/core/src/protocol/rpc/body.rs
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ impl BodyBytes {
}

pub fn into_vec(self) -> Vec<u8> {
self.0.map(|bytes| bytes.to_vec()).unwrap_or_else(Vec::new)
self.0.map(|bytes| bytes.into()).unwrap_or_else(Vec::new)
}

pub fn into_bytes(self) -> Option<Bytes> {
Expand Down
1 change: 0 additions & 1 deletion comms/dht/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ tari_common_sqlite = { path = "../../common_sqlite" }

anyhow = "1.0.53"
bitflags = "1.2.0"
bytes = "0.5"
chacha20 = "0.7.1"
chacha20poly1305 = "0.9.1"
chrono = { version = "0.4.19", default-features = false }
Expand Down
299 changes: 169 additions & 130 deletions comms/dht/src/crypt.rs

Large diffs are not rendered by default.

8 changes: 4 additions & 4 deletions comms/dht/src/dedup/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ mod test {
assert!(dedup.poll_ready(&mut cx).is_ready());
let node_identity = make_node_identity();
let inbound_message =
make_dht_inbound_message(&node_identity, vec![], DhtMessageFlags::empty(), false, false).unwrap();
make_dht_inbound_message(&node_identity, &vec![], DhtMessageFlags::empty(), false, false).unwrap();
let decrypted_msg = DecryptedDhtMessage::succeeded(wrap_in_envelope_body!(vec![]), None, inbound_message);

rt.block_on(dedup.call(decrypted_msg.clone())).unwrap();
Expand All @@ -213,12 +213,12 @@ mod test {
#[test]
fn deterministic_hash() {
const TEST_MSG: &[u8] = b"test123";
const EXPECTED_HASH: &str = "d6333668f259f677703fbe4e89152ee41c7c01f6dec502befc63120246523ffe";
const EXPECTED_HASH: &str = "1c2bb1bcff443af4441b789bd1d6984bb8d7bed2c9f85e8cf4f45615fdd9e47d";
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

😕 Why does the hash change here? It looks like this should be a drop-in replacement?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah looks like a breaking change :) This has more to do with changing the test helper to take a prost::Message instead of raw bytes. Now the Vec is serialized into protobuf before being hashed, which matches thebehaviour of the real code.


let node_identity = make_node_identity();
let dht_message = make_dht_inbound_message(
&node_identity,
TEST_MSG.to_vec(),
&TEST_MSG.to_vec(),
DhtMessageFlags::empty(),
false,
false,
Expand All @@ -229,7 +229,7 @@ mod test {
let node_identity = make_node_identity();
let dht_message = make_dht_inbound_message(
&node_identity,
TEST_MSG.to_vec(),
&TEST_MSG.to_vec(),
DhtMessageFlags::empty(),
false,
false,
Expand Down
11 changes: 6 additions & 5 deletions comms/dht/src/dht.rs
Original file line number Diff line number Diff line change
Expand Up @@ -494,7 +494,7 @@ mod test {
let msg = wrap_in_envelope_body!(b"secret".to_vec());
let dht_envelope = make_dht_envelope(
&node_identity,
msg.to_encoded_bytes(),
&msg,
DhtMessageFlags::empty(),
false,
MessageTag::new(),
Expand Down Expand Up @@ -546,7 +546,7 @@ mod test {
// Encrypt for self
let dht_envelope = make_dht_envelope(
&node_identity,
msg.to_encoded_bytes(),
&msg,
DhtMessageFlags::ENCRYPTED,
true,
MessageTag::new(),
Expand Down Expand Up @@ -602,10 +602,11 @@ mod test {
let node_identity2 = make_node_identity();
let ecdh_key = crypt::generate_ecdh_secret(node_identity2.secret_key(), node_identity2.public_key());
let key_message = crypt::generate_key_message(&ecdh_key);
let encrypted_bytes = crypt::encrypt(&key_message, &msg.to_encoded_bytes()).unwrap();
let mut encrypted_bytes = msg.encode_into_bytes_mut();
crypt::encrypt(&key_message, &mut encrypted_bytes).unwrap();
let dht_envelope = make_dht_envelope(
&node_identity2,
encrypted_bytes,
&encrypted_bytes.to_vec(),
DhtMessageFlags::ENCRYPTED,
true,
MessageTag::new(),
Expand Down Expand Up @@ -667,7 +668,7 @@ mod test {
let msg = wrap_in_envelope_body!(b"secret".to_vec());
let mut dht_envelope = make_dht_envelope(
&node_identity,
msg.to_encoded_bytes(),
&msg,
DhtMessageFlags::empty(),
false,
MessageTag::new(),
Expand Down
5 changes: 2 additions & 3 deletions comms/dht/src/envelope.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ use std::{
};

use bitflags::bitflags;
use bytes::Bytes;
use chrono::{DateTime, NaiveDateTime, Utc};
use prost_types::Timestamp;
use serde::{Deserialize, Serialize};
Expand Down Expand Up @@ -249,10 +248,10 @@ impl From<DhtMessageHeader> for DhtHeader {
}

impl DhtEnvelope {
pub fn new(header: DhtHeader, body: &Bytes) -> Self {
pub fn new(header: DhtHeader, body: Vec<u8>) -> Self {
Self {
header: Some(header),
body: body.to_vec(),
body,
}
}
}
Expand Down
51 changes: 23 additions & 28 deletions comms/dht/src/inbound/decryption.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ use tari_comms::{
message::EnvelopeBody,
peer_manager::NodeIdentity,
pipeline::PipelineError,
BytesMut,
};
use thiserror::Error;
use tower::{layer::Layer, Service, ServiceExt};
Expand Down Expand Up @@ -406,11 +407,11 @@ where S: Service<DecryptedDhtMessage, Response = (), Error = PipelineError>
message_body: &[u8],
) -> Result<EnvelopeBody, DecryptionError> {
let key_message = crypt::generate_key_message(shared_secret);
let decrypted =
crypt::decrypt(&key_message, message_body).map_err(DecryptionError::DecryptionFailedMalformedCipher)?;
let mut decrypted = BytesMut::from(message_body);
crypt::decrypt(&key_message, &mut decrypted).map_err(DecryptionError::DecryptionFailedMalformedCipher)?;
// Deserialization into an EnvelopeBody is done here to determine if the
// decryption produced valid bytes or not.
EnvelopeBody::decode(decrypted.as_slice())
EnvelopeBody::decode(decrypted.freeze())
.and_then(|body| {
// Check if we received a body length of zero
//
Expand Down Expand Up @@ -477,10 +478,11 @@ mod test {

use futures::{executor::block_on, future};
use tari_comms::{
message::{MessageExt, MessageTag},
message::MessageTag,
runtime,
test_utils::mocks::create_connectivity_mock,
wrap_in_envelope_body,
BytesMut,
};
use tari_test_utils::{counter_context, unpack_enum};
use tokio::time::sleep;
Expand All @@ -492,6 +494,7 @@ mod test {
test_utils::{
make_dht_header,
make_dht_inbound_message,
make_dht_inbound_message_raw,
make_keypair,
make_node_identity,
make_valid_message_signature,
Expand Down Expand Up @@ -527,14 +530,8 @@ mod test {
let mut service = DecryptionService::new(Default::default(), node_identity.clone(), connectivity, service);

let plain_text_msg = wrap_in_envelope_body!(b"Secret plans".to_vec());
let inbound_msg = make_dht_inbound_message(
&node_identity,
plain_text_msg.to_encoded_bytes(),
DhtMessageFlags::ENCRYPTED,
true,
true,
)
.unwrap();
let inbound_msg =
make_dht_inbound_message(&node_identity, &plain_text_msg, DhtMessageFlags::ENCRYPTED, true, true).unwrap();

block_on(service.call(inbound_msg)).unwrap();
let decrypted = result.lock().unwrap().take().unwrap();
Expand All @@ -560,7 +557,7 @@ mod test {
let some_other_node_identity = make_node_identity();
let inbound_msg = make_dht_inbound_message(
&some_other_node_identity,
some_secret,
&some_secret,
DhtMessageFlags::ENCRYPTED,
true,
true,
Expand Down Expand Up @@ -591,7 +588,7 @@ mod test {

let nonsense = b"Cannot Decrypt this".to_vec();
let inbound_msg =
make_dht_inbound_message(&node_identity, nonsense.clone(), DhtMessageFlags::ENCRYPTED, true, true).unwrap();
make_dht_inbound_message_raw(&node_identity, nonsense, DhtMessageFlags::ENCRYPTED, true, true).unwrap();

let err = service.call(inbound_msg).await.unwrap_err();
let err = err.downcast::<DecryptionError>().unwrap();
Expand All @@ -615,14 +612,8 @@ mod test {
let mut service = DecryptionService::new(Default::default(), node_identity.clone(), connectivity, service);

let plain_text_msg = b"Secret message to nowhere".to_vec();
let inbound_msg = make_dht_inbound_message(
&node_identity,
plain_text_msg.to_encoded_bytes(),
DhtMessageFlags::ENCRYPTED,
true,
false,
)
.unwrap();
let inbound_msg =
make_dht_inbound_message(&node_identity, &plain_text_msg, DhtMessageFlags::ENCRYPTED, true, false).unwrap();

let err = service.call(inbound_msg).await.unwrap_err();
let err = err.downcast::<DecryptionError>().unwrap();
Expand All @@ -645,13 +636,15 @@ mod test {
let node_identity = make_node_identity();
let mut service = DecryptionService::new(Default::default(), node_identity.clone(), connectivity, service);

let plain_text_msg = b"Secret message".to_vec();
let plain_text_msg = BytesMut::from(b"Secret message".as_slice());
let (e_secret_key, e_public_key) = make_keypair();
let shared_secret = crypt::generate_ecdh_secret(&e_secret_key, node_identity.public_key());
let key_message = crypt::generate_key_message(&shared_secret);
let msg_tag = MessageTag::new();

let message = crypt::encrypt(&key_message, &plain_text_msg).unwrap();
let mut message = plain_text_msg.clone();
crypt::encrypt(&key_message, &mut message).unwrap();
let message = message.freeze();
let header = make_dht_header(
&node_identity,
&e_public_key,
Expand All @@ -663,7 +656,7 @@ mod test {
true,
)
.unwrap();
let envelope = DhtEnvelope::new(header.into(), &message.into());
let envelope = DhtEnvelope::new(header.into(), message.into());
let msg_tag = MessageTag::new();
let mut inbound_msg = DhtInboundMessage::new(
msg_tag,
Expand Down Expand Up @@ -706,13 +699,15 @@ mod test {
let node_identity = make_node_identity();
let mut service = DecryptionService::new(Default::default(), node_identity.clone(), connectivity, service);

let plain_text_msg = b"Public message".to_vec();
let plain_text_msg = BytesMut::from(b"Public message".as_slice());
let (e_secret_key, e_public_key) = make_keypair();
let shared_secret = crypt::generate_ecdh_secret(&e_secret_key, node_identity.public_key());
let key_message = crypt::generate_key_message(&shared_secret);
let msg_tag = MessageTag::new();

let message = crypt::encrypt(&key_message, &plain_text_msg).unwrap();
let mut message = plain_text_msg.clone();
crypt::encrypt(&key_message, &mut message).unwrap();
let message = message.freeze();
let header = make_dht_header(
&node_identity,
&e_public_key,
Expand All @@ -724,7 +719,7 @@ mod test {
true,
)
.unwrap();
let envelope = DhtEnvelope::new(header.into(), &message.into());
let envelope = DhtEnvelope::new(header.into(), message.into());
let msg_tag = MessageTag::new();
let mut inbound_msg = DhtInboundMessage::new(
msg_tag,
Expand Down
4 changes: 2 additions & 2 deletions comms/dht/src/inbound/deserialize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ mod test {

let dht_envelope = make_dht_envelope(
&node_identity,
b"A".to_vec(),
&b"A".to_vec(),
DhtMessageFlags::empty(),
false,
MessageTag::new(),
Expand All @@ -181,7 +181,7 @@ mod test {
.unwrap();

let msg = spy.pop_request().unwrap();
assert_eq!(msg.body, b"A".to_vec());
assert_eq!(msg.body, b"A".to_vec().to_encoded_bytes());
assert_eq!(msg.dht_header, dht_envelope.header.unwrap().try_into().unwrap());
}
}
2 changes: 1 addition & 1 deletion comms/dht/src/inbound/dht_handler/task.rs
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ where S: Service<DecryptedDhtMessage, Response = (), Error = PipelineError>
.with_debug_info("Propagating join message".to_string())
.with_dht_header(dht_header)
.finish(),
body.to_encoded_bytes(),
body.encode_into_bytes_mut(),
)
.await?;
}
Expand Down
20 changes: 10 additions & 10 deletions comms/dht/src/inbound/forward.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ use std::task::Poll;

use futures::{future::BoxFuture, task::Context};
use log::*;
use tari_comms::{peer_manager::Peer, pipeline::PipelineError};
use prost::bytes::BufMut;
use tari_comms::{peer_manager::Peer, pipeline::PipelineError, BytesMut};
use tari_utilities::epoch_time::EpochTime;
use tower::{layer::Layer, Service, ServiceExt};

Expand Down Expand Up @@ -204,12 +205,11 @@ where S: Service<DecryptedDhtMessage, Response = (), Error = PipelineError>
return Ok(());
}
}

let body = decryption_result
let err_body = decryption_result
.as_ref()
.err()
.cloned()
.expect("previous check that decryption failed");
.expect_err("previous check that decryption failed");
let mut body = BytesMut::with_capacity(err_body.len());
body.put(err_body.as_slice());

let excluded_peers = vec![source_peer.node_id.clone()];
let dest_node_id = dht_header.destination.to_derived_node_id();
Expand Down Expand Up @@ -259,7 +259,7 @@ where S: Service<DecryptedDhtMessage, Response = (), Error = PipelineError>
mod test {
use std::time::Duration;

use tari_comms::{runtime, runtime::task, wrap_in_envelope_body};
use tari_comms::{message::MessageExt, runtime, runtime::task, wrap_in_envelope_body};
use tokio::sync::mpsc;

use super::*;
Expand All @@ -278,7 +278,7 @@ mod test {

let node_identity = make_node_identity();
let inbound_msg =
make_dht_inbound_message(&node_identity, b"".to_vec(), DhtMessageFlags::empty(), false, false).unwrap();
make_dht_inbound_message(&node_identity, &b"".to_vec(), DhtMessageFlags::empty(), false, false).unwrap();
let msg = DecryptedDhtMessage::succeeded(
wrap_in_envelope_body!(Vec::new()),
Some(node_identity.public_key().clone()),
Expand All @@ -300,7 +300,7 @@ mod test {
let sample_body = b"Lorem ipsum";
let inbound_msg = make_dht_inbound_message(
&make_node_identity(),
sample_body.to_vec(),
&sample_body.to_vec(),
DhtMessageFlags::empty(),
false,
false,
Expand All @@ -318,7 +318,7 @@ mod test {
let (params, body) = oms_mock_state.pop_call().await.unwrap();

// Header and body are preserved when forwarding
assert_eq!(&body.to_vec(), &sample_body);
assert_eq!(&body.to_vec(), &sample_body.to_vec().to_encoded_bytes());
assert_eq!(params.dht_header.unwrap(), header);
}
}
Loading