diff --git a/src/bytes.rs b/src/bytes.rs index 4843a2b..b19bb0c 100644 --- a/src/bytes.rs +++ b/src/bytes.rs @@ -31,28 +31,39 @@ impl ByteWriter { } } +fn poll_write_helper( + mut s: Pin<&mut ByteWriter>, + cx: &mut Context<'_>, + buf: &[u8], +) -> Poll> +where + S: Sink + Unpin, +{ + match Pin::new(&mut s.0).poll_ready(cx).map_err(convert_err) { + Poll::Ready(Ok(())) => {} + Poll::Ready(Err(e)) => return Poll::Ready(Err(e)), + Poll::Pending => return Poll::Pending, + } + let len = buf.len(); + let msg = Message::binary(buf.to_owned()); + Poll::Ready( + Pin::new(&mut s.0) + .start_send(msg) + .map_err(convert_err) + .map(|()| len), + ) +} + impl futures_io::AsyncWrite for ByteWriter where S: Sink + Unpin, { fn poll_write( - mut self: Pin<&mut Self>, + self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8], ) -> Poll> { - match Pin::new(&mut self.0).poll_ready(cx).map_err(convert_err) { - Poll::Ready(Ok(())) => {} - Poll::Ready(Err(e)) => return Poll::Ready(Err(e)), - Poll::Pending => return Poll::Pending, - } - let len = buf.len(); - let msg = Message::binary(buf.to_owned()); - Poll::Ready( - Pin::new(&mut self.0) - .start_send(msg) - .map_err(convert_err) - .map(|()| len), - ) + poll_write_helper(self, cx, buf) } fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { @@ -64,6 +75,28 @@ where } } +#[cfg(feature = "tokio-runtime")] +impl tokio::io::AsyncWrite for ByteWriter +where + S: Sink + Unpin, +{ + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + poll_write_helper(self, cx, buf) + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.0).poll_flush(cx).map_err(convert_err) + } + + fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.0).poll_close(cx).map_err(convert_err) + } +} + fn convert_err(e: WsError) -> io::Error { match e { WsError::Io(io) => io,