Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TTS] Fastpitch energy condition and refactoring #5218

Merged
merged 23 commits into from
Oct 27, 2022
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
35cb74f
Incorporating Energy conditioning in FastPitch
subhankar-ghosh Sep 19, 2022
4042f99
Minor fixes in Energy conditioning in FastPitch
subhankar-ghosh Sep 20, 2022
da674ef
Add Energy conditioning in FastPitch to infer method
subhankar-ghosh Sep 20, 2022
8d6e48c
adding fn to function names
subhankar-ghosh Oct 15, 2022
4b084bd
Incorporating Energy conditioning in FastPitch
subhankar-ghosh Sep 19, 2022
917c239
Minor fixes in Energy conditioning in FastPitch
subhankar-ghosh Sep 20, 2022
7d29f68
Add Energy conditioning in FastPitch to infer method
subhankar-ghosh Sep 20, 2022
7b35ed8
adding fn to function names
subhankar-ghosh Oct 15, 2022
5862766
rebase main
subhankar-ghosh Oct 15, 2022
e60b97c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 15, 2022
718e2ef
remove ifelse from batching, minor refactoring changes in energy code
subhankar-ghosh Oct 20, 2022
0e4b9fc
Merge conflict
subhankar-ghosh Oct 20, 2022
65bd158
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 20, 2022
6a59935
Refactor based on PR comments.
subhankar-ghosh Oct 21, 2022
ee57f13
Refactor based on PR comments.
subhankar-ghosh Oct 21, 2022
ebc9ee1
Added support for not learning alignment in energy
subhankar-ghosh Oct 21, 2022
b87f7e3
Fix typo in assert statemetn
subhankar-ghosh Oct 21, 2022
5774084
Renaming average_pitch to average_features
subhankar-ghosh Oct 21, 2022
49e8674
Merge branch 'main' into fastpitch_energy_condition
subhankar-ghosh Oct 21, 2022
82bca3c
Merge branch 'main' into fastpitch_energy_condition
XuesongYang Oct 24, 2022
beda201
Renaming len variable name as it is a keyword
subhankar-ghosh Oct 26, 2022
06dd6b5
Renaming len variable name as it is a keyword
subhankar-ghosh Oct 26, 2022
90a9f73
Merge branch 'main' into fastpitch_energy_condition
subhankar-ghosh Oct 27, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions nemo/collections/tts/helpers/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
from pesq import pesq
from pystoi import stoi

from nemo.collections.tts.torch.tts_data_types import DATA_STR2DATA_CLASS, MAIN_DATA_TYPES, WithLens
from nemo.utils import logging

