diff --git a/pymc_marketing/mmm/mmm.py b/pymc_marketing/mmm/mmm.py index 66f175ab7..118cf931a 100644 --- a/pymc_marketing/mmm/mmm.py +++ b/pymc_marketing/mmm/mmm.py @@ -622,7 +622,15 @@ def channel_contributions_forward_pass( progressbar=False, ) - return idata.posterior_predictive.channel_contributions.to_numpy() + channel_contributions = idata.posterior_predictive.channel_contributions + if self.time_varying_media: + # This is coupled with the name of the + # latent process Deterministic + name = "media_temporal_latent_multiplier" + mutliplier = self.fit_result[name] + channel_contributions = channel_contributions * mutliplier + + return channel_contributions.to_numpy() @property def _serializable_model_config(self) -> dict[str, Any]: @@ -996,7 +1004,8 @@ def get_channel_contributions_forward_pass_grid( delta * self.preprocessed_data["X"][self.channel_columns].to_numpy() ) channel_contribution_forward_pass = self.channel_contributions_forward_pass( - channel_data=channel_data, disable_logger_stdout=True + channel_data=channel_data, + disable_logger_stdout=True, ) channel_contributions.append(channel_contribution_forward_pass) return DataArray( diff --git a/tests/mmm/test_mmm.py b/tests/mmm/test_mmm.py index c76f6c40d..1df4ed0a2 100644 --- a/tests/mmm/test_mmm.py +++ b/tests/mmm/test_mmm.py @@ -1264,3 +1264,29 @@ def test_missing_attrs_to_defaults(toy_X, toy_y) -> None: # clean up os.remove(file) + + +def test_channel_contributions_forward_pass_time_varying_media(toy_X, toy_y) -> None: + mmm = MMM( + date_column="date", + channel_columns=["channel_1", "channel_2"], + control_columns=["control_1", "control_2"], + adstock=GeometricAdstock(l_max=2), + saturation=LogisticSaturation(), + time_varying_media=True, + ) + mmm = mock_fit(mmm, toy_X, toy_y) + + posterior = mmm.fit_result + + baseline_contributions = posterior["baseline_channel_contributions"] + multiplier = posterior["media_temporal_latent_multiplier"] + target_scale = mmm.y.max() + recovered_contributions = baseline_contributions * multiplier * target_scale + media_contributions = mmm.channel_contributions_forward_pass( + mmm.preprocessed_data["X"][mmm.channel_columns].to_numpy() + ) + np.testing.assert_allclose( + recovered_contributions.to_numpy(), + media_contributions, + )