Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use union of default and user provided configs in ModelBuilder #565

Merged
merged 1 commit into from
Mar 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions docs/source/guide/benefits/model_deployment.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@
"source": [
"We first illustrate the use of `model_config` to define custom priors within the model.\n",
"\n",
"Because there are potentially many variables that can be configured, each model provides a `default_model_config` attribute. This will allow you to see which settings are available by default and easily update only the ones you care about by merging dictionary keys.\n",
"Because there are potentially many variables that can be configured, each model provides a `default_model_config` attribute. This will allow you to see which settings are available by default and only define the ones you need to change.\n",
"\n",
"We need to create a dummy model to be able to see the configuration dictionary."
]
Expand Down Expand Up @@ -305,8 +305,7 @@
"metadata": {},
"outputs": [],
"source": [
"custom_beta_channel_prior = {'beta_channel': {'sigma': prior_sigma, 'dims': ('channel',)}}\n",
"my_model_config = dummy_model.default_model_config | custom_beta_channel_prior"
"my_model_config = {'beta_channel': {'sigma': prior_sigma, 'dims': ('channel',)}}"
]
},
{
Expand Down
47 changes: 37 additions & 10 deletions docs/source/notebooks/mmm/mmm_example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -1009,9 +1009,9 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"You can use the optional parameter 'model_config' to apply your own priors to the model. Each entry in the 'model_config' contains a key that corresponds to a registered distribution name in our model. The value of the key is a dictionary that describes the input parameters of that specific distribution. If you want to apply your own priors, you can copy the 'model_config' definition below, modify its content, and pass it to `DelayedSaturatedMMM`.\n",
"You can use the optional parameter 'model_config' to apply your own priors to the model. Each entry in the 'model_config' contains a key that corresponds to a registered distribution name in our model. The value of the key is a dictionary that describes the input parameters of that specific distribution.\n",
"\n",
"If you're unsure how to define your own priors, you can use the 'default_model_config' property of `DelayedSaturatedMMM` to see the required structure.\n"
"If you're unsure how to define your own priors, you can use the 'default_model_config' property of `DelayedSaturatedMMM` to see the required structure."
]
},
{
Expand Down Expand Up @@ -1074,7 +1074,7 @@
}
],
"source": [
"custom_beta_channel_prior = {'beta_channel': {'dist': 'LogNormal',\n",
"my_model_config = {'beta_channel': {'dist': 'LogNormal',\n",
" \"kwargs\":{\"mu\":np.array([2,1]), \"sigma\": prior_sigma},\n",
" },\n",
" \"likelihood\": {\n",
Expand All @@ -1085,9 +1085,7 @@
" # {'sigma': 5}\n",
" }\n",
" }\n",
" }\n",
"my_model_config = {**dummy_model.default_model_config, **custom_beta_channel_prior}\n",
"my_model_config"
" }"
]
},
{
Expand All @@ -1110,7 +1108,7 @@
"metadata": {},
"outputs": [],
"source": [
"sampler_config= {\"progressbar\": True}"
"my_sampler_config= {\"progressbar\": True}"
]
},
{
Expand All @@ -1128,7 +1126,7 @@
"source": [
"mmm = DelayedSaturatedMMM(\n",
" model_config = my_model_config,\n",
" sampler_config = sampler_config,\n",
" sampler_config = my_sampler_config,\n",
" date_column=\"date_week\",\n",
" channel_columns=[\"x1\", \"x2\"],\n",
" control_columns=[\n",
Expand Down Expand Up @@ -8988,7 +8986,7 @@
"metadata": {
"hide_input": false,
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
Expand All @@ -9002,7 +9000,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.13"
"version": "3.8.10"
},
"toc": {
"base_numbering": 1,
Expand All @@ -9016,6 +9014,35 @@
"toc_position": {},
"toc_section_display": true,
"toc_window_display": false
},
"varInspector": {
"cols": {
"lenName": 16,
"lenType": 16,
"lenVar": 40
},
"kernels_config": {
"python": {
"delete_cmd_postfix": "",
"delete_cmd_prefix": "del ",
"library": "var_list.py",
"varRefreshCmd": "print(var_dic_list())"
},
"r": {
"delete_cmd_postfix": ") ",
"delete_cmd_prefix": "rm(",
"library": "var_list.r",
"varRefreshCmd": "cat(var_dic_list()) "
}
},
"types_to_exclude": [
"module",
"function",
"builtin_function_or_method",
"instance",
"_Feature"
],
"window_display": false
}
},
"nbformat": 4,
Expand Down
12 changes: 8 additions & 4 deletions pymc_marketing/model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,11 +75,15 @@ def __init__(
>>> model = MyModel(model_config, sampler_config)
"""
if sampler_config is None:
sampler_config = self.default_sampler_config
sampler_config = {}
if model_config is None:
model_config = self.default_model_config
self.sampler_config = sampler_config
self.model_config = model_config # parameters for priors etc.
model_config = {}
self.sampler_config = (
self.default_sampler_config | sampler_config
) # Parameters for fit sampling
self.model_config = (
self.default_model_config | model_config
) # parameters for priors etc.
self.model: Optional[pm.Model] = None # Set by build_model
self.idata: Optional[
az.InferenceData
Expand Down
4 changes: 2 additions & 2 deletions tests/mmm/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,11 +68,11 @@ def _generate_and_preprocess_model_data(self, X, y):

@property
def default_model_config(self):
pass
return {}

@property
def default_sampler_config(self):
pass
return {}

@property
def output_var(self):
Expand Down
32 changes: 23 additions & 9 deletions tests/model_builder/test_model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def fitted_model_instance(toy_X, toy_y):
"b": {"loc": 0, "scale": 10},
"obs_error": 2,
}
model = test_ModelBuilder(
model = ModelBuilderTest(
model_config=model_config,
sampler_config=sampler_config,
test_parameter="test_paramter",
Expand All @@ -79,15 +79,15 @@ def not_fitted_model_instance(toy_X, toy_y):
"b": {"loc": 0, "scale": 10},
"obs_error": 2,
}
model = test_ModelBuilder(
model = ModelBuilderTest(
model_config=model_config,
sampler_config=sampler_config,
test_parameter="test_paramter",
)
return model


class test_ModelBuilder(ModelBuilder):
class ModelBuilderTest(ModelBuilder):
def __init__(self, model_config=None, sampler_config=None, test_parameter=None):
self.test_parameter = test_parameter
super().__init__(model_config=model_config, sampler_config=sampler_config)
Expand Down Expand Up @@ -159,14 +159,28 @@ def default_sampler_config(self) -> Dict:
}


def test_model_and_sampler_config():
default = ModelBuilderTest()
assert default.model_config == default.default_model_config
assert default.sampler_config == default.default_sampler_config

nondefault = ModelBuilderTest(
model_config={"obs_error": 3}, sampler_config={"draws": 42}
)
assert nondefault.model_config != nondefault.default_model_config
assert nondefault.sampler_config != nondefault.default_sampler_config
assert nondefault.model_config == default.model_config | {"obs_error": 3}
assert nondefault.sampler_config == default.sampler_config | {"draws": 42}


def test_save_input_params(fitted_model_instance):
assert fitted_model_instance.idata.attrs["test_paramter"] == '"test_paramter"'


def test_save_load(fitted_model_instance):
temp = tempfile.NamedTemporaryFile(mode="w", encoding="utf-8", delete=False)
fitted_model_instance.save(temp.name)
test_builder2 = test_ModelBuilder.load(temp.name)
test_builder2 = ModelBuilderTest.load(temp.name)
assert fitted_model_instance.idata.groups() == test_builder2.idata.groups()
assert fitted_model_instance.id == test_builder2.id
x_pred = np.random.uniform(low=0, high=1, size=100)
Expand All @@ -184,14 +198,14 @@ def test_initial_build_and_fit(fitted_model_instance, check_idata=True) -> Model


def test_save_without_fit_raises_runtime_error():
model_builder = test_ModelBuilder()
model_builder = ModelBuilderTest()
with pytest.raises(RuntimeError):
model_builder.save("saved_model")


def test_empty_sampler_config_fit(toy_X, toy_y):
sampler_config = {}
model_builder = test_ModelBuilder(sampler_config=sampler_config)
model_builder = ModelBuilderTest(sampler_config=sampler_config)
model_builder.idata = model_builder.fit(
X=toy_X, y=toy_y, chains=1, draws=100, tune=100
)
Expand All @@ -215,7 +229,7 @@ def test_fit(fitted_model_instance):


def test_fit_no_y(toy_X):
model_builder = test_ModelBuilder()
model_builder = ModelBuilderTest()
model_builder.idata = model_builder.fit(X=toy_X, chains=1, draws=100, tune=100)
assert model_builder.model is not None
assert model_builder.idata is not None
Expand Down Expand Up @@ -260,14 +274,14 @@ def test_model_config_formatting():
],
},
}
model_builder = test_ModelBuilder()
model_builder = ModelBuilderTest()
converted_model_config = model_builder._model_config_formatting(model_config)
np.testing.assert_equal(converted_model_config["a"]["dims"], ("x",))
np.testing.assert_equal(converted_model_config["a"]["loc"], np.array([0, 0]))


def test_id():
model_builder = test_ModelBuilder()
model_builder = ModelBuilderTest()
expected_id = hashlib.sha256(
str(model_builder.model_config.values()).encode()
+ model_builder.version.encode()
Expand Down
Loading