From 9d496023128ff5b3b5b16065241afc83fc1ba3cd Mon Sep 17 00:00:00 2001
From: hzargar2 <hzargar2@gmail.com>
Date: Sat, 19 Aug 2023 17:21:23 -0400
Subject: [PATCH] feat: updated ListPaginator to be generic over type T where T
 impl PaginableList instead of having separate SearchListPaginator and
 ListPaginator implementations for types SearchList and List.

---
 src/params.rs | 269 ++++++++++++++++++--------------------------------
 1 file changed, 97 insertions(+), 172 deletions(-)

diff --git a/src/params.rs b/src/params.rs
index 4f8410922..bfd1ea2d1 100644
--- a/src/params.rs
+++ b/src/params.rs
@@ -179,16 +179,19 @@ pub trait Paginable {
     fn set_last(&mut self, item: Self::O);
 }
 
-#[derive(Debug)]
-pub struct ListPaginator<T, P> {
-    pub page: List<T>,
-    pub params: P,
-}
-
-#[derive(Debug)]
-pub struct SearchListPaginator<T, P> {
-    pub page: SearchList<T>,
-    pub params: P,
+pub trait PaginableList {
+    type O: Paginate + DeserializeOwned + Send + Sync + 'static + Clone + std::fmt::Debug;
+    fn new(
+        &self,
+        data: Vec<Self::O>,
+        url: String,
+        has_more: bool,
+        total_count: Option<u64>,
+    ) -> Self;
+    fn get_data(&self) -> Vec<Self::O>;
+    fn get_url(&self) -> String;
+    fn get_total_count(&self) -> Option<u64>;
+    fn has_more(&self) -> bool;
 }
 
 /// A single page of a cursor-paginated list of a search object.
@@ -230,151 +233,67 @@ impl<T: Clone> Clone for SearchList<T> {
     }
 }
 
