Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(gadget-sdk)!: prevent duplicate and self-referential messages #458

Merged
merged 6 commits into from
Nov 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,7 @@ testcontainers = { version = "0.20.1" }
symbiotic-rs = { version = "0.1.0" }
dashmap = "6.1.0"
bincode2 = "2.0.1"
lru-mem = "0.3.0"

[profile.dev.package.backtrace]
opt-level = 3
Expand Down
38 changes: 20 additions & 18 deletions blueprint-manager/src/sdk/setup.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
use std::collections::BTreeMap;
use std::time::Duration;

use futures::stream::FuturesOrdered;
use futures::StreamExt;
use gadget_io::tokio::task::JoinHandle;
use gadget_sdk::clients::tangle::runtime::TangleRuntimeClient;
use gadget_sdk::network::Network;
use gadget_sdk::prometheus::PrometheusConfig;
use gadget_sdk::store::{ECDSAKeyStore, KeyValueStoreBackend};
use sp_core::{keccak_256, sr25519, Pair};
use std::collections::BTreeMap;
use std::time::Duration;

use crate::sdk::config::SingleGadgetConfig;
pub use gadget_io::KeystoreContainer;
Expand Down Expand Up @@ -110,28 +111,29 @@ pub async fn wait_for_connection_to_bootnodes(

debug!("Waiting for {n_required} peers to show up across {n_networks} networks");

let mut tasks = gadget_io::tokio::task::JoinSet::new();
let mut tasks = FuturesOrdered::new();

// For each network, we start a task that checks if we have enough peers connected
// and then we wait for all of them to finish.

let wait_for_peers = |handle: GossipHandle, n_required| async move {
'inner: loop {
let n_connected = handle.connected_peers();
if n_connected >= n_required {
break 'inner;
}
let topic = handle.topic();
debug!("`{topic}`: We currently have {n_connected}/{n_required} peers connected to network");
gadget_io::tokio::time::sleep(Duration::from_millis(1000)).await;
}
};

for handle in handles.values() {
tasks.spawn(wait_for_peers(handle.clone(), n_required));
tasks.push_back(wait_for_peers(handle, n_required));
}

// Wait for all tasks to finish
while tasks.join_next().await.is_some() {}
tasks.collect::<()>().await;

Ok(())
}

async fn wait_for_peers(handle: &GossipHandle, required: usize) {
loop {
let n_connected = handle.connected_peers();
if n_connected >= required {
return;
}
let topic = handle.topic();
debug!("`{topic}`: We currently have {n_connected}/{required} peers connected to network");
gadget_io::tokio::time::sleep(Duration::from_millis(1000)).await;
}
}
6 changes: 4 additions & 2 deletions blueprint-test-utils/src/test_ext.rs
Original file line number Diff line number Diff line change
Expand Up @@ -309,10 +309,12 @@ pub async fn new_test_ext_blueprint_manager<
pub fn find_open_tcp_bind_port() -> u16 {
let listener = std::net::TcpListener::bind(format!("{LOCAL_BIND_ADDR}:0"))
.expect("Should bind to localhost");
listener
let port = listener
.local_addr()
.expect("Should have a local address")
.port()
.port();
drop(listener);
port
}

