Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add support for retrying 429 (too many requests) responses #45

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ Versioning].

## [Unreleased] <!-- #release:date -->

* Automatically retry HTTP requests that return status code 429. (too many requests)

## [0.11.0] - 2024-03-29

* Add `portal_url` to `Customer`.
Expand Down
3 changes: 3 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ once_cell = "1.16.0"
ordered-float = { version = "3.4.0", features = ["serde"] }
rand = "0.8.5"
reqwest = { version = "0.11.13", features = ["json"] }
reqwest-middleware = "0.2.2"
reqwest-retry = "0.2.2"
serde = { version = "1.0.151", features = ["derive"] }
serde-enum-str = "0.3.2"
serde_json = "1.0.91"
Expand All @@ -36,6 +38,7 @@ tokio = { version = "1.23.0", features = ["macros"] }
tokio-stream = "0.1.11"
tracing = "0.1.37"
tracing-subscriber = { version = "0.3.16", features = ["env-filter"] }
wiremock = "0.5.19"

[package.metadata.docs.rs]
all-features = true
Expand Down
7 changes: 5 additions & 2 deletions src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@

use async_stream::try_stream;
use futures_core::Stream;
use reqwest::{Method, RequestBuilder, Url};
use reqwest::{Method, Url};
use reqwest_middleware::{ClientWithMiddleware, RequestBuilder};
use serde::de::DeserializeOwned;
use serde::Deserialize;

Expand All @@ -39,7 +40,7 @@ pub mod taxes;
/// [`Arc`]: std::sync::Arc
#[derive(Debug)]
pub struct Client {
pub(crate) inner: reqwest::Client,
pub(crate) inner: ClientWithMiddleware,
pub(crate) api_key: String,
pub(crate) endpoint: Url,
}
Expand All @@ -65,6 +66,8 @@ impl Client {
url.path_segments_mut()
.expect("builder validated URL can be a base")
.extend(path);
// All request methods and paths are included to support retries for
// 429 status code.
self.inner.request(method, url).bearer_auth(&self.api_key)
}

Expand Down
3 changes: 2 additions & 1 deletion src/client/customers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@ use codes_iso_3166::part_1::CountryCode;
use codes_iso_4217::CurrencyCode;
use futures_core::Stream;
use futures_util::stream::TryStreamExt;
use reqwest::{Method, RequestBuilder};
use reqwest::Method;
use reqwest_middleware::RequestBuilder;
use serde::{Deserialize, Serialize};
use serde_enum_str::{Deserialize_enum_str, Serialize_enum_str};
use time::format_description::well_known::Rfc3339;
Expand Down
33 changes: 30 additions & 3 deletions src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@ use std::time::Duration;

use once_cell::sync::Lazy;
use reqwest::Url;
use reqwest_retry::policies::ExponentialBackoff;
use reqwest_retry::RetryTransientMiddleware;

use crate::Client;
use crate::client::Client;

