Skip to content

Commit

Permalink
Implement forecast decomposition for ARIMA-like models (#1174)
Browse files Browse the repository at this point in the history
* implemented decomposition

* added tests for sarima

* added tests for auto_arima

* updated changelog

* changed test params

* same name exogs test

* added notes

* removed test

* added comments for decomposition

* fix lint

* added note
  • Loading branch information
brsnw250 authored Mar 23, 2023
1 parent df78d2e commit 4041a08
Show file tree
Hide file tree
Showing 4 changed files with 257 additions and 0 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,13 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Methods `forecast_components` and `predict_components` for forecast decomposition in `_CatBoostAdapter` ([#1135](https://github.com/tinkoff-ai/etna/issues/1135))
- Methods `forecast_components` and `predict_components` for forecast decomposition in `_HoltWintersAdapter ` ([#1146](https://github.com/tinkoff-ai/etna/issues/1146))
- Methods `predict_components` for forecast decomposition in `_ProphetAdapter` ([#1161](https://github.com/tinkoff-ai/etna/issues/1161))
- Methods `forecast_components` and `predict_components` for forecast decomposition in `_SARIMAXAdapter` and `_AutoARIMAAdapter` ([#1149](https://github.com/tinkoff-ai/etna/issues/1149))
- Add `refit` parameter into `backtest` ([#1159](https://github.com/tinkoff-ai/etna/pull/1159))
- Add `stride` parameter into `backtest` ([#1165](https://github.com/tinkoff-ai/etna/pull/1165))
- Add optional parameter `ts` into `forecast` method of pipelines ([#1071](https://github.com/tinkoff-ai/etna/pull/1071))
- Add tests on `transform` method of transforms on subset of segments, on new segments, on future with gap ([#1094](https://github.com/tinkoff-ai/etna/pull/1094))
- Add tests on `inverse_transform` method of transforms on subset of segments, on new segments, on future with gap ([#1127](https://github.com/tinkoff-ai/etna/pull/1127))
-
### Changed
- Add optional `features` parameter in the signature of `TSDataset.to_pandas`, `TSDataset.to_flatten` ([#809](https://github.com/tinkoff-ai/etna/pull/809))
- Signature of the constructor of `TFTModel`, `DeepARModel` ([#1110](https://github.com/tinkoff-ai/etna/pull/1110))
Expand Down
4 changes: 4 additions & 0 deletions etna/models/autoarima.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,10 @@ class AutoARIMAModel(
Notes
-----
We use :py:class:`pmdarima.arima.arima.ARIMA`.
This model supports in-sample and out-of-sample prediction decomposition.
Prediction components for AutoARIMA model are: exogenous and ARIMA components.
Decomposition is obtained directly from fitted model parameters.
"""

def __init__(
Expand Down
148 changes: 148 additions & 0 deletions etna/models/sarimax.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@
from typing import Sequence
from typing import Tuple

import numpy as np
import pandas as pd
from statsmodels.tools.sm_exceptions import ValueWarning
from statsmodels.tsa.statespace.sarimax import SARIMAX
from statsmodels.tsa.statespace.sarimax import SARIMAXResultsWrapper
from statsmodels.tsa.statespace.simulation_smoother import SimulationSmoother

from etna.libs.pmdarima_utils import seasonal_prediction_with_confidence
from etna.models.base import BaseAdapter
Expand Down Expand Up @@ -212,6 +214,148 @@ def get_model(self) -> SARIMAXResultsWrapper:
"""
return self._fit_results

@staticmethod
def _prepare_components_df(components: np.ndarray, model: SARIMAX) -> pd.DataFrame:
"""Prepare `pd.DataFrame` with components."""
if model.exog_names is not None:
components_names = model.exog_names[:]
else:
components_names = []

if model.seasonal_periods == 0:
components_names.append("arima")
else:
components_names.append("sarima")

df = pd.DataFrame(data=components, columns=components_names)
return df.add_prefix("target_component_")

@staticmethod
def _prepare_design_matrix(ssm: SimulationSmoother) -> np.ndarray:
"""Extract design matrix from state space model."""
design_mat = ssm["design"]
if len(design_mat.shape) == 2:
design_mat = design_mat[..., np.newaxis]

return design_mat

def _mle_regression_decomposition(self, state: np.ndarray, ssm: SimulationSmoother, exog: np.ndarray) -> np.ndarray:
"""Estimate SARIMAX components for MLE regression case.
SARIMAX representation as SSM: https://www.statsmodels.org/dev/statespace.html
In MLE case exogenous data fitted separately from other components:
https://github.com/statsmodels/statsmodels/blob/main/statsmodels/tsa/statespace/sarimax.py#L1644
"""
# get design matrix from SSM
design_mat = self._prepare_design_matrix(ssm)

# estimate SARIMA component
components = np.sum(design_mat * state, axis=1).T

if len(exog) > 0:
# restore parameters for exogenous variabales
exog_params = np.linalg.lstsq(a=exog, b=np.squeeze(ssm["obs_intercept"]))[0]

# estimate exogenous components and append to others
weighted_exog = exog * exog_params[np.newaxis]
components = np.concatenate([weighted_exog, components], axis=1)

return components

def _state_regression_decomposition(self, state: np.ndarray, ssm: SimulationSmoother, k_exog: int) -> np.ndarray:
"""Estimate SARIMAX components for state regression case.
SARIMAX representation as SSM: https://www.statsmodels.org/dev/statespace.html
In state regression case parameters for exogenous variables estimated inside SSM.
"""
# get design matrix from SSM
design_mat = self._prepare_design_matrix(ssm)

if k_exog > 0:
# estimate SARIMA component
sarima = np.sum(design_mat[:, :-k_exog] * state[:-k_exog], axis=1)

# obtain params from SSM and estimate exogenous components
weighted_exog = np.squeeze(design_mat[:, -k_exog:] * state[-k_exog:])
components = np.concatenate([weighted_exog, sarima], axis=0).T

else:
# in this case we can take whole matrix for SARIMA component
components = np.sum(design_mat * state, axis=1).T

return components

def predict_components(self, df: pd.DataFrame) -> pd.DataFrame:
"""Estimate prediction components.
Parameters
----------
df:
features dataframe
Returns
-------
:
dataframe with prediction components
"""
fit_results = self._fit_results
model = fit_results.model

if model.hamilton_representation:
raise ValueError("Prediction decomposition is not implemented for Hamilton representation of an ARMA!")

state = fit_results.predicted_state[:, :-1]

if model.mle_regression:
components = self._mle_regression_decomposition(state=state, ssm=model.ssm, exog=model.exog)

else:
components = self._state_regression_decomposition(state=state, ssm=model.ssm, k_exog=model.k_exog)

return self._prepare_components_df(components=components, model=model)

def forecast_components(self, df: pd.DataFrame) -> pd.DataFrame:
"""Estimate forecast components.
Parameters
----------
df:
features dataframe
Returns
-------
:
dataframe with forecast components
"""
fit_results = self._fit_results

model = fit_results.model
if model.hamilton_representation:
raise ValueError("Prediction decomposition is not implemented for Hamilton representation of an ARMA!")

horizon = len(df)
self._encode_categoricals(df)
self._check_df(df, horizon)

exog_future = self._select_regressors(df)

forecast_results = fit_results.get_forecast(horizon, exog=exog_future).prediction_results.results
state = forecast_results.predicted_state[:, :-1]

if model.mle_regression:
# If there are no exog variales `mle_regression` will be set to `False`
# even if user set to `True`.
components = self._mle_regression_decomposition(
state=state, ssm=forecast_results.model, exog=exog_future.values # type: ignore
)

else:
components = self._state_regression_decomposition(
state=state, ssm=forecast_results.model, k_exog=model.k_exog
)

return self._prepare_components_df(components=components, model=model)


class _SARIMAXAdapter(_SARIMAXBaseAdapter):
"""
Expand Down Expand Up @@ -400,6 +544,10 @@ class SARIMAXModel(
`exogenous regressors` which should be known in future, however we use exogenous for
additional features what is not known in future, and regressors for features we do know in
future.
This model supports in-sample and out-of-sample prediction decomposition.
Prediction components for SARIMAX model are: exogenous and SARIMA components.
Decomposition is obtained directly from fitted model parameters.
"""

def __init__(
Expand Down
103 changes: 103 additions & 0 deletions tests/test_models/test_sarimax_model.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from copy import deepcopy

import numpy as np
import pytest
from statsmodels.tsa.statespace.sarimax import SARIMAXResultsWrapper

from etna.models import SARIMAXModel
from etna.models.sarimax import _SARIMAXAdapter
from etna.pipeline import Pipeline
from tests.test_models.utils import assert_model_equals_loaded_original

Expand Down Expand Up @@ -133,3 +135,104 @@ def test_forecast_1_point(example_tsds):
def test_save_load(example_tsds):
model = SARIMAXModel()
assert_model_equals_loaded_original(model=model, ts=example_tsds, transforms=[], horizon=3)


@pytest.mark.parametrize(
"components_method_name,in_sample", (("predict_components", True), ("forecast_components", False))
)
def test_decomposition_hamiltonian_repr_error(dfs_w_exog, components_method_name, in_sample):
train, test = dfs_w_exog
pred_df = train if in_sample else test

model = _SARIMAXAdapter(order=(2, 0, 0), seasonal_order=(1, 0, 0, 3), hamilton_representation=True)
model.fit(train, ["f1", "f2"])

components_method = getattr(model, components_method_name)

with pytest.raises(
ValueError, match="Prediction decomposition is not implemented for Hamilton representation of an ARMA!"
):
_ = components_method(df=pred_df)


@pytest.mark.parametrize(
"components_method_name,in_sample", (("predict_components", True), ("forecast_components", False))
)
@pytest.mark.parametrize(
"regressors, regressors_components",
(
(["f1", "f2"], ["target_component_f1", "target_component_f2"]),
(["f1"], ["target_component_f1"]),
(["f1", "f1"], ["target_component_f1", "target_component_f1"]),
([], []),
),
)
@pytest.mark.parametrize("trend", (None, "t"))
def test_components_names(dfs_w_exog, regressors, regressors_components, trend, components_method_name, in_sample):
expected_components = regressors_components + ["target_component_sarima"]

train, test = dfs_w_exog
pred_df = train if in_sample else test

model = _SARIMAXAdapter(trend=trend)
model.fit(train, regressors)

components_method = getattr(model, components_method_name)
components = components_method(df=pred_df)

assert sorted(components.columns) == sorted(expected_components)


@pytest.mark.long_2
@pytest.mark.parametrize(
"components_method_name,in_sample", (("predict_components", True), ("forecast_components", False))
)
@pytest.mark.parametrize(
"mle_regression,time_varying_regression,regressors",
(
(True, False, ["f1", "f1"]),
(True, False, []),
(False, True, ["f1", "f2"]),
(False, False, ["f1", "f2"]),
(False, False, []),
),
)
@pytest.mark.parametrize("trend", (None, "t"))
@pytest.mark.parametrize("enforce_stationarity", (True, False))
@pytest.mark.parametrize("enforce_invertibility", (True, False))
@pytest.mark.parametrize("concentrate_scale", (True, False))
@pytest.mark.parametrize("use_exact_diffuse", (True, False))
def test_components_sum_up_to_target(
dfs_w_exog,
components_method_name,
in_sample,
mle_regression,
time_varying_regression,
trend,
enforce_stationarity,
enforce_invertibility,
concentrate_scale,
use_exact_diffuse,
regressors,
):
train, test = dfs_w_exog

model = _SARIMAXAdapter(
trend=trend,
mle_regression=mle_regression,
time_varying_regression=time_varying_regression,
enforce_stationarity=enforce_stationarity,
enforce_invertibility=enforce_invertibility,
concentrate_scale=concentrate_scale,
use_exact_diffuse=use_exact_diffuse,
)
model.fit(train, regressors)

components_method = getattr(model, components_method_name)

pred_df = train if in_sample else test

pred = model.predict(pred_df, prediction_interval=False, quantiles=[])
components = components_method(df=pred_df)

np.testing.assert_allclose(np.sum(components.values, axis=1), np.squeeze(pred))

1 comment on commit 4041a08

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please sign in to comment.