diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 99a22e3..a11c686 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -16,10 +16,10 @@ jobs: include: - build: macos os: macos-latest - rust: 1.65.0 + rust: 1.68.2 - build: ubuntu os: ubuntu-latest - rust: 1.65.0 + rust: 1.68.2 steps: - uses: actions/checkout@v2 - uses: actions-rs/toolchain@v1 @@ -38,7 +38,7 @@ jobs: - uses: actions/checkout@v2 - uses: actions-rs/toolchain@v1 with: - toolchain: 1.65.0 + toolchain: 1.68.2 default: true components: rustfmt - run: cargo fmt -- --check @@ -50,7 +50,7 @@ jobs: - uses: actions/checkout@v2 - uses: actions-rs/toolchain@v1 with: - toolchain: 1.65.0 + toolchain: 1.68.2 default: true components: clippy - uses: actions-rs/clippy-check@v1 diff --git a/CHANGELOG.md b/CHANGELOG.md index 7443f36..c642273 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,8 @@ Versioning]. ## [Unreleased] +* Automatically retry read-only HTTP requests. + ## [0.3.0] - 2023-02-26 * Add the `Client::get_tenant` method to get a tenant by ID. diff --git a/Cargo.toml b/Cargo.toml index 497df79..856f3d5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -9,7 +9,7 @@ categories = ["api-bindings", "web-programming"] keywords = ["frontegg", "front", "egg", "api", "sdk"] repository = "https://github.com/MaterializeInc/rust-frontegg" version = "0.3.0" -rust-version = "1.65" +rust-version = "1.68.2" edition = "2021" [dependencies] @@ -17,6 +17,8 @@ async-stream = "0.3.3" futures-core = "0.3.25" once_cell = "1.16.0" reqwest = { version = "0.11.13", features = ["json"] } +reqwest-middleware = "0.2.2" +reqwest-retry = "0.2.2" serde = { version = "1.0.151", features = ["derive"] } serde_json = "1.0.91" time = { version = "0.3.17", features = ["serde", "serde-human-readable"] } @@ -30,6 +32,7 @@ tokio = { version = "1.23.0", features = ["macros"] } tokio-stream = "0.1.11" tracing = "0.1.37" tracing-subscriber = { version = "0.3.16", features = ["env-filter"] } +wiremock = "0.5.19" [package.metadata.docs.rs] all-features = true diff --git a/src/client.rs b/src/client.rs index d459a01..bd08071 100644 --- a/src/client.rs +++ b/src/client.rs @@ -15,7 +15,8 @@ use std::time::{Duration, SystemTime}; -use reqwest::{Method, RequestBuilder, Url}; +use reqwest::{Method, Url}; +use reqwest_middleware::{ClientWithMiddleware, RequestBuilder}; use serde::de::DeserializeOwned; use serde::{Deserialize, Serialize}; use tokio::sync::Mutex; @@ -38,7 +39,8 @@ const AUTH_VENDOR_PATH: [&str; 2] = ["auth", "vendor"]; /// [`Arc`]: std::sync::Arc #[derive(Debug)] pub struct Client { - pub(crate) inner: reqwest::Client, + pub(crate) client_retryable: ClientWithMiddleware, + pub(crate) client_non_retryable: ClientWithMiddleware, pub(crate) client_id: String, pub(crate) secret_key: String, pub(crate) vendor_endpoint: Url, @@ -67,7 +69,14 @@ impl Client { .expect("builder validated URL can be a base") .clear() .extend(path); - self.inner.request(method, url) + match method { + // GET and HEAD requests are idempotent and we can safely retry + // them without fear of duplicating data. + Method::GET | Method::HEAD => self.client_retryable.request(method, url), + // All other requests are assumed to be mutating and therefore + // we leave it to the caller to retry them. + _ => self.client_non_retryable.request(method, url), + } } async fn send_request(&self, req: RequestBuilder) -> Result diff --git a/src/client/users.rs b/src/client/users.rs index b7ff0bd..6d357f7 100644 --- a/src/client/users.rs +++ b/src/client/users.rs @@ -173,7 +173,7 @@ pub struct User { /// Binds a [`User`] to a [`Tenant`] for a `frontegg.user.*` webhook event /// -/// [`Tenant`]: crate::client::tenant::Tenant +/// [`Tenant`]: crate::client::tenants::Tenant #[derive(Debug, Clone, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct WebhookTenantBinding { diff --git a/src/config.rs b/src/config.rs index b6a2060..0a8db4a 100644 --- a/src/config.rs +++ b/src/config.rs @@ -17,8 +17,10 @@ use std::time::Duration; use once_cell::sync::Lazy; use reqwest::Url; +use reqwest_retry::policies::ExponentialBackoff; +use reqwest_retry::RetryTransientMiddleware; -use crate::Client; +use crate::client::Client; pub static DEFAULT_VENDOR_ENDPOINT: Lazy = Lazy::new(|| { "https://api.frontegg.com" @@ -37,27 +39,55 @@ pub struct ClientConfig { /// A builder for a [`Client`]. pub struct ClientBuilder { vendor_endpoint: Url, + retry_policy: Option, } impl Default for ClientBuilder { fn default() -> ClientBuilder { ClientBuilder { vendor_endpoint: DEFAULT_VENDOR_ENDPOINT.clone(), + retry_policy: Some( + ExponentialBackoff::builder() + .retry_bounds(Duration::from_millis(100), Duration::from_secs(3)) + .build_with_max_retries(5), + ), } } } impl ClientBuilder { + /// Sets the policy for retrying failed read-only API calls. + /// + /// Note that the created [`Client`] will retry only read-only API calls, + /// like [`get_tenant`](Client::get_tenant), but not mutating API calls, + /// like [`create_user`](Client::create_user). + pub fn with_retry_policy(mut self, policy: ExponentialBackoff) -> Self { + self.retry_policy = Some(policy); + self + } + + /// Sets the vendor endpoint. + pub fn with_vendor_endpoint(mut self, endpoint: Url) -> Self { + self.vendor_endpoint = endpoint; + self + } + /// Creates a [`Client`] that incorporates the optional parameters /// configured on the builder and the specified required parameters. pub fn build(self, config: ClientConfig) -> Client { - let inner = reqwest::ClientBuilder::new() + let client = reqwest::ClientBuilder::new() .redirect(reqwest::redirect::Policy::none()) .timeout(Duration::from_secs(60)) .build() .unwrap(); Client { - inner, + client_retryable: match self.retry_policy { + Some(policy) => reqwest_middleware::ClientBuilder::new(client.clone()) + .with(RetryTransientMiddleware::new_with_policy(policy)) + .build(), + None => reqwest_middleware::ClientBuilder::new(client.clone()).build(), + }, + client_non_retryable: reqwest_middleware::ClientBuilder::new(client).build(), client_id: config.client_id, secret_key: config.secret_key, vendor_endpoint: self.vendor_endpoint, diff --git a/src/error.rs b/src/error.rs index bce3de0..427749b 100644 --- a/src/error.rs +++ b/src/error.rs @@ -23,7 +23,7 @@ use reqwest::StatusCode; #[derive(Debug)] pub enum Error { /// An error in the underlying transport. - Transport(reqwest::Error), + Transport(reqwest_middleware::Error), /// An error returned by the API. Api(ApiError), } @@ -61,9 +61,15 @@ impl fmt::Display for ApiError { impl std::error::Error for ApiError {} +impl From for Error { + fn from(e: reqwest_middleware::Error) -> Error { + Error::Transport(e) + } +} + impl From for Error { fn from(e: reqwest::Error) -> Error { - Error::Transport(e) + Error::Transport(reqwest_middleware::Error::from(e)) } } diff --git a/src/util.rs b/src/util.rs index 800deb6..351799f 100644 --- a/src/util.rs +++ b/src/util.rs @@ -16,7 +16,7 @@ use std::fmt; use std::iter; -use reqwest::RequestBuilder; +use reqwest_middleware::RequestBuilder; use uuid::Uuid; pub trait RequestBuilderExt { diff --git a/tests/api.rs b/tests/api.rs index ee9c13b..9e9b203 100644 --- a/tests/api.rs +++ b/tests/api.rs @@ -26,14 +26,17 @@ use std::collections::HashSet; use std::env; +use std::time::Duration; use futures::stream::TryStreamExt; use once_cell::sync::Lazy; use reqwest::StatusCode; +use reqwest_retry::policies::ExponentialBackoff; use serde_json::json; use test_log::test; use tracing::info; use uuid::Uuid; +use wiremock::{matchers, Mock, MockServer, ResponseTemplate}; use frontegg::{ApiError, Client, ClientConfig, Error, TenantRequest, UserListConfig, UserRequest}; @@ -60,6 +63,69 @@ async fn delete_existing_tenants(client: &Client) { } } +/// Tests that errors are retried automatically by the client for read API calls +/// but not for write API calls. +#[test(tokio::test)] +async fn test_retries_with_mock_server() { + // Start a mock Frontegg API server and a client configured to target that + // server. The retry policy disables backoff to speed up the tests. + const MAX_RETRIES: u32 = 3; + let server = MockServer::start().await; + let client = Client::builder() + .with_vendor_endpoint(server.uri().parse().unwrap()) + .with_retry_policy( + ExponentialBackoff::builder() + .retry_bounds(Duration::from_millis(1), Duration::from_millis(1)) + .build_with_max_retries(MAX_RETRIES), + ) + .build(ClientConfig { + client_id: "".into(), + secret_key: "".into(), + }); + + // Register authentication handler. + let mock = Mock::given(matchers::path("/auth/vendor")) + .and(matchers::method("POST")) + .respond_with( + ResponseTemplate::new(200) + .set_body_string("{\"token\":\"test\", \"expiresIn\":2687784526}"), + ) + .expect(1) + .named("auth"); + server.register(mock).await; + + // Register a mock for the `get_tenant` call that returns a 429 response + // code and ensure the client repeatedly retries the API call until giving + // up after `MAX_RETRIES` retries and returning the error. + let mock = Mock::given(matchers::method("GET")) + .and(matchers::path_regex("/tenants/.*")) + .respond_with(ResponseTemplate::new(429)) + .expect(u64::from(MAX_RETRIES) + 1) + .named("get tenants"); + server.register(mock).await; + let res = client.get_tenant(Uuid::new_v4()).await; + assert!(res.is_err()); + + // Register a mock for the `create_tenant` call that returns a 429 response + // code and ensure the client only tries the API call once. + let mock = Mock::given(matchers::method("POST")) + .and(matchers::path_regex("/tenants/.*")) + .respond_with(ResponseTemplate::new(429)) + .expect(1) + .named("post tenants"); + server.register(mock).await; + let _ = client + .create_tenant(&TenantRequest { + id: Uuid::new_v4(), + name: &format!("{TENANT_NAME_PREFIX} 1"), + metadata: json!({ + "tenant_number": 1, + }), + }) + .await; +} + +/// Tests basic functionality of creating and retrieving tenants and users. #[test(tokio::test)] async fn test_tenants_and_users() { // Set up. @@ -146,8 +212,8 @@ async fn test_tenants_and_users() { // the same properties. let user = client.get_user(created_user.id).await.unwrap(); assert_eq!(created_user.id, user.id); - assert_eq!(created_user.name, user.name); - assert_eq!(created_user.email, user.email); + assert_eq!(user.name, name); + assert_eq!(user.email, email); assert_eq!(user.tenants.len(), 1); assert_eq!(user.tenants[0].tenant_id, tenant.id);