diff --git a/CHANGELOG.md b/CHANGELOG.md index 87d7ba3318b..c1e048a09b5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -34,6 +34,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added `SacreBLEUScore` metric to text package ([#546](https://github.com/PyTorchLightning/metrics/pull/546)) +- Added simple aggregation metrics: `SumMetric`, `MeanMetric`, `CatMetric`, `MinMetric`, `MaxMetric` ([#506](https://github.com/PyTorchLightning/metrics/pull/506)) + + ### Changed - `AveragePrecision` will now as default output the `macro` average for multilabel and multiclass problems ([#477](https://github.com/PyTorchLightning/metrics/pull/477)) @@ -42,9 +45,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - `half`, `double`, `float` will no longer change the dtype of the metric states. Use `metric.set_dtype` instead ([#493](https://github.com/PyTorchLightning/metrics/pull/493)) -- Changed `is_differentiable` from property to a constant attribute ([#551](https://github.com/PyTorchLightning/metrics/pull/551)) +- Renamed `AverageMeter` to `MeanMetric` ([#506](https://github.com/PyTorchLightning/metrics/pull/506)) +- Changed `is_differentiable` from property to a constant attribute ([#551](https://github.com/PyTorchLightning/metrics/pull/551)) + ### Deprecated diff --git a/docs/source/references/modules.rst b/docs/source/references/modules.rst index 4a2b5c5b1b5..847ebef3ca9 100644 --- a/docs/source/references/modules.rst +++ b/docs/source/references/modules.rst @@ -14,10 +14,42 @@ metrics. .. autoclass:: torchmetrics.Metric :noindex: -We also have an ``AverageMeter`` class that is helpful for defining ad-hoc metrics, when creating -your own metric type might be too burdensome. -.. autoclass:: torchmetrics.AverageMeter +************************* +Basic Aggregation Metrics +************************* + +Torchmetrics comes with a number of metrics for aggregation of basic statistics: mean, max, min etc. of +either tensors or native python floats. + +CatMetric +~~~~~~~~~ + +.. autoclass:: torchmetrics.CatMetric + :noindex: + +MaxMetric +~~~~~~~~~ + +.. autoclass:: torchmetrics.MaxMetric + :noindex: + +MeanMetric +~~~~~~~~~~ + +.. autoclass:: torchmetrics.MeanMetric + :noindex: + +MinMetric +~~~~~~~~~ + +.. autoclass:: torchmetrics.MinMetric + :noindex: + +SumMetric +~~~~~~~~~ + +.. autoclass:: torchmetrics.SumMetric :noindex: ************* diff --git a/tests/bases/test_aggregation.py b/tests/bases/test_aggregation.py new file mode 100644 index 00000000000..106621e9cb4 --- /dev/null +++ b/tests/bases/test_aggregation.py @@ -0,0 +1,166 @@ +import numpy as np +import pytest +import torch + +from tests.helpers.testers import BATCH_SIZE, NUM_BATCHES, MetricTester +from torchmetrics.aggregation import CatMetric, MaxMetric, MeanMetric, MinMetric, SumMetric + + +def compare_mean(values, weights): + """reference implementation for mean aggregation.""" + return np.average(values.numpy(), weights=weights) + + +def compare_sum(values, weights): + """reference implementation for sum aggregation.""" + return np.sum(values.numpy()) + + +def compare_min(values, weights): + """reference implementation for min aggregation.""" + return np.min(values.numpy()) + + +def compare_max(values, weights): + """reference implementation for max aggregation.""" + return np.max(values.numpy()) + + +# wrap all other than mean metric to take an additional argument +# this lets them fit into the testing framework +class WrappedMinMetric(MinMetric): + """Wrapped min metric.""" + + def update(self, values, weights): + """only pass values on.""" + super().update(values) + + +class WrappedMaxMetric(MaxMetric): + """Wrapped max metric.""" + + def update(self, values, weights): + """only pass values on.""" + super().update(values) + + +class WrappedSumMetric(SumMetric): + """Wrapped min metric.""" + + def update(self, values, weights): + """only pass values on.""" + super().update(values) + + +class WrappedCatMetric(CatMetric): + """Wrapped cat metric.""" + + def update(self, values, weights): + """only pass values on.""" + super().update(values) + + +@pytest.mark.parametrize( + "values, weights", + [ + (torch.rand(NUM_BATCHES, BATCH_SIZE), torch.ones(NUM_BATCHES, BATCH_SIZE)), + (torch.rand(NUM_BATCHES, BATCH_SIZE), torch.rand(NUM_BATCHES, BATCH_SIZE) > 0.5), + (torch.rand(NUM_BATCHES, BATCH_SIZE, 2), torch.rand(NUM_BATCHES, BATCH_SIZE, 2) > 0.5), + ], +) +@pytest.mark.parametrize( + "metric_class, compare_fn", + [ + (WrappedMinMetric, compare_min), + (WrappedMaxMetric, compare_max), + (WrappedSumMetric, compare_sum), + (MeanMetric, compare_mean), + ], +) +class TestAggregation(MetricTester): + """Test aggregation metrics.""" + + @pytest.mark.parametrize("ddp", [False, True]) + @pytest.mark.parametrize("dist_sync_on_step", [False]) + def test_aggreagation(self, ddp, dist_sync_on_step, metric_class, compare_fn, values, weights): + """test modular implementation.""" + self.run_class_metric_test( + ddp=ddp, + dist_sync_on_step=dist_sync_on_step, + metric_class=metric_class, + sk_metric=compare_fn, + check_scriptable=True, + # Abuse of names here + preds=values, + target=weights, + ) + + +_case1 = float("nan") * torch.ones(5) +_case2 = torch.tensor([1.0, 2.0, float("nan"), 4.0, 5.0]) + + +@pytest.mark.parametrize("value", [_case1, _case2]) +@pytest.mark.parametrize("nan_strategy", ["error", "warn"]) +@pytest.mark.parametrize("metric_class", [MinMetric, MaxMetric, SumMetric, MeanMetric, CatMetric]) +def test_nan_error(value, nan_strategy, metric_class): + """test correct errors are raised.""" + metric = metric_class(nan_strategy=nan_strategy) + if nan_strategy == "error": + with pytest.raises(RuntimeError, match="Encounted `nan` values in tensor"): + metric(value.clone()) + elif nan_strategy == "warn": + with pytest.warns(UserWarning, match="Encounted `nan` values in tensor"): + metric(value.clone()) + + +@pytest.mark.parametrize( + "metric_class, nan_strategy, value, expected", + [ + (MinMetric, "ignore", _case1, torch.tensor(float("inf"))), + (MinMetric, 2.0, _case1, 2.0), + (MinMetric, "ignore", _case2, 1.0), + (MinMetric, 2.0, _case2, 1.0), + (MaxMetric, "ignore", _case1, -torch.tensor(float("inf"))), + (MaxMetric, 2.0, _case1, 2.0), + (MaxMetric, "ignore", _case2, 5.0), + (MaxMetric, 2.0, _case2, 5.0), + (SumMetric, "ignore", _case1, 0.0), + (SumMetric, 2.0, _case1, 10.0), + (SumMetric, "ignore", _case2, 12.0), + (SumMetric, 2.0, _case2, 14.0), + (MeanMetric, "ignore", _case1, torch.tensor([float("nan")])), + (MeanMetric, 2.0, _case1, 2.0), + (MeanMetric, "ignore", _case2, 3.0), + (MeanMetric, 2.0, _case2, 2.8), + (CatMetric, "ignore", _case1, []), + (CatMetric, 2.0, _case1, torch.tensor([2.0, 2.0, 2.0, 2.0, 2.0])), + (CatMetric, "ignore", _case2, torch.tensor([1.0, 2.0, 4.0, 5.0])), + (CatMetric, 2.0, _case2, torch.tensor([1.0, 2.0, 2.0, 4.0, 5.0])), + ], +) +def test_nan_expected(metric_class, nan_strategy, value, expected): + """test that nan values are handled correctly.""" + metric = metric_class(nan_strategy=nan_strategy) + metric.update(value.clone()) + out = metric.compute() + assert np.allclose(out, expected, equal_nan=True) + + +@pytest.mark.parametrize("metric_class", [MinMetric, MaxMetric, SumMetric, MeanMetric, CatMetric]) +def test_error_on_wrong_nan_strategy(metric_class): + """test error raised on wrong nan_strategy argument.""" + with pytest.raises(ValueError, match="Arg `nan_strategy` should either .*"): + metric_class(nan_strategy=[]) + + +@pytest.mark.skipif(not hasattr(torch, "broadcast_to"), reason="PyTorch <1.8 does not have broadcast_to") +@pytest.mark.parametrize( + "weights, expected", [(1, 11.5), (torch.ones(2, 1, 1), 11.5), (torch.tensor([1, 2]).reshape(2, 1, 1), 13.5)] +) +def test_mean_metric_broadcasting(weights, expected): + """check that weight broadcasting works for mean metric.""" + values = torch.arange(24).reshape(2, 3, 4) + avg = MeanMetric() + + assert avg(values, weights) == expected diff --git a/tests/bases/test_average.py b/tests/bases/test_average.py deleted file mode 100644 index 9c84caf8ddc..00000000000 --- a/tests/bases/test_average.py +++ /dev/null @@ -1,88 +0,0 @@ -import numpy as np -import pytest -import torch - -from tests.helpers.testers import BATCH_SIZE, NUM_BATCHES, MetricTester -from torchmetrics.average import AverageMeter - - -def average(values, weights): - return np.average(values, weights=weights) - - -def average_ignore_weights(values, weights): - return np.average(values) - - -class DefaultWeightWrapper(AverageMeter): - def update(self, values, weights): - super().update(values) - - -class ScalarWrapper(AverageMeter): - def update(self, values, weights): - # torch.ravel is PyTorch 1.8 only, so use np.ravel instead - values = values.cpu().numpy() - weights = weights.cpu().numpy() - for v, w in zip(np.ravel(values), np.ravel(weights)): - super().update(float(v), float(w)) - - -@pytest.mark.parametrize( - "values, weights", - [ - (torch.rand(NUM_BATCHES, BATCH_SIZE), torch.ones(NUM_BATCHES, BATCH_SIZE)), - (torch.rand(NUM_BATCHES, BATCH_SIZE), torch.rand(NUM_BATCHES, BATCH_SIZE) > 0.5), - (torch.rand(NUM_BATCHES, BATCH_SIZE, 2), torch.rand(NUM_BATCHES, BATCH_SIZE, 2) > 0.5), - ], -) -class TestAverageMeter(MetricTester): - @pytest.mark.parametrize("ddp", [False, True]) - @pytest.mark.parametrize("dist_sync_on_step", [False, True]) - def test_average_fn(self, ddp, dist_sync_on_step, values, weights): - self.run_class_metric_test( - ddp=ddp, - dist_sync_on_step=dist_sync_on_step, - metric_class=AverageMeter, - sk_metric=average, - # Abuse of names here - preds=values, - target=weights, - ) - - @pytest.mark.parametrize("ddp", [False, True]) - @pytest.mark.parametrize("dist_sync_on_step", [False, True]) - def test_average_fn_default(self, ddp, dist_sync_on_step, values, weights): - self.run_class_metric_test( - ddp=ddp, - dist_sync_on_step=dist_sync_on_step, - metric_class=DefaultWeightWrapper, - sk_metric=average_ignore_weights, - # Abuse of names here - preds=values, - target=weights, - ) - - @pytest.mark.parametrize("ddp", [False, True]) - @pytest.mark.parametrize("dist_sync_on_step", [False, True]) - def test_average_fn_scalar(self, ddp, dist_sync_on_step, values, weights): - self.run_class_metric_test( - ddp=ddp, - dist_sync_on_step=dist_sync_on_step, - metric_class=ScalarWrapper, - sk_metric=average, - # Abuse of names here - preds=values, - target=weights, - ) - - -@pytest.mark.skipif(not hasattr(torch, "broadcast_to"), reason="PyTorch <1.8 does not have broadcast_to") -@pytest.mark.parametrize( - "weights, expected", [(1, 11.5), (torch.ones(2, 1, 1), 11.5), (torch.tensor([1, 2]).reshape(2, 1, 1), 13.5)] -) -def test_AverageMeter_broadcasting(weights, expected): - values = torch.arange(24).reshape(2, 3, 4) - avg = AverageMeter() - - assert avg(values, weights) == expected diff --git a/torchmetrics/__init__.py b/torchmetrics/__init__.py index e9908822fc6..45d52ca4d1d 100644 --- a/torchmetrics/__init__.py +++ b/torchmetrics/__init__.py @@ -12,8 +12,8 @@ _PROJECT_ROOT = os.path.dirname(_PACKAGE_ROOT) from torchmetrics import functional # noqa: E402 +from torchmetrics.aggregation import CatMetric, MaxMetric, MeanMetric, MinMetric, SumMetric # noqa: E402 from torchmetrics.audio import PESQ, PIT, SI_SDR, SI_SNR, SNR # noqa: E402 -from torchmetrics.average import AverageMeter # noqa: E402 from torchmetrics.classification import ( # noqa: E402 AUC, AUROC, @@ -71,7 +71,6 @@ "Accuracy", "AUC", "AUROC", - "AverageMeter", "AveragePrecision", "BinnedAveragePrecision", "BinnedPrecisionRecallCurve", @@ -80,6 +79,7 @@ "BLEUScore", "BootStrapper", "CalibrationError", + "CatMetric", "CohenKappa", "ConfusionMatrix", "CosineSimilarity", @@ -96,13 +96,16 @@ "KLDivergence", "LPIPS", "MatthewsCorrcoef", + "MaxMetric", "MeanAbsoluteError", "MeanAbsolutePercentageError", + "MeanMetric", "MeanSquaredError", "MeanSquaredLogError", "Metric", "MetricCollection", "MetricTracker", + "MinMetric", "MultioutputWrapper", "PearsonCorrcoef", "PESQ", @@ -128,6 +131,7 @@ "Specificity", "SSIM", "StatScores", + "SumMetric", "SymmetricMeanAbsolutePercentageError", "WER", ] diff --git a/torchmetrics/aggregation.py b/torchmetrics/aggregation.py new file mode 100644 index 00000000000..e009abfec2b --- /dev/null +++ b/torchmetrics/aggregation.py @@ -0,0 +1,445 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import warnings +from typing import Any, Callable, List, Optional, Union + +import torch +from torch import Tensor + +from torchmetrics.metric import Metric +from torchmetrics.utilities.data import dim_zero_cat + + +class BaseAggregator(Metric): + """Base class for aggregation metrics. + + Args: + fn: string specifying the reduction function + default_value: default tensor value to use for the metric state + nan_strategy: options: + - ``'error'``: if any `nan` values are encounted will give a RuntimeError + - ``'warn'``: if any `nan` values are encounted will give a warning and continue + - ``'ignore'``: all `nan` values are silently removed + - a float: if a float is provided will impude any `nan` values with this value + + compute_on_step: + Forward only calls ``update()`` and returns None if this is + set to False. default: True + dist_sync_on_step: + Synchronize metric state across processes at each ``forward()`` + before returning the value at the step. + process_group: + Specify the process group on which synchronization is called. + default: None (which selects the entire world) + dist_sync_fn: + Callback that performs the allgather operation on the metric state. + When `None`, DDP will be used to perform the allgather. + + Raises: + ValueError: + If ``nan_strategy`` is not one of ``error``, ``warn``, ``ignore`` or a float + """ + + value: Tensor + is_differentiable = None + higher_is_better = None + + def __init__( + self, + fn: Union[Callable, str], + default_value: Union[Tensor, List], + nan_strategy: Union[str, float] = "error", + compute_on_step: bool = True, + dist_sync_on_step: bool = False, + process_group: Optional[Any] = None, + dist_sync_fn: Callable = None, + ): + super().__init__( + compute_on_step=compute_on_step, + dist_sync_on_step=dist_sync_on_step, + process_group=process_group, + dist_sync_fn=dist_sync_fn, + ) + allowed_nan_strategy = ("error", "warn", "ignore") + if nan_strategy not in allowed_nan_strategy and not isinstance(nan_strategy, float): + raise ValueError( + f"Arg `nan_strategy` should either be a float or one of {allowed_nan_strategy}" + f" but got {nan_strategy}." + ) + + self.nan_strategy = nan_strategy + self.add_state("value", default=default_value, dist_reduce_fx=fn) + + def _cast_and_nan_check_input(self, x: Union[float, Tensor]) -> Tensor: + """Converts input x to a tensor if not already and afterwards checks for nans that either give an error, + warning or just ignored.""" + if not isinstance(x, Tensor): + x = torch.as_tensor(x, dtype=torch.float32, device=self.device) + + nans = torch.isnan(x) + if any(nans.flatten()): + if self.nan_strategy == "error": + raise RuntimeError("Encounted `nan` values in tensor") + if self.nan_strategy == "warn": + warnings.warn("Encounted `nan` values in tensor. Will be removed.", UserWarning) + x = x[~nans] + elif self.nan_strategy == "ignore": + x = x[~nans] + else: + x[nans] = self.nan_strategy + + return x.float() + + def update(self, value: Union[float, Tensor]) -> None: # type: ignore + """Overwrite in child class.""" + pass + + def compute(self) -> Tensor: + """Compute the aggregated value.""" + return self.value.squeeze() if isinstance(self.value, Tensor) else self.value + + +class MaxMetric(BaseAggregator): + """Aggregate a stream of value into their maximum value. + + Args: + nan_strategy: options: + - ``'error'``: if any `nan` values are encounted will give a RuntimeError + - ``'warn'``: if any `nan` values are encounted will give a warning and continue + - ``'ignore'``: all `nan` values are silently removed + - a float: if a float is provided will impude any `nan` values with this value + + compute_on_step: + Forward only calls ``update()`` and returns None if this is + set to False. default: True + dist_sync_on_step: + Synchronize metric state across processes at each ``forward()`` + before returning the value at the step. + process_group: + Specify the process group on which synchronization is called. + default: None (which selects the entire world) + dist_sync_fn: + Callback that performs the allgather operation on the metric state. + When `None`, DDP will be used to perform the allgather. + + Raises: + ValueError: + If ``nan_strategy`` is not one of ``error``, ``warn``, ``ignore`` or a float + + Example: + >>> from torchmetrics import MaxMetric + >>> metric = MaxMetric() + >>> metric.update(1) + >>> metric.update(torch.tensor([2, 3])) + >>> metric.compute() + tensor(3.) + """ + + def __init__( + self, + nan_strategy: Union[str, float] = "warn", + compute_on_step: bool = True, + dist_sync_on_step: bool = False, + process_group: Optional[Any] = None, + dist_sync_fn: Callable = None, + ): + super().__init__( + "max", + -torch.tensor(float("inf")), + nan_strategy, + compute_on_step, + dist_sync_on_step, + process_group, + dist_sync_fn, + ) + + def update(self, value: Union[float, Tensor]) -> None: # type: ignore + """Update state with data. + + Args: + value: Either a float or tensor containing data. Additional tensor + dimensions will be flattened + """ + value = self._cast_and_nan_check_input(value) + if any(value.flatten()): # make sure tensor not empty + self.value = torch.max(self.value, torch.max(value)) + + +class MinMetric(BaseAggregator): + """Aggregate a stream of value into their minimum value. + + Args: + nan_strategy: options: + - ``'error'``: if any `nan` values are encounted will give a RuntimeError + - ``'warn'``: if any `nan` values are encounted will give a warning and continue + - ``'ignore'``: all `nan` values are silently removed + - a float: if a float is provided will impude any `nan` values with this value + + compute_on_step: + Forward only calls ``update()`` and returns None if this is + set to False. default: True + dist_sync_on_step: + Synchronize metric state across processes at each ``forward()`` + before returning the value at the step. + process_group: + Specify the process group on which synchronization is called. + default: None (which selects the entire world) + dist_sync_fn: + Callback that performs the allgather operation on the metric state. + When `None`, DDP will be used to perform the allgather. + + Raises: + ValueError: + If ``nan_strategy`` is not one of ``error``, ``warn``, ``ignore`` or a float + + Example: + >>> from torchmetrics import MinMetric + >>> metric = MinMetric() + >>> metric.update(1) + >>> metric.update(torch.tensor([2, 3])) + >>> metric.compute() + tensor(1.) + """ + + def __init__( + self, + nan_strategy: Union[str, float] = "warn", + compute_on_step: bool = True, + dist_sync_on_step: bool = False, + process_group: Optional[Any] = None, + dist_sync_fn: Callable = None, + ): + super().__init__( + "min", + torch.tensor(float("inf")), + nan_strategy, + compute_on_step, + dist_sync_on_step, + process_group, + dist_sync_fn, + ) + + def update(self, value: Union[float, Tensor]) -> None: # type: ignore + """Update state with data. + + Args: + value: Either a float or tensor containing data. Additional tensor + dimensions will be flattened + """ + value = self._cast_and_nan_check_input(value) + if any(value.flatten()): # make sure tensor not empty + self.value = torch.min(self.value, torch.min(value)) + + +class SumMetric(BaseAggregator): + """Aggregate a stream of value into their sum. + + Args: + nan_strategy: options: + - ``'error'``: if any `nan` values are encounted will give a RuntimeError + - ``'warn'``: if any `nan` values are encounted will give a warning and continue + - ``'ignore'``: all `nan` values are silently removed + - a float: if a float is provided will impude any `nan` values with this value + + compute_on_step: + Forward only calls ``update()`` and returns None if this is + set to False. default: True + dist_sync_on_step: + Synchronize metric state across processes at each ``forward()`` + before returning the value at the step. + process_group: + Specify the process group on which synchronization is called. + default: None (which selects the entire world) + dist_sync_fn: + Callback that performs the allgather operation on the metric state. + When `None`, DDP will be used to perform the allgather. + + Raises: + ValueError: + If ``nan_strategy`` is not one of ``error``, ``warn``, ``ignore`` or a float + + Example: + >>> from torchmetrics import SumMetric + >>> metric = SumMetric() + >>> metric.update(1) + >>> metric.update(torch.tensor([2, 3])) + >>> metric.compute() + tensor(6.) + """ + + def __init__( + self, + nan_strategy: Union[str, float] = "warn", + compute_on_step: bool = True, + dist_sync_on_step: bool = False, + process_group: Optional[Any] = None, + dist_sync_fn: Callable = None, + ): + super().__init__( + "sum", torch.zeros(1), nan_strategy, compute_on_step, dist_sync_on_step, process_group, dist_sync_fn + ) + + def update(self, value: Union[float, Tensor]) -> None: # type: ignore + """Update state with data. + + Args: + value: Either a float or tensor containing data. Additional tensor + dimensions will be flattened + """ + value = self._cast_and_nan_check_input(value) + self.value += value.sum() + + +class CatMetric(BaseAggregator): + """Concatenate a stream of values. + + Args: + nan_strategy: options: + - ``'error'``: if any `nan` values are encounted will give a RuntimeError + - ``'warn'``: if any `nan` values are encounted will give a warning and continue + - ``'ignore'``: all `nan` values are silently removed + - a float: if a float is provided will impude any `nan` values with this value + + compute_on_step: + Forward only calls ``update()`` and returns None if this is + set to False. default: True + dist_sync_on_step: + Synchronize metric state across processes at each ``forward()`` + before returning the value at the step. + process_group: + Specify the process group on which synchronization is called. + default: None (which selects the entire world) + dist_sync_fn: + Callback that performs the allgather operation on the metric state. + When `None`, DDP will be used to perform the allgather. + + Raises: + ValueError: + If ``nan_strategy`` is not one of ``error``, ``warn``, ``ignore`` or a float + + Example: + >>> from torchmetrics import CatMetric + >>> metric = CatMetric() + >>> metric.update(1) + >>> metric.update(torch.tensor([2, 3])) + >>> metric.compute() + tensor([1., 2., 3.]) + """ + + def __init__( + self, + nan_strategy: Union[str, float] = "warn", + compute_on_step: bool = True, + dist_sync_on_step: bool = False, + process_group: Optional[Any] = None, + dist_sync_fn: Callable = None, + ): + super().__init__("cat", [], nan_strategy, compute_on_step, dist_sync_on_step, process_group, dist_sync_fn) + + def update(self, value: Union[float, Tensor]) -> None: # type: ignore + """Update state with data. + + Args: + value: Either a float or tensor containing data. Additional tensor + dimensions will be flattened + """ + value = self._cast_and_nan_check_input(value) + if any(value.flatten()): + self.value.append(value) + + def compute(self) -> Tensor: + """Compute the aggregated value.""" + if isinstance(self.value, list) and self.value: + return dim_zero_cat(self.value) + return self.value + + +class MeanMetric(BaseAggregator): + """Aggregate a stream of value into their mean value. + + Args: + nan_strategy: options: + - ``'error'``: if any `nan` values are encounted will give a RuntimeError + - ``'warn'``: if any `nan` values are encounted will give a warning and continue + - ``'ignore'``: all `nan` values are silently removed + - a float: if a float is provided will impude any `nan` values with this value + + compute_on_step: + Forward only calls ``update()`` and returns None if this is + set to False. default: True + dist_sync_on_step: + Synchronize metric state across processes at each ``forward()`` + before returning the value at the step. + process_group: + Specify the process group on which synchronization is called. + default: None (which selects the entire world) + dist_sync_fn: + Callback that performs the allgather operation on the metric state. + When `None`, DDP will be used to perform the allgather. + + Raises: + ValueError: + If ``nan_strategy`` is not one of ``error``, ``warn``, ``ignore`` or a float + + Example: + >>> from torchmetrics import MeanMetric + >>> metric = MeanMetric() + >>> metric.update(1) + >>> metric.update(torch.tensor([2, 3])) + >>> metric.compute() + tensor([2.]) + """ + + def __init__( + self, + nan_strategy: Union[str, float] = "warn", + compute_on_step: bool = True, + dist_sync_on_step: bool = False, + process_group: Optional[Any] = None, + dist_sync_fn: Callable = None, + ): + super().__init__( + "sum", torch.zeros(1), nan_strategy, compute_on_step, dist_sync_on_step, process_group, dist_sync_fn + ) + self.add_state("weight", default=torch.zeros(1), dist_reduce_fx="sum") + + def update(self, value: Union[float, Tensor], weight: Union[float, Tensor] = 1.0) -> None: # type: ignore + """Update state with data. + + Args: + value: Either a float or tensor containing data. Additional tensor + dimensions will be flattened + weight: Either a float or tensor containing weights for calculating + the average. Shape of weight should be able to broadcast with + the shape of `value`. Default to `1.0` corresponding to simple + harmonic average. + """ + value = self._cast_and_nan_check_input(value) + weight = self._cast_and_nan_check_input(weight) + + # broadcast weight to values shape + if not hasattr(torch, "broadcast_to"): + if weight.shape == (): + weight = torch.ones_like(value) * weight + if weight.shape != value.shape: + raise ValueError("Broadcasting not supported on PyTorch <1.8") + else: + weight = torch.broadcast_to(weight, value.shape) + + self.value += (value * weight).sum() + self.weight += weight.sum() + + def compute(self) -> Tensor: + """Compute the aggregated value.""" + return self.value / self.weight diff --git a/torchmetrics/average.py b/torchmetrics/average.py deleted file mode 100644 index b602d57bbca..00000000000 --- a/torchmetrics/average.py +++ /dev/null @@ -1,109 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from typing import Any, Callable, Optional, Union - -import torch -from torch import Tensor - -from torchmetrics.metric import Metric - - -class AverageMeter(Metric): - """Computes the average of a stream of values. - - Forward accepts - - ``value`` (float tensor): ``(...)`` - - ``weight`` (float tensor): ``(...)`` - - Args: - compute_on_step: - Forward only calls ``update()`` and returns None if this is - set to False. default: True - dist_sync_on_step: - Synchronize metric state across processes at each ``forward()`` - before returning the value at the step. - process_group: - Specify the process group on which synchronization is called. - default: None (which selects the entire world) - dist_sync_fn: - Callback that performs the allgather operation on the metric state. - When `None`, DDP will be used to perform the allgather. - - Example:: - >>> from torchmetrics import AverageMeter - >>> avg = AverageMeter() - >>> avg.update(3) - >>> avg.update(1) - >>> avg.compute() - tensor(2.) - - >>> avg = AverageMeter() - >>> values = torch.tensor([1., 2., 3.]) - >>> avg(values) - tensor(2.) - - >>> avg = AverageMeter() - >>> values = torch.tensor([1., 2.]) - >>> weights = torch.tensor([3., 1.]) - >>> avg(values, weights) - tensor(1.2500) - """ - - value: Tensor - weight: Tensor - - def __init__( - self, - compute_on_step: bool = True, - dist_sync_on_step: bool = False, - process_group: Optional[Any] = None, - dist_sync_fn: Callable = None, - ) -> None: - super().__init__( - compute_on_step=compute_on_step, - dist_sync_on_step=dist_sync_on_step, - process_group=process_group, - dist_sync_fn=dist_sync_fn, - ) - self.add_state("value", torch.zeros(()), dist_reduce_fx="sum") - self.add_state("weight", torch.zeros(()), dist_reduce_fx="sum") - - # TODO: need to be strings because Unions are not pickleable in Python 3.6 - def update(self, value: "Union[Tensor, float]", weight: "Union[Tensor, float]" = 1.0) -> None: # type: ignore - """Updates the average with. - - Args: - value: A tensor of observations (can also be a scalar value) - weight: The weight of each observation (automatically broadcasted - to fit ``value``) - """ - if not isinstance(value, Tensor): - value = torch.as_tensor(value, dtype=torch.float32, device=self.value.device) - if not isinstance(weight, Tensor): - weight = torch.as_tensor(weight, dtype=torch.float32, device=self.weight.device) - - # braodcast_to only supported on PyTorch 1.8+ - if not hasattr(torch, "broadcast_to"): - if weight.shape == (): - weight = torch.ones_like(value) * weight - if weight.shape != value.shape: - raise ValueError("Broadcasting not supported on PyTorch <1.8") - else: - weight = torch.broadcast_to(weight, value.shape) - - self.value += (value * weight).sum() - self.weight += weight.sum() - - def compute(self) -> Tensor: - return self.value / self.weight diff --git a/torchmetrics/metric.py b/torchmetrics/metric.py index d23994d11fe..a24c359518f 100644 --- a/torchmetrics/metric.py +++ b/torchmetrics/metric.py @@ -25,7 +25,7 @@ from torch.nn import Module from torchmetrics.utilities import apply_to_collection, rank_zero_warn -from torchmetrics.utilities.data import _flatten, dim_zero_cat, dim_zero_mean, dim_zero_sum +from torchmetrics.utilities.data import _flatten, dim_zero_cat, dim_zero_max, dim_zero_mean, dim_zero_min, dim_zero_sum from torchmetrics.utilities.distributed import gather_all_tensors from torchmetrics.utilities.exceptions import TorchMetricsUserError from torchmetrics.utilities.imports import _LIGHTNING_AVAILABLE, _compare_version @@ -125,10 +125,10 @@ def add_state( default: Default value of the state; can either be a ``torch.Tensor`` or an empty list. The state will be reset to this value when ``self.reset()`` is called. dist_reduce_fx (Optional): Function to reduce state across multiple processes in distributed mode. - If value is ``"sum"``, ``"mean"``, or ``"cat"``, we will use ``torch.sum``, ``torch.mean``, - and ``torch.cat`` respectively, each with argument ``dim=0``. Note that the ``"cat"`` reduction - only makes sense if the state is a list, and not a tensor. The user can also pass a custom - function in this parameter. + If value is ``"sum"``, ``"mean"``, ``"cat"``, ``"min"`` or ``"max"`` we will use ``torch.sum``, + ``torch.mean``, ``torch.cat``, ``torch.min`` and ``torch.max``` respectively, each with argument + ``dim=0``. Note that the ``"cat"`` reduction only makes sense if the state is a list, and not + a tensor. The user can also pass a custom function in this parameter. persistent (Optional): whether the state will be saved as part of the modules ``state_dict``. Default is ``False``. @@ -162,6 +162,10 @@ def add_state( dist_reduce_fx = dim_zero_sum elif dist_reduce_fx == "mean": dist_reduce_fx = dim_zero_mean + elif dist_reduce_fx == "max": + dist_reduce_fx = dim_zero_max + elif dist_reduce_fx == "min": + dist_reduce_fx = dim_zero_min elif dist_reduce_fx == "cat": dist_reduce_fx = dim_zero_cat elif dist_reduce_fx is not None and not callable(dist_reduce_fx): diff --git a/torchmetrics/utilities/data.py b/torchmetrics/utilities/data.py index 46648352e8f..05caac19163 100644 --- a/torchmetrics/utilities/data.py +++ b/torchmetrics/utilities/data.py @@ -22,6 +22,7 @@ def dim_zero_cat(x: Union[Tensor, List[Tensor]]) -> Tensor: + """concatenation along the zero dimension.""" x = x if isinstance(x, (list, tuple)) else [x] x = [y.unsqueeze(0) if y.numel() == 1 and y.ndim == 0 else y for y in x] if not x: # empty list @@ -30,13 +31,25 @@ def dim_zero_cat(x: Union[Tensor, List[Tensor]]) -> Tensor: def dim_zero_sum(x: Tensor) -> Tensor: + """summation along the zero dimension.""" return torch.sum(x, dim=0) def dim_zero_mean(x: Tensor) -> Tensor: + """average along the zero dimension.""" return torch.mean(x, dim=0) +def dim_zero_max(x: Tensor) -> Tensor: + """max along the zero dimension.""" + return torch.max(x, dim=0).values + + +def dim_zero_min(x: Tensor) -> Tensor: + """min along the zero dimension.""" + return torch.min(x, dim=0).values + + def _flatten(x: Sequence) -> list: return [item for sublist in x for item in sublist]