pub static DEFAULT_ENDPOINT: Lazy<Url> = Lazy::new(|| {
"https://api.billwithorb.com/v1"
Expand All @@ -35,27 +37,52 @@ pub struct ClientConfig {
/// A builder for a [`Client`].
pub struct ClientBuilder {
endpoint: Url,
retry_policy: Option<ExponentialBackoff>,
}

impl Default for ClientBuilder {
fn default() -> ClientBuilder {
ClientBuilder {
endpoint: DEFAULT_ENDPOINT.clone(),
retry_policy: Some(
ExponentialBackoff::builder()
.retry_bounds(Duration::from_secs(1), Duration::from_secs(5))
.build_with_max_retries(5),
),
}
}
}

impl ClientBuilder {
/// Sets the policy for retrying failed API calls.
///
/// Note that the created [`Client`] will retry all API calls that return a 429 status code.
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we're using the default policy I don't believe this is entirely accurate, it'll also retry server timeouts

pub fn with_retry_policy(mut self, policy: ExponentialBackoff) -> Self {
self.retry_policy = Some(policy);
self
}

/// Sets the endpoint.
pub fn with_endpoint(mut self, endpoint: Url) -> Self {
self.endpoint = endpoint;
self
}

/// Creates a [`Client`] that incorporates the optional parameters
/// configured on the builder and the specified required parameters.
pub fn build(self, config: ClientConfig) -> Client {
let inner = reqwest::ClientBuilder::new()
let client = reqwest::ClientBuilder::new()
.redirect(reqwest::redirect::Policy::none())
.timeout(Duration::from_secs(60))
.build()
.unwrap();
Client {
inner,
inner: match self.retry_policy {
Some(policy) => reqwest_middleware::ClientBuilder::new(client.clone())
.with(RetryTransientMiddleware::new_with_policy(policy))
.build(),
None => reqwest_middleware::ClientBuilder::new(client.clone()).build(),
},
api_key: config.api_key,
endpoint: self.endpoint,
}
Expand Down
10 changes: 8 additions & 2 deletions src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ use reqwest::StatusCode;
#[derive(Debug)]
pub enum Error {
/// An error in the underlying transport.
Transport(reqwest::Error),
Transport(reqwest_middleware::Error),
/// An error returned by the API.
Api(ApiError),
/// The API returned an unexpected response.
Expand Down Expand Up @@ -79,9 +79,15 @@ impl fmt::Display for ApiError {

impl std::error::Error for ApiError {}

impl From<reqwest_middleware::Error> for Error {
fn from(e: reqwest_middleware::Error) -> Error {
Error::Transport(e)
}
}

impl From<reqwest::Error> for Error {
fn from(e: reqwest::Error) -> Error {
Error::Transport(e)
Error::Transport(reqwest_middleware::Error::from(e))
}
}

Expand Down
44 changes: 44 additions & 0 deletions tests/api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,11 @@ use futures::stream::TryStreamExt;
use once_cell::sync::Lazy;
use rand::Rng;
use reqwest::StatusCode;
use reqwest_retry::policies::ExponentialBackoff;
use test_log::test;
use tokio::time::{self, Duration};
use tracing::info;
use wiremock::{matchers, Mock, MockServer, ResponseTemplate};

use orb_billing::{
AddIncrementCreditLedgerEntryRequestParams, AddVoidCreditLedgerEntryRequestParams, Address,
Expand Down Expand Up @@ -769,3 +771,45 @@ async fn test_errors() {
let res = client.get_customer_by_external_id("$NOEXIST$").await;
assert_error_with_status_code(res, StatusCode::NOT_FOUND);
}

// Tests that 429 responses are retried automatically by the client for API calls
#[test(tokio::test)]
async fn test_retry_429() {
// Start a mock orb API server and a client configured to target that
// server. The retry policy disables backoff to speed up the tests.
const MAX_RETRIES: u32 = 3;
let server = MockServer::start().await;
let client = Client::builder()
.with_endpoint(server.uri().parse().unwrap())
.with_retry_policy(
ExponentialBackoff::builder()
.retry_bounds(Duration::from_millis(1), Duration::from_millis(1))
.build_with_max_retries(MAX_RETRIES),
)
.build(ClientConfig { api_key: "".into() });

// register a mock for the /customers endpoint that returns a 429 response
// code. Ensure the client repeatedly retries the API call until giving
// up after `MAX_RETRIES` attempts and returning the error.
let mock = Mock::given(matchers::method("POST"))
.and(matchers::path("/customers"))
.respond_with(ResponseTemplate::new(429))
.expect(u64::from(MAX_RETRIES) + 1)
.named("put customers");
server.register(mock).await;
let customer_idx = 0;
let res = client
.create_customer(&CreateCustomerRequest {
name: &format!("{TEST_PREFIX}-{customer_idx}"),
email: &format!("orb-testing-{customer_idx}@materialize.com"),
external_id: None,
payment_provider: Some(CustomerPaymentProviderRequest {
kind: PaymentProvider::Stripe,
id: &format!("cus_fake_{customer_idx}"),
}),
..Default::default()
})
.await;

assert!(res.is_err());
}
Loading