Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

#1035 Distribution new customer enhancements #1061

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion pymc_marketing/clv/models/beta_geo_beta_binom.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,6 +524,7 @@ def _distribution_new_customers(
"purchase_rate",
"recency_frequency",
),
n_samples: int = 1000,
) -> xarray.Dataset:
"""Compute posterior predictive samples of dropout, purchase rate and frequency/recency of new customers.

Expand All @@ -542,6 +543,8 @@ def _distribution_new_customers(
Random state to use for sampling.
var_names : sequence of str, optional
Names of the variables to sample from. Defaults to ["dropout", "purchase_rate", "recency_frequency"].
n_samples : int, optional
Number of posterior predictive samples to generate. Defaults to 1000

"""
if data is None:
Expand All @@ -557,7 +560,7 @@ def _distribution_new_customers(

if dataset.sizes["chain"] == 1 and dataset.sizes["draw"] == 1:
# For map fit add a dummy draw dimension
dataset = dataset.squeeze("draw").expand_dims(draw=range(1000))
dataset = dataset.squeeze("draw").expand_dims(draw=range(n_samples))

coords = self.model.coords.copy() # type: ignore
coords["customer_id"] = data["customer_id"]
Expand Down Expand Up @@ -668,6 +671,7 @@ def distribution_new_customer_recency_frequency(
*,
T: int | np.ndarray | pd.Series | None = None,
random_seed: RandomState | None = None,
n_samples: int = 1,
) -> xarray.Dataset:
"""BG/BB process representing purchases across the customer population.

Expand All @@ -687,6 +691,8 @@ def distribution_new_customer_recency_frequency(
Not required if `data` Dataframe contains a `T` column.
random_seed : ~numpy.random.RandomState, optional
Random state to use for sampling.
n_samples : int, optional
Number of samples to generate. Defaults to 1.

Returns
-------
Expand All @@ -698,4 +704,5 @@ def distribution_new_customer_recency_frequency(
T=T,
random_seed=random_seed,
var_names=["recency_frequency"],
n_samples=n_samples,
)["recency_frequency"]
Loading