Skip to content

Commit

Permalink
FIX: to_dict with nn models (#949)
Browse files Browse the repository at this point in the history
  • Loading branch information
martins0n authored Sep 23, 2022
1 parent eea7498 commit 76b8081
Show file tree
Hide file tree
Showing 9 changed files with 119 additions and 40 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ jobs:
- name: PyTest ("not long")
run: |
poetry run pytest tests -v --cov=etna -m "not long_1 and not long_2" --cov-report=xml --durations=10
poetry run pytest etna -v --doctest-modules --durations=10
poetry run pytest etna -v --doctest-modules --ignore=etna/libs --durations=10
- name: Upload coverage
uses: codecov/codecov-action@v2
Expand Down
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
-
-
- Mark some tests as long ([#929](https://github.com/tinkoff-ai/etna/pull/929))
-
- Fix to_dict with nn models and add unsafe conversion for callbacks ([#949](https://github.com/tinkoff-ai/etna/pull/949))
-
-
-
Expand Down
2 changes: 2 additions & 0 deletions etna/core/mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ def _parse_value(value: Any) -> Any:
model_parameters = value.get_params()
answer.update(model_parameters)
return answer
elif hasattr(value, "_init_params"):
return {"_target_": BaseMixin._get_target_from_class(value), **value._init_params}
elif isinstance(value, (str, float, int)):
return value
elif isinstance(value, List):
Expand Down
42 changes: 42 additions & 0 deletions etna/core/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import inspect
from copy import deepcopy
from functools import wraps
from typing import Callable


def init_collector(init: Callable) -> Callable:
"""
Make decorator for collecting init parameters.
N.B. if init method has positional only parameters, they will be ignored.
"""

@wraps(init)
def wrapper(*args, **kwargs):
self, *args = args
init_args = inspect.signature(self.__init__).parameters

deepcopy_args = deepcopy(args)
deepcopy_kwargs = deepcopy(kwargs)

self._init_params = {}
args_dict = dict(
zip([arg for arg, param in init_args.items() if param.kind == param.POSITIONAL_OR_KEYWORD], deepcopy_args)
)
self._init_params.update(args_dict)
self._init_params.update(deepcopy_kwargs)

return init(self, *args, **kwargs)

return wrapper


def create_type_with_init_collector(type_: type) -> type:
"""Create type with init decorated with init_collector."""
previous_frame = inspect.stack()[1]
module = inspect.getmodule(previous_frame[0])
if module is None:
return type_
new_type = type(type_.__name__, (type_,), {"__module__": module.__name__})
if hasattr(type_, "__init__"):
new_type.__init__ = init_collector(new_type.__init__) # type: ignore
return new_type
Empty file.
18 changes: 18 additions & 0 deletions etna/libs/pytorch_lightning/callbacks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@

from copy import deepcopy

from etna.core.utils import create_type_with_init_collector


from pytorch_lightning.callbacks import __all__ as pl_callbacks

generated_types = []

for type_name in pl_callbacks:

type_ = deepcopy(getattr(__import__('pytorch_lightning.callbacks', fromlist=[type_name]), type_name))
new_type = create_type_with_init_collector(type_)
globals()[type_name] = new_type
generated_types.append(new_type)

__all__ = generated_types
43 changes: 24 additions & 19 deletions etna/models/nn/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def __init__(
self.input_size = input_size
self.hidden_size = hidden_size
self.lr = lr
self.loss = nn.MSELoss() if loss is None else loss
self.loss = loss
self.optimizer_params = {} if optimizer_params is None else optimizer_params
layers = [nn.Linear(in_features=input_size, out_features=hidden_size[0]), nn.ReLU()]
for i in range(1, len(hidden_size)):
Expand Down Expand Up @@ -165,24 +165,6 @@ def __init__(
val_dataloader_params: Optional[dict] = None,
split_params: Optional[dict] = None,
):
super().__init__(
net=MLPNet(
input_size=input_size,
hidden_size=hidden_size,
lr=lr,
loss=loss, # type: ignore
optimizer_params=optimizer_params,
),
encoder_length=encoder_length,
decoder_length=decoder_length,
train_batch_size=train_batch_size,
test_batch_size=test_batch_size,
train_dataloader_params=train_dataloader_params,
test_dataloader_params=test_dataloader_params,
val_dataloader_params=val_dataloader_params,
trainer_params=trainer_params,
split_params=split_params,
)
"""Init MLP model.
Parameters
----------
Expand Down Expand Up @@ -218,3 +200,26 @@ def __init__(
* **generator**: (*Optional[torch.Generator]*) - generator for reproducibile train-test splitting
* **torch_dataset_size**: (*Optional[int]*) - number of samples in dataset, in case of dataset not implementing ``__len__``
"""
self.input_size = input_size
self.hidden_size = hidden_size
self.lr = lr
self.loss = loss
self.optimizer_params = optimizer_params
super().__init__(
net=MLPNet(
input_size=input_size,
hidden_size=hidden_size,
lr=lr,
loss=nn.MSELoss() if loss is None else loss,
optimizer_params=optimizer_params,
),
encoder_length=encoder_length,
decoder_length=decoder_length,
train_batch_size=train_batch_size,
test_batch_size=test_batch_size,
train_dataloader_params=train_dataloader_params,
test_dataloader_params=test_dataloader_params,
val_dataloader_params=val_dataloader_params,
trainer_params=trainer_params,
split_params=split_params,
)
44 changes: 25 additions & 19 deletions etna/models/nn/rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,25 +213,6 @@ def __init__(
val_dataloader_params: Optional[dict] = None,
split_params: Optional[dict] = None,
):
super().__init__(
net=RNNNet(
input_size=input_size,
num_layers=num_layers,
hidden_size=hidden_size,
lr=lr,
loss=nn.MSELoss() if loss is None else loss,
optimizer_params=optimizer_params,
),
decoder_length=decoder_length,
encoder_length=encoder_length,
train_batch_size=train_batch_size,
test_batch_size=test_batch_size,
train_dataloader_params=train_dataloader_params,
test_dataloader_params=test_dataloader_params,
val_dataloader_params=val_dataloader_params,
trainer_params=trainer_params,
split_params=split_params,
)
"""Init RNN model based on LSTM cell.
Parameters
Expand Down Expand Up @@ -272,3 +253,28 @@ def __init__(
* **torch_dataset_size**: (*Optional[int]*) - number of samples in dataset, in case of dataset not implementing ``__len__``
"""
self.input_size = input_size
self.num_layers = num_layers
self.hidden_size = hidden_size
self.lr = lr
self.loss = loss
self.optimizer_params = optimizer_params
super().__init__(
net=RNNNet(
input_size=input_size,
num_layers=num_layers,
hidden_size=hidden_size,
lr=lr,
loss=nn.MSELoss() if loss is None else loss,
optimizer_params=optimizer_params,
),
decoder_length=decoder_length,
encoder_length=encoder_length,
train_batch_size=train_batch_size,
test_batch_size=test_batch_size,
train_dataloader_params=train_dataloader_params,
test_dataloader_params=test_dataloader_params,
val_dataloader_params=val_dataloader_params,
trainer_params=trainer_params,
split_params=split_params,
)
6 changes: 6 additions & 0 deletions tests/test_core/test_to_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,14 @@
from etna.core import BaseMixin
from etna.ensembles import StackingEnsemble
from etna.ensembles import VotingEnsemble
from etna.libs.pytorch_lightning.callbacks import EarlyStopping
from etna.metrics import MAE
from etna.metrics import SMAPE
from etna.models import AutoARIMAModel
from etna.models import CatBoostModelPerSegment
from etna.models import LinearPerSegmentModel
from etna.models.nn import DeepARModel
from etna.models.nn import MLPModel
from etna.pipeline import Pipeline
from etna.transforms import AddConstTransform
from etna.transforms import ChangePointsTrendTransform
Expand Down Expand Up @@ -80,6 +82,10 @@ def test_to_dict_transforms(target_object):
(
DensityOutliersTransform("target", distance_coef=6),
{'in_column': 'target', 'window_size': 15, 'distance_coef': 6, 'n_neighbors': 3, 'distance_func': {'_target_': 'etna.analysis.outliers.density_outliers.absolute_difference_distance'}, '_target_': 'etna.transforms.outliers.point_outliers.DensityOutliersTransform'} # noqa: E501
),
(
MLPModel(decoder_length=1, hidden_size=[64, 64], input_size=1, trainer_params={"max_epochs": 100, "callbacks": [EarlyStopping(monitor="val_loss", patience=3)]}, lr=0.01, train_batch_size=32, split_params=dict(train_size=0.75)), # noqa: E501
{'input_size': 1, 'decoder_length': 1, 'hidden_size': [64, 64], 'encoder_length': 0, 'lr': 0.01, 'train_batch_size': 32, 'test_batch_size': 16, 'trainer_params': {'max_epochs': 100, 'callbacks': [{'monitor': 'val_loss', 'patience': 3, '_target_': 'etna.libs.pytorch_lightning.callbacks.EarlyStopping'}]}, 'train_dataloader_params': {}, 'test_dataloader_params': {}, 'val_dataloader_params': {}, 'split_params': {'train_size': 0.75}, '_target_': 'etna.models.nn.mlp.MLPModel'} # noqa: E501
)
],
)
Expand Down

1 comment on commit 76b8081

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please sign in to comment.