Skip to content

Commit

Permalink
feat: add paginator API
Browse files Browse the repository at this point in the history
  • Loading branch information
arlyon committed Jun 3, 2022
1 parent 8160ee4 commit 2dc1a68
Showing 1 changed file with 69 additions and 28 deletions.
97 changes: 69 additions & 28 deletions src/params.rs
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,17 @@ where
}
}

pub trait Paginable {
type O: Object + Send;
fn set_last(&mut self, item: Self::O);
}

#[derive(Debug)]
pub struct ListPaginator<T, P> {
pub page: List<T>,
pub params: P,
}

/// A single page of a cursor-paginated list of an object.
///
/// For more details, see <https://stripe.com/docs/api/pagination>
Expand Down Expand Up @@ -213,34 +224,31 @@ impl<T: Clone> Clone for List<T> {
}
}

impl<T: DeserializeOwned + Send + 'static> List<T> {
/// Prefer `List::next` when possible
pub fn get_next(client: &Client, url: &str, last_id: &str) -> Response<List<T>> {
if url.starts_with("/v1/") {
let path = url.trim_start_matches("/v1/").to_string(); // the url we get back is prefixed
client.get_query(
&path,
[("starting_after", last_id)].iter().cloned().collect::<HashMap<_, _>>(),
)
} else {
err(StripeError::UnsupportedVersion)
}
impl<T> List<T> {
pub fn paginate<P>(self, params: P) -> ListPaginator<T, P> {
ListPaginator { page: self, params }
}
}

impl<T: Paginate + DeserializeOwned + Send + 'static> List<T> {
impl<
T: Paginate + DeserializeOwned + Send + Sync + 'static + Clone + std::fmt::Debug,
P: Clone + Serialize + Send + 'static + std::fmt::Debug,
> ListPaginator<T, P>
where
P: Paginable<O = T>,
{
/// Repeatedly queries Stripe for more data until all elements in list are fetched, using
/// Stripe's default page size.
///
/// Requires `feature = "blocking"`.
#[cfg(feature = "blocking")]
pub fn get_all(self, client: &Client) -> Response<Vec<T>> {
let mut data = Vec::new();
let mut next = self;
let mut data = Vec::with_capacity(self.page.total_count);
let mut paginator = self;
loop {
if next.has_more {
let resp = next.next(client)?;
data.extend(next.data);
if paginator.page.has_more {
let resp = paginator.next(client)?;
data.extend(paginator.page.data);
next = resp;
} else {
data.extend(next.data);
Expand Down Expand Up @@ -298,18 +306,50 @@ impl<T: Paginate + DeserializeOwned + Send + 'static> List<T> {
}

/// Fetch an additional page of data from stripe.
pub fn next(&self, client: &Client) -> Response<List<T>> {
if let Some(last_id) = self.data.last().map(|d| d.cursor()) {
List::get_next(client, &self.url, last_id.as_ref())
pub fn next(&self, client: &Client) -> Response<Self> {
if let Some(last) = self.page.data.last() {
if self.page.url.starts_with("/v1/") {
let path = self.page.url.trim_start_matches("/v1/").to_string(); // the url we get back is prefixed

// clone the params and set the cursor
let params_next = {
let mut p = self.params.clone();
p.set_last(last.clone());
p
};

println!("next");
let page = client.get_query(&path, &params_next);

ListPaginator::create_paginator(page, params_next)
} else {
err(StripeError::UnsupportedVersion)
}
} else {
ok(List {
data: Vec::new(),
has_more: false,
total_count: self.total_count,
url: self.url.clone(),
ok(ListPaginator {
page: List {
data: Vec::new(),
has_more: false,
total_count: self.page.total_count,
url: self.page.url.clone(),
},
params: self.params.clone(),
})
}
}

/// Pin a new future which maps the result inside the page future into
/// a ListPaginator
#[cfg(feature = "async")]
fn create_paginator(page: Response<List<T>>, params: P) -> Response<Self> {
use futures_util::FutureExt;
Box::pin(page.map(|page| page.map(|page| ListPaginator { page, params })))
}

#[cfg(feature = "blocking")]
fn create_paginator(page: Response<List<T>>, params: P) -> Response<Self> {
ok(ListPaginator { page, params })
}
}

pub type Metadata = HashMap<String, String>;
Expand Down Expand Up @@ -439,7 +479,7 @@ mod tests {
"metadata": {},
"preferred_locales": [],
"tax_exempt": "none"
}], "has_more": true, "url": "/v1/customers"}"#,
}], "has_more": false, "url": "/v1/customers"}"#,
);
});

Expand All @@ -463,7 +503,8 @@ mod tests {
);
});

let res = Customer::list(&client, ListCustomers::new()).await.unwrap();
let params = ListCustomers::new();
let res = Customer::list(&client, &params).await.unwrap().paginate(params);

println!("{:?}", res);

Expand Down

0 comments on commit 2dc1a68

Please sign in to comment.