Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Radtts 1.13 plus #5457

Merged
merged 7 commits into from
Nov 21, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions examples/tts/radtts.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,14 +61,14 @@ def prepare_model_weights(model, unfreeze_modules):
def main(cfg):
trainer = pl.Trainer(**cfg.trainer)
exp_manager(trainer, cfg.get('exp_manager', None))
model = RadTTSModel(cfg=cfg.model, trainer=trainer)
model = RadTTSModel(cfg=cfg.model, trainer=trainer).cuda()
if cfg.model.load_from_checkpoint:
model.maybe_init_from_pretrained_checkpoint(cfg=cfg.model)
prepare_model_weights(model, cfg.model.trainerConfig.unfreeze_modules)
lr_logger = pl.callbacks.LearningRateMonitor()
epoch_time_logger = LogEpochTimeCallback()
trainer.callbacks.extend([lr_logger, epoch_time_logger])
trainer.fit(model)
trainer.fit(model.cuda())


if __name__ == '__main__':
Expand Down
50 changes: 15 additions & 35 deletions nemo/collections/tts/modules/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,39 +119,29 @@ def __init__(self, input_size, hidden_size, num_layers=1, lstm_norm_fn="spectral
lstm_norm_fn_pntr(self.bilstm, 'weight_hh_l0_reverse')
self.bilstm.flatten_parameters()

@torch.jit.export
def lstm_tensor(self, context: Tensor, lens: Tensor, enforce_sorted: bool = False) -> Tuple[Tensor, Tensor]:
seq = nn.utils.rnn.pack_padded_sequence(
context, lens.long().cpu(), batch_first=True, enforce_sorted=enforce_sorted
)
if not (torch.jit.is_scripting() or torch.jit.is_tracing()):
self.bilstm.flatten_parameters()
if hasattr(self.bilstm, 'forward'):
ret, _ = self.bilstm.forward(seq)
else:
ret, _ = self.bilstm.forward_1(seq)
return nn.utils.rnn.pad_packed_sequence(ret, batch_first=True)
return self.lstm_sequence(seq)

@torch.jit.export
def lstm_sequence(self, seq: PackedSequence) -> Tuple[Tensor, Tensor]:
if not (torch.jit.is_scripting() or torch.jit.is_tracing()):
self.bilstm.flatten_parameters()
if hasattr(self.bilstm, 'forward'):
ret, _ = self.bilstm.forward(seq)
elif hasattr(self.bilstm, 'forward_1'):
ret, _ = self.bilstm.forward_1(seq)
ret, _ = self.bilstm(seq)
return nn.utils.rnn.pad_packed_sequence(ret, batch_first=True)

@torch.jit.export
def sort_and_lstm_tensor(self, context: Tensor, lens: Tensor) -> Tensor:
def forward(self, context: Tensor, lens: Tensor) -> Tensor:
context, lens_sorted, unsort_ids = sort_tensor(context, lens)
seq = nn.utils.rnn.pack_padded_sequence(
context, lens_sorted.long().cpu(), batch_first=True, enforce_sorted=True
)
return self.lstm_sequence(seq)[0][unsort_ids]
dtype = context.dtype
# this is only needed for Torchscript to run in Triton
# (https://github.com/pytorch/pytorch/issues/89241)
with torch.cuda.amp.autocast(enabled=False):
ret = self.lstm_tensor(context.to(dtype=torch.float32), lens_sorted, enforce_sorted=True)
return ret[0].to(dtype=dtype)[unsort_ids]


class ConvLSTMLinear(BiLSTM):
class ConvLSTMLinear(nn.Module):
def __init__(
self,
in_dim=None,
Expand All @@ -163,7 +153,8 @@ def __init__(
use_partial_padding=False,
norm_fn=None,
):
super(ConvLSTMLinear, self).__init__(n_channels, int(n_channels // 2), 1)
super(ConvLSTMLinear, self).__init__()
self.bilstm = BiLSTM(n_channels, int(n_channels // 2), 1)
self.convolutions = nn.ModuleList()

if n_layers > 0:
Expand Down Expand Up @@ -194,27 +185,16 @@ def __init__(
if out_dim is not None:
self.dense = nn.Linear(n_channels, out_dim)

def masked_conv_to_sequence(self, context: Tensor, lens: Tensor, enforce_sorted: bool = False) -> PackedSequence:
def forward(self, context: Tensor, lens: Tensor) -> Tensor:
mask = get_mask_from_lengths_and_val(lens, context)
mask = mask.to(dtype=context.dtype).unsqueeze(1)
for conv in self.convolutions:
context = self.dropout(F.relu(conv(context, mask)))

context = context.transpose(1, 2)
seq = torch.nn.utils.rnn.pack_padded_sequence(
context, lens.long().cpu(), batch_first=True, enforce_sorted=enforce_sorted
)
return seq

def forward(self, context: Tensor, lens: Tensor) -> Tensor:
context, lens, unsort_ids = sort_tensor(context, lens)
seq = self.masked_conv_to_sequence(context, lens, enforce_sorted=True)
context, _ = self.lstm_sequence(seq)
context = context[unsort_ids]

# Apply Bidirectional LSTM
context = self.bilstm(context, lens)
if self.dense is not None:
context = self.dense(context).permute(0, 2, 1)

return context


Expand Down
8 changes: 3 additions & 5 deletions nemo/collections/tts/modules/radtts.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,9 +346,7 @@ def preprocess_context(self, context, speaker_vecs, out_lens, f0, energy_avg):
context_w_spkvec = torch.cat((context_w_spkvec, energy_avg), 1)

unfolded_out_lens = out_lens // self.n_group_size
context_lstm_padded_output = self.context_lstm.sort_and_lstm_tensor(
context_w_spkvec.transpose(1, 2), unfolded_out_lens
)
context_lstm_padded_output = self.context_lstm(context_w_spkvec.transpose(1, 2), unfolded_out_lens)
context_w_spkvec = context_lstm_padded_output.transpose(1, 2)

if not self.context_lstm_w_f0_and_energy:
Expand Down Expand Up @@ -773,8 +771,8 @@ def input_example(self, max_batch=1, max_dim=256):
"""
par = next(self.parameters())
sz = (max_batch, max_dim)
inp = torch.randint(0, 16, sz, device=par.device, dtype=torch.int64)
lens = torch.randint(16, max_dim, (max_batch,), device=par.device, dtype=torch.int)
inp = torch.randint(16, 32, sz, device=par.device, dtype=torch.int64)
lens = torch.randint(max_dim // 4, max_dim // 2, (max_batch,), device=par.device, dtype=torch.int)
speaker = torch.randint(0, 1, (max_batch,), device=par.device, dtype=torch.int64)
inputs = {
'text': inp,
Expand Down
36 changes: 24 additions & 12 deletions nemo/utils/export_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import os
from contextlib import nullcontext
from enum import Enum
from typing import Callable, Dict, List, Optional, Type
from typing import Callable, Dict, Optional, Type

import onnx
import torch
Expand Down Expand Up @@ -154,14 +154,16 @@ def to_onnxrt_input(ort_input_names, input_names, input_dict, input_list):


def verify_torchscript(model, output, input_examples, input_names, check_tolerance=0.01):
ts_model = torch.jit.load(output)

all_good = True
for input_example in input_examples:
input_list, input_dict = parse_input_example(input_example)
output_example = model.forward(*input_list, **input_dict)

all_good = all_good and run_ts_and_compare(ts_model, input_list, input_dict, output_example, check_tolerance)
# We disable autocast here to make sure exported TS will run under Triton or other C++ env
with torch.cuda.amp.autocast(enabled=False):
ts_model = torch.jit.load(output)
all_good = all_good and run_ts_and_compare(
ts_model, input_list, input_dict, output_example, check_tolerance
)
status = "SUCCESS" if all_good else "FAIL"
logging.info(f"Torchscript generated at {output} verified with torchscript forward : " + status)
return all_good
Expand Down Expand Up @@ -204,9 +206,15 @@ def run_ts_and_compare(ts_model, ts_input_list, ts_input_dict, output_example, c
if torch.is_tensor(expected):
tout = out.to('cpu')
logging.debug(f"Checking output {i}, shape: {expected.shape}:\n{expected}\n{tout}")
if not torch.allclose(tout, expected.cpu(), rtol=check_tolerance, atol=check_tolerance):
this_good = True
try:
if not torch.allclose(tout, expected.cpu(), rtol=check_tolerance, atol=check_tolerance):
this_good = False
except Exception: # there may ne size mismatch and it may be OK
this_good = False
if not this_good:
logging.info(f"Results mismatch! PyTorch(expected):\n{expected}\nTorchScript:\n{tout}")
all_good = False
logging.info(f"onnxruntime results mismatch! PyTorch(expected):\n{expected}\nTorchScript:\n{tout}")
return all_good


Expand All @@ -220,9 +228,15 @@ def run_ort_and_compare(sess, ort_input, output_example, check_tolerance=0.01):
if torch.is_tensor(expected):
tout = torch.from_numpy(out)
logging.debug(f"Checking output {i}, shape: {expected.shape}:\n{expected}\n{tout}")
if not torch.allclose(tout, expected.cpu(), rtol=check_tolerance, atol=100 * check_tolerance):
all_good = False
this_good = True
try:
if not torch.allclose(tout, expected.cpu(), rtol=check_tolerance, atol=100 * check_tolerance):
this_good = False
except Exception: # there may ne size mismatch and it may be OK
this_good = False
if not this_good:
logging.info(f"onnxruntime results mismatch! PyTorch(expected):\n{expected}\nONNXruntime:\n{tout}")
all_good = False
return all_good


Expand Down Expand Up @@ -419,8 +433,7 @@ def replace_modules(


def script_module(m: nn.Module):
m1 = torch.jit.script(m)
return m1
return torch.jit.script(m)


default_replacements = {
Expand All @@ -432,7 +445,6 @@ def script_module(m: nn.Module):

script_replacements = {
"BiLSTM": script_module,
"ConvLSTMLinear": script_module,
}


Expand Down
4 changes: 3 additions & 1 deletion tests/collections/tts/test_tts_exportables.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import tempfile

import pytest
import torch
from omegaconf import OmegaConf

from nemo.collections.tts.models import FastPitchModel, HifiGanModel, RadTTSModel
Expand Down Expand Up @@ -79,4 +80,5 @@ def test_RadTTSModel_export_to_torchscript(self, radtts_model):
model = radtts_model.cuda()
with tempfile.TemporaryDirectory() as tmpdir:
filename = os.path.join(tmpdir, 'rad.ts')
model.export(output=filename, verbose=True, check_trace=True)
with torch.cuda.amp.autocast(enabled=True):
model.export(output=filename, verbose=True, check_trace=True)