diff --git a/pymc_marketing/clv/models/basic.py b/pymc_marketing/clv/models/basic.py index 178197cd..b7319c1d 100644 --- a/pymc_marketing/clv/models/basic.py +++ b/pymc_marketing/clv/models/basic.py @@ -26,7 +26,6 @@ from pymc.backends import NDArray from pymc.backends.base import MultiTrace from pymc.model.core import Model -from xarray import Dataset from pymc_marketing.model_builder import ModelBuilder from pymc_marketing.model_config import ModelConfig, parse_model_config @@ -256,23 +255,6 @@ def default_sampler_config(self) -> dict: def _serializable_model_config(self) -> dict: return self.model_config - @property - def fit_result(self) -> Dataset: - """Get the fit result.""" - if self.idata is None or "posterior" not in self.idata: - raise RuntimeError("The model hasn't been fit yet, call .fit() first") - return self.idata["posterior"] - - @fit_result.setter - def fit_result(self, res: az.InferenceData) -> None: - if self.idata is None: - self.idata = res - elif "posterior" in self.idata: - warnings.warn("Overriding pre-existing fit_result", stacklevel=1) - self.idata.posterior = res - else: - self.idata.posterior = res - def fit_summary(self, **kwargs): """Compute the summary of the fit result.""" res = self.fit_result diff --git a/pymc_marketing/mmm/base.py b/pymc_marketing/mmm/base.py index d1061e1b..8af129aa 100644 --- a/pymc_marketing/mmm/base.py +++ b/pymc_marketing/mmm/base.py @@ -271,13 +271,6 @@ def get_target_transformer(self) -> Pipeline: identity_transformer = FunctionTransformer() return Pipeline(steps=[("scaler", identity_transformer)]) - @property - def fit_result(self) -> Dataset: - """Get the posterior data.""" - if self.idata is None or "posterior" not in self.idata: - raise RuntimeError("The model hasn't been fit yet, call .fit() first") - return self.idata["posterior"] - def _get_group_predictive_data( self, group: Literal["prior_predictive", "posterior_predictive"], diff --git a/pymc_marketing/model_builder.py b/pymc_marketing/model_builder.py index ca2c04fb..7507057a 100644 --- a/pymc_marketing/model_builder.py +++ b/pymc_marketing/model_builder.py @@ -684,6 +684,42 @@ def fit( self.set_idata_attrs(self.idata) return self.idata # type: ignore + @property + def fit_result(self) -> xr.Dataset: + """Get the posterior fit_result. + + Returns + ------- + InferenceData object. + + """ + return create_idata_accessor( + "posterior", "The model hasn't been fit yet, call .fit() first" + ).__get__(self) + + @fit_result.setter + def fit_result(self, res: az.InferenceData) -> None: + """Create a setter method to overwrite the pre-existing fit_result. + + Parameters + ---------- + res : az.InferenceData + The inferencedata object to be set + + Returns + ------- + property + The property setter for the InferenceData object. + + """ + if self.idata is None: + self.idata = res + elif "posterior" in self.idata: + warnings.warn("Overriding pre-existing fit_result", stacklevel=1) + self.idata.posterior = res + else: + self.idata.posterior = res + def predict( self, X_pred: np.ndarray | pd.DataFrame | pd.Series, @@ -959,6 +995,7 @@ def graphviz(self, **kwargs): posterior = create_idata_accessor( "posterior", "The model hasn't been fit yet, call .fit() first" ) + posterior_predictive = create_idata_accessor( "posterior_predictive", "The model hasn't been fit yet, call .sample_posterior_predictive() first", diff --git a/tests/clv/models/test_basic.py b/tests/clv/models/test_basic.py index 0e021464..c72fd531 100644 --- a/tests/clv/models/test_basic.py +++ b/tests/clv/models/test_basic.py @@ -132,11 +132,6 @@ def test_wrong_fit_method(self): ): model.fit(fit_method="wrong_method") - def test_fit_result_error(self): - model = CLVModelTest() - with pytest.raises(RuntimeError, match="The model hasn't been fit yet"): - model.fit_result - def test_load(self, mocker): model = CLVModelTest() @@ -153,20 +148,6 @@ def test_default_sampler_config(self): model = CLVModelTest() assert model.sampler_config == {} - def test_set_fit_result(self): - model = CLVModelTest() - model.build_model() - model.idata = None - fake_fit = pm.sample_prior_predictive( - samples=50, model=model.model, random_seed=1234 - ) - fake_fit.add_groups(dict(posterior=fake_fit.prior)) - model.fit_result = fake_fit - with pytest.warns(UserWarning, match="Overriding pre-existing fit_result"): - model.fit_result = fake_fit - model.idata = None - model.fit_result = fake_fit - def test_fit_summary_for_mcmc(self, mocker): model = CLVModelTest() diff --git a/tests/test_model_builder.py b/tests/test_model_builder.py index 5d09e6a2..ffd7ea87 100644 --- a/tests/test_model_builder.py +++ b/tests/test_model_builder.py @@ -269,6 +269,27 @@ def test_fit_dup_Y(toy_X, toy_y): model_builder.fit(X=toy_X, chains=1, draws=100, tune=100) +def test_fit_result_error(): + model = ModelBuilderTest() + with pytest.raises(RuntimeError, match="The model hasn't been fit yet"): + model.fit_result + + +def test_set_fit_result(toy_X, toy_y): + model = ModelBuilderTest() + model.build_model(X=toy_X, y=toy_y) + model.idata = None + fake_fit = pm.sample_prior_predictive( + samples=50, model=model.model, random_seed=1234 + ) + fake_fit.add_groups(dict(posterior=fake_fit.prior)) + model.fit_result = fake_fit + with pytest.warns(UserWarning, match="Overriding pre-existing fit_result"): + model.fit_result = fake_fit + model.idata = None + model.fit_result = fake_fit + + @pytest.mark.skipif( sys.platform == "win32", reason="Permissions for temp files not granted on windows CI.",