Skip to content

Commit

Permalink
Add wer details - insertion, deletion, substitution rate (#5557)
Browse files Browse the repository at this point in the history
* add wer insertion, deletion, substitution

Signed-off-by: fayejf <[email protected]>

* style

Signed-off-by: fayejf <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Levenshtein -> jiwer due to license

Signed-off-by: fayejf <[email protected]>

* typo

Signed-off-by: fayejf <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: fayejf <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
fayejf and pre-commit-ci[bot] authored Dec 7, 2022
1 parent 52aac8e commit f1c8714
Show file tree
Hide file tree
Showing 3 changed files with 106 additions and 2 deletions.
74 changes: 73 additions & 1 deletion nemo/collections/asr/metrics/wer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from typing import Callable, Dict, List, Optional, Tuple, Union

import editdistance
import jiwer
import numpy as np
import torch
from omegaconf import DictConfig, OmegaConf
Expand All @@ -27,7 +28,7 @@
from nemo.collections.asr.parts.utils.rnnt_utils import Hypothesis, NBestHypotheses
from nemo.utils import logging

__all__ = ['word_error_rate', 'WER', 'move_dimension_to_the_front']
__all__ = ['word_error_rate', 'word_error_rate_detail', 'WER', 'move_dimension_to_the_front']


def word_error_rate(hypotheses: List[str], references: List[str], use_cer=False) -> float:
Expand Down Expand Up @@ -58,6 +59,8 @@ def word_error_rate(hypotheses: List[str], references: List[str], use_cer=False)
h_list = h.split()
r_list = r.split()
words += len(r_list)
# May deprecate using editdistance in future release for here and rest of codebase
# once we confirm jiwer is reliable.
scores += editdistance.eval(h_list, r_list)
if words != 0:
wer = 1.0 * scores / words
Expand All @@ -66,6 +69,75 @@ def word_error_rate(hypotheses: List[str], references: List[str], use_cer=False)
return wer


def word_error_rate_detail(
hypotheses: List[str], references: List[str], use_cer=False
) -> Tuple[float, int, float, float, float]:
"""
Computes Average Word Error Rate with details (insertion rate, deletion rate, substitution rate)
between two texts represented as corresponding lists of string.
Hypotheses and references must have same length.
Args:
hypotheses (list): list of hypotheses
references(list) : list of references
use_cer (bool): set True to enable cer
Returns:
wer (float): average word error rate
words (int): Total number of words/charactors of given reference texts
ins_rate (float): average insertion error rate
del_rate (float): average deletion error rate
sub_rate (float): average substitution error rate
"""
scores = 0
words = 0
ops_count = {'substitutions': 0, 'insertions': 0, 'deletions': 0}

if len(hypotheses) != len(references):
raise ValueError(
"In word error rate calculation, hypotheses and reference"
" lists must have the same number of elements. But I got:"
"{0} and {1} correspondingly".format(len(hypotheses), len(references))
)

for h, r in zip(hypotheses, references):
if use_cer:
h_list = list(h)
r_list = list(r)
else:
h_list = h.split()
r_list = r.split()

# To get rid of the issue that jiwer does not allow empty string
if len(r_list) == 0:
if len(h_list) != 0:
errors = len(h_list)
ops_count['insertions'] += errors
else:
if use_cer:
measures = jiwer.cer(r, h, return_dict=True)
else:
measures = jiwer.compute_measures(r, h)

errors = measures['insertions'] + measures['deletions'] + measures['substitutions']
ops_count['insertions'] += measures['insertions']
ops_count['deletions'] += measures['deletions']
ops_count['substitutions'] += measures['substitutions']

scores += errors
words += len(r_list)

if words != 0:
wer = 1.0 * scores / words
ins_rate = 1.0 * ops_count['insertions'] / words
del_rate = 1.0 * ops_count['deletions'] / words
sub_rate = 1.0 * ops_count['substitutions'] / words
else:
wer, ins_rate, del_rate, sub_rate = float('inf'), float('inf'), float('inf'), float('inf')

return wer, words, ins_rate, del_rate, sub_rate


def move_dimension_to_the_front(tensor, dim_index):
all_dims = list(range(tensor.ndim))
return tensor.permute(*([dim_index] + all_dims[:dim_index] + all_dims[dim_index + 1 :]))
Expand Down
1 change: 1 addition & 0 deletions requirements/requirements_asr.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ editdistance
g2p_en
inflect
ipywidgets
jiwer
kaldi-python-io
kaldiio
librosa>=0.9.0
Expand Down
33 changes: 32 additions & 1 deletion tests/collections/asr/test_asr_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,13 @@

from nemo.collections.asr.metrics.rnnt_wer import RNNTWER
from nemo.collections.asr.metrics.rnnt_wer_bpe import RNNTBPEWER
from nemo.collections.asr.metrics.wer import WER, CTCDecoding, CTCDecodingConfig, word_error_rate
from nemo.collections.asr.metrics.wer import (
WER,
CTCDecoding,
CTCDecodingConfig,
word_error_rate,
word_error_rate_detail,
)
from nemo.collections.asr.metrics.wer_bpe import WERBPE, CTCBPEDecoding, CTCBPEDecodingConfig
from nemo.collections.asr.parts.utils.rnnt_utils import Hypothesis
from nemo.collections.common.tokenizers import CharTokenizer
Expand Down Expand Up @@ -103,6 +109,31 @@ def test_wer_function(self):
assert word_error_rate(hypotheses=['ducati motorcycle'], references=['ducuti motorcycle']) == 0.5
assert word_error_rate(hypotheses=['a B c'], references=['a b c']) == 1.0 / 3.0

assert word_error_rate_detail(hypotheses=['cat'], references=['cot'])[0] == 1.0
assert word_error_rate_detail(hypotheses=['GPU'], references=['G P U'])[0] == 1.0
assert word_error_rate_detail(hypotheses=['G P U'], references=['GPU'])[0] == 3.0
assert word_error_rate_detail(hypotheses=['ducati motorcycle'], references=['motorcycle'])[0] == 1.0
assert word_error_rate_detail(hypotheses=['ducati motorcycle'], references=['ducuti motorcycle'])[0] == 0.5
assert word_error_rate_detail(hypotheses=['a B c'], references=['a b c'])[0] == 1.0 / 3.0

assert word_error_rate_detail(hypotheses=['cat'], references=['']) == (
float("inf"),
0,
float("inf"),
float("inf"),
float("inf"),
)
assert word_error_rate_detail(hypotheses=['cat', ''], references=['', 'gpu']) == (2.0, 1, 1.0, 1.0, 0.0,)
assert word_error_rate_detail(hypotheses=['cat'], references=['cot']) == (1.0, 1, 0.0, 0.0, 1.0)
assert word_error_rate_detail(hypotheses=['G P U'], references=['GPU']) == (3.0, 1, 2.0, 0.0, 1.0)
assert word_error_rate_detail(hypotheses=[''], references=['ducuti motorcycle'], use_cer=True) == (
1.0,
17,
0.0,
1.0,
0.0,
)

@pytest.mark.unit
@pytest.mark.parametrize("batch_dim_index", [0, 1])
@pytest.mark.parametrize("test_wer_bpe", [False, True])
Expand Down

0 comments on commit f1c8714

Please sign in to comment.