Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Text Normalization Update #2356

Merged
merged 8 commits into from
Jun 15, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions nemo_text_processing/text_normalization/data/months/abbr.tsv
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@ jan january
feb february
mar march
apr april
jun june
jun june
jul july
aug august
sep september
sept september
oct october
nov november
dec december
dec december
13 changes: 13 additions & 0 deletions nemo_text_processing/text_normalization/data/roman/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
49 changes: 49 additions & 0 deletions nemo_text_processing/text_normalization/data/roman/digit_teen.tsv
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
i 1
ii 2
iii 3
iv 4
v 5
vi 6
vii 7
viii 8
ix 9
x 10
xi 11
xii 12
xiii 13
xiv 14
xv 15
xvi 16
xvii 17
xviii 18
xix 19
xx 20
xxi 21
xxii 22
xxiii 23
xxiv 24
xxv 25
xxvi 26
xxvii 27
xxviii 28
xxix 29
xxx 30
xxxi 31
xxxii 32
xxxiii 33
xxxiv 34
xxxv 35
xxxvi 36
xxxvii 37
xxxviii 38
xxxix 39
xl 40
xli 41
xlii 42
xliii 43
xliv 44
xlv 45
xlvi 46
xlvii 47
xlviii 48
xlix 49
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
c 100
cc 200
ccc 300
cd 400
d 500
dc 600
dcc 700
dccc 800
cm 900
5 changes: 5 additions & 0 deletions nemo_text_processing/text_normalization/data/roman/ties.tsv
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
l 50
lx 60
lxx 70
lxxx 80
xc 90
ekmb marked this conversation as resolved.
Show resolved Hide resolved
3 changes: 3 additions & 0 deletions nemo_text_processing/text_normalization/data/whitelist.tsv
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
Ph.D. p h d
Hon. honorable
& and
Mt. Mount
Maj. Major
Rev. Reverend
# hash
Gov. governor
7-eleven seven eleven
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,7 @@ Mrs. Misses
Ms. Miss
Mr Mister
Mrs Misses
Ms Miss
Ms Miss
&Co. and Co.
§ section
= equals
21 changes: 21 additions & 0 deletions nemo_text_processing/text_normalization/data_loader_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import csv
import json
import os
import re
from collections import defaultdict, namedtuple
from typing import Dict, List, Optional, Set, Tuple

Expand Down Expand Up @@ -241,8 +242,28 @@ def post_process_punctuation(text: str) -> str:
.replace('“', '"')
.replace("‘", "'")
.replace('`', "'")
.replace('- -', "--")
)

for punct in "!,.:;?":
text = text.replace(f' {punct}', punct)
return text.strip()


def pre_process(text: str) -> str:
"""
Adds space around punctuation marks
Args:
text: string that may include semiotic classes
Returns: text with spaces around punctuation marks
"""
space_both = '*<=>^[]{}'
for punct in space_both:
text = text.replace(punct, ' ' + punct + ' ')

text = text.replace('--', ' ' + '--' + ' ')
# remove extra space
text = re.sub(r' +', ' ', text)
return text
23 changes: 19 additions & 4 deletions nemo_text_processing/text_normalization/normalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from collections import OrderedDict
from typing import List

from nemo_text_processing.text_normalization.data_loader_utils import post_process_punctuation
from nemo_text_processing.text_normalization.data_loader_utils import post_process_punctuation, pre_process
from nemo_text_processing.text_normalization.taggers.tokenize_and_classify import ClassifyFst
from nemo_text_processing.text_normalization.token_parser import PRESERVE_ORDER_KEY, TokenParser
from nemo_text_processing.text_normalization.verbalizers.verbalize_final import VerbalizeFinalFst
Expand Down Expand Up @@ -67,18 +67,23 @@ def normalize_list(self, texts: List[str], verbose=False) -> List[str]:
res.append(text)
return res

