From 905a5503f9d5066a71fb34f6ed9cda75a499edd1 Mon Sep 17 00:00:00 2001 From: Boris Fomitchev Date: Tue, 8 Nov 2022 12:59:27 -0800 Subject: [PATCH] Fixing RADTTS training - removing view buffer and fixing accuracy issue Signed-off-by: Boris Fomitchev --- nemo/collections/tts/models/radtts.py | 14 -------- nemo/collections/tts/modules/common.py | 27 +++++++--------- nemo/collections/tts/modules/radtts.py | 21 +++++++++--- nemo/collections/tts/modules/submodules.py | 29 +++++++++++------ nemo/core/classes/exportable.py | 8 ----- scripts/export.py | 3 +- tests/collections/tts/test_tts_exportables.py | 32 +++++++++++++++++-- 7 files changed, 80 insertions(+), 54 deletions(-) diff --git a/nemo/collections/tts/models/radtts.py b/nemo/collections/tts/models/radtts.py index 47251b4a3f617..73e2adb5f8bb5 100644 --- a/nemo/collections/tts/models/radtts.py +++ b/nemo/collections/tts/models/radtts.py @@ -407,17 +407,3 @@ def output_module(self): def forward_for_export(self, text, lens, speaker_id, speaker_id_text, speaker_id_attributes): return self.model.forward_for_export(text, lens, speaker_id, speaker_id_text, speaker_id_attributes) - - def get_export_subnet(self, subnet=None): - return self.model.get_export_subnet(subnet) - - def _prepare_for_export(self, **kwargs): - """ - Override this method to prepare module for export. This is in-place operation. - Base version does common necessary module replacements (Apex etc) - """ - PartialConv1d.forward = PartialConv1d.forward_no_cache - super()._prepare_for_export(**kwargs) - - def _export_teardown(self): - PartialConv1d.forward = PartialConv1d.forward_with_cache diff --git a/nemo/collections/tts/modules/common.py b/nemo/collections/tts/modules/common.py index 4c566789c5636..091623327ace0 100644 --- a/nemo/collections/tts/modules/common.py +++ b/nemo/collections/tts/modules/common.py @@ -123,15 +123,15 @@ def lstm_tensor(self, context: Tensor, lens: Tensor, enforce_sorted: bool = Fals seq = nn.utils.rnn.pack_padded_sequence( context, lens.long().cpu(), batch_first=True, enforce_sorted=enforce_sorted ) - if not torch.jit.is_scripting(): - self.bilstm.flatten_parameters() + # if not torch.jit.is_scripting(): + # self.bilstm.flatten_parameters() ret, _ = self.bilstm(seq) return nn.utils.rnn.pad_packed_sequence(ret, batch_first=True) @torch.jit.export def lstm_sequence(self, seq: PackedSequence) -> Tuple[Tensor, Tensor]: - if not torch.jit.is_scripting(): - self.bilstm.flatten_parameters() + # if not torch.jit.is_scripting(): + # self.bilstm.flatten_parameters() ret, _ = self.bilstm(seq) return nn.utils.rnn.pad_packed_sequence(ret, batch_first=True) @@ -164,10 +164,10 @@ def __init__( ): super(ConvLSTMLinear, self).__init__(n_channels, int(n_channels // 2), 1) self.out_dim = out_dim + self.convolutions = nn.ModuleList() if n_layers > 0: self.dropout = nn.Dropout(p=p_dropout) - self.convolutions = nn.ModuleList() for i in range(n_layers): conv_layer = ConvNorm( @@ -219,7 +219,6 @@ def conv_to_padded_tensor(self, context: Tensor, lens: Tensor) -> Tensor: ret = torch.nn.utils.rnn.pad_sequence(context_embedded, batch_first=True) return ret - @torch.jit.export def masked_conv_to_sequence(self, context: Tensor, lens: Tensor, enforce_sorted: bool = False) -> PackedSequence: mask = get_mask_from_lengths_and_val(lens, context) mask = mask.unsqueeze(1) @@ -234,16 +233,14 @@ def masked_conv_to_sequence(self, context: Tensor, lens: Tensor, enforce_sorted: def forward(self, context: Tensor, lens: Optional[Tensor] = None) -> Tensor: if lens is None: - for conv in self.convolutions: - context = self.dropout(F.relu(conv(context))) - context = context.transpose(1, 2) - context, _ = self.bilstm(context) + my_lens = context.new_ones([context.shape[0]], dtype=torch.int) * context.shape[2] else: - # borisf : does not match ADLR (values, lengths) - # seq = self.masked_conv_to_sequence(context, lens, enforce_sorted=False) - # borisf : does match ADLR - seq = self.conv_to_sequence(context, lens, enforce_sorted=False) - context, _ = self.lstm_sequence(seq) + my_lens = lens + # borisf : does not match ADLR (values, lengths) + # seq = self.masked_conv_to_sequence(context, lens, enforce_sorted=False) + # borisf : does match ADLR + seq = self.conv_to_sequence(context, my_lens, enforce_sorted=False) + context, _ = self.lstm_sequence(seq) if self.dense is not None: context = self.dense(context).permute(0, 2, 1) diff --git a/nemo/collections/tts/modules/radtts.py b/nemo/collections/tts/modules/radtts.py index 8de98dc4d1fc4..e4a1cc6ab7320 100644 --- a/nemo/collections/tts/modules/radtts.py +++ b/nemo/collections/tts/modules/radtts.py @@ -199,6 +199,8 @@ def __init__( self.decoder_use_unvoiced_bias = kwargs['decoder_use_unvoiced_bias'] self.ap_pred_log_f0 = ap_pred_log_f0 self.ap_use_unvoiced_bias = kwargs['ap_use_unvoiced_bias'] + self.prepared_for_export = False + if 'atn' in include_modules or 'dec' in include_modules: if self.learn_alignments: self.attention = ConvAttention(n_mel_channels, self.n_speaker_dim, n_text_dim) @@ -613,7 +615,7 @@ def infer( spk_vec_text = self.encode_speaker(speaker_id_text) spk_vec_attributes = self.encode_speaker(speaker_id_attributes) txt_enc, _ = self.encode_text(text, in_lens) - print("txt_enc: ", txt_enc.shape) + # print("txt_enc: ", txt_enc.shape) if dur is None: # get token durations @@ -784,14 +786,23 @@ def _prepare_for_export(self, **kwargs): PartialConv1d.forward = PartialConv1d.forward_no_cache self.remove_norms() super()._prepare_for_export(**kwargs) + if self.prepared_for_export: + return self.encoder = torch.jit.script(self.encoder) self.v_pred_module.feat_pred_fn = torch.jit.script(self.v_pred_module.feat_pred_fn) - self.f0_pred_module.feat_pred_fn = torch.jit.script(self.f0_pred_module.feat_pred_fn) - self.energy_pred_module.feat_pred_fn = torch.jit.script(self.energy_pred_module.feat_pred_fn) - self.dur_pred_layer.feat_pred_fn = torch.jit.script(self.dur_pred_layer.feat_pred_fn) + if hasattr(self, 'f0_pred_module'): + self.f0_pred_module.feat_pred_fn = torch.jit.script(self.f0_pred_module.feat_pred_fn) + if hasattr(self, 'energy_pred_module'): + self.energy_pred_module.feat_pred_fn = torch.jit.script(self.energy_pred_module.feat_pred_fn) + if hasattr(self, 'dur_pred_layer'): + self.dur_pred_layer.feat_pred_fn = torch.jit.script(self.dur_pred_layer.feat_pred_fn) if self.use_context_lstm: self.context_lstm = torch.jit.script(self.context_lstm) + self.prepared_for_export = True + + def _export_teardown(self): + PartialConv1d.forward = PartialConv1d.forward_with_cache def input_example(self, max_batch=1, max_dim=256): """ @@ -802,7 +813,7 @@ def input_example(self, max_batch=1, max_dim=256): par = next(self.parameters()) sz = (max_batch, max_dim) inp = torch.randint(0, 16, sz, device=par.device, dtype=torch.int64) - lens = torch.randint(0, max_dim, (max_batch,), device=par.device, dtype=torch.int) + lens = torch.randint(16, max_dim, (max_batch,), device=par.device, dtype=torch.int) speaker = torch.randint(0, 1, (max_batch,), device=par.device, dtype=torch.int64) inputs = { 'text': inp, diff --git a/nemo/collections/tts/modules/submodules.py b/nemo/collections/tts/modules/submodules.py index b7574ed9ddf42..e293a3d5fb3a8 100644 --- a/nemo/collections/tts/modules/submodules.py +++ b/nemo/collections/tts/modules/submodules.py @@ -30,15 +30,18 @@ def __init__(self, *args, **kwargs): super(PartialConv1d, self).__init__(*args, **kwargs) weight_maskUpdater = torch.ones(1, 1, self.kernel_size[0]) self.register_buffer("weight_maskUpdater", weight_maskUpdater, persistent=False) - slide_winsize = torch.tensor(self.weight_maskUpdater.shape[1] * self.weight_maskUpdater.shape[2]) + slide_winsize = torch.tensor( + self.weight_maskUpdater.shape[1] * self.weight_maskUpdater.shape[2], requires_grad=False + ) self.register_buffer("slide_winsize", slide_winsize, persistent=False) - if self.bias is not None: - bias_view = self.bias.view(1, self.out_channels, 1) - self.register_buffer('bias_view', bias_view, persistent=False) # caching part self.last_size = (-1, -1, -1) + if self.bias is not None: + bias_view = self.bias.clone().detach().reshape(1, self.out_channels, 1) + self.register_buffer("bias_view", bias_view, persistent=False) + update_mask = torch.ones(1, 1, 1) self.register_buffer('update_mask', update_mask, persistent=False) mask_ratio = torch.ones(1, 1, 1) @@ -51,6 +54,8 @@ def calculate_mask(self, input: torch.Tensor, mask_in: Optional[torch.Tensor]): mask = torch.ones(1, 1, input.shape[2], dtype=input.dtype, device=input.device) else: mask = mask_in + input = torch.mul(input, mask) + update_mask = F.conv1d( mask, self.weight_maskUpdater, @@ -60,11 +65,13 @@ def calculate_mask(self, input: torch.Tensor, mask_in: Optional[torch.Tensor]): dilation=self.dilation, groups=1, ) - # for mixed precision training, change 1e-8 to 1e-6 - mask_ratio = self.slide_winsize / (update_mask + 1e-6) + + update_mask_filled = torch.masked_fill(update_mask, update_mask == 0, self.slide_winsize) + mask_ratio = self.slide_winsize / (update_mask_filled) + update_mask = torch.clamp(update_mask, 0, 1) - mask_ratio = torch.mul(mask_ratio.to(update_mask), update_mask) - return torch.mul(input, mask), mask_ratio, update_mask + mask_ratio = torch.mul(mask_ratio, update_mask) + return input, mask_ratio, update_mask def forward_aux(self, input: torch.Tensor, mask_ratio: torch.Tensor, update_mask: torch.Tensor) -> torch.Tensor: assert len(input.shape) == 3 @@ -72,7 +79,11 @@ def forward_aux(self, input: torch.Tensor, mask_ratio: torch.Tensor, update_mask raw_out = self._conv_forward(input, self.weight, self.bias) if self.bias is not None: - output = torch.mul(raw_out - self.bias_view, mask_ratio) + self.bias_view + if torch.jit.is_scripting(): + bias_view = self.bias_view + else: + bias_view = self.bias.view(1, self.out_channels, 1) + output = torch.mul(raw_out - bias_view, mask_ratio) + bias_view output = torch.mul(output, update_mask) else: output = torch.mul(raw_out, mask_ratio) diff --git a/nemo/core/classes/exportable.py b/nemo/core/classes/exportable.py index 5a9ab55a4ee7d..9204eb357b783 100644 --- a/nemo/core/classes/exportable.py +++ b/nemo/core/classes/exportable.py @@ -167,14 +167,6 @@ def _export( jitted_model.save(output) assert os.path.exists(output) - if check_trace: - if isinstance(check_trace, bool): - check_trace_input = [input_example] - else: - check_trace_input = check_trace - - verify_torchscript(jitted_model, output, check_trace_input, input_names, check_tolerance) - elif format == ExportFormat.ONNX: # dynamic axis is a mapping from input/output_name => list of "dynamic" indices if dynamic_axes is None: diff --git a/scripts/export.py b/scripts/export.py index a9108a6a7436e..a8c8ad6605de5 100644 --- a/scripts/export.py +++ b/scripts/export.py @@ -142,10 +142,11 @@ def nemo_export(argv): input_example = model.input_module.input_example(**in_args) if check_trace and len(in_args) > 0: check_trace = [input_example] - for key, arg in in_args: + for key, arg in in_args.items(): in_args[key] = (arg + 1) // 2 input_example2 = model.input_module.input_example(**in_args) check_trace.append(input_example2) + logging.info(f"Using additional check args: {in_args}") _, descriptions = model.export( out, diff --git a/tests/collections/tts/test_tts_exportables.py b/tests/collections/tts/test_tts_exportables.py index 3c3f13a028a64..f39d4f0bb7793 100644 --- a/tests/collections/tts/test_tts_exportables.py +++ b/tests/collections/tts/test_tts_exportables.py @@ -15,8 +15,10 @@ import tempfile import pytest +from omegaconf import DictConfig, OmegaConf -from nemo.collections.tts.models import FastPitchModel, HifiGanModel +from nemo.collections.tts.models import FastPitchModel, HifiGanModel, RadTTSModel +from nemo.utils.app_state import AppState @pytest.fixture() @@ -31,6 +33,24 @@ def hifigan_model(): return model +@pytest.fixture() +def radtts_model(): + this_test_dir = os.path.dirname(os.path.abspath(__file__)) + + cfg = OmegaConf.load(os.path.join(this_test_dir, '../../../examples/tts/conf/rad-tts_feature_pred.yaml')) + cfg.model.init_from_ptl_ckpt = None + cfg.model.train_ds.dataset.manifest_filepath = "dummy.json" + cfg.model.train_ds.dataset.sup_data_path = "dummy.json" + cfg.model.validation_ds.dataset.manifest_filepath = "dummy.json" + cfg.model.validation_ds.dataset.sup_data_path = "dummy.json" + app_state = AppState() + app_state.is_model_being_restored = True + model = RadTTSModel(cfg=cfg.model) + app_state.is_model_being_restored = False + model.eval() + return model + + class TestExportable: @pytest.mark.run_only_on('GPU') @pytest.mark.unit @@ -50,7 +70,15 @@ def test_HifiGanModel_export_to_onnx(self, hifigan_model): filename = os.path.join(tmpdir, 'hfg.pt') model.export(output=filename, verbose=True, check_trace=True) + @pytest.mark.run_only_on('GPU') + @pytest.mark.unit + def test_RadTTSModel_export_to_torchscript(self, radtts_model): + model = radtts_model.cuda() + with tempfile.TemporaryDirectory() as tmpdir: + filename = os.path.join(tmpdir, 'rad.ts') + model.export(output=filename, verbose=True, check_trace=True) + if __name__ == "__main__": t = TestExportable() - t.test_FastPitchModel_export_to_onnx(fastpitch_model()) + t.test_RadTTSModel_export_to_torchscript(radtts_model())