From 854fce1ca52e7b3a5b717e5840cbbe4688f008b4 Mon Sep 17 00:00:00 2001 From: "d.a.bunin" Date: Thu, 6 Apr 2023 14:44:43 +0300 Subject: [PATCH 1/2] feature: add params_to_tune for DeepARModel and TFTModel --- etna/models/nn/deepar.py | 24 ++++++++++++++++++++++-- etna/models/nn/tft.py | 25 +++++++++++++++++++++++-- tests/test_models/nn/test_deepar.py | 13 +++++++++++++ tests/test_models/nn/test_tft.py | 13 +++++++++++++ 4 files changed, 71 insertions(+), 4 deletions(-) diff --git a/etna/models/nn/deepar.py b/etna/models/nn/deepar.py index 15edc9fae..d80eb25c4 100644 --- a/etna/models/nn/deepar.py +++ b/etna/models/nn/deepar.py @@ -23,6 +23,12 @@ from pytorch_forecasting.models import DeepAR from pytorch_lightning import LightningModule +if SETTINGS.auto_required: + from optuna.distributions import BaseDistribution + from optuna.distributions import IntUniformDistribution + from optuna.distributions import LogUniformDistribution + from optuna.distributions import UniformDistribution + class DeepARModel(_DeepCopyMixin, PytorchForecastingMixin, SaveNNMixin, PredictionIntervalContextRequiredAbstractModel): """Wrapper for :py:class:`pytorch_forecasting.models.deepar.DeepAR`. @@ -240,8 +246,7 @@ def predict( def get_model(self) -> Any: """Get internal model that is used inside etna class. - Internal model is a model that is used inside etna to forecast segments, - e.g. :py:class:`catboost.CatBoostRegressor` or :py:class:`sklearn.linear_model.Ridge`. + Model is the instance of :py:class:`pytorch_forecasting.models.deepar.DeepAR`. Returns ------- @@ -249,3 +254,18 @@ def get_model(self) -> Any: Internal model """ return self.model + + def params_to_tune(self) -> Dict[str, "BaseDistribution"]: + """Get default grid for tuning hyperparameters. + + Returns + ------- + : + Grid to tune. + """ + return { + "hidden_size": IntUniformDistribution(low=4, high=64, step=4), + "rnn_layers": IntUniformDistribution(low=1, high=3, step=1), + "dropout": UniformDistribution(low=0, high=0.5), + "lr": LogUniformDistribution(low=1e-5, high=1e-2), + } diff --git a/etna/models/nn/tft.py b/etna/models/nn/tft.py index ce27d448e..9945e8c92 100644 --- a/etna/models/nn/tft.py +++ b/etna/models/nn/tft.py @@ -23,6 +23,12 @@ from pytorch_forecasting.models import TemporalFusionTransformer from pytorch_lightning import LightningModule +if SETTINGS.auto_required: + from optuna.distributions import BaseDistribution + from optuna.distributions import IntUniformDistribution + from optuna.distributions import LogUniformDistribution + from optuna.distributions import UniformDistribution + class TFTModel(_DeepCopyMixin, PytorchForecastingMixin, SaveNNMixin, PredictionIntervalContextRequiredAbstractModel): """Wrapper for :py:class:`pytorch_forecasting.models.temporal_fusion_transformer.TemporalFusionTransformer`. @@ -269,8 +275,7 @@ def predict( def get_model(self) -> Any: """Get internal model that is used inside etna class. - Internal model is a model that is used inside etna to forecast segments, - e.g. :py:class:`catboost.CatBoostRegressor` or :py:class:`sklearn.linear_model.Ridge`. + Model is the instance of :py:class:`pytorch_forecasting.models.temporal_fusion_transformer.TemporalFusionTransformer`. Returns ------- @@ -278,3 +283,19 @@ def get_model(self) -> Any: Internal model """ return self.model + + def params_to_tune(self) -> Dict[str, "BaseDistribution"]: + """Get default grid for tuning hyperparameters. + + Returns + ------- + : + Grid to tune. + """ + return { + "hidden_size": IntUniformDistribution(low=4, high=64, step=4), + "lstm_layers": IntUniformDistribution(low=1, high=3, step=1), + "dropout": UniformDistribution(low=0, high=0.5), + "attention_head_size": IntUniformDistribution(low=2, high=8, step=2), + "lr": LogUniformDistribution(low=1e-5, high=1e-2), + } diff --git a/tests/test_models/nn/test_deepar.py b/tests/test_models/nn/test_deepar.py index 7ca8750c7..e91a7a860 100644 --- a/tests/test_models/nn/test_deepar.py +++ b/tests/test_models/nn/test_deepar.py @@ -2,6 +2,7 @@ import pandas as pd import pytest +from optuna.samplers import RandomSampler from pytorch_forecasting.data import GroupNormalizer from etna.datasets.tsdataset import TSDataset @@ -191,3 +192,15 @@ def test_repr(): def test_deepar_forecast_throw_error_on_return_components(): with pytest.raises(NotImplementedError, match="This mode isn't currently implemented!"): DeepARModel.forecast(self=Mock(), ts=Mock(), prediction_size=Mock(), return_components=True) + + +def test_params_to_tune(): + model = DeepARModel(decoder_length=3, encoder_length=4) + grid = model.params_to_tune() + # we need sampler to get a value from distribution + sampler = RandomSampler() + + assert len(grid) > 0 + for name, distribution in grid.items(): + value = sampler.sample_independent(study=None, trial=None, param_name=name, param_distribution=distribution) + _ = model.set_params(**{name: value}) diff --git a/tests/test_models/nn/test_tft.py b/tests/test_models/nn/test_tft.py index 7f0d7b737..4bcdc29aa 100644 --- a/tests/test_models/nn/test_tft.py +++ b/tests/test_models/nn/test_tft.py @@ -2,6 +2,7 @@ import pandas as pd import pytest +from optuna.samplers import RandomSampler from etna.metrics import MAE from etna.models.nn import TFTModel @@ -196,3 +197,15 @@ def test_repr(): def test_tft_forecast_throw_error_on_return_components(): with pytest.raises(NotImplementedError, match="This mode isn't currently implemented!"): TFTModel.forecast(self=Mock(), ts=Mock(), prediction_size=Mock(), return_components=True) + + +def test_params_to_tune(): + model = TFTModel(decoder_length=3, encoder_length=4) + grid = model.params_to_tune() + # we need sampler to get a value from distribution + sampler = RandomSampler() + + assert len(grid) > 0 + for name, distribution in grid.items(): + value = sampler.sample_independent(study=None, trial=None, param_name=name, param_distribution=distribution) + _ = model.set_params(**{name: value}) From 5ca505f8ce1e6e98028eb9d35d40aa5c276153bc Mon Sep 17 00:00:00 2001 From: "d.a.bunin" Date: Thu, 6 Apr 2023 14:48:07 +0300 Subject: [PATCH 2/2] chore: update changelog --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 52f54e345..d5e113006 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -48,6 +48,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Add default `params_to_tune` for `SARIMAXModel`, change default parameters for the model ([#1206](https://github.com/tinkoff-ai/etna/pull/1206)) - Add default `params_to_tune` for linear models ([#1204](https://github.com/tinkoff-ai/etna/pull/1204)) - Add default `params_to_tune` for `SeasonalMovingAverageModel`, `MovingAverageModel`, `NaiveModel` and `DeadlineMovingAverageModel` ([#1208](https://github.com/tinkoff-ai/etna/pull/1208)) +- Add default `params_to_tune` for `DeepARModel` and `TFTModel` ([#1210](https://github.com/tinkoff-ai/etna/pull/1210)) ### Fixed - Fix bug in `GaleShapleyFeatureSelectionTransform` with wrong number of remaining features ([#1110](https://github.com/tinkoff-ai/etna/pull/1110)) - `ProphetModel` fails with additional seasonality set ([#1157](https://github.com/tinkoff-ai/etna/pull/1157))