Skip to content

Commit

Permalink
Merge pull request #444 from blisc/tts_new
Browse files Browse the repository at this point in the history
Update Tacotron2 Modules for v10
  • Loading branch information
okuchaiev authored Mar 3, 2020
2 parents 2cf9fd7 + 21e4ea8 commit bf9c2f0
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 12 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,8 @@ files, along with unit tests, examples and tutorials
### Fixed
- Critical fix of the training action on CPU
([PR #308](https://github.com/NVIDIA/NeMo/pull/309)) - @tkornuta-nvidia
- Fixed issue in Tacotron 2 prenet
([PR #444](https://github.com/NVIDIA/NeMo/pull/444)) - @blisc

### Removed
- gradient_predivide_factor arg of train() now has no effect
Expand Down
4 changes: 1 addition & 3 deletions nemo/collections/tts/parts/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,7 @@ def __init__(self, in_dim, out_dim, bias=True, w_init_gain='linear'):
super(LinearNorm, self).__init__()
self.linear_layer = torch.nn.Linear(in_dim, out_dim, bias=bias)

torch.nn.init.xavier_uniform_(
self.linear_layer.weight, gain=torch.nn.init.calculate_gain(w_init_gain),
)
torch.nn.init.xavier_uniform_(self.linear_layer.weight, gain=torch.nn.init.calculate_gain(w_init_gain))

def forward(self, x):
return self.linear_layer(x)
Expand Down
13 changes: 6 additions & 7 deletions nemo/collections/tts/parts/tacotron2.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,7 @@ def get_alignment_energies(self, query, processed_memory, attention_weights_cat)
------
query: decoder output (batch, n_mel_channels * n_frames_per_step)
processed_memory: processed encoder outputs (B, T_in, attention_dim)
attention_weights_cat: cumulative and prev. att weights
(B, 2, max_time)
attention_weights_cat: cumulative and prev. att weights (B, 2, max_time)
RETURNS
-------
alignment (batch, max_time)
Expand Down Expand Up @@ -108,12 +107,12 @@ def forward(self, x, inference=False):
for linear in self.layers:
x = F.relu(linear(x))
x0 = x[0].unsqueeze(0)
mask = Variable(torch.bernoulli(x0.data.new(x0.data.size()).fill_(0.5)))
mask = Variable(torch.bernoulli(x0.data.new(x0.data.size()).fill_(1 - self.p_dropout)))
mask = mask.expand(x.size(0), x.size(1))
x = x * mask * 2
x = x * mask * 1 / (1 - self.p_dropout)
else:
for linear in self.layers:
x = F.dropout(F.relu(linear(x)), p=0.0, training=True)
x = F.dropout(F.relu(linear(x)), p=self.p_dropout, training=True)
return x


Expand Down Expand Up @@ -177,7 +176,7 @@ def __init__(

def forward(self, x):
for i in range(len(self.convolutions) - 1):
x = F.dropout(torch.tanh(self.convolutions[i](x)), 0.5, self.training)
x = F.dropout(torch.tanh(self.convolutions[i](x)), self.p_dropout, self.training)
x = F.dropout(self.convolutions[-1](x), self.p_dropout, self.training)

return x
Expand Down Expand Up @@ -265,7 +264,7 @@ def __init__(
self.p_decoder_dropout = p_decoder_dropout
self.early_stopping = early_stopping

self.prenet = Prenet(n_mel_channels * n_frames_per_step, [prenet_dim, prenet_dim], prenet_p_dropout,)
self.prenet = Prenet(n_mel_channels * n_frames_per_step, [prenet_dim, prenet_dim], prenet_p_dropout)

self.attention_rnn = nn.LSTMCell(prenet_dim + encoder_embedding_dim, attention_rnn_dim)

Expand Down
5 changes: 3 additions & 2 deletions nemo/collections/tts/tacotron2_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@
from torch import nn
from torch.nn.functional import pad

from .parts.layers import get_mask_from_lengths
from .parts.tacotron2 import Decoder, Encoder, Postnet
from nemo import logging
from nemo.backends.pytorch.nm import LossNM, NonTrainableNM, TrainableNM
from nemo.collections.tts.parts.layers import get_mask_from_lengths
from nemo.collections.tts.parts.tacotron2 import Decoder, Encoder, Postnet
from nemo.core.neural_types import *
from nemo.utils.decorators import add_port_docs

Expand Down

0 comments on commit bf9c2f0

Please sign in to comment.