pub struct LocalhostTestExt {
Expand Down
Submodule forge-std updated 1 files
+1 −1 package.json
2 changes: 1 addition & 1 deletion blueprints/incredible-squaring/contracts/lib/forge-std
Submodule forge-std updated 1 files
+1 −1 package.json
3 changes: 2 additions & 1 deletion sdk/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ url = { workspace = true, features = ["serde"] }
uuid = { workspace = true }
failure = { workspace = true }
num-bigint = { workspace = true }

# Keystore deps
ed25519-zebra = { workspace = true }
k256 = { workspace = true, features = ["ecdsa", "ecdsa-core", "arithmetic"] }
Expand Down Expand Up @@ -92,6 +91,8 @@ gadget-blueprint-proc-macro = { workspace = true, default-features = false }
gadget-context-derive = { workspace = true, default-features = false }
gadget-blueprint-proc-macro-core = { workspace = true, default-features = false }

lru-mem = { workspace = true }

# Benchmarking deps
sysinfo = { workspace = true }
dashmap = { workspace = true }
Expand Down
54 changes: 36 additions & 18 deletions sdk/src/network/gossip.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
clippy::module_name_repetitions,
clippy::exhaustive_enums
)]
use crate::error::Error;
use crate::{error, trace, warn};
use async_trait::async_trait;
use ecdsa::Public;
use gadget_io::tokio::sync::mpsc::UnboundedSender;
Expand All @@ -13,15 +15,13 @@ use libp2p::kad::store::MemoryStore;
use libp2p::{
gossipsub, mdns, request_response, swarm::NetworkBehaviour, swarm::SwarmEvent, PeerId,
};
use lru_mem::LruCache;
use serde::{Deserialize, Serialize};
use sp_core::ecdsa;
use sp_core::{ecdsa, sha2_256};
use std::collections::BTreeMap;
use std::sync::atomic::AtomicU32;
use std::sync::Arc;

use crate::error::Error;
use crate::{error, trace, warn};

use super::{Network, ParticipantInfo, ProtocolMessage};

/// Maximum allowed size for a Signed Message.
Expand All @@ -48,6 +48,7 @@ pub struct NetworkServiceWithoutSwarm<'a> {
pub ecdsa_peer_id_to_libp2p_id: Arc<RwLock<BTreeMap<ecdsa::Public, PeerId>>>,
pub ecdsa_key: &'a ecdsa::Pair,
pub span: tracing::Span,
pub my_id: PeerId,
}

impl<'a> NetworkServiceWithoutSwarm<'a> {
Expand All @@ -61,6 +62,7 @@ impl<'a> NetworkServiceWithoutSwarm<'a> {
ecdsa_peer_id_to_libp2p_id: &self.ecdsa_peer_id_to_libp2p_id,
ecdsa_key: self.ecdsa_key,
span: &self.span,
my_id: self.my_id,
}
}
}
Expand All @@ -71,6 +73,7 @@ pub struct NetworkService<'a> {
pub ecdsa_peer_id_to_libp2p_id: &'a Arc<RwLock<BTreeMap<ecdsa::Public, PeerId>>>,
pub ecdsa_key: &'a ecdsa::Pair,
pub span: &'a tracing::Span,
pub my_id: PeerId,
}

impl NetworkService<'_> {
Expand Down Expand Up @@ -247,13 +250,14 @@ impl NetworkService<'_> {
}
}

#[derive(Clone)]
pub struct GossipHandle {
pub topic: IdentTopic,
pub tx_to_outbound: UnboundedSender<IntraNodePayload>,
pub rx_from_inbound: Arc<Mutex<gadget_io::tokio::sync::mpsc::UnboundedReceiver<Vec<u8>>>>,
pub connected_peers: Arc<AtomicU32>,
pub ecdsa_peer_id_to_libp2p_id: Arc<RwLock<BTreeMap<ecdsa::Public, PeerId>>>,
pub recent_messages: parking_lot::Mutex<LruCache<[u8; 32], ()>>,
pub my_id: PeerId,
}

