Skip to content

Commit

Permalink
Allow custom priors and likelihood in DelayedSaturated MMM (#397)
Browse files Browse the repository at this point in the history
Co-authored-by: Ricardo Vieira <[email protected]>
Co-authored-by: Juan Orduz <[email protected]>
Co-authored-by: Markus Sagen <[email protected]>
Co-authored-by: nialloulton <[email protected]>
Co-authored-by: Ricardo Vieira <[email protected]>
  • Loading branch information
6 people authored Dec 4, 2023
1 parent 30c91ee commit cf0a954
Show file tree
Hide file tree
Showing 3 changed files with 344 additions and 68 deletions.
28 changes: 28 additions & 0 deletions pymc_marketing/mmm/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1001,6 +1001,34 @@ def plot_direct_contribution_curves(
fig.suptitle("Direct response curves", fontsize=16)
return fig

def _get_distribution(self, dist: Dict) -> Callable:
"""
Retrieve a PyMC distribution callable based on the provided dictionary.
Parameters
----------
dist : Dict
A dictionary containing the key 'dist' which should correspond to the
name of a PyMC distribution.
Returns
-------
Callable
A PyMC distribution callable that can be used to instantiate a random
variable.
Raises
------
ValueError
If the specified distribution name in the dictionary does not correspond
to any distribution in PyMC.
"""
try:
prior_distribution = getattr(pm, dist["dist"])
except AttributeError:
raise ValueError(f"Distribution {dist['dist']} does not exist in PyMC")
return prior_distribution

def compute_mean_contributions_over_time(
self, original_scale: bool = False
) -> pd.DataFrame:
Expand Down
233 changes: 181 additions & 52 deletions pymc_marketing/mmm/delayed_saturated_mmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
import pandas as pd
import pymc as pm
import seaborn as sns
from pytensor.compile.sharedvalue import SharedVariable
from pytensor.tensor import TensorVariable
from xarray import DataArray

from pymc_marketing.mmm.base import MMM
Expand Down Expand Up @@ -142,6 +144,105 @@ def _save_input_params(self, idata) -> None:
idata.attrs["validate_data"] = json.dumps(self.validate_data)
idata.attrs["yearly_seasonality"] = json.dumps(self.yearly_seasonality)

def _create_likelihood_distribution(
self,
dist: Dict,
mu: SharedVariable,
observed: Union[np.ndarray, pd.Series],
dims: str,
) -> TensorVariable:
"""
Create and return a likelihood distribution for the model.
This method prepares the distribution and its parameters as specified in the
configuration dictionary, validates them, and constructs the likelihood
distribution using PyMC.
Parameters
----------
dist : Dict
A configuration dictionary that must contain a 'dist' key with the name of
the distribution and a 'kwargs' key with parameters for the distribution.
observed : Union[np.ndarray, pd.Series]
The observed data to which the likelihood distribution will be fitted.
dims : str
The dimensions of the data.
Returns
-------
TensorVariable
The likelihood distribution constructed with PyMC.
Raises
------
ValueError
If 'kwargs' key is missing in `dist`, or the parameter configuration does
not contain 'dist' and 'kwargs' keys, or if 'mu' is present in the nested
'kwargs'
"""

allowed_distributions = [
"Normal",
"StudentT",
"Laplace",
"Logistic",
"LogNormal",
"Wald",
"TruncatedNormal",
"Gamma",
"AsymmetricLaplace",
"VonMises",
]

if dist["dist"] not in allowed_distributions:
raise ValueError(
f"The distribution used for the likelihood is not allowed. Please, use one of the following distributions: {allowed_distributions}."
)

# Validate that 'kwargs' is present and is a dictionary
if "kwargs" not in dist or not isinstance(dist["kwargs"], dict):
raise ValueError(
"The 'kwargs' key must be present in the 'dist' dictionary and be a dictionary itself."
)

if "mu" in dist["kwargs"]:
raise ValueError(
"The 'mu' key is not allowed directly within 'kwargs' of the main distribution as it is reserved."
)

parameter_distributions = {}
for param, param_config in dist["kwargs"].items():
# Check if param_config is a dictionary with a 'dist' key
if isinstance(param_config, dict) and "dist" in param_config:
# Prepare nested distribution
if "kwargs" not in param_config:
raise ValueError(
f"The parameter configuration for '{param}' must contain 'kwargs'."
)

parameter_distributions[param] = self._get_distribution(
dist=param_config
)(**param_config["kwargs"], name=f"likelihood_{param}")
elif isinstance(param_config, (int, float)):
# Use the value directly
parameter_distributions[param] = param_config
else:
raise ValueError(
f"Invalid parameter configuration for '{param}'. It must be either a dictionary with a 'dist' key or a numeric value."
)

# Extract the likelihood distribution name and instantiate it
likelihood_dist_name = dist["dist"]
likelihood_dist = self._get_distribution(dist={"dist": likelihood_dist_name})

return likelihood_dist(
name="likelihood",
mu=mu,
observed=observed,
dims=dims,
**parameter_distributions,
)

def build_model(
self,
X: pd.DataFrame,
Expand Down Expand Up @@ -171,13 +272,55 @@ def build_model(
---------------
model : pm.Model
The PyMC model object containing all the defined stochastic and deterministic variables.
Examples
--------
custom_config = {
'intercept': {'dist': 'Normal', 'kwargs': {'mu': 0, 'sigma': 2}},
'beta_channel': {'dist': 'LogNormal', 'kwargs': {'mu': 1, 'sigma': 3}},
'alpha': {'dist': 'Beta', 'kwargs': {'alpha': 1, 'beta': 3}},
'lam': {'dist': 'Gamma', 'kwargs': {'alpha': 3, 'beta': 1}},
'likelihood': {'dist': 'Normal',
'kwargs': {'sigma': {'dist': 'HalfNormal', 'kwargs': {'sigma': 2}}}
},
'gamma_control': {'dist': 'Normal', 'kwargs': {'mu': 0, 'sigma': 2}},
'gamma_fourier': {'dist': 'Laplace', 'kwargs': {'mu': 0, 'b': 1}}
}
model = DelayedSaturatedMMM(
date_column="date_week",
channel_columns=["x1", "x2"],
control_columns=[
"event_1",
"event_2",
"t",
],
adstock_max_lag=8,
yearly_seasonality=2,
model_config=custom_config,
)
"""
model_config = self.model_config

self.intercept_dist = self._get_distribution(
dist=self.model_config["intercept"]
)
self.beta_channel_dist = self._get_distribution(
dist=self.model_config["beta_channel"]
)
self.lam_dist = self._get_distribution(dist=self.model_config["lam"])
self.alpha_dist = self._get_distribution(dist=self.model_config["alpha"])
self.gamma_control_dist = self._get_distribution(
dist=self.model_config["gamma_control"]
)
self.gamma_fourier_dist = self._get_distribution(
dist=self.model_config["gamma_fourier"]
)

self._generate_and_preprocess_model_data(X, y)
with pm.Model(coords=self.model_coords) as self.model:
channel_data_ = pm.MutableData(
name="channel_data",
value=self.preprocessed_data["X"][self.channel_columns].to_numpy(),
value=self.preprocessed_data["X"][self.channel_columns],
dims=("date", "channel"),
)

Expand All @@ -187,33 +330,26 @@ def build_model(
dims="date",
)

intercept = pm.Normal(
name="intercept",
mu=model_config["intercept"]["mu"],
sigma=model_config["intercept"]["sigma"],
intercept = self.intercept_dist(
name="intercept", **self.model_config["intercept"]["kwargs"]
)

beta_channel = pm.HalfNormal(
beta_channel = self.beta_channel_dist(
name="beta_channel",
sigma=model_config["beta_channel"]["sigma"],
dims=model_config["beta_channel"]["dims"],
**self.model_config["beta_channel"]["kwargs"],
dims=("channel",),
)
alpha = pm.Beta(
alpha = self.alpha_dist(
name="alpha",
alpha=model_config["alpha"]["alpha"],
beta=model_config["alpha"]["beta"],
dims=model_config["alpha"]["dims"],
dims="channel",
**self.model_config["alpha"]["kwargs"],
)

lam = pm.Gamma(
lam = self.lam_dist(
name="lam",
alpha=model_config["lam"]["alpha"],
beta=model_config["lam"]["beta"],
dims=model_config["lam"]["dims"],
dims="channel",
**self.model_config["lam"]["kwargs"],
)

sigma = pm.HalfNormal(name="sigma", sigma=model_config["sigma"]["sigma"])

channel_adstock = pm.Deterministic(
name="channel_adstock",
var=geometric_adstock(
Expand Down Expand Up @@ -245,19 +381,18 @@ def build_model(
for column in self.control_columns
)
):
gamma_control = self.gamma_control_dist(
name="gamma_control",
dims="control",
**self.model_config["gamma_control"]["kwargs"],
)

control_data_ = pm.MutableData(
name="control_data",
value=self.preprocessed_data["X"][self.control_columns],
dims=("date", "control"),
)

gamma_control = pm.Normal(
name="gamma_control",
mu=model_config["gamma_control"]["mu"],
sigma=model_config["gamma_control"]["sigma"],
dims=model_config["gamma_control"]["dims"],
)

control_contributions = pm.Deterministic(
name="control_contributions",
var=control_data_ * gamma_control,
Expand All @@ -280,11 +415,10 @@ def build_model(
dims=("date", "fourier_mode"),
)

gamma_fourier = pm.Laplace(
gamma_fourier = self.gamma_fourier_dist(
name="gamma_fourier",
mu=model_config["gamma_fourier"]["mu"],
b=model_config["gamma_fourier"]["b"],
dims=model_config["gamma_fourier"]["dims"],
dims="fourier_mode",
**self.model_config["gamma_fourier"]["kwargs"],
)

fourier_contribution = pm.Deterministic(
Expand All @@ -295,36 +429,31 @@ def build_model(

mu_var += fourier_contribution.sum(axis=-1)

mu = pm.Deterministic(
name="mu", var=mu_var, dims=model_config["mu"]["dims"]
)
mu = pm.Deterministic(name="mu", var=mu_var, dims="date")

pm.Normal(
name="likelihood",
self._create_likelihood_distribution(
dist=self.model_config["likelihood"],
mu=mu,
sigma=sigma,
observed=target_,
dims=model_config["likelihood"]["dims"],
dims="date",
)

@property
def default_model_config(self) -> Dict:
model_config: Dict = {
"intercept": {"mu": 0, "sigma": 2},
"beta_channel": {"sigma": 2, "dims": ("channel",)},
"alpha": {"alpha": 1, "beta": 3, "dims": ("channel",)},
"lam": {"alpha": 3, "beta": 1, "dims": ("channel",)},
"sigma": {"sigma": 2},
"gamma_control": {
"mu": 0,
"sigma": 2,
"dims": ("control",),
return {
"intercept": {"dist": "Normal", "kwargs": {"mu": 0, "sigma": 2}},
"beta_channel": {"dist": "HalfNormal", "kwargs": {"sigma": 2}},
"alpha": {"dist": "Beta", "kwargs": {"alpha": 1, "beta": 3}},
"lam": {"dist": "Gamma", "kwargs": {"alpha": 3, "beta": 1}},
"likelihood": {
"dist": "Normal",
"kwargs": {
"sigma": {"dist": "HalfNormal", "kwargs": {"sigma": 2}},
},
},
"mu": {"dims": ("date",)},
"likelihood": {"dims": ("date",)},
"gamma_fourier": {"mu": 0, "b": 1, "dims": "fourier_mode"},
"gamma_control": {"dist": "Normal", "kwargs": {"mu": 0, "sigma": 2}},
"gamma_fourier": {"dist": "Laplace", "kwargs": {"mu": 0, "b": 1}},
}
return model_config

def _get_fourier_models_data(self, X) -> pd.DataFrame:
"""Generates fourier modes to model seasonality.
Expand Down
Loading

0 comments on commit cf0a954

Please sign in to comment.