Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DPO notes and some script updates #38

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 62 additions & 13 deletions examples/tts/conf/t5tts/t5tts_inference.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -42,19 +42,68 @@ model:

sample_rate: ${sample_rate}

text_tokenizer:
_target_: nemo.collections.common.tokenizers.text_to_speech.tts_tokenizers.IPATokenizer
punct: true
apostrophe: true
pad_with_space: false
g2p:
_target_: nemo.collections.tts.g2p.models.i18n_ipa.IpaG2p
phoneme_dict: ${phoneme_dict_path}
heteronyms: ${heteronyms_path}
phoneme_probability: 0.8
ignore_ambiguous_words: false
use_chars: true
use_stresses: true
text_tokenizers: # Add more languages for multi-lingual TTS
english_phoneme:
_target_: nemo.collections.common.tokenizers.text_to_speech.tts_tokenizers.IPATokenizer
punct: true
apostrophe: true
pad_with_space: false
g2p:
_target_: nemo.collections.tts.g2p.models.i18n_ipa.IpaG2p
phoneme_dict: "scripts/tts_dataset_files/ipa_cmudict-0.7b_nv23.01.txt"
heteronyms: "scripts/tts_dataset_files/heteronyms-052722"
phoneme_probability: 0.8
ignore_ambiguous_words: false
use_chars: true
use_stresses: true
spanish_phoneme:
_target_: nemo.collections.common.tokenizers.text_to_speech.tts_tokenizers.IPATokenizer
locale: es-ES
punct: true
apostrophe: true
pad_with_space: true
g2p:
_target_: nemo.collections.tts.g2p.models.i18n_ipa.IpaG2p
locale: es-ES
phoneme_dict: "scripts/tts_dataset_files/es_ES/es_ES_nv230301.dict"
phoneme_probability: 0.8
ignore_ambiguous_words: false
use_chars: true
use_stresses: true
german_phoneme:
_target_: nemo.collections.common.tokenizers.text_to_speech.tts_tokenizers.IPATokenizer
locale: de-DE
punct: true
apostrophe: true
pad_with_space: true
g2p:
_target_: nemo.collections.tts.g2p.models.i18n_ipa.IpaG2p
locale: 'de-DE'
phoneme_dict: "scripts/tts_dataset_files/de/de_nv230119.dict"
heteronyms: "scripts/tts_dataset_files/de/de_nv230119.heteronym"
phoneme_probability: 0.8
ignore_ambiguous_words: false
use_chars: true
use_stresses: true
grapheme_case: mixed
grapheme_prefix: '#'
mandarin_phoneme:
_target_: nemo.collections.common.tokenizers.text_to_speech.tts_tokenizers.ChinesePhonemesTokenizer
punct: true
apostrophe: true
pad_with_space: true
g2p:
_target_: nemo.collections.tts.g2p.models.zh_cn_pinyin.ChineseG2p
phoneme_dict: "scripts/tts_dataset_files/zh/36finals/ipa_dict_nv23.05.txt"
word_segmenter: "jieba"
phoneme_prefix: ""
phoneme_case: "lower"
tone_prefix: "#"
ascii_letter_prefix: ""
ascii_letter_case: "upper"
multilingual_sentencepiece:
_target_: AutoTokenizer
pretrained_model: "bert-base-multilingual-uncased"

test_ds:
dataset:
Expand Down
128 changes: 121 additions & 7 deletions examples/tts/t5tts_commands.md
Original file line number Diff line number Diff line change
Expand Up @@ -234,9 +234,9 @@ model.prior_scaling_factor=null \

| Model Type | Cluster | Training Sub File |
|------------|---------|--------|
| multi_encoder_context_tts | login-eos | /lustre/fsw/llmservice_nemo_speechlm/users/pneekhara/scriptsSimpleT5/multiencoder_t5tts.sub |
| decoder_context_tts | login-eos | /lustre/fsw/llmservice_nemo_speechlm/users/pneekhara/scriptsSimpleT5/decodercontext_t5tts.sub |
| single_encoder_sv_tts | login-eos | /lustre/fsw/llmservice_nemo_speechlm/users/pneekhara/scriptsSimpleT5/singleencoder_svt5tts.sub |
| multi_encoder_context_tts | draco-oci-login-01.draco-oci-iad.nvidia.com |/lustre/fsw/portfolios/llmservice/users/pneekhara/launchscripts/unnormalized_me.sub |
| decoder_context_tts | draco-oci-login-01.draco-oci-iad.nvidia.com | /lustre/fsw/portfolios/llmservice/users/pneekhara/launchscripts/unnormalizedt5_decoder.sub |
| single_encoder_sv_tts | draco-oci-login-01.draco-oci-iad.nvidia.com | /lustre/fsw/portfolios/llmservice/users/pneekhara/launchscripts/unnormalizedt5_singleencoder.sub |
| decoder_pretrain_synthesizer | login-eos | /lustre/fsw/llmservice_nemo_speechlm/users/pneekhara/scriptsSimpleT5/newt5_pretrain.sub |

