Skip to content

Commit

Permalink
Shift of exogenous data (#1254)
Browse files Browse the repository at this point in the history
* save exog last available date

* added exog shift transform

* added tests

* fixed tests

* fixed feature names selection

* new shift estimation logic

* review fixes

* added tests

* reworked tests

* moved exog dates from base

* reworked `_get_feature_names`

* added tests

* formatting

* review fixes

* added test

* moved exog check

* reworked check

* added test
  • Loading branch information
brsnw250 authored May 10, 2023
1 parent 634a5c6 commit e1f642f
Show file tree
Hide file tree
Showing 5 changed files with 413 additions and 1 deletion.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## Unreleased
### Added
- Notebook `forecast_interpretation.ipynb` with forecast decomposition ([#1220](https://github.com/tinkoff-ai/etna/pull/1220))
-
- Exogenous variables shift transform `ExogShiftTransform`([#1254](https://github.com/tinkoff-ai/etna/pull/1254))
-
-
### Changed
Expand Down
1 change: 1 addition & 0 deletions etna/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from etna.transforms.math import AddConstTransform
from etna.transforms.math import BoxCoxTransform
from etna.transforms.math import DifferencingTransform
from etna.transforms.math import ExogShiftTransform
from etna.transforms.math import LagTransform
from etna.transforms.math import LambdaTransform
from etna.transforms.math import LogTransform
Expand Down
1 change: 1 addition & 0 deletions etna/transforms/math/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from etna.transforms.math.add_constant import AddConstTransform
from etna.transforms.math.apply_lambda import LambdaTransform
from etna.transforms.math.differencing import DifferencingTransform
from etna.transforms.math.lags import ExogShiftTransform
from etna.transforms.math.lags import LagTransform
from etna.transforms.math.log import LogTransform
from etna.transforms.math.power import BoxCoxTransform
Expand Down
199 changes: 199 additions & 0 deletions etna/transforms/math/lags.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
from typing import Dict
from typing import List
from typing import Literal
from typing import Optional
from typing import Union

import pandas as pd

from etna.datasets import TSDataset
from etna.models.utils import determine_num_steps
from etna.transforms.base import FutureMixin
from etna.transforms.base import IrreversibleTransform

Expand Down Expand Up @@ -96,3 +100,198 @@ def _transform(self, df: pd.DataFrame) -> pd.DataFrame:
def get_regressors_info(self) -> List[str]:
"""Return the list with regressors created by the transform."""
return [self._get_column_name(lag) for lag in self.lags]


class ExogShiftTransform(IrreversibleTransform, FutureMixin):
"""Shifts exogenous variables from a given dataframe."""

def __init__(self, lag: Union[int, Literal["auto"]], horizon: Optional[int] = None):
"""Create instance of ExogShiftTransform.
Parameters
----------
lag:
value for shift estimation
* if set to `int` all exogenous variables will be shifted `lag` steps forward;
* if set to `auto` minimal shift will be estimated for each variable based on
the prediction horizon and available timeline
horizon:
prediction horizon. Mandatory when set to `lag="auto"`, ignored otherwise
"""
super().__init__(required_features="all")

self.lag: Optional[int] = None
self.horizon: Optional[int] = None
self._auto = False

self._freq: Optional[str] = None
self._created_regressors: Optional[List[str]] = None
self._exog_shifts: Optional[Dict[str, int]] = None
self._exog_last_date: Optional[Dict[str, pd.Timestamp]] = None
self._filter_out_columns = {"target"}

if isinstance(lag, int):
if lag <= 0:
raise ValueError(f"{self.__class__.__name__} works only with positive lags values, {lag} given")
self.lag = lag

else:
if horizon is None:
raise ValueError("Value of `horizon` should be specified when using `auto`!")

if horizon < 1:
raise ValueError(f"{self.__class__.__name__} works only with positive horizon values, {horizon} given")

self.horizon = horizon
self._auto = True

def _save_exog_last_date(self, df_exog: Optional[pd.DataFrame] = None):
"""Save last available date of each exogenous variable."""
self._exog_last_date = {}
if df_exog is not None:
exog_names = set(df_exog.columns.get_level_values("feature"))

for name in exog_names:
feature = df_exog.loc[:, pd.IndexSlice[:, name]]

na_mask = pd.isna(feature).any(axis=1)
last_date = feature.index[~na_mask].max()

self._exog_last_date[name] = last_date

def fit(self, ts: TSDataset) -> "ExogShiftTransform":
"""Fit the transform.
Parameters
----------
ts:
Dataset to fit the transform on.
Returns
-------
:
The fitted transform instance.
"""
self._freq = ts.freq
self._save_exog_last_date(df_exog=ts.df_exog)

super().fit(ts=ts)

return self

def _fit(self, df: pd.DataFrame) -> "ExogShiftTransform":
"""Estimate shifts for exogenous variables.
Parameters
----------
df:
dataframe with data.
Returns
-------
:
Fitted `ExogShiftTransform` instance.
"""
feature_names = self._get_feature_names(df=df)

self._exog_shifts = dict()
self._created_regressors = []

for feature_name in feature_names:
shift = self._estimate_shift(df=df, feature_name=feature_name)
self._exog_shifts[feature_name] = shift

if shift > 0:
self._created_regressors.append(f"{feature_name}_shift_{shift}")

return self

def _get_feature_names(self, df: pd.DataFrame) -> List[str]:
"""Return the names of exogenous variables."""
feature_names = []
if self._exog_last_date is not None:
feature_names = list(self._exog_last_date.keys())

df_columns = df.columns.get_level_values("feature")
for name in feature_names:
if name not in df_columns:
raise ValueError(f"Feature `{name}` is expected to be in the dataframe!")

return feature_names

def _estimate_shift(self, df: pd.DataFrame, feature_name: str) -> int:
"""Estimate shift value for exogenous variable."""
if not self._auto:
return self.lag # type: ignore

if self._exog_last_date is None or self._freq is None:
raise ValueError("Call `fit()` method before estimating exog shifts!")

last_date = df.index.max()
last_feature_date = self._exog_last_date[feature_name]

if last_feature_date > last_date:
delta = -determine_num_steps(start_timestamp=last_date, end_timestamp=last_feature_date, freq=self._freq)

elif last_feature_date < last_date:
delta = determine_num_steps(start_timestamp=last_feature_date, end_timestamp=last_date, freq=self._freq)

else:
delta = 0

shift = max(0, delta + self.horizon) # type: ignore

return shift

def _transform(self, df: pd.DataFrame) -> pd.DataFrame:
"""Shift exogenous variables.
Parameters
----------
df:
dataframe with data to transform.
Returns
-------
:
Transformed dataframe.
"""
if self._exog_shifts is None:
raise ValueError("Transform is not fitted!")

result = df
segments = sorted(set(df.columns.get_level_values("segment")))
feature_names = self._get_feature_names(df=df)

shifted_features = []
features_to_remove = []
for feature_name in feature_names:
shift = self._exog_shifts[feature_name]

feature = df.loc[:, pd.IndexSlice[:, feature_name]]

if shift > 0:
shifted_feature = feature.shift(shift, freq=self._freq)

column_name = f"{feature_name}_shift_{shift}"
shifted_feature.columns = pd.MultiIndex.from_product([segments, [column_name]])

shifted_features.append(shifted_feature)
features_to_remove.append(feature_name)

if len(features_to_remove) > 0:
result = result.drop(columns=pd.MultiIndex.from_product([segments, features_to_remove]))

result = pd.concat([result] + shifted_features, axis=1)
result.sort_index(axis=1, inplace=True)
return result

def get_regressors_info(self) -> List[str]:
"""Return the list with regressors created by the transform."""
if self._created_regressors is None:
raise ValueError("Fit the transform to get the regressors info!")

return self._created_regressors
Loading

1 comment on commit e1f642f

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please sign in to comment.