Skip to content

add_set_params_and_corresponding_tests #1102

Merged
merged 18 commits into from
Feb 22, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## Unreleased
### Added

-
- `set_params` method to change parameters of ETNA objects [#1025](https://github.com/tinkoff-ai/etna/issues/1025)
GooseIt marked this conversation as resolved.
Show resolved Hide resolved
-
### Changed

Expand Down
39 changes: 39 additions & 0 deletions etna/core/mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,45 @@ def to_dict(self):
params["_target_"] = BaseMixin._get_target_from_class(self)
return params

def set_params(self, **params):
GooseIt marked this conversation as resolved.
Show resolved Hide resolved
"""Set the parameters of this estimator.

The method works on simple estimators as well as on nested objects
(such as :class:`~etna.pipeline.Pipeline`). The latter have
parameters of the form ``<component>.<parameter>`` so that it's
possible to update each component of a nested object.

Parameters
----------
**params : dict
GooseIt marked this conversation as resolved.
Show resolved Hide resolved
Estimator parameters.
GooseIt marked this conversation as resolved.
Show resolved Hide resolved
Returns
-------
self : estimator instance
GooseIt marked this conversation as resolved.
Show resolved Hide resolved
Estimator instance.
"""
for parameter, parameter_value in params.items():
# split specification into list of nested params, "model.depth" -> ["model", "depth"]
param_nested = parameter.split(".")
# all nested params except the last specify path to the estimator whose attribute will be set
# in the following cycle, we find this estimator
estimator = self
nesting_correct = True
for param_current_nesting_level in param_nested[:-1]:
try:
estimator = estimator.__getattribute__(param_current_nesting_level)
GooseIt marked this conversation as resolved.
Show resolved Hide resolved
except AttributeError:
nesting_correct = False
break
if nesting_correct:
try:
# if there is no such attribute, the first row will throw AttributeError
estimator.__getattribute__(param_nested[-1])
GooseIt marked this conversation as resolved.
Show resolved Hide resolved
estimator.__setattr__(param_nested[-1], parameter_value)
GooseIt marked this conversation as resolved.
Show resolved Hide resolved
except AttributeError:
pass
return self
GooseIt marked this conversation as resolved.
Show resolved Hide resolved


class StringEnumWithRepr(str, Enum):
"""Base class for str enums, that has alternative __repr__ method."""
Expand Down
82 changes: 82 additions & 0 deletions tests/test_core/test_set_params.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
from etna.models import CatBoostMultiSegmentModel
from etna.pipeline import Pipeline
from etna.transforms import AddConstTransform


def test_base_mixin_set_params_changes_params_estimator():
catboost_model = CatBoostMultiSegmentModel(iterations=1000, depth=10)
catboost_model.set_params(**{"learning_rate": 1e-3, "depth": 8})
expected_dict = {
"iterations": 1000,
"depth": 8,
"learning_rate": 1e-3,
"logging_level": "Silent",
"kwargs": {},
"_target_": "etna.models.catboost.CatBoostMultiSegmentModel",
GooseIt marked this conversation as resolved.
Show resolved Hide resolved
}
obtained_dict = catboost_model.to_dict()
assert obtained_dict == expected_dict


def test_base_mixin_set_params_changes_params_pipeline():
pipeline = Pipeline(model=CatBoostMultiSegmentModel(iterations=1000, depth=10), transforms=(), horizon=5)
pipeline.set_params(**{"model.learning_rate": 1e-3, "model.depth": 8, "transforms": AddConstTransform("column", 1)})
expected_dict = {
"_target_": "etna.pipeline.pipeline.Pipeline",
"horizon": 5,
"model": {
"_target_": "etna.models.catboost.CatBoostMultiSegmentModel",
"depth": 8,
"iterations": 1000,
"kwargs": {},
"learning_rate": 0.001,
"logging_level": "Silent",
},
"transforms": {
"_target_": "etna.transforms.math.add_constant.AddConstTransform",
"in_column": "column",
"inplace": True,
"value": 1,
},
}
obtained_dict = pipeline.to_dict()
assert obtained_dict == expected_dict


def test_base_mixin_set_params_with_nonexistent_attributes_estimator():
catboost_model = CatBoostMultiSegmentModel(iterations=1000, depth=10)
catboost_model.set_params(**{"incorrect_attribute_1": 1e-3, "incorrect_attribute_2": 8})
expected_dict = {
"iterations": 1000,
"depth": 10,
"logging_level": "Silent",
"kwargs": {},
"_target_": "etna.models.catboost.CatBoostMultiSegmentModel",
GooseIt marked this conversation as resolved.
Show resolved Hide resolved
}
obtained_dict = catboost_model.to_dict()
assert obtained_dict == expected_dict


def test_base_mixin_set_params_with_nonexistent_attributes_pipeline():
pipeline = Pipeline(model=CatBoostMultiSegmentModel(iterations=1000, depth=10), transforms=(), horizon=5)
pipeline.set_params(
**{
"incorrect_estimator": "value",
"model.incorrect_attribute": "value",
"model.incorrect_nesting.depth": "value",
}
)
expected_dict = {
"_target_": "etna.pipeline.pipeline.Pipeline",
"horizon": 5,
"model": {
"_target_": "etna.models.catboost.CatBoostMultiSegmentModel",
"depth": 10,
"iterations": 1000,
"kwargs": {},
"logging_level": "Silent",
},
"transforms": (),
}
obtained_dict = pipeline.to_dict()
assert obtained_dict == expected_dict