From 4a2cb3491517d8c87dbfbd4eb6d618e918aad284 Mon Sep 17 00:00:00 2001 From: Josh Triplett Date: Tue, 18 Feb 2025 13:14:11 +0100 Subject: [PATCH] Add tokio support for `ByteReader` Factor out a helper for the common portions. This also fixes a bug in the async-std version when doing a short read. --- Cargo.toml | 4 ++ examples/tokio-client-bytes.rs | 46 ++++++++++++++++++++ src/bytes.rs | 79 +++++++++++++++++++++++----------- 3 files changed, 105 insertions(+), 24 deletions(-) create mode 100644 examples/tokio-client-bytes.rs diff --git a/Cargo.toml b/Cargo.toml index 15c2416..f77f927 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -180,6 +180,10 @@ required-features = ["gio-runtime"] name = "gio-echo-server" required-features = ["gio-runtime"] +[[example]] +name = "tokio-client-bytes" +required-features = ["tokio-runtime"] + [[example]] name = "tokio-echo" required-features = ["tokio-runtime"] diff --git a/examples/tokio-client-bytes.rs b/examples/tokio-client-bytes.rs new file mode 100644 index 0000000..b68b654 --- /dev/null +++ b/examples/tokio-client-bytes.rs @@ -0,0 +1,46 @@ +//! A simple example of hooking up stdin/stdout to a WebSocket stream using ByteStream. +//! +//! This example will connect to a server specified in the argument list and +//! then forward all data read on stdin to the server, printing out all data +//! received on stdout. +//! +//! Note that this is not currently optimized for performance, especially around +//! buffer management. Rather it's intended to show an example of working with a +//! client. +//! +//! You can use this example together with the `server` example. + +use std::env; + +use futures::StreamExt; + +use async_tungstenite::tokio::connect_async; +use async_tungstenite::{ByteReader, ByteWriter}; +use tokio::io; +use tokio::task; + +async fn run() { + let connect_addr = env::args() + .nth(1) + .unwrap_or_else(|| panic!("this program requires at least one argument")); + + let (ws_stream, _) = connect_async(&connect_addr) + .await + .expect("Failed to connect"); + println!("WebSocket handshake has been successfully completed"); + + let (write, read) = ws_stream.split(); + let mut byte_writer = ByteWriter::new(write); + let mut byte_reader = ByteReader::new(read); + let stdin_to_ws = + task::spawn(async move { io::copy(&mut io::stdin(), &mut byte_writer).await }); + let ws_to_stdout = + task::spawn(async move { io::copy(&mut byte_reader, &mut io::stdout()).await }); + stdin_to_ws.await.unwrap().unwrap(); + ws_to_stdout.await.unwrap().unwrap(); +} + +fn main() { + let rt = tokio::runtime::Runtime::new().expect("runtime"); + rt.block_on(run()) +} diff --git a/src/bytes.rs b/src/bytes.rs index 2770b3d..7927635 100644 --- a/src/bytes.rs +++ b/src/bytes.rs @@ -125,39 +125,70 @@ impl ByteReader { } } +fn poll_read_helper( + mut s: Pin<&mut ByteReader>, + cx: &mut Context<'_>, + buf_len: usize, +) -> Poll>> +where + S: Stream> + Unpin, +{ + Poll::Ready(Ok(Some(match s.bytes { + None => match Pin::new(&mut s.stream).poll_next(cx) { + Poll::Pending => return Poll::Pending, + Poll::Ready(None) => return Poll::Ready(Ok(None)), + Poll::Ready(Some(Err(e))) => return Poll::Ready(Err(convert_err(e))), + Poll::Ready(Some(Ok(msg))) => { + let bytes = msg.into_data(); + if bytes.len() > buf_len { + s.bytes.insert(bytes).split_to(buf_len) + } else { + bytes + } + } + }, + Some(ref mut bytes) if bytes.len() > buf_len => bytes.split_to(buf_len), + Some(ref mut bytes) => { + let bytes = bytes.clone(); + s.bytes = None; + bytes + } + }))) +} + impl futures_io::AsyncRead for ByteReader where S: Stream> + Unpin, { fn poll_read( - mut self: Pin<&mut Self>, + self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8], ) -> Poll> { - let buf_len = buf.len(); - let bytes_to_read = match self.bytes { - None => match Pin::new(&mut self.stream).poll_next(cx) { - Poll::Pending => return Poll::Pending, - Poll::Ready(None) => return Poll::Ready(Ok(0)), - Poll::Ready(Some(Err(e))) => return Poll::Ready(Err(convert_err(e))), - Poll::Ready(Some(Ok(msg))) => { - let bytes = msg.into_data(); - if bytes.len() > buf_len { - self.bytes.insert(bytes).split_to(buf_len) - } else { - bytes - } - } - }, - Some(ref mut bytes) if bytes.len() > buf_len => bytes.split_to(buf_len), - Some(ref mut bytes) => { - let bytes = bytes.clone(); - self.bytes = None; - bytes + poll_read_helper(self, cx, buf.len()).map_ok(|bytes| { + bytes.map_or(0, |bytes| { + buf[..bytes.len()].copy_from_slice(&bytes); + bytes.len() + }) + }) + } +} + +#[cfg(feature = "tokio-runtime")] +impl tokio::io::AsyncRead for ByteReader +where + S: Stream> + Unpin, +{ + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut tokio::io::ReadBuf, + ) -> Poll> { + poll_read_helper(self, cx, buf.remaining()).map_ok(|bytes| { + if let Some(ref bytes) = bytes { + buf.put_slice(bytes); } - }; - buf.copy_from_slice(&bytes_to_read); - Poll::Ready(Ok(bytes_to_read.len())) + }) } }