Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TTS] FastPitch speaker encoder #6417

Merged
merged 42 commits into from
Apr 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
1a15e50
Add initial codes
hsiehjackson Apr 12, 2023
66f7722
Remove wemb
hsiehjackson Apr 12, 2023
3d8c01b
Fix import
hsiehjackson Apr 12, 2023
86112a6
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 12, 2023
3447da0
Restore aligner loss
hsiehjackson Apr 13, 2023
b84879a
Add ConditionalInput
hsiehjackson Apr 14, 2023
794643e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 14, 2023
6b88020
Fix error and support pre-trained config
hsiehjackson Apr 14, 2023
12e2770
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 14, 2023
aa9fb0a
Follow comments
hsiehjackson Apr 14, 2023
83e0569
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 14, 2023
4aaa0cd
Rename config
hsiehjackson Apr 14, 2023
8099093
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 14, 2023
a4b231a
Change copyright and random weight test
hsiehjackson Apr 15, 2023
b59cc77
Add initial codes
hsiehjackson Apr 12, 2023
2877ded
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 12, 2023
9763037
Fix import error
hsiehjackson Apr 12, 2023
3dcf96a
Add initial codes
hsiehjackson Apr 12, 2023
6030648
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 12, 2023
fb81cbd
Fix dataset error
hsiehjackson Apr 12, 2023
6869a74
Remove reference speaker embedding
hsiehjackson Apr 12, 2023
d98c344
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 12, 2023
9d4f7bf
Remove SV encoder
hsiehjackson Apr 12, 2023
84a829b
Follow comments
hsiehjackson Apr 13, 2023
678bdaf
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 13, 2023
8c22694
Fix length type
hsiehjackson Apr 13, 2023
0fdb9ec
Fix append
hsiehjackson Apr 14, 2023
e5d0af4
Move error msg
hsiehjackson Apr 14, 2023
d47f973
Add look-up into speaker encoder
hsiehjackson Apr 14, 2023
27b6ad0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 14, 2023
0324370
Add valueerror msg
hsiehjackson Apr 14, 2023
d012faa
Move lookup
hsiehjackson Apr 14, 2023
20fbf8f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 14, 2023
c444716
Remove unused
hsiehjackson Apr 14, 2023
72ab4d1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 14, 2023
860cded
Fix error
hsiehjackson Apr 14, 2023
7c321ce
Rebase and Fix error
hsiehjackson Apr 17, 2023
aeff0f3
Fix spk encoder
hsiehjackson Apr 17, 2023
57d2b93
Rename n_speakers
hsiehjackson Apr 17, 2023
576c3ff
Follow comments
hsiehjackson Apr 17, 2023
e2d6ad1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 17, 2023
e833662
Fix n_speakers None error
hsiehjackson Apr 17, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 23 additions & 3 deletions examples/tts/conf/fastpitch_align_44100_adapter.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ name: FastPitch
train_dataset: ???
validation_datasets: ???
sup_data_path: ???
sup_data_types: [ "align_prior_matrix", "pitch", "speaker_id"]
sup_data_types: [ "align_prior_matrix", "pitch", "speaker_id", "reference_audio"]

# Default values from librosa.pyin
pitch_fmin: 65.40639132514966
Expand Down Expand Up @@ -35,10 +35,8 @@ model:
learn_alignment: true
bin_loss_warmup_epochs: 100

n_speakers: 1
max_token_duration: 75
symbols_embedding_dim: 384
speaker_embedding_dim: 384
pitch_embedding_kernel_size: 3

pitch_fmin: ${pitch_fmin}
Expand Down Expand Up @@ -248,6 +246,28 @@ model:
n_layers: 2
condition_types: [ "add", "layernorm" ] # options: [ "add", "cat", "layernorm" ]

speaker_encoder:
_target_: nemo.collections.tts.modules.submodules.SpeakerEncoder
lookup_module:
_target_: nemo.collections.tts.modules.submodules.SpeakerLookupTable
n_speakers: ???
embedding_dim: ${model.symbols_embedding_dim}
gst_module:
_target_: nemo.collections.tts.modules.submodules.GlobalStyleToken
gst_size: ${model.symbols_embedding_dim}
n_style_token: 10
n_style_attn_head: 4
reference_encoder:
_target_: nemo.collections.tts.modules.submodules.ReferenceEncoder
n_mels: ${model.n_mel_channels}
cnn_filters: [32, 32, 64, 64, 128, 128]
dropout: 0.2
gru_hidden: ${model.symbols_embedding_dim}
kernel_size: 3
stride: 2
padding: 1
bias: true

