From 300a80975bf6c1b3fa9bdf0679564031b1d5d888 Mon Sep 17 00:00:00 2001 From: gagewrye <95107220+gagewrye@users.noreply.github.com> Date: Fri, 15 Nov 2024 13:17:16 -0800 Subject: [PATCH] Add Distance Jaccard Loss function --- Drone Classification/models/__init__.py | 4 +- Drone Classification/models/loss.py | 79 ++++++++++++++++++++++++- 2 files changed, 80 insertions(+), 3 deletions(-) diff --git a/Drone Classification/models/__init__.py b/Drone Classification/models/__init__.py index 76acdcc..75ec990 100755 --- a/Drone Classification/models/__init__.py +++ b/Drone Classification/models/__init__.py @@ -1,2 +1,2 @@ -from .models import ResNet_UNet, ResNet_FC, SegmentModelWrapper -from .loss import JaccardLoss, FocalLoss \ No newline at end of file +from .models import ResNet_UNet, ResNet_FC, DenseNet_UNet, SegmentModelWrapper +from .loss import JaccardLoss, FocalLoss, DistanceCountLoss \ No newline at end of file diff --git a/Drone Classification/models/loss.py b/Drone Classification/models/loss.py index 23e6fcb..34a7cdb 100644 --- a/Drone Classification/models/loss.py +++ b/Drone Classification/models/loss.py @@ -67,6 +67,81 @@ def forward(self, y_pred : torch.tensor, y_true: torch.tensor): # Return the Jaccard loss (1 - Jaccard index) return 1 - jaccard_index +# This loss checks how close a positive pixel is to a true positive. +# Total positive count is used to check accuracy to mangrove size (Overall size is what we care about to monitor health). +# It also incorporates Jaccard loss to increase IOU +class DistanceCountLoss(nn.Module): + def __init__(self, smooth=1e-10, weight_jaccard=0.9, weight_distance_count=0.1): + super(DistanceCountLoss, self).__init__() + self.max_samples = 100 + self.smooth = smooth + self.weight_jaccard = weight_jaccard + self.weight_distance_count = weight_distance_count + + def jaccard_loss(self, y_pred, y_true): + # Apply sigmoid to predictions to get probabilities + y_pred = torch.sigmoid(y_pred) + + # Flatten the tensors + y_pred = y_pred.view(-1) + y_true = y_true.view(-1) + + # Calculate intersection and union + intersection = (y_pred * y_true).sum() + union = y_pred.sum() + y_true.sum() - intersection + + # Jaccard index and loss + jaccard_index = (intersection + self.smooth) / (union + self.smooth) + return 1 - jaccard_index + + def chamfer_distance(self, pred_points, label_points): + # Check if either set is empty + if pred_points.shape[0] == 0: + return label_points.shape[0] # Penalize based on the count of positives in label + if label_points.shape[0] == 0: + return pred_points.shape[0] # Penalize based on the count of positives in prediction + + # Compute distances only one way (pred -> label) at a time + dist_pred_to_label = torch.cdist(pred_points, label_points).min(dim=1)[0] + dist_label_to_pred = torch.cdist(label_points, pred_points).min(dim=1)[0] + + # Chamfer distance is the mean of both sets' minimum distances + return dist_pred_to_label.mean() + dist_label_to_pred.mean() + + def sample_points(self, points): + # Randomly sample points if there are more than max_samples + if points.shape[0] > self.max_samples: + indices = torch.randperm(points.shape[0])[:self.max_samples] + points = points[indices] + return points + + def forward(self, y_pred, y_true): + # Calculate Jaccard loss + jaccard_loss_value = self.jaccard_loss(y_pred, y_true) + + # Threshold to obtain positive pixels in predictions and labels + pred_positives = self.sample_points((y_pred > 0.5).nonzero(as_tuple=False).float()) + label_positives = self.sample_points((y_true > 0.5).nonzero(as_tuple=False).float()) + + # Count of positive pixels in label + label_count = label_positives.shape[0] + + # If there are no positives in the labels, avoid distance calculation + if label_count == 0: + distance_loss = pred_positives.shape[0] # Penalize only false positives + else: + # Calculate distance loss using Chamfer Distance + distance_loss = self.chamfer_distance(pred_positives, label_positives) + + # Combined loss + total_loss = ( + self.weight_jaccard * jaccard_loss_value + + self.weight_distance_count * distance_loss + ) + + return total_loss + + class FocalLoss(nn.Module): """ Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002. @@ -118,4 +193,6 @@ def forward(self, inputs, targets): raise ValueError( f"Invalid Value for arg 'reduction': '{self.reduction} \n Supported reduction modes: 'none', 'mean', 'sum'" ) - return loss \ No newline at end of file + return loss + +