Skip to content

Commit

Permalink
[TTS] Fix bugs in HiFi-GAN (scheduler, optimizers) and add input_exam…
Browse files Browse the repository at this point in the history
…ple() in Mixer-TTS/Mixer-TTS-X (#3564)

* update hifigan and mixer tts

Signed-off-by: Oktai Tatanov <[email protected]>

* remove unused import

Signed-off-by: Oktai Tatanov <[email protected]>

* fix bug with sched_config

Signed-off-by: Oktai Tatanov <[email protected]>

* fix bug with set_struct

Signed-off-by: Oktai Tatanov <[email protected]>

Co-authored-by: Jason <[email protected]>
  • Loading branch information
Oktai15 and blisc authored Feb 2, 2022
1 parent de6b46a commit 862429a
Show file tree
Hide file tree
Showing 5 changed files with 65 additions and 36 deletions.
8 changes: 4 additions & 4 deletions examples/tts/conf/hifigan/hifigan.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,10 @@ model:
lr: 0.0002
betas: [0.8, 0.99]

sched:
name: CosineAnnealing
min_lr: 1e-5
warmup_ratio: 0.02
sched:
name: CosineAnnealing
min_lr: 1e-5
warmup_ratio: 0.02

max_steps: 25000000
l1_loss_factor: 45
Expand Down
8 changes: 4 additions & 4 deletions examples/tts/conf/hifigan/hifigan_44100.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,10 @@ model:
lr: 0.0002
betas: [0.8, 0.99]

sched:
name: CosineAnnealing
min_lr: 1e-5
warmup_ratio: 0.02
sched:
name: CosineAnnealing
min_lr: 1e-5
warmup_ratio: 0.02

max_steps: 25000000
l1_loss_factor: 45
Expand Down
2 changes: 1 addition & 1 deletion nemo/collections/tts/models/fastpitch.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,7 +488,7 @@ def input_example(self, max_batch=1, max_dim=256):

if self.fastpitch.speaker_emb is not None:
inputs['speaker'] = torch.randint(
0, self.fastpitch.speaker_emb.num_embeddings, (maz_batch,), device=par.device, dtype=torch.int64
0, self.fastpitch.speaker_emb.num_embeddings, (max_batch,), device=par.device, dtype=torch.int64
)

return (inputs,)
Expand Down
65 changes: 38 additions & 27 deletions nemo/collections/tts/models/hifigan.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from nemo.core.neural_types.elements import AudioSignal, MelSpectrogramType
from nemo.core.neural_types.neural_type import NeuralType
from nemo.core.optim.lr_scheduler import CosineAnnealing, compute_max_steps
from nemo.utils import logging
from nemo.utils import logging, model_utils

HAVE_WANDB = True
try:
Expand All @@ -41,8 +41,10 @@

class HifiGanModel(Vocoder, Exportable):
def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None):
if isinstance(cfg, dict):
cfg = OmegaConf.create(cfg)
# Convert to Hydra 1.0 compatible DictConfig
cfg = model_utils.convert_model_config_to_dict_config(cfg)
cfg = model_utils.maybe_update_config_version(cfg)

super().__init__(cfg=cfg, trainer=trainer)

self.audio_to_melspec_precessor = instantiate(cfg.preprocessor)
Expand Down Expand Up @@ -80,10 +82,7 @@ def _get_max_steps(self):
drop_last=self._train_dl.drop_last,
)

def _get_warmup_steps(self, max_steps):
warmup_steps = self._cfg.sched.get("warmup_steps", None)
warmup_ratio = self._cfg.sched.get("warmup_ratio", None)

def _get_warmup_steps(self, max_steps, warmup_steps, warmup_ratio):
if warmup_steps is not None and warmup_ratio is not None:
raise ValueError(f'Either use warmup_steps or warmup_ratio for scheduler')

Expand All @@ -96,37 +95,47 @@ def _get_warmup_steps(self, max_steps):
raise ValueError(f'Specify warmup_steps or warmup_ratio for scheduler')

def configure_optimizers(self):
self.optim_g = instantiate(self._cfg.optim, params=self.generator.parameters(),)
self.optim_d = instantiate(
self._cfg.optim, params=itertools.chain(self.msd.parameters(), self.mpd.parameters()),
)
optim_config = self._cfg.optim.copy()

OmegaConf.set_struct(optim_config, False)
sched_config = optim_config.pop("sched", None)
OmegaConf.set_struct(optim_config, True)

optim_g = instantiate(optim_config, params=self.generator.parameters(),)
optim_d = instantiate(optim_config, params=itertools.chain(self.msd.parameters(), self.mpd.parameters()),)

if hasattr(self._cfg, 'sched'):
# backward compatibility
if sched_config is None and 'sched' in self._cfg:
sched_config = self._cfg.sched

if sched_config is not None:
max_steps = self._cfg.get("max_steps", None)
if max_steps is None or max_steps < 0:
max_steps = self._get_max_steps()

