From bc20bd46f67ffb384d67d0605cb5126e1f3d6333 Mon Sep 17 00:00:00 2001 From: skytz Date: Sat, 13 Jan 2024 03:28:22 +0200 Subject: [PATCH 1/5] fix(pagination): prevent infinite loop caused by clone --- src/params.rs | 51 ++++++++++++++------------------------------------- 1 file changed, 14 insertions(+), 37 deletions(-) diff --git a/src/params.rs b/src/params.rs index 1d89de6d5..bcfea9d5f 100644 --- a/src/params.rs +++ b/src/params.rs @@ -182,7 +182,7 @@ pub trait Paginable { pub trait PaginableList { type O: Paginate + DeserializeOwned + Send + Sync + 'static + Clone + std::fmt::Debug; fn new(data: Vec, url: String, has_more: bool, total_count: Option) -> Self; - fn get_data(&self) -> Vec; + fn get_data(&mut self) -> &mut Vec; fn get_url(&self) -> String; fn get_total_count(&self) -> Option; fn has_more(&self) -> bool; @@ -191,7 +191,7 @@ pub trait PaginableList { /// A single page of a cursor-paginated list of a search object. /// /// For more details, see -#[derive(Debug, Deserialize, Serialize)] +#[derive(Debug, Clone, Deserialize, Serialize)] pub struct SearchList { pub object: String, pub url: String, @@ -214,19 +214,6 @@ impl Default for SearchList { } } -impl Clone for SearchList { - fn clone(&self) -> Self { - SearchList { - object: self.object.clone(), - data: self.data.clone(), - has_more: self.has_more, - total_count: self.total_count, - url: self.url.clone(), - next_page: self.next_page.clone(), - } - } -} - impl PaginableList for SearchList { @@ -241,8 +228,8 @@ impl Vec { - self.data.clone() + fn get_data(&mut self) -> &mut Vec { + &mut self.data } fn get_url(&self) -> String { self.url.clone() @@ -264,8 +251,8 @@ impl Vec { - self.data.clone() + fn get_data(&mut self) -> &mut Vec { + &mut self.data } fn get_url(&self) -> String { self.url.clone() @@ -287,7 +274,7 @@ impl SearchList { /// A single page of a cursor-paginated list of an object. /// /// For more details, see -#[derive(Debug, Deserialize, Serialize)] +#[derive(Debug, Clone, Deserialize, Serialize)] pub struct List { pub data: Vec, pub has_more: bool, @@ -301,17 +288,6 @@ impl Default for List { } } -impl Clone for List { - fn clone(&self) -> Self { - List { - data: self.data.clone(), - has_more: self.has_more, - total_count: self.total_count, - url: self.url.clone(), - } - } -} - impl List { pub fn paginate

(self, params: P) -> ListPaginator, P> { ListPaginator { page: self, params } @@ -381,7 +357,7 @@ where /// Requires `feature = ["async", "stream"]`. #[cfg(all(feature = "async", feature = "stream"))] pub fn stream( - self, + mut self, client: &Client, ) -> impl futures_util::Stream> + Unpin { // We are going to be popping items off the end of the list, so we need to reverse it. @@ -395,7 +371,7 @@ where async fn unfold_stream( state: Option<(Self, Client)>, ) -> Option<(Result, Option<(Self, Client)>)> { - let (paginator, client) = state?; // If none, we sent the last item in the last iteration + let (mut paginator, client) = state?; // If none, we sent the last item in the last iteration if paginator.page.get_data().len() > 1 { return Some((Ok(paginator.page.get_data().pop()?), Some((paginator, client)))); @@ -407,7 +383,7 @@ where } match paginator.next(&client).await { - Ok(next_paginator) => { + Ok(mut next_paginator) => { let data = paginator.page.get_data().pop()?; next_paginator.page.get_data().reverse(); @@ -419,10 +395,11 @@ where } /// Fetch an additional page of data from stripe. - pub fn next(&self, client: &Client) -> Response { + pub fn next(&mut self, client: &Client) -> Response { + let page_url = self.page.get_url(); if let Some(last) = self.page.get_data().last() { - if self.page.get_url().starts_with("/v1/") { - let path = self.page.get_url().trim_start_matches("/v1/").to_string(); // the url we get back is prefixed + if page_url.starts_with("/v1/") { + let path = page_url.trim_start_matches("/v1/").to_string(); // the url we get back is prefixed // clone the params and set the cursor let params_next = { From 1510caaee708bfacd92f3e95d159373b33ed8b71 Mon Sep 17 00:00:00 2001 From: skytz Date: Sun, 14 Jan 2024 16:26:43 +0200 Subject: [PATCH 2/5] added tests for multiple results per page --- src/params.rs | 162 +++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 161 insertions(+), 1 deletion(-) diff --git a/src/params.rs b/src/params.rs index bcfea9d5f..d9600ed8e 100644 --- a/src/params.rs +++ b/src/params.rs @@ -589,7 +589,7 @@ mod tests { }); let params = ListCustomers::new(); - let res = Customer::list(&client, ¶ms).await.unwrap().paginate(params); + let mut res = Customer::list(&client, ¶ms).await.unwrap().paginate(params); println!("{:?}", res); @@ -601,6 +601,85 @@ mod tests { next_item.assert_hits_async(1).await; } + #[cfg(feature = "async")] + #[tokio::test] + async fn list_multiple() { + use httpmock::Method::GET; + use httpmock::MockServer; + + use crate::Client; + use crate::{Customer, ListCustomers}; + + // Start a lightweight mock server. + let server = MockServer::start_async().await; + + let client = Client::from_url(&*server.url("/"), "fake_key"); + + let next_item = server.mock(|when, then| { + when.method(GET).path("/v1/customers").query_param("starting_after", "cus_2"); + then.status(200).body( + r#"{"object": "list", "data": [{ + "id": "cus_2", + "object": "customer", + "balance": 0, + "created": 1649316733, + "currency": "gbp", + "delinquent": false, + "email": null, + "invoice_prefix": "4AF7482", + "invoice_settings": {}, + "livemode": false, + "metadata": {}, + "preferred_locales": [], + "tax_exempt": "none" + }], "has_more": false, "url": "/v1/customers"}"#, + ); + }); + + let first_item = server.mock(|when, then| { + when.method(GET).path("/v1/customers"); + then.status(200).body( + r#"{"object": "list", "data": [{ + "id": "cus_1", + "object": "customer", + "balance": 0, + "created": 1649316732, + "currency": "gbp", + "delinquent": false, + "invoice_prefix": "4AF7482", + "invoice_settings": {}, + "livemode": false, + "metadata": {}, + "preferred_locales": [], + "tax_exempt": "none" + }, { + "id": "cus_2", + "object": "customer", + "balance": 0, + "created": 1649316733, + "currency": "gbp", + "delinquent": false, + "invoice_prefix": "4AF7482", + "invoice_settings": {}, + "livemode": false, + "metadata": {}, + "preferred_locales": [], + "tax_exempt": "none" + }], "has_more": true, "url": "/v1/customers"}"#, + ); + }); + + let params = ListCustomers::new(); + let mut res = Customer::list(&client, ¶ms).await.unwrap().paginate(params); + + let res2 = res.next(&client).await.unwrap(); + + println!("{:?}", res2); + + first_item.assert_hits_async(1).await; + next_item.assert_hits_async(1).await; + } + #[cfg(all(feature = "async", feature = "stream"))] #[tokio::test] async fn stream() { @@ -668,4 +747,85 @@ mod tests { first_item.assert_hits_async(1).await; next_item.assert_hits_async(1).await; } + + #[cfg(all(feature = "async", feature = "stream"))] + #[tokio::test] + async fn stream_multiple() { + use futures_util::StreamExt; + use httpmock::Method::GET; + use httpmock::MockServer; + + use crate::Client; + use crate::{Customer, ListCustomers}; + + // Start a lightweight mock server. + let server = MockServer::start_async().await; + + let client = Client::from_url(&*server.url("/"), "fake_key"); + + let next_item = server.mock(|when, then| { + when.method(GET).path("/v1/customers").query_param("starting_after", "cus_2"); + then.status(200).body( + r#"{"object": "list", "data": [{ + "id": "cus_3", + "object": "customer", + "balance": 0, + "created": 1649316734, + "currency": "gbp", + "delinquent": false, + "email": null, + "invoice_prefix": "4AF7482", + "invoice_settings": {}, + "livemode": false, + "metadata": {}, + "preferred_locales": [], + "tax_exempt": "none" + }], "has_more": false, "url": "/v1/customers"}"#, + ); + }); + + let items = server.mock(|when, then| { + when.method(GET).path("/v1/customers"); + then.status(200).body( + r#"{"object": "list", "data": [{ + "id": "cus_1", + "object": "customer", + "balance": 0, + "created": 1649316732, + "currency": "gbp", + "delinquent": false, + "invoice_prefix": "4AF7482", + "invoice_settings": {}, + "livemode": false, + "metadata": {}, + "preferred_locales": [], + "tax_exempt": "none" + }, { + "id": "cus_2", + "object": "customer", + "balance": 0, + "created": 1649316733, + "currency": "gbp", + "delinquent": false, + "invoice_prefix": "4AF7482", + "invoice_settings": {}, + "livemode": false, + "metadata": {}, + "preferred_locales": [], + "tax_exempt": "none" + }], "has_more": true, "url": "/v1/customers"}"#, + ); + }); + + let params = ListCustomers::default(); + let res = Customer::list(&client, ¶ms).await.unwrap().paginate(params); + + let stream = res.stream(&client).collect::>().await; + + println!("{:#?}", stream.len()); + assert_eq!(stream.len(), 3); + + items.assert_hits_async(1).await; + next_item.assert_hits_async(1).await; + } } From 82080a9ff2740c00ec622f8a23f681fccfe63ff1 Mon Sep 17 00:00:00 2001 From: skytz Date: Sun, 14 Jan 2024 17:03:36 +0200 Subject: [PATCH 3/5] removed mutability requirement for next function --- src/params.rs | 37 +++++++++++++++++++++++-------------- 1 file changed, 23 insertions(+), 14 deletions(-) diff --git a/src/params.rs b/src/params.rs index d9600ed8e..a55ece894 100644 --- a/src/params.rs +++ b/src/params.rs @@ -182,7 +182,8 @@ pub trait Paginable { pub trait PaginableList { type O: Paginate + DeserializeOwned + Send + Sync + 'static + Clone + std::fmt::Debug; fn new(data: Vec, url: String, has_more: bool, total_count: Option) -> Self; - fn get_data(&mut self) -> &mut Vec; + fn get_data_mut(&mut self) -> &mut Vec; + fn get_data(&self) -> &Vec; fn get_url(&self) -> String; fn get_total_count(&self) -> Option; fn has_more(&self) -> bool; @@ -228,9 +229,13 @@ impl &mut Vec { + fn get_data_mut(&mut self) -> &mut Vec { &mut self.data } + + fn get_data(&self) -> &Vec { + &self.data + } fn get_url(&self) -> String { self.url.clone() } @@ -251,9 +256,14 @@ impl &mut Vec { + fn get_data_mut(&mut self) -> &mut Vec { &mut self.data } + + fn get_data(&self) -> &Vec { + &self.data + } + fn get_url(&self) -> String { self.url.clone() } @@ -361,7 +371,7 @@ where client: &Client, ) -> impl futures_util::Stream> + Unpin { // We are going to be popping items off the end of the list, so we need to reverse it. - self.page.get_data().reverse(); + self.page.get_data_mut().reverse(); Box::pin(futures_util::stream::unfold(Some((self, client.clone())), Self::unfold_stream)) } @@ -374,18 +384,18 @@ where let (mut paginator, client) = state?; // If none, we sent the last item in the last iteration if paginator.page.get_data().len() > 1 { - return Some((Ok(paginator.page.get_data().pop()?), Some((paginator, client)))); + return Some((Ok(paginator.page.get_data_mut().pop()?), Some((paginator, client)))); // We have more data on this page } if !paginator.page.has_more() { - return Some((Ok(paginator.page.get_data().pop()?), None)); // Final value of the stream, no errors + return Some((Ok(paginator.page.get_data_mut().pop()?), None)); // Final value of the stream, no errors } match paginator.next(&client).await { Ok(mut next_paginator) => { - let data = paginator.page.get_data().pop()?; - next_paginator.page.get_data().reverse(); + let data = paginator.page.get_data_mut().pop()?; + next_paginator.page.get_data_mut().reverse(); // Yield last value of thimuts page, the next page (and client) becomes the state Some((Ok(data), Some((next_paginator, client)))) @@ -395,11 +405,10 @@ where } /// Fetch an additional page of data from stripe. - pub fn next(&mut self, client: &Client) -> Response { - let page_url = self.page.get_url(); + pub fn next(&self, client: &Client) -> Response { if let Some(last) = self.page.get_data().last() { - if page_url.starts_with("/v1/") { - let path = page_url.trim_start_matches("/v1/").to_string(); // the url we get back is prefixed + if self.page.get_url().starts_with("/v1/") { + let path = self.page.get_url().trim_start_matches("/v1/").to_string(); // the url we get back is prefixed // clone the params and set the cursor let params_next = { @@ -589,7 +598,7 @@ mod tests { }); let params = ListCustomers::new(); - let mut res = Customer::list(&client, ¶ms).await.unwrap().paginate(params); + let res = Customer::list(&client, ¶ms).await.unwrap().paginate(params); println!("{:?}", res); @@ -670,7 +679,7 @@ mod tests { }); let params = ListCustomers::new(); - let mut res = Customer::list(&client, ¶ms).await.unwrap().paginate(params); + let res = Customer::list(&client, ¶ms).await.unwrap().paginate(params); let res2 = res.next(&client).await.unwrap(); From 27c4d24038dfc74040fa2c177ac1e9c05c0ba8dd Mon Sep 17 00:00:00 2001 From: Alexander Lyon Date: Wed, 24 Jan 2024 16:32:24 +0000 Subject: [PATCH 4/5] fix sync and add equivalent tests --- src/params.rs | 84 +++++++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 82 insertions(+), 2 deletions(-) diff --git a/src/params.rs b/src/params.rs index a55ece894..224cd6e23 100644 --- a/src/params.rs +++ b/src/params.rs @@ -327,11 +327,11 @@ where let mut paginator = self; loop { if !paginator.page.has_more() { - data.extend(paginator.page.get_data().into_iter()); + data.extend(paginator.page.get_data_mut().drain(..)); break; } let next_paginator = paginator.next(client)?; - data.extend(paginator.page.get_data().into_iter()); + data.extend(paginator.page.get_data_mut().drain(..)); paginator = next_paginator } Ok(data) @@ -610,6 +610,86 @@ mod tests { next_item.assert_hits_async(1).await; } + #[cfg(feature = "blocking")] + #[test] + fn get_all() { + use httpmock::Method::GET; + use httpmock::MockServer; + + use crate::Client; + use crate::{Customer, ListCustomers}; + + // Start a lightweight mock server. + let server = MockServer::start(); + + let client = Client::from_url(&*server.url("/"), "fake_key"); + + let next_item = server.mock(|when, then| { + when.method(GET).path("/v1/customers").query_param("starting_after", "cus_2"); + then.status(200).body( + r#"{"object": "list", "data": [{ + "id": "cus_2", + "object": "customer", + "balance": 0, + "created": 1649316733, + "currency": "gbp", + "delinquent": false, + "email": null, + "invoice_prefix": "4AF7482", + "invoice_settings": {}, + "livemode": false, + "metadata": {}, + "preferred_locales": [], + "tax_exempt": "none" + }], "has_more": false, "url": "/v1/customers"}"#, + ); + }); + + let first_item = server.mock(|when, then| { + when.method(GET).path("/v1/customers"); + then.status(200).body( + r#"{"object": "list", "data": [{ + "id": "cus_1", + "object": "customer", + "balance": 0, + "created": 1649316732, + "currency": "gbp", + "delinquent": false, + "invoice_prefix": "4AF7482", + "invoice_settings": {}, + "livemode": false, + "metadata": {}, + "preferred_locales": [], + "tax_exempt": "none" + }, { + "id": "cus_2", + "object": "customer", + "balance": 0, + "created": 1649316733, + "currency": "gbp", + "delinquent": false, + "invoice_prefix": "4AF7482", + "invoice_settings": {}, + "livemode": false, + "metadata": {}, + "preferred_locales": [], + "tax_exempt": "none" + }], "has_more": true, "url": "/v1/customers"}"#, + ); + }); + + let params = ListCustomers::new(); + let res = Customer::list(&client, ¶ms).unwrap().paginate(params); + + let customers = res.get_all(&client).unwrap(); + + println!("{:?}", customers); + + assert_eq!(customers.len(), 3); + first_item.assert_hits(1); + next_item.assert_hits(1); + } + #[cfg(feature = "async")] #[tokio::test] async fn list_multiple() { From 49893df6dfc3de53af645fea94a2bfc1840d3ea1 Mon Sep 17 00:00:00 2001 From: Alexander Lyon Date: Wed, 24 Jan 2024 20:27:46 +0000 Subject: [PATCH 5/5] fix examples --- examples/connect.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/connect.rs b/examples/connect.rs index 980db1dbf..f16289753 100644 --- a/examples/connect.rs +++ b/examples/connect.rs @@ -48,6 +48,7 @@ async fn main() { expand: &[], refresh_url: Some("https://test.com/refresh"), return_url: Some("https://test.com/return"), + collection_options: None, }, ) .await