Skip to content

Commit

Permalink
E2E TTS fixes (#3508)
Browse files Browse the repository at this point in the history
* E2E TTS fixes: types, keyword args, etc.

Signed-off-by: Jocelyn Huang <[email protected]>

* Fix import for FastPitch_HiFiGAN type

Signed-off-by: Jocelyn Huang <[email protected]>
  • Loading branch information
redoctopus authored Jan 26, 2022
1 parent 3146fca commit 360fa7c
Show file tree
Hide file tree
Showing 5 changed files with 45 additions and 34 deletions.
4 changes: 2 additions & 2 deletions examples/tts/conf/fastspeech2_hifigan_e2e.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ model:
preprocessor:
_target_: nemo.collections.asr.parts.preprocessing.features.FilterbankFeatures
dither: 0.0
nfilt: ${n_mels}
nfilt: ${model.n_mels}
frame_splicing: 1
highfreq: 8000
log: true
Expand Down Expand Up @@ -133,4 +133,4 @@ exp_manager:
exp_dir: null
name: ${name}
create_tensorboard_logger: True
create_checkpoint_callback: True
create_checkpoint_callback: True
11 changes: 9 additions & 2 deletions nemo/collections/tts/losses/fastspeech2loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,14 @@

from nemo.collections.tts.helpers.helpers import get_mask_from_lengths
from nemo.core.classes import Loss, typecheck
from nemo.core.neural_types.elements import LengthsType, LossType, MaskType, MelSpectrogramType, TokenDurationType
from nemo.core.neural_types.elements import (
LengthsType,
LossType,
MaskType,
MelSpectrogramType,
TokenDurationType,
TokenLogDurationType,
)
from nemo.core.neural_types.neural_type import NeuralType


Expand All @@ -26,7 +33,7 @@ class DurationLoss(Loss):
@property
def input_types(self):
return {
"log_duration_pred": NeuralType(('B', 'T'), TokenDurationType()),
"log_duration_pred": NeuralType(('B', 'T'), TokenLogDurationType()),
"duration_target": NeuralType(('B', 'T'), TokenDurationType()),
"mask": NeuralType(('B', 'T', 'D'), MaskType()),
}
Expand Down
23 changes: 13 additions & 10 deletions nemo/collections/tts/models/fastpitch_hifigan_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from nemo.collections.tts.modules.hifigan_modules import MultiPeriodDiscriminator, MultiScaleDiscriminator
from nemo.core.classes.common import PretrainedModelInfo, typecheck
from nemo.core.neural_types.elements import (
MelSpectrogramType,
AudioSignal,
RegressionValuesType,
TokenDurationType,
TokenIndex,
Expand Down Expand Up @@ -202,7 +202,7 @@ def configure_optimizers(self):
"splice": NeuralType(optional=True),
},
output_types={
"audio": NeuralType(('B', 'S', 'T'), MelSpectrogramType()),
"audio": NeuralType(('B', 'S', 'T'), AudioSignal()),
"splices": NeuralType(),
"log_dur_preds": NeuralType(('B', 'T'), TokenLogDurationType()),
"pitch_preds": NeuralType(('B', 'T'), RegressionValuesType()),
Expand Down Expand Up @@ -237,10 +237,9 @@ def forward(self, *, text, durs=None, pitch=None, pace=1.0, splice=True):
len_regulated, dec_lens = regulate_len(durs, enc_out, pace)

gen_in = len_regulated
splices = None
splices = []
if splice:
output = []
splices = []
for i, sample in enumerate(len_regulated):
start = np.random.randint(low=0, high=min(int(sample.size(0)), int(dec_lens[i])) - self.splice_length)
# Splice generated spec
Expand All @@ -250,7 +249,7 @@ def forward(self, *, text, durs=None, pitch=None, pace=1.0, splice=True):

output = self.generator(x=gen_in.transpose(1, 2))

return output, splices, log_durs_predicted, pitch_predicted
return output, torch.tensor(splices), log_durs_predicted, pitch_predicted

def training_step(self, batch, batch_idx, optimizer_idx):
audio, _, text, text_lens, durs, pitch, _ = batch
Expand All @@ -267,8 +266,12 @@ def training_step(self, batch, batch_idx, optimizer_idx):
real_score_mp, gen_score_mp, _, _ = self.multiperioddisc(real_audio, audio_pred)
real_score_ms, gen_score_ms, _, _ = self.multiscaledisc(real_audio, audio_pred)

loss_mp, loss_mp_real, _ = self.disc_loss(real_score_mp, gen_score_mp)
loss_ms, loss_ms_real, _ = self.disc_loss(real_score_ms, gen_score_ms)
loss_mp, loss_mp_real, _ = self.disc_loss(
disc_real_outputs=real_score_mp, disc_generated_outputs=gen_score_mp
)
loss_ms, loss_ms_real, _ = self.disc_loss(
disc_real_outputs=real_score_ms, disc_generated_outputs=gen_score_ms
)
loss_mp /= len(loss_mp_real)
loss_ms /= len(loss_ms_real)
loss_disc = loss_mp + loss_ms
Expand Down Expand Up @@ -298,9 +301,9 @@ def training_step(self, batch, batch_idx, optimizer_idx):
loss_mel = torch.nn.functional.l1_loss(real_spliced_spec, pred_spliced_spec)
loss_mel *= self.mel_loss_coeff
_, gen_score_mp, _, _ = self.multiperioddisc(real_audio, audio_pred)
_, gen_score_ms, _, _ = self.multiscaledisc(real_audio, audio_pred)
loss_gen_mp, list_loss_gen_mp = self.gen_loss(gen_score_mp)
loss_gen_ms, list_loss_gen_ms = self.gen_loss(gen_score_ms)
_, gen_score_ms, _, _ = self.multiscaledisc(y=real_audio, y_hat=audio_pred)
loss_gen_mp, list_loss_gen_mp = self.gen_loss(disc_outputs=gen_score_mp)
loss_gen_ms, list_loss_gen_ms = self.gen_loss(disc_outputs=gen_score_ms)
loss_gen_mp /= len(list_loss_gen_mp)
loss_gen_ms /= len(list_loss_gen_ms)
total_loss = loss_gen_mp + loss_gen_ms + loss_mel
Expand Down
2 changes: 1 addition & 1 deletion nemo/collections/tts/models/fastspeech2.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None):
},
output_types={
"mel_spec": NeuralType(('B', 'T', 'C'), MelSpectrogramType()),
"log_dur_preds": NeuralType(('B', 'T'), TokenDurationType(), optional=True),
"log_dur_preds": NeuralType(('B', 'T'), TokenLogDurationType(), optional=True),
"pitch_preds": NeuralType(('B', 'T'), RegressionValuesType(), optional=True),
"energy_preds": NeuralType(('B', 'T'), RegressionValuesType(), optional=True),
"encoded_text_mask": NeuralType(('B', 'T', 'D'), MaskType()),
Expand Down
39 changes: 20 additions & 19 deletions nemo/collections/tts/models/fastspeech2_hifigan_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@
from nemo.collections.tts.modules.hifigan_modules import MultiPeriodDiscriminator, MultiScaleDiscriminator
from nemo.core.classes.common import PretrainedModelInfo, typecheck
from nemo.core.neural_types.elements import (
AudioSignal,
LengthsType,
MaskType,
MelSpectrogramType,
RegressionValuesType,
TokenDurationType,
TokenIndex,
Expand All @@ -52,7 +52,7 @@ def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None):
cfg = OmegaConf.create(cfg)
super().__init__(cfg=cfg, trainer=trainer)

self.audio_to_melspec_precessor = instantiate(cfg.preprocessor)
self.audio_to_melspec_preprocessor = instantiate(cfg.preprocessor)
self.encoder = instantiate(cfg.encoder)
self.variance_adapter = instantiate(cfg.variance_adaptor)

Expand Down Expand Up @@ -145,7 +145,7 @@ def configure_optimizers(self):
"energies": NeuralType(('B', 'T'), RegressionValuesType(), optional=True),
},
output_types={
"audio": NeuralType(('B', 'S', 'T'), MelSpectrogramType()),
"audio": NeuralType(('B', 'S', 'T'), AudioSignal()),
"splices": NeuralType(),
"log_dur_preds": NeuralType(('B', 'T'), TokenLogDurationType()),
"pitch_preds": NeuralType(('B', 'T'), RegressionValuesType()),
Expand All @@ -166,11 +166,10 @@ def forward(self, *, text, text_length, splice=True, durations=None, pitch=None,
)

gen_in = context
splices = None
splices = []
if splice:
# Splice generated spec
output = []
splices = []
for i, sample in enumerate(context):
start = np.random.randint(low=0, high=min(int(sample.size(0)), int(spec_len[i])) - self.splice_length)
output.append(sample[start : start + self.splice_length, :])
Expand All @@ -179,17 +178,16 @@ def forward(self, *, text, text_length, splice=True, durations=None, pitch=None,

output = self.generator(x=gen_in.transpose(1, 2))

return output, splices, log_dur_preds, pitch_preds, energy_preds, encoded_text_mask
return output, torch.tensor(splices), log_dur_preds, pitch_preds, energy_preds, encoded_text_mask

def training_step(self, batch, batch_idx, optimizer_idx):
f, fl, t, tl, durations, pitch, energies = batch
spec, spec_len = self.audio_to_melspec_precessor(f, fl)
_, spec_len = self.audio_to_melspec_preprocessor(f, fl)

# train discriminator
if optimizer_idx == 0:
with torch.no_grad():
audio_pred, splices, _, _, _, _ = self(
spec=spec,
spec_len=spec_len,
text=t,
text_length=tl,
Expand All @@ -203,10 +201,14 @@ def training_step(self, batch, batch_idx, optimizer_idx):
real_audio = torch.stack(real_audio).unsqueeze(1)

real_score_mp, gen_score_mp, _, _ = self.multiperioddisc(real_audio, audio_pred)
real_score_ms, gen_score_ms, _, _ = self.multiscaledisc(real_audio, audio_pred)
real_score_ms, gen_score_ms, _, _ = self.multiscaledisc(y=real_audio, y_hat=audio_pred)

loss_mp, loss_mp_real, _ = self.disc_loss(real_score_mp, gen_score_mp)
loss_ms, loss_ms_real, _ = self.disc_loss(real_score_ms, gen_score_ms)
loss_mp, loss_mp_real, _ = self.disc_loss(
disc_real_outputs=real_score_mp, disc_generated_outputs=gen_score_mp
)
loss_ms, loss_ms_real, _ = self.disc_loss(
disc_real_outputs=real_score_ms, disc_generated_outputs=gen_score_ms
)
loss_mp /= len(loss_mp_real)
loss_ms /= len(loss_ms_real)
loss_disc = loss_mp + loss_ms
Expand All @@ -219,7 +221,6 @@ def training_step(self, batch, batch_idx, optimizer_idx):
# train generator
elif optimizer_idx == 1:
audio_pred, splices, log_dur_preds, pitch_preds, energy_preds, encoded_text_mask = self(
spec=spec,
spec_len=spec_len,
text=t,
text_length=tl,
Expand All @@ -241,14 +242,14 @@ def training_step(self, batch, batch_idx, optimizer_idx):
loss_mel = torch.nn.functional.l1_loss(real_spliced_spec, pred_spliced_spec)
loss_mel *= self.mel_loss_coeff
_, gen_score_mp, real_feat_mp, gen_feat_mp = self.multiperioddisc(real_audio, audio_pred)
_, gen_score_ms, real_feat_ms, gen_feat_ms = self.multiscaledisc(real_audio, audio_pred)
loss_gen_mp, list_loss_gen_mp = self.gen_loss(gen_score_mp)
loss_gen_ms, list_loss_gen_ms = self.gen_loss(gen_score_ms)
_, gen_score_ms, real_feat_ms, gen_feat_ms = self.multiscaledisc(y=real_audio, y_hat=audio_pred)
loss_gen_mp, list_loss_gen_mp = self.gen_loss(disc_outputs=gen_score_mp)
loss_gen_ms, list_loss_gen_ms = self.gen_loss(disc_outputs=gen_score_ms)
loss_gen_mp /= len(list_loss_gen_mp)
loss_gen_ms /= len(list_loss_gen_ms)
total_loss = loss_gen_mp + loss_gen_ms + loss_mel
loss_feat_mp = self.feat_matching_loss(real_feat_mp, gen_feat_mp)
loss_feat_ms = self.feat_matching_loss(real_feat_ms, gen_feat_ms)
loss_feat_mp = self.feat_matching_loss(fmap_r=real_feat_mp, fmap_g=gen_feat_mp)
loss_feat_ms = self.feat_matching_loss(fmap_r=real_feat_ms, fmap_g=gen_feat_ms)
total_loss += loss_feat_mp + loss_feat_ms
self.log(name="loss_gen_disc_feat", value=loss_feat_mp + loss_feat_ms)
self.log(name="loss_gen_disc_feat_ms", value=loss_feat_ms)
Expand Down Expand Up @@ -295,8 +296,8 @@ def training_step(self, batch, batch_idx, optimizer_idx):

def validation_step(self, batch, batch_idx):
f, fl, t, tl, _, _, _ = batch
spec, spec_len = self.audio_to_melspec_precessor(f, fl)
audio_pred, _, _, _, _, _ = self(spec=spec, spec_len=spec_len, text=t, text_length=tl, splice=False)
spec, spec_len = self.audio_to_melspec_preprocessor(f, fl)
audio_pred, _, _, _, _, _ = self(spec_len=spec_len, text=t, text_length=tl, splice=False)
audio_pred.squeeze_()
pred_spec, _ = self.melspec_fn(audio_pred, seq_len=spec_len)
loss = self.mel_val_loss(spec_pred=pred_spec, spec_target=spec, spec_target_len=spec_len, pad_value=-11.52)
Expand Down

0 comments on commit 360fa7c

Please sign in to comment.