Skip to content

Commit

Permalink
Replaced fit_result with posterior in mmm
Browse files Browse the repository at this point in the history
  • Loading branch information
sreekailash committed Jan 4, 2025
1 parent 10e1cbd commit cac170c
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 53 deletions.
50 changes: 8 additions & 42 deletions pymc_marketing/mmm/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def __init__(
self._time_resolution: int
self._time_index: NDArray[np.int_]
self._time_index_mid: int
self._fit_result: az.InferenceData
self.posterior: az.InferenceData
self._posterior_predictive: az.InferenceData
super().__init__(model_config=model_config, sampler_config=sampler_config)

Expand Down Expand Up @@ -271,40 +271,6 @@ def get_target_transformer(self) -> Pipeline:
identity_transformer = FunctionTransformer()
return Pipeline(steps=[("scaler", identity_transformer)])

@property
def prior(self) -> Dataset:
"""Get the prior data."""
if self.idata is None or "prior" not in self.idata:
raise RuntimeError(
"The model hasn't been sampled yet, call .sample_prior_predictive() first"
)
return self.idata["prior"]

@property
def prior_predictive(self) -> Dataset:
"""Get the prior predictive data."""
if self.idata is None or "prior_predictive" not in self.idata:
raise RuntimeError(
"The model hasn't been sampled yet, call .sample_prior_predictive() first"
)
return self.idata["prior_predictive"]

@property
def fit_result(self) -> Dataset:
"""Get the posterior data."""
if self.idata is None or "posterior" not in self.idata:
raise RuntimeError("The model hasn't been fit yet, call .fit() first")
return self.idata["posterior"]

@property
def posterior_predictive(self) -> Dataset:
"""Get the posterior predictive data."""
if self.idata is None or "posterior_predictive" not in self.idata:
raise RuntimeError(
"The model hasn't been fit yet, call .sample_posterior_predictive() first"
)
return self.idata["posterior_predictive"]

def _get_group_predictive_data(
self,
group: Literal["prior_predictive", "posterior_predictive"],
Expand Down Expand Up @@ -828,7 +794,7 @@ def plot_errors(

def _format_model_contributions(self, var_contribution: str) -> DataArray:
contributions = az.extract(
self.fit_result,
self.posterior,
var_names=[var_contribution],
combined=False,
)
Expand Down Expand Up @@ -903,7 +869,7 @@ def plot_components_contributions(self, **plt_kwargs: Any) -> plt.Figure:
)
if self.X is not None:
intercept = az.extract(
self.fit_result, var_names=["intercept"], combined=False
self.posterior, var_names=["intercept"], combined=False
)

if intercept.ndim == 2:
Expand Down Expand Up @@ -953,7 +919,7 @@ def compute_channel_contribution_original_scale(self) -> DataArray:
"""
channel_contribution = az.extract(
data=self.fit_result, var_names=["channel_contributions"], combined=False
data=self.posterior, var_names=["channel_contributions"], combined=False
)

# sklearn preprocessers expect 2-D arrays of (obs, features)
Expand Down Expand Up @@ -991,7 +957,7 @@ def compute_mean_contributions_over_time(
"""
contributions_channel_over_time = (
az.extract(
self.fit_result,
self.posterior,
var_names=["channel_contributions"],
combined=True,
)
Expand All @@ -1005,7 +971,7 @@ def compute_mean_contributions_over_time(
if getattr(self, "control_columns", None):
contributions_control_over_time = (
az.extract(
self.fit_result,
self.posterior,
var_names=["control_contributions"],
combined=True,
)
Expand All @@ -1022,7 +988,7 @@ def compute_mean_contributions_over_time(
if getattr(self, "yearly_seasonality", None):
contributions_fourier_over_time = pd.DataFrame(
az.extract(
self.fit_result,
self.posterior,
var_names=["fourier_contributions"],
combined=True,
)
Expand All @@ -1040,7 +1006,7 @@ def compute_mean_contributions_over_time(

contributions_intercept_over_time = (
az.extract(
self.fit_result,
self.posterior,
var_names=["intercept"],
combined=True,
)
Expand Down
22 changes: 11 additions & 11 deletions pymc_marketing/mmm/mmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -617,7 +617,7 @@ def channel_contributions_forward_pass(
)

idata = pm.sample_posterior_predictive(
self.fit_result,
self.posterior,
var_names=["channel_contributions"],
progressbar=False,
)
Expand All @@ -627,7 +627,7 @@ def channel_contributions_forward_pass(
# This is coupled with the name of the
# latent process Deterministic
name = "media_temporal_latent_multiplier"
mutliplier = self.fit_result[name]
mutliplier = self.posterior[name]
channel_contributions = channel_contributions * mutliplier

return channel_contributions.to_numpy()
Expand Down Expand Up @@ -1052,7 +1052,7 @@ def plot_channel_parameter(self, param_name: str, **plt_kwargs: Any) -> plt.Figu
)

param_samples_df = pd.DataFrame(
data=az.extract(data=self.fit_result, var_names=[param_name]).T,
data=az.extract(data=self.posterior, var_names=[param_name]).T,
columns=self.channel_columns,
)

Expand All @@ -1074,7 +1074,7 @@ def get_ts_contribution_posterior(
----------
var_contribution : str
The variable for which to get the contributions. It must be a valid variable
in the `fit_result` attribute.
in the `posterior` attribute.
original_scale : bool, optional
Whether to plot in the original scale.
Expand Down Expand Up @@ -1172,7 +1172,7 @@ def plot_components_contributions(
)
if self.X is not None:
intercept = az.extract(
self.fit_result, var_names=["intercept"], combined=False
self.posterior, var_names=["intercept"], combined=False
)

if original_scale:
Expand Down Expand Up @@ -1399,7 +1399,7 @@ def new_spend_contributions(
self.channel_transformer.transform(new_data) if not prior else new_data
)

idata: Dataset = self.fit_result if not prior else self.prior
idata: Dataset = self.posterior if not prior else self.prior

coords = {
"time_since_spend": np.arange(-self.adstock.l_max, self.adstock.l_max + 1),
Expand Down Expand Up @@ -1588,7 +1588,7 @@ def format_recovered_transformation_parameters(
"""
# Retrieve channel names
channels = self.fit_result.channel.values
channels = self.posterior.channel.values

# Initialize the dictionary to store channel information
channels_info = {}
Expand All @@ -1607,14 +1607,14 @@ def format_recovered_transformation_parameters(
for group_name, params in param_groups.items():
# Build dictionary for the current group of parameters
param_dict = {
param.replace(group_name.split("_")[0] + "_", ""): self.fit_result[
param.replace(group_name.split("_")[0] + "_", ""): self.posterior[
param
]
.quantile(quantile, dim=["chain", "draw"])
.to_pandas()
.to_dict()[channel]
for param in params
if param in self.fit_result
if param in self.posterior
}
channel_info[group_name] = param_dict

Expand Down Expand Up @@ -1646,11 +1646,11 @@ def _format_parameters_for_budget_allocator(self) -> dict[str, Any]:
}
"""
saturation_params: dict[str, np.ndarray] = {
key: self.fit_result[f"saturation_{key}"].values
key: self.posterior[f"saturation_{key}"].values
for key in self.saturation.default_priors.keys()
}
adstock_params: dict[str, np.ndarray] = {
key: self.fit_result[f"adstock_{key}"].values
key: self.posterior[f"adstock_{key}"].values
for key in self.adstock.default_priors.keys()
}

Expand Down

0 comments on commit cac170c

Please sign in to comment.