Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TTS] Fixing RADTTS training - removing view buffer and fixing accuracy issue #5358

Merged
merged 58 commits into from
Nov 17, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
58 commits
Select commit Hold shift + click to select a range
73f5be2
Fixing RADTTS training - removing view buffer and fixing accuracy issue
borisfom Nov 8, 2022
fe08120
Merge remote-tracking branch 'upstream/main' into fix-radtts-training
borisfom Nov 9, 2022
cd2fbbc
Merge remote-tracking branch 'upstream/main' into fix-radtts-training
borisfom Nov 9, 2022
44a735c
Addressing code review
borisfom Nov 9, 2022
f6d2f2c
Addressing code review 2
borisfom Nov 9, 2022
835dd4d
Fixed assignment
borisfom Nov 10, 2022
da91270
Working script
borisfom Nov 10, 2022
d48393c
restored flatten_parameters
borisfom Nov 10, 2022
378c3da
Working bias alias for export
borisfom Nov 11, 2022
dedc9db
Merge remote-tracking branch 'upstream/main' into fix-radtts-training
borisfom Nov 11, 2022
4007c2a
Removing unused import
borisfom Nov 11, 2022
45f2d1a
Reverting PartialConv
borisfom Nov 11, 2022
406a0ba
Removing flatten_parameters
borisfom Nov 11, 2022
7bbc05a
Moving mask updater to GPU
borisfom Nov 11, 2022
7ce3881
Restored norms
borisfom Nov 11, 2022
864f9cd
Restored flatten
borisfom Nov 11, 2022
9c91f92
Moved to sort/unsort
borisfom Nov 11, 2022
89d47d8
Moved to masked norm
borisfom Nov 12, 2022
71a9684
Merge remote-tracking branch 'upstream/main' into fix-radtts-training
borisfom Nov 12, 2022
f6c6c0c
Turned off cache
borisfom Nov 12, 2022
5febfdf
Merge branch 'noscript' into fix-radtts-training
borisfom Nov 12, 2022
7890227
cleanup
borisfom Nov 12, 2022
0c6c6fb
Verifying cache not used
borisfom Nov 12, 2022
f8f9699
Removing cache
borisfom Nov 12, 2022
495f573
Working autocast export
borisfom Nov 12, 2022
5765d3d
restored e-6
borisfom Nov 13, 2022
65427fc
Removed some casts around masks, etc
borisfom Nov 14, 2022
8e547fd
Fixing some casts
borisfom Nov 14, 2022
47dab07
Fixing in-place ops
borisfom Nov 14, 2022
0324f46
fixing grad
borisfom Nov 14, 2022
cb8858a
Small export fixes
borisfom Nov 15, 2022
96c5dec
Merge remote-tracking branch 'upstream/main' into fix-radtts-training
borisfom Nov 15, 2022
c17ba73
LGTM cleanup
borisfom Nov 15, 2022
16c0b72
Fixed lstm_tensor
borisfom Nov 15, 2022
32a1943
restored TS check routine
borisfom Nov 15, 2022
14561be
Fixed config error
borisfom Nov 15, 2022
77404b9
Merge remote-tracking branch 'upstream/main' into fix-radtts-training
borisfom Nov 15, 2022
bb834da
reverting some bad optimizations
borisfom Nov 15, 2022
c62d4f1
Merge remote-tracking branch 'upstream/main' into fix-radtts-training
borisfom Nov 15, 2022
b95fb2c
[TTS] add CI test for RADTTS training recipe.
XuesongYang Nov 10, 2022
c620de4
Addressing code review
borisfom Nov 15, 2022
0480195
Merge remote-tracking branch 'upstream/main' into fix-radtts-training
borisfom Nov 15, 2022
cb6c459
Removing unused var
borisfom Nov 16, 2022
6f112e7
Merge remote-tracking branch 'upstream/main' into fix-radtts-training
borisfom Nov 16, 2022
87179ab
Adding debug
borisfom Nov 16, 2022
c9d9e4f
Logging fixes
borisfom Nov 16, 2022
55e95b1
Fixing training warnings
borisfom Nov 16, 2022
647640c
Fixing more warnings
borisfom Nov 16, 2022
c5649d0
Fixing more warnings 2
borisfom Nov 16, 2022
b38844f
Merge branch 'main' into fix-radtts-training
okuchaiev Nov 16, 2022
c2a8b77
Merge remote-tracking branch 'upstream/main' into fix-radtts-training
borisfom Nov 17, 2022
9b3fd7e
Merge branch 'fix-radtts-training' of github.com:borisfom/NeMo into f…
borisfom Nov 17, 2022
d3d0529
Code review fixes
borisfom Nov 17, 2022
79535a0
Improving TS check
borisfom Nov 17, 2022
4cb79d8
Merge remote-tracking branch 'upstream/main' into fix-radtts-training
borisfom Nov 17, 2022
af8bc14
Addressing code review comments, optimizing script
borisfom Nov 17, 2022
69f0987
Merge remote-tracking branch 'upstream/main' into fix-radtts-training
borisfom Nov 17, 2022
c335d2f
Forced no-autocast
borisfom Nov 17, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 34 additions & 3 deletions Jenkinsfile
Original file line number Diff line number Diff line change
Expand Up @@ -4206,7 +4206,9 @@ assert_frame_equal(training_curve, gt_curve, rtol=1e-3, atol=1e-3)"'''
validation_datasets=/home/TestData/an4_dataset/an4_val.json \
sup_data_path=/home/TestData/an4_dataset/beta_priors \
trainer.devices="[0]" \
+trainer.limit_train_batches=1 +trainer.limit_val_batches=1 trainer.max_epochs=1 \
+trainer.limit_train_batches=1 \
+trainer.limit_val_batches=1 \
trainer.max_epochs=1 \
trainer.strategy=null \
model.pitch_mean=212.35873413085938 \
model.pitch_std=68.52806091308594 \
Expand All @@ -4224,14 +4226,41 @@ assert_frame_equal(training_curve, gt_curve, rtol=1e-3, atol=1e-3)"'''
~model.text_normalizer_call_kwargs'
}
}
stage('RADTTS') {
steps {
sh 'python examples/tts/radtts.py \
train_dataset=/home/TestData/an4_dataset/an4_train.json \
validation_datasets=/home/TestData/an4_dataset/an4_val.json \
sup_data_path=/home/TestData/an4_dataset/radtts_beta_priors \
trainer.devices="[0]" \
+trainer.limit_train_batches=1 \
+trainer.limit_val_batches=1 \
trainer.max_epochs=1 \
trainer.strategy=null \
model.pitch_mean=212.35873413085938 \
model.pitch_std=68.52806091308594 \
model.train_ds.dataloader_params.batch_size=4 \
model.train_ds.dataloader_params.num_workers=0 \
model.validation_ds.dataloader_params.batch_size=4 \
model.validation_ds.dataloader_params.num_workers=0 \
export_dir=/home/TestData/radtts_test \
model.optim.lr=0.0001 \
model.modelConfig.decoder_use_partial_padding=True \
~trainer.check_val_every_n_epoch \
~model.text_normalizer \
~model.text_normalizer_call_kwargs'
}
}
stage('Mixer-TTS') {
steps {
sh 'python examples/tts/mixer_tts.py \
train_dataset=/home/TestData/an4_dataset/an4_train.json \
validation_datasets=/home/TestData/an4_dataset/an4_val.json \
sup_data_path=/home/TestData/an4_dataset/sup_data \
trainer.devices="[0]" \
+trainer.limit_train_batches=1 +trainer.limit_val_batches=1 trainer.max_epochs=1 \
+trainer.limit_train_batches=1 \
+trainer.limit_val_batches=1 \
trainer.max_epochs=1 \
trainer.strategy=null \
model.pitch_mean=212.35873413085938 \
model.pitch_std=68.52806091308594 \
Expand All @@ -4250,7 +4279,9 @@ assert_frame_equal(training_curve, gt_curve, rtol=1e-3, atol=1e-3)"'''
train_dataset=/home/TestData/an4_dataset/an4_train.json \
validation_datasets=/home/TestData/an4_dataset/an4_val.json \
trainer.devices="[0]" \
+trainer.limit_train_batches=1 +trainer.limit_val_batches=1 +trainer.max_epochs=1 \
+trainer.limit_train_batches=1 \
+trainer.limit_val_batches=1 \
+trainer.max_epochs=1 \
trainer.strategy=null \
model.train_ds.dataloader_params.batch_size=4 \
model.train_ds.dataloader_params.num_workers=0 \
Expand Down
2 changes: 1 addition & 1 deletion nemo/collections/common/callbacks/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.
import time

from pytorch_lightning.callbacks.base import Callback
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.utilities import rank_zero_only

# from sacrebleu import corpus_bleu
Expand Down
19 changes: 2 additions & 17 deletions nemo/collections/tts/models/radtts.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
from nemo.collections.tts.helpers.helpers import plot_alignment_to_numpy
from nemo.collections.tts.losses.radttsloss import AttentionBinarizationLoss, RADTTSLoss
from nemo.collections.tts.models.base import SpectrogramGenerator
from nemo.collections.tts.modules.submodules import PartialConv1d
from nemo.core.classes import Exportable
from nemo.core.classes.common import typecheck
from nemo.core.neural_types.elements import Index, MelSpectrogramType, TokenIndex
Expand Down Expand Up @@ -159,7 +158,7 @@ def training_step(self, batch, batch_idx):
loss_outputs['binarization_loss'] = (binarization_loss, 1.0)

for k, (v, w) in loss_outputs.items():
self.log("train/" + k, loss_outputs[k][0])
self.log("train/" + k, loss_outputs[k][0], on_step=True)

return {'loss': loss}

Expand Down Expand Up @@ -229,7 +228,7 @@ def validation_epoch_end(self, outputs):

for k, v in loss_outputs.items():
if k != "binarization_loss":
self.log("val/" + k, loss_outputs[k][0])
self.log("val/" + k, loss_outputs[k][0], sync_dist=True, on_epoch=True)

attn = outputs[0]["attn"]
attn_soft = outputs[0]["attn_soft"]
Expand Down Expand Up @@ -407,17 +406,3 @@ def output_module(self):

def forward_for_export(self, text, lens, speaker_id, speaker_id_text, speaker_id_attributes):
return self.model.forward_for_export(text, lens, speaker_id, speaker_id_text, speaker_id_attributes)

def get_export_subnet(self, subnet=None):
return self.model.get_export_subnet(subnet)

def _prepare_for_export(self, **kwargs):
"""
Override this method to prepare module for export. This is in-place operation.
Base version does common necessary module replacements (Apex etc)
"""
PartialConv1d.forward = PartialConv1d.forward_no_cache
super()._prepare_for_export(**kwargs)

def _export_teardown(self):
PartialConv1d.forward = PartialConv1d.forward_with_cache
95 changes: 28 additions & 67 deletions nemo/collections/tts/modules/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
piecewise_linear_transform,
unbounded_piecewise_quadratic_transform,
)
from nemo.collections.tts.modules.submodules import ConvNorm, LinearNorm
from nemo.collections.tts.modules.submodules import ConvNorm, LinearNorm, MaskedInstanceNorm1d


@torch.jit.script
Expand All @@ -44,7 +44,7 @@ def get_mask_from_lengths_and_val(lengths, val):
max_len = val.shape[-1]
ids = torch.arange(0, max_len, device=lengths.device)
mask = ids < lengths.unsqueeze(1)
return mask.float()
return mask
XuesongYang marked this conversation as resolved.
Show resolved Hide resolved


@torch.jit.script
Expand Down Expand Up @@ -123,30 +123,31 @@ def lstm_tensor(self, context: Tensor, lens: Tensor, enforce_sorted: bool = Fals
seq = nn.utils.rnn.pack_padded_sequence(
context, lens.long().cpu(), batch_first=True, enforce_sorted=enforce_sorted
)
if not torch.jit.is_scripting():
if not (torch.jit.is_scripting() or torch.jit.is_tracing()):
self.bilstm.flatten_parameters()
ret, _ = self.bilstm(seq)
if hasattr(self.bilstm, 'forward'):
ret, _ = self.bilstm.forward(seq)
else:
ret, _ = self.bilstm.forward_1(seq)
return nn.utils.rnn.pad_packed_sequence(ret, batch_first=True)

@torch.jit.export
def lstm_sequence(self, seq: PackedSequence) -> Tuple[Tensor, Tensor]:
if not torch.jit.is_scripting():
if not (torch.jit.is_scripting() or torch.jit.is_tracing()):
self.bilstm.flatten_parameters()
ret, _ = self.bilstm(seq)
if hasattr(self.bilstm, 'forward'):
ret, _ = self.bilstm.forward(seq)
elif hasattr(self.bilstm, 'forward_1'):
ret, _ = self.bilstm.forward_1(seq)
return nn.utils.rnn.pad_packed_sequence(ret, batch_first=True)

XuesongYang marked this conversation as resolved.
Show resolved Hide resolved
@torch.jit.export
def sort_and_lstm_tensor(self, context: Tensor, lens: Tensor) -> Tensor:
lens_sorted, ids_sorted = torch.sort(lens, descending=True)
unsort_ids = torch.zeros_like(ids_sorted)
for i in range(ids_sorted.shape[0]):
unsort_ids[ids_sorted[i]] = i
context = context[ids_sorted]
context, lens_sorted, unsort_ids = sort_tensor(context, lens)
seq = nn.utils.rnn.pack_padded_sequence(
context, lens_sorted.long().cpu(), batch_first=True, enforce_sorted=True
)
ret, _ = self.bilstm(seq)
return nn.utils.rnn.pad_packed_sequence(ret, batch_first=True)[0][unsort_ids]
return self.lstm_sequence(seq)[0][unsort_ids]


class ConvLSTMLinear(BiLSTM):
Expand All @@ -160,14 +161,14 @@ def __init__(
p_dropout=0.1,
use_partial_padding=False,
norm_fn=None,
lstm_norm_fn="spectral",
):
super(ConvLSTMLinear, self).__init__(n_channels, int(n_channels // 2), 1)
self.out_dim = out_dim
self.convolutions = nn.ModuleList()

if n_layers > 0:
self.dropout = nn.Dropout(p=p_dropout)
self.convolutions = nn.ModuleList()

use_weight_norm = norm_fn is None

for i in range(n_layers):
conv_layer = ConvNorm(
Expand All @@ -178,85 +179,46 @@ def __init__(
padding=int((kernel_size - 1) / 2),
dilation=1,
w_init_gain='relu',
use_weight_norm=False,
use_weight_norm=use_weight_norm,
use_partial_padding=use_partial_padding,
norm_fn=norm_fn,
)
if norm_fn is not None:
print("Applying {} norm to {}".format(norm_fn, conv_layer))
else:
conv_layer = torch.nn.utils.weight_norm(conv_layer.conv)
print("Applying weight norm to {}".format(conv_layer))
self.convolutions.append(conv_layer)

self.dense = None
if out_dim is not None:
self.dense = nn.Linear(n_channels, out_dim)

@torch.jit.export
def conv_to_sequence(self, context: Tensor, lens: Tensor, enforce_sorted: bool = False) -> PackedSequence:
context_embedded = []
bs: int = context.shape[0]
b_ind: int = 0
for b_ind in range(bs): # TODO: speed up
curr_context = context[b_ind : b_ind + 1, :, : lens[b_ind]].clone()
for conv in self.convolutions:
curr_context = self.dropout(F.relu(conv(curr_context)))
context_embedded.append(curr_context[0].transpose(0, 1))
seq = torch.nn.utils.rnn.pack_sequence(context_embedded, enforce_sorted=enforce_sorted)
return seq

@torch.jit.export
def conv_to_padded_tensor(self, context: Tensor, lens: Tensor) -> Tensor:
context_embedded = []
bs: int = context.shape[0]
b_ind: int = 0
for b_ind in range(bs): # TODO: speed up
curr_context = context[b_ind : b_ind + 1, :, : lens[b_ind]].clone()
for conv in self.convolutions:
curr_context = self.dropout(F.relu(conv(curr_context)))
context_embedded.append(curr_context[0].transpose(0, 1))
ret = torch.nn.utils.rnn.pad_sequence(context_embedded, batch_first=True)
return ret

@torch.jit.export
def masked_conv_to_sequence(self, context: Tensor, lens: Tensor, enforce_sorted: bool = False) -> PackedSequence:
titu1994 marked this conversation as resolved.
Show resolved Hide resolved
mask = get_mask_from_lengths_and_val(lens, context)
mask = mask.unsqueeze(1)
mask = mask.to(dtype=context.dtype).unsqueeze(1)
for conv in self.convolutions:
context = self.dropout(F.relu(conv(context, mask)))
context = torch.mul(context, mask)

context = context.transpose(1, 2)
seq = torch.nn.utils.rnn.pack_padded_sequence(
context, lens.long().cpu(), batch_first=True, enforce_sorted=enforce_sorted
)
return seq

def forward(self, context: Tensor, lens: Optional[Tensor] = None) -> Tensor:
if lens is None:
for conv in self.convolutions:
context = self.dropout(F.relu(conv(context)))
context = context.transpose(1, 2)
context, _ = self.bilstm(context)
else:
# borisf : does not match ADLR (values, lengths)
# seq = self.masked_conv_to_sequence(context, lens, enforce_sorted=False)
# borisf : does match ADLR
seq = self.conv_to_sequence(context, lens, enforce_sorted=False)
context, _ = self.lstm_sequence(seq)
def forward(self, context: Tensor, lens: Tensor) -> Tensor:
borisfom marked this conversation as resolved.
Show resolved Hide resolved
context, lens, unsort_ids = sort_tensor(context, lens)
seq = self.masked_conv_to_sequence(context, lens, enforce_sorted=True)
context, _ = self.lstm_sequence(seq)
context = context[unsort_ids]

if self.dense is not None:
context = self.dense(context).permute(0, 2, 1)

return context


def getRadTTSEncoder(
encoder_n_convolutions=3,
encoder_embedding_dim=512,
encoder_kernel_size=5,
norm_fn=nn.BatchNorm1d,
lstm_norm_fn=None,
def get_radtts_encoder(
encoder_n_convolutions=3, encoder_embedding_dim=512, encoder_kernel_size=5, norm_fn=MaskedInstanceNorm1d,
):
return ConvLSTMLinear(
in_dim=encoder_embedding_dim,
Expand All @@ -266,15 +228,14 @@ def getRadTTSEncoder(
p_dropout=0.5,
use_partial_padding=True,
norm_fn=norm_fn,
lstm_norm_fn=lstm_norm_fn,
)


class Invertible1x1ConvLUS(torch.nn.Module):
def __init__(self, c):
super(Invertible1x1ConvLUS, self).__init__()
# Sample a random orthonormal matrix to initialize weights
W = torch.qr(torch.FloatTensor(c, c).normal_())[0]
W, _ = torch.linalg.qr(torch.FloatTensor(c, c).normal_())
# Ensure determinant is 1.0 not -1.0
if torch.det(W) < 0:
W[:, 0] = -1 * W[:, 0]
Expand Down
Loading