From 75e72305eae8ec83a2e6254104a9e1aefbf84246 Mon Sep 17 00:00:00 2001 From: Skaftenicki Date: Wed, 15 Dec 2021 14:42:08 +0100 Subject: [PATCH 1/8] impl --- torchmetrics/wrappers/tracker.py | 78 ++++++++++++++++++++++++++------ 1 file changed, 65 insertions(+), 13 deletions(-) diff --git a/torchmetrics/wrappers/tracker.py b/torchmetrics/wrappers/tracker.py index 90be9237fcc..798631f8d35 100644 --- a/torchmetrics/wrappers/tracker.py +++ b/torchmetrics/wrappers/tracker.py @@ -12,12 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. from copy import deepcopy -from typing import Any, Tuple, Union +from typing import Any, Dict, List, Tuple, Union import torch from torch import Tensor, nn from torchmetrics.metric import Metric +from torchmetrics.collections import MetricCollection class MetricTracker(nn.ModuleList): @@ -38,8 +39,7 @@ class MetricTracker(nn.ModuleList): maximize: bool indicating if higher metric values are better (`True`) or lower is better (`False`) - Example: - + Example (single metric): >>> from torchmetrics import Accuracy, MetricTracker >>> _ = torch.manual_seed(42) >>> tracker = MetricTracker(Accuracy(num_classes=10)) @@ -55,15 +55,48 @@ class MetricTracker(nn.ModuleList): current acc=0.07999999821186066 current acc=0.10199999809265137 >>> best_acc, which_epoch = tracker.best_metric(return_step=True) + >>> best_acc + 0.12600000202655792 + >>> which_epoch + 2 >>> tracker.compute_all() tensor([0.1120, 0.0880, 0.1260, 0.0800, 0.1020]) + + Example (multiple metrics using MetricCollection): + >>> from torchmetrics import MetricTracker, MetricCollection, MeanSquaredError, R2Score + >>> _ = torch.manual_seed(42) + >>> tracker = MetricTracker(MetricCollection([MeanSquaredError(), R2Score()]), maximize=[False, True]) + >>> for epoch in range(5): + ... tracker.increment() + ... for batch_idx in range(5): + ... preds, target = torch.randn(100), torch.randn(100) + ... tracker.update(preds, target) + ... print(f"current stats={tracker.compute()}") # doctest: +NORMALIZE_WHITESPACE + current stats={'MeanSquaredError': tensor(1.8218), 'MeanAbsoluteError': tensor(1.0769)} + current stats={'MeanSquaredError': tensor(2.0268), 'MeanAbsoluteError': tensor(1.1410)} + current stats={'MeanSquaredError': tensor(1.9491), 'MeanAbsoluteError': tensor(1.1089)} + current stats={'MeanSquaredError': tensor(1.9800), 'MeanAbsoluteError': tensor(1.1353)} + current stats={'MeanSquaredError': tensor(2.2481), 'MeanAbsoluteError': tensor(1.2001)} + >>> best_res, which_epoch = tracker.best_metric(return_step=True) + >>> best_res + {'MeanSquaredError': 1.8218141794204712, 'MeanAbsoluteError': 1.07692551612854} + >>> which_epoch + {'MeanSquaredError': 0, 'MeanAbsoluteError': 0} + >>> tracker.compute_all() # doctest: +NORMALIZE_WHITESPACE + {'MeanSquaredError': tensor([1.8218, 2.0268, 1.9491, 1.9800, 2.2481]), + 'MeanAbsoluteError': tensor([1.0769, 1.1410, 1.1089, 1.1353, 1.2001])} """ - def __init__(self, metric: Metric, maximize: bool = True) -> None: + def __init__(self, metric: Union[Metric, MetricCollection], maximize: Union[bool, List[bool]] = True) -> None: super().__init__() - if not isinstance(metric, Metric): - raise TypeError("metric arg need to be an instance of a torchmetrics metric" f" but got {metric}") + if not isinstance(metric, (Metric, MetricCollection)): + raise TypeError("metric arg need to be an instance of a torchmetrics" + f" `Metric` or `MetricCollection` but got {metric}") self._base_metric = metric + if not isinstance(maximize, (bool, list)): + raise ValueError("Argument `maximize` should either be a single bool or list of bool") + if isinstance(maximize, list) and not isinstance(metric, MetricCollection) and len(maximize) != len(metric): + raise ValueError("The len of argument `maximize` should match the length of th metric collection") self.maximize = maximize self._increment_called = False @@ -96,7 +129,13 @@ def compute(self) -> Any: def compute_all(self) -> Tensor: """Compute the metric value for all tracked metrics.""" self._check_for_increment("compute_all") - return torch.stack([metric.compute() for i, metric in enumerate(self) if i != 0], dim=0) + # The i!=0 accounts for the self._base_metric should be ignored + res = [metric.compute() for i, metric in enumerate(self) if i != 0] + if isinstance(self._base_metric, MetricCollection): + keys = res[0].keys() + return {k:torch.stack([r[k] for r in res], dim=0) for k in keys} + else: + return torch.stack(res, dim=0) def reset(self) -> None: """Resets the current metric being tracked.""" @@ -107,7 +146,7 @@ def reset_all(self) -> None: for metric in self: metric.reset() - def best_metric(self, return_step: bool = False) -> Union[float, Tuple[int, float]]: + def best_metric(self, return_step: bool = False) -> Union[float, Tuple[int, float], Dict[str,float], Tuple[Dict[str,int], Dict[str,float]]]: """Returns the highest metric out of all tracked. Args: @@ -116,11 +155,24 @@ def best_metric(self, return_step: bool = False) -> Union[float, Tuple[int, floa Returns: The best metric value, and optionally the timestep. """ - fn = torch.max if self.maximize else torch.min - idx, max = fn(self.compute_all(), 0) - if return_step: - return idx.item(), max.item() - return max.item() + if isinstance(self._base_metric, Metric): + fn = torch.max if self.maximize else torch.min + idx, best = fn(self.compute_all(), 0) + if return_step: + return idx.item(), best.item() + return best.item() # doctest: +NORMALIZE_WHITESPACE + else: + res = self.compute_all() + maximize = self.maximize if isinstance(self.maximize, list) else len(res) * [self.maximize] + idx, best = {}, {} + for i, (k, v) in enumerate(res.items()): + fn = torch.max if maximize[i] else torch.min + out = fn(v, 0) + idx[k], best[k] = out[0].item(), out[1].item() + + if return_step: + return idx, best + return best def _check_for_increment(self, method: str) -> None: if not self._increment_called: From cc896ea12b6c8c528f2e72720121545b6e441337 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Tue, 4 Jan 2022 12:35:42 +0100 Subject: [PATCH 2/8] update --- tests/wrappers/test_tracker.py | 39 ++++++++++++++++++++++++-------- torchmetrics/wrappers/tracker.py | 33 ++++++++++++++------------- 2 files changed, 46 insertions(+), 26 deletions(-) diff --git a/tests/wrappers/test_tracker.py b/tests/wrappers/test_tracker.py index ce3f977811c..7fe50cd8bc1 100644 --- a/tests/wrappers/test_tracker.py +++ b/tests/wrappers/test_tracker.py @@ -17,7 +17,7 @@ import torch from tests.helpers import seed_all -from torchmetrics import Accuracy, MeanAbsoluteError, MeanSquaredError, Precision, Recall +from torchmetrics import Accuracy, MeanAbsoluteError, MeanSquaredError, Precision, Recall, MetricCollection from torchmetrics.wrappers import MetricTracker seed_all(42) @@ -48,11 +48,15 @@ def test_raises_error_if_increment_not_called(method, method_input): @pytest.mark.parametrize( "base_metric, metric_input, maximize", [ - (partial(Accuracy, num_classes=10), (torch.randint(10, (50,)), torch.randint(10, (50,))), True), - (partial(Precision, num_classes=10), (torch.randint(10, (50,)), torch.randint(10, (50,))), True), - (partial(Recall, num_classes=10), (torch.randint(10, (50,)), torch.randint(10, (50,))), True), - (MeanSquaredError, (torch.randn(50), torch.randn(50)), False), - (MeanAbsoluteError, (torch.randn(50), torch.randn(50)), False), + (Accuracy(num_classes=10), (torch.randint(10, (50,)), torch.randint(10, (50,))), True), + (Precision(num_classes=10), (torch.randint(10, (50,)), torch.randint(10, (50,))), True), + (Recall(num_classes=10), (torch.randint(10, (50,)), torch.randint(10, (50,))), True), + (MeanSquaredError(), (torch.randn(50), torch.randn(50)), False), + (MeanAbsoluteError(), (torch.randn(50), torch.randn(50)), False), + (MetricCollection([Accuracy(num_classes=10), Precision(num_classes=10), Recall(num_classes=10)]), (torch.randint(10, (50,)), torch.randint(10, (50,))), True), + (MetricCollection([Accuracy(num_classes=10), Precision(num_classes=10), Recall(num_classes=10)]), (torch.randint(10, (50,)), torch.randint(10, (50,))), [True, True, True]), + (MetricCollection([MeanSquaredError(), MeanAbsoluteError()]), (torch.randn(50), torch.randn(50)), False), + (MetricCollection([MeanSquaredError(), MeanAbsoluteError()]), (torch.randn(50), torch.randn(50)), [False, False]), ], ) def test_tracker(base_metric, metric_input, maximize): @@ -66,11 +70,26 @@ def test_tracker(base_metric, metric_input, maximize): tracker(*metric_input) val = tracker.compute() - assert val != 0.0 + if isinstance(val, dict): + for v in val.values(): + assert v != 0.0 + else: + assert val != 0.0 assert tracker.n_steps == i + 1 assert tracker.n_steps == 5 - assert tracker.compute_all().shape[0] == 5 + all_computed_val = tracker.compute_all() + if isinstance(all_computed_val, dict): + for v in all_computed_val.values(): + assert v.shape[0] == 5 + else: + assert all_computed_val == 5 + val, idx = tracker.best_metric(return_step=True) - assert val != 0.0 - assert idx in list(range(5)) + if isinstance(val, dict): + for v, i in zip(val.values(), idx.values()): + assert v != 0.0 + assert i in list(range(5)) + else: + assert val != 0.0 + assert idx in list(range(5)) diff --git a/torchmetrics/wrappers/tracker.py b/torchmetrics/wrappers/tracker.py index 798631f8d35..56ca19bf831 100644 --- a/torchmetrics/wrappers/tracker.py +++ b/torchmetrics/wrappers/tracker.py @@ -22,9 +22,10 @@ class MetricTracker(nn.ModuleList): - """A wrapper class that can help keeping track of a metric over time and implement useful methods. The wrapper - implements the standard `update`, `compute`, `reset` methods that just calls corresponding method of the - currently tracked metric. However, the following additional methods are provided: + """A wrapper class that can help keeping track of a metric or metric collection over time and implement + useful methods. The wrapper implements the standard `update`, `compute`, `reset` methods that just calls + corresponding method of the currently tracked metric. + However, the following additional methods are provided: -``MetricTracker.n_steps``: number of metrics being tracked @@ -63,28 +64,28 @@ class MetricTracker(nn.ModuleList): tensor([0.1120, 0.0880, 0.1260, 0.0800, 0.1020]) Example (multiple metrics using MetricCollection): - >>> from torchmetrics import MetricTracker, MetricCollection, MeanSquaredError, R2Score + >>> from torchmetrics import MetricTracker, MetricCollection, MeanSquaredError, ExplainedVariance >>> _ = torch.manual_seed(42) - >>> tracker = MetricTracker(MetricCollection([MeanSquaredError(), R2Score()]), maximize=[False, True]) + >>> tracker = MetricTracker(MetricCollection([MeanSquaredError(), ExplainedVariance()]), maximize=[False, True]) >>> for epoch in range(5): ... tracker.increment() ... for batch_idx in range(5): ... preds, target = torch.randn(100), torch.randn(100) ... tracker.update(preds, target) ... print(f"current stats={tracker.compute()}") # doctest: +NORMALIZE_WHITESPACE - current stats={'MeanSquaredError': tensor(1.8218), 'MeanAbsoluteError': tensor(1.0769)} - current stats={'MeanSquaredError': tensor(2.0268), 'MeanAbsoluteError': tensor(1.1410)} - current stats={'MeanSquaredError': tensor(1.9491), 'MeanAbsoluteError': tensor(1.1089)} - current stats={'MeanSquaredError': tensor(1.9800), 'MeanAbsoluteError': tensor(1.1353)} - current stats={'MeanSquaredError': tensor(2.2481), 'MeanAbsoluteError': tensor(1.2001)} + current stats={'MeanSquaredError': tensor(1.8218), 'ExplainedVariance': tensor(-0.8969)} + current stats={'MeanSquaredError': tensor(2.0268), 'ExplainedVariance': tensor(-1.0206)} + current stats={'MeanSquaredError': tensor(1.9491), 'ExplainedVariance': tensor(-0.8298)} + current stats={'MeanSquaredError': tensor(1.9800), 'ExplainedVariance': tensor(-0.9199)} + current stats={'MeanSquaredError': tensor(2.2481), 'ExplainedVariance': tensor(-1.1622)} >>> best_res, which_epoch = tracker.best_metric(return_step=True) >>> best_res - {'MeanSquaredError': 1.8218141794204712, 'MeanAbsoluteError': 1.07692551612854} + {'MeanSquaredError': 1.8218144178390503, 'ExplainedVariance': -0.8297995328903198} >>> which_epoch - {'MeanSquaredError': 0, 'MeanAbsoluteError': 0} + {'MeanSquaredError': 0, 'ExplainedVariance': 2} >>> tracker.compute_all() # doctest: +NORMALIZE_WHITESPACE {'MeanSquaredError': tensor([1.8218, 2.0268, 1.9491, 1.9800, 2.2481]), - 'MeanAbsoluteError': tensor([1.0769, 1.1410, 1.1089, 1.1353, 1.2001])} + 'ExplainedVariance': tensor([-0.8969, -1.0206, -0.8298, -0.9199, -1.1622])} """ def __init__(self, metric: Union[Metric, MetricCollection], maximize: Union[bool, List[bool]] = True) -> None: @@ -133,7 +134,7 @@ def compute_all(self) -> Tensor: res = [metric.compute() for i, metric in enumerate(self) if i != 0] if isinstance(self._base_metric, MetricCollection): keys = res[0].keys() - return {k:torch.stack([r[k] for r in res], dim=0) for k in keys} + return {k : torch.stack([r[k] for r in res], dim=0) for k in keys} else: return torch.stack(res, dim=0) @@ -160,7 +161,7 @@ def best_metric(self, return_step: bool = False) -> Union[float, Tuple[int, floa idx, best = fn(self.compute_all(), 0) if return_step: return idx.item(), best.item() - return best.item() # doctest: +NORMALIZE_WHITESPACE + return best.item() else: res = self.compute_all() maximize = self.maximize if isinstance(self.maximize, list) else len(res) * [self.maximize] @@ -169,7 +170,7 @@ def best_metric(self, return_step: bool = False) -> Union[float, Tuple[int, floa fn = torch.max if maximize[i] else torch.min out = fn(v, 0) idx[k], best[k] = out[0].item(), out[1].item() - + if return_step: return idx, best return best From 0a01e95be1923e80e7e8847f87c35f243a9b864e Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 6 Jan 2022 09:44:03 +0000 Subject: [PATCH 3/8] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/wrappers/test_tracker.py | 24 ++++++++++++++++++------ torchmetrics/wrappers/tracker.py | 25 ++++++++++++++----------- 2 files changed, 32 insertions(+), 17 deletions(-) diff --git a/tests/wrappers/test_tracker.py b/tests/wrappers/test_tracker.py index 7fe50cd8bc1..8a79c982623 100644 --- a/tests/wrappers/test_tracker.py +++ b/tests/wrappers/test_tracker.py @@ -17,7 +17,7 @@ import torch from tests.helpers import seed_all -from torchmetrics import Accuracy, MeanAbsoluteError, MeanSquaredError, Precision, Recall, MetricCollection +from torchmetrics import Accuracy, MeanAbsoluteError, MeanSquaredError, MetricCollection, Precision, Recall from torchmetrics.wrappers import MetricTracker seed_all(42) @@ -53,10 +53,22 @@ def test_raises_error_if_increment_not_called(method, method_input): (Recall(num_classes=10), (torch.randint(10, (50,)), torch.randint(10, (50,))), True), (MeanSquaredError(), (torch.randn(50), torch.randn(50)), False), (MeanAbsoluteError(), (torch.randn(50), torch.randn(50)), False), - (MetricCollection([Accuracy(num_classes=10), Precision(num_classes=10), Recall(num_classes=10)]), (torch.randint(10, (50,)), torch.randint(10, (50,))), True), - (MetricCollection([Accuracy(num_classes=10), Precision(num_classes=10), Recall(num_classes=10)]), (torch.randint(10, (50,)), torch.randint(10, (50,))), [True, True, True]), + ( + MetricCollection([Accuracy(num_classes=10), Precision(num_classes=10), Recall(num_classes=10)]), + (torch.randint(10, (50,)), torch.randint(10, (50,))), + True, + ), + ( + MetricCollection([Accuracy(num_classes=10), Precision(num_classes=10), Recall(num_classes=10)]), + (torch.randint(10, (50,)), torch.randint(10, (50,))), + [True, True, True], + ), (MetricCollection([MeanSquaredError(), MeanAbsoluteError()]), (torch.randn(50), torch.randn(50)), False), - (MetricCollection([MeanSquaredError(), MeanAbsoluteError()]), (torch.randn(50), torch.randn(50)), [False, False]), + ( + MetricCollection([MeanSquaredError(), MeanAbsoluteError()]), + (torch.randn(50), torch.randn(50)), + [False, False], + ), ], ) def test_tracker(base_metric, metric_input, maximize): @@ -71,7 +83,7 @@ def test_tracker(base_metric, metric_input, maximize): val = tracker.compute() if isinstance(val, dict): - for v in val.values(): + for v in val.values(): assert v != 0.0 else: assert val != 0.0 @@ -89,7 +101,7 @@ def test_tracker(base_metric, metric_input, maximize): if isinstance(val, dict): for v, i in zip(val.values(), idx.values()): assert v != 0.0 - assert i in list(range(5)) + assert i in list(range(5)) else: assert val != 0.0 assert idx in list(range(5)) diff --git a/torchmetrics/wrappers/tracker.py b/torchmetrics/wrappers/tracker.py index 56ca19bf831..9e810695821 100644 --- a/torchmetrics/wrappers/tracker.py +++ b/torchmetrics/wrappers/tracker.py @@ -17,15 +17,14 @@ import torch from torch import Tensor, nn -from torchmetrics.metric import Metric from torchmetrics.collections import MetricCollection +from torchmetrics.metric import Metric class MetricTracker(nn.ModuleList): - """A wrapper class that can help keeping track of a metric or metric collection over time and implement - useful methods. The wrapper implements the standard `update`, `compute`, `reset` methods that just calls - corresponding method of the currently tracked metric. - However, the following additional methods are provided: + """A wrapper class that can help keeping track of a metric or metric collection over time and implement useful + methods. The wrapper implements the standard `update`, `compute`, `reset` methods that just calls corresponding + method of the currently tracked metric. However, the following additional methods are provided: -``MetricTracker.n_steps``: number of metrics being tracked @@ -62,7 +61,7 @@ class MetricTracker(nn.ModuleList): 2 >>> tracker.compute_all() tensor([0.1120, 0.0880, 0.1260, 0.0800, 0.1020]) - + Example (multiple metrics using MetricCollection): >>> from torchmetrics import MetricTracker, MetricCollection, MeanSquaredError, ExplainedVariance >>> _ = torch.manual_seed(42) @@ -84,15 +83,17 @@ class MetricTracker(nn.ModuleList): >>> which_epoch {'MeanSquaredError': 0, 'ExplainedVariance': 2} >>> tracker.compute_all() # doctest: +NORMALIZE_WHITESPACE - {'MeanSquaredError': tensor([1.8218, 2.0268, 1.9491, 1.9800, 2.2481]), + {'MeanSquaredError': tensor([1.8218, 2.0268, 1.9491, 1.9800, 2.2481]), 'ExplainedVariance': tensor([-0.8969, -1.0206, -0.8298, -0.9199, -1.1622])} """ def __init__(self, metric: Union[Metric, MetricCollection], maximize: Union[bool, List[bool]] = True) -> None: super().__init__() if not isinstance(metric, (Metric, MetricCollection)): - raise TypeError("metric arg need to be an instance of a torchmetrics" - f" `Metric` or `MetricCollection` but got {metric}") + raise TypeError( + "metric arg need to be an instance of a torchmetrics" + f" `Metric` or `MetricCollection` but got {metric}" + ) self._base_metric = metric if not isinstance(maximize, (bool, list)): raise ValueError("Argument `maximize` should either be a single bool or list of bool") @@ -134,7 +135,7 @@ def compute_all(self) -> Tensor: res = [metric.compute() for i, metric in enumerate(self) if i != 0] if isinstance(self._base_metric, MetricCollection): keys = res[0].keys() - return {k : torch.stack([r[k] for r in res], dim=0) for k in keys} + return {k: torch.stack([r[k] for r in res], dim=0) for k in keys} else: return torch.stack(res, dim=0) @@ -147,7 +148,9 @@ def reset_all(self) -> None: for metric in self: metric.reset() - def best_metric(self, return_step: bool = False) -> Union[float, Tuple[int, float], Dict[str,float], Tuple[Dict[str,int], Dict[str,float]]]: + def best_metric( + self, return_step: bool = False + ) -> Union[float, Tuple[int, float], Dict[str, float], Tuple[Dict[str, int], Dict[str, float]]]: """Returns the highest metric out of all tracked. Args: From 989c0ae5618257a621884a3b41fd5b0432b83366 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Thu, 20 Jan 2022 15:02:09 +0100 Subject: [PATCH 4/8] changelog --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 9df5c265111..81170772de7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added +- Added support for `MetricCollection` in `MetricTracker` ([#718](https://github.com/PyTorchLightning/metrics/pull/718)) ### Changed From 575d664144af7850e25501a394b8880f28ee6536 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Thu, 20 Jan 2022 15:12:28 +0100 Subject: [PATCH 5/8] fix tests --- tests/wrappers/test_tracker.py | 19 +++++++++++++++---- torchmetrics/wrappers/tracker.py | 15 ++++++++------- 2 files changed, 23 insertions(+), 11 deletions(-) diff --git a/tests/wrappers/test_tracker.py b/tests/wrappers/test_tracker.py index 8a79c982623..21f9ee4ce1c 100644 --- a/tests/wrappers/test_tracker.py +++ b/tests/wrappers/test_tracker.py @@ -24,9 +24,16 @@ def test_raises_error_on_wrong_input(): - with pytest.raises(TypeError, match="metric arg need to be an instance of a torchmetrics metric .*"): + """ make sure that input type errors are raised on wrong input """ + with pytest.raises(TypeError, match="Metric arg need to be an instance of a .*"): MetricTracker([1, 2, 3]) + with pytest.raises(ValueError, match="Argument `maximize` should either be a single bool or list of bool"): + MetricTracker(MeanAbsoluteError(), maximize=2) + + with pytest.raises(ValueError, match="The len of argument `maximize` should match the length of the metric collection"): + MetricTracker(MetricCollection([MeanAbsoluteError(), MeanSquaredError()]), maximize=[False, False, False]) + @pytest.mark.parametrize( "method, method_input", @@ -72,7 +79,8 @@ def test_raises_error_if_increment_not_called(method, method_input): ], ) def test_tracker(base_metric, metric_input, maximize): - tracker = MetricTracker(base_metric(), maximize=maximize) + """ Test that arguments gets passed correctly to child modules""" + tracker = MetricTracker(base_metric, maximize=maximize) for i in range(5): tracker.increment() # check both update and forward works @@ -81,6 +89,7 @@ def test_tracker(base_metric, metric_input, maximize): for _ in range(5): tracker(*metric_input) + # Make sure we have computed something val = tracker.compute() if isinstance(val, dict): for v in val.values(): @@ -89,14 +98,16 @@ def test_tracker(base_metric, metric_input, maximize): assert val != 0.0 assert tracker.n_steps == i + 1 + # Assert that compute all returns all values assert tracker.n_steps == 5 all_computed_val = tracker.compute_all() if isinstance(all_computed_val, dict): for v in all_computed_val.values(): - assert v.shape[0] == 5 + assert v.numel() == 5 else: - assert all_computed_val == 5 + assert all_computed_val.numel() == 5 + # Assert that best_metric returns both index and value val, idx = tracker.best_metric(return_step=True) if isinstance(val, dict): for v, i in zip(val.values(), idx.values()): diff --git a/torchmetrics/wrappers/tracker.py b/torchmetrics/wrappers/tracker.py index c22cb49ba8c..8ab19b500a5 100644 --- a/torchmetrics/wrappers/tracker.py +++ b/torchmetrics/wrappers/tracker.py @@ -35,9 +35,10 @@ class MetricTracker(nn.ModuleList): -``MetricTracker.best_metric()``: returns the best value Args: - metric: instance of a torchmetric modular to keep track of at each timestep. - maximize: bool indicating if higher metric values are better (`True`) or lower - is better (`False`) + metric: instance of a `torchmetrics.Metric` or `torchmetrics.MetricCollection` to keep track + of at each timestep. + maximize: either single bool or list of bool indicating if higher metric values are + better (`True`) or lower is better (`False`). Example (single metric): >>> from torchmetrics import Accuracy, MetricTracker @@ -84,21 +85,21 @@ class MetricTracker(nn.ModuleList): {'MeanSquaredError': 0, 'ExplainedVariance': 2} >>> tracker.compute_all() # doctest: +NORMALIZE_WHITESPACE {'MeanSquaredError': tensor([1.8218, 2.0268, 1.9491, 1.9800, 2.2481]), - 'ExplainedVariance': tensor([-0.8969, -1.0206, -0.8298, -0.9199, -1.1622])} + 'ExplainedVariance': tensor([-0.8969, -1.0206, -0.8298, -0.9199, -1.1622])} """ def __init__(self, metric: Union[Metric, MetricCollection], maximize: Union[bool, List[bool]] = True) -> None: super().__init__() if not isinstance(metric, (Metric, MetricCollection)): raise TypeError( - "metric arg need to be an instance of a torchmetrics" + "Metric arg need to be an instance of a torchmetrics" f" `Metric` or `MetricCollection` but got {metric}" ) self._base_metric = metric if not isinstance(maximize, (bool, list)): raise ValueError("Argument `maximize` should either be a single bool or list of bool") - if isinstance(maximize, list) and not isinstance(metric, MetricCollection) and len(maximize) != len(metric): - raise ValueError("The len of argument `maximize` should match the length of th metric collection") + if isinstance(maximize, list) and isinstance(metric, MetricCollection) and len(maximize) != len(metric): + raise ValueError("The len of argument `maximize` should match the length of the metric collection") self.maximize = maximize self._increment_called = False From ddf2df91882acd8fd2bed7057296d89152acd3eb Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 20 Jan 2022 14:13:27 +0000 Subject: [PATCH 6/8] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/wrappers/test_tracker.py | 8 +++++--- torchmetrics/wrappers/tracker.py | 4 ++-- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/tests/wrappers/test_tracker.py b/tests/wrappers/test_tracker.py index 21f9ee4ce1c..6c29e46b343 100644 --- a/tests/wrappers/test_tracker.py +++ b/tests/wrappers/test_tracker.py @@ -24,14 +24,16 @@ def test_raises_error_on_wrong_input(): - """ make sure that input type errors are raised on wrong input """ + """make sure that input type errors are raised on wrong input.""" with pytest.raises(TypeError, match="Metric arg need to be an instance of a .*"): MetricTracker([1, 2, 3]) with pytest.raises(ValueError, match="Argument `maximize` should either be a single bool or list of bool"): MetricTracker(MeanAbsoluteError(), maximize=2) - with pytest.raises(ValueError, match="The len of argument `maximize` should match the length of the metric collection"): + with pytest.raises( + ValueError, match="The len of argument `maximize` should match the length of the metric collection" + ): MetricTracker(MetricCollection([MeanAbsoluteError(), MeanSquaredError()]), maximize=[False, False, False]) @@ -79,7 +81,7 @@ def test_raises_error_if_increment_not_called(method, method_input): ], ) def test_tracker(base_metric, metric_input, maximize): - """ Test that arguments gets passed correctly to child modules""" + """Test that arguments gets passed correctly to child modules.""" tracker = MetricTracker(base_metric, maximize=maximize) for i in range(5): tracker.increment() diff --git a/torchmetrics/wrappers/tracker.py b/torchmetrics/wrappers/tracker.py index 8ab19b500a5..47fb2f8f0bb 100644 --- a/torchmetrics/wrappers/tracker.py +++ b/torchmetrics/wrappers/tracker.py @@ -35,9 +35,9 @@ class MetricTracker(nn.ModuleList): -``MetricTracker.best_metric()``: returns the best value Args: - metric: instance of a `torchmetrics.Metric` or `torchmetrics.MetricCollection` to keep track + metric: instance of a `torchmetrics.Metric` or `torchmetrics.MetricCollection` to keep track of at each timestep. - maximize: either single bool or list of bool indicating if higher metric values are + maximize: either single bool or list of bool indicating if higher metric values are better (`True`) or lower is better (`False`). Example (single metric): From 4d83e0236620f9f3266b79f5e7db8820fb6a6af2 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Thu, 20 Jan 2022 17:14:17 +0100 Subject: [PATCH 7/8] Apply suggestions from code review --- tests/wrappers/test_tracker.py | 2 +- torchmetrics/wrappers/tracker.py | 12 ++++++------ 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/tests/wrappers/test_tracker.py b/tests/wrappers/test_tracker.py index 6c29e46b343..8b4a8804c7f 100644 --- a/tests/wrappers/test_tracker.py +++ b/tests/wrappers/test_tracker.py @@ -24,7 +24,7 @@ def test_raises_error_on_wrong_input(): - """make sure that input type errors are raised on wrong input.""" + """Make sure that input type errors are raised on the wrong input.""" with pytest.raises(TypeError, match="Metric arg need to be an instance of a .*"): MetricTracker([1, 2, 3]) diff --git a/torchmetrics/wrappers/tracker.py b/torchmetrics/wrappers/tracker.py index 47fb2f8f0bb..f60a1b29821 100644 --- a/torchmetrics/wrappers/tracker.py +++ b/torchmetrics/wrappers/tracker.py @@ -56,8 +56,8 @@ class MetricTracker(nn.ModuleList): current acc=0.07999999821186066 current acc=0.10199999809265137 >>> best_acc, which_epoch = tracker.best_metric(return_step=True) - >>> best_acc - 0.12600000202655792 + >>> best_acc # doctest: +ELLIPSIS + 0.1260... >>> which_epoch 2 >>> tracker.compute_all() @@ -78,12 +78,13 @@ class MetricTracker(nn.ModuleList): current stats={'MeanSquaredError': tensor(1.9491), 'ExplainedVariance': tensor(-0.8298)} current stats={'MeanSquaredError': tensor(1.9800), 'ExplainedVariance': tensor(-0.9199)} current stats={'MeanSquaredError': tensor(2.2481), 'ExplainedVariance': tensor(-1.1622)} + >>> from pprint import pprint >>> best_res, which_epoch = tracker.best_metric(return_step=True) - >>> best_res + >>> pprint(best_res) {'MeanSquaredError': 1.8218144178390503, 'ExplainedVariance': -0.8297995328903198} >>> which_epoch {'MeanSquaredError': 0, 'ExplainedVariance': 2} - >>> tracker.compute_all() # doctest: +NORMALIZE_WHITESPACE + >>> pprint(tracker.compute_all()) {'MeanSquaredError': tensor([1.8218, 2.0268, 1.9491, 1.9800, 2.2481]), 'ExplainedVariance': tensor([-0.8969, -1.0206, -0.8298, -0.9199, -1.1622])} """ @@ -137,8 +138,7 @@ def compute_all(self) -> Tensor: if isinstance(self._base_metric, MetricCollection): keys = res[0].keys() return {k: torch.stack([r[k] for r in res], dim=0) for k in keys} - else: - return torch.stack(res, dim=0) + return torch.stack(res, dim=0) def reset(self) -> None: """Resets the current metric being tracked.""" From f65844ce0a22e06b67699028b9596b2f1075808e Mon Sep 17 00:00:00 2001 From: Jirka Date: Thu, 20 Jan 2022 17:17:43 +0100 Subject: [PATCH 8/8] doctest --- tests/wrappers/test_tracker.py | 1 - torchmetrics/wrappers/tracker.py | 9 +++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/wrappers/test_tracker.py b/tests/wrappers/test_tracker.py index 8b4a8804c7f..5e46528f09d 100644 --- a/tests/wrappers/test_tracker.py +++ b/tests/wrappers/test_tracker.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from functools import partial import pytest import torch diff --git a/torchmetrics/wrappers/tracker.py b/torchmetrics/wrappers/tracker.py index f60a1b29821..273cb9019f5 100644 --- a/torchmetrics/wrappers/tracker.py +++ b/torchmetrics/wrappers/tracker.py @@ -80,13 +80,14 @@ class MetricTracker(nn.ModuleList): current stats={'MeanSquaredError': tensor(2.2481), 'ExplainedVariance': tensor(-1.1622)} >>> from pprint import pprint >>> best_res, which_epoch = tracker.best_metric(return_step=True) - >>> pprint(best_res) - {'MeanSquaredError': 1.8218144178390503, 'ExplainedVariance': -0.8297995328903198} + >>> pprint(best_res) # doctest: +ELLIPSIS + {'ExplainedVariance': -0.829..., + 'MeanSquaredError': 1.821...} >>> which_epoch {'MeanSquaredError': 0, 'ExplainedVariance': 2} >>> pprint(tracker.compute_all()) - {'MeanSquaredError': tensor([1.8218, 2.0268, 1.9491, 1.9800, 2.2481]), - 'ExplainedVariance': tensor([-0.8969, -1.0206, -0.8298, -0.9199, -1.1622])} + {'ExplainedVariance': tensor([-0.8969, -1.0206, -0.8298, -0.9199, -1.1622]), + 'MeanSquaredError': tensor([1.8218, 2.0268, 1.9491, 1.9800, 2.2481])} """ def __init__(self, metric: Union[Metric, MetricCollection], maximize: Union[bool, List[bool]] = True) -> None: