diff --git a/iroh-gossip/src/net.rs b/iroh-gossip/src/net.rs index 2500710a1d..6610eaf830 100644 --- a/iroh-gossip/src/net.rs +++ b/iroh-gossip/src/net.rs @@ -599,6 +599,7 @@ mod test { use iroh_net::{derp::DerpMap, MagicEndpoint}; use tokio::spawn; + use tokio::time::timeout; use tokio_util::sync::CancellationToken; use tracing::info; @@ -632,7 +633,7 @@ mod test { #[tokio::test] async fn gossip_net_smoke() { - util::setup_logging(); + let _guard = util::setup_logging(); let (derp_map, derp_region, cleanup) = util::run_derp_and_stun([127, 0, 0, 1].into()) .await .unwrap(); @@ -667,6 +668,11 @@ mod test { let len = 10; + // subscribe nodes 2 and 3 to the topic + let mut stream2 = go2.subscribe(topic).await.unwrap(); + let mut stream3 = go3.subscribe(topic).await.unwrap(); + + // publish messages on node1 let pub1 = spawn(async move { for i in 0..len { let message = format!("hi{}", i); @@ -678,11 +684,11 @@ mod test { } }); + // wait for messages on node2 let sub2 = spawn(async move { - let mut stream = go2.subscribe(topic).await.unwrap(); let mut recv = vec![]; loop { - let ev = stream.recv().await.unwrap(); + let ev = stream2.recv().await.unwrap(); info!("go2 event: {ev:?}"); if let Event::Received(msg, _prev_peer) = ev { recv.push(msg); @@ -693,11 +699,11 @@ mod test { } }); + // wait for messages on node3 let sub3 = spawn(async move { - let mut stream = go3.subscribe(topic).await.unwrap(); let mut recv = vec![]; loop { - let ev = stream.recv().await.unwrap(); + let ev = stream3.recv().await.unwrap(); info!("go3 event: {ev:?}"); if let Event::Received(msg, _prev_peer) = ev { recv.push(msg); @@ -708,9 +714,18 @@ mod test { } }); - pub1.await.unwrap(); - let recv2 = sub2.await.unwrap(); - let recv3 = sub3.await.unwrap(); + timeout(Duration::from_secs(10), pub1) + .await + .unwrap() + .unwrap(); + let recv2 = timeout(Duration::from_secs(10), sub2) + .await + .unwrap() + .unwrap(); + let recv3 = timeout(Duration::from_secs(10), sub3) + .await + .unwrap() + .unwrap(); let expected: Vec = (0..len) .map(|i| Bytes::from(format!("hi{i}").into_bytes())) @@ -720,7 +735,11 @@ mod test { cancel.cancel(); for t in tasks { - t.await.unwrap().unwrap(); + timeout(Duration::from_secs(10), t) + .await + .unwrap() + .unwrap() + .unwrap(); } drop(cleanup); } @@ -735,16 +754,103 @@ mod test { derp::{DerpMap, UseIpv4, UseIpv6}, stun::{is, parse_binding_request, response}, }; - use tokio::sync::oneshot; + use tokio::{runtime::RuntimeFlavor, sync::oneshot}; + use tracing::level_filters::LevelFilter; use tracing::{debug, info, trace}; use tracing_subscriber::{prelude::*, EnvFilter}; - pub fn setup_logging() { + /// Configures logging for the current test, **single-threaded runtime only**. + /// + /// This setup can be used for any sync test or async test using a single-threaded tokio + /// runtime (the default). For multi-threaded runtimes use [`with_logging`]. + /// + /// This configures logging that will interact well with tests: logs will be captured by the + /// test framework and only printed on failure. + /// + /// The logging is unfiltered, it logs all crates and modules on TRACE level. If that's too + /// much consider if your test is too large (or write a version that allows filtering...). + /// + /// # Example + /// + /// ```no_run + /// #[tokio::test] + /// async fn test_something() { + /// let _guard = crate::test_utils::setup_logging(); + /// assert!(true); + /// } + #[must_use = "The tracing guard must only be dropped at the end of the test"] + #[allow(dead_code)] + pub(crate) fn setup_logging() -> tracing::subscriber::DefaultGuard { + if let Ok(handle) = tokio::runtime::Handle::try_current() { + match handle.runtime_flavor() { + RuntimeFlavor::CurrentThread => (), + RuntimeFlavor::MultiThread => { + panic!("setup_logging() does not work in a multi-threaded tokio runtime"); + } + _ => panic!("unknown runtime flavour"), + } + } + testing_subscriber().set_default() + } + + /// Returns the a [`tracing::Subscriber`] configured for our tests. + /// + /// This subscriber will ensure that log output is captured by the test's default output + /// capturing and thus is only shown with the test on failure. By default it uses + /// `RUST_LOG=trace` as configuration but you can specify the `RUST_LOG` environment + /// variable explicitly to override this. + /// + /// To use this in a tokio multi-threaded runtime use: + /// + /// ```no_run + /// use tracing_future::WithSubscriber; + /// use crate::test_utils::testing_subscriber; + /// + /// #[tokio::test(flavor = "multi_thread")] + /// async fn test_something() -> Result<()> { + /// async move { + /// Ok(()) + /// }.with_subscriber(testing_subscriber()).await + /// } + /// ``` + pub(crate) fn testing_subscriber() -> impl tracing::Subscriber { + let var = std::env::var_os("RUST_LOG"); + let trace_log_layer = match var { + Some(_) => None, + None => Some( + tracing_subscriber::fmt::layer() + .with_writer(|| TestWriter) + .with_filter(LevelFilter::TRACE), + ), + }; + let env_log_layer = var.map(|_| { + tracing_subscriber::fmt::layer() + .with_writer(|| TestWriter) + .with_filter(EnvFilter::from_default_env()) + }); tracing_subscriber::registry() - .with(tracing_subscriber::fmt::layer().with_writer(std::io::stderr)) - .with(EnvFilter::from_default_env()) - .try_init() - .ok(); + .with(trace_log_layer) + .with(env_log_layer) + } + + /// A tracing writer that interacts well with test output capture. + /// + /// Using this writer will make sure that the output is captured normally and only printed + /// when the test fails. See [`setup_logging`] to actually use this. + #[derive(Debug)] + struct TestWriter; + + impl std::io::Write for TestWriter { + fn write(&mut self, buf: &[u8]) -> std::io::Result { + print!( + "{}", + std::str::from_utf8(buf).expect("tried to log invalid UTF-8") + ); + Ok(buf.len()) + } + fn flush(&mut self) -> std::io::Result<()> { + std::io::stdout().flush() + } } /// A drop guard to clean up test infrastructure.