Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for collection in Tracker #718

Merged
merged 11 commits into from
Jan 20, 2022
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added

- Added support for `MetricCollection` in `MetricTracker` ([#718](https://github.com/PyTorchLightning/metrics/pull/718))

### Changed

Expand Down
68 changes: 56 additions & 12 deletions tests/wrappers/test_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,25 @@
import torch

from tests.helpers import seed_all
from torchmetrics import Accuracy, MeanAbsoluteError, MeanSquaredError, Precision, Recall
from torchmetrics import Accuracy, MeanAbsoluteError, MeanSquaredError, MetricCollection, Precision, Recall
from torchmetrics.wrappers import MetricTracker

seed_all(42)


def test_raises_error_on_wrong_input():
with pytest.raises(TypeError, match="metric arg need to be an instance of a torchmetrics metric .*"):
"""make sure that input type errors are raised on wrong input."""
with pytest.raises(TypeError, match="Metric arg need to be an instance of a .*"):
MetricTracker([1, 2, 3])

with pytest.raises(ValueError, match="Argument `maximize` should either be a single bool or list of bool"):
MetricTracker(MeanAbsoluteError(), maximize=2)

with pytest.raises(
ValueError, match="The len of argument `maximize` should match the length of the metric collection"
):
MetricTracker(MetricCollection([MeanAbsoluteError(), MeanSquaredError()]), maximize=[False, False, False])


@pytest.mark.parametrize(
"method, method_input",
Expand All @@ -48,15 +57,32 @@ def test_raises_error_if_increment_not_called(method, method_input):
@pytest.mark.parametrize(
"base_metric, metric_input, maximize",
[
(partial(Accuracy, num_classes=10), (torch.randint(10, (50,)), torch.randint(10, (50,))), True),
(partial(Precision, num_classes=10), (torch.randint(10, (50,)), torch.randint(10, (50,))), True),
(partial(Recall, num_classes=10), (torch.randint(10, (50,)), torch.randint(10, (50,))), True),
(MeanSquaredError, (torch.randn(50), torch.randn(50)), False),
(MeanAbsoluteError, (torch.randn(50), torch.randn(50)), False),
(Accuracy(num_classes=10), (torch.randint(10, (50,)), torch.randint(10, (50,))), True),
(Precision(num_classes=10), (torch.randint(10, (50,)), torch.randint(10, (50,))), True),
(Recall(num_classes=10), (torch.randint(10, (50,)), torch.randint(10, (50,))), True),
(MeanSquaredError(), (torch.randn(50), torch.randn(50)), False),
(MeanAbsoluteError(), (torch.randn(50), torch.randn(50)), False),
(
MetricCollection([Accuracy(num_classes=10), Precision(num_classes=10), Recall(num_classes=10)]),
(torch.randint(10, (50,)), torch.randint(10, (50,))),
True,
),
(
MetricCollection([Accuracy(num_classes=10), Precision(num_classes=10), Recall(num_classes=10)]),
(torch.randint(10, (50,)), torch.randint(10, (50,))),
[True, True, True],
),
(MetricCollection([MeanSquaredError(), MeanAbsoluteError()]), (torch.randn(50), torch.randn(50)), False),
(
MetricCollection([MeanSquaredError(), MeanAbsoluteError()]),
(torch.randn(50), torch.randn(50)),
[False, False],
),
],
)
def test_tracker(base_metric, metric_input, maximize):
tracker = MetricTracker(base_metric(), maximize=maximize)
"""Test that arguments gets passed correctly to child modules."""
tracker = MetricTracker(base_metric, maximize=maximize)
for i in range(5):
tracker.increment()
# check both update and forward works
Expand All @@ -65,12 +91,30 @@ def test_tracker(base_metric, metric_input, maximize):
for _ in range(5):
tracker(*metric_input)

# Make sure we have computed something
val = tracker.compute()
assert val != 0.0
if isinstance(val, dict):
for v in val.values():
assert v != 0.0
else:
assert val != 0.0
assert tracker.n_steps == i + 1

# Assert that compute all returns all values
assert tracker.n_steps == 5
assert tracker.compute_all().shape[0] == 5
all_computed_val = tracker.compute_all()
if isinstance(all_computed_val, dict):
for v in all_computed_val.values():
assert v.numel() == 5
else:
assert all_computed_val.numel() == 5

# Assert that best_metric returns both index and value
val, idx = tracker.best_metric(return_step=True)
assert val != 0.0
assert idx in list(range(5))
if isinstance(val, dict):
for v, i in zip(val.values(), idx.values()):
assert v != 0.0
assert i in list(range(5))
else:
assert val != 0.0
assert idx in list(range(5))
95 changes: 76 additions & 19 deletions torchmetrics/wrappers/tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from copy import deepcopy
from typing import Any, Tuple, Union
from typing import Any, Dict, List, Tuple, Union

import torch
from torch import Tensor, nn

from torchmetrics.collections import MetricCollection
from torchmetrics.metric import Metric


class MetricTracker(nn.ModuleList):
"""A wrapper class that can help keeping track of a metric over time and implement useful methods. The wrapper
implements the standard `update`, `compute`, `reset` methods that just calls corresponding method of the
currently tracked metric. However, the following additional methods are provided:
"""A wrapper class that can help keeping track of a metric or metric collection over time and implement useful
methods. The wrapper implements the standard `update`, `compute`, `reset` methods that just calls corresponding
method of the currently tracked metric. However, the following additional methods are provided:

-``MetricTracker.n_steps``: number of metrics being tracked

Expand All @@ -34,12 +35,12 @@ class MetricTracker(nn.ModuleList):
-``MetricTracker.best_metric()``: returns the best value

Args:
metric: instance of a torchmetric modular to keep track of at each timestep.
maximize: bool indicating if higher metric values are better (`True`) or lower
is better (`False`)

Example:
metric: instance of a `torchmetrics.Metric` or `torchmetrics.MetricCollection` to keep track
of at each timestep.
maximize: either single bool or list of bool indicating if higher metric values are
better (`True`) or lower is better (`False`).

Example (single metric):
>>> from torchmetrics import Accuracy, MetricTracker
>>> _ = torch.manual_seed(42)
>>> tracker = MetricTracker(Accuracy(num_classes=10))
Expand All @@ -55,15 +56,50 @@ class MetricTracker(nn.ModuleList):
current acc=0.07999999821186066
current acc=0.10199999809265137
>>> best_acc, which_epoch = tracker.best_metric(return_step=True)
>>> best_acc
0.12600000202655792
>>> which_epoch
2
>>> tracker.compute_all()
tensor([0.1120, 0.0880, 0.1260, 0.0800, 0.1020])

Example (multiple metrics using MetricCollection):
>>> from torchmetrics import MetricTracker, MetricCollection, MeanSquaredError, ExplainedVariance
>>> _ = torch.manual_seed(42)
>>> tracker = MetricTracker(MetricCollection([MeanSquaredError(), ExplainedVariance()]), maximize=[False, True])
>>> for epoch in range(5):
... tracker.increment()
... for batch_idx in range(5):
... preds, target = torch.randn(100), torch.randn(100)
... tracker.update(preds, target)
... print(f"current stats={tracker.compute()}") # doctest: +NORMALIZE_WHITESPACE
current stats={'MeanSquaredError': tensor(1.8218), 'ExplainedVariance': tensor(-0.8969)}
current stats={'MeanSquaredError': tensor(2.0268), 'ExplainedVariance': tensor(-1.0206)}
current stats={'MeanSquaredError': tensor(1.9491), 'ExplainedVariance': tensor(-0.8298)}
current stats={'MeanSquaredError': tensor(1.9800), 'ExplainedVariance': tensor(-0.9199)}
current stats={'MeanSquaredError': tensor(2.2481), 'ExplainedVariance': tensor(-1.1622)}
>>> best_res, which_epoch = tracker.best_metric(return_step=True)
>>> best_res
{'MeanSquaredError': 1.8218144178390503, 'ExplainedVariance': -0.8297995328903198}
>>> which_epoch
{'MeanSquaredError': 0, 'ExplainedVariance': 2}
>>> tracker.compute_all() # doctest: +NORMALIZE_WHITESPACE
{'MeanSquaredError': tensor([1.8218, 2.0268, 1.9491, 1.9800, 2.2481]),
'ExplainedVariance': tensor([-0.8969, -1.0206, -0.8298, -0.9199, -1.1622])}
"""

def __init__(self, metric: Metric, maximize: bool = True) -> None:
def __init__(self, metric: Union[Metric, MetricCollection], maximize: Union[bool, List[bool]] = True) -> None:
super().__init__()
if not isinstance(metric, Metric):
raise TypeError("metric arg need to be an instance of a torchmetrics metric" f" but got {metric}")
if not isinstance(metric, (Metric, MetricCollection)):
raise TypeError(
"Metric arg need to be an instance of a torchmetrics"
f" `Metric` or `MetricCollection` but got {metric}"
)
self._base_metric = metric
if not isinstance(maximize, (bool, list)):
raise ValueError("Argument `maximize` should either be a single bool or list of bool")
if isinstance(maximize, list) and isinstance(metric, MetricCollection) and len(maximize) != len(metric):
raise ValueError("The len of argument `maximize` should match the length of the metric collection")
self.maximize = maximize

self._increment_called = False
Expand Down Expand Up @@ -96,7 +132,13 @@ def compute(self) -> Any:
def compute_all(self) -> Tensor:
"""Compute the metric value for all tracked metrics."""
self._check_for_increment("compute_all")
return torch.stack([metric.compute() for i, metric in enumerate(self) if i != 0], dim=0)
# The i!=0 accounts for the self._base_metric should be ignored
res = [metric.compute() for i, metric in enumerate(self) if i != 0]
if isinstance(self._base_metric, MetricCollection):
keys = res[0].keys()
return {k: torch.stack([r[k] for r in res], dim=0) for k in keys}
else:
return torch.stack(res, dim=0)

def reset(self) -> None:
"""Resets the current metric being tracked."""
Expand All @@ -107,7 +149,9 @@ def reset_all(self) -> None:
for metric in self:
metric.reset()

def best_metric(self, return_step: bool = False) -> Union[float, Tuple[int, float]]:
def best_metric(
self, return_step: bool = False
) -> Union[float, Tuple[int, float], Dict[str, float], Tuple[Dict[str, int], Dict[str, float]]]:
"""Returns the highest metric out of all tracked.

Args:
Expand All @@ -116,11 +160,24 @@ def best_metric(self, return_step: bool = False) -> Union[float, Tuple[int, floa
Returns:
The best metric value, and optionally the timestep.
"""
fn = torch.max if self.maximize else torch.min
idx, max = fn(self.compute_all(), 0)
if return_step:
return idx.item(), max.item()
return max.item()
if isinstance(self._base_metric, Metric):
fn = torch.max if self.maximize else torch.min
idx, best = fn(self.compute_all(), 0)
if return_step:
return idx.item(), best.item()
return best.item()
else:
res = self.compute_all()
maximize = self.maximize if isinstance(self.maximize, list) else len(res) * [self.maximize]
idx, best = {}, {}
for i, (k, v) in enumerate(res.items()):
fn = torch.max if maximize[i] else torch.min
out = fn(v, 0)
idx[k], best[k] = out[0].item(), out[1].item()

if return_step:
return idx, best
return best

def _check_for_increment(self, method: str) -> None:
if not self._increment_called:
Expand Down