-
Notifications
You must be signed in to change notification settings - Fork 220
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
PoC: Use Pydantic as data validator (#809)
* prior with pydantic * dependencies * validate adstock * make mypy happy * add validation sample curve * make the prior type tighter * add test type * add validation init mmm * mmm * start with Fourier * fix type * fix test and imprtove docstrings * docstrings * types * self type * init validator * types model builder * improve docstrings * more input validations mmm init * validation budget optimizer * fix dummy example types * hsgp kwargs class * add kwargs * undo type hint in dict * fix fourier names * better docs * fix tests * add type hint * undo * fix type error * feedback2 * restrict signature * serialize fourier * docs and tests * fix docs * work on parsing * add hsgp to parsing config * add tests * uncomment * undo changes * undo model config parser * handle hsgp_kwargs * add hsgp flag * docs * undo type hint * improve hints * add more sections to docs * Update pymc_marketing/mmm/tvp.py Co-authored-by: Will Dean <[email protected]> * feedback 4 * fix test --------- Co-authored-by: Will Dean <[email protected]>
- Loading branch information
Showing
22 changed files
with
480 additions
and
235 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,7 +8,9 @@ | |
:toctree: generated/ | ||
clv | ||
hsgp_kwargs | ||
mmm | ||
model_config | ||
model_builder | ||
prior | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
# Copyright 2024 The PyMC Labs Developers | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
"""Class to store and validate keyword argument for the Hilbert Space Gaussian Process (HSGP) components.""" | ||
|
||
from typing import Annotated | ||
|
||
import pymc as pm | ||
from pydantic import BaseModel, Field, InstanceOf | ||
|
||
|
||
class HSGPKwargs(BaseModel): | ||
"""HSGP keyword arguments for the time-varying prior. | ||
See [1]_ and [2]_ for the theoretical background on the Hilbert Space Gaussian Process (HSGP). | ||
See , [6]_ for a practical guide through the method using code examples. | ||
See the :class:`~pymc.gp.HSGP` class for more information on the Hilbert Space Gaussian Process in PyMC. | ||
We also recommend the following resources for a more practical introduction to HSGP: [3]_, [4]_, [5]_. | ||
References | ||
---------- | ||
.. [1] Solin, A., Sarkka, S. (2019) Hilbert Space Methods for Reduced-Rank Gaussian Process Regression. | ||
.. [2] Ruitort-Mayol, G., and Anderson, M., and Solin, A., and Vehtari, A. (2022). Practical Hilbert Space Approximate Bayesian Gaussian Processes for Probabilistic Programming. | ||
.. [3] PyMC Example Gallery: `"Gaussian Processes: HSGP Reference & First Steps" <https://www.pymc.io/projects/examples/en/latest/gaussian_processes/HSGP-Basic.html>`_. | ||
.. [4] PyMC Example Gallery: `"Gaussian Processes: HSGP Advanced Usage" <https://www.pymc.io/projects/examples/en/latest/gaussian_processes/HSGP-Advanced.html>`_. | ||
.. [5] PyMC Example Gallery: `"Baby Births Modelling with HSGPs" <https://www.pymc.io/projects/examples/en/latest/gaussian_processes/GP-Births.html>`_. | ||
.. [6] Orduz, J. `"A Conceptual and Practical Introduction to Hilbert Space GPs Approximation Methods" <https://juanitorduz.github.io/hsgp_intro/>`_. | ||
Parameters | ||
---------- | ||
m : int | ||
Number of basis functions. Default is 200. | ||
L : float, optional | ||
Extent of basis functions. Set this to reflect the expected range of in+out-of-sample data | ||
(considering that time-indices are zero-centered).Default is `X_mid * 2` (identical to `c=2` in HSGP). | ||
By default it is None. | ||
eta_lam : float | ||
Exponential prior for the variance. Default is 1. | ||
ls_mu : float | ||
Mean of the inverse gamma prior for the lengthscale. Default is 5. | ||
ls_sigma : float | ||
Standard deviation of the inverse gamma prior for the lengthscale. Default is 5. | ||
cov_func : ~pymc.gp.cov.Covariance, optional | ||
Gaussian process Covariance function. By default it is None. | ||
""" # noqa E501 | ||
|
||
m: int = Field(200, description="Number of basis functions") | ||
L: ( | ||
Annotated[ | ||
float, | ||
Field( | ||
gt=0, | ||
description=""" | ||
Extent of basis functions. Set this to reflect the expected range of in+out-of-sample data | ||
(considering that time-indices are zero-centered).Default is `X_mid * 2` (identical to `c=2` in HSGP) | ||
""", | ||
), | ||
] | ||
| None | ||
) = None | ||
eta_lam: float = Field(1, gt=0, description="Exponential prior for the variance") | ||
ls_mu: float = Field( | ||
5, gt=0, description="Mean of the inverse gamma prior for the lengthscale" | ||
) | ||
ls_sigma: float = Field( | ||
5, | ||
gt=0, | ||
description="Standard deviation of the inverse gamma prior for the lengthscale", | ||
) | ||
cov_func: InstanceOf[pm.gp.cov.Covariance] | None = Field( | ||
None, description="Gaussian process Covariance function" | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.