diff --git a/src/anomalib/metrics/spro.py b/src/anomalib/metrics/spro.py index 77b7e9b1cb..b383805969 100644 --- a/src/anomalib/metrics/spro.py +++ b/src/anomalib/metrics/spro.py @@ -109,7 +109,7 @@ def spro_score(predictions: torch.Tensor, targets: torch.Tensor, predictions = predictions > threshold score = torch.tensor(0.0) - + m = 0 # Iterate for each image in the batch for i, target in enumerate(targets): unique_labels = torch.unique(target) @@ -137,6 +137,9 @@ def spro_score(predictions: torch.Tensor, targets: torch.Tensor, # Update score with minimum of true_pos/saturation_threshold and 1.0 score += torch.minimum(true_pos / saturation_threshold, torch.tensor(1.0)) + m += 1 - # Calculate the mean score - return torch.mean(score) + # If there are only backgrounds + if m == 0: + return torch.tensor(1.0) + return score / m