diff --git a/pymc_marketing/model_builder.py b/pymc_marketing/model_builder.py index 725eeb6db..855db00ee 100644 --- a/pymc_marketing/model_builder.py +++ b/pymc_marketing/model_builder.py @@ -480,12 +480,39 @@ def attrs_to_init_kwargs(cls, attrs) -> dict[str, Any]: "sampler_config": json.loads(attrs["sampler_config"]), } + def build_from_idata(self, idata: az.InferenceData) -> None: + """Build model from the InferenceData object. + + This is part of the :func:`load` method. See :func:`load` for more larger context. + + Usually a wrapper around the :func:`build_model` method unless the model + has some additional steps to be built. + + Parameters + ---------- + idata : az.InferenceData + The InferenceData object to build the model from. + + """ + dataset = idata.fit_data.to_dataframe() # type: ignore + X = dataset.drop(columns=[self.output_var]) + y = dataset[self.output_var] + + self.build_model(X, y) + @classmethod def load(cls, fname: str): """Create a ModelBuilder instance from a file. Loads inference data for the model. + This class method has a few steps: + + - Load the InferenceData from the file. + - Construct a new instance of the model using the InferenceData attrs + - Build the model from the InferenceData + - Check if the model id matches the id in the InferenceData loaded. + Parameters ---------- fname : string @@ -521,11 +548,7 @@ def load(cls, fname: str): model = cls(**init_kwargs) model.idata = idata - dataset = idata.fit_data.to_dataframe() - X = dataset.drop(columns=[model.output_var]) - y = dataset[model.output_var] - model.build_model(X, y) - # All previously used data is in idata. + model.build_from_idata(idata) if model.id != idata.attrs["id"]: error_msg = (