Skip to content

Implement forecast decomposition for Holt-like models #1162

Merged
merged 9 commits into from
Mar 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- `ChangePointsLevelTransform` and base classes `PerIntervalModel`, `BaseChangePointsModelAdapter` for per-interval transforms ([#998](https://github.com/tinkoff-ai/etna/pull/998))
- Method `set_params` to change parameters of ETNA objects ([#1102](https://github.com/tinkoff-ai/etna/pull/1102))
- Function `plot_forecast_decomposition` ([#1129](https://github.com/tinkoff-ai/etna/pull/1129))
- Method `forecast_components` for forecast decomposition in `_TBATSAdapter` [#1125](https://github.com/tinkoff-ai/etna/issues/1125)
- Methods `forecast_components` and `predict_components` for forecast decomposition in `_CatBoostAdapter` [#1135](https://github.com/tinkoff-ai/etna/issues/1135)
-
- Method `forecast_components` for forecast decomposition in `_TBATSAdapter` ([#1125](https://github.com/tinkoff-ai/etna/issues/1125))
- Methods `forecast_components` and `predict_components` for forecast decomposition in `_CatBoostAdapter` ([#1135](https://github.com/tinkoff-ai/etna/issues/1135))
- Methods `forecast_components` and `predict_components` for forecast decomposition in `_HoltWintersAdapter ` ([#1146](https://github.com/tinkoff-ai/etna/issues/1146))
-
### Changed
- Add optional `features` parameter in the signature of `TSDataset.to_pandas`, `TSDataset.to_flatten` ([#809](https://github.com/tinkoff-ai/etna/pull/809))
- Signature of the constructor of `TFTModel`, `DeepARModel` ([#1110](https://github.com/tinkoff-ai/etna/pull/1110))
Expand Down
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ mypy-check:
mypy

spell-check:
codespell etna/ *.md tests/ -L mape,hist
codespell etna/ *.md tests/ -L mape,hist,lamda
python -m scripts.notebook_codespell

imported-deps-check:
Expand Down
144 changes: 144 additions & 0 deletions etna/models/holt_winters.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import numpy as np
import pandas as pd
from scipy.special import inv_boxcox
from statsmodels.tsa.holtwinters import ExponentialSmoothing
from statsmodels.tsa.holtwinters.results import HoltWintersResultsWrapper

Expand Down Expand Up @@ -277,6 +278,135 @@ def get_model(self) -> HoltWintersResultsWrapper:
"""
return self._result

def _check_mul_components(self):
"""Raise error if model has multiplicative components."""
model = self._model

if model is None:
raise ValueError("This model is not fitted!")

if (model.trend is not None and model.trend == "mul") or (
model.seasonal is not None and model.seasonal == "mul"
):
raise ValueError("Forecast decomposition is only supported for additive components!")

def _rescale_components(self, components: pd.DataFrame) -> pd.DataFrame:
"""Rescale components when Box-Cox transform used."""
alex-hse-repository marked this conversation as resolved.
Show resolved Hide resolved
if self._result is None:
raise ValueError("This model is not fitted!")

pred = np.sum(components.values, axis=1)
transformed_pred = inv_boxcox(pred, self._result.params["lamda"])
alex-hse-repository marked this conversation as resolved.
Show resolved Hide resolved
components *= (transformed_pred / pred).reshape((-1, 1))
return components

def forecast_components(self, df: pd.DataFrame) -> pd.DataFrame:
"""Estimate forecast components.

Parameters
----------
df:
features dataframe

Returns
-------
:
dataframe with forecast components
"""
model = self._model
fit_result = self._result

if fit_result is None or model is None:
raise ValueError("This model is not fitted!")

self._check_mul_components()
self._check_df(df)

level = fit_result.level.values
trend = fit_result.trend.values
season = fit_result.season.values

horizon = df["timestamp"].nunique()
horizon_steps = np.arange(1, horizon + 1)

components = {"target_component_level": level[-1] * np.ones(horizon)}

if model.trend is not None:
t = horizon_steps.copy()

if model.damped_trend:
t = np.cumsum(fit_result.params["damping_trend"] ** t)
alex-hse-repository marked this conversation as resolved.
Show resolved Hide resolved

components["target_component_trend"] = trend[-1] * t

if model.seasonal is not None:
last_period = len(season)

seasonal_periods = fit_result.model.seasonal_periods
k = horizon_steps // seasonal_periods

components["target_component_seasonality"] = season[
last_period + horizon_steps - seasonal_periods * (k + 1) - 1
]

components_df = pd.DataFrame(data=components)

if model._use_boxcox:
components_df = self._rescale_components(components=components_df)

return components_df

def predict_components(self, df: pd.DataFrame) -> pd.DataFrame:
"""Estimate prediction components.

Parameters
----------
df:
features dataframe

Returns
-------
:
dataframe with prediction components
"""
model = self._model
fit_result = self._result

if fit_result is None or model is None:
raise ValueError("This model is not fitted!")

self._check_mul_components()
self._check_df(df)

level = fit_result.level.values
trend = fit_result.trend.values
season = fit_result.season.values

components = {
"target_component_level": np.concatenate([[fit_result.params["initial_level"]], level[:-1]]),
}

if model.trend is not None:
trend = np.concatenate([[fit_result.params["initial_trend"]], trend[:-1]])

if model.damped_trend:
trend *= fit_result.params["damping_trend"]

components["target_component_trend"] = trend

if model.seasonal is not None:
seasonal_periods = model.seasonal_periods
components["target_component_seasonality"] = np.concatenate(
[fit_result.params["initial_seasons"], season[:-seasonal_periods]]
)

components_df = pd.DataFrame(data=components)

if model._use_boxcox:
components_df = self._rescale_components(components=components_df)

return components_df


class HoltWintersModel(
PerSegmentModelMixin,
Expand All @@ -289,6 +419,11 @@ class HoltWintersModel(
Notes
-----
We use :py:class:`statsmodels.tsa.holtwinters.ExponentialSmoothing` model from statsmodels package.

This model supports in-sample and out-of-sample prediction decomposition.
Prediction components for Holt-Winters model are: level, trend and seasonality.
For in-sample decomposition, components are obtained directly from the fitted model. For out-of-sample,
components estimated using an analytical form of the prediction function.
"""

def __init__(
Expand Down Expand Up @@ -486,6 +621,11 @@ class HoltModel(HoltWintersModel):
We use :py:class:`statsmodels.tsa.holtwinters.ExponentialSmoothing` model from statsmodels package.
They implement :py:class:`statsmodels.tsa.holtwinters.Holt` model
as a restricted version of :py:class:`~statsmodels.tsa.holtwinters.ExponentialSmoothing` model.

This model supports in-sample and out-of-sample prediction decomposition.
Prediction components for Holt model are: level and trend.
For in-sample decomposition, components are obtained directly from the fitted model. For out-of-sample,
components estimated using an analytical form of the prediction function.
"""

def __init__(
Expand Down Expand Up @@ -583,6 +723,10 @@ class SimpleExpSmoothingModel(HoltWintersModel):
We use :py:class:`statsmodels.tsa.holtwinters.ExponentialSmoothing` model from statsmodels package.
They implement :py:class:`statsmodels.tsa.holtwinters.SimpleExpSmoothing` model
as a restricted version of :py:class:`~statsmodels.tsa.holtwinters.ExponentialSmoothing` model.

This model supports in-sample and out-of-sample prediction decomposition.
For in-sample decomposition, level component is obtained directly from the fitted model. For out-of-sample,
it estimated using an analytical form of the prediction function.
"""

def __init__(
Expand Down
132 changes: 132 additions & 0 deletions tests/test_models/test_holt_winters_model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import numpy as np
import pandas as pd
import pytest
from statsmodels.tsa.holtwinters.results import HoltWintersResultsWrapper

Expand All @@ -8,6 +9,7 @@
from etna.models import HoltModel
from etna.models import HoltWintersModel
from etna.models import SimpleExpSmoothingModel
from etna.models.holt_winters import _HoltWintersAdapter
from etna.pipeline import Pipeline
from tests.test_models.utils import assert_model_equals_loaded_original

Expand Down Expand Up @@ -119,3 +121,133 @@ def test_get_model_after_training(example_tsds, etna_model_class, expected_class
@pytest.mark.parametrize("model", [HoltModel(), HoltWintersModel(), SimpleExpSmoothingModel()])
def test_save_load(model, example_tsds):
assert_model_equals_loaded_original(model=model, ts=example_tsds, transforms=[], horizon=3)


@pytest.fixture()
def multi_trend_dfs(multitrend_df):
df = multitrend_df.copy()
df.columns = df.columns.droplevel("segment")
df.reset_index(inplace=True)
df["target"] += 10 - df["target"].min()

return df.iloc[:-9], df.iloc[-9:]


@pytest.fixture()
def seasonal_dfs():
target = pd.Series(
[
41.727458,
24.041850,
32.328103,
37.328708,
46.213153,
29.346326,
36.482910,
42.977719,
48.901525,
31.180221,
37.717881,
40.420211,
51.206863,
31.887228,
40.978263,
43.772491,
55.558567,
33.850915,
42.076383,
45.642292,
59.766780,
35.191877,
44.319737,
47.913736,
],
index=pd.period_range(start="2005Q1", end="2010Q4", freq="Q"),
)

df = pd.DataFrame(
{
"timestamp": target.index.to_timestamp(),
"target": target.values,
}
)

return df.iloc[:-9], df.iloc[-9:]


def test_check_mul_components_not_fitted_error():
model = _HoltWintersAdapter()
with pytest.raises(ValueError, match="This model is not fitted!"):
model._check_mul_components()


def test_rescale_components_not_fitted_error():
model = _HoltWintersAdapter()
with pytest.raises(ValueError, match="This model is not fitted!"):
model._rescale_components(pd.DataFrame({}))


@pytest.mark.parametrize("components_method_name", ("predict_components", "forecast_components"))
def test_decomposition_not_fitted_error(seasonal_dfs, components_method_name):
_, test = seasonal_dfs

model = _HoltWintersAdapter()
components_method = getattr(model, components_method_name)

with pytest.raises(ValueError, match="This model is not fitted!"):
components_method(df=test)


@pytest.mark.parametrize("components_method_name", ("predict_components", "forecast_components"))
@pytest.mark.parametrize("trend,seasonal", (("mul", "mul"), ("mul", None), (None, "mul")))
def test_check_mul_components(seasonal_dfs, trend, seasonal, components_method_name):
_, test = seasonal_dfs

model = _HoltWintersAdapter(trend=trend, seasonal=seasonal)
model.fit(test, [])

components_method = getattr(model, components_method_name)

with pytest.raises(ValueError, match="Forecast decomposition is only supported for additive components!"):
components_method(df=test)


@pytest.mark.parametrize("components_method_name", ("predict_components", "forecast_components"))
@pytest.mark.parametrize("trend,trend_component", (("add", ["target_component_trend"]), (None, [])))
@pytest.mark.parametrize("seasonal,seasonal_component", (("add", ["target_component_seasonality"]), (None, [])))
def test_components_names(seasonal_dfs, trend, trend_component, seasonal, seasonal_component, components_method_name):
expected_components_names = set(trend_component + seasonal_component + ["target_component_level"])
_, test = seasonal_dfs

model = _HoltWintersAdapter(trend=trend, seasonal=seasonal)
model.fit(test, [])
components_method = getattr(model, components_method_name)
components = components_method(df=test)

assert set(components.columns) == expected_components_names


@pytest.mark.parametrize(
"components_method_name,in_sample", (("predict_components", True), ("forecast_components", False))
)
@pytest.mark.parametrize("df_names", ("seasonal_dfs", "multi_trend_dfs"))
@pytest.mark.parametrize("trend,damped_trend", (("add", True), ("add", False), (None, False)))
@pytest.mark.parametrize("seasonal", ("add", None))
@pytest.mark.parametrize("use_boxcox", (True, False))
def test_components_sum_up_to_target(
df_names, trend, seasonal, damped_trend, use_boxcox, components_method_name, in_sample, request
):
dfs = request.getfixturevalue(df_names)
train, test = dfs
alex-hse-repository marked this conversation as resolved.
Show resolved Hide resolved

model = _HoltWintersAdapter(trend=trend, seasonal=seasonal, damped_trend=damped_trend, use_boxcox=use_boxcox)
model.fit(train, [])

components_method = getattr(model, components_method_name)

pred_df = train if in_sample else test
alex-hse-repository marked this conversation as resolved.
Show resolved Hide resolved

components = components_method(df=pred_df)
pred = model.predict(pred_df)

np.testing.assert_allclose(np.sum(components.values, axis=1), pred)