HAVE_WANDB = True
Expand Down Expand Up @@ -560,3 +561,16 @@ def split_view(tensor, split_size: int, dim: int = 0):
cur_shape = tensor.shape
new_shape = cur_shape[:dim] + (tensor.shape[dim] // split_size, split_size) + cur_shape[dim + 1 :]
return tensor.reshape(*new_shape)


def process_batch(batch_data, sup_data_types_set):
batch_dict = {}
batch_index = 0
for name, datatype in DATA_STR2DATA_CLASS.items():
if datatype in MAIN_DATA_TYPES or datatype in sup_data_types_set:
batch_dict[name] = batch_data[batch_index]
batch_index = batch_index + 1
if issubclass(datatype, WithLens):
batch_dict[name + "_lens"] = batch_data[batch_index]
batch_index = batch_index + 1
return batch_dict
31 changes: 31 additions & 0 deletions nemo/collections/tts/losses/fastpitchloss.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,37 @@ def forward(self, pitch_predicted, pitch_tgt, len):
return pitch_loss


class EnergyLoss(Loss):
def __init__(self, loss_scale=0.1):
super().__init__()
self.loss_scale = loss_scale

@property
def input_types(self):
return {
"energy_predicted": NeuralType(('B', 'T'), RegressionValuesType()),
"energy_tgt": NeuralType(('B', 'T'), RegressionValuesType()),
"len": NeuralType(('B'), LengthsType()),
}

@property
def output_types(self):
return {
"loss": NeuralType(elements_type=LossType()),
}

@typecheck()
def forward(self, energy_predicted, energy_tgt, length):
if energy_tgt is None:
return 0.0
subhankar-ghosh marked this conversation as resolved.
Show resolved Hide resolved
dur_mask = mask_from_lens(length, max_len=energy_tgt.size(1))
energy_loss = F.mse_loss(energy_tgt, energy_predicted, reduction='none')
energy_loss = (energy_loss * dur_mask).sum() / dur_mask.sum()
energy_loss *= self.loss_scale

return energy_loss


class MelLoss(Loss):
@property
def input_types(self):
Expand Down
113 changes: 76 additions & 37 deletions nemo/collections/tts/models/fastpitch.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,11 @@
from pytorch_lightning.loggers import LoggerCollection, TensorBoardLogger

from nemo.collections.common.parts.preprocessing import parsers
from nemo.collections.tts.helpers.helpers import plot_alignment_to_numpy, plot_spectrogram_to_numpy
from nemo.collections.tts.helpers.helpers import plot_alignment_to_numpy, plot_spectrogram_to_numpy, process_batch
from nemo.collections.tts.losses.aligner_loss import BinLoss, ForwardSumLoss
from nemo.collections.tts.losses.fastpitchloss import DurationLoss, MelLoss, PitchLoss
from nemo.collections.tts.losses.fastpitchloss import DurationLoss, EnergyLoss, MelLoss, PitchLoss
from nemo.collections.tts.models.base import SpectrogramGenerator
from nemo.collections.tts.modules.fastpitch import FastPitchModule
from nemo.collections.tts.torch.tts_data_types import SpeakerID
from nemo.core.classes import Exportable
from nemo.core.classes.common import PretrainedModelInfo, typecheck
from nemo.core.neural_types.elements import (
Expand Down Expand Up @@ -119,20 +118,24 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None):
loss_scale = 0.1 if self.learn_alignment else 1.0
dur_loss_scale = loss_scale
pitch_loss_scale = loss_scale
energy_loss_scale = loss_scale
if "dur_loss_scale" in cfg:
dur_loss_scale = cfg.dur_loss_scale
if "pitch_loss_scale" in cfg:
pitch_loss_scale = cfg.pitch_loss_scale
if "energy_loss_scale" in cfg:
energy_loss_scale = cfg.energy_loss_scale

self.mel_loss = MelLoss()
self.pitch_loss = PitchLoss(loss_scale=pitch_loss_scale)
self.duration_loss = DurationLoss(loss_scale=dur_loss_scale)
self.mel_loss_fn = MelLoss()
self.pitch_loss_fn = PitchLoss(loss_scale=pitch_loss_scale)
self.duration_loss_fn = DurationLoss(loss_scale=dur_loss_scale)
self.energy_loss_fn = EnergyLoss(loss_scale=energy_loss_scale)

self.aligner = None
if self.learn_alignment:
self.aligner = instantiate(self._cfg.alignment_module)
self.forward_sum_loss = ForwardSumLoss()
self.bin_loss = BinLoss()
self.forward_sum_loss_fn = ForwardSumLoss()
self.bin_loss_fn = BinLoss()

self.preprocessor = instantiate(self._cfg.preprocessor)
input_fft = instantiate(self._cfg.input_fft, **input_fft_kwargs)
Expand All @@ -142,16 +145,20 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None):
speaker_emb_condition_prosody = cfg.get("speaker_emb_condition_prosody", False)
speaker_emb_condition_decoder = cfg.get("speaker_emb_condition_decoder", False)
speaker_emb_condition_aligner = cfg.get("speaker_emb_condition_aligner", False)
energy_embedding_kernel_size = cfg.get("energy_embedding_kernel_size", 0)
energy_predictor = instantiate(self._cfg.get("energy_predictor", None))

