Skip to content

Commit

Permalink
ByteWriter: Support tokio::io::AsyncWrite if tokio-runtime is ena…
Browse files Browse the repository at this point in the history
…bled
  • Loading branch information
joshtriplett authored and sdroege committed Feb 18, 2025
1 parent 11d4d33 commit a6f8f9a
Showing 1 changed file with 47 additions and 14 deletions.
61 changes: 47 additions & 14 deletions src/bytes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,28 +31,39 @@ impl<S> ByteWriter<S> {
}
}

fn poll_write_helper<S>(
mut s: Pin<&mut ByteWriter<S>>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>>
where
S: Sink<Message, Error = WsError> + Unpin,
{
match Pin::new(&mut s.0).poll_ready(cx).map_err(convert_err) {
Poll::Ready(Ok(())) => {}
Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
Poll::Pending => return Poll::Pending,
}
let len = buf.len();
let msg = Message::binary(buf.to_owned());
Poll::Ready(
Pin::new(&mut s.0)
.start_send(msg)
.map_err(convert_err)
.map(|()| len),
)
}

impl<S> futures_io::AsyncWrite for ByteWriter<S>
where
S: Sink<Message, Error = WsError> + Unpin,
{
fn poll_write(
mut self: Pin<&mut Self>,
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
match Pin::new(&mut self.0).poll_ready(cx).map_err(convert_err) {
Poll::Ready(Ok(())) => {}
Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
Poll::Pending => return Poll::Pending,
}
let len = buf.len();
let msg = Message::binary(buf.to_owned());
Poll::Ready(
Pin::new(&mut self.0)
.start_send(msg)
.map_err(convert_err)
.map(|()| len),
)
poll_write_helper(self, cx, buf)
}

fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Expand All @@ -64,6 +75,28 @@ where
}
}

#[cfg(feature = "tokio-runtime")]
impl<S> tokio::io::AsyncWrite for ByteWriter<S>
where
S: Sink<Message, Error = WsError> + Unpin,
{
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
poll_write_helper(self, cx, buf)
}

fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Pin::new(&mut self.0).poll_flush(cx).map_err(convert_err)
}

fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Pin::new(&mut self.0).poll_close(cx).map_err(convert_err)
}
}

fn convert_err(e: WsError) -> io::Error {
match e {
WsError::Io(io) => io,
Expand Down

0 comments on commit a6f8f9a

Please sign in to comment.