Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix clv plotting bugs and edits to Quickstart #601

Merged
merged 9 commits into from
Apr 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
590 changes: 286 additions & 304 deletions docs/source/notebooks/clv/clv_quickstart.ipynb

Large diffs are not rendered by default.

104 changes: 74 additions & 30 deletions pymc_marketing/clv/plotting.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from typing import Optional, Sequence, Tuple
from typing import Optional, Sequence, Tuple, Union

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from matplotlib.lines import Line2D

from pymc_marketing.clv import BetaGeoModel, ParetoNBDModel

__all__ = [
"plot_customer_exposure",
"plot_frequency_recency_matrix",
Expand Down Expand Up @@ -156,7 +158,7 @@


def plot_frequency_recency_matrix(
model,
model: Union[BetaGeoModel, ParetoNBDModel],
ColtAllen marked this conversation as resolved.
Show resolved Hide resolved
t=1,
max_frequency: Optional[int] = None,
max_recency: Optional[int] = None,
Expand All @@ -172,8 +174,8 @@

Parameters
----------
model: lifetimes model
A fitted lifetimes model.
model: CLV model
A fitted CLV model.
t: float, optional
Next units of time to make predictions for
max_frequency: int, optional
Expand All @@ -197,27 +199,49 @@
axes: matplotlib.AxesSubplot
"""
if max_frequency is None:
max_frequency = int(model.frequency.max())
max_frequency = int(model.data["frequency"].max())

if max_recency is None:
max_recency = int(model.recency.max())
max_recency = int(model.data["recency"].max())

mesh_frequency, mesh_recency = _create_frequency_recency_meshes(
max_frequency=max_frequency,
max_recency=max_recency,
)

Z = (
model.expected_num_purchases(
customer_id=np.arange(mesh_recency.size), # placeholder
t=t,
frequency=mesh_frequency.ravel(),
recency=mesh_recency.ravel(),
T=max_recency,
# FIXME: This is a hotfix for ParetoNBDModel, as it has a different API from BetaGeoModel
# We should harmonize them!
Comment on lines +212 to +213
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have an issue fir this? Otherwise, can we create one?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is an issue open; it'll take several PRs to fix:

#527

if isinstance(model, ParetoNBDModel):
transaction_data = pd.DataFrame(

Check warning on line 215 in pymc_marketing/clv/plotting.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/clv/plotting.py#L215

Added line #L215 was not covered by tests
{
"customer_id": np.arange(mesh_recency.size), # placeholder
"frequency": mesh_frequency.ravel(),
"recency": mesh_recency.ravel(),
"T": max_recency,
}
)
.mean(("draw", "chain"))
.values.reshape(mesh_recency.shape)
)

Z = (

Check warning on line 224 in pymc_marketing/clv/plotting.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/clv/plotting.py#L224

Added line #L224 was not covered by tests
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does Z mean? Can we give a more descriptive name 🙏 ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's for the Z dimension of the heatmap plot. I'll make a note to rename this when the API is fixed.

model.expected_purchases(
data=transaction_data,
future_t=t,
)
.mean(("draw", "chain"))
.values.reshape(mesh_recency.shape)
)
else:
Z = (
model.expected_num_purchases(
customer_id=np.arange(mesh_recency.size), # placeholder
frequency=mesh_frequency.ravel(),
recency=mesh_recency.ravel(),
T=max_recency,
t=t,
)
.mean(("draw", "chain"))
.values.reshape(mesh_recency.shape)
)

if ax is None:
ax = plt.subplot(111)

Expand Down Expand Up @@ -245,7 +269,7 @@


def plot_probability_alive_matrix(
model,
model: Union[BetaGeoModel, ParetoNBDModel],
max_frequency: Optional[int] = None,
max_recency: Optional[int] = None,
title: str = "Probability Customer is Alive,\nby Frequency and Recency of a Customer",
Expand All @@ -261,8 +285,8 @@

Parameters
----------
model: lifetimes model
A fitted lifetimes model.
model: CLV model
A fitted CLV model.
max_frequency: int, optional
The maximum frequency to plot. Default is max observed frequency.
max_recency: int, optional
Expand All @@ -285,26 +309,46 @@
"""

if max_frequency is None:
max_frequency = int(model.frequency.max())
max_frequency = int(model.data["frequency"].max())

if max_recency is None:
max_recency = int(model.recency.max())
max_recency = int(model.data["recency"].max())

mesh_frequency, mesh_recency = _create_frequency_recency_meshes(
max_frequency=max_frequency,
max_recency=max_recency,
)
# FIXME: This is a hotfix for ParetoNBDModel, as it has a different API from BetaGeoModel
# We should harmonize them!
if isinstance(model, ParetoNBDModel):
transaction_data = pd.DataFrame(

Check warning on line 324 in pymc_marketing/clv/plotting.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/clv/plotting.py#L324

Added line #L324 was not covered by tests
{
"customer_id": np.arange(mesh_recency.size), # placeholder
"frequency": mesh_frequency.ravel(),
"recency": mesh_recency.ravel(),
"T": max_recency,
}
)

Z = (
model.expected_probability_alive(
customer_id=np.arange(mesh_recency.size), # placeholder
frequency=mesh_frequency.ravel(),
recency=mesh_recency.ravel(),
T=max_recency,
Z = (

Check warning on line 333 in pymc_marketing/clv/plotting.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/clv/plotting.py#L333

Added line #L333 was not covered by tests
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment as above on the Z variable meaning (name)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See above.

model.expected_probability_alive(
data=transaction_data,
future_t=0, # TODO: This can be a function parameter in the case of ParetoNBDModel
)
.mean(("draw", "chain"))
.values.reshape(mesh_recency.shape)
)
else:
Z = (
model.expected_probability_alive(
customer_id=np.arange(mesh_recency.size), # placeholder
frequency=mesh_frequency.ravel(),
recency=mesh_recency.ravel(),
T=max_recency, # type: ignore
)
.mean(("draw", "chain"))
.values.reshape(mesh_recency.shape)
)
.mean(("draw", "chain"))
.values.reshape(mesh_recency.shape)
)

interpolation = kwargs.pop("interpolation", "none")

Expand Down
2 changes: 1 addition & 1 deletion tests/clv/models/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from arviz import InferenceData, from_dict

from pymc_marketing.clv.models.basic import CLVModel
from tests.clv.utils import set_model_fit
from tests.conftest import set_model_fit


class CLVModelTest(CLVModel):
Expand Down
2 changes: 1 addition & 1 deletion tests/clv/models/test_gamma_gamma.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
GammaGammaModel,
GammaGammaModelIndividual,
)
from tests.clv.utils import set_model_fit
from tests.conftest import set_model_fit


class BaseTestGammaGammaModel:
Expand Down
2 changes: 1 addition & 1 deletion tests/clv/models/test_pareto_nbd.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from pymc_marketing.clv import ParetoNBDModel
from pymc_marketing.clv.distributions import ParetoNBD
from tests.clv.utils import set_model_fit
from tests.conftest import set_model_fit


class TestParetoNBDModel:
Expand Down
100 changes: 52 additions & 48 deletions tests/clv/test_plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,57 @@
)


@pytest.fixture(scope="module")
def test_summary_data() -> pd.DataFrame:
return pd.read_csv("tests/clv/datasets/test_summary_data.csv", index_col=0)
class MockModel:
def __init__(self, data: pd.DataFrame):
self.data = data

def _mock_posterior(
self, customer_id: Union[np.ndarray, pd.Series]
) -> xr.DataArray:
n_customers = len(customer_id)
n_chains = 4
n_draws = 10
chains = np.arange(n_chains)
draws = np.arange(n_draws)
return xr.DataArray(
data=np.ones((n_customers, n_chains, n_draws)),
coords={"customer_id": customer_id, "chain": chains, "draw": draws},
dims=["customer_id", "chain", "draw"],
)

def expected_probability_alive(
self,
customer_id: Union[np.ndarray, pd.Series],
frequency: Union[np.ndarray, pd.Series],
recency: Union[np.ndarray, pd.Series],
T: Union[np.ndarray, pd.Series],
):
return self._mock_posterior(customer_id)

def expected_purchases(
self,
customer_id: Union[np.ndarray, pd.Series],
data: pd.DataFrame,
*,
future_t: Union[np.ndarray, pd.Series, TensorVariable],
):
return self._mock_posterior(customer_id)

# TODO: This is required until CLV API is standardized.
def expected_num_purchases(
self,
customer_id: Union[np.ndarray, pd.Series],
t: Union[np.ndarray, pd.Series, TensorVariable],
frequency: Union[np.ndarray, pd.Series, TensorVariable],
recency: Union[np.ndarray, pd.Series, TensorVariable],
T: Union[np.ndarray, pd.Series, TensorVariable],
):
return self._mock_posterior(customer_id)


@pytest.fixture
def mock_model(test_summary_data) -> MockModel:
return MockModel(test_summary_data)


@pytest.mark.parametrize(
Expand All @@ -33,7 +81,7 @@ def test_plot_customer_exposure(test_summary_data, kwargs) -> None:
assert isinstance(ax, plt.Axes)


def test_plot_cumstomer_exposure_with_ax(test_summary_data) -> None:
def test_plot_customer_exposure_with_ax(test_summary_data) -> None:
ax = plt.subplot()
plot_customer_exposure(test_summary_data, ax=ax)

Expand All @@ -59,50 +107,6 @@ def test_plot_customer_exposure_invalid_args(test_summary_data, kwargs) -> None:
plot_customer_exposure(test_summary_data, **kwargs)


class MockModel:
def __init__(self, frequency, recency):
self.frequency = frequency
self.recency = recency

def _mock_posterior(
self, customer_id: Union[np.ndarray, pd.Series]
) -> xr.DataArray:
n_customers = len(customer_id)
n_chains = 4
n_draws = 10
chains = np.arange(n_chains)
draws = np.arange(n_draws)
return xr.DataArray(
data=np.ones((n_customers, n_chains, n_draws)),
coords={"customer_id": customer_id, "chain": chains, "draw": draws},
dims=["customer_id", "chain", "draw"],
)

def expected_probability_alive(
self,
customer_id: Union[np.ndarray, pd.Series],
frequency: Union[np.ndarray, pd.Series],
recency: Union[np.ndarray, pd.Series],
T: Union[np.ndarray, pd.Series],
):
return self._mock_posterior(customer_id)

def expected_num_purchases(
self,
customer_id: Union[np.ndarray, pd.Series],
t: Union[np.ndarray, pd.Series, TensorVariable],
frequency: Union[np.ndarray, pd.Series, TensorVariable],
recency: Union[np.ndarray, pd.Series, TensorVariable],
T: Union[np.ndarray, pd.Series, TensorVariable],
):
return self._mock_posterior(customer_id)


@pytest.fixture
def mock_model(test_summary_data) -> MockModel:
return MockModel(test_summary_data["frequency"], test_summary_data["recency"])
ColtAllen marked this conversation as resolved.
Show resolved Hide resolved


def test_plot_frequency_recency_matrix(mock_model) -> None:
ax: plt.Axes = plot_frequency_recency_matrix(mock_model)

Expand Down
12 changes: 2 additions & 10 deletions tests/clv/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
rfm_train_test_split,
to_xarray,
)
from tests.clv.utils import set_model_fit
from tests.conftest import set_model_fit


def test_to_xarray():
Expand All @@ -42,15 +42,6 @@ def test_to_xarray():
np.testing.assert_array_equal(new_y.coords["test_dim"], customer_id)


@pytest.fixture(scope="module")
def test_summary_data() -> pd.DataFrame:
rng = np.random.default_rng(14)
df = pd.read_csv("tests/clv/datasets/test_summary_data.csv", index_col=0)
df["monetary_value"] = rng.lognormal(size=(len(df)))
df["customer_id"] = df.index
return df


@pytest.fixture(scope="module")
def fitted_bg(test_summary_data) -> BetaGeoModel:
rng = np.random.default_rng(13)
Expand Down Expand Up @@ -100,6 +91,7 @@ def fitted_pnbd(test_summary_data) -> ParetoNBDModel:
pnbd_model.build_model()

# Mock an idata object for tests requiring a fitted model
# TODO: This is quite slow. Check similar fixtures in the model tests to speed this up.
fake_fit = pm.sample_prior_predictive(
samples=50, model=pnbd_model.model, random_seed=rng
).prior
Expand Down
18 changes: 0 additions & 18 deletions tests/clv/utils.py

This file was deleted.

Loading
Loading