Skip to content

add_set_params_and_corresponding_tests #1102

Merged
merged 18 commits into from
Feb 22, 2023
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## Unreleased
### Added

-
- Method `set_params` to change parameters of ETNA objects [#1102](https://github.com/tinkoff-ai/etna/pull/1102)
-
### Changed

Expand Down
44 changes: 44 additions & 0 deletions etna/core/mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from typing import Tuple
from typing import cast

import hydra_slayer
from sklearn.base import BaseEstimator


Expand Down Expand Up @@ -89,6 +90,49 @@ def to_dict(self):
params["_target_"] = BaseMixin._get_target_from_class(self)
return params

@staticmethod
def _update_nested_dict_with_flat_dict(params_dict: dict, flat_dict: dict):
GooseIt marked this conversation as resolved.
Show resolved Hide resolved
"""Change flat specification into dict of nested params.
GooseIt marked this conversation as resolved.
Show resolved Hide resolved

Parameters
----------
**params_dict: dict
dict with nested parameters structure, e.g ``{"model": {"learning_rate": value1}}``
**flat_dict: dict
dict with flat paratemers structure, e.g. ``{"model.depth": value2}``

Returns
GooseIt marked this conversation as resolved.
Show resolved Hide resolved
-------
**nested_dict: dict
dict with nested parameters structure, containing all the nested keys of two given dicts,
e.g. ``{"model": {"depth": value1, "learning_rate": value2}}``
"""
for param, param_value in flat_dict.items():
*param_nesting, param_attr = param.split(".")
cycle_dict = params_dict
for param_nested in param_nesting:
cycle_dict = cycle_dict.setdefault(param_nested, {})
cycle_dict[param_attr] = param_value

def set_params(self, **params: dict) -> "BaseMixin":
"""Return new object instance with modified parameters.

The method works on simple estimators as well as on nested objects
(such as :class:`~etna.pipeline.Pipeline`). The latter have
parameters of the form ``<component>.<parameter>`` so that it's
possible to update each component of a nested object.

Parameters
----------
**params: dict
Estimator parameters.
GooseIt marked this conversation as resolved.
Show resolved Hide resolved

"""
params_dict = self.to_dict()
self._update_nested_dict_with_flat_dict(params_dict, params)
estimator_out = hydra_slayer.get_from_params(**params_dict)
return estimator_out


class StringEnumWithRepr(str, Enum):
"""Base class for str enums, that has alternative __repr__ method."""
Expand Down
159 changes: 159 additions & 0 deletions tests/test_core/test_set_params.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
import pytest

from etna.core import BaseMixin
from etna.models import CatBoostMultiSegmentModel
from etna.pipeline import Pipeline
from etna.transforms import AddConstTransform


def test_base_mixin_set_params_changes_params_estimator():
catboost_model = CatBoostMultiSegmentModel(iterations=1000, depth=10)
catboost_model = catboost_model.set_params(**{"learning_rate": 1e-3, "depth": 8})
expected_dict = {
"_target_": "etna.models.catboost.CatBoostMultiSegmentModel",
"iterations": 1000,
"depth": 8,
"learning_rate": 1e-3,
"logging_level": "Silent",
"kwargs": {},
}
obtained_dict = catboost_model.to_dict()
assert obtained_dict == expected_dict


def test_base_mixin_set_params_changes_params_pipeline():
pipeline = Pipeline(model=CatBoostMultiSegmentModel(iterations=1000, depth=10), transforms=(), horizon=5)
pipeline = pipeline.set_params(
**{"model.learning_rate": 1e-3, "model.depth": 8, "transforms": [AddConstTransform("column", 1)]}
)
expected_dict = {
"_target_": "etna.pipeline.pipeline.Pipeline",
"horizon": 5,
"model": {
"_target_": "etna.models.catboost.CatBoostMultiSegmentModel",
"depth": 8,
"iterations": 1000,
"kwargs": {},
"learning_rate": 0.001,
"logging_level": "Silent",
},
"transforms": [
{
"_target_": "etna.transforms.math.add_constant.AddConstTransform",
"in_column": "column",
"inplace": True,
"value": 1,
}
],
}
obtained_dict = pipeline.to_dict()
assert obtained_dict == expected_dict


def test_base_mixin_set_params_doesnt_change_params_inplace_estimator():
catboost_model = CatBoostMultiSegmentModel(iterations=1000, depth=10)
catboost_model.set_params(**{"learning_rate": 1e-3, "depth": 8})
expected_dict = {
"_target_": "etna.models.catboost.CatBoostMultiSegmentModel",
"iterations": 1000,
"depth": 10,
"logging_level": "Silent",
"kwargs": {},
}
obtained_dict = catboost_model.to_dict()
assert obtained_dict == expected_dict


def test_base_mixin_set_params_doesnt_change_params_inplace_pipeline():
pipeline = Pipeline(model=CatBoostMultiSegmentModel(iterations=1000, depth=10), transforms=(), horizon=5)
pipeline.set_params(
**{"model.learning_rate": 1e-3, "model.depth": 8, "transforms": [AddConstTransform("column", 1)]}
)
expected_dict = {
"_target_": "etna.pipeline.pipeline.Pipeline",
"horizon": 5,
"model": {
"_target_": "etna.models.catboost.CatBoostMultiSegmentModel",
"depth": 10,
"iterations": 1000,
"kwargs": {},
"logging_level": "Silent",
},
"transforms": (),
}
obtained_dict = pipeline.to_dict()
assert obtained_dict == expected_dict


def test_base_mixin_set_params_with_nonexistent_attributes_estimator():
catboost_model = CatBoostMultiSegmentModel(iterations=1000, depth=10)
with pytest.raises(TypeError, match=".*got an unexpected keyword argument.*"):
catboost_model.set_params(**{"incorrect_attribute": 1e-3})


def test_base_mixin_set_params_with_nonexistent_not_nested_attribute_pipeline():
pipeline = Pipeline(model=CatBoostMultiSegmentModel(iterations=1000, depth=10), transforms=(), horizon=5)
with pytest.raises(TypeError, match=".*got an unexpected keyword argument.*"):
pipeline.set_params(
GooseIt marked this conversation as resolved.
Show resolved Hide resolved
**{
"incorrect_estimator": "value",
}
)


def test_base_mixin_set_params_with_nonexistent_nested_attribute_pipeline():
pipeline = Pipeline(model=CatBoostMultiSegmentModel(iterations=1000, depth=10), transforms=(), horizon=5)
with pytest.raises(TypeError, match=".*got an unexpected keyword argument.*"):
pipeline.set_params(
**{
"model.incorrect_attribute": "value",
}
)


def test_update_nested_dict_with_flat_dict_empty_flat_dict_returns_nested_dict():
nested_dict = {"learning_rate": 1e-3}
flat_dict = {}
BaseMixin._update_nested_dict_with_flat_dict(nested_dict, flat_dict)
expected_dict = {"learning_rate": 1e-3}
assert nested_dict == expected_dict


def test_update_nested_dict_with_flat_dict_empty_nested_dict_no_nesting_in_flat_dict():
nested_dict = {}
flat_dict = {"depth": 8}
BaseMixin._update_nested_dict_with_flat_dict(nested_dict, flat_dict)
expected_dict = {"depth": 8}
assert nested_dict == expected_dict


def test_update_nested_dict_with_flat_dict_empty_nested_dict_nesting_in_flat_dict():
nested_dict = {}
flat_dict = {"model.depth": 8}
BaseMixin._update_nested_dict_with_flat_dict(nested_dict, flat_dict)
expected_dict = {"model": {"depth": 8}}
assert nested_dict == expected_dict


def test_update_nested_dict_with_flat_dict_no_nesting_in_flat_dict():
nested_dict = {"learning_rate": 1e-3}
flat_dict = {"depth": 8}
BaseMixin._update_nested_dict_with_flat_dict(nested_dict, flat_dict)
expected_dict = {"learning_rate": 1e-3, "depth": 8}
assert nested_dict == expected_dict


def test_update_nested_dict_with_flat_dict_nesting_in_flat_dict():
nested_dict = {"model": {"learning_rate": 1e-3}}
flat_dict = {"model.depth": 8}
BaseMixin._update_nested_dict_with_flat_dict(nested_dict, flat_dict)
expected_dict = {"model": {"learning_rate": 1e-3, "depth": 8}}
assert nested_dict == expected_dict


def test_update_nested_dict_with_flat_dict_prioritizes_flat_dict_params():
nested_dict = {"learning_rate": 1e-3}
flat_dict = {"learning_rate": 3e-4}
BaseMixin._update_nested_dict_with_flat_dict(nested_dict, flat_dict)
expected_dict = {"learning_rate": 3e-4}
assert nested_dict == expected_dict