From eea78b289972dcfbb5d47a469f03f5f471d131bd Mon Sep 17 00:00:00 2001 From: Ryan Langman Date: Mon, 31 Jul 2023 15:32:49 -0700 Subject: [PATCH] [TTS] Create EnCodec training recipe (#6852) * [TTS] Create EnCodec training recipe Signed-off-by: Ryan * [TTS] Update encodec recipe Signed-off-by: Ryan * [TTS] Rename EnCodec to AudioCodec Signed-off-by: Ryan * [TTS] Add EnCodec unit tests Signed-off-by: Ryan * [TTS] Add copyright header to distributed.py Signed-off-by: Ryan --------- Signed-off-by: Ryan Signed-off-by: jubick1337 --- examples/tts/audio_codec.py | 31 ++ examples/tts/conf/audio_codec/encodec.yaml | 160 ++++++ examples/tts/conf/hifigan/hifigan_data.yaml | 1 + nemo/collections/tts/data/vocoder_dataset.py | 11 + .../tts/losses/audio_codec_loss.py | 294 ++++++++++ nemo/collections/tts/models/__init__.py | 2 + nemo/collections/tts/models/audio_codec.py | 385 +++++++++++++ .../tts/modules/audio_codec_modules.py | 515 ++++++++++++++++++ nemo/collections/tts/modules/common.py | 27 + .../tts/modules/vector_quantization.py | 429 +++++++++++++++ nemo/collections/tts/parts/utils/callbacks.py | 87 +++ .../tts/parts/utils/distributed.py | 42 ++ nemo/collections/tts/parts/utils/helpers.py | 5 +- .../tts/losses/test_audio_codec_loss.py | 44 ++ .../tts/modules/test_audio_codec_modules.py | 96 ++++ 15 files changed, 2128 insertions(+), 1 deletion(-) create mode 100644 examples/tts/audio_codec.py create mode 100644 examples/tts/conf/audio_codec/encodec.yaml create mode 100644 nemo/collections/tts/losses/audio_codec_loss.py create mode 100644 nemo/collections/tts/models/audio_codec.py create mode 100644 nemo/collections/tts/modules/audio_codec_modules.py create mode 100644 nemo/collections/tts/modules/vector_quantization.py create mode 100644 nemo/collections/tts/parts/utils/distributed.py create mode 100644 tests/collections/tts/losses/test_audio_codec_loss.py create mode 100644 tests/collections/tts/modules/test_audio_codec_modules.py diff --git a/examples/tts/audio_codec.py b/examples/tts/audio_codec.py new file mode 100644 index 000000000000..ffc91cd98f01 --- /dev/null +++ b/examples/tts/audio_codec.py @@ -0,0 +1,31 @@ +# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytorch_lightning as pl + +from nemo.collections.tts.models import AudioCodecModel +from nemo.core.config import hydra_runner +from nemo.utils.exp_manager import exp_manager + + +@hydra_runner(config_path="conf/audio_codec", config_name="audio_codec") +def main(cfg): + trainer = pl.Trainer(**cfg.trainer) + exp_manager(trainer, cfg.get("exp_manager", None)) + model = AudioCodecModel(cfg=cfg.model, trainer=trainer) + trainer.fit(model) + + +if __name__ == '__main__': + main() # noqa pylint: disable=no-value-for-parameter diff --git a/examples/tts/conf/audio_codec/encodec.yaml b/examples/tts/conf/audio_codec/encodec.yaml new file mode 100644 index 000000000000..e6e9f2e7876f --- /dev/null +++ b/examples/tts/conf/audio_codec/encodec.yaml @@ -0,0 +1,160 @@ +# This config contains the default values for training 24khz EnCodec model +# If you want to train model on other dataset, you can change config values according to your dataset. +# Most dataset-specific arguments are in the head of the config file, see below. + +name: EnCodec + +max_epochs: ??? +# Adjust batch size based on GPU memory +batch_size: 16 +# When doing weighted sampling with multiple manifests, this defines how many training steps are in an epoch. +# If null, then weighted sampling is disabled. +weighted_sampling_steps_per_epoch: null + +# Dataset metadata for each manifest +# https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/tts/data/vocoder_dataset.py#L39-L41 +train_ds_meta: ??? +val_ds_meta: ??? + +log_ds_meta: ??? +log_dir: ??? + +# Modify these values based on your sample rate +sample_rate: 24000 +train_n_samples: 24000 +down_sample_rates: [2, 4, 5, 8] +up_sample_rates: [8, 5, 4, 2] +# The number of samples per encoded audio frame. Should be the product of the down_sample_rates. +# For example 2 * 4 * 5 * 8 = 320. +samples_per_frame: 320 + +model: + + max_epochs: ${max_epochs} + steps_per_epoch: ${weighted_sampling_steps_per_epoch} + + sample_rate: ${sample_rate} + samples_per_frame: ${samples_per_frame} + time_domain_loss_scale: 0.1 + # Probability of updating the discriminator during each training step + disc_update_prob: 0.67 + + # All resolutions for mel reconstruction loss, ordered [num_fft, hop_length, window_length] + mel_loss_resolutions: [ + [32, 8, 32], [64, 16, 64], [128, 32, 128], [256, 64, 256], [512, 128, 512], [1024, 256, 1024], [2048, 512, 2048] + ] + + train_ds: + dataset: + _target_: nemo.collections.tts.data.vocoder_dataset.VocoderDataset + weighted_sampling_steps_per_epoch: ${weighted_sampling_steps_per_epoch} + sample_rate: ${sample_rate} + n_samples: ${train_n_samples} + min_duration: 1.01 + max_duration: null + dataset_meta: ${train_ds_meta} + + dataloader_params: + batch_size: ${batch_size} + drop_last: true + num_workers: 4 + + validation_ds: + dataset: + _target_: nemo.collections.tts.data.vocoder_dataset.VocoderDataset + sample_rate: ${sample_rate} + n_samples: null + min_duration: null + max_duration: null + trunc_duration: 10.0 # Only use the first 10 seconds of audio for computing validation loss + dataset_meta: ${val_ds_meta} + + dataloader_params: + batch_size: 8 + num_workers: 2 + + # Configures how audio samples are generated and saved during training. + # Remove this section to disable logging. + log_config: + log_dir: ${log_dir} + log_epochs: [10, 50] + epoch_frequency: 100 + log_tensorboard: false + log_wandb: false + + generators: + - _target_: nemo.collections.tts.parts.utils.callbacks.AudioCodecArtifactGenerator + log_audio: true + log_encoding: true + log_quantized: true + + dataset: + _target_: nemo.collections.tts.data.vocoder_dataset.VocoderDataset + sample_rate: ${sample_rate} + n_samples: null + min_duration: null + max_duration: null + trunc_duration: 15.0 # Only log the first 15 seconds of generated audio. + dataset_meta: ${log_ds_meta} + + dataloader_params: + batch_size: 4 + num_workers: 2 + + audio_encoder: + _target_: nemo.collections.tts.modules.audio_codec_modules.SEANetEncoder + down_sample_rates: ${down_sample_rates} + + audio_decoder: + _target_: nemo.collections.tts.modules.audio_codec_modules.SEANetDecoder + up_sample_rates: ${up_sample_rates} + + vector_quantizer: + _target_: nemo.collections.tts.modules.vector_quantization.ResidualVectorQuantizer + num_codebooks: 8 + + discriminator: + _target_: nemo.collections.tts.modules.audio_codec_modules.MultiResolutionDiscriminatorSTFT + resolutions: [[128, 32, 128], [256, 64, 256], [512, 128, 512], [1024, 256, 1024], [2048, 512, 2048]] + + # The original EnCodec uses hinged loss, but squared-GAN loss is more stable + # and reduces the need to tune the loss weights or use a gradient balancer. + generator_loss: + _target_: nemo.collections.tts.losses.audio_codec_loss.GeneratorSquaredLoss + + discriminator_loss: + _target_: nemo.collections.tts.losses.audio_codec_loss.DiscriminatorSquaredLoss + + optim: + _target_: torch.optim.Adam + lr: 3e-4 + betas: [0.5, 0.9] + + sched: + name: ExponentialLR + gamma: 0.999 + +trainer: + num_nodes: 1 + devices: 1 + accelerator: gpu + strategy: ddp + precision: 32 # Vector quantization only works with 32-bit precision. + max_epochs: ${max_epochs} + accumulate_grad_batches: 1 + enable_checkpointing: False # Provided by exp_manager + logger: false # Provided by exp_manager + log_every_n_steps: 100 + check_val_every_n_epoch: 5 + benchmark: false + +exp_manager: + exp_dir: null + name: ${name} + create_tensorboard_logger: true + create_checkpoint_callback: true + create_wandb_logger: false + checkpoint_callback_params: + monitor: val_loss + resume_if_exists: false + resume_ignore_no_checkpoint: false diff --git a/examples/tts/conf/hifigan/hifigan_data.yaml b/examples/tts/conf/hifigan/hifigan_data.yaml index fde2f169aa8d..62ce3344636e 100644 --- a/examples/tts/conf/hifigan/hifigan_data.yaml +++ b/examples/tts/conf/hifigan/hifigan_data.yaml @@ -92,6 +92,7 @@ model: n_samples: null min_duration: null max_duration: null + trunc_duration: 15.0 dataset_meta: ${log_ds_meta} dataloader_params: diff --git a/nemo/collections/tts/data/vocoder_dataset.py b/nemo/collections/tts/data/vocoder_dataset.py index 9bb115ba2448..6bf03068a395 100644 --- a/nemo/collections/tts/data/vocoder_dataset.py +++ b/nemo/collections/tts/data/vocoder_dataset.py @@ -65,6 +65,7 @@ class VocoderDataset(Dataset): will be ignored. max_duration: Optional float, if provided audio files in the training manifest longer than 'max_duration' will be ignored. + trunc_duration: Optional int, if provided audio will be truncated to at most 'trunc_duration' seconds. num_audio_retries: Number of read attempts to make when sampling audio file, to avoid training failing from sporadic IO errors. """ @@ -78,6 +79,7 @@ def __init__( feature_processors: Optional[Dict[str, FeatureProcessor]] = None, min_duration: Optional[float] = None, max_duration: Optional[float] = None, + trunc_duration: Optional[float] = None, num_audio_retries: int = 5, ): super().__init__() @@ -88,6 +90,11 @@ def __init__( self.num_audio_retries = num_audio_retries self.load_precomputed_mel = False + if trunc_duration: + self.trunc_samples = int(trunc_duration * self.sample_rate) + else: + self.trunc_samples = None + if feature_processors: logging.info(f"Found feature processors {feature_processors.keys()}") self.feature_processors = list(feature_processors.values()) @@ -132,6 +139,10 @@ def _sample_audio(self, audio_filepath: Path) -> Tuple[torch.Tensor, torch.Tenso else: audio_segment = self._segment_audio(audio_filepath) audio_array = audio_segment.samples + + if self.trunc_samples: + audio_array = audio_array[: self.trunc_samples] + audio = torch.tensor(audio_array) audio_len = torch.tensor(audio.shape[0]) return audio, audio_len diff --git a/nemo/collections/tts/losses/audio_codec_loss.py b/nemo/collections/tts/losses/audio_codec_loss.py new file mode 100644 index 000000000000..bde96fadb4c2 --- /dev/null +++ b/nemo/collections/tts/losses/audio_codec_loss.py @@ -0,0 +1,294 @@ +# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List + +import torch +import torch.nn.functional as F +from einops import rearrange + +from nemo.collections.asr.parts.preprocessing.features import FilterbankFeatures +from nemo.core.classes import Loss, typecheck +from nemo.core.neural_types import ( + AudioSignal, + LengthsType, + LossType, + NeuralType, + PredictionsType, + RegressionValuesType, + VoidType, +) + + +class MaskedLoss(Loss): + def __init__(self, loss_fn, loss_scale: float = 1.0): + super(MaskedLoss, self).__init__() + self.loss_scale = loss_scale + self.loss_fn = loss_fn + + @property + def input_types(self): + return { + "target": NeuralType(('B', 'D', 'T'), RegressionValuesType()), + "predicted": NeuralType(('B', 'D', 'T'), PredictionsType()), + "target_len": NeuralType(tuple('B'), LengthsType()), + } + + @property + def output_types(self): + return { + "loss": NeuralType(elements_type=LossType()), + } + + @typecheck() + def forward(self, predicted, target, target_len): + assert target.shape[2] == predicted.shape[2] + + # [B, D, T] + loss = self.loss_fn(input=predicted, target=target) + # [B, T] + loss = torch.mean(loss, dim=1) + # [B] + loss = torch.sum(loss, dim=1) / torch.clamp(target_len, min=1.0) + + # [1] + loss = torch.mean(loss) + loss = self.loss_scale * loss + + return loss + + +class MaskedMAELoss(MaskedLoss): + def __init__(self, loss_scale: float = 1.0): + loss_fn = torch.nn.L1Loss(reduction='none') + super(MaskedMAELoss, self).__init__(loss_fn=loss_fn, loss_scale=loss_scale) + + +class MaskedMSELoss(MaskedLoss): + def __init__(self, loss_scale: float = 1.0): + loss_fn = torch.nn.MSELoss(reduction='none') + super(MaskedMSELoss, self).__init__(loss_fn=loss_fn, loss_scale=loss_scale) + + +class TimeDomainLoss(Loss): + def __init__(self): + super(TimeDomainLoss, self).__init__() + self.loss_fn = MaskedMAELoss() + + @property + def input_types(self): + return { + "audio_real": NeuralType(('B', 'T'), AudioSignal()), + "audio_gen": NeuralType(('B', 'T'), AudioSignal()), + "audio_len": NeuralType(tuple('B'), LengthsType()), + } + + @property + def output_types(self): + return { + "loss": [NeuralType(elements_type=LossType())], + } + + @typecheck() + def forward(self, audio_real, audio_gen, audio_len): + audio_real = rearrange(audio_real, "B T -> B 1 T") + audio_gen = rearrange(audio_gen, "B T -> B 1 T") + loss = self.loss_fn(target=audio_real, predicted=audio_gen, target_len=audio_len) + return loss + + +class MultiResolutionMelLoss(Loss): + def __init__(self, sample_rate: int, mel_dim: int, resolutions: List[List], l1_scale: float = 1.0): + super(MultiResolutionMelLoss, self).__init__() + + self.l1_loss_fn = MaskedMAELoss(loss_scale=l1_scale) + self.l2_loss_fn = MaskedMSELoss() + + self.mel_features = torch.nn.ModuleList() + for n_fft, hop_len, win_len in resolutions: + mel_feature = FilterbankFeatures( + sample_rate=sample_rate, + nfilt=mel_dim, + n_window_size=win_len, + n_window_stride=hop_len, + n_fft=n_fft, + pad_to=1, + mag_power=1.0, + log_zero_guard_type="add", + log_zero_guard_value=1.0, + mel_norm=None, + normalize=None, + preemph=None, + dither=0.0, + use_grads=True, + ) + self.mel_features.append(mel_feature) + + @property + def input_types(self): + return { + "audio_real": NeuralType(('B', 'T'), AudioSignal()), + "audio_gen": NeuralType(('B', 'T'), AudioSignal()), + "audio_len": NeuralType(tuple('B'), LengthsType()), + } + + @property + def output_types(self): + return { + "loss": [NeuralType(elements_type=LossType())], + } + + @typecheck() + def forward(self, audio_real, audio_gen, audio_len): + loss = 0.0 + for mel_feature in self.mel_features: + mel_real, mel_real_len = mel_feature(x=audio_real, seq_len=audio_len) + mel_gen, _ = mel_feature(x=audio_gen, seq_len=audio_len) + loss += self.l1_loss_fn(predicted=mel_gen, target=mel_real, target_len=mel_real_len) + loss += self.l2_loss_fn(predicted=mel_gen, target=mel_real, target_len=mel_real_len) + + loss /= len(self.mel_features) + + return loss + + +class RelativeFeatureMatchingLoss(Loss): + def __init__(self, div_guard=1e-3): + super(RelativeFeatureMatchingLoss, self).__init__() + self.div_guard = div_guard + + @property + def input_types(self): + return { + "fmaps_real": [[NeuralType(elements_type=VoidType())]], + "fmaps_gen": [[NeuralType(elements_type=VoidType())]], + } + + @property + def output_types(self): + return { + "loss": NeuralType(elements_type=LossType()), + } + + @typecheck() + def forward(self, fmaps_real, fmaps_gen): + loss = 0.0 + for fmap_real, fmap_gen in zip(fmaps_real, fmaps_gen): + # [B, ..., time] + for feat_real, feat_gen in zip(fmap_real, fmap_gen): + # [B, ...] + feat_mean = torch.mean(torch.abs(feat_real), dim=-1) + diff = torch.mean(torch.abs(feat_real - feat_gen), dim=-1) + feat_loss = diff / (feat_mean + self.div_guard) + # [1] + feat_loss = torch.mean(feat_loss) / len(fmap_real) + loss += feat_loss + + loss /= len(fmaps_real) + + return loss + + +class GeneratorHingedLoss(Loss): + @property + def input_types(self): + return { + "disc_scores_gen": [NeuralType(('B', 'C', 'T'), VoidType())], + } + + @property + def output_types(self): + return {"loss": NeuralType(elements_type=LossType())} + + @typecheck() + def forward(self, disc_scores_gen): + loss = 0.0 + for disc_score_gen in disc_scores_gen: + loss += torch.mean(F.relu(1 - disc_score_gen)) + + loss /= len(disc_scores_gen) + + return loss + + +class GeneratorSquaredLoss(Loss): + @property + def input_types(self): + return { + "disc_scores_gen": [NeuralType(('B', 'C', 'T'), VoidType())], + } + + @property + def output_types(self): + return {"loss": NeuralType(elements_type=LossType())} + + @typecheck() + def forward(self, disc_scores_gen): + loss = 0.0 + for disc_score_gen in disc_scores_gen: + loss += torch.mean((1 - disc_score_gen) ** 2) + + loss /= len(disc_scores_gen) + + return loss + + +class DiscriminatorHingedLoss(Loss): + @property + def input_types(self): + return { + "disc_scores_real": [NeuralType(('B', 'C', 'T'), VoidType())], + "disc_scores_gen": [NeuralType(('B', 'C', 'T'), VoidType())], + } + + @property + def output_types(self): + return {"loss": NeuralType(elements_type=LossType())} + + @typecheck() + def forward(self, disc_scores_real, disc_scores_gen): + loss = 0.0 + for disc_score_real, disc_score_gen in zip(disc_scores_real, disc_scores_gen): + loss_real = torch.mean(F.relu(1 - disc_score_real)) + loss_gen = torch.mean(F.relu(1 + disc_score_gen)) + loss += (loss_real + loss_gen) / 2 + + loss /= len(disc_scores_real) + + return loss + + +class DiscriminatorSquaredLoss(Loss): + @property + def input_types(self): + return { + "disc_scores_real": [NeuralType(('B', 'C', 'T'), VoidType())], + "disc_scores_gen": [NeuralType(('B', 'C', 'T'), VoidType())], + } + + @property + def output_types(self): + return {"loss": NeuralType(elements_type=LossType())} + + @typecheck() + def forward(self, disc_scores_real, disc_scores_gen): + loss = 0.0 + for disc_score_real, disc_score_gen in zip(disc_scores_real, disc_scores_gen): + loss_real = torch.mean((1 - disc_score_real) ** 2) + loss_gen = torch.mean(disc_score_gen ** 2) + loss += (loss_real + loss_gen) / 2 + + loss /= len(disc_scores_real) + + return loss diff --git a/nemo/collections/tts/models/__init__.py b/nemo/collections/tts/models/__init__.py index db02a0d6bda4..4f01ea6a099e 100644 --- a/nemo/collections/tts/models/__init__.py +++ b/nemo/collections/tts/models/__init__.py @@ -13,6 +13,7 @@ # limitations under the License. from nemo.collections.tts.models.aligner import AlignerModel +from nemo.collections.tts.models.audio_codec import AudioCodecModel from nemo.collections.tts.models.fastpitch import FastPitchModel from nemo.collections.tts.models.fastpitch_ssl import FastPitchModel_SSL from nemo.collections.tts.models.hifigan import HifiGanModel @@ -28,6 +29,7 @@ __all__ = [ "AlignerModel", + "AudioCodecModel", "FastPitchModel", "FastPitchModel_SSL", "SSLDisentangler", diff --git a/nemo/collections/tts/models/audio_codec.py b/nemo/collections/tts/models/audio_codec.py new file mode 100644 index 000000000000..30f74dc2be2a --- /dev/null +++ b/nemo/collections/tts/models/audio_codec.py @@ -0,0 +1,385 @@ +# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import itertools +import random +from pathlib import Path +from typing import List, Tuple + +import torch +import torch.nn.functional as F +from hydra.utils import instantiate +from omegaconf import DictConfig, OmegaConf +from pytorch_lightning import Trainer + +from nemo.collections.tts.losses.audio_codec_loss import ( + MultiResolutionMelLoss, + RelativeFeatureMatchingLoss, + TimeDomainLoss, +) +from nemo.collections.tts.modules.common import GaussianDropout +from nemo.collections.tts.parts.utils.callbacks import LoggingCallback +from nemo.collections.tts.parts.utils.helpers import get_batch_size, get_num_workers +from nemo.core import ModelPT +from nemo.core.classes.common import PretrainedModelInfo, typecheck +from nemo.core.neural_types.elements import AudioSignal, EncodedRepresentation, Index, LengthsType +from nemo.core.neural_types.neural_type import NeuralType +from nemo.core.optim.lr_scheduler import compute_max_steps, prepare_lr_scheduler +from nemo.utils import model_utils +from nemo.utils.decorators import experimental + + +@experimental +class AudioCodecModel(ModelPT): + def __init__(self, cfg: DictConfig, trainer: Trainer = None): + # Convert to Hydra 1.0 compatible DictConfig + cfg = model_utils.convert_model_config_to_dict_config(cfg) + cfg = model_utils.maybe_update_config_version(cfg) + + super().__init__(cfg=cfg, trainer=trainer) + + self.sample_rate = cfg.sample_rate + self.samples_per_frame = cfg.samples_per_frame + + self.disc_update_prob = cfg.get("disc_update_prob", 1.0) + self.audio_encoder = instantiate(cfg.audio_encoder) + + # Optionally, add gaussian noise to encoder output as an information bottleneck + encoder_noise_stdev = cfg.get("encoder_noise_stdev", 0.0) + if encoder_noise_stdev: + self.encoder_noise = GaussianDropout(stdev=encoder_noise_stdev) + else: + self.encoder_noise = None + + if "vector_quantizer" in cfg: + self.vector_quantizer = instantiate(cfg.vector_quantizer) + else: + self.vector_quantizer = None + + self.audio_decoder = instantiate(cfg.audio_decoder) + self.discriminator = instantiate(cfg.discriminator) + + mel_loss_dim = cfg.get("mel_loss_dim", 64) + mel_loss_resolutions = cfg.mel_loss_resolutions + self.time_domain_loss_scale = cfg.get("time_domain_loss_scale", 1.0) + self.mel_loss_scale = cfg.get("mel_loss_scale", 1.0) + mel_loss_l1_scale = cfg.get("mel_loss_l1_scale", 1.0) + self.gen_loss_scale = cfg.get("gen_loss_scale", 1.0) + self.feature_loss_scale = cfg.get("feature_loss_scale", 1.0) + + self.time_domain_loss_fn = TimeDomainLoss() + self.mel_loss_fn = MultiResolutionMelLoss( + sample_rate=self.sample_rate, + mel_dim=mel_loss_dim, + resolutions=mel_loss_resolutions, + l1_scale=mel_loss_l1_scale, + ) + self.gen_loss_fn = instantiate(cfg.generator_loss) + self.disc_loss_fn = instantiate(cfg.discriminator_loss) + self.feature_loss_fn = RelativeFeatureMatchingLoss() + + self.log_config = cfg.get("log_config", None) + self.lr_schedule_interval = None + self.automatic_optimization = False + + @typecheck( + input_types={ + "audio": NeuralType(('B', 'T_audio'), AudioSignal()), + "audio_len": NeuralType(tuple('B'), LengthsType()), + }, + output_types={ + "encoded": NeuralType(('B', 'D', 'T_encoded'), EncodedRepresentation()), + "encoded_len": NeuralType(tuple('B'), LengthsType()), + }, + ) + def encode_audio(self, audio: torch.Tensor, audio_len: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + audio, audio_len = self.pad_audio(audio, audio_len) + encoded, encoded_len = self.audio_encoder(audio=audio, audio_len=audio_len) + return encoded, encoded_len + + @typecheck( + input_types={ + "inputs": NeuralType(('B', 'D', 'T_encoded'), EncodedRepresentation()), + "input_len": NeuralType(tuple('B'), LengthsType()), + }, + output_types={ + "audio": NeuralType(('B', 'T_audio'), AudioSignal()), + "audio_len": NeuralType(tuple('B'), LengthsType()), + }, + ) + def decode_audio(self, inputs: torch.Tensor, input_len: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + audio, audio_len = self.audio_decoder(inputs=inputs, input_len=input_len) + return audio, audio_len + + @typecheck( + input_types={ + "encoded": NeuralType(('B', 'D', 'T_encoded'), EncodedRepresentation()), + "encoded_len": NeuralType(tuple('B'), LengthsType()), + }, + output_types={"indices": NeuralType(('N', 'B', 'T_encoded'), Index())}, + ) + def quantize_encode(self, encoded: torch.Tensor, encoded_len: torch.Tensor) -> torch.Tensor: + if not self.vector_quantizer: + raise ValueError("Cannot quantize without quantizer") + + indices = self.vector_quantizer.encode(inputs=encoded, input_len=encoded_len) + return indices + + @typecheck( + input_types={ + "indices": NeuralType(('N', 'B', 'T_encoded'), Index()), + "encoded_len": NeuralType(tuple('B'), LengthsType()), + }, + output_types={"quantized": NeuralType(('B', 'D', 'T_encoded'), EncodedRepresentation()),}, + ) + def quantize_decode(self, indices: torch.Tensor, encoded_len: torch.Tensor) -> torch.Tensor: + if not self.vector_quantizer: + raise ValueError("Cannot dequantize without quantizer") + + quantized = self.vector_quantizer.decode(indices=indices, input_len=encoded_len) + return quantized + + @typecheck( + input_types={ + "audio": NeuralType(('B', 'T_audio'), AudioSignal()), + "audio_len": NeuralType(tuple('B'), LengthsType()), + }, + output_types={ + "output_audio": NeuralType(('B', 'T_audio'), EncodedRepresentation()), + "output_audio_len": NeuralType(tuple('B'), LengthsType()), + }, + ) + def forward(self, audio: torch.Tensor, audio_len: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + audio, audio_len = self.pad_audio(audio, audio_len) + encoded, encoded_len = self.encode_audio(audio=audio, audio_len=audio_len) + + if self.vector_quantizer: + indices = self.quantize_encode(encoded=encoded, encoded_len=encoded_len) + quantized = self.quantize_decode(indices=indices, encoded_len=encoded_len) + output_audio, output_audio_len = self.decode_audio(inputs=quantized, input_len=encoded_len) + else: + output_audio, output_audio_len = self.decode_audio(inputs=encoded, input_len=encoded_len) + + return output_audio, output_audio_len + + # Zero pad the end of the audio so that we do not have a partial end frame. + def pad_audio(self, audio, audio_len): + padded_len = self.samples_per_frame * torch.ceil(audio_len / self.samples_per_frame).int() + max_len = padded_len.max().item() + num_padding = max_len - audio.shape[1] + padded_audio = F.pad(audio, (0, num_padding)) + return padded_audio, padded_len + + def _process_batch(self, batch): + # [B, T_audio] + audio = batch.get("audio") + # [B] + audio_len = batch.get("audio_lens") + audio, audio_len = self.pad_audio(audio, audio_len) + + # [B, D, T_encoded] + encoded, encoded_len = self.audio_encoder(audio=audio, audio_len=audio_len) + + if self.encoder_noise is not None: + encoded = self.encoder_noise(encoded) + + if self.vector_quantizer: + encoded, _, commit_loss = self.vector_quantizer(inputs=encoded, input_len=encoded_len) + else: + commit_loss = None + + # [B, T] + audio_gen, audio_gen_len = self.audio_decoder(inputs=encoded, input_len=encoded_len) + + return audio, audio_len, audio_gen, commit_loss + + def training_step(self, batch, batch_idx): + optim_gen, optim_disc = self.optimizers() + optim_gen.zero_grad() + + audio, audio_len, audio_gen, commit_loss = self._process_batch(batch) + + if self.disc_update_prob < random.random(): + loss_disc = None + else: + # Train discriminator + optim_disc.zero_grad() + + disc_scores_real, disc_scores_gen, _, _ = self.discriminator( + audio_real=audio, audio_gen=audio_gen.detach() + ) + loss_disc = self.disc_loss_fn(disc_scores_real=disc_scores_real, disc_scores_gen=disc_scores_gen) + train_disc_loss = loss_disc + + self.manual_backward(train_disc_loss) + optim_disc.step() + + loss_time_domain = self.time_domain_loss_fn(audio_real=audio, audio_gen=audio_gen, audio_len=audio_len) + train_loss_time_domain = self.time_domain_loss_scale * loss_time_domain + + loss_mel = self.mel_loss_fn(audio_real=audio, audio_gen=audio_gen, audio_len=audio_len) + train_loss_mel = self.mel_loss_scale * loss_mel + + _, disc_scores_gen, fmaps_real, fmaps_gen = self.discriminator(audio_real=audio, audio_gen=audio_gen) + + loss_gen = self.gen_loss_fn(disc_scores_gen=disc_scores_gen) + train_loss_gen = self.gen_loss_scale * loss_gen + + loss_feature = self.feature_loss_fn(fmaps_real=fmaps_real, fmaps_gen=fmaps_gen) + train_loss_feature = self.feature_loss_scale * loss_feature + + loss_gen_all = train_loss_time_domain + train_loss_mel + train_loss_gen + train_loss_feature + if commit_loss is not None: + loss_gen_all += commit_loss + + self.manual_backward(loss_gen_all) + optim_gen.step() + + self.update_lr() + + metrics = { + "g_loss_time_domain": loss_time_domain, + "g_loss_mel": loss_mel, + "g_loss_gen": loss_gen, + "g_loss_feature": loss_feature, + "g_loss": loss_gen_all, + "global_step": self.global_step, + "lr": optim_gen.param_groups[0]['lr'], + } + + if loss_disc is not None: + metrics["d_loss"] = loss_disc + + if commit_loss is not None: + metrics["g_loss_commit"] = commit_loss + + self.log_dict(metrics, on_step=True, sync_dist=True) + self.log("t_loss", train_loss_mel, prog_bar=True, logger=False, sync_dist=True) + + def training_epoch_end(self, outputs): + self.update_lr("epoch") + + def validation_step(self, batch, batch_idx): + audio, audio_len, audio_gen, _ = self._process_batch(batch) + loss_audio = self.time_domain_loss_fn(audio_real=audio, audio_gen=audio_gen, audio_len=audio_len) + loss_mel = self.mel_loss_fn(audio_real=audio, audio_gen=audio_gen, audio_len=audio_len) + metrics = {"val_loss": loss_audio + loss_mel, "val_loss_audio": loss_audio, "val_loss_mel": loss_mel} + self.log_dict(metrics, on_epoch=True, sync_dist=True) + + @staticmethod + def _setup_train_dataloader(cfg): + dataset = instantiate(cfg.dataset) + sampler = dataset.get_sampler(cfg.dataloader_params.batch_size) + data_loader = torch.utils.data.DataLoader( + dataset, collate_fn=dataset.collate_fn, sampler=sampler, **cfg.dataloader_params + ) + return data_loader + + @staticmethod + def _setup_test_dataloader(cfg): + dataset = instantiate(cfg.dataset) + data_loader = torch.utils.data.DataLoader(dataset, collate_fn=dataset.collate_fn, **cfg.dataloader_params) + return data_loader + + def setup_training_data(self, cfg): + self._train_dl = self._setup_train_dataloader(cfg) + + def setup_validation_data(self, cfg): + self._validation_dl = self._setup_test_dataloader(cfg) + + def setup_test_data(self, cfg): + pass + + @property + def max_steps(self): + if "max_steps" in self._cfg: + return self._cfg.get("max_steps") + + if "max_epochs" not in self._cfg: + raise ValueError("Must specify 'max_steps' or 'max_epochs'.") + + if "steps_per_epoch" in self._cfg: + return self._cfg.max_epochs * self._cfg.steps_per_epoch + + return compute_max_steps( + max_epochs=self._cfg.max_epochs, + accumulate_grad_batches=self.trainer.accumulate_grad_batches, + limit_train_batches=self.trainer.limit_train_batches, + num_workers=get_num_workers(self.trainer), + num_samples=len(self._train_dl.dataset), + batch_size=get_batch_size(self._train_dl), + drop_last=self._train_dl.drop_last, + ) + + def configure_optimizers(self): + optim_config = self._cfg.optim.copy() + + OmegaConf.set_struct(optim_config, False) + sched_config = optim_config.pop("sched", None) + OmegaConf.set_struct(optim_config, True) + + gen_params = itertools.chain(self.audio_encoder.parameters(), self.audio_decoder.parameters()) + disc_params = self.discriminator.parameters() + optim_g = instantiate(optim_config, params=gen_params) + optim_d = instantiate(optim_config, params=disc_params) + + if sched_config is None: + return [optim_g, optim_d] + + OmegaConf.set_struct(sched_config, False) + sched_config["max_steps"] = self.max_steps + OmegaConf.set_struct(sched_config, True) + + scheduler_g = prepare_lr_scheduler( + optimizer=optim_g, scheduler_config=sched_config, train_dataloader=self._train_dl + ) + + scheduler_d = prepare_lr_scheduler( + optimizer=optim_d, scheduler_config=sched_config, train_dataloader=self._train_dl + ) + + self.lr_schedule_interval = scheduler_g["interval"] + + return [optim_g, optim_d], [scheduler_g, scheduler_d] + + def update_lr(self, interval="step"): + schedulers = self.lr_schedulers() + if schedulers is not None and self.lr_schedule_interval == interval: + sch1, sch2 = schedulers + sch1.step() + sch2.step() + + def configure_callbacks(self): + if not self.log_config: + return [] + + data_loader = self._setup_test_dataloader(self.log_config) + generators = instantiate(self.log_config.generators) + log_dir = Path(self.log_config.log_dir) if self.log_config.log_dir else None + log_callback = LoggingCallback( + generators=generators, + data_loader=data_loader, + log_epochs=self.log_config.log_epochs, + epoch_frequency=self.log_config.epoch_frequency, + output_dir=log_dir, + loggers=self.trainer.loggers, + log_tensorboard=self.log_config.log_tensorboard, + log_wandb=self.log_config.log_wandb, + ) + + return [log_callback] + + @classmethod + def list_available_models(cls) -> List[PretrainedModelInfo]: + return [] diff --git a/nemo/collections/tts/modules/audio_codec_modules.py b/nemo/collections/tts/modules/audio_codec_modules.py new file mode 100644 index 000000000000..ba1c8aa9a348 --- /dev/null +++ b/nemo/collections/tts/modules/audio_codec_modules.py @@ -0,0 +1,515 @@ +# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Iterable, Optional, Tuple + +import torch +import torch.nn as nn +from einops import rearrange + +from nemo.collections.tts.parts.utils.helpers import mask_sequence_tensor +from nemo.core.classes.module import NeuralModule +from nemo.core.neural_types.elements import AudioSignal, EncodedRepresentation, LengthsType, VoidType +from nemo.core.neural_types.neural_type import NeuralType + + +def get_padding(kernel_size: int, dilation: int = 1) -> int: + return (kernel_size * dilation - dilation) // 2 + + +def get_padding_2d(kernel_size: Tuple[int, int], dilation: Tuple[int, int]) -> Tuple[int, int]: + paddings = (get_padding(kernel_size[0], dilation[0]), get_padding(kernel_size[1], dilation[1])) + return paddings + + +def get_down_sample_padding(kernel_size: int, stride: int) -> int: + return (kernel_size - stride + 1) // 2 + + +def get_up_sample_padding(kernel_size: int, stride: int) -> Tuple[int, int]: + output_padding = (kernel_size - stride) % 2 + padding = (kernel_size - stride + 1) // 2 + return padding, output_padding + + +class Conv1dNorm(NeuralModule): + def __init__( + self, in_channels: int, out_channels: int, kernel_size: int, stride: int = 1, padding: Optional[int] = None + ): + super().__init__() + if not padding: + padding = get_padding(kernel_size) + conv = nn.Conv1d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + padding_mode="reflect", + ) + self.conv = nn.utils.weight_norm(conv) + + @property + def input_types(self): + return { + "inputs": NeuralType(('B', 'C', 'T'), VoidType()), + "lengths": NeuralType(tuple('B'), LengthsType()), + } + + @property + def output_types(self): + return { + "out": [NeuralType(('B', 'C', 'T'), VoidType())], + } + + def remove_weight_norm(self): + nn.utils.remove_weight_norm(self.conv) + + def forward(self, inputs, lengths): + out = self.conv(inputs) + out = mask_sequence_tensor(out, lengths) + return out + + +class ConvTranspose1dNorm(NeuralModule): + def __init__(self, in_channels: int, out_channels: int, kernel_size: int, stride: int = 1): + super().__init__() + padding, output_padding = get_up_sample_padding(kernel_size, stride) + conv = nn.ConvTranspose1d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + output_padding=output_padding, + padding_mode="zeros", + ) + self.conv = nn.utils.weight_norm(conv) + + @property + def input_types(self): + return { + "inputs": NeuralType(('B', 'C', 'T'), VoidType()), + "lengths": NeuralType(tuple('B'), LengthsType()), + } + + @property + def output_types(self): + return { + "out": [NeuralType(('B', 'C', 'T'), VoidType())], + } + + def remove_weight_norm(self): + nn.utils.remove_weight_norm(self.conv) + + def forward(self, inputs, lengths): + out = self.conv(inputs) + out = mask_sequence_tensor(out, lengths) + return out + + +class Conv2dNorm(NeuralModule): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: Tuple[int, int], + stride: Tuple[int, int] = (1, 1), + dilation: Tuple[int, int] = (1, 1), + ): + super().__init__() + assert len(kernel_size) == len(dilation) + padding = get_padding_2d(kernel_size, dilation) + conv = nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + dilation=dilation, + padding=padding, + padding_mode="reflect", + ) + self.conv = nn.utils.weight_norm(conv) + + @property + def input_types(self): + return { + "inputs": NeuralType(('B', 'C', 'H', 'T'), VoidType()), + } + + @property + def output_types(self): + return { + "out": [NeuralType(('B', 'C', 'H', 'T'), VoidType())], + } + + def remove_weight_norm(self): + nn.utils.remove_weight_norm(self.conv) + + def forward(self, inputs): + return self.conv(inputs) + + +class SEANetResnetBlock(NeuralModule): + def __init__(self, channels: int): + super().__init__() + self.activation = nn.ELU() + hidden_channels = channels // 2 + self.pre_conv = Conv1dNorm(in_channels=channels, out_channels=channels, kernel_size=1) + self.res_conv1 = Conv1dNorm(in_channels=channels, out_channels=hidden_channels, kernel_size=3) + self.res_conv2 = Conv1dNorm(in_channels=hidden_channels, out_channels=channels, kernel_size=1) + + @property + def input_types(self): + return { + "inputs": NeuralType(('B', 'C', 'T_input'), VoidType()), + "lengths": NeuralType(tuple('B'), LengthsType()), + } + + @property + def output_types(self): + return { + "out": [NeuralType(('B', 'C', 'T_out'), VoidType())], + } + + def remove_weight_norm(self): + self.pre_conv.remove_weight_norm() + self.res_conv1.remove_weight_norm() + self.res_conv2.remove_weight_norm() + + def forward(self, inputs, lengths): + res = self.activation(inputs) + res = self.res_conv1(res, lengths) + res = self.activation(res) + res = self.res_conv2(res, lengths) + + out = self.pre_conv(inputs, lengths) + res + out = mask_sequence_tensor(out, lengths) + return out + + +class SEANetRNN(NeuralModule): + def __init__(self, dim: int, num_layers: int, rnn_type: str = "lstm", use_skip: bool = False): + super().__init__() + self.use_skip = use_skip + if rnn_type == "lstm": + self.rnn = torch.nn.LSTM(input_size=dim, hidden_size=dim, num_layers=num_layers) + elif rnn_type == "gru": + self.rnn = torch.nn.GRU(input_size=dim, hidden_size=dim, num_layers=num_layers) + else: + raise ValueError(f"Unknown RNN type {rnn_type}") + + @property + def input_types(self): + return { + "inputs": NeuralType(('B', 'C', 'T'), VoidType()), + "lengths": NeuralType(tuple('B'), LengthsType()), + } + + @property + def output_types(self): + return { + "out": [NeuralType(('B', 'C', 'T'), VoidType())], + } + + def forward(self, inputs, lengths): + inputs = rearrange(inputs, "B C T -> B T C") + + packed_inputs = nn.utils.rnn.pack_padded_sequence( + inputs, lengths=lengths.cpu(), batch_first=True, enforce_sorted=False + ) + packed_out, _ = self.rnn(packed_inputs) + out, _ = nn.utils.rnn.pad_packed_sequence(packed_out, batch_first=True) + + if self.use_skip: + out = out + inputs + + out = rearrange(out, "B T C -> B C T") + return out + + +class SEANetEncoder(NeuralModule): + def __init__( + self, + down_sample_rates: Iterable[int] = (2, 4, 5, 8), + base_channels: int = 32, + in_kernel_size: int = 7, + out_kernel_size: int = 7, + encoded_dim: int = 128, + rnn_layers: int = 2, + rnn_type: str = "lstm", + rnn_skip: bool = True, + ): + assert in_kernel_size > 0 + assert out_kernel_size > 0 + + super().__init__() + + self.down_sample_rates = down_sample_rates + self.activation = nn.ELU() + self.pre_conv = Conv1dNorm(in_channels=1, out_channels=base_channels, kernel_size=in_kernel_size) + + in_channels = base_channels + self.res_blocks = nn.ModuleList([]) + self.down_sample_conv_layers = nn.ModuleList([]) + for i, down_sample_rate in enumerate(self.down_sample_rates): + res_block = SEANetResnetBlock(channels=in_channels) + self.res_blocks.append(res_block) + + out_channels = 2 * in_channels + kernel_size = 2 * down_sample_rate + down_sample_conv = Conv1dNorm( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=down_sample_rate, + padding=get_down_sample_padding(kernel_size, down_sample_rate), + ) + in_channels = out_channels + self.down_sample_conv_layers.append(down_sample_conv) + + self.rnn = SEANetRNN(dim=in_channels, num_layers=rnn_layers, rnn_type=rnn_type, use_skip=rnn_skip) + self.post_conv = Conv1dNorm(in_channels=in_channels, out_channels=encoded_dim, kernel_size=out_kernel_size) + + @property + def input_types(self): + return { + "audio": NeuralType(('B', 'C', 'T_audio'), AudioSignal()), + "audio_len": NeuralType(tuple('B'), LengthsType()), + } + + @property + def output_types(self): + return { + "encoded": [NeuralType(('B', 'D', 'T_encoded'), EncodedRepresentation())], + "encoded_len": [NeuralType(tuple('B'), LengthsType())], + } + + def remove_weight_norm(self): + self.pre_conv.remove_weight_norm() + for res_block in self.res_blocks: + res_block.remove_weight_norm() + for down_sample_conv in self.down_sample_conv_layers: + down_sample_conv.remove_weight_norm() + + def forward(self, audio, audio_len): + encoded_len = audio_len + audio = rearrange(audio, "B T -> B 1 T") + # [B, C, T_audio] + out = self.pre_conv(audio, encoded_len) + for res_block, down_sample_conv, down_sample_rate in zip( + self.res_blocks, self.down_sample_conv_layers, self.down_sample_rates + ): + # [B, C, T] + out = res_block(out, encoded_len) + out = self.activation(out) + + encoded_len = encoded_len // down_sample_rate + # [B, 2 * C, T / down_sample_rate] + out = down_sample_conv(out, encoded_len) + + out = self.rnn(out, encoded_len) + out = self.activation(out) + # [B, encoded_dim, T_encoded] + encoded = self.post_conv(out, encoded_len) + return encoded, encoded_len + + +class SEANetDecoder(NeuralModule): + def __init__( + self, + up_sample_rates: Iterable[int] = (8, 5, 4, 2), + base_channels: int = 512, + in_kernel_size: int = 7, + out_kernel_size: int = 3, + encoded_dim: int = 128, + rnn_layers: int = 2, + rnn_type: str = "lstm", + rnn_skip: bool = True, + ): + assert in_kernel_size > 0 + assert out_kernel_size > 0 + + super().__init__() + + self.up_sample_rates = up_sample_rates + self.activation = nn.ELU() + self.pre_conv = Conv1dNorm(in_channels=encoded_dim, out_channels=base_channels, kernel_size=in_kernel_size) + self.rnn = SEANetRNN(dim=base_channels, num_layers=rnn_layers, rnn_type=rnn_type, use_skip=rnn_skip) + + in_channels = base_channels + self.res_blocks = nn.ModuleList([]) + self.up_sample_conv_layers = nn.ModuleList([]) + for i, up_sample_rate in enumerate(self.up_sample_rates): + out_channels = in_channels // 2 + kernel_size = 2 * up_sample_rate + up_sample_conv = ConvTranspose1dNorm( + in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=up_sample_rate + ) + in_channels = out_channels + self.up_sample_conv_layers.append(up_sample_conv) + + res_block = SEANetResnetBlock(channels=in_channels) + self.res_blocks.append(res_block) + + self.post_conv = Conv1dNorm(in_channels=in_channels, out_channels=1, kernel_size=out_kernel_size) + self.out_activation = nn.Tanh() + + @property + def input_types(self): + return { + "inputs": [NeuralType(('B', 'D', 'T_encoded'), EncodedRepresentation())], + "input_len": [NeuralType(tuple('B'), LengthsType())], + } + + @property + def output_types(self): + return { + "audio": NeuralType(('B', 'C', 'T_audio'), AudioSignal()), + "audio_len": NeuralType(tuple('B'), LengthsType()), + } + + def remove_weight_norm(self): + self.pre_conv.remove_weight_norm() + for up_sample_conv in self.up_sample_conv_layers: + up_sample_conv.remove_weight_norm() + for res_block in self.res_blocks: + res_block.remove_weight_norm() + + def forward(self, inputs, input_len): + audio_len = input_len + # [B, C, T_encoded] + out = self.pre_conv(inputs, audio_len) + out = self.rnn(out, audio_len) + for res_block, up_sample_conv, up_sample_rate in zip( + self.res_blocks, self.up_sample_conv_layers, self.up_sample_rates + ): + audio_len *= up_sample_rate + out = self.activation(out) + # [B, C / 2, T * up_sample_rate] + out = up_sample_conv(out, audio_len) + out = res_block(out, audio_len) + + out = self.activation(out) + # [B, 1, T_audio] + out = self.post_conv(out, audio_len) + audio = self.out_activation(out) + audio = rearrange(audio, "B 1 T -> B T") + return audio, audio_len + + +class DiscriminatorSTFT(NeuralModule): + def __init__(self, resolution, lrelu_slope=0.1): + super().__init__() + + self.n_fft, self.hop_length, self.win_length = resolution + self.register_buffer("window", torch.hann_window(self.win_length, periodic=False)) + self.activation = nn.LeakyReLU(lrelu_slope) + + self.conv_layers = nn.ModuleList( + [ + Conv2dNorm(2, 32, kernel_size=(3, 9)), + Conv2dNorm(32, 32, kernel_size=(3, 9), dilation=(1, 1), stride=(1, 2)), + Conv2dNorm(32, 32, kernel_size=(3, 9), dilation=(2, 1), stride=(1, 2)), + Conv2dNorm(32, 32, kernel_size=(3, 9), dilation=(4, 1), stride=(1, 2)), + Conv2dNorm(32, 32, kernel_size=(3, 3)), + ] + ) + self.conv_post = Conv2dNorm(32, 1, kernel_size=(3, 3)) + + def stft(self, audio): + # [B, fft, T_spec] + out = torch.stft( + audio, + n_fft=self.n_fft, + hop_length=self.hop_length, + win_length=self.win_length, + window=self.window, + normalized=True, + center=True, + return_complex=True, + ) + out = rearrange(out, "B fft T -> B 1 T fft") + # [batch, 2, T_spec, fft] + out = torch.cat([out.real, out.imag], dim=1) + return out + + @property + def input_types(self): + return { + "audio": NeuralType(('B', 'T_audio'), AudioSignal()), + } + + @property + def output_types(self): + return { + "scores": NeuralType(('B', 'C', 'T_spec'), VoidType()), + "fmap": [NeuralType(('B', 'D', 'T_spec', 'C'), VoidType())], + } + + def forward(self, audio): + fmap = [] + + # [batch, 2, T_spec, fft] + out = self.stft(audio) + for conv in self.conv_layers: + # [batch, filters, T_spec, fft // 2**i] + out = conv(out) + out = self.activation(out) + fmap.append(out) + # [batch, 1, T_spec, fft // 8] + scores = self.conv_post(out) + fmap.append(scores) + scores = rearrange(scores, "B 1 T C -> B C T") + + return scores, fmap + + +class MultiResolutionDiscriminatorSTFT(NeuralModule): + def __init__(self, resolutions): + super().__init__() + self.discriminators = nn.ModuleList([DiscriminatorSTFT(res) for res in resolutions]) + + @property + def input_types(self): + return { + "audio": NeuralType(('B', 'T_audio'), AudioSignal()), + "audio_gen": NeuralType(('B', 'T_audio'), AudioSignal()), + } + + @property + def output_types(self): + return { + "scores_real": [NeuralType(('B', 'C', 'T_spec'), VoidType())], + "scores_gen": [NeuralType(('B', 'C', 'T_spec'), VoidType())], + "fmaps_real": [[NeuralType(('B', 'D', 'T_spec', 'C'), VoidType())]], + "fmaps_gen": [[NeuralType(('B', 'D', 'T_spec', 'C'), VoidType())]], + } + + def forward(self, audio_real, audio_gen): + scores_real = [] + scores_gen = [] + fmaps_real = [] + fmaps_gen = [] + + for disc in self.discriminators: + score_real, fmap_real = disc(audio=audio_real) + scores_real.append(score_real) + fmaps_real.append(fmap_real) + + score_gen, fmap_gen = disc(audio=audio_gen) + scores_gen.append(score_gen) + fmaps_gen.append(fmap_gen) + + return scores_real, scores_gen, fmaps_real, fmaps_gen diff --git a/nemo/collections/tts/modules/common.py b/nemo/collections/tts/modules/common.py index 7f6652f8455d..5f7d6153a7d1 100644 --- a/nemo/collections/tts/modules/common.py +++ b/nemo/collections/tts/modules/common.py @@ -764,3 +764,30 @@ def forward(self, queries, keys, query_lens, mask=None, key_lens=None, attn_prio attn = self.softmax(attn) # softmax along T2 return attn, attn_logprob + + +class GaussianDropout(torch.nn.Module): + """ + Gaussian dropout using multiplicative gaussian noise. + + https://keras.io/api/layers/regularization_layers/gaussian_dropout/ + + Can be an effective alternative bottleneck to VAE or VQ: + + https://www.deepmind.com/publications/gaussian-dropout-as-an-information-bottleneck-layer + + Unlike some other implementations, this takes the standard deviation of the noise as input + instead of the 'rate' typically defined as: stdev = sqrt(rate / (1 - rate)) + """ + + def __init__(self, stdev=1.0): + super(GaussianDropout, self).__init__() + self.stdev = stdev + + def forward(self, inputs): + if not self.training: + return inputs + + noise = torch.normal(mean=1.0, std=self.stdev, size=inputs.shape, device=inputs.device) + out = noise * inputs + return out diff --git a/nemo/collections/tts/modules/vector_quantization.py b/nemo/collections/tts/modules/vector_quantization.py new file mode 100644 index 000000000000..ac4b3a3f9aa3 --- /dev/null +++ b/nemo/collections/tts/modules/vector_quantization.py @@ -0,0 +1,429 @@ +# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange, repeat +from torch import Tensor + +from nemo.collections.tts.losses.audio_codec_loss import MaskedMSELoss +from nemo.collections.tts.parts.utils.distributed import broadcast_tensors +from nemo.collections.tts.parts.utils.helpers import mask_sequence_tensor +from nemo.core.classes.common import typecheck +from nemo.core.classes.module import NeuralModule +from nemo.core.neural_types.elements import EncodedRepresentation, Index, LengthsType, LossType +from nemo.core.neural_types.neural_type import NeuralType +from nemo.utils.decorators import experimental + + +def _ema_inplace(moving_avg: Tensor, new: Tensor, decay: float) -> None: + moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay)) + + +def _laplace_smoothing(inputs: Tensor, n_categories: int, epsilon: float = 1e-5) -> Tensor: + input_sum = inputs.sum() + smoothed = (inputs + epsilon) / (input_sum + n_categories * epsilon) + return input_sum * smoothed + + +def _compute_distances(input1: Tensor, input2: Tensor) -> Tensor: + """ + Compute pairwise L2 distance between two input tensors + + Args: + input1: [B, D] first tensor. + input2: [N, D] second tensor. + + Returns: + [(B, D)] tensor of distances. + """ + input2 = rearrange(input2, "N D -> D N") + distances = input1.pow(2).sum(1, keepdim=True) - (2 * input1 @ input2) + input2.pow(2).sum(0, keepdim=True) + return distances + + +def _sample_vectors(samples: Tensor, num_sample: int) -> Tensor: + """ + Randomly sample from the input batch. + + Args: + samples: [B, D] tensor with features to sample. + num_sample: Number of samples to draw. + If the value is less than or equal to B, then the samples will be unique. + If the value is greater than B, then samples will be drawn with replacement. + + Returns: + Tensor with num_sample values randomly sampled from the input batch. + """ + device = samples.device + total_samples = samples.shape[0] + + if total_samples >= num_sample: + indices = torch.randperm(total_samples, device=device)[:num_sample] + else: + indices = torch.randint(low=0, high=total_samples, size=(num_sample,), device=device) + + return samples[indices] + + +def _k_means(samples: Tensor, num_clusters: int, num_iters: int = 10) -> Tuple[Tensor, Tensor]: + """ + K-means clustering algorithm. + + Args: + samples: [B, D] tensor with features to cluster + num_clusters: K, the number of clusters. + num_iters: Number of iterations of K-means to run. + + Returns: + [K, D] cluster means and [K] bins counting how many input samples belong to each cluster + """ + assert num_iters > 0 + + input_dim = samples.shape[1] + # [K, D] + means = _sample_vectors(samples=samples, num_sample=num_clusters) + + for _ in range(num_iters): + # [B, K] + dists = _compute_distances(samples, means) + + # [N] + buckets = dists.min(dim=1).indices + buckets_repeated = repeat(buckets, "B -> B D", D=input_dim) + # [K] + bin_counts = torch.bincount(buckets, minlength=num_clusters) + bin_counts_expanded = rearrange(bin_counts, "K -> K ()") + + # [K, D] + new_means = buckets.new_zeros(num_clusters, input_dim, dtype=samples.dtype) + new_means.scatter_add_(dim=0, index=buckets_repeated, src=samples) + new_means = new_means / torch.clamp(bin_counts_expanded, min=1) + means = torch.where(bin_counts_expanded == 0, means, new_means) + + return means, bin_counts + + +def _mask_3d(tensor: Tensor, lengths: Tensor): + """ + Mask 3d tensor with time on 1st axis. + + Args: + tensor: tensor of shape (B, T, D) + lengths: LongTensor of shape (B,) + Returns: + Masked Tensor (B, T, D) + """ + batch_size, max_lengths, _ = tensor.shape + mask = torch.ones(batch_size, max_lengths, 1).cumsum(dim=1).type_as(lengths) + mask = mask <= rearrange(lengths, "b -> b 1 1") + return tensor * mask + + +@experimental +class EuclideanCodebook(NeuralModule): + """ + Codebook with Euclidean distance. + + Args: + codebook_size: Number of codes to use. + codebook_dim: Dimension of each code. + decay: Decay for exponential moving average over the codebooks. + threshold_ema_dead_code: Threshold for dead code expiration. + During every iteration, replace codes with exponential moving average cluster size less than threshold + with randomly selected values from the current batch. + kmeans_iters: Optional int, if provided codes will be initialized from the centroids learned from + kmeans_iters iterations of k-means clustering on the first training batch. + """ + + def __init__( + self, + codebook_size: int, + codebook_dim: int, + decay: float = 0.99, + threshold_ema_dead_code: Optional[int] = 2, + kmeans_iters: Optional[int] = None, + ): + super().__init__() + self.decay = decay + + if kmeans_iters: + codes = nn.init.kaiming_uniform_(torch.empty(codebook_size, codebook_dim)) + else: + codes = torch.zeros(codebook_size, codebook_dim) + + self.codebook_size = codebook_size + + self.kmeans_iters = kmeans_iters + self.threshold_ema_dead_code = threshold_ema_dead_code + + self.register_buffer("initialized", Tensor([not kmeans_iters])) + self.register_buffer("cluster_size", torch.zeros(codebook_size)) + self.register_buffer("codes", codes) + self.register_buffer("codes_avg", codes.clone()) + + @torch.jit.ignore + def _init_codes(self, data): + if self.initialized: + return + + codes, cluster_size = _k_means(samples=data, num_clusters=self.codebook_size, num_iters=self.kmeans_iters) + self.codes.data.copy_(codes) + self.codes_avg.data.copy_(codes.clone()) + self.cluster_size.data.copy_(cluster_size) + self.initialized.data.copy_(Tensor([True])) + broadcast_tensors(self.buffers()) + + def _expire_codes(self, inputs: Tensor) -> None: + if not self.threshold_ema_dead_code: + return + + expired_codes = self.cluster_size < self.threshold_ema_dead_code + if not torch.any(expired_codes): + return + + samples = _sample_vectors(samples=inputs, num_sample=self.codebook_size) + expired_codes = rearrange(expired_codes, "K -> K ()") + modified_codes = torch.where(expired_codes, samples, self.codes) + self.codes.data.copy_(modified_codes) + + broadcast_tensors(self.buffers()) + + def _update_codes(self, inputs: Tensor, indices: Tensor) -> None: + code_onehot = F.one_hot(indices, self.codebook_size).type(inputs.dtype) + code_onehot = rearrange(code_onehot, "B N -> N B") + # [N] + code_counts = code_onehot.sum(1) + _ema_inplace(moving_avg=self.cluster_size, new=code_counts, decay=self.decay) + # [N, D] + code_sum = code_onehot @ inputs + _ema_inplace(moving_avg=self.codes_avg, new=code_sum, decay=self.decay) + + cluster_size_smoothed = _laplace_smoothing(self.cluster_size, n_categories=self.codebook_size) + cluster_size_smoothed = rearrange(cluster_size_smoothed, "N -> N ()") + codes_normalized = self.codes_avg / cluster_size_smoothed + self.codes.data.copy_(codes_normalized) + + def _quantize(self, inputs: Tensor) -> Tensor: + # [B, N] + dist = _compute_distances(inputs, self.codes) + # [B] + indices = dist.min(dim=1).indices + return indices + + def _dequantize(self, indices: Tensor) -> Tensor: + # [B, D] + quantized = F.embedding(indices, self.codes) + return quantized + + @property + def input_types(self): + return { + "inputs": NeuralType(('B', 'T', 'D'), EncodedRepresentation()), + "input_len": NeuralType(tuple('B'), LengthsType()), + } + + @property + def output_types(self): + return { + "quantized": NeuralType(('B', 'T', 'D'), EncodedRepresentation()), + "indices": NeuralType(('B', 'T'), Index()), + } + + def forward(self, inputs, input_len): + input_flat = rearrange(inputs, "B T D -> (B T) D") + self._init_codes(input_flat) + # [(B T)] + indices_flat = self._quantize(inputs=input_flat) + # [B, T] + indices = indices_flat.view(*inputs.shape[:-1]) + # [B, T, D] + quantized = self._dequantize(indices=indices) + + if self.training: + # We do expiry of codes here because buffers are in sync and all the workers will make the same decision. + self._expire_codes(inputs=input_flat) + self._update_codes(inputs=input_flat, indices=indices_flat) + + quantized = _mask_3d(quantized, input_len) + indices = mask_sequence_tensor(indices, input_len) + return quantized, indices + + @typecheck( + input_types={ + "inputs": NeuralType(('B', 'T', 'D'), EncodedRepresentation()), + "input_len": NeuralType(tuple('B'), LengthsType()), + }, + output_types={"indices": NeuralType(('B', 'T'), Index())}, + ) + def encode(self, inputs, input_len): + input_flat = rearrange(inputs, "B T D -> (B T) D") + # [(B T)] + indices_flat = self._quantize(inputs=input_flat) + # [B, T] + indices = indices_flat.view(*inputs.shape[:-1]) + indices = mask_sequence_tensor(indices, input_len) + return indices + + @typecheck( + input_types={"indices": NeuralType(('B', 'T'), Index()), "input_len": NeuralType(tuple('B'), LengthsType()),}, + output_types={"quantized": NeuralType(('B', 'T', 'D'), EncodedRepresentation())}, + ) + def decode(self, indices, input_len): + # [B, T, D] + quantized = self._dequantize(indices=indices) + quantized = _mask_3d(quantized, input_len) + return quantized + + +class ResidualVectorQuantizer(NeuralModule): + """ + Residual vector quantization (RVQ) algorithm as described in https://arxiv.org/pdf/2107.03312.pdf. + + Args: + num_codebooks: Number of codebooks to use. + commit_loss_scale: Loss scale for codebook commit loss. + codebook_size: Number of codes to use for each codebook. + codebook_dim: Dimension of each code. + decay: Decay for exponential moving average over the codebooks. + threshold_ema_dead_code: Threshold for dead code expiration. + During every iteration, replace codes with exponential moving average cluster size less than threshold + with randomly selected values from the current batch. + kmeans_iters: Optional int, if provided codes will be initialized from the centroids learned from + kmeans_iters iterations of k-means clustering on the first training batch. + """ + + def __init__( + self, + num_codebooks: int, + commit_loss_scale: float = 1.0, + codebook_size: int = 1024, + codebook_dim: int = 128, + decay: float = 0.99, + threshold_ema_dead_code: Optional[int] = 2, + kmeans_iters: Optional[int] = 50, + ): + super().__init__() + self.codebook_dim = codebook_dim + + if commit_loss_scale: + self.commit_loss_fn = MaskedMSELoss(loss_scale=commit_loss_scale) + else: + self.commit_loss_fn = None + + self.codebooks = nn.ModuleList( + [ + EuclideanCodebook( + codebook_size=codebook_size, + codebook_dim=codebook_dim, + decay=decay, + threshold_ema_dead_code=threshold_ema_dead_code, + kmeans_iters=kmeans_iters, + ) + for _ in range(num_codebooks) + ] + ) + + def _commit_loss(self, input, target, input_len): + if not self.commit_loss_fn: + return 0.0 + + return self.commit_loss_fn( + predicted=rearrange(input, "B T D -> B D T"), + target=rearrange(target, "B T D -> B D T"), + target_len=input_len, + ) + + @property + def input_types(self): + return { + "inputs": NeuralType(('B', 'D', 'T'), EncodedRepresentation()), + "input_len": NeuralType(tuple('B'), LengthsType()), + } + + @property + def output_types(self): + return { + "quantized": NeuralType(('B', 'D', 'T'), EncodedRepresentation()), + "indices": NeuralType(('B', 'T'), Index()), + "commit_loss": NeuralType((), LossType()), + } + + def forward(self, inputs: Tensor, input_len: Tensor) -> Tuple[Tensor, Tensor, float]: + commit_loss = 0.0 + residual = rearrange(inputs, "B D T -> B T D") + + index_list = [] + quantized = torch.zeros_like(residual) + for codebook in self.codebooks: + quantized_i, indices_i = codebook(inputs=residual, input_len=input_len) + + if self.training: + quantized_i = residual + (quantized_i - residual).detach() + quantized_i_const = quantized_i.detach() + commit_loss_i = self._commit_loss(input=residual, target=quantized_i_const, input_len=input_len) + commit_loss = commit_loss + commit_loss_i + + residual = residual - quantized_i_const + + else: + residual = residual - quantized_i + + quantized = quantized + quantized_i + index_list.append(indices_i) + + # [N, B, T] + indices = torch.stack(index_list) + quantized = rearrange(quantized, "B T D -> B D T") + return quantized, indices, commit_loss + + @typecheck( + input_types={ + "inputs": NeuralType(('B', 'D', 'T'), EncodedRepresentation()), + "input_len": NeuralType(tuple('B'), LengthsType()), + }, + output_types={"indices": NeuralType(('N', 'B', 'T'), Index())}, + ) + def encode(self, inputs: Tensor, input_len: Tensor) -> Tensor: + residual = rearrange(inputs, "B D T -> B T D") + index_list = [] + for codebook in self.codebooks: + # [B, T] + indices_i = codebook.encode(inputs=residual, input_len=input_len) + # [B, D, T] + quantized_i = codebook.decode(indices=indices_i, input_len=input_len) + residual = residual - quantized_i + index_list.append(indices_i) + # [N, B, T] + indices = torch.stack(index_list) + return indices + + @typecheck( + input_types={ + "indices": NeuralType(('N', 'B', 'T'), Index()), + "input_len": NeuralType(tuple('B'), LengthsType()), + }, + output_types={"quantized": NeuralType(('B', 'D', 'T'), EncodedRepresentation()),}, + ) + def decode(self, indices: Tensor, input_len: Tensor) -> Tensor: + # [B, T, D] + quantized = torch.zeros([indices.shape[1], indices.shape[2], self.codebook_dim], device=indices.device) + for codebook_indices, codebook in zip(indices, self.codebooks): + quantized_i = codebook.decode(indices=codebook_indices, input_len=input_len) + quantized = quantized + quantized_i + quantized = rearrange(quantized, "B T D -> B D T") + return quantized diff --git a/nemo/collections/tts/parts/utils/callbacks.py b/nemo/collections/tts/parts/utils/callbacks.py index 2320e5b21a7c..0d408658d8ad 100644 --- a/nemo/collections/tts/parts/utils/callbacks.py +++ b/nemo/collections/tts/parts/utils/callbacks.py @@ -265,6 +265,93 @@ def generate_artifacts( return audio_artifacts, [] +class AudioCodecArtifactGenerator(ArtifactGenerator): + """ + Generator for logging Audio Codec model outputs. + """ + + def __init__(self, log_audio: bool = True, log_encoding: bool = False, log_quantized: bool = False): + self.log_audio = log_audio + self.log_encoding = log_encoding + self.log_quantized = log_quantized + + def _generate_audio(self, model, audio_ids, audio, audio_len): + if not self.log_audio: + return [] + + with torch.no_grad(): + # [B, T] + audio_pred, audio_pred_len = model(audio=audio, audio_len=audio_len) + + audio_artifacts = [] + for i, audio_id in enumerate(audio_ids): + audio_pred_i = audio_pred[i, : audio_pred_len[i]].cpu().numpy() + audio_artifact = AudioArtifact( + id=f"audio_{audio_id}", data=audio_pred_i, filename=f"{audio_id}.wav", sample_rate=model.sample_rate, + ) + audio_artifacts.append(audio_artifact) + + return audio_artifacts + + def _generate_images(self, model, audio_ids, audio, audio_len): + image_artifacts = [] + + if not self.log_encoding and not self.log_quantized: + return image_artifacts + + with torch.no_grad(): + # [B, D, T] + encoded, encoded_len = model.encode_audio(audio=audio, audio_len=audio_len) + + if self.log_encoding: + for i, audio_id in enumerate(audio_ids): + encoded_i = encoded[i, :, : encoded_len[i]].cpu().numpy() + encoded_artifact = ImageArtifact( + id=f"encoded_{audio_id}", + data=encoded_i, + filename=f"{audio_id}_encode.png", + x_axis="Audio Frames", + y_axis="Channels", + ) + image_artifacts.append(encoded_artifact) + + if not self.log_quantized: + return image_artifacts + + with torch.no_grad(): + # [B, D, T] + indices = model.quantize_encode(encoded=encoded, encoded_len=encoded_len) + quantized = model.quantize_decode(indices=indices, encoded_len=encoded_len) + + for i, audio_id in enumerate(audio_ids): + quantized_i = quantized[i, :, : encoded_len[i]].cpu().numpy() + quantized_artifact = ImageArtifact( + id=f"quantized_{audio_id}", + data=quantized_i, + filename=f"{audio_id}_quantized.png", + x_axis="Audio Frames", + y_axis="Channels", + ) + image_artifacts.append(quantized_artifact) + + return image_artifacts + + def generate_artifacts( + self, model: LightningModule, batch_dict: Dict + ) -> Tuple[List[AudioArtifact], List[ImageArtifact]]: + + audio_filepaths = batch_dict.get("audio_filepaths") + audio_ids = [create_id(p) for p in audio_filepaths] + + audio = batch_dict.get("audio") + audio_len = batch_dict.get("audio_lens") + + audio_artifacts = self._generate_audio(model=model, audio_ids=audio_ids, audio=audio, audio_len=audio_len) + image_artifacts = self._generate_images(model=model, audio_ids=audio_ids, audio=audio, audio_len=audio_len) + + return audio_artifacts, image_artifacts + + class FastPitchArtifactGenerator(ArtifactGenerator): """ Generator for logging FastPitch model outputs. diff --git a/nemo/collections/tts/parts/utils/distributed.py b/nemo/collections/tts/parts/utils/distributed.py new file mode 100644 index 000000000000..cbe102bcfdcd --- /dev/null +++ b/nemo/collections/tts/parts/utils/distributed.py @@ -0,0 +1,42 @@ +# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Iterable + +import torch + + +def _is_distributed(): + return torch.distributed.is_initialized() and torch.distributed.get_world_size() > 1 + + +def _is_complex_or_float(tensor): + return torch.is_floating_point(tensor) or torch.is_complex(tensor) + + +def broadcast_tensors(tensors: Iterable[torch.Tensor], src: int = 0): + """ + Broadcast the tensors from the given parameters to all workers. + This can be used to ensure that all workers have the same model to start with. + """ + if not _is_distributed(): + return + + tensors = [tensor for tensor in tensors if _is_complex_or_float(tensor)] + handles = [] + for tensor in tensors: + handle = torch.distributed.broadcast(tensor.data, src=src, async_op=True) + handles.append(handle) + for handle in handles: + handle.wait() diff --git a/nemo/collections/tts/parts/utils/helpers.py b/nemo/collections/tts/parts/utils/helpers.py index b9ea0854e48c..72048882fe78 100644 --- a/nemo/collections/tts/parts/utils/helpers.py +++ b/nemo/collections/tts/parts/utils/helpers.py @@ -733,7 +733,10 @@ def mask_sequence_tensor(tensor: torch.Tensor, lengths: torch.Tensor): """ batch_size, *_, max_lengths = tensor.shape - if len(tensor.shape) == 3: + if len(tensor.shape) == 2: + mask = torch.ones(batch_size, max_lengths).cumsum(dim=-1).type_as(lengths) + mask = mask <= rearrange(lengths, "b -> b 1") + elif len(tensor.shape) == 3: mask = torch.ones(batch_size, 1, max_lengths).cumsum(dim=-1).type_as(lengths) mask = mask <= rearrange(lengths, "b -> b 1 1") elif len(tensor.shape) == 4: diff --git a/tests/collections/tts/losses/test_audio_codec_loss.py b/tests/collections/tts/losses/test_audio_codec_loss.py new file mode 100644 index 000000000000..0fe7991e92cb --- /dev/null +++ b/tests/collections/tts/losses/test_audio_codec_loss.py @@ -0,0 +1,44 @@ +# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import torch + +from nemo.collections.tts.losses.audio_codec_loss import MaskedMAELoss, MaskedMSELoss + + +class TestAudioCodecLoss: + @pytest.mark.run_only_on('CPU') + @pytest.mark.unit + def test_masked_loss_l1(self): + loss_fn = MaskedMAELoss() + target = torch.tensor([[[1.0], [2.0], [0.0]], [[3.0], [0.0], [0.0]]]).transpose(1, 2) + predicted = torch.tensor([[[0.5], [1.0], [0.0]], [[4.5], [0.0], [0.0]]]).transpose(1, 2) + target_len = torch.tensor([2, 1]) + + loss = loss_fn(predicted=predicted, target=target, target_len=target_len) + + assert loss == 1.125 + + @pytest.mark.run_only_on('CPU') + @pytest.mark.unit + def test_masked_loss_l2(self): + loss_fn = MaskedMSELoss() + target = torch.tensor([[[1.0], [2.0], [4.0]], [[3.0], [0.0], [0.0]]]).transpose(1, 2) + predicted = torch.tensor([[[0.5], [1.0], [4.0]], [[4.5], [0.0], [0.0]]]).transpose(1, 2) + target_len = torch.tensor([3, 1]) + + loss = loss_fn(predicted=predicted, target=target, target_len=target_len) + + assert loss == (4 / 3) diff --git a/tests/collections/tts/modules/test_audio_codec_modules.py b/tests/collections/tts/modules/test_audio_codec_modules.py new file mode 100644 index 000000000000..948b1220f39c --- /dev/null +++ b/tests/collections/tts/modules/test_audio_codec_modules.py @@ -0,0 +1,96 @@ +# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import torch + +from nemo.collections.tts.modules.audio_codec_modules import ( + Conv1dNorm, + ConvTranspose1dNorm, + get_down_sample_padding, + get_up_sample_padding, +) + + +class TestAudioCodecModules: + def setup_class(self): + self.in_channels = 8 + self.out_channels = 16 + self.batch_size = 2 + self.len1 = 4 + self.len2 = 8 + self.max_len = 10 + self.kernel_size = 3 + + @pytest.mark.run_only_on('CPU') + @pytest.mark.unit + def test_conv1d(self): + inputs = torch.rand([self.batch_size, self.in_channels, self.max_len]) + lengths = torch.tensor([self.len1, self.len2], dtype=torch.int32) + + conv = Conv1dNorm(in_channels=self.in_channels, out_channels=self.out_channels, kernel_size=self.kernel_size) + out = conv(inputs, lengths) + + assert out.shape == (self.batch_size, self.out_channels, self.max_len) + assert torch.all(out[0, :, : self.len1] != 0.0) + assert torch.all(out[0, :, self.len1 :] == 0.0) + assert torch.all(out[1, :, : self.len2] != 0.0) + assert torch.all(out[1, :, self.len2 :] == 0.0) + + @pytest.mark.run_only_on('CPU') + @pytest.mark.unit + def test_conv1d_downsample(self): + stride = 2 + out_len = self.max_len // stride + out_len_1 = self.len1 // stride + out_len_2 = self.len2 // stride + inputs = torch.rand([self.batch_size, self.in_channels, self.max_len]) + lengths = torch.tensor([out_len_1, out_len_2], dtype=torch.int32) + + padding = get_down_sample_padding(kernel_size=self.kernel_size, stride=stride) + conv = Conv1dNorm( + in_channels=self.in_channels, + out_channels=self.out_channels, + kernel_size=self.kernel_size, + stride=stride, + padding=padding, + ) + out = conv(inputs, lengths) + + assert out.shape == (self.batch_size, self.out_channels, out_len) + assert torch.all(out[0, :, :out_len_1] != 0.0) + assert torch.all(out[0, :, out_len_1:] == 0.0) + assert torch.all(out[1, :, :out_len_2] != 0.0) + assert torch.all(out[1, :, out_len_2:] == 0.0) + + @pytest.mark.run_only_on('CPU') + @pytest.mark.unit + def test_conv1d_transpose_upsample(self): + stride = 2 + out_len = self.max_len * stride + out_len_1 = self.len1 * stride + out_len_2 = self.len2 * stride + inputs = torch.rand([self.batch_size, self.in_channels, self.max_len]) + lengths = torch.tensor([out_len_1, out_len_2], dtype=torch.int32) + + conv = ConvTranspose1dNorm( + in_channels=self.in_channels, out_channels=self.out_channels, kernel_size=self.kernel_size, stride=stride + ) + out = conv(inputs, lengths) + + assert out.shape == (self.batch_size, self.out_channels, out_len) + assert torch.all(out[0, :, :out_len_1] != 0.0) + assert torch.all(out[0, :, out_len_1:] == 0.0) + assert torch.all(out[1, :, :out_len_2] != 0.0) + assert torch.all(out[1, :, out_len_2:] == 0.0)