Skip to content

Commit

Permalink
Add inverse transformation into predict method of pipelines (#1314)
Browse files Browse the repository at this point in the history
  • Loading branch information
Mr-Geekman authored Jul 18, 2023
1 parent 7e54706 commit 41440d8
Show file tree
Hide file tree
Showing 5 changed files with 62 additions and 5 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
-
- `mrmr` feature selection working with categoricals ([#1311](https://github.com/tinkoff-ai/etna/pull/1311))
- Fix version of `statsforecast` to 1.4 to avoid dependency conflicts during installation ([#1313](https://github.com/tinkoff-ai/etna/pull/1313))
- Add inverse transformation into `predict` method of pipelines ([#1314](https://github.com/tinkoff-ai/etna/pull/1314))

### Removed
- Building docker images with cuda 10.2 ([#1306](https://github.com/tinkoff-ai/etna/pull/1306))
Expand Down
2 changes: 2 additions & 0 deletions etna/pipeline/mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,8 @@ def _predict(
)
else:
raise NotImplementedError(f"Unknown model type: {self.model.__class__.__name__}!")

results.inverse_transform(self.transforms)
return results


Expand Down
19 changes: 18 additions & 1 deletion tests/test_pipeline/test_autoregressive_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ def test_backtest_forecasts_sanity(step_ts: TSDataset):
(ProphetModel(), []),
],
)
def test_predict(model, transforms, example_tsds):
def test_predict_format(model, transforms, example_tsds):
ts = example_tsds
pipeline = AutoRegressivePipeline(model=model, transforms=transforms, horizon=7)
pipeline.fit(ts)
Expand All @@ -282,6 +282,23 @@ def test_predict(model, transforms, example_tsds):
assert len(result_df) == len(example_tsds.segments) * num_points


def test_predict_values(example_tsds):
original_ts = deepcopy(example_tsds)

model = LinearPerSegmentModel()
transforms = [AddConstTransform(in_column="target", value=10, inplace=True), DateFlagsTransform()]
pipeline = AutoRegressivePipeline(model=model, transforms=transforms, horizon=5)
pipeline.fit(example_tsds)
predictions_pipeline = pipeline.predict(ts=original_ts)

original_ts.fit_transform(transforms)
model.fit(original_ts)
predictions_manual = model.predict(original_ts)
predictions_manual.inverse_transform(transforms)

pd.testing.assert_frame_equal(predictions_pipeline.to_pandas(), predictions_manual.to_pandas())


@pytest.mark.parametrize("load_ts", [True, False])
@pytest.mark.parametrize(
"model, transforms",
Expand Down
20 changes: 20 additions & 0 deletions tests/test_pipeline/test_mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,26 @@ def test_predict_mixin_predict_create_ts_called(start_timestamp, end_timestamp,
mixin._create_ts.assert_called_once_with(ts=ts, start_timestamp=start_timestamp, end_timestamp=end_timestamp)


@pytest.mark.parametrize(
"start_timestamp, end_timestamp",
[
(pd.Timestamp("2020-01-01"), pd.Timestamp("2020-01-01")),
(pd.Timestamp("2020-01-01"), pd.Timestamp("2020-01-02")),
(pd.Timestamp("2020-01-01"), pd.Timestamp("2020-01-10")),
(pd.Timestamp("2020-01-05"), pd.Timestamp("2020-01-10")),
],
)
def test_predict_mixin_predict_inverse_transform_called(start_timestamp, end_timestamp, example_tsds):
ts = MagicMock()
mixin = make_mixin()

result = mixin._predict(
ts=ts, start_timestamp=start_timestamp, end_timestamp=end_timestamp, prediction_interval=False, quantiles=[]
)

result.inverse_transform.assert_called_once_with(mixin.transforms)


@pytest.mark.parametrize(
"start_timestamp, end_timestamp",
[
Expand Down
25 changes: 21 additions & 4 deletions tests/test_pipeline/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,8 +205,8 @@ def test_forecast_with_intervals_other_model(base_forecast, model_class):
)


def test_forecast(example_tsds):
"""Test that the forecast from the Pipeline is correct."""
def test_forecast_values(example_tsds):
"""Test that the forecast from the Pipeline generates correct values."""
original_ts = deepcopy(example_tsds)

model = LinearPerSegmentModel()
Expand All @@ -221,7 +221,7 @@ def test_forecast(example_tsds):
forecast_manual = model.forecast(future)
forecast_manual.inverse_transform(transforms)

assert np.all(forecast_pipeline.df.values == forecast_manual.df.values)
pd.testing.assert_frame_equal(forecast_pipeline.to_pandas(), forecast_manual.to_pandas())


@pytest.mark.parametrize(
Expand Down Expand Up @@ -1121,7 +1121,7 @@ def test_pipeline_with_deepmodel(example_tsds):
(ProphetModel(), []),
],
)
def test_predict(model, transforms, example_tsds):
def test_predict_format(model, transforms, example_tsds):
ts = example_tsds
pipeline = Pipeline(model=model, transforms=transforms, horizon=7)
pipeline.fit(ts)
Expand All @@ -1143,6 +1143,23 @@ def test_predict(model, transforms, example_tsds):
assert len(result_df) == len(example_tsds.segments) * num_points


def test_predict_values(example_tsds):
original_ts = deepcopy(example_tsds)

model = LinearPerSegmentModel()
transforms = [AddConstTransform(in_column="target", value=10, inplace=True), DateFlagsTransform()]
pipeline = Pipeline(model=model, transforms=transforms, horizon=5)
pipeline.fit(example_tsds)
predictions_pipeline = pipeline.predict(ts=original_ts)

original_ts.fit_transform(transforms)
model.fit(original_ts)
predictions_manual = model.predict(original_ts)
predictions_manual.inverse_transform(transforms)

pd.testing.assert_frame_equal(predictions_pipeline.to_pandas(), predictions_manual.to_pandas())


@pytest.mark.parametrize("load_ts", [True, False])
@pytest.mark.parametrize(
"model, transforms",
Expand Down

1 comment on commit 41440d8

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please sign in to comment.