Skip to content

Commit

Permalink
Add finetuning scripts (#7263)
Browse files Browse the repository at this point in the history
* Add new script for finetuning asr models

Signed-off-by: Nithin Rao Koluguri <nithinraok>

* Update config for PTL 2.0

Signed-off-by: Nithin Rao Koluguri <nithinraok>

* style fix

Signed-off-by: Nithin Rao Koluguri <nithinraok>

* update jenkins run

Signed-off-by: Nithin Rao Koluguri <nithinraok>

* add doc strings

Signed-off-by: Nithin Rao Koluguri <nithinraok>

* improve code to support all decoder types

Signed-off-by: Nithin Rao Koluguri <nithinraok>

* add doc strings and support for char models

Signed-off-by: Nithin Rao Koluguri <nithinraok>

* typo fix

Signed-off-by: Nithin Rao Koluguri <nithinraok>

---------

Signed-off-by: Nithin Rao Koluguri <nithinraok>
Co-authored-by: Nithin Rao Koluguri <nithinraok>
  • Loading branch information
nithinraok authored Aug 29, 2023
1 parent 0dcc3c7 commit 68fea1a
Show file tree
Hide file tree
Showing 5 changed files with 377 additions and 27 deletions.
20 changes: 18 additions & 2 deletions Jenkinsfile
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ pipeline {
}
}

