diff --git a/src/anomalib/metrics/spro.py b/src/anomalib/metrics/spro.py index 2bc6e35ff6..3434818adc 100644 --- a/src/anomalib/metrics/spro.py +++ b/src/anomalib/metrics/spro.py @@ -98,7 +98,7 @@ def compute(self) -> torch.Tensor: def spro_score( predictions: torch.Tensor, - targets: torch.Tensor, + targets: list[torch.Tensor], threshold: float = 0.5, saturation_config: dict | None = None, ) -> torch.Tensor: @@ -106,7 +106,8 @@ def spro_score( Args: predictions (torch.Tensor): Predicted anomaly masks. - targets: (torch.Tensor): Ground truth anomaly masks with non-binary values and, original height and width + targets: (list[torch.Tensor]): Ground truth anomaly masks with original height and width. Each element in the + list is a tensor list of masks for the corresponding image. threshold (float): When predictions are passed as float, the threshold is used to binarize the predictions. saturation_config (dict): Saturations configuration for each label (pixel value) as the keys. Defaults: ``None`` (which the score is equivalent to PRO metric, but with the 'region' are