Skip to content

Commit

Permalink
Add Distance Jaccard Loss function
Browse files Browse the repository at this point in the history
  • Loading branch information
gagewrye committed Nov 15, 2024
1 parent 49f9634 commit 300a809
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 3 deletions.
4 changes: 2 additions & 2 deletions Drone Classification/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from .models import ResNet_UNet, ResNet_FC, SegmentModelWrapper
from .loss import JaccardLoss, FocalLoss
from .models import ResNet_UNet, ResNet_FC, DenseNet_UNet, SegmentModelWrapper
from .loss import JaccardLoss, FocalLoss, DistanceCountLoss
79 changes: 78 additions & 1 deletion Drone Classification/models/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,81 @@ def forward(self, y_pred : torch.tensor, y_true: torch.tensor):
# Return the Jaccard loss (1 - Jaccard index)
return 1 - jaccard_index

# This loss checks how close a positive pixel is to a true positive.
# Total positive count is used to check accuracy to mangrove size (Overall size is what we care about to monitor health).
# It also incorporates Jaccard loss to increase IOU
class DistanceCountLoss(nn.Module):
def __init__(self, smooth=1e-10, weight_jaccard=0.9, weight_distance_count=0.1):
super(DistanceCountLoss, self).__init__()
self.max_samples = 100
self.smooth = smooth
self.weight_jaccard = weight_jaccard
self.weight_distance_count = weight_distance_count

def jaccard_loss(self, y_pred, y_true):
# Apply sigmoid to predictions to get probabilities
y_pred = torch.sigmoid(y_pred)

# Flatten the tensors
y_pred = y_pred.view(-1)
y_true = y_true.view(-1)

# Calculate intersection and union
intersection = (y_pred * y_true).sum()
union = y_pred.sum() + y_true.sum() - intersection

# Jaccard index and loss
jaccard_index = (intersection + self.smooth) / (union + self.smooth)
return 1 - jaccard_index

def chamfer_distance(self, pred_points, label_points):
# Check if either set is empty
if pred_points.shape[0] == 0:
return label_points.shape[0] # Penalize based on the count of positives in label
if label_points.shape[0] == 0:
return pred_points.shape[0] # Penalize based on the count of positives in prediction

# Compute distances only one way (pred -> label) at a time
dist_pred_to_label = torch.cdist(pred_points, label_points).min(dim=1)[0]
dist_label_to_pred = torch.cdist(label_points, pred_points).min(dim=1)[0]

# Chamfer distance is the mean of both sets' minimum distances
return dist_pred_to_label.mean() + dist_label_to_pred.mean()

def sample_points(self, points):
# Randomly sample points if there are more than max_samples
if points.shape[0] > self.max_samples:
indices = torch.randperm(points.shape[0])[:self.max_samples]
points = points[indices]
return points

def forward(self, y_pred, y_true):
# Calculate Jaccard loss
jaccard_loss_value = self.jaccard_loss(y_pred, y_true)

# Threshold to obtain positive pixels in predictions and labels
pred_positives = self.sample_points((y_pred > 0.5).nonzero(as_tuple=False).float())
label_positives = self.sample_points((y_true > 0.5).nonzero(as_tuple=False).float())

# Count of positive pixels in label
label_count = label_positives.shape[0]

# If there are no positives in the labels, avoid distance calculation
if label_count == 0:
distance_loss = pred_positives.shape[0] # Penalize only false positives
else:
# Calculate distance loss using Chamfer Distance
distance_loss = self.chamfer_distance(pred_positives, label_positives)

# Combined loss
total_loss = (
self.weight_jaccard * jaccard_loss_value
+ self.weight_distance_count * distance_loss
)

return total_loss


class FocalLoss(nn.Module):
"""
Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002.
Expand Down Expand Up @@ -118,4 +193,6 @@ def forward(self, inputs, targets):
raise ValueError(
f"Invalid Value for arg 'reduction': '{self.reduction} \n Supported reduction modes: 'none', 'mean', 'sum'"
)
return loss
return loss


0 comments on commit 300a809

Please sign in to comment.