From c733fc59483cf8f4e82364bedd81ecfba91f8b1f Mon Sep 17 00:00:00 2001 From: sam1373 Date: Wed, 23 Feb 2022 16:20:57 -0800 Subject: [PATCH 1/6] ssl update Signed-off-by: sam1373 --- .../conf/ssl/citrinet/citrinet_ssl_1024.yaml | 35 +++--- .../conf/ssl/citrinet/citrinet_ssl_ci.yaml | 27 ++-- .../asr/conf/ssl/conformer/conformer_ssl.yaml | 32 ++--- .../conf/ssl/contextnet/contextnet_ssl.yaml | 26 ++-- .../asr/losses/pt_losses/contrastive.py | 119 ++++++++++++------ nemo/collections/asr/modules/__init__.py | 1 + .../asr/modules/audio_preprocessing.py | 68 ++++++++++ 7 files changed, 218 insertions(+), 90 deletions(-) diff --git a/examples/asr/conf/ssl/citrinet/citrinet_ssl_1024.yaml b/examples/asr/conf/ssl/citrinet/citrinet_ssl_1024.yaml index 93800c9455a7..1e8289842561 100644 --- a/examples/asr/conf/ssl/citrinet/citrinet_ssl_1024.yaml +++ b/examples/asr/conf/ssl/citrinet/citrinet_ssl_1024.yaml @@ -21,8 +21,8 @@ model: sample_rate: 16000 batch_size: 32 trim_silence: false - max_duration: 35.0 - min_duration: 3.0 + max_duration: 16.7 + min_duration: 8.0 shuffle: true is_tarred: false tarred_audio_filepaths: null @@ -39,7 +39,7 @@ model: shuffle: false use_start_end_token: false max_duration: 35.0 - min_duration: 3.0 + min_duration: 8.0 model_defaults: repeat: 5 @@ -68,11 +68,11 @@ model: stft_conv: false spec_augment: - _target_: nemo.collections.asr.modules.SpectrogramAugmentation - freq_masks: 4 - time_masks: 10 - freq_width: 27 - time_width: 0.05 + _target_: nemo.collections.asr.modules.MaskedPatchAugmentation + freq_masks: 3 + freq_width: 20 + patch_size: 48 + mask_patches: 10 encoder: _target_: nemo.collections.asr.modules.ConvASREncoder @@ -402,22 +402,24 @@ model: se_context_size: ${model.model_defaults.se_context_size} kernel_size_factor: ${model.model_defaults.kernel_size_factor} - decoder: _target_: nemo.collections.asr.modules.ConvASRDecoderReconstruction feat_in: ${model.model_defaults.enc_final} feat_hidden: 128 feat_out: ${model.model_defaults.decoder_out_channels} stride_layers: 1 - + #if loss.combine_time_steps is less than the encoder stride, then a corresponding amount of stride_layers needs to + #be added to the decoder (here stride is 8 and combine_time_steps is 4, so 1 stride layer is added) + non_stride_layers: 0 loss: _target_: nemo.collections.asr.losses.ContrastiveLoss in_dim: *n_mels proj_dim: ${model.model_defaults.decoder_out_channels} - combine_time_steps: 4 - codebook_size: 1200 - + combine_time_steps: 4 #how many spectrogram time steps are used for one target/representation for contrastive task + quantized_targets: false #should quantizer or linear layer be used + sample_from_same_utterance_only: true #should negatives be sampled only from the same utterance + sample_from_non_masked: false #should negatives be sampled from non-masked steps optim: name: novograd @@ -438,13 +440,14 @@ model: last_epoch: -1 trainer: - gpus: 0 # number of gpus + devices: 1 # number of gpus max_epochs: 100 max_steps: null # computed at runtime if not set num_nodes: 1 - accelerator: ddp + accelerator: gpu + strategy: ddp accumulate_grad_batches: 1 - checkpoint_callback: false # Provided by exp_manager + enable_checkpointing: False # Provided by exp_manager logger: false # Provided by exp_manager log_every_n_steps: 100 # Interval of logging. val_check_interval: 1.0 # Set to 0.25 to check 4 times per epoch, or an int for number of iterations diff --git a/examples/asr/conf/ssl/citrinet/citrinet_ssl_ci.yaml b/examples/asr/conf/ssl/citrinet/citrinet_ssl_ci.yaml index 205cf80c6c15..4c2a55ac1113 100644 --- a/examples/asr/conf/ssl/citrinet/citrinet_ssl_ci.yaml +++ b/examples/asr/conf/ssl/citrinet/citrinet_ssl_ci.yaml @@ -11,7 +11,7 @@ model: batch_size: 32 trim_silence: false max_duration: 35.0 - min_duration: 3.0 + min_duration: 4.0 shuffle: true is_tarred: false tarred_audio_filepaths: null @@ -27,7 +27,7 @@ model: shuffle: false use_start_end_token: false max_duration: 35.0 - min_duration: 3.0 + min_duration: 4.0 model_defaults: repeat: 1 @@ -55,11 +55,11 @@ model: stft_conv: false spec_augment: - _target_: nemo.collections.asr.modules.SpectrogramAugmentation - freq_masks: 4 - time_masks: 10 - freq_width: 27 - time_width: 0.05 + _target_: nemo.collections.asr.modules.MaskedPatchAugmentation + freq_masks: 3 + freq_width: 20 + patch_size: 16 + mask_patches: 10 encoder: _target_: nemo.collections.asr.modules.ConvASREncoder @@ -403,8 +403,10 @@ model: in_dim: *n_mels proj_dim: ${model.model_defaults.decoder_out_channels} combine_time_steps: 4 - codebook_size: 300 - + sample_from_non_masked: false + num_negatives: 30 + quantized_targets: false + sample_from_same_utterance_only: true optim: name: novograd @@ -425,13 +427,14 @@ model: last_epoch: -1 trainer: - gpus: 0 # number of gpus + devices: 1 # number of gpus max_epochs: 100 max_steps: null # computed at runtime if not set num_nodes: 1 - accelerator: ddp + accelerator: gpu + strategy: ddp accumulate_grad_batches: 1 - checkpoint_callback: false # Provided by exp_manager + enable_checkpointing: False # Provided by exp_manager logger: false # Provided by exp_manager log_every_n_steps: 100 # Interval of logging. val_check_interval: 1.0 # Set to 0.25 to check 4 times per epoch, or an int for number of iterations diff --git a/examples/asr/conf/ssl/conformer/conformer_ssl.yaml b/examples/asr/conf/ssl/conformer/conformer_ssl.yaml index 7952163346db..a31781763954 100644 --- a/examples/asr/conf/ssl/conformer/conformer_ssl.yaml +++ b/examples/asr/conf/ssl/conformer/conformer_ssl.yaml @@ -36,8 +36,8 @@ model: pin_memory: false use_start_end_token: true trim_silence: false - max_duration: 16.7 # it is set for LibriSpeech, you may need to update it for your dataset - min_duration: 3.0 + max_duration: 16.7 + min_duration: 8.0 # tarred datasets is_tarred: false tarred_audio_filepaths: null @@ -54,7 +54,7 @@ model: num_workers: 8 pin_memory: true use_start_end_token: false - + min_duration: 8.0 preprocessor: @@ -73,11 +73,11 @@ model: pad_value: 0.0 spec_augment: - _target_: nemo.collections.asr.modules.SpectrogramAugmentation - freq_masks: 4 - time_masks: 10 - freq_width: 27 - time_width: 0.05 + _target_: nemo.collections.asr.modules.MaskedPatchAugmentation + freq_masks: 3 + freq_width: 20 + patch_size: 48 + mask_patches: 10 encoder: _target_: nemo.collections.asr.modules.ConformerEncoder @@ -120,15 +120,18 @@ model: feat_hidden: 128 feat_out: *dec_out stride_layers: 0 - non_stride_layers: 2 - + #if loss.combine_time_steps is less than the encoder stride, then a corresponding amount of stride_layers needs to + #be added to the decoder (here stride and combine_time_steps are both 4) + non_stride_layers: 0 loss: _target_: nemo.collections.asr.losses.ContrastiveLoss in_dim: *n_mels proj_dim: *dec_out - combine_time_steps: 4 - codebook_size: 1200 + combine_time_steps: 4 #how many spectrogram time steps are used for one target/representation for contrastive task + quantized_targets: false #should quantizer or linear layer be used + sample_from_same_utterance_only: true #should negatives be sampled only from the same utterance + sample_from_non_masked: false #should negatives be sampled from non-masked steps optim: name: adamw @@ -147,12 +150,13 @@ model: min_lr: 1e-6 trainer: - gpus: -1 # number of GPUs, -1 would use all available GPUs + devices: -1 # number of GPUs, -1 would use all available GPUs num_nodes: 1 max_epochs: 1000 max_steps: null # computed at runtime if not set val_check_interval: 1.0 # Set to 0.25 to check 4 times per epoch, or an int for number of iterations - accelerator: ddp + accelerator: gpu + strategy: ddp accumulate_grad_batches: 1 gradient_clip_val: 0.0 amp_level: O0 # O1/O2 for mixed precision diff --git a/examples/asr/conf/ssl/contextnet/contextnet_ssl.yaml b/examples/asr/conf/ssl/contextnet/contextnet_ssl.yaml index d39ef24c87a6..6c74fc6036a0 100644 --- a/examples/asr/conf/ssl/contextnet/contextnet_ssl.yaml +++ b/examples/asr/conf/ssl/contextnet/contextnet_ssl.yaml @@ -25,6 +25,7 @@ model: batch_size: 16 # Can be increased if memory allows or when using smaller model trim_silence: false max_duration: 16.7 + min_duration: 8.0 shuffle: true use_start_end_token: false num_workers: 16 @@ -46,6 +47,7 @@ model: use_start_end_token: false num_workers: 16 pin_memory: true + min_duration: 8.0 model_defaults: filters: 1024 @@ -73,11 +75,11 @@ model: stft_conv: false spec_augment: - _target_: nemo.collections.asr.modules.SpectrogramAugmentation - freq_masks: 4 - time_masks: 10 # can be 5 for small-med models, 10 for larger models. - freq_width: 27 - time_width: 0.05 + _target_: nemo.collections.asr.modules.MaskedPatchAugmentation + freq_masks: 3 + freq_width: 20 + patch_size: 48 + mask_patches: 10 encoder: _target_: nemo.collections.asr.modules.ConvASREncoder @@ -374,13 +376,18 @@ model: feat_hidden: 128 feat_out: ${model.model_defaults.decoder_out_channels} stride_layers: 1 + #if loss.combine_time_steps is less than the encoder stride, then a corresponding amount of stride_layers needs to + #be added to the decoder (here stride is 8 and combine_time_steps is 4, so 1 stride layer is added) + non_stride_layers: 0 loss: _target_: nemo.collections.asr.losses.ContrastiveLoss in_dim: *n_mels proj_dim: ${model.model_defaults.decoder_out_channels} - combine_time_steps: 4 - codebook_size: 1200 + combine_time_steps: 4 #how many spectrogram time steps are used for one target/representation for contrastive task + quantized_targets: false #should quantizer or linear layer be used + sample_from_same_utterance_only: true #should negatives be sampled only from the same utterance + sample_from_non_masked: false #should negatives be sampled from non-masked steps optim: name: novograd @@ -401,11 +408,12 @@ model: last_epoch: -1 trainer: - gpus: 0 # number of gpus + devices: 1 # number of gpus max_epochs: 100 max_steps: null # computed at runtime if not set num_nodes: 1 # Should be set via SLURM variable `SLURM_JOB_NUM_NODES` - accelerator: ddp + accelerator: gpu + strategy: ddp accumulate_grad_batches: 1 checkpoint_callback: false # Provided by exp_manager logger: false # Provided by exp_manager diff --git a/nemo/collections/asr/losses/pt_losses/contrastive.py b/nemo/collections/asr/losses/pt_losses/contrastive.py index 26aaacaee114..ec5429e98936 100644 --- a/nemo/collections/asr/losses/pt_losses/contrastive.py +++ b/nemo/collections/asr/losses/pt_losses/contrastive.py @@ -48,12 +48,13 @@ def __init__( proj_dim: int = 128, combine_time_steps: int = 1, num_negatives: int = 100, - quantized_targets: bool = True, + quantized_targets: bool = False, codebook_size: int = 320, prob_ppl_weight: float = 0.1, logit_temp: float = 0.1, reduce: str = "sum", - sample_from_non_masked: bool = True, + sample_from_same_utterance_only: bool = True, + sample_from_non_masked: bool = False, sample_from_codebook: bool = False, group_loss: bool = False, num_groups: int = 2, @@ -76,6 +77,7 @@ def __init__( prob_ppl_weight: Float multiplier on the perplexity loss for target quantization. logit_temp: Float temperature for normalizing logits. reduce: String representing the type of reduction used for cross entropy. + sample_from_same_utterance_only: Bool that determines if negatives should be sampled only from same utterance. sample_from_non_masked: Bool that determines if negatives should be sampled from non-masked steps of the spectrogram. sample_from_codebook: Bool that determines if negatives should be sampled from entire codebook. group_loss: Bool that determines if loss should be computed separately for each group in the quantizer codebook. @@ -107,6 +109,7 @@ def __init__( self.logit_temp = logit_temp self.reduce = reduce self.combine_time_steps = combine_time_steps + self.sample_from_same_utterance_only = sample_from_same_utterance_only self.sample_from_non_masked = sample_from_non_masked self.sample_from_codebook = sample_from_codebook self.group_loss = group_loss @@ -116,13 +119,16 @@ def __init__( self.target_proj = nn.Linear(in_dim * combine_time_steps, proj_dim) def sample_negatives(self, y, num): + # y - T'xBxC or T'xC high = y.shape[0] - with torch.no_grad(): - neg_idxs = torch.randint(low=0, high=high - 1, size=(self.num_negatives * num,)) + neg_idxs = torch.multinomial(torch.ones((num, high), device=y.device), self.num_negatives) negs = y[neg_idxs.view(-1)] - negs = negs.view(num, self.num_negatives, y.shape[-1]).permute(1, 0, 2) # to NxTxC + negs = negs.view((num, self.num_negatives) + y.shape[1:]) + negs = negs.transpose(0, 1) + # negs - NxT'xBxC or NxT'xC + return negs, neg_idxs @typecheck() @@ -133,46 +139,79 @@ def forward(self, spectrograms, spec_masks, decoder_outputs): # BxTxC targets = targets.reshape(targets.shape[0], targets.shape[1] // self.combine_time_steps, -1) - masks = masks.reshape(targets.shape) + masks = masks.reshape(targets.shape[0], targets.shape[1], -1) if self.quantized_targets: targets, prob_ppl_loss, cur_codebook_temp = self.quantizer(targets) else: targets = self.target_proj(targets) - masks = masks.mean(-1) > self.mask_threshold - out_masked_only = decoder_outputs[masks] - targets_masked_only = targets[masks] - - # T'xC - # number of masked time steps to predict (T') - - if self.group_loss: - num_groups = self.quantizer.groups - negatives = self.quantizer.vars.reshape(num_groups, self.quantizer.num_vars, -1) - # GxNx(C//G) - negatives = negatives.transpose(0, 1) - # NxGx(C//G) - negatives = negatives.unsqueeze(1).expand(-1, out_masked_only.shape[0], -1, -1) - # NxT'xGx(C//G) - negatives = negatives.reshape(negatives.shape[0], -1, negatives.shape[-1]) - # NxT'Gx(C//G) - - out_masked_only = out_masked_only.reshape(-1, out_masked_only.shape[-1] // num_groups) - targets_masked_only = targets_masked_only.reshape(-1, targets_masked_only.shape[-1] // num_groups) - # T'Gx(C//G) - elif self.sample_from_codebook: - # sample from the full codebook - negatives = self.quantizer.sample_from_codebook(self.num_negatives, targets_masked_only.size(0)) - elif self.sample_from_non_masked: - # sample from all steps in batch - negatives, _ = self.sample_negatives( - targets.reshape(targets.shape[0] * targets.shape[1], -1), targets_masked_only.size(0), # BTxC - ) # T' + if self.sample_from_same_utterance_only: + bs = decoder_outputs.shape[0] + masks = masks.mean(-1) > self.mask_threshold + out_masked_only = decoder_outputs[masks] + targets_masked_only = targets[masks] + out_masked_only = out_masked_only.reshape(bs, -1, out_masked_only.shape[-1]) + targets_masked_only = targets_masked_only.reshape(bs, -1, targets_masked_only.shape[-1]) + + # BxT'xC + # number of masked time steps to predict (T') + # -> T'xBxC + + out_masked_only = out_masked_only.transpose(0, 1) + targets_masked_only = targets_masked_only.transpose(0, 1) + # -> T'xBxC + + if self.sample_from_non_masked: + # sample from all steps in utterance + negatives, _ = self.sample_negatives( + targets.transpose(0, 1), targets_masked_only.size(0), # TxBxC # T' + ) + else: + # only sample from masked steps in utterance + negatives, _ = self.sample_negatives(targets_masked_only, targets_masked_only.size(0)) # T'xBxC # T' + # NxT'xBxC + + out_masked_only = out_masked_only.reshape(-1, out_masked_only.shape[-1]) + targets_masked_only = targets_masked_only.reshape(-1, targets_masked_only.shape[-1]) + negatives = negatives.reshape(self.num_negatives, -1, negatives.shape[-1]) + + # T'BxC and NxT'BxC + else: - # only sample from masked steps - negatives, _ = self.sample_negatives(targets_masked_only, targets_masked_only.size(0)) # T'xC # T' - # NxT'xC + masks = masks.mean(-1) > self.mask_threshold + out_masked_only = decoder_outputs[masks] + targets_masked_only = targets[masks] + + # T'xC + # number of masked time steps to predict (T') + + if self.group_loss: + num_groups = self.quantizer.groups + negatives = self.quantizer.vars.reshape(num_groups, self.quantizer.num_vars, -1) + # GxNx(C//G) + negatives = negatives.transpose(0, 1) + # NxGx(C//G) + negatives = negatives.unsqueeze(1).expand(-1, out_masked_only.shape[0], -1, -1) + # NxT'xGx(C//G) + negatives = negatives.reshape(negatives.shape[0], -1, negatives.shape[-1]) + # NxT'Gx(C//G) + + out_masked_only = out_masked_only.reshape(-1, out_masked_only.shape[-1] // num_groups) + targets_masked_only = targets_masked_only.reshape(-1, targets_masked_only.shape[-1] // num_groups) + # T'Gx(C//G) + elif self.sample_from_codebook: + # sample from the full codebook + negatives = self.quantizer.sample_from_codebook(self.num_negatives, targets_masked_only.size(0)) + elif self.sample_from_non_masked: + # sample from all steps in batch + negatives, _ = self.sample_negatives( + targets.reshape(targets.shape[0] * targets.shape[1], -1), targets_masked_only.size(0), # BTxC + ) # T' + else: + # only sample from masked steps + negatives, _ = self.sample_negatives(targets_masked_only, targets_masked_only.size(0)) # T'xC # T' + # NxT'xC # Calculate similarity between logits and all targets similarity_scores = self._calculate_similarity(out_masked_only, negatives, targets_masked_only) @@ -198,6 +237,7 @@ def forward(self, spectrograms, spec_masks, decoder_outputs): if not isinstance(loss, torch.Tensor): loss = torch.Tensor([0]).to(device=decoder_outputs.device) + return loss def _calculate_similarity(self, logits, negatives, targets): @@ -217,4 +257,5 @@ def _calculate_similarity(self, logits, negatives, targets): return logits def set_num_updates(self, num_updates): - self.quantizer.set_num_updates(num_updates) + if self.quantized_targets: + self.quantizer.set_num_updates(num_updates) diff --git a/nemo/collections/asr/modules/__init__.py b/nemo/collections/asr/modules/__init__.py index e76ab9f7eb8f..93b23bec9246 100644 --- a/nemo/collections/asr/modules/__init__.py +++ b/nemo/collections/asr/modules/__init__.py @@ -16,6 +16,7 @@ AudioToMelSpectrogramPreprocessor, AudioToMFCCPreprocessor, CropOrPadSpectrogramAugmentation, + MaskedPatchAugmentation, SpectrogramAugmentation, ) from nemo.collections.asr.modules.beam_search_decoder import BeamSearchDecoderWithLM diff --git a/nemo/collections/asr/modules/audio_preprocessing.py b/nemo/collections/asr/modules/audio_preprocessing.py index a6f6622a3c24..c5007fda8d9a 100644 --- a/nemo/collections/asr/modules/audio_preprocessing.py +++ b/nemo/collections/asr/modules/audio_preprocessing.py @@ -13,6 +13,7 @@ # limitations under the License. import math +import random from abc import ABC, abstractmethod from dataclasses import dataclass from typing import Any, Optional @@ -52,6 +53,7 @@ 'AudioToMelSpectrogramPreprocessor', 'AudioToMFCCPreprocessor', 'SpectrogramAugmentation', + 'MaskedPatchAugmentation', 'CropOrPadSpectrogramAugmentation', ] @@ -516,6 +518,72 @@ def forward(self, input_spec, length): return augmented_spec +class MaskedPatchAugmentation(NeuralModule): + """ + Zeroes out fixed size time patches of the spectrogram. + All samples in batch are guaranteed to have the same amount of masked time steps. + Optionally also performs frequency masking in the same way as SpecAugment. + Args: + patch_size (int): up to how many time steps does one patch consist of. + Defaults to 48. + mask_patches (int): how many patches should be masked in each sample. + Defaults to 10. + freq_masks (int): how many frequency segments should be cut. + Defaults to 0. + freq_width (int): maximum number of frequencies to be cut in a segment. + Defaults to 0. + """ + + @property + def input_types(self): + """Returns definitions of module input types + """ + return { + "input_spec": NeuralType(('B', 'D', 'T'), SpectrogramType()), + "length": NeuralType(tuple('B'), LengthsType()), + } + + @property + def output_types(self): + """Returns definitions of module output types + """ + return {"augmented_spec": NeuralType(('B', 'D', 'T'), SpectrogramType())} + + def __init__( + self, patch_size: int = 48, mask_patches: int = 10, freq_masks: int = 0, freq_width: int = 0, + ): + super().__init__() + self.patch_size = patch_size + self.mask_patches = mask_patches + + if freq_masks > 0: + self.spec_augment = SpecAugment(freq_masks=freq_masks, time_masks=0, freq_width=freq_width, time_width=0,) + else: + self.spec_augment = None + + @typecheck() + def forward(self, input_spec, length): + augmented_spec = input_spec + + min_len = torch.min(length) + mask_patches = self.mask_patches + if min_len < self.patch_size * self.mask_patches: + mask_patches = min_len // self.patch_size + + for idx in range(input_spec.shape[0]): + cur_len = length[idx] + patches = range(cur_len // self.patch_size - 1) + masked_patches = random.sample(patches, mask_patches) + + for mp in masked_patches: + augmented_spec[idx, :, mp * self.patch_size : (mp + 1) * self.patch_size] = 0.0 + + if self.spec_augment is not None: + augmented_spec = self.spec_augment(input_spec=augmented_spec, length=length) + + return augmented_spec + + class CropOrPadSpectrogramAugmentation(NeuralModule): """ Pad or Crop the incoming Spectrogram to a certain shape. From 308a8189bf762a7850978df35266227c0df04dd7 Mon Sep 17 00:00:00 2001 From: sam1373 Date: Wed, 23 Feb 2022 16:25:09 -0800 Subject: [PATCH 2/6] tutorial update Signed-off-by: sam1373 --- .../asr/Self_Supervised_Pre_Training.ipynb | 93 ++++++++++++------- 1 file changed, 58 insertions(+), 35 deletions(-) diff --git a/tutorials/asr/Self_Supervised_Pre_Training.ipynb b/tutorials/asr/Self_Supervised_Pre_Training.ipynb index 0da15cb50b7a..101ceabb7dc4 100644 --- a/tutorials/asr/Self_Supervised_Pre_Training.ipynb +++ b/tutorials/asr/Self_Supervised_Pre_Training.ipynb @@ -28,7 +28,7 @@ "\n", "## Install NeMo\n", "BRANCH = 'r1.7.0'\n", - "!python -m pip install git+https://github.com/sam1373/NeMo.git@$BRANCH#egg=nemo_toolkit[all]\n", + "!python -m pip install git+https://github.com/NVIDIA/NeMo.git@$BRANCH#egg=nemo_toolkit[all]\n", "\n", "\"\"\"\n", "Remember to restart the runtime for the kernel to pick up any upgraded packages (e.g. matplotlib)!\n", @@ -109,7 +109,7 @@ "# Download the dataset. This will take a few moments...\n", "print(\"******\")\n", "if not os.path.exists(data_dir + '/an4_sphere.tar.gz'):\n", - " an4_url = 'http://www.speech.cs.cmu.edu/databases/an4/an4_sphere.tar.gz'\n", + " an4_url = 'https://dldata-public.s3.us-east-2.amazonaws.com/an4_sphere.tar.gz'\n", " an4_path = wget.download(an4_url, data_dir)\n", " print(f\"Dataset downloaded at: {an4_path}\")\n", "else:\n", @@ -271,8 +271,8 @@ "source": [ "## Grab the configs we'll use in this example\n", "!mkdir configs\n", - "!wget -P configs/ https://raw.githubusercontent.com/NVIDIA/NeMo/main/examples/asr/conf/ssl/citrinet/citrinet_ssl_1024.yaml\n", - "!wget -P configs/ https://raw.githubusercontent.com/NVIDIA/NeMo/main/examples/asr/conf/citrinet/citrinet_1024.yaml\n" + "!wget -P configs/ https://raw.githubusercontent.com/NVIDIA/NeMo/r1.7.0/examples/asr/conf/ssl/citrinet/citrinet_ssl_1024.yaml\n", + "!wget -P configs/ https://raw.githubusercontent.com/NVIDIA/NeMo/r1.7.0/examples/asr/conf/citrinet/citrinet_1024.yaml\n" ] }, { @@ -281,7 +281,7 @@ "id": "RLzjCgmHuJ_j" }, "source": [ - "Since this config is for a very large model, we will modify it to make a much smaller version for the purpose of this tutorial by reducing the number of channels and the number of sub-blocks in each block, as well as reducing augmentation." + "Since this config is for a very large model, we will modify it to make a much smaller version for the purpose of this tutorial by reducing the number of channels and the number of sub-blocks in each block." ] }, { @@ -303,28 +303,18 @@ "cfg.model.model_defaults.repeat = 1\n", "cfg.model.model_defaults.enc_final = 256\n", "\n", - "cfg.model.spec_augment.freq_masks = 2\n", - "cfg.model.spec_augment.time_masks = 5\n", - "\n", - "cfg.model.optim.weight_decay = 0\n", - "cfg.model.optim.sched.warmup_steps = 2000\n", - "\n", "cfg.model.train_ds.manifest_filepath = \"/content/an4/train_manifest.json\"\n", "cfg.model.train_ds.batch_size = 16\n", "\n", "cfg.model.validation_ds.manifest_filepath = \"/content/an4/test_manifest.json\"\n", "cfg.model.validation_ds.batch_size = 16\n", "\n", - "cfg.trainer.max_epochs = None\n", - "cfg.trainer.max_steps = 10000 \n", - "#in practice you will usually want a much larger amount of pre-training steps\n", - "cfg.trainer.log_every_n_steps = 100\n", - "\n", "if torch.cuda.is_available():\n", - " cfg.trainer.gpus = 1\n", - " cfg.trainer.accelerator = 'dp'\n", + " cfg.trainer.accelerator = 'gpu'\n", + " cfg.trainer.strategy = 'dp'\n", "else:\n", - " cfg.trainer.gpus = 0\n", + " cfg.trainer.accelerator = 'cpu'\n", + " cfg.trainer.strategy = None\n", "\n", "cfg.exp_manager.exp_dir = \"/content/exp\"\n", "cfg.exp_manager.name = \"pre_trained\"\n", @@ -332,7 +322,30 @@ "cfg.exp_manager.create_tensorboard_logger = False\n", "cfg.exp_manager.resume_if_exists = True\n", "cfg.exp_manager.resume_ignore_no_checkpoint = True\n", - "cfg.exp_manager.checkpoint_callback_params.save_best_model = True" + "cfg.exp_manager.checkpoint_callback_params.save_best_model = True\n", + "\n", + "cfg.trainer.check_val_every_n_epoch = 1\n", + "\n", + "cfg.model.optim.sched.name = \"CosineAnnealing\"\n", + "cfg.model.optim.sched.warmup_steps = 1000\n", + "cfg.model.optim.sched.max_steps = 5000\n", + "#in practice you will usually want a much larger amount of pre-training steps\n", + "cfg.model.optim.sched.min_lr = 0\n", + "cfg.model.optim.lr = 0.015\n", + "cfg.model.optim.weight_decay = 0\n", + "\n", + "cfg.trainer.max_steps = cfg.model.optim.sched.max_steps\n", + "\n", + "cfg.model.spec_augment.patch_size = 16\n", + "\n", + "cfg.model.train_ds.min_duration = 2\n", + "cfg.model.validation_ds.min_duration = 2\n", + "#with patch_size set to 16 and 10 patches, \n", + "#we need to be able to mask 160 time steps;\n", + "#at preprocessor stride 0.01 this means we need minimum duration of 1.6 seconds \n", + "#or more to sample only from masked steps in the same utterance\n", + "\n", + "cfg.model.loss.num_negatives = 40" ] }, { @@ -341,7 +354,7 @@ "id": "TrfAb1DjWzpL" }, "source": [ - "The parameters that are relevant to the self-supervised decoder and loss can be found in cfg.model.decoder and cfg.model loss. The default parameters for them tend to work well, so we will keep them as is for this tutorial." + "The following parameters will be used for decoder, loss, and masking:" ] }, { @@ -353,13 +366,18 @@ "outputs": [], "source": [ "print(OmegaConf.to_yaml(cfg.model.decoder))\n", - "print(OmegaConf.to_yaml(cfg.model.loss))" + "print(OmegaConf.to_yaml(cfg.model.loss))\n", + "print(OmegaConf.to_yaml(cfg.model.spec_augment))" ] }, { "cell_type": "markdown", "source": [ - "Note that for this loss the outputs must match the inputs, so since we are using Citrinet architecture with 8x stride, we would need to either set \"cfg.model.loss.combine_time_steps\" to 8, or put additional stride layers in the decoder. By default for Citrinet with 8x stride we use \"cfg.model.loss.combine_time_steps=4\" and \"cfg.model.decoder.stride_layers=1\" to match the 8x stride." + "Note that for this loss the outputs must match the inputs, so since we are using Citrinet architecture with 8x stride, we would need to either set \"cfg.model.loss.combine_time_steps\" to 8, or put additional stride layers in the decoder. By default for Citrinet with 8x stride we use \"cfg.model.loss.combine_time_steps=4\" and \"cfg.model.decoder.stride_layers=1\" to match the 8x stride.\n", + "\n", + "Since in MaskedPatchAugmentation we set patch_size to 16 and mask_patches is set to 10, this will result in 160 total masked steps in the spectrogram. Since combine_time_steps is set to 4, this means that 160 / 4 = 40 total potential negative can be used, so we set loss.num_negatives to 40 (unless you set sample_from_same_utterance_only to false or sample_from_non_masked to true, but this tends to make results worse).\n", + "\n", + "In the default configs we assume that min_duration for samples is higher (8 seconds by default), so there we can set patch_size to 48 for a total of 480 masked steps, and use 100 sampled negatives. If the min_duration of samples that you are training on allows, the amount of masked steps as well as negatives can be increased further (masking around 50% of the sample duration tends to work well)." ], "metadata": { "id": "4JnepitBZ3ta" @@ -371,7 +389,7 @@ "id": "yoUIMS7mgrUs" }, "source": [ - "Now we will can create the config object." + "Now we can create the config object." ] }, { @@ -494,8 +512,6 @@ "cfg.model.model_defaults.filters = 256\n", "cfg.model.model_defaults.repeat = 1\n", "cfg.model.model_defaults.enc_final = 256\n", - "cfg.model.optim.weight_decay = 0\n", - "cfg.model.optim.sched.warmup_steps = 500\n", "\n", "cfg.model.spec_augment.freq_masks = 2\n", "cfg.model.spec_augment.time_masks = 5\n", @@ -508,14 +524,12 @@ "\n", "cfg.model.log_prediction = False\n", "\n", - "cfg.trainer.max_epochs = None\n", - "cfg.trainer.max_steps = 2000 \n", - "\n", "if torch.cuda.is_available():\n", - " cfg.trainer.gpus = 1\n", - " cfg.trainer.accelerator = 'dp'\n", + " cfg.trainer.accelerator = 'gpu'\n", + " cfg.trainer.strategy = 'dp'\n", "else:\n", - " cfg.trainer.gpus = 0\n", + " cfg.trainer.accelerator = 'cpu'\n", + " cfg.trainer.strategy = None\n", "\n", "cfg.model.tokenizer.dir = data_dir + \"/tokenizers/an4/tokenizer_spe_unigram_v128/\" # note this is a directory, not a path to a vocabulary file\n", "cfg.model.tokenizer.type = \"bpe\"\n", @@ -526,7 +540,16 @@ "cfg.exp_manager.create_tensorboard_logger = False\n", "cfg.exp_manager.resume_if_exists = True\n", "cfg.exp_manager.resume_ignore_no_checkpoint = True\n", - "cfg.exp_manager.checkpoint_callback_params.save_best_model = True" + "cfg.exp_manager.checkpoint_callback_params.save_best_model = True\n", + "\n", + "cfg.model.optim.sched.name = \"CosineAnnealing\"\n", + "cfg.model.optim.sched.warmup_steps = 500\n", + "cfg.model.optim.sched.max_steps = 2000\n", + "cfg.model.optim.sched.min_lr = 0\n", + "cfg.model.optim.lr = 0.015 #if encoder is frozen, lr can be much higher\n", + "cfg.model.optim.weight_decay = 0\n", + "\n", + "cfg.trainer.max_steps = cfg.model.optim.sched.max_steps" ] }, { @@ -618,7 +641,7 @@ { "cell_type": "markdown", "source": [ - "We can optionally freeze the encoder and only fine-tune the decoder during traning. This can be done to lower the computational requirements of fine-tuning, but will likely result in a higher word error rate." + "We can optionally freeze the encoder and only fine-tune the decoder during training. This can be done to lower the memory and time requirements of fine-tuning, but will likely result in a higher word error rate." ], "metadata": { "id": "S5aVb2F8WuAR" @@ -658,7 +681,7 @@ { "cell_type": "markdown", "source": [ - "With the default parameters in this notebook, this pre-training and fine-tuning should achieve around 0.2-0.3 word error rate on the an4 validation set." + "With the default parameters in this notebook, this pre-training and fine-tuning should achieve around 0.2-0.3 word error rate on the an4 validation set. With frozen encoder and lr increased to 0.15, you will get around 0.6 WER." ], "metadata": { "id": "UfnbNZ-AmmD1" From 936c20fb5555d4358f112ca3c5883c9afabb1770 Mon Sep 17 00:00:00 2001 From: sam1373 Date: Wed, 23 Feb 2022 17:08:32 -0800 Subject: [PATCH 3/6] revert configs Signed-off-by: sam1373 --- examples/asr/conf/ssl/citrinet/citrinet_ssl_1024.yaml | 7 +++---- examples/asr/conf/ssl/citrinet/citrinet_ssl_ci.yaml | 7 +++---- examples/asr/conf/ssl/conformer/conformer_ssl.yaml | 5 ++--- examples/asr/conf/ssl/contextnet/contextnet_ssl.yaml | 5 ++--- 4 files changed, 10 insertions(+), 14 deletions(-) diff --git a/examples/asr/conf/ssl/citrinet/citrinet_ssl_1024.yaml b/examples/asr/conf/ssl/citrinet/citrinet_ssl_1024.yaml index 1e8289842561..05bb2f916250 100644 --- a/examples/asr/conf/ssl/citrinet/citrinet_ssl_1024.yaml +++ b/examples/asr/conf/ssl/citrinet/citrinet_ssl_1024.yaml @@ -440,14 +440,13 @@ model: last_epoch: -1 trainer: - devices: 1 # number of gpus + gpus: 0 # number of gpus max_epochs: 100 max_steps: null # computed at runtime if not set num_nodes: 1 - accelerator: gpu - strategy: ddp + accelerator: ddp accumulate_grad_batches: 1 - enable_checkpointing: False # Provided by exp_manager + checkpoint_callback: false # Provided by exp_manager logger: false # Provided by exp_manager log_every_n_steps: 100 # Interval of logging. val_check_interval: 1.0 # Set to 0.25 to check 4 times per epoch, or an int for number of iterations diff --git a/examples/asr/conf/ssl/citrinet/citrinet_ssl_ci.yaml b/examples/asr/conf/ssl/citrinet/citrinet_ssl_ci.yaml index 4c2a55ac1113..98151e6b3810 100644 --- a/examples/asr/conf/ssl/citrinet/citrinet_ssl_ci.yaml +++ b/examples/asr/conf/ssl/citrinet/citrinet_ssl_ci.yaml @@ -427,14 +427,13 @@ model: last_epoch: -1 trainer: - devices: 1 # number of gpus + gpus: 0 # number of gpus max_epochs: 100 max_steps: null # computed at runtime if not set num_nodes: 1 - accelerator: gpu - strategy: ddp + accelerator: ddp accumulate_grad_batches: 1 - enable_checkpointing: False # Provided by exp_manager + checkpoint_callback: false # Provided by exp_manager logger: false # Provided by exp_manager log_every_n_steps: 100 # Interval of logging. val_check_interval: 1.0 # Set to 0.25 to check 4 times per epoch, or an int for number of iterations diff --git a/examples/asr/conf/ssl/conformer/conformer_ssl.yaml b/examples/asr/conf/ssl/conformer/conformer_ssl.yaml index a31781763954..a43a91bef83d 100644 --- a/examples/asr/conf/ssl/conformer/conformer_ssl.yaml +++ b/examples/asr/conf/ssl/conformer/conformer_ssl.yaml @@ -150,13 +150,12 @@ model: min_lr: 1e-6 trainer: - devices: -1 # number of GPUs, -1 would use all available GPUs + gpus: -1 # number of GPUs, -1 would use all available GPUs num_nodes: 1 max_epochs: 1000 max_steps: null # computed at runtime if not set val_check_interval: 1.0 # Set to 0.25 to check 4 times per epoch, or an int for number of iterations - accelerator: gpu - strategy: ddp + accelerator: ddp accumulate_grad_batches: 1 gradient_clip_val: 0.0 amp_level: O0 # O1/O2 for mixed precision diff --git a/examples/asr/conf/ssl/contextnet/contextnet_ssl.yaml b/examples/asr/conf/ssl/contextnet/contextnet_ssl.yaml index 6c74fc6036a0..7933bcf397f5 100644 --- a/examples/asr/conf/ssl/contextnet/contextnet_ssl.yaml +++ b/examples/asr/conf/ssl/contextnet/contextnet_ssl.yaml @@ -408,12 +408,11 @@ model: last_epoch: -1 trainer: - devices: 1 # number of gpus + gpus: 0 # number of gpus max_epochs: 100 max_steps: null # computed at runtime if not set num_nodes: 1 # Should be set via SLURM variable `SLURM_JOB_NUM_NODES` - accelerator: gpu - strategy: ddp + accelerator: ddp accumulate_grad_batches: 1 checkpoint_callback: false # Provided by exp_manager logger: false # Provided by exp_manager From 47b1615c3e6afda9c2bbbcf255f4dbbdebae49e0 Mon Sep 17 00:00:00 2001 From: sam1373 Date: Wed, 23 Feb 2022 17:14:19 -0800 Subject: [PATCH 4/6] revert configs Signed-off-by: sam1373 --- examples/asr/conf/ssl/citrinet/citrinet_ssl_1024.yaml | 7 +++---- examples/asr/conf/ssl/citrinet/citrinet_ssl_ci.yaml | 7 +++---- examples/asr/conf/ssl/conformer/conformer_ssl.yaml | 5 ++--- examples/asr/conf/ssl/contextnet/contextnet_ssl.yaml | 5 ++--- 4 files changed, 10 insertions(+), 14 deletions(-) diff --git a/examples/asr/conf/ssl/citrinet/citrinet_ssl_1024.yaml b/examples/asr/conf/ssl/citrinet/citrinet_ssl_1024.yaml index 1e8289842561..05bb2f916250 100644 --- a/examples/asr/conf/ssl/citrinet/citrinet_ssl_1024.yaml +++ b/examples/asr/conf/ssl/citrinet/citrinet_ssl_1024.yaml @@ -440,14 +440,13 @@ model: last_epoch: -1 trainer: - devices: 1 # number of gpus + gpus: 0 # number of gpus max_epochs: 100 max_steps: null # computed at runtime if not set num_nodes: 1 - accelerator: gpu - strategy: ddp + accelerator: ddp accumulate_grad_batches: 1 - enable_checkpointing: False # Provided by exp_manager + checkpoint_callback: false # Provided by exp_manager logger: false # Provided by exp_manager log_every_n_steps: 100 # Interval of logging. val_check_interval: 1.0 # Set to 0.25 to check 4 times per epoch, or an int for number of iterations diff --git a/examples/asr/conf/ssl/citrinet/citrinet_ssl_ci.yaml b/examples/asr/conf/ssl/citrinet/citrinet_ssl_ci.yaml index 4c2a55ac1113..98151e6b3810 100644 --- a/examples/asr/conf/ssl/citrinet/citrinet_ssl_ci.yaml +++ b/examples/asr/conf/ssl/citrinet/citrinet_ssl_ci.yaml @@ -427,14 +427,13 @@ model: last_epoch: -1 trainer: - devices: 1 # number of gpus + gpus: 0 # number of gpus max_epochs: 100 max_steps: null # computed at runtime if not set num_nodes: 1 - accelerator: gpu - strategy: ddp + accelerator: ddp accumulate_grad_batches: 1 - enable_checkpointing: False # Provided by exp_manager + checkpoint_callback: false # Provided by exp_manager logger: false # Provided by exp_manager log_every_n_steps: 100 # Interval of logging. val_check_interval: 1.0 # Set to 0.25 to check 4 times per epoch, or an int for number of iterations diff --git a/examples/asr/conf/ssl/conformer/conformer_ssl.yaml b/examples/asr/conf/ssl/conformer/conformer_ssl.yaml index a31781763954..a43a91bef83d 100644 --- a/examples/asr/conf/ssl/conformer/conformer_ssl.yaml +++ b/examples/asr/conf/ssl/conformer/conformer_ssl.yaml @@ -150,13 +150,12 @@ model: min_lr: 1e-6 trainer: - devices: -1 # number of GPUs, -1 would use all available GPUs + gpus: -1 # number of GPUs, -1 would use all available GPUs num_nodes: 1 max_epochs: 1000 max_steps: null # computed at runtime if not set val_check_interval: 1.0 # Set to 0.25 to check 4 times per epoch, or an int for number of iterations - accelerator: gpu - strategy: ddp + accelerator: ddp accumulate_grad_batches: 1 gradient_clip_val: 0.0 amp_level: O0 # O1/O2 for mixed precision diff --git a/examples/asr/conf/ssl/contextnet/contextnet_ssl.yaml b/examples/asr/conf/ssl/contextnet/contextnet_ssl.yaml index 6c74fc6036a0..7933bcf397f5 100644 --- a/examples/asr/conf/ssl/contextnet/contextnet_ssl.yaml +++ b/examples/asr/conf/ssl/contextnet/contextnet_ssl.yaml @@ -408,12 +408,11 @@ model: last_epoch: -1 trainer: - devices: 1 # number of gpus + gpus: 0 # number of gpus max_epochs: 100 max_steps: null # computed at runtime if not set num_nodes: 1 # Should be set via SLURM variable `SLURM_JOB_NUM_NODES` - accelerator: gpu - strategy: ddp + accelerator: ddp accumulate_grad_batches: 1 checkpoint_callback: false # Provided by exp_manager logger: false # Provided by exp_manager From 704951d11a91e812ca20fb28256f4aa6b3d92223 Mon Sep 17 00:00:00 2001 From: sam1373 Date: Thu, 24 Feb 2022 14:21:56 -0800 Subject: [PATCH 5/6] specify gpus Signed-off-by: sam1373 --- tutorials/asr/Self_Supervised_Pre_Training.ipynb | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tutorials/asr/Self_Supervised_Pre_Training.ipynb b/tutorials/asr/Self_Supervised_Pre_Training.ipynb index 101ceabb7dc4..ff284f4f50a8 100644 --- a/tutorials/asr/Self_Supervised_Pre_Training.ipynb +++ b/tutorials/asr/Self_Supervised_Pre_Training.ipynb @@ -312,9 +312,11 @@ "if torch.cuda.is_available():\n", " cfg.trainer.accelerator = 'gpu'\n", " cfg.trainer.strategy = 'dp'\n", + " cfg.trainer.gpus = 1\n", "else:\n", " cfg.trainer.accelerator = 'cpu'\n", " cfg.trainer.strategy = None\n", + " cfg.trainer.gpus = 0\n", "\n", "cfg.exp_manager.exp_dir = \"/content/exp\"\n", "cfg.exp_manager.name = \"pre_trained\"\n", @@ -527,9 +529,11 @@ "if torch.cuda.is_available():\n", " cfg.trainer.accelerator = 'gpu'\n", " cfg.trainer.strategy = 'dp'\n", + " cfg.trainer.gpus = 1\n", "else:\n", " cfg.trainer.accelerator = 'cpu'\n", " cfg.trainer.strategy = None\n", + " cfg.trainer.gpus = 0\n", "\n", "cfg.model.tokenizer.dir = data_dir + \"/tokenizers/an4/tokenizer_spe_unigram_v128/\" # note this is a directory, not a path to a vocabulary file\n", "cfg.model.tokenizer.type = \"bpe\"\n", From 44aec4aaa21799f073137408d929c89c1d25a365 Mon Sep 17 00:00:00 2001 From: sam1373 Date: Thu, 24 Feb 2022 22:44:55 -0800 Subject: [PATCH 6/6] update dirs Signed-off-by: sam1373 --- .../asr/Self_Supervised_Pre_Training.ipynb | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/tutorials/asr/Self_Supervised_Pre_Training.ipynb b/tutorials/asr/Self_Supervised_Pre_Training.ipynb index ff284f4f50a8..244ab7f43d09 100644 --- a/tutorials/asr/Self_Supervised_Pre_Training.ipynb +++ b/tutorials/asr/Self_Supervised_Pre_Training.ipynb @@ -295,7 +295,7 @@ "from omegaconf import OmegaConf\n", "import torch\n", "\n", - "config_path = './configs/citrinet_ssl_1024.yaml'\n", + "config_path = data_dir + '/configs/citrinet_ssl_1024.yaml'\n", "\n", "cfg = OmegaConf.load(config_path)\n", "\n", @@ -303,10 +303,10 @@ "cfg.model.model_defaults.repeat = 1\n", "cfg.model.model_defaults.enc_final = 256\n", "\n", - "cfg.model.train_ds.manifest_filepath = \"/content/an4/train_manifest.json\"\n", + "cfg.model.train_ds.manifest_filepath = train_manifest\n", "cfg.model.train_ds.batch_size = 16\n", "\n", - "cfg.model.validation_ds.manifest_filepath = \"/content/an4/test_manifest.json\"\n", + "cfg.model.validation_ds.manifest_filepath = test_manifest\n", "cfg.model.validation_ds.batch_size = 16\n", "\n", "if torch.cuda.is_available():\n", @@ -318,7 +318,7 @@ " cfg.trainer.strategy = None\n", " cfg.trainer.gpus = 0\n", "\n", - "cfg.exp_manager.exp_dir = \"/content/exp\"\n", + "cfg.exp_manager.exp_dir = data_dir + \"/content/exp\"\n", "cfg.exp_manager.name = \"pre_trained\"\n", "cfg.exp_manager.use_datetime_version = False\n", "cfg.exp_manager.create_tensorboard_logger = False\n", @@ -476,7 +476,7 @@ "outputs": [], "source": [ "!mkdir scripts\n", - "!wget -P scripts/ https://raw.githubusercontent.com/NVIDIA/NeMo/main/scripts/tokenizers/process_asr_text_tokenizer.py\n", + "!wget -P scripts/ https://raw.githubusercontent.com/NVIDIA/NeMo/r1.7.0/scripts/tokenizers/process_asr_text_tokenizer.py\n", "\n", "!python ./scripts/process_asr_text_tokenizer.py \\\n", " --manifest=\"{data_dir}/an4/train_manifest.json\" \\\n", @@ -507,7 +507,7 @@ }, "outputs": [], "source": [ - "config_path = './configs/citrinet_1024.yaml'\n", + "config_path = data_dir + '/configs/citrinet_1024.yaml'\n", "\n", "cfg = OmegaConf.load(config_path)\n", "\n", @@ -518,10 +518,10 @@ "cfg.model.spec_augment.freq_masks = 2\n", "cfg.model.spec_augment.time_masks = 5\n", "\n", - "cfg.model.train_ds.manifest_filepath = \"/content/an4/train_manifest.json\"\n", + "cfg.model.train_ds.manifest_filepath = train_manifest\n", "cfg.model.train_ds.batch_size = 16\n", "\n", - "cfg.model.validation_ds.manifest_filepath = \"/content/an4/test_manifest.json\"\n", + "cfg.model.validation_ds.manifest_filepath = test_manifest\n", "cfg.model.validation_ds.batch_size = 16\n", "\n", "cfg.model.log_prediction = False\n", @@ -538,7 +538,7 @@ "cfg.model.tokenizer.dir = data_dir + \"/tokenizers/an4/tokenizer_spe_unigram_v128/\" # note this is a directory, not a path to a vocabulary file\n", "cfg.model.tokenizer.type = \"bpe\"\n", "\n", - "cfg.exp_manager.exp_dir = \"/content/exp\"\n", + "cfg.exp_manager.exp_dir = data_dir + \"/content/exp\"\n", "cfg.exp_manager.name = \"fine_tuned\"\n", "cfg.exp_manager.use_datetime_version = False\n", "cfg.exp_manager.create_tensorboard_logger = False\n", @@ -573,7 +573,7 @@ }, "outputs": [], "source": [ - "cfg.init_from_nemo_model=\"/content/exp/pre_trained/checkpoints/pre_trained.nemo\"\n", + "cfg.init_from_nemo_model=data_dir + \"/content/exp/pre_trained/checkpoints/pre_trained.nemo\"\n", "cfg.init_strict = False" ] },