From dd1f69d35eea30f4b64db7796ab0ea92823a7487 Mon Sep 17 00:00:00 2001 From: IshaanJolly Date: Sat, 21 Sep 2024 00:46:03 +0100 Subject: [PATCH 1/4] feat: test.txt added for commit check --- test.txt | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 test.txt diff --git a/test.txt b/test.txt new file mode 100644 index 000000000..e69de29bb From 5c681b1853d8b0f0e12137d541282ccd87e8bb5c Mon Sep 17 00:00:00 2001 From: IshaanJolly Date: Sun, 22 Sep 2024 20:07:23 +0100 Subject: [PATCH 2/4] feat: replaced plot_curve with plot_samples within ./mmm/plot.py --- pymc_marketing/mmm/plot.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymc_marketing/mmm/plot.py b/pymc_marketing/mmm/plot.py index 6efdba7d5..fa2a34976 100644 --- a/pymc_marketing/mmm/plot.py +++ b/pymc_marketing/mmm/plot.py @@ -570,7 +570,7 @@ def plot_curve( subplot_kwargs : dict, optional Addtional kwargs to while creating the fig and axes sample_kwargs : dict, optional - Kwargs for the :func:`plot_curve` function + Kwargs for the :func:`plot_sample` function hdi_kwargs : dict, optional Kwargs for the :func:`plot_hdi` function same_axes : bool From d64b1c25101dd5520c7ceb67a8d130b7620ddfe7 Mon Sep 17 00:00:00 2001 From: IshaanJolly Date: Sun, 22 Sep 2024 20:32:23 +0100 Subject: [PATCH 3/4] feat: n_samples added to distributions_new_customers --- pymc_marketing/clv/models/beta_geo_beta_binom.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/pymc_marketing/clv/models/beta_geo_beta_binom.py b/pymc_marketing/clv/models/beta_geo_beta_binom.py index 45f8b68f0..f9eec3f18 100644 --- a/pymc_marketing/clv/models/beta_geo_beta_binom.py +++ b/pymc_marketing/clv/models/beta_geo_beta_binom.py @@ -524,6 +524,7 @@ def _distribution_new_customers( "purchase_rate", "recency_frequency", ), + n_samples: int = 1000, ) -> xarray.Dataset: """Compute posterior predictive samples of dropout, purchase rate and frequency/recency of new customers. @@ -542,6 +543,8 @@ def _distribution_new_customers( Random state to use for sampling. var_names : sequence of str, optional Names of the variables to sample from. Defaults to ["dropout", "purchase_rate", "recency_frequency"]. + n_samples : int, optional + Number of posterior predictive samples to generate. Defaults to 1000 """ if data is None: @@ -557,7 +560,7 @@ def _distribution_new_customers( if dataset.sizes["chain"] == 1 and dataset.sizes["draw"] == 1: # For map fit add a dummy draw dimension - dataset = dataset.squeeze("draw").expand_dims(draw=range(1000)) + dataset = dataset.squeeze("draw").expand_dims(draw=range(n_samples)) coords = self.model.coords.copy() # type: ignore coords["customer_id"] = data["customer_id"] @@ -668,6 +671,7 @@ def distribution_new_customer_recency_frequency( *, T: int | np.ndarray | pd.Series | None = None, random_seed: RandomState | None = None, + n_samples: int = 1, ) -> xarray.Dataset: """BG/BB process representing purchases across the customer population. @@ -687,6 +691,8 @@ def distribution_new_customer_recency_frequency( Not required if `data` Dataframe contains a `T` column. random_seed : ~numpy.random.RandomState, optional Random state to use for sampling. + n_samples : int, optional + Number of samples to generate. Defaults to 1. Returns ------- @@ -698,4 +704,5 @@ def distribution_new_customer_recency_frequency( T=T, random_seed=random_seed, var_names=["recency_frequency"], + n_samples=n_samples, )["recency_frequency"] From 625a53040fbb99745a47bc20b38b6c57d88300eb Mon Sep 17 00:00:00 2001 From: IshaanJolly Date: Mon, 23 Sep 2024 10:26:41 +0100 Subject: [PATCH 4/4] revert the plot.py changes --- pymc_marketing/mmm/plot.py | 2 +- test.txt | 0 2 files changed, 1 insertion(+), 1 deletion(-) delete mode 100644 test.txt diff --git a/pymc_marketing/mmm/plot.py b/pymc_marketing/mmm/plot.py index fa2a34976..6efdba7d5 100644 --- a/pymc_marketing/mmm/plot.py +++ b/pymc_marketing/mmm/plot.py @@ -570,7 +570,7 @@ def plot_curve( subplot_kwargs : dict, optional Addtional kwargs to while creating the fig and axes sample_kwargs : dict, optional - Kwargs for the :func:`plot_sample` function + Kwargs for the :func:`plot_curve` function hdi_kwargs : dict, optional Kwargs for the :func:`plot_hdi` function same_axes : bool diff --git a/test.txt b/test.txt deleted file mode 100644 index e69de29bb..000000000