From e43dbb5e80a2229da3c1e4cbca70bd82d813c679 Mon Sep 17 00:00:00 2001 From: Jan-Henrik Lambrechts Date: Fri, 1 Oct 2021 00:33:17 +0100 Subject: [PATCH 01/56] scaffolding of PR --- tests/wrappers/test_minmax.py | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) create mode 100644 tests/wrappers/test_minmax.py diff --git a/tests/wrappers/test_minmax.py b/tests/wrappers/test_minmax.py new file mode 100644 index 00000000000..6a9333b1ee5 --- /dev/null +++ b/tests/wrappers/test_minmax.py @@ -0,0 +1,26 @@ +import torch +import pytest + +@pytest.mark.parametrize("device", ["cpu", "cuda"]) +@pytest.mark.parametrize("type", ["min", "max"]) +def test_minmax(device, type): + "test that both min and max versions of MinMaxMetric operate correctly after calling compute" + m = MinMaxMetric() + acc = Accuracy() + + preds_1 = torch.Tensor([[0.9, 0.1], [0.2, 0.8]]) + preds_2 = torch.Tensor([[0.1, 0.9], [0.2, 0.8]]) + preds_3 = torch.Tensor([[0.1, 0.9], [0.8, 0.2]]) + labels = torch.Tensor([[0, 1], [0, 1]]).long() + + acc(preds_1, labels) # acc is 0.5 + m(acc.compute()) # max_metrix is 0.5 + assert m.compute() == 0.5 + + acc(preds_2, labels) # acc is 1. + m(acc.compute()) # max_metrix is 1. + assert m.compute() == 1. + + acc(preds_3, labels) # acc is 0.5 + m(acc.compute()) # max_metrix is 1. + assert m.compute() == 1. \ No newline at end of file From d3d7ff99df148b8ce9cf620dce12f950c5596b8c Mon Sep 17 00:00:00 2001 From: Jan-Henrik Lambrechts Date: Fri, 1 Oct 2021 00:33:37 +0100 Subject: [PATCH 02/56] more scaffolding --- torchmetrics/__init__.py | 2 +- torchmetrics/wrappers/__init__.py | 1 + torchmetrics/wrappers/minmax.py | 42 +++++++++++++++++++++++++++++++ 3 files changed, 44 insertions(+), 1 deletion(-) create mode 100644 torchmetrics/wrappers/minmax.py diff --git a/torchmetrics/__init__.py b/torchmetrics/__init__.py index 4a95a0f6db8..e2bacd2d5b8 100644 --- a/torchmetrics/__init__.py +++ b/torchmetrics/__init__.py @@ -64,7 +64,7 @@ RetrievalRecall, ) from torchmetrics.text import WER, BERTScore, BLEUScore, ROUGEScore, SacreBLEUScore # noqa: E402 -from torchmetrics.wrappers import BootStrapper, MetricTracker, MultioutputWrapper # noqa: E402 +from torchmetrics.wrappers import BootStrapper, MetricTracker, MultioutputWrapper, MinMaxMetric # noqa: E402 __all__ = [ "functional", diff --git a/torchmetrics/wrappers/__init__.py b/torchmetrics/wrappers/__init__.py index 5bca8460c89..4242cc3785d 100644 --- a/torchmetrics/wrappers/__init__.py +++ b/torchmetrics/wrappers/__init__.py @@ -14,3 +14,4 @@ from torchmetrics.wrappers.bootstrapping import BootStrapper # noqa: F401 from torchmetrics.wrappers.multioutput import MultioutputWrapper # noqa: F401 from torchmetrics.wrappers.tracker import MetricTracker # noqa: F401 +from torchmetrics.wrappers.minmax import MinMaxMetric diff --git a/torchmetrics/wrappers/minmax.py b/torchmetrics/wrappers/minmax.py new file mode 100644 index 00000000000..b1538901642 --- /dev/null +++ b/torchmetrics/wrappers/minmax.py @@ -0,0 +1,42 @@ +from torchmetrics.metric import Metric +from pytorch_lightning.utilities.distributed import gather_all_tensors + + +class MaxMetric(Metric): + """ + Pytorch-Lightning Metric that tracks the maximum value of a scalar/tensor across an experiment + """ + def __init__(self, dist_sync_on_step=False): + super().__init__(dist_sync_on_step=dist_sync_on_step) + self.add_state("max_val", default=torch.tensor(0)) + + def _wrap_compute(self, compute): + def wrapped_func(*args, **kwargs): + # return cached value + if self._computed is not None: + return self._computed + + dist_sync_fn = self.dist_sync_fn + if ( + dist_sync_fn is None + and torch.distributed.is_available() + and torch.distributed.is_initialized() + ): + # User provided a bool, so we assume DDP if available + dist_sync_fn = gather_all_tensors + + if self._to_sync and dist_sync_fn is not None: + self._sync_dist(dist_sync_fn) + + self._computed = compute(*args, **kwargs) + # removed the auto-reset + + return self._computed + + return wrapped_func + + def update(self, val): + self.max_val = val if self.max_val < val else self.max_val + + def compute(self): + return self.max_val \ No newline at end of file From 8e2457e35bd26e9a3f46dddbe428fd653fbb847f Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 30 Sep 2021 23:35:06 +0000 Subject: [PATCH 03/56] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/wrappers/test_minmax.py | 21 +++++++++++---------- torchmetrics/__init__.py | 2 +- torchmetrics/wrappers/__init__.py | 2 +- torchmetrics/wrappers/minmax.py | 16 ++++++---------- 4 files changed, 19 insertions(+), 22 deletions(-) diff --git a/tests/wrappers/test_minmax.py b/tests/wrappers/test_minmax.py index 6a9333b1ee5..978e5f4ef66 100644 --- a/tests/wrappers/test_minmax.py +++ b/tests/wrappers/test_minmax.py @@ -1,10 +1,11 @@ -import torch import pytest +import torch + @pytest.mark.parametrize("device", ["cpu", "cuda"]) @pytest.mark.parametrize("type", ["min", "max"]) def test_minmax(device, type): - "test that both min and max versions of MinMaxMetric operate correctly after calling compute" + """test that both min and max versions of MinMaxMetric operate correctly after calling compute.""" m = MinMaxMetric() acc = Accuracy() @@ -13,14 +14,14 @@ def test_minmax(device, type): preds_3 = torch.Tensor([[0.1, 0.9], [0.8, 0.2]]) labels = torch.Tensor([[0, 1], [0, 1]]).long() - acc(preds_1, labels) # acc is 0.5 - m(acc.compute()) # max_metrix is 0.5 + acc(preds_1, labels) # acc is 0.5 + m(acc.compute()) # max_metrix is 0.5 assert m.compute() == 0.5 - acc(preds_2, labels) # acc is 1. - m(acc.compute()) # max_metrix is 1. - assert m.compute() == 1. + acc(preds_2, labels) # acc is 1. + m(acc.compute()) # max_metrix is 1. + assert m.compute() == 1.0 - acc(preds_3, labels) # acc is 0.5 - m(acc.compute()) # max_metrix is 1. - assert m.compute() == 1. \ No newline at end of file + acc(preds_3, labels) # acc is 0.5 + m(acc.compute()) # max_metrix is 1. + assert m.compute() == 1.0 diff --git a/torchmetrics/__init__.py b/torchmetrics/__init__.py index e2bacd2d5b8..60de14e1029 100644 --- a/torchmetrics/__init__.py +++ b/torchmetrics/__init__.py @@ -64,7 +64,7 @@ RetrievalRecall, ) from torchmetrics.text import WER, BERTScore, BLEUScore, ROUGEScore, SacreBLEUScore # noqa: E402 -from torchmetrics.wrappers import BootStrapper, MetricTracker, MultioutputWrapper, MinMaxMetric # noqa: E402 +from torchmetrics.wrappers import BootStrapper, MetricTracker, MinMaxMetric, MultioutputWrapper # noqa: E402 __all__ = [ "functional", diff --git a/torchmetrics/wrappers/__init__.py b/torchmetrics/wrappers/__init__.py index 4242cc3785d..a0c4f68955d 100644 --- a/torchmetrics/wrappers/__init__.py +++ b/torchmetrics/wrappers/__init__.py @@ -12,6 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. from torchmetrics.wrappers.bootstrapping import BootStrapper # noqa: F401 +from torchmetrics.wrappers.minmax import MinMaxMetric from torchmetrics.wrappers.multioutput import MultioutputWrapper # noqa: F401 from torchmetrics.wrappers.tracker import MetricTracker # noqa: F401 -from torchmetrics.wrappers.minmax import MinMaxMetric diff --git a/torchmetrics/wrappers/minmax.py b/torchmetrics/wrappers/minmax.py index b1538901642..e3357680fa2 100644 --- a/torchmetrics/wrappers/minmax.py +++ b/torchmetrics/wrappers/minmax.py @@ -1,11 +1,11 @@ -from torchmetrics.metric import Metric from pytorch_lightning.utilities.distributed import gather_all_tensors +from torchmetrics.metric import Metric + class MaxMetric(Metric): - """ - Pytorch-Lightning Metric that tracks the maximum value of a scalar/tensor across an experiment - """ + """Pytorch-Lightning Metric that tracks the maximum value of a scalar/tensor across an experiment.""" + def __init__(self, dist_sync_on_step=False): super().__init__(dist_sync_on_step=dist_sync_on_step) self.add_state("max_val", default=torch.tensor(0)) @@ -17,11 +17,7 @@ def wrapped_func(*args, **kwargs): return self._computed dist_sync_fn = self.dist_sync_fn - if ( - dist_sync_fn is None - and torch.distributed.is_available() - and torch.distributed.is_initialized() - ): + if dist_sync_fn is None and torch.distributed.is_available() and torch.distributed.is_initialized(): # User provided a bool, so we assume DDP if available dist_sync_fn = gather_all_tensors @@ -39,4 +35,4 @@ def update(self, val): self.max_val = val if self.max_val < val else self.max_val def compute(self): - return self.max_val \ No newline at end of file + return self.max_val From 79917a1e39374c6910921153e64aba7e9ac8e885 Mon Sep 17 00:00:00 2001 From: Jan-Henrik Lambrechts <31068156+janhenriklambrechts@users.noreply.github.com> Date: Thu, 7 Oct 2021 09:47:36 +0100 Subject: [PATCH 04/56] Add linting skip comment Co-authored-by: Nicki Skafte Detlefsen --- torchmetrics/wrappers/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchmetrics/wrappers/__init__.py b/torchmetrics/wrappers/__init__.py index a0c4f68955d..1f9dab6da45 100644 --- a/torchmetrics/wrappers/__init__.py +++ b/torchmetrics/wrappers/__init__.py @@ -12,6 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. from torchmetrics.wrappers.bootstrapping import BootStrapper # noqa: F401 -from torchmetrics.wrappers.minmax import MinMaxMetric +from torchmetrics.wrappers.minmax import MinMaxMetric # noqa: F401 from torchmetrics.wrappers.multioutput import MultioutputWrapper # noqa: F401 from torchmetrics.wrappers.tracker import MetricTracker # noqa: F401 From 8b8624ffd8003a32dc69a785d532f53b64bb0d1e Mon Sep 17 00:00:00 2001 From: Jan-Henrik Lambrechts Date: Wed, 13 Oct 2021 10:46:38 +0100 Subject: [PATCH 05/56] changed name to minmax --- torchmetrics/wrappers/minmax.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchmetrics/wrappers/minmax.py b/torchmetrics/wrappers/minmax.py index e3357680fa2..68de5a832b5 100644 --- a/torchmetrics/wrappers/minmax.py +++ b/torchmetrics/wrappers/minmax.py @@ -3,7 +3,7 @@ from torchmetrics.metric import Metric -class MaxMetric(Metric): +class MinMaxMetric(Metric): """Pytorch-Lightning Metric that tracks the maximum value of a scalar/tensor across an experiment.""" def __init__(self, dist_sync_on_step=False): From 26da7455fd4ae84446397635734ee55f25d88b64 Mon Sep 17 00:00:00 2001 From: Jan-Henrik Lambrechts Date: Wed, 13 Oct 2021 11:20:24 +0100 Subject: [PATCH 06/56] implemented wrapper design and modified tests --- tests/wrappers/test_minmax.py | 28 +++++++++----- torchmetrics/wrappers/minmax.py | 66 +++++++++++++++++++-------------- 2 files changed, 56 insertions(+), 38 deletions(-) diff --git a/tests/wrappers/test_minmax.py b/tests/wrappers/test_minmax.py index 978e5f4ef66..61ed39e69ed 100644 --- a/tests/wrappers/test_minmax.py +++ b/tests/wrappers/test_minmax.py @@ -1,27 +1,35 @@ import pytest import torch +from torchmetrics.wrappers import MinMaxMetric +from torchmetrics.classification import Accuracy @pytest.mark.parametrize("device", ["cpu", "cuda"]) @pytest.mark.parametrize("type", ["min", "max"]) def test_minmax(device, type): """test that both min and max versions of MinMaxMetric operate correctly after calling compute.""" - m = MinMaxMetric() acc = Accuracy() + min_max_acc = MinMaxMetric(acc) preds_1 = torch.Tensor([[0.9, 0.1], [0.2, 0.8]]) preds_2 = torch.Tensor([[0.1, 0.9], [0.2, 0.8]]) preds_3 = torch.Tensor([[0.1, 0.9], [0.8, 0.2]]) labels = torch.Tensor([[0, 1], [0, 1]]).long() - acc(preds_1, labels) # acc is 0.5 - m(acc.compute()) # max_metrix is 0.5 - assert m.compute() == 0.5 + min_max_acc(preds_1, labels) + acc = min_max_acc.compute() + assert acc["raw"] == 0.5 + assert acc["max"] == 0.5 + assert acc["min"] == 0.5 - acc(preds_2, labels) # acc is 1. - m(acc.compute()) # max_metrix is 1. - assert m.compute() == 1.0 + min_max_acc(preds_2, labels) + acc = min_max_acc.compute() + assert acc["raw"] == 1.0 + assert acc["max"] == 1.0 + assert acc["min"] == 0.5 - acc(preds_3, labels) # acc is 0.5 - m(acc.compute()) # max_metrix is 1. - assert m.compute() == 1.0 + min_max_acc(preds_3, labels) + acc = min_max_acc.compute() + assert acc["raw"] == 0.5 + assert acc["max"] == 1.0 + assert acc["min"] == 0.5 diff --git a/torchmetrics/wrappers/minmax.py b/torchmetrics/wrappers/minmax.py index 68de5a832b5..40365383e25 100644 --- a/torchmetrics/wrappers/minmax.py +++ b/torchmetrics/wrappers/minmax.py @@ -1,38 +1,48 @@ -from pytorch_lightning.utilities.distributed import gather_all_tensors - + # Copyright The PyTorch Lightning team. + # + # Licensed under the Apache License, Version 2.0 (the "License"); + # you may not use this file except in compliance with the License. + # You may obtain a copy of the License at + # + # http://www.apache.org/licenses/LICENSE-2.0 + # + # Unless required by applicable law or agreed to in writing, software + # distributed under the License is distributed on an "AS IS" BASIS, + # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + # See the License for the specific language governing permissions and + # limitations under the License. + +import torch from torchmetrics.metric import Metric - +from typing import Any class MinMaxMetric(Metric): - """Pytorch-Lightning Metric that tracks the maximum value of a scalar/tensor across an experiment.""" + """Wrapper Metric that tracks both the minimum and maximum of a scalar/tensor across an experiment.""" - def __init__(self, dist_sync_on_step=False): + def __init__(self, base_metric: Metric, dist_sync_on_step:bool=False, min_bound_init:float=1., max_bound_init:float=0.): super().__init__(dist_sync_on_step=dist_sync_on_step) - self.add_state("max_val", default=torch.tensor(0)) - - def _wrap_compute(self, compute): - def wrapped_func(*args, **kwargs): - # return cached value - if self._computed is not None: - return self._computed - - dist_sync_fn = self.dist_sync_fn - if dist_sync_fn is None and torch.distributed.is_available() and torch.distributed.is_initialized(): - # User provided a bool, so we assume DDP if available - dist_sync_fn = gather_all_tensors - - if self._to_sync and dist_sync_fn is not None: - self._sync_dist(dist_sync_fn) + self._base_metric = base_metric + self.add_state("min_val", default=torch.tensor(min_bound_init)) + self.add_state("max_val", default=torch.tensor(max_bound_init)) + self.min_bound_init = min_bound_init + self.max_bound_init = max_bound_init - self._computed = compute(*args, **kwargs) - # removed the auto-reset + def update(self, *args: Any, **kwargs: Any): + "Update underlying metric" + self._base_metric.update(*args, **kwargs) + - return self._computed + def compute(self): + "Compute underlying metric as well as max and min values." + val = self._base_metric.compute() + self.max_val = val if self.max_val < val else self.max_val + self.min_val = val if self.min_val > val else self.min_val + return {"raw" : val, "max" : self.max_val, "min" : self.min_val} - return wrapped_func + def reset(self): + "Sets max_val and min_val to 0. and resets the base metric." + self.max_val = self.max_bound_init + self.min_val = self.min_bound_init + self._base_metric.reset() - def update(self, val): - self.max_val = val if self.max_val < val else self.max_val - def compute(self): - return self.max_val From 3a076578bc908bf04aabb97b2f735e8ee229fe99 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 13 Oct 2021 10:20:55 +0000 Subject: [PATCH 07/56] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/wrappers/test_minmax.py | 7 +++-- torchmetrics/wrappers/minmax.py | 53 +++++++++++++++++++-------------- 2 files changed, 35 insertions(+), 25 deletions(-) diff --git a/tests/wrappers/test_minmax.py b/tests/wrappers/test_minmax.py index 61ed39e69ed..4903eb65476 100644 --- a/tests/wrappers/test_minmax.py +++ b/tests/wrappers/test_minmax.py @@ -1,7 +1,8 @@ import pytest import torch -from torchmetrics.wrappers import MinMaxMetric + from torchmetrics.classification import Accuracy +from torchmetrics.wrappers import MinMaxMetric @pytest.mark.parametrize("device", ["cpu", "cuda"]) @@ -16,8 +17,8 @@ def test_minmax(device, type): preds_3 = torch.Tensor([[0.1, 0.9], [0.8, 0.2]]) labels = torch.Tensor([[0, 1], [0, 1]]).long() - min_max_acc(preds_1, labels) - acc = min_max_acc.compute() + min_max_acc(preds_1, labels) + acc = min_max_acc.compute() assert acc["raw"] == 0.5 assert acc["max"] == 0.5 assert acc["min"] == 0.5 diff --git a/torchmetrics/wrappers/minmax.py b/torchmetrics/wrappers/minmax.py index 40365383e25..786ea6be437 100644 --- a/torchmetrics/wrappers/minmax.py +++ b/torchmetrics/wrappers/minmax.py @@ -1,25 +1,34 @@ - # Copyright The PyTorch Lightning team. - # - # Licensed under the Apache License, Version 2.0 (the "License"); - # you may not use this file except in compliance with the License. - # You may obtain a copy of the License at - # - # http://www.apache.org/licenses/LICENSE-2.0 - # - # Unless required by applicable law or agreed to in writing, software - # distributed under the License is distributed on an "AS IS" BASIS, - # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - # See the License for the specific language governing permissions and - # limitations under the License. +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any import torch + from torchmetrics.metric import Metric -from typing import Any + class MinMaxMetric(Metric): """Wrapper Metric that tracks both the minimum and maximum of a scalar/tensor across an experiment.""" - def __init__(self, base_metric: Metric, dist_sync_on_step:bool=False, min_bound_init:float=1., max_bound_init:float=0.): + def __init__( + self, + base_metric: Metric, + dist_sync_on_step: bool = False, + min_bound_init: float = 1.0, + max_bound_init: float = 0.0, + ): super().__init__(dist_sync_on_step=dist_sync_on_step) self._base_metric = base_metric self.add_state("min_val", default=torch.tensor(min_bound_init)) @@ -28,21 +37,21 @@ def __init__(self, base_metric: Metric, dist_sync_on_step:bool=False, min_bound_ self.max_bound_init = max_bound_init def update(self, *args: Any, **kwargs: Any): - "Update underlying metric" + """Update underlying metric.""" self._base_metric.update(*args, **kwargs) - def compute(self): - "Compute underlying metric as well as max and min values." + """Compute underlying metric as well as max and min values.""" val = self._base_metric.compute() self.max_val = val if self.max_val < val else self.max_val self.min_val = val if self.min_val > val else self.min_val - return {"raw" : val, "max" : self.max_val, "min" : self.min_val} + return {"raw": val, "max": self.max_val, "min": self.min_val} def reset(self): - "Sets max_val and min_val to 0. and resets the base metric." + """Sets max_val and min_val to 0. + + and resets the base metric. + """ self.max_val = self.max_bound_init self.min_val = self.min_bound_init self._base_metric.reset() - - From 6febac740a60f9d206cba83f70c7e076303ed79d Mon Sep 17 00:00:00 2001 From: Jan-Henrik Lambrechts Date: Wed, 13 Oct 2021 11:21:35 +0100 Subject: [PATCH 08/56] removed useless parameter from test --- tests/wrappers/test_minmax.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/wrappers/test_minmax.py b/tests/wrappers/test_minmax.py index 61ed39e69ed..c5ef9983a41 100644 --- a/tests/wrappers/test_minmax.py +++ b/tests/wrappers/test_minmax.py @@ -5,8 +5,7 @@ @pytest.mark.parametrize("device", ["cpu", "cuda"]) -@pytest.mark.parametrize("type", ["min", "max"]) -def test_minmax(device, type): +def test_minmax(device): """test that both min and max versions of MinMaxMetric operate correctly after calling compute.""" acc = Accuracy() min_max_acc = MinMaxMetric(acc) From 519e714fa8a879f03b9fd09dcade17865205356d Mon Sep 17 00:00:00 2001 From: Jan-Henrik Lambrechts Date: Wed, 13 Oct 2021 11:30:14 +0100 Subject: [PATCH 09/56] flake + typing fixes --- tests/wrappers/test_minmax.py | 2 +- torchmetrics/wrappers/minmax.py | 9 +++++---- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/tests/wrappers/test_minmax.py b/tests/wrappers/test_minmax.py index c5ef9983a41..278ee5ef7cd 100644 --- a/tests/wrappers/test_minmax.py +++ b/tests/wrappers/test_minmax.py @@ -5,7 +5,7 @@ @pytest.mark.parametrize("device", ["cpu", "cuda"]) -def test_minmax(device): +def test_minmax(device: str) -> None: """test that both min and max versions of MinMaxMetric operate correctly after calling compute.""" acc = Accuracy() min_max_acc = MinMaxMetric(acc) diff --git a/torchmetrics/wrappers/minmax.py b/torchmetrics/wrappers/minmax.py index 40365383e25..e57d9ff2912 100644 --- a/torchmetrics/wrappers/minmax.py +++ b/torchmetrics/wrappers/minmax.py @@ -13,8 +13,9 @@ # limitations under the License. import torch +from torch import Tensor from torchmetrics.metric import Metric -from typing import Any +from typing import Any, Dict class MinMaxMetric(Metric): """Wrapper Metric that tracks both the minimum and maximum of a scalar/tensor across an experiment.""" @@ -27,19 +28,19 @@ def __init__(self, base_metric: Metric, dist_sync_on_step:bool=False, min_bound_ self.min_bound_init = min_bound_init self.max_bound_init = max_bound_init - def update(self, *args: Any, **kwargs: Any): + def update(self, *args: Any, **kwargs: Any) -> None: # type: ignore "Update underlying metric" self._base_metric.update(*args, **kwargs) - def compute(self): + def compute(self) -> Dict[str, Tensor]: # type: ignore "Compute underlying metric as well as max and min values." val = self._base_metric.compute() self.max_val = val if self.max_val < val else self.max_val self.min_val = val if self.min_val > val else self.min_val return {"raw" : val, "max" : self.max_val, "min" : self.min_val} - def reset(self): + def reset(self) -> None: "Sets max_val and min_val to 0. and resets the base metric." self.max_val = self.max_bound_init self.min_val = self.min_bound_init From e297c66b0a13c849a02f60c5ae188d25a969ace8 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 13 Oct 2021 10:32:56 +0000 Subject: [PATCH 10/56] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- torchmetrics/wrappers/minmax.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/torchmetrics/wrappers/minmax.py b/torchmetrics/wrappers/minmax.py index 826de0b88d9..b2b68b3f4a8 100644 --- a/torchmetrics/wrappers/minmax.py +++ b/torchmetrics/wrappers/minmax.py @@ -12,12 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any +from typing import Any, Dict import torch from torch import Tensor + from torchmetrics.metric import Metric -from typing import Any, Dict + class MinMaxMetric(Metric): """Wrapper Metric that tracks both the minimum and maximum of a scalar/tensor across an experiment.""" @@ -36,11 +37,11 @@ def __init__( self.min_bound_init = min_bound_init self.max_bound_init = max_bound_init - def update(self, *args: Any, **kwargs: Any) -> None: # type: ignore - """Update underlying metric""" + def update(self, *args: Any, **kwargs: Any) -> None: # type: ignore + """Update underlying metric.""" self._base_metric.update(*args, **kwargs) - def compute(self) -> Dict[str, Tensor]: # type: ignore + def compute(self) -> Dict[str, Tensor]: # type: ignore """Compute underlying metric as well as max and min values.""" val = self._base_metric.compute() self.max_val = val if self.max_val < val else self.max_val From 8e38515c340a63324cbfbcb28419e493b6ffa661 Mon Sep 17 00:00:00 2001 From: Jan-Henrik Lambrechts Date: Wed, 13 Oct 2021 12:02:52 +0100 Subject: [PATCH 11/56] clean descriptions of minmax for docs --- torchmetrics/wrappers/minmax.py | 30 ++++++++++++++++++++++++++---- 1 file changed, 26 insertions(+), 4 deletions(-) diff --git a/torchmetrics/wrappers/minmax.py b/torchmetrics/wrappers/minmax.py index 826de0b88d9..2feaa476f87 100644 --- a/torchmetrics/wrappers/minmax.py +++ b/torchmetrics/wrappers/minmax.py @@ -20,7 +20,25 @@ from typing import Any, Dict class MinMaxMetric(Metric): - """Wrapper Metric that tracks both the minimum and maximum of a scalar/tensor across an experiment.""" + """Wrapper Metric that tracks both the minimum and maximum of a + scalar/tensor across an experiment. + + Note: + Make sure you pass proper initialization values to the ``min_bound_init`` and ``max_bound_init`` parameters. + For the ``Accuracy`` metric, the defaults of ``0.0`` and ``1.0`` make sense, + however, for other metrics you will likely want to use different initialization values. + + Args: + base_metric: + The metric of which you want to keep track of its maximum and minimum values. + dist_sync_on_step: + Synchronize metric state across processes at each ``forward()`` + before returning the value at the step. + min_bound_init: + Initialization value of the ``min`` parameter. default: 0.0 + max_bound_init: + Initialization value of the ``max`` parameter. default: 1.0 + """ def __init__( self, @@ -37,18 +55,22 @@ def __init__( self.max_bound_init = max_bound_init def update(self, *args: Any, **kwargs: Any) -> None: # type: ignore - """Update underlying metric""" + """Updates the underlying metric""" self._base_metric.update(*args, **kwargs) def compute(self) -> Dict[str, Tensor]: # type: ignore - """Compute underlying metric as well as max and min values.""" + """Computes the underlying metric as well as max and min values for this metric. + + Returns a dictionary that consists of the computed value (``raw``), as well as the + minimum (``min``) and maximum (``max``) values. + """ val = self._base_metric.compute() self.max_val = val if self.max_val < val else self.max_val self.min_val = val if self.min_val > val else self.min_val return {"raw": val, "max": self.max_val, "min": self.min_val} def reset(self) -> None: - """Sets max_val and min_val to the initialization bounds and resets the base metric.""" + """Sets ``max_val`` and ``min_val`` to the initialization bounds and resets the base metric.""" self.max_val = self.max_bound_init self.min_val = self.min_bound_init self._base_metric.reset() From 649ba68dd388e585c204d2bf30044541a587d2dc Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 13 Oct 2021 11:04:17 +0000 Subject: [PATCH 12/56] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- torchmetrics/wrappers/minmax.py | 21 ++++++++++----------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/torchmetrics/wrappers/minmax.py b/torchmetrics/wrappers/minmax.py index 310a9f4a514..b3867cb8b9a 100644 --- a/torchmetrics/wrappers/minmax.py +++ b/torchmetrics/wrappers/minmax.py @@ -21,16 +21,15 @@ class MinMaxMetric(Metric): - """Wrapper Metric that tracks both the minimum and maximum of a - scalar/tensor across an experiment. + """Wrapper Metric that tracks both the minimum and maximum of a scalar/tensor across an experiment. Note: Make sure you pass proper initialization values to the ``min_bound_init`` and ``max_bound_init`` parameters. - For the ``Accuracy`` metric, the defaults of ``0.0`` and ``1.0`` make sense, + For the ``Accuracy`` metric, the defaults of ``0.0`` and ``1.0`` make sense, however, for other metrics you will likely want to use different initialization values. - Args: - base_metric: + Args: + base_metric: The metric of which you want to keep track of its maximum and minimum values. dist_sync_on_step: Synchronize metric state across processes at each ``forward()`` @@ -55,15 +54,15 @@ def __init__( self.min_bound_init = min_bound_init self.max_bound_init = max_bound_init - def update(self, *args: Any, **kwargs: Any) -> None: # type: ignore - """Updates the underlying metric""" + def update(self, *args: Any, **kwargs: Any) -> None: # type: ignore + """Updates the underlying metric.""" self._base_metric.update(*args, **kwargs) - def compute(self) -> Dict[str, Tensor]: # type: ignore + def compute(self) -> Dict[str, Tensor]: # type: ignore """Computes the underlying metric as well as max and min values for this metric. - - Returns a dictionary that consists of the computed value (``raw``), as well as the - minimum (``min``) and maximum (``max``) values. + + Returns a dictionary that consists of the computed value (``raw``), as well as the minimum (``min``) and maximum + (``max``) values. """ val = self._base_metric.compute() self.max_val = val if self.max_val < val else self.max_val From f53b1f46c0f98d71ac6b50f0462e860cc6522bc2 Mon Sep 17 00:00:00 2001 From: Jan-Henrik Lambrechts Date: Thu, 14 Oct 2021 13:26:43 +0100 Subject: [PATCH 13/56] added MinMaxMetric to __all__ --- torchmetrics/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torchmetrics/__init__.py b/torchmetrics/__init__.py index 906a14bb52a..b0ea32a6277 100644 --- a/torchmetrics/__init__.py +++ b/torchmetrics/__init__.py @@ -105,6 +105,7 @@ "Metric", "MetricCollection", "MetricTracker", + "MinMaxMetric", "MinMetric", "MultioutputWrapper", "PearsonCorrcoef", From 341ddf8ceeaf9b784f7f6a15ef1ef2a297430a6c Mon Sep 17 00:00:00 2001 From: Jan-Henrik Lambrechts Date: Fri, 15 Oct 2021 11:17:04 +0100 Subject: [PATCH 14/56] removed redundant device flag in test --- tests/wrappers/test_minmax.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/wrappers/test_minmax.py b/tests/wrappers/test_minmax.py index d6d28805796..d5afc8130f2 100644 --- a/tests/wrappers/test_minmax.py +++ b/tests/wrappers/test_minmax.py @@ -4,9 +4,7 @@ from torchmetrics.classification import Accuracy from torchmetrics.wrappers import MinMaxMetric - -@pytest.mark.parametrize("device", ["cpu", "cuda"]) -def test_minmax(device: str) -> None: +def test_minmax() -> None: """test that both min and max versions of MinMaxMetric operate correctly after calling compute.""" acc = Accuracy() min_max_acc = MinMaxMetric(acc) From cdb44b371bd3c465f5c38093a67c3d215a73213f Mon Sep 17 00:00:00 2001 From: Jan-Henrik Lambrechts Date: Fri, 15 Oct 2021 11:42:39 +0100 Subject: [PATCH 15/56] added test and assertion when compute is not a scalar: --- tests/wrappers/test_minmax.py | 21 ++++++++++++++++++++- torchmetrics/wrappers/minmax.py | 23 +++++++++++++++++++++++ 2 files changed, 43 insertions(+), 1 deletion(-) diff --git a/tests/wrappers/test_minmax.py b/tests/wrappers/test_minmax.py index d5afc8130f2..49eff8dae35 100644 --- a/tests/wrappers/test_minmax.py +++ b/tests/wrappers/test_minmax.py @@ -3,8 +3,9 @@ from torchmetrics.classification import Accuracy from torchmetrics.wrappers import MinMaxMetric +from torchmetrics.metric import Metric -def test_minmax() -> None: +def test_base() -> None: """test that both min and max versions of MinMaxMetric operate correctly after calling compute.""" acc = Accuracy() min_max_acc = MinMaxMetric(acc) @@ -31,3 +32,21 @@ def test_minmax() -> None: assert acc["raw"] == 0.5 assert acc["max"] == 1.0 assert acc["min"] == 0.5 + +def test_no_scalar_compute() -> None: + """test that an assertion error is thrown if the wrapped basemetric gives a non-scalar on compute""" + + class NonScalarMetric(Metric): + def __init__(self): + super().__init__() + pass + def update(self): + pass + def compute(self): + return "" + + nsm = NonScalarMetric() + min_max_nsm = MinMaxMetric(nsm) + + with pytest.raises(AssertionError): + min_max_nsm.compute() \ No newline at end of file diff --git a/torchmetrics/wrappers/minmax.py b/torchmetrics/wrappers/minmax.py index b3867cb8b9a..5150568490e 100644 --- a/torchmetrics/wrappers/minmax.py +++ b/torchmetrics/wrappers/minmax.py @@ -19,6 +19,26 @@ from torchmetrics.metric import Metric +def _is_suitable_val(val: Any) -> bool: + """Utility function that checks whether min/max value is either a: + - int + - float + - tensor with 1 element + """ + print(val) + print(type(val)) + if (type(val) == int) or (type(val) == float): + return True + elif type(val) == torch.Tensor: + print(val.size()) + if val.size() == torch.Size([]): + return True + else: + return False + else: + return False + + class MinMaxMetric(Metric): """Wrapper Metric that tracks both the minimum and maximum of a scalar/tensor across an experiment. @@ -65,6 +85,9 @@ def compute(self) -> Dict[str, Tensor]: # type: ignore (``max``) values. """ val = self._base_metric.compute() + isv = _is_suitable_val(val) + print(isv) + assert _is_suitable_val(val), "Computed Base Metric should be a scalar (Int, Float or Tensor of Size 1)" self.max_val = val if self.max_val < val else self.max_val self.min_val = val if self.min_val > val else self.min_val return {"raw": val, "max": self.max_val, "min": self.min_val} From a55c83e8e65300a284a8d72d2a1798c9fd389d40 Mon Sep 17 00:00:00 2001 From: Jan-Henrik Lambrechts Date: Fri, 15 Oct 2021 14:02:25 +0100 Subject: [PATCH 16/56] introduced infinity as bounds --- torchmetrics/wrappers/minmax.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchmetrics/wrappers/minmax.py b/torchmetrics/wrappers/minmax.py index 5150568490e..3858a0b0e2c 100644 --- a/torchmetrics/wrappers/minmax.py +++ b/torchmetrics/wrappers/minmax.py @@ -64,8 +64,8 @@ def __init__( self, base_metric: Metric, dist_sync_on_step: bool = False, - min_bound_init: float = 1.0, - max_bound_init: float = 0.0, + min_bound_init: float = float("inf"), + max_bound_init: float = float("-inf") ): super().__init__(dist_sync_on_step=dist_sync_on_step) self._base_metric = base_metric From dd4c9ce56b5e88144d133ceb67fdfef7225bc226 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 15 Oct 2021 13:05:02 +0000 Subject: [PATCH 17/56] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/wrappers/test_minmax.py | 14 +++++++++----- torchmetrics/wrappers/minmax.py | 13 +++++++------ 2 files changed, 16 insertions(+), 11 deletions(-) diff --git a/tests/wrappers/test_minmax.py b/tests/wrappers/test_minmax.py index 49eff8dae35..3f572b55632 100644 --- a/tests/wrappers/test_minmax.py +++ b/tests/wrappers/test_minmax.py @@ -2,8 +2,9 @@ import torch from torchmetrics.classification import Accuracy -from torchmetrics.wrappers import MinMaxMetric from torchmetrics.metric import Metric +from torchmetrics.wrappers import MinMaxMetric + def test_base() -> None: """test that both min and max versions of MinMaxMetric operate correctly after calling compute.""" @@ -33,20 +34,23 @@ def test_base() -> None: assert acc["max"] == 1.0 assert acc["min"] == 0.5 + def test_no_scalar_compute() -> None: - """test that an assertion error is thrown if the wrapped basemetric gives a non-scalar on compute""" + """test that an assertion error is thrown if the wrapped basemetric gives a non-scalar on compute.""" class NonScalarMetric(Metric): def __init__(self): super().__init__() pass + def update(self): pass + def compute(self): return "" - + nsm = NonScalarMetric() min_max_nsm = MinMaxMetric(nsm) - + with pytest.raises(AssertionError): - min_max_nsm.compute() \ No newline at end of file + min_max_nsm.compute() diff --git a/torchmetrics/wrappers/minmax.py b/torchmetrics/wrappers/minmax.py index 3858a0b0e2c..8f67c9f45cd 100644 --- a/torchmetrics/wrappers/minmax.py +++ b/torchmetrics/wrappers/minmax.py @@ -19,11 +19,13 @@ from torchmetrics.metric import Metric + def _is_suitable_val(val: Any) -> bool: """Utility function that checks whether min/max value is either a: - - int - - float - - tensor with 1 element + + - int + - float + - tensor with 1 element """ print(val) print(type(val)) @@ -36,8 +38,7 @@ def _is_suitable_val(val: Any) -> bool: else: return False else: - return False - + return False class MinMaxMetric(Metric): @@ -65,7 +66,7 @@ def __init__( base_metric: Metric, dist_sync_on_step: bool = False, min_bound_init: float = float("inf"), - max_bound_init: float = float("-inf") + max_bound_init: float = float("-inf"), ): super().__init__(dist_sync_on_step=dist_sync_on_step) self._base_metric = base_metric From 1dc9a7c58f71bcd1ef610fd3072bab0eb9f37557 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Wed, 20 Oct 2021 19:59:02 +0200 Subject: [PATCH 18/56] Apply suggestions from code review --- torchmetrics/wrappers/minmax.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/torchmetrics/wrappers/minmax.py b/torchmetrics/wrappers/minmax.py index 8f67c9f45cd..0d362250bc4 100644 --- a/torchmetrics/wrappers/minmax.py +++ b/torchmetrics/wrappers/minmax.py @@ -33,12 +33,8 @@ def _is_suitable_val(val: Any) -> bool: return True elif type(val) == torch.Tensor: print(val.size()) - if val.size() == torch.Size([]): - return True - else: - return False - else: - return False + return val.size() == torch.Size([]) + return False class MinMaxMetric(Metric): From 7078a16c2005f9cc0a3d70ca9277a608b49f248e Mon Sep 17 00:00:00 2001 From: Jan-Henrik Lambrechts <31068156+janhenriklambrechts@users.noreply.github.com> Date: Thu, 21 Oct 2021 00:08:20 +0100 Subject: [PATCH 19/56] update typing in helper function Co-authored-by: Jirka Borovec --- torchmetrics/wrappers/minmax.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchmetrics/wrappers/minmax.py b/torchmetrics/wrappers/minmax.py index 0d362250bc4..2b1b907cb45 100644 --- a/torchmetrics/wrappers/minmax.py +++ b/torchmetrics/wrappers/minmax.py @@ -20,7 +20,7 @@ from torchmetrics.metric import Metric -def _is_suitable_val(val: Any) -> bool: +def _is_suitable_val(val: Union[int, float, Tensor]) -> bool: """Utility function that checks whether min/max value is either a: - int From 306f6906f08c67d9c15ef0097be78a2c4fad0801 Mon Sep 17 00:00:00 2001 From: Jan-Henrik Lambrechts <31068156+janhenriklambrechts@users.noreply.github.com> Date: Thu, 21 Oct 2021 00:08:57 +0100 Subject: [PATCH 20/56] summarize helper function Co-authored-by: Jirka Borovec --- torchmetrics/wrappers/minmax.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/torchmetrics/wrappers/minmax.py b/torchmetrics/wrappers/minmax.py index 2b1b907cb45..a987bb6f659 100644 --- a/torchmetrics/wrappers/minmax.py +++ b/torchmetrics/wrappers/minmax.py @@ -21,12 +21,7 @@ def _is_suitable_val(val: Union[int, float, Tensor]) -> bool: - """Utility function that checks whether min/max value is either a: - - - int - - float - - tensor with 1 element - """ + """Utility function that checks whether min/max value.""" print(val) print(type(val)) if (type(val) == int) or (type(val) == float): From 6890b36a289690ee043edf6212f0b7a1ec5e3858 Mon Sep 17 00:00:00 2001 From: Jan-Henrik Lambrechts Date: Thu, 21 Oct 2021 10:36:32 +0100 Subject: [PATCH 21/56] added example of minmaxmetric and removed debugging print statements --- torchmetrics/wrappers/minmax.py | 33 ++++++++++++++++++++------------- 1 file changed, 20 insertions(+), 13 deletions(-) diff --git a/torchmetrics/wrappers/minmax.py b/torchmetrics/wrappers/minmax.py index a987bb6f659..05ab4bc14d7 100644 --- a/torchmetrics/wrappers/minmax.py +++ b/torchmetrics/wrappers/minmax.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict +from typing import Any, Dict, Union import torch from torch import Tensor @@ -22,12 +22,9 @@ def _is_suitable_val(val: Union[int, float, Tensor]) -> bool: """Utility function that checks whether min/max value.""" - print(val) - print(type(val)) if (type(val) == int) or (type(val) == float): return True elif type(val) == torch.Tensor: - print(val.size()) return val.size() == torch.Size([]) return False @@ -35,11 +32,6 @@ def _is_suitable_val(val: Union[int, float, Tensor]) -> bool: class MinMaxMetric(Metric): """Wrapper Metric that tracks both the minimum and maximum of a scalar/tensor across an experiment. - Note: - Make sure you pass proper initialization values to the ``min_bound_init`` and ``max_bound_init`` parameters. - For the ``Accuracy`` metric, the defaults of ``0.0`` and ``1.0`` make sense, - however, for other metrics you will likely want to use different initialization values. - Args: base_metric: The metric of which you want to keep track of its maximum and minimum values. @@ -47,9 +39,26 @@ class MinMaxMetric(Metric): Synchronize metric state across processes at each ``forward()`` before returning the value at the step. min_bound_init: - Initialization value of the ``min`` parameter. default: 0.0 + Initialization value of the ``min`` parameter. default: -inf max_bound_init: - Initialization value of the ``max`` parameter. default: 1.0 + Initialization value of the ``max`` parameter. default: inf + + Example:: + >>> import torch + >>> from torchmetrics import Accuracy, MinMaxMetric + >>> base_metric = Accuracy() + >>> minmax_metric = MinMaxMetric(base_metric) + >>> preds_1 = torch.Tensor([[0.9, 0.1], [0.2, 0.8]]) + >>> preds_2 = torch.Tensor([[0.1, 0.9], [0.2, 0.8]]) + >>> labels = torch.Tensor([[0, 1], [0, 1]]).long() + >>> minmax_metric(preds_1,labels) # Accuracy is 0.5 + >>> output = minmax_metric.compute() + >>> print(output) + {'raw': tensor(0.5000), 'max': tensor(0.5000), 'min': tensor(0.5000)} + >>> minmax_metric(preds_2,labels) # Accuracy is 1.0 + >>> output = minmax_metric.compute() + >>> print(output) + {'raw': tensor(1.), 'max': tensor(1.), 'min': tensor(0.5000)} """ def __init__( @@ -77,8 +86,6 @@ def compute(self) -> Dict[str, Tensor]: # type: ignore (``max``) values. """ val = self._base_metric.compute() - isv = _is_suitable_val(val) - print(isv) assert _is_suitable_val(val), "Computed Base Metric should be a scalar (Int, Float or Tensor of Size 1)" self.max_val = val if self.max_val < val else self.max_val self.min_val = val if self.min_val > val else self.min_val From ab76517d170b1a859bf63a2c88af88daeeb41ce6 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 21 Oct 2021 09:39:53 +0000 Subject: [PATCH 22/56] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- torchmetrics/wrappers/minmax.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchmetrics/wrappers/minmax.py b/torchmetrics/wrappers/minmax.py index 05ab4bc14d7..679fde36651 100644 --- a/torchmetrics/wrappers/minmax.py +++ b/torchmetrics/wrappers/minmax.py @@ -42,7 +42,7 @@ class MinMaxMetric(Metric): Initialization value of the ``min`` parameter. default: -inf max_bound_init: Initialization value of the ``max`` parameter. default: inf - + Example:: >>> import torch >>> from torchmetrics import Accuracy, MinMaxMetric @@ -50,7 +50,7 @@ class MinMaxMetric(Metric): >>> minmax_metric = MinMaxMetric(base_metric) >>> preds_1 = torch.Tensor([[0.9, 0.1], [0.2, 0.8]]) >>> preds_2 = torch.Tensor([[0.1, 0.9], [0.2, 0.8]]) - >>> labels = torch.Tensor([[0, 1], [0, 1]]).long() + >>> labels = torch.Tensor([[0, 1], [0, 1]]).long() >>> minmax_metric(preds_1,labels) # Accuracy is 0.5 >>> output = minmax_metric.compute() >>> print(output) From 80456ccc4ec64c1830e1b4035049dd8fe211edc5 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Fri, 22 Oct 2021 14:53:37 +0200 Subject: [PATCH 23/56] changelog --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 2bb38b37cf4..c2b5c61c474 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -47,6 +47,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added Short Term Objective Intelligibility (`STOI`) ([#353](https://github.com/PyTorchLightning/metrics/issues/353)) +- Added `MinMaxMetric` to wrappers ([#556](https://github.com/PyTorchLightning/metrics/pull/556)) + + ### Changed - `AveragePrecision` will now as default output the `macro` average for multilabel and multiclass problems ([#477](https://github.com/PyTorchLightning/metrics/pull/477)) From bde7ae5afb10816ea6379dc904d0a89b1c286ea4 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Fri, 22 Oct 2021 14:54:58 +0200 Subject: [PATCH 24/56] docs --- docs/source/references/modules.rst | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/docs/source/references/modules.rst b/docs/source/references/modules.rst index a6566018384..446e7ebc1dd 100644 --- a/docs/source/references/modules.rst +++ b/docs/source/references/modules.rst @@ -622,6 +622,12 @@ MetricTracker .. autoclass:: torchmetrics.MetricTracker :noindex: +MinMaxMetric +~~~~~~~~~~~~ + +.. autoclass:: torchmetrics.MinMaxMetric + :noindex: + MultioutputWrapper ~~~~~~~~~~~~~~~~~~ From a899406319e089ae3d4c9f997960f3b2aaa8fb14 Mon Sep 17 00:00:00 2001 From: Jan-Henrik Lambrechts <31068156+janhenriklambrechts@users.noreply.github.com> Date: Fri, 22 Oct 2021 14:34:28 +0100 Subject: [PATCH 25/56] Update torchmetrics/wrappers/minmax.py Co-authored-by: Nicki Skafte Detlefsen --- torchmetrics/wrappers/minmax.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/torchmetrics/wrappers/minmax.py b/torchmetrics/wrappers/minmax.py index 679fde36651..2d35e707fa2 100644 --- a/torchmetrics/wrappers/minmax.py +++ b/torchmetrics/wrappers/minmax.py @@ -69,6 +69,10 @@ def __init__( max_bound_init: float = float("-inf"), ): super().__init__(dist_sync_on_step=dist_sync_on_step) + if not isinstance(base_metric, Metric): + raise raise ValueError( + f"Expected base metric to be an instance of torchmetrics.Metric but received {base_metric}" + ) self._base_metric = base_metric self.add_state("min_val", default=torch.tensor(min_bound_init)) self.add_state("max_val", default=torch.tensor(max_bound_init)) From 0887a70333e8989661aa02ec29d757e776b30b31 Mon Sep 17 00:00:00 2001 From: Jan-Henrik Lambrechts <31068156+janhenriklambrechts@users.noreply.github.com> Date: Fri, 22 Oct 2021 14:35:14 +0100 Subject: [PATCH 26/56] Update torchmetrics/wrappers/minmax.py Co-authored-by: Nicki Skafte Detlefsen --- torchmetrics/wrappers/minmax.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchmetrics/wrappers/minmax.py b/torchmetrics/wrappers/minmax.py index 2d35e707fa2..4feddf12aea 100644 --- a/torchmetrics/wrappers/minmax.py +++ b/torchmetrics/wrappers/minmax.py @@ -25,7 +25,7 @@ def _is_suitable_val(val: Union[int, float, Tensor]) -> bool: if (type(val) == int) or (type(val) == float): return True elif type(val) == torch.Tensor: - return val.size() == torch.Size([]) + return val.numel() == 1 return False From b9e34aa24c2f63fff6f62e07aac76965b6e8d081 Mon Sep 17 00:00:00 2001 From: Jan-Henrik Lambrechts <31068156+janhenriklambrechts@users.noreply.github.com> Date: Fri, 22 Oct 2021 14:35:37 +0100 Subject: [PATCH 27/56] Update torchmetrics/wrappers/minmax.py Co-authored-by: Nicki Skafte Detlefsen --- torchmetrics/wrappers/minmax.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/torchmetrics/wrappers/minmax.py b/torchmetrics/wrappers/minmax.py index 4feddf12aea..9897e46822f 100644 --- a/torchmetrics/wrappers/minmax.py +++ b/torchmetrics/wrappers/minmax.py @@ -97,6 +97,5 @@ def compute(self) -> Dict[str, Tensor]: # type: ignore def reset(self) -> None: """Sets ``max_val`` and ``min_val`` to the initialization bounds and resets the base metric.""" - self.max_val = self.max_bound_init - self.min_val = self.min_bound_init + super().reset() self._base_metric.reset() From 4d8c29651f52167915fbc5780f0abe4fc41af7aa Mon Sep 17 00:00:00 2001 From: Jan-Henrik Lambrechts Date: Mon, 25 Oct 2021 10:01:24 +0100 Subject: [PATCH 28/56] remove personalizable metric values, implemented nicki comments --- tests/wrappers/test_minmax.py | 25 ++++++++++--------------- torchmetrics/wrappers/minmax.py | 12 +++--------- 2 files changed, 13 insertions(+), 24 deletions(-) diff --git a/tests/wrappers/test_minmax.py b/tests/wrappers/test_minmax.py index 3f572b55632..0616217b310 100644 --- a/tests/wrappers/test_minmax.py +++ b/tests/wrappers/test_minmax.py @@ -1,13 +1,13 @@ import pytest import torch -from torchmetrics.classification import Accuracy +from torchmetrics.classification import Accuracy, ConfusionMatrix from torchmetrics.metric import Metric from torchmetrics.wrappers import MinMaxMetric def test_base() -> None: - """test that both min and max versions of MinMaxMetric operate correctly after calling compute.""" + """tests that both min and max versions of MinMaxMetric operate correctly after calling compute.""" acc = Accuracy() min_max_acc = MinMaxMetric(acc) @@ -34,22 +34,17 @@ def test_base() -> None: assert acc["max"] == 1.0 assert acc["min"] == 0.5 +def test_no_base_metric() -> None: + """tests that ValueError is raised when no base_metric is passed""" + x = "" + with pytest.raises(ValueError): + MinMaxMetric(x) -def test_no_scalar_compute() -> None: - """test that an assertion error is thrown if the wrapped basemetric gives a non-scalar on compute.""" - - class NonScalarMetric(Metric): - def __init__(self): - super().__init__() - pass - def update(self): - pass - - def compute(self): - return "" +def test_no_scalar_compute() -> None: + """tests that an assertion error is thrown if the wrapped basemetric gives a non-scalar on compute.""" - nsm = NonScalarMetric() + nsm = ConfusionMatrix(num_classes=2) min_max_nsm = MinMaxMetric(nsm) with pytest.raises(AssertionError): diff --git a/torchmetrics/wrappers/minmax.py b/torchmetrics/wrappers/minmax.py index 9897e46822f..cc68ef7cb00 100644 --- a/torchmetrics/wrappers/minmax.py +++ b/torchmetrics/wrappers/minmax.py @@ -38,10 +38,6 @@ class MinMaxMetric(Metric): dist_sync_on_step: Synchronize metric state across processes at each ``forward()`` before returning the value at the step. - min_bound_init: - Initialization value of the ``min`` parameter. default: -inf - max_bound_init: - Initialization value of the ``max`` parameter. default: inf Example:: >>> import torch @@ -65,19 +61,17 @@ def __init__( self, base_metric: Metric, dist_sync_on_step: bool = False, - min_bound_init: float = float("inf"), - max_bound_init: float = float("-inf"), ): super().__init__(dist_sync_on_step=dist_sync_on_step) if not isinstance(base_metric, Metric): - raise raise ValueError( + raise ValueError( f"Expected base metric to be an instance of torchmetrics.Metric but received {base_metric}" ) self._base_metric = base_metric self.add_state("min_val", default=torch.tensor(min_bound_init)) self.add_state("max_val", default=torch.tensor(max_bound_init)) - self.min_bound_init = min_bound_init - self.max_bound_init = max_bound_init + self.min_bound_init = float("inf") + self.max_bound_init = float("-inf") def update(self, *args: Any, **kwargs: Any) -> None: # type: ignore """Updates the underlying metric.""" From 2643fb7a965591e233f05bcc5c61c62398e11e3b Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 25 Oct 2021 09:01:58 +0000 Subject: [PATCH 29/56] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/wrappers/test_minmax.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/wrappers/test_minmax.py b/tests/wrappers/test_minmax.py index 0616217b310..ac32e7816be 100644 --- a/tests/wrappers/test_minmax.py +++ b/tests/wrappers/test_minmax.py @@ -34,8 +34,9 @@ def test_base() -> None: assert acc["max"] == 1.0 assert acc["min"] == 0.5 + def test_no_base_metric() -> None: - """tests that ValueError is raised when no base_metric is passed""" + """tests that ValueError is raised when no base_metric is passed.""" x = "" with pytest.raises(ValueError): MinMaxMetric(x) From 8afd4b53e78290cff399ee8752a722499bbe33b2 Mon Sep 17 00:00:00 2001 From: Jan-Henrik Lambrechts <31068156+janhenriklambrechts@users.noreply.github.com> Date: Mon, 25 Oct 2021 10:02:04 +0100 Subject: [PATCH 30/56] Update torchmetrics/wrappers/minmax.py Co-authored-by: Nicki Skafte Detlefsen --- torchmetrics/wrappers/minmax.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchmetrics/wrappers/minmax.py b/torchmetrics/wrappers/minmax.py index cc68ef7cb00..b12ebe7aa73 100644 --- a/torchmetrics/wrappers/minmax.py +++ b/torchmetrics/wrappers/minmax.py @@ -68,8 +68,8 @@ def __init__( f"Expected base metric to be an instance of torchmetrics.Metric but received {base_metric}" ) self._base_metric = base_metric - self.add_state("min_val", default=torch.tensor(min_bound_init)) - self.add_state("max_val", default=torch.tensor(max_bound_init)) + self.add_state("min_val", default=torch.tensor(min_bound_init), dist_reduce_fx='min') + self.add_state("max_val", default=torch.tensor(max_bound_init), dist_reduce_fx='max') self.min_bound_init = float("inf") self.max_bound_init = float("-inf") From 321e87c325d924bc66d64e1811477ddab3b077ad Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 25 Oct 2021 09:02:57 +0000 Subject: [PATCH 31/56] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- torchmetrics/wrappers/minmax.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchmetrics/wrappers/minmax.py b/torchmetrics/wrappers/minmax.py index b12ebe7aa73..76f4f62be6b 100644 --- a/torchmetrics/wrappers/minmax.py +++ b/torchmetrics/wrappers/minmax.py @@ -68,8 +68,8 @@ def __init__( f"Expected base metric to be an instance of torchmetrics.Metric but received {base_metric}" ) self._base_metric = base_metric - self.add_state("min_val", default=torch.tensor(min_bound_init), dist_reduce_fx='min') - self.add_state("max_val", default=torch.tensor(max_bound_init), dist_reduce_fx='max') + self.add_state("min_val", default=torch.tensor(min_bound_init), dist_reduce_fx="min") + self.add_state("max_val", default=torch.tensor(max_bound_init), dist_reduce_fx="max") self.min_bound_init = float("inf") self.max_bound_init = float("-inf") From 68698150339e0ac92484b3d9c9fa3db24bd70e02 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Mon, 25 Oct 2021 12:13:49 +0200 Subject: [PATCH 32/56] fix implementation --- torchmetrics/wrappers/minmax.py | 58 ++++++++++++++++++++++----------- 1 file changed, 39 insertions(+), 19 deletions(-) diff --git a/torchmetrics/wrappers/minmax.py b/torchmetrics/wrappers/minmax.py index 76f4f62be6b..3b623890353 100644 --- a/torchmetrics/wrappers/minmax.py +++ b/torchmetrics/wrappers/minmax.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, Union +from typing import Any, Callable, Dict, Union, Optional import torch from torch import Tensor @@ -20,24 +20,27 @@ from torchmetrics.metric import Metric -def _is_suitable_val(val: Union[int, float, Tensor]) -> bool: - """Utility function that checks whether min/max value.""" - if (type(val) == int) or (type(val) == float): - return True - elif type(val) == torch.Tensor: - return val.numel() == 1 - return False - - class MinMaxMetric(Metric): - """Wrapper Metric that tracks both the minimum and maximum of a scalar/tensor across an experiment. + """ Wrapper Metric that tracks both the minimum and maximum of a scalar/tensor across an experiment. The min/max + value will be updated each time `.compute` is called. Args: base_metric: The metric of which you want to keep track of its maximum and minimum values. + compute_on_step: + Forward only calls ``update()`` and return None if this is set to False. default: True dist_sync_on_step: Synchronize metric state across processes at each ``forward()`` - before returning the value at the step. + before returning the value at the step. default: False + process_group: + Specify the process group on which synchronization is called. default: None (which selects the entire world) + dist_sync_fn: + Callback that performs the allgather operation on the metric state. When ``None``, DDP + will be used to perform the allgather + + Raises: + ValueError + If ``base_metric` argument is not an subclasses instance of ``torchmetrics.Metric`` Example:: >>> import torch @@ -60,18 +63,24 @@ class MinMaxMetric(Metric): def __init__( self, base_metric: Metric, + compute_on_step: bool = True, dist_sync_on_step: bool = False, - ): - super().__init__(dist_sync_on_step=dist_sync_on_step) + process_group: Optional[Any] = None, + dist_sync_fn: Callable = None, + ) -> None: + super().__init__( + compute_on_step=compute_on_step, + dist_sync_on_step=dist_sync_on_step, + process_group=process_group, + dist_sync_fn=dist_sync_fn, + ) if not isinstance(base_metric, Metric): raise ValueError( f"Expected base metric to be an instance of torchmetrics.Metric but received {base_metric}" ) self._base_metric = base_metric - self.add_state("min_val", default=torch.tensor(min_bound_init), dist_reduce_fx="min") - self.add_state("max_val", default=torch.tensor(max_bound_init), dist_reduce_fx="max") - self.min_bound_init = float("inf") - self.max_bound_init = float("-inf") + self.add_state("min_val", default=torch.tensor(float("inf")), dist_reduce_fx="min") + self.add_state("max_val", default=torch.tensor(float("-inf")), dist_reduce_fx="max") def update(self, *args: Any, **kwargs: Any) -> None: # type: ignore """Updates the underlying metric.""" @@ -84,7 +93,8 @@ def compute(self) -> Dict[str, Tensor]: # type: ignore (``max``) values. """ val = self._base_metric.compute() - assert _is_suitable_val(val), "Computed Base Metric should be a scalar (Int, Float or Tensor of Size 1)" + if not self._is_suitable_val(val): + raise RuntimeError('Returned value from base metric should be a scalar (int, float or tensor of size 1, but got {val}') self.max_val = val if self.max_val < val else self.max_val self.min_val = val if self.min_val > val else self.min_val return {"raw": val, "max": self.max_val, "min": self.min_val} @@ -93,3 +103,13 @@ def reset(self) -> None: """Sets ``max_val`` and ``min_val`` to the initialization bounds and resets the base metric.""" super().reset() self._base_metric.reset() + + @staticmethod + def _is_suitable_val(val: Union[int, float, Tensor]) -> bool: + """Utility function that checks whether min/max value.""" + if (type(val) == int) or (type(val) == float): + return True + elif type(val) == torch.Tensor: + return val.numel() == 1 + return False + From 73700bd0a37256c8de5f3c67b70903ae4021cf46 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Mon, 25 Oct 2021 12:13:57 +0200 Subject: [PATCH 33/56] improve tests --- tests/wrappers/test_minmax.py | 67 +++++++++++++++++++++++++++++------ 1 file changed, 57 insertions(+), 10 deletions(-) diff --git a/tests/wrappers/test_minmax.py b/tests/wrappers/test_minmax.py index ac32e7816be..12c6c84a2cd 100644 --- a/tests/wrappers/test_minmax.py +++ b/tests/wrappers/test_minmax.py @@ -1,12 +1,62 @@ import pytest import torch -from torchmetrics.classification import Accuracy, ConfusionMatrix -from torchmetrics.metric import Metric +from functools import partial + +from tests.helpers import seed_all +from torchmetrics import Accuracy, ConfusionMatrix, MeanSquaredError from torchmetrics.wrappers import MinMaxMetric +from tests.helpers.testers import BATCH_SIZE, NUM_BATCHES, NUM_CLASSES, MetricTester + + +seed_all(42) + + +class TestingMinMaxMetric(MinMaxMetric): + """ wrap metric to fit testing framework """ + def compute(self): + """ instead of returning dict, return as list """ + output_dict = super().compute() + return [output_dict['raw'], output_dict['min'], output_dict['max']] + + def forward(self, *args, **kwargs): + self.update(*args, **kwargs) + return self.compute() -def test_base() -> None: +def compare_fn(preds, target, base_fn): + """ comparing function for minmax wrapper""" + min, max = 1e6, -1e6 # pick some very large numbers for comparing + for i in range(NUM_BATCHES): + val = base_fn(preds[:(i+1)*BATCH_SIZE], target[:(i+1)*BATCH_SIZE]).cpu().numpy() + min = min if min < val else val + max = max if max > val else val + raw = base_fn(preds, target) + return [raw.cpu().numpy(), min, max] + + +@pytest.mark.parametrize("preds, target, base_metric", [ + (torch.rand(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES).softmax(dim=-1), torch.randint(NUM_CLASSES, (NUM_BATCHES, BATCH_SIZE)), Accuracy(num_classes=NUM_CLASSES)), + (torch.randn(NUM_BATCHES, BATCH_SIZE), torch.randn(NUM_BATCHES, BATCH_SIZE), MeanSquaredError()) +]) +class TestMultioutputWrapper(MetricTester): + """ Test the MinMaxMetric wrapper works as expected """ + @pytest.mark.parametrize("ddp", [True, False]) + def test_multioutput_wrapper(self, preds, target, base_metric, ddp): + self.run_class_metric_test( + ddp, + preds, + target, + TestingMinMaxMetric, + partial(compare_fn, base_fn=base_metric), + dist_sync_on_step=False, + metric_args=dict(base_metric=base_metric), + check_batch=False, + check_scriptable=False, + ) + + +def test_basic_example() -> None: """tests that both min and max versions of MinMaxMetric operate correctly after calling compute.""" acc = Accuracy() min_max_acc = MinMaxMetric(acc) @@ -37,16 +87,13 @@ def test_base() -> None: def test_no_base_metric() -> None: """tests that ValueError is raised when no base_metric is passed.""" - x = "" - with pytest.raises(ValueError): - MinMaxMetric(x) + with pytest.raises(ValueError, match=r'Expected base metric to be an instance .*'): + MinMaxMetric([]) def test_no_scalar_compute() -> None: """tests that an assertion error is thrown if the wrapped basemetric gives a non-scalar on compute.""" + min_max_nsm = MinMaxMetric(ConfusionMatrix(num_classes=2)) - nsm = ConfusionMatrix(num_classes=2) - min_max_nsm = MinMaxMetric(nsm) - - with pytest.raises(AssertionError): + with pytest.raises(RuntimeError, match=r'Returned value from base metric should be a scalar .*'): min_max_nsm.compute() From 76c8273bf98b2032d8484b802ac089f1c515d98a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 25 Oct 2021 10:14:53 +0000 Subject: [PATCH 34/56] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/wrappers/test_minmax.py | 40 ++++++++++++++++++++------------- torchmetrics/wrappers/minmax.py | 9 ++++---- 2 files changed, 29 insertions(+), 20 deletions(-) diff --git a/tests/wrappers/test_minmax.py b/tests/wrappers/test_minmax.py index 12c6c84a2cd..866cc320a20 100644 --- a/tests/wrappers/test_minmax.py +++ b/tests/wrappers/test_minmax.py @@ -1,23 +1,23 @@ +from functools import partial + import pytest import torch -from functools import partial - from tests.helpers import seed_all +from tests.helpers.testers import BATCH_SIZE, NUM_BATCHES, NUM_CLASSES, MetricTester from torchmetrics import Accuracy, ConfusionMatrix, MeanSquaredError from torchmetrics.wrappers import MinMaxMetric -from tests.helpers.testers import BATCH_SIZE, NUM_BATCHES, NUM_CLASSES, MetricTester - seed_all(42) class TestingMinMaxMetric(MinMaxMetric): - """ wrap metric to fit testing framework """ + """wrap metric to fit testing framework.""" + def compute(self): - """ instead of returning dict, return as list """ + """instead of returning dict, return as list.""" output_dict = super().compute() - return [output_dict['raw'], output_dict['min'], output_dict['max']] + return [output_dict["raw"], output_dict["min"], output_dict["max"]] def forward(self, *args, **kwargs): self.update(*args, **kwargs) @@ -25,22 +25,30 @@ def forward(self, *args, **kwargs): def compare_fn(preds, target, base_fn): - """ comparing function for minmax wrapper""" + """comparing function for minmax wrapper.""" min, max = 1e6, -1e6 # pick some very large numbers for comparing for i in range(NUM_BATCHES): - val = base_fn(preds[:(i+1)*BATCH_SIZE], target[:(i+1)*BATCH_SIZE]).cpu().numpy() + val = base_fn(preds[: (i + 1) * BATCH_SIZE], target[: (i + 1) * BATCH_SIZE]).cpu().numpy() min = min if min < val else val max = max if max > val else val raw = base_fn(preds, target) return [raw.cpu().numpy(), min, max] -@pytest.mark.parametrize("preds, target, base_metric", [ - (torch.rand(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES).softmax(dim=-1), torch.randint(NUM_CLASSES, (NUM_BATCHES, BATCH_SIZE)), Accuracy(num_classes=NUM_CLASSES)), - (torch.randn(NUM_BATCHES, BATCH_SIZE), torch.randn(NUM_BATCHES, BATCH_SIZE), MeanSquaredError()) -]) +@pytest.mark.parametrize( + "preds, target, base_metric", + [ + ( + torch.rand(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES).softmax(dim=-1), + torch.randint(NUM_CLASSES, (NUM_BATCHES, BATCH_SIZE)), + Accuracy(num_classes=NUM_CLASSES), + ), + (torch.randn(NUM_BATCHES, BATCH_SIZE), torch.randn(NUM_BATCHES, BATCH_SIZE), MeanSquaredError()), + ], +) class TestMultioutputWrapper(MetricTester): - """ Test the MinMaxMetric wrapper works as expected """ + """Test the MinMaxMetric wrapper works as expected.""" + @pytest.mark.parametrize("ddp", [True, False]) def test_multioutput_wrapper(self, preds, target, base_metric, ddp): self.run_class_metric_test( @@ -87,7 +95,7 @@ def test_basic_example() -> None: def test_no_base_metric() -> None: """tests that ValueError is raised when no base_metric is passed.""" - with pytest.raises(ValueError, match=r'Expected base metric to be an instance .*'): + with pytest.raises(ValueError, match=r"Expected base metric to be an instance .*"): MinMaxMetric([]) @@ -95,5 +103,5 @@ def test_no_scalar_compute() -> None: """tests that an assertion error is thrown if the wrapped basemetric gives a non-scalar on compute.""" min_max_nsm = MinMaxMetric(ConfusionMatrix(num_classes=2)) - with pytest.raises(RuntimeError, match=r'Returned value from base metric should be a scalar .*'): + with pytest.raises(RuntimeError, match=r"Returned value from base metric should be a scalar .*"): min_max_nsm.compute() diff --git a/torchmetrics/wrappers/minmax.py b/torchmetrics/wrappers/minmax.py index 3b623890353..74e5b5fa02f 100644 --- a/torchmetrics/wrappers/minmax.py +++ b/torchmetrics/wrappers/minmax.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Dict, Union, Optional +from typing import Any, Callable, Dict, Optional, Union import torch from torch import Tensor @@ -21,7 +21,7 @@ class MinMaxMetric(Metric): - """ Wrapper Metric that tracks both the minimum and maximum of a scalar/tensor across an experiment. The min/max + """Wrapper Metric that tracks both the minimum and maximum of a scalar/tensor across an experiment. The min/max value will be updated each time `.compute` is called. Args: @@ -94,7 +94,9 @@ def compute(self) -> Dict[str, Tensor]: # type: ignore """ val = self._base_metric.compute() if not self._is_suitable_val(val): - raise RuntimeError('Returned value from base metric should be a scalar (int, float or tensor of size 1, but got {val}') + raise RuntimeError( + "Returned value from base metric should be a scalar (int, float or tensor of size 1, but got {val}" + ) self.max_val = val if self.max_val < val else self.max_val self.min_val = val if self.min_val > val else self.min_val return {"raw": val, "max": self.max_val, "min": self.min_val} @@ -112,4 +114,3 @@ def _is_suitable_val(val: Union[int, float, Tensor]) -> bool: elif type(val) == torch.Tensor: return val.numel() == 1 return False - From 4d09a22f853a43244efc92dfab66f8be92ab398f Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Mon, 25 Oct 2021 12:23:37 +0200 Subject: [PATCH 35/56] fix mypy --- torchmetrics/wrappers/minmax.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/torchmetrics/wrappers/minmax.py b/torchmetrics/wrappers/minmax.py index 74e5b5fa02f..674f97a1d96 100644 --- a/torchmetrics/wrappers/minmax.py +++ b/torchmetrics/wrappers/minmax.py @@ -59,6 +59,8 @@ class MinMaxMetric(Metric): >>> print(output) {'raw': tensor(1.), 'max': tensor(1.), 'min': tensor(0.5000)} """ + min_val: Tensor + max_val: Tensor def __init__( self, @@ -111,6 +113,6 @@ def _is_suitable_val(val: Union[int, float, Tensor]) -> bool: """Utility function that checks whether min/max value.""" if (type(val) == int) or (type(val) == float): return True - elif type(val) == torch.Tensor: + elif isinstance(val, Tensor): return val.numel() == 1 return False From f9574c228b5a75ae4c441e18a6b19e95c462ebec Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 25 Oct 2021 10:24:32 +0000 Subject: [PATCH 36/56] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- torchmetrics/wrappers/minmax.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torchmetrics/wrappers/minmax.py b/torchmetrics/wrappers/minmax.py index 674f97a1d96..9a965e68311 100644 --- a/torchmetrics/wrappers/minmax.py +++ b/torchmetrics/wrappers/minmax.py @@ -59,6 +59,7 @@ class MinMaxMetric(Metric): >>> print(output) {'raw': tensor(1.), 'max': tensor(1.), 'min': tensor(0.5000)} """ + min_val: Tensor max_val: Tensor From 15f31e26cc3604901a1e8df9555ed373926e988f Mon Sep 17 00:00:00 2001 From: Jirka Date: Mon, 25 Oct 2021 14:10:28 +0200 Subject: [PATCH 37/56] prepare 0.6 RC --- CHANGELOG.md | 64 ++++--------------- .../classification/test_average_precision.py | 2 +- torchmetrics/__about__.py | 2 +- torchmetrics/classification/__init__.py | 2 +- ...{average_precision.py => avg_precision.py} | 0 5 files changed, 17 insertions(+), 53 deletions(-) rename torchmetrics/classification/{average_precision.py => avg_precision.py} (100%) diff --git a/CHANGELOG.md b/CHANGELOG.md index 49ebaf56417..14bf326b303 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,67 +6,38 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 **Note: we move fast, but still we preserve 0.1 version (one feature release) back compatibility.** -## [unReleased] - 2021-MM-DD +## [0.6.0] - 2021-10-DD ### Added -- Added Learned Perceptual Image Patch Similarity (LPIPS) ([#431](https://github.com/PyTorchLightning/metrics/issues/431)) - - -- Added Tweedie Deviance Score ([#499](https://github.com/PyTorchLightning/metrics/pull/499)) - - +- Added audio metrics: + - Perceptual Evaluation of Speech Quality (PESQ) ([#353](https://github.com/PyTorchLightning/metrics/issues/353)) + - Short Term Objective Intelligibility (STOI) ([#353](https://github.com/PyTorchLightning/metrics/issues/353)) +- Added Information retrieval metrics: + - `RetrievalRPrecision` ([#577](https://github.com/PyTorchLightning/metrics/pull/577/)) + - `RetrievalHitRate` ([#576](https://github.com/PyTorchLightning/metrics/pull/576)) +- Added NLP metrics: + - `SacreBLEUScore` ([#546](https://github.com/PyTorchLightning/metrics/pull/546)) + - `CharErrorRate` ([#575](https://github.com/PyTorchLightning/metrics/pull/575)) +- Added other metrics: + - Tweedie Deviance Score ([#499](https://github.com/PyTorchLightning/metrics/pull/499)) + - Learned Perceptual Image Patch Similarity (LPIPS) ([#431](https://github.com/PyTorchLightning/metrics/issues/431)) - Added support for float targets in `nDCG` metric ([#437](https://github.com/PyTorchLightning/metrics/pull/437)) - - -- Added `average` argument to `AveragePrecision` metric for reducing multilabel and multiclass problems ([#477](https://github.com/PyTorchLightning/metrics/pull/477)) - - -- Added Perceptual Evaluation of Speech Quality (PESQ) ([#353](https://github.com/PyTorchLightning/metrics/issues/353)) - - +- Added `average` argument to `AveragePrecision` metric for reducing multi-label and multi-class problems ([#477](https://github.com/PyTorchLightning/metrics/pull/477)) - Added `MultioutputWrapper` ([#510](https://github.com/PyTorchLightning/metrics/pull/510)) - - - Added metric sweeping `higher_is_better` as constant attribute ([#544](https://github.com/PyTorchLightning/metrics/pull/544)) - - -- Added `SacreBLEUScore` metric to text package ([#546](https://github.com/PyTorchLightning/metrics/pull/546)) - - - Added simple aggregation metrics: `SumMetric`, `MeanMetric`, `CatMetric`, `MinMetric`, `MaxMetric` ([#506](https://github.com/PyTorchLightning/metrics/pull/506)) - - - Added pairwise submodule with metrics ([#553](https://github.com/PyTorchLightning/metrics/pull/553)) - `pairwise_cosine_similarity` - `pairwise_euclidean_distance` - `pairwise_linear_similarity` - `pairwise_manhatten_distance` - -- Added Short Term Objective Intelligibility (`STOI`) ([#353](https://github.com/PyTorchLightning/metrics/issues/353)) - - -- Added `RetrievalRPrecision` metric to retrieval package ([#577](https://github.com/PyTorchLightning/metrics/pull/577/)) - - -- Added `RetrievalHitRate` metric to retrieval package ([#576](https://github.com/PyTorchLightning/metrics/pull/576)) - - -- Added `CharErrorRate` metric to text package ([#575](https://github.com/PyTorchLightning/metrics/pull/575)) - - ### Changed - `AveragePrecision` will now as default output the `macro` average for multilabel and multiclass problems ([#477](https://github.com/PyTorchLightning/metrics/pull/477)) - - - `half`, `double`, `float` will no longer change the dtype of the metric states. Use `metric.set_dtype` instead ([#493](https://github.com/PyTorchLightning/metrics/pull/493)) - - - Renamed `AverageMeter` to `MeanMetric` ([#506](https://github.com/PyTorchLightning/metrics/pull/506)) - - - Changed `is_differentiable` from property to a constant attribute ([#551](https://github.com/PyTorchLightning/metrics/pull/551)) ### Deprecated @@ -77,18 +48,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Removed `dtype` property ([#493](https://github.com/PyTorchLightning/metrics/pull/493)) - ### Fixed - Fixed bug in `F1` with `average='macro'` and `ignore_index!=None` ([#495](https://github.com/PyTorchLightning/metrics/pull/495)) - - - Fixed bug in `pit` by using the returned first result to initialize device and type ([#533](https://github.com/PyTorchLightning/metrics/pull/533)) - - - Fixed `SSIM` metric using too much memory ([#539](https://github.com/PyTorchLightning/metrics/pull/539)) - - - Fixed bug where `device` property was not properly update when metric was a child of a module ([#542](https://github.com/PyTorchLightning/metrics/pull/542)) ## [0.5.1] - 2021-08-30 diff --git a/tests/classification/test_average_precision.py b/tests/classification/test_average_precision.py index 5c65a2256cf..557fb92b154 100644 --- a/tests/classification/test_average_precision.py +++ b/tests/classification/test_average_precision.py @@ -24,7 +24,7 @@ from tests.classification.inputs import _input_multilabel from tests.helpers import seed_all from tests.helpers.testers import NUM_CLASSES, MetricTester -from torchmetrics.classification.average_precision import AveragePrecision +from torchmetrics.classification.avg_precision import AveragePrecision from torchmetrics.functional import average_precision seed_all(42) diff --git a/torchmetrics/__about__.py b/torchmetrics/__about__.py index ea1d5dcd3c0..334914962d1 100644 --- a/torchmetrics/__about__.py +++ b/torchmetrics/__about__.py @@ -1,4 +1,4 @@ -__version__ = "0.6.0dev" +__version__ = "0.6.0rc0" __author__ = "PyTorchLightning et al." __author_email__ = "name@pytorchlightning.ai" __license__ = "Apache-2.0" diff --git a/torchmetrics/classification/__init__.py b/torchmetrics/classification/__init__.py index 35476172b06..0ed2d3d8d8b 100644 --- a/torchmetrics/classification/__init__.py +++ b/torchmetrics/classification/__init__.py @@ -14,7 +14,7 @@ from torchmetrics.classification.accuracy import Accuracy # noqa: F401 from torchmetrics.classification.auc import AUC # noqa: F401 from torchmetrics.classification.auroc import AUROC # noqa: F401 -from torchmetrics.classification.average_precision import AveragePrecision # noqa: F401 +from torchmetrics.classification.avg_precision import AveragePrecision # noqa: F401 from torchmetrics.classification.binned_precision_recall import BinnedAveragePrecision # noqa: F401 from torchmetrics.classification.binned_precision_recall import BinnedPrecisionRecallCurve # noqa: F401 from torchmetrics.classification.binned_precision_recall import BinnedRecallAtFixedPrecision # noqa: F401 diff --git a/torchmetrics/classification/average_precision.py b/torchmetrics/classification/avg_precision.py similarity index 100% rename from torchmetrics/classification/average_precision.py rename to torchmetrics/classification/avg_precision.py From 1a2191be21d05d6d2e557c86b535cf6cf9b386f7 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 25 Oct 2021 12:35:12 +0000 Subject: [PATCH 38/56] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- torchmetrics/__init__.py | 1 - 1 file changed, 1 deletion(-) diff --git a/torchmetrics/__init__.py b/torchmetrics/__init__.py index 867c5629401..9a9f6bb6402 100644 --- a/torchmetrics/__init__.py +++ b/torchmetrics/__init__.py @@ -68,7 +68,6 @@ from torchmetrics.text import WER, BERTScore, BLEUScore, CharErrorRate, ROUGEScore, SacreBLEUScore # noqa: E402 from torchmetrics.wrappers import BootStrapper, MetricTracker, MinMaxMetric, MultioutputWrapper # noqa: E402 - __all__ = [ "functional", "Accuracy", From ba83223bed5f20e42c0fa0c83a05f5be1e366cd6 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Mon, 25 Oct 2021 14:45:56 +0200 Subject: [PATCH 39/56] fix doctest --- torchmetrics/wrappers/minmax.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchmetrics/wrappers/minmax.py b/torchmetrics/wrappers/minmax.py index 9a965e68311..9d812ae8a11 100644 --- a/torchmetrics/wrappers/minmax.py +++ b/torchmetrics/wrappers/minmax.py @@ -50,11 +50,11 @@ class MinMaxMetric(Metric): >>> preds_1 = torch.Tensor([[0.9, 0.1], [0.2, 0.8]]) >>> preds_2 = torch.Tensor([[0.1, 0.9], [0.2, 0.8]]) >>> labels = torch.Tensor([[0, 1], [0, 1]]).long() - >>> minmax_metric(preds_1,labels) # Accuracy is 0.5 + >>> _ = minmax_metric(preds_1, labels) # Accuracy is 0.5 >>> output = minmax_metric.compute() >>> print(output) {'raw': tensor(0.5000), 'max': tensor(0.5000), 'min': tensor(0.5000)} - >>> minmax_metric(preds_2,labels) # Accuracy is 1.0 + >>> _ = minmax_metric(preds_2, labels) # Accuracy is 1.0 >>> output = minmax_metric.compute() >>> print(output) {'raw': tensor(1.), 'max': tensor(1.), 'min': tensor(0.5000)} From c3016b13eefe71606ca9de0e431023825618abeb Mon Sep 17 00:00:00 2001 From: Jan-Henrik Lambrechts Date: Mon, 25 Oct 2021 14:57:24 +0100 Subject: [PATCH 40/56] added pprint --- torchmetrics/wrappers/minmax.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/torchmetrics/wrappers/minmax.py b/torchmetrics/wrappers/minmax.py index 9d812ae8a11..250ce50e8d2 100644 --- a/torchmetrics/wrappers/minmax.py +++ b/torchmetrics/wrappers/minmax.py @@ -45,6 +45,7 @@ class MinMaxMetric(Metric): Example:: >>> import torch >>> from torchmetrics import Accuracy, MinMaxMetric + >>> from pprint import pprint >>> base_metric = Accuracy() >>> minmax_metric = MinMaxMetric(base_metric) >>> preds_1 = torch.Tensor([[0.9, 0.1], [0.2, 0.8]]) @@ -52,11 +53,11 @@ class MinMaxMetric(Metric): >>> labels = torch.Tensor([[0, 1], [0, 1]]).long() >>> _ = minmax_metric(preds_1, labels) # Accuracy is 0.5 >>> output = minmax_metric.compute() - >>> print(output) + >>> pprint(output) {'raw': tensor(0.5000), 'max': tensor(0.5000), 'min': tensor(0.5000)} >>> _ = minmax_metric(preds_2, labels) # Accuracy is 1.0 >>> output = minmax_metric.compute() - >>> print(output) + >>> pprint(output) {'raw': tensor(1.), 'max': tensor(1.), 'min': tensor(0.5000)} """ From eb3e478e5b79d2aeb2f5e247eada909998a9a8cf Mon Sep 17 00:00:00 2001 From: Jan-Henrik Lambrechts Date: Mon, 25 Oct 2021 15:19:39 +0100 Subject: [PATCH 41/56] moved base test to parametrize --- tests/wrappers/test_minmax.py | 40 ++++++++++++++++++++++------------- 1 file changed, 25 insertions(+), 15 deletions(-) diff --git a/tests/wrappers/test_minmax.py b/tests/wrappers/test_minmax.py index 866cc320a20..1f6ea369f21 100644 --- a/tests/wrappers/test_minmax.py +++ b/tests/wrappers/test_minmax.py @@ -63,34 +63,44 @@ def test_multioutput_wrapper(self, preds, target, base_metric, ddp): check_scriptable=False, ) - -def test_basic_example() -> None: +@pytest.mark.parametrize( + "preds, labels, raws, maxs, mins", + [ + ( + ([[0.9, 0.1], [0.2, 0.8]],[[0.1, 0.9], [0.2, 0.8]], [[0.1, 0.9], [0.8, 0.2]]), + [[0, 1], [0, 1]], + (0.5, 1.0, 0.5), + (0.5, 1.0, 1.0), + (0.5, 0.5, 0.5) + ) + ]) +def test_basic_example(preds, labels, raws, maxs, mins) -> None: """tests that both min and max versions of MinMaxMetric operate correctly after calling compute.""" acc = Accuracy() min_max_acc = MinMaxMetric(acc) - preds_1 = torch.Tensor([[0.9, 0.1], [0.2, 0.8]]) - preds_2 = torch.Tensor([[0.1, 0.9], [0.2, 0.8]]) - preds_3 = torch.Tensor([[0.1, 0.9], [0.8, 0.2]]) - labels = torch.Tensor([[0, 1], [0, 1]]).long() + preds_1 = torch.Tensor(preds[0]) + preds_2 = torch.Tensor(preds[1]) + preds_3 = torch.Tensor(preds[2]) + labels = torch.Tensor(labels).long() min_max_acc(preds_1, labels) acc = min_max_acc.compute() - assert acc["raw"] == 0.5 - assert acc["max"] == 0.5 - assert acc["min"] == 0.5 + assert acc["raw"] == raws[0] + assert acc["max"] == maxs[0] + assert acc["min"] == mins[0] min_max_acc(preds_2, labels) acc = min_max_acc.compute() - assert acc["raw"] == 1.0 - assert acc["max"] == 1.0 - assert acc["min"] == 0.5 + assert acc["raw"] == raws[1] + assert acc["max"] == maxs[1] + assert acc["min"] == mins[1] min_max_acc(preds_3, labels) acc = min_max_acc.compute() - assert acc["raw"] == 0.5 - assert acc["max"] == 1.0 - assert acc["min"] == 0.5 + assert acc["raw"] == raws[2] + assert acc["max"] == maxs[2] + assert acc["min"] == mins[2] def test_no_base_metric() -> None: From 3067d3676c36dc0840134b8accf286af4d97b3ed Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 25 Oct 2021 14:20:11 +0000 Subject: [PATCH 42/56] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/wrappers/test_minmax.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/tests/wrappers/test_minmax.py b/tests/wrappers/test_minmax.py index 1f6ea369f21..e2098d15e8e 100644 --- a/tests/wrappers/test_minmax.py +++ b/tests/wrappers/test_minmax.py @@ -63,17 +63,19 @@ def test_multioutput_wrapper(self, preds, target, base_metric, ddp): check_scriptable=False, ) + @pytest.mark.parametrize( - "preds, labels, raws, maxs, mins", + "preds, labels, raws, maxs, mins", [ ( - ([[0.9, 0.1], [0.2, 0.8]],[[0.1, 0.9], [0.2, 0.8]], [[0.1, 0.9], [0.8, 0.2]]), + ([[0.9, 0.1], [0.2, 0.8]], [[0.1, 0.9], [0.2, 0.8]], [[0.1, 0.9], [0.8, 0.2]]), [[0, 1], [0, 1]], (0.5, 1.0, 0.5), (0.5, 1.0, 1.0), - (0.5, 0.5, 0.5) + (0.5, 0.5, 0.5), ) - ]) + ], +) def test_basic_example(preds, labels, raws, maxs, mins) -> None: """tests that both min and max versions of MinMaxMetric operate correctly after calling compute.""" acc = Accuracy() From 8df796802f9c9501d51dcbb5f5b068ae7363b8d5 Mon Sep 17 00:00:00 2001 From: Jan-Henrik Lambrechts Date: Tue, 26 Oct 2021 13:52:35 +0100 Subject: [PATCH 43/56] fix doctest with pprint --- torchmetrics/wrappers/minmax.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchmetrics/wrappers/minmax.py b/torchmetrics/wrappers/minmax.py index 250ce50e8d2..7e4ac3cc460 100644 --- a/torchmetrics/wrappers/minmax.py +++ b/torchmetrics/wrappers/minmax.py @@ -54,11 +54,11 @@ class MinMaxMetric(Metric): >>> _ = minmax_metric(preds_1, labels) # Accuracy is 0.5 >>> output = minmax_metric.compute() >>> pprint(output) - {'raw': tensor(0.5000), 'max': tensor(0.5000), 'min': tensor(0.5000)} + {'max': tensor(0.5000), 'min': tensor(0.5000), 'raw': tensor(0.5000)} >>> _ = minmax_metric(preds_2, labels) # Accuracy is 1.0 >>> output = minmax_metric.compute() >>> pprint(output) - {'raw': tensor(1.), 'max': tensor(1.), 'min': tensor(0.5000)} + {'max': tensor(1.), 'min': tensor(0.5000), 'raw': tensor(1.)} """ min_val: Tensor From c380d238ea378f15c2b362ea0a7bcc1087c5ef17 Mon Sep 17 00:00:00 2001 From: Jirka Date: Wed, 27 Oct 2021 13:26:57 +0200 Subject: [PATCH 44/56] prune --- tests/wrappers/test_minmax.py | 28 +++++++--------------------- 1 file changed, 7 insertions(+), 21 deletions(-) diff --git a/tests/wrappers/test_minmax.py b/tests/wrappers/test_minmax.py index e2098d15e8e..77edc391f72 100644 --- a/tests/wrappers/test_minmax.py +++ b/tests/wrappers/test_minmax.py @@ -80,29 +80,15 @@ def test_basic_example(preds, labels, raws, maxs, mins) -> None: """tests that both min and max versions of MinMaxMetric operate correctly after calling compute.""" acc = Accuracy() min_max_acc = MinMaxMetric(acc) - - preds_1 = torch.Tensor(preds[0]) - preds_2 = torch.Tensor(preds[1]) - preds_3 = torch.Tensor(preds[2]) labels = torch.Tensor(labels).long() - min_max_acc(preds_1, labels) - acc = min_max_acc.compute() - assert acc["raw"] == raws[0] - assert acc["max"] == maxs[0] - assert acc["min"] == mins[0] - - min_max_acc(preds_2, labels) - acc = min_max_acc.compute() - assert acc["raw"] == raws[1] - assert acc["max"] == maxs[1] - assert acc["min"] == mins[1] - - min_max_acc(preds_3, labels) - acc = min_max_acc.compute() - assert acc["raw"] == raws[2] - assert acc["max"] == maxs[2] - assert acc["min"] == mins[2] + for i in range(3): + preds_ = torch.Tensor(preds[i]) + min_max_acc(preds_, labels) + acc = min_max_acc.compute() + assert acc["raw"] == raws[i] + assert acc["max"] == maxs[i] + assert acc["min"] == mins[i] def test_no_base_metric() -> None: From 03ca76b8fcbaa9c5d4d50b7027040f18f0f67035 Mon Sep 17 00:00:00 2001 From: Jirka Date: Wed, 27 Oct 2021 13:30:11 +0200 Subject: [PATCH 45/56] docs --- torchmetrics/wrappers/minmax.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/torchmetrics/wrappers/minmax.py b/torchmetrics/wrappers/minmax.py index 7e4ac3cc460..608c96f2f34 100644 --- a/torchmetrics/wrappers/minmax.py +++ b/torchmetrics/wrappers/minmax.py @@ -44,20 +44,20 @@ class MinMaxMetric(Metric): Example:: >>> import torch - >>> from torchmetrics import Accuracy, MinMaxMetric + >>> from torchmetrics import Accuracy >>> from pprint import pprint >>> base_metric = Accuracy() >>> minmax_metric = MinMaxMetric(base_metric) >>> preds_1 = torch.Tensor([[0.9, 0.1], [0.2, 0.8]]) >>> preds_2 = torch.Tensor([[0.1, 0.9], [0.2, 0.8]]) >>> labels = torch.Tensor([[0, 1], [0, 1]]).long() - >>> _ = minmax_metric(preds_1, labels) # Accuracy is 0.5 - >>> output = minmax_metric.compute() - >>> pprint(output) + >>> pprint(minmax_metric(preds_1, labels)) {'max': tensor(0.5000), 'min': tensor(0.5000), 'raw': tensor(0.5000)} - >>> _ = minmax_metric(preds_2, labels) # Accuracy is 1.0 - >>> output = minmax_metric.compute() - >>> pprint(output) + >>> pprint(minmax_metric.compute()) + {'max': tensor(0.5000), 'min': tensor(0.5000), 'raw': tensor(0.5000)} + >>> pprint(minmax_metric(preds_2, labels)) + {'max': tensor(1.), 'min': tensor(1.), 'raw': tensor(1.)} + >>> pprint(minmax_metric.compute()) {'max': tensor(1.), 'min': tensor(0.5000), 'raw': tensor(1.)} """ From c61ba5d93f4fd83a2ed31f69630844cf9a6defc7 Mon Sep 17 00:00:00 2001 From: Jirka Date: Wed, 27 Oct 2021 13:32:30 +0200 Subject: [PATCH 46/56] fixing --- torchmetrics/wrappers/minmax.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/torchmetrics/wrappers/minmax.py b/torchmetrics/wrappers/minmax.py index 608c96f2f34..9df95995536 100644 --- a/torchmetrics/wrappers/minmax.py +++ b/torchmetrics/wrappers/minmax.py @@ -80,7 +80,7 @@ def __init__( ) if not isinstance(base_metric, Metric): raise ValueError( - f"Expected base metric to be an instance of torchmetrics.Metric but received {base_metric}" + f"Expected base metric to be an instance of `torchmetrics.Metric` but received {base_metric}" ) self._base_metric = base_metric self.add_state("min_val", default=torch.tensor(float("inf")), dist_reduce_fx="min") @@ -99,7 +99,7 @@ def compute(self) -> Dict[str, Tensor]: # type: ignore val = self._base_metric.compute() if not self._is_suitable_val(val): raise RuntimeError( - "Returned value from base metric should be a scalar (int, float or tensor of size 1, but got {val}" + f"Returned value from base metric should be a scalar (int, float or tensor of size 1, but got {val}" ) self.max_val = val if self.max_val < val else self.max_val self.min_val = val if self.min_val > val else self.min_val @@ -113,8 +113,8 @@ def reset(self) -> None: @staticmethod def _is_suitable_val(val: Union[int, float, Tensor]) -> bool: """Utility function that checks whether min/max value.""" - if (type(val) == int) or (type(val) == float): + if isinstance(val, (int, float)): return True - elif isinstance(val, Tensor): + if isinstance(val, Tensor): return val.numel() == 1 return False From ed962df8b09190c12d8149ed5de4338cafd89170 Mon Sep 17 00:00:00 2001 From: Jirka Date: Wed, 27 Oct 2021 13:33:01 +0200 Subject: [PATCH 47/56] . --- torchmetrics/wrappers/minmax.py | 32 ++++++++++++++++---------------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/torchmetrics/wrappers/minmax.py b/torchmetrics/wrappers/minmax.py index 9df95995536..513a527a64f 100644 --- a/torchmetrics/wrappers/minmax.py +++ b/torchmetrics/wrappers/minmax.py @@ -43,22 +43,22 @@ class MinMaxMetric(Metric): If ``base_metric` argument is not an subclasses instance of ``torchmetrics.Metric`` Example:: - >>> import torch - >>> from torchmetrics import Accuracy - >>> from pprint import pprint - >>> base_metric = Accuracy() - >>> minmax_metric = MinMaxMetric(base_metric) - >>> preds_1 = torch.Tensor([[0.9, 0.1], [0.2, 0.8]]) - >>> preds_2 = torch.Tensor([[0.1, 0.9], [0.2, 0.8]]) - >>> labels = torch.Tensor([[0, 1], [0, 1]]).long() - >>> pprint(minmax_metric(preds_1, labels)) - {'max': tensor(0.5000), 'min': tensor(0.5000), 'raw': tensor(0.5000)} - >>> pprint(minmax_metric.compute()) - {'max': tensor(0.5000), 'min': tensor(0.5000), 'raw': tensor(0.5000)} - >>> pprint(minmax_metric(preds_2, labels)) - {'max': tensor(1.), 'min': tensor(1.), 'raw': tensor(1.)} - >>> pprint(minmax_metric.compute()) - {'max': tensor(1.), 'min': tensor(0.5000), 'raw': tensor(1.)} + >>> import torch + >>> from torchmetrics import Accuracy + >>> from pprint import pprint + >>> base_metric = Accuracy() + >>> minmax_metric = MinMaxMetric(base_metric) + >>> preds_1 = torch.Tensor([[0.9, 0.1], [0.2, 0.8]]) + >>> preds_2 = torch.Tensor([[0.1, 0.9], [0.2, 0.8]]) + >>> labels = torch.Tensor([[0, 1], [0, 1]]).long() + >>> pprint(minmax_metric(preds_1, labels)) + {'max': tensor(0.5000), 'min': tensor(0.5000), 'raw': tensor(0.5000)} + >>> pprint(minmax_metric.compute()) + {'max': tensor(0.5000), 'min': tensor(0.5000), 'raw': tensor(0.5000)} + >>> pprint(minmax_metric(preds_2, labels)) + {'max': tensor(1.), 'min': tensor(1.), 'raw': tensor(1.)} + >>> pprint(minmax_metric.compute()) + {'max': tensor(1.), 'min': tensor(0.5000), 'raw': tensor(1.)} """ min_val: Tensor From b8760fd30283592dde643cbbafb48d147431ae0e Mon Sep 17 00:00:00 2001 From: Jirka Date: Thu, 28 Oct 2021 18:55:13 +0200 Subject: [PATCH 48/56] release v0.6.0 --- CHANGELOG.md | 2 +- torchmetrics/__about__.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index a3bb255fbdb..c77f17bb0e3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,7 +6,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 **Note: we move fast, but still we preserve 0.1 version (one feature release) back compatibility.** -## [0.6.0] - 2021-10-DD +## [0.6.0] - 2021-10-28 ### Added diff --git a/torchmetrics/__about__.py b/torchmetrics/__about__.py index ed3caea89b2..c01d1134307 100644 --- a/torchmetrics/__about__.py +++ b/torchmetrics/__about__.py @@ -1,4 +1,4 @@ -__version__ = "0.6.0rc1" +__version__ = "0.6.0" __author__ = "PyTorchLightning et al." __author_email__ = "name@pytorchlightning.ai" __license__ = "Apache-2.0" From d946cd2674424c3dc3f1201934ac6dff051558ff Mon Sep 17 00:00:00 2001 From: Jirka Date: Thu, 28 Oct 2021 18:55:13 +0200 Subject: [PATCH 49/56] release v0.6.0 * fix parsing --- CHANGELOG.md | 2 +- torchmetrics/__about__.py | 2 +- torchmetrics/setup_tools.py | 5 +++-- 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index a3bb255fbdb..c77f17bb0e3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,7 +6,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 **Note: we move fast, but still we preserve 0.1 version (one feature release) back compatibility.** -## [0.6.0] - 2021-10-DD +## [0.6.0] - 2021-10-28 ### Added diff --git a/torchmetrics/__about__.py b/torchmetrics/__about__.py index ed3caea89b2..c01d1134307 100644 --- a/torchmetrics/__about__.py +++ b/torchmetrics/__about__.py @@ -1,4 +1,4 @@ -__version__ = "0.6.0rc1" +__version__ = "0.6.0" __author__ = "PyTorchLightning et al." __author_email__ = "name@pytorchlightning.ai" __license__ = "Apache-2.0" diff --git a/torchmetrics/setup_tools.py b/torchmetrics/setup_tools.py index 68c361c60be..6265b684bd6 100644 --- a/torchmetrics/setup_tools.py +++ b/torchmetrics/setup_tools.py @@ -18,7 +18,7 @@ _PROJECT_ROOT = os.path.dirname(os.path.dirname(__file__)) -def _load_requirements(path_dir: str, file_name: str = "requirements.txt", comment_char: str = "#") -> List[str]: +def _load_requirements(path_dir: str, file_name: str = "requirements.txt", comment_char: str = "#@") -> List[str]: """Load requirements from a file. >>> _load_requirements(_PROJECT_ROOT) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE @@ -30,7 +30,8 @@ def _load_requirements(path_dir: str, file_name: str = "requirements.txt", comme for ln in lines: # filer all comments if comment_char in ln: - ln = ln[: ln.index(comment_char)].strip() + char_idx = min(ln.index(ch) for ch in comment_char) + ln = ln[:char_idx].strip() # skip directly installed dependencies if ln.startswith("http") or ln.startswith("git") or ln.startswith("-r"): continue From 6940092cc160ed1dc72e45a769e3891578a3268a Mon Sep 17 00:00:00 2001 From: Jirka Date: Mon, 1 Nov 2021 09:23:03 +0100 Subject: [PATCH 50/56] update setup --- setup.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/setup.py b/setup.py index 920ac4fd1e7..31b7eee5372 100755 --- a/setup.py +++ b/setup.py @@ -30,7 +30,7 @@ def _load_py_module(fname, pkg="torchmetrics"): BASE_REQUIREMENTS = setup_tools._load_requirements(path_dir=_PATH_ROOT, file_name="requirements.txt") -def _prepare_extras(base_req: List[str], skip_files: Tuple[str] = ("devel.txt")): +def _prepare_extras(skip_files: Tuple[str] = ("devel.txt")): # find all extra requirements _load_req = partial(setup_tools._load_requirements, path_dir=_PATH_REQUIRE) found_req_files = sorted(os.path.basename(p) for p in glob.glob(os.path.join(_PATH_REQUIRE, "*.txt"))) @@ -38,7 +38,7 @@ def _prepare_extras(base_req: List[str], skip_files: Tuple[str] = ("devel.txt")) found_req_files = [n for n in found_req_files if n not in skip_files] found_req_names = [os.path.splitext(req)[0].replace("datatype_", "") for req in found_req_files] # define basic and extra extras - extras_req = {name: base_req + _load_req(file_name=fname) for name, fname in zip(found_req_names, found_req_files)} + extras_req = {name: _load_req(file_name=fname) for name, fname in zip(found_req_names, found_req_files)} # filter the uniques extras_req = {n: list(set(req)) for n, req in extras_req.items()} # create an 'all' keyword that install all possible denpendencies @@ -69,7 +69,7 @@ def _prepare_extras(base_req: List[str], skip_files: Tuple[str] = ("devel.txt")) python_requires=">=3.6", setup_requires=[], install_requires=BASE_REQUIREMENTS, - extras_require=_prepare_extras(BASE_REQUIREMENTS), + extras_require=_prepare_extras(), project_urls={ "Bug Tracker": os.path.join(about.__homepage__, "issues"), "Documentation": "https://torchmetrics.rtfd.io/en/latest/", @@ -80,7 +80,7 @@ def _prepare_extras(base_req: List[str], skip_files: Tuple[str] = ("devel.txt")) "Natural Language :: English", # How mature is this project? Common values are # 3 - Alpha, 4 - Beta, 5 - Production/Stable - "Development Status :: 3 - Alpha", + "Development Status :: 5 - Production/Stable", # Indicate who your project is intended for "Intended Audience :: Developers", "Topic :: Scientific/Engineering :: Artificial Intelligence", @@ -96,5 +96,6 @@ def _prepare_extras(base_req: List[str], skip_files: Tuple[str] = ("devel.txt")) "Programming Language :: Python :: 3.7", "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", ], ) From 9f54e3d333587b6c824d8677dc4e38df3e4a3013 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 15 Nov 2021 13:21:09 +0000 Subject: [PATCH 51/56] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- torchmetrics/__init__.py | 1 - 1 file changed, 1 deletion(-) diff --git a/torchmetrics/__init__.py b/torchmetrics/__init__.py index cf42609fc47..eae171421c0 100644 --- a/torchmetrics/__init__.py +++ b/torchmetrics/__init__.py @@ -77,7 +77,6 @@ ) from torchmetrics.wrappers import BootStrapper, MetricTracker, MinMaxMetric, MultioutputWrapper # noqa: E402 - __all__ = [ "functional", "Accuracy", From 1c7bd07b468fea17386a9a2f16680a45e60d5ec1 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Mon, 15 Nov 2021 14:23:50 +0100 Subject: [PATCH 52/56] update --- CHANGELOG.md | 4 +++- torchmetrics/setup_tools.py | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index d6e92d39317..37a5efe54c0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,10 +10,13 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [unreleased] - 2021-MM-DD ### Added + - Added NLP metrics: - `MatchErrorRate` ([#619](https://github.com/PyTorchLightning/metrics/pull/619)) +- Added `MinMaxMetric` to wrappers ([#556](https://github.com/PyTorchLightning/metrics/pull/556)) + ### Changed @@ -60,7 +63,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - `pairwise_euclidean_distance` - `pairwise_linear_similarity` - `pairwise_manhatten_distance` -- Added `MinMaxMetric` to wrappers ([#556](https://github.com/PyTorchLightning/metrics/pull/556)) ### Changed diff --git a/torchmetrics/setup_tools.py b/torchmetrics/setup_tools.py index aa77cdf2d9e..0739412de15 100644 --- a/torchmetrics/setup_tools.py +++ b/torchmetrics/setup_tools.py @@ -18,7 +18,7 @@ _PROJECT_ROOT = os.path.dirname(os.path.dirname(__file__)) -def _load_requirements(path_dir: str, file_name: str = "requirements.txt", comment_char: str = "#@") -> List[str]: +def _load_requirements(path_dir: str, file_name: str = "requirements.txt", comment_char: str = "#") -> List[str]: """Load requirements from a file. >>> _load_requirements(_PROJECT_ROOT) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE From 0052bacd6be9bca5e6b7e1ddab8238644367f1fd Mon Sep 17 00:00:00 2001 From: Skaftenicki Date: Mon, 15 Nov 2021 15:28:37 +0100 Subject: [PATCH 53/56] fix tests --- tests/wrappers/test_minmax.py | 27 ++++++++++++++++++++++++--- torchmetrics/wrappers/minmax.py | 4 ++-- 2 files changed, 26 insertions(+), 5 deletions(-) diff --git a/tests/wrappers/test_minmax.py b/tests/wrappers/test_minmax.py index 77edc391f72..b9d854e5cdc 100644 --- a/tests/wrappers/test_minmax.py +++ b/tests/wrappers/test_minmax.py @@ -7,7 +7,7 @@ from tests.helpers.testers import BATCH_SIZE, NUM_BATCHES, NUM_CLASSES, MetricTester from torchmetrics import Accuracy, ConfusionMatrix, MeanSquaredError from torchmetrics.wrappers import MinMaxMetric - +from copy import deepcopy seed_all(42) @@ -28,11 +28,32 @@ def compare_fn(preds, target, base_fn): """comparing function for minmax wrapper.""" min, max = 1e6, -1e6 # pick some very large numbers for comparing for i in range(NUM_BATCHES): - val = base_fn(preds[: (i + 1) * BATCH_SIZE], target[: (i + 1) * BATCH_SIZE]).cpu().numpy() + val = base_fn(preds[:(i + 1) * BATCH_SIZE], target[:(i + 1) * BATCH_SIZE]).cpu().numpy() + min = min if min < val else val + max = max if max > val else val + raw = base_fn(preds, target) + return [raw.cpu().numpy(), min, max] + + +def compare_fn_ddp(preds, target, base_fn): + min, max = 1e6, -1e6 # pick some very large numbers for comparing + for i, j in zip(range(0, NUM_BATCHES, 2), range(1, NUM_BATCHES, 2)): + p = torch.cat([ + preds[i*BATCH_SIZE:(i+1)*BATCH_SIZE], + preds[j*BATCH_SIZE:(j+1)*BATCH_SIZE] + ]) + t = torch.cat([ + target[i*BATCH_SIZE:(i+1)*BATCH_SIZE], + target[j*BATCH_SIZE:(j+1)*BATCH_SIZE] + ]) + base_fn.update(p, t) + val = base_fn.compute().cpu().numpy() min = min if min < val else val max = max if max > val else val + print(min, max) raw = base_fn(preds, target) return [raw.cpu().numpy(), min, max] + @pytest.mark.parametrize( @@ -56,7 +77,7 @@ def test_multioutput_wrapper(self, preds, target, base_metric, ddp): preds, target, TestingMinMaxMetric, - partial(compare_fn, base_fn=base_metric), + partial(compare_fn_ddp if ddp else compare_fn, base_fn=deepcopy(base_metric)), dist_sync_on_step=False, metric_args=dict(base_metric=base_metric), check_batch=False, diff --git a/torchmetrics/wrappers/minmax.py b/torchmetrics/wrappers/minmax.py index 513a527a64f..292b1dac443 100644 --- a/torchmetrics/wrappers/minmax.py +++ b/torchmetrics/wrappers/minmax.py @@ -83,8 +83,8 @@ def __init__( f"Expected base metric to be an instance of `torchmetrics.Metric` but received {base_metric}" ) self._base_metric = base_metric - self.add_state("min_val", default=torch.tensor(float("inf")), dist_reduce_fx="min") - self.add_state("max_val", default=torch.tensor(float("-inf")), dist_reduce_fx="max") + self.register_buffer("min_val", torch.tensor(float("inf"))) + self.register_buffer("max_val", torch.tensor(float("-inf"))) def update(self, *args: Any, **kwargs: Any) -> None: # type: ignore """Updates the underlying metric.""" From b4e8b2ab5dce26d11cf2814beb3a76e906a5f930 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 15 Nov 2021 14:40:32 +0000 Subject: [PATCH 54/56] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/wrappers/test_minmax.py | 16 +++++----------- 1 file changed, 5 insertions(+), 11 deletions(-) diff --git a/tests/wrappers/test_minmax.py b/tests/wrappers/test_minmax.py index b9d854e5cdc..02f80086b31 100644 --- a/tests/wrappers/test_minmax.py +++ b/tests/wrappers/test_minmax.py @@ -1,3 +1,4 @@ +from copy import deepcopy from functools import partial import pytest @@ -7,7 +8,7 @@ from tests.helpers.testers import BATCH_SIZE, NUM_BATCHES, NUM_CLASSES, MetricTester from torchmetrics import Accuracy, ConfusionMatrix, MeanSquaredError from torchmetrics.wrappers import MinMaxMetric -from copy import deepcopy + seed_all(42) @@ -28,7 +29,7 @@ def compare_fn(preds, target, base_fn): """comparing function for minmax wrapper.""" min, max = 1e6, -1e6 # pick some very large numbers for comparing for i in range(NUM_BATCHES): - val = base_fn(preds[:(i + 1) * BATCH_SIZE], target[:(i + 1) * BATCH_SIZE]).cpu().numpy() + val = base_fn(preds[: (i + 1) * BATCH_SIZE], target[: (i + 1) * BATCH_SIZE]).cpu().numpy() min = min if min < val else val max = max if max > val else val raw = base_fn(preds, target) @@ -38,14 +39,8 @@ def compare_fn(preds, target, base_fn): def compare_fn_ddp(preds, target, base_fn): min, max = 1e6, -1e6 # pick some very large numbers for comparing for i, j in zip(range(0, NUM_BATCHES, 2), range(1, NUM_BATCHES, 2)): - p = torch.cat([ - preds[i*BATCH_SIZE:(i+1)*BATCH_SIZE], - preds[j*BATCH_SIZE:(j+1)*BATCH_SIZE] - ]) - t = torch.cat([ - target[i*BATCH_SIZE:(i+1)*BATCH_SIZE], - target[j*BATCH_SIZE:(j+1)*BATCH_SIZE] - ]) + p = torch.cat([preds[i * BATCH_SIZE : (i + 1) * BATCH_SIZE], preds[j * BATCH_SIZE : (j + 1) * BATCH_SIZE]]) + t = torch.cat([target[i * BATCH_SIZE : (i + 1) * BATCH_SIZE], target[j * BATCH_SIZE : (j + 1) * BATCH_SIZE]]) base_fn.update(p, t) val = base_fn.compute().cpu().numpy() min = min if min < val else val @@ -53,7 +48,6 @@ def compare_fn_ddp(preds, target, base_fn): print(min, max) raw = base_fn(preds, target) return [raw.cpu().numpy(), min, max] - @pytest.mark.parametrize( From 60d164abe81decf13cf4f88bad3d9d612a7ac83a Mon Sep 17 00:00:00 2001 From: Nicki Skafte Detlefsen Date: Mon, 15 Nov 2021 16:07:15 +0100 Subject: [PATCH 55/56] Update tests/wrappers/test_minmax.py --- tests/wrappers/test_minmax.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/wrappers/test_minmax.py b/tests/wrappers/test_minmax.py index 02f80086b31..fcb55308ca2 100644 --- a/tests/wrappers/test_minmax.py +++ b/tests/wrappers/test_minmax.py @@ -45,7 +45,6 @@ def compare_fn_ddp(preds, target, base_fn): val = base_fn.compute().cpu().numpy() min = min if min < val else val max = max if max > val else val - print(min, max) raw = base_fn(preds, target) return [raw.cpu().numpy(), min, max] From 297943b59af04baa9a869113bbdc7d058e4e1b5d Mon Sep 17 00:00:00 2001 From: Skaftenicki Date: Mon, 15 Nov 2021 16:24:48 +0100 Subject: [PATCH 56/56] update --- torchmetrics/wrappers/minmax.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/torchmetrics/wrappers/minmax.py b/torchmetrics/wrappers/minmax.py index 292b1dac443..68503ffff96 100644 --- a/torchmetrics/wrappers/minmax.py +++ b/torchmetrics/wrappers/minmax.py @@ -48,17 +48,16 @@ class MinMaxMetric(Metric): >>> from pprint import pprint >>> base_metric = Accuracy() >>> minmax_metric = MinMaxMetric(base_metric) - >>> preds_1 = torch.Tensor([[0.9, 0.1], [0.2, 0.8]]) - >>> preds_2 = torch.Tensor([[0.1, 0.9], [0.2, 0.8]]) + >>> preds_1 = torch.Tensor([[0.1, 0.9], [0.2, 0.8]]) + >>> preds_2 = torch.Tensor([[0.9, 0.1], [0.2, 0.8]]) >>> labels = torch.Tensor([[0, 1], [0, 1]]).long() >>> pprint(minmax_metric(preds_1, labels)) - {'max': tensor(0.5000), 'min': tensor(0.5000), 'raw': tensor(0.5000)} + {'max': tensor(1.), 'min': tensor(1.), 'raw': tensor(1.)} >>> pprint(minmax_metric.compute()) - {'max': tensor(0.5000), 'min': tensor(0.5000), 'raw': tensor(0.5000)} - >>> pprint(minmax_metric(preds_2, labels)) {'max': tensor(1.), 'min': tensor(1.), 'raw': tensor(1.)} + >>> minmax_metric.update(preds_2, labels) >>> pprint(minmax_metric.compute()) - {'max': tensor(1.), 'min': tensor(0.5000), 'raw': tensor(1.)} + {'max': tensor(1.), 'min': tensor(0.7500), 'raw': tensor(0.7500)} """ min_val: Tensor