From 4a7c3e3df40a9af6af12e957ca704ed651ed8ff6 Mon Sep 17 00:00:00 2001 From: Eric Harper Date: Fri, 30 Jul 2021 11:53:07 -0600 Subject: [PATCH] Merge 1.2 bugfixes into main (#2588) * update jenkinsfile Signed-off-by: ericharper * update BRANCH Signed-off-by: ericharper * Fix onnx for ASR notebook (#2542) * Update onnx version Signed-off-by: smajumdar * Fix onnx Signed-off-by: smajumdar * Fix onnx Signed-off-by: smajumdar * Fix typos and MeCab import (#2541) Signed-off-by: MaximumEntropy * Fix branch for ASR notebooks (#2549) Signed-off-by: smajumdar * rmtok (#2559) Signed-off-by: Abhinav Khattar * Add xxhash dependency (#2564) Signed-off-by: MaximumEntropy * fix (#2566) * fix Signed-off-by: nithinraok * doc add Signed-off-by: nithinraok * style fix Signed-off-by: nithinraok * Fix moses path issue (#2573) Signed-off-by: MaximumEntropy * More moses data path fixes (#2575) Signed-off-by: MaximumEntropy * Path fixes (#2580) Signed-off-by: MaximumEntropy * Upper bound transformers for 1.2 (#2584) * upper bound transformers and name change jarvis to riva Signed-off-by: ericharper * upper bound transformers and name change jarvis to riva Signed-off-by: ericharper * update jenkinsfile Signed-off-by: ericharper * update notebooks branch Signed-off-by: ericharper * update notebooks branch Signed-off-by: ericharper * update notebooks branch Signed-off-by: ericharper Co-authored-by: Somshubra Majumdar Co-authored-by: Sandeep Subramanian Co-authored-by: Abhinav Khattar Co-authored-by: Nithin Rao --- .../asr/parts/utils/nmse_clustering.py | 86 +++++---- .../machine_translation/mt_enc_dec_model.py | 8 +- tutorials/AudioTranslationSample.ipynb | 166 +++++++++--------- tutorials/VoiceSwapSample.ipynb | 2 +- tutorials/asr/01_ASR_with_NeMo.ipynb | 1 + .../08_ASR_with_Subword_Tokenization.ipynb | 2 +- .../asr/10_ASR_CTC_Language_Finetuning.ipynb | 2 +- ...a_Preprocessing_and_Cleaning_for_NMT.ipynb | 38 ++-- 8 files changed, 169 insertions(+), 136 deletions(-) diff --git a/nemo/collections/asr/parts/utils/nmse_clustering.py b/nemo/collections/asr/parts/utils/nmse_clustering.py index 958654146ea3..e65c62ffa60e 100644 --- a/nemo/collections/asr/parts/utils/nmse_clustering.py +++ b/nemo/collections/asr/parts/utils/nmse_clustering.py @@ -40,6 +40,16 @@ scaler = MinMaxScaler(feature_range=(0, 1)) +try: + from torch.linalg import eigh as eigh + + TORCH_EIGN = True +except ImportError: + TORCH_EIGN = False + from scipy.linalg import eigh as eigh + + logging.warning("Using eigen decomposition from scipy, upgrade torch to 1.9 or higher for faster clustering") + def isGraphFullyConnected(affinity_mat): return getTheLargestComponent(affinity_mat, 0).sum() == affinity_mat.shape[0] @@ -120,7 +130,7 @@ def getCosAffinityMatrix(emb): def getLaplacian(X): """ - Calculates a Laplacian matrix from an affinity matrix X. + Calculates a laplacian matrix from an affinity matrix X. """ X[np.diag_indices(X.shape[0])] = 0 A = X @@ -130,19 +140,46 @@ def getLaplacian(X): return L -def eigDecompose(Laplacian, cuda, device=None): - if cuda: - if device == None: - device = torch.cuda.current_device() - laplacian_torch = torch.from_numpy(Laplacian).float().to(device) +def eigDecompose(laplacian, cuda, device=None): + if TORCH_EIGN: + if cuda: + if device is None: + device = torch.cuda.current_device() + laplacian = torch.from_numpy(laplacian).float().to(device) + else: + laplacian = torch.from_numpy(laplacian).float() + lambdas, diffusion_map = eigh(laplacian) + lambdas = lambdas.cpu().numpy() + diffusion_map = diffusion_map.cpu().numpy() else: - laplacian_torch = torch.from_numpy(Laplacian).float() - lambdas_torch, diffusion_map_torch = torch.linalg.eigh(laplacian_torch) - lambdas = lambdas_torch.cpu().numpy() - diffusion_map = diffusion_map_torch.cpu().numpy() + lambdas, diffusion_map = eigh(laplacian) + return lambdas, diffusion_map +def getLamdaGaplist(lambdas): + lambdas = np.real(lambdas) + return list(lambdas[1:] - lambdas[:-1]) + + +def estimateNumofSpeakers(affinity_mat, max_num_speaker, is_cuda=False): + """ + Estimates the number of speakers using eigen decompose on laplacian Matrix. + affinity_mat: (array) + NxN affitnity matrix + max_num_speaker: (int) + Maximum number of clusters to consider for each session + is_cuda: (bool) + if cuda availble eigh decomposition would be computed on GPUs + """ + laplacian = getLaplacian(affinity_mat) + lambdas, _ = eigDecompose(laplacian, is_cuda) + lambdas = np.sort(lambdas) + lambda_gap_list = getLamdaGaplist(lambdas) + num_of_spk = np.argmax(lambda_gap_list[: min(max_num_speaker, len(lambda_gap_list))]) + 1 + return num_of_spk, lambdas, lambda_gap_list + + class _SpectralClustering: def __init__(self, n_clusters=8, random_state=0, n_init=10, p_value=10, n_jobs=None, cuda=False): self.n_clusters = n_clusters @@ -170,8 +207,8 @@ def getSpectralEmbeddings(self, affinity_mat, n_spks=8, drop_first=True, cuda=Fa if not isGraphFullyConnected(affinity_mat): logging.warning("Graph is not fully connected and the clustering result might not be accurate.") - Laplacian = getLaplacian(affinity_mat) - lambdas_, diffusion_map_ = eigDecompose(Laplacian, cuda) + laplacian = getLaplacian(affinity_mat) + lambdas_, diffusion_map_ = eigDecompose(laplacian, cuda) lambdas = lambdas_[:n_spks] diffusion_map = diffusion_map_[:, :n_spks] embedding = diffusion_map.T[n_spks::-1] @@ -363,7 +400,7 @@ def getEigRatio(self, p_neighbors): """ affinity_mat = getAffinityGraphMat(self.mat, p_neighbors) - est_num_of_spk, lambdas, lambda_gap_list = self.estimateNumofSpeakers(affinity_mat) + est_num_of_spk, lambdas, lambda_gap_list = estimateNumofSpeakers(affinity_mat, self.max_num_speaker, self.cuda) arg_sorted_idx = np.argsort(lambda_gap_list[: self.max_num_speaker])[::-1] max_key = arg_sorted_idx[0] max_eig_gap = lambda_gap_list[max_key] / (max(lambdas) + self.eps) @@ -388,21 +425,6 @@ def getPvalueList(self): return p_value_list - def getLamdaGaplist(self, lambdas): - lambdas = np.real(lambdas) - return list(lambdas[1:] - lambdas[:-1]) - - def estimateNumofSpeakers(self, affinity_mat): - """ - Estimates the number of speakers using eigen decompose on Laplacian Matrix. - """ - Laplacian = getLaplacian(affinity_mat) - lambdas, _ = eigDecompose(Laplacian, self.cuda) - lambdas = np.sort(lambdas) - lambda_gap_list = self.getLamdaGaplist(lambdas) - num_of_spk = np.argmax(lambda_gap_list[: min(self.max_num_speaker, len(lambda_gap_list))]) + 1 - return num_of_spk, lambdas, lambda_gap_list - def COSclustering(key, emb, oracle_num_speakers=None, max_num_speaker=8, min_samples=6, fixed_thres=None, cuda=False): """ @@ -423,8 +445,9 @@ def COSclustering(key, emb, oracle_num_speakers=None, max_num_speaker=8, min_sam min_samples: (int) Minimum number of samples required for NME clustering, this avoids - zero p_neighbour_lists. Default of 6 is selected since (1/rp_threshold) >= 4. - + zero p_neighbour_lists. Default of 6 is selected since (1/rp_threshold) >= 4 + when max_rp_threshold = 0.25. Thus, NME analysis is skipped for matrices + smaller than (min_samples)x(min_samples). Returns: Y: (List[int]) Speaker label for each segment. @@ -443,12 +466,13 @@ def COSclustering(key, emb, oracle_num_speakers=None, max_num_speaker=8, min_sam NME_mat_size=300, cuda=cuda, ) - est_num_of_spk, p_hat_value = nmesc.NMEanalysis() if emb.shape[0] > min_samples: + est_num_of_spk, p_hat_value = nmesc.NMEanalysis() affinity_mat = getAffinityGraphMat(mat, p_hat_value) else: affinity_mat = mat + est_num_of_spk, _, _ = estimateNumofSpeakers(affinity_mat, max_num_speaker, cuda) if oracle_num_speakers: est_num_of_spk = oracle_num_speakers diff --git a/nemo/collections/nlp/models/machine_translation/mt_enc_dec_model.py b/nemo/collections/nlp/models/machine_translation/mt_enc_dec_model.py index b9601b1aa0d9..c49cd7dcb1fd 100644 --- a/nemo/collections/nlp/models/machine_translation/mt_enc_dec_model.py +++ b/nemo/collections/nlp/models/machine_translation/mt_enc_dec_model.py @@ -103,7 +103,7 @@ def __init__(self, cfg: MTEncDecModelConfig, trainer: Trainer = None): ) elif isinstance(self.src_language, ListConfig): for lng in self.src_language: - self.multilingual_ids.append(self.encoder_tokenizer.token_to_id("<" + lng + ">")) + self.multilingual_ids.append(None) elif isinstance(self.tgt_language, ListConfig): for lng in self.tgt_language: self.multilingual_ids.append(self.encoder_tokenizer.token_to_id("<" + lng + ">")) @@ -773,7 +773,11 @@ def translate( raise ValueError("Expect source_lang and target_lang to infer for multilingual model.") src_symbol = self.encoder_tokenizer.token_to_id('<' + source_lang + '>') tgt_symbol = self.encoder_tokenizer.token_to_id('<' + target_lang + '>') - prepend_ids = [src_symbol if src_symbol in self.multilingual_ids else tgt_symbol] + if src_symbol in self.multilingual_ids: + prepend_ids = [src_symbol] + elif tgt_symbol in self.multilingual_ids: + prepend_ids = [tgt_symbol] + try: self.eval() src, src_mask = self.prepare_inference_batch(text, prepend_ids) diff --git a/tutorials/AudioTranslationSample.ipynb b/tutorials/AudioTranslationSample.ipynb index 58ed9a1a444c..94b03c581497 100644 --- a/tutorials/AudioTranslationSample.ipynb +++ b/tutorials/AudioTranslationSample.ipynb @@ -2,9 +2,6 @@ "cells": [ { "cell_type": "markdown", - "metadata": { - "id": "RYGnI-EZp_nK" - }, "source": [ "# Getting Started: Sample Conversational AI application\n", "This notebook shows how to use NVIDIA NeMo (https://github.com/NVIDIA/NeMo) to construct a toy demo which translate Mandarin audio file into English one.\n", @@ -15,49 +12,48 @@ "* Transcribe audio with (Mandarin) speech recognition model.\n", "* Translate text with machine translation model.\n", "* Generate audio with text-to-speech models." - ] + ], + "metadata": { + "id": "RYGnI-EZp_nK" + } }, { "cell_type": "markdown", - "metadata": { - "id": "V72HXYuQ_p9a" - }, "source": [ "## Installation\n", "NeMo can be installed via simple pip command.\n", "This will take about 4 minutes.\n", "\n", "(The installation method below should work inside your new Conda environment or in an NVIDIA docker container.)" - ] + ], + "metadata": { + "id": "V72HXYuQ_p9a" + } }, { "cell_type": "code", "execution_count": null, - "metadata": { - "id": "efDmTWf1_iYK" - }, - "outputs": [], "source": [ "BRANCH = 'main'\n", "!python -m pip install git+https://github.com/NVIDIA/NeMo.git@$BRANCH#egg=nemo_toolkit[all]" - ] + ], + "outputs": [], + "metadata": { + "id": "efDmTWf1_iYK" + } }, { "cell_type": "markdown", - "metadata": { - "id": "EyJ5HiiPrPKA" - }, "source": [ "## Import all necessary packages" - ] + ], + "metadata": { + "id": "EyJ5HiiPrPKA" + } }, { "cell_type": "code", "execution_count": null, - "metadata": { - "id": "tdUqxeUEA8nw" - }, - "outputs": [], "source": [ "# Import NeMo and it's ASR, NLP and TTS collections\n", "import nemo\n", @@ -69,13 +65,14 @@ "import nemo.collections.tts as nemo_tts\n", "# We'll use this to listen to audio\n", "import IPython" - ] + ], + "outputs": [], + "metadata": { + "id": "tdUqxeUEA8nw" + } }, { "cell_type": "markdown", - "metadata": { - "id": "bt2EZyU3A1aq" - }, "source": [ "## Instantiate pre-trained NeMo models\n", "\n", @@ -84,30 +81,28 @@ "* ``list_available_models()`` - it will list all models currently available on NGC and their names.\n", "\n", "* ``from_pretrained(...)`` API downloads and initialized model directly from the NGC using model name.\n" - ] + ], + "metadata": { + "id": "bt2EZyU3A1aq" + } }, { "cell_type": "code", "execution_count": null, - "metadata": { - "id": "YNNHs5Xjr8ox", - "scrolled": true - }, - "outputs": [], "source": [ "# Here is an example of all CTC-based models:\n", "nemo_asr.models.EncDecCTCModel.list_available_models()\n", "# More ASR Models are available - see: nemo_asr.models.ASRModel.list_available_models()" - ] + ], + "outputs": [], + "metadata": { + "id": "YNNHs5Xjr8ox", + "scrolled": true + } }, { "cell_type": "code", "execution_count": null, - "metadata": { - "id": "1h9nhICjA5Dk", - "scrolled": true - }, - "outputs": [], "source": [ "# Speech Recognition model - Citrinet initially trained on Multilingual LibriSpeech English corpus, and fine-tuned on the open source Aishell-2\n", "asr_model = nemo_asr.models.EncDecCTCModel.from_pretrained(model_name=\"stt_zh_citrinet_1024_gamma_0_25\").cuda()\n", @@ -117,24 +112,25 @@ "spectrogram_generator = nemo_tts.models.FastPitchModel.from_pretrained(model_name=\"tts_en_fastpitch\").cuda()\n", "# Vocoder model which takes spectrogram and produces actual audio\n", "vocoder = nemo_tts.models.HifiGanModel.from_pretrained(model_name=\"tts_hifigan\").cuda()" - ] + ], + "outputs": [], + "metadata": { + "id": "1h9nhICjA5Dk", + "scrolled": true + } }, { "cell_type": "markdown", - "metadata": { - "id": "KPota-JtsqSY" - }, "source": [ "## Get an audio sample in Mandarin" - ] + ], + "metadata": { + "id": "KPota-JtsqSY" + } }, { "cell_type": "code", "execution_count": null, - "metadata": { - "id": "7cGCEKkcLr52" - }, - "outputs": [], "source": [ "# Download audio sample which we'll try\n", "# This is a sample from MCV 6.1 Dev dataset - the model hasn't seen it before\n", @@ -143,71 +139,71 @@ "!wget 'https://nemo-public.s3.us-east-2.amazonaws.com/zh-samples/common_voice_zh-CN_21347786.mp3'\n", "# To listen it, click on the play button below\n", "IPython.display.Audio(audio_sample)" - ] + ], + "outputs": [], + "metadata": { + "id": "7cGCEKkcLr52" + } }, { "cell_type": "markdown", - "metadata": { - "id": "BaCdNJhhtBfM" - }, "source": [ "## Transcribe audio file\n", "We will use speech recognition model to convert audio into text.\n" - ] + ], + "metadata": { + "id": "BaCdNJhhtBfM" + } }, { "cell_type": "code", "execution_count": null, - "metadata": { - "id": "KTA7jM6sL6yC" - }, - "outputs": [], "source": [ "transcribed_text = asr_model.transcribe([audio_sample])\n", "print(transcribed_text)" - ] + ], + "outputs": [], + "metadata": { + "id": "KTA7jM6sL6yC" + } }, { "cell_type": "markdown", - "metadata": { - "id": "BjYb2TMtttCc" - }, "source": [ "## Translate Chinese text into English\n", "NeMo's NMT models have a handy ``.translate()`` method." - ] + ], + "metadata": { + "id": "BjYb2TMtttCc" + } }, { "cell_type": "code", "execution_count": null, - "metadata": { - "id": "kQTdE4b9Nm9O" - }, - "outputs": [], "source": [ "english_text = nmt_model.translate(transcribed_text)\n", "print(english_text)" - ] + ], + "outputs": [], + "metadata": { + "id": "kQTdE4b9Nm9O" + } }, { "cell_type": "markdown", - "metadata": { - "id": "9Rppc59Ut7uy" - }, "source": [ "## Generate English audio from text\n", "Speech generation from text typically has two steps:\n", "* Generate spectrogram from the text. In this example we will use FastPitch model for this.\n", "* Generate actual audio from the spectrogram. In this example we will use HifiGan model for this.\n" - ] + ], + "metadata": { + "id": "9Rppc59Ut7uy" + } }, { "cell_type": "code", "execution_count": null, - "metadata": { - "id": "wpMYfufgNt15" - }, - "outputs": [], "source": [ "# A helper function which combines FastPitch and HifiGan to go directly from \n", "# text to audio\n", @@ -216,26 +212,27 @@ " spectrogram = spectrogram_generator.generate_spectrogram(tokens=parsed)\n", " audio = vocoder.convert_spectrogram_to_audio(spec=spectrogram)\n", " return audio.to('cpu').detach().numpy()" - ] + ], + "outputs": [], + "metadata": { + "id": "wpMYfufgNt15" + } }, { "cell_type": "code", "execution_count": null, - "metadata": {}, - "outputs": [], "source": [ "# Listen to generated audio in English\n", "IPython.display.Audio(text_to_audio(english_text[0]), rate=22050)" - ] + ], + "outputs": [], + "metadata": {} }, { "cell_type": "markdown", - "metadata": { - "id": "LiQ_GQpcBYUs" - }, "source": [ "## Next steps\n", - "A demo like this is great for prototyping and experimentation. However, for real production deployment, you would want to use a service like [NVIDIA Jarvis](https://developer.nvidia.com/nvidia-jarvis).\n", + "A demo like this is great for prototyping and experimentation. However, for real production deployment, you would want to use a service like [NVIDIA Riva](https://developer.nvidia.com/riva).\n", "\n", "**NeMo is built for training.** You can fine-tune, or train from scratch on your data all models used in this example. We recommend you checkout the following, more in-depth, tutorials next:\n", "\n", @@ -247,7 +244,10 @@ "\n", "\n", "You can find scripts for training and fine-tuning ASR, NLP and TTS models [here](https://github.com/NVIDIA/NeMo/tree/main/examples). " - ] + ], + "metadata": { + "id": "LiQ_GQpcBYUs" + } } ], "metadata": { @@ -277,4 +277,4 @@ }, "nbformat": 4, "nbformat_minor": 1 -} +} \ No newline at end of file diff --git a/tutorials/VoiceSwapSample.ipynb b/tutorials/VoiceSwapSample.ipynb index 14dc05902cfb..42591a1bafdb 100644 --- a/tutorials/VoiceSwapSample.ipynb +++ b/tutorials/VoiceSwapSample.ipynb @@ -270,7 +270,7 @@ "cell_type": "markdown", "source": [ "## Next steps\n", - "A demo like this is great for prototyping and experimentation. However, for real production deployment, you would want to use a service like [NVIDIA Jarvis](https://developer.nvidia.com/nvidia-jarvis).\n", + "A demo like this is great for prototyping and experimentation. However, for real production deployment, you would want to use a service like [NVIDIA Riva](https://developer.nvidia.com/riva).\n", "\n", "**NeMo is built for training.** You can fine-tune, or train from scratch on your data all models used in this example. We recommend you checkout the following, more in-depth, tutorials next:\n", "\n", diff --git a/tutorials/asr/01_ASR_with_NeMo.ipynb b/tutorials/asr/01_ASR_with_NeMo.ipynb index 06c4b03d393e..33d77e15ebd3 100644 --- a/tutorials/asr/01_ASR_with_NeMo.ipynb +++ b/tutorials/asr/01_ASR_with_NeMo.ipynb @@ -997,6 +997,7 @@ "id": "I4WRcmakjQnj" }, "source": [ + "!pip install --upgrade onnxruntime onnxruntime-gpu\n", "#!mkdir -p ort\n", "#%cd ort\n", "#!git clean -xfd\n", diff --git a/tutorials/asr/08_ASR_with_Subword_Tokenization.ipynb b/tutorials/asr/08_ASR_with_Subword_Tokenization.ipynb index c103b2907d79..784baa6a8691 100644 --- a/tutorials/asr/08_ASR_with_Subword_Tokenization.ipynb +++ b/tutorials/asr/08_ASR_with_Subword_Tokenization.ipynb @@ -40,7 +40,7 @@ "!pip install matplotlib>=3.3.2\n", "\n", "## Install NeMo\n", - "BRANCH = \"v1.0.2\"\n", + "BRANCH = \"main\"\n", "!python -m pip install git+https://github.com/NVIDIA/NeMo.git@$BRANCH#egg=nemo_toolkit[all]\n", "\n", "## Grab the config we'll use in this example\n", diff --git a/tutorials/asr/10_ASR_CTC_Language_Finetuning.ipynb b/tutorials/asr/10_ASR_CTC_Language_Finetuning.ipynb index 65e76e1832f3..f803698e9306 100644 --- a/tutorials/asr/10_ASR_CTC_Language_Finetuning.ipynb +++ b/tutorials/asr/10_ASR_CTC_Language_Finetuning.ipynb @@ -39,7 +39,7 @@ "!pip install matplotlib>=3.3.2\n", "\n", "## Install NeMo\n", - "BRANCH = \"r1.1.0\"\n", + "BRANCH = \"main\"\n", "!python -m pip install git+https://github.com/NVIDIA/NeMo.git@$BRANCH#egg=nemo_toolkit[all]\n", "\n", "\"\"\"\n", diff --git a/tutorials/nlp/Data_Preprocessing_and_Cleaning_for_NMT.ipynb b/tutorials/nlp/Data_Preprocessing_and_Cleaning_for_NMT.ipynb index cfa4a79aa2a5..5b5fdcd777b0 100644 --- a/tutorials/nlp/Data_Preprocessing_and_Cleaning_for_NMT.ipynb +++ b/tutorials/nlp/Data_Preprocessing_and_Cleaning_for_NMT.ipynb @@ -26,6 +26,10 @@ "BRANCH = 'main'\n", "!python -m pip install git+https://github.com/NVIDIA/NeMo.git@$BRANCH#egg=nemo_toolkit[all]\n", "\n", + "!pip uninstall -y sacrebleu\n", + "!pip install sacrebleu[ja]\n", + "!pip install xxhash\n", + "\n", "## Install kenlm with 7-gram support\n", "!mkdir -p data\n", "!rm -rf data/kenlm\n", @@ -74,7 +78,7 @@ "1. Downloading and filtering publicly available datasets based on confidence thresholds (if available). For example, [WikiMatrix](https://arxiv.org/abs/1907.05791) filtering based on [LASER](https://arxiv.org/abs/1812.10464) confidence scores.\n", "2. Language ID filtering using a pre-trained [fastText classifier](https://fasttext.cc/docs/en/language-identification.html). This step will remove all sentences from the parallel corpus that our classifier predicts as not being in the appropriate language (ex: sentences in the English column that aren't in English or sentences in Russian column that aren't in Russian).\n", "3. Length and Length-ratio filtering. This steps removes all sentences that are 1) too long 2) too short or 3) have a ratio between their lengths greater than a certain factor (this typically removes partial translations).\n", - "4. [Bicleaner](https://github.com/bitextor/bicleaner) classifier-based cleaning. Bicleaner identifies noisy parallel senteces using a classifier that leverages multiple features such as n-gram language model likelihood scores, word alignment scores and other heuristics.\n", + "4. [Bicleaner](https://github.com/bitextor/bicleaner) classifier-based cleaning. Bicleaner identifies noisy parallel sentences using a classifier that leverages multiple features such as n-gram language model likelihood scores, word alignment scores and other heuristics.\n", "\n", "## Pre-processing\n", "5. [Moses Punctuation Normalization](https://github.com/moses-smt/mosesdecoder/blob/master/scripts/tokenizer/normalize-punctuation.perl). This step standardizes punctuation. For example the less common way to write apostrophes Tiffany`s will be standardized to Tiffany's.\n", @@ -84,7 +88,7 @@ "9. Shuffling - This step shuffles the order of occurrence of translation pairs.\n", "\n", "## Tarred Datasets for Large Corpora\n", - "10. Large datasts with over 50M sentence pairs when batched and pickled can be upto 60GB in size. Loading them entirely into CPU memory when using say 8 or 16 workers with DistributedDataParallel training uses 480-960GB of RAM which is often impractical and inefficient. Instead, we use [Webdataset](https://github.com/webdataset/webdataset) to allow training while keeping datasets on disk and let webddataset handle pre-loading and fetching of data into CPU RAM.\n", + "10. Large datasets with over 50M sentence pairs when batched and pickled can be up to 60GB in size. Loading them entirely into CPU memory when using say 8 or 16 workers with DistributedDataParallel training uses 480-960GB of RAM which is often impractical and inefficient. Instead, we use [Webdataset](https://github.com/webdataset/webdataset) to allow training while keeping datasets on disk and let webddataset handle pre-loading and fetching of data into CPU RAM.\n", "\n", "\n", "## Disclaimer\n", @@ -318,7 +322,7 @@ "\n", "1. Pre-filtering based on 37 rules.\n", "2. Language model fluency scores based on n-gram language models trained with kenlm.\n", - "3. Random forest clasifier that uses all examples filtered out in steps 1 & 2 as \"negative\" examples." + "3. Random forest classifier that uses all examples filtered out in steps 1 & 2 as \"negative\" examples." ] }, { @@ -432,12 +436,12 @@ "outputs": [], "source": [ "print('Normalizing English ...')\n", - "!perl mosesdecoder/scripts/tokenizer/normalize-punctuation.perl -l en \\\n", + "!perl data/mosesdecoder/scripts/tokenizer/normalize-punctuation.perl -l en \\\n", " < data/WikiMatrix.en-ru.langidfilter.lengthratio.bicleaner.60.en > \\\n", " data/WikiMatrix.en-ru.langidfilter.lengthratio.bicleaner.60.moses.norm.en\n", "\n", "print('Normalizing Russian ...')\n", - "!perl mosesdecoder/scripts/tokenizer/normalize-punctuation.perl -l ru \\\n", + "!perl data/mosesdecoder/scripts/tokenizer/normalize-punctuation.perl -l ru \\\n", " < data/WikiMatrix.en-ru.langidfilter.lengthratio.bicleaner.60.ru > \\\n", " data/WikiMatrix.en-ru.langidfilter.lengthratio.bicleaner.60.moses.norm.ru\n" ] @@ -500,12 +504,12 @@ "outputs": [], "source": [ "print('Tokenizing English ...')\n", - "!perl mosesdecoder/scripts/tokenizer/tokenizer.perl -l en -no-escape -threads 4 \\\n", + "!perl data/mosesdecoder/scripts/tokenizer/tokenizer.perl -l en -no-escape -threads 4 \\\n", " < data/WikiMatrix.en-ru.langidfilter.lengthratio.bicleaner.60.moses.norm.en > \\\n", " data/WikiMatrix.en-ru.langidfilter.lengthratio.bicleaner.60.moses.norm.tok.en\n", "\n", "print('Tokenizing Russian ...')\n", - "!perl mosesdecoder/scripts/tokenizer/tokenizer.perl -l ru -no-escape -threads 4 \\\n", + "!perl data/mosesdecoder/scripts/tokenizer/tokenizer.perl -l ru -no-escape -threads 4 \\\n", " < data/WikiMatrix.en-ru.langidfilter.lengthratio.bicleaner.60.moses.norm.ru > \\\n", " data/WikiMatrix.en-ru.langidfilter.lengthratio.bicleaner.60.moses.norm.tok.ru\n" ] @@ -691,17 +695,17 @@ "metadata": {}, "outputs": [], "source": [ - "!shuf --random-source=data/WikiMatrix.en-ru.langidfilter.lengthratio.bicleaner.60.sacremoses.norm.tok.dedup.en \\\n", - " data/WikiMatrix.en-ru.langidfilter.lengthratio.bicleaner.60.sacremoses.norm.tok.dedup.en > \\\n", - " data/WikiMatrix.en-ru.langidfilter.lengthratio.bicleaner.60.sacremoses.norm.tok.dedup.shuf.en\n", + "!shuf --random-source=data/WikiMatrix.en-ru.langidfilter.lengthratio.bicleaner.60.moses.norm.tok.dedup.en \\\n", + " data/WikiMatrix.en-ru.langidfilter.lengthratio.bicleaner.60.moses.norm.tok.dedup.en > \\\n", + " data/WikiMatrix.en-ru.langidfilter.lengthratio.bicleaner.60.moses.norm.tok.dedup.shuf.en\n", "\n", - "!shuf --random-source=data/WikiMatrix.en-ru.langidfilter.lengthratio.bicleaner.60.sacremoses.norm.tok.dedup.en \\\n", - " data/WikiMatrix.en-ru.langidfilter.lengthratio.bicleaner.60.sacremoses.norm.tok.dedup.ru > \\\n", - " data/WikiMatrix.en-ru.langidfilter.lengthratio.bicleaner.60.sacremoses.norm.tok.dedup.shuf.ru\n", + "!shuf --random-source=data/WikiMatrix.en-ru.langidfilter.lengthratio.bicleaner.60.moses.norm.tok.dedup.en \\\n", + " data/WikiMatrix.en-ru.langidfilter.lengthratio.bicleaner.60.moses.norm.tok.dedup.ru > \\\n", + " data/WikiMatrix.en-ru.langidfilter.lengthratio.bicleaner.60.moses.norm.tok.dedup.shuf.ru\n", "\n", "!paste -d \"\\t\" \\\n", - " data/WikiMatrix.en-ru.langidfilter.lengthratio.bicleaner.60.sacremoses.norm.tok.dedup.shuf.en \\\n", - " data/WikiMatrix.en-ru.langidfilter.lengthratio.bicleaner.60.sacremoses.norm.tok.dedup.shuf.ru \\\n", + " data/WikiMatrix.en-ru.langidfilter.lengthratio.bicleaner.60.moses.norm.tok.dedup.shuf.en \\\n", + " data/WikiMatrix.en-ru.langidfilter.lengthratio.bicleaner.60.moses.norm.tok.dedup.shuf.ru \\\n", " | head -10" ] }, @@ -734,8 +738,8 @@ " -O create_tarred_parallel_dataset.py\n", "\n", "!python create_tarred_parallel_dataset.py \\\n", - " --src_fname data/WikiMatrix.en-ru.langidfilter.lengthratio.bicleaner.60.sacremoses.norm.tok.dedup.shuf.en \\\n", - " --tgt_fname data/WikiMatrix.en-ru.langidfilter.lengthratio.bicleaner.60.sacremoses.norm.tok.dedup.shuf.ru \\\n", + " --src_fname data/WikiMatrix.en-ru.langidfilter.lengthratio.bicleaner.60.moses.norm.tok.dedup.shuf.en \\\n", + " --tgt_fname data/WikiMatrix.en-ru.langidfilter.lengthratio.bicleaner.60.moses.norm.tok.dedup.shuf.ru \\\n", " --out_dir data/tarred_dataset_en_ru_8k_tokens \\\n", " --clean \\\n", " --encoder_tokenizer_name yttm \\\n",