From bdfb9504212bc20b1ac4c99de466f9b153513282 Mon Sep 17 00:00:00 2001 From: Taejin Park Date: Fri, 28 Apr 2023 09:39:24 -0700 Subject: [PATCH 01/10] [BugFix] Force _get_batch_preds() to keep logits in decoder timestamps generator (#6499) * [BugFix] _get_batch_preds() is forced to keep logits in decoder timestamps generators Signed-off-by: Taejin Park * Ingnore keep_logits boolean in FrameASRBatchLogits Signed-off-by: Taejin Park --------- Signed-off-by: Taejin Park Co-authored-by: Jagadeesh Balam <4916480+jbalam-nv@users.noreply.github.com> --- .../asr/parts/utils/decoder_timestamps_utils.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/nemo/collections/asr/parts/utils/decoder_timestamps_utils.py b/nemo/collections/asr/parts/utils/decoder_timestamps_utils.py index 8e81d49939cb..f26b0c6b701a 100644 --- a/nemo/collections/asr/parts/utils/decoder_timestamps_utils.py +++ b/nemo/collections/asr/parts/utils/decoder_timestamps_utils.py @@ -232,7 +232,7 @@ def get_wer_feat_logit(audio_file_path, asr, frame_len, tokens_per_chunk, delay, return hyp, tokens, log_prob -class FrameBatchASR_Logits(FrameBatchASR): +class FrameBatchASRLogits(FrameBatchASR): """ A class for streaming frame-based ASR. Inherits from FrameBatchASR and adds new capability of returning the logit output. @@ -260,10 +260,9 @@ def read_audio_file_and_return(self, audio_filepath: str, delay: float, model_st self.set_frame_reader(frame_reader) @torch.no_grad() - def _get_batch_preds(self): + def _get_batch_preds(self, keep_logits): device = self.asr_model.device for batch in iter(self.data_loader): - feat_signal, feat_signal_len = batch feat_signal, feat_signal_len = feat_signal.to(device), feat_signal_len.to(device) log_probs, encoded_len, predictions = self.asr_model( @@ -272,9 +271,12 @@ def _get_batch_preds(self): preds = torch.unbind(predictions) for pred in preds: self.all_preds.append(pred.cpu().numpy()) + # Always keep logits in FrameBatchASRLogits + _ = keep_logits log_probs_tup = torch.unbind(log_probs) for log_prob in log_probs_tup: self.all_logprobs.append(log_prob) + del log_probs, log_probs_tup del encoded_len del predictions @@ -635,7 +637,7 @@ def run_ASR_BPE_CTC(self, asr_model: Type[EncDecCTCModelBPE]) -> Tuple[Dict, Dic log_prediction=asr_model._cfg.get("log_prediction", False), ) - frame_asr = FrameBatchASR_Logits( + frame_asr = FrameBatchASRLogits( asr_model=asr_model, frame_len=self.chunk_len_in_sec, total_buffer=self.total_buffer_in_secs, From 92bb5c023921968219436700b250cf858155e3ef Mon Sep 17 00:00:00 2001 From: Ryan Langman Date: Fri, 28 Apr 2023 09:40:17 -0700 Subject: [PATCH 02/10] [TTS] Fix FastPitch energy code (#6511) Signed-off-by: Ryan --- nemo/collections/tts/modules/fastpitch.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/nemo/collections/tts/modules/fastpitch.py b/nemo/collections/tts/modules/fastpitch.py index e2da672cf9c7..5f2227a999db 100644 --- a/nemo/collections/tts/modules/fastpitch.py +++ b/nemo/collections/tts/modules/fastpitch.py @@ -317,7 +317,7 @@ def forward( # Predict energy if self.energy_predictor is not None: - energy_pred = self.energy_predictor(prosody_input, enc_mask).squeeze(-1) + energy_pred = self.energy_predictor(enc_out, enc_mask, conditioning=spk_emb).squeeze(-1) if energy is not None: # Average energy over characters @@ -402,7 +402,7 @@ def infer( assert energy.shape[-1] == text.shape[-1], f"energy.shape[-1]: {energy.shape[-1]} != len(text)" energy_emb = self.energy_emb(energy) else: - energy_pred = self.energy_predictor(prosody_input, enc_mask).squeeze(-1) + energy_pred = self.energy_predictor(enc_out, enc_mask, conditioning=spk_emb).squeeze(-1) energy_emb = self.energy_emb(energy_pred.unsqueeze(1)) enc_out = enc_out + energy_emb.transpose(1, 2) From 62b3bb580ca70412be61c390a48eb505b5264488 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Fri, 28 Apr 2023 10:41:30 -0700 Subject: [PATCH 03/10] fix custom forward_torch_softmax (#6512) (#6517) Signed-off-by: Abhinav Khattar Co-authored-by: Abhinav Khattar --- .../nlp/modules/common/megatron/fused_softmax.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/nemo/collections/nlp/modules/common/megatron/fused_softmax.py b/nemo/collections/nlp/modules/common/megatron/fused_softmax.py index 3dc0a00c55bd..2c914a67dd12 100644 --- a/nemo/collections/nlp/modules/common/megatron/fused_softmax.py +++ b/nemo/collections/nlp/modules/common/megatron/fused_softmax.py @@ -51,9 +51,10 @@ def forward_torch_softmax(self, input, mask): input = input * self.scale mask_output = self.mask_func(input, mask) if mask is not None else input probs = torch.nn.Softmax(dim=-1)(mask_output) - all_k_masked = mask.all(axis=-1) - zero_attention_mask = (1.0 - all_k_masked.float())[:, :, :, None] - probs = probs * zero_attention_mask + if mask is not None: + all_k_masked = mask.all(axis=-1) + zero_attention_mask = (1.0 - all_k_masked.float())[:, :, :, None] + probs = probs * zero_attention_mask if self.input_in_float16 and self.softmax_in_fp32: if self.input_in_fp16: From 9a4aa1131375b014152da27357501a4fa4d8a57c Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Fri, 28 Apr 2023 14:08:42 -0700 Subject: [PATCH 04/10] [TTS] fixed broken path. (#6514) (#6518) Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> Co-authored-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> --- tutorials/tts/Vits_Training.ipynb | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tutorials/tts/Vits_Training.ipynb b/tutorials/tts/Vits_Training.ipynb index e4c088d66a5e..37e55e0d7572 100644 --- a/tutorials/tts/Vits_Training.ipynb +++ b/tutorials/tts/Vits_Training.ipynb @@ -305,6 +305,8 @@ " model.sample_rate=22050 \\\n", " train_dataset=tests/data/asr/an4_train.json \\\n", " validation_datasets=tests/data/asr/an4_val.json \\\n", + " phoneme_dict_path=tts_dataset_files/ipa_cmudict-0.7b_nv23.01.txt \\\n", + " heteronyms_path=tts_dataset_files/heteronyms-052722 \\\n", " trainer.max_epochs=3 \\\n", " trainer.accelerator=null \\\n", " trainer.check_val_every_n_epoch=1 \\\n", From 1589a6a7b9b8dcdda307e049c0a288be9b020e13 Mon Sep 17 00:00:00 2001 From: anteju <108555623+anteju@users.noreply.github.com> Date: Fri, 28 Apr 2023 14:33:59 -0700 Subject: [PATCH 05/10] Fix normalization of impulse response in ImpulsePerturbation (#6505) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Ante Jukić --- .../asr/parts/preprocessing/perturb.py | 47 +++++++++++++------ 1 file changed, 32 insertions(+), 15 deletions(-) diff --git a/nemo/collections/asr/parts/preprocessing/perturb.py b/nemo/collections/asr/parts/preprocessing/perturb.py index 801305d90b7a..d4b1944ec6a2 100644 --- a/nemo/collections/asr/parts/preprocessing/perturb.py +++ b/nemo/collections/asr/parts/preprocessing/perturb.py @@ -344,16 +344,24 @@ class ImpulsePerturbation(Perturbation): manifest_path (list): Manifest file for RIRs audio_tar_filepaths (list): Tar files, if RIR audio files are tarred shuffle_n (int): Shuffle parameter for shuffling buffered files from the tar files + normalize_impulse (bool): Normalize impulse response to zero mean and amplitude 1 shift_impulse (bool): Shift impulse response to adjust for delay at the beginning rng (int): Random seed. Default is None """ def __init__( - self, manifest_path=None, audio_tar_filepaths=None, shuffle_n=128, shift_impulse=False, rng=None, + self, + manifest_path=None, + audio_tar_filepaths=None, + shuffle_n=128, + normalize_impulse=False, + shift_impulse=False, + rng=None, ): self._manifest = collections.ASRAudioText(manifest_path, parser=parsers.make_parser([]), index_by_file_id=True) self._audiodataset = None self._tarred_audio = False + self._normalize_impulse = normalize_impulse self._shift_impulse = shift_impulse self._data_iterator = None @@ -373,23 +381,32 @@ def perturb(self, data): tarred_audio=self._tarred_audio, audio_dataset=self._data_iterator, ) - if not self._shift_impulse: - impulse_norm = (impulse.samples - min(impulse.samples)) / (max(impulse.samples) - min(impulse.samples)) - data._samples = signal.fftconvolve(data._samples, impulse_norm, "same") - data._samples = data._samples / max( - abs(data._samples) - ) # normalize data samples to [-1,1] after rir convolution to avoid nans with fp16 training + + # normalize if necessary + if self._normalize_impulse: + # normalize the impulse response to zero mean and amplitude 1 + impulse_norm = impulse.samples - np.mean(impulse.samples) + impulse_norm /= max(abs(impulse_norm)) else: - # Find peak and shift peak to left - impulse_norm = (impulse.samples - min(impulse.samples)) / (max(impulse.samples) - min(impulse.samples)) + impulse_norm = impulse.samples + + # len of input data samples + len_data = len(data._samples) + + # convolve with the full impulse response + data._samples = signal.fftconvolve(data._samples, impulse_norm, "full") + + # compensate the dominant path propagation delay + if self._shift_impulse: + # Find the peak of the IR and shift the output to the left max_ind = np.argmax(np.abs(impulse_norm)) + data._samples = data._samples[max_ind:] + + # trim to match the input data length + data._samples = data._samples[:len_data] - impulse_resp = impulse_norm[max_ind:] - delay_after = len(impulse_resp) - data._samples = signal.fftconvolve(data._samples, impulse_resp, "full")[:-delay_after] - data._samples = data._samples / max( - abs(data._samples) - ) # normalize data samples to [-1,1] after rir convolution to avoid nans with fp16 training + # normalize data samples to [-1,1] after rir convolution to avoid nans with fp16 training + data._samples = data._samples / max(abs(data._samples)) class ShiftPerturbation(Perturbation): From 892987169ef277f328e15b71a5a0c9bd961c8ee7 Mon Sep 17 00:00:00 2001 From: Somshubra Majumdar Date: Fri, 28 Apr 2023 16:30:42 -0700 Subject: [PATCH 06/10] Add interleaved pp support (#6498) * Add support for Virtual Pipeline Parallel conversion Signed-off-by: smajumdar * Add support for Virtual Pipeline Parallel conversion Signed-off-by: smajumdar * Switch to megatron core Signed-off-by: smajumdar * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: smajumdar Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .../megatron_change_num_partitions.py | 385 ++++++++++++++---- 1 file changed, 313 insertions(+), 72 deletions(-) diff --git a/examples/nlp/language_modeling/megatron_change_num_partitions.py b/examples/nlp/language_modeling/megatron_change_num_partitions.py index 944565d8bd43..a4b28fa4d761 100644 --- a/examples/nlp/language_modeling/megatron_change_num_partitions.py +++ b/examples/nlp/language_modeling/megatron_change_num_partitions.py @@ -13,11 +13,13 @@ # limitations under the License. import os +import tempfile from argparse import ArgumentParser from typing import Dict, List import torch -from omegaconf import open_dict +import torch.nn as nn +from omegaconf import OmegaConf, open_dict from pytorch_lightning import Trainer from nemo.collections.nlp.parts.nlp_overrides import ( @@ -54,6 +56,20 @@ --target_pipeline_model_parallel_size=1 \ --target_pipeline_model_parallel_split_rank=0 \ --precision=bf16 + +# Megatron GPT + Virtual Pipeline parallelism + +python megatron_change_num_partitions.py \ + --model_extracted_dir="" \ + --target_file="" \ + --ckpt_name="" \ + --tensor_model_parallel_size= \ + --target_tensor_model_parallel_size= \ + --pipeline_model_parallel_size= \ + --target_pipeline_model_parallel_size= \ + --virtual_pipeline_model_parallel_size= \ + --hparams_file="" \ + --precision=bf16 ### Only Tensor Parallelism conversion ### @@ -100,6 +116,43 @@ """ +def set_virtual_parallel_rank_safely(rank: int): + AppState().virtual_pipeline_model_parallel_rank = rank + + try: + from megatron.core import parallel_state + + parallel_state.set_virtual_pipeline_model_parallel_rank(rank) + + if rank is None: + parallel_state.set_virtual_pipeline_model_parallel_world_size(0) + + except (ImportError, ModuleNotFoundError): + logging.warning("`megatron-core` not installed, cannot set virtual parallel rank !") + + +################# +### Utilities ### +################# + + +def force_cpu_model(cfg): + with open_dict(cfg): + # temporarily + original_cpu_init = cfg.get('use_cpu_initialization', False) + original_amp_o2 = cfg.get('megatron_amp_O2', False) + cfg.use_cpu_initialization = True + cfg.megatron_amp_O2 = False + return cfg, {'original_cpu_init': original_cpu_init, 'original_amp_o2': original_amp_o2} + + +def restore_model_config(cfg, original_dict): + with open_dict(cfg): + for key, val in original_dict.items(): + cfg[key] = val + return cfg + + ################# ### Utilities ### ################# @@ -732,6 +785,12 @@ def main(): parser.add_argument( '--target_pipeline_model_parallel_split_rank', type=int, default=0, help='PP rank to split for Enc-Dec models' ) + parser.add_argument( + '--virtual_pipeline_model_parallel_size', type=int, default=None, help='Virtual Pipeline parallelism size' + ) + parser.add_argument( + '--ckpt_name', type=str, default=None, help='Checkpoint name to load from for Virtual Parallel' + ) parser.add_argument( "--model_class", type=str, @@ -759,6 +818,7 @@ def main(): default=None, help="Path to the tokenizer model path if your model uses a tokenizer model as an artifact. This is needed if your model uses a sentencepiece tokenizer.", ) + parser.add_argument('--hparams_file', type=str, default=None, help='Path to hparams file from PTL training') parser.add_argument('--tp_conversion_only', action='store_true', help='Only convert TP model to TP model') parser.add_argument('--model_extracted_dir', type=str, default=None, help='Path to pre-extracted model directory') @@ -795,6 +855,25 @@ def main(): pp_size = args.pipeline_model_parallel_size tgt_pp_size = args.target_pipeline_model_parallel_size pipeline_model_parallel_split_rank = args.target_pipeline_model_parallel_split_rank + vp_size = args.virtual_pipeline_model_parallel_size + if vp_size is None: + vp_size = 1 + + convert_vp = vp_size > 1 + if convert_vp: + hparams_filepath = args.hparams_file + if hparams_filepath is None: + logging.warning( + '\n\n\n!!!!!!!!!\n' + 'You are converting a model with virtual pipeline parallelism enabled, \n' + 'but have not passed `hparams_file` argument. \n' + 'This will cause each ckpt file to be temporarily laoded onto GPU memory!\n\n' + 'It is highly recommended to pass `hparams_file` argument to avoid this.\n' + ) + else: + hparams_filepath = None + + # Import the class of the model cls = model_utils.import_class_by_path(args.model_class) if args.model_file is None and args.model_extracted_dir is None: @@ -830,10 +909,16 @@ def main(): tgt_pp_size = 1 pipeline_model_parallel_split_rank = 0 + if vp_size is None or vp_size < 0: + vp_size = 1 + app_state = AppState() app_state.data_parallel_rank = 0 app_state.pipeline_model_parallel_size = pp_size app_state.tensor_model_parallel_size = tp_size + + if vp_size > 1: + app_state.virtual_pipeline_model_parallel_size = vp_size app_state.model_parallel_size = app_state.pipeline_model_parallel_size * app_state.tensor_model_parallel_size world_size = pp_size * tp_size # pseudo world size for simulating load of a specific rank on a single gpu @@ -841,87 +926,198 @@ def main(): app_state.tensor_model_parallel_rank = 0 app_state.pipeline_model_parallel_rank = 0 + if vp_size > 1: + set_virtual_parallel_rank_safely(0) + # If input model has TP > 1 or PP > 1 # Reconstruct the model to have TP = 1 and PP = 1 # Note that this is a forward loop that will process PP [0..N] TP [0..M] in sequential order. if tp_size > 1 or pp_size > 1: - partitions = {} + partitions = {} # 3d list of VP x PP x TP model = None - for pp_rank in range(pp_size): - app_state.pipeline_model_parallel_rank = pp_rank - partitions[pp_rank] = [] - - for tp_rank in range(tp_size): - app_state.tensor_model_parallel_rank = tp_rank - - logging.info(f"Loading ------------ PP Rank: {pp_rank} TP Rank: {tp_rank}") - - # Override flag that forces Model to use AppState instead of Trainer - # to determine the world size, global and local rank - # Used for simulating load of a specific rank on a single gpu - os.environ[NEMO_MEGATRON_MODEL_PARALLEL_APPSTATE_OVERRIDE] = "true" - - # Compute the global rank to load the correct subset of parameters - global_rank = pp_rank * tp_size + tp_rank - - # Update AppState - app_state.world_size = world_size - app_state.global_rank = global_rank - app_state.local_rank = global_rank % num_gpu_per_node - app_state.pipeline_model_parallel_size = pp_size - app_state.tensor_model_parallel_size = tp_size - app_state.pipeline_model_parallel_split_rank = pipeline_model_parallel_split_rank - app_state.model_parallel_size = ( - app_state.pipeline_model_parallel_size * app_state.tensor_model_parallel_size - ) - save_restore_connector = NLPSaveRestoreConnector() + # Build partitions structure + for vp_idx in range(vp_size): + partitions[vp_idx] = [] # Build first layer - VP - if args.model_extracted_dir is not None: - logging.info(f"Using extracted model directory: {args.model_extracted_dir}") - save_restore_connector.model_extracted_dir = args.model_extracted_dir + for pp_idx in range(pp_size): + # For each VP, build PP x TP holder + partitions[vp_idx].append({}) + partitions[vp_idx][pp_idx] = [] - if args.model_file is not None: - model_filepath = args.model_file - else: - model_filepath = args.model_extracted_dir + for vp_rank in range(vp_size): + if vp_size > 1: + set_virtual_parallel_rank_safely(vp_rank) - model = cls.restore_from( - restore_path=model_filepath, - trainer=trainer, - map_location=torch.device("cpu"), - save_restore_connector=save_restore_connector, - ) - model.to(dtype=dtype) + for pp_rank in range(pp_size): + app_state.pipeline_model_parallel_rank = pp_rank - # Reset env flag - os.environ.pop(NEMO_MEGATRON_MODEL_PARALLEL_APPSTATE_OVERRIDE, None) + for tp_rank in range(tp_size): + app_state.tensor_model_parallel_rank = tp_rank - logging.info( - f"<<<<<<<< LOADED MODEL PP={pp_rank + 1} TP={tp_rank + 1} | " - f"GLOBAL RANK = {global_rank} >>>>>>>>>" - ) - params = [p for _, p in model.named_parameters()] - partitions[pp_rank].append(params) + logging.info(f"Loading ------------ PP Rank: {pp_rank} TP Rank: {tp_rank}") - # app_state is being updated incorrectly during restore - app_state.data_parallel_rank = 0 - app_state.pipeline_model_parallel_rank = pp_rank - app_state.tensor_model_parallel_rank = tp_rank - app_state.pipeline_model_parallel_size = pp_size - app_state.tensor_model_parallel_size = tp_size - app_state.model_parallel_size = ( - app_state.pipeline_model_parallel_size * app_state.tensor_model_parallel_size - ) + # Override flag that forces Model to use AppState instead of Trainer + # to determine the world size, global and local rank + # Used for simulating load of a specific rank on a single gpu + os.environ[NEMO_MEGATRON_MODEL_PARALLEL_APPSTATE_OVERRIDE] = "true" + + # Compute the global rank to load the correct subset of parameters + global_rank = pp_rank * tp_size + tp_rank + + # Update AppState + app_state.world_size = world_size + app_state.global_rank = global_rank + app_state.local_rank = global_rank % num_gpu_per_node + app_state.pipeline_model_parallel_size = pp_size + app_state.tensor_model_parallel_size = tp_size + app_state.pipeline_model_parallel_split_rank = pipeline_model_parallel_split_rank + app_state.model_parallel_size = ( + app_state.pipeline_model_parallel_size * app_state.tensor_model_parallel_size + ) + + if vp_size > 1: + set_virtual_parallel_rank_safely(vp_rank) + + if vp_rank == 0: + save_restore_connector = NLPSaveRestoreConnector() + + if args.model_extracted_dir is not None: + logging.info(f"Using extracted model directory: {args.model_extracted_dir}") + save_restore_connector.model_extracted_dir = args.model_extracted_dir + + if args.model_file is not None: + model_filepath = args.model_file + else: + model_filepath = args.model_extracted_dir + + if vp_size == 1: + + # Get model config + tmp_cfg = cls.restore_from( + restore_path=model_filepath, + trainer=trainer, + map_location=torch.device("cpu"), + save_restore_connector=save_restore_connector, + return_config=True, + ) + + # Force model onto CPU + tmp_cfg, restore_dict = force_cpu_model(tmp_cfg) + + # Restore model + model = cls.restore_from( + restore_path=model_filepath, + trainer=trainer, + map_location=torch.device("cpu"), + save_restore_connector=save_restore_connector, + override_config_path=tmp_cfg, + ) + model.freeze() + + # Restore model config + restore_model_config(model.cfg, restore_dict) + + else: + if args.ckpt_name is None: + raise ValueError( + "For Virtual Parallel, ckpt name is required.\n" + "Please provide `--ckpt_name` argument." + ) + + # inject model parallel rank + checkpoint_path = model_utils.inject_model_parallel_rank( + os.path.join(model_filepath, args.ckpt_name) + ) + + if hparams_filepath is not None: + # Force the model onto CPU + tmp_cfg = OmegaConf.load(hparams_filepath) + tmp_cfg, restore_dict = force_cpu_model(tmp_cfg) + + with tempfile.NamedTemporaryFile(mode='w', encoding='utf-8', suffix='.yml') as tmp: + OmegaConf.save(tmp_cfg, tmp, resolve=True) + tmp.seek(0) + + model = cls.load_from_checkpoint( + checkpoint_path=checkpoint_path, + trainer=trainer, + map_location=torch.device("cpu"), + hparams_file=tmp.name, + ) + model.freeze() + + restore_model_config(model.cfg, restore_dict) + + else: + model = cls.load_from_checkpoint( + checkpoint_path=checkpoint_path, trainer=trainer, map_location=torch.device("cpu"), + ) + model.freeze() + + model.to(dtype=dtype) + + # Reset env flag + os.environ.pop(NEMO_MEGATRON_MODEL_PARALLEL_APPSTATE_OVERRIDE, None) + + logging.info( + f"<<<<<<<< LOADED MODEL PP={pp_rank + 1} TP={tp_rank + 1} | " + f"GLOBAL RANK = {global_rank} >>>>>>>>>" + ) + + # Save the parameters + if vp_size == 1: + params = [p for p in model.parameters()] + partitions[vp_rank][pp_rank].append(params) # vp_rank = 0 + + else: + vp_params_tmp = [] + for vp_idx in range(vp_size): + set_virtual_parallel_rank_safely(vp_idx) + params = [p for p in model.model[vp_idx].parameters()] + # params = model.model[vp_idx].module.state_dict_for_save_checkpoint() + # params = [p for p in params.values()] + vp_params_tmp.append(params) + # partitions[pp_rank][vp_idx].append(params) + + for vp_idx in range(vp_size): + partitions[vp_idx][pp_rank].append(vp_params_tmp[vp_idx]) + + del vp_params_tmp + set_virtual_parallel_rank_safely(0) + + # app_state is being updated incorrectly during restore + app_state.data_parallel_rank = 0 + app_state.pipeline_model_parallel_rank = pp_rank + app_state.tensor_model_parallel_rank = tp_rank + app_state.pipeline_model_parallel_size = pp_size + app_state.tensor_model_parallel_size = tp_size + app_state.model_parallel_size = ( + app_state.pipeline_model_parallel_size * app_state.tensor_model_parallel_size + ) + + if vp_size > 1: + app_state.virtual_pipeline_model_parallel_size = vp_size + set_virtual_parallel_rank_safely(vp_rank) # Build a unified model with PP 1 TP 1 with open_dict(model.cfg): model.cfg.tensor_model_parallel_size = 1 model.cfg.pipeline_model_parallel_size = 1 + model.cfg.virtual_pipeline_model_parallel_size = None + + app_state.global_rank = 0 + app_state.local_rank = 0 + app_state.data_parallel_rank = 0 + app_state.pipeline_model_parallel_rank = 0 app_state.tensor_model_parallel_rank = 0 - app_state.pipeline_model_parallel_size = 0 + app_state.pipeline_model_parallel_size = 1 + app_state.tensor_model_parallel_size = 1 app_state.model_parallel_size = 1 + if vp_size > 1: + set_virtual_parallel_rank_safely(None) + trainer = Trainer(devices=1, strategy=NLPDDPStrategy(), accelerator="cpu", precision=precision) with open_dict(model.cfg): @@ -930,25 +1126,52 @@ def main(): if args.tokenizer_vocab_file is not None: model.cfg.tokenizer.vocab_file = args.tokenizer_vocab_file - # temporarily - original_cpu_init = model.cfg.get('use_cpu_initialization', False) - original_amp_o2 = model.cfg.get('megatron_amp_O2', False) - model.cfg.use_cpu_initialization = True - model.cfg.megatron_amp_O2 = False + model.cfg, restore_dict = force_cpu_model(model.cfg) - model = cls(model.cfg, trainer) + # Remove Virtual Parallelism + model.cfg.virtual_pipeline_model_parallel_size = None + + logging.info(f"<<<<<<<< Building TP 1 PP 1 base model >>>>>>>>>") + model = cls(model.cfg, trainer) # type: nn.Module + model.freeze() model = model.to('cpu') model._save_restore_connector = NLPSaveRestoreConnector() + vp_param_count = 0 + for vp in range(vp_size): + for pp in range(pp_size): + for tp in range(tp_size): + vp_param_count += len(partitions[vp][pp][tp]) + + if vp_size > 1: + logging.debug(f"Total params in TP PP VP = 1 : {len(list(model.parameters()))}") + logging.debug(f"Total params in VP PP TP (og): {vp_param_count}") + + # Flatten Virtual Pipeline + if vp_size == 1: + # unpack vp container, pack pp tp container + partitions = partitions[0] + partitions = {idx: val for idx, val in enumerate(partitions)} + else: + flat_partitions = {idx: [] for idx in range(pp_size)} + + for pp in range(pp_size): + for tp in range(tp_size): + vp_cache = [] + for vp in range(vp_size): + vp_cache.extend(partitions[vp][pp][tp]) + + flat_partitions[pp].append(vp_cache) + + partitions = flat_partitions + if tgt_tp_size > 1 or tgt_pp_size > 1: merge_partition(model, partitions) else: # Write out the PP 1 TP 1 model to disk merge_partition(model, partitions, args.target_file) - with open_dict(model.cfg): - model.cfg.use_cpu_initialization = original_cpu_init - model.cfg.megatron_amp_O2 = original_amp_o2 + restore_model_config(model.cfg, restore_dict) # Empty cache memory of all parameters from all PP TP partitions partitions.clear() @@ -968,6 +1191,16 @@ def main(): else: model_filepath = args.model_extracted_dir + tmp_cfg = cls.restore_from( + restore_path=model_filepath, + trainer=trainer, + map_location=torch.device("cpu"), + save_restore_connector=save_restore_connector, + return_config=True, + ) + + tmp_cfg, restore_dict = force_cpu_model(tmp_cfg) + model = cls.restore_from( restore_path=model_filepath, trainer=trainer, @@ -976,6 +1209,8 @@ def main(): ) model.to(dtype=dtype) + restore_model_config(model.cfg, restore_dict) + # If target model has TP > 1 or PP > 1 if tgt_pp_size > 1 or tgt_tp_size > 1: @@ -1046,10 +1281,16 @@ def main(): with open_dict(model.cfg): model.cfg.tokenizer.model = args.tokenizer_model_path - model = cls(model.cfg, trainer).to('cpu') + model.cfg, restore_dict = force_cpu_model(model.cfg) + + model = cls(model.cfg, trainer) + model = model.to('cpu') model._save_restore_connector = NLPSaveRestoreConnector() + model.freeze() model.to(dtype=dtype) + restore_model_config(model.cfg, restore_dict) + # Update global batch size if old_global_batch_size % new_global_batch_size != 0 or old_global_batch_size < new_global_batch_size: logging.info( From 5468077f5127be1a4c88065de2544f4268b9a6e4 Mon Sep 17 00:00:00 2001 From: Somshubra Majumdar Date: Mon, 1 May 2023 09:49:25 -0700 Subject: [PATCH 07/10] Fix typos (#6523) * Fix typos Signed-off-by: smajumdar * Fix typos Signed-off-by: smajumdar --------- Signed-off-by: smajumdar --- tutorials/asr/ASR_CTC_Language_Finetuning.ipynb | 10 +++++----- tutorials/asr/ASR_with_Subword_Tokenization.ipynb | 4 ++-- tutorials/asr/Buffered_Transducer_Inference.ipynb | 2 +- .../asr/Online_Offline_Speech_Commands_Demo.ipynb | 2 +- tutorials/asr/Streaming_ASR.ipynb | 2 +- tutorials/asr/asr_adapters/ASR_with_Adapters.ipynb | 4 ++-- 6 files changed, 12 insertions(+), 12 deletions(-) diff --git a/tutorials/asr/ASR_CTC_Language_Finetuning.ipynb b/tutorials/asr/ASR_CTC_Language_Finetuning.ipynb index 7541ff33db5a..b9c0db866f9c 100644 --- a/tutorials/asr/ASR_CTC_Language_Finetuning.ipynb +++ b/tutorials/asr/ASR_CTC_Language_Finetuning.ipynb @@ -540,8 +540,8 @@ "import matplotlib.pyplot as plt\n", "\n", "plt.bar(x=TOKEN_COUNT_X, height=NUM_TOKENS_Y)\n", - "plt.title(\"Occurance of unique tokens in train+dev set\")\n", - "plt.xlabel(\"# of occurances\")\n", + "plt.title(\"Occurrences of unique tokens in train+dev set\")\n", + "plt.xlabel(\"# of occurrences\")\n", "plt.ylabel(\"# of tokens\")\n", "plt.xlim(0, MAX_COUNT);" ], @@ -565,13 +565,13 @@ "source": [ "UNCOMMON_TOKENS_COUNT = 5\n", "\n", - "chars_with_infrequent_occurance = set()\n", + "chars_with_infrequent_occurrence = set()\n", "for count in range(1, UNCOMMON_TOKENS_COUNT + 1):\n", " if count in train_counts:\n", " token_list = train_counts[count]\n", - " chars_with_infrequent_occurance.update(set(token_list))\n", + " chars_with_infrequent_occurrence.update(set(token_list))\n", "\n", - "print(f\"Number of tokens with <= {UNCOMMON_TOKENS_COUNT} occurances : {len(chars_with_infrequent_occurance)}\")" + "print(f\"Number of tokens with <= {UNCOMMON_TOKENS_COUNT} occurrences : {len(chars_with_infrequent_occurrence)}\")" ], "execution_count": null, "outputs": [] diff --git a/tutorials/asr/ASR_with_Subword_Tokenization.ipynb b/tutorials/asr/ASR_with_Subword_Tokenization.ipynb index f398fc16ce1b..b932916f2bc5 100644 --- a/tutorials/asr/ASR_with_Subword_Tokenization.ipynb +++ b/tutorials/asr/ASR_with_Subword_Tokenization.ipynb @@ -312,7 +312,7 @@ "\r\n", " - Sophisticated subword tokenization algorithms build their vocabularies based on large text corpora. To accurately tokenize such large volumes of text with minimal vocabulary size, the subwords that are learned inherently model the interdependency between tokens of that language to some degree. \r\n", " \r\n", - "Looking at the previous example, the token `hel##` is a single token that represents the relationship `h` => `e` => `l`. When the model predicts the singe token `hel##`, it implicitly predicts this relationship - even though the subsequent token can be either `l` (for `hell`) or `##lo` (for `hello`) and is predicted independently of the previous token!\r\n", + "Looking at the previous example, the token `hel##` is a single token that represents the relationship `h` => `e` => `l`. When the model predicts the single token `hel##`, it implicitly predicts this relationship - even though the subsequent token can be either `l` (for `hell`) or `##lo` (for `hello`) and is predicted independently of the previous token!\r\n", "\r\n", " - By reducing the target sentence length by subword tokenization (target sentence here being the characters/subwords transcribed from the audio signal), we entirely sidestep the sequence length limitation of CTC loss!\r\n", "\r\n", @@ -554,7 +554,7 @@ "\r\n", " - `--spe_sample_size`: If the dataset is too large, consider using a sampled dataset indicated by a positive integer. By default, any negative value (default = -1) will use the entire dataset.\r\n", "\r\n", - " - `--spe_train_extremely_large_corpus`: When training a sentencepiece tokenizer on very large amounts of text, sometimes the tokenizer will run out of memory or wont be able to process so much data on RAM. At some point you might receive the following error - \"Input corpus too large, try with train_extremely_large_corpus=true\". If your machine has large amounts of RAM, it might still be possible to build the tokenizer using the above flag. Will silently fail if it runs out of RAM.\r\n", + " - `--spe_train_extremely_large_corpus`: When training a sentencepiece tokenizer on very large amounts of text, sometimes the tokenizer will run out of memory or won't be able to process so much data on RAM. At some point you might receive the following error - \"Input corpus too large, try with train_extremely_large_corpus=true\". If your machine has large amounts of RAM, it might still be possible to build the tokenizer using the above flag. Will silently fail if it runs out of RAM.\r\n", "\r\n", " - `--log`: Whether the script should display log messages" ] diff --git a/tutorials/asr/Buffered_Transducer_Inference.ipynb b/tutorials/asr/Buffered_Transducer_Inference.ipynb index 045e36e45ae8..c23398dca46a 100644 --- a/tutorials/asr/Buffered_Transducer_Inference.ipynb +++ b/tutorials/asr/Buffered_Transducer_Inference.ipynb @@ -806,7 +806,7 @@ " print(\"\\nGreedy labels collected from this buffer\")\n", " print(tok[len(tok) - 1 - delay:len(tok) - 1 - delay + tokens_per_chunk]) \n", " self.toks_unmerged += tok[len(tok) - 1 - delay:len(tok) - 1 - delay + tokens_per_chunk]\n", - " print(\"\\nTokens collected from succesive buffers before RNNT merge\")\n", + " print(\"\\nTokens collected from successive buffers before RNNT merge\")\n", " print(self.toks_unmerged)\n", "\n", " output = []\n", diff --git a/tutorials/asr/Online_Offline_Speech_Commands_Demo.ipynb b/tutorials/asr/Online_Offline_Speech_Commands_Demo.ipynb index 2248ddac7417..c704ee1145c3 100644 --- a/tutorials/asr/Online_Offline_Speech_Commands_Demo.ipynb +++ b/tutorials/asr/Online_Offline_Speech_Commands_Demo.ipynb @@ -440,7 +440,7 @@ " Arg:\n", " wav_file: wave file to be performed inference on.\n", " STEP: infer every STEP seconds \n", - " WINDOW_SIZE : lenght of audio to be sent to NN.\n", + " WINDOW_SIZE : length of audio to be sent to NN.\n", " \"\"\"\n", " \n", " FRAME_LEN = STEP \n", diff --git a/tutorials/asr/Streaming_ASR.ipynb b/tutorials/asr/Streaming_ASR.ipynb index da44d33f68f0..a4701dc025d8 100644 --- a/tutorials/asr/Streaming_ASR.ipynb +++ b/tutorials/asr/Streaming_ASR.ipynb @@ -538,7 +538,7 @@ " print(\"\\nGreedy labels collected from this buffer\")\n", " print(tok[len(tok) - 1 - delay:len(tok) - 1 - delay + self.n_tokens_per_chunk]) \n", " self.toks_unmerged += tok[len(tok) - 1 - delay:len(tok) - 1 - delay + self.n_tokens_per_chunk]\n", - " print(\"\\nTokens collected from succesive buffers before CTC merge\")\n", + " print(\"\\nTokens collected from successive buffers before CTC merge\")\n", " print(self.toks_unmerged)\n", "\n", "\n", diff --git a/tutorials/asr/asr_adapters/ASR_with_Adapters.ipynb b/tutorials/asr/asr_adapters/ASR_with_Adapters.ipynb index 6a3980045c81..62481c3762d2 100644 --- a/tutorials/asr/asr_adapters/ASR_with_Adapters.ipynb +++ b/tutorials/asr/asr_adapters/ASR_with_Adapters.ipynb @@ -665,7 +665,7 @@ "\n", "For this experiment we will continue to use the original spec augmentation config in the base model, however you may find better results by modifying the strength of this augmentation.\n", "\n", - "**Note**: The script inside ASR examples **disables spec augment entirely**. This is done in order to provide a stable default to measure the best possible adaptation case, but may severely degrade the performance on general speech. Please be careful when copying the hyper parameters from the tutorial to the script for large scale experimentatin." + "**Note**: The script inside ASR examples **disables spec augment entirely**. This is done in order to provide a stable default to measure the best possible adaptation case, but may severely degrade the performance on general speech. Please be careful when copying the hyper parameters from the tutorial to the script for large scale experimentation." ], "metadata": { "id": "T3VuqcGTNuIJ" @@ -804,7 +804,7 @@ "source": [ "-----\n", "\n", - "As you can see, a single component of the model may support one or more adapter types (or none at all)! Below, we will experiment with the simple Linear Adapters, but as an excercise, you might try to use other adapter types present here." + "As you can see, a single component of the model may support one or more adapter types (or none at all)! Below, we will experiment with the simple Linear Adapters, but as an exercise, you might try to use other adapter types present here." ], "metadata": { "id": "YXTC4LiSnB2O" From a7bf6cf0d4a27544d4a9b5d36eb3ac47f689972d Mon Sep 17 00:00:00 2001 From: trias702 <25867060+trias702@users.noreply.github.com> Date: Tue, 2 May 2023 11:06:59 -0500 Subject: [PATCH 08/10] New noise_norm perturbation based on Riva work (#6445) * Initial commit for new noise_norm perturbation Signed-off-by: Daniel Egert * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Minor fix to random seed in perturb Signed-off-by: Daniel Egert * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Updated code to reflect feedback Signed-off-by: Daniel Egert * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Updates for feedback given by code reviewers Signed-off-by: Daniel Egert * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Updates in response to PR feedback Signed-off-by: Daniel Egert * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Added comment about ref_mic being None Signed-off-by: Daniel Egert * Updated perturb to use inspect module Signed-off-by: Daniel Egert --------- Signed-off-by: Daniel Egert Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- nemo/collections/asr/data/audio_to_text.py | 10 +- .../asr/data/audio_to_text_dataset.py | 4 +- .../asr/parts/preprocessing/__init__.py | 1 + .../asr/parts/preprocessing/perturb.py | 344 +++++++++++++++--- nemo/utils/model_utils.py | 2 +- 5 files changed, 304 insertions(+), 57 deletions(-) diff --git a/nemo/collections/asr/data/audio_to_text.py b/nemo/collections/asr/data/audio_to_text.py index d61f0e1f69ef..3b2e2a767a97 100644 --- a/nemo/collections/asr/data/audio_to_text.py +++ b/nemo/collections/asr/data/audio_to_text.py @@ -191,8 +191,8 @@ def expand_sharded_filepaths(sharded_filepaths, shard_strategy: str, world_size: sharded_filepaths = sharded_filepaths.replace(bkey, "}") if isinstance(sharded_filepaths, str): - # Brace expand - sharded_filepaths = list(braceexpand.braceexpand(sharded_filepaths)) + # Brace expand, set escape=False for Windows compatibility + sharded_filepaths = list(braceexpand.braceexpand(sharded_filepaths, escape=False)) # Expand store paths into WebDataset URLs sharded_filepaths = [ @@ -1359,5 +1359,9 @@ def __iter__(self): for dataset_idx in shuffled_order: d = self.datasets[dataset_idx] assert isinstance(d, IterableDataset), "ChainDataset only supports IterableDataset" - for x in d: + for idx, x in enumerate(d): yield x + # in case d is an infinite dataset, we want to break the loop + # so that the other datasets get a chance to yield too + if idx >= len(d) - 1: + break diff --git a/nemo/collections/asr/data/audio_to_text_dataset.py b/nemo/collections/asr/data/audio_to_text_dataset.py index 325857e81323..14e8dea19651 100644 --- a/nemo/collections/asr/data/audio_to_text_dataset.py +++ b/nemo/collections/asr/data/audio_to_text_dataset.py @@ -512,7 +512,7 @@ def get_audio_to_text_char_dataset_from_config( constructed dataset or None if dataset config is invalid or nothing to load """ if 'augmentor' in config: - augmentor = process_augmentations(config['augmentor']) + augmentor = process_augmentations(config['augmentor'], global_rank=global_rank, world_size=world_size) else: augmentor = None @@ -609,7 +609,7 @@ def get_audio_to_text_bpe_dataset_from_config( constructed dataset or None if dataset config is invalid or nothing to load """ if 'augmentor' in config: - augmentor = process_augmentations(config['augmentor']) + augmentor = process_augmentations(config['augmentor'], global_rank=global_rank, world_size=world_size) else: augmentor = None diff --git a/nemo/collections/asr/parts/preprocessing/__init__.py b/nemo/collections/asr/parts/preprocessing/__init__.py index b25f0ff25e42..a0785c56bf2a 100644 --- a/nemo/collections/asr/parts/preprocessing/__init__.py +++ b/nemo/collections/asr/parts/preprocessing/__init__.py @@ -20,6 +20,7 @@ GainPerturbation, ImpulsePerturbation, NoisePerturbation, + NoisePerturbationWithNormalization, Perturbation, RirAndNoisePerturbation, ShiftPerturbation, diff --git a/nemo/collections/asr/parts/preprocessing/perturb.py b/nemo/collections/asr/parts/preprocessing/perturb.py index d4b1944ec6a2..d882bc83772b 100644 --- a/nemo/collections/asr/parts/preprocessing/perturb.py +++ b/nemo/collections/asr/parts/preprocessing/perturb.py @@ -33,6 +33,7 @@ # SOFTWARE. # This file contains code artifacts adapted from https://github.com/ryanleary/patter import copy +import inspect import io import os import random @@ -44,10 +45,10 @@ import numpy as np import soundfile as sf from scipy import signal -from torch.utils.data import IterableDataset from nemo.collections.asr.parts.preprocessing.segment import AudioSegment from nemo.collections.common.parts.preprocessing import collections, parsers +from nemo.core.classes import IterableDataset from nemo.utils import logging # TODO @blisc: Perhaps refactor instead of import guarding @@ -69,16 +70,11 @@ HAVE_NUMBA = False -def read_one_audiosegment(manifest, target_sr, rng=None, tarred_audio=False, audio_dataset=None): - - random.seed(rng) if rng else None - +def read_one_audiosegment(manifest, target_sr, tarred_audio=False, audio_dataset=None): if tarred_audio: if audio_dataset is None: raise TypeError("Expected augmentation dataset but got None") - audio_file, file_id = next(audio_dataset) - manifest_idx = manifest.mapping[file_id] - manifest_entry = manifest[manifest_idx] + audio_file, file_id, manifest_entry = next(audio_dataset) offset = 0 if manifest_entry.offset is None else manifest_entry.offset duration = 0 if manifest_entry.duration is None else manifest_entry.duration @@ -375,11 +371,7 @@ def __init__( def perturb(self, data): impulse = read_one_audiosegment( - self._manifest, - data.sample_rate, - self._rng, - tarred_audio=self._tarred_audio, - audio_dataset=self._data_iterator, + self._manifest, data.sample_rate, tarred_audio=self._tarred_audio, audio_dataset=self._data_iterator, ) # normalize if necessary @@ -491,7 +483,7 @@ def orig_sr(self): def get_one_noise_sample(self, target_sr): return read_one_audiosegment( - self._manifest, target_sr, self._rng, tarred_audio=self._tarred_audio, audio_dataset=self._data_iterator + self._manifest, target_sr, tarred_audio=self._tarred_audio, audio_dataset=self._data_iterator ) def perturb(self, data, ref_mic=0): @@ -501,11 +493,7 @@ def perturb(self, data, ref_mic=0): ref_mic (int): reference mic index for scaling multi-channel audios """ noise = read_one_audiosegment( - self._manifest, - data.sample_rate, - self._rng, - tarred_audio=self._tarred_audio, - audio_dataset=self._data_iterator, + self._manifest, data.sample_rate, tarred_audio=self._tarred_audio, audio_dataset=self._data_iterator, ) self.perturb_with_input_noise(data, noise, ref_mic=ref_mic) @@ -600,6 +588,223 @@ def perturb_with_foreground_noise(self, data, noise, data_rms=None, max_noise_du data._samples[noise_idx : noise_idx + noise_samples.shape[0]] += noise_samples +class NoisePerturbationWithNormalization(Perturbation): + """ + Perturbation that adds noise to input audio, with normalisation to specific decibel level. + Also tiles shorter noise samples up to their corresponding clean audio length. + + Args: + manifest_path (str or list): Manifest file with paths to noise files, can be list if using multiple noise sources + min_snr_db (float): Minimum SNR of audio after noise is added + max_snr_db (float): Maximum SNR of audio after noise is added + snr_samples (list): A discrete list of SNRs DBs to sample from when mixing, will be used instead of [min_snr_db,max_snr_db] + norm_to_db (float): Will normalise clean, noise, and mixed samples to this DB + audio_tar_filepaths (str or list) : Tar files, if noise audio files are tarred, can be list for multiple sources + shuffle_n (int): Shuffle parameter for shuffling buffered files from the tar files + orig_sr (int): Original sampling rate of the noise files + rng (int): Random seed. Default is None + shard_strategy (str): if you're using tarred audio and wish to scatter instead of replicate, set this to 'scatter' + epsilon (float): minimum value for RMS DB normalisation to avoid divide by zero + """ + + def __init__( + self, + manifest_path=None, + min_snr_db=10, + max_snr_db=50, + snr_samples=None, + norm_to_db=None, + rng=None, + audio_tar_filepaths=None, + shuffle_n=128, + orig_sr=16000, + global_rank=0, + world_size=1, + shard_strategy='replicate', + epsilon=0.01, + ): + # import here to avoid circular import error + from nemo.collections.asr.data.audio_to_text import RandomizedChainDataset + + self._manifest = collections.ASRAudioText(manifest_path, parser=parsers.make_parser([]), index_by_file_id=True) + self._audiodataset = None + self._tarred_audio = False + self._orig_sr = orig_sr + self._data_iterator = None + + random.seed(rng) if rng else None + self._rng = rng + + if audio_tar_filepaths: + self._tarred_audio = True + if isinstance(manifest_path, str): + manifest_path = [manifest_path] + if isinstance(audio_tar_filepaths, str): + audio_tar_filepaths = [audio_tar_filepaths] + datasets = [] + for tarred_audio_filepath, manifest_filepath in zip(audio_tar_filepaths, manifest_path): + dataset = AugmentationDataset( + manifest_filepath, + tarred_audio_filepath, + shuffle_n, + rank=global_rank, + world_size=world_size, + shard_strategy=shard_strategy, + ) + datasets.append(dataset) + self._audiodataset = RandomizedChainDataset( + datasets, rnd_seed=(rng if rng else random.randint(0, 30000)) + global_rank + ) + if len(self._audiodataset) == 0: + raise RuntimeError( + "NoisePerturbationWithNormalization detected a zero length RandomizedChainDataset, should never happen" + ) + self._data_iterator = iter(self._audiodataset) + + self._min_snr_db = min_snr_db + self._max_snr_db = max_snr_db + self._norm_to_db = norm_to_db + self._snr_samples = snr_samples if isinstance(snr_samples, list) and len(snr_samples) > 0 else None + self._epsilon = epsilon + + @property + def orig_sr(self): + return self._orig_sr + + def read_one_audiosegment(self, target_sr): + if self._tarred_audio: + if self._data_iterator is None: + raise TypeError("Expected valid iterator but got None") + try: + audio_file, file_id, manifest_entry = next(self._data_iterator) + except StopIteration: + self._data_iterator = iter(self._audiodataset) + audio_file, file_id, manifest_entry = next(self._data_iterator) + + offset = 0 if manifest_entry.offset is None else manifest_entry.offset + duration = 0 if manifest_entry.duration is None else manifest_entry.duration + + else: + audio_record = random.sample(self._manifest.data, 1)[0] + audio_file = audio_record.audio_file + offset = 0 if audio_record.offset is None else audio_record.offset + duration = 0 if audio_record.duration is None else audio_record.duration + + return AudioSegment.from_file(audio_file, target_sr=target_sr, offset=offset, duration=duration) + + def perturb(self, data, ref_mic=0): + """ + Args: + data (AudioSegment): audio data + ref_mic (int): reference mic index for scaling multi-channel audios + """ + + noise = self.read_one_audiosegment(data.sample_rate) + + # noise samples need to be at least 1 second long to avoid strange oddities + # in the RMS SNR mixing, so we have a fail-safe here to ensure at least 1 sec duration + while noise.duration < 1: + noise = self.read_one_audiosegment(data.sample_rate) + + self.perturb_with_input_noise(data, noise, ref_mic=ref_mic, norm_to_db=self._norm_to_db) + + def snr_mixer(self, clean, noise, snr, norm_to_db=-25.0): + """ + Mixes the clean audio with the noise + Args: + clean (numpy array): the clean audio data + noise (numpy array): the noise audio data + snr (float): the SNR value for the mixing + norm_to_db (float): the DB value to normalise to before mixing + """ + clean = self.norm_audio_to_db(clean, norm_to_db) + noise = self.norm_audio_to_db(noise, norm_to_db) + + # Set the noise level for a given SNR + # note that if your noise doesn't overlap with your audio then your target SNR + # may not be achievable. Consider using an rms-threshold in the future + noisescalar = 10 ** (-snr / 20.0) + noisenewlevel = noise * noisescalar + noisyspeech = clean + noisenewlevel + + return clean, noisenewlevel, noisyspeech + + def norm_audio_to_db(self, x, norm_to_db): + """ + Normalises audio signal to particular db, with some epsilon in-case of divide by zero + Args: + x (numpy array): input audio signal + norm_to_db (float): the db to normalise to + """ + rms = (x ** 2).mean(axis=0) ** 0.5 + rms = np.where(np.isclose(rms, 0), self._epsilon, rms) + scalar = 10 ** (norm_to_db / 20.0) / rms + return x * scalar + + def concatenate_noise_sample(self, clean, noise, fs, silence_length=0.25): + """ + Tiles the noise array to match the clean audio array, with small silence between the joins + Args: + clean (numpy array): clean audio data + noise (numpy array): noise audio data + fs (int): sample rate used by both clean and noise audio data + silence_length (float): the amount of silence (in secs) to insert before tiling + """ + while len(noise) < len(clean): + if noise.ndim > 1: + zeros = np.zeros((int(fs * silence_length), noise.shape[-1])) + else: + zeros = np.zeros((int(fs * silence_length),)) + noiseconcat = np.append(noise, zeros, axis=0) + noise = np.append(noiseconcat, noise, axis=0) + + return noise + + def perturb_with_input_noise(self, data, noise, data_rms=None, ref_mic=0, norm_to_db=-25.0): + """ + Args: + data (AudioSegment): audio data + noise (AudioSegment): noise data + data_rms (Union[float, List[float]): rms_db for data input + ref_mic (int): reference mic index for scaling multi-channel audio, if set to None then + each channel will be scaled independently + norm_to_db (float): will normalise all audio to this DB + """ + if data.num_channels != noise.num_channels: + raise ValueError( + f"Found mismatched channels for data ({data.num_channels}) and noise ({noise.num_channels})." + ) + + if not (0 <= ref_mic < data.num_channels): + raise ValueError( + f" reference mic ID must be an integer in [0, {data.num_channels}), got {ref_mic} instead." + ) + + if self._snr_samples: + snr_db = random.sample(self._snr_samples, 1)[0] + else: + snr_db = random.uniform(self._min_snr_db, self._max_snr_db) + if data_rms is None: + data_rms = data.rms_db if ref_mic is None else data.rms_db[ref_mic] + + if norm_to_db is None: + norm_to_db = data_rms + + data_norm = data._samples + noise_norm = noise._samples + + if len(data_norm) == 0: + return + + if len(noise_norm) < len(data_norm): + noise_norm = self.concatenate_noise_sample(data_norm, noise_norm, data.sample_rate) + noise_norm = noise_norm[0 : len(data_norm)] + + _, _, noisy_snr = self.snr_mixer(clean=data_norm, noise=noise_norm, snr=snr_db, norm_to_db=norm_to_db) + + data._samples = noisy_snr + + class WhiteNoisePerturbation(Perturbation): """ Perturbation that adds white noise to an audio file in the training dataset. @@ -857,6 +1062,7 @@ def perturb(self, data): "impulse": ImpulsePerturbation, "shift": ShiftPerturbation, "noise": NoisePerturbation, + "noise_norm": NoisePerturbationWithNormalization, "white_noise": WhiteNoisePerturbation, "rir_noise_aug": RirAndNoisePerturbation, "transcode_aug": TranscodePerturbation, @@ -902,7 +1108,7 @@ def from_config(cls, config): return cls(perturbations=ptbs) -def process_augmentations(augmenter) -> Optional[AudioAugmentor]: +def process_augmentations(augmenter, global_rank=0, world_size=1) -> Optional[AudioAugmentor]: """Process list of online data augmentations. Accepts either an AudioAugmentor object with pre-defined augmentations, or a dictionary that points to augmentations that have been defined. @@ -1016,7 +1222,12 @@ class CustomPerturbation(perturb.Perturbation): raise ValueError("`prob` must be a float value between 0 and 1.") try: - augmentation = perturbation_types[augment_name](**augment_kwargs) + augmentation_class = perturbation_types[augment_name] + if 'global_rank' in inspect.signature(augmentation_class).parameters: + augment_kwargs['global_rank'] = global_rank + if 'world_size' in inspect.signature(augmentation_class).parameters: + augment_kwargs['world_size'] = world_size + augmentation = augmentation_class(**augment_kwargs) augmentations.append([prob, augmentation]) except KeyError: raise KeyError(f"Invalid perturbation name. Allowed values : {perturbation_types.keys()}") @@ -1028,40 +1239,38 @@ class CustomPerturbation(perturb.Perturbation): class AugmentationDataset(IterableDataset): """ A class that loads tarred audio files and cycles over the files in the dataset. - Accepts a single comma-separated JSON manifest file (in the same style as for the AudioToCharDataset/AudioToBPEDataset), as well as the path(s) to the tarball(s) containing the wav files. Each line of the manifest should contain the information for one audio file, including at least the transcript and name of the audio file within the tarball. - Valid formats for the audio_tar_filepaths argument include: (1) a single string that can be brace-expanded, e.g. 'path/to/audio.tar' or 'path/to/audio_{1..100}.tar.gz', or (2) a list of file paths that will not be brace-expanded, e.g. ['audio_1.tar', 'audio_2.tar', ...]. - Note: For brace expansion in (1), there may be cases where `{x..y}` syntax cannot be used due to shell interference. This occurs most commonly inside SLURM scripts. Therefore we provide a few equivalent replacements. Supported opening braces - { <=> (, [, < and the special tag _OP_. Supported closing braces - } <=> ), ], > and the special tag _CL_. For SLURM based tasks, we suggest the use of the special tags for ease of use. - See the WebDataset documentation for more information about accepted data and input formats. """ - def __init__(self, manifest_path: str, tar_filepaths: Union[str, List[str]], shuffle_n: int = 128): - self._manifest = collections.ASRAudioText(manifest_path, parser=parsers.make_parser([]), index_by_file_id=True) + def __init__( + self, + manifest_path: str, + tar_filepaths: Union[str, List[str]], + shuffle_n: int = 128, + rank: int = 0, + world_size: int = 1, + shard_strategy: str = "replicate", + ): + # import here to avoid circular import error + from nemo.collections.asr.data.audio_to_text import expand_sharded_filepaths - if isinstance(tar_filepaths, str): - # Replace '(' and '[' with '{' - brace_keys_open = ['(', '[', '<', '_OP_'] - for bkey in brace_keys_open: - if bkey in tar_filepaths: - tar_filepaths = tar_filepaths.replace(bkey, "{") + self._manifest = collections.ASRAudioText(manifest_path, parser=parsers.make_parser([]), index_by_file_id=True) - # Replace ')' and ']' with '}' - brace_keys_close = [')', ']', '>', '_CL_'] - for bkey in brace_keys_close: - if bkey in tar_filepaths: - tar_filepaths = tar_filepaths.replace(bkey, "}") + tar_filepaths = expand_sharded_filepaths( + tar_filepaths, shard_strategy=shard_strategy, world_size=world_size, global_rank=rank + ) if not HAVE_OMEGACONG_WEBDATASET: raise LightningNotInstalledException(self) @@ -1072,25 +1281,58 @@ def __init__(self, manifest_path: str, tar_filepaths: Union[str, List[str]], shu else: logging.info("WebDataset will not shuffle files within the tar files.") - self.audio_dataset = self.audio_dataset.rename(audio='wav', key='__key__').to_tuple('audio', 'key') - self.audio_iter = iter(self.audio_dataset) + self.audio_dataset = ( + self.audio_dataset.rename(audio='wav;ogg;flac', key='__key__') + .to_tuple('audio', 'key') + .pipe(self._loop_offsets) + ) def __len__(self): return len(self._manifest) + def _loop_offsets(self, iterator): + """This function is used to iterate through utterances with different offsets for each file. + """ + + class TarredAudioLoopOffsets: + def __init__(self, collection): + self.iterator = iterator + self.collection = collection + self.current_fn = None + self.current_bytes = None + self.offset_id = 0 + + def __iter__(self): + return self + + def __next__(self): + if self.current_fn is None: + self.current_bytes, self.current_fn = next(self.iterator) + self.offset_id = 0 + else: + offset_list = self.collection.mapping[self.current_fn] + if len(offset_list) == self.offset_id + 1: + self.current_bytes, self.current_fn = next(self.iterator) + self.offset_id = 0 + else: + self.offset_id += 1 + + return self.current_bytes, self.current_fn, self.offset_id + + return TarredAudioLoopOffsets(self._manifest) + def __iter__(self): - return self + audio_iter = iter(self.audio_dataset) - def __next__(self): while True: try: - audio_bytes, audio_filename = next(self.audio_iter) - + audio_bytes, audio_filename, offset_id = next(audio_iter) + file_id, _ = os.path.splitext(os.path.basename(audio_filename)) + manifest_idx = self._manifest.mapping[file_id][offset_id] + manifest_entry = self._manifest[manifest_idx] + + # Convert audio bytes to IO stream for processing (for SoundFile to read) + audio_file = io.BytesIO(audio_bytes) + yield audio_file, file_id, manifest_entry except StopIteration: - self.audio_iter = iter(self.audio_dataset) - audio_bytes, audio_filename = next(self.audio_iter) - file_id, _ = os.path.splitext(os.path.basename(audio_filename)) - - # Convert audio bytes to IO stream for processing (for SoundFile to read) - audio_file = io.BytesIO(audio_bytes) - return audio_file, file_id + audio_iter = iter(self.audio_dataset) diff --git a/nemo/utils/model_utils.py b/nemo/utils/model_utils.py index 45fabceb4a91..211ffdcdf11e 100644 --- a/nemo/utils/model_utils.py +++ b/nemo/utils/model_utils.py @@ -256,7 +256,7 @@ def resolve_validation_dataloaders(model: 'ModelPT'): ds_key = resolve_dataset_name_from_cfg(cfg.validation_ds) - if ds_key is None: + if ds_key is None or val_dl_idx < 0: logging.debug( "Could not resolve file path from provided config - {}. " "Disabling support for multi-dataloaders.".format(cfg.validation_ds) From 4942dcfe86e0597944f342dd2718a104f094c477 Mon Sep 17 00:00:00 2001 From: Ryan Langman Date: Tue, 2 May 2023 09:10:03 -0700 Subject: [PATCH 09/10] [TTS] Add script for computing feature stats (#6508) * [TTS] Add script for computing feature stats Signed-off-by: Ryan * [TTS] Add overwrite config Signed-off-by: Ryan --------- Signed-off-by: Ryan --- .../tts/compute_feature_stats.py | 196 ++++++++++++++++++ 1 file changed, 196 insertions(+) create mode 100644 scripts/dataset_processing/tts/compute_feature_stats.py diff --git a/scripts/dataset_processing/tts/compute_feature_stats.py b/scripts/dataset_processing/tts/compute_feature_stats.py new file mode 100644 index 000000000000..6774563810d9 --- /dev/null +++ b/scripts/dataset_processing/tts/compute_feature_stats.py @@ -0,0 +1,196 @@ +# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This script is to compute global and speaker-level feature statistics for a given TTS training manifest. + +This script should be run after compute_features.py as it loads the precomputed feature data. + +$ python /scripts/dataset_processing/tts/compute_feature_stats.py \ + --feature_config_path=/examples/tts/conf/features/feature_22050.yaml + --manifest_path=/manifest.json \ + --audio_dir=/audio \ + --feature_dir=/features \ + --stats_path=/feature_stats.json + +The output dictionary will contain the feature statistics for every speaker, as well as a "default" entry +with the global statistics. + +For example: + +{ + "default": { + "pitch_mean": 100.0, + "pitch_std": 50.0, + "energy_mean": 7.5, + "energy_std": 4.5 + }, + "speaker1": { + "pitch_mean": 105.0, + "pitch_std": 45.0, + "energy_mean": 7.0, + "energy_std": 5.0 + }, + "speaker2": { + "pitch_mean": 110.0, + "pitch_std": 30.0, + "energy_mean": 5.0, + "energy_std": 2.5 + } +} + +""" + +import argparse +import json +from collections import defaultdict +from pathlib import Path +from typing import List, Tuple + +import torch +from hydra.utils import instantiate +from omegaconf import OmegaConf +from tqdm import tqdm + +from nemo.collections.asr.parts.utils.manifest_utils import read_manifest + + +def get_args(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter, description="Compute TTS feature statistics.", + ) + parser.add_argument( + "--feature_config_path", required=True, type=Path, help="Path to feature config file.", + ) + parser.add_argument( + "--manifest_path", required=True, type=Path, help="Path to training manifest.", + ) + parser.add_argument( + "--audio_dir", required=True, type=Path, help="Path to base directory with audio data.", + ) + parser.add_argument( + "--feature_dir", required=True, type=Path, help="Path to directory where feature data was stored.", + ) + parser.add_argument( + "--feature_names", default="pitch,energy", type=str, help="Comma separated list of features to process.", + ) + parser.add_argument( + "--mask_field", + default="voiced_mask", + type=str, + help="If provided, stat computation will ignore non-masked frames.", + ) + parser.add_argument( + "--stats_path", + default=Path("feature_stats.json"), + type=Path, + help="Path to output JSON file with dataset feature statistics.", + ) + parser.add_argument( + "--overwrite", default=False, type=bool, help="Whether to overwrite the output stats file if it exists.", + ) + + args = parser.parse_args() + return args + + +def _compute_stats(values: List[torch.Tensor]) -> Tuple[float, float]: + values_tensor = torch.cat(values, dim=0) + mean = values_tensor.mean().item() + std = values_tensor.std(dim=0).item() + return mean, std + + +def main(): + args = get_args() + + feature_config_path = args.feature_config_path + manifest_path = args.manifest_path + audio_dir = args.audio_dir + feature_dir = args.feature_dir + feature_name_str = args.feature_names + mask_field = args.mask_field + stats_path = args.stats_path + overwrite = args.overwrite + + if not manifest_path.exists(): + raise ValueError(f"Manifest {manifest_path} does not exist.") + + if not audio_dir.exists(): + raise ValueError(f"Audio directory {audio_dir} does not exist.") + + if not feature_dir.exists(): + raise ValueError( + f"Feature directory {audio_dir} does not exist. " + f"Please check that the path is correct and that you ran compute_features.py" + ) + + if stats_path.exists(): + if overwrite: + print(f"Will overwrite existing stats path: {stats_path}") + else: + raise ValueError(f"Stats path already exists: {stats_path}") + + feature_config = OmegaConf.load(feature_config_path) + feature_config = instantiate(feature_config) + featurizer_dict = feature_config.featurizers + + print(f"Found featurizers for {list(featurizer_dict.keys())}.") + featurizers = featurizer_dict.values() + + feature_names = feature_name_str.split(",") + # For each feature, we have a dictionary mapping speaker IDs to a list containing all features + # for that speaker + feature_stats = {name: defaultdict(list) for name in feature_names} + + entries = read_manifest(manifest_path) + + for entry in tqdm(entries): + speaker = entry["speaker"] + + entry_dict = {} + for featurizer in featurizers: + feature_dict = featurizer.load(manifest_entry=entry, audio_dir=audio_dir, feature_dir=feature_dir) + entry_dict.update(feature_dict) + + if mask_field: + mask = entry_dict[mask_field] + else: + mask = None + + for feature_name in feature_names: + values = entry_dict[feature_name] + if mask is not None: + values = values[mask] + + feature_stat_dict = feature_stats[feature_name] + feature_stat_dict["default"].append(values) + feature_stat_dict[speaker].append(values) + + stat_dict = defaultdict(dict) + for feature_name in feature_names: + mean_key = f"{feature_name}_mean" + std_key = f"{feature_name}_std" + feature_stat_dict = feature_stats[feature_name] + for speaker_id, values in feature_stat_dict.items(): + speaker_mean, speaker_std = _compute_stats(values) + stat_dict[speaker_id][mean_key] = speaker_mean + stat_dict[speaker_id][std_key] = speaker_std + + with open(stats_path, 'w', encoding="utf-8") as stats_f: + json.dump(stat_dict, stats_f, indent=4) + + +if __name__ == "__main__": + main() From 908aa67db3434f0dcedceb9625fb34899a91888d Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 2 May 2023 19:04:55 +0000 Subject: [PATCH 10/10] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- nemo/collections/tts/modules/submodules.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nemo/collections/tts/modules/submodules.py b/nemo/collections/tts/modules/submodules.py index 72d853f3d3e7..eb44fb25fcf6 100644 --- a/nemo/collections/tts/modules/submodules.py +++ b/nemo/collections/tts/modules/submodules.py @@ -742,7 +742,7 @@ def output_types(self): return { "embs": NeuralType(('B', 'D'), EncodedRepresentation()), } - + def overwrite_precomputed_emb(self, emb): self.precomputed_emb = torch.nn.Parameter(emb)