Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

min max wrapper #556

Merged
merged 86 commits into from
Nov 17, 2021
Merged
Show file tree
Hide file tree
Changes from 47 commits
Commits
Show all changes
86 commits
Select commit Hold shift + click to select a range
e43dbb5
scaffolding of PR
janhenriklambrechts Sep 30, 2021
d3d7ff9
more scaffolding
janhenriklambrechts Sep 30, 2021
8e2457e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 30, 2021
79917a1
Add linting skip comment
janhenriklambrechts Oct 7, 2021
f0da817
Merge branch 'master' into test49/min-max-metric
Borda Oct 13, 2021
8b8624f
changed name to minmax
janhenriklambrechts Oct 13, 2021
26da745
implemented wrapper design and modified tests
janhenriklambrechts Oct 13, 2021
3a07657
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 13, 2021
6febac7
removed useless parameter from test
janhenriklambrechts Oct 13, 2021
519e714
flake + typing fixes
janhenriklambrechts Oct 13, 2021
b661e13
resolve merge conflict
janhenriklambrechts Oct 13, 2021
e297c66
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 13, 2021
8e38515
clean descriptions of minmax for docs
janhenriklambrechts Oct 13, 2021
fb87ae1
merge conflict
janhenriklambrechts Oct 13, 2021
649ba68
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 13, 2021
0fa8389
Merge branch 'master' into test49/min-max-metric
Borda Oct 14, 2021
f53b1f4
added MinMaxMetric to __all__
janhenriklambrechts Oct 14, 2021
c04abe5
Merge branch 'master' into test49/min-max-metric
Borda Oct 14, 2021
341ddf8
removed redundant device flag in test
janhenriklambrechts Oct 15, 2021
cdb44b3
added test and assertion when compute is not a scalar:
janhenriklambrechts Oct 15, 2021
a55c83e
introduced infinity as bounds
janhenriklambrechts Oct 15, 2021
1cd2241
Merge branch 'master' into test49/min-max-metric
janhenriklambrechts Oct 15, 2021
dd4c9ce
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 15, 2021
e7c7a1b
Merge branch 'master' into test49/min-max-metric
Borda Oct 20, 2021
1dc9a7c
Apply suggestions from code review
Borda Oct 20, 2021
7078a16
update typing in helper function
janhenriklambrechts Oct 20, 2021
306f690
summarize helper function
janhenriklambrechts Oct 20, 2021
6890b36
added example of minmaxmetric and removed debugging print statements
janhenriklambrechts Oct 21, 2021
ab76517
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 21, 2021
80456cc
changelog
SkafteNicki Oct 22, 2021
bde7ae5
docs
SkafteNicki Oct 22, 2021
a899406
Update torchmetrics/wrappers/minmax.py
janhenriklambrechts Oct 22, 2021
0887a70
Update torchmetrics/wrappers/minmax.py
janhenriklambrechts Oct 22, 2021
b9e34aa
Update torchmetrics/wrappers/minmax.py
janhenriklambrechts Oct 22, 2021
a055286
Merge branch 'master' into test49/min-max-metric
SkafteNicki Oct 25, 2021
4d8c296
remove personalizable metric values, implemented nicki comments
janhenriklambrechts Oct 25, 2021
2643fb7
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 25, 2021
8afd4b5
Update torchmetrics/wrappers/minmax.py
janhenriklambrechts Oct 25, 2021
321e87c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 25, 2021
6869815
fix implementation
SkafteNicki Oct 25, 2021
73700bd
improve tests
SkafteNicki Oct 25, 2021
76c8273
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 25, 2021
4d09a22
fix mypy
SkafteNicki Oct 25, 2021
f9574c2
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 25, 2021
15f31e2
prepare 0.6 RC
Borda Oct 25, 2021
6165910
Merge branch 'master' into test49/min-max-metric
SkafteNicki Oct 25, 2021
1a2191b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 25, 2021
ba83223
fix doctest
SkafteNicki Oct 25, 2021
667629d
Merge branch 'master' into test49/min-max-metric
SkafteNicki Oct 25, 2021
c3016b1
added pprint
janhenriklambrechts Oct 25, 2021
eb3e478
moved base test to parametrize
janhenriklambrechts Oct 25, 2021
3067d36
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 25, 2021
9f755ae
Merge branch 'master' into test49/min-max-metric
Borda Oct 25, 2021
8df7968
fix doctest with pprint
janhenriklambrechts Oct 26, 2021
c380d23
prune
Borda Oct 27, 2021
03ca76b
docs
Borda Oct 27, 2021
c61ba5d
fixing
Borda Oct 27, 2021
ed962df
.
Borda Oct 27, 2021
1201a3a
Merge branch 'master' into test49/min-max-metric
mergify[bot] Oct 27, 2021
6e1b862
Merge branch 'master' into test49/min-max-metric
mergify[bot] Oct 27, 2021
e522f3b
Merge branch 'master' into test49/min-max-metric
Borda Oct 28, 2021
d34692a
Merge branch 'master' into test49/min-max-metric
Borda Oct 28, 2021
b8760fd
release v0.6.0
Borda Oct 28, 2021
c0636ec
Merge branch 'master' into test49/min-max-metric
mergify[bot] Oct 28, 2021
f1eea2b
Merge branch 'master' into test49/min-max-metric
mergify[bot] Oct 28, 2021
d946cd2
release v0.6.0
Borda Oct 28, 2021
d3a7891
Merge branch 'master' into test49/min-max-metric
mergify[bot] Oct 28, 2021
a81f00d
Merge branch 'master' into test49/min-max-metric
mergify[bot] Oct 28, 2021
4ed2bce
Merge branch 'master' into test49/min-max-metric
mergify[bot] Oct 29, 2021
749724b
Merge branch 'master' into test49/min-max-metric
mergify[bot] Oct 29, 2021
6940092
update setup
Borda Nov 1, 2021
bab618f
Merge branch 'master' into test49/min-max-metric
mergify[bot] Nov 1, 2021
352afe6
Merge branch 'master' into test49/min-max-metric
mergify[bot] Nov 1, 2021
e68aaac
Merge branch 'master' into test49/min-max-metric
mergify[bot] Nov 1, 2021
3aaaba8
Merge branch 'master' into test49/min-max-metric
mergify[bot] Nov 2, 2021
9e4306a
Merge branch 'master' into test49/min-max-metric
mergify[bot] Nov 8, 2021
f5e09bd
Merge branch 'master' into test49/min-max-metric
mergify[bot] Nov 8, 2021
606404a
Merge branch 'master' into test49/min-max-metric
mergify[bot] Nov 13, 2021
5ab8e09
Merge branch 'master' into test49/min-max-metric
mergify[bot] Nov 13, 2021
d6749ea
Merge branch 'master' into test49/min-max-metric
SkafteNicki Nov 15, 2021
9f54e3d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 15, 2021
1c7bd07
update
SkafteNicki Nov 15, 2021
0052bac
fix tests
SkafteNicki Nov 15, 2021
b4e8b2a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 15, 2021
60d164a
Update tests/wrappers/test_minmax.py
SkafteNicki Nov 15, 2021
297943b
update
SkafteNicki Nov 15, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 15 additions & 50 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,67 +6,39 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

