diff --git a/nemo/collections/asr/metrics.py b/nemo/collections/asr/metrics.py index b0795c189a47..8fd711a95e28 100644 --- a/nemo/collections/asr/metrics.py +++ b/nemo/collections/asr/metrics.py @@ -1,30 +1,13 @@ # Copyright (c) 2019 NVIDIA Corporation from typing import List, Optional +import editdistance import torch -def __levenshtein(a: List, b: List) -> int: - """Calculates the Levenshtein distance between a and b. - The code was copied from: http://hetland.org/coding/python/levenshtein.py - """ - n, m = len(a), len(b) - if n > m: - # Make sure n <= m, to use O(min(n,m)) space - a, b = b, a - n, m = m, n - - current = list(range(n + 1)) - for i in range(1, m + 1): - previous, current = current, [i] + [0] * n - for j in range(1, n + 1): - add, delete = previous[j] + 1, current[j - 1] + 1 - change = previous[j - 1] - if a[j - 1] != b[i - 1]: - change = change + 1 - current[j] = min(add, delete, change) - - return current[n] +def __levenshtein(a: List[str], b: List[str]) -> int: + """Calculates the Levenshtein distance between a and b.""" + return editdistance.eval(a, b) def word_error_rate(hypotheses: List[str], references: List[str], use_cer=False) -> float: diff --git a/requirements/requirements_asr.txt b/requirements/requirements_asr.txt index 0edb80411639..e237ec515bd0 100644 --- a/requirements/requirements_asr.txt +++ b/requirements/requirements_asr.txt @@ -1,4 +1,5 @@ braceexpand +editdistance frozendict inflect kaldi-io