diff --git a/tokio/src/io/interest.rs b/tokio/src/io/interest.rs index 3a39cf761b7..abda774dde7 100644 --- a/tokio/src/io/interest.rs +++ b/tokio/src/io/interest.rs @@ -115,6 +115,7 @@ impl Interest { /// /// assert!(BOTH.is_readable()); /// assert!(BOTH.is_writable()); + /// ``` pub const fn add(self, other: Interest) -> Interest { Interest(self.0.add(other.0)) } @@ -135,6 +136,12 @@ impl Interest { } } +impl Default for Interest { + fn default() -> Self { + Interest::READABLE.add(Interest::WRITABLE) + } +} + impl ops::BitOr for Interest { type Output = Self; diff --git a/tokio/src/net/tcp/listener.rs b/tokio/src/net/tcp/listener.rs index 34e393c895f..6d6a3092c25 100644 --- a/tokio/src/net/tcp/listener.rs +++ b/tokio/src/net/tcp/listener.rs @@ -158,13 +158,58 @@ impl TcpListener { /// } /// ``` pub async fn accept(&self) -> io::Result<(TcpStream, SocketAddr)> { + self.accept_with_interest(Default::default()).await + } + + /// Accepts a new incoming connection from this listener with custom + /// interest registration. + /// + /// This function will yield once a new TCP connection is established. When + /// established, the corresponding [`TcpStream`] and the remote peer's + /// address will be returned. + /// + /// # Cancel safety + /// + /// This method is cancel safe. If the method is used as the event in a + /// [`tokio::select!`](crate::select) statement and some other branch + /// completes first, then it is guaranteed that no new connections were + /// accepted by this method. + /// + /// [`TcpStream`]: struct@crate::net::TcpStream + /// + /// # Examples + /// + /// ```no_run + /// use tokio::{io::Interest, net::TcpListener}; + /// + /// use std::io; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let listener = TcpListener::bind("127.0.0.1:8080").await?; + /// + /// match listener + /// .accept_with_interest(Default::default().add(Interest::PRIORITY)) + /// .await + /// { + /// Ok((_socket, addr)) => println!("new client: {:?}", addr), + /// Err(e) => println!("couldn't get client: {:?}", e), + /// } + /// + /// Ok(()) + /// } + /// ``` + pub async fn accept_with_interest( + &self, + interest: Interest, + ) -> io::Result<(TcpStream, SocketAddr)> { let (mio, addr) = self .io .registration() .async_io(Interest::READABLE, || self.io.accept()) .await?; - let stream = TcpStream::new(mio)?; + let stream = TcpStream::new_with_interest(mio, interest)?; Ok((stream, addr)) } diff --git a/tokio/src/net/tcp/socket.rs b/tokio/src/net/tcp/socket.rs index df792f9a615..ed7d4470225 100644 --- a/tokio/src/net/tcp/socket.rs +++ b/tokio/src/net/tcp/socket.rs @@ -654,7 +654,7 @@ impl TcpSocket { unsafe { mio::net::TcpStream::from_raw_socket(raw_socket) } }; - TcpStream::connect_mio(mio).await + TcpStream::connect_mio(mio, Default::default()).await } /// Converts the socket into a `TcpListener`. diff --git a/tokio/src/net/tcp/stream.rs b/tokio/src/net/tcp/stream.rs index b7104d78b40..39fb280c468 100644 --- a/tokio/src/net/tcp/stream.rs +++ b/tokio/src/net/tcp/stream.rs @@ -112,12 +112,59 @@ impl TcpStream { /// [`write_all`]: fn@crate::io::AsyncWriteExt::write_all /// [`AsyncWriteExt`]: trait@crate::io::AsyncWriteExt pub async fn connect(addr: A) -> io::Result { + Self::connect_with_interest(addr, Default::default()).await + } + + /// Opens a TCP connection to a remote host with custom interest + /// registration.. + /// + /// `addr` is an address of the remote host. Anything which implements the + /// [`ToSocketAddrs`] trait can be supplied as the address. If `addr` + /// yields multiple addresses, connect will be attempted with each of the + /// addresses until a connection is successful. If none of the addresses + /// result in a successful connection, the error returned from the last + /// connection attempt (the last address) is returned. + /// + /// To configure the socket before connecting, you can use the [`TcpSocket`] + /// type. + /// + /// [`ToSocketAddrs`]: trait@crate::net::ToSocketAddrs + /// [`TcpSocket`]: struct@crate::net::TcpSocket + /// + /// # Examples + /// + /// ```no_run + /// use tokio::net::TcpStream; + /// use tokio::io::{AsyncWriteExt, Interest}; + /// use std::error::Error; + /// + /// #[tokio::main] + /// async fn main() -> Result<(), Box> { + /// // Connect to a peer + /// let mut stream = TcpStream::connect_with_interest( + /// "127.0.0.1:8080", + /// Default::default().add(Interest::PRIORITY), + /// ) + /// .await?; + /// + /// // Write some data. + /// stream.write_all(b"hello world!").await?; + /// + /// Ok(()) + /// } + /// ``` + /// + /// The [`write_all`] method is defined on the [`AsyncWriteExt`] trait. + /// + /// [`write_all`]: fn@crate::io::AsyncWriteExt::write_all + /// [`AsyncWriteExt`]: trait@crate::io::AsyncWriteExt + pub async fn connect_with_interest(addr: A, interest: Interest) -> io::Result { let addrs = to_socket_addrs(addr).await?; let mut last_err = None; for addr in addrs { - match TcpStream::connect_addr(addr).await { + match TcpStream::connect_addr(addr, interest).await { Ok(stream) => return Ok(stream), Err(e) => last_err = Some(e), } @@ -132,13 +179,13 @@ impl TcpStream { } /// Establishes a connection to the specified `addr`. - async fn connect_addr(addr: SocketAddr) -> io::Result { + async fn connect_addr(addr: SocketAddr, interest: Interest) -> io::Result { let sys = mio::net::TcpStream::connect(addr)?; - TcpStream::connect_mio(sys).await + TcpStream::connect_mio(sys, interest).await } - pub(crate) async fn connect_mio(sys: mio::net::TcpStream) -> io::Result { - let stream = TcpStream::new(sys)?; + pub(crate) async fn connect_mio(sys: mio::net::TcpStream, interest: Interest) -> io::Result { + let stream = TcpStream::new_with_interest(sys, interest)?; // Once we've connected, wait for the stream to be writable as // that's when the actual connection has been initiated. Once we're @@ -161,6 +208,14 @@ impl TcpStream { Ok(TcpStream { io }) } + pub(crate) fn new_with_interest( + connected: mio::net::TcpStream, + interest: Interest, + ) -> io::Result { + let io = PollEvented::new_with_interest(connected, interest)?; + Ok(TcpStream { io }) + } + /// Creates new `TcpStream` from a `std::net::TcpStream`. /// /// This function is intended to be used to wrap a TCP stream from the