impl GossipHandle {
Expand Down Expand Up @@ -338,18 +342,29 @@ enum MessageType {
#[async_trait]
impl Network for GossipHandle {
async fn next_message(&self) -> Option<ProtocolMessage> {
let mut lock = self
.rx_from_inbound
.try_lock()
.expect("There should be only a single caller for `next_message`");
loop {
let mut lock = self
.rx_from_inbound
.try_lock()
.expect("There should be only a single caller for `next_message`");

let message = lock.recv().await?;
match bincode::deserialize(&message) {
Ok(message) => Some(message),
Err(e) => {
error!("Failed to deserialize message: {e}");
drop(lock);
Network::next_message(self).await
let message_bytes = lock.recv().await?;
drop(lock);
match bincode::deserialize::<ProtocolMessage>(&message_bytes) {
Ok(message) => {
let hash = sha2_256(&message.payload);
let mut map = self.recent_messages.lock();
if map
.insert(hash, ())
.expect("Should not exceed memory limit (rx)")
.is_none()
{
return Some(message);
}
}
Err(e) => {
error!("Failed to deserialize message: {e}");
}
}
}
}
Expand Down Expand Up @@ -377,14 +392,17 @@ impl Network for GossipHandle {
MessageType::Broadcast
};

let raw_payload = bincode::serialize(&message).map_err(|e| Error::Network {
reason: format!("Failed to serialize message: {e}"),
})?;
let payload_inner = match message_type {
MessageType::Broadcast => GossipOrRequestResponse::Gossip(GossipMessage {
topic: self.topic.to_string(),
raw_payload: bincode::serialize(&message).expect("Should serialize"),
raw_payload,
}),
MessageType::P2P(_) => GossipOrRequestResponse::Request(MyBehaviourRequest::Message {
topic: self.topic.to_string(),
raw_payload: bincode::serialize(&message).expect("Should serialize"),
raw_payload,
}),
};

Expand Down
6 changes: 6 additions & 0 deletions sdk/src/network/handlers/gossip.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,12 @@ impl NetworkService<'_> {
error!("Got message from unknown peer");
return;
};

// Reject messages from self
if origin == self.my_id {
return;
}

trace!("Got message from peer: {origin}");
match bincode::deserialize::<GossipMessage>(&message.data) {
Ok(GossipMessage { topic, raw_payload }) => {
Expand Down
5 changes: 5 additions & 0 deletions sdk/src/network/handlers/p2p.rs
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,11 @@ impl NetworkService<'_> {
)
}
Message { topic, raw_payload } => {
// Reject messages from self
if peer == self.my_id {
return;
}

let topic = IdentTopic::new(topic);
if let Some((_, tx, _)) = self
.inbound_mapping
Expand Down
8 changes: 4 additions & 4 deletions sdk/src/network/messaging.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ pub trait MessageMetadata {
}

#[async_trait]
pub trait Network {
pub trait MessagingNetwork {
type Message: MessageMetadata + Send + Sync + 'static;

async fn next_message(&self) -> Option<Payload<Self::Message>>;
Expand Down Expand Up @@ -133,7 +133,7 @@ where
M: MessageMetadata + Clone + Send + Sync + Serialize + for<'de> Deserialize<'de> + 'static,
B: Backend<M> + Send + Sync + 'static,
L: LocalDelivery<M> + Send + Sync + 'static,
N: Network<Message = M> + Send + Sync + 'static,
N: MessagingNetwork<Message = M> + Send + Sync + 'static,
{
backend: Arc<B>,
local_delivery: Arc<L>,
Expand All @@ -147,7 +147,7 @@ where
M: MessageMetadata + Clone + Send + Sync + Serialize + for<'de> Deserialize<'de> + 'static,
B: Backend<M> + Send + Sync + 'static,
L: LocalDelivery<M> + Send + Sync + 'static,
N: Network<Message = M> + Send + Sync + 'static,
N: MessagingNetwork<Message = M> + Send + Sync + 'static,
{
fn clone(&self) -> Self {
Self {
Expand All @@ -165,7 +165,7 @@ where
M: MessageMetadata + Clone + Send + Sync + Serialize + for<'de> Deserialize<'de> + 'static,
B: Backend<M> + Send + Sync + 'static,
L: LocalDelivery<M> + Send + Sync + 'static,
N: Network<Message = M> + Send + Sync + 'static,
N: MessagingNetwork<Message = M> + Send + Sync + 'static,
{
pub fn new(backend: B, local_delivery: L, network: N) -> Self {
let this = Self {
Expand Down
Loading
Loading