Skip to content

Commit

Permalink
fix mistake from earlier PR
Browse files Browse the repository at this point in the history
  • Loading branch information
SkafteNicki authored and Borda committed Sep 13, 2022
1 parent b16607c commit 1fa37a4
Showing 1 changed file with 0 additions and 16 deletions.
16 changes: 0 additions & 16 deletions src/torchmetrics/classification/stat_scores.py
Original file line number Diff line number Diff line change
Expand Up @@ -700,15 +700,6 @@ def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore
preds: Predictions from model (probabilities, logits or labels)
target: Ground truth values
"""
if self.task is not None:
if self.task == "binary":
BinaryStatScores.update(self, preds, target)
elif self.task == "multiclass":
MulticlassStatScores.update(self, preds, target)
elif self.task == "multilabel":
MultilabelStatScores.update(self, preds, target)
return

tp, fp, tn, fn = _stat_scores_update(
preds,
target,
Expand Down Expand Up @@ -772,12 +763,5 @@ def compute(self) -> Tensor:
- If ``reduce='macro'``, the shape will be ``(N, C, 5)``
- If ``reduce='samples'``, the shape will be ``(N, X, 5)``
"""
if self.task is not None:
if self.task == "binary":
return BinaryStatScores.compute(self)
elif self.task == "multiclass":
return MulticlassStatScores.compute(self)
elif self.task == "multilabel":
return MultilabelStatScores.compute(self)
tp, fp, tn, fn = self._get_final_stats()
return _stat_scores_compute(tp, fp, tn, fn)

0 comments on commit 1fa37a4

Please sign in to comment.