Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Plot Waterfall Components Decomposition #631

Merged
merged 6 commits into from
Apr 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 114 additions & 1 deletion pymc_marketing/mmm/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1162,7 +1162,7 @@ def compute_mean_contributions_over_time(
)

if getattr(self, "yearly_seasonality", None):
contributions_fourier_over_time = (
contributions_fourier_over_time = pd.DataFrame(
az.extract(
self.fit_result,
var_names=["fourier_contributions"],
Expand All @@ -1172,6 +1172,8 @@ def compute_mean_contributions_over_time(
.to_dataframe()
.squeeze()
.unstack()
.sum(axis=1),
columns=["yearly_seasonality"],
)
else:
contributions_fourier_over_time = pd.DataFrame(
Expand Down Expand Up @@ -1300,6 +1302,117 @@ def plot_channel_contribution_share_hdi(
def graphviz(self, **kwargs):
return pm.model_to_graphviz(self.model, **kwargs)

def _process_decomposition_components(self, data: pd.DataFrame) -> pd.DataFrame:
"""
Process data to compute the sum of contributions by component and calculate their percentages.
The output dataframe will have columns for "component", "contribution", and "percentage".

Parameters
cetagostini marked this conversation as resolved.
Show resolved Hide resolved
----------
data : pd.DataFrame
Dataframe containing the contribution by component from the function "compute_mean_contributions_over_time".

Returns
cetagostini marked this conversation as resolved.
Show resolved Hide resolved
-------
pd.DataFrame
A dataframe with contributions summed up by component, sorted by contribution in ascending order.
With an additional column showing the percentage contribution of each component.
"""

dataframe = data.copy()
stack_dataframe = dataframe.stack().reset_index()
stack_dataframe.columns = pd.Index(["date", "component", "contribution"])
stack_dataframe.set_index(["date", "component"], inplace=True)
dataframe = stack_dataframe.groupby("component").sum()
dataframe.sort_values(by="contribution", ascending=True, inplace=True)
dataframe.reset_index(inplace=True)

total_contribution = dataframe["contribution"].sum()
dataframe["percentage"] = (dataframe["contribution"] / total_contribution) * 100

return dataframe

def plot_waterfall_components_decomposition(
self,
original_scale: bool = True,
figsize: tuple[int, int] = (14, 7),
**kwargs,
) -> plt.Figure:
"""
This function creates a waterfall plot. The plot shows the decomposition of the target into its components.

Parameters
----------
original_scale : bool, optional
If True, the contributions are plotted in the original scale of the target.
figsize : Tuple, optional
The size of the figure. The default is (14, 7).
**kwargs
Additional keyword arguments to pass to the matplotlib `subplots` function.

Returns
-------
fig : matplotlib.figure.Figure
The matplotlib figure object.
"""

dataframe = self.compute_mean_contributions_over_time(
original_scale=original_scale
)

dataframe = self._process_decomposition_components(data=dataframe)
total_contribution = dataframe["contribution"].sum()

fig, ax = plt.subplots(figsize=figsize, layout="constrained", **kwargs)

cumulative_contribution = 0

for index, row in dataframe.iterrows():
color = "lightblue" if row["contribution"] >= 0 else "salmon"

bar_start = (
cumulative_contribution + row["contribution"]
if row["contribution"] < 0
else cumulative_contribution
)
ax.barh(row["component"], row["contribution"], left=bar_start, color=color)

if row["contribution"] > 0:
cumulative_contribution += row["contribution"]

label_pos = bar_start + (row["contribution"] / 2)

if row["contribution"] < 0:
label_pos = bar_start - (row["contribution"] / 2)

ax.text(
label_pos,
index,
f"{row['contribution']:,.0f}\n({row['percentage']:.1f}%)",
ha="center",
va="center",
color="black",
fontsize=10,
)

ax.set_title("Response Decomposition Waterfall by Components")
ax.set_xlabel("Cumulative Contribution")
ax.set_ylabel("Components")

xticks = np.linspace(0, total_contribution, num=11)
xticklabels = [f"{(x/total_contribution)*100:.0f}%" for x in xticks]
ax.set_xticks(xticks)
ax.set_xticklabels(xticklabels)

ax.spines["right"].set_visible(False)
ax.spines["top"].set_visible(False)
ax.spines["left"].set_visible(False)

ax.set_yticks(np.arange(len(dataframe)))
ax.set_yticklabels(dataframe["component"])

return fig


class MMM(BaseMMM, ValidateTargetColumn, ValidateDateColumn, ValidateChannelColumns):
pass
17 changes: 11 additions & 6 deletions tests/mmm/test_delayed_saturated_mmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,6 @@ def test_fit(self, toy_X: pd.DataFrame, toy_y: pd.Series) -> None:
assert mmm.model_config is not None
n_channel: int = len(mmm.channel_columns)
n_control: int = len(mmm.control_columns)
fourier_terms: int = 2 * mmm.yearly_seasonality
mmm.fit(
X=toy_X,
y=toy_y,
Expand Down Expand Up @@ -322,17 +321,23 @@ def test_fit(self, toy_X: pd.DataFrame, toy_y: pd.Series) -> None:
)
assert mean_model_contributions_ts.shape == (
toy_X.shape[0],
n_channel + n_control + fourier_terms + 1,
n_channel
+ n_control
+ 2, # 2 for yearly seasonality (+1) and intercept (+)
)

processed_df = mmm._process_decomposition_components(
data=mean_model_contributions_ts
)

assert processed_df.shape == (n_channel + n_control + 2, 3)

assert mean_model_contributions_ts.columns.tolist() == [
"channel_1",
"channel_2",
"control_1",
"control_2",
"sin_order_1",
"cos_order_1",
"sin_order_2",
"cos_order_2",
"yearly_seasonality",
"intercept",
]

Expand Down
1 change: 1 addition & 0 deletions tests/mmm/test_plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ class ToyMMM(BaseDelayedSaturatedMMM, MaxAbsScaleTarget):
("plot_posterior_predictive", {"original_scale": True}),
("plot_components_contributions", {}),
("plot_channel_parameter", {"param_name": "alpha"}),
("plot_waterfall_components_decomposition", {"original_scale": True}),
("plot_direct_contribution_curves", {}),
("plot_direct_contribution_curves", {"same_axes": True}),
("plot_direct_contribution_curves", {"channels": ["channel_2"]}),
Expand Down
Loading