diff --git a/CHANGELOG.md b/CHANGELOG.md index d1aefa21b31..65a30703b41 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added `is_differentiable` property to `AUC`, `AUROC`, `CohenKappa` and `AveragePrecision` ([#178](https://github.com/PyTorchLightning/metrics/pull/178)) +- Added `add_metrics` method to `MetricCollection` for adding additional metrics after initialization ([#221](https://github.com/PyTorchLightning/metrics/pull/221)) + + - Added pre-gather reduction in the case of `dist_reduce_fx="cat"` to reduce communication cost ([#217](https://github.com/PyTorchLightning/metrics/pull/217)) diff --git a/tests/bases/test_collections.py b/tests/bases/test_collections.py index 8ee8dae77f3..7e2b237f99e 100644 --- a/tests/bases/test_collections.py +++ b/tests/bases/test_collections.py @@ -231,6 +231,20 @@ def test_metric_collection_same_order(): assert k1 == k2 +def test_collection_add_metrics(): + m1 = DummyMetricSum() + m2 = DummyMetricDiff() + + collection = MetricCollection([m1]) + collection.add_metrics({'m1_': DummyMetricSum()}) + collection.add_metrics(m2) + + collection.update(5) + results = collection.compute() + assert results['DummyMetricSum'] == results['m1_'] and results['m1_'] == 5 + assert results['DummyMetricDiff'] == -5 + + def test_collection_check_arg(): assert MetricCollection._check_arg(None, 'prefix') is None assert MetricCollection._check_arg('sample', 'prefix') == 'sample' diff --git a/torchmetrics/collections.py b/torchmetrics/collections.py index 105b3d282ad..0793d20def9 100644 --- a/torchmetrics/collections.py +++ b/torchmetrics/collections.py @@ -95,46 +95,8 @@ def __init__( postfix: Optional[str] = None ): super().__init__() - if isinstance(metrics, Metric): - # set compatible with original type expectations - metrics = [metrics] - if isinstance(metrics, Sequence): - # prepare for optional additions - metrics = list(metrics) - remain = [] - for m in additional_metrics: - (metrics if isinstance(m, Metric) else remain).append(m) - if remain: - rank_zero_warn( - f"You have passes extra arguments {remain} which are not `Metric` so they will be ignored." - ) - elif additional_metrics: - raise ValueError( - f"You have passes extra arguments {additional_metrics} which are not compatible" - f" with first passed dictionary {metrics} so they will be ignored." - ) - - if isinstance(metrics, dict): - # Check all values are metrics - # Make sure that metrics are added in deterministic order - for name in sorted(metrics.keys()): - metric = metrics[name] - if not isinstance(metric, Metric): - raise ValueError( - f"Value {metric} belonging to key {name} is not an instance of `pl.metrics.Metric`" - ) - self[name] = metric - elif isinstance(metrics, Sequence): - for metric in metrics: - if not isinstance(metric, Metric): - raise ValueError(f"Input {metric} to `MetricCollection` is not a instance of `pl.metrics.Metric`") - name = metric.__class__.__name__ - if name in self: - raise ValueError(f"Encountered two metrics both named {name}") - self[name] = metric - else: - raise ValueError("Unknown input to MetricCollection.") + self.add_metrics(metrics, *additional_metrics) self.prefix = self._check_arg(prefix, 'prefix') self.postfix = self._check_arg(postfix, 'postfix') @@ -186,6 +148,51 @@ def persistent(self, mode: bool = True) -> None: for _, m in self.items(keep_base=True): m.persistent(mode) + def add_metrics(self, metrics: Union[Metric, Sequence[Metric], Dict[str, Metric]], + *additional_metrics: Metric) -> None: + """Add new metrics to Metric Collection + """ + if isinstance(metrics, Metric): + # set compatible with original type expectations + metrics = [metrics] + if isinstance(metrics, Sequence): + # prepare for optional additions + metrics = list(metrics) + remain = [] + for m in additional_metrics: + (metrics if isinstance(m, Metric) else remain).append(m) + + if remain: + rank_zero_warn( + f"You have passes extra arguments {remain} which are not `Metric` so they will be ignored." + ) + elif additional_metrics: + raise ValueError( + f"You have passes extra arguments {additional_metrics} which are not compatible" + f" with first passed dictionary {metrics} so they will be ignored." + ) + + if isinstance(metrics, dict): + # Check all values are metrics + # Make sure that metrics are added in deterministic order + for name in sorted(metrics.keys()): + metric = metrics[name] + if not isinstance(metric, Metric): + raise ValueError( + f"Value {metric} belonging to key {name} is not an instance of `pl.metrics.Metric`" + ) + self[name] = metric + elif isinstance(metrics, Sequence): + for metric in metrics: + if not isinstance(metric, Metric): + raise ValueError(f"Input {metric} to `MetricCollection` is not a instance of `pl.metrics.Metric`") + name = metric.__class__.__name__ + if name in self: + raise ValueError(f"Encountered two metrics both named {name}") + self[name] = metric + else: + raise ValueError("Unknown input to MetricCollection.") + def _set_name(self, base: str) -> str: name = base if self.prefix is None else self.prefix + base name = name if self.postfix is None else name + self.postfix