Skip to content

Commit

Permalink
Add stride parameter into backtest (#1165)
Browse files Browse the repository at this point in the history
  • Loading branch information
Mr-Geekman authored Mar 17, 2023
1 parent 02bf892 commit 91a9105
Show file tree
Hide file tree
Showing 6 changed files with 492 additions and 223 deletions.
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,15 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
-
-
- Add `refit` parameter into `backtest` ([#1159](https://github.com/tinkoff-ai/etna/pull/1159))
- Add `stride` parameter into `backtest` ([#1165](https://github.com/tinkoff-ai/etna/pull/1165))
- Add optional parameter `ts` into `forecast` method of pipelines ([#1071](https://github.com/tinkoff-ai/etna/pull/1071))
- Add tests on `transform` method of transforms on subset of segments, on new segments, on future with gap ([#1094](https://github.com/tinkoff-ai/etna/pull/1094))
- Add tests on `inverse_transform` method of transforms on subset of segments, on new segments, on future with gap ([#1127](https://github.com/tinkoff-ai/etna/pull/1127))
### Changed
-
-
- Add more scenarios into tests for models ([#1082](https://github.com/tinkoff-ai/etna/pull/1082))
-
-
- Decouple `SeasonalMovingAverageModel` from `PerSegmentModelMixin` ([#1132](https://github.com/tinkoff-ai/etna/pull/1132))
- Decouple `DeadlineMovingAverageModel` from `PerSegmentModelMixin` ([#1140](https://github.com/tinkoff-ai/etna/pull/1140))
### Fixed
Expand Down
95 changes: 73 additions & 22 deletions etna/pipeline/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from joblib import delayed
from scipy.stats import norm
from typing_extensions import TypedDict
from typing_extensions import assert_never

from etna.core import AbstractSaveable
from etna.core import BaseMixin
Expand All @@ -35,6 +36,12 @@ class CrossValidationMode(Enum):
expand = "expand"
constant = "constant"

@classmethod
def _missing_(cls, value):
raise NotImplementedError(
f"{value} is not a valid {cls.__name__}. Only {', '.join([repr(m.value) for m in cls])} modes allowed"
)


class FoldMask(BaseMixin):
"""Container to hold the description of the fold mask.
Expand Down Expand Up @@ -209,10 +216,11 @@ def backtest(
ts: TSDataset,
metrics: List[Metric],
n_folds: Union[int, List[FoldMask]] = 5,
mode: str = "expand",
mode: Optional[str] = None,
aggregate_metrics: bool = False,
n_jobs: int = 1,
refit: Union[bool, int] = True,
stride: Optional[int] = None,
joblib_params: Optional[Dict[str, Any]] = None,
forecast_params: Optional[Dict[str, Any]] = None,
) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
Expand All @@ -229,7 +237,8 @@ def backtest(
n_folds:
Number of folds or the list of fold masks
mode:
One of 'expand', 'constant' -- train generation policy
Train generation policy: 'expand' or 'constant'. Works only if ``n_folds`` is integer.
By default, is set to 'expand'.
aggregate_metrics:
If True aggregate metrics above folds, return raw metrics otherwise
n_jobs:
Expand All @@ -243,6 +252,8 @@ def backtest(
* If ``value: int``: pipeline is trained every ``value`` folds starting from the first.
stride:
Number of points between folds. Works only if ``n_folds`` is integer. By default, is set to ``horizon``.
joblib_params:
Additional parameters for :py:class:`joblib.Parallel`
forecast_params:
Expand Down Expand Up @@ -453,43 +464,65 @@ def _init_backtest(self):

@staticmethod
def _validate_backtest_n_folds(n_folds: int):
"""Check that given n_folds value is valid."""
"""Check that given n_folds value is >= 1."""
if n_folds < 1:
raise ValueError(f"Folds number should be a positive number, {n_folds} given")

@staticmethod
def _validate_backtest_dataset(ts: TSDataset, n_folds: int, horizon: int):
def _validate_backtest_mode(n_folds: Union[int, List[FoldMask]], mode: Optional[str]) -> CrossValidationMode:
if mode is None:
return CrossValidationMode.expand

if not isinstance(n_folds, int):
raise ValueError("Mode shouldn't be set if n_folds are fold masks!")

return CrossValidationMode(mode.lower())

@staticmethod
def _validate_backtest_stride(n_folds: Union[int, List[FoldMask]], horizon: int, stride: Optional[int]) -> int:
if stride is None:
return horizon

if not isinstance(n_folds, int):
raise ValueError("Stride shouldn't be set if n_folds are fold masks!")

if stride < 1:
raise ValueError(f"Stride should be a positive number, {stride} given!")

return stride

@staticmethod
def _validate_backtest_dataset(ts: TSDataset, n_folds: int, horizon: int, stride: int):
"""Check all segments have enough timestamps to validate forecaster with given number of splits."""
min_required_length = horizon * n_folds
min_required_length = horizon + (n_folds - 1) * stride
segments = set(ts.df.columns.get_level_values("segment"))
for segment in segments:
segment_target = ts[:, segment, "target"]
if len(segment_target) < min_required_length:
raise ValueError(
f"All the series from feature dataframe should contain at least "
f"{horizon} * {n_folds} = {min_required_length} timestamps; "
f"{horizon} + {n_folds-1} * {stride} = {min_required_length} timestamps; "
f"series {segment} does not."
)

@staticmethod
def _generate_masks_from_n_folds(ts: TSDataset, n_folds: int, horizon: int, mode: str) -> List[FoldMask]:
def _generate_masks_from_n_folds(
ts: TSDataset, n_folds: int, horizon: int, mode: CrossValidationMode, stride: int
) -> List[FoldMask]:
"""Generate fold masks from n_folds."""
mode_enum = CrossValidationMode(mode.lower())
if mode_enum == CrossValidationMode.expand:
if mode is CrossValidationMode.expand:
constant_history_length = 0
elif mode_enum == CrossValidationMode.constant:
elif mode is CrossValidationMode.constant:
constant_history_length = 1
else:
raise NotImplementedError(
f"Only '{CrossValidationMode.expand}' and '{CrossValidationMode.constant}' modes allowed"
)
assert_never(mode)

masks = []
dataset_timestamps = list(ts.index)
min_timestamp_idx, max_timestamp_idx = 0, len(dataset_timestamps)
for offset in range(n_folds, 0, -1):
min_train_idx = min_timestamp_idx + (n_folds - offset) * horizon * constant_history_length
max_train_idx = max_timestamp_idx - horizon * offset - 1
min_train_idx = min_timestamp_idx + (n_folds - offset) * stride * constant_history_length
max_train_idx = max_timestamp_idx - stride * (offset - 1) - horizon - 1
min_test_idx = max_train_idx + 1
max_test_idx = max_train_idx + horizon

Expand Down Expand Up @@ -625,7 +658,7 @@ def _get_fold_info(self) -> pd.DataFrame:
tmp_df[f"{stage_name}_{border}_time"] = [fold_info[f"{stage_name}_timerange"][border]]
tmp_df[self._fold_column] = fold_number
timerange_dfs.append(tmp_df)
timerange_df = pd.concat(timerange_dfs)
timerange_df = pd.concat(timerange_dfs, ignore_index=True)
return timerange_df

def _get_backtest_forecasts(self) -> pd.DataFrame:
Expand All @@ -648,12 +681,16 @@ def _get_backtest_forecasts(self) -> pd.DataFrame:
forecasts.sort_index(axis=1, inplace=True)
return forecasts

def _prepare_fold_masks(self, ts: TSDataset, masks: Union[int, List[FoldMask]], mode: str) -> List[FoldMask]:
def _prepare_fold_masks(
self, ts: TSDataset, masks: Union[int, List[FoldMask]], mode: CrossValidationMode, stride: int
) -> List[FoldMask]:
"""Prepare and validate fold masks."""
if isinstance(masks, int):
self._validate_backtest_n_folds(n_folds=masks)
self._validate_backtest_dataset(ts=ts, n_folds=masks, horizon=self.horizon)
masks = self._generate_masks_from_n_folds(ts=ts, n_folds=masks, horizon=self.horizon, mode=mode)
self._validate_backtest_dataset(ts=ts, n_folds=masks, horizon=self.horizon, stride=stride)
masks = self._generate_masks_from_n_folds(
ts=ts, n_folds=masks, horizon=self.horizon, mode=mode, stride=stride
)
for i, mask in enumerate(masks):
mask.first_train_timestamp = mask.first_train_timestamp if mask.first_train_timestamp else ts.index[0]
masks[i] = mask
Expand Down Expand Up @@ -768,10 +805,11 @@ def backtest(
ts: TSDataset,
metrics: List[Metric],
n_folds: Union[int, List[FoldMask]] = 5,
mode: str = "expand",
mode: Optional[str] = None,
aggregate_metrics: bool = False,
n_jobs: int = 1,
refit: Union[bool, int] = True,
stride: Optional[int] = None,
joblib_params: Optional[Dict[str, Any]] = None,
forecast_params: Optional[Dict[str, Any]] = None,
) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
Expand All @@ -788,7 +826,8 @@ def backtest(
n_folds:
Number of folds or the list of fold masks
mode:
One of 'expand', 'constant' -- train generation policy, ignored if n_folds is a list of masks
Train generation policy: 'expand' or 'constant'. Works only if ``n_folds`` is integer.
By default, is set to 'expand'.
aggregate_metrics:
If True aggregate metrics above folds, return raw metrics otherwise
n_jobs:
Expand All @@ -802,6 +841,8 @@ def backtest(
* If ``value: int``: pipeline is trained every ``value`` folds starting from the first.
stride:
Number of points between folds. Works only if ``n_folds`` is integer. By default, is set to ``horizon``.
joblib_params:
Additional parameters for :py:class:`joblib.Parallel`
forecast_params:
Expand All @@ -811,7 +852,17 @@ def backtest(
-------
metrics_df, forecast_df, fold_info_df: Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]
Metrics dataframe, forecast dataframe and dataframe with information about folds
Raises
------
ValueError:
If ``mode`` is set when ``n_folds`` are ``List[FoldMask]``.
ValueError:
If ``stride`` is set when ``n_folds`` are ``List[FoldMask]``.
"""
mode_enum = self._validate_backtest_mode(n_folds=n_folds, mode=mode)
stride = self._validate_backtest_stride(n_folds=n_folds, horizon=self.horizon, stride=stride)

if joblib_params is None:
joblib_params = dict(verbose=11, backend="multiprocessing", mmap_mode="c")

Expand All @@ -820,7 +871,7 @@ def backtest(

self._init_backtest()
self._validate_backtest_metrics(metrics=metrics)
masks = self._prepare_fold_masks(ts=ts, masks=n_folds, mode=mode)
masks = self._prepare_fold_masks(ts=ts, masks=n_folds, mode=mode_enum, stride=stride)
self._folds = self._run_all_folds(
masks=masks,
ts=ts,
Expand Down
3 changes: 3 additions & 0 deletions etna/transforms/timestamp/special_days.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import datetime
from typing import Any
from typing import Dict
from typing import Optional
from typing import Tuple

Expand Down Expand Up @@ -55,6 +57,7 @@ def __init__(self, find_special_weekday: bool = True, find_special_month_day: bo
self.anomaly_week_days: Optional[Tuple[int]] = None
self.anomaly_month_days: Optional[Tuple[int]] = None

self.res_type: Dict[str, Any]
if self.find_special_weekday and find_special_month_day:
self.res_type = {"df_sample": (0, 0), "columns": ["anomaly_weekdays", "anomaly_monthdays"]}
elif self.find_special_weekday:
Expand Down
Loading

0 comments on commit 91a9105

Please sign in to comment.