diff --git a/CHANGELOG.md b/CHANGELOG.md index 3c8ba845fa1..15a97e82b2f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -69,6 +69,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 * Jaccard index ([#1205](https://github.com/Lightning-AI/metrics/pull/1205)) +- Fixed bug in `ClasswiseWrapper` such that `compute` gave wrong result ([#1225](https://github.com/Lightning-AI/metrics/pull/1225)) + + - Fixed synchronization of empty list states ([#1219](https://github.com/Lightning-AI/metrics/pull/1219)) diff --git a/src/torchmetrics/wrappers/classwise.py b/src/torchmetrics/wrappers/classwise.py index b24c3c9429e..82fc867e80c 100644 --- a/src/torchmetrics/wrappers/classwise.py +++ b/src/torchmetrics/wrappers/classwise.py @@ -1,4 +1,17 @@ -from typing import Any, Dict, List, Optional +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Callable, Dict, List, Optional from torch import Tensor @@ -51,8 +64,6 @@ class ClasswiseWrapper(Metric): 'recall_horse': tensor(0.), 'recall_fish': tensor(0.3333), 'recall_dog': tensor(0.4000)} """ - full_state_update: Optional[bool] = True - def __init__(self, metric: Metric, labels: Optional[List[str]] = None) -> None: super().__init__() if not isinstance(metric, Metric): @@ -61,6 +72,7 @@ def __init__(self, metric: Metric, labels: Optional[List[str]] = None) -> None: raise ValueError(f"Expected argument `labels` to either be `None` or a list of strings but got {labels}") self.metric = metric self.labels = labels + self._update_count = 1 def _convert(self, x: Tensor) -> Dict[str, Any]: name = self.metric.__class__.__name__.lower() @@ -68,6 +80,9 @@ def _convert(self, x: Tensor) -> Dict[str, Any]: return {f"{name}_{i}": val for i, val in enumerate(x)} return {f"{name}_{lab}": val for lab, val in zip(self.labels, x)} + def forward(self, *args: Any, **kwargs: Any) -> Any: + return self._convert(self.metric(*args, **kwargs)) + def update(self, *args: Any, **kwargs: Any) -> None: self.metric.update(*args, **kwargs) @@ -76,3 +91,11 @@ def compute(self) -> Dict[str, Tensor]: def reset(self) -> None: self.metric.reset() + + def _wrap_update(self, update: Callable) -> Callable: + """Overwrite to do nothing.""" + return update + + def _wrap_compute(self, compute: Callable) -> Callable: + """Overwrite to do nothing.""" + return compute diff --git a/tests/unittests/wrappers/test_classwise.py b/tests/unittests/wrappers/test_classwise.py index 9e2891b1ad7..e5c8fbf2415 100644 --- a/tests/unittests/wrappers/test_classwise.py +++ b/tests/unittests/wrappers/test_classwise.py @@ -18,8 +18,8 @@ def test_output_no_labels(): base = Accuracy(num_classes=3, average=None) metric = ClasswiseWrapper(Accuracy(num_classes=3, average=None)) for _ in range(2): - preds = torch.randn(10, 3).softmax(dim=-1) - target = torch.randint(3, (10,)) + preds = torch.randn(20, 3).softmax(dim=-1) + target = torch.randint(3, (20,)) val = metric(preds, target) val_base = base(preds, target) assert isinstance(val, dict) @@ -35,8 +35,8 @@ def test_output_with_labels(): base = Accuracy(num_classes=3, average=None) metric = ClasswiseWrapper(Accuracy(num_classes=3, average=None), labels=labels) for _ in range(2): - preds = torch.randn(10, 3).softmax(dim=-1) - target = torch.randint(3, (10,)) + preds = torch.randn(20, 3).softmax(dim=-1) + target = torch.randint(3, (20,)) val = metric(preds, target) val_base = base(preds, target) assert isinstance(val, dict) @@ -44,6 +44,13 @@ def test_output_with_labels(): for i, lab in enumerate(labels): assert f"accuracy_{lab}" in val assert val[f"accuracy_{lab}"] == val_base[i] + val = metric.compute() + val_base = base.compute() + assert isinstance(val, dict) + assert len(val) == 3 + for i, lab in enumerate(labels): + assert f"accuracy_{lab}" in val + assert val[f"accuracy_{lab}"] == val_base[i] @pytest.mark.parametrize("prefix", [None, "pre_"])