def normalize(self, text: str, verbose: bool, punct_post_process: bool = False) -> str:
def normalize(
self, text: str, verbose: bool, punct_pre_process: bool = False, punct_post_process: bool = False
) -> str:
"""
Main function. Normalizes tokens from written to spoken form
e.g. 12 kg -> twelve kilograms
Args:
text: string that may include semiotic classes
verbose: whether to print intermediate meta information
punct_post_process: set to True to normalize punctuation
punct_pre_process: whether to perform punctuation pre-processing, for example, [25] -> [ 25 ]
punct_post_process: whether to normalize punctuation
Returns: spoken form
"""
if punct_pre_process:
text = pre_process(text)
text = text.strip()
if not text:
if verbose:
Expand Down Expand Up @@ -222,10 +227,20 @@ def parse_args():
parser.add_argument(
"--punct_post_process", help="set to True to enable punctuation post processing", action="store_true"
)
parser.add_argument(
"--punct_pre_process", help="set to True to enable punctuation pre processing", action="store_true"
)
return parser.parse_args()


if __name__ == "__main__":
args = parse_args()
normalizer = Normalizer(input_case=args.input_case)
print(normalizer.normalize(args.input_string, verbose=args.verbose, punct_post_process=args.punct_post_process))
print(
normalizer.normalize(
args.input_string,
verbose=args.verbose,
punct_pre_process=args.punct_pre_process,
punct_post_process=args.punct_post_process,
)
)
98 changes: 48 additions & 50 deletions nemo_text_processing/text_normalization/normalize_with_audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,15 @@

import json
import os
import re
import time
from argparse import ArgumentParser
from typing import List, Tuple

from nemo_text_processing.text_normalization.data_loader_utils import post_process_punctuation
from joblib import Parallel, delayed
from nemo_text_processing.text_normalization.data_loader_utils import post_process_punctuation, pre_process
from nemo_text_processing.text_normalization.normalize import Normalizer
from nemo_text_processing.text_normalization.taggers.tokenize_and_classify import ClassifyFst
from nemo_text_processing.text_normalization.verbalizers.verbalize_final import VerbalizeFinalFst
from tqdm import tqdm

from nemo.collections.asr.metrics.wer import word_error_rate
from nemo.collections.asr.models import ASRModel
Expand Down Expand Up @@ -79,20 +78,30 @@ def __init__(self, input_case: str):
self.tagger = ClassifyFst(input_case=input_case, deterministic=False)
self.verbalizer = VerbalizeFinalFst(deterministic=False)