stage('L2: Speech to Text WPE - CitriNet') {
stage('Speech to Text WPE - CitriNet') {
steps {
sh 'python examples/asr/asr_ctc/speech_to_text_ctc_bpe.py \
--config-path="../conf/citrinet/" --config-name="config_bpe" \
Expand All @@ -150,7 +150,7 @@ pipeline {
}
}

stage('L2: Speech Pre-training - CitriNet') {
stage('Speech Pre-training - CitriNet') {
steps {
sh 'python examples/asr/speech_pretraining/speech_pre_training.py \
--config-path="../conf/ssl/citrinet/" --config-name="citrinet_ssl_ci" \
Expand All @@ -164,6 +164,22 @@ pipeline {
}
}

stage('Speech To Text Finetuning') {
steps {
sh 'python examples/asr/speech_to_text_finetune.py \
--config-path="conf" --config-name="speech_to_text_finetune" \
model.train_ds.manifest_filepath=/home/TestData/an4_dataset/an4_train.json \
model.validation_ds.manifest_filepath=/home/TestData/an4_dataset/an4_val.json \
init_from_nemo_model=/home/TestData/asr/stt_en_fastconformer_transducer_large.nemo \
model.tokenizer.update_tokenizer=False \
trainer.devices=[1] \
trainer.accelerator="gpu" \
+trainer.fast_dev_run=True \
exp_manager.exp_dir=examples/asr/speech_finetuning_results'
sh 'rm -rf examples/asr/speech_finetuning_results'
}
}

// TODO: Please Fix Me
// Error locating target 'nemo.collections.asr.modules.wav2vec_modules.ConvFeatureEncoder', see chained exception above.
// stage('L2: Speech Pre-training - Wav2Vec') {
Expand Down
63 changes: 39 additions & 24 deletions docs/source/asr/configs.rst
Original file line number Diff line number Diff line change
Expand Up @@ -984,8 +984,8 @@ Main parts of the config:
batch_size: 16 # you may increase batch_size if your memory allows
# other params
Finetuning
~~~~~~~~~~~
Finetuning with Text-Only Data
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

To finetune existing ASR model using text-only data use ``<NeMo_git_root>/examples/asr/asr_with_tts/speech_to_text_bpe_with_text_finetune.py`` script with the corresponding config ``<NeMo_git_root>/examples/asr/conf/asr_tts/hybrid_asr_tts.yaml``.

Expand Down Expand Up @@ -1030,47 +1030,53 @@ Fine-tuning Configurations
All ASR scripts support easy fine-tuning by partially/fully loading the pretrained weights from a checkpoint into the **currently instantiated model**. Note that the currently instantiated model should have parameters that match the pre-trained checkpoint (such that weights may load properly). In order to directly fine-tune a pre-existing checkpoint, please follow the tutorial `ASR Language Fine-tuning. <https://colab.research.google.com/github/NVIDIA/NeMo/blob/stable/tutorials/asr/ASR_CTC_Language_Finetuning.ipynb>`_
Pre-trained weights can be provided in multiple ways -
Models can be fine-tuned in two ways:
* By updating or retaining current tokenizer alone
* By updating model architecture and tokenizer
Fine-tuning by updating or retaining current tokenizer
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
In this case, the model architecture is not updated. The model is initialized with the pre-trained weights by
two ways:
1) Providing a path to a NeMo model (via ``init_from_nemo_model``)
2) Providing a name of a pretrained NeMo model (which will be downloaded via the cloud) (via ``init_from_pretrained_model``)
3) Providing a path to a Pytorch Lightning checkpoint file (via ``init_from_ptl_ckpt``)
There are multiple ASR subtasks inside the ``examples/asr/`` directory, you can substitute the ``<subtask>`` tag below.
Then users can use existing tokenizer or update the tokenizer with new vocabulary. This is useful when users don't want to update the model architecture
but want to update the tokenizer with new vocabulary.
Fine-tuning via a NeMo model
~~~~~~~~~~~~~~~~~~~~~~~~~~~~
The same script can be used to finetune CTC, RNNT or Hybrid models as well.
<NeMo_repo>/examples/asr/speech_to_text_finetune.py script supports this type of fine-tuning with the following arguments:
.. code-block:: sh
python examples/asr/<subtask>/script_to_<script_name>.py \
python examples/asr/speech_to_text_finetune.py \
--config-path=<path to dir of configs> \
--config-name=<name of config without .yaml>) \
model.train_ds.manifest_filepath="<path to manifest file>" \
model.validation_ds.manifest_filepath="<path to manifest file>" \
model.tokenizer.update_tokenizer=<True/False> \ # True to update tokenizer, False to retain existing tokenizer
model.tokenizer.dir=<path to tokenizer dir> \ # Path to tokenizer dir when update_tokenizer=True
model.tokenizer.type=<tokenizer type> \ # tokenizer type when update_tokenizer=True
trainer.devices=-1 \
trainer.accelerator='gpu' \
trainer.max_epochs=50 \
+init_from_nemo_model="<path to .nemo model file>"
+init_from_nemo_model="<path to .nemo model file>" (or +init_from_pretrained_model="<name of pretrained checkpoint>")
Fine-tuning by changing model architecture and tokenizer
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Fine-tuning via a NeMo pretrained model name
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
If users want to update the model architecture as well they can use the following script:
.. code-block:: sh
For providing pretrained model, users can provide Pre-trained weights in multiple ways -
python examples/asr/<subtask>/script_to_<script_name>.py \
--config-path=<path to dir of configs> \
--config-name=<name of config without .yaml>) \
model.train_ds.manifest_filepath="<path to manifest file>" \
model.validation_ds.manifest_filepath="<path to manifest file>" \
trainer.devices=-1 \
trainer.accelerator='gpu' \
trainer.max_epochs=50 \
+init_from_pretrained_model="<name of pretrained checkpoint>"
1) Providing a path to a NeMo model (via ``init_from_nemo_model``)
2) Providing a name of a pretrained NeMo model (which will be downloaded via the cloud) (via ``init_from_pretrained_model``)
3) Providing a path to a Pytorch Lightning checkpoint file (via ``init_from_ptl_ckpt``)
Fine-tuning via a Pytorch Lightning checkpoint
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
There are multiple ASR subtasks inside the ``examples/asr/`` directory, you can substitute the ``<subtask>`` tag below.
.. code-block:: sh
Expand All @@ -1082,7 +1088,16 @@ Fine-tuning via a Pytorch Lightning checkpoint
trainer.devices=-1 \
trainer.accelerator='gpu' \
trainer.max_epochs=50 \
+init_from_ptl_ckpt="<name of pytorch lightning checkpoint>"
+init_from_nemo_model="<path to .nemo model file>" # (or +init_from_pretrained_model, +init_from_ptl_ckpt )
To reinitialize part of the model, to make it different from the pretrained model, users can mention them through config:
.. code-block:: yaml
init_from_nemo_model: "<path to .nemo model file>"
asr_model:
include: ["preprocessor","encoder"]
exclude: ["decoder"]
Fine-tuning Execution Flow Diagram
----------------------------------
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,6 @@ trainer:
precision: 32 # 16, 32, or bf16
log_every_n_steps: 10 # Interval of logging.
enable_progress_bar: True
resume_from_checkpoint: null # The path to a checkpoint file to continue the training, restores the whole state including the epoch, step, LR schedulers, apex, etc.
num_sanity_val_steps: 0 # number of steps to perform validation steps for sanity check the validation process before starting the training, setting to 0 disables it
check_val_every_n_epoch: 1 # number of evaluations on validation every n epochs
sync_batchnorm: true
Expand Down
118 changes: 118 additions & 0 deletions examples/asr/conf/speech_to_text_finetune.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
name: "Speech_To_Text_Finetuning"

