Skip to content

Commit

Permalink
Add a ByteWriter wrapper that implements AsyncWrite for a `Sink<M…
Browse files Browse the repository at this point in the history
…essage>`

This is useful for programs that want to treat a WebSocket as a stream
of bytes.
  • Loading branch information
joshtriplett authored and sdroege committed Feb 18, 2025
1 parent 79fb4b2 commit 11d4d33
Show file tree
Hide file tree
Showing 4 changed files with 129 additions and 0 deletions.
4 changes: 4 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,10 @@ required-features = ["async-std-runtime"]
name = "client"
required-features = ["async-std-runtime"]

[[example]]
name = "client-bytes"
required-features = ["async-std-runtime"]

[[example]]
name = "autobahn-server"
required-features = ["async-std-runtime", "futures-03-sink"]
Expand Down
48 changes: 48 additions & 0 deletions examples/client-bytes.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
//! A simple example of hooking up stdin/stdout to a WebSocket stream using ByteStream.
//!
//! This example will connect to a server specified in the argument list and
//! then forward all data read on stdin to the server, printing out all data
//! received on stdout.
//!
//! Note that this is not currently optimized for performance, especially around
//! buffer management. Rather it's intended to show an example of working with a
//! client.
//!
//! You can use this example together with the `server` example.
use std::env;

use futures::StreamExt;

use async_std::io;
use async_std::prelude::*;
use async_std::task;
use async_tungstenite::async_std::connect_async;
use async_tungstenite::ByteWriter;

async fn run() {
let connect_addr = env::args()
.nth(1)
.unwrap_or_else(|| panic!("this program requires at least one argument"));

let (ws_stream, _) = connect_async(&connect_addr)
.await
.expect("Failed to connect");
println!("WebSocket handshake has been successfully completed");

let (write, read) = ws_stream.split();
let byte_writer = ByteWriter::new(write);
let stdin_to_ws = task::spawn(async {
io::copy(io::stdin(), byte_writer).await.unwrap();
});
let ws_to_stdout = task::spawn(read.for_each(|message| async {
let data = message.unwrap().into_data();
async_std::io::stdout().write_all(&data).await.unwrap();
}));
stdin_to_ws.await;
ws_to_stdout.await;
}

fn main() {
task::block_on(run())
}
72 changes: 72 additions & 0 deletions src/bytes.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
//! Provides an abstraction to use `AsyncWrite` to write bytes to a `WebSocketStream`.
use std::{
io,
pin::Pin,
task::{Context, Poll},
};

use futures_util::Sink;

use crate::{Message, WsError};

/// Treat a `WebSocketStream` as an `AsyncWrite` implementation.
///
/// Every write sends a binary message. If you want to group writes together, consider wrapping
/// this with a `BufWriter`.
#[derive(Debug)]
pub struct ByteWriter<S>(S);

impl<S> ByteWriter<S> {
/// Create a new `ByteWriter` from a `Sink` that accepts a WebSocket `Message`
#[inline(always)]
pub fn new(s: S) -> Self {
Self(s)
}

/// Get the underlying `Sink` back.
#[inline(always)]
pub fn into_inner(self) -> S {
self.0
}
}

impl<S> futures_io::AsyncWrite for ByteWriter<S>
where
S: Sink<Message, Error = WsError> + Unpin,
{
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
match Pin::new(&mut self.0).poll_ready(cx).map_err(convert_err) {
Poll::Ready(Ok(())) => {}
Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
Poll::Pending => return Poll::Pending,
}
let len = buf.len();
let msg = Message::binary(buf.to_owned());
Poll::Ready(
Pin::new(&mut self.0)
.start_send(msg)
.map_err(convert_err)
.map(|()| len),
)
}

fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Pin::new(&mut self.0).poll_flush(cx).map_err(convert_err)
}

fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Pin::new(&mut self.0).poll_close(cx).map_err(convert_err)
}
}

fn convert_err(e: WsError) -> io::Error {
match e {
WsError::Io(io) => io,
_ => io::Error::new(io::ErrorKind::Other, e),
}
}
5 changes: 5 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,11 @@ pub mod gio;
#[cfg(feature = "tokio-runtime")]
pub mod tokio;

#[cfg(feature = "futures-03-sink")]
pub mod bytes;
#[cfg(feature = "futures-03-sink")]
pub use bytes::ByteWriter;

use tungstenite::protocol::CloseFrame;

/// Creates a WebSocket handshake from a request and a stream.
Expand Down

0 comments on commit 11d4d33

Please sign in to comment.