Skip to content

Commit

Permalink
Update BetaGeoModel API (#709)
Browse files Browse the repository at this point in the history
* _extract_predictive_variables util

* deprecation warnings

* expected_purchases_new_customer

* TestBetaGeoModel.setup_class

* test cleanup and test_expected_purchases_new_customer

* expected_probability_alive

* expected_purchases

* TODOs and docstrings

* update runslow tests rtol for new test dataset

* prob_alive_matrix plot fix and notebook testing

* alive loop in bgnbd nb

* quickstart nb fix

* docstring and TODO revisions

* docstring syntax

* docstring edits

* more docstring fixes

* docstrings indent

* docstring indent, clear codecov bug

---------

Co-authored-by: Juan Orduz <[email protected]>
  • Loading branch information
2 people authored and twiecki committed Sep 10, 2024
1 parent c74d369 commit 7774e7c
Show file tree
Hide file tree
Showing 31 changed files with 7,463 additions and 5,511 deletions.
453 changes: 226 additions & 227 deletions docs/source/notebooks/clv/bg_nbd.ipynb

Large diffs are not rendered by default.

739 changes: 350 additions & 389 deletions docs/source/notebooks/clv/clv_quickstart.ipynb

Large diffs are not rendered by default.

2,467 changes: 1,120 additions & 1,347 deletions docs/source/notebooks/mmm/mmm_budget_allocation_example.ipynb

Large diffs are not rendered by default.

3,555 changes: 2,022 additions & 1,533 deletions docs/source/notebooks/mmm/mmm_example.ipynb

Large diffs are not rendered by default.

789 changes: 403 additions & 386 deletions docs/source/notebooks/mmm/mmm_lift_test.ipynb

Large diffs are not rendered by default.

175 changes: 87 additions & 88 deletions docs/source/notebooks/mmm/mmm_tvp_example.ipynb

Large diffs are not rendered by default.

Binary file not shown.
436 changes: 265 additions & 171 deletions pymc_marketing/clv/models/beta_geo.py

Large diffs are not rendered by default.

27 changes: 10 additions & 17 deletions pymc_marketing/clv/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,34 +331,27 @@ def plot_probability_alive_matrix(
max_frequency=max_frequency,
max_recency=max_recency,
)
transaction_data = pd.DataFrame(
{
"customer_id": np.arange(mesh_recency.size), # placeholder
"frequency": mesh_frequency.ravel(),
"recency": mesh_recency.ravel(),
"T": max_recency,
}
)
# FIXME: This is a hotfix for ParetoNBDModel, as it has a different API from BetaGeoModel
# We should harmonize them!
if isinstance(model, ParetoNBDModel):
transaction_data = pd.DataFrame(
{
"customer_id": np.arange(mesh_recency.size), # placeholder
"frequency": mesh_frequency.ravel(),
"recency": mesh_recency.ravel(),
"T": max_recency,
}
)

Z = (
model.expected_probability_alive(
data=transaction_data,
future_t=0, # TODO: This can be a function parameter in the case of ParetoNBDModel
future_t=0, # TODO: This is a required parameter if data is provided.
)
.mean(("draw", "chain"))
.values.reshape(mesh_recency.shape)
)
else:
Z = (
model.expected_probability_alive(
customer_id=np.arange(mesh_recency.size), # placeholder
frequency=mesh_frequency.ravel(),
recency=mesh_recency.ravel(),
T=max_recency, # type: ignore
)
model.expected_probability_alive(data=transaction_data)
.mean(("draw", "chain"))
.values.reshape(mesh_recency.shape)
)
Expand Down
34 changes: 31 additions & 3 deletions pymc_marketing/mmm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,25 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from pymc_marketing.mmm import base, delayed_saturated_mmm, preprocessing, validating
from pymc_marketing.mmm.base import MMM, BaseMMM
from pymc_marketing.mmm.delayed_saturated_mmm import DelayedSaturatedMMM
from pymc_marketing.mmm.base import (
BaseValidateMMM,
MMMModelBuilder,
)
from pymc_marketing.mmm.components.adstock import (
AdstockTransformation,
DelayedAdstock,
GeometricAdstock,
WeibullAdstock,
)
from pymc_marketing.mmm.components.saturation import (
HillSaturation,
LogisticSaturation,
MichaelisMentenSaturation,
SaturationTransformation,
TanhSaturation,
TanhSaturationBaselined,
)
from pymc_marketing.mmm.delayed_saturated_mmm import MMM, DelayedSaturatedMMM
from pymc_marketing.mmm.preprocessing import (
preprocessing_method_X,
preprocessing_method_y,
Expand All @@ -26,10 +43,21 @@
"preprocessing",
"validating",
"MMM",
"BaseMMM",
"MMMModelBuilder",
"BaseValidateMMM",
"DelayedSaturatedMMM",
"preprocessing_method_X",
"preprocessing_method_y",
"validation_method_X",
"validation_method_y",
"AdstockTransformation",
"DelayedAdstock",
"GeometricAdstock",
"WeibullAdstock",
"SaturationTransformation",
"MichaelisMentenSaturation",
"HillSaturation",
"LogisticSaturation",
"TanhSaturation",
"TanhSaturationBaselined",
]
Loading

0 comments on commit 7774e7c

Please sign in to comment.