Skip to content

Commit

Permalink
Merge pull request #174 from Deci-AI/feature/ALG-391_dice_metrics
Browse files Browse the repository at this point in the history
Add Dice and BinaryDice metrics
  • Loading branch information
lkdci authored May 2, 2022
2 parents 6f4f13b + f8a3fa7 commit 018bd0c
Show file tree
Hide file tree
Showing 2 changed files with 159 additions and 20 deletions.
6 changes: 5 additions & 1 deletion src/super_gradients/common/factories/metrics_factory.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from super_gradients.common.factories.base_factory import BaseFactory
from super_gradients.training.metrics import Accuracy, Top5, DetectionMetrics, IoU, PixelAccuracy
from super_gradients.training.metrics import Accuracy, Top5, DetectionMetrics, IoU, PixelAccuracy, BinaryIOU, Dice,\
BinaryDice


class MetricsFactory(BaseFactory):
Expand All @@ -10,6 +11,9 @@ def __init__(self):
'Top5': Top5,
'DetectionMetrics': DetectionMetrics,
'IoU': IoU,
"BinaryIOU": BinaryIOU,
"Dice": Dice,
"BinaryDice": BinaryDice,
'PixelAccuracy': PixelAccuracy,
}
super().__init__(type_dict)
173 changes: 154 additions & 19 deletions src/super_gradients/training/metrics/segmentation_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
import torch
import torchmetrics
from torchmetrics import Metric
from typing import Optional, Tuple
from torchmetrics.utilities.distributed import reduce
from abc import ABC, abstractmethod


def batch_pix_accuracy(predict, target):
Expand Down Expand Up @@ -59,6 +62,51 @@ def pixel_accuracy(im_pred, im_lab):
return pixel_correct, pixel_labeled


def _dice_from_confmat(
confmat: torch.Tensor,
num_classes: int,
ignore_index: Optional[int] = None,
absent_score: float = 0.0,
reduction: str = "elementwise_mean",
) -> torch.Tensor:
"""Computes Dice coefficient from confusion matrix.
Args:
confmat: Confusion matrix without normalization
num_classes: Number of classes for a given prediction and target tensor
ignore_index: optional int specifying a target class to ignore. If given, this class index does not contribute
to the returned score, regardless of reduction method.
absent_score: score to use for an individual class, if no instances of the class index were present in `pred`
AND no instances of the class index were present in `target`.
reduction: a method to reduce metric score over labels.
- ``'elementwise_mean'``: takes the mean (default)
- ``'sum'``: takes the sum
- ``'none'``: no reduction will be applied
"""

# Remove the ignored class index from the scores.
if ignore_index is not None and 0 <= ignore_index < num_classes:
confmat[ignore_index] = 0.0

intersection = torch.diag(confmat)
denominator = confmat.sum(0) + confmat.sum(1)

# If this class is absent in both target AND pred (union == 0), then use the absent_score for this class.
scores = 2 * intersection.float() / denominator.float()
scores[denominator == 0] = absent_score

if ignore_index is not None and 0 <= ignore_index < num_classes:
scores = torch.cat(
[
scores[:ignore_index],
scores[ignore_index + 1:],
]
)

return reduce(scores, reduction=reduction)


def intersection_and_union(im_pred, im_lab, num_class):
im_pred = np.asarray(im_pred)
im_lab = np.asarray(im_lab)
Expand All @@ -77,17 +125,60 @@ def intersection_and_union(im_pred, im_lab, num_class):
return area_inter, area_union