optim:
name: adamw
lr: 1e-3
Expand Down
47 changes: 47 additions & 0 deletions nemo/collections/tts/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
LogMel,
P_voiced,
Pitch,
ReferenceAudio,
SpeakerID,
TTSDataType,
Voiced_mask,
Expand Down Expand Up @@ -483,6 +484,13 @@ def add_energy(self, **kwargs):
def add_speaker_id(self, **kwargs):
pass

def add_reference_audio(self, **kwargs):
assert SpeakerID in self.sup_data_types, "Please add speaker_id in sup_data_types."
"""Add a mapping for each speaker to their manifest indexes"""
self.speaker_to_index_map = defaultdict(set)
for i, d in enumerate(self.data):
self.speaker_to_index_map[d['speaker_id']].add(i)

def get_spec(self, audio):
with torch.cuda.amp.autocast(enabled=False):
spec = self.stft(audio)
Expand Down Expand Up @@ -522,6 +530,12 @@ def _pad_wav_to_multiple(self, wav):
)
return wav

# Random sample a reference index from the same speaker
def sample_reference_index(self, speaker_id):
reference_pool = self.speaker_to_index_map[speaker_id]
reference_index = random.sample(reference_pool, 1)[0]
return reference_index

def __getitem__(self, index):
sample = self.data[index]

Expand Down Expand Up @@ -683,6 +697,19 @@ def __getitem__(self, index):
if SpeakerID in self.sup_data_types_set:
speaker_id = torch.tensor(sample["speaker_id"]).long()

reference_audio, reference_audio_length = None, None
if ReferenceAudio in self.sup_data_types_set:
reference_index = self.sample_reference_index(sample["speaker_id"])
reference_audio = self.featurizer.process(
self.data[reference_index]["audio_filepath"],
trim=self.trim,
trim_ref=self.trim_ref,
trim_top_db=self.trim_top_db,
trim_frame_length=self.trim_frame_length,
trim_hop_length=self.trim_hop_length,
)
reference_audio_length = torch.tensor(reference_audio.shape[0]).long()

return (
audio,
audio_length,
Expand All @@ -700,6 +727,8 @@ def __getitem__(self, index):
voiced_mask,
p_voiced,
audio_shifted,
reference_audio,
reference_audio_length,
)

def __len__(self):
Expand Down Expand Up @@ -733,6 +762,8 @@ def general_collate_fn(self, batch):
voiced_masks,
p_voiceds,
_,
_,
reference_audio_lengths,
) = zip(*batch)

max_audio_len = max(audio_lengths).item()
Expand All @@ -741,6 +772,9 @@ def general_collate_fn(self, batch):
max_durations_len = max([len(i) for i in durations_list]) if Durations in self.sup_data_types_set else None
max_pitches_len = max(pitches_lengths).item() if Pitch in self.sup_data_types_set else None
max_energies_len = max(energies_lengths).item() if Energy in self.sup_data_types_set else None
max_reference_audio_len = (
max(reference_audio_lengths).item() if ReferenceAudio in self.sup_data_types_set else None
)

if LogMel in self.sup_data_types_set:
log_mel_pad = torch.finfo(batch[0][4].dtype).tiny
Expand All @@ -765,6 +799,7 @@ def general_collate_fn(self, batch):
voiced_masks,
p_voiceds,
audios_shifted,
reference_audios,
) = (
[],
[],
Expand All @@ -776,6 +811,7 @@ def general_collate_fn(self, batch):
[],
[],
[],
[],
)

for i, sample_tuple in enumerate(batch):
Expand All @@ -796,6 +832,8 @@ def general_collate_fn(self, batch):
voiced_mask,
p_voiced,
audio_shifted,
reference_audio,
reference_audios_length,
) = sample_tuple

audio = general_padding(audio, audio_len.item(), max_audio_len)
Expand Down Expand Up @@ -834,6 +872,11 @@ def general_collate_fn(self, batch):
if SpeakerID in self.sup_data_types_set:
speaker_ids.append(speaker_id)

if ReferenceAudio in self.sup_data_types_set:
reference_audios.append(
general_padding(reference_audio, reference_audios_length.item(), max_reference_audio_len)
)

