Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow custom priors and likelihood in DelayedSaturated MMM #397

Merged
Merged
Show file tree
Hide file tree
Changes from 49 commits
Commits
Show all changes
65 commits
Select commit Hold shift + click to select a range
5cab673
Creating Budget Allocation example
cetagostini Aug 28, 2023
9f99c46
Small change on the notebook
cetagostini Aug 28, 2023
e85323c
Switching words
cetagostini Aug 28, 2023
e532ac9
Adding missing code
cetagostini Aug 28, 2023
e240817
Modifying notebook
cetagostini Aug 28, 2023
bf0c73d
Changing introduction
cetagostini Aug 28, 2023
a019e37
Adding links
cetagostini Aug 28, 2023
af344eb
Correcting grammar v1
cetagostini Aug 28, 2023
e23ebd7
changing title
cetagostini Aug 28, 2023
cd9a40c
Modifying narrative v2
cetagostini Aug 28, 2023
c48691e
Modify Narrative V3
cetagostini Aug 29, 2023
a37cbde
Changing Load Model Section
cetagostini Aug 29, 2023
f5fa1eb
Small grammar correction
cetagostini Aug 29, 2023
df5d883
Updating functions and descriptions
cetagostini Sep 9, 2023
a1a8cfa
Merge remote-tracking branch 'upstream/main'
cetagostini Sep 19, 2023
d6fd770
Adding section to handling non-fit errors
cetagostini Sep 19, 2023
6b0b911
Correcting git workflows error
cetagostini Sep 28, 2023
df67dde
Updating notebook
cetagostini Oct 2, 2023
b35b894
Merge remote-tracking branch 'upstream/main'
cetagostini Oct 2, 2023
e39807b
Merge branch 'main' of https://github.com/cetagostini/pymc-marketing
cetagostini Oct 2, 2023
35a6ccd
Merge branch 'pymc-labs:main' into main
cetagostini Oct 6, 2023
703c676
Merge branch 'pymc-labs:main' into main
cetagostini Oct 13, 2023
ebe20ce
model builder changes
cetagostini Oct 17, 2023
6499849
Replacing dict dims
cetagostini Oct 17, 2023
a8aebcf
Commenting not used params
cetagostini Oct 17, 2023
5847254
Adding _pre_process_prior function
cetagostini Oct 23, 2023
e2a5eb6
importing missing library
cetagostini Oct 23, 2023
13b6e96
Correcting error on importing
cetagostini Oct 23, 2023
bf7c7c7
+ importing str_for_dist library
cetagostini Oct 23, 2023
2524860
Correcting model
cetagostini Oct 23, 2023
93807aa
solving error on fit
cetagostini Oct 23, 2023
b64801a
Updating code (Trying to solve dims mismatch)
cetagostini Oct 28, 2023
74e1730
small adjustment
cetagostini Oct 28, 2023
a86debd
Applying changes based on juanito examples
cetagostini Oct 30, 2023
4500a5d
Praying for mercy.
cetagostini Oct 30, 2023
c692123
debug
cetagostini Oct 30, 2023
d63ebfa
modifying _create_distribution function
cetagostini Oct 31, 2023
91a5af9
Adding prior likelihood config
cetagostini Nov 5, 2023
0f88a16
Deleting hint
cetagostini Nov 5, 2023
a19763e
Adjusting hint
cetagostini Nov 5, 2023
a954957
Adding docstrings
cetagostini Nov 5, 2023
28c28eb
Adding extra unit tests
cetagostini Nov 6, 2023
d667c39
small changes
cetagostini Nov 6, 2023
9139d02
Merge remote-tracking branch 'origin/main' into model_builder_mmm_del…
cetagostini Nov 6, 2023
946e6dc
Merge remote-tracking branch 'upstream/main' into model_builder_mmm_d…
cetagostini Nov 6, 2023
1d0044a
solving error
cetagostini Nov 6, 2023
78c8239
Merge remote-tracking branch 'upstream/main' into model_builder_mmm_d…
cetagostini Nov 22, 2023
82f3132
Adding last team feedback
cetagostini Nov 22, 2023
1fc193c
Fixing error
cetagostini Nov 22, 2023
9fdc71f
Applying suggestion from Ricardo
cetagostini Nov 25, 2023
a0b4b3c
Define possible distributions for likelihood
cetagostini Nov 27, 2023
0bd3415
Adding tests and extra distributions
cetagostini Nov 29, 2023
4b60e0d
Adding config to init_test
cetagostini Nov 30, 2023
d2da24c
adding extra test
cetagostini Nov 30, 2023
f276763
lint
cetagostini Nov 30, 2023
3b11a10
adjusting test
cetagostini Nov 30, 2023
0a81f67
Update tests/mmm/test_delayed_saturated_mmm.py
cetagostini Dec 1, 2023
157d6ce
Update test
cetagostini Dec 2, 2023
dbc167e
Update tests/mmm/test_delayed_saturated_mmm.py
cetagostini Dec 3, 2023
94b4a8c
Adding last changes.
cetagostini Dec 3, 2023
9fcc335
Update tests/mmm/test_delayed_saturated_mmm.py
cetagostini Dec 3, 2023
515f6d6
matching string
cetagostini Dec 3, 2023
27e5dad
Correcting error on assert
cetagostini Dec 3, 2023
3550f66
Huge team work!
cetagostini Dec 3, 2023
df19ddd
Simplify match in test
ricardoV94 Dec 4, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions pymc_marketing/mmm/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1001,6 +1001,34 @@ def plot_direct_contribution_curves(
fig.suptitle("Direct response curves", fontsize=16)
return fig

def _get_distribution(self, dist: Dict) -> Callable:
"""
Retrieve a PyMC distribution callable based on the provided dictionary.

Parameters
----------
dist : Dict
A dictionary containing the key 'dist' which should correspond to the
name of a PyMC distribution.

Returns
-------
Callable
A PyMC distribution callable that can be used to instantiate a random
variable.

Raises
------
ValueError
If the specified distribution name in the dictionary does not correspond
to any distribution in PyMC.
"""
try:
prior_distribution = getattr(pm, dist["dist"])
except AttributeError:
raise ValueError(f"Distribution {dist['dist']} does not exist in PyMC")
return prior_distribution

def compute_mean_contributions_over_time(
self, original_scale: bool = False
) -> pd.DataFrame:
Expand Down
218 changes: 166 additions & 52 deletions pymc_marketing/mmm/delayed_saturated_mmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
import pandas as pd
import pymc as pm
import seaborn as sns
from pytensor.compile.sharedvalue import SharedVariable
from pytensor.tensor import TensorVariable
from xarray import DataArray

from pymc_marketing.mmm.base import MMM
Expand Down Expand Up @@ -142,6 +144,86 @@
idata.attrs["validate_data"] = json.dumps(self.validate_data)
idata.attrs["yearly_seasonality"] = json.dumps(self.yearly_seasonality)

def _create_likelihood_distribution(
self,
dist: Dict,
mu: SharedVariable,
cetagostini marked this conversation as resolved.
Show resolved Hide resolved
observed: Union[np.ndarray, pd.Series],
dims: str,
) -> TensorVariable:
"""
Create and return a likelihood distribution for the model.

This method prepares the distribution and its parameters as specified in the
configuration dictionary, validates them, and constructs the likelihood
distribution using PyMC.

Parameters
----------
dist : Dict
A configuration dictionary that must contain a 'dist' key with the name of
the distribution and a 'kwargs' key with parameters for the distribution.
observed : Union[np.ndarray, pd.Series]
The observed data to which the likelihood distribution will be fitted.
dims : str
The dimensions of the data.

Returns
-------
TensorVariable
The likelihood distribution constructed with PyMC.

Raises
------
ValueError
If 'kwargs' key is missing in `dist`, or the parameter configuration does
not contain 'dist' and 'kwargs' keys, or if 'mu' is present in the nested
'kwargs'
"""
# Validate that 'kwargs' is present and is a dictionary
if "kwargs" not in dist or not isinstance(dist["kwargs"], dict):
raise ValueError(

Check warning on line 185 in pymc_marketing/mmm/delayed_saturated_mmm.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mmm/delayed_saturated_mmm.py#L185

Added line #L185 was not covered by tests
"The 'kwargs' key must be present in the 'dist' dictionary and be a dictionary itself."
)

if "mu" in dist["kwargs"]:
raise ValueError(
"The 'mu' key is not allowed directly within 'kwargs' of the main distribution as it is reserved."
)
cetagostini marked this conversation as resolved.
Show resolved Hide resolved

parameter_distributions = {}
for param, param_config in dist["kwargs"].items():
# Check if param_config is a dictionary with a 'dist' key
if isinstance(param_config, dict) and "dist" in param_config:
# Prepare nested distribution
if "kwargs" not in param_config:
raise ValueError(

Check warning on line 200 in pymc_marketing/mmm/delayed_saturated_mmm.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mmm/delayed_saturated_mmm.py#L200

Added line #L200 was not covered by tests
f"The parameter configuration for '{param}' must contain 'kwargs'."
)

parameter_distributions[param] = self._get_distribution(
dist=param_config
)(**param_config["kwargs"], name=f"likelihood_{param}")
elif isinstance(param_config, (int, float)):
# Use the value directly
parameter_distributions[param] = param_config

Check warning on line 209 in pymc_marketing/mmm/delayed_saturated_mmm.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mmm/delayed_saturated_mmm.py#L209

Added line #L209 was not covered by tests
else:
raise ValueError(
f"Invalid parameter configuration for '{param}'. It must be either a dictionary with a 'dist' key or a numeric value."
)

# Extract the likelihood distribution name and instantiate it
likelihood_dist_name = dist["dist"]
likelihood_dist = self._get_distribution(dist={"dist": likelihood_dist_name})

return likelihood_dist(
name="likelihood",
mu=mu,
observed=observed,
dims=dims,
**parameter_distributions,
)

def build_model(
self,
X: pd.DataFrame,
Expand Down Expand Up @@ -171,13 +253,59 @@
---------------
model : pm.Model
The PyMC model object containing all the defined stochastic and deterministic variables.

Examples
--------
custom_config = {
'intercept': {'dist': 'Normal', 'kwargs': {'mu': 0, 'sigma': 2}},
'beta_channel': {'dist': 'LogNormal', 'kwargs': {'mu': 1, 'sigma': 3}},
'alpha': {'dist': 'Beta', 'kwargs': {'alpha': 1, 'beta': 3}},
'lam': {'dist': 'Gamma', 'kwargs': {'alpha': 3, 'beta': 1}},
'likelihood': {
'dist': 'StudentT',
'nu': {
'dist': 'Gamma', 'kwargs': {'alpha': 3, 'beta': 1}},
'sigma': {
'dist': 'HalfNormal', 'kwargs': {'sigma': 3}}
},
'gamma_control': {'dist': 'Normal', 'kwargs': {'mu': 0, 'sigma': 2}},
'gamma_fourier': {'dist': 'Laplace', 'kwargs': {'mu': 0, 'b': 1}}
}

model = DelayedSaturatedMMM(
date_column="date_week",
channel_columns=["x1", "x2"],
control_columns=[
"event_1",
"event_2",
"t",
],
adstock_max_lag=8,
yearly_seasonality=2,
model_config=custom_config,
)
"""
model_config = self.model_config

self.intercept_dist = self._get_distribution(
dist=self.model_config["intercept"]
)
self.beta_channel_dist = self._get_distribution(
dist=self.model_config["beta_channel"]
)
self.lam_dist = self._get_distribution(dist=self.model_config["lam"])
self.alpha_dist = self._get_distribution(dist=self.model_config["alpha"])
self.gamma_control_dist = self._get_distribution(
dist=self.model_config["gamma_control"]
)
self.gamma_fourier_dist = self._get_distribution(
dist=self.model_config["gamma_fourier"]
)

self._generate_and_preprocess_model_data(X, y)
with pm.Model(coords=self.model_coords) as self.model:
channel_data_ = pm.MutableData(
name="channel_data",
value=self.preprocessed_data["X"][self.channel_columns].to_numpy(),
value=self.preprocessed_data["X"][self.channel_columns],
dims=("date", "channel"),
)

Expand All @@ -187,33 +315,26 @@
dims="date",
)

intercept = pm.Normal(
name="intercept",
mu=model_config["intercept"]["mu"],
sigma=model_config["intercept"]["sigma"],
intercept = self.intercept_dist(
name="intercept", **self.model_config["intercept"]["kwargs"]
)

beta_channel = pm.HalfNormal(
beta_channel = self.beta_channel_dist(
name="beta_channel",
sigma=model_config["beta_channel"]["sigma"],
dims=model_config["beta_channel"]["dims"],
**self.model_config["beta_channel"]["kwargs"],
dims=("channel",),
)
alpha = pm.Beta(
alpha = self.alpha_dist(
name="alpha",
alpha=model_config["alpha"]["alpha"],
beta=model_config["alpha"]["beta"],
dims=model_config["alpha"]["dims"],
dims="channel",
**self.model_config["alpha"]["kwargs"],
)

lam = pm.Gamma(
lam = self.lam_dist(
name="lam",
alpha=model_config["lam"]["alpha"],
beta=model_config["lam"]["beta"],
dims=model_config["lam"]["dims"],
dims="channel",
**self.model_config["lam"]["kwargs"],
)

sigma = pm.HalfNormal(name="sigma", sigma=model_config["sigma"]["sigma"])

channel_adstock = pm.Deterministic(
name="channel_adstock",
var=geometric_adstock(
Expand Down Expand Up @@ -245,19 +366,18 @@
for column in self.control_columns
)
):
gamma_control = self.gamma_control_dist(
name="gamma_control",
dims="control",
**self.model_config["gamma_control"]["kwargs"],
)

control_data_ = pm.MutableData(
name="control_data",
value=self.preprocessed_data["X"][self.control_columns],
dims=("date", "control"),
)

gamma_control = pm.Normal(
name="gamma_control",
mu=model_config["gamma_control"]["mu"],
sigma=model_config["gamma_control"]["sigma"],
dims=model_config["gamma_control"]["dims"],
)

control_contributions = pm.Deterministic(
name="control_contributions",
var=control_data_ * gamma_control,
Expand All @@ -280,11 +400,10 @@
dims=("date", "fourier_mode"),
)

gamma_fourier = pm.Laplace(
gamma_fourier = self.gamma_fourier_dist(
name="gamma_fourier",
mu=model_config["gamma_fourier"]["mu"],
b=model_config["gamma_fourier"]["b"],
dims=model_config["gamma_fourier"]["dims"],
dims="fourier_mode",
**self.model_config["gamma_fourier"]["kwargs"],
)

fourier_contribution = pm.Deterministic(
Expand All @@ -295,36 +414,31 @@

mu_var += fourier_contribution.sum(axis=-1)

mu = pm.Deterministic(
name="mu", var=mu_var, dims=model_config["mu"]["dims"]
)
mu = pm.Deterministic(name="mu", var=mu_var, dims="date")

pm.Normal(
name="likelihood",
self._create_likelihood_distribution(
dist=self.model_config["likelihood"],
mu=mu,
sigma=sigma,
observed=target_,
dims=model_config["likelihood"]["dims"],
dims="date",
)

@property
def default_model_config(self) -> Dict:
model_config: Dict = {
"intercept": {"mu": 0, "sigma": 2},
"beta_channel": {"sigma": 2, "dims": ("channel",)},
"alpha": {"alpha": 1, "beta": 3, "dims": ("channel",)},
"lam": {"alpha": 3, "beta": 1, "dims": ("channel",)},
"sigma": {"sigma": 2},
"gamma_control": {
"mu": 0,
"sigma": 2,
"dims": ("control",),
return {
"intercept": {"dist": "Normal", "kwargs": {"mu": 0, "sigma": 2}},
"beta_channel": {"dist": "HalfNormal", "kwargs": {"sigma": 2}},
"alpha": {"dist": "Beta", "kwargs": {"alpha": 1, "beta": 3}},
"lam": {"dist": "Gamma", "kwargs": {"alpha": 3, "beta": 1}},
"likelihood": {
"dist": "Normal",
"kwargs": {
"sigma": {"dist": "HalfNormal", "kwargs": {"sigma": 2}},
},
},
"mu": {"dims": ("date",)},
"likelihood": {"dims": ("date",)},
"gamma_fourier": {"mu": 0, "b": 1, "dims": "fourier_mode"},
"gamma_control": {"dist": "Normal", "kwargs": {"mu": 0, "sigma": 2}},
"gamma_fourier": {"dist": "Laplace", "kwargs": {"mu": 0, "b": 1}},
}
return model_config

def _get_fourier_models_data(self, X) -> pd.DataFrame:
"""Generates fourier modes to model seasonality.
Expand Down
Loading