Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PoC: Use Pydantic as data validator #809

Merged
merged 55 commits into from
Jul 9, 2024
Merged
Show file tree
Hide file tree
Changes from 52 commits
Commits
Show all changes
55 commits
Select commit Hold shift + click to select a range
bf2d621
prior with pydantic
juanitorduz Jul 4, 2024
d51331c
dependencies
juanitorduz Jul 4, 2024
e4aefbc
validate adstock
juanitorduz Jul 4, 2024
6014c5e
make mypy happy
juanitorduz Jul 4, 2024
11848af
add validation sample curve
juanitorduz Jul 4, 2024
20982cb
Merge branch 'main' into pydantic1
juanitorduz Jul 5, 2024
2138235
Merge branch 'main' into pydantic1
juanitorduz Jul 5, 2024
07a8cd5
make the prior type tighter
juanitorduz Jul 5, 2024
56e033e
add test type
juanitorduz Jul 5, 2024
828df9b
add validation init mmm
juanitorduz Jul 5, 2024
cd8f20a
mmm
juanitorduz Jul 5, 2024
cba5b2a
start with Fourier
juanitorduz Jul 5, 2024
921730f
fix type
juanitorduz Jul 5, 2024
0427e5d
fix test and imprtove docstrings
juanitorduz Jul 5, 2024
409f935
docstrings
juanitorduz Jul 5, 2024
f628ccb
types
juanitorduz Jul 5, 2024
5ca2de3
self type
juanitorduz Jul 5, 2024
cfe8021
init validator
juanitorduz Jul 5, 2024
c610046
types model builder
juanitorduz Jul 5, 2024
1104298
improve docstrings
juanitorduz Jul 5, 2024
b75e3c9
more input validations mmm init
juanitorduz Jul 5, 2024
62a5be3
validation budget optimizer
juanitorduz Jul 5, 2024
63ed4fc
fix dummy example types
juanitorduz Jul 5, 2024
9866b40
hsgp kwargs class
juanitorduz Jul 6, 2024
6d91790
add kwargs
juanitorduz Jul 6, 2024
7416728
undo type hint in dict
juanitorduz Jul 6, 2024
36ef9f3
fix fourier names
juanitorduz Jul 6, 2024
1cd86fe
better docs
juanitorduz Jul 6, 2024
bd8ce8a
fix tests
juanitorduz Jul 6, 2024
7f1b944
add type hint
juanitorduz Jul 6, 2024
4b83f1b
undo
juanitorduz Jul 6, 2024
79c2cc3
fix type error
juanitorduz Jul 6, 2024
32d9118
feedback2
juanitorduz Jul 6, 2024
39ad8ce
restrict signature
juanitorduz Jul 6, 2024
1bd9627
serialize fourier
juanitorduz Jul 6, 2024
b69d2ce
docs and tests
juanitorduz Jul 6, 2024
5a14589
fix docs
juanitorduz Jul 6, 2024
49599b5
Merge branch 'main' into pydantic1
juanitorduz Jul 8, 2024
7df75af
Merge branch 'main' into pydantic1
juanitorduz Jul 8, 2024
101d861
work on parsing
juanitorduz Jul 8, 2024
a3201a6
add hsgp to parsing config
juanitorduz Jul 8, 2024
650e507
add tests
juanitorduz Jul 8, 2024
2d96a71
uncomment
juanitorduz Jul 8, 2024
327f880
undo changes
juanitorduz Jul 8, 2024
b1c7c84
undo model config parser
juanitorduz Jul 8, 2024
7f88be6
Merge branch 'main' into pydantic1
juanitorduz Jul 8, 2024
93f6106
handle hsgp_kwargs
juanitorduz Jul 8, 2024
7d88913
add hsgp flag
juanitorduz Jul 8, 2024
e650a49
docs
juanitorduz Jul 8, 2024
b90e742
undo type hint
juanitorduz Jul 8, 2024
08ed78c
improve hints
juanitorduz Jul 8, 2024
3fd1329
add more sections to docs
juanitorduz Jul 8, 2024
3eb8446
Update pymc_marketing/mmm/tvp.py
juanitorduz Jul 9, 2024
86e4d7a
feedback 4
juanitorduz Jul 9, 2024
47851cc
fix test
juanitorduz Jul 9, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/api/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@
:toctree: generated/
clv
hsgp_kwargs
mmm
model_config
model_builder
prior
```
2 changes: 1 addition & 1 deletion docs/source/notebooks/mmm/mmm_example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -1119,7 +1119,7 @@
"source": [
"dummy_model = MMM(\n",
" date_column=\"\",\n",
" channel_columns=\"\",\n",
" channel_columns=[\"\"],\n",
" adstock=\"geometric\",\n",
" saturation=\"logistic\",\n",
" adstock_max_lag=4,\n",
Expand Down
6 changes: 4 additions & 2 deletions pymc_marketing/clv/models/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import arviz as az
import pandas as pd
import pymc as pm
from pydantic import ConfigDict, InstanceOf, validate_call
from pymc.backends import NDArray
from pymc.backends.base import MultiTrace
from pymc.model.core import Model
Expand All @@ -32,11 +33,12 @@
class CLVModel(ModelBuilder):
_model_type = "CLVModel"

@validate_call(config=ConfigDict(arbitrary_types_allowed=True))
juanitorduz marked this conversation as resolved.
Show resolved Hide resolved
def __init__(
self,
data: pd.DataFrame,
*,
model_config: ModelConfig | None = None,
model_config: InstanceOf[ModelConfig] | None = None,
sampler_config: dict | None = None,
non_distributions: list[str] | None = None,
):
Expand Down Expand Up @@ -65,7 +67,7 @@ def _validate_cols(
if data[required_col].nunique() != n:
raise ValueError(f"Column {required_col} has duplicate entries")

def __repr__(self):
def __repr__(self) -> str:
if not hasattr(self, "model"):
return self._model_type
else:
Expand Down
80 changes: 80 additions & 0 deletions pymc_marketing/hsgp_kwargs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# Copyright 2024 The PyMC Labs Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
juanitorduz marked this conversation as resolved.
Show resolved Hide resolved
from typing import Annotated

import pymc as pm
from pydantic import BaseModel, Field, InstanceOf


class HSGPKwargs(BaseModel):
"""HSGP keyword arguments for the time-varying prior.

See [1]_ and [2]_ for the theoretical background on the Hilbert Space Gaussian Process (HSGP).
See , [6]_ for a practical guide through the method using code examples.
See the :class:`~pymc.gp.HSGP` class for more information on the Hilbert Space Gaussian Process in PyMC.
We also recommend the following resources for a more practical introduction to HSGP: [3]_, [4]_, [5]_.

References
----------
.. [1] Solin, A., Sarkka, S. (2019) Hilbert Space Methods for Reduced-Rank Gaussian Process Regression.
.. [2] Ruitort-Mayol, G., and Anderson, M., and Solin, A., and Vehtari, A. (2022). Practical Hilbert Space Approximate Bayesian Gaussian Processes for Probabilistic Programming.
.. [3] PyMC Example Gallery: `"Gaussian Processes: HSGP Reference & First Steps" <https://www.pymc.io/projects/examples/en/latest/gaussian_processes/HSGP-Basic.html>`_.
.. [4] PyMC Example Gallery: `"Gaussian Processes: HSGP Advanced Usage" <https://www.pymc.io/projects/examples/en/latest/gaussian_processes/HSGP-Advanced.html>`_.
.. [5] PyMC Example Gallery: `"Baby Births Modelling with HSGPs" <https://www.pymc.io/projects/examples/en/latest/gaussian_processes/GP-Births.html>`_.
.. [6] Orduz, J. `"A Conceptual and Practical Introduction to Hilbert Space GPs Approximation Methods" <https://juanitorduz.github.io/hsgp_intro/>`_.

