diff --git a/torchmetrics/classification/binned_precision_recall.py b/torchmetrics/classification/binned_precision_recall.py index 6ba2e0a7354..f683618d234 100644 --- a/torchmetrics/classification/binned_precision_recall.py +++ b/torchmetrics/classification/binned_precision_recall.py @@ -34,11 +34,10 @@ def _recall_at_precision( except ValueError: max_recall = torch.tensor(0.0, device=recall.device, dtype=recall.dtype) + best_threshold = torch.tensor(0) if max_recall == 0.0: best_threshold = torch.tensor(1e6, device=thresholds.device, dtype=thresholds.dtype) - else: - best_threshold = torch.tensor(0) return max_recall, best_threshold @@ -128,7 +127,7 @@ class BinnedPrecisionRecallCurve(Metric): def __init__( self, num_classes: int, - thresholds: Optional[Union[Tensor, List[float]]] = None, + thresholds: Union[float, Tensor, List[float], None] = None, compute_on_step: bool = True, dist_sync_on_step: bool = False, process_group: Optional[Any] = None, @@ -187,7 +186,7 @@ def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore self.FPs[:, i] += ((~target) & (predictions)).sum(dim=0) self.FNs[:, i] += ((target) & (~predictions)).sum(dim=0) - def compute(self) -> Tuple[Tensor, Tensor, Tensor]: + def compute(self) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]: """Returns float tensor of size n_classes""" precisions = (self.TPs + METRIC_EPS) / (self.TPs + self.FPs + METRIC_EPS) recalls = self.TPs / (self.TPs + self.FNs + METRIC_EPS) @@ -198,8 +197,8 @@ def compute(self) -> Tuple[Tensor, Tensor, Tensor]: t_zeros = torch.zeros(self.num_classes, 1, dtype=recalls.dtype, device=recalls.device) recalls = torch.cat([recalls, t_zeros], dim=1) if self.num_classes == 1: - return (precisions[0, :], recalls[0, :], self.thresholds) - return (list(precisions), list(recalls), [self.thresholds for _ in range(self.num_classes)]) + return precisions[0, :], recalls[0, :], self.thresholds + return list(precisions), list(recalls), [self.thresholds for _ in range(self.num_classes)] class BinnedAveragePrecision(BinnedPrecisionRecallCurve):