Skip to content

Commit

Permalink
Replace map_err() conversions with a From call via the Try oper…
Browse files Browse the repository at this point in the history
…ator (#239)

* curl,ureq: Use `thiserror`'s `#[from]` feature via the `Try` operator

* Replace `map_err()` conversions with a `From` call via the `Try` operator

The `?` or `Try` operator in the standard library calls `.into()` on the
`Error` type before bubbling it up, allowing for natural conversions to
other error types.  `thiserror` supports marking such errors as `#[from]`
(implying `#[source]`) to generate the necessary `From<>` implementation
to facilitate automatic conversion from any specified `Error` type to the
corresponding enum variant in our `thiserror` enums.
  • Loading branch information
MarijnS95 authored Nov 28, 2023
1 parent 8c31046 commit e24e255
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 101 deletions.
81 changes: 33 additions & 48 deletions src/curl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@ use super::{HttpRequest, HttpResponse};
pub enum Error {
/// Error returned by curl crate.
#[error("curl request failed")]
Curl(#[source] curl::Error),
Curl(#[from] curl::Error),
/// Non-curl HTTP error.
#[error("HTTP error")]
Http(#[source] http::Error),
Http(#[from] http::Error),
/// Other error.
#[error("Other error: {}", _0)]
Other(String),
Expand All @@ -28,34 +28,27 @@ pub enum Error {
///
pub fn http_client(request: HttpRequest) -> Result<HttpResponse, Error> {
let mut easy = Easy::new();
easy.url(&request.url.to_string()[..])
.map_err(Error::Curl)?;
easy.url(&request.url.to_string()[..])?;

let mut headers = curl::easy::List::new();
request
.headers
.iter()
.map(|(name, value)| {
headers
.append(&format!(
"{}: {}",
name,
value.to_str().map_err(|_| Error::Other(format!(
"invalid {} header value {:?}",
name,
value.as_bytes()
)))?
))
.map_err(Error::Curl)
})
.collect::<Result<_, _>>()?;
for (name, value) in &request.headers {
headers.append(&format!(
"{}: {}",
name,
// TODO: Unnecessary fallibility, curl uses a CString under the hood
value.to_str().map_err(|_| Error::Other(format!(
"invalid {} header value {:?}",
name,
value.as_bytes()
)))?
))?
}

easy.http_headers(headers).map_err(Error::Curl)?;
easy.http_headers(headers)?;

if let Method::POST = request.method {
easy.post(true).map_err(Error::Curl)?;
easy.post_field_size(request.body.len() as u64)
.map_err(Error::Curl)?;
easy.post(true)?;
easy.post_field_size(request.body.len() as u64)?;
} else {
assert_eq!(request.method, Method::GET);
}
Expand All @@ -65,37 +58,29 @@ pub fn http_client(request: HttpRequest) -> Result<HttpResponse, Error> {
{
let mut transfer = easy.transfer();

transfer
.read_function(|buf| Ok(form_slice.read(buf).unwrap_or(0)))
.map_err(Error::Curl)?;
transfer.read_function(|buf| Ok(form_slice.read(buf).unwrap_or(0)))?;

transfer
.write_function(|new_data| {
data.extend_from_slice(new_data);
Ok(new_data.len())
})
.map_err(Error::Curl)?;
transfer.write_function(|new_data| {
data.extend_from_slice(new_data);
Ok(new_data.len())
})?;

transfer.perform().map_err(Error::Curl)?;
transfer.perform()?;
}

let status_code = easy.response_code().map_err(Error::Curl)? as u16;
let status_code = easy.response_code()? as u16;

Ok(HttpResponse {
status_code: StatusCode::from_u16(status_code).map_err(|err| Error::Http(err.into()))?,
status_code: StatusCode::from_u16(status_code).map_err(http::Error::from)?,
headers: easy
.content_type()
.map_err(Error::Curl)?
.map(|content_type| {
Ok(vec![(
CONTENT_TYPE,
HeaderValue::from_str(content_type).map_err(|err| Error::Http(err.into()))?,
)]
.into_iter()
.collect::<HeaderMap>())
})
.content_type()?
.map(|content_type| HeaderValue::from_str(content_type).map_err(http::Error::from))
.transpose()?
.unwrap_or_else(HeaderMap::new),
.map_or_else(HeaderMap::new, |content_type| {
vec![(CONTENT_TYPE, content_type)]
.into_iter()
.collect::<HeaderMap>()
}),
body: data,
})
}
66 changes: 21 additions & 45 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1309,9 +1309,7 @@ where
F: FnOnce(HttpRequest) -> Result<HttpResponse, RE>,
RE: Error + 'static,
{
http_client(self.prepare_request()?)
.map_err(RequestTokenError::Request)
.and_then(endpoint_response)
endpoint_response(http_client(self.prepare_request()?)?)
}

///
Expand All @@ -1327,9 +1325,7 @@ where
RE: Error + 'static,
{
let http_request = self.prepare_request()?;
let http_response = http_client(http_request)
.await
.map_err(RequestTokenError::Request)?;
let http_response = http_client(http_request).await?;
endpoint_response(http_response)
}
}
Expand Down Expand Up @@ -1412,9 +1408,7 @@ where
F: FnOnce(HttpRequest) -> Result<HttpResponse, RE>,
RE: Error + 'static,
{
http_client(self.prepare_request()?)
.map_err(RequestTokenError::Request)
.and_then(endpoint_response)
endpoint_response(http_client(self.prepare_request()?)?)
}
///
/// Asynchronously sends the request to the authorization server and awaits a response.
Expand All @@ -1429,9 +1423,7 @@ where
RE: Error + 'static,
{
let http_request = self.prepare_request()?;
let http_response = http_client(http_request)
.await
.map_err(RequestTokenError::Request)?;
let http_response = http_client(http_request).await?;
endpoint_response(http_response)
}

Expand Down Expand Up @@ -1536,9 +1528,7 @@ where
F: FnOnce(HttpRequest) -> Result<HttpResponse, RE>,
RE: Error + 'static,
{
http_client(self.prepare_request()?)
.map_err(RequestTokenError::Request)
.and_then(endpoint_response)
endpoint_response(http_client(self.prepare_request()?)?)
}

///
Expand All @@ -1554,9 +1544,7 @@ where
RE: Error + 'static,
{
let http_request = self.prepare_request()?;
let http_response = http_client(http_request)
.await
.map_err(RequestTokenError::Request)?;
let http_response = http_client(http_request).await?;
endpoint_response(http_response)
}

Expand Down Expand Up @@ -1660,9 +1648,7 @@ where
F: FnOnce(HttpRequest) -> Result<HttpResponse, RE>,
RE: Error + 'static,
{
http_client(self.prepare_request()?)
.map_err(RequestTokenError::Request)
.and_then(endpoint_response)
endpoint_response(http_client(self.prepare_request()?)?)
}

///
Expand All @@ -1678,9 +1664,7 @@ where
RE: Error + 'static,
{
let http_request = self.prepare_request()?;
let http_response = http_client(http_request)
.await
.map_err(RequestTokenError::Request)?;
let http_response = http_client(http_request).await?;
endpoint_response(http_response)
}

Expand Down Expand Up @@ -1810,9 +1794,7 @@ where
F: FnOnce(HttpRequest) -> Result<HttpResponse, RE>,
RE: Error + 'static,
{
http_client(self.prepare_request()?)
.map_err(RequestTokenError::Request)
.and_then(endpoint_response)
endpoint_response(http_client(self.prepare_request()?)?)
}

///
Expand All @@ -1828,9 +1810,7 @@ where
RE: Error + 'static,
{
let http_request = self.prepare_request()?;
let http_response = http_client(http_request)
.await
.map_err(RequestTokenError::Request)?;
let http_response = http_client(http_request).await?;
endpoint_response(http_response)
}
}
Expand Down Expand Up @@ -1923,9 +1903,7 @@ where
// From https://tools.ietf.org/html/rfc7009#section-2.2:
// "The content of the response body is ignored by the client as all
// necessary information is conveyed in the response code."
http_client(self.prepare_request()?)
.map_err(RequestTokenError::Request)
.and_then(endpoint_response_status_only)
endpoint_response_status_only(http_client(self.prepare_request()?)?)
}

///
Expand All @@ -1941,9 +1919,7 @@ where
RE: Error + 'static,
{
let http_request = self.prepare_request()?;
let http_response = http_client(http_request)
.await
.map_err(RequestTokenError::Request)?;
let http_response = http_client(http_request).await?;
endpoint_response_status_only(http_response)
}
}
Expand Down Expand Up @@ -2224,9 +2200,7 @@ where
RE: Error + 'static,
EF: ExtraDeviceAuthorizationFields,
{
http_client(self.prepare_request()?)
.map_err(RequestTokenError::Request)
.and_then(endpoint_response)
endpoint_response(http_client(self.prepare_request()?)?)
}

///
Expand All @@ -2243,9 +2217,7 @@ where
EF: ExtraDeviceAuthorizationFields,
{
let http_request = self.prepare_request()?;
let http_response = http_client(http_request)
.await
.map_err(RequestTokenError::Request)?;
let http_response = http_client(http_request).await?;
endpoint_response(http_response)
}
}
Expand Down Expand Up @@ -2504,8 +2476,12 @@ where
// use that, otherwise use the value given by the device authorization
// response.
let timeout_dur = timeout.unwrap_or_else(|| self.dev_auth_resp.expires_in());
let chrono_timeout = chrono::Duration::from_std(timeout_dur)
.map_err(|_| RequestTokenError::Other("Failed to convert duration".to_string()))?;
let chrono_timeout = chrono::Duration::from_std(timeout_dur).map_err(|e| {
RequestTokenError::Other(format!(
"Failed to convert `{:?}` to `chrono::Duration`: {}",
timeout_dur, e
))
})?;

// Calculate the DateTime at which the request times out.
let timeout_dt = (*self.time_fn)()
Expand Down Expand Up @@ -3179,7 +3155,7 @@ where
/// connectivity failed).
///
#[error("Request failed")]
Request(#[source] RE),
Request(#[from] RE),
///
/// Failed to parse server response. Parse errors may occur while parsing either successful
/// or error responses.
Expand Down
16 changes: 8 additions & 8 deletions src/ureq.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,16 +30,18 @@ pub enum Error {
/// Synchronous HTTP client for ureq.
///
pub fn http_client(request: HttpRequest) -> Result<HttpResponse, Error> {
let mut req = if let Method::POST = request.method {
ureq::post(&request.url.to_string())
let mut req = if request.method == Method::POST {
ureq::post(request.url.as_ref())
} else {
ureq::get(&request.url.to_string())
ureq::get(request.url.as_ref())
};

for (name, value) in request.headers {
if let Some(name) = name {
req = req.set(
&name.to_string(),
name.as_ref(),
// TODO: In newer `ureq` it should be easier to convert arbitrary byte sequences
// without unnecessary UTF-8 fallibility here.
value.to_str().map_err(|_| {
Error::Other(format!(
"invalid {} header value {:?}",
Expand All @@ -59,12 +61,10 @@ pub fn http_client(request: HttpRequest) -> Result<HttpResponse, Error> {
.map_err(Box::new)?;

Ok(HttpResponse {
status_code: StatusCode::from_u16(response.status())
.map_err(|err| Error::Http(err.into()))?,
status_code: StatusCode::from_u16(response.status()).map_err(http::Error::from)?,
headers: vec![(
CONTENT_TYPE,
HeaderValue::from_str(response.content_type())
.map_err(|err| Error::Http(err.into()))?,
HeaderValue::from_str(response.content_type()).map_err(http::Error::from)?,
)]
.into_iter()
.collect::<HeaderMap>(),
Expand Down

0 comments on commit e24e255

Please sign in to comment.