Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(comms): simplify and remove possibility of deadlock from pipelines and substream close #4676

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion base_layer/p2p/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ impl Default for P2pConfig {
allow_test_addresses: false,
listener_liveness_max_sessions: 0,
listener_liveness_allowlist_cidrs: StringList::default(),
user_agent: "".to_string(),
user_agent: String::new(),
auxiliary_tcp_listener_address: None,
rpc_max_simultaneous_sessions: 100,
rpc_max_sessions_per_peer: 10,
Expand Down
117 changes: 75 additions & 42 deletions comms/core/src/multiplexing/yamux.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
use std::{future::Future, io, pin::Pin, task::Poll};

use futures::{task::Context, Stream};
use tari_shutdown::{Shutdown, ShutdownSignal};
use tokio::{
io::{AsyncRead, AsyncWrite, ReadBuf},
sync::mpsc,
Expand Down Expand Up @@ -91,11 +90,10 @@ impl Yamux {
where
TSocket: futures::AsyncRead + futures::AsyncWrite + Unpin + Send + 'static,
{
let shutdown = Shutdown::new();
let (incoming_tx, incoming_rx) = mpsc::channel(10);
let incoming = IncomingWorker::new(connection, incoming_tx, shutdown.to_signal());
let incoming = IncomingWorker::new(connection, incoming_tx);
runtime::task::spawn(incoming.run());
IncomingSubstreams::new(incoming_rx, counter, shutdown)
IncomingSubstreams::new(incoming_rx, counter)
}

/// Get the yamux control struct
Expand Down Expand Up @@ -166,19 +164,13 @@ impl Control {
pub struct IncomingSubstreams {
inner: mpsc::Receiver<yamux::Stream>,
substream_counter: AtomicRefCounter,
shutdown: Shutdown,
}

impl IncomingSubstreams {
pub(self) fn new(
inner: mpsc::Receiver<yamux::Stream>,
substream_counter: AtomicRefCounter,
shutdown: Shutdown,
) -> Self {
pub(self) fn new(inner: mpsc::Receiver<yamux::Stream>, substream_counter: AtomicRefCounter) -> Self {
Self {
inner,
substream_counter,
shutdown,
}
}

Expand All @@ -201,12 +193,6 @@ impl Stream for IncomingSubstreams {
}
}

impl Drop for IncomingSubstreams {
fn drop(&mut self) {
self.shutdown.trigger();
}
}

/// A yamux stream wrapper that can be read from and written to.
#[derive(Debug)]
pub struct Substream {
Expand Down Expand Up @@ -258,41 +244,23 @@ impl From<yamux::StreamId> for stream_id::Id {
struct IncomingWorker<TSocket> {
connection: yamux::Connection<TSocket>,
sender: mpsc::Sender<yamux::Stream>,
shutdown_signal: ShutdownSignal,
}

impl<TSocket> IncomingWorker<TSocket>
where TSocket: futures::AsyncRead + futures::AsyncWrite + Unpin + Send + 'static /* */
{
pub fn new(
connection: yamux::Connection<TSocket>,
sender: mpsc::Sender<yamux::Stream>,
shutdown_signal: ShutdownSignal,
) -> Self {
Self {
connection,
sender,
shutdown_signal,
}
pub fn new(connection: yamux::Connection<TSocket>, sender: mpsc::Sender<yamux::Stream>) -> Self {
Self { connection, sender }
}

#[tracing::instrument(name = "yamux::incoming_worker::run", skip(self), fields(connection = %self.connection))]
pub async fn run(mut self) {
loop {
tokio::select! {
biased;

_ = self.shutdown_signal.wait() => {
let mut control = self.connection.control();
if let Err(err) = control.close().await {
error!(target: LOG_TARGET, "Failed to close yamux connection: {}", err);
}
debug!(
target: LOG_TARGET,
"{} Yamux connection has closed", self.connection
);
_ = self.sender.closed() => {
self.close().await;
break
}
},

result = self.connection.next_stream() => {
match result {
Expand Down Expand Up @@ -336,14 +304,51 @@ where TSocket: futures::AsyncRead + futures::AsyncWrite + Unpin + Send + 'static
}
}
}

async fn close(&mut self) {
let mut control = self.connection.control();
// Sends the close message once polled, while continuing to poll the connection future
let close_fut = control.close();
tokio::pin!(close_fut);
loop {
tokio::select! {
biased;

result = &mut close_fut => {
match result {
Ok(_) => break,
Err(err) => {
error!(target: LOG_TARGET, "Failed to close yamux connection: {}", err);
break;
}
}
},

result = self.connection.next_stream() => {
match result {
Ok(Some(_)) => continue,
Ok(None) => break,
Err(err) => {
error!(target: LOG_TARGET, "Error while closing yamux connection: {}", err);
continue;
}
}
}
}
}
debug!(target: LOG_TARGET, "{} Yamux connection has closed", self.connection);
}
}

#[cfg(test)]
mod test {
use std::{io, time::Duration};
use std::{io, sync::Arc, time::Duration};

use tari_test_utils::collect_stream;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::{
io::{AsyncReadExt, AsyncWriteExt},
sync::Barrier,
};
use tokio_stream::StreamExt;

use crate::{
Expand Down Expand Up @@ -455,6 +460,34 @@ mod test {
Ok(())
}

#[runtime::test]
async fn rude_close_does_not_freeze() -> io::Result<()> {
let (dialer, listener) = MemorySocket::new_pair();

let barrier = Arc::new(Barrier::new(2));
let b = barrier.clone();

task::spawn(async move {
// Drop immediately
let incoming = Yamux::upgrade_connection(listener, ConnectionDirection::Inbound)
.unwrap()
.into_incoming();
drop(incoming);
b.wait().await;
});

let dialer = Yamux::upgrade_connection(dialer, ConnectionDirection::Outbound).unwrap();
let mut dialer_control = dialer.get_yamux_control();
let mut substream = dialer_control.open_stream().await.unwrap();
barrier.wait().await;

let mut buf = vec![];
substream.read_to_end(&mut buf).await.unwrap();
assert!(buf.is_empty());

Ok(())
}

#[runtime::test]
async fn send_big_message() -> io::Result<()> {
#[allow(non_upper_case_globals)]
Expand Down
4 changes: 2 additions & 2 deletions comms/core/src/pipeline/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ where
let pipeline = (factory)(sink_service);
Ok(OutboundPipelineConfig {
in_receiver,
out_receiver,
out_receiver: Some(out_receiver),
pipeline,
})
}
Expand Down Expand Up @@ -147,7 +147,7 @@ pub struct OutboundPipelineConfig<TInItem, TPipeline> {
/// Messages read from this stream are passed to the pipeline
pub in_receiver: mpsc::Receiver<TInItem>,
/// Receiver of `OutboundMessage`s coming from the pipeline
pub out_receiver: mpsc::UnboundedReceiver<OutboundMessage>,
pub out_receiver: Option<mpsc::UnboundedReceiver<OutboundMessage>>,
/// The pipeline (`tower::Service`) to run for each in_stream message
pub pipeline: TPipeline,
}
Expand Down
22 changes: 12 additions & 10 deletions comms/core/src/pipeline/inbound.rs
Original file line number Diff line number Diff line change
Expand Up @@ -88,15 +88,17 @@ where

let num_available = self.executor.num_available();
let max_available = self.executor.max_available();
// Only emit this message if there is any concurrent usage
if num_available < max_available {
debug!(
target: LOG_TARGET,
"Inbound pipeline usage: {}/{}",
max_available - num_available,
max_available
);
}
log!(
target: LOG_TARGET,
if num_available < max_available {
Level::Debug
} else {
Level::Trace
},
"Inbound pipeline usage: {}/{}",
max_available - num_available,
max_available
);

let id = current_id;
current_id = (current_id + 1) % u64::MAX;
Expand All @@ -106,7 +108,7 @@ where
.spawn(async move {
let timer = Instant::now();
trace!(target: LOG_TARGET, "Start inbound pipeline {}", id);
match time::timeout(Duration::from_secs(30), service.oneshot(item)).await {
match time::timeout(Duration::from_secs(10), service.oneshot(item)).await {
Ok(Ok(_)) => {},
Ok(Err(err)) => {
warn!(target: LOG_TARGET, "Inbound pipeline returned an error: '{}'", err);
Expand Down
Loading