From 3082e83f07a698666e678a7b2f489099553f081b Mon Sep 17 00:00:00 2001 From: liork Date: Sun, 17 Apr 2022 10:35:10 +0300 Subject: [PATCH 1/3] Add dice metrics --- .../common/factories/metrics_factory.py | 6 +- .../training/metrics/segmentation_metrics.py | 141 ++++++++++++++++-- 2 files changed, 132 insertions(+), 15 deletions(-) diff --git a/src/super_gradients/common/factories/metrics_factory.py b/src/super_gradients/common/factories/metrics_factory.py index 2814189f12..77003ec89b 100644 --- a/src/super_gradients/common/factories/metrics_factory.py +++ b/src/super_gradients/common/factories/metrics_factory.py @@ -1,5 +1,6 @@ from super_gradients.common.factories.base_factory import BaseFactory -from super_gradients.training.metrics import Accuracy, Top5, DetectionMetrics, IoU, PixelAccuracy +from super_gradients.training.metrics import Accuracy, Top5, DetectionMetrics, IoU, PixelAccuracy, BinaryIOU, Dice,\ + BinaryDice class MetricsFactory(BaseFactory): @@ -10,6 +11,9 @@ def __init__(self): 'Top5': Top5, 'DetectionMetrics': DetectionMetrics, 'IoU': IoU, + "BinaryIOU": BinaryIOU, + "Dice": Dice, + "BinaryDice": BinaryDice, 'PixelAccuracy': PixelAccuracy, } super().__init__(type_dict) diff --git a/src/super_gradients/training/metrics/segmentation_metrics.py b/src/super_gradients/training/metrics/segmentation_metrics.py index 0e94a04218..eff331b1af 100755 --- a/src/super_gradients/training/metrics/segmentation_metrics.py +++ b/src/super_gradients/training/metrics/segmentation_metrics.py @@ -2,6 +2,8 @@ import torch import torchmetrics from torchmetrics import Metric +from typing import Optional +from torchmetrics.utilities.distributed import reduce def batch_pix_accuracy(predict, target): @@ -59,6 +61,51 @@ def pixel_accuracy(im_pred, im_lab): return pixel_correct, pixel_labeled +def _dice_from_confmat( + confmat: torch.Tensor, + num_classes: int, + ignore_index: Optional[int] = None, + absent_score: float = 0.0, + reduction: str = "elementwise_mean", +) -> torch.Tensor: + """Computes Dice coefficient from confusion matrix. + + Args: + confmat: Confusion matrix without normalization + num_classes: Number of classes for a given prediction and target tensor + ignore_index: optional int specifying a target class to ignore. If given, this class index does not contribute + to the returned score, regardless of reduction method. + absent_score: score to use for an individual class, if no instances of the class index were present in `pred` + AND no instances of the class index were present in `target`. + reduction: a method to reduce metric score over labels. + + - ``'elementwise_mean'``: takes the mean (default) + - ``'sum'``: takes the sum + - ``'none'``: no reduction will be applied + """ + + # Remove the ignored class index from the scores. + if ignore_index is not None and 0 <= ignore_index < num_classes: + confmat[ignore_index] = 0.0 + + intersection = torch.diag(confmat) + denominator = confmat.sum(0) + confmat.sum(1) + + # If this class is absent in both target AND pred (union == 0), then use the absent_score for this class. + scores = 2 * intersection.float() / denominator.float() + scores[denominator == 0] = absent_score + + if ignore_index is not None and 0 <= ignore_index < num_classes: + scores = torch.cat( + [ + scores[:ignore_index], + scores[ignore_index + 1 :], + ] + ) + + return reduce(scores, reduction=reduction) + + def intersection_and_union(im_pred, im_lab, num_class): im_pred = np.asarray(im_pred) im_lab = np.asarray(im_lab) @@ -77,6 +124,26 @@ def intersection_and_union(im_pred, im_lab, num_class): return area_inter, area_union +def _preprocess_segmentation_inputs(preds, target: torch.Tensor, + apply_arg_max: bool = False, + apply_sigmoid: bool = False): + """ + preprocess segmentation predictions and target before updating segmentation metrics, handles multiple inputs and + apply normalizations. + :param apply_arg_max: Whether to apply argmax on predictions tensor. + :param apply_sigmoid: Whether to apply sigmoid on predictions tensor. + """ + # WHEN DEALING WITH MULTIPLE OUTPUTS- OUTPUTS[0] IS THE MAIN SEGMENTATION MAP + if isinstance(preds, (tuple, list)): + preds = preds[0] + if apply_arg_max: + _, preds = torch.max(preds, 1) + elif apply_sigmoid: + preds = torch.sigmoid(preds) + + return preds, target + + class PixelAccuracy(Metric): def __init__(self, ignore_label=-100, dist_sync_on_step=False): super().__init__(dist_sync_on_step=dist_sync_on_step) @@ -85,9 +152,8 @@ def __init__(self, ignore_label=-100, dist_sync_on_step=False): self.add_state("total_label", default=torch.tensor(0.), dist_reduce_fx="sum") def update(self, preds: torch.Tensor, target: torch.Tensor): - if isinstance(preds, tuple): - preds = preds[0] - _, predict = torch.max(preds, 1) + predict, target = _preprocess_segmentation_inputs(preds, target, apply_arg_max=True) + labeled_mask = target.ne(self.ignore_label) pixel_labeled = torch.sum(labeled_mask) pixel_correct = torch.sum((predict == target) * labeled_mask) @@ -101,26 +167,73 @@ def compute(self): return pix_acc -class IoU(torchmetrics.IoU): - def __init__(self, num_classes, dist_sync_on_step=True, ignore_index=None): - super().__init__(num_classes=num_classes, dist_sync_on_step=dist_sync_on_step, ignore_index=ignore_index) +class IoU(torchmetrics.JaccardIndex): + def __init__(self, + num_classes: int, + dist_sync_on_step: bool = True, + ignore_index: Optional[int] = None, + reduction: str = "elementwise_mean", + threshold: float = 0.5): + super().__init__(num_classes=num_classes, dist_sync_on_step=dist_sync_on_step, ignore_index=ignore_index, + reduction=reduction, threshold=threshold) def update(self, preds, target: torch.Tensor): - # WHEN DEALING WITH MULTIPLE OUTPUTS- OUTPUTS[0] IS THE MAIN SEGMENTATION MAP - if isinstance(preds, tuple): - preds = preds[0] - _, preds = torch.max(preds, 1) + preds, target = _preprocess_segmentation_inputs(preds, target, apply_arg_max=True) super().update(preds=preds, target=target) -class BinaryIOU(torchmetrics.IoU): - def __init__(self, dist_sync_on_step=True, ignore_index=None): - super().__init__(num_classes=2, dist_sync_on_step=dist_sync_on_step, ignore_index=ignore_index, reduction="none", threshold=0.5) +class BinaryIOU(IoU): + def __init__(self, + dist_sync_on_step=True, + ignore_index: Optional[int] = None, + threshold: float = 0.5): + super().__init__(num_classes=2, dist_sync_on_step=dist_sync_on_step, ignore_index=ignore_index, + reduction="none", threshold=threshold) self.component_names = ["target_IOU", "background_IOU", "mean_IOU"] def update(self, preds, target: torch.Tensor): - super().update(preds=torch.sigmoid(preds), target=target.long()) + preds, target = _preprocess_segmentation_inputs(preds, target, apply_sigmoid=True) + super(IoU, self).update(preds=preds, target=target.long()) def compute(self): ious = super(BinaryIOU, self).compute() return {"target_IOU": ious[1], "background_IOU": ious[0], "mean_IOU": ious.mean()} + + +class Dice(torchmetrics.JaccardIndex): + def __init__(self, + num_classes: int, + dist_sync_on_step: bool = True, + ignore_index: Optional[int] = None, + reduction: str = "elementwise_mean", + threshold: float = 0.5): + super().__init__(num_classes=num_classes, dist_sync_on_step=dist_sync_on_step, ignore_index=ignore_index, + reduction=reduction, threshold=threshold) + + def update(self, preds, target: torch.Tensor): + preds, target = _preprocess_segmentation_inputs(preds, target, apply_arg_max=True) + super().update(preds=preds, target=target) + + def compute(self) -> torch.Tensor: + """Computes Dice coefficient""" + return _dice_from_confmat( + self.confmat, self.num_classes, self.ignore_index, self.absent_score, self.reduction + ) + + +class BinaryDice(Dice): + def __init__(self, + dist_sync_on_step=True, + ignore_index: Optional[int] = None, + threshold: float = 0.5): + super().__init__(num_classes=2, dist_sync_on_step=dist_sync_on_step, ignore_index=ignore_index, + reduction="none", threshold=threshold) + self.component_names = ["target_Dice", "background_Dice", "mean_Dice"] + + def update(self, preds, target: torch.Tensor): + preds, target = _preprocess_segmentation_inputs(preds, target, apply_sigmoid=True) + super(Dice, self).update(preds=preds, target=target.long()) + + def compute(self): + dices = super().compute() + return {"target_Dice": dices[1], "background_Dice": dices[0], "mean_Dice": dices.mean()} From 401b2476019041cad407847e1ce902fe8758c64a Mon Sep 17 00:00:00 2001 From: liork Date: Sun, 17 Apr 2022 10:35:48 +0300 Subject: [PATCH 2/3] lint --- src/super_gradients/training/metrics/segmentation_metrics.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/super_gradients/training/metrics/segmentation_metrics.py b/src/super_gradients/training/metrics/segmentation_metrics.py index eff331b1af..398390379f 100755 --- a/src/super_gradients/training/metrics/segmentation_metrics.py +++ b/src/super_gradients/training/metrics/segmentation_metrics.py @@ -99,7 +99,7 @@ def _dice_from_confmat( scores = torch.cat( [ scores[:ignore_index], - scores[ignore_index + 1 :], + scores[ignore_index + 1:], ] ) From 9c05c2dbc5b40d871035c00f61ba52538e336f2c Mon Sep 17 00:00:00 2001 From: liork Date: Tue, 19 Apr 2022 16:39:08 +0300 Subject: [PATCH 3/3] Add MetricsArgsPrepFn --- .../training/metrics/segmentation_metrics.py | 121 +++++++++++------- 1 file changed, 73 insertions(+), 48 deletions(-) diff --git a/src/super_gradients/training/metrics/segmentation_metrics.py b/src/super_gradients/training/metrics/segmentation_metrics.py index 398390379f..3cc5a2ebef 100755 --- a/src/super_gradients/training/metrics/segmentation_metrics.py +++ b/src/super_gradients/training/metrics/segmentation_metrics.py @@ -2,8 +2,9 @@ import torch import torchmetrics from torchmetrics import Metric -from typing import Optional +from typing import Optional, Tuple from torchmetrics.utilities.distributed import reduce +from abc import ABC, abstractmethod def batch_pix_accuracy(predict, target): @@ -124,35 +125,59 @@ def intersection_and_union(im_pred, im_lab, num_class): return area_inter, area_union -def _preprocess_segmentation_inputs(preds, target: torch.Tensor, - apply_arg_max: bool = False, - apply_sigmoid: bool = False): +class AbstractMetricsArgsPrepFn(ABC): """ - preprocess segmentation predictions and target before updating segmentation metrics, handles multiple inputs and - apply normalizations. - :param apply_arg_max: Whether to apply argmax on predictions tensor. - :param apply_sigmoid: Whether to apply sigmoid on predictions tensor. + Abstract preprocess metrics arguments class. """ - # WHEN DEALING WITH MULTIPLE OUTPUTS- OUTPUTS[0] IS THE MAIN SEGMENTATION MAP - if isinstance(preds, (tuple, list)): - preds = preds[0] - if apply_arg_max: - _, preds = torch.max(preds, 1) - elif apply_sigmoid: - preds = torch.sigmoid(preds) + @abstractmethod + def __call__(self, preds, target: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + All base classes must implement this function and return a tuple of torch tensors (predictions, target). + """ + raise NotImplementedError() - return preds, target + +class PreprocessSegmentationMetricsArgs(AbstractMetricsArgsPrepFn): + """ + Default segmentation inputs preprocess function before updating segmentation metrics, handles multiple inputs and + apply normalizations. + """ + def __init__(self, + apply_arg_max: bool = False, + apply_sigmoid: bool = False): + """ + :param apply_arg_max: Whether to apply argmax on predictions tensor. + :param apply_sigmoid: Whether to apply sigmoid on predictions tensor. + """ + self.apply_arg_max = apply_arg_max + self.apply_sigmoid = apply_sigmoid + + def __call__(self, preds, target: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + # WHEN DEALING WITH MULTIPLE OUTPUTS- OUTPUTS[0] IS THE MAIN SEGMENTATION MAP + if isinstance(preds, (tuple, list)): + preds = preds[0] + if self.apply_arg_max: + _, preds = torch.max(preds, 1) + elif self.apply_sigmoid: + preds = torch.sigmoid(preds) + + target = target.long() + return preds, target class PixelAccuracy(Metric): - def __init__(self, ignore_label=-100, dist_sync_on_step=False): + def __init__(self, + ignore_label=-100, + dist_sync_on_step=False, + metrics_args_prep_fn: Optional[AbstractMetricsArgsPrepFn] = None): super().__init__(dist_sync_on_step=dist_sync_on_step) self.ignore_label = ignore_label self.add_state("total_correct", default=torch.tensor(0.), dist_reduce_fx="sum") self.add_state("total_label", default=torch.tensor(0.), dist_reduce_fx="sum") + self.metrics_args_prep_fn = metrics_args_prep_fn or PreprocessSegmentationMetricsArgs(apply_arg_max=True) def update(self, preds: torch.Tensor, target: torch.Tensor): - predict, target = _preprocess_segmentation_inputs(preds, target, apply_arg_max=True) + predict, target = self.metrics_args_prep_fn(preds, target) labeled_mask = target.ne(self.ignore_label) pixel_labeled = torch.sum(labeled_mask) @@ -170,48 +195,34 @@ def compute(self): class IoU(torchmetrics.JaccardIndex): def __init__(self, num_classes: int, - dist_sync_on_step: bool = True, + dist_sync_on_step: bool = False, ignore_index: Optional[int] = None, reduction: str = "elementwise_mean", - threshold: float = 0.5): + threshold: float = 0.5, + metrics_args_prep_fn: Optional[AbstractMetricsArgsPrepFn] = None): super().__init__(num_classes=num_classes, dist_sync_on_step=dist_sync_on_step, ignore_index=ignore_index, reduction=reduction, threshold=threshold) + self.metrics_args_prep_fn = metrics_args_prep_fn or PreprocessSegmentationMetricsArgs(apply_arg_max=True) def update(self, preds, target: torch.Tensor): - preds, target = _preprocess_segmentation_inputs(preds, target, apply_arg_max=True) + preds, target = self.metrics_args_prep_fn(preds, target) super().update(preds=preds, target=target) -class BinaryIOU(IoU): - def __init__(self, - dist_sync_on_step=True, - ignore_index: Optional[int] = None, - threshold: float = 0.5): - super().__init__(num_classes=2, dist_sync_on_step=dist_sync_on_step, ignore_index=ignore_index, - reduction="none", threshold=threshold) - self.component_names = ["target_IOU", "background_IOU", "mean_IOU"] - - def update(self, preds, target: torch.Tensor): - preds, target = _preprocess_segmentation_inputs(preds, target, apply_sigmoid=True) - super(IoU, self).update(preds=preds, target=target.long()) - - def compute(self): - ious = super(BinaryIOU, self).compute() - return {"target_IOU": ious[1], "background_IOU": ious[0], "mean_IOU": ious.mean()} - - class Dice(torchmetrics.JaccardIndex): def __init__(self, num_classes: int, - dist_sync_on_step: bool = True, + dist_sync_on_step: bool = False, ignore_index: Optional[int] = None, reduction: str = "elementwise_mean", - threshold: float = 0.5): + threshold: float = 0.5, + metrics_args_prep_fn: Optional[AbstractMetricsArgsPrepFn] = None): super().__init__(num_classes=num_classes, dist_sync_on_step=dist_sync_on_step, ignore_index=ignore_index, reduction=reduction, threshold=threshold) + self.metrics_args_prep_fn = metrics_args_prep_fn or PreprocessSegmentationMetricsArgs(apply_arg_max=True) def update(self, preds, target: torch.Tensor): - preds, target = _preprocess_segmentation_inputs(preds, target, apply_arg_max=True) + preds, target = self.metrics_args_prep_fn(preds, target) super().update(preds=preds, target=target) def compute(self) -> torch.Tensor: @@ -221,19 +232,33 @@ def compute(self) -> torch.Tensor: ) +class BinaryIOU(IoU): + def __init__(self, + dist_sync_on_step=True, + ignore_index: Optional[int] = None, + threshold: float = 0.5, + metrics_args_prep_fn: Optional[AbstractMetricsArgsPrepFn] = None): + metrics_args_prep_fn = metrics_args_prep_fn or PreprocessSegmentationMetricsArgs(apply_sigmoid=True) + super().__init__(num_classes=2, dist_sync_on_step=dist_sync_on_step, ignore_index=ignore_index, + reduction="none", threshold=threshold, metrics_args_prep_fn=metrics_args_prep_fn) + self.component_names = ["target_IOU", "background_IOU", "mean_IOU"] + + def compute(self): + ious = super(BinaryIOU, self).compute() + return {"target_IOU": ious[1], "background_IOU": ious[0], "mean_IOU": ious.mean()} + + class BinaryDice(Dice): def __init__(self, dist_sync_on_step=True, ignore_index: Optional[int] = None, - threshold: float = 0.5): + threshold: float = 0.5, + metrics_args_prep_fn: Optional[AbstractMetricsArgsPrepFn] = None): + metrics_args_prep_fn = metrics_args_prep_fn or PreprocessSegmentationMetricsArgs(apply_sigmoid=True) super().__init__(num_classes=2, dist_sync_on_step=dist_sync_on_step, ignore_index=ignore_index, - reduction="none", threshold=threshold) + reduction="none", threshold=threshold, metrics_args_prep_fn=metrics_args_prep_fn) self.component_names = ["target_Dice", "background_Dice", "mean_Dice"] - def update(self, preds, target: torch.Tensor): - preds, target = _preprocess_segmentation_inputs(preds, target, apply_sigmoid=True) - super(Dice, self).update(preds=preds, target=target.long()) - def compute(self): dices = super().compute() return {"target_Dice": dices[1], "background_Dice": dices[0], "mean_Dice": dices.mean()}