self.fastpitch = FastPitchModule(
input_fft,
output_fft,
duration_predictor,
pitch_predictor,
energy_predictor,
self.aligner,
cfg.n_speakers,
cfg.symbols_embedding_dim,
cfg.pitch_embedding_kernel_size,
energy_embedding_kernel_size,
cfg.n_mel_channels,
cfg.max_token_duration,
speaker_emb_condition_prosody,
Expand Down Expand Up @@ -284,6 +291,7 @@ def parse(self, str_input: str, normalize=True) -> torch.tensor:
"text": NeuralType(('B', 'T_text'), TokenIndex()),
"durs": NeuralType(('B', 'T_text'), TokenDurationType()),
"pitch": NeuralType(('B', 'T_audio'), RegressionValuesType()),
"energy": NeuralType(('B', 'T_audio'), RegressionValuesType(), optional=True),
"speaker": NeuralType(('B'), Index(), optional=True),
"pace": NeuralType(optional=True),
"spec": NeuralType(('B', 'D', 'T_spec'), MelSpectrogramType(), optional=True),
Expand All @@ -298,6 +306,7 @@ def forward(
text,
durs=None,
pitch=None,
energy=None,
speaker=None,
pace=1.0,
spec=None,
Expand All @@ -309,6 +318,7 @@ def forward(
text=text,
durs=durs,
pitch=pitch,
energy=energy,
speaker=speaker,
pace=pace,
spec=spec,
Expand All @@ -329,24 +339,41 @@ def generate_spectrogram(
return spect

def training_step(self, batch, batch_idx):
attn_prior, durs, speaker = None, None, None
attn_prior, durs, speaker, energy = None, None, None, None
if self.learn_alignment:
if self.ds_class_name == "TTSDataset":
if SpeakerID in self._train_dl.dataset.sup_data_types_set:
audio, audio_lens, text, text_lens, attn_prior, pitch, _, speaker = batch
else:
audio, audio_lens, text, text_lens, attn_prior, pitch, _ = batch
else:
raise ValueError(f"Unknown vocab class: {self.vocab.__class__.__name__}")
assert self.ds_class_name == "TTSDataset", f"Unknown dataset class: {self.ds_class_name}"
batch_dict = process_batch(batch, self._train_dl.dataset.sup_data_types_set)
audio = batch_dict.get("audio")
audio_lens = batch_dict.get("audio_lens")
text = batch_dict.get("text")
text_lens = batch_dict.get("text_lens")
attn_prior = batch_dict.get("align_prior_matrix", None)
pitch = batch_dict.get("pitch", None)
energy = batch_dict.get("energy", None)
speaker = batch_dict.get("speaker_id", None)
else:
audio, audio_lens, text, text_lens, durs, pitch, speaker = batch

mels, spec_len = self.preprocessor(input_signal=audio, length=audio_lens)

mels_pred, _, _, log_durs_pred, pitch_pred, attn_soft, attn_logprob, attn_hard, attn_hard_dur, pitch = self(
(
mels_pred,
_,
_,
log_durs_pred,
pitch_pred,
attn_soft,
attn_logprob,
attn_hard,
attn_hard_dur,
pitch,
energy_pred,
energy_tgt,
) = self(
text=text,
durs=durs,
pitch=pitch,
energy=energy,
speaker=speaker,
pace=1.0,
spec=mels if self.learn_alignment else None,
Expand All @@ -357,22 +384,25 @@ def training_step(self, batch, batch_idx):
if durs is None:
durs = attn_hard_dur

mel_loss = self.mel_loss(spect_predicted=mels_pred, spect_tgt=mels)
dur_loss = self.duration_loss(log_durs_predicted=log_durs_pred, durs_tgt=durs, len=text_lens)
mel_loss = self.mel_loss_fn(spect_predicted=mels_pred, spect_tgt=mels)
dur_loss = self.duration_loss_fn(log_durs_predicted=log_durs_pred, durs_tgt=durs, len=text_lens)
loss = mel_loss + dur_loss
if self.learn_alignment:
ctc_loss = self.forward_sum_loss(attn_logprob=attn_logprob, in_lens=text_lens, out_lens=spec_len)
ctc_loss = self.forward_sum_loss_fn(attn_logprob=attn_logprob, in_lens=text_lens, out_lens=spec_len)
bin_loss_weight = min(self.current_epoch / self.bin_loss_warmup_epochs, 1.0) * 1.0
bin_loss = self.bin_loss(hard_attention=attn_hard, soft_attention=attn_soft) * bin_loss_weight
bin_loss = self.bin_loss_fn(hard_attention=attn_hard, soft_attention=attn_soft) * bin_loss_weight
loss += ctc_loss + bin_loss

pitch_loss = self.pitch_loss(pitch_predicted=pitch_pred, pitch_tgt=pitch, len=text_lens)
loss += pitch_loss
pitch_loss = self.pitch_loss_fn(pitch_predicted=pitch_pred, pitch_tgt=pitch, len=text_lens)
energy_loss = self.energy_loss_fn(energy_predicted=energy_pred, energy_tgt=energy_tgt, len=text_lens)
loss += pitch_loss + energy_loss

self.log("t_loss", loss)
self.log("t_mel_loss", mel_loss)
self.log("t_dur_loss", dur_loss)
self.log("t_pitch_loss", pitch_loss)
if energy_tgt is not None:
self.log("t_energy_loss", energy_loss)
if self.learn_alignment:
self.log("t_ctc_loss", ctc_loss)
self.log("t_bin_loss", bin_loss)
Expand Down Expand Up @@ -404,25 +434,29 @@ def training_step(self, batch, batch_idx):
return loss

def validation_step(self, batch, batch_idx):
attn_prior, durs, speaker = None, None, None
attn_prior, durs, speaker, energy = None, None, None, None
if self.learn_alignment:
if self.ds_class_name == "TTSDataset":
if SpeakerID in self._train_dl.dataset.sup_data_types_set:
audio, audio_lens, text, text_lens, attn_prior, pitch, _, speaker = batch
else:
audio, audio_lens, text, text_lens, attn_prior, pitch, _ = batch
else:
raise ValueError(f"Unknown vocab class: {self.vocab.__class__.__name__}")
assert self.ds_class_name == "TTSDataset", f"Unknown dataset class: {self.ds_class_name}"
batch_dict = process_batch(batch, self._train_dl.dataset.sup_data_types_set)
audio = batch_dict.get("audio")
audio_lens = batch_dict.get("audio_lens")
text = batch_dict.get("text")
text_lens = batch_dict.get("text_lens")
attn_prior = batch_dict.get("align_prior_matrix", None)
pitch = batch_dict.get("pitch", None)
energy = batch_dict.get("energy", None)
speaker = batch_dict.get("speaker_id", None)
else:
audio, audio_lens, text, text_lens, durs, pitch, speaker = batch

mels, mel_lens = self.preprocessor(input_signal=audio, length=audio_lens)

# Calculate val loss on ground truth durations to better align L2 loss in time
mels_pred, _, _, log_durs_pred, pitch_pred, _, _, _, attn_hard_dur, pitch = self(
(mels_pred, _, _, log_durs_pred, pitch_pred, _, _, _, attn_hard_dur, pitch, energy_pred, energy_tgt,) = self(
text=text,
durs=durs,
pitch=pitch,
energy=energy,
speaker=speaker,
pace=1.0,
spec=mels if self.learn_alignment else None,
Expand All @@ -433,16 +467,18 @@ def validation_step(self, batch, batch_idx):
if durs is None:
durs = attn_hard_dur

mel_loss = self.mel_loss(spect_predicted=mels_pred, spect_tgt=mels)
dur_loss = self.duration_loss(log_durs_predicted=log_durs_pred, durs_tgt=durs, len=text_lens)
pitch_loss = self.pitch_loss(pitch_predicted=pitch_pred, pitch_tgt=pitch, len=text_lens)
loss = mel_loss + dur_loss + pitch_loss
mel_loss = self.mel_loss_fn(spect_predicted=mels_pred, spect_tgt=mels)
dur_loss = self.duration_loss_fn(log_durs_predicted=log_durs_pred, durs_tgt=durs, len=text_lens)
pitch_loss = self.pitch_loss_fn(pitch_predicted=pitch_pred, pitch_tgt=pitch, len=text_lens)
energy_loss = self.energy_loss_fn(energy_predicted=energy_pred, energy_tgt=energy_tgt, len=text_lens)
loss = mel_loss + dur_loss + pitch_loss + energy_loss

return {
"val_loss": loss,
"mel_loss": mel_loss,
"dur_loss": dur_loss,
"pitch_loss": pitch_loss,
"energy_loss": energy_loss if energy_tgt is not None else None,
"mel_target": mels if batch_idx == 0 else None,
"mel_pred": mels_pred if batch_idx == 0 else None,
}
Expand All @@ -457,8 +493,11 @@ def validation_epoch_end(self, outputs):
self.log("val_mel_loss", mel_loss)
self.log("val_dur_loss", dur_loss)
self.log("val_pitch_loss", pitch_loss)
if outputs[0]["energy_loss"] is not None:
energy_loss = collect("energy_loss")
self.log("val_energy_loss", energy_loss)

_, _, _, _, spec_target, spec_predict = outputs[0].values()
_, _, _, _, _, spec_target, spec_predict = outputs[0].values()

if isinstance(self.logger, TensorBoardLogger):
self.tb_logger.add_image(
Expand Down
12 changes: 6 additions & 6 deletions nemo/collections/tts/models/mixer_tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
)
from nemo.collections.tts.losses.aligner_loss import BinLoss, ForwardSumLoss
from nemo.collections.tts.models.base import SpectrogramGenerator
from nemo.collections.tts.modules.fastpitch import average_pitch, regulate_len
from nemo.collections.tts.modules.fastpitch import average_features, regulate_len
from nemo.core import Exportable
from nemo.core.classes.common import PretrainedModelInfo, typecheck
from nemo.core.neural_types.elements import (
Expand Down Expand Up @@ -242,7 +242,7 @@ def _metrics(
if self.add_bin_loss:
bin_loss = self.bin_loss(hard_attention=attn_hard, soft_attention=attn_soft)
loss = loss + self.bin_loss_scale * bin_loss
true_avg_pitch = average_pitch(true_pitch.unsqueeze(1), attn_hard_dur).squeeze(1)
true_avg_pitch = average_features(true_pitch.unsqueeze(1), attn_hard_dur).squeeze(1)

# Pitch loss
pitch_loss = F.mse_loss(pred_pitch, true_avg_pitch, reduction='none') # noqa
Expand Down Expand Up @@ -315,12 +315,12 @@ def forward(self, text, text_len, pitch=None, spect=None, spect_len=None, attn_p
# Avg pitch, add pitch_emb
if not self.training:
if pitch is not None:
pitch = average_pitch(pitch.unsqueeze(1), attn_hard_dur).squeeze(1)
pitch = average_features(pitch.unsqueeze(1), attn_hard_dur).squeeze(1)
pitch_emb = self.pitch_emb(pitch.unsqueeze(1))
else:
pitch_emb = self.pitch_emb(pitch_predicted.unsqueeze(1))
else:
pitch = average_pitch(pitch.unsqueeze(1), attn_hard_dur).squeeze(1)
pitch = average_features(pitch.unsqueeze(1), attn_hard_dur).squeeze(1)
pitch_emb = self.pitch_emb(pitch.unsqueeze(1))

enc_out = enc_out + pitch_emb.transpose(1, 2)
Expand Down Expand Up @@ -381,7 +381,7 @@ def infer(

# Avg pitch, pitch predictor
if use_gt_durs and pitch is not None:
pitch = average_pitch(pitch.unsqueeze(1), attn_hard_dur).squeeze(1)
pitch = average_features(pitch.unsqueeze(1), attn_hard_dur).squeeze(1)
pitch_emb = self.pitch_emb(pitch.unsqueeze(1))
else:
pitch_predicted = self.pitch_predictor(enc_out, enc_mask)
Expand Down Expand Up @@ -565,7 +565,7 @@ def validation_step(self, batch, batch_idx):
pitches += [
wandb.Image(
plot_pitch_to_numpy(
average_pitch(pitch.unsqueeze(1), attn_hard_dur)
average_features(pitch.unsqueeze(1), attn_hard_dur)
.squeeze(1)[i, : text_len[i]]
.data.cpu()
.numpy(),
Expand Down
Loading