Skip to content

Commit

Permalink
Weighted AUROC to omit empty classes (#376)
Browse files Browse the repository at this point in the history
* Skipping empty classes in weighted auroc.py

added logic to omit empty classes with weighted AUROC, as they should have 0 weight anyway

* formatting

* Added test to test_auroc.py

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Removed comment, reformatted prediction matrix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fmt

* debug -- fixing binary case

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Updated tests

Now checks for ValueError and UserWarning

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* updated changelog

* debug

* final bug

* Update auroc.py

* removed f-string

* Docstring; removed f-string

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Jirka Borovec <[email protected]>
Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
  • Loading branch information
4 people authored Jul 26, 2021
1 parent 397a089 commit f42256d
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 2 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Fixed

- Fixed `weighted`, `multi-class` AUROC computation to allow for 0 observations of some class, as contribution to final AUROC is 0 ([#348](https://github.com/PyTorchLightning/metrics/issues/348))


## [0.4.1] - 2021-07-05
Expand Down
28 changes: 28 additions & 0 deletions tests/classification/test_auroc.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,3 +195,31 @@ def test_error_multiclass_no_num_classes():
ValueError, match="Detected input to `multiclass` but you did not provide `num_classes` argument"
):
_ = auroc(torch.randn(20, 3).softmax(dim=-1), torch.randint(3, (20, )))


def test_weighted_with_empty_classes():
""" Tests that weighted multiclass AUROC calculation yields the same results if a new
but empty class exists. Tests that the proper warnings and errors are raised
"""
preds = torch.tensor([
[0.90, 0.05, 0.05],
[0.05, 0.90, 0.05],
[0.05, 0.05, 0.90],
[0.85, 0.05, 0.10],
[0.10, 0.10, 0.80],
])
target = torch.tensor([0, 1, 1, 2, 2])
num_classes = 3
_auroc = auroc(preds, target, average="weighted", num_classes=num_classes)

# Add in a class with zero observations at second to last index
preds = torch.cat((preds[:, :num_classes - 1], torch.rand_like(preds[:, 0:1]), preds[:, num_classes - 1:]), axis=1)
# Last class (2) gets moved to 3
target[target == num_classes - 1] = num_classes
with pytest.warns(UserWarning, match='Class 2 had 0 observations, omitted from AUROC calculation'):
_auroc_empty_class = auroc(preds, target, average="weighted", num_classes=num_classes + 1)
assert _auroc == _auroc_empty_class

target = torch.zeros_like(target)
with pytest.raises(ValueError, match='Found 1 non-empty class in `multiclass` AUROC calculation'):
_ = auroc(preds, target, average="weighted", num_classes=num_classes + 1)
20 changes: 18 additions & 2 deletions torchmetrics/functional/classification/auroc.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
from typing import Optional, Sequence, Tuple

import torch
Expand Down Expand Up @@ -87,8 +88,23 @@ def _auroc_compute(
else:
raise ValueError('Detected input to be `multilabel` but you did not provide `num_classes` argument')
else:
if mode != DataType.BINARY and num_classes is None:
raise ValueError('Detected input to `multiclass` but you did not provide `num_classes` argument')
if mode != DataType.BINARY:
if num_classes is None:
raise ValueError("Detected input to `multiclass` but you did not provide `num_classes` argument")
if average == AverageMethod.WEIGHTED and len(torch.unique(target)) < num_classes:
# If one or more classes has 0 observations, we should exclude them, as its weight will be 0
target_bool_mat = torch.zeros((len(target), num_classes), dtype=bool)
target_bool_mat[torch.arange(len(target)), target.long()] = 1
class_observed = target_bool_mat.sum(axis=0) > 0
for c in range(num_classes):
if not class_observed[c]:
warnings.warn(f'Class {c} had 0 observations, omitted from AUROC calculation', UserWarning)
preds = preds[:, class_observed]
target = target_bool_mat[:, class_observed]
target = torch.where(target)[1]
num_classes = class_observed.sum()
if num_classes == 1:
raise ValueError('Found 1 non-empty class in `multiclass` AUROC calculation')
fpr, tpr, _ = roc(preds, target, num_classes, pos_label, sample_weights)

# calculate standard roc auc score
Expand Down

0 comments on commit f42256d

Please sign in to comment.