From 2dc1a6873e993e7790f03b07c2a66925434bb83e Mon Sep 17 00:00:00 2001 From: Alexander Lyon Date: Fri, 3 Jun 2022 09:36:41 +0100 Subject: [PATCH] feat: add paginator API --- src/params.rs | 97 ++++++++++++++++++++++++++++++++++++--------------- 1 file changed, 69 insertions(+), 28 deletions(-) diff --git a/src/params.rs b/src/params.rs index 79b264f3b..594d029ab 100644 --- a/src/params.rs +++ b/src/params.rs @@ -185,6 +185,17 @@ where } } +pub trait Paginable { + type O: Object + Send; + fn set_last(&mut self, item: Self::O); +} + +#[derive(Debug)] +pub struct ListPaginator { + pub page: List, + pub params: P, +} + /// A single page of a cursor-paginated list of an object. /// /// For more details, see @@ -213,34 +224,31 @@ impl Clone for List { } } -impl List { - /// Prefer `List::next` when possible - pub fn get_next(client: &Client, url: &str, last_id: &str) -> Response> { - if url.starts_with("/v1/") { - let path = url.trim_start_matches("/v1/").to_string(); // the url we get back is prefixed - client.get_query( - &path, - [("starting_after", last_id)].iter().cloned().collect::>(), - ) - } else { - err(StripeError::UnsupportedVersion) - } +impl List { + pub fn paginate

(self, params: P) -> ListPaginator { + ListPaginator { page: self, params } } } -impl List { +impl< + T: Paginate + DeserializeOwned + Send + Sync + 'static + Clone + std::fmt::Debug, + P: Clone + Serialize + Send + 'static + std::fmt::Debug, + > ListPaginator +where + P: Paginable, +{ /// Repeatedly queries Stripe for more data until all elements in list are fetched, using /// Stripe's default page size. /// /// Requires `feature = "blocking"`. #[cfg(feature = "blocking")] pub fn get_all(self, client: &Client) -> Response> { - let mut data = Vec::new(); - let mut next = self; + let mut data = Vec::with_capacity(self.page.total_count); + let mut paginator = self; loop { - if next.has_more { - let resp = next.next(client)?; - data.extend(next.data); + if paginator.page.has_more { + let resp = paginator.next(client)?; + data.extend(paginator.page.data); next = resp; } else { data.extend(next.data); @@ -298,18 +306,50 @@ impl List { } /// Fetch an additional page of data from stripe. - pub fn next(&self, client: &Client) -> Response> { - if let Some(last_id) = self.data.last().map(|d| d.cursor()) { - List::get_next(client, &self.url, last_id.as_ref()) + pub fn next(&self, client: &Client) -> Response { + if let Some(last) = self.page.data.last() { + if self.page.url.starts_with("/v1/") { + let path = self.page.url.trim_start_matches("/v1/").to_string(); // the url we get back is prefixed + + // clone the params and set the cursor + let params_next = { + let mut p = self.params.clone(); + p.set_last(last.clone()); + p + }; + + println!("next"); + let page = client.get_query(&path, ¶ms_next); + + ListPaginator::create_paginator(page, params_next) + } else { + err(StripeError::UnsupportedVersion) + } } else { - ok(List { - data: Vec::new(), - has_more: false, - total_count: self.total_count, - url: self.url.clone(), + ok(ListPaginator { + page: List { + data: Vec::new(), + has_more: false, + total_count: self.page.total_count, + url: self.page.url.clone(), + }, + params: self.params.clone(), }) } } + + /// Pin a new future which maps the result inside the page future into + /// a ListPaginator + #[cfg(feature = "async")] + fn create_paginator(page: Response>, params: P) -> Response { + use futures_util::FutureExt; + Box::pin(page.map(|page| page.map(|page| ListPaginator { page, params }))) + } + + #[cfg(feature = "blocking")] + fn create_paginator(page: Response>, params: P) -> Response { + ok(ListPaginator { page, params }) + } } pub type Metadata = HashMap; @@ -439,7 +479,7 @@ mod tests { "metadata": {}, "preferred_locales": [], "tax_exempt": "none" - }], "has_more": true, "url": "/v1/customers"}"#, + }], "has_more": false, "url": "/v1/customers"}"#, ); }); @@ -463,7 +503,8 @@ mod tests { ); }); - let res = Customer::list(&client, ListCustomers::new()).await.unwrap(); + let params = ListCustomers::new(); + let res = Customer::list(&client, ¶ms).await.unwrap().paginate(params); println!("{:?}", res);