diff --git a/CHANGELOG.md b/CHANGELOG.md index dfa1f24..c39b303 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,8 @@ Versioning]. ## [Unreleased] +* Automatically retry HTTP requests that return status code 429. (too many requests) + ## [0.11.0] - 2024-03-29 * Add `portal_url` to `Customer`. diff --git a/Cargo.toml b/Cargo.toml index 53edc3f..b644eed 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -22,6 +22,8 @@ once_cell = "1.16.0" ordered-float = { version = "3.4.0", features = ["serde"] } rand = "0.8.5" reqwest = { version = "0.11.13", features = ["json"] } +reqwest-middleware = "0.2.2" +reqwest-retry = "0.2.2" serde = { version = "1.0.151", features = ["derive"] } serde-enum-str = "0.3.2" serde_json = "1.0.91" @@ -36,6 +38,7 @@ tokio = { version = "1.23.0", features = ["macros"] } tokio-stream = "0.1.11" tracing = "0.1.37" tracing-subscriber = { version = "0.3.16", features = ["env-filter"] } +wiremock = "0.5.19" [package.metadata.docs.rs] all-features = true diff --git a/src/client.rs b/src/client.rs index 8989f05..1b6c519 100644 --- a/src/client.rs +++ b/src/client.rs @@ -15,7 +15,8 @@ use async_stream::try_stream; use futures_core::Stream; -use reqwest::{Method, RequestBuilder, Url}; +use reqwest::{Method, Url}; +use reqwest_middleware::{ClientWithMiddleware, RequestBuilder}; use serde::de::DeserializeOwned; use serde::Deserialize; @@ -39,7 +40,7 @@ pub mod taxes; /// [`Arc`]: std::sync::Arc #[derive(Debug)] pub struct Client { - pub(crate) inner: reqwest::Client, + pub(crate) inner: ClientWithMiddleware, pub(crate) api_key: String, pub(crate) endpoint: Url, } @@ -65,6 +66,8 @@ impl Client { url.path_segments_mut() .expect("builder validated URL can be a base") .extend(path); + // All request methods and paths are included to support retries for + // 429 status code. self.inner.request(method, url).bearer_auth(&self.api_key) } diff --git a/src/client/customers.rs b/src/client/customers.rs index f41bf0e..c6b4c17 100644 --- a/src/client/customers.rs +++ b/src/client/customers.rs @@ -17,7 +17,8 @@ use codes_iso_3166::part_1::CountryCode; use codes_iso_4217::CurrencyCode; use futures_core::Stream; use futures_util::stream::TryStreamExt; -use reqwest::{Method, RequestBuilder}; +use reqwest::Method; +use reqwest_middleware::RequestBuilder; use serde::{Deserialize, Serialize}; use serde_enum_str::{Deserialize_enum_str, Serialize_enum_str}; use time::format_description::well_known::Rfc3339; diff --git a/src/config.rs b/src/config.rs index f6bd58b..f239385 100644 --- a/src/config.rs +++ b/src/config.rs @@ -17,8 +17,10 @@ use std::time::Duration; use once_cell::sync::Lazy; use reqwest::Url; +use reqwest_retry::policies::ExponentialBackoff; +use reqwest_retry::RetryTransientMiddleware; -use crate::Client; +use crate::client::Client; pub static DEFAULT_ENDPOINT: Lazy = Lazy::new(|| { "https://api.billwithorb.com/v1" @@ -35,27 +37,52 @@ pub struct ClientConfig { /// A builder for a [`Client`]. pub struct ClientBuilder { endpoint: Url, + retry_policy: Option, } impl Default for ClientBuilder { fn default() -> ClientBuilder { ClientBuilder { endpoint: DEFAULT_ENDPOINT.clone(), + retry_policy: Some( + ExponentialBackoff::builder() + .retry_bounds(Duration::from_secs(1), Duration::from_secs(5)) + .build_with_max_retries(5), + ), } } } impl ClientBuilder { + /// Sets the policy for retrying failed API calls. + /// + /// Note that the created [`Client`] will retry all API calls that return a 429 status code. + pub fn with_retry_policy(mut self, policy: ExponentialBackoff) -> Self { + self.retry_policy = Some(policy); + self + } + + /// Sets the endpoint. + pub fn with_endpoint(mut self, endpoint: Url) -> Self { + self.endpoint = endpoint; + self + } + /// Creates a [`Client`] that incorporates the optional parameters /// configured on the builder and the specified required parameters. pub fn build(self, config: ClientConfig) -> Client { - let inner = reqwest::ClientBuilder::new() + let client = reqwest::ClientBuilder::new() .redirect(reqwest::redirect::Policy::none()) .timeout(Duration::from_secs(60)) .build() .unwrap(); Client { - inner, + inner: match self.retry_policy { + Some(policy) => reqwest_middleware::ClientBuilder::new(client.clone()) + .with(RetryTransientMiddleware::new_with_policy(policy)) + .build(), + None => reqwest_middleware::ClientBuilder::new(client.clone()).build(), + }, api_key: config.api_key, endpoint: self.endpoint, } diff --git a/src/error.rs b/src/error.rs index e76db39..39dd1cd 100644 --- a/src/error.rs +++ b/src/error.rs @@ -23,7 +23,7 @@ use reqwest::StatusCode; #[derive(Debug)] pub enum Error { /// An error in the underlying transport. - Transport(reqwest::Error), + Transport(reqwest_middleware::Error), /// An error returned by the API. Api(ApiError), /// The API returned an unexpected response. @@ -79,9 +79,15 @@ impl fmt::Display for ApiError { impl std::error::Error for ApiError {} +impl From for Error { + fn from(e: reqwest_middleware::Error) -> Error { + Error::Transport(e) + } +} + impl From for Error { fn from(e: reqwest::Error) -> Error { - Error::Transport(e) + Error::Transport(reqwest_middleware::Error::from(e)) } } diff --git a/tests/api.rs b/tests/api.rs index 4482cee..6b745ae 100644 --- a/tests/api.rs +++ b/tests/api.rs @@ -36,9 +36,11 @@ use futures::stream::TryStreamExt; use once_cell::sync::Lazy; use rand::Rng; use reqwest::StatusCode; +use reqwest_retry::policies::ExponentialBackoff; use test_log::test; use tokio::time::{self, Duration}; use tracing::info; +use wiremock::{matchers, Mock, MockServer, ResponseTemplate}; use orb_billing::{ AddIncrementCreditLedgerEntryRequestParams, AddVoidCreditLedgerEntryRequestParams, Address, @@ -769,3 +771,45 @@ async fn test_errors() { let res = client.get_customer_by_external_id("$NOEXIST$").await; assert_error_with_status_code(res, StatusCode::NOT_FOUND); } + +// Tests that 429 responses are retried automatically by the client for API calls +#[test(tokio::test)] +async fn test_retry_429() { + // Start a mock orb API server and a client configured to target that + // server. The retry policy disables backoff to speed up the tests. + const MAX_RETRIES: u32 = 3; + let server = MockServer::start().await; + let client = Client::builder() + .with_endpoint(server.uri().parse().unwrap()) + .with_retry_policy( + ExponentialBackoff::builder() + .retry_bounds(Duration::from_millis(1), Duration::from_millis(1)) + .build_with_max_retries(MAX_RETRIES), + ) + .build(ClientConfig { api_key: "".into() }); + + // register a mock for the /customers endpoint that returns a 429 response + // code. Ensure the client repeatedly retries the API call until giving + // up after `MAX_RETRIES` attempts and returning the error. + let mock = Mock::given(matchers::method("POST")) + .and(matchers::path("/customers")) + .respond_with(ResponseTemplate::new(429)) + .expect(u64::from(MAX_RETRIES) + 1) + .named("put customers"); + server.register(mock).await; + let customer_idx = 0; + let res = client + .create_customer(&CreateCustomerRequest { + name: &format!("{TEST_PREFIX}-{customer_idx}"), + email: &format!("orb-testing-{customer_idx}@materialize.com"), + external_id: None, + payment_provider: Some(CustomerPaymentProviderRequest { + kind: PaymentProvider::Stripe, + id: &format!("cus_fake_{customer_idx}"), + }), + ..Default::default() + }) + .await; + + assert!(res.is_err()); +}