Skip to content

add RMSE metric & rmse functional metric #1051

Merged
merged 3 commits into from
Dec 26, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## Unreleased
### Added
-
- `RMSE` metric & `rmse` functional metric
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

May you also add a ling to the current PR in the format "([#<pr number>](link))". See the raw view of CHANGELOG.md file for the examples

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

-
-
-
Expand Down
2 changes: 2 additions & 0 deletions etna/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from etna.metrics.base import Metric
from etna.metrics.base import MetricAggregationMode
from etna.metrics.functional_metrics import mape
from etna.metrics.functional_metrics import rmse
from etna.metrics.functional_metrics import sign
from etna.metrics.functional_metrics import smape
from etna.metrics.intervals_metrics import Coverage
Expand All @@ -16,6 +17,7 @@
from etna.metrics.metrics import MSE
from etna.metrics.metrics import MSLE
from etna.metrics.metrics import R2
from etna.metrics.metrics import RMSE
from etna.metrics.metrics import SMAPE
from etna.metrics.metrics import MedAE
from etna.metrics.metrics import Sign
Expand Down
32 changes: 32 additions & 0 deletions etna/metrics/functional_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Union

import numpy as np
from sklearn.metrics import mean_squared_error

ArrayLike = List[Union[float, List[float]]]

Expand Down Expand Up @@ -112,3 +113,34 @@ def sign(y_true: ArrayLike, y_pred: ArrayLike) -> float:
raise ValueError("Shapes of the labels must be the same")

return np.mean(np.sign(y_true_array - y_pred_array))


def rmse(y_true: ArrayLike, y_pred: ArrayLike) -> float:
"""Root mean squared error metric.

.. math::
RMSE(y\_true, y\_pred) = \\sqrt\\frac{\\sum_{i=0}^{n-1}{(y\_true_i - y\_pred_i)^2}}{n}

Parameters
----------
y_true:
array-like of shape (n_samples,) or (n_samples, n_outputs)

Ground truth (correct) target values.

y_pred:
array-like of shape (n_samples,) or (n_samples, n_outputs)

Estimated target values.

Returns
-------
float
A floating point value (the best value is 0.0).
"""
y_true_array, y_pred_array = np.asarray(y_true), np.asarray(y_pred)

if len(y_true_array.shape) != len(y_pred_array.shape):
raise ValueError("Shapes of the labels must be the same")

return mean_squared_error(y_true=y_true_array, y_pred=y_pred_array, squared=False)
32 changes: 31 additions & 1 deletion etna/metrics/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from etna.metrics import mse
from etna.metrics import msle
from etna.metrics import r2_score
from etna.metrics import rmse
from etna.metrics import sign
from etna.metrics import smape
from etna.metrics.base import Metric
Expand Down Expand Up @@ -68,6 +69,35 @@ def greater_is_better(self) -> bool:
return False


class RMSE(Metric):
"""Root mean squared error metric with multi-segment computation support.

.. math::
RMSE(y\_true, y\_pred) = \\sqrt\\frac{\\sum_{i=0}^{n-1}{(y\_true_i - y\_pred_i)^2}}{n}

Notes
-----
You can read more about logic of multi-segment metrics in Metric docs.
"""

def __init__(self, mode: str = MetricAggregationMode.per_segment, **kwargs):
"""Init metric.

Parameters
----------
mode: 'macro' or 'per-segment'
metrics aggregation mode
kwargs:
metric's computation arguments
"""
super().__init__(mode=mode, metric_fn=rmse, **kwargs)
alex-hse-repository marked this conversation as resolved.
Show resolved Hide resolved

@property
def greater_is_better(self) -> bool:
"""Whether higher metric value is better."""
return False


