Skip to content

Commit

Permalink
Proper multilabel support for confmat (#134)
Browse files Browse the repository at this point in the history
* multilabel

* multilabel

* remove

* Update CHANGELOG.md

* Update torchmetrics/classification/confusion_matrix.py

Co-authored-by: Jirka Borovec <[email protected]>

* Update tests/classification/test_confusion_matrix.py

Co-authored-by: Jirka Borovec <[email protected]>

* Update torchmetrics/functional/classification/confusion_matrix.py

* fix docstring

* revert

* Apply suggestions from code review

* Apply suggestions from code review

Co-authored-by: Jirka Borovec <[email protected]>
  • Loading branch information
SkafteNicki and Borda authored Mar 29, 2021
1 parent 19b77cc commit d1af80f
Show file tree
Hide file tree
Showing 4 changed files with 135 additions and 41 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Changed `ExplainedVariance` from storing all preds/targets to tracking 5 statistics ([#68](https://github.com/PyTorchLightning/metrics/pull/68))


- Changed behaviour of `confusionmatrix` for multilabel data to better match `multilabel_confusion_matrix` from sklearn ([#134](https://github.com/PyTorchLightning/metrics/pull/134))

### Deprecated


Expand Down
61 changes: 42 additions & 19 deletions tests/classification/test_confusion_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import pytest
import torch
from sklearn.metrics import confusion_matrix as sk_confusion_matrix
from sklearn.metrics import multilabel_confusion_matrix as sk_multilabel_confusion_matrix

from tests.classification.inputs import _input_binary, _input_binary_prob
from tests.classification.inputs import _input_multiclass as _input_mcls
Expand Down Expand Up @@ -48,17 +49,35 @@ def _sk_cm_binary(preds, target, normalize=None):


def _sk_cm_multilabel_prob(preds, target, normalize=None):
sk_preds = (preds.view(-1).numpy() >= THRESHOLD).astype(np.uint8)
sk_target = target.view(-1).numpy()
sk_preds = (preds.numpy() >= THRESHOLD).astype(np.uint8)
sk_target = target.numpy()

return sk_confusion_matrix(y_true=sk_target, y_pred=sk_preds, normalize=normalize)
cm = sk_multilabel_confusion_matrix(y_true=sk_target, y_pred=sk_preds)
if normalize is not None:
if normalize == 'true':
cm = cm / cm.sum(axis=1, keepdims=True)
elif normalize == 'pred':
cm = cm / cm.sum(axis=0, keepdims=True)
elif normalize == 'all':
cm = cm / cm.sum()
cm[np.isnan(cm)] = 0
return cm


def _sk_cm_multilabel(preds, target, normalize=None):
sk_preds = preds.view(-1).numpy()
sk_target = target.view(-1).numpy()
sk_preds = preds.numpy()
sk_target = target.numpy()

return sk_confusion_matrix(y_true=sk_target, y_pred=sk_preds, normalize=normalize)
cm = sk_multilabel_confusion_matrix(y_true=sk_target, y_pred=sk_preds)
if normalize is not None:
if normalize == 'true':
cm = cm / cm.sum(axis=1, keepdims=True)
elif normalize == 'pred':
cm = cm / cm.sum(axis=0, keepdims=True)
elif normalize == 'all':
cm = cm / cm.sum()
cm[np.isnan(cm)] = 0
return cm


def _sk_cm_multiclass_prob(preds, target, normalize=None):
Expand Down Expand Up @@ -91,21 +110,23 @@ def _sk_cm_multidim_multiclass(preds, target, normalize=None):

@pytest.mark.parametrize("normalize", ['true', 'pred', 'all', None])
@pytest.mark.parametrize(
"preds, target, sk_metric, num_classes",
[(_input_binary_prob.preds, _input_binary_prob.target, _sk_cm_binary_prob, 2),
(_input_binary.preds, _input_binary.target, _sk_cm_binary, 2),
(_input_mlb_prob.preds, _input_mlb_prob.target, _sk_cm_multilabel_prob, 2),
(_input_mlb.preds, _input_mlb.target, _sk_cm_multilabel, 2),
(_input_mcls_prob.preds, _input_mcls_prob.target, _sk_cm_multiclass_prob, NUM_CLASSES),
(_input_mcls.preds, _input_mcls.target, _sk_cm_multiclass, NUM_CLASSES),
(_input_mdmc_prob.preds, _input_mdmc_prob.target, _sk_cm_multidim_multiclass_prob, NUM_CLASSES),
(_input_mdmc.preds, _input_mdmc.target, _sk_cm_multidim_multiclass, NUM_CLASSES)]
"preds, target, sk_metric, num_classes, multilabel",
[(_input_binary_prob.preds, _input_binary_prob.target, _sk_cm_binary_prob, 2, False),
(_input_binary.preds, _input_binary.target, _sk_cm_binary, 2, False),
(_input_mlb_prob.preds, _input_mlb_prob.target, _sk_cm_multilabel_prob, NUM_CLASSES, True),
(_input_mlb.preds, _input_mlb.target, _sk_cm_multilabel, NUM_CLASSES, True),
(_input_mcls_prob.preds, _input_mcls_prob.target, _sk_cm_multiclass_prob, NUM_CLASSES, False),
(_input_mcls.preds, _input_mcls.target, _sk_cm_multiclass, NUM_CLASSES, False),
(_input_mdmc_prob.preds, _input_mdmc_prob.target, _sk_cm_multidim_multiclass_prob, NUM_CLASSES, False),
(_input_mdmc.preds, _input_mdmc.target, _sk_cm_multidim_multiclass, NUM_CLASSES, False)]
)
class TestConfusionMatrix(MetricTester):

@pytest.mark.parametrize("ddp", [True, False])
@pytest.mark.parametrize("dist_sync_on_step", [True, False])
def test_confusion_matrix(self, normalize, preds, target, sk_metric, num_classes, ddp, dist_sync_on_step):
def test_confusion_matrix(
self, normalize, preds, target, sk_metric, num_classes, multilabel, ddp, dist_sync_on_step
):
self.run_class_metric_test(
ddp=ddp,
preds=preds,
Expand All @@ -116,11 +137,12 @@ def test_confusion_matrix(self, normalize, preds, target, sk_metric, num_classes
metric_args={
"num_classes": num_classes,
"threshold": THRESHOLD,
"normalize": normalize
"normalize": normalize,
"multilabel": multilabel
}
)

def test_confusion_matrix_functional(self, normalize, preds, target, sk_metric, num_classes):
def test_confusion_matrix_functional(self, normalize, preds, target, sk_metric, num_classes, multilabel):
self.run_functional_metric_test(
preds,
target,
Expand All @@ -129,7 +151,8 @@ def test_confusion_matrix_functional(self, normalize, preds, target, sk_metric,
metric_args={
"num_classes": num_classes,
"threshold": THRESHOLD,
"normalize": normalize
"normalize": normalize,
"multilabel": multilabel
}
)

Expand Down
46 changes: 36 additions & 10 deletions torchmetrics/classification/confusion_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,8 @@ class ConfusionMatrix(Metric):
"""
Computes the `confusion matrix
<https://scikit-learn.org/stable/modules/model_evaluation.html#confusion-matrix>`_. Works with binary,
multiclass, and multilabel data. Accepts probabilities from a model output or
integer class values in prediction. Works with multi-dimensional preds and
target.
Note:
This metric produces a multi-dimensional output, so it can not be directly logged.
multiclass, and multilabel data. Accepts probabilities from a model output or integer class values in prediction.
Works with multi-dimensional preds and target, but it should be noted that additional dimensions will be flattened.
Forward accepts
Expand All @@ -41,6 +37,10 @@ class ConfusionMatrix(Metric):
If preds has an extra dimension as in the case of multi-class scores we perform an argmax on ``dim=1``.
If working with multilabel data, setting the `is_multilabel` argument to `True` will make sure that a
`confusion matrix gets calculated per label
<https://scikit-learn.org/stable/modules/generated/sklearn.metrics.multilabel_confusion_matrix.html>`_.
Args:
num_classes: Number of classes in the dataset.
normalize: Normalization mode for confusion matrix. Choose from
Expand All @@ -52,6 +52,8 @@ class ConfusionMatrix(Metric):
threshold:
Threshold value for binary or multi-label probabilites. default: 0.5
multilabel:
determines if data is multilabel or not.
compute_on_step:
Forward only calls ``update()`` and return None if this is set to False. default: True
dist_sync_on_step:
Expand All @@ -60,7 +62,7 @@ class ConfusionMatrix(Metric):
process_group:
Specify the process group on which synchronization is called. default: None (which selects the entire world)
Example:
Example (binary data):
>>> from torchmetrics import ConfusionMatrix
>>> target = torch.tensor([1, 1, 0, 0])
>>> preds = torch.tensor([0, 1, 0, 0])
Expand All @@ -69,13 +71,31 @@ class ConfusionMatrix(Metric):
tensor([[2., 0.],
[1., 1.]])
Example (multiclass data):
>>> target = torch.tensor([2, 1, 0, 0])
>>> preds = torch.tensor([2, 1, 0, 1])
>>> confmat = ConfusionMatrix(num_classes=3)
>>> confmat(preds, target)
tensor([[1., 1., 0.],
[0., 1., 0.],
[0., 0., 1.]])
Example (multilabel data):
>>> target = torch.tensor([[0, 1, 0], [1, 0, 1]])
>>> preds = torch.tensor([[0, 0, 1], [1, 0, 1]])
>>> confmat = ConfusionMatrix(num_classes=3, multilabel=True)
>>> confmat(preds, target) # doctest: +NORMALIZE_WHITESPACE
tensor([[[1., 0.], [0., 1.]],
[[1., 0.], [1., 0.]],
[[0., 1.], [0., 1.]]])
"""

def __init__(
self,
num_classes: int,
normalize: Optional[str] = None,
threshold: float = 0.5,
multilabel: bool = False,
compute_on_step: bool = True,
dist_sync_on_step: bool = False,
process_group: Optional[Any] = None,
Expand All @@ -89,12 +109,14 @@ def __init__(
self.num_classes = num_classes
self.normalize = normalize
self.threshold = threshold
self.multilabel = multilabel

allowed_normalize = ('true', 'pred', 'all', 'none', None)
assert self.normalize in allowed_normalize, \
f"Argument average needs to one of the following: {allowed_normalize}"

self.add_state("confmat", default=torch.zeros(num_classes, num_classes), dist_reduce_fx="sum")
default = torch.zeros(num_classes, 2, 2) if multilabel else torch.zeros(num_classes, num_classes)
self.add_state("confmat", default=default, dist_reduce_fx="sum")

def update(self, preds: Tensor, target: Tensor):
"""
Expand All @@ -104,11 +126,15 @@ def update(self, preds: Tensor, target: Tensor):
preds: Predictions from model
target: Ground truth values
"""
confmat = _confusion_matrix_update(preds, target, self.num_classes, self.threshold)
confmat = _confusion_matrix_update(preds, target, self.num_classes, self.threshold, self.multilabel)
self.confmat += confmat

def compute(self) -> Tensor:
"""
Computes confusion matrix
Computes confusion matrix.
Returns:
If `multilabel=False` this will be a `[n_classes, n_classes]` tensor and if `multilabel=True`
this will be a `[n_classes, 2, 2]` tensor
"""
return _confusion_matrix_compute(self.confmat, self.normalize)
67 changes: 55 additions & 12 deletions torchmetrics/functional/classification/confusion_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,26 @@
from torchmetrics.utilities.enums import DataType


def _confusion_matrix_update(preds: Tensor, target: Tensor, num_classes: int, threshold: float = 0.5) -> Tensor:
def _confusion_matrix_update(
preds: Tensor, target: Tensor, num_classes: int, threshold: float = 0.5, multilabel: bool = False
) -> Tensor:
preds, target, mode = _input_format_classification(preds, target, threshold)
if mode not in (DataType.BINARY, DataType.MULTILABEL):
preds = preds.argmax(dim=1)
target = target.argmax(dim=1)
unique_mapping = (target.view(-1) * num_classes + preds.view(-1)).to(torch.long)
bins = torch.bincount(unique_mapping, minlength=num_classes**2)
confmat = bins.reshape(num_classes, num_classes)

if multilabel:
unique_mapping = ((2 * target + preds) + 4 * torch.arange(num_classes, device=preds.device)).flatten()
minlength = 4 * num_classes
else:
unique_mapping = (target.view(-1) * num_classes + preds.view(-1)).to(torch.long)
minlength = num_classes ** 2

bins = torch.bincount(unique_mapping, minlength=minlength)
if multilabel:
confmat = bins.reshape(num_classes, 2, 2)
else:
confmat = bins.reshape(num_classes, num_classes)
return confmat


Expand All @@ -54,18 +66,28 @@ def _confusion_matrix_compute(confmat: Tensor, normalize: Optional[str] = None)


def confusion_matrix(
preds: Tensor, target: Tensor, num_classes: int, normalize: Optional[str] = None, threshold: float = 0.5
preds: Tensor,
target: Tensor,
num_classes: int,
normalize: Optional[str] = None,
threshold: float = 0.5,
multilabel: bool = False
) -> Tensor:
"""
Computes the confusion matrix. Works with binary, multiclass, and multilabel data.
Accepts probabilities from a model output or integer class values in prediction.
Works with multi-dimensional preds and target.
Computes the `confusion matrix
<https://scikit-learn.org/stable/modules/model_evaluation.html#confusion-matrix>`_. Works with binary,
multiclass, and multilabel data. Accepts probabilities from a model output or integer class values in prediction.
Works with multi-dimensional preds and target, but it should be noted that additional dimensions will be flattened.
If preds and target are the same shape and preds is a float tensor, we use the ``self.threshold`` argument
to convert into integer labels. This is the case for binary and multi-label probabilities.
If preds has an extra dimension as in the case of multi-class scores we perform an argmax on ``dim=1``.
If working with multilabel data, setting the `is_multilabel` argument to `True` will make sure that a
`confusion matrix gets calculated per label
<https://scikit-learn.org/stable/modules/generated/sklearn.metrics.multilabel_confusion_matrix.html>`_.
Args:
preds: (float or long tensor), Either a ``(N, ...)`` tensor with labels or
``(N, C, ...)`` where C is the number of classes, tensor with labels/probabilities
Expand All @@ -80,14 +102,35 @@ def confusion_matrix(
threshold:
Threshold value for binary or multi-label probabilities. default: 0.5
multilabel:
determines if data is multilabel or not.
Example:
>>> from torchmetrics.functional import confusion_matrix
Example (binary data):
>>> from torchmetrics import ConfusionMatrix
>>> target = torch.tensor([1, 1, 0, 0])
>>> preds = torch.tensor([0, 1, 0, 0])
>>> confusion_matrix(preds, target, num_classes=2)
>>> confmat = ConfusionMatrix(num_classes=2)
>>> confmat(preds, target)
tensor([[2., 0.],
[1., 1.]])
Example (multiclass data):
>>> target = torch.tensor([2, 1, 0, 0])
>>> preds = torch.tensor([2, 1, 0, 1])
>>> confmat = ConfusionMatrix(num_classes=3)
>>> confmat(preds, target)
tensor([[1., 1., 0.],
[0., 1., 0.],
[0., 0., 1.]])
Example (multilabel data):
>>> target = torch.tensor([[0, 1, 0], [1, 0, 1]])
>>> preds = torch.tensor([[0, 0, 1], [1, 0, 1]])
>>> confmat = ConfusionMatrix(num_classes=3, multilabel=True)
>>> confmat(preds, target) # doctest: +NORMALIZE_WHITESPACE
tensor([[[1., 0.], [0., 1.]],
[[1., 0.], [1., 0.]],
[[0., 1.], [0., 1.]]])
"""
confmat = _confusion_matrix_update(preds, target, num_classes, threshold)
confmat = _confusion_matrix_update(preds, target, num_classes, threshold, multilabel)
return _confusion_matrix_compute(confmat, normalize)

0 comments on commit d1af80f

Please sign in to comment.