## Pretrained Models and Results
Expand Down Expand Up @@ -277,13 +277,127 @@ python scripts/t5tts/infer_and_evaluate.py \
--datasets "vctk,libri_val" \
--out_dir /datap/misc/Evals \
--temperature 0.6 \
--topk 80 \
--use_cfg \
--cfg_scale 1.8 ;
--topk 80
```

Ignore the other params in the file, I also use this for evaluating ongoing experiments on the cluster by copying over the checkpoints and hparams..

### Inference Notebook

Inference Notebook: `t5tts_inference.ipynb` For quickly trying custom texts/contexts.
Inference Notebook: `t5tts_inference.ipynb` For quickly trying custom texts/contexts.

### DPO Preference Alignment

Preference Alignment (DPO) involves the following steps
1) Create a list of text-context pairs for which we will generate preference data.
2) For each text-context pair generate multiple audios from a base T5-TTS checkpoint and calculate metrics (CER/SSIM) for each generation.
3) Create chosen-rejected pairs from the generated audio.
4) Finetune the base T5-TTS checkpoint on the chosen-rejected pairs.

#### 1. Create text-context pairs
We pair a list of challenging texts with context audios from from Riva and LibriTTS dataset. We add a similar number of regular texts from LibriTTS and Riva (paired with random context audios). We also include examples with text contexts. There are other options for generating text-context pairs.

```
python scripts/t5tts/dpo/create_text_contextpairs.py \
--challenging_texts /Data/DPOPairsInputData/challenging_texts_nemollm.txt \
--regular_texts_for_audiocontext /Data/DPOPairsInputData/regular_texts_for_audiocontext.txt \
--regular_texts_for_textcontext /Data/DPOPairsInputData/regular_texts_for_textcontext.txt \
--audio_contexts /Data/DPOPairsInputData/audio_context_list.json \
--text_contexts /Data/DPOPairsInputData/text_context_list.txt \
--output_manifest /Data/DPOPairsInputData/text_context_pairs_v2.json \
--nsamples_perpair 6 ;
```
Each pair is repeated `nsamples_perpair` times which specifies how many samples we want to generate for each pair. The output manifest serves as the input for the next step.

We can also explore other options for these text-context pairs as well depending on the task.

#### 2. Generate audios for each text-context pair

Next, we can generate audios from a base T5-TTS checkpoint using the following command. We pass the `audio_dir` as "/" since our text context pairs contains absolute paths. Model config arguments should be modified accordingly to match the base checkpoint architecture. We can run the below command on cluster to generate audios across multiple nodes. This command saves the generated audios along with the metrics for each generation in the `exp_dir`. Each generated audio file is accompanied with a `.json` file that has the CER/SSIM metrics.

Sample sub file on EOS: `/lustre/fsw/llmservice_nemo_speechlm/users/shehzeenh/launchscripts/newdatagendpo_decoder.sub`

```
python examples/tts/t5tts.py \
--config-name=t5tts_inference \
batch_size=64 \
+init_from_ptl_ckpt="/mountdir/checkpoints/continuouscheckpoints_ks1_ks3/decodercontext_small_282.ckpt" \
exp_manager.exp_dir="/lustre/fsw/llmservice_nemo_speechlm/data/TTS/DPOData/Generations/decodercontext_small_282" \
+test_ds_meta.textcontextpairs.manifest_path="/lustre/fsw/llmservice_nemo_speechlm/data/TTS/DPOData/manifests/dpo_textcontext_pairs.json" \
+test_ds_meta.textcontextpairs.audio_dir="/" \
+test_ds_meta.textcontextpairs.feature_dir="/" \
model.model_type="decoder_context_tts" \
model.t5_encoder.kernel_size=3 \
model.t5_decoder.kernel_size=1 \
model.context_duration_min=5.0 \
model.context_duration_max=5.0 \
model.use_text_conditioning_encoder=true \
model.codecmodel_path="/mountdir/checkpoints/AudioCodec_21Hz_no_eliz.nemo" \
model.alignment_loss_scale=0.002 \
model.prior_scaling_factor=null \
model.load_cached_codes_if_available=false \
+model.use_kv_cache_for_inference=true \
trainer.num_nodes=${SLURM_JOB_NUM_NODES}
```
#### 3. Create chosen-rejected pairs from the generations

Next, we go through the generated audio directory and create chosen-rejected pairs.

```
python scripts/t5tts/dpo/create_preference_pairs.py \
--input_manifest /lustre/fsw/llmservice_nemo_speechlm/data/TTS/DPOData/manifests/dpo_textcontext_pairs.json \
--generated_audio_dir /lustre/fsw/llmservice_nemo_speechlm/data/TTS/DPOData/Generations/decodercontext_small_282/T5TTS/version_0/audios \
--group_size 6 \
--cer_threshold 0.01 \
--val_size 256 ;
```

`cer_threshold=0.01` means that filter out pairs in which the chosen CER > 0.01.

This command should save train and val manifests for DPO finetuning in the base directory of the generated_audio_dir, that is, `/lustre/fsw/llmservice_nemo_speechlm/data/TTS/DPOData/Generations/decodercontext_small_282/T5TTS/version_0/manifests/`

#### 4. DPO Finetuning Command

Finally, we perform DPO finetuning using the following command:

```
python examples/tts/t5tts.py \
batch_size=4 \
+init_from_ptl_ckpt="/mountdir/checkpoints/decoder_21_epoch_2.ckpt" \
+mode="dpo_train" \
max_epochs=10 \
exp_manager.exp_dir="/lustre/fsw/llmservice_nemo_speechlm/data/TTS/DPOData/TrainingsICML/decodercontext_small_282" \
exp_manager.checkpoint_callback_params.always_save_nemo=false \
model.train_ds.dataset._target_="nemo.collections.tts.data.text_to_speech_dataset.T5TTSDatasetDPO" \
model.validation_ds.dataset._target_="nemo.collections.tts.data.text_to_speech_dataset.T5TTSDatasetDPO" \
+train_ds_meta.dpopreftrain.manifest_path="/lustre/fsw/llmservice_nemo_speechlm/data/TTS/DPOData/Generations/decodercontext_small_282/T5TTS/version_0/manifests/dpo_train_manifest.json" \
+train_ds_meta.dpopreftrain.audio_dir="/" \
+train_ds_meta.dpopreftrain.feature_dir="/" \
+val_ds_meta.dpoprefval.manifest_path="/lustre/fsw/llmservice_nemo_speechlm/data/TTS/DPOData/Generations/decodercontext_small_282/T5TTS/version_0/manifests/dpo_val_manifest.json" \
+val_ds_meta.dpoprefval.audio_dir="/" \
+val_ds_meta.dpoprefval.feature_dir="/" \
+model.dpo_beta=0.01 \
+model.dpo_sft_loss_weight=0.0 \
model.model_type="decoder_context_tts" \
model.context_duration_min=5.0 \
model.context_duration_max=5.0 \
model.use_text_conditioning_encoder=true \
model.codecmodel_path="/mountdir/checkpoints/AudioCodec_21Hz_no_eliz.nemo" \
model.alignment_loss_scale=0.001 \
model.prior_scaling_factor=null \
trainer.val_check_interval=200 \
trainer.log_every_n_steps=10 \
model.optim.lr=2e-7 \
~model.optim.sched \
trainer.num_nodes=${SLURM_JOB_NUM_NODES}
```

Note the following overrides in the above command:

```
+mode="dpo_train" \
model.train_ds.dataset._target_="nemo.collections.tts.data.text_to_speech_dataset.T5TTSDatasetDPO" \
model.validation_ds.dataset._target_="nemo.collections.tts.data.text_to_speech_dataset.T5TTSDatasetDPO" \
```

Again, our manifest contain absolute paths so we specify `audio_dir="/"` .
3 changes: 2 additions & 1 deletion nemo/collections/tts/models/t5tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -746,7 +746,7 @@ def infer_batch(self, batch, max_decoder_steps=500, temperature=0.7, topk=80, us
if self.use_kv_cache_for_inference:
assert self.cfg.t5_decoder.use_flash_self_attention is False, "KV cache is not supported with flash self attention"
assert self.cfg.t5_decoder.use_flash_x_attention is False, "KV cache is not supported with flash cross attention"
assert self.cfg.t5_decoder.pos_emb.name == "learnable", "KV cache is not tested with Rope, Alibi yet. Disable this assert, if you still want to use it."
assert self.cfg.t5_decoder.pos_emb.name in ["learnable", "learnable_v2"], "KV cache is not tested with Rope, Alibi yet. Disable this assert, if you still want to use it."

self.t5_decoder.reset_cache(use_cache=self.use_kv_cache_for_inference)

Expand Down Expand Up @@ -841,6 +841,7 @@ def infer_batch(self, batch, max_decoder_steps=500, temperature=0.7, topk=80, us

predicted_audio, predicted_audio_lens = self.codes_to_audio(predicted_codes, predicted_codes_lens)

torch.cuda.empty_cache()
return predicted_audio, predicted_audio_lens, predicted_codes, predicted_codes_lens

def test_step(self, batch, batch_idx):
Expand Down
Loading
Loading