-
Notifications
You must be signed in to change notification settings - Fork 2.6k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add support for Hydra multirun to NeMo (#5159)
* Update execution doc and remove old snippet Signed-off-by: smajumdar <[email protected]> * Fix types Signed-off-by: smajumdar <[email protected]> * Fix defaults Signed-off-by: smajumdar <[email protected]> * Fix types for ParallelAdapterConfig Signed-off-by: smajumdar <[email protected]> * Add hash for config cache Signed-off-by: smajumdar <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Add support to delete redundant ckpt files for HP search Signed-off-by: smajumdar <[email protected]> * Correct config for IA3 Signed-off-by: smajumdar <[email protected]> * Fix check to <= 0 Signed-off-by: smajumdar <[email protected]> Signed-off-by: smajumdar <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
- Loading branch information
1 parent
564f211
commit b47a467
Showing
10 changed files
with
617 additions
and
21 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,226 @@ | ||
# Config to perform ASR adaptation using any pre-trained model (local nemo model or pre-trained checkpoint). | ||
############################################################################################################ | ||
# This config is special in that it is used alongside the scripts in the asr_adapters examples directory, | ||
# but does not directly construct a model itself. Instead it mimics the usual ASR model configs, and initializes | ||
# a pre-trained model (either local or via network), and overrides its many data loaders / optimizer / scheduler | ||
# and other arguments | ||
# | ||
# **Note**: This config does *not* get stored in the adapted model, since this config is merely to setup the | ||
# adapter training / inference script. This file can be considered a config not for the model, but for the | ||
# script that will adapt the model or infer an adapted model. | ||
# | ||
# You can therefore call this script multiple times to add as many adapters as you need in a single model, | ||
# by providing the previous adapted checkpoint as `model.nemo_model`. | ||
# | ||
# **Note**: Any config value in this yaml file *overrides* the equivalent config inside the model ! | ||
# | ||
# There are some important paramters of this config that must be updated by the user : | ||
# - model.pretrained_model or model.nemo_model: str name or path to some pretrained model. Only one of the | ||
# two should be passed. Selects the pre-trained model to be loaded and adapted. | ||
# | ||
# - model.adapter.adapter_name: Globally unique name, assigned to the adapter itself. Every adapter of a | ||
# model must have a unique name. | ||
# | ||
# - model.adapter.in_features: The output dimension of each block of the model. This is model dependent. | ||
# For example, Conformer dimension can be found via `model.encoder.d_model` in its config. | ||
# For Citrinets/ContextNets, the dimension can be found usually in `model.encoder.jasper.0.filters`. | ||
# | ||
# - model.train_ds.manifest_filepath / model.validation_ds.manifest_filepath: Data filepaths to train the | ||
# adapter module. | ||
############################################################################################################ | ||
# The recommendations during training of adapters is significantly different than general ASR training or | ||
# fine-tuning. Below are some recommended configuration values. | ||
# | ||
# - model.adapter.dim: Usually we chose a small bottleneck dim here. 16 to 32 is generally enough. | ||
# | ||
# - model.optim.lr: We generally chose a very small LR, and a very short training schedule of just a few hundred | ||
# steps - depending on the size of the dataset. Usually just a few epochs over the dataset with a low LR is | ||
# sufficient for adaptation. | ||
# | ||
# - model.optim.weight_decay: We find that strong weight decay prevents significant degradation of prior training, | ||
# but also limits the capacity of the model to learn the adapted domain. Usually as a baseline we use 0.0 | ||
# | ||
# - model.optim.sched.warmup_steps: We encourage warmup steps to be modified to suit the smaller training schedule. | ||
# | ||
# - trainer.max_steps: We recommend using trainer.max_steps to limit the training duration to just 10-20 epochs. | ||
# Adapters converge very fast, and prolonged training may cause overfitting to the new domain, consequently, | ||
# leading to catastrophic forgetting of the old domain. You can equivalently use small number of epochs using | ||
# trainer.max_epochs. | ||
# | ||
# - trainer.check_val_every_n_epoch: Since the training run is short, and very fast usually, it is recommended to | ||
# reduce the amount of validation to once every few epochs, rather than after every epoch, to speed up training. | ||
|
||
name: "ASR-Adapter-hp" | ||
|
||
model: | ||
# One of the below two values must be set ! | ||
pretrained_model: null # name of a pretrained model | ||
nemo_model: null # path to a ASR model file (.nemo) | ||
|
||
log_prediction: false # enables logging sample predictions in the output during training | ||
|
||
adapter: | ||
# Config of the adapter training/eval script. | ||
adapter_name: ??? # Name of the adapter, used by the script | ||
adapter_module_name: null # Name of the adapter module. Combine multiple modules with '+' between module names. | ||
adapter_state_dict_name: "adapters.pt" # If the individual adapters must be saved, a file name can be provided here. null disables this. | ||
|
||
# Config of the adapter module itself | ||
_target_: nemo.collections.common.parts.adapter_modules.LinearAdapter | ||
in_features: ??? # User must provide the output dimension of the layers of the model, which is the input dimension of this adapter. | ||
dim: 32 # The hidden dimension of the adapter, as chosen by user, but small values are preferred to reduce param count. | ||
activation: swish | ||
norm_position: 'pre' # Can be `pre` or `post` | ||
dropout: 0.0 # float, dropout for the adapter | ||
|
||
# Adapter strategy config | ||
adapter_strategy: | ||
_target_: nemo.core.classes.mixins.adapter_mixin_strategies.ResidualAddAdapterStrategy | ||
stochastic_depth: 0.0 # float, setting to > 0 will enable stochastic depth for each adapter block. | ||
l2_lambda: 0.0 # float, setting to > 0 will enable l2 norm auxiliary loss for each adapter's output. | ||
|
||
# Optional global config available to all adapters at a global level. | ||
# A global config is shared across every layer of the adapters, defining global properties rather | ||
# than properties local to the adapter (as defined above). | ||
# This can be useful in order to select *which type of adapter* is added, *what adapters to enable*, | ||
# and further global operations that can decide dynamically how to support the requested adapter. | ||
global_cfg: | ||
check_encoder_adapter: True # ASR adapter key, determines whether to check if encoder adapter modules is supported | ||
check_decoder_adapter: True # ASR adapter key, determines whether to check if decoder adapter modules is supported | ||
check_joint_adapter: True # ASR adapter key, determines whether to check if joint adapter modules is supported | ||
|
||
# Overrides the model's internal spec augment configuration | ||
spec_augment: | ||
_target_: nemo.collections.asr.modules.SpectrogramAugmentation | ||
freq_masks: 0 | ||
time_masks: 0 | ||
freq_width: 27 | ||
time_width: 0.05 | ||
|
||
train_ds: | ||
# train dataset + dataloader config | ||
# sample_rate will be merged with model config | ||
# use_start_end_token will be merged with model config | ||
# trim_silence will be merged with model config | ||
# max_duration will be merged with model config | ||
# min_duration will be merged with model config | ||
manifest_filepath: ??? | ||
batch_size: 16 # you may increase batch_size if your memory allows | ||
shuffle: true | ||
num_workers: 8 | ||
pin_memory: true | ||
# tarred datasets | ||
is_tarred: false | ||
tarred_audio_filepaths: null | ||
shuffle_n: 2048 | ||
# bucketing params | ||
bucketing_strategy: "synced_randomized" | ||
bucketing_batch_size: null | ||
|
||
validation_ds: | ||
# sample_rate will be merged with model config | ||
# use_start_end_token will be merged with model config | ||
manifest_filepath: ??? | ||
batch_size: 16 | ||
shuffle: false | ||
num_workers: 8 | ||
pin_memory: true | ||
|
||
test_ds: | ||
# sample_rate will be merged with model config | ||
# use_start_end_token will be merged with model config | ||
manifest_filepath: null | ||
batch_size: 16 | ||
shuffle: false | ||
num_workers: 8 | ||
pin_memory: true | ||
|
||
optim: | ||
# optimizer arguments | ||
name: adamw | ||
betas: [0.9, 0.98] | ||
lr: 0.001 # LR depends on the scheduler used by the base model. Noam prefers 0.5, Cosine Annealing prefers 0.02 | ||
weight_decay: 0 # During adaptation, since training run is short, WD is not required. Can be set if needed. | ||
|
||
# scheduler setup | ||
sched: | ||
name: CosineAnnealing | ||
|
||
# scheduler config override | ||
warmup_steps: null # Warmup steps should be set, and smaller than the trainer.max_steps set below. | ||
warmup_ratio: 0.1 # Warmup steps will be 10% of the training steps. | ||
min_lr: 1e-5 | ||
last_epoch: -1 | ||
|
||
trainer: | ||
devices: -1 # number of GPUs, -1 would use all available GPUs | ||
num_nodes: 1 | ||
max_epochs: 1 | ||
max_steps: -1 # computed at runtime if not set | ||
val_check_interval: 1.0 # Set to 0.25 to check 4 times per epoch, or an int for number of iterations | ||
accelerator: auto | ||
strategy: ddp | ||
accumulate_grad_batches: 1 | ||
gradient_clip_val: null | ||
precision: 32 # Should be set to 16 for O1 and O2 to enable the AMP. | ||
log_every_n_steps: 10 # Interval of logging. | ||
resume_from_checkpoint: null # The path to a checkpoint file to continue the training, restores the whole state including the epoch, step, LR schedulers, apex, etc. | ||
num_sanity_val_steps: 0 # number of steps to perform validation steps for sanity check the validation process before starting the training, setting to 0 disables it | ||
check_val_every_n_epoch: 1 # number of evaluations on validation every n epochs | ||
sync_batchnorm: true | ||
enable_checkpointing: False # Provided by exp_manager | ||
logger: false # Provided by exp_manager | ||
benchmark: false # needs to be false for models with variable-length speech input as it slows down training | ||
|
||
|
||
exp_manager: | ||
exp_dir: null | ||
|
||
# Add a unique name for all hyperparameter arguments to allow continued training. | ||
# NOTE: It is necessary to add all hyperparameter arguments to the name ! | ||
# This ensures successful restoration of model runs in case HP search crashes. | ||
name: ${name}-lr-${model.optim.lr}-adim-${model.adapter.dim}-sd-${model.adapter.adapter_strategy.stochastic_depth} | ||
|
||
create_tensorboard_logger: true | ||
create_checkpoint_callback: true | ||
checkpoint_callback_params: | ||
# in case of multiple validation sets, first one is used | ||
monitor: "val_wer" | ||
mode: "min" | ||
save_top_k: 1 | ||
always_save_nemo: True # saves the checkpoints as nemo files instead of PTL checkpoints | ||
|
||
create_wandb_logger: false | ||
wandb_logger_kwargs: | ||
name: ${exp_manager.name} | ||
project: null | ||
entity: null | ||
save_dir: null | ||
offline: false # If true, wandb logging will be done offline and would require manual syncing. | ||
tags: null # List of tags to assign to the run | ||
|
||
# HP Search may crash due to various reasons, best to attempt continuation in order to | ||
# resume from where the last failure case occured. | ||
resume_if_exists: true | ||
resume_ignore_no_checkpoint: true | ||
|
||
# Required for Hydra launch of hyperparameter search | ||
defaults: | ||
- override hydra/launcher: nemo_launcher | ||
|
||
# Hydra arguments necessary for hyperparameter optimization | ||
hydra: | ||
sweep: | ||
dir: "." | ||
subdir: "." | ||
|
||
sweeper: | ||
params: # place all the parameters you wish to search over here (corresponding to the rest of the config) | ||
model.optim.lr: 0.001,0.0001 | ||
model.adapter.dim: 32,64,96,128 | ||
model.adapter.adapter_strategy.stochastic_depth: 0.0,0.5,0.6,0.7,0.8,0.9 | ||
|
||
# Arguments to the hyperparameter runner | ||
launcher: | ||
num_gpus: -1 # Number of gpus to use. Each run works on a single GPU. | ||
jobs_per_gpu: 1 # If each GPU has large memory, you can run multiple jobs on the same GPU for faster results (until OOM). |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.