diff --git a/nemo/collections/tts/helpers/helpers.py b/nemo/collections/tts/helpers/helpers.py index b77ca82da6b8..a742264fc1c0 100644 --- a/nemo/collections/tts/helpers/helpers.py +++ b/nemo/collections/tts/helpers/helpers.py @@ -54,6 +54,7 @@ from pesq import pesq from pystoi import stoi +from nemo.collections.tts.torch.tts_data_types import DATA_STR2DATA_CLASS, MAIN_DATA_TYPES, WithLens from nemo.utils import logging HAVE_WANDB = True @@ -560,3 +561,16 @@ def split_view(tensor, split_size: int, dim: int = 0): cur_shape = tensor.shape new_shape = cur_shape[:dim] + (tensor.shape[dim] // split_size, split_size) + cur_shape[dim + 1 :] return tensor.reshape(*new_shape) + + +def process_batch(batch_data, sup_data_types_set): + batch_dict = {} + batch_index = 0 + for name, datatype in DATA_STR2DATA_CLASS.items(): + if datatype in MAIN_DATA_TYPES or datatype in sup_data_types_set: + batch_dict[name] = batch_data[batch_index] + batch_index = batch_index + 1 + if issubclass(datatype, WithLens): + batch_dict[name + "_lens"] = batch_data[batch_index] + batch_index = batch_index + 1 + return batch_dict diff --git a/nemo/collections/tts/losses/fastpitchloss.py b/nemo/collections/tts/losses/fastpitchloss.py index f9e9b3926e64..313c0223c4c5 100644 --- a/nemo/collections/tts/losses/fastpitchloss.py +++ b/nemo/collections/tts/losses/fastpitchloss.py @@ -120,6 +120,37 @@ def forward(self, pitch_predicted, pitch_tgt, len): return pitch_loss +class EnergyLoss(Loss): + def __init__(self, loss_scale=0.1): + super().__init__() + self.loss_scale = loss_scale + + @property + def input_types(self): + return { + "energy_predicted": NeuralType(('B', 'T'), RegressionValuesType()), + "energy_tgt": NeuralType(('B', 'T'), RegressionValuesType()), + "length": NeuralType(('B'), LengthsType()), + } + + @property + def output_types(self): + return { + "loss": NeuralType(elements_type=LossType()), + } + + @typecheck() + def forward(self, energy_predicted, energy_tgt, length): + if energy_tgt is None: + return 0.0 + dur_mask = mask_from_lens(length, max_len=energy_tgt.size(1)) + energy_loss = F.mse_loss(energy_tgt, energy_predicted, reduction='none') + energy_loss = (energy_loss * dur_mask).sum() / dur_mask.sum() + energy_loss *= self.loss_scale + + return energy_loss + + class MelLoss(Loss): @property def input_types(self): diff --git a/nemo/collections/tts/models/fastpitch.py b/nemo/collections/tts/models/fastpitch.py index 3a1cc9ffdce1..e13092454691 100644 --- a/nemo/collections/tts/models/fastpitch.py +++ b/nemo/collections/tts/models/fastpitch.py @@ -22,12 +22,11 @@ from pytorch_lightning.loggers import LoggerCollection, TensorBoardLogger from nemo.collections.common.parts.preprocessing import parsers -from nemo.collections.tts.helpers.helpers import plot_alignment_to_numpy, plot_spectrogram_to_numpy +from nemo.collections.tts.helpers.helpers import plot_alignment_to_numpy, plot_spectrogram_to_numpy, process_batch from nemo.collections.tts.losses.aligner_loss import BinLoss, ForwardSumLoss -from nemo.collections.tts.losses.fastpitchloss import DurationLoss, MelLoss, PitchLoss +from nemo.collections.tts.losses.fastpitchloss import DurationLoss, EnergyLoss, MelLoss, PitchLoss from nemo.collections.tts.models.base import SpectrogramGenerator from nemo.collections.tts.modules.fastpitch import FastPitchModule -from nemo.collections.tts.torch.tts_data_types import SpeakerID from nemo.core.classes import Exportable from nemo.core.classes.common import PretrainedModelInfo, typecheck from nemo.core.neural_types.elements import ( @@ -119,20 +118,24 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): loss_scale = 0.1 if self.learn_alignment else 1.0 dur_loss_scale = loss_scale pitch_loss_scale = loss_scale + energy_loss_scale = loss_scale if "dur_loss_scale" in cfg: dur_loss_scale = cfg.dur_loss_scale if "pitch_loss_scale" in cfg: pitch_loss_scale = cfg.pitch_loss_scale + if "energy_loss_scale" in cfg: + energy_loss_scale = cfg.energy_loss_scale - self.mel_loss = MelLoss() - self.pitch_loss = PitchLoss(loss_scale=pitch_loss_scale) - self.duration_loss = DurationLoss(loss_scale=dur_loss_scale) + self.mel_loss_fn = MelLoss() + self.pitch_loss_fn = PitchLoss(loss_scale=pitch_loss_scale) + self.duration_loss_fn = DurationLoss(loss_scale=dur_loss_scale) + self.energy_loss_fn = EnergyLoss(loss_scale=energy_loss_scale) self.aligner = None if self.learn_alignment: self.aligner = instantiate(self._cfg.alignment_module) - self.forward_sum_loss = ForwardSumLoss() - self.bin_loss = BinLoss() + self.forward_sum_loss_fn = ForwardSumLoss() + self.bin_loss_fn = BinLoss() self.preprocessor = instantiate(self._cfg.preprocessor) input_fft = instantiate(self._cfg.input_fft, **input_fft_kwargs) @@ -142,16 +145,20 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): speaker_emb_condition_prosody = cfg.get("speaker_emb_condition_prosody", False) speaker_emb_condition_decoder = cfg.get("speaker_emb_condition_decoder", False) speaker_emb_condition_aligner = cfg.get("speaker_emb_condition_aligner", False) + energy_embedding_kernel_size = cfg.get("energy_embedding_kernel_size", 0) + energy_predictor = instantiate(self._cfg.get("energy_predictor", None)) self.fastpitch = FastPitchModule( input_fft, output_fft, duration_predictor, pitch_predictor, + energy_predictor, self.aligner, cfg.n_speakers, cfg.symbols_embedding_dim, cfg.pitch_embedding_kernel_size, + energy_embedding_kernel_size, cfg.n_mel_channels, cfg.max_token_duration, speaker_emb_condition_prosody, @@ -284,6 +291,7 @@ def parse(self, str_input: str, normalize=True) -> torch.tensor: "text": NeuralType(('B', 'T_text'), TokenIndex()), "durs": NeuralType(('B', 'T_text'), TokenDurationType()), "pitch": NeuralType(('B', 'T_audio'), RegressionValuesType()), + "energy": NeuralType(('B', 'T_audio'), RegressionValuesType(), optional=True), "speaker": NeuralType(('B'), Index(), optional=True), "pace": NeuralType(optional=True), "spec": NeuralType(('B', 'D', 'T_spec'), MelSpectrogramType(), optional=True), @@ -298,6 +306,7 @@ def forward( text, durs=None, pitch=None, + energy=None, speaker=None, pace=1.0, spec=None, @@ -309,6 +318,7 @@ def forward( text=text, durs=durs, pitch=pitch, + energy=energy, speaker=speaker, pace=pace, spec=spec, @@ -329,24 +339,41 @@ def generate_spectrogram( return spect def training_step(self, batch, batch_idx): - attn_prior, durs, speaker = None, None, None + attn_prior, durs, speaker, energy = None, None, None, None if self.learn_alignment: - if self.ds_class_name == "TTSDataset": - if SpeakerID in self._train_dl.dataset.sup_data_types_set: - audio, audio_lens, text, text_lens, attn_prior, pitch, _, speaker = batch - else: - audio, audio_lens, text, text_lens, attn_prior, pitch, _ = batch - else: - raise ValueError(f"Unknown vocab class: {self.vocab.__class__.__name__}") + assert self.ds_class_name == "TTSDataset", f"Unknown dataset class: {self.ds_class_name}" + batch_dict = process_batch(batch, self._train_dl.dataset.sup_data_types_set) + audio = batch_dict.get("audio") + audio_lens = batch_dict.get("audio_lens") + text = batch_dict.get("text") + text_lens = batch_dict.get("text_lens") + attn_prior = batch_dict.get("align_prior_matrix", None) + pitch = batch_dict.get("pitch", None) + energy = batch_dict.get("energy", None) + speaker = batch_dict.get("speaker_id", None) else: audio, audio_lens, text, text_lens, durs, pitch, speaker = batch mels, spec_len = self.preprocessor(input_signal=audio, length=audio_lens) - mels_pred, _, _, log_durs_pred, pitch_pred, attn_soft, attn_logprob, attn_hard, attn_hard_dur, pitch = self( + ( + mels_pred, + _, + _, + log_durs_pred, + pitch_pred, + attn_soft, + attn_logprob, + attn_hard, + attn_hard_dur, + pitch, + energy_pred, + energy_tgt, + ) = self( text=text, durs=durs, pitch=pitch, + energy=energy, speaker=speaker, pace=1.0, spec=mels if self.learn_alignment else None, @@ -357,22 +384,25 @@ def training_step(self, batch, batch_idx): if durs is None: durs = attn_hard_dur - mel_loss = self.mel_loss(spect_predicted=mels_pred, spect_tgt=mels) - dur_loss = self.duration_loss(log_durs_predicted=log_durs_pred, durs_tgt=durs, len=text_lens) + mel_loss = self.mel_loss_fn(spect_predicted=mels_pred, spect_tgt=mels) + dur_loss = self.duration_loss_fn(log_durs_predicted=log_durs_pred, durs_tgt=durs, len=text_lens) loss = mel_loss + dur_loss if self.learn_alignment: - ctc_loss = self.forward_sum_loss(attn_logprob=attn_logprob, in_lens=text_lens, out_lens=spec_len) + ctc_loss = self.forward_sum_loss_fn(attn_logprob=attn_logprob, in_lens=text_lens, out_lens=spec_len) bin_loss_weight = min(self.current_epoch / self.bin_loss_warmup_epochs, 1.0) * 1.0 - bin_loss = self.bin_loss(hard_attention=attn_hard, soft_attention=attn_soft) * bin_loss_weight + bin_loss = self.bin_loss_fn(hard_attention=attn_hard, soft_attention=attn_soft) * bin_loss_weight loss += ctc_loss + bin_loss - pitch_loss = self.pitch_loss(pitch_predicted=pitch_pred, pitch_tgt=pitch, len=text_lens) - loss += pitch_loss + pitch_loss = self.pitch_loss_fn(pitch_predicted=pitch_pred, pitch_tgt=pitch, len=text_lens) + energy_loss = self.energy_loss_fn(energy_predicted=energy_pred, energy_tgt=energy_tgt, length=text_lens) + loss += pitch_loss + energy_loss self.log("t_loss", loss) self.log("t_mel_loss", mel_loss) self.log("t_dur_loss", dur_loss) self.log("t_pitch_loss", pitch_loss) + if energy_tgt is not None: + self.log("t_energy_loss", energy_loss) if self.learn_alignment: self.log("t_ctc_loss", ctc_loss) self.log("t_bin_loss", bin_loss) @@ -404,25 +434,29 @@ def training_step(self, batch, batch_idx): return loss def validation_step(self, batch, batch_idx): - attn_prior, durs, speaker = None, None, None + attn_prior, durs, speaker, energy = None, None, None, None if self.learn_alignment: - if self.ds_class_name == "TTSDataset": - if SpeakerID in self._train_dl.dataset.sup_data_types_set: - audio, audio_lens, text, text_lens, attn_prior, pitch, _, speaker = batch - else: - audio, audio_lens, text, text_lens, attn_prior, pitch, _ = batch - else: - raise ValueError(f"Unknown vocab class: {self.vocab.__class__.__name__}") + assert self.ds_class_name == "TTSDataset", f"Unknown dataset class: {self.ds_class_name}" + batch_dict = process_batch(batch, self._train_dl.dataset.sup_data_types_set) + audio = batch_dict.get("audio") + audio_lens = batch_dict.get("audio_lens") + text = batch_dict.get("text") + text_lens = batch_dict.get("text_lens") + attn_prior = batch_dict.get("align_prior_matrix", None) + pitch = batch_dict.get("pitch", None) + energy = batch_dict.get("energy", None) + speaker = batch_dict.get("speaker_id", None) else: audio, audio_lens, text, text_lens, durs, pitch, speaker = batch mels, mel_lens = self.preprocessor(input_signal=audio, length=audio_lens) # Calculate val loss on ground truth durations to better align L2 loss in time - mels_pred, _, _, log_durs_pred, pitch_pred, _, _, _, attn_hard_dur, pitch = self( + (mels_pred, _, _, log_durs_pred, pitch_pred, _, _, _, attn_hard_dur, pitch, energy_pred, energy_tgt,) = self( text=text, durs=durs, pitch=pitch, + energy=energy, speaker=speaker, pace=1.0, spec=mels if self.learn_alignment else None, @@ -433,16 +467,18 @@ def validation_step(self, batch, batch_idx): if durs is None: durs = attn_hard_dur - mel_loss = self.mel_loss(spect_predicted=mels_pred, spect_tgt=mels) - dur_loss = self.duration_loss(log_durs_predicted=log_durs_pred, durs_tgt=durs, len=text_lens) - pitch_loss = self.pitch_loss(pitch_predicted=pitch_pred, pitch_tgt=pitch, len=text_lens) - loss = mel_loss + dur_loss + pitch_loss + mel_loss = self.mel_loss_fn(spect_predicted=mels_pred, spect_tgt=mels) + dur_loss = self.duration_loss_fn(log_durs_predicted=log_durs_pred, durs_tgt=durs, len=text_lens) + pitch_loss = self.pitch_loss_fn(pitch_predicted=pitch_pred, pitch_tgt=pitch, len=text_lens) + energy_loss = self.energy_loss_fn(energy_predicted=energy_pred, energy_tgt=energy_tgt, length=text_lens) + loss = mel_loss + dur_loss + pitch_loss + energy_loss return { "val_loss": loss, "mel_loss": mel_loss, "dur_loss": dur_loss, "pitch_loss": pitch_loss, + "energy_loss": energy_loss if energy_tgt is not None else None, "mel_target": mels if batch_idx == 0 else None, "mel_pred": mels_pred if batch_idx == 0 else None, } @@ -457,8 +493,11 @@ def validation_epoch_end(self, outputs): self.log("val_mel_loss", mel_loss) self.log("val_dur_loss", dur_loss) self.log("val_pitch_loss", pitch_loss) + if outputs[0]["energy_loss"] is not None: + energy_loss = collect("energy_loss") + self.log("val_energy_loss", energy_loss) - _, _, _, _, spec_target, spec_predict = outputs[0].values() + _, _, _, _, _, spec_target, spec_predict = outputs[0].values() if isinstance(self.logger, TensorBoardLogger): self.tb_logger.add_image( diff --git a/nemo/collections/tts/models/mixer_tts.py b/nemo/collections/tts/models/mixer_tts.py index 842d08cc321c..fffaf82d5a06 100644 --- a/nemo/collections/tts/models/mixer_tts.py +++ b/nemo/collections/tts/models/mixer_tts.py @@ -39,7 +39,7 @@ ) from nemo.collections.tts.losses.aligner_loss import BinLoss, ForwardSumLoss from nemo.collections.tts.models.base import SpectrogramGenerator -from nemo.collections.tts.modules.fastpitch import average_pitch, regulate_len +from nemo.collections.tts.modules.fastpitch import average_features, regulate_len from nemo.core import Exportable from nemo.core.classes.common import PretrainedModelInfo, typecheck from nemo.core.neural_types.elements import ( @@ -242,7 +242,7 @@ def _metrics( if self.add_bin_loss: bin_loss = self.bin_loss(hard_attention=attn_hard, soft_attention=attn_soft) loss = loss + self.bin_loss_scale * bin_loss - true_avg_pitch = average_pitch(true_pitch.unsqueeze(1), attn_hard_dur).squeeze(1) + true_avg_pitch = average_features(true_pitch.unsqueeze(1), attn_hard_dur).squeeze(1) # Pitch loss pitch_loss = F.mse_loss(pred_pitch, true_avg_pitch, reduction='none') # noqa @@ -315,12 +315,12 @@ def forward(self, text, text_len, pitch=None, spect=None, spect_len=None, attn_p # Avg pitch, add pitch_emb if not self.training: if pitch is not None: - pitch = average_pitch(pitch.unsqueeze(1), attn_hard_dur).squeeze(1) + pitch = average_features(pitch.unsqueeze(1), attn_hard_dur).squeeze(1) pitch_emb = self.pitch_emb(pitch.unsqueeze(1)) else: pitch_emb = self.pitch_emb(pitch_predicted.unsqueeze(1)) else: - pitch = average_pitch(pitch.unsqueeze(1), attn_hard_dur).squeeze(1) + pitch = average_features(pitch.unsqueeze(1), attn_hard_dur).squeeze(1) pitch_emb = self.pitch_emb(pitch.unsqueeze(1)) enc_out = enc_out + pitch_emb.transpose(1, 2) @@ -381,7 +381,7 @@ def infer( # Avg pitch, pitch predictor if use_gt_durs and pitch is not None: - pitch = average_pitch(pitch.unsqueeze(1), attn_hard_dur).squeeze(1) + pitch = average_features(pitch.unsqueeze(1), attn_hard_dur).squeeze(1) pitch_emb = self.pitch_emb(pitch.unsqueeze(1)) else: pitch_predicted = self.pitch_predictor(enc_out, enc_mask) @@ -565,7 +565,7 @@ def validation_step(self, batch, batch_idx): pitches += [ wandb.Image( plot_pitch_to_numpy( - average_pitch(pitch.unsqueeze(1), attn_hard_dur) + average_features(pitch.unsqueeze(1), attn_hard_dur) .squeeze(1)[i, : text_len[i]] .data.cpu() .numpy(), diff --git a/nemo/collections/tts/modules/fastpitch.py b/nemo/collections/tts/modules/fastpitch.py index 24c394375176..fbee3631c609 100644 --- a/nemo/collections/tts/modules/fastpitch.py +++ b/nemo/collections/tts/modules/fastpitch.py @@ -61,7 +61,7 @@ from nemo.core.neural_types.neural_type import NeuralType -def average_pitch(pitch, durs): +def average_features(pitch, durs): durs_cums_ends = torch.cumsum(durs, dim=1).long() durs_cums_starts = torch.nn.functional.pad(durs_cums_ends[:, :-1], (1, 0)) pitch_nonzero_cums = torch.nn.functional.pad(torch.cumsum(pitch != 0.0, dim=2), (1, 0)) @@ -135,10 +135,12 @@ def __init__( decoder_module: NeuralModule, duration_predictor: NeuralModule, pitch_predictor: NeuralModule, + energy_predictor: NeuralModule, aligner: NeuralModule, n_speakers: int, symbols_embedding_dim: int, pitch_embedding_kernel_size: int, + energy_embedding_kernel_size: int, n_mel_channels: int = 80, max_token_duration: int = 75, speaker_emb_condition_prosody: bool = False, @@ -151,6 +153,7 @@ def __init__( self.decoder = decoder_module self.duration_predictor = duration_predictor self.pitch_predictor = pitch_predictor + self.energy_predictor = energy_predictor self.aligner = aligner self.learn_alignment = aligner is not None self.use_duration_predictor = True @@ -174,6 +177,14 @@ def __init__( padding=int((pitch_embedding_kernel_size - 1) / 2), ) + if self.energy_predictor is not None: + self.energy_emb = torch.nn.Conv1d( + 1, + symbols_embedding_dim, + kernel_size=energy_embedding_kernel_size, + padding=int((energy_embedding_kernel_size - 1) / 2), + ) + # Store values precomputed from training data for convenience self.register_buffer('pitch_mean', torch.zeros(1)) self.register_buffer('pitch_std', torch.zeros(1)) @@ -186,6 +197,7 @@ def input_types(self): "text": NeuralType(('B', 'T_text'), TokenIndex()), "durs": NeuralType(('B', 'T_text'), TokenDurationType()), "pitch": NeuralType(('B', 'T_audio'), RegressionValuesType()), + "energy": NeuralType(('B', 'T_audio'), RegressionValuesType(), optional=True), "speaker": NeuralType(('B'), Index(), optional=True), "pace": NeuralType(optional=True), "spec": NeuralType(('B', 'D', 'T_spec'), MelSpectrogramType(), optional=True), @@ -207,6 +219,8 @@ def output_types(self): "attn_hard": NeuralType(('B', 'S', 'T_spec', 'T_text'), ProbsType()), "attn_hard_dur": NeuralType(('B', 'T_text'), TokenDurationType()), "pitch": NeuralType(('B', 'T_audio'), RegressionValuesType()), + "energy_pred": NeuralType(('B', 'T_text'), RegressionValuesType()), + "energy_tgt": NeuralType(('B', 'T_audio'), RegressionValuesType()), } @typecheck() @@ -216,6 +230,7 @@ def forward( text, durs=None, pitch=None, + energy=None, speaker=None, pace=1.0, spec=None, @@ -260,13 +275,38 @@ def forward( if pitch is not None: if self.learn_alignment and pitch.shape[-1] != pitch_predicted.shape[-1]: # Pitch during training is per spectrogram frame, but during inference, it should be per character - pitch = average_pitch(pitch.unsqueeze(1), attn_hard_dur).squeeze(1) + pitch = average_features(pitch.unsqueeze(1), attn_hard_dur).squeeze(1) + elif not self.learn_alignment: + # If alignment is not learnt attn_hard_dur is None, hence durs_predicted + pitch = average_features(pitch.unsqueeze(1), durs_predicted).squeeze(1) pitch_emb = self.pitch_emb(pitch.unsqueeze(1)) else: pitch_emb = self.pitch_emb(pitch_predicted.unsqueeze(1)) enc_out = enc_out + pitch_emb.transpose(1, 2) + # Predict energy + if self.energy_predictor is not None: + energy_pred = self.energy_predictor(enc_out, enc_mask).squeeze(-1) + + if energy is not None: + # Average energy over characters + if self.learn_alignment: + energy_tgt = average_features(energy.unsqueeze(1), attn_hard_dur) + else: + energy_tgt = average_features(energy.unsqueeze(1), durs_predicted) + energy_tgt = torch.log(1.0 + energy_tgt) + energy_emb = self.energy_emb(energy_tgt) + energy_tgt = energy_tgt.squeeze(1) + else: + energy_emb = self.energy_emb(energy_pred.unsqueeze(1)) + energy_tgt = None + + enc_out = enc_out + energy_emb.transpose(1, 2) + else: + energy_pred = None + energy_tgt = None + if self.learn_alignment and spec is not None: len_regulated, dec_lens = regulate_len(attn_hard_dur, enc_out, pace) elif spec is None and durs is not None: @@ -292,9 +332,11 @@ def forward( attn_hard, attn_hard_dur, pitch, + energy_pred, + energy_tgt, ) - def infer(self, *, text, pitch=None, speaker=None, pace=1.0, volume=None): + def infer(self, *, text, pitch=None, speaker=None, energy=None, pace=1.0, volume=None): # Calculate speaker embedding if self.speaker_emb is None or speaker is None: spk_emb = 0 @@ -315,6 +357,14 @@ def infer(self, *, text, pitch=None, speaker=None, pace=1.0, volume=None): ) pitch_predicted = self.pitch_predictor(prosody_input, enc_mask) + pitch pitch_emb = self.pitch_emb(pitch_predicted.unsqueeze(1)) + if self.energy_predictor is not None: + if energy is not None: + assert energy.shape[-1] == text.shape[-1], f"energy.shape[-1]: {energy.shape[-1]} != len(text)" + energy_emb = self.energy_emb(energy) + else: + energy_pred = self.energy_predictor(prosody_input, enc_mask).squeeze(-1) + energy_emb = self.energy_emb(energy_pred.unsqueeze(1)) + enc_out = enc_out + energy_emb.transpose(1, 2) enc_out = enc_out + pitch_emb.transpose(1, 2) # Expand to decoder time dimension