class AbstractMetricsArgsPrepFn(ABC):
"""
Abstract preprocess metrics arguments class.
"""
@abstractmethod
def __call__(self, preds, target: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
All base classes must implement this function and return a tuple of torch tensors (predictions, target).
"""
raise NotImplementedError()


class PreprocessSegmentationMetricsArgs(AbstractMetricsArgsPrepFn):
"""
Default segmentation inputs preprocess function before updating segmentation metrics, handles multiple inputs and
apply normalizations.
"""
def __init__(self,
apply_arg_max: bool = False,
apply_sigmoid: bool = False):
"""
:param apply_arg_max: Whether to apply argmax on predictions tensor.
:param apply_sigmoid: Whether to apply sigmoid on predictions tensor.
"""
self.apply_arg_max = apply_arg_max
self.apply_sigmoid = apply_sigmoid

def __call__(self, preds, target: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
# WHEN DEALING WITH MULTIPLE OUTPUTS- OUTPUTS[0] IS THE MAIN SEGMENTATION MAP
if isinstance(preds, (tuple, list)):
preds = preds[0]
if self.apply_arg_max:
_, preds = torch.max(preds, 1)
elif self.apply_sigmoid:
preds = torch.sigmoid(preds)

target = target.long()
return preds, target


class PixelAccuracy(Metric):
def __init__(self, ignore_label=-100, dist_sync_on_step=False):
def __init__(self,
ignore_label=-100,
dist_sync_on_step=False,
metrics_args_prep_fn: Optional[AbstractMetricsArgsPrepFn] = None):
super().__init__(dist_sync_on_step=dist_sync_on_step)
self.ignore_label = ignore_label
self.add_state("total_correct", default=torch.tensor(0.), dist_reduce_fx="sum")
self.add_state("total_label", default=torch.tensor(0.), dist_reduce_fx="sum")
self.metrics_args_prep_fn = metrics_args_prep_fn or PreprocessSegmentationMetricsArgs(apply_arg_max=True)

def update(self, preds: torch.Tensor, target: torch.Tensor):
if isinstance(preds, tuple):
preds = preds[0]
_, predict = torch.max(preds, 1)
predict, target = self.metrics_args_prep_fn(preds, target)

labeled_mask = target.ne(self.ignore_label)
pixel_labeled = torch.sum(labeled_mask)
pixel_correct = torch.sum((predict == target) * labeled_mask)
Expand All @@ -101,29 +192,73 @@ def compute(self):
return pix_acc


class IoU(torchmetrics.IoU):
def __init__(self, num_classes, dist_sync_on_step=True, ignore_index=None):
super().__init__(num_classes=num_classes, dist_sync_on_step=dist_sync_on_step, ignore_index=ignore_index)
class IoU(torchmetrics.JaccardIndex):
def __init__(self,
num_classes: int,
dist_sync_on_step: bool = False,
ignore_index: Optional[int] = None,
reduction: str = "elementwise_mean",
threshold: float = 0.5,
metrics_args_prep_fn: Optional[AbstractMetricsArgsPrepFn] = None):
super().__init__(num_classes=num_classes, dist_sync_on_step=dist_sync_on_step, ignore_index=ignore_index,
reduction=reduction, threshold=threshold)
self.metrics_args_prep_fn = metrics_args_prep_fn or PreprocessSegmentationMetricsArgs(apply_arg_max=True)

def update(self, preds, target: torch.Tensor):
# WHEN DEALING WITH MULTIPLE OUTPUTS- OUTPUTS[0] IS THE MAIN SEGMENTATION MAP
if isinstance(preds, tuple):
preds = preds[0]
_, preds = torch.max(preds, 1)
preds, target = self.metrics_args_prep_fn(preds, target)
super().update(preds=preds, target=target)


class BinaryIOU(torchmetrics.IoU):
def __init__(self, dist_sync_on_step=True, ignore_index=None):
super().__init__(num_classes=2, dist_sync_on_step=dist_sync_on_step, ignore_index=ignore_index, reduction="none", threshold=0.5)
self.component_names = ["target_IOU", "background_IOU", "mean_IOU"]
class Dice(torchmetrics.JaccardIndex):
def __init__(self,
num_classes: int,
dist_sync_on_step: bool = False,
ignore_index: Optional[int] = None,
reduction: str = "elementwise_mean",
threshold: float = 0.5,
metrics_args_prep_fn: Optional[AbstractMetricsArgsPrepFn] = None):
super().__init__(num_classes=num_classes, dist_sync_on_step=dist_sync_on_step, ignore_index=ignore_index,
reduction=reduction, threshold=threshold)
self.metrics_args_prep_fn = metrics_args_prep_fn or PreprocessSegmentationMetricsArgs(apply_arg_max=True)

def update(self, preds, target: torch.Tensor):
# WHEN DEALING WITH MULTIPLE OUTPUTS- OUTPUTS[0] IS THE MAIN SEGMENTATION MAP
if isinstance(preds, tuple):
preds = preds[0]
super().update(preds=torch.sigmoid(preds), target=target.long())
preds, target = self.metrics_args_prep_fn(preds, target)
super().update(preds=preds, target=target)

def compute(self) -> torch.Tensor:
"""Computes Dice coefficient"""
return _dice_from_confmat(
self.confmat, self.num_classes, self.ignore_index, self.absent_score, self.reduction
)


class BinaryIOU(IoU):
def __init__(self,
dist_sync_on_step=True,
ignore_index: Optional[int] = None,
threshold: float = 0.5,
metrics_args_prep_fn: Optional[AbstractMetricsArgsPrepFn] = None):
metrics_args_prep_fn = metrics_args_prep_fn or PreprocessSegmentationMetricsArgs(apply_sigmoid=True)
super().__init__(num_classes=2, dist_sync_on_step=dist_sync_on_step, ignore_index=ignore_index,
reduction="none", threshold=threshold, metrics_args_prep_fn=metrics_args_prep_fn)
self.component_names = ["target_IOU", "background_IOU", "mean_IOU"]

def compute(self):
ious = super(BinaryIOU, self).compute()
return {"target_IOU": ious[1], "background_IOU": ious[0], "mean_IOU": ious.mean()}


class BinaryDice(Dice):
def __init__(self,
dist_sync_on_step=True,
ignore_index: Optional[int] = None,
threshold: float = 0.5,
metrics_args_prep_fn: Optional[AbstractMetricsArgsPrepFn] = None):
metrics_args_prep_fn = metrics_args_prep_fn or PreprocessSegmentationMetricsArgs(apply_sigmoid=True)
super().__init__(num_classes=2, dist_sync_on_step=dist_sync_on_step, ignore_index=ignore_index,
reduction="none", threshold=threshold, metrics_args_prep_fn=metrics_args_prep_fn)
self.component_names = ["target_Dice", "background_Dice", "mean_Dice"]

def compute(self):
dices = super().compute()
return {"target_Dice": dices[1], "background_Dice": dices[0], "mean_Dice": dices.mean()}

0 comments on commit 018bd0c

Please sign in to comment.