data_dict = {
"audio": torch.stack(audios),
"audio_lens": torch.stack(audio_lengths),
Expand All @@ -851,6 +894,10 @@ def general_collate_fn(self, batch):
"voiced_mask": torch.stack(voiced_masks) if Voiced_mask in self.sup_data_types_set else None,
"p_voiced": torch.stack(p_voiceds) if P_voiced in self.sup_data_types_set else None,
"audio_shifted": torch.stack(audios_shifted) if audio_shifted is not None else None,
"reference_audio": torch.stack(reference_audios) if ReferenceAudio in self.sup_data_types_set else None,
"reference_audio_lens": torch.stack(reference_audio_lengths)
if ReferenceAudio in self.sup_data_types_set
else None,
}

return data_dict
Expand Down
78 changes: 67 additions & 11 deletions nemo/collections/tts/models/fastpitch.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,14 +139,17 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None):
output_fft = instantiate(self._cfg.output_fft)
duration_predictor = instantiate(self._cfg.duration_predictor)
pitch_predictor = instantiate(self._cfg.pitch_predictor)
speaker_encoder = instantiate(self._cfg.get("speaker_encoder", None))
energy_embedding_kernel_size = cfg.get("energy_embedding_kernel_size", 0)
energy_predictor = instantiate(self._cfg.get("energy_predictor", None))

