Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

JaccardIndex then multilabel=True not working #1172

Closed
claverru opened this issue Aug 4, 2022 · 2 comments · Fixed by #1195
Closed

JaccardIndex then multilabel=True not working #1172

claverru opened this issue Aug 4, 2022 · 2 comments · Fixed by #1195
Labels
bug / fix Something isn't working help wanted Extra attention is needed

Comments

@claverru
Copy link

claverru commented Aug 4, 2022

🐛 Bug

JaccardIndex is not correctly coded when multilabel=True and average is different than "none".

To Reproduce

Instantiate JaccardIndex(..., multilabel=True, average="micro") and call it as usually with a multilabel classification data.

Code sample

target = torch.randint(0, 2, (8, 3))
preds = torch.rand(8, 3)

ji = torchmetrics.classification.JaccardIndex(num_classes=3, multilabel=True, average="micro")

ji(preds, target)

Environment

  • TorchMetrics version 0.9.3

Additional context

When multilabel=True, you call _jaccard_from_confmat and then access to index 1 (this is correct)
https://github.com/Lightning-AI/metrics/blob/v0.9.3/torchmetrics/classification/jaccard.py#L117

Possible implementation snippet

With this IoU (jaccard) base implementation, you can easily organize different combinations (multilabel=True + macro, multilabel=False + micro, etc)

def _compute_iou(cm: torch.Tensor) -> torch.Tensor:
    intersection = cm.diagonal(dim1=-2, dim2=-1)
    union = torch.sum(cm, dim=-1) + torch.sum(cm, dim=-2) - intersection
    return intersection.float() / union.float()

For example:

Micro + multilabel=True

# mlcm: (3, 2, 2)
iou = _compute_iou(mlcm.sum(0))[1]

Macro + multilabel=True

# mlcm: (3, 2, 2)
iou = _compute_iou(mlcm)[:, 1].mean()

Macro + multilabel=False

# cm: (3, 3)
iou = _compute_iou(cm).mean()

And so on.

@claverru claverru added bug / fix Something isn't working help wanted Extra attention is needed labels Aug 4, 2022
@github-actions
Copy link

github-actions bot commented Aug 4, 2022

Hi! thanks for your contribution!, great first issue!

@SkafteNicki
Copy link
Member

Issue will be fixed by classification refactor: see this issue #1001 and this PR #1195 for all changes

Small recap: This issue describes that jaccard_index is wrongly calculated in the multilabel setting. This is simply due to a wrong implementation. Issue have been fixed in the refactor such that everything should be right (our implementation is better tested against sklearn now). Only difference is that instead of using jaccard_index the specialized version multilabel_jaccard_index should be used:

from torchmetrics.functional import multilabel_jaccard_index
import torch

target = torch.randint(0, 2, (8, 3))
preds = torch.rand(8, 3)

multilabel_jaccard_index(preds, target, num_labels=3, average="micro") # tensor(0.2632)
multilabel_jaccard_index(preds, target, num_labels=3, average="micro") # tensor(0.2762)
multilabel_jaccard_index(preds, target, num_labels=3, average=None)  # tensor([0.1429, 0.2857, 0.4000])

which give the correct result. Issue will be closed when #1195 is merged.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug / fix Something isn't working help wanted Extra attention is needed
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants