Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Dice and BinaryDice metrics #174

Merged
merged 8 commits into from
May 2, 2022
Merged
6 changes: 5 additions & 1 deletion src/super_gradients/common/factories/metrics_factory.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from super_gradients.common.factories.base_factory import BaseFactory
from super_gradients.training.metrics import Accuracy, Top5, DetectionMetrics, IoU, PixelAccuracy
from super_gradients.training.metrics import Accuracy, Top5, DetectionMetrics, IoU, PixelAccuracy, BinaryIOU, Dice,\
BinaryDice


class MetricsFactory(BaseFactory):
Expand All @@ -10,6 +11,9 @@ def __init__(self):
'Top5': Top5,
'DetectionMetrics': DetectionMetrics,
'IoU': IoU,
"BinaryIOU": BinaryIOU,
"Dice": Dice,
"BinaryDice": BinaryDice,
'PixelAccuracy': PixelAccuracy,
}
super().__init__(type_dict)
141 changes: 127 additions & 14 deletions src/super_gradients/training/metrics/segmentation_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import torch
import torchmetrics
from torchmetrics import Metric
from typing import Optional
from torchmetrics.utilities.distributed import reduce


def batch_pix_accuracy(predict, target):
Expand Down Expand Up @@ -59,6 +61,51 @@ def pixel_accuracy(im_pred, im_lab):
return pixel_correct, pixel_labeled


def _dice_from_confmat(
confmat: torch.Tensor,
num_classes: int,
ignore_index: Optional[int] = None,
absent_score: float = 0.0,
reduction: str = "elementwise_mean",
) -> torch.Tensor:
"""Computes Dice coefficient from confusion matrix.

Args:
confmat: Confusion matrix without normalization
num_classes: Number of classes for a given prediction and target tensor
ignore_index: optional int specifying a target class to ignore. If given, this class index does not contribute
to the returned score, regardless of reduction method.
absent_score: score to use for an individual class, if no instances of the class index were present in `pred`
AND no instances of the class index were present in `target`.
reduction: a method to reduce metric score over labels.

- ``'elementwise_mean'``: takes the mean (default)
- ``'sum'``: takes the sum
- ``'none'``: no reduction will be applied
"""

# Remove the ignored class index from the scores.
if ignore_index is not None and 0 <= ignore_index < num_classes:
confmat[ignore_index] = 0.0

intersection = torch.diag(confmat)
denominator = confmat.sum(0) + confmat.sum(1)

# If this class is absent in both target AND pred (union == 0), then use the absent_score for this class.
scores = 2 * intersection.float() / denominator.float()
scores[denominator == 0] = absent_score

if ignore_index is not None and 0 <= ignore_index < num_classes:
scores = torch.cat(
[
scores[:ignore_index],
scores[ignore_index + 1:],
]
)

return reduce(scores, reduction=reduction)


def intersection_and_union(im_pred, im_lab, num_class):
im_pred = np.asarray(im_pred)
im_lab = np.asarray(im_lab)
Expand All @@ -77,6 +124,26 @@ def intersection_and_union(im_pred, im_lab, num_class):
return area_inter, area_union


def _preprocess_segmentation_inputs(preds, target: torch.Tensor,
lkdci marked this conversation as resolved.
Show resolved Hide resolved
apply_arg_max: bool = False,
apply_sigmoid: bool = False):
"""
preprocess segmentation predictions and target before updating segmentation metrics, handles multiple inputs and
apply normalizations.
:param apply_arg_max: Whether to apply argmax on predictions tensor.
:param apply_sigmoid: Whether to apply sigmoid on predictions tensor.
"""
# WHEN DEALING WITH MULTIPLE OUTPUTS- OUTPUTS[0] IS THE MAIN SEGMENTATION MAP
if isinstance(preds, (tuple, list)):
preds = preds[0]
if apply_arg_max:
_, preds = torch.max(preds, 1)
elif apply_sigmoid:
preds = torch.sigmoid(preds)

return preds, target


class PixelAccuracy(Metric):
def __init__(self, ignore_label=-100, dist_sync_on_step=False):
super().__init__(dist_sync_on_step=dist_sync_on_step)
Expand All @@ -85,9 +152,8 @@ def __init__(self, ignore_label=-100, dist_sync_on_step=False):
self.add_state("total_label", default=torch.tensor(0.), dist_reduce_fx="sum")

