Skip to content

Commit

Permalink
Update FastPitch Export (#2355)
Browse files Browse the repository at this point in the history
Signed-off-by: Jason <[email protected]>
  • Loading branch information
blisc authored Jun 14, 2021
1 parent fdb3797 commit fbfdc1b
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 6 deletions.
1 change: 1 addition & 0 deletions nemo/collections/tts/losses/fastpitchloss.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ def output_types(self):
def forward(self, spect_predicted, spect_tgt):
spect_tgt.requires_grad = False
spect_tgt = spect_tgt.transpose(1, 2) # (B, T, H)
spect_predicted = spect_predicted.transpose(1, 2) # (B, T, H)

ldiff = spect_tgt.size(1) - spect_predicted.size(1)
spect_predicted = F.pad(spect_predicted, (0, 0, 0, ldiff, 0, 0), value=0.0)
Expand Down
10 changes: 6 additions & 4 deletions nemo/collections/tts/models/fastpitch.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,9 +194,10 @@ def forward(

@typecheck(output_types={"spect": NeuralType(('B', 'C', 'T'), MelSpectrogramType())})
def generate_spectrogram(self, tokens: 'torch.tensor', speaker: int = 0, pace: float = 1.0) -> torch.tensor:
# FIXME: return masks as well?
self.eval()
spect, *_ = self(text=tokens, durs=None, pitch=None, speaker=speaker, pace=pace)
return spect.transpose(1, 2)
return spect

def training_step(self, batch, batch_idx):
attn_prior, durs, speakers = None, None, None
Expand All @@ -206,7 +207,7 @@ def training_step(self, batch, batch_idx):
audio, audio_lens, text, text_lens, durs, pitch, speakers = batch
mels, spec_len = self.preprocessor(input_signal=audio, length=audio_lens)

mels_pred, _, log_durs_pred, pitch_pred, attn_soft, attn_logprob, attn_hard, attn_hard_dur, pitch = self(
mels_pred, _, _, log_durs_pred, pitch_pred, attn_soft, attn_logprob, attn_hard, attn_hard_dur, pitch = self(
text=text,
durs=durs,
pitch=pitch,
Expand Down Expand Up @@ -275,7 +276,7 @@ def validation_step(self, batch, batch_idx):
mels, mel_lens = self.preprocessor(input_signal=audio, length=audio_lens)

# Calculate val loss on ground truth durations to better align L2 loss in time
mels_pred, _, log_durs_pred, pitch_pred, _, _, _, attn_hard_dur, pitch = self(
mels_pred, _, _, log_durs_pred, pitch_pred, _, _, _, attn_hard_dur, pitch = self(
text=text,
durs=durs,
pitch=pitch,
Expand Down Expand Up @@ -390,6 +391,7 @@ def output_module(self):
def forward_for_export(self, text):
(
spect,
num_frames,
durs_predicted,
log_durs_predicted,
pitch_predicted,
Expand All @@ -399,7 +401,7 @@ def forward_for_export(self, text):
attn_hard_dur,
pitch,
) = self.fastpitch(text=text)
return spect, durs_predicted, log_durs_predicted, pitch_predicted
return spect, num_frames, durs_predicted, log_durs_predicted, pitch_predicted

@property
def disabled_deployment_input_names(self):
Expand Down
4 changes: 3 additions & 1 deletion nemo/collections/tts/modules/fastpitch.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,7 @@ def input_types(self):
def output_types(self):
return {
"spect": NeuralType(('B', 'D', 'T'), MelSpectrogramType()),
"num_frames": NeuralType(('B'), TokenDurationType()),
"durs_predicted": NeuralType(('B', 'T'), TokenDurationType()),
"log_durs_predicted": NeuralType(('B', 'T'), TokenLogDurationType()),
"pitch_predicted": NeuralType(('B', 'T'), RegressionValuesType()),
Expand Down Expand Up @@ -282,9 +283,10 @@ def forward(

# Output FFT
dec_out, _ = self.decoder(input=len_regulated, seq_lens=dec_lens)
spect = self.proj(dec_out)
spect = self.proj(dec_out).transpose(1, 2)
return (
spect,
dec_lens,
durs_predicted,
log_durs_predicted,
pitch_predicted,
Expand Down
2 changes: 1 addition & 1 deletion nemo/collections/tts/modules/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

def mask_from_lens(lens, max_len: Optional[int] = None):
if max_len is None:
max_len = lens.max().item()
max_len = lens.max()
ids = torch.arange(0, max_len, device=lens.device, dtype=lens.dtype)
mask = torch.lt(ids, lens.unsqueeze(1))
return mask
Expand Down

0 comments on commit fbfdc1b

Please sign in to comment.