Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Self-supervised tutorial & update #3344

Merged
merged 50 commits into from
Jan 26, 2022
Merged
Show file tree
Hide file tree
Changes from 48 commits
Commits
Show all changes
50 commits
Select commit Hold shift + click to select a range
60ab6d4
update
sam1373 Dec 15, 2021
93c0e28
update
sam1373 Dec 15, 2021
fe3926f
version test
sam1373 Dec 15, 2021
ddab8dd
version test
sam1373 Dec 15, 2021
3f66999
image for tutorial
sam1373 Dec 15, 2021
9355f94
enc_final in model_defaults
sam1373 Dec 15, 2021
0698b4e
enc_final in model_defaults
sam1373 Dec 15, 2021
7bc8f88
self-supervised tutorial
sam1373 Dec 15, 2021
690b8ec
Merge branch 'main' into pre_training_5
sam1373 Dec 15, 2021
3de5fbe
fix
sam1373 Dec 16, 2021
93c7b11
Merge remote-tracking branch 'origin/pre_training_5' into pre_training_5
sam1373 Dec 16, 2021
1b81e44
fix
sam1373 Dec 16, 2021
b8e5bf0
contextnet ssl config
sam1373 Dec 20, 2021
d5ded21
remove test_ds from config
sam1373 Dec 20, 2021
b42c747
Merge branch 'main' into pre_training_5
sam1373 Dec 20, 2021
62f4bc9
Merge branch 'main' into pre_training_5
sam1373 Dec 20, 2021
ca0ca7f
update recon decoder
sam1373 Dec 21, 2021
c899811
Merge remote-tracking branch 'origin/pre_training_5' into pre_training_5
sam1373 Dec 21, 2021
9632e0d
don't save -last if val_loss is nan
sam1373 Dec 21, 2021
ecfbd66
check if val_loss is there
sam1373 Dec 21, 2021
eda9e40
keep entries from same file together when tarring
sam1373 Dec 22, 2021
aeab8ef
keep entries from same file together when tarring
sam1373 Dec 22, 2021
99e4deb
print num of files in shard
sam1373 Dec 22, 2021
c894aef
update
sam1373 Dec 30, 2021
4310e51
style
sam1373 Dec 30, 2021
268a89a
Merge branch 'main' into pre_training_5
sam1373 Jan 3, 2022
e7cae50
Merge branch 'main' into pre_training_5
sam1373 Jan 12, 2022
567399a
Merge branch 'main' into pre_training_5
sam1373 Jan 12, 2022
916a191
moving configs, add docstrings
sam1373 Jan 19, 2022
4583ac4
Merge branch 'main' into pre_training_5
sam1373 Jan 19, 2022
b649643
Merge remote-tracking branch 'origin/pre_training_5' into pre_training_5
sam1373 Jan 19, 2022
b0585d4
tutorial updates
sam1373 Jan 19, 2022
ba2bb63
update test
sam1373 Jan 19, 2022
e43bb36
update loading
sam1373 Jan 19, 2022
1ff9a88
update loading
sam1373 Jan 19, 2022
b5f2c6d
update loading
sam1373 Jan 19, 2022
7702b34
update loading
sam1373 Jan 19, 2022
14be631
fix
sam1373 Jan 20, 2022
2cd0d7f
fix
sam1373 Jan 20, 2022
6106d6c
citrinet configs
sam1373 Jan 21, 2022
74c9352
Merge branch 'main' into pre_training_5
sam1373 Jan 21, 2022
a00c399
citrinet configs update
sam1373 Jan 21, 2022
da1b142
Merge remote-tracking branch 'origin/pre_training_5' into pre_training_5
sam1373 Jan 21, 2022
6d38032
update
sam1373 Jan 24, 2022
b3564c1
default include all
sam1373 Jan 24, 2022
223e048
comments
sam1373 Jan 24, 2022
3767be4
docstring hydra example
sam1373 Jan 24, 2022
78869b9
Merge branch 'main' into pre_training_5
sam1373 Jan 24, 2022
40823c1
Merge branch 'main' into pre_training_5
sam1373 Jan 26, 2022
e18caf6
Merge branch 'main' into pre_training_5
sam1373 Jan 26, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Jenkinsfile
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,7 @@ pipeline {
stage('L2: Speech Pre-training - CitriNet') {
steps {
sh 'python examples/asr/speech_pretraining/speech_pre_training.py \
--config-path="../conf/citrinet_ssl/" --config-name="citrinet_ssl_ci" \
--config-path="../conf/ssl/citrinet/" --config-name="citrinet_ssl_ci" \
model.train_ds.manifest_filepath=/home/TestData/an4_dataset/an4_train.json \
model.validation_ds.manifest_filepath=/home/TestData/an4_dataset/an4_val.json \
trainer.gpus=[1] \
Expand Down
5 changes: 3 additions & 2 deletions examples/asr/conf/citrinet/citrinet_1024.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ model:
se_context_size: -1
kernel_size_factor: 0.25
filters: 1024
enc_final: 1024
titu1994 marked this conversation as resolved.
Show resolved Hide resolved

tokenizer:
dir: ??? # path to directory which contains either tokenizer.model (bpe) or vocab.txt (for wpe)
Expand Down Expand Up @@ -400,7 +401,7 @@ model:



- filters: &enc_final 1024
- filters: ${model.model_defaults.enc_final}
repeat: 1
kernel: [41]
stride: [1]
Expand All @@ -415,7 +416,7 @@ model:

decoder:
_target_: nemo.collections.asr.modules.ConvASRDecoder
feat_in: *enc_final
feat_in: ${model.model_defaults.enc_final}
num_classes: -1 # filled with vocabulary size from tokenizer at runtime
vocabulary: [] # filled with vocabulary from tokenizer at runtime

Expand Down
5 changes: 3 additions & 2 deletions examples/asr/conf/citrinet/citrinet_384.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ model:
se: true
se_context_size: -1
kernel_size_factor: 1.0
enc_final: 640

tokenizer:
dir: ??? # path to directory which contains either tokenizer.model (bpe) or vocab.txt (for wpe)
Expand Down Expand Up @@ -356,7 +357,7 @@ model:
se_context_size: ${model.model_defaults.se_context_size}
kernel_size_factor: ${model.model_defaults.kernel_size_factor}

- filters: &enc_final 640
- filters: ${model.model_defaults.enc_final}
repeat: 1
kernel: [41]
stride: [1]
Expand All @@ -371,7 +372,7 @@ model:

decoder:
_target_: nemo.collections.asr.modules.ConvASRDecoder
feat_in: *enc_final
feat_in: ${model.model_defaults.enc_final}
num_classes: -1 # filled with vocabulary size from tokenizer at runtime
vocabulary: [] # filled with vocabulary from tokenizer at runtime

Expand Down
5 changes: 3 additions & 2 deletions examples/asr/conf/citrinet/citrinet_512.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ model:
se: true
se_context_size: -1
kernel_size_factor: 1.0
enc_final: 640

tokenizer:
dir: ??? # path to directory which contains either tokenizer.model (bpe) or vocab.txt (for wpe)
Expand Down Expand Up @@ -355,7 +356,7 @@ model:
se_context_size: ${model.model_defaults.se_context_size}
kernel_size_factor: ${model.model_defaults.kernel_size_factor}

- filters: &enc_final 640
- filters: ${model.model_defaults.enc_final}
repeat: 1
kernel: [41]
stride: [1]
Expand All @@ -370,7 +371,7 @@ model:

decoder:
_target_: nemo.collections.asr.modules.ConvASRDecoder
feat_in: *enc_final
feat_in: ${model.model_defaults.enc_final}
num_classes: -1 # filled with vocabulary size from tokenizer at runtime
vocabulary: [] # filled with vocabulary from tokenizer at runtime

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@ model:
max_duration: 35.0
min_duration: 3.0
shuffle: true
# tarred datasets
is_tarred: false
tarred_audio_filepaths: null
shuffle_n: 2048
use_start_end_token: false
sam1373 marked this conversation as resolved.
Show resolved Hide resolved
# bucketing params
bucketing_strategy: "synced_randomized"
bucketing_batch_size: null
Expand All @@ -50,6 +50,7 @@ model:
kernel_size_factor: 0.25
filters: 1024
decoder_out_channels: 128
enc_final: 1024


preprocessor:
Expand Down Expand Up @@ -389,7 +390,7 @@ model:



- filters: &enc_final 1024
- filters: ${model.model_defaults.enc_final}
repeat: 1
kernel: [41]
stride: [1]
Expand All @@ -404,7 +405,7 @@ model:

decoder:
_target_: nemo.collections.asr.modules.ConvASRDecoderReconstruction
feat_in: *enc_final
feat_in: ${model.model_defaults.enc_final}
feat_hidden: 128
feat_out: ${model.model_defaults.decoder_out_channels}
stride_layers: 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,9 @@ model:
max_duration: 35.0
min_duration: 3.0
shuffle: true
use_start_end_token: false
# tarred datasets
is_tarred: false
tarred_audio_filepaths: null
shuffle_n: 2048
use_start_end_token: false
# bucketing params
bucketing_strategy: "synced_randomized"
bucketing_batch_size: null
Expand Down
188 changes: 188 additions & 0 deletions examples/asr/conf/ssl/conformer/conformer_ssl.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
# This config contains the default values for self-supervised pre-training of a Conformer ASR model, large size (~120M).

# Architecture and training config:
# Default learning parameters in this config are set for effective batch size of 2K. To train it with smaller effective
# batch sizes, you may need to re-tune the learning parameters or use higher accumulate_grad_batches.
# Here are the recommended configs for different variants of Conformer-CTC, other parameters are the same as in this config file.
# One extra layer (compared to original paper) is added to the medium and large variants to compensate for replacing the LSTM decoder with a linear one.
#
# +-------------+---------+---------+----------+------------+-----+
# | Model | d_model | n_heads | n_layers | time_masks | lr |
# +=============+=========+========+===========+============+=====+
# | Small (13M)| 176 | 4 | 16 | 5 | 5.0 |
# +-------------+---------+--------+-----------+------------+-----+
# | Medium (30M)| 256 | 4 | 18 | 5 | 5.0 |
# +-------------+---------+--------+-----------+------------+-----+
# | Large (121M)| 512 | 8 | 18 | 10 | 2.0 |
# +---------------------------------------------------------------+
#
# If you do not want to train with AMP, you may use weight decay of 0.0 or reduce the number of time maskings to 2
# with time_width=100. It may help when you want to train for fewer epochs and need faster convergence.
# With weight_decay=0.0, learning rate may need to get reduced to 2.0.

name: "Conformer-SSL"

model:
sample_rate: 16000
log_prediction: true # enables logging sample predictions in the output during training
ctc_reduction: 'mean_batch'

train_ds:
manifest_filepath: ???
sam1373 marked this conversation as resolved.
Show resolved Hide resolved
sample_rate: ${model.sample_rate}
batch_size: 16 # you may increase batch_size if your memory allows
shuffle: true
num_workers: 8
pin_memory: false
use_start_end_token: true
trim_silence: false
max_duration: 16.7 # it is set for LibriSpeech, you may need to update it for your dataset
min_duration: 3.0
# tarred datasets
is_tarred: false
tarred_audio_filepaths: null
shuffle_n: 2048
# bucketing params
bucketing_strategy: "synced_randomized"
bucketing_batch_size: null

validation_ds:
manifest_filepath: ???
sample_rate: ${model.sample_rate}
batch_size: 16 # you may increase batch_size if your memory allows
shuffle: false
num_workers: 8
pin_memory: true
use_start_end_token: false



preprocessor:
_target_: nemo.collections.asr.modules.AudioToMelSpectrogramPreprocessor
sample_rate: ${model.sample_rate}
normalize: "per_feature"
window_size: 0.025
window_stride: 0.01
window: "hann"
features: &n_mels 80
n_fft: 512
log: true
frame_splicing: 1
dither: 0.00001
pad_to: 16
pad_value: 0.0

spec_augment:
_target_: nemo.collections.asr.modules.SpectrogramAugmentation
freq_masks: 4
VahidooX marked this conversation as resolved.
Show resolved Hide resolved
time_masks: 10
freq_width: 27
time_width: 0.05

encoder:
_target_: nemo.collections.asr.modules.ConformerEncoder
feat_in: ${model.preprocessor.features}
feat_out: -1 # you may set it if you need different output size other than the default d_model
n_layers: 18
d_model: 512

# Sub-sampling params
subsampling: striding # vggnet or striding, vggnet may give better results but needs more memory
subsampling_factor: 4 # must be power of 2
subsampling_conv_channels: -1 # -1 sets it to d_model

# Feed forward module's params
ff_expansion_factor: 4

# Multi-headed Attention Module's params
self_attention_model: rel_pos # rel_pos or abs_pos
n_heads: 8 # may need to be lower for smaller d_models
# [left, right] specifies the number of steps to be seen from left and right of each step in self-attention
att_context_size: [-1, -1] # -1 means unlimited context
xscaling: true # scales up the input embeddings by sqrt(d_model)
untie_biases: true # unties the biases of the TransformerXL layers
pos_emb_max_len: 5000

# Convolution module's params
conv_kernel_size: 31
sam1373 marked this conversation as resolved.
Show resolved Hide resolved
conv_norm_type: 'batch_norm' # batch_norm or layer_norm

### regularization
dropout: 0.1 # The dropout used in most of the Conformer Modules
dropout_emb: 0.0 # The dropout used for embeddings
dropout_att: 0.1 # The dropout for multi-headed attention modules

dec_out: &dec_out 128

decoder:
_target_: nemo.collections.asr.modules.ConvASRDecoderReconstruction
feat_in: ${model.encoder.d_model}
feat_hidden: 128
feat_out: *dec_out
stride_layers: 0
non_stride_layers: 2


loss:
_target_: nemo.collections.asr.losses.ContrastiveLoss
in_dim: *n_mels
proj_dim: *dec_out
VahidooX marked this conversation as resolved.
Show resolved Hide resolved
combine_time_steps: 4
codebook_size: 1200

optim:
name: adamw
lr: 5.0
# optimizer arguments
betas: [0.9, 0.98]
weight_decay: 1e-3

# scheduler setup
sched:
name: NoamAnnealing
d_model: ${model.encoder.d_model}
# scheduler config override
warmup_steps: 10000
warmup_ratio: null
min_lr: 1e-6

trainer:
gpus: -1 # number of GPUs, -1 would use all available GPUs
num_nodes: 1
max_epochs: 1000
max_steps: null # computed at runtime if not set
val_check_interval: 1.0 # Set to 0.25 to check 4 times per epoch, or an int for number of iterations
accelerator: ddp
accumulate_grad_batches: 1
gradient_clip_val: 0.0
amp_level: O0 # O1/O2 for mixed precision
sam1373 marked this conversation as resolved.
Show resolved Hide resolved
precision: 32 # Should be set to 16 for O1 and O2 to enable the AMP.
log_every_n_steps: 10 # Interval of logging.
progress_bar_refresh_rate: 10
resume_from_checkpoint: null # The path to a checkpoint file to continue the training, restores the whole state including the epoch, step, LR schedulers, apex, etc.
num_sanity_val_steps: 0 # number of steps to perform validation steps for sanity check the validation process before starting the training, setting to 0 disables it
check_val_every_n_epoch: 1 # number of evaluations on validation every n epochs
sync_batchnorm: true
checkpoint_callback: false # Provided by exp_manager
logger: false # Provided by exp_manager

exp_manager:
exp_dir: null
name: ${name}
create_tensorboard_logger: true
create_checkpoint_callback: true
checkpoint_callback_params:
# in case of multiple validation sets, first one is used
monitor: "val_loss"
mode: "min"
save_top_k: 5

# you need to set these two to True to continue the training
resume_if_exists: false
resume_ignore_no_checkpoint: false

# You may use this section to create a W&B logger
create_wandb_logger: false
wandb_logger_kwargs:
name: null
project: null
Loading