Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix metrics in macro average #303

Merged
merged 44 commits into from
Aug 2, 2021
Merged
Show file tree
Hide file tree
Changes from 42 commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
ef0e947
fix weights for nonexisting classes
vatch123 Jun 17, 2021
df136e0
Merge branch 'master' into bug/macro_metrics
vatch123 Jun 18, 2021
a8ad678
fix division by zero
vatch123 Jun 18, 2021
7c57569
Merge branch 'master' into bug/macro_metrics
Borda Jun 21, 2021
4c9bc45
part fix
SkafteNicki Jun 23, 2021
1668e1d
Merge branch 'bug/macro_metrics' of https://github.com/vatch123/metri…
SkafteNicki Jun 23, 2021
b62c82b
add test case
SkafteNicki Jun 23, 2021
ab29824
Merge branch 'master' into bug/macro_metrics
Borda Jun 23, 2021
ec186d3
Merge branch 'master' into bug/macro_metrics
Borda Jun 23, 2021
eaff5bb
Merge branch 'master' into bug/macro_metrics
Borda Jul 1, 2021
e90df07
Apply suggestions from code review
Borda Jul 7, 2021
4c3fa09
Merge branch 'master' into bug/macro_metrics
Borda Jul 7, 2021
067e228
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 7, 2021
7d8cc1a
fix
SkafteNicki Jul 7, 2021
e218174
changelog
SkafteNicki Jul 7, 2021
6f806f9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 7, 2021
2cecec2
please fix
SkafteNicki Jul 7, 2021
fc971a3
Merge branch 'master' into bug/macro_metrics
mergify[bot] Jul 7, 2021
4e2f76d
Merge branch 'master' into bug/macro_metrics
mergify[bot] Jul 9, 2021
fd160b6
Merge branch 'master' into bug/macro_metrics
mergify[bot] Jul 9, 2021
4e68f7a
Merge branch 'master' into bug/macro_metrics
mergify[bot] Jul 13, 2021
dc7c805
Merge branch 'master' into bug/macro_metrics
mergify[bot] Jul 13, 2021
ff3eb2a
Merge branch 'master' into bug/macro_metrics
mergify[bot] Jul 14, 2021
a913f97
Merge branch 'master' into bug/macro_metrics
mergify[bot] Jul 16, 2021
3f328b9
Merge branch 'master' into bug/macro_metrics
mergify[bot] Jul 16, 2021
36d64d8
Merge branch 'master' into bug/macro_metrics
mergify[bot] Jul 16, 2021
328ee51
Merge branch 'master' into bug/macro_metrics
mergify[bot] Jul 16, 2021
99853bb
Merge branch 'master' into bug/macro_metrics
mergify[bot] Jul 19, 2021
b6f0a6a
Merge branch 'master' into bug/macro_metrics
mergify[bot] Jul 24, 2021
bf16f7c
Merge branch 'master' into bug/macro_metrics
mergify[bot] Jul 24, 2021
ca50e71
Merge branch 'master' into bug/macro_metrics
mergify[bot] Jul 24, 2021
69bf988
Merge branch 'master' into bug/macro_metrics
mergify[bot] Jul 24, 2021
20abef7
Merge branch 'master' into bug/macro_metrics
mergify[bot] Jul 24, 2021
6df1eb2
Merge branch 'master' into bug/macro_metrics
Borda Jul 26, 2021
324982f
Merge branch 'master' into bug/macro_metrics
SkafteNicki Jul 26, 2021
a48f638
Merge branch 'master' into bug/macro_metrics
SkafteNicki Jul 29, 2021
51b2147
Merge branch 'master' into bug/macro_metrics
SkafteNicki Aug 2, 2021
de57083
dist_sync not working
SkafteNicki Aug 2, 2021
7d04724
trying to fix
SkafteNicki Aug 2, 2021
48e8e9c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 2, 2021
6704190
Merge branch 'master' into bug/macro_metrics
SkafteNicki Aug 2, 2021
2cd3c19
Merge branch 'master' into bug/macro_metrics
SkafteNicki Aug 2, 2021
289f347
Merge branch 'master' into bug/macro_metrics
Borda Aug 2, 2021
4b01415
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 2, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Fixed

