Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Check for missing attrs after sample_prior_predictive and fit #867

Merged
117 changes: 35 additions & 82 deletions pymc_marketing/mmm/delayed_saturated_mmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@

import json
import warnings
from pathlib import Path
from typing import Annotated, Any

import arviz as az
Expand Down Expand Up @@ -56,7 +55,6 @@
from pymc_marketing.mmm.validating import ValidateControlColumns
from pymc_marketing.model_config import parse_model_config
from pymc_marketing.prior import Prior
from pymc_marketing.utils import from_netcdf

__all__ = ["BaseMMM", "MMM", "DelayedSaturatedMMM"]

Expand Down Expand Up @@ -298,19 +296,21 @@ def _generate_and_preprocess_model_data( # type: ignore
self.X[self.date_column].iloc[1] - self.X[self.date_column].iloc[0]
).days

def _save_input_params(self, idata) -> None:
"""Saves input parameters to the attrs of idata."""
idata.attrs["date_column"] = json.dumps(self.date_column)
idata.attrs["adstock"] = json.dumps(self.adstock.lookup_name)
idata.attrs["saturation"] = json.dumps(self.saturation.lookup_name)
idata.attrs["adstock_first"] = json.dumps(self.adstock_first)
idata.attrs["control_columns"] = json.dumps(self.control_columns)
idata.attrs["channel_columns"] = json.dumps(self.channel_columns)
idata.attrs["adstock_max_lag"] = json.dumps(self.adstock.l_max)
idata.attrs["validate_data"] = json.dumps(self.validate_data)
idata.attrs["yearly_seasonality"] = json.dumps(self.yearly_seasonality)
idata.attrs["time_varying_intercept"] = json.dumps(self.time_varying_intercept)
idata.attrs["time_varying_media"] = json.dumps(self.time_varying_media)
def create_idata_attrs(self) -> dict[str, str]:
attrs = super().create_idata_attrs()
attrs["date_column"] = json.dumps(self.date_column)
attrs["adstock"] = json.dumps(self.adstock.lookup_name)
attrs["saturation"] = json.dumps(self.saturation.lookup_name)
attrs["adstock_first"] = json.dumps(self.adstock_first)
attrs["control_columns"] = json.dumps(self.control_columns)
attrs["channel_columns"] = json.dumps(self.channel_columns)
attrs["adstock_max_lag"] = json.dumps(self.adstock.l_max)
attrs["validate_data"] = json.dumps(self.validate_data)
attrs["yearly_seasonality"] = json.dumps(self.yearly_seasonality)
attrs["time_varying_intercept"] = json.dumps(self.time_varying_intercept)
attrs["time_varying_media"] = json.dumps(self.time_varying_media)
juanitorduz marked this conversation as resolved.
Show resolved Hide resolved

return attrs

