Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
Borda committed Jul 13, 2021
1 parent 12a09ce commit e15338d
Showing 1 changed file with 5 additions and 6 deletions.
11 changes: 5 additions & 6 deletions torchmetrics/classification/binned_precision_recall.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,10 @@ def _recall_at_precision(

except ValueError:
max_recall = torch.tensor(0.0, device=recall.device, dtype=recall.dtype)
best_threshold = torch.tensor(0)

if max_recall == 0.0:
best_threshold = torch.tensor(1e6, device=thresholds.device, dtype=thresholds.dtype)
else:
best_threshold = torch.tensor(0)

return max_recall, best_threshold

Expand Down Expand Up @@ -128,7 +127,7 @@ class BinnedPrecisionRecallCurve(Metric):
def __init__(
self,
num_classes: int,
thresholds: Optional[Union[Tensor, List[float]]] = None,
thresholds: Union[float, Tensor, List[float], None] = None,
compute_on_step: bool = True,
dist_sync_on_step: bool = False,
process_group: Optional[Any] = None,
Expand Down Expand Up @@ -187,7 +186,7 @@ def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore
self.FPs[:, i] += ((~target) & (predictions)).sum(dim=0)
self.FNs[:, i] += ((target) & (~predictions)).sum(dim=0)

def compute(self) -> Tuple[Tensor, Tensor, Tensor]:
def compute(self) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]:
"""Returns float tensor of size n_classes"""
precisions = (self.TPs + METRIC_EPS) / (self.TPs + self.FPs + METRIC_EPS)
recalls = self.TPs / (self.TPs + self.FNs + METRIC_EPS)
Expand All @@ -198,8 +197,8 @@ def compute(self) -> Tuple[Tensor, Tensor, Tensor]:
t_zeros = torch.zeros(self.num_classes, 1, dtype=recalls.dtype, device=recalls.device)
recalls = torch.cat([recalls, t_zeros], dim=1)
if self.num_classes == 1:
return (precisions[0, :], recalls[0, :], self.thresholds)
return (list(precisions), list(recalls), [self.thresholds for _ in range(self.num_classes)])
return precisions[0, :], recalls[0, :], self.thresholds
return list(precisions), list(recalls), [self.thresholds for _ in range(self.num_classes)]


class BinnedAveragePrecision(BinnedPrecisionRecallCurve):
Expand Down

0 comments on commit e15338d

Please sign in to comment.