diff --git a/examples/tts/conf/tacotron2.yaml b/examples/tts/conf/tacotron2.yaml index 0a2555d33306..ae84da7ee36c 100644 --- a/examples/tts/conf/tacotron2.yaml +++ b/examples/tts/conf/tacotron2.yaml @@ -1,4 +1,4 @@ -name: &name "Tacotron 2" +name: &name Tacotron2 sample_rate: &sr 22050 # , , will be added by the tacotron2.py script labels: &labels [' ', '!', '"', "'", '(', ')', ',', '-', '.', ':', ';', '?', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', diff --git a/examples/tts/conf/waveglow.yaml b/examples/tts/conf/waveglow.yaml index d6be608b107e..dba752d9bc22 100644 --- a/examples/tts/conf/waveglow.yaml +++ b/examples/tts/conf/waveglow.yaml @@ -8,6 +8,7 @@ train_dataset: ??? validation_datasets: ??? model: + sigma: 1.0 train_ds: dataset: cls: "nemo.collections.tts.data.datalayers.AudioDataset" diff --git a/nemo/collections/asr/data/audio_to_text.py b/nemo/collections/asr/data/audio_to_text.py index 9b437e416fb2..b18d0ea4c1f8 100644 --- a/nemo/collections/asr/data/audio_to_text.py +++ b/nemo/collections/asr/data/audio_to_text.py @@ -135,6 +135,7 @@ def __init__( load_audio: bool = True, add_misc: bool = False, ): + self.parser = parser self.collection = collections.ASRAudioText( manifests_files=manifest_filepath.split(','), diff --git a/nemo/collections/tts/losses/glow_tts_loss.py b/nemo/collections/tts/losses/glow_tts_loss.py index 56bd01134888..727483b07ae0 100644 --- a/nemo/collections/tts/losses/glow_tts_loss.py +++ b/nemo/collections/tts/losses/glow_tts_loss.py @@ -78,9 +78,6 @@ def output_types(self): "logdet": NeuralType(elements_type=VoidType()), } - def __init__(self): - super().__init__() - @typecheck() def forward(self, z, y_m, y_logs, logdet, logw, logw_, x_lengths, y_lengths): diff --git a/nemo/collections/tts/losses/tacotron2loss.py b/nemo/collections/tts/losses/tacotron2loss.py index d6a3c26add16..2645aa7cb82a 100644 --- a/nemo/collections/tts/losses/tacotron2loss.py +++ b/nemo/collections/tts/losses/tacotron2loss.py @@ -12,8 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from torch import nn -from torch.nn.functional import pad +import torch from nemo.collections.tts.helpers.helpers import get_mask_from_lengths from nemo.core.classes import Loss, typecheck @@ -22,18 +21,16 @@ class Tacotron2Loss(Loss): - """ A Loss module that computes loss for Tacotron2 - """ + """A Loss module that computes loss for Tacotron2""" @property def input_types(self): return { - "mel_out": NeuralType(('B', 'T', 'D'), MelSpectrogramType()), - "mel_out_postnet": NeuralType(('B', 'T', 'D'), MelSpectrogramType()), - "gate_out": NeuralType(('B', 'T'), LogitsType()), - "mel_target": NeuralType(('B', 'T', 'D'), MelSpectrogramType()), - "gate_target": NeuralType(('B', 'T'), LogitsType()), - "target_len": NeuralType(('B'), LengthsType()), + "spec_pred_dec": NeuralType(('B', 'D', 'T'), MelSpectrogramType()), + "spec_pred_postnet": NeuralType(('B', 'D', 'T'), MelSpectrogramType()), + "gate_pred": NeuralType(('B', 'T'), LogitsType()), + "spec_target": NeuralType(('B', 'D', 'T'), MelSpectrogramType()), + "spec_target_len": NeuralType(('B'), LengthsType()), "pad_value": NeuralType(), } @@ -41,38 +38,47 @@ def input_types(self): def output_types(self): return { "loss": NeuralType(elements_type=LossType()), + "gate_target": NeuralType(('B', 'T'), LogitsType()), # Used for evaluation } @typecheck() - def forward(self, *, mel_out, mel_out_postnet, gate_out, mel_target, gate_target, target_len, pad_value): - mel_target.requires_grad = False + def forward(self, *, spec_pred_dec, spec_pred_postnet, gate_pred, spec_target, spec_target_len, pad_value): + # Make the gate target + max_len = spec_target.shape[2] + gate_target = torch.zeros(spec_target_len.shape[0], max_len) + gate_target = gate_target.type_as(gate_pred) + for i, length in enumerate(spec_target_len): + gate_target[i, length.data - 1 :] = 1 + + spec_target.requires_grad = False gate_target.requires_grad = False gate_target = gate_target.view(-1, 1) - max_len = mel_target.shape[2] + max_len = spec_target.shape[2] - if max_len < mel_out.shape[2]: + if max_len < spec_pred_dec.shape[2]: # Predicted len is larger than reference # Need to slice - mel_out = mel_out.narrow(2, 0, max_len) - mel_out_postnet = mel_out_postnet.narrow(2, 0, max_len) - gate_out = gate_out.narrow(1, 0, max_len).contiguous() - elif max_len > mel_out.shape[2]: + spec_pred_dec = spec_pred_dec.narrow(2, 0, max_len) + spec_pred_postnet = spec_pred_postnet.narrow(2, 0, max_len) + gate_pred = gate_pred.narrow(1, 0, max_len).contiguous() + elif max_len > spec_pred_dec.shape[2]: # Need to do padding - pad_amount = max_len - mel_out.shape[2] - mel_out = pad(mel_out, (0, pad_amount), value=pad_value) - mel_out_postnet = pad(mel_out_postnet, (0, pad_amount), value=pad_value) - gate_out = pad(gate_out, (0, pad_amount), value=1e3) - max_len = mel_out.shape[2] + pad_amount = max_len - spec_pred_dec.shape[2] + spec_pred_dec = torch.nn.functional.pad(spec_pred_dec, (0, pad_amount), value=pad_value) + spec_pred_postnet = torch.nn.functional.pad(spec_pred_postnet, (0, pad_amount), value=pad_value) + gate_pred = torch.nn.functional.pad(gate_pred, (0, pad_amount), value=1e3) + max_len = spec_pred_dec.shape[2] - mask = ~get_mask_from_lengths(target_len, max_len=max_len) - mask = mask.expand(mel_target.shape[1], mask.size(0), mask.size(1)) + mask = ~get_mask_from_lengths(spec_target_len, max_len=max_len) + mask = mask.expand(spec_target.shape[1], mask.size(0), mask.size(1)) mask = mask.permute(1, 0, 2) - mel_out.data.masked_fill_(mask, pad_value) - mel_out_postnet.data.masked_fill_(mask, pad_value) - gate_out.data.masked_fill_(mask[:, 0, :], 1e3) + spec_pred_dec.data.masked_fill_(mask, pad_value) + spec_pred_postnet.data.masked_fill_(mask, pad_value) + gate_pred.data.masked_fill_(mask[:, 0, :], 1e3) - gate_out = gate_out.view(-1, 1) - mel_loss = nn.MSELoss()(mel_out, mel_target) + nn.MSELoss()(mel_out_postnet, mel_target) - gate_loss = nn.BCEWithLogitsLoss()(gate_out, gate_target) - return mel_loss + gate_loss + gate_pred = gate_pred.view(-1, 1) + rnn_mel_loss = torch.nn.functional.mse_loss(spec_pred_dec, spec_target) + postnet_mel_loss = torch.nn.functional.mse_loss(spec_pred_postnet, spec_target) + gate_loss = torch.nn.functional.binary_cross_entropy_with_logits(gate_pred, gate_target) + return rnn_mel_loss + postnet_mel_loss + gate_loss, gate_target diff --git a/nemo/collections/tts/models/base.py b/nemo/collections/tts/models/base.py new file mode 100644 index 000000000000..868117747413 --- /dev/null +++ b/nemo/collections/tts/models/base.py @@ -0,0 +1,56 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from abc import ABC, abstractmethod + +from nemo.core.classes import ModelPT + + +class SpectrogramGenerator(ModelPT, ABC): + """ Base class for all TTS models that turn text into a spectrogram """ + + @abstractmethod + def parse(self, str_input: str, **kwargs) -> 'torch.tensor': + """ + A helper function that accepts raw pythong strings and turns it into a tensor. The tensor should have 2 + dimensions. The first is the batch, which should be of size 1. The second should represent time. The tensor + should represented either tokenized or embedded text, depending on the model. + """ + + @abstractmethod + def generate_spectrogram(self, tokens: 'torch.tensor', **kwargs) -> 'torch.tensor': + """ + Accepts a batch of text or text_tokens and returns a batch of spectrograms + + Args: + tokens: A torch tensor representing the text to be generated + + Returns: + sepctrograms + """ + + +class Vocoder(ModelPT, ABC): + """ Base class for all TTS models that generate audio conditioned a on spectrogram """ + + @abstractmethod + def convert_spectrogram_to_audio(self, spec: 'torch.tensor', **kwargs) -> 'torch.tensor': + """ + Accepts a batch of spectrograms and returns a batch of audio + + Args: + spec: A torch tensor representing the spectrograms to be vocoded + + Returns: + audio + """ diff --git a/nemo/collections/tts/models/glow_tts.py b/nemo/collections/tts/models/glow_tts.py index 172ac9d30cda..e98234a26893 100644 --- a/nemo/collections/tts/models/glow_tts.py +++ b/nemo/collections/tts/models/glow_tts.py @@ -25,10 +25,12 @@ from nemo.collections.asr.parts.perturb import process_augmentations from nemo.collections.tts.helpers.helpers import log_audio_to_tb, plot_alignment_to_numpy, plot_spectrogram_to_numpy from nemo.collections.tts.losses.glow_tts_loss import GlowTTSLoss +from nemo.collections.tts.models.base import SpectrogramGenerator from nemo.collections.tts.modules.glow_tts import GlowTTSModule -from nemo.core.classes import ModelPT +from nemo.core.classes.common import PretrainedModelInfo, typecheck +from nemo.core.neural_types.elements import LengthsType, MelSpectrogramType, TokenIndex +from nemo.core.neural_types.neural_type import NeuralType from nemo.utils import logging -from nemo.utils.decorators import experimental @dataclass @@ -53,8 +55,7 @@ class GlowTTSConfig: test_ds: Optional[Dict[Any, Any]] = None -@experimental -class GlowTTSModel(ModelPT): +class GlowTTSModel(SpectrogramGenerator): """ GlowTTS model used to generate spectrograms from text Consists of a text encoder and an invertible spectrogram decoder @@ -82,25 +83,29 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): decoder = instantiate(self._cfg.decoder) self.glow_tts = GlowTTSModule(encoder, decoder, n_speakers=cfg.n_speakers, gin_channels=cfg.gin_channels) - - self.setup_optimization() - self.loss = GlowTTSLoss() - def train_dataloader(self): - return self._train_dl - - def val_dataloader(self): - return self._val_dl - - def test_dataloader(self): - return self._test_dl + def parse(self, str_input: str) -> torch.tensor: + if str_input[-1] not in [".", "!", "?"]: + str_input = str_input + "." - def get_parser(self): - return self.parser + tokens = self.parser(str_input) - def forward(self, x, x_lengths, y=None, y_lengths=None, gen=False, noise_scale=0.3, length_scale=1.0): + x = torch.tensor(tokens).unsqueeze_(0).long().to(self.device) + return x + @typecheck( + input_types={ + "x": NeuralType(('B', 'T'), TokenIndex()), + "x_lengths": NeuralType(('B'), LengthsType()), + "y": NeuralType(('B', 'D', 'T'), MelSpectrogramType(), optional=True), + "y_lengths": NeuralType(('B'), LengthsType(), optional=True), + "gen": NeuralType(optional=True), + "noise_scale": NeuralType(optional=True), + "length_scale": NeuralType(optional=True), + } + ) + def forward(self, *, x, x_lengths, y=None, y_lengths=None, gen=False, noise_scale=0.3, length_scale=1.0): if gen: return self.glow_tts.generate_spect( text=x, text_lengths=x_lengths, noise_scale=noise_scale, length_scale=length_scale @@ -109,8 +114,9 @@ def forward(self, x, x_lengths, y=None, y_lengths=None, gen=False, noise_scale=0 return self.glow_tts(text=x, text_lengths=x_lengths, spect=y, spect_lengths=y_lengths) def step(self, y, y_lengths, x, x_lengths): - - z, y_m, y_logs, logdet, logw, logw_, y_lengths, attn = self(x, x_lengths, y, y_lengths, gen=False) + z, y_m, y_logs, logdet, logw, logw_, y_lengths, attn = self( + x=x, x_lengths=x_lengths, y=y, y_lengths=y_lengths, gen=False + ) l_mle, l_length, logdet = self.loss( z=z, @@ -128,7 +134,6 @@ def step(self, y, y_lengths, x, x_lengths): return l_mle, l_length, logdet, loss, attn def training_step(self, batch, batch_idx): - y, y_lengths, x, x_lengths = batch y, y_lengths = self.preprocessor(input_signal=y, length=y_lengths) @@ -144,14 +149,13 @@ def training_step(self, batch, batch_idx): return output def validation_step(self, batch, batch_idx): - y, y_lengths, x, x_lengths = batch y, y_lengths = self.preprocessor(input_signal=y, length=y_lengths) l_mle, l_length, logdet, loss, attn = self.step(y, y_lengths, x, x_lengths) - y_gen, attn_gen = self(x, x_lengths, gen=True) + y_gen, attn_gen = self(x=x, x_lengths=x_lengths, gen=True) return { "loss": loss, @@ -178,8 +182,7 @@ def validation_epoch_end(self, outputs): 'val_logdet': avg_logdet, } if self.logger is not None and self.logger.experiment is not None: - parser = self.get_parser() - separated_phonemes = "|".join([parser.symbols[c] for c in outputs[0]['x'][0]]) + separated_phonemes = "|".join([self.parser.symbols[c] for c in outputs[0]['x'][0]]) self.logger.experiment.add_text("separated phonemes", separated_phonemes, self.global_step) self.logger.experiment.add_image( "real_spectrogram", @@ -247,33 +250,37 @@ def setup_training_data(self, train_data_config: Optional[DictConfig]): self._train_dl = self._setup_dataloader_from_config(cfg=train_data_config) def setup_validation_data(self, val_data_config: Optional[DictConfig]): - self._val_dl = self._setup_dataloader_from_config(cfg=val_data_config) + self._validation_dl = self._setup_dataloader_from_config(cfg=val_data_config) def setup_test_data(self, test_data_config: Optional[DictConfig]): self._test_dl = self._setup_dataloader_from_config(cfg=test_data_config) - def generate_spectrogram(self, text: str, noise_scale: float = 0.0, length_scale: float = 1.0) -> torch.Tensor: + def generate_spectrogram( + self, tokens: 'torch.tensor', noise_scale: float = 0.0, length_scale: float = 1.0 + ) -> torch.tensor: self.eval() - text_parser = self.get_parser() - - if text[-1] != ".": - text = text + "." - - text_seq = text_parser(text) - - x = torch.Tensor(text_seq).unsqueeze(0).cuda().long() - x_lengths = torch.tensor([x.shape[1]]).cuda() + token_len = torch.tensor([tokens.shape[1]]).to(self.device) + spect, _ = self(x=tokens, x_lengths=token_len, gen=True, noise_scale=noise_scale, length_scale=length_scale) - with torch.no_grad(): - spect, _ = self(x, x_lengths, gen=True, noise_scale=noise_scale, length_scale=length_scale) - - return spect[0] + return spect @classmethod - def list_available_models(cls) -> Optional[Dict[str, str]]: - pass - - def export(self, **kwargs): - pass + def list_available_models(cls) -> 'List[PretrainedModelInfo]': + """ + This method returns a list of pre-trained model which can be instantiated directly from NVIDIA's NGC cloud. + Returns: + List of available pre-trained models. + """ + list_of_models = [] + model = PretrainedModelInfo( + pretrained_model_name="GlowTTS-22050Hz", + location="https://nemo-public.s3.us-east-2.amazonaws.com/nemo-1.0.0alpha-tests/glow_tts.nemo", + description=( + "The model is trained on LJSpeech sampled at 22050Hz, and can be used to generate female " + "English voices with an American accent." + ), + ) + list_of_models.append(model) + return list_of_models diff --git a/nemo/collections/tts/models/tacotron2.py b/nemo/collections/tts/models/tacotron2.py index 636094c7eb26..c313850d7ccb 100644 --- a/nemo/collections/tts/models/tacotron2.py +++ b/nemo/collections/tts/models/tacotron2.py @@ -18,11 +18,14 @@ import torch from hydra.utils import instantiate from omegaconf import MISSING, DictConfig, OmegaConf, open_dict +from omegaconf.errors import ConfigAttributeError from torch import nn +from nemo.collections.asr.parts import parsers from nemo.collections.tts.helpers.helpers import tacotron2_log_to_tb_func from nemo.collections.tts.losses.tacotron2loss import Tacotron2Loss -from nemo.core.classes import ModelPT, typecheck +from nemo.collections.tts.models.base import SpectrogramGenerator +from nemo.core.classes.common import PretrainedModelInfo, typecheck from nemo.core.neural_types.elements import ( AudioSignal, EmbeddedTextType, @@ -33,7 +36,6 @@ ) from nemo.core.neural_types.neural_type import NeuralType from nemo.utils import logging -from nemo.utils.decorators import experimental @dataclass @@ -58,10 +60,8 @@ class Tacotron2Config: validation_ds: Optional[Dict[Any, Any]] = None -@experimental # TODO: Need to implement abstract methods: list_available_models -class Tacotron2Model(ModelPT): - """ Tacotron 2 Model that is used to generate mel spectrograms from text - """ +class Tacotron2Model(SpectrogramGenerator): + """Tacotron 2 Model that is used to generate mel spectrograms from text""" # TODO: tensorboard for training def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None): @@ -79,68 +79,138 @@ def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None): OmegaConf.merge(cfg, schema) self.pad_value = self._cfg.preprocessor.params.pad_value + self._parser = None self.audio_to_melspec_precessor = instantiate(self._cfg.preprocessor) self.text_embedding = nn.Embedding(len(cfg.labels) + 3, 512) self.encoder = instantiate(self._cfg.encoder) self.decoder = instantiate(self._cfg.decoder) self.postnet = instantiate(self._cfg.postnet) self.loss = Tacotron2Loss() + self.calculate_loss = True + + @property + def parser(self): + if self._parser is not None: + return self._parser + if self._validation_dl is not None: + return self._validation_dl.dataset.parser + if self._test_dl is not None: + return self._test_dl.dataset.parser + if self._train_dl is not None: + return self._train_dl.dataset.parser + + # Else construct a parser + # Try to get params from validation, test, and then train + params = {} + try: + params = self._cfg.validation_ds.dataset.params + except ConfigAttributeError: + pass + if params == {}: + try: + params = self._cfg.test_ds.dataset.params + except ConfigAttributeError: + pass + if params == {}: + try: + params = self._cfg.train_ds.dataset.params + except ConfigAttributeError: + pass + + name = params.get('parser', None) or params.get('parser', None) or 'en' + unk_id = params.get('unk_index', None) or params.get('unk_index', None) or -1 + blank_id = params.get('blank_index', None) or params.get('blank_index', None) or -1 + do_normalize = params.get('normalize', None) or params.get('normalize', None) or False + self._parser = parsers.make_parser( + labels=self._cfg.labels, name=name, unk_id=unk_id, blank_id=blank_id, do_normalize=do_normalize, + ) + return self._parser + + def parse(self, str_input: str) -> torch.tensor: + tokens = self.parser(str_input) + # Parser doesn't add bos and eos ids, so maunally add it + tokens = [len(self._cfg.labels)] + tokens + [len(self._cfg.labels) + 1] + tokens_tensor = torch.tensor(tokens).unsqueeze_(0).to(self.device) + + return tokens_tensor @property def input_types(self): - return { - "audio": NeuralType(('B', 'T'), AudioSignal()), - "audio_len": NeuralType(('B'), LengthsType()), - "tokens": NeuralType(('B', 'T'), EmbeddedTextType()), - "token_len": NeuralType(('B'), LengthsType()), - } + if self.training: + return { + "tokens": NeuralType(('B', 'T'), EmbeddedTextType()), + "token_len": NeuralType(('B'), LengthsType()), + "audio": NeuralType(('B', 'T'), AudioSignal()), + "audio_len": NeuralType(('B'), LengthsType()), + } + else: + return { + "tokens": NeuralType(('B', 'T'), EmbeddedTextType()), + "token_len": NeuralType(('B'), LengthsType()), + "audio": NeuralType(('B', 'T'), AudioSignal(), optional=True), + "audio_len": NeuralType(('B'), LengthsType(), optional=True), + } @property def output_types(self): + if not self.calculate_loss and not self.training: + return { + "spec_pred_dec": NeuralType(('B', 'D', 'T'), MelSpectrogramType()), + "spec_pred_postnet": NeuralType(('B', 'D', 'T'), MelSpectrogramType()), + "gate_pred": NeuralType(('B', 'T'), LogitsType()), + "alignments": NeuralType(('B', 'T', 'T'), SequenceToSequenceAlignmentType()), + } return { - "mel_out": NeuralType(('B', 'T', 'D'), MelSpectrogramType()), - "mel_out_postnet": NeuralType(('B', 'T', 'D'), MelSpectrogramType()), - "gate_out": NeuralType(('B', 'T'), LogitsType()), - "mel_target": NeuralType(('B', 'T', 'D'), MelSpectrogramType()), - "gate_target": NeuralType(('B', 'T'), LogitsType()), - "target_len": NeuralType(('B'), LengthsType()), + "spec_pred_dec": NeuralType(('B', 'D', 'T'), MelSpectrogramType()), + "spec_pred_postnet": NeuralType(('B', 'D', 'T'), MelSpectrogramType()), + "gate_pred": NeuralType(('B', 'T'), LogitsType()), + "spec_target": NeuralType(('B', 'D', 'T'), MelSpectrogramType()), + "spec_target_len": NeuralType(('B'), LengthsType()), "alignments": NeuralType(('B', 'T', 'T'), SequenceToSequenceAlignmentType()), } @typecheck() - def forward(self, *, audio, audio_len, tokens, token_len): - spec, spec_len = self.audio_to_melspec_precessor(audio, audio_len) + def forward(self, *, tokens, token_len, audio=None, audio_len=None): + if audio is not None and audio_len is not None: + spec_target, spec_target_len = self.audio_to_melspec_precessor(audio, audio_len) token_embedding = self.text_embedding(tokens).transpose(1, 2) encoder_embedding = self.encoder(token_embedding=token_embedding, token_len=token_len) if self.training: - spec_dec, gate, alignments = self.decoder( - memory=encoder_embedding, decoder_inputs=spec, memory_lengths=token_len + spec_pred_dec, gate_pred, alignments = self.decoder( + memory=encoder_embedding, decoder_inputs=spec_target, memory_lengths=token_len ) else: - spec_dec, gate, alignments, _ = self.decoder(memory=encoder_embedding, memory_lengths=token_len) - spec_postnet = self.postnet(mel_spec=spec_dec) + spec_pred_dec, gate_pred, alignments, _ = self.decoder(memory=encoder_embedding, memory_lengths=token_len) + spec_pred_postnet = self.postnet(mel_spec=spec_pred_dec) - max_len = spec.shape[2] - gate_padded = torch.zeros(spec_len.shape[0], max_len) - gate_padded = gate_padded.type_as(gate) - for i, length in enumerate(spec_len): - gate_padded[i, length.data - 1 :] = 1 + if not self.calculate_loss: + return spec_pred_dec, spec_pred_postnet, gate_pred, alignments + return spec_pred_dec, spec_pred_postnet, gate_pred, spec_target, spec_target_len, alignments - return spec_dec, spec_postnet, gate, spec, gate_padded, spec_len, alignments + @typecheck( + input_types={"tokens": NeuralType(('B', 'T'), EmbeddedTextType())}, + output_types={"spec": NeuralType(('B', 'D', 'T'), MelSpectrogramType())}, + ) + def generate_spectrogram(self, *, tokens): + self.eval() + self.calculate_loss = False + token_len = torch.tensor([len(i) for i in tokens]).to(self.device) + tensors = self(tokens=tokens, token_len=token_len) + spectrogram_pred = tensors[1] + return spectrogram_pred def training_step(self, batch, batch_idx): audio, audio_len, tokens, token_len = batch - spec_dec, spec_postnet, gate, spec, gate_padded, spec_len, _ = self.forward( + spec_pred_dec, spec_pred_postnet, gate_pred, spec_target, spec_target_len, _ = self.forward( audio=audio, audio_len=audio_len, tokens=tokens, token_len=token_len ) - loss = self.loss( - mel_out=spec_dec, - mel_out_postnet=spec_postnet, - gate_out=gate, - mel_target=spec, - gate_target=gate_padded, - target_len=spec_len, + loss, _ = self.loss( + spec_pred_dec=spec_pred_dec, + spec_pred_postnet=spec_pred_postnet, + gate_pred=gate_pred, + spec_target=spec_target, + spec_target_len=spec_target_len, pad_value=self.pad_value, ) @@ -153,25 +223,24 @@ def training_step(self, batch, batch_idx): def validation_step(self, batch, batch_idx): audio, audio_len, tokens, token_len = batch - spec_dec, spec_postnet, gate, spec, gate_padded, spec_len, alignments = self.forward( + spec_pred_dec, spec_pred_postnet, gate_pred, spec_target, spec_target_len, alignments = self.forward( audio=audio, audio_len=audio_len, tokens=tokens, token_len=token_len ) - loss = self.loss( - mel_out=spec_dec, - mel_out_postnet=spec_postnet, - gate_out=gate, - mel_target=spec, - gate_target=gate_padded, - target_len=spec_len, + loss, gate_target = self.loss( + spec_pred_dec=spec_pred_dec, + spec_pred_postnet=spec_pred_postnet, + gate_pred=gate_pred, + spec_target=spec_target, + spec_target_len=spec_target_len, pad_value=self.pad_value, ) return { "loss": loss, - "mel_target": spec, - "mel_postnet": spec_postnet, - "gate": gate, - "gate_target": gate_padded, + "mel_target": spec_target, + "mel_postnet": spec_pred_postnet, + "gate": gate_pred, + "gate_target": gate_target, "alignments": alignments, } @@ -219,6 +288,20 @@ def setup_validation_data(self, cfg): self._validation_dl = self.__setup_dataloader_from_config(cfg, shuffle_should_be=False, name="validation") @classmethod - def list_available_models(cls) -> 'Optional[Dict[str, str]]': - """TODO: Implement me!""" - pass + def list_available_models(cls) -> 'List[PretrainedModelInfo]': + """ + This method returns a list of pre-trained model which can be instantiated directly from NVIDIA's NGC cloud. + Returns: + List of available pre-trained models. + """ + list_of_models = [] + model = PretrainedModelInfo( + pretrained_model_name="Tacotron2-22050Hz", + location="https://nemo-public.s3.us-east-2.amazonaws.com/nemo-1.0.0alpha-tests/tacotron2.nemo", + description=( + "The model is trained on LJSpeech sampled at 22050Hz, and can be used to generate female " + "English voices with an American accent." + ), + ) + list_of_models.append(model) + return list_of_models diff --git a/nemo/collections/tts/models/waveglow.py b/nemo/collections/tts/models/waveglow.py index 2d1183627696..7a68d1fa7b9d 100644 --- a/nemo/collections/tts/models/waveglow.py +++ b/nemo/collections/tts/models/waveglow.py @@ -21,8 +21,9 @@ from nemo.collections.tts.helpers.helpers import waveglow_log_to_tb_func from nemo.collections.tts.losses.waveglowloss import WaveGlowLoss +from nemo.collections.tts.models.base import Vocoder from nemo.collections.tts.modules.waveglow import OperationMode -from nemo.core.classes import ModelPT, typecheck +from nemo.core.classes.common import PretrainedModelInfo, typecheck from nemo.core.neural_types.elements import ( AudioSignal, LengthsType, @@ -32,7 +33,6 @@ ) from nemo.core.neural_types.neural_type import NeuralType from nemo.utils import logging -from nemo.utils.decorators import experimental @dataclass @@ -50,14 +50,13 @@ class Preprocessor: class WaveglowConfig: waveglow: Dict[Any, Any] = MISSING preprocessor: Preprocessor = Preprocessor() + sigma: int = MISSING train_ds: Optional[Dict[Any, Any]] = None validation_ds: Optional[Dict[Any, Any]] = None -@experimental # TODO: Need to implement abstract methods: list_available_models -class WaveGlowModel(ModelPT): - """ Waveglow model used to convert betweeen spectrograms and audio - """ +class WaveGlowModel(Vocoder): + """Waveglow model used to convert betweeen spectrograms and audio""" def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None): if isinstance(cfg, dict): @@ -74,7 +73,7 @@ def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None): OmegaConf.merge(cfg, schema) self.pad_value = self._cfg.preprocessor.params.pad_value - self.sigma = 1.0 + self.sigma = self._cfg.sigma self.audio_to_melspec_precessor = instantiate(self._cfg.preprocessor) self.waveglow = instantiate(self._cfg.waveglow) self.mode = OperationMode.infer @@ -112,7 +111,7 @@ def forward(self, *, audio, audio_len, run_inverse=True): f"WaveGlowModel's mode {self.mode} does not match WaveGlowModule's mode {self.waveglow.mode}" ) spec, spec_len = self.audio_to_melspec_precessor(audio, audio_len) - tensors = self.waveglow(spect=spec, audio=audio, run_inverse=run_inverse) + tensors = self.waveglow(spec=spec, audio=audio, run_inverse=run_inverse, sigma=self.sigma) if self.mode == OperationMode.training: return tensors[:-1] # z, log_s_list, log_det_W_list elif self.mode == OperationMode.validation: @@ -120,11 +119,25 @@ def forward(self, *, audio, audio_len, run_inverse=True): return z, log_s_list, log_det_W_list, audio_pred, spec, spec_len return tensors # audio_pred + @typecheck( + input_types={"spec": NeuralType(('B', 'D', 'T'), MelSpectrogramType()), "sigma": NeuralType(optional=True)}, + output_types={"audio": NeuralType(('B', 'T'), AudioSignal())}, + ) + def convert_spectrogram_to_audio(self, spec: torch.Tensor, sigma: float = 1.0) -> torch.Tensor: + self.eval() + self.mode = OperationMode.infer + self.waveglow.mode = OperationMode.infer + + with torch.no_grad(): + audio = self.waveglow(spec=spec, run_inverse=True, audio=None, sigma=sigma) + + return audio + def training_step(self, batch, batch_idx): self.mode = OperationMode.training self.waveglow.mode = OperationMode.training audio, audio_len = batch - z, log_s_list, log_det_W_list = self.forward(audio=audio, audio_len=audio_len) + z, log_s_list, log_det_W_list = self(audio=audio, audio_len=audio_len, run_inverse=False) loss = self.loss(z=z, log_s_list=log_s_list, log_det_W_list=log_det_W_list, sigma=self.sigma) output = { @@ -138,7 +151,7 @@ def validation_step(self, batch, batch_idx): self.mode = OperationMode.validation self.waveglow.mode = OperationMode.validation audio, audio_len = batch - z, log_s_list, log_det_W_list, audio_pred, spec, spec_len = self.forward( + z, log_s_list, log_det_W_list, audio_pred, spec, spec_len = self( audio=audio, audio_len=audio_len, run_inverse=(batch_idx == 0) ) loss = self.loss(z=z, log_s_list=log_s_list, log_det_W_list=log_det_W_list, sigma=self.sigma) @@ -189,17 +202,18 @@ def setup_training_data(self, cfg): def setup_validation_data(self, cfg): self._validation_dl = self.__setup_dataloader_from_config(cfg, shuffle_should_be=False, name="validation") - def convert_spectrogram_to_audio(self, spect: torch.Tensor) -> torch.Tensor: - self.eval() - self.mode = OperationMode.infer - self.waveglow.mode = OperationMode.infer - - with torch.no_grad(): - audio = self.waveglow(spect=spect.unsqueeze(0), run_inverse=True, audio=None) - - return audio.squeeze(0) - @classmethod - def list_available_models(cls) -> 'Optional[Dict[str, str]]': - """TODO: Implement me!""" - pass + def list_available_models(cls) -> 'List[PretrainedModelInfo]': + """ + This method returns a list of pre-trained model which can be instantiated directly from NVIDIA's NGC cloud. + Returns: + List of available pre-trained models. + """ + list_of_models = [] + model = PretrainedModelInfo( + pretrained_model_name="WaveGlow-22050Hz", + location="https://nemo-public.s3.us-east-2.amazonaws.com/nemo-1.0.0alpha-tests/waveglow.nemo", + description="The model is trained on LJSpeech sampled at 22050Hz, and can be used as an universal vocoder", + ) + list_of_models.append(model) + return list_of_models diff --git a/nemo/collections/tts/modules/glow_tts.py b/nemo/collections/tts/modules/glow_tts.py index 7912a5ee2d3c..969412a16003 100644 --- a/nemo/collections/tts/modules/glow_tts.py +++ b/nemo/collections/tts/modules/glow_tts.py @@ -119,7 +119,7 @@ def __init__( self.ffn_layers = nn.ModuleList() self.norm_layers_2 = nn.ModuleList() - for i in range(self.n_layers): + for _ in range(self.n_layers): self.attn_layers.append( glow_tts_submodules.AttentionBlock( hidden_channels, hidden_channels, n_heads, window_size=window_size, p_dropout=p_dropout, @@ -158,7 +158,7 @@ def output_types(self): } @typecheck() - def forward(self, text, text_lengths, speaker_embeddings=None): + def forward(self, *, text, text_lengths, speaker_embeddings=None): x = self.emb(text) * math.sqrt(self.hidden_channels) # [b, t, h] @@ -198,12 +198,10 @@ def forward(self, text, text_lengths, speaker_embeddings=None): def save_to(self, save_path: str): """TODO: Implement""" - pass @classmethod def restore_from(cls, restore_path: str): """TODO: Implement""" - pass @experimental @@ -241,7 +239,7 @@ def __init__( self.n_sqz = n_sqz self.flows = nn.ModuleList() - for b in range(n_blocks): + for _ in range(n_blocks): self.flows.append(glow_tts_submodules.ActNorm(channels=in_channels * n_sqz)) self.flows.append(glow_tts_submodules.InvConvNear(channels=in_channels * n_sqz, n_split=n_split)) self.flows.append( @@ -274,7 +272,7 @@ def output_types(self): } @typecheck() - def forward(self, spect, spect_mask, speaker_embeddings=None, reverse=False): + def forward(self, *, spect, spect_mask, speaker_embeddings=None, reverse=False): if not reverse: flows = self.flows logdet_tot = 0 @@ -329,12 +327,10 @@ def store_inverse(self): def save_to(self, save_path: str): """TODO: Implement""" - pass @classmethod def restore_from(cls, restore_path: str): """TODO: Implement""" - pass class GlowTTSModule(NeuralModule): @@ -362,10 +358,10 @@ def __init__( def input_types(self): return { "text": NeuralType(('B', 'T'), TokenIndex()), - "text_lengths": NeuralType(('B',), LengthsType()), + "text_lengths": NeuralType(('B'), LengthsType()), "spect": NeuralType(('B', 'D', 'T'), MelSpectrogramType()), - "spect_lengths": NeuralType(('B',), LengthsType()), - "speaker": NeuralType(('B',), IntType(), optional=True), + "spect_lengths": NeuralType(('B'), LengthsType()), + "speaker": NeuralType(('B'), IntType(), optional=True), } @property @@ -374,21 +370,21 @@ def output_types(self): "z": NeuralType(('B', 'D', 'T'), NormalDistributionSamplesType()), "y_m": NeuralType(('B', 'D', 'T'), NormalDistributionMeanType()), "y_logs": NeuralType(('B', 'D', 'T'), NormalDistributionLogVarianceType()), - "logdet": NeuralType(('B',), LogDeterminantType()), + "logdet": NeuralType(('B'), LogDeterminantType()), "log_durs_predicted": NeuralType(('B', 'T'), TokenLogDurationType()), "log_durs_extracted": NeuralType(('B', 'T'), TokenLogDurationType()), - "spect_lengths": NeuralType(('B',), LengthsType()), + "spect_lengths": NeuralType(('B'), LengthsType()), "attn": NeuralType(('B', 'T', 'T'), SequenceToSequenceAlignmentType()), } @typecheck() - def forward(self, text, text_lengths, spect, spect_lengths, g=None): + def forward(self, *, text, text_lengths, spect, spect_lengths, speaker=None): - if g is not None: - g = F.normalize(self.emb_g(g)).unsqueeze(-1) # [b, h] + if speaker is not None: + speaker = F.normalize(self.emb_g(speaker)).unsqueeze(-1) # [b, h] x_m, x_logs, log_durs_predicted, x_mask = self.encoder( - text=text, text_lengths=text_lengths, speaker_embeddings=g + text=text, text_lengths=text_lengths, speaker_embeddings=speaker ) y_max_length = spect.size(2) @@ -400,7 +396,7 @@ def forward(self, text, text_lengths, spect, spect_lengths, g=None): y_mask = torch.unsqueeze(glow_tts_submodules.sequence_mask(spect_lengths, y_max_length), 1).to(x_mask.dtype) attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2) - z, logdet = self.decoder(spect=spect, spect_mask=y_mask, speaker_embeddings=g, reverse=False) + z, logdet = self.decoder(spect=spect, spect_mask=y_mask, speaker_embeddings=speaker, reverse=False) with torch.no_grad(): x_s_sq_r = torch.exp(-2 * x_logs) @@ -419,13 +415,26 @@ def forward(self, text, text_lengths, spect, spect_lengths, g=None): return z, y_m, y_logs, logdet, log_durs_predicted, log_durs_extracted, spect_lengths, attn - def generate_spect(self, text, text_lengths, g=None, noise_scale=0.3, length_scale=1.0): + @typecheck( + input_types={ + "text": NeuralType(('B', 'T'), TokenIndex()), + "text_lengths": NeuralType(('B',), LengthsType()), + "speaker": NeuralType(('B'), IntType(), optional=True), + "noise_scale": NeuralType(optional=True), + "length_scale": NeuralType(optional=True), + }, + output_types={ + "y": NeuralType(('B', 'D', 'T'), MelSpectrogramType()), + "attn": NeuralType(('B', 'T', 'T'), SequenceToSequenceAlignmentType()), + }, + ) + def generate_spect(self, *, text, text_lengths, speaker=None, noise_scale=0.3, length_scale=1.0): - if g is not None: - g = F.normalize(self.emb_g(g)).unsqueeze(-1) # [b, h] + if speaker is not None: + speaker = F.normalize(self.emb_g(speaker)).unsqueeze(-1) # [b, h] x_m, x_logs, log_durs_predicted, x_mask = self.encoder( - text=text, text_lengths=text_lengths, speaker_embeddings=g + text=text, text_lengths=text_lengths, speaker_embeddings=speaker ) w = torch.exp(log_durs_predicted) * x_mask.squeeze() * length_scale @@ -444,15 +453,13 @@ def generate_spect(self, text, text_lengths, g=None, noise_scale=0.3, length_sca y_logs = torch.matmul(x_logs, attn) z = (y_m + torch.exp(y_logs) * torch.randn_like(y_m) * noise_scale) * y_mask - y, _ = self.decoder(spect=z, spect_mask=y_mask, speaker_embeddings=g, reverse=True) + y, _ = self.decoder(spect=z, spect_mask=y_mask, speaker_embeddings=speaker, reverse=True) return y, attn def save_to(self, save_path: str): """TODO: Implement""" - pass @classmethod def restore_from(cls, restore_path: str): """TODO: Implement""" - pass diff --git a/nemo/collections/tts/modules/tacotron2.py b/nemo/collections/tts/modules/tacotron2.py index 525475cf533a..11a74bb54574 100644 --- a/nemo/collections/tts/modules/tacotron2.py +++ b/nemo/collections/tts/modules/tacotron2.py @@ -113,12 +113,12 @@ def forward(self, *, token_embedding, token_len): return outputs def save_to(self, save_path: str): - # TODO: Implement me!!! + # TODO: Implement me! pass @classmethod def restore_from(cls, restore_path: str): - # TODO: Implement me!!! + # TODO: Implement me! pass @@ -385,12 +385,12 @@ def infer(self, *, memory, memory_lengths): return mel_outputs, gate_outputs, alignments, mel_lengths def save_to(self, save_path: str): - # TODO: Implement me!!! + # TODO: Implement me! pass @classmethod def restore_from(cls, restore_path: str): - # TODO: Implement me!!! + # TODO: Implement me! pass @@ -489,10 +489,10 @@ def forward(self, *, mel_spec): return mel_spec + mel_spec_out def save_to(self, save_path: str): - # TODO: Implement me!!! + # TODO: Implement me! pass @classmethod def restore_from(cls, restore_path: str): - # TODO: Implement me!!! + # TODO: Implement me! pass diff --git a/nemo/collections/tts/modules/waveglow.py b/nemo/collections/tts/modules/waveglow.py index 22904d1b5901..5aadb66a0e34 100644 --- a/nemo/collections/tts/modules/waveglow.py +++ b/nemo/collections/tts/modules/waveglow.py @@ -97,7 +97,7 @@ def __init__( self.n_remaining_channels = n_remaining_channels @typecheck() - def forward(self, spect, audio=None, run_inverse=True): + def forward(self, spec, audio=None, run_inverse=True, sigma=1.0): """ TODO """ if self.training and self.mode != OperationMode.training: @@ -108,11 +108,11 @@ def forward(self, spect, audio=None, run_inverse=True): audio_pred = torch.zeros((1, 1)) if audio is not None and self.mode != OperationMode.infer: # audio_to_normal_dist is used to calculate loss so only run this in train or val model - z, log_s_list, log_det_W_list = self.audio_to_normal_dist(spect=spect, audio=audio) + z, log_s_list, log_det_W_list = self.audio_to_normal_dist(spec=spec, audio=audio) if run_inverse: # norm_dist_to_audio is used to predict audio from spectrogram so only used in val or infer mode # Could also log train audio but currently not done - audio_pred = self.norm_dist_to_audio(spect=spect) + audio_pred = self.norm_dist_to_audio(spec=spec, sigma=sigma) # Return the necessary tensors if self.mode == OperationMode.training or self.mode == OperationMode.validation: @@ -122,10 +122,10 @@ def forward(self, spect, audio=None, run_inverse=True): @property def input_types(self): return { - "spect": NeuralType(('B', 'D', 'T'), MelSpectrogramType()), + "spec": NeuralType(('B', 'D', 'T'), MelSpectrogramType()), "audio": NeuralType(('B', 'T'), AudioSignal(), optional=True), "run_inverse": NeuralType(elements_type=IntType(), optional=True), - # "sigma": NeuralType(elements_type=BoolType(), optional=True), # TODO: Add to forward + "sigma": NeuralType(optional=True), } @property @@ -152,16 +152,16 @@ def input_example(self): mel = torch.randn((1, self.n_mel_channels, 96), device=par.device, dtype=par.dtype) return tuple([mel]) - def audio_to_normal_dist(self, *, spect: torch.Tensor, audio: torch.Tensor) -> (torch.Tensor, list, list): + def audio_to_normal_dist(self, *, spec: torch.Tensor, audio: torch.Tensor) -> (torch.Tensor, list, list): # Upsample spectrogram to size of audio - spect = self.upsample(spect) - assert spect.size(2) >= audio.size(1) - if spect.size(2) > audio.size(1): - spect = spect[:, :, : audio.size(1)] + spec = self.upsample(spec) + assert spec.size(2) >= audio.size(1) + if spec.size(2) > audio.size(1): + spec = spec[:, :, : audio.size(1)] - spect = spect.unfold(2, self.n_group, self.n_group).permute(0, 2, 1, 3) - spect = spect.contiguous().view(spect.size(0), spect.size(1), -1) - spect = spect.permute(0, 2, 1) + spec = spec.unfold(2, self.n_group, self.n_group).permute(0, 2, 1, 3) + spec = spec.contiguous().view(spec.size(0), spec.size(1), -1) + spec = spec.permute(0, 2, 1) audio = audio.unfold(1, self.n_group, self.n_group).permute(0, 2, 1) output_audio = [] @@ -180,7 +180,7 @@ def audio_to_normal_dist(self, *, spect: torch.Tensor, audio: torch.Tensor) -> ( audio_0 = audio[:, :n_half, :] audio_1 = audio[:, n_half:, :] - output = self.wavenet[k]((audio_0, spect)) + output = self.wavenet[k]((audio_0, spec)) log_s = output[:, n_half:, :] b = output[:, :n_half, :] audio_1 = torch.exp(log_s) * audio_1 + b @@ -191,18 +191,18 @@ def audio_to_normal_dist(self, *, spect: torch.Tensor, audio: torch.Tensor) -> ( output_audio.append(audio) return torch.cat(output_audio, 1), log_s_list, log_det_W_list - def norm_dist_to_audio(self, *, spect, sigma: float = 1.0): - spect = self.upsample(spect) + def norm_dist_to_audio(self, *, spec, sigma: float = 1.0): + spec = self.upsample(spec) # trim conv artifacts. maybe pad spec to kernel multiple time_cutoff = self.upsample.kernel_size[0] - self.upsample.stride[0] - spect = spect[:, :, :-time_cutoff] + spec = spec[:, :, :-time_cutoff] - spect = spect.unfold(2, self.n_group, self.n_group).permute(0, 2, 1, 3) - spect = spect.contiguous().view(spect.size(0), spect.size(1), -1) - spect = spect.permute(0, 2, 1) + spec = spec.unfold(2, self.n_group, self.n_group).permute(0, 2, 1, 3) + spec = spec.contiguous().view(spec.size(0), spec.size(1), -1) + spec = spec.permute(0, 2, 1) - audio = sigma * torch.randn(spect.size(0), self.n_remaining_channels, spect.size(2), device=spect.device).to( - spect.dtype + audio = sigma * torch.randn(spec.size(0), self.n_remaining_channels, spec.size(2), device=spec.device).to( + spec.dtype ) for k in reversed(range(self.n_flows)): @@ -210,7 +210,7 @@ def norm_dist_to_audio(self, *, spect, sigma: float = 1.0): audio_0 = audio[:, :n_half, :] audio_1 = audio[:, n_half:, :] - output = self.wavenet[k]((audio_0, spect)) + output = self.wavenet[k]((audio_0, spec)) s = output[:, n_half:, :] b = output[:, :n_half, :] audio_1 = (audio_1 - b) / torch.exp(s) @@ -218,17 +218,17 @@ def norm_dist_to_audio(self, *, spect, sigma: float = 1.0): audio = self.convinv[k](audio, reverse=True) if k % self.n_early_every == 0 and k > 0: - z = sigma * torch.randn(spect.size(0), self.n_early_size, spect.size(2), device=spect.device).to( - spect.dtype + z = sigma * torch.randn(spec.size(0), self.n_early_size, spec.size(2), device=spec.device).to( + spec.dtype ) audio = torch.cat((z, audio), 1) return audio.permute(0, 2, 1).contiguous().view(audio.size(0), -1) def save_to(self, save_path: str): - # TODO: Implement me!!! + # TODO: Implement me! pass @classmethod def restore_from(cls, restore_path: str): - # TODO: Implement me!!! + # TODO: Implement me! pass diff --git a/nemo/core/classes/common.py b/nemo/core/classes/common.py index f343d75c7407..3dc569836c96 100644 --- a/nemo/core/classes/common.py +++ b/nemo/core/classes/common.py @@ -317,7 +317,7 @@ def to_config_file(self, path2yaml_file: str): raise NotImplementedError() -PretrainedModelInfo = namedtuple("PretrainedModelInfo", ("pretrained_model_name", "description", "location"),) +PretrainedModelInfo = namedtuple("PretrainedModelInfo", ("pretrained_model_name", "description", "location")) class Model(Typing, Serialization, FileIO): diff --git a/tutorials/tts/1_TTS_inference.ipynb b/tutorials/tts/1_TTS_inference.ipynb new file mode 100644 index 000000000000..e14fc0328afc --- /dev/null +++ b/tutorials/tts/1_TTS_inference.ipynb @@ -0,0 +1,266 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# TTS Inference\n", + "\n", + "This notebook can be used to generate audio samples using either NeMo's pretrained models or after training NeMo TTS models. This script currently uses a two step inference procedure. First, a model is used to generate a mel spectrogram from text. Second, a model is used to generate audio from a mel spectrogram.\n", + "\n", + "Currently supported models are:\n", + "Mel Spectrogram Generators:\n", + "- Tacotron 2\n", + "- Glow-TTS\n", + "\n", + "Audio Generators\n", + "- Grifflin-Lim\n", + "- WaveGlow" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Licence\n", + "\n", + "> Copyright 2020 NVIDIA. All Rights Reserved.\n", + "> \n", + "> Licensed under the Apache License, Version 2.0 (the \"License\");\n", + "> you may not use this file except in compliance with the License.\n", + "> You may obtain a copy of the License at\n", + "> \n", + "> http://www.apache.org/licenses/LICENSE-2.0\n", + "> \n", + "> Unless required by applicable law or agreed to in writing, software\n", + "> distributed under the License is distributed on an \"AS IS\" BASIS,\n", + "> WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", + "> See the License for the specific language governing permissions and\n", + "> limitations under the License." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "\"\"\"\n", + "You can run either this notebook locally (if you have all the dependencies and a GPU) or on Google Colab.\n", + "Instructions for setting up Colab are as follows:\n", + "1. Open a new Python 3 notebook.\n", + "2. Import this notebook from GitHub (File -> Upload Notebook -> \"GITHUB\" tab -> copy/paste GitHub URL)\n", + "3. Connect to an instance with a GPU (Runtime -> Change runtime type -> select \"GPU\" for hardware accelerator)\n", + "4. Run this cell to set up dependencies.\n", + "\"\"\"\n", + "# # If you're using Google Colab and not running locally, uncomment and run this cell.\n", + "# !apt-get install sox libsndfile1 ffmpeg\n", + "# !pip install wget unidecode\n", + "# !pip install git+git://github.com/nvidia/NeMo.git@main#egg=nemo_toolkit[tts]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "supported_spec_gen = [\"tacotron2\", \"glow_tts\"]\n", + "supported_audio_gen = [\"griffin_lim\", \"waveglow\"]\n", + "\n", + "print(\"Choose one of the following spectrogram generators:\")\n", + "print([model for model in supported_spec_gen])\n", + "spectrogram_generator = input()\n", + "print(\"Choose one of the following audio generators:\")\n", + "print([model for model in supported_audio_gen])\n", + "audio_generator = input()\n", + "\n", + "assert spectrogram_generator in supported_spec_gen\n", + "assert audio_generator in supported_audio_gen" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Load model checkpoints\n", + "\n", + "Note: For best quality with Glow TTS, please update the glow tts yaml file with the path to cmudict" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from omegaconf import OmegaConf, open_dict\n", + "import torch\n", + "from ruamel.yaml import YAML\n", + "from nemo.collections.asr.parts import parsers\n", + "\n", + "SAMPLE_RATE = 22050\n", + "NFFT = 1024\n", + "NMEL = 80\n", + "FMAX = None\n", + "\n", + "def load_spectrogram_model():\n", + " if spectrogram_generator == \"tacotron2\":\n", + " from nemo.collections.tts.models import Tacotron2Model as SpecModel\n", + " pretrained_model = \"Tacotron2-22050Hz\"\n", + " elif spectrogram_generator == \"glow_tts\":\n", + " from nemo.collections.tts.models import GlowTTSModel as SpecModel\n", + " pretrained_model = \"GlowTTS-22050Hz\"\n", + " else:\n", + " raise NotImplementedError\n", + "\n", + " model = SpecModel.from_pretrained(pretrained_model)\n", + " with open_dict(model._cfg):\n", + " global SAMPLE_RATE\n", + " global NFFT\n", + " global NMEL\n", + " global FMAX\n", + " SAMPLE_RATE = model._cfg.sample_rate or SAMPLE_RATE\n", + " NFFT = model._cfg.n_fft or NFFT\n", + " NMEL = model._cfg.n_mels or NMEL\n", + " FMAX = model._cfg.fmax or FMAX\n", + " return model\n", + "\n", + "def load_vocoder_model():\n", + " if audio_generator == \"waveglow\":\n", + " from nemo.collections.tts.models import WaveGlowModel as VocoderModel\n", + " pretrained_model = \"WaveGlow-22050Hz\"\n", + " elif audio_generator == \"griffin_lim\":\n", + " from nemo.collections.tts.helpers.helpers import griffin_lim\n", + " import numpy as np\n", + " import librosa\n", + " class GL:\n", + " def __init__(self):\n", + " pass\n", + " def convert_spectrogram_to_audio(self, spec):\n", + " log_mel_spec = spec.squeeze().to('cpu').numpy().T\n", + " mel_spec = np.exp(log_mel_spec)\n", + " mel_pseudo_inverse = librosa.filters.mel(SAMPLE_RATE, NFFT, NMEL, fmax=FMAX)\n", + " return griffin_lim(np.dot(mel_spec, mel_pseudo_inverse).T ** 1.2)\n", + " def load_state_dict(self, *args, **kwargs):\n", + " pass\n", + " def cuda(self, *args, **kwargs):\n", + " return self\n", + " return GL()\n", + " else:\n", + " raise NotImplementedError\n", + "\n", + " model = VocoderModel.from_pretrained(pretrained_model)\n", + " with open_dict(model._cfg):\n", + " global SAMPLE_RATE\n", + " global NFFT\n", + " global NMEL\n", + " global FMAX\n", + " if model._cfg.sample_rate is not None and SAMPLE_RATE is not None:\n", + " assert model._cfg.sample_rate == SAMPLE_RATE\n", + " if model._cfg.n_fft is not None and NFFT is not None:\n", + " assert _cfg.n_fft == NFFT\n", + " if model._cfg.n_mels is not None and NMEL is not None:\n", + " assert model._cfg.n_mels == NMEL\n", + " if model._cfg.fmax is not None and FMAX is not None:\n", + " assert model._cfg.fmax == FMAX\n", + " return model\n", + "\n", + "spec_gen = load_spectrogram_model().cuda()\n", + "vocoder = load_vocoder_model().cuda()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def infer(spec_gen_model, vocder_model, str_input):\n", + " with torch.no_grad():\n", + " parsed = spec_gen.parse(text_to_generate)\n", + " spectrogram = spec_gen.generate_spectrogram(tokens=parsed)\n", + " audio = vocoder.convert_spectrogram_to_audio(spec=spectrogram)\n", + " if isinstance(spectrogram, torch.Tensor):\n", + " spectrogram = spectrogram.to('cpu').numpy()\n", + " if len(spectrogram.shape) == 3:\n", + " spectrogram = spectrogram[0]\n", + " if isinstance(audio, torch.Tensor):\n", + " audio = audio.to('cpu').numpy()\n", + " return spectrogram, audio" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "text_to_generate = input(\"Input what you want the model to say: \")\n", + "spec, audio = infer(spec_gen, vocoder, text_to_generate)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Show Audio and Spectrogram" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import IPython.display as ipd\n", + "import numpy as np\n", + "from PIL import Image\n", + "from matplotlib.pyplot import imshow\n", + "from matplotlib import pyplot as plt\n", + "\n", + "\n", + "ipd.Audio(audio, rate=SAMPLE_RATE)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%matplotlib inline\n", + "imshow(spec, origin=\"lower\")\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.6.9" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} \ No newline at end of file