Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Classification metrics overhaul: stat scores (3/n) #4839

Merged
merged 175 commits into from
Dec 30, 2020
Merged
Changes from 7 commits
Commits
Show all changes
175 commits
Select commit Hold shift + click to select a range
6959ea0
Add stuff
tadejsv Nov 24, 2020
0679015
Change metrics documentation layout
tadejsv Nov 24, 2020
35627b5
Add stuff
tadejsv Nov 24, 2020
0282f3c
Add stat scores
tadejsv Nov 24, 2020
55fdaaf
Change testing utils
tadejsv Nov 24, 2020
35f8320
Merge branch 'cls_metrics_input_formatting' into cls_metrics_accuracy
tadejsv Nov 24, 2020
dd05912
Merge branch 'cls_metrics_input_formatting' into cls_metrics_stat_scores
tadejsv Nov 24, 2020
5cbf56a
Replace len(*.shape) with *.ndim
tadejsv Nov 24, 2020
9c33d0b
More descriptive error message for input formatting
tadejsv Nov 24, 2020
6562205
Replace movedim with permute
tadejsv Nov 24, 2020
b97aef2
Merge branch 'cls_metrics_input_formatting' into cls_metrics_accuracy
tadejsv Nov 24, 2020
74261f7
Merge branch 'cls_metrics_input_formatting' into cls_metrics_stat_scores
tadejsv Nov 24, 2020
cbbc769
PEP 8 compliance
tadejsv Nov 24, 2020
33166c5
WIP
tadejsv Nov 24, 2020
801abe8
Add reduce_scores function
tadejsv Nov 24, 2020
fb181ed
Merge branch 'cls_metrics_accuracy' into cls_metrics_stat_scores
tadejsv Nov 24, 2020
fbebd34
Temporarily add back legacy class_reduce
tadejsv Nov 24, 2020
b3d1b8b
Merge branch 'cls_metrics_stat_scores' into cls_metrics_precision_recall
tadejsv Nov 24, 2020
f45fc81
Division with float
tadejsv Nov 24, 2020
3fdef40
Merge branch 'cls_metrics_accuracy' into cls_metrics_stat_scores
tadejsv Nov 24, 2020
452df32
Merge branch 'cls_metrics_stat_scores' into cls_metrics_precision_recall
tadejsv Nov 24, 2020
9d44a26
PEP 8 compliance
tadejsv Nov 24, 2020
82c3460
Merge branch 'cls_metrics_stat_scores' into cls_metrics_precision_recall
tadejsv Nov 24, 2020
5ce7cd9
Remove precision recall
tadejsv Nov 24, 2020
3b70270
Replace movedim with permute
tadejsv Nov 24, 2020
f1ae7b2
Add back tests
tadejsv Nov 24, 2020
04a5066
Add empty newlines
tadejsv Nov 25, 2020
9dc7bea
Add empty line
tadejsv Nov 25, 2020
a9640f6
Fix permute
tadejsv Nov 25, 2020
692392c
Fix some issues with old versions of PyTorch
tadejsv Nov 25, 2020
a04a71e
Style changes in error messages
tadejsv Nov 25, 2020
eaac5d7
More error message style improvements
tadejsv Nov 25, 2020
c1108f0
Fix typo in docs
tadejsv Nov 25, 2020
277769b
Add more descriptive variable names in utils
tadejsv Nov 25, 2020
4849298
Change internal var names
tadejsv Nov 25, 2020
22906a4
Merge remote-tracking branch 'upstream/master' into cls_metrics_input…
tadejsv Nov 25, 2020
1034a71
Merge branch 'cls_metrics_input_formatting' into cls_metrics_accuracy
tadejsv Nov 25, 2020
ebcdbeb
Merge branch 'cls_metrics_accuracy' into cls_metrics_stat_scores
tadejsv Nov 25, 2020
02bd636
Break down error checking for inputs into separate functions
tadejsv Nov 25, 2020
f97145b
Remove the (N, ..., C) option in MD-MC
tadejsv Nov 25, 2020
536feaf
Simplify select_topk
tadejsv Nov 25, 2020
4241d7c
Remove detach for inputs
tadejsv Nov 25, 2020
99d3c81
Merge branch 'cls_metrics_input_formatting' into cls_metrics_accuracy
tadejsv Nov 25, 2020
86d6c4d
Fix typos
tadejsv Nov 25, 2020
54c98a0
Merge branch 'cls_metrics_input_formatting' into cls_metrics_accuracy
tadejsv Nov 25, 2020
bb11677
Merge branch 'master' into cls_metrics_input_formatting
teddykoker Nov 25, 2020
bdc4111
Merge branch 'cls_metrics_input_formatting' into cls_metrics_accuracy
tadejsv Nov 25, 2020
cde3997
Update pytorch_lightning/metrics/classification/utils.py
tadejsv Nov 26, 2020
05a54da
Update docs/source/metrics.rst
tadejsv Nov 26, 2020
9a43a5e
Minor error message changes
tadejsv Nov 26, 2020
3f4ad3c
Update pytorch_lightning/metrics/utils.py
tadejsv Nov 26, 2020
a654e6a
Reuse case from validation in formatting
tadejsv Nov 26, 2020
7b2ef2b
Merge branch 'cls_metrics_input_formatting' of github.com:tadejsv/pyt…
tadejsv Nov 26, 2020
16ab8f7
Refactor code in _input_format_classification
tadejsv Nov 27, 2020
558276f
Merge branch 'master' into cls_metrics_input_formatting
tchaton Nov 27, 2020
ecffe18
Small improvements
tadejsv Nov 27, 2020
a907ade
Merge branch 'cls_metrics_input_formatting' into cls_metrics_accuracy
tadejsv Nov 27, 2020
725c7dd
PEP 8
tadejsv Nov 27, 2020
41ad0b7
Update pytorch_lightning/metrics/classification/utils.py
tadejsv Nov 27, 2020
ca13e76
Update pytorch_lightning/metrics/classification/utils.py
tadejsv Nov 27, 2020
ede2c7f
Update docs/source/metrics.rst
tadejsv Nov 27, 2020
c6e4de4
Update pytorch_lightning/metrics/classification/utils.py
tadejsv Nov 27, 2020
201d0de
Apply suggestions from code review
tadejsv Nov 27, 2020
f08edbc
Alphabetical reordering of regression metrics
tadejsv Nov 27, 2020
523bae3
Merge branch 'cls_metrics_input_formatting' of github.com:tadejsv/pyt…
tadejsv Nov 27, 2020
db24fae
Merge branch 'master' into cls_metrics_input_formatting
Borda Nov 27, 2020
35e3eff
Change default value of top_k and add error checking
tadejsv Nov 28, 2020
dd6f8ea
Merge branch 'cls_metrics_input_formatting' of github.com:tadejsv/pyt…
tadejsv Nov 28, 2020
c28aadf
Extract basic validation into separate function
tadejsv Nov 28, 2020
4bfc688
Merge branch 'cls_metrics_input_formatting' into cls_metrics_accuracy
tadejsv Nov 28, 2020
323285e
Update to new top_k default
tadejsv Nov 28, 2020
0cb0eac
Update desciption of parameters in input formatting
tadejsv Nov 29, 2020
28acf4c
Merge branch 'master' into cls_metrics_input_formatting
tchaton Nov 30, 2020
8e7a85a
Apply suggestions from code review
tadejsv Nov 30, 2020
829155e
Check that probabilities in preds sum to 1 (for MC)
tadejsv Nov 30, 2020
768879d
Fix coverage
tadejsv Nov 30, 2020
e4d88e2
Merge branch 'cls_metrics_input_formatting' into cls_metrics_accuracy
tadejsv Dec 1, 2020
eeded45
Split accuracy and hamming loss
tadejsv Dec 1, 2020
b49cfdc
Remove old redundant accuracy
tadejsv Dec 1, 2020
15ef14d
Merge branch 'master' into cls_metrics_input_formatting
teddykoker Dec 2, 2020
479114f
Merge branch 'master' into cls_metrics_stat_scores
tchaton Dec 3, 2020
3d8f584
Merge branch 'master' into cls_metrics_accuracy
tchaton Dec 3, 2020
1568970
Merge branch 'master' into cls_metrics_input_formatting
tchaton Dec 3, 2020
a9fa730
Merge with master and resolve conflicts
tadejsv Dec 6, 2020
44ad276
Merge branch 'master' into cls_metrics_input_formatting
Borda Dec 6, 2020
96d40c8
Minor changes
tadejsv Dec 6, 2020
cca430a
Merge branch 'cls_metrics_input_formatting' of github.com:tadejsv/pyt…
tadejsv Dec 6, 2020
b0bde16
Merge branch 'cls_metrics_input_formatting' into cls_metrics_accuracy
tadejsv Dec 6, 2020
627d99a
Fix imports
tadejsv Dec 6, 2020
de3defb
Improve docstring descriptions
tadejsv Dec 6, 2020
218ff56
Merge branch 'cls_metrics_accuracy' into cls_metrics_stat_scores
tadejsv Dec 6, 2020
c24d47b
Fix imports
tadejsv Dec 6, 2020
f3c47f9
Fix edge case and simplify testing
tadejsv Dec 6, 2020
a7e91a9
Merge branch 'cls_metrics_input_formatting' into cls_metrics_accuracy
tadejsv Dec 6, 2020
94c1af6
Merge branch 'cls_metrics_accuracy' into cls_metrics_stat_scores
tadejsv Dec 6, 2020
b7ced6e
Fix docs
tadejsv Dec 6, 2020
e91e564
PEP8
tadejsv Dec 6, 2020
798ec03
Reorder imports
tadejsv Dec 6, 2020
ccdc421
Merge branch 'cls_metrics_accuracy' into cls_metrics_stat_scores
tadejsv Dec 6, 2020
658bfb1
Add top_k parameter
tadejsv Dec 6, 2020
7217924
Merge remote-tracking branch 'upstream/master' into cls_metrics_accuracy
tadejsv Dec 7, 2020
a7c143e
Update changelog
tadejsv Dec 7, 2020
531ae33
Update docstring
tadejsv Dec 7, 2020
2eba226
Merge branch 'master' into cls_metrics_accuracy
tadejsv Dec 7, 2020
a66cf31
Update docstring
tadejsv Dec 7, 2020
e93f83e
Merge branch 'cls_metrics_accuracy' of github.com:tadejsv/pytorch-lig…
tadejsv Dec 7, 2020
89b09f8
Reverse formatting changes for tests
tadejsv Dec 7, 2020
e715437
Change parameter order
tadejsv Dec 7, 2020
d5daec8
Remove formatting changes 2/2
tadejsv Dec 7, 2020
c820060
Remove formatting 3/3
tadejsv Dec 7, 2020
b576de0
.
tadejsv Dec 7, 2020
dae341b
Improve description of top_k parameter
tadejsv Dec 7, 2020
43136b2
Merge branch 'cls_metrics_accuracy' into cls_metrics_stat_scores
tadejsv Dec 7, 2020
b2d2b71
Apply suggestions from code review
Borda Dec 7, 2020
9b2a399
Apply suggestions from code review
tadejsv Dec 7, 2020
0952df2
Remove unneeded assert
tadejsv Dec 7, 2020
c7fe698
Update pytorch_lightning/metrics/functional/accuracy.py
tadejsv Dec 7, 2020
e2bc0ab
Remove unneeded assert
tadejsv Dec 7, 2020
acbd1ca
Merge branch 'cls_metrics_accuracy' of github.com:tadejsv/pytorch-lig…
tadejsv Dec 7, 2020
8801f8a
Explicit checking of parameter values
tadejsv Dec 7, 2020
c32b36e
Apply suggestions from code review
Borda Dec 7, 2020
0314c7d
Apply suggestions from code review
tadejsv Dec 7, 2020
152cadf
Fix top_k checking
tadejsv Dec 7, 2020
022d6a6
PEP8
tadejsv Dec 7, 2020
9efc963
Don't check dist_sync in test
tadejsv Dec 8, 2020
d992f7d
add back check_dist_sync_on_step
tadejsv Dec 8, 2020
a726060
Make sure half-precision inputs are transformed (#5013)
tadejsv Dec 8, 2020
93c5d02
Fix typo
tadejsv Dec 8, 2020
0813055
Rename hamming loss to hamming distance
tadejsv Dec 8, 2020
6bf714b
Fix tests for half precision
tadejsv Dec 8, 2020
d12f1d6
Fix docs underline length
tadejsv Dec 8, 2020
a55cb46
Fix doc undeline length
tadejsv Dec 8, 2020
d75eec3
Merge branch 'master' into cls_metrics_accuracy
justusschock Dec 8, 2020
6b3b057
Replace mdmc_accuracy parameter with subset_accuracy
tadejsv Dec 8, 2020
6f218d4
Merge branch 'cls_metrics_accuracy' of github.com:tadejsv/pytorch-lig…
tadejsv Dec 8, 2020
98cb5f4
Update changelog
tadejsv Dec 8, 2020
778aeae
Merge branch 'cls_metrics_accuracy' into cls_metrics_stat_scores
tadejsv Dec 8, 2020
d129ccb
Merge remote-tracking branch 'upstream/release/1.2-dev' into cls_metr…
tadejsv Dec 21, 2020
7eb1457
Fix unwanted accuracy change
tadejsv Dec 21, 2020
207a762
Enable top_k for ML prob inputs
tadejsv Dec 21, 2020
3b79348
Test that default threshold is 0.5
tadejsv Dec 21, 2020
b609b35
Fix typo
tadejsv Dec 21, 2020
633e3ff
Update top_k description in helpers
tadejsv Dec 23, 2020
82879d0
Merge remote-tracking branch 'upstream/release/1.2-dev' into cls_metr…
tadejsv Dec 23, 2020
103dfc6
updates
tadejsv Dec 23, 2020
9be50aa
Update styling and add back tests
tadejsv Dec 23, 2020
d3f851c
Remove excess spaces
tadejsv Dec 23, 2020
1612139
fix torch.where for old versions
tadejsv Dec 23, 2020
ca03c4a
fix linting
tadejsv Dec 23, 2020
aea4c66
Update docstring
tadejsv Dec 23, 2020
7b4dcc1
Fix docstring
tadejsv Dec 23, 2020
9cd07a8
Apply suggestions from code review (mostly docs)
tadejsv Dec 24, 2020
a713fc7
Default threshold to None, accept only (0,1)
tadejsv Dec 24, 2020
075ed53
Change wrong threshold message
tadejsv Dec 24, 2020
c289f0c
Improve documentation and add tests
tadejsv Dec 25, 2020
aae5141
Merge branch 'tests_mprc' into cls_metrics_stat_scores
tadejsv Dec 25, 2020
e665f89
Add back ddp tests
tadejsv Dec 27, 2020
16d29bf
Change stat reduce method and default
tadejsv Dec 27, 2020
7e8fb8e
Remove DDP tests and fix doctests
tadejsv Dec 28, 2020
d1a4eff
Fix doctest
tadejsv Dec 28, 2020
01e8e63
Update changelog
tadejsv Dec 28, 2020
3e58244
Refactoring
tadejsv Dec 28, 2020
475c706
Fix typo
tadejsv Dec 28, 2020
d387eb1
Refactor
tadejsv Dec 28, 2020
d2a92e8
Increase coverage
tadejsv Dec 28, 2020
c178cb6
Fix linting
tadejsv Dec 28, 2020
8bf6cf1
Consistent use of backticks
tadejsv Dec 29, 2020
b2fcd55
Merge remote-tracking branch 'upstream/release/1.2-dev' into cls_metr…
tadejsv Dec 29, 2020
169fc7c
Fix too long line in docs
tadejsv Dec 29, 2020
21551f1
Apply suggestions from code review
tadejsv Dec 29, 2020
e52fa9c
Fix deprecation test
tadejsv Dec 29, 2020
85d6e3a
Fix deprecation test
tadejsv Dec 29, 2020
3461159
Default threshold back to 0.5
tadejsv Dec 29, 2020
fe48912
Minor documentation fixes
tadejsv Dec 30, 2020
c2c45f1
Add types to tests
tadejsv Dec 30, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 44 additions & 20 deletions docs/source/metrics.rst
Original file line number Diff line number Diff line change
@@ -268,6 +268,30 @@ Accuracy
.. autoclass:: pytorch_lightning.metrics.classification.Accuracy
:noindex:

AveragePrecision
~~~~~~~~~~~~~~~~

.. autoclass:: pytorch_lightning.metrics.classification.AveragePrecision
:noindex:

ConfusionMatrix
~~~~~~~~~~~~~~~

.. autoclass:: pytorch_lightning.metrics.classification.ConfusionMatrix
:noindex:

F1
~~

.. autoclass:: pytorch_lightning.metrics.classification.F1
:noindex:

FBeta
~~~~~

.. autoclass:: pytorch_lightning.metrics.classification.FBeta
:noindex:

Hamming Loss
~~~~~~~~~~~~

@@ -280,29 +304,24 @@ Precision
.. autoclass:: pytorch_lightning.metrics.classification.Precision
:noindex:

Recall
~~~~~~
PrecisionRecallCurve
~~~~~~~~~~~~~~~~~~~~

.. autoclass:: pytorch_lightning.metrics.classification.Recall
.. autoclass:: pytorch_lightning.metrics.classification.PrecisionRecallCurve
:noindex:

FBeta
~~~~~
Recall
~~~~~~

.. autoclass:: pytorch_lightning.metrics.classification.FBeta
.. autoclass:: pytorch_lightning.metrics.classification.Recall
:noindex:

F1
~~
ROC
~~~

.. autoclass:: pytorch_lightning.metrics.classification.F1
.. autoclass:: pytorch_lightning.metrics.classification.ROC
:noindex:

ConfusionMatrix
~~~~~~~~~~~~~~~

.. autoclass:: pytorch_lightning.metrics.classification.ConfusionMatrix
:noindex:

Functional Metrics (Classification)
-----------------------------------
@@ -313,12 +332,6 @@ accuracy [func]
.. autofunction:: pytorch_lightning.metrics.functional.accuracy
:noindex:

hamming_loss [func]
~~~~~~~~~~~~~~~~~~~

.. autofunction:: pytorch_lightning.metrics.functional.hamming_loss
:noindex:

auc [func]
~~~~~~~~~~

@@ -374,6 +387,11 @@ fbeta [func]
.. autofunction:: pytorch_lightning.metrics.functional.fbeta
:noindex:

hamming_loss [func]
~~~~~~~~~~~~~~~~~~~

.. autofunction:: pytorch_lightning.metrics.functional.hamming_loss
:noindex:

iou [func]
~~~~~~~~~~
@@ -416,6 +434,12 @@ recall [func]
.. autofunction:: pytorch_lightning.metrics.functional.classification.recall
:noindex:

select_topk [func]
~~~~~~~~~~~~~~~~~~~~~

.. autofunction:: pytorch_lightning.metrics.utils.select_topk
:noindex:


stat_scores [func]
~~~~~~~~~~~~~~~~~~
2 changes: 1 addition & 1 deletion pytorch_lightning/metrics/classification/__init__.py
Original file line number Diff line number Diff line change
@@ -14,7 +14,7 @@
from pytorch_lightning.metrics.classification.accuracy import Accuracy
from pytorch_lightning.metrics.classification.hamming_loss import HammingLoss
from pytorch_lightning.metrics.classification.precision_recall import Precision, Recall
from pytorch_lightning.metrics.classification.f_beta import FBeta, F1
from pytorch_lightning.metrics.classification.average_precision import AveragePrecision
from pytorch_lightning.metrics.classification.confusion_matrix import ConfusionMatrix
from pytorch_lightning.metrics.classification.f_beta import FBeta, F1
from pytorch_lightning.metrics.classification.precision_recall import Precision, Recall
1 change: 1 addition & 0 deletions pytorch_lightning/metrics/classification/accuracy.py
Original file line number Diff line number Diff line change
@@ -14,6 +14,7 @@
from typing import Any, Callable, Optional

import torch

from pytorch_lightning.metrics.metric import Metric
from pytorch_lightning.metrics.functional.accuracy import _accuracy_update, _accuracy_compute

226 changes: 5 additions & 221 deletions pytorch_lightning/metrics/functional/classification.py
Original file line number Diff line number Diff line change
@@ -21,90 +21,6 @@
from pytorch_lightning.utilities import rank_zero_warn


def to_onehot(
tensor: torch.Tensor,
num_classes: Optional[int] = None,
) -> torch.Tensor:
"""
Converts a dense label tensor to one-hot format

Args:
tensor: dense label tensor, with shape [N, d1, d2, ...]
num_classes: number of classes C

Output:
A sparse label tensor with shape [N, C, d1, d2, ...]

Example:

>>> x = torch.tensor([1, 2, 3])
>>> to_onehot(x)
tensor([[0, 1, 0, 0],
[0, 0, 1, 0],
[0, 0, 0, 1]])

"""
if num_classes is None:
num_classes = int(tensor.max().detach().item() + 1)
dtype, device, shape = tensor.dtype, tensor.device, tensor.shape
tensor_onehot = torch.zeros(shape[0], num_classes, *shape[1:], dtype=dtype, device=device)
index = tensor.long().unsqueeze(1).expand_as(tensor_onehot)
return tensor_onehot.scatter_(1, index, 1.0)


def to_categorical(tensor: torch.Tensor, argmax_dim: int = 1) -> torch.Tensor:
"""
Converts a tensor of probabilities to a dense label tensor

Args:
tensor: probabilities to get the categorical label [N, d1, d2, ...]
argmax_dim: dimension to apply

Return:
A tensor with categorical labels [N, d2, ...]

Example:

>>> x = torch.tensor([[0.2, 0.5], [0.9, 0.1]])
>>> to_categorical(x)
tensor([1, 0])

"""
return torch.argmax(tensor, dim=argmax_dim)


def get_num_classes(
pred: torch.Tensor,
target: torch.Tensor,
num_classes: Optional[int] = None,
) -> int:
"""
Calculates the number of classes for a given prediction and target tensor.

Args:
pred: predicted values
target: true labels
num_classes: number of classes if known

Return:
An integer that represents the number of classes.
"""
num_target_classes = int(target.max().detach().item() + 1)
num_pred_classes = int(pred.max().detach().item() + 1)
num_all_classes = max(num_target_classes, num_pred_classes)

if num_classes is None:
num_classes = num_all_classes
elif num_classes != num_all_classes:
rank_zero_warn(
f"You have set {num_classes} number of classes which is"
f" different from predicted ({num_pred_classes}) and"
f" target ({num_target_classes}) number of classes",
RuntimeWarning,
)
return num_classes


def stat_scores(
pred: torch.Tensor,
target: torch.Tensor,
@@ -418,7 +334,8 @@ def _binary_clf_curve(
return fps, tps, pred[threshold_idxs]


def roc(
# TODO: deprecated in favor of general ROC in pytorch_lightning/metrics/functional/roc.py
def __roc(
pred: torch.Tensor,
target: torch.Tensor,
sample_weight: Optional[Sequence] = None,
@@ -472,7 +389,8 @@ def roc(
return fpr, tpr, thresholds


def multiclass_roc(
# TODO: deprecated in favor of general ROC in pytorch_lightning/metrics/functional/roc.py
def __multiclass_roc(
pred: torch.Tensor,
target: torch.Tensor,
sample_weight: Optional[Sequence] = None,
@@ -512,111 +430,11 @@ def multiclass_roc(
for c in range(num_classes):
pred_c = pred[:, c]

class_roc_vals.append(roc(pred=pred_c, target=target, sample_weight=sample_weight, pos_label=c))
class_roc_vals.append(__roc(pred=pred_c, target=target, sample_weight=sample_weight, pos_label=c))

return tuple(class_roc_vals)


def precision_recall_curve(
pred: torch.Tensor,
target: torch.Tensor,
sample_weight: Optional[Sequence] = None,
pos_label: int = 1.0,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Computes precision-recall pairs for different thresholds.

Args:
pred: estimated probabilities
target: ground-truth labels
sample_weight: sample weights
pos_label: the label for the positive class

Return:
precision, recall, thresholds

Example:

>>> pred = torch.tensor([0, 1, 2, 3])
>>> target = torch.tensor([0, 1, 1, 0])
>>> precision, recall, thresholds = precision_recall_curve(pred, target)
>>> precision
tensor([0.6667, 0.5000, 0.0000, 1.0000])
>>> recall
tensor([1.0000, 0.5000, 0.0000, 0.0000])
>>> thresholds
tensor([1, 2, 3])

"""
fps, tps, thresholds = _binary_clf_curve(pred=pred, target=target, sample_weight=sample_weight, pos_label=pos_label)

precision = tps / (tps + fps)
recall = tps / tps[-1]

# stop when full recall attained
# and reverse the outputs so recall is decreasing
last_ind = torch.where(tps == tps[-1])[0][0]
sl = slice(0, last_ind.item() + 1)

# need to call reversed explicitly, since including that to slice would
# introduce negative strides that are not yet supported in pytorch
precision = torch.cat([reversed(precision[sl]), torch.ones(1, dtype=precision.dtype, device=precision.device)])

recall = torch.cat([reversed(recall[sl]), torch.zeros(1, dtype=recall.dtype, device=recall.device)])

thresholds = torch.tensor(reversed(thresholds[sl]))

return precision, recall, thresholds


def multiclass_precision_recall_curve(
pred: torch.Tensor,
target: torch.Tensor,
sample_weight: Optional[Sequence] = None,
num_classes: Optional[int] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Computes precision-recall pairs for different thresholds given a multiclass scores.

Args:
pred: estimated probabilities
target: ground-truth labels
sample_weight: sample weight
num_classes: number of classes

Return:
number of classes, precision, recall, thresholds

Example:

>>> pred = torch.tensor([[0.85, 0.05, 0.05, 0.05],
... [0.05, 0.85, 0.05, 0.05],
... [0.05, 0.05, 0.85, 0.05],
... [0.05, 0.05, 0.05, 0.85]])
>>> target = torch.tensor([0, 1, 3, 2])
>>> nb_classes, precision, recall, thresholds = multiclass_precision_recall_curve(pred, target)
>>> nb_classes
(tensor([1., 1.]), tensor([1., 0.]), tensor([0.8500]))
>>> precision
(tensor([1., 1.]), tensor([1., 0.]), tensor([0.8500]))
>>> recall
(tensor([0.2500, 0.0000, 1.0000]), tensor([1., 0., 0.]), tensor([0.0500, 0.8500]))
>>> thresholds # doctest: +NORMALIZE_WHITESPACE
(tensor([0.2500, 0.0000, 1.0000]), tensor([1., 0., 0.]), tensor([0.0500, 0.8500]))
"""
num_classes = get_num_classes(pred, target, num_classes)

class_pr_vals = []
for c in range(num_classes):
pred_c = pred[:, c]

class_pr_vals.append(
precision_recall_curve(pred=pred_c, target=target, sample_weight=sample_weight, pos_label=c)
)

return tuple(class_pr_vals)


def auc(x: torch.Tensor, y: torch.Tensor, reorder: bool = True) -> torch.Tensor:
"""
Computes Area Under the Curve (AUC) using the trapezoidal rule
@@ -790,40 +608,6 @@ def _multiclass_auroc(pred, target, sample_weight, num_classes):
return torch.mean(class_aurocs)


def average_precision(
pred: torch.Tensor,
target: torch.Tensor,
sample_weight: Optional[Sequence] = None,
pos_label: int = 1.0,
) -> torch.Tensor:
"""
Compute average precision from prediction scores

Args:
pred: estimated probabilities
target: ground-truth labels
sample_weight: sample weights
pos_label: the label for the positive class

Return:
Tensor containing average precision score

Example:

>>> x = torch.tensor([0, 1, 2, 3])
>>> y = torch.tensor([0, 1, 2, 2])
>>> average_precision(x, y)
tensor(0.3333)
"""
precision, recall, _ = precision_recall_curve(
pred=pred, target=target, sample_weight=sample_weight, pos_label=pos_label
)
# Return the step function integral
# The following works because the last entry of precision is
# guaranteed to be 1, as returned by precision_recall_curve
return -torch.sum((recall[1:] - recall[:-1]) * precision[:-1])


def dice_score(
pred: torch.Tensor,
target: torch.Tensor,
Loading