Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
SkafteNicki committed Dec 10, 2022
1 parent b7ace05 commit 9bc9a17
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 3 deletions.
2 changes: 1 addition & 1 deletion src/torchmetrics/classification/jaccard.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ def __init__(
self.average = average

def compute(self) -> Tensor:
return _jaccard_index_reduce(self.confmat, average=self.average)
return _jaccard_index_reduce(self.confmat, average=self.average, ignore_index=self.ignore_index)


class MultilabelJaccardIndex(MultilabelConfusionMatrix):
Expand Down
9 changes: 7 additions & 2 deletions src/torchmetrics/functional/classification/jaccard.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
def _jaccard_index_reduce(
confmat: Tensor,
average: Optional[Literal["micro", "macro", "weighted", "none", "binary"]],
ignore_index=None,
) -> Tensor:
"""Perform reduction of an un-normalized confusion matrix into jaccard score.
Expand Down Expand Up @@ -64,6 +65,10 @@ def _jaccard_index_reduce(
num = confmat[:, 1, 1]
denom = confmat[:, 1, 1] + confmat[:, 0, 1] + confmat[:, 1, 0]
else: # multiclass
if ignore_index is not None and 0 <= ignore_index <= confmat.shape[0]:
cond = torch.arange(confmat.shape[0]) != ignore_index
confmat = confmat[cond, :]
confmat = confmat[:, cond]
num = torch.diag(confmat)
denom = confmat.sum(0) + confmat.sum(1) - num

Expand Down Expand Up @@ -217,7 +222,7 @@ def multiclass_jaccard_index(
_multiclass_confusion_matrix_tensor_validation(preds, target, num_classes, ignore_index)
preds, target = _multiclass_confusion_matrix_format(preds, target, ignore_index)
confmat = _multiclass_confusion_matrix_update(preds, target, num_classes)
return _jaccard_index_reduce(confmat, average=average)
return _jaccard_index_reduce(confmat, average=average, ignore_index=ignore_index)


def _multilabel_jaccard_index_arg_validation(
Expand Down Expand Up @@ -293,7 +298,7 @@ def multilabel_jaccard_index(
_multilabel_confusion_matrix_tensor_validation(preds, target, num_labels, ignore_index)
preds, target = _multilabel_confusion_matrix_format(preds, target, num_labels, threshold, ignore_index)
confmat = _multilabel_confusion_matrix_update(preds, target, num_labels)
return _jaccard_index_reduce(confmat, average=average)
return _jaccard_index_reduce(confmat, average=average, ignore_index=ignore_index)


def jaccard_index(
Expand Down

0 comments on commit 9bc9a17

Please sign in to comment.