Skip to content

Commit

Permalink
Adapt AsyncRead, AsynWrite
Browse files Browse the repository at this point in the history
  • Loading branch information
Urhengulas committed Oct 24, 2020
1 parent 842ad6e commit 538f70f
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 83 deletions.
30 changes: 8 additions & 22 deletions src/common/io/rewind.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use std::marker::Unpin;
use std::{cmp, io};

use bytes::{Buf, Bytes};
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};

use crate::common::{task, Pin, Poll};

Expand Down Expand Up @@ -46,27 +46,22 @@ impl<T> AsyncRead for Rewind<T>
where
T: AsyncRead + Unpin,
{
#[inline]
unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [std::mem::MaybeUninit<u8>]) -> bool {
self.inner.prepare_uninitialized_buffer(buf)
}

fn poll_read(
mut self: Pin<&mut Self>,
self: Pin<&mut Self>,
cx: &mut task::Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
if let Some(mut prefix) = self.pre.take() {
// If there are no remaining bytes, let the bytes get dropped.
if !prefix.is_empty() {
let copy_len = cmp::min(prefix.len(), buf.len());
prefix.copy_to_slice(&mut buf[..copy_len]);
let copy_len = cmp::min(prefix.len(), buf.remaining());
// TODO: There should be a way to do following two lines cleaner...
buf.put_slice(prefix.to_vec().as_slice());
prefix.advance(copy_len);
// Put back whats left
if !prefix.is_empty() {
self.pre = Some(prefix);
}

return Poll::Ready(Ok(copy_len));
}
}
Pin::new(&mut self.inner).poll_read(cx, buf)
Expand All @@ -92,15 +87,6 @@ where
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<io::Result<()>> {
Pin::new(&mut self.inner).poll_shutdown(cx)
}

#[inline]
fn poll_write_buf<B: Buf>(
mut self: Pin<&mut Self>,
cx: &mut task::Context<'_>,
buf: &mut B,
) -> Poll<io::Result<usize>> {
Pin::new(&mut self.inner).poll_write_buf(cx, buf)
}
}

#[cfg(test)]
Expand Down
33 changes: 4 additions & 29 deletions src/server/tcp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ mod addr_stream {
use std::net::SocketAddr;
#[cfg(unix)]
use std::os::unix::io::{AsRawFd, RawFd};
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use tokio::net::TcpStream;

use crate::common::{task, Pin, Poll};
Expand Down Expand Up @@ -231,30 +231,14 @@ mod addr_stream {
}

impl AsyncRead for AddrStream {
unsafe fn prepare_uninitialized_buffer(
&self,
buf: &mut [std::mem::MaybeUninit<u8>],
) -> bool {
self.inner.prepare_uninitialized_buffer(buf)
}

#[inline]
fn poll_read(
mut self: Pin<&mut Self>,
self: Pin<&mut Self>,
cx: &mut task::Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
Pin::new(&mut self.inner).poll_read(cx, buf)
}

#[inline]
fn poll_read_buf<B: BufMut>(
mut self: Pin<&mut Self>,
cx: &mut task::Context<'_>,
buf: &mut B,
) -> Poll<io::Result<usize>> {
Pin::new(&mut self.inner).poll_read_buf(cx, buf)
}
}

impl AsyncWrite for AddrStream {
Expand All @@ -267,15 +251,6 @@ mod addr_stream {
Pin::new(&mut self.inner).poll_write(cx, buf)
}

#[inline]
fn poll_write_buf<B: Buf>(
mut self: Pin<&mut Self>,
cx: &mut task::Context<'_>,
buf: &mut B,
) -> Poll<io::Result<usize>> {
Pin::new(&mut self.inner).poll_write_buf(cx, buf)
}

#[inline]
fn poll_flush(self: Pin<&mut Self>, _cx: &mut task::Context<'_>) -> Poll<io::Result<()>> {
// TCP flush is a noop
Expand Down
40 changes: 8 additions & 32 deletions src/upgrade.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use std::io;
use std::marker::Unpin;

use bytes::{Buf, Bytes};
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use tokio::sync::oneshot;

use crate::common::io::Rewind;
Expand Down Expand Up @@ -105,15 +105,11 @@ impl Upgraded {
}

impl AsyncRead for Upgraded {
unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [std::mem::MaybeUninit<u8>]) -> bool {
self.io.prepare_uninitialized_buffer(buf)
}

fn poll_read(
mut self: Pin<&mut Self>,
self: Pin<&mut Self>,
cx: &mut task::Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
Pin::new(&mut self.io).poll_read(cx, buf)
}
}
Expand All @@ -127,14 +123,6 @@ impl AsyncWrite for Upgraded {
Pin::new(&mut self.io).poll_write(cx, buf)
}

fn poll_write_buf<B: Buf>(
mut self: Pin<&mut Self>,
cx: &mut task::Context<'_>,
buf: &mut B,
) -> Poll<io::Result<usize>> {
Pin::new(self.io.get_mut()).poll_write_dyn_buf(cx, buf)
}

fn poll_flush(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<io::Result<()>> {
Pin::new(&mut self.io).poll_flush(cx)
}
Expand Down Expand Up @@ -247,15 +235,11 @@ impl dyn Io + Send {
}

impl<T: AsyncRead + Unpin> AsyncRead for ForwardsWriteBuf<T> {
unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [std::mem::MaybeUninit<u8>]) -> bool {
self.0.prepare_uninitialized_buffer(buf)
}

fn poll_read(
mut self: Pin<&mut Self>,
self: Pin<&mut Self>,
cx: &mut task::Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
Pin::new(&mut self.0).poll_read(cx, buf)
}
}
Expand All @@ -269,14 +253,6 @@ impl<T: AsyncWrite + Unpin> AsyncWrite for ForwardsWriteBuf<T> {
Pin::new(&mut self.0).poll_write(cx, buf)
}

fn poll_write_buf<B: Buf>(
mut self: Pin<&mut Self>,
cx: &mut task::Context<'_>,
buf: &mut B,
) -> Poll<io::Result<usize>> {
Pin::new(&mut self.0).poll_write_buf(cx, buf)
}

fn poll_flush(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<io::Result<()>> {
Pin::new(&mut self.0).poll_flush(cx)
}
Expand All @@ -292,7 +268,7 @@ impl<T: AsyncRead + AsyncWrite + Unpin + 'static> Io for ForwardsWriteBuf<T> {
cx: &mut task::Context<'_>,
mut buf: &mut dyn Buf,
) -> Poll<io::Result<usize>> {
Pin::new(&mut self.0).poll_write_buf(cx, &mut buf)
Pin::new(&mut self.0).poll_write(cx, buf.bytes())
}
}

Expand Down

0 comments on commit 538f70f

Please sign in to comment.