**Note: we move fast, but still we preserve 0.1 version (one feature release) back compatibility.**

## [unReleased] - 2021-MM-DD
## [0.6.0] - 2021-10-DD

### Added

- Added Learned Perceptual Image Patch Similarity (LPIPS) ([#431](https://github.com/PyTorchLightning/metrics/issues/431))


- Added Tweedie Deviance Score ([#499](https://github.com/PyTorchLightning/metrics/pull/499))


- Added audio metrics:
- Perceptual Evaluation of Speech Quality (PESQ) ([#353](https://github.com/PyTorchLightning/metrics/issues/353))
- Short Term Objective Intelligibility (STOI) ([#353](https://github.com/PyTorchLightning/metrics/issues/353))
- Added Information retrieval metrics:
- `RetrievalRPrecision` ([#577](https://github.com/PyTorchLightning/metrics/pull/577/))
- `RetrievalHitRate` ([#576](https://github.com/PyTorchLightning/metrics/pull/576))
- Added NLP metrics:
- `SacreBLEUScore` ([#546](https://github.com/PyTorchLightning/metrics/pull/546))
- `CharErrorRate` ([#575](https://github.com/PyTorchLightning/metrics/pull/575))
- Added other metrics:
- Tweedie Deviance Score ([#499](https://github.com/PyTorchLightning/metrics/pull/499))
- Learned Perceptual Image Patch Similarity (LPIPS) ([#431](https://github.com/PyTorchLightning/metrics/issues/431))
- Added support for float targets in `nDCG` metric ([#437](https://github.com/PyTorchLightning/metrics/pull/437))


- Added `average` argument to `AveragePrecision` metric for reducing multilabel and multiclass problems ([#477](https://github.com/PyTorchLightning/metrics/pull/477))


- Added Perceptual Evaluation of Speech Quality (PESQ) ([#353](https://github.com/PyTorchLightning/metrics/issues/353))


- Added `average` argument to `AveragePrecision` metric for reducing multi-label and multi-class problems ([#477](https://github.com/PyTorchLightning/metrics/pull/477))
- Added `MultioutputWrapper` ([#510](https://github.com/PyTorchLightning/metrics/pull/510))


- Added metric sweeping `higher_is_better` as constant attribute ([#544](https://github.com/PyTorchLightning/metrics/pull/544))


- Added `SacreBLEUScore` metric to text package ([#546](https://github.com/PyTorchLightning/metrics/pull/546))


- Added simple aggregation metrics: `SumMetric`, `MeanMetric`, `CatMetric`, `MinMetric`, `MaxMetric` ([#506](https://github.com/PyTorchLightning/metrics/pull/506))


- Added pairwise submodule with metrics ([#553](https://github.com/PyTorchLightning/metrics/pull/553))
- `pairwise_cosine_similarity`
- `pairwise_euclidean_distance`
- `pairwise_linear_similarity`
- `pairwise_manhatten_distance`


- Added Short Term Objective Intelligibility (`STOI`) ([#353](https://github.com/PyTorchLightning/metrics/issues/353))


- Added `RetrievalRPrecision` metric to retrieval package ([#577](https://github.com/PyTorchLightning/metrics/pull/577/))


- Added `RetrievalHitRate` metric to retrieval package ([#576](https://github.com/PyTorchLightning/metrics/pull/576))


- Added `CharErrorRate` metric to text package ([#575](https://github.com/PyTorchLightning/metrics/pull/575))

- Added `MinMaxMetric` to wrappers ([#556](https://github.com/PyTorchLightning/metrics/pull/556))

### Changed

- `AveragePrecision` will now as default output the `macro` average for multilabel and multiclass problems ([#477](https://github.com/PyTorchLightning/metrics/pull/477))


- `half`, `double`, `float` will no longer change the dtype of the metric states. Use `metric.set_dtype` instead ([#493](https://github.com/PyTorchLightning/metrics/pull/493))


- Renamed `AverageMeter` to `MeanMetric` ([#506](https://github.com/PyTorchLightning/metrics/pull/506))


- Changed `is_differentiable` from property to a constant attribute ([#551](https://github.com/PyTorchLightning/metrics/pull/551))

### Deprecated
Expand All @@ -77,18 +49,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

- Removed `dtype` property ([#493](https://github.com/PyTorchLightning/metrics/pull/493))


### Fixed

- Fixed bug in `F1` with `average='macro'` and `ignore_index!=None` ([#495](https://github.com/PyTorchLightning/metrics/pull/495))


- Fixed bug in `pit` by using the returned first result to initialize device and type ([#533](https://github.com/PyTorchLightning/metrics/pull/533))


- Fixed `SSIM` metric using too much memory ([#539](https://github.com/PyTorchLightning/metrics/pull/539))


- Fixed bug where `device` property was not properly update when metric was a child of a module ([#542](https://github.com/PyTorchLightning/metrics/pull/542))

## [0.5.1] - 2021-08-30
Expand Down
6 changes: 6 additions & 0 deletions docs/source/references/modules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -643,6 +643,12 @@ MetricTracker
.. autoclass:: torchmetrics.MetricTracker
:noindex:

MinMaxMetric
~~~~~~~~~~~~

.. autoclass:: torchmetrics.MinMaxMetric
:noindex:

MultioutputWrapper
~~~~~~~~~~~~~~~~~~

Expand Down
2 changes: 1 addition & 1 deletion tests/classification/test_average_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from tests.classification.inputs import _input_multilabel
from tests.helpers import seed_all
from tests.helpers.testers import NUM_CLASSES, MetricTester
from torchmetrics.classification.average_precision import AveragePrecision
from torchmetrics.classification.avg_precision import AveragePrecision
from torchmetrics.functional import average_precision

seed_all(42)
Expand Down
107 changes: 107 additions & 0 deletions tests/wrappers/test_minmax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
from functools import partial

import pytest
import torch

from tests.helpers import seed_all
from tests.helpers.testers import BATCH_SIZE, NUM_BATCHES, NUM_CLASSES, MetricTester
from torchmetrics import Accuracy, ConfusionMatrix, MeanSquaredError
from torchmetrics.wrappers import MinMaxMetric

seed_all(42)


class TestingMinMaxMetric(MinMaxMetric):
"""wrap metric to fit testing framework."""

def compute(self):
"""instead of returning dict, return as list."""
output_dict = super().compute()
return [output_dict["raw"], output_dict["min"], output_dict["max"]]

def forward(self, *args, **kwargs):
self.update(*args, **kwargs)
return self.compute()


def compare_fn(preds, target, base_fn):
"""comparing function for minmax wrapper."""
min, max = 1e6, -1e6 # pick some very large numbers for comparing
for i in range(NUM_BATCHES):
val = base_fn(preds[: (i + 1) * BATCH_SIZE], target[: (i + 1) * BATCH_SIZE]).cpu().numpy()
min = min if min < val else val
max = max if max > val else val
raw = base_fn(preds, target)
return [raw.cpu().numpy(), min, max]


@pytest.mark.parametrize(
"preds, target, base_metric",
[
(
torch.rand(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES).softmax(dim=-1),
torch.randint(NUM_CLASSES, (NUM_BATCHES, BATCH_SIZE)),
Accuracy(num_classes=NUM_CLASSES),
),
(torch.randn(NUM_BATCHES, BATCH_SIZE), torch.randn(NUM_BATCHES, BATCH_SIZE), MeanSquaredError()),
],
)
class TestMultioutputWrapper(MetricTester):
"""Test the MinMaxMetric wrapper works as expected."""

@pytest.mark.parametrize("ddp", [True, False])
def test_multioutput_wrapper(self, preds, target, base_metric, ddp):
self.run_class_metric_test(
ddp,
preds,
target,
TestingMinMaxMetric,
partial(compare_fn, base_fn=base_metric),
dist_sync_on_step=False,
metric_args=dict(base_metric=base_metric),
check_batch=False,
check_scriptable=False,
)


def test_basic_example() -> None:
"""tests that both min and max versions of MinMaxMetric operate correctly after calling compute."""
acc = Accuracy()
min_max_acc = MinMaxMetric(acc)

preds_1 = torch.Tensor([[0.9, 0.1], [0.2, 0.8]])
preds_2 = torch.Tensor([[0.1, 0.9], [0.2, 0.8]])
preds_3 = torch.Tensor([[0.1, 0.9], [0.8, 0.2]])
labels = torch.Tensor([[0, 1], [0, 1]]).long()

min_max_acc(preds_1, labels)
acc = min_max_acc.compute()
assert acc["raw"] == 0.5
assert acc["max"] == 0.5
assert acc["min"] == 0.5

min_max_acc(preds_2, labels)
acc = min_max_acc.compute()
assert acc["raw"] == 1.0
assert acc["max"] == 1.0
assert acc["min"] == 0.5

min_max_acc(preds_3, labels)
acc = min_max_acc.compute()
assert acc["raw"] == 0.5
assert acc["max"] == 1.0
assert acc["min"] == 0.5
janhenriklambrechts marked this conversation as resolved.
Show resolved Hide resolved


def test_no_base_metric() -> None:
"""tests that ValueError is raised when no base_metric is passed."""
with pytest.raises(ValueError, match=r"Expected base metric to be an instance .*"):
MinMaxMetric([])


def test_no_scalar_compute() -> None:
"""tests that an assertion error is thrown if the wrapped basemetric gives a non-scalar on compute."""
min_max_nsm = MinMaxMetric(ConfusionMatrix(num_classes=2))

with pytest.raises(RuntimeError, match=r"Returned value from base metric should be a scalar .*"):
min_max_nsm.compute()
2 changes: 1 addition & 1 deletion torchmetrics/__about__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "0.6.0dev"
__version__ = "0.6.0rc0"
__author__ = "PyTorchLightning et al."
__author_email__ = "[email protected]"
__license__ = "Apache-2.0"
Expand Down
3 changes: 2 additions & 1 deletion torchmetrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@
RetrievalRPrecision,
)
from torchmetrics.text import WER, BERTScore, BLEUScore, CharErrorRate, ROUGEScore, SacreBLEUScore # noqa: E402
from torchmetrics.wrappers import BootStrapper, MetricTracker, MultioutputWrapper # noqa: E402
from torchmetrics.wrappers import BootStrapper, MetricTracker, MinMaxMetric, MultioutputWrapper # noqa: E402

__all__ = [
"functional",
Expand Down Expand Up @@ -107,6 +107,7 @@
"Metric",
"MetricCollection",
"MetricTracker",
"MinMaxMetric",
"MinMetric",
"MultioutputWrapper",
"PearsonCorrcoef",
Expand Down
2 changes: 1 addition & 1 deletion torchmetrics/classification/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from torchmetrics.classification.accuracy import Accuracy # noqa: F401
from torchmetrics.classification.auc import AUC # noqa: F401
from torchmetrics.classification.auroc import AUROC # noqa: F401
from torchmetrics.classification.average_precision import AveragePrecision # noqa: F401
from torchmetrics.classification.avg_precision import AveragePrecision # noqa: F401
from torchmetrics.classification.binned_precision_recall import BinnedAveragePrecision # noqa: F401
from torchmetrics.classification.binned_precision_recall import BinnedPrecisionRecallCurve # noqa: F401
from torchmetrics.classification.binned_precision_recall import BinnedRecallAtFixedPrecision # noqa: F401
Expand Down
1 change: 1 addition & 0 deletions torchmetrics/wrappers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from torchmetrics.wrappers.bootstrapping import BootStrapper # noqa: F401
from torchmetrics.wrappers.minmax import MinMaxMetric # noqa: F401
from torchmetrics.wrappers.multioutput import MultioutputWrapper # noqa: F401
from torchmetrics.wrappers.tracker import MetricTracker # noqa: F401
119 changes: 119 additions & 0 deletions torchmetrics/wrappers/minmax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Callable, Dict, Optional, Union

import torch
from torch import Tensor

from torchmetrics.metric import Metric


class MinMaxMetric(Metric):
"""Wrapper Metric that tracks both the minimum and maximum of a scalar/tensor across an experiment. The min/max
value will be updated each time `.compute` is called.

Args:
base_metric:
The metric of which you want to keep track of its maximum and minimum values.
compute_on_step:
Forward only calls ``update()`` and return None if this is set to False. default: True
dist_sync_on_step:
Synchronize metric state across processes at each ``forward()``
before returning the value at the step. default: False
process_group:
Specify the process group on which synchronization is called. default: None (which selects the entire world)
dist_sync_fn:
Callback that performs the allgather operation on the metric state. When ``None``, DDP
will be used to perform the allgather

Raises:
ValueError
If ``base_metric` argument is not an subclasses instance of ``torchmetrics.Metric``

Example::
>>> import torch
>>> from torchmetrics import Accuracy, MinMaxMetric
>>> base_metric = Accuracy()
>>> minmax_metric = MinMaxMetric(base_metric)
>>> preds_1 = torch.Tensor([[0.9, 0.1], [0.2, 0.8]])
>>> preds_2 = torch.Tensor([[0.1, 0.9], [0.2, 0.8]])
>>> labels = torch.Tensor([[0, 1], [0, 1]]).long()
>>> minmax_metric(preds_1,labels) # Accuracy is 0.5
>>> output = minmax_metric.compute()
>>> print(output)
{'raw': tensor(0.5000), 'max': tensor(0.5000), 'min': tensor(0.5000)}
>>> minmax_metric(preds_2,labels) # Accuracy is 1.0
>>> output = minmax_metric.compute()
>>> print(output)
janhenriklambrechts marked this conversation as resolved.
Show resolved Hide resolved
{'raw': tensor(1.), 'max': tensor(1.), 'min': tensor(0.5000)}
"""
janhenriklambrechts marked this conversation as resolved.
Show resolved Hide resolved

min_val: Tensor
max_val: Tensor

def __init__(
self,
base_metric: Metric,
compute_on_step: bool = True,
dist_sync_on_step: bool = False,
process_group: Optional[Any] = None,
dist_sync_fn: Callable = None,
) -> None:
super().__init__(
compute_on_step=compute_on_step,
dist_sync_on_step=dist_sync_on_step,
process_group=process_group,
dist_sync_fn=dist_sync_fn,
)
if not isinstance(base_metric, Metric):
raise ValueError(
f"Expected base metric to be an instance of torchmetrics.Metric but received {base_metric}"
)
self._base_metric = base_metric
janhenriklambrechts marked this conversation as resolved.
Show resolved Hide resolved
self.add_state("min_val", default=torch.tensor(float("inf")), dist_reduce_fx="min")
self.add_state("max_val", default=torch.tensor(float("-inf")), dist_reduce_fx="max")

def update(self, *args: Any, **kwargs: Any) -> None: # type: ignore
"""Updates the underlying metric."""
self._base_metric.update(*args, **kwargs)

def compute(self) -> Dict[str, Tensor]: # type: ignore
"""Computes the underlying metric as well as max and min values for this metric.

Returns a dictionary that consists of the computed value (``raw``), as well as the minimum (``min``) and maximum
(``max``) values.
"""
val = self._base_metric.compute()
janhenriklambrechts marked this conversation as resolved.
Show resolved Hide resolved
if not self._is_suitable_val(val):
raise RuntimeError(
"Returned value from base metric should be a scalar (int, float or tensor of size 1, but got {val}"
)
self.max_val = val if self.max_val < val else self.max_val
self.min_val = val if self.min_val > val else self.min_val
return {"raw": val, "max": self.max_val, "min": self.min_val}

def reset(self) -> None:
"""Sets ``max_val`` and ``min_val`` to the initialization bounds and resets the base metric."""
super().reset()
self._base_metric.reset()

@staticmethod
def _is_suitable_val(val: Union[int, float, Tensor]) -> bool:
"""Utility function that checks whether min/max value."""
if (type(val) == int) or (type(val) == float):
return True
elif isinstance(val, Tensor):
return val.numel() == 1
return False