From 73f5be2abc02e009269715b24979173fabc8f9f2 Mon Sep 17 00:00:00 2001 From: Boris Fomitchev Date: Tue, 8 Nov 2022 12:59:27 -0800 Subject: [PATCH 01/43] Fixing RADTTS training - removing view buffer and fixing accuracy issue Signed-off-by: Boris Fomitchev --- nemo/collections/tts/modules/common.py | 47 +++++++++++++--------- nemo/collections/tts/modules/radtts.py | 10 ++--- nemo/collections/tts/modules/submodules.py | 22 +++++----- 3 files changed, 47 insertions(+), 32 deletions(-) diff --git a/nemo/collections/tts/modules/common.py b/nemo/collections/tts/modules/common.py index 4c566789c563..938739f58214 100644 --- a/nemo/collections/tts/modules/common.py +++ b/nemo/collections/tts/modules/common.py @@ -123,15 +123,15 @@ def lstm_tensor(self, context: Tensor, lens: Tensor, enforce_sorted: bool = Fals seq = nn.utils.rnn.pack_padded_sequence( context, lens.long().cpu(), batch_first=True, enforce_sorted=enforce_sorted ) - if not torch.jit.is_scripting(): - self.bilstm.flatten_parameters() + # if not torch.jit.is_scripting(): + # self.bilstm.flatten_parameters() ret, _ = self.bilstm(seq) return nn.utils.rnn.pad_packed_sequence(ret, batch_first=True) @torch.jit.export def lstm_sequence(self, seq: PackedSequence) -> Tuple[Tensor, Tensor]: - if not torch.jit.is_scripting(): - self.bilstm.flatten_parameters() + # if not torch.jit.is_scripting(): + # self.bilstm.flatten_parameters() ret, _ = self.bilstm(seq) return nn.utils.rnn.pad_packed_sequence(ret, batch_first=True) @@ -164,10 +164,10 @@ def __init__( ): super(ConvLSTMLinear, self).__init__(n_channels, int(n_channels // 2), 1) self.out_dim = out_dim + self.convolutions = nn.ModuleList() if n_layers > 0: self.dropout = nn.Dropout(p=p_dropout) - self.convolutions = nn.ModuleList() for i in range(n_layers): conv_layer = ConvNorm( @@ -219,7 +219,7 @@ def conv_to_padded_tensor(self, context: Tensor, lens: Tensor) -> Tensor: ret = torch.nn.utils.rnn.pad_sequence(context_embedded, batch_first=True) return ret - @torch.jit.export + @torch.jit.ignore def masked_conv_to_sequence(self, context: Tensor, lens: Tensor, enforce_sorted: bool = False) -> PackedSequence: mask = get_mask_from_lengths_and_val(lens, context) mask = mask.unsqueeze(1) @@ -232,24 +232,35 @@ def masked_conv_to_sequence(self, context: Tensor, lens: Tensor, enforce_sorted: ) return seq - def forward(self, context: Tensor, lens: Optional[Tensor] = None) -> Tensor: - if lens is None: - for conv in self.convolutions: - context = self.dropout(F.relu(conv(context))) - context = context.transpose(1, 2) - context, _ = self.bilstm(context) - else: - # borisf : does not match ADLR (values, lengths) - # seq = self.masked_conv_to_sequence(context, lens, enforce_sorted=False) - # borisf : does match ADLR - seq = self.conv_to_sequence(context, lens, enforce_sorted=False) - context, _ = self.lstm_sequence(seq) + def forward(self, context: Tensor, lens: Tensor) -> Tensor: + # borisf : does not match ADLR (values, lengths) + # seq = self.masked_conv_to_sequence(context, lens, enforce_sorted=False) + # borisf : does match ADLR + seq = self.conv_to_sequence(context, lens, enforce_sorted=False) + context, _ = self.lstm_sequence(seq) if self.dense is not None: context = self.dense(context).permute(0, 2, 1) return context + def script(self): + traced = nn.ModuleList() + for conv in self.convolutions: + if hasattr(conv, 'conv'): + w = conv.conv.weight + else: + print('------\n', self) + print(conv) + w = conv.weight + s = w.shape + rand_in = torch.ones(1, s[1], s[2], dtype=w.dtype, device=w.device) + # , torch.ones(1, 1, s[2], dtype=w.dtype, device=w.device)) + traced.append(torch.jit.trace_module(conv, {"forward": rand_in})) + print('============\n') + self.convolutions = traced + return torch.jit.script(self) + def getRadTTSEncoder( encoder_n_convolutions=3, diff --git a/nemo/collections/tts/modules/radtts.py b/nemo/collections/tts/modules/radtts.py index 8de98dc4d1fc..c977aaeccb19 100644 --- a/nemo/collections/tts/modules/radtts.py +++ b/nemo/collections/tts/modules/radtts.py @@ -784,11 +784,11 @@ def _prepare_for_export(self, **kwargs): PartialConv1d.forward = PartialConv1d.forward_no_cache self.remove_norms() super()._prepare_for_export(**kwargs) - self.encoder = torch.jit.script(self.encoder) - self.v_pred_module.feat_pred_fn = torch.jit.script(self.v_pred_module.feat_pred_fn) - self.f0_pred_module.feat_pred_fn = torch.jit.script(self.f0_pred_module.feat_pred_fn) - self.energy_pred_module.feat_pred_fn = torch.jit.script(self.energy_pred_module.feat_pred_fn) - self.dur_pred_layer.feat_pred_fn = torch.jit.script(self.dur_pred_layer.feat_pred_fn) + self.encoder = self.encoder.script() + self.v_pred_module.feat_pred_fn = self.v_pred_module.feat_pred_fn.script() + self.f0_pred_module.feat_pred_fn = self.f0_pred_module.feat_pred_fn.script() + self.energy_pred_module.feat_pred_fn = self.energy_pred_module.feat_pred_fn.script() + self.dur_pred_layer.feat_pred_fn = self.dur_pred_layer.feat_pred_fn.script() if self.use_context_lstm: self.context_lstm = torch.jit.script(self.context_lstm) diff --git a/nemo/collections/tts/modules/submodules.py b/nemo/collections/tts/modules/submodules.py index b7574ed9ddf4..f3a33d57937a 100644 --- a/nemo/collections/tts/modules/submodules.py +++ b/nemo/collections/tts/modules/submodules.py @@ -30,12 +30,11 @@ def __init__(self, *args, **kwargs): super(PartialConv1d, self).__init__(*args, **kwargs) weight_maskUpdater = torch.ones(1, 1, self.kernel_size[0]) self.register_buffer("weight_maskUpdater", weight_maskUpdater, persistent=False) - slide_winsize = torch.tensor(self.weight_maskUpdater.shape[1] * self.weight_maskUpdater.shape[2]) + slide_winsize = torch.tensor( + self.weight_maskUpdater.shape[1] * self.weight_maskUpdater.shape[2], requires_grad=False + ) self.register_buffer("slide_winsize", slide_winsize, persistent=False) - if self.bias is not None: - bias_view = self.bias.view(1, self.out_channels, 1) - self.register_buffer('bias_view', bias_view, persistent=False) # caching part self.last_size = (-1, -1, -1) @@ -51,6 +50,8 @@ def calculate_mask(self, input: torch.Tensor, mask_in: Optional[torch.Tensor]): mask = torch.ones(1, 1, input.shape[2], dtype=input.dtype, device=input.device) else: mask = mask_in + input = torch.mul(input, mask) + update_mask = F.conv1d( mask, self.weight_maskUpdater, @@ -60,11 +61,13 @@ def calculate_mask(self, input: torch.Tensor, mask_in: Optional[torch.Tensor]): dilation=self.dilation, groups=1, ) - # for mixed precision training, change 1e-8 to 1e-6 - mask_ratio = self.slide_winsize / (update_mask + 1e-6) + + update_mask_filled = torch.masked_fill(update_mask, update_mask == 0, self.slide_winsize) + mask_ratio = self.slide_winsize / (update_mask_filled) + update_mask = torch.clamp(update_mask, 0, 1) - mask_ratio = torch.mul(mask_ratio.to(update_mask), update_mask) - return torch.mul(input, mask), mask_ratio, update_mask + mask_ratio = torch.mul(mask_ratio, update_mask) + return input, mask_ratio, update_mask def forward_aux(self, input: torch.Tensor, mask_ratio: torch.Tensor, update_mask: torch.Tensor) -> torch.Tensor: assert len(input.shape) == 3 @@ -72,7 +75,8 @@ def forward_aux(self, input: torch.Tensor, mask_ratio: torch.Tensor, update_mask raw_out = self._conv_forward(input, self.weight, self.bias) if self.bias is not None: - output = torch.mul(raw_out - self.bias_view, mask_ratio) + self.bias_view + bias_view = self.bias.view(1, self.out_channels, 1) + output = torch.mul(raw_out - bias_view, mask_ratio) + bias_view output = torch.mul(output, update_mask) else: output = torch.mul(raw_out, mask_ratio) From 44a735c2070a4b2921ac46545fa1267778806b7c Mon Sep 17 00:00:00 2001 From: Boris Fomitchev Date: Wed, 9 Nov 2022 14:17:44 -0800 Subject: [PATCH 02/43] Addressing code review Signed-off-by: Boris Fomitchev --- nemo/collections/tts/modules/common.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/nemo/collections/tts/modules/common.py b/nemo/collections/tts/modules/common.py index 938739f58214..7419a7b86632 100644 --- a/nemo/collections/tts/modules/common.py +++ b/nemo/collections/tts/modules/common.py @@ -219,7 +219,6 @@ def conv_to_padded_tensor(self, context: Tensor, lens: Tensor) -> Tensor: ret = torch.nn.utils.rnn.pad_sequence(context_embedded, batch_first=True) return ret - @torch.jit.ignore def masked_conv_to_sequence(self, context: Tensor, lens: Tensor, enforce_sorted: bool = False) -> PackedSequence: mask = get_mask_from_lengths_and_val(lens, context) mask = mask.unsqueeze(1) @@ -232,7 +231,11 @@ def masked_conv_to_sequence(self, context: Tensor, lens: Tensor, enforce_sorted: ) return seq - def forward(self, context: Tensor, lens: Tensor) -> Tensor: + def forward(self, context: Tensor, in_lens: Optional[Tensor]) -> Tensor: + if in_lens is None: + lens = context.new_ones([context.shape[0]], dtype=torch.int) * context.shape[2] + else: + lens = in_lens # borisf : does not match ADLR (values, lengths) # seq = self.masked_conv_to_sequence(context, lens, enforce_sorted=False) # borisf : does match ADLR @@ -250,14 +253,11 @@ def script(self): if hasattr(conv, 'conv'): w = conv.conv.weight else: - print('------\n', self) - print(conv) w = conv.weight s = w.shape rand_in = torch.ones(1, s[1], s[2], dtype=w.dtype, device=w.device) # , torch.ones(1, 1, s[2], dtype=w.dtype, device=w.device)) traced.append(torch.jit.trace_module(conv, {"forward": rand_in})) - print('============\n') self.convolutions = traced return torch.jit.script(self) From f6d2f2cdaff1fae5a5cc6cfd060537fb254f409a Mon Sep 17 00:00:00 2001 From: Boris Fomitchev Date: Wed, 9 Nov 2022 14:57:48 -0800 Subject: [PATCH 03/43] Addressing code review 2 Signed-off-by: Boris Fomitchev --- nemo/collections/tts/modules/common.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/nemo/collections/tts/modules/common.py b/nemo/collections/tts/modules/common.py index 7419a7b86632..4189faf47ac2 100644 --- a/nemo/collections/tts/modules/common.py +++ b/nemo/collections/tts/modules/common.py @@ -231,15 +231,15 @@ def masked_conv_to_sequence(self, context: Tensor, lens: Tensor, enforce_sorted: ) return seq - def forward(self, context: Tensor, in_lens: Optional[Tensor]) -> Tensor: - if in_lens is None: - lens = context.new_ones([context.shape[0]], dtype=torch.int) * context.shape[2] + def forward(self, context: Tensor, lens: Optional[Tensor] = None) -> Tensor: + if lens is None: + my_lens = context.new_ones([context.shape[0]], dtype=torch.int) * context.shape[2] else: - lens = in_lens + my_lens = in_lens # borisf : does not match ADLR (values, lengths) # seq = self.masked_conv_to_sequence(context, lens, enforce_sorted=False) # borisf : does match ADLR - seq = self.conv_to_sequence(context, lens, enforce_sorted=False) + seq = self.conv_to_sequence(context, my_lens, enforce_sorted=False) context, _ = self.lstm_sequence(seq) if self.dense is not None: From 835dd4d5896215af601e1d4bdb400dde3307e49a Mon Sep 17 00:00:00 2001 From: Boris Fomitchev Date: Wed, 9 Nov 2022 16:12:46 -0800 Subject: [PATCH 04/43] Fixed assignment Signed-off-by: Boris Fomitchev --- nemo/collections/tts/modules/common.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nemo/collections/tts/modules/common.py b/nemo/collections/tts/modules/common.py index 4189faf47ac2..3e569d0f8f14 100644 --- a/nemo/collections/tts/modules/common.py +++ b/nemo/collections/tts/modules/common.py @@ -235,7 +235,7 @@ def forward(self, context: Tensor, lens: Optional[Tensor] = None) -> Tensor: if lens is None: my_lens = context.new_ones([context.shape[0]], dtype=torch.int) * context.shape[2] else: - my_lens = in_lens + my_lens = lens # borisf : does not match ADLR (values, lengths) # seq = self.masked_conv_to_sequence(context, lens, enforce_sorted=False) # borisf : does match ADLR From da91270ceaf4ef8947be812eaadc0eadf9416285 Mon Sep 17 00:00:00 2001 From: Boris Fomitchev Date: Thu, 10 Nov 2022 13:30:03 -0800 Subject: [PATCH 05/43] Working script Signed-off-by: Boris Fomitchev --- nemo/collections/tts/models/radtts.py | 14 -------- nemo/collections/tts/modules/common.py | 14 -------- nemo/collections/tts/modules/radtts.py | 25 +++++++++++---- nemo/collections/tts/modules/submodules.py | 9 +++++- nemo/core/classes/exportable.py | 8 ----- scripts/export.py | 3 +- tests/collections/tts/test_tts_exportables.py | 32 +++++++++++++++++-- 7 files changed, 58 insertions(+), 47 deletions(-) diff --git a/nemo/collections/tts/models/radtts.py b/nemo/collections/tts/models/radtts.py index 47251b4a3f61..73e2adb5f8bb 100644 --- a/nemo/collections/tts/models/radtts.py +++ b/nemo/collections/tts/models/radtts.py @@ -407,17 +407,3 @@ def output_module(self): def forward_for_export(self, text, lens, speaker_id, speaker_id_text, speaker_id_attributes): return self.model.forward_for_export(text, lens, speaker_id, speaker_id_text, speaker_id_attributes) - - def get_export_subnet(self, subnet=None): - return self.model.get_export_subnet(subnet) - - def _prepare_for_export(self, **kwargs): - """ - Override this method to prepare module for export. This is in-place operation. - Base version does common necessary module replacements (Apex etc) - """ - PartialConv1d.forward = PartialConv1d.forward_no_cache - super()._prepare_for_export(**kwargs) - - def _export_teardown(self): - PartialConv1d.forward = PartialConv1d.forward_with_cache diff --git a/nemo/collections/tts/modules/common.py b/nemo/collections/tts/modules/common.py index 3e569d0f8f14..091623327ace 100644 --- a/nemo/collections/tts/modules/common.py +++ b/nemo/collections/tts/modules/common.py @@ -247,20 +247,6 @@ def forward(self, context: Tensor, lens: Optional[Tensor] = None) -> Tensor: return context - def script(self): - traced = nn.ModuleList() - for conv in self.convolutions: - if hasattr(conv, 'conv'): - w = conv.conv.weight - else: - w = conv.weight - s = w.shape - rand_in = torch.ones(1, s[1], s[2], dtype=w.dtype, device=w.device) - # , torch.ones(1, 1, s[2], dtype=w.dtype, device=w.device)) - traced.append(torch.jit.trace_module(conv, {"forward": rand_in})) - self.convolutions = traced - return torch.jit.script(self) - def getRadTTSEncoder( encoder_n_convolutions=3, diff --git a/nemo/collections/tts/modules/radtts.py b/nemo/collections/tts/modules/radtts.py index c977aaeccb19..e4a1cc6ab732 100644 --- a/nemo/collections/tts/modules/radtts.py +++ b/nemo/collections/tts/modules/radtts.py @@ -199,6 +199,8 @@ def __init__( self.decoder_use_unvoiced_bias = kwargs['decoder_use_unvoiced_bias'] self.ap_pred_log_f0 = ap_pred_log_f0 self.ap_use_unvoiced_bias = kwargs['ap_use_unvoiced_bias'] + self.prepared_for_export = False + if 'atn' in include_modules or 'dec' in include_modules: if self.learn_alignments: self.attention = ConvAttention(n_mel_channels, self.n_speaker_dim, n_text_dim) @@ -613,7 +615,7 @@ def infer( spk_vec_text = self.encode_speaker(speaker_id_text) spk_vec_attributes = self.encode_speaker(speaker_id_attributes) txt_enc, _ = self.encode_text(text, in_lens) - print("txt_enc: ", txt_enc.shape) + # print("txt_enc: ", txt_enc.shape) if dur is None: # get token durations @@ -784,14 +786,23 @@ def _prepare_for_export(self, **kwargs): PartialConv1d.forward = PartialConv1d.forward_no_cache self.remove_norms() super()._prepare_for_export(**kwargs) - self.encoder = self.encoder.script() - self.v_pred_module.feat_pred_fn = self.v_pred_module.feat_pred_fn.script() - self.f0_pred_module.feat_pred_fn = self.f0_pred_module.feat_pred_fn.script() - self.energy_pred_module.feat_pred_fn = self.energy_pred_module.feat_pred_fn.script() - self.dur_pred_layer.feat_pred_fn = self.dur_pred_layer.feat_pred_fn.script() + if self.prepared_for_export: + return + self.encoder = torch.jit.script(self.encoder) + self.v_pred_module.feat_pred_fn = torch.jit.script(self.v_pred_module.feat_pred_fn) + if hasattr(self, 'f0_pred_module'): + self.f0_pred_module.feat_pred_fn = torch.jit.script(self.f0_pred_module.feat_pred_fn) + if hasattr(self, 'energy_pred_module'): + self.energy_pred_module.feat_pred_fn = torch.jit.script(self.energy_pred_module.feat_pred_fn) + if hasattr(self, 'dur_pred_layer'): + self.dur_pred_layer.feat_pred_fn = torch.jit.script(self.dur_pred_layer.feat_pred_fn) if self.use_context_lstm: self.context_lstm = torch.jit.script(self.context_lstm) + self.prepared_for_export = True + + def _export_teardown(self): + PartialConv1d.forward = PartialConv1d.forward_with_cache def input_example(self, max_batch=1, max_dim=256): """ @@ -802,7 +813,7 @@ def input_example(self, max_batch=1, max_dim=256): par = next(self.parameters()) sz = (max_batch, max_dim) inp = torch.randint(0, 16, sz, device=par.device, dtype=torch.int64) - lens = torch.randint(0, max_dim, (max_batch,), device=par.device, dtype=torch.int) + lens = torch.randint(16, max_dim, (max_batch,), device=par.device, dtype=torch.int) speaker = torch.randint(0, 1, (max_batch,), device=par.device, dtype=torch.int64) inputs = { 'text': inp, diff --git a/nemo/collections/tts/modules/submodules.py b/nemo/collections/tts/modules/submodules.py index f3a33d57937a..e293a3d5fb3a 100644 --- a/nemo/collections/tts/modules/submodules.py +++ b/nemo/collections/tts/modules/submodules.py @@ -38,6 +38,10 @@ def __init__(self, *args, **kwargs): # caching part self.last_size = (-1, -1, -1) + if self.bias is not None: + bias_view = self.bias.clone().detach().reshape(1, self.out_channels, 1) + self.register_buffer("bias_view", bias_view, persistent=False) + update_mask = torch.ones(1, 1, 1) self.register_buffer('update_mask', update_mask, persistent=False) mask_ratio = torch.ones(1, 1, 1) @@ -75,7 +79,10 @@ def forward_aux(self, input: torch.Tensor, mask_ratio: torch.Tensor, update_mask raw_out = self._conv_forward(input, self.weight, self.bias) if self.bias is not None: - bias_view = self.bias.view(1, self.out_channels, 1) + if torch.jit.is_scripting(): + bias_view = self.bias_view + else: + bias_view = self.bias.view(1, self.out_channels, 1) output = torch.mul(raw_out - bias_view, mask_ratio) + bias_view output = torch.mul(output, update_mask) else: diff --git a/nemo/core/classes/exportable.py b/nemo/core/classes/exportable.py index 5a9ab55a4ee7..9204eb357b78 100644 --- a/nemo/core/classes/exportable.py +++ b/nemo/core/classes/exportable.py @@ -167,14 +167,6 @@ def _export( jitted_model.save(output) assert os.path.exists(output) - if check_trace: - if isinstance(check_trace, bool): - check_trace_input = [input_example] - else: - check_trace_input = check_trace - - verify_torchscript(jitted_model, output, check_trace_input, input_names, check_tolerance) - elif format == ExportFormat.ONNX: # dynamic axis is a mapping from input/output_name => list of "dynamic" indices if dynamic_axes is None: diff --git a/scripts/export.py b/scripts/export.py index a9108a6a7436..a8c8ad6605de 100644 --- a/scripts/export.py +++ b/scripts/export.py @@ -142,10 +142,11 @@ def nemo_export(argv): input_example = model.input_module.input_example(**in_args) if check_trace and len(in_args) > 0: check_trace = [input_example] - for key, arg in in_args: + for key, arg in in_args.items(): in_args[key] = (arg + 1) // 2 input_example2 = model.input_module.input_example(**in_args) check_trace.append(input_example2) + logging.info(f"Using additional check args: {in_args}") _, descriptions = model.export( out, diff --git a/tests/collections/tts/test_tts_exportables.py b/tests/collections/tts/test_tts_exportables.py index 3c3f13a028a6..f39d4f0bb779 100644 --- a/tests/collections/tts/test_tts_exportables.py +++ b/tests/collections/tts/test_tts_exportables.py @@ -15,8 +15,10 @@ import tempfile import pytest +from omegaconf import DictConfig, OmegaConf -from nemo.collections.tts.models import FastPitchModel, HifiGanModel +from nemo.collections.tts.models import FastPitchModel, HifiGanModel, RadTTSModel +from nemo.utils.app_state import AppState @pytest.fixture() @@ -31,6 +33,24 @@ def hifigan_model(): return model +@pytest.fixture() +def radtts_model(): + this_test_dir = os.path.dirname(os.path.abspath(__file__)) + + cfg = OmegaConf.load(os.path.join(this_test_dir, '../../../examples/tts/conf/rad-tts_feature_pred.yaml')) + cfg.model.init_from_ptl_ckpt = None + cfg.model.train_ds.dataset.manifest_filepath = "dummy.json" + cfg.model.train_ds.dataset.sup_data_path = "dummy.json" + cfg.model.validation_ds.dataset.manifest_filepath = "dummy.json" + cfg.model.validation_ds.dataset.sup_data_path = "dummy.json" + app_state = AppState() + app_state.is_model_being_restored = True + model = RadTTSModel(cfg=cfg.model) + app_state.is_model_being_restored = False + model.eval() + return model + + class TestExportable: @pytest.mark.run_only_on('GPU') @pytest.mark.unit @@ -50,7 +70,15 @@ def test_HifiGanModel_export_to_onnx(self, hifigan_model): filename = os.path.join(tmpdir, 'hfg.pt') model.export(output=filename, verbose=True, check_trace=True) + @pytest.mark.run_only_on('GPU') + @pytest.mark.unit + def test_RadTTSModel_export_to_torchscript(self, radtts_model): + model = radtts_model.cuda() + with tempfile.TemporaryDirectory() as tmpdir: + filename = os.path.join(tmpdir, 'rad.ts') + model.export(output=filename, verbose=True, check_trace=True) + if __name__ == "__main__": t = TestExportable() - t.test_FastPitchModel_export_to_onnx(fastpitch_model()) + t.test_RadTTSModel_export_to_torchscript(radtts_model()) From d48393c1c55b0f88c09ca03c29e95e32a0bbb3d1 Mon Sep 17 00:00:00 2001 From: Boris Fomitchev Date: Thu, 10 Nov 2022 13:56:47 -0800 Subject: [PATCH 06/43] restored flatten_parameters Signed-off-by: Boris Fomitchev --- nemo/collections/tts/modules/common.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/nemo/collections/tts/modules/common.py b/nemo/collections/tts/modules/common.py index 091623327ace..0613a5343dc0 100644 --- a/nemo/collections/tts/modules/common.py +++ b/nemo/collections/tts/modules/common.py @@ -123,15 +123,15 @@ def lstm_tensor(self, context: Tensor, lens: Tensor, enforce_sorted: bool = Fals seq = nn.utils.rnn.pack_padded_sequence( context, lens.long().cpu(), batch_first=True, enforce_sorted=enforce_sorted ) - # if not torch.jit.is_scripting(): - # self.bilstm.flatten_parameters() + if not torch.jit.is_scripting(): + self.bilstm.flatten_parameters() ret, _ = self.bilstm(seq) return nn.utils.rnn.pad_packed_sequence(ret, batch_first=True) @torch.jit.export def lstm_sequence(self, seq: PackedSequence) -> Tuple[Tensor, Tensor]: - # if not torch.jit.is_scripting(): - # self.bilstm.flatten_parameters() + if not torch.jit.is_scripting(): + self.bilstm.flatten_parameters() ret, _ = self.bilstm(seq) return nn.utils.rnn.pad_packed_sequence(ret, batch_first=True) From 378c3da0d5e81c5768bc79192b0272dd25e889d6 Mon Sep 17 00:00:00 2001 From: Boris Fomitchev Date: Thu, 10 Nov 2022 17:35:44 -0800 Subject: [PATCH 07/43] Working bias alias for export Signed-off-by: Boris Fomitchev --- nemo/collections/tts/modules/radtts.py | 17 ++++---- nemo/collections/tts/modules/submodules.py | 49 ++++++++++++++-------- 2 files changed, 41 insertions(+), 25 deletions(-) diff --git a/nemo/collections/tts/modules/radtts.py b/nemo/collections/tts/modules/radtts.py index e4a1cc6ab732..a30de902160b 100644 --- a/nemo/collections/tts/modules/radtts.py +++ b/nemo/collections/tts/modules/radtts.py @@ -669,9 +669,8 @@ def infer( # replication pad, because ungrouping with different group sizes # may lead to mismatched lengths # FIXME: use replication pad - print("V mask, energy_avg, f0, f0_bias: ", voiced_mask.shape, energy_avg.shape, f0.shape, f0_bias.shape) (energy_avg, f0) = pad_energy_avg_and_f0(energy_avg, f0, max_out_len) - print("V mask, energy_avg, f0, f0_bias: ", voiced_mask.shape, energy_avg.shape, f0.shape, f0_bias.shape) + # print("V mask, energy_avg, f0, f0_bias: ", voiced_mask.shape, energy_avg.shape, f0.shape, f0_bias.shape) context_w_spkvec = self.preprocess_context( txt_enc_time_expanded, spk_vec, out_lens, f0 * voiced_mask + f0_bias, energy_avg @@ -783,11 +782,16 @@ def output_types(self): # Methods for model exportability def _prepare_for_export(self, **kwargs): - PartialConv1d.forward = PartialConv1d.forward_no_cache + print(kwargs) + PartialConv1d.forward = PartialConv1d.forward_for_script self.remove_norms() super()._prepare_for_export(**kwargs) if self.prepared_for_export: return + for m in self.modules(): + if isinstance(m, PartialConv1d): + PartialConv1d.update_bias_view(m) + self.encoder = torch.jit.script(self.encoder) self.v_pred_module.feat_pred_fn = torch.jit.script(self.v_pred_module.feat_pred_fn) if hasattr(self, 'f0_pred_module'): @@ -801,9 +805,6 @@ def _prepare_for_export(self, **kwargs): self.context_lstm = torch.jit.script(self.context_lstm) self.prepared_for_export = True - def _export_teardown(self): - PartialConv1d.forward = PartialConv1d.forward_with_cache - def input_example(self, max_batch=1, max_dim=256): """ Generates input examples for tracing etc. @@ -833,8 +834,8 @@ def forward_for_export(self, text, lens, speaker_id, speaker_id_text, speaker_id text, speaker_id_text=speaker_id_text, speaker_id_attributes=speaker_id_attributes, - sigma=0.0, - sigma_txt=0.0, + sigma=0.7, + sigma_txt=0.7, sigma_f0=1.0, sigma_energy=1.0, f0_mean=145.0, diff --git a/nemo/collections/tts/modules/submodules.py b/nemo/collections/tts/modules/submodules.py index e293a3d5fb3a..b43a9ad32dbe 100644 --- a/nemo/collections/tts/modules/submodules.py +++ b/nemo/collections/tts/modules/submodules.py @@ -30,24 +30,27 @@ def __init__(self, *args, **kwargs): super(PartialConv1d, self).__init__(*args, **kwargs) weight_maskUpdater = torch.ones(1, 1, self.kernel_size[0]) self.register_buffer("weight_maskUpdater", weight_maskUpdater, persistent=False) - slide_winsize = torch.tensor( - self.weight_maskUpdater.shape[1] * self.weight_maskUpdater.shape[2], requires_grad=False - ) - self.register_buffer("slide_winsize", slide_winsize, persistent=False) + self.slide_winsize = self.kernel_size[0] # caching part self.last_size = (-1, -1, -1) - if self.bias is not None: - bias_view = self.bias.clone().detach().reshape(1, self.out_channels, 1) - self.register_buffer("bias_view", bias_view, persistent=False) - update_mask = torch.ones(1, 1, 1) self.register_buffer('update_mask', update_mask, persistent=False) mask_ratio = torch.ones(1, 1, 1) self.register_buffer('mask_ratio', mask_ratio, persistent=False) self.partial: bool = True + def update_bias_view(self): + """ + To be called before jit.script(), to set up an alias for self.bias + to work around TorchScript bug resolving Optional[Tensor] + """ + if self.bias is not None: + self.bias_view = self.bias.view(1, self.out_channels, 1) + else: + self.bias_view = None + def calculate_mask(self, input: torch.Tensor, mask_in: Optional[torch.Tensor]): with torch.no_grad(): if mask_in is None: @@ -73,16 +76,29 @@ def calculate_mask(self, input: torch.Tensor, mask_in: Optional[torch.Tensor]): mask_ratio = torch.mul(mask_ratio, update_mask) return input, mask_ratio, update_mask + def forward_aux_script( + self, input: torch.Tensor, mask_ratio: torch.Tensor, update_mask: torch.Tensor + ) -> torch.Tensor: + assert len(input.shape) == 3 + + raw_out = self._conv_forward(input, self.weight, self.bias) + + if self.bias is not None: + bias_view = self.bias_view + output = torch.mul(raw_out - bias_view, mask_ratio) + bias_view + output = torch.mul(output, update_mask) + else: + output = torch.mul(raw_out, mask_ratio) + + return output + def forward_aux(self, input: torch.Tensor, mask_ratio: torch.Tensor, update_mask: torch.Tensor) -> torch.Tensor: assert len(input.shape) == 3 raw_out = self._conv_forward(input, self.weight, self.bias) if self.bias is not None: - if torch.jit.is_scripting(): - bias_view = self.bias_view - else: - bias_view = self.bias.view(1, self.out_channels, 1) + bias_view = self.bias.view(1, self.out_channels, 1) output = torch.mul(raw_out - bias_view, mask_ratio) + bias_view output = torch.mul(output, update_mask) else: @@ -90,9 +106,8 @@ def forward_aux(self, input: torch.Tensor, mask_ratio: torch.Tensor, update_mask return output - @torch.jit.ignore def forward_with_cache(self, input: torch.Tensor, mask_in: Optional[torch.Tensor] = None) -> torch.Tensor: - use_cache = not (torch.jit.is_tracing() or torch.onnx.is_in_onnx_export()) + use_cache = self.partial and not torch.jit.is_tracing() cache_hit = use_cache and mask_in is None and self.last_size == input.shape if cache_hit: mask_ratio = self.mask_ratio @@ -101,15 +116,15 @@ def forward_with_cache(self, input: torch.Tensor, mask_in: Optional[torch.Tensor input, mask_ratio, update_mask = self.calculate_mask(input, mask_in) if use_cache: # if a mask is input, or tensor shape changed, update mask ratio - self.last_size = tuple(input.shape) + self.last_size = (input.shape[0], input.shape[1], input.shape[2]) self.update_mask = update_mask self.mask_ratio = mask_ratio return self.forward_aux(input, mask_ratio, update_mask) - def forward_no_cache(self, input: torch.Tensor, mask_in: Optional[torch.Tensor] = None) -> torch.Tensor: + def forward_for_script(self, input: torch.Tensor, mask_in: Optional[torch.Tensor] = None) -> torch.Tensor: if self.partial: input, mask_ratio, update_mask = self.calculate_mask(input, mask_in) - return self.forward_aux(input, mask_ratio, update_mask) + return self.forward_aux_script(input, mask_ratio, update_mask) else: if mask_in is not None: input = torch.mul(input, mask_in) From 4007c2a0b34b69c7fb310cde9dfcaf7c9db67a4f Mon Sep 17 00:00:00 2001 From: Boris Fomitchev Date: Thu, 10 Nov 2022 17:39:18 -0800 Subject: [PATCH 08/43] Removing unused import Signed-off-by: Boris Fomitchev --- tests/collections/tts/test_tts_exportables.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/collections/tts/test_tts_exportables.py b/tests/collections/tts/test_tts_exportables.py index f39d4f0bb779..4b64b5096bc3 100644 --- a/tests/collections/tts/test_tts_exportables.py +++ b/tests/collections/tts/test_tts_exportables.py @@ -15,7 +15,7 @@ import tempfile import pytest -from omegaconf import DictConfig, OmegaConf +from omegaconf import OmegaConf from nemo.collections.tts.models import FastPitchModel, HifiGanModel, RadTTSModel from nemo.utils.app_state import AppState From 45f2d1a3ea6e64ec841d107a7f1374d7d89e74c9 Mon Sep 17 00:00:00 2001 From: Boris Fomitchev Date: Thu, 10 Nov 2022 18:33:13 -0800 Subject: [PATCH 09/43] Reverting PartialConv Signed-off-by: Boris Fomitchev --- nemo/collections/tts/modules/common.py | 1 - nemo/collections/tts/modules/radtts.py | 8 +- nemo/collections/tts/modules/submodules.py | 179 ++++++++------------- 3 files changed, 68 insertions(+), 120 deletions(-) diff --git a/nemo/collections/tts/modules/common.py b/nemo/collections/tts/modules/common.py index 0613a5343dc0..83ad6de6489e 100644 --- a/nemo/collections/tts/modules/common.py +++ b/nemo/collections/tts/modules/common.py @@ -180,7 +180,6 @@ def __init__( w_init_gain='relu', use_weight_norm=False, use_partial_padding=use_partial_padding, - norm_fn=norm_fn, ) if norm_fn is not None: print("Applying {} norm to {}".format(norm_fn, conv_layer)) diff --git a/nemo/collections/tts/modules/radtts.py b/nemo/collections/tts/modules/radtts.py index a30de902160b..c84d4979bea6 100644 --- a/nemo/collections/tts/modules/radtts.py +++ b/nemo/collections/tts/modules/radtts.py @@ -783,14 +783,14 @@ def output_types(self): # Methods for model exportability def _prepare_for_export(self, **kwargs): print(kwargs) - PartialConv1d.forward = PartialConv1d.forward_for_script + # PartialConv1d.forward = PartialConv1d.forward_for_script self.remove_norms() super()._prepare_for_export(**kwargs) if self.prepared_for_export: return - for m in self.modules(): - if isinstance(m, PartialConv1d): - PartialConv1d.update_bias_view(m) + # for m in self.modules(): + # if isinstance(m, PartialConv1d): + # PartialConv1d.update_bias_view(m) self.encoder = torch.jit.script(self.encoder) self.v_pred_module.feat_pred_fn = torch.jit.script(self.v_pred_module.feat_pred_fn) diff --git a/nemo/collections/tts/modules/submodules.py b/nemo/collections/tts/modules/submodules.py index b43a9ad32dbe..03e21013b5bb 100644 --- a/nemo/collections/tts/modules/submodules.py +++ b/nemo/collections/tts/modules/submodules.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Tuple +from typing import Tuple import torch from torch.autograd import Variable @@ -27,116 +27,65 @@ class PartialConv1d(torch.nn.Conv1d): """ def __init__(self, *args, **kwargs): - super(PartialConv1d, self).__init__(*args, **kwargs) - weight_maskUpdater = torch.ones(1, 1, self.kernel_size[0]) - self.register_buffer("weight_maskUpdater", weight_maskUpdater, persistent=False) - self.slide_winsize = self.kernel_size[0] - - # caching part - self.last_size = (-1, -1, -1) - - update_mask = torch.ones(1, 1, 1) - self.register_buffer('update_mask', update_mask, persistent=False) - mask_ratio = torch.ones(1, 1, 1) - self.register_buffer('mask_ratio', mask_ratio, persistent=False) - self.partial: bool = True - - def update_bias_view(self): - """ - To be called before jit.script(), to set up an alias for self.bias - to work around TorchScript bug resolving Optional[Tensor] - """ - if self.bias is not None: - self.bias_view = self.bias.view(1, self.out_channels, 1) - else: - self.bias_view = None - def calculate_mask(self, input: torch.Tensor, mask_in: Optional[torch.Tensor]): - with torch.no_grad(): - if mask_in is None: - mask = torch.ones(1, 1, input.shape[2], dtype=input.dtype, device=input.device) - else: - mask = mask_in - input = torch.mul(input, mask) - - update_mask = F.conv1d( - mask, - self.weight_maskUpdater, - bias=None, - stride=self.stride, - padding=self.padding, - dilation=self.dilation, - groups=1, - ) + self.multi_channel = False + self.return_mask = False + super(PartialConv1d, self).__init__(*args, **kwargs) - update_mask_filled = torch.masked_fill(update_mask, update_mask == 0, self.slide_winsize) - mask_ratio = self.slide_winsize / (update_mask_filled) + self.weight_maskUpdater = torch.ones(1, 1, self.kernel_size[0]) + self.slide_winsize = self.weight_maskUpdater.shape[1] * self.weight_maskUpdater.shape[2] - update_mask = torch.clamp(update_mask, 0, 1) - mask_ratio = torch.mul(mask_ratio, update_mask) - return input, mask_ratio, update_mask + self.last_size = (None, None, None) + self.update_mask = None + self.mask_ratio = None - def forward_aux_script( - self, input: torch.Tensor, mask_ratio: torch.Tensor, update_mask: torch.Tensor - ) -> torch.Tensor: + def forward(self, input: torch.Tensor, mask_in: Tuple[int, int, int] = None): assert len(input.shape) == 3 - - raw_out = self._conv_forward(input, self.weight, self.bias) - - if self.bias is not None: - bias_view = self.bias_view - output = torch.mul(raw_out - bias_view, mask_ratio) + bias_view - output = torch.mul(output, update_mask) - else: - output = torch.mul(raw_out, mask_ratio) - - return output - - def forward_aux(self, input: torch.Tensor, mask_ratio: torch.Tensor, update_mask: torch.Tensor) -> torch.Tensor: - assert len(input.shape) == 3 - - raw_out = self._conv_forward(input, self.weight, self.bias) - - if self.bias is not None: - bias_view = self.bias.view(1, self.out_channels, 1) - output = torch.mul(raw_out - bias_view, mask_ratio) + bias_view - output = torch.mul(output, update_mask) - else: - output = torch.mul(raw_out, mask_ratio) - - return output - - def forward_with_cache(self, input: torch.Tensor, mask_in: Optional[torch.Tensor] = None) -> torch.Tensor: - use_cache = self.partial and not torch.jit.is_tracing() - cache_hit = use_cache and mask_in is None and self.last_size == input.shape + # borisf: disabled cache for export + use_cache = not (torch.jit.is_tracing() or torch.onnx.is_in_onnx_export()) + cache_hit = use_cache and mask_in is None and self.last_size == tuple(input.shape) if cache_hit: mask_ratio = self.mask_ratio update_mask = self.update_mask + # if a mask is input, or tensor shape changed, update mask ratio else: - input, mask_ratio, update_mask = self.calculate_mask(input, mask_in) + with torch.no_grad(): + if self.weight_maskUpdater.type() != input.type(): + self.weight_maskUpdater = self.weight_maskUpdater.to(input) + if mask_in is None: + mask = torch.ones(1, 1, input.shape[2]).to(input) + else: + mask = mask_in + update_mask = F.conv1d( + mask, + self.weight_maskUpdater, + bias=None, + stride=self.stride, + padding=self.padding, + dilation=self.dilation, + groups=1, + ) + # for mixed precision training, change 1e-8 to 1e-6 + mask_ratio = self.slide_winsize / (update_mask + 1e-6) + update_mask = torch.clamp(update_mask, 0, 1) + mask_ratio = torch.mul(mask_ratio, update_mask) if use_cache: - # if a mask is input, or tensor shape changed, update mask ratio - self.last_size = (input.shape[0], input.shape[1], input.shape[2]) + self.last_size = tuple(input.shape) self.update_mask = update_mask self.mask_ratio = mask_ratio - return self.forward_aux(input, mask_ratio, update_mask) - def forward_for_script(self, input: torch.Tensor, mask_in: Optional[torch.Tensor] = None) -> torch.Tensor: - if self.partial: - input, mask_ratio, update_mask = self.calculate_mask(input, mask_in) - return self.forward_aux_script(input, mask_ratio, update_mask) + raw_out = super(PartialConv1d, self).forward(torch.mul(input, mask) if mask_in is not None else input) + if self.bias is not None: + bias_view = self.bias.view(1, self.out_channels, 1) + output = torch.mul(raw_out - bias_view, mask_ratio) + bias_view + output = torch.mul(output, update_mask) else: - if mask_in is not None: - input = torch.mul(input, mask_in) - return self._conv_forward(input, self.weight, self.bias) + output = torch.mul(raw_out, mask_ratio) - def forward(self, input: torch.Tensor, mask_in: Optional[torch.Tensor] = None) -> torch.Tensor: - if self.partial: - return self.forward_with_cache(input, mask_in) + if self.return_mask: + return output, update_mask else: - if mask_in is not None: - input = torch.mul(input, mask_in) - return self._conv_forward(input, self.weight, self.bias) + return output class LinearNorm(torch.nn.Module): @@ -161,16 +110,21 @@ def __init__( dilation=1, bias=True, w_init_gain='linear', - use_partial_padding: bool = False, - use_weight_norm: bool = False, - norm_fn=None, + use_partial_padding=False, + use_weight_norm=False, ): super(ConvNorm, self).__init__() if padding is None: assert kernel_size % 2 == 1 padding = int(dilation * (kernel_size - 1) / 2) - self.use_partial_padding: bool = use_partial_padding - conv = PartialConv1d( + self.kernel_size = kernel_size + self.dilation = dilation + self.use_partial_padding = use_partial_padding + self.use_weight_norm = use_weight_norm + conv_fn = torch.nn.Conv1d + if self.use_partial_padding: + conv_fn = PartialConv1d + self.conv = conv_fn( in_channels, out_channels, kernel_size=kernel_size, @@ -179,21 +133,16 @@ def __init__( dilation=dilation, bias=bias, ) - conv.partial = use_partial_padding - torch.nn.init.xavier_uniform_(conv.weight, gain=torch.nn.init.calculate_gain(w_init_gain)) - if use_weight_norm: - conv = torch.nn.utils.weight_norm(conv) - if norm_fn is not None: - self.norm = norm_fn(out_channels, affine=True) - else: - self.norm = None - self.conv = conv + torch.nn.init.xavier_uniform_(self.conv.weight, gain=torch.nn.init.calculate_gain(w_init_gain)) + if self.use_weight_norm: + self.conv = torch.nn.utils.weight_norm(self.conv) - def forward(self, input: torch.Tensor, mask_in: Optional[torch.Tensor] = None) -> torch.Tensor: - ret = self.conv(input, mask_in) - if self.norm is not None: - ret = self.norm(ret) - return ret + def forward(self, signal, mask=None): + if self.use_partial_padding: + conv_signal = self.conv(signal, mask) + else: + conv_signal = self.conv(signal) + return conv_signal class LocationLayer(torch.nn.Module): @@ -322,7 +271,7 @@ def __init__(self, c): self.conv = torch.nn.Conv1d(c, c, kernel_size=1, stride=1, padding=0, bias=False) # Sample a random orthonormal matrix to initialize weights - W = torch.linalg.qr(torch.FloatTensor(c, c).normal_())[0] + W = torch.qr(torch.FloatTensor(c, c).normal_())[0] # Ensure determinant is 1.0 not -1.0 if torch.det(W) < 0: From 406a0bacef571e8c0e4703acb6c4e66be0c53c53 Mon Sep 17 00:00:00 2001 From: Boris Fomitchev Date: Thu, 10 Nov 2022 19:03:01 -0800 Subject: [PATCH 10/43] Removing flatten_parameters Signed-off-by: Boris Fomitchev --- nemo/collections/tts/modules/common.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/nemo/collections/tts/modules/common.py b/nemo/collections/tts/modules/common.py index 83ad6de6489e..afb5ba847589 100644 --- a/nemo/collections/tts/modules/common.py +++ b/nemo/collections/tts/modules/common.py @@ -123,15 +123,15 @@ def lstm_tensor(self, context: Tensor, lens: Tensor, enforce_sorted: bool = Fals seq = nn.utils.rnn.pack_padded_sequence( context, lens.long().cpu(), batch_first=True, enforce_sorted=enforce_sorted ) - if not torch.jit.is_scripting(): - self.bilstm.flatten_parameters() + # if not torch.jit.is_scripting(): + # self.bilstm.flatten_parameters() ret, _ = self.bilstm(seq) return nn.utils.rnn.pad_packed_sequence(ret, batch_first=True) @torch.jit.export def lstm_sequence(self, seq: PackedSequence) -> Tuple[Tensor, Tensor]: - if not torch.jit.is_scripting(): - self.bilstm.flatten_parameters() + # if not torch.jit.is_scripting(): + # self.bilstm.flatten_parameters() ret, _ = self.bilstm(seq) return nn.utils.rnn.pad_packed_sequence(ret, batch_first=True) From 7bbc05aa18d4c1147758fdb498dfb43f387db5fa Mon Sep 17 00:00:00 2001 From: Boris Fomitchev Date: Thu, 10 Nov 2022 20:41:59 -0800 Subject: [PATCH 11/43] Moving mask updater to GPU Signed-off-by: Boris Fomitchev --- nemo/collections/tts/modules/common.py | 14 ++++++++++++++ nemo/collections/tts/modules/radtts.py | 10 +++++----- nemo/collections/tts/modules/submodules.py | 10 +++++----- 3 files changed, 24 insertions(+), 10 deletions(-) diff --git a/nemo/collections/tts/modules/common.py b/nemo/collections/tts/modules/common.py index afb5ba847589..a65df8be2547 100644 --- a/nemo/collections/tts/modules/common.py +++ b/nemo/collections/tts/modules/common.py @@ -246,6 +246,20 @@ def forward(self, context: Tensor, lens: Optional[Tensor] = None) -> Tensor: return context + def script(self): + traced = nn.ModuleList() + for conv in self.convolutions: + if hasattr(conv, 'conv'): + w = conv.conv.weight + else: + w = conv.weight + s = w.shape + rand_in = torch.ones(1, s[1], s[2], dtype=w.dtype, device=w.device) + # , torch.ones(1, 1, s[2], dtype=w.dtype, device=w.device)) + traced.append(torch.jit.trace_module(conv, {"forward": rand_in})) + self.convolutions = traced + return torch.jit.script(self) + def getRadTTSEncoder( encoder_n_convolutions=3, diff --git a/nemo/collections/tts/modules/radtts.py b/nemo/collections/tts/modules/radtts.py index c84d4979bea6..756c65f05c1e 100644 --- a/nemo/collections/tts/modules/radtts.py +++ b/nemo/collections/tts/modules/radtts.py @@ -792,14 +792,14 @@ def _prepare_for_export(self, **kwargs): # if isinstance(m, PartialConv1d): # PartialConv1d.update_bias_view(m) - self.encoder = torch.jit.script(self.encoder) - self.v_pred_module.feat_pred_fn = torch.jit.script(self.v_pred_module.feat_pred_fn) + self.encoder = self.encoder.script() + self.v_pred_module.feat_pred_fn = self.v_pred_module.feat_pred_fn.script() if hasattr(self, 'f0_pred_module'): - self.f0_pred_module.feat_pred_fn = torch.jit.script(self.f0_pred_module.feat_pred_fn) + self.f0_pred_module.feat_pred_fn = self.f0_pred_module.feat_pred_fn.script() if hasattr(self, 'energy_pred_module'): - self.energy_pred_module.feat_pred_fn = torch.jit.script(self.energy_pred_module.feat_pred_fn) + self.energy_pred_module.feat_pred_fn = self.energy_pred_module.feat_pred_fn.script() if hasattr(self, 'dur_pred_layer'): - self.dur_pred_layer.feat_pred_fn = torch.jit.script(self.dur_pred_layer.feat_pred_fn) + self.dur_pred_layer.feat_pred_fn = self.dur_pred_layer.feat_pred_fn.script() if self.use_context_lstm: self.context_lstm = torch.jit.script(self.context_lstm) diff --git a/nemo/collections/tts/modules/submodules.py b/nemo/collections/tts/modules/submodules.py index 03e21013b5bb..70e1b030e4e9 100644 --- a/nemo/collections/tts/modules/submodules.py +++ b/nemo/collections/tts/modules/submodules.py @@ -31,15 +31,17 @@ def __init__(self, *args, **kwargs): self.multi_channel = False self.return_mask = False super(PartialConv1d, self).__init__(*args, **kwargs) + weight_maskUpdater = torch.ones(1, 1, self.kernel_size[0]) + self.register_buffer("weight_maskUpdater", weight_maskUpdater, persistent=False) - self.weight_maskUpdater = torch.ones(1, 1, self.kernel_size[0]) + # self.weight_maskUpdater = torch.ones(1, 1, self.kernel_size[0]) self.slide_winsize = self.weight_maskUpdater.shape[1] * self.weight_maskUpdater.shape[2] self.last_size = (None, None, None) self.update_mask = None self.mask_ratio = None - def forward(self, input: torch.Tensor, mask_in: Tuple[int, int, int] = None): + def forward(self, input, mask_in=None): assert len(input.shape) == 3 # borisf: disabled cache for export use_cache = not (torch.jit.is_tracing() or torch.onnx.is_in_onnx_export()) @@ -50,10 +52,8 @@ def forward(self, input: torch.Tensor, mask_in: Tuple[int, int, int] = None): # if a mask is input, or tensor shape changed, update mask ratio else: with torch.no_grad(): - if self.weight_maskUpdater.type() != input.type(): - self.weight_maskUpdater = self.weight_maskUpdater.to(input) if mask_in is None: - mask = torch.ones(1, 1, input.shape[2]).to(input) + mask = torch.ones(1, 1, input.shape[2], dtype=input.dtype, device=input.device) else: mask = mask_in update_mask = F.conv1d( From 7ce38817d4860339ec025937b37107b9bc61fafb Mon Sep 17 00:00:00 2001 From: Boris Fomitchev Date: Thu, 10 Nov 2022 23:03:32 -0800 Subject: [PATCH 12/43] Restored norms Signed-off-by: Boris Fomitchev --- nemo/collections/tts/modules/common.py | 1 + nemo/collections/tts/modules/radtts.py | 4 ++-- nemo/collections/tts/modules/submodules.py | 13 ++++++++++--- 3 files changed, 13 insertions(+), 5 deletions(-) diff --git a/nemo/collections/tts/modules/common.py b/nemo/collections/tts/modules/common.py index a65df8be2547..3e569d0f8f14 100644 --- a/nemo/collections/tts/modules/common.py +++ b/nemo/collections/tts/modules/common.py @@ -180,6 +180,7 @@ def __init__( w_init_gain='relu', use_weight_norm=False, use_partial_padding=use_partial_padding, + norm_fn=norm_fn, ) if norm_fn is not None: print("Applying {} norm to {}".format(norm_fn, conv_layer)) diff --git a/nemo/collections/tts/modules/radtts.py b/nemo/collections/tts/modules/radtts.py index 756c65f05c1e..8fcd20427fba 100644 --- a/nemo/collections/tts/modules/radtts.py +++ b/nemo/collections/tts/modules/radtts.py @@ -834,8 +834,8 @@ def forward_for_export(self, text, lens, speaker_id, speaker_id_text, speaker_id text, speaker_id_text=speaker_id_text, speaker_id_attributes=speaker_id_attributes, - sigma=0.7, - sigma_txt=0.7, + sigma=0.0, + sigma_txt=0.0, sigma_f0=1.0, sigma_energy=1.0, f0_mean=145.0, diff --git a/nemo/collections/tts/modules/submodules.py b/nemo/collections/tts/modules/submodules.py index 70e1b030e4e9..311046729d1d 100644 --- a/nemo/collections/tts/modules/submodules.py +++ b/nemo/collections/tts/modules/submodules.py @@ -112,6 +112,7 @@ def __init__( w_init_gain='linear', use_partial_padding=False, use_weight_norm=False, + norm_fn=None, ): super(ConvNorm, self).__init__() if padding is None: @@ -136,13 +137,19 @@ def __init__( torch.nn.init.xavier_uniform_(self.conv.weight, gain=torch.nn.init.calculate_gain(w_init_gain)) if self.use_weight_norm: self.conv = torch.nn.utils.weight_norm(self.conv) + if norm_fn is not None: + self.norm = norm_fn(out_channels, affine=True) + else: + self.norm = None def forward(self, signal, mask=None): if self.use_partial_padding: - conv_signal = self.conv(signal, mask) + ret = self.conv(signal, mask) else: - conv_signal = self.conv(signal) - return conv_signal + ret = self.conv(signal) + if self.norm is not None: + ret = self.norm(ret) + return ret class LocationLayer(torch.nn.Module): From 864f9cd6e40dcbb91c8749c31914ba1ba7a66729 Mon Sep 17 00:00:00 2001 From: Boris Fomitchev Date: Thu, 10 Nov 2022 23:15:48 -0800 Subject: [PATCH 13/43] Restored flatten Signed-off-by: Boris Fomitchev --- nemo/collections/tts/modules/common.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/nemo/collections/tts/modules/common.py b/nemo/collections/tts/modules/common.py index 3e569d0f8f14..262b1a7e5e1b 100644 --- a/nemo/collections/tts/modules/common.py +++ b/nemo/collections/tts/modules/common.py @@ -123,15 +123,15 @@ def lstm_tensor(self, context: Tensor, lens: Tensor, enforce_sorted: bool = Fals seq = nn.utils.rnn.pack_padded_sequence( context, lens.long().cpu(), batch_first=True, enforce_sorted=enforce_sorted ) - # if not torch.jit.is_scripting(): - # self.bilstm.flatten_parameters() + if not (torch.jit.is_tracing() or torch.jit.is_scripting()): + self.bilstm.flatten_parameters() ret, _ = self.bilstm(seq) return nn.utils.rnn.pad_packed_sequence(ret, batch_first=True) @torch.jit.export def lstm_sequence(self, seq: PackedSequence) -> Tuple[Tensor, Tensor]: - # if not torch.jit.is_scripting(): - # self.bilstm.flatten_parameters() + if not (torch.jit.is_tracing() or torch.jit.is_scripting()): + self.bilstm.flatten_parameters() ret, _ = self.bilstm(seq) return nn.utils.rnn.pad_packed_sequence(ret, batch_first=True) From 9c91f92c6fccf707755a8d0fff91f13e379079e7 Mon Sep 17 00:00:00 2001 From: Boris Fomitchev Date: Thu, 10 Nov 2022 23:49:32 -0800 Subject: [PATCH 14/43] Moved to sort/unsort Signed-off-by: Boris Fomitchev --- nemo/collections/tts/modules/common.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/nemo/collections/tts/modules/common.py b/nemo/collections/tts/modules/common.py index 262b1a7e5e1b..504becc9cf64 100644 --- a/nemo/collections/tts/modules/common.py +++ b/nemo/collections/tts/modules/common.py @@ -123,15 +123,15 @@ def lstm_tensor(self, context: Tensor, lens: Tensor, enforce_sorted: bool = Fals seq = nn.utils.rnn.pack_padded_sequence( context, lens.long().cpu(), batch_first=True, enforce_sorted=enforce_sorted ) - if not (torch.jit.is_tracing() or torch.jit.is_scripting()): - self.bilstm.flatten_parameters() + # if not (torch.jit.is_tracing() or torch.jit.is_scripting()): + # self.bilstm.flatten_parameters() ret, _ = self.bilstm(seq) return nn.utils.rnn.pad_packed_sequence(ret, batch_first=True) @torch.jit.export def lstm_sequence(self, seq: PackedSequence) -> Tuple[Tensor, Tensor]: - if not (torch.jit.is_tracing() or torch.jit.is_scripting()): - self.bilstm.flatten_parameters() + # if not (torch.jit.is_tracing() or torch.jit.is_scripting()): + # self.bilstm.flatten_parameters() ret, _ = self.bilstm(seq) return nn.utils.rnn.pad_packed_sequence(ret, batch_first=True) @@ -237,10 +237,11 @@ def forward(self, context: Tensor, lens: Optional[Tensor] = None) -> Tensor: else: my_lens = lens # borisf : does not match ADLR (values, lengths) - # seq = self.masked_conv_to_sequence(context, lens, enforce_sorted=False) - # borisf : does match ADLR - seq = self.conv_to_sequence(context, my_lens, enforce_sorted=False) + context, my_lens, unsort_ids = sort_tensor(context, my_lens) + # seq = self.masked_conv_to_sequence(context, my_lens, enforce_sorted=True) + seq = self.conv_to_sequence(context, my_lens, enforce_sorted=True) context, _ = self.lstm_sequence(seq) + context = context[unsort_ids] if self.dense is not None: context = self.dense(context).permute(0, 2, 1) From 89d47d84f6d26da3b7ac9f015f85532892fdec3c Mon Sep 17 00:00:00 2001 From: Boris Fomitchev Date: Fri, 11 Nov 2022 17:24:22 -0800 Subject: [PATCH 15/43] Moved to masked norm Signed-off-by: Boris Fomitchev --- nemo/collections/tts/modules/common.py | 57 ++++------- nemo/collections/tts/modules/radtts.py | 14 +-- nemo/collections/tts/modules/submodules.py | 113 ++++++++++++++++++++- 3 files changed, 130 insertions(+), 54 deletions(-) diff --git a/nemo/collections/tts/modules/common.py b/nemo/collections/tts/modules/common.py index 504becc9cf64..302765edd617 100644 --- a/nemo/collections/tts/modules/common.py +++ b/nemo/collections/tts/modules/common.py @@ -29,7 +29,7 @@ piecewise_linear_transform, unbounded_piecewise_quadratic_transform, ) -from nemo.collections.tts.modules.submodules import ConvNorm, LinearNorm +from nemo.collections.tts.modules.submodules import ConvNorm, LinearNorm, MaskedInstanceNorm1d @torch.jit.script @@ -169,6 +169,8 @@ def __init__( if n_layers > 0: self.dropout = nn.Dropout(p=p_dropout) + use_weight_norm = norm_fn is None + for i in range(n_layers): conv_layer = ConvNorm( in_dim if i == 0 else n_channels, @@ -178,14 +180,14 @@ def __init__( padding=int((kernel_size - 1) / 2), dilation=1, w_init_gain='relu', - use_weight_norm=False, + use_weight_norm=use_weight_norm, use_partial_padding=use_partial_padding, norm_fn=norm_fn, ) if norm_fn is not None: print("Applying {} norm to {}".format(norm_fn, conv_layer)) else: - conv_layer = torch.nn.utils.weight_norm(conv_layer.conv) + # conv_layer = torch.nn.utils.weight_norm(conv_layer.conv) print("Applying weight norm to {}".format(conv_layer)) self.convolutions.append(conv_layer) @@ -193,32 +195,6 @@ def __init__( if out_dim is not None: self.dense = nn.Linear(n_channels, out_dim) - @torch.jit.export - def conv_to_sequence(self, context: Tensor, lens: Tensor, enforce_sorted: bool = False) -> PackedSequence: - context_embedded = [] - bs: int = context.shape[0] - b_ind: int = 0 - for b_ind in range(bs): # TODO: speed up - curr_context = context[b_ind : b_ind + 1, :, : lens[b_ind]].clone() - for conv in self.convolutions: - curr_context = self.dropout(F.relu(conv(curr_context))) - context_embedded.append(curr_context[0].transpose(0, 1)) - seq = torch.nn.utils.rnn.pack_sequence(context_embedded, enforce_sorted=enforce_sorted) - return seq - - @torch.jit.export - def conv_to_padded_tensor(self, context: Tensor, lens: Tensor) -> Tensor: - context_embedded = [] - bs: int = context.shape[0] - b_ind: int = 0 - for b_ind in range(bs): # TODO: speed up - curr_context = context[b_ind : b_ind + 1, :, : lens[b_ind]].clone() - for conv in self.convolutions: - curr_context = self.dropout(F.relu(conv(curr_context))) - context_embedded.append(curr_context[0].transpose(0, 1)) - ret = torch.nn.utils.rnn.pad_sequence(context_embedded, batch_first=True) - return ret - def masked_conv_to_sequence(self, context: Tensor, lens: Tensor, enforce_sorted: bool = False) -> PackedSequence: mask = get_mask_from_lengths_and_val(lens, context) mask = mask.unsqueeze(1) @@ -231,15 +207,14 @@ def masked_conv_to_sequence(self, context: Tensor, lens: Tensor, enforce_sorted: ) return seq - def forward(self, context: Tensor, lens: Optional[Tensor] = None) -> Tensor: - if lens is None: - my_lens = context.new_ones([context.shape[0]], dtype=torch.int) * context.shape[2] - else: - my_lens = lens - # borisf : does not match ADLR (values, lengths) + def forward(self, context: Tensor, lens: Tensor) -> Tensor: + # if lens is None: + # my_lens = context.new_ones([context.shape[0]], dtype=torch.int) * context.shape[2] + # else: + my_lens = lens context, my_lens, unsort_ids = sort_tensor(context, my_lens) - # seq = self.masked_conv_to_sequence(context, my_lens, enforce_sorted=True) - seq = self.conv_to_sequence(context, my_lens, enforce_sorted=True) + seq = self.masked_conv_to_sequence(context, my_lens, enforce_sorted=True) + # seq = self.conv_to_sequence(context, my_lens, enforce_sorted=True) context, _ = self.lstm_sequence(seq) context = context[unsort_ids] @@ -256,8 +231,10 @@ def script(self): else: w = conv.weight s = w.shape - rand_in = torch.ones(1, s[1], s[2], dtype=w.dtype, device=w.device) - # , torch.ones(1, 1, s[2], dtype=w.dtype, device=w.device)) + rand_in = ( + torch.ones(1, s[1], s[2], dtype=w.dtype, device=w.device), + torch.ones(1, 1, s[2], dtype=w.dtype, device=w.device), + ) traced.append(torch.jit.trace_module(conv, {"forward": rand_in})) self.convolutions = traced return torch.jit.script(self) @@ -267,7 +244,7 @@ def getRadTTSEncoder( encoder_n_convolutions=3, encoder_embedding_dim=512, encoder_kernel_size=5, - norm_fn=nn.BatchNorm1d, + norm_fn=MaskedInstanceNorm1d, lstm_norm_fn=None, ): return ConvLSTMLinear( diff --git a/nemo/collections/tts/modules/radtts.py b/nemo/collections/tts/modules/radtts.py index 8fcd20427fba..6f298aeaf9d7 100644 --- a/nemo/collections/tts/modules/radtts.py +++ b/nemo/collections/tts/modules/radtts.py @@ -183,9 +183,7 @@ def __init__( self.speaker_embedding = torch.nn.Embedding(n_speakers, self.n_speaker_dim) self.embedding = torch.nn.Embedding(n_text, n_text_dim) self.flows = torch.nn.ModuleList() - self.encoder = getRadTTSEncoder( - encoder_embedding_dim=n_text_dim, norm_fn=nn.InstanceNorm1d, lstm_norm_fn=text_encoder_lstm_norm - ) + self.encoder = getRadTTSEncoder(encoder_embedding_dim=n_text_dim, lstm_norm_fn=text_encoder_lstm_norm) self.dummy_speaker_embedding = dummy_speaker_embedding self.learn_alignments = learn_alignments self.affine_activation = affine_activation @@ -782,16 +780,10 @@ def output_types(self): # Methods for model exportability def _prepare_for_export(self, **kwargs): - print(kwargs) - # PartialConv1d.forward = PartialConv1d.forward_for_script self.remove_norms() super()._prepare_for_export(**kwargs) if self.prepared_for_export: return - # for m in self.modules(): - # if isinstance(m, PartialConv1d): - # PartialConv1d.update_bias_view(m) - self.encoder = self.encoder.script() self.v_pred_module.feat_pred_fn = self.v_pred_module.feat_pred_fn.script() if hasattr(self, 'f0_pred_module'): @@ -834,8 +826,8 @@ def forward_for_export(self, text, lens, speaker_id, speaker_id_text, speaker_id text, speaker_id_text=speaker_id_text, speaker_id_attributes=speaker_id_attributes, - sigma=0.0, - sigma_txt=0.0, + sigma=0.7, + sigma_txt=0.7, sigma_f0=1.0, sigma_energy=1.0, f0_mean=145.0, diff --git a/nemo/collections/tts/modules/submodules.py b/nemo/collections/tts/modules/submodules.py index 311046729d1d..c193f3403a80 100644 --- a/nemo/collections/tts/modules/submodules.py +++ b/nemo/collections/tts/modules/submodules.py @@ -12,13 +12,116 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Tuple +from typing import Optional, Tuple import torch +from torch import Tensor from torch.autograd import Variable from torch.nn import functional as F +def masked_instance_norm( + input: Tensor, + mask: Tensor, + running_mean: Optional[Tensor], + running_var: Optional[Tensor], + weight: Optional[Tensor], + bias: Optional[Tensor], + use_input_stats: bool, + momentum: float, + eps: float = 1e-5, +) -> Tensor: + r"""Applies Masked Instance Normalization for each channel in each data sample in a batch. + + See :class:`~MaskedInstanceNorm1d` for details. + """ + if not use_input_stats and (running_mean is None or running_var is None): + raise ValueError('Expected running_mean and running_var to be not None when use_input_stats=False') + + shape = input.shape + b, c = shape[:2] + + num_dims = len(shape[2:]) + _dims = tuple(range(-num_dims, 0)) + _slice = (...,) + (None,) * num_dims + + running_mean_ = running_mean[None, :].repeat(b, 1) if running_mean is not None else None + running_var_ = running_var[None, :].repeat(b, 1) if running_mean is not None else None + + if use_input_stats: + lengths = mask.sum(_dims) + mean = (input * mask).sum(_dims) / lengths # (N, C) + var = (((input - mean[_slice]) * mask) ** 2).sum(_dims) / lengths # (N, C) + + if running_mean is not None: + running_mean_.mul_(1 - momentum).add_(momentum * mean.detach()) + running_mean.copy_(running_mean_.view(b, c).mean(0, keepdim=False)) + if running_var is not None: + running_var_.mul_(1 - momentum).add_(momentum * var.detach()) + running_var.copy_(running_var_.view(b, c).mean(0, keepdim=False)) + else: + mean, var = running_mean_.view(b, c), running_var_.view(b, c) + + out = (input - mean[_slice]) / torch.sqrt(var[_slice] + eps) # (N, C, ...) + + if weight is not None and bias is not None: + out = out * weight[None, :][_slice] + bias[None, :][_slice] + + return out + + +class MaskedInstanceNorm1d(torch.nn.InstanceNorm1d): + r"""Applies Instance Normalization over a masked 3D input + (a mini-batch of 1D inputs with additional channel dimension).. + + See documentation of :class:`~torch.nn.InstanceNorm1d` for details. + + Shape: + - Input: :math:`(N, C, L)` + - Mask: :math:`(N, 1, L)` + - Output: :math:`(N, C, L)` (same shape as input) + """ + + def __init__( + self, + num_features: int, + eps: float = 1e-5, + momentum: float = 0.1, + affine: bool = False, + track_running_stats: bool = False, + ) -> None: + super(MaskedInstanceNorm1d, self).__init__(num_features, eps, momentum, affine, track_running_stats) + + def forward(self, input: Tensor, mask: Tensor = None) -> Tensor: + self._check_input_dim(input) + if mask is not None: + self._check_input_dim(mask) + + if mask is None: + return F.instance_norm( + input, + self.running_mean, + self.running_var, + self.weight, + self.bias, + self.training or not self.track_running_stats, + self.momentum, + self.eps, + ) + else: + return masked_instance_norm( + input, + mask, + self.running_mean, + self.running_var, + self.weight, + self.bias, + self.training or not self.track_running_stats, + self.momentum, + self.eps, + ) + + class PartialConv1d(torch.nn.Conv1d): """ Zero padding creates a unique identifier for where the edge of the data is, such that the model can almost always identify @@ -145,10 +248,14 @@ def __init__( def forward(self, signal, mask=None): if self.use_partial_padding: ret = self.conv(signal, mask) + if self.norm is not None: + ret = self.norm(ret, mask) else: + if mask is not None: + signal = signal * mask ret = self.conv(signal) - if self.norm is not None: - ret = self.norm(ret) + if self.norm is not None: + ret = self.norm(ret) return ret From f6c6c0c256515a1410cdfc5636973416c802853b Mon Sep 17 00:00:00 2001 From: Boris Fomitchev Date: Sat, 12 Nov 2022 01:18:43 -0800 Subject: [PATCH 16/43] Turned off cache Signed-off-by: Boris Fomitchev --- nemo/collections/tts/modules/submodules.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/nemo/collections/tts/modules/submodules.py b/nemo/collections/tts/modules/submodules.py index c193f3403a80..49a9df3e9311 100644 --- a/nemo/collections/tts/modules/submodules.py +++ b/nemo/collections/tts/modules/submodules.py @@ -147,8 +147,8 @@ def __init__(self, *args, **kwargs): def forward(self, input, mask_in=None): assert len(input.shape) == 3 # borisf: disabled cache for export - use_cache = not (torch.jit.is_tracing() or torch.onnx.is_in_onnx_export()) - cache_hit = use_cache and mask_in is None and self.last_size == tuple(input.shape) + use_cache = False # not (torch.jit.is_tracing() or torch.onnx.is_in_onnx_export()) + cache_hit = False # use_cache and mask_in is None and self.last_size == tuple(input.shape) if cache_hit: mask_ratio = self.mask_ratio update_mask = self.update_mask From 78902273e3325add848502ede3050f384b3158ee Mon Sep 17 00:00:00 2001 From: Boris Fomitchev Date: Sat, 12 Nov 2022 01:51:18 -0800 Subject: [PATCH 17/43] cleanup Signed-off-by: Boris Fomitchev --- nemo/collections/tts/modules/submodules.py | 18 +++++++----------- 1 file changed, 7 insertions(+), 11 deletions(-) diff --git a/nemo/collections/tts/modules/submodules.py b/nemo/collections/tts/modules/submodules.py index 49a9df3e9311..02d6520f68eb 100644 --- a/nemo/collections/tts/modules/submodules.py +++ b/nemo/collections/tts/modules/submodules.py @@ -131,8 +131,6 @@ class PartialConv1d(torch.nn.Conv1d): def __init__(self, *args, **kwargs): - self.multi_channel = False - self.return_mask = False super(PartialConv1d, self).__init__(*args, **kwargs) weight_maskUpdater = torch.ones(1, 1, self.kernel_size[0]) self.register_buffer("weight_maskUpdater", weight_maskUpdater, persistent=False) @@ -147,8 +145,8 @@ def __init__(self, *args, **kwargs): def forward(self, input, mask_in=None): assert len(input.shape) == 3 # borisf: disabled cache for export - use_cache = False # not (torch.jit.is_tracing() or torch.onnx.is_in_onnx_export()) - cache_hit = False # use_cache and mask_in is None and self.last_size == tuple(input.shape) + use_cache = mask_in is None and not torch.jit.is_tracing() + cache_hit = use_cache and self.last_size == tuple(input.shape) if cache_hit: mask_ratio = self.mask_ratio update_mask = self.update_mask @@ -159,6 +157,7 @@ def forward(self, input, mask_in=None): mask = torch.ones(1, 1, input.shape[2], dtype=input.dtype, device=input.device) else: mask = mask_in + input = torch.mul(input, mask) update_mask = F.conv1d( mask, self.weight_maskUpdater, @@ -168,8 +167,8 @@ def forward(self, input, mask_in=None): dilation=self.dilation, groups=1, ) - # for mixed precision training, change 1e-8 to 1e-6 - mask_ratio = self.slide_winsize / (update_mask + 1e-6) + update_mask_filled = torch.masked_fill(update_mask, update_mask == 0, self.slide_winsize) + mask_ratio = self.slide_winsize / update_mask_filled update_mask = torch.clamp(update_mask, 0, 1) mask_ratio = torch.mul(mask_ratio, update_mask) if use_cache: @@ -177,7 +176,7 @@ def forward(self, input, mask_in=None): self.update_mask = update_mask self.mask_ratio = mask_ratio - raw_out = super(PartialConv1d, self).forward(torch.mul(input, mask) if mask_in is not None else input) + raw_out = super(PartialConv1d, self).forward(input) if self.bias is not None: bias_view = self.bias.view(1, self.out_channels, 1) output = torch.mul(raw_out - bias_view, mask_ratio) + bias_view @@ -185,10 +184,7 @@ def forward(self, input, mask_in=None): else: output = torch.mul(raw_out, mask_ratio) - if self.return_mask: - return output, update_mask - else: - return output + return output class LinearNorm(torch.nn.Module): From 0c6c6fbe6fa90ada8a8650d1277d3e5d823cb521 Mon Sep 17 00:00:00 2001 From: Boris Fomitchev Date: Sat, 12 Nov 2022 02:24:52 -0800 Subject: [PATCH 18/43] Verifying cache not used Signed-off-by: Boris Fomitchev --- nemo/collections/tts/modules/submodules.py | 1 + 1 file changed, 1 insertion(+) diff --git a/nemo/collections/tts/modules/submodules.py b/nemo/collections/tts/modules/submodules.py index 02d6520f68eb..67224021e77a 100644 --- a/nemo/collections/tts/modules/submodules.py +++ b/nemo/collections/tts/modules/submodules.py @@ -172,6 +172,7 @@ def forward(self, input, mask_in=None): update_mask = torch.clamp(update_mask, 0, 1) mask_ratio = torch.mul(mask_ratio, update_mask) if use_cache: + print("Cache use") self.last_size = tuple(input.shape) self.update_mask = update_mask self.mask_ratio = mask_ratio From f8f9699a9d78ce63df66f3c853e5d84c65611770 Mon Sep 17 00:00:00 2001 From: Boris Fomitchev Date: Sat, 12 Nov 2022 02:31:42 -0800 Subject: [PATCH 19/43] Removing cache Signed-off-by: Boris Fomitchev --- nemo/collections/tts/modules/submodules.py | 59 +++++++--------------- 1 file changed, 19 insertions(+), 40 deletions(-) diff --git a/nemo/collections/tts/modules/submodules.py b/nemo/collections/tts/modules/submodules.py index 67224021e77a..95289380037f 100644 --- a/nemo/collections/tts/modules/submodules.py +++ b/nemo/collections/tts/modules/submodules.py @@ -130,52 +130,31 @@ class PartialConv1d(torch.nn.Conv1d): """ def __init__(self, *args, **kwargs): - super(PartialConv1d, self).__init__(*args, **kwargs) weight_maskUpdater = torch.ones(1, 1, self.kernel_size[0]) self.register_buffer("weight_maskUpdater", weight_maskUpdater, persistent=False) - - # self.weight_maskUpdater = torch.ones(1, 1, self.kernel_size[0]) self.slide_winsize = self.weight_maskUpdater.shape[1] * self.weight_maskUpdater.shape[2] - self.last_size = (None, None, None) - self.update_mask = None - self.mask_ratio = None - def forward(self, input, mask_in=None): - assert len(input.shape) == 3 - # borisf: disabled cache for export - use_cache = mask_in is None and not torch.jit.is_tracing() - cache_hit = use_cache and self.last_size == tuple(input.shape) - if cache_hit: - mask_ratio = self.mask_ratio - update_mask = self.update_mask - # if a mask is input, or tensor shape changed, update mask ratio - else: - with torch.no_grad(): - if mask_in is None: - mask = torch.ones(1, 1, input.shape[2], dtype=input.dtype, device=input.device) - else: - mask = mask_in - input = torch.mul(input, mask) - update_mask = F.conv1d( - mask, - self.weight_maskUpdater, - bias=None, - stride=self.stride, - padding=self.padding, - dilation=self.dilation, - groups=1, - ) - update_mask_filled = torch.masked_fill(update_mask, update_mask == 0, self.slide_winsize) - mask_ratio = self.slide_winsize / update_mask_filled - update_mask = torch.clamp(update_mask, 0, 1) - mask_ratio = torch.mul(mask_ratio, update_mask) - if use_cache: - print("Cache use") - self.last_size = tuple(input.shape) - self.update_mask = update_mask - self.mask_ratio = mask_ratio + with torch.no_grad(): + if mask_in is None: + mask = torch.ones(1, 1, input.shape[2], dtype=input.dtype, device=input.device) + else: + mask = mask_in + input = torch.mul(input, mask) + update_mask = F.conv1d( + mask, + self.weight_maskUpdater, + bias=None, + stride=self.stride, + padding=self.padding, + dilation=self.dilation, + groups=1, + ) + update_mask_filled = torch.masked_fill(update_mask, update_mask == 0, self.slide_winsize) + mask_ratio = self.slide_winsize / update_mask_filled + update_mask = torch.clamp(update_mask, 0, 1) + mask_ratio = torch.mul(mask_ratio, update_mask) raw_out = super(PartialConv1d, self).forward(input) if self.bias is not None: From 495f573621120b84a94fedc3ab5467bc52e42df8 Mon Sep 17 00:00:00 2001 From: Boris Fomitchev Date: Sat, 12 Nov 2022 14:51:59 -0800 Subject: [PATCH 20/43] Working autocast export Signed-off-by: Boris Fomitchev --- nemo/collections/tts/modules/common.py | 15 +-- nemo/collections/tts/modules/radtts.py | 11 +-- nemo/collections/tts/modules/submodules.py | 107 ++++++++++++--------- 3 files changed, 70 insertions(+), 63 deletions(-) diff --git a/nemo/collections/tts/modules/common.py b/nemo/collections/tts/modules/common.py index 302765edd617..5d0d9adaa665 100644 --- a/nemo/collections/tts/modules/common.py +++ b/nemo/collections/tts/modules/common.py @@ -163,7 +163,6 @@ def __init__( lstm_norm_fn="spectral", ): super(ConvLSTMLinear, self).__init__(n_channels, int(n_channels // 2), 1) - self.out_dim = out_dim self.convolutions = nn.ModuleList() if n_layers > 0: @@ -187,7 +186,6 @@ def __init__( if norm_fn is not None: print("Applying {} norm to {}".format(norm_fn, conv_layer)) else: - # conv_layer = torch.nn.utils.weight_norm(conv_layer.conv) print("Applying weight norm to {}".format(conv_layer)) self.convolutions.append(conv_layer) @@ -214,7 +212,6 @@ def forward(self, context: Tensor, lens: Tensor) -> Tensor: my_lens = lens context, my_lens, unsort_ids = sort_tensor(context, my_lens) seq = self.masked_conv_to_sequence(context, my_lens, enforce_sorted=True) - # seq = self.conv_to_sequence(context, my_lens, enforce_sorted=True) context, _ = self.lstm_sequence(seq) context = context[unsort_ids] @@ -225,19 +222,17 @@ def forward(self, context: Tensor, lens: Tensor) -> Tensor: def script(self): traced = nn.ModuleList() + return torch.jit.script(self) + for conv in self.convolutions: - if hasattr(conv, 'conv'): - w = conv.conv.weight - else: - w = conv.weight + w = conv.conv.weight s = w.shape rand_in = ( - torch.ones(1, s[1], s[2], dtype=w.dtype, device=w.device), - torch.ones(1, 1, s[2], dtype=w.dtype, device=w.device), + torch.ones(4, s[1], s[2], dtype=w.dtype, device=w.device), + torch.ones(4, 1, s[2], dtype=w.dtype, device=w.device), ) traced.append(torch.jit.trace_module(conv, {"forward": rand_in})) self.convolutions = traced - return torch.jit.script(self) def getRadTTSEncoder( diff --git a/nemo/collections/tts/modules/radtts.py b/nemo/collections/tts/modules/radtts.py index 6f298aeaf9d7..16b529cde12c 100644 --- a/nemo/collections/tts/modules/radtts.py +++ b/nemo/collections/tts/modules/radtts.py @@ -61,9 +61,9 @@ def pad_energy_avg_and_f0(energy_avg, f0, max_out_len): def adjust_f0(f0, f0_mean, f0_std, vmask_bool): if f0_mean > 0.0: f0_mu, f0_sigma = f0[vmask_bool].mean(), f0[vmask_bool].std() - f0[vmask_bool] = (f0[vmask_bool] - f0_mu) / f0_sigma + f0[vmask_bool] = ((f0[vmask_bool] - f0_mu) / f0_sigma).to(dtype=f0.dtype) f0_std = f0_std if f0_std > 0 else f0_sigma - f0[vmask_bool] = f0[vmask_bool] * f0_std + f0_mean + f0[vmask_bool] = (f0[vmask_bool] * f0_std + f0_mean).to(dtype=f0.dtype) return f0 @@ -722,8 +722,7 @@ def infer_f0(self, residual, txt_enc_time_expanded, spk_vec, voiced_mask=None, l voiced_mask = voiced_mask[:, :, : f0.shape[-1]] if self.ap_pred_log_f0: # if variable is set, decoder sees linear f0 - # mask = f0 > 0.0 if voiced_mask is None else voiced_mask.bool() - f0[voiced_mask] = torch.exp(f0[voiced_mask]).to(f0) + f0 = torch.exp(f0).to(dtype=f0.dtype) f0[~voiced_mask] = 0.0 return f0 @@ -826,8 +825,8 @@ def forward_for_export(self, text, lens, speaker_id, speaker_id_text, speaker_id text, speaker_id_text=speaker_id_text, speaker_id_attributes=speaker_id_attributes, - sigma=0.7, - sigma_txt=0.7, + sigma=0.0, + sigma_txt=0.0, sigma_f0=1.0, sigma_energy=1.0, f0_mean=145.0, diff --git a/nemo/collections/tts/modules/submodules.py b/nemo/collections/tts/modules/submodules.py index 95289380037f..206a6182565c 100644 --- a/nemo/collections/tts/modules/submodules.py +++ b/nemo/collections/tts/modules/submodules.py @@ -20,7 +20,7 @@ from torch.nn import functional as F -def masked_instance_norm( +def masked_instance_norm0( input: Tensor, mask: Tensor, running_mean: Optional[Tensor], @@ -38,20 +38,29 @@ def masked_instance_norm( if not use_input_stats and (running_mean is None or running_var is None): raise ValueError('Expected running_mean and running_var to be not None when use_input_stats=False') + print(running_mean, running_var, weight, bias, use_input_stats, momentum) + shape = input.shape b, c = shape[:2] - num_dims = len(shape[2:]) - _dims = tuple(range(-num_dims, 0)) - _slice = (...,) + (None,) * num_dims + num_dims = 1 # len(shape[2:]) + # _dims = tuple(range(-num_dims, 0)) - running_mean_ = running_mean[None, :].repeat(b, 1) if running_mean is not None else None - running_var_ = running_var[None, :].repeat(b, 1) if running_mean is not None else None + # _slice = (...,) + (None,) * num_dims + # print (_dims, _slice) + if running_mean is not None: + running_mean_ = running_mean[None, :].repeat(b, 1) + else: + running_mean_ = None + if running_var is not None: + running_var_ = running_var[None, :].repeat(b, 1) + else: + running_var_ = None if use_input_stats: - lengths = mask.sum(_dims) - mean = (input * mask).sum(_dims) / lengths # (N, C) - var = (((input - mean[_slice]) * mask) ** 2).sum(_dims) / lengths # (N, C) + lengths = mask.sum((-1,)) + mean = (input * mask).sum((-1,)) / lengths # (N, C) + var = (((input - mean[(..., None)]) * mask) ** 2).sum((-1,)) / lengths # (N, C) if running_mean is not None: running_mean_.mul_(1 - momentum).add_(momentum * mean.detach()) @@ -62,10 +71,36 @@ def masked_instance_norm( else: mean, var = running_mean_.view(b, c), running_var_.view(b, c) - out = (input - mean[_slice]) / torch.sqrt(var[_slice] + eps) # (N, C, ...) + out = (input - mean[(..., None)]) / torch.sqrt(var[(..., None)] + eps) # (N, C, ...) + + out = out * weight[None, :][(..., None)] + bias[None, :][(..., None)] + + return out + + +def masked_instance_norm( + input: Tensor, mask: Tensor, weight: Tensor, bias: Tensor, momentum: float, eps: float = 1e-5, +) -> Tensor: + r"""Applies Masked Instance Normalization for each channel in each data sample in a batch. + + See :class:`~MaskedInstanceNorm1d` for details. + """ + shape = input.shape + b, c = shape[:2] + + num_dims = 1 # len(shape[2:]) + # _dims = tuple(range(-num_dims, 0)) + + # _slice = (...,) + (None,) * num_dims + # print (_dims, _slice) - if weight is not None and bias is not None: - out = out * weight[None, :][_slice] + bias[None, :][_slice] + lengths = mask.sum((-1,)) + mean = (input * mask).sum((-1,)) / lengths # (N, C) + var = (((input - mean[(..., None)]) * mask) ** 2).sum((-1,)) / lengths # (N, C) + + out = (input - mean[(..., None)]) / torch.sqrt(var[(..., None)] + eps) # (N, C, ...) + + out = out * weight[None, :][(..., None)] + bias[None, :][(..., None)] return out @@ -92,34 +127,8 @@ def __init__( ) -> None: super(MaskedInstanceNorm1d, self).__init__(num_features, eps, momentum, affine, track_running_stats) - def forward(self, input: Tensor, mask: Tensor = None) -> Tensor: - self._check_input_dim(input) - if mask is not None: - self._check_input_dim(mask) - - if mask is None: - return F.instance_norm( - input, - self.running_mean, - self.running_var, - self.weight, - self.bias, - self.training or not self.track_running_stats, - self.momentum, - self.eps, - ) - else: - return masked_instance_norm( - input, - mask, - self.running_mean, - self.running_var, - self.weight, - self.bias, - self.training or not self.track_running_stats, - self.momentum, - self.eps, - ) + def forward(self, input: Tensor, mask: Tensor) -> Tensor: + return masked_instance_norm(input, mask, self.weight, self.bias, self.momentum, self.eps,) class PartialConv1d(torch.nn.Conv1d): @@ -129,13 +138,16 @@ class PartialConv1d(torch.nn.Conv1d): this affect. """ + __constants__ = ['slide_winsize'] + slide_winsize: float + def __init__(self, *args, **kwargs): super(PartialConv1d, self).__init__(*args, **kwargs) weight_maskUpdater = torch.ones(1, 1, self.kernel_size[0]) self.register_buffer("weight_maskUpdater", weight_maskUpdater, persistent=False) self.slide_winsize = self.weight_maskUpdater.shape[1] * self.weight_maskUpdater.shape[2] - def forward(self, input, mask_in=None): + def forward(self, input, mask_in): with torch.no_grad(): if mask_in is None: mask = torch.ones(1, 1, input.shape[2], dtype=input.dtype, device=input.device) @@ -156,7 +168,8 @@ def forward(self, input, mask_in=None): update_mask = torch.clamp(update_mask, 0, 1) mask_ratio = torch.mul(mask_ratio, update_mask) - raw_out = super(PartialConv1d, self).forward(input) + raw_out = self._conv_forward(input, self.weight, self.bias) + if self.bias is not None: bias_view = self.bias.view(1, self.out_channels, 1) output = torch.mul(raw_out - bias_view, mask_ratio) + bias_view @@ -179,6 +192,9 @@ def forward(self, x): class ConvNorm(torch.nn.Module): + __constants__ = ['use_partial_padding'] + use_partial_padding: bool + def __init__( self, in_channels, @@ -197,12 +213,9 @@ def __init__( if padding is None: assert kernel_size % 2 == 1 padding = int(dilation * (kernel_size - 1) / 2) - self.kernel_size = kernel_size - self.dilation = dilation self.use_partial_padding = use_partial_padding - self.use_weight_norm = use_weight_norm conv_fn = torch.nn.Conv1d - if self.use_partial_padding: + if use_partial_padding: conv_fn = PartialConv1d self.conv = conv_fn( in_channels, @@ -214,7 +227,7 @@ def __init__( bias=bias, ) torch.nn.init.xavier_uniform_(self.conv.weight, gain=torch.nn.init.calculate_gain(w_init_gain)) - if self.use_weight_norm: + if use_weight_norm: self.conv = torch.nn.utils.weight_norm(self.conv) if norm_fn is not None: self.norm = norm_fn(out_channels, affine=True) From 5765d3d6fdbda58088beb1bc122abaf2a3bb5ccb Mon Sep 17 00:00:00 2001 From: Boris Fomitchev Date: Sun, 13 Nov 2022 14:41:32 -0800 Subject: [PATCH 21/43] restored e-6 Signed-off-by: Boris Fomitchev --- nemo/collections/tts/modules/common.py | 14 -------------- nemo/collections/tts/modules/radtts.py | 10 +++++----- nemo/collections/tts/modules/submodules.py | 18 +++++------------- 3 files changed, 10 insertions(+), 32 deletions(-) diff --git a/nemo/collections/tts/modules/common.py b/nemo/collections/tts/modules/common.py index 5d0d9adaa665..1892e2a3c67b 100644 --- a/nemo/collections/tts/modules/common.py +++ b/nemo/collections/tts/modules/common.py @@ -220,20 +220,6 @@ def forward(self, context: Tensor, lens: Tensor) -> Tensor: return context - def script(self): - traced = nn.ModuleList() - return torch.jit.script(self) - - for conv in self.convolutions: - w = conv.conv.weight - s = w.shape - rand_in = ( - torch.ones(4, s[1], s[2], dtype=w.dtype, device=w.device), - torch.ones(4, 1, s[2], dtype=w.dtype, device=w.device), - ) - traced.append(torch.jit.trace_module(conv, {"forward": rand_in})) - self.convolutions = traced - def getRadTTSEncoder( encoder_n_convolutions=3, diff --git a/nemo/collections/tts/modules/radtts.py b/nemo/collections/tts/modules/radtts.py index 16b529cde12c..b51146961c25 100644 --- a/nemo/collections/tts/modules/radtts.py +++ b/nemo/collections/tts/modules/radtts.py @@ -783,14 +783,14 @@ def _prepare_for_export(self, **kwargs): super()._prepare_for_export(**kwargs) if self.prepared_for_export: return - self.encoder = self.encoder.script() - self.v_pred_module.feat_pred_fn = self.v_pred_module.feat_pred_fn.script() + self.encoder = torch.jit.script(self.encoder) + self.v_pred_module.feat_pred_fn = torch.jit.script(self.v_pred_module.feat_pred_fn) if hasattr(self, 'f0_pred_module'): - self.f0_pred_module.feat_pred_fn = self.f0_pred_module.feat_pred_fn.script() + self.f0_pred_module.feat_pred_fn = torch.jit.script(self.f0_pred_module.feat_pred_fn) if hasattr(self, 'energy_pred_module'): - self.energy_pred_module.feat_pred_fn = self.energy_pred_module.feat_pred_fn.script() + self.energy_pred_module.feat_pred_fn = torch.jit.script(self.energy_pred_module.feat_pred_fn) if hasattr(self, 'dur_pred_layer'): - self.dur_pred_layer.feat_pred_fn = self.dur_pred_layer.feat_pred_fn.script() + self.dur_pred_layer.feat_pred_fn = torch.jit.script(self.dur_pred_layer.feat_pred_fn) if self.use_context_lstm: self.context_lstm = torch.jit.script(self.context_lstm) diff --git a/nemo/collections/tts/modules/submodules.py b/nemo/collections/tts/modules/submodules.py index 206a6182565c..a316a2478811 100644 --- a/nemo/collections/tts/modules/submodules.py +++ b/nemo/collections/tts/modules/submodules.py @@ -42,12 +42,7 @@ def masked_instance_norm0( shape = input.shape b, c = shape[:2] - - num_dims = 1 # len(shape[2:]) - # _dims = tuple(range(-num_dims, 0)) - - # _slice = (...,) + (None,) * num_dims - # print (_dims, _slice) + num_dims = 1 if running_mean is not None: running_mean_ = running_mean[None, :].repeat(b, 1) else: @@ -88,11 +83,7 @@ def masked_instance_norm( shape = input.shape b, c = shape[:2] - num_dims = 1 # len(shape[2:]) - # _dims = tuple(range(-num_dims, 0)) - - # _slice = (...,) + (None,) * num_dims - # print (_dims, _slice) + num_dims = 1 lengths = mask.sum((-1,)) mean = (input * mask).sum((-1,)) / lengths # (N, C) @@ -163,8 +154,9 @@ def forward(self, input, mask_in): dilation=self.dilation, groups=1, ) - update_mask_filled = torch.masked_fill(update_mask, update_mask == 0, self.slide_winsize) - mask_ratio = self.slide_winsize / update_mask_filled + # update_mask_filled = torch.masked_fill(update_mask, update_mask == 0, self.slide_winsize) + mask_ratio = self.slide_winsize / (update_mask + 1e-6) + # mask_ratio = self.slide_winsize / update_mask_filled update_mask = torch.clamp(update_mask, 0, 1) mask_ratio = torch.mul(mask_ratio, update_mask) From 65427fcd4f1dddfba90b08fbd41a357456856aff Mon Sep 17 00:00:00 2001 From: Boris Fomitchev Date: Sun, 13 Nov 2022 17:36:42 -0800 Subject: [PATCH 22/43] Removed some casts around masks, etc Signed-off-by: Boris Fomitchev --- nemo/collections/tts/modules/radtts.py | 42 ++++++++++++++------------ 1 file changed, 23 insertions(+), 19 deletions(-) diff --git a/nemo/collections/tts/modules/radtts.py b/nemo/collections/tts/modules/radtts.py index b51146961c25..fe3fbdb2f047 100644 --- a/nemo/collections/tts/modules/radtts.py +++ b/nemo/collections/tts/modules/radtts.py @@ -60,8 +60,9 @@ def pad_energy_avg_and_f0(energy_avg, f0, max_out_len): def adjust_f0(f0, f0_mean, f0_std, vmask_bool): if f0_mean > 0.0: - f0_mu, f0_sigma = f0[vmask_bool].mean(), f0[vmask_bool].std() - f0[vmask_bool] = ((f0[vmask_bool] - f0_mu) / f0_sigma).to(dtype=f0.dtype) + masked_f0 = f0[vmask_bool] + f0_mu, f0_sigma = torch.std_mean(masked_f0) + f0[vmask_bool] = ((masked_f0 - f0_mu) / f0_sigma).to(dtype=f0.dtype) f0_std = f0_std if f0_std > 0 else f0_sigma f0[vmask_bool] = (f0[vmask_bool] * f0_std + f0_mean).to(dtype=f0.dtype) return f0 @@ -411,18 +412,18 @@ def get_first_order_features(self, feats, out_lens, dilation=1): return (dfeats_R + dfeats_L) * 0.5 - def apply_voice_mask_to_text(self, text_enc, voiced_mask): + def apply_voice_mask_to_text(self, text_enc, voiced_mask_bool): """ text_enc: b x C x N - voiced_mask: b x N + voiced_mask_bool: b x N """ - voiced_mask = voiced_mask.unsqueeze(1) + voiced_mask_bool = voiced_mask_bool.unsqueeze(1) voiced_embedding_s = self.v_embeddings.weight[0:1, :, None] unvoiced_embedding_s = self.v_embeddings.weight[1:2, :, None] voiced_embedding_b = self.v_embeddings.weight[2:3, :, None] unvoiced_embedding_b = self.v_embeddings.weight[3:4, :, None] - scale = torch.sigmoid(voiced_embedding_s * voiced_mask + unvoiced_embedding_s * (1 - voiced_mask)) - bias = 0.1 * torch.tanh(voiced_embedding_b * voiced_mask + unvoiced_embedding_b * (1 - voiced_mask)) + scale = torch.sigmoid(torch.where(voiced_mask_bool, voiced_embedding_s, unvoiced_embedding_s)) + bias = 0.1 * torch.tanh(torch.where(voiced_mask_bool, voiced_embedding_b, unvoiced_embedding_b)) return text_enc * scale + bias def forward( @@ -464,10 +465,11 @@ def forward( f0_bias = 0 # unvoiced bias forward pass + voiced_mask_bool = voiced_mask.bool() if self.use_unvoiced_bias: f0_bias = self.unvoiced_bias_module(context.permute(0, 2, 1)) f0_bias = -f0_bias[..., 0] - f0_bias = f0_bias * (~voiced_mask.bool()).float() + f0_bias.mask_fill(voiced_mask_bool, 0.0) # mel decoder forward pass if 'dec' in self.include_modules: @@ -534,13 +536,13 @@ def forward( # affine transform context using voiced mask if self.ap_use_voiced_embeddings: - text_enc_time_expanded = self.apply_voice_mask_to_text(text_enc_time_expanded, voiced_mask) + text_enc_time_expanded = self.apply_voice_mask_to_text(text_enc_time_expanded, voiced_mask_bool) if self.ap_use_unvoiced_bias: # whether to use the unvoiced bias in the attribute predictor f0_target = torch.detach(f0 * voiced_mask + f0_bias) else: f0_target = torch.detach(f0) # fit to log f0 in f0 predictor - f0_target[voiced_mask.bool()] = torch.log(f0_target[voiced_mask.bool()]) + f0_target[voiced_mask_bool] = torch.log(f0_target[voiced_mask_bool]) f0_target = f0_target / 6 # scale to ~ [0, 1] in log space energy_avg = energy_avg * 2 - 1 # scale to ~ [-1, 1] @@ -636,7 +638,8 @@ def infer( # get logits voiced_mask = self.v_pred_module.infer(None, txt_enc_time_expanded, spk_vec_attributes, lens=out_lens) voiced_mask = torch.sigmoid(voiced_mask[:, 0]) > 0.5 - voiced_mask = voiced_mask.float() + else: + voiced_mask = voiced_mask.bool() ap_txt_enc_time_expanded = txt_enc_time_expanded # voice mask augmentation only used for attribute prediction @@ -648,14 +651,14 @@ def infer( if self.use_unvoiced_bias: f0_bias = self.unvoiced_bias_module(txt_enc_time_expanded.permute(0, 2, 1)) f0_bias = -f0_bias[..., 0] - f0_bias = f0_bias * (~voiced_mask.bool()).float() + f0_bias.masked_fill_(voiced_mask, 0.0) if f0 is None: n_f0_feature_channels = 2 if self.use_first_order_features else 1 z_f0 = torch.normal(txt_enc.new_zeros(batch_size, n_f0_feature_channels, max_out_len)) * sigma_f0 f0 = self.infer_f0(z_f0, ap_txt_enc_time_expanded, spk_vec_attributes, voiced_mask, out_lens)[:, 0] - f0 = adjust_f0(f0, f0_mean, f0_std, voiced_mask.to(dtype=bool)) + f0 = adjust_f0(f0, f0_mean, f0_std, voiced_mask) if energy_avg is None: n_energy_feature_channels = 2 if self.use_first_order_features else 1 @@ -671,7 +674,7 @@ def infer( # print("V mask, energy_avg, f0, f0_bias: ", voiced_mask.shape, energy_avg.shape, f0.shape, f0_bias.shape) context_w_spkvec = self.preprocess_context( - txt_enc_time_expanded, spk_vec, out_lens, f0 * voiced_mask + f0_bias, energy_avg + txt_enc_time_expanded, spk_vec, out_lens, f0.masked_fill(~voiced_mask, 0.0) + f0_bias, energy_avg ) residual = torch.normal(txt_enc.new_zeros(batch_size, 80 * self.n_group_size, torch.max(n_groups))) * sigma @@ -701,8 +704,6 @@ def infer( def infer_f0(self, residual, txt_enc_time_expanded, spk_vec, voiced_mask=None, lens=None): f0 = self.f0_pred_module.infer(residual, txt_enc_time_expanded, spk_vec, lens) - if voiced_mask is not None and len(voiced_mask.shape) == 2: - voiced_mask = voiced_mask[:, None] # constants if self.ap_pred_log_f0: if self.use_first_order_features: @@ -718,12 +719,15 @@ def infer_f0(self, residual, txt_enc_time_expanded, spk_vec, voiced_mask=None, l voiced_mask = f0 > 0.0 else: voiced_mask = voiced_mask.bool() - # due to grouping, f0 might be 1 frame short - voiced_mask = voiced_mask[:, :, : f0.shape[-1]] + if len(voiced_mask.shape) == 2: + voiced_mask = voiced_mask[:, None] + # due to grouping, f0 might be 1 frame short + voiced_mask = voiced_mask[:, :, : f0.shape[-1]] + if self.ap_pred_log_f0: # if variable is set, decoder sees linear f0 f0 = torch.exp(f0).to(dtype=f0.dtype) - f0[~voiced_mask] = 0.0 + f0.masked_fill(~voiced_mask, 0.0) return f0 def infer_energy(self, residual, txt_enc_time_expanded, spk_vec, lens): From 8e547fd0a034230649e11455678562f888496ebd Mon Sep 17 00:00:00 2001 From: Boris Fomitchev Date: Sun, 13 Nov 2022 19:40:55 -0800 Subject: [PATCH 23/43] Fixing some casts Signed-off-by: Boris Fomitchev --- nemo/collections/tts/modules/common.py | 14 ++---- nemo/collections/tts/modules/radtts.py | 2 +- nemo/collections/tts/modules/submodules.py | 53 ---------------------- 3 files changed, 6 insertions(+), 63 deletions(-) diff --git a/nemo/collections/tts/modules/common.py b/nemo/collections/tts/modules/common.py index 1892e2a3c67b..3b333a12d19e 100644 --- a/nemo/collections/tts/modules/common.py +++ b/nemo/collections/tts/modules/common.py @@ -44,7 +44,7 @@ def get_mask_from_lengths_and_val(lengths, val): max_len = val.shape[-1] ids = torch.arange(0, max_len, device=lengths.device) mask = ids < lengths.unsqueeze(1) - return mask.float() + return mask @torch.jit.script @@ -195,10 +195,10 @@ def __init__( def masked_conv_to_sequence(self, context: Tensor, lens: Tensor, enforce_sorted: bool = False) -> PackedSequence: mask = get_mask_from_lengths_and_val(lens, context) - mask = mask.unsqueeze(1) + mask = mask.to(dtype=context.dtype).unsqueeze(1) for conv in self.convolutions: context = self.dropout(F.relu(conv(context, mask))) - context = torch.mul(context, mask) + context = context.transpose(1, 2) seq = torch.nn.utils.rnn.pack_padded_sequence( context, lens.long().cpu(), batch_first=True, enforce_sorted=enforce_sorted @@ -206,12 +206,8 @@ def masked_conv_to_sequence(self, context: Tensor, lens: Tensor, enforce_sorted: return seq def forward(self, context: Tensor, lens: Tensor) -> Tensor: - # if lens is None: - # my_lens = context.new_ones([context.shape[0]], dtype=torch.int) * context.shape[2] - # else: - my_lens = lens - context, my_lens, unsort_ids = sort_tensor(context, my_lens) - seq = self.masked_conv_to_sequence(context, my_lens, enforce_sorted=True) + context, lens, unsort_ids = sort_tensor(context, lens) + seq = self.masked_conv_to_sequence(context, lens, enforce_sorted=True) context, _ = self.lstm_sequence(seq) context = context[unsort_ids] diff --git a/nemo/collections/tts/modules/radtts.py b/nemo/collections/tts/modules/radtts.py index fe3fbdb2f047..a5a73e34119e 100644 --- a/nemo/collections/tts/modules/radtts.py +++ b/nemo/collections/tts/modules/radtts.py @@ -61,7 +61,7 @@ def pad_energy_avg_and_f0(energy_avg, f0, max_out_len): def adjust_f0(f0, f0_mean, f0_std, vmask_bool): if f0_mean > 0.0: masked_f0 = f0[vmask_bool] - f0_mu, f0_sigma = torch.std_mean(masked_f0) + f0_sigma, f0_mu = torch.std_mean(masked_f0) f0[vmask_bool] = ((masked_f0 - f0_mu) / f0_sigma).to(dtype=f0.dtype) f0_std = f0_std if f0_std > 0 else f0_sigma f0[vmask_bool] = (f0[vmask_bool] * f0_std + f0_mean).to(dtype=f0.dtype) diff --git a/nemo/collections/tts/modules/submodules.py b/nemo/collections/tts/modules/submodules.py index a316a2478811..1e07ec42525a 100644 --- a/nemo/collections/tts/modules/submodules.py +++ b/nemo/collections/tts/modules/submodules.py @@ -20,59 +20,6 @@ from torch.nn import functional as F -def masked_instance_norm0( - input: Tensor, - mask: Tensor, - running_mean: Optional[Tensor], - running_var: Optional[Tensor], - weight: Optional[Tensor], - bias: Optional[Tensor], - use_input_stats: bool, - momentum: float, - eps: float = 1e-5, -) -> Tensor: - r"""Applies Masked Instance Normalization for each channel in each data sample in a batch. - - See :class:`~MaskedInstanceNorm1d` for details. - """ - if not use_input_stats and (running_mean is None or running_var is None): - raise ValueError('Expected running_mean and running_var to be not None when use_input_stats=False') - - print(running_mean, running_var, weight, bias, use_input_stats, momentum) - - shape = input.shape - b, c = shape[:2] - num_dims = 1 - if running_mean is not None: - running_mean_ = running_mean[None, :].repeat(b, 1) - else: - running_mean_ = None - if running_var is not None: - running_var_ = running_var[None, :].repeat(b, 1) - else: - running_var_ = None - - if use_input_stats: - lengths = mask.sum((-1,)) - mean = (input * mask).sum((-1,)) / lengths # (N, C) - var = (((input - mean[(..., None)]) * mask) ** 2).sum((-1,)) / lengths # (N, C) - - if running_mean is not None: - running_mean_.mul_(1 - momentum).add_(momentum * mean.detach()) - running_mean.copy_(running_mean_.view(b, c).mean(0, keepdim=False)) - if running_var is not None: - running_var_.mul_(1 - momentum).add_(momentum * var.detach()) - running_var.copy_(running_var_.view(b, c).mean(0, keepdim=False)) - else: - mean, var = running_mean_.view(b, c), running_var_.view(b, c) - - out = (input - mean[(..., None)]) / torch.sqrt(var[(..., None)] + eps) # (N, C, ...) - - out = out * weight[None, :][(..., None)] + bias[None, :][(..., None)] - - return out - - def masked_instance_norm( input: Tensor, mask: Tensor, weight: Tensor, bias: Tensor, momentum: float, eps: float = 1e-5, ) -> Tensor: From 47dab070fad5eb6832e0a9146e7b68ad34d5faab Mon Sep 17 00:00:00 2001 From: Boris Fomitchev Date: Mon, 14 Nov 2022 01:47:00 -0800 Subject: [PATCH 24/43] Fixing in-place ops Signed-off-by: Boris Fomitchev --- nemo/collections/tts/modules/radtts.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/nemo/collections/tts/modules/radtts.py b/nemo/collections/tts/modules/radtts.py index a5a73e34119e..f35f1667e1d9 100644 --- a/nemo/collections/tts/modules/radtts.py +++ b/nemo/collections/tts/modules/radtts.py @@ -469,7 +469,7 @@ def forward( if self.use_unvoiced_bias: f0_bias = self.unvoiced_bias_module(context.permute(0, 2, 1)) f0_bias = -f0_bias[..., 0] - f0_bias.mask_fill(voiced_mask_bool, 0.0) + f0_bias.masked_fill_(voiced_mask_bool, 0.0) # mel decoder forward pass if 'dec' in self.include_modules: @@ -727,7 +727,7 @@ def infer_f0(self, residual, txt_enc_time_expanded, spk_vec, voiced_mask=None, l if self.ap_pred_log_f0: # if variable is set, decoder sees linear f0 f0 = torch.exp(f0).to(dtype=f0.dtype) - f0.masked_fill(~voiced_mask, 0.0) + f0.masked_fill_(~voiced_mask, 0.0) return f0 def infer_energy(self, residual, txt_enc_time_expanded, spk_vec, lens): From 0324f46eb40fca809d85325b0b99d60d029adad4 Mon Sep 17 00:00:00 2001 From: Boris Fomitchev Date: Mon, 14 Nov 2022 12:46:47 -0800 Subject: [PATCH 25/43] fixing grad Signed-off-by: Boris Fomitchev --- nemo/collections/tts/modules/submodules.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/nemo/collections/tts/modules/submodules.py b/nemo/collections/tts/modules/submodules.py index 1e07ec42525a..b42e12fbbafe 100644 --- a/nemo/collections/tts/modules/submodules.py +++ b/nemo/collections/tts/modules/submodules.py @@ -86,12 +86,12 @@ def __init__(self, *args, **kwargs): self.slide_winsize = self.weight_maskUpdater.shape[1] * self.weight_maskUpdater.shape[2] def forward(self, input, mask_in): + if mask_in is None: + mask = torch.ones(1, 1, input.shape[2], dtype=input.dtype, device=input.device) + else: + mask = mask_in + input = torch.mul(input, mask) with torch.no_grad(): - if mask_in is None: - mask = torch.ones(1, 1, input.shape[2], dtype=input.dtype, device=input.device) - else: - mask = mask_in - input = torch.mul(input, mask) update_mask = F.conv1d( mask, self.weight_maskUpdater, From cb8858ace7c77c55f056f280515503fea6564ad4 Mon Sep 17 00:00:00 2001 From: Boris Fomitchev Date: Mon, 14 Nov 2022 18:28:06 -0800 Subject: [PATCH 26/43] Small export fixes Signed-off-by: Boris Fomitchev --- nemo/utils/cast_utils.py | 7 ++----- scripts/export.py | 2 +- 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/nemo/utils/cast_utils.py b/nemo/utils/cast_utils.py index b57b78decd6f..9eb064936ea5 100644 --- a/nemo/utils/cast_utils.py +++ b/nemo/utils/cast_utils.py @@ -70,9 +70,6 @@ def __init__(self, mod): self.mod = mod def forward(self, x): - if torch.is_autocast_enabled(): - with avoid_float16_autocast_context(): - ret = self.mod.forward(x.to(torch.float32)).to(x.dtype) - else: - ret = self.mod.forward(x) + with avoid_float16_autocast_context(): + ret = self.mod.forward(x.to(torch.float32)).to(x.dtype) return ret diff --git a/scripts/export.py b/scripts/export.py index a8c8ad6605de..2d7514483fc9 100644 --- a/scripts/export.py +++ b/scripts/export.py @@ -139,8 +139,8 @@ def nemo_export(argv): with autocast(), torch.no_grad(), torch.inference_mode(): model.to(device=args.device).freeze() model.eval() - input_example = model.input_module.input_example(**in_args) if check_trace and len(in_args) > 0: + input_example = model.input_module.input_example(**in_args) check_trace = [input_example] for key, arg in in_args.items(): in_args[key] = (arg + 1) // 2 From c17ba735d591aa22c205c8f766c7e62efd23cb55 Mon Sep 17 00:00:00 2001 From: Boris Fomitchev Date: Mon, 14 Nov 2022 18:50:31 -0800 Subject: [PATCH 27/43] LGTM cleanup Signed-off-by: Boris Fomitchev --- nemo/collections/tts/models/radtts.py | 1 - nemo/collections/tts/modules/common.py | 4 --- nemo/collections/tts/modules/radtts.py | 1 - nemo/collections/tts/modules/submodules.py | 10 +++---- nemo/core/classes/exportable.py | 1 - nemo/utils/export_utils.py | 31 ---------------------- 6 files changed, 3 insertions(+), 45 deletions(-) diff --git a/nemo/collections/tts/models/radtts.py b/nemo/collections/tts/models/radtts.py index 73e2adb5f8bb..75c6273860ca 100644 --- a/nemo/collections/tts/models/radtts.py +++ b/nemo/collections/tts/models/radtts.py @@ -27,7 +27,6 @@ from nemo.collections.tts.helpers.helpers import plot_alignment_to_numpy from nemo.collections.tts.losses.radttsloss import AttentionBinarizationLoss, RADTTSLoss from nemo.collections.tts.models.base import SpectrogramGenerator -from nemo.collections.tts.modules.submodules import PartialConv1d from nemo.core.classes import Exportable from nemo.core.classes.common import typecheck from nemo.core.neural_types.elements import Index, MelSpectrogramType, TokenIndex diff --git a/nemo/collections/tts/modules/common.py b/nemo/collections/tts/modules/common.py index 3b333a12d19e..e5c03e43d6c1 100644 --- a/nemo/collections/tts/modules/common.py +++ b/nemo/collections/tts/modules/common.py @@ -123,15 +123,11 @@ def lstm_tensor(self, context: Tensor, lens: Tensor, enforce_sorted: bool = Fals seq = nn.utils.rnn.pack_padded_sequence( context, lens.long().cpu(), batch_first=True, enforce_sorted=enforce_sorted ) - # if not (torch.jit.is_tracing() or torch.jit.is_scripting()): - # self.bilstm.flatten_parameters() ret, _ = self.bilstm(seq) return nn.utils.rnn.pad_packed_sequence(ret, batch_first=True) @torch.jit.export def lstm_sequence(self, seq: PackedSequence) -> Tuple[Tensor, Tensor]: - # if not (torch.jit.is_tracing() or torch.jit.is_scripting()): - # self.bilstm.flatten_parameters() ret, _ = self.bilstm(seq) return nn.utils.rnn.pad_packed_sequence(ret, batch_first=True) diff --git a/nemo/collections/tts/modules/radtts.py b/nemo/collections/tts/modules/radtts.py index f35f1667e1d9..87ccf7881a0f 100644 --- a/nemo/collections/tts/modules/radtts.py +++ b/nemo/collections/tts/modules/radtts.py @@ -32,7 +32,6 @@ get_mask_from_lengths, getRadTTSEncoder, ) -from nemo.collections.tts.modules.submodules import PartialConv1d from nemo.core.classes import Exportable, NeuralModule from nemo.core.neural_types.elements import Index, LengthsType, MelSpectrogramType, TokenDurationType, TokenIndex from nemo.core.neural_types.neural_type import NeuralType diff --git a/nemo/collections/tts/modules/submodules.py b/nemo/collections/tts/modules/submodules.py index b42e12fbbafe..970d72a7ea4b 100644 --- a/nemo/collections/tts/modules/submodules.py +++ b/nemo/collections/tts/modules/submodules.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Tuple +from typing import Tuple import torch from torch import Tensor @@ -28,9 +28,6 @@ def masked_instance_norm( See :class:`~MaskedInstanceNorm1d` for details. """ shape = input.shape - b, c = shape[:2] - - num_dims = 1 lengths = mask.sum((-1,)) mean = (input * mask).sum((-1,)) / lengths # (N, C) @@ -101,9 +98,8 @@ def forward(self, input, mask_in): dilation=self.dilation, groups=1, ) - # update_mask_filled = torch.masked_fill(update_mask, update_mask == 0, self.slide_winsize) - mask_ratio = self.slide_winsize / (update_mask + 1e-6) - # mask_ratio = self.slide_winsize / update_mask_filled + update_mask_filled = torch.masked_fill(update_mask, update_mask == 0, self.slide_winsize) + mask_ratio = self.slide_winsize / update_mask_filled update_mask = torch.clamp(update_mask, 0, 1) mask_ratio = torch.mul(mask_ratio, update_mask) diff --git a/nemo/core/classes/exportable.py b/nemo/core/classes/exportable.py index 9204eb357b78..76ca8525ac01 100644 --- a/nemo/core/classes/exportable.py +++ b/nemo/core/classes/exportable.py @@ -28,7 +28,6 @@ parse_input_example, replace_for_export, verify_runtime, - verify_torchscript, wrap_forward_method, ) diff --git a/nemo/utils/export_utils.py b/nemo/utils/export_utils.py index 11dc5416b9b0..a3d0cf13944e 100644 --- a/nemo/utils/export_utils.py +++ b/nemo/utils/export_utils.py @@ -134,20 +134,6 @@ def to_onnxrt_input(ort_input_names, input_names, input_dict, input_list): return odict -def verify_torchscript(model, output, input_examples, input_names, check_tolerance=0.01): - ts_model = torch.jit.load(output) - - all_good = True - for input_example in input_examples: - input_list, input_dict = parse_input_example(input_example) - output_example = model.forward(*input_list, **input_dict) - # ts_input = to_onnxrt_input(ort_input_names, input_names, input_dict, input_list) - all_good = all_good and run_ts_and_compare(ts_model, input_list, input_dict, output_example, check_tolerance) - status = "SUCCESS" if all_good else "FAIL" - logging.info(f"Torchscript generated at {output} verified with torchscript forward : " + status) - return all_good - - def verify_runtime(model, output, input_examples, input_names, check_tolerance=0.01): onnx_model = onnx.load(output) ort_input_names = [node.name for node in onnx_model.graph.input] @@ -172,23 +158,6 @@ def verify_runtime(model, output, input_examples, input_names, check_tolerance=0 return all_good -def run_ts_and_compare(ts_model, ts_input_list, ts_input_dict, output_example, check_tolerance=0.01): - # Verify the model can be read, and is valid - ts_out = ts_model(*ts_input_list, **ts_input_dict) - - all_good = True - for i, out in enumerate(ts_out): - expected = output_example[i] - - if torch.is_tensor(expected): - tout = out.to('cpu') - logging.info(f"Checking output {i}, shape: {expected.shape}:\n{expected}\n{tout}") - if not torch.allclose(tout, expected.cpu(), rtol=check_tolerance, atol=check_tolerance): - all_good = False - logging.info(f"onnxruntime results mismatch! PyTorch(expected):\n{expected}\nTorchScript:\n{tout}") - return all_good - - def run_ort_and_compare(sess, ort_input, output_example, check_tolerance=0.01): # Verify the model can be read, and is valid ort_out = sess.run(None, ort_input) From 16c0b7281c57838f08366434e362dc17c167d9f0 Mon Sep 17 00:00:00 2001 From: Boris Fomitchev Date: Mon, 14 Nov 2022 21:27:50 -0800 Subject: [PATCH 28/43] Fixed lstm_tensor Signed-off-by: Boris Fomitchev --- nemo/collections/tts/modules/common.py | 13 ++++++------- nemo/collections/tts/modules/radtts.py | 12 +++++++----- 2 files changed, 13 insertions(+), 12 deletions(-) diff --git a/nemo/collections/tts/modules/common.py b/nemo/collections/tts/modules/common.py index e5c03e43d6c1..184c4388756d 100644 --- a/nemo/collections/tts/modules/common.py +++ b/nemo/collections/tts/modules/common.py @@ -123,26 +123,25 @@ def lstm_tensor(self, context: Tensor, lens: Tensor, enforce_sorted: bool = Fals seq = nn.utils.rnn.pack_padded_sequence( context, lens.long().cpu(), batch_first=True, enforce_sorted=enforce_sorted ) + if not torch.jit.is_scripting(): + self.bilstm.flatten_parameters() ret, _ = self.bilstm(seq) return nn.utils.rnn.pad_packed_sequence(ret, batch_first=True) @torch.jit.export def lstm_sequence(self, seq: PackedSequence) -> Tuple[Tensor, Tensor]: + if not torch.jit.is_scripting(): + self.bilstm.flatten_parameters() ret, _ = self.bilstm(seq) return nn.utils.rnn.pad_packed_sequence(ret, batch_first=True) @torch.jit.export def sort_and_lstm_tensor(self, context: Tensor, lens: Tensor) -> Tensor: - lens_sorted, ids_sorted = torch.sort(lens, descending=True) - unsort_ids = torch.zeros_like(ids_sorted) - for i in range(ids_sorted.shape[0]): - unsort_ids[ids_sorted[i]] = i - context = context[ids_sorted] + context, lens_sorted, unsort_ids = sort_tensor(context, lens) seq = nn.utils.rnn.pack_padded_sequence( context, lens_sorted.long().cpu(), batch_first=True, enforce_sorted=True ) - ret, _ = self.bilstm(seq) - return nn.utils.rnn.pad_packed_sequence(ret, batch_first=True)[0][unsort_ids] + return self.lstm_sequence(seq)[0][unsort_ids] class ConvLSTMLinear(BiLSTM): diff --git a/nemo/collections/tts/modules/radtts.py b/nemo/collections/tts/modules/radtts.py index 87ccf7881a0f..8bec71074e49 100644 --- a/nemo/collections/tts/modules/radtts.py +++ b/nemo/collections/tts/modules/radtts.py @@ -356,7 +356,7 @@ def preprocess_context(self, context, speaker_vecs, out_lens, f0, energy_avg): context_w_spkvec = torch.cat((context_w_spkvec, energy_avg), 1) unfolded_out_lens = out_lens // self.n_group_size - context_lstm_padded_output, _ = self.context_lstm.lstm_tensor( + context_lstm_padded_output = self.context_lstm.sort_and_lstm_tensor( context_w_spkvec.transpose(1, 2), unfolded_out_lens ) context_w_spkvec = context_lstm_padded_output.transpose(1, 2) @@ -680,8 +680,8 @@ def infer( # map from z sample to data num_steps_to_exit = len(self.exit_steps) - mel = residual[:, num_steps_to_exit * self.n_early_size :] - remaining_residual = residual[:, : num_steps_to_exit * self.n_early_size] + # mel = residual[:, num_steps_to_exit * self.n_early_size :] + remaining_residual, mel = torch.tensor_split(residual, [num_steps_to_exit * self.n_early_size,], dim=1) for i, flow_step in enumerate(reversed(self.flows)): curr_step = self.n_flows - i - 1 @@ -689,8 +689,10 @@ def infer( if num_steps_to_exit > 0 and curr_step == self.exit_steps[num_steps_to_exit - 1]: # concatenate the next chunk of z num_steps_to_exit = num_steps_to_exit - 1 - residual_to_add = remaining_residual[:, num_steps_to_exit * self.n_early_size :] - remaining_residual = remaining_residual[:, : num_steps_to_exit * self.n_early_size] + # residual_to_add = remaining_residual[:, num_steps_to_exit * self.n_early_size :] + remaining_residual, residual_to_add = torch.tensor_split( + remaining_residual, [num_steps_to_exit * self.n_early_size,], dim=1 + ) mel = torch.cat((residual_to_add, mel), 1) if self.n_group_size > 1: From 32a1943fb294db4d458ba72b14e840c4c5000b26 Mon Sep 17 00:00:00 2001 From: Boris Fomitchev Date: Mon, 14 Nov 2022 21:50:49 -0800 Subject: [PATCH 29/43] restored TS check routine Signed-off-by: Boris Fomitchev --- nemo/core/classes/exportable.py | 9 +++++++++ nemo/utils/export_utils.py | 31 +++++++++++++++++++++++++++++++ 2 files changed, 40 insertions(+) diff --git a/nemo/core/classes/exportable.py b/nemo/core/classes/exportable.py index 76ca8525ac01..5a9ab55a4ee7 100644 --- a/nemo/core/classes/exportable.py +++ b/nemo/core/classes/exportable.py @@ -28,6 +28,7 @@ parse_input_example, replace_for_export, verify_runtime, + verify_torchscript, wrap_forward_method, ) @@ -166,6 +167,14 @@ def _export( jitted_model.save(output) assert os.path.exists(output) + if check_trace: + if isinstance(check_trace, bool): + check_trace_input = [input_example] + else: + check_trace_input = check_trace + + verify_torchscript(jitted_model, output, check_trace_input, input_names, check_tolerance) + elif format == ExportFormat.ONNX: # dynamic axis is a mapping from input/output_name => list of "dynamic" indices if dynamic_axes is None: diff --git a/nemo/utils/export_utils.py b/nemo/utils/export_utils.py index a3d0cf13944e..11dc5416b9b0 100644 --- a/nemo/utils/export_utils.py +++ b/nemo/utils/export_utils.py @@ -134,6 +134,20 @@ def to_onnxrt_input(ort_input_names, input_names, input_dict, input_list): return odict +def verify_torchscript(model, output, input_examples, input_names, check_tolerance=0.01): + ts_model = torch.jit.load(output) + + all_good = True + for input_example in input_examples: + input_list, input_dict = parse_input_example(input_example) + output_example = model.forward(*input_list, **input_dict) + # ts_input = to_onnxrt_input(ort_input_names, input_names, input_dict, input_list) + all_good = all_good and run_ts_and_compare(ts_model, input_list, input_dict, output_example, check_tolerance) + status = "SUCCESS" if all_good else "FAIL" + logging.info(f"Torchscript generated at {output} verified with torchscript forward : " + status) + return all_good + + def verify_runtime(model, output, input_examples, input_names, check_tolerance=0.01): onnx_model = onnx.load(output) ort_input_names = [node.name for node in onnx_model.graph.input] @@ -158,6 +172,23 @@ def verify_runtime(model, output, input_examples, input_names, check_tolerance=0 return all_good +def run_ts_and_compare(ts_model, ts_input_list, ts_input_dict, output_example, check_tolerance=0.01): + # Verify the model can be read, and is valid + ts_out = ts_model(*ts_input_list, **ts_input_dict) + + all_good = True + for i, out in enumerate(ts_out): + expected = output_example[i] + + if torch.is_tensor(expected): + tout = out.to('cpu') + logging.info(f"Checking output {i}, shape: {expected.shape}:\n{expected}\n{tout}") + if not torch.allclose(tout, expected.cpu(), rtol=check_tolerance, atol=check_tolerance): + all_good = False + logging.info(f"onnxruntime results mismatch! PyTorch(expected):\n{expected}\nTorchScript:\n{tout}") + return all_good + + def run_ort_and_compare(sess, ort_input, output_example, check_tolerance=0.01): # Verify the model can be read, and is valid ort_out = sess.run(None, ort_input) From 14561be00eda47c6bf309c9b6e2471abcfdbb7f6 Mon Sep 17 00:00:00 2001 From: Boris Fomitchev Date: Mon, 14 Nov 2022 22:05:47 -0800 Subject: [PATCH 30/43] Fixed config error Signed-off-by: Boris Fomitchev --- examples/tts/conf/rad-tts_feature_pred.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/tts/conf/rad-tts_feature_pred.yaml b/examples/tts/conf/rad-tts_feature_pred.yaml index cd7cce38f1a1..ed881d31b70f 100644 --- a/examples/tts/conf/rad-tts_feature_pred.yaml +++ b/examples/tts/conf/rad-tts_feature_pred.yaml @@ -12,8 +12,8 @@ whitelist_path: "nemo_text_processing/text_normalization/en/data/whitelist/tts.t # these frame-wise values depend on pitch_fmin and pitch_fmax, you can get values # by running `scripts/dataset_processing/tts/extract_sup_data.py` -pitch_mean: ??? # e.g. 212.35873413085938 for LJSpeech -pitch_std: ??? # e.g. 68.52806091308594 for LJSpeech +pitch_mean: 145.0 # e.g. 212.35873413085938 for LJSpeech +pitch_std: 30.0 # e.g. 68.52806091308594 for LJSpeech # default values from librosa.pyin pitch_fmin: 65.40639132514966 From bb834da4cbe7ff7ef6d6e5c972bc86c3196269d7 Mon Sep 17 00:00:00 2001 From: Boris Fomitchev Date: Tue, 15 Nov 2022 02:26:53 -0800 Subject: [PATCH 31/43] reverting some bad optimizations Signed-off-by: Boris Fomitchev --- nemo/collections/tts/modules/radtts.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/nemo/collections/tts/modules/radtts.py b/nemo/collections/tts/modules/radtts.py index 8bec71074e49..96cdec6246cd 100644 --- a/nemo/collections/tts/modules/radtts.py +++ b/nemo/collections/tts/modules/radtts.py @@ -411,18 +411,18 @@ def get_first_order_features(self, feats, out_lens, dilation=1): return (dfeats_R + dfeats_L) * 0.5 - def apply_voice_mask_to_text(self, text_enc, voiced_mask_bool): + def apply_voice_mask_to_text(self, text_enc, voiced_mask): """ text_enc: b x C x N - voiced_mask_bool: b x N + voiced_mask: b x N """ - voiced_mask_bool = voiced_mask_bool.unsqueeze(1) + voiced_mask = voiced_mask.unsqueeze(1) voiced_embedding_s = self.v_embeddings.weight[0:1, :, None] unvoiced_embedding_s = self.v_embeddings.weight[1:2, :, None] voiced_embedding_b = self.v_embeddings.weight[2:3, :, None] unvoiced_embedding_b = self.v_embeddings.weight[3:4, :, None] - scale = torch.sigmoid(torch.where(voiced_mask_bool, voiced_embedding_s, unvoiced_embedding_s)) - bias = 0.1 * torch.tanh(torch.where(voiced_mask_bool, voiced_embedding_b, unvoiced_embedding_b)) + scale = torch.sigmoid(voiced_embedding_s * voiced_mask + unvoiced_embedding_s * (1 - voiced_mask)) + bias = 0.1 * torch.tanh(voiced_embedding_b * voiced_mask + unvoiced_embedding_b * (1 - voiced_mask)) return text_enc * scale + bias def forward( @@ -535,7 +535,7 @@ def forward( # affine transform context using voiced mask if self.ap_use_voiced_embeddings: - text_enc_time_expanded = self.apply_voice_mask_to_text(text_enc_time_expanded, voiced_mask_bool) + text_enc_time_expanded = self.apply_voice_mask_to_text(text_enc_time_expanded, voiced_mask) if self.ap_use_unvoiced_bias: # whether to use the unvoiced bias in the attribute predictor f0_target = torch.detach(f0 * voiced_mask + f0_bias) else: @@ -810,7 +810,7 @@ def input_example(self, max_batch=1, max_dim=256): par = next(self.parameters()) sz = (max_batch, max_dim) inp = torch.randint(0, 16, sz, device=par.device, dtype=torch.int64) - lens = torch.randint(16, max_dim, (max_batch,), device=par.device, dtype=torch.int) + lens = torch.randint(0, max_dim, (max_batch,), device=par.device, dtype=torch.int) speaker = torch.randint(0, 1, (max_batch,), device=par.device, dtype=torch.int64) inputs = { 'text': inp, From b95fb2c34713d44ad7f0cb7d9c220af8dcd29be2 Mon Sep 17 00:00:00 2001 From: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> Date: Wed, 9 Nov 2022 21:45:40 -0800 Subject: [PATCH 32/43] [TTS] add CI test for RADTTS training recipe. Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> --- Jenkinsfile | 37 ++++++++++++++++++++++++++++++++++--- 1 file changed, 34 insertions(+), 3 deletions(-) diff --git a/Jenkinsfile b/Jenkinsfile index b6eeb7c53ade..fbbcfc53ac40 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -4206,7 +4206,9 @@ assert_frame_equal(training_curve, gt_curve, rtol=1e-3, atol=1e-3)"''' validation_datasets=/home/TestData/an4_dataset/an4_val.json \ sup_data_path=/home/TestData/an4_dataset/beta_priors \ trainer.devices="[0]" \ - +trainer.limit_train_batches=1 +trainer.limit_val_batches=1 trainer.max_epochs=1 \ + +trainer.limit_train_batches=1 \ + +trainer.limit_val_batches=1 \ + trainer.max_epochs=1 \ trainer.strategy=null \ model.pitch_mean=212.35873413085938 \ model.pitch_std=68.52806091308594 \ @@ -4224,6 +4226,31 @@ assert_frame_equal(training_curve, gt_curve, rtol=1e-3, atol=1e-3)"''' ~model.text_normalizer_call_kwargs' } } + stage('RADTTS') { + steps { + sh 'python examples/tts/radtts.py \ + train_dataset=/home/TestData/an4_dataset/an4_train.json \ + validation_datasets=/home/TestData/an4_dataset/an4_val.json \ + sup_data_path=/home/TestData/an4_dataset/radtts_beta_priors \ + trainer.devices="[0]" \ + +trainer.limit_train_batches=1 \ + +trainer.limit_val_batches=1 \ + trainer.max_epochs=1 \ + trainer.strategy=null \ + model.pitch_mean=212.35873413085938 \ + model.pitch_std=68.52806091308594 \ + model.train_ds.dataloader_params.batch_size=4 \ + model.train_ds.dataloader_params.num_workers=0 \ + model.validation_ds.dataloader_params.batch_size=4 \ + model.validation_ds.dataloader_params.num_workers=0 \ + export_dir=/home/TestData/radtts_test \ + model.optim.lr=0.0001 \ + model.modelConfig.decoder_use_partial_padding=True \ + ~trainer.check_val_every_n_epoch \ + ~model.text_normalizer \ + ~model.text_normalizer_call_kwargs' + } + } stage('Mixer-TTS') { steps { sh 'python examples/tts/mixer_tts.py \ @@ -4231,7 +4258,9 @@ assert_frame_equal(training_curve, gt_curve, rtol=1e-3, atol=1e-3)"''' validation_datasets=/home/TestData/an4_dataset/an4_val.json \ sup_data_path=/home/TestData/an4_dataset/sup_data \ trainer.devices="[0]" \ - +trainer.limit_train_batches=1 +trainer.limit_val_batches=1 trainer.max_epochs=1 \ + +trainer.limit_train_batches=1 \ + +trainer.limit_val_batches=1 \ + trainer.max_epochs=1 \ trainer.strategy=null \ model.pitch_mean=212.35873413085938 \ model.pitch_std=68.52806091308594 \ @@ -4250,7 +4279,9 @@ assert_frame_equal(training_curve, gt_curve, rtol=1e-3, atol=1e-3)"''' train_dataset=/home/TestData/an4_dataset/an4_train.json \ validation_datasets=/home/TestData/an4_dataset/an4_val.json \ trainer.devices="[0]" \ - +trainer.limit_train_batches=1 +trainer.limit_val_batches=1 +trainer.max_epochs=1 \ + +trainer.limit_train_batches=1 \ + +trainer.limit_val_batches=1 \ + +trainer.max_epochs=1 \ trainer.strategy=null \ model.train_ds.dataloader_params.batch_size=4 \ model.train_ds.dataloader_params.num_workers=0 \ From c620de4b2a9a35a8ec956aa7d25cec20ab9e8c02 Mon Sep 17 00:00:00 2001 From: Boris Fomitchev Date: Tue, 15 Nov 2022 15:11:39 -0800 Subject: [PATCH 33/43] Addressing code review Signed-off-by: Boris Fomitchev --- examples/tts/conf/rad-tts_feature_pred.yaml | 4 ++-- nemo/collections/tts/modules/radtts.py | 24 +++++++++---------- nemo/core/classes/exportable.py | 21 ++++++---------- nemo/utils/export_utils.py | 6 ++--- tests/collections/tts/test_tts_exportables.py | 8 +++---- 5 files changed, 26 insertions(+), 37 deletions(-) diff --git a/examples/tts/conf/rad-tts_feature_pred.yaml b/examples/tts/conf/rad-tts_feature_pred.yaml index ed881d31b70f..cd7cce38f1a1 100644 --- a/examples/tts/conf/rad-tts_feature_pred.yaml +++ b/examples/tts/conf/rad-tts_feature_pred.yaml @@ -12,8 +12,8 @@ whitelist_path: "nemo_text_processing/text_normalization/en/data/whitelist/tts.t # these frame-wise values depend on pitch_fmin and pitch_fmax, you can get values # by running `scripts/dataset_processing/tts/extract_sup_data.py` -pitch_mean: 145.0 # e.g. 212.35873413085938 for LJSpeech -pitch_std: 30.0 # e.g. 68.52806091308594 for LJSpeech +pitch_mean: ??? # e.g. 212.35873413085938 for LJSpeech +pitch_std: ??? # e.g. 68.52806091308594 for LJSpeech # default values from librosa.pyin pitch_fmin: 65.40639132514966 diff --git a/nemo/collections/tts/modules/radtts.py b/nemo/collections/tts/modules/radtts.py index 96cdec6246cd..35c443acde58 100644 --- a/nemo/collections/tts/modules/radtts.py +++ b/nemo/collections/tts/modules/radtts.py @@ -59,12 +59,11 @@ def pad_energy_avg_and_f0(energy_avg, f0, max_out_len): def adjust_f0(f0, f0_mean, f0_std, vmask_bool): if f0_mean > 0.0: - masked_f0 = f0[vmask_bool] - f0_sigma, f0_mu = torch.std_mean(masked_f0) - f0[vmask_bool] = ((masked_f0 - f0_mu) / f0_sigma).to(dtype=f0.dtype) + f0_sigma, f0_mu = torch.std_mean(f0[vmask_bool]) + f0 = ((f0 - f0_mu) / f0_sigma).to(dtype=f0.dtype) f0_std = f0_std if f0_std > 0 else f0_sigma - f0[vmask_bool] = (f0[vmask_bool] * f0_std + f0_mean).to(dtype=f0.dtype) - return f0 + f0 = (f0 * f0_std + f0_mean).to(dtype=f0.dtype) + return f0.masked_fill(~vmask_bool, 0.0) class FlowStep(nn.Module): @@ -636,9 +635,10 @@ def infer( if self.use_vpred_module: # get logits voiced_mask = self.v_pred_module.infer(None, txt_enc_time_expanded, spk_vec_attributes, lens=out_lens) - voiced_mask = torch.sigmoid(voiced_mask[:, 0]) > 0.5 + voiced_mask_bool = torch.sigmoid(voiced_mask[:, 0]) > 0.5 + voiced_mask = voiced_mask_bool.to(dur.dtype) else: - voiced_mask = voiced_mask.bool() + voiced_mask_bool = voiced_mask.bool() ap_txt_enc_time_expanded = txt_enc_time_expanded # voice mask augmentation only used for attribute prediction @@ -650,14 +650,13 @@ def infer( if self.use_unvoiced_bias: f0_bias = self.unvoiced_bias_module(txt_enc_time_expanded.permute(0, 2, 1)) f0_bias = -f0_bias[..., 0] - f0_bias.masked_fill_(voiced_mask, 0.0) if f0 is None: n_f0_feature_channels = 2 if self.use_first_order_features else 1 z_f0 = torch.normal(txt_enc.new_zeros(batch_size, n_f0_feature_channels, max_out_len)) * sigma_f0 - f0 = self.infer_f0(z_f0, ap_txt_enc_time_expanded, spk_vec_attributes, voiced_mask, out_lens)[:, 0] + f0 = self.infer_f0(z_f0, ap_txt_enc_time_expanded, spk_vec_attributes, voiced_mask_bool, out_lens)[:, 0] - f0 = adjust_f0(f0, f0_mean, f0_std, voiced_mask) + f0 = adjust_f0(f0, f0_mean, f0_std, voiced_mask_bool) if energy_avg is None: n_energy_feature_channels = 2 if self.use_first_order_features else 1 @@ -673,7 +672,7 @@ def infer( # print("V mask, energy_avg, f0, f0_bias: ", voiced_mask.shape, energy_avg.shape, f0.shape, f0_bias.shape) context_w_spkvec = self.preprocess_context( - txt_enc_time_expanded, spk_vec, out_lens, f0.masked_fill(~voiced_mask, 0.0) + f0_bias, energy_avg + txt_enc_time_expanded, spk_vec, out_lens, (f0 + f0_bias) * voiced_mask, energy_avg ) residual = torch.normal(txt_enc.new_zeros(batch_size, 80 * self.n_group_size, torch.max(n_groups))) * sigma @@ -719,7 +718,6 @@ def infer_f0(self, residual, txt_enc_time_expanded, spk_vec, voiced_mask=None, l if voiced_mask is None: voiced_mask = f0 > 0.0 else: - voiced_mask = voiced_mask.bool() if len(voiced_mask.shape) == 2: voiced_mask = voiced_mask[:, None] # due to grouping, f0 might be 1 frame short @@ -810,7 +808,7 @@ def input_example(self, max_batch=1, max_dim=256): par = next(self.parameters()) sz = (max_batch, max_dim) inp = torch.randint(0, 16, sz, device=par.device, dtype=torch.int64) - lens = torch.randint(0, max_dim, (max_batch,), device=par.device, dtype=torch.int) + lens = torch.randint(16, max_dim, (max_batch,), device=par.device, dtype=torch.int) speaker = torch.randint(0, 1, (max_batch,), device=par.device, dtype=torch.int64) inputs = { 'text': inp, diff --git a/nemo/core/classes/exportable.py b/nemo/core/classes/exportable.py index 5a9ab55a4ee7..60099d921f59 100644 --- a/nemo/core/classes/exportable.py +++ b/nemo/core/classes/exportable.py @@ -147,12 +147,14 @@ def _export( output_names = self.output_names output_example = tuple(self.forward(*input_list, **input_dict)) + if check_trace: + if isinstance(check_trace, bool): + check_trace_input = [input_example] + else: + check_trace_input = check_trace + if format == ExportFormat.TORCHSCRIPT: - if check_trace: - if isinstance(check_trace, bool): - check_trace_input = {"forward": tuple(input_list) + tuple(input_dict.values())} - else: - check_trace_input = check_trace + jitted_model = torch.jit.trace_module( self, {"forward": tuple(input_list) + tuple(input_dict.values())}, @@ -168,11 +170,6 @@ def _export( assert os.path.exists(output) if check_trace: - if isinstance(check_trace, bool): - check_trace_input = [input_example] - else: - check_trace_input = check_trace - verify_torchscript(jitted_model, output, check_trace_input, input_names, check_tolerance) elif format == ExportFormat.ONNX: @@ -196,10 +193,6 @@ def _export( ) if check_trace: - if isinstance(check_trace, bool): - check_trace_input = [input_example] - else: - check_trace_input = check_trace verify_runtime(self, output, check_trace_input, input_names) else: raise ValueError(f'Encountered unknown export format {format}.') diff --git a/nemo/utils/export_utils.py b/nemo/utils/export_utils.py index 11dc5416b9b0..4af379803786 100644 --- a/nemo/utils/export_utils.py +++ b/nemo/utils/export_utils.py @@ -141,7 +141,7 @@ def verify_torchscript(model, output, input_examples, input_names, check_toleran for input_example in input_examples: input_list, input_dict = parse_input_example(input_example) output_example = model.forward(*input_list, **input_dict) - # ts_input = to_onnxrt_input(ort_input_names, input_names, input_dict, input_list) + all_good = all_good and run_ts_and_compare(ts_model, input_list, input_dict, output_example, check_tolerance) status = "SUCCESS" if all_good else "FAIL" logging.info(f"Torchscript generated at {output} verified with torchscript forward : " + status) @@ -182,7 +182,7 @@ def run_ts_and_compare(ts_model, ts_input_list, ts_input_dict, output_example, c if torch.is_tensor(expected): tout = out.to('cpu') - logging.info(f"Checking output {i}, shape: {expected.shape}:\n{expected}\n{tout}") + logging.debug(f"Checking output {i}, shape: {expected.shape}:\n{expected}\n{tout}") if not torch.allclose(tout, expected.cpu(), rtol=check_tolerance, atol=check_tolerance): all_good = False logging.info(f"onnxruntime results mismatch! PyTorch(expected):\n{expected}\nTorchScript:\n{tout}") @@ -198,7 +198,7 @@ def run_ort_and_compare(sess, ort_input, output_example, check_tolerance=0.01): if torch.is_tensor(expected): tout = torch.from_numpy(out) - logging.info(f"Checking output {i}, shape: {expected.shape}:\n{expected}\n{tout}") + logging.debug(f"Checking output {i}, shape: {expected.shape}:\n{expected}\n{tout}") if not torch.allclose(tout, expected.cpu(), rtol=check_tolerance, atol=100 * check_tolerance): all_good = False logging.info(f"onnxruntime results mismatch! PyTorch(expected):\n{expected}\nONNXruntime:\n{tout}") diff --git a/tests/collections/tts/test_tts_exportables.py b/tests/collections/tts/test_tts_exportables.py index 4b64b5096bc3..e3e496373271 100644 --- a/tests/collections/tts/test_tts_exportables.py +++ b/tests/collections/tts/test_tts_exportables.py @@ -43,6 +43,9 @@ def radtts_model(): cfg.model.train_ds.dataset.sup_data_path = "dummy.json" cfg.model.validation_ds.dataset.manifest_filepath = "dummy.json" cfg.model.validation_ds.dataset.sup_data_path = "dummy.json" + cfg.pitch_mean = 212.35 + cfg.pitch_std = 68.52 + app_state = AppState() app_state.is_model_being_restored = True model = RadTTSModel(cfg=cfg.model) @@ -77,8 +80,3 @@ def test_RadTTSModel_export_to_torchscript(self, radtts_model): with tempfile.TemporaryDirectory() as tmpdir: filename = os.path.join(tmpdir, 'rad.ts') model.export(output=filename, verbose=True, check_trace=True) - - -if __name__ == "__main__": - t = TestExportable() - t.test_RadTTSModel_export_to_torchscript(radtts_model()) From cb6c459ccfd6e2b36bef425fe8ffdf67834a516a Mon Sep 17 00:00:00 2001 From: Boris Fomitchev Date: Tue, 15 Nov 2022 19:24:18 -0800 Subject: [PATCH 34/43] Removing unused var Signed-off-by: Boris Fomitchev --- nemo/collections/tts/modules/submodules.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/nemo/collections/tts/modules/submodules.py b/nemo/collections/tts/modules/submodules.py index 970d72a7ea4b..ca6b6b7581fe 100644 --- a/nemo/collections/tts/modules/submodules.py +++ b/nemo/collections/tts/modules/submodules.py @@ -27,14 +27,10 @@ def masked_instance_norm( See :class:`~MaskedInstanceNorm1d` for details. """ - shape = input.shape - lengths = mask.sum((-1,)) mean = (input * mask).sum((-1,)) / lengths # (N, C) var = (((input - mean[(..., None)]) * mask) ** 2).sum((-1,)) / lengths # (N, C) - out = (input - mean[(..., None)]) / torch.sqrt(var[(..., None)] + eps) # (N, C, ...) - out = out * weight[None, :][(..., None)] + bias[None, :][(..., None)] return out From 87179ab640b2b9b91c9df185f0fdb4b7ffe52d62 Mon Sep 17 00:00:00 2001 From: Boris Fomitchev Date: Tue, 15 Nov 2022 19:39:04 -0800 Subject: [PATCH 35/43] Adding debug Signed-off-by: Boris Fomitchev --- examples/tts/radtts.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/examples/tts/radtts.py b/examples/tts/radtts.py index 7260e8d9907f..ca63e1efb2fa 100644 --- a/examples/tts/radtts.py +++ b/examples/tts/radtts.py @@ -12,6 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os +import signal +import sys +import time + import pytorch_lightning as pl from nemo.collections.common.callbacks import LogEpochTimeCallback @@ -21,6 +26,12 @@ from nemo.utils.exp_manager import exp_manager +def handle_pdb(sig, frame): + import pdb + + pdb.Pdb().set_trace(frame) + + def freeze(model): for p in model.parameters(): p.requires_grad = False @@ -59,6 +70,7 @@ def prepare_model_weights(model, unfreeze_modules): @hydra_runner(config_path="conf", config_name="rad-tts_dec") def main(cfg): + signal.signal(signal.SIGUSR1, handle_pdb) trainer = pl.Trainer(**cfg.trainer) exp_manager(trainer, cfg.get('exp_manager', None)) model = RadTTSModel(cfg=cfg.model, trainer=trainer) From c9d9e4f57d73caed073ddb65e5077cc94a142894 Mon Sep 17 00:00:00 2001 From: Boris Fomitchev Date: Tue, 15 Nov 2022 20:45:05 -0800 Subject: [PATCH 36/43] Logging fixes Signed-off-by: Boris Fomitchev --- nemo/collections/tts/models/radtts.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/nemo/collections/tts/models/radtts.py b/nemo/collections/tts/models/radtts.py index 75c6273860ca..a2ba2387b998 100644 --- a/nemo/collections/tts/models/radtts.py +++ b/nemo/collections/tts/models/radtts.py @@ -158,7 +158,7 @@ def training_step(self, batch, batch_idx): loss_outputs['binarization_loss'] = (binarization_loss, 1.0) for k, (v, w) in loss_outputs.items(): - self.log("train/" + k, loss_outputs[k][0]) + self.log("train/" + k, loss_outputs[k][0], on_step=True, sync_dist=True) return {'loss': loss} @@ -228,7 +228,7 @@ def validation_epoch_end(self, outputs): for k, v in loss_outputs.items(): if k != "binarization_loss": - self.log("val/" + k, loss_outputs[k][0]) + self.log("val/" + k, loss_outputs[k][0], sync_dist=True, on_epoch=True) attn = outputs[0]["attn"] attn_soft = outputs[0]["attn_soft"] From 55e95b1e45bb64ba1c5909f80b4876832898b161 Mon Sep 17 00:00:00 2001 From: Boris Fomitchev Date: Tue, 15 Nov 2022 23:13:18 -0800 Subject: [PATCH 37/43] Fixing training warnings Signed-off-by: Boris Fomitchev --- nemo/collections/common/callbacks/callbacks.py | 2 +- nemo/collections/tts/models/radtts.py | 2 +- nemo/core/optim/radam.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/nemo/collections/common/callbacks/callbacks.py b/nemo/collections/common/callbacks/callbacks.py index 489c862b3780..1a6c011c38df 100644 --- a/nemo/collections/common/callbacks/callbacks.py +++ b/nemo/collections/common/callbacks/callbacks.py @@ -13,7 +13,7 @@ # limitations under the License. import time -from pytorch_lightning.callbacks.base import Callback +from pytorch_lightning.callbacks import Callback from pytorch_lightning.utilities import rank_zero_only # from sacrebleu import corpus_bleu diff --git a/nemo/collections/tts/models/radtts.py b/nemo/collections/tts/models/radtts.py index a2ba2387b998..30b6189c484f 100644 --- a/nemo/collections/tts/models/radtts.py +++ b/nemo/collections/tts/models/radtts.py @@ -158,7 +158,7 @@ def training_step(self, batch, batch_idx): loss_outputs['binarization_loss'] = (binarization_loss, 1.0) for k, (v, w) in loss_outputs.items(): - self.log("train/" + k, loss_outputs[k][0], on_step=True, sync_dist=True) + self.log("train/" + k, loss_outputs[k][0], on_step=True) return {'loss': loss} diff --git a/nemo/core/optim/radam.py b/nemo/core/optim/radam.py index 62a5ecff87be..40fcf5e827b0 100644 --- a/nemo/core/optim/radam.py +++ b/nemo/core/optim/radam.py @@ -81,7 +81,7 @@ def step(self, closure=None): exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] beta1, beta2 = group['betas'] - exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, 1 - beta2) exp_avg.mul_(beta1).add_(1 - beta1, grad) state['step'] += 1 From 647640c9b5d167ad03ac723a8a3dd083cbf280ae Mon Sep 17 00:00:00 2001 From: Boris Fomitchev Date: Tue, 15 Nov 2022 23:30:16 -0800 Subject: [PATCH 38/43] Fixing more warnings Signed-off-by: Boris Fomitchev --- nemo/collections/tts/modules/common.py | 2 +- nemo/core/optim/radam.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/nemo/collections/tts/modules/common.py b/nemo/collections/tts/modules/common.py index 184c4388756d..cea745a2a6a3 100644 --- a/nemo/collections/tts/modules/common.py +++ b/nemo/collections/tts/modules/common.py @@ -235,7 +235,7 @@ class Invertible1x1ConvLUS(torch.nn.Module): def __init__(self, c): super(Invertible1x1ConvLUS, self).__init__() # Sample a random orthonormal matrix to initialize weights - W = torch.qr(torch.FloatTensor(c, c).normal_())[0] + W, _ = torch.linalg.qr(torch.FloatTensor(c, c).normal_()) # Ensure determinant is 1.0 not -1.0 if torch.det(W) < 0: W[:, 0] = -1 * W[:, 0] diff --git a/nemo/core/optim/radam.py b/nemo/core/optim/radam.py index 40fcf5e827b0..7571adc9be70 100644 --- a/nemo/core/optim/radam.py +++ b/nemo/core/optim/radam.py @@ -81,7 +81,7 @@ def step(self, closure=None): exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] beta1, beta2 = group['betas'] - exp_avg_sq.mul_(beta2).addcmul_(grad, grad, 1 - beta2) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) exp_avg.mul_(beta1).add_(1 - beta1, grad) state['step'] += 1 From c5649d0dc62b736841251816a748fa8852192378 Mon Sep 17 00:00:00 2001 From: Boris Fomitchev Date: Tue, 15 Nov 2022 23:50:07 -0800 Subject: [PATCH 39/43] Fixing more warnings 2 Signed-off-by: Boris Fomitchev --- nemo/core/optim/radam.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/nemo/core/optim/radam.py b/nemo/core/optim/radam.py index 7571adc9be70..69cfab4bf858 100644 --- a/nemo/core/optim/radam.py +++ b/nemo/core/optim/radam.py @@ -81,8 +81,8 @@ def step(self, closure=None): exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] beta1, beta2 = group['betas'] - exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) - exp_avg.mul_(beta1).add_(1 - beta1, grad) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=(1.0 - beta2)) + exp_avg.mul_(beta1).add_(grad, alpha=(1.0 - beta1)) state['step'] += 1 buffered = self.buffer[int(state['step'] % 10)] From d3d05292cef5f9d8da8c23670564ac5ea0b1b89c Mon Sep 17 00:00:00 2001 From: Boris Fomitchev Date: Wed, 16 Nov 2022 17:35:08 -0800 Subject: [PATCH 40/43] Code review fixes Signed-off-by: Boris Fomitchev --- examples/tts/radtts.py | 12 ------------ nemo/collections/tts/modules/submodules.py | 2 +- 2 files changed, 1 insertion(+), 13 deletions(-) diff --git a/examples/tts/radtts.py b/examples/tts/radtts.py index ca63e1efb2fa..7260e8d9907f 100644 --- a/examples/tts/radtts.py +++ b/examples/tts/radtts.py @@ -12,11 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os -import signal -import sys -import time - import pytorch_lightning as pl from nemo.collections.common.callbacks import LogEpochTimeCallback @@ -26,12 +21,6 @@ from nemo.utils.exp_manager import exp_manager -def handle_pdb(sig, frame): - import pdb - - pdb.Pdb().set_trace(frame) - - def freeze(model): for p in model.parameters(): p.requires_grad = False @@ -70,7 +59,6 @@ def prepare_model_weights(model, unfreeze_modules): @hydra_runner(config_path="conf", config_name="rad-tts_dec") def main(cfg): - signal.signal(signal.SIGUSR1, handle_pdb) trainer = pl.Trainer(**cfg.trainer) exp_manager(trainer, cfg.get('exp_manager', None)) model = RadTTSModel(cfg=cfg.model, trainer=trainer) diff --git a/nemo/collections/tts/modules/submodules.py b/nemo/collections/tts/modules/submodules.py index ca6b6b7581fe..e61b9b224885 100644 --- a/nemo/collections/tts/modules/submodules.py +++ b/nemo/collections/tts/modules/submodules.py @@ -305,7 +305,7 @@ def __init__(self, c): self.conv = torch.nn.Conv1d(c, c, kernel_size=1, stride=1, padding=0, bias=False) # Sample a random orthonormal matrix to initialize weights - W = torch.qr(torch.FloatTensor(c, c).normal_())[0] + W = torch.linalg.qr(torch.FloatTensor(c, c).normal_())[0] # Ensure determinant is 1.0 not -1.0 if torch.det(W) < 0: From 79535a0d5226ed67e02cc62d31b83835a7bd0dcc Mon Sep 17 00:00:00 2001 From: Boris Fomitchev Date: Wed, 16 Nov 2022 22:39:46 -0800 Subject: [PATCH 41/43] Improving TS check Signed-off-by: Boris Fomitchev --- nemo/core/classes/exportable.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nemo/core/classes/exportable.py b/nemo/core/classes/exportable.py index 60099d921f59..5eecfbeca1d3 100644 --- a/nemo/core/classes/exportable.py +++ b/nemo/core/classes/exportable.py @@ -167,7 +167,7 @@ def _export( if verbose: logging.info(f"JIT code:\n{jitted_model.code}") jitted_model.save(output) - assert os.path.exists(output) + jitted_model = torch.jit.load(output) if check_trace: verify_torchscript(jitted_model, output, check_trace_input, input_names, check_tolerance) From af8bc14d1c95e43fe3fcc8e1b9fce091db5ca18c Mon Sep 17 00:00:00 2001 From: Boris Fomitchev Date: Thu, 17 Nov 2022 02:34:31 -0800 Subject: [PATCH 42/43] Addressing code review comments, optimizing script Signed-off-by: Boris Fomitchev --- nemo/collections/tts/modules/common.py | 24 ++++++++-------- nemo/collections/tts/modules/radtts.py | 39 ++------------------------ nemo/core/classes/exportable.py | 2 +- nemo/utils/export_utils.py | 14 ++++++++- 4 files changed, 28 insertions(+), 51 deletions(-) diff --git a/nemo/collections/tts/modules/common.py b/nemo/collections/tts/modules/common.py index cea745a2a6a3..6a1ee56753a2 100644 --- a/nemo/collections/tts/modules/common.py +++ b/nemo/collections/tts/modules/common.py @@ -123,16 +123,22 @@ def lstm_tensor(self, context: Tensor, lens: Tensor, enforce_sorted: bool = Fals seq = nn.utils.rnn.pack_padded_sequence( context, lens.long().cpu(), batch_first=True, enforce_sorted=enforce_sorted ) - if not torch.jit.is_scripting(): + if not (torch.jit.is_scripting() or torch.jit.is_tracing()): self.bilstm.flatten_parameters() - ret, _ = self.bilstm(seq) + if hasattr(self.bilstm, 'forward'): + ret, _ = self.bilstm.forward(seq) + else: + ret, _ = self.bilstm.forward_1(seq) return nn.utils.rnn.pad_packed_sequence(ret, batch_first=True) @torch.jit.export def lstm_sequence(self, seq: PackedSequence) -> Tuple[Tensor, Tensor]: - if not torch.jit.is_scripting(): + if not (torch.jit.is_scripting() or torch.jit.is_tracing()): self.bilstm.flatten_parameters() - ret, _ = self.bilstm(seq) + if hasattr(self.bilstm, 'forward'): + ret, _ = self.bilstm.forward(seq) + elif hasattr(self.bilstm, 'forward_1'): + ret, _ = self.bilstm.forward_1(seq) return nn.utils.rnn.pad_packed_sequence(ret, batch_first=True) @torch.jit.export @@ -155,7 +161,6 @@ def __init__( p_dropout=0.1, use_partial_padding=False, norm_fn=None, - lstm_norm_fn="spectral", ): super(ConvLSTMLinear, self).__init__(n_channels, int(n_channels // 2), 1) self.convolutions = nn.ModuleList() @@ -212,12 +217,8 @@ def forward(self, context: Tensor, lens: Tensor) -> Tensor: return context -def getRadTTSEncoder( - encoder_n_convolutions=3, - encoder_embedding_dim=512, - encoder_kernel_size=5, - norm_fn=MaskedInstanceNorm1d, - lstm_norm_fn=None, +def get_radtts_encoder( + encoder_n_convolutions=3, encoder_embedding_dim=512, encoder_kernel_size=5, norm_fn=MaskedInstanceNorm1d, ): return ConvLSTMLinear( in_dim=encoder_embedding_dim, @@ -227,7 +228,6 @@ def getRadTTSEncoder( p_dropout=0.5, use_partial_padding=True, norm_fn=norm_fn, - lstm_norm_fn=lstm_norm_fn, ) diff --git a/nemo/collections/tts/modules/radtts.py b/nemo/collections/tts/modules/radtts.py index 35c443acde58..dca0f0ede62c 100644 --- a/nemo/collections/tts/modules/radtts.py +++ b/nemo/collections/tts/modules/radtts.py @@ -30,7 +30,7 @@ Invertible1x1ConvLUS, LinearNorm, get_mask_from_lengths, - getRadTTSEncoder, + get_radtts_encoder, ) from nemo.core.classes import Exportable, NeuralModule from nemo.core.neural_types.elements import Index, LengthsType, MelSpectrogramType, TokenDurationType, TokenIndex @@ -143,8 +143,6 @@ def __init__( n_flows, n_conv_layers_per_step, n_mel_channels, - n_hidden, - mel_encoder_n_hidden, dummy_speaker_embedding, n_early_size, n_early_every, @@ -182,7 +180,7 @@ def __init__( self.speaker_embedding = torch.nn.Embedding(n_speakers, self.n_speaker_dim) self.embedding = torch.nn.Embedding(n_text, n_text_dim) self.flows = torch.nn.ModuleList() - self.encoder = getRadTTSEncoder(encoder_embedding_dim=n_text_dim, lstm_norm_fn=text_encoder_lstm_norm) + self.encoder = get_radtts_encoder(encoder_embedding_dim=n_text_dim) self.dummy_speaker_embedding = dummy_speaker_embedding self.learn_alignments = learn_alignments self.affine_activation = affine_activation @@ -191,12 +189,10 @@ def __init__( self.use_context_lstm = bool(use_context_lstm) self.context_lstm_norm = context_lstm_norm self.context_lstm_w_f0_and_energy = context_lstm_w_f0_and_energy - # self.length_regulator = LengthRegulator() self.use_first_order_features = bool(use_first_order_features) self.decoder_use_unvoiced_bias = kwargs['decoder_use_unvoiced_bias'] self.ap_pred_log_f0 = ap_pred_log_f0 self.ap_use_unvoiced_bias = kwargs['ap_use_unvoiced_bias'] - self.prepared_for_export = False if 'atn' in include_modules or 'dec' in include_modules: if self.learn_alignments: @@ -215,12 +211,6 @@ def __init__( n_in_context_lstm = n_f0_dims + n_energy_avg_dims + n_text_dim n_in_context_lstm *= n_group_size n_in_context_lstm += self.n_speaker_dim - - n_context_hidden = n_f0_dims + n_energy_avg_dims + n_text_dim - n_context_hidden = n_context_hidden * n_group_size / 2 - n_context_hidden = self.n_speaker_dim + n_context_hidden - n_context_hidden = int(n_context_hidden) - n_flowstep_cond_dims = self.n_speaker_dim + n_text_dim * n_group_size self.context_lstm = BiLSTM( @@ -476,7 +466,6 @@ def forward( # sometimes referred to as the "squeeze" operation # invert this by calling self.fold(mel_or_z) mel = self.unfold(mel.unsqueeze(-1)) - z_out = [] # where context is folded # mask f0 in case values are interpolated context_w_spkvec = self.preprocess_context( @@ -601,8 +590,6 @@ def infer( voiced_mask=None, ): - # print ("Text, lens: ", text.shape, in_lens.shape) - batch_size = text.shape[0] n_tokens = text.shape[1] spk_vec = self.encode_speaker(speaker_id) @@ -613,7 +600,6 @@ def infer( spk_vec_text = self.encode_speaker(speaker_id_text) spk_vec_attributes = self.encode_speaker(speaker_id_attributes) txt_enc, _ = self.encode_text(text, in_lens) - # print("txt_enc: ", txt_enc.shape) if dur is None: # get token durations @@ -624,9 +610,7 @@ def infer( dur = dur[:, 0] dur = dur.clamp(0, token_duration_max) - # get attributes f0, energy, vpred, etc) txt_enc_time_expanded, out_lens = regulate_len(dur, txt_enc.transpose(1, 2), pace) - # print ("txt_enc_time_expanded, out_lens, dur: ", txt_enc_time_expanded.shape, out_lens, dur) n_groups = torch.div(out_lens, self.n_group_size, rounding_mode='floor') max_out_len = torch.max(out_lens) @@ -669,7 +653,6 @@ def infer( # may lead to mismatched lengths # FIXME: use replication pad (energy_avg, f0) = pad_energy_avg_and_f0(energy_avg, f0, max_out_len) - # print("V mask, energy_avg, f0, f0_bias: ", voiced_mask.shape, energy_avg.shape, f0.shape, f0_bias.shape) context_w_spkvec = self.preprocess_context( txt_enc_time_expanded, spk_vec, out_lens, (f0 + f0_bias) * voiced_mask, energy_avg @@ -679,7 +662,6 @@ def infer( # map from z sample to data num_steps_to_exit = len(self.exit_steps) - # mel = residual[:, num_steps_to_exit * self.n_early_size :] remaining_residual, mel = torch.tensor_split(residual, [num_steps_to_exit * self.n_early_size,], dim=1) for i, flow_step in enumerate(reversed(self.flows)): @@ -688,7 +670,6 @@ def infer( if num_steps_to_exit > 0 and curr_step == self.exit_steps[num_steps_to_exit - 1]: # concatenate the next chunk of z num_steps_to_exit = num_steps_to_exit - 1 - # residual_to_add = remaining_residual[:, num_steps_to_exit * self.n_early_size :] remaining_residual, residual_to_add = torch.tensor_split( remaining_residual, [num_steps_to_exit * self.n_early_size,], dim=1 ) @@ -697,8 +678,6 @@ def infer( if self.n_group_size > 1: mel = self.fold(mel) - # print ("mel=", mel.shape, "out_lens=", out_lens, "dur=", dur.shape) - return {'mel': mel, 'out_lens': out_lens, 'dur': dur, 'f0': f0, 'energy_avg': energy_avg} def infer_f0(self, residual, txt_enc_time_expanded, spk_vec, voiced_mask=None, lens=None): @@ -784,20 +763,6 @@ def output_types(self): def _prepare_for_export(self, **kwargs): self.remove_norms() super()._prepare_for_export(**kwargs) - if self.prepared_for_export: - return - self.encoder = torch.jit.script(self.encoder) - self.v_pred_module.feat_pred_fn = torch.jit.script(self.v_pred_module.feat_pred_fn) - if hasattr(self, 'f0_pred_module'): - self.f0_pred_module.feat_pred_fn = torch.jit.script(self.f0_pred_module.feat_pred_fn) - if hasattr(self, 'energy_pred_module'): - self.energy_pred_module.feat_pred_fn = torch.jit.script(self.energy_pred_module.feat_pred_fn) - if hasattr(self, 'dur_pred_layer'): - self.dur_pred_layer.feat_pred_fn = torch.jit.script(self.dur_pred_layer.feat_pred_fn) - - if self.use_context_lstm: - self.context_lstm = torch.jit.script(self.context_lstm) - self.prepared_for_export = True def input_example(self, max_batch=1, max_dim=256): """ diff --git a/nemo/core/classes/exportable.py b/nemo/core/classes/exportable.py index 5eecfbeca1d3..e5f7b5231600 100644 --- a/nemo/core/classes/exportable.py +++ b/nemo/core/classes/exportable.py @@ -128,7 +128,7 @@ def _export( # Set module mode with torch.onnx.select_model_mode_for_export( self, training - ), torch.inference_mode(), torch.jit.optimized_execution(True): + ), torch.inference_mode(), torch.no_grad(), torch.jit.optimized_execution(True): if input_example is None: input_example = self.input_module.input_example() diff --git a/nemo/utils/export_utils.py b/nemo/utils/export_utils.py index 4af379803786..807e696d5cf0 100644 --- a/nemo/utils/export_utils.py +++ b/nemo/utils/export_utils.py @@ -15,7 +15,7 @@ import os from contextlib import nullcontext from enum import Enum -from typing import Callable, Dict, Optional, Type +from typing import Callable, Dict, List, Optional, Type import onnx import torch @@ -386,12 +386,22 @@ def replace_modules( return model +def script_module(m: nn.Module): + m1 = torch.jit.script(m) + return m1 + + default_replacements = { "BatchNorm1d": wrap_module(nn.BatchNorm1d, CastToFloat), "BatchNorm2d": wrap_module(nn.BatchNorm2d, CastToFloat), "LayerNorm": wrap_module(nn.LayerNorm, CastToFloat), } +script_replacements = { + "BiLSTM": script_module, + "ConvLSTMLinear": script_module, +} + def replace_for_export(model: nn.Module) -> nn.Module: """ @@ -405,3 +415,5 @@ def replace_for_export(model: nn.Module) -> nn.Module: """ replace_modules(model, default_Apex_replacements) replace_modules(model, default_replacements) + # This one has to be the last + replace_modules(model, script_replacements) From c335d2f77ad2b78838432ec6b76c1fad07272abf Mon Sep 17 00:00:00 2001 From: Boris Fomitchev Date: Thu, 17 Nov 2022 03:48:41 -0800 Subject: [PATCH 43/43] Forced no-autocast Signed-off-by: Boris Fomitchev --- nemo/utils/cast_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nemo/utils/cast_utils.py b/nemo/utils/cast_utils.py index 9eb064936ea5..f973a4719e24 100644 --- a/nemo/utils/cast_utils.py +++ b/nemo/utils/cast_utils.py @@ -70,6 +70,6 @@ def __init__(self, mod): self.mod = mod def forward(self, x): - with avoid_float16_autocast_context(): + with torch.cuda.amp.autocast(enabled=False): ret = self.mod.forward(x.to(torch.float32)).to(x.dtype) return ret