Skip to content

Commit

Permalink
comments
Browse files Browse the repository at this point in the history
Signed-off-by: sam1373 <[email protected]>
  • Loading branch information
sam1373 committed Jan 24, 2022
1 parent b3564c1 commit 223e048
Show file tree
Hide file tree
Showing 7 changed files with 37 additions and 15 deletions.
5 changes: 3 additions & 2 deletions examples/asr/conf/citrinet/citrinet_384.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ model:
se: true
se_context_size: -1
kernel_size_factor: 1.0
enc_final: 640

tokenizer:
dir: ??? # path to directory which contains either tokenizer.model (bpe) or vocab.txt (for wpe)
Expand Down Expand Up @@ -356,7 +357,7 @@ model:
se_context_size: ${model.model_defaults.se_context_size}
kernel_size_factor: ${model.model_defaults.kernel_size_factor}

- filters: &enc_final 640
- filters: ${model.model_defaults.enc_final}
repeat: 1
kernel: [41]
stride: [1]
Expand All @@ -371,7 +372,7 @@ model:

decoder:
_target_: nemo.collections.asr.modules.ConvASRDecoder
feat_in: *enc_final
feat_in: ${model.model_defaults.enc_final}
num_classes: -1 # filled with vocabulary size from tokenizer at runtime
vocabulary: [] # filled with vocabulary from tokenizer at runtime

Expand Down
5 changes: 3 additions & 2 deletions examples/asr/conf/citrinet/citrinet_512.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ model:
se: true
se_context_size: -1
kernel_size_factor: 1.0
enc_final: 640

tokenizer:
dir: ??? # path to directory which contains either tokenizer.model (bpe) or vocab.txt (for wpe)
Expand Down Expand Up @@ -355,7 +356,7 @@ model:
se_context_size: ${model.model_defaults.se_context_size}
kernel_size_factor: ${model.model_defaults.kernel_size_factor}

- filters: &enc_final 640
- filters: ${model.model_defaults.enc_final}
repeat: 1
kernel: [41]
stride: [1]
Expand All @@ -370,7 +371,7 @@ model:

decoder:
_target_: nemo.collections.asr.modules.ConvASRDecoder
feat_in: *enc_final
feat_in: ${model.model_defaults.enc_final}
num_classes: -1 # filled with vocabulary size from tokenizer at runtime
vocabulary: [] # filled with vocabulary from tokenizer at runtime

Expand Down
1 change: 1 addition & 0 deletions examples/asr/conf/ssl/citrinet/citrinet_ssl_1024.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ model:
shuffle: true
is_tarred: false
tarred_audio_filepaths: null
shuffle_n: 2048
use_start_end_token: false
# bucketing params
bucketing_strategy: "synced_randomized"
Expand Down
16 changes: 12 additions & 4 deletions examples/asr/conf/ssl/conformer/conformer_ssl.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,18 @@ model:
batch_size: 16 # you may increase batch_size if your memory allows
shuffle: true
num_workers: 8
pin_memory: true
use_start_end_token: false
pin_memory: false
use_start_end_token: true
trim_silence: false
max_duration: 35.0 # it is set for LibriSpeech, you may need to update it for your dataset
max_duration: 16.7 # it is set for LibriSpeech, you may need to update it for your dataset
min_duration: 3.0
# tarred datasets
is_tarred: false
tarred_audio_filepaths: null
shuffle_n: 2048
# bucketing params
bucketing_strategy: "synced_randomized"
bucketing_batch_size: null

validation_ds:
manifest_filepath: ???
Expand Down Expand Up @@ -98,6 +105,7 @@ model:

# Convolution module's params
conv_kernel_size: 31
conv_norm_type: 'batch_norm' # batch_norm or layer_norm

### regularization
dropout: 0.1 # The dropout used in most of the Conformer Modules
Expand Down Expand Up @@ -167,7 +175,7 @@ exp_manager:
# in case of multiple validation sets, first one is used
monitor: "val_loss"
mode: "min"
save_top_k: 1
save_top_k: 5

