Skip to content

Commit

Permalink
Add support for Hydra multirun to NeMo (#5159)
Browse files Browse the repository at this point in the history
* Update execution doc and remove old snippet

Signed-off-by: smajumdar <[email protected]>

* Fix types

Signed-off-by: smajumdar <[email protected]>

* Fix defaults

Signed-off-by: smajumdar <[email protected]>

* Fix types for ParallelAdapterConfig

Signed-off-by: smajumdar <[email protected]>

* Add hash for config cache

Signed-off-by: smajumdar <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Add support to delete redundant ckpt files for HP search

Signed-off-by: smajumdar <[email protected]>

* Correct config for IA3

Signed-off-by: smajumdar <[email protected]>

* Fix check to <= 0

Signed-off-by: smajumdar <[email protected]>

Signed-off-by: smajumdar <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
titu1994 and pre-commit-ci[bot] authored Nov 4, 2022
1 parent 564f211 commit b47a467
Show file tree
Hide file tree
Showing 10 changed files with 617 additions and 21 deletions.
34 changes: 32 additions & 2 deletions examples/asr/asr_adapters/train_asr_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,34 @@
model.train_ds.batch_size=16 \
model.validation_ds.manifest_filepath=<Path to manifest> \
model.validation_ds.batch_size=16 \
model.optim.lr=0.5 \
model.optim.lr=0.001 \
model.optim.weight_decay=0.0 \
model.optim.sched.warmup_steps=100 \
trainer.max_steps=300 \
trainer.devices=1 \
trainer.precision=32 \
exp_manager.exp_dir=<Some directory for experiment manager>
# Hyper Parmaeter Search
python train_asr_adapter.py \
--config-path="../conf/asr_adapters" \
--config-name="asr_adaptation_hp.yaml" \
-m \
model.pretrained_model=null \
model.nemo_model=null \
model.adapter.adapter_name=<Unique adapter name> \
model.adapter.adapter_module_name=<null, or str module. Type: encoder, decoder, joint, or multiple with + between them> \
model.adapter.in_features=<dimension of the layer outputs of the model> \
model.train_ds.manifest_filepath=<Path to manifest> \
model.train_ds.batch_size=16 \
model.validation_ds.manifest_filepath=<Path to manifest> \
model.validation_ds.batch_size=16 \
exp_manager.exp_dir="<some directory>" \
exp_manager.create_wandb_logger=true \
exp_manager.wandb_logger_kwargs.project="<Project Name>" \
++delete_ckpt_after_train=True
# Fine-tune a model
While adaptation is very efficient for low-resource datasets, it imposes several restrictions -
Expand All @@ -59,7 +79,7 @@
https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/main/asr/results.html
"""

import glob
import os
from dataclasses import is_dataclass

Expand Down Expand Up @@ -211,6 +231,16 @@ def main(cfg):
# Save the adapter modules in a seperate file
model.save_adapters(str(state_path))

if 'delete_ckpt_after_train' in cfg:
delete_ckpt_after_train = cfg.delete_ckpt_after_train
if delete_ckpt_after_train:
logging.info("Deleting *.ckpt files after training to preserve storage space...")

ckpt_files = glob.glob(os.path.join(exp_log_dir, "checkpoints", "*.ckpt"))
for filepath in ckpt_files:
logging.info(f"Deleting file : {filepath}")
os.remove(filepath)


if __name__ == '__main__':
main() # noqa pylint: disable=no-value-for-parameter
226 changes: 226 additions & 0 deletions examples/asr/conf/asr_adapters/asr_adaptation_hp.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,226 @@
# Config to perform ASR adaptation using any pre-trained model (local nemo model or pre-trained checkpoint).
############################################################################################################
# This config is special in that it is used alongside the scripts in the asr_adapters examples directory,
# but does not directly construct a model itself. Instead it mimics the usual ASR model configs, and initializes
# a pre-trained model (either local or via network), and overrides its many data loaders / optimizer / scheduler
# and other arguments
#
# **Note**: This config does *not* get stored in the adapted model, since this config is merely to setup the
# adapter training / inference script. This file can be considered a config not for the model, but for the
# script that will adapt the model or infer an adapted model.
#
# You can therefore call this script multiple times to add as many adapters as you need in a single model,
# by providing the previous adapted checkpoint as `model.nemo_model`.
#
# **Note**: Any config value in this yaml file *overrides* the equivalent config inside the model !
#
# There are some important paramters of this config that must be updated by the user :
# - model.pretrained_model or model.nemo_model: str name or path to some pretrained model. Only one of the
# two should be passed. Selects the pre-trained model to be loaded and adapted.
#
# - model.adapter.adapter_name: Globally unique name, assigned to the adapter itself. Every adapter of a
# model must have a unique name.
#
# - model.adapter.in_features: The output dimension of each block of the model. This is model dependent.
# For example, Conformer dimension can be found via `model.encoder.d_model` in its config.
# For Citrinets/ContextNets, the dimension can be found usually in `model.encoder.jasper.0.filters`.
#
# - model.train_ds.manifest_filepath / model.validation_ds.manifest_filepath: Data filepaths to train the
# adapter module.
############################################################################################################
# The recommendations during training of adapters is significantly different than general ASR training or
# fine-tuning. Below are some recommended configuration values.
#
# - model.adapter.dim: Usually we chose a small bottleneck dim here. 16 to 32 is generally enough.
#
# - model.optim.lr: We generally chose a very small LR, and a very short training schedule of just a few hundred
# steps - depending on the size of the dataset. Usually just a few epochs over the dataset with a low LR is
# sufficient for adaptation.
#
# - model.optim.weight_decay: We find that strong weight decay prevents significant degradation of prior training,
# but also limits the capacity of the model to learn the adapted domain. Usually as a baseline we use 0.0
#
# - model.optim.sched.warmup_steps: We encourage warmup steps to be modified to suit the smaller training schedule.
#
# - trainer.max_steps: We recommend using trainer.max_steps to limit the training duration to just 10-20 epochs.
# Adapters converge very fast, and prolonged training may cause overfitting to the new domain, consequently,
# leading to catastrophic forgetting of the old domain. You can equivalently use small number of epochs using
# trainer.max_epochs.
#
# - trainer.check_val_every_n_epoch: Since the training run is short, and very fast usually, it is recommended to
# reduce the amount of validation to once every few epochs, rather than after every epoch, to speed up training.

name: "ASR-Adapter-hp"

model:
# One of the below two values must be set !
pretrained_model: null # name of a pretrained model
nemo_model: null # path to a ASR model file (.nemo)

log_prediction: false # enables logging sample predictions in the output during training

adapter:
# Config of the adapter training/eval script.
adapter_name: ??? # Name of the adapter, used by the script
adapter_module_name: null # Name of the adapter module. Combine multiple modules with '+' between module names.
adapter_state_dict_name: "adapters.pt" # If the individual adapters must be saved, a file name can be provided here. null disables this.

# Config of the adapter module itself
_target_: nemo.collections.common.parts.adapter_modules.LinearAdapter
in_features: ??? # User must provide the output dimension of the layers of the model, which is the input dimension of this adapter.
dim: 32 # The hidden dimension of the adapter, as chosen by user, but small values are preferred to reduce param count.
activation: swish
norm_position: 'pre' # Can be `pre` or `post`
dropout: 0.0 # float, dropout for the adapter

# Adapter strategy config
adapter_strategy:
_target_: nemo.core.classes.mixins.adapter_mixin_strategies.ResidualAddAdapterStrategy
stochastic_depth: 0.0 # float, setting to > 0 will enable stochastic depth for each adapter block.
l2_lambda: 0.0 # float, setting to > 0 will enable l2 norm auxiliary loss for each adapter's output.

# Optional global config available to all adapters at a global level.
# A global config is shared across every layer of the adapters, defining global properties rather
# than properties local to the adapter (as defined above).
# This can be useful in order to select *which type of adapter* is added, *what adapters to enable*,
# and further global operations that can decide dynamically how to support the requested adapter.
global_cfg:
check_encoder_adapter: True # ASR adapter key, determines whether to check if encoder adapter modules is supported
check_decoder_adapter: True # ASR adapter key, determines whether to check if decoder adapter modules is supported
check_joint_adapter: True # ASR adapter key, determines whether to check if joint adapter modules is supported

# Overrides the model's internal spec augment configuration
spec_augment:
_target_: nemo.collections.asr.modules.SpectrogramAugmentation
freq_masks: 0
time_masks: 0
freq_width: 27
time_width: 0.05

train_ds:
# train dataset + dataloader config
# sample_rate will be merged with model config
# use_start_end_token will be merged with model config
# trim_silence will be merged with model config
# max_duration will be merged with model config
# min_duration will be merged with model config
manifest_filepath: ???
batch_size: 16 # you may increase batch_size if your memory allows
shuffle: true
num_workers: 8
pin_memory: true
# tarred datasets
is_tarred: false
tarred_audio_filepaths: null
shuffle_n: 2048
# bucketing params
bucketing_strategy: "synced_randomized"
bucketing_batch_size: null

validation_ds:
# sample_rate will be merged with model config
# use_start_end_token will be merged with model config
manifest_filepath: ???
batch_size: 16
shuffle: false
num_workers: 8
pin_memory: true

test_ds:
# sample_rate will be merged with model config
# use_start_end_token will be merged with model config
manifest_filepath: null
batch_size: 16
shuffle: false
num_workers: 8
pin_memory: true

optim:
# optimizer arguments
name: adamw
betas: [0.9, 0.98]
lr: 0.001 # LR depends on the scheduler used by the base model. Noam prefers 0.5, Cosine Annealing prefers 0.02
weight_decay: 0 # During adaptation, since training run is short, WD is not required. Can be set if needed.

# scheduler setup
sched:
name: CosineAnnealing

# scheduler config override
warmup_steps: null # Warmup steps should be set, and smaller than the trainer.max_steps set below.
warmup_ratio: 0.1 # Warmup steps will be 10% of the training steps.
min_lr: 1e-5
last_epoch: -1

trainer:
devices: -1 # number of GPUs, -1 would use all available GPUs
num_nodes: 1
max_epochs: 1
max_steps: -1 # computed at runtime if not set
val_check_interval: 1.0 # Set to 0.25 to check 4 times per epoch, or an int for number of iterations
accelerator: auto
strategy: ddp
accumulate_grad_batches: 1
gradient_clip_val: null
precision: 32 # Should be set to 16 for O1 and O2 to enable the AMP.
log_every_n_steps: 10 # Interval of logging.
resume_from_checkpoint: null # The path to a checkpoint file to continue the training, restores the whole state including the epoch, step, LR schedulers, apex, etc.
num_sanity_val_steps: 0 # number of steps to perform validation steps for sanity check the validation process before starting the training, setting to 0 disables it
check_val_every_n_epoch: 1 # number of evaluations on validation every n epochs
sync_batchnorm: true
enable_checkpointing: False # Provided by exp_manager
logger: false # Provided by exp_manager
benchmark: false # needs to be false for models with variable-length speech input as it slows down training


exp_manager:
exp_dir: null

# Add a unique name for all hyperparameter arguments to allow continued training.
# NOTE: It is necessary to add all hyperparameter arguments to the name !
# This ensures successful restoration of model runs in case HP search crashes.
name: ${name}-lr-${model.optim.lr}-adim-${model.adapter.dim}-sd-${model.adapter.adapter_strategy.stochastic_depth}

create_tensorboard_logger: true
create_checkpoint_callback: true
checkpoint_callback_params:
# in case of multiple validation sets, first one is used
monitor: "val_wer"
mode: "min"
save_top_k: 1
always_save_nemo: True # saves the checkpoints as nemo files instead of PTL checkpoints

create_wandb_logger: false
wandb_logger_kwargs:
name: ${exp_manager.name}
project: null
entity: null
save_dir: null
offline: false # If true, wandb logging will be done offline and would require manual syncing.
tags: null # List of tags to assign to the run

# HP Search may crash due to various reasons, best to attempt continuation in order to
# resume from where the last failure case occured.
resume_if_exists: true
resume_ignore_no_checkpoint: true

# Required for Hydra launch of hyperparameter search
defaults:
- override hydra/launcher: nemo_launcher

# Hydra arguments necessary for hyperparameter optimization
hydra:
sweep:
dir: "."
subdir: "."

sweeper:
params: # place all the parameters you wish to search over here (corresponding to the rest of the config)
model.optim.lr: 0.001,0.0001
model.adapter.dim: 32,64,96,128
model.adapter.adapter_strategy.stochastic_depth: 0.0,0.5,0.6,0.7,0.8,0.9

# Arguments to the hyperparameter runner
launcher:
num_gpus: -1 # Number of gpus to use. Each run works on a single GPU.
jobs_per_gpu: 1 # If each GPU has large memory, you can run multiple jobs on the same GPU for faster results (until OOM).
2 changes: 1 addition & 1 deletion nemo/collections/asr/modules/conv_asr.py
Original file line number Diff line number Diff line change
Expand Up @@ -934,7 +934,7 @@ class JasperEncoderConfig:
@dataclass
class ConvASREncoderConfig:
_target_: str = 'nemo.collections.asr.modules.ConvASREncoder'
jasper: Optional[JasperEncoderConfig] = field(default_factory=list)
jasper: Optional[List[JasperEncoderConfig]] = field(default_factory=list)
activation: str = MISSING
feat_in: int = MISSING
normalization_mode: str = "batch"
Expand Down
11 changes: 6 additions & 5 deletions nemo/collections/common/parts/adapter_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

from dataclasses import dataclass, is_dataclass
from typing import Optional
from typing import Any, Optional

from hydra.utils import instantiate
from omegaconf import OmegaConf
Expand Down Expand Up @@ -57,6 +57,7 @@ def setup_adapter_strategy(self, adapter_strategy: Optional[adapter_mixin_strate


class LinearAdapter(AbstractAdapterModule):

"""
Simple Linear Feedforward Adapter module with LayerNorm and singe hidden layer with activation function.
Note: The adapter explicitly initializes its final layer with all zeros in order to avoid affecting the
Expand All @@ -66,7 +67,7 @@ class LinearAdapter(AbstractAdapterModule):
in_features: Input dimension of the module. Note that for adapters, input_dim == output_dim.
dim: Hidden dimension of the feed forward network.
activation: Str name for an activation function.
norm_position: Str, can be `pre` or `post`. Defaults to `post`. Determines whether the normalization
norm_position: Str, can be `pre` or `post`. Defaults to `pre`. Determines whether the normalization
will occur in the first layer or the last layer. Certain architectures may prefer one over the other.
dropout: float value, whether to perform dropout on the output of the last layer of the adapter.
adapter_strategy: By default, ResidualAddAdapterStrategyConfig. An adapter composition function object.
Expand All @@ -77,7 +78,7 @@ def __init__(
in_features: int,
dim: int,
activation: str = 'swish',
norm_position: str = "post",
norm_position: str = 'pre',
dropout: float = 0.0,
adapter_strategy: adapter_mixin_strategies.ResidualAddAdapterStrategyConfig = None,
):
Expand Down Expand Up @@ -142,7 +143,7 @@ class LinearAdapterConfig:
in_features: int
dim: int
activation: str = 'swish'
norm_position: str = 'post'
norm_position: str = 'pre'
dropout: float = 0.0
adapter_strategy: Optional[dict] = adapter_mixin_strategies.ResidualAddAdapterStrategyConfig()
adapter_strategy: Optional[Any] = adapter_mixin_strategies.ResidualAddAdapterStrategyConfig()
_target_: str = "{0}.{1}".format(LinearAdapter.__module__, LinearAdapter.__name__)
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import enum
import logging
from dataclasses import dataclass
from typing import Optional
from typing import Any, Optional

import torch
import torch.nn as nn
Expand Down Expand Up @@ -76,7 +76,7 @@ class MLPInfusedAdapter(InfusedAdapter):
@dataclass
class InfusedAdapterConfig:
in_features: int
adapter_strategy: Optional[dict] = adapter_mixin_strategies.ResidualAddAdapterStrategyConfig()
adapter_strategy: Optional[Any] = adapter_mixin_strategies.ResidualAddAdapterStrategyConfig()
_target_: str = "{0}.{1}".format(InfusedAdapter.__module__, InfusedAdapter.__name__)


Expand Down Expand Up @@ -167,5 +167,5 @@ class ParallelLinearAdapterConfig:
column_init_method: str = 'xavier'
row_init_method: str = 'zero'
dropout: float = 0.0
adapter_strategy: Optional[dict] = adapter_mixin_strategies.ResidualAddAdapterStrategyConfig()
adapter_strategy: Optional[Any] = adapter_mixin_strategies.ResidualAddAdapterStrategyConfig()
_target_: str = "{0}.{1}".format(ParallelLinearAdapter.__module__, ParallelLinearAdapter.__name__)
15 changes: 8 additions & 7 deletions nemo/core/config/hydra_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,18 +91,19 @@ def wrapper(cfg_passthrough: Optional[DictConfig] = None) -> Any:

# Wrap a callable object with name `parse_args`
# This is to mimic the ArgParser.parse_args() API.
class _argparse_wrapper:
def __init__(self, arg_parser):
self.arg_parser = arg_parser
self._actions = arg_parser._actions
def parse_args(self, args=None, namespace=None):
return parsed_args

def parse_args(self, args=None, namespace=None):
return parsed_args
parsed_args.parse_args = parse_args

# no return value from run_hydra() as it may sometime actually run the task_function
# multiple times (--multirun)
# argparse_wrapper = _argparse_wrapper(args)
argparse_wrapper = parsed_args

_run_hydra(
args_parser=_argparse_wrapper(args),
args=argparse_wrapper,
args_parser=args,
task_function=task_function,
config_path=config_path,
config_name=config_name,
Expand Down
Loading

0 comments on commit b47a467

Please sign in to comment.