Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Nested metric collections #1003

Merged
merged 7 commits into from
May 5, 2022
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

- Added `RetrievalRecallAtFixedPrecision` to retrieval package ([#951](https://github.com/PyTorchLightning/metrics/pull/951))


-
- Added support for nested metric collections ([#1003](https://github.com/PyTorchLightning/metrics/pull/1003))


### Changed
Expand Down
33 changes: 33 additions & 0 deletions tests/bases/test_collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,3 +401,36 @@ def test_error_on_wrong_specified_compute_groups():
MetricCollection(
ConfusionMatrix(3), Recall(3), Precision(3), compute_groups=[["ConfusionMatrix"], ["Recall", "Accuracy"]]
)


@pytest.mark.parametrize(
"input_collections",
[
[
MetricCollection(
[Accuracy(num_classes=3, average="macro"), Precision(num_classes=3, average="macro")], prefix="macro_"
),
MetricCollection(
[Accuracy(num_classes=3, average="micro"), Precision(num_classes=3, average="micro")], prefix="micro_"
),
],
{
"macro": MetricCollection(
[Accuracy(num_classes=3, average="macro"), Precision(num_classes=3, average="macro")]
),
"micro": MetricCollection(
[Accuracy(num_classes=3, average="micro"), Precision(num_classes=3, average="micro")]
),
},
],
)
def test_nested_collections(input_collections):
"""Test that nested collections gets flattened to a single collection."""
metrics = MetricCollection(input_collections, prefix="valmetrics/")
preds = torch.randn(10, 3).softmax(dim=-1)
target = torch.randint(3, (10,))
val = metrics(preds, target)
assert "valmetrics/macro_Accuracy" in val
assert "valmetrics/macro_Precision" in val
assert "valmetrics/micro_Accuracy" in val
assert "valmetrics/micro_Precision" in val
55 changes: 46 additions & 9 deletions torchmetrics/collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,10 @@ class name as key for the output dict.
this argument is ``True`` which enables this feature. Set this argument to `False` for disabling
this behaviour. Can also be set to a list of list of metrics for setting the compute groups yourself.

.. note::
Metric collections can be nested at initilization (see last example) but the output of the collection will
still be a single flattened dictionary combining the prefix and postfix arguments from the nested collection.

Raises:
ValueError:
If one of the elements of ``metrics`` is not an instance of ``pl.metrics.Metric``.
Expand Down Expand Up @@ -104,6 +108,23 @@ class name as key for the output dict.
... )
>>> pprint(metrics(preds, target))
{'Accuracy': tensor(0.1250), 'MeanSquaredError': tensor(2.3750), 'Precision': tensor(0.0667)}

Example (nested metric collections):
>>> metrics = MetricCollection([
... MetricCollection([
... Accuracy(num_classes=3, average='macro'),
... Precision(num_classes=3, average='macro')
... ], postfix='_macro'),
... MetricCollection([
... Accuracy(num_classes=3, average='micro'),
... Precision(num_classes=3, average='micro')
... ], postfix='_micro'),
... ], prefix='valmetrics/')
>>> pprint(metrics(preds, target)) # doctest: +NORMALIZE_WHITESPACE
{'valmetrics/Accuracy_macro': tensor(0.1111),
'valmetrics/Accuracy_micro': tensor(0.1250),
'valmetrics/Precision_macro': tensor(0.0667),
'valmetrics/Precision_micro': tensor(0.1250)}
"""

_groups: Dict[int, List[str]]
Expand Down Expand Up @@ -195,6 +216,10 @@ def _merge_compute_groups(self) -> None:
@staticmethod
def _equal_metric_states(metric1: Metric, metric2: Metric) -> bool:
"""Check if the metric state of two metrics are the same."""
# empty state
if len(metric1._defaults) == 0 or len(metric2._defaults) == 0:
return False

if metric1._defaults.keys() != metric2._defaults.keys():
return False

Expand Down Expand Up @@ -280,19 +305,31 @@ def add_metrics(
# Make sure that metrics are added in deterministic order
for name in sorted(metrics.keys()):
metric = metrics[name]
if not isinstance(metric, Metric):
if not isinstance(metric, (Metric, MetricCollection)):
raise ValueError(
f"Value {metric} belonging to key {name} is not an instance of `pl.metrics.Metric`"
f"Value {metric} belonging to key {name} is not an instance of"
" `torchmetrics.Metric` or `torchmetrics.MetricCollection`"
)
self[name] = metric
if isinstance(metric, Metric):
self[name] = metric
else:
for k, v in metric.items(keep_base=False):
self[f"{name}_{k}"] = v
elif isinstance(metrics, Sequence):
for metric in metrics:
if not isinstance(metric, Metric):
raise ValueError(f"Input {metric} to `MetricCollection` is not a instance of `pl.metrics.Metric`")
name = metric.__class__.__name__
if name in self:
raise ValueError(f"Encountered two metrics both named {name}")
self[name] = metric
if not isinstance(metric, (Metric, MetricCollection)):
raise ValueError(
f"Input {metric} to `MetricCollection` is not a instance of"
" `torchmetrics.Metric` or `torchmetrics.MetricCollection`"
)
if isinstance(metric, Metric):
name = metric.__class__.__name__
if name in self:
raise ValueError(f"Encountered two metrics both named {name}")
self[name] = metric
else:
for k, v in metric.items(keep_base=False):
self[k] = v
else:
raise ValueError("Unknown input to MetricCollection.")

Expand Down