-impl<T> SearchList<T> {
-    pub fn paginate<P>(self, params: P) -> SearchListPaginator<T, P> {
-        SearchListPaginator { page: self, params }
-    }
-}
-
-impl<
-        T: Paginate + DeserializeOwned + Send + Sync + 'static + Clone + std::fmt::Debug,
-        P: Clone + Serialize + Send + 'static + std::fmt::Debug,
-    > SearchListPaginator<T, P>
-where
-    P: Paginable<O = T>,
+impl<T: Paginate + DeserializeOwned + Send + Sync + 'static + Clone + std::fmt::Debug> PaginableList
+    for SearchList<T>
 {
-    /// Repeatedly queries Stripe for more data until all elements in list are fetched, using
-    /// Stripe's default page size.
-    ///
-    /// Requires `feature = "blocking"`.
-    #[cfg(feature = "blocking")]
-    pub fn get_all(self, client: &Client) -> Response<Vec<T>> {
-        let mut data = Vec::with_capacity(self.page.total_count.unwrap_or(0) as usize);
-        let mut paginator = self;
-        loop {
-            if !paginator.page.has_more {
-                data.extend(paginator.page.data.into_iter());
-                break;
-            }
-            let next_paginator = paginator.next(client)?;
-            data.extend(paginator.page.data.into_iter());
-            paginator = next_paginator
-        }
-        Ok(data)
-    }
-
-    /// Get all values in this SearchList, consuming self and lazily paginating until all values are fetched.
-    ///
-    /// This function repeatedly queries Stripe for more data until all elements in list are fetched, using
-    /// the page size specified in params, or Stripe's default page size if none is specified.
-    ///
-    /// ```no_run
-    /// # use stripe::{Customer, SearchListCustomers, StripeError, Client};
-    /// # use futures_util::TryStreamExt;
-    /// # async fn run() -> Result<(), StripeError> {
-    /// # let client = Client::new("sk_test_123");
-    /// # let params = SearchListCustomers { ..Default::default() };
-    ///
-    /// let list = Customer::list(&client, &params).await.unwrap().paginate(params);
-    /// let mut stream = list.stream(&client);
-    ///
-    /// // take a value out from the stream
-    /// if let Some(val) = stream.try_next().await? {
-    ///     println!("GOT = {:?}", val);
-    /// }
-    ///
-    /// // alternatively, you can use stream combinators
-    /// let all_values = stream.try_collect::<Vec<_>>().await?;
-    ///
-    /// # Ok(())
-    /// # }
-    /// ```
-    ///
-    /// Requires `feature = ["async", "stream"]`.
-    #[cfg(all(feature = "async", feature = "stream"))]
-    pub fn stream(
-        mut self,
-        client: &Client,
-    ) -> impl futures_util::Stream<Item = Result<T, StripeError>> + Unpin {
-        // We are going to be popping items off the end of the list, so we need to reverse it.
-        self.page.data.reverse();
+    type O = T;
 
-        Box::pin(futures_util::stream::unfold(Some((self, client.clone())), Self::unfold_stream))
+    fn new(
+        &self,
+        data: Vec<Self::O>,
+        url: String,
+        has_more: bool,
+        total_count: Option<u64>,
+    ) -> SearchList<T> {
+        Self { object: "".to_string(), url, has_more, data: data, next_page: None, total_count }
     }
 
-    /// unfold a single item from the stream
-    #[cfg(all(feature = "async", feature = "stream"))]
-    async fn unfold_stream(
-        state: Option<(Self, Client)>,
-    ) -> Option<(Result<T, StripeError>, Option<(Self, Client)>)> {
-        let (mut paginator, client) = state?; // If none, we sent the last item in the last iteration
-
-        if paginator.page.data.len() > 1 {
-            return Some((Ok(paginator.page.data.pop()?), Some((paginator, client))));
-            // We have more data on this page
-        }
-
-        if !paginator.page.has_more {
-            return Some((Ok(paginator.page.data.pop()?), None)); // Final value of the stream, no errors
-        }
-
-        match paginator.next(&client).await {
-            Ok(mut next_paginator) => {
-                let data = paginator.page.data.pop()?;
-                next_paginator.page.data.reverse();
-
-                // Yield last value of thimuts page, the next page (and client) becomes the state
-                Some((Ok(data), Some((next_paginator, client))))
-            }
-            Err(e) => Some((Err(e), None)), // We ran into an error. The last value of the stream will be the error.
-        }
+    fn get_data(&self) -> Vec<Self::O> {
+        self.data.clone()
     }
+    fn get_url(&self) -> String {
+        self.url.clone()
+    }
+    fn get_total_count(&self) -> Option<u64> {
+        self.total_count.clone()
+    }
+    fn has_more(&self) -> bool {
+        self.has_more.clone()
+    }
+}
 
-    /// Fetch an additional page of data from stripe.
-    pub fn next(&self, client: &Client) -> Response<Self> {
-        if let Some(last) = self.page.data.last() {
-            if self.page.url.starts_with("/v1/") {
-                let path = self.page.url.trim_start_matches("/v1/").to_string(); // the url we get back is prefixed
-
-                // clone the params and set the cursor
-                let params_next = {
-                    let mut p = self.params.clone();
-                    p.set_last(last.clone());
-                    p
-                };
-
-                let page = client.get_query(&path, &params_next);
+impl<T: Paginate + DeserializeOwned + Send + Sync + 'static + Clone + std::fmt::Debug> PaginableList
+    for List<T>
+{
+    type O = T;
 
-                SearchListPaginator::create_paginator(page, params_next)
-            } else {
-                err(StripeError::UnsupportedVersion)
-            }
-        } else {
-            ok(SearchListPaginator {
-                page: SearchList {
-                    object: self.page.object.clone(),
-                    data: Vec::new(),
-                    has_more: false,
-                    total_count: self.page.total_count,
-                    url: self.page.url.clone(),
-                    next_page: self.page.next_page.clone(),
-                },
-                params: self.params.clone(),
-            })
-        }
+    fn new(
+        &self,
+        data: Vec<Self::O>,
+        url: String,
+        has_more: bool,
+        total_count: Option<u64>,
+    ) -> List<T> {
+        Self { url, has_more, data: data, total_count }
     }
 
-    /// Pin a new future which maps the result inside the page future into
-    /// a SearchListPaginator
-    #[cfg(feature = "async")]
-    fn create_paginator(page: Response<SearchList<T>>, params: P) -> Response<Self> {
-        use futures_util::FutureExt;
-        Box::pin(page.map(|page| page.map(|page| SearchListPaginator { page, params })))
+    fn get_data(&self) -> Vec<Self::O> {
+        self.data.clone()
+    }
+    fn get_url(&self) -> String {
+        self.url.clone()
+    }
+    fn get_total_count(&self) -> Option<u64> {
+        self.total_count.clone()
     }
+    fn has_more(&self) -> bool {
+        self.has_more.clone()
+    }
+}
 
-    #[cfg(feature = "blocking")]
-    fn create_paginator(page: Response<SearchList<T>>, params: P) -> Response<Self> {
-        page.map(|page| SearchListPaginator { page, params })
+impl<T> SearchList<T> {
+    pub fn paginate<P>(self, params: P) -> ListPaginator<SearchList<T>, P> {
+        ListPaginator { page: self, params }
     }
 }
 
@@ -407,33 +326,39 @@ impl<T: Clone> Clone for List<T> {
 }
 
 impl<T> List<T> {
-    pub fn paginate<P>(self, params: P) -> ListPaginator<T, P> {
+    pub fn paginate<P>(self, params: P) -> ListPaginator<List<T>, P> {
         ListPaginator { page: self, params }
     }
 }
 
+#[derive(Debug)]
+pub struct ListPaginator<T, P> {
+    pub page: T,
+    pub params: P,
+}
+
 impl<
-        T: Paginate + DeserializeOwned + Send + Sync + 'static + Clone + std::fmt::Debug,
+        T: PaginableList + Send + DeserializeOwned + 'static,
         P: Clone + Serialize + Send + 'static + std::fmt::Debug,
     > ListPaginator<T, P>
 where
-    P: Paginable<O = T>,
+    P: Paginable<O = T::O>,
 {
     /// Repeatedly queries Stripe for more data until all elements in list are fetched, using
     /// Stripe's default page size.
     ///
     /// Requires `feature = "blocking"`.
     #[cfg(feature = "blocking")]
-    pub fn get_all(self, client: &Client) -> Response<Vec<T>> {
-        let mut data = Vec::with_capacity(self.page.total_count.unwrap_or(0) as usize);
+    pub fn get_all(self, client: &Client) -> Response<Vec<T::O>> {
+        let mut data = Vec::with_capacity(self.page.get_total_count().unwrap_or(0) as usize);
         let mut paginator = self;
         loop {
-            if !paginator.page.has_more {
-                data.extend(paginator.page.data.into_iter());
+            if !paginator.page.has_more() {
+                data.extend(paginator.page.get_data().into_iter());
                 break;
             }
             let next_paginator = paginator.next(client)?;
-            data.extend(paginator.page.data.into_iter());
+            data.extend(paginator.page.get_data().into_iter());
             paginator = next_paginator
         }
         Ok(data)
@@ -471,9 +396,9 @@ where
     pub fn stream(
         mut self,
         client: &Client,
-    ) -> impl futures_util::Stream<Item = Result<T, StripeError>> + Unpin {
+    ) -> impl futures_util::Stream<Item = Result<T::O, StripeError>> + Unpin {
         // We are going to be popping items off the end of the list, so we need to reverse it.
-        self.page.data.reverse();
+        self.page.get_data().reverse();
 
         Box::pin(futures_util::stream::unfold(Some((self, client.clone())), Self::unfold_stream))
     }
@@ -482,22 +407,22 @@ where
     #[cfg(all(feature = "async", feature = "stream"))]
     async fn unfold_stream(
         state: Option<(Self, Client)>,
-    ) -> Option<(Result<T, StripeError>, Option<(Self, Client)>)> {
+    ) -> Option<(Result<T::O, StripeError>, Option<(Self, Client)>)> {
         let (mut paginator, client) = state?; // If none, we sent the last item in the last iteration
 
-        if paginator.page.data.len() > 1 {
-            return Some((Ok(paginator.page.data.pop()?), Some((paginator, client))));
+        if paginator.page.get_data().len() > 1 {
+            return Some((Ok(paginator.page.get_data().pop()?), Some((paginator, client))));
             // We have more data on this page
         }
 
-        if !paginator.page.has_more {
-            return Some((Ok(paginator.page.data.pop()?), None)); // Final value of the stream, no errors
+        if !paginator.page.has_more() {
+            return Some((Ok(paginator.page.get_data().pop()?), None)); // Final value of the stream, no errors
         }
 
         match paginator.next(&client).await {
             Ok(mut next_paginator) => {
-                let data = paginator.page.data.pop()?;
-                next_paginator.page.data.reverse();
+                let data = paginator.page.get_data().pop()?;
+                next_paginator.page.get_data().reverse();
 
                 // Yield last value of thimuts page, the next page (and client) becomes the state
                 Some((Ok(data), Some((next_paginator, client))))
@@ -508,9 +433,9 @@ where
 
     /// Fetch an additional page of data from stripe.
     pub fn next(&self, client: &Client) -> Response<Self> {
-        if let Some(last) = self.page.data.last() {
-            if self.page.url.starts_with("/v1/") {
-                let path = self.page.url.trim_start_matches("/v1/").to_string(); // the url we get back is prefixed
+        if let Some(last) = self.page.get_data().last() {
+            if self.page.get_url().starts_with("/v1/") {
+                let path = self.page.get_url().trim_start_matches("/v1/").to_string(); // the url we get back is prefixed
 
                 // clone the params and set the cursor
                 let params_next = {
@@ -527,12 +452,12 @@ where
             }
         } else {
             ok(ListPaginator {
-                page: List {
-                    data: Vec::new(),
-                    has_more: false,
-                    total_count: self.page.total_count,
-                    url: self.page.url.clone(),
-                },
+                page: self.page.new(
+                    Vec::new(),
+                    self.page.get_url(),
+                    self.page.has_more(),
+                    self.page.get_total_count(),
+                ),
                 params: self.params.clone(),
             })
         }
@@ -541,13 +466,13 @@ where
     /// Pin a new future which maps the result inside the page future into
     /// a ListPaginator
     #[cfg(feature = "async")]
-    fn create_paginator(page: Response<List<T>>, params: P) -> Response<Self> {
+    fn create_paginator(page: Response<T>, params: P) -> Response<Self> {
         use futures_util::FutureExt;
         Box::pin(page.map(|page| page.map(|page| ListPaginator { page, params })))
     }
 
     #[cfg(feature = "blocking")]
-    fn create_paginator(page: Response<List<T>>, params: P) -> Response<Self> {
+    fn create_paginator(page: Response<T>, params: P) -> Response<Self> {
         page.map(|page| ListPaginator { page, params })
     }
 }