Parameters
----------
m : int
Number of basis functions. Default is 200.
L : float, optional
Extent of basis functions. Set this to reflect the expected range of in+out-of-sample data
(considering that time-indices are zero-centered).Default is `X_mid * 2` (identical to `c=2` in HSGP).
By default it is None.
eta_lam : float
Exponential prior for the variance. Default is 1.
ls_mu : float
Mean of the inverse gamma prior for the lengthscale. Default is 5.
ls_sigma : float
Standard deviation of the inverse gamma prior for the lengthscale. Default is 5.
cov_func : ~pymc.gp.cov.Covariance, optional
Gaussian process Covariance function. By default it is None.
""" # noqa E501

m: int = Field(200, description="Number of basis functions")
L: (
Annotated[
float,
Field(
gt=0,
description="""
Extent of basis functions. Set this to reflect the expected range of in+out-of-sample data
(considering that time-indices are zero-centered).Default is `X_mid * 2` (identical to `c=2` in HSGP)
""",
),
]
| None
) = None
eta_lam: float = Field(1, gt=0, description="Exponential prior for the variance")
ls_mu: float = Field(
5, gt=0, description="Mean of the inverse gamma prior for the lengthscale"
)
ls_sigma: float = Field(
5,
gt=0,
description="Standard deviation of the inverse gamma prior for the lengthscale",
)
cov_func: InstanceOf[pm.gp.cov.Covariance] | None = Field(
None, description="Gaussian process Covariance function"
)
31 changes: 17 additions & 14 deletions pymc_marketing/mmm/budget_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from typing import Any

import numpy as np
from pydantic import BaseModel, ConfigDict, Field
from scipy.optimize import minimize

from pymc_marketing.mmm.components.adstock import AdstockTransformation
Expand All @@ -30,7 +31,7 @@ def __init__(self, message: str):
super().__init__(message)


class BudgetOptimizer:
class BudgetOptimizer(BaseModel):
"""
A class for optimizing budget allocation in a marketing mix model.
Expand Down Expand Up @@ -58,19 +59,21 @@ class BudgetOptimizer:
Default is True.
"""

def __init__(
self,
adstock: AdstockTransformation,
saturation: SaturationTransformation,
num_days: int,
parameters: dict[str, dict[str, dict[str, float]]],
adstock_first: bool = True,
):
self.adstock = adstock
self.saturation = saturation
self.num_days = num_days
self.parameters = parameters
self.adstock_first = adstock_first
adstock: AdstockTransformation = Field(
..., description="The adstock transformation class."
)
saturation: SaturationTransformation = Field(
..., description="The saturation transformation class."
)
num_days: int = Field(..., gt=0, description="The number of days.")
parameters: dict[str, dict[str, dict[str, float]]] = Field(
..., description="A dictionary of parameters for each channel."
)
adstock_first: bool = Field(
True,
description="Whether to apply adstock transformation first or saturation transformation first.",
)
model_config = ConfigDict(arbitrary_types_allowed=True)

def objective(self, budgets: list[float]) -> float:
"""
Expand Down
44 changes: 29 additions & 15 deletions pymc_marketing/mmm/components/adstock.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def function(self, x, alpha):

