Skip to content

Commit

Permalink
Fix bug in macro average for a number of classification metrics (#1821)
Browse files Browse the repository at this point in the history
Co-authored-by: Jirka Borovec <[email protected]>
Co-authored-by: Jirka <[email protected]>
  • Loading branch information
3 people authored Jun 15, 2023
1 parent 181e112 commit 17a4920
Show file tree
Hide file tree
Showing 18 changed files with 228 additions and 83 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Fixed support for half precision in `PearsonCorrCoef` ([#1819](https://github.com/Lightning-AI/torchmetrics/pull/1819))


- Fixed number of bugs related to `average="macro"` in classification metrics ([#1821](https://github.com/Lightning-AI/torchmetrics/pull/1821))


## [0.11.4] - 2023-03-10

### Fixed
Expand Down
4 changes: 3 additions & 1 deletion src/torchmetrics/classification/f_beta.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,7 +488,9 @@ def __init__(
def compute(self) -> Tensor:
"""Compute metric."""
tp, fp, tn, fn = self._final_state()
return _fbeta_reduce(tp, fp, tn, fn, self.beta, average=self.average, multidim_average=self.multidim_average)
return _fbeta_reduce(
tp, fp, tn, fn, self.beta, average=self.average, multidim_average=self.multidim_average, multilabel=True
)

def plot(
self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
Expand Down
4 changes: 2 additions & 2 deletions src/torchmetrics/classification/precision_recall.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,7 +405,7 @@ def compute(self) -> Tensor:
"""Compute metric."""
tp, fp, tn, fn = self._final_state()
return _precision_recall_reduce(
"precision", tp, fp, tn, fn, average=self.average, multidim_average=self.multidim_average
"precision", tp, fp, tn, fn, average=self.average, multidim_average=self.multidim_average, multilabel=True
)

def plot(
Expand Down Expand Up @@ -819,7 +819,7 @@ def compute(self) -> Tensor:
"""Compute metric."""
tp, fp, tn, fn = self._final_state()
return _precision_recall_reduce(
"recall", tp, fp, tn, fn, average=self.average, multidim_average=self.multidim_average
"recall", tp, fp, tn, fn, average=self.average, multidim_average=self.multidim_average, multilabel=True
)

def plot(
Expand Down
4 changes: 3 additions & 1 deletion src/torchmetrics/classification/specificity.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,7 +383,9 @@ class MultilabelSpecificity(MultilabelStatScores):
def compute(self) -> Tensor:
"""Compute metric."""
tp, fp, tn, fn = self._final_state()
return _specificity_reduce(tp, fp, tn, fn, average=self.average, multidim_average=self.multidim_average)
return _specificity_reduce(
tp, fp, tn, fn, average=self.average, multidim_average=self.multidim_average, multilabel=True
)

def plot(
self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
Expand Down
7 changes: 2 additions & 5 deletions src/torchmetrics/functional/classification/accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
_multilabel_stat_scores_tensor_validation,
_multilabel_stat_scores_update,
)
from torchmetrics.utilities.compute import _safe_divide
from torchmetrics.utilities.compute import _adjust_weights_safe_divide, _safe_divide
from torchmetrics.utilities.enums import ClassificationTask


Expand Down Expand Up @@ -83,10 +83,7 @@ def _accuracy_reduce(
return _safe_divide(tp, tp + fn)

score = _safe_divide(tp + tn, tp + tn + fp + fn) if multilabel else _safe_divide(tp, tp + fn)
if average is None or average == "none":
return score
weights = tp + fn if average == "weighted" else torch.ones_like(score)
return _safe_divide(weights * score, weights.sum(-1, keepdim=True)).sum(-1)
return _adjust_weights_safe_divide(score, average, multilabel, tp, fp, fn)


def binary_accuracy(
Expand Down
10 changes: 4 additions & 6 deletions src/torchmetrics/functional/classification/f_beta.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
_multilabel_stat_scores_tensor_validation,
_multilabel_stat_scores_update,
)
from torchmetrics.utilities.compute import _safe_divide
from torchmetrics.utilities.compute import _adjust_weights_safe_divide, _safe_divide
from torchmetrics.utilities.enums import ClassificationTask


Expand All @@ -43,6 +43,7 @@ def _fbeta_reduce(
beta: float,
average: Optional[Literal["binary", "micro", "macro", "weighted", "none"]],
multidim_average: Literal["global", "samplewise"] = "global",
multilabel: bool = False,
) -> Tensor:
beta2 = beta**2
if average == "binary":
Expand All @@ -54,10 +55,7 @@ def _fbeta_reduce(
return _safe_divide((1 + beta2) * tp, (1 + beta2) * tp + beta2 * fn + fp)

fbeta_score = _safe_divide((1 + beta2) * tp, (1 + beta2) * tp + beta2 * fn + fp)
if average is None or average == "none":
return fbeta_score
weights = tp + fn if average == "weighted" else torch.ones_like(fbeta_score)
return _safe_divide(weights * fbeta_score, weights.sum(-1, keepdim=True)).sum(-1)
return _adjust_weights_safe_divide(fbeta_score, average, multilabel, tp, fp, fn)


def _binary_fbeta_score_arg_validation(
Expand Down Expand Up @@ -375,7 +373,7 @@ def multilabel_fbeta_score(
_multilabel_stat_scores_tensor_validation(preds, target, num_labels, multidim_average, ignore_index)
preds, target = _multilabel_stat_scores_format(preds, target, num_labels, threshold, ignore_index)
tp, fp, tn, fn = _multilabel_stat_scores_update(preds, target, multidim_average)
return _fbeta_reduce(tp, fp, tn, fn, beta, average=average, multidim_average=multidim_average)
return _fbeta_reduce(tp, fp, tn, fn, beta, average=average, multidim_average=multidim_average, multilabel=True)


def binary_f1_score(
Expand Down
7 changes: 2 additions & 5 deletions src/torchmetrics/functional/classification/hamming.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
_multilabel_stat_scores_tensor_validation,
_multilabel_stat_scores_update,
)
from torchmetrics.utilities.compute import _safe_divide
from torchmetrics.utilities.compute import _adjust_weights_safe_divide, _safe_divide
from torchmetrics.utilities.enums import ClassificationTask


Expand Down Expand Up @@ -80,10 +80,7 @@ def _hamming_distance_reduce(
return 1 - _safe_divide(tp, tp + fn)

score = 1 - _safe_divide(tp + tn, tp + tn + fp + fn) if multilabel else 1 - _safe_divide(tp, tp + fn)
if average is None or average == "none":
return score
weights = tp + fn if average == "weighted" else torch.ones_like(score)
return _safe_divide(weights * score, weights.sum(-1, keepdim=True)).sum(-1)
return _adjust_weights_safe_divide(score, average, multilabel, tp, fp, fn)


def binary_hamming_distance(
Expand Down
5 changes: 4 additions & 1 deletion src/torchmetrics/functional/classification/jaccard.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,8 @@ def _jaccard_index_reduce(
return confmat[1, 1] / (confmat[0, 1] + confmat[1, 0] + confmat[1, 1])

ignore_index_cond = ignore_index is not None and 0 <= ignore_index <= confmat.shape[0]
if confmat.ndim == 3: # multilabel
multilabel = confmat.ndim == 3
if multilabel:
num = confmat[:, 1, 1]
denom = confmat[:, 1, 1] + confmat[:, 0, 1] + confmat[:, 1, 0]
else: # multiclass
Expand All @@ -87,6 +88,8 @@ def _jaccard_index_reduce(
weights = torch.ones_like(jaccard)
if ignore_index_cond:
weights[ignore_index] = 0.0
if not multilabel:
weights[confmat.sum(1) + confmat.sum(0) == 0] = 0.0
return ((weights * jaccard) / weights.sum()).sum()


Expand Down
16 changes: 9 additions & 7 deletions src/torchmetrics/functional/classification/precision_recall.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
_multilabel_stat_scores_tensor_validation,
_multilabel_stat_scores_update,
)
from torchmetrics.utilities.compute import _safe_divide
from torchmetrics.utilities.compute import _adjust_weights_safe_divide, _safe_divide
from torchmetrics.utilities.enums import ClassificationTask


Expand All @@ -43,6 +43,7 @@ def _precision_recall_reduce(
fn: Tensor,
average: Optional[Literal["binary", "micro", "macro", "weighted", "none"]],
multidim_average: Literal["global", "samplewise"] = "global",
multilabel: bool = False,
) -> Tensor:
different_stat = fp if stat == "precision" else fn # this is what differs between the two scores
if average == "binary":
Expand All @@ -54,10 +55,7 @@ def _precision_recall_reduce(
return _safe_divide(tp, tp + different_stat)

score = _safe_divide(tp, tp + different_stat)
if average is None or average == "none":
return score
weights = tp + fn if average == "weighted" else torch.ones_like(score)
return _safe_divide(weights * score, weights.sum(-1, keepdim=True)).sum(-1)
return _adjust_weights_safe_divide(score, average, multilabel, tp, fp, fn)


def binary_precision(
Expand Down Expand Up @@ -336,7 +334,9 @@ def multilabel_precision(
_multilabel_stat_scores_tensor_validation(preds, target, num_labels, multidim_average, ignore_index)
preds, target = _multilabel_stat_scores_format(preds, target, num_labels, threshold, ignore_index)
tp, fp, tn, fn = _multilabel_stat_scores_update(preds, target, multidim_average)
return _precision_recall_reduce("precision", tp, fp, tn, fn, average=average, multidim_average=multidim_average)
return _precision_recall_reduce(
"precision", tp, fp, tn, fn, average=average, multidim_average=multidim_average, multilabel=True
)


def binary_recall(
Expand Down Expand Up @@ -615,7 +615,9 @@ def multilabel_recall(
_multilabel_stat_scores_tensor_validation(preds, target, num_labels, multidim_average, ignore_index)
preds, target = _multilabel_stat_scores_format(preds, target, num_labels, threshold, ignore_index)
tp, fp, tn, fn = _multilabel_stat_scores_update(preds, target, multidim_average)
return _precision_recall_reduce("recall", tp, fp, tn, fn, average=average, multidim_average=multidim_average)
return _precision_recall_reduce(
"recall", tp, fp, tn, fn, average=average, multidim_average=multidim_average, multilabel=True
)


def precision(
Expand Down
10 changes: 4 additions & 6 deletions src/torchmetrics/functional/classification/specificity.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
_multilabel_stat_scores_tensor_validation,
_multilabel_stat_scores_update,
)
from torchmetrics.utilities.compute import _safe_divide
from torchmetrics.utilities.compute import _adjust_weights_safe_divide, _safe_divide
from torchmetrics.utilities.enums import ClassificationTask


Expand All @@ -42,6 +42,7 @@ def _specificity_reduce(
fn: Tensor,
average: Optional[Literal["binary", "micro", "macro", "weighted", "none"]],
multidim_average: Literal["global", "samplewise"] = "global",
multilabel: bool = False,
) -> Tensor:
if average == "binary":
return _safe_divide(tn, tn + fp)
Expand All @@ -51,10 +52,7 @@ def _specificity_reduce(
return _safe_divide(tn, tn + fp)

specificity_score = _safe_divide(tn, tn + fp)
if average is None or average == "none":
return specificity_score
weights = tp + fn if average == "weighted" else torch.ones_like(specificity_score)
return _safe_divide(weights * specificity_score, weights.sum(-1, keepdim=True)).sum(-1)
return _adjust_weights_safe_divide(specificity_score, average, multilabel, tp, fp, fn)


def binary_specificity(
Expand Down Expand Up @@ -333,7 +331,7 @@ def multilabel_specificity(
_multilabel_stat_scores_tensor_validation(preds, target, num_labels, multidim_average, ignore_index)
preds, target = _multilabel_stat_scores_format(preds, target, num_labels, threshold, ignore_index)
tp, fp, tn, fn = _multilabel_stat_scores_update(preds, target, multidim_average)
return _specificity_reduce(tp, fp, tn, fn, average=average, multidim_average=multidim_average)
return _specificity_reduce(tp, fp, tn, fn, average=average, multidim_average=multidim_average, multilabel=True)


def specificity(
Expand Down
16 changes: 15 additions & 1 deletion src/torchmetrics/utilities/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Tuple
from typing import Optional, Tuple

import torch
from torch import Tensor
Expand Down Expand Up @@ -55,6 +55,20 @@ def _safe_divide(num: Tensor, denom: Tensor) -> Tensor:
return num / denom


def _adjust_weights_safe_divide(
score: Tensor, average: Optional[str], multilabel: bool, tp: Tensor, fp: Tensor, fn: Tensor
) -> Tensor:
if average is None or average == "none":
return score
if average == "weighted":
weights = tp + fn
else:
weights = torch.ones_like(score)
if not multilabel:
weights[tp + fp + fn == 0] = 0.0
return _safe_divide(weights * score, weights.sum(-1, keepdim=True)).sum(-1)


def _auc_format_inputs(x: Tensor, y: Tensor) -> Tuple[Tensor, Tensor]:
"""Check that auc input is correct."""
x = x.squeeze() if x.ndim > 1 else x
Expand Down
Loading

0 comments on commit 17a4920

Please sign in to comment.