diff --git a/src/super_gradients/common/factories/metrics_factory.py b/src/super_gradients/common/factories/metrics_factory.py index 2814189f12..77003ec89b 100644 --- a/src/super_gradients/common/factories/metrics_factory.py +++ b/src/super_gradients/common/factories/metrics_factory.py @@ -1,5 +1,6 @@ from super_gradients.common.factories.base_factory import BaseFactory -from super_gradients.training.metrics import Accuracy, Top5, DetectionMetrics, IoU, PixelAccuracy +from super_gradients.training.metrics import Accuracy, Top5, DetectionMetrics, IoU, PixelAccuracy, BinaryIOU, Dice,\ + BinaryDice class MetricsFactory(BaseFactory): @@ -10,6 +11,9 @@ def __init__(self): 'Top5': Top5, 'DetectionMetrics': DetectionMetrics, 'IoU': IoU, + "BinaryIOU": BinaryIOU, + "Dice": Dice, + "BinaryDice": BinaryDice, 'PixelAccuracy': PixelAccuracy, } super().__init__(type_dict) diff --git a/src/super_gradients/training/metrics/segmentation_metrics.py b/src/super_gradients/training/metrics/segmentation_metrics.py index 5b1a3a87c7..3cc5a2ebef 100755 --- a/src/super_gradients/training/metrics/segmentation_metrics.py +++ b/src/super_gradients/training/metrics/segmentation_metrics.py @@ -2,6 +2,9 @@ import torch import torchmetrics from torchmetrics import Metric +from typing import Optional, Tuple +from torchmetrics.utilities.distributed import reduce +from abc import ABC, abstractmethod def batch_pix_accuracy(predict, target): @@ -59,6 +62,51 @@ def pixel_accuracy(im_pred, im_lab): return pixel_correct, pixel_labeled +def _dice_from_confmat( + confmat: torch.Tensor, + num_classes: int, + ignore_index: Optional[int] = None, + absent_score: float = 0.0, + reduction: str = "elementwise_mean", +) -> torch.Tensor: + """Computes Dice coefficient from confusion matrix. + + Args: + confmat: Confusion matrix without normalization + num_classes: Number of classes for a given prediction and target tensor + ignore_index: optional int specifying a target class to ignore. If given, this class index does not contribute + to the returned score, regardless of reduction method. + absent_score: score to use for an individual class, if no instances of the class index were present in `pred` + AND no instances of the class index were present in `target`. + reduction: a method to reduce metric score over labels. + + - ``'elementwise_mean'``: takes the mean (default) + - ``'sum'``: takes the sum + - ``'none'``: no reduction will be applied + """ + + # Remove the ignored class index from the scores. + if ignore_index is not None and 0 <= ignore_index < num_classes: + confmat[ignore_index] = 0.0 + + intersection = torch.diag(confmat) + denominator = confmat.sum(0) + confmat.sum(1) + + # If this class is absent in both target AND pred (union == 0), then use the absent_score for this class. + scores = 2 * intersection.float() / denominator.float() + scores[denominator == 0] = absent_score + + if ignore_index is not None and 0 <= ignore_index < num_classes: + scores = torch.cat( + [ + scores[:ignore_index], + scores[ignore_index + 1:], + ] + ) + + return reduce(scores, reduction=reduction) + + def intersection_and_union(im_pred, im_lab, num_class): im_pred = np.asarray(im_pred) im_lab = np.asarray(im_lab) @@ -77,17 +125,60 @@ def intersection_and_union(im_pred, im_lab, num_class): return area_inter, area_union +class AbstractMetricsArgsPrepFn(ABC): + """ + Abstract preprocess metrics arguments class. + """ + @abstractmethod + def __call__(self, preds, target: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + All base classes must implement this function and return a tuple of torch tensors (predictions, target). + """ + raise NotImplementedError() + + +class PreprocessSegmentationMetricsArgs(AbstractMetricsArgsPrepFn): + """ + Default segmentation inputs preprocess function before updating segmentation metrics, handles multiple inputs and + apply normalizations. + """ + def __init__(self, + apply_arg_max: bool = False, + apply_sigmoid: bool = False): + """ + :param apply_arg_max: Whether to apply argmax on predictions tensor. + :param apply_sigmoid: Whether to apply sigmoid on predictions tensor. + """ + self.apply_arg_max = apply_arg_max + self.apply_sigmoid = apply_sigmoid + + def __call__(self, preds, target: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + # WHEN DEALING WITH MULTIPLE OUTPUTS- OUTPUTS[0] IS THE MAIN SEGMENTATION MAP + if isinstance(preds, (tuple, list)): + preds = preds[0] + if self.apply_arg_max: + _, preds = torch.max(preds, 1) + elif self.apply_sigmoid: + preds = torch.sigmoid(preds) + + target = target.long() + return preds, target + + class PixelAccuracy(Metric): - def __init__(self, ignore_label=-100, dist_sync_on_step=False): + def __init__(self, + ignore_label=-100, + dist_sync_on_step=False, + metrics_args_prep_fn: Optional[AbstractMetricsArgsPrepFn] = None): super().__init__(dist_sync_on_step=dist_sync_on_step) self.ignore_label = ignore_label self.add_state("total_correct", default=torch.tensor(0.), dist_reduce_fx="sum") self.add_state("total_label", default=torch.tensor(0.), dist_reduce_fx="sum") + self.metrics_args_prep_fn = metrics_args_prep_fn or PreprocessSegmentationMetricsArgs(apply_arg_max=True) def update(self, preds: torch.Tensor, target: torch.Tensor): - if isinstance(preds, tuple): - preds = preds[0] - _, predict = torch.max(preds, 1) + predict, target = self.metrics_args_prep_fn(preds, target) + labeled_mask = target.ne(self.ignore_label) pixel_labeled = torch.sum(labeled_mask) pixel_correct = torch.sum((predict == target) * labeled_mask) @@ -101,29 +192,73 @@ def compute(self): return pix_acc -class IoU(torchmetrics.IoU): - def __init__(self, num_classes, dist_sync_on_step=True, ignore_index=None): - super().__init__(num_classes=num_classes, dist_sync_on_step=dist_sync_on_step, ignore_index=ignore_index) +class IoU(torchmetrics.JaccardIndex): + def __init__(self, + num_classes: int, + dist_sync_on_step: bool = False, + ignore_index: Optional[int] = None, + reduction: str = "elementwise_mean", + threshold: float = 0.5, + metrics_args_prep_fn: Optional[AbstractMetricsArgsPrepFn] = None): + super().__init__(num_classes=num_classes, dist_sync_on_step=dist_sync_on_step, ignore_index=ignore_index, + reduction=reduction, threshold=threshold) + self.metrics_args_prep_fn = metrics_args_prep_fn or PreprocessSegmentationMetricsArgs(apply_arg_max=True) def update(self, preds, target: torch.Tensor): - # WHEN DEALING WITH MULTIPLE OUTPUTS- OUTPUTS[0] IS THE MAIN SEGMENTATION MAP - if isinstance(preds, tuple): - preds = preds[0] - _, preds = torch.max(preds, 1) + preds, target = self.metrics_args_prep_fn(preds, target) super().update(preds=preds, target=target) -class BinaryIOU(torchmetrics.IoU): - def __init__(self, dist_sync_on_step=True, ignore_index=None): - super().__init__(num_classes=2, dist_sync_on_step=dist_sync_on_step, ignore_index=ignore_index, reduction="none", threshold=0.5) - self.component_names = ["target_IOU", "background_IOU", "mean_IOU"] +class Dice(torchmetrics.JaccardIndex): + def __init__(self, + num_classes: int, + dist_sync_on_step: bool = False, + ignore_index: Optional[int] = None, + reduction: str = "elementwise_mean", + threshold: float = 0.5, + metrics_args_prep_fn: Optional[AbstractMetricsArgsPrepFn] = None): + super().__init__(num_classes=num_classes, dist_sync_on_step=dist_sync_on_step, ignore_index=ignore_index, + reduction=reduction, threshold=threshold) + self.metrics_args_prep_fn = metrics_args_prep_fn or PreprocessSegmentationMetricsArgs(apply_arg_max=True) def update(self, preds, target: torch.Tensor): - # WHEN DEALING WITH MULTIPLE OUTPUTS- OUTPUTS[0] IS THE MAIN SEGMENTATION MAP - if isinstance(preds, tuple): - preds = preds[0] - super().update(preds=torch.sigmoid(preds), target=target.long()) + preds, target = self.metrics_args_prep_fn(preds, target) + super().update(preds=preds, target=target) + + def compute(self) -> torch.Tensor: + """Computes Dice coefficient""" + return _dice_from_confmat( + self.confmat, self.num_classes, self.ignore_index, self.absent_score, self.reduction + ) + + +class BinaryIOU(IoU): + def __init__(self, + dist_sync_on_step=True, + ignore_index: Optional[int] = None, + threshold: float = 0.5, + metrics_args_prep_fn: Optional[AbstractMetricsArgsPrepFn] = None): + metrics_args_prep_fn = metrics_args_prep_fn or PreprocessSegmentationMetricsArgs(apply_sigmoid=True) + super().__init__(num_classes=2, dist_sync_on_step=dist_sync_on_step, ignore_index=ignore_index, + reduction="none", threshold=threshold, metrics_args_prep_fn=metrics_args_prep_fn) + self.component_names = ["target_IOU", "background_IOU", "mean_IOU"] def compute(self): ious = super(BinaryIOU, self).compute() return {"target_IOU": ious[1], "background_IOU": ious[0], "mean_IOU": ious.mean()} + + +class BinaryDice(Dice): + def __init__(self, + dist_sync_on_step=True, + ignore_index: Optional[int] = None, + threshold: float = 0.5, + metrics_args_prep_fn: Optional[AbstractMetricsArgsPrepFn] = None): + metrics_args_prep_fn = metrics_args_prep_fn or PreprocessSegmentationMetricsArgs(apply_sigmoid=True) + super().__init__(num_classes=2, dist_sync_on_step=dist_sync_on_step, ignore_index=ignore_index, + reduction="none", threshold=threshold, metrics_args_prep_fn=metrics_args_prep_fn) + self.component_names = ["target_Dice", "background_Dice", "mean_Dice"] + + def compute(self): + dices = super().compute() + return {"target_Dice": dices[1], "background_Dice": dices[0], "mean_Dice": dices.mean()}