-
Notifications
You must be signed in to change notification settings - Fork 2.6k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[TTS] Create EnCodec training recipe (#6852)
* [TTS] Create EnCodec training recipe Signed-off-by: Ryan <[email protected]> * [TTS] Update encodec recipe Signed-off-by: Ryan <[email protected]> * [TTS] Rename EnCodec to AudioCodec Signed-off-by: Ryan <[email protected]> * [TTS] Add EnCodec unit tests Signed-off-by: Ryan <[email protected]> * [TTS] Add copyright header to distributed.py Signed-off-by: Ryan <[email protected]> --------- Signed-off-by: Ryan <[email protected]> Signed-off-by: jubick1337 <[email protected]>
- Loading branch information
1 parent
0b48771
commit eea78b2
Showing
15 changed files
with
2,128 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import pytorch_lightning as pl | ||
|
||
from nemo.collections.tts.models import AudioCodecModel | ||
from nemo.core.config import hydra_runner | ||
from nemo.utils.exp_manager import exp_manager | ||
|
||
|
||
@hydra_runner(config_path="conf/audio_codec", config_name="audio_codec") | ||
def main(cfg): | ||
trainer = pl.Trainer(**cfg.trainer) | ||
exp_manager(trainer, cfg.get("exp_manager", None)) | ||
model = AudioCodecModel(cfg=cfg.model, trainer=trainer) | ||
trainer.fit(model) | ||
|
||
|
||
if __name__ == '__main__': | ||
main() # noqa pylint: disable=no-value-for-parameter |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,160 @@ | ||
# This config contains the default values for training 24khz EnCodec model | ||
# If you want to train model on other dataset, you can change config values according to your dataset. | ||
# Most dataset-specific arguments are in the head of the config file, see below. | ||
|
||
name: EnCodec | ||
|
||
max_epochs: ??? | ||
# Adjust batch size based on GPU memory | ||
batch_size: 16 | ||
# When doing weighted sampling with multiple manifests, this defines how many training steps are in an epoch. | ||
# If null, then weighted sampling is disabled. | ||
weighted_sampling_steps_per_epoch: null | ||
|
||
# Dataset metadata for each manifest | ||
# https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/tts/data/vocoder_dataset.py#L39-L41 | ||
train_ds_meta: ??? | ||
val_ds_meta: ??? | ||
|
||
log_ds_meta: ??? | ||
log_dir: ??? | ||
|
||
# Modify these values based on your sample rate | ||
sample_rate: 24000 | ||
train_n_samples: 24000 | ||
down_sample_rates: [2, 4, 5, 8] | ||
up_sample_rates: [8, 5, 4, 2] | ||
# The number of samples per encoded audio frame. Should be the product of the down_sample_rates. | ||
# For example 2 * 4 * 5 * 8 = 320. | ||
samples_per_frame: 320 | ||
|
||
model: | ||
|
||
max_epochs: ${max_epochs} | ||
steps_per_epoch: ${weighted_sampling_steps_per_epoch} | ||
|
||
sample_rate: ${sample_rate} | ||
samples_per_frame: ${samples_per_frame} | ||
time_domain_loss_scale: 0.1 | ||
# Probability of updating the discriminator during each training step | ||
disc_update_prob: 0.67 | ||
|
||
# All resolutions for mel reconstruction loss, ordered [num_fft, hop_length, window_length] | ||
mel_loss_resolutions: [ | ||
[32, 8, 32], [64, 16, 64], [128, 32, 128], [256, 64, 256], [512, 128, 512], [1024, 256, 1024], [2048, 512, 2048] | ||
] | ||
|
||
train_ds: | ||
dataset: | ||
_target_: nemo.collections.tts.data.vocoder_dataset.VocoderDataset | ||
weighted_sampling_steps_per_epoch: ${weighted_sampling_steps_per_epoch} | ||
sample_rate: ${sample_rate} | ||
n_samples: ${train_n_samples} | ||
min_duration: 1.01 | ||
max_duration: null | ||
dataset_meta: ${train_ds_meta} | ||
|
||
dataloader_params: | ||
batch_size: ${batch_size} | ||
drop_last: true | ||
num_workers: 4 | ||
|
||
validation_ds: | ||
dataset: | ||
_target_: nemo.collections.tts.data.vocoder_dataset.VocoderDataset | ||
sample_rate: ${sample_rate} | ||
n_samples: null | ||
min_duration: null | ||
max_duration: null | ||
trunc_duration: 10.0 # Only use the first 10 seconds of audio for computing validation loss | ||
dataset_meta: ${val_ds_meta} | ||
|
||
dataloader_params: | ||
batch_size: 8 | ||
num_workers: 2 | ||
|
||
# Configures how audio samples are generated and saved during training. | ||
# Remove this section to disable logging. | ||
log_config: | ||
log_dir: ${log_dir} | ||
log_epochs: [10, 50] | ||
epoch_frequency: 100 | ||
log_tensorboard: false | ||
log_wandb: false | ||
|
||
generators: | ||
- _target_: nemo.collections.tts.parts.utils.callbacks.AudioCodecArtifactGenerator | ||
log_audio: true | ||
log_encoding: true | ||
log_quantized: true | ||
|
||
dataset: | ||
_target_: nemo.collections.tts.data.vocoder_dataset.VocoderDataset | ||
sample_rate: ${sample_rate} | ||
n_samples: null | ||
min_duration: null | ||
max_duration: null | ||
trunc_duration: 15.0 # Only log the first 15 seconds of generated audio. | ||
dataset_meta: ${log_ds_meta} | ||
|
||
dataloader_params: | ||
batch_size: 4 | ||
num_workers: 2 | ||
|
||
audio_encoder: | ||
_target_: nemo.collections.tts.modules.audio_codec_modules.SEANetEncoder | ||
down_sample_rates: ${down_sample_rates} | ||
|
||
audio_decoder: | ||
_target_: nemo.collections.tts.modules.audio_codec_modules.SEANetDecoder | ||
up_sample_rates: ${up_sample_rates} | ||
|
||
vector_quantizer: | ||
_target_: nemo.collections.tts.modules.vector_quantization.ResidualVectorQuantizer | ||
num_codebooks: 8 | ||
|
||
discriminator: | ||
_target_: nemo.collections.tts.modules.audio_codec_modules.MultiResolutionDiscriminatorSTFT | ||
resolutions: [[128, 32, 128], [256, 64, 256], [512, 128, 512], [1024, 256, 1024], [2048, 512, 2048]] | ||
|
||
# The original EnCodec uses hinged loss, but squared-GAN loss is more stable | ||
# and reduces the need to tune the loss weights or use a gradient balancer. | ||
generator_loss: | ||
_target_: nemo.collections.tts.losses.audio_codec_loss.GeneratorSquaredLoss | ||
|
||
discriminator_loss: | ||
_target_: nemo.collections.tts.losses.audio_codec_loss.DiscriminatorSquaredLoss | ||
|
||
optim: | ||
_target_: torch.optim.Adam | ||
lr: 3e-4 | ||
betas: [0.5, 0.9] | ||
|
||
sched: | ||
name: ExponentialLR | ||
gamma: 0.999 | ||
|
||
trainer: | ||
num_nodes: 1 | ||
devices: 1 | ||
accelerator: gpu | ||
strategy: ddp | ||
precision: 32 # Vector quantization only works with 32-bit precision. | ||
max_epochs: ${max_epochs} | ||
accumulate_grad_batches: 1 | ||
enable_checkpointing: False # Provided by exp_manager | ||
logger: false # Provided by exp_manager | ||
log_every_n_steps: 100 | ||
check_val_every_n_epoch: 5 | ||
benchmark: false | ||
|
||
exp_manager: | ||
exp_dir: null | ||
name: ${name} | ||
create_tensorboard_logger: true | ||
create_checkpoint_callback: true | ||
create_wandb_logger: false | ||
checkpoint_callback_params: | ||
monitor: val_loss | ||
resume_if_exists: false | ||
resume_ignore_no_checkpoint: false |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.