From 55236550ec286766f313977a2baa02d24b9b5836 Mon Sep 17 00:00:00 2001 From: Ramon Emiliani Date: Mon, 17 Jan 2022 16:40:05 -0500 Subject: [PATCH 01/13] Improve speed by removing for loop and using bucketize + scatter_add. --- .../classification/calibration_error.py | 26 +++++++++++-------- 1 file changed, 15 insertions(+), 11 deletions(-) diff --git a/torchmetrics/functional/classification/calibration_error.py b/torchmetrics/functional/classification/calibration_error.py index 32d493f9839..c4999ffaaf5 100644 --- a/torchmetrics/functional/classification/calibration_error.py +++ b/torchmetrics/functional/classification/calibration_error.py @@ -46,17 +46,21 @@ def _ce_compute( if norm not in {"l1", "l2", "max"}: raise ValueError(f"Norm {norm} is not supported. Please select from l1, l2, or max. ") - conf_bin = torch.zeros_like(bin_boundaries) - acc_bin = torch.zeros_like(bin_boundaries) - prop_bin = torch.zeros_like(bin_boundaries) - for i, (bin_lower, bin_upper) in enumerate(zip(bin_boundaries[:-1], bin_boundaries[1:])): - # Calculated confidence and accuracy in each bin - in_bin = confidences.gt(bin_lower.item()) * confidences.le(bin_upper.item()) - prop_in_bin = in_bin.float().mean() - if prop_in_bin.item() > 0: - acc_bin[i] = accuracies[in_bin].float().mean() - conf_bin[i] = confidences[in_bin].mean() - prop_bin[i] = prop_in_bin + acc_bin = torch.zeros(len(bin_boundaries) - 1) + conf_bin = torch.zeros(len(bin_boundaries) - 1) + count_bin = torch.zeros(len(bin_boundaries) - 1) + + indices = torch.bucketize(confidences, bin_boundaries) - 1 + + count_bin.scatter_add_(dim=0, index=indices, src=torch.ones_like(confidences)) + + conf_bin.scatter_add_(dim=0, index=indices, src=confidences) + conf_bin = torch.nan_to_num(conf_bin / count_bin) + + acc_bin.scatter_add_(dim=0, index=indices, src=accuracies) + acc_bin = torch.nan_to_num(acc_bin / count_bin) + + prop_bin = count_bin / count_bin.sum() if norm == "l1": ce = torch.sum(torch.abs(acc_bin - conf_bin) * prop_bin) From b3a93627e4b51b6aea8c17e4057a9df2cd267ee0 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Tue, 18 Jan 2022 16:19:51 +0100 Subject: [PATCH 02/13] device Co-authored-by: Nicki Skafte Detlefsen --- torchmetrics/functional/classification/calibration_error.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torchmetrics/functional/classification/calibration_error.py b/torchmetrics/functional/classification/calibration_error.py index c4999ffaaf5..48460e6c809 100644 --- a/torchmetrics/functional/classification/calibration_error.py +++ b/torchmetrics/functional/classification/calibration_error.py @@ -46,9 +46,9 @@ def _ce_compute( if norm not in {"l1", "l2", "max"}: raise ValueError(f"Norm {norm} is not supported. Please select from l1, l2, or max. ") - acc_bin = torch.zeros(len(bin_boundaries) - 1) - conf_bin = torch.zeros(len(bin_boundaries) - 1) - count_bin = torch.zeros(len(bin_boundaries) - 1) + acc_bin = torch.zeros(len(bin_boundaries) - 1, device=confidences.device) + conf_bin = torch.zeros(len(bin_boundaries) - 1, device=confidences.device) + count_bin = torch.zeros(len(bin_boundaries) - 1, device=confidences.device) indices = torch.bucketize(confidences, bin_boundaries) - 1 From 0848523aef63c02b89127a70b2b076e94553838d Mon Sep 17 00:00:00 2001 From: Ashutosh Kumar Date: Wed, 19 Jan 2022 03:38:22 +0530 Subject: [PATCH 03/13] Remove deprecated functions, and warnings - Text (#773) * Remove deprecated functions, and warnings * Update links for docstring Co-authored-by: Daniel Stancl <46073029+stancld@users.noreply.github.com> Co-authored-by: Jirka Borovec --- docs/source/links.rst | 2 +- docs/source/references/functional.rst | 6 +-- docs/source/references/modules.rst | 6 +-- torchmetrics/__init__.py | 2 - torchmetrics/functional/__init__.py | 3 +- torchmetrics/functional/text/__init__.py | 2 +- torchmetrics/functional/text/bert.py | 16 -------- torchmetrics/functional/text/bleu.py | 23 ----------- torchmetrics/functional/text/cer.py | 17 +-------- torchmetrics/functional/text/sacre_bleu.py | 21 ----------- torchmetrics/functional/text/wer.py | 41 ++------------------ torchmetrics/text/__init__.py | 2 +- torchmetrics/text/bert.py | 16 -------- torchmetrics/text/bleu.py | 16 -------- torchmetrics/text/cer.py | 16 -------- torchmetrics/text/sacre_bleu.py | 21 ----------- torchmetrics/text/wer.py | 44 +--------------------- 17 files changed, 15 insertions(+), 239 deletions(-) diff --git a/docs/source/links.rst b/docs/source/links.rst index eeaced1807c..d0d54d5223b 100644 --- a/docs/source/links.rst +++ b/docs/source/links.rst @@ -28,7 +28,7 @@ .. _sklearn averaging methods: https://scikit-learn.org/stable/modules/model_evaluation.html#multiclass-and-multilabel-classification .. _Cosine Similarity: https://en.wikipedia.org/wiki/Cosine_similarity .. _spearmans rank correlation coefficient: https://en.wikipedia.org/wiki/Spearman%27s_rank_correlation_coefficient -.. _WER: https://en.wikipedia.org/wiki/Word_error_rate +.. _WordErrorRate: https://en.wikipedia.org/wiki/Word_error_rate .. _FID: https://en.wikipedia.org/wiki/Fr%C3%A9chet_inception_distance .. _mean-squared-error: https://en.wikipedia.org/wiki/Mean_squared_error .. _SSIM: https://en.wikipedia.org/wiki/Structural_similarity diff --git a/docs/source/references/functional.rst b/docs/source/references/functional.rst index 7f3a105a2ab..66d1e20c97a 100644 --- a/docs/source/references/functional.rst +++ b/docs/source/references/functional.rst @@ -507,10 +507,10 @@ translation_edit_rate [func] .. autofunction:: torchmetrics.functional.translation_edit_rate :noindex: -wer [func] -~~~~~~~~~~ +word_error_rate [func] +~~~~~~~~~~~~~~~~~~~~~~ -.. autofunction:: torchmetrics.functional.wer +.. autofunction:: torchmetrics.functional.word_error_rate :noindex: word_information_lost [func] diff --git a/docs/source/references/modules.rst b/docs/source/references/modules.rst index 2d6d375b228..65e5bb11802 100644 --- a/docs/source/references/modules.rst +++ b/docs/source/references/modules.rst @@ -678,10 +678,10 @@ TranslationEditRate .. autoclass:: torchmetrics.TranslationEditRate :noindex: -WER -~~~ +WordErrorRate +~~~~~~~~~~~~~ -.. autoclass:: torchmetrics.WER +.. autoclass:: torchmetrics.WordErrorRate :noindex: WordInfoLost diff --git a/torchmetrics/__init__.py b/torchmetrics/__init__.py index c3c1b43f5da..75811b31cab 100644 --- a/torchmetrics/__init__.py +++ b/torchmetrics/__init__.py @@ -90,7 +90,6 @@ RetrievalRPrecision, ) from torchmetrics.text import ( # noqa: E402 - WER, BLEUScore, CharErrorRate, CHRFScore, @@ -187,7 +186,6 @@ "SumMetric", "SymmetricMeanAbsolutePercentageError", "TranslationEditRate", - "WER", "WordErrorRate", "CharErrorRate", "MatchErrorRate", diff --git a/torchmetrics/functional/__init__.py b/torchmetrics/functional/__init__.py index 95660715f2c..fd2113295ed 100644 --- a/torchmetrics/functional/__init__.py +++ b/torchmetrics/functional/__init__.py @@ -76,7 +76,7 @@ from torchmetrics.functional.text.sacre_bleu import sacre_bleu_score from torchmetrics.functional.text.squad import squad from torchmetrics.functional.text.ter import translation_edit_rate -from torchmetrics.functional.text.wer import wer, word_error_rate +from torchmetrics.functional.text.wer import word_error_rate from torchmetrics.functional.text.wil import word_information_lost from torchmetrics.functional.text.wip import word_information_preserved from torchmetrics.utilities.imports import _TRANSFORMERS_AUTO_AVAILABLE @@ -158,7 +158,6 @@ "stat_scores", "symmetric_mean_absolute_percentage_error", "translation_edit_rate", - "wer", "word_error_rate", "char_error_rate", "match_error_rate", diff --git a/torchmetrics/functional/text/__init__.py b/torchmetrics/functional/text/__init__.py index e4e0161443a..14b982f90eb 100644 --- a/torchmetrics/functional/text/__init__.py +++ b/torchmetrics/functional/text/__init__.py @@ -20,7 +20,7 @@ from torchmetrics.functional.text.sacre_bleu import sacre_bleu_score # noqa: F401 from torchmetrics.functional.text.squad import squad # noqa: F401 from torchmetrics.functional.text.ter import translation_edit_rate # noqa: F401 -from torchmetrics.functional.text.wer import wer, word_error_rate # noqa: F401 +from torchmetrics.functional.text.wer import word_error_rate # noqa: F401 from torchmetrics.functional.text.wil import word_information_lost # noqa: F401 from torchmetrics.functional.text.wip import word_information_preserved # noqa: F401 from torchmetrics.utilities.imports import _NLTK_AVAILABLE, _TRANSFORMERS_AUTO_AVAILABLE diff --git a/torchmetrics/functional/text/bert.py b/torchmetrics/functional/text/bert.py index 18f93b84064..ace4be8d914 100644 --- a/torchmetrics/functional/text/bert.py +++ b/torchmetrics/functional/text/bert.py @@ -19,11 +19,9 @@ from warnings import warn import torch -from deprecate import deprecated from torch import Tensor from torch.utils.data import DataLoader, Dataset -from torchmetrics.utilities import _future_warning from torchmetrics.utilities.imports import _TQDM_AVAILABLE, _TRANSFORMERS_AUTO_AVAILABLE if _TRANSFORMERS_AUTO_AVAILABLE: @@ -457,13 +455,6 @@ def _rescale_metrics_with_baseline( return all_metrics[..., 0], all_metrics[..., 1], all_metrics[..., 2] -@deprecated( - args_mapping={"predictions": "preds", "references": "target"}, - target=True, - deprecated_in="0.7", - remove_in="0.8", - stream=_future_warning, -) def bert_score( preds: Union[List[str], Dict[str, Tensor]], target: Union[List[str], Dict[str, Tensor]], @@ -549,13 +540,6 @@ def bert_score( Returns: Python dictionary containing the keys `precision`, `recall` and `f1` with corresponding values. - .. deprecated:: v0.7 - Args: - predictions: - This argument is deprecated in favor of `preds` and will be removed in v0.8. - references: - This argument is deprecated in favor of `target` and will be removed in v0.8. - Raises: ValueError: If `len(preds) != len(target)`. diff --git a/torchmetrics/functional/text/bleu.py b/torchmetrics/functional/text/bleu.py index 880ef4ed7b8..525e6347c12 100644 --- a/torchmetrics/functional/text/bleu.py +++ b/torchmetrics/functional/text/bleu.py @@ -18,14 +18,10 @@ # Link: https://pytorch.org/text/_modules/torchtext/data/metrics.html#bleu_score from collections import Counter from typing import Callable, Sequence, Tuple, Union -from warnings import warn import torch -from deprecate import deprecated from torch import Tensor, tensor -from torchmetrics.utilities import _future_warning - def _count_ngram(ngram_input_list: Sequence[str], n_gram: int) -> Counter: """Counting how many times each word appears in a given text with ngram. @@ -146,13 +142,6 @@ def _bleu_score_compute( return bleu -@deprecated( - args_mapping={"translate_corpus": "preds", "reference_corpus": "target"}, - target=True, - deprecated_in="0.7", - remove_in="0.8", - stream=_future_warning, -) def bleu_score( preds: Union[str, Sequence[str]], target: Sequence[Union[str, Sequence[str]]], @@ -174,13 +163,6 @@ def bleu_score( Return: Tensor with BLEU Score - .. deprecated:: v0.7 - Args: - translate_corpus: - This argument is deprecated in favor of `preds` and will be removed in v0.8. - reference_corpus: - This argument is deprecated in favor of `target` and will be removed in v0.8. - Example: >>> from torchmetrics.functional import bleu_score >>> preds = ['the cat is on the mat'] @@ -195,11 +177,6 @@ def bleu_score( [2] Automatic Evaluation of Machine Translation Quality Using Longest Common Subsequence and Skip-Bigram Statistics by Chin-Yew Lin and Franz Josef Och `Machine Translation Evolution`_ """ - warn( - "Input order of targets and preds were changed to predictions firsts and targets second in v0.7." - " Warning will be removed in v0.8." - ) - preds_ = [preds] if isinstance(preds, str) else preds target_ = [[tgt] if isinstance(tgt, str) else tgt for tgt in target] diff --git a/torchmetrics/functional/text/cer.py b/torchmetrics/functional/text/cer.py index ba0bb47bfc4..59d1801e933 100644 --- a/torchmetrics/functional/text/cer.py +++ b/torchmetrics/functional/text/cer.py @@ -15,11 +15,9 @@ from typing import List, Tuple, Union import torch -from deprecate import deprecated from torch import Tensor, tensor from torchmetrics.functional.text.helper import _edit_distance -from torchmetrics.utilities import _future_warning def _cer_update( @@ -61,13 +59,6 @@ def _cer_compute(errors: Tensor, total: Tensor) -> Tensor: return errors / total -@deprecated( - args_mapping={"predictions": "preds", "references": "target"}, - target=True, - deprecated_in="0.7", - remove_in="0.8", - stream=_future_warning, -) def char_error_rate(preds: Union[str, List[str]], target: Union[str, List[str]]) -> Tensor: """character error rate is a common metric of the performance of an automatic speech recognition system. This value indicates the percentage of characters that were incorrectly predicted. The lower the value, the better the @@ -75,16 +66,10 @@ def char_error_rate(preds: Union[str, List[str]], target: Union[str, List[str]]) Args: preds: Transcription(s) to score as a string or list of strings target: Reference(s) for each speech input as a string or list of strings + Returns: Character error rate score - .. deprecated:: v0.7 - Args: - predictions: - This argument is deprecated in favor of `preds` and will be removed in v0.8. - references: - This argument is deprecated in favor of `target` and will be removed in v0.8. - Examples: >>> preds = ["this is the prediction", "there is an other sample"] >>> target = ["this is the reference", "there is another one"] diff --git a/torchmetrics/functional/text/sacre_bleu.py b/torchmetrics/functional/text/sacre_bleu.py index 730ddcbecad..f96db5d5b29 100644 --- a/torchmetrics/functional/text/sacre_bleu.py +++ b/torchmetrics/functional/text/sacre_bleu.py @@ -40,15 +40,12 @@ import re from functools import partial from typing import Sequence -from warnings import warn import torch -from deprecate import deprecated from torch import Tensor, tensor from typing_extensions import Literal from torchmetrics.functional.text.bleu import _bleu_score_compute, _bleu_score_update -from torchmetrics.utilities import _future_warning from torchmetrics.utilities.imports import _REGEX_AVAILABLE AVAILABLE_TOKENIZERS = ("none", "13a", "zh", "intl", "char") @@ -278,13 +275,6 @@ def _lower(line: str, lowercase: bool) -> str: return line -@deprecated( - args_mapping={"translate_corpus": "preds", "reference_corpus": "target"}, - target=True, - deprecated_in="0.7", - remove_in="0.8", - stream=_future_warning, -) def sacre_bleu_score( preds: Sequence[str], target: Sequence[Sequence[str]], @@ -314,13 +304,6 @@ def sacre_bleu_score( Return: Tensor with BLEU Score - .. deprecated:: v0.7 - Args: - translate_corpus: - This argument is deprecated in favor of `preds` and will be removed in v0.8. - reference_corpus: - This argument is deprecated in favor of `target` and will be removed in v0.8. - Example: >>> from torchmetrics.functional import sacre_bleu_score >>> preds = ['the cat is on the mat'] @@ -337,10 +320,6 @@ def sacre_bleu_score( [3] Automatic Evaluation of Machine Translation Quality Using Longest Common Subsequence and Skip-Bigram Statistics by Chin-Yew Lin and Franz Josef Och `Machine Translation Evolution`_ """ - warn( - "Input order of targets and preds were changed to predictions firsts and targets second in v0.7." - " Warning will be removed in v0.8." - ) if tokenize not in AVAILABLE_TOKENIZERS: raise ValueError(f"Argument `tokenize` expected to be one of {AVAILABLE_TOKENIZERS} but got {tokenize}.") diff --git a/torchmetrics/functional/text/wer.py b/torchmetrics/functional/text/wer.py index 093a96f68b7..851d3896476 100644 --- a/torchmetrics/functional/text/wer.py +++ b/torchmetrics/functional/text/wer.py @@ -15,11 +15,9 @@ from typing import List, Tuple, Union import torch -from deprecate import deprecated, void from torch import Tensor, tensor from torchmetrics.functional.text.helper import _edit_distance -from torchmetrics.utilities import _future_warning def _wer_update( @@ -63,17 +61,10 @@ def _wer_compute(errors: Tensor, total: Tensor) -> Tensor: return errors / total -@deprecated( - args_mapping={"predictions": "preds", "references": "target"}, - target=True, - deprecated_in="0.7", - remove_in="0.8", - stream=_future_warning, -) def word_error_rate(preds: Union[str, List[str]], target: Union[str, List[str]]) -> Tensor: - """Word error rate (WER_) is a common metric of the performance of an automatic speech recognition system. This - value indicates the percentage of words that were incorrectly predicted. The lower the value, the better the - performance of the ASR system with a WER of 0 being a perfect score. + """Word error rate (WordErrorRate_) is a common metric of the performance of an automatic speech recognition + system. This value indicates the percentage of words that were incorrectly predicted. The lower the value, the + better the performance of the ASR system with a WER of 0 being a perfect score. Args: preds: Transcription(s) to score as a string or list of strings @@ -82,13 +73,6 @@ def word_error_rate(preds: Union[str, List[str]], target: Union[str, List[str]]) Returns: Word error rate score - .. deprecated:: v0.7 - Args: - predictions: - This argument is deprecated in favor of `preds` and will be removed in v0.8. - references: - This argument is deprecated in favor of `target` and will be removed in v0.8. - Examples: >>> preds = ["this is the prediction", "there is an other sample"] >>> target = ["this is the reference", "there is another one"] @@ -97,22 +81,3 @@ def word_error_rate(preds: Union[str, List[str]], target: Union[str, List[str]]) """ errors, total = _wer_update(preds, target) return _wer_compute(errors, total) - - -@deprecated(target=word_error_rate, deprecated_in="0.7", remove_in="0.8", stream=_future_warning) -def wer( - predictions: Union[str, List[str]], - references: Union[str, List[str]], -) -> Tensor: - """Word error rate (WER_) is a common metric of the performance of an automatic speech recognition system. - - .. deprecated:: v0.7 - Use :func:`torchmetrics.fuctional.word_error_rate`. Will be removed in v0.8. - - Examples: - >>> preds = ["this is the prediction", "there is an other sample"] - >>> target = ["this is the reference", "there is another one"] - >>> wer(preds=preds, target=target) - tensor(0.5000) - """ - return void(predictions, references) diff --git a/torchmetrics/text/__init__.py b/torchmetrics/text/__init__.py index f027218e9f6..27e95d6ab6b 100644 --- a/torchmetrics/text/__init__.py +++ b/torchmetrics/text/__init__.py @@ -19,7 +19,7 @@ from torchmetrics.text.sacre_bleu import SacreBLEUScore # noqa: F401 from torchmetrics.text.squad import SQuAD # noqa: F401 from torchmetrics.text.ter import TranslationEditRate # noqa: F401 -from torchmetrics.text.wer import WER, WordErrorRate # noqa: F401 +from torchmetrics.text.wer import WordErrorRate # noqa: F401 from torchmetrics.text.wil import WordInfoLost # noqa: F401 from torchmetrics.text.wip import WordInfoPreserved # noqa: F401 from torchmetrics.utilities.imports import _NLTK_AVAILABLE, _TRANSFORMERS_AUTO_AVAILABLE diff --git a/torchmetrics/text/bert.py b/torchmetrics/text/bert.py index fa34cc55657..99cc0f1563d 100644 --- a/torchmetrics/text/bert.py +++ b/torchmetrics/text/bert.py @@ -15,12 +15,10 @@ from warnings import warn import torch -from deprecate import deprecated from torch import Tensor from torchmetrics.functional.text.bert import _preprocess_text, bert_score from torchmetrics.metric import Metric -from torchmetrics.utilities import _future_warning from torchmetrics.utilities.imports import _TRANSFORMERS_AUTO_AVAILABLE if _TRANSFORMERS_AUTO_AVAILABLE: @@ -203,13 +201,6 @@ def __init__( self.add_state("target_input_ids", [], dist_reduce_fx="cat") self.add_state("target_attention_mask", [], dist_reduce_fx="cat") - @deprecated( - args_mapping={"predictions": "preds", "references": "target"}, - target=True, - deprecated_in="0.7", - remove_in="0.8", - stream=_future_warning, - ) def update(self, preds: List[str], target: List[str]) -> None: # type: ignore """Store predictions/references for computing BERT scores. It is necessary to store sentences in a tokenized form to ensure the DDP mode working. @@ -219,13 +210,6 @@ def update(self, preds: List[str], target: List[str]) -> None: # type: ignore An iterable of predicted sentences. target: An iterable of reference sentences. - - .. deprecated:: v0.7 - Args: - predictions: - This argument is deprecated in favor of `preds` and will be removed in v0.8. - references: - This argument is deprecated in favor of `target` and will be removed in v0.8. """ preds_dict = _preprocess_text( preds, diff --git a/torchmetrics/text/bleu.py b/torchmetrics/text/bleu.py index 8d104a464e6..31b0c6709dc 100644 --- a/torchmetrics/text/bleu.py +++ b/torchmetrics/text/bleu.py @@ -20,12 +20,10 @@ from warnings import warn import torch -from deprecate import deprecated from torch import Tensor, tensor from torchmetrics import Metric from torchmetrics.functional.text.bleu import _bleu_score_compute, _bleu_score_update, _tokenize_fn -from torchmetrics.utilities import _future_warning class BLEUScore(Metric): @@ -97,26 +95,12 @@ def __init__( self.add_state("numerator", torch.zeros(self.n_gram), dist_reduce_fx="sum") self.add_state("denominator", torch.zeros(self.n_gram), dist_reduce_fx="sum") - @deprecated( - args_mapping={"translate_corpus": "preds", "reference_corpus": "target"}, - target=True, - deprecated_in="0.7", - remove_in="0.8", - stream=_future_warning, - ) def update(self, preds: Sequence[str], target: Sequence[Sequence[str]]) -> None: # type: ignore """Compute Precision Scores. Args: preds: An iterable of machine translated corpus target: An iterable of iterables of reference corpus - - .. deprecated:: v0.7 - Args: - translate_corpus: - This argument is deprecated in favor of `preds` and will be removed in v0.8. - reference_corpus: - This argument is deprecated in favor of `target` and will be removed in v0.8. """ self.preds_len, self.target_len = _bleu_score_update( preds, diff --git a/torchmetrics/text/cer.py b/torchmetrics/text/cer.py index 8771caf7dc4..078bf58b6fc 100644 --- a/torchmetrics/text/cer.py +++ b/torchmetrics/text/cer.py @@ -15,12 +15,10 @@ from typing import Any, Callable, List, Optional, Union import torch -from deprecate import deprecated from torch import Tensor, tensor from torchmetrics.functional.text.cer import _cer_compute, _cer_update from torchmetrics.metric import Metric -from torchmetrics.utilities import _future_warning class CharErrorRate(Metric): @@ -86,26 +84,12 @@ def __init__( self.add_state("errors", tensor(0, dtype=torch.float), dist_reduce_fx="sum") self.add_state("total", tensor(0, dtype=torch.float), dist_reduce_fx="sum") - @deprecated( - args_mapping={"predictions": "preds", "references": "target"}, - target=True, - deprecated_in="0.7", - remove_in="0.8", - stream=_future_warning, - ) def update(self, preds: Union[str, List[str]], target: Union[str, List[str]]) -> None: # type: ignore """Store references/predictions for computing Character Error Rate scores. Args: preds: Transcription(s) to score as a string or list of strings target: Reference(s) for each speech input as a string or list of strings - - .. deprecated:: v0.7 - Args: - predictions: - This argument is deprecated in favor of `preds` and will be removed in v0.8. - references: - This argument is deprecated in favor of `target` and will be removed in v0.8. """ errors, total = _cer_update(preds, target) self.errors += errors diff --git a/torchmetrics/text/sacre_bleu.py b/torchmetrics/text/sacre_bleu.py index aca3a7b9055..196b173d9ff 100644 --- a/torchmetrics/text/sacre_bleu.py +++ b/torchmetrics/text/sacre_bleu.py @@ -18,15 +18,12 @@ # Date: 2020-07-18 # Link: https://pytorch.org/text/_modules/torchtext/data/metrics.html#bleu_score from typing import Any, Callable, Optional, Sequence -from warnings import warn -from deprecate import deprecated from typing_extensions import Literal from torchmetrics.functional.text.bleu import _bleu_score_update from torchmetrics.functional.text.sacre_bleu import _SacreBLEUTokenizer from torchmetrics.text.bleu import BLEUScore -from torchmetrics.utilities import _future_warning from torchmetrics.utilities.imports import _REGEX_AVAILABLE AVAILABLE_TOKENIZERS = ("none", "13a", "zh", "intl", "char") @@ -103,10 +100,6 @@ def __init__( process_group=process_group, dist_sync_fn=dist_sync_fn, ) - warn( - "Input order of targets and preds were changed to predictions firsts and targets \ - second in v0.7. Warning will be removed in v0.8" - ) if tokenize not in AVAILABLE_TOKENIZERS: raise ValueError(f"Argument `tokenize` expected to be one of {AVAILABLE_TOKENIZERS} but got {tokenize}.") @@ -117,26 +110,12 @@ def __init__( ) self.tokenizer = _SacreBLEUTokenizer(tokenize, lowercase) - @deprecated( - args_mapping={"translate_corpus": "preds", "reference_corpus": "target"}, - target=True, - deprecated_in="0.7", - remove_in="0.8", - stream=_future_warning, - ) def update(self, preds: Sequence[str], target: Sequence[Sequence[str]]) -> None: # type: ignore """Compute Precision Scores. Args: preds: An iterable of machine translated corpus target: An iterable of iterables of reference corpus - - .. deprecated:: v0.7 - Args: - translate_corpus: - This argument is deprecated in favor of `preds` and will be removed in v0.8. - reference_corpus: - This argument is deprecated in favor of `target` and will be removed in v0.8. """ self.preds_len, self.target_len = _bleu_score_update( preds, diff --git a/torchmetrics/text/wer.py b/torchmetrics/text/wer.py index ac0a012102d..630359f8114 100644 --- a/torchmetrics/text/wer.py +++ b/torchmetrics/text/wer.py @@ -14,17 +14,15 @@ from typing import Any, Callable, List, Optional, Union import torch -from deprecate import deprecated, void from torch import Tensor, tensor from torchmetrics.functional.text.wer import _wer_compute, _wer_update from torchmetrics.metric import Metric -from torchmetrics.utilities import _future_warning class WordErrorRate(Metric): r""" - Word error rate (WER_) is a common metric of the performance of an automatic speech recognition system. + Word error rate (WordErrorRate_) is a common metric of the performance of an automatic speech recognition system. This value indicates the percentage of words that were incorrectly predicted. The lower the value, the better the performance of the ASR system with a WER of 0 being a perfect score. Word error rate can then be computed as: @@ -84,26 +82,12 @@ def __init__( self.add_state("errors", tensor(0, dtype=torch.float), dist_reduce_fx="sum") self.add_state("total", tensor(0, dtype=torch.float), dist_reduce_fx="sum") - @deprecated( - args_mapping={"predictions": "preds", "references": "target"}, - target=True, - deprecated_in="0.7", - remove_in="0.8", - stream=_future_warning, - ) def update(self, preds: Union[str, List[str]], target: Union[str, List[str]]) -> None: # type: ignore """Store references/predictions for computing Word Error Rate scores. Args: preds: Transcription(s) to score as a string or list of strings target: Reference(s) for each speech input as a string or list of strings - - .. deprecated:: v0.7 - Args: - predictions: - This argument is deprecated in favor of `preds` and will be removed in v0.8. - references: - This argument is deprecated in favor of `target` and will be removed in v0.8. """ errors, total = _wer_update(preds, target) self.errors += errors @@ -116,29 +100,3 @@ def compute(self) -> Tensor: Word error rate score """ return _wer_compute(self.errors, self.total) - - -class WER(WordErrorRate): - r""" - Word error rate (WER_) is a common metric of the performance of an automatic speech recognition system. - - .. deprecated:: v0.7 - Use :class:`torchmetrics.WordErrorRate`. Will be removed in v0.8. - - Examples: - >>> preds = ["this is the prediction", "there is an other sample"] - >>> target = ["this is the reference", "there is another one"] - >>> metric = WER() - >>> metric(preds, target) - tensor(0.5000) - """ - - @deprecated(target=WordErrorRate, deprecated_in="0.7", remove_in="0.8", stream=_future_warning) - def __init__( - self, - compute_on_step: bool = True, - dist_sync_on_step: bool = False, - process_group: Optional[Any] = None, - dist_sync_fn: Callable = None, - ) -> None: - void(compute_on_step, dist_sync_on_step, process_group, dist_sync_fn) From 43a226101e05ec8e5a5fe4ad64b90c44af1a3b34 Mon Sep 17 00:00:00 2001 From: Ashutosh Kumar Date: Wed, 19 Jan 2022 03:38:22 +0530 Subject: [PATCH 04/13] Remove deprecated functions, and warnings - Text (#773) * Remove deprecated functions, and warnings * Update links for docstring * chlog Co-authored-by: Daniel Stancl <46073029+stancld@users.noreply.github.com> Co-authored-by: Jirka Borovec --- CHANGELOG.md | 34 +++++++++-------- docs/source/links.rst | 2 +- docs/source/references/functional.rst | 6 +-- docs/source/references/modules.rst | 6 +-- torchmetrics/__init__.py | 2 - torchmetrics/functional/__init__.py | 3 +- torchmetrics/functional/text/__init__.py | 2 +- torchmetrics/functional/text/bert.py | 16 -------- torchmetrics/functional/text/bleu.py | 23 ----------- torchmetrics/functional/text/cer.py | 17 +-------- torchmetrics/functional/text/sacre_bleu.py | 21 ----------- torchmetrics/functional/text/wer.py | 41 ++------------------ torchmetrics/text/__init__.py | 2 +- torchmetrics/text/bert.py | 16 -------- torchmetrics/text/bleu.py | 16 -------- torchmetrics/text/cer.py | 16 -------- torchmetrics/text/sacre_bleu.py | 21 ----------- torchmetrics/text/wer.py | 44 +--------------------- 18 files changed, 34 insertions(+), 254 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 9c29df43b51..6a35f752030 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -20,6 +20,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Removed +- Removed deprecated functions, and warnings in Text ([#773](https://github.com/PyTorchLightning/metrics/pull/773)) + * `functional.wer` + * `WER` + ### Fixed @@ -58,8 +62,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Renamed IoU -> Jaccard Index ([#662](https://github.com/PyTorchLightning/metrics/pull/662)) - Renamed text WER metric ([#714](https://github.com/PyTorchLightning/metrics/pull/714)) - * `functional.wer` -> `functional.word_error_rate` - * `WER` -> `WordErrorRate` + * `functional.wer` -> `functional.word_error_rate` + * `WER` -> `WordErrorRate` - Renamed correlation coefficient classes: ([#710](https://github.com/PyTorchLightning/metrics/pull/710)) * `MatthewsCorrcoef` -> `MatthewsCorrCoef` * `PearsonCorrcoef` -> `PearsonCorrCoef` @@ -81,27 +85,27 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 * `SNR` -> `SignalNoiseRatio` * `SI_SNR` -> `ScaleInvariantSignalNoiseRatio` - Renamed F-score metrics: ([#731](https://github.com/PyTorchLightning/metrics/pull/731), [#740](https://github.com/PyTorchLightning/metrics/pull/740)) - * `functional.f1` -> `functional.f1_score` - * `F1` -> `F1Score` - * `functional.fbeta` -> `functional.fbeta_score` - * `FBeta` -> `FBetaScore` + * `functional.f1` -> `functional.f1_score` + * `F1` -> `F1Score` + * `functional.fbeta` -> `functional.fbeta_score` + * `FBeta` -> `FBetaScore` - Renamed Hinge metric: ([#734](https://github.com/PyTorchLightning/metrics/pull/734)) - * `functional.hinge` -> `functional.hinge_loss` - * `Hinge` -> `HingeLoss` + * `functional.hinge` -> `functional.hinge_loss` + * `Hinge` -> `HingeLoss` - Renamed image PSNR metrics ([#732](https://github.com/PyTorchLightning/metrics/pull/732)) * `functional.psnr` -> `functional.peak_signal_noise_ratio` * `PSNR` -> `PeakSignalNoiseRatio` - Renamed image PIT metric: ([#737](https://github.com/PyTorchLightning/metrics/pull/737)) - * `functional.pit` -> `functional.permutation_invariant_training` - * `PIT` -> `PermutationInvariantTraining` + * `functional.pit` -> `functional.permutation_invariant_training` + * `PIT` -> `PermutationInvariantTraining` - Renamed image SSIM metric: ([#747](https://github.com/PyTorchLightning/metrics/pull/747)) - * `functional.ssim` -> `functional.scale_invariant_signal_noise_ratio` - * `SSIM` -> `StructuralSimilarityIndexMeasure` + * `functional.ssim` -> `functional.scale_invariant_signal_noise_ratio` + * `SSIM` -> `StructuralSimilarityIndexMeasure` - Renamed detection `MAP` to `MeanAveragePrecision` metric ([#754](https://github.com/PyTorchLightning/metrics/pull/754)) - Renamed Fidelity & LPIPS image metric: ([#752](https://github.com/PyTorchLightning/metrics/pull/752)) - * `image.FID` -> `image.FrechetInceptionDistance` - * `image.KID` -> `image.KernelInceptionDistance` - * `image.LPIPS` -> `image.LearnedPerceptualImagePatchSimilarity` + * `image.FID` -> `image.FrechetInceptionDistance` + * `image.KID` -> `image.KernelInceptionDistance` + * `image.LPIPS` -> `image.LearnedPerceptualImagePatchSimilarity` ### Removed diff --git a/docs/source/links.rst b/docs/source/links.rst index eeaced1807c..d0d54d5223b 100644 --- a/docs/source/links.rst +++ b/docs/source/links.rst @@ -28,7 +28,7 @@ .. _sklearn averaging methods: https://scikit-learn.org/stable/modules/model_evaluation.html#multiclass-and-multilabel-classification .. _Cosine Similarity: https://en.wikipedia.org/wiki/Cosine_similarity .. _spearmans rank correlation coefficient: https://en.wikipedia.org/wiki/Spearman%27s_rank_correlation_coefficient -.. _WER: https://en.wikipedia.org/wiki/Word_error_rate +.. _WordErrorRate: https://en.wikipedia.org/wiki/Word_error_rate .. _FID: https://en.wikipedia.org/wiki/Fr%C3%A9chet_inception_distance .. _mean-squared-error: https://en.wikipedia.org/wiki/Mean_squared_error .. _SSIM: https://en.wikipedia.org/wiki/Structural_similarity diff --git a/docs/source/references/functional.rst b/docs/source/references/functional.rst index 7f3a105a2ab..66d1e20c97a 100644 --- a/docs/source/references/functional.rst +++ b/docs/source/references/functional.rst @@ -507,10 +507,10 @@ translation_edit_rate [func] .. autofunction:: torchmetrics.functional.translation_edit_rate :noindex: -wer [func] -~~~~~~~~~~ +word_error_rate [func] +~~~~~~~~~~~~~~~~~~~~~~ -.. autofunction:: torchmetrics.functional.wer +.. autofunction:: torchmetrics.functional.word_error_rate :noindex: word_information_lost [func] diff --git a/docs/source/references/modules.rst b/docs/source/references/modules.rst index 2d6d375b228..65e5bb11802 100644 --- a/docs/source/references/modules.rst +++ b/docs/source/references/modules.rst @@ -678,10 +678,10 @@ TranslationEditRate .. autoclass:: torchmetrics.TranslationEditRate :noindex: -WER -~~~ +WordErrorRate +~~~~~~~~~~~~~ -.. autoclass:: torchmetrics.WER +.. autoclass:: torchmetrics.WordErrorRate :noindex: WordInfoLost diff --git a/torchmetrics/__init__.py b/torchmetrics/__init__.py index c3c1b43f5da..75811b31cab 100644 --- a/torchmetrics/__init__.py +++ b/torchmetrics/__init__.py @@ -90,7 +90,6 @@ RetrievalRPrecision, ) from torchmetrics.text import ( # noqa: E402 - WER, BLEUScore, CharErrorRate, CHRFScore, @@ -187,7 +186,6 @@ "SumMetric", "SymmetricMeanAbsolutePercentageError", "TranslationEditRate", - "WER", "WordErrorRate", "CharErrorRate", "MatchErrorRate", diff --git a/torchmetrics/functional/__init__.py b/torchmetrics/functional/__init__.py index 95660715f2c..fd2113295ed 100644 --- a/torchmetrics/functional/__init__.py +++ b/torchmetrics/functional/__init__.py @@ -76,7 +76,7 @@ from torchmetrics.functional.text.sacre_bleu import sacre_bleu_score from torchmetrics.functional.text.squad import squad from torchmetrics.functional.text.ter import translation_edit_rate -from torchmetrics.functional.text.wer import wer, word_error_rate +from torchmetrics.functional.text.wer import word_error_rate from torchmetrics.functional.text.wil import word_information_lost from torchmetrics.functional.text.wip import word_information_preserved from torchmetrics.utilities.imports import _TRANSFORMERS_AUTO_AVAILABLE @@ -158,7 +158,6 @@ "stat_scores", "symmetric_mean_absolute_percentage_error", "translation_edit_rate", - "wer", "word_error_rate", "char_error_rate", "match_error_rate", diff --git a/torchmetrics/functional/text/__init__.py b/torchmetrics/functional/text/__init__.py index e4e0161443a..14b982f90eb 100644 --- a/torchmetrics/functional/text/__init__.py +++ b/torchmetrics/functional/text/__init__.py @@ -20,7 +20,7 @@ from torchmetrics.functional.text.sacre_bleu import sacre_bleu_score # noqa: F401 from torchmetrics.functional.text.squad import squad # noqa: F401 from torchmetrics.functional.text.ter import translation_edit_rate # noqa: F401 -from torchmetrics.functional.text.wer import wer, word_error_rate # noqa: F401 +from torchmetrics.functional.text.wer import word_error_rate # noqa: F401 from torchmetrics.functional.text.wil import word_information_lost # noqa: F401 from torchmetrics.functional.text.wip import word_information_preserved # noqa: F401 from torchmetrics.utilities.imports import _NLTK_AVAILABLE, _TRANSFORMERS_AUTO_AVAILABLE diff --git a/torchmetrics/functional/text/bert.py b/torchmetrics/functional/text/bert.py index 18f93b84064..ace4be8d914 100644 --- a/torchmetrics/functional/text/bert.py +++ b/torchmetrics/functional/text/bert.py @@ -19,11 +19,9 @@ from warnings import warn import torch -from deprecate import deprecated from torch import Tensor from torch.utils.data import DataLoader, Dataset -from torchmetrics.utilities import _future_warning from torchmetrics.utilities.imports import _TQDM_AVAILABLE, _TRANSFORMERS_AUTO_AVAILABLE if _TRANSFORMERS_AUTO_AVAILABLE: @@ -457,13 +455,6 @@ def _rescale_metrics_with_baseline( return all_metrics[..., 0], all_metrics[..., 1], all_metrics[..., 2] -@deprecated( - args_mapping={"predictions": "preds", "references": "target"}, - target=True, - deprecated_in="0.7", - remove_in="0.8", - stream=_future_warning, -) def bert_score( preds: Union[List[str], Dict[str, Tensor]], target: Union[List[str], Dict[str, Tensor]], @@ -549,13 +540,6 @@ def bert_score( Returns: Python dictionary containing the keys `precision`, `recall` and `f1` with corresponding values. - .. deprecated:: v0.7 - Args: - predictions: - This argument is deprecated in favor of `preds` and will be removed in v0.8. - references: - This argument is deprecated in favor of `target` and will be removed in v0.8. - Raises: ValueError: If `len(preds) != len(target)`. diff --git a/torchmetrics/functional/text/bleu.py b/torchmetrics/functional/text/bleu.py index 880ef4ed7b8..525e6347c12 100644 --- a/torchmetrics/functional/text/bleu.py +++ b/torchmetrics/functional/text/bleu.py @@ -18,14 +18,10 @@ # Link: https://pytorch.org/text/_modules/torchtext/data/metrics.html#bleu_score from collections import Counter from typing import Callable, Sequence, Tuple, Union -from warnings import warn import torch -from deprecate import deprecated from torch import Tensor, tensor -from torchmetrics.utilities import _future_warning - def _count_ngram(ngram_input_list: Sequence[str], n_gram: int) -> Counter: """Counting how many times each word appears in a given text with ngram. @@ -146,13 +142,6 @@ def _bleu_score_compute( return bleu -@deprecated( - args_mapping={"translate_corpus": "preds", "reference_corpus": "target"}, - target=True, - deprecated_in="0.7", - remove_in="0.8", - stream=_future_warning, -) def bleu_score( preds: Union[str, Sequence[str]], target: Sequence[Union[str, Sequence[str]]], @@ -174,13 +163,6 @@ def bleu_score( Return: Tensor with BLEU Score - .. deprecated:: v0.7 - Args: - translate_corpus: - This argument is deprecated in favor of `preds` and will be removed in v0.8. - reference_corpus: - This argument is deprecated in favor of `target` and will be removed in v0.8. - Example: >>> from torchmetrics.functional import bleu_score >>> preds = ['the cat is on the mat'] @@ -195,11 +177,6 @@ def bleu_score( [2] Automatic Evaluation of Machine Translation Quality Using Longest Common Subsequence and Skip-Bigram Statistics by Chin-Yew Lin and Franz Josef Och `Machine Translation Evolution`_ """ - warn( - "Input order of targets and preds were changed to predictions firsts and targets second in v0.7." - " Warning will be removed in v0.8." - ) - preds_ = [preds] if isinstance(preds, str) else preds target_ = [[tgt] if isinstance(tgt, str) else tgt for tgt in target] diff --git a/torchmetrics/functional/text/cer.py b/torchmetrics/functional/text/cer.py index ba0bb47bfc4..59d1801e933 100644 --- a/torchmetrics/functional/text/cer.py +++ b/torchmetrics/functional/text/cer.py @@ -15,11 +15,9 @@ from typing import List, Tuple, Union import torch -from deprecate import deprecated from torch import Tensor, tensor from torchmetrics.functional.text.helper import _edit_distance -from torchmetrics.utilities import _future_warning def _cer_update( @@ -61,13 +59,6 @@ def _cer_compute(errors: Tensor, total: Tensor) -> Tensor: return errors / total -@deprecated( - args_mapping={"predictions": "preds", "references": "target"}, - target=True, - deprecated_in="0.7", - remove_in="0.8", - stream=_future_warning, -) def char_error_rate(preds: Union[str, List[str]], target: Union[str, List[str]]) -> Tensor: """character error rate is a common metric of the performance of an automatic speech recognition system. This value indicates the percentage of characters that were incorrectly predicted. The lower the value, the better the @@ -75,16 +66,10 @@ def char_error_rate(preds: Union[str, List[str]], target: Union[str, List[str]]) Args: preds: Transcription(s) to score as a string or list of strings target: Reference(s) for each speech input as a string or list of strings + Returns: Character error rate score - .. deprecated:: v0.7 - Args: - predictions: - This argument is deprecated in favor of `preds` and will be removed in v0.8. - references: - This argument is deprecated in favor of `target` and will be removed in v0.8. - Examples: >>> preds = ["this is the prediction", "there is an other sample"] >>> target = ["this is the reference", "there is another one"] diff --git a/torchmetrics/functional/text/sacre_bleu.py b/torchmetrics/functional/text/sacre_bleu.py index 730ddcbecad..f96db5d5b29 100644 --- a/torchmetrics/functional/text/sacre_bleu.py +++ b/torchmetrics/functional/text/sacre_bleu.py @@ -40,15 +40,12 @@ import re from functools import partial from typing import Sequence -from warnings import warn import torch -from deprecate import deprecated from torch import Tensor, tensor from typing_extensions import Literal from torchmetrics.functional.text.bleu import _bleu_score_compute, _bleu_score_update -from torchmetrics.utilities import _future_warning from torchmetrics.utilities.imports import _REGEX_AVAILABLE AVAILABLE_TOKENIZERS = ("none", "13a", "zh", "intl", "char") @@ -278,13 +275,6 @@ def _lower(line: str, lowercase: bool) -> str: return line -@deprecated( - args_mapping={"translate_corpus": "preds", "reference_corpus": "target"}, - target=True, - deprecated_in="0.7", - remove_in="0.8", - stream=_future_warning, -) def sacre_bleu_score( preds: Sequence[str], target: Sequence[Sequence[str]], @@ -314,13 +304,6 @@ def sacre_bleu_score( Return: Tensor with BLEU Score - .. deprecated:: v0.7 - Args: - translate_corpus: - This argument is deprecated in favor of `preds` and will be removed in v0.8. - reference_corpus: - This argument is deprecated in favor of `target` and will be removed in v0.8. - Example: >>> from torchmetrics.functional import sacre_bleu_score >>> preds = ['the cat is on the mat'] @@ -337,10 +320,6 @@ def sacre_bleu_score( [3] Automatic Evaluation of Machine Translation Quality Using Longest Common Subsequence and Skip-Bigram Statistics by Chin-Yew Lin and Franz Josef Och `Machine Translation Evolution`_ """ - warn( - "Input order of targets and preds were changed to predictions firsts and targets second in v0.7." - " Warning will be removed in v0.8." - ) if tokenize not in AVAILABLE_TOKENIZERS: raise ValueError(f"Argument `tokenize` expected to be one of {AVAILABLE_TOKENIZERS} but got {tokenize}.") diff --git a/torchmetrics/functional/text/wer.py b/torchmetrics/functional/text/wer.py index 093a96f68b7..851d3896476 100644 --- a/torchmetrics/functional/text/wer.py +++ b/torchmetrics/functional/text/wer.py @@ -15,11 +15,9 @@ from typing import List, Tuple, Union import torch -from deprecate import deprecated, void from torch import Tensor, tensor from torchmetrics.functional.text.helper import _edit_distance -from torchmetrics.utilities import _future_warning def _wer_update( @@ -63,17 +61,10 @@ def _wer_compute(errors: Tensor, total: Tensor) -> Tensor: return errors / total -@deprecated( - args_mapping={"predictions": "preds", "references": "target"}, - target=True, - deprecated_in="0.7", - remove_in="0.8", - stream=_future_warning, -) def word_error_rate(preds: Union[str, List[str]], target: Union[str, List[str]]) -> Tensor: - """Word error rate (WER_) is a common metric of the performance of an automatic speech recognition system. This - value indicates the percentage of words that were incorrectly predicted. The lower the value, the better the - performance of the ASR system with a WER of 0 being a perfect score. + """Word error rate (WordErrorRate_) is a common metric of the performance of an automatic speech recognition + system. This value indicates the percentage of words that were incorrectly predicted. The lower the value, the + better the performance of the ASR system with a WER of 0 being a perfect score. Args: preds: Transcription(s) to score as a string or list of strings @@ -82,13 +73,6 @@ def word_error_rate(preds: Union[str, List[str]], target: Union[str, List[str]]) Returns: Word error rate score - .. deprecated:: v0.7 - Args: - predictions: - This argument is deprecated in favor of `preds` and will be removed in v0.8. - references: - This argument is deprecated in favor of `target` and will be removed in v0.8. - Examples: >>> preds = ["this is the prediction", "there is an other sample"] >>> target = ["this is the reference", "there is another one"] @@ -97,22 +81,3 @@ def word_error_rate(preds: Union[str, List[str]], target: Union[str, List[str]]) """ errors, total = _wer_update(preds, target) return _wer_compute(errors, total) - - -@deprecated(target=word_error_rate, deprecated_in="0.7", remove_in="0.8", stream=_future_warning) -def wer( - predictions: Union[str, List[str]], - references: Union[str, List[str]], -) -> Tensor: - """Word error rate (WER_) is a common metric of the performance of an automatic speech recognition system. - - .. deprecated:: v0.7 - Use :func:`torchmetrics.fuctional.word_error_rate`. Will be removed in v0.8. - - Examples: - >>> preds = ["this is the prediction", "there is an other sample"] - >>> target = ["this is the reference", "there is another one"] - >>> wer(preds=preds, target=target) - tensor(0.5000) - """ - return void(predictions, references) diff --git a/torchmetrics/text/__init__.py b/torchmetrics/text/__init__.py index f027218e9f6..27e95d6ab6b 100644 --- a/torchmetrics/text/__init__.py +++ b/torchmetrics/text/__init__.py @@ -19,7 +19,7 @@ from torchmetrics.text.sacre_bleu import SacreBLEUScore # noqa: F401 from torchmetrics.text.squad import SQuAD # noqa: F401 from torchmetrics.text.ter import TranslationEditRate # noqa: F401 -from torchmetrics.text.wer import WER, WordErrorRate # noqa: F401 +from torchmetrics.text.wer import WordErrorRate # noqa: F401 from torchmetrics.text.wil import WordInfoLost # noqa: F401 from torchmetrics.text.wip import WordInfoPreserved # noqa: F401 from torchmetrics.utilities.imports import _NLTK_AVAILABLE, _TRANSFORMERS_AUTO_AVAILABLE diff --git a/torchmetrics/text/bert.py b/torchmetrics/text/bert.py index fa34cc55657..99cc0f1563d 100644 --- a/torchmetrics/text/bert.py +++ b/torchmetrics/text/bert.py @@ -15,12 +15,10 @@ from warnings import warn import torch -from deprecate import deprecated from torch import Tensor from torchmetrics.functional.text.bert import _preprocess_text, bert_score from torchmetrics.metric import Metric -from torchmetrics.utilities import _future_warning from torchmetrics.utilities.imports import _TRANSFORMERS_AUTO_AVAILABLE if _TRANSFORMERS_AUTO_AVAILABLE: @@ -203,13 +201,6 @@ def __init__( self.add_state("target_input_ids", [], dist_reduce_fx="cat") self.add_state("target_attention_mask", [], dist_reduce_fx="cat") - @deprecated( - args_mapping={"predictions": "preds", "references": "target"}, - target=True, - deprecated_in="0.7", - remove_in="0.8", - stream=_future_warning, - ) def update(self, preds: List[str], target: List[str]) -> None: # type: ignore """Store predictions/references for computing BERT scores. It is necessary to store sentences in a tokenized form to ensure the DDP mode working. @@ -219,13 +210,6 @@ def update(self, preds: List[str], target: List[str]) -> None: # type: ignore An iterable of predicted sentences. target: An iterable of reference sentences. - - .. deprecated:: v0.7 - Args: - predictions: - This argument is deprecated in favor of `preds` and will be removed in v0.8. - references: - This argument is deprecated in favor of `target` and will be removed in v0.8. """ preds_dict = _preprocess_text( preds, diff --git a/torchmetrics/text/bleu.py b/torchmetrics/text/bleu.py index 8d104a464e6..31b0c6709dc 100644 --- a/torchmetrics/text/bleu.py +++ b/torchmetrics/text/bleu.py @@ -20,12 +20,10 @@ from warnings import warn import torch -from deprecate import deprecated from torch import Tensor, tensor from torchmetrics import Metric from torchmetrics.functional.text.bleu import _bleu_score_compute, _bleu_score_update, _tokenize_fn -from torchmetrics.utilities import _future_warning class BLEUScore(Metric): @@ -97,26 +95,12 @@ def __init__( self.add_state("numerator", torch.zeros(self.n_gram), dist_reduce_fx="sum") self.add_state("denominator", torch.zeros(self.n_gram), dist_reduce_fx="sum") - @deprecated( - args_mapping={"translate_corpus": "preds", "reference_corpus": "target"}, - target=True, - deprecated_in="0.7", - remove_in="0.8", - stream=_future_warning, - ) def update(self, preds: Sequence[str], target: Sequence[Sequence[str]]) -> None: # type: ignore """Compute Precision Scores. Args: preds: An iterable of machine translated corpus target: An iterable of iterables of reference corpus - - .. deprecated:: v0.7 - Args: - translate_corpus: - This argument is deprecated in favor of `preds` and will be removed in v0.8. - reference_corpus: - This argument is deprecated in favor of `target` and will be removed in v0.8. """ self.preds_len, self.target_len = _bleu_score_update( preds, diff --git a/torchmetrics/text/cer.py b/torchmetrics/text/cer.py index 8771caf7dc4..078bf58b6fc 100644 --- a/torchmetrics/text/cer.py +++ b/torchmetrics/text/cer.py @@ -15,12 +15,10 @@ from typing import Any, Callable, List, Optional, Union import torch -from deprecate import deprecated from torch import Tensor, tensor from torchmetrics.functional.text.cer import _cer_compute, _cer_update from torchmetrics.metric import Metric -from torchmetrics.utilities import _future_warning class CharErrorRate(Metric): @@ -86,26 +84,12 @@ def __init__( self.add_state("errors", tensor(0, dtype=torch.float), dist_reduce_fx="sum") self.add_state("total", tensor(0, dtype=torch.float), dist_reduce_fx="sum") - @deprecated( - args_mapping={"predictions": "preds", "references": "target"}, - target=True, - deprecated_in="0.7", - remove_in="0.8", - stream=_future_warning, - ) def update(self, preds: Union[str, List[str]], target: Union[str, List[str]]) -> None: # type: ignore """Store references/predictions for computing Character Error Rate scores. Args: preds: Transcription(s) to score as a string or list of strings target: Reference(s) for each speech input as a string or list of strings - - .. deprecated:: v0.7 - Args: - predictions: - This argument is deprecated in favor of `preds` and will be removed in v0.8. - references: - This argument is deprecated in favor of `target` and will be removed in v0.8. """ errors, total = _cer_update(preds, target) self.errors += errors diff --git a/torchmetrics/text/sacre_bleu.py b/torchmetrics/text/sacre_bleu.py index aca3a7b9055..196b173d9ff 100644 --- a/torchmetrics/text/sacre_bleu.py +++ b/torchmetrics/text/sacre_bleu.py @@ -18,15 +18,12 @@ # Date: 2020-07-18 # Link: https://pytorch.org/text/_modules/torchtext/data/metrics.html#bleu_score from typing import Any, Callable, Optional, Sequence -from warnings import warn -from deprecate import deprecated from typing_extensions import Literal from torchmetrics.functional.text.bleu import _bleu_score_update from torchmetrics.functional.text.sacre_bleu import _SacreBLEUTokenizer from torchmetrics.text.bleu import BLEUScore -from torchmetrics.utilities import _future_warning from torchmetrics.utilities.imports import _REGEX_AVAILABLE AVAILABLE_TOKENIZERS = ("none", "13a", "zh", "intl", "char") @@ -103,10 +100,6 @@ def __init__( process_group=process_group, dist_sync_fn=dist_sync_fn, ) - warn( - "Input order of targets and preds were changed to predictions firsts and targets \ - second in v0.7. Warning will be removed in v0.8" - ) if tokenize not in AVAILABLE_TOKENIZERS: raise ValueError(f"Argument `tokenize` expected to be one of {AVAILABLE_TOKENIZERS} but got {tokenize}.") @@ -117,26 +110,12 @@ def __init__( ) self.tokenizer = _SacreBLEUTokenizer(tokenize, lowercase) - @deprecated( - args_mapping={"translate_corpus": "preds", "reference_corpus": "target"}, - target=True, - deprecated_in="0.7", - remove_in="0.8", - stream=_future_warning, - ) def update(self, preds: Sequence[str], target: Sequence[Sequence[str]]) -> None: # type: ignore """Compute Precision Scores. Args: preds: An iterable of machine translated corpus target: An iterable of iterables of reference corpus - - .. deprecated:: v0.7 - Args: - translate_corpus: - This argument is deprecated in favor of `preds` and will be removed in v0.8. - reference_corpus: - This argument is deprecated in favor of `target` and will be removed in v0.8. """ self.preds_len, self.target_len = _bleu_score_update( preds, diff --git a/torchmetrics/text/wer.py b/torchmetrics/text/wer.py index ac0a012102d..630359f8114 100644 --- a/torchmetrics/text/wer.py +++ b/torchmetrics/text/wer.py @@ -14,17 +14,15 @@ from typing import Any, Callable, List, Optional, Union import torch -from deprecate import deprecated, void from torch import Tensor, tensor from torchmetrics.functional.text.wer import _wer_compute, _wer_update from torchmetrics.metric import Metric -from torchmetrics.utilities import _future_warning class WordErrorRate(Metric): r""" - Word error rate (WER_) is a common metric of the performance of an automatic speech recognition system. + Word error rate (WordErrorRate_) is a common metric of the performance of an automatic speech recognition system. This value indicates the percentage of words that were incorrectly predicted. The lower the value, the better the performance of the ASR system with a WER of 0 being a perfect score. Word error rate can then be computed as: @@ -84,26 +82,12 @@ def __init__( self.add_state("errors", tensor(0, dtype=torch.float), dist_reduce_fx="sum") self.add_state("total", tensor(0, dtype=torch.float), dist_reduce_fx="sum") - @deprecated( - args_mapping={"predictions": "preds", "references": "target"}, - target=True, - deprecated_in="0.7", - remove_in="0.8", - stream=_future_warning, - ) def update(self, preds: Union[str, List[str]], target: Union[str, List[str]]) -> None: # type: ignore """Store references/predictions for computing Word Error Rate scores. Args: preds: Transcription(s) to score as a string or list of strings target: Reference(s) for each speech input as a string or list of strings - - .. deprecated:: v0.7 - Args: - predictions: - This argument is deprecated in favor of `preds` and will be removed in v0.8. - references: - This argument is deprecated in favor of `target` and will be removed in v0.8. """ errors, total = _wer_update(preds, target) self.errors += errors @@ -116,29 +100,3 @@ def compute(self) -> Tensor: Word error rate score """ return _wer_compute(self.errors, self.total) - - -class WER(WordErrorRate): - r""" - Word error rate (WER_) is a common metric of the performance of an automatic speech recognition system. - - .. deprecated:: v0.7 - Use :class:`torchmetrics.WordErrorRate`. Will be removed in v0.8. - - Examples: - >>> preds = ["this is the prediction", "there is an other sample"] - >>> target = ["this is the reference", "there is another one"] - >>> metric = WER() - >>> metric(preds, target) - tensor(0.5000) - """ - - @deprecated(target=WordErrorRate, deprecated_in="0.7", remove_in="0.8", stream=_future_warning) - def __init__( - self, - compute_on_step: bool = True, - dist_sync_on_step: bool = False, - process_group: Optional[Any] = None, - dist_sync_fn: Callable = None, - ) -> None: - void(compute_on_step, dist_sync_on_step, process_group, dist_sync_fn) From cfe5e87797e07cf6c429e99b07648db1ddfb3e4c Mon Sep 17 00:00:00 2001 From: Nicki Skafte Detlefsen Date: Wed, 19 Jan 2022 22:34:39 +0100 Subject: [PATCH 05/13] Fix Matthews correlation coefficient when the denominator is 0 (#781) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Jirka Borovec --- CHANGELOG.md | 3 +++ tests/classification/test_matthews_corrcoef.py | 7 +++++++ .../functional/classification/matthews_corrcoef.py | 10 +++++++++- 3 files changed, 19 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 6a35f752030..9df5c265111 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -30,6 +30,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Fixed check for available modules ([#772](https://github.com/PyTorchLightning/metrics/pull/772)) +- Fixed Matthews correlation coefficient when the denominator is 0 ([#781](https://github.com/PyTorchLightning/metrics/pull/781)) + + ## [0.7.0] - 2022-01-17 ### Added diff --git a/tests/classification/test_matthews_corrcoef.py b/tests/classification/test_matthews_corrcoef.py index 692c11a8419..b5cb8aae8ea 100644 --- a/tests/classification/test_matthews_corrcoef.py +++ b/tests/classification/test_matthews_corrcoef.py @@ -140,3 +140,10 @@ def test_matthews_corrcoef_differentiability(self, preds, target, sk_metric, num "threshold": THRESHOLD, }, ) + + +def test_zero_case(): + """Cases where the denominator in the matthews corrcoef is 0, the score should return 0.""" + # Example where neither 1 or 2 is present in the target tensor + out = matthews_corrcoef(torch.tensor([0, 1, 2]), torch.tensor([0, 0, 0]), 3) + assert out == 0.0 diff --git a/torchmetrics/functional/classification/matthews_corrcoef.py b/torchmetrics/functional/classification/matthews_corrcoef.py index 00e396d97db..a0807699806 100644 --- a/torchmetrics/functional/classification/matthews_corrcoef.py +++ b/torchmetrics/functional/classification/matthews_corrcoef.py @@ -37,7 +37,15 @@ def _matthews_corrcoef_compute(confmat: Tensor) -> Tensor: pk = confmat.sum(dim=0).float() c = torch.trace(confmat).float() s = confmat.sum().float() - return (c * s - sum(tk * pk)) / (torch.sqrt(s ** 2 - sum(pk * pk)) * torch.sqrt(s ** 2 - sum(tk * tk))) + + cov_ytyp = c * s - sum(tk * pk) + cov_ypyp = s ** 2 - sum(pk * pk) + cov_ytyt = s ** 2 - sum(tk * tk) + + if cov_ypyp * cov_ytyt == 0: + return torch.tensor(0, dtype=confmat.dtype, device=confmat.device) + else: + return cov_ytyp / torch.sqrt(cov_ytyt * cov_ypyp) def matthews_corrcoef( From ec75b14b59bb865fa9d9678785348cabf602b923 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Thu, 20 Jan 2022 14:50:23 +0100 Subject: [PATCH 06/13] fast and slow binning --- .../classification/calibration_error.py | 90 +++++++++++++++---- 1 file changed, 74 insertions(+), 16 deletions(-) diff --git a/torchmetrics/functional/classification/calibration_error.py b/torchmetrics/functional/classification/calibration_error.py index 48460e6c809..4797b253ff1 100644 --- a/torchmetrics/functional/classification/calibration_error.py +++ b/torchmetrics/functional/classification/calibration_error.py @@ -18,6 +18,71 @@ from torchmetrics.utilities.checks import _input_format_classification from torchmetrics.utilities.enums import DataType +from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_6 + + +def _slow_binning( + confidences: FloatTensor, + accuracies: FloatTensor, + bin_boundaries: FloatTensor +) -> Tuple[FloatTensor, FloatTensor, FloatTensor]: + """ + Compute calibration bins using for loops. Use for pytorch < 1.6 + Args: + confidences (FloatTensor): The confidence (i.e. predicted prob) of the top1 prediction. + accuracies (FloatTensor): 1.0 if the top-1 prediction was correct, 0.0 otherwise. + bin_boundaries (FloatTensor): Bin boundaries separating the linspace from 0 to 1. + + Returns: + tuple with binned accuracy, binned confidence and binned probabilities + """ + conf_bin = torch.zeros_like(bin_boundaries) + acc_bin = torch.zeros_like(bin_boundaries) + prop_bin = torch.zeros_like(bin_boundaries) + for i, (bin_lower, bin_upper) in enumerate(zip(bin_boundaries[:-1], bin_boundaries[1:])): + # Calculated confidence and accuracy in each bin + in_bin = confidences.gt(bin_lower.item()) * confidences.le(bin_upper.item()) + prop_in_bin = in_bin.float().mean() + if prop_in_bin.item() > 0: + acc_bin[i] = accuracies[in_bin].float().mean() + conf_bin[i] = confidences[in_bin].mean() + prop_bin[i] = prop_in_bin + return acc_bin, conf_bin, prop_bin + + +def _fast_binning( + confidences: FloatTensor, + accuracies: FloatTensor, + bin_boundaries: FloatTensor +) -> Tuple[FloatTensor, FloatTensor, FloatTensor]: + """ + Compute calibration bins using torch.bucketize. Use for pytorch >= 1.6. + + Args: + confidences (FloatTensor): The confidence (i.e. predicted prob) of the top1 prediction. + accuracies (FloatTensor): 1.0 if the top-1 prediction was correct, 0.0 otherwise. + bin_boundaries (FloatTensor): Bin boundaries separating the linspace from 0 to 1. + + Returns: + tuple with binned accuracy, binned confidence and binned probabilities + + """ + acc_bin = torch.zeros(len(bin_boundaries) - 1, device=confidences.device) + conf_bin = torch.zeros(len(bin_boundaries) - 1, device=confidences.device) + count_bin = torch.zeros(len(bin_boundaries) - 1, device=confidences.device) + + indices = torch.bucketize(confidences, bin_boundaries) - 1 + + count_bin.scatter_add_(dim=0, index=indices, src=torch.ones_like(confidences)) + + conf_bin.scatter_add_(dim=0, index=indices, src=confidences) + conf_bin = torch.nan_to_num(conf_bin / count_bin) + + acc_bin.scatter_add_(dim=0, index=indices, src=accuracies) + acc_bin = torch.nan_to_num(acc_bin / count_bin) + + prop_bin = count_bin / count_bin.sum() + return acc_bin, conf_bin, prop_bin def _ce_compute( @@ -46,21 +111,14 @@ def _ce_compute( if norm not in {"l1", "l2", "max"}: raise ValueError(f"Norm {norm} is not supported. Please select from l1, l2, or max. ") - acc_bin = torch.zeros(len(bin_boundaries) - 1, device=confidences.device) - conf_bin = torch.zeros(len(bin_boundaries) - 1, device=confidences.device) - count_bin = torch.zeros(len(bin_boundaries) - 1, device=confidences.device) - - indices = torch.bucketize(confidences, bin_boundaries) - 1 - - count_bin.scatter_add_(dim=0, index=indices, src=torch.ones_like(confidences)) - - conf_bin.scatter_add_(dim=0, index=indices, src=confidences) - conf_bin = torch.nan_to_num(conf_bin / count_bin) - - acc_bin.scatter_add_(dim=0, index=indices, src=accuracies) - acc_bin = torch.nan_to_num(acc_bin / count_bin) - - prop_bin = count_bin / count_bin.sum() + if _TORCH_GREATER_EQUAL_1_6: + acc_bin, conf_bin, prop_bin = _fast_binning( + confidences, accuracies, bin_boundaries + ) + else: + acc_bin, conf_bin, prop_bin = _slow_binning( + confidences, accuracies, bin_boundaries + ) if norm == "l1": ce = torch.sum(torch.abs(acc_bin - conf_bin) * prop_bin) @@ -90,7 +148,7 @@ def _ce_update(preds: Tensor, target: Tensor) -> Tuple[FloatTensor, FloatTensor] ValueError: If the dataset shape is not binary, multiclass, or multidimensional-multiclass. Returns: - Tuple[FloatTensor, FloatTensor]: [description] + tuple with confidences and accuracies """ _, _, mode = _input_format_classification(preds, target) From 102d772318ea0b8c6d153b059e786a67839ab444 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 20 Jan 2022 13:51:23 +0000 Subject: [PATCH 07/13] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../classification/calibration_error.py | 20 +++++-------------- 1 file changed, 5 insertions(+), 15 deletions(-) diff --git a/torchmetrics/functional/classification/calibration_error.py b/torchmetrics/functional/classification/calibration_error.py index 4797b253ff1..777bd340017 100644 --- a/torchmetrics/functional/classification/calibration_error.py +++ b/torchmetrics/functional/classification/calibration_error.py @@ -22,9 +22,7 @@ def _slow_binning( - confidences: FloatTensor, - accuracies: FloatTensor, - bin_boundaries: FloatTensor + confidences: FloatTensor, accuracies: FloatTensor, bin_boundaries: FloatTensor ) -> Tuple[FloatTensor, FloatTensor, FloatTensor]: """ Compute calibration bins using for loops. Use for pytorch < 1.6 @@ -51,12 +49,9 @@ def _slow_binning( def _fast_binning( - confidences: FloatTensor, - accuracies: FloatTensor, - bin_boundaries: FloatTensor + confidences: FloatTensor, accuracies: FloatTensor, bin_boundaries: FloatTensor ) -> Tuple[FloatTensor, FloatTensor, FloatTensor]: - """ - Compute calibration bins using torch.bucketize. Use for pytorch >= 1.6. + """Compute calibration bins using torch.bucketize. Use for pytorch >= 1.6. Args: confidences (FloatTensor): The confidence (i.e. predicted prob) of the top1 prediction. @@ -65,7 +60,6 @@ def _fast_binning( Returns: tuple with binned accuracy, binned confidence and binned probabilities - """ acc_bin = torch.zeros(len(bin_boundaries) - 1, device=confidences.device) conf_bin = torch.zeros(len(bin_boundaries) - 1, device=confidences.device) @@ -112,13 +106,9 @@ def _ce_compute( raise ValueError(f"Norm {norm} is not supported. Please select from l1, l2, or max. ") if _TORCH_GREATER_EQUAL_1_6: - acc_bin, conf_bin, prop_bin = _fast_binning( - confidences, accuracies, bin_boundaries - ) + acc_bin, conf_bin, prop_bin = _fast_binning(confidences, accuracies, bin_boundaries) else: - acc_bin, conf_bin, prop_bin = _slow_binning( - confidences, accuracies, bin_boundaries - ) + acc_bin, conf_bin, prop_bin = _slow_binning(confidences, accuracies, bin_boundaries) if norm == "l1": ce = torch.sum(torch.abs(acc_bin - conf_bin) * prop_bin) From 0163b3aba4a71749881c4619c05c7dc10f37f03d Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Thu, 20 Jan 2022 14:52:34 +0100 Subject: [PATCH 08/13] changelog --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 9df5c265111..9b5e8310422 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,6 +14,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Changed +- Use torch.bucketize in calibration error when torch>1.6 for faster computations ([#769](https://github.com/PyTorchLightning/metrics/pull/769)) ### Deprecated From b77b7518f1289d8c2ca859e9ead7baa0c2a017ff Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Thu, 20 Jan 2022 15:38:19 +0100 Subject: [PATCH 09/13] Apply suggestions from code review Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com> --- torchmetrics/functional/classification/calibration_error.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torchmetrics/functional/classification/calibration_error.py b/torchmetrics/functional/classification/calibration_error.py index 777bd340017..c555066dc5f 100644 --- a/torchmetrics/functional/classification/calibration_error.py +++ b/torchmetrics/functional/classification/calibration_error.py @@ -61,9 +61,9 @@ def _fast_binning( Returns: tuple with binned accuracy, binned confidence and binned probabilities """ - acc_bin = torch.zeros(len(bin_boundaries) - 1, device=confidences.device) - conf_bin = torch.zeros(len(bin_boundaries) - 1, device=confidences.device) - count_bin = torch.zeros(len(bin_boundaries) - 1, device=confidences.device) + acc_bin = torch.zeros(len(bin_boundaries) - 1, device=confidences.device, dtype=confidences.dtype) + conf_bin = torch.zeros(len(bin_boundaries) - 1, device=confidences.device, dtype=confidences.dtype) + count_bin = torch.zeros(len(bin_boundaries) - 1, device=confidences.device, dtype=confidences.dtype) indices = torch.bucketize(confidences, bin_boundaries) - 1 From f6741ae35cbffde62f5c0b0be72a7e0d03468796 Mon Sep 17 00:00:00 2001 From: Jirka Date: Thu, 20 Jan 2022 15:44:19 +0100 Subject: [PATCH 10/13] cleaning --- .../classification/calibration_error.py | 58 +++++++++---------- 1 file changed, 29 insertions(+), 29 deletions(-) diff --git a/torchmetrics/functional/classification/calibration_error.py b/torchmetrics/functional/classification/calibration_error.py index c555066dc5f..5c0cd79d0b2 100644 --- a/torchmetrics/functional/classification/calibration_error.py +++ b/torchmetrics/functional/classification/calibration_error.py @@ -21,15 +21,15 @@ from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_6 -def _slow_binning( - confidences: FloatTensor, accuracies: FloatTensor, bin_boundaries: FloatTensor -) -> Tuple[FloatTensor, FloatTensor, FloatTensor]: +def _binning_with_loop( + confidences: Tensor, accuracies: Tensor, bin_boundaries: Tensor +) -> Tuple[Tensor, Tensor, Tensor]: """ Compute calibration bins using for loops. Use for pytorch < 1.6 Args: - confidences (FloatTensor): The confidence (i.e. predicted prob) of the top1 prediction. - accuracies (FloatTensor): 1.0 if the top-1 prediction was correct, 0.0 otherwise. - bin_boundaries (FloatTensor): Bin boundaries separating the linspace from 0 to 1. + confidences: The confidence (i.e. predicted prob) of the top1 prediction. + accuracies: 1.0 if the top-1 prediction was correct, 0.0 otherwise. + bin_boundaries: Bin boundaries separating the linspace from 0 to 1. Returns: tuple with binned accuracy, binned confidence and binned probabilities @@ -48,15 +48,15 @@ def _slow_binning( return acc_bin, conf_bin, prop_bin -def _fast_binning( - confidences: FloatTensor, accuracies: FloatTensor, bin_boundaries: FloatTensor -) -> Tuple[FloatTensor, FloatTensor, FloatTensor]: +def _binning_bucketize( + confidences: Tensor, accuracies: Tensor, bin_boundaries: Tensor +) -> Tuple[Tensor, Tensor, Tensor]: """Compute calibration bins using torch.bucketize. Use for pytorch >= 1.6. Args: - confidences (FloatTensor): The confidence (i.e. predicted prob) of the top1 prediction. - accuracies (FloatTensor): 1.0 if the top-1 prediction was correct, 0.0 otherwise. - bin_boundaries (FloatTensor): Bin boundaries separating the linspace from 0 to 1. + confidences: The confidence (i.e. predicted prob) of the top1 prediction. + accuracies: 1.0 if the top-1 prediction was correct, 0.0 otherwise. + bin_boundaries: Bin boundaries separating the linspace from 0 to 1. Returns: tuple with binned accuracy, binned confidence and binned probabilities @@ -80,20 +80,20 @@ def _fast_binning( def _ce_compute( - confidences: FloatTensor, - accuracies: FloatTensor, - bin_boundaries: FloatTensor, + confidences: Tensor, + accuracies: Tensor, + bin_boundaries: Tensor, norm: str = "l1", debias: bool = False, ) -> Tensor: """Computes the calibration error given the provided bin boundaries and norm. Args: - confidences (FloatTensor): The confidence (i.e. predicted prob) of the top1 prediction. - accuracies (FloatTensor): 1.0 if the top-1 prediction was correct, 0.0 otherwise. - bin_boundaries (FloatTensor): Bin boundaries separating the linspace from 0 to 1. - norm (str, optional): Norm function to use when computing calibration error. Defaults to "l1". - debias (bool, optional): Apply debiasing to L2 norm computation as in + confidences: The confidence (i.e. predicted prob) of the top1 prediction. + accuracies: 1.0 if the top-1 prediction was correct, 0.0 otherwise. + bin_boundaries: Bin boundaries separating the linspace from 0 to 1. + norm: Norm function to use when computing calibration error. Defaults to "l1". + debias: Apply debiasing to L2 norm computation as in `Verified Uncertainty Calibration`_. Defaults to False. Raises: @@ -106,9 +106,9 @@ def _ce_compute( raise ValueError(f"Norm {norm} is not supported. Please select from l1, l2, or max. ") if _TORCH_GREATER_EQUAL_1_6: - acc_bin, conf_bin, prop_bin = _fast_binning(confidences, accuracies, bin_boundaries) + acc_bin, conf_bin, prop_bin = _binning_bucketize(confidences, accuracies, bin_boundaries) else: - acc_bin, conf_bin, prop_bin = _slow_binning(confidences, accuracies, bin_boundaries) + acc_bin, conf_bin, prop_bin = _binning_with_loop(confidences, accuracies, bin_boundaries) if norm == "l1": ce = torch.sum(torch.abs(acc_bin - conf_bin) * prop_bin) @@ -126,13 +126,13 @@ def _ce_compute( return ce -def _ce_update(preds: Tensor, target: Tensor) -> Tuple[FloatTensor, FloatTensor]: +def _ce_update(preds: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]: """Given a predictions and targets tensor, computes the confidences of the top-1 prediction and records their correctness. Args: - preds (Tensor): Input softmaxed predictions. - target (Tensor): Labels. + preds: Input softmaxed predictions. + target: Labels. Raises: ValueError: If the dataset shape is not binary, multiclass, or multidimensional-multiclass. @@ -190,10 +190,10 @@ def calibration_error(preds: Tensor, target: Tensor, n_bins: int = 15, norm: str L2-norm debiasing is not yet supported. Args: - preds (Tensor): Model output probabilities. - target (Tensor): Ground-truth target class labels. - n_bins (int, optional): Number of bins to use when computing t. Defaults to 15. - norm (str, optional): Norm used to compare empirical and expected probability bins. + preds: Model output probabilities. + target: Ground-truth target class labels. + n_bins: Number of bins to use when computing t. Defaults to 15. + norm: Norm used to compare empirical and expected probability bins. Defaults to "l1", or Expected Calibration Error. """ if norm not in ("l1", "l2", "max"): From 7ff78c4ef517e780e83651db268fa3a639ba7797 Mon Sep 17 00:00:00 2001 From: Jirka Date: Thu, 20 Jan 2022 15:46:41 +0100 Subject: [PATCH 11/13] flake8 --- torchmetrics/functional/classification/calibration_error.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchmetrics/functional/classification/calibration_error.py b/torchmetrics/functional/classification/calibration_error.py index 5c0cd79d0b2..321205937a7 100644 --- a/torchmetrics/functional/classification/calibration_error.py +++ b/torchmetrics/functional/classification/calibration_error.py @@ -14,7 +14,7 @@ from typing import Tuple import torch -from torch import FloatTensor, Tensor +from torch import Tensor from torchmetrics.utilities.checks import _input_format_classification from torchmetrics.utilities.enums import DataType From 8f4ec578b7727d59826a6962072c9052099eee79 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Thu, 20 Jan 2022 16:16:21 +0100 Subject: [PATCH 12/13] increase to 1.8 --- torchmetrics/functional/classification/calibration_error.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchmetrics/functional/classification/calibration_error.py b/torchmetrics/functional/classification/calibration_error.py index 321205937a7..2a5748e7163 100644 --- a/torchmetrics/functional/classification/calibration_error.py +++ b/torchmetrics/functional/classification/calibration_error.py @@ -18,7 +18,7 @@ from torchmetrics.utilities.checks import _input_format_classification from torchmetrics.utilities.enums import DataType -from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_6 +from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_8 def _binning_with_loop( @@ -105,7 +105,7 @@ def _ce_compute( if norm not in {"l1", "l2", "max"}: raise ValueError(f"Norm {norm} is not supported. Please select from l1, l2, or max. ") - if _TORCH_GREATER_EQUAL_1_6: + if _TORCH_GREATER_EQUAL_1_8: acc_bin, conf_bin, prop_bin = _binning_bucketize(confidences, accuracies, bin_boundaries) else: acc_bin, conf_bin, prop_bin = _binning_with_loop(confidences, accuracies, bin_boundaries) From 97be5af0ba71f215dfeb7cd64cb106fe78b46524 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Thu, 20 Jan 2022 20:14:42 +0100 Subject: [PATCH 13/13] chlog --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index b5ec016459f..a415e0ae348 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,7 +15,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Changed -- Use torch.bucketize in calibration error when torch>1.6 for faster computations ([#769](https://github.com/PyTorchLightning/metrics/pull/769)) +- Used `torch.bucketize` in calibration error when `torch>1.8` for faster computations ([#769](https://github.com/PyTorchLightning/metrics/pull/769)) ### Deprecated