def update(self, preds: torch.Tensor, target: torch.Tensor):
if isinstance(preds, tuple):
preds = preds[0]
_, predict = torch.max(preds, 1)
predict, target = _preprocess_segmentation_inputs(preds, target, apply_arg_max=True)

labeled_mask = target.ne(self.ignore_label)
pixel_labeled = torch.sum(labeled_mask)
pixel_correct = torch.sum((predict == target) * labeled_mask)
Expand All @@ -101,26 +167,73 @@ def compute(self):
return pix_acc


class IoU(torchmetrics.IoU):
def __init__(self, num_classes, dist_sync_on_step=True, ignore_index=None):
super().__init__(num_classes=num_classes, dist_sync_on_step=dist_sync_on_step, ignore_index=ignore_index)
class IoU(torchmetrics.JaccardIndex):
shaydeci marked this conversation as resolved.
Show resolved Hide resolved
def __init__(self,
num_classes: int,
dist_sync_on_step: bool = True,
ignore_index: Optional[int] = None,
reduction: str = "elementwise_mean",
threshold: float = 0.5):
super().__init__(num_classes=num_classes, dist_sync_on_step=dist_sync_on_step, ignore_index=ignore_index,
reduction=reduction, threshold=threshold)

def update(self, preds, target: torch.Tensor):
# WHEN DEALING WITH MULTIPLE OUTPUTS- OUTPUTS[0] IS THE MAIN SEGMENTATION MAP
if isinstance(preds, tuple):
preds = preds[0]
_, preds = torch.max(preds, 1)
preds, target = _preprocess_segmentation_inputs(preds, target, apply_arg_max=True)
super().update(preds=preds, target=target)


class BinaryIOU(torchmetrics.IoU):
def __init__(self, dist_sync_on_step=True, ignore_index=None):
super().__init__(num_classes=2, dist_sync_on_step=dist_sync_on_step, ignore_index=ignore_index, reduction="none", threshold=0.5)
class BinaryIOU(IoU):
def __init__(self,
dist_sync_on_step=True,
ignore_index: Optional[int] = None,
threshold: float = 0.5):
super().__init__(num_classes=2, dist_sync_on_step=dist_sync_on_step, ignore_index=ignore_index,
reduction="none", threshold=threshold)
self.component_names = ["target_IOU", "background_IOU", "mean_IOU"]

def update(self, preds, target: torch.Tensor):
super().update(preds=torch.sigmoid(preds), target=target.long())
preds, target = _preprocess_segmentation_inputs(preds, target, apply_sigmoid=True)
super(IoU, self).update(preds=preds, target=target.long())

def compute(self):
ious = super(BinaryIOU, self).compute()
return {"target_IOU": ious[1], "background_IOU": ious[0], "mean_IOU": ious.mean()}


class Dice(torchmetrics.JaccardIndex):
lkdci marked this conversation as resolved.
Show resolved Hide resolved
def __init__(self,
num_classes: int,
dist_sync_on_step: bool = True,
ignore_index: Optional[int] = None,
reduction: str = "elementwise_mean",
threshold: float = 0.5):
super().__init__(num_classes=num_classes, dist_sync_on_step=dist_sync_on_step, ignore_index=ignore_index,
reduction=reduction, threshold=threshold)

def update(self, preds, target: torch.Tensor):
preds, target = _preprocess_segmentation_inputs(preds, target, apply_arg_max=True)
super().update(preds=preds, target=target)

def compute(self) -> torch.Tensor:
"""Computes Dice coefficient"""
return _dice_from_confmat(
self.confmat, self.num_classes, self.ignore_index, self.absent_score, self.reduction
)


class BinaryDice(Dice):
def __init__(self,
dist_sync_on_step=True,
ignore_index: Optional[int] = None,
threshold: float = 0.5):
super().__init__(num_classes=2, dist_sync_on_step=dist_sync_on_step, ignore_index=ignore_index,
reduction="none", threshold=threshold)
self.component_names = ["target_Dice", "background_Dice", "mean_Dice"]

def update(self, preds, target: torch.Tensor):
preds, target = _preprocess_segmentation_inputs(preds, target, apply_sigmoid=True)
super(Dice, self).update(preds=preds, target=target.long())

def compute(self):
dices = super().compute()
return {"target_Dice": dices[1], "background_Dice": dices[0], "mean_Dice": dices.mean()}