diff --git a/.gitignore b/.gitignore index dd0bbe49a749..d21e46cb2d43 100644 --- a/.gitignore +++ b/.gitignore @@ -157,3 +157,7 @@ wandb dump.py docs/sources/source/test_build/ + +# Checkpoints, config files and temporary files created in tutorials. +examples/neural_graphs/*.chkpt +examples/neural_graphs/*.yml \ No newline at end of file diff --git a/CHANGELOG.md b/CHANGELOG.md index a00d22c74c6e..9a749af6e06b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -78,6 +78,8 @@ To release a new version, please update the changelog as followed: - Online audio augmentation notebook in ASR examples ([PR #605](https://github.com/NVIDIA/NeMo/pull/605)) - @titu1994 - ContextNet Encoder + Decoder Initial Support ([PR #630](https://github.com/NVIDIA/NeMo/pull/630)) - @titu1994 - Added finetuning with Megatron-LM ([PR #601](https://github.com/NVIDIA/NeMo/pull/601)) - @ekmb +- Added documentation for 8 kHz model ([PR #632](https://github.com/NVIDIA/NeMo/pull/632)) - @jbalam-nv + ### Changed - Syncs across workers at each step to check for NaN or inf loss. Terminates all workers if stop\_on\_nan\_loss is set (as before), lets Apex deal with it if apex.amp optimization level is O1 or higher, and skips the step across workers otherwise. ([PR #637](https://github.com/NVIDIA/NeMo/pull/637)) - @redoctopus diff --git a/Jenkinsfile b/Jenkinsfile index bf6cd48cb448..f76dea063f03 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -2,7 +2,7 @@ pipeline { agent { docker { image 'nvcr.io/nvidia/pytorch:20.01-py3' - args '--device=/dev/nvidia0 --gpus all --user 0:128 -v /home:/home -v $HOME/.cache/torch:/root/.cache/torch --shm-size=8g' + args '--device=/dev/nvidia0 --gpus all --user 0:128 -v /home/TestData:/home/TestData -v $HOME/.cache/torch:/root/.cache/torch --shm-size=8g' } } options { @@ -193,11 +193,22 @@ pipeline { } stage ('Punctuation and Classification Training/Inference Test') { steps { - sh 'cd examples/nlp/token_classification && CUDA_VISIBLE_DEVICES=1 python punctuation_capitalization.py --data_dir /home/TestData/nlp/token_classification_punctuation/ --work_dir punctuation_output --save_epoch_freq 1 --num_epochs 1 --save_step_freq -1 --batch_size 2' + sh 'cd examples/nlp/token_classification && CUDA_VISIBLE_DEVICES=1 python punctuation_capitalization.py \ + --data_dir /home/TestData/nlp/token_classification_punctuation/ --work_dir punctuation_output --save_epoch_freq 1 \ + --num_epochs 1 --save_step_freq -1 --batch_size 2' sh 'cd examples/nlp/token_classification && DATE_F=$(ls punctuation_output/) && DATA_DIR="/home/TestData/nlp/token_classification_punctuation" && CUDA_VISIBLE_DEVICES=1 python punctuation_capitalization_infer.py --checkpoint_dir punctuation_output/$DATE_F/checkpoints/ --punct_labels_dict $DATA_DIR/punct_label_ids.csv --capit_labels_dict $DATA_DIR/capit_label_ids.csv' sh 'rm -rf examples/nlp/token_classification/punctuation_output' } } + stage('SGD Test') { + steps { + sh 'cd examples/nlp/dialogue_state_tracking && CUDA_VISIBLE_DEVICES=0 python dialogue_state_tracking_sgd.py \ + --data_dir /home/TestData/nlp/sgd/ --schema_embedding_dir /home/TestData/nlp/sgd/embeddings/ --eval_dataset dev \ + --dialogues_example_dir /home/TestData/nlp/sgd/dialogue_example_dir/ --work_dir sgd_output --task DEBUG \ + --num_epochs 1 --save_epoch_freq=0' + sh 'rm -rf examples/nlp/dialogue_state_tracking/sgd_output' + } + } } } @@ -355,7 +366,8 @@ pipeline { post { always { + sh "chmod -R 777 ." cleanWs() } } -} +} \ No newline at end of file diff --git a/docs/docs_zh/sources/source/speech_command/tutorial.rst b/docs/docs_zh/sources/source/speech_command/tutorial.rst index ab60fdf33717..0a901d6849c7 100644 --- a/docs/docs_zh/sources/source/speech_command/tutorial.rst +++ b/docs/docs_zh/sources/source/speech_command/tutorial.rst @@ -90,7 +90,7 @@ QuartzNet 模型使用一种固定的模型定义模式: QuartzNet-[BxR], 其 process_classification_evaluation_epoch, ) - logging = nemo.logging + from nemo.utils import logging # Lets define some hyper parameters lr = 0.05 @@ -420,7 +420,7 @@ QuartzNet 模型使用一种固定的模型定义模式: QuartzNet-[BxR], 其 import nemo import nemo.collections.asr as nemo_asr - logging = nemo.logging + from nemo.utils import logging # We add some data_dir = '' diff --git a/docs/sources/source/asr/8kHz_models.rst b/docs/sources/source/asr/8kHz_models.rst new file mode 100644 index 000000000000..053bd23a6dc1 --- /dev/null +++ b/docs/sources/source/asr/8kHz_models.rst @@ -0,0 +1,39 @@ +8kHz Models +=========== + +For applications based on telephony speech, using models trained on narrowband audio data sampled at 8 kHz may perform better than using models built with +audio at a higher frequency (Note that to use models with audio at a different sample rate from your data, you would need to resample your data to match the sampling rate in the +config file of the model). One approach to create large datasets for training a model suitable for your application would be to convert all audio data +to the formats prevalent in your application. Here we detail one such approach that we took to train a model based on 8 kHz data. + +To train a model suitable for recognizing telephony speech we converted some of the datasets to G.711 :cite:`8kHz-mod-itu1988g711`. G.711 is a popular speech codec used in VoIP products and encodes speech +at 64 kbps using PCM u-law companding. We converted audio from LibriSpeech, Mozilla Common Voice and WSJ datasets to G.711 format and combined Fisher and Switchboard datasets to +train a :ref:`Quartznet15x5 ` model with about 4000 hours of data. To convert your audio to G.711 format you can use the script `convert_wav_to_g711wav.py` found in the `scripts` sub-directory of the nemo base directory. + +Among the experiments that we ran, we got the best accuracy for a model that used our 16 kHz Quartznet15x5 model's weights as pre-trained weights. We then +trained the model for 250 epochs with five datasets mentioned above. Here are some results for our best model so far (note that all the test sets +were converted to G.711 format for the results below): + +====================== ===================== +Test set WER (%) +====================== ===================== +LibriSpeech dev-clean 4.35 +LibriSpeech dev-other 11.89 +LibriSpeech test-clean 4.45 +LibriSpeech test-other 12.02 +Switchboard test 10.74 +Switchboard dev 10.59 +====================== ===================== + +The model was first pretrained with 8 kHz LibriSpeech data for 134 epochs and then trained for another 250 epochs using G.711 audio from all the five datasets listed above. For best accuracy +in your application, you may choose to :ref:`fine-tune ` this model using data collected from your application. + +.. + The pre-trained model is available for download `here `_. + +References +---------- +.. bibliography:: asr_all.bib + :style: plain + :labelprefix: 8kHz-mod + :keyprefix: 8kHz-mod- diff --git a/docs/sources/source/asr/asr_all.bib b/docs/sources/source/asr/asr_all.bib index 3cdd9c68f9d2..5eb0704b073f 100644 --- a/docs/sources/source/asr/asr_all.bib +++ b/docs/sources/source/asr/asr_all.bib @@ -60,8 +60,6 @@ @misc{ardila2019common primaryClass={cs.CL} } - - @article{graves2012, title={Sequence Transduction with Recurrent Neural Networks}, author={Graves, Alex}, @@ -927,8 +925,14 @@ @article{novograd2019 } @article{kriman2019quartznet, - title={Quartznet: Deep automatic speech recognition with 1d time-channel separable convolutions}, + title={Quartznet: {Deep} automatic speech recognition with 1d time-channel separable convolutions}, author={Kriman, Samuel and Beliaev, Stanislav and Ginsburg, Boris and Huang, Jocelyn and Kuchaiev, Oleksii and Lavrukhin, Vitaly and Leary, Ryan and Li, Jason and Zhang, Yang}, journal={arXiv preprint arXiv:1910.10261}, year={2019} -} \ No newline at end of file +} + +@misc{itu1988g711, + title={{ITU-T} {G.711} - {Pulse} code modulation ({PCM}) of voice frequencies}, + author={ITU-T Geneva Switzerland}, + year={1988}, +} diff --git a/docs/sources/source/asr/intro.rst b/docs/sources/source/asr/intro.rst index f8aac81833e9..cfa5d1e16919 100644 --- a/docs/sources/source/asr/intro.rst +++ b/docs/sources/source/asr/intro.rst @@ -10,6 +10,8 @@ Speech Recognition tutorial datasets models + 8kHz_models + diff --git a/docs/sources/source/asr/jasper.rst b/docs/sources/source/asr/jasper.rst index dc136cbb5f19..ec98d88ae0f1 100644 --- a/docs/sources/source/asr/jasper.rst +++ b/docs/sources/source/asr/jasper.rst @@ -23,3 +23,10 @@ Jasper10x5dr | Librispeech, `here `__ ============= ======================= ================================================================================= + +References +^^^^^^^^^^ +.. bibliography:: asr_all.bib + :style: plain + :labelprefix: ASR-MODELS + :keyprefix: asr-models- \ No newline at end of file diff --git a/docs/sources/source/asr/models.rst b/docs/sources/source/asr/models.rst index 57f529bc5298..66b5249af508 100644 --- a/docs/sources/source/asr/models.rst +++ b/docs/sources/source/asr/models.rst @@ -7,10 +7,3 @@ Models jasper quartznet -References -------------- - -.. bibliography:: asr_all.bib - :style: plain - :labelprefix: ASR-MODELS - :keyprefix: asr-models- \ No newline at end of file diff --git a/docs/sources/source/asr/quartznet.rst b/docs/sources/source/asr/quartznet.rst index 6dbadab71907..58e38ad1cc44 100644 --- a/docs/sources/source/asr/quartznet.rst +++ b/docs/sources/source/asr/quartznet.rst @@ -1,7 +1,9 @@ +.. _Quartznet_model: + QuartzNet --------- -QuartzNet is a version of Jasper :cite:`asr-models-li2019jasper` model with separable convolutions and larger filters. It can achieve performance +QuartzNet :cite:`qtz-models-kriman2019quartznet` is a version of Jasper :cite:`qtz-models-li2019jasper` model with separable convolutions and larger filters. It can achieve performance similar to Jasper but with an order of magnitude less parameters. Similarly to Jasper, QuartzNet family of models are denoted as QuartzNet_[BxR] where B is the number of blocks, and R - the number of convolutional sub-blocks within a block. Each sub-block contains a 1-D *separable* convolution, batch normalization, ReLU, and dropout: @@ -9,11 +11,9 @@ Similarly to Jasper, QuartzNet family of models are denoted as QuartzNet_[BxR] w :align: center :alt: quartznet model - .. note:: This checkpoint was trained on LibriSpeech :cite:`panayotov2015librispeech` and full "validated" part of En Mozilla Common Voice :cite:`ardila2019common` - `QuartzNet paper `_. -Pretrained models can be found, `here `_. +Pretrained models can be found at the following links: ============= ===================== ============================================================================== Network Dataset Download Link @@ -24,7 +24,10 @@ QuartzNet15x5 Aishell2 `here `_. These QuartzNet models were trained for 200 epochs using mixed precision on 2 GPUs with a batch size of 128 over 200 epochs. @@ -32,7 +30,7 @@ QuartzNet3x2 (93k params) Speech Commands V2 97.29% Test References ----------- +^^^^^^^^^^ .. bibliography:: speech_recognition_all.bib :style: plain diff --git a/docs/sources/source/speech_command/speech_recognition_all.bib b/docs/sources/source/speech_command/speech_recognition_all.bib index 277e56e7ec9b..a358cf2a70c9 100644 --- a/docs/sources/source/speech_command/speech_recognition_all.bib +++ b/docs/sources/source/speech_command/speech_recognition_all.bib @@ -40,4 +40,11 @@ @article{park2019 year = "2019", eid = {arXiv:1904.08779}, eprint = {1904.08779}, +} + +@article{li2019jasper, + title={Jasper: An End-to-End Convolutional Neural Acoustic Model}, + author={Li, Jason and Lavrukhin, Vitaly and Ginsburg, Boris and Leary, Ryan and Kuchaiev, Oleksii and Cohen, Jonathan M and Nguyen, Huyen and Gadde, Ravi Teja}, + journal={arXiv preprint arXiv:1904.03288}, + year={2019} } \ No newline at end of file diff --git a/docs/sources/source/speech_command/tutorial.rst b/docs/sources/source/speech_command/tutorial.rst index e5c59970b5ed..bd0aa38814bb 100644 --- a/docs/sources/source/speech_command/tutorial.rst +++ b/docs/sources/source/speech_command/tutorial.rst @@ -111,7 +111,7 @@ The script below does both training and evaluation (on V1 dataset) on single GPU process_classification_evaluation_epoch, ) - logging = nemo.logging + from nemo.utils import logging # Lets define some hyper parameters lr = 0.05 @@ -447,7 +447,7 @@ but they can similarly be used for v2 dataset. import nemo import nemo.collections.asr as nemo_asr - logging = nemo.logging + from nemo.utils import logging # We add some data_dir = '' diff --git a/docs/sources/source/tutorials/intro.rst b/docs/sources/source/tutorials/intro.rst index 3d4177e4a153..be51895dffc0 100644 --- a/docs/sources/source/tutorials/intro.rst +++ b/docs/sources/source/tutorials/intro.rst @@ -12,3 +12,4 @@ Getting started weightsharing callbacks complex_training + neural_graphs diff --git a/docs/sources/source/tutorials/neural_graphs.rst b/docs/sources/source/tutorials/neural_graphs.rst new file mode 100644 index 000000000000..e8f363457237 --- /dev/null +++ b/docs/sources/source/tutorials/neural_graphs.rst @@ -0,0 +1,54 @@ +Neural Graphs +============= + +The Neural Graph is a high-level abstract concept empowering the user to build graphs consisting of many, +interconnected Neural Modules. +Once the user defines a graph, its topology is “frozen”, i.e. connections between modules cannot change. +If a user wants to change the topology - he/she can build another graph, potentially spanned over the same modules. +At the same time, he can reuse and nest one graph into another. + + +.. figure:: neural_graphs_general.png + +The import/export/save/restore options combined with the lightweight API make Neural Graphs +a perfect tool for rapid prototyping and experimentation. + +There are two Jupyter Notebook tutorials focusing on different aspects of the Neural Graphs functionality. + +Tutorial I: The basic functionality +----------------------------------- + +In this first part of the Neural Graphs (NGs) tutorial we will focus on a simple example: +training TaylorNet module to approximate a sine wave function. +We will build a simple "model graph" and show how we can nest it into another graphs. + + +.. figure:: neural_graphs_nesting.png + +This part covers the following: + * how to create a Neural Graph object + * how to activate/deactivate graph context (in various ways) + * how to bind NG inputs and outpus (in various ways) + * how to nest one graph (representing the our "trainable model") into training and validation graphs + + +Tutorial II: The advanced functionality +--------------------------------------- + +In this first part of the Neural Graphs (NGs) tutorial we will focus on a more complex example: +training of an End-to-End Convolutional Neural Acoustic Model called JASPER. +We will build a "model graph" and show how we can nest it into another graphs, how we can freeze/unfreeze modules, +use graph configuration and save/load graph checkpoints. + +This part covers the following: + * how to nest one graph into another + * how to serialize and deserialize a graph + * how to export and import serialized graph configuration to/from YAML files + * how to save and load graph checkpoints (containing weights of the Trainable NMs) + * how to freeze/unfreeze modules in a graph + +Additionally, we will show how use `AppState` to list all the modules and graphs we have created in the scope of +our application. + +.. note:: + Both tutorial notebooks can be found in the `nemo/examples/neural_graphs` folder. diff --git a/docs/sources/source/tutorials/neural_graphs_general.png b/docs/sources/source/tutorials/neural_graphs_general.png new file mode 100644 index 000000000000..996e3db26e3d Binary files /dev/null and b/docs/sources/source/tutorials/neural_graphs_general.png differ diff --git a/docs/sources/source/tutorials/neural_graphs_nesting.png b/docs/sources/source/tutorials/neural_graphs_nesting.png new file mode 100644 index 000000000000..c411587714b8 Binary files /dev/null and b/docs/sources/source/tutorials/neural_graphs_nesting.png differ diff --git a/examples/applications/asr_service/app/__init__.py b/examples/applications/asr_service/app/__init__.py index a31e50d7ef94..f5da84fa3f61 100644 --- a/examples/applications/asr_service/app/__init__.py +++ b/examples/applications/asr_service/app/__init__.py @@ -7,8 +7,7 @@ import nemo import nemo.collections.asr as nemo_asr - -logging = nemo.logging +from nemo.utils import logging app = Flask(__name__) # make sure WORK_DIR exists before calling your service diff --git a/examples/applications/asr_service/app/routes.py b/examples/applications/asr_service/app/routes.py index ccc7dbc20cce..7bd636a9b39f 100644 --- a/examples/applications/asr_service/app/routes.py +++ b/examples/applications/asr_service/app/routes.py @@ -17,10 +17,8 @@ from flask import request from werkzeug.utils import secure_filename -import nemo import nemo.collections.asr as nemo_asr - -logging = nemo.logging +from nemo.utils import logging try: from app import beam_search_with_lm diff --git a/examples/asr/configs/quartznet15x5_8kHz.yaml b/examples/asr/configs/quartznet15x5_8kHz.yaml new file mode 100644 index 000000000000..3bbe1019e460 --- /dev/null +++ b/examples/asr/configs/quartznet15x5_8kHz.yaml @@ -0,0 +1,198 @@ +model: "QuartzNet" +sample_rate: 8000 + +AudioToTextDataLayer: + max_duration: 16.7 + trim_silence: true + + train: + shuffle: true + + eval: + shuffle: false + max_duration: null + +AudioToMelSpectrogramPreprocessor: + window_size: 0.02 + window_stride: 0.01 + window: "hann" + normalize: "per_feature" + n_fft: 512 + features: 64 + dither: 0.00001 + pad_to: 16 + stft_conv: true + +SpectrogramAugmentation: + rect_masks: 5 + rect_time: 120 + rect_freq: 50 + +JasperEncoder: + activation: "relu" + conv_mask: true + + jasper: + - filters: 256 + repeat: 1 + kernel: [33] + stride: [2] + dilation: [1] + dropout: 0.0 + residual: false + separable: true + + - filters: 256 + repeat: 5 + kernel: [33] + stride: [1] + dilation: [1] + dropout: 0.0 + residual: true + separable: true + + - filters: 256 + repeat: 5 + kernel: [33] + stride: [1] + dilation: [1] + dropout: 0.0 + residual: true + separable: true + + - filters: 256 + repeat: 5 + kernel: [33] + stride: [1] + dilation: [1] + dropout: 0.0 + residual: true + separable: true + + - filters: 256 + repeat: 5 + kernel: [39] + stride: [1] + dilation: [1] + dropout: 0.0 + residual: true + separable: true + + - filters: 256 + repeat: 5 + kernel: [39] + stride: [1] + dilation: [1] + dropout: 0.0 + residual: true + separable: true + + - filters: 256 + repeat: 5 + kernel: [39] + stride: [1] + dilation: [1] + dropout: 0.0 + residual: true + separable: true + + - filters: 512 + repeat: 5 + kernel: [51] + stride: [1] + dilation: [1] + dropout: 0.0 + residual: true + separable: true + + - filters: 512 + repeat: 5 + kernel: [51] + stride: [1] + dilation: [1] + dropout: 0.0 + residual: true + separable: true + + - filters: 512 + repeat: 5 + kernel: [51] + stride: [1] + dilation: [1] + dropout: 0.0 + residual: true + separable: true + + - filters: 512 + repeat: 5 + kernel: [63] + stride: [1] + dilation: [1] + dropout: 0.0 + residual: true + separable: true + + - filters: 512 + repeat: 5 + kernel: [63] + stride: [1] + dilation: [1] + dropout: 0.0 + residual: true + separable: true + + - filters: 512 + repeat: 5 + kernel: [63] + stride: [1] + dilation: [1] + dropout: 0.0 + residual: true + separable: true + + - filters: 512 + repeat: 5 + kernel: [75] + stride: [1] + dilation: [1] + dropout: 0.0 + residual: true + separable: true + + - filters: 512 + repeat: 5 + kernel: [75] + stride: [1] + dilation: [1] + dropout: 0.0 + residual: true + separable: true + + - filters: 512 + repeat: 5 + kernel: [75] + stride: [1] + dilation: [1] + dropout: 0.0 + residual: true + separable: true + + - filters: 512 + repeat: 1 + kernel: [87] + stride: [1] + dilation: [2] + dropout: 0.0 + residual: false + separable: true + + - filters: 1024 + repeat: 1 + kernel: [1] + stride: [1] + dilation: [1] + dropout: 0.0 + residual: false + +labels: [" ", "a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l", "m", + "n", "o", "p", "q", "r", "s", "t", "u", "v", "w", "x", "y", "z", "'"] diff --git a/examples/asr/contextnet.py b/examples/asr/contextnet.py index 6e6845142d8f..2857bb7f0b44 100644 --- a/examples/asr/contextnet.py +++ b/examples/asr/contextnet.py @@ -23,10 +23,9 @@ import nemo.collections.asr as nemo_asr import nemo.utils.argparse as nm_argparse from nemo.collections.asr.helpers import monitor_asr_train_progress, process_evaluation_batch, process_evaluation_epoch +from nemo.utils import logging from nemo.utils.lr_policies import CosineAnnealing -logging = nemo.logging - def parse_args(): parser = argparse.ArgumentParser( diff --git a/examples/asr/jasper.py b/examples/asr/jasper.py index 4ac0b05c8f7e..10b4d5d47f5e 100644 --- a/examples/asr/jasper.py +++ b/examples/asr/jasper.py @@ -11,10 +11,9 @@ import nemo.collections.asr as nemo_asr import nemo.utils.argparse as nm_argparse from nemo.collections.asr.helpers import monitor_asr_train_progress, process_evaluation_batch, process_evaluation_epoch +from nemo.utils import logging from nemo.utils.lr_policies import CosineAnnealing -logging = nemo.logging - def parse_args(): parser = argparse.ArgumentParser( diff --git a/examples/asr/jasper_aishell.py b/examples/asr/jasper_aishell.py index baed99786114..0ee584507909 100644 --- a/examples/asr/jasper_aishell.py +++ b/examples/asr/jasper_aishell.py @@ -10,10 +10,9 @@ import nemo.collections.asr as nemo_asr import nemo.utils.argparse as nm_argparse from nemo.collections.asr.helpers import monitor_asr_train_progress, process_evaluation_batch, process_evaluation_epoch +from nemo.utils import logging from nemo.utils.lr_policies import SquareAnnealing -logging = nemo.logging - def parse_args(): parser = argparse.ArgumentParser( diff --git a/examples/asr/jasper_aishell_infer.py b/examples/asr/jasper_aishell_infer.py index 048e6ee160be..1e44b8527e5f 100644 --- a/examples/asr/jasper_aishell_infer.py +++ b/examples/asr/jasper_aishell_infer.py @@ -9,8 +9,7 @@ import nemo import nemo.collections.asr as nemo_asr from nemo.collections.asr.helpers import post_process_predictions, post_process_transcripts, word_error_rate - -logging = nemo.logging +from nemo.utils import logging def load_vocab(vocab_file): diff --git a/examples/asr/jasper_an4.py b/examples/asr/jasper_an4.py index 9ac79f3d1935..6d3f6c82b24f 100644 --- a/examples/asr/jasper_an4.py +++ b/examples/asr/jasper_an4.py @@ -17,10 +17,9 @@ process_evaluation_epoch, word_error_rate, ) +from nemo.utils import logging from nemo.utils.lr_policies import CosineAnnealing -logging = nemo.logging - def create_dags(model_config_file, vocab, args, nf): diff --git a/examples/asr/jasper_eval.py b/examples/asr/jasper_eval.py index 8b37cd974e05..5ef4d4c51149 100644 --- a/examples/asr/jasper_eval.py +++ b/examples/asr/jasper_eval.py @@ -23,8 +23,7 @@ import nemo import nemo.collections.asr as nemo_asr from nemo.collections.asr.helpers import post_process_predictions, post_process_transcripts, word_error_rate - -logging = nemo.logging +from nemo.utils import logging def main(): diff --git a/examples/asr/notebooks/3_Speech_Commands_using_NeMo.ipynb b/examples/asr/notebooks/3_Speech_Commands_using_NeMo.ipynb index b7c2fc8416c1..36a9834ee800 100644 --- a/examples/asr/notebooks/3_Speech_Commands_using_NeMo.ipynb +++ b/examples/asr/notebooks/3_Speech_Commands_using_NeMo.ipynb @@ -254,7 +254,7 @@ ")\n", "from nemo.collections.asr.metrics import classification_accuracy\n", "\n", - "logging = nemo.logging" + "from nemo.utils import logging" ] }, { diff --git a/examples/asr/notebooks/4_Online_Data_Augmentation.ipynb b/examples/asr/notebooks/4_Online_Data_Augmentation.ipynb index ddffabfb270c..edbfa3e271b8 100644 --- a/examples/asr/notebooks/4_Online_Data_Augmentation.ipynb +++ b/examples/asr/notebooks/4_Online_Data_Augmentation.ipynb @@ -836,7 +836,7 @@ ")\n", "from nemo.collections.asr.metrics import classification_accuracy\n", "\n", - "logging = nemo.logging" + "from nemo.utils import logging" ] }, { diff --git a/examples/asr/quartznet.py b/examples/asr/quartznet.py index aee030010ae6..9dbea554c78d 100644 --- a/examples/asr/quartznet.py +++ b/examples/asr/quartznet.py @@ -10,10 +10,9 @@ import nemo.collections.asr as nemo_asr import nemo.utils.argparse as nm_argparse from nemo.collections.asr.helpers import monitor_asr_train_progress, process_evaluation_batch, process_evaluation_epoch +from nemo.utils import logging from nemo.utils.lr_policies import CosineAnnealing -logging = nemo.logging - def parse_args(): parser = argparse.ArgumentParser( diff --git a/examples/asr/quartznet_speech_commands.py b/examples/asr/quartznet_speech_commands.py index 13cab6b9951d..7bcb9058974a 100644 --- a/examples/asr/quartznet_speech_commands.py +++ b/examples/asr/quartznet_speech_commands.py @@ -17,10 +17,9 @@ process_classification_evaluation_batch, process_classification_evaluation_epoch, ) +from nemo.utils import logging from nemo.utils.lr_policies import CosineAnnealing, PolynomialDecayAnnealing, PolynomialHoldDecayAnnealing -logging = nemo.logging - def parse_args(): parser = argparse.ArgumentParser( diff --git a/examples/image/gan.py b/examples/image/gan.py index 28cac4cba43c..42ed3cb0ac90 100644 --- a/examples/image/gan.py +++ b/examples/image/gan.py @@ -9,9 +9,7 @@ import nemo import nemo.collections.simple_gan as nemo_simple_gan from nemo.backends.pytorch.torchvision.helpers import compute_accuracy, eval_epochs_done_callback, eval_iter_callback - -logging = nemo.logging - +from nemo.utils import logging parser = argparse.ArgumentParser(description='MNIST') parser.add_argument("--local_rank", default=None, type=int) diff --git a/examples/image/transfer_learning.py b/examples/image/transfer_learning.py index bb3d54fe837c..206104d2404f 100644 --- a/examples/image/transfer_learning.py +++ b/examples/image/transfer_learning.py @@ -8,8 +8,7 @@ import nemo from nemo.backends.pytorch.torchvision.helpers import compute_accuracy, eval_epochs_done_callback, eval_iter_callback - -logging = nemo.logging +from nemo.utils import logging sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../..'))) diff --git a/examples/neural_graphs/img/neural_graphs_general.png b/examples/neural_graphs/img/neural_graphs_general.png new file mode 100644 index 000000000000..996e3db26e3d Binary files /dev/null and b/examples/neural_graphs/img/neural_graphs_general.png differ diff --git a/examples/neural_graphs/img/neural_graphs_nesting.png b/examples/neural_graphs/img/neural_graphs_nesting.png new file mode 100644 index 000000000000..c411587714b8 Binary files /dev/null and b/examples/neural_graphs/img/neural_graphs_nesting.png differ diff --git a/examples/neural_graphs/neural_graph_advanced.ipynb b/examples/neural_graphs/neural_graph_advanced.ipynb new file mode 100644 index 000000000000..fd8a0b955dc9 --- /dev/null +++ b/examples/neural_graphs/neural_graph_advanced.ipynb @@ -0,0 +1,379 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# =============================================================================\n", + "# Copyright (c) 2020 NVIDIA. All Rights Reserved.\n", + "#\n", + "# Licensed under the Apache License, Version 2.0 (the \"License\");\n", + "# you may not use this file except in compliance with the License.\n", + "# You may obtain a copy of the License at\n", + "#\n", + "# http://www.apache.org/licenses/LICENSE-2.0\n", + "#\n", + "# Unless required by applicable law or agreed to in writing, software\n", + "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", + "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", + "# See the License for the specific language governing permissions and\n", + "# limitations under the License.\n", + "# =============================================================================\n", + "\n", + "from functools import partial\n", + "from os.path import expanduser, join, abspath, dirname, exists\n", + "import tarfile\n", + "\n", + "from ruamel.yaml import YAML\n", + "\n", + "import nemo\n", + "import nemo.collections.asr as nemo_asr\n", + "from nemo.collections.asr.helpers import monitor_asr_train_progress\n", + "from nemo.core import NeuralGraph, OperationMode, DeviceType, SimpleLossLoggerCallback\n", + "from nemo.utils import logging\n", + "from nemo.utils.app_state import AppState\n", + "\n", + "# Create Neural(Module)Factory, use CPU.\n", + "nf = nemo.core.NeuralModuleFactory(placement=DeviceType.CPU)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Tutorial II: The advanced functionality\n", + "\n", + "In this first part of the Neural Graphs (NGs) tutorial we will focus on a more complex example: training of an End-to-End Convolutional Neural Acoustic Model called JASPER. We will build a \"model graph\" and show how we can nest it into another graphs, how we can freeze/unfreeze modules, use graph configuration and save/load graph checkpoints.\n", + "\n", + "#### This part covers the following:\n", + " * how to nest one graph into another\n", + " * how to serialize and deserialize a graph\n", + " * how to export and import serialized graph configuration to/from YAML files\n", + " * how to save and load graph checkpoints (containing weights of the Trainable NMs)\n", + " * how to freeze/unfreeze modules in a graph\n", + " \n", + "Additionally, we will show how use `AppState` to list all the modules and graphs we have created in the scope of our application.\n", + "In order to learn more about graph nesting and input/output binding please refer to the first part of the tutorial.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Prepare the samples for training JASPER - we will use the data available in NeMo tests.\n", + "data_folder = abspath(\"../../tests/data/\")\n", + "logging.info(\"Looking up for test ASR data\")\n", + "if not exists(join(data_folder, \"asr\")):\n", + " logging.info(\"Extracting ASR data to: {0}\".format(join(data_folder, \"asr\")))\n", + " tar = tarfile.open(join(data_folder, \"asr.tar.gz\"), \"r:gz\")\n", + " tar.extractall(path=data_folder)\n", + " tar.close()\n", + "else:\n", + " logging.info(\"ASR data found in: {0}\".format(join(data_folder, \"asr\")))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Set paths to model configuration, manifest and sample files.\n", + "model_config_file = abspath(\"../asr/configs/jasper_an4.yaml\")\n", + "manifest_path = join(data_folder, 'asr/tarred_an4/tarred_audio_manifest.json')\n", + "tarpath = join(data_folder, 'asr/tarred_an4/audio_1.tar')\n", + "\n", + "# Open the model config file and get vocabulary.\n", + "yaml = YAML(typ=\"safe\")\n", + "with open(expanduser(model_config_file)) as f:\n", + " config = yaml.load(f)\n", + " \n", + "# Get labels (vocabulary).\n", + "vocab = config['labels']\n", + "vocab_len = len(vocab)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Instantiate DataLayer that can load the tarred samples.\n", + "data_layer = nemo_asr.TarredAudioToTextDataLayer(\n", + " audio_tar_filepaths=tarpath, manifest_filepath=manifest_path, labels=vocab, batch_size=16)\n", + "logging.info(\"Loaded {} samples that we will use for training\".format(len(data_layer)))\n", + "\n", + "# Create rest of the modules using the Neural Module deserialization feature.\n", + "data_preprocessor = nemo_asr.AudioToMelSpectrogramPreprocessor.deserialize(config[\"AudioToMelSpectrogramPreprocessor\"])\n", + "\n", + "jasper_encoder = nemo_asr.JasperEncoder.deserialize(config[\"JasperEncoder\"])\n", + "jasper_decoder = nemo_asr.JasperDecoderForCTC.deserialize(\n", + " config[\"JasperDecoderForCTC\"], overwrite_params={\"num_classes\": vocab_len}\n", + ")\n", + "ctc_loss = nemo_asr.CTCLossNM(num_classes=vocab_len)\n", + "greedy_decoder = nemo_asr.GreedyCTCDecoder()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Create the Jasper \"model\" graph.\n", + "with NeuralGraph(operation_mode=OperationMode.both, name=\"jasper_model\") as jasper_model:\n", + " # Copy one input port definitions - using \"user\" port names.\n", + " jasper_model.inputs[\"input\"] = data_preprocessor.input_ports[\"input_signal\"]\n", + " # Bind selected inputs - bind other using the default port name.\n", + " i_processed_signal, i_processed_signal_len = data_preprocessor(input_signal=jasper_model.inputs[\"input\"], length=jasper_model)\n", + " i_encoded, i_encoded_len = jasper_encoder(audio_signal=i_processed_signal, length=i_processed_signal_len)\n", + " i_log_probs = jasper_decoder(encoder_output=i_encoded)\n", + " # Bind selected outputs - using \"user\" port names.\n", + " jasper_model.outputs[\"log_probs\"] = i_log_probs\n", + " jasper_model.outputs[\"encoded_len\"] = i_encoded_len\n", + "\n", + "# Print the summary.\n", + "logging.info(jasper_model.summary())" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Serialize the whole graph.\n", + "serialized_jasper = jasper_model.serialize()\n", + "logging.info(\"Serialized JASPER model:\\n {}\".format(serialized_jasper))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# You can also serialize/deserialize a single NeuralModule, e.g. a decoder.\n", + "logging.info(\"Serialized JASPER Decoder:\\n {}\".format(jasper_decoder.serialize()))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# We can also export the serialized configuration to a file.\n", + "jasper_model.export_to_config(\"my_jasper.yml\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Display the lists of graph and modules.\n", + "logging.info(AppState().graphs.summary())\n", + "logging.info(AppState().modules.summary())" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Deserialize graph - create a copy of the JASPER \"model\".\n", + "# Please note that the modules exist, so we must enable the graph to \"reuse\" them.\n", + "# (Commenting out reuse_existing_modules will raise a KeyError.)\n", + "jasper_copy = NeuralGraph.deserialize(serialized_jasper, reuse_existing_modules=True)\n", + "serialized_jasper_copy = jasper_copy.serialize()\n", + "assert serialized_jasper == serialized_jasper_copy # THE SAME! Please note name of the graph is not exported." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Alternativelly, import a copy of the JASPER \"model\" from config.\n", + "jasper_copy = NeuralGraph.import_from_config(\"my_jasper.yml\", reuse_existing_modules=True, name=\"jasper_copy\")\n", + "\n", + "# Print the summary.\n", + "logging.info(jasper_copy.summary())\n", + "\n", + "# Display list of graph and modules\n", + "logging.info(AppState().graphs.summary())\n", + "logging.info(AppState().modules.summary())" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Note that there are two graphs in the \"Graph Registry\", yet the list of modules haven't changed. This means that both graphs are spanned on the same list of modules." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Create the \"training\" graph.\n", + "with NeuralGraph(operation_mode=OperationMode.training) as training_graph:\n", + " # Create the \"implicit\" training graph.\n", + " o_audio_signal, o_audio_signal_len, o_transcript, o_transcript_len = data_layer()\n", + " # Use Jasper module as any other neural module.\n", + " o_log_probs, o_encoded_len = jasper_copy(input=o_audio_signal, length=o_audio_signal_len)\n", + " o_predictions = greedy_decoder(log_probs=o_log_probs)\n", + " o_loss = ctc_loss(\n", + " log_probs=o_log_probs, targets=o_transcript, input_length=o_encoded_len, target_length=o_transcript_len\n", + " )\n", + " # Set the graph output.\n", + " training_graph.outputs[\"o_loss\"] = o_loss\n", + "\n", + "# Print the summary.\n", + "logging.info(training_graph.summary())" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Create a simple loss callback.\n", + "loss_callback = nemo.core.SimpleLossLoggerCallback(\n", + " tensors=[training_graph.output_tensors[\"o_loss\"]],\n", + " print_func=lambda x: logging.info(f'Train Loss: {str(x[0].item())}'), step_freq=1\n", + ")\n", + "# Train the graph.\n", + "nf.train(\n", + " training_graph=training_graph,\n", + " optimizer=\"novograd\",\n", + " callbacks=[loss_callback],\n", + " optimization_params={\"max_steps\": 5, \"lr\": 0.01},\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Please note that the loss is going down. Still, we use only 65 samples, so we cannot really expect the model to be useful;)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Finally, I can save the graph checkpoint!\n", + "# Note that optionally you can indicate the names of the modules to be saved.\n", + "jasper_copy.save_to(\"my_jasper.chkpt\")#, module_names=[\"jasperencoder0\"])\n", + "# Please note only \"trainable\" modules will be saved." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# We can also save the whole training graph - which in this case will result in the same checkpoint...\n", + "training_graph.export_to_config(\"my_whole_graph.yml\")\n", + "training_graph.save_to(\"my_whole_graph.chkpt\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Finally, I can load everything and continue training.\n", + "new_training_graph = NeuralGraph.import_from_config(\"my_whole_graph.yml\", reuse_existing_modules=True)\n", + "\n", + "# Let's restore only the encoder\n", + "new_training_graph.restore_from(\"my_whole_graph.chkpt\", module_names=[\"jasperencoder0\"])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# So let us freeze the whole graph...\n", + "training_graph.freeze() #we can also freeze a subset, using \"module_names=[]\"\"\n", + "# ... and finetune only the decoder.\n", + "training_graph.unfreeze(module_names=[\"jasperdecoderforctc0\"])\n", + "\n", + "# Ok, let us see what the graph looks like now.\n", + "logging.info(training_graph.summary())" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "# Create a new simple callback using graph outputs \"o_loss\".\n", + "loss_callback = nemo.core.SimpleLossLoggerCallback(\n", + " tensors=[new_training_graph.output_tensors[\"o_loss\"]],\n", + " print_func=lambda x: logging.info(f'Train Loss: {str(x[0].item())}'), step_freq=1\n", + ")\n", + "\n", + "# And continue training...\n", + "nf.reset_trainer()\n", + "nf.train(\n", + " training_graph=new_training_graph,\n", + " optimizer=\"novograd\",\n", + " callbacks=[loss_callback],\n", + " optimization_params={\"max_steps\": 5, \"lr\": 0.01},\n", + ")\n", + "# Please note that this will throw an error if you will freeze all the trainable modules!" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "nemo-env", + "language": "python", + "name": "nemo-env" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.4" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/examples/neural_graphs/neural_graph_basic.ipynb b/examples/neural_graphs/neural_graph_basic.ipynb new file mode 100644 index 000000000000..8c90654c7723 --- /dev/null +++ b/examples/neural_graphs/neural_graph_basic.ipynb @@ -0,0 +1,296 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# =============================================================================\n", + "# Copyright (c) 2020 NVIDIA. All Rights Reserved.\n", + "#\n", + "# Licensed under the Apache License, Version 2.0 (the \"License\");\n", + "# you may not use this file except in compliance with the License.\n", + "# You may obtain a copy of the License at\n", + "#\n", + "# http://www.apache.org/licenses/LICENSE-2.0\n", + "#\n", + "# Unless required by applicable law or agreed to in writing, software\n", + "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", + "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", + "# See the License for the specific language governing permissions and\n", + "# limitations under the License.\n", + "# =============================================================================\n", + "\n", + "import torch\n", + "\n", + "from nemo.backends.pytorch.tutorials import MSELoss, RealFunctionDataLayer, TaylorNet\n", + "from nemo.core import (\n", + " DeviceType,\n", + " EvaluatorCallback,\n", + " NeuralGraph,\n", + " NeuralModuleFactory,\n", + " OperationMode,\n", + " SimpleLossLoggerCallback,\n", + ")\n", + "from nemo.utils import logging\n", + "from nemo.utils.app_state import AppState\n", + "\n", + "# Create Neural(Module)Factory, use CPU.\n", + "nf = NeuralModuleFactory(placement=DeviceType.CPU)" + ] + }, + { + "attachments": { + "neural_graphs_general.png": { + "image/png": "" + } + }, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Introduction to Neural Graphs (NGs) \n", + "\n", + "The Neural Graph is a high-level abstract concept empowering the users to build graphs consisting of many, interconnected Neural Modules. A user in his/her application can build any number of graphs, potentially spanning over the same modules. Once defined, graphs can be trained, exported/saved and imported/restored in other application(s).\n", + "\n", + "![neural_graphs_general.png](attachment:neural_graphs_general.png)\n", + "\n", + "The import/export/save/restore options combined with the lightweight API make Neural Graphs a perfect tool for rapid prototyping and experimentation.\n", + "\n", + "\n" + ] + }, + { + "attachments": { + "neural_graphs_nesting.png": { + "image/png": "" + } + }, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Tutorial I: The basic functionality\n", + "\n", + "In this first part of the Neural Graphs (NGs) tutorial we will focus on a simple example: training TaylorNet module to approximate a sine wave function. We will build a simple \"model graph\" and show how we can nest it into another graphs.\n", + "\n", + "![neural_graphs_nesting.png](attachment:neural_graphs_nesting.png)\n", + "\n", + "#### This part covers the following:\n", + " * how to create a Neural Graph object\n", + " * how to activate/deactivate graph context (in various ways)\n", + " * how to bind NG inputs and outpus (in various ways)\n", + " * how to nest one graph (representing the our \"trainable model\") into training and validation graphs\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Instantiate the necessary neural modules.\n", + "dl_training = RealFunctionDataLayer(n=10000, batch_size=32)\n", + "dl_validation = RealFunctionDataLayer(n=10000, batch_size=32)\n", + "tn = TaylorNet(dim=4)\n", + "loss = MSELoss()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Build the \"model\"graph.\n", + "simple_model = NeuralGraph(operation_mode=OperationMode.both)\n", + "\n", + "# Activate the \"graph context\".\n", + "simple_model.activate() \n", + "\n", + "# Create bound input port by copying the definition from input port \"x\" of TaylorNet.\n", + "simple_model.inputs[\"input\"] = tn.input_ports[\"x\"]\n", + "# Bind the \"x\" input, so that \"x\" of graph will \"lead\" to input port \"x\" of TaylorNet.\n", + "_ = tn(x=simple_model.inputs[\"input\"])\n", + "# Add the module for the second time, also binding the port.\n", + "_ = tn(x=simple_model.inputs[\"input\"])\n", + "# All outputs will be bound by default.\n", + "\n", + "# Deactivate the graph context.\n", + "simple_model.deactivate()\n", + "\n", + "# Let us see what the graph looks like.\n", + "logging.info(simple_model.summary())\n", + "# Please note that the graph is NOT COMPLETE, as it:\n", + "# * doesn't contain a DataLayer, and\n", + "# * has bound input ports that need to be connected." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# And how about a \"model graph\" with an arbitrary graph with a loop?\n", + "\n", + "# Create a new graph instance.\n", + "simple_model = NeuralGraph(operation_mode=OperationMode.both)\n", + "\n", + "# Activate the new \"graph context\" using the \"with\" statement.\n", + "with simple_model:\n", + " # As this time we decided to stay with the original port name \"x\", we can use the \"default input binding\".\n", + " embeddings = tn(x=simple_model)\n", + " # Now create a loop and pass them back as inputs to TaylorNet instance.\n", + " prediction = tn(x=embeddings)\n", + " # Moreover, we are interested only in the second output, so we must \"manually bind\" it.\n", + " simple_model.outputs[\"prediction\"] = prediction\n", + "# Ending \"with\" closes the \"graph context\".\n", + " \n", + "# Ok, let us see what the graph looks like now.\n", + "logging.info(simple_model.summary())" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Anyway, for the rest of the example let's create a simple \"model graph\" wrapping just one module.\n", + "\n", + "# Create a new graph and open it's context in a single line.\n", + "with NeuralGraph(operation_mode=OperationMode.both) as simple_model:\n", + " # As this time we decided to stay with the original port name \"x\", we can use the \"default input binding\".\n", + " prediction = tn(x=simple_model)\n", + " # Moreover, we are interested only in the second output, so we must \"manually bind\" it.\n", + " simple_model.outputs[\"prediction\"] = prediction\n", + " \n", + "# Ok, let us see what the graph looks like now.\n", + "logging.info(simple_model.summary())" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Let us now compose a COMPLETE training graph.\n", + "# In particular, we will \"nest\" our \"model graph\" into this new training graph.\n", + "with NeuralGraph(operation_mode=OperationMode.training) as training_graph:\n", + " # Take outputs from the training DL.\n", + " x, t = dl_training()\n", + " # Pass them to \"inner\" graph (nest!).\n", + " p = simple_model(x=x)\n", + " # Pass both of them to loss.\n", + " lss = loss(predictions=p, target=t)\n", + " # We will use \"loss\" as output during training, so we must \"manually bind\" it.\n", + " training_graph.outputs[\"loss\"] = lss\n", + " \n", + "# Ok, let us see what the graph looks like now.\n", + "logging.info(training_graph.summary())\n", + "# In the following plaese note that:\n", + "# * during nesting the graph was flattened - 3 modules, 4 steps\n", + "# * the input passed to \"simple_model\" bound input port were passed to the actual input of TaylorNet\n", + "# * the graph is COMPLETE, i.e. there are no inputs that are bound and there is a single datalayer\n", + "# So in short: we can execute it!" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Let us compose a COMPLETE validation graph.\n", + "with NeuralGraph(operation_mode=OperationMode.evaluation) as validation_graph:\n", + " # Take outputs from the training DL.\n", + " x_valid, t_valid = dl_validation()\n", + " # Pass them to the trainable module.\n", + " p_valid = simple_model(x=x_valid)\n", + " loss_valid = loss(predictions=p_valid, target=t_valid)\n", + "\n", + "# Ok, let us see what the graph looks like now.\n", + "logging.info(validation_graph.summary())" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Create training callback logging loss to console.\n", + "train_callback = SimpleLossLoggerCallback(\n", + " tensors=[lss], print_func=lambda x: logging.info(f'Train Loss: {str(x[0].item())}')\n", + ")\n", + "\n", + "# Create evaluator callback logging/aggregating the validation loss to console.\n", + "def batch_loss_per_batch_callback(tensors, global_vars):\n", + " if \"batch_loss\" not in global_vars.keys():\n", + " global_vars[\"batch_loss\"] = []\n", + " for key, value in tensors.items():\n", + " if key.startswith(\"loss\"):\n", + " global_vars[\"batch_loss\"].append(torch.mean(torch.stack(value)))\n", + "\n", + "\n", + "def batch_loss_epoch_finished_callback(global_vars):\n", + " epoch_loss = torch.max(torch.tensor(global_vars[\"batch_loss\"]))\n", + " logging.info(\"Evaluation Loss: {0}\".format(epoch_loss))\n", + " return dict({\"Evaluation Loss\": epoch_loss})\n", + "\n", + "\n", + "eval_callback = EvaluatorCallback(\n", + " eval_tensors=[loss_valid],\n", + " user_iter_callback=batch_loss_per_batch_callback,\n", + " user_epochs_done_callback=batch_loss_epoch_finished_callback,\n", + " eval_step=100,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Invoke the \"train\" action.\n", + "nf.reset_trainer() # I do not understand why do I have to \"reset the trainer\" when calling train() function again :]\n", + "nf.train(\n", + " training_graph=training_graph,\n", + " callbacks=[train_callback, eval_callback],\n", + " optimization_params={\"num_epochs\": 3, \"lr\": 0.0003},\n", + " optimizer=\"sgd\",\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "nemo-env", + "language": "python", + "name": "nemo-env" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.4" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/examples/nlp/dialogue_state_tracking/data/dialogue_augmentation_for_sgd_format.py b/examples/nlp/dialogue_state_tracking/data/dialogue_augmentation_for_sgd_format.py new file mode 100644 index 000000000000..bc6ccff0ec24 --- /dev/null +++ b/examples/nlp/dialogue_state_tracking/data/dialogue_augmentation_for_sgd_format.py @@ -0,0 +1,521 @@ +# ============================================================================= +# Copyright 2020 NVIDIA. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= + +import argparse +import copy +import json +import os +import random +import re +from collections import defaultdict +from pprint import pprint + +import inflect +import numpy as np +from tqdm import tqdm + +p = inflect.engine() + + +def get_ontology(dialogues, schemas): + """ + creates ontology: + (service_name, slot_name) -> + -> is_categorical -> True/False + -> possible_values -> set of values + """ + ontology = defaultdict(defaultdict) + for schema in schemas: + service_name = schema['service_name'] + for slot in schema['slots']: + slot_name = slot['name'] + ontology[(service_name, slot_name)]["is_categorical"] = slot['is_categorical'] + ontology[(service_name, slot_name)]["possible_values"] = set(slot['possible_values']) + + for dialogue in dialogues: + for turn in dialogue["turns"]: + for frame in turn["frames"]: + service_name = frame["service"] + if "state" in frame: + for k, vs in frame["state"]["slot_values"].items(): + for v in vs: + ontology[(service_name, k)]["possible_values"].add(v) + if "actions" in frame: + for action in frame["actions"]: + k = action["slot"] + for v in action["values"]: + if (service_name, k) in ontology: + # some slots like 'count' are not in schema + ontology[(service_name, k)]["possible_values"].add(v) + return ontology + + +def get_affected_future_frames(dialogue, from_turn_id, slot_name, slot_value, service): + """ + determine for all turns starting from from_turn_id if they contain the given combination of slot_name, slot_value, service + if so, return affected List[(turn_id, frame_id, slot_name)] + """ + assert isinstance(from_turn_id, int) + assert isinstance(slot_name, str) + assert isinstance(slot_value, str) + assert isinstance(service, str) + res = [] + for turn_id, turn in enumerate(dialogue["turns"][from_turn_id:], start=from_turn_id): + for frame_id, frame in enumerate(turn["frames"]): + if turn["speaker"] == "SYSTEM": + if frame["service"] == service: + for action in frame["actions"]: + if action["slot"] == slot_name and slot_value in action["values"]: + res.append((turn_id, frame_id, slot_name)) + continue + else: + if frame["service"] == service and slot_value in frame["state"]["slot_values"].get(slot_name, []): + res.append((turn_id, frame_id, slot_name)) + continue + return res + + +def augment_dialog_by_auxiliary_entries(dialogue): + """ + augments dialogue by slot_to_span and state_update. + slot_to_span (dict): slotname-> value-> [start_idx, end_idx] for all values in turn that appear exactly once in utterance. + state_update (dict): slotname-> [(turn_id, frame_id, slot_name)] only contains newly introduced slotnames. + New for system are all slots in "actions". + New for user are all slots who did not appear in previous turn or whose (list of) value has changed. + Returns list of following affected turns/frames. + + """ + prev_service_user = "" + prev_state_slots_user = {} # key, value + for turn_id, turn in enumerate(dialogue["turns"]): + for frame in turn["frames"]: + slot_to_spans = defaultdict(dict) + for slot in frame["slots"]: + k = slot["slot"] + start_idx, end_idx = slot["start"], slot["exclusive_end"] + slot_to_spans[k][turn["utterance"][start_idx:end_idx]] = [start_idx, end_idx] + frame["slot_to_span"] = slot_to_spans + + if turn["speaker"] == "SYSTEM": + for frame in turn["frames"]: + new_slots = defaultdict(list) + for action in frame["actions"]: + slot = action["slot"] + slot_values = action["values"] + for v in slot_values: + new_slots[slot] = get_affected_future_frames( + dialogue, turn_id + 1, slot_name=slot, slot_value=v, service=frame["service"] + ) + if v in turn["utterance"]: + if slot not in frame["slot_to_span"] or v not in frame["slot_to_span"][slot]: + if len(turn["utterance"].split(v)) == 2: + start_idx = turn["utterance"].index(v) + end_idx = start_idx + len(v) + frame["slot_to_span"][slot][v] = [start_idx, end_idx] + frame["state_update"] = new_slots + else: + for frame in turn["frames"]: + new_slots = defaultdict(list) # map from slot_value -> List[frames] in future + for k, vs in frame["state"]["slot_values"].items(): + for v_id, v in enumerate(vs): + if v in turn["utterance"]: + if k not in frame["slot_to_span"] or v not in frame["slot_to_span"][k]: + if len(turn["utterance"].split(v)) == 2: + start_idx = turn["utterance"].index(v) + end_idx = start_idx + len(v) + frame["slot_to_span"][k][v] = [start_idx, end_idx] + if k not in prev_state_slots_user or v not in prev_state_slots_user[k]: + new_slots[k] = get_affected_future_frames( + dialogue, turn_id + 1, slot_name=k, slot_value=v, service=frame["service"] + ) + frame["state_update"] = new_slots + + if len(turn["frames"]) == 1: + use_frame = turn["frames"][0] + else: + use_frame = [frame for frame in turn["frames"] if frame["service"] != prev_service_user][0] + prev_service_user = use_frame["service"] + prev_state_slots_user = use_frame["state"]["slot_values"] + + +def validate(dialogue): + """ + check if dialogue is valid wrt to non categorical slots: + -check if span indices are within utterance length + -check if utterance substring (by span) is found among values in system action + -check if utterance substring (by span) is found among values in user state->slot_values->key + Otherwise raise error with turn id and frame id + """ + for turn_id, turn in enumerate(dialogue["turns"]): + for frame_id, frame in enumerate(turn["frames"]): + for slot in frame["slots"]: + try: + st_idx, end_idx, key = slot["start"], slot["exclusive_end"], slot["slot"] + word = turn["utterance"][st_idx:end_idx] + assert 0 <= st_idx < end_idx <= len(turn["utterance"]) + if turn["speaker"] == "SYSTEM": + found_key = False + for action in frame["actions"]: + if action["slot"] == key: + if word in action["values"]: + found_key = True + assert found_key + else: + if key in frame["state"]["slot_values"]: + assert word in frame["state"]["slot_values"][key] + except Exception: + raise ValueError(f"Turn {turn_id}, frame {frame_id}") + + +def process_dialogues(final_dialogues, dialogue_count, dialogues, replace_turn_prob, replace_word_prob, new_val_func): + """ + iterates through all dialogues and does replacement according to new_val_func + writes out into final_dialogues. + """ + replace_success = 0 + replace_failed = 0 + for dialogue_id, dialogue in tqdm(enumerate(dialogues)): + d_id, d_count = dialogue["dialogue_id"].split("_") + d_id = int(d_id) + dialogue["dialogue_id"] = f"{d_id}_{dialogue_count[d_id]:05d}" + dialogue_count[d_id] += 1 + for turn_id, turn in enumerate(dialogue["turns"]): + if random.random() < replace_turn_prob: + spans = get_sentence_components(turn=turn) + for span in reversed(spans): + if random.random() < replace_word_prob: + old_value = dialogue["turns"][turn_id]["utterance"][span[0] : span[1]] + new_value = new_val_func(dialogue, turn_id, old_value, span[0], span[1]) + if new_value: + tmp_dialogue = copy.deepcopy(dialogue) + try: + replace(tmp_dialogue, turn_id, span[0], span[1], new_value) + validate(tmp_dialogue) + for k, v in tmp_dialogue.items(): + dialogue[k] = v + replace_success += 1 + except Exception: + replace_failed += 1 + + for turn in dialogue["turns"]: + for frame in turn["frames"]: + if 'state_update' in frame: + frame.pop("state_update") + if 'slot_to_span' in frame: + frame.pop("slot_to_span") + final_dialogues[d_id].append(dialogue) + print(f"Replacement success {replace_success}, failed {replace_failed}\n") + + +def update_spans(dialogue, turn_id, frame_id, start_idx, end_idx, old_value, new_value): + """ + update slot spans and slot_to_span + """ + frame = dialogue["turns"][turn_id]["frames"][frame_id] + offset = len(new_value) - len(old_value) + + for slot in frame['slots']: + if start_idx < slot['start']: + slot['start'] += offset + if start_idx < slot['exclusive_end']: + slot['exclusive_end'] += offset + + for k, vs in frame['slot_to_span'].items(): + for v, spans in vs.items(): + if start_idx < spans[0]: + spans[0] += offset + if start_idx < spans[1]: + spans[1] += offset + + +def update_values(dialogue, turn_id, frame_id, key, old_value, new_value): + """ + only update values: actions, state, slot_to_span + """ + frame = dialogue["turns"][turn_id]["frames"][frame_id] + if "actions" in frame: + for action in frame["actions"]: + if key == action["slot"] and old_value in action["values"]: + action["values"].remove(old_value) + action["values"].append(new_value) + if "state" in frame: + for k, vs in frame["state"]["slot_values"].items(): + for v_id, v in enumerate(vs): + if k == key and v == old_value: + vs[v_id] = new_value + + for k, vs in frame["slot_to_span"].items(): + for v, spans in list(vs.items()): + if k == key and v == old_value: + vs.pop(v) + vs[new_value] = spans + + +def get_sentence_components(turn): + """ + return list of start and end indices of slot values/ words that appear in utterance + """ + sentence = turn["utterance"] + word_indices = np.asarray([False for _ in range(len(sentence) + 1)]) + for frame in turn["frames"]: + if "state" in frame: + for k, vs in frame["state"]["slot_values"].items(): + for v in vs: + if v in sentence: + start_idx = sentence.index(v) + end_idx = start_idx + len(v) + word_indices[start_idx:end_idx] = True + if "actions" in frame: + for action in frame["actions"]: + k = action["slot"] + for v in action["values"]: + if v in sentence: + start_idx = sentence.index(v) + end_idx = start_idx + len(v) + word_indices[start_idx:end_idx] = True + + for i in range(len(sentence)): + if sentence[i].isalnum(): + word_indices[i] = True + res = [] + idx = 0 + while idx < len(word_indices): + if word_indices[idx]: + start_idx = idx + while word_indices[idx]: + idx += 1 + end_idx = idx + res.append((start_idx, end_idx)) + idx += 1 + return res + + +def find_word_in_turn(dialogue, turn_id, value, start_idx, end_idx): + """ + find non-cat slot value in turn. + return List[(turn_id, frame_id, key)] + """ + assert isinstance(value, str) + frames = dialogue["turns"][turn_id]["frames"] + res = [] + for frame_id, frame in enumerate(frames): + for slot in frame["slots"]: + if start_idx == slot["start"] and end_idx == slot["exclusive_end"]: + res.append((turn_id, frame_id, slot["slot"])) + return res + + +def get_new_value(dialogue, turn_id, value, start_idx, end_idx): + """ + replace span with another value from ontology if this belongs non-cat slot + return new value + """ + candidates = find_word_in_turn(dialogue, turn_id, value, start_idx, end_idx) + possible_values = set() + for _, frame_id, k in candidates: + frame = dialogue["turns"][turn_id]["frames"][frame_id] + service = frame["service"] + if "possible_values" in ontology[(service, k)]: + possible_values.update(ontology[(service, k)]["possible_values"]) + return random.choice(list(possible_values)) if possible_values else None + + +def replace(dialogue, turn_id, start_idx, end_idx, new_value): + """ + replace utterance at turn_id around start_idx:end_idx with new_value. + If old value is found in turn (non-categorical slot), change all affected frames with new_value: + -update_values + -update_spans + """ + assert isinstance(turn_id, int) + assert isinstance(start_idx, int) + assert isinstance(end_idx, int) + turn = dialogue["turns"][turn_id] + sentence = turn["utterance"] + old_value = sentence[start_idx:end_idx] + affected_values = find_word_in_turn( + dialogue=dialogue, turn_id=turn_id, value=old_value, start_idx=start_idx, end_idx=end_idx + ) + affected_spans = [(turn_id, start_idx, end_idx)] + for _, frame_id, key in affected_values.copy(): + frame = dialogue["turns"][turn_id]["frames"][frame_id] + new_affected_values = frame["state_update"][key] + affected_values += new_affected_values + for a_turn_id, a_frame_id, a_key in new_affected_values: + assert key == a_key + spans = ( + dialogue["turns"][a_turn_id]["frames"][a_frame_id]["slot_to_span"].get(a_key, {}).get(old_value, None) + ) + if spans: + affected_spans += [(a_turn_id, spans[0], spans[1])] + + for a_turn_id, a_frame_id, a_key in affected_values: + update_values(dialogue, a_turn_id, a_frame_id, a_key, old_value, new_value) + for a_turn_id, start_idx, end_idx in affected_spans: + turn = dialogue["turns"][a_turn_id] + assert old_value == turn["utterance"][start_idx:end_idx] + for a_frame_id in range(len(turn["frames"])): + update_spans(dialogue, a_turn_id, a_frame_id, start_idx, end_idx, old_value, new_value) + turn["utterance"] = turn["utterance"][:start_idx] + new_value + turn["utterance"][end_idx:] + + +def num2str(dialogue, turn_id, old_value, start_idx, end_idx): + """ + gets old_value and returns stringified version if old_value was number and does not belong to non-cat span value + """ + res = find_word_in_turn(dialogue, turn_id, old_value, start_idx, end_idx) + if not res and old_value.isnumeric(): + return p.number_to_words(int(old_value)) + " " + old_value + return None + + +def test_helper(dialogue, dialogue_id, turn_id, start_idx, end_idx, new_value): + replace(dialogue, turn_id=turn_id, start_idx=start_idx, end_idx=end_idx, new_value=new_value) + for turn in dialogue["turns"]: + for frame in turn["frames"]: + if "state_update" in frame: + frame.pop("state_update") + + +def test(dialogues, dialogue_id, turn_id, old_value, new_value): + dialogue = copy.deepcopy(dialogues[dialogue_id]) + augment_dialog_by_auxiliary_entries(dialogue) + m = re.search(old_value, dialogue["turns"][turn_id]["utterance"]) + test_helper(dialogue, dialogue_id, turn_id, start_idx=m.start(), end_idx=m.end(), new_value=new_value) + pprint(dialogue) + validate(dialogue) + d_str_new = json.dumps(dialogue, sort_keys=True, indent=2) + d_str_old = json.dumps(dialogues[dialogue_id], sort_keys=True, indent=2) + print(d_str_new == d_str_old) + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--concat_orig_dialogue", action="store_true", help="contenate original dialogue to the augmented one" + ) + parser.add_argument( + "--input_dir", + type=str, + default="", + help="data directory. contains one schema.json and multiple dialogue*.json files", + ) + parser.add_argument("--output_dir", type=str, help="output data directory", default=None) + parser.add_argument("--num2string", action="store_true", help="convert digits to string") + parser.add_argument("--repeat", type=int, default=5, help="number of augmentation sweeps over input data") + parser.add_argument("--replace_turn_prob", type=float, default=1.0, help="likelihood to modify an utterance turn") + parser.add_argument( + "--replace_word_prob", type=float, default=1.0, help="likelihood to modify a word in an utterance" + ) + parser.add_argument("--seed", type=int, default=0) + args = parser.parse_args() + return args + + +if __name__ == "__main__": + + args = parse_args() + print(vars(args)) + random.seed(args.seed) + + if not os.path.exists(args.input_dir): + raise ValueError( + "SGD dataset not found. Dataset can be downloaded from https://github.com/google-research-datasets/dstc8-schema-guided-dialogue" + ) + + in_file_path = args.input_dir + schema_path = os.path.join(in_file_path, 'schema.json') + dialogue_files = [ + os.path.join(in_file_path, f) + for f in os.listdir(in_file_path) + if os.path.isfile(os.path.join(in_file_path, f)) + if "dialogue" in f + ] + dialogue_files.sort() + orig_dialog = [] + for d_file in dialogue_files: + orig_dialog.extend(json.load(open(d_file, 'r'))) + print(f"len(orig_dialog) = {len(orig_dialog)}") + orig_schema = json.load(open(schema_path, 'r')) + + dialogue_count = defaultdict(int) + final_dialogues = defaultdict(list) + + ontology = get_ontology(dialogues=orig_dialog, schemas=orig_schema) + + for dialogue_id, dialogue in tqdm(enumerate(orig_dialog)): + validate(dialogue) # for test purposes + augment_dialog_by_auxiliary_entries(dialogue) + validate(dialogue) # for test purposes + + if args.num2string: + if args.concat_orig_dialogue: + process_dialogues( + final_dialogues=final_dialogues, + dialogue_count=dialogue_count, + dialogues=orig_dialog, + replace_turn_prob=1.0, + replace_word_prob=1.0, + new_val_func=num2str, + ) + else: + process_dialogues( + final_dialogues=defaultdict(list), + dialogue_count=defaultdict(int), + dialogues=orig_dialog, + replace_turn_prob=1.0, + replace_word_prob=1.0, + new_val_func=num2str, + ) + + for _ in range(args.repeat): + dialogues = copy.deepcopy(orig_dialog) + process_dialogues( + final_dialogues=final_dialogues, + dialogue_count=dialogue_count, + dialogues=dialogues, + replace_turn_prob=args.replace_turn_prob, + replace_word_prob=args.replace_word_prob, + new_val_func=get_new_value, + ) + + if args.concat_orig_dialogue and not args.num2string: + for dialogue_id, dialogue in tqdm(enumerate(orig_dialog)): + d_id, d_count = dialogue["dialogue_id"].split("_") + d_id = int(d_id) + dialogue["dialogue_id"] = f"{d_id}_{dialogue_count[d_id]:05d}" + dialogue_count[d_id] += 1 + final_dialogues[d_id].append(dialogue) + + for dir_id, dialogues in final_dialogues.items(): + for dialogue in dialogues: + for turn in dialogue["turns"]: + for frame in turn["frames"]: + if 'state_update' in frame: + frame.pop("state_update") + if 'slot_to_span' in frame: + frame.pop("slot_to_span") + if args.output_dir is None: + output_dir = f"augmented_repeat{args.repeat}_replace_turn_prob{args.replace_turn_prob}_replace_word_prob{args.replace_word_prob}_concatorig{args.concat_orig_dialogue}_num2string{args.num2string}" + else: + output_dir = args.output_dir + os.makedirs(output_dir, exist_ok=True) + for dir_id, dialogues in final_dialogues.items(): + with open(os.path.join(output_dir, f"dialogues_{dir_id:03d}.json"), 'w') as outfile: + json.dump(dialogues, outfile, indent=2) + + with open(os.path.join(output_dir, f"schema.json"), 'w') as outfile: + json.dump(orig_schema, outfile, indent=2) diff --git a/examples/nlp/dialogue_state_tracking/data/multiwoz/__init__.py b/examples/nlp/dialogue_state_tracking/data/multiwoz/__init__.py new file mode 100644 index 000000000000..cd24d1f06b22 --- /dev/null +++ b/examples/nlp/dialogue_state_tracking/data/multiwoz/__init__.py @@ -0,0 +1,16 @@ +# ============================================================================= +# Copyright 2020 NVIDIA. All Rights Reserved. +# Copyright 2019 The Google Research Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= diff --git a/examples/nlp/dialogue_state_tracking/data/multiwoz/correct_categorical_state_values.tsv b/examples/nlp/dialogue_state_tracking/data/multiwoz/correct_categorical_state_values.tsv new file mode 100644 index 000000000000..0672290108ea --- /dev/null +++ b/examples/nlp/dialogue_state_tracking/data/multiwoz/correct_categorical_state_values.tsv @@ -0,0 +1,18 @@ +alpha-milton alpha milton +any dontcare +bed and breakfast guesthouse +boating boat +cam cambridge +concert concerthall +concert hall concerthall +guest house guesthouse +guesthouses guesthouse +moderate|cheap cheap|moderate +museum kettles yard museum +mutiple sports multiple sports +nightclub night club +acorn guesthouse acorn guest house +swimmingpool swimming pool +sports multiple sports +pool swimming pool +theater theatre \ No newline at end of file diff --git a/examples/nlp/dialogue_state_tracking/data/multiwoz/create_data_from_multiwoz.py b/examples/nlp/dialogue_state_tracking/data/multiwoz/create_data_from_multiwoz.py new file mode 100644 index 000000000000..cdbc5c4fa989 --- /dev/null +++ b/examples/nlp/dialogue_state_tracking/data/multiwoz/create_data_from_multiwoz.py @@ -0,0 +1,793 @@ +# ============================================================================= +# Copyright 2020 NVIDIA. All Rights Reserved. +# Copyright 2019 The Google Research Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= + +"""Converts Multiwoz 2.1 dataset to the data format of SGD.""" +import argparse +import collections +import copy +import json +import os +import re + +import nemo.collections.nlp.data.datasets.sgd_dataset.schema as schema +from nemo import logging + +# Parsing arguments +parser = argparse.ArgumentParser(description='conversion of multiwoz into sgd') + +parser.add_argument('--input_data_dir', type=str, required=True, help='Path of the dataset to convert from.') +parser.add_argument( + '--output_dir', + type=str, + help='Path to output directory. If not specified, generate the dialogues in the same directory as the script.', +) +parser.add_argument( + '--annotate_copy_slots', + action='store_true', + help='Whether to annotate slots whose value is copied from a different slot in ' + 'the previous state. If true, add a new key "copy_from" in the slot ' + 'annotation dict. Its value is the slot that the value is copied from.', +) + +parser.add_argument('--schema_file_name', default='schema.json', type=str, help='Name of the schema file to use.') + +args = parser.parse_args() + +_PATH_MAPPING = [('test', 'testListFile.json'), ('dev', 'valListFile.json'), ('train', '')] + +_DIR_PATH = os.path.dirname(os.path.abspath(__file__)) +# File used for correcting categorical slot values. Each line is a pair of +# the original slot value in MultiWOZ 2.1 annotation and the corrected slot +# value. +_CORRECT_FOR_STATE_PATH = os.path.join(_DIR_PATH, 'correct_categorical_state_values.tsv') + +_DEFAULT_SERVICE_NAME = 'all' +# "Don't care" slot value. +_DONT_CARE = 'dontcare' +_NONE_VALUE = 'none' +_INACTIVE_INTENT = 'NONE' +# Maximum number of dialogues to write in each output file. +_NUM_DIALS_PER_FILE = 512 + +# We try to find the span of non-categorical slot values in the dialog history, +# but sometimes there is no exact match and we choose to find the closest values +# from the utterance. If the found value is contained in the list below, +# we need to check if it is a correct match. +_FOUND_VALUES_NEED_CHECK = [ + 'restaurant', + 'hotel', + 'museum', + 'church', + 'college', + 'cinema', + 'park', + 'guesthouses', + 'guesthouse', + 'great', + 'from', + 'hotels', + 'school', + 'schools', + 'guests', + 'colleges', + 'lodge', + 'theatre', + 'centre', + 'bar', + 'bed and breakfast', + 'train', + 'station', + 'gallery', + 'la', + 'time', + 'house', + 'guest house', + 'old', + 'pool', + 'house', + 'a', + 'b', + 'the', + 'cafe', + 'cambridge', + 'hospital', + 'restaurant\'s', +] + +# A collection of phrases that are semantically similar to the key value, which +# is a word. +_SIMILAR_WORDS = { + 'portuguese': ['portugese', 'portugeuese'], + '01:30': ['1 thirty p . m .'], + '16:30': ['after 16:00'], + 'anatolia': ['anatoilia'], + 'allenbell': ['allenball'], + 'caribbean': ['carribbean'], + 'seafood': ['sea food'], + 'moroccan': ['morrocan'], + 'avalon': ['avaion'], + 'barbeque': ['bbq'], + 'american': ['americas'], + 'italian': ['pizza place'], + 'indian': ['taj tandoori'], + 'british': ['english'], + 'cambride': ['cambridge'], + 'fenditton': ['fen ditton'], + 'cafe': ['caffe'], + 'gonvile': ['gonville'], + 'shaddia': ['shaddai'], +} + +# A collection of phrases that are semantically similar to the key value, which +# is a phrase consisted of more than one word. +_SIMILAR_PHRASES = { + 'alexander bed and breakfast': ['alexander b&b', 'alexander bed and breafast', 'alexander bed & breakfast'], + 'a and b guest house': ['a & b guest house', 'a and b guesthouse', 'a and be guest house'], + 'saint johns chop house': ['saint johns chop shop house'], + 'bridge guest house': ['bridge guesthouse'], + 'finches b and b': ['finches b & b', 'finches b&b'], + 'finches bed and breakfast': ['flinches bed and breakfast', 'finches b&b'], + 'carolina bed and breakfast': ['carolina b&b'], + 'city centre north b and b': ['city centre north b&b', 'city centre north b & b'], + 'lan hong house': ['ian hong house', 'ian hong'], + 'ugly duckling': ['ugly ducking'], + 'sri lankan': ['sri lanken'], + 'cambridge punter': ['cambridge punte'], + 'abc theatre': ['adc theatre'], +} + + +def _locate_boundary(phrase, text): + """Locate the span of the phrase using exact match.""" + + def _locate_token_boundary(pos, text): + """Get the start and end index of a token that covers a certain position.""" + if pos < 0: + raise ValueError('Pos {} should be a positive integer.'.format(pos)) + next_space = text.find(' ', pos) + left_boundary = text.rfind(' ', 0, pos) + 1 + right_boundary = next_space if next_space != -1 else len(text) + return left_boundary, right_boundary + + phrase = phrase.strip() + pos_in_text = text.find(phrase) + if pos_in_text == -1: + return None, None + + tokens = phrase.split() + start_idx, _ = _locate_token_boundary(pos_in_text, text) + last_token = tokens[-1] + find_last_token = text.find(last_token, pos_in_text + len(phrase) - len(last_token)) + if find_last_token == -1: + raise ValueError('Should find the last word for value {}'.format(phrase)) + _, end_idx = _locate_token_boundary(find_last_token, text) + # If it's a number, the value should be exactly the same. + if phrase.isdigit() and text[start_idx:end_idx] != phrase: + return None, None + # If the phrase is short, the value should be exactly the same. + # e.g. we don't want to match "theatre" when searching for "the" + if len(phrase) <= 3 and len(phrase) != (end_idx - start_idx): + return None, None + return start_idx, end_idx + + +def _locate_word(word, text, start_pos): + """Get start and end index of a phrase that semantically equals to a word.""" + # If the word to search for contains 3 or 4 digits, correct it into time. + obj = re.match(r'(? 12: + times_to_try.append(':'.join([str(hour - 12), obj.group(2)])) + if minute == 0: + times_to_try.append(str(hour - 12) + ' pm') + times_to_try.append(str(hour - 12) + 'pm') + times_to_try.append(str(hour - 12) + ' p . m .') + times_to_try.append(str(hour - 12) + ' o\'clock p . m .') + times_to_try.append(str(hour - 12) + ' o\'clock') + times_to_try.append(str(hour) + ' o\'clock') + times_to_try.append(str(hour - 12) + ':00') + times_to_try.append(str(hour)) + elif hour == 12 and minute == 0: + times_to_try.extend(['12 pm', '12pm', '12 o\'clock', '12 p . m .', '12', 'noon']) + else: + times_to_try.append(':'.join([str(hour + 12), obj.group(2)])) + if int(minute) == 0: + times_to_try.append(str(hour) + ' am') + times_to_try.append(str(hour) + 'am') + times_to_try.append(str(hour) + ' a . m .') + times_to_try.append(str(hour) + ' o\'clock a . m .') + times_to_try.append(str(hour) + ' o\'clock') + times_to_try.append(str(hour + 12) + ':00') + times_to_try.append(str(hour)) + if minute == 15 or minute == 45 or minute == 30: + times_to_try.append('after ' + str(hour) + ':' + str(minute - 15)) + if hour < 10: + times_to_try.append('after 0' + str(hour) + ':' + str(minute - 15)) + if minute == 0: + times_to_try.append('after ' + str(hour - 1) + ':45') + for time_value in times_to_try: + # Correct time like "08:15" to "8:15" to increase match possibility. + if time_value[0] == '0': + if len(time_value) > 2 and time_value[1] != [':']: + time_value = time_value[1:] + else: + start_idx, end_idx = _locate_boundary(word, text) + if start_idx is not None: + return start_idx + start_pos, end_idx + start_pos + # Try phrases that is similar to the word to find. + for similar_word in _SIMILAR_WORDS.get(word, []): + start_idx, end_idx = _locate_boundary(similar_word, text) + if start_idx is not None: + return start_idx + start_pos, end_idx + start_pos + + # Slot values ended with 's' can be written in different formats. + # e.g. rosas can be written as rosa, rosa's. + if word.endswith('s') and len(word) > 3: + modified_words = [word[:-1] + '\'s', word[:-1]] + for modified_word in modified_words: + start_idx, end_idx = _locate_boundary(modified_word, text) + if start_idx is not None: + return start_idx + start_pos, end_idx + start_pos + return None, None + + +def exists_in_prev_dialog_states(slot_value, converted_turns): + """Whether slot value exists in the previous dialogue states.""" + for user_turn in converted_turns[::2]: + assert user_turn['speaker'] == 'USER' + for frame in user_turn['frames']: + if 'state' in frame and 'slot_values' in frame['state']: + slot_values_dict = frame['state']['slot_values'] + for slot, values_list in slot_values_dict.items(): + new_list = [] + for value in values_list: + new_list.extend(value.split('|')) + if slot_value in new_list: + return frame['service'], slot, values_list + return None, None, None + + +class Processor(object): + """A processor to convert Multiwoz to the data format used in SGD.""" + + def __init__(self, schemas): + self._schemas = schemas + # For statistically evaluating the modifications. + # Number of non-categorical slot values in dialogue state, which needs span + # annotations. + self._slot_spans_num = 0 + # Dict to track the number of non-categorical slot values whose span can not + # be found. + self._unfound_slot_spans_num = collections.Counter() + + # Dict used to correct categorical slot values annotated in MultiWOZ 2.1. + self._slot_value_correction_for_cat_slots = {} + with open(_CORRECT_FOR_STATE_PATH, 'r') as f: + for line in f: + tok_from, tok_to = line.replace('\n', '').split('\t') + self._slot_value_correction_for_cat_slots[tok_from] = tok_to + + @property + def unfound_slot_span_ratio(self): + """Get the ratio of the slot spans that can't be found in the utterances.""" + ratio_dict = {k: float(v) / float(self._slot_spans_num) for k, v in self._unfound_slot_spans_num.items()} + ratio_dict['total'] = float(sum(self._unfound_slot_spans_num.values())) / float(self._slot_spans_num) + return ratio_dict + + def _basic_text_process(self, text, lower=True): + # Remove redundant spaces. + text = re.sub(r'\s+', ' ', text).strip() + if lower: + text = text.lower() + return text + + def _insert_slots_annotations_to_turn(self, turn, slots_annotations_list, service_name): + """Insert slot span annotations to a turn.""" + found_service = False + for frame in turn['frames']: + if frame['service'] == service_name: + frame['slots'].extend(slots_annotations_list) + found_service = True + continue + if not found_service: + turn['frames'].append({'service': service_name, 'slots': slots_annotations_list, 'actions': []}) + return + + def _correct_state_value_for_noncat(self, slot, val): + """Correct slot values for non-categorical slots.""" + val = val.strip() + if ( + (val == 'cam' and slot == 'restaurant-name') + or (val == 'friday' and slot == 'train-leaveat') + or (val == 'bed' and slot == 'attraction-name') + ): + return '' + if val == 'portugese': + val = 'portuguese' + return val + + def _correct_state_value_for_cat(self, _, val): + """Correct slot values for categorical slots.""" + val = val.strip() + return self._slot_value_correction_for_cat_slots.get(val, val) + + def _get_intent_from_actions(self, state_value_dict, sys_actions, user_actions): + """Generate user intent by rules. + + We assume each service has only one active intent which equals to the domain + mentioned in the current user turn. + We use _infer_domains_from_actions to infer the list of possible domains. + Domains that appear in the user actions and dialogue updates are prioritised + over domains mentioned in the previous system actions. + In the provided schema of MultiWOZ 2.1, every service contains one domain, + so the active_intent is either "NONE" or "find_{domain}" for every service. + + Args: + state_value_dict: a dict, key is the slot name, value is a list. + sys_actions: a list of sys actions in the next turn. + user_actions: a list of user actions. + + Returns: + String, intent of the current user turn. + """ + + def _infer_domains_from_actions(state_value_dict, sys_actions, user_actions): + """Infer the domains involved in the current turn from actions.""" + user_mentioned_domains = set() + for user_action in user_actions: + domain = user_action['act'].lower().split('-')[0] + if domain not in ['general', 'booking']: + user_mentioned_domains.add(domain) + sys_mentioned_domains = set() + for sys_action in sys_actions: + domain = sys_action['act'].lower().split('-')[0] + if domain not in ['general', 'booking']: + sys_mentioned_domains.add(domain) + # Compute domains whose slot values get updated in the current turn. + state_change_domains = set() + for slot, _ in state_value_dict.items(): + domain_name = slot.split('-')[0] + state_change_domains.add(domain_name) + # Infer the possible domains involved in the current turn for a certain + # service. + return list(user_mentioned_domains.union(state_change_domains)) or list(sys_mentioned_domains) + + domains = _infer_domains_from_actions(state_value_dict, sys_actions, user_actions) + return 'find_' + domains[0] if domains else _INACTIVE_INTENT + + def _is_filled(self, slot_value): + """Whether a slot value is filled.""" + slot_value = slot_value.lower() + return slot_value and slot_value != 'not mentioned' and slot_value != 'none' + + def _new_service_name(self, domain): + """Get the new service_name decided by the new schema.""" + # If the schema file only contains one service, we summarize all the slots + # into one service, otherwise, keep the domain name as the service name. + return _DEFAULT_SERVICE_NAME if (len(self._schemas.services) == 1) else domain + + def _get_slot_name(self, slot_name, service_name, in_book_field=False): + """Get the slot name that is consistent with the schema file.""" + slot_name = 'book' + slot_name if in_book_field else slot_name + return '-'.join([service_name, slot_name]).lower() + + def _generate_dialog_states(self, frame_dict, overwrite_slot_values): + """Get the dialog states and overwrite some of the slot values.""" + dialog_states = collections.defaultdict(dict) + orig_dialog_states = collections.defaultdict(dict) + for domain_name, values in frame_dict.items(): + dialog_states_of_one_domain = {} + for k, v in values['book'].items(): + if isinstance(v, list): + for item_dict in v: + new_states = { + self._get_slot_name(slot_name, domain_name, in_book_field=True): slot_val + for slot_name, slot_val in item_dict.items() + } + dialog_states_of_one_domain.update(new_states) + if isinstance(v, str) and v: + slot_name = self._get_slot_name(k, domain_name, in_book_field=True) + dialog_states_of_one_domain[slot_name] = v + new_states = { + self._get_slot_name(slot_name, domain_name): slot_val for slot_name, slot_val in values['semi'].items() + } + dialog_states_of_one_domain.update(new_states) + # Get the new service_name that is decided by the schema. If the + # schema file only contains one service, we summarize all the slots into + # one service, otherwise, keep the domain name as the service name. + new_service_name = self._new_service_name(domain_name) + # Record the orig state values without any change. + orig_dialog_state_of_one_domain = copy.deepcopy(dialog_states_of_one_domain) + for (key, value) in orig_dialog_state_of_one_domain.items(): + if key in self._schemas.get_service_schema(new_service_name).slots and self._is_filled(value): + orig_dialog_states[new_service_name][key] = value + # Correct the slot values in the dialogue state. + corrected_dialog_states_of_one_domain = {} + for k, v in dialog_states_of_one_domain.items(): + if k in self._schemas.get_service_schema(new_service_name).categorical_slots: + corrected_dialog_states_of_one_domain[k] = self._correct_state_value_for_cat( + k, self._basic_text_process(v) + ) + else: + corrected_dialog_states_of_one_domain[k] = self._correct_state_value_for_noncat( + k, self._basic_text_process(v) + ) + dialog_states_of_one_domain = { + k: v for k, v in corrected_dialog_states_of_one_domain.items() if self._is_filled(v) + } + + # Overwrite some of the slot values and changes the slot value of a slot + # into a list. + for slot, value in dialog_states_of_one_domain.items(): + dialog_states_of_one_domain[slot] = [value] + if slot in overwrite_slot_values[new_service_name]: + if value in overwrite_slot_values[new_service_name][slot]: + dialog_states_of_one_domain[slot] = sorted( + overwrite_slot_values[new_service_name][slot][value] + ) + # Only track the slot values that are listed in the schema file. Slots + # such as reference number, phone number are filtered out. + for (key, value) in dialog_states_of_one_domain.items(): + if key in self._schemas.get_service_schema(new_service_name).slots: + dialog_states[new_service_name][key] = value + return dialog_states, orig_dialog_states + + def _get_update_states(self, prev_ds, cur_ds): + """Get the updated dialogue states between two user turns.""" + updates = collections.defaultdict(dict) + for service, slot_values_dict in cur_ds.items(): + if service not in prev_ds: + updates[service] = slot_values_dict + continue + for slot, values in slot_values_dict.items(): + for value in values: + if slot not in prev_ds[service] or value not in prev_ds[service][slot]: + updates[service][slot] = updates[service].get(slot, []) + [value] + return updates + + def _generate_slot_annotation(self, orig_utt, slot, slot_value): + """Generate the slot span of a slot value from the utterance. + + Args: + orig_utt: Original utterance in string. + slot: Slot name in string. + slot_value: Slot value to be annotated in string. + + Returns: + slot_ann: A dict that denotes the slot name and slot spans. + slot_value: The corrected slot value based on the utterance. It's + unchanged if the slot value can't be found in the utterance. + """ + slot_ann = [] + utt = orig_utt.lower() + start_idx, end_idx = None, None + # Check if the utterance mentions any phrases that are semantically same as + # the slot value. + for alias_slot_value in [slot_value] + _SIMILAR_PHRASES.get(slot_value, []): + start_idx, end_idx = _locate_boundary(alias_slot_value, utt) + if start_idx is not None: + break + if start_idx is None: + # Tokenize the slot value and find each of them. + splitted_slot_values = slot_value.strip().split() + unfound_tokens_idx = [] + search_start_idx = 0 + # Find if each token exists in the utterance. + for i, value_tok in enumerate(splitted_slot_values): + tok_start_idx, tok_end_idx = _locate_word(value_tok, utt, search_start_idx) + if tok_start_idx is not None and tok_end_idx is not None: + # Hard coded rules + # if the value to find is one of ['and', 'of', 'by'] and + # there's no token prior to them having been found, we don't think + # the value as found since they are fairly common words. + if value_tok in ['and', 'of', 'by'] and start_idx is None: + unfound_tokens_idx.append(i) + continue + if start_idx is None: + start_idx = tok_start_idx + search_start_idx = tok_end_idx + else: + unfound_tokens_idx.append(i) + # Record the last index. + if search_start_idx > 0: + end_idx = search_start_idx + if start_idx is None: + return [], slot_value + new_slot_value = utt[start_idx:end_idx] + + if abs(len(slot_value) - len(new_slot_value)) > 20: + return [], slot_value + if len(new_slot_value.split()) > (len(slot_value.strip().split()) + 2) and ( + new_slot_value not in _SIMILAR_PHRASES.get(slot_value, []) + ): + return [], slot_value + # If the value found from the utterance is one of values below and the real + # slot value contains more than one tokens, we don't think it as a + # successful match. + if new_slot_value.strip() in _FOUND_VALUES_NEED_CHECK and len(slot_value.split()) > 1: + return [], slot_value + # If the value based on the utterance ends with any value below, we don't + # annotate span of it. + if new_slot_value.strip().split()[-1] in ['and', 'the', 'of', 'by']: + return [], slot_value + slot_ann.append( + {'slot': slot, 'value': orig_utt[start_idx:end_idx], 'exclusive_end': end_idx, 'start': start_idx,} + ) + return slot_ann, new_slot_value + + def _update_corrected_slot_values( + self, corrected_slot_values_dict, service_name, slot, slot_value, new_slot_value + ): + """Update the dict that keeps track of the modified state values.""" + if slot not in corrected_slot_values_dict[service_name]: + corrected_slot_values_dict[service_name][slot] = collections.defaultdict(set) + corrected_slot_values_dict[service_name][slot][slot_value] = {slot_value} + corrected_slot_values_dict[service_name][slot][slot_value].add(new_slot_value) + return + + def _get_requested_slots_from_action(self, act_list): + """Get user's requested slots from the action.""" + act_request = [] + for act_dict in act_list: + if 'request' in act_dict['act'].lower(): + slot_name = act_dict['slot'] + if slot_name == 'Arrive': + slot_name = 'arriveby' + elif slot_name == 'Leave': + slot_name = 'leaveat' + act_request.append('-'.join([act_dict['act'].split('-')[0], slot_name]).lower()) + return act_request + + def _generate_actions(self, dialog_act): + """Generate user/system actions.""" + converted_actions = collections.defaultdict(list) + for k, pair_list in dialog_act.items(): + k_list = k.lower().strip().split('-') + domain = k_list[0] + service_name = self._new_service_name(domain) + act_slot_values_dict = collections.defaultdict(list) + for pair in pair_list: + slot = pair[0] + slot_value = pair[1] + if slot != _NONE_VALUE: + act_slot_values_dict[slot].append(slot_value) + if not act_slot_values_dict: + converted_actions[service_name].append({'act': k}) + for slot, values in act_slot_values_dict.items(): + converted_actions[service_name].append({'act': k, 'slot': slot, 'values': values}) + return converted_actions + + def _generate_dial_turns(self, turns, dial_id): + """Generate the dialog turns and the services mentioned in the dialogue.""" + prev_dialog_states = collections.defaultdict(dict) + corrected_slot_values = collections.defaultdict(dict) + converted_turns = [] + appear_services = set() + if len(turns) % 2 != 0: + raise ValueError('dialog ended by user') + for i in range(len(turns))[::2]: + user_info = turns[i] + sys_info = turns[i + 1] + user_utt = self._basic_text_process(user_info['text'], False) + sys_utt = self._basic_text_process(sys_info['text'], False) + user_actions = collections.defaultdict(list) + sys_actions = collections.defaultdict(list) + if 'dialog_act' in user_info: + user_actions = self._generate_actions(user_info['dialog_act']) + if 'dialog_act' in sys_info: + sys_actions = self._generate_actions(sys_info['dialog_act']) + + sys_turn = {'utterance': sys_utt, 'speaker': 'SYSTEM', 'frames': [], 'turn_id': str(i + 1)} + user_turn = {'utterance': user_utt, 'speaker': 'USER', 'frames': [], 'turn_id': str(i)} + dialog_states, _ = self._generate_dialog_states(sys_info['metadata'], corrected_slot_values) + appear_services.update(dialog_states.keys()) + + # Fill in slot spans in the user turn and the previous system turn for + # the non categorical slots. + user_slots = collections.defaultdict(list) + sys_slots = collections.defaultdict(list) + update_states = self._get_update_states(prev_dialog_states, dialog_states) + prev_sys_utt = converted_turns[-1]['utterance'] if converted_turns else '' + for service_name, slot_values_dict in update_states.items(): + new_service_name = self._new_service_name(service_name) + service_schema = self._schemas.get_service_schema(new_service_name) + for slot, slot_value in slot_values_dict.items(): + assert slot_value, 'slot values shouls not be empty' + slot_value = slot_value[0] + if slot in service_schema.categorical_slots: + if slot_value not in service_schema.get_categorical_slot_values(slot) and slot_value not in [ + _DONT_CARE + ]: + logging.error('Value %s not contained in slot %s, dial_id %s, ', slot_value, slot, dial_id) + dialog_states[service_name][slot] = [slot_value] + else: + self._slot_spans_num += 1 + if slot_value == _DONT_CARE: + continue + user_slot_ann, slot_value_from_user = self._generate_slot_annotation( + user_utt, slot, slot_value + ) + sys_slot_ann, slot_value_from_sys = self._generate_slot_annotation( + prev_sys_utt, slot, slot_value + ) + # Values from user utterance has a higher priority than values from + # sys utterance. We correct the slot value of non-categorical slot + # first based on user utterance, then system utterance. + if user_slot_ann and slot_value_from_user != slot_value: + if sys_slot_ann and (slot_value_from_sys == slot_value): + user_slot_ann = None + else: + self._update_corrected_slot_values( + corrected_slot_values, service_name, slot, slot_value, slot_value_from_user + ) + dialog_states[service_name][slot] = list( + corrected_slot_values[service_name][slot][slot_value] + ) + if not user_slot_ann and sys_slot_ann and slot_value_from_sys != slot_value: + self._update_corrected_slot_values( + corrected_slot_values, service_name, slot, slot_value, slot_value_from_sys + ) + dialog_states[service_name][slot] = list( + corrected_slot_values[service_name][slot][slot_value] + ) + if user_slot_ann: + user_slots[service_name].extend(user_slot_ann) + if sys_slot_ann: + sys_slots[service_name].extend(sys_slot_ann) + if not user_slot_ann and not sys_slot_ann: + # First check if it exists in the previous dialogue states. + from_service_name, from_slot, from_slot_values = exists_in_prev_dialog_states( + slot_value, converted_turns + ) + if from_service_name is not None: + self._unfound_slot_spans_num['copy_from_prev_dialog_state'] += 1 + if args.annotate_copy_slots: + user_slots[service_name].append( + {'slot': slot, 'copy_from': from_slot, 'value': from_slot_values} + ) + continue + # Second, trace back the dialogue history to find the span. + for prev_turn in converted_turns[-2::-1]: + prev_utt = prev_turn['utterance'] + prev_slot_ann, prev_slot_value = self._generate_slot_annotation( + prev_utt, slot, slot_value + ) + if prev_slot_ann: + if prev_slot_value != slot_value: + self._update_corrected_slot_values( + corrected_slot_values, service_name, slot, slot_value, prev_slot_value + ) + dialog_states[service_name][slot] = list( + corrected_slot_values[service_name][slot][slot_value] + ) + self._insert_slots_annotations_to_turn(prev_turn, prev_slot_ann, service_name) + break + self._unfound_slot_spans_num[slot] += 1 + continue + # Fill in slot annotations for the system turn. + for service_name in sys_slots: + if not sys_slots[service_name]: + continue + self._insert_slots_annotations_to_turn(converted_turns[-1], sys_slots[service_name], service_name) + # Generate user frames from dialog_states. + latest_update_states = self._get_update_states(prev_dialog_states, dialog_states) + for service_name, slot_values_dict in dialog_states.items(): + user_intent = self._get_intent_from_actions( + latest_update_states[service_name], sys_actions[service_name], user_actions[service_name] + ) + # Fill in values. + user_turn['frames'].append( + { + 'slots': user_slots[service_name], + 'state': { + 'slot_values': {k: v for k, v in slot_values_dict.items() if v}, + 'requested_slots': self._get_requested_slots_from_action(user_actions[service_name]), + 'active_intent': user_intent, + }, + 'service': service_name, + } + ) + non_active_services = set(self._schemas.services) - appear_services + for service_name in non_active_services: + user_intent = self._get_intent_from_actions({}, sys_actions[service_name], user_actions[service_name]) + user_turn['frames'].append( + { + 'service': service_name, + 'slots': [], + 'state': { + 'active_intent': user_intent, + 'requested_slots': self._get_requested_slots_from_action(user_actions[service_name]), + 'slot_values': {}, + }, + } + ) + converted_turns.extend([user_turn, sys_turn]) + prev_dialog_states = dialog_states + return converted_turns, list(appear_services) + + def convert_to_dstc(self, id_list, dialogs): + """Generate a list of dialogues in the dstc8 data format.""" + converted_dialogs = [] + for dial_id in id_list: + converted_turns, covered_services = self._generate_dial_turns(dialogs[dial_id]['log'], dial_id) + dialog = {'dialogue_id': dial_id, 'services': covered_services, 'turns': converted_turns} + converted_dialogs.append(dialog) + return converted_dialogs + + +def change_to_nemo_id(dialogs_list, file_index): + for i, dialogue in enumerate(dialogs_list): + dialogue['dialogue_id'] = f'{file_index}_{i:05d}' + return dialogs_list + + +def main(): + schema_path = os.path.join(_DIR_PATH, args.schema_file_name) + schemas = schema.Schema(schema_path) + processor = Processor(schemas) + data_path = os.path.join(args.input_data_dir, 'data.json') + with open(data_path, 'r') as f: + data = json.load(f) + dev_test_ids = [] + output_dir = args.output_dir or _DIR_PATH + # Generate dev and test set according to the ids listed in the files. Ids not + # included in the dev and test id list files belong to the training set. + for output_dir_name, file_name in _PATH_MAPPING: + output_sub_dir = os.path.join(output_dir, output_dir_name) + if not os.path.exists(output_sub_dir): + os.makedirs(output_sub_dir) + schema_path = os.path.join(output_sub_dir, 'schema.json') + schemas.save_to_file(schema_path) + dial_ids = [] + if file_name: + id_list_path = os.path.join(args.input_data_dir, file_name) + with open(id_list_path) as f: + dial_ids = [id_name.strip() for id_name in f.readlines()] + dev_test_ids.extend(dial_ids) + else: + # Generate the ids for the training set. + dial_ids = list(set(data.keys()) - set(dev_test_ids)) + converted_dials = processor.convert_to_dstc(dial_ids, data) + logging.info('Unfound slot span ratio %s', processor.unfound_slot_span_ratio) + logging.info('Writing %d dialogs to %s', len(converted_dials), output_sub_dir) + for i in range(0, len(converted_dials), _NUM_DIALS_PER_FILE): + file_index = int(i / _NUM_DIALS_PER_FILE) + 1 + # Create a new json file and save the dialogues. + json_file_path = os.path.join(output_sub_dir, 'dialogues_{:03d}.json'.format(file_index)) + dialogs_list = converted_dials[(file_index - 1) * _NUM_DIALS_PER_FILE : file_index * _NUM_DIALS_PER_FILE] + dialogs_list = change_to_nemo_id(dialogs_list, file_index) + with open(json_file_path, 'w') as f: + json.dump(dialogs_list, f, indent=2, separators=(',', ': '), sort_keys=True) + logging.info('Created %s with %d dialogues.', json_file_path, len(dialogs_list)) + + +if __name__ == '__main__': + main() diff --git a/examples/nlp/dialogue_state_tracking/data/multiwoz/schema.json b/examples/nlp/dialogue_state_tracking/data/multiwoz/schema.json new file mode 100644 index 000000000000..c130b0fd818b --- /dev/null +++ b/examples/nlp/dialogue_state_tracking/data/multiwoz/schema.json @@ -0,0 +1,636 @@ +[ + { + "service_name": "hotel", + "slots": [ + { + "name": "hotel-pricerange", + "description": "the price range of the hotel", + "possible_values": [ + "$100", + "cheap", + "cheap>moderate", + "cheap|moderate", + "expensive", + "moderate" + ], + "is_categorical": true + }, + { + "name": "hotel-type", + "description": "the type of the hotel", + "possible_values": [ + "guesthouse", + "hotel", + "hotel|guesthouse" + ], + "is_categorical": true + }, + { + "name": "hotel-parking", + "description": "does the hotel have free parking", + "possible_values": [ + "free", + "no", + "yes" + ], + "is_categorical": true + }, + { + "name": "hotel-bookday", + "description": "the day of hotel booking", + "possible_values": [ + "friday", + "friday>tuesday", + "monday", + "mondaymonday", + "thursday", + "tuesday", + "wednesday", + "wednesday|friday" + ], + "is_categorical": true + }, + { + "name": "hotel-bookpeople", + "description": "number of people to book the hotel for", + "possible_values": [ + "1", + "2", + "3", + "4", + "5", + "6", + "7", + "8" + ], + "is_categorical": true + }, + { + "name": "hotel-bookstay", + "description": "the duration of stay or booking", + "possible_values": [ + "1", + "2", + "3", + "3|1", + "4", + "5", + "5|4", + "6", + "7", + "8" + ], + "is_categorical": true + }, + { + "name": "hotel-stars", + "description": "the rating of the hotel", + "possible_values": [ + "0", + "1", + "2", + "3", + "3|4", + "4", + "4|5", + "5" + ], + "is_categorical": true + }, + { + "name": "hotel-internet", + "description": "does it have internet or wifi", + "possible_values": [ + "free", + "no", + "yes" + ], + "is_categorical": true + }, + { + "name": "hotel-name", + "description": "the name of the hotel", + "possible_values": [], + "is_categorical": false + }, + { + "name": "hotel-area", + "description": "the locality of the hotel", + "possible_values": [ + "centre", + "east", + "north", + "south", + "west", + "west|centre" + ], + "is_categorical": true + } + ], + "description": "hotel reservations and vacation stays", + "intents": [ + { + "name": "find_hotel", + "description": "search for a hotel to stay in", + "is_transactional": false, + "required_slots": [], + "optional_slots": { + "hotel-pricerange": "dontcare", + "hotel-type": "dontcare", + "hotel-parking": "dontcare", + "hotel-bookday": "dontcare", + "hotel-bookpeople": "dontcare", + "hotel-bookstay": "dontcare", + "hotel-stars": "dontcare", + "hotel-internet": "dontcare", + "hotel-name": "dontcare", + "hotel-area": "dontcare" + } + } + ] + }, + { + "service_name": "train", + "slots": [ + { + "name": "train-destination", + "description": "the city you want to go to", + "possible_values": [ + "birmingham new street", + "bishops stortford", + "bournemouth", + "broxbourne", + "cambridge", + "centre", + "city centre north", + "copper kettle", + "curry prince", + "ely", + "glastonbury", + "gourmet burger kitchen", + "huntingdon marriott hotel", + "huntington marriott", + "kings lynn", + "leicester", + "liverpool", + "liverpool street", + "london", + "london kings cross", + "london liverpool street", + "norway", + "norwich", + "peterborough", + "stansted airport", + "stevenage" + ], + "is_categorical": true + }, + { + "name": "train-arriveby", + "description": "when should the train reach your destination", + "possible_values": [], + "is_categorical": false + }, + { + "name": "train-departure", + "description": "the location where you want to catch the train from", + "possible_values": [ + "alpha milton", + "aylesbray lodge guest", + "birmingham new street", + "bishops stortford", + "brookshite", + "broxbourne", + "cafe uno", + "camboats", + "cambridge", + "cineworld", + "city hall", + "duxford", + "east london", + "ely", + "hamilton lodge", + "huntingdon", + "kings lynn", + "leicester", + "liverpool", + "london", + "london kings cross", + "london liverpool", + "london liverpool street", + "norwich", + "panahar", + "peterborough", + "stansted airport", + "stevenage", + "stratford", + "wandlebury country park" + ], + "is_categorical": true + }, + { + "name": "train-day", + "description": "the day of the journey", + "possible_values": [ + "friday", + "monday", + "saturday", + "sunday", + "thursday", + "tuesday", + "wednesday" + ], + "is_categorical": true + }, + { + "name": "train-bookpeople", + "description": "number of tickets to buy", + "possible_values": [ + "1", + "10", + "15", + "2", + "3", + "4", + "5", + "6", + "7", + "8", + "9" + ], + "is_categorical": true + }, + { + "name": "train-leaveat", + "description": "the departure time of the train", + "possible_values": [], + "is_categorical": false + } + ], + "description": "find trains that take you to places", + "intents": [ + { + "name": "find_train", + "description": "search for trains that take you places", + "is_transactional": false, + "required_slots": [], + "optional_slots": { + "train-destination": "dontcare", + "train-arriveby": "dontcare", + "train-departure": "dontcare", + "train-day": "dontcare", + "train-bookpeople": "dontcare", + "train-leaveat": "dontcare" + } + } + ] + }, + { + "service_name": "attraction", + "slots": [ + { + "name": "attraction-area", + "description": "the place where you are located", + "possible_values": [ + "centre", + "centre|west", + "east", + "north", + "south", + "west" + ], + "is_categorical": true + }, + { + "name": "attraction-name", + "description": "the name of the site you want to visit", + "possible_values": [], + "is_categorical": false + }, + { + "name": "attraction-type", + "description": "the type of attractions you are interested in", + "possible_values": [ + "architecture", + "boat", + "boating", + "camboats", + "church", + "churchills college", + "cinema", + "college", + "concert", + "concerthall", + "concerthall|boat", + "entertainment", + "entertainment|cinemas|museums|theatres", + "gallery", + "gastropub", + "hiking|historical", + "hotel", + "multiple sports", + "multiple sports|theatre", + "museum", + "museum kettles yard", + "museum|nightclub", + "night club", + "outdoor", + "park", + "park|boat", + "pool", + "special", + "sports", + "swimming pool", + "theater", + "theatre" + ], + "is_categorical": true + } + ], + "description": "find touristy stuff to do around you", + "intents": [ + { + "name": "find_attraction", + "description": "search for places to see for leisure", + "is_transactional": false, + "required_slots": [], + "optional_slots": { + "attraction-area": "dontcare", + "attraction-name": "dontcare", + "attraction-type": "dontcare" + } + } + ] + }, + { + "service_name": "restaurant", + "slots": [ + { + "name": "restaurant-pricerange", + "description": "indicates how expensive or cheap the restaurant is", + "possible_values": [ + "cheap", + "cheap|moderate", + "expensive", + "moderate" + ], + "is_categorical": true + }, + { + "name": "restaurant-area", + "description": "the locality of the restaurant", + "possible_values": [ + "centre", + "east", + "east|south", + "north", + "south", + "west" + ], + "is_categorical": true + }, + { + "name": "restaurant-food", + "description": "the cuisine or type of food served", + "possible_values": [], + "is_categorical": false + }, + { + "name": "restaurant-name", + "description": "the name of the restaurant", + "possible_values": [], + "is_categorical": false + }, + { + "name": "restaurant-bookday", + "description": "the day of booking at the restaurant", + "possible_values": [ + "friday", + "monday", + "saturday", + "saturday|thursday", + "sunday", + "sunday|thursday", + "thursday", + "tuesday", + "wednesday" + ], + "is_categorical": true + }, + { + "name": "restaurant-bookpeople", + "description": "number of people to reserve the restaurant for", + "possible_values": [ + "1", + "2", + "3", + "4", + "4|7", + "5", + "6", + "7", + "8" + ], + "is_categorical": true + }, + { + "name": "restaurant-booktime", + "description": "the time of the reservation at the restaurant", + "possible_values": [], + "is_categorical": false + } + ], + "description": "find places to dine and whet your appetite", + "intents": [ + { + "name": "find_restaurant", + "description": "search for places to wine and dine", + "is_transactional": false, + "required_slots": [], + "optional_slots": { + "restaurant-pricerange": "dontcare", + "restaurant-area": "dontcare", + "restaurant-food": "dontcare", + "restaurant-name": "dontcare", + "restaurant-bookday": "dontcare", + "restaurant-bookpeople": "dontcare", + "restaurant-booktime": "dontcare" + } + } + ] + }, + { + "service_name": "hospital", + "slots": [ + { + "name": "hospital-department", + "description": "the kind of ailment or sickness you want treated", + "possible_values": [ + "acute medical assessment unit", + "acute medicine for the elderly", + "antenatal", + "cambridge eye unit", + "cardiology", + "cardiology and coronary care unit", + "childrens oncology and haematology", + "childrens surgical and medicine", + "clinical decisions unit", + "clinical research facility", + "coronary care unit", + "diabetes and endocrinology", + "emergency department", + "gastroenterology", + "gynaecology", + "haematology", + "haematology and haematological oncology", + "haematology day unit", + "hepatobillary and gastrointestinal surgery regional referral centre", + "hepatology", + "infectious diseases", + "infusion services", + "inpatient occupational therapy", + "intermediate dependancy area", + "john farman intensive care unit", + "medical decisions unit", + "medicine for the elderly", + "neonatal unit", + "neurology", + "neurology neurosurgery", + "neurosciences", + "neurosciences critical care unit", + "oncology", + "oral and maxillofacial surgery and ent", + "paediatric clinic", + "paediatric day unit", + "paediatric intensive care unit", + "plastic and vascular surgery plastics", + "psychiatry", + "respiratory medicine", + "surgery", + "teenage cancer trust unit", + "transitional care", + "transplant high dependency unit", + "trauma and orthopaedics", + "trauma high dependency unit", + "urology" + ], + "is_categorical": true + } + ], + "description": "making you feel better when you are ill", + "intents": [ + { + "name": "find_hospital", + "description": "search for a medical facility or a doctor", + "is_transactional": false, + "required_slots": [], + "optional_slots": { + "hospital-department": "dontcare" + } + } + ] + }, + { + "service_name": "taxi", + "slots": [ + { + "name": "taxi-leaveat", + "description": "the time you want to depart", + "possible_values": [], + "is_categorical": false + }, + { + "name": "taxi-destination", + "description": "the place you want to get to", + "possible_values": [], + "is_categorical": false + }, + { + "name": "taxi-departure", + "description": "the place you want to board the taxi", + "possible_values": [], + "is_categorical": false + }, + { + "name": "taxi-arriveby", + "description": "the time of your arrival at the destination", + "possible_values": [], + "is_categorical": false + } + ], + "description": "rent cheap cabs to avoid traffic", + "intents": [ + { + "name": "find_taxi", + "description": "search for taxis to avoid traffic", + "is_transactional": false, + "required_slots": [], + "optional_slots": { + "taxi-leaveat": "dontcare", + "taxi-destination": "dontcare", + "taxi-departure": "dontcare", + "taxi-arriveby": "dontcare" + } + } + ] + }, + { + "service_name": "bus", + "slots": [ + { + "name": "bus-departure", + "description": "the departure place of the bus", + "possible_values": [ + "cambridge" + ], + "is_categorical": true + }, + { + "name": "bus-destination", + "description": "the destination of the bus", + "possible_values": [ + "bishops stortford", + "cambridge", + "kohinoor", + "london kings cross" + ], + "is_categorical": true + }, + { + "name": "bus-leaveat", + "description": "the time when bus leaves", + "possible_values": [ + "21:45" + ], + "is_categorical": true + }, + { + "name": "bus-day", + "description": "the day of the bus", + "possible_values": [ + "wednesday" + ], + "is_categorical": true + } + ], + "description": "Bus service for traveling", + "intents": [ + { + "name": "find_bus", + "description": "search for a bus", + "is_transactional": false, + "required_slots": [], + "optional_slots": { + "bus-departure": "dontcare", + "bus-destination": "dontcare", + "bus-day": "dontcare", + "taxi-leaveat": "dontcare" + } + } + ] + } +] diff --git a/examples/nlp/dialogue_state_tracking/dialogue_state_tracking_sgd.py b/examples/nlp/dialogue_state_tracking/dialogue_state_tracking_sgd.py new file mode 100644 index 000000000000..b5edb930720d --- /dev/null +++ b/examples/nlp/dialogue_state_tracking/dialogue_state_tracking_sgd.py @@ -0,0 +1,462 @@ +# ============================================================================= +# Copyright 2020 NVIDIA. All Rights Reserved. +# Copyright 2019 The Google Research Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= + +''' +This file contains code artifacts adapted from the original implementation: +https://github.com/google-research/google-research/blob/master/schema_guided_dst/baseline/train_and_predict.py +''' + +import argparse +import math +import os + +import nemo.collections.nlp as nemo_nlp +import nemo.collections.nlp.data.datasets.sgd_dataset.data_processor as data_processor +from nemo.collections.nlp.callbacks.sgd_callback import eval_epochs_done_callback, eval_iter_callback +from nemo.collections.nlp.data.datasets.sgd_dataset.schema_processor import SchemaPreprocessor +from nemo.collections.nlp.nm.trainables import SGDDecoderNM, SGDEncoderNM +from nemo.core import Backend, CheckpointCallback, EvaluatorCallback, NeuralModuleFactory, SimpleLossLoggerCallback +from nemo.utils import logging +from nemo.utils.lr_policies import get_lr_policy + +# Parsing arguments +parser = argparse.ArgumentParser(description='Schema_guided_dst') + +# BERT based utterance encoder related arguments +parser.add_argument( + "--max_seq_length", + default=80, + type=int, + help="The maximum total input sequence length after WordPiece tokenization. " + "Sequences longer than this will be truncated, and sequences shorter " + "than this will be padded.", +) +parser.add_argument("--dropout", default=0.1, type=float, help="Dropout rate for BERT representations.") +parser.add_argument( + "--pretrained_model_name", + default="bert-base-cased", + type=str, + help="Name of the pre-trained model", + choices=nemo_nlp.nm.trainables.get_pretrained_lm_models_list(), +) +parser.add_argument("--bert_checkpoint", default=None, type=str, help="Path to model checkpoint") +parser.add_argument("--bert_config", default=None, type=str, help="Path to bert config file in json format") +parser.add_argument( + "--tokenizer_model", + default=None, + type=str, + help="Path to pretrained tokenizer model, only used if --tokenizer is sentencepiece", +) +parser.add_argument( + "--tokenizer", + default="nemobert", + type=str, + choices=["nemobert", "sentencepiece"], + help="tokenizer to use, only relevant when using custom pretrained checkpoint.", +) +parser.add_argument("--vocab_file", default=None, help="Path to the vocab file.") +parser.add_argument( + "--do_lower_case", + action='store_true', + help="Whether to lower case the input text. True for uncased models, False for cased models. " + + "Only applicable when tokenizer is build with vocab file", +) + +# Hyperparameters and optimization related flags. +parser.add_argument( + "--checkpoint_dir", + default=None, + type=str, + help="The folder containing the checkpoints for the model to continue training", +) +parser.add_argument("--train_batch_size", default=32, type=int, help="Total batch size for training.") +parser.add_argument("--eval_batch_size", default=8, type=int, help="Total batch size for eval.") +parser.add_argument("--num_epochs", default=80, type=int, help="Total number of training epochs to perform.") + +parser.add_argument("--optimizer_kind", default="adam_w", type=str) +parser.add_argument("--learning_rate", default=1e-4, type=float, help="The initial learning rate for Adam.") +parser.add_argument("--lr_policy", default="PolynomialDecayAnnealing", type=str) +parser.add_argument("--weight_decay", default=0.01, type=float) +parser.add_argument( + "--lr_warmup_proportion", + default=0.1, + type=float, + help="Proportion of training to perform linear learning rate warmup for. " "E.g., 0.1 = 10% of training.", +) +parser.add_argument("--grad_norm_clip", type=float, default=1, help="Gradient clipping") +parser.add_argument("--local_rank", default=None, type=int) +parser.add_argument("--amp_opt_level", default="O0", type=str, choices=["O0", "O1", "O2"]) +parser.add_argument("--num_gpus", default=1, type=int) + +# Input and output paths and other flags. +parser.add_argument( + "--task_name", + default="dstc8_single_domain", + type=str, + choices=data_processor.FILE_RANGES.keys(), + help="The name of the task to train.", +) +parser.add_argument( + "--data_dir", + type=str, + required=True, + help="Directory for the downloaded DSTC8 data, which contains the dialogue files" + " and schema files of all datasets (eg train, dev)", +) +parser.add_argument( + "--work_dir", + type=str, + default="output/SGD", + help="The output directory where the model checkpoints will be written.", +) +parser.add_argument( + "--schema_embedding_dir", + type=str, + default='schema_embedding_dir', + help="Directory where .npy file for embedding of entities (slots, values, intents) in the dataset_split's schema are stored.", +) +parser.add_argument( + "--overwrite_schema_emb_files", + action="store_true", + help="Whether to generate a new file saving the dialogue examples.", +) +parser.add_argument( + "--joint_acc_across_turn", + action="store_true", + help="Whether to compute joint accuracy across turn instead of across service. Should be set to True when conducting multiwoz style evaluation.", +) +parser.add_argument( + "--no_fuzzy_match", + action="store_true", + help="Whether to use fuzzy string matching when comparing non-categorical slot values. Fuzz match should not be used when conducting multiwoz style evaluation.", +) +parser.add_argument( + "--dialogues_example_dir", + type=str, + default="dialogues_example_dir", + help="Directory where preprocessed DSTC8 dialogues are stored.", +) +parser.add_argument( + "--overwrite_dial_files", action="store_true", help="Whether to generate a new file saving the dialogue examples." +) +parser.add_argument("--no_shuffle", action="store_true", help="Whether to shuffle training data") +parser.add_argument("--no_time_to_log_dir", action="store_true", help="whether to add time to work_dir or not") +parser.add_argument( + "--eval_dataset", type=str, default="dev", choices=["dev", "test"], help="Dataset split for evaluation." +) +parser.add_argument( + "--save_epoch_freq", + default=1, + type=int, + help="Frequency of saving checkpoint '-1' - step checkpoint won't be saved", +) +parser.add_argument( + "--save_step_freq", + default=-1, + type=int, + help="Frequency of saving checkpoint '-1' - step checkpoint won't be saved", +) + +parser.add_argument( + "--loss_log_freq", default=-1, type=int, help="Frequency of logging loss values, '-1' - at the end of the epoch", +) + +parser.add_argument( + "--loss_reduction", + default='mean', + type=str, + help="specifies the reduction to apply to the final loss, choose 'mean' or 'sum'", +) + +parser.add_argument( + "--eval_epoch_freq", default=1, type=int, help="Frequency of evaluation", +) + +parser.add_argument( + "--num_workers", + default=2, + type=int, + help="Number of workers for data loading, -1 means set it automatically to the number of CPU cores", +) + +parser.add_argument( + "--enable_pin_memory", action="store_true", help="Enables the pin_memory feature of Pytroch's DataLoader", +) + +parser.add_argument( + "--state_tracker", + type=str, + default='baseline', + choices=['baseline', 'ret_sys_act'], + help="Specifies the state tracker mode", +) +parser.add_argument( + "--schema_emb_init", + type=str, + default='baseline', + choices=['baseline', 'random', 'last_layer_average'], + help="Specifies how schema embeddings are generated. Baseline uses ['CLS'] token", +) +parser.add_argument( + "--train_schema_emb", action="store_true", help="Specifies whether schema embeddings are trainables.", +) +parser.add_argument( + "--head_transform", + default="", + type=str, + choices=["", "Attention"], + help="transformation to use for computing head. Default uses linear projection.", +) +parser.add_argument( + "--debug_mode", action="store_true", help="Enables debug mode with more info on data preprocessing and evaluation", +) + +parser.add_argument( + "--checkpoints_to_keep", default=1, type=int, help="The number of last checkpoints to keep", +) + +args = parser.parse_args() +logging.info(args) + +if args.debug_mode: + logging.setLevel(10) + +if args.task_name == "multiwoz": + schema_config = { + "MAX_NUM_CAT_SLOT": 9, + "MAX_NUM_NONCAT_SLOT": 4, + "MAX_NUM_VALUE_PER_CAT_SLOT": 47, + "MAX_NUM_INTENT": 1, + } +else: + schema_config = { + "MAX_NUM_CAT_SLOT": 6, + "MAX_NUM_NONCAT_SLOT": 12, + "MAX_NUM_VALUE_PER_CAT_SLOT": 12, + "MAX_NUM_INTENT": 4, + } + +if not os.path.exists(args.data_dir): + raise ValueError(f'Data not found at {args.data_dir}') + +nf = NeuralModuleFactory( + backend=Backend.PyTorch, + local_rank=args.local_rank, + optimization_level=args.amp_opt_level, + log_dir=args.work_dir, + create_tb_writer=True, + checkpoint_dir=args.checkpoint_dir, + files_to_copy=[__file__], + add_time_to_log_dir=not args.no_time_to_log_dir, +) + +pretrained_bert_model = nemo_nlp.nm.trainables.get_pretrained_lm_model( + pretrained_model_name=args.pretrained_model_name, + config=args.bert_config, + vocab=args.vocab_file, + checkpoint=args.bert_checkpoint, +) + +schema_config["EMBEDDING_DIMENSION"] = pretrained_bert_model.hidden_size +schema_config["MAX_SEQ_LENGTH"] = args.max_seq_length + +tokenizer = nemo_nlp.data.tokenizers.get_tokenizer( + tokenizer_name=args.tokenizer, + pretrained_model_name=args.pretrained_model_name, + tokenizer_model=args.tokenizer_model, + vocab_file=args.vocab_file, + do_lower_case=args.do_lower_case, +) + +hidden_size = pretrained_bert_model.hidden_size + +# Run SGD preprocessor to generate and store schema embeddings +schema_preprocessor = SchemaPreprocessor( + data_dir=args.data_dir, + schema_embedding_dir=args.schema_embedding_dir, + schema_config=schema_config, + tokenizer=tokenizer, + bert_model=pretrained_bert_model, + overwrite_schema_emb_files=args.overwrite_schema_emb_files, + bert_ckpt_dir=args.checkpoint_dir, + nf=nf, + mode=args.schema_emb_init, + is_trainable=args.train_schema_emb, +) + +dialogues_processor = data_processor.Dstc8DataProcessor( + task_name=args.task_name, + dstc8_data_dir=args.data_dir, + dialogues_example_dir=args.dialogues_example_dir, + tokenizer=tokenizer, + schema_emb_processor=schema_preprocessor, + overwrite_dial_files=args.overwrite_dial_files, +) + +# define model pipeline +sgd_encoder = SGDEncoderNM(hidden_size=hidden_size, dropout=args.dropout) +sgd_decoder = SGDDecoderNM( + embedding_dim=hidden_size, schema_emb_processor=schema_preprocessor, head_transform="Logits" + args.head_transform +) +dst_loss = nemo_nlp.nm.losses.SGDDialogueStateLossNM(reduction=args.loss_reduction) + + +def create_pipeline(dataset_split='train'): + datalayer = nemo_nlp.nm.data_layers.SGDDataLayer( + dataset_split=dataset_split, + dialogues_processor=dialogues_processor, + batch_size=args.train_batch_size, + shuffle=not args.no_shuffle if dataset_split == 'train' else False, + num_workers=args.num_workers, + pin_memory=args.enable_pin_memory, + ) + data = datalayer() + + # Encode the utterances using BERT. + token_embeddings = pretrained_bert_model( + input_ids=data.utterance_ids, attention_mask=data.utterance_mask, token_type_ids=data.utterance_segment, + ) + encoded_utterance, token_embeddings = sgd_encoder(hidden_states=token_embeddings) + ( + logit_intent_status, + logit_req_slot_status, + logit_cat_slot_status, + logit_cat_slot_value, + logit_noncat_slot_status, + logit_noncat_slot_start, + logit_noncat_slot_end, + ) = sgd_decoder( + encoded_utterance=encoded_utterance, + token_embeddings=token_embeddings, + utterance_mask=data.utterance_mask, + cat_slot_values_mask=data.cat_slot_values_mask, + intent_status_mask=data.intent_status_mask, + service_ids=data.service_id, + ) + + if dataset_split == 'train': + loss = dst_loss( + logit_intent_status=logit_intent_status, + intent_status_labels=data.intent_status_labels, + logit_req_slot_status=logit_req_slot_status, + requested_slot_status=data.requested_slot_status, + req_slot_mask=data.req_slot_mask, + logit_cat_slot_status=logit_cat_slot_status, + categorical_slot_status=data.categorical_slot_status, + cat_slot_status_mask=data.cat_slot_status_mask, + logit_cat_slot_value=logit_cat_slot_value, + categorical_slot_values=data.categorical_slot_values, + logit_noncat_slot_status=logit_noncat_slot_status, + noncategorical_slot_status=data.noncategorical_slot_status, + noncat_slot_status_mask=data.noncat_slot_status_mask, + logit_noncat_slot_start=logit_noncat_slot_start, + logit_noncat_slot_end=logit_noncat_slot_end, + noncategorical_slot_value_start=data.noncategorical_slot_value_start, + noncategorical_slot_value_end=data.noncategorical_slot_value_end, + ) + tensors = [loss] + else: + tensors = [ + data.example_id_num, + data.service_id, + data.is_real_example, + data.start_char_idx, + data.end_char_idx, + logit_intent_status, + logit_req_slot_status, + logit_cat_slot_status, + logit_cat_slot_value, + logit_noncat_slot_status, + logit_noncat_slot_start, + logit_noncat_slot_end, + data.intent_status_labels, + data.requested_slot_status, + data.categorical_slot_status, + data.categorical_slot_values, + data.noncategorical_slot_status, + ] + + steps_per_epoch = math.ceil(len(datalayer) / (args.train_batch_size * args.num_gpus)) + return steps_per_epoch, tensors + + +steps_per_epoch, train_tensors = create_pipeline() +logging.info(f'Steps per epoch: {steps_per_epoch}') +_, eval_tensors = create_pipeline(dataset_split=args.eval_dataset) + +# Create trainer and execute training action +train_callback = SimpleLossLoggerCallback( + tensors=train_tensors, + print_func=lambda x: logging.info("Loss: {:.8f}".format(x[0].item())), + get_tb_values=lambda x: [["loss", x[0]]], + tb_writer=nf.tb_writer, + step_freq=args.loss_log_freq if args.loss_log_freq > 0 else steps_per_epoch, +) + +# we'll write predictions to file in DSTC8 format during evaluation callback +input_json_files = [ + os.path.join(args.data_dir, args.eval_dataset, 'dialogues_{:03d}.json'.format(fid)) + for fid in data_processor.FILE_RANGES[args.task_name][args.eval_dataset] +] +schema_json_file = os.path.join(args.data_dir, args.eval_dataset, 'schema.json') + +# Write predictions to file in DSTC8 format. +prediction_dir = os.path.join(nf.work_dir, 'predictions', 'pred_res_{}_{}'.format(args.eval_dataset, args.task_name)) +output_metric_file = os.path.join(nf.work_dir, 'metrics.txt') +os.makedirs(prediction_dir, exist_ok=True) + +eval_callback = EvaluatorCallback( + eval_tensors=eval_tensors, + user_iter_callback=lambda x, y: eval_iter_callback(x, y, schema_preprocessor, args.eval_dataset), + user_epochs_done_callback=lambda x: eval_epochs_done_callback( + x, + input_json_files, + args.eval_dataset, + args.data_dir, + prediction_dir, + output_metric_file, + args.state_tracker, + args.debug_mode, + schema_preprocessor, + args.joint_acc_across_turn, + args.no_fuzzy_match, + ), + tb_writer=nf.tb_writer, + eval_step=args.eval_epoch_freq * steps_per_epoch, +) + +ckpt_callback = CheckpointCallback( + folder=nf.checkpoint_dir, epoch_freq=args.save_epoch_freq, step_freq=args.save_step_freq, checkpoints_to_keep=1 +) + +lr_policy_fn = get_lr_policy( + args.lr_policy, total_steps=args.num_epochs * steps_per_epoch, warmup_ratio=args.lr_warmup_proportion +) + +nf.train( + tensors_to_optimize=train_tensors, + callbacks=[train_callback, eval_callback, ckpt_callback], + lr_policy=lr_policy_fn, + optimizer=args.optimizer_kind, + optimization_params={ + "num_epochs": args.num_epochs, + "lr": args.learning_rate, + "eps": 1e-6, + "weight_decay": args.weight_decay, + "grad_norm_clip": args.grad_norm_clip, + }, +) diff --git a/examples/speaker_recognition/notebooks/Speaker_Recognition_an4.ipynb b/examples/speaker_recognition/notebooks/Speaker_Recognition_an4.ipynb index 75f62f1c050c..118341cd2b30 100644 --- a/examples/speaker_recognition/notebooks/Speaker_Recognition_an4.ipynb +++ b/examples/speaker_recognition/notebooks/Speaker_Recognition_an4.ipynb @@ -310,7 +310,7 @@ }, "outputs": [], "source": [ - "logging = nemo.logging\n", + "from nemo.utils import logging\n", "yaml = YAML(typ=\"safe\")\n", "with open('../configs/quartznet_spkr_3x1x512_xvector.yaml') as f:\n", " spkr_params = yaml.load(f)\n", diff --git a/examples/speaker_recognition/notebooks/Speaker_Recognition_hi-mia.ipynb b/examples/speaker_recognition/notebooks/Speaker_Recognition_hi-mia.ipynb index 234bd53cbde0..4a25b4856bc9 100644 --- a/examples/speaker_recognition/notebooks/Speaker_Recognition_hi-mia.ipynb +++ b/examples/speaker_recognition/notebooks/Speaker_Recognition_hi-mia.ipynb @@ -198,7 +198,7 @@ }, "outputs": [], "source": [ - "logging = nemo.logging\n", + "from nemo.utils import logging\n", "yaml = YAML(typ=\"safe\")\n", "with open('examples/speaker_recognition/configs/quartznet_spkr_3x2x512_xvector.yaml') as f:\n", " spkr_params = yaml.load(f)\n", diff --git a/examples/speaker_recognition/speaker_reco.py b/examples/speaker_recognition/speaker_reco.py index 3c6cf1b84985..be85e1863769 100644 --- a/examples/speaker_recognition/speaker_reco.py +++ b/examples/speaker_recognition/speaker_reco.py @@ -27,10 +27,9 @@ process_classification_evaluation_batch, process_classification_evaluation_epoch, ) +from nemo.utils import logging from nemo.utils.lr_policies import CosineAnnealing -logging = nemo.logging - def parse_args(): parser = argparse.ArgumentParser( diff --git a/examples/speaker_recognition/spkr_get_emb.py b/examples/speaker_recognition/spkr_get_emb.py index db93f638979f..7fe5d9848bc0 100644 --- a/examples/speaker_recognition/spkr_get_emb.py +++ b/examples/speaker_recognition/spkr_get_emb.py @@ -23,8 +23,7 @@ import nemo import nemo.collections.asr as nemo_asr import nemo.utils.argparse as nm_argparse - -logging = nemo.logging +from nemo.utils import logging def parse_args(): diff --git a/examples/start_here/chatbot_example.py b/examples/start_here/chatbot_example.py index fb59b243c67f..e0e974ab446c 100644 --- a/examples/start_here/chatbot_example.py +++ b/examples/start_here/chatbot_example.py @@ -3,8 +3,7 @@ import shutil import nemo - -logging = nemo.logging +from nemo.utils import logging data_file = "movie_data.txt" diff --git a/examples/start_here/simplest_example.py b/examples/start_here/simplest_example.py index 0bf3fb795dac..1e4bd2de633f 100644 --- a/examples/start_here/simplest_example.py +++ b/examples/start_here/simplest_example.py @@ -1,7 +1,6 @@ # Copyright (c) 2019 NVIDIA Corporation import nemo - -logging = nemo.logging +from nemo.utils import logging nf = nemo.core.NeuralModuleFactory() # To use CPU-only do: diff --git a/examples/tts/fastspeech.py b/examples/tts/fastspeech.py index 94f342a70a8b..15a147fbfd03 100644 --- a/examples/tts/fastspeech.py +++ b/examples/tts/fastspeech.py @@ -23,9 +23,7 @@ from nemo.collections import asr as nemo_asr from nemo.collections import tts as nemo_tts from nemo.utils import argparse as nm_argparse -from nemo.utils import lr_policies - -logging = nemo.logging +from nemo.utils import logging, lr_policies def parse_args(): diff --git a/examples/tts/fastspeech_durations.py b/examples/tts/fastspeech_durations.py index e7b827a9feae..ac692e366cf4 100644 --- a/examples/tts/fastspeech_durations.py +++ b/examples/tts/fastspeech_durations.py @@ -22,8 +22,7 @@ import nemo import nemo.collections.asr as nemo_asr import nemo.collections.tts as nemo_tts - -logging = nemo.logging +from nemo.utils import logging def parse_args(): diff --git a/examples/tts/notebooks/1_Tacotron_inference.ipynb b/examples/tts/notebooks/1_Tacotron_inference.ipynb new file mode 100644 index 000000000000..23b27fecc395 --- /dev/null +++ b/examples/tts/notebooks/1_Tacotron_inference.ipynb @@ -0,0 +1,650 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Copyright 2020 NVIDIA. All Rights Reserved.\n", + "#\n", + "# Licensed under the Apache License, Version 2.0 (the \"License\");\n", + "# you may not use this file except in compliance with the License.\n", + "# You may obtain a copy of the License at\n", + "#\n", + "# http://www.apache.org/licenses/LICENSE-2.0\n", + "#\n", + "# Unless required by applicable law or agreed to in writing, software\n", + "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", + "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", + "# See the License for the specific language governing permissions and\n", + "# limitations under the License." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "\"\"\"\n", + "You can run either this notebook locally (if you have all the dependencies and a GPU) or on Google Colab.\n", + "Instructions for setting up Colab are as follows:\n", + "1. Open a new Python 3 notebook.\n", + "2. Import this notebook from GitHub (File -> Upload Notebook -> \"GITHUB\" tab -> copy/paste GitHub URL)\n", + "3. Connect to an instance with a GPU (Runtime -> Change runtime type -> select \"GPU\" for hardware accelerator)\n", + "4. Run this cell to set up dependencies.\n", + "\"\"\"\n", + "# If you're using Google Colab and not running locally, run this cell.\n", + "!pip install wget\n", + "!pip install nemo_toolkit[tts]\n", + "\n", + "!mkdir configs\n", + "!wget -P configs/ https://raw.githubusercontent.com/NVIDIA/NeMo/master/examples/tts/configs/tacotron2.yaml\n", + "!wget -P configs/ https://raw.githubusercontent.com/NVIDIA/NeMo/master/examples/tts/configs/waveglow.yaml" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import argparse\n", + "import math\n", + "import os\n", + "import copy\n", + "import shutil\n", + "import librosa\n", + "import matplotlib.pyplot as plt\n", + "from functools import partial\n", + "from scipy.io.wavfile import write\n", + "import numpy as np\n", + "import IPython.display as ipd\n", + "\n", + "from ruamel.yaml import YAML\n", + "\n", + "import torch\n", + "import nemo\n", + "import nemo.collections.asr as nemo_asr\n", + "import nemo.collections.tts as nemo_tts\n", + "import nemo.utils.argparse as nm_argparse\n", + "\n", + "logging = nemo.logging" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Download config files\n", + "config_path = '../configs/tacotron2.yaml'\n", + "waveglow_config_path = '../configs/waveglow.yaml'\n", + "\n", + "yaml = YAML(typ=\"safe\")\n", + "with open(config_path) as file:\n", + " tacotron2_config = yaml.load(file)\n", + " labels = tacotron2_config[\"labels\"]\n", + " \n", + "with open(waveglow_config_path) as file:\n", + " waveglow_config = yaml.load(file)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Download pre-trained checkpoints\n", + "\n", + "Note: The checkpoint for WaveGlow is very large (>1GB), so please ensure you have sufficient storage space." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "base_checkpoint_path = './checkpoints/'\n", + "WAVEGLOW = os.path.join(base_checkpoint_path, 'WaveGlowNM.pt')\n", + "TACOTRON_ENCODER = os.path.join(base_checkpoint_path, 'Tacotron2Encoder.pt')\n", + "TACOTRON_DECODER = os.path.join(base_checkpoint_path, 'Tacotron2Decoder.pt')\n", + "TACOTRON_POSTNET = os.path.join(base_checkpoint_path, 'Tacotron2Postnet.pt')\n", + "TEXT_EMBEDDING = os.path.join(base_checkpoint_path, 'TextEmbedding.pt')\n", + "\n", + "if not os.path.exists(base_checkpoint_path):\n", + " os.makedirs(base_checkpoint_path)\n", + " \n", + "if not os.path.exists(WAVEGLOW):\n", + " !wget wget https://api.ngc.nvidia.com/v2/models/nvidia/waveglow_ljspeech/versions/2/files/WaveGlowNM.pt -P {base_checkpoint_path};\n", + "\n", + "if not os.path.exists(TACOTRON_ENCODER):\n", + " !wget https://api.ngc.nvidia.com/v2/models/nvidia/tacotron2_ljspeech/versions/2/files/Tacotron2Encoder.pt -P {base_checkpoint_path};\n", + " \n", + "if not os.path.exists(TACOTRON_DECODER):\n", + " !wget https://api.ngc.nvidia.com/v2/models/nvidia/tacotron2_ljspeech/versions/2/files/Tacotron2Decoder.pt -P {base_checkpoint_path};\n", + "\n", + "if not os.path.exists(TACOTRON_POSTNET):\n", + " !wget https://api.ngc.nvidia.com/v2/models/nvidia/tacotron2_ljspeech/versions/2/files/Tacotron2Postnet.pt -P {base_checkpoint_path};\n", + "\n", + "if not os.path.exists(TEXT_EMBEDDING):\n", + " !wget https://api.ngc.nvidia.com/v2/models/nvidia/tacotron2_ljspeech/versions/2/files/TextEmbedding.pt -P {base_checkpoint_path};\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Prepare the Neural Factory\n", + "neural_factory = nemo.core.NeuralModuleFactory(\n", + " optimization_level=\"O0\", backend=nemo.core.Backend.PyTorch\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Text Line Data Layer\n", + "\n", + "Construct a simple datalayer to load a single line of text (accepted from the user) and pass it to the model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from nemo.backends.pytorch import DataLayerNM\n", + "from nemo.core.neural_types import *\n", + "from nemo.utils.misc import pad_to\n", + "from nemo.collections.asr.parts.dataset import TranscriptDataset" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "class SentenceDataLayer(DataLayerNM):\n", + " \"\"\"A simple Neural Module for loading textual transcript data.\n", + " The path, labels, and eos_id arguments are dataset parameters.\n", + "\n", + " Args:\n", + " pad_id (int): Label position of padding symbol\n", + " batch_size (int): Size of batches to generate in data loader\n", + " drop_last (bool): Whether we drop last (possibly) incomplete batch.\n", + " Defaults to False.\n", + " num_workers (int): Number of processes to work on data loading (0 for\n", + " just main process).\n", + " Defaults to 0.\n", + " \"\"\"\n", + "\n", + " @property\n", + " def output_ports(self):\n", + " \"\"\"Returns definitions of module output ports.\n", + "\n", + " texts:\n", + " 0: AxisType(BatchTag)\n", + "\n", + " 1: AxisType(TimeTag)\n", + "\n", + " texts_length:\n", + " 0: AxisType(BatchTag)\n", + "\n", + " \"\"\"\n", + " return {\n", + " 'texts': NeuralType(('B', 'T'), LabelsType()),\n", + " 'texts_length': NeuralType(tuple('B'), LengthsType()),\n", + " }\n", + "\n", + " def __init__(\n", + " self,\n", + " path,\n", + " labels,\n", + " batch_size,\n", + " bos_id=None,\n", + " eos_id=None,\n", + " pad_id=None,\n", + " drop_last=False,\n", + " num_workers=0,\n", + " shuffle=True,\n", + " ):\n", + " super().__init__()\n", + "\n", + " # Set up dataset\n", + " self.dataset_params = {\n", + " 'path': path,\n", + " 'labels': labels,\n", + " 'bos_id': bos_id,\n", + " 'eos_id': eos_id,\n", + " }\n", + "\n", + " self._dataset = TranscriptDataset(**self.dataset_params)\n", + "\n", + " # Set up data loader\n", + " sampler = None\n", + " pad_id = 0 if pad_id is None else pad_id\n", + " \n", + " def update_dataset(self):\n", + " self._dataset = TranscriptDataset(**self.dataset_params)\n", + " logging.info('Dataset updated.')\n", + "\n", + " def __len__(self):\n", + " return len(self._dataset)\n", + "\n", + " @property\n", + " def dataset(self):\n", + " return self._dataset\n", + "\n", + " @property\n", + " def data_iterator(self):\n", + " return None\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Create the Tacotron 2 + WaveGlow Neural Modules" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def create_NMs(tacotron2_config, waveglow_config, labels, decoder_infer=False, waveglow_sigma=0.6):\n", + " data_preprocessor = nemo_asr.AudioToMelSpectrogramPreprocessor(\n", + " **tacotron2_config[\"AudioToMelSpectrogramPreprocessor\"][\"init_params\"]\n", + " )\n", + " \n", + " text_embedding_params = copy.deepcopy(tacotron2_config[\"TextEmbedding\"][\"init_params\"])\n", + " text_embedding_params['n_symbols'] = len(labels) + 3\n", + " \n", + " # Load checkpoint for text embedding\n", + " text_embedding = nemo_tts.TextEmbedding(**text_embedding_params)\n", + " text_embedding.restore_from(TEXT_EMBEDDING)\n", + " \n", + " # Load checkpoint for encoder\n", + " t2_enc = nemo_tts.Tacotron2Encoder(**tacotron2_config[\"Tacotron2Encoder\"][\"init_params\"])\n", + " t2_enc.restore_from(TACOTRON_ENCODER)\n", + " \n", + " # Load checkpoint for decoder\n", + " decoder_params = copy.deepcopy(tacotron2_config[\"Tacotron2Decoder\"][\"init_params\"])\n", + " \n", + " t2_dec = nemo_tts.Tacotron2DecoderInfer(**decoder_params) \n", + " t2_dec.restore_from(TACOTRON_DECODER)\n", + " \n", + " # Load checkpoint for PortNet\n", + " t2_postnet = nemo_tts.Tacotron2Postnet(**tacotron2_config[\"Tacotron2Postnet\"][\"init_params\"])\n", + " t2_postnet.restore_from(TACOTRON_POSTNET)\n", + " \n", + " t2_loss = nemo_tts.Tacotron2Loss(**tacotron2_config[\"Tacotron2Loss\"][\"init_params\"])\n", + " \n", + " makegatetarget = nemo_tts.MakeGate()\n", + "\n", + " total_weights = text_embedding.num_weights + t2_enc.num_weights + t2_dec.num_weights + t2_postnet.num_weights\n", + "\n", + " logging.info('================================')\n", + " logging.info(f\"Total number of parameters (Tacotron 2): {total_weights}\")\n", + " logging.info('================================')\n", + " \n", + " \n", + " # Load WaveGlow model\n", + " waveglow_args = copy.deepcopy(waveglow_config[\"WaveGlowNM\"][\"init_params\"])\n", + " waveglow_args['sigma'] = waveglow_sigma\n", + " \n", + " waveglow = nemo_tts.WaveGlowInferNM(**waveglow_args)\n", + " waveglow.restore_from(WAVEGLOW)\n", + " \n", + " total_weights = waveglow.num_weights\n", + " \n", + " logging.info('================================')\n", + " logging.info(f\"Total number of parameters (WaveGlow): {total_weights}\")\n", + " logging.info('================================')\n", + "\n", + " return (\n", + " data_preprocessor,\n", + " text_embedding,\n", + " t2_enc,\n", + " t2_dec,\n", + " t2_postnet,\n", + " t2_loss,\n", + " makegatetarget,\n", + " ), waveglow" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "neural_modules, waveglow = create_NMs(tacotron2_config, waveglow_config, labels, decoder_infer=True, waveglow_sigma=0.6);" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Utility functions" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def update_text(text):\n", + " if not os.path.exists('cache/'):\n", + " os.makedirs('cache/')\n", + " \n", + " fp = os.path.join('cache', 'input.txt')\n", + " with open(fp, 'w', encoding='utf8') as f:\n", + " f.write('{}\\n'.format(text))\n", + " f.flush()\n", + " \n", + " logging.info(\"Updated input file with value : %s\", text)\n", + " return fp\n", + " \n", + "def cleanup_cachedir():\n", + " if os.path.exists('cache/'):\n", + " shutil.rmtree('cache/')\n", + " logging.info(\"Cleaned up cache directory !\")\n", + " \n", + "def plot_and_save_spec(spectrogram, i, save_dir=None):\n", + " fig, ax = plt.subplots(figsize=(12, 3))\n", + " im = ax.imshow(spectrogram, aspect=\"auto\", origin=\"lower\", interpolation='none')\n", + " plt.colorbar(im, ax=ax)\n", + " plt.xlabel(\"Frames\")\n", + " plt.ylabel(\"Channels\")\n", + " plt.tight_layout()\n", + " save_file = f\"spec_{i}.png\"\n", + " if save_dir:\n", + " save_file = os.path.join(save_dir, save_file)\n", + " plt.savefig(save_file)\n", + " plt.close()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Initializing the inference DAG\n", + "\n", + "To initialize the graph, we accept some text from the user. Later, we will accept the actual text that we want to convert to speech !" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "text = input('Please enter some initial text here :')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "filepath = update_text(text)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Create inference DAG" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Tacotron 2 DAG\n", + "(_, text_embedding, t2_enc, t2_dec, t2_postnet, _, _) = neural_modules\n", + "\n", + "data_layer = SentenceDataLayer(\n", + " path=filepath,\n", + " labels=labels,\n", + " batch_size=1,\n", + " num_workers=0,\n", + " bos_id=len(labels),\n", + " eos_id=len(labels) + 1,\n", + " pad_id=len(labels) + 2,\n", + " shuffle=False,\n", + ")\n", + "transcript, transcript_len = data_layer()\n", + "\n", + "transcript_embedded = text_embedding(char_phone=transcript)\n", + "\n", + "transcript_encoded = t2_enc(char_phone_embeddings=transcript_embedded, embedding_length=transcript_len,)\n", + "\n", + "mel_decoder, gate, alignments, mel_len = t2_dec(\n", + " char_phone_encoded=transcript_encoded, encoded_length=transcript_len,\n", + ")\n", + "\n", + "mel_postnet = t2_postnet(mel_input=mel_decoder)\n", + "\n", + "# WaveGlow DAG\n", + "audio_pred = waveglow(mel_spectrogram=mel_postnet)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Setup inference tensors\n", + "infer_tensors = [mel_postnet, gate, alignments, mel_len]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Run inference DAG" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def run_tacotron2():\n", + " logging.info(\"Running Tacotron 2\")\n", + " # Run tacotron 2\n", + " evaluated_tensors = neural_factory.infer(\n", + " tensors=infer_tensors, offload_to_cpu=False\n", + " )\n", + " logging.info(\"Done Running Tacotron 2\")\n", + " \n", + " mel_len_val = evaluated_tensors[-1]\n", + " \n", + " filterbank = librosa.filters.mel(\n", + " sr=tacotron2_config[\"sample_rate\"],\n", + " n_fft=tacotron2_config[\"n_fft\"],\n", + " n_mels=tacotron2_config[\"n_mels\"],\n", + " fmax=tacotron2_config[\"fmax\"],\n", + " )\n", + " \n", + " return evaluated_tensors, filterbank, mel_len_val\n", + "\n", + "def run_waveglow(save_dir, waveglow_denoiser_strength=0.0):\n", + " # Run Tacotron 2 and WaveGlow\n", + " evaluated_tensors, filterbank, mel_len_val = run_tacotron2()\n", + " \n", + " logging.info(\"Running Waveglow\")\n", + " evaluated_tensors = neural_factory.infer(\n", + " tensors=[audio_pred],\n", + " )\n", + " logging.info(\"Done Running Waveglow\")\n", + " \n", + " if waveglow_denoiser_strength > 0:\n", + " logging.info(\"Setup WaveGlow denoiser\")\n", + " waveglow.setup_denoiser()\n", + " \n", + " logging.info(\"Saving results to disk\")\n", + " for i, batch in enumerate(evaluated_tensors[0]):\n", + " audio = batch.cpu().numpy()\n", + " for j, sample in enumerate(audio):\n", + " sample_len = mel_len_val[i][j] * tacotron2_config[\"n_stride\"]\n", + " sample = sample[:sample_len]\n", + " save_file = f\"sample_{i * 32 + j}.wav\"\n", + " if save_dir:\n", + " save_file = os.path.join(save_dir, save_file)\n", + " if waveglow_denoiser_strength > 0:\n", + " sample, spec = waveglow.denoise(sample, strength=waveglow_denoiser_strength)\n", + " else:\n", + " spec, _ = librosa.core.magphase(librosa.core.stft(sample, n_fft=waveglow_config[\"n_fft\"]))\n", + " write(save_file, waveglow_config[\"sample_rate\"], sample)\n", + " spec = np.dot(filterbank, spec)\n", + " spec = np.log(np.clip(spec, a_min=1e-5, a_max=None))\n", + " plot_and_save_spec(spec, i * 32 + j, save_dir)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Run Tacotron 2 + WaveGlow on input text" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "text = input('Please enter some initial text here :')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "filepath = update_text(text)\n", + "data_layer.update_dataset()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Prepare directories to save results" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "savedir = 'results/'\n", + "saved_audio = os.path.join(savedir, 'sample_0.wav')\n", + "saved_spectrogram = os.path.join(savedir, 'spec_0.png')\n", + "\n", + "if not os.path.exists(savedir):\n", + " os.makedirs(savedir)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Generate the audio\n", + "\n", + "Lets run the Tacotron 2 model and send the results to WaveGlow to generate the audio!" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "run_waveglow(savedir, waveglow_denoiser_strength=0.0)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Lets hear the generated audio !" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "ipd.Audio(saved_audio, rate=16000)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "ipd.Image(saved_spectrogram)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Cleanup cachedir" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "cleanup_cachedir()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3.7.6 64-bit ('NeMo': conda)", + "language": "python", + "name": "python37664bitnemoconda43f94a748a2e4953b0129556ecdf4f62" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.6" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/examples/tts/tacotron2.py b/examples/tts/tacotron2.py index 64f034c7b86e..f87ff213a7ba 100644 --- a/examples/tts/tacotron2.py +++ b/examples/tts/tacotron2.py @@ -28,10 +28,9 @@ tacotron2_process_eval_batch, tacotron2_process_final_eval, ) +from nemo.utils import logging from nemo.utils.lr_policies import CosineAnnealing -logging = nemo.logging - def parse_args(): parser = argparse.ArgumentParser( diff --git a/examples/tts/tacotron2_v0p9.py b/examples/tts/tacotron2_v0p9.py index d8e8dd69153e..e6339c5e542a 100644 --- a/examples/tts/tacotron2_v0p9.py +++ b/examples/tts/tacotron2_v0p9.py @@ -33,10 +33,9 @@ tacotron2_process_eval_batch, tacotron2_process_final_eval, ) +from nemo.utils import logging from nemo.utils.lr_policies import CosineAnnealing -logging = nemo.logging - def parse_args(): parser = argparse.ArgumentParser( diff --git a/examples/tts/tts_infer.py b/examples/tts/tts_infer.py index 48a9ffe4623d..df95b31063c0 100644 --- a/examples/tts/tts_infer.py +++ b/examples/tts/tts_infer.py @@ -24,8 +24,7 @@ import nemo import nemo.collections.asr as nemo_asr import nemo.collections.tts as nemo_tts - -logging = nemo.logging +from nemo.utils import logging def parse_args(): diff --git a/examples/tts/waveglow.py b/examples/tts/waveglow.py index 674117bd5825..63e262099f8a 100644 --- a/examples/tts/waveglow.py +++ b/examples/tts/waveglow.py @@ -20,8 +20,7 @@ import nemo.collections.tts as nemo_tts import nemo.utils.argparse as nm_argparse from nemo.collections.tts import waveglow_eval_log_to_tb_func, waveglow_log_to_tb_func, waveglow_process_eval_batch - -logging = nemo.logging +from nemo.utils import logging def parse_args(): diff --git a/examples/tts/waveglow_v0p9.py b/examples/tts/waveglow_v0p9.py index 6ad8fc7d14d1..2bc905ea6973 100644 --- a/examples/tts/waveglow_v0p9.py +++ b/examples/tts/waveglow_v0p9.py @@ -27,8 +27,7 @@ import nemo.collections.tts as nemo_tts import nemo.utils.argparse as nm_argparse from nemo.collections.tts import waveglow_eval_log_to_tb_func, waveglow_log_to_tb_func, waveglow_process_eval_batch - -logging = nemo.logging +from nemo.utils import logging def parse_args(): diff --git a/nemo/backends/pytorch/actions.py b/nemo/backends/pytorch/actions.py index d22fa8bc2c13..47a31c175bdc 100644 --- a/nemo/backends/pytorch/actions.py +++ b/nemo/backends/pytorch/actions.py @@ -537,6 +537,7 @@ def _eval(self, tensors_2_evaluate, callback, step, verbose=False): 'num_workers': dl_nm.num_workers, 'batch_size': dl_nm.batch_size, 'shuffle': False, + 'pin_memory': dl_nm.pin_memory, } if hasattr(dl_nm, 'collate_fn'): dataloader_params['collate_fn'] = dl_nm.collate_fn @@ -555,6 +556,7 @@ def _eval(self, tensors_2_evaluate, callback, step, verbose=False): 'num_workers': dl_nm.num_workers, 'batch_size': dl_nm.batch_size, 'shuffle': dl_nm.shuffle, + 'pin_memory': dl_nm.pin_memory, } if hasattr(dl_nm, 'collate_fn'): dataloader_params['collate_fn'] = dl_nm.collate_fn @@ -712,6 +714,7 @@ def _infer( 'num_workers': dl_nm.num_workers, 'batch_size': dl_nm.batch_size, 'shuffle': False, + 'pin_memory': dl_nm.pin_memory, } if hasattr(dl_nm, 'collate_fn'): dataloader_params['collate_fn'] = dl_nm.collate_fn @@ -730,6 +733,7 @@ def _infer( 'num_workers': dl_nm.num_workers, 'batch_size': dl_nm.batch_size, 'shuffle': dl_nm.shuffle, + 'pin_memory': dl_nm.pin_memory, } if hasattr(dl_nm, 'collate_fn'): dataloader_params['collate_fn'] = dl_nm.collate_fn @@ -1241,6 +1245,7 @@ def train( 'num_workers': dataNM.num_workers, 'batch_size': dataNM.batch_size, 'shuffle': False, + 'pin_memory': dataNM.pin_memory, } if hasattr(dataNM, 'collate_fn'): dataloader_params['collate_fn'] = dataNM.collate_fn @@ -1323,6 +1328,7 @@ def train( 'num_workers': dataNM.num_workers, 'batch_size': dataNM.batch_size, 'shuffle': dataNM.shuffle, + 'pin_memory': dataNM.pin_memory, } if hasattr(dataNM, 'collate_fn'): dataloader_params['collate_fn'] = dataNM.collate_fn diff --git a/nemo/backends/pytorch/common/losses.py b/nemo/backends/pytorch/common/losses.py index 70633a2ffc14..ad25a5dd6773 100644 --- a/nemo/backends/pytorch/common/losses.py +++ b/nemo/backends/pytorch/common/losses.py @@ -5,7 +5,7 @@ from nemo.core.neural_types import LabelsType, LogitsType, LossType, MaskType, NeuralType, RegressionValuesType from nemo.utils.decorators import add_port_docs -__all__ = ['SequenceLoss', 'CrossEntropyLossNM', 'MSELoss', 'LossAggregatorNM'] +__all__ = ['SequenceLoss', 'CrossEntropyLossNM', 'MSELoss', 'LossAggregatorNM', 'BCEWithLogitsLossNM'] class SequenceLoss(LossNM): @@ -159,7 +159,7 @@ def _loss_function(self, logits, labels, loss_mask=None): labels_flatten = labels_flatten[loss_mask_flatten] if len(labels_flatten) == 0: - return 0 + return self._criterion(logits, torch.argmax(logits, dim=-1)) loss = self._criterion(logits_flatten, labels_flatten) return loss @@ -248,3 +248,65 @@ def _loss_function(self, **kwargs): else: loss = loss.add(loss_value) return loss + + +class BCEWithLogitsLossNM(LossNM): + """ + CrossEntropyLoss + Args: + logits_ndim (int): number of dimensions (or rank) of the logits tensor + weight (list): list of rescaling weight given to each class + reduction (str): type of the reduction over the batch + """ + + @property + @add_port_docs() + def input_ports(self): + """Returns definitions of module input ports. + """ + return { + "logits": NeuralType(['B'] + ['ANY'] * (self._logits_dim - 1), LogitsType()), + "labels": NeuralType(['B'] + ['ANY'] * (self._logits_dim - 2), LabelsType()), + "loss_mask": NeuralType(['B'] + ['ANY'] * (self._logits_dim - 2), MaskType(), optional=True), + } + + @property + @add_port_docs() + def output_ports(self): + """Returns definitions of module output ports. + + loss: + NeuralType(None) + """ + return {"loss": NeuralType(elements_type=LossType())} + + def __init__(self, logits_ndim=2, weight=None, reduction='mean'): + super().__init__() + + if weight: + weight = torch.FloatTensor(weight).to(self._device) + self._criterion = nn.BCEWithLogitsLoss(weight=weight, reduction=reduction) + self._logits_dim = logits_ndim + + def _loss_function(self, logits, labels, loss_mask=None): + """ + Args: + logits (float): output of the classifier + labels (long): ground truth labels + loss_mask (bool/float/int): tensor to specify the masking + """ + logits_flatten = torch.flatten(logits, start_dim=0, end_dim=-2) + labels_flatten = torch.flatten(labels, start_dim=0, end_dim=-1) + + if loss_mask is not None: + if loss_mask.dtype is not torch.bool: + loss_mask = loss_mask > 0.5 + loss_mask_flatten = torch.flatten(loss_mask, start_dim=0, end_dim=-1) + logits_flatten = logits_flatten[loss_mask_flatten] + labels_flatten = labels_flatten[loss_mask_flatten] + + if len(labels_flatten) == 0: + return 0 + + loss = self._criterion(logits_flatten, labels_flatten) + return loss diff --git a/nemo/backends/pytorch/nm.py b/nemo/backends/pytorch/nm.py index 2b1c33b48e43..761aa505031e 100644 --- a/nemo/backends/pytorch/nm.py +++ b/nemo/backends/pytorch/nm.py @@ -39,6 +39,9 @@ def __init__(self, pretrained_model_name=None, name=None): nn.Module.__init__(self) # For PyTorch API NeuralModule.__init__(self, name) # For NeuralModule API + # Unfrozen by default. + self._frozen = False + # Set module type. self._type = ModuleType.trainable @@ -115,6 +118,8 @@ def freeze(self, weights=None): for name, param in self.named_parameters(): if weights is None or name in weights: param.requires_grad = False + # Freeze. + self._frozen = True @t.jit.ignore def unfreeze(self, weights=None): @@ -126,6 +131,15 @@ def unfreeze(self, weights=None): for name, param in self.named_parameters(): if weights is None or name in weights: param.requires_grad = True + # Unfreeze. + self._frozen = False + + @t.jit.ignore + def is_frozen(self) -> bool: + """ Returns: + True/False depending whether there are any frozen weights or not. + """ + return self._frozen @property def num_weights(self): @@ -214,6 +228,7 @@ def __init__(self, name=None): self._batch_size = 1 self._num_workers = os.cpu_count() # Use all CPUs by default. self._shuffle = False # Don't shuffle by default. + self._pin_memory = False @property def input_ports(self): @@ -327,6 +342,11 @@ def num_workers(self): # """ Property setting the number of workers. """ # self._num_workers = nw + @property + def pin_memory(self): + """ Property returning the pin memory flag. """ + return self._pin_memory + class LossNM(NeuralModule): """A helper Base class for creating Pytorch-based loss function modules. diff --git a/nemo/collections/asr/audio_preprocessing.py b/nemo/collections/asr/audio_preprocessing.py index 98c96b8520d9..54f3df7e8f0c 100644 --- a/nemo/collections/asr/audio_preprocessing.py +++ b/nemo/collections/asr/audio_preprocessing.py @@ -28,19 +28,18 @@ ] import math -import warnings from abc import abstractmethod import numpy as np import torch from packaging import version -import nemo from .parts.features import FilterbankFeatures from .parts.spectr_augment import SpecAugment, SpecCutout from nemo.backends.pytorch import NonTrainableNM from nemo.core import Optimization from nemo.core.neural_types import * +from nemo.utils import logging from nemo.utils.decorators import add_port_docs try: @@ -54,14 +53,12 @@ HAVE_TORCHAUDIO = True except ModuleNotFoundError: HAVE_TORCHAUDIO = False - warnings.warn('Could not import torchaudio. Some features might not work.') + logging.warning('Could not import torchaudio. Some features might not work.') + try: from apex import amp except (AttributeError, ModuleNotFoundError) as e: - warnings.warn("Unable to import APEX. Mixed precision and distributed training will not work.") - - -logging = nemo.logging + logging.warning("Unable to import APEX. Mixed precision and distributed training will not work.") class AudioPreprocessor(NonTrainableNM): diff --git a/nemo/collections/asr/contextnet.py b/nemo/collections/asr/contextnet.py index c09be485d67a..145a6d79718a 100644 --- a/nemo/collections/asr/contextnet.py +++ b/nemo/collections/asr/contextnet.py @@ -5,15 +5,13 @@ import torch.nn as nn import torch.nn.functional as F -import nemo from .jasper import JasperEncoder from .parts.jasper import init_weights from nemo.backends.pytorch.nm import TrainableNM from nemo.core.neural_types import * +from nemo.utils import logging from nemo.utils.decorators import add_port_docs -logging = nemo.logging - class ContextNetEncoder(JasperEncoder): """ diff --git a/nemo/collections/asr/data_layer.py b/nemo/collections/asr/data_layer.py index ab94c70d53e4..dbaba86c3190 100644 --- a/nemo/collections/asr/data_layer.py +++ b/nemo/collections/asr/data_layer.py @@ -24,7 +24,6 @@ import torch import webdataset as wd -import nemo from .parts.collections import ASRAudioText from .parts.dataset import ( AudioDataset, @@ -40,6 +39,7 @@ from nemo.backends.pytorch import DataLayerNM from nemo.core import DeviceType from nemo.core.neural_types import * +from nemo.utils import logging from nemo.utils.decorators import add_port_docs from nemo.utils.misc import pad_to @@ -51,8 +51,6 @@ 'AudioToSpeechLabelDataLayer', ] -logging = nemo.logging - def _process_augmentations(augmenter) -> AudioAugmentor: """Process list of online data augmentations. @@ -497,7 +495,7 @@ def __init__( self.collate_fn = partial(seq_collate_fn, token_pad_value=pad_id) # Check for distributed and partition shards accordingly - if torch.distributed.is_initialized(): + if torch.distributed.is_available() and torch.distributed.is_initialized(): global_rank = torch.distributed.get_rank() world_size = torch.distributed.get_world_size() diff --git a/nemo/collections/asr/greedy_ctc_decoder.py b/nemo/collections/asr/greedy_ctc_decoder.py index 287db80cd8bf..c4a264f10832 100644 --- a/nemo/collections/asr/greedy_ctc_decoder.py +++ b/nemo/collections/asr/greedy_ctc_decoder.py @@ -1,12 +1,27 @@ -# Copyright (c) 2019 NVIDIA Corporation -import torch +# -*- coding: utf-8 -*- -from nemo.backends.pytorch.nm import TrainableNM -from nemo.core.neural_types import * +# ============================================================================= +# Copyright (c) 2020 NVIDIA. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= + +from nemo.backends.pytorch.nm import NonTrainableNM +from nemo.core.neural_types import LogprobsType, NeuralType, PredictionsType from nemo.utils.decorators import add_port_docs -class GreedyCTCDecoder(TrainableNM): +class GreedyCTCDecoder(NonTrainableNM): """ Greedy decoder that computes the argmax over a softmax distribution """ @@ -14,23 +29,22 @@ class GreedyCTCDecoder(TrainableNM): @property @add_port_docs() def input_ports(self): - """Returns definitions of module input ports. + """Returns: + Definitions of module input ports. """ - # return {"log_probs": NeuralType({0: AxisType(BatchTag), 1: AxisType(TimeTag), 2: AxisType(ChannelTag),})} return {"log_probs": NeuralType(('B', 'T', 'D'), LogprobsType())} @property @add_port_docs() def output_ports(self): - """Returns definitions of module output ports. + """Returns: + Definitions of module output ports. """ - # return {"predictions": NeuralType({0: AxisType(BatchTag), 1: AxisType(TimeTag)})} return {"predictions": NeuralType(('B', 'T'), PredictionsType())} def __init__(self): super().__init__() def forward(self, log_probs): - with torch.no_grad(): - argmx = log_probs.argmax(dim=-1, keepdim=False) - return argmx + argmx = log_probs.argmax(dim=-1, keepdim=False) + return argmx diff --git a/nemo/collections/asr/helpers.py b/nemo/collections/asr/helpers.py index 7734b48b9ee7..dd36cd412e1e 100644 --- a/nemo/collections/asr/helpers.py +++ b/nemo/collections/asr/helpers.py @@ -2,10 +2,8 @@ import torch -import nemo from .metrics import classification_accuracy, word_error_rate - -logging = nemo.logging +from nemo.utils import logging def __ctc_decoder_predictions_tensor(tensor, labels): diff --git a/nemo/collections/asr/jasper.py b/nemo/collections/asr/jasper.py index b5de8f6b7af4..c83ed59b5e3b 100644 --- a/nemo/collections/asr/jasper.py +++ b/nemo/collections/asr/jasper.py @@ -5,14 +5,12 @@ import torch.nn as nn import torch.nn.functional as F -import nemo from .parts.jasper import JasperBlock, StatsPoolLayer, init_weights, jasper_activations from nemo.backends.pytorch.nm import TrainableNM from nemo.core.neural_types import * +from nemo.utils import logging from nemo.utils.decorators import add_port_docs -logging = nemo.logging - class JasperEncoder(TrainableNM): """ diff --git a/nemo/collections/asr/las/helpers.py b/nemo/collections/asr/las/helpers.py index 92558323d1ef..baa44e48075b 100644 --- a/nemo/collections/asr/las/helpers.py +++ b/nemo/collections/asr/las/helpers.py @@ -3,11 +3,10 @@ import torch -import nemo from nemo.backends.pytorch.common.metrics import char_lm_metrics from nemo.collections.asr.metrics import word_error_rate +from nemo.utils import logging -logging = nemo.logging ENG_MWN = 5.3 diff --git a/nemo/collections/asr/parts/collections.py b/nemo/collections/asr/parts/collections.py index c28f2d9f3e30..de11bfebcf7c 100644 --- a/nemo/collections/asr/parts/collections.py +++ b/nemo/collections/asr/parts/collections.py @@ -6,10 +6,8 @@ import pandas as pd -import nemo from nemo.collections.asr.parts import manifest, parsers - -logging = nemo.logging +from nemo.utils import logging class _Collection(collections.UserList): diff --git a/nemo/collections/nlp/callbacks/sgd_callback.py b/nemo/collections/nlp/callbacks/sgd_callback.py new file mode 100644 index 000000000000..8ba7897952e8 --- /dev/null +++ b/nemo/collections/nlp/callbacks/sgd_callback.py @@ -0,0 +1,225 @@ +# ============================================================================= +# Copyright 2020 NVIDIA. All Rights Reserved. +# Copyright 2019 The Google Research Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= + +""" +This file contains code artifacts adapted from the original implementation: +https://github.com/google-research/google-research/blob/master/schema_guided_dst +""" + +import json +import os + +import torch + +import nemo.collections.nlp.data.datasets.sgd_dataset.prediction_utils as pred_utils +from nemo import logging +from nemo.collections.nlp.data.datasets.sgd_dataset.evaluate import ( + ALL_SERVICES, + PER_FRAME_OUTPUT_FILENAME, + SEEN_SERVICES, + UNSEEN_SERVICES, + get_dataset_as_dict, + get_in_domain_services, + get_metrics, +) + +__all__ = ['eval_iter_callback', 'eval_epochs_done_callback'] + + +def tensor2list(tensor): + return tensor.detach().cpu().tolist() + + +def get_str_example_id(eval_dataset, ids_to_service_names_dict, example_id_num): + def format_turn_id(ex_id_num): + dialog_id_1, dialog_id_2, turn_id, service_id = ex_id_num + return "{}-{}_{:05d}-{:02d}-{}".format( + eval_dataset, dialog_id_1, dialog_id_2, turn_id, ids_to_service_names_dict[service_id] + ) + + return list(map(format_turn_id, tensor2list(example_id_num))) + + +def eval_iter_callback(tensors, global_vars, schema_processor, eval_dataset): + if 'predictions' not in global_vars: + global_vars['predictions'] = [] + + output = {} + for k, v in tensors.items(): + ind = k.find('~~~') + if ind != -1: + output[k[:ind]] = torch.cat(v) + + predictions = {} + ids_to_service_names_dict = schema_processor.get_ids_to_service_names_dict() + predictions['example_id'] = get_str_example_id(eval_dataset, ids_to_service_names_dict, output['example_id_num']) + + predictions['service_id'] = output['service_id'] + predictions['is_real_example'] = output['is_real_example'] + + # Scores are output for each intent. + # Note that the intent indices are shifted by 1 to account for NONE intent. + predictions['intent_status'] = torch.argmax(output['logit_intent_status'], -1) + + # Scores are output for each requested slot. + predictions['req_slot_status'] = torch.nn.Sigmoid()(output['logit_req_slot_status']) + + # For categorical slots, the status of each slot and the predicted value are output. + cat_slot_status_dist = torch.nn.Softmax(dim=-1)(output['logit_cat_slot_status']) + cat_slot_value_dist = torch.nn.Softmax(dim=-1)(output['logit_cat_slot_value']) + + predictions['cat_slot_status'] = torch.argmax(output['logit_cat_slot_status'], axis=-1) + predictions['cat_slot_status_p'] = torch.max(cat_slot_status_dist, axis=-1)[0] + predictions['cat_slot_value'] = torch.argmax(output['logit_cat_slot_value'], axis=-1) + predictions['cat_slot_value_p'] = torch.max(cat_slot_value_dist, axis=-1)[0] + + # For non-categorical slots, the status of each slot and the indices for spans are output. + noncat_slot_status_dist = torch.nn.Softmax(dim=-1)(output['logit_noncat_slot_status']) + + predictions['noncat_slot_status'] = torch.argmax(output['logit_noncat_slot_status'], axis=-1) + predictions['noncat_slot_status_p'] = torch.max(noncat_slot_status_dist, axis=-1)[0] + + softmax = torch.nn.Softmax(dim=-1) + start_scores = softmax(output['logit_noncat_slot_start']) + end_scores = softmax(output['logit_noncat_slot_end']) + + batch_size, max_num_noncat_slots, max_num_tokens = end_scores.size() + # Find the span with the maximum sum of scores for start and end indices. + total_scores = torch.unsqueeze(start_scores, axis=3) + torch.unsqueeze(end_scores, axis=2) + # Mask out scores where start_index > end_index. + # device = total_scores.device + start_idx = torch.arange(max_num_tokens, device=total_scores.device).view(1, 1, -1, 1) + end_idx = torch.arange(max_num_tokens, device=total_scores.device).view(1, 1, 1, -1) + invalid_index_mask = (start_idx > end_idx).repeat(batch_size, max_num_noncat_slots, 1, 1) + total_scores = torch.where( + invalid_index_mask, + torch.zeros(total_scores.size(), device=total_scores.device, dtype=total_scores.dtype), + total_scores, + ) + max_span_index = torch.argmax(total_scores.view(-1, max_num_noncat_slots, max_num_tokens ** 2), axis=-1) + max_span_p = torch.max(total_scores.view(-1, max_num_noncat_slots, max_num_tokens ** 2), axis=-1)[0] + predictions['noncat_slot_p'] = max_span_p + + span_start_index = torch.div(max_span_index, max_num_tokens) + span_end_index = torch.fmod(max_span_index, max_num_tokens) + + predictions['noncat_slot_start'] = span_start_index + predictions['noncat_slot_end'] = span_end_index + + # Add inverse alignments. + predictions['noncat_alignment_start'] = output['start_char_idx'] + predictions['noncat_alignment_end'] = output['end_char_idx'] + + # added for debugging + predictions['cat_slot_status_GT'] = output['categorical_slot_status'] + predictions['noncat_slot_status_GT'] = output['noncategorical_slot_status'] + + global_vars['predictions'].extend(combine_predictions_in_example(predictions, batch_size)) + + +def combine_predictions_in_example(predictions, batch_size): + ''' + Combines predicted values to a single example. + ''' + examples_preds = [{} for _ in range(batch_size)] + for k, v in predictions.items(): + if k != 'example_id': + v = torch.chunk(v, batch_size) + + for i in range(batch_size): + if k == 'example_id': + examples_preds[i][k] = v[i] + else: + examples_preds[i][k] = v[i].view(-1) + return examples_preds + + +def eval_epochs_done_callback( + global_vars, + input_json_files, + eval_dataset, + data_dir, + prediction_dir, + output_metric_file, + state_tracker, + eval_debug, + schema_emb_preprocessor, + joint_acc_across_turn, + no_fuzzy_match, +): + # added for debugging + in_domain_services = get_in_domain_services( + os.path.join(data_dir, eval_dataset, "schema.json"), os.path.join(data_dir, "train", "schema.json") + ) + ############## + pred_utils.write_predictions_to_file( + global_vars['predictions'], + input_json_files, + prediction_dir, + schemas=schema_emb_preprocessor.schemas, + state_tracker=state_tracker, + eval_debug=eval_debug, + in_domain_services=in_domain_services, + ) + metrics = evaluate( + prediction_dir, + data_dir, + eval_dataset, + output_metric_file, + schema_emb_preprocessor.schemas, + joint_acc_across_turn, + no_fuzzy_match, + ) + return metrics + + +def evaluate( + prediction_dir, data_dir, eval_dataset, output_metric_file, schemas, joint_acc_across_turn, no_fuzzy_match +): + + in_domain_services = get_in_domain_services( + os.path.join(data_dir, eval_dataset, "schema.json"), os.path.join(data_dir, "train", "schema.json") + ) + + with open(os.path.join(data_dir, eval_dataset, "schema.json")) as f: + eval_services = {} + list_services = json.load(f) + for service in list_services: + eval_services[service["service_name"]] = service + f.close() + + dataset_ref = get_dataset_as_dict(os.path.join(data_dir, eval_dataset, "dialogues_*.json")) + dataset_hyp = get_dataset_as_dict(os.path.join(prediction_dir, "*.json")) + + all_metric_aggregate, _ = get_metrics( + dataset_ref, dataset_hyp, eval_services, in_domain_services, joint_acc_across_turn, no_fuzzy_match + ) + if SEEN_SERVICES in all_metric_aggregate: + logging.info(f'Dialog metrics for {SEEN_SERVICES} : {sorted(all_metric_aggregate[SEEN_SERVICES].items())}') + if UNSEEN_SERVICES in all_metric_aggregate: + logging.info(f'Dialog metrics for {UNSEEN_SERVICES}: {sorted(all_metric_aggregate[UNSEEN_SERVICES].items())}') + if ALL_SERVICES in all_metric_aggregate: + logging.info(f'Dialog metrics for {ALL_SERVICES} : {sorted(all_metric_aggregate[ALL_SERVICES].items())}') + # Write the aggregated metrics values. + with open(output_metric_file, "w") as f: + json.dump(all_metric_aggregate, f, indent=2, separators=(",", ": "), sort_keys=True) + f.close() + # Write the per-frame metrics values with the corrresponding dialogue frames. + with open(os.path.join(prediction_dir, PER_FRAME_OUTPUT_FILENAME), "w") as f: + json.dump(dataset_hyp, f, indent=2, separators=(",", ": ")) + f.close() + return all_metric_aggregate[ALL_SERVICES] diff --git a/nemo/collections/nlp/data/datasets/__init__.py b/nemo/collections/nlp/data/datasets/__init__.py index 2342e3f25ead..1e31f3e115f4 100644 --- a/nemo/collections/nlp/data/datasets/__init__.py +++ b/nemo/collections/nlp/data/datasets/__init__.py @@ -31,6 +31,8 @@ BertPunctuationCapitalizationInferDataset, ) from nemo.collections.nlp.data.datasets.qa_squad_dataset.qa_squad_dataset import SquadDataset +from nemo.collections.nlp.data.datasets.sgd_dataset.schema_embedding_dataset import SchemaEmbeddingDataset +from nemo.collections.nlp.data.datasets.sgd_dataset.sgd_dataset import SGDDataset from nemo.collections.nlp.data.datasets.text_classification import ( BertTextClassificationDataset, TextClassificationDataDesc, diff --git a/nemo/collections/nlp/data/datasets/sgd_dataset/data_processor.py b/nemo/collections/nlp/data/datasets/sgd_dataset/data_processor.py new file mode 100644 index 000000000000..7a4b7c2bc05f --- /dev/null +++ b/nemo/collections/nlp/data/datasets/sgd_dataset/data_processor.py @@ -0,0 +1,401 @@ +# ============================================================================= +# Copyright 2020 NVIDIA. All Rights Reserved. +# Copyright 2019 The Google Research Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= + +""" +This file contains code artifacts adapted from the original implementation: +https://github.com/google-research/google-research/blob/master/schema_guided_dst/baseline/data_utils.py +""" + +import json +import os +import re + +import numpy as np +import torch + +from nemo.collections.nlp.data.datasets.sgd_dataset.input_example import InputExample +from nemo.utils import logging + +__all__ = [ + 'FILE_RANGES', + 'PER_FRAME_OUTPUT_FILENAME', + 'Dstc8DataProcessor', +] + + +FILE_RANGES = { + "dstc8_single_domain": {"train": range(1, 44), "dev": range(1, 8), "test": range(1, 12)}, + "dstc8_multi_domain": {"train": range(44, 128), "dev": range(8, 21), "test": range(12, 35)}, + "dstc8_all": {"train": range(1, 128), "dev": range(1, 21), "test": range(1, 35)}, + "DEBUG": {"train": range(1, 2), "dev": range(1, 2), "test": range(1, 2)}, + "multiwoz": {"train": range(1, 18), "dev": range(1, 3), "test": range(1, 3)}, +} + +# Name of the file containing all predictions and their corresponding frame metrics. +PER_FRAME_OUTPUT_FILENAME = "dialogues_and_metrics.json" + + +class Dstc8DataProcessor(object): + """Data generator for dstc8 dialogues.""" + + def __init__( + self, + task_name, + dstc8_data_dir, + dialogues_example_dir, + tokenizer, + schema_emb_processor, + overwrite_dial_files=False, + ): + """ + Constructs Dstc8DataProcessor + Args: + task_name (str): task name, for example, "dstc8_single_domain" + dstc8_data_dir (str): path to data directory + dialogues_example_dir (str): path to store processed dialogue examples + tokenizer (Tokenizer): such as NemoBertTokenizer + schema_emb_processor (Obj): contains information about schemas + overwrite_dial_files (bool): whether to overwite dialogue files + """ + self.dstc8_data_dir = dstc8_data_dir + self.dialogues_examples_dir = dialogues_example_dir + + self._task_name = task_name + self.schema_config = schema_emb_processor.schema_config + + train_file_range = FILE_RANGES[task_name]["train"] + dev_file_range = FILE_RANGES[task_name]["dev"] + test_file_range = FILE_RANGES[task_name]["test"] + + self._file_ranges = { + "train": train_file_range, + "dev": dev_file_range, + "test": test_file_range, + } + + self._tokenizer = tokenizer + self._max_seq_length = self.schema_config["MAX_SEQ_LENGTH"] + + self.dial_files = {} + + for dataset in ["train", "dev", "test"]: + # Process dialogue files + dial_file = f"{task_name}_{dataset}_examples.processed" + dial_file = os.path.join(dialogues_example_dir, dial_file) + self.dial_files[(task_name, dataset)] = dial_file + + if not os.path.exists(dial_file) or overwrite_dial_files: + logging.debug(f"Start generating the dialogue examples for {dataset} dataset.") + master_device = not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0 + if master_device: + if not os.path.exists(dialogues_example_dir): + os.makedirs(dialogues_example_dir) + dial_examples = self._generate_dialog_examples(dataset, schema_emb_processor.schemas) + with open(dial_file, "wb") as f: + np.save(f, dial_examples) + f.close() + logging.debug(f"The dialogue examples for {dataset} dataset saved at {dial_file}") + logging.debug(f"Finish generating the dialogue examples for {dataset} dataset.") + + # wait until the master process writes to the dialogue processed file + if torch.distributed.is_initialized(): + torch.distributed.barrier() + + def get_dialog_examples(self, dataset): + """ + Returns a list of `InputExample`s of the data splits' dialogues. + Args: + dataset(str): can be "train", "dev", or "test". + Returns: + examples: a list of `InputExample`s. + """ + if (self._task_name, dataset) not in self.dial_files or not os.path.exists( + self.dial_files[(self._task_name, dataset)] + ): + raise ValueError( + f"{dataset} dialogue examples were not processed for {self._task_name} task. Re-initialize Dstc8DataProcessor and add {dataset} dataset to datasets arg." + ) + + dial_file = self.dial_files[(self._task_name, dataset)] + logging.info(f"Loading dialogue examples from {dial_file}.") + with open(dial_file, "rb") as f: + dial_examples = np.load(f, allow_pickle=True) + f.close() + return dial_examples + + def _generate_dialog_examples(self, dataset, schemas): + """ + Returns a list of `InputExample`s of the data splits' dialogues. + Args: + dataset(str): can be "train", "dev", or "test". + schemas(Schema): for all services and all datasets processed by the schema_processor + Returns: + examples: a list of `InputExample`s. + """ + logging.info(f'Creating examples from the dialogues started...') + dialog_paths = [ + os.path.join(self.dstc8_data_dir, dataset, "dialogues_{:03d}.json".format(i)) + for i in self._file_ranges[dataset] + ] + dialogs = Dstc8DataProcessor.load_dialogues(dialog_paths) + + examples = [] + for dialog_idx, dialog in enumerate(dialogs): + if dialog_idx % 1000 == 0: + logging.info(f'Processed {dialog_idx} dialogs.') + examples.extend(self._create_examples_from_dialog(dialog, schemas, dataset)) + + logging.info(f'Finished creating the examples from {len(dialogs)} dialogues.') + return examples + + def _create_examples_from_dialog(self, dialog, schemas, dataset): + """ + Create examples for every turn in the dialog. + Args: + dialog (dict): dialogue example + schemas(Schema): for all services and all datasets processed by the schema_processor + dataset(str): can be "train", "dev", or "test". + Returns: + examples: a list of `InputExample`s. + """ + dialog_id = dialog["dialogue_id"] + prev_states = {} + examples = [] + for turn_idx, turn in enumerate(dialog["turns"]): + # Generate an example for every frame in every user turn. + if turn["speaker"] == "USER": + user_utterance = turn["utterance"] + user_frames = {f["service"]: f for f in turn["frames"]} + if turn_idx > 0: + system_turn = dialog["turns"][turn_idx - 1] + system_utterance = system_turn["utterance"] + system_frames = {f["service"]: f for f in system_turn["frames"]} + else: + system_utterance = "" + system_frames = {} + + turn_id = "{}-{}-{:02d}".format(dataset, dialog_id, turn_idx) + turn_examples, prev_states = self._create_examples_from_turn( + turn_id, system_utterance, user_utterance, system_frames, user_frames, prev_states, schemas + ) + examples.extend(turn_examples) + return examples + + def _get_state_update(self, current_state, prev_state): + """ + Updates dialogue state + Args: + current_state (dict): dict of slot - slot values pairs for the current dialogue turn + prev_state (dict): dict of slot - slot values pairs for the previous dialogue turns + Returns: + state_update (dict): dict of slot - slot values pairs that very added/updated during the current + dialogue turn + """ + state_update = dict(current_state) + for slot, values in current_state.items(): + if slot in prev_state and prev_state[slot][0] in values: + # Remove the slot from state if its value didn't change. + state_update.pop(slot) + return state_update + + def _create_examples_from_turn( + self, turn_id, system_utterance, user_utterance, system_frames, user_frames, prev_states, schemas + ): + """ + Creates an example for each frame in the user turn. + Args: + turn_id (int): turn number + system_utterance (str): last system utterance + user_utterance (str): lst user utterance + system_frames (dict): all system utterances and slot - slot value pairs + user_frames (dict): all user utterances and slot - slot value pairs + prev_states (dict): slot - slot value pairs from the previous turns + schemas (obj): carries information about the service from the current turn + Returns: + examples: a list of `InputExample`s. + prev_states (dict): updated dialogue state + """ + system_tokens, system_alignments, system_inv_alignments = self._tokenize(system_utterance) + user_tokens, user_alignments, user_inv_alignments = self._tokenize(user_utterance) + states = {} + base_example = InputExample(schema_config=self.schema_config, is_real_example=True, tokenizer=self._tokenizer,) + base_example.example_id = turn_id + + _, dialog_id, turn_id_ = turn_id.split('-') + dialog_id_1, dialog_id_2 = dialog_id.split('_') + base_example.example_id_num = [int(dialog_id_1), int(dialog_id_2), int(turn_id_)] + base_example.add_utterance_features( + system_tokens, system_inv_alignments, user_tokens, user_inv_alignments, system_utterance, user_utterance + ) + examples = [] + for service, user_frame in user_frames.items(): + # Create an example for this service. + example = base_example.make_copy_with_utterance_features() + + example.example_id = "{}-{}".format(turn_id, service) + _, dialog_id, turn_id_ = turn_id.split('-') + dialog_id_1, dialog_id_2 = dialog_id.split('_') + example.example_id_num = [ + int(dialog_id_1), + int(dialog_id_2), + int(turn_id_), + schemas.get_service_id(service), + ] + + example.service_schema = schemas.get_service_schema(service) + system_frame = system_frames.get(service, None) + state = user_frame["state"]["slot_values"] + state_update = self._get_state_update(state, prev_states.get(service, {})) + states[service] = state + # Populate features in the example. + example.add_categorical_slots(state_update) + # The input tokens to bert are in the format [CLS] [S1] [S2] ... [SEP] + # [U1] [U2] ... [SEP] [PAD] ... [PAD]. For system token indices a bias of + # 1 is added for the [CLS] token and for user tokens a bias of 2 + + # len(system_tokens) is added to account for [CLS], system tokens and + # [SEP]. + user_span_boundaries = self._find_subword_indices( + state_update, user_utterance, user_frame["slots"], user_alignments, user_tokens, 2 + len(system_tokens) + ) + if system_frame is not None: + system_span_boundaries = self._find_subword_indices( + state_update, system_utterance, system_frame["slots"], system_alignments, system_tokens, 1 + ) + else: + system_span_boundaries = {} + example.add_noncategorical_slots(state_update, user_span_boundaries, system_span_boundaries) + example.add_requested_slots(user_frame) + example.add_intents(user_frame) + examples.append(example) + return examples, states + + def _find_subword_indices(self, slot_values, utterance, char_slot_spans, alignments, subwords, bias): + """Find indices for subwords corresponding to slot values.""" + span_boundaries = {} + for slot, values in slot_values.items(): + # Get all values present in the utterance for the specified slot. + value_char_spans = {} + for slot_span in char_slot_spans: + if slot_span["slot"] == slot: + value = utterance[slot_span["start"] : slot_span["exclusive_end"]] + start_tok_idx = alignments[slot_span["start"]] + end_tok_idx = alignments[slot_span["exclusive_end"] - 1] + if 0 <= start_tok_idx < len(subwords): + end_tok_idx = min(end_tok_idx, len(subwords) - 1) + value_char_spans[value] = (start_tok_idx + bias, end_tok_idx + bias) + for v in values: + if v in value_char_spans: + span_boundaries[slot] = value_char_spans[v] + break + return span_boundaries + + def _tokenize(self, utterance): + """Tokenize the utterance using word-piece tokenization used by BERT. + + Args: + utterance: A string containing the utterance to be tokenized. + + Returns: + bert_tokens: A list of tokens obtained by word-piece tokenization of the + utterance. + alignments: A dict mapping indices of characters corresponding to start + and end positions of words (not subwords) to corresponding indices in + bert_tokens list. + inverse_alignments: A list of size equal to bert_tokens. Each element is a + tuple containing the index of the starting and inclusive ending + character of the word corresponding to the subword. This list is used + during inference to map word-piece indices to spans in the original + utterance. + """ + # utterance = tokenization.convert_to_unicode(utterance) + + # After _naive_tokenize, spaces and punctuation marks are all retained, i.e. + # direct concatenation of all the tokens in the sequence will be the + # original string. + tokens = Dstc8DataProcessor._naive_tokenize(utterance) + # Filter out empty tokens and obtain aligned character index for each token. + alignments = {} + char_index = 0 + bert_tokens = [] + # These lists store inverse alignments to be used during inference. + bert_tokens_start_chars = [] + bert_tokens_end_chars = [] + for token in tokens: + if token.strip(): + subwords = self._tokenizer.text_to_tokens(token) + # Store the alignment for the index of starting character and the + # inclusive ending character of the token. + alignments[char_index] = len(bert_tokens) + bert_tokens_start_chars.extend([char_index] * len(subwords)) + bert_tokens.extend(subwords) + # The inclusive ending character index corresponding to the word. + inclusive_char_end = char_index + len(token) - 1 + alignments[inclusive_char_end] = len(bert_tokens) - 1 + bert_tokens_end_chars.extend([inclusive_char_end] * len(subwords)) + char_index += len(token) + inverse_alignments = list(zip(bert_tokens_start_chars, bert_tokens_end_chars)) + return bert_tokens, alignments, inverse_alignments + + def get_num_dialog_examples(self, dataset): + """ + Gets the number of dilaog examples in the data split. + Args: + dataset: str. can be "train", "dev", or "test". + Returns:from nemo_nlp.data.datasets.sgd import data_utils + example_count: int. number of examples in the specified dataset. + """ + example_count = 0 + dialog_paths = [ + os.path.join(self.dstc8_data_dir, dataset, "dialogues_{:03d}.json".format(i)) + for i in self._file_ranges[dataset] + ] + dst_set = Dstc8DataProcessor.load_dialogues(dialog_paths) + for dialog in dst_set: + for turn in dialog["turns"]: + if turn["speaker"] == "USER": + example_count += len(turn["frames"]) + return example_count + + @classmethod + def _naive_tokenize(cls, s): + """ + Tokenizes a string, separating words, spaces and punctuations. + Args: + s (str): a string + Returns: + seq_tok (list): list of words, spaces and punctuations from the s + """ + # Spaces and punctuation marks are all retained, i.e. direct concatenation + # of all the tokens in the sequence will be the original string. + seq_tok = [tok for tok in re.split(r"([^a-zA-Z0-9])", s) if tok] + return seq_tok + + @classmethod + def load_dialogues(cls, dialog_json_filepaths): + """ + Obtain the list of all dialogues from specified json files. + Args: + dialog_json_filepaths (list): list of json files + Returns: + dialogs (list): the list of all dialogues + """ + dialogs = [] + for dialog_json_filepath in sorted(dialog_json_filepaths): + with open(dialog_json_filepath, 'r') as f: + dialogs.extend(json.load(f)) + f.close() + return dialogs diff --git a/nemo/collections/nlp/data/datasets/sgd_dataset/evaluate.py b/nemo/collections/nlp/data/datasets/sgd_dataset/evaluate.py new file mode 100644 index 000000000000..fb2dba564b78 --- /dev/null +++ b/nemo/collections/nlp/data/datasets/sgd_dataset/evaluate.py @@ -0,0 +1,213 @@ +# ============================================================================= +# Copyright 2020 NVIDIA. All Rights Reserved. +# Copyright 2019 The Google Research Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= + +""" +Evaluate predictions JSON file, w.r.t. ground truth file. +This file contains code artifacts adapted from the original implementation: +https://github.com/google-research/google-research/blob/master/schema_guided_dst/evaluate.py +""" + +import collections +import glob +import json + +import numpy as np + +import nemo +import nemo.collections.nlp.data.datasets.sgd_dataset.metrics as metrics + +__all__ = [ + 'get_in_domain_services', + 'get_dataset_as_dict', + 'ALL_SERVICES', + 'SEEN_SERVICES', + 'UNSEEN_SERVICES', + 'get_metrics', + 'PER_FRAME_OUTPUT_FILENAME', +] + +ALL_SERVICES = "#ALL_SERVICES" +SEEN_SERVICES = "#SEEN_SERVICES" +UNSEEN_SERVICES = "#UNSEEN_SERVICES" + +# Name of the file containing all predictions and their corresponding frame metrics. +PER_FRAME_OUTPUT_FILENAME = "dialogues_and_metrics.json" + + +def get_service_set(schema_path): + """Get the set of all services present in a schema.""" + service_set = set() + with open(schema_path) as f: + schema = json.load(f) + for service in schema: + service_set.add(service["service_name"]) + f.close() + return service_set + + +def get_in_domain_services(schema_path_1, schema_path_2): + """Get the set of common services between two schemas.""" + return get_service_set(schema_path_1) & get_service_set(schema_path_2) + + +def get_dataset_as_dict(file_path_patterns): + """Read the DSTC8 json dialog data as dictionary with dialog ID as keys.""" + dataset_dict = {} + if isinstance(file_path_patterns, list): + list_fp = file_path_patterns + else: + list_fp = sorted(glob.glob(file_path_patterns)) + for fp in list_fp: + if PER_FRAME_OUTPUT_FILENAME in fp: + continue + nemo.logging.info("Loading file: %s", fp) + with open(fp) as f: + data = json.load(f) + if isinstance(data, list): + for dial in data: + dataset_dict[dial["dialogue_id"]] = dial + elif isinstance(data, dict): + dataset_dict.update(data) + f.close() + return dataset_dict + + +def get_metrics(dataset_ref, dataset_hyp, service_schemas, in_domain_services, joint_acc_across_turn, no_fuzzy_match): + """Calculate the DSTC8 metrics. + + Args: + dataset_ref: The ground truth dataset represented as a dict mapping dialogue + id to the corresponding dialogue. + dataset_hyp: The predictions in the same format as `dataset_ref`. + service_schemas: A dict mapping service name to the schema for the service. + in_domain_services: The set of services which are present in the training + set. + schemas: Schemas with information for all services + + Returns: + A dict mapping a metric collection name to a dict containing the values + for various metrics. Each metric collection aggregates the metrics across + a specific set of frames in the dialogues. + """ + # Metrics can be aggregated in various ways, eg over all dialogues, only for + # dialogues containing unseen services or for dialogues corresponding to a + # single service. This aggregation is done through metric_collections, which + # is a dict mapping a collection name to a dict, which maps a metric to a list + # of values for that metric. Each value in this list is the value taken by + # the metric on a frame. + metric_collections = collections.defaultdict(lambda: collections.defaultdict(list)) + + # Ensure the dialogs in dataset_hyp also occur in dataset_ref. + assert set(dataset_hyp.keys()).issubset(set(dataset_ref.keys())) + nemo.logging.info("len(dataset_hyp)=%d, len(dataset_ref)=%d", len(dataset_hyp), len(dataset_ref)) + + # Store metrics for every frame for debugging. + per_frame_metric = {} + for dial_id, dial_hyp in dataset_hyp.items(): + dial_ref = dataset_ref[dial_id] + + if set(dial_ref["services"]) != set(dial_hyp["services"]): + raise ValueError( + "Set of services present in ground truth and predictions don't match " + "for dialogue with id {}".format(dial_id) + ) + + joint_metrics = [metrics.JOINT_GOAL_ACCURACY, metrics.JOINT_CAT_ACCURACY, metrics.JOINT_NONCAT_ACCURACY] + for turn_id, (turn_ref, turn_hyp) in enumerate(zip(dial_ref["turns"], dial_hyp["turns"])): + metric_collections_per_turn = collections.defaultdict(lambda: collections.defaultdict(lambda: 1.0)) + if turn_ref["speaker"] != turn_hyp["speaker"]: + raise ValueError("Speakers don't match in dialogue with id {}".format(dial_id)) + + # Skip system turns because metrics are only computed for user turns. + if turn_ref["speaker"] != "USER": + continue + + if turn_ref["utterance"] != turn_hyp["utterance"]: + nemo.logging.info("Ref utt: %s", turn_ref["utterance"]) + nemo.logging.info("Hyp utt: %s", turn_hyp["utterance"]) + raise ValueError("Utterances don't match for dialogue with id {}".format(dial_id)) + + hyp_frames_by_service = {frame["service"]: frame for frame in turn_hyp["frames"]} + + # Calculate metrics for each frame in each user turn. + for frame_ref in turn_ref["frames"]: + service_name = frame_ref["service"] + if service_name not in hyp_frames_by_service: + raise ValueError( + "Frame for service {} not found in dialogue with id {}".format(service_name, dial_id) + ) + service = service_schemas[service_name] + frame_hyp = hyp_frames_by_service[service_name] + + active_intent_acc = metrics.get_active_intent_accuracy(frame_ref, frame_hyp) + slot_tagging_f1_scores = metrics.get_slot_tagging_f1( + frame_ref, frame_hyp, turn_ref["utterance"], service + ) + requested_slots_f1_scores = metrics.get_requested_slots_f1(frame_ref, frame_hyp) + goal_accuracy_dict = metrics.get_average_and_joint_goal_accuracy( + frame_ref, frame_hyp, service, no_fuzzy_match + ) + + frame_metric = { + metrics.ACTIVE_INTENT_ACCURACY: active_intent_acc, + metrics.REQUESTED_SLOTS_F1: requested_slots_f1_scores.f1, + metrics.REQUESTED_SLOTS_PRECISION: requested_slots_f1_scores.precision, + metrics.REQUESTED_SLOTS_RECALL: requested_slots_f1_scores.recall, + } + if slot_tagging_f1_scores is not None: + frame_metric[metrics.SLOT_TAGGING_F1] = slot_tagging_f1_scores.f1 + frame_metric[metrics.SLOT_TAGGING_PRECISION] = slot_tagging_f1_scores.precision + frame_metric[metrics.SLOT_TAGGING_RECALL] = slot_tagging_f1_scores.recall + frame_metric.update(goal_accuracy_dict) + + frame_id = "{:s}-{:03d}-{:s}".format(dial_id, turn_id, frame_hyp["service"]) + per_frame_metric[frame_id] = frame_metric + # Add the frame-level metric result back to dialogues. + frame_hyp["metrics"] = frame_metric + + # Get the domain name of the service. + domain_name = frame_hyp["service"].split("_")[0] + domain_keys = [ALL_SERVICES, frame_hyp["service"], domain_name] + if frame_hyp["service"] in in_domain_services: + domain_keys.append(SEEN_SERVICES) + else: + domain_keys.append(UNSEEN_SERVICES) + for domain_key in domain_keys: + for metric_key, metric_value in frame_metric.items(): + if metric_value != metrics.NAN_VAL: + if joint_acc_across_turn and metric_key in joint_metrics: + metric_collections_per_turn[domain_key][metric_key] *= metric_value + else: + metric_collections[domain_key][metric_key].append(metric_value) + if joint_acc_across_turn: + # Conduct multiwoz style evaluation that computes joint goal accuracy + # across all the slot values of all the domains for each turn. + for domain_key in metric_collections_per_turn: + for metric_key, metric_value in metric_collections_per_turn[domain_key].items(): + metric_collections[domain_key][metric_key].append(metric_value) + + all_metric_aggregate = {} + for domain_key, domain_metric_vals in metric_collections.items(): + domain_metric_aggregate = {} + for metric_key, value_list in domain_metric_vals.items(): + if value_list: + # Metrics are macro-averaged across all frames. + domain_metric_aggregate[metric_key] = round(float(np.mean(value_list)) * 100.0, 2) + else: + domain_metric_aggregate[metric_key] = metrics.NAN_VAL + all_metric_aggregate[domain_key] = domain_metric_aggregate + return all_metric_aggregate, per_frame_metric diff --git a/nemo/collections/nlp/data/datasets/sgd_dataset/input_example.py b/nemo/collections/nlp/data/datasets/sgd_dataset/input_example.py new file mode 100644 index 000000000000..a9361bd7cfdb --- /dev/null +++ b/nemo/collections/nlp/data/datasets/sgd_dataset/input_example.py @@ -0,0 +1,393 @@ +# ============================================================================= +# Copyright 2020 NVIDIA. All Rights Reserved. +# Copyright 2019 The Google Research Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= +""" +This file contains code artifacts adapted from the original implementation: +https://github.com/google-research/google-research/blob/master/schema_guided_dst/baseline/data_utils.py +""" + +from nemo import logging + +__all__ = ['InputExample', 'STR_DONTCARE', 'STATUS_OFF', 'STATUS_ACTIVE', 'STATUS_DONTCARE', 'truncate_seq_pair'] + +STR_DONTCARE = "dontcare" + +# These are used to represent the status of slots (off, active, dontcare) and +# intents (off, active) in dialogue state tracking. +STATUS_OFF = 0 +STATUS_ACTIVE = 1 +STATUS_DONTCARE = 2 + + +class InputExample(object): + """An example for training/inference.""" + + def __init__( + self, + schema_config, + service_schema=None, + example_id="NONE", + example_id_num=[], + is_real_example=False, + tokenizer=None, + ): + """Constructs an InputExample. + + Args: + max_seq_length: The maximum length of the sequence. Sequences longer than + this value will be truncated. + service_schema: A ServiceSchema object wrapping the schema for the service + corresponding to this example. + example_id: Unique identifier for the example, like: 'train-1_00000-00-Restaurants_1' + example_id_num: dialogue_id and turn_id combined and service id combined into a list of ints, + like: [1, 0, 0, 18] + is_real_example: Indicates if an example is real or used for padding in a + minibatch. + tokenizer (Tokenizer): such as NemoBertTokenizer + """ + self.schema_config = schema_config + self.service_schema = service_schema + self.example_id = example_id + self.example_id_num = example_id_num + + self.is_real_example = is_real_example + self._max_seq_length = schema_config["MAX_SEQ_LENGTH"] + self._tokenizer = tokenizer + if self.is_real_example and self._tokenizer is None: + raise ValueError("Must specify tokenizer when input is a real example.") + + self.user_utterance = '' + self.system_utterance = '' + # The id of each subword in the vocabulary for BERT. + self.utterance_ids = [0] * self._max_seq_length + # Denotes the identity of the sequence. Takes values 0 (system utterance) and 1 (user utterance). + self.utterance_segment = [0] * self._max_seq_length + # Mask which takes the value 0 for padded tokens and 1 otherwise. + self.utterance_mask = [0] * self._max_seq_length + # Start and inclusive end character indices in the original utterance + # corresponding to the tokens. This is used to obtain the character indices + # from the predicted subword indices during inference. + # NOTE: A positive value indicates the character indices in the user + # utterance whereas a negative value indicates the character indices in the + # system utterance. The indices are offset by 1 to prevent ambiguity in the + # 0 index, which could be in either the user or system utterance by the + # above convention. Now the 0 index corresponds to padded tokens. + self.start_char_idx = [0] * self._max_seq_length + self.end_char_idx = [0] * self._max_seq_length + + # Number of categorical slots present in the service. + self.num_categorical_slots = 0 + # The status of each categorical slot in the service. + self.categorical_slot_status = [STATUS_OFF] * schema_config["MAX_NUM_CAT_SLOT"] + # Masks out categorical status for padded cat slots + self.cat_slot_status_mask = [0] * len(self.categorical_slot_status) + # Number of values taken by each categorical slot. + self.num_categorical_slot_values = [0] * schema_config["MAX_NUM_CAT_SLOT"] + # The index of the correct value for each categorical slot. + self.categorical_slot_values = [0] * schema_config["MAX_NUM_CAT_SLOT"] + # Masks out categorical slots values for slots not used in the service + self.cat_slot_values_mask = [ + [0] * schema_config["MAX_NUM_VALUE_PER_CAT_SLOT"] for _ in range(schema_config["MAX_NUM_CAT_SLOT"]) + ] + + # Number of non-categorical slots present in the service. + self.num_noncategorical_slots = 0 + # The status of each non-categorical slot in the service. + self.noncategorical_slot_status = [STATUS_OFF] * schema_config["MAX_NUM_NONCAT_SLOT"] + # Masks out non-categorical status for padded cat slots + self.noncat_slot_status_mask = [0] * len(self.noncategorical_slot_status) + # The index of the starting subword corresponding to the slot span for a + # non-categorical slot value. + self.noncategorical_slot_value_start = [0] * schema_config["MAX_NUM_NONCAT_SLOT"] + # The index of the ending (inclusive) subword corresponding to the slot span + # for a non-categorical slot value. + self.noncategorical_slot_value_end = [0] * schema_config["MAX_NUM_NONCAT_SLOT"] + + # Total number of slots present in the service. All slots are included here + # since every slot can be requested. + self.num_slots = 0 + # Takes value 1 if the corresponding slot is requested, 0 otherwise. + self.requested_slot_status = [STATUS_OFF] * ( + schema_config["MAX_NUM_CAT_SLOT"] + schema_config["MAX_NUM_NONCAT_SLOT"] + ) + # Masks out requested slots that are not used for the service + self.requested_slot_mask = [0] * len(self.requested_slot_status) + + # Total number of intents present in the service. + self.num_intents = 0 + # Takes value 1 if the intent is active, 0 otherwise. + self.intent_status = [STATUS_OFF] * schema_config["MAX_NUM_INTENT"] + # Masks out intents that are not used for the service, [1] for none intent + self.intent_status_mask = [1] + [0] * len(self.intent_status) + # Label for active intent in the turn + self.intent_status_labels = 0 + + @property + def readable_summary(self): + """Get a readable dict that summarizes the attributes of an InputExample.""" + seq_length = sum(self.utterance_mask) + utt_toks = self._tokenizer.convert_ids_to_tokens(self.utterance_ids[:seq_length]) + utt_tok_mask_pairs = list(zip(utt_toks, self.utterance_segment[:seq_length])) + active_intents = [ + self.service_schema.get_intent_from_id(idx) + for idx, s in enumerate(self.intent_status) + if s == STATUS_ACTIVE + ] + if len(active_intents) > 1: + raise ValueError("Should not have multiple active intents in a single service.") + active_intent = active_intents[0] if active_intents else "" + slot_values_in_state = {} + for idx, s in enumerate(self.categorical_slot_status): + if s == STATUS_ACTIVE: + value_id = self.categorical_slot_values[idx] + slot_values_in_state[ + self.service_schema.get_categorical_slot_from_id(idx) + ] = self.service_schema.get_categorical_slot_value_from_id(idx, value_id) + elif s == STATUS_DONTCARE: + slot_values_in_state[self.service_schema.get_categorical_slot_from_id(idx)] = STR_DONTCARE + for idx, s in enumerate(self.noncategorical_slot_status): + if s == STATUS_ACTIVE: + slot = self.service_schema.get_non_categorical_slot_from_id(idx) + start_id = self.noncategorical_slot_value_start[idx] + end_id = self.noncategorical_slot_value_end[idx] + # Token list is consisted of the subwords that may start with "##". We + # remove "##" to reconstruct the original value. Note that it's not a + # strict restoration of the original string. It's primarily used for + # debugging. + # ex. ["san", "j", "##ose"] --> "san jose" + readable_value = " ".join(utt_toks[start_id : end_id + 1]).replace(" ##", "") + slot_values_in_state[slot] = readable_value + elif s == STATUS_DONTCARE: + slot = self.service_schema.get_non_categorical_slot_from_id(idx) + slot_values_in_state[slot] = STR_DONTCARE + + summary_dict = { + "utt_tok_mask_pairs": utt_tok_mask_pairs, + "utt_len": seq_length, + "num_categorical_slots": self.num_categorical_slots, + "num_categorical_slot_values": self.num_categorical_slot_values, + "num_noncategorical_slots": self.num_noncategorical_slots, + "service_name": self.service_schema.service_name, + "active_intent": active_intent, + "slot_values_in_state": slot_values_in_state, + } + return summary_dict + + def add_utterance_features( + self, system_tokens, system_inv_alignments, user_tokens, user_inv_alignments, system_utterance, user_utterance + ): + """Add utterance related features input to bert. + + Note: this method modifies the system tokens and user_tokens in place to + make their total length <= the maximum input length for BERT model. + + Args: + system_tokens: a list of strings which represents system utterance. + system_inv_alignments: a list of tuples which denotes the start and end + charater of the tpken that a bert token originates from in the original + system utterance. + user_tokens: a list of strings which represents user utterance. + user_inv_alignments: a list of tuples which denotes the start and end + charater of the token that a bert token originates from in the original + user utterance. + """ + # Make user-system utterance input (in BERT format) + # Input sequence length for utterance BERT encoder + max_utt_len = self._max_seq_length + + # Modify lengths of sys & usr utterance so that length of total utt + # (including cls_token, setp_token, sep_token) is no more than max_utt_len + is_too_long = truncate_seq_pair(system_tokens, user_tokens, max_utt_len - 3) + if is_too_long: + logging.debug(f'Utterance sequence truncated in example id - {self.example_id}.') + + # Construct the tokens, segment mask and valid token mask which will be + # input to BERT, using the tokens for system utterance (sequence A) and + # user utterance (sequence B). + utt_subword = [] + utt_seg = [] + utt_mask = [] + start_char_idx = [] + end_char_idx = [] + + utt_subword.append(self._tokenizer.cls_token) + utt_seg.append(0) + utt_mask.append(1) + start_char_idx.append(0) + end_char_idx.append(0) + + for subword_idx, subword in enumerate(system_tokens): + utt_subword.append(subword) + utt_seg.append(0) + utt_mask.append(1) + st, en = system_inv_alignments[subword_idx] + start_char_idx.append(-(st + 1)) + end_char_idx.append(-(en + 1)) + + utt_subword.append(self._tokenizer.sep_token) + utt_seg.append(0) + utt_mask.append(1) + start_char_idx.append(0) + end_char_idx.append(0) + + for subword_idx, subword in enumerate(user_tokens): + utt_subword.append(subword) + utt_seg.append(1) + utt_mask.append(1) + st, en = user_inv_alignments[subword_idx] + start_char_idx.append(st + 1) + end_char_idx.append(en + 1) + + utt_subword.append(self._tokenizer.sep_token) + utt_seg.append(1) + utt_mask.append(1) + start_char_idx.append(0) + end_char_idx.append(0) + + utterance_ids = self._tokenizer.tokens_to_ids(utt_subword) + + # Zero-pad up to the BERT input sequence length. + while len(utterance_ids) < max_utt_len: + utterance_ids.append(0) + utt_seg.append(0) + utt_mask.append(0) + start_char_idx.append(0) + end_char_idx.append(0) + self.utterance_ids = utterance_ids + self.utterance_segment = utt_seg + self.utterance_mask = utt_mask + self.start_char_idx = start_char_idx + self.end_char_idx = end_char_idx + + self.user_utterances = user_utterance + self.system_utterance = system_utterance + + def make_copy_with_utterance_features(self): + """Make a copy of the current example with utterance features.""" + new_example = InputExample( + schema_config=self.schema_config, + service_schema=self.service_schema, + example_id=self.example_id, + example_id_num=self.example_id_num, + is_real_example=self.is_real_example, + tokenizer=self._tokenizer, + ) + new_example.utterance_ids = list(self.utterance_ids) + new_example.utterance_segment = list(self.utterance_segment) + new_example.utterance_mask = list(self.utterance_mask) + new_example.start_char_idx = list(self.start_char_idx) + new_example.end_char_idx = list(self.end_char_idx) + new_example.user_utterance = self.user_utterance + new_example.system_utterance = self.system_utterance + return new_example + + def add_categorical_slots(self, state_update): + """Add features for categorical slots.""" + categorical_slots = self.service_schema.categorical_slots + self.num_categorical_slots = len(categorical_slots) + for slot_idx, slot in enumerate(categorical_slots): + values = state_update.get(slot, []) + # Add categorical slot value features. + slot_values = self.service_schema.get_categorical_slot_values(slot) + self.num_categorical_slot_values[slot_idx] = len(slot_values) + # set slot mask to 1, i.e. the slot is active in the service + self.cat_slot_status_mask[slot_idx] = 1 + # set the number of active slot values for this slots in the service + for slot_value_idx in range(len(self.service_schema._categorical_slot_values[slot])): + self.cat_slot_values_mask[slot_idx][slot_value_idx] = 1 + + if not values: + self.categorical_slot_status[slot_idx] = STATUS_OFF + elif values[0] == STR_DONTCARE: + self.categorical_slot_status[slot_idx] = STATUS_DONTCARE + else: + self.categorical_slot_status[slot_idx] = STATUS_ACTIVE + self.categorical_slot_values[slot_idx] = self.service_schema.get_categorical_slot_value_id( + slot, values[0] + ) + + def add_noncategorical_slots(self, state_update, system_span_boundaries, user_span_boundaries): + """Add features for non-categorical slots.""" + noncategorical_slots = self.service_schema.non_categorical_slots + self.num_noncategorical_slots = len(noncategorical_slots) + for slot_idx, slot in enumerate(noncategorical_slots): + values = state_update.get(slot, []) + self.noncat_slot_status_mask[slot_idx] = 1 + if not values: + self.noncategorical_slot_status[slot_idx] = STATUS_OFF + elif values[0] == STR_DONTCARE: + self.noncategorical_slot_status[slot_idx] = STATUS_DONTCARE + else: + self.noncategorical_slot_status[slot_idx] = STATUS_ACTIVE + # Add indices of the start and end tokens for the first encountered + # value. Spans in user utterance are prioritized over the system + # utterance. If a span is not found, the slot value is ignored. + if slot in user_span_boundaries: + start, end = user_span_boundaries[slot] + elif slot in system_span_boundaries: + start, end = system_span_boundaries[slot] + else: + # A span may not be found because the value was cropped out or because + # the value was mentioned earlier in the dialogue. Since this model + # only makes use of the last two utterances to predict state updates, + # it will fail in such cases. + logging.debug( + f'"Slot values {str(values)} not found in user or system utterance in example with id - {self.example_id}.' + ) + + continue + self.noncategorical_slot_value_start[slot_idx] = start + self.noncategorical_slot_value_end[slot_idx] = end + + def add_requested_slots(self, frame): + all_slots = self.service_schema.slots + self.num_slots = len(all_slots) + for slot_idx, slot in enumerate(all_slots): + self.requested_slot_mask[slot_idx] = 1 + if slot in frame["state"]["requested_slots"]: + self.requested_slot_status[slot_idx] = STATUS_ACTIVE + + def add_intents(self, frame): + all_intents = self.service_schema.intents + self.num_intents = len(all_intents) + for intent_idx, intent in enumerate(all_intents): + if intent == frame["state"]["active_intent"]: + self.intent_status[intent_idx] = STATUS_ACTIVE + # adding +1 to take none intent into account + # supports only 1 active intent in the turn + self.intent_status_labels = intent_idx + 1 + self.intent_status_mask[intent_idx + 1] = 1 + + +# Modified from run_classifier._truncate_seq_pair in the public bert model repo. +# https://github.com/google-research/bert/blob/master/run_classifier.py. +def truncate_seq_pair(tokens_a, tokens_b, max_length): + """Truncate a seq pair in place so that their total length <= max_length.""" + is_too_long = False + # This is a simple heuristic which will always truncate the longer sequence + # one token at a time. This makes more sense than truncating an equal percent + # of tokens from each, since if one sequence is very short then each token + # that's truncated likely contains more information than a longer sequence. + while True: + total_length = len(tokens_a) + len(tokens_b) + if total_length <= max_length: + break + is_too_long = True + if len(tokens_a) > len(tokens_b): + tokens_a.pop() + else: + tokens_b.pop() + return is_too_long diff --git a/nemo/collections/nlp/data/datasets/sgd_dataset/metrics.py b/nemo/collections/nlp/data/datasets/sgd_dataset/metrics.py new file mode 100644 index 000000000000..6ad5f81dbdae --- /dev/null +++ b/nemo/collections/nlp/data/datasets/sgd_dataset/metrics.py @@ -0,0 +1,284 @@ +# ============================================================================= +# Copyright 2020 NVIDIA. All Rights Reserved. +# Copyright 2019 The Google Research Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= + +"""Evaluation metrics for Schema-guided dialogue. + +This library provides functions for calculating the evaluation metrics for a +single dialogue. The following metrics are defined: + +(1) Active intent accuracy: The fraction of user turns for which the active + intent has been correctly predicted. +(2) Slot tagging F1: The macro-averaged F1 score for tagging slot values for + non-categorical slots. This metric is optional to report in the final paper + if participants decide not to use slot tagging. +(3) Requested slots F1: The macro-averaged F1 score for requested slots over the + turns. For a turn, if there are no requested slots in both the ground truth + and the prediction, that turn is skipped. The reported number is the average + F1 score for all un-skipped user turns. This metric is optional to report in + the final paper. +(4) Average goal accuracy: For each turn, participants must predict a single + value for each slot present in the dialogue state. The slots which have a + non-empty assignment in the ground truth dialogue state are only considered. + This is the average accuracy of predicting the value of a slot correctly. A + fuzzy matching based score is used for non-categorical slots. +(5) Joint goal accuracy: This is the average accuracy of predicting all slot + assignments for a turn correctly. A fuzzy matching based score is used for + non-categorical slots. This is the primary evaluation metric used for ranking + submissions. More details to follow with the evaluation script. + +This file contains code artifacts adapted from the original implementation: +https://github.com/google-research/google-research/blob/master/schema_guided_dst/metrics.py +""" + +import collections + +import numpy as np +from rapidfuzz import fuzz + +F1Scores = collections.namedtuple("F1Scores", ["f1", "precision", "recall"]) + +# Evaluation and other relevant metrics for DSTC8 Schema-guided DST. +# (1) Active intent accuracy. +ACTIVE_INTENT_ACCURACY = "active_intent_accuracy" +# (2) Slot tagging F1. +SLOT_TAGGING_F1 = "slot_tagging_f1" +SLOT_TAGGING_PRECISION = "slot_tagging_precision" +SLOT_TAGGING_RECALL = "slot_tagging_recall" +# (3) Requested slots F1. +REQUESTED_SLOTS_F1 = "requested_slots_f1" +REQUESTED_SLOTS_PRECISION = "requested_slots_precision" +REQUESTED_SLOTS_RECALL = "requested_slots_recall" +# (4) Average goal accuracy. +AVERAGE_GOAL_ACCURACY = "average_goal_accuracy" +AVERAGE_CAT_ACCURACY = "average_cat_accuracy" +AVERAGE_NONCAT_ACCURACY = "average_noncat_accuracy" +# (5) Joint goal accuracy. +JOINT_GOAL_ACCURACY = "joint_goal_accuracy" +JOINT_CAT_ACCURACY = "joint_cat_accuracy" +JOINT_NONCAT_ACCURACY = "joint_noncat_accuracy" + +NAN_VAL = "NA" + + +def compute_f1(list_ref, list_hyp): + """Compute F1 score from reference (grouth truth) list and hypothesis list. + + Args: + list_ref: List of true elements. + list_hyp: List of postive (retrieved) elements. + + Returns: + A F1Scores object containing F1, precision, and recall scores. + """ + + ref = collections.Counter(list_ref) + hyp = collections.Counter(list_hyp) + true = sum(ref.values()) + positive = sum(hyp.values()) + true_positive = sum((ref & hyp).values()) + precision = float(true_positive) / positive if positive else 1.0 + recall = float(true_positive) / true if true else 1.0 + if precision + recall > 0.0: + f1 = 2.0 * precision * recall / (precision + recall) + else: # The F1-score is defined to be 0 if both precision and recall are 0. + f1 = 0.0 + + return F1Scores(f1=f1, precision=precision, recall=recall) + + +def fuzzy_string_match(str_ref, str_hyp): + """Returns fuzzy string similarity score in range [0.0, 1.0].""" + + # The higher the score, the higher the similarity between the two strings. + return fuzz.token_sort_ratio(str_ref, str_hyp) / 100.0 + + +def noncat_slot_value_match(str_ref_list, str_hyp, no_fuzzy_match): + """Calculate non-categorical slots correctness. + + Args: + str_ref_list: a list of reference strings. + str_hyp: the hypothesis string. + use_fuzzy_match: whether to use fuzzy string matching. + + Returns: + score: The highest fuzzy string match score of the references and hypotheis. + """ + score = 0.0 + for str_ref in str_ref_list: + if no_fuzzy_match: + match_score = float(str_ref == str_hyp) + else: + match_score = fuzzy_string_match(str_ref, str_hyp) + score = max(score, match_score) + return score + + +def compare_slot_values(slot_values_ref, slot_values_hyp, service, no_fuzzy_match): + """Compare and get correctness of goal state's slot_values. + + Args: + slot_values_ref: goal state slot_values from reference (ground truth). + slot_values_hyp: goal state slot_values from hypothesis (prediction). + service: a service data structure in the schema. We use it to obtain the + list of slots in the service and infer whether a slot is categorical. + use_fuzzy_match: whether to use fuzzy string matching for non-categorical + slot values + + Returns: + (list_cor, slot_active, slot_cat) + list_cor: list of corectness scores, each corresponding to one slot in the + service. The score is a float either 0.0 or 1.0 for categorical slot, + and in range [0.0, 1.0] for non-categorical slot. + slot_active: list indicating whether the element in list_cor corresponds to + an active ground-truth slot. + slot_cat: list indicating whether the element in list_cor corresponds to a + categorical slot. + """ + list_cor = [] + slot_active = [] + slot_cat = [] + + for slot in service["slots"]: + slot_name = slot["name"] + slot_cat.append(slot["is_categorical"]) + + if slot_name in slot_values_ref: # REF=active + slot_active.append(True) + if slot_name in slot_values_hyp: # HYP=active, apply matching + value_ref_list = slot_values_ref[slot_name] + value_hyp = slot_values_hyp[slot_name][0] + if slot["is_categorical"]: + cor = float(value_ref_list[0] == value_hyp) + else: + cor = noncat_slot_value_match(value_ref_list, value_hyp, no_fuzzy_match) + + list_cor.append(cor) + else: # HYP=off + list_cor.append(0.0) + else: # REF=off + slot_active.append(False) + if slot_name in slot_values_hyp: # HYP=active + list_cor.append(0.0) + else: # HYP=off + list_cor.append(1.0) + + assert len(list_cor) == len(service["slots"]) + assert len(slot_active) == len(service["slots"]) + assert len(slot_cat) == len(service["slots"]) + return list_cor, slot_active, slot_cat + + +def get_active_intent_accuracy(frame_ref, frame_hyp): + """Get active intent accuracy of a frame. + + Args: + frame_ref: single semantic frame from reference (ground truth) file. + frame_hyp: single semantic frame from hypothesis (prediction) file. + + Returns: + 1.0 if the intent prediction is correct, otherwise 0.0. + """ + return float(frame_ref["state"]["active_intent"] == frame_hyp["state"]["active_intent"]) + + +def get_slot_tagging_f1(frame_ref, frame_hyp, utt, service): + """Get slot tagging (non-categorical slots only) F1 scores of a frame. + + Args: + frame_ref: single semantic frame from reference (ground truth) file. + frame_hyp: single semantic frame from hypothesis (prediction) file. + utt: user utterance. Slot tagging annotations are the character positions in + the utterance. + service: a service data structure in the schema. We use it to infer whether + a slot is non-categorical. + + Returns: + A F1Scores object containing F1, precision, and recall scores. + """ + + list_noncat_slots = [s["name"] for s in service["slots"] if not s["is_categorical"]] + if "slots" not in frame_hyp: + return None + else: + list_ref = [ + (s["slot"], utt[s["start"] : s["exclusive_end"]]) + for s in frame_ref["slots"] + if s["slot"] in list_noncat_slots + ] + list_hyp = [ + (s["slot"], utt[s["start"] : s["exclusive_end"]]) + for s in frame_hyp["slots"] + if s["slot"] in list_noncat_slots + ] + return compute_f1(list_ref, list_hyp) + + +def get_requested_slots_f1(frame_ref, frame_hyp): + """Get requested slots F1 scores of a frame. + + Args: + frame_ref: single semantic frame from reference (ground truth) file. + frame_hyp: single semantic frame from hypothesis (prediction) file. + + Returns: + A F1Scores object containing F1, precision, and recall scores. + """ + return compute_f1(frame_ref["state"]["requested_slots"], frame_hyp["state"]["requested_slots"]) + + +def get_average_and_joint_goal_accuracy(frame_ref, frame_hyp, service, no_fuzzy_match): + """Get average and joint goal accuracies of a frame. + + Args: + frame_ref: single semantic frame from reference (ground truth) file. + frame_hyp: single semantic frame from hypothesis (prediction) file. + service: a service data structure in the schema. We use it to obtain the + list of slots in the service and infer whether a slot is categorical. + use_fuzzy_match: whether to use fuzzy string matching for comparing + non-categorical slot values. + + Returns: + goal_acc: a dict whose values are average / joint + all-goal / categorical-goal / non-categorical-goal accuracies. + """ + goal_acc = {} + + list_acc, slot_active, slot_cat = compare_slot_values( + frame_ref["state"]["slot_values"], frame_hyp["state"]["slot_values"], service, no_fuzzy_match + ) + + # (4) Average goal accuracy. + active_acc = [acc for acc, active in zip(list_acc, slot_active) if active] + goal_acc[AVERAGE_GOAL_ACCURACY] = np.mean(active_acc) if active_acc else NAN_VAL + # (4-a) categorical. + active_cat_acc = [acc for acc, active, cat in zip(list_acc, slot_active, slot_cat) if active and cat] + goal_acc[AVERAGE_CAT_ACCURACY] = np.mean(active_cat_acc) if active_cat_acc else NAN_VAL + # (4-b) non-categorical. + active_noncat_acc = [acc for acc, active, cat in zip(list_acc, slot_active, slot_cat) if active and not cat] + goal_acc[AVERAGE_NONCAT_ACCURACY] = np.mean(active_noncat_acc) if active_noncat_acc else NAN_VAL + + # (5) Joint goal accuracy. + goal_acc[JOINT_GOAL_ACCURACY] = np.prod(list_acc) if list_acc else NAN_VAL + # (5-a) categorical. + cat_acc = [acc for acc, cat in zip(list_acc, slot_cat) if cat] + goal_acc[JOINT_CAT_ACCURACY] = np.prod(cat_acc) if cat_acc else NAN_VAL + # (5-b) non-categorical. + noncat_acc = [acc for acc, cat in zip(list_acc, slot_cat) if not cat] + goal_acc[JOINT_NONCAT_ACCURACY] = np.prod(noncat_acc) if noncat_acc else NAN_VAL + + return goal_acc diff --git a/nemo/collections/nlp/data/datasets/sgd_dataset/prediction_utils.py b/nemo/collections/nlp/data/datasets/sgd_dataset/prediction_utils.py new file mode 100644 index 000000000000..7207406e3321 --- /dev/null +++ b/nemo/collections/nlp/data/datasets/sgd_dataset/prediction_utils.py @@ -0,0 +1,357 @@ +# ============================================================================= +# Copyright 2020 NVIDIA. All Rights Reserved. +# Copyright 2019 The Google Research Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= + +""" +Prediction and evaluation-related utility functions. +This file contains code artifacts adapted from the original implementation: +https://github.com/google-research/google-research/blob/master/schema_guided_dst/baseline/pred_utils.py +""" + +import collections +import json +import os + +from nemo import logging +from nemo.collections.nlp.data.datasets.sgd_dataset.input_example import STATUS_ACTIVE, STATUS_DONTCARE, STR_DONTCARE + +REQ_SLOT_THRESHOLD = 0.5 + +__all__ = ['get_predicted_dialog_baseline', 'write_predictions_to_file'] + + +def get_predicted_dialog_ret_sys_act(dialog, all_predictions, schemas, eval_debug, in_domain_services): + """Update labels in a dialogue based on model predictions. + Args: + dialog: A json object containing dialogue whose labels are to be updated. + all_predictions: A dict mapping prediction name to the predicted value. See + SchemaGuidedDST class for the contents of this dict. + schemas: A Schema object wrapping all the schemas for the dataset. + Returns: + A json object containing the dialogue with labels predicted by the model. + """ + # This approach retreives slot values from the history of system actions if slot is active but it can not find it in user utterance + # Overwrite the labels in the turn with the predictions from the model. For + # test set, these labels are missing from the data and hence they are added. + dialog_id = dialog["dialogue_id"] + # The slot values tracked for each service. + all_slot_values = collections.defaultdict(dict) + sys_prev_slots = collections.defaultdict(dict) + sys_rets = {} + + for turn_idx, turn in enumerate(dialog["turns"]): + if turn["speaker"] == "SYSTEM": + for frame in turn["frames"]: + for action in frame["actions"]: + if action["slot"] and len(action["values"]) > 0: + sys_prev_slots[frame["service"]][action["slot"]] = action["values"][0] + elif turn["speaker"] == "USER": + user_utterance = turn["utterance"] + system_utterance = dialog["turns"][turn_idx - 1]["utterance"] if turn_idx else "" + turn_id = "{:02d}".format(turn_idx) + for frame in turn["frames"]: + cat_slot_status_acc = 0 + cat_slot_status_num = 0 + noncat_slot_status_num = 0 + noncat_slot_status_acc = 0 + + predictions = all_predictions[(dialog_id, turn_id, frame["service"])] + slot_values = all_slot_values[frame["service"]] + service_schema = schemas.get_service_schema(frame["service"]) + + # Remove the slot spans and state if present. + true_slots = frame.pop("slots", None) + true_state = frame.pop("state", None) + + # The baseline model doesn't predict slot spans. Only state predictions + # are added. + state = {} + + # Add prediction for active intent. Offset is subtracted to account for + # NONE intent. + active_intent_id = predictions["intent_status"] + state["active_intent"] = ( + service_schema.get_intent_from_id(active_intent_id - 1) if active_intent_id else "NONE" + ) + + # Add prediction for requested slots. + requested_slots = [] + for slot_idx, slot in enumerate(service_schema.slots): + if predictions["req_slot_status"][slot_idx] > REQ_SLOT_THRESHOLD: + requested_slots.append(slot) + state["requested_slots"] = requested_slots + + # Add prediction for user goal (slot values). + # Categorical slots. + categorical_slots_dict = {} + non_categorical_slots_dict = {} + + predictions["cat_slot_status_p"] = predictions["cat_slot_status_p"].cpu().numpy() + predictions["cat_slot_status"] = predictions["cat_slot_status"].cpu().numpy() + predictions["cat_slot_value"] = predictions["cat_slot_value"].cpu().numpy() + predictions["cat_slot_value_p"] = predictions["cat_slot_value_p"].cpu().numpy() + + predictions["noncat_slot_status_p"] = predictions["noncat_slot_status_p"].cpu().numpy() + predictions["noncat_slot_status"] = predictions["noncat_slot_status"].cpu().numpy() + predictions["noncat_slot_p"] = predictions["noncat_slot_p"].cpu().numpy() + + predictions["noncat_alignment_start"] = predictions["noncat_alignment_start"].cpu().numpy() + predictions["noncat_alignment_end"] = predictions["noncat_alignment_end"].cpu().numpy() + predictions["cat_slot_status_GT"] = predictions["cat_slot_status_GT"].cpu().numpy() + predictions["noncat_slot_status_GT"] = predictions["noncat_slot_status_GT"].cpu().numpy() + + for slot_idx, slot in enumerate(service_schema.categorical_slots): + # debugging info + cat_slot_status_num += 1 + categorical_slots_dict[slot] = ( + predictions["cat_slot_status_GT"][slot_idx], + predictions["cat_slot_status"][slot_idx], + predictions["cat_slot_status_p"][slot_idx], + service_schema.get_categorical_slot_values(slot)[predictions["cat_slot_value"][slot_idx]], + predictions["cat_slot_value_p"][slot_idx], + ) + + if predictions["cat_slot_status_GT"][slot_idx] == predictions["cat_slot_status"][slot_idx]: + cat_slot_status_acc += 1 + #### + slot_status = predictions["cat_slot_status"][slot_idx] + if slot_status == STATUS_DONTCARE: + slot_values[slot] = STR_DONTCARE + elif slot_status == STATUS_ACTIVE: + # print(predictions["cat_slot_status_p"][slot_idx]) + if ( + predictions["cat_slot_status_p"][slot_idx] + predictions["cat_slot_value_p"][slot_idx] + ) / 2 > 0.9: + value_idx = predictions["cat_slot_value"][slot_idx] + slot_values[slot] = service_schema.get_categorical_slot_values(slot)[value_idx] + else: + if slot in sys_prev_slots[frame["service"]]: + # debugging info + sys_rets[slot] = sys_prev_slots[frame["service"]][slot] + ## + slot_values[slot] = sys_prev_slots[frame["service"]][slot] + print("pooooy", slot_values[slot]) + else: + value_idx = predictions["cat_slot_value"][slot_idx] + slot_values[slot] = service_schema.get_categorical_slot_values(slot)[value_idx] + + for slot_idx, slot in enumerate(service_schema.non_categorical_slots): + tok_start_idx = predictions["noncat_slot_start"][slot_idx] + tok_end_idx = predictions["noncat_slot_end"][slot_idx] + ch_start_idx = predictions["noncat_alignment_start"][tok_start_idx] + ch_end_idx = predictions["noncat_alignment_end"][tok_end_idx] + + # debugging nfo + noncat_slot_status_num += 1 + + non_categorical_slots_dict[slot] = ( + predictions["noncat_slot_status_GT"][slot_idx], + predictions["noncat_slot_status"][slot_idx], + predictions["noncat_slot_status_p"][slot_idx], + (ch_start_idx, ch_end_idx), + user_utterance[ch_start_idx - 1 : ch_end_idx] + if (ch_start_idx > 0 and ch_end_idx > 0) + else system_utterance[-ch_start_idx - 1 : -ch_end_idx], + predictions["noncat_slot_p"][slot_idx], + ) + if predictions["noncat_slot_status_GT"][slot_idx] == predictions["noncat_slot_status"][slot_idx]: + noncat_slot_status_acc += 1 + + slot_status = predictions["noncat_slot_status"][slot_idx] + if slot_status == STATUS_DONTCARE: + slot_values[slot] = STR_DONTCARE + elif slot_status == STATUS_ACTIVE: + tok_start_idx = predictions["noncat_slot_start"][slot_idx] + tok_end_idx = predictions["noncat_slot_end"][slot_idx] + ch_start_idx = predictions["noncat_alignment_start"][tok_start_idx] + ch_end_idx = predictions["noncat_alignment_end"][tok_end_idx] + # logging.debug(ch_start_idx, ch_end_idx) + # logging.debug(f'Active Slot: {slot}') + # logging.debug(f'{predictions["noncat_slot_p"][slot_idx]}, ({ch_start_idx}, {ch_end_idx}), {user_utterance[ch_start_idx - 1 : ch_end_idx]}') + if ch_start_idx > 0 and ch_end_idx > 0: + # Add span from the user utterance. + slot_values[slot] = user_utterance[ch_start_idx - 1 : ch_end_idx] + # elif ch_start_idx < 0 and ch_end_idx < 0: + # Add span from the system utterance. + # slot_values[slot] = system_utterance[-ch_start_idx - 1 : -ch_end_idx] + else: + if slot in sys_prev_slots[frame["service"]]: + # debugging info + sys_rets[slot] = sys_prev_slots[frame["service"]][slot] + ## + slot_values[slot] = sys_prev_slots[frame["service"]][slot] + # elif ch_start_idx < 0 and ch_end_idx < 0: + # slot_values[slot] = system_utterance[-ch_start_idx - 1 : -ch_end_idx] + # print("hoooy", slot_values[slot]) + + if eval_debug and frame["service"] in in_domain_services: + logging.debug("-----------------------------------New Frame------------------------------") + logging.debug(f'SYS : {system_utterance}') + logging.debug(f'USER: {user_utterance}') + + logging.debug("\n") + logging.debug(f"PRED CAT: {categorical_slots_dict}") + logging.debug(f"PRED NON-CAT: {non_categorical_slots_dict}") + + logging.debug("\n") + logging.debug(f"SLOTS - LABEL: {true_slots}") + logging.debug(f"STATE - LABEL: {true_state['slot_values']}") + logging.debug(f"STATE - PRED : {slot_values}") + + logging.debug("\n") + logging.debug(f"SYS PREV SLOT: {sys_prev_slots}") + logging.debug(f"SYS RETS: {sys_rets}") + cat_slot_status_acc = ( + "NAN" if cat_slot_status_num == 0 else cat_slot_status_acc / cat_slot_status_num + ) + logging.debug(f"CAT STATUS ACC: {cat_slot_status_acc}") + noncat_slot_status_acc = ( + "NAN" if noncat_slot_status_num == 0 else noncat_slot_status_acc / noncat_slot_status_num + ) + logging.debug(f"NONCAT STATUS ACC: {noncat_slot_status_acc}") + + # Create a new dict to avoid overwriting the state in previous turns + # because of use of same objects. + state["slot_values"] = {s: [v] for s, v in slot_values.items()} + frame["state"] = state + + return dialog + + +def get_predicted_dialog_baseline(dialog, all_predictions, schemas): + """Update labels in a dialogue based on model predictions. + Args: + dialog: A json object containing dialogue whose labels are to be updated. + all_predictions: A dict mapping prediction name to the predicted value. See + SchemaGuidedDST class for the contents of this dict. + schemas: A Schema object wrapping all the schemas for the dataset. + Returns: + A json object containing the dialogue with labels predicted by the model. + """ + # Overwrite the labels in the turn with the predictions from the model. For + # test set, these labels are missing from the data and hence they are added. + dialog_id = dialog["dialogue_id"] + # The slot values tracked for each service. + all_slot_values = collections.defaultdict(dict) + for turn_idx, turn in enumerate(dialog["turns"]): + if turn["speaker"] == "USER": + user_utterance = turn["utterance"] + system_utterance = dialog["turns"][turn_idx - 1]["utterance"] if turn_idx else "" + turn_id = "{:02d}".format(turn_idx) + for frame in turn["frames"]: + predictions = all_predictions[(dialog_id, turn_id, frame["service"])] + slot_values = all_slot_values[frame["service"]] + service_schema = schemas.get_service_schema(frame["service"]) + # Remove the slot spans and state if present. + frame.pop("slots", None) + frame.pop("state", None) + + # The baseline model doesn't predict slot spans. Only state predictions + # are added. + state = {} + + # Add prediction for active intent. Offset is subtracted to account for + # NONE intent. + active_intent_id = predictions["intent_status"] + state["active_intent"] = ( + service_schema.get_intent_from_id(active_intent_id - 1) if active_intent_id else "NONE" + ) + + # Add prediction for requested slots. + requested_slots = [] + for slot_idx, slot in enumerate(service_schema.slots): + if predictions["req_slot_status"][slot_idx] > REQ_SLOT_THRESHOLD: + requested_slots.append(slot) + state["requested_slots"] = requested_slots + + # Add prediction for user goal (slot values). + # Categorical slots. + for slot_idx, slot in enumerate(service_schema.categorical_slots): + slot_status = predictions["cat_slot_status"][slot_idx] + if slot_status == STATUS_DONTCARE: + slot_values[slot] = STR_DONTCARE + elif slot_status == STATUS_ACTIVE: + value_idx = predictions["cat_slot_value"][slot_idx] + slot_values[slot] = service_schema.get_categorical_slot_values(slot)[value_idx] + # Non-categorical slots. + for slot_idx, slot in enumerate(service_schema.non_categorical_slots): + slot_status = predictions["noncat_slot_status"][slot_idx] + if slot_status == STATUS_DONTCARE: + slot_values[slot] = STR_DONTCARE + elif slot_status == STATUS_ACTIVE: + tok_start_idx = predictions["noncat_slot_start"][slot_idx] + tok_end_idx = predictions["noncat_slot_end"][slot_idx] + ch_start_idx = predictions["noncat_alignment_start"][tok_start_idx] + ch_end_idx = predictions["noncat_alignment_end"][tok_end_idx] + if ch_start_idx < 0 and ch_end_idx < 0: + # Add span from the system utterance. + slot_values[slot] = system_utterance[-ch_start_idx - 1 : -ch_end_idx] + elif ch_start_idx > 0 and ch_end_idx > 0: + # Add span from the user utterance. + slot_values[slot] = user_utterance[ch_start_idx - 1 : ch_end_idx] + # Create a new dict to avoid overwriting the state in previous turns + # because of use of same objects. + state["slot_values"] = {s: [v] for s, v in slot_values.items()} + frame["state"] = state + return dialog + + +def write_predictions_to_file( + predictions, input_json_files, output_dir, schemas, state_tracker, eval_debug, in_domain_services +): + """Write the predicted dialogues as json files. + + Args: + predictions: An iterator containing model predictions. This is the output of + the predict method in the estimator. + input_json_files: A list of json paths containing the dialogues to run + inference on. + schemas: Schemas to all services in the dst dataset (train, dev and test splits). + output_dir: The directory where output json files will be created. + """ + logging.info(f"Writing predictions to {output_dir} started.") + + # Index all predictions. + all_predictions = {} + for idx, prediction in enumerate(predictions): + if not prediction["is_real_example"]: + continue + _, dialog_id, turn_id, service_name = prediction['example_id'].split('-') + all_predictions[(dialog_id, turn_id, service_name)] = prediction + logging.info(f'Predictions for {idx} examples are getting processed.') + + # Read each input file and write its predictions. + for input_file_path in input_json_files: + with open(input_file_path) as f: + dialogs = json.load(f) + logging.info(f'{input_file_path} file is loaded') + pred_dialogs = [] + for d in dialogs: + if state_tracker == 'baseline': + pred_dialog = get_predicted_dialog_baseline(d, all_predictions, schemas) + elif state_tracker == 'ret_sys_act': + pred_dialog = get_predicted_dialog_ret_sys_act( + d, all_predictions, schemas, eval_debug, in_domain_services + ) + else: + raise ValueError(f"tracker_mode {state_tracker} is not defined.") + pred_dialogs.append(pred_dialog) + f.close() + input_file_name = os.path.basename(input_file_path) + output_file_path = os.path.join(output_dir, input_file_name) + with open(output_file_path, "w") as f: + json.dump(pred_dialogs, f, indent=2, separators=(",", ": "), sort_keys=True) + f.close() diff --git a/nemo/collections/nlp/data/datasets/sgd_dataset/schema.py b/nemo/collections/nlp/data/datasets/sgd_dataset/schema.py new file mode 100644 index 000000000000..1462c6329892 --- /dev/null +++ b/nemo/collections/nlp/data/datasets/sgd_dataset/schema.py @@ -0,0 +1,182 @@ +# ============================================================================= +# Copyright 2020 NVIDIA. All Rights Reserved. +# Copyright 2019 The Google Research Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= + +""" +Wrappers for schemas of different services. +This file contains code artifacts adapted from the original implementation: +https://github.com/google-research/google-research/blob/master/schema_guided_dst/schema.py +https://github.com/google-research/google-research/blob/master/schema_guided_dst +""" + +import json + +from nemo import logging + +__all__ = ['ServiceSchema', 'Schema'] + + +class ServiceSchema(object): + """A wrapper for schema for a service.""" + + def __init__(self, schema_json, service_id=None): + self._service_name = schema_json["service_name"] + self._description = schema_json["description"] + self._schema_json = schema_json + self._service_id = service_id + + # Construct the vocabulary for intents, slots, categorical slots, + # non-categorical slots and categorical slot values. These vocabs are used + # for generating indices for their embedding matrix. + self._intents = sorted(i["name"] for i in schema_json["intents"]) + self._slots = sorted(s["name"] for s in schema_json["slots"]) + self._categorical_slots = sorted( + s["name"] for s in schema_json["slots"] if s["is_categorical"] and s["name"] in self.state_slots + ) + self._non_categorical_slots = sorted( + s["name"] for s in schema_json["slots"] if not s["is_categorical"] and s["name"] in self.state_slots + ) + slot_schemas = {s["name"]: s for s in schema_json["slots"]} + categorical_slot_values = {} + categorical_slot_value_ids = {} + for slot in self._categorical_slots: + slot_schema = slot_schemas[slot] + values = sorted(slot_schema["possible_values"]) + categorical_slot_values[slot] = values + value_ids = {value: idx for idx, value in enumerate(values)} + categorical_slot_value_ids[slot] = value_ids + self._categorical_slot_values = categorical_slot_values + self._categorical_slot_value_ids = categorical_slot_value_ids + + @property + def schema_json(self): + return self._schema_json + + @property + def state_slots(self): + """Set of slots which are permitted to be in the dialogue state.""" + state_slots = set() + for intent in self._schema_json["intents"]: + state_slots.update(intent["required_slots"]) + state_slots.update(intent["optional_slots"]) + return state_slots + + @property + def service_name(self): + return self._service_name + + @property + def service_id(self): + return self._service_id + + @property + def description(self): + return self._description + + @property + def slots(self): + return self._slots + + @property + def intents(self): + return self._intents + + @property + def categorical_slots(self): + return self._categorical_slots + + @property + def non_categorical_slots(self): + return self._non_categorical_slots + + def get_categorical_slot_values(self, slot): + return self._categorical_slot_values[slot] + + def get_slot_from_id(self, slot_id): + return self._slots[slot_id] + + def get_intent_from_id(self, intent_id): + return self._intents[intent_id] + + def get_categorical_slot_from_id(self, slot_id): + return self._categorical_slots[slot_id] + + def get_non_categorical_slot_from_id(self, slot_id): + return self._non_categorical_slots[slot_id] + + def get_categorical_slot_value_from_id(self, slot_id, value_id): + slot = self.categorical_slots[slot_id] + return self._categorical_slot_values[slot][value_id] + + def get_categorical_slot_value_id(self, slot, value): + return self._categorical_slot_value_ids[slot][value] + + +class Schema(object): + """Wrapper for schemas for all services in a dataset.""" + + def __init__(self, schema_json_paths): + """ + TODO fix: + schema_json_paths: list of .json path to schema files of a single str with path to the json file. + """ + # Load the schema from the json file. + if isinstance(schema_json_paths, str): + with open(schema_json_paths, "r") as f: + all_schemas = json.load(f) + f.close() + else: + # load multiple schemas from the list of the json files + all_schemas = [] + completed_services = [] + for schema_json_path in schema_json_paths: + with open(schema_json_path, "r") as f: + schemas = json.load(f) + f.close() + logging.debug("Num of services in %s: %s", schema_json_path, len(schemas)) + + for service in schemas: + if service['service_name'] not in completed_services: + completed_services.append(service['service_name']) + all_schemas.append(service) + + self._services = sorted(schema["service_name"] for schema in all_schemas) + self._services_vocab = {v: k for k, v in enumerate(self._services)} + self._services_id_to_vocab = {v: k for k, v in self._services_vocab.items()} + service_schemas = {} + for schema in all_schemas: + service = schema["service_name"] + service_schemas[service] = ServiceSchema(schema, service_id=self.get_service_id(service)) + + self._service_schemas = service_schemas + self._schemas = all_schemas + + def get_service_id(self, service): + return self._services_vocab[service] + + def get_service_from_id(self, service_id): + return self._services[service_id] + + def get_service_schema(self, service): + return self._service_schemas[service] + + @property + def services(self): + return self._services + + def save_to_file(self, file_path): + with open(file_path, "w") as f: + json.dump(self._schemas, f, indent=2) diff --git a/nemo/collections/nlp/data/datasets/sgd_dataset/schema_embedding_dataset.py b/nemo/collections/nlp/data/datasets/sgd_dataset/schema_embedding_dataset.py new file mode 100644 index 000000000000..9bb5ff65148b --- /dev/null +++ b/nemo/collections/nlp/data/datasets/sgd_dataset/schema_embedding_dataset.py @@ -0,0 +1,357 @@ +# ============================================================================= +# Copyright 2020 NVIDIA. All Rights Reserved. +# Copyright 2019 The Google Research Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= + +""" +Extract BERT embeddings for slots, values, intents in schema. + +This file contains code artifacts adapted from the original implementation: +https://github.com/google-research/google-research/blob/master/schema_guided_dst/baseline/extract_schema_embedding.py +""" + +import collections +import random +import re + +import numpy as np +import torch +from torch.utils.data import Dataset + +from nemo import logging +from nemo.collections.nlp.data.datasets.sgd_dataset.input_example import truncate_seq_pair + +# Separator to separate the two sentences in BERT's input sequence. +_NL_SEPARATOR = "|||" + +__all__ = ['SchemaEmbeddingDataset'] + + +class SchemaEmbeddingDataset(Dataset): + def __init__(self, schema_config, tokenizer, schemas): + """Generate the embeddings for a schema's elements. + + Args: + tokenizer (tokenizer): such as NemoBertTokenizer + max_seq_length: Sequence length used for BERT model + schemas: Schemas for all services in the datasets + """ + self._tokenizer = tokenizer + self.schema_config = schema_config + self.schemas = schemas + + input_features = self._get_input_features() + + self.features = collections.defaultdict(list) + + for feature in input_features: + self.features["input_ids"].append(feature.input_ids) + self.features["input_mask"].append(feature.input_mask) + self.features["input_type_ids"].append(feature.input_type_ids) + self.features["embedding_tensor_name"].append(feature.embedding_tensor_name) + self.features["service_id"].append(feature.service_id) + self.features["intent_or_slot_id"].append(feature.intent_or_slot_id) + self.features["value_id"].append(feature.value_id) + + def __len__(self): + return len(self.features['input_ids']) + + def __getitem__(self, idx): + return ( + np.array(self.features['input_ids'][idx]), + np.array(self.features['input_mask'][idx], dtype=np.long), + np.array(self.features['input_type_ids'][idx]), + ) + + def _create_feature(self, line, embedding_tensor_name, service_id, intent_or_slot_id, value_id=-1): + """Create a single InputFeatures instance.""" + seq_length = self.schema_config["MAX_SEQ_LENGTH"] + # line = tokenization.convert_to_unicode(input_line) + line = line.strip() + text_a = None + text_b = None + m = re.match(r"^(.*) \|\|\| (.*)$", line) + if m is None: + text_a = line + else: + text_a = m.group(1) + text_b = m.group(2) + + tokens_a = self._tokenizer.text_to_tokens(text_a) + tokens_b = None + if text_b: + tokens_b = self._tokenizer.text_to_tokens(text_b) + + if tokens_b: + # Modifies `tokens_a` and `tokens_b` in place so that the total + # length is less than the specified length. + # Account for [CLS], [SEP], [SEP] with "- 3" + truncate_seq_pair(tokens_a, tokens_b, seq_length - 3) + else: + # Account for [CLS] and [SEP] with "- 2" + if len(tokens_a) > seq_length - 2: + tokens_a = tokens_a[0 : (seq_length - 2)] + + # The convention in BERT is: + # (a) For sequence pairs: + # tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP] + # type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1 + # (b) For single sequences: + # tokens: [CLS] the dog is hairy . [SEP] + # type_ids: 0 0 0 0 0 0 0 + # + # Where "type_ids" are used to indicate whether this is the first + # sequence or the second sequence. The embedding vectors for `type=0` and + # `type=1` were learned during pre-training and are added to the wordpiece + # embedding vector (and position vector). This is not *strictly* necessary + # since the [SEP] token unambiguously separates the sequences, but it + # makes it easier for the model to learn the concept of sequences. + # + # For classification tasks, the first vector (corresponding to [CLS]) is + # used as as the "sentence vector". Note that this only makes sense + # because the entire model is fine-tuned. + tokens = [] + input_type_ids = [] + tokens.append(self._tokenizer.cls_token) + input_type_ids.append(0) + for token in tokens_a: + tokens.append(token) + input_type_ids.append(0) + tokens.append(self._tokenizer.sep_token) + input_type_ids.append(0) + + if tokens_b: + for token in tokens_b: + tokens.append(token) + input_type_ids.append(1) + tokens.append(self._tokenizer.sep_token) + input_type_ids.append(1) + + input_ids = self._tokenizer.tokens_to_ids(tokens) + + # The mask has 1 for real tokens and 0 for padding tokens. Only real + # tokens are attended to. + input_mask = [1] * len(input_ids) + + # Zero-pad up to the sequence length. + while len(input_ids) < seq_length: + input_ids.append(0) + input_mask.append(0) + input_type_ids.append(0) + assert len(input_ids) == seq_length + assert len(input_mask) == seq_length + assert len(input_type_ids) == seq_length + + return InputFeatures( + input_ids=input_ids, + input_mask=input_mask, + input_type_ids=input_type_ids, + embedding_tensor_name=embedding_tensor_name, + service_id=service_id, + intent_or_slot_id=intent_or_slot_id, + value_id=value_id, + ) + + def _get_intents_input_features(self, service_schema): + """Create features for BERT inference for all intents of a service. + + We use "[service description] ||| [intent name] [intent description]" as an + intent's full description. + + Args: + service_schema: A ServiceSchema object containing the schema for the + corresponding service. + + Returns: + A list of InputFeatures containing features to be given as input to the + BERT model. + """ + service_des = service_schema.description + + features = [] + intent_descriptions = {i["name"]: i["description"] for i in service_schema.schema_json["intents"]} + for intent_id, intent in enumerate(service_schema.intents): + nl_seq = " ".join([service_des, _NL_SEPARATOR, intent, intent_descriptions[intent]]) + features.append(self._create_feature(nl_seq, "intent_emb", service_schema.service_id, intent_id)) + return features + + def _get_req_slots_input_features(self, service_schema): + """Create features for BERT inference for all requested slots of a service. + + We use "[service description] ||| [slot name] [slot description]" as a + slot's full description. + + Args: + service_schema: A ServiceSchema object containing the schema for the + corresponding service. + + Returns: + A list of InputFeatures containing features to be given as input to the + BERT model. + """ + service_des = service_schema.description + + slot_descriptions = {s["name"]: s["description"] for s in service_schema.schema_json["slots"]} + features = [] + for slot_id, slot in enumerate(service_schema.slots): + nl_seq = " ".join([service_des, _NL_SEPARATOR, slot, slot_descriptions[slot]]) + features.append(self._create_feature(nl_seq, "req_slot_emb", service_schema.service_id, slot_id)) + return features + + def _get_goal_slots_and_values_input_features(self, service_schema): + """Get BERT input features for all goal slots and categorical values. + + We use "[service description] ||| [slot name] [slot description]" as a + slot's full description. + We use ""[slot name] [slot description] ||| [value name]" as a categorical + slot value's full description. + + Args: + service_schema: A ServiceSchema object containing the schema for the + corresponding service. + + Returns: + A list of InputFeatures containing features to be given as input to the + BERT model. + """ + service_des = service_schema.description + + features = [] + slot_descriptions = {s["name"]: s["description"] for s in service_schema.schema_json["slots"]} + + for slot_id, slot in enumerate(service_schema.non_categorical_slots): + nl_seq = " ".join([service_des, _NL_SEPARATOR, slot, slot_descriptions[slot]]) + features.append(self._create_feature(nl_seq, "noncat_slot_emb", service_schema.service_id, slot_id)) + + for slot_id, slot in enumerate(service_schema.categorical_slots): + nl_seq = " ".join([service_des, _NL_SEPARATOR, slot, slot_descriptions[slot]]) + features.append(self._create_feature(nl_seq, "cat_slot_emb", service_schema.service_id, slot_id)) + for value_id, value in enumerate(service_schema.get_categorical_slot_values(slot)): + nl_seq = " ".join([slot, slot_descriptions[slot], _NL_SEPARATOR, value]) + features.append( + self._create_feature(nl_seq, "cat_slot_value_emb", service_schema.service_id, slot_id, value_id) + ) + return features + + def _get_input_features(self): + """Get the input function to compute schema element embeddings. + + Args: + schemas: A wrapper for all service schemas in the dataset to be embedded. + + Returns: + The input_fn to be passed to the estimator. + """ + # Obtain all the features. + features = [] + for service in self.schemas.services: + service_schema = self.schemas.get_service_schema(service) + features.extend(self._get_intents_input_features(service_schema)) + features.extend(self._get_req_slots_input_features(service_schema)) + features.extend(self._get_goal_slots_and_values_input_features(service_schema)) + + return features + + def _populate_schema_embeddings(self, schema_embeddings, hidden_states, mode): + """ + Populate all schema embeddings with BERT embeddings. + """ + completed_services = set() + batch_size, seq_len, hidden_size = hidden_states[0].shape + + for idx in range(len(self)): + service_id = self.features['service_id'][idx] + service = self.schemas.get_service_from_id(service_id) + + if service not in completed_services: + logging.debug(f"Generating embeddings for service {service}.") + completed_services.add(service) + tensor_name = self.features["embedding_tensor_name"][idx] + emb_mat = schema_embeddings[service_id][tensor_name] + + if mode == 'random': + # randomly initialize schema embeddings + random_token = random.randint(0, seq_len - 1) + embedding = [round(float(x), 6) for x in hidden_states[0][idx, random_token, :].flat] + elif mode == 'last_layer_average': + # Obtain the encoding of the [CLS] token. + embedding = [round(float(x), 6) for x in np.mean(hidden_states[0][idx, :], 0).flat] + elif mode == 'baseline': + # Obtain the encoding of the [CLS] token. + embedding = [round(float(x), 6) for x in hidden_states[0][idx, 0, :].flat] + else: + raise ValueError(f'Mode {mode} for generation schema embeddings is not supported') + intent_or_slot_id = self.features['intent_or_slot_id'][idx] + value_id = self.features['value_id'][idx] + + if tensor_name == "cat_slot_value_emb": + emb_mat[intent_or_slot_id, value_id] = embedding + else: + emb_mat[intent_or_slot_id] = embedding + + def save_embeddings(self, bert_hidden_states, output_file, mode): + """Generate schema element embeddings and save it as a numpy file.""" + schema_embeddings = [] + max_num_intent = self.schema_config["MAX_NUM_INTENT"] + max_num_cat_slot = self.schema_config["MAX_NUM_CAT_SLOT"] + max_num_noncat_slot = self.schema_config["MAX_NUM_NONCAT_SLOT"] + max_num_slot = max_num_cat_slot + max_num_noncat_slot + max_num_value = self.schema_config["MAX_NUM_VALUE_PER_CAT_SLOT"] + embedding_dim = self.schema_config["EMBEDDING_DIMENSION"] + + for _ in self.schemas.services: + schema_embeddings.append( + { + "intent_emb": np.zeros([max_num_intent, embedding_dim]), + "req_slot_emb": np.zeros([max_num_slot, embedding_dim]), + "cat_slot_emb": np.zeros([max_num_cat_slot, embedding_dim]), + "noncat_slot_emb": np.zeros([max_num_noncat_slot, embedding_dim]), + "cat_slot_value_emb": np.zeros([max_num_cat_slot, max_num_value, embedding_dim]), + } + ) + + # Populate the embeddings based on bert inference results and save them. + self._populate_schema_embeddings(schema_embeddings, bert_hidden_states, mode) + + master_device = not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0 + if master_device: + with open(output_file, "wb") as f_s: + np.save(f_s, schema_embeddings) + logging.info(f"The schema embeddings saved at {output_file}") + f_s.close() + + +class InputFeatures(object): + """A single set of features for BERT inference.""" + + def __init__( + self, input_ids, input_mask, input_type_ids, embedding_tensor_name, service_id, intent_or_slot_id, value_id + ): + # The ids in the vocabulary for input tokens. + self.input_ids = input_ids + # A boolean mask indicating which tokens in the input_ids are valid. + self.input_mask = input_mask + # Denotes the sequence each input token belongs to. + self.input_type_ids = input_type_ids + # The name of the embedding tensor corresponding to this example. + self.embedding_tensor_name = embedding_tensor_name + # The id of the service corresponding to this example. + self.service_id = service_id + # The id of the intent (for intent embeddings) or slot (for slot or slot + # value embeddings) corresponding to this example. + self.intent_or_slot_id = intent_or_slot_id + # The id of the value corresponding to this example. Only set if slot value + # embeddings are being calculated. + self.value_id = value_id diff --git a/nemo/collections/nlp/data/datasets/sgd_dataset/schema_processor.py b/nemo/collections/nlp/data/datasets/sgd_dataset/schema_processor.py new file mode 100644 index 000000000000..8c4475d56670 --- /dev/null +++ b/nemo/collections/nlp/data/datasets/sgd_dataset/schema_processor.py @@ -0,0 +1,168 @@ +# ============================================================================= +# Copyright 2020 NVIDIA. All Rights Reserved. +# Copyright 2019 The Google Research Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= + +""" +This file contains code artifacts adapted from the original implementation: +https://github.com/google-research/google-research/blob/master/schema_guided_dst +""" + +import collections +import os + +import numpy as np +import torch + +from nemo import logging +from nemo.collections.nlp.data.datasets.sgd_dataset import schema +from nemo.collections.nlp.data.datasets.sgd_dataset.schema_embedding_dataset import SchemaEmbeddingDataset +from nemo.collections.nlp.nm.data_layers.bert_inference_datalayer import BertInferDataLayer +from nemo.collections.nlp.utils.data_utils import concatenate + +__all__ = ['SchemaPreprocessor'] + + +class SchemaPreprocessor: + """ + Convert the raw data to the standard format supported by + StateTrackingSGDData. + + Args: + data_dir (str) - Directory for the downloaded DSTC8 data, which contains + the dialogue files and schema files of all datasets (eg train, dev) + dialogues_example_dir (str) - Directory where preprocessed DSTC8 dialogues are stored + schema_embedding_dir (str) - Directory where .npy file for embedding of + entities (slots, values, intents) in the dataset_split's + schema are stored. + task_name (str) - The name of the task to train + vocab_file (str) - The path to BERT vocab file + do_lower_case - (bool) - Whether to lower case the input text. + Should be True for uncased models and False for cased models. + max_seq_length (int) - The maximum total input sequence length after + WordPiece tokenization. Sequences longer than this will be + truncated, and sequences shorter than this will be padded." + tokenizer - tokenizer + bert_model - pretrained BERT model + dataset_split (str) - Dataset split for training / prediction (train/dev/test) + overwrite_dial_file (bool) - Whether to generate a new file saving + the dialogue examples overwrite_schema_emb_file, + bert_ckpt_dir (str) - Directory containing pre-trained BERT checkpoint + nf - NeuralModuleFactory + mode(str): Schema embeddings initialization mode, baseline is ['CLS'] token embeddings + from the last BERT layer + """ + + def __init__( + self, + data_dir, + schema_embedding_dir, + schema_config, + tokenizer, + bert_model, + overwrite_schema_emb_files, + bert_ckpt_dir, + nf, + datasets=['train', 'test', 'dev'], + mode='baseline', + is_trainable=False, + ): + + # Dimension of the embedding for intents, slots and categorical slot values in + # Maximum allowed number of categorical trackable slots for a service. + self.schema_config = schema_config.copy() + # self.MAX_NUM_CAT_SLOT = config["MAX_NUM_CAT_SLOT"] + # # Maximum allowed number of non-categorical trackable slots for a service. + # self.MAX_NUM_NONCAT_SLOT = config["MAX_NUM_NONCAT_SLOT"] + # # Maximum allowed number of values per categorical trackable slot. + # self.MAX_NUM_VALUE_PER_CAT_SLOT = config["MAX_NUM_VALUE_PER_CAT_SLOT"] + # # Maximum allowed number of intents for a service. + # self.MAX_NUM_INTENT = config["MAX_NUM_INTENT"] + + self.is_trainable = is_trainable + self.datasets = datasets + + for dataset_split in ['train', 'test', 'dev']: + if dataset_split not in self.datasets: + logging.warning( + 'WARNING: %s set was not included and won\'t be processed. Services from this dataset split ' + + 'won\'t be supported', + dataset_split, + ) + os.makedirs(schema_embedding_dir, exist_ok=True) + + tokenizer_type = type(tokenizer.tokenizer).__name__ + vocab_size = getattr(tokenizer, "vocab_size", 0) + self.schema_embedding_file = os.path.join( + schema_embedding_dir, + "{}_{}_{}_{}_pretrained_schema_embedding.npy".format( + '_'.join(self.datasets), mode, tokenizer_type, vocab_size + ), + ) + all_schema_json_paths = [] + for dataset_split in self.datasets: + all_schema_json_paths.append(os.path.join(data_dir, dataset_split, "schema.json")) + self.schemas = schema.Schema(all_schema_json_paths) + + if not os.path.exists(self.schema_embedding_file) or overwrite_schema_emb_files: + # Generate the schema embeddings if needed or specified + logging.info(f"Start generating the schema embeddings.") + dataset_params = { + "schema_config": schema_config, + "tokenizer": tokenizer, + "schemas": self.schemas, + } + emb_datalayer = BertInferDataLayer( + dataset_type=SchemaEmbeddingDataset, dataset_params=dataset_params, batch_size=1, shuffle=False, + ) + + input_ids, input_mask, input_type_ids = emb_datalayer() + + hidden_states = bert_model(input_ids=input_ids, token_type_ids=input_type_ids, attention_mask=input_mask) + evaluated_tensors = nf.infer(tensors=[hidden_states], checkpoint_dir=bert_ckpt_dir) + + master_device = not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0 + if master_device: + hidden_states = [concatenate(tensors) for tensors in evaluated_tensors] + emb_datalayer.dataset.save_embeddings(hidden_states, self.schema_embedding_file, mode) + logging.info(f"Finish generating the schema embeddings.") + + # wait until the master process writes to the schema embedding file + if torch.distributed.is_initialized(): + torch.distributed.barrier() + + with open(self.schema_embedding_file, "rb") as f: + self.schema_embeddings = np.load(f, allow_pickle=True) + f.close() + + def get_schema_embeddings(self): + # Convert from list of dict to dict of list + schema_data_dict = collections.defaultdict(list) + for service in self.schema_embeddings: + schema_data_dict["cat_slot_emb"].append(service["cat_slot_emb"]) + schema_data_dict["cat_slot_value_emb"].append(service["cat_slot_value_emb"]) + schema_data_dict["noncat_slot_emb"].append(service["noncat_slot_emb"]) + schema_data_dict["req_slot_emb"].append(service["req_slot_emb"]) + schema_data_dict["intent_emb"].append(service["intent_emb"]) + return schema_data_dict + + def _get_schema_embedding_file_name(self): + return self.schema_embedding_file + + def get_service_names_to_id_dict(self): + return self.schemas._services_vocab + + def get_ids_to_service_names_dict(self): + return self.schemas._services_id_to_vocab diff --git a/nemo/collections/nlp/data/datasets/sgd_dataset/sgd_dataset.py b/nemo/collections/nlp/data/datasets/sgd_dataset/sgd_dataset.py new file mode 100644 index 000000000000..7e83b6a93904 --- /dev/null +++ b/nemo/collections/nlp/data/datasets/sgd_dataset/sgd_dataset.py @@ -0,0 +1,67 @@ +# ============================================================================= +# Copyright 2020 NVIDIA. All Rights Reserved. +# Copyright 2019 The Google Research Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= +""" +This file contains code artifacts adapted from the original implementation: +https://github.com/google-research/google-research/blob/master/schema_guided_dst +""" +import numpy as np +from torch.utils.data import Dataset + +__all__ = ['SGDDataset'] + + +class SGDDataset(Dataset): + """ + Processes SGD dataset + Args: + dataset_split (str): train/dev/test + dialogues_processor (obj): Data generator for dstc8 dialogues + """ + + def __init__(self, dataset_split, dialogues_processor): + self.features = dialogues_processor.get_dialog_examples(dataset_split) + + def __len__(self): + return len(self.features) + + def __getitem__(self, idx): + ex = self.features[idx] + service_id = ex.service_schema.service_id + + return ( + np.array(ex.example_id_num), + np.array(service_id), + np.array(ex.is_real_example, dtype=int), + np.array(ex.utterance_ids), + np.array(ex.utterance_segment), + np.array(ex.utterance_mask, dtype=np.long), + np.array(ex.categorical_slot_status), + np.array(ex.cat_slot_status_mask), + np.array(ex.categorical_slot_values), + np.array(ex.cat_slot_values_mask), + np.array(ex.noncategorical_slot_status), + np.array(ex.noncat_slot_status_mask), + np.array(ex.noncategorical_slot_value_start), + np.array(ex.noncategorical_slot_value_end), + np.array(ex.start_char_idx), # noncat_alignment_start + np.array(ex.end_char_idx), # noncat_alignment_end + np.array(ex.num_slots), # num_requested_slots + np.array(ex.requested_slot_status, dtype=np.float32), + np.array(ex.requested_slot_mask), + np.array(ex.intent_status_mask), + np.array(ex.intent_status_labels), + ) diff --git a/nemo/collections/nlp/data/tokenizers/tokenizer_utils.py b/nemo/collections/nlp/data/tokenizers/tokenizer_utils.py index 7ace6eeea053..0cdfb7e46bfe 100644 --- a/nemo/collections/nlp/data/tokenizers/tokenizer_utils.py +++ b/nemo/collections/nlp/data/tokenizers/tokenizer_utils.py @@ -18,14 +18,21 @@ from transformers import AlbertTokenizer, BertTokenizer, RobertaTokenizer import nemo -from nemo.collections.nlp.nm.trainables.common.megatron.megatron_utils import ( - get_megatron_vocab_file, - is_lower_cased_megatron, -) +from nemo.utils import logging -__all__ = ['MODEL_SPECIAL_TOKENS', 'TOKENIZERS', 'get_tokenizer', 'get_bert_special_tokens'] +try: + __megatron_utils_satisfied = True + from nemo.collections.nlp.nm.trainables.common.megatron.megatron_utils import ( + get_megatron_vocab_file, + is_lower_cased_megatron, + ) + +except Exception as e: + logging.error('Failed to import Megatron utils: `{}` ({})'.format(str(e), type(e))) + __megatron_utils_satisfied = False -logging = nemo.logging + +__all__ = ['MODEL_SPECIAL_TOKENS', 'TOKENIZERS', 'get_tokenizer', 'get_bert_special_tokens'] MODEL_SPECIAL_TOKENS = { 'bert': { @@ -84,12 +91,14 @@ def get_tokenizer( vocab_file (str): path to vocab file do_lower_case (bool): (whether to apply lower cased) - only applicable when tokenizer is build with vocab file ''' - if 'megatron' in pretrained_model_name: - do_lower_case = is_lower_cased_megatron(pretrained_model_name) - vocab_file = get_megatron_vocab_file(pretrained_model_name) - return nemo.collections.nlp.data.tokenizers.NemoBertTokenizer( - vocab_file=vocab_file, do_lower_case=do_lower_case - ) + # Check if we can use Megatron utils. + if __megatron_utils_satisfied: + if 'megatron' in pretrained_model_name: + do_lower_case = is_lower_cased_megatron(pretrained_model_name) + vocab_file = get_megatron_vocab_file(pretrained_model_name) + return nemo.collections.nlp.data.tokenizers.NemoBertTokenizer( + vocab_file=vocab_file, do_lower_case=do_lower_case + ) if tokenizer_name == 'nemobert': tokenizer = nemo.collections.nlp.data.tokenizers.NemoBertTokenizer( diff --git a/nemo/collections/nlp/nm/data_layers/__init__.py b/nemo/collections/nlp/nm/data_layers/__init__.py index 5b5d3dde539f..0c42605631ac 100644 --- a/nemo/collections/nlp/nm/data_layers/__init__.py +++ b/nemo/collections/nlp/nm/data_layers/__init__.py @@ -14,6 +14,7 @@ # limitations under the License. # ============================================================================= +from nemo.collections.nlp.nm.data_layers.bert_inference_datalayer import * from nemo.collections.nlp.nm.data_layers.glue_benchmark_datalayer import * from nemo.collections.nlp.nm.data_layers.joint_intent_slot_datalayer import * from nemo.collections.nlp.nm.data_layers.lm_bert_datalayer import * @@ -21,6 +22,7 @@ from nemo.collections.nlp.nm.data_layers.machine_translation_datalayer import * from nemo.collections.nlp.nm.data_layers.punctuation_capitalization_datalayer import * from nemo.collections.nlp.nm.data_layers.qa_squad_datalayer import * +from nemo.collections.nlp.nm.data_layers.state_tracking_sgd_datalayer import * from nemo.collections.nlp.nm.data_layers.state_tracking_trade_datalayer import * from nemo.collections.nlp.nm.data_layers.text_classification_datalayer import * from nemo.collections.nlp.nm.data_layers.text_datalayer import * diff --git a/nemo/collections/nlp/nm/data_layers/bert_inference_datalayer.py b/nemo/collections/nlp/nm/data_layers/bert_inference_datalayer.py new file mode 100644 index 000000000000..2da78552084d --- /dev/null +++ b/nemo/collections/nlp/nm/data_layers/bert_inference_datalayer.py @@ -0,0 +1,68 @@ +# ============================================================================= +# Copyright 2020 NVIDIA. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= + +from nemo.collections.nlp.nm.data_layers.text_datalayer import TextDataLayer +from nemo.core import ChannelType, NeuralType +from nemo.utils.decorators import add_port_docs + +__all__ = ['BertInferDataLayer'] + + +class BertInferDataLayer(TextDataLayer): + """ + Data layer to run infernce with BERT (get final hidden layer). + + Args: + tokenizer (TokenizerSpec): tokenizer + dataset (str): directory or a single file with dataset documents + max_seq_length (int): maximum allowed length of the text segments + mask_probability (float): probability of masking input sequence tokens + batch_size (int): batch size in segments + short_seeq_prob (float): Probability of creating sequences which are + shorter than the maximum length. + Defualts to 0.1. + """ + + @property + @add_port_docs() + def output_ports(self): + """Returns definitions of module output ports. + + input_ids: indices of tokens which constitute batches of text segments + 0: AxisType(BatchTag) + + 1: AxisType(TimeTag) + + input_type_ids: indices of token types (e.g., sentences A & B in BERT) + 0: AxisType(BatchTag) + + 1: AxisType(TimeTag) + + input_mask: bool tensor with 0s in place of tokens to be masked + 0: AxisType(BatchTag) + + 1: AxisType(TimeTag) + + """ + return { + "input_ids": NeuralType(('B', 'T'), ChannelType()), + "input_type_ids": NeuralType(('B', 'T'), ChannelType()), + "input_mask": NeuralType(('B', 'T'), ChannelType()), + } + + def __init__(self, dataset_type, dataset_params, batch_size=1, shuffle=False): + + super().__init__(dataset_type, dataset_params, batch_size=batch_size, shuffle=shuffle) diff --git a/nemo/collections/nlp/nm/data_layers/state_tracking_sgd_datalayer.py b/nemo/collections/nlp/nm/data_layers/state_tracking_sgd_datalayer.py new file mode 100644 index 000000000000..d5f76d0ed65b --- /dev/null +++ b/nemo/collections/nlp/nm/data_layers/state_tracking_sgd_datalayer.py @@ -0,0 +1,121 @@ +# ============================================================================= +# Copyright 2020 NVIDIA. All Rights Reserved. +# Copyright 2019 The Google Research Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= + +from nemo.backends.pytorch import DataLayerNM +from nemo.collections.nlp.data.datasets.sgd_dataset.sgd_dataset import SGDDataset +from nemo.core.neural_types import ChannelType, LabelsType, LengthsType, NeuralType +from nemo.utils.decorators import add_port_docs + +__all__ = ['SGDDataLayer'] + + +class SGDDataLayer(DataLayerNM): + """ + Data layer for Schema Guided Dialogue State Tracking Dataset. + Args: + dataset_split (str): train/ dev/ test, + dialogues_processor (obj): containt dialogue data, + dataset_type (Dataset): Dataset Type, + shuffle (bool): enables shuffling, default=False + num_workers (int): number of workers + batch_size (int): batch size + pin_memory (bool): enables copying Tensors into CUDA pinned memory before returning them + """ + + @property + @add_port_docs() + def output_ports(self): + """Returns definitions of module output ports. + example_id_num (int): example ids + service_id (int): service ids + is_real_example (bool): flag to determine is the example is valid + utterance_ids (int): utterance ids + utterance_segment (int): Denotes the identity of the sequence. Takes values 0 (system utterance) and 1 (user utterance) + utterance_mask (int): Mask which takes the value 0 for padded tokens and 1 otherwise + categorical_slot_status (int): The status of each categorical slot in the service + cat_slot_status_mask(int): Masks out categorical status for padded cat slots, takes values 0 and 1 + categorical_slot_values (int): The index of the correct value for each categorical slot + cat_slot_values_mask (int): Masks out categorical slots values for slots not used in the service, takes values 0 and 1 + noncategorical_slot_status (int): The status of each non-categorical slot in the service + noncat_slot_status_mask(int): Masks out non-categorical status for padded cat slots, takes values 0 and 1 + noncategorical_slot_value_start (int): The index of the starting subword corresponding to the slot span for a non-categorical slot value + noncategorical_slot_value_end (int): The index of the ending (inclusive) subword corresponding to the slot span for a non-categorical slot value + start_char_idx (int): Start character indices in the original utterance corresponding to the tokens + end_char_idx (int): Inclusive end character indices in the original utterance corresponding to the tokens + num_slots (int): Total number of slots present in the service + requested_slot_status (int): Takes value 1 if the corresponding slot is requested, 0 otherwise + req_slot_mask (int): Masks requested slots not used for the particular service + intent_status_mask (long): Masks out padded intents in the service, takes values 0 and 1 + intent_status_labels (int): Intent labels + + """ + return { + "example_id_num": NeuralType(('B'), ChannelType()), + "service_id": NeuralType(('B'), ChannelType()), + "is_real_example": NeuralType(('B'), ChannelType()), + "utterance_ids": NeuralType(('B', 'T'), ChannelType()), + "utterance_segment": NeuralType(('B', 'T'), ChannelType()), + "utterance_mask": NeuralType(('B', 'T'), ChannelType()), + "categorical_slot_status": NeuralType(('B', 'T'), LabelsType()), + "cat_slot_status_mask": NeuralType(('B', 'T'), ChannelType()), + "categorical_slot_values": NeuralType(('B', 'T'), LabelsType()), + "cat_slot_values_mask": NeuralType(('B', 'T', 'C'), ChannelType()), + "noncategorical_slot_status": NeuralType(('B', 'T'), LabelsType()), + "noncat_slot_status_mask": NeuralType(('B', 'T'), ChannelType()), + "noncategorical_slot_value_start": NeuralType(('B', 'T'), LabelsType()), + "noncategorical_slot_value_end": NeuralType(('B', 'T'), LabelsType()), + "start_char_idx": NeuralType(('B', 'T'), LabelsType()), + "end_char_idx": NeuralType(('B', 'T'), LabelsType()), + "num_slots": NeuralType(('B'), LengthsType()), + "requested_slot_status": NeuralType(('B', 'T'), LabelsType()), + "req_slot_mask": NeuralType(('B', 'T'), ChannelType()), + "intent_status_mask": NeuralType(('B', 'T'), ChannelType()), + "intent_status_labels": NeuralType(('B'), LabelsType()), + } + + def __init__( + self, + dataset_split, + dialogues_processor, + dataset_type=SGDDataset, + shuffle=False, + batch_size=1, + num_workers=-1, + pin_memory=False, + ): + super().__init__() + dataset_params = { + 'dataset_split': dataset_split, + 'dialogues_processor': dialogues_processor, + } + self._dataset = dataset_type(**dataset_params) + self._batch_size = batch_size + self._shuffle = shuffle + self._pin_memory = pin_memory + if num_workers >= 0: + self._num_workers = num_workers + + def __len__(self): + return len(self._dataset) + + @property + def dataset(self): + return self._dataset + + @property + def data_iterator(self): + return None diff --git a/nemo/collections/nlp/nm/data_layers/text_datalayer.py b/nemo/collections/nlp/nm/data_layers/text_datalayer.py index e18da9f0d721..0013fc97e9a4 100644 --- a/nemo/collections/nlp/nm/data_layers/text_datalayer.py +++ b/nemo/collections/nlp/nm/data_layers/text_datalayer.py @@ -31,11 +31,14 @@ class TextDataLayer(DataLayerNM): shuffle (bool): whether to shuffle data """ - def __init__(self, dataset_type, dataset_params, batch_size, shuffle=False): + def __init__(self, dataset_type, dataset_params, batch_size, shuffle=False, num_workers=-1, pin_memory=False): super().__init__() self._dataset = dataset_type(**dataset_params) self._batch_size = batch_size self._shuffle = shuffle + self._pin_memory = pin_memory + if num_workers >= 0: + self._num_workers = num_workers def __len__(self): return len(self._dataset) diff --git a/nemo/collections/nlp/nm/losses/__init__.py b/nemo/collections/nlp/nm/losses/__init__.py index ee7b74199e13..357839adb61a 100644 --- a/nemo/collections/nlp/nm/losses/__init__.py +++ b/nemo/collections/nlp/nm/losses/__init__.py @@ -15,5 +15,6 @@ # ============================================================================= from nemo.collections.nlp.nm.losses.masked_xentropy_loss import * +from nemo.collections.nlp.nm.losses.sgd_loss import * from nemo.collections.nlp.nm.losses.smoothed_cross_entropy_loss import * from nemo.collections.nlp.nm.losses.spanning_loss import * diff --git a/nemo/collections/nlp/nm/losses/sgd_loss.py b/nemo/collections/nlp/nm/losses/sgd_loss.py new file mode 100644 index 000000000000..e51912bb7b48 --- /dev/null +++ b/nemo/collections/nlp/nm/losses/sgd_loss.py @@ -0,0 +1,227 @@ +# ============================================================================= +# Copyright 2020 NVIDIA. All Rights Reserved. +# Copyright 2019 The Google Research Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= + +''' +This file contains code artifacts adapted from the original implementation: +https://github.com/google-research/google-research/blob/master/schema_guided_dst/baseline/train_and_predict.py +''' + +import torch + +from nemo import logging +from nemo.backends.pytorch import LossNM +from nemo.collections.nlp.data.datasets.sgd_dataset.input_example import STATUS_ACTIVE +from nemo.core import ChannelType, LabelsType, LogitsType, NeuralType +from nemo.utils.decorators import add_port_docs + +__all__ = ['SGDDialogueStateLossNM'] + + +class SGDDialogueStateLossNM(LossNM): + """ + Neural module which implements loss for SGD model. + """ + + @property + @add_port_docs + def input_ports(self): + """Returns definitions of module input ports. + logit_intent_status (float): Output of SGD model + intent_status_labels (int): Intent labels + logit_req_slot_status (float): Output of SGD model + requested_slot_status (float): Takes value 1 if the corresponding slot is requested, 0 otherwise + req_slot_mask (bool): Masks requested slots not used for the particular service + logit_cat_slot_status (float): Output of SGD model + categorical_slot_status (int): The status of each categorical slot in the service + cat_slot_status_mask (bool): Masks categorical slots not used for the particular service + logit_cat_slot_value (float): Output of SGD model + categorical_slot_values (int): The index of the correct value for each categorical slot + logit_noncat_slot_status (float): Output of SGD model + noncategorical_slot_status (int): The status of each noncategorical slot in the service + noncat_slot_status_mask (bool): masks noncategorical slots not used for the particular service + logit_noncat_slot_start (float): Output of SGD model + logit_noncat_slot_end (float): Output of SGD model + noncategorical_slot_value_start (int): The index of the starting subword corresponding to the slot span for a non-categorical slot value + noncategorical_slot_value_end (int): The index of the ending (inclusive) subword corresponding to the slot span for a non-categorical slot value + """ + return { + "logit_intent_status": NeuralType(('B', 'T', 'C'), LogitsType()), + "intent_status_labels": NeuralType(('B'), LabelsType()), + "logit_req_slot_status": NeuralType(('B', 'T'), LogitsType()), + "requested_slot_status": NeuralType(('B', 'T'), LabelsType()), + "req_slot_mask": NeuralType(('B', 'T'), ChannelType()), + "logit_cat_slot_status": NeuralType(('B', 'T', 'C'), LogitsType()), + "categorical_slot_status": NeuralType(('B', 'T'), LabelsType()), + "cat_slot_status_mask": NeuralType(('B', 'T'), ChannelType()), + "logit_cat_slot_value": NeuralType(('B', 'T', 'C'), LogitsType()), + "categorical_slot_values": NeuralType(('B', 'T'), LabelsType()), + "logit_noncat_slot_status": NeuralType(('B', 'T', 'C'), LogitsType()), + "noncategorical_slot_status": NeuralType(('B', 'T'), LabelsType()), + "noncat_slot_status_mask": NeuralType(('B', 'T'), ChannelType()), + "logit_noncat_slot_start": NeuralType(('B', 'T', 'C'), LogitsType()), + "logit_noncat_slot_end": NeuralType(('B', 'T', 'C'), LogitsType()), + "noncategorical_slot_value_start": NeuralType(('B', 'T'), LabelsType()), + "noncategorical_slot_value_end": NeuralType(('B', 'T'), LabelsType()), + } + + @property + def output_ports(self): + """ + Returns definitions of module output ports. + loss: + NeuralType(None) + """ + return {"loss": NeuralType(None)} + + def __init__(self, reduction='mean'): + """ + Args: + reduction (str): specifies the reduction to apply to the final loss, choose 'mean' or 'sum' + """ + super().__init__() + + if reduction not in ['mean', 'sum']: + logging.warning(f'{reduction} reduction is not supported. Setting reduction to "mean"') + reduction = 'mean' + + self.reduction = reduction + self._cross_entropy = torch.nn.CrossEntropyLoss(reduction=self.reduction) + self._criterion_req_slots = torch.nn.BCEWithLogitsLoss(reduction=self.reduction) + + def _loss_function( + self, + logit_intent_status, + intent_status_labels, + logit_req_slot_status, + requested_slot_status, + req_slot_mask, + logit_cat_slot_status, + categorical_slot_status, + cat_slot_status_mask, + logit_cat_slot_value, + categorical_slot_values, + logit_noncat_slot_status, + noncategorical_slot_status, + noncat_slot_status_mask, + logit_noncat_slot_start, + logit_noncat_slot_end, + noncategorical_slot_value_start, + noncategorical_slot_value_end, + ): + # Intent loss + intent_loss = self._cross_entropy(logit_intent_status, intent_status_labels) + + # Requested slots. + # Shape: (batch_size, max_num_slots) + # mask unused slots + # Sigmoid cross entropy is used because more than one slots can be requested in a single utterance + requested_slot_loss = self._criterion_req_slots( + logit_req_slot_status.view(-1)[req_slot_mask], requested_slot_status.view(-1)[req_slot_mask] + ) + + # Categorical slot status + # Shape of logit_cat_slot_status: (batch_size, max_num_cat_slots, 3) + cat_slot_status_mask = cat_slot_status_mask.view(-1) > 0.5 + if sum(cat_slot_status_mask) == 0: + logging.warning(f'No active categorical slots in the batch') + cat_slot_status_loss = self._cross_entropy( + logit_cat_slot_status.view(-1, 3), torch.argmax(logit_cat_slot_status.view(-1, 3), dim=-1) + ) + else: + cat_slot_status_loss = self._cross_entropy( + logit_cat_slot_status.view(-1, 3)[cat_slot_status_mask], + categorical_slot_status.view(-1)[cat_slot_status_mask], + ) + + # Categorical slot values. + # Shape: (batch_size, max_num_cat_slots, max_num_slot_values). + max_num_slot_values = logit_cat_slot_value.size()[-1] + + # Zero out losses for categorical slot value when the slot status is not active. + cat_slot_value_mask = (categorical_slot_status == STATUS_ACTIVE).view(-1) + # to handle cases with no active categorical slot value + cat_slot_value_mask = cat_slot_value_mask.view(-1) > 0.5 + if sum(cat_slot_value_mask) == 0: + logging.warning(f'No active values for categorical slots in the batch.') + cat_slot_value_loss = self._cross_entropy( + logit_cat_slot_value.view(-1, max_num_slot_values), + torch.argmax(logit_cat_slot_value.view(-1, max_num_slot_values), dim=-1), + ) + else: + slot_values_active_logits = logit_cat_slot_value.view(-1, max_num_slot_values)[cat_slot_value_mask] + slot_values_active_labels = categorical_slot_values.view(-1)[cat_slot_value_mask] + cat_slot_value_loss = self._cross_entropy(slot_values_active_logits, slot_values_active_labels) + + # Non-categorical slot status. + # Shape: (batch_size, max_num_noncat_slots, 3). + noncat_slot_status_mask = noncat_slot_status_mask.view(-1) > 0.5 + if sum(noncat_slot_status_mask) == 0: + logging.warning(f'No active non-categorical slots in the batch.') + noncat_slot_status_loss = self._cross_entropy( + logit_noncat_slot_status.view(-1, 3), torch.argmax(logit_noncat_slot_status.view(-1, 3), dim=-1) + ) + else: + noncat_slot_status_loss = self._cross_entropy( + logit_noncat_slot_status.view(-1, 3)[noncat_slot_status_mask], + noncategorical_slot_status.view(-1)[noncat_slot_status_mask], + ) + + # Non-categorical slot spans. + # Shape: (batch_size, max_num_noncat_slots, max_num_tokens).n + max_num_tokens = logit_noncat_slot_start.size()[-1] + # Zero out losses for non-categorical slot spans when the slot status is not active. + # changed here + non_cat_slot_value_mask = (noncategorical_slot_status == STATUS_ACTIVE).view(-1) + # non_cat_slot_value_mask = (noncategorical_slot_status > -1 ).view(-1) + # to handle cases with no active categorical slot value + non_cat_slot_value_mask = non_cat_slot_value_mask.view(-1) + if sum(non_cat_slot_value_mask) == 0: + logging.warning(f'No active values for non-categorical slots in the batch.') + span_start_loss = self._cross_entropy( + logit_noncat_slot_start.view(-1, max_num_tokens), + torch.argmax(logit_noncat_slot_start.view(-1, max_num_tokens), dim=-1), + ) + span_end_loss = self._cross_entropy( + logit_noncat_slot_end.view(-1, max_num_tokens), + torch.argmax(logit_noncat_slot_end.view(-1, max_num_tokens), dim=-1), + ) + else: + noncat_slot_start_active_logits = logit_noncat_slot_start.view(-1, max_num_tokens)[non_cat_slot_value_mask] + noncat_slot_start_active_labels = noncategorical_slot_value_start.view(-1)[non_cat_slot_value_mask] + span_start_loss = self._cross_entropy(noncat_slot_start_active_logits, noncat_slot_start_active_labels) + + noncat_slot_end_active_logits = logit_noncat_slot_end.view(-1, max_num_tokens)[non_cat_slot_value_mask] + noncat_slot_end_active_labels = noncategorical_slot_value_end.view(-1)[non_cat_slot_value_mask] + span_end_loss = self._cross_entropy(noncat_slot_end_active_logits, noncat_slot_end_active_labels) + + losses = { + "intent_loss": intent_loss, + "requested_slot_loss": requested_slot_loss, + "cat_slot_status_loss": cat_slot_status_loss, + "cat_slot_value_loss": cat_slot_value_loss, + "noncat_slot_status_loss": noncat_slot_status_loss, + "span_start_loss": span_start_loss, + "span_end_loss": span_end_loss, + } + + total_loss = sum(losses.values()) + if self.reduction == 'mean': + total_loss = total_loss / len(losses) + else: + batch_size = logit_intent_status.shape[0] + total_loss = total_loss / batch_size + return total_loss diff --git a/nemo/collections/nlp/nm/trainables/common/__init__.py b/nemo/collections/nlp/nm/trainables/common/__init__.py index 7ac5338dfe4a..0061462d13fe 100644 --- a/nemo/collections/nlp/nm/trainables/common/__init__.py +++ b/nemo/collections/nlp/nm/trainables/common/__init__.py @@ -16,8 +16,14 @@ from nemo.collections.nlp.nm.trainables.common.common_utils import * from nemo.collections.nlp.nm.trainables.common.huggingface import * -from nemo.collections.nlp.nm.trainables.common.megatron import * from nemo.collections.nlp.nm.trainables.common.sequence_classification_nm import * from nemo.collections.nlp.nm.trainables.common.sequence_regression_nm import * from nemo.collections.nlp.nm.trainables.common.token_classification_nm import * from nemo.collections.nlp.nm.trainables.common.transformer import * +from nemo.utils import logging + +try: + from nemo.collections.nlp.nm.trainables.common.megatron.megatron_utils import * + +except Exception as e: + logging.error('Failed to import Megatron utils: `{}` ({})'.format(str(e), type(e))) diff --git a/nemo/collections/nlp/nm/trainables/common/common_utils.py b/nemo/collections/nlp/nm/trainables/common/common_utils.py index 318038c50aa4..4964269f65ed 100644 --- a/nemo/collections/nlp/nm/trainables/common/common_utils.py +++ b/nemo/collections/nlp/nm/trainables/common/common_utils.py @@ -18,8 +18,16 @@ from nemo import logging from nemo.collections.nlp.nm.trainables.common.huggingface.huggingface_utils import * -from nemo.collections.nlp.nm.trainables.common.megatron.megatron_bert_nm import MegatronBERT -from nemo.collections.nlp.nm.trainables.common.megatron.megatron_utils import * + +try: + __megatron_utils_satisfied = True + from nemo.collections.nlp.nm.trainables.common.megatron.megatron_bert_nm import MegatronBERT + from nemo.collections.nlp.nm.trainables.common.megatron.megatron_utils import * + +except Exception as e: + logging.error('Failed to import Megatron Neural Module and utils: `{}` ({})'.format(str(e), type(e))) + __megatron_utils_satisfied = False + __all__ = ['get_pretrained_lm_models_list', 'get_pretrained_lm_model'] @@ -28,7 +36,10 @@ def get_pretrained_lm_models_list(): ''' Returns the list of support pretrained models ''' - return get_megatron_lm_models_list() + get_huggingface_lm_models_list() + if __megatron_utils_satisfied: + return get_megatron_lm_models_list() + get_huggingface_lm_models_list() + else: + return get_huggingface_lm_models_list() def get_pretrained_lm_model(pretrained_model_name, config=None, vocab=None, checkpoint=None): @@ -45,7 +56,7 @@ def get_pretrained_lm_model(pretrained_model_name, config=None, vocab=None, chec ''' if pretrained_model_name in get_huggingface_lm_models_list(): model = get_huggingface_lm_model(bert_config=config, pretrained_model_name=pretrained_model_name) - elif pretrained_model_name in get_megatron_lm_models_list(): + elif __megatron_utils_satisfied and pretrained_model_name in get_megatron_lm_models_list(): if pretrained_model_name == 'megatron-bert-cased' or pretrained_model_name == 'megatron-bert-uncased': if not (config and checkpoint): raise ValueError(f'Config file and pretrained checkpoint required for {pretrained_model_name}') diff --git a/nemo/collections/nlp/nm/trainables/common/megatron/__init__.py b/nemo/collections/nlp/nm/trainables/common/megatron/__init__.py index d82f20067425..34bb64c10941 100644 --- a/nemo/collections/nlp/nm/trainables/common/megatron/__init__.py +++ b/nemo/collections/nlp/nm/trainables/common/megatron/__init__.py @@ -14,4 +14,10 @@ # limitations under the License. # ============================================================================= -from nemo.collections.nlp.nm.trainables.common.megatron.megatron_bert_nm import * +from nemo.utils import logging + +try: + from nemo.collections.nlp.nm.trainables.common.megatron.megatron_bert_nm import * + +except Exception as e: + logging.error('Failed to import Megatron Neural Module: `{}` ({})'.format(str(e), type(e))) diff --git a/nemo/collections/nlp/nm/trainables/common/megatron/megatron_utils.py b/nemo/collections/nlp/nm/trainables/common/megatron/megatron_utils.py index 13ec1894eb05..558831967207 100644 --- a/nemo/collections/nlp/nm/trainables/common/megatron/megatron_utils.py +++ b/nemo/collections/nlp/nm/trainables/common/megatron/megatron_utils.py @@ -30,7 +30,7 @@ 'get_megatron_checkpoint', ] -MEGATRON_CACHE = os.path.join(os.path.dirname(TRANSFORMERS_CACHE), 'megatron') +MEGATRON_CACHE = os.path.join(os.path.dirname(str(TRANSFORMERS_CACHE)), 'megatron') CONFIGS = {'345m': {"hidden-size": 1024, "num-attention-heads": 16, "num-layers": 24, "max-seq-length": 512}} diff --git a/nemo/collections/nlp/nm/trainables/common/transformer/transformer_modules.py b/nemo/collections/nlp/nm/trainables/common/transformer/transformer_modules.py index cca8dc6002ba..50e66346e017 100644 --- a/nemo/collections/nlp/nm/trainables/common/transformer/transformer_modules.py +++ b/nemo/collections/nlp/nm/trainables/common/transformer/transformer_modules.py @@ -28,11 +28,12 @@ try: from apex.normalization import FusedLayerNorm -except (AttributeError, ModuleNotFoundError): - # this is lie - it isn't fused in this case - logging.warning( - "Unable to import APEX. Mixed precision, distributed training and FusedLayerNorm are not available." - ) + + # Try to use FusedLayerNorm from Apex - this will trigger an error. + _ = FusedLayerNorm(8, eps=1e-5) + +except Exception as e: + logging.warning("Unable to import FusedLayerNorm from APEX. Using regular LayerNorm instead.") from torch.nn import LayerNorm as FusedLayerNorm diff --git a/nemo/collections/nlp/nm/trainables/dialogue_state_tracking/__init__.py b/nemo/collections/nlp/nm/trainables/dialogue_state_tracking/__init__.py index 0eed31eb4973..05f3cde4c1ce 100644 --- a/nemo/collections/nlp/nm/trainables/dialogue_state_tracking/__init__.py +++ b/nemo/collections/nlp/nm/trainables/dialogue_state_tracking/__init__.py @@ -13,5 +13,5 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================= - +from nemo.collections.nlp.nm.trainables.dialogue_state_tracking.sgd import * from nemo.collections.nlp.nm.trainables.dialogue_state_tracking.trade_generator_nm import * diff --git a/nemo/collections/nlp/nm/trainables/dialogue_state_tracking/sgd/__init__.py b/nemo/collections/nlp/nm/trainables/dialogue_state_tracking/sgd/__init__.py new file mode 100644 index 000000000000..7f7de4a67ec5 --- /dev/null +++ b/nemo/collections/nlp/nm/trainables/dialogue_state_tracking/sgd/__init__.py @@ -0,0 +1,18 @@ +# ============================================================================= +# Copyright 2020 NVIDIA. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= + +from nemo.collections.nlp.nm.trainables.dialogue_state_tracking.sgd.sgd_decoder_nm import * +from nemo.collections.nlp.nm.trainables.dialogue_state_tracking.sgd.sgd_encoder_nm import * diff --git a/nemo/collections/nlp/nm/trainables/dialogue_state_tracking/sgd/sgd_decoder_nm.py b/nemo/collections/nlp/nm/trainables/dialogue_state_tracking/sgd/sgd_decoder_nm.py new file mode 100644 index 000000000000..30f1d5c8758c --- /dev/null +++ b/nemo/collections/nlp/nm/trainables/dialogue_state_tracking/sgd/sgd_decoder_nm.py @@ -0,0 +1,404 @@ +# ============================================================================= +# Copyright 2020 NVIDIA. All Rights Reserved. +# Copyright 2019 The Google Research Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= + +''' +This file contains code artifacts adapted from the original implementation: +https://github.com/google-research/google-research/blob/master/schema_guided_dst/baseline/train_and_predict.py +''' + +import math +import sys + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from nemo.backends.pytorch.nm import TrainableNM +from nemo.core import ChannelType, EmbeddedTextType, LogitsType, NeuralType +from nemo.utils.decorators import add_port_docs + +__all__ = ['SGDDecoderNM'] + + +class LogitsAttention(nn.Module): + def __init__(self, num_classes, embedding_dim): + """Get logits for elements by using attention on token embedding. + Args: + num_classes (int): An int containing the number of classes for which logits are to be generated. + embedding_dim (int): hidden size of the BERT + + Returns: + A tensor of shape (batch_size, num_elements, num_classes) containing the logits. + """ + super().__init__() + self.num_attention_heads = 16 + self.attention_head_size = embedding_dim // self.num_attention_heads + self.embedding_dim = embedding_dim + self.num_classes = num_classes + self.dropout = nn.Dropout(0.1) + + self.key = nn.Linear(embedding_dim, embedding_dim) + self.query = nn.Linear(embedding_dim, embedding_dim) + self.value = nn.Linear(embedding_dim, embedding_dim) + self.layer = nn.Linear(embedding_dim, num_classes) + + def transpose_for_scores(self, x): + new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) + x = x.view(*new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward(self, encoded_utterance, token_embeddings, element_embeddings): + """ + token_embeddings - token hidden states from BERT encoding of the utterance + encoded_utterance - [CLS] token hidden state from BERT encoding of the utterance + element_embeddings: A tensor of shape (batch_size, num_elements, embedding_dim). + """ + _, num_elements, _ = element_embeddings.size() + + query_layer = self.query(element_embeddings) + key_layer = self.key(token_embeddings) + value_layer = self.value(token_embeddings) + + query_layer = self.transpose_for_scores(query_layer) + key_layer = self.transpose_for_scores(key_layer) + value_layer = self.transpose_for_scores(value_layer) + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + attention_scores = attention_scores / math.sqrt(self.embedding_dim) + attention_probs = nn.Softmax(dim=-1)(attention_scores) + attention_probs = self.dropout(attention_probs) + context_layer = torch.matmul(attention_probs, value_layer) + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.embedding_dim,) + context_layer = context_layer.view(*new_context_layer_shape) + + logits = self.layer(context_layer) + return logits + + +class Logits(nn.Module): + def __init__(self, num_classes, embedding_dim): + """Get logits for elements by conditioning on utterance embedding. + Args: + num_classes (int): An int containing the number of classes for which logits are to be generated. + embedding_dim (int): hidden size of the BERT + + Returns: + A tensor of shape (batch_size, num_elements, num_classes) containing the logits. + """ + super().__init__() + self.num_classes = num_classes + self.utterance_proj = nn.Linear(embedding_dim, embedding_dim) + self.activation = F.gelu + + self.layer1 = nn.Linear(2 * embedding_dim, embedding_dim) + self.layer2 = nn.Linear(embedding_dim, num_classes) + + def forward(self, encoded_utterance, token_embeddings, element_embeddings): + """ + token_embeddings - token hidden states from BERT encoding of the utterance + encoded_utterance - [CLS] token hidden state from BERT encoding of the utterance + element_embeddings: A tensor of shape (batch_size, num_elements, embedding_dim). + """ + _, num_elements, _ = element_embeddings.size() + + # Project the utterance embeddings. + utterance_embedding = self.utterance_proj(encoded_utterance) + utterance_embedding = self.activation(utterance_embedding) + + # Combine the utterance and element embeddings. + repeated_utterance_embedding = utterance_embedding.unsqueeze(1).repeat(1, num_elements, 1) + + utterance_element_emb = torch.cat([repeated_utterance_embedding, element_embeddings], axis=2) + logits = self.layer1(utterance_element_emb) + logits = self.activation(logits) + logits = self.layer2(logits) + return logits + + +class SGDDecoderNM(TrainableNM): + """ + Baseline model for schema guided dialogue state tracking with option to make schema embeddings learnable + """ + + @property + @add_port_docs() + def input_ports(self): + """Returns definitions of module output ports. + encoded_utterance (float): [CLS] token hidden state from BERT encoding of the utterance + token_embeddings (float): BERT encoding of utterance (all tokens) + utterance_mask (bool): Mask which takes the value 0 for padded tokens and 1 otherwise + cat_slot_values_mask (int): Masks out categorical slots values for slots not used in the service, takes values 0 and 1 + intent_status_mask (int): Masks out padded intents in the service, takes values 0 and 1 + service_ids (int): service ids + """ + return { + "encoded_utterance": NeuralType(('B', 'T'), EmbeddedTextType()), + "token_embeddings": NeuralType(('B', 'T', 'C'), ChannelType()), + "utterance_mask": NeuralType(('B', 'T'), ChannelType()), + "cat_slot_values_mask": NeuralType(('B', 'T', 'C'), ChannelType()), + "intent_status_mask": NeuralType(('B', 'T'), ChannelType()), + "service_ids": NeuralType(('B'), ChannelType()), + } + + @property + @add_port_docs() + def output_ports(self): + """Returns definitions of module output ports. + logit_intent_status (float): output for intent status + logit_req_slot_status (float): output for requested slots status + logit_cat_slot_status (float): output for categorical slots status + logit_cat_slot_value (float): output for categorical slots values + logit_noncat_slot_status (float): Output of SGD model + logit_noncat_slot_start (float): output for non categorical slots values start + logit_noncat_slot_end (float): output for non categorical slots values end + """ + return { + "logit_intent_status": NeuralType(('B', 'T', 'C'), LogitsType()), + "logit_req_slot_status": NeuralType(('B', 'T'), LogitsType()), + "logit_cat_slot_status": NeuralType(('B', 'T', 'C'), LogitsType()), + "logit_cat_slot_value": NeuralType(('B', 'T', 'C'), LogitsType()), + "logit_noncat_slot_status": NeuralType(('B', 'T', 'C'), LogitsType()), + "logit_noncat_slot_start": NeuralType(('B', 'T', 'C'), LogitsType()), + "logit_noncat_slot_end": NeuralType(('B', 'T', 'C'), LogitsType()), + } + + def __init__(self, embedding_dim, schema_emb_processor, head_transform): + """Get logits for elements by conditioning on utterance embedding. + + Args: + embedding_dim (int): hidden size of the BERT + schema_emb_processor (obj): contains schema embeddings for services and config file + head_transform (str): transformation to use for computing head + """ + super().__init__() + + # Add a trainable vector for the NONE intent + self.none_intent_vector = torch.empty((1, 1, embedding_dim), requires_grad=True).to(self._device) + # TODO truncated norm init + nn.init.normal_(self.none_intent_vector, std=0.02) + self.none_intent_vector = torch.nn.Parameter(self.none_intent_vector).to(self._device) + self.intent_layer = getattr(sys.modules[__name__], head_transform)(1, embedding_dim).to(self._device) + self.requested_slots_layer = getattr(sys.modules[__name__], head_transform)(1, embedding_dim).to(self._device) + + self.cat_slot_value_layer = getattr(sys.modules[__name__], head_transform)(1, embedding_dim).to(self._device) + + # Slot status values: none, dontcare, active. + self.cat_slot_status_layer = getattr(sys.modules[__name__], head_transform)(3, embedding_dim).to(self._device) + self.noncat_slot_layer = getattr(sys.modules[__name__], head_transform)(3, embedding_dim).to(self._device) + + # dim 2 for non_categorical slot - to represent start and end position + self.noncat_layer1 = nn.Linear(2 * embedding_dim, embedding_dim).to(self._device) + self.noncat_activation = F.gelu + self.noncat_layer2 = nn.Linear(embedding_dim, 2).to(self._device) + + config = schema_emb_processor.schema_config + num_services = len(schema_emb_processor.schemas.services) + self.intents_emb = nn.Embedding(num_services, config["MAX_NUM_INTENT"] * embedding_dim) + self.cat_slot_emb = nn.Embedding(num_services, config["MAX_NUM_CAT_SLOT"] * embedding_dim) + self.cat_slot_value_emb = nn.Embedding( + num_services, config["MAX_NUM_CAT_SLOT"] * config["MAX_NUM_VALUE_PER_CAT_SLOT"] * embedding_dim + ) + self.noncat_slot_emb = nn.Embedding(num_services, config["MAX_NUM_NONCAT_SLOT"] * embedding_dim) + self.req_slot_emb = nn.Embedding( + num_services, (config["MAX_NUM_CAT_SLOT"] + config["MAX_NUM_NONCAT_SLOT"]) * embedding_dim + ) + + # initialize schema embeddings from the BERT generated embeddings + schema_embeddings = schema_emb_processor.get_schema_embeddings() + self.intents_emb.weight.data.copy_( + torch.from_numpy(np.stack(schema_embeddings['intent_emb']).reshape(num_services, -1)) + ) + self.cat_slot_emb.weight.data.copy_( + torch.from_numpy(np.stack(schema_embeddings['cat_slot_emb']).reshape(num_services, -1)) + ) + self.cat_slot_value_emb.weight.data.copy_( + torch.from_numpy(np.stack(schema_embeddings['cat_slot_value_emb']).reshape(num_services, -1)) + ) + self.noncat_slot_emb.weight.data.copy_( + torch.from_numpy(np.stack(schema_embeddings['noncat_slot_emb']).reshape(num_services, -1)) + ) + self.req_slot_emb.weight.data.copy_( + torch.from_numpy(np.stack(schema_embeddings['req_slot_emb']).reshape(num_services, -1)) + ) + + if not schema_emb_processor.is_trainable: + self.intents_emb.weight.requires_grad = False + self.cat_slot_emb.weight.requires_grad = False + self.cat_slot_value_emb.weight.requires_grad = False + self.noncat_slot_emb.weight.requires_grad = False + self.req_slot_emb.weight.requires_grad = False + + self.to(self._device) + + def forward( + self, + encoded_utterance, + token_embeddings, + utterance_mask, + cat_slot_values_mask, + service_ids, + intent_status_mask, + ): + batch_size, emb_dim = encoded_utterance.size() + intent_embeddings = self.intents_emb(service_ids).view(batch_size, -1, emb_dim) + cat_slot_emb = self.cat_slot_emb(service_ids).view(batch_size, -1, emb_dim) + max_number_cat_slots = cat_slot_emb.shape[1] + cat_slot_value_emb = self.cat_slot_value_emb(service_ids).view(batch_size, max_number_cat_slots, -1, emb_dim) + noncat_slot_emb = self.noncat_slot_emb(service_ids).view(batch_size, -1, emb_dim) + req_slot_emb = self.req_slot_emb(service_ids).view(batch_size, -1, emb_dim) + + logit_intent_status = self._get_intents( + encoded_utterance, intent_embeddings, intent_status_mask, token_embeddings + ) + + logit_req_slot_status = self._get_requested_slots(encoded_utterance, req_slot_emb, token_embeddings) + + logit_cat_slot_status, logit_cat_slot_value = self._get_categorical_slot_goals( + encoded_utterance, cat_slot_emb, cat_slot_value_emb, cat_slot_values_mask, token_embeddings + ) + + ( + logit_noncat_slot_status, + logit_noncat_slot_start, + logit_noncat_slot_end, + ) = self._get_noncategorical_slot_goals(encoded_utterance, utterance_mask, noncat_slot_emb, token_embeddings) + + return ( + logit_intent_status, + logit_req_slot_status, + logit_cat_slot_status, + logit_cat_slot_value, + logit_noncat_slot_status, + logit_noncat_slot_start, + logit_noncat_slot_end, + ) + + def _get_intents(self, encoded_utterance, intent_embeddings, intent_status_mask, token_embeddings): + """ + Args: + intent_embedding - BERT schema embeddings + encoded_utterance - representation of untterance + intent_status_mask - masks out intent not used for the service + """ + batch_size = intent_embeddings.size()[0] + + # Add a trainable vector for the NONE intent. + repeated_none_intent_vector = self.none_intent_vector.repeat(batch_size, 1, 1) + intent_embeddings = torch.cat([repeated_none_intent_vector, intent_embeddings], axis=1) + logits = self.intent_layer( + encoded_utterance=encoded_utterance, + token_embeddings=token_embeddings, + element_embeddings=intent_embeddings, + ) + logits = logits.squeeze(axis=-1) # Shape: (batch_size, max_intents + 1) + + # Mask out logits for padded intents + negative_logits = self._get_negative_logits(logits) + return torch.where(intent_status_mask.to(dtype=torch.bool), logits, negative_logits) + + def _get_requested_slots(self, encoded_utterance, requested_slot_emb, token_embeddings): + """Obtain logits for requested slots.""" + + logits = self.requested_slots_layer( + encoded_utterance=encoded_utterance, + token_embeddings=token_embeddings, + element_embeddings=requested_slot_emb, + ) + logits = logits.squeeze(axis=-1) + + # logits shape: (batch_size, max_num_slots) + logits = logits.squeeze(axis=-1) + return logits + + def _get_categorical_slot_goals( + self, encoded_utterance, cat_slot_emb, cat_slot_value_emb, cat_slot_values_mask, token_embeddings + ): + """ + Obtain logits for status and values for categorical slots + Slot status values: none, dontcare, active + """ + + # Predict the status of all categorical slots. + status_logits = self.cat_slot_status_layer( + encoded_utterance=encoded_utterance, token_embeddings=token_embeddings, element_embeddings=cat_slot_emb + ) + + # Predict the goal value. + # Shape: (batch_size, max_categorical_slots, max_categorical_values, embedding_dim). + _, max_num_slots, max_num_values, embedding_dim = cat_slot_value_emb.size() + cat_slot_value_emb_reshaped = cat_slot_value_emb.view(-1, max_num_slots * max_num_values, embedding_dim) + + value_logits = self.cat_slot_value_layer( + encoded_utterance=encoded_utterance, + token_embeddings=token_embeddings, + element_embeddings=cat_slot_value_emb_reshaped, + ) + + # Reshape to obtain the logits for all slots. + value_logits = value_logits.view(-1, max_num_slots, max_num_values) + + # Mask out logits for padded slots and values because they will be softmaxed + negative_value_logits = self._get_negative_logits(value_logits) + value_logits = torch.where(cat_slot_values_mask.to(dtype=torch.bool), value_logits, negative_value_logits) + return status_logits, value_logits + + def _get_noncategorical_slot_goals(self, encoded_utterance, utterance_mask, noncat_slot_emb, token_embeddings): + """ + Obtain logits for status and slot spans for non-categorical slots. + Slot status values: none, dontcare, active + """ + # Predict the status of all non-categorical slots. + max_num_slots = noncat_slot_emb.size()[1] + status_logits = self.noncat_slot_layer( + encoded_utterance=encoded_utterance, token_embeddings=token_embeddings, element_embeddings=noncat_slot_emb + ) + + # Predict the distribution for span indices. + max_num_tokens = token_embeddings.size()[1] + + repeated_token_embeddings = token_embeddings.unsqueeze(1).repeat(1, max_num_slots, 1, 1) + repeated_slot_embeddings = noncat_slot_emb.unsqueeze(2).repeat(1, 1, max_num_tokens, 1) + + # Shape: (batch_size, max_num_slots, max_num_tokens, 2 * embedding_dim). + slot_token_embeddings = torch.cat([repeated_slot_embeddings, repeated_token_embeddings], axis=3) + + # Project the combined embeddings to obtain logits, Shape: (batch_size, max_num_slots, max_num_tokens, 2) + span_logits = self.noncat_layer1(slot_token_embeddings) + span_logits = self.noncat_activation(span_logits) + span_logits = self.noncat_layer2(span_logits) + + # Mask out invalid logits for padded tokens. + utterance_mask = utterance_mask.to(bool) # Shape: (batch_size, max_num_tokens). + repeated_utterance_mask = utterance_mask.unsqueeze(1).unsqueeze(3).repeat(1, max_num_slots, 1, 2) + negative_logits = (torch.finfo(span_logits.dtype).max * -0.7) * torch.ones( + span_logits.size(), device=self._device, dtype=span_logits.dtype + ) + + span_logits = torch.where(repeated_utterance_mask, span_logits, negative_logits) + + # Shape of both tensors: (batch_size, max_num_slots, max_num_tokens). + span_start_logits, span_end_logits = torch.unbind(span_logits, dim=3) + return status_logits, span_start_logits, span_end_logits + + def _get_negative_logits(self, logits): + # returns tensor with negative logits that will be used to mask out unused values + # for a particular service + negative_logits = (torch.finfo(logits.dtype).max * -0.7) * torch.ones( + logits.size(), device=self._device, dtype=logits.dtype + ) + return negative_logits diff --git a/nemo/collections/nlp/nm/trainables/dialogue_state_tracking/sgd/sgd_encoder_nm.py b/nemo/collections/nlp/nm/trainables/dialogue_state_tracking/sgd/sgd_encoder_nm.py new file mode 100644 index 000000000000..13c1887f04b2 --- /dev/null +++ b/nemo/collections/nlp/nm/trainables/dialogue_state_tracking/sgd/sgd_encoder_nm.py @@ -0,0 +1,90 @@ +# ============================================================================= +# Copyright 2020 NVIDIA. All Rights Reserved. +# Copyright 2019 The Google Research Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= + +''' +This file contains code artifacts adapted from the original implementation: +https://github.com/google-research/google-research/blob/master/schema_guided_dst/baseline/train_and_predict.py +''' + +from torch import nn + +from nemo.backends.pytorch.nm import TrainableNM +from nemo.collections.nlp.utils.transformer_utils import transformer_weights_init +from nemo.core import ChannelType, EmbeddedTextType, NeuralType +from nemo.utils.decorators import add_port_docs + +__all__ = ['SGDEncoderNM'] + +ACTIVATIONS_F = { + "tanh": nn.Tanh, + "relu": nn.ReLU, +} + + +class SGDEncoderNM(TrainableNM): + """ + Neural module which extracts the first token from the BERT representation of the utterance + followed by a fully connected layer. + + Args: + hidden_size (int): hidden size of the BERT model + activation (str): activation function applied + dropout (float): dropout ratio + """ + + @property + @add_port_docs + def input_ports(self): + """ + Returns definitions of module input ports. + hidden_states (float): BERT representation of the utterance + """ + return {"hidden_states": NeuralType(('B', 'T', 'C'), ChannelType())} + + @property + @add_port_docs + def output_ports(self): + """Returns definitions of module output ports. + logits (float): First token of the BERT representation of the utterance followed by fc and dropout + hidden_states (float) : BERT representation of the utterance with applied dropout + """ + return { + "logits": NeuralType(('B', 'T'), EmbeddedTextType()), + "hidden_states": NeuralType(('B', 'T', 'C'), ChannelType()), + } + + def __init__(self, hidden_size, activation='tanh', dropout=0.0, use_transformer_pretrained=True): + super().__init__() + self.fc = nn.Linear(hidden_size, hidden_size).to(self._device) + + if activation not in ACTIVATIONS_F: + raise ValueError(f'{activation} is not in supported ' + '{ACTIVATIONS_F.keys()}') + + self.activation = ACTIVATIONS_F[activation]() + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + + if use_transformer_pretrained: + self.apply(lambda module: transformer_weights_init(module, xavier=False)) + # self.to(self._device) # sometimes this is necessary + + def forward(self, hidden_states): + first_token_hidden_states = hidden_states[:, 0] + logits = self.fc(first_token_hidden_states) + logits = self.activation(logits) + logits = self.dropout1(logits) + return logits, self.dropout2(hidden_states) diff --git a/nemo/collections/nlp/utils/data_utils.py b/nemo/collections/nlp/utils/data_utils.py index d57c782fedca..1c3dcf5f0db8 100644 --- a/nemo/collections/nlp/utils/data_utils.py +++ b/nemo/collections/nlp/utils/data_utils.py @@ -17,7 +17,9 @@ import re import string -__all__ = ['get_vocab', 'get_tokens', 'normalize_answer', 'mask_padded_tokens'] +import numpy as np + +__all__ = ['get_vocab', 'get_tokens', 'normalize_answer', 'mask_padded_tokens', 'concatenate'] def get_vocab(file): @@ -55,3 +57,10 @@ def get_tokens(s): def mask_padded_tokens(tokens, pad_id): mask = tokens != pad_id return mask + + +def concatenate(lists): + """ + Helper function for inference + """ + return np.concatenate([t.cpu() for t in lists]) diff --git a/nemo/collections/tts/data_layers.py b/nemo/collections/tts/data_layers.py index d57da99187e3..6d29b4504cc9 100644 --- a/nemo/collections/tts/data_layers.py +++ b/nemo/collections/tts/data_layers.py @@ -1,15 +1,13 @@ # Copyright (c) 2019 NVIDIA Corporation import torch -import nemo from .parts.datasets import AudioOnlyDataset from nemo.backends.pytorch.nm import DataLayerNM from nemo.core import DeviceType from nemo.core.neural_types import AudioSignal, LengthsType, NeuralType +from nemo.utils import logging from nemo.utils.decorators import add_port_docs -logging = nemo.logging - class AudioDataLayer(DataLayerNM): """ diff --git a/nemo/collections/tts/parts/helpers.py b/nemo/collections/tts/parts/helpers.py index fb5935ba8d84..0819b6398b45 100644 --- a/nemo/collections/tts/parts/helpers.py +++ b/nemo/collections/tts/parts/helpers.py @@ -4,9 +4,7 @@ import numpy as np import torch -import nemo - -logging = nemo.logging +from nemo.utils import logging __all__ = [ "waveglow_log_to_tb_func", diff --git a/nemo/collections/tts/parts/tacotron2.py b/nemo/collections/tts/parts/tacotron2.py index a201c1cfdbe3..925251f19f44 100644 --- a/nemo/collections/tts/parts/tacotron2.py +++ b/nemo/collections/tts/parts/tacotron2.py @@ -1,15 +1,12 @@ # Copyright (c) 2019 NVIDIA Corporation -from math import sqrt import torch from torch import nn from torch.autograd import Variable from torch.nn import functional as F -import nemo from nemo.collections.tts.parts.layers import ConvNorm, LinearNorm, get_mask_from_lengths - -logging = nemo.logging +from nemo.utils import logging class LocationLayer(nn.Module): diff --git a/nemo/collections/tts/tacotron2_modules.py b/nemo/collections/tts/tacotron2_modules.py index 2e4b235b1be1..5485728cd015 100644 --- a/nemo/collections/tts/tacotron2_modules.py +++ b/nemo/collections/tts/tacotron2_modules.py @@ -279,6 +279,42 @@ class Tacotron2DecoderInfer(Tacotron2Decoder): Defaults to 31. """ + def __init__( + self, + n_mel_channels: int, + n_frames_per_step: int = 1, + encoder_embedding_dim: int = 512, + gate_threshold: float = 0.5, + prenet_dim: int = 256, + max_decoder_steps: int = 1000, + decoder_rnn_dim: int = 1024, + p_decoder_dropout: float = 0.1, + p_attention_dropout: float = 0.1, + attention_rnn_dim: int = 1024, + attention_dim: int = 128, + attention_location_n_filters: int = 32, + attention_location_kernel_size: int = 31, + prenet_p_dropout: float = 0.5, + force: bool = False, + ): + super().__init__( + n_mel_channels=n_mel_channels, + n_frames_per_step=n_frames_per_step, + encoder_embedding_dim=encoder_embedding_dim, + gate_threshold=gate_threshold, + prenet_dim=prenet_dim, + max_decoder_steps=max_decoder_steps, + decoder_rnn_dim=decoder_rnn_dim, + p_decoder_dropout=p_decoder_dropout, + p_attention_dropout=p_attention_dropout, + attention_rnn_dim=attention_rnn_dim, + attention_dim=attention_dim, + attention_location_n_filters=attention_location_n_filters, + attention_location_kernel_size=attention_location_kernel_size, + prenet_p_dropout=prenet_p_dropout, + force=force, + ) + @property @add_port_docs() def input_ports(self): @@ -483,6 +519,9 @@ class MakeGate(NonTrainableNM): """MakeGate is a helper Neural Module that makes the target stop value. """ + def __init__(self): + super().__init__() + @property @add_port_docs() def input_ports(self): diff --git a/nemo/constants.py b/nemo/constants.py index 6cd3a1f60ff8..9d6793d7630a 100644 --- a/nemo/constants.py +++ b/nemo/constants.py @@ -47,4 +47,5 @@ # NEMO_ENV_VARNAME_DEBUG_VERBOSITY = "NEMO_DEBUG_VERBOSITY" NEMO_ENV_VARNAME_ENABLE_COLORING = "NEMO_ENABLE_COLORING" NEMO_ENV_VARNAME_REDIRECT_LOGS_TO_STDERR = "NEMO_REDIRECT_LOGS_TO_STDERR" +NEMO_ENV_VARNAME_TESTING = "NEMO_TESTING" # NEMO_ENV_VARNAME_SAVE_LOGS_TO_DIR = "NEMO_SAVE_LOGS_TO_DIR" diff --git a/nemo/core/callbacks.py b/nemo/core/callbacks.py index e465bf5bf95a..12e39a99c25f 100644 --- a/nemo/core/callbacks.py +++ b/nemo/core/callbacks.py @@ -25,7 +25,7 @@ from collections import namedtuple import nemo -from nemo.utils import get_checkpoint_from_dir +from nemo.utils import get_checkpoint_from_dir, logging try: import wandb @@ -34,8 +34,6 @@ except (ImportError, ModuleNotFoundError): _WANDB_AVAILABLE = False -logging = nemo.logging - class ActionCallback(ABC): """Abstract interface for callbacks. diff --git a/nemo/core/neural_factory.py b/nemo/core/neural_factory.py index b5c1930d0e06..59b266965669 100644 --- a/nemo/core/neural_factory.py +++ b/nemo/core/neural_factory.py @@ -36,10 +36,9 @@ from ..utils import ExpManager from .callbacks import ActionCallback, EvaluatorCallback from .neural_types import * +from nemo.utils import logging from nemo.utils.decorators import deprecated -logging = nemo.logging - class DeploymentFormat(Enum): """Which format to use when exporting a Neural Module for deployment""" diff --git a/nemo/core/neural_graph.py b/nemo/core/neural_graph.py index 917af83b117c..4e8e486eb1be 100644 --- a/nemo/core/neural_graph.py +++ b/nemo/core/neural_graph.py @@ -477,7 +477,7 @@ def export_to_config(self, config_file: str): YAML.dump(to_export, outfile) logging.info( - "Configuration of graph `{}` ({}) exported to {}".format(self.name, type(self).__name__, abs_path_file) + "Configuration of graph `{}` ({}) exported to '{}'".format(self.name, type(self).__name__, abs_path_file) ) def serialize(self) -> Dict[str, Any]: @@ -874,14 +874,21 @@ def summary(self) -> str: A nice, full graph summary. """ # Line "decorator". - desc = "\n" + 120 * '=' + "\n" + desc = "\n" + 113 * '=' + "\n" # 1. general information. - desc += "The `{}` Neural Graph:\n".format(self.name) + desc += "The `{}` Neural Graph [{}]".format(self.name, self.operation_mode) + if self.is_complete(): + desc += " [COMPLETE]:\n" + else: + desc += " [INCOMPLETE]:\n" # 2. modules. desc += " * Modules ({}):\n".format(len(self._modules)) for key, module in self._modules.items(): - desc += " * `{}` ({})\n".format(key, type(module).__name__) + if module.type == ModuleType.trainable and module.is_frozen(): + desc += " * `{}` ({}) [FROZEN]\n".format(key, type(module).__name__) + else: + desc += " * `{}` ({})\n".format(key, type(module).__name__) # 3. steps. desc += " * Steps ({}):\n".format(len(self._steps)) @@ -912,7 +919,7 @@ def summary(self) -> str: for output in outputs["mappings"]: desc += " * {}\n".format(output) # Line "decorator". - desc += 120 * '=' + desc += 113 * '=' # Return the result. return desc @@ -1030,9 +1037,10 @@ def restore_from(self, filename: str, module_names: Optional[List[str]] = None): try: # Get module. module = self._modules[name] - # Restore module weights - set_state_dict(module, chkpt["modules"][name]) - log_str += " * Module '{}' ({}) params loaded\n".format(module.name, type(module).__name__) + if module.type == ModuleType.trainable: + # Restore module weights + set_state_dict(module, chkpt["modules"][name]) + log_str += " * Module '{}' ({}) params loaded\n".format(module.name, type(module).__name__) except KeyError: log_str += " ! Module '{}' params not found in checkpoint\n".format(name) warning = True @@ -1042,3 +1050,32 @@ def restore_from(self, filename: str, module_names: Optional[List[str]] = None): logging.warning(log_str) else: logging.info(log_str) + + def is_complete(self) -> bool: + """ + Method checks if graph is "complete". In here the "complete" means that the graph has: + * exactly one DataLayer + * zero bound input ports + + In short it means that the graph can be complete. + + Returns: + True or false. + """ + has_datalayer = False + # Iterate through the modules one by one. + for module in self._modules.values(): + # Get module. + if module.type == ModuleType.datalayer: + if has_datalayer: + # More than one DL is not acceptable. + return False + else: + has_datalayer = True + + # Now check the ports. + if len(self._inputs) != 0: + return False + + # Else: + return True diff --git a/nemo/core/neural_modules.py b/nemo/core/neural_modules.py index 163db3ea3513..decb2a0acd35 100644 --- a/nemo/core/neural_modules.py +++ b/nemo/core/neural_modules.py @@ -121,10 +121,10 @@ def __extract_init_params(self) -> Dict[str, Any]: # Get the frame "call context". for frame in stack()[1:]: - # Get the call arguments. + # Get the current call arguments. localvars = getargvalues(frame[0]) - # Fill the parameters with call_args. + # Fill the parameters with call arguments. for key in to_set_params: if key in localvars.args: init_params[key] = localvars.locals[key] @@ -142,7 +142,7 @@ def __extract_init_params(self) -> Dict[str, Any]: if len(to_set_params) != 0: raise ValueError( "Could not collect all the signature params! " - F"Please file a bug on GitHub with the current stacktrace so that it can be resolved." + f"Please file a bug on GitHub with the current stack trace so that it can be reproduced." ) # print("! init_params of {}: {}\n".format(type(self).__name__, init_params)) @@ -228,7 +228,7 @@ def export_to_config(self, config_file: str): YAML.dump(to_export, outfile) logging.info( - "Configuration of module `{}` ({}) exported to {}".format(self.name, type(self).__name__, abs_path_file) + "Configuration of module `{}` ({}) exported to '{}'".format(self.name, type(self).__name__, abs_path_file) ) def serialize(self) -> Dict[str, Any]: diff --git a/nemo/utils/decorators/deprecated.py b/nemo/utils/decorators/deprecated.py index 80e330c4be56..d738c8a18031 100644 --- a/nemo/utils/decorators/deprecated.py +++ b/nemo/utils/decorators/deprecated.py @@ -22,7 +22,7 @@ from nemo.utils import logging -# logging = nemo.logging +# from nemo.utils import logging # Remember which deprecation warnings have been printed already. _PRINTED_WARNING = {} diff --git a/nemo/utils/exp_logging.py b/nemo/utils/exp_logging.py index 7f9a3ad9b0fa..fd3a0540ffe2 100644 --- a/nemo/utils/exp_logging.py +++ b/nemo/utils/exp_logging.py @@ -9,11 +9,11 @@ from nemo.utils.decorators import deprecated -# logging = nemo.logging +# from nemo.utils import logging @deprecated( version=0.11, explanation=( - "Please use nemo.logging instead by using logging = nemo.logging and logging.info(), " + "Please use nemo.logging instead by using from nemo.utils import logging and logging.info(), " "logging.warning() , etc." ), ) diff --git a/nemo/utils/formatters/base.py b/nemo/utils/formatters/base.py index 6b844877b185..12500477b9c8 100644 --- a/nemo/utils/formatters/base.py +++ b/nemo/utils/formatters/base.py @@ -126,3 +126,9 @@ def format(self, record): class BaseNeMoFormatter(BaseFormatter): DEFAULT_FORMAT = "%(color)s[NeMo %(levelname)1.1s %(asctime)s %(module)s:%(lineno)d]%(end_color)s %(message)s" + + +class DebugNeMoFormatter(BaseFormatter): + DEFAULT_FORMAT = ( + "%(color)s[NeMo %(levelname)1.1s %(asctime)s %(module)s:%(lineno)d rank:%(rank)d]%(end_color)s %(message)s" + ) diff --git a/nemo/utils/helpers.py b/nemo/utils/helpers.py index aa2cca686aea..b21b7200b58f 100644 --- a/nemo/utils/helpers.py +++ b/nemo/utils/helpers.py @@ -13,7 +13,7 @@ import nemo from nemo.utils import logging -# logging = nemo.logging +# from nemo.utils import logging def rgetattr(obj, attr, *args): diff --git a/nemo/utils/nemo_logging.py b/nemo/utils/nemo_logging.py index 1551acf84839..8a2bd06040d6 100644 --- a/nemo/utils/nemo_logging.py +++ b/nemo/utils/nemo_logging.py @@ -20,9 +20,9 @@ from contextlib import contextmanager # from nemo.constants import NEMO_ENV_VARNAME_SAVE_LOGS_TO_DIR -from nemo.constants import NEMO_ENV_VARNAME_REDIRECT_LOGS_TO_STDERR +from nemo.constants import NEMO_ENV_VARNAME_REDIRECT_LOGS_TO_STDERR, NEMO_ENV_VARNAME_TESTING from nemo.utils.env_var_parsing import get_envbool, get_envint -from nemo.utils.formatters.base import BaseNeMoFormatter +from nemo.utils.formatters.base import BaseNeMoFormatter, DebugNeMoFormatter from nemo.utils.metaclasses import Singleton __all__ = ["Logger", "LogMode"] @@ -88,7 +88,17 @@ def _define_logger(self): self._logger = _logging.getLogger("nemo_logger") # By default, silence all loggers except the logger for rank 0 self.remove_stream_handlers() - if get_envint("RANK", 0) == 0: + if get_envbool(NEMO_ENV_VARNAME_TESTING, False): + old_factory = _logging.getLogRecordFactory() + + def record_factory(*args, **kwargs): + record = old_factory(*args, **kwargs) + record.rank = get_envint("RANK", 0) + return record + + _logging.setLogRecordFactory(record_factory) + self.add_stream_handlers(formatter=DebugNeMoFormatter) + elif get_envint("RANK", 0) == 0: self.add_stream_handlers() finally: @@ -112,7 +122,7 @@ def remove_stream_handlers(self): except KeyError: pass - def add_stream_handlers(self): + def add_stream_handlers(self, formatter=BaseNeMoFormatter): if self._logger is None: raise RuntimeError("Impossible to set handlers if the Logger is not predefined") @@ -127,8 +137,6 @@ def add_stream_handlers(self): self._handlers["stream_stderr"] = _logging.StreamHandler(sys.stderr) self._handlers["stream_stderr"].addFilter(lambda record: record.levelno > _logging.INFO) - formatter = BaseNeMoFormatter - self._handlers["stream_stdout"].setFormatter(formatter()) self._logger.addHandler(self._handlers["stream_stdout"]) @@ -138,9 +146,9 @@ def add_stream_handlers(self): except KeyError: pass - def reset_stream_handler(self): + def reset_stream_handler(self, formatter=BaseNeMoFormatter): self.remove_stream_handlers() - self.add_stream_handlers() + self.add_stream_handlers(formatter=formatter) def add_file_handler(self, log_file): if self._logger is None: diff --git a/nemo/utils/neural_graph/graph_outputs.py b/nemo/utils/neural_graph/graph_outputs.py index 6c494d986b7f..6f14c6848cb8 100644 --- a/nemo/utils/neural_graph/graph_outputs.py +++ b/nemo/utils/neural_graph/graph_outputs.py @@ -75,12 +75,12 @@ def __init__(self, tensors_ref): # Tensors[step][output_port_name] passed from the external neural graph object. self._tensors_ref = tensors_ref - # This dictionary stores the output tensors collected during the "default" tensor recording. + # This dictionary stores the bound outputs collected during the "default" recording of produced tensors. # As they are using the default port names, the second/next tensor published on the same port # will generate a new unique name following the (step_number.module.port_name) pattern. self._default_outputs = {} - # This dictionary stores list of output tensors of module "manually" indicated by the user. + # This dictionary stores list of outputs of modules "manually" bound by the user. # In this case tring to overwriting the existing ports with new tensors will be forbidden (Exception). self._manual_outputs = {} diff --git a/nemo/utils/neural_graph/neural_graph_manager.py b/nemo/utils/neural_graph/neural_graph_manager.py index b016b57dc3b8..b8b2e1deeb1f 100644 --- a/nemo/utils/neural_graph/neural_graph_manager.py +++ b/nemo/utils/neural_graph/neural_graph_manager.py @@ -45,11 +45,14 @@ def summary(self) -> str: Returns: A summary of the graphs on the list. """ - # TODO: a nicer summary. ;) - desc = "List of graphs:" + # Line "decorator". + summary = "\n" + 113 * '=' + "\n" + summary += "Registry of {}s:\n".format(self._base_type_name) for graph in self: - desc = desc + "`{}`: {}\n".format(graph.name, graph) - return desc + summary += " * {} ({}) [{}]\n".format(graph.name, len(graph), graph.operation_mode) + # Line "decorator". + summary += 113 * '=' + return summary @property def active_graph(self) -> "NeuralGraph": diff --git a/nemo/utils/neural_graph/object_registry.py b/nemo/utils/neural_graph/object_registry.py index 8e861e529944..464cda92f219 100644 --- a/nemo/utils/neural_graph/object_registry.py +++ b/nemo/utils/neural_graph/object_registry.py @@ -137,7 +137,11 @@ def summary(self) -> str: Returns: A summary of the objects on the list. """ - summary = "Registry of {}s:\n".format(self._base_type_name) + # Line "decorator". + summary = "\n" + 113 * '=' + "\n" + summary += "Registry of {}s:\n".format(self._base_type_name) for obj in self: summary += " * {} ({})\n".format(obj.name, type(obj).__name__) + # Line "decorator". + summary += 113 * '=' return summary diff --git a/requirements/requirements.txt b/requirements/requirements.txt index 204ed8dbee7f..5d46fee518a4 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -11,3 +11,4 @@ wget wrapt ruamel.yaml sklearn +scipy diff --git a/requirements/requirements_nlp.txt b/requirements/requirements_nlp.txt index f923e8869ea9..f3cbda69415c 100644 --- a/requirements/requirements_nlp.txt +++ b/requirements/requirements_nlp.txt @@ -8,5 +8,8 @@ unidecode youtokentome numpy tqdm +sklearn +rapidfuzz gdown megatron-lm +inflect diff --git a/requirements/requirements_simple_gan.txt b/requirements/requirements_simple_gan.txt index 8f59cf99bbac..6ccafc3f904b 100644 --- a/requirements/requirements_simple_gan.txt +++ b/requirements/requirements_simple_gan.txt @@ -1,2 +1 @@ matplotlib -torchvision \ No newline at end of file diff --git a/requirements/requirements_tts.txt b/requirements/requirements_tts.txt index 61ff985cc778..3d5ac563c873 100644 --- a/requirements/requirements_tts.txt +++ b/requirements/requirements_tts.txt @@ -1,5 +1,3 @@ -librosa matplotlib pypinyin -scipy -attrdict \ No newline at end of file +attrdict diff --git a/scripts/convert_wav_to_g711wav.py b/scripts/convert_wav_to_g711wav.py new file mode 100644 index 000000000000..f882e5fc64cc --- /dev/null +++ b/scripts/convert_wav_to_g711wav.py @@ -0,0 +1,93 @@ +# Copyright 2020 NVIDIA. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# USAGE: +# python convert_wav_to_g711wav.py \ +# --data_dir= \ +# --dest_dir= +# +# Converts all wav audio files to PCM u-law wav files (8kHz, 8-bit). +# Requires sox to be installed. +import argparse +import concurrent.futures +import glob +import logging +import os +import subprocess + +from tqdm import tqdm + +parser = argparse.ArgumentParser(description='Convert wav audio to pcm mulaw wav') +parser.add_argument( + "--data_dir", default=None, type=str, required=True, help="The path to the input directory with .wav files.", +) +parser.add_argument( + "--dest_dir", default=None, type=str, required=True, help="Path to the destination directory.", +) +args = parser.parse_args() + + +def __convert_audio(in_path, out_path): + """ + Helper function that's called per thread, converts wav to G.711 wav. + Args: + in_path: source wav file to convert + out_path: destination for G.711 wav file + """ + cmd = ["sox", in_path, "-r", "8000", "-c", "1", "-e", "u-law", out_path] + subprocess.run(cmd) + + +def __process_set(data_dir, dst_root): + """ + Finds and converts all wav audio files in the given directory to pcm_mulaw. + Args: + data_dir: source directory with wav files to convert + dst_root: where G.711 (pcm_mulaw) wav files will be stored + """ + wav_list = glob.glob(data_dir) + + if not os.path.exists(dst_root): + os.makedirs(dst_root) + + # Set up and execute concurrent audio conversion + tp = concurrent.futures.ProcessPoolExecutor(max_workers=64) + futures = [] + + for wav_path in tqdm(wav_list, desc="Submitting wav futures", unit="file"): + audio_id = os.path.basename(wav_path) + out_path = os.path.join(dst_root, audio_id) + futures.append(tp.submit(__convert_audio, wav_path, out_path)) + + pbar = tqdm(total=len(wav_list), desc="Converting wav files", unit="file") + count = 0 + for f in concurrent.futures.as_completed(futures): + count += 1 + pbar.update() + tp.shutdown() + pbar.close() + + +def main(): + data_dir = args.data_dir + dest_dir = args.dest_dir + + logging.info("\n\nConverting audio in {}", data_dir) + __process_set( + os.path.join(data_dir, "*.wav",), os.path.join(dest_dir), + ) + + +if __name__ == '__main__': + main() diff --git a/scripts/export_jasper_to_onnx.py b/scripts/export_jasper_to_onnx.py index daa9459394a0..6df997e11c86 100644 --- a/scripts/export_jasper_to_onnx.py +++ b/scripts/export_jasper_to_onnx.py @@ -7,8 +7,7 @@ import nemo import nemo.collections.asr as nemo_asr - -logging = nemo.logging +from nemo.utils import logging def get_parser(): diff --git a/tests/integration/test_asr_gradient_step_and_eval.py b/tests/integration/test_asr_gradient_step_and_eval.py index d68898c076b6..9b883c4dfa4a 100644 --- a/tests/integration/test_asr_gradient_step_and_eval.py +++ b/tests/integration/test_asr_gradient_step_and_eval.py @@ -24,14 +24,13 @@ import pytest from ruamel.yaml import YAML -import nemo import nemo.collections.asr as nemo_asr - -logging = nemo.logging +from nemo.core import EvaluatorCallback, SimpleLossLoggerCallback +from nemo.utils import logging @pytest.mark.usefixtures("neural_factory") -class TestASRPytorch(TestCase): +class TestASRIntegrationPytorch(TestCase): labels = [ " ", "a", @@ -148,7 +147,7 @@ def test_jasper_training(self): ) loss_list = [] - callback = nemo.core.SimpleLossLoggerCallback( + callback = SimpleLossLoggerCallback( tensors=[loss], print_func=partial(self.print_and_log_loss, loss_log_list=loss_list), step_freq=1 ) @@ -200,7 +199,7 @@ def test_quartznet_training(self): ) loss_list = [] - callback = nemo.core.SimpleLossLoggerCallback( + callback = SimpleLossLoggerCallback( tensors=[loss], print_func=partial(self.print_and_log_loss, loss_log_list=loss_list), step_freq=1 ) @@ -258,7 +257,7 @@ def test_contextnet_ctc_training(self): ) loss_list = [] - callback = nemo.core.SimpleLossLoggerCallback( + callback = SimpleLossLoggerCallback( tensors=[loss], print_func=partial(self.print_and_log_loss, loss_log_list=loss_list), step_freq=1 ) @@ -315,7 +314,7 @@ def test_stft_conv_training(self): ) loss_list = [] - callback = nemo.core.SimpleLossLoggerCallback( + callback = SimpleLossLoggerCallback( tensors=[loss], print_func=partial(self.print_and_log_loss, loss_log_list=loss_list), step_freq=1 ) @@ -373,7 +372,7 @@ def test_jasper_evaluation(self): process_evaluation_epoch, ) - eval_callback = nemo.core.EvaluatorCallback( + eval_callback = EvaluatorCallback( eval_tensors=[loss, predictions, transcript, transcript_len], user_iter_callback=lambda x, y: process_evaluation_batch(x, y, labels=self.labels), user_epochs_done_callback=process_evaluation_epoch, diff --git a/tests/integration/test_integration_multidataset.py b/tests/integration/test_integration_multidataset.py index 892d3e08bcb4..4eee92058e8b 100644 --- a/tests/integration/test_integration_multidataset.py +++ b/tests/integration/test_integration_multidataset.py @@ -26,8 +26,7 @@ import nemo from nemo.backends.pytorch.common import DataCombination from nemo.core import ChannelType, NeuralType - -logging = nemo.logging +from nemo.utils import logging @pytest.mark.usefixtures("neural_factory") diff --git a/tests/integration/test_speaker_recognition_gradient_step.py b/tests/integration/test_speaker_recognition_gradient_step.py index ab062ddbad81..cf2535e9c9af 100644 --- a/tests/integration/test_speaker_recognition_gradient_step.py +++ b/tests/integration/test_speaker_recognition_gradient_step.py @@ -25,8 +25,7 @@ import nemo import nemo.collections.asr as nemo_asr - -logging = nemo.logging +from nemo.utils import logging @pytest.mark.usefixtures("neural_factory") diff --git a/tests/integration/test_speechcommands_gradient_step_and_eval.py b/tests/integration/test_speechcommands_gradient_step_and_eval.py index 2f6dcf3b2be2..c997ca98ad94 100644 --- a/tests/integration/test_speechcommands_gradient_step_and_eval.py +++ b/tests/integration/test_speechcommands_gradient_step_and_eval.py @@ -26,8 +26,7 @@ import nemo import nemo.collections.asr as nemo_asr - -logging = nemo.logging +from nemo.utils import logging @pytest.mark.usefixtures("neural_factory") diff --git a/tests/integration/test_tts_gradient_step.py b/tests/integration/test_tts_gradient_step.py index 8ffa18fb6269..8b8e500d81f4 100644 --- a/tests/integration/test_tts_gradient_step.py +++ b/tests/integration/test_tts_gradient_step.py @@ -25,11 +25,11 @@ import numpy as np import pytest -import nemo import nemo.collections.asr as nemo_asr import nemo.collections.tts as nemo_tts - -logging = nemo.logging +from nemo.backends.pytorch.actions import PtActions +from nemo.core import SimpleLossLoggerCallback +from nemo.utils import logging @pytest.mark.usefixtures("neural_factory") @@ -158,11 +158,11 @@ def test_tacotron2_training(self): ) loss_list = [] - callback = nemo.core.SimpleLossLoggerCallback( + callback = SimpleLossLoggerCallback( tensors=[loss_t], print_func=partial(self.print_and_log_loss, loss_log_list=loss_list), step_freq=1 ) # Instantiate an optimizer to perform `train` action - optimizer = nemo.backends.pytorch.actions.PtActions() + optimizer = PtActions() optimizer.train( [loss_t], callbacks=[callback], optimizer="sgd", optimization_params={"max_steps": 3, "lr": 0.01} ) @@ -212,11 +212,11 @@ def test_waveglow_training(self): loss_t = waveglow_loss(z=z, log_s_list=log_s_list, log_det_W_list=log_det_W_list) loss_list = [] - callback = nemo.core.SimpleLossLoggerCallback( + callback = SimpleLossLoggerCallback( tensors=[loss_t], print_func=partial(self.print_and_log_loss, loss_log_list=loss_list), step_freq=1 ) # Instantiate an optimizer to perform `train` action - optimizer = nemo.backends.pytorch.actions.PtActions() + optimizer = PtActions() optimizer.train( [loss_t], callbacks=[callback], optimizer="sgd", optimization_params={"max_steps": 3, "lr": 0.01} ) @@ -314,11 +314,11 @@ def test_fastspeech(self): ) loss_list = [] - callback = nemo.core.SimpleLossLoggerCallback( + callback = SimpleLossLoggerCallback( tensors=[loss_t], print_func=partial(self.print_and_log_loss, loss_log_list=loss_list), step_freq=1 ) # Instantiate an optimizer to perform `train` action - optimizer = nemo.backends.pytorch.actions.PtActions() + optimizer = PtActions() optimizer.train( [loss_t], callbacks=[callback], optimizer="sgd", optimization_params={"max_steps": 3, "lr": 0.0003} ) diff --git a/tests/unit/core/test_weight_share.py b/tests/unit/core/test_weight_share.py index 53c6dad81356..165db51f923b 100644 --- a/tests/unit/core/test_weight_share.py +++ b/tests/unit/core/test_weight_share.py @@ -34,8 +34,7 @@ from nemo.collections.nlp.nm.trainables.common import TokenClassifier from nemo.core import WeightShareTransform from nemo.core.neural_types import * - -logging = nemo.logging +from nemo.utils import logging @pytest.mark.usefixtures("neural_factory") diff --git a/tests/unit/test_unit_asr.py b/tests/unit/test_unit_asr.py index a664ac03fd23..ff6cc6985878 100644 --- a/tests/unit/test_unit_asr.py +++ b/tests/unit/test_unit_asr.py @@ -29,15 +29,13 @@ import nemo.collections.asr as nemo_asr from nemo.collections.asr.parts import AudioDataset, WaveformFeaturizer, collections, parsers from nemo.core import DeviceType - -logging = nemo.logging - +from nemo.utils import logging freq = 16000 @pytest.mark.usefixtures("neural_factory") -class TestASRPytorch(TestCase): +class TestUnitASRPytorch(TestCase): labels = [ " ", "a", diff --git a/tests/unit/test_unit_multidataset.py b/tests/unit/test_unit_multidataset.py index 9d8384df8ac4..1ef74caeadaf 100644 --- a/tests/unit/test_unit_multidataset.py +++ b/tests/unit/test_unit_multidataset.py @@ -26,8 +26,7 @@ import nemo from nemo.backends.pytorch.common import DataCombination from nemo.core import ChannelType, NeuralType - -logging = nemo.logging +from nemo.utils import logging @pytest.mark.usefixtures("neural_factory") diff --git a/tests/unit/test_unit_speech_commands.py b/tests/unit/test_unit_speech_commands.py index d8563ceafd3a..3077c08708b1 100644 --- a/tests/unit/test_unit_speech_commands.py +++ b/tests/unit/test_unit_speech_commands.py @@ -29,9 +29,7 @@ import nemo.collections.asr as nemo_asr from nemo.collections.asr.parts import AudioLabelDataset, WaveformFeaturizer, collections, parsers, perturb from nemo.core import DeviceType - -logging = nemo.logging - +from nemo.utils import logging freq = 16000 diff --git a/tests/unit/utils/test_deprecated.py b/tests/unit/utils/test_deprecated.py index 2ae3e5cb156f..4f1c9490e60f 100644 --- a/tests/unit/utils/test_deprecated.py +++ b/tests/unit/utils/test_deprecated.py @@ -30,7 +30,7 @@ class DeprecatedTest(TestCase): NEMO_ERR_MSG_FORMAT = re.compile( - r"\[NeMo W [0-9]{4}-[0-9]{2}-[0-9]{2} [0-9]{2}:[0-9]{2}:[0-9]{2} deprecated:[0-9]*\] " + r"\[NeMo W [0-9]{4}-[0-9]{2}-[0-9]{2} [0-9]{2}:[0-9]{2}:[0-9]{2} deprecated:[0-9]+( rank:[0-9]+)?\] " ) @pytest.mark.unit