Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix clv plotting bugs and edits to Quickstart #601

Merged
merged 9 commits into from
Apr 8, 2024
Merged
598 changes: 294 additions & 304 deletions docs/source/notebooks/clv/clv_quickstart.ipynb

Large diffs are not rendered by default.

24 changes: 13 additions & 11 deletions pymc_marketing/clv/plotting.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from typing import Optional, Sequence, Tuple
from typing import Optional, Sequence, Tuple, Union

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from matplotlib.lines import Line2D

from pymc_marketing.clv import BetaGeoModel, ParetoNBDModel

__all__ = [
"plot_customer_exposure",
"plot_frequency_recency_matrix",
Expand Down Expand Up @@ -156,7 +158,7 @@


def plot_frequency_recency_matrix(
model,
model: Union[BetaGeoModel, ParetoNBDModel],
ColtAllen marked this conversation as resolved.
Show resolved Hide resolved
t=1,
max_frequency: Optional[int] = None,
max_recency: Optional[int] = None,
Expand All @@ -172,8 +174,8 @@

Parameters
----------
model: lifetimes model
A fitted lifetimes model.
model: CLV model
A fitted CLV model.
t: float, optional
Next units of time to make predictions for
max_frequency: int, optional
Expand All @@ -197,10 +199,10 @@
axes: matplotlib.AxesSubplot
"""
if max_frequency is None:
max_frequency = int(model.frequency.max())
max_frequency = int(model.data["frequency"].max())

Check warning on line 202 in pymc_marketing/clv/plotting.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/clv/plotting.py#L202

Added line #L202 was not covered by tests

if max_recency is None:
max_recency = int(model.recency.max())
max_recency = int(model.data["recency"].max())

Check warning on line 205 in pymc_marketing/clv/plotting.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/clv/plotting.py#L205

Added line #L205 was not covered by tests

mesh_frequency, mesh_recency = _create_frequency_recency_meshes(
max_frequency=max_frequency,
Expand Down Expand Up @@ -245,7 +247,7 @@


def plot_probability_alive_matrix(
model,
model: Union[BetaGeoModel, ParetoNBDModel],
max_frequency: Optional[int] = None,
max_recency: Optional[int] = None,
title: str = "Probability Customer is Alive,\nby Frequency and Recency of a Customer",
Expand All @@ -261,8 +263,8 @@

Parameters
----------
model: lifetimes model
A fitted lifetimes model.
model: CLV model
A fitted CLV model.
max_frequency: int, optional
The maximum frequency to plot. Default is max observed frequency.
max_recency: int, optional
Expand All @@ -285,10 +287,10 @@
"""

if max_frequency is None:
max_frequency = int(model.frequency.max())
max_frequency = int(model.data["frequency"].max())

Check warning on line 290 in pymc_marketing/clv/plotting.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/clv/plotting.py#L290

Added line #L290 was not covered by tests

if max_recency is None:
max_recency = int(model.recency.max())
max_recency = int(model.data["recency"].max())

Check warning on line 293 in pymc_marketing/clv/plotting.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/clv/plotting.py#L293

Added line #L293 was not covered by tests

mesh_frequency, mesh_recency = _create_frequency_recency_meshes(
max_frequency=max_frequency,
Expand Down
2 changes: 1 addition & 1 deletion tests/clv/models/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from arviz import InferenceData, from_dict

from pymc_marketing.clv.models.basic import CLVModel
from tests.clv.utils import set_model_fit
from tests.conftest import set_model_fit


class CLVModelTest(CLVModel):
Expand Down
2 changes: 1 addition & 1 deletion tests/clv/models/test_gamma_gamma.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
GammaGammaModel,
GammaGammaModelIndividual,
)
from tests.clv.utils import set_model_fit
from tests.conftest import set_model_fit


class BaseTestGammaGammaModel:
Expand Down
2 changes: 1 addition & 1 deletion tests/clv/models/test_pareto_nbd.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from pymc_marketing.clv import ParetoNBDModel
from pymc_marketing.clv.distributions import ParetoNBD
from tests.clv.utils import set_model_fit
from tests.conftest import set_model_fit


class TestParetoNBDModel:
Expand Down
109 changes: 59 additions & 50 deletions tests/clv/test_plotting.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,74 @@
from typing import Union

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pymc as pm
import pytest
import xarray as xr
from pytensor.tensor import TensorVariable

from pymc_marketing.clv.models import BetaGeoModel, ParetoNBDModel
from pymc_marketing.clv.plotting import (
plot_customer_exposure,
plot_frequency_recency_matrix,
plot_probability_alive_matrix,
)
from tests.conftest import set_model_fit


@pytest.fixture(scope="module")
def fitted_bg(test_summary_data) -> BetaGeoModel:
rng = np.random.default_rng(13)
data = pd.DataFrame(
{
"customer_id": test_summary_data.index,
"frequency": test_summary_data["frequency"],
"recency": test_summary_data["recency"],
"T": test_summary_data["T"],
}
)
model_config = {
# Narrow Gaussian centered at MLE params from lifetimes BetaGeoFitter
"a_prior": {"dist": "DiracDelta", "kwargs": {"c": 1.85034151}},
"alpha_prior": {"dist": "DiracDelta", "kwargs": {"c": 1.86428187}},
"b_prior": {"dist": "DiracDelta", "kwargs": {"c": 3.18105431}},
"r_prior": {"dist": "DiracDelta", "kwargs": {"c": 0.16385072}},
}
model = BetaGeoModel(
data=data,
model_config=model_config,
)
model.build_model()
fake_fit = pm.sample_prior_predictive(
samples=50, model=model.model, random_seed=rng
).prior
set_model_fit(model, fake_fit)

return model


@pytest.fixture(scope="module")
def test_summary_data() -> pd.DataFrame:
return pd.read_csv("tests/clv/datasets/test_summary_data.csv", index_col=0)
def fitted_pnbd(test_summary_data) -> ParetoNBDModel:
rng = np.random.default_rng(45)

model_config = {
# Narrow Gaussian centered at MLE params from lifetimes ParetoNBDFitter
"r_prior": {"dist": "DiracDelta", "kwargs": {"c": 0.5534}},
"alpha_prior": {"dist": "DiracDelta", "kwargs": {"c": 10.5802}},
"s_prior": {"dist": "DiracDelta", "kwargs": {"c": 0.6061}},
"beta_prior": {"dist": "DiracDelta", "kwargs": {"c": 11.6562}},
}
pnbd_model = ParetoNBDModel(
data=test_summary_data,
model_config=model_config,
)
pnbd_model.build_model()

# Mock an idata object for tests requiring a fitted model
# TODO: This is quite slow. Check similar fixtures in the model tests to speed this up.
fake_fit = pm.sample_prior_predictive(
samples=50, model=pnbd_model.model, random_seed=rng
).prior
set_model_fit(pnbd_model, fake_fit)

return pnbd_model


@pytest.mark.parametrize(
Expand Down Expand Up @@ -59,50 +111,7 @@ def test_plot_customer_exposure_invalid_args(test_summary_data, kwargs) -> None:
plot_customer_exposure(test_summary_data, **kwargs)


class MockModel:
def __init__(self, frequency, recency):
self.frequency = frequency
self.recency = recency

def _mock_posterior(
self, customer_id: Union[np.ndarray, pd.Series]
) -> xr.DataArray:
n_customers = len(customer_id)
n_chains = 4
n_draws = 10
chains = np.arange(n_chains)
draws = np.arange(n_draws)
return xr.DataArray(
data=np.ones((n_customers, n_chains, n_draws)),
coords={"customer_id": customer_id, "chain": chains, "draw": draws},
dims=["customer_id", "chain", "draw"],
)

def expected_probability_alive(
self,
customer_id: Union[np.ndarray, pd.Series],
frequency: Union[np.ndarray, pd.Series],
recency: Union[np.ndarray, pd.Series],
T: Union[np.ndarray, pd.Series],
):
return self._mock_posterior(customer_id)

def expected_num_purchases(
self,
customer_id: Union[np.ndarray, pd.Series],
t: Union[np.ndarray, pd.Series, TensorVariable],
frequency: Union[np.ndarray, pd.Series, TensorVariable],
recency: Union[np.ndarray, pd.Series, TensorVariable],
T: Union[np.ndarray, pd.Series, TensorVariable],
):
return self._mock_posterior(customer_id)


@pytest.fixture
def mock_model(test_summary_data) -> MockModel:
return MockModel(test_summary_data["frequency"], test_summary_data["recency"])

ColtAllen marked this conversation as resolved.
Show resolved Hide resolved

@pytest.mark.parametrize("mock_model", (fitted_bg, fitted_pnbd))
def test_plot_frequency_recency_matrix(mock_model) -> None:
ax: plt.Axes = plot_frequency_recency_matrix(mock_model)

Expand Down
12 changes: 2 additions & 10 deletions tests/clv/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
rfm_train_test_split,
to_xarray,
)
from tests.clv.utils import set_model_fit
from tests.conftest import set_model_fit


def test_to_xarray():
Expand All @@ -42,15 +42,6 @@ def test_to_xarray():
np.testing.assert_array_equal(new_y.coords["test_dim"], customer_id)


@pytest.fixture(scope="module")
def test_summary_data() -> pd.DataFrame:
rng = np.random.default_rng(14)
df = pd.read_csv("tests/clv/datasets/test_summary_data.csv", index_col=0)
df["monetary_value"] = rng.lognormal(size=(len(df)))
df["customer_id"] = df.index
return df


@pytest.fixture(scope="module")
def fitted_bg(test_summary_data) -> BetaGeoModel:
rng = np.random.default_rng(13)
Expand Down Expand Up @@ -100,6 +91,7 @@ def fitted_pnbd(test_summary_data) -> ParetoNBDModel:
pnbd_model.build_model()

# Mock an idata object for tests requiring a fitted model
# TODO: This is quite slow. Check similar fixtures in the model tests to speed this up.
fake_fit = pm.sample_prior_predictive(
samples=50, model=pnbd_model.model, random_seed=rng
).prior
Expand Down
18 changes: 0 additions & 18 deletions tests/clv/utils.py

This file was deleted.

28 changes: 28 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
from typing import Union

import numpy as np
import pandas as pd
import pytest
from arviz import InferenceData
from xarray import Dataset

from pymc_marketing.clv.models import CLVModel


def pytest_addoption(parser):
Expand Down Expand Up @@ -43,3 +50,24 @@ def cdnow_trans() -> pd.DataFrame:
Data source: https://www.brucehardie.com/datasets/
"""
return pd.read_csv("datasets/cdnow_transactions.csv")


@pytest.fixture(scope="module")
def test_summary_data() -> pd.DataFrame:
rng = np.random.default_rng(14)
df = pd.read_csv("tests/clv/datasets/test_summary_data.csv", index_col=0)
df["monetary_value"] = rng.lognormal(size=(len(df)))
df["customer_id"] = df.index
return df


def set_model_fit(model: CLVModel, fit: Union[InferenceData, Dataset]):
if isinstance(fit, InferenceData):
assert "posterior" in fit.groups()
else:
fit = InferenceData(posterior=fit)
if model.model is None:
model.build_model()
model.idata = fit
model.idata.add_groups(fit_data=model.data.to_xarray())
model.set_idata_attrs(fit)
Loading