diff --git a/Cargo.toml b/Cargo.toml index dafba4c24..cab56b24d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -32,7 +32,7 @@ cargo-args = ["-Zunstable-options", "-Zrustdoc-scrape-examples=examples"] name = "stripe" [features] -default = ["full", "webhook-events", "uuid"] +default = ["full", "webhook-events", "uuid", "stream"] full = [ "checkout", "billing", @@ -45,6 +45,8 @@ full = [ "webhook-endpoints", ] +stream = ["futures-util"] + # stripe feature groups checkout = ["billing"] billing = [] @@ -98,6 +100,9 @@ tokio = { version = "1.2", optional = true } smart-default = "0.6.0" uuid = { version = "0.8", optional=true, features=["v4"] } +# stream for lists +futures-util = { version = "0.3.21", optional = true } + # webhook support hmac = { version = "0.12", optional = true } sha2 = { version = "0.10", optional = true } diff --git a/src/params.rs b/src/params.rs index 068fd6994..79b264f3b 100644 --- a/src/params.rs +++ b/src/params.rs @@ -250,6 +250,53 @@ impl List { Ok(data) } + /// Get all values in this List, consuming self and lazily paginating until all values are fetched. + /// + /// This function repeatedly queries Stripe for more data until all elements in list are fetched, using + /// the page size specified in params, or Stripe's default page size if none is specified. + /// + /// ```no_run + /// let value_stream = list.get_all(&client); + /// while let Some(val) = value_stream.try_next().await? { + /// println!("GOT = {:?}", val); + /// } + /// + /// // Alternatively, you can collect all values into a Vec + /// let all_values = list.get_all(&client).try_collect::().await?; + /// ``` + #[cfg(all(feature = "async", feature = "stream"))] + pub fn stream( + mut self, + client: &Client, + ) -> impl futures_util::Stream> { + // We are going to be popping items off the end of the list, so we need to reverse it. + self.page.data.reverse(); + + futures_util::stream::unfold(Some((self, client.clone())), |state| async { + let (mut paginator, client) = state?; // If none, we sent the last item in the last iteration + + if paginator.page.data.len() > 1 { + return Some((Ok(paginator.page.data.pop()?), Some((paginator, client)))); + // We have more data on this page + } + + if !paginator.page.has_more { + return Some((Ok(paginator.page.data.pop()?), None)); // Final value of the stream, no errors + } + + match paginator.next(&client).await { + Ok(mut next_paginator) => { + let data = paginator.page.data.pop()?; + next_paginator.page.data.reverse(); + + // Yield last value of thimuts page, the next page (and client) becomes the state + Some((Ok(data), Some((next_paginator, client)))) + } + Err(e) => Some((Err(e), None)), // We ran into an error. The last value of the stream will be the error. + } + }) + } + /// Fetch an additional page of data from stripe. pub fn next(&self, client: &Client) -> Response> { if let Some(last_id) = self.data.last().map(|d| d.cursor()) { @@ -427,4 +474,72 @@ mod tests { first_item.assert_hits_async(1).await; next_item.assert_hits_async(1).await; } + + #[cfg(all(feature = "async", feature = "stream"))] + #[tokio::test] + async fn stream() { + use futures_util::StreamExt; + use httpmock::Method::GET; + use httpmock::MockServer; + + use crate::Client; + use crate::{Customer, ListCustomers}; + + // Start a lightweight mock server. + let server = MockServer::start_async().await; + + let client = Client::from_url(&*server.url("/"), "fake_key"); + + let next_item = server.mock(|when, then| { + when.method(GET).path("/v1/customers").query_param("starting_after", "cus_1"); + then.status(200).body( + r#"{"object": "list", "data": [{ + "id": "cus_2", + "object": "customer", + "balance": 0, + "created": 1649316731, + "currency": "gbp", + "delinquent": false, + "email": null, + "invoice_prefix": "4AF7482", + "invoice_settings": {}, + "livemode": false, + "metadata": {}, + "preferred_locales": [], + "tax_exempt": "none" + }], "has_more": false, "url": "/v1/customers"}"#, + ); + }); + + let first_item = server.mock(|when, then| { + when.method(GET).path("/v1/customers"); + then.status(200).body( + r#"{"object": "list", "data": [{ + "id": "cus_1", + "object": "customer", + "balance": 0, + "created": 1649316731, + "currency": "gbp", + "delinquent": false, + "invoice_prefix": "4AF7482", + "invoice_settings": {}, + "livemode": false, + "metadata": {}, + "preferred_locales": [], + "tax_exempt": "none" + }], "has_more": true, "url": "/v1/customers"}"#, + ); + }); + + let params = ListCustomers::new(); + let res = Customer::list(&client, ¶ms).await.unwrap().paginate(params); + + let stream = res.stream(&client).collect::>().await; + + println!("{:#?}", stream); + assert_eq!(stream.len(), 2); + + first_item.assert_hits_async(1).await; + next_item.assert_hits_async(1).await; + } }