Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/pre_training_5' into pre_training_5
Browse files Browse the repository at this point in the history
  • Loading branch information
sam1373 committed Dec 21, 2021
2 parents ca0ca7f + 62f4bc9 commit c899811
Show file tree
Hide file tree
Showing 94 changed files with 1,768 additions and 922 deletions.
25 changes: 15 additions & 10 deletions Jenkinsfile
Original file line number Diff line number Diff line change
Expand Up @@ -111,12 +111,12 @@ pipeline {
parallel {
stage('En TN grammars') {
steps {
sh 'CUDA_VISIBLE_DEVICES="" python nemo_text_processing/text_normalization/normalize.py "1" --cache_dir /home/TestData/nlp/text_norm/ci/grammars/12-12'
sh 'CUDA_VISIBLE_DEVICES="" python nemo_text_processing/text_normalization/normalize.py "1" --cache_dir /home/TestData/nlp/text_norm/ci/grammars/12-15'
}
}
stage('En ITN grammars') {
steps {
sh 'CUDA_VISIBLE_DEVICES="" python nemo_text_processing/inverse_text_normalization/inverse_normalize.py --language en "twenty" --cache_dir /home/TestData/nlp/text_norm/ci/grammars/12-12'
sh 'CUDA_VISIBLE_DEVICES="" python nemo_text_processing/inverse_text_normalization/inverse_normalize.py --language en "twenty" --cache_dir /home/TestData/nlp/text_norm/ci/grammars/12-15'
}
}
stage('German ITN and non-deterministic TN') {
Expand All @@ -131,8 +131,8 @@ pipeline {
}
stage('Test En non-deterministic TN & Run all En TN/ITN tests (restore grammars from cache)') {
steps {
sh 'CUDA_VISIBLE_DEVICES="" python nemo_text_processing/text_normalization/normalize_with_audio.py --text "\$.01" --n_tagged 2 --cache_dir /home/TestData/nlp/text_norm/ci/grammars/12-12'
sh 'CUDA_VISIBLE_DEVICES="" pytest tests/nemo_text_processing/en/ -m "not pleasefixme" --cpu --tn_cache_dir /home/TestData/nlp/text_norm/ci/grammars/12-12'
sh 'CUDA_VISIBLE_DEVICES="" python nemo_text_processing/text_normalization/normalize_with_audio.py --text "\$.01" --n_tagged 2 --cache_dir /home/TestData/nlp/text_norm/ci/grammars/12-15'
sh 'CUDA_VISIBLE_DEVICES="" pytest tests/nemo_text_processing/en/ -m "not pleasefixme" --cpu --tn_cache_dir /home/TestData/nlp/text_norm/ci/grammars/12-15'
}
}
stage('Run Ru ITN and non-deterministic TN & Run all Ru ITN tests') {
Expand All @@ -156,7 +156,7 @@ pipeline {
parallel {
stage('L2: Eng TN') {
steps {
sh 'cd tools/text_processing_deployment && python pynini_export.py --output=/home/TestData/nlp/text_norm/output/ --grammars=tn_grammars --cache_dir /home/TestData/nlp/text_norm/ci/grammars/12-12 --language=en && ls -R /home/TestData/nlp/text_norm/output/ && echo ".far files created "|| exit 1'
sh 'cd tools/text_processing_deployment && python pynini_export.py --output=/home/TestData/nlp/text_norm/output/ --grammars=tn_grammars --cache_dir /home/TestData/nlp/text_norm/ci/grammars/12-15 --language=en && ls -R /home/TestData/nlp/text_norm/output/ && echo ".far files created "|| exit 1'
sh 'cd nemo_text_processing/text_normalization/ && python run_predict.py --input=/home/TestData/nlp/text_norm/ci/test.txt --input_case="lower_cased" --language=en --output=/home/TestData/nlp/text_norm/output/test.pynini.txt --verbose'
sh 'cat /home/TestData/nlp/text_norm/output/test.pynini.txt'
sh 'cmp --silent /home/TestData/nlp/text_norm/output/test.pynini.txt /home/TestData/nlp/text_norm/ci/test_goal_py_12-10.txt || exit 1'
Expand All @@ -166,7 +166,7 @@ pipeline {

stage('L2: Eng ITN export') {
steps {
sh 'cd tools/text_processing_deployment && python pynini_export.py --output=/home/TestData/nlp/text_denorm/output/ --grammars=itn_grammars --cache_dir /home/TestData/nlp/text_norm/ci/grammars/12-12 --language=en && ls -R /home/TestData/nlp/text_denorm/output/ && echo ".far files created "|| exit 1'
sh 'cd tools/text_processing_deployment && python pynini_export.py --output=/home/TestData/nlp/text_denorm/output/ --grammars=itn_grammars --cache_dir /home/TestData/nlp/text_norm/ci/grammars/12-15 --language=en && ls -R /home/TestData/nlp/text_denorm/output/ && echo ".far files created "|| exit 1'
sh 'cd nemo_text_processing/inverse_text_normalization/ && python run_predict.py --input=/home/TestData/nlp/text_denorm/ci/test.txt --language=en --output=/home/TestData/nlp/text_denorm/output/test.pynini.txt --verbose'
sh 'cmp --silent /home/TestData/nlp/text_denorm/output/test.pynini.txt /home/TestData/nlp/text_denorm/ci/test_goal_py.txt || exit 1'
sh 'rm -rf /home/TestData/nlp/text_denorm/output/*'
Expand All @@ -175,23 +175,23 @@ pipeline {
stage('L2: TN with Audio (audio and raw text)') {
steps {
sh 'cd nemo_text_processing/text_normalization && \
python normalize_with_audio.py --language=en --cache_dir /home/TestData/nlp/text_norm/ci/grammars/12-12 --text "The total amounts to \\$4.76." \
python normalize_with_audio.py --language=en --cache_dir /home/TestData/nlp/text_norm/ci/grammars/12-15 --text "The total amounts to \\$4.76." \
--audio_data /home/TestData/nlp/text_norm/audio_based/audio.wav | tail -n2 | head -n1 > /tmp/out_raw.txt 2>&1 && \
cmp --silent /tmp/out_raw.txt /home/TestData/nlp/text_norm/audio_based/result.txt || exit 1'
}
}
stage('L2: TN with Audio (audio and text file)') {
steps {
sh 'cd nemo_text_processing/text_normalization && \
python normalize_with_audio.py --language=en --cache_dir /home/TestData/nlp/text_norm/ci/grammars/12-12 --text /home/TestData/nlp/text_norm/audio_based/text.txt \
python normalize_with_audio.py --language=en --cache_dir /home/TestData/nlp/text_norm/ci/grammars/12-15 --text /home/TestData/nlp/text_norm/audio_based/text.txt \
--audio_data /home/TestData/nlp/text_norm/audio_based/audio.wav | tail -n2 | head -n1 > /tmp/out_file.txt 2>&1 && \
cmp --silent /tmp/out_file.txt /home/TestData/nlp/text_norm/audio_based/result.txt || exit 1'
}
}
stage('L2: TN with Audio (manifest)') {
steps {
sh 'cd nemo_text_processing/text_normalization && \
python normalize_with_audio.py --language=en --audio_data /home/TestData/nlp/text_norm/audio_based/manifest.json --n_tagged=120 --cache_dir /home/TestData/nlp/text_norm/ci/grammars/12-12'
python normalize_with_audio.py --language=en --audio_data /home/TestData/nlp/text_norm/audio_based/manifest.json --n_tagged=120 --cache_dir /home/TestData/nlp/text_norm/ci/grammars/12-15'
}
}
}
Expand Down Expand Up @@ -2054,7 +2054,12 @@ pipeline {
}
failFast true
steps {
sh 'CUDA_VISIBLE_DEVICES=0 python examples/asr/speech_to_text_infer.py --asr_model QuartzNet15x5Base-En --dataset /home/TestData/librispeech/librivox-dev-other.json --wer_tolerance 0.1012 --batch_size 64'
sh 'CUDA_VISIBLE_DEVICES=0 python examples/asr/speech_to_text_eval.py \
pretrained_name=QuartzNet15x5Base-En \
dataset_manifest=/home/TestData/librispeech/librivox-dev-other.json \
batch_size=64 \
tolerance=0.1012'
sh 'rm -f examples/asr/evaluation_transcripts.json'
}
}
}
Expand Down
3 changes: 2 additions & 1 deletion README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -173,14 +173,15 @@ Megatron GPT training requires NVIDIA Apex to be installed.
git clone https://github.com/NVIDIA/apex
cd apex
git checkout 14ccf5986401104121d0ef286a29386904af3bb7
git checkout 05f2d96baf9387c271134e292c811c3d94ed5fd2
pip install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./
Docker containers:
~~~~~~~~~~~~~~~~~~
To build a nemo container with Dockerfile from a branch, please run

.. code-block:: bash
DOCKER_BUILDKIT=1 docker build -f Dockerfile -t nemo:latest .
Expand Down
2 changes: 1 addition & 1 deletion examples/asr/experimental/sclite/speech_to_text_sclite.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

"""
This script is based on speech_to_text_infer.py and allows you to score the hypotheses
This script is based on speech_to_text_eval.py and allows you to score the hypotheses
with sclite. A local installation from https://github.com/usnistgov/SCTK is required.
Hypotheses and references are first saved in trn format and are scored after applying a glm
file (if provided).
Expand Down
156 changes: 156 additions & 0 deletions examples/asr/speech_to_text_eval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Script to compute the Word or Character Error Rate of a given ASR model for a given manifest file for some dataset.
The manifest file must conform to standard ASR definition - containing `audio_filepath` and `text` as the ground truth.
Note: This script depends on the `transcribe_speech.py` script, and therefore both scripts should be located in the
same directory during execution.
# Arguments
<< All arguments of `transcribe_speech.py` are inherited by this script, so please refer to `transcribe_speech.py`
for full list of arguments >>
dataset_manifest: Required - path to dataset JSON manifest file (in NeMo format)
output_filename: Optional - output filename where the transcriptions will be written.
use_cer: Bool, whether to compute CER or WER
tolerance: Float, minimum WER/CER required to pass some arbitrary tolerance.
only_score_manifest: Bool, when set will skip audio transcription and just calculate WER of provided manifest.
# Usage
## To score a dataset with a manifest file that does not contain previously transcribed `pred_text`.
python speech_to_text_eval.py \
model_path=null \
pretrained_name=null \
dataset_manifest=<Mandatory: Path to an ASR dataset manifest file> \
output_filename=<Optional: Some output filename which will hold the transcribed text as a manifest> \
batch_size=32 \
amp=True \
use_cer=False
## To score a manifest file which has been previously augmented with transcribed text as `pred_text`
This is useful when one uses `transcribe_speech_parallel.py` to transcribe larger datasets, and results are written
to a manifest which has the two keys `text` (for ground truth) and `pred_text` (for model's transcription)
python speech_to_text_eval.py \
dataset_manifest=<Mandatory: Path to an ASR dataset manifest file> \
use_cer=False \
only_score_manifest=True
"""

import json
import os
from dataclasses import dataclass, is_dataclass
from typing import Optional

import torch
import transcribe_speech
from omegaconf import MISSING, OmegaConf, open_dict

from nemo.collections.asr.metrics.wer import word_error_rate
from nemo.core.config import hydra_runner
from nemo.utils import logging


@dataclass
class EvaluationConfig(transcribe_speech.TranscriptionConfig):
dataset_manifest: str = MISSING
output_filename: Optional[str] = "evaluation_transcripts.json"

use_cer: bool = False
tolerance: Optional[float] = None

only_score_manifest: bool = False


@hydra_runner(config_name="EvaluationConfig", schema=EvaluationConfig)
def main(cfg: EvaluationConfig):
torch.set_grad_enabled(False)

if is_dataclass(cfg):
cfg = OmegaConf.structured(cfg)

if cfg.audio_dir is not None:
raise RuntimeError(
"Evaluation script requires ground truth labels to be passed via a manifest file. "
"If manifest file is available, submit it via `dataset_manifest` argument."
)

if not os.path.exists(cfg.dataset_manifest):
raise FileNotFoundError(f"The dataset manifest file could not be found at path : {cfg.dataset_manifest}")

if not cfg.only_score_manifest:
# Transcribe speech into an output directory
transcription_cfg = transcribe_speech.main(cfg) # type: EvaluationConfig

# Release GPU memory if it was used during transcription
if torch.cuda.is_available():
torch.cuda.empty_cache()

logging.info("Finished transcribing speech dataset. Computing ASR metrics..")

else:
cfg.output_filename = cfg.dataset_manifest
transcription_cfg = cfg

ground_truth_text = []
predicted_text = []
invalid_manifest = False
with open(transcription_cfg.output_filename, 'r') as f:
for line in f:
data = json.loads(line)

if 'pred_text' not in data:
invalid_manifest = True
break

ground_truth_text.append(data['text'])
predicted_text.append(data['pred_text'])

# Test for invalid manifest supplied
if invalid_manifest:
raise ValueError(
f"Invalid manifest provided: {transcription_cfg.output_filename} does not "
f"contain value for `pred_text`."
)

# Compute the WER
metric_name = 'CER' if cfg.use_cer else 'WER'
metric_value = word_error_rate(hypotheses=predicted_text, references=ground_truth_text, use_cer=cfg.use_cer)

if cfg.tolerance is not None:
if metric_value > cfg.tolerance:
raise ValueError(f"Got {metric_name} of {metric_value}, which was higher than tolerance={cfg.tolerance}")

logging.info(f'Got {metric_name} of {metric_value}. Tolerance was {cfg.tolerance}')
else:
logging.info(f'Got {metric_name} of {metric_value}')

# Inject the metric name and score into the config, and return the entire config
with open_dict(cfg):
cfg.metric_name = metric_name
cfg.metric_value = metric_value

return cfg


if __name__ == '__main__':
main() # noqa pylint: disable=no-value-for-parameter
115 changes: 0 additions & 115 deletions examples/asr/speech_to_text_infer.py

This file was deleted.

Loading

0 comments on commit c899811

Please sign in to comment.