Skip to content

Commit

Permalink
Add DeadlineMovingAverageModel (#827)
Browse files Browse the repository at this point in the history
* DeadlineMovingAverageModel

* some fixes for deadline model

* add CHANGELOG.md

* add test for pipeline.backtest

* fix changelog conflict

* fix lint mistake

* fix behavior of deadline model in big horizons

* fix docs and lint mistakes

* fix docstrings and add test for big horizon

* fix linter mistake

* add test for big horizon

* fix lint mistake

* update CHANGELOG.md

Co-authored-by: ext.ytarasyuk <[email protected]>
  • Loading branch information
DBcreator and ext.ytarasyuk authored Aug 5, 2022
1 parent d76a11b commit 71a19db
Show file tree
Hide file tree
Showing 4 changed files with 422 additions and 7 deletions.
4 changes: 2 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Added
-
-
-
- `DeadlineMovingAverageModel` ([#827](https://github.com/tinkoff-ai/etna/pull/827))
- `DirectEnsemble` ([#824](https://github.com/tinkoff-ai/etna/pull/824))
-
-
Expand Down Expand Up @@ -494,4 +494,4 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Distribution plot
- Anomalies (Outliers) plot
- Backtest (CrossValidation) plot
- Forecast plot
- Forecast plot
1 change: 1 addition & 0 deletions etna/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from etna.models.catboost import CatBoostModelPerSegment
from etna.models.catboost import CatBoostMultiSegmentModel
from etna.models.catboost import CatBoostPerSegmentModel
from etna.models.deadline_ma import DeadlineMovingAverageModel
from etna.models.holt_winters import HoltModel
from etna.models.holt_winters import HoltWintersModel
from etna.models.holt_winters import SimpleExpSmoothingModel
Expand Down
169 changes: 169 additions & 0 deletions etna/models/deadline_ma.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
import warnings
from enum import Enum
from typing import Dict
from typing import List

import numpy as np
import pandas as pd

from etna.models.base import PerSegmentModel


class SeasonalityMode(Enum):
"""Enum for seasonality mode for DeadlineMovingAverageModel."""

month = "month"
year = "year"

@classmethod
def _missing_(cls, value):
raise NotImplementedError(
f"{value} is not a valid {cls.__name__}. Only {', '.join([repr(m.value) for m in cls])} seasonality allowed"
)


class _DeadlineMovingAverageModel:
"""Moving average model that uses exact previous dates to predict."""

def __init__(self, window: int = 3, seasonality: str = "month"):
"""
Initialize deadline moving average model.
Length of remembered tail of series is equal to the number of ``window`` months or years, depending on the ``seasonality``.
Parameters
----------
window: int
Number of values taken for forecast for each point.
seasonality: str
Only allowed monthly or annual seasonality.
"""
self.name = "target"
self.window = window
self.seasonality = SeasonalityMode(seasonality)
self.freqs_available = {"H", "D"}

def fit(self, df: pd.DataFrame, regressors: List[str]) -> "_DeadlineMovingAverageModel":
"""
Fit DeadlineMovingAverageModel model.
Parameters
----------
df: pd.DataFrame
Data to fit on
regressors:
List of the columns with regressors(ignored in this model)
Raises
------
ValueError
If freq of dataframe is not supported
ValueError
If series is too short for chosen shift value
Returns
-------
:
Fitted model
"""
freq = pd.infer_freq(df["timestamp"])
if freq not in self.freqs_available:
raise ValueError(f"{freq} is not supported! Use daily or hourly frequency!")

if set(df.columns) != {"timestamp", "target"}:
warnings.warn(
message=f"{type(self).__name__} does not work with any exogenous series or features. "
f"It uses only target series for predict/\n "
)
targets = df["target"]
timestamps = df["timestamp"]

if self.seasonality == SeasonalityMode.month:
first_index = timestamps.iloc[-1] - pd.DateOffset(months=self.window)

elif self.seasonality == SeasonalityMode.year:
first_index = timestamps.iloc[-1] - pd.DateOffset(years=self.window)

if first_index < timestamps.iloc[0]:
raise ValueError(
"Given series is too short for chosen shift value. Try lower shift value, or give" "longer series."
)

self.series = targets.loc[timestamps >= first_index]
self.timestamps = timestamps.loc[timestamps >= first_index]
self.shift = len(self.series)

return self

def predict(self, df: pd.DataFrame) -> np.ndarray:
"""
Compute predictions from a DeadlineMovingAverageModel.
Parameters
----------
df: pd.DataFrame
Used only for getting the horizon of forecast and timestamps.
Returns
-------
:
Array with predictions.
"""
timestamps = df["timestamp"]
index = pd.date_range(start=self.timestamps.iloc[0], end=timestamps.iloc[-1])
res = np.append(self.series.values, np.zeros(len(df)))
res = pd.DataFrame(res)
res.index = index
for i in range(len(self.series), len(res)):
for w in range(1, self.window + 1):
if self.seasonality == SeasonalityMode.month:
prev_date = res.index[i] - pd.DateOffset(months=w)

elif self.seasonality == SeasonalityMode.year:
prev_date = res.index[i] - pd.DateOffset(years=w)
if prev_date <= self.timestamps.iloc[-1]:
res.loc[index[i]] += self.series.loc[self.timestamps == prev_date].values
else:
res.loc[index[i]] += res.loc[prev_date].values

res.loc[index[i]] = res.loc[index[i]] / self.window

res = res.values.reshape(
len(res),
)

return res[-len(df) :]


class DeadlineMovingAverageModel(PerSegmentModel):
"""Moving average model that uses exact previous dates to predict."""

def __init__(self, window: int = 3, seasonality: str = "month"):
"""
Initialize deadline moving average model.
Parameters
----------
window: int
Number of values taken for forecast for each point.
seasonality: str
Only allowed monthly or annual seasonality.
"""
self.window = window
self.seasonality = seasonality
super(DeadlineMovingAverageModel, self).__init__(
base_model=_DeadlineMovingAverageModel(window=window, seasonality=seasonality)
)

def get_model(self) -> Dict[str, "DeadlineMovingAverageModel"]:
"""Get internal model.
Returns
-------
:
Internal model
"""
return self._get_model()


__all__ = ["DeadlineMovingAverageModel"]
Loading

1 comment on commit 71a19db

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please sign in to comment.