Skip to content

Commit

Permalink
HSGP as component (#1246)
Browse files Browse the repository at this point in the history
* allow for non-centered with scalers

* bring back the logic from #829

* closes #1177

* import plot_curve

* support for more dimensions

* move the test suite

* change the import in the example

* add back the type hint

* hopefully display the plots

* correct the docstring

* more the tests around

* change the reference format

* implement HSGPPeriod based on prior_linearized example

* Be explicit in the docstring

* Switch to using pytensor for operations

* pull out eta and ls as Prior class

* Add serialization methods for Periodic class

* set seed in the example for period

* use pydantic for validation

* add some tests

* handle case with pytensor array

* Some validation on the L parameter

* Remove reference for prior since works for posterior as well

* add test for period hsgp

* test for serialization

* support for CovFunc as well

* consolidate similar logic

* use the HSGP class internally

* handle the drop first case

* checks for non-dict parameters and higher dimen block

* fix the failing tvp tests

* run the notebook with new names

* migrate to mmm.hsgp module

* support for more dimensions in HSGPPeriodic

* Specific training in description

---------

Co-authored-by: Juan Orduz <[email protected]>
  • Loading branch information
wd60622 and juanitorduz authored Dec 12, 2024
1 parent f4d5030 commit 45d03ea
Show file tree
Hide file tree
Showing 17 changed files with 4,050 additions and 2,518 deletions.
3,599 changes: 1,873 additions & 1,726 deletions docs/source/notebooks/mmm/mmm_time_varying_media_example.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion pymc_marketing/hsgp_kwargs.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,6 @@ class HSGPKwargs(BaseModel):
gt=0,
description="Standard deviation of the inverse gamma prior for the lengthscale",
)
cov_func: InstanceOf[pm.gp.cov.Covariance] | None = Field(
cov_func: InstanceOf[pm.gp.cov.Covariance] | str | None = Field(
None, description="Gaussian process Covariance function"
)
20 changes: 20 additions & 0 deletions pymc_marketing/mmm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,17 @@
saturation_from_dict,
)
from pymc_marketing.mmm.fourier import MonthlyFourier, YearlyFourier
from pymc_marketing.mmm.hsgp import (
HSGP,
CovFunc,
HSGPPeriodic,
PeriodicCovFunc,
approx_hsgp_hyperparams,
create_complexity_penalizing_prior,
create_constrained_inverse_gamma_prior,
create_eta_prior,
create_m_and_L_recommendations,
)
from pymc_marketing.mmm.linear_trend import LinearTrend
from pymc_marketing.mmm.media_transformation import (
MediaConfig,
Expand All @@ -52,11 +63,14 @@
from pymc_marketing.mmm.validating import validation_method_X, validation_method_y

__all__ = [
"HSGP",
"MMM",
"AdstockTransformation",
"BaseValidateMMM",
"CovFunc",
"DelayedAdstock",
"GeometricAdstock",
"HSGPPeriodic",
"HillSaturation",
"HillSaturationSigmoid",
"InverseScaledLogisticSaturation",
Expand All @@ -68,6 +82,7 @@
"MediaTransformation",
"MichaelisMentenSaturation",
"MonthlyFourier",
"PeriodicCovFunc",
"RootSaturation",
"SaturationTransformation",
"TanhSaturation",
Expand All @@ -76,7 +91,12 @@
"WeibullPDFAdstock",
"YearlyFourier",
"adstock_from_dict",
"approx_hsgp_hyperparams",
"base",
"create_complexity_penalizing_prior",
"create_constrained_inverse_gamma_prior",
"create_eta_prior",
"create_m_and_L_recommendations",
"mmm",
"preprocessing",
"preprocessing_method_X",
Expand Down
4 changes: 2 additions & 2 deletions pymc_marketing/mmm/components/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,13 @@
from pytensor import tensor as pt
from pytensor.tensor.variable import TensorVariable

from pymc_marketing.mmm.plot import (
from pymc_marketing.model_config import parse_model_config
from pymc_marketing.plot import (
SelToString,
plot_curve,
plot_hdi,
plot_samples,
)
from pymc_marketing.model_config import parse_model_config
from pymc_marketing.prior import DimHandler, Prior, create_dim_handler

# "x" for saturation, "time since exposure" for adstock
Expand Down
2 changes: 1 addition & 1 deletion pymc_marketing/mmm/fourier.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@
from typing_extensions import Self

from pymc_marketing.constants import DAYS_IN_MONTH, DAYS_IN_YEAR
from pymc_marketing.mmm.plot import SelToString, plot_curve, plot_hdi, plot_samples
from pymc_marketing.plot import SelToString, plot_curve, plot_hdi, plot_samples
from pymc_marketing.prior import Prior, create_dim_handler

X_NAME: str = "day"
Expand Down
Loading

0 comments on commit 45d03ea

Please sign in to comment.