Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Proper multilabel support for confmat #134

Merged
merged 20 commits into from
Mar 29, 2021
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Changed `ExplainedVariance` from storing all preds/targets to tracking 5 statistics ([#68](https://github.com/PyTorchLightning/metrics/pull/68))


- Changed behaviour of `confusionmatrix` for multilabel data to better match `multilabel_confusion_matrix` from sklearn ([#134](https://github.com/PyTorchLightning/metrics/pull/134))

carmocca marked this conversation as resolved.
Show resolved Hide resolved
### Deprecated


Expand Down
61 changes: 42 additions & 19 deletions tests/classification/test_confusion_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import pytest
import torch
from sklearn.metrics import confusion_matrix as sk_confusion_matrix
from sklearn.metrics import multilabel_confusion_matrix as sk_multilabel_confusion_matrix

from tests.classification.inputs import _input_binary, _input_binary_prob
from tests.classification.inputs import _input_multiclass as _input_mcls
Expand Down Expand Up @@ -48,17 +49,35 @@ def _sk_cm_binary(preds, target, normalize=None):


def _sk_cm_multilabel_prob(preds, target, normalize=None):
sk_preds = (preds.view(-1).numpy() >= THRESHOLD).astype(np.uint8)
sk_target = target.view(-1).numpy()
sk_preds = (preds.numpy() >= THRESHOLD).astype(np.uint8)
sk_target = target.numpy()

return sk_confusion_matrix(y_true=sk_target, y_pred=sk_preds, normalize=normalize)
cm = sk_multilabel_confusion_matrix(y_true=sk_target, y_pred=sk_preds)
if normalize is not None:
if normalize == 'true':
cm = cm / cm.sum(axis=1, keepdims=True)
elif normalize == 'pred':
cm = cm / cm.sum(axis=0, keepdims=True)
elif normalize == 'all':
cm = cm / cm.sum()
SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved
cm[np.isnan(cm)] = 0
return cm


def _sk_cm_multilabel(preds, target, normalize=None):
sk_preds = preds.view(-1).numpy()
sk_target = target.view(-1).numpy()
sk_preds = preds.numpy()
sk_target = target.numpy()

return sk_confusion_matrix(y_true=sk_target, y_pred=sk_preds, normalize=normalize)
cm = sk_multilabel_confusion_matrix(y_true=sk_target, y_pred=sk_preds)
if normalize is not None:
if normalize == 'true':
cm = cm / cm.sum(axis=1, keepdims=True)
elif normalize == 'pred':
cm = cm / cm.sum(axis=0, keepdims=True)
elif normalize == 'all':
cm = cm / cm.sum()
cm[np.isnan(cm)] = 0
return cm


def _sk_cm_multiclass_prob(preds, target, normalize=None):
Expand Down Expand Up @@ -91,21 +110,23 @@ def _sk_cm_multidim_multiclass(preds, target, normalize=None):

@pytest.mark.parametrize("normalize", ['true', 'pred', 'all', None])
@pytest.mark.parametrize(
"preds, target, sk_metric, num_classes",
[(_input_binary_prob.preds, _input_binary_prob.target, _sk_cm_binary_prob, 2),
(_input_binary.preds, _input_binary.target, _sk_cm_binary, 2),
(_input_mlb_prob.preds, _input_mlb_prob.target, _sk_cm_multilabel_prob, 2),
(_input_mlb.preds, _input_mlb.target, _sk_cm_multilabel, 2),
(_input_mcls_prob.preds, _input_mcls_prob.target, _sk_cm_multiclass_prob, NUM_CLASSES),
(_input_mcls.preds, _input_mcls.target, _sk_cm_multiclass, NUM_CLASSES),
(_input_mdmc_prob.preds, _input_mdmc_prob.target, _sk_cm_multidim_multiclass_prob, NUM_CLASSES),
(_input_mdmc.preds, _input_mdmc.target, _sk_cm_multidim_multiclass, NUM_CLASSES)]
"preds, target, sk_metric, num_classes, is_multilabel",
SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved
[(_input_binary_prob.preds, _input_binary_prob.target, _sk_cm_binary_prob, 2, False),
(_input_binary.preds, _input_binary.target, _sk_cm_binary, 2, False),
(_input_mlb_prob.preds, _input_mlb_prob.target, _sk_cm_multilabel_prob, NUM_CLASSES, True),
(_input_mlb.preds, _input_mlb.target, _sk_cm_multilabel, NUM_CLASSES, True),
(_input_mcls_prob.preds, _input_mcls_prob.target, _sk_cm_multiclass_prob, NUM_CLASSES, False),
(_input_mcls.preds, _input_mcls.target, _sk_cm_multiclass, NUM_CLASSES, False),
(_input_mdmc_prob.preds, _input_mdmc_prob.target, _sk_cm_multidim_multiclass_prob, NUM_CLASSES, False),
(_input_mdmc.preds, _input_mdmc.target, _sk_cm_multidim_multiclass, NUM_CLASSES, False)]
)
class TestConfusionMatrix(MetricTester):

@pytest.mark.parametrize("ddp", [True, False])
@pytest.mark.parametrize("dist_sync_on_step", [True, False])
def test_confusion_matrix(self, normalize, preds, target, sk_metric, num_classes, ddp, dist_sync_on_step):
def test_confusion_matrix(
self, normalize, preds, target, sk_metric, num_classes, is_multilabel, ddp, dist_sync_on_step
SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved
):
self.run_class_metric_test(
ddp=ddp,
preds=preds,
Expand All @@ -116,11 +137,12 @@ def test_confusion_matrix(self, normalize, preds, target, sk_metric, num_classes
metric_args={
"num_classes": num_classes,
"threshold": THRESHOLD,
"normalize": normalize
"normalize": normalize,
"is_multilabel": is_multilabel
SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved
}
)

def test_confusion_matrix_functional(self, normalize, preds, target, sk_metric, num_classes):
def test_confusion_matrix_functional(self, normalize, preds, target, sk_metric, num_classes, is_multilabel):
SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved
self.run_functional_metric_test(
preds,
target,
Expand All @@ -129,7 +151,8 @@ def test_confusion_matrix_functional(self, normalize, preds, target, sk_metric,
metric_args={
"num_classes": num_classes,
"threshold": THRESHOLD,
"normalize": normalize
"normalize": normalize,
"is_multilabel": is_multilabel
SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved
}
)

Expand Down
57 changes: 47 additions & 10 deletions torchmetrics/classification/confusion_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,8 @@ class ConfusionMatrix(Metric):
"""
Computes the `confusion matrix
<https://scikit-learn.org/stable/modules/model_evaluation.html#confusion-matrix>`_. Works with binary,
multiclass, and multilabel data. Accepts probabilities from a model output or
integer class values in prediction. Works with multi-dimensional preds and
target.

Note:
This metric produces a multi-dimensional output, so it can not be directly logged.
multiclass, and multilabel data. Accepts probabilities from a model output or integer class values in prediction.
Works with multi-dimensional preds and target, but it should be noted that additional dimensions will be flattened.

Forward accepts

Expand All @@ -41,6 +37,10 @@ class ConfusionMatrix(Metric):

If preds has an extra dimension as in the case of multi-class scores we perform an argmax on ``dim=1``.

If working with multilabel data, setting the `is_multilabel` argument to `True` will make sure that a
`confusion matrix gets calculated per label
<https://scikit-learn.org/stable/modules/generated/sklearn.metrics.multilabel_confusion_matrix.html>`_.

Args:
num_classes: Number of classes in the dataset.
normalize: Normalization mode for confusion matrix. Choose from
Expand All @@ -52,6 +52,8 @@ class ConfusionMatrix(Metric):

threshold:
Threshold value for binary or multi-label probabilites. default: 0.5
is_multiclass:
SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved
SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved
determines if data is multilabel or not.
compute_on_step:
Forward only calls ``update()`` and return None if this is set to False. default: True
dist_sync_on_step:
Expand All @@ -60,7 +62,7 @@ class ConfusionMatrix(Metric):
process_group:
Specify the process group on which synchronization is called. default: None (which selects the entire world)

Example:
Example (binary data):
>>> from torchmetrics import ConfusionMatrix
>>> target = torch.tensor([1, 1, 0, 0])
>>> preds = torch.tensor([0, 1, 0, 0])
Expand All @@ -69,13 +71,40 @@ class ConfusionMatrix(Metric):
tensor([[2., 0.],
[1., 1.]])

Example (multiclass data):
>>> from torchmetrics import ConfusionMatrix
Borda marked this conversation as resolved.
Show resolved Hide resolved
>>> target = torch.tensor([2, 1, 0, 0])
>>> preds = torch.tensor([2, 1, 0, 1])
>>> confmat = ConfusionMatrix(num_classes=3)
>>> confmat(preds, target)
tensor([[1., 1., 0.],
[0., 1., 0.],
[0., 0., 1.]])

Example (multilabel data):
>>> from torchmetrics import ConfusionMatrix
Borda marked this conversation as resolved.
Show resolved Hide resolved
>>> target = torch.tensor([[0, 1, 0],
... [1, 0, 1]])
>>> preds = torch.tensor([[0, 0, 1],
... [1, 0, 1]])
Borda marked this conversation as resolved.
Show resolved Hide resolved
>>> confmat = ConfusionMatrix(num_classes=3, is_multilabel=True)
SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved
>>> confmat(preds, target)
tensor([[[1., 0.],
[0., 1.]],
<BLANKLINE>
[[1., 0.],
[1., 0.]],
<BLANKLINE>
[[0., 1.],
[0., 1.]]])
SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved
"""

def __init__(
self,
num_classes: int,
normalize: Optional[str] = None,
threshold: float = 0.5,
is_multilabel: bool = False,
SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved
compute_on_step: bool = True,
dist_sync_on_step: bool = False,
process_group: Optional[Any] = None,
Expand All @@ -89,12 +118,16 @@ def __init__(
self.num_classes = num_classes
self.normalize = normalize
self.threshold = threshold
self.is_multilabel = is_multilabel
SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved

allowed_normalize = ('true', 'pred', 'all', 'none', None)
assert self.normalize in allowed_normalize, \
f"Argument average needs to one of the following: {allowed_normalize}"

self.add_state("confmat", default=torch.zeros(num_classes, num_classes), dist_reduce_fx="sum")
if is_multilabel:
self.add_state("confmat", default=torch.zeros(num_classes, 2, 2), dist_reduce_fx="sum")
else:
self.add_state("confmat", default=torch.zeros(num_classes, num_classes), dist_reduce_fx="sum")
Borda marked this conversation as resolved.
Show resolved Hide resolved

def update(self, preds: Tensor, target: Tensor):
"""
Expand All @@ -104,11 +137,15 @@ def update(self, preds: Tensor, target: Tensor):
preds: Predictions from model
target: Ground truth values
"""
confmat = _confusion_matrix_update(preds, target, self.num_classes, self.threshold)
confmat = _confusion_matrix_update(preds, target, self.num_classes, self.threshold, self.is_multilabel)
SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved
self.confmat += confmat

def compute(self) -> Tensor:
"""
Computes confusion matrix
Computes confusion matrix.

Returns:
If `is_multiclass=False` this will be a `[n_classes, n_classes]` tensor and if `is_multiclass=True`
SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved
this will be a `[n_classes, 2, 2]` tensor
"""
return _confusion_matrix_compute(self.confmat, self.normalize)
76 changes: 64 additions & 12 deletions torchmetrics/functional/classification/confusion_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,26 @@
from torchmetrics.utilities.enums import DataType


def _confusion_matrix_update(preds: Tensor, target: Tensor, num_classes: int, threshold: float = 0.5) -> Tensor:
def _confusion_matrix_update(
preds: Tensor, target: Tensor, num_classes: int, threshold: float = 0.5, is_multilabel: bool = False
SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved
) -> Tensor:
preds, target, mode = _input_format_classification(preds, target, threshold)
if mode not in (DataType.BINARY, DataType.MULTILABEL):
preds = preds.argmax(dim=1)
target = target.argmax(dim=1)
unique_mapping = (target.view(-1) * num_classes + preds.view(-1)).to(torch.long)
bins = torch.bincount(unique_mapping, minlength=num_classes**2)
confmat = bins.reshape(num_classes, num_classes)

if is_multilabel:
SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved
unique_mapping = ((2 * target + preds) + 4 * torch.arange(num_classes, device=preds.device)).flatten()
minlength = 4 * num_classes
else:
unique_mapping = (target.view(-1) * num_classes + preds.view(-1)).to(torch.long)
minlength = num_classes ** 2

bins = torch.bincount(unique_mapping, minlength=minlength)
if is_multilabel:
SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved
confmat = bins.reshape(num_classes, 2, 2)
else:
confmat = bins.reshape(num_classes, num_classes)
return confmat


Expand All @@ -54,18 +66,28 @@ def _confusion_matrix_compute(confmat: Tensor, normalize: Optional[str] = None)


def confusion_matrix(
preds: Tensor, target: Tensor, num_classes: int, normalize: Optional[str] = None, threshold: float = 0.5
preds: Tensor,
target: Tensor,
num_classes: int,
normalize: Optional[str] = None,
threshold: float = 0.5,
is_multilabel: bool = False
SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved
) -> Tensor:
"""
Computes the confusion matrix. Works with binary, multiclass, and multilabel data.
Accepts probabilities from a model output or integer class values in prediction.
Works with multi-dimensional preds and target.
Computes the `confusion matrix
<https://scikit-learn.org/stable/modules/model_evaluation.html#confusion-matrix>`_. Works with binary,
multiclass, and multilabel data. Accepts probabilities from a model output or integer class values in prediction.
Works with multi-dimensional preds and target, but it should be noted that additional dimensions will be flattened.

If preds and target are the same shape and preds is a float tensor, we use the ``self.threshold`` argument
to convert into integer labels. This is the case for binary and multi-label probabilities.

If preds has an extra dimension as in the case of multi-class scores we perform an argmax on ``dim=1``.

If working with multilabel data, setting the `is_multilabel` argument to `True` will make sure that a
`confusion matrix gets calculated per label
<https://scikit-learn.org/stable/modules/generated/sklearn.metrics.multilabel_confusion_matrix.html>`_.

Args:
preds: (float or long tensor), Either a ``(N, ...)`` tensor with labels or
``(N, C, ...)`` where C is the number of classes, tensor with labels/probabilities
Expand All @@ -80,14 +102,44 @@ def confusion_matrix(

threshold:
Threshold value for binary or multi-label probabilities. default: 0.5
is_multiclass:
SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved
determines if data is multilabel or not.

Example:
>>> from torchmetrics.functional import confusion_matrix
Example (binary data):
>>> from torchmetrics import ConfusionMatrix
>>> target = torch.tensor([1, 1, 0, 0])
>>> preds = torch.tensor([0, 1, 0, 0])
>>> confusion_matrix(preds, target, num_classes=2)
>>> confmat = ConfusionMatrix(num_classes=2)
>>> confmat(preds, target)
tensor([[2., 0.],
[1., 1.]])

Example (multiclass data):
>>> from torchmetrics import ConfusionMatrix
Borda marked this conversation as resolved.
Show resolved Hide resolved
>>> target = torch.tensor([2, 1, 0, 0])
>>> preds = torch.tensor([2, 1, 0, 1])
>>> confmat = ConfusionMatrix(num_classes=3)
>>> confmat(preds, target)
tensor([[1., 1., 0.],
[0., 1., 0.],
[0., 0., 1.]])

Example (multilabel data):
>>> from torchmetrics import ConfusionMatrix
>>> target = torch.tensor([[0, 1, 0],
... [1, 0, 1]])
>>> preds = torch.tensor([[0, 0, 1],
... [1, 0, 1]])
Borda marked this conversation as resolved.
Show resolved Hide resolved
>>> confmat = ConfusionMatrix(num_classes=3, is_multilabel=True)
SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved
>>> confmat(preds, target)
tensor([[[1., 0.],
[0., 1.]],
<BLANKLINE>
[[1., 0.],
[1., 0.]],
<BLANKLINE>
[[0., 1.],
[0., 1.]]])
SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved
"""
confmat = _confusion_matrix_update(preds, target, num_classes, threshold)
confmat = _confusion_matrix_update(preds, target, num_classes, threshold, is_multilabel)
SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved
return _confusion_matrix_compute(confmat, normalize)