Skip to content

Commit

Permalink
Fixing RADTTS training - removing view buffer and fixing accuracy issue
Browse files Browse the repository at this point in the history
Signed-off-by: Boris Fomitchev <[email protected]>
  • Loading branch information
borisfom authored and XuesongYang committed Nov 10, 2022
1 parent b60d7f4 commit 905a550
Show file tree
Hide file tree
Showing 7 changed files with 80 additions and 54 deletions.
14 changes: 0 additions & 14 deletions nemo/collections/tts/models/radtts.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,17 +407,3 @@ def output_module(self):

def forward_for_export(self, text, lens, speaker_id, speaker_id_text, speaker_id_attributes):
return self.model.forward_for_export(text, lens, speaker_id, speaker_id_text, speaker_id_attributes)

def get_export_subnet(self, subnet=None):
return self.model.get_export_subnet(subnet)

def _prepare_for_export(self, **kwargs):
"""
Override this method to prepare module for export. This is in-place operation.
Base version does common necessary module replacements (Apex etc)
"""
PartialConv1d.forward = PartialConv1d.forward_no_cache
super()._prepare_for_export(**kwargs)

def _export_teardown(self):
PartialConv1d.forward = PartialConv1d.forward_with_cache
27 changes: 12 additions & 15 deletions nemo/collections/tts/modules/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,15 +123,15 @@ def lstm_tensor(self, context: Tensor, lens: Tensor, enforce_sorted: bool = Fals
seq = nn.utils.rnn.pack_padded_sequence(
context, lens.long().cpu(), batch_first=True, enforce_sorted=enforce_sorted
)
if not torch.jit.is_scripting():
self.bilstm.flatten_parameters()
# if not torch.jit.is_scripting():
# self.bilstm.flatten_parameters()
ret, _ = self.bilstm(seq)
return nn.utils.rnn.pad_packed_sequence(ret, batch_first=True)

@torch.jit.export
def lstm_sequence(self, seq: PackedSequence) -> Tuple[Tensor, Tensor]:
if not torch.jit.is_scripting():
self.bilstm.flatten_parameters()
# if not torch.jit.is_scripting():
# self.bilstm.flatten_parameters()
ret, _ = self.bilstm(seq)
return nn.utils.rnn.pad_packed_sequence(ret, batch_first=True)

Expand Down Expand Up @@ -164,10 +164,10 @@ def __init__(
):
super(ConvLSTMLinear, self).__init__(n_channels, int(n_channels // 2), 1)
self.out_dim = out_dim
self.convolutions = nn.ModuleList()

if n_layers > 0:
self.dropout = nn.Dropout(p=p_dropout)
self.convolutions = nn.ModuleList()

for i in range(n_layers):
conv_layer = ConvNorm(
Expand Down Expand Up @@ -219,7 +219,6 @@ def conv_to_padded_tensor(self, context: Tensor, lens: Tensor) -> Tensor:
ret = torch.nn.utils.rnn.pad_sequence(context_embedded, batch_first=True)
return ret

@torch.jit.export
def masked_conv_to_sequence(self, context: Tensor, lens: Tensor, enforce_sorted: bool = False) -> PackedSequence:
mask = get_mask_from_lengths_and_val(lens, context)
mask = mask.unsqueeze(1)
Expand All @@ -234,16 +233,14 @@ def masked_conv_to_sequence(self, context: Tensor, lens: Tensor, enforce_sorted:

def forward(self, context: Tensor, lens: Optional[Tensor] = None) -> Tensor:
if lens is None:
for conv in self.convolutions:
context = self.dropout(F.relu(conv(context)))
context = context.transpose(1, 2)
context, _ = self.bilstm(context)
my_lens = context.new_ones([context.shape[0]], dtype=torch.int) * context.shape[2]
else:
# borisf : does not match ADLR (values, lengths)
# seq = self.masked_conv_to_sequence(context, lens, enforce_sorted=False)
# borisf : does match ADLR
seq = self.conv_to_sequence(context, lens, enforce_sorted=False)
context, _ = self.lstm_sequence(seq)
my_lens = lens
# borisf : does not match ADLR (values, lengths)
# seq = self.masked_conv_to_sequence(context, lens, enforce_sorted=False)
# borisf : does match ADLR
seq = self.conv_to_sequence(context, my_lens, enforce_sorted=False)
context, _ = self.lstm_sequence(seq)

if self.dense is not None:
context = self.dense(context).permute(0, 2, 1)
Expand Down
21 changes: 16 additions & 5 deletions nemo/collections/tts/modules/radtts.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,8 @@ def __init__(
self.decoder_use_unvoiced_bias = kwargs['decoder_use_unvoiced_bias']
self.ap_pred_log_f0 = ap_pred_log_f0
self.ap_use_unvoiced_bias = kwargs['ap_use_unvoiced_bias']
self.prepared_for_export = False

if 'atn' in include_modules or 'dec' in include_modules:
if self.learn_alignments:
self.attention = ConvAttention(n_mel_channels, self.n_speaker_dim, n_text_dim)
Expand Down Expand Up @@ -613,7 +615,7 @@ def infer(
spk_vec_text = self.encode_speaker(speaker_id_text)
spk_vec_attributes = self.encode_speaker(speaker_id_attributes)
txt_enc, _ = self.encode_text(text, in_lens)
print("txt_enc: ", txt_enc.shape)
# print("txt_enc: ", txt_enc.shape)

if dur is None:
# get token durations
Expand Down Expand Up @@ -784,14 +786,23 @@ def _prepare_for_export(self, **kwargs):
PartialConv1d.forward = PartialConv1d.forward_no_cache
self.remove_norms()
super()._prepare_for_export(**kwargs)
if self.prepared_for_export:
return
self.encoder = torch.jit.script(self.encoder)
self.v_pred_module.feat_pred_fn = torch.jit.script(self.v_pred_module.feat_pred_fn)
self.f0_pred_module.feat_pred_fn = torch.jit.script(self.f0_pred_module.feat_pred_fn)
self.energy_pred_module.feat_pred_fn = torch.jit.script(self.energy_pred_module.feat_pred_fn)
self.dur_pred_layer.feat_pred_fn = torch.jit.script(self.dur_pred_layer.feat_pred_fn)
if hasattr(self, 'f0_pred_module'):
self.f0_pred_module.feat_pred_fn = torch.jit.script(self.f0_pred_module.feat_pred_fn)
if hasattr(self, 'energy_pred_module'):
self.energy_pred_module.feat_pred_fn = torch.jit.script(self.energy_pred_module.feat_pred_fn)
if hasattr(self, 'dur_pred_layer'):
self.dur_pred_layer.feat_pred_fn = torch.jit.script(self.dur_pred_layer.feat_pred_fn)

if self.use_context_lstm:
self.context_lstm = torch.jit.script(self.context_lstm)
self.prepared_for_export = True

def _export_teardown(self):
PartialConv1d.forward = PartialConv1d.forward_with_cache

def input_example(self, max_batch=1, max_dim=256):
"""
Expand All @@ -802,7 +813,7 @@ def input_example(self, max_batch=1, max_dim=256):
par = next(self.parameters())
sz = (max_batch, max_dim)
inp = torch.randint(0, 16, sz, device=par.device, dtype=torch.int64)
lens = torch.randint(0, max_dim, (max_batch,), device=par.device, dtype=torch.int)
lens = torch.randint(16, max_dim, (max_batch,), device=par.device, dtype=torch.int)
speaker = torch.randint(0, 1, (max_batch,), device=par.device, dtype=torch.int64)
inputs = {
'text': inp,
Expand Down
29 changes: 20 additions & 9 deletions nemo/collections/tts/modules/submodules.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,18 @@ def __init__(self, *args, **kwargs):
super(PartialConv1d, self).__init__(*args, **kwargs)
weight_maskUpdater = torch.ones(1, 1, self.kernel_size[0])
self.register_buffer("weight_maskUpdater", weight_maskUpdater, persistent=False)
slide_winsize = torch.tensor(self.weight_maskUpdater.shape[1] * self.weight_maskUpdater.shape[2])
slide_winsize = torch.tensor(
self.weight_maskUpdater.shape[1] * self.weight_maskUpdater.shape[2], requires_grad=False
)
self.register_buffer("slide_winsize", slide_winsize, persistent=False)

if self.bias is not None:
bias_view = self.bias.view(1, self.out_channels, 1)
self.register_buffer('bias_view', bias_view, persistent=False)
# caching part
self.last_size = (-1, -1, -1)

if self.bias is not None:
bias_view = self.bias.clone().detach().reshape(1, self.out_channels, 1)
self.register_buffer("bias_view", bias_view, persistent=False)

update_mask = torch.ones(1, 1, 1)
self.register_buffer('update_mask', update_mask, persistent=False)
mask_ratio = torch.ones(1, 1, 1)
Expand All @@ -51,6 +54,8 @@ def calculate_mask(self, input: torch.Tensor, mask_in: Optional[torch.Tensor]):
mask = torch.ones(1, 1, input.shape[2], dtype=input.dtype, device=input.device)
else:
mask = mask_in
input = torch.mul(input, mask)

update_mask = F.conv1d(
mask,
self.weight_maskUpdater,
Expand All @@ -60,19 +65,25 @@ def calculate_mask(self, input: torch.Tensor, mask_in: Optional[torch.Tensor]):
dilation=self.dilation,
groups=1,
)
# for mixed precision training, change 1e-8 to 1e-6
mask_ratio = self.slide_winsize / (update_mask + 1e-6)

update_mask_filled = torch.masked_fill(update_mask, update_mask == 0, self.slide_winsize)
mask_ratio = self.slide_winsize / (update_mask_filled)

update_mask = torch.clamp(update_mask, 0, 1)
mask_ratio = torch.mul(mask_ratio.to(update_mask), update_mask)
return torch.mul(input, mask), mask_ratio, update_mask
mask_ratio = torch.mul(mask_ratio, update_mask)
return input, mask_ratio, update_mask

def forward_aux(self, input: torch.Tensor, mask_ratio: torch.Tensor, update_mask: torch.Tensor) -> torch.Tensor:
assert len(input.shape) == 3

raw_out = self._conv_forward(input, self.weight, self.bias)

if self.bias is not None:
output = torch.mul(raw_out - self.bias_view, mask_ratio) + self.bias_view
if torch.jit.is_scripting():
bias_view = self.bias_view
else:
bias_view = self.bias.view(1, self.out_channels, 1)
output = torch.mul(raw_out - bias_view, mask_ratio) + bias_view
output = torch.mul(output, update_mask)
else:
output = torch.mul(raw_out, mask_ratio)
Expand Down
8 changes: 0 additions & 8 deletions nemo/core/classes/exportable.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,14 +167,6 @@ def _export(
jitted_model.save(output)
assert os.path.exists(output)

if check_trace:
if isinstance(check_trace, bool):
check_trace_input = [input_example]
else:
check_trace_input = check_trace

verify_torchscript(jitted_model, output, check_trace_input, input_names, check_tolerance)

elif format == ExportFormat.ONNX:
# dynamic axis is a mapping from input/output_name => list of "dynamic" indices
if dynamic_axes is None:
Expand Down
3 changes: 2 additions & 1 deletion scripts/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,10 +142,11 @@ def nemo_export(argv):
input_example = model.input_module.input_example(**in_args)
if check_trace and len(in_args) > 0:
check_trace = [input_example]
for key, arg in in_args:
for key, arg in in_args.items():
in_args[key] = (arg + 1) // 2
input_example2 = model.input_module.input_example(**in_args)
check_trace.append(input_example2)
logging.info(f"Using additional check args: {in_args}")

_, descriptions = model.export(
out,
Expand Down
32 changes: 30 additions & 2 deletions tests/collections/tts/test_tts_exportables.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,10 @@
import tempfile

import pytest
from omegaconf import DictConfig, OmegaConf

from nemo.collections.tts.models import FastPitchModel, HifiGanModel
from nemo.collections.tts.models import FastPitchModel, HifiGanModel, RadTTSModel
from nemo.utils.app_state import AppState


@pytest.fixture()
Expand All @@ -31,6 +33,24 @@ def hifigan_model():
return model


@pytest.fixture()
def radtts_model():
this_test_dir = os.path.dirname(os.path.abspath(__file__))

cfg = OmegaConf.load(os.path.join(this_test_dir, '../../../examples/tts/conf/rad-tts_feature_pred.yaml'))
cfg.model.init_from_ptl_ckpt = None
cfg.model.train_ds.dataset.manifest_filepath = "dummy.json"
cfg.model.train_ds.dataset.sup_data_path = "dummy.json"
cfg.model.validation_ds.dataset.manifest_filepath = "dummy.json"
cfg.model.validation_ds.dataset.sup_data_path = "dummy.json"
app_state = AppState()
app_state.is_model_being_restored = True
model = RadTTSModel(cfg=cfg.model)
app_state.is_model_being_restored = False
model.eval()
return model


class TestExportable:
@pytest.mark.run_only_on('GPU')
@pytest.mark.unit
Expand All @@ -50,7 +70,15 @@ def test_HifiGanModel_export_to_onnx(self, hifigan_model):
filename = os.path.join(tmpdir, 'hfg.pt')
model.export(output=filename, verbose=True, check_trace=True)

@pytest.mark.run_only_on('GPU')
@pytest.mark.unit
def test_RadTTSModel_export_to_torchscript(self, radtts_model):
model = radtts_model.cuda()
with tempfile.TemporaryDirectory() as tmpdir:
filename = os.path.join(tmpdir, 'rad.ts')
model.export(output=filename, verbose=True, check_trace=True)


if __name__ == "__main__":
t = TestExportable()
t.test_FastPitchModel_export_to_onnx(fastpitch_model())
t.test_RadTTSModel_export_to_torchscript(radtts_model())

0 comments on commit 905a550

Please sign in to comment.