Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

use reqwest_retry to automatically retry failed requests #13

Merged
merged 1 commit into from
Jun 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@ jobs:
include:
- build: macos
os: macos-latest
rust: 1.65.0
rust: 1.68.2
- build: ubuntu
os: ubuntu-latest
rust: 1.65.0
rust: 1.68.2
steps:
- uses: actions/checkout@v2
- uses: actions-rs/toolchain@v1
Expand All @@ -38,7 +38,7 @@ jobs:
- uses: actions/checkout@v2
- uses: actions-rs/toolchain@v1
with:
toolchain: 1.65.0
toolchain: 1.68.2
default: true
components: rustfmt
- run: cargo fmt -- --check
Expand All @@ -50,7 +50,7 @@ jobs:
- uses: actions/checkout@v2
- uses: actions-rs/toolchain@v1
with:
toolchain: 1.65.0
toolchain: 1.68.2
default: true
components: clippy
- uses: actions-rs/clippy-check@v1
Expand Down
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ Versioning].

## [Unreleased] <!-- #release:date -->

* Automatically retry read-only HTTP requests.

## [0.3.0] - 2023-02-26

* Add the `Client::get_tenant` method to get a tenant by ID.
Expand Down
5 changes: 4 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,16 @@ categories = ["api-bindings", "web-programming"]
keywords = ["frontegg", "front", "egg", "api", "sdk"]
repository = "https://github.com/MaterializeInc/rust-frontegg"
version = "0.3.0"
rust-version = "1.65"
rust-version = "1.68.2"
edition = "2021"

[dependencies]
async-stream = "0.3.3"
futures-core = "0.3.25"
once_cell = "1.16.0"
reqwest = { version = "0.11.13", features = ["json"] }
reqwest-middleware = "0.2.2"
reqwest-retry = "0.2.2"
serde = { version = "1.0.151", features = ["derive"] }
serde_json = "1.0.91"
time = { version = "0.3.17", features = ["serde", "serde-human-readable"] }
Expand All @@ -30,6 +32,7 @@ tokio = { version = "1.23.0", features = ["macros"] }
tokio-stream = "0.1.11"
tracing = "0.1.37"
tracing-subscriber = { version = "0.3.16", features = ["env-filter"] }
wiremock = "0.5.19"

[package.metadata.docs.rs]
all-features = true
Expand Down
15 changes: 12 additions & 3 deletions src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@

use std::time::{Duration, SystemTime};

