Skip to content

Commit

Permalink
[TTS] Fixing RADTTS training - removing view buffer and fixing accura…
Browse files Browse the repository at this point in the history
…cy issue (NVIDIA#5358)

* Fixing RADTTS training - removing view buffer and fixing accuracy issue

Signed-off-by: Boris Fomitchev <[email protected]>

* Addressing code review

Signed-off-by: Boris Fomitchev <[email protected]>

* Addressing code review 2

Signed-off-by: Boris Fomitchev <[email protected]>

* Fixed assignment

Signed-off-by: Boris Fomitchev <[email protected]>

* Working script

Signed-off-by: Boris Fomitchev <[email protected]>

* restored flatten_parameters

Signed-off-by: Boris Fomitchev <[email protected]>

* Working bias alias for export

Signed-off-by: Boris Fomitchev <[email protected]>

* Removing unused import

Signed-off-by: Boris Fomitchev <[email protected]>

* Reverting PartialConv

Signed-off-by: Boris Fomitchev <[email protected]>

* Removing flatten_parameters

Signed-off-by: Boris Fomitchev <[email protected]>

* Moving mask updater to GPU

Signed-off-by: Boris Fomitchev <[email protected]>

* Restored norms

Signed-off-by: Boris Fomitchev <[email protected]>

* Restored flatten

Signed-off-by: Boris Fomitchev <[email protected]>

* Moved to sort/unsort

Signed-off-by: Boris Fomitchev <[email protected]>

* Moved to masked norm

Signed-off-by: Boris Fomitchev <[email protected]>

* Turned off cache

Signed-off-by: Boris Fomitchev <[email protected]>

* cleanup

Signed-off-by: Boris Fomitchev <[email protected]>

* Verifying cache not used

Signed-off-by: Boris Fomitchev <[email protected]>

* Removing cache

Signed-off-by: Boris Fomitchev <[email protected]>

* Working autocast export

Signed-off-by: Boris Fomitchev <[email protected]>

* restored e-6

Signed-off-by: Boris Fomitchev <[email protected]>

* Removed some casts around masks, etc

Signed-off-by: Boris Fomitchev <[email protected]>

* Fixing some casts

Signed-off-by: Boris Fomitchev <[email protected]>

* Fixing in-place  ops

Signed-off-by: Boris Fomitchev <[email protected]>

* fixing grad

Signed-off-by: Boris Fomitchev <[email protected]>

* Small export fixes

Signed-off-by: Boris Fomitchev <[email protected]>

* LGTM cleanup

Signed-off-by: Boris Fomitchev <[email protected]>

* Fixed lstm_tensor

Signed-off-by: Boris Fomitchev <[email protected]>

* restored TS check routine

Signed-off-by: Boris Fomitchev <[email protected]>

* Fixed config error

Signed-off-by: Boris Fomitchev <[email protected]>

* reverting some bad optimizations

Signed-off-by: Boris Fomitchev <[email protected]>

* [TTS] add CI test for RADTTS training recipe.

Signed-off-by: Xuesong Yang <[email protected]>

* Addressing code review

Signed-off-by: Boris Fomitchev <[email protected]>

* Removing unused var

Signed-off-by: Boris Fomitchev <[email protected]>

* Adding debug

Signed-off-by: Boris Fomitchev <[email protected]>

* Logging fixes

Signed-off-by: Boris Fomitchev <[email protected]>

* Fixing training warnings

Signed-off-by: Boris Fomitchev <[email protected]>

* Fixing more warnings

Signed-off-by: Boris Fomitchev <[email protected]>

* Fixing more warnings 2

Signed-off-by: Boris Fomitchev <[email protected]>

* Code review fixes

Signed-off-by: Boris Fomitchev <[email protected]>

* Improving TS check

Signed-off-by: Boris Fomitchev <[email protected]>

* Addressing code review comments, optimizing script

Signed-off-by: Boris Fomitchev <[email protected]>

* Forced no-autocast

Signed-off-by: Boris Fomitchev <[email protected]>

Signed-off-by: Boris Fomitchev <[email protected]>
Signed-off-by: Xuesong Yang <[email protected]>
Co-authored-by: Xuesong Yang <[email protected]>
Co-authored-by: Oleksii Kuchaiev <[email protected]>
Signed-off-by: Hainan Xu <[email protected]>
  • Loading branch information
3 people authored and Hainan Xu committed Nov 29, 2022
1 parent ea61803 commit 4125ab7
Show file tree
Hide file tree
Showing 12 changed files with 240 additions and 254 deletions.
37 changes: 34 additions & 3 deletions Jenkinsfile
Original file line number Diff line number Diff line change
Expand Up @@ -4206,7 +4206,9 @@ assert_frame_equal(training_curve, gt_curve, rtol=1e-3, atol=1e-3)"'''
validation_datasets=/home/TestData/an4_dataset/an4_val.json \
sup_data_path=/home/TestData/an4_dataset/beta_priors \
trainer.devices="[0]" \
+trainer.limit_train_batches=1 +trainer.limit_val_batches=1 trainer.max_epochs=1 \
+trainer.limit_train_batches=1 \
+trainer.limit_val_batches=1 \
trainer.max_epochs=1 \
trainer.strategy=null \
model.pitch_mean=212.35873413085938 \
model.pitch_std=68.52806091308594 \
Expand All @@ -4224,14 +4226,41 @@ assert_frame_equal(training_curve, gt_curve, rtol=1e-3, atol=1e-3)"'''
~model.text_normalizer_call_kwargs'
}
}
stage('RADTTS') {
steps {
sh 'python examples/tts/radtts.py \
train_dataset=/home/TestData/an4_dataset/an4_train.json \
validation_datasets=/home/TestData/an4_dataset/an4_val.json \
sup_data_path=/home/TestData/an4_dataset/radtts_beta_priors \
trainer.devices="[0]" \
+trainer.limit_train_batches=1 \
+trainer.limit_val_batches=1 \
trainer.max_epochs=1 \
trainer.strategy=null \
model.pitch_mean=212.35873413085938 \
model.pitch_std=68.52806091308594 \
model.train_ds.dataloader_params.batch_size=4 \
model.train_ds.dataloader_params.num_workers=0 \
model.validation_ds.dataloader_params.batch_size=4 \
model.validation_ds.dataloader_params.num_workers=0 \
export_dir=/home/TestData/radtts_test \
model.optim.lr=0.0001 \
model.modelConfig.decoder_use_partial_padding=True \
~trainer.check_val_every_n_epoch \
~model.text_normalizer \
~model.text_normalizer_call_kwargs'
}
}
stage('Mixer-TTS') {
steps {
sh 'python examples/tts/mixer_tts.py \
train_dataset=/home/TestData/an4_dataset/an4_train.json \
validation_datasets=/home/TestData/an4_dataset/an4_val.json \
sup_data_path=/home/TestData/an4_dataset/sup_data \
trainer.devices="[0]" \
+trainer.limit_train_batches=1 +trainer.limit_val_batches=1 trainer.max_epochs=1 \
+trainer.limit_train_batches=1 \
+trainer.limit_val_batches=1 \
trainer.max_epochs=1 \
trainer.strategy=null \
model.pitch_mean=212.35873413085938 \
model.pitch_std=68.52806091308594 \
Expand All @@ -4250,7 +4279,9 @@ assert_frame_equal(training_curve, gt_curve, rtol=1e-3, atol=1e-3)"'''
train_dataset=/home/TestData/an4_dataset/an4_train.json \
validation_datasets=/home/TestData/an4_dataset/an4_val.json \
trainer.devices="[0]" \
+trainer.limit_train_batches=1 +trainer.limit_val_batches=1 +trainer.max_epochs=1 \
+trainer.limit_train_batches=1 \
+trainer.limit_val_batches=1 \
+trainer.max_epochs=1 \
trainer.strategy=null \
model.train_ds.dataloader_params.batch_size=4 \
model.train_ds.dataloader_params.num_workers=0 \
Expand Down
2 changes: 1 addition & 1 deletion nemo/collections/common/callbacks/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.
import time

from pytorch_lightning.callbacks.base import Callback
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.utilities import rank_zero_only

# from sacrebleu import corpus_bleu
Expand Down
19 changes: 2 additions & 17 deletions nemo/collections/tts/models/radtts.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
from nemo.collections.tts.helpers.helpers import plot_alignment_to_numpy
from nemo.collections.tts.losses.radttsloss import AttentionBinarizationLoss, RADTTSLoss
from nemo.collections.tts.models.base import SpectrogramGenerator
from nemo.collections.tts.modules.submodules import PartialConv1d
from nemo.core.classes import Exportable
from nemo.core.classes.common import typecheck
from nemo.core.neural_types.elements import Index, MelSpectrogramType, TokenIndex
Expand Down Expand Up @@ -159,7 +158,7 @@ def training_step(self, batch, batch_idx):
loss_outputs['binarization_loss'] = (binarization_loss, 1.0)

for k, (v, w) in loss_outputs.items():
self.log("train/" + k, loss_outputs[k][0])
self.log("train/" + k, loss_outputs[k][0], on_step=True)

return {'loss': loss}

Expand Down Expand Up @@ -229,7 +228,7 @@ def validation_epoch_end(self, outputs):

for k, v in loss_outputs.items():
if k != "binarization_loss":
self.log("val/" + k, loss_outputs[k][0])
self.log("val/" + k, loss_outputs[k][0], sync_dist=True, on_epoch=True)

attn = outputs[0]["attn"]
attn_soft = outputs[0]["attn_soft"]
Expand Down Expand Up @@ -407,17 +406,3 @@ def output_module(self):

def forward_for_export(self, text, lens, speaker_id, speaker_id_text, speaker_id_attributes):
return self.model.forward_for_export(text, lens, speaker_id, speaker_id_text, speaker_id_attributes)

def get_export_subnet(self, subnet=None):
return self.model.get_export_subnet(subnet)

def _prepare_for_export(self, **kwargs):
"""
Override this method to prepare module for export. This is in-place operation.
Base version does common necessary module replacements (Apex etc)
"""
PartialConv1d.forward = PartialConv1d.forward_no_cache
super()._prepare_for_export(**kwargs)

def _export_teardown(self):
PartialConv1d.forward = PartialConv1d.forward_with_cache
95 changes: 28 additions & 67 deletions nemo/collections/tts/modules/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
piecewise_linear_transform,
unbounded_piecewise_quadratic_transform,
)
from nemo.collections.tts.modules.submodules import ConvNorm, LinearNorm
from nemo.collections.tts.modules.submodules import ConvNorm, LinearNorm, MaskedInstanceNorm1d


@torch.jit.script
Expand All @@ -44,7 +44,7 @@ def get_mask_from_lengths_and_val(lengths, val):
max_len = val.shape[-1]
ids = torch.arange(0, max_len, device=lengths.device)
mask = ids < lengths.unsqueeze(1)
return mask.float()
return mask


@torch.jit.script
Expand Down Expand Up @@ -123,30 +123,31 @@ def lstm_tensor(self, context: Tensor, lens: Tensor, enforce_sorted: bool = Fals
seq = nn.utils.rnn.pack_padded_sequence(
context, lens.long().cpu(), batch_first=True, enforce_sorted=enforce_sorted
)
if not torch.jit.is_scripting():
if not (torch.jit.is_scripting() or torch.jit.is_tracing()):
self.bilstm.flatten_parameters()
ret, _ = self.bilstm(seq)
if hasattr(self.bilstm, 'forward'):
ret, _ = self.bilstm.forward(seq)
else:
ret, _ = self.bilstm.forward_1(seq)
return nn.utils.rnn.pad_packed_sequence(ret, batch_first=True)

@torch.jit.export
def lstm_sequence(self, seq: PackedSequence) -> Tuple[Tensor, Tensor]:
if not torch.jit.is_scripting():
if not (torch.jit.is_scripting() or torch.jit.is_tracing()):
self.bilstm.flatten_parameters()
ret, _ = self.bilstm(seq)
if hasattr(self.bilstm, 'forward'):
ret, _ = self.bilstm.forward(seq)
elif hasattr(self.bilstm, 'forward_1'):
ret, _ = self.bilstm.forward_1(seq)
return nn.utils.rnn.pad_packed_sequence(ret, batch_first=True)

@torch.jit.export
def sort_and_lstm_tensor(self, context: Tensor, lens: Tensor) -> Tensor:
lens_sorted, ids_sorted = torch.sort(lens, descending=True)
unsort_ids = torch.zeros_like(ids_sorted)
for i in range(ids_sorted.shape[0]):
unsort_ids[ids_sorted[i]] = i
context = context[ids_sorted]
context, lens_sorted, unsort_ids = sort_tensor(context, lens)
seq = nn.utils.rnn.pack_padded_sequence(
context, lens_sorted.long().cpu(), batch_first=True, enforce_sorted=True
)
ret, _ = self.bilstm(seq)
return nn.utils.rnn.pad_packed_sequence(ret, batch_first=True)[0][unsort_ids]
return self.lstm_sequence(seq)[0][unsort_ids]


class ConvLSTMLinear(BiLSTM):
Expand All @@ -160,14 +161,14 @@ def __init__(
p_dropout=0.1,
use_partial_padding=False,
norm_fn=None,
lstm_norm_fn="spectral",
):
super(ConvLSTMLinear, self).__init__(n_channels, int(n_channels // 2), 1)
self.out_dim = out_dim
self.convolutions = nn.ModuleList()

if n_layers > 0:
self.dropout = nn.Dropout(p=p_dropout)
self.convolutions = nn.ModuleList()

use_weight_norm = norm_fn is None

for i in range(n_layers):
conv_layer = ConvNorm(
Expand All @@ -178,85 +179,46 @@ def __init__(
padding=int((kernel_size - 1) / 2),
dilation=1,
w_init_gain='relu',
use_weight_norm=False,
use_weight_norm=use_weight_norm,
use_partial_padding=use_partial_padding,
norm_fn=norm_fn,
)
if norm_fn is not None:
print("Applying {} norm to {}".format(norm_fn, conv_layer))
else:
conv_layer = torch.nn.utils.weight_norm(conv_layer.conv)
print("Applying weight norm to {}".format(conv_layer))
self.convolutions.append(conv_layer)

self.dense = None
if out_dim is not None:
self.dense = nn.Linear(n_channels, out_dim)

@torch.jit.export
def conv_to_sequence(self, context: Tensor, lens: Tensor, enforce_sorted: bool = False) -> PackedSequence:
context_embedded = []
bs: int = context.shape[0]
b_ind: int = 0
for b_ind in range(bs): # TODO: speed up
curr_context = context[b_ind : b_ind + 1, :, : lens[b_ind]].clone()
for conv in self.convolutions:
curr_context = self.dropout(F.relu(conv(curr_context)))
context_embedded.append(curr_context[0].transpose(0, 1))
seq = torch.nn.utils.rnn.pack_sequence(context_embedded, enforce_sorted=enforce_sorted)
return seq

@torch.jit.export
def conv_to_padded_tensor(self, context: Tensor, lens: Tensor) -> Tensor:
context_embedded = []
bs: int = context.shape[0]
b_ind: int = 0
for b_ind in range(bs): # TODO: speed up
curr_context = context[b_ind : b_ind + 1, :, : lens[b_ind]].clone()
for conv in self.convolutions:
curr_context = self.dropout(F.relu(conv(curr_context)))
context_embedded.append(curr_context[0].transpose(0, 1))
ret = torch.nn.utils.rnn.pad_sequence(context_embedded, batch_first=True)
return ret

@torch.jit.export
def masked_conv_to_sequence(self, context: Tensor, lens: Tensor, enforce_sorted: bool = False) -> PackedSequence:
mask = get_mask_from_lengths_and_val(lens, context)
mask = mask.unsqueeze(1)
mask = mask.to(dtype=context.dtype).unsqueeze(1)
for conv in self.convolutions:
context = self.dropout(F.relu(conv(context, mask)))
context = torch.mul(context, mask)

context = context.transpose(1, 2)
seq = torch.nn.utils.rnn.pack_padded_sequence(
context, lens.long().cpu(), batch_first=True, enforce_sorted=enforce_sorted
)
return seq

def forward(self, context: Tensor, lens: Optional[Tensor] = None) -> Tensor:
if lens is None:
for conv in self.convolutions:
context = self.dropout(F.relu(conv(context)))
context = context.transpose(1, 2)
context, _ = self.bilstm(context)
else:
# borisf : does not match ADLR (values, lengths)
# seq = self.masked_conv_to_sequence(context, lens, enforce_sorted=False)
# borisf : does match ADLR
seq = self.conv_to_sequence(context, lens, enforce_sorted=False)
context, _ = self.lstm_sequence(seq)
def forward(self, context: Tensor, lens: Tensor) -> Tensor:
context, lens, unsort_ids = sort_tensor(context, lens)
seq = self.masked_conv_to_sequence(context, lens, enforce_sorted=True)
context, _ = self.lstm_sequence(seq)
context = context[unsort_ids]

if self.dense is not None:
context = self.dense(context).permute(0, 2, 1)

return context


def getRadTTSEncoder(
encoder_n_convolutions=3,
encoder_embedding_dim=512,
encoder_kernel_size=5,
norm_fn=nn.BatchNorm1d,
lstm_norm_fn=None,
def get_radtts_encoder(
encoder_n_convolutions=3, encoder_embedding_dim=512, encoder_kernel_size=5, norm_fn=MaskedInstanceNorm1d,
):
return ConvLSTMLinear(
in_dim=encoder_embedding_dim,
Expand All @@ -266,15 +228,14 @@ def getRadTTSEncoder(
p_dropout=0.5,
use_partial_padding=True,
norm_fn=norm_fn,
lstm_norm_fn=lstm_norm_fn,
)


class Invertible1x1ConvLUS(torch.nn.Module):
def __init__(self, c):
super(Invertible1x1ConvLUS, self).__init__()
# Sample a random orthonormal matrix to initialize weights
W = torch.qr(torch.FloatTensor(c, c).normal_())[0]
W, _ = torch.linalg.qr(torch.FloatTensor(c, c).normal_())
# Ensure determinant is 1.0 not -1.0
if torch.det(W) < 0:
W[:, 0] = -1 * W[:, 0]
Expand Down
Loading

0 comments on commit 4125ab7

Please sign in to comment.