Skip to content

Speed up metrics computation by optimizing segment validation #1338

Merged
merged 7 commits into from
Aug 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Changed
-
- Add sorting by timestamp before the fit in `CatBoostPerSegmentModel` and `CatBoostMultiSegmentModel` ([#1337](https://github.com/tinkoff-ai/etna/pull/1337))
- Speed up metrics computation by optimizing segment validation, forbid NaNs during metrics computation ([#1338](https://github.com/tinkoff-ai/etna/pull/1338))
- Unify errors, warnings and checks in models ([#1312](https://github.com/tinkoff-ai/etna/pull/1312))
- Remove upper limitation on version of numba ([#1321](https://github.com/tinkoff-ai/etna/pull/1321))

Expand Down
97 changes: 71 additions & 26 deletions etna/metrics/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,9 +111,8 @@ def name(self) -> str:
return self.__class__.__name__

@staticmethod
def _validate_segment_columns(y_true: TSDataset, y_pred: TSDataset):
"""
Check if all the segments from ``y_true`` are in ``y_pred`` and vice versa.
def _validate_segments(y_true: TSDataset, y_pred: TSDataset):
"""Check that segments in ``y_true`` and ``y_pred`` are the same.

Parameters
----------
Expand All @@ -125,9 +124,7 @@ def _validate_segment_columns(y_true: TSDataset, y_pred: TSDataset):
Raises
------
ValueError:
if there are mismatches in y_true and y_pred segments,
ValueError:
if one of segments in y_true or y_pred doesn't contain 'target' column.
if there are mismatches in y_true and y_pred segments
"""
segments_true = set(y_true.df.columns.get_level_values("segment"))
segments_pred = set(y_pred.df.columns.get_level_values("segment"))
Expand All @@ -144,33 +141,78 @@ def _validate_segment_columns(y_true: TSDataset, y_pred: TSDataset):
f"There are segments in y_true that are not in y_pred, for example: "
f"{', '.join(list(true_diff_pred)[:5])}"
)
for segment in segments_true:

@staticmethod
def _validate_target_columns(y_true: TSDataset, y_pred: TSDataset):
"""Check that all the segments from ``y_true`` and ``y_pred`` has 'target' column.

Parameters
----------
y_true:
y_true dataset
y_pred:
y_pred dataset

Raises
------
ValueError:
if one of segments in y_true or y_pred doesn't contain 'target' column.
"""
segments = set(y_true.df.columns.get_level_values("segment"))

for segment in segments:
for name, dataset in zip(("y_true", "y_pred"), (y_true, y_pred)):
if "target" not in dataset.loc[:, segment].columns:
if (segment, "target") not in dataset.columns:
raise ValueError(
f"All the segments in {name} should contain 'target' column. Segment {segment} doesn't."
)

@staticmethod
def _validate_timestamp_columns(timestamp_true: pd.Series, timestamp_pred: pd.Series):
"""
Check that ``y_true`` and ``y_pred`` have the same timestamp.
def _validate_index(y_true: TSDataset, y_pred: TSDataset):
"""Check that ``y_true`` and ``y_pred`` have the same timestamps.

Parameters
----------
timestamp_true:
y_true's timestamp column
timestamp_pred:
y_pred's timestamp column
y_true:
y_true dataset
y_pred:
y_pred dataset

Raises
------
ValueError:
If there are mismatches in ``y_true`` and ``y_pred`` timestamps
"""
if set(timestamp_pred) != set(timestamp_true):
if not y_true.index.equals(y_pred.index):
raise ValueError("y_true and y_pred have different timestamps")

@staticmethod
def _validate_nans(y_true: TSDataset, y_pred: TSDataset):
"""Check that ``y_true`` and ``y_pred`` doesn't have NaNs.

Parameters
----------
y_true:
y_true dataset
y_pred:
y_pred dataset

Raises
------
ValueError:
If there are NaNs in ``y_true`` or ``y_pred``
"""
df_true = y_true.df.loc[:, pd.IndexSlice[:, "target"]]
df_pred = y_pred.df.loc[:, pd.IndexSlice[:, "target"]]

df_true_isna = df_true.isna().any().any()
if df_true_isna > 0:
raise ValueError("There are NaNs in y_true")

df_pred_isna = df_pred.isna().any().any()
if df_pred_isna > 0:
raise ValueError("There are NaNs in y_pred")

@staticmethod
def _macro_average(metrics_per_segments: Dict[str, float]) -> Union[float, Dict[str, float]]:
"""
Expand Down Expand Up @@ -226,18 +268,21 @@ def __call__(self, y_true: TSDataset, y_pred: TSDataset) -> Union[float, Dict[st
metric's value aggregated over segments or not (depends on mode)
"""
self._log_start()
self._validate_segment_columns(y_true=y_true, y_pred=y_pred)
self._validate_segments(y_true=y_true, y_pred=y_pred)
self._validate_target_columns(y_true=y_true, y_pred=y_pred)
self._validate_index(y_true=y_true, y_pred=y_pred)
self._validate_nans(y_true=y_true, y_pred=y_pred)

df_true = y_true[:, :, "target"].sort_index(axis=1)
df_pred = y_pred[:, :, "target"].sort_index(axis=1)

segments = set(y_true.df.columns.get_level_values("segment"))
metrics_per_segment = {}
for segment in segments:
self._validate_timestamp_columns(
timestamp_true=y_true[:, segment, "target"].dropna().index,
timestamp_pred=y_pred[:, segment, "target"].dropna().index,
)
metrics_per_segment[segment] = self.metric_fn(
y_true=y_true[:, segment, "target"].values, y_pred=y_pred[:, segment, "target"].values, **self.kwargs
)
segments = df_true.columns.get_level_values("segment").unique()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't it be sorted as index in the dataframe is sorted?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

May be we need test for such behaviour(input datasets have unsorted segments)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it will be sorted because we sorted index of df_true. Also we have a guarantee that unique returns values in the order of its appearance.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I'll try to add a test on this.


for i, segment in enumerate(segments):
cur_y_true = df_true.iloc[:, i]
cur_y_pred = df_pred.iloc[:, i]
metrics_per_segment[segment] = self.metric_fn(y_true=cur_y_true, y_pred=cur_y_pred, **self.kwargs)
metrics = self._aggregate_metrics(metrics_per_segment)
return metrics

Expand Down
18 changes: 8 additions & 10 deletions etna/metrics/intervals_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,16 +67,15 @@ def __call__(self, y_true: TSDataset, y_pred: TSDataset) -> Union[float, Dict[st
-------
metric's value aggregated over segments or not (depends on mode)
"""
self._validate_segment_columns(y_true=y_true, y_pred=y_pred)
self._validate_segments(y_true=y_true, y_pred=y_pred)
self._validate_target_columns(y_true=y_true, y_pred=y_pred)
self._validate_index(y_true=y_true, y_pred=y_pred)
self._validate_nans(y_true=y_true, y_pred=y_pred)
self._validate_tsdataset_quantiles(ts=y_pred, quantiles=self.quantiles)

segments = set(y_true.df.columns.get_level_values("segment"))
metrics_per_segment = {}
for segment in segments:
self._validate_timestamp_columns(
timestamp_true=y_true[:, segment, "target"].dropna().index,
timestamp_pred=y_pred[:, segment, "target"].dropna().index,
)
upper_quantile_flag = y_true[:, segment, "target"] <= y_pred[:, segment, f"target_{self.quantiles[1]:.4g}"]
lower_quantile_flag = y_true[:, segment, "target"] >= y_pred[:, segment, f"target_{self.quantiles[0]:.4g}"]

Expand Down Expand Up @@ -135,16 +134,15 @@ def __call__(self, y_true: TSDataset, y_pred: TSDataset) -> Union[float, Dict[st
-------
metric's value aggregated over segments or not (depends on mode)
"""
self._validate_segment_columns(y_true=y_true, y_pred=y_pred)
self._validate_segments(y_true=y_true, y_pred=y_pred)
self._validate_target_columns(y_true=y_true, y_pred=y_pred)
self._validate_index(y_true=y_true, y_pred=y_pred)
self._validate_nans(y_true=y_true, y_pred=y_pred)
self._validate_tsdataset_quantiles(ts=y_pred, quantiles=self.quantiles)

segments = set(y_true.df.columns.get_level_values("segment"))
metrics_per_segment = {}
for segment in segments:
self._validate_timestamp_columns(
timestamp_true=y_true[:, segment, "target"].dropna().index,
timestamp_pred=y_pred[:, segment, "target"].dropna().index,
)
upper_quantile = y_pred[:, segment, f"target_{self.quantiles[1]:.4g}"]
lower_quantile = y_pred[:, segment, f"target_{self.quantiles[0]:.4g}"]

Expand Down
71 changes: 60 additions & 11 deletions tests/test_metrics/test_metrics.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from copy import deepcopy

import numpy as np
import pandas as pd
import pytest

Expand Down Expand Up @@ -117,34 +120,60 @@ def test_metrics_invalid_aggregation(metric_class):
@pytest.mark.parametrize(
"metric_class", (MAE, MSE, RMSE, MedAE, MSLE, MAPE, SMAPE, R2, Sign, MaxDeviation, DummyMetric, WAPE)
)
def test_invalid_timestamps(metric_class, two_dfs_with_different_timestamps):
"""Check metrics behavior in case of invalid timeranges"""
def test_invalid_segments(metric_class, two_dfs_with_different_segments_sets):
"""Check metrics behavior in case of invalid segments sets"""
forecast_df, true_df = two_dfs_with_different_segments_sets
metric = metric_class()
with pytest.raises(ValueError, match="There are segments in .* that are not in .*"):
_ = metric(y_true=true_df, y_pred=forecast_df)


@pytest.mark.parametrize(
"metric_class", (MAE, MSE, RMSE, MedAE, MSLE, MAPE, SMAPE, R2, Sign, MaxDeviation, DummyMetric, WAPE)
)
def test_invalid_target_columns(metric_class, train_test_dfs):
"""Check metrics behavior in case of no target column in segment"""
forecast_df, true_df = train_test_dfs
columns = forecast_df.df.columns.to_list()
columns[0] = ("segment_1", "not_target")
forecast_df.df.columns = pd.MultiIndex.from_tuples(columns, names=["segment", "feature"])
metric = metric_class()
with pytest.raises(ValueError, match="All the segments in .* should contain 'target' column"):
_ = metric(y_true=true_df, y_pred=forecast_df)


@pytest.mark.parametrize(
"metric_class", (MAE, MSE, RMSE, MedAE, MSLE, MAPE, SMAPE, R2, Sign, MaxDeviation, DummyMetric, WAPE)
)
def test_invalid_index(metric_class, two_dfs_with_different_timestamps):
"""Check metrics behavior in case of invalid index"""
forecast_df, true_df = two_dfs_with_different_timestamps
metric = metric_class()
with pytest.raises(ValueError):
with pytest.raises(ValueError, match="y_true and y_pred have different timestamps"):
_ = metric(y_true=true_df, y_pred=forecast_df)


@pytest.mark.parametrize(
"metric_class", (MAE, MSE, RMSE, MedAE, MSLE, MAPE, SMAPE, R2, Sign, MaxDeviation, DummyMetric, WAPE)
)
def test_invalid_segments(metric_class, two_dfs_with_different_segments_sets):
"""Check metrics behavior in case of invalid segments sets"""
forecast_df, true_df = two_dfs_with_different_segments_sets
def test_invalid_nans_pred(metric_class, train_test_dfs):
"""Check metrics behavior in case of nans in prediction."""
forecast_df, true_df = train_test_dfs
forecast_df.df.iloc[0, 0] = np.NaN
metric = metric_class()
with pytest.raises(ValueError):
with pytest.raises(ValueError, match="There are NaNs in y_pred"):
_ = metric(y_true=true_df, y_pred=forecast_df)


@pytest.mark.parametrize(
"metric_class", (MAE, MSE, RMSE, MedAE, MSLE, MAPE, SMAPE, R2, Sign, MaxDeviation, DummyMetric, WAPE)
)
def test_invalid_segments_target(metric_class, train_test_dfs):
"""Check metrics behavior in case of no target column in segment"""
def test_invalid_nans_true(metric_class, train_test_dfs):
"""Check metrics behavior in case of nans in true values."""
forecast_df, true_df = train_test_dfs
forecast_df.df.drop(columns=[("segment_1", "target")], inplace=True)
true_df.df.iloc[0, 0] = np.NaN
metric = metric_class()
with pytest.raises(ValueError):
with pytest.raises(ValueError, match="There are NaNs in y_true"):
_ = metric(y_true=true_df, y_pred=forecast_df)


Expand Down Expand Up @@ -181,6 +210,26 @@ def test_metrics_values(metric_class, metric_fn, train_test_dfs):
assert value == true_metric_value


@pytest.mark.parametrize(
"metric_class", (MAE, MSE, RMSE, MedAE, MSLE, MAPE, SMAPE, R2, Sign, MaxDeviation, DummyMetric, WAPE)
)
def test_metric_values_with_changed_segment_order(metric_class, train_test_dfs):
forecast_df, true_df = train_test_dfs
forecast_df_new, true_df_new = deepcopy(train_test_dfs)
segments = np.array(forecast_df.segments)

forecast_segment_order = segments[[3, 2, 0, 1, 4]]
forecast_df_new.df = forecast_df_new.df.loc[:, pd.IndexSlice[forecast_segment_order, :]]
true_segment_order = segments[[4, 1, 3, 2, 0]]
true_df_new.df = true_df_new.df.loc[:, pd.IndexSlice[true_segment_order, :]]

metric = metric_class(mode="per-segment")
metrics_initial = metric(y_pred=forecast_df, y_true=true_df)
metrics_changed_order = metric(y_pred=forecast_df_new, y_true=true_df_new)

assert metrics_initial == metrics_changed_order


@pytest.mark.parametrize(
"metric, greater_is_better",
(
Expand Down