From 9c3ce5a7bfff31887d008c49ce33e6dd498f22e8 Mon Sep 17 00:00:00 2001 From: Sean McArthur Date: Mon, 15 Jan 2024 16:00:41 -0500 Subject: [PATCH] wip: hyper v1 upgrade --- .github/workflows/ci.yml | 1 + Cargo.toml | 28 +-- src/async_impl/body.rs | 260 +++++++++++-------------- src/async_impl/client.rs | 41 ++-- src/async_impl/decoder.rs | 104 +++++----- src/async_impl/h3_client/connect.rs | 2 +- src/async_impl/h3_client/dns.rs | 2 +- src/async_impl/h3_client/pool.rs | 2 +- src/async_impl/response.rs | 60 ++++-- src/async_impl/upgrade.rs | 7 +- src/connect.rs | 288 ++++++++++++++++------------ src/dns/gai.rs | 4 +- src/dns/resolve.rs | 4 +- src/dns/trust_dns.rs | 2 +- src/error.rs | 7 +- src/tls.rs | 87 +++++++-- tests/support/server.rs | 42 ++-- 17 files changed, 519 insertions(+), 422 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 957186a0d..0c3318bc2 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -209,6 +209,7 @@ jobs: docs: name: Docs + needs: [test] runs-on: ubuntu-latest steps: diff --git a/Cargo.toml b/Cargo.toml index c9a0521fc..52342e45f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -35,7 +35,7 @@ default-tls = ["hyper-tls", "native-tls-crate", "__tls", "tokio-native-tls"] # Enables native-tls specific functionality not available by default. native-tls = ["default-tls"] -native-tls-alpn = ["native-tls", "native-tls-crate/alpn"] +native-tls-alpn = ["native-tls", "native-tls-crate/alpn", "hyper-tls/alpn"] native-tls-vendored = ["native-tls", "native-tls-crate/vendored"] rustls-tls = ["rustls-tls-webpki-roots"] @@ -74,14 +74,14 @@ __tls = [] # Enables common rustls code. # Equivalent to rustls-tls-manual-roots but shorter :) -__rustls = ["hyper-rustls", "tokio-rustls", "rustls", "__tls", "rustls-pemfile"] +__rustls = ["hyper-rustls", "tokio-rustls", "rustls", "__tls", "rustls-pemfile", "rustls-pki-types"] # When enabled, disable using the cached SYS_PROXIES. __internal_proxy_sys_no_cache = [] [dependencies] base64 = "0.21" -http = "0.2" +http = "1" url = "2.2" bytes = "1.0" serde = "1.0" @@ -100,9 +100,11 @@ mime_guess = { version = "2.0", default-features = false, optional = true } [target.'cfg(not(target_arch = "wasm32"))'.dependencies] encoding_rs = "0.8" -http-body = "0.4.0" -hyper = { version = "0.14.21", default-features = false, features = ["tcp", "http1", "http2", "client", "runtime"] } -h2 = "0.3.14" +http-body = "1" +http-body-util = "0.1" +hyper = { version = "1", features = ["http1", "http2", "client"] } +hyper-util = { version = "0.1", features = ["http1", "http2", "client", "client-legacy", "tokio"] } +h2 = "0.4" once_cell = "1" log = "0.4" mime = "0.3.16" @@ -114,15 +116,16 @@ ipnet = "2.3" # Optional deps... ## default-tls -hyper-tls = { version = "0.5", optional = true } +hyper-tls = { version = "0.6", optional = true } native-tls-crate = { version = "0.2.10", optional = true, package = "native-tls" } tokio-native-tls = { version = "0.3.0", optional = true } # rustls-tls -hyper-rustls = { version = "0.24.0", default-features = false, optional = true } -rustls = { version = "0.21.6", features = ["dangerous_configuration"], optional = true } -tokio-rustls = { version = "0.24", optional = true } -webpki-roots = { version = "0.25", optional = true } +hyper-rustls = { version = "0.26.0", default-features = false, optional = true } +rustls = { version = "0.22.2", optional = true } +rustls-pki-types = { version = "1.1.0", features = ["alloc"] ,optional = true } +tokio-rustls = { version = "0.25", optional = true } +webpki-roots = { version = "0.26.0", optional = true } rustls-native-certs = { version = "0.6", optional = true } rustls-pemfile = { version = "1.0", optional = true } @@ -149,7 +152,8 @@ futures-channel = { version="0.3", optional = true} [target.'cfg(not(target_arch = "wasm32"))'.dev-dependencies] env_logger = "0.10" -hyper = { version = "0.14", default-features = false, features = ["tcp", "stream", "http1", "http2", "client", "server", "runtime"] } +hyper = { version = "1.1.0", default-features = false, features = ["http1", "http2", "client", "server"] } +hyper-util = { version = "0.1", features = ["http1", "http2", "client", "client-legacy", "server-auto", "tokio"] } serde = { version = "1.0", features = ["derive"] } libflate = "1.0" brotli_crate = { package = "brotli", version = "3.3.0" } diff --git a/src/async_impl/body.rs b/src/async_impl/body.rs index 0d0357cb6..1d2a7feff 100644 --- a/src/async_impl/body.rs +++ b/src/async_impl/body.rs @@ -4,10 +4,9 @@ use std::pin::Pin; use std::task::{Context, Poll}; use bytes::Bytes; -use futures_core::Stream; use http_body::Body as HttpBody; -use pin_project_lite::pin_project; -use sync_wrapper::SyncWrapper; +use http_body_util::combinators::BoxBody; +//use sync_wrapper::SyncWrapper; #[cfg(feature = "stream")] use tokio::fs::File; use tokio::time::Sleep; @@ -19,31 +18,22 @@ pub struct Body { inner: Inner, } -// The `Stream` trait isn't stable, so the impl isn't public. -pub(crate) struct ImplStream(Body); - enum Inner { Reusable(Bytes), - Streaming { - body: Pin< - Box< - dyn HttpBody> - + Send - + Sync, - >, - >, - timeout: Option>>, - }, + Streaming(BoxBody>), } -pin_project! { - struct WrapStream { - #[pin] - inner: SyncWrapper, - } +/// A body with a total timeout. +/// +/// The timeout does not reset upon each chunk, but rather requires the whole +/// body be streamed before the deadline is reached. +pub(crate) struct TotalTimeoutBody { + inner: B, + timeout: Pin>, } -struct WrapHyper(hyper::Body); +/// Converts any `impl Body` into a `impl Stream` of just its DATA frames. +pub(crate) struct DataStream(pub(crate) B); impl Body { /// Returns a reference to the internal data of the `Body`. @@ -52,7 +42,7 @@ impl Body { pub fn as_bytes(&self) -> Option<&[u8]> { match &self.inner { Inner::Reusable(bytes) => Some(bytes.as_ref()), - Inner::Streaming { .. } => None, + Inner::Streaming(..) => None, } } @@ -90,43 +80,37 @@ impl Body { Body::stream(stream) } + #[cfg(any(feature = "stream", feature = "multipart"))] pub(crate) fn stream(stream: S) -> Body where - S: futures_core::stream::TryStream + Send + 'static, + S: futures_core::stream::TryStream + Send + Sync + 'static, S::Error: Into>, Bytes: From, { use futures_util::TryStreamExt; - - let body = Box::pin(WrapStream { - inner: SyncWrapper::new(stream.map_ok(Bytes::from).map_err(Into::into)), - }); + use http_body::Frame; + use http_body_util::StreamBody; + + let body = http_body_util::BodyExt::boxed(StreamBody::new( + stream + .map_ok(|d| Frame::data(Bytes::from(d))) + .map_err(Into::into), + )); Body { - inner: Inner::Streaming { - body, - timeout: None, - }, - } - } - - pub(crate) fn response(body: hyper::Body, timeout: Option>>) -> Body { - Body { - inner: Inner::Streaming { - body: Box::pin(WrapHyper(body)), - timeout, - }, + inner: Inner::Streaming(body), } } + /* #[cfg(feature = "blocking")] pub(crate) fn wrap(body: hyper::Body) -> Body { Body { inner: Inner::Streaming { body: Box::pin(WrapHyper(body)), - timeout: None, }, } } + */ pub(crate) fn empty() -> Body { Body::reusable(Bytes::new()) @@ -138,6 +122,25 @@ impl Body { } } + // pub? + pub(crate) fn streaming(inner: B) -> Body + where + B: HttpBody + Send + Sync + 'static, + B::Data: Into, + B::Error: Into>, + { + use http_body_util::BodyExt; + + let boxed = inner + .map_frame(|f| f.map_data(Into::into)) + .map_err(Into::into) + .boxed(); + + Body { + inner: Inner::Streaming(boxed), + } + } + pub(crate) fn try_reuse(self) -> (Option, Self) { let reuse = match self.inner { Inner::Reusable(ref chunk) => Some(chunk.clone()), @@ -154,30 +157,32 @@ impl Body { } } - pub(crate) fn into_stream(self) -> ImplStream { - ImplStream(self) + #[cfg(feature = "multipart")] + pub(crate) fn into_stream(self) -> DataStream { + DataStream(self) } #[cfg(feature = "multipart")] pub(crate) fn content_length(&self) -> Option { match self.inner { Inner::Reusable(ref bytes) => Some(bytes.len() as u64), - Inner::Streaming { ref body, .. } => body.size_hint().exact(), + Inner::Streaming(ref body) => body.size_hint().exact(), } } } +/* impl From for Body { #[inline] fn from(body: hyper::Body) -> Body { Self { inner: Inner::Streaming { body: Box::pin(WrapHyper(body)), - timeout: None, }, } } } +*/ impl From for Body { #[inline] @@ -229,132 +234,101 @@ impl fmt::Debug for Body { } } -// ===== impl ImplStream ===== - -impl HttpBody for ImplStream { +impl HttpBody for Body { type Data = Bytes; type Error = crate::Error; - fn poll_data( + fn poll_frame( mut self: Pin<&mut Self>, cx: &mut Context, - ) -> Poll>> { - let opt_try_chunk = match self.0.inner { - Inner::Streaming { - ref mut body, - ref mut timeout, - } => { - if let Some(ref mut timeout) = timeout { - if let Poll::Ready(()) = timeout.as_mut().poll(cx) { - return Poll::Ready(Some(Err(crate::error::body(crate::error::TimedOut)))); - } - } - futures_core::ready!(Pin::new(body).poll_data(cx)) - .map(|opt_chunk| opt_chunk.map(Into::into).map_err(crate::error::body)) - } + ) -> Poll, Self::Error>>> { + match self.inner { Inner::Reusable(ref mut bytes) => { - if bytes.is_empty() { - None + let out = bytes.split_off(0); + if out.is_empty() { + Poll::Ready(None) } else { - Some(Ok(std::mem::replace(bytes, Bytes::new()))) + Poll::Ready(Some(Ok(hyper::body::Frame::data(out)))) } } - }; - - Poll::Ready(opt_try_chunk) - } - - fn poll_trailers( - self: Pin<&mut Self>, - _cx: &mut Context, - ) -> Poll, Self::Error>> { - Poll::Ready(Ok(None)) - } - - fn is_end_stream(&self) -> bool { - match self.0.inner { - Inner::Streaming { ref body, .. } => body.is_end_stream(), - Inner::Reusable(ref bytes) => bytes.is_empty(), - } - } - - fn size_hint(&self) -> http_body::SizeHint { - match self.0.inner { - Inner::Streaming { ref body, .. } => body.size_hint(), - Inner::Reusable(ref bytes) => { - let mut hint = http_body::SizeHint::default(); - hint.set_exact(bytes.len() as u64); - hint - } + Inner::Streaming(ref mut body) => Poll::Ready( + futures_core::ready!(Pin::new(body).poll_frame(cx)) + .map(|opt_chunk| opt_chunk.map_err(crate::error::body)), + ), } } } -impl Stream for ImplStream { - type Item = Result; +// ===== impl TotalTimeoutBody ===== - fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { - self.poll_data(cx) +pub(crate) fn total_timeout(body: B, timeout: Pin>) -> TotalTimeoutBody { + TotalTimeoutBody { + inner: body, + timeout, } } -// ===== impl WrapStream ===== - -impl HttpBody for WrapStream +impl hyper::body::Body for TotalTimeoutBody where - S: Stream>, - D: Into, - E: Into>, + B: hyper::body::Body + Unpin, + B::Error: Into>, { - type Data = Bytes; - type Error = E; + type Data = B::Data; + type Error = crate::Error; - fn poll_data( - self: Pin<&mut Self>, + fn poll_frame( + mut self: Pin<&mut Self>, cx: &mut Context, - ) -> Poll>> { - let item = futures_core::ready!(self.project().inner.get_pin_mut().poll_next(cx)?); - - Poll::Ready(item.map(|val| Ok(val.into()))) - } - - fn poll_trailers( - self: Pin<&mut Self>, - _cx: &mut Context, - ) -> Poll, Self::Error>> { - Poll::Ready(Ok(None)) + ) -> Poll, Self::Error>>> { + if let Poll::Ready(()) = self.timeout.as_mut().poll(cx) { + return Poll::Ready(Some(Err(crate::error::body(crate::error::TimedOut)))); + } + Poll::Ready( + futures_core::ready!(Pin::new(&mut self.inner).poll_frame(cx)) + .map(|opt_chunk| opt_chunk.map_err(crate::error::body)), + ) } } -// ===== impl WrapHyper ===== - -impl HttpBody for WrapHyper { - type Data = Bytes; - type Error = Box; +pub(crate) type ResponseBody = + http_body_util::combinators::BoxBody>; - fn poll_data( - mut self: Pin<&mut Self>, - cx: &mut Context, - ) -> Poll>> { - // safe pin projection - Pin::new(&mut self.0) - .poll_data(cx) - .map(|opt| opt.map(|res| res.map_err(Into::into))) - } +pub(crate) fn response( + body: hyper::body::Incoming, + timeout: Option>>, +) -> ResponseBody { + use http_body_util::BodyExt; - fn poll_trailers( - self: Pin<&mut Self>, - _cx: &mut Context, - ) -> Poll, Self::Error>> { - Poll::Ready(Ok(None)) + if let Some(timeout) = timeout { + total_timeout(body, timeout).map_err(Into::into).boxed() + } else { + body.map_err(Into::into).boxed() } +} - fn is_end_stream(&self) -> bool { - self.0.is_end_stream() - } +// ===== impl DataStream ===== - fn size_hint(&self) -> http_body::SizeHint { - HttpBody::size_hint(&self.0) +impl futures_core::Stream for DataStream +where + B: HttpBody + Unpin, +{ + type Item = Result; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + loop { + return match futures_core::ready!(Pin::new(&mut self.0).poll_frame(cx)) { + Some(Ok(frame)) => { + // skip non-data frames + if let Ok(buf) = frame.into_data() { + Poll::Ready(Some(Ok(buf))) + } else { + continue; + } + } + Some(Err(err)) => Poll::Ready(Some(Err(err))), + None => Poll::Ready(None), + }; + } } } diff --git a/src/async_impl/client.rs b/src/async_impl/client.rs index f082d28bc..a2716c0c8 100644 --- a/src/async_impl/client.rs +++ b/src/async_impl/client.rs @@ -13,7 +13,7 @@ use http::header::{ }; use http::uri::Scheme; use http::Uri; -use hyper::client::{HttpConnector, ResponseFuture as HyperResponseFuture}; +use hyper_util::client::legacy::connect::HttpConnector; #[cfg(feature = "native-tls-crate")] use native_tls_crate::TlsConnector; use pin_project_lite::pin_project; @@ -52,6 +52,8 @@ use quinn::TransportConfig; #[cfg(feature = "http3")] use quinn::VarInt; +type HyperResponseFuture = hyper_util::client::legacy::ResponseFuture; + /// An asynchronous `Client` to make Requests with. /// /// The Client has various configuration values to tweak, but the defaults @@ -464,18 +466,7 @@ impl ClientBuilder { #[cfg(feature = "rustls-tls-webpki-roots")] if config.tls_built_in_root_certs { - use rustls::OwnedTrustAnchor; - - let trust_anchors = - webpki_roots::TLS_SERVER_ROOTS.iter().map(|trust_anchor| { - OwnedTrustAnchor::from_subject_spki_name_constraints( - trust_anchor.subject, - trust_anchor.spki, - trust_anchor.name_constraints, - ) - }); - - root_cert_store.add_trust_anchors(trust_anchors); + root_cert_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned()); } #[cfg(feature = "rustls-tls-native-roots")] @@ -485,11 +476,10 @@ impl ClientBuilder { for cert in rustls_native_certs::load_native_certs() .map_err(crate::error::builder)? { - let cert = rustls::Certificate(cert.0); // Continue on parsing errors, as native stores often include ancient or syntactically // invalid certificates, like root certificates without any X509 extensions. // Inspiration: https://github.com/rustls/rustls/blob/633bf4ba9d9521a95f68766d04c22e2b01e68318/rustls/src/anchors.rs#L105-L112 - match root_cert_store.add(&cert) { + match root_cert_store.add(cert.into()) { Ok(_) => valid_count += 1, Err(err) => { invalid_count += 1; @@ -532,12 +522,8 @@ impl ClientBuilder { } // Build TLS config - let config_builder = rustls::ClientConfig::builder() - .with_safe_default_cipher_suites() - .with_safe_default_kx_groups() - .with_protocol_versions(&versions) - .map_err(crate::error::builder)? - .with_root_certificates(root_cert_store); + let config_builder = + rustls::ClientConfig::builder().with_root_certificates(root_cert_store); // Finalize TLS config let mut tls = if let Some(id) = config.identity { @@ -612,7 +598,8 @@ impl ClientBuilder { connector.set_timeout(config.connect_timeout); connector.set_verbose(config.connection_verbose); - let mut builder = hyper::Client::builder(); + let mut builder = + hyper_util::client::legacy::Client::builder(hyper_util::rt::TokioExecutor::new()); if matches!(config.http_version_pref, HttpVersionPref::Http2) { builder.http2_only(true); } @@ -1651,7 +1638,7 @@ impl ClientBuilder { } } -type HyperClient = hyper::Client; +type HyperClient = hyper_util::client::legacy::Client; impl Default for Client { fn default() -> Self { @@ -1828,9 +1815,7 @@ impl Client { ResponseFuture::H3(self.inner.h3_client.as_ref().unwrap().request(req)) } _ => { - let mut req = builder - .body(body.into_stream()) - .expect("valid request parts"); + let mut req = builder.body(body).expect("valid request parts"); *req.headers_mut() = headers.clone(); ResponseFuture::Default(self.inner.hyper.request(req)) } @@ -2194,7 +2179,7 @@ impl PendingRequest { let mut req = hyper::Request::builder() .method(self.method.clone()) .uri(uri) - .body(body.into_stream()) + .body(body) .expect("valid request parts"); *req.headers_mut() = self.headers.clone(); ResponseFuture::Default(self.client.hyper.request(req)) @@ -2430,7 +2415,7 @@ impl Future for PendingRequest { let mut req = hyper::Request::builder() .method(self.method.clone()) .uri(uri.clone()) - .body(body.into_stream()) + .body(body) .expect("valid request parts"); *req.headers_mut() = headers.clone(); std::mem::swap(self.as_mut().headers(), &mut headers); diff --git a/src/async_impl/decoder.rs b/src/async_impl/decoder.rs index c0542cfb1..92b6181cf 100644 --- a/src/async_impl/decoder.rs +++ b/src/async_impl/decoder.rs @@ -16,14 +16,15 @@ use bytes::Bytes; use futures_core::Stream; use futures_util::stream::Peekable; use http::HeaderMap; -use hyper::body::HttpBody; +use hyper::body::Body as HttpBody; +use hyper::body::Frame; #[cfg(any(feature = "gzip", feature = "brotli", feature = "deflate"))] use tokio_util::codec::{BytesCodec, FramedRead}; #[cfg(any(feature = "gzip", feature = "brotli", feature = "deflate"))] use tokio_util::io::StreamReader; -use super::super::Body; +use super::body::ResponseBody; use crate::error; #[derive(Clone, Copy, Debug)] @@ -50,7 +51,7 @@ type PeekableIoStreamReader = StreamReader; enum Inner { /// A `PlainText` decoder just returns the response content as is. - PlainText(super::body::ImplStream), + PlainText(ResponseBody), /// A `Gzip` decoder will uncompress the gzipped response content before returning it. #[cfg(feature = "gzip")] @@ -72,7 +73,7 @@ enum Inner { /// A future attempt to poll the response body for EOF so we know whether to use gzip or not. struct Pending(PeekableIoStream, DecoderType); -struct IoStream(super::body::ImplStream); +struct IoStream(ResponseBody); enum DecoderType { #[cfg(feature = "gzip")] @@ -100,9 +101,9 @@ impl Decoder { /// A plain text decoder. /// /// This decoder will emit the underlying chunks as-is. - fn plain_text(body: Body) -> Decoder { + fn plain_text(body: ResponseBody) -> Decoder { Decoder { - inner: Inner::PlainText(body.into_stream()), + inner: Inner::PlainText(body), } } @@ -110,12 +111,12 @@ impl Decoder { /// /// This decoder will buffer and decompress chunks that are gzipped. #[cfg(feature = "gzip")] - fn gzip(body: Body) -> Decoder { + fn gzip(body: ResponseBody) -> Decoder { use futures_util::StreamExt; Decoder { inner: Inner::Pending(Box::pin(Pending( - IoStream(body.into_stream()).peekable(), + IoStream(body).peekable(), DecoderType::Gzip, ))), } @@ -130,7 +131,7 @@ impl Decoder { Decoder { inner: Inner::Pending(Box::pin(Pending( - IoStream(body.into_stream()).peekable(), + IoStream(body).peekable(), DecoderType::Brotli, ))), } @@ -145,7 +146,7 @@ impl Decoder { Decoder { inner: Inner::Pending(Box::pin(Pending( - IoStream(body.into_stream()).peekable(), + IoStream(body).peekable(), DecoderType::Deflate, ))), } @@ -187,7 +188,11 @@ impl Decoder { /// how to decode the content body of the request. /// /// Uses the correct variant by inspecting the Content-Encoding header. - pub(super) fn detect(_headers: &mut HeaderMap, body: Body, _accepts: Accepts) -> Decoder { + pub(super) fn detect( + _headers: &mut HeaderMap, + body: ResponseBody, + _accepts: Accepts, + ) -> Decoder { #[cfg(feature = "gzip")] { if _accepts.gzip && Decoder::detect_encoding(_headers, "gzip") { @@ -213,26 +218,35 @@ impl Decoder { } } -impl Stream for Decoder { - type Item = Result; +impl HttpBody for Decoder { + type Data = Bytes; + type Error = crate::Error; - fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { - // Do a read or poll for a pending decoder value. + fn poll_frame( + mut self: Pin<&mut Self>, + cx: &mut Context, + ) -> Poll, Self::Error>>> { match self.inner { #[cfg(any(feature = "brotli", feature = "gzip", feature = "deflate"))] Inner::Pending(ref mut future) => match Pin::new(future).poll(cx) { Poll::Ready(Ok(inner)) => { self.inner = inner; - self.poll_next(cx) + self.poll_frame(cx) } Poll::Ready(Err(e)) => Poll::Ready(Some(Err(crate::error::decode_io(e)))), Poll::Pending => Poll::Pending, }, - Inner::PlainText(ref mut body) => Pin::new(body).poll_next(cx), + Inner::PlainText(ref mut body) => { + match futures_core::ready!(Pin::new(body).poll_frame(cx)) { + Some(Ok(frame)) => Poll::Ready(Some(Ok(frame))), + Some(Err(err)) => Poll::Ready(Some(Err(crate::error::decode(err)))), + None => Poll::Ready(None), + } + } #[cfg(feature = "gzip")] Inner::Gzip(ref mut decoder) => { match futures_core::ready!(Pin::new(decoder).poll_next(cx)) { - Some(Ok(bytes)) => Poll::Ready(Some(Ok(bytes.freeze()))), + Some(Ok(bytes)) => Poll::Ready(Some(Ok(Frame::data(bytes.freeze())))), Some(Err(err)) => Poll::Ready(Some(Err(crate::error::decode_io(err)))), None => Poll::Ready(None), } @@ -240,7 +254,7 @@ impl Stream for Decoder { #[cfg(feature = "brotli")] Inner::Brotli(ref mut decoder) => { match futures_core::ready!(Pin::new(decoder).poll_next(cx)) { - Some(Ok(bytes)) => Poll::Ready(Some(Ok(bytes.freeze()))), + Some(Ok(bytes)) => Poll::Ready(Some(Ok(Frame::data(bytes.freeze())))), Some(Err(err)) => Poll::Ready(Some(Err(crate::error::decode_io(err)))), None => Poll::Ready(None), } @@ -248,32 +262,13 @@ impl Stream for Decoder { #[cfg(feature = "deflate")] Inner::Deflate(ref mut decoder) => { match futures_core::ready!(Pin::new(decoder).poll_next(cx)) { - Some(Ok(bytes)) => Poll::Ready(Some(Ok(bytes.freeze()))), + Some(Ok(bytes)) => Poll::Ready(Some(Ok(Frame::data(bytes.freeze())))), Some(Err(err)) => Poll::Ready(Some(Err(crate::error::decode_io(err)))), None => Poll::Ready(None), } } } } -} - -impl HttpBody for Decoder { - type Data = Bytes; - type Error = crate::Error; - - fn poll_data( - self: Pin<&mut Self>, - cx: &mut Context, - ) -> Poll>> { - self.poll_next(cx) - } - - fn poll_trailers( - self: Pin<&mut Self>, - _cx: &mut Context, - ) -> Poll, Self::Error>> { - Poll::Ready(Ok(None)) - } fn size_hint(&self) -> http_body::SizeHint { match self.inner { @@ -285,6 +280,11 @@ impl HttpBody for Decoder { } } +fn empty() -> ResponseBody { + use http_body_util::{combinators::BoxBody, BodyExt, Empty}; + BoxBody::new(Empty::new().map_err(|never| match never {})) +} + impl Future for Pending { type Output = Result; @@ -303,13 +303,10 @@ impl Future for Pending { .expect("just peeked Some") .unwrap_err())); } - None => return Poll::Ready(Ok(Inner::PlainText(Body::empty().into_stream()))), + None => return Poll::Ready(Ok(Inner::PlainText(empty()))), }; - let _body = std::mem::replace( - &mut self.0, - IoStream(Body::empty().into_stream()).peekable(), - ); + let _body = std::mem::replace(&mut self.0, IoStream(empty()).peekable()); match self.1 { #[cfg(feature = "brotli")] @@ -335,10 +332,19 @@ impl Stream for IoStream { type Item = Result; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { - match futures_core::ready!(Pin::new(&mut self.0).poll_next(cx)) { - Some(Ok(chunk)) => Poll::Ready(Some(Ok(chunk))), - Some(Err(err)) => Poll::Ready(Some(Err(err.into_io()))), - None => Poll::Ready(None), + loop { + return match futures_core::ready!(Pin::new(&mut self.0).poll_frame(cx)) { + Some(Ok(frame)) => { + // skip non-data frames + if let Ok(buf) = frame.into_data() { + Poll::Ready(Some(Ok(buf))) + } else { + continue; + } + } + Some(Err(err)) => Poll::Ready(Some(Err(error::into_io(err)))), + None => Poll::Ready(None), + }; } } } @@ -346,6 +352,7 @@ impl Stream for IoStream { // ===== impl Accepts ===== impl Accepts { + /* pub(super) fn none() -> Self { Accepts { #[cfg(feature = "gzip")] @@ -356,6 +363,7 @@ impl Accepts { deflate: false, } } + */ pub(super) fn as_str(&self) -> Option<&'static str> { match (self.is_gzip(), self.is_brotli(), self.is_deflate()) { diff --git a/src/async_impl/h3_client/connect.rs b/src/async_impl/h3_client/connect.rs index 968704713..ec732f66a 100644 --- a/src/async_impl/h3_client/connect.rs +++ b/src/async_impl/h3_client/connect.rs @@ -5,7 +5,7 @@ use bytes::Bytes; use h3::client::SendRequest; use h3_quinn::{Connection, OpenStreams}; use http::Uri; -use hyper::client::connect::dns::Name; +use hyper_util::client::legacy::connect::dns::Name; use quinn::{ClientConfig, Endpoint, TransportConfig}; use std::net::{IpAddr, SocketAddr}; use std::str::FromStr; diff --git a/src/async_impl/h3_client/dns.rs b/src/async_impl/h3_client/dns.rs index 9cb50d1e3..bd59daaed 100644 --- a/src/async_impl/h3_client/dns.rs +++ b/src/async_impl/h3_client/dns.rs @@ -1,5 +1,5 @@ use core::task; -use hyper::client::connect::dns::Name; +use hyper_util::client::legacy::connect::dns::Name; use std::future::Future; use std::net::SocketAddr; use std::task::Poll; diff --git a/src/async_impl/h3_client/pool.rs b/src/async_impl/h3_client/pool.rs index 6fcb8e719..6d3f047b6 100644 --- a/src/async_impl/h3_client/pool.rs +++ b/src/async_impl/h3_client/pool.rs @@ -13,7 +13,7 @@ use h3::client::SendRequest; use h3_quinn::{Connection, OpenStreams}; use http::uri::{Authority, Scheme}; use http::{Request, Response, Uri}; -use hyper::Body as HyperBody; +use hyper::body as HyperBody; use log::trace; pub(super) type Key = (Scheme, Authority); diff --git a/src/async_impl/response.rs b/src/async_impl/response.rs index fc5a5d464..e8e1bde7a 100644 --- a/src/async_impl/response.rs +++ b/src/async_impl/response.rs @@ -4,9 +4,8 @@ use std::pin::Pin; use bytes::Bytes; use encoding_rs::{Encoding, UTF_8}; -use futures_util::stream::StreamExt; -use hyper::client::connect::HttpInfo; use hyper::{HeaderMap, StatusCode, Version}; +use hyper_util::client::legacy::connect::HttpInfo; use mime::Mime; #[cfg(feature = "json")] use serde::de::DeserializeOwned; @@ -19,7 +18,6 @@ use super::body::Body; use super::decoder::{Accepts, Decoder}; #[cfg(feature = "cookies")] use crate::cookie; -use crate::response::ResponseUrl; /// A Response to a submitted `Request`. pub struct Response { @@ -31,13 +29,17 @@ pub struct Response { impl Response { pub(super) fn new( - res: hyper::Response, + res: hyper::Response, url: Url, accepts: Accepts, timeout: Option>>, ) -> Response { let (mut parts, body) = res.into_parts(); - let decoder = Decoder::detect(&mut parts.headers, Body::response(body, timeout), accepts); + let decoder = Decoder::detect( + &mut parts.headers, + super::body::response(body, timeout), + accepts, + ); let res = hyper::Response::from_parts(parts, decoder); Response { @@ -78,9 +80,9 @@ impl Response { /// - The response is compressed and automatically decoded (thus changing /// the actual decoded length). pub fn content_length(&self) -> Option { - use hyper::body::HttpBody; + use hyper::body::Body; - HttpBody::size_hint(self.res.body()).exact() + Body::size_hint(self.res.body()).exact() } /// Retrieve the cookies contained in the response. @@ -256,7 +258,11 @@ impl Response { /// # } /// ``` pub async fn bytes(self) -> crate::Result { - hyper::body::to_bytes(self.res.into_body()).await + use http_body_util::BodyExt; + + BodyExt::collect(self.res.into_body()) + .await + .map(|buf| buf.to_bytes()) } /// Stream a chunk of the response body. @@ -276,10 +282,19 @@ impl Response { /// # } /// ``` pub async fn chunk(&mut self) -> crate::Result> { - if let Some(item) = self.res.body_mut().next().await { - Ok(Some(item?)) - } else { - Ok(None) + use http_body_util::BodyExt; + + // loop to ignore unrecognized frames + loop { + if let Some(res) = self.res.body_mut().frame().await { + let frame = res?; + if let Ok(buf) = frame.into_data() { + return Ok(Some(buf)); + } + // else continue + } else { + return Ok(None); + } } } @@ -308,7 +323,7 @@ impl Response { #[cfg(feature = "stream")] #[cfg_attr(docsrs, doc(cfg(feature = "stream")))] pub fn bytes_stream(self) -> impl futures_core::Stream> { - self.res.into_body() + super::body::DataStream(self.res.into_body()) } // util methods @@ -396,8 +411,20 @@ impl fmt::Debug for Response { } } +/// A `Response` can be piped as the `Body` of another request. +impl From for Body { + fn from(r: Response) -> Body { + Body::streaming(r.res.into_body()) + } +} + +/* +// I'm not sure this conversion is that useful... People should be encouraged +// to use `http::Resposne`, not `reqwest::Response`. impl> From> for Response { fn from(r: http::Response) -> Response { + use crate::response::ResponseUrl; + let (mut parts, body) = r.into_parts(); let body = body.into(); let decoder = Decoder::detect(&mut parts.headers, body, Accepts::none()); @@ -414,12 +441,6 @@ impl> From> for Response { } } -/// A `Response` can be piped as the `Body` of another request. -impl From for Body { - fn from(r: Response) -> Body { - Body::stream(r.res.into_body()) - } -} #[cfg(test)] mod tests { @@ -442,3 +463,4 @@ mod tests { assert_eq!(*response.url(), url); } } +*/ diff --git a/src/async_impl/upgrade.rs b/src/async_impl/upgrade.rs index 4a69b4db5..3b599d0ad 100644 --- a/src/async_impl/upgrade.rs +++ b/src/async_impl/upgrade.rs @@ -3,11 +3,12 @@ use std::task::{self, Poll}; use std::{fmt, io}; use futures_util::TryFutureExt; +use hyper_util::rt::TokioIo; use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; /// An upgraded HTTP connection. pub struct Upgraded { - inner: hyper::upgrade::Upgraded, + inner: TokioIo, } impl AsyncRead for Upgraded { @@ -58,7 +59,9 @@ impl fmt::Debug for Upgraded { impl From for Upgraded { fn from(inner: hyper::upgrade::Upgraded) -> Self { - Upgraded { inner } + Upgraded { + inner: TokioIo::new(inner), + } } } diff --git a/src/connect.rs b/src/connect.rs index 2fdcd56c0..d9a9f44e7 100644 --- a/src/connect.rs +++ b/src/connect.rs @@ -2,11 +2,13 @@ use http::header::HeaderValue; use http::uri::{Authority, Scheme}; use http::Uri; -use hyper::client::connect::{Connected, Connection}; -use hyper::service::Service; +use hyper::rt::{Read, ReadBufCursor, Write}; +use hyper_util::client::legacy::connect::{Connected, Connection}; +#[cfg(feature = "__tls")] +use hyper_util::rt::TokioIo; #[cfg(feature = "native-tls-crate")] use native_tls_crate::{TlsConnector, TlsConnectorBuilder}; -use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; +use tower_service::Service; use pin_project_lite::pin_project; use std::future::Future; @@ -25,7 +27,7 @@ use crate::dns::DynResolver; use crate::error::BoxError; use crate::proxy::{Proxy, ProxyScheme}; -pub(crate) type HttpConnector = hyper::client::HttpConnector; +pub(crate) type HttpConnector = hyper_util::client::legacy::connect::HttpConnector; #[derive(Clone)] pub(crate) struct Connector { @@ -196,7 +198,7 @@ impl Connector { let tls_connector = tokio_native_tls::TlsConnector::from(tls.clone()); let io = tls_connector.connect(&host, conn).await?; return Ok(Conn { - inner: self.verbose.wrap(NativeTlsConn { inner: io }), + inner: self.verbose.wrap(io /*NativeTlsConn { inner: io }*/), is_proxy: false, tls_info: self.tls_info, }); @@ -211,8 +213,9 @@ impl Connector { let tls = tls_proxy.clone(); let host = dst.host().ok_or("no host in url")?.to_string(); let conn = socks::connect(proxy, dst, dns).await?; - let server_name = rustls::ServerName::try_from(host.as_str()) - .map_err(|_| "Invalid Server Name")?; + let server_name = + rustls_pki_types::ServerName::try_from(host.as_str().to_owned()) + .map_err(|_| "Invalid Server Name")?; let io = RustlsConnector::from(tls) .connect(server_name, conn) .await?; @@ -262,7 +265,14 @@ impl Connector { if let hyper_tls::MaybeHttpsStream::Https(stream) = io { if !self.nodelay { - stream.get_ref().get_ref().get_ref().set_nodelay(false)?; + stream + .inner() + .get_ref() + .get_ref() + .get_ref() + .inner() + .inner() + .set_nodelay(false)?; } Ok(Conn { inner: self.verbose.wrap(NativeTlsConn { inner: stream }), @@ -293,8 +303,8 @@ impl Connector { if let hyper_rustls::MaybeHttpsStream::Https(stream) = io { if !self.nodelay { - let (io, _) = stream.get_ref(); - io.set_nodelay(false)?; + let (io, _) = stream.inner().get_ref(); + io.inner().inner().set_nodelay(false)?; } Ok(Conn { inner: self.verbose.wrap(RustlsTlsConn { inner: stream }), @@ -350,10 +360,12 @@ impl Connector { .await?; let tls_connector = tokio_native_tls::TlsConnector::from(tls.clone()); let io = tls_connector - .connect(host.ok_or("no host in url")?, tunneled) + .connect(host.ok_or("no host in url")?, TokioIo::new(tunneled)) .await?; return Ok(Conn { - inner: self.verbose.wrap(NativeTlsConn { inner: io }), + inner: self.verbose.wrap(NativeTlsConn { + inner: TokioIo::new(io), + }), is_proxy: false, tls_info: false, }); @@ -366,7 +378,7 @@ impl Connector { tls_proxy, } => { if dst.scheme() == Some(&Scheme::HTTPS) { - use rustls::ServerName; + use rustls_pki_types::ServerName; use std::convert::TryFrom; use tokio_rustls::TlsConnector as RustlsConnector; @@ -377,16 +389,18 @@ impl Connector { let tls = tls.clone(); let conn = http.call(proxy_dst).await?; log::trace!("tunneling HTTPS over proxy"); - let maybe_server_name = - ServerName::try_from(host.as_str()).map_err(|_| "Invalid Server Name"); + let maybe_server_name = ServerName::try_from(host.as_str().to_owned()) + .map_err(|_| "Invalid Server Name"); let tunneled = tunnel(conn, host, port, self.user_agent.clone(), auth).await?; let server_name = maybe_server_name?; let io = RustlsConnector::from(tls) - .connect(server_name, tunneled) + .connect(server_name, TokioIo::new(tunneled)) .await?; return Ok(Conn { - inner: self.verbose.wrap(RustlsTlsConn { inner: io }), + inner: self.verbose.wrap(RustlsTlsConn { + inner: TokioIo::new(io), + }), is_proxy: false, tls_info: false, }); @@ -476,18 +490,15 @@ impl TlsInfoFactory for tokio::net::TcpStream { } } -#[cfg(feature = "default-tls")] -impl TlsInfoFactory for hyper_tls::MaybeHttpsStream { +#[cfg(feature = "__tls")] +impl TlsInfoFactory for TokioIo { fn tls_info(&self) -> Option { - match self { - hyper_tls::MaybeHttpsStream::Https(tls) => tls.tls_info(), - hyper_tls::MaybeHttpsStream::Http(_) => None, - } + self.inner().tls_info() } } #[cfg(feature = "default-tls")] -impl TlsInfoFactory for hyper_tls::TlsStream> { +impl TlsInfoFactory for tokio_native_tls::TlsStream>> { fn tls_info(&self) -> Option { let peer_certificate = self .get_ref() @@ -500,7 +511,11 @@ impl TlsInfoFactory for hyper_tls::TlsStream { +impl TlsInfoFactory + for tokio_native_tls::TlsStream< + TokioIo>>, + > +{ fn tls_info(&self) -> Option { let peer_certificate = self .get_ref() @@ -512,32 +527,35 @@ impl TlsInfoFactory for tokio_native_tls::TlsStream { } } -#[cfg(feature = "__rustls")] -impl TlsInfoFactory for hyper_rustls::MaybeHttpsStream { +#[cfg(feature = "default-tls")] +impl TlsInfoFactory for hyper_tls::MaybeHttpsStream> { fn tls_info(&self) -> Option { match self { - hyper_rustls::MaybeHttpsStream::Https(tls) => tls.tls_info(), - hyper_rustls::MaybeHttpsStream::Http(_) => None, + hyper_tls::MaybeHttpsStream::Https(tls) => tls.tls_info(), + hyper_tls::MaybeHttpsStream::Http(_) => None, } } } #[cfg(feature = "__rustls")] -impl TlsInfoFactory for tokio_rustls::TlsStream { +impl TlsInfoFactory for tokio_rustls::client::TlsStream>> { fn tls_info(&self) -> Option { let peer_certificate = self .get_ref() .1 .peer_certificates() .and_then(|certs| certs.first()) - .map(|c| c.0.clone()); + .map(|c| c.first()) + .and_then(|c| c.map(|cc| vec![*cc])); Some(crate::tls::TlsInfo { peer_certificate }) } } #[cfg(feature = "__rustls")] impl TlsInfoFactory - for tokio_rustls::client::TlsStream> + for tokio_rustls::client::TlsStream< + TokioIo>>, + > { fn tls_info(&self) -> Option { let peer_certificate = self @@ -545,30 +563,28 @@ impl TlsInfoFactory .1 .peer_certificates() .and_then(|certs| certs.first()) - .map(|c| c.0.clone()); + .map(|c| c.first()) + .and_then(|c| c.map(|cc| vec![*cc])); Some(crate::tls::TlsInfo { peer_certificate }) } } #[cfg(feature = "__rustls")] -impl TlsInfoFactory for tokio_rustls::client::TlsStream { +impl TlsInfoFactory for hyper_rustls::MaybeHttpsStream> { fn tls_info(&self) -> Option { - let peer_certificate = self - .get_ref() - .1 - .peer_certificates() - .and_then(|certs| certs.first()) - .map(|c| c.0.clone()); - Some(crate::tls::TlsInfo { peer_certificate }) + match self { + hyper_rustls::MaybeHttpsStream::Https(tls) => tls.tls_info(), + hyper_rustls::MaybeHttpsStream::Http(_) => None, + } } } pub(crate) trait AsyncConn: - AsyncRead + AsyncWrite + Connection + Send + Sync + Unpin + 'static + Read + Write + Connection + Send + Sync + Unpin + 'static { } -impl AsyncConn for T {} +impl AsyncConn for T {} #[cfg(feature = "__tls")] trait AsyncConnWithInfo: AsyncConn + TlsInfoFactory {} @@ -614,25 +630,25 @@ impl Connection for Conn { } } -impl AsyncRead for Conn { +impl Read for Conn { fn poll_read( self: Pin<&mut Self>, cx: &mut Context, - buf: &mut ReadBuf<'_>, + buf: ReadBufCursor<'_>, ) -> Poll> { let this = self.project(); - AsyncRead::poll_read(this.inner, cx, buf) + Read::poll_read(this.inner, cx, buf) } } -impl AsyncWrite for Conn { +impl Write for Conn { fn poll_write( self: Pin<&mut Self>, cx: &mut Context, buf: &[u8], ) -> Poll> { let this = self.project(); - AsyncWrite::poll_write(this.inner, cx, buf) + Write::poll_write(this.inner, cx, buf) } fn poll_write_vectored( @@ -641,7 +657,7 @@ impl AsyncWrite for Conn { bufs: &[IoSlice<'_>], ) -> Poll> { let this = self.project(); - AsyncWrite::poll_write_vectored(this.inner, cx, bufs) + Write::poll_write_vectored(this.inner, cx, bufs) } fn is_write_vectored(&self) -> bool { @@ -650,12 +666,12 @@ impl AsyncWrite for Conn { fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { let this = self.project(); - AsyncWrite::poll_flush(this.inner, cx) + Write::poll_flush(this.inner, cx) } fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { let this = self.project(); - AsyncWrite::poll_shutdown(this.inner, cx) + Write::poll_shutdown(this.inner, cx) } } @@ -670,8 +686,9 @@ async fn tunnel( auth: Option, ) -> Result where - T: AsyncRead + AsyncWrite + Unpin, + T: Read + Write + Unpin, { + use hyper_util::rt::TokioIo; use tokio::io::{AsyncReadExt, AsyncWriteExt}; let mut buf = format!( @@ -701,13 +718,15 @@ where // headers end buf.extend_from_slice(b"\r\n"); - conn.write_all(&buf).await?; + let mut tokio_conn = TokioIo::new(&mut conn); + + tokio_conn.write_all(&buf).await?; let mut buf = [0; 8192]; let mut pos = 0; loop { - let n = conn.read(&mut buf[pos..]).await?; + let n = tokio_conn.read(&mut buf[pos..]).await?; if n == 0 { return Err(tunnel_eof()); @@ -739,62 +758,69 @@ fn tunnel_eof() -> BoxError { #[cfg(feature = "default-tls")] mod native_tls_conn { use super::TlsInfoFactory; - use hyper::client::connect::{Connected, Connection}; + use hyper::rt::{Read, ReadBufCursor, Write}; + use hyper_tls::MaybeHttpsStream; + use hyper_util::client::legacy::connect::{Connected, Connection}; + use hyper_util::rt::TokioIo; use pin_project_lite::pin_project; use std::{ io::{self, IoSlice}, pin::Pin, task::{Context, Poll}, }; - use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; + use tokio::io::{AsyncRead, AsyncWrite}; + use tokio::net::TcpStream; use tokio_native_tls::TlsStream; pin_project! { pub(super) struct NativeTlsConn { - #[pin] pub(super) inner: TlsStream, + #[pin] pub(super) inner: TokioIo>, } } - impl Connection for NativeTlsConn { - #[cfg(feature = "native-tls-alpn")] + impl Connection for NativeTlsConn>> { fn connected(&self) -> Connected { - match self.inner.get_ref().negotiated_alpn().ok() { - Some(Some(alpn_protocol)) if alpn_protocol == b"h2" => self - .inner - .get_ref() - .get_ref() - .get_ref() - .connected() - .negotiated_h2(), - _ => self.inner.get_ref().get_ref().get_ref().connected(), - } + self.inner + .inner() + .get_ref() + .get_ref() + .get_ref() + .inner() + .connected() } + } - #[cfg(not(feature = "native-tls-alpn"))] + impl Connection for NativeTlsConn>>> { fn connected(&self) -> Connected { - self.inner.get_ref().get_ref().get_ref().connected() + self.inner + .inner() + .get_ref() + .get_ref() + .get_ref() + .inner() + .connected() } } - impl AsyncRead for NativeTlsConn { + impl Read for NativeTlsConn { fn poll_read( self: Pin<&mut Self>, cx: &mut Context, - buf: &mut ReadBuf<'_>, + buf: ReadBufCursor<'_>, ) -> Poll> { let this = self.project(); - AsyncRead::poll_read(this.inner, cx, buf) + Read::poll_read(this.inner, cx, buf) } } - impl AsyncWrite for NativeTlsConn { + impl Write for NativeTlsConn { fn poll_write( self: Pin<&mut Self>, cx: &mut Context, buf: &[u8], ) -> Poll> { let this = self.project(); - AsyncWrite::poll_write(this.inner, cx, buf) + Write::poll_write(this.inner, cx, buf) } fn poll_write_vectored( @@ -803,7 +829,7 @@ mod native_tls_conn { bufs: &[IoSlice<'_>], ) -> Poll> { let this = self.project(); - AsyncWrite::poll_write_vectored(this.inner, cx, bufs) + Write::poll_write_vectored(this.inner, cx, bufs) } fn is_write_vectored(&self) -> bool { @@ -815,7 +841,7 @@ mod native_tls_conn { cx: &mut Context, ) -> Poll> { let this = self.project(); - AsyncWrite::poll_flush(this.inner, cx) + Write::poll_flush(this.inner, cx) } fn poll_shutdown( @@ -823,17 +849,14 @@ mod native_tls_conn { cx: &mut Context, ) -> Poll> { let this = self.project(); - AsyncWrite::poll_shutdown(this.inner, cx) + Write::poll_shutdown(this.inner, cx) } } - impl TlsInfoFactory for NativeTlsConn { - fn tls_info(&self) -> Option { - self.inner.tls_info() - } - } - - impl TlsInfoFactory for NativeTlsConn> { + impl TlsInfoFactory for NativeTlsConn + where + TokioIo>: TlsInfoFactory, + { fn tls_info(&self) -> Option { self.inner.tls_info() } @@ -843,51 +866,76 @@ mod native_tls_conn { #[cfg(feature = "__rustls")] mod rustls_tls_conn { use super::TlsInfoFactory; - use hyper::client::connect::{Connected, Connection}; + use hyper::rt::{Read, ReadBufCursor, Write}; + use hyper_rustls::MaybeHttpsStream; + use hyper_util::client::legacy::connect::{Connected, Connection}; + use hyper_util::rt::TokioIo; use pin_project_lite::pin_project; use std::{ io::{self, IoSlice}, pin::Pin, task::{Context, Poll}, }; - use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; + use tokio::io::{AsyncRead, AsyncWrite}; + use tokio::net::TcpStream; use tokio_rustls::client::TlsStream; pin_project! { pub(super) struct RustlsTlsConn { - #[pin] pub(super) inner: TlsStream, + #[pin] pub(super) inner: TokioIo>, } } - impl Connection for RustlsTlsConn { + impl Connection for RustlsTlsConn>> { fn connected(&self) -> Connected { - if self.inner.get_ref().1.alpn_protocol() == Some(b"h2") { - self.inner.get_ref().0.connected().negotiated_h2() + if self.inner.inner().get_ref().1.alpn_protocol() == Some(b"h2") { + self.inner + .inner() + .get_ref() + .0 + .inner() + .connected() + .negotiated_h2() } else { - self.inner.get_ref().0.connected() + self.inner.inner().get_ref().0.inner().connected() + } + } + } + impl Connection for RustlsTlsConn>>> { + fn connected(&self) -> Connected { + if self.inner.inner().get_ref().1.alpn_protocol() == Some(b"h2") { + self.inner + .inner() + .get_ref() + .0 + .inner() + .connected() + .negotiated_h2() + } else { + self.inner.inner().get_ref().0.inner().connected() } } } - impl AsyncRead for RustlsTlsConn { + impl Read for RustlsTlsConn { fn poll_read( self: Pin<&mut Self>, cx: &mut Context, - buf: &mut ReadBuf<'_>, + buf: ReadBufCursor<'_>, ) -> Poll> { let this = self.project(); - AsyncRead::poll_read(this.inner, cx, buf) + Read::poll_read(this.inner, cx, buf) } } - impl AsyncWrite for RustlsTlsConn { + impl Write for RustlsTlsConn { fn poll_write( self: Pin<&mut Self>, cx: &mut Context, buf: &[u8], ) -> Poll> { let this = self.project(); - AsyncWrite::poll_write(this.inner, cx, buf) + Write::poll_write(this.inner, cx, buf) } fn poll_write_vectored( @@ -896,7 +944,7 @@ mod rustls_tls_conn { bufs: &[IoSlice<'_>], ) -> Poll> { let this = self.project(); - AsyncWrite::poll_write_vectored(this.inner, cx, bufs) + Write::poll_write_vectored(this.inner, cx, bufs) } fn is_write_vectored(&self) -> bool { @@ -908,7 +956,7 @@ mod rustls_tls_conn { cx: &mut Context, ) -> Poll> { let this = self.project(); - AsyncWrite::poll_flush(this.inner, cx) + Write::poll_flush(this.inner, cx) } fn poll_shutdown( @@ -916,17 +964,13 @@ mod rustls_tls_conn { cx: &mut Context, ) -> Poll> { let this = self.project(); - AsyncWrite::poll_shutdown(this.inner, cx) - } - } - - impl TlsInfoFactory for RustlsTlsConn { - fn tls_info(&self) -> Option { - self.inner.tls_info() + Write::poll_shutdown(this.inner, cx) } } - - impl TlsInfoFactory for RustlsTlsConn> { + impl TlsInfoFactory for RustlsTlsConn + where + TokioIo>: TlsInfoFactory, + { fn tls_info(&self) -> Option { self.inner.tls_info() } @@ -999,13 +1043,13 @@ mod socks { } mod verbose { - use hyper::client::connect::{Connected, Connection}; + use hyper::rt::{Read, ReadBufCursor, Write}; + use hyper_util::client::legacy::connect::{Connected, Connection}; use std::cmp::min; use std::fmt; use std::io::{self, IoSlice}; use std::pin::Pin; use std::task::{Context, Poll}; - use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; pub(super) const OFF: Wrapper = Wrapper(false); @@ -1031,22 +1075,25 @@ mod verbose { inner: T, } - impl Connection for Verbose { + impl Connection for Verbose { fn connected(&self) -> Connected { self.inner.connected() } } - impl AsyncRead for Verbose { + impl Read for Verbose { fn poll_read( mut self: Pin<&mut Self>, cx: &mut Context, - buf: &mut ReadBuf<'_>, + buf: ReadBufCursor<'_>, ) -> Poll> { match Pin::new(&mut self.inner).poll_read(cx, buf) { Poll::Ready(Ok(())) => { + /* log::trace!("{:08x} read: {:?}", self.id, Escape(buf.filled())); Poll::Ready(Ok(())) + */ + todo!("verbose poll_read"); } Poll::Ready(Err(e)) => Poll::Ready(Err(e)), Poll::Pending => Poll::Pending, @@ -1054,7 +1101,7 @@ mod verbose { } } - impl AsyncWrite for Verbose { + impl Write for Verbose { fn poll_write( mut self: Pin<&mut Self>, cx: &mut Context, @@ -1170,6 +1217,7 @@ mod verbose { mod tests { use super::tunnel; use crate::proxy; + use hyper_util::rt::TokioIo; use std::io::{Read, Write}; use std::net::TcpListener; use std::thread; @@ -1232,7 +1280,7 @@ mod tests { .build() .expect("new rt"); let f = async move { - let tcp = TcpStream::connect(&addr).await?; + let tcp = TokioIo::new(TcpStream::connect(&addr).await?); let host = addr.ip().to_string(); let port = addr.port(); tunnel(tcp, host, port, ua(), None).await @@ -1250,7 +1298,7 @@ mod tests { .build() .expect("new rt"); let f = async move { - let tcp = TcpStream::connect(&addr).await?; + let tcp = TokioIo::new(TcpStream::connect(&addr).await?); let host = addr.ip().to_string(); let port = addr.port(); tunnel(tcp, host, port, ua(), None).await @@ -1268,7 +1316,7 @@ mod tests { .build() .expect("new rt"); let f = async move { - let tcp = TcpStream::connect(&addr).await?; + let tcp = TokioIo::new(TcpStream::connect(&addr).await?); let host = addr.ip().to_string(); let port = addr.port(); tunnel(tcp, host, port, ua(), None).await @@ -1292,7 +1340,7 @@ mod tests { .build() .expect("new rt"); let f = async move { - let tcp = TcpStream::connect(&addr).await?; + let tcp = TokioIo::new(TcpStream::connect(&addr).await?); let host = addr.ip().to_string(); let port = addr.port(); tunnel(tcp, host, port, ua(), None).await @@ -1314,7 +1362,7 @@ mod tests { .build() .expect("new rt"); let f = async move { - let tcp = TcpStream::connect(&addr).await?; + let tcp = TokioIo::new(TcpStream::connect(&addr).await?); let host = addr.ip().to_string(); let port = addr.port(); tunnel( diff --git a/src/dns/gai.rs b/src/dns/gai.rs index f32f3b0e0..00c981f0a 100644 --- a/src/dns/gai.rs +++ b/src/dns/gai.rs @@ -1,6 +1,6 @@ use futures_util::future::FutureExt; -use hyper::client::connect::dns::{GaiResolver as HyperGaiResolver, Name}; -use hyper::service::Service; +use hyper_util::client::legacy::connect::dns::{GaiResolver as HyperGaiResolver, Name}; +use tower_service::Service; use crate::dns::{Addrs, Resolve, Resolving}; use crate::error::BoxError; diff --git a/src/dns/resolve.rs b/src/dns/resolve.rs index 3686765a0..4c36f30ec 100644 --- a/src/dns/resolve.rs +++ b/src/dns/resolve.rs @@ -1,5 +1,5 @@ -use hyper::client::connect::dns::Name; -use hyper::service::Service; +use hyper_util::client::legacy::connect::dns::Name; +use tower_service::Service; use std::collections::HashMap; use std::future::Future; diff --git a/src/dns/trust_dns.rs b/src/dns/trust_dns.rs index 86ea5a68d..bb5afc137 100644 --- a/src/dns/trust_dns.rs +++ b/src/dns/trust_dns.rs @@ -1,6 +1,6 @@ //! DNS resolution via the [trust_dns_resolver](https://github.com/bluejekyll/trust-dns) crate -use hyper::client::connect::dns::Name; +use hyper_util::client::legacy::connect::dns::Name; use once_cell::sync::OnceCell; use trust_dns_resolver::{lookup_ip::LookupIpIntoIter, system_conf, TokioAsyncResolver}; diff --git a/src/error.rs b/src/error.rs index 9453fcd92..711b03c96 100644 --- a/src/error.rs +++ b/src/error.rs @@ -121,6 +121,7 @@ impl Error { matches!(self.inner.kind, Kind::Request) } + /* #[cfg(not(target_arch = "wasm32"))] /// Returns true if the error is related to connect pub fn is_connect(&self) -> bool { @@ -138,6 +139,7 @@ impl Error { false } + */ /// Returns true if the error is related to the request or response body pub fn is_body(&self) -> bool { @@ -287,9 +289,8 @@ pub(crate) fn upgrade>(e: E) -> Error { // io::Error helpers -#[allow(unused)] -pub(crate) fn into_io(e: Error) -> io::Error { - e.into_io() +pub(crate) fn into_io(e: BoxError) -> io::Error { + io::Error::new(io::ErrorKind::Other, e) } #[allow(unused)] diff --git a/src/tls.rs b/src/tls.rs index 8d6577394..9cb49fcdc 100644 --- a/src/tls.rs +++ b/src/tls.rs @@ -13,9 +13,11 @@ #[cfg(feature = "__rustls")] use rustls::{ - client::HandshakeSignatureValid, client::ServerCertVerified, client::ServerCertVerifier, - DigitallySignedStruct, Error as TLSError, ServerName, + client::danger::HandshakeSignatureValid, client::danger::ServerCertVerified, + client::danger::ServerCertVerifier, DigitallySignedStruct, Error as TLSError, SignatureScheme, }; +#[cfg(feature = "__rustls")] +use rustls_pki_types::{ServerName, UnixTime}; use std::fmt; /// Represents a server X509 certificate. @@ -41,7 +43,6 @@ pub struct Identity { inner: ClientCert, } -#[derive(Clone)] enum ClientCert { #[cfg(feature = "native-tls")] Pkcs12(native_tls_crate::Identity), @@ -49,11 +50,32 @@ enum ClientCert { Pkcs8(native_tls_crate::Identity), #[cfg(feature = "__rustls")] Pem { - key: rustls::PrivateKey, - certs: Vec, + key: rustls_pki_types::PrivateKeyDer<'static>, + certs: Vec>, }, } +impl Clone for ClientCert { + fn clone(&self) -> Self { + match self { + #[cfg(feature = "native-tls")] + Self::Pkcs8(i) => Self::Pkcs8(i.clone()), + #[cfg(feature = "native-tls")] + Self::Pkcs12(i) => Self::Pkcs12(i.clone()), + #[cfg(feature = "__rustls")] + ClientCert::Pem { key, certs } => ClientCert::Pem { + key: key.clone_key(), + certs: certs.clone(), + }, + #[cfg_attr( + any(feature = "native-tls", feature = "__rustls"), + allow(unreachable_patterns) + )] + _ => unreachable!(), + } + } +} + impl Certificate { /// Create a `Certificate` from a binary DER encoded certificate /// @@ -119,7 +141,7 @@ impl Certificate { match self.original { Cert::Der(buf) => root_cert_store - .add(&rustls::Certificate(buf)) + .add(buf.into()) .map_err(crate::error::builder)?, Cert::Pem(buf) => { let mut pem = Cursor::new(buf); @@ -130,7 +152,7 @@ impl Certificate { })?; for c in certs { root_cert_store - .add(&rustls::Certificate(c)) + .add(c.into()) .map_err(crate::error::builder)?; } } @@ -245,8 +267,8 @@ impl Identity { let (key, certs) = { let mut pem = Cursor::new(buf); - let mut sk = Vec::::new(); - let mut certs = Vec::::new(); + let mut sk = Vec::::new(); + let mut certs = Vec::::new(); for item in std::iter::from_fn(|| rustls_pemfile::read_one(&mut pem).transpose()) { match item.map_err(|_| { @@ -254,12 +276,16 @@ impl Identity { "Invalid identity PEM file", ))) })? { - rustls_pemfile::Item::X509Certificate(cert) => { - certs.push(rustls::Certificate(cert)) + rustls_pemfile::Item::X509Certificate(cert) => certs.push(cert.into()), + rustls_pemfile::Item::PKCS8Key(key) => { + sk.push(rustls_pki_types::PrivateKeyDer::Pkcs8(key.into())) + } + rustls_pemfile::Item::RSAKey(key) => { + sk.push(rustls_pki_types::PrivateKeyDer::Pkcs1(key.into())) + } + rustls_pemfile::Item::ECKey(key) => { + sk.push(rustls_pki_types::PrivateKeyDer::Sec1(key.into())) } - rustls_pemfile::Item::PKCS8Key(key) => sk.push(rustls::PrivateKey(key)), - rustls_pemfile::Item::RSAKey(key) => sk.push(rustls::PrivateKey(key)), - rustls_pemfile::Item::ECKey(key) => sk.push(rustls::PrivateKey(key)), _ => { return Err(crate::error::builder(TLSError::General(String::from( "No valid certificate was found", @@ -302,7 +328,8 @@ impl Identity { self, config_builder: rustls::ConfigBuilder< rustls::ClientConfig, - rustls::client::WantsTransparencyPolicyOrClientCert, + // Not sure here + rustls::client::WantsClientCert, >, ) -> crate::Result { match self.inner { @@ -428,18 +455,18 @@ impl Default for TlsBackend { } #[cfg(feature = "__rustls")] +#[derive(Debug)] pub(crate) struct NoVerifier; #[cfg(feature = "__rustls")] impl ServerCertVerifier for NoVerifier { fn verify_server_cert( &self, - _end_entity: &rustls::Certificate, - _intermediates: &[rustls::Certificate], + _end_entity: &rustls_pki_types::CertificateDer, + _intermediates: &[rustls_pki_types::CertificateDer], _server_name: &ServerName, - _scts: &mut dyn Iterator, _ocsp_response: &[u8], - _now: std::time::SystemTime, + _now: UnixTime, ) -> Result { Ok(ServerCertVerified::assertion()) } @@ -447,7 +474,7 @@ impl ServerCertVerifier for NoVerifier { fn verify_tls12_signature( &self, _message: &[u8], - _cert: &rustls::Certificate, + _cert: &rustls_pki_types::CertificateDer, _dss: &DigitallySignedStruct, ) -> Result { Ok(HandshakeSignatureValid::assertion()) @@ -456,11 +483,29 @@ impl ServerCertVerifier for NoVerifier { fn verify_tls13_signature( &self, _message: &[u8], - _cert: &rustls::Certificate, + _cert: &rustls_pki_types::CertificateDer, _dss: &DigitallySignedStruct, ) -> Result { Ok(HandshakeSignatureValid::assertion()) } + + fn supported_verify_schemes(&self) -> Vec { + vec![ + SignatureScheme::RSA_PKCS1_SHA1, + SignatureScheme::ECDSA_SHA1_Legacy, + SignatureScheme::RSA_PKCS1_SHA256, + SignatureScheme::ECDSA_NISTP256_SHA256, + SignatureScheme::RSA_PKCS1_SHA384, + SignatureScheme::ECDSA_NISTP384_SHA384, + SignatureScheme::RSA_PKCS1_SHA512, + SignatureScheme::ECDSA_NISTP521_SHA512, + SignatureScheme::RSA_PSS_SHA256, + SignatureScheme::RSA_PSS_SHA384, + SignatureScheme::RSA_PSS_SHA512, + SignatureScheme::ED25519, + SignatureScheme::ED448, + ] + } } /// Hyper extension carrying extra TLS layer information. diff --git a/tests/support/server.rs b/tests/support/server.rs index 5193a5fbe..90725ef06 100644 --- a/tests/support/server.rs +++ b/tests/support/server.rs @@ -6,7 +6,6 @@ use std::sync::mpsc as std_mpsc; use std::thread; use std::time::Duration; -use hyper::server::conn::AddrIncoming; use tokio::runtime; use tokio::sync::oneshot; @@ -38,19 +37,19 @@ impl Drop for Server { pub fn http(func: F) -> Server where - F: Fn(http::Request) -> Fut + Clone + Send + 'static, - Fut: Future> + Send + 'static, + F: Fn(http::Request) -> Fut + Clone + Send + 'static, + Fut: Future> + Send + 'static, { http_with_config(func, identity) } -pub fn http_with_config(func: F1, apply_config: F2) -> Server +type Builder = hyper_util::server::conn::auto::Builder; + +pub fn http_with_config(func: F1, apply_config: F2) -> Server where - F1: Fn(http::Request) -> Fut + Clone + Send + 'static, - Fut: Future> + Send + 'static, - F2: FnOnce(hyper::server::Builder) -> hyper::server::Builder - + Send - + 'static, + F1: Fn(http::Request) -> Fut + Clone + Send + 'static, + Fut: Future> + Send + 'static, + F2: FnOnce(&mut Builder) -> Bu + Send + 'static, { // Spawn new runtime in thread to prevent reactor execution context conflict thread::spawn(move || { @@ -59,17 +58,24 @@ where .build() .expect("new rt"); let srv = rt.block_on(async move { - let builder = hyper::Server::bind(&([127, 0, 0, 1], 0).into()); + let listener = tokio::net::TcpListener::bind(&([127, 0, 0, 1], 0).into()) + .await + .unwrap(); + let mut builder = + hyper_util::server::conn::auto::Builder::new(hyper_util::rt::TokioExecutor::new()); + apply_config(&mut builder); - apply_config(builder).serve(hyper::service::make_service_fn(move |_| { + while let Ok((io, _)) = listener.accept().await { let func = func.clone(); - async move { - Ok::<_, Infallible>(hyper::service::service_fn(move |req| { - let fut = func(req); - async move { Ok::<_, Infallible>(fut.await) } - })) - } - })) + let svc = hyper::service::service_fn(move |req| { + let fut = func(req); + async move { Ok::<_, Infallible>(fut.await) } + }); + let fut = builder.serve_connection(hyper_util::rt::TokioIo::new(io), svc); + tokio::spawn(async move { + let _ = fut.await; + }); + } }); let addr = srv.local_addr();