From 6c2dfb40cddc0708141c84bcf80c071e688535a3 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Fri, 1 Mar 2024 17:21:16 +0000 Subject: [PATCH] Use union of default and user provided configs in ModelBuilder --- .../guide/benefits/model_deployment.ipynb | 5 +- docs/source/notebooks/mmm/mmm_example.ipynb | 47 +++++++++++++++---- pymc_marketing/model_builder.py | 12 +++-- tests/mmm/test_base.py | 4 +- tests/model_builder/test_model_builder.py | 32 +++++++++---- 5 files changed, 72 insertions(+), 28 deletions(-) diff --git a/docs/source/guide/benefits/model_deployment.ipynb b/docs/source/guide/benefits/model_deployment.ipynb index b6f6ac007..176539432 100644 --- a/docs/source/guide/benefits/model_deployment.ipynb +++ b/docs/source/guide/benefits/model_deployment.ipynb @@ -218,7 +218,7 @@ "source": [ "We first illustrate the use of `model_config` to define custom priors within the model.\n", "\n", - "Because there are potentially many variables that can be configured, each model provides a `default_model_config` attribute. This will allow you to see which settings are available by default and easily update only the ones you care about by merging dictionary keys.\n", + "Because there are potentially many variables that can be configured, each model provides a `default_model_config` attribute. This will allow you to see which settings are available by default and only define the ones you need to change.\n", "\n", "We need to create a dummy model to be able to see the configuration dictionary." ] @@ -305,8 +305,7 @@ "metadata": {}, "outputs": [], "source": [ - "custom_beta_channel_prior = {'beta_channel': {'sigma': prior_sigma, 'dims': ('channel',)}}\n", - "my_model_config = dummy_model.default_model_config | custom_beta_channel_prior" + "my_model_config = {'beta_channel': {'sigma': prior_sigma, 'dims': ('channel',)}}" ] }, { diff --git a/docs/source/notebooks/mmm/mmm_example.ipynb b/docs/source/notebooks/mmm/mmm_example.ipynb index b1c5978ab..eec650ba3 100644 --- a/docs/source/notebooks/mmm/mmm_example.ipynb +++ b/docs/source/notebooks/mmm/mmm_example.ipynb @@ -1009,9 +1009,9 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "You can use the optional parameter 'model_config' to apply your own priors to the model. Each entry in the 'model_config' contains a key that corresponds to a registered distribution name in our model. The value of the key is a dictionary that describes the input parameters of that specific distribution. If you want to apply your own priors, you can copy the 'model_config' definition below, modify its content, and pass it to `DelayedSaturatedMMM`.\n", + "You can use the optional parameter 'model_config' to apply your own priors to the model. Each entry in the 'model_config' contains a key that corresponds to a registered distribution name in our model. The value of the key is a dictionary that describes the input parameters of that specific distribution.\n", "\n", - "If you're unsure how to define your own priors, you can use the 'default_model_config' property of `DelayedSaturatedMMM` to see the required structure.\n" + "If you're unsure how to define your own priors, you can use the 'default_model_config' property of `DelayedSaturatedMMM` to see the required structure." ] }, { @@ -1074,7 +1074,7 @@ } ], "source": [ - "custom_beta_channel_prior = {'beta_channel': {'dist': 'LogNormal',\n", + "my_model_config = {'beta_channel': {'dist': 'LogNormal',\n", " \"kwargs\":{\"mu\":np.array([2,1]), \"sigma\": prior_sigma},\n", " },\n", " \"likelihood\": {\n", @@ -1085,9 +1085,7 @@ " # {'sigma': 5}\n", " }\n", " }\n", - " }\n", - "my_model_config = {**dummy_model.default_model_config, **custom_beta_channel_prior}\n", - "my_model_config" + " }" ] }, { @@ -1110,7 +1108,7 @@ "metadata": {}, "outputs": [], "source": [ - "sampler_config= {\"progressbar\": True}" + "my_sampler_config= {\"progressbar\": True}" ] }, { @@ -1128,7 +1126,7 @@ "source": [ "mmm = DelayedSaturatedMMM(\n", " model_config = my_model_config,\n", - " sampler_config = sampler_config,\n", + " sampler_config = my_sampler_config,\n", " date_column=\"date_week\",\n", " channel_columns=[\"x1\", \"x2\"],\n", " control_columns=[\n", @@ -8988,7 +8986,7 @@ "metadata": { "hide_input": false, "kernelspec": { - "display_name": "Python 3 (ipykernel)", + "display_name": "Python 3", "language": "python", "name": "python3" }, @@ -9002,7 +9000,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.13" + "version": "3.8.10" }, "toc": { "base_numbering": 1, @@ -9016,6 +9014,35 @@ "toc_position": {}, "toc_section_display": true, "toc_window_display": false + }, + "varInspector": { + "cols": { + "lenName": 16, + "lenType": 16, + "lenVar": 40 + }, + "kernels_config": { + "python": { + "delete_cmd_postfix": "", + "delete_cmd_prefix": "del ", + "library": "var_list.py", + "varRefreshCmd": "print(var_dic_list())" + }, + "r": { + "delete_cmd_postfix": ") ", + "delete_cmd_prefix": "rm(", + "library": "var_list.r", + "varRefreshCmd": "cat(var_dic_list()) " + } + }, + "types_to_exclude": [ + "module", + "function", + "builtin_function_or_method", + "instance", + "_Feature" + ], + "window_display": false } }, "nbformat": 4, diff --git a/pymc_marketing/model_builder.py b/pymc_marketing/model_builder.py index d90a175ce..e34765ec7 100644 --- a/pymc_marketing/model_builder.py +++ b/pymc_marketing/model_builder.py @@ -75,11 +75,15 @@ def __init__( >>> model = MyModel(model_config, sampler_config) """ if sampler_config is None: - sampler_config = self.default_sampler_config + sampler_config = {} if model_config is None: - model_config = self.default_model_config - self.sampler_config = sampler_config - self.model_config = model_config # parameters for priors etc. + model_config = {} + self.sampler_config = ( + self.default_sampler_config | sampler_config + ) # Parameters for fit sampling + self.model_config = ( + self.default_model_config | model_config + ) # parameters for priors etc. self.model: Optional[pm.Model] = None # Set by build_model self.idata: Optional[ az.InferenceData diff --git a/tests/mmm/test_base.py b/tests/mmm/test_base.py index c96eda4a4..a244800f0 100644 --- a/tests/mmm/test_base.py +++ b/tests/mmm/test_base.py @@ -68,11 +68,11 @@ def _generate_and_preprocess_model_data(self, X, y): @property def default_model_config(self): - pass + return {} @property def default_sampler_config(self): - pass + return {} @property def output_var(self): diff --git a/tests/model_builder/test_model_builder.py b/tests/model_builder/test_model_builder.py index 531e590c5..2eeff355e 100644 --- a/tests/model_builder/test_model_builder.py +++ b/tests/model_builder/test_model_builder.py @@ -57,7 +57,7 @@ def fitted_model_instance(toy_X, toy_y): "b": {"loc": 0, "scale": 10}, "obs_error": 2, } - model = test_ModelBuilder( + model = ModelBuilderTest( model_config=model_config, sampler_config=sampler_config, test_parameter="test_paramter", @@ -79,7 +79,7 @@ def not_fitted_model_instance(toy_X, toy_y): "b": {"loc": 0, "scale": 10}, "obs_error": 2, } - model = test_ModelBuilder( + model = ModelBuilderTest( model_config=model_config, sampler_config=sampler_config, test_parameter="test_paramter", @@ -87,7 +87,7 @@ def not_fitted_model_instance(toy_X, toy_y): return model -class test_ModelBuilder(ModelBuilder): +class ModelBuilderTest(ModelBuilder): def __init__(self, model_config=None, sampler_config=None, test_parameter=None): self.test_parameter = test_parameter super().__init__(model_config=model_config, sampler_config=sampler_config) @@ -159,6 +159,20 @@ def default_sampler_config(self) -> Dict: } +def test_model_and_sampler_config(): + default = ModelBuilderTest() + assert default.model_config == default.default_model_config + assert default.sampler_config == default.default_sampler_config + + nondefault = ModelBuilderTest( + model_config={"obs_error": 3}, sampler_config={"draws": 42} + ) + assert nondefault.model_config != nondefault.default_model_config + assert nondefault.sampler_config != nondefault.default_sampler_config + assert nondefault.model_config == default.model_config | {"obs_error": 3} + assert nondefault.sampler_config == default.sampler_config | {"draws": 42} + + def test_save_input_params(fitted_model_instance): assert fitted_model_instance.idata.attrs["test_paramter"] == '"test_paramter"' @@ -166,7 +180,7 @@ def test_save_input_params(fitted_model_instance): def test_save_load(fitted_model_instance): temp = tempfile.NamedTemporaryFile(mode="w", encoding="utf-8", delete=False) fitted_model_instance.save(temp.name) - test_builder2 = test_ModelBuilder.load(temp.name) + test_builder2 = ModelBuilderTest.load(temp.name) assert fitted_model_instance.idata.groups() == test_builder2.idata.groups() assert fitted_model_instance.id == test_builder2.id x_pred = np.random.uniform(low=0, high=1, size=100) @@ -184,14 +198,14 @@ def test_initial_build_and_fit(fitted_model_instance, check_idata=True) -> Model def test_save_without_fit_raises_runtime_error(): - model_builder = test_ModelBuilder() + model_builder = ModelBuilderTest() with pytest.raises(RuntimeError): model_builder.save("saved_model") def test_empty_sampler_config_fit(toy_X, toy_y): sampler_config = {} - model_builder = test_ModelBuilder(sampler_config=sampler_config) + model_builder = ModelBuilderTest(sampler_config=sampler_config) model_builder.idata = model_builder.fit( X=toy_X, y=toy_y, chains=1, draws=100, tune=100 ) @@ -215,7 +229,7 @@ def test_fit(fitted_model_instance): def test_fit_no_y(toy_X): - model_builder = test_ModelBuilder() + model_builder = ModelBuilderTest() model_builder.idata = model_builder.fit(X=toy_X, chains=1, draws=100, tune=100) assert model_builder.model is not None assert model_builder.idata is not None @@ -260,14 +274,14 @@ def test_model_config_formatting(): ], }, } - model_builder = test_ModelBuilder() + model_builder = ModelBuilderTest() converted_model_config = model_builder._model_config_formatting(model_config) np.testing.assert_equal(converted_model_config["a"]["dims"], ("x",)) np.testing.assert_equal(converted_model_config["a"]["loc"], np.array([0, 0])) def test_id(): - model_builder = test_ModelBuilder() + model_builder = ModelBuilderTest() expected_id = hashlib.sha256( str(model_builder.model_config.values()).encode() + model_builder.version.encode()