Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bugfix for classwise wrapper compute logic #1225

Merged
merged 10 commits into from
Sep 16, 2022
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
* Jaccard index ([#1205](https://github.com/Lightning-AI/metrics/pull/1205))


- Fixed bug in `ClasswiseWrapper` such that `compute` gave wrong result ([#1225](https://github.com/Lightning-AI/metrics/pull/1225))


- Fixed synchronization of empty list states ([#1219](https://github.com/Lightning-AI/metrics/pull/1219))


Expand Down
29 changes: 26 additions & 3 deletions src/torchmetrics/wrappers/classwise.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,17 @@
from typing import Any, Dict, List, Optional
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Dict, List, Optional

from torch import Tensor

Expand Down Expand Up @@ -51,8 +64,6 @@ class ClasswiseWrapper(Metric):
'recall_horse': tensor(0.), 'recall_fish': tensor(0.3333), 'recall_dog': tensor(0.4000)}
"""

full_state_update: Optional[bool] = True

def __init__(self, metric: Metric, labels: Optional[List[str]] = None) -> None:
super().__init__()
if not isinstance(metric, Metric):
Expand All @@ -61,13 +72,17 @@ def __init__(self, metric: Metric, labels: Optional[List[str]] = None) -> None:
raise ValueError(f"Expected argument `labels` to either be `None` or a list of strings but got {labels}")
self.metric = metric
self.labels = labels
self._update_count = 1

def _convert(self, x: Tensor) -> Dict[str, Any]:
name = self.metric.__class__.__name__.lower()
if self.labels is None:
return {f"{name}_{i}": val for i, val in enumerate(x)}
return {f"{name}_{lab}": val for lab, val in zip(self.labels, x)}

def forward(self, *args: Any, **kwargs: Any) -> Any:
return self._convert(self.metric(*args, **kwargs))

def update(self, *args: Any, **kwargs: Any) -> None:
self.metric.update(*args, **kwargs)

Expand All @@ -76,3 +91,11 @@ def compute(self) -> Dict[str, Tensor]:

def reset(self) -> None:
self.metric.reset()

def _wrap_update(self, update: Callable) -> Callable:
"""Overwrite to do nothing."""
return update

def _wrap_compute(self, compute: Callable) -> Callable:
"""Overwrite to do nothing."""
return compute
15 changes: 11 additions & 4 deletions tests/unittests/wrappers/test_classwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ def test_output_no_labels():
base = Accuracy(num_classes=3, average=None)
metric = ClasswiseWrapper(Accuracy(num_classes=3, average=None))
for _ in range(2):
preds = torch.randn(10, 3).softmax(dim=-1)
target = torch.randint(3, (10,))
preds = torch.randn(20, 3).softmax(dim=-1)
target = torch.randint(3, (20,))
val = metric(preds, target)
val_base = base(preds, target)
assert isinstance(val, dict)
Expand All @@ -35,15 +35,22 @@ def test_output_with_labels():
base = Accuracy(num_classes=3, average=None)
metric = ClasswiseWrapper(Accuracy(num_classes=3, average=None), labels=labels)
for _ in range(2):
preds = torch.randn(10, 3).softmax(dim=-1)
target = torch.randint(3, (10,))
preds = torch.randn(20, 3).softmax(dim=-1)
target = torch.randint(3, (20,))
val = metric(preds, target)
val_base = base(preds, target)
assert isinstance(val, dict)
assert len(val) == 3
for i, lab in enumerate(labels):
assert f"accuracy_{lab}" in val
assert val[f"accuracy_{lab}"] == val_base[i]
val = metric.compute()
val_base = base.compute()
assert isinstance(val, dict)
assert len(val) == 3
for i, lab in enumerate(labels):
assert f"accuracy_{lab}" in val
assert val[f"accuracy_{lab}"] == val_base[i]


@pytest.mark.parametrize("prefix", [None, "pre_"])
Expand Down