Skip to content

Commit

Permalink
Hierarchical Model Configuration (#743)
Browse files Browse the repository at this point in the history
* some base logic and tests

* lookup function once

* add error handling

* implement for mmm and media transformations

* add examples

* add to documentation

* add to docstring

* tests for likelihood

* use deepcopy since keys are added

* set default dims and warn

* fix output_var

* migrate failing tests to model_config

* remove the moved test

* use deepcopy since keys are added

* add to docstrings from feedback

* fix handlers at initialize
  • Loading branch information
wd60622 authored and twiecki committed Sep 10, 2024
1 parent 1ad7ed3 commit cc531a7
Show file tree
Hide file tree
Showing 9 changed files with 1,025 additions and 178 deletions.
1 change: 1 addition & 0 deletions docs/source/api/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,5 @@
clv
mmm
model_config
```
24 changes: 15 additions & 9 deletions pymc_marketing/mmm/components/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

import warnings
from collections.abc import Generator, MutableMapping, Sequence
from copy import deepcopy
from inspect import signature
from itertools import product
from typing import Any
Expand All @@ -35,7 +36,11 @@
from pymc.distributions.shape_utils import Dims
from pytensor import tensor as pt

from pymc_marketing.mmm.utils import _get_distribution_from_dict
from pymc_marketing.model_config import (
DimHandler,
create_dim_handler,
create_distribution,
)

Values = Sequence[Any] | npt.NDArray[Any]
Coords = dict[str, Values]
Expand Down Expand Up @@ -154,7 +159,7 @@ class Transformation:
def __init__(self, priors: dict | None = None, prefix: str | None = None) -> None:
self._checks()
priors = priors or {}
self.function_priors = {**self.default_priors, **priors}
self.function_priors = {**deepcopy(self.default_priors), **priors}
self.prefix = prefix or self.prefix

def update_priors(self, priors: dict[str, Any]) -> None:
Expand Down Expand Up @@ -271,20 +276,21 @@ def variable_mapping(self) -> dict[str, str]:
def _create_distributions(
self, dims: Dims | None = None
) -> dict[str, pt.TensorVariable]:
dim_handler: DimHandler = create_dim_handler(dims)
distributions: dict[str, pt.TensorVariable] = {}
for parameter_name, variable_name in self.variable_mapping.items():
parameter_prior = self.function_priors[parameter_name]

distribution = _get_distribution_from_dict(
dist=parameter_prior,
)

distributions[parameter_name] = distribution(
var_dims = parameter_prior.get("dims")
var = create_distribution(
name=variable_name,
dims=dims,
**parameter_prior["kwargs"],
distribution_name=parameter_prior["dist"],
distribution_kwargs=parameter_prior["kwargs"],
dims=var_dims,
)

distributions[parameter_name] = dim_handler(var, var_dims)

return distributions

def sample_prior(
Expand Down
30 changes: 30 additions & 0 deletions pymc_marketing/mmm/components/saturation.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,36 @@ def function(self, x, b):
saturation.plot_curve(curve)
plt.show()
Define a hierarchical saturation function with only hierarchical parameters
for saturation parameter of logistic saturation.
.. code-block:: python
from pymc_marketing.mmm import LogisticSaturation
priors = {
"lam": {
"dist": "Gamma",
"kwargs": {
"alpha": {
"dist": "HalfNormal",
"kwargs": {"sigma": 1},
},
"beta": {
"dist": "HalfNormal",
"kwargs": {"sigma": 1},
},
},
"dims": "channel",
},
"beta": {
"dist": "HalfNormal",
"kwargs": {"sigma": 1},
"dims": "channel",
},
}
saturation = LogisticSaturation(priors=priors)
"""

import numpy as np
Expand Down
186 changes: 55 additions & 131 deletions pymc_marketing/mmm/delayed_saturated_mmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
import pymc as pm
import pytensor.tensor as pt
import seaborn as sns
from pytensor.tensor import TensorVariable
from xarray import DataArray, Dataset

from pymc_marketing.constants import DAYS_IN_YEAR
Expand All @@ -47,12 +46,16 @@
from pymc_marketing.mmm.preprocessing import MaxAbsScaleChannels, MaxAbsScaleTarget
from pymc_marketing.mmm.tvp import create_time_varying_intercept, infer_time_index
from pymc_marketing.mmm.utils import (
_get_distribution_from_dict,
apply_sklearn_transformer_across_dim,
create_new_spend_data,
generate_fourier_modes,
)
from pymc_marketing.mmm.validating import ValidateControlColumns
from pymc_marketing.model_config import (
create_distribution_from_config,
create_likelihood_distribution,
get_distribution,
)

__all__ = ["BaseMMM", "MMM", "DelayedSaturatedMMM"]

Expand Down Expand Up @@ -236,112 +239,6 @@ def _save_input_params(self, idata) -> None:
idata.attrs["validate_data"] = json.dumps(self.validate_data)
idata.attrs["yearly_seasonality"] = json.dumps(self.yearly_seasonality)

def _create_likelihood_distribution(
self,
dist: dict,
mu: TensorVariable,
observed: np.ndarray | pd.Series,
dims: str,
) -> TensorVariable:
"""
Create and return a likelihood distribution for the model.
This method prepares the distribution and its parameters as specified in the
configuration dictionary, validates them, and constructs the likelihood
distribution using PyMC.
Parameters
----------
dist : Dict
A configuration dictionary that must contain a 'dist' key with the name of
the distribution and a 'kwargs' key with parameters for the distribution.
observed : Union[np.ndarray, pd.Series]
The observed data to which the likelihood distribution will be fitted.
dims : str
The dimensions of the data.
Returns
-------
TensorVariable
The likelihood distribution constructed with PyMC.
Raises
------
ValueError
If 'kwargs' key is missing in `dist`, or the parameter configuration does
not contain 'dist' and 'kwargs' keys, or if 'mu' is present in the nested
'kwargs'
"""
allowed_distributions = [
"Normal",
"StudentT",
"Laplace",
"Logistic",
"LogNormal",
"Wald",
"TruncatedNormal",
"Gamma",
"AsymmetricLaplace",
"VonMises",
]

if dist["dist"] not in allowed_distributions:
raise ValueError(
f"""
The distribution used for the likelihood is not allowed.
Please, use one of the following distributions: {allowed_distributions}.
"""
)

# Validate that 'kwargs' is present and is a dictionary
if "kwargs" not in dist or not isinstance(dist["kwargs"], dict):
raise ValueError(
"The 'kwargs' key must be present in the 'dist' dictionary and be a dictionary itself."
)

if "mu" in dist["kwargs"]:
raise ValueError(
"The 'mu' key is not allowed directly within 'kwargs' of the main distribution as it is reserved."
)

parameter_distributions = {}
for param, param_config in dist["kwargs"].items():
# Check if param_config is a dictionary with a 'dist' key
if isinstance(param_config, dict) and "dist" in param_config:
# Prepare nested distribution
if "kwargs" not in param_config:
raise ValueError(
f"The parameter configuration for '{param}' must contain 'kwargs'."
)

parameter_distributions[param] = _get_distribution_from_dict(
dist=param_config
)(**param_config["kwargs"], name=f"likelihood_{param}")
elif isinstance(param_config, int | float):
# Use the value directly
parameter_distributions[param] = param_config
else:
raise ValueError(
f"""
Invalid parameter configuration for '{param}'.
It must be either a dictionary with a 'dist' key or a numeric value.
"""
)

# Extract the likelihood distribution name and instantiate it
likelihood_dist_name = dist["dist"]
likelihood_dist = _get_distribution_from_dict(
dist={"dist": likelihood_dist_name}
)

return likelihood_dist(
name=self.output_var,
mu=mu,
observed=observed,
dims=dims,
**parameter_distributions,
)

def forward_pass(
self, x: pt.TensorVariable | npt.NDArray[np.float_]
) -> pt.TensorVariable:
Expand Down Expand Up @@ -429,16 +326,6 @@ def build_model(
)
"""

self.intercept_dist = _get_distribution_from_dict(
dist=self.model_config["intercept"]
)
self.gamma_control_dist = _get_distribution_from_dict(
dist=self.model_config["gamma_control"]
)
self.gamma_fourier_dist = _get_distribution_from_dict(
dist=self.model_config["gamma_fourier"]
)

self._generate_and_preprocess_model_data(X, y)
with pm.Model(
coords=self.model_coords,
Expand All @@ -464,16 +351,19 @@ def build_model(
self._time_index,
dims="date",
)
intercept_dist = get_distribution(
name=self.model_config["intercept"]["dist"]
)
intercept = create_time_varying_intercept(
time_index,
self._time_index_mid,
self._time_resolution,
self.intercept_dist,
intercept_dist,
self.model_config,
)
else:
intercept = self.intercept_dist(
name="intercept", **self.model_config["intercept"]["kwargs"]
intercept = create_distribution_from_config(
name="intercept", config=self.model_config
)

channel_contributions = pm.Deterministic(
Expand All @@ -492,10 +382,17 @@ def build_model(
for column in self.control_columns
)
):
gamma_control = self.gamma_control_dist(
if self.model_config["gamma_control"].get("dims") != "control":
msg = (
"The 'dims' key in gamma_control must be 'control'."
" This will be fixed automatically."
)
warnings.warn(msg, stacklevel=2)
self.model_config["gamma_control"]["dims"] = "control"

gamma_control = create_distribution_from_config(
name="gamma_control",
dims="control",
**self.model_config["gamma_control"]["kwargs"],
config=self.model_config,
)

control_data_ = pm.Data(
Expand Down Expand Up @@ -529,10 +426,17 @@ def build_model(
mutable=True,
)

gamma_fourier = self.gamma_fourier_dist(
if self.model_config["gamma_fourier"].get("dims") != "fourier_mode":
msg = (
"The 'dims' key in gamma_fourier must be 'fourier_mode'."
" This will be fixed automatically."
)
warnings.warn(msg, stacklevel=2)
self.model_config["gamma_fourier"]["dims"] = "fourier_mode"

gamma_fourier = create_distribution_from_config(
name="gamma_fourier",
dims="fourier_mode",
**self.model_config["gamma_fourier"]["kwargs"],
config=self.model_config,
)

fourier_contribution = pm.Deterministic(
Expand All @@ -551,8 +455,9 @@ def build_model(

mu = pm.Deterministic(name="mu", var=mu_var, dims="date")

self._create_likelihood_distribution(
dist=self.model_config["likelihood"],
create_likelihood_distribution(
name=self.output_var,
param_config=self.model_config["likelihood"],
mu=mu,
observed=target_,
dims="date",
Expand All @@ -568,8 +473,16 @@ def default_model_config(self) -> dict:
"sigma": {"dist": "HalfNormal", "kwargs": {"sigma": 2}},
},
},
"gamma_control": {"dist": "Normal", "kwargs": {"mu": 0, "sigma": 2}},
"gamma_fourier": {"dist": "Laplace", "kwargs": {"mu": 0, "b": 1}},
"gamma_control": {
"dist": "Normal",
"kwargs": {"mu": 0, "sigma": 2},
"dims": "control",
},
"gamma_fourier": {
"dist": "Laplace",
"kwargs": {"mu": 0, "b": 1},
"dims": "fourier_mode",
},
"intercept_tvp_kwargs": {
"m": 200,
"L": None,
Expand All @@ -580,6 +493,17 @@ def default_model_config(self) -> dict:
},
}

for media_transform in [self.adstock, self.saturation]:
for param, config in media_transform.function_priors.items():
if "dims" not in config:
msg = (
f"{param} doesn't have a 'dims' key in config. Setting to channel."
f" Set priors explicitly in {media_transform.__class__.__name__}"
" to avoid this warning."
)
warnings.warn(msg, stacklevel=2)
config["dims"] = "channel"

return {
**base_config,
**self.adstock.model_config,
Expand Down
Loading

0 comments on commit cc531a7

Please sign in to comment.