diff --git a/Jenkinsfile b/Jenkinsfile index 98f50c3fa465..1cc106cecd20 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -2210,14 +2210,13 @@ pipeline { ~trainer.check_val_every_n_epoch' } } - // TODO(Oktai15): update it in 1.8.0 version stage('FastPitch') { steps { sh 'python examples/tts/fastpitch.py \ - --config-name fastpitch_align \ + --config-name fastpitch_align_v1.05 \ train_dataset=/home/TestData/an4_dataset/an4_train.json \ validation_datasets=/home/TestData/an4_dataset/an4_val.json \ - prior_folder=/home/TestData/an4_dataset/beta_priors \ + sup_data_path=/home/TestData/an4_dataset/beta_priors \ trainer.devices="[0]" \ +trainer.limit_train_batches=1 +trainer.limit_val_batches=1 trainer.max_epochs=1 \ trainer.strategy=null \ diff --git a/examples/tts/conf/fastpitch_align_v1.05.yaml b/examples/tts/conf/fastpitch_align_v1.05.yaml index f96403db6be9..53d401a07682 100644 --- a/examples/tts/conf/fastpitch_align_v1.05.yaml +++ b/examples/tts/conf/fastpitch_align_v1.05.yaml @@ -77,6 +77,7 @@ model: _target_: nemo.collections.tts.torch.g2ps.EnglishG2p phoneme_dict: ${phoneme_dict_path} heteronyms: ${heteronyms_path} + phoneme_probability: 0.5 train_ds: dataset: @@ -101,6 +102,7 @@ model: pitch_norm: true pitch_mean: ${model.pitch_mean} pitch_std: ${model.pitch_std} + use_beta_binomial_interpolator: true dataloader_params: drop_last: false @@ -131,6 +133,7 @@ model: pitch_norm: true pitch_mean: ${model.pitch_mean} pitch_std: ${model.pitch_std} + use_beta_binomial_interpolator: true dataloader_params: drop_last: false diff --git a/examples/tts/conf/hifigan/hifigan_44100.yaml b/examples/tts/conf/hifigan/hifigan_44100.yaml index 98470bcb7005..6fbd0187c265 100644 --- a/examples/tts/conf/hifigan/hifigan_44100.yaml +++ b/examples/tts/conf/hifigan/hifigan_44100.yaml @@ -21,7 +21,7 @@ train_n_segments: 16384 train_max_duration: null train_min_duration: 0.75 -val_n_segments: 132096 +val_n_segments: 131072 val_max_duration: null val_min_duration: 3 diff --git a/nemo/collections/common/data/vocabs.py b/nemo/collections/common/data/vocabs.py index 45bb04b37e83..ba509fa6709d 100644 --- a/nemo/collections/common/data/vocabs.py +++ b/nemo/collections/common/data/vocabs.py @@ -19,6 +19,7 @@ import time import unicodedata from builtins import str as unicode +from contextlib import contextmanager from typing import List import nltk @@ -375,3 +376,8 @@ def encode(self, text): ps = [space] + ps + [space] return [self._label2id[p] for p in ps] + + @contextmanager + def set_phone_prob(self, prob=None): + # Add do nothing since this class doesn't support mixed g2p + yield diff --git a/nemo/collections/tts/models/fastpitch.py b/nemo/collections/tts/models/fastpitch.py index f1602be9f9ff..02d688044232 100644 --- a/nemo/collections/tts/models/fastpitch.py +++ b/nemo/collections/tts/models/fastpitch.py @@ -195,13 +195,21 @@ def parser(self): return self._parser def parse(self, str_input: str, normalize=True) -> torch.tensor: + if self.training: + logging.warning("parse() is meant to be called in eval mode.") if str_input[-1] not in [".", "!", "?"]: str_input = str_input + "." if normalize and self.text_normalizer_call is not None: str_input = self.text_normalizer_call(str_input, **self.text_normalizer_call_kwargs) - tokens = self.parser(str_input) + if self.learn_alignment: + # Disable mixed g2p representation + with self.vocab.set_phone_prob(prob=1.0): + tokens = self.parser(str_input) + else: + # TODO(Oktai15): remove it in 1.8.0 version + tokens = self.parser(str_input) x = torch.tensor(tokens).unsqueeze_(0).long().to(self.device) return x @@ -246,8 +254,8 @@ def forward( @typecheck(output_types={"spect": NeuralType(('B', 'D', 'T_spec'), MelSpectrogramType())}) def generate_spectrogram(self, tokens: 'torch.tensor', speaker: int = 0, pace: float = 1.0) -> torch.tensor: - # FIXME: return masks as well? - self.eval() + if self.training: + logging.warning("generate_spectrogram() is meant to be called in eval mode.") if isinstance(speaker, int): speaker = torch.tensor([speaker]).to(self.device) spect, *_ = self(text=tokens, durs=None, pitch=None, speaker=speaker, pace=pace) @@ -312,20 +320,20 @@ def training_step(self, batch, batch_idx): self.tb_logger.add_image( "train_mel_target", - plot_spectrogram_to_numpy(mels[0].data.cpu().numpy()), + plot_spectrogram_to_numpy(mels[0].data.cpu().float().numpy()), self.global_step, dataformats="HWC", ) - spec_predict = mels_pred[0].data.cpu().numpy() + spec_predict = mels_pred[0].data.cpu().float().numpy() self.tb_logger.add_image( "train_mel_predicted", plot_spectrogram_to_numpy(spec_predict), self.global_step, dataformats="HWC", ) if self.learn_alignment: - attn = attn_hard[0].data.cpu().numpy().squeeze() + attn = attn_hard[0].data.cpu().float().numpy().squeeze() self.tb_logger.add_image( "train_attn", plot_alignment_to_numpy(attn.T), self.global_step, dataformats="HWC", ) - soft_attn = attn_soft[0].data.cpu().numpy().squeeze() + soft_attn = attn_soft[0].data.cpu().float().numpy().squeeze() self.tb_logger.add_image( "train_soft_attn", plot_alignment_to_numpy(soft_attn.T), self.global_step, dataformats="HWC", ) @@ -396,11 +404,11 @@ def validation_epoch_end(self, outputs): if isinstance(self.logger, TensorBoardLogger): self.tb_logger.add_image( "val_mel_target", - plot_spectrogram_to_numpy(spec_target[0].data.cpu().numpy()), + plot_spectrogram_to_numpy(spec_target[0].data.cpu().float().numpy()), self.global_step, dataformats="HWC", ) - spec_predict = spec_predict[0].data.cpu().numpy() + spec_predict = spec_predict[0].data.cpu().float().numpy() self.tb_logger.add_image( "val_mel_predicted", plot_spectrogram_to_numpy(spec_predict), self.global_step, dataformats="HWC", ) @@ -428,12 +436,13 @@ def __setup_dataloader_from_config(self, cfg, shuffle_should_be: bool = True, na if cfg.dataset._target_ == "nemo.collections.asr.data.audio_to_text.FastPitchDataset": dataset = instantiate(cfg.dataset, parser=self.parser) elif cfg.dataset._target_ == "nemo.collections.tts.torch.data.TTSDataset": - dataset = instantiate( - cfg.dataset, - text_normalizer=self.normalizer, - text_normalizer_call_kwargs=self.text_normalizer_call_kwargs, - text_tokenizer=self.vocab, - ) + with self.vocab.set_phone_prob(prob=None if name == "val" else self.vocab.phoneme_probability): + dataset = instantiate( + cfg.dataset, + text_normalizer=self.normalizer, + text_normalizer_call_kwargs=self.text_normalizer_call_kwargs, + text_tokenizer=self.vocab, + ) else: # TODO(Oktai15): remove it in 1.8.0 version dataset = instantiate(cfg.dataset) diff --git a/nemo/collections/tts/models/mixer_tts.py b/nemo/collections/tts/models/mixer_tts.py index a39c994a9eaf..ce155ab1a97b 100644 --- a/nemo/collections/tts/models/mixer_tts.py +++ b/nemo/collections/tts/models/mixer_tts.py @@ -644,9 +644,13 @@ def generate_spectrogram( return pred_spect def parse(self, text: str, normalize=True) -> torch.Tensor: + if self.training: + logging.warning("parse() is meant to be called in eval mode.") if normalize and self.text_normalizer_call is not None: text = self.text_normalizer_call(text, **self.text_normalizer_call_kwargs) - return torch.tensor(self.tokenizer.encode(text)).long().unsqueeze(0).to(self.device) + with self.tokenizer.set_phone_prob(prob=1.0): + tokens = self.tokenizer.encode(text) + return torch.tensor(tokens).long().unsqueeze(0).to(self.device) def _loader(self, cfg): try: diff --git a/nemo/collections/tts/torch/data.py b/nemo/collections/tts/torch/data.py index 2d89025f69fa..6fe6d58efb6e 100644 --- a/nemo/collections/tts/torch/data.py +++ b/nemo/collections/tts/torch/data.py @@ -21,6 +21,7 @@ from typing import Callable, Dict, List, Optional, Union import librosa +import numpy as np import torch from nemo_text_processing.text_normalization.normalize import Normalizer from tqdm import tqdm @@ -134,9 +135,12 @@ def __init__( # Initialize text tokenizer self.text_tokenizer = text_tokenizer + + self.phoneme_probability = None if isinstance(self.text_tokenizer, BaseTokenizer): self.text_tokenizer_pad_id = text_tokenizer.pad self.tokens = text_tokenizer.tokens + self.phoneme_probability = self.text_tokenizer.phoneme_probability else: if text_tokenizer_pad_id is None: raise ValueError(f"text_tokenizer_pad_id must be specified if text_tokenizer is not BaseTokenizer") @@ -146,6 +150,7 @@ def __init__( self.text_tokenizer_pad_id = text_tokenizer_pad_id self.tokens = tokens + self.cache_text = True if self.phoneme_probability is None else False # Initialize text normalizer is specified self.text_normalizer = text_normalizer @@ -179,15 +184,14 @@ def __init__( if "normalized_text" not in item: text = item["text"] - if self.text_normalizer is not None: text = self.text_normalizer_call(text, **self.text_normalizer_call_kwargs) - file_info["normalized_text"] = text - file_info["text_tokens"] = self.text_tokenizer(text) else: file_info["normalized_text"] = item["normalized_text"] - file_info["text_tokens"] = self.text_tokenizer(item["normalized_text"]) + + if self.cache_text: + file_info["text_tokens"] = self.text_tokenizer(file_info["normalized_text"]) data.append(file_info) @@ -241,6 +245,7 @@ def __init__( hop_length=self.hop_len, win_length=self.win_length, window=window_fn(self.win_length, periodic=False).to(torch.float) if window_fn else None, + return_complex=True, ) # Initialize sup_data_path, sup_data_types and run preprocessing methods for every supplementary data type @@ -331,6 +336,13 @@ def add_align_prior_matrix(self, **kwargs): self.align_prior_matrix_folder.mkdir(exist_ok=True, parents=True) self.use_beta_binomial_interpolator = kwargs.pop('use_beta_binomial_interpolator', False) + if not self.cache_text: + if 'use_beta_binomial_interpolator' in kwargs and not self.use_beta_binomial_interpolator: + logging.warning( + "phoneme_probability is not None, but use_beta_binomial_interpolator=False, we" + " set use_beta_binomial_interpolator=True manually to use phoneme_probability." + ) + self.use_beta_binomial_interpolator = True if self.use_beta_binomial_interpolator: self.beta_binomial_interpolator = BetaBinomialInterpolator() @@ -386,9 +398,13 @@ def __getitem__(self, index): features = self.featurizer.process(sample["audio_filepath"], trim=self.trim) audio, audio_length = features, torch.tensor(features.shape[0]).long() - # Load text - text = torch.tensor(sample["text_tokens"]).long() - text_length = torch.tensor(len(sample["text_tokens"])).long() + if "text_tokens" in sample: + text = torch.tensor(sample["text_tokens"]).long() + text_length = torch.tensor(len(sample["text_tokens"])).long() + else: + tokenized = self.text_tokenizer(sample["normalized_text"]) + text = torch.tensor(tokenized).long() + text_length = torch.tensor(len(tokenized)).long() # Load mel if needed log_mel, log_mel_length = None, None @@ -417,6 +433,7 @@ def __getitem__(self, index): # Load alignment prior matrix if needed align_prior_matrix = None if AlignPriorMatrix in self.sup_data_types_set: + align_prior_matrix = None if self.use_beta_binomial_interpolator: mel_len = self.get_log_mel(audio).shape[2] align_prior_matrix = torch.from_numpy(self.beta_binomial_interpolator(mel_len, text_length.item())) @@ -823,7 +840,10 @@ def __getitem__(self, index): features = self.featurizer.process(sample["audio_filepath"], trim=self.trim) audio, audio_length = features, torch.tensor(features.shape[0]).long() - mel = torch.load(sample["mel_filepath"]) + if Path(sample["mel_filepath"]).suffix == ".npy": + mel = np.load(sample["mel_filepath"]) + else: + mel = torch.load(sample["mel_filepath"]) frames = math.ceil(self.n_segments / self.hop_length) if len(audio) > self.n_segments: diff --git a/nemo/collections/tts/torch/g2ps.py b/nemo/collections/tts/torch/g2ps.py index ac9e59d09f87..a71fbdf505d3 100644 --- a/nemo/collections/tts/torch/g2ps.py +++ b/nemo/collections/tts/torch/g2ps.py @@ -14,8 +14,10 @@ import abc import pathlib +import random import re import time +from typing import Optional import nltk import torch @@ -53,6 +55,7 @@ def __init__( ignore_ambiguous_words=True, heteronyms=None, encoding='latin-1', + phoneme_probability: Optional[float] = None, ): """English G2P module. This module converts words from grapheme to phoneme representation using phoneme_dict in CMU dict format. Optionally, it can ignore words which are heteronyms, ambiguous or marked as unchangeable by word_tokenize_func (see code for details). @@ -67,6 +70,9 @@ def __init__( ignore_ambiguous_words: Whether to not handle word via phoneme_dict with ambiguous phoneme sequences. Defaults to True. heteronyms (str, Path, List): Path to file with heteronyms (every line is new word) or list of words. encoding: Encoding type. + phoneme_probability (Optional[float]): The probability (0. self.phoneme_probability: + return word, True + # punctuation if re.search("[a-zA-Z]", word) is None: return list(word), True diff --git a/nemo/collections/tts/torch/tts_tokenizers.py b/nemo/collections/tts/torch/tts_tokenizers.py index 3d30dbfc6e77..d6c9d887430a 100644 --- a/nemo/collections/tts/torch/tts_tokenizers.py +++ b/nemo/collections/tts/torch/tts_tokenizers.py @@ -15,6 +15,7 @@ import abc import itertools import string +from contextlib import contextmanager from typing import List from nemo.collections.tts.torch.de_utils import german_text_preprocessing @@ -282,6 +283,9 @@ def __init__( Note that lower() function shouldn't applied here, because text can contains phonemes (it will be handled by g2p). """ + self.phoneme_probability = None + if hasattr(g2p, "phoneme_probability"): + self.phoneme_probability = g2p.phoneme_probability tokens = [] self.space, tokens = len(tokens), tokens + [space] # Space @@ -295,7 +299,12 @@ def __init__( vowels = [f'{p}{s}' for p, s in itertools.product(vowels, (0, 1, 2))] tokens.extend(vowels) - if chars: + if chars or self.phoneme_probability is not None: + if not chars: + logging.warning( + "phoneme_probability was not None, characters will be enabled even though " + "chars was set to False." + ) tokens.extend(string.ascii_lowercase) if apostrophe: @@ -308,7 +317,7 @@ def __init__( super().__init__(tokens, oov=oov, sep=sep, add_blank_at=add_blank_at) - self.chars = chars + self.chars = chars if self.phoneme_probability is None else True self.punct = punct self.stresses = stresses self.pad_with_space = pad_with_space @@ -321,7 +330,7 @@ def encode(self, text): ps, space, tokens = [], self.tokens[self.space], set(self.tokens) text = self.text_preprocessing_func(text) - g2p_text = self.g2p(text) + g2p_text = self.g2p(text) # TODO: handle infer for p in g2p_text: # noqa # Remove stress @@ -351,3 +360,13 @@ def encode(self, text): ps = [space] + ps + [space] return [self._token2id[p] for p in ps] + + @contextmanager + def set_phone_prob(self, prob): + if hasattr(self.g2p, "phoneme_probability"): + self.g2p.phoneme_probability = prob + try: + yield + finally: + if hasattr(self.g2p, "phoneme_probability"): + self.g2p.phoneme_probability = self.phoneme_probability