From eb286418aea05b367141deb2dd0e6969d4529a1e Mon Sep 17 00:00:00 2001 From: Ruben2424 <61056653+Ruben2424@users.noreply.github.com> Date: Thu, 29 Jun 2023 22:55:43 +0200 Subject: [PATCH] feat(client): Make clients able to use non-Send executor (#3184) Closes #3017 BREAKING CHANGE: `client::conn::http2` types now use another generic for an `Executor`. Code that names `Connection` needs to include the additional generic parameter. Signed-off-by: Sven Pfennig --- examples/single_threaded.rs | 171 +++++++++++++-- src/client/conn/http2.rs | 77 ++++--- src/client/dispatch.rs | 89 +++++--- src/common/exec.rs | 38 +--- src/common/mod.rs | 2 +- src/ffi/task.rs | 8 +- src/proto/h2/client.rs | 406 +++++++++++++++++++++++++++++------- src/proto/h2/mod.rs | 2 +- src/rt/bounds.rs | 72 +++++-- 9 files changed, 652 insertions(+), 213 deletions(-) diff --git a/examples/single_threaded.rs b/examples/single_threaded.rs index ee109d54fa..de6256239c 100644 --- a/examples/single_threaded.rs +++ b/examples/single_threaded.rs @@ -1,17 +1,22 @@ #![deny(warnings)] +use http_body_util::BodyExt; use hyper::server::conn::http2; use std::cell::Cell; use std::net::SocketAddr; use std::rc::Rc; +use tokio::io::{self, AsyncRead, AsyncWrite, AsyncWriteExt}; use tokio::net::TcpListener; use hyper::body::{Body as HttpBody, Bytes, Frame}; use hyper::service::service_fn; +use hyper::Request; use hyper::{Error, Response}; use std::marker::PhantomData; use std::pin::Pin; use std::task::{Context, Poll}; +use std::thread; +use tokio::net::TcpStream; struct Body { // Our Body type is !Send and !Sync: @@ -40,28 +45,57 @@ impl HttpBody for Body { } } -fn main() -> Result<(), Box> { +fn main() { pretty_env_logger::init(); - // Configure a runtime that runs everything on the current thread - let rt = tokio::runtime::Builder::new_current_thread() - .enable_all() - .build() - .expect("build runtime"); - - // Combine it with a `LocalSet, which means it can spawn !Send futures... - let local = tokio::task::LocalSet::new(); - local.block_on(&rt, run()) + let server = thread::spawn(move || { + // Configure a runtime for the server that runs everything on the current thread + let rt = tokio::runtime::Builder::new_current_thread() + .enable_all() + .build() + .expect("build runtime"); + + // Combine it with a `LocalSet, which means it can spawn !Send futures... + let local = tokio::task::LocalSet::new(); + local.block_on(&rt, server()).unwrap(); + }); + + let client = thread::spawn(move || { + // Configure a runtime for the client that runs everything on the current thread + let rt = tokio::runtime::Builder::new_current_thread() + .enable_all() + .build() + .expect("build runtime"); + + // Combine it with a `LocalSet, which means it can spawn !Send futures... + let local = tokio::task::LocalSet::new(); + local + .block_on( + &rt, + client("http://localhost:3000".parse::().unwrap()), + ) + .unwrap(); + }); + + server.join().unwrap(); + client.join().unwrap(); } -async fn run() -> Result<(), Box> { - let addr: SocketAddr = ([127, 0, 0, 1], 3000).into(); +async fn server() -> Result<(), Box> { + let mut stdout = io::stdout(); + let addr: SocketAddr = ([127, 0, 0, 1], 3000).into(); // Using a !Send request counter is fine on 1 thread... let counter = Rc::new(Cell::new(0)); let listener = TcpListener::bind(addr).await?; - println!("Listening on http://{}", addr); + + stdout + .write_all(format!("Listening on http://{}", addr).as_bytes()) + .await + .unwrap(); + stdout.flush().await.unwrap(); + loop { let (stream, _) = listener.accept().await?; @@ -80,12 +114,121 @@ async fn run() -> Result<(), Box> { .serve_connection(stream, service) .await { - println!("Error serving connection: {:?}", err); + let mut stdout = io::stdout(); + stdout + .write_all(format!("Error serving connection: {:?}", err).as_bytes()) + .await + .unwrap(); + stdout.flush().await.unwrap(); } }); } } +struct IOTypeNotSend { + _marker: PhantomData<*const ()>, + stream: TcpStream, +} + +impl IOTypeNotSend { + fn new(stream: TcpStream) -> Self { + Self { + _marker: PhantomData, + stream, + } + } +} + +impl AsyncWrite for IOTypeNotSend { + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + Pin::new(&mut self.stream).poll_write(cx, buf) + } + + fn poll_flush( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + Pin::new(&mut self.stream).poll_flush(cx) + } + + fn poll_shutdown( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + Pin::new(&mut self.stream).poll_shutdown(cx) + } +} + +impl AsyncRead for IOTypeNotSend { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut tokio::io::ReadBuf<'_>, + ) -> Poll> { + Pin::new(&mut self.stream).poll_read(cx, buf) + } +} + +async fn client(url: hyper::Uri) -> Result<(), Box> { + let host = url.host().expect("uri has no host"); + let port = url.port_u16().unwrap_or(80); + let addr = format!("{}:{}", host, port); + let stream = TcpStream::connect(addr).await?; + + let stream = IOTypeNotSend::new(stream); + + let (mut sender, conn) = hyper::client::conn::http2::handshake(LocalExec, stream).await?; + + tokio::task::spawn_local(async move { + if let Err(err) = conn.await { + let mut stdout = io::stdout(); + stdout + .write_all(format!("Connection failed: {:?}", err).as_bytes()) + .await + .unwrap(); + stdout.flush().await.unwrap(); + } + }); + + let authority = url.authority().unwrap().clone(); + + // Make 4 requests + for _ in 0..4 { + let req = Request::builder() + .uri(url.clone()) + .header(hyper::header::HOST, authority.as_str()) + .body(Body::from("test".to_string()))?; + + let mut res = sender.send_request(req).await?; + + let mut stdout = io::stdout(); + stdout + .write_all(format!("Response: {}\n", res.status()).as_bytes()) + .await + .unwrap(); + stdout + .write_all(format!("Headers: {:#?}\n", res.headers()).as_bytes()) + .await + .unwrap(); + stdout.flush().await.unwrap(); + + // Print the response body + while let Some(next) = res.frame().await { + let frame = next?; + if let Some(chunk) = frame.data_ref() { + stdout.write_all(&chunk).await.unwrap(); + } + } + stdout.write_all(b"\n-----------------\n").await.unwrap(); + stdout.flush().await.unwrap(); + } + Ok(()) +} + // NOTE: This part is only needed for HTTP/2. HTTP/1 doesn't need an executor. // // Since the Server needs to spawn some background tasks, we needed diff --git a/src/client/conn/http2.rs b/src/client/conn/http2.rs index a4cdc22f71..16c7af0a3c 100644 --- a/src/client/conn/http2.rs +++ b/src/client/conn/http2.rs @@ -1,6 +1,6 @@ //! HTTP/2 client connections -use std::error::Error as StdError; +use std::error::Error; use std::fmt; use std::marker::PhantomData; use std::sync::Arc; @@ -12,12 +12,10 @@ use tokio::io::{AsyncRead, AsyncWrite}; use super::super::dispatch; use crate::body::{Body, Incoming as IncomingBody}; use crate::common::time::Time; -use crate::common::{ - exec::{BoxSendFuture, Exec}, - task, Future, Pin, Poll, -}; +use crate::common::{task, Future, Pin, Poll}; use crate::proto; -use crate::rt::{Executor, Timer}; +use crate::rt::bounds::ExecutorClient; +use crate::rt::Timer; /// The sender side of an established connection. pub struct SendRequest { @@ -37,20 +35,22 @@ impl Clone for SendRequest { /// In most cases, this should just be spawned into an executor, so that it /// can process incoming and outgoing messages, notice hangups, and the like. #[must_use = "futures do nothing unless polled"] -pub struct Connection +pub struct Connection where - T: AsyncRead + AsyncWrite + Send + 'static, + T: AsyncRead + AsyncWrite + 'static + Unpin, B: Body + 'static, + E: ExecutorClient + Unpin, + B::Error: Into>, { - inner: (PhantomData, proto::h2::ClientTask), + inner: (PhantomData, proto::h2::ClientTask), } /// A builder to configure an HTTP connection. /// /// After setting options, the builder is used to create a handshake future. #[derive(Clone, Debug)] -pub struct Builder { - pub(super) exec: Exec, +pub struct Builder { + pub(super) exec: Ex, pub(super) timer: Time, h2_builder: proto::h2::client::Config, } @@ -59,13 +59,16 @@ pub struct Builder { /// /// This is a shortcut for `Builder::new().handshake(io)`. /// See [`client::conn`](crate::client::conn) for more. -pub async fn handshake(exec: E, io: T) -> crate::Result<(SendRequest, Connection)> +pub async fn handshake( + exec: E, + io: T, +) -> crate::Result<(SendRequest, Connection)> where - E: Executor + Send + Sync + 'static, - T: AsyncRead + AsyncWrite + Unpin + Send + 'static, + T: AsyncRead + AsyncWrite + Unpin + 'static, B: Body + 'static, B::Data: Send, - B::Error: Into>, + B::Error: Into>, + E: ExecutorClient + Unpin + Clone, { Builder::new(exec).handshake(io).await } @@ -188,12 +191,13 @@ impl fmt::Debug for SendRequest { // ===== impl Connection -impl Connection +impl Connection where - T: AsyncRead + AsyncWrite + Unpin + Send + 'static, + T: AsyncRead + AsyncWrite + Unpin + 'static, B: Body + Unpin + Send + 'static, B::Data: Send, - B::Error: Into>, + B::Error: Into>, + E: ExecutorClient + Unpin, { /// Returns whether the [extended CONNECT protocol][1] is enabled or not. /// @@ -209,22 +213,26 @@ where } } -impl fmt::Debug for Connection +impl fmt::Debug for Connection where - T: AsyncRead + AsyncWrite + fmt::Debug + Send + 'static, + T: AsyncRead + AsyncWrite + fmt::Debug + 'static + Unpin, B: Body + 'static, + E: ExecutorClient + Unpin, + B::Error: Into>, { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("Connection").finish() } } -impl Future for Connection +impl Future for Connection where - T: AsyncRead + AsyncWrite + Unpin + Send + 'static, - B: Body + Send + 'static, + T: AsyncRead + AsyncWrite + Unpin + 'static, + B: Body + 'static + Unpin, B::Data: Send, - B::Error: Into>, + E: Unpin, + B::Error: Into>, + E: ExecutorClient + 'static + Send + Sync + Unpin, { type Output = crate::Result<()>; @@ -239,22 +247,22 @@ where // ===== impl Builder -impl Builder { +impl Builder +where + Ex: Clone, +{ /// Creates a new connection builder. #[inline] - pub fn new(exec: E) -> Builder - where - E: Executor + Send + Sync + 'static, - { + pub fn new(exec: Ex) -> Builder { Builder { - exec: Exec::new(exec), + exec, timer: Time::Empty, h2_builder: Default::default(), } } /// Provide a timer to execute background HTTP2 tasks. - pub fn timer(&mut self, timer: M) -> &mut Builder + pub fn timer(&mut self, timer: M) -> &mut Builder where M: Timer + Send + Sync + 'static, { @@ -388,12 +396,13 @@ impl Builder { pub fn handshake( &self, io: T, - ) -> impl Future, Connection)>> + ) -> impl Future, Connection)>> where - T: AsyncRead + AsyncWrite + Unpin + Send + 'static, + T: AsyncRead + AsyncWrite + Unpin + 'static, B: Body + 'static, B::Data: Send, - B::Error: Into>, + B::Error: Into>, + Ex: ExecutorClient + Unpin, { let opts = self.clone(); diff --git a/src/client/dispatch.rs b/src/client/dispatch.rs index 3aef84012f..40cb554917 100644 --- a/src/client/dispatch.rs +++ b/src/client/dispatch.rs @@ -1,11 +1,18 @@ #[cfg(feature = "http2")] use std::future::Future; +use http::{Request, Response}; +use http_body::Body; +use pin_project_lite::pin_project; use tokio::sync::{mpsc, oneshot}; +use tracing::trace; +use crate::{ + body::Incoming, + common::{task, Poll}, +}; #[cfg(feature = "http2")] -use crate::common::Pin; -use crate::common::{task, Poll}; +use crate::{common::Pin, proto::h2::client::ResponseFutMap}; #[cfg(test)] pub(crate) type RetryPromise = oneshot::Receiver)>>; @@ -266,37 +273,57 @@ impl Callback { } } } +} - #[cfg(feature = "http2")] - pub(crate) async fn send_when( - self, - mut when: impl Future)>> + Unpin, - ) { - use futures_util::future; - use tracing::trace; - - let mut cb = Some(self); - - // "select" on this callback being canceled, and the future completing - future::poll_fn(move |cx| { - match Pin::new(&mut when).poll(cx) { - Poll::Ready(Ok(res)) => { - cb.take().expect("polled after complete").send(Ok(res)); - Poll::Ready(()) - } - Poll::Pending => { - // check if the callback is canceled - ready!(cb.as_mut().unwrap().poll_canceled(cx)); - trace!("send_when canceled"); - Poll::Ready(()) - } - Poll::Ready(Err(err)) => { - cb.take().expect("polled after complete").send(Err(err)); - Poll::Ready(()) - } +#[cfg(feature = "http2")] +pin_project! { + pub struct SendWhen + where + B: Body, + B: 'static, + { + #[pin] + pub(crate) when: ResponseFutMap, + #[pin] + pub(crate) call_back: Option, Response>>, + } +} + +#[cfg(feature = "http2")] +impl Future for SendWhen +where + B: Body + 'static, +{ + type Output = (); + + fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll { + let mut this = self.project(); + + let mut call_back = this.call_back.take().expect("polled after complete"); + + match Pin::new(&mut this.when).poll(cx) { + Poll::Ready(Ok(res)) => { + call_back.send(Ok(res)); + Poll::Ready(()) } - }) - .await + Poll::Pending => { + // check if the callback is canceled + match call_back.poll_canceled(cx) { + Poll::Ready(v) => v, + Poll::Pending => { + // Move call_back back to struct before return + this.call_back.set(Some(call_back)); + return std::task::Poll::Pending; + } + }; + trace!("send_when canceled"); + Poll::Ready(()) + } + Poll::Ready(Err(err)) => { + call_back.send(Err(err)); + Poll::Ready(()) + } + } } } diff --git a/src/common/exec.rs b/src/common/exec.rs index ef006c9d84..69d19e9bb7 100644 --- a/src/common/exec.rs +++ b/src/common/exec.rs @@ -1,50 +1,14 @@ -use std::fmt; use std::future::Future; use std::pin::Pin; -use std::sync::Arc; - -use crate::rt::Executor; - -pub(crate) type BoxSendFuture = Pin + Send>>; - -// Executor must be provided by the user -#[derive(Clone)] -pub(crate) struct Exec(Arc + Send + Sync>); - -// ===== impl Exec ===== - -impl Exec { - pub(crate) fn new(exec: E) -> Self - where - E: Executor + Send + Sync + 'static, - { - Self(Arc::new(exec)) - } - - pub(crate) fn execute(&self, fut: F) - where - F: Future + Send + 'static, - { - self.0.execute(Box::pin(fut)) - } -} - -impl fmt::Debug for Exec { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("Exec").finish() - } -} // If http2 is not enable, we just have a stub here, so that the trait bounds // that *would* have been needed are still checked. Why? // // Because enabling `http2` shouldn't suddenly add new trait bounds that cause // a compilation error. -#[cfg(not(feature = "http2"))] -#[allow(missing_debug_implementations)] + pub struct H2Stream(std::marker::PhantomData<(F, B)>); -#[cfg(not(feature = "http2"))] impl Future for H2Stream where F: Future, E>>, diff --git a/src/common/mod.rs b/src/common/mod.rs index 67b2bbde59..2392851951 100644 --- a/src/common/mod.rs +++ b/src/common/mod.rs @@ -10,7 +10,7 @@ macro_rules! ready { pub(crate) mod buf; #[cfg(all(feature = "server", any(feature = "http1", feature = "http2")))] pub(crate) mod date; -#[cfg(any(feature = "http1", feature = "http2", feature = "server"))] +#[cfg(not(feature = "http2"))] pub(crate) mod exec; pub(crate) mod io; mod never; diff --git a/src/ffi/task.rs b/src/ffi/task.rs index ef54fe408f..a973a7bab3 100644 --- a/src/ffi/task.rs +++ b/src/ffi/task.rs @@ -177,8 +177,12 @@ impl WeakExec { } } -impl crate::rt::Executor> for WeakExec { - fn execute(&self, fut: BoxFuture<()>) { +impl crate::rt::Executor for WeakExec +where + F: Future + Send + 'static, + F::Output: Send + Sync + AsTaskType, +{ + fn execute(&self, fut: F) { if let Some(exec) = self.0.upgrade() { exec.spawn(hyper_task::boxed(fut)); } diff --git a/src/proto/h2/client.rs b/src/proto/h2/client.rs index adadfce68d..56aff85a9f 100644 --- a/src/proto/h2/client.rs +++ b/src/proto/h2/client.rs @@ -1,25 +1,30 @@ -use std::error::Error as StdError; +use std::marker::PhantomData; + use std::time::Duration; use bytes::Bytes; +use futures_channel::mpsc::{Receiver, Sender}; use futures_channel::{mpsc, oneshot}; -use futures_util::future::{self, Either, FutureExt as _, TryFutureExt as _}; -use futures_util::stream::StreamExt as _; -use h2::client::{Builder, SendRequest}; +use futures_util::future::{self, Either, FutureExt as _, Select}; +use futures_util::stream::{StreamExt as _, StreamFuture}; +use h2::client::{Builder, Connection, SendRequest}; use h2::SendStream; use http::{Method, StatusCode}; +use pin_project_lite::pin_project; use tokio::io::{AsyncRead, AsyncWrite}; use tracing::{debug, trace, warn}; +use super::ping::{Ponger, Recorder}; use super::{ping, H2Upgraded, PipeToSendStream, SendBuf}; use crate::body::{Body, Incoming as IncomingBody}; -use crate::client::dispatch::Callback; +use crate::client::dispatch::{Callback, SendWhen}; use crate::common::time::Time; -use crate::common::{exec::Exec, task, Future, Never, Pin, Poll}; +use crate::common::{task, Future, Never, Pin, Poll}; use crate::ext::Protocol; use crate::headers; use crate::proto::h2::UpgradedSendStream; use crate::proto::Dispatched; +use crate::rt::bounds::ExecutorClient; use crate::upgrade::Upgraded; use crate::{Request, Response}; use h2::client::ResponseFuture; @@ -98,17 +103,19 @@ fn new_ping_config(config: &Config) -> ping::Config { } } -pub(crate) async fn handshake( +pub(crate) async fn handshake( io: T, req_rx: ClientRx, config: &Config, - exec: Exec, + mut exec: E, timer: Time, -) -> crate::Result> +) -> crate::Result> where - T: AsyncRead + AsyncWrite + Send + Unpin + 'static, - B: Body, + T: AsyncRead + AsyncWrite + Unpin + 'static, + B: Body + 'static, B::Data: Send + 'static, + E: ExecutorClient + Unpin, + B::Error: Into>, { let (h2_tx, mut conn) = new_builder(config) .handshake::<_, SendBuf>(io) @@ -122,40 +129,24 @@ where let (conn_drop_ref, rx) = mpsc::channel(1); let (cancel_tx, conn_eof) = oneshot::channel(); - let conn_drop_rx = rx.into_future().map(|(item, _rx)| { - if let Some(never) = item { - match never {} - } - }); + let conn_drop_rx = rx.into_future(); let ping_config = new_ping_config(&config); let (conn, ping) = if ping_config.is_enabled() { let pp = conn.ping_pong().expect("conn.ping_pong"); - let (recorder, mut ponger) = ping::channel(pp, ping_config, timer); + let (recorder, ponger) = ping::channel(pp, ping_config, timer); - let conn = future::poll_fn(move |cx| { - match ponger.poll(cx) { - Poll::Ready(ping::Ponged::SizeUpdate(wnd)) => { - conn.set_target_window_size(wnd); - conn.set_initial_window_size(wnd)?; - } - Poll::Ready(ping::Ponged::KeepAliveTimedOut) => { - debug!("connection keep-alive timed out"); - return Poll::Ready(Ok(())); - } - Poll::Pending => {} - } - - Pin::new(&mut conn).poll(cx) - }); + let conn: Conn<_, B> = Conn::new(ponger, conn); (Either::Left(conn), recorder) } else { (Either::Right(conn), ping::disabled()) }; - let conn = conn.map_err(|e| debug!("connection error: {}", e)); + let conn: ConnMapErr = ConnMapErr { conn }; - exec.execute(conn_task(conn, conn_drop_rx, cancel_tx)); + exec.execute_h2_future(H2ClientFuture::Task { + task: ConnTask::new(conn, conn_drop_rx, cancel_tx), + }); Ok(ClientTask { ping, @@ -165,25 +156,195 @@ where h2_tx, req_rx, fut_ctx: None, + marker: PhantomData, }) } -async fn conn_task(conn: C, drop_rx: D, cancel_tx: oneshot::Sender) +pin_project! { + struct Conn + where + B: Body, + { + #[pin] + ponger: Ponger, + #[pin] + conn: Connection::Data>>, + } +} + +impl Conn +where + B: Body, + T: AsyncRead + AsyncWrite + Unpin, +{ + fn new(ponger: Ponger, conn: Connection::Data>>) -> Self { + Conn { ponger, conn } + } +} + +impl Future for Conn where - C: Future + Unpin, - D: Future + Unpin, + B: Body, + T: AsyncRead + AsyncWrite + Unpin, { - match future::select(conn, drop_rx).await { - Either::Left(_) => { - // ok or err, the `conn` has finished + type Output = Result<(), h2::Error>; + + fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll { + let mut this = self.project(); + match this.ponger.poll(cx) { + Poll::Ready(ping::Ponged::SizeUpdate(wnd)) => { + this.conn.set_target_window_size(wnd); + this.conn.set_initial_window_size(wnd)?; + } + Poll::Ready(ping::Ponged::KeepAliveTimedOut) => { + debug!("connection keep-alive timed out"); + return Poll::Ready(Ok(())); + } + Poll::Pending => {} } - Either::Right(((), conn)) => { - // mpsc has been dropped, hopefully polling - // the connection some more should start shutdown - // and then close - trace!("send_request dropped, starting conn shutdown"); - drop(cancel_tx); - let _ = conn.await; + + Pin::new(&mut this.conn).poll(cx) + } +} + +pin_project! { + struct ConnMapErr + where + B: Body, + T: AsyncRead, + T: AsyncWrite, + T: Unpin, + { + #[pin] + conn: Either, Connection::Data>>>, + } +} + +impl Future for ConnMapErr +where + B: Body, + T: AsyncRead + AsyncWrite + Unpin, +{ + type Output = Result<(), ()>; + + fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll { + self.project() + .conn + .poll(cx) + .map_err(|e| debug!("connection error: {}", e)) + } +} + +pin_project! { + pub struct ConnTask + where + B: Body, + T: AsyncRead, + T: AsyncWrite, + T: Unpin, + { + #[pin] + select: Select, StreamFuture>>, + #[pin] + cancel_tx: Option>, + conn: Option>, + } +} + +impl ConnTask +where + B: Body, + T: AsyncRead + AsyncWrite + Unpin, +{ + fn new( + conn: ConnMapErr, + drop_rx: StreamFuture>, + cancel_tx: oneshot::Sender, + ) -> Self { + Self { + select: future::select(conn, drop_rx), + cancel_tx: Some(cancel_tx), + conn: None, + } + } +} + +impl Future for ConnTask +where + B: Body, + T: AsyncRead + AsyncWrite + Unpin, +{ + type Output = (); + + fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll { + let mut this = self.project(); + + if let Some(conn) = this.conn { + conn.poll_unpin(cx).map(|_| ()) + } else { + match ready!(this.select.poll_unpin(cx)) { + Either::Left((_, _)) => { + // ok or err, the `conn` has finished + return Poll::Ready(()); + } + Either::Right((_, b)) => { + // mpsc has been dropped, hopefully polling + // the connection some more should start shutdown + // and then close + trace!("send_request dropped, starting conn shutdown"); + drop(this.cancel_tx.take().expect("Future polled twice")); + this.conn = &mut Some(b); + return Poll::Pending; + } + } + } + } +} + +pin_project! { + #[project = H2ClientFutureProject] + pub enum H2ClientFuture + where + B: http_body::Body, + B: 'static, + B::Error: Into>, + T: AsyncRead, + T: AsyncWrite, + T: Unpin, + { + Pipe { + #[pin] + pipe: PipeMap, + }, + Send { + #[pin] + send_when: SendWhen, + }, + Task { + #[pin] + task: ConnTask, + }, + } +} + +impl Future for H2ClientFuture +where + B: http_body::Body + 'static, + B::Error: Into>, + T: AsyncRead + AsyncWrite + Unpin, +{ + type Output = (); + + fn poll( + self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll { + let this = self.project(); + + match this { + H2ClientFutureProject::Pipe { pipe } => pipe.poll(cx), + H2ClientFutureProject::Send { send_when } => send_when.poll(cx), + H2ClientFutureProject::Task { task } => task.poll(cx), } } } @@ -202,43 +363,89 @@ where impl Unpin for FutCtx {} -pub(crate) struct ClientTask +pub(crate) struct ClientTask where B: Body, + E: Unpin, { ping: ping::Recorder, conn_drop_ref: ConnDropRef, conn_eof: ConnEof, - executor: Exec, + executor: E, h2_tx: SendRequest>, req_rx: ClientRx, fut_ctx: Option>, + marker: PhantomData, } -impl ClientTask +impl ClientTask where B: Body + 'static, + E: ExecutorClient + Unpin, + B::Error: Into>, + T: AsyncRead + AsyncWrite + Unpin, { pub(crate) fn is_extended_connect_protocol_enabled(&self) -> bool { self.h2_tx.is_extended_connect_protocol_enabled() } } -impl ClientTask +pin_project! { + pub struct PipeMap + where + S: Body, + { + #[pin] + pipe: PipeToSendStream, + #[pin] + conn_drop_ref: Option>, + #[pin] + ping: Option, + } +} + +impl Future for PipeMap where - B: Body + Send + 'static, + B: http_body::Body, + B::Error: Into>, +{ + type Output = (); + + fn poll( + self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll { + let mut this = self.project(); + + match this.pipe.poll_unpin(cx) { + Poll::Ready(result) => { + if let Err(e) = result { + debug!("client request body error: {}", e); + } + drop(this.conn_drop_ref.take().expect("Future polled twice")); + drop(this.ping.take().expect("Future polled twice")); + return Poll::Ready(()); + } + Poll::Pending => (), + }; + Poll::Pending + } +} + +impl ClientTask +where + B: Body + 'static + Unpin, B::Data: Send, - B::Error: Into>, + E: ExecutorClient + Unpin, + B::Error: Into>, + T: AsyncRead + AsyncWrite + Unpin, { fn poll_pipe(&mut self, f: FutCtx, cx: &mut task::Context<'_>) { let ping = self.ping.clone(); + let send_stream = if !f.is_connect { if !f.eos { - let mut pipe = Box::pin(PipeToSendStream::new(f.body, f.body_tx)).map(|res| { - if let Err(e) = res { - debug!("client request body error: {}", e); - } - }); + let mut pipe = PipeToSendStream::new(f.body, f.body_tx); // eagerly see if the body pipe is ready and // can thus skip allocating in the executor @@ -250,13 +457,15 @@ where // "open stream" alive while this body is // still sending... let ping = ping.clone(); - let pipe = pipe.map(move |x| { - drop(conn_drop_ref); - drop(ping); - x - }); + + let pipe = PipeMap { + pipe, + conn_drop_ref: Some(conn_drop_ref), + ping: Some(ping), + }; // Clear send task - self.executor.execute(pipe); + self.executor + .execute_h2_future(H2ClientFuture::Pipe { pipe: pipe }); } } } @@ -266,7 +475,49 @@ where Some(f.body_tx) }; - let fut = f.fut.map(move |result| match result { + self.executor.execute_h2_future(H2ClientFuture::Send { + send_when: SendWhen { + when: ResponseFutMap { + fut: f.fut, + ping: Some(ping), + send_stream: Some(send_stream), + }, + call_back: Some(f.cb), + }, + }); + } +} + +pin_project! { + pub(crate) struct ResponseFutMap + where + B: Body, + B: 'static, + { + #[pin] + fut: ResponseFuture, + #[pin] + ping: Option, + #[pin] + send_stream: Option::Data>>>>, + } +} + +impl Future for ResponseFutMap +where + B: Body + 'static, +{ + type Output = Result, (crate::Error, Option>)>; + + fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll { + let mut this = self.project(); + + let result = ready!(this.fut.poll(cx)); + + let ping = this.ping.take().expect("Future polled twice"); + let send_stream = this.send_stream.take().expect("Future polled twice"); + + match result { Ok(res) => { // record that we got the response headers ping.record_non_data(); @@ -277,17 +528,17 @@ where warn!("h2 connect response with non-zero body not supported"); send_stream.send_reset(h2::Reason::INTERNAL_ERROR); - return Err(( + return Poll::Ready(Err(( crate::Error::new_h2(h2::Reason::INTERNAL_ERROR.into()), - None, - )); + None::>, + ))); } let (parts, recv_stream) = res.into_parts(); let mut res = Response::from_parts(parts, IncomingBody::empty()); let (pending, on_upgrade) = crate::upgrade::pending(); let io = H2Upgraded { - ping, + ping: ping, send_stream: unsafe { UpgradedSendStream::new(send_stream) }, recv_stream, buf: Bytes::new(), @@ -297,31 +548,32 @@ where pending.fulfill(upgraded); res.extensions_mut().insert(on_upgrade); - Ok(res) + Poll::Ready(Ok(res)) } else { let res = res.map(|stream| { let ping = ping.for_stream(&stream); IncomingBody::h2(stream, content_length.into(), ping) }); - Ok(res) + Poll::Ready(Ok(res)) } } Err(err) => { ping.ensure_not_timed_out().map_err(|e| (e, None))?; debug!("client response error: {}", err); - Err((crate::Error::new_h2(err), None)) + Poll::Ready(Err((crate::Error::new_h2(err), None::>))) } - }); - self.executor.execute(f.cb.send_when(fut)); + } } } -impl Future for ClientTask +impl Future for ClientTask where - B: Body + Send + 'static, + B: Body + 'static + Unpin, B::Data: Send, - B::Error: Into>, + B::Error: Into>, + E: ExecutorClient + 'static + Send + Sync + Unpin, + T: AsyncRead + AsyncWrite + Unpin, { type Output = crate::Result; diff --git a/src/proto/h2/mod.rs b/src/proto/h2/mod.rs index c81c0b4665..d0e8c0c323 100644 --- a/src/proto/h2/mod.rs +++ b/src/proto/h2/mod.rs @@ -85,7 +85,7 @@ fn strip_connection_headers(headers: &mut HeaderMap, is_request: bool) { // body adapters used by both Client and Server pin_project! { - struct PipeToSendStream + pub(crate) struct PipeToSendStream where S: Body, { diff --git a/src/rt/bounds.rs b/src/rt/bounds.rs index 69115ef2ca..6368339796 100644 --- a/src/rt/bounds.rs +++ b/src/rt/bounds.rs @@ -6,14 +6,18 @@ #[cfg(all(feature = "server", feature = "http2"))] pub use self::h2::Http2ConnExec; -#[cfg(all(feature = "server", feature = "http2"))] +#[cfg(all(feature = "client", feature = "http2"))] +pub use self::h2_client::ExecutorClient; + +#[cfg(all(feature = "client", feature = "http2"))] #[cfg_attr(docsrs, doc(cfg(all(feature = "server", feature = "http2"))))] -mod h2 { - use crate::{common::exec::Exec, proto::h2::server::H2Stream, rt::Executor}; - use http_body::Body; - use std::future::Future; +mod h2_client { + use std::{error::Error, future::Future}; + use tokio::io::{AsyncRead, AsyncWrite}; - /// An executor to spawn http2 connections. + use crate::{proto::h2::client::H2ClientFuture, rt::Executor}; + + /// An executor to spawn http2 futures for the client. /// /// This trait is implemented for any type that implements [`Executor`] /// trait for any future. @@ -21,28 +25,64 @@ mod h2 { /// This trait is sealed and cannot be implemented for types outside this crate. /// /// [`Executor`]: crate::rt::Executor - pub trait Http2ConnExec: sealed::Sealed<(F, B)> + Clone { + pub trait ExecutorClient: sealed_client::Sealed<(B, T)> + where + B: http_body::Body, + B::Error: Into>, + T: AsyncRead + AsyncWrite + Unpin, + { #[doc(hidden)] - fn execute_h2stream(&mut self, fut: H2Stream); + fn execute_h2_future(&mut self, future: H2ClientFuture); } - impl Http2ConnExec for Exec + impl ExecutorClient for E where - H2Stream: Future + Send + 'static, - B: Body, + E: Executor>, + B: http_body::Body + 'static, + B::Error: Into>, + H2ClientFuture: Future, + T: AsyncRead + AsyncWrite + Unpin, { - fn execute_h2stream(&mut self, fut: H2Stream) { - self.execute(fut) + fn execute_h2_future(&mut self, future: H2ClientFuture) { + self.execute(future) } } - impl sealed::Sealed<(F, B)> for Exec + impl sealed_client::Sealed<(B, T)> for E where - H2Stream: Future + Send + 'static, - B: Body, + E: Executor>, + B: http_body::Body + 'static, + B::Error: Into>, + H2ClientFuture: Future, + T: AsyncRead + AsyncWrite + Unpin, { } + mod sealed_client { + pub trait Sealed {} + } +} + +#[cfg(all(feature = "server", feature = "http2"))] +#[cfg_attr(docsrs, doc(cfg(all(feature = "server", feature = "http2"))))] +mod h2 { + use crate::{proto::h2::server::H2Stream, rt::Executor}; + use http_body::Body; + use std::future::Future; + + /// An executor to spawn http2 connections. + /// + /// This trait is implemented for any type that implements [`Executor`] + /// trait for any future. + /// + /// This trait is sealed and cannot be implemented for types outside this crate. + /// + /// [`Executor`]: crate::rt::Executor + pub trait Http2ConnExec: sealed::Sealed<(F, B)> + Clone { + #[doc(hidden)] + fn execute_h2stream(&mut self, fut: H2Stream); + } + #[doc(hidden)] impl Http2ConnExec for E where