Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TTS] Fix bugs in HiFi-GAN (scheduler, optimizers) and add input_example() in Mixer-TTS/Mixer-TTS-X #3564

Merged
merged 10 commits into from
Feb 2, 2022
8 changes: 4 additions & 4 deletions examples/tts/conf/hifigan/hifigan.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,10 @@ model:
lr: 0.0002
betas: [0.8, 0.99]

sched:
name: CosineAnnealing
min_lr: 1e-5
warmup_ratio: 0.02
sched:
name: CosineAnnealing
min_lr: 1e-5
warmup_ratio: 0.02

max_steps: 25000000
l1_loss_factor: 45
Expand Down
8 changes: 4 additions & 4 deletions examples/tts/conf/hifigan/hifigan_44100.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,10 @@ model:
lr: 0.0002
betas: [0.8, 0.99]

sched:
name: CosineAnnealing
min_lr: 1e-5
warmup_ratio: 0.02
sched:
name: CosineAnnealing
min_lr: 1e-5
warmup_ratio: 0.02

max_steps: 25000000
l1_loss_factor: 45
Expand Down
2 changes: 1 addition & 1 deletion nemo/collections/tts/models/fastpitch.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,7 +488,7 @@ def input_example(self, max_batch=1, max_dim=256):

if self.fastpitch.speaker_emb is not None:
inputs['speaker'] = torch.randint(
0, self.fastpitch.speaker_emb.num_embeddings, (maz_batch,), device=par.device, dtype=torch.int64
0, self.fastpitch.speaker_emb.num_embeddings, (max_batch,), device=par.device, dtype=torch.int64
)

return (inputs,)
Expand Down
64 changes: 36 additions & 28 deletions nemo/collections/tts/models/hifigan.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import torch
import torch.nn.functional as F
from hydra.utils import instantiate
from omegaconf import DictConfig, OmegaConf, open_dict
from omegaconf import DictConfig, open_dict
from pytorch_lightning.loggers.wandb import WandbLogger

from nemo.collections.tts.data.datalayers import MelAudioDataset
Expand All @@ -30,7 +30,7 @@
from nemo.core.neural_types.elements import AudioSignal, MelSpectrogramType
from nemo.core.neural_types.neural_type import NeuralType
from nemo.core.optim.lr_scheduler import CosineAnnealing, compute_max_steps
from nemo.utils import logging
from nemo.utils import logging, model_utils

HAVE_WANDB = True
try:
Expand All @@ -41,8 +41,10 @@

class HifiGanModel(Vocoder, Exportable):
def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None):
if isinstance(cfg, dict):
cfg = OmegaConf.create(cfg)
# Convert to Hydra 1.0 compatible DictConfig
cfg = model_utils.convert_model_config_to_dict_config(cfg)
cfg = model_utils.maybe_update_config_version(cfg)

super().__init__(cfg=cfg, trainer=trainer)

self.audio_to_melspec_precessor = instantiate(cfg.preprocessor)
Expand Down Expand Up @@ -80,10 +82,7 @@ def _get_max_steps(self):
drop_last=self._train_dl.drop_last,
)

def _get_warmup_steps(self, max_steps):
warmup_steps = self._cfg.sched.get("warmup_steps", None)
warmup_ratio = self._cfg.sched.get("warmup_ratio", None)

def _get_warmup_steps(self, max_steps, warmup_steps, warmup_ratio):
if warmup_steps is not None and warmup_ratio is not None:
raise ValueError(f'Either use warmup_steps or warmup_ratio for scheduler')

Expand All @@ -96,37 +95,44 @@ def _get_warmup_steps(self, max_steps):
raise ValueError(f'Specify warmup_steps or warmup_ratio for scheduler')

def configure_optimizers(self):
self.optim_g = instantiate(self._cfg.optim, params=self.generator.parameters(),)
self.optim_d = instantiate(
self._cfg.optim, params=itertools.chain(self.msd.parameters(), self.mpd.parameters()),
)
optim_config = self._cfg.optim.copy()
sched_config = optim_config.pop("sched", None)

optim_g = instantiate(optim_config, params=self.generator.parameters(),)
optim_d = instantiate(optim_config, params=itertools.chain(self.msd.parameters(), self.mpd.parameters()),)

if hasattr(self._cfg, 'sched'):
# backward compatibility
if sched_config is None and 'sched' in self._cfg:
sched_config = self._cfg.sched

if sched_config is not None:
max_steps = self._cfg.get("max_steps", None)
if max_steps is None or max_steps < 0:
max_steps = self._get_max_steps()