use reqwest::{Method, RequestBuilder, Url};
use reqwest::{Method, Url};
use reqwest_middleware::{ClientWithMiddleware, RequestBuilder};
use serde::de::DeserializeOwned;
use serde::{Deserialize, Serialize};
use tokio::sync::Mutex;
Expand All @@ -38,7 +39,8 @@ const AUTH_VENDOR_PATH: [&str; 2] = ["auth", "vendor"];
/// [`Arc`]: std::sync::Arc
#[derive(Debug)]
pub struct Client {
pub(crate) inner: reqwest::Client,
pub(crate) client_retryable: ClientWithMiddleware,
pub(crate) client_non_retryable: ClientWithMiddleware,
pub(crate) client_id: String,
pub(crate) secret_key: String,
pub(crate) vendor_endpoint: Url,
Expand Down Expand Up @@ -67,7 +69,14 @@ impl Client {
.expect("builder validated URL can be a base")
.clear()
.extend(path);
self.inner.request(method, url)
match method {
// GET and HEAD requests are idempotent and we can safely retry
// them without fear of duplicating data.
Method::GET | Method::HEAD => self.client_retryable.request(method, url),
// All other requests are assumed to be mutating and therefore
// we leave it to the caller to retry them.
_ => self.client_non_retryable.request(method, url),
}
}

async fn send_request<T>(&self, req: RequestBuilder) -> Result<T, Error>
Expand Down
2 changes: 1 addition & 1 deletion src/client/users.rs
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ pub struct User {

/// Binds a [`User`] to a [`Tenant`] for a `frontegg.user.*` webhook event
///
/// [`Tenant`]: crate::client::tenant::Tenant
/// [`Tenant`]: crate::client::tenants::Tenant
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct WebhookTenantBinding {
Expand Down
36 changes: 33 additions & 3 deletions src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@ use std::time::Duration;

use once_cell::sync::Lazy;
use reqwest::Url;
use reqwest_retry::policies::ExponentialBackoff;
use reqwest_retry::RetryTransientMiddleware;

use crate::Client;
use crate::client::Client;

pub static DEFAULT_VENDOR_ENDPOINT: Lazy<Url> = Lazy::new(|| {
"https://api.frontegg.com"
Expand All @@ -37,27 +39,55 @@ pub struct ClientConfig {
/// A builder for a [`Client`].
pub struct ClientBuilder {
vendor_endpoint: Url,
retry_policy: Option<ExponentialBackoff>,
}

impl Default for ClientBuilder {
fn default() -> ClientBuilder {
ClientBuilder {
vendor_endpoint: DEFAULT_VENDOR_ENDPOINT.clone(),
retry_policy: Some(
ExponentialBackoff::builder()
.retry_bounds(Duration::from_millis(100), Duration::from_secs(3))
.build_with_max_retries(5),
),
}
}
}

impl ClientBuilder {
/// Sets the policy for retrying failed read-only API calls.
///
/// Note that the created [`Client`] will retry only read-only API calls,
/// like [`get_tenant`](Client::get_tenant), but not mutating API calls,
/// like [`create_user`](Client::create_user).
pub fn with_retry_policy(mut self, policy: ExponentialBackoff) -> Self {
self.retry_policy = Some(policy);
self
}

/// Sets the vendor endpoint.
pub fn with_vendor_endpoint(mut self, endpoint: Url) -> Self {
self.vendor_endpoint = endpoint;
self
}

/// Creates a [`Client`] that incorporates the optional parameters
/// configured on the builder and the specified required parameters.
pub fn build(self, config: ClientConfig) -> Client {
let inner = reqwest::ClientBuilder::new()
let client = reqwest::ClientBuilder::new()
.redirect(reqwest::redirect::Policy::none())
.timeout(Duration::from_secs(60))
.build()
.unwrap();
Client {
inner,
client_retryable: match self.retry_policy {
Some(policy) => reqwest_middleware::ClientBuilder::new(client.clone())
.with(RetryTransientMiddleware::new_with_policy(policy))
.build(),
None => reqwest_middleware::ClientBuilder::new(client.clone()).build(),
},
client_non_retryable: reqwest_middleware::ClientBuilder::new(client).build(),
client_id: config.client_id,
secret_key: config.secret_key,
vendor_endpoint: self.vendor_endpoint,
Expand Down
10 changes: 8 additions & 2 deletions src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ use reqwest::StatusCode;
#[derive(Debug)]
pub enum Error {
/// An error in the underlying transport.
Transport(reqwest::Error),
Transport(reqwest_middleware::Error),
/// An error returned by the API.
Api(ApiError),
}
Expand Down Expand Up @@ -61,9 +61,15 @@ impl fmt::Display for ApiError {

impl std::error::Error for ApiError {}

impl From<reqwest_middleware::Error> for Error {
fn from(e: reqwest_middleware::Error) -> Error {
Error::Transport(e)
}
}

impl From<reqwest::Error> for Error {
fn from(e: reqwest::Error) -> Error {
Error::Transport(e)
Error::Transport(reqwest_middleware::Error::from(e))
}
}

Expand Down
2 changes: 1 addition & 1 deletion src/util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
use std::fmt;
use std::iter;

use reqwest::RequestBuilder;
use reqwest_middleware::RequestBuilder;
use uuid::Uuid;

pub trait RequestBuilderExt {
Expand Down
70 changes: 68 additions & 2 deletions tests/api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,17 @@

use std::collections::HashSet;
use std::env;
use std::time::Duration;

use futures::stream::TryStreamExt;
use once_cell::sync::Lazy;
use reqwest::StatusCode;
use reqwest_retry::policies::ExponentialBackoff;
use serde_json::json;
use test_log::test;
use tracing::info;
use uuid::Uuid;
use wiremock::{matchers, Mock, MockServer, ResponseTemplate};

use frontegg::{ApiError, Client, ClientConfig, Error, TenantRequest, UserListConfig, UserRequest};

Expand All @@ -60,6 +63,69 @@ async fn delete_existing_tenants(client: &Client) {
}
}

/// Tests that errors are retried automatically by the client for read API calls
/// but not for write API calls.
#[test(tokio::test)]
async fn test_retries_with_mock_server() {
// Start a mock Frontegg API server and a client configured to target that
// server. The retry policy disables backoff to speed up the tests.
const MAX_RETRIES: u32 = 3;
let server = MockServer::start().await;
let client = Client::builder()
.with_vendor_endpoint(server.uri().parse().unwrap())
.with_retry_policy(
ExponentialBackoff::builder()
.retry_bounds(Duration::from_millis(1), Duration::from_millis(1))
.build_with_max_retries(MAX_RETRIES),
)
.build(ClientConfig {
client_id: "".into(),
secret_key: "".into(),
});

// Register authentication handler.
let mock = Mock::given(matchers::path("/auth/vendor"))
.and(matchers::method("POST"))
.respond_with(
ResponseTemplate::new(200)
.set_body_string("{\"token\":\"test\", \"expiresIn\":2687784526}"),
)
.expect(1)
.named("auth");
server.register(mock).await;

// Register a mock for the `get_tenant` call that returns a 429 response
// code and ensure the client repeatedly retries the API call until giving
// up after `MAX_RETRIES` retries and returning the error.
let mock = Mock::given(matchers::method("GET"))
.and(matchers::path_regex("/tenants/.*"))
.respond_with(ResponseTemplate::new(429))
.expect(u64::from(MAX_RETRIES) + 1)
.named("get tenants");
server.register(mock).await;
let res = client.get_tenant(Uuid::new_v4()).await;
assert!(res.is_err());

// Register a mock for the `create_tenant` call that returns a 429 response
// code and ensure the client only tries the API call once.
let mock = Mock::given(matchers::method("POST"))
.and(matchers::path_regex("/tenants/.*"))
.respond_with(ResponseTemplate::new(429))
.expect(1)
.named("post tenants");
server.register(mock).await;
let _ = client
.create_tenant(&TenantRequest {
id: Uuid::new_v4(),
name: &format!("{TENANT_NAME_PREFIX} 1"),
metadata: json!({
"tenant_number": 1,
}),
})
.await;
}

/// Tests basic functionality of creating and retrieving tenants and users.
#[test(tokio::test)]
async fn test_tenants_and_users() {
// Set up.
Expand Down Expand Up @@ -146,8 +212,8 @@ async fn test_tenants_and_users() {
// the same properties.
let user = client.get_user(created_user.id).await.unwrap();
assert_eq!(created_user.id, user.id);
assert_eq!(created_user.name, user.name);
assert_eq!(created_user.email, user.email);
assert_eq!(user.name, name);
assert_eq!(user.email, email);
assert_eq!(user.tenants.len(), 1);
assert_eq!(user.tenants[0].tenant_id, tenant.id);

Expand Down