# you need to set these two to True to continue the training
resume_if_exists: false
Expand Down
11 changes: 8 additions & 3 deletions examples/asr/conf/ssl/contextnet/contextnet_ssl.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,17 @@ model:
trim_silence: false
max_duration: 16.7
shuffle: true
is_tarred: false
tarred_audio_filepaths: null
tarred_shard_strategy: "scatter"
use_start_end_token: false
num_workers: 16
pin_memory: true
# tarred datasets
is_tarred: false
tarred_audio_filepaths: null
tarred_shard_strategy: "scatter"
shuffle_n: 2048
# bucketing params
bucketing_strategy: "synced_randomized"
bucketing_batch_size: null

validation_ds:
manifest_filepath: ???
Expand Down
12 changes: 9 additions & 3 deletions nemo/core/classes/modelPT.py
Original file line number Diff line number Diff line change
Expand Up @@ -949,7 +949,7 @@ def maybe_init_from_pretrained_checkpoint(self, cfg: OmegaConf, map_location: st
self.load_state_dict(restored_model.state_dict(), strict=False)
logging.info(f'Model checkpoint restored from nemo file with path : `{model_path}`')
del restored_model
else:
elif isinstance(cfg.init_from_nemo_model, dict):
model_load_dict = cfg.init_from_nemo_model
for model_load_cfg in model_load_dict.values():
model_path = model_load_cfg.path
Expand All @@ -966,6 +966,8 @@ def maybe_init_from_pretrained_checkpoint(self, cfg: OmegaConf, map_location: st
)

del restored_model
else:
raise TypeError("Invalid type: init_from_nemo_model is not a string or a dict!")

if 'init_from_pretrained_model' in cfg and cfg.init_from_pretrained_model is not None:
with open_dict(cfg):
Expand Down Expand Up @@ -996,7 +998,7 @@ def maybe_init_from_pretrained_checkpoint(self, cfg: OmegaConf, map_location: st
logging.info(f'Model checkpoint restored from pretrained chackpoint with name : `{model_name}`')

del restored_model
else:
elif isinstance(cfg.init_from_pretrained_model, dict):
model_load_dict = cfg.init_from_pretrained_model
for model_load_cfg in model_load_dict.values():
model_name = model_load_cfg.name
Expand All @@ -1016,6 +1018,8 @@ def maybe_init_from_pretrained_checkpoint(self, cfg: OmegaConf, map_location: st
)

del restored_model
else:
raise TypeError("Invalid type: init_from_pretrained_model is not a string or a dict!")

if 'init_from_ptl_ckpt' in cfg and cfg.init_from_ptl_ckpt is not None:
with open_dict(cfg):
Expand All @@ -1031,7 +1035,7 @@ def maybe_init_from_pretrained_checkpoint(self, cfg: OmegaConf, map_location: st
)

del ckpt
else:
elif isinstance(cfg.init_from_ptl_ckpt, dict):
model_load_dict = cfg.init_from_ptl_ckpt
for model_load_cfg in model_load_dict.values():
ckpt_path = model_load_cfg.path
Expand All @@ -1046,6 +1050,8 @@ def maybe_init_from_pretrained_checkpoint(self, cfg: OmegaConf, map_location: st
)

del ckpt
else:
raise TypeError("Invalid type: init_from_ptl_ckpt is not a string or a dict!")

def teardown(self, stage: str):
"""
Expand Down
2 changes: 1 addition & 1 deletion tutorials/asr/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ In this repository, you will find several tutorials discussing what is Automatic

10) `ASR_with_Transducers`: In this tutorial, we take a deep dive into Transducer based ASR models, discussing the similarity of setup and config to CTC models and then train a small ContextNet model on the AN4 dataset. We then discuss how to change the decoding strategy of a trained Transducer from greedy search to beam search. Finally, we wrap up this tutorial by extraining the alignment matrix from a trained Transducer model.

11) `Self_Supervised_Pre_Training`: It can often be difficult to obtain labeled data for ASR training. In this tutorial, we demonstrate how to pre-train a small Citrinet model in an unsupervised manner, and then fine-tune with CTC loss.
11) `Self_Supervised_Pre_Training`: It can often be difficult to obtain labeled data for ASR training. In this tutorial, we demonstrate how to pre-train a speech model in an unsupervised manner, and then fine-tune with CTC loss.

----------------

Expand Down

0 comments on commit 223e048

Please sign in to comment.