diff --git a/src/torchmetrics/classification/jaccard.py b/src/torchmetrics/classification/jaccard.py index c637a8a94b6..bf8e65fc173 100644 --- a/src/torchmetrics/classification/jaccard.py +++ b/src/torchmetrics/classification/jaccard.py @@ -168,7 +168,7 @@ def __init__( self.average = average def compute(self) -> Tensor: - return _jaccard_index_reduce(self.confmat, average=self.average) + return _jaccard_index_reduce(self.confmat, average=self.average, ignore_index=self.ignore_index) class MultilabelJaccardIndex(MultilabelConfusionMatrix): diff --git a/src/torchmetrics/functional/classification/jaccard.py b/src/torchmetrics/functional/classification/jaccard.py index d9c70be4169..f06d41290d1 100644 --- a/src/torchmetrics/functional/classification/jaccard.py +++ b/src/torchmetrics/functional/classification/jaccard.py @@ -37,6 +37,7 @@ def _jaccard_index_reduce( confmat: Tensor, average: Optional[Literal["micro", "macro", "weighted", "none", "binary"]], + ignore_index=None, ) -> Tensor: """Perform reduction of an un-normalized confusion matrix into jaccard score. @@ -64,6 +65,10 @@ def _jaccard_index_reduce( num = confmat[:, 1, 1] denom = confmat[:, 1, 1] + confmat[:, 0, 1] + confmat[:, 1, 0] else: # multiclass + if ignore_index is not None and 0 <= ignore_index <= confmat.shape[0]: + cond = torch.arange(confmat.shape[0]) != ignore_index + confmat = confmat[cond, :] + confmat = confmat[:, cond] num = torch.diag(confmat) denom = confmat.sum(0) + confmat.sum(1) - num @@ -217,7 +222,7 @@ def multiclass_jaccard_index( _multiclass_confusion_matrix_tensor_validation(preds, target, num_classes, ignore_index) preds, target = _multiclass_confusion_matrix_format(preds, target, ignore_index) confmat = _multiclass_confusion_matrix_update(preds, target, num_classes) - return _jaccard_index_reduce(confmat, average=average) + return _jaccard_index_reduce(confmat, average=average, ignore_index=ignore_index) def _multilabel_jaccard_index_arg_validation( @@ -293,7 +298,7 @@ def multilabel_jaccard_index( _multilabel_confusion_matrix_tensor_validation(preds, target, num_labels, ignore_index) preds, target = _multilabel_confusion_matrix_format(preds, target, num_labels, threshold, ignore_index) confmat = _multilabel_confusion_matrix_update(preds, target, num_labels) - return _jaccard_index_reduce(confmat, average=average) + return _jaccard_index_reduce(confmat, average=average, ignore_index=ignore_index) def jaccard_index(