For the purposes of classification metrics, inputs (predictions and targets) are split
into these categories (N
stands for the batch size and C
for number of classes):
binary
means integers that are either 0 or 1
Type | preds shape | preds dtype | target shape | target dtype |
---|---|---|---|---|
Binary | (N,) | float |
(N,) | binary * |
Multi-class | (N,) | int |
(N,) | int |
Multi-class with logits or probabilities | (N, C) | float |
(N,) | int |
Multi-label | (N, ...) | float |
(N, ...) | binary * |
Multi-dimensional multi-class | (N, ...) | int |
(N, ...) | int |
Multi-dimensional multi-class with logits or probabilities | (N, C, ...) | float |
(N, ...) | int |
Note
All dimensions of size 1 (except N
) are "squeezed out" at the beginning, so
that, for example, a tensor of shape (N, 1)
is treated as (N, )
.
When predictions or targets are integers, it is assumed that class labels start at 0, i.e. the possible class labels are 0, 1, 2, 3, etc. Below are some examples of different input types
.. testcode:: # Binary inputs binary_preds = torch.tensor([0.6, 0.1, 0.9]) binary_target = torch.tensor([1, 0, 2]) # Multi-class inputs mc_preds = torch.tensor([0, 2, 1]) mc_target = torch.tensor([0, 1, 2]) # Multi-class inputs with probabilities mc_preds_probs = torch.tensor([[0.8, 0.2, 0], [0.1, 0.2, 0.7], [0.3, 0.6, 0.1]]) mc_target_probs = torch.tensor([0, 1, 2]) # Multi-label inputs ml_preds = torch.tensor([[0.2, 0.8, 0.9], [0.5, 0.6, 0.1], [0.3, 0.1, 0.1]]) ml_target = torch.tensor([[0, 1, 1], [1, 0, 0], [0, 0, 0]])
In some cases, you might have inputs which appear to be (multi-dimensional) multi-class but are actually binary/multi-label - for example, if both predictions and targets are integer (binary) tensors. Or it could be the other way around, you want to treat binary/multi-label inputs as 2-class (multi-dimensional) multi-class inputs.
For these cases, the metrics where this distinction would make a difference, expose the
multiclass
argument. Let's see how this is used on the example of
:class:`~torchmetrics.StatScores` metric.
First, let's consider the case with label predictions with 2 classes, which we want to treat as binary.
.. testcode:: from torchmetrics.functional import stat_scores # These inputs are supposed to be binary, but appear as multi-class preds = torch.tensor([0, 1, 0]) target = torch.tensor([1, 1, 0])
As you can see below, by default the inputs are treated
as multi-class. We can set multiclass=False
to treat the inputs as binary -
which is the same as converting the predictions to float beforehand.
>>> stat_scores(preds, target, reduce='macro', num_classes=2)
tensor([[1, 1, 1, 0, 1],
[1, 0, 1, 1, 2]])
>>> stat_scores(preds, target, reduce='macro', num_classes=1, multiclass=False)
tensor([[1, 0, 1, 1, 2]])
>>> stat_scores(preds.float(), target, reduce='macro', num_classes=1)
tensor([[1, 0, 1, 1, 2]])
Next, consider the opposite example: inputs are binary (as predictions are probabilities), but we would like to treat them as 2-class multi-class, to obtain the metric for both classes.
.. testcode:: preds = torch.tensor([0.2, 0.7, 0.3]) target = torch.tensor([1, 1, 0])
In this case we can set multiclass=True
, to treat the inputs as multi-class.
>>> stat_scores(preds, target, reduce='macro', num_classes=1)
tensor([[1, 0, 1, 1, 2]])
>>> stat_scores(preds, target, reduce='macro', num_classes=2, multiclass=True)
tensor([[1, 1, 1, 0, 1],
[1, 0, 1, 1, 2]])