Skip to content

Commit

Permalink
Check for missing attrs after sample_prior_predictive and fit (#867)
Browse files Browse the repository at this point in the history
* separate the attr creation from attachment and perform check

* remove data for CLV

* fix model_builder tests

* fix clv tests

* more specific model builder checks

* rework with no args and kwargs

* rework common load method

* Update pymc_marketing/model_builder.py
  • Loading branch information
wd60622 authored and twiecki committed Sep 10, 2024
1 parent 430ccc5 commit ee1e804
Show file tree
Hide file tree
Showing 5 changed files with 258 additions and 128 deletions.
117 changes: 35 additions & 82 deletions pymc_marketing/mmm/delayed_saturated_mmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@

import json
import warnings
from pathlib import Path
from typing import Annotated, Any

import arviz as az
Expand Down Expand Up @@ -56,7 +55,6 @@
from pymc_marketing.mmm.validating import ValidateControlColumns
from pymc_marketing.model_config import parse_model_config
from pymc_marketing.prior import Prior
from pymc_marketing.utils import from_netcdf

__all__ = ["BaseMMM", "MMM", "DelayedSaturatedMMM"]

Expand Down Expand Up @@ -298,19 +296,21 @@ def _generate_and_preprocess_model_data( # type: ignore
self.X[self.date_column].iloc[1] - self.X[self.date_column].iloc[0]
).days

def _save_input_params(self, idata) -> None:
"""Saves input parameters to the attrs of idata."""
idata.attrs["date_column"] = json.dumps(self.date_column)
idata.attrs["adstock"] = json.dumps(self.adstock.lookup_name)
idata.attrs["saturation"] = json.dumps(self.saturation.lookup_name)
idata.attrs["adstock_first"] = json.dumps(self.adstock_first)
idata.attrs["control_columns"] = json.dumps(self.control_columns)
idata.attrs["channel_columns"] = json.dumps(self.channel_columns)
idata.attrs["adstock_max_lag"] = json.dumps(self.adstock.l_max)
idata.attrs["validate_data"] = json.dumps(self.validate_data)
idata.attrs["yearly_seasonality"] = json.dumps(self.yearly_seasonality)
idata.attrs["time_varying_intercept"] = json.dumps(self.time_varying_intercept)
idata.attrs["time_varying_media"] = json.dumps(self.time_varying_media)
def create_idata_attrs(self) -> dict[str, str]:
attrs = super().create_idata_attrs()
attrs["date_column"] = json.dumps(self.date_column)
attrs["adstock"] = json.dumps(self.adstock.lookup_name)
attrs["saturation"] = json.dumps(self.saturation.lookup_name)
attrs["adstock_first"] = json.dumps(self.adstock_first)
attrs["control_columns"] = json.dumps(self.control_columns)
attrs["channel_columns"] = json.dumps(self.channel_columns)
attrs["adstock_max_lag"] = json.dumps(self.adstock.l_max)
attrs["validate_data"] = json.dumps(self.validate_data)
attrs["yearly_seasonality"] = json.dumps(self.yearly_seasonality)
attrs["time_varying_intercept"] = json.dumps(self.time_varying_intercept)
attrs["time_varying_media"] = json.dumps(self.time_varying_media)

return attrs

def forward_pass(
self, x: pt.TensorVariable | npt.NDArray[np.float64]
Expand Down Expand Up @@ -636,73 +636,26 @@ def ndarray_to_list(d: dict) -> dict:
return ndarray_to_list(serializable_config)

@classmethod
def load(cls, fname: str):
"""
Creates a MMM instance from a file,
instantiating the model with the saved original input parameters.
Loads inference data for the model.
Parameters
----------
fname : string
This denotes the name with path from where idata should be loaded from.
Returns
-------
Returns an instance of MMM.
Raises
------
ValueError
If the inference data that is loaded doesn't match with the model.
"""

filepath = Path(fname)
idata = from_netcdf(filepath)
model_config = cls._model_config_formatting(
json.loads(idata.attrs["model_config"])
)
with warnings.catch_warnings():
warnings.simplefilter("ignore", category=DeprecationWarning)
model = cls(
date_column=json.loads(idata.attrs["date_column"]),
control_columns=json.loads(idata.attrs["control_columns"]),
# Media Transformations
channel_columns=json.loads(idata.attrs["channel_columns"]),
adstock_max_lag=json.loads(idata.attrs["adstock_max_lag"]),
adstock=json.loads(idata.attrs.get("adstock", "geometric")),
saturation=json.loads(idata.attrs.get("saturation", "logistic")),
adstock_first=json.loads(idata.attrs.get("adstock_first", True)),
# Seasonality
yearly_seasonality=json.loads(idata.attrs["yearly_seasonality"]),
# TVP
time_varying_intercept=json.loads(
idata.attrs.get("time_varying_intercept", False)
),
time_varying_media=json.loads(
idata.attrs.get("time_varying_media", False)
),
# Configurations
validate_data=json.loads(idata.attrs["validate_data"]),
model_config=model_config,
sampler_config=json.loads(idata.attrs["sampler_config"]),
)

model.idata = idata
dataset = idata.fit_data.to_dataframe()
X = dataset.drop(columns=[model.output_var])
y = dataset[model.output_var].values
model.build_model(X, y)
# All previously used data is in idata.
if model.id != idata.attrs["id"]:
error_msg = (
f"The file '{fname}' does not contain "
"an inference data of the same model or "
f"configuration as '{cls._model_type}'"
)
raise ValueError(error_msg)

return model
def attrs_to_init_kwargs(cls, attrs) -> dict[str, Any]:
return {
"model_config": cls._model_config_formatting(
json.loads(attrs["model_config"])
),
"date_column": json.loads(attrs["date_column"]),
"control_columns": json.loads(attrs["control_columns"]),
"channel_columns": json.loads(attrs["channel_columns"]),
"adstock_max_lag": json.loads(attrs["adstock_max_lag"]),
"adstock": json.loads(attrs.get("adstock", "geometric")),
"saturation": json.loads(attrs.get("saturation", "logistic")),
"adstock_first": json.loads(attrs.get("adstock_first", True)),
"yearly_seasonality": json.loads(attrs["yearly_seasonality"]),
"time_varying_intercept": json.loads(
attrs.get("time_varying_intercept", False)
),
"time_varying_media": json.loads(attrs.get("time_varying_media", False)),
"validate_data": json.loads(attrs["validate_data"]),
"sampler_config": json.loads(attrs["sampler_config"]),
}

def _data_setter(
self,
Expand Down
125 changes: 90 additions & 35 deletions pymc_marketing/model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import json
import warnings
from abc import ABC, abstractmethod
from inspect import signature
from pathlib import Path
from typing import Any

Expand Down Expand Up @@ -135,7 +136,7 @@ def _data_setter(

@property
@abstractmethod
def output_var(self):
def output_var(self) -> str:
"""
Returns the name of the output variable of the model.
Expand Down Expand Up @@ -264,6 +265,27 @@ def build_model(
None
"""

def create_idata_attrs(self) -> dict[str, str]:
def default(x):
if isinstance(x, Prior):
return x.to_json()
elif isinstance(x, HSGPKwargs):
return x.model_dump(mode="json")
return x.__dict__

attrs: dict[str, str] = {}

attrs["id"] = self.id
attrs["model_type"] = self._model_type
attrs["version"] = self.version
attrs["sampler_config"] = json.dumps(self.sampler_config)
attrs["model_config"] = json.dumps(
self._serializable_model_config,
default=default,
)

return attrs

def set_idata_attrs(
self, idata: az.InferenceData | None = None
) -> az.InferenceData:
Expand All @@ -277,42 +299,59 @@ def set_idata_attrs(
Raises
------
ValueError
If the attrs are missing for a property initialization of the class
RuntimeError
If no InferenceData object is provided.
Returns
-------
None
InferenceData
The InferenceData instance with the attrs set
Examples
--------
>>> model = MyModel(ModelBuilder)
>>> idata = az.InferenceData(your_dataset)
>>> model.set_idata_attrs(idata=idata)
Set the attrs for an InferenceData object manually.
.. code-block:: python
idata: az.InferenceData = ...
model.set_idata_attrs(idata=idata)
"""
if idata is None:
idata = self.idata
if idata is None:
raise RuntimeError("No idata provided to set attrs on.")

def default(x):
if isinstance(x, Prior):
return x.to_json()
elif isinstance(x, HSGPKwargs):
return x.model_dump(mode="json")
return x.__dict__
attrs = self.create_idata_attrs()
attrs_keys = set(attrs.keys())
required_keys = {
"id",
"model_type",
"version",
"sampler_config",
"model_config",
}
if missing_keys := required_keys - attrs_keys:
msg = (
f"Missing required keys in attrs: {missing_keys}. "
"Call super().create_idata_attrs()."
)
raise ValueError(msg)

idata.attrs["id"] = self.id
idata.attrs["model_type"] = self._model_type
idata.attrs["version"] = self.version
idata.attrs["sampler_config"] = json.dumps(self.sampler_config)
idata.attrs["model_config"] = json.dumps(
self._serializable_model_config,
default=default,
)
# Only classes with non-dataset parameters will implement save_input_params
if hasattr(self, "_save_input_params"):
self._save_input_params(idata)
init_parameters: set[str] = set(signature(self.__init__).parameters.keys()) # type: ignore
# Remove since this will be stored in the fit_data group of InferenceData
init_parameters -= {"data"}

if missing_keys := init_parameters - attrs_keys:
msg = (
f"__init__ has parameters that are not in the attrs: {missing_keys}. "
"The save and load functionality will not work correctly."
)
raise ValueError(msg)

idata.attrs = attrs
return idata

def save(self, fname: str) -> None:
Expand Down Expand Up @@ -374,10 +413,20 @@ def _model_config_formatting(cls, model_config: dict) -> dict:
)
return model_config

@classmethod
def attrs_to_init_kwargs(cls, attrs) -> dict[str, Any]:
return {
"model_config": cls._model_config_formatting(
json.loads(attrs["model_config"])
),
"sampler_config": json.loads(attrs["sampler_config"]),
}

@classmethod
def load(cls, fname: str):
"""
Creates a ModelBuilder instance from a file,
Loads inference data for the model.
Parameters
Expand All @@ -393,24 +442,27 @@ def load(cls, fname: str):
------
ValueError
If the inference data that is loaded doesn't match with the model.
Examples
--------
>>> class MyModel(ModelBuilder):
>>> ...
>>> name = './mymodel.nc'
>>> imported_model = MyModel.load(name)
Load a model from a file
.. code-block:: python
file_name: str = "./mymodel.nc"
model = MyModel.load(file_name)
"""
filepath = Path(str(fname))
idata = from_netcdf(filepath)

# needs to be converted, because json.loads was changing tuple to list
model_config = cls._model_config_formatting(
json.loads(idata.attrs["model_config"])
)
model = cls(
model_config=model_config,
sampler_config=json.loads(idata.attrs["sampler_config"]),
)
init_kwargs = cls.attrs_to_init_kwargs(idata.attrs)

with warnings.catch_warnings():
warnings.simplefilter("ignore", category=DeprecationWarning)
model = cls(**init_kwargs)

model.idata = idata
dataset = idata.fit_data.to_dataframe()
X = dataset.drop(columns=[model.output_var])
Expand All @@ -419,8 +471,11 @@ def load(cls, fname: str):
# All previously used data is in idata.

if model.id != idata.attrs["id"]:
error_msg = f"""The file '{fname}' does not contain an inference data of the same model
or configuration as '{cls._model_type}'"""
error_msg = (
f"The file '{fname}' does not contain "
"an inference data of the same model "
f"or configuration as '{cls._model_type}'"
)
raise ValueError(error_msg)

return model
Expand Down
15 changes: 13 additions & 2 deletions tests/clv/models/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,21 @@
class CLVModelTest(CLVModel):
_model_type = "CLVModelTest"

def __init__(self, data=None, **kwargs):
def __init__(
self,
data=None,
model_config=None,
sampler_config: dict | None = None,
):
if data is None:
data = pd.DataFrame({"y": np.random.randn(10)})
super().__init__(data=data, **kwargs)

super().__init__(
data=data,
model_config=model_config,
sampler_config=sampler_config,
non_distributions=[],
)

@property
def default_model_config(self):
Expand Down
Loading

0 comments on commit ee1e804

Please sign in to comment.