From de47365b4d47a51322dcd20f2b31af37297ad58e Mon Sep 17 00:00:00 2001 From: Stanimal Date: Wed, 18 Aug 2021 10:47:43 +0400 Subject: [PATCH] fix: improve p2p RPC robustness - Correctly handles edge case where the client can receive a response after the deadline + grace period has expired due to extreme latency. - Minor code simplifications - Adds integration stress tests to comms - Increase client-side deadline grace period. This is needed because at any point we could be transferring a lot of data, causing delays, which the client must tollerate. - Minor performance optimisations (e.g removed usage of `.split()` which uses a BiLock) --- Cargo.lock | 1 + .../wallet/src/connectivity_service/handle.rs | 4 +- .../src/connectivity_service/service.rs | 6 +- .../tests/output_manager_service/service.rs | 4 +- comms/Cargo.toml | 1 + comms/src/builder/error.rs | 2 +- .../src/connection_manager/peer_connection.rs | 12 +- comms/src/connectivity/manager.rs | 4 +- comms/src/protocol/rpc/client.rs | 54 +++- comms/src/protocol/rpc/client_pool.rs | 22 +- comms/src/protocol/rpc/context.rs | 13 +- comms/src/protocol/rpc/error.rs | 2 + comms/src/protocol/rpc/message.rs | 4 + comms/src/protocol/rpc/server/mock.rs | 2 +- comms/src/protocol/rpc/server/mod.rs | 194 +++++++----- comms/src/protocol/rpc/status.rs | 15 +- comms/src/protocol/rpc/test/client_pool.rs | 15 +- comms/src/protocol/rpc/test/smoke.rs | 55 ++-- comms/tests/greeting_service.rs | 166 ++++++++++ comms/tests/helpers.rs | 62 ++++ comms/tests/rpc_stress.rs | 290 ++++++++++++++++++ comms/tests/substream_stress.rs | 160 ++++++++++ 22 files changed, 935 insertions(+), 153 deletions(-) create mode 100644 comms/tests/greeting_service.rs create mode 100644 comms/tests/helpers.rs create mode 100644 comms/tests/rpc_stress.rs create mode 100644 comms/tests/substream_stress.rs diff --git a/Cargo.lock b/Cargo.lock index 9cbf3d7b19..affba253d4 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4486,6 +4486,7 @@ dependencies = [ "serde_json", "snow", "tari_common", + "tari_comms_rpc_macros", "tari_crypto", "tari_shutdown", "tari_storage", diff --git a/base_layer/wallet/src/connectivity_service/handle.rs b/base_layer/wallet/src/connectivity_service/handle.rs index 41c9d51c5a..d596c543db 100644 --- a/base_layer/wallet/src/connectivity_service/handle.rs +++ b/base_layer/wallet/src/connectivity_service/handle.rs @@ -66,7 +66,7 @@ impl WalletConnectivityHandle { /// Obtain a BaseNodeWalletRpcClient. /// /// This can be relied on to obtain a pooled BaseNodeWalletRpcClient rpc session from a currently selected base - /// node/nodes. It will be block until this is happens. The ONLY other time it will return is if the node is + /// node/nodes. It will block until this happens. The ONLY other time it will return is if the node is /// shutting down, where it will return None. Use this function whenever no work can be done without a /// BaseNodeWalletRpcClient RPC session. pub async fn obtain_base_node_wallet_rpc_client(&mut self) -> Option> { @@ -89,7 +89,7 @@ impl WalletConnectivityHandle { /// Obtain a BaseNodeSyncRpcClient. /// /// This can be relied on to obtain a pooled BaseNodeSyncRpcClient rpc session from a currently selected base - /// node/nodes. It will be block until this is happens. The ONLY other time it will return is if the node is + /// node/nodes. It will block until this happens. The ONLY other time it will return is if the node is /// shutting down, where it will return None. Use this function whenever no work can be done without a /// BaseNodeSyncRpcClient RPC session. pub async fn obtain_base_node_sync_rpc_client(&mut self) -> Option> { diff --git a/base_layer/wallet/src/connectivity_service/service.rs b/base_layer/wallet/src/connectivity_service/service.rs index 5f8a7c469b..f71c1ae598 100644 --- a/base_layer/wallet/src/connectivity_service/service.rs +++ b/base_layer/wallet/src/connectivity_service/service.rs @@ -242,8 +242,10 @@ impl WalletConnectivityService { conn.peer_node_id() ); self.pools = Some(ClientPoolContainer { - base_node_sync_rpc_client: conn.create_rpc_client_pool(self.config.base_node_rpc_pool_size), - base_node_wallet_rpc_client: conn.create_rpc_client_pool(self.config.base_node_rpc_pool_size), + base_node_sync_rpc_client: conn + .create_rpc_client_pool(self.config.base_node_rpc_pool_size, Default::default()), + base_node_wallet_rpc_client: conn + .create_rpc_client_pool(self.config.base_node_rpc_pool_size, Default::default()), }); self.notify_pending_requests().await?; debug!( diff --git a/base_layer/wallet/tests/output_manager_service/service.rs b/base_layer/wallet/tests/output_manager_service/service.rs index 72cd592725..415b5d8261 100644 --- a/base_layer/wallet/tests/output_manager_service/service.rs +++ b/base_layer/wallet/tests/output_manager_service/service.rs @@ -1816,14 +1816,14 @@ fn test_txo_validation_rpc_timeout() { .unwrap(); runtime.block_on(async { - let mut delay = delay_for(Duration::from_secs(60)).fuse(); + let mut delay = delay_for(Duration::from_secs(100)).fuse(); let mut failed = 0; loop { futures::select! { event = event_stream.select_next_some() => { if let Ok(msg) = event { if let OutputManagerEvent::TxoValidationFailure(_,_) = (*msg).clone() { - failed+=1; + failed+=1; } } diff --git a/comms/Cargo.toml b/comms/Cargo.toml index 61466e7c1c..b30b7439c8 100644 --- a/comms/Cargo.toml +++ b/comms/Cargo.toml @@ -48,6 +48,7 @@ anyhow = "1.0.32" [dev-dependencies] tari_test_utils = {version="^0.9", path="../infrastructure/test_utils"} +tari_comms_rpc_macros = {version="*", path="./rpc_macros"} env_logger = "0.7.0" serde_json = "1.0.39" diff --git a/comms/src/builder/error.rs b/comms/src/builder/error.rs index b213fd09f0..a5df9ed2a9 100644 --- a/comms/src/builder/error.rs +++ b/comms/src/builder/error.rs @@ -36,7 +36,7 @@ pub enum CommsBuilderError { ConnectionManagerError(#[from] ConnectionManagerError), #[error("Node identity not set. Call `with_node_identity(node_identity)` on [CommsBuilder]")] NodeIdentityNotSet, - #[error("Shutdown signa not set. Call `with_shutdown_signal(shutdown_signal)` on [CommsBuilder]")] + #[error("Shutdown signal not set. Call `with_shutdown_signal(shutdown_signal)` on [CommsBuilder]")] ShutdownSignalNotSet, #[error("The PeerStorage was not provided to the CommsBuilder. Use `with_peer_storage` to set it.")] PeerStorageNotProvided, diff --git a/comms/src/connection_manager/peer_connection.rs b/comms/src/connection_manager/peer_connection.rs index 4fef839c68..3e15b50bb0 100644 --- a/comms/src/connection_manager/peer_connection.rs +++ b/comms/src/connection_manager/peer_connection.rs @@ -233,9 +233,15 @@ impl PeerConnection { /// Creates a new RpcClientPool that can be shared between tasks. The client pool will lazily establish up to /// `max_sessions` sessions and provides client session that is least used. #[cfg(feature = "rpc")] - pub fn create_rpc_client_pool(&self, max_sessions: usize) -> RpcClientPool - where T: RpcPoolClient + From + NamedProtocolService + Clone { - RpcClientPool::new(self.clone(), max_sessions) + pub fn create_rpc_client_pool( + &self, + max_sessions: usize, + client_config: RpcClientBuilder, + ) -> RpcClientPool + where + T: RpcPoolClient + From + NamedProtocolService + Clone, + { + RpcClientPool::new(self.clone(), max_sessions, client_config) } /// Immediately disconnects the peer connection. This can only fail if the peer connection worker diff --git a/comms/src/connectivity/manager.rs b/comms/src/connectivity/manager.rs index 9998aa5397..b692c8c5af 100644 --- a/comms/src/connectivity/manager.rs +++ b/comms/src/connectivity/manager.rs @@ -228,7 +228,7 @@ impl ConnectivityManagerActor { _ => { debug!( target: LOG_TARGET, - "No existing connection found for peer `{}`. Dialling...", + "No existing connection found for peer `{}`. Dialing...", node_id.short_str() ); if let Err(err) = self.connection_manager.send_dial_peer(node_id, reply).await { @@ -528,7 +528,7 @@ impl ConnectivityManagerActor { num_failed ); if self.peer_manager.set_offline(node_id, true).await? { - warn!( + debug!( target: LOG_TARGET, "Peer `{}` was marked as offline but was already offline.", node_id ); diff --git a/comms/src/protocol/rpc/client.rs b/comms/src/protocol/rpc/client.rs index 22f62e0680..5ed6a040d8 100644 --- a/comms/src/protocol/rpc/client.rs +++ b/comms/src/protocol/rpc/client.rs @@ -50,6 +50,7 @@ use futures::{ use log::*; use prost::Message; use std::{ + convert::TryFrom, fmt, future::Future, marker::PhantomData, @@ -75,7 +76,7 @@ impl RpcClient { TSubstream: AsyncRead + AsyncWrite + Unpin + Send + 'static, { let (request_tx, request_rx) = mpsc::channel(1); - let connector = ClientConnector { inner: request_tx }; + let connector = ClientConnector::new(request_tx); let (ready_tx, ready_rx) = oneshot::channel(); task::spawn(RpcClientWorker::new(config, request_rx, framed, ready_tx).run()); ready_rx @@ -227,7 +228,7 @@ impl Default for RpcClientConfig { fn default() -> Self { Self { deadline: Some(Duration::from_secs(30)), - deadline_grace_period: Duration::from_secs(10), + deadline_grace_period: Duration::from_secs(30), handshake_timeout: Duration::from_secs(30), } } @@ -239,6 +240,10 @@ pub struct ClientConnector { } impl ClientConnector { + pub(self) fn new(sender: mpsc::Sender) -> Self { + Self { inner: sender } + } + pub fn close(&mut self) { self.inner.close_channel(); } @@ -293,8 +298,8 @@ pub struct RpcClientWorker { request_rx: mpsc::Receiver, framed: CanonicalFraming, // Request ids are limited to u16::MAX because varint encoding is used over the wire and the magnitude of the value - // sent determines the byte size. A u16 will be more than enough for the purpose (currently just logging) - request_id: u16, + // sent determines the byte size. A u16 will be more than enough for the purpose + next_request_id: u16, ready_tx: Option>>, latency: Option, } @@ -312,7 +317,7 @@ where TSubstream: AsyncRead + AsyncWrite + Unpin + Send config, request_rx, framed, - request_id: 0, + next_request_id: 0, ready_tx: Some(ready_tx), latency: None, } @@ -348,7 +353,7 @@ where TSubstream: AsyncRead + AsyncWrite + Unpin + Send match req { SendRequest { request, reply } => { if let Err(err) = self.do_request_response(request, reply).await { - debug!(target: LOG_TARGET, "Unexpected error: {}. Worker is terminating.", err); + error!(target: LOG_TARGET, "Unexpected error: {}. Worker is terminating.", err); break; } }, @@ -433,8 +438,8 @@ where TSubstream: AsyncRead + AsyncWrite + Unpin + Send }, }; - match Self::convert_to_result(resp) { - Ok(resp) => { + match Self::convert_to_result(resp, request_id) { + Ok(Ok(resp)) => { // The consumer may drop the receiver before all responses are received. // We just ignore that as we still want obey the protocol and receive messages until the FIN flag or // the connection is dropped @@ -447,7 +452,7 @@ where TSubstream: AsyncRead + AsyncWrite + Unpin + Send break; } }, - Err(err) => { + Ok(Err(err)) => { debug!(target: LOG_TARGET, "Remote service returned error: {}", err); if !response_tx.is_closed() { let _ = response_tx.send(Err(err)).await; @@ -455,6 +460,14 @@ where TSubstream: AsyncRead + AsyncWrite + Unpin + Send response_tx.close_channel(); break; }, + Err(err @ RpcError::ResponseIdDidNotMatchRequest { .. }) => { + warn!(target: LOG_TARGET, "{}", err); + // Ignore the response, this can happen when there is excessive latency. The server sends back a + // reply before the deadline but it is only received after the client has timed + // out + continue; + }, + Err(err) => return Err(err), } } @@ -462,16 +475,29 @@ where TSubstream: AsyncRead + AsyncWrite + Unpin + Send } fn next_request_id(&mut self) -> u16 { - let next_id = self.request_id; + let next_id = self.next_request_id; // request_id is allowed to wrap around back to 0 - self.request_id = self.request_id.checked_add(1).unwrap_or(0); + self.next_request_id = self.next_request_id.checked_add(1).unwrap_or(0); next_id } - fn convert_to_result(resp: proto::rpc::RpcResponse) -> Result, RpcStatus> { + fn convert_to_result( + resp: proto::rpc::RpcResponse, + request_id: u16, + ) -> Result, RpcStatus>, RpcError> { + let resp_id = u16::try_from(resp.request_id) + .map_err(|_| RpcStatus::protocol_error(format!("invalid request_id: must be less than {}", u16::MAX)))?; + + if resp_id != request_id { + return Err(RpcError::ResponseIdDidNotMatchRequest { + expected: request_id, + actual: resp.request_id as u16, + }); + } + let status = RpcStatus::from(&resp); if !status.is_ok() { - return Err(status); + return Ok(Err(status)); } let resp = Response { @@ -479,7 +505,7 @@ where TSubstream: AsyncRead + AsyncWrite + Unpin + Send message: resp.message.into(), }; - Ok(resp) + Ok(Ok(resp)) } } diff --git a/comms/src/protocol/rpc/client_pool.rs b/comms/src/protocol/rpc/client_pool.rs index 1a89fe3328..6829b41265 100644 --- a/comms/src/protocol/rpc/client_pool.rs +++ b/comms/src/protocol/rpc/client_pool.rs @@ -22,7 +22,14 @@ use crate::{ peer_manager::NodeId, - protocol::rpc::{error::HandshakeRejectReason, NamedProtocolService, RpcClient, RpcError, RpcHandshakeError}, + protocol::rpc::{ + error::HandshakeRejectReason, + NamedProtocolService, + RpcClient, + RpcClientBuilder, + RpcError, + RpcHandshakeError, + }, PeerConnection, }; use log::*; @@ -43,8 +50,8 @@ impl RpcClientPool where T: RpcPoolClient + From + NamedProtocolService + Clone { /// Create a new RpcClientPool. Panics if passed a pool_size of 0. - pub(crate) fn new(peer_connection: PeerConnection, pool_size: usize) -> Self { - let pool = LazyPool::new(peer_connection, pool_size); + pub(crate) fn new(peer_connection: PeerConnection, pool_size: usize, client_config: RpcClientBuilder) -> Self { + let pool = LazyPool::new(peer_connection, pool_size, client_config); Self { pool: Arc::new(Mutex::new(pool)), } @@ -60,16 +67,18 @@ where T: RpcPoolClient + From + NamedProtocolService + Clone pub(super) struct LazyPool { connection: PeerConnection, clients: Vec>, + client_config: RpcClientBuilder, } impl LazyPool where T: RpcPoolClient + From + NamedProtocolService + Clone { - pub fn new(connection: PeerConnection, capacity: usize) -> Self { + pub fn new(connection: PeerConnection, capacity: usize, client_config: RpcClientBuilder) -> Self { assert!(capacity > 0, "Pool capacity of 0 is invalid"); Self { connection, clients: Vec::with_capacity(capacity), + client_config, } } @@ -162,7 +171,10 @@ where T: RpcPoolClient + From + NamedProtocolService + Clone async fn add_new_client_session(&mut self) -> Result<&RpcClientLease, RpcClientPoolError> { debug_assert!(!self.is_full(), "add_new_client called when pool is full"); - let client = self.connection.connect_rpc().await?; + let client = self + .connection + .connect_rpc_using_builder(self.client_config.clone()) + .await?; let client = RpcClientLease::new(client); self.clients.push(client); Ok(self.clients.last().unwrap()) diff --git a/comms/src/protocol/rpc/context.rs b/comms/src/protocol/rpc/context.rs index 1de5e6a023..f6e9d988f6 100644 --- a/comms/src/protocol/rpc/context.rs +++ b/comms/src/protocol/rpc/context.rs @@ -77,19 +77,28 @@ impl RpcCommsProvider for RpcCommsBackend { } pub struct RequestContext { + request_id: u32, backend: Box, node_id: NodeId, } impl RequestContext { - pub(super) fn new(node_id: NodeId, backend: Box) -> Self { - Self { backend, node_id } + pub(super) fn new(request_id: u32, node_id: NodeId, backend: Box) -> Self { + Self { + request_id, + backend, + node_id, + } } pub fn peer_node_id(&self) -> &NodeId { &self.node_id } + pub fn request_id(&self) -> u32 { + self.request_id + } + pub(crate) async fn fetch_peer(&self) -> Result { self.backend.fetch_peer(&self.node_id).await } diff --git a/comms/src/protocol/rpc/error.rs b/comms/src/protocol/rpc/error.rs index 0565e93575..959e781cf1 100644 --- a/comms/src/protocol/rpc/error.rs +++ b/comms/src/protocol/rpc/error.rs @@ -45,6 +45,8 @@ pub enum RpcError { ServerClosedRequest, #[error("Request cancelled")] RequestCancelled, + #[error("Response did not match the request ID (expected {expected} actual {actual})")] + ResponseIdDidNotMatchRequest { expected: u16, actual: u16 }, #[error("Client internal error: {0}")] ClientInternalError(String), #[error("Handshake error: {0}")] diff --git a/comms/src/protocol/rpc/message.rs b/comms/src/protocol/rpc/message.rs index c1455fb3c9..2e963b3d2d 100644 --- a/comms/src/protocol/rpc/message.rs +++ b/comms/src/protocol/rpc/message.rs @@ -242,6 +242,10 @@ impl proto::rpc::RpcResponse { pub fn flags(&self) -> RpcMessageFlags { RpcMessageFlags::from_bits_truncate(self.flags as u8) } + + pub fn is_fin(&self) -> bool { + self.flags as u8 & RpcMessageFlags::FIN.bits() != 0 + } } impl fmt::Display for proto::rpc::RpcResponse { diff --git a/comms/src/protocol/rpc/server/mock.rs b/comms/src/protocol/rpc/server/mock.rs index 0747e51da9..19741a0a1a 100644 --- a/comms/src/protocol/rpc/server/mock.rs +++ b/comms/src/protocol/rpc/server/mock.rs @@ -80,7 +80,7 @@ impl RpcRequestMock { } pub fn request_with_context(&self, node_id: NodeId, msg: T) -> Request { - let context = RequestContext::new(node_id, Box::new(self.comms_provider.clone())); + let context = RequestContext::new(0, node_id, Box::new(self.comms_provider.clone())); Request::with_context(context, 0.into(), msg) } diff --git a/comms/src/protocol/rpc/server/mod.rs b/comms/src/protocol/rpc/server/mod.rs index d00c50c75d..092120ee48 100644 --- a/comms/src/protocol/rpc/server/mod.rs +++ b/comms/src/protocol/rpc/server/mod.rs @@ -53,14 +53,13 @@ use crate::{ protocol::{ProtocolEvent, ProtocolId, ProtocolNotification, ProtocolNotificationRx}, Bytes, }; -use futures::{channel::mpsc, AsyncRead, AsyncWrite, Sink, SinkExt, StreamExt}; +use futures::{channel::mpsc, AsyncRead, AsyncWrite, SinkExt, StreamExt}; use log::*; use prost::Message; use std::{ - io, + future::Future, time::{Duration, Instant}, }; -use tari_shutdown::{OptionalShutdownSignal, ShutdownSignal}; use tokio::time; use tower::Service; use tower_make::MakeService; @@ -150,7 +149,6 @@ pub struct RpcServerBuilder { maximum_simultaneous_sessions: Option, minimum_client_deadline: Duration, handshake_timeout: Duration, - shutdown_signal: OptionalShutdownSignal, } impl RpcServerBuilder { @@ -173,11 +171,6 @@ impl RpcServerBuilder { self } - pub fn with_shutdown_signal(mut self, shutdown_signal: ShutdownSignal) -> Self { - self.shutdown_signal = Some(shutdown_signal).into(); - self - } - pub fn finish(self) -> RpcServer { let (request_tx, request_rx) = mpsc::channel(10); RpcServer { @@ -194,7 +187,6 @@ impl Default for RpcServerBuilder { maximum_simultaneous_sessions: Some(1000), minimum_client_deadline: Duration::from_secs(1), handshake_timeout: Duration::from_secs(15), - shutdown_signal: Default::default(), } } } @@ -248,8 +240,7 @@ where let mut protocol_notifs = self .protocol_notifications .take() - .expect("PeerRpcServer initialized without protocol_notifications") - .take_until(self.config.shutdown_signal.clone()); + .expect("PeerRpcServer initialized without protocol_notifications"); let mut requests = self .request_rx @@ -274,8 +265,7 @@ where debug!( target: LOG_TARGET, - "Peer RPC server is shut down because the shutdown signal was triggered or the protocol notification \ - stream ended" + "Peer RPC server is shut down because the protocol notification stream ended" ); Ok(()) @@ -367,10 +357,9 @@ where let service = ActivePeerRpcService { config: self.config.clone(), node_id: node_id.clone(), - framed: Some(framed), + framed, service, comms_provider: self.comms_provider.clone(), - shutdown_signal: self.config.shutdown_signal.clone(), }; self.executor @@ -385,9 +374,8 @@ struct ActivePeerRpcService { config: RpcServerBuilder, node_id: NodeId, service: TSvc, - framed: Option>, + framed: CanonicalFraming, comms_provider: TCommsProvider, - shutdown_signal: OptionalShutdownSignal, } impl ActivePeerRpcService @@ -408,28 +396,26 @@ where } async fn run(&mut self) -> Result<(), RpcServerError> { - let (mut sink, stream) = self.framed.take().unwrap().split(); - let mut stream = stream.fuse().take_until(self.shutdown_signal.clone()); - - while let Some(result) = stream.next().await { + while let Some(result) = self.framed.next().await { let start = Instant::now(); - if let Err(err) = self.handle(&mut sink, result?.freeze()).await { - sink.close().await?; + if let Err(err) = self.handle(result?.freeze()).await { + self.framed.close().await?; return Err(err); } - debug!(target: LOG_TARGET, "RPC request completed in {:.0?}", start.elapsed()); + let elapsed = start.elapsed(); + debug!( + target: LOG_TARGET, + "RPC request completed in {:.0?}{}", + elapsed, + if elapsed.as_secs() > 5 { " (LONG REQUEST)" } else { "" } + ); } - sink.close().await?; + self.framed.close().await?; Ok(()) } - fn create_request_context(&self) -> RequestContext { - RequestContext::new(self.node_id.clone(), Box::new(self.comms_provider.clone())) - } - - async fn handle(&mut self, sink: &mut W, mut request: Bytes) -> Result<(), RpcServerError> - where W: Sink + Unpin { + async fn handle(&mut self, mut request: Bytes) -> Result<(), RpcServerError> { let decoded_msg = proto::rpc::RpcRequest::decode(&mut request)?; let request_id = decoded_msg.request_id; @@ -453,7 +439,7 @@ where flags: RpcMessageFlags::FIN.bits().into(), message: status.details_bytes(), }; - sink.send(bad_request.to_encoded_bytes().into()).await?; + self.framed.send(bad_request.to_encoded_bytes().into()).await?; return Ok(()); } @@ -462,9 +448,14 @@ where "[Peer=`{}`] Got request {}", self.node_id, decoded_msg ); - let req = Request::with_context(self.create_request_context(), method, decoded_msg.message.into()); + let req = Request::with_context( + self.create_request_context(request_id), + method, + decoded_msg.message.into(), + ); - let service_result = time::timeout(deadline, self.service.call(req)).await; + let service_call = log_timing(request_id, "service call", self.service.call(req)); + let service_result = time::timeout(deadline, service_call).await; let service_result = match service_result { Ok(v) => v, Err(_) => { @@ -478,9 +469,42 @@ where match service_result { Ok(body) => { + // This is the most basic way we can push responses back to the peer. Keeping this here for reference + // and possible future evaluation + // + // body.into_message() + // .map(|msg| match msg { + // Ok(msg) => { + // trace!(target: LOG_TARGET, "Sending body len = {}", msg.len()); + // let mut flags = RpcMessageFlags::empty(); + // if msg.is_finished() { + // flags |= RpcMessageFlags::FIN; + // } + // proto::rpc::RpcResponse { + // request_id, + // status: RpcStatus::ok().as_code(), + // flags: flags.bits().into(), + // message: msg.into(), + // } + // }, + // Err(err) => { + // debug!(target: LOG_TARGET, "Body contained an error: {}", err); + // proto::rpc::RpcResponse { + // request_id, + // status: err.as_code(), + // flags: RpcMessageFlags::FIN.bits().into(), + // message: err.details().as_bytes().to_vec(), + // } + // }, + // }) + // .map(|resp| Ok(resp.to_encoded_bytes().into())) + // .forward(PreventClose::new(sink)) + // .await?; + let mut message = body.into_message(); loop { - match time::timeout(deadline, message.next()).await { + let msg_read = log_timing(request_id, "message read", message.next()); + match time::timeout(deadline, msg_read).await { Ok(Some(msg)) => { let resp = match msg { Ok(msg) => { @@ -507,7 +531,10 @@ where }, }; - if !send_response_checked(sink, request_id, resp).await? { + let is_valid = + log_timing(request_id, "transmit", self.send_response(request_id, resp)).await?; + + if !is_valid { break; } }, @@ -521,7 +548,7 @@ where break; }, } - } + } // end loop }, Err(err) => { debug!(target: LOG_TARGET, "Service returned an error: {}", err); @@ -532,50 +559,63 @@ where message: err.details_bytes(), }; - sink.send(resp.to_encoded_bytes().into()).await?; + self.framed.send(resp.to_encoded_bytes().into()).await?; }, } Ok(()) } -} -/// Sends an RpcResponse on the given Sink. If the size of the message exceeds the RPC_MAX_FRAME_SIZE, an error is -/// returned to the client and false is returned from this function, otherwise the message is sent and true is returned -#[inline] -async fn send_response_checked( - sink: &mut S, - request_id: u32, - resp: proto::rpc::RpcResponse, -) -> Result -where - S: Sink + Unpin, -{ - match resp.to_encoded_bytes() { - buf if buf.len() > RPC_MAX_FRAME_SIZE => { - let msg = format!( - "This node tried to return a message that exceeds the maximum frame size. Max = {:.4} MiB, Got = \ - {:.4} MiB", - RPC_MAX_FRAME_SIZE as f32 / (1024.0 * 1024.0), - buf.len() as f32 / (1024.0 * 1024.0) - ); - warn!(target: LOG_TARGET, "{}", msg); - sink.send( - proto::rpc::RpcResponse { - request_id, - status: RpcStatusCode::MalformedResponse as u32, - flags: RpcMessageFlags::FIN.bits().into(), - message: msg.as_bytes().to_vec(), - } - .to_encoded_bytes() - .into(), - ) - .await?; - Ok(false) - }, - buf => { - sink.send(buf.into()).await?; - Ok(true) - }, + /// Sends an RpcResponse on the given Sink. If the size of the message exceeds the RPC_MAX_FRAME_SIZE, an error is + /// returned to the client and false is returned from this function, otherwise the message is sent and true is + /// returned + async fn send_response(&mut self, request_id: u32, resp: proto::rpc::RpcResponse) -> Result { + match resp.to_encoded_bytes() { + buf if buf.len() > RPC_MAX_FRAME_SIZE => { + let msg = format!( + "This node tried to return a message that exceeds the maximum frame size. Max = {:.4} MiB, Got = \ + {:.4} MiB", + RPC_MAX_FRAME_SIZE as f32 / (1024.0 * 1024.0), + buf.len() as f32 / (1024.0 * 1024.0) + ); + warn!(target: LOG_TARGET, "{}", msg); + self.framed + .send( + proto::rpc::RpcResponse { + request_id, + status: RpcStatusCode::MalformedResponse as u32, + flags: RpcMessageFlags::FIN.bits().into(), + message: msg.as_bytes().to_vec(), + } + .to_encoded_bytes() + .into(), + ) + .await?; + Ok(false) + }, + buf => { + self.framed.send(buf.into()).await?; + Ok(true) + }, + } } + + fn create_request_context(&self, request_id: u32) -> RequestContext { + RequestContext::new(request_id, self.node_id.clone(), Box::new(self.comms_provider.clone())) + } +} + +async fn log_timing>(request_id: u32, tag: &str, fut: F) -> R { + let t = Instant::now(); + let ret = fut.await; + let elapsed = t.elapsed(); + trace!( + target: LOG_TARGET, + "RPC TIMING(REQ_ID={}): '{}' took {:.2}s{}", + request_id, + tag, + elapsed.as_secs_f32(), + if elapsed.as_secs() >= 5 { " (SLOW)" } else { "" } + ); + ret } diff --git a/comms/src/protocol/rpc/status.rs b/comms/src/protocol/rpc/status.rs index ffa4fa839f..e0ddf7fe22 100644 --- a/comms/src/protocol/rpc/status.rs +++ b/comms/src/protocol/rpc/status.rs @@ -36,21 +36,21 @@ pub struct RpcStatus { impl RpcStatus { pub fn ok() -> Self { - RpcStatus { + Self { code: RpcStatusCode::Ok, details: Default::default(), } } pub fn unsupported_method(details: T) -> Self { - RpcStatus { + Self { code: RpcStatusCode::UnsupportedMethod, details: details.to_string(), } } pub fn not_implemented(details: T) -> Self { - RpcStatus { + Self { code: RpcStatusCode::NotImplemented, details: details.to_string(), } @@ -99,6 +99,13 @@ impl RpcStatus { } } + pub(super) fn protocol_error(details: T) -> Self { + Self { + code: RpcStatusCode::ProtocolError, + details: details.to_string(), + } + } + pub fn as_code(&self) -> u32 { self.code as u32 } @@ -177,6 +184,8 @@ pub enum RpcStatusCode { General = 6, /// Entity not found NotFound = 7, + /// RPC protocol error + ProtocolError = 8, // The following status represents anything that is not recognised (i.e not one of the above codes). /// Unrecognised RPC status code InvalidRpcStatusCode, diff --git a/comms/src/protocol/rpc/test/client_pool.rs b/comms/src/protocol/rpc/test/client_pool.rs index 246bd9f12f..e1eb957d5f 100644 --- a/comms/src/protocol/rpc/test/client_pool.rs +++ b/comms/src/protocol/rpc/test/client_pool.rs @@ -52,7 +52,6 @@ async fn setup(num_concurrent_sessions: usize) -> (PeerConnection, PeerConnectio task::spawn( RpcServer::builder() .with_maximum_simultaneous_sessions(num_concurrent_sessions) - .with_shutdown_signal(shutdown.to_signal()) .finish() .add_service(GreetingServer::new(GreetingService::default())) .serve(notif_rx, context), @@ -80,7 +79,7 @@ mod lazy_pool { #[runtime::test] async fn it_connects_lazily() { let (conn, mock_state, _shutdown) = setup(2).await; - let mut pool = LazyPool::::new(conn, 2); + let mut pool = LazyPool::::new(conn, 2, Default::default()); assert_eq!(mock_state.num_open_substreams(), 0); let _conn1 = pool.get_least_used_or_connect().await.unwrap(); assert_eq!(mock_state.num_open_substreams(), 1); @@ -91,7 +90,7 @@ mod lazy_pool { #[runtime::test] async fn it_reuses_unused_connections() { let (conn, mock_state, _shutdown) = setup(2).await; - let mut pool = LazyPool::::new(conn, 2); + let mut pool = LazyPool::::new(conn, 2, Default::default()); let _ = pool.get_least_used_or_connect().await.unwrap(); assert_eq!(pool.refresh_num_active_connections(), 1); async_assert_eventually!(mock_state.num_open_substreams(), expect = 1); @@ -103,7 +102,7 @@ mod lazy_pool { #[runtime::test] async fn it_reuses_least_used_connections() { let (conn, mock_state, _shutdown) = setup(2).await; - let mut pool = LazyPool::::new(conn, 2); + let mut pool = LazyPool::::new(conn, 2, Default::default()); let conn1 = pool.get_least_used_or_connect().await.unwrap(); assert_eq!(mock_state.num_open_substreams(), 1); let conn2 = pool.get_least_used_or_connect().await.unwrap(); @@ -124,7 +123,7 @@ mod lazy_pool { #[runtime::test] async fn it_reuses_used_connections_if_necessary() { let (conn, mock_state, _shutdown) = setup(2).await; - let mut pool = LazyPool::::new(conn, 1); + let mut pool = LazyPool::::new(conn, 1, Default::default()); let conn1 = pool.get_least_used_or_connect().await.unwrap(); assert_eq!(mock_state.num_open_substreams(), 1); let conn2 = pool.get_least_used_or_connect().await.unwrap(); @@ -136,7 +135,7 @@ mod lazy_pool { #[runtime::test] async fn it_gracefully_handles_insufficient_server_sessions() { let (conn, mock_state, _shutdown) = setup(1).await; - let mut pool = LazyPool::::new(conn, 2); + let mut pool = LazyPool::::new(conn, 2, Default::default()); let conn1 = pool.get_least_used_or_connect().await.unwrap(); assert_eq!(mock_state.num_open_substreams(), 1); let conn2 = pool.get_least_used_or_connect().await.unwrap(); @@ -148,7 +147,7 @@ mod lazy_pool { #[runtime::test] async fn it_prunes_disconnected_sessions() { let (conn, mock_state, _shutdown) = setup(2).await; - let mut pool = LazyPool::::new(conn, 2); + let mut pool = LazyPool::::new(conn, 2, Default::default()); let mut conn1 = pool.get_least_used_or_connect().await.unwrap(); assert_eq!(mock_state.num_open_substreams(), 1); let _conn2 = pool.get_least_used_or_connect().await.unwrap(); @@ -165,7 +164,7 @@ mod lazy_pool { #[runtime::test] async fn it_fails_when_peer_connected_disconnects() { let (mut peer_conn, _, _shutdown) = setup(2).await; - let mut pool = LazyPool::::new(peer_conn.clone(), 2); + let mut pool = LazyPool::::new(peer_conn.clone(), 2, Default::default()); let mut _conn1 = pool.get_least_used_or_connect().await.unwrap(); peer_conn.disconnect().await.unwrap(); let err = pool.get_least_used_or_connect().await.unwrap_err(); diff --git a/comms/src/protocol/rpc/test/smoke.rs b/comms/src/protocol/rpc/test/smoke.rs index c02d78db73..14570b31ce 100644 --- a/comms/src/protocol/rpc/test/smoke.rs +++ b/comms/src/protocol/rpc/test/smoke.rs @@ -52,7 +52,7 @@ use crate::{ test_utils::node_identity::build_node_identity, NodeIdentity, }; -use futures::{channel::mpsc, SinkExt, StreamExt}; +use futures::{channel::mpsc, future, future::Either, SinkExt, StreamExt}; use std::{sync::Arc, time::Duration}; use tari_crypto::tari_utilities::hex::Hex; use tari_shutdown::Shutdown; @@ -64,34 +64,39 @@ pub(super) async fn setup_service( num_concurrent_sessions: usize, ) -> ( mpsc::Sender>, - task::JoinHandle>, + task::JoinHandle<()>, RpcCommsBackend, Shutdown, ) { let (notif_tx, notif_rx) = mpsc::channel(1); let shutdown = Shutdown::new(); let (context, _) = create_mocked_rpc_context(); - let server_hnd = task::spawn( - RpcServer::builder() - .with_maximum_simultaneous_sessions(num_concurrent_sessions) - .with_minimum_client_deadline(Duration::from_secs(0)) - .with_shutdown_signal(shutdown.to_signal()) - .finish() - .add_service(GreetingServer::new(service_impl)) - .serve(notif_rx, context.clone()), - ); + let server_hnd = task::spawn({ + let context = context.clone(); + let shutdown_signal = shutdown.to_signal(); + async move { + let fut = RpcServer::builder() + .with_maximum_simultaneous_sessions(num_concurrent_sessions) + .with_minimum_client_deadline(Duration::from_secs(0)) + .finish() + .add_service(GreetingServer::new(service_impl)) + .serve(notif_rx, context); + + futures::pin_mut!(fut); + + match future::select(shutdown_signal, fut).await { + Either::Left((r, _)) => r.unwrap(), + Either::Right((r, _)) => r.unwrap(), + } + } + }); (notif_tx, server_hnd, context, shutdown) } pub(super) async fn setup( service_impl: T, num_concurrent_sessions: usize, -) -> ( - MemorySocket, - task::JoinHandle>, - Arc, - Shutdown, -) { +) -> (MemorySocket, task::JoinHandle<()>, Arc, Shutdown) { let (mut notif_tx, server_hnd, context, shutdown) = setup_service(service_impl, num_concurrent_sessions).await; let (inbound, outbound) = MemorySocket::new_pair(); let node_identity = build_node_identity(Default::default()); @@ -110,8 +115,7 @@ pub(super) async fn setup( } #[runtime::test_basic] -async fn request_reponse_errors_and_streaming() // a.k.a smoke test -{ +async fn request_response_errors_and_streaming() { let (socket, server_hnd, node_identity, mut shutdown) = setup(GreetingService::default(), 1).await; let framed = framing::canonical(socket, 1024); @@ -180,7 +184,7 @@ async fn request_reponse_errors_and_streaming() // a.k.a smoke test unpack_enum!(RpcError::ClientClosed = err); shutdown.trigger().unwrap(); - server_hnd.await.unwrap().unwrap(); + server_hnd.await.unwrap(); } #[runtime::test_basic] @@ -244,17 +248,6 @@ async fn response_too_big() { let _ = client.reply_with_msg_of_size(max_size as u64).await.unwrap(); } -#[runtime::test_basic] -async fn server_shutdown_after_connect() { - let (socket, _, _, mut shutdown) = setup(GreetingService::new(&[]), 1).await; - let framed = framing::canonical(socket, 1024); - let mut client = GreetingClient::connect(framed).await.unwrap(); - shutdown.trigger().unwrap(); - - let err = client.say_hello(Default::default()).await.unwrap_err(); - unpack_enum!(RpcError::RequestCancelled = err); -} - #[runtime::test_basic] async fn server_shutdown_before_connect() { let (socket, _, _, mut shutdown) = setup(GreetingService::new(&[]), 1).await; diff --git a/comms/tests/greeting_service.rs b/comms/tests/greeting_service.rs new file mode 100644 index 0000000000..c13ae842e4 --- /dev/null +++ b/comms/tests/greeting_service.rs @@ -0,0 +1,166 @@ +// Copyright 2021, The Tari Project +// +// Redistribution and use in source and binary forms, with or without modification, are permitted provided that the +// following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following +// disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the +// following disclaimer in the documentation and/or other materials provided with the distribution. +// +// 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote +// products derived from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, +// INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, +// WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +// USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +#![allow(dead_code)] +#![cfg(feature = "rpc")] + +use core::iter; +use futures::{channel::mpsc, stream, SinkExt, StreamExt}; +use std::{cmp, time::Duration}; +use tari_comms::{ + async_trait, + protocol::rpc::{Request, Response, RpcStatus, Streaming}, +}; +use tari_comms_rpc_macros::tari_rpc; +use tokio::{task, time}; + +#[tari_rpc(protocol_name = b"t/greeting/1", server_struct = GreetingServer, client_struct = GreetingClient)] +pub trait GreetingRpc: Send + Sync + 'static { + #[rpc(method = 1)] + async fn say_hello(&self, request: Request) -> Result, RpcStatus>; + #[rpc(method = 2)] + async fn get_greetings(&self, request: Request) -> Result, RpcStatus>; + #[rpc(method = 3)] + async fn reply_with_msg_of_size(&self, request: Request) -> Result>, RpcStatus>; + #[rpc(method = 4)] + async fn stream_large_items( + &self, + request: Request, + ) -> Result>, RpcStatus>; + #[rpc(method = 5)] + async fn slow_response(&self, request: Request) -> Result, RpcStatus>; +} + +pub struct GreetingService { + greetings: Vec, +} + +impl GreetingService { + pub const DEFAULT_GREETINGS: &'static [&'static str] = + &["Sawubona", "Jambo", "Bonjour", "Hello", "Molo", "Olá", "سلام", "你好"]; + + pub fn new(greetings: &[&str]) -> Self { + Self { + greetings: greetings.iter().map(ToString::to_string).collect(), + } + } +} + +impl Default for GreetingService { + fn default() -> Self { + Self::new(Self::DEFAULT_GREETINGS) + } +} + +#[async_trait] +impl GreetingRpc for GreetingService { + async fn say_hello(&self, request: Request) -> Result, RpcStatus> { + let msg = request.message(); + let greeting = self + .greetings + .get(msg.language as usize) + .ok_or_else(|| RpcStatus::bad_request(format!("{} is not a valid language identifier", msg.language)))?; + + let greeting = format!("{} {}", greeting, msg.name); + Ok(Response::new(SayHelloResponse { greeting })) + } + + async fn get_greetings(&self, request: Request) -> Result, RpcStatus> { + let num = *request.message(); + let (mut tx, rx) = mpsc::channel(num as usize); + let greetings = self.greetings[..cmp::min(num as usize + 1, self.greetings.len())].to_vec(); + task::spawn(async move { + let iter = greetings.into_iter().map(Ok); + let mut stream = stream::iter(iter) + // "Extra" Result::Ok is to satisfy send_all + .map(Ok); + tx.send_all(&mut stream).await.unwrap(); + }); + + Ok(Streaming::new(rx)) + } + + async fn reply_with_msg_of_size(&self, request: Request) -> Result>, RpcStatus> { + let size = request.into_message() as usize; + Ok(Response::new(iter::repeat(0).take(size).collect())) + } + + async fn stream_large_items( + &self, + request: Request, + ) -> Result>, RpcStatus> { + let req_id = request.context().request_id(); + let StreamLargeItemsRequest { + id, + item_size, + num_items, + } = request.into_message(); + let (mut tx, rx) = mpsc::channel(10); + let t = std::time::Instant::now(); + task::spawn(async move { + let item = iter::repeat(0u8).take(item_size as usize).collect::>(); + for (i, item) in iter::repeat_with(|| Ok(item.clone())) + .take(num_items as usize) + .enumerate() + { + tx.send(item).await.unwrap(); + log::info!( + "[{}] reqid: {} t={:.2?} sent {}/{}", + id, + req_id, + t.elapsed(), + i + 1, + num_items + ); + } + }); + Ok(Streaming::new(rx)) + } + + async fn slow_response(&self, request: Request) -> Result, RpcStatus> { + time::delay_for(Duration::from_secs(request.into_message())).await; + Ok(Response::new(())) + } +} + +#[derive(prost::Message)] +pub struct SayHelloRequest { + #[prost(string, tag = "1")] + pub name: String, + #[prost(uint32, tag = "2")] + pub language: u32, +} + +#[derive(prost::Message)] +pub struct SayHelloResponse { + #[prost(string, tag = "1")] + pub greeting: String, +} + +#[derive(prost::Message)] +pub struct StreamLargeItemsRequest { + #[prost(uint64, tag = "1")] + pub id: u64, + #[prost(uint64, tag = "2")] + pub num_items: u64, + #[prost(uint64, tag = "3")] + pub item_size: u64, +} diff --git a/comms/tests/helpers.rs b/comms/tests/helpers.rs new file mode 100644 index 0000000000..e67f6dd6a4 --- /dev/null +++ b/comms/tests/helpers.rs @@ -0,0 +1,62 @@ +// Copyright 2021, The Tari Project +// +// Redistribution and use in source and binary forms, with or without modification, are permitted provided that the +// following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following +// disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the +// following disclaimer in the documentation and/or other materials provided with the distribution. +// +// 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote +// products derived from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, +// INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, +// WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +// USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +use rand::rngs::OsRng; +use std::sync::Arc; +use tari_comms::{peer_manager::PeerFeatures, types::CommsDatabase, CommsBuilder, NodeIdentity, UnspawnedCommsNode}; +use tari_shutdown::ShutdownSignal; +use tari_storage::{ + lmdb_store::{LMDBBuilder, LMDBConfig}, + LMDBWrapper, +}; +use tari_test_utils::{paths::create_temporary_data_path, random}; + +pub fn create_peer_storage() -> CommsDatabase { + let database_name = random::string(8); + let datastore = LMDBBuilder::new() + .set_path(create_temporary_data_path()) + .set_env_config(LMDBConfig::default()) + .set_max_number_of_databases(1) + .add_database(&database_name, lmdb_zero::db::CREATE) + .build() + .unwrap(); + + let peer_database = datastore.get_handle(&database_name).unwrap(); + LMDBWrapper::new(Arc::new(peer_database)) +} + +pub fn create_comms(signal: ShutdownSignal) -> UnspawnedCommsNode { + let node_identity = Arc::new(NodeIdentity::random( + &mut OsRng, + "/ip4/127.0.0.1/tcp/0".parse().unwrap(), + PeerFeatures::COMMUNICATION_NODE, + )); + + CommsBuilder::new() + .allow_test_addresses() + .with_listener_address("/ip4/127.0.0.1/tcp/0".parse().unwrap()) + .with_node_identity(node_identity) + .with_peer_storage(create_peer_storage(), None) + .with_shutdown_signal(signal) + .build() + .unwrap() +} diff --git a/comms/tests/rpc_stress.rs b/comms/tests/rpc_stress.rs new file mode 100644 index 0000000000..24d124d1ee --- /dev/null +++ b/comms/tests/rpc_stress.rs @@ -0,0 +1,290 @@ +// Copyright 2021, The Tari Project +// +// Redistribution and use in source and binary forms, with or without modification, are permitted provided that the +// following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following +// disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the +// following disclaimer in the documentation and/or other materials provided with the distribution. +// +// 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote +// products derived from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, +// INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, +// WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +// USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +#![cfg(feature = "rpc")] + +// Run as normal, --nocapture for some extra output +// cargo test --package tari_comms --test rpc_stress run --all-features --release -- --exact --nocapture + +mod greeting_service; +use greeting_service::{GreetingClient, GreetingServer, GreetingService, StreamLargeItemsRequest}; + +mod helpers; +use helpers::create_comms; + +use futures::{future, StreamExt}; +use std::{env, future::Future, time::Duration}; +use tari_comms::{ + protocol::rpc::{RpcClientBuilder, RpcServer}, + transports::TcpTransport, + CommsNode, +}; +use tari_shutdown::{Shutdown, ShutdownSignal}; +use tokio::{task, time::Instant}; + +pub async fn spawn_node(signal: ShutdownSignal) -> CommsNode { + let rpc_server = RpcServer::builder() + .with_unlimited_simultaneous_sessions() + .finish() + .add_service(GreetingServer::new(GreetingService::default())); + + let comms = create_comms(signal) + .add_rpc_server(rpc_server) + .spawn_with_transport(TcpTransport::new()) + .await + .unwrap(); + + comms + .node_identity() + .set_public_address(comms.listening_address().clone()); + comms +} + +async fn run_stress_test(test_params: Params) { + let Params { + num_tasks, + num_concurrent_sessions, + deadline, + payload_size, + num_items, + } = test_params; + + let time = Instant::now(); + assert!( + num_tasks >= num_concurrent_sessions, + "concurrent tasks must be more than concurrent sessions, otherwise the (lazy) pool wont make the given number \ + of sessions" + ); + println!( + "RPC stress test will transfer a total of {} MiB of data", + (num_tasks * payload_size * num_items) / (1024 * 1024) + ); + log::info!( + "RPC stress test will transfer a total of {} MiB of data", + (num_tasks * payload_size * num_items) / (1024 * 1024) + ); + + let shutdown = Shutdown::new(); + let node1 = spawn_node(shutdown.to_signal()).await; + let node2 = spawn_node(shutdown.to_signal()).await; + + node1 + .peer_manager() + .add_peer(node2.node_identity().to_peer()) + .await + .unwrap(); + + let conn1_2 = node1 + .connectivity() + .dial_peer(node2.node_identity().node_id().clone()) + .await + .unwrap(); + + let client_pool = conn1_2.create_rpc_client_pool::( + num_concurrent_sessions, + RpcClientBuilder::new().with_deadline(deadline), + ); + + let mut tasks = Vec::with_capacity(num_tasks); + for i in 0..num_tasks { + let pool = client_pool.clone(); + tasks.push(task::spawn(async move { + let mut client = pool.get().await.unwrap(); + // let mut stream = client + // .get_greetings(GreetingService::DEFAULT_GREETINGS.len() as u32) + // .await + // .unwrap(); + // let mut count = 0; + // while let Some(Ok(_)) = stream.next().await { + // count += 1; + // } + // assert_eq!(count, GreetingService::DEFAULT_GREETINGS.len()); + + // let err = client.slow_response(5).await.unwrap_err(); + // unpack_enum!(RpcError::RequestFailed(err) = err); + // assert_eq!(err.status_code(), RpcStatusCode::Timeout); + + // let msg = client.reply_with_msg_of_size(1024).await.unwrap(); + // assert_eq!(msg.len(), 1024); + + let time = std::time::Instant::now(); + log::info!("[{}] start {:.2?}", i, time.elapsed(),); + let mut stream = client + .stream_large_items(StreamLargeItemsRequest { + id: i as u64, + num_items: num_items as u64, + item_size: payload_size as u64, + }) + .await + .unwrap(); + + log::info!("[{}] got stream {:.2?}", i, time.elapsed()); + let mut count = 0; + while let Some(r) = stream.next().await { + count += 1; + log::info!( + "[{}] (count = {}) consuming stream {:.2?} : {}", + i, + count, + time.elapsed(), + r.as_ref().err().map(ToString::to_string).unwrap_or_else(String::new) + ); + + let _ = r.unwrap(); + } + assert_eq!(count, num_items); + })); + } + + future::join_all(tasks).await.into_iter().for_each(Result::unwrap); + log::info!("Stress test took {:.2?}", time.elapsed()); +} + +struct Params { + pub num_tasks: usize, + pub num_concurrent_sessions: usize, + pub deadline: Duration, + pub payload_size: usize, + pub num_items: usize, +} + +async fn quick() { + run_stress_test(Params { + num_tasks: 10, + num_concurrent_sessions: 10, + deadline: Duration::from_secs(5), + payload_size: 1024, + num_items: 10, + }) + .await; +} + +async fn basic() { + run_stress_test(Params { + num_tasks: 10, + num_concurrent_sessions: 5, + deadline: Duration::from_secs(15), + payload_size: 1024 * 1024, + num_items: 4, + }) + .await; +} + +#[allow(dead_code)] +async fn many_small_messages() { + run_stress_test(Params { + num_tasks: 10, + num_concurrent_sessions: 10, + deadline: Duration::from_secs(5), + payload_size: 1024, + num_items: 10 * 1024, + }) + .await; +} + +async fn few_large_messages() { + run_stress_test(Params { + num_tasks: 10, + num_concurrent_sessions: 10, + deadline: Duration::from_secs(5), + payload_size: 1024 * 1024, + num_items: 10, + }) + .await; +} + +async fn payload_limit() { + run_stress_test(Params { + num_tasks: 50, + num_concurrent_sessions: 10, + deadline: Duration::from_secs(20), + payload_size: 4 * 1024 * 1024 - 100, + num_items: 2, + }) + .await; +} + +async fn high_contention() { + run_stress_test(Params { + num_tasks: 1000, + num_concurrent_sessions: 10, + deadline: Duration::from_secs(15), + payload_size: 1024 * 1024, + num_items: 4, + }) + .await; +} + +async fn high_concurrency() { + run_stress_test(Params { + num_tasks: 1000, + num_concurrent_sessions: 1000, + deadline: Duration::from_secs(15), + payload_size: 1024 * 1024, + num_items: 4, + }) + .await; +} + +async fn high_contention_high_concurrency() { + run_stress_test(Params { + num_tasks: 2000, + num_concurrent_sessions: 1000, + deadline: Duration::from_secs(15), + payload_size: 1024 * 1024, + num_items: 4, + }) + .await; +} + +#[tokio_macros::test] +async fn run_ci() { + log_timing("quick", quick()).await; + log_timing("basic", basic()).await; + log_timing("many_small_messages", many_small_messages()).await; + log_timing("few_large_messages", few_large_messages()).await; +} + +#[tokio_macros::test] +async fn run() { + if env::var("CI").is_ok() { + println!("Skipping the stress test on CI"); + return; + } + // let _ = env_logger::try_init(); + log_timing("quick", quick()).await; + log_timing("basic", basic()).await; + log_timing("many_small_messages", many_small_messages()).await; + log_timing("few_large_messages", few_large_messages()).await; + log_timing("payload_limit", payload_limit()).await; + log_timing("high_contention", high_contention()).await; + log_timing("high_concurrency", high_concurrency()).await; + log_timing("high_contention_high_concurrency", high_contention_high_concurrency()).await; +} + +async fn log_timing>(name: &str, fut: F) -> R { + let t = Instant::now(); + println!("'{}' is running...", name); + let ret = fut.await; + let elapsed = t.elapsed(); + println!("'{}' completed in {:.2}s", name, elapsed.as_secs_f32()); + ret +} diff --git a/comms/tests/substream_stress.rs b/comms/tests/substream_stress.rs new file mode 100644 index 0000000000..cbae6e8e52 --- /dev/null +++ b/comms/tests/substream_stress.rs @@ -0,0 +1,160 @@ +// Copyright 2021, The Tari Project +// +// Redistribution and use in source and binary forms, with or without modification, are permitted provided that the +// following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following +// disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the +// following disclaimer in the documentation and/or other materials provided with the distribution. +// +// 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote +// products derived from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, +// INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, +// WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +// USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +mod helpers; +use helpers::create_comms; + +use futures::{channel::mpsc, future, SinkExt, StreamExt}; +use std::time::Duration; +use tari_comms::{ + framing, + protocol::{ProtocolEvent, ProtocolId, ProtocolNotificationRx}, + transports::TcpTransport, + BytesMut, + CommsNode, + Substream, +}; +use tari_shutdown::{Shutdown, ShutdownSignal}; +use tari_test_utils::unpack_enum; +use tokio::{task, time::Instant}; + +const PROTOCOL_NAME: &[u8] = b"test/dummy/protocol"; + +pub async fn spawn_node(signal: ShutdownSignal) -> (CommsNode, ProtocolNotificationRx) { + let (notif_tx, notif_rx) = mpsc::channel(1); + let comms = create_comms(signal) + .add_protocol(&[ProtocolId::from_static(PROTOCOL_NAME)], notif_tx) + .spawn_with_transport(TcpTransport::new()) + .await + .unwrap(); + + comms + .node_identity() + .set_public_address(comms.listening_address().clone()); + (comms, notif_rx) +} + +async fn run_stress_test(num_substreams: usize, num_iterations: usize, payload_size: usize, frame_size: usize) { + let shutdown = Shutdown::new(); + let (node1, _) = spawn_node(shutdown.to_signal()).await; + let (node2, mut notif_rx) = spawn_node(shutdown.to_signal()).await; + + node1 + .peer_manager() + .add_peer(node2.node_identity().to_peer()) + .await + .unwrap(); + + let mut conn = node1 + .connectivity() + .dial_peer(node2.node_identity().node_id().clone()) + .await + .unwrap(); + + let sample = { + let mut buf = BytesMut::with_capacity(payload_size); + buf.fill(1); + buf.freeze() + }; + + task::spawn({ + let sample = sample.clone(); + async move { + while let Some(event) = notif_rx.next().await { + unpack_enum!(ProtocolEvent::NewInboundSubstream(_n, remote_substream) = event.event); + let mut remote_substream = framing::canonical(remote_substream, frame_size); + + task::spawn({ + let sample = sample.clone(); + async move { + let mut count = 0; + while let Some(r) = remote_substream.next().await { + r.unwrap(); + count += 1; + remote_substream.send(sample.clone()).await.unwrap(); + + if count == num_iterations { + break; + } + } + + assert_eq!(count, num_iterations); + } + }); + } + } + }); + + let mut substreams = Vec::with_capacity(num_substreams); + for _ in 0..num_substreams { + let substream = conn + .open_framed_substream(&ProtocolId::from_static(PROTOCOL_NAME), frame_size) + .await + .unwrap(); + substreams.push(substream); + } + + let tasks = substreams + .into_iter() + .enumerate() + .map(|(id, mut substream)| { + let sample = sample.clone(); + task::spawn(async move { + let mut count = 1; + // Send first to get the ball rolling + substream.send(sample.clone()).await.unwrap(); + let mut total_time = Duration::from_secs(0); + while let Some(r) = substream.next().await { + r.unwrap(); + count += 1; + let t = Instant::now(); + substream.send(sample.clone()).await.unwrap(); + total_time += t.elapsed(); + if count == num_iterations { + break; + } + } + + println!("[task {}] send time = {:.2?}", id, total_time,); + assert_eq!(count, num_iterations); + total_time + }) + }) + .collect::>(); + + let send_latencies = future::join_all(tasks) + .await + .into_iter() + .map(Result::unwrap) + .collect::>(); + let avg = send_latencies.iter().sum::().as_millis() / send_latencies.len() as u128; + println!("avg t = {}ms", avg); +} + +#[tokio_macros::test] +async fn many_at_frame_limit() { + const NUM_SUBSTREAMS: usize = 20; + const NUM_ITERATIONS_PER_STREAM: usize = 100; + const MAX_FRAME_SIZE: usize = 4 * 1024 * 1024; + const PAYLOAD_SIZE: usize = 4 * 1024 * 1024; + run_stress_test(NUM_SUBSTREAMS, NUM_ITERATIONS_PER_STREAM, PAYLOAD_SIZE, MAX_FRAME_SIZE).await; +}