Skip to content

Commit

Permalink
Fix metrics in macro average (#303)
Browse files Browse the repository at this point in the history
* fix weights for nonexisting classes

* fix division by zero

* part fix

* add test case

* Apply suggestions from code review

* fix

* changelog

* please fix

* dist_sync not working

* trying to fix

Co-authored-by: Jirka Borovec <[email protected]>
Co-authored-by: SkafteNicki <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
  • Loading branch information
5 people authored Aug 2, 2021
1 parent d8b89e0 commit 79cb5e2
Show file tree
Hide file tree
Showing 8 changed files with 104 additions and 7 deletions.
4 changes: 3 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Fixed

- Fixed bug where classification metrics with `average='macro'` would lead to wrong result if a class was missing ([#303](https://github.com/PyTorchLightning/metrics/pull/303))


- Fixed `weighted`, `multi-class` AUROC computation to allow for 0 observations of some class, as contribution to final AUROC is 0 ([#348](https://github.com/PyTorchLightning/metrics/issues/348))


Expand All @@ -85,7 +88,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

- Fixed calculation in `IoU` metric when using `ignore_index` argument ([#328](https://github.com/PyTorchLightning/metrics/pull/328))


## [0.4.1] - 2021-07-05

### Changed
Expand Down
7 changes: 7 additions & 0 deletions tests/classification/inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,3 +116,10 @@ def generate_plausible_inputs_binary(num_batches=NUM_BATCHES, batch_size=BATCH_S
_input_multilabel_prob_plausible = generate_plausible_inputs_multilabel()

_input_binary_prob_plausible = generate_plausible_inputs_binary()

# randomly remove one class from the input
_temp = torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE))
_class_remove, _class_replace = torch.multinomial(torch.ones(NUM_CLASSES), num_samples=2, replacement=False)
_temp[_temp == _class_remove] = _class_replace

_input_multiclass_with_missing_class = Input(_temp.clone(), _temp.clone())
21 changes: 20 additions & 1 deletion tests/classification/test_accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from tests.classification.inputs import _input_multiclass as _input_mcls
from tests.classification.inputs import _input_multiclass_logits as _input_mcls_logits
from tests.classification.inputs import _input_multiclass_prob as _input_mcls_prob
from tests.classification.inputs import _input_multiclass_with_missing_class as _input_miss_class
from tests.classification.inputs import _input_multidim_multiclass as _input_mdmc
from tests.classification.inputs import _input_multidim_multiclass_prob as _input_mdmc_prob
from tests.classification.inputs import _input_multilabel as _input_mlb
Expand All @@ -31,7 +32,7 @@
from tests.classification.inputs import _input_multilabel_multidim_prob as _input_mlmd_prob
from tests.classification.inputs import _input_multilabel_prob as _input_mlb_prob
from tests.helpers import seed_all
from tests.helpers.testers import NUM_CLASSES, THRESHOLD, MetricTester
from tests.helpers.testers import NUM_BATCHES, NUM_CLASSES, THRESHOLD, MetricTester
from torchmetrics import Accuracy
from torchmetrics.functional import accuracy
from torchmetrics.utilities.checks import _input_format_classification
Expand Down Expand Up @@ -342,3 +343,21 @@ def test_class_not_present(metric_class, metric_fn, ignore_index, expected):
cl_metric(preds, target)
result_cl = cl_metric.compute()
assert torch.allclose(expected, result_cl, equal_nan=True)


@pytest.mark.parametrize("average", ["micro", "macro", "weighted"])
def test_same_input(average):
preds = _input_miss_class.preds
target = _input_miss_class.target
preds_flat = torch.cat([p for p in preds], dim=0)
target_flat = torch.cat([t for t in target], dim=0)

mc = Accuracy(num_classes=NUM_CLASSES, average=average)
for i in range(NUM_BATCHES):
mc.update(preds[i], target[i])
class_res = mc.compute()
func_res = accuracy(preds_flat, target_flat, num_classes=NUM_CLASSES, average=average)
sk_res = sk_accuracy(target_flat, preds_flat)

assert torch.allclose(class_res, torch.tensor(sk_res).float())
assert torch.allclose(func_res, torch.tensor(sk_res).float())
26 changes: 24 additions & 2 deletions tests/classification/test_f_beta.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,14 @@
from tests.classification.inputs import _input_multiclass as _input_mcls
from tests.classification.inputs import _input_multiclass_logits as _input_mcls_logits
from tests.classification.inputs import _input_multiclass_prob as _input_mcls_prob
from tests.classification.inputs import _input_multiclass_with_missing_class as _input_miss_class
from tests.classification.inputs import _input_multidim_multiclass as _input_mdmc
from tests.classification.inputs import _input_multidim_multiclass_prob as _input_mdmc_prob
from tests.classification.inputs import _input_multilabel as _input_mlb
from tests.classification.inputs import _input_multilabel_logits as _input_mlb_logits
from tests.classification.inputs import _input_multilabel_prob as _input_mlb_prob
from tests.helpers import seed_all
from tests.helpers.testers import NUM_CLASSES, THRESHOLD, MetricTester
from tests.helpers.testers import NUM_BATCHES, NUM_CLASSES, THRESHOLD, MetricTester
from torchmetrics import F1, FBeta, Metric
from torchmetrics.functional import f1, fbeta
from torchmetrics.utilities.checks import _input_format_classification
Expand All @@ -55,7 +56,6 @@ def _sk_fbeta_f1(preds, target, sk_fn, num_classes, average, multiclass, ignore_
preds, target, THRESHOLD, num_classes=num_classes, multiclass=multiclass
)
sk_preds, sk_target = sk_preds.numpy(), sk_target.numpy()

sk_scores = sk_fn(sk_target, sk_preds, average=average, zero_division=0, labels=labels)

if len(labels) != num_classes and not average:
Expand Down Expand Up @@ -425,3 +425,25 @@ def test_top_k(

assert torch.isclose(class_metric.compute(), result)
assert torch.isclose(metric_fn(preds, target, top_k=k, average=average, num_classes=3), result)


@pytest.mark.parametrize("average", ["micro", "macro", "weighted"])
@pytest.mark.parametrize(
"metric_class, metric_functional, sk_fn",
[(partial(FBeta, beta=2.0), partial(fbeta, beta=2.0), partial(fbeta_score, beta=2.0)), (F1, f1, f1_score)],
)
def test_same_input(metric_class, metric_functional, sk_fn, average):
preds = _input_miss_class.preds
target = _input_miss_class.target
preds_flat = torch.cat([p for p in preds], dim=0)
target_flat = torch.cat([t for t in target], dim=0)

mc = metric_class(num_classes=NUM_CLASSES, average=average)
for i in range(NUM_BATCHES):
mc.update(preds[i], target[i])
class_res = mc.compute()
func_res = metric_functional(preds_flat, target_flat, num_classes=NUM_CLASSES, average=average)
sk_res = sk_fn(target_flat, preds_flat, average=average, zero_division=0)

assert torch.allclose(class_res, torch.tensor(sk_res).float())
assert torch.allclose(func_res, torch.tensor(sk_res).float())
26 changes: 24 additions & 2 deletions tests/classification/test_precision_recall.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,14 @@
from tests.classification.inputs import _input_multiclass as _input_mcls
from tests.classification.inputs import _input_multiclass_logits as _input_mcls_logits
from tests.classification.inputs import _input_multiclass_prob as _input_mcls_prob
from tests.classification.inputs import _input_multiclass_with_missing_class as _input_miss_class
from tests.classification.inputs import _input_multidim_multiclass as _input_mdmc
from tests.classification.inputs import _input_multidim_multiclass_prob as _input_mdmc_prob
from tests.classification.inputs import _input_multilabel as _input_mlb
from tests.classification.inputs import _input_multilabel_logits as _input_mlb_logits
from tests.classification.inputs import _input_multilabel_prob as _input_mlb_prob
from tests.helpers import seed_all
from tests.helpers.testers import NUM_CLASSES, THRESHOLD, MetricTester
from tests.helpers.testers import NUM_BATCHES, NUM_CLASSES, THRESHOLD, MetricTester
from torchmetrics import Metric, Precision, Recall
from torchmetrics.functional import precision, precision_recall, recall
from torchmetrics.utilities.checks import _input_format_classification
Expand Down Expand Up @@ -209,7 +210,7 @@ def test_no_support(metric_class, metric_fn):
)
class TestPrecisionRecall(MetricTester):
@pytest.mark.parametrize("ddp", [False, True])
@pytest.mark.parametrize("dist_sync_on_step", [True, False])
@pytest.mark.parametrize("dist_sync_on_step", [False])
def test_precision_recall_class(
self,
ddp: bool,
Expand Down Expand Up @@ -437,3 +438,24 @@ def test_class_not_present(metric_class, metric_fn, ignore_index, expected):
cl_metric(preds, target)
result_cl = cl_metric.compute()
assert torch.allclose(expected, result_cl, equal_nan=True)


@pytest.mark.parametrize("average", ["micro", "macro", "weighted"])
@pytest.mark.parametrize(
"metric_class, metric_functional, sk_fn", [(Precision, precision, precision_score), (Recall, recall, recall_score)]
)
def test_same_input(metric_class, metric_functional, sk_fn, average):
preds = _input_miss_class.preds
target = _input_miss_class.target
preds_flat = torch.cat([p for p in preds], dim=0)
target_flat = torch.cat([t for t in target], dim=0)

mc = metric_class(num_classes=NUM_CLASSES, average=average)
for i in range(NUM_BATCHES):
mc.update(preds[i], target[i])
class_res = mc.compute()
func_res = metric_functional(preds_flat, target_flat, num_classes=NUM_CLASSES, average=average)
sk_res = sk_fn(target_flat, preds_flat, average=average, zero_division=1)

assert torch.allclose(class_res, torch.tensor(sk_res).float())
assert torch.allclose(func_res, torch.tensor(sk_res).float())
6 changes: 6 additions & 0 deletions torchmetrics/functional/classification/accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,12 @@ def _accuracy_compute(
else:
numerator = tp
denominator = tp + fn

if average == AverageMethod.MACRO and mdmc_average != MDMCAverageMethod.SAMPLEWISE:
cond = tp + fp + fn == 0
numerator = numerator[~cond]
denominator = denominator[~cond]

if average == AverageMethod.NONE and mdmc_average != MDMCAverageMethod.SAMPLEWISE:
# a class is not present if there exists no TPs, no FPs, and no FNs
meaningless_indeces = torch.nonzero((tp | fn | fp) == 0).cpu()
Expand Down
8 changes: 7 additions & 1 deletion torchmetrics/functional/classification/f_beta.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,15 @@ def _fbeta_compute(
precision = _safe_divide(tp.float(), tp + fp)
recall = _safe_divide(tp.float(), tp + fn)

if average == AvgMethod.MACRO and mdmc_average != MDMCAverageMethod.SAMPLEWISE:
cond = tp + fp + fn == 0
precision = precision[~cond]
recall = recall[~cond]

num = (1 + beta ** 2) * precision * recall
denom = beta ** 2 * precision + recall
denom[denom == 0.0] = 1 # avoid division by 0
denom[denom == 0.0] = 1.0 # avoid division by 0

# if classes matter and a given class is not present in both the preds and the target,
# computing the score for this class is meaningless, thus they should be ignored
if average == AvgMethod.NONE and mdmc_average != MDMCAverageMethod.SAMPLEWISE:
Expand Down
13 changes: 13 additions & 0 deletions torchmetrics/functional/classification/precision_recall.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,12 @@ def _precision_compute(
) -> Tensor:
numerator = tp
denominator = tp + fp

if average == AverageMethod.MACRO and mdmc_average != MDMCAverageMethod.SAMPLEWISE:
cond = tp + fp + fn == 0
numerator = numerator[~cond]
denominator = denominator[~cond]

if average == AverageMethod.NONE and mdmc_average != MDMCAverageMethod.SAMPLEWISE:
# a class is not present if there exists no TPs, no FPs, and no FNs
meaningless_indeces = torch.nonzero((tp | fn | fp) == 0).cpu()
Expand Down Expand Up @@ -199,11 +205,18 @@ def _recall_compute(
) -> Tensor:
numerator = tp
denominator = tp + fn

if average == AverageMethod.MACRO and mdmc_average != MDMCAverageMethod.SAMPLEWISE:
cond = tp + fp + fn == 0
numerator = numerator[~cond]
denominator = denominator[~cond]

if average == AverageMethod.NONE and mdmc_average != MDMCAverageMethod.SAMPLEWISE:
# a class is not present if there exists no TPs, no FPs, and no FNs
meaningless_indeces = ((tp | fn | fp) == 0).nonzero().cpu()
numerator[meaningless_indeces, ...] = -1
denominator[meaningless_indeces, ...] = -1

return _reduce_stat_scores(
numerator=numerator,
denominator=denominator,
Expand Down

0 comments on commit 79cb5e2

Please sign in to comment.