-
Notifications
You must be signed in to change notification settings - Fork 3.4k
/
precision_recall_curve.py
221 lines (186 loc) · 8.78 KB
/
precision_recall_curve.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Sequence, Tuple, List, Union
import torch
import torch.nn.functional as F
from pytorch_lightning.utilities import rank_zero_warn
def _binary_clf_curve(
preds: torch.Tensor,
target: torch.Tensor,
sample_weights: Optional[Sequence] = None,
pos_label: int = 1.,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
adapted from https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/metrics/_ranking.py
"""
if sample_weights is not None and not isinstance(sample_weights, torch.Tensor):
sample_weights = torch.tensor(sample_weights, device=preds.device, dtype=torch.float)
# remove class dimension if necessary
if preds.ndim > target.ndim:
preds = preds[:, 0]
desc_score_indices = torch.argsort(preds, descending=True)
preds = preds[desc_score_indices]
target = target[desc_score_indices]
if sample_weights is not None:
weight = sample_weights[desc_score_indices]
else:
weight = 1.
# pred typically has many tied values. Here we extract
# the indices associated with the distinct values. We also
# concatenate a value for the end of the curve.
distinct_value_indices = torch.where(preds[1:] - preds[:-1])[0]
threshold_idxs = F.pad(distinct_value_indices, (0, 1), value=target.size(0) - 1)
target = (target == pos_label).to(torch.long)
tps = torch.cumsum(target * weight, dim=0)[threshold_idxs]
if sample_weights is not None:
# express fps as a cumsum to ensure fps is increasing even in
# the presence of floating point errors
fps = torch.cumsum((1 - target) * weight, dim=0)[threshold_idxs]
else:
fps = 1 + threshold_idxs - tps
return fps, tps, preds[threshold_idxs]
def _precision_recall_curve_update(
preds: torch.Tensor,
target: torch.Tensor,
num_classes: Optional[int] = None,
pos_label: Optional[int] = None,
) -> Tuple[torch.Tensor, torch.Tensor, int, int]:
if not (len(preds.shape) == len(target.shape) or len(preds.shape) == len(target.shape) + 1):
raise ValueError(
"preds and target must have same number of dimensions, or one additional dimension for preds"
)
# single class evaluation
if len(preds.shape) == len(target.shape):
if num_classes is not None and num_classes != 1:
raise ValueError('Preds and target have equal shape, but number of classes is different from 1')
num_classes = 1
if pos_label is None:
rank_zero_warn('`pos_label` automatically set 1.')
pos_label = 1
preds = preds.flatten()
target = target.flatten()
# multi class evaluation
if len(preds.shape) == len(target.shape) + 1:
if pos_label is not None:
rank_zero_warn('Argument `pos_label` should be `None` when running'
f'multiclass precision recall curve. Got {pos_label}')
if num_classes != preds.shape[1]:
raise ValueError(f'Argument `num_classes` was set to {num_classes} in'
f'metric `precision_recall_curve` but detected {preds.shape[1]}'
'number of classes from predictions')
preds = preds.transpose(0, 1).reshape(num_classes, -1).transpose(0, 1)
target = target.flatten()
return preds, target, num_classes, pos_label
def _precision_recall_curve_compute(
preds: torch.Tensor,
target: torch.Tensor,
num_classes: int,
pos_label: int,
sample_weights: Optional[Sequence] = None,
) -> Union[Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
Tuple[List[torch.Tensor], List[torch.Tensor], List[torch.Tensor]]]:
if num_classes == 1:
fps, tps, thresholds = _binary_clf_curve(
preds=preds,
target=target,
sample_weights=sample_weights,
pos_label=pos_label
)
precision = tps / (tps + fps)
recall = tps / tps[-1]
# stop when full recall attained
# and reverse the outputs so recall is decreasing
last_ind = torch.where(tps == tps[-1])[0][0]
sl = slice(0, last_ind.item() + 1)
# need to call reversed explicitly, since including that to slice would
# introduce negative strides that are not yet supported in pytorch
precision = torch.cat([reversed(precision[sl]),
torch.ones(1, dtype=precision.dtype,
device=precision.device)])
recall = torch.cat([reversed(recall[sl]),
torch.zeros(1, dtype=recall.dtype,
device=recall.device)])
thresholds = reversed(thresholds[sl]).clone()
return precision, recall, thresholds
# Recursively call per class
precision, recall, thresholds = [], [], []
for c in range(num_classes):
preds_c = preds[:, c]
res = precision_recall_curve(
preds=preds_c,
target=target,
num_classes=1,
pos_label=c,
sample_weights=sample_weights,
)
precision.append(res[0])
recall.append(res[1])
thresholds.append(res[2])
return precision, recall, thresholds
def precision_recall_curve(
preds: torch.Tensor,
target: torch.Tensor,
num_classes: Optional[int] = None,
pos_label: Optional[int] = None,
sample_weights: Optional[Sequence] = None,
) -> Union[Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
Tuple[List[torch.Tensor], List[torch.Tensor], List[torch.Tensor]]]:
"""
Computes precision-recall pairs for different thresholds.
Args:
num_classes: integer with number of classes. Not nessesary to provide
for binary problems.
pos_label: integer determining the positive class. Default is ``None``
which for binary problem is translate to 1. For multiclass problems
this argument should not be set as we iteratively change it in the
range [0,num_classes-1]
sample_weight: sample weights for each data point
Returns: 3-element tuple containing
precision:
tensor where element i is the precision of predictions with
score >= thresholds[i] and the last element is 1.
If multiclass, this is a list of such tensors, one for each class.
recall:
tensor where element i is the recall of predictions with
score >= thresholds[i] and the last element is 0.
If multiclass, this is a list of such tensors, one for each class.
thresholds:
Thresholds used for computing precision/recall scores
Example (binary case):
>>> pred = torch.tensor([0, 1, 2, 3])
>>> target = torch.tensor([0, 1, 1, 0])
>>> precision, recall, thresholds = precision_recall_curve(pred, target, pos_label=1)
>>> precision
tensor([0.6667, 0.5000, 0.0000, 1.0000])
>>> recall
tensor([1.0000, 0.5000, 0.0000, 0.0000])
>>> thresholds
tensor([1, 2, 3])
Example (multiclass case):
>>> pred = torch.tensor([[0.75, 0.05, 0.05, 0.05, 0.05],
... [0.05, 0.75, 0.05, 0.05, 0.05],
... [0.05, 0.05, 0.75, 0.05, 0.05],
... [0.05, 0.05, 0.05, 0.75, 0.05]])
>>> target = torch.tensor([0, 1, 3, 2])
>>> precision, recall, thresholds = precision_recall_curve(pred, target, num_classes=5)
>>> precision
[tensor([1., 1.]), tensor([1., 1.]), tensor([0.2500, 0.0000, 1.0000]), tensor([0.2500, 0.0000, 1.0000]), tensor([0., 1.])]
>>> recall
[tensor([1., 0.]), tensor([1., 0.]), tensor([1., 0., 0.]), tensor([1., 0., 0.]), tensor([nan, 0.])]
>>> thresholds
[tensor([0.7500]), tensor([0.7500]), tensor([0.0500, 0.7500]), tensor([0.0500, 0.7500]), tensor([0.0500])]
"""
preds, target, num_classes, pos_label = _precision_recall_curve_update(preds, target,
num_classes, pos_label)
return _precision_recall_curve_compute(preds, target, num_classes, pos_label, sample_weights)