From 335bfd2b8165e1c6cb109b69fb4a73093ec0cf0b Mon Sep 17 00:00:00 2001 From: Marijn Suijten Date: Tue, 28 Nov 2023 17:09:04 +0100 Subject: [PATCH 1/2] curl,ureq: Use `thiserror`'s `#[from]` feature via the `Try` operator --- src/curl.rs | 81 ++++++++++++++++++++++------------------------------- src/ureq.rs | 16 +++++------ 2 files changed, 41 insertions(+), 56 deletions(-) diff --git a/src/curl.rs b/src/curl.rs index c9cc7e7..2386535 100644 --- a/src/curl.rs +++ b/src/curl.rs @@ -14,10 +14,10 @@ use super::{HttpRequest, HttpResponse}; pub enum Error { /// Error returned by curl crate. #[error("curl request failed")] - Curl(#[source] curl::Error), + Curl(#[from] curl::Error), /// Non-curl HTTP error. #[error("HTTP error")] - Http(#[source] http::Error), + Http(#[from] http::Error), /// Other error. #[error("Other error: {}", _0)] Other(String), @@ -28,34 +28,27 @@ pub enum Error { /// pub fn http_client(request: HttpRequest) -> Result { let mut easy = Easy::new(); - easy.url(&request.url.to_string()[..]) - .map_err(Error::Curl)?; + easy.url(&request.url.to_string()[..])?; let mut headers = curl::easy::List::new(); - request - .headers - .iter() - .map(|(name, value)| { - headers - .append(&format!( - "{}: {}", - name, - value.to_str().map_err(|_| Error::Other(format!( - "invalid {} header value {:?}", - name, - value.as_bytes() - )))? - )) - .map_err(Error::Curl) - }) - .collect::>()?; + for (name, value) in &request.headers { + headers.append(&format!( + "{}: {}", + name, + // TODO: Unnecessary fallibility, curl uses a CString under the hood + value.to_str().map_err(|_| Error::Other(format!( + "invalid {} header value {:?}", + name, + value.as_bytes() + )))? + ))? + } - easy.http_headers(headers).map_err(Error::Curl)?; + easy.http_headers(headers)?; if let Method::POST = request.method { - easy.post(true).map_err(Error::Curl)?; - easy.post_field_size(request.body.len() as u64) - .map_err(Error::Curl)?; + easy.post(true)?; + easy.post_field_size(request.body.len() as u64)?; } else { assert_eq!(request.method, Method::GET); } @@ -65,37 +58,29 @@ pub fn http_client(request: HttpRequest) -> Result { { let mut transfer = easy.transfer(); - transfer - .read_function(|buf| Ok(form_slice.read(buf).unwrap_or(0))) - .map_err(Error::Curl)?; + transfer.read_function(|buf| Ok(form_slice.read(buf).unwrap_or(0)))?; - transfer - .write_function(|new_data| { - data.extend_from_slice(new_data); - Ok(new_data.len()) - }) - .map_err(Error::Curl)?; + transfer.write_function(|new_data| { + data.extend_from_slice(new_data); + Ok(new_data.len()) + })?; - transfer.perform().map_err(Error::Curl)?; + transfer.perform()?; } - let status_code = easy.response_code().map_err(Error::Curl)? as u16; + let status_code = easy.response_code()? as u16; Ok(HttpResponse { - status_code: StatusCode::from_u16(status_code).map_err(|err| Error::Http(err.into()))?, + status_code: StatusCode::from_u16(status_code).map_err(http::Error::from)?, headers: easy - .content_type() - .map_err(Error::Curl)? - .map(|content_type| { - Ok(vec![( - CONTENT_TYPE, - HeaderValue::from_str(content_type).map_err(|err| Error::Http(err.into()))?, - )] - .into_iter() - .collect::()) - }) + .content_type()? + .map(|content_type| HeaderValue::from_str(content_type).map_err(http::Error::from)) .transpose()? - .unwrap_or_else(HeaderMap::new), + .map_or_else(HeaderMap::new, |content_type| { + vec![(CONTENT_TYPE, content_type)] + .into_iter() + .collect::() + }), body: data, }) } diff --git a/src/ureq.rs b/src/ureq.rs index 82ee6c3..a4eaf79 100644 --- a/src/ureq.rs +++ b/src/ureq.rs @@ -30,16 +30,18 @@ pub enum Error { /// Synchronous HTTP client for ureq. /// pub fn http_client(request: HttpRequest) -> Result { - let mut req = if let Method::POST = request.method { - ureq::post(&request.url.to_string()) + let mut req = if request.method == Method::POST { + ureq::post(request.url.as_ref()) } else { - ureq::get(&request.url.to_string()) + ureq::get(request.url.as_ref()) }; for (name, value) in request.headers { if let Some(name) = name { req = req.set( - &name.to_string(), + name.as_ref(), + // TODO: In newer `ureq` it should be easier to convert arbitrary byte sequences + // without unnecessary UTF-8 fallibility here. value.to_str().map_err(|_| { Error::Other(format!( "invalid {} header value {:?}", @@ -59,12 +61,10 @@ pub fn http_client(request: HttpRequest) -> Result { .map_err(Box::new)?; Ok(HttpResponse { - status_code: StatusCode::from_u16(response.status()) - .map_err(|err| Error::Http(err.into()))?, + status_code: StatusCode::from_u16(response.status()).map_err(http::Error::from)?, headers: vec![( CONTENT_TYPE, - HeaderValue::from_str(response.content_type()) - .map_err(|err| Error::Http(err.into()))?, + HeaderValue::from_str(response.content_type()).map_err(http::Error::from)?, )] .into_iter() .collect::(), From 85c70cb9256d221cb866c06b0b0a54f3dd4036e2 Mon Sep 17 00:00:00 2001 From: Marijn Suijten Date: Tue, 28 Nov 2023 17:15:28 +0100 Subject: [PATCH 2/2] Replace `map_err()` conversions with a `From` call via the `Try` operator The `?` or `Try` operator in the standard library calls `.into()` on the `Error` type before bubbling it up, allowing for natural conversions to other error types. `thiserror` supports marking such errors as `#[from]` (implying `#[source]`) to generate the necessary `From<>` implementation to facilitate automatic conversion from any specified `Error` type to the corresponding enum variant in our `thiserror` enums. --- src/lib.rs | 66 +++++++++++++++++------------------------------------- 1 file changed, 21 insertions(+), 45 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index fce91c6..10a17be 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1309,9 +1309,7 @@ where F: FnOnce(HttpRequest) -> Result, RE: Error + 'static, { - http_client(self.prepare_request()?) - .map_err(RequestTokenError::Request) - .and_then(endpoint_response) + endpoint_response(http_client(self.prepare_request()?)?) } /// @@ -1327,9 +1325,7 @@ where RE: Error + 'static, { let http_request = self.prepare_request()?; - let http_response = http_client(http_request) - .await - .map_err(RequestTokenError::Request)?; + let http_response = http_client(http_request).await?; endpoint_response(http_response) } } @@ -1412,9 +1408,7 @@ where F: FnOnce(HttpRequest) -> Result, RE: Error + 'static, { - http_client(self.prepare_request()?) - .map_err(RequestTokenError::Request) - .and_then(endpoint_response) + endpoint_response(http_client(self.prepare_request()?)?) } /// /// Asynchronously sends the request to the authorization server and awaits a response. @@ -1429,9 +1423,7 @@ where RE: Error + 'static, { let http_request = self.prepare_request()?; - let http_response = http_client(http_request) - .await - .map_err(RequestTokenError::Request)?; + let http_response = http_client(http_request).await?; endpoint_response(http_response) } @@ -1536,9 +1528,7 @@ where F: FnOnce(HttpRequest) -> Result, RE: Error + 'static, { - http_client(self.prepare_request()?) - .map_err(RequestTokenError::Request) - .and_then(endpoint_response) + endpoint_response(http_client(self.prepare_request()?)?) } /// @@ -1554,9 +1544,7 @@ where RE: Error + 'static, { let http_request = self.prepare_request()?; - let http_response = http_client(http_request) - .await - .map_err(RequestTokenError::Request)?; + let http_response = http_client(http_request).await?; endpoint_response(http_response) } @@ -1660,9 +1648,7 @@ where F: FnOnce(HttpRequest) -> Result, RE: Error + 'static, { - http_client(self.prepare_request()?) - .map_err(RequestTokenError::Request) - .and_then(endpoint_response) + endpoint_response(http_client(self.prepare_request()?)?) } /// @@ -1678,9 +1664,7 @@ where RE: Error + 'static, { let http_request = self.prepare_request()?; - let http_response = http_client(http_request) - .await - .map_err(RequestTokenError::Request)?; + let http_response = http_client(http_request).await?; endpoint_response(http_response) } @@ -1810,9 +1794,7 @@ where F: FnOnce(HttpRequest) -> Result, RE: Error + 'static, { - http_client(self.prepare_request()?) - .map_err(RequestTokenError::Request) - .and_then(endpoint_response) + endpoint_response(http_client(self.prepare_request()?)?) } /// @@ -1828,9 +1810,7 @@ where RE: Error + 'static, { let http_request = self.prepare_request()?; - let http_response = http_client(http_request) - .await - .map_err(RequestTokenError::Request)?; + let http_response = http_client(http_request).await?; endpoint_response(http_response) } } @@ -1923,9 +1903,7 @@ where // From https://tools.ietf.org/html/rfc7009#section-2.2: // "The content of the response body is ignored by the client as all // necessary information is conveyed in the response code." - http_client(self.prepare_request()?) - .map_err(RequestTokenError::Request) - .and_then(endpoint_response_status_only) + endpoint_response_status_only(http_client(self.prepare_request()?)?) } /// @@ -1941,9 +1919,7 @@ where RE: Error + 'static, { let http_request = self.prepare_request()?; - let http_response = http_client(http_request) - .await - .map_err(RequestTokenError::Request)?; + let http_response = http_client(http_request).await?; endpoint_response_status_only(http_response) } } @@ -2224,9 +2200,7 @@ where RE: Error + 'static, EF: ExtraDeviceAuthorizationFields, { - http_client(self.prepare_request()?) - .map_err(RequestTokenError::Request) - .and_then(endpoint_response) + endpoint_response(http_client(self.prepare_request()?)?) } /// @@ -2243,9 +2217,7 @@ where EF: ExtraDeviceAuthorizationFields, { let http_request = self.prepare_request()?; - let http_response = http_client(http_request) - .await - .map_err(RequestTokenError::Request)?; + let http_response = http_client(http_request).await?; endpoint_response(http_response) } } @@ -2504,8 +2476,12 @@ where // use that, otherwise use the value given by the device authorization // response. let timeout_dur = timeout.unwrap_or_else(|| self.dev_auth_resp.expires_in()); - let chrono_timeout = chrono::Duration::from_std(timeout_dur) - .map_err(|_| RequestTokenError::Other("Failed to convert duration".to_string()))?; + let chrono_timeout = chrono::Duration::from_std(timeout_dur).map_err(|e| { + RequestTokenError::Other(format!( + "Failed to convert `{:?}` to `chrono::Duration`: {}", + timeout_dur, e + )) + })?; // Calculate the DateTime at which the request times out. let timeout_dt = (*self.time_fn)() @@ -3179,7 +3155,7 @@ where /// connectivity failed). /// #[error("Request failed")] - Request(#[source] RE), + Request(#[from] RE), /// /// Failed to parse server response. Parse errors may occur while parsing either successful /// or error responses.