-
Notifications
You must be signed in to change notification settings - Fork 80
Merged
MLPModel #860
Changes from 22 commits
Commits
Show all changes
23 commits
Select commit
Hold shift + click to select a range
578c83b
MLPModel
e0bd288
fix docstring
b724f27
add tests for step and forward method
ac9688a
update tests
e61950b
fix test
d59982c
fix lint
d3b042f
update tests
58b761a
update tests
670f081
update tests and fix import
3ef4548
add changelog
3e8210c
fix make_samples method
6aed23c
fix make_samples
e99a7ff
update make_samples test
c068141
fix mistake in sampling of last batch
e830a48
Merge branch 'master' into issue-829
martins0n 1deb1e9
add checking for decoder_real
81ee25c
add decoder_real
b8f0c79
Merge branch 'issue-829' of https://github.com/tinkoff-ai/etna into i…
4e812b0
fix lint
fb224b5
fix batch in step method
152c25d
fix size in batch
1e2e880
fix assert
92412bf
fix docstrings
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,222 @@ | ||
from typing import Any | ||
from typing import Dict | ||
from typing import Iterable | ||
from typing import List | ||
from typing import Optional | ||
|
||
import pandas as pd | ||
from typing_extensions import TypedDict | ||
|
||
from etna import SETTINGS | ||
|
||
if SETTINGS.torch_required: | ||
import torch | ||
import torch.nn as nn | ||
|
||
import numpy as np | ||
|
||
from etna.models.base import DeepBaseModel | ||
from etna.models.base import DeepBaseNet | ||
|
||
|
||
class MLPBatch(TypedDict): | ||
"""Batch specification for MLP.""" | ||
|
||
decoder_real: "torch.Tensor" | ||
decoder_target: "torch.Tensor" | ||
segment: "torch.Tensor" | ||
|
||
|
||
class MLPNet(DeepBaseNet): | ||
"""MLP model.""" | ||
|
||
def __init__( | ||
self, | ||
input_size: int, | ||
hidden_size: List[int], | ||
lr: float, | ||
loss: "torch.nn.Module", | ||
optimizer_params: Optional[dict], | ||
) -> None: | ||
"""Init MLP model. | ||
|
||
Parameters | ||
---------- | ||
input_size: | ||
size of the input feature space: target plus extra features | ||
num_layers: | ||
number of layers | ||
hidden_size: | ||
list of sizes of the hidden states | ||
lr: | ||
learning rate | ||
loss: | ||
loss function | ||
optimizer_params: | ||
parameters for optimizer for Adam optimizer (api reference :py:class:`torch.optim.Adam`) | ||
""" | ||
super().__init__() | ||
self.input_size = input_size | ||
self.hidden_size = hidden_size | ||
self.lr = lr | ||
self.loss = nn.MSELoss() if loss is None else loss | ||
self.optimizer_params = {} if optimizer_params is None else optimizer_params | ||
layers = [nn.Linear(in_features=input_size, out_features=hidden_size[0]), nn.ReLU()] | ||
for i in range(1, len(hidden_size)): | ||
layers.append(nn.Linear(in_features=hidden_size[i - 1], out_features=hidden_size[i])) | ||
layers.append(nn.ReLU()) | ||
layers.append(nn.Linear(in_features=hidden_size[-1], out_features=1)) | ||
self.mlp = nn.Sequential(*layers) | ||
|
||
def forward(self, batch): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we can add typing I guess |
||
"""Forward pass. | ||
|
||
Parameters | ||
---------- | ||
batch: | ||
batch of data | ||
Returns | ||
------- | ||
: | ||
forecast | ||
""" | ||
decoder_real = batch["decoder_real"].float() | ||
return self.mlp(decoder_real) | ||
|
||
def step(self, batch: MLPBatch, *args, **kwargs): # type: ignore | ||
"""Step for loss computation for training or validation. | ||
|
||
Parameters | ||
---------- | ||
batch: | ||
batch of data | ||
Returns | ||
------- | ||
: | ||
loss, true_target, prediction_target | ||
""" | ||
decoder_real = batch["decoder_real"].float() | ||
decoder_target = batch["decoder_target"].float() | ||
|
||
output = self.mlp(decoder_real) | ||
loss = self.loss(output, decoder_target) | ||
return loss, decoder_target, output | ||
|
||
def make_samples(self, df: pd.DataFrame, encoder_length: int, decoder_length: int) -> Iterable[dict]: | ||
"""Make samples from segment DataFrame.""" | ||
|
||
def _make(df: pd.DataFrame, start_idx: int, decoder_length: int) -> Optional[dict]: | ||
sample: Dict[str, Any] = {"decoder_real": list(), "decoder_target": list(), "segment": None} | ||
total_length = len(df["target"]) | ||
total_sample_length = decoder_length | ||
|
||
if total_sample_length + start_idx > total_length: | ||
return None | ||
|
||
sample["decoder_real"] = ( | ||
df.select_dtypes(include=[np.number]) | ||
.pipe(lambda x: x[[i for i in x.columns if i != "target"]]) | ||
.values[start_idx : start_idx + decoder_length] | ||
) | ||
|
||
target = df["target"].values[start_idx : start_idx + decoder_length].reshape(-1, 1) | ||
sample["decoder_target"] = target | ||
sample["segment"] = df["segment"].values[0] | ||
return sample | ||
|
||
start_idx = 0 | ||
while True: | ||
batch = _make( | ||
df=df, | ||
start_idx=start_idx, | ||
decoder_length=decoder_length, | ||
) | ||
if batch is None: | ||
break | ||
yield batch | ||
start_idx += decoder_length | ||
if start_idx < len(df): | ||
resid_length = len(df) - decoder_length | ||
batch = _make(df=df, start_idx=resid_length, decoder_length=decoder_length) | ||
if batch is not None: | ||
yield batch | ||
|
||
def configure_optimizers(self): | ||
"""Optimizer configuration.""" | ||
optimizer = torch.optim.Adam(self.parameters(), lr=self.lr, **self.optimizer_params) | ||
return optimizer | ||
|
||
|
||
class MLPModel(DeepBaseModel): | ||
"""MLPModel.""" | ||
|
||
def __init__( | ||
self, | ||
input_size: int, | ||
decoder_length: int, | ||
hidden_size: List, | ||
encoder_length: int = 0, | ||
lr: float = 1e-3, | ||
loss: Optional["torch.nn.Module"] = None, | ||
train_batch_size: int = 16, | ||
test_batch_size: int = 16, | ||
optimizer_params: Optional[dict] = None, | ||
trainer_params: Optional[dict] = None, | ||
train_dataloader_params: Optional[dict] = None, | ||
test_dataloader_params: Optional[dict] = None, | ||
val_dataloader_params: Optional[dict] = None, | ||
split_params: Optional[dict] = None, | ||
): | ||
super().__init__( | ||
net=MLPNet( | ||
input_size=input_size, | ||
hidden_size=hidden_size, | ||
lr=lr, | ||
loss=loss, # type: ignore | ||
optimizer_params=optimizer_params, | ||
), | ||
encoder_length=encoder_length, | ||
decoder_length=decoder_length, | ||
train_batch_size=train_batch_size, | ||
test_batch_size=test_batch_size, | ||
train_dataloader_params=train_dataloader_params, | ||
test_dataloader_params=test_dataloader_params, | ||
val_dataloader_params=val_dataloader_params, | ||
trainer_params=trainer_params, | ||
split_params=split_params, | ||
) | ||
"""Init MLP model. | ||
Parameters | ||
---------- | ||
input_size: | ||
size of the input feature space: target plus extra features | ||
encoder_length: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we should change order |
||
encoder length | ||
decoder_length: | ||
decoder length | ||
hidden_size: | ||
List of sizes of the hidden states | ||
lr: | ||
learning rate | ||
loss: | ||
loss function, MSELoss by default | ||
train_batch_size: | ||
batch size for training | ||
test_batch_size: | ||
batch size for testing | ||
optimizer_params: | ||
parameters for optimizer for Adam optimizer (api reference :py:class:`torch.optim.Adam`) | ||
trainer_params: | ||
Pytorch ligthning trainer parameters (api reference :py:class:`pytorch_lightning.trainer.trainer.Trainer`) | ||
train_dataloader_params: | ||
parameters for train dataloader like sampler for example (api reference :py:class:`torch.utils.data.DataLoader`) | ||
test_dataloader_params: | ||
parameters for test dataloader | ||
val_dataloader_params: | ||
parameters for validation dataloader | ||
split_params: | ||
dictionary with parameters for :py:func:`torch.utils.data.random_split` for train-test splitting | ||
* **train_size**: (*float*) value from 0 to 1 - fraction of samples to use for training | ||
* **generator**: (*Optional[torch.Generator]*) - generator for reproducibile train-test splitting | ||
* **torch_dataset_size**: (*Optional[int]*) - number of samples in dataset, in case of dataset not implementing ``__len__`` | ||
""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
from unittest.mock import MagicMock | ||
|
||
import numpy as np | ||
import pytest | ||
import torch | ||
from torch import nn | ||
|
||
from etna.datasets.tsdataset import TSDataset | ||
from etna.metrics import MAE | ||
from etna.models.nn import MLPModel | ||
from etna.models.nn.mlp import MLPNet | ||
from etna.transforms import FourierTransform | ||
from etna.transforms import LagTransform | ||
from etna.transforms import StandardScalerTransform | ||
|
||
|
||
@pytest.mark.parametrize("horizon", [8, 13]) | ||
def test_mlp_model_run_weekly_overfit_with_scaler(ts_dataset_weekly_function_with_horizon, horizon): | ||
|
||
ts_train, ts_test = ts_dataset_weekly_function_with_horizon(horizon) | ||
lag = LagTransform(in_column="target", lags=list(range(horizon, horizon + 4))) | ||
fourier = FourierTransform(period=7, order=3) | ||
std = StandardScalerTransform(in_column="target") | ||
ts_train.fit_transform([std, lag, fourier]) | ||
|
||
decoder_length = 14 | ||
model = MLPModel( | ||
input_size=10, | ||
hidden_size=[10, 10, 10, 10, 10], | ||
lr=1e-1, | ||
decoder_length=decoder_length, | ||
trainer_params=dict(max_epochs=100), | ||
) | ||
future = ts_train.make_future(decoder_length) | ||
model.fit(ts_train) | ||
future = model.forecast(future, horizon=horizon) | ||
|
||
mae = MAE("macro") | ||
assert mae(ts_test, future) < 0.05 | ||
|
||
|
||
def test_mlp_make_samples(simple_df_relevance): | ||
mlp_module = MagicMock() | ||
df, df_exog = simple_df_relevance | ||
|
||
ts = TSDataset(df=df, df_exog=df_exog, freq="D") | ||
df = ts.to_flatten(ts.df) | ||
encoder_length = 0 | ||
decoder_length = 5 | ||
ts_samples = list( | ||
MLPNet.make_samples( | ||
mlp_module, df=df[df.segment == "1"], encoder_length=encoder_length, decoder_length=decoder_length | ||
) | ||
) | ||
first_sample = ts_samples[0] | ||
second_sample = ts_samples[1] | ||
last_sample = ts_samples[-1] | ||
expected = { | ||
"decoder_real": np.array([[58.0, 0], [59.0, 0], [60.0, 0], [61.0, 0], [62.0, 0]]), | ||
"decoder_target": np.array([[27.0], [28.0], [29.0], [30.0], [31.0]]), | ||
"segment": "1", | ||
} | ||
|
||
assert first_sample["segment"] == "1" | ||
assert first_sample["decoder_real"].shape == (decoder_length, 2) | ||
assert first_sample["decoder_target"].shape == (decoder_length, 1) | ||
assert len(ts_samples) == 7 | ||
assert np.all(last_sample["decoder_target"] == expected["decoder_target"]) | ||
assert np.all(last_sample["decoder_real"] == expected["decoder_real"]) | ||
assert last_sample["segment"] == expected["segment"] | ||
np.testing.assert_equal(df[["target"]].iloc[:decoder_length], first_sample["decoder_target"]) | ||
np.testing.assert_equal(df[["target"]].iloc[decoder_length : 2 * decoder_length], second_sample["decoder_target"]) | ||
|
||
|
||
def test_mlp_step(): | ||
|
||
batch = { | ||
"decoder_real": torch.Tensor([[1, 2, 3], [1, 2, 3], [1, 2, 3]]), | ||
"decoder_target": torch.Tensor([[1], [2], [3]]), | ||
"segment": "A", | ||
} | ||
model = MLPNet(input_size=3, hidden_size=[1], lr=1e-2, loss=nn.MSELoss(), optimizer_params=None) | ||
loss, decoder_target, output = model.step(batch) | ||
assert type(loss) == torch.Tensor | ||
assert type(decoder_target) == torch.Tensor | ||
assert torch.all(decoder_target == batch["decoder_target"]) | ||
assert type(output) == torch.Tensor | ||
assert output.shape == torch.Size([3, 1]) | ||
|
||
|
||
def test_mlp_layers(): | ||
model = MLPNet(input_size=3, hidden_size=[10], lr=1e-2, loss=None, optimizer_params=None) | ||
model_ = nn.Sequential( | ||
nn.Linear(in_features=3, out_features=10), nn.ReLU(), nn.Linear(in_features=10, out_features=1) | ||
) | ||
assert repr(model_) == repr(model.mlp) |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is no such param