def forward_pass(
self, x: pt.TensorVariable | npt.NDArray[np.float64]
Expand Down Expand Up @@ -636,73 +636,26 @@ def ndarray_to_list(d: dict) -> dict:
return ndarray_to_list(serializable_config)

@classmethod
def load(cls, fname: str):
"""
Creates a MMM instance from a file,
instantiating the model with the saved original input parameters.
Loads inference data for the model.

Parameters
----------
fname : string
This denotes the name with path from where idata should be loaded from.

Returns
-------
Returns an instance of MMM.

Raises
------
ValueError
If the inference data that is loaded doesn't match with the model.
"""

filepath = Path(fname)
idata = from_netcdf(filepath)
model_config = cls._model_config_formatting(
json.loads(idata.attrs["model_config"])
)
with warnings.catch_warnings():
warnings.simplefilter("ignore", category=DeprecationWarning)
model = cls(
date_column=json.loads(idata.attrs["date_column"]),
control_columns=json.loads(idata.attrs["control_columns"]),
# Media Transformations
channel_columns=json.loads(idata.attrs["channel_columns"]),
adstock_max_lag=json.loads(idata.attrs["adstock_max_lag"]),
adstock=json.loads(idata.attrs.get("adstock", "geometric")),
saturation=json.loads(idata.attrs.get("saturation", "logistic")),
adstock_first=json.loads(idata.attrs.get("adstock_first", True)),
# Seasonality
yearly_seasonality=json.loads(idata.attrs["yearly_seasonality"]),
# TVP
time_varying_intercept=json.loads(
idata.attrs.get("time_varying_intercept", False)
),
time_varying_media=json.loads(
idata.attrs.get("time_varying_media", False)
),
# Configurations
validate_data=json.loads(idata.attrs["validate_data"]),
model_config=model_config,
sampler_config=json.loads(idata.attrs["sampler_config"]),
)

model.idata = idata
dataset = idata.fit_data.to_dataframe()
X = dataset.drop(columns=[model.output_var])
y = dataset[model.output_var].values
model.build_model(X, y)
# All previously used data is in idata.
if model.id != idata.attrs["id"]:
error_msg = (
f"The file '{fname}' does not contain "
"an inference data of the same model or "
f"configuration as '{cls._model_type}'"
)
raise ValueError(error_msg)

return model
def attrs_to_init_kwargs(cls, attrs) -> dict[str, Any]:
return {
"model_config": cls._model_config_formatting(
json.loads(attrs["model_config"])
),
"date_column": json.loads(attrs["date_column"]),
"control_columns": json.loads(attrs["control_columns"]),
"channel_columns": json.loads(attrs["channel_columns"]),
"adstock_max_lag": json.loads(attrs["adstock_max_lag"]),
"adstock": json.loads(attrs.get("adstock", "geometric")),
"saturation": json.loads(attrs.get("saturation", "logistic")),
"adstock_first": json.loads(attrs.get("adstock_first", True)),
"yearly_seasonality": json.loads(attrs["yearly_seasonality"]),
"time_varying_intercept": json.loads(
attrs.get("time_varying_intercept", False)
),
"time_varying_media": json.loads(attrs.get("time_varying_media", False)),
"validate_data": json.loads(attrs["validate_data"]),
"sampler_config": json.loads(attrs["sampler_config"]),
}

def _data_setter(
self,
Expand Down
125 changes: 90 additions & 35 deletions pymc_marketing/model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import json
import warnings
from abc import ABC, abstractmethod
from inspect import signature
from pathlib import Path
from typing import Any

Expand Down Expand Up @@ -135,7 +136,7 @@ def _data_setter(

@property
@abstractmethod
def output_var(self):
def output_var(self) -> str:
"""
Returns the name of the output variable of the model.

Expand Down Expand Up @@ -264,6 +265,27 @@ def build_model(
None
"""

def create_idata_attrs(self) -> dict[str, str]:
def default(x):
if isinstance(x, Prior):
return x.to_json()
elif isinstance(x, HSGPKwargs):
return x.model_dump(mode="json")
return x.__dict__

attrs: dict[str, str] = {}

attrs["id"] = self.id
attrs["model_type"] = self._model_type
attrs["version"] = self.version
attrs["sampler_config"] = json.dumps(self.sampler_config)
attrs["model_config"] = json.dumps(
self._serializable_model_config,
default=default,
)

return attrs

def set_idata_attrs(
self, idata: az.InferenceData | None = None
) -> az.InferenceData:
Expand All @@ -277,42 +299,59 @@ def set_idata_attrs(

Raises
------
ValueError
If the attrs are missing.
wd60622 marked this conversation as resolved.
Show resolved Hide resolved
RuntimeError
If no InferenceData object is provided.

Returns
-------
None
InferenceData
The InferenceData instance with the attrs set

Examples
--------
>>> model = MyModel(ModelBuilder)
>>> idata = az.InferenceData(your_dataset)
>>> model.set_idata_attrs(idata=idata)
Set the attrs for an InferenceData object manually.

.. code-block:: python

idata: az.InferenceData = ...
model.set_idata_attrs(idata=idata)

"""
if idata is None:
idata = self.idata
if idata is None:
raise RuntimeError("No idata provided to set attrs on.")

def default(x):
if isinstance(x, Prior):
return x.to_json()
elif isinstance(x, HSGPKwargs):
return x.model_dump(mode="json")
return x.__dict__
attrs = self.create_idata_attrs()
attrs_keys = set(attrs.keys())
required_keys = {
"id",
"model_type",
"version",
"sampler_config",
"model_config",
}
if missing_keys := required_keys - attrs_keys:
msg = (
f"Missing required keys in attrs: {missing_keys}. "
"Call super().create_idata_attrs()."
)
raise ValueError(msg)

idata.attrs["id"] = self.id
idata.attrs["model_type"] = self._model_type
idata.attrs["version"] = self.version
idata.attrs["sampler_config"] = json.dumps(self.sampler_config)
idata.attrs["model_config"] = json.dumps(
self._serializable_model_config,
default=default,
)
# Only classes with non-dataset parameters will implement save_input_params
if hasattr(self, "_save_input_params"):
self._save_input_params(idata)
init_parameters: set[str] = set(signature(self.__init__).parameters.keys()) # type: ignore
# Remove since this will be stored in the fit_data group of InferenceData
init_parameters -= {"data"}

if missing_keys := init_parameters - attrs_keys:
msg = (
f"__init__ has parameters that are not in the attrs: {missing_keys}. "
"The save and load functionality will not work correctly."
)
raise ValueError(msg)
Comment on lines +327 to +352
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Think we should do this step at initialization. WDYT?
Then we'd know immediately if there is an issue with saving the model. Don't think this is too heavy of a calculation

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This actually has a bit of difficulty which the creation of the id which depends on the model_config attribute


idata.attrs = attrs
return idata

def save(self, fname: str) -> None:
Expand Down Expand Up @@ -374,10 +413,20 @@ def _model_config_formatting(cls, model_config: dict) -> dict:
)
return model_config

@classmethod
def attrs_to_init_kwargs(cls, attrs) -> dict[str, Any]:
return {
"model_config": cls._model_config_formatting(
json.loads(attrs["model_config"])
),
"sampler_config": json.loads(attrs["sampler_config"]),
}

@classmethod
def load(cls, fname: str):
"""
Creates a ModelBuilder instance from a file,

Loads inference data for the model.

Parameters
Expand All @@ -393,24 +442,27 @@ def load(cls, fname: str):
------
ValueError
If the inference data that is loaded doesn't match with the model.

Examples
--------
>>> class MyModel(ModelBuilder):
>>> ...
>>> name = './mymodel.nc'
>>> imported_model = MyModel.load(name)
Load a model from a file

.. code-block:: python

file_name: str = "./mymodel.nc"
model = MyModel.load(file_name)

"""
filepath = Path(str(fname))
idata = from_netcdf(filepath)

# needs to be converted, because json.loads was changing tuple to list
model_config = cls._model_config_formatting(
json.loads(idata.attrs["model_config"])
)
model = cls(
model_config=model_config,
sampler_config=json.loads(idata.attrs["sampler_config"]),
)
init_kwargs = cls.attrs_to_init_kwargs(idata.attrs)

with warnings.catch_warnings():
warnings.simplefilter("ignore", category=DeprecationWarning)
model = cls(**init_kwargs)

model.idata = idata
dataset = idata.fit_data.to_dataframe()
X = dataset.drop(columns=[model.output_var])
Expand All @@ -419,8 +471,11 @@ def load(cls, fname: str):
# All previously used data is in idata.

if model.id != idata.attrs["id"]:
error_msg = f"""The file '{fname}' does not contain an inference data of the same model
or configuration as '{cls._model_type}'"""
error_msg = (
f"The file '{fname}' does not contain "
"an inference data of the same model "
f"or configuration as '{cls._model_type}'"
)
raise ValueError(error_msg)

return model
Expand Down
15 changes: 13 additions & 2 deletions tests/clv/models/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,21 @@
class CLVModelTest(CLVModel):
_model_type = "CLVModelTest"

def __init__(self, data=None, **kwargs):
def __init__(
self,
data=None,
model_config=None,
sampler_config: dict | None = None,
):
if data is None:
data = pd.DataFrame({"y": np.random.randn(10)})
super().__init__(data=data, **kwargs)

super().__init__(
data=data,
model_config=model_config,
sampler_config=sampler_config,
non_distributions=[],
)

@property
def default_model_config(self):
Expand Down
Loading
Loading