def normalize(self, text: str, n_tagged: int, punct_post_process: bool = True, verbose: bool = False) -> str:
def normalize(
self,
text: str,
n_tagged: int,
punct_pre_process: bool = True,
punct_post_process: bool = True,
verbose: bool = False,
) -> str:
"""
Main function. Normalizes tokens from written to spoken form
e.g. 12 kg -> twelve kilograms
Args:
text: string that may include semiotic classes
n_tagged: number of tagged options to consider, -1 - to get all possible tagged options
punct_pre_process: whether to perform punctuation pre-processing, for example, [25] -> [ 25 ]
punct_post_process: whether to normalize punctuation
verbose: whether to print intermediate meta information
Returns:
normalized text options (usually there are multiple ways of normalizing a given semiotic class)
"""
if punct_pre_process:
text = pre_process(text)
text = text.strip()
if not text:
if verbose:
Expand All @@ -108,7 +117,6 @@ def normalize(self, text: str, n_tagged: int, punct_post_process: bool = True, v
normalized_texts = []
for tagged_text in tagged_texts:
self._verbalize(tagged_text, normalized_texts)

if len(normalized_texts) == 0:
raise ValueError()
if punct_post_process:
Expand Down Expand Up @@ -183,36 +191,12 @@ def calculate_cer(normalized_texts: List[str], transcript: str, remove_punct=Fal
text_clean = text.replace('-', ' ').lower()
if remove_punct:
for punct in "!?:;,.-()*+-/<=>@^_":
text_clean = text_clean.replace(punct, " ")
text_clean = re.sub(r' +', ' ', text_clean)
text_clean = text_clean.replace(punct, "")
cer = round(word_error_rate([transcript], [text_clean], use_cer=True) * 100, 2)
normalized_options.append((text, cer))
return normalized_options


def pre_process(text: str) -> str:
"""
Adds space around punctuation marks
Args:
text: string that may include semiotic classes
Returns: text with spaces around punctuation marks
"""
text = text.replace('--', '-')
space_right = '!?:;,.-()*+-/<=>@^_'
space_both = '-()*+-/<=>@^_'

for punct in space_right:
text = text.replace(punct, punct + ' ')
for punct in space_both:
text = text.replace(punct, ' ' + punct + ' ')

# remove extra space
text = re.sub(r' +', ' ', text)
return text


def get_asr_model(asr_model: ASRModel):
"""
Returns ASR Model
Expand Down Expand Up @@ -249,12 +233,36 @@ def parse_args():
)
parser.add_argument("--verbose", help="print info for debugging", action="store_true")
parser.add_argument("--remove_punct", help="remove punctuation before calculating cer", action="store_true")
parser.add_argument(
"--no_punct_pre_process", help="set to True to disable punctuation pre processing", action="store_true"
)
parser.add_argument(
"--no_punct_post_process", help="set to True to disable punctuation post processing", action="store_true"
)
return parser.parse_args()


def _normalize_line(normalizer: NormalizerWithAudio, line: str, asr_model: ASRModel = None):
line = json.loads(line)
audio = line['audio_filepath']
if 'transcript' in line:
transcript = line['transcript']
else:
transcript = asr_model.transcribe([audio])[0]

normalized_texts = normalizer.normalize(
text=line['text'],
verbose=args.verbose,
n_tagged=args.n_tagged,
punct_pre_process=not args.no_punct_pre_process,
punct_post_process=not args.no_punct_post_process,
)
normalized_text, cer = normalizer.select_best_match(normalized_texts, transcript, args.verbose, args.remove_punct)
line['nemo_normalized'] = normalized_text
line['CER_nemo_normalized'] = cer
return line


def normalize_manifest(args):
"""
Args:
Expand All @@ -265,26 +273,15 @@ def normalize_manifest(args):
asr_model = None
with open(args.audio_data, 'r') as f:
with open(manifest_out, 'w') as f_out:
for line in tqdm(f):
line = json.loads(line)
audio = line['audio_filepath']
if 'transcript' in line:
transcript = line['transcript']
else:
if asr_model is None:
asr_model = get_asr_model(args.model)
transcript = asr_model.transcribe([audio])[0]
normalized_texts = normalizer.normalize(
text=line['text'],
verbose=args.verbose,
n_tagged=args.n_tagged,
punct_post_process=not args.no_punct_post_process,
)
normalized_text, cer = normalizer.select_best_match(
normalized_texts, transcript, args.verbose, args.remove_punct
)
line['nemo_normalized'] = normalized_text
line['CER_nemo_normalized'] = cer
lines = f.readlines()
first_line = json.loads(lines[0])
if 'transcript' not in first_line:
asr_model = get_asr_model(args.model)
normalized_lines = Parallel(n_jobs=-1)(
delayed(_normalize_line)(normalizer, line, asr_model) for line in lines
)

for line in normalized_lines:
f_out.write(json.dumps(line, ensure_ascii=False) + '\n')
print(f'Normalized version saved at {manifest_out}')

Expand All @@ -302,6 +299,7 @@ def normalize_manifest(args):
text=args.text,
verbose=args.verbose,
n_tagged=args.n_tagged,
punct_pre_process=not args.no_punct_pre_process,
punct_post_process=not args.no_punct_post_process,
)
if args.audio_data:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def get_serial_graph(self):
letter_pronunciation = pynini.string_map(load_labels(get_abs_path("data/letter_pronunciation.tsv")))
alpha |= letter_pronunciation

delimiter = insert_space | pynini.cross("-", " ")
delimiter = insert_space | pynini.cross("-", " ") | pynini.cross("/", " ")
letter_num = pynini.closure(alpha + delimiter, 1) + num_graph
num_letter = pynini.closure(num_graph + delimiter, 1) + alpha
next_alpha_or_num = pynini.closure(delimiter + (alpha | num_graph))
Expand Down
Loading