Skip to content

Commit

Permalink
Fix clv plotting bugs and edits to Quickstart (pymc-labs#601)
Browse files Browse the repository at this point in the history
* move fixtures to conftest

* docstrings and moved set_model_fit to conftest

* fixed pandas quickstart warnings

* revert to MockModel and add ParetoNBD support

* quickstart edit for issue 609

* notebook edit
  • Loading branch information
ColtAllen authored and louismagowan committed Apr 11, 2024
1 parent 80a79ba commit 5411653
Show file tree
Hide file tree
Showing 9 changed files with 445 additions and 413 deletions.
590 changes: 286 additions & 304 deletions docs/source/notebooks/clv/clv_quickstart.ipynb

Large diffs are not rendered by default.

104 changes: 74 additions & 30 deletions pymc_marketing/clv/plotting.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from typing import Optional, Sequence, Tuple
from typing import Optional, Sequence, Tuple, Union

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from matplotlib.lines import Line2D

from pymc_marketing.clv import BetaGeoModel, ParetoNBDModel

__all__ = [
"plot_customer_exposure",
"plot_frequency_recency_matrix",
Expand Down Expand Up @@ -156,7 +158,7 @@ def _create_frequency_recency_meshes(


def plot_frequency_recency_matrix(
model,
model: Union[BetaGeoModel, ParetoNBDModel],
t=1,
max_frequency: Optional[int] = None,
max_recency: Optional[int] = None,
Expand All @@ -172,8 +174,8 @@ def plot_frequency_recency_matrix(
Parameters
----------
model: lifetimes model
A fitted lifetimes model.
model: CLV model
A fitted CLV model.
t: float, optional
Next units of time to make predictions for
max_frequency: int, optional
Expand All @@ -197,27 +199,49 @@ def plot_frequency_recency_matrix(
axes: matplotlib.AxesSubplot
"""
if max_frequency is None:
max_frequency = int(model.frequency.max())
max_frequency = int(model.data["frequency"].max())

if max_recency is None:
max_recency = int(model.recency.max())
max_recency = int(model.data["recency"].max())

mesh_frequency, mesh_recency = _create_frequency_recency_meshes(
max_frequency=max_frequency,
max_recency=max_recency,
)

Z = (
model.expected_num_purchases(
customer_id=np.arange(mesh_recency.size), # placeholder
t=t,
frequency=mesh_frequency.ravel(),
recency=mesh_recency.ravel(),
T=max_recency,
# FIXME: This is a hotfix for ParetoNBDModel, as it has a different API from BetaGeoModel
# We should harmonize them!
if isinstance(model, ParetoNBDModel):
transaction_data = pd.DataFrame(
{
"customer_id": np.arange(mesh_recency.size), # placeholder
"frequency": mesh_frequency.ravel(),
"recency": mesh_recency.ravel(),
"T": max_recency,
}
)
.mean(("draw", "chain"))
.values.reshape(mesh_recency.shape)
)

Z = (
model.expected_purchases(
data=transaction_data,
future_t=t,
)
.mean(("draw", "chain"))
.values.reshape(mesh_recency.shape)
)
else:
Z = (
model.expected_num_purchases(
customer_id=np.arange(mesh_recency.size), # placeholder
frequency=mesh_frequency.ravel(),
recency=mesh_recency.ravel(),
T=max_recency,
t=t,
)
.mean(("draw", "chain"))
.values.reshape(mesh_recency.shape)
)

if ax is None:
ax = plt.subplot(111)

Expand Down Expand Up @@ -245,7 +269,7 @@ def plot_frequency_recency_matrix(


def plot_probability_alive_matrix(
model,
model: Union[BetaGeoModel, ParetoNBDModel],
max_frequency: Optional[int] = None,
max_recency: Optional[int] = None,
title: str = "Probability Customer is Alive,\nby Frequency and Recency of a Customer",
Expand All @@ -261,8 +285,8 @@ def plot_probability_alive_matrix(
Parameters
----------
model: lifetimes model
A fitted lifetimes model.
model: CLV model
A fitted CLV model.
max_frequency: int, optional
The maximum frequency to plot. Default is max observed frequency.
max_recency: int, optional
Expand All @@ -285,26 +309,46 @@ def plot_probability_alive_matrix(
"""

if max_frequency is None:
max_frequency = int(model.frequency.max())
max_frequency = int(model.data["frequency"].max())

if max_recency is None:
max_recency = int(model.recency.max())
max_recency = int(model.data["recency"].max())

mesh_frequency, mesh_recency = _create_frequency_recency_meshes(
max_frequency=max_frequency,
max_recency=max_recency,
)
# FIXME: This is a hotfix for ParetoNBDModel, as it has a different API from BetaGeoModel
# We should harmonize them!
if isinstance(model, ParetoNBDModel):
transaction_data = pd.DataFrame(
{
"customer_id": np.arange(mesh_recency.size), # placeholder
"frequency": mesh_frequency.ravel(),
"recency": mesh_recency.ravel(),
"T": max_recency,
}
)

Z = (
model.expected_probability_alive(
customer_id=np.arange(mesh_recency.size), # placeholder
frequency=mesh_frequency.ravel(),
recency=mesh_recency.ravel(),
T=max_recency,
Z = (
model.expected_probability_alive(
data=transaction_data,
future_t=0, # TODO: This can be a function parameter in the case of ParetoNBDModel
)
.mean(("draw", "chain"))
.values.reshape(mesh_recency.shape)
)
else:
Z = (
model.expected_probability_alive(
customer_id=np.arange(mesh_recency.size), # placeholder
frequency=mesh_frequency.ravel(),
recency=mesh_recency.ravel(),
T=max_recency, # type: ignore
)
.mean(("draw", "chain"))
.values.reshape(mesh_recency.shape)
)
.mean(("draw", "chain"))
.values.reshape(mesh_recency.shape)
)

interpolation = kwargs.pop("interpolation", "none")

Expand Down
2 changes: 1 addition & 1 deletion tests/clv/models/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from arviz import InferenceData, from_dict

from pymc_marketing.clv.models.basic import CLVModel
from tests.clv.utils import set_model_fit
from tests.conftest import set_model_fit


class CLVModelTest(CLVModel):
Expand Down
2 changes: 1 addition & 1 deletion tests/clv/models/test_gamma_gamma.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
GammaGammaModel,
GammaGammaModelIndividual,
)
from tests.clv.utils import set_model_fit
from tests.conftest import set_model_fit


class BaseTestGammaGammaModel:
Expand Down
2 changes: 1 addition & 1 deletion tests/clv/models/test_pareto_nbd.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from pymc_marketing.clv import ParetoNBDModel
from pymc_marketing.clv.distributions import ParetoNBD
from tests.clv.utils import set_model_fit
from tests.conftest import set_model_fit


class TestParetoNBDModel:
Expand Down
100 changes: 52 additions & 48 deletions tests/clv/test_plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,57 @@
)


@pytest.fixture(scope="module")
def test_summary_data() -> pd.DataFrame:
return pd.read_csv("tests/clv/datasets/test_summary_data.csv", index_col=0)
class MockModel:
def __init__(self, data: pd.DataFrame):
self.data = data

def _mock_posterior(
self, customer_id: Union[np.ndarray, pd.Series]
) -> xr.DataArray:
n_customers = len(customer_id)
n_chains = 4
n_draws = 10
chains = np.arange(n_chains)
draws = np.arange(n_draws)
return xr.DataArray(
data=np.ones((n_customers, n_chains, n_draws)),
coords={"customer_id": customer_id, "chain": chains, "draw": draws},
dims=["customer_id", "chain", "draw"],
)

def expected_probability_alive(
self,
customer_id: Union[np.ndarray, pd.Series],
frequency: Union[np.ndarray, pd.Series],
recency: Union[np.ndarray, pd.Series],
T: Union[np.ndarray, pd.Series],
):
return self._mock_posterior(customer_id)

def expected_purchases(
self,
customer_id: Union[np.ndarray, pd.Series],
data: pd.DataFrame,
*,
future_t: Union[np.ndarray, pd.Series, TensorVariable],
):
return self._mock_posterior(customer_id)

# TODO: This is required until CLV API is standardized.
def expected_num_purchases(
self,
customer_id: Union[np.ndarray, pd.Series],
t: Union[np.ndarray, pd.Series, TensorVariable],
frequency: Union[np.ndarray, pd.Series, TensorVariable],
recency: Union[np.ndarray, pd.Series, TensorVariable],
T: Union[np.ndarray, pd.Series, TensorVariable],
):
return self._mock_posterior(customer_id)


@pytest.fixture
def mock_model(test_summary_data) -> MockModel:
return MockModel(test_summary_data)


@pytest.mark.parametrize(
Expand All @@ -33,7 +81,7 @@ def test_plot_customer_exposure(test_summary_data, kwargs) -> None:
assert isinstance(ax, plt.Axes)


def test_plot_cumstomer_exposure_with_ax(test_summary_data) -> None:
def test_plot_customer_exposure_with_ax(test_summary_data) -> None:
ax = plt.subplot()
plot_customer_exposure(test_summary_data, ax=ax)

Expand All @@ -59,50 +107,6 @@ def test_plot_customer_exposure_invalid_args(test_summary_data, kwargs) -> None:
plot_customer_exposure(test_summary_data, **kwargs)


class MockModel:
def __init__(self, frequency, recency):
self.frequency = frequency
self.recency = recency

def _mock_posterior(
self, customer_id: Union[np.ndarray, pd.Series]
) -> xr.DataArray:
n_customers = len(customer_id)
n_chains = 4
n_draws = 10
chains = np.arange(n_chains)
draws = np.arange(n_draws)
return xr.DataArray(
data=np.ones((n_customers, n_chains, n_draws)),
coords={"customer_id": customer_id, "chain": chains, "draw": draws},
dims=["customer_id", "chain", "draw"],
)

def expected_probability_alive(
self,
customer_id: Union[np.ndarray, pd.Series],
frequency: Union[np.ndarray, pd.Series],
recency: Union[np.ndarray, pd.Series],
T: Union[np.ndarray, pd.Series],
):
return self._mock_posterior(customer_id)

def expected_num_purchases(
self,
customer_id: Union[np.ndarray, pd.Series],
t: Union[np.ndarray, pd.Series, TensorVariable],
frequency: Union[np.ndarray, pd.Series, TensorVariable],
recency: Union[np.ndarray, pd.Series, TensorVariable],
T: Union[np.ndarray, pd.Series, TensorVariable],
):
return self._mock_posterior(customer_id)


@pytest.fixture
def mock_model(test_summary_data) -> MockModel:
return MockModel(test_summary_data["frequency"], test_summary_data["recency"])


def test_plot_frequency_recency_matrix(mock_model) -> None:
ax: plt.Axes = plot_frequency_recency_matrix(mock_model)

Expand Down
12 changes: 2 additions & 10 deletions tests/clv/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
rfm_train_test_split,
to_xarray,
)
from tests.clv.utils import set_model_fit
from tests.conftest import set_model_fit


def test_to_xarray():
Expand All @@ -42,15 +42,6 @@ def test_to_xarray():
np.testing.assert_array_equal(new_y.coords["test_dim"], customer_id)


@pytest.fixture(scope="module")
def test_summary_data() -> pd.DataFrame:
rng = np.random.default_rng(14)
df = pd.read_csv("tests/clv/datasets/test_summary_data.csv", index_col=0)
df["monetary_value"] = rng.lognormal(size=(len(df)))
df["customer_id"] = df.index
return df


@pytest.fixture(scope="module")
def fitted_bg(test_summary_data) -> BetaGeoModel:
rng = np.random.default_rng(13)
Expand Down Expand Up @@ -100,6 +91,7 @@ def fitted_pnbd(test_summary_data) -> ParetoNBDModel:
pnbd_model.build_model()

# Mock an idata object for tests requiring a fitted model
# TODO: This is quite slow. Check similar fixtures in the model tests to speed this up.
fake_fit = pm.sample_prior_predictive(
samples=50, model=pnbd_model.model, random_seed=rng
).prior
Expand Down
18 changes: 0 additions & 18 deletions tests/clv/utils.py

This file was deleted.

Loading

0 comments on commit 5411653

Please sign in to comment.