Skip to content

add RMSE metric & rmse functional metric #1051

Merged
merged 3 commits into from
Dec 26, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## Unreleased
### Added
-
- `RMSE` metric & `rmse` functional metric ([#1051](https://github.com/tinkoff-ai/etna/pull/1051))
-
-
-
Expand Down
2 changes: 2 additions & 0 deletions etna/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from etna.metrics.base import Metric
from etna.metrics.base import MetricAggregationMode
from etna.metrics.functional_metrics import mape
from etna.metrics.functional_metrics import rmse
from etna.metrics.functional_metrics import sign
from etna.metrics.functional_metrics import smape
from etna.metrics.intervals_metrics import Coverage
Expand All @@ -16,6 +17,7 @@
from etna.metrics.metrics import MSE
from etna.metrics.metrics import MSLE
from etna.metrics.metrics import R2
from etna.metrics.metrics import RMSE
from etna.metrics.metrics import SMAPE
from etna.metrics.metrics import MedAE
from etna.metrics.metrics import Sign
Expand Down
5 changes: 5 additions & 0 deletions etna/metrics/functional_metrics.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from functools import partial
from typing import List
from typing import Union

import numpy as np
from sklearn.metrics import mean_squared_error as mse

ArrayLike = List[Union[float, List[float]]]

Expand Down Expand Up @@ -112,3 +114,6 @@ def sign(y_true: ArrayLike, y_pred: ArrayLike) -> float:
raise ValueError("Shapes of the labels must be the same")

return np.mean(np.sign(y_true_array - y_pred_array))


rmse = partial(mse, squared=False)
32 changes: 31 additions & 1 deletion etna/metrics/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from etna.metrics import mse
from etna.metrics import msle
from etna.metrics import r2_score
from etna.metrics import rmse
from etna.metrics import sign
from etna.metrics import smape
from etna.metrics.base import Metric
Expand Down Expand Up @@ -68,6 +69,35 @@ def greater_is_better(self) -> bool:
return False


class RMSE(Metric):
"""Root mean squared error metric with multi-segment computation support.

.. math::
RMSE(y\_true, y\_pred) = \\sqrt\\frac{\\sum_{i=0}^{n-1}{(y\_true_i - y\_pred_i)^2}}{n}

Notes
-----
You can read more about logic of multi-segment metrics in Metric docs.
"""

def __init__(self, mode: str = MetricAggregationMode.per_segment, **kwargs):
"""Init metric.

Parameters
----------
mode: 'macro' or 'per-segment'
metrics aggregation mode
kwargs:
metric's computation arguments
"""
super().__init__(mode=mode, metric_fn=rmse, **kwargs)
alex-hse-repository marked this conversation as resolved.
Show resolved Hide resolved

@property
def greater_is_better(self) -> bool:
"""Whether higher metric value is better."""
return False


class R2(Metric):
"""Coefficient of determination metric with multi-segment computation support.

Expand Down Expand Up @@ -242,4 +272,4 @@ def greater_is_better(self) -> None:
return None


__all__ = ["MAE", "MSE", "R2", "MSLE", "MAPE", "SMAPE", "MedAE", "Sign"]
__all__ = ["MAE", "MSE", "RMSE", "R2", "MSLE", "MAPE", "SMAPE", "MedAE", "Sign"]
5 changes: 3 additions & 2 deletions tests/test_metrics/test_functional_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from etna.metrics import mse
from etna.metrics import msle
from etna.metrics import r2_score
from etna.metrics import rmse
from etna.metrics import sign
from etna.metrics import smape

Expand All @@ -27,7 +28,7 @@ def y_pred_1d():

@pytest.mark.parametrize(
"metric, right_metrics_value",
((mae, 1), (mse, 1), (mape, 100), (smape, 66.6666666667), (medae, 1), (r2_score, 0), (sign, -1)),
((mae, 1), (mse, 1), (rmse, 1), (mape, 100), (smape, 66.6666666667), (medae, 1), (r2_score, 0), (sign, -1)),
)
def test_all_1d_metrics(metric, right_metrics_value, y_true_1d, y_pred_1d):
assert round(metric(y_true_1d, y_pred_1d), 10) == right_metrics_value
Expand All @@ -52,7 +53,7 @@ def y_pred_2d():

@pytest.mark.parametrize(
"metric, right_metrics_value",
((mae, 1), (mse, 1), (mape, 100), (smape, 66.6666666667), (medae, 1), (r2_score, 0.0), (sign, -1)),
((mae, 1), (mse, 1), (rmse, 1), (mape, 100), (smape, 66.6666666667), (medae, 1), (r2_score, 0.0), (sign, -1)),
)
def test_all_2d_metrics(metric, right_metrics_value, y_true_2d, y_pred_2d):
assert round(metric(y_true_2d, y_pred_2d), 10) == right_metrics_value
19 changes: 12 additions & 7 deletions tests/test_metrics/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from etna.metrics import mse
from etna.metrics import msle
from etna.metrics import r2_score
from etna.metrics import rmse
from etna.metrics import sign
from etna.metrics import smape
from etna.metrics.base import MetricAggregationMode
Expand All @@ -16,6 +17,7 @@
from etna.metrics.metrics import MSE
from etna.metrics.metrics import MSLE
from etna.metrics.metrics import R2
from etna.metrics.metrics import RMSE
from etna.metrics.metrics import SMAPE
from etna.metrics.metrics import MedAE
from etna.metrics.metrics import Sign
Expand All @@ -28,6 +30,7 @@
(
(MAE, "MAE", {}, ""),
(MSE, "MSE", {}, ""),
(RMSE, "RMSE", {}, ""),
(MedAE, "MedAE", {}, ""),
(MSLE, "MSLE", {}, ""),
(MAPE, "MAPE", {}, ""),
Expand All @@ -50,7 +53,7 @@ def test_repr(metric_class, metric_class_repr, metric_params, param_repr):

@pytest.mark.parametrize(
"metric_class",
(MAE, MSE, MedAE, MSLE, MAPE, SMAPE, R2, Sign),
(MAE, MSE, RMSE, MedAE, MSLE, MAPE, SMAPE, R2, Sign),
)
def test_name_class_name(metric_class):
"""Check metrics name property without changing its during inheritance"""
Expand All @@ -74,7 +77,7 @@ def test_name_repr(metric_class):
assert metric_name == true_name


@pytest.mark.parametrize("metric_class", (MAE, MSE, MedAE, MSLE, MAPE, SMAPE, R2, Sign))
@pytest.mark.parametrize("metric_class", (MAE, MSE, RMSE, MedAE, MSLE, MAPE, SMAPE, R2, Sign))
def test_metrics_macro(metric_class, train_test_dfs):
"""Check metrics interface in 'macro' mode"""
forecast_df, true_df = train_test_dfs
Expand All @@ -83,7 +86,7 @@ def test_metrics_macro(metric_class, train_test_dfs):
assert isinstance(value, float)


@pytest.mark.parametrize("metric_class", (MAE, MSE, MedAE, MSLE, MAPE, SMAPE, R2, Sign, DummyMetric))
@pytest.mark.parametrize("metric_class", (MAE, MSE, RMSE, MedAE, MSLE, MAPE, SMAPE, R2, Sign, DummyMetric))
def test_metrics_per_segment(metric_class, train_test_dfs):
"""Check metrics interface in 'per-segment' mode"""
forecast_df, true_df = train_test_dfs
Expand All @@ -94,14 +97,14 @@ def test_metrics_per_segment(metric_class, train_test_dfs):
assert segment in value


@pytest.mark.parametrize("metric_class", (MAE, MSE, MedAE, MSLE, MAPE, SMAPE, R2, Sign, DummyMetric))
@pytest.mark.parametrize("metric_class", (MAE, MSE, RMSE, MedAE, MSLE, MAPE, SMAPE, R2, Sign, DummyMetric))
def test_metrics_invalid_aggregation(metric_class):
"""Check metrics behavior in case of invalid aggregation mode"""
with pytest.raises(NotImplementedError):
_ = metric_class(mode="a")


@pytest.mark.parametrize("metric_class", (MAE, MSE, MedAE, MSLE, MAPE, SMAPE, R2, Sign, DummyMetric))
@pytest.mark.parametrize("metric_class", (MAE, MSE, RMSE, MedAE, MSLE, MAPE, SMAPE, R2, Sign, DummyMetric))
def test_invalid_timestamps(metric_class, two_dfs_with_different_timestamps):
"""Check metrics behavior in case of invalid timeranges"""
forecast_df, true_df = two_dfs_with_different_timestamps
Expand All @@ -110,7 +113,7 @@ def test_invalid_timestamps(metric_class, two_dfs_with_different_timestamps):
_ = metric(y_true=true_df, y_pred=forecast_df)


@pytest.mark.parametrize("metric_class", (MAE, MSE, MedAE, MSLE, MAPE, SMAPE, R2, Sign, DummyMetric))
@pytest.mark.parametrize("metric_class", (MAE, MSE, RMSE, MedAE, MSLE, MAPE, SMAPE, R2, Sign, DummyMetric))
def test_invalid_segments(metric_class, two_dfs_with_different_segments_sets):
"""Check metrics behavior in case of invalid segments sets"""
forecast_df, true_df = two_dfs_with_different_segments_sets
Expand All @@ -119,7 +122,7 @@ def test_invalid_segments(metric_class, two_dfs_with_different_segments_sets):
_ = metric(y_true=true_df, y_pred=forecast_df)


@pytest.mark.parametrize("metric_class", (MAE, MSE, MedAE, MSLE, MAPE, SMAPE, R2, Sign, DummyMetric))
@pytest.mark.parametrize("metric_class", (MAE, MSE, RMSE, MedAE, MSLE, MAPE, SMAPE, R2, Sign, DummyMetric))
def test_invalid_segments_target(metric_class, train_test_dfs):
"""Check metrics behavior in case of no target column in segment"""
forecast_df, true_df = train_test_dfs
Expand All @@ -134,6 +137,7 @@ def test_invalid_segments_target(metric_class, train_test_dfs):
(
(MAE, mae),
(MSE, mse),
(RMSE, rmse),
(MedAE, medae),
(MSLE, msle),
(MAPE, mape),
Expand Down Expand Up @@ -164,6 +168,7 @@ def test_metrics_values(metric_class, metric_fn, train_test_dfs):
(
(MAE(), False),
(MSE(), False),
(RMSE(), False),
(MedAE(), False),
(MSLE(), False),
(MAPE(), False),
Expand Down