Skip to content

Commit

Permalink
[TTS] Update aligner comments
Browse files Browse the repository at this point in the history
Signed-off-by: Ryan <[email protected]>
  • Loading branch information
rlangman committed Jul 10, 2023
1 parent a7ef1fa commit c67ca7b
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 17 deletions.
43 changes: 27 additions & 16 deletions nemo/collections/tts/modules/aligner.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,18 @@


class AlignmentEncoder(torch.nn.Module):
"""Module for alignment text and mel spectrogram. """
"""
Module for alignment text and mel spectrogram.
Args:
n_mel_channels: Dimension of mel spectrogram.
n_text_channels: Dimension of text embeddings.
n_att_channels: Dimension of model
temperature: Temperature to scale distance by.
Suggested to be 0.0005 when using dist_type "l2" and 15.0 when using "cosine".
condition_types: List of types for nemo.collections.tts.modules.submodules.ConditionalInput.
dist_type: Distance type to use for similarity measurement. Supports "l2" and "cosine" distance.
"""

def __init__(
self,
Expand Down Expand Up @@ -71,40 +82,40 @@ def get_dist(self, keys, queries, mask=None):
"""Calculation of distance matrix.
Args:
queries (torch.tensor): B x C x T1 tensor (probably going to be mel data).
queries (torch.tensor): B x C1 x T1 tensor (probably going to be mel data).
keys (torch.tensor): B x C2 x T2 tensor (text data).
mask (torch.tensor): B x T2 x 1 tensor, binary mask for variable length entries and also can be used
for ignoring unnecessary elements from keys in the resulting distance matrix (True = mask element, False = leave unchanged).
Output:
dist (torch.tensor): B x T1 x T2 tensor.
"""
# B x n_attn_dims x T2
# B x C x T2
keys_enc = self.key_proj(keys)
# B x n_attn_dims x T1
# B x C x T1
queries_enc = self.query_proj(queries)

# B x 1 x T1 x T2
dist = self.dist_fn(queries=queries_enc, keys=keys_enc)
dist = self.dist_fn(queries_enc=queries_enc, keys_enc=keys_enc)

self._apply_mask(dist, mask, float("inf"))

return dist

@staticmethod
def get_euclidean_dist(queries, keys):
queries = rearrange(queries, "B C T1 -> B C T1 1")
keys = rearrange(keys, "B C T2 -> B C 1 T2")
def get_euclidean_dist(queries_enc, keys_enc):
queries_enc = rearrange(queries_enc, "B C T1 -> B C T1 1")
keys = rearrange(keys_enc, "B C T2 -> B C 1 T2")
# B x C x T1 x T2
distance = (queries - keys) ** 2
distance = (queries_enc - keys_enc) ** 2
# B x 1 x T1 x T2
l2_dist = distance.sum(axis=1, keepdim=True)
return l2_dist

@staticmethod
def get_cosine_dist(queries, keys):
queries = rearrange(queries, "B C T1 -> B C T1 1")
keys = rearrange(keys, "B C T2 -> B C 1 T2")
cosine_dist = -torch.nn.functional.cosine_similarity(queries, keys, dim=1)
def get_cosine_dist(queries_enc, keys_enc):
queries_enc = rearrange(queries_enc, "B C T1 -> B C T1 1")
keys_enc = rearrange(keys_enc, "B C T2 -> B C 1 T2")
cosine_dist = -torch.nn.functional.cosine_similarity(queries_enc, keys_enc, dim=1)
cosine_dist = rearrange(cosine_dist, "B T1 T2 -> B 1 T1 T2")
return cosine_dist

Expand Down Expand Up @@ -199,12 +210,12 @@ def forward(self, queries, keys, mask=None, attn_prior=None, conditioning=None):
attn_logprob (torch.tensor): B x 1 x T1 x T2 log-prob attention mask.
"""
keys = self.cond_input(keys.transpose(1, 2), conditioning).transpose(1, 2)
# B x C x T2
keys_enc = self.key_proj(keys)
# B x C x T1
queries_enc = self.query_proj(queries)
# B x C x T2
keys_enc = self.key_proj(keys)
# B x 1 x T1 x T2
distance = self.dist_fn(queries=queries_enc, keys=keys_enc)
distance = self.dist_fn(queries_enc=queries_enc, keys_enc=keys_enc)
attn = -self.temperature * distance

if attn_prior is not None:
Expand Down
2 changes: 1 addition & 1 deletion nemo/collections/tts/modules/submodules.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,7 +509,7 @@ def forward(self, inputs, conditioning=None):
inputs = inputs + conditioning

if "concat" in self.condition_types:
conditioning = conditionting.repeat(1, inputs.shape[1], 1)
conditioning = conditioning.repeat(1, inputs.shape[1], 1)
inputs = torch.cat([inputs, conditioning])
inputs = self.concat_proj(inputs)

Expand Down

0 comments on commit c67ca7b

Please sign in to comment.