import numpy as np
import xarray as xr
from pydantic import Field, InstanceOf, validate_call

from pymc_marketing.mmm.components.base import Transformation
from pymc_marketing.mmm.transformers import (
Expand Down Expand Up @@ -81,13 +82,20 @@ class AdstockTransformation(Transformation):
prefix: str = "adstock"
lookup_name: str

@validate_call
def __init__(
self,
l_max: int,
normalize: bool = True,
mode: ConvMode = ConvMode.After,
priors: dict | None = None,
prefix: str | None = None,
l_max: int = Field(
..., gt=0, description="Maximum lag for the adstock transformation."
),
normalize: bool = Field(
True, description="Whether to normalize the adstock values."
),
mode: ConvMode = Field(ConvMode.After, description="Convolution mode."),
priors: dict[str, str | InstanceOf[Prior]] | None = Field(
default=None, description="Priors for the parameters."
),
prefix: str | None = Field(None, description="Prefix for the parameters."),
) -> None:
self.l_max = l_max
self.normalize = normalize
Expand Down Expand Up @@ -368,16 +376,22 @@ def _get_adstock_function(
if isinstance(function, AdstockTransformation):
return function

if function not in ADSTOCK_TRANSFORMATIONS:
elif isinstance(function, str):
if function not in ADSTOCK_TRANSFORMATIONS:
raise ValueError(
f"Unknown adstock function: {function}. Choose from {list(ADSTOCK_TRANSFORMATIONS.keys())}"
)

if kwargs:
warnings.warn(
"The preferred method of initializing a lagging function is to use the class directly.",
DeprecationWarning,
stacklevel=1,
)

return ADSTOCK_TRANSFORMATIONS[function](**kwargs)

else:
raise ValueError(
f"Unknown adstock function: {function}. Choose from {list(ADSTOCK_TRANSFORMATIONS.keys())}"
)

if kwargs:
warnings.warn(
"The preferred method of initializing a lagging function is to use the class directly.",
DeprecationWarning,
stacklevel=1,
)

return ADSTOCK_TRANSFORMATIONS[function](**kwargs)
8 changes: 6 additions & 2 deletions pymc_marketing/mmm/components/saturation.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def function(self, x, b):

import numpy as np
import xarray as xr
from pydantic import Field, InstanceOf, validate_call

from pymc_marketing.mmm.components.base import Transformation
from pymc_marketing.mmm.transformers import (
Expand Down Expand Up @@ -130,10 +131,13 @@ class InfiniteReturns(SaturationTransformation):

prefix: str = "saturation"

@validate_call
def sample_curve(
self,
parameters: xr.Dataset,
max_value: float = 1.0,
parameters: InstanceOf[xr.Dataset] = Field(
..., description="Parameters of the saturation transformation."
),
max_value: float = Field(1.0, gt=0, description="Maximum range value."),
) -> xr.DataArray:
"""Sample the curve of the saturation transformation given parameters.
Expand Down
Loading
Loading