- Fixed bug where classification metrics with `average='macro'` would lead to wrong result if a class was missing ([#303](https://github.com/PyTorchLightning/metrics/pull/303))


- Fixed `weighted`, `multi-class` AUROC computation to allow for 0 observations of some class, as contribution to final AUROC is 0 ([#348](https://github.com/PyTorchLightning/metrics/issues/348))


Expand All @@ -85,7 +88,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

- Fixed calculation in `IoU` metric when using `ignore_index` argument ([#328](https://github.com/PyTorchLightning/metrics/pull/328))


## [0.4.1] - 2021-07-05

### Changed
Expand Down
7 changes: 7 additions & 0 deletions tests/classification/inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,3 +116,10 @@ def generate_plausible_inputs_binary(num_batches=NUM_BATCHES, batch_size=BATCH_S
_input_multilabel_prob_plausible = generate_plausible_inputs_multilabel()

_input_binary_prob_plausible = generate_plausible_inputs_binary()

# randomly remove one class from the input
_temp = torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE))
_class_remove, _class_replace = torch.multinomial(torch.ones(NUM_CLASSES), num_samples=2, replacement=False)
_temp[_temp == _class_remove] = _class_replace

_input_multiclass_with_missing_class = Input(_temp.clone(), _temp.clone())
21 changes: 20 additions & 1 deletion tests/classification/test_accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from tests.classification.inputs import _input_multiclass as _input_mcls
from tests.classification.inputs import _input_multiclass_logits as _input_mcls_logits
from tests.classification.inputs import _input_multiclass_prob as _input_mcls_prob
from tests.classification.inputs import _input_multiclass_with_missing_class as _input_miss_class
from tests.classification.inputs import _input_multidim_multiclass as _input_mdmc
from tests.classification.inputs import _input_multidim_multiclass_prob as _input_mdmc_prob
from tests.classification.inputs import _input_multilabel as _input_mlb
Expand All @@ -31,7 +32,7 @@
from tests.classification.inputs import _input_multilabel_multidim_prob as _input_mlmd_prob
from tests.classification.inputs import _input_multilabel_prob as _input_mlb_prob
from tests.helpers import seed_all
from tests.helpers.testers import NUM_CLASSES, THRESHOLD, MetricTester
from tests.helpers.testers import NUM_BATCHES, NUM_CLASSES, THRESHOLD, MetricTester
from torchmetrics import Accuracy
from torchmetrics.functional import accuracy
from torchmetrics.utilities.checks import _input_format_classification
Expand Down Expand Up @@ -350,3 +351,21 @@ def test_class_not_present(metric_class, metric_fn, ignore_index, expected):
cl_metric(preds, target)
result_cl = cl_metric.compute()
assert torch.allclose(expected, result_cl, equal_nan=True)


@pytest.mark.parametrize('average', ['micro', 'macro', 'weighted'])
def test_same_input(average):
preds = _input_miss_class.preds
target = _input_miss_class.target
preds_flat = torch.cat([p for p in preds], dim=0)
target_flat = torch.cat([t for t in target], dim=0)

mc = Accuracy(num_classes=NUM_CLASSES, average=average)
for i in range(NUM_BATCHES):
mc.update(preds[i], target[i])
class_res = mc.compute()
func_res = accuracy(preds_flat, target_flat, num_classes=NUM_CLASSES, average=average)
sk_res = sk_accuracy(target_flat, preds_flat)

assert torch.allclose(class_res, torch.tensor(sk_res).float())
assert torch.allclose(func_res, torch.tensor(sk_res).float())
26 changes: 24 additions & 2 deletions tests/classification/test_f_beta.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,14 @@
from tests.classification.inputs import _input_multiclass as _input_mcls
from tests.classification.inputs import _input_multiclass_logits as _input_mcls_logits
from tests.classification.inputs import _input_multiclass_prob as _input_mcls_prob
from tests.classification.inputs import _input_multiclass_with_missing_class as _input_miss_class
from tests.classification.inputs import _input_multidim_multiclass as _input_mdmc
from tests.classification.inputs import _input_multidim_multiclass_prob as _input_mdmc_prob
from tests.classification.inputs import _input_multilabel as _input_mlb
from tests.classification.inputs import _input_multilabel_logits as _input_mlb_logits
from tests.classification.inputs import _input_multilabel_prob as _input_mlb_prob
from tests.helpers import seed_all
from tests.helpers.testers import NUM_CLASSES, THRESHOLD, MetricTester
from tests.helpers.testers import NUM_BATCHES, NUM_CLASSES, THRESHOLD, MetricTester
from torchmetrics import F1, FBeta, Metric
from torchmetrics.functional import f1, fbeta
from torchmetrics.utilities.checks import _input_format_classification
Expand All @@ -55,7 +56,6 @@ def _sk_fbeta_f1(preds, target, sk_fn, num_classes, average, multiclass, ignore_
preds, target, THRESHOLD, num_classes=num_classes, multiclass=multiclass
)
sk_preds, sk_target = sk_preds.numpy(), sk_target.numpy()

sk_scores = sk_fn(sk_target, sk_preds, average=average, zero_division=0, labels=labels)

if len(labels) != num_classes and not average:
Expand Down Expand Up @@ -408,3 +408,25 @@ def test_top_k(

assert torch.isclose(class_metric.compute(), result)
assert torch.isclose(metric_fn(preds, target, top_k=k, average=average, num_classes=3), result)


@pytest.mark.parametrize('average', ['micro', 'macro', 'weighted'])
@pytest.mark.parametrize(
'metric_class, metric_functional, sk_fn',
[(partial(FBeta, beta=2.0), partial(fbeta, beta=2.0), partial(fbeta_score, beta=2.0)), (F1, f1, f1_score)]
)
def test_same_input(metric_class, metric_functional, sk_fn, average):
preds = _input_miss_class.preds
target = _input_miss_class.target
preds_flat = torch.cat([p for p in preds], dim=0)
target_flat = torch.cat([t for t in target], dim=0)

mc = metric_class(num_classes=NUM_CLASSES, average=average)
for i in range(NUM_BATCHES):
mc.update(preds[i], target[i])
class_res = mc.compute()
func_res = metric_functional(preds_flat, target_flat, num_classes=NUM_CLASSES, average=average)
sk_res = sk_fn(target_flat, preds_flat, average=average, zero_division=0)

assert torch.allclose(class_res, torch.tensor(sk_res).float())
assert torch.allclose(func_res, torch.tensor(sk_res).float())
26 changes: 24 additions & 2 deletions tests/classification/test_precision_recall.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,14 @@
from tests.classification.inputs import _input_multiclass as _input_mcls
from tests.classification.inputs import _input_multiclass_logits as _input_mcls_logits
from tests.classification.inputs import _input_multiclass_prob as _input_mcls_prob
from tests.classification.inputs import _input_multiclass_with_missing_class as _input_miss_class
from tests.classification.inputs import _input_multidim_multiclass as _input_mdmc
from tests.classification.inputs import _input_multidim_multiclass_prob as _input_mdmc_prob
from tests.classification.inputs import _input_multilabel as _input_mlb
from tests.classification.inputs import _input_multilabel_logits as _input_mlb_logits
from tests.classification.inputs import _input_multilabel_prob as _input_mlb_prob
from tests.helpers import seed_all
from tests.helpers.testers import NUM_CLASSES, THRESHOLD, MetricTester
from tests.helpers.testers import NUM_BATCHES, NUM_CLASSES, THRESHOLD, MetricTester
from torchmetrics import Metric, Precision, Recall
from torchmetrics.functional import precision, precision_recall, recall
from torchmetrics.utilities.checks import _input_format_classification
Expand Down Expand Up @@ -202,7 +203,7 @@ def test_no_support(metric_class, metric_fn):
class TestPrecisionRecall(MetricTester):

@pytest.mark.parametrize("ddp", [False, True])
@pytest.mark.parametrize("dist_sync_on_step", [True, False])
@pytest.mark.parametrize("dist_sync_on_step", [False])
def test_precision_recall_class(
self,
ddp: bool,
Expand Down Expand Up @@ -430,3 +431,24 @@ def test_class_not_present(metric_class, metric_fn, ignore_index, expected):
cl_metric(preds, target)
result_cl = cl_metric.compute()
assert torch.allclose(expected, result_cl, equal_nan=True)


@pytest.mark.parametrize('average', ['micro', 'macro', 'weighted'])
@pytest.mark.parametrize(
'metric_class, metric_functional, sk_fn', [(Precision, precision, precision_score), (Recall, recall, recall_score)]
)
def test_same_input(metric_class, metric_functional, sk_fn, average):
preds = _input_miss_class.preds
target = _input_miss_class.target
preds_flat = torch.cat([p for p in preds], dim=0)
target_flat = torch.cat([t for t in target], dim=0)

mc = metric_class(num_classes=NUM_CLASSES, average=average)
for i in range(NUM_BATCHES):
mc.update(preds[i], target[i])
class_res = mc.compute()
func_res = metric_functional(preds_flat, target_flat, num_classes=NUM_CLASSES, average=average)
sk_res = sk_fn(target_flat, preds_flat, average=average, zero_division=1)

assert torch.allclose(class_res, torch.tensor(sk_res).float())
assert torch.allclose(func_res, torch.tensor(sk_res).float())
6 changes: 6 additions & 0 deletions torchmetrics/functional/classification/accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,12 @@ def _accuracy_compute(
else:
numerator = tp
denominator = tp + fn

if average == AverageMethod.MACRO and mdmc_average != MDMCAverageMethod.SAMPLEWISE:
cond = tp + fp + fn == 0
numerator = numerator[~cond]
denominator = denominator[~cond]

if average == AverageMethod.NONE and mdmc_average != MDMCAverageMethod.SAMPLEWISE:
# a class is not present if there exists no TPs, no FPs, and no FNs
meaningless_indeces = torch.nonzero((tp | fn | fp) == 0).cpu()
Expand Down
7 changes: 6 additions & 1 deletion torchmetrics/functional/classification/f_beta.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,14 @@ def _fbeta_compute(
precision = _safe_divide(tp.float(), tp + fp)
recall = _safe_divide(tp.float(), tp + fn)

if average == AvgMethod.MACRO and mdmc_average != MDMCAverageMethod.SAMPLEWISE:
cond = tp + fp + fn == 0
precision = precision[~cond]
recall = recall[~cond]

num = (1 + beta**2) * precision * recall
denom = beta**2 * precision + recall
denom[denom == 0.] = 1 # avoid division by 0
denom[denom == 0.] = 1.0 # avoid division by 0
# if classes matter and a given class is not present in both the preds and the target,
# computing the score for this class is meaningless, thus they should be ignored
if average == AvgMethod.NONE and mdmc_average != MDMCAverageMethod.SAMPLEWISE:
Expand Down
13 changes: 13 additions & 0 deletions torchmetrics/functional/classification/precision_recall.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,12 @@ def _precision_compute(
) -> Tensor:
numerator = tp
denominator = tp + fp

if average == AverageMethod.MACRO and mdmc_average != MDMCAverageMethod.SAMPLEWISE:
cond = tp + fp + fn == 0
numerator = numerator[~cond]
denominator = denominator[~cond]

if average == AverageMethod.NONE and mdmc_average != MDMCAverageMethod.SAMPLEWISE:
# a class is not present if there exists no TPs, no FPs, and no FNs
meaningless_indeces = torch.nonzero((tp | fn | fp) == 0).cpu()
Expand Down Expand Up @@ -199,11 +205,18 @@ def _recall_compute(
) -> Tensor:
numerator = tp
denominator = tp + fn

if average == AverageMethod.MACRO and mdmc_average != MDMCAverageMethod.SAMPLEWISE:
cond = tp + fp + fn == 0
numerator = numerator[~cond]
denominator = denominator[~cond]

if average == AverageMethod.NONE and mdmc_average != MDMCAverageMethod.SAMPLEWISE:
# a class is not present if there exists no TPs, no FPs, and no FNs
meaningless_indeces = ((tp | fn | fp) == 0).nonzero().cpu()
numerator[meaningless_indeces, ...] = -1
denominator[meaningless_indeces, ...] = -1

return _reduce_stat_scores(
numerator=numerator,
denominator=denominator,
Expand Down