Skip to content

Commit

Permalink
[TTS] Add Mixed Representation Training (#3473)
Browse files Browse the repository at this point in the history
* Update CMUdict with ADLR version pronunciations

Signed-off-by: Jocelyn Huang <[email protected]>

* minor updates for finetuning

Signed-off-by: Jason <[email protected]>

* update conf

Signed-off-by: Jason <[email protected]>

* merge

Signed-off-by: Jason <[email protected]>

* update

Signed-off-by: Jason <[email protected]>

* update

Signed-off-by: Jason <[email protected]>

* bug fixes

Signed-off-by: Jason <[email protected]>

* update config

Signed-off-by: Jason <[email protected]>

* bf16 support

Signed-off-by: Jason <[email protected]>

* bf16 support

Signed-off-by: Jason <[email protected]>

* bugfix

Signed-off-by: Jason <[email protected]>

* update

Signed-off-by: Jason <[email protected]>

* finalize changes

Signed-off-by: Jason <[email protected]>

* undo notebook 1.6.0 pins

Signed-off-by: Jason <[email protected]>

* more 1.6.0 undos

Signed-off-by: Jason <[email protected]>

* wip

Signed-off-by: Jason <[email protected]>

* update num_workers

Signed-off-by: Jason <[email protected]>

* update hypers

Signed-off-by: Jason <[email protected]>

* revert to main _align yamls

Signed-off-by: Jason <[email protected]>

* update yamls

Signed-off-by: Jason <[email protected]>

* cleanup

Signed-off-by: Jason <[email protected]>

* remove unnecessary line

Signed-off-by: Jason <[email protected]>

* address comments

Signed-off-by: Jason <[email protected]>

* update vocoder mel uploading; add contextmanager to mixed g2p

Signed-off-by: Jason <[email protected]>

* update comments; make prob required argument

Signed-off-by: Jason <[email protected]>

* added val check

Signed-off-by: Jason <[email protected]>

* update message

Signed-off-by: Jason <[email protected]>

* update

Signed-off-by: Jason <[email protected]>

* revert num workers

Signed-off-by: Jason <[email protected]>

Co-authored-by: Jocelyn Huang <[email protected]>
  • Loading branch information
blisc and redoctopus authored Feb 11, 2022
1 parent 058fa38 commit c645c4c
Show file tree
Hide file tree
Showing 9 changed files with 102 additions and 31 deletions.
5 changes: 2 additions & 3 deletions Jenkinsfile
Original file line number Diff line number Diff line change
Expand Up @@ -2210,14 +2210,13 @@ pipeline {
~trainer.check_val_every_n_epoch'
}
}
// TODO(Oktai15): update it in 1.8.0 version
stage('FastPitch') {
steps {
sh 'python examples/tts/fastpitch.py \
--config-name fastpitch_align \
--config-name fastpitch_align_v1.05 \
train_dataset=/home/TestData/an4_dataset/an4_train.json \
validation_datasets=/home/TestData/an4_dataset/an4_val.json \
prior_folder=/home/TestData/an4_dataset/beta_priors \
sup_data_path=/home/TestData/an4_dataset/beta_priors \
trainer.devices="[0]" \
+trainer.limit_train_batches=1 +trainer.limit_val_batches=1 trainer.max_epochs=1 \
trainer.strategy=null \
Expand Down
3 changes: 3 additions & 0 deletions examples/tts/conf/fastpitch_align_v1.05.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ model:
_target_: nemo.collections.tts.torch.g2ps.EnglishG2p
phoneme_dict: ${phoneme_dict_path}
heteronyms: ${heteronyms_path}
phoneme_probability: 0.5

train_ds:
dataset:
Expand All @@ -101,6 +102,7 @@ model:
pitch_norm: true
pitch_mean: ${model.pitch_mean}
pitch_std: ${model.pitch_std}
use_beta_binomial_interpolator: true

dataloader_params:
drop_last: false
Expand Down Expand Up @@ -131,6 +133,7 @@ model:
pitch_norm: true
pitch_mean: ${model.pitch_mean}
pitch_std: ${model.pitch_std}
use_beta_binomial_interpolator: true

dataloader_params:
drop_last: false
Expand Down
2 changes: 1 addition & 1 deletion examples/tts/conf/hifigan/hifigan_44100.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ train_n_segments: 16384
train_max_duration: null
train_min_duration: 0.75

val_n_segments: 132096
val_n_segments: 131072
val_max_duration: null
val_min_duration: 3

Expand Down
6 changes: 6 additions & 0 deletions nemo/collections/common/data/vocabs.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import time
import unicodedata
from builtins import str as unicode
from contextlib import contextmanager
from typing import List

import nltk
Expand Down Expand Up @@ -375,3 +376,8 @@ def encode(self, text):
ps = [space] + ps + [space]

return [self._label2id[p] for p in ps]

@contextmanager
def set_phone_prob(self, prob=None):
# Add do nothing since this class doesn't support mixed g2p
yield
39 changes: 24 additions & 15 deletions nemo/collections/tts/models/fastpitch.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,13 +195,21 @@ def parser(self):
return self._parser

def parse(self, str_input: str, normalize=True) -> torch.tensor:
if self.training:
logging.warning("parse() is meant to be called in eval mode.")
if str_input[-1] not in [".", "!", "?"]:
str_input = str_input + "."

if normalize and self.text_normalizer_call is not None:
str_input = self.text_normalizer_call(str_input, **self.text_normalizer_call_kwargs)

tokens = self.parser(str_input)
if self.learn_alignment:
# Disable mixed g2p representation
with self.vocab.set_phone_prob(prob=1.0):
tokens = self.parser(str_input)
else:
# TODO(Oktai15): remove it in 1.8.0 version
tokens = self.parser(str_input)

x = torch.tensor(tokens).unsqueeze_(0).long().to(self.device)
return x
Expand Down Expand Up @@ -246,8 +254,8 @@ def forward(

@typecheck(output_types={"spect": NeuralType(('B', 'D', 'T_spec'), MelSpectrogramType())})
def generate_spectrogram(self, tokens: 'torch.tensor', speaker: int = 0, pace: float = 1.0) -> torch.tensor:
# FIXME: return masks as well?
self.eval()
if self.training:
logging.warning("generate_spectrogram() is meant to be called in eval mode.")
if isinstance(speaker, int):
speaker = torch.tensor([speaker]).to(self.device)
spect, *_ = self(text=tokens, durs=None, pitch=None, speaker=speaker, pace=pace)
Expand Down Expand Up @@ -312,20 +320,20 @@ def training_step(self, batch, batch_idx):

self.tb_logger.add_image(
"train_mel_target",
plot_spectrogram_to_numpy(mels[0].data.cpu().numpy()),
plot_spectrogram_to_numpy(mels[0].data.cpu().float().numpy()),
self.global_step,
dataformats="HWC",
)
spec_predict = mels_pred[0].data.cpu().numpy()
spec_predict = mels_pred[0].data.cpu().float().numpy()
self.tb_logger.add_image(
"train_mel_predicted", plot_spectrogram_to_numpy(spec_predict), self.global_step, dataformats="HWC",
)
if self.learn_alignment:
attn = attn_hard[0].data.cpu().numpy().squeeze()
attn = attn_hard[0].data.cpu().float().numpy().squeeze()
self.tb_logger.add_image(
"train_attn", plot_alignment_to_numpy(attn.T), self.global_step, dataformats="HWC",
)
soft_attn = attn_soft[0].data.cpu().numpy().squeeze()
soft_attn = attn_soft[0].data.cpu().float().numpy().squeeze()
self.tb_logger.add_image(
"train_soft_attn", plot_alignment_to_numpy(soft_attn.T), self.global_step, dataformats="HWC",
)
Expand Down Expand Up @@ -396,11 +404,11 @@ def validation_epoch_end(self, outputs):
if isinstance(self.logger, TensorBoardLogger):
self.tb_logger.add_image(
"val_mel_target",
plot_spectrogram_to_numpy(spec_target[0].data.cpu().numpy()),
plot_spectrogram_to_numpy(spec_target[0].data.cpu().float().numpy()),
self.global_step,
dataformats="HWC",
)
spec_predict = spec_predict[0].data.cpu().numpy()
spec_predict = spec_predict[0].data.cpu().float().numpy()
self.tb_logger.add_image(
"val_mel_predicted", plot_spectrogram_to_numpy(spec_predict), self.global_step, dataformats="HWC",
)
Expand Down Expand Up @@ -428,12 +436,13 @@ def __setup_dataloader_from_config(self, cfg, shuffle_should_be: bool = True, na
if cfg.dataset._target_ == "nemo.collections.asr.data.audio_to_text.FastPitchDataset":
dataset = instantiate(cfg.dataset, parser=self.parser)
elif cfg.dataset._target_ == "nemo.collections.tts.torch.data.TTSDataset":
dataset = instantiate(
cfg.dataset,
text_normalizer=self.normalizer,
text_normalizer_call_kwargs=self.text_normalizer_call_kwargs,
text_tokenizer=self.vocab,
)
with self.vocab.set_phone_prob(prob=None if name == "val" else self.vocab.phoneme_probability):
dataset = instantiate(
cfg.dataset,
text_normalizer=self.normalizer,
text_normalizer_call_kwargs=self.text_normalizer_call_kwargs,
text_tokenizer=self.vocab,
)
else:
# TODO(Oktai15): remove it in 1.8.0 version
dataset = instantiate(cfg.dataset)
Expand Down
6 changes: 5 additions & 1 deletion nemo/collections/tts/models/mixer_tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -644,9 +644,13 @@ def generate_spectrogram(
return pred_spect

def parse(self, text: str, normalize=True) -> torch.Tensor:
if self.training:
logging.warning("parse() is meant to be called in eval mode.")
if normalize and self.text_normalizer_call is not None:
text = self.text_normalizer_call(text, **self.text_normalizer_call_kwargs)
return torch.tensor(self.tokenizer.encode(text)).long().unsqueeze(0).to(self.device)
with self.tokenizer.set_phone_prob(prob=1.0):
tokens = self.tokenizer.encode(text)
return torch.tensor(tokens).long().unsqueeze(0).to(self.device)

def _loader(self, cfg):
try:
Expand Down
36 changes: 28 additions & 8 deletions nemo/collections/tts/torch/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from typing import Callable, Dict, List, Optional, Union

import librosa
import numpy as np
import torch
from nemo_text_processing.text_normalization.normalize import Normalizer
from tqdm import tqdm
Expand Down Expand Up @@ -134,9 +135,12 @@ def __init__(

# Initialize text tokenizer
self.text_tokenizer = text_tokenizer

self.phoneme_probability = None
if isinstance(self.text_tokenizer, BaseTokenizer):
self.text_tokenizer_pad_id = text_tokenizer.pad
self.tokens = text_tokenizer.tokens
self.phoneme_probability = self.text_tokenizer.phoneme_probability
else:
if text_tokenizer_pad_id is None:
raise ValueError(f"text_tokenizer_pad_id must be specified if text_tokenizer is not BaseTokenizer")
Expand All @@ -146,6 +150,7 @@ def __init__(

self.text_tokenizer_pad_id = text_tokenizer_pad_id
self.tokens = tokens
self.cache_text = True if self.phoneme_probability is None else False

# Initialize text normalizer is specified
self.text_normalizer = text_normalizer
Expand Down Expand Up @@ -179,15 +184,14 @@ def __init__(

if "normalized_text" not in item:
text = item["text"]

if self.text_normalizer is not None:
text = self.text_normalizer_call(text, **self.text_normalizer_call_kwargs)

file_info["normalized_text"] = text
file_info["text_tokens"] = self.text_tokenizer(text)
else:
file_info["normalized_text"] = item["normalized_text"]
file_info["text_tokens"] = self.text_tokenizer(item["normalized_text"])

if self.cache_text:
file_info["text_tokens"] = self.text_tokenizer(file_info["normalized_text"])

data.append(file_info)

Expand Down Expand Up @@ -241,6 +245,7 @@ def __init__(
hop_length=self.hop_len,
win_length=self.win_length,
window=window_fn(self.win_length, periodic=False).to(torch.float) if window_fn else None,
return_complex=True,
)

# Initialize sup_data_path, sup_data_types and run preprocessing methods for every supplementary data type
Expand Down Expand Up @@ -331,6 +336,13 @@ def add_align_prior_matrix(self, **kwargs):
self.align_prior_matrix_folder.mkdir(exist_ok=True, parents=True)

self.use_beta_binomial_interpolator = kwargs.pop('use_beta_binomial_interpolator', False)
if not self.cache_text:
if 'use_beta_binomial_interpolator' in kwargs and not self.use_beta_binomial_interpolator:
logging.warning(
"phoneme_probability is not None, but use_beta_binomial_interpolator=False, we"
" set use_beta_binomial_interpolator=True manually to use phoneme_probability."
)
self.use_beta_binomial_interpolator = True

if self.use_beta_binomial_interpolator:
self.beta_binomial_interpolator = BetaBinomialInterpolator()
Expand Down Expand Up @@ -386,9 +398,13 @@ def __getitem__(self, index):
features = self.featurizer.process(sample["audio_filepath"], trim=self.trim)
audio, audio_length = features, torch.tensor(features.shape[0]).long()

# Load text
text = torch.tensor(sample["text_tokens"]).long()
text_length = torch.tensor(len(sample["text_tokens"])).long()
if "text_tokens" in sample:
text = torch.tensor(sample["text_tokens"]).long()
text_length = torch.tensor(len(sample["text_tokens"])).long()
else:
tokenized = self.text_tokenizer(sample["normalized_text"])
text = torch.tensor(tokenized).long()
text_length = torch.tensor(len(tokenized)).long()

# Load mel if needed
log_mel, log_mel_length = None, None
Expand Down Expand Up @@ -417,6 +433,7 @@ def __getitem__(self, index):
# Load alignment prior matrix if needed
align_prior_matrix = None
if AlignPriorMatrix in self.sup_data_types_set:
align_prior_matrix = None
if self.use_beta_binomial_interpolator:
mel_len = self.get_log_mel(audio).shape[2]
align_prior_matrix = torch.from_numpy(self.beta_binomial_interpolator(mel_len, text_length.item()))
Expand Down Expand Up @@ -823,7 +840,10 @@ def __getitem__(self, index):
features = self.featurizer.process(sample["audio_filepath"], trim=self.trim)
audio, audio_length = features, torch.tensor(features.shape[0]).long()

mel = torch.load(sample["mel_filepath"])
if Path(sample["mel_filepath"]).suffix == ".npy":
mel = np.load(sample["mel_filepath"])
else:
mel = torch.load(sample["mel_filepath"])
frames = math.ceil(self.n_segments / self.hop_length)

if len(audio) > self.n_segments:
Expand Down
11 changes: 11 additions & 0 deletions nemo/collections/tts/torch/g2ps.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@

import abc
import pathlib
import random
import re
import time
from typing import Optional

import nltk
import torch
Expand Down Expand Up @@ -53,6 +55,7 @@ def __init__(
ignore_ambiguous_words=True,
heteronyms=None,
encoding='latin-1',
phoneme_probability: Optional[float] = None,
):
"""English G2P module. This module converts words from grapheme to phoneme representation using phoneme_dict in CMU dict format.
Optionally, it can ignore words which are heteronyms, ambiguous or marked as unchangeable by word_tokenize_func (see code for details).
Expand All @@ -67,6 +70,9 @@ def __init__(
ignore_ambiguous_words: Whether to not handle word via phoneme_dict with ambiguous phoneme sequences. Defaults to True.
heteronyms (str, Path, List): Path to file with heteronyms (every line is new word) or list of words.
encoding: Encoding type.
phoneme_probability (Optional[float]): The probability (0.<var<1.) that each word is phonemized. Defaults to None which is the same as 1.
Note that this code path is only run if the word can be phonemized. For example: If the word does not have a entry in the g2p dict, it will be returned
as characters. If the word has multiple entries and ignore_ambiguous_words is True, it will be returned as characters.
"""
phoneme_dict = (
self._parse_as_cmu_dict(phoneme_dict, encoding)
Expand All @@ -91,6 +97,8 @@ def __init__(
if isinstance(heteronyms, str) or isinstance(heteronyms, pathlib.Path)
else heteronyms
)
self.phoneme_probability = phoneme_probability
self._rng = random.Random()

@staticmethod
def _parse_as_cmu_dict(phoneme_dict_path=None, encoding='latin-1'):
Expand Down Expand Up @@ -163,6 +171,9 @@ def parse_one_word(self, word: str):
`status` will be `False` if word wasn't handled, `True` otherwise.
"""

if self.phoneme_probability is not None and self._rng.random() > self.phoneme_probability:
return word, True

# punctuation
if re.search("[a-zA-Z]", word) is None:
return list(word), True
Expand Down
25 changes: 22 additions & 3 deletions nemo/collections/tts/torch/tts_tokenizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import abc
import itertools
import string
from contextlib import contextmanager
from typing import List

from nemo.collections.tts.torch.de_utils import german_text_preprocessing
Expand Down Expand Up @@ -282,6 +283,9 @@ def __init__(
Note that lower() function shouldn't applied here, because text can contains phonemes (it will be handled by g2p).
"""

self.phoneme_probability = None
if hasattr(g2p, "phoneme_probability"):
self.phoneme_probability = g2p.phoneme_probability
tokens = []
self.space, tokens = len(tokens), tokens + [space] # Space

Expand All @@ -295,7 +299,12 @@ def __init__(
vowels = [f'{p}{s}' for p, s in itertools.product(vowels, (0, 1, 2))]
tokens.extend(vowels)

if chars:
if chars or self.phoneme_probability is not None:
if not chars:
logging.warning(
"phoneme_probability was not None, characters will be enabled even though "
"chars was set to False."
)
tokens.extend(string.ascii_lowercase)

if apostrophe:
Expand All @@ -308,7 +317,7 @@ def __init__(

super().__init__(tokens, oov=oov, sep=sep, add_blank_at=add_blank_at)

self.chars = chars
self.chars = chars if self.phoneme_probability is None else True
self.punct = punct
self.stresses = stresses
self.pad_with_space = pad_with_space
Expand All @@ -321,7 +330,7 @@ def encode(self, text):
ps, space, tokens = [], self.tokens[self.space], set(self.tokens)

text = self.text_preprocessing_func(text)
g2p_text = self.g2p(text)
g2p_text = self.g2p(text) # TODO: handle infer

for p in g2p_text: # noqa
# Remove stress
Expand Down Expand Up @@ -351,3 +360,13 @@ def encode(self, text):
ps = [space] + ps + [space]

return [self._token2id[p] for p in ps]

@contextmanager
def set_phone_prob(self, prob):
if hasattr(self.g2p, "phoneme_probability"):
self.g2p.phoneme_probability = prob
try:
yield
finally:
if hasattr(self.g2p, "phoneme_probability"):
self.g2p.phoneme_probability = self.phoneme_probability

0 comments on commit c645c4c

Please sign in to comment.