Skip to content

Commit

Permalink
Bugfix for classwise wrapper compute logic (#1225)
Browse files Browse the repository at this point in the history
* update wrapped methods
* add testing
  • Loading branch information
SkafteNicki authored Sep 16, 2022
1 parent 3fd3dc4 commit 5659805
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 7 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
* Jaccard index ([#1205](https://github.com/Lightning-AI/metrics/pull/1205))


- Fixed bug in `ClasswiseWrapper` such that `compute` gave wrong result ([#1225](https://github.com/Lightning-AI/metrics/pull/1225))


- Fixed synchronization of empty list states ([#1219](https://github.com/Lightning-AI/metrics/pull/1219))


Expand Down
29 changes: 26 additions & 3 deletions src/torchmetrics/wrappers/classwise.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,17 @@
from typing import Any, Dict, List, Optional
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Dict, List, Optional

from torch import Tensor

Expand Down Expand Up @@ -51,8 +64,6 @@ class ClasswiseWrapper(Metric):
'recall_horse': tensor(0.), 'recall_fish': tensor(0.3333), 'recall_dog': tensor(0.4000)}
"""

full_state_update: Optional[bool] = True

def __init__(self, metric: Metric, labels: Optional[List[str]] = None) -> None:
super().__init__()
if not isinstance(metric, Metric):
Expand All @@ -61,13 +72,17 @@ def __init__(self, metric: Metric, labels: Optional[List[str]] = None) -> None:
raise ValueError(f"Expected argument `labels` to either be `None` or a list of strings but got {labels}")
self.metric = metric
self.labels = labels
self._update_count = 1

def _convert(self, x: Tensor) -> Dict[str, Any]:
name = self.metric.__class__.__name__.lower()
if self.labels is None:
return {f"{name}_{i}": val for i, val in enumerate(x)}
return {f"{name}_{lab}": val for lab, val in zip(self.labels, x)}

def forward(self, *args: Any, **kwargs: Any) -> Any:
return self._convert(self.metric(*args, **kwargs))

def update(self, *args: Any, **kwargs: Any) -> None:
self.metric.update(*args, **kwargs)

Expand All @@ -76,3 +91,11 @@ def compute(self) -> Dict[str, Tensor]:

def reset(self) -> None:
self.metric.reset()

def _wrap_update(self, update: Callable) -> Callable:
"""Overwrite to do nothing."""
return update

def _wrap_compute(self, compute: Callable) -> Callable:
"""Overwrite to do nothing."""
return compute
15 changes: 11 additions & 4 deletions tests/unittests/wrappers/test_classwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ def test_output_no_labels():
base = Accuracy(num_classes=3, average=None)
metric = ClasswiseWrapper(Accuracy(num_classes=3, average=None))
for _ in range(2):
preds = torch.randn(10, 3).softmax(dim=-1)
target = torch.randint(3, (10,))
preds = torch.randn(20, 3).softmax(dim=-1)
target = torch.randint(3, (20,))
val = metric(preds, target)
val_base = base(preds, target)
assert isinstance(val, dict)
Expand All @@ -35,15 +35,22 @@ def test_output_with_labels():
base = Accuracy(num_classes=3, average=None)
metric = ClasswiseWrapper(Accuracy(num_classes=3, average=None), labels=labels)
for _ in range(2):
preds = torch.randn(10, 3).softmax(dim=-1)
target = torch.randint(3, (10,))
preds = torch.randn(20, 3).softmax(dim=-1)
target = torch.randint(3, (20,))
val = metric(preds, target)
val_base = base(preds, target)
assert isinstance(val, dict)
assert len(val) == 3
for i, lab in enumerate(labels):
assert f"accuracy_{lab}" in val
assert val[f"accuracy_{lab}"] == val_base[i]
val = metric.compute()
val_base = base.compute()
assert isinstance(val, dict)
assert len(val) == 3
for i, lab in enumerate(labels):
assert f"accuracy_{lab}" in val
assert val[f"accuracy_{lab}"] == val_base[i]


@pytest.mark.parametrize("prefix", [None, "pre_"])
Expand Down

0 comments on commit 5659805

Please sign in to comment.