Skip to content

Commit

Permalink
Tacotron2 retrain (#4103)
Browse files Browse the repository at this point in the history
* fix yaml

Signed-off-by: treacker <[email protected]>

* Fix for new TTSDataset class

Signed-off-by: treacker <[email protected]>

* added wandb logging

Signed-off-by: treacker <[email protected]>

* added wandb logging

Signed-off-by: treacker <[email protected]>

* fix numpy version

Signed-off-by: treacker <[email protected]>

* fix numpy version

Signed-off-by: treacker <[email protected]>

* inference fix

Signed-off-by: treacker <[email protected]>

* removed old code

Signed-off-by: treacker <[email protected]>

* updated parser logic

Signed-off-by: treacker <[email protected]>

* reverted version update

Signed-off-by: treacker <[email protected]>

* refactored parser logic

Signed-off-by: treacker <[email protected]>

* Updated Jenkinsfile

Signed-off-by: treacker <[email protected]>

* Refactored tutorial for Tacotron2

Signed-off-by: treacker <[email protected]>

* Made backward compatibility

Signed-off-by: treacker <[email protected]>

* Made backward compatibility

Signed-off-by: treacker <[email protected]>

* Update Jenkinsfile

Signed-off-by: treacker <[email protected]>

* Update tacotron.yaml

Signed-off-by: treacker <[email protected]>

* Refactoring

Signed-off-by: treacker <[email protected]>

* cleaned up TN/ ITN doc (#4119)

* cleaned up TN/ ITN doc

Signed-off-by: Yang Zhang <[email protected]>

* fix typo

Signed-off-by: Yang Zhang <[email protected]>

* fix image

Signed-off-by: Yang Zhang <[email protected]>

* fix image

Signed-off-by: Yang Zhang <[email protected]>
Signed-off-by: treacker <[email protected]>

* Check implicit grad acc in GLUE dataset building (#4123)

* Check implicit grad acc in GLUE dataset building

Signed-off-by: MaximumEntropy <[email protected]>

* Fix jenkins test for GLUE/XNLI

Signed-off-by: MaximumEntropy <[email protected]>
Signed-off-by: treacker <[email protected]>

* Refactoring

Signed-off-by: treacker <[email protected]>

* Refactoring

Signed-off-by: treacker <[email protected]>

* Fixed jenkins

Signed-off-by: treacker <[email protected]>

* Refactoring

Signed-off-by: treacker <[email protected]>

* Refactoring

Signed-off-by: treacker <[email protected]>

* Refactoring

Signed-off-by: treacker <[email protected]>

Co-authored-by: Yang Zhang <[email protected]>
Co-authored-by: Sandeep Subramanian <[email protected]>
  • Loading branch information
3 people authored May 11, 2022
1 parent df33239 commit 30db4d4
Show file tree
Hide file tree
Showing 5 changed files with 640 additions and 455 deletions.
12 changes: 8 additions & 4 deletions Jenkinsfile
Original file line number Diff line number Diff line change
Expand Up @@ -2983,7 +2983,6 @@ pipeline {
}
}
parallel {
// TODO(Oktai15): update it in 1.8.0 version
stage('Tacotron 2') {
steps {
sh 'python examples/tts/tacotron2.py \
Expand All @@ -2993,13 +2992,18 @@ pipeline {
trainer.accelerator="gpu" \
+trainer.limit_train_batches=1 +trainer.limit_val_batches=1 trainer.max_epochs=1 \
trainer.strategy=null \
model.train_ds.dataloader_params.batch_size=4 \
model.validation_ds.dataloader_params.batch_size=4 \
model.decoder.decoder_rnn_dim=256 \
model.decoder.attention_rnn_dim=1024 \
model.decoder.prenet_dim=128 \
model.postnet.postnet_n_convolutions=3 \
~trainer.check_val_every_n_epoch'
model.train_ds.dataloader_params.batch_size=4 \
model.train_ds.dataloader_params.num_workers=1 \
model.validation_ds.dataloader_params.batch_size=4 \
model.validation_ds.dataloader_params.num_workers=1 \
~model.text_normalizer \
~model.text_normalizer_call_kwargs \
~trainer.check_val_every_n_epoch \
'
}
}
stage('WaveGlow') {
Expand Down
163 changes: 111 additions & 52 deletions examples/tts/conf/tacotron2.yaml
Original file line number Diff line number Diff line change
@@ -1,81 +1,136 @@
# TODO(Oktai15): update this config in 1.8.0 version
# This config contains the default values for training Tacotron2 model on LJSpeech dataset.
# If you want to train model on other dataset, you can change config values according to your dataset.
# Most dataset-specific arguments are in the head of the config file, see below.

name: Tacotron2
sample_rate: 22050
# <PAD>, <BOS>, <EOS> will be added by the tacotron2.py script
labels: [' ', '!', '"', "'", '(', ')', ',', '-', '.', ':', ';', '?', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H',
'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', '[', ']',
'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't',
'u', 'v', 'w', 'x', 'y', 'z']
n_fft: 1024
n_mels: 80
fmax: 8000
n_stride: 256
pad_value: -11.52

train_dataset: ???
validation_datasets: ???
sup_data_path: null
sup_data_types: null

phoneme_dict_path: "scripts/tts_dataset_files/cmudict-0.7b_nv22.01"
heteronyms_path: "scripts/tts_dataset_files/heteronyms-030921"
whitelist_path: "nemo_text_processing/text_normalization/en/data/whitelist_lj_speech.tsv"



model:
labels: ${labels}
pitch_fmin: 65.40639132514966
pitch_fmax: 2093.004522404789

sample_rate: 22050
n_mel_channels: 80
n_window_size: 1024
n_window_stride: 256
n_fft: 1024
lowfreq: 0
highfreq: 8000
window: hann
pad_value: -11.52


text_normalizer:
_target_: nemo_text_processing.text_normalization.normalize.Normalizer
lang: en
input_case: cased
whitelist: ${whitelist_path}

text_normalizer_call_kwargs:
verbose: false
punct_pre_process: true
punct_post_process: true

text_tokenizer:
_target_: nemo.collections.tts.torch.tts_tokenizers.EnglishPhonemesTokenizer
punct: true
stresses: true
chars: true
apostrophe: true
pad_with_space: true
g2p:
_target_: nemo.collections.tts.torch.g2ps.EnglishG2p
phoneme_dict: ${phoneme_dict_path}
heteronyms: ${heteronyms_path}

train_ds:
dataset:
_target_: "nemo.collections.asr.data.audio_to_text.AudioToCharDataset"
_target_: "nemo.collections.tts.torch.data.TTSDataset"
manifest_filepath: ${train_dataset}
sample_rate: ${model.sample_rate}
sup_data_path: ${sup_data_path}
sup_data_types: ${sup_data_types}
n_fft: ${model.n_fft}
win_length: ${model.n_window_size}
hop_length: ${model.n_window_stride}
window: ${model.window}
n_mels: ${model.n_mel_channels}
lowfreq: ${model.lowfreq}
highfreq: ${model.highfreq}
max_duration: null
min_duration: 0.1
trim: false
int_values: false
normalize: true
sample_rate: ${sample_rate}
# bos_id: 66
# eos_id: 67
# pad_id: 68 These parameters are added automatically in Tacotron2
ignore_file: null
trim: False
pitch_fmin: ${model.pitch_fmin}
pitch_fmax: ${model.pitch_fmax}
dataloader_params:
drop_last: false
shuffle: true
batch_size: 48
num_workers: 4


pin_memory: false
validation_ds:
dataset:
_target_: "nemo.collections.asr.data.audio_to_text.AudioToCharDataset"
manifest_filepath: ${validation_datasets}
_target_: "nemo.collections.tts.torch.data.TTSDataset"
manifest_filepath: ${train_dataset}
sample_rate: ${model.sample_rate}
sup_data_path: ${sup_data_path}
sup_data_types: ${sup_data_types}
n_fft: ${model.n_fft}
win_length: ${model.n_window_size}
hop_length: ${model.n_window_stride}
window: ${model.window}
n_mels: ${model.n_mel_channels}
lowfreq: ${model.lowfreq}
highfreq: ${model.highfreq}
max_duration: null
min_duration: 0.1
int_values: false
normalize: true
sample_rate: ${sample_rate}
trim: false
# bos_id: 66
# eos_id: 67
# pad_id: 68 These parameters are added automatically in Tacotron2
ignore_file: null
trim: False
pitch_fmin: ${model.pitch_fmin}
pitch_fmax: ${model.pitch_fmax}
dataloader_params:
drop_last: false
shuffle: false
batch_size: 48
batch_size: 24
num_workers: 8
pin_memory: false

preprocessor:
_target_: nemo.collections.asr.parts.preprocessing.features.FilterbankFeatures
dither: 0.0
nfilt: ${n_mels}
frame_splicing: 1
highfreq: ${fmax}
nfilt: ${model.n_mel_channels}
highfreq: ${model.highfreq}
log: true
log_zero_guard_type: clamp
log_zero_guard_value: 1e-05
lowfreq: 0
mag_power: 1.0
n_fft: ${n_fft}
n_window_size: 1024
n_window_stride: ${n_stride}
normalize: null
lowfreq: ${model.lowfreq}
n_fft: ${model.n_fft}
n_window_size: ${model.n_window_size}
n_window_stride: ${model.n_window_stride}
pad_to: 16
pad_value: ${pad_value}
pad_value: ${model.pad_value}
sample_rate: ${model.sample_rate}
window: ${model.window}
normalize: null
preemph: null
sample_rate: ${sample_rate}
window: hann
dither: 0.0
frame_splicing: 1
stft_conv: false
nb_augmentation_prob : 0
mag_power: 1.0
exact_pad: true
use_grads: false

encoder:
_target_: nemo.collections.tts.modules.tacotron2.Encoder
Expand All @@ -90,7 +145,7 @@ model:
gate_threshold: 0.5
max_decoder_steps: 1000
n_frames_per_step: 1 # currently only 1 is supported
n_mel_channels: ${n_mels}
n_mel_channels: ${model.n_mel_channels}
p_attention_dropout: 0.1
p_decoder_dropout: 0.1
prenet_dim: 256
Expand All @@ -105,7 +160,7 @@ model:

postnet:
_target_: nemo.collections.tts.modules.tacotron2.Postnet
n_mel_channels: ${n_mels}
n_mel_channels: ${model.n_mel_channels}
p_dropout: 0.5
postnet_embedding_dim: 512
postnet_kernel_size: 5
Expand All @@ -132,11 +187,15 @@ trainer:
enable_checkpointing: False # Provided by exp_manager
logger: False # Provided by exp_manager
gradient_clip_val: 1.0
log_every_n_steps: 200
check_val_every_n_epoch: 25
log_every_n_steps: 60
check_val_every_n_epoch: 2


exp_manager:
exp_dir: null
name: ${name}
create_tensorboard_logger: True
create_checkpoint_callback: True
create_tensorboard_logger: true
create_checkpoint_callback: true
checkpoint_callback_params:
monitor: val_loss
mode: min
68 changes: 68 additions & 0 deletions nemo/collections/tts/helpers/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,12 @@

from nemo.utils import logging

HAVE_WANDB = True
try:
import wandb
except ModuleNotFoundError:
HAVE_WANDB = False

try:
from pytorch_lightning.utilities import rank_zero_only
except ModuleNotFoundError:
Expand Down Expand Up @@ -284,6 +290,7 @@ def tacotron2_log_to_tb_func(
step,
dataformats="HWC",
)

if add_audio:
filterbank = librosa.filters.mel(sr=sr, n_fft=n_fft, n_mels=n_mels, fmax=fmax)
log_mel = mel_postnet[0].data.cpu().numpy().T
Expand All @@ -299,6 +306,67 @@ def tacotron2_log_to_tb_func(
swriter.add_audio(f"audio/{tag}_target", audio / max(np.abs(audio)), step, sample_rate=sr)


def tacotron2_log_to_wandb_func(
swriter,
tensors,
step,
tag="train",
log_images=False,
log_images_freq=1,
add_audio=True,
griffin_lim_mag_scale=1024,
griffin_lim_power=1.2,
sr=22050,
n_fft=1024,
n_mels=80,
fmax=8000,
):
_, spec_target, mel_postnet, gate, gate_target, alignments = tensors
if not HAVE_WANDB:
return
if log_images and step % log_images_freq == 0:
alignments = []
specs = []
gates = []
alignments += [
wandb.Image(plot_alignment_to_numpy(alignments[0].data.cpu().numpy().T), caption=f"{tag}_alignment",)
]
alignments += [
wandb.Image(plot_spectrogram_to_numpy(spec_target[0].data.cpu().numpy()), caption=f"{tag}_mel_target",),
wandb.Image(plot_spectrogram_to_numpy(mel_postnet[0].data.cpu().numpy()), caption=f"{tag}_mel_predicted",),
]
gates += [
wandb.Image(
plot_gate_outputs_to_numpy(
gate_target[0].data.cpu().numpy(), torch.sigmoid(gate[0]).data.cpu().numpy(),
),
caption=f"{tag}_gate",
)
]

swriter.log({"specs": specs, "alignments": alignments, "gates": gates})

if add_audio:
audios = []
filterbank = librosa.filters.mel(sr=sr, n_fft=n_fft, n_mels=n_mels, fmax=fmax)
log_mel = mel_postnet[0].data.cpu().numpy().T
mel = np.exp(log_mel)
magnitude = np.dot(mel, filterbank) * griffin_lim_mag_scale
audio_pred = griffin_lim(magnitude.T ** griffin_lim_power)

log_mel = spec_target[0].data.cpu().numpy().T
mel = np.exp(log_mel)
magnitude = np.dot(mel, filterbank) * griffin_lim_mag_scale
audio_true = griffin_lim(magnitude.T ** griffin_lim_power)

audios += [
wandb.Audio(audio_true / max(np.abs(audio_true)), caption=f"{tag}_wav_target", sample_rate=sr,),
wandb.Audio(audio_pred / max(np.abs(audio_pred)), caption=f"{tag}_wav_predicted", sample_rate=sr,),
]

swriter.log({"audios": audios})


def plot_alignment_to_numpy(alignment, info=None):
fig, ax = plt.subplots(figsize=(6, 4))
im = ax.imshow(alignment, aspect='auto', origin='lower', interpolation='none')
Expand Down
Loading

0 comments on commit 30db4d4

Please sign in to comment.