diff --git a/CHANGELOG.md b/CHANGELOG.md index 9294fb5a3..9ccfbb910 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,6 +13,13 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - - - +- +- +- +- +- +- +- ### Changed - - @@ -22,12 +29,23 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - - Mark some tests as long ([#929](https://github.com/tinkoff-ai/etna/pull/929)) - -### Fixed - - +- +- +- +### Fixed +- Fix to_dict with function as parameter ([#941](https://github.com/tinkoff-ai/etna/pull/941)) +- - Fix native networks to work with generated future equals to horizon ([#936](https://github.com/tinkoff-ai/etna/pull/936)) - Fix `SARIMAXModel` to work with exogenous data on `pmdarima>=2.0` ([#940](https://github.com/tinkoff-ai/etna/pull/940)) - +- +- +- +- +- +- ## [1.12.0] - 2022-09-05 ### Added diff --git a/etna/core/mixins.py b/etna/core/mixins.py index 90166dab0..2ddec9d7a 100644 --- a/etna/core/mixins.py +++ b/etna/core/mixins.py @@ -2,6 +2,7 @@ import warnings from enum import Enum from typing import Any +from typing import Callable from typing import Dict from typing import List @@ -37,6 +38,10 @@ def _get_target_from_class(value: Any): return None return str(value.__module__) + "." + str(value.__class__.__name__) + @staticmethod + def _get_target_from_function(value: Callable): + return str(value.__module__) + "." + str(value.__qualname__) + @staticmethod def _parse_value(value: Any) -> Any: if isinstance(value, BaseMixin): @@ -55,6 +60,8 @@ def _parse_value(value: Any) -> Any: return tuple([BaseMixin._parse_value(elem) for elem in value]) elif isinstance(value, Dict): return {key: BaseMixin._parse_value(item) for key, item in value.items()} + elif inspect.isfunction(value): + return {"_target_": BaseMixin._get_target_from_function(value)} else: answer = {} answer["_target_"] = BaseMixin._get_target_from_class(value) diff --git a/tests/test_core/test_to_dict.py b/tests/test_core/test_to_dict.py index 99f930ae8..6831bd5fd 100644 --- a/tests/test_core/test_to_dict.py +++ b/tests/test_core/test_to_dict.py @@ -18,6 +18,7 @@ from etna.pipeline import Pipeline from etna.transforms import AddConstTransform from etna.transforms import ChangePointsTrendTransform +from etna.transforms import DensityOutliersTransform from etna.transforms import LambdaTransform from etna.transforms import LogTransform @@ -53,9 +54,15 @@ def ensemble_samples(): ChangePointsTrendTransform( in_column="target", change_point_model=Binseg(), detrend_model=LinearRegression(), n_bkps=50 ), + pytest.param( + DensityOutliersTransform("target", distance_coef=6), + marks=pytest.mark.xfail( + reason="partial function after initialization instead of original function, dumps return different results" + ), + ), pytest.param( LambdaTransform(in_column="target", transform_func=lambda x: x - 2, inverse_transform_func=lambda x: x + 2), - marks=pytest.mark.xfail(reason="some bug"), + marks=pytest.mark.xfail(reason="lambdas in class attributes"), ), ], ) @@ -66,6 +73,22 @@ def test_to_dict_transforms(target_object): assert pickle.dumps(transformed_object) == pickle.dumps(target_object) +# fmt: off +@pytest.mark.parametrize( + "target_object, expected", + [ + ( + DensityOutliersTransform("target", distance_coef=6), + {'in_column': 'target', 'window_size': 15, 'distance_coef': 6, 'n_neighbors': 3, 'distance_func': {'_target_': 'etna.analysis.outliers.density_outliers.absolute_difference_distance'}, '_target_': 'etna.transforms.outliers.point_outliers.DensityOutliersTransform'} # noqa: E501 + ) + ], +) +def test_to_dict_transforms_with_expected(target_object, expected): + dict_object = target_object.to_dict() + assert dict_object == expected +# fmt: on + + @pytest.mark.parametrize( "target_model", [