class R2(Metric):
"""Coefficient of determination metric with multi-segment computation support.

Expand Down Expand Up @@ -242,4 +272,4 @@ def greater_is_better(self) -> None:
return None


__all__ = ["MAE", "MSE", "R2", "MSLE", "MAPE", "SMAPE", "MedAE", "Sign"]
__all__ = ["MAE", "MSE", "RMSE", "R2", "MSLE", "MAPE", "SMAPE", "MedAE", "Sign"]
5 changes: 3 additions & 2 deletions tests/test_metrics/test_functional_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from etna.metrics import mse
from etna.metrics import msle
from etna.metrics import r2_score
from etna.metrics import rmse
from etna.metrics import sign
from etna.metrics import smape

Expand All @@ -27,7 +28,7 @@ def y_pred_1d():

@pytest.mark.parametrize(
"metric, right_metrics_value",
((mae, 1), (mse, 1), (mape, 100), (smape, 66.6666666667), (medae, 1), (r2_score, 0), (sign, -1)),
((mae, 1), (mse, 1), (rmse, 1), (mape, 100), (smape, 66.6666666667), (medae, 1), (r2_score, 0), (sign, -1)),
)
def test_all_1d_metrics(metric, right_metrics_value, y_true_1d, y_pred_1d):
assert round(metric(y_true_1d, y_pred_1d), 10) == right_metrics_value
Expand All @@ -52,7 +53,7 @@ def y_pred_2d():

@pytest.mark.parametrize(
"metric, right_metrics_value",
((mae, 1), (mse, 1), (mape, 100), (smape, 66.6666666667), (medae, 1), (r2_score, 0.0), (sign, -1)),
((mae, 1), (mse, 1), (rmse, 1), (mape, 100), (smape, 66.6666666667), (medae, 1), (r2_score, 0.0), (sign, -1)),
)
def test_all_2d_metrics(metric, right_metrics_value, y_true_2d, y_pred_2d):
assert round(metric(y_true_2d, y_pred_2d), 10) == right_metrics_value
19 changes: 12 additions & 7 deletions tests/test_metrics/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from etna.metrics import mse
from etna.metrics import msle
from etna.metrics import r2_score
from etna.metrics import rmse
from etna.metrics import sign
from etna.metrics import smape
from etna.metrics.base import MetricAggregationMode
Expand All @@ -16,6 +17,7 @@
from etna.metrics.metrics import MSE
from etna.metrics.metrics import MSLE
from etna.metrics.metrics import R2
from etna.metrics.metrics import RMSE
from etna.metrics.metrics import SMAPE
from etna.metrics.metrics import MedAE
from etna.metrics.metrics import Sign
Expand All @@ -28,6 +30,7 @@
(
(MAE, "MAE", {}, ""),
(MSE, "MSE", {}, ""),
(RMSE, "RMSE", {}, ""),
(MedAE, "MedAE", {}, ""),
(MSLE, "MSLE", {}, ""),
(MAPE, "MAPE", {}, ""),
Expand All @@ -50,7 +53,7 @@ def test_repr(metric_class, metric_class_repr, metric_params, param_repr):

@pytest.mark.parametrize(
"metric_class",
(MAE, MSE, MedAE, MSLE, MAPE, SMAPE, R2, Sign),
(MAE, MSE, RMSE, MedAE, MSLE, MAPE, SMAPE, R2, Sign),
)
def test_name_class_name(metric_class):
"""Check metrics name property without changing its during inheritance"""
Expand All @@ -74,7 +77,7 @@ def test_name_repr(metric_class):
assert metric_name == true_name


@pytest.mark.parametrize("metric_class", (MAE, MSE, MedAE, MSLE, MAPE, SMAPE, R2, Sign))
@pytest.mark.parametrize("metric_class", (MAE, MSE, RMSE, MedAE, MSLE, MAPE, SMAPE, R2, Sign))
def test_metrics_macro(metric_class, train_test_dfs):
"""Check metrics interface in 'macro' mode"""
forecast_df, true_df = train_test_dfs
Expand All @@ -83,7 +86,7 @@ def test_metrics_macro(metric_class, train_test_dfs):
assert isinstance(value, float)


@pytest.mark.parametrize("metric_class", (MAE, MSE, MedAE, MSLE, MAPE, SMAPE, R2, Sign, DummyMetric))
@pytest.mark.parametrize("metric_class", (MAE, MSE, RMSE, MedAE, MSLE, MAPE, SMAPE, R2, Sign, DummyMetric))
def test_metrics_per_segment(metric_class, train_test_dfs):
"""Check metrics interface in 'per-segment' mode"""
forecast_df, true_df = train_test_dfs
Expand All @@ -94,14 +97,14 @@ def test_metrics_per_segment(metric_class, train_test_dfs):
assert segment in value


@pytest.mark.parametrize("metric_class", (MAE, MSE, MedAE, MSLE, MAPE, SMAPE, R2, Sign, DummyMetric))
@pytest.mark.parametrize("metric_class", (MAE, MSE, RMSE, MedAE, MSLE, MAPE, SMAPE, R2, Sign, DummyMetric))
def test_metrics_invalid_aggregation(metric_class):
"""Check metrics behavior in case of invalid aggregation mode"""
with pytest.raises(NotImplementedError):
_ = metric_class(mode="a")


@pytest.mark.parametrize("metric_class", (MAE, MSE, MedAE, MSLE, MAPE, SMAPE, R2, Sign, DummyMetric))
@pytest.mark.parametrize("metric_class", (MAE, MSE, RMSE, MedAE, MSLE, MAPE, SMAPE, R2, Sign, DummyMetric))
def test_invalid_timestamps(metric_class, two_dfs_with_different_timestamps):
"""Check metrics behavior in case of invalid timeranges"""
forecast_df, true_df = two_dfs_with_different_timestamps
Expand All @@ -110,7 +113,7 @@ def test_invalid_timestamps(metric_class, two_dfs_with_different_timestamps):
_ = metric(y_true=true_df, y_pred=forecast_df)


@pytest.mark.parametrize("metric_class", (MAE, MSE, MedAE, MSLE, MAPE, SMAPE, R2, Sign, DummyMetric))
@pytest.mark.parametrize("metric_class", (MAE, MSE, RMSE, MedAE, MSLE, MAPE, SMAPE, R2, Sign, DummyMetric))
def test_invalid_segments(metric_class, two_dfs_with_different_segments_sets):
"""Check metrics behavior in case of invalid segments sets"""
forecast_df, true_df = two_dfs_with_different_segments_sets
Expand All @@ -119,7 +122,7 @@ def test_invalid_segments(metric_class, two_dfs_with_different_segments_sets):
_ = metric(y_true=true_df, y_pred=forecast_df)


@pytest.mark.parametrize("metric_class", (MAE, MSE, MedAE, MSLE, MAPE, SMAPE, R2, Sign, DummyMetric))
@pytest.mark.parametrize("metric_class", (MAE, MSE, RMSE, MedAE, MSLE, MAPE, SMAPE, R2, Sign, DummyMetric))
def test_invalid_segments_target(metric_class, train_test_dfs):
"""Check metrics behavior in case of no target column in segment"""
forecast_df, true_df = train_test_dfs
Expand All @@ -134,6 +137,7 @@ def test_invalid_segments_target(metric_class, train_test_dfs):
(
(MAE, mae),
(MSE, mse),
(RMSE, rmse),
(MedAE, medae),
(MSLE, msle),
(MAPE, mape),
Expand Down Expand Up @@ -164,6 +168,7 @@ def test_metrics_values(metric_class, metric_fn, train_test_dfs):
(
(MAE(), False),
(MSE(), False),
(RMSE(), False),
(MedAE(), False),
(MSLE(), False),
(MAPE(), False),
Expand Down