Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add RelativeSquaredError #1765

Merged
merged 26 commits into from
May 11, 2023
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
b3ec758
Added functional and Module implementation of Relative Squared Error …
wbeardall May 8, 2023
22aded1
Updated docs
wbeardall May 8, 2023
55242a0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 8, 2023
3581af6
Pre-commit check updates
wbeardall May 8, 2023
daddabf
Pre-commit check updates
wbeardall May 8, 2023
40431a7
Merge branch 'master' into master
Borda May 9, 2023
1f3aa69
Apply suggestions from code review
Borda May 9, 2023
d76ca38
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 9, 2023
8a6cf19
Merge branch 'master' into master
Borda May 9, 2023
190b88f
Apply suggestions from code review
Borda May 9, 2023
7c54e04
Merge branch 'master' into master
Borda May 9, 2023
dd50c25
Update src/torchmetrics/functional/regression/rse.py
justusschock May 9, 2023
7e21a87
changelog
SkafteNicki May 9, 2023
b215796
license
SkafteNicki May 9, 2023
80b2e96
Update src/torchmetrics/functional/regression/rse.py
SkafteNicki May 9, 2023
776e656
add plotting + docstring
SkafteNicki May 9, 2023
6439791
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 9, 2023
3ce09c5
fix typing
SkafteNicki May 9, 2023
4cefdf7
fix typing
SkafteNicki May 10, 2023
3304c41
Merge branch 'master' into master
SkafteNicki May 10, 2023
58e74b2
fix math
SkafteNicki May 10, 2023
025c27a
Merge branch 'master' into master
mergify[bot] May 10, 2023
ab6a9c8
fix broken test
SkafteNicki May 10, 2023
d44c169
Merge branch 'master' of https://github.com/wbeardall/torchmetrics in…
SkafteNicki May 10, 2023
3418385
try fix
SkafteNicki May 11, 2023
c50aa0f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 11, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions docs/source/regression/rse.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
.. customcarditem::
:header: Relative Squared Error
:image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/tabular_classification.svg
:tags: Regression

.. include:: ../links.rst

############################
Relative Squared Error (RSE)
############################

Module Interface
________________

.. autoclass:: torchmetrics.RelativeSquaredError
:noindex:
:exclude-members: update, compute

Functional Interface
____________________

.. autofunction:: torchmetrics.functional.relative_squared_error
:noindex:
2 changes: 2 additions & 0 deletions src/torchmetrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@
MinkowskiDistance,
PearsonCorrCoef,
R2Score,
RelativeSquaredError,
SpearmanCorrCoef,
SymmetricMeanAbsolutePercentageError,
TweedieDevianceScore,
Expand Down Expand Up @@ -182,6 +183,7 @@
"Recall",
"RecallAtFixedPrecision",
"RelativeAverageSpectralError",
"RelativeSquaredError",
"RetrievalFallOut",
"RetrievalHitRate",
"RetrievalMAP",
Expand Down
2 changes: 2 additions & 0 deletions src/torchmetrics/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@
minkowski_distance,
pearson_corrcoef,
r2_score,
relative_squared_error,
spearman_corrcoef,
symmetric_mean_absolute_percentage_error,
tweedie_deviance_score,
Expand Down Expand Up @@ -190,6 +191,7 @@
"r2_score",
"recall",
"relative_average_spectral_error",
"relative_squared_error",
"retrieval_average_precision",
"retrieval_fall_out",
"retrieval_hit_rate",
Expand Down
2 changes: 2 additions & 0 deletions src/torchmetrics/functional/regression/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from torchmetrics.functional.regression.mse import mean_squared_error
from torchmetrics.functional.regression.pearson import pearson_corrcoef
from torchmetrics.functional.regression.r2 import r2_score
from torchmetrics.functional.regression.rse import relative_squared_error
from torchmetrics.functional.regression.spearman import spearman_corrcoef
from torchmetrics.functional.regression.symmetric_mape import symmetric_mean_absolute_percentage_error
from torchmetrics.functional.regression.tweedie_deviance import tweedie_deviance_score
Expand All @@ -45,6 +46,7 @@
"mean_absolute_percentage_error",
"minkowski_distance",
"r2_score",
"relative_squared_error",
"spearman_corrcoef",
"symmetric_mean_absolute_percentage_error",
"tweedie_deviance_score",
Expand Down
66 changes: 66 additions & 0 deletions src/torchmetrics/functional/regression/rse.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
from typing import Union

import torch
from torch import Tensor

from torchmetrics.functional.regression.r2 import _r2_score_update


