Skip to content

Fix plot_backtest and plot_backtest_interactive on one-step forecast #1260

Merged
merged 3 commits into from
May 12, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
-
-
### Fixed
-
- Fix `plot_backtest` and `plot_backtest_interactive` on one-step forecast ([1260](https://github.com/tinkoff-ai/etna/pull/1260))
- Fix `BaseReconciliator` to work on `pandas==1.1.5` ([#1229](https://github.com/tinkoff-ai/etna/pull/1229))
- Fix `TSDataset.make_future` to handle hierarchy, quantiles, target components ([#1248](https://github.com/tinkoff-ai/etna/pull/1248))
- Fix warning during creation of `ResampleWithDistributionTransform` ([#1230](https://github.com/tinkoff-ai/etna/pull/1230))
Expand Down
14 changes: 8 additions & 6 deletions etna/analysis/forecast/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,6 @@ def plot_backtest(
forecast_start = forecast_df.index.min()
history_df = df[df.index < forecast_start]
backtest_df = df[df.index >= forecast_start]
freq_timedelta = df.index[1] - df.index[0]

# prepare colors
default_colors = plt.rcParams["axes.prop_cycle"].by_key()["color"]
Expand All @@ -257,6 +256,8 @@ def plot_backtest(
segment_history_df = history_df[segment]
segment_forecast_df = forecast_df[segment]
is_full_folds = set(segment_backtest_df.index) == set(segment_forecast_df.index)
single_point_forecast = len(segment_backtest_df) == 1
Mr-Geekman marked this conversation as resolved.
Show resolved Hide resolved
draw_only_lines = is_full_folds and not single_point_forecast

# plot history
if history_len == "all":
Expand All @@ -270,13 +271,13 @@ def plot_backtest(
for fold_number in folds:
start_fold = fold_numbers[fold_numbers == fold_number].index.min()
end_fold = fold_numbers[fold_numbers == fold_number].index.max()
end_fold_exclusive = end_fold + freq_timedelta
end_fold_exclusive = pd.date_range(start=end_fold, periods=2, freq=ts.freq)[1]

# draw test
backtest_df_slice_fold = segment_backtest_df[start_fold:end_fold_exclusive]
ax[i].plot(backtest_df_slice_fold.index, backtest_df_slice_fold.target, color=lines_colors["test"])

if is_full_folds:
if draw_only_lines:
# draw forecast
forecast_df_slice_fold = segment_forecast_df[start_fold:end_fold_exclusive]
ax[i].plot(forecast_df_slice_fold.index, forecast_df_slice_fold.target, color=lines_colors["forecast"])
Expand Down Expand Up @@ -360,7 +361,6 @@ def plot_backtest_interactive(
forecast_start = forecast_df.index.min()
history_df = df[df.index < forecast_start]
backtest_df = df[df.index >= forecast_start]
freq_timedelta = df.index[1] - df.index[0]

# prepare colors
colors = plotly.colors.qualitative.Dark24
Expand All @@ -371,6 +371,8 @@ def plot_backtest_interactive(
segment_history_df = history_df[segment]
segment_forecast_df = forecast_df[segment]
is_full_folds = set(segment_backtest_df.index) == set(segment_forecast_df.index)
single_point_forecast = len(segment_backtest_df) == 1
draw_only_lines = is_full_folds and not single_point_forecast

# plot history
if history_len == "all":
Expand All @@ -395,7 +397,7 @@ def plot_backtest_interactive(
for fold_number in folds:
start_fold = fold_numbers[fold_numbers == fold_number].index.min()
end_fold = fold_numbers[fold_numbers == fold_number].index.max()
end_fold_exclusive = end_fold + freq_timedelta
end_fold_exclusive = pd.date_range(start=end_fold, periods=2, freq=ts.freq)[1]

# draw test
backtest_df_slice_fold = segment_backtest_df[start_fold:end_fold_exclusive]
Expand All @@ -412,7 +414,7 @@ def plot_backtest_interactive(
)
)

if is_full_folds:
if draw_only_lines:
# draw forecast
forecast_df_slice_fold = segment_forecast_df[start_fold:end_fold_exclusive]
fig.add_trace(
Expand Down