# use `init_from_nemo_model` or `init_from_pretrained_model` to initialize the model
# We do not currently support `init_from_ptl_ckpt` to create a single script for all types of models.
init_from_nemo_model: null # path to nemo model

model:
sample_rate: 16000
compute_eval_loss: false # eval samples can be very long and exhaust memory. Disable computation of transducer loss during validation/testing with this flag.
log_prediction: true # enables logging sample predictions in the output during training
rnnt_reduction: 'mean_volume'
skip_nan_grad: false

train_ds:
manifest_filepath: ???
sample_rate: ${model.sample_rate}
batch_size: 16 # you may increase batch_size if your memory allows
shuffle: true
num_workers: 8
pin_memory: true
max_duration: 20
min_duration: 0.1
# tarred datasets
is_tarred: false
tarred_audio_filepaths: null
shuffle_n: 2048
# bucketing params
bucketing_strategy: "fully_randomized"
bucketing_batch_size: null

validation_ds:
manifest_filepath: ???
sample_rate: ${model.sample_rate}
batch_size: 16
shuffle: false
use_start_end_token: false
num_workers: 8
pin_memory: true

test_ds:
manifest_filepath: null
sample_rate: ${model.sample_rate}
batch_size: 16
shuffle: false
use_start_end_token: false
num_workers: 8
pin_memory: true

char_labels: # use for char based models
update_labels: false
labels: null # example list config: \[' ', 'a', 'b', 'c'\]

tokenizer: # use for spe/bpe based tokenizer models
update_tokenizer: false
dir: null # path to directory which contains either tokenizer.model (bpe) or vocab.txt (for wpe)
type: bpe # Can be either bpe (SentencePiece tokenizer) or wpe (WordPiece tokenizer)

spec_augment:
_target_: nemo.collections.asr.modules.SpectrogramAugmentation
freq_masks: 2 # set to zero to disable it
time_masks: 10 # set to zero to disable it
freq_width: 27
time_width: 0.05

optim:
name: adamw
lr: 1e-4
# optimizer arguments
betas: [0.9, 0.98]
weight_decay: 1e-3

# scheduler setup
sched:
name: CosineAnnealing
# scheduler config override
warmup_steps: 5000
warmup_ratio: null
min_lr: 5e-6

trainer:
devices: -1 # number of GPUs, -1 would use all available GPUs
num_nodes: 1
max_epochs: 50
max_steps: -1 # computed at runtime if not set
val_check_interval: 1.0 # Set to 0.25 to check 4 times per epoch, or an int for number of iterations
accelerator: auto
strategy: ddp
accumulate_grad_batches: 1
gradient_clip_val: 0.0
precision: 32 # 16, 32, or bf16
log_every_n_steps: 10 # Interval of logging.
enable_progress_bar: True
num_sanity_val_steps: 0 # number of steps to perform validation steps for sanity check the validation process before starting the training, setting to 0 disables it
check_val_every_n_epoch: 1 # number of evaluations on validation every n epochs
sync_batchnorm: true
enable_checkpointing: False # Provided by exp_manager
logger: false # Provided by exp_manager
benchmark: false # needs to be false for models with variable-length speech input as it slows down training


exp_manager:
exp_dir: null
name: ${name}
create_tensorboard_logger: true
create_checkpoint_callback: true
checkpoint_callback_params:
# in case of multiple validation sets, first one is used
monitor: "val_wer"
mode: "min"
save_top_k: 5
always_save_nemo: True # saves the checkpoints as nemo files along with PTL checkpoints
resume_if_exists: false
resume_ignore_no_checkpoint: false

create_wandb_logger: false
wandb_logger_kwargs:
name: null
project: null
Loading

0 comments on commit 68fea1a

Please sign in to comment.