Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move fit_result to model_builder.py and remove redundancies from CLV and MMM #1344

Merged
merged 10 commits into from
Jan 8, 2025
18 changes: 0 additions & 18 deletions pymc_marketing/clv/models/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
from pymc.backends import NDArray
from pymc.backends.base import MultiTrace
from pymc.model.core import Model
from xarray import Dataset

from pymc_marketing.model_builder import ModelBuilder
from pymc_marketing.model_config import ModelConfig, parse_model_config
Expand Down Expand Up @@ -256,23 +255,6 @@ def default_sampler_config(self) -> dict:
def _serializable_model_config(self) -> dict:
return self.model_config

@property
def fit_result(self) -> Dataset:
"""Get the fit result."""
if self.idata is None or "posterior" not in self.idata:
raise RuntimeError("The model hasn't been fit yet, call .fit() first")
return self.idata["posterior"]

@fit_result.setter
def fit_result(self, res: az.InferenceData) -> None:
if self.idata is None:
self.idata = res
elif "posterior" in self.idata:
warnings.warn("Overriding pre-existing fit_result", stacklevel=1)
self.idata.posterior = res
else:
self.idata.posterior = res

def fit_summary(self, **kwargs):
"""Compute the summary of the fit result."""
res = self.fit_result
Expand Down
7 changes: 0 additions & 7 deletions pymc_marketing/mmm/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,13 +271,6 @@ def get_target_transformer(self) -> Pipeline:
identity_transformer = FunctionTransformer()
return Pipeline(steps=[("scaler", identity_transformer)])

@property
def fit_result(self) -> Dataset:
"""Get the posterior data."""
if self.idata is None or "posterior" not in self.idata:
raise RuntimeError("The model hasn't been fit yet, call .fit() first")
return self.idata["posterior"]

def _get_group_predictive_data(
self,
group: Literal["prior_predictive", "posterior_predictive"],
Expand Down
14 changes: 13 additions & 1 deletion pymc_marketing/model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,16 @@

return self.idata[value]

return property(accessor)
def setter(self, res: az.InferenceData) -> None:
if self.idata is None:
self.idata = res
elif "posterior" in self.idata:
warnings.warn("Overriding pre-existing fit_result", stacklevel=1)
self.idata.posterior = res
else:
self.idata.posterior = res

Check warning on line 82 in pymc_marketing/model_builder.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/model_builder.py#L82

Added line #L82 was not covered by tests
sreekailash marked this conversation as resolved.
Show resolved Hide resolved

return property(accessor, setter)


def create_sample_kwargs(
Expand Down Expand Up @@ -959,6 +968,9 @@
posterior = create_idata_accessor(
"posterior", "The model hasn't been fit yet, call .fit() first"
)
fit_result = create_idata_accessor(
"posterior", "The model hasn't been fit yet, call .fit() first"
)
posterior_predictive = create_idata_accessor(
"posterior_predictive",
"The model hasn't been fit yet, call .sample_posterior_predictive() first",
Expand Down
Loading