Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Proper multilabel support for confmat #134

Merged
merged 20 commits into from
Mar 29, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Changed `ExplainedVariance` from storing all preds/targets to tracking 5 statistics ([#68](https://github.com/PyTorchLightning/metrics/pull/68))


- Changed behaviour of `confusionmatrix` for multilabel data to better match `multilabel_confusion_matrix` from sklearn ([#134](https://github.com/PyTorchLightning/metrics/pull/134))

carmocca marked this conversation as resolved.
Show resolved Hide resolved
### Deprecated


Expand Down
61 changes: 42 additions & 19 deletions tests/classification/test_confusion_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import pytest
import torch
from sklearn.metrics import confusion_matrix as sk_confusion_matrix
from sklearn.metrics import multilabel_confusion_matrix as sk_multilabel_confusion_matrix

from tests.classification.inputs import _input_binary, _input_binary_prob
from tests.classification.inputs import _input_multiclass as _input_mcls
Expand Down Expand Up @@ -48,17 +49,35 @@ def _sk_cm_binary(preds, target, normalize=None):


def _sk_cm_multilabel_prob(preds, target, normalize=None):
sk_preds = (preds.view(-1).numpy() >= THRESHOLD).astype(np.uint8)
sk_target = target.view(-1).numpy()
sk_preds = (preds.numpy() >= THRESHOLD).astype(np.uint8)
sk_target = target.numpy()

return sk_confusion_matrix(y_true=sk_target, y_pred=sk_preds, normalize=normalize)
cm = sk_multilabel_confusion_matrix(y_true=sk_target, y_pred=sk_preds)
if normalize is not None:
if normalize == 'true':
cm = cm / cm.sum(axis=1, keepdims=True)
elif normalize == 'pred':
cm = cm / cm.sum(axis=0, keepdims=True)
elif normalize == 'all':
cm = cm / cm.sum()
SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved
cm[np.isnan(cm)] = 0
return cm


def _sk_cm_multilabel(preds, target, normalize=None):
sk_preds = preds.view(-1).numpy()
sk_target = target.view(-1).numpy()
sk_preds = preds.numpy()
sk_target = target.numpy()

return sk_confusion_matrix(y_true=sk_target, y_pred=sk_preds, normalize=normalize)
cm = sk_multilabel_confusion_matrix(y_true=sk_target, y_pred=sk_preds)
if normalize is not None:
if normalize == 'true':
cm = cm / cm.sum(axis=1, keepdims=True)
elif normalize == 'pred':
cm = cm / cm.sum(axis=0, keepdims=True)
elif normalize == 'all':
cm = cm / cm.sum()
cm[np.isnan(cm)] = 0
return cm


def _sk_cm_multiclass_prob(preds, target, normalize=None):
Expand Down Expand Up @@ -91,21 +110,23 @@ def _sk_cm_multidim_multiclass(preds, target, normalize=None):

@pytest.mark.parametrize("normalize", ['true', 'pred', 'all', None])
@pytest.mark.parametrize(
"preds, target, sk_metric, num_classes",
[(_input_binary_prob.preds, _input_binary_prob.target, _sk_cm_binary_prob, 2),
(_input_binary.preds, _input_binary.target, _sk_cm_binary, 2),
(_input_mlb_prob.preds, _input_mlb_prob.target, _sk_cm_multilabel_prob, 2),
(_input_mlb.preds, _input_mlb.target, _sk_cm_multilabel, 2),
(_input_mcls_prob.preds, _input_mcls_prob.target, _sk_cm_multiclass_prob, NUM_CLASSES),
(_input_mcls.preds, _input_mcls.target, _sk_cm_multiclass, NUM_CLASSES),
(_input_mdmc_prob.preds, _input_mdmc_prob.target, _sk_cm_multidim_multiclass_prob, NUM_CLASSES),
(_input_mdmc.preds, _input_mdmc.target, _sk_cm_multidim_multiclass, NUM_CLASSES)]
"preds, target, sk_metric, num_classes, multilabel",
[(_input_binary_prob.preds, _input_binary_prob.target, _sk_cm_binary_prob, 2, False),
(_input_binary.preds, _input_binary.target, _sk_cm_binary, 2, False),
(_input_mlb_prob.preds, _input_mlb_prob.target, _sk_cm_multilabel_prob, NUM_CLASSES, True),
(_input_mlb.preds, _input_mlb.target, _sk_cm_multilabel, NUM_CLASSES, True),
(_input_mcls_prob.preds, _input_mcls_prob.target, _sk_cm_multiclass_prob, NUM_CLASSES, False),
(_input_mcls.preds, _input_mcls.target, _sk_cm_multiclass, NUM_CLASSES, False),
(_input_mdmc_prob.preds, _input_mdmc_prob.target, _sk_cm_multidim_multiclass_prob, NUM_CLASSES, False),
(_input_mdmc.preds, _input_mdmc.target, _sk_cm_multidim_multiclass, NUM_CLASSES, False)]
)
class TestConfusionMatrix(MetricTester):

@pytest.mark.parametrize("ddp", [True, False])
@pytest.mark.parametrize("dist_sync_on_step", [True, False])
def test_confusion_matrix(self, normalize, preds, target, sk_metric, num_classes, ddp, dist_sync_on_step):
def test_confusion_matrix(
self, normalize, preds, target, sk_metric, num_classes, multilabel, ddp, dist_sync_on_step
):
self.run_class_metric_test(
ddp=ddp,
preds=preds,
Expand All @@ -116,11 +137,12 @@ def test_confusion_matrix(self, normalize, preds, target, sk_metric, num_classes
metric_args={
"num_classes": num_classes,
"threshold": THRESHOLD,
"normalize": normalize
"normalize": normalize,
"multilabel": multilabel
}
)

def test_confusion_matrix_functional(self, normalize, preds, target, sk_metric, num_classes):
def test_confusion_matrix_functional(self, normalize, preds, target, sk_metric, num_classes, multilabel):
self.run_functional_metric_test(
preds,
target,
Expand All @@ -129,7 +151,8 @@ def test_confusion_matrix_functional(self, normalize, preds, target, sk_metric,
metric_args={
"num_classes": num_classes,
"threshold": THRESHOLD,
"normalize": normalize
"normalize": normalize,
"multilabel": multilabel
}
)

Expand Down
46 changes: 36 additions & 10 deletions torchmetrics/classification/confusion_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,8 @@ class ConfusionMatrix(Metric):
"""
Computes the `confusion matrix
<https://scikit-learn.org/stable/modules/model_evaluation.html#confusion-matrix>`_. Works with binary,
multiclass, and multilabel data. Accepts probabilities from a model output or
integer class values in prediction. Works with multi-dimensional preds and
target.

Note:
This metric produces a multi-dimensional output, so it can not be directly logged.
multiclass, and multilabel data. Accepts probabilities from a model output or integer class values in prediction.
Works with multi-dimensional preds and target, but it should be noted that additional dimensions will be flattened.

Forward accepts

Expand All @@ -41,6 +37,10 @@ class ConfusionMatrix(Metric):

If preds has an extra dimension as in the case of multi-class scores we perform an argmax on ``dim=1``.

If working with multilabel data, setting the `is_multilabel` argument to `True` will make sure that a
`confusion matrix gets calculated per label
<https://scikit-learn.org/stable/modules/generated/sklearn.metrics.multilabel_confusion_matrix.html>`_.

Args:
num_classes: Number of classes in the dataset.
normalize: Normalization mode for confusion matrix. Choose from
Expand All @@ -52,6 +52,8 @@ class ConfusionMatrix(Metric):

threshold:
Threshold value for binary or multi-label probabilites. default: 0.5
multilabel:
determines if data is multilabel or not.
compute_on_step:
Forward only calls ``update()`` and return None if this is set to False. default: True
dist_sync_on_step:
Expand All @@ -60,7 +62,7 @@ class ConfusionMatrix(Metric):
process_group:
Specify the process group on which synchronization is called. default: None (which selects the entire world)

Example:
Example (binary data):
>>> from torchmetrics import ConfusionMatrix
>>> target = torch.tensor([1, 1, 0, 0])
>>> preds = torch.tensor([0, 1, 0, 0])
Expand All @@ -69,13 +71,31 @@ class ConfusionMatrix(Metric):
tensor([[2., 0.],
[1., 1.]])

Example (multiclass data):
>>> target = torch.tensor([2, 1, 0, 0])
>>> preds = torch.tensor([2, 1, 0, 1])
>>> confmat = ConfusionMatrix(num_classes=3)
>>> confmat(preds, target)
tensor([[1., 1., 0.],
[0., 1., 0.],
[0., 0., 1.]])

Example (multilabel data):
>>> target = torch.tensor([[0, 1, 0], [1, 0, 1]])
>>> preds = torch.tensor([[0, 0, 1], [1, 0, 1]])
>>> confmat = ConfusionMatrix(num_classes=3, multilabel=True)
>>> confmat(preds, target) # doctest: +NORMALIZE_WHITESPACE
tensor([[[1., 0.], [0., 1.]],
[[1., 0.], [1., 0.]],
[[0., 1.], [0., 1.]]])
"""

def __init__(
self,
num_classes: int,
normalize: Optional[str] = None,
threshold: float = 0.5,
multilabel: bool = False,
compute_on_step: bool = True,
dist_sync_on_step: bool = False,
process_group: Optional[Any] = None,
Expand All @@ -89,12 +109,14 @@ def __init__(
self.num_classes = num_classes
self.normalize = normalize
self.threshold = threshold
self.multilabel = multilabel

allowed_normalize = ('true', 'pred', 'all', 'none', None)
assert self.normalize in allowed_normalize, \
f"Argument average needs to one of the following: {allowed_normalize}"

self.add_state("confmat", default=torch.zeros(num_classes, num_classes), dist_reduce_fx="sum")
default = torch.zeros(num_classes, 2, 2) if multilabel else torch.zeros(num_classes, num_classes)
self.add_state("confmat", default=default, dist_reduce_fx="sum")

def update(self, preds: Tensor, target: Tensor):
"""
Expand All @@ -104,11 +126,15 @@ def update(self, preds: Tensor, target: Tensor):
preds: Predictions from model
target: Ground truth values
"""
confmat = _confusion_matrix_update(preds, target, self.num_classes, self.threshold)
confmat = _confusion_matrix_update(preds, target, self.num_classes, self.threshold, self.multilabel)
self.confmat += confmat

def compute(self) -> Tensor:
"""
Computes confusion matrix
Computes confusion matrix.

Returns:
If `multilabel=False` this will be a `[n_classes, n_classes]` tensor and if `multilabel=True`
this will be a `[n_classes, 2, 2]` tensor
"""
return _confusion_matrix_compute(self.confmat, self.normalize)
67 changes: 55 additions & 12 deletions torchmetrics/functional/classification/confusion_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,26 @@
from torchmetrics.utilities.enums import DataType


def _confusion_matrix_update(preds: Tensor, target: Tensor, num_classes: int, threshold: float = 0.5) -> Tensor:
def _confusion_matrix_update(
preds: Tensor, target: Tensor, num_classes: int, threshold: float = 0.5, multilabel: bool = False
) -> Tensor:
preds, target, mode = _input_format_classification(preds, target, threshold)
if mode not in (DataType.BINARY, DataType.MULTILABEL):
preds = preds.argmax(dim=1)
target = target.argmax(dim=1)
unique_mapping = (target.view(-1) * num_classes + preds.view(-1)).to(torch.long)
bins = torch.bincount(unique_mapping, minlength=num_classes**2)
confmat = bins.reshape(num_classes, num_classes)

if multilabel:
unique_mapping = ((2 * target + preds) + 4 * torch.arange(num_classes, device=preds.device)).flatten()
minlength = 4 * num_classes
else:
unique_mapping = (target.view(-1) * num_classes + preds.view(-1)).to(torch.long)
minlength = num_classes ** 2

bins = torch.bincount(unique_mapping, minlength=minlength)
if multilabel:
confmat = bins.reshape(num_classes, 2, 2)
else:
confmat = bins.reshape(num_classes, num_classes)
return confmat


Expand All @@ -54,18 +66,28 @@ def _confusion_matrix_compute(confmat: Tensor, normalize: Optional[str] = None)


def confusion_matrix(
preds: Tensor, target: Tensor, num_classes: int, normalize: Optional[str] = None, threshold: float = 0.5
preds: Tensor,
target: Tensor,
num_classes: int,
normalize: Optional[str] = None,
threshold: float = 0.5,
multilabel: bool = False
) -> Tensor:
"""
Computes the confusion matrix. Works with binary, multiclass, and multilabel data.
Accepts probabilities from a model output or integer class values in prediction.
Works with multi-dimensional preds and target.
Computes the `confusion matrix
<https://scikit-learn.org/stable/modules/model_evaluation.html#confusion-matrix>`_. Works with binary,
multiclass, and multilabel data. Accepts probabilities from a model output or integer class values in prediction.
Works with multi-dimensional preds and target, but it should be noted that additional dimensions will be flattened.

If preds and target are the same shape and preds is a float tensor, we use the ``self.threshold`` argument
to convert into integer labels. This is the case for binary and multi-label probabilities.

If preds has an extra dimension as in the case of multi-class scores we perform an argmax on ``dim=1``.

If working with multilabel data, setting the `is_multilabel` argument to `True` will make sure that a
`confusion matrix gets calculated per label
<https://scikit-learn.org/stable/modules/generated/sklearn.metrics.multilabel_confusion_matrix.html>`_.

Args:
preds: (float or long tensor), Either a ``(N, ...)`` tensor with labels or
``(N, C, ...)`` where C is the number of classes, tensor with labels/probabilities
Expand All @@ -80,14 +102,35 @@ def confusion_matrix(

threshold:
Threshold value for binary or multi-label probabilities. default: 0.5
multilabel:
determines if data is multilabel or not.

Example:
>>> from torchmetrics.functional import confusion_matrix
Example (binary data):
>>> from torchmetrics import ConfusionMatrix
>>> target = torch.tensor([1, 1, 0, 0])
>>> preds = torch.tensor([0, 1, 0, 0])
>>> confusion_matrix(preds, target, num_classes=2)
>>> confmat = ConfusionMatrix(num_classes=2)
>>> confmat(preds, target)
tensor([[2., 0.],
[1., 1.]])

Example (multiclass data):
>>> target = torch.tensor([2, 1, 0, 0])
>>> preds = torch.tensor([2, 1, 0, 1])
>>> confmat = ConfusionMatrix(num_classes=3)
>>> confmat(preds, target)
tensor([[1., 1., 0.],
[0., 1., 0.],
[0., 0., 1.]])

Example (multilabel data):
>>> target = torch.tensor([[0, 1, 0], [1, 0, 1]])
>>> preds = torch.tensor([[0, 0, 1], [1, 0, 1]])
>>> confmat = ConfusionMatrix(num_classes=3, multilabel=True)
>>> confmat(preds, target) # doctest: +NORMALIZE_WHITESPACE
tensor([[[1., 0.], [0., 1.]],
[[1., 0.], [1., 0.]],
[[0., 1.], [0., 1.]]])
"""
confmat = _confusion_matrix_update(preds, target, num_classes, threshold)
confmat = _confusion_matrix_update(preds, target, num_classes, threshold, multilabel)
return _confusion_matrix_compute(confmat, normalize)