From cc91ef3308bf417a3c2fc487a195a22655af9423 Mon Sep 17 00:00:00 2001 From: Ihar Date: Sun, 2 May 2021 11:46:48 +0300 Subject: [PATCH 1/2] added add_metrics method to MetricCollection --- tests/bases/test_collections.py | 14 ++++++ torchmetrics/collections.py | 85 ++++++++++++++++++--------------- 2 files changed, 60 insertions(+), 39 deletions(-) diff --git a/tests/bases/test_collections.py b/tests/bases/test_collections.py index cb3d3d2c4b1..cbe29d706fe 100644 --- a/tests/bases/test_collections.py +++ b/tests/bases/test_collections.py @@ -184,6 +184,20 @@ def test_metric_collection_same_order(): assert k1 == k2 +def test_collection_add_metrics(): + m1 = DummyMetricSum() + m2 = DummyMetricDiff() + + collection = MetricCollection([m1]) + collection.add_metrics({'m1_': DummyMetricSum()}) + collection.add_metrics(m2) + + collection.update(5) + results = collection.compute() + assert results['DummyMetricSum'] == results['m1_'] and results['m1_'] == 5 + assert results['DummyMetricDiff'] == -5 + + def test_collection_check_arg(): assert MetricCollection._check_arg(None, 'prefix') is None assert MetricCollection._check_arg('sample', 'prefix') == 'sample' diff --git a/torchmetrics/collections.py b/torchmetrics/collections.py index fd7ce8945b9..00ac61e7bcf 100644 --- a/torchmetrics/collections.py +++ b/torchmetrics/collections.py @@ -94,46 +94,8 @@ def __init__( postfix: Optional[str] = None ): super().__init__() - if isinstance(metrics, Metric): - # set compatible with original type expectations - metrics = [metrics] - if isinstance(metrics, Sequence): - # prepare for optional additions - metrics = list(metrics) - remain = [] - for m in additional_metrics: - (metrics if isinstance(m, Metric) else remain).append(m) - if remain: - rank_zero_warn( - f"You have passes extra arguments {remain} which are not `Metric` so they will be ignored." - ) - elif additional_metrics: - raise ValueError( - f"You have passes extra arguments {additional_metrics} which are not compatible" - f" with first passed dictionary {metrics} so they will be ignored." - ) - - if isinstance(metrics, dict): - # Check all values are metrics - # Make sure that metrics are added in deterministic order - for name in sorted(metrics.keys()): - metric = metrics[name] - if not isinstance(metric, Metric): - raise ValueError( - f"Value {metric} belonging to key {name} is not an instance of `pl.metrics.Metric`" - ) - self[name] = metric - elif isinstance(metrics, Sequence): - for metric in metrics: - if not isinstance(metric, Metric): - raise ValueError(f"Input {metric} to `MetricCollection` is not a instance of `pl.metrics.Metric`") - name = metric.__class__.__name__ - if name in self: - raise ValueError(f"Encountered two metrics both named {name}") - self[name] = metric - else: - raise ValueError("Unknown input to MetricCollection.") + self.add_metrics(metrics, *additional_metrics) self.prefix = self._check_arg(prefix, 'prefix') self.postfix = self._check_arg(postfix, 'postfix') @@ -185,6 +147,51 @@ def persistent(self, mode: bool = True) -> None: for _, m in self.items(): m.persistent(mode) + def add_metrics(self, metrics: Union[Metric, Sequence[Metric], Dict[str, Metric]], + *additional_metrics: Metric) -> None: + """Add new metrics to Metric Collection + """ + if isinstance(metrics, Metric): + # set compatible with original type expectations + metrics = [metrics] + if isinstance(metrics, Sequence): + # prepare for optional additions + metrics = list(metrics) + remain = [] + for m in additional_metrics: + (metrics if isinstance(m, Metric) else remain).append(m) + + if remain: + rank_zero_warn( + f"You have passes extra arguments {remain} which are not `Metric` so they will be ignored." + ) + elif additional_metrics: + raise ValueError( + f"You have passes extra arguments {additional_metrics} which are not compatible" + f" with first passed dictionary {metrics} so they will be ignored." + ) + + if isinstance(metrics, dict): + # Check all values are metrics + # Make sure that metrics are added in deterministic order + for name in sorted(metrics.keys()): + metric = metrics[name] + if not isinstance(metric, Metric): + raise ValueError( + f"Value {metric} belonging to key {name} is not an instance of `pl.metrics.Metric`" + ) + self[name] = metric + elif isinstance(metrics, Sequence): + for metric in metrics: + if not isinstance(metric, Metric): + raise ValueError(f"Input {metric} to `MetricCollection` is not a instance of `pl.metrics.Metric`") + name = metric.__class__.__name__ + if name in self: + raise ValueError(f"Encountered two metrics both named {name}") + self[name] = metric + else: + raise ValueError("Unknown input to MetricCollection.") + def _set_name(self, base: str) -> str: name = base if self.prefix is None else self.prefix + base name = name if self.postfix is None else name + self.postfix From fb8e815c2e7f67063bc2ffe92ff34881c013032e Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Mon, 3 May 2021 08:53:42 +0200 Subject: [PATCH 2/2] changelog --- CHANGELOG.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index e137e8b7595..315db8afa55 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added `is_differentiable` property to `AUC`, `AUROC`, `CohenKappa` and `AveragePrecision` ([#178](https://github.com/PyTorchLightning/metrics/pull/178)) + +- Added `add_metrics` method to `MetricCollection` for adding additional metrics after initialization ([#221](https://github.com/PyTorchLightning/metrics/pull/221)) + + ### Changed