Skip to content

add RMSE metric & rmse functional metric #1051

Merged
merged 3 commits into from
Dec 26, 2022
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## Unreleased
### Added
-
- `RMSE` metric & `rmse` functional metric ([#1051](https://github.com/tinkoff-ai/etna/pull/1051))
-
-
-
Expand Down
2 changes: 2 additions & 0 deletions etna/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from etna.metrics.base import Metric
from etna.metrics.base import MetricAggregationMode
from etna.metrics.functional_metrics import mape
from etna.metrics.functional_metrics import rmse
from etna.metrics.functional_metrics import sign
from etna.metrics.functional_metrics import smape
from etna.metrics.intervals_metrics import Coverage
Expand All @@ -16,6 +17,7 @@
from etna.metrics.metrics import MSE
from etna.metrics.metrics import MSLE
from etna.metrics.metrics import R2
from etna.metrics.metrics import RMSE
from etna.metrics.metrics import SMAPE
from etna.metrics.metrics import MedAE
from etna.metrics.metrics import Sign
Expand Down
5 changes: 5 additions & 0 deletions etna/metrics/functional_metrics.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from functools import partial
from typing import List
from typing import Union

import numpy as np
from sklearn.metrics import mean_squared_error as mse

ArrayLike = List[Union[float, List[float]]]

Expand Down Expand Up @@ -112,3 +114,6 @@ def sign(y_true: ArrayLike, y_pred: ArrayLike) -> float:
raise ValueError("Shapes of the labels must be the same")

return np.mean(np.sign(y_true_array - y_pred_array))


rmse = partial(mse, squared=False)
32 changes: 31 additions & 1 deletion etna/metrics/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from etna.metrics import mse
from etna.metrics import msle
from etna.metrics import r2_score
from etna.metrics import rmse
from etna.metrics import sign
from etna.metrics import smape
from etna.metrics.base import Metric
Expand Down Expand Up @@ -68,6 +69,35 @@ def greater_is_better(self) -> bool:
return False


class RMSE(Metric):
"""Root mean squared error metric with multi-segment computation support.

.. math::
RMSE(y\_true, y\_pred) = \\sqrt\\frac{\\sum_{i=0}^{n-1}{(y\_true_i - y\_pred_i)^2}}{n}

Notes
-----
You can read more about logic of multi-segment metrics in Metric docs.
"""

def __init__(self, mode: str = MetricAggregationMode.per_segment, **kwargs):
"""Init metric.

Parameters
----------
mode: 'macro' or 'per-segment'
metrics aggregation mode
kwargs:
metric's computation arguments
"""
super().__init__(mode=mode, metric_fn=rmse, **kwargs)
alex-hse-repository marked this conversation as resolved.
Show resolved Hide resolved

@property
def greater_is_better(self) -> bool:
"""Whether higher metric value is better."""
return False


class R2(Metric):
"""Coefficient of determination metric with multi-segment computation support.

Expand Down Expand Up @@ -242,4 +272,4 @@ def greater_is_better(self) -> None:
return None


__all__ = ["MAE", "MSE", "R2", "MSLE", "MAPE", "SMAPE", "MedAE", "Sign"]
__all__ = ["MAE", "MSE", "RMSE", "R2", "MSLE", "MAPE", "SMAPE", "MedAE", "Sign"]
5 changes: 3 additions & 2 deletions tests/test_metrics/test_functional_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from etna.metrics import mse
from etna.metrics import msle
from etna.metrics import r2_score
from etna.metrics import rmse
from etna.metrics import sign
from etna.metrics import smape

Expand All @@ -27,7 +28,7 @@ def y_pred_1d():

@pytest.mark.parametrize(
"metric, right_metrics_value",
((mae, 1), (mse, 1), (mape, 100), (smape, 66.6666666667), (medae, 1), (r2_score, 0), (sign, -1)),
((mae, 1), (mse, 1), (rmse, 1), (mape, 100), (smape, 66.6666666667), (medae, 1), (r2_score, 0), (sign, -1)),
)
def test_all_1d_metrics(metric, right_metrics_value, y_true_1d, y_pred_1d):
assert round(metric(y_true_1d, y_pred_1d), 10) == right_metrics_value
Expand All @@ -52,7 +53,7 @@ def y_pred_2d():

@pytest.mark.parametrize(
"metric, right_metrics_value",
((mae, 1), (mse, 1), (mape, 100), (smape, 66.6666666667), (medae, 1), (r2_score, 0.0), (sign, -1)),
((mae, 1), (mse, 1), (rmse, 1), (mape, 100), (smape, 66.6666666667), (medae, 1), (r2_score, 0.0), (sign, -1)),
)
def test_all_2d_metrics(metric, right_metrics_value, y_true_2d, y_pred_2d):
assert round(metric(y_true_2d, y_pred_2d), 10) == right_metrics_value
19 changes: 12 additions & 7 deletions tests/test_metrics/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from etna.metrics import mse
from etna.metrics import msle
from etna.metrics import r2_score
from etna.metrics import rmse
from etna.metrics import sign
from etna.metrics import smape
from etna.metrics.base import MetricAggregationMode
Expand All @@ -16,6 +17,7 @@
from etna.metrics.metrics import MSE
from etna.metrics.metrics import MSLE
from etna.metrics.metrics import R2
from etna.metrics.metrics import RMSE
from etna.metrics.metrics import SMAPE
from etna.metrics.metrics import MedAE
from etna.metrics.metrics import Sign
Expand All @@ -28,6 +30,7 @@
(
(MAE, "MAE", {}, ""),
(MSE, "MSE", {}, ""),
(RMSE, "RMSE", {}, ""),
(MedAE, "MedAE", {}, ""),
(MSLE, "MSLE", {}, ""),
(MAPE, "MAPE", {}, ""),
Expand All @@ -50,7 +53,7 @@ def test_repr(metric_class, metric_class_repr, metric_params, param_repr):

@pytest.mark.parametrize(
"metric_class",
(MAE, MSE, MedAE, MSLE, MAPE, SMAPE, R2, Sign),
(MAE, MSE, RMSE, MedAE, MSLE, MAPE, SMAPE, R2, Sign),
)
def test_name_class_name(metric_class):
"""Check metrics name property without changing its during inheritance"""
Expand All @@ -74,7 +77,7 @@ def test_name_repr(metric_class):
assert metric_name == true_name


@pytest.mark.parametrize("metric_class", (MAE, MSE, MedAE, MSLE, MAPE, SMAPE, R2, Sign))
@pytest.mark.parametrize("metric_class", (MAE, MSE, RMSE, MedAE, MSLE, MAPE, SMAPE, R2, Sign))
def test_metrics_macro(metric_class, train_test_dfs):
"""Check metrics interface in 'macro' mode"""
forecast_df, true_df = train_test_dfs
Expand All @@ -83,7 +86,7 @@ def test_metrics_macro(metric_class, train_test_dfs):
assert isinstance(value, float)


@pytest.mark.parametrize("metric_class", (MAE, MSE, MedAE, MSLE, MAPE, SMAPE, R2, Sign, DummyMetric))
@pytest.mark.parametrize("metric_class", (MAE, MSE, RMSE, MedAE, MSLE, MAPE, SMAPE, R2, Sign, DummyMetric))
def test_metrics_per_segment(metric_class, train_test_dfs):
"""Check metrics interface in 'per-segment' mode"""
forecast_df, true_df = train_test_dfs
Expand All @@ -94,14 +97,14 @@ def test_metrics_per_segment(metric_class, train_test_dfs):
assert segment in value


@pytest.mark.parametrize("metric_class", (MAE, MSE, MedAE, MSLE, MAPE, SMAPE, R2, Sign, DummyMetric))
@pytest.mark.parametrize("metric_class", (MAE, MSE, RMSE, MedAE, MSLE, MAPE, SMAPE, R2, Sign, DummyMetric))
def test_metrics_invalid_aggregation(metric_class):
"""Check metrics behavior in case of invalid aggregation mode"""
with pytest.raises(NotImplementedError):
_ = metric_class(mode="a")


@pytest.mark.parametrize("metric_class", (MAE, MSE, MedAE, MSLE, MAPE, SMAPE, R2, Sign, DummyMetric))
@pytest.mark.parametrize("metric_class", (MAE, MSE, RMSE, MedAE, MSLE, MAPE, SMAPE, R2, Sign, DummyMetric))
def test_invalid_timestamps(metric_class, two_dfs_with_different_timestamps):
"""Check metrics behavior in case of invalid timeranges"""
forecast_df, true_df = two_dfs_with_different_timestamps
Expand All @@ -110,7 +113,7 @@ def test_invalid_timestamps(metric_class, two_dfs_with_different_timestamps):
_ = metric(y_true=true_df, y_pred=forecast_df)


@pytest.mark.parametrize("metric_class", (MAE, MSE, MedAE, MSLE, MAPE, SMAPE, R2, Sign, DummyMetric))
@pytest.mark.parametrize("metric_class", (MAE, MSE, RMSE, MedAE, MSLE, MAPE, SMAPE, R2, Sign, DummyMetric))
def test_invalid_segments(metric_class, two_dfs_with_different_segments_sets):
"""Check metrics behavior in case of invalid segments sets"""
forecast_df, true_df = two_dfs_with_different_segments_sets
Expand All @@ -119,7 +122,7 @@ def test_invalid_segments(metric_class, two_dfs_with_different_segments_sets):
_ = metric(y_true=true_df, y_pred=forecast_df)


@pytest.mark.parametrize("metric_class", (MAE, MSE, MedAE, MSLE, MAPE, SMAPE, R2, Sign, DummyMetric))
@pytest.mark.parametrize("metric_class", (MAE, MSE, RMSE, MedAE, MSLE, MAPE, SMAPE, R2, Sign, DummyMetric))
def test_invalid_segments_target(metric_class, train_test_dfs):
"""Check metrics behavior in case of no target column in segment"""
forecast_df, true_df = train_test_dfs
Expand All @@ -134,6 +137,7 @@ def test_invalid_segments_target(metric_class, train_test_dfs):
(
(MAE, mae),
(MSE, mse),
(RMSE, rmse),
(MedAE, medae),
(MSLE, msle),
(MAPE, mape),
Expand Down Expand Up @@ -164,6 +168,7 @@ def test_metrics_values(metric_class, metric_fn, train_test_dfs):
(
(MAE(), False),
(MSE(), False),
(RMSE(), False),
(MedAE(), False),
(MSLE(), False),
(MAPE(), False),
Expand Down