diff --git a/examples/tts/conf/fastpitch_align_44100_adapter.yaml b/examples/tts/conf/fastpitch_align_44100_adapter.yaml index 032ab1da501f..bac6a64b06e9 100644 --- a/examples/tts/conf/fastpitch_align_44100_adapter.yaml +++ b/examples/tts/conf/fastpitch_align_44100_adapter.yaml @@ -7,7 +7,7 @@ name: FastPitch train_dataset: ??? validation_datasets: ??? sup_data_path: ??? -sup_data_types: [ "align_prior_matrix", "pitch", "speaker_id"] +sup_data_types: [ "align_prior_matrix", "pitch", "speaker_id", "reference_audio"] # Default values from librosa.pyin pitch_fmin: 65.40639132514966 @@ -35,10 +35,8 @@ model: learn_alignment: true bin_loss_warmup_epochs: 100 - n_speakers: 1 max_token_duration: 75 symbols_embedding_dim: 384 - speaker_embedding_dim: 384 pitch_embedding_kernel_size: 3 pitch_fmin: ${pitch_fmin} @@ -248,6 +246,28 @@ model: n_layers: 2 condition_types: [ "add", "layernorm" ] # options: [ "add", "cat", "layernorm" ] + speaker_encoder: + _target_: nemo.collections.tts.modules.submodules.SpeakerEncoder + lookup_module: + _target_: nemo.collections.tts.modules.submodules.SpeakerLookupTable + n_speakers: ??? + embedding_dim: ${model.symbols_embedding_dim} + gst_module: + _target_: nemo.collections.tts.modules.submodules.GlobalStyleToken + gst_size: ${model.symbols_embedding_dim} + n_style_token: 10 + n_style_attn_head: 4 + reference_encoder: + _target_: nemo.collections.tts.modules.submodules.ReferenceEncoder + n_mels: ${model.n_mel_channels} + cnn_filters: [32, 32, 64, 64, 128, 128] + dropout: 0.2 + gru_hidden: ${model.symbols_embedding_dim} + kernel_size: 3 + stride: 2 + padding: 1 + bias: true + optim: name: adamw lr: 1e-3 diff --git a/nemo/collections/tts/data/dataset.py b/nemo/collections/tts/data/dataset.py index af4df1e58668..6bb41d341b31 100644 --- a/nemo/collections/tts/data/dataset.py +++ b/nemo/collections/tts/data/dataset.py @@ -50,6 +50,7 @@ LogMel, P_voiced, Pitch, + ReferenceAudio, SpeakerID, TTSDataType, Voiced_mask, @@ -483,6 +484,13 @@ def add_energy(self, **kwargs): def add_speaker_id(self, **kwargs): pass + def add_reference_audio(self, **kwargs): + assert SpeakerID in self.sup_data_types, "Please add speaker_id in sup_data_types." + """Add a mapping for each speaker to their manifest indexes""" + self.speaker_to_index_map = defaultdict(set) + for i, d in enumerate(self.data): + self.speaker_to_index_map[d['speaker_id']].add(i) + def get_spec(self, audio): with torch.cuda.amp.autocast(enabled=False): spec = self.stft(audio) @@ -522,6 +530,12 @@ def _pad_wav_to_multiple(self, wav): ) return wav + # Random sample a reference index from the same speaker + def sample_reference_index(self, speaker_id): + reference_pool = self.speaker_to_index_map[speaker_id] + reference_index = random.sample(reference_pool, 1)[0] + return reference_index + def __getitem__(self, index): sample = self.data[index] @@ -683,6 +697,19 @@ def __getitem__(self, index): if SpeakerID in self.sup_data_types_set: speaker_id = torch.tensor(sample["speaker_id"]).long() + reference_audio, reference_audio_length = None, None + if ReferenceAudio in self.sup_data_types_set: + reference_index = self.sample_reference_index(sample["speaker_id"]) + reference_audio = self.featurizer.process( + self.data[reference_index]["audio_filepath"], + trim=self.trim, + trim_ref=self.trim_ref, + trim_top_db=self.trim_top_db, + trim_frame_length=self.trim_frame_length, + trim_hop_length=self.trim_hop_length, + ) + reference_audio_length = torch.tensor(reference_audio.shape[0]).long() + return ( audio, audio_length, @@ -700,6 +727,8 @@ def __getitem__(self, index): voiced_mask, p_voiced, audio_shifted, + reference_audio, + reference_audio_length, ) def __len__(self): @@ -733,6 +762,8 @@ def general_collate_fn(self, batch): voiced_masks, p_voiceds, _, + _, + reference_audio_lengths, ) = zip(*batch) max_audio_len = max(audio_lengths).item() @@ -741,6 +772,9 @@ def general_collate_fn(self, batch): max_durations_len = max([len(i) for i in durations_list]) if Durations in self.sup_data_types_set else None max_pitches_len = max(pitches_lengths).item() if Pitch in self.sup_data_types_set else None max_energies_len = max(energies_lengths).item() if Energy in self.sup_data_types_set else None + max_reference_audio_len = ( + max(reference_audio_lengths).item() if ReferenceAudio in self.sup_data_types_set else None + ) if LogMel in self.sup_data_types_set: log_mel_pad = torch.finfo(batch[0][4].dtype).tiny @@ -765,6 +799,7 @@ def general_collate_fn(self, batch): voiced_masks, p_voiceds, audios_shifted, + reference_audios, ) = ( [], [], @@ -776,6 +811,7 @@ def general_collate_fn(self, batch): [], [], [], + [], ) for i, sample_tuple in enumerate(batch): @@ -796,6 +832,8 @@ def general_collate_fn(self, batch): voiced_mask, p_voiced, audio_shifted, + reference_audio, + reference_audios_length, ) = sample_tuple audio = general_padding(audio, audio_len.item(), max_audio_len) @@ -834,6 +872,11 @@ def general_collate_fn(self, batch): if SpeakerID in self.sup_data_types_set: speaker_ids.append(speaker_id) + if ReferenceAudio in self.sup_data_types_set: + reference_audios.append( + general_padding(reference_audio, reference_audios_length.item(), max_reference_audio_len) + ) + data_dict = { "audio": torch.stack(audios), "audio_lens": torch.stack(audio_lengths), @@ -851,6 +894,10 @@ def general_collate_fn(self, batch): "voiced_mask": torch.stack(voiced_masks) if Voiced_mask in self.sup_data_types_set else None, "p_voiced": torch.stack(p_voiceds) if P_voiced in self.sup_data_types_set else None, "audio_shifted": torch.stack(audios_shifted) if audio_shifted is not None else None, + "reference_audio": torch.stack(reference_audios) if ReferenceAudio in self.sup_data_types_set else None, + "reference_audio_lens": torch.stack(reference_audio_lengths) + if ReferenceAudio in self.sup_data_types_set + else None, } return data_dict diff --git a/nemo/collections/tts/models/fastpitch.py b/nemo/collections/tts/models/fastpitch.py index 76eaae2f9ba2..5502e69a3111 100644 --- a/nemo/collections/tts/models/fastpitch.py +++ b/nemo/collections/tts/models/fastpitch.py @@ -139,14 +139,17 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): output_fft = instantiate(self._cfg.output_fft) duration_predictor = instantiate(self._cfg.duration_predictor) pitch_predictor = instantiate(self._cfg.pitch_predictor) + speaker_encoder = instantiate(self._cfg.get("speaker_encoder", None)) energy_embedding_kernel_size = cfg.get("energy_embedding_kernel_size", 0) energy_predictor = instantiate(self._cfg.get("energy_predictor", None)) # [TODO] may remove if we change the pre-trained config + # cfg: condition_types = [ "add" ] + n_speakers = cfg.get("n_speakers", 0) speaker_emb_condition_prosody = cfg.get("speaker_emb_condition_prosody", False) speaker_emb_condition_decoder = cfg.get("speaker_emb_condition_decoder", False) speaker_emb_condition_aligner = cfg.get("speaker_emb_condition_aligner", False) - if cfg.n_speakers > 1: + if n_speakers > 1 and "add" not in input_fft.cond_input.condition_types: input_fft.cond_input.condition_types.append("add") if speaker_emb_condition_prosody: duration_predictor.cond_input.condition_types.append("add") @@ -163,7 +166,8 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): pitch_predictor, energy_predictor, self.aligner, - cfg.n_speakers, + speaker_encoder, + n_speakers, cfg.symbols_embedding_dim, cfg.pitch_embedding_kernel_size, energy_embedding_kernel_size, @@ -305,6 +309,9 @@ def parse(self, str_input: str, normalize=True) -> torch.tensor: "attn_prior": NeuralType(('B', 'T_spec', 'T_text'), ProbsType(), optional=True), "mel_lens": NeuralType(('B'), LengthsType(), optional=True), "input_lens": NeuralType(('B'), LengthsType(), optional=True), + # reference_* data is used for multi-speaker FastPitch training + "reference_spec": NeuralType(('B', 'D', 'T_spec'), MelSpectrogramType(), optional=True), + "reference_spec_lens": NeuralType(('B'), LengthsType(), optional=True), } ) def forward( @@ -320,6 +327,8 @@ def forward( attn_prior=None, mel_lens=None, input_lens=None, + reference_spec=None, + reference_spec_lens=None, ): return self.fastpitch( text=text, @@ -332,21 +341,43 @@ def forward( attn_prior=attn_prior, mel_lens=mel_lens, input_lens=input_lens, + reference_spec=reference_spec, + reference_spec_lens=reference_spec_lens, ) @typecheck(output_types={"spect": NeuralType(('B', 'D', 'T_spec'), MelSpectrogramType())}) def generate_spectrogram( - self, tokens: 'torch.tensor', speaker: Optional[int] = None, pace: float = 1.0 + self, + tokens: 'torch.tensor', + speaker: Optional[int] = None, + pace: float = 1.0, + reference_spec: Optional['torch.tensor'] = None, + reference_spec_lens: Optional['torch.tensor'] = None, ) -> torch.tensor: if self.training: logging.warning("generate_spectrogram() is meant to be called in eval mode.") if isinstance(speaker, int): speaker = torch.tensor([speaker]).to(self.device) - spect, *_ = self(text=tokens, durs=None, pitch=None, speaker=speaker, pace=pace) + spect, *_ = self( + text=tokens, + durs=None, + pitch=None, + speaker=speaker, + pace=pace, + reference_spec=reference_spec, + reference_spec_lens=reference_spec_lens, + ) return spect def training_step(self, batch, batch_idx): - attn_prior, durs, speaker, energy = None, None, None, None + attn_prior, durs, speaker, energy, reference_audio, reference_audio_len = ( + None, + None, + None, + None, + None, + None, + ) if self.learn_alignment: assert self.ds_class_name == "TTSDataset", f"Unknown dataset class: {self.ds_class_name}" batch_dict = process_batch(batch, self._train_dl.dataset.sup_data_types_set) @@ -358,10 +389,17 @@ def training_step(self, batch, batch_idx): pitch = batch_dict.get("pitch", None) energy = batch_dict.get("energy", None) speaker = batch_dict.get("speaker_id", None) + reference_audio = batch_dict.get("reference_audio", None) + reference_audio_len = batch_dict.get("reference_audio_lens", None) else: audio, audio_lens, text, text_lens, durs, pitch, speaker = batch mels, spec_len = self.preprocessor(input_signal=audio, length=audio_lens) + reference_spec, reference_spec_len = None, None + if reference_audio is not None: + reference_spec, reference_spec_len = self.preprocessor( + input_signal=reference_audio, length=reference_audio_len + ) ( mels_pred, @@ -384,6 +422,8 @@ def training_step(self, batch, batch_idx): speaker=speaker, pace=1.0, spec=mels if self.learn_alignment else None, + reference_spec=reference_spec, + reference_spec_lens=reference_spec_len, attn_prior=attn_prior, mel_lens=spec_len, input_lens=text_lens, @@ -441,7 +481,14 @@ def training_step(self, batch, batch_idx): return loss def validation_step(self, batch, batch_idx): - attn_prior, durs, speaker, energy = None, None, None, None + attn_prior, durs, speaker, energy, reference_audio, reference_audio_len = ( + None, + None, + None, + None, + None, + None, + ) if self.learn_alignment: assert self.ds_class_name == "TTSDataset", f"Unknown dataset class: {self.ds_class_name}" batch_dict = process_batch(batch, self._train_dl.dataset.sup_data_types_set) @@ -453,10 +500,17 @@ def validation_step(self, batch, batch_idx): pitch = batch_dict.get("pitch", None) energy = batch_dict.get("energy", None) speaker = batch_dict.get("speaker_id", None) + reference_audio = batch_dict.get("reference_audio", None) + reference_audio_len = batch_dict.get("reference_audio_lens", None) else: audio, audio_lens, text, text_lens, durs, pitch, speaker = batch mels, mel_lens = self.preprocessor(input_signal=audio, length=audio_lens) + reference_spec, reference_spec_len = None, None + if reference_audio is not None: + reference_spec, reference_spec_len = self.preprocessor( + input_signal=reference_audio, length=reference_audio_len + ) # Calculate val loss on ground truth durations to better align L2 loss in time (mels_pred, _, _, log_durs_pred, pitch_pred, _, _, _, attn_hard_dur, pitch, energy_pred, energy_tgt,) = self( @@ -467,6 +521,8 @@ def validation_step(self, batch, batch_idx): speaker=speaker, pace=1.0, spec=mels if self.learn_alignment else None, + reference_spec=reference_spec, + reference_spec_lens=reference_spec_len, attn_prior=attn_prior, mel_lens=mel_lens, input_lens=text_lens, @@ -496,13 +552,13 @@ def validation_epoch_end(self, outputs): mel_loss = collect("mel_loss") dur_loss = collect("dur_loss") pitch_loss = collect("pitch_loss") - self.log("val_loss", val_loss) - self.log("val_mel_loss", mel_loss) - self.log("val_dur_loss", dur_loss) - self.log("val_pitch_loss", pitch_loss) + self.log("val_loss", val_loss, sync_dist=True) + self.log("val_mel_loss", mel_loss, sync_dist=True) + self.log("val_dur_loss", dur_loss, sync_dist=True) + self.log("val_pitch_loss", pitch_loss, sync_dist=True) if outputs[0]["energy_loss"] is not None: energy_loss = collect("energy_loss") - self.log("val_energy_loss", energy_loss) + self.log("val_energy_loss", energy_loss, sync_dist=True) _, _, _, _, _, spec_target, spec_predict = outputs[0].values() diff --git a/nemo/collections/tts/modules/fastpitch.py b/nemo/collections/tts/modules/fastpitch.py index 83ec35d58693..e2da672cf9c7 100644 --- a/nemo/collections/tts/modules/fastpitch.py +++ b/nemo/collections/tts/modules/fastpitch.py @@ -103,7 +103,6 @@ class TemporalPredictor(NeuralModule): def __init__(self, input_size, filter_size, kernel_size, dropout, n_layers=2, condition_types=[]): super(TemporalPredictor, self).__init__() - self.cond_input = ConditionalInput(input_size, input_size, condition_types) self.layers = torch.nn.ModuleList() for i in range(n_layers): @@ -158,6 +157,7 @@ def __init__( pitch_predictor: NeuralModule, energy_predictor: NeuralModule, aligner: NeuralModule, + speaker_encoder: NeuralModule, n_speakers: int, symbols_embedding_dim: int, pitch_embedding_kernel_size: int, @@ -173,11 +173,15 @@ def __init__( self.pitch_predictor = pitch_predictor self.energy_predictor = energy_predictor self.aligner = aligner + self.speaker_encoder = speaker_encoder self.learn_alignment = aligner is not None self.use_duration_predictor = True self.binarize = False - if n_speakers > 1: + # TODO: combine self.speaker_emb with self.speaker_encoder + # cfg: remove `n_speakers`, create `speaker_encoder.lookup_module` + # state_dict: move `speaker_emb.weight` to `speaker_encoder.lookup_module.table.weight` + if n_speakers > 1 and speaker_encoder is None: self.speaker_emb = torch.nn.Embedding(n_speakers, symbols_embedding_dim) else: self.speaker_emb = None @@ -219,6 +223,8 @@ def input_types(self): "attn_prior": NeuralType(('B', 'T_spec', 'T_text'), ProbsType(), optional=True), "mel_lens": NeuralType(('B'), LengthsType(), optional=True), "input_lens": NeuralType(('B'), LengthsType(), optional=True), + "reference_spec": NeuralType(('B', 'D', 'T_spec'), MelSpectrogramType(), optional=True), + "reference_spec_lens": NeuralType(('B'), LengthsType(), optional=True), } @property @@ -238,6 +244,19 @@ def output_types(self): "energy_tgt": NeuralType(('B', 'T_audio'), RegressionValuesType()), } + def get_speaker_embedding(self, speaker, reference_spec, reference_spec_lens): + """spk_emb: Bx1xD""" + if self.speaker_encoder is not None: + spk_emb = self.speaker_encoder(speaker, reference_spec, reference_spec_lens).unsqueeze(1) + elif self.speaker_emb is not None: + if speaker is None: + raise ValueError('Please give speaker id to get lookup speaker embedding.') + spk_emb = self.speaker_emb(speaker).unsqueeze(1) + else: + spk_emb = None + + return spk_emb + @typecheck() def forward( self, @@ -252,6 +271,8 @@ def forward( attn_prior=None, mel_lens=None, input_lens=None, + reference_spec=None, + reference_spec_lens=None, ): if not self.learn_alignment and self.training: @@ -259,10 +280,9 @@ def forward( assert pitch is not None # Calculate speaker embedding - if self.speaker_emb is None or speaker is None: - spk_emb = None - else: - spk_emb = self.speaker_emb(speaker).unsqueeze(1) + spk_emb = self.get_speaker_embedding( + speaker=speaker, reference_spec=reference_spec, reference_spec_lens=reference_spec_lens, + ) # Input FFT enc_out, enc_mask = self.encoder(input=text, conditioning=spk_emb) @@ -347,12 +367,23 @@ def forward( energy_tgt, ) - def infer(self, *, text, pitch=None, speaker=None, energy=None, pace=1.0, volume=None): + def infer( + self, + *, + text, + pitch=None, + speaker=None, + energy=None, + pace=1.0, + volume=None, + reference_spec=None, + reference_spec_lens=None, + ): + # Calculate speaker embedding - if self.speaker_emb is None or speaker is None: - spk_emb = 0 - else: - spk_emb = self.speaker_emb(speaker).unsqueeze(1) + spk_emb = self.get_speaker_embedding( + speaker=speaker, reference_spec=reference_spec, reference_spec_lens=reference_spec_lens, + ) # Input FFT enc_out, enc_mask = self.encoder(input=text, conditioning=spk_emb) diff --git a/nemo/collections/tts/modules/submodules.py b/nemo/collections/tts/modules/submodules.py index 44ed0a92d776..dbf26f1ceeee 100644 --- a/nemo/collections/tts/modules/submodules.py +++ b/nemo/collections/tts/modules/submodules.py @@ -18,8 +18,12 @@ from torch import Tensor from torch.autograd import Variable from torch.nn import functional as F +from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence -from nemo.core.classes import adapter_mixins +from nemo.core.classes import NeuralModule, adapter_mixins +from nemo.core.neural_types.elements import EncodedRepresentation, Index, LengthsType, MelSpectrogramType +from nemo.core.neural_types.neural_type import NeuralType +from nemo.utils import logging SUPPORTED_CONDITION_TYPES = ["add", "concat", "layernorm"] @@ -456,8 +460,10 @@ def forward(self, inputs, conditioning=None): if self.condition: if conditioning is None: raise ValueError( - 'You should add additional data types as conditions e.g. speaker id or reference audio' + """You should add additional data types as conditions (e.g. speaker id or reference audio) + and define speaker_encoder in your config.""" ) + inputs = inputs * self.cond_weight(conditioning) inputs = inputs + self.cond_bias(conditioning) @@ -493,7 +499,8 @@ def forward(self, inputs, conditioning=None): if len(self.condition_types) > 0: if conditioning is None: raise ValueError( - 'You should add additional data types as conditions e.g. speaker id or reference audio' + """You should add additional data types as conditions (e.g. speaker id or reference audio) + and define speaker_encoder in your config.""" ) if "add" in self.condition_types: @@ -507,3 +514,241 @@ def forward(self, inputs, conditioning=None): inputs = self.concat_proj(inputs) return inputs + + +class StyleAttention(NeuralModule): + def __init__(self, gst_size=128, n_style_token=10, n_style_attn_head=4): + super(StyleAttention, self).__init__() + + token_size = gst_size // n_style_attn_head + self.tokens = torch.nn.Parameter(torch.FloatTensor(n_style_token, token_size)) + self.mha = torch.nn.MultiheadAttention( + embed_dim=gst_size, + num_heads=n_style_attn_head, + dropout=0.0, + bias=True, + kdim=token_size, + vdim=token_size, + batch_first=True, + ) + torch.nn.init.normal_(self.tokens) + + @property + def input_types(self): + return { + "inputs": NeuralType(('B', 'D'), EncodedRepresentation()), + "token_id": NeuralType(('B'), Index(), optional=True), + } + + @property + def output_types(self): + return { + "style_emb": NeuralType(('B', 'D'), EncodedRepresentation()), + } + + def forward(self, inputs): + batch_size = inputs.size(0) + query = inputs.unsqueeze(1) + tokens = F.tanh(self.tokens).unsqueeze(0).expand(batch_size, -1, -1) + + style_emb, _ = self.mha(query=query, key=tokens, value=tokens) + style_emb = style_emb.squeeze(1) + return style_emb + + +class Conv2DReLUNorm(torch.nn.Module): + def __init__(self, in_channels, out_channels, kernel_size=3, stride=2, padding=1, bias=True, dropout=0.0): + super(Conv2DReLUNorm, self).__init__() + self.conv = torch.nn.Conv2d( + in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias + ) + self.norm = torch.nn.LayerNorm(out_channels) + self.dropout = torch.nn.Dropout(dropout) + + def forward(self, x, x_mask=None): + if x_mask is not None: + x = x * x_mask + + # bhwc -> bchw + x = x.contiguous().permute(0, 3, 1, 2) + x = F.relu(self.conv(x)) + # bchw -> bhwc + x = x.contiguous().permute(0, 2, 3, 1) + x = self.norm(x) + x = self.dropout(x) + return x + + +class ReferenceEncoder(NeuralModule): + """ + Encode mel-spectrograms to an utterance level feature + """ + + def __init__(self, n_mels, cnn_filters, dropout, gru_hidden, kernel_size, stride, padding, bias): + super(ReferenceEncoder, self).__init__() + self.filter_size = [1] + list(cnn_filters) + self.layers = torch.nn.ModuleList( + [ + Conv2DReLUNorm( + in_channels=int(self.filter_size[i]), + out_channels=int(self.filter_size[i + 1]), + kernel_size=kernel_size, + stride=stride, + padding=padding, + bias=bias, + dropout=dropout, + ) + for i in range(len(cnn_filters)) + ] + ) + post_conv_height = self.calculate_post_conv_lengths(n_mels, n_convs=len(cnn_filters)) + self.gru = torch.nn.GRU( + input_size=cnn_filters[-1] * post_conv_height, hidden_size=gru_hidden, batch_first=True, + ) + + @property + def input_types(self): + return { + "inputs": NeuralType(('B', 'D', 'T_spec'), MelSpectrogramType()), + "inputs_lengths": NeuralType(('B'), LengthsType()), + } + + @property + def output_types(self): + return { + "out": NeuralType(('B', 'D'), EncodedRepresentation()), + } + + def forward(self, inputs, inputs_lengths): + # BMW -> BWMC (M: mels) + x = inputs.transpose(1, 2).unsqueeze(3) + x_lens = inputs_lengths + x_masks = self.lengths_to_masks(x_lens).unsqueeze(2).unsqueeze(3) + + for layer in self.layers: + x = layer(x, x_masks) + x_lens = self.calculate_post_conv_lengths(x_lens) + x_masks = self.lengths_to_masks(x_lens).unsqueeze(2).unsqueeze(3) + + # BWMC -> BWC + x = x.contiguous().view(x.shape[0], x.shape[1], -1) + + self.gru.flatten_parameters() + packed_x = pack_padded_sequence(x, x_lens.cpu(), batch_first=True, enforce_sorted=False) + packed_x, _ = self.gru(packed_x) + x, x_lens = pad_packed_sequence(packed_x, batch_first=True) + x = x[torch.arange(len(x_lens)), (x_lens - 1), :] + return x + + @staticmethod + def calculate_post_conv_lengths(lengths, n_convs=1, kernel_size=3, stride=2, pad=1): + """Batch lengths after n convolution with fixed kernel/stride/pad.""" + for _ in range(n_convs): + lengths = (lengths - kernel_size + 2 * pad) // stride + 1 + return lengths + + @staticmethod + def lengths_to_masks(lengths): + """Batch of lengths to batch of masks""" + # B -> BxT + masks = torch.arange(lengths.max()).to(lengths.device).expand( + lengths.shape[0], lengths.max() + ) < lengths.unsqueeze(1) + return masks + + +class GlobalStyleToken(NeuralModule): + """ + Global Style Token based Speaker Embedding + """ + + def __init__( + self, reference_encoder, gst_size=128, n_style_token=10, n_style_attn_head=4, + ): + super(GlobalStyleToken, self).__init__() + self.reference_encoder = reference_encoder + self.style_attention = StyleAttention( + gst_size=gst_size, n_style_token=n_style_token, n_style_attn_head=n_style_attn_head + ) + + @property + def input_types(self): + return { + "inp": NeuralType(('B', 'D', 'T_spec'), MelSpectrogramType()), + "inp_lengths": NeuralType(('B'), LengthsType()), + } + + @property + def output_types(self): + return { + "gst": NeuralType(('B', 'D'), EncodedRepresentation()), + } + + def forward(self, inp, inp_lengths): + style_embedding = self.reference_encoder(inp, inp_lengths) + gst = self.style_attention(style_embedding) + return gst + + +class SpeakerLookupTable(torch.nn.Module): + """ + LookupTable based Speaker Embedding + """ + + def __init__(self, n_speakers, embedding_dim): + super(SpeakerLookupTable, self).__init__() + self.table = torch.nn.Embedding(n_speakers, embedding_dim) + + def forward(self, speaker): + return self.table(speaker) + + +class SpeakerEncoder(NeuralModule): + """ + class SpeakerEncoder represents speakers representation. + This module can combine GST (global style token) based speaker embeddings and lookup table speaker embeddings. + """ + + def __init__(self, lookup_module=None, gst_module=None): + """ + lookup_module: Torch module to get lookup based speaker embedding + gst_module: Neural module to get GST based speaker embedding + """ + super(SpeakerEncoder, self).__init__() + self.lookup_module = lookup_module + self.gst_module = gst_module + + @property + def input_types(self): + return { + "speaker": NeuralType(('B'), Index(), optional=True), + "reference_spec": NeuralType(('B', 'D', 'T_spec'), MelSpectrogramType(), optional=True), + "reference_spec_lens": NeuralType(('B'), LengthsType(), optional=True), + } + + @property + def output_types(self): + return { + "embs": NeuralType(('B', 'D'), EncodedRepresentation()), + } + + def forward(self, speaker=None, reference_spec=None, reference_spec_lens=None): + embs = None + + # Get Lookup table speaker embedding + if self.lookup_module is not None and speaker is not None: + embs = self.lookup_module(speaker) + + # Get GST based speaker embedding + if self.gst_module is not None: + if reference_spec is None or reference_spec_lens is None: + raise ValueError( + "You should add `reference_audio` in sup_data_types or remove `speaker_encoder`in config." + ) + out = self.gst_module(reference_spec, reference_spec_lens) + embs = out if embs is None else embs + out + + elif self.gst_module is None and reference_spec is not None and reference_spec_lens is not None: + logging.warning("You may add `gst_module` in speaker_encoder to use reference_audio.") + + return embs diff --git a/nemo/collections/tts/torch/tts_data_types.py b/nemo/collections/tts/torch/tts_data_types.py index 899e5da7d801..ae7516009cd9 100644 --- a/nemo/collections/tts/torch/tts_data_types.py +++ b/nemo/collections/tts/torch/tts_data_types.py @@ -67,6 +67,10 @@ class LMTokens(TTSDataType): name = "lm_tokens" +class ReferenceAudio(TTSDataType, WithLens): + name = "reference_audio" + + MAIN_DATA_TYPES = [Audio, Text] VALID_SUPPLEMENTARY_DATA_TYPES = [ LogMel, @@ -78,5 +82,6 @@ class LMTokens(TTSDataType): LMTokens, Voiced_mask, P_voiced, + ReferenceAudio, ] DATA_STR2DATA_CLASS = {d.name: d for d in MAIN_DATA_TYPES + VALID_SUPPLEMENTARY_DATA_TYPES}