def _relative_squared_error_compute(
sum_squared_obs: Tensor,
sum_obs: Tensor,
sum_squared_error: Tensor,
n_obs: Union[int, Tensor],
squared: bool = True,
epsilon: float = 1.17e-06,
Borda marked this conversation as resolved.
Show resolved Hide resolved
) -> Tensor:
"""Computes Relative Squared Error.

Args:
sum_squared_obs: Sum of square of all observations
sum_obs: Sum of all observations
sum_squared_error: Residual sum of squares
n_obs: Number of predictions or observations
squared: Returns RRSE value if set to False.
epsilon: Specifies the lower bound for target values. Any target value below epsilon
is set to epsilon (avoids ``ZeroDivisionError``).
SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved

Example:
>>> target = torch.tensor([[0.5, 1], [-1, 1], [7, -6]])
>>> preds = torch.tensor([[0, 2], [-1, 2], [8, -5]])
>>> # RSE uses the same update function as R2 score.
>>> sum_squared_obs, sum_obs, rss, n_obs = _r2_score_update(preds, target)
>>> _relative_squared_error_compute(sum_squared_obs, sum_obs, rss, n_obs, squared=True)
tensor(0.0632)
"""
rse = sum_squared_error / torch.clamp((sum_squared_obs - sum_obs * sum_obs / n_obs), min=epsilon)
Borda marked this conversation as resolved.
Show resolved Hide resolved
if not squared:
rse = torch.sqrt(rse)
return torch.mean(rse)


def relative_squared_error(preds: Tensor, target: Tensor, squared: bool = True) -> Tensor:
r"""Computes the relative squared error (RSE).

.. math:: \text{RSE} = \frac{\\sum_i^N(y_i - \\hat{y_i})^2}{\\sum_i^N(y_i - \\overline{y})^2}
Borda marked this conversation as resolved.
Show resolved Hide resolved
Where :math:`y` is a tensor of target values with mean :math:`\\overline{y}`, and
:math:`\\hat{y}` is a tensor of predictions.

If `preds` and `targets` are 2D tensors, the RSE is averaged over the second dim.

Args:
preds: estimated labels
target: ground truth labels
squared: returns RRSE value if set to False
Return:
Tensor with RSE

Example:
>>> from torchmetrics.functional.regression import relative_squared_error
>>> target = torch.tensor([3, -0.5, 2, 7])
>>> preds = torch.tensor([2.5, 0.0, 2, 8])
>>> relative_squared_error(preds, target)
tensor(0.0514)
"""
sum_squared_obs, sum_obs, rss, n_obs = _r2_score_update(preds, target)
return _relative_squared_error_compute(sum_squared_obs, sum_obs, rss, n_obs, squared=squared)
2 changes: 2 additions & 0 deletions src/torchmetrics/regression/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from torchmetrics.regression.mse import MeanSquaredError
from torchmetrics.regression.pearson import PearsonCorrCoef
from torchmetrics.regression.r2 import R2Score
from torchmetrics.regression.rse import RelativeSquaredError
from torchmetrics.regression.spearman import SpearmanCorrCoef
from torchmetrics.regression.symmetric_mape import SymmetricMeanAbsolutePercentageError
from torchmetrics.regression.tweedie_deviance import TweedieDevianceScore
Expand All @@ -43,6 +44,7 @@
"MeanSquaredError",
"PearsonCorrCoef",
"R2Score",
"RelativeSquaredError",
"SpearmanCorrCoef",
"SymmetricMeanAbsolutePercentageError",
"TweedieDevianceScore",
Expand Down
70 changes: 70 additions & 0 deletions src/torchmetrics/regression/rse.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
from typing import Any

import torch
from torch import Tensor, tensor

from torchmetrics.functional.regression.r2 import _r2_score_update
from torchmetrics.functional.regression.rse import _relative_squared_error_compute
from torchmetrics.metric import Metric


class RelativeSquaredError(Metric):
r"""Computes the relative squared error (RSE).

.. math:: \text{RSE} = \frac{\sum_i^N(y_i - \hat{y_i})^2}{\sum_i^N(y_i - \overline{y})^2}
Where :math:`y` is a tensor of target values with mean :math:`\overline{y}`, and
Borda marked this conversation as resolved.
Show resolved Hide resolved
:math:`\hat{y}` is a tensor of predictions.

If num_outputs > 1, the returned value is averaged over all the outputs.

Args:
num_outputs: Number of outputs in multioutput setting
squared: If True returns RSE value, if False returns RRSE value.
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.

Example:
>>> from torchmetrics.regression import RelativeSquaredError
>>> target = torch.tensor([3, -0.5, 2, 7])
>>> preds = torch.tensor([2.5, 0.0, 2, 8])
>>> relative_squared_error = RelativeSquaredError()
>>> relative_squared_error(preds, target)
tensor(0.0514)
"""
is_differentiable = True
higher_is_better = False
full_state_update = False
sum_squared_error: Tensor
sum_error: Tensor
residual: Tensor
total: Tensor

def __init__(
self,
num_outputs: int = 1,
squared: bool = True,
**kwargs: Any,
) -> None:
super().__init__(**kwargs)

self.num_outputs = num_outputs

self.add_state("sum_squared_error", default=torch.zeros(self.num_outputs), dist_reduce_fx="sum")
self.add_state("sum_error", default=torch.zeros(self.num_outputs), dist_reduce_fx="sum")
self.add_state("residual", default=torch.zeros(self.num_outputs), dist_reduce_fx="sum")
self.add_state("total", default=tensor(0), dist_reduce_fx="sum")
self.squared = squared

