Skip to content

Fix plot_backtest and plot_backtest_interactive on one-step forecast #1260

Merged
merged 3 commits into from
May 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
-
-
### Fixed
-
- Fix `plot_backtest` and `plot_backtest_interactive` on one-step forecast ([1260](https://github.com/tinkoff-ai/etna/pull/1260))
- Fix `BaseReconciliator` to work on `pandas==1.1.5` ([#1229](https://github.com/tinkoff-ai/etna/pull/1229))
- Fix `TSDataset.make_future` to handle hierarchy, quantiles, target components ([#1248](https://github.com/tinkoff-ai/etna/pull/1248))
- Fix warning during creation of `ResampleWithDistributionTransform` ([#1230](https://github.com/tinkoff-ai/etna/pull/1230))
Expand Down
14 changes: 8 additions & 6 deletions etna/analysis/forecast/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,6 @@ def plot_backtest(
forecast_start = forecast_df.index.min()
history_df = df[df.index < forecast_start]
backtest_df = df[df.index >= forecast_start]
freq_timedelta = df.index[1] - df.index[0]

# prepare colors
default_colors = plt.rcParams["axes.prop_cycle"].by_key()["color"]
Expand All @@ -257,6 +256,8 @@ def plot_backtest(
segment_history_df = history_df[segment]
segment_forecast_df = forecast_df[segment]
is_full_folds = set(segment_backtest_df.index) == set(segment_forecast_df.index)
single_point_forecast = len(segment_backtest_df) == 1
Mr-Geekman marked this conversation as resolved.
Show resolved Hide resolved
draw_only_lines = is_full_folds and not single_point_forecast

# plot history
if history_len == "all":
Expand All @@ -270,13 +271,13 @@ def plot_backtest(
for fold_number in folds:
start_fold = fold_numbers[fold_numbers == fold_number].index.min()
end_fold = fold_numbers[fold_numbers == fold_number].index.max()
end_fold_exclusive = end_fold + freq_timedelta
end_fold_exclusive = pd.date_range(start=end_fold, periods=2, freq=ts.freq)[1]

# draw test
backtest_df_slice_fold = segment_backtest_df[start_fold:end_fold_exclusive]
ax[i].plot(backtest_df_slice_fold.index, backtest_df_slice_fold.target, color=lines_colors["test"])

if is_full_folds:
if draw_only_lines:
# draw forecast
forecast_df_slice_fold = segment_forecast_df[start_fold:end_fold_exclusive]
ax[i].plot(forecast_df_slice_fold.index, forecast_df_slice_fold.target, color=lines_colors["forecast"])
Expand Down Expand Up @@ -360,7 +361,6 @@ def plot_backtest_interactive(
forecast_start = forecast_df.index.min()
history_df = df[df.index < forecast_start]
backtest_df = df[df.index >= forecast_start]
freq_timedelta = df.index[1] - df.index[0]

# prepare colors
colors = plotly.colors.qualitative.Dark24
Expand All @@ -371,6 +371,8 @@ def plot_backtest_interactive(
segment_history_df = history_df[segment]
segment_forecast_df = forecast_df[segment]
is_full_folds = set(segment_backtest_df.index) == set(segment_forecast_df.index)
single_point_forecast = len(segment_backtest_df) == 1
draw_only_lines = is_full_folds and not single_point_forecast

# plot history
if history_len == "all":
Expand All @@ -395,7 +397,7 @@ def plot_backtest_interactive(
for fold_number in folds:
start_fold = fold_numbers[fold_numbers == fold_number].index.min()
end_fold = fold_numbers[fold_numbers == fold_number].index.max()
end_fold_exclusive = end_fold + freq_timedelta
end_fold_exclusive = pd.date_range(start=end_fold, periods=2, freq=ts.freq)[1]

# draw test
backtest_df_slice_fold = segment_backtest_df[start_fold:end_fold_exclusive]
Expand All @@ -412,7 +414,7 @@ def plot_backtest_interactive(
)
)

if is_full_folds:
if draw_only_lines:
# draw forecast
forecast_df_slice_fold = segment_forecast_df[start_fold:end_fold_exclusive]
fig.add_trace(
Expand Down