Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: simplify Service impls #1861

Merged
merged 1 commit into from
Dec 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions crates/json-rpc/src/result.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,7 @@ pub type BorrowedRpcResult<'a, E> = RpcResult<&'a RawValue, E, &'a RawValue>;
/// Transform a transport response into an [`RpcResult`], discarding the [`Id`].
///
/// [`Id`]: crate::Id
pub fn transform_response<T, E, ErrResp>(
response: Response<T, ErrResp>,
) -> Result<T, RpcError<E, ErrResp>>
pub fn transform_response<T, E, ErrResp>(response: Response<T, ErrResp>) -> RpcResult<T, E, ErrResp>
where
ErrResp: RpcReturn,
{
Expand Down
112 changes: 51 additions & 61 deletions crates/transport-http/src/hyper_transport.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use alloy_json_rpc::{RequestPacket, ResponsePacket};
use alloy_transport::{
utils::guess_local_url, TransportConnect, TransportError, TransportErrorKind, TransportFut,
TransportResult,
};
use http_body_util::{BodyExt, Full};
use hyper::{
Expand Down Expand Up @@ -79,63 +80,46 @@ where
ResBody::Error: std::error::Error + Send + Sync + 'static,
ResBody::Data: Send,
{
/// Make a request to the server using the given service.
fn request_hyper(&self, req: RequestPacket) -> TransportFut<'static> {
let this = self.clone();
let span = debug_span!("HyperClient", url = %this.url);
Box::pin(
async move {
debug!(count = req.len(), "sending request packet to server");
let ser = req.serialize().map_err(TransportError::ser_err)?;
// convert the Box<RawValue> into a hyper request<B>
let body = ser.get().as_bytes().to_owned().into();

let req = hyper::Request::builder()
.method(hyper::Method::POST)
.uri(this.url.as_str())
.header(
header::CONTENT_TYPE,
header::HeaderValue::from_static("application/json"),
)
.body(body)
.expect("request parts are invalid");

let mut service = this.client.service.clone();
let resp = service.call(req).await.map_err(TransportErrorKind::custom)?;

let status = resp.status();

debug!(%status, "received response from server");

// Unpack data from the response body. We do this regardless of
// the status code, as we want to return the error in the body
// if there is one.
let body = resp
.into_body()
.collect()
.await
.map_err(TransportErrorKind::custom)?
.to_bytes();

debug!(bytes = body.len(), "retrieved response body. Use `trace` for full body");
trace!(body = %String::from_utf8_lossy(&body), "response body");

if status != hyper::StatusCode::OK {
return Err(TransportErrorKind::http_error(
status.as_u16(),
String::from_utf8_lossy(&body).into_owned(),
));
}

// Deserialize a Box<RawValue> from the body. If deserialization fails, return
// the body as a string in the error. The conversion to String
// is lossy and may not cover all the bytes in the body.
serde_json::from_slice(&body).map_err(|err| {
TransportError::deser_err(err, String::from_utf8_lossy(body.as_ref()))
})
}
.instrument(span),
)
async fn do_hyper(self, req: RequestPacket) -> TransportResult<ResponsePacket> {
debug!(count = req.len(), "sending request packet to server");
let ser = req.serialize().map_err(TransportError::ser_err)?;
// convert the Box<RawValue> into a hyper request<B>
let body = ser.get().as_bytes().to_owned().into();

let req = hyper::Request::builder()
.method(hyper::Method::POST)
.uri(self.url.as_str())
.header(header::CONTENT_TYPE, header::HeaderValue::from_static("application/json"))
.body(body)
.expect("request parts are invalid");

let mut service = self.client.service;
let resp = service.call(req).await.map_err(TransportErrorKind::custom)?;

let status = resp.status();

debug!(%status, "received response from server");

// Unpack data from the response body. We do this regardless of
// the status code, as we want to return the error in the body
// if there is one.
let body = resp.into_body().collect().await.map_err(TransportErrorKind::custom)?.to_bytes();

debug!(bytes = body.len(), "retrieved response body. Use `trace` for full body");
trace!(body = %String::from_utf8_lossy(&body), "response body");

if status != hyper::StatusCode::OK {
return Err(TransportErrorKind::http_error(
status.as_u16(),
String::from_utf8_lossy(&body).into_owned(),
));
}

// Deserialize a Box<RawValue> from the body. If deserialization fails, return
// the body as a string in the error. The conversion to String
// is lossy and may not cover all the bytes in the body.
serde_json::from_slice(&body)
.map_err(|err| TransportError::deser_err(err, String::from_utf8_lossy(body.as_ref())))
}
}

Expand Down Expand Up @@ -168,12 +152,14 @@ where
type Error = TransportError;
type Future = TransportFut<'static>;

fn poll_ready(&mut self, _cx: &mut task::Context<'_>) -> task::Poll<Result<(), Self::Error>> {
task::Poll::Ready(Ok(()))
#[inline]
fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> task::Poll<Result<(), Self::Error>> {
(&*self).poll_ready(cx)
}

#[inline]
fn call(&mut self, req: RequestPacket) -> Self::Future {
self.request_hyper(req)
(&*self).call(req)
}
}

Expand All @@ -188,11 +174,15 @@ where
type Error = TransportError;
type Future = TransportFut<'static>;

#[inline]
fn poll_ready(&mut self, _cx: &mut task::Context<'_>) -> task::Poll<Result<(), Self::Error>> {
// `hyper` always returns `Ok(())`.
task::Poll::Ready(Ok(()))
}

fn call(&mut self, req: RequestPacket) -> Self::Future {
self.request_hyper(req)
let this = self.clone();
let span = debug_span!("HyperTransport", url = %this.url);
Box::pin(this.do_hyper(req).instrument(span))
}
}
87 changes: 40 additions & 47 deletions crates/transport-http/src/reqwest_transport.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use crate::{Http, HttpConnect};
use alloy_json_rpc::{RequestPacket, ResponsePacket};
use alloy_transport::{
utils::guess_local_url, TransportConnect, TransportError, TransportErrorKind, TransportFut,
TransportResult,
};
use std::task;
use tower::Service;
Expand Down Expand Up @@ -37,46 +38,38 @@ impl Http<Client> {
Self { client: Default::default(), url }
}

/// Make a request.
fn request_reqwest(&self, req: RequestPacket) -> TransportFut<'static> {
let this = self.clone();
let span: tracing::Span = debug_span!("ReqwestTransport", url = %self.url);
Box::pin(
async move {
let resp = this
.client
.post(this.url)
.json(&req)
.send()
.await
.map_err(TransportErrorKind::custom)?;
let status = resp.status();

debug!(%status, "received response from server");

// Unpack data from the response body. We do this regardless of
// the status code, as we want to return the error in the body
// if there is one.
let body = resp.bytes().await.map_err(TransportErrorKind::custom)?;

debug!(bytes = body.len(), "retrieved response body. Use `trace` for full body");
trace!(body = %String::from_utf8_lossy(&body), "response body");

if status != reqwest::StatusCode::OK {
return Err(TransportErrorKind::http_error(
status.as_u16(),
String::from_utf8_lossy(&body).into_owned(),
));
}

// Deserialize a Box<RawValue> from the body. If deserialization fails, return
// the body as a string in the error. The conversion to String
// is lossy and may not cover all the bytes in the body.
serde_json::from_slice(&body)
.map_err(|err| TransportError::deser_err(err, String::from_utf8_lossy(&body)))
}
.instrument(span),
)
async fn do_reqwest(self, req: RequestPacket) -> TransportResult<ResponsePacket> {
let resp = self
.client
.post(self.url)
.json(&req)
.send()
.await
.map_err(TransportErrorKind::custom)?;
let status = resp.status();

debug!(%status, "received response from server");

// Unpack data from the response body. We do this regardless of
// the status code, as we want to return the error in the body
// if there is one.
let body = resp.bytes().await.map_err(TransportErrorKind::custom)?;

debug!(bytes = body.len(), "retrieved response body. Use `trace` for full body");
trace!(body = %String::from_utf8_lossy(&body), "response body");

if status != reqwest::StatusCode::OK {
return Err(TransportErrorKind::http_error(
status.as_u16(),
String::from_utf8_lossy(&body).into_owned(),
));
}

// Deserialize a Box<RawValue> from the body. If deserialization fails, return
// the body as a string in the error. The conversion to String
// is lossy and may not cover all the bytes in the body.
serde_json::from_slice(&body)
.map_err(|err| TransportError::deser_err(err, String::from_utf8_lossy(&body)))
}
}

Expand All @@ -86,14 +79,13 @@ impl Service<RequestPacket> for Http<reqwest::Client> {
type Future = TransportFut<'static>;

#[inline]
fn poll_ready(&mut self, _cx: &mut task::Context<'_>) -> task::Poll<Result<(), Self::Error>> {
// reqwest always returns ok
task::Poll::Ready(Ok(()))
fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> task::Poll<Result<(), Self::Error>> {
(&*self).poll_ready(cx)
}

#[inline]
fn call(&mut self, req: RequestPacket) -> Self::Future {
self.request_reqwest(req)
(&*self).call(req)
}
}

Expand All @@ -104,12 +96,13 @@ impl Service<RequestPacket> for &Http<reqwest::Client> {

#[inline]
fn poll_ready(&mut self, _cx: &mut task::Context<'_>) -> task::Poll<Result<(), Self::Error>> {
// reqwest always returns ok
// `reqwest` always returns `Ok(())`.
task::Poll::Ready(Ok(()))
}

#[inline]
fn call(&mut self, req: RequestPacket) -> Self::Future {
self.request_reqwest(req)
let this = self.clone();
let span = debug_span!("ReqwestTransport", url = %this.url);
Box::pin(this.do_reqwest(req).instrument(span))
}
}
Loading