def update(self, preds: Tensor, target: Tensor) -> None:
"""Update state with predictions and targets."""
sum_squared_error, sum_error, residual, total = _r2_score_update(preds, target)

self.sum_squared_error += sum_squared_error
self.sum_error += sum_error
self.residual += residual
self.total += total

def compute(self) -> Tensor:
"""Computes relative squared error over state."""
return _relative_squared_error_compute(
self.sum_squared_error, self.sum_error, self.residual, self.total, squared=self.squared
)
135 changes: 135 additions & 0 deletions tests/unittests/regression/test_rse.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
from collections import namedtuple
from functools import partial

import numpy as np
import pytest
import torch

from torchmetrics.functional import relative_squared_error
from torchmetrics.regression import RelativeSquaredError
from unittests import BATCH_SIZE, NUM_BATCHES
from unittests.helpers import seed_all
from unittests.helpers.testers import MetricTester

seed_all(42)

num_targets = 5

Input = namedtuple("Input", ["preds", "target"])

_single_target_inputs = Input(
preds=torch.rand(NUM_BATCHES, BATCH_SIZE),
target=torch.rand(NUM_BATCHES, BATCH_SIZE),
)

_multi_target_inputs = Input(
preds=torch.rand(NUM_BATCHES, BATCH_SIZE, num_targets),
target=torch.rand(NUM_BATCHES, BATCH_SIZE, num_targets),
)


def _sk_rse(target, preds, squared):
mean = np.mean(target, axis=0, keepdims=True)
error = target - preds
sum_squared_error = np.sum(error * error, axis=0)
deviation = target - mean
sum_squared_deviation = np.sum(deviation * deviation, axis=0)
rse = sum_squared_error / np.maximum(sum_squared_deviation, 1.17e-06)
if not squared:
rse = np.sqrt(rse)
return np.mean(rse)


def _single_target_ref_metric(preds, target, squared):
sk_preds = preds.view(-1).numpy()
sk_target = target.view(-1).numpy()
return _sk_rse(sk_target, sk_preds, squared=squared)


def _multi_target_ref_metric(preds, target, squared):
sk_preds = preds.view(-1, num_targets).numpy()
sk_target = target.view(-1, num_targets).numpy()
return _sk_rse(sk_target, sk_preds, squared=squared)


@pytest.mark.parametrize("squared", [False, True])
@pytest.mark.parametrize(
"preds, target, ref_metric, num_outputs",
[
(_single_target_inputs.preds, _single_target_inputs.target, _single_target_ref_metric, 1),
(_multi_target_inputs.preds, _multi_target_inputs.target, _multi_target_ref_metric, num_targets),
],
)
class TestRelativeSquaredError(MetricTester):
"""Test class for `RelativeSquaredError` metric."""

@pytest.mark.parametrize("ddp", [True, False])
def test_rse(self, squared, preds, target, ref_metric, num_outputs, ddp):
"""Test class implementation of metric."""
self.run_class_metric_test(
ddp,
preds,
target,
RelativeSquaredError,
partial(ref_metric, squared=squared),
metric_args={"squared": squared, "num_outputs": num_outputs},
)

def test_rse_functional(self, squared, preds, target, ref_metric, num_outputs):
"""Test functional implementation of metric."""
self.run_functional_metric_test(
preds,
target,
relative_squared_error,
partial(ref_metric, squared=squared),
metric_args={"squared": squared},
)

def test_rse_differentiability(self, squared, preds, target, ref_metric, num_outputs):
"""Test the differentiability of the metric, according to its `is_differentiable` attribute."""
self.run_differentiability_test(
preds=preds,
target=target,
metric_module=partial(RelativeSquaredError, num_outputs=num_outputs),
metric_functional=relative_squared_error,
metric_args={"squared": squared},
)

@pytest.mark.xfail(raises=RuntimeError, reason="clamp_min_cpu not implented for `Half`.")
def test_rse_half_cpu(self, squared, preds, target, ref_metric, num_outputs):
"""Test dtype support of the metric on CPU."""
self.run_precision_test_cpu(
preds,
target,
partial(RelativeSquaredError, num_outputs=num_outputs),
relative_squared_error,
{"squared": squared},
)

@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda")
def test_rse_half_gpu(self, squared, preds, target, ref_metric, num_outputs):
"""Test dtype support of the metric on GPU."""
self.run_precision_test_gpu(
preds,
target,
partial(RelativeSquaredError, num_outputs=num_outputs),
relative_squared_error,
{"squared": squared},
)


def test_error_on_different_shape(metric_class=RelativeSquaredError):
"""Test that error is raised on different shapes of input."""
metric = metric_class()
with pytest.raises(RuntimeError, match="Predictions and targets are expected to have the same shape"):
metric(torch.randn(100), torch.randn(50))


def test_error_on_multidim_tensors(metric_class=RelativeSquaredError):
"""Test that error is raised if a larger than 2D tensor is given as input."""
metric = metric_class()
with pytest.raises(
ValueError,
match=r"Expected both prediction and target to be 1D or 2D tensors," r" but received tensors with dimension .",
):
metric(torch.randn(10, 20, 5), torch.randn(10, 20, 5))