Skip to content

Commit

Permalink
PoC: Use Pydantic as data validator (#809)
Browse files Browse the repository at this point in the history
* prior with pydantic

* dependencies

* validate adstock

* make mypy happy

* add validation sample curve

* make the prior type tighter

* add test type

* add validation init mmm

* mmm

* start with Fourier

* fix type

* fix test and imprtove docstrings

* docstrings

* types

* self type

* init validator

* types model builder

* improve docstrings

* more input validations mmm init

* validation budget optimizer

* fix dummy example types

* hsgp kwargs class

* add kwargs

* undo type hint in dict

* fix fourier names

* better docs

* fix tests

* add type hint

* undo

* fix type error

* feedback2

* restrict signature

* serialize fourier

* docs and tests

* fix docs

* work on parsing

* add hsgp to parsing config

* add tests

* uncomment

* undo changes

* undo model config parser

* handle hsgp_kwargs

* add hsgp flag

* docs

* undo type hint

* improve hints

* add more sections to docs

* Update pymc_marketing/mmm/tvp.py

Co-authored-by: Will Dean <[email protected]>

* feedback 4

* fix test

---------

Co-authored-by: Will Dean <[email protected]>
  • Loading branch information
2 people authored and twiecki committed Sep 10, 2024
1 parent 2b59c72 commit 8ffe8c1
Show file tree
Hide file tree
Showing 22 changed files with 480 additions and 235 deletions.
2 changes: 2 additions & 0 deletions docs/source/api/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@
:toctree: generated/
clv
hsgp_kwargs
mmm
model_config
model_builder
prior
```
2 changes: 1 addition & 1 deletion docs/source/notebooks/mmm/mmm_example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -1119,7 +1119,7 @@
"source": [
"dummy_model = MMM(\n",
" date_column=\"\",\n",
" channel_columns=\"\",\n",
" channel_columns=[\"\"],\n",
" adstock=\"geometric\",\n",
" saturation=\"logistic\",\n",
" adstock_max_lag=4,\n",
Expand Down
6 changes: 4 additions & 2 deletions pymc_marketing/clv/models/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import arviz as az
import pandas as pd
import pymc as pm
from pydantic import ConfigDict, InstanceOf, validate_call
from pymc.backends import NDArray
from pymc.backends.base import MultiTrace
from pymc.model.core import Model
Expand All @@ -32,11 +33,12 @@
class CLVModel(ModelBuilder):
_model_type = "CLVModel"

@validate_call(config=ConfigDict(arbitrary_types_allowed=True))
def __init__(
self,
data: pd.DataFrame,
*,
model_config: ModelConfig | None = None,
model_config: InstanceOf[ModelConfig] | None = None,
sampler_config: dict | None = None,
non_distributions: list[str] | None = None,
):
Expand Down Expand Up @@ -65,7 +67,7 @@ def _validate_cols(
if data[required_col].nunique() != n:
raise ValueError(f"Column {required_col} has duplicate entries")

def __repr__(self):
def __repr__(self) -> str:
if not hasattr(self, "model"):
return self._model_type
else:
Expand Down
82 changes: 82 additions & 0 deletions pymc_marketing/hsgp_kwargs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
# Copyright 2024 The PyMC Labs Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Class to store and validate keyword argument for the Hilbert Space Gaussian Process (HSGP) components."""

from typing import Annotated

import pymc as pm
from pydantic import BaseModel, Field, InstanceOf


class HSGPKwargs(BaseModel):
"""HSGP keyword arguments for the time-varying prior.
See [1]_ and [2]_ for the theoretical background on the Hilbert Space Gaussian Process (HSGP).
See , [6]_ for a practical guide through the method using code examples.
See the :class:`~pymc.gp.HSGP` class for more information on the Hilbert Space Gaussian Process in PyMC.
We also recommend the following resources for a more practical introduction to HSGP: [3]_, [4]_, [5]_.
References
----------
.. [1] Solin, A., Sarkka, S. (2019) Hilbert Space Methods for Reduced-Rank Gaussian Process Regression.
.. [2] Ruitort-Mayol, G., and Anderson, M., and Solin, A., and Vehtari, A. (2022). Practical Hilbert Space Approximate Bayesian Gaussian Processes for Probabilistic Programming.
.. [3] PyMC Example Gallery: `"Gaussian Processes: HSGP Reference & First Steps" <https://www.pymc.io/projects/examples/en/latest/gaussian_processes/HSGP-Basic.html>`_.
.. [4] PyMC Example Gallery: `"Gaussian Processes: HSGP Advanced Usage" <https://www.pymc.io/projects/examples/en/latest/gaussian_processes/HSGP-Advanced.html>`_.
.. [5] PyMC Example Gallery: `"Baby Births Modelling with HSGPs" <https://www.pymc.io/projects/examples/en/latest/gaussian_processes/GP-Births.html>`_.
.. [6] Orduz, J. `"A Conceptual and Practical Introduction to Hilbert Space GPs Approximation Methods" <https://juanitorduz.github.io/hsgp_intro/>`_.
Parameters
----------
m : int
Number of basis functions. Default is 200.
L : float, optional
Extent of basis functions. Set this to reflect the expected range of in+out-of-sample data
(considering that time-indices are zero-centered).Default is `X_mid * 2` (identical to `c=2` in HSGP).
By default it is None.
eta_lam : float
Exponential prior for the variance. Default is 1.
ls_mu : float
Mean of the inverse gamma prior for the lengthscale. Default is 5.
ls_sigma : float
Standard deviation of the inverse gamma prior for the lengthscale. Default is 5.
cov_func : ~pymc.gp.cov.Covariance, optional
Gaussian process Covariance function. By default it is None.
""" # noqa E501

m: int = Field(200, description="Number of basis functions")
L: (
Annotated[
float,
Field(
gt=0,
description="""
Extent of basis functions. Set this to reflect the expected range of in+out-of-sample data
(considering that time-indices are zero-centered).Default is `X_mid * 2` (identical to `c=2` in HSGP)
""",
),
]
| None
) = None
eta_lam: float = Field(1, gt=0, description="Exponential prior for the variance")
ls_mu: float = Field(
5, gt=0, description="Mean of the inverse gamma prior for the lengthscale"
)
ls_sigma: float = Field(
5,
gt=0,
description="Standard deviation of the inverse gamma prior for the lengthscale",
)
cov_func: InstanceOf[pm.gp.cov.Covariance] | None = Field(
None, description="Gaussian process Covariance function"
)
31 changes: 17 additions & 14 deletions pymc_marketing/mmm/budget_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from typing import Any

import numpy as np
from pydantic import BaseModel, ConfigDict, Field
from scipy.optimize import minimize

from pymc_marketing.mmm.components.adstock import AdstockTransformation
Expand All @@ -30,7 +31,7 @@ def __init__(self, message: str):
super().__init__(message)


class BudgetOptimizer:
class BudgetOptimizer(BaseModel):
"""
A class for optimizing budget allocation in a marketing mix model.
Expand Down Expand Up @@ -58,19 +59,21 @@ class BudgetOptimizer:
Default is True.
"""

def __init__(
self,
adstock: AdstockTransformation,
saturation: SaturationTransformation,
num_days: int,
parameters: dict[str, dict[str, dict[str, float]]],
adstock_first: bool = True,
):
self.adstock = adstock
self.saturation = saturation
self.num_days = num_days
self.parameters = parameters
self.adstock_first = adstock_first
adstock: AdstockTransformation = Field(
..., description="The adstock transformation class."
)
saturation: SaturationTransformation = Field(
..., description="The saturation transformation class."
)
num_days: int = Field(..., gt=0, description="The number of days.")
parameters: dict[str, dict[str, dict[str, float]]] = Field(
..., description="A dictionary of parameters for each channel."
)
adstock_first: bool = Field(
True,
description="Whether to apply adstock transformation first or saturation transformation first.",
)
model_config = ConfigDict(arbitrary_types_allowed=True)

def objective(self, budgets: list[float]) -> float:
"""
Expand Down
44 changes: 29 additions & 15 deletions pymc_marketing/mmm/components/adstock.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def function(self, x, alpha):

import numpy as np
import xarray as xr
from pydantic import Field, InstanceOf, validate_call

from pymc_marketing.mmm.components.base import Transformation
from pymc_marketing.mmm.transformers import (
Expand Down Expand Up @@ -81,13 +82,20 @@ class AdstockTransformation(Transformation):
prefix: str = "adstock"
lookup_name: str

@validate_call
def __init__(
self,
l_max: int,
normalize: bool = True,
mode: ConvMode = ConvMode.After,
priors: dict | None = None,
prefix: str | None = None,
l_max: int = Field(
..., gt=0, description="Maximum lag for the adstock transformation."
),
normalize: bool = Field(
True, description="Whether to normalize the adstock values."
),
mode: ConvMode = Field(ConvMode.After, description="Convolution mode."),
priors: dict[str, str | InstanceOf[Prior]] | None = Field(
default=None, description="Priors for the parameters."
),
prefix: str | None = Field(None, description="Prefix for the parameters."),
) -> None:
self.l_max = l_max
self.normalize = normalize
Expand Down Expand Up @@ -368,16 +376,22 @@ def _get_adstock_function(
if isinstance(function, AdstockTransformation):
return function

if function not in ADSTOCK_TRANSFORMATIONS:
elif isinstance(function, str):
if function not in ADSTOCK_TRANSFORMATIONS:
raise ValueError(
f"Unknown adstock function: {function}. Choose from {list(ADSTOCK_TRANSFORMATIONS.keys())}"
)

if kwargs:
warnings.warn(
"The preferred method of initializing a lagging function is to use the class directly.",
DeprecationWarning,
stacklevel=1,
)

return ADSTOCK_TRANSFORMATIONS[function](**kwargs)

else:
raise ValueError(
f"Unknown adstock function: {function}. Choose from {list(ADSTOCK_TRANSFORMATIONS.keys())}"
)

if kwargs:
warnings.warn(
"The preferred method of initializing a lagging function is to use the class directly.",
DeprecationWarning,
stacklevel=1,
)

return ADSTOCK_TRANSFORMATIONS[function](**kwargs)
8 changes: 6 additions & 2 deletions pymc_marketing/mmm/components/saturation.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def function(self, x, b):

import numpy as np
import xarray as xr
from pydantic import Field, InstanceOf, validate_call

from pymc_marketing.mmm.components.base import Transformation
from pymc_marketing.mmm.transformers import (
Expand Down Expand Up @@ -130,10 +131,13 @@ class InfiniteReturns(SaturationTransformation):

prefix: str = "saturation"

@validate_call
def sample_curve(
self,
parameters: xr.Dataset,
max_value: float = 1.0,
parameters: InstanceOf[xr.Dataset] = Field(
..., description="Parameters of the saturation transformation."
),
max_value: float = Field(1.0, gt=0, description="Maximum range value."),
) -> xr.DataArray:
"""Sample the curve of the saturation transformation given parameters.
Expand Down
Loading

0 comments on commit 8ffe8c1

Please sign in to comment.