warmup_steps = self._get_warmup_steps(max_steps)
warmup_steps = self._get_warmup_steps(
max_steps=max_steps,
warmup_steps=sched_config.get("warmup_steps", None),
warmup_ratio=sched_config.get("warmup_ratio", None),
)

self.scheduler_g = CosineAnnealing(
optimizer=self.optim_g, max_steps=max_steps, min_lr=self._cfg.sched.min_lr, warmup_steps=warmup_steps,
scheduler_g = CosineAnnealing(
optimizer=optim_g, max_steps=max_steps, min_lr=sched_config.min_lr, warmup_steps=warmup_steps,
) # Use warmup to delay start
sch1_dict = {
'scheduler': self.scheduler_g,
'scheduler': scheduler_g,
'interval': 'step',
}

self.scheduler_d = CosineAnnealing(
optimizer=self.optim_d, max_steps=max_steps, min_lr=self._cfg.sched.min_lr,
)
scheduler_d = CosineAnnealing(optimizer=optim_d, max_steps=max_steps, min_lr=sched_config.min_lr,)
sch2_dict = {
'scheduler': self.scheduler_d,
'scheduler': scheduler_d,
'interval': 'step',
}

return [self.optim_g, self.optim_d], [sch1_dict, sch2_dict]
return [optim_g, optim_d], [sch1_dict, sch2_dict]
else:
return [self.optim_g, self.optim_d]
return [optim_g, optim_d]

@property
def input_types(self):
Expand Down Expand Up @@ -172,8 +178,10 @@ def training_step(self, batch, batch_idx):
audio_pred = self.generator(x=audio_mel)
audio_pred_mel, _ = self.trg_melspec_fn(audio_pred.squeeze(1), audio_len)

optim_g, optim_d = self.optimizers()

# train discriminator
self.optim_d.zero_grad()
optim_d.zero_grad()
mpd_score_real, mpd_score_gen, _, _ = self.mpd(y=audio, y_hat=audio_pred.detach())
loss_disc_mpd, _, _ = self.discriminator_loss(
disc_real_outputs=mpd_score_real, disc_generated_outputs=mpd_score_gen
Expand All @@ -184,10 +192,10 @@ def training_step(self, batch, batch_idx):
)
loss_d = loss_disc_msd + loss_disc_mpd
self.manual_backward(loss_d)
self.optim_d.step()
optim_d.step()

# train generator
self.optim_g.zero_grad()
optim_g.zero_grad()
loss_mel = F.l1_loss(audio_pred_mel, audio_trg_mel)
_, mpd_score_gen, fmap_mpd_real, fmap_mpd_gen = self.mpd(y=audio, y_hat=audio_pred)
_, msd_score_gen, fmap_msd_real, fmap_msd_gen = self.msd(y=audio, y_hat=audio_pred)
Expand All @@ -197,7 +205,7 @@ def training_step(self, batch, batch_idx):
loss_gen_msd, _ = self.generator_loss(disc_outputs=msd_score_gen)
loss_g = loss_gen_msd + loss_gen_mpd + loss_fm_msd + loss_fm_mpd + loss_mel * self.l1_factor
self.manual_backward(loss_g)
self.optim_g.step()
optim_g.step()

# run schedulers
schedulers = self.lr_schedulers()
Expand All @@ -216,7 +224,7 @@ def training_step(self, batch, batch_idx):
"d_loss_msd": loss_disc_msd,
"d_loss": loss_d,
"global_step": self.global_step,
"lr": self.optim_g.param_groups[0]['lr'],
"lr": optim_g.param_groups[0]['lr'],
}
self.log_dict(metrics, on_step=True, sync_dist=True)
self.log("g_l1_loss", loss_mel, prog_bar=True, logger=False, sync_dist=True)
Expand Down
18 changes: 18 additions & 0 deletions nemo/collections/tts/models/mixer_tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -714,6 +714,24 @@ def output_types(self):
"spect": NeuralType(('B', 'D', 'T_spec'), MelSpectrogramType()),
}

def input_example(self, max_text_len=10, max_lm_tokens_len=10):
text = torch.randint(
low=0, high=len(self.tokenizer.tokens), size=(1, max_text_len), device=self.device, dtype=torch.long,
)

inputs = {'text': text}

if self.cond_on_lm_embeddings:
inputs['lm_tokens'] = torch.randint(
low=0,
high=self.lm_embeddings.weight.shape[0],
size=(1, max_lm_tokens_len),
device=self.device,
dtype=torch.long,
)

return (inputs,)

def forward_for_export(self, text, lm_tokens=None):
text_mask = (text != self.tokenizer_pad).unsqueeze(2)
spect = self.infer(text=text, text_mask=text_mask, lm_tokens=lm_tokens).transpose(1, 2)
Expand Down