Skip to content

Commit

Permalink
Add tokio support for ByteReader
Browse files Browse the repository at this point in the history
Factor out a helper for the common portions.

This also fixes a bug in the async-std version when doing a short read.
  • Loading branch information
joshtriplett authored and sdroege committed Feb 18, 2025
1 parent 2ac7c70 commit 4a2cb34
Show file tree
Hide file tree
Showing 3 changed files with 105 additions and 24 deletions.
4 changes: 4 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,10 @@ required-features = ["gio-runtime"]
name = "gio-echo-server"
required-features = ["gio-runtime"]

[[example]]
name = "tokio-client-bytes"
required-features = ["tokio-runtime"]

[[example]]
name = "tokio-echo"
required-features = ["tokio-runtime"]
Expand Down
46 changes: 46 additions & 0 deletions examples/tokio-client-bytes.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
//! A simple example of hooking up stdin/stdout to a WebSocket stream using ByteStream.
//!
//! This example will connect to a server specified in the argument list and
//! then forward all data read on stdin to the server, printing out all data
//! received on stdout.
//!
//! Note that this is not currently optimized for performance, especially around
//! buffer management. Rather it's intended to show an example of working with a
//! client.
//!
//! You can use this example together with the `server` example.
use std::env;

use futures::StreamExt;

use async_tungstenite::tokio::connect_async;
use async_tungstenite::{ByteReader, ByteWriter};
use tokio::io;
use tokio::task;

async fn run() {
let connect_addr = env::args()
.nth(1)
.unwrap_or_else(|| panic!("this program requires at least one argument"));

let (ws_stream, _) = connect_async(&connect_addr)
.await
.expect("Failed to connect");
println!("WebSocket handshake has been successfully completed");

let (write, read) = ws_stream.split();
let mut byte_writer = ByteWriter::new(write);
let mut byte_reader = ByteReader::new(read);
let stdin_to_ws =
task::spawn(async move { io::copy(&mut io::stdin(), &mut byte_writer).await });
let ws_to_stdout =
task::spawn(async move { io::copy(&mut byte_reader, &mut io::stdout()).await });
stdin_to_ws.await.unwrap().unwrap();
ws_to_stdout.await.unwrap().unwrap();
}

fn main() {
let rt = tokio::runtime::Runtime::new().expect("runtime");
rt.block_on(run())
}
79 changes: 55 additions & 24 deletions src/bytes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -125,39 +125,70 @@ impl<S> ByteReader<S> {
}
}

fn poll_read_helper<S>(
mut s: Pin<&mut ByteReader<S>>,
cx: &mut Context<'_>,
buf_len: usize,
) -> Poll<io::Result<Option<Bytes>>>
where
S: Stream<Item = Result<Message, WsError>> + Unpin,
{
Poll::Ready(Ok(Some(match s.bytes {
None => match Pin::new(&mut s.stream).poll_next(cx) {
Poll::Pending => return Poll::Pending,
Poll::Ready(None) => return Poll::Ready(Ok(None)),
Poll::Ready(Some(Err(e))) => return Poll::Ready(Err(convert_err(e))),
Poll::Ready(Some(Ok(msg))) => {
let bytes = msg.into_data();
if bytes.len() > buf_len {
s.bytes.insert(bytes).split_to(buf_len)
} else {
bytes
}
}
},
Some(ref mut bytes) if bytes.len() > buf_len => bytes.split_to(buf_len),
Some(ref mut bytes) => {
let bytes = bytes.clone();
s.bytes = None;
bytes
}
})))
}

impl<S> futures_io::AsyncRead for ByteReader<S>
where
S: Stream<Item = Result<Message, WsError>> + Unpin,
{
fn poll_read(
mut self: Pin<&mut Self>,
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
let buf_len = buf.len();
let bytes_to_read = match self.bytes {
None => match Pin::new(&mut self.stream).poll_next(cx) {
Poll::Pending => return Poll::Pending,
Poll::Ready(None) => return Poll::Ready(Ok(0)),
Poll::Ready(Some(Err(e))) => return Poll::Ready(Err(convert_err(e))),
Poll::Ready(Some(Ok(msg))) => {
let bytes = msg.into_data();
if bytes.len() > buf_len {
self.bytes.insert(bytes).split_to(buf_len)
} else {
bytes
}
}
},
Some(ref mut bytes) if bytes.len() > buf_len => bytes.split_to(buf_len),
Some(ref mut bytes) => {
let bytes = bytes.clone();
self.bytes = None;
bytes
poll_read_helper(self, cx, buf.len()).map_ok(|bytes| {
bytes.map_or(0, |bytes| {
buf[..bytes.len()].copy_from_slice(&bytes);
bytes.len()
})
})
}
}

#[cfg(feature = "tokio-runtime")]
impl<S> tokio::io::AsyncRead for ByteReader<S>
where
S: Stream<Item = Result<Message, WsError>> + Unpin,
{
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut tokio::io::ReadBuf,
) -> Poll<io::Result<()>> {
poll_read_helper(self, cx, buf.remaining()).map_ok(|bytes| {
if let Some(ref bytes) = bytes {
buf.put_slice(bytes);
}
};
buf.copy_from_slice(&bytes_to_read);
Poll::Ready(Ok(bytes_to_read.len()))
})
}
}

Expand Down

0 comments on commit 4a2cb34

Please sign in to comment.