Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TTS] Fixing RADTTS training - removing view buffer and fixing accuracy issue #5358

Merged
merged 58 commits into from
Nov 17, 2022
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
58 commits
Select commit Hold shift + click to select a range
73f5be2
Fixing RADTTS training - removing view buffer and fixing accuracy issue
borisfom Nov 8, 2022
fe08120
Merge remote-tracking branch 'upstream/main' into fix-radtts-training
borisfom Nov 9, 2022
cd2fbbc
Merge remote-tracking branch 'upstream/main' into fix-radtts-training
borisfom Nov 9, 2022
44a735c
Addressing code review
borisfom Nov 9, 2022
f6d2f2c
Addressing code review 2
borisfom Nov 9, 2022
835dd4d
Fixed assignment
borisfom Nov 10, 2022
da91270
Working script
borisfom Nov 10, 2022
d48393c
restored flatten_parameters
borisfom Nov 10, 2022
378c3da
Working bias alias for export
borisfom Nov 11, 2022
dedc9db
Merge remote-tracking branch 'upstream/main' into fix-radtts-training
borisfom Nov 11, 2022
4007c2a
Removing unused import
borisfom Nov 11, 2022
45f2d1a
Reverting PartialConv
borisfom Nov 11, 2022
406a0ba
Removing flatten_parameters
borisfom Nov 11, 2022
7bbc05a
Moving mask updater to GPU
borisfom Nov 11, 2022
7ce3881
Restored norms
borisfom Nov 11, 2022
864f9cd
Restored flatten
borisfom Nov 11, 2022
9c91f92
Moved to sort/unsort
borisfom Nov 11, 2022
89d47d8
Moved to masked norm
borisfom Nov 12, 2022
71a9684
Merge remote-tracking branch 'upstream/main' into fix-radtts-training
borisfom Nov 12, 2022
f6c6c0c
Turned off cache
borisfom Nov 12, 2022
5febfdf
Merge branch 'noscript' into fix-radtts-training
borisfom Nov 12, 2022
7890227
cleanup
borisfom Nov 12, 2022
0c6c6fb
Verifying cache not used
borisfom Nov 12, 2022
f8f9699
Removing cache
borisfom Nov 12, 2022
495f573
Working autocast export
borisfom Nov 12, 2022
5765d3d
restored e-6
borisfom Nov 13, 2022
65427fc
Removed some casts around masks, etc
borisfom Nov 14, 2022
8e547fd
Fixing some casts
borisfom Nov 14, 2022
47dab07
Fixing in-place ops
borisfom Nov 14, 2022
0324f46
fixing grad
borisfom Nov 14, 2022
cb8858a
Small export fixes
borisfom Nov 15, 2022
96c5dec
Merge remote-tracking branch 'upstream/main' into fix-radtts-training
borisfom Nov 15, 2022
c17ba73
LGTM cleanup
borisfom Nov 15, 2022
16c0b72
Fixed lstm_tensor
borisfom Nov 15, 2022
32a1943
restored TS check routine
borisfom Nov 15, 2022
14561be
Fixed config error
borisfom Nov 15, 2022
77404b9
Merge remote-tracking branch 'upstream/main' into fix-radtts-training
borisfom Nov 15, 2022
bb834da
reverting some bad optimizations
borisfom Nov 15, 2022
c62d4f1
Merge remote-tracking branch 'upstream/main' into fix-radtts-training
borisfom Nov 15, 2022
b95fb2c
[TTS] add CI test for RADTTS training recipe.
XuesongYang Nov 10, 2022
c620de4
Addressing code review
borisfom Nov 15, 2022
0480195
Merge remote-tracking branch 'upstream/main' into fix-radtts-training
borisfom Nov 15, 2022
cb6c459
Removing unused var
borisfom Nov 16, 2022
6f112e7
Merge remote-tracking branch 'upstream/main' into fix-radtts-training
borisfom Nov 16, 2022
87179ab
Adding debug
borisfom Nov 16, 2022
c9d9e4f
Logging fixes
borisfom Nov 16, 2022
55e95b1
Fixing training warnings
borisfom Nov 16, 2022
647640c
Fixing more warnings
borisfom Nov 16, 2022
c5649d0
Fixing more warnings 2
borisfom Nov 16, 2022
b38844f
Merge branch 'main' into fix-radtts-training
okuchaiev Nov 16, 2022
c2a8b77
Merge remote-tracking branch 'upstream/main' into fix-radtts-training
borisfom Nov 17, 2022
9b3fd7e
Merge branch 'fix-radtts-training' of github.com:borisfom/NeMo into f…
borisfom Nov 17, 2022
d3d0529
Code review fixes
borisfom Nov 17, 2022
79535a0
Improving TS check
borisfom Nov 17, 2022
4cb79d8
Merge remote-tracking branch 'upstream/main' into fix-radtts-training
borisfom Nov 17, 2022
af8bc14
Addressing code review comments, optimizing script
borisfom Nov 17, 2022
69f0987
Merge remote-tracking branch 'upstream/main' into fix-radtts-training
borisfom Nov 17, 2022
c335d2f
Forced no-autocast
borisfom Nov 17, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 28 additions & 17 deletions nemo/collections/tts/modules/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,15 +123,15 @@ def lstm_tensor(self, context: Tensor, lens: Tensor, enforce_sorted: bool = Fals
seq = nn.utils.rnn.pack_padded_sequence(
context, lens.long().cpu(), batch_first=True, enforce_sorted=enforce_sorted
)
if not torch.jit.is_scripting():
self.bilstm.flatten_parameters()
# if not torch.jit.is_scripting():
# self.bilstm.flatten_parameters()
ret, _ = self.bilstm(seq)
return nn.utils.rnn.pad_packed_sequence(ret, batch_first=True)

@torch.jit.export
def lstm_sequence(self, seq: PackedSequence) -> Tuple[Tensor, Tensor]:
if not torch.jit.is_scripting():
self.bilstm.flatten_parameters()
# if not torch.jit.is_scripting():
# self.bilstm.flatten_parameters()
ret, _ = self.bilstm(seq)
return nn.utils.rnn.pad_packed_sequence(ret, batch_first=True)

XuesongYang marked this conversation as resolved.
Show resolved Hide resolved
Expand Down Expand Up @@ -164,10 +164,10 @@ def __init__(
):
super(ConvLSTMLinear, self).__init__(n_channels, int(n_channels // 2), 1)
self.out_dim = out_dim
self.convolutions = nn.ModuleList()

if n_layers > 0:
self.dropout = nn.Dropout(p=p_dropout)
self.convolutions = nn.ModuleList()

for i in range(n_layers):
conv_layer = ConvNorm(
Expand Down Expand Up @@ -219,7 +219,6 @@ def conv_to_padded_tensor(self, context: Tensor, lens: Tensor) -> Tensor:
ret = torch.nn.utils.rnn.pad_sequence(context_embedded, batch_first=True)
return ret

@torch.jit.export
def masked_conv_to_sequence(self, context: Tensor, lens: Tensor, enforce_sorted: bool = False) -> PackedSequence:
titu1994 marked this conversation as resolved.
Show resolved Hide resolved
mask = get_mask_from_lengths_and_val(lens, context)
mask = mask.unsqueeze(1)
Expand All @@ -232,24 +231,36 @@ def masked_conv_to_sequence(self, context: Tensor, lens: Tensor, enforce_sorted:
)
return seq

def forward(self, context: Tensor, lens: Optional[Tensor] = None) -> Tensor:
if lens is None:
for conv in self.convolutions:
context = self.dropout(F.relu(conv(context)))
context = context.transpose(1, 2)
context, _ = self.bilstm(context)
def forward(self, context: Tensor, in_lens: Optional[Tensor]) -> Tensor:
borisfom marked this conversation as resolved.
Show resolved Hide resolved
if in_lens is None:
lens = context.new_ones([context.shape[0]], dtype=torch.int) * context.shape[2]
else:
# borisf : does not match ADLR (values, lengths)
# seq = self.masked_conv_to_sequence(context, lens, enforce_sorted=False)
# borisf : does match ADLR
seq = self.conv_to_sequence(context, lens, enforce_sorted=False)
context, _ = self.lstm_sequence(seq)
lens = in_lens
# borisf : does not match ADLR (values, lengths)
# seq = self.masked_conv_to_sequence(context, lens, enforce_sorted=False)
# borisf : does match ADLR
seq = self.conv_to_sequence(context, lens, enforce_sorted=False)
context, _ = self.lstm_sequence(seq)

if self.dense is not None:
context = self.dense(context).permute(0, 2, 1)

return context

def script(self):
borisfom marked this conversation as resolved.
Show resolved Hide resolved
traced = nn.ModuleList()
for conv in self.convolutions:
if hasattr(conv, 'conv'):
w = conv.conv.weight
else:
w = conv.weight
s = w.shape
rand_in = torch.ones(1, s[1], s[2], dtype=w.dtype, device=w.device)
# , torch.ones(1, 1, s[2], dtype=w.dtype, device=w.device))
traced.append(torch.jit.trace_module(conv, {"forward": rand_in}))
self.convolutions = traced
return torch.jit.script(self)


def getRadTTSEncoder(
encoder_n_convolutions=3,
Expand Down
10 changes: 5 additions & 5 deletions nemo/collections/tts/modules/radtts.py
Original file line number Diff line number Diff line change
Expand Up @@ -784,11 +784,11 @@ def _prepare_for_export(self, **kwargs):
PartialConv1d.forward = PartialConv1d.forward_no_cache
self.remove_norms()
super()._prepare_for_export(**kwargs)
self.encoder = torch.jit.script(self.encoder)
self.v_pred_module.feat_pred_fn = torch.jit.script(self.v_pred_module.feat_pred_fn)
self.f0_pred_module.feat_pred_fn = torch.jit.script(self.f0_pred_module.feat_pred_fn)
self.energy_pred_module.feat_pred_fn = torch.jit.script(self.energy_pred_module.feat_pred_fn)
self.dur_pred_layer.feat_pred_fn = torch.jit.script(self.dur_pred_layer.feat_pred_fn)
self.encoder = self.encoder.script()
self.v_pred_module.feat_pred_fn = self.v_pred_module.feat_pred_fn.script()
self.f0_pred_module.feat_pred_fn = self.f0_pred_module.feat_pred_fn.script()
self.energy_pred_module.feat_pred_fn = self.energy_pred_module.feat_pred_fn.script()
self.dur_pred_layer.feat_pred_fn = self.dur_pred_layer.feat_pred_fn.script()

if self.use_context_lstm:
self.context_lstm = torch.jit.script(self.context_lstm)
Expand Down
22 changes: 13 additions & 9 deletions nemo/collections/tts/modules/submodules.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,11 @@ def __init__(self, *args, **kwargs):
super(PartialConv1d, self).__init__(*args, **kwargs)
weight_maskUpdater = torch.ones(1, 1, self.kernel_size[0])
self.register_buffer("weight_maskUpdater", weight_maskUpdater, persistent=False)
slide_winsize = torch.tensor(self.weight_maskUpdater.shape[1] * self.weight_maskUpdater.shape[2])
slide_winsize = torch.tensor(
self.weight_maskUpdater.shape[1] * self.weight_maskUpdater.shape[2], requires_grad=False
)
self.register_buffer("slide_winsize", slide_winsize, persistent=False)

if self.bias is not None:
bias_view = self.bias.view(1, self.out_channels, 1)
self.register_buffer('bias_view', bias_view, persistent=False)
# caching part
self.last_size = (-1, -1, -1)

Expand All @@ -51,6 +50,8 @@ def calculate_mask(self, input: torch.Tensor, mask_in: Optional[torch.Tensor]):
mask = torch.ones(1, 1, input.shape[2], dtype=input.dtype, device=input.device)
else:
mask = mask_in
input = torch.mul(input, mask)

update_mask = F.conv1d(
mask,
self.weight_maskUpdater,
Expand All @@ -60,19 +61,22 @@ def calculate_mask(self, input: torch.Tensor, mask_in: Optional[torch.Tensor]):
dilation=self.dilation,
groups=1,
)
# for mixed precision training, change 1e-8 to 1e-6
mask_ratio = self.slide_winsize / (update_mask + 1e-6)

update_mask_filled = torch.masked_fill(update_mask, update_mask == 0, self.slide_winsize)
mask_ratio = self.slide_winsize / (update_mask_filled)

update_mask = torch.clamp(update_mask, 0, 1)
mask_ratio = torch.mul(mask_ratio.to(update_mask), update_mask)
return torch.mul(input, mask), mask_ratio, update_mask
mask_ratio = torch.mul(mask_ratio, update_mask)
return input, mask_ratio, update_mask

def forward_aux(self, input: torch.Tensor, mask_ratio: torch.Tensor, update_mask: torch.Tensor) -> torch.Tensor:
assert len(input.shape) == 3

raw_out = self._conv_forward(input, self.weight, self.bias)

if self.bias is not None:
output = torch.mul(raw_out - self.bias_view, mask_ratio) + self.bias_view
bias_view = self.bias.view(1, self.out_channels, 1)
output = torch.mul(raw_out - bias_view, mask_ratio) + bias_view
output = torch.mul(output, update_mask)
else:
output = torch.mul(raw_out, mask_ratio)
Expand Down