Skip to content

Commit

Permalink
Merge branch 'master' into add_auto_base_and_auto_abstract
Browse files Browse the repository at this point in the history
  • Loading branch information
GooseIt authored Feb 22, 2023
2 parents 37f2106 + cdf1ee3 commit 881f38e
Show file tree
Hide file tree
Showing 3 changed files with 206 additions and 1 deletion.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## Unreleased
### Added

-
- Method `set_params` to change parameters of ETNA objects [#1102](https://github.com/tinkoff-ai/etna/pull/1102)
-
### Changed

Expand Down
46 changes: 46 additions & 0 deletions etna/core/mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from typing import Tuple
from typing import cast

import hydra_slayer
from sklearn.base import BaseEstimator


Expand Down Expand Up @@ -89,6 +90,51 @@ def to_dict(self):
params["_target_"] = BaseMixin._get_target_from_class(self)
return params

@staticmethod
def _update_nested_dict_with_flat_dict(params_dict: dict, flat_dict: dict):
"""Update nested dict with flat dict.
The method updates ``params_dict`` with values from ``flat_dict``,
so that ``params_dict`` contains all the nested keys of two given dicts,
e.g. for ``params_dict = {"model": {"learning_rate": value1}}``
and ``flat_dict = {"model.depth": value2}``
resulting ``params_dict`` will be
``{"model": {"depth": value1, "learning_rate": value2}}``
Parameters
----------
**params_dict: dict
dict with nested parameters structure, e.g ``{"model": {"learning_rate": value1}}``
**flat_dict: dict
dict with flat paratemers structure, e.g. ``{"model.depth": value2}``
"""
for param, param_value in flat_dict.items():
*param_nesting, param_attr = param.split(".")
cycle_dict = params_dict
for param_nested in param_nesting:
cycle_dict = cycle_dict.setdefault(param_nested, {})
cycle_dict[param_attr] = param_value

def set_params(self, **params: dict) -> "BaseMixin":
"""Return new object instance with modified parameters.
The method works on simple estimators as well as on nested objects
(such as :class:`~etna.pipeline.Pipeline`). The latter have
parameters of the form ``<component>.<parameter>`` so that it's
possible to update each component of a nested object.
Parameters
----------
**params: dict
Estimator parameters.
"""
params_dict = self.to_dict()
self._update_nested_dict_with_flat_dict(params_dict, params)
estimator_out = hydra_slayer.get_from_params(**params_dict)
return estimator_out


class StringEnumWithRepr(str, Enum):
"""Base class for str enums, that has alternative __repr__ method."""
Expand Down
159 changes: 159 additions & 0 deletions tests/test_core/test_set_params.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
import pytest

from etna.core import BaseMixin
from etna.models import CatBoostMultiSegmentModel
from etna.pipeline import Pipeline
from etna.transforms import AddConstTransform


def test_base_mixin_set_params_changes_params_estimator():
catboost_model = CatBoostMultiSegmentModel(iterations=1000, depth=10)
catboost_model = catboost_model.set_params(**{"learning_rate": 1e-3, "depth": 8})
expected_dict = {
"_target_": "etna.models.catboost.CatBoostMultiSegmentModel",
"iterations": 1000,
"depth": 8,
"learning_rate": 1e-3,
"logging_level": "Silent",
"kwargs": {},
}
obtained_dict = catboost_model.to_dict()
assert obtained_dict == expected_dict


def test_base_mixin_set_params_changes_params_pipeline():
pipeline = Pipeline(model=CatBoostMultiSegmentModel(iterations=1000, depth=10), transforms=(), horizon=5)
pipeline = pipeline.set_params(
**{"model.learning_rate": 1e-3, "model.depth": 8, "transforms": [AddConstTransform("column", 1)]}
)
expected_dict = {
"_target_": "etna.pipeline.pipeline.Pipeline",
"horizon": 5,
"model": {
"_target_": "etna.models.catboost.CatBoostMultiSegmentModel",
"depth": 8,
"iterations": 1000,
"kwargs": {},
"learning_rate": 0.001,
"logging_level": "Silent",
},
"transforms": [
{
"_target_": "etna.transforms.math.add_constant.AddConstTransform",
"in_column": "column",
"inplace": True,
"value": 1,
}
],
}
obtained_dict = pipeline.to_dict()
assert obtained_dict == expected_dict


def test_base_mixin_set_params_doesnt_change_params_inplace_estimator():
catboost_model = CatBoostMultiSegmentModel(iterations=1000, depth=10)
catboost_model.set_params(**{"learning_rate": 1e-3, "depth": 8})
expected_dict = {
"_target_": "etna.models.catboost.CatBoostMultiSegmentModel",
"iterations": 1000,
"depth": 10,
"logging_level": "Silent",
"kwargs": {},
}
obtained_dict = catboost_model.to_dict()
assert obtained_dict == expected_dict


def test_base_mixin_set_params_doesnt_change_params_inplace_pipeline():
pipeline = Pipeline(model=CatBoostMultiSegmentModel(iterations=1000, depth=10), transforms=(), horizon=5)
pipeline.set_params(
**{"model.learning_rate": 1e-3, "model.depth": 8, "transforms": [AddConstTransform("column", 1)]}
)
expected_dict = {
"_target_": "etna.pipeline.pipeline.Pipeline",
"horizon": 5,
"model": {
"_target_": "etna.models.catboost.CatBoostMultiSegmentModel",
"depth": 10,
"iterations": 1000,
"kwargs": {},
"logging_level": "Silent",
},
"transforms": (),
}
obtained_dict = pipeline.to_dict()
assert obtained_dict == expected_dict


def test_base_mixin_set_params_with_nonexistent_attributes_estimator():
catboost_model = CatBoostMultiSegmentModel(iterations=1000, depth=10)
with pytest.raises(TypeError, match=".*got an unexpected keyword argument.*"):
catboost_model.set_params(**{"incorrect_attribute": 1e-3})


def test_base_mixin_set_params_with_nonexistent_not_nested_attribute_pipeline():
pipeline = Pipeline(model=CatBoostMultiSegmentModel(iterations=1000, depth=10), transforms=(), horizon=5)
with pytest.raises(TypeError, match=".*got an unexpected keyword argument.*"):
pipeline.set_params(
**{
"incorrect_estimator": "value",
}
)


def test_base_mixin_set_params_with_nonexistent_nested_attribute_pipeline():
pipeline = Pipeline(model=CatBoostMultiSegmentModel(iterations=1000, depth=10), transforms=(), horizon=5)
with pytest.raises(TypeError, match=".*got an unexpected keyword argument.*"):
pipeline.set_params(
**{
"model.incorrect_attribute": "value",
}
)


def test_update_nested_dict_with_flat_dict_empty_flat_dict_returns_nested_dict():
nested_dict = {"learning_rate": 1e-3}
flat_dict = {}
BaseMixin._update_nested_dict_with_flat_dict(nested_dict, flat_dict)
expected_dict = {"learning_rate": 1e-3}
assert nested_dict == expected_dict


def test_update_nested_dict_with_flat_dict_empty_nested_dict_no_nesting_in_flat_dict():
nested_dict = {}
flat_dict = {"depth": 8}
BaseMixin._update_nested_dict_with_flat_dict(nested_dict, flat_dict)
expected_dict = {"depth": 8}
assert nested_dict == expected_dict


def test_update_nested_dict_with_flat_dict_empty_nested_dict_nesting_in_flat_dict():
nested_dict = {}
flat_dict = {"model.depth": 8}
BaseMixin._update_nested_dict_with_flat_dict(nested_dict, flat_dict)
expected_dict = {"model": {"depth": 8}}
assert nested_dict == expected_dict


def test_update_nested_dict_with_flat_dict_no_nesting_in_flat_dict():
nested_dict = {"learning_rate": 1e-3}
flat_dict = {"depth": 8}
BaseMixin._update_nested_dict_with_flat_dict(nested_dict, flat_dict)
expected_dict = {"learning_rate": 1e-3, "depth": 8}
assert nested_dict == expected_dict


def test_update_nested_dict_with_flat_dict_nesting_in_flat_dict():
nested_dict = {"model": {"learning_rate": 1e-3}}
flat_dict = {"model.depth": 8}
BaseMixin._update_nested_dict_with_flat_dict(nested_dict, flat_dict)
expected_dict = {"model": {"learning_rate": 1e-3, "depth": 8}}
assert nested_dict == expected_dict


def test_update_nested_dict_with_flat_dict_prioritizes_flat_dict_params():
nested_dict = {"learning_rate": 1e-3}
flat_dict = {"learning_rate": 3e-4}
BaseMixin._update_nested_dict_with_flat_dict(nested_dict, flat_dict)
expected_dict = {"learning_rate": 3e-4}
assert nested_dict == expected_dict

0 comments on commit 881f38e

Please sign in to comment.