diff --git a/CHANGELOG.md b/CHANGELOG.md index d20a95708e5..999829ce749 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,8 +16,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added `RetrievalRecallAtFixedPrecision` to retrieval package ([#951](https://github.com/PyTorchLightning/metrics/pull/951)) - -- +- Added support for nested metric collections ([#1003](https://github.com/PyTorchLightning/metrics/pull/1003)) ### Changed diff --git a/tests/bases/test_collections.py b/tests/bases/test_collections.py index 6427f58dcbc..88aeab375ac 100644 --- a/tests/bases/test_collections.py +++ b/tests/bases/test_collections.py @@ -401,3 +401,36 @@ def test_error_on_wrong_specified_compute_groups(): MetricCollection( ConfusionMatrix(3), Recall(3), Precision(3), compute_groups=[["ConfusionMatrix"], ["Recall", "Accuracy"]] ) + + +@pytest.mark.parametrize( + "input_collections", + [ + [ + MetricCollection( + [Accuracy(num_classes=3, average="macro"), Precision(num_classes=3, average="macro")], prefix="macro_" + ), + MetricCollection( + [Accuracy(num_classes=3, average="micro"), Precision(num_classes=3, average="micro")], prefix="micro_" + ), + ], + { + "macro": MetricCollection( + [Accuracy(num_classes=3, average="macro"), Precision(num_classes=3, average="macro")] + ), + "micro": MetricCollection( + [Accuracy(num_classes=3, average="micro"), Precision(num_classes=3, average="micro")] + ), + }, + ], +) +def test_nested_collections(input_collections): + """Test that nested collections gets flattened to a single collection.""" + metrics = MetricCollection(input_collections, prefix="valmetrics/") + preds = torch.randn(10, 3).softmax(dim=-1) + target = torch.randint(3, (10,)) + val = metrics(preds, target) + assert "valmetrics/macro_Accuracy" in val + assert "valmetrics/macro_Precision" in val + assert "valmetrics/micro_Accuracy" in val + assert "valmetrics/micro_Precision" in val diff --git a/torchmetrics/collections.py b/torchmetrics/collections.py index 55492bdd3bd..f819db1d711 100644 --- a/torchmetrics/collections.py +++ b/torchmetrics/collections.py @@ -54,6 +54,10 @@ class name as key for the output dict. this argument is ``True`` which enables this feature. Set this argument to `False` for disabling this behaviour. Can also be set to a list of list of metrics for setting the compute groups yourself. + .. note:: + Metric collections can be nested at initilization (see last example) but the output of the collection will + still be a single flattened dictionary combining the prefix and postfix arguments from the nested collection. + Raises: ValueError: If one of the elements of ``metrics`` is not an instance of ``pl.metrics.Metric``. @@ -104,6 +108,23 @@ class name as key for the output dict. ... ) >>> pprint(metrics(preds, target)) {'Accuracy': tensor(0.1250), 'MeanSquaredError': tensor(2.3750), 'Precision': tensor(0.0667)} + + Example (nested metric collections): + >>> metrics = MetricCollection([ + ... MetricCollection([ + ... Accuracy(num_classes=3, average='macro'), + ... Precision(num_classes=3, average='macro') + ... ], postfix='_macro'), + ... MetricCollection([ + ... Accuracy(num_classes=3, average='micro'), + ... Precision(num_classes=3, average='micro') + ... ], postfix='_micro'), + ... ], prefix='valmetrics/') + >>> pprint(metrics(preds, target)) # doctest: +NORMALIZE_WHITESPACE + {'valmetrics/Accuracy_macro': tensor(0.1111), + 'valmetrics/Accuracy_micro': tensor(0.1250), + 'valmetrics/Precision_macro': tensor(0.0667), + 'valmetrics/Precision_micro': tensor(0.1250)} """ _groups: Dict[int, List[str]] @@ -195,6 +216,10 @@ def _merge_compute_groups(self) -> None: @staticmethod def _equal_metric_states(metric1: Metric, metric2: Metric) -> bool: """Check if the metric state of two metrics are the same.""" + # empty state + if len(metric1._defaults) == 0 or len(metric2._defaults) == 0: + return False + if metric1._defaults.keys() != metric2._defaults.keys(): return False @@ -280,19 +305,31 @@ def add_metrics( # Make sure that metrics are added in deterministic order for name in sorted(metrics.keys()): metric = metrics[name] - if not isinstance(metric, Metric): + if not isinstance(metric, (Metric, MetricCollection)): raise ValueError( - f"Value {metric} belonging to key {name} is not an instance of `pl.metrics.Metric`" + f"Value {metric} belonging to key {name} is not an instance of" + " `torchmetrics.Metric` or `torchmetrics.MetricCollection`" ) - self[name] = metric + if isinstance(metric, Metric): + self[name] = metric + else: + for k, v in metric.items(keep_base=False): + self[f"{name}_{k}"] = v elif isinstance(metrics, Sequence): for metric in metrics: - if not isinstance(metric, Metric): - raise ValueError(f"Input {metric} to `MetricCollection` is not a instance of `pl.metrics.Metric`") - name = metric.__class__.__name__ - if name in self: - raise ValueError(f"Encountered two metrics both named {name}") - self[name] = metric + if not isinstance(metric, (Metric, MetricCollection)): + raise ValueError( + f"Input {metric} to `MetricCollection` is not a instance of" + " `torchmetrics.Metric` or `torchmetrics.MetricCollection`" + ) + if isinstance(metric, Metric): + name = metric.__class__.__name__ + if name in self: + raise ValueError(f"Encountered two metrics both named {name}") + self[name] = metric + else: + for k, v in metric.items(keep_base=False): + self[k] = v else: raise ValueError("Unknown input to MetricCollection.")