Skip to content

Commit

Permalink
[TTS] Fastpitch energy condition and refactoring (NVIDIA#5218)
Browse files Browse the repository at this point in the history
* Incorporating Energy conditioning in FastPitch

Signed-off-by: subhankar-ghosh <[email protected]>

* Minor fixes in Energy conditioning in FastPitch

Signed-off-by: subhankar-ghosh <[email protected]>

* Add Energy conditioning in FastPitch to infer method

Signed-off-by: subhankar-ghosh <[email protected]>

* adding fn to function names

Signed-off-by: subhankar-ghosh <[email protected]>

* Incorporating Energy conditioning in FastPitch

Signed-off-by: subhankar-ghosh <[email protected]>

* Minor fixes in Energy conditioning in FastPitch

Signed-off-by: subhankar-ghosh <[email protected]>

* Add Energy conditioning in FastPitch to infer method

Signed-off-by: subhankar-ghosh <[email protected]>

* adding fn to function names

Signed-off-by: subhankar-ghosh <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* remove ifelse from batching, minor refactoring changes in energy code

Signed-off-by: subhankar-ghosh <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Refactor based on PR comments.

Signed-off-by: subhankar-ghosh <[email protected]>

* Added support for not learning alignment in energy

Signed-off-by: subhankar-ghosh <[email protected]>

* Fix typo in assert statemetn

Signed-off-by: subhankar-ghosh <[email protected]>

* Renaming average_pitch to average_features

Signed-off-by: subhankar-ghosh <[email protected]>

* Renaming len variable name as it is a keyword

Signed-off-by: subhankar-ghosh <[email protected]>

* Renaming len variable name as it is a keyword

Signed-off-by: subhankar-ghosh <[email protected]>

Signed-off-by: subhankar-ghosh <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Xuesong Yang <[email protected]>
Signed-off-by: 1-800-bad-code <[email protected]>
  • Loading branch information
3 people authored and 1-800-BAD-CODE committed Nov 13, 2022
1 parent a5b89d8 commit 14b3bf5
Show file tree
Hide file tree
Showing 5 changed files with 180 additions and 46 deletions.
14 changes: 14 additions & 0 deletions nemo/collections/tts/helpers/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
from pesq import pesq
from pystoi import stoi

from nemo.collections.tts.torch.tts_data_types import DATA_STR2DATA_CLASS, MAIN_DATA_TYPES, WithLens
from nemo.utils import logging

HAVE_WANDB = True
Expand Down Expand Up @@ -560,3 +561,16 @@ def split_view(tensor, split_size: int, dim: int = 0):
cur_shape = tensor.shape
new_shape = cur_shape[:dim] + (tensor.shape[dim] // split_size, split_size) + cur_shape[dim + 1 :]
return tensor.reshape(*new_shape)


def process_batch(batch_data, sup_data_types_set):
batch_dict = {}
batch_index = 0
for name, datatype in DATA_STR2DATA_CLASS.items():
if datatype in MAIN_DATA_TYPES or datatype in sup_data_types_set:
batch_dict[name] = batch_data[batch_index]
batch_index = batch_index + 1
if issubclass(datatype, WithLens):
batch_dict[name + "_lens"] = batch_data[batch_index]
batch_index = batch_index + 1
return batch_dict
31 changes: 31 additions & 0 deletions nemo/collections/tts/losses/fastpitchloss.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,37 @@ def forward(self, pitch_predicted, pitch_tgt, len):
return pitch_loss


class EnergyLoss(Loss):
def __init__(self, loss_scale=0.1):
super().__init__()
self.loss_scale = loss_scale

@property
def input_types(self):
return {
"energy_predicted": NeuralType(('B', 'T'), RegressionValuesType()),
"energy_tgt": NeuralType(('B', 'T'), RegressionValuesType()),
"length": NeuralType(('B'), LengthsType()),
}

@property
def output_types(self):
return {
"loss": NeuralType(elements_type=LossType()),
}

@typecheck()
def forward(self, energy_predicted, energy_tgt, length):
if energy_tgt is None:
return 0.0
dur_mask = mask_from_lens(length, max_len=energy_tgt.size(1))
energy_loss = F.mse_loss(energy_tgt, energy_predicted, reduction='none')
energy_loss = (energy_loss * dur_mask).sum() / dur_mask.sum()
energy_loss *= self.loss_scale

return energy_loss


class MelLoss(Loss):
@property
def input_types(self):
Expand Down
113 changes: 76 additions & 37 deletions nemo/collections/tts/models/fastpitch.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,11 @@
from pytorch_lightning.loggers import LoggerCollection, TensorBoardLogger

from nemo.collections.common.parts.preprocessing import parsers
from nemo.collections.tts.helpers.helpers import plot_alignment_to_numpy, plot_spectrogram_to_numpy
from nemo.collections.tts.helpers.helpers import plot_alignment_to_numpy, plot_spectrogram_to_numpy, process_batch
from nemo.collections.tts.losses.aligner_loss import BinLoss, ForwardSumLoss
from nemo.collections.tts.losses.fastpitchloss import DurationLoss, MelLoss, PitchLoss
from nemo.collections.tts.losses.fastpitchloss import DurationLoss, EnergyLoss, MelLoss, PitchLoss
from nemo.collections.tts.models.base import SpectrogramGenerator
from nemo.collections.tts.modules.fastpitch import FastPitchModule
from nemo.collections.tts.torch.tts_data_types import SpeakerID
from nemo.core.classes import Exportable
from nemo.core.classes.common import PretrainedModelInfo, typecheck
from nemo.core.neural_types.elements import (
Expand Down Expand Up @@ -119,20 +118,24 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None):
loss_scale = 0.1 if self.learn_alignment else 1.0
dur_loss_scale = loss_scale
pitch_loss_scale = loss_scale
energy_loss_scale = loss_scale
if "dur_loss_scale" in cfg:
dur_loss_scale = cfg.dur_loss_scale
if "pitch_loss_scale" in cfg:
pitch_loss_scale = cfg.pitch_loss_scale
if "energy_loss_scale" in cfg:
energy_loss_scale = cfg.energy_loss_scale

self.mel_loss = MelLoss()
self.pitch_loss = PitchLoss(loss_scale=pitch_loss_scale)
self.duration_loss = DurationLoss(loss_scale=dur_loss_scale)
self.mel_loss_fn = MelLoss()
self.pitch_loss_fn = PitchLoss(loss_scale=pitch_loss_scale)
self.duration_loss_fn = DurationLoss(loss_scale=dur_loss_scale)
self.energy_loss_fn = EnergyLoss(loss_scale=energy_loss_scale)

self.aligner = None
if self.learn_alignment:
self.aligner = instantiate(self._cfg.alignment_module)
self.forward_sum_loss = ForwardSumLoss()
self.bin_loss = BinLoss()
self.forward_sum_loss_fn = ForwardSumLoss()
self.bin_loss_fn = BinLoss()

self.preprocessor = instantiate(self._cfg.preprocessor)
input_fft = instantiate(self._cfg.input_fft, **input_fft_kwargs)
Expand All @@ -142,16 +145,20 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None):
speaker_emb_condition_prosody = cfg.get("speaker_emb_condition_prosody", False)
speaker_emb_condition_decoder = cfg.get("speaker_emb_condition_decoder", False)
speaker_emb_condition_aligner = cfg.get("speaker_emb_condition_aligner", False)
energy_embedding_kernel_size = cfg.get("energy_embedding_kernel_size", 0)
energy_predictor = instantiate(self._cfg.get("energy_predictor", None))

self.fastpitch = FastPitchModule(
input_fft,
output_fft,
duration_predictor,
pitch_predictor,
energy_predictor,
self.aligner,
cfg.n_speakers,
cfg.symbols_embedding_dim,
cfg.pitch_embedding_kernel_size,
energy_embedding_kernel_size,
cfg.n_mel_channels,
cfg.max_token_duration,
speaker_emb_condition_prosody,
Expand Down Expand Up @@ -284,6 +291,7 @@ def parse(self, str_input: str, normalize=True) -> torch.tensor:
"text": NeuralType(('B', 'T_text'), TokenIndex()),
"durs": NeuralType(('B', 'T_text'), TokenDurationType()),
"pitch": NeuralType(('B', 'T_audio'), RegressionValuesType()),
"energy": NeuralType(('B', 'T_audio'), RegressionValuesType(), optional=True),
"speaker": NeuralType(('B'), Index(), optional=True),
"pace": NeuralType(optional=True),
"spec": NeuralType(('B', 'D', 'T_spec'), MelSpectrogramType(), optional=True),
Expand All @@ -298,6 +306,7 @@ def forward(
text,
durs=None,
pitch=None,
energy=None,
speaker=None,
pace=1.0,
spec=None,
Expand All @@ -309,6 +318,7 @@ def forward(
text=text,
durs=durs,
pitch=pitch,
energy=energy,
speaker=speaker,
pace=pace,
spec=spec,
Expand All @@ -329,24 +339,41 @@ def generate_spectrogram(
return spect

def training_step(self, batch, batch_idx):
attn_prior, durs, speaker = None, None, None
attn_prior, durs, speaker, energy = None, None, None, None
if self.learn_alignment:
if self.ds_class_name == "TTSDataset":
if SpeakerID in self._train_dl.dataset.sup_data_types_set:
audio, audio_lens, text, text_lens, attn_prior, pitch, _, speaker = batch
else:
audio, audio_lens, text, text_lens, attn_prior, pitch, _ = batch
else:
raise ValueError(f"Unknown vocab class: {self.vocab.__class__.__name__}")
assert self.ds_class_name == "TTSDataset", f"Unknown dataset class: {self.ds_class_name}"
batch_dict = process_batch(batch, self._train_dl.dataset.sup_data_types_set)
audio = batch_dict.get("audio")
audio_lens = batch_dict.get("audio_lens")
text = batch_dict.get("text")
text_lens = batch_dict.get("text_lens")
attn_prior = batch_dict.get("align_prior_matrix", None)
pitch = batch_dict.get("pitch", None)
energy = batch_dict.get("energy", None)
speaker = batch_dict.get("speaker_id", None)
else:
audio, audio_lens, text, text_lens, durs, pitch, speaker = batch

mels, spec_len = self.preprocessor(input_signal=audio, length=audio_lens)

mels_pred, _, _, log_durs_pred, pitch_pred, attn_soft, attn_logprob, attn_hard, attn_hard_dur, pitch = self(
(
mels_pred,
_,
_,
log_durs_pred,
pitch_pred,
attn_soft,
attn_logprob,
attn_hard,
attn_hard_dur,
pitch,
energy_pred,
energy_tgt,
) = self(
text=text,
durs=durs,
pitch=pitch,
energy=energy,
speaker=speaker,
pace=1.0,
spec=mels if self.learn_alignment else None,
Expand All @@ -357,22 +384,25 @@ def training_step(self, batch, batch_idx):
if durs is None:
durs = attn_hard_dur

mel_loss = self.mel_loss(spect_predicted=mels_pred, spect_tgt=mels)
dur_loss = self.duration_loss(log_durs_predicted=log_durs_pred, durs_tgt=durs, len=text_lens)
mel_loss = self.mel_loss_fn(spect_predicted=mels_pred, spect_tgt=mels)
dur_loss = self.duration_loss_fn(log_durs_predicted=log_durs_pred, durs_tgt=durs, len=text_lens)
loss = mel_loss + dur_loss
if self.learn_alignment:
ctc_loss = self.forward_sum_loss(attn_logprob=attn_logprob, in_lens=text_lens, out_lens=spec_len)
ctc_loss = self.forward_sum_loss_fn(attn_logprob=attn_logprob, in_lens=text_lens, out_lens=spec_len)
bin_loss_weight = min(self.current_epoch / self.bin_loss_warmup_epochs, 1.0) * 1.0
bin_loss = self.bin_loss(hard_attention=attn_hard, soft_attention=attn_soft) * bin_loss_weight
bin_loss = self.bin_loss_fn(hard_attention=attn_hard, soft_attention=attn_soft) * bin_loss_weight
loss += ctc_loss + bin_loss

pitch_loss = self.pitch_loss(pitch_predicted=pitch_pred, pitch_tgt=pitch, len=text_lens)
loss += pitch_loss
pitch_loss = self.pitch_loss_fn(pitch_predicted=pitch_pred, pitch_tgt=pitch, len=text_lens)
energy_loss = self.energy_loss_fn(energy_predicted=energy_pred, energy_tgt=energy_tgt, length=text_lens)
loss += pitch_loss + energy_loss

self.log("t_loss", loss)
self.log("t_mel_loss", mel_loss)
self.log("t_dur_loss", dur_loss)
self.log("t_pitch_loss", pitch_loss)
if energy_tgt is not None:
self.log("t_energy_loss", energy_loss)
if self.learn_alignment:
self.log("t_ctc_loss", ctc_loss)
self.log("t_bin_loss", bin_loss)
Expand Down Expand Up @@ -404,25 +434,29 @@ def training_step(self, batch, batch_idx):
return loss

def validation_step(self, batch, batch_idx):
attn_prior, durs, speaker = None, None, None
attn_prior, durs, speaker, energy = None, None, None, None
if self.learn_alignment:
if self.ds_class_name == "TTSDataset":
if SpeakerID in self._train_dl.dataset.sup_data_types_set:
audio, audio_lens, text, text_lens, attn_prior, pitch, _, speaker = batch
else:
audio, audio_lens, text, text_lens, attn_prior, pitch, _ = batch
else:
raise ValueError(f"Unknown vocab class: {self.vocab.__class__.__name__}")
assert self.ds_class_name == "TTSDataset", f"Unknown dataset class: {self.ds_class_name}"
batch_dict = process_batch(batch, self._train_dl.dataset.sup_data_types_set)
audio = batch_dict.get("audio")
audio_lens = batch_dict.get("audio_lens")
text = batch_dict.get("text")
text_lens = batch_dict.get("text_lens")
attn_prior = batch_dict.get("align_prior_matrix", None)
pitch = batch_dict.get("pitch", None)
energy = batch_dict.get("energy", None)
speaker = batch_dict.get("speaker_id", None)
else:
audio, audio_lens, text, text_lens, durs, pitch, speaker = batch

mels, mel_lens = self.preprocessor(input_signal=audio, length=audio_lens)

# Calculate val loss on ground truth durations to better align L2 loss in time
mels_pred, _, _, log_durs_pred, pitch_pred, _, _, _, attn_hard_dur, pitch = self(
(mels_pred, _, _, log_durs_pred, pitch_pred, _, _, _, attn_hard_dur, pitch, energy_pred, energy_tgt,) = self(
text=text,
durs=durs,
pitch=pitch,
energy=energy,
speaker=speaker,
pace=1.0,
spec=mels if self.learn_alignment else None,
Expand All @@ -433,16 +467,18 @@ def validation_step(self, batch, batch_idx):
if durs is None:
durs = attn_hard_dur

mel_loss = self.mel_loss(spect_predicted=mels_pred, spect_tgt=mels)
dur_loss = self.duration_loss(log_durs_predicted=log_durs_pred, durs_tgt=durs, len=text_lens)
pitch_loss = self.pitch_loss(pitch_predicted=pitch_pred, pitch_tgt=pitch, len=text_lens)
loss = mel_loss + dur_loss + pitch_loss
mel_loss = self.mel_loss_fn(spect_predicted=mels_pred, spect_tgt=mels)
dur_loss = self.duration_loss_fn(log_durs_predicted=log_durs_pred, durs_tgt=durs, len=text_lens)
pitch_loss = self.pitch_loss_fn(pitch_predicted=pitch_pred, pitch_tgt=pitch, len=text_lens)
energy_loss = self.energy_loss_fn(energy_predicted=energy_pred, energy_tgt=energy_tgt, length=text_lens)
loss = mel_loss + dur_loss + pitch_loss + energy_loss

return {
"val_loss": loss,
"mel_loss": mel_loss,
"dur_loss": dur_loss,
"pitch_loss": pitch_loss,
"energy_loss": energy_loss if energy_tgt is not None else None,
"mel_target": mels if batch_idx == 0 else None,
"mel_pred": mels_pred if batch_idx == 0 else None,
}
Expand All @@ -457,8 +493,11 @@ def validation_epoch_end(self, outputs):
self.log("val_mel_loss", mel_loss)
self.log("val_dur_loss", dur_loss)
self.log("val_pitch_loss", pitch_loss)
if outputs[0]["energy_loss"] is not None:
energy_loss = collect("energy_loss")
self.log("val_energy_loss", energy_loss)

_, _, _, _, spec_target, spec_predict = outputs[0].values()
_, _, _, _, _, spec_target, spec_predict = outputs[0].values()

if isinstance(self.logger, TensorBoardLogger):
self.tb_logger.add_image(
Expand Down
12 changes: 6 additions & 6 deletions nemo/collections/tts/models/mixer_tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
)
from nemo.collections.tts.losses.aligner_loss import BinLoss, ForwardSumLoss
from nemo.collections.tts.models.base import SpectrogramGenerator
from nemo.collections.tts.modules.fastpitch import average_pitch, regulate_len
from nemo.collections.tts.modules.fastpitch import average_features, regulate_len
from nemo.core import Exportable
from nemo.core.classes.common import PretrainedModelInfo, typecheck
from nemo.core.neural_types.elements import (
Expand Down Expand Up @@ -242,7 +242,7 @@ def _metrics(
if self.add_bin_loss:
bin_loss = self.bin_loss(hard_attention=attn_hard, soft_attention=attn_soft)
loss = loss + self.bin_loss_scale * bin_loss
true_avg_pitch = average_pitch(true_pitch.unsqueeze(1), attn_hard_dur).squeeze(1)
true_avg_pitch = average_features(true_pitch.unsqueeze(1), attn_hard_dur).squeeze(1)

# Pitch loss
pitch_loss = F.mse_loss(pred_pitch, true_avg_pitch, reduction='none') # noqa
Expand Down Expand Up @@ -315,12 +315,12 @@ def forward(self, text, text_len, pitch=None, spect=None, spect_len=None, attn_p
# Avg pitch, add pitch_emb
if not self.training:
if pitch is not None:
pitch = average_pitch(pitch.unsqueeze(1), attn_hard_dur).squeeze(1)
pitch = average_features(pitch.unsqueeze(1), attn_hard_dur).squeeze(1)
pitch_emb = self.pitch_emb(pitch.unsqueeze(1))
else:
pitch_emb = self.pitch_emb(pitch_predicted.unsqueeze(1))
else:
pitch = average_pitch(pitch.unsqueeze(1), attn_hard_dur).squeeze(1)
pitch = average_features(pitch.unsqueeze(1), attn_hard_dur).squeeze(1)
pitch_emb = self.pitch_emb(pitch.unsqueeze(1))

enc_out = enc_out + pitch_emb.transpose(1, 2)
Expand Down Expand Up @@ -381,7 +381,7 @@ def infer(

# Avg pitch, pitch predictor
if use_gt_durs and pitch is not None:
pitch = average_pitch(pitch.unsqueeze(1), attn_hard_dur).squeeze(1)
pitch = average_features(pitch.unsqueeze(1), attn_hard_dur).squeeze(1)
pitch_emb = self.pitch_emb(pitch.unsqueeze(1))
else:
pitch_predicted = self.pitch_predictor(enc_out, enc_mask)
Expand Down Expand Up @@ -565,7 +565,7 @@ def validation_step(self, batch, batch_idx):
pitches += [
wandb.Image(
plot_pitch_to_numpy(
average_pitch(pitch.unsqueeze(1), attn_hard_dur)
average_features(pitch.unsqueeze(1), attn_hard_dur)
.squeeze(1)[i, : text_len[i]]
.data.cpu()
.numpy(),
Expand Down
Loading

0 comments on commit 14b3bf5

Please sign in to comment.