Skip to content

Commit

Permalink
Folds number estimation (#1279)
Browse files Browse the repository at this point in the history
* implemented `estimate_max_n_folds`

* added tests

* updated doc

* removed check

* renamed checks

* added test

* updated changelog

* updated docs

* renamed test

* fixed doc
  • Loading branch information
brsnw250 authored Jun 6, 2023
1 parent bcb98a6 commit 8e8c0f6
Show file tree
Hide file tree
Showing 4 changed files with 388 additions and 1 deletion.
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Notebook `forecast_interpretation.ipynb` with forecast decomposition ([#1220](https://github.com/tinkoff-ai/etna/pull/1220))
- Exogenous variables shift transform `ExogShiftTransform`([#1254](https://github.com/tinkoff-ai/etna/pull/1254))
- Parameter `start_timestamp` to forecast CLI command ([#1265](https://github.com/tinkoff-ai/etna/pull/1265))
-
- Function `estimate_max_n_folds` for folds number estimation ([#1279](https://github.com/tinkoff-ai/etna/pull/1279))
-
### Changed
- Set the default value of `final_model` to `LinearRegression(positive=True)` in the constructor of `StackingEnsemble` ([#1238](https://github.com/tinkoff-ai/etna/pull/1238))
- Add microseconds to `FileLogger`'s directory name ([#1264](https://github.com/tinkoff-ai/etna/pull/1264))
Expand Down
122 changes: 122 additions & 0 deletions etna/commands/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
from enum import Enum
from math import floor
from typing import Literal
from typing import Optional
from typing import Union

from etna.datasets import TSDataset
from etna.pipeline import Pipeline


class MethodsWithFolds(str, Enum):
"""Enum for methods that use `n_folds` argument."""

forecast = "forecast"
backtest = "backtest"

@classmethod
def _missing_(cls, value):
raise ValueError(
f"{value} is not a valid method name. Only {', '.join([repr(m.value) for m in cls])} are allowed"
)


def _estimate_n_folds(num_points: int, horizon: int, stride: int, context_size: int) -> int:
"""Estimate number of folds."""
if num_points < horizon + context_size:
raise ValueError("Not enough data points!")

res = (num_points - horizon + stride - context_size) / stride
return floor(res)


def _max_n_folds_forecast(pipeline: Pipeline, context_size: int, ts: Optional[TSDataset] = None) -> int:
"""Estimate max n_folds for forecast method."""
if ts is None:
if pipeline.ts is None:
raise ValueError(
"There is no ts for forecast method! Pass ts into function or make sure that pipeline is fitted."
)

else:
ts = pipeline.ts

num_points = len(ts.index)
horizon = pipeline.horizon

return _estimate_n_folds(num_points=num_points, horizon=horizon, stride=horizon, context_size=context_size)


def _max_n_folds_backtest(pipeline: Pipeline, context_size: int, ts: TSDataset, **method_kwargs) -> int:
"""Estimate max n_folds for backtest method."""
# process backtest with intervals case
backtest_with_intervals = "forecast_params" in method_kwargs and method_kwargs["forecast_params"].get(
"prediction_interval", False
)

if backtest_with_intervals:
raise NotImplementedError("Number of folds estimation for backtest with intervals is not implemented!")

num_points = len(ts.index)

horizon = pipeline.horizon
stride = method_kwargs.get("stride", horizon)

return _estimate_n_folds(num_points=num_points, horizon=horizon, stride=stride, context_size=context_size)


def estimate_max_n_folds(
pipeline: Pipeline,
method_name: Union[Literal["forecast"], Literal["backtest"]],
context_size: int,
ts: Optional[TSDataset] = None,
**method_kwargs,
) -> int:
"""Estimate number of folds using provided data and pipeline configuration.
This function helps to estimate maximum number of folds that can be used when performing
forecast with intervals or pipeline backtest. Number of folds estimated using the following formula:
.. math::
max\\_n\\_folds = \\left\\lfloor\\frac{num\\_points - horizon + stride - context\\_size}{stride}\\right\\rfloor,
where :math:`num\\_points` is number of points in the dataset,
:math:`horizon` is length of forecasting horizon,
:math:`stride` is number of points between folds,
:math:`context\\_size` is pipeline context size.
Parameters
----------
pipeline:
Pipeline for which to estimate number of folds.
method_name:
Method name for which to estimate number of folds.
context_size:
Minimum number of points for pipeline to be estimated.
ts:
Dataset which will be used for estimation.
method_kwargs:
Additional arguments for methods that impact number of folds.
Returns
-------
:
Number of folds.
"""
if context_size < 1:
raise ValueError("Pipeline `context_size` parameter must be positive integer!")

if ts is None and method_name != MethodsWithFolds.forecast:
raise ValueError("Parameter `ts` is required when estimating for backtest method")

method = MethodsWithFolds(method_name)

if method == MethodsWithFolds.forecast:
n_folds = _max_n_folds_forecast(pipeline=pipeline, context_size=context_size, ts=ts)

else:
# ts always not None for backtest case
n_folds = _max_n_folds_backtest(pipeline=pipeline, context_size=context_size, ts=ts, **method_kwargs) # type: ignore

return n_folds
8 changes: 8 additions & 0 deletions tests/test_commands/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import pandas as pd
import pytest

from etna.datasets import TSDataset
from etna.datasets import generate_ar_df


Expand Down Expand Up @@ -171,3 +172,10 @@ def start_timestamp_forecast_omegaconf_path():
tmp.flush()
yield Path(tmp.name)
tmp.close()


@pytest.fixture
def empty_ts():
df = pd.DataFrame({"segment": [], "timestamp": [], "target": []})
df = TSDataset.to_dataset(df=df)
return TSDataset(df=df, freq="D")
Loading

1 comment on commit 8e8c0f6

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please sign in to comment.