warmup_steps = self._get_warmup_steps(max_steps)
warmup_steps = self._get_warmup_steps(
max_steps=max_steps,
warmup_steps=sched_config.get("warmup_steps", None),
warmup_ratio=sched_config.get("warmup_ratio", None),
)

self.scheduler_g = CosineAnnealing(
optimizer=self.optim_g, max_steps=max_steps, min_lr=self._cfg.sched.min_lr, warmup_steps=warmup_steps,
scheduler_g = CosineAnnealing(
optimizer=optim_g, max_steps=max_steps, min_lr=sched_config.min_lr, warmup_steps=warmup_steps,
) # Use warmup to delay start
sch1_dict = {
'scheduler': self.scheduler_g,
'scheduler': scheduler_g,
'interval': 'step',
}

self.scheduler_d = CosineAnnealing(
optimizer=self.optim_d, max_steps=max_steps, min_lr=self._cfg.sched.min_lr,
)
scheduler_d = CosineAnnealing(optimizer=optim_d, max_steps=max_steps, min_lr=sched_config.min_lr,)
sch2_dict = {
'scheduler': self.scheduler_d,
'scheduler': scheduler_d,
'interval': 'step',
}

return [self.optim_g, self.optim_d], [sch1_dict, sch2_dict]
return [optim_g, optim_d], [sch1_dict, sch2_dict]
else:
return [self.optim_g, self.optim_d]
return [optim_g, optim_d]

@property
def input_types(self):
Expand Down Expand Up @@ -172,8 +181,10 @@ def training_step(self, batch, batch_idx):
audio_pred = self.generator(x=audio_mel)
audio_pred_mel, _ = self.trg_melspec_fn(audio_pred.squeeze(1), audio_len)

optim_g, optim_d = self.optimizers()

# train discriminator
self.optim_d.zero_grad()
optim_d.zero_grad()
mpd_score_real, mpd_score_gen, _, _ = self.mpd(y=audio, y_hat=audio_pred.detach())
loss_disc_mpd, _, _ = self.discriminator_loss(
disc_real_outputs=mpd_score_real, disc_generated_outputs=mpd_score_gen
Expand All @@ -184,10 +195,10 @@ def training_step(self, batch, batch_idx):
)
loss_d = loss_disc_msd + loss_disc_mpd
self.manual_backward(loss_d)
self.optim_d.step()
optim_d.step()

# train generator
self.optim_g.zero_grad()
optim_g.zero_grad()
loss_mel = F.l1_loss(audio_pred_mel, audio_trg_mel)
_, mpd_score_gen, fmap_mpd_real, fmap_mpd_gen = self.mpd(y=audio, y_hat=audio_pred)
_, msd_score_gen, fmap_msd_real, fmap_msd_gen = self.msd(y=audio, y_hat=audio_pred)
Expand All @@ -197,7 +208,7 @@ def training_step(self, batch, batch_idx):
loss_gen_msd, _ = self.generator_loss(disc_outputs=msd_score_gen)
loss_g = loss_gen_msd + loss_gen_mpd + loss_fm_msd + loss_fm_mpd + loss_mel * self.l1_factor
self.manual_backward(loss_g)
self.optim_g.step()
optim_g.step()

# run schedulers
schedulers = self.lr_schedulers()
Expand All @@ -216,7 +227,7 @@ def training_step(self, batch, batch_idx):
"d_loss_msd": loss_disc_msd,
"d_loss": loss_d,
"global_step": self.global_step,
"lr": self.optim_g.param_groups[0]['lr'],
"lr": optim_g.param_groups[0]['lr'],
}
self.log_dict(metrics, on_step=True, sync_dist=True)
self.log("g_l1_loss", loss_mel, prog_bar=True, logger=False, sync_dist=True)
Expand Down
18 changes: 18 additions & 0 deletions nemo/collections/tts/models/mixer_tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -714,6 +714,24 @@ def output_types(self):
"spect": NeuralType(('B', 'D', 'T_spec'), MelSpectrogramType()),
}

def input_example(self, max_text_len=10, max_lm_tokens_len=10):
text = torch.randint(
low=0, high=len(self.tokenizer.tokens), size=(1, max_text_len), device=self.device, dtype=torch.long,
)

inputs = {'text': text}

if self.cond_on_lm_embeddings:
inputs['lm_tokens'] = torch.randint(
low=0,
high=self.lm_embeddings.weight.shape[0],
size=(1, max_lm_tokens_len),
device=self.device,
dtype=torch.long,
)

return (inputs,)

def forward_for_export(self, text, lm_tokens=None):
text_mask = (text != self.tokenizer_pad).unsqueeze(2)
spect = self.infer(text=text, text_mask=text_mask, lm_tokens=lm_tokens).transpose(1, 2)
Expand Down

0 comments on commit 862429a

Please sign in to comment.