Skip to content

Commit

Permalink
Plot Waterfall Components Decomposition (#631)
Browse files Browse the repository at this point in the history
* Creating plot waterfall

Co-Authored-By: Carlos Trujillo <[email protected]>

* requested changes

* pre-commit

---------

Co-authored-by: Carlos Trujillo <[email protected]>
  • Loading branch information
cetagostini-wise and cetagostini authored Apr 28, 2024
1 parent d531d0c commit 1c186f0
Show file tree
Hide file tree
Showing 3 changed files with 126 additions and 7 deletions.
115 changes: 114 additions & 1 deletion pymc_marketing/mmm/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1162,7 +1162,7 @@ def compute_mean_contributions_over_time(
)

if getattr(self, "yearly_seasonality", None):
contributions_fourier_over_time = (
contributions_fourier_over_time = pd.DataFrame(
az.extract(
self.fit_result,
var_names=["fourier_contributions"],
Expand All @@ -1172,6 +1172,8 @@ def compute_mean_contributions_over_time(
.to_dataframe()
.squeeze()
.unstack()
.sum(axis=1),
columns=["yearly_seasonality"],
)
else:
contributions_fourier_over_time = pd.DataFrame(
Expand Down Expand Up @@ -1300,6 +1302,117 @@ def plot_channel_contribution_share_hdi(
def graphviz(self, **kwargs):
return pm.model_to_graphviz(self.model, **kwargs)

def _process_decomposition_components(self, data: pd.DataFrame) -> pd.DataFrame:
"""
Process data to compute the sum of contributions by component and calculate their percentages.
The output dataframe will have columns for "component", "contribution", and "percentage".
Parameters
----------
data : pd.DataFrame
Dataframe containing the contribution by component from the function "compute_mean_contributions_over_time".
Returns
-------
pd.DataFrame
A dataframe with contributions summed up by component, sorted by contribution in ascending order.
With an additional column showing the percentage contribution of each component.
"""

dataframe = data.copy()
stack_dataframe = dataframe.stack().reset_index()
stack_dataframe.columns = pd.Index(["date", "component", "contribution"])
stack_dataframe.set_index(["date", "component"], inplace=True)
dataframe = stack_dataframe.groupby("component").sum()
dataframe.sort_values(by="contribution", ascending=True, inplace=True)
dataframe.reset_index(inplace=True)

total_contribution = dataframe["contribution"].sum()
dataframe["percentage"] = (dataframe["contribution"] / total_contribution) * 100

return dataframe

def plot_waterfall_components_decomposition(
self,
original_scale: bool = True,
figsize: tuple[int, int] = (14, 7),
**kwargs,
) -> plt.Figure:
"""
This function creates a waterfall plot. The plot shows the decomposition of the target into its components.
Parameters
----------
original_scale : bool, optional
If True, the contributions are plotted in the original scale of the target.
figsize : Tuple, optional
The size of the figure. The default is (14, 7).
**kwargs
Additional keyword arguments to pass to the matplotlib `subplots` function.
Returns
-------
fig : matplotlib.figure.Figure
The matplotlib figure object.
"""

dataframe = self.compute_mean_contributions_over_time(
original_scale=original_scale
)

dataframe = self._process_decomposition_components(data=dataframe)
total_contribution = dataframe["contribution"].sum()

fig, ax = plt.subplots(figsize=figsize, layout="constrained", **kwargs)

cumulative_contribution = 0

for index, row in dataframe.iterrows():
color = "lightblue" if row["contribution"] >= 0 else "salmon"

bar_start = (
cumulative_contribution + row["contribution"]
if row["contribution"] < 0
else cumulative_contribution
)
ax.barh(row["component"], row["contribution"], left=bar_start, color=color)

if row["contribution"] > 0:
cumulative_contribution += row["contribution"]

label_pos = bar_start + (row["contribution"] / 2)

if row["contribution"] < 0:
label_pos = bar_start - (row["contribution"] / 2)

ax.text(
label_pos,
index,
f"{row['contribution']:,.0f}\n({row['percentage']:.1f}%)",
ha="center",
va="center",
color="black",
fontsize=10,
)

ax.set_title("Response Decomposition Waterfall by Components")
ax.set_xlabel("Cumulative Contribution")
ax.set_ylabel("Components")

xticks = np.linspace(0, total_contribution, num=11)
xticklabels = [f"{(x/total_contribution)*100:.0f}%" for x in xticks]
ax.set_xticks(xticks)
ax.set_xticklabels(xticklabels)

ax.spines["right"].set_visible(False)
ax.spines["top"].set_visible(False)
ax.spines["left"].set_visible(False)

ax.set_yticks(np.arange(len(dataframe)))
ax.set_yticklabels(dataframe["component"])

return fig


class MMM(BaseMMM, ValidateTargetColumn, ValidateDateColumn, ValidateChannelColumns):
pass
17 changes: 11 additions & 6 deletions tests/mmm/test_delayed_saturated_mmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,6 @@ def test_fit(self, toy_X: pd.DataFrame, toy_y: pd.Series) -> None:
assert mmm.model_config is not None
n_channel: int = len(mmm.channel_columns)
n_control: int = len(mmm.control_columns)
fourier_terms: int = 2 * mmm.yearly_seasonality
mmm.fit(
X=toy_X,
y=toy_y,
Expand Down Expand Up @@ -322,17 +321,23 @@ def test_fit(self, toy_X: pd.DataFrame, toy_y: pd.Series) -> None:
)
assert mean_model_contributions_ts.shape == (
toy_X.shape[0],
n_channel + n_control + fourier_terms + 1,
n_channel
+ n_control
+ 2, # 2 for yearly seasonality (+1) and intercept (+)
)

processed_df = mmm._process_decomposition_components(
data=mean_model_contributions_ts
)

assert processed_df.shape == (n_channel + n_control + 2, 3)

assert mean_model_contributions_ts.columns.tolist() == [
"channel_1",
"channel_2",
"control_1",
"control_2",
"sin_order_1",
"cos_order_1",
"sin_order_2",
"cos_order_2",
"yearly_seasonality",
"intercept",
]

Expand Down
1 change: 1 addition & 0 deletions tests/mmm/test_plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ class ToyMMM(BaseDelayedSaturatedMMM, MaxAbsScaleTarget):
("plot_posterior_predictive", {"original_scale": True}),
("plot_components_contributions", {}),
("plot_channel_parameter", {"param_name": "alpha"}),
("plot_waterfall_components_decomposition", {"original_scale": True}),
("plot_direct_contribution_curves", {}),
("plot_direct_contribution_curves", {"same_axes": True}),
("plot_direct_contribution_curves", {"channels": ["channel_2"]}),
Expand Down

0 comments on commit 1c186f0

Please sign in to comment.