Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Include tvp scaling in contribution grid #1253

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions pymc_marketing/mmm/mmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -622,7 +622,15 @@ def channel_contributions_forward_pass(
progressbar=False,
)

return idata.posterior_predictive.channel_contributions.to_numpy()
channel_contributions = idata.posterior_predictive.channel_contributions
if self.time_varying_media:
# This is coupled with the name of the
# latent process Deterministic
name = "media_temporal_latent_multiplier"
mutliplier = self.fit_result[name]
channel_contributions = channel_contributions * mutliplier

return channel_contributions.to_numpy()

@property
def _serializable_model_config(self) -> dict[str, Any]:
Expand Down Expand Up @@ -996,7 +1004,8 @@ def get_channel_contributions_forward_pass_grid(
delta * self.preprocessed_data["X"][self.channel_columns].to_numpy()
)
channel_contribution_forward_pass = self.channel_contributions_forward_pass(
channel_data=channel_data, disable_logger_stdout=True
channel_data=channel_data,
disable_logger_stdout=True,
)
channel_contributions.append(channel_contribution_forward_pass)
return DataArray(
Expand Down
26 changes: 26 additions & 0 deletions tests/mmm/test_mmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1264,3 +1264,29 @@ def test_missing_attrs_to_defaults(toy_X, toy_y) -> None:

# clean up
os.remove(file)


def test_channel_contributions_forward_pass_time_varying_media(toy_X, toy_y) -> None:
mmm = MMM(
date_column="date",
channel_columns=["channel_1", "channel_2"],
control_columns=["control_1", "control_2"],
adstock=GeometricAdstock(l_max=2),
saturation=LogisticSaturation(),
time_varying_media=True,
)
mmm = mock_fit(mmm, toy_X, toy_y)

posterior = mmm.fit_result

baseline_contributions = posterior["baseline_channel_contributions"]
multiplier = posterior["media_temporal_latent_multiplier"]
target_scale = mmm.y.max()
recovered_contributions = baseline_contributions * multiplier * target_scale
media_contributions = mmm.channel_contributions_forward_pass(
mmm.preprocessed_data["X"][mmm.channel_columns].to_numpy()
)
np.testing.assert_allclose(
recovered_contributions.to_numpy(),
media_contributions,
)
Loading