# [TODO] may remove if we change the pre-trained config
# cfg: condition_types = [ "add" ]
n_speakers = cfg.get("n_speakers", 0)
speaker_emb_condition_prosody = cfg.get("speaker_emb_condition_prosody", False)
speaker_emb_condition_decoder = cfg.get("speaker_emb_condition_decoder", False)
speaker_emb_condition_aligner = cfg.get("speaker_emb_condition_aligner", False)
if cfg.n_speakers > 1:
if n_speakers > 1 and "add" not in input_fft.cond_input.condition_types:
input_fft.cond_input.condition_types.append("add")
if speaker_emb_condition_prosody:
duration_predictor.cond_input.condition_types.append("add")
Expand All @@ -163,7 +166,8 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None):
pitch_predictor,
energy_predictor,
self.aligner,
cfg.n_speakers,
speaker_encoder,
n_speakers,
cfg.symbols_embedding_dim,
cfg.pitch_embedding_kernel_size,
energy_embedding_kernel_size,
Expand Down Expand Up @@ -305,6 +309,9 @@ def parse(self, str_input: str, normalize=True) -> torch.tensor:
"attn_prior": NeuralType(('B', 'T_spec', 'T_text'), ProbsType(), optional=True),
"mel_lens": NeuralType(('B'), LengthsType(), optional=True),
"input_lens": NeuralType(('B'), LengthsType(), optional=True),
# reference_* data is used for multi-speaker FastPitch training
"reference_spec": NeuralType(('B', 'D', 'T_spec'), MelSpectrogramType(), optional=True),
"reference_spec_lens": NeuralType(('B'), LengthsType(), optional=True),
}
)
def forward(
Expand All @@ -320,6 +327,8 @@ def forward(
attn_prior=None,
mel_lens=None,
input_lens=None,
reference_spec=None,
reference_spec_lens=None,
):
return self.fastpitch(
text=text,
Expand All @@ -332,21 +341,43 @@ def forward(
attn_prior=attn_prior,
mel_lens=mel_lens,
input_lens=input_lens,
reference_spec=reference_spec,
reference_spec_lens=reference_spec_lens,
)

@typecheck(output_types={"spect": NeuralType(('B', 'D', 'T_spec'), MelSpectrogramType())})
def generate_spectrogram(
self, tokens: 'torch.tensor', speaker: Optional[int] = None, pace: float = 1.0
self,
tokens: 'torch.tensor',
speaker: Optional[int] = None,
pace: float = 1.0,
reference_spec: Optional['torch.tensor'] = None,
reference_spec_lens: Optional['torch.tensor'] = None,
) -> torch.tensor:
if self.training:
logging.warning("generate_spectrogram() is meant to be called in eval mode.")
if isinstance(speaker, int):
speaker = torch.tensor([speaker]).to(self.device)
spect, *_ = self(text=tokens, durs=None, pitch=None, speaker=speaker, pace=pace)
spect, *_ = self(
text=tokens,
durs=None,
pitch=None,
speaker=speaker,
pace=pace,
reference_spec=reference_spec,
reference_spec_lens=reference_spec_lens,
)
return spect

def training_step(self, batch, batch_idx):
attn_prior, durs, speaker, energy = None, None, None, None
attn_prior, durs, speaker, energy, reference_audio, reference_audio_len = (
None,
None,
None,
None,
None,
None,
)
if self.learn_alignment:
assert self.ds_class_name == "TTSDataset", f"Unknown dataset class: {self.ds_class_name}"
batch_dict = process_batch(batch, self._train_dl.dataset.sup_data_types_set)
Expand All @@ -358,10 +389,17 @@ def training_step(self, batch, batch_idx):
pitch = batch_dict.get("pitch", None)
energy = batch_dict.get("energy", None)
speaker = batch_dict.get("speaker_id", None)
reference_audio = batch_dict.get("reference_audio", None)
reference_audio_len = batch_dict.get("reference_audio_lens", None)
else:
audio, audio_lens, text, text_lens, durs, pitch, speaker = batch

mels, spec_len = self.preprocessor(input_signal=audio, length=audio_lens)
reference_spec, reference_spec_len = None, None
if reference_audio is not None:
reference_spec, reference_spec_len = self.preprocessor(
input_signal=reference_audio, length=reference_audio_len
)

(
mels_pred,
Expand All @@ -384,6 +422,8 @@ def training_step(self, batch, batch_idx):
speaker=speaker,
pace=1.0,
spec=mels if self.learn_alignment else None,
reference_spec=reference_spec,
reference_spec_lens=reference_spec_len,
attn_prior=attn_prior,
mel_lens=spec_len,
input_lens=text_lens,
Expand Down Expand Up @@ -441,7 +481,14 @@ def training_step(self, batch, batch_idx):
return loss

def validation_step(self, batch, batch_idx):
attn_prior, durs, speaker, energy = None, None, None, None
attn_prior, durs, speaker, energy, reference_audio, reference_audio_len = (
None,
None,
None,
None,
None,
None,
)
if self.learn_alignment:
assert self.ds_class_name == "TTSDataset", f"Unknown dataset class: {self.ds_class_name}"
batch_dict = process_batch(batch, self._train_dl.dataset.sup_data_types_set)
Expand All @@ -453,10 +500,17 @@ def validation_step(self, batch, batch_idx):
pitch = batch_dict.get("pitch", None)
energy = batch_dict.get("energy", None)
speaker = batch_dict.get("speaker_id", None)
reference_audio = batch_dict.get("reference_audio", None)
reference_audio_len = batch_dict.get("reference_audio_lens", None)
else:
audio, audio_lens, text, text_lens, durs, pitch, speaker = batch

mels, mel_lens = self.preprocessor(input_signal=audio, length=audio_lens)
reference_spec, reference_spec_len = None, None
if reference_audio is not None:
reference_spec, reference_spec_len = self.preprocessor(
input_signal=reference_audio, length=reference_audio_len
)

# Calculate val loss on ground truth durations to better align L2 loss in time
(mels_pred, _, _, log_durs_pred, pitch_pred, _, _, _, attn_hard_dur, pitch, energy_pred, energy_tgt,) = self(
Expand All @@ -467,6 +521,8 @@ def validation_step(self, batch, batch_idx):
speaker=speaker,
pace=1.0,
spec=mels if self.learn_alignment else None,
reference_spec=reference_spec,
reference_spec_lens=reference_spec_len,
attn_prior=attn_prior,
mel_lens=mel_lens,
input_lens=text_lens,
Expand Down Expand Up @@ -496,13 +552,13 @@ def validation_epoch_end(self, outputs):
mel_loss = collect("mel_loss")
dur_loss = collect("dur_loss")
pitch_loss = collect("pitch_loss")
self.log("val_loss", val_loss)
self.log("val_mel_loss", mel_loss)
self.log("val_dur_loss", dur_loss)
self.log("val_pitch_loss", pitch_loss)
self.log("val_loss", val_loss, sync_dist=True)
self.log("val_mel_loss", mel_loss, sync_dist=True)
self.log("val_dur_loss", dur_loss, sync_dist=True)
self.log("val_pitch_loss", pitch_loss, sync_dist=True)
if outputs[0]["energy_loss"] is not None:
energy_loss = collect("energy_loss")
self.log("val_energy_loss", energy_loss)
self.log("val_energy_loss", energy_loss, sync_dist=True)

_, _, _, _, _, spec_target, spec_predict = outputs[0].values()

Expand Down
Loading