Skip to content

Commit

Permalink
changed names
Browse files Browse the repository at this point in the history
  • Loading branch information
brsnw250 committed Mar 15, 2023
1 parent 65576c3 commit 4a67b12
Showing 1 changed file with 15 additions and 9 deletions.
24 changes: 15 additions & 9 deletions etna/models/holt_winters.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,26 +322,30 @@ def forecast_components(self, df: pd.DataFrame) -> pd.DataFrame:
self._check_mul_components()
self._check_df(df)

level = fit_result.level.values
trend = fit_result.trend.values
season = fit_result.season.values

horizon = df["timestamp"].nunique()
horizon_steps = np.arange(1, horizon + 1)

components = {"target_component_level": fit_result.level[-1] * np.ones(horizon)}
components = {"target_component_level": level[-1] * np.ones(horizon)}

if model.trend is not None:
t = horizon_steps.copy()

if model.damped_trend:
t = np.cumsum(fit_result.params["damping_trend"] ** t)

components["target_component_trend"] = fit_result.trend[-1] * t
components["target_component_trend"] = trend[-1] * t

if model.seasonal is not None:
last_period = len(fit_result.season)
last_period = len(season)

seasonal_periods = fit_result.model.seasonal_periods
k = horizon_steps // seasonal_periods

components["target_component_seasonality"] = fit_result.season.values[
components["target_component_seasonality"] = season[
last_period + horizon_steps - seasonal_periods * (k + 1) - 1
]

Expand Down Expand Up @@ -374,14 +378,16 @@ def predict_components(self, df: pd.DataFrame) -> pd.DataFrame:
self._check_mul_components()
self._check_df(df)

level = fit_result.level.values
trend = fit_result.trend.values
season = fit_result.season.values

components = {
"target_component_level": np.concatenate(
[[fit_result.params["initial_level"]], fit_result.level.values[:-1]]
),
"target_component_level": np.concatenate([[fit_result.params["initial_level"]], level[:-1]]),
}

if model.trend is not None:
trend = np.concatenate([[fit_result.params["initial_trend"]], fit_result.trend.values[:-1]])
trend = np.concatenate([[fit_result.params["initial_trend"]], trend[:-1]])

if model.damped_trend:
trend *= fit_result.params["damping_trend"]
Expand All @@ -391,7 +397,7 @@ def predict_components(self, df: pd.DataFrame) -> pd.DataFrame:
if model.seasonal is not None:
seasonal_periods = model.seasonal_periods
components["target_component_seasonality"] = np.concatenate(
[fit_result.params["initial_seasons"], fit_result.season.values[:-seasonal_periods]]
[fit_result.params["initial_seasons"], season[:-seasonal_periods]]
)

components_df = pd.DataFrame(data=components)
Expand Down

0 comments on commit 4a67b12

Please sign in to comment.