Skip to content

Commit

Permalink
Save & load support for time varying parameters (#815)
Browse files Browse the repository at this point in the history
* add missing init for save and load

* get rid of warnings from JSON parsing

* new error message without line break
  • Loading branch information
wd60622 authored and twiecki committed Sep 10, 2024
1 parent 2b9424c commit 6fdd3d2
Show file tree
Hide file tree
Showing 6 changed files with 83 additions and 36 deletions.
8 changes: 4 additions & 4 deletions pymc_marketing/clv/models/pareto_nbd.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,11 +240,11 @@ def build_model(self) -> None: # type: ignore[override]
"purchase_covariate": self.purchase_covariate_cols,
"dropout_covariate": self.dropout_covariate_cols,
"obs_var": ["recency", "frequency"],
"customer_id": self.data["customer_id"],
}
mutable_coords = {"customer_id": self.data["customer_id"]}
with pm.Model(coords=coords, coords_mutable=mutable_coords) as self.model:
with pm.Model(coords=coords) as self.model:
if self.purchase_covariate_cols:
purchase_data = pm.MutableData(
purchase_data = pm.Data(
"purchase_data",
self.data[self.purchase_covariate_cols],
dims=["customer_id", "purchase_covariate"],
Expand Down Expand Up @@ -273,7 +273,7 @@ def build_model(self) -> None: # type: ignore[override]

# churn priors
if self.dropout_covariate_cols:
dropout_data = pm.MutableData(
dropout_data = pm.Data(
"dropout_data",
self.data[self.dropout_covariate_cols],
dims=["customer_id", "dropout_covariate"],
Expand Down
58 changes: 34 additions & 24 deletions pymc_marketing/mmm/delayed_saturated_mmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,11 +203,9 @@ def _generate_and_preprocess_model_data( # type: ignore
date_data = X[self.date_column]
channel_data = X[self.channel_columns]

self.coords_mutable: dict[str, Any] = {
"date": date_data,
}
coords: dict[str, Any] = {
"channel": self.channel_columns,
"date": date_data,
}

new_X_dict = {
Expand Down Expand Up @@ -250,6 +248,8 @@ def _save_input_params(self, idata) -> None:
idata.attrs["adstock_max_lag"] = json.dumps(self.adstock_max_lag)
idata.attrs["validate_data"] = json.dumps(self.validate_data)
idata.attrs["yearly_seasonality"] = json.dumps(self.yearly_seasonality)
idata.attrs["time_varying_intercept"] = json.dumps(self.time_varying_intercept)
idata.attrs["time_varying_media"] = json.dumps(self.time_varying_media)

def forward_pass(
self, x: pt.TensorVariable | npt.NDArray[np.float64]
Expand Down Expand Up @@ -347,20 +347,17 @@ def build_model(
self._generate_and_preprocess_model_data(X, y)
with pm.Model(
coords=self.model_coords,
coords_mutable=self.coords_mutable,
) as self.model:
channel_data_ = pm.Data(
name="channel_data",
value=self.preprocessed_data["X"][self.channel_columns],
dims=("date", "channel"),
mutable=True,
)

target_ = pm.Data(
name="target",
value=self.preprocessed_data["y"],
dims="date",
mutable=True,
)
if self.time_varying_intercept | self.time_varying_media:
time_index = pm.Data(
Expand Down Expand Up @@ -441,7 +438,6 @@ def build_model(
name="control_data",
value=self.preprocessed_data["X"][self.control_columns],
dims=("date", "control"),
mutable=True,
)

control_contributions = pm.Deterministic(
Expand All @@ -459,7 +455,6 @@ def build_model(
self.date_column
].dt.dayofyear.to_numpy(),
dims="date",
mutable=True,
)

def create_deterministic(x: pt.TensorVariable) -> None:
Expand Down Expand Up @@ -544,7 +539,6 @@ def channel_contributions_forward_pass(
"""
coords = {
**self.model_coords,
**self.coords_mutable,
}
with pm.Model(coords=coords):
pm.Deterministic(
Expand Down Expand Up @@ -602,28 +596,44 @@ def load(cls, fname: str):
model_config = cls._model_config_formatting(
json.loads(idata.attrs["model_config"])
)
model = cls(
date_column=json.loads(idata.attrs["date_column"]),
control_columns=json.loads(idata.attrs["control_columns"]),
channel_columns=json.loads(idata.attrs["channel_columns"]),
adstock_max_lag=json.loads(idata.attrs["adstock_max_lag"]),
adstock=json.loads(idata.attrs.get("adstock", "geometric")),
saturation=json.loads(idata.attrs.get("saturation", "logistic")),
adstock_first=json.loads(idata.attrs.get("adstock_first", True)),
validate_data=json.loads(idata.attrs["validate_data"]),
yearly_seasonality=json.loads(idata.attrs["yearly_seasonality"]),
model_config=model_config,
sampler_config=json.loads(idata.attrs["sampler_config"]),
)
with warnings.catch_warnings():
warnings.simplefilter("ignore", category=DeprecationWarning)
model = cls(
date_column=json.loads(idata.attrs["date_column"]),
control_columns=json.loads(idata.attrs["control_columns"]),
# Media Transformations
channel_columns=json.loads(idata.attrs["channel_columns"]),
adstock_max_lag=json.loads(idata.attrs["adstock_max_lag"]),
adstock=json.loads(idata.attrs.get("adstock", "geometric")),
saturation=json.loads(idata.attrs.get("saturation", "logistic")),
adstock_first=json.loads(idata.attrs.get("adstock_first", True)),
# Seasonality
yearly_seasonality=json.loads(idata.attrs["yearly_seasonality"]),
# TVP
time_varying_intercept=json.loads(
idata.attrs.get("time_varying_intercept", False)
),
time_varying_media=json.loads(
idata.attrs.get("time_varying_media", False)
),
# Configurations
validate_data=json.loads(idata.attrs["validate_data"]),
model_config=model_config,
sampler_config=json.loads(idata.attrs["sampler_config"]),
)

model.idata = idata
dataset = idata.fit_data.to_dataframe()
X = dataset.drop(columns=[model.output_var])
y = dataset[model.output_var].values
model.build_model(X, y)
# All previously used data is in idata.
if model.id != idata.attrs["id"]:
error_msg = f"""The file '{fname}' does not contain an inference data of the same model
or configuration as '{cls._model_type}'"""
error_msg = (
f"The file '{fname}' does not contain "
"an inference data of the same model or "
f"configuration as '{cls._model_type}'"
)
raise ValueError(error_msg)

return model
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ dependencies = [
"numpy>=1.17",
"pandas",
# NOTE: Used as minimum pymc version with ci.yml `OLDEST_PYMC_VERSION`
"pymc>=5.12.0,<5.16.0",
"pymc>=5.13.0,<5.16.0",
"scikit-learn>=1.1.1",
"seaborn>=0.12.2",
"xarray",
Expand Down
4 changes: 2 additions & 2 deletions tests/clv/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ def test_pareto_nbd_sample_prior(
s = pm.Gamma(name="s", alpha=5, beta=1, size=s_size)
beta = pm.Gamma(name="beta", alpha=5, beta=1, size=beta_size)

T = pm.MutableData(name="T", value=np.array(10))
T = pm.Data(name="T", value=np.array(10))

ParetoNBD(
name="pareto_nbd",
Expand Down Expand Up @@ -436,7 +436,7 @@ def test_beta_geo_beta_binom_sample_prior(
gamma = pm.Normal(name="gamma", mu=gamma_true, sigma=1e-4, size=gamma_size)
delta = pm.Normal(name="delta", mu=delta_true, sigma=1e-4, size=delta_size)

T = pm.MutableData(name="T", value=np.array(T_true))
T = pm.Data(name="T", value=np.array(T_true))

BetaGeoBetaBinom(
name="beta_geo_beta_binom",
Expand Down
43 changes: 40 additions & 3 deletions tests/mmm/test_delayed_saturated_mmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -637,9 +637,11 @@ def mock_property(self):
# Apply the monkeypatch for the property
monkeypatch.setattr(MMM, "id", property(mock_property))

error_msg = """The file 'test_model' does not contain an inference data of the same model
or configuration as 'MMM'"""

error_msg = (
"The file 'test_model' does not "
"contain an inference data of the "
"same model or configuration as 'MMM'"
)
with pytest.raises(ValueError, match=error_msg):
MMM.load("test_model")
os.remove("test_model")
Expand Down Expand Up @@ -1017,3 +1019,38 @@ def test_initialize_defaults_channel_media_dims() -> None:
for transform in [mmm.adstock, mmm.saturation]:
for config in transform.function_priors.values():
assert config.dims == ("channel",)


@pytest.mark.parametrize(
"time_varying_intercept, time_varying_media",
[
(True, False),
(False, True),
(True, True),
],
)
def test_save_load_with_tvp(
time_varying_intercept, time_varying_media, toy_X, toy_y
) -> None:
mmm = MMM(
channel_columns=["channel_1", "channel_2"],
date_column="date",
adstock="geometric",
saturation="logistic",
adstock_max_lag=5,
time_varying_intercept=time_varying_intercept,
time_varying_media=time_varying_media,
)
mmm = mock_fit(mmm, toy_X, toy_y)

file = "tmp-model"
mmm.save(file)
loaded_mmm = MMM.load(file)

assert mmm.time_varying_intercept == loaded_mmm.time_varying_intercept
assert mmm.time_varying_intercept == time_varying_intercept
assert mmm.time_varying_media == loaded_mmm.time_varying_media
assert mmm.time_varying_media == time_varying_media

# clean up
os.remove(file)
4 changes: 2 additions & 2 deletions tests/test_model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,8 +110,8 @@ def build_model(self, X: pd.DataFrame, y: pd.Series, model_config=None):
with pm.Model(coords=coords) as self.model:
if model_config is None:
model_config = self.default_model_config
x = pm.MutableData("x", self.X["input"].values)
y_data = pm.MutableData("y_data", self.y)
x = pm.Data("x", self.X["input"].values)
y_data = pm.Data("y_data", self.y)

# prior parameters
a_loc = model_config["a"]["loc"]
Expand Down

0 comments on commit 6fdd3d2

Please sign in to comment.