From b54df4ced534c8cf571637432cb4a7c3774a2799 Mon Sep 17 00:00:00 2001 From: Thomas Eizinger Date: Tue, 25 Oct 2022 16:04:22 +1100 Subject: [PATCH] Add test for poll-based API --- tests/poll_api.rs | 249 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 249 insertions(+) create mode 100644 tests/poll_api.rs diff --git a/tests/poll_api.rs b/tests/poll_api.rs new file mode 100644 index 00000000..d755b71c --- /dev/null +++ b/tests/poll_api.rs @@ -0,0 +1,249 @@ +use futures::future::BoxFuture; +use futures::stream::FuturesUnordered; +use futures::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, FutureExt, StreamExt}; +use quickcheck::{Arbitrary, Gen, QuickCheck}; +use std::future::Future; +use std::io; +use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4}; +use std::pin::Pin; +use std::task::{Context, Poll}; +use tokio::net::{TcpListener, TcpStream}; +use tokio::runtime::Runtime; +use tokio_util::compat::TokioAsyncReadCompatExt; +use yamux::{Connection, Mode, WindowUpdateMode}; + +#[test] +fn prop_config_send_recv_multi() { + let _ = env_logger::try_init(); + + fn prop(msgs: Vec, cfg1: TestConfig, cfg2: TestConfig) { + Runtime::new().unwrap().block_on(async move { + let num_messagses = msgs.len(); + + let (listener, address) = bind().await.expect("bind"); + + let server = async { + let socket = listener.accept().await.expect("accept").0.compat(); + let connection = Connection::new(socket, cfg1.0, Mode::Server); + + EchoServer::new(connection).await + }; + + let client = async { + let socket = TcpStream::connect(address).await.expect("connect").compat(); + let connection = Connection::new(socket, cfg2.0, Mode::Client); + + MessageSender::new(connection, msgs).await + }; + + let (server_processed, client_processed) = + futures::future::try_join(server, client).await.unwrap(); + + assert_eq!(server_processed, num_messagses); + assert_eq!(client_processed, num_messagses); + }) + } + + QuickCheck::new() + .tests(10) + .quickcheck(prop as fn(_, _, _) -> _) +} + +#[derive(Clone, Debug)] +struct Msg(Vec); + +impl Arbitrary for Msg { + fn arbitrary(g: &mut Gen) -> Msg { + let mut msg = Msg(Arbitrary::arbitrary(g)); + if msg.0.is_empty() { + msg.0.push(Arbitrary::arbitrary(g)); + } + + msg + } + + fn shrink(&self) -> Box> { + Box::new(self.0.shrink().filter(|v| !v.is_empty()).map(|v| Msg(v))) + } +} + +#[derive(Clone, Debug)] +struct TestConfig(yamux::Config); + +impl Arbitrary for TestConfig { + fn arbitrary(g: &mut Gen) -> Self { + let mut c = yamux::Config::default(); + c.set_window_update_mode(if bool::arbitrary(g) { + WindowUpdateMode::OnRead + } else { + WindowUpdateMode::OnReceive + }); + c.set_read_after_close(Arbitrary::arbitrary(g)); + c.set_receive_window(256 * 1024 + u32::arbitrary(g) % (768 * 1024)); + TestConfig(c) + } +} + +async fn bind() -> io::Result<(TcpListener, SocketAddr)> { + let i = Ipv4Addr::new(127, 0, 0, 1); + let s = SocketAddr::V4(SocketAddrV4::new(i, 0)); + let l = TcpListener::bind(&s).await?; + let a = l.local_addr()?; + Ok((l, a)) +} + +struct EchoServer { + connection: Connection, + worker_streams: FuturesUnordered>>, + streams_processed: usize, +} + +impl EchoServer { + fn new(connection: Connection) -> Self { + Self { + connection, + worker_streams: FuturesUnordered::default(), + streams_processed: 0, + } + } +} + +impl Future for EchoServer +where + T: AsyncRead + AsyncWrite + Unpin, +{ + type Output = yamux::Result; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.get_mut(); + + loop { + match this.worker_streams.poll_next_unpin(cx) { + Poll::Ready(Some(Ok(()))) => { + this.streams_processed += 1; + continue; + } + Poll::Ready(Some(Err(e))) => { + eprintln!("A stream failed: {}", e); + continue; + } + Poll::Ready(None) | Poll::Pending => {} + } + + match this.connection.poll_next_inbound(cx)? { + Poll::Ready(Some(mut stream)) => { + this.worker_streams.push( + async move { + { + let (mut r, mut w) = AsyncReadExt::split(&mut stream); + futures::io::copy(&mut r, &mut w).await?; + } + stream.close().await?; + Ok(()) + } + .boxed(), + ); + continue; + } + Poll::Ready(None) => return Poll::Ready(Ok(this.streams_processed)), + Poll::Pending => {} + } + + return Poll::Pending; + } + } +} + +struct MessageSender { + connection: Connection, + pending_messages: Vec, + worker_streams: FuturesUnordered>, + streams_processed: usize, +} + +impl MessageSender { + fn new(connection: Connection, messages: Vec) -> Self { + Self { + connection, + pending_messages: messages, + worker_streams: FuturesUnordered::default(), + streams_processed: 0, + } + } +} + +impl Future for MessageSender +where + T: AsyncRead + AsyncWrite + Unpin, +{ + type Output = yamux::Result; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.get_mut(); + + loop { + if this.pending_messages.is_empty() && this.worker_streams.is_empty() { + futures::ready!(this.connection.poll_close(cx)?); + + return Poll::Ready(Ok(this.streams_processed)); + } + + match this.worker_streams.poll_next_unpin(cx) { + Poll::Ready(Some(())) => { + this.streams_processed += 1; + continue; + } + Poll::Ready(None) | Poll::Pending => {} + } + + if let Some(Msg(message)) = this.pending_messages.pop() { + match this.connection.poll_new_outbound(cx)? { + Poll::Ready(stream) => { + this.worker_streams.push( + async move { + let id = stream.id(); + let len = message.len(); + + let (mut reader, mut writer) = AsyncReadExt::split(stream); + + let write_fut = async { + writer.write_all(&message).await.unwrap(); + log::debug!("C: {}: sent {} bytes", id, len); + writer.close().await.unwrap(); + }; + + let mut received = Vec::new(); + let read_fut = async { + reader.read_to_end(&mut received).await.unwrap(); + log::debug!("C: {}: received {} bytes", id, received.len()); + }; + + futures::future::join(write_fut, read_fut).await; + + assert_eq!(message, received) + } + .boxed(), + ); + continue; + } + Poll::Pending => { + this.pending_messages.push(Msg(message)); + } + } + } + + match this.connection.poll_next_inbound(cx)? { + Poll::Ready(Some(stream)) => { + drop(stream); + panic!("Did not expect remote to open a stream"); + } + Poll::Ready(None) => { + panic!("Did not expect remote to close the connection"); + } + Poll::Pending => {} + } + + return Poll::Pending; + } + } +}