From 49b511d0550da4c953a6e46fdc1bd0113ed7a42a Mon Sep 17 00:00:00 2001 From: Ji Chen Date: Wed, 24 Jun 2020 14:59:14 -0700 Subject: [PATCH 01/29] add wavernn example --- examples/pipeline_wavernn/datasets.py | 85 +++++ examples/pipeline_wavernn/wavernn.py | 489 ++++++++++++++++++++++++++ 2 files changed, 574 insertions(+) create mode 100644 examples/pipeline_wavernn/datasets.py create mode 100644 examples/pipeline_wavernn/wavernn.py diff --git a/examples/pipeline_wavernn/datasets.py b/examples/pipeline_wavernn/datasets.py new file mode 100644 index 0000000000..d196ee1c18 --- /dev/null +++ b/examples/pipeline_wavernn/datasets.py @@ -0,0 +1,85 @@ +import os +import random +import torch +import torchaudio +from torchaudio.datasets import LJSPEECH + + +class ProcessedLJSPEECH(LJSPEECH): + + def __init__(self, + files, + transforms, + mode, + n_bits): + + self.transforms = transforms + self.files = files + self.mode = mode + self.n_bits = n_bits + + def __getitem__(self, index): + + file = self.files[index] + x, sample_rate = torchaudio.load(file) + mel = self.transforms(x) + + bits = 16 if self.mode == 'MOL' else self.n_bits + + x = (x + 1.) * (2 ** bits - 1) / 2 + x = torch.clamp(x, min=0, max=2 ** bits - 1) + + return mel.squeeze(0), x.int().squeeze(0) + + def __len__(self): + return len(self.files) + + +def datasets_ljspeech(args, transforms): + + root = args.file_path + wavefiles = [os.path.join(root, file) for file in os.listdir(root)] + + random.seed(args.seed) + random.shuffle(wavefiles) + + train_files = wavefiles[:-args.test_samples] + test_files = wavefiles[-args.test_samples:] + + train_dataset = ProcessedLJSPEECH(train_files, transforms, args.mode, args.n_bits) + test_dataset = ProcessedLJSPEECH(test_files, transforms, args.mode, args.n_bits) + + return train_dataset, test_dataset + + +def collate_factory(args): + + def raw_collate(batch): + + pad = (args.kernel_size - 1) // 2 + seq_len = args.hop_length * args.seq_len_factor + mel_win = args.seq_len_factor + 2 * pad + + max_offsets = [x[0].shape[-1] - (mel_win + 2 * pad) for x in batch] + mel_offsets = [random.randint(0, offset) for offset in max_offsets] + wav_offsets = [(offset + pad) * args.hop_length for offset in mel_offsets] + + mels = [x[0][:, mel_offsets[i]:mel_offsets[i] + mel_win] for i, x in enumerate(batch)] + waves = [x[1][wav_offsets[i]:wav_offsets[i] + seq_len + 1] for i, x in enumerate(batch)] + + mels = torch.stack(mels) + waves = torch.stack(waves).long() + + x_input = waves[:, :seq_len] + y_coarse = waves[:, 1:] + + bits = 16 if args.mode == 'MOL' else args.n_bits + + x_input = 2 * x_input / (2**bits - 1.) - 1 + + if args.mode == 'MOL': + y_coarse = 2 * y_coarse.float() / (2**bits - 1.) - 1 + + return x_input, mels, y_coarse + + return raw_collate diff --git a/examples/pipeline_wavernn/wavernn.py b/examples/pipeline_wavernn/wavernn.py new file mode 100644 index 0000000000..c4aa97bd9d --- /dev/null +++ b/examples/pipeline_wavernn/wavernn.py @@ -0,0 +1,489 @@ +import argparse +import os +import shutil +from collections import defaultdict +from datetime import datetime + +import torch +import torch.nn as nn +import torchaudio +from datasets import datasets_ljspeech, collate_factory +from typing import List +from torchaudio.models import _WaveRNN +from torch.utils.data import DataLoader +from torch.optim import Adam +from tqdm import tqdm + + +def parse_args(): + parser = argparse.ArgumentParser() + + # training parameters + parser.add_argument( + "--workers", + default=2, + type=int, + metavar="N", + help="number of data loading workers", + ) + parser.add_argument( + "--checkpoint", + default="checkpoint.pth.par", + type=str, + metavar="PATH", + help="filename to latest checkpoint", + ) + parser.add_argument( + "--epochs", + default=10000, + type=int, + metavar="N", + help="number of total epochs to run", + ) + parser.add_argument( + "--start-epoch", + default=0, + type=int, + metavar="N", + help="manual epoch number" + ) + parser.add_argument( + "--print-freq", + default=2500, + type=int, + metavar="N", + help="print frequency in epochs", + ) + parser.add_argument( + "--batch-size", + default=32, + type=int, + metavar="N", + help="mini-batch size" + ) + parser.add_argument( + "--learning-rate", + default=1e-4, + type=float, + metavar="LR", + help="initial learning rate", + ) + parser.add_argument( + "--weight-decay", + default=0.0, + type=float, + metavar="W", + help="weight decay" + ) + parser.add_argument( + "--adam-beta1", + default=0.9, + type=float, + metavar="BETA1", + help="adam_beta1" + ) + parser.add_argument( + "--adam-beta2", + default=0.999, + type=float, + metavar="BETA2", + help="adam_beta2" + ) + parser.add_argument( + "--eps", + default=1e-8, + type=float, + metavar="EPS", + help="eps") + parser.add_argument( + "--clip-norm", + metavar="NORM", + type=float, + default=4.0, + help="clip norm value") + + parser.add_argument("--progress-bar", action="store_true", help="use progress bar while training") + parser.add_argument("--seed", type=int, default=1000, help="random seed") + # parser.add_argument("--jit", action="store_true", help="if used, model is jitted") + # parser.add_argument("--distributed", action="store_true", help="enable DistributedDataParallel") + + # model parameters + parser.add_argument( + "--upsample-scales", + default=[5, 5, 11], + type=List[int], + help="the list of upsample scales", + ) + parser.add_argument( + "--n-bits", + default=9, + type=int, + help="the bits of output waveform", + ) + parser.add_argument( + "--sample-rate", + default=22050, + type=int, + help="the rate of audio dimensions (samples per second)", + ) + parser.add_argument( + "--hop-length", + default=275, + type=int, + help="the number of samples between the starts of consecutive frames", + ) + parser.add_argument( + "--win-length", + default=1100, + type=int, + help="the length of the STFT window", + ) + parser.add_argument( + "--f-min", + default=40., + type=float, + help="the lowest frequency of the lowest band in a spectrogram", + ) + parser.add_argument( + "--n-res-block", + default=10, + type=int, + help="the number of ResBlock in stack", + ) + parser.add_argument( + "--n-rnn", + default=512, + type=int, + help="the dimension of RNN layer", + ) + parser.add_argument( + "--n-fc", + default=512, + type=int, + help="the dimension of fully connected layer ", + ) + parser.add_argument( + "--kernel-size", + default=5, + type=int, + help="the number of kernel size in the first Conv1d layer", + ) + parser.add_argument( + "--n-freq", + default=80, + type=int, + help="the number of bins in a spectrogram", + ) + parser.add_argument( + "--n-hidden", + default=128, + type=int, + help="the number of hidden dimensions", + ) + parser.add_argument( + "--n-output", + default=128, + type=int, + help="the number of output dimensions", + ) + parser.add_argument( + "--mode", + default="RAW", + type=str, + help="the type of input waveform in ['RAW', 'MOL']", + ) + parser.add_argument( + "--seq-len-factor", + default=5, + type=int, + help="seq_length = hop_length * seq_len_factor, the length of sequence for training", + ) + parser.add_argument( + "--test-samples", + default=50, + type=float, + help="the number of files for test", + ) + parser.add_argument( + "--file-path", + default="/private/home/jimchen90/datasets/LJSpeech-1.1/wavs/", + type=str, + help="the path of audio files", + ) + + args = parser.parse_args() + return args + + +def save_checkpoint(state, is_best, filename): + + if filename == "": + return + + tempfile = filename + ".temp" + + # Remove tempfile in case interuption during the copying from tempfile to filename + if os.path.isfile(tempfile): + os.remove(tempfile) + + torch.save(state, tempfile) + if os.path.isfile(tempfile): + os.rename(tempfile, filename) + if is_best: + shutil.copyfile(filename, "model_best.pth.tar") + + print("Checkpoint: saved", flush=True) + + +# count total parameters in the model +def count_parameters(model): + return sum(p.numel() for p in model.parameters() if p.requires_grad) + + +def train_one_epoch(model, criterion, optimizer, data_loader, device, pbar=None): + + model.train() + + sums = defaultdict(lambda: 0.0) + + for i, (x, m, y) in enumerate(data_loader): + x = x.to(device, non_blocking=True) + m = m.to(device, non_blocking=True) + y = y.to(device, non_blocking=True) + + y_hat = model(x, m) + + if model.mode == 'RAW': + y_hat = y_hat.transpose(1, 2) + + elif model.mode == 'MOL': + y = y.float().unsqueeze(-1) + + else: + raise ValueError('This input mode is not valid.') + + loss = criterion(y_hat, y) + + sums["loss"] += loss.item() + + optimizer.zero_grad() + loss.backward() + + if args.clip_norm > 0: + sums["gradient"] += torch.nn.utils.clip_grad_norm_( + model.parameters(), args.clip_norm + ) + + optimizer.step() + + if pbar is not None: + pbar.update(1 / len(data_loader)) + + avg_loss = sums["loss"] / len(data_loader) + print(f"Training loss: {avg_loss:4.5f}", flush=True) + + if "gradient" in sums: + avg_gradient = sums["gradient"] / len(data_loader) + print(f"Average gradient norm: {avg_gradient:4.8f}", flush=True) + + +def evaluate(model, criterion, data_loader, device): + + with torch.no_grad(): + + model.eval() + + sums = defaultdict(lambda: 0.0) + + for i, (x, m, y) in enumerate(data_loader): + + x = x.to(device, non_blocking=True) + m = m.to(device, non_blocking=True) + y = y.to(device, non_blocking=True) + + y_hat = model(x, m) + + if model.mode == 'RAW': + y_hat = y_hat.transpose(1, 2) + + elif model.mode == 'MOL': + y = y.float().unsqueeze(-1) + + else: + raise ValueError('This input mode is not valid.') + + loss = criterion(y_hat, y) + sums["loss"] += loss.item() + + avg_loss = sums["loss"] / len(data_loader) + print(f"Validation loss: {avg_loss:.5f}", flush=True) + + return avg_loss + + +def main(args): + + devices = ["cuda:0" if torch.cuda.is_available() else "cpu"] + + print("Start time: {}".format(str(datetime.now())), flush=True) + + # Empty CUDA cache + torch.cuda.empty_cache() + + # parameters for melspectrogram + melkwargs = { + "n_fft": 2048, + "n_mels": args.n_freq, + "hop_length": args.hop_length, + "f_min": args.f_min, + "win_length": args.win_length + } + + transforms = torch.nn.Sequential( + # torchaudio.transforms.Resample(sample_rate_original, sample_rate_input), + torchaudio.transforms.MelSpectrogram( + sample_rate=args.sample_rate, **melkwargs + ), + ) + + # Dataloader + train_dataset, test_dataset = datasets_ljspeech(args, transforms) + + loader_training_params = { + "num_workers": args.workers, + "pin_memory": True, + "shuffle": True, + "drop_last": False, + } + loader_validation_params = loader_training_params.copy() + loader_validation_params["shuffle"] = False + + collate_fn = collate_factory(args) + + loader_training = DataLoader( + train_dataset, + batch_size=args.batch_size, + collate_fn=collate_fn, + **loader_training_params, + ) + + loader_test = DataLoader( + test_dataset, + batch_size=args.batch_size, + collate_fn=collate_fn, + **loader_validation_params, + ) + + # model + model = _WaveRNN(upsample_scales=args.upsample_scales, + n_bits=args.n_bits, + sample_rate=args.sample_rate, + hop_length=args.hop_length, + n_res_block=args.n_res_block, + n_rnn=args.n_rnn, + n_fc=args.n_fc, + kernel_size=args.kernel_size, + n_freq=args.n_freq, + n_hidden=args.n_hidden, + n_output=args.n_output, + mode=args.mode) + +# if args.jit: +# model = torch.jit.script(model) + +# if not args.distributed: +# model = torch.nn.DataParallel(model) +# else: +# model.cuda() +# model = torch.nn.parallel.DistributedDataParallel(model, device_ids=devices) + + model = model.to(devices[0], non_blocking=True) + + n = count_parameters(model) + print(f"Number of parameters: {n}", flush=True) + + # Check the hop length is correctly factorised + total_scale = 1 + for upsample_scale in args.upsample_scales: + total_scale *= upsample_scale + assert total_scale == args.hop_length + + # Optimizer + optimizer_params = { + "lr": args.learning_rate, + "betas": (args.adam_beta1, args.adam_beta2), + "eps": args.eps, + "weight_decay": args.weight_decay, + } + + optimizer = Adam(model.parameters(), **optimizer_params) + + # This is for 'RAW' input, I need to add loss function for 'MOL' input here. + criterion = nn.CrossEntropyLoss() + + best_loss = 1.0 + + load_checkpoint = args.checkpoint and os.path.isfile(args.checkpoint) + + if load_checkpoint: + print("Checkpoint: loading '{}'".format(args.checkpoint), flush=True) + checkpoint = torch.load(args.checkpoint) + + args.start_epoch = checkpoint["epoch"] + best_loss = checkpoint["best_loss"] + + model.load_state_dict(checkpoint["state_dict"]) + optimizer.load_state_dict(checkpoint["optimizer"]) + # scheduler.load_state_dict(checkpoint["scheduler"]) + + print("Checkpoint: loaded '{}' at epoch {}".format(args.checkpoint, checkpoint["epoch"]), flush=True,) + + else: + print("Checkpoint: not found", flush=True) + + save_checkpoint( + { + "epoch": args.start_epoch, + "state_dict": model.state_dict(), + "best_loss": best_loss, + "optimizer": optimizer.state_dict(), + # "scheduler": scheduler.state_dict(), + }, + False, + args.checkpoint, + ) + + with tqdm(total=args.epochs, unit_scale=1, disable=not args.progress_bar) as pbar: + + for epoch in range(args.start_epoch, args.epochs): + + train_one_epoch( + model, + criterion, + optimizer, + loader_training, + devices[0], + pbar=pbar, + ) + + if not (epoch + 1) % args.print_freq or epoch + 1 == args.epochs: + + sum_loss = evaluate(model, criterion, loader_test, devices[0]) + + is_best = sum_loss < best_loss + best_loss = min(sum_loss, best_loss) + save_checkpoint( + { + "epoch": epoch + 1, + "state_dict": model.state_dict(), + "best_loss": best_loss, + "optimizer": optimizer.state_dict(), + }, + is_best, + args.checkpoint, + ) + + +if __name__ == "__main__": + args = parse_args() + main(args) From 4728e3d501211463923f29c9758cc2e513c27106 Mon Sep 17 00:00:00 2001 From: Ji Chen Date: Wed, 24 Jun 2020 15:10:24 -0700 Subject: [PATCH 02/29] update dataset --- examples/pipeline_wavernn/datasets.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/examples/pipeline_wavernn/datasets.py b/examples/pipeline_wavernn/datasets.py index d196ee1c18..f49f18b502 100644 --- a/examples/pipeline_wavernn/datasets.py +++ b/examples/pipeline_wavernn/datasets.py @@ -73,13 +73,6 @@ def raw_collate(batch): x_input = waves[:, :seq_len] y_coarse = waves[:, 1:] - bits = 16 if args.mode == 'MOL' else args.n_bits - - x_input = 2 * x_input / (2**bits - 1.) - 1 - - if args.mode == 'MOL': - y_coarse = 2 * y_coarse.float() / (2**bits - 1.) - 1 - return x_input, mels, y_coarse return raw_collate From d17b9d12171f9de7dab1deb63480250f15f72991 Mon Sep 17 00:00:00 2001 From: Ji Chen Date: Thu, 25 Jun 2020 10:46:16 -0700 Subject: [PATCH 03/29] update input type --- examples/pipeline_wavernn/datasets.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/examples/pipeline_wavernn/datasets.py b/examples/pipeline_wavernn/datasets.py index f49f18b502..d196ee1c18 100644 --- a/examples/pipeline_wavernn/datasets.py +++ b/examples/pipeline_wavernn/datasets.py @@ -73,6 +73,13 @@ def raw_collate(batch): x_input = waves[:, :seq_len] y_coarse = waves[:, 1:] + bits = 16 if args.mode == 'MOL' else args.n_bits + + x_input = 2 * x_input / (2**bits - 1.) - 1 + + if args.mode == 'MOL': + y_coarse = 2 * y_coarse.float() / (2**bits - 1.) - 1 + return x_input, mels, y_coarse return raw_collate From 17560e74083a01c251fa226d7866aaaea4e30e73 Mon Sep 17 00:00:00 2001 From: Ji Chen Date: Mon, 29 Jun 2020 06:25:09 -0700 Subject: [PATCH 04/29] add transform and mol loss --- examples/pipeline_wavernn/datasets.py | 90 ++++++++----- examples/pipeline_wavernn/loss_mol.py | 81 ++++++++++++ examples/pipeline_wavernn/transform.py | 59 +++++++++ examples/pipeline_wavernn/wavernn.py | 171 +++++++++++++++---------- 4 files changed, 306 insertions(+), 95 deletions(-) create mode 100644 examples/pipeline_wavernn/loss_mol.py create mode 100644 examples/pipeline_wavernn/transform.py diff --git a/examples/pipeline_wavernn/datasets.py b/examples/pipeline_wavernn/datasets.py index d196ee1c18..b56f78c3aa 100644 --- a/examples/pipeline_wavernn/datasets.py +++ b/examples/pipeline_wavernn/datasets.py @@ -7,34 +7,54 @@ class ProcessedLJSPEECH(LJSPEECH): - def __init__(self, - files, - transforms, - mode, - n_bits): + def __init__(self, files, transforms, mode, mulaw, n_bits): self.transforms = transforms self.files = files - self.mode = mode + self.mulaw = mulaw self.n_bits = n_bits + self.mode = mode def __getitem__(self, index): file = self.files[index] - x, sample_rate = torchaudio.load(file) - mel = self.transforms(x) - bits = 16 if self.mode == 'MOL' else self.n_bits + # use torchaudio transform to get waveform and specgram + # waveform, sample_rate = torchaudio.load(file) + # specgram = self.transforms(x) + # return waveform.squeeze(0), mel.squeeze(0) + + # use librosa transform to get waveform and specgram + waveform = self.transforms.load(file) + specgram = self.transforms.melspectrogram(waveform) - x = (x + 1.) * (2 ** bits - 1) / 2 - x = torch.clamp(x, min=0, max=2 ** bits - 1) + # waveform: [0, 2**bits-1] in all cases. + # It is better than [-1, 1] in all cases because of mulaw-encode. + if self.mode == 'waveform': + waveform = self.transforms.mulaw_encode(waveform, 2**self.n_bits) if self.mulaw \ + else float_2_int(waveform, self.n_bits) - return mel.squeeze(0), x.int().squeeze(0) + elif self.mode == 'mol': + waveform = float_2_int(waveform, 16) + + return waveform, specgram def __len__(self): return len(self.files) +# From float waveform [-1, 1] to integer label [0, 2 ** bits - 1] +def float_2_int(waveform, bits): + assert abs(waveform).max() <= 1.0 + waveform = (waveform + 1.) * (2**bits - 1) / 2 + return torch.clamp(waveform, 0, 2**bits - 1).int() + + +# From integer label [0, 2 ** bits - 1] to float waveform [-1, 1] +def int_2_float(waveform, bits): + return 2 * waveform / (2**bits - 1.) - 1. + + def datasets_ljspeech(args, transforms): root = args.file_path @@ -46,8 +66,8 @@ def datasets_ljspeech(args, transforms): train_files = wavefiles[:-args.test_samples] test_files = wavefiles[-args.test_samples:] - train_dataset = ProcessedLJSPEECH(train_files, transforms, args.mode, args.n_bits) - test_dataset = ProcessedLJSPEECH(test_files, transforms, args.mode, args.n_bits) + train_dataset = ProcessedLJSPEECH(train_files, transforms, args.mode, args.mulaw, args.n_bits) + test_dataset = ProcessedLJSPEECH(test_files, transforms, args.mode, args.mulaw, args.n_bits) return train_dataset, test_dataset @@ -57,29 +77,39 @@ def collate_factory(args): def raw_collate(batch): pad = (args.kernel_size - 1) // 2 - seq_len = args.hop_length * args.seq_len_factor - mel_win = args.seq_len_factor + 2 * pad + # input sequence length, increase seq_len_factor to increase it. + wave_length = args.hop_length * args.seq_len_factor + # input spectrogram length + spec_length = args.seq_len_factor + pad * 2 + + # max start postion in spectrogram + max_offsets = [x[1].shape[-1] - (spec_length + pad * 2) for x in batch] + + # random start postion in spectrogram + spec_offsets = [random.randint(0, offset) for offset in max_offsets] - max_offsets = [x[0].shape[-1] - (mel_win + 2 * pad) for x in batch] - mel_offsets = [random.randint(0, offset) for offset in max_offsets] - wav_offsets = [(offset + pad) * args.hop_length for offset in mel_offsets] + # random start postion in waveform + wave_offsets = [(offset + pad) * args.hop_length for offset in spec_offsets] - mels = [x[0][:, mel_offsets[i]:mel_offsets[i] + mel_win] for i, x in enumerate(batch)] - waves = [x[1][wav_offsets[i]:wav_offsets[i] + seq_len + 1] for i, x in enumerate(batch)] + waveform_combine = [x[0][wave_offsets[i]:wave_offsets[i] + wave_length + 1] for i, x in enumerate(batch)] + specgram = [x[1][:, spec_offsets[i]:spec_offsets[i] + spec_length] for i, x in enumerate(batch)] - mels = torch.stack(mels) - waves = torch.stack(waves).long() + # stack batch + specgram = torch.stack(specgram) + waveform_combine = torch.stack(waveform_combine) - x_input = waves[:, :seq_len] - y_coarse = waves[:, 1:] + waveform = waveform_combine[:, :wave_length] + target = waveform_combine[:, 1:] - bits = 16 if args.mode == 'MOL' else args.n_bits + # waveform: [-1, 1], target: [0, 2**bits-1] if mode = 'waveform' + # waveform: [-1, 1], target: [-1, 1] if mode = 'mol' + bits = 16 if args.mode == 'mol' else args.n_bits - x_input = 2 * x_input / (2**bits - 1.) - 1 + waveform = int_2_float(waveform.float(), bits) - if args.mode == 'MOL': - y_coarse = 2 * y_coarse.float() / (2**bits - 1.) - 1 + if args.mode == 'mol': + target = int_2_float(target.float(), bits) - return x_input, mels, y_coarse + return waveform, specgram, target return raw_collate diff --git a/examples/pipeline_wavernn/loss_mol.py b/examples/pipeline_wavernn/loss_mol.py new file mode 100644 index 0000000000..70d1a18548 --- /dev/null +++ b/examples/pipeline_wavernn/loss_mol.py @@ -0,0 +1,81 @@ +import torch +import torch.nn.functional as F + +# Adapted from wavenet vocoder: +# https://github.com/r9y9/wavenet_vocoder/blob/master/wavenet_vocoder/mixture.py +# Explain mol loss: +# https://github.com/Rayhane-mamah/Tacotron-2/issues/155 + +# Remove numpy dependency + + +def log_sum_exp(x): + """ numerically stable log_sum_exp implementation that prevents overflow """ + + axis = len(x.size()) - 1 + m, _ = torch.max(x, dim=axis) + m2, _ = torch.max(x, dim=axis, keepdim=True) + return m + torch.log(torch.sum(torch.exp(x - m2), dim=axis)) + + +def LossFn_Mol(y_hat, y, num_classes=65536, log_scale_min=None, reduce=True): + """ calculate the loss of mol mode""" + + min_value = -32.23619130191664 # = float(np.log(1e-14)) + if log_scale_min is None: + log_scale_min = min_value + + assert y_hat.dim() == 3 + assert y_hat.size(-1) % 3 == 0 + + nr_mix = y_hat.size(-1) // 3 + + # unpack parameters. (n_batch, n_time, n_mixtures) x 3 + logit_probs = y_hat[:, :, :nr_mix] + means = y_hat[:, :, nr_mix:2 * nr_mix] + log_scales = torch.clamp(y_hat[:, :, 2 * nr_mix:3 * nr_mix], min=log_scale_min) + + # n_batch x n_time x 1 -> n_batch x n_time x n_mixtures + y = y.expand_as(means) + + centered_y = y - means + inv_stdv = torch.exp(-log_scales) + plus_in = inv_stdv * (centered_y + 1. / (num_classes - 1)) + cdf_plus = torch.sigmoid(plus_in) + min_in = inv_stdv * (centered_y - 1. / (num_classes - 1)) + cdf_min = torch.sigmoid(min_in) + + # log probability for edge case of 0 (before scaling) + # equivalent: torch.log(F.sigmoid(plus_in)) + log_cdf_plus = plus_in - F.softplus(plus_in) + + # log probability for edge case of 255 (before scaling) + # equivalent: (1 - F.sigmoid(min_in)).log() + log_one_minus_cdf_min = -F.softplus(min_in) + + # probability for all other cases + cdf_delta = cdf_plus - cdf_min + + mid_in = inv_stdv * centered_y + # log probability in the center of the bin, to be used in extreme cases + # (not actually used in our code) + log_pdf_mid = mid_in - log_scales - 2. * F.softplus(mid_in) + + inner_inner_cond = (cdf_delta > 1e-5).float() + + tmp = 10.397192449493701 # = np.log((num_classes - 1) / 2) + inner_inner_out = inner_inner_cond * \ + torch.log(torch.clamp(cdf_delta, min=1e-12)) + \ + (1. - inner_inner_cond) * (log_pdf_mid - tmp) + inner_cond = (y > 0.999).float() + inner_out = inner_cond * log_one_minus_cdf_min + (1. - inner_cond) * inner_inner_out + cond = (y < -0.999).float() + log_probs = cond * log_cdf_plus + (1. - cond) * inner_out + + # Add the 10 distributions probabilities and compute the new probabilities: + log_probs = log_probs + F.log_softmax(logit_probs, -1) + + if reduce: + return -torch.mean(log_sum_exp(log_probs)) + else: + return -log_sum_exp(log_probs).unsqueeze(-1) diff --git a/examples/pipeline_wavernn/transform.py b/examples/pipeline_wavernn/transform.py new file mode 100644 index 0000000000..7d7f7e1501 --- /dev/null +++ b/examples/pipeline_wavernn/transform.py @@ -0,0 +1,59 @@ +import librosa +import numpy as np +import torch + + +class Transform(): + def __init__(self, + sample_rate, + n_fft, + hop_length, + win_length, + num_mels, + fmin, + min_level_db): + + self.sample_rate = sample_rate + self.n_fft = n_fft + self.hop_length = hop_length + self.win_length = win_length + self.num_mels = num_mels + self.fmin = fmin + self.min_level_db = min_level_db + + def load(self, path): + waveform = librosa.load(path, sr=self.sample_rate)[0] + return torch.from_numpy(waveform) + + def stft(self, y): + return librosa.stft( + y=y, + n_fft=self.n_fft, hop_length=self.hop_length, win_length=self.win_length) + + def linear_to_mel(self, spectrogram): + return librosa.feature.melspectrogram( + S=spectrogram, sr=self.sample_rate, n_fft=self.n_fft, n_mels=self.num_mels, fmin=self.fmin) + + def normalize(self, S): + return np.clip((S - self.min_level_db) / - self.min_level_db, 0, 1) + + def denormalize(self, S): + return (np.clip(S, 0, 1) * -self.min_level_db) + self.min_level_db + + def amp_to_db(self, x): + return 20 * np.log10(np.maximum(1e-5, x)) + + def db_to_amp(self, x): + return np.power(10.0, x * 0.05) + + def melspectrogram(self, y): + D = self.stft(y.numpy()) + S = self.amp_to_db(self.linear_to_mel(np.abs(D))) + S = self.normalize(S) + return torch.from_numpy(S).float() + + def mulaw_encode(self, x, mu): + x = x.numpy() + mu = mu - 1 + fx = np.sign(x) * np.log(1 + mu * np.abs(x)) / np.log(1 + mu) + return torch.from_numpy(np.floor((fx + 1) / 2 * mu + 0.5)).int() diff --git a/examples/pipeline_wavernn/wavernn.py b/examples/pipeline_wavernn/wavernn.py index c4aa97bd9d..840236b61b 100644 --- a/examples/pipeline_wavernn/wavernn.py +++ b/examples/pipeline_wavernn/wavernn.py @@ -7,12 +7,14 @@ import torch import torch.nn as nn import torchaudio +from transform import Transform from datasets import datasets_ljspeech, collate_factory from typing import List from torchaudio.models import _WaveRNN from torch.utils.data import DataLoader from torch.optim import Adam from tqdm import tqdm +from loss_mol import LossFn_Mol def parse_args(): @@ -21,7 +23,7 @@ def parse_args(): # training parameters parser.add_argument( "--workers", - default=2, + default=8, type=int, metavar="N", help="number of data loading workers", @@ -30,12 +32,12 @@ def parse_args(): "--checkpoint", default="checkpoint.pth.par", type=str, - metavar="PATH", + metavar="FILE", help="filename to latest checkpoint", ) parser.add_argument( "--epochs", - default=10000, + default=2000, type=int, metavar="N", help="number of total epochs to run", @@ -49,7 +51,7 @@ def parse_args(): ) parser.add_argument( "--print-freq", - default=2500, + default=100, type=int, metavar="N", help="print frequency in epochs", @@ -102,18 +104,22 @@ def parse_args(): default=4.0, help="clip norm value") - parser.add_argument("--progress-bar", action="store_true", help="use progress bar while training") parser.add_argument("--seed", type=int, default=1000, help="random seed") - # parser.add_argument("--jit", action="store_true", help="if used, model is jitted") - # parser.add_argument("--distributed", action="store_true", help="enable DistributedDataParallel") + parser.add_argument("--progress-bar", default=False, action="store_true", help="use progress bar while training") + parser.add_argument("--mulaw", default=True, action="store_true", help="if used, waveform is mulaw encoded") + parser.add_argument("--jit", default=False, action="store_true", help="if used, model is jitted") + # parser.add_argument("--distributed", default=False, action="store_true", help="enable DistributedDataParallel") # model parameters + + # the product of upsample_scales must equal hop_length parser.add_argument( "--upsample-scales", default=[5, 5, 11], type=List[int], help="the list of upsample scales", ) + # output waveform bits parser.add_argument( "--n-bits", default=9, @@ -136,13 +142,19 @@ def parse_args(): "--win-length", default=1100, type=int, - help="the length of the STFT window", + help="the number of samples between the starts of consecutive frames", ) parser.add_argument( "--f-min", default=40., type=float, - help="the lowest frequency of the lowest band in a spectrogram", + help="the number of samples between the starts of consecutive frames", + ) + parser.add_argument( + "--min-level-db", + default=-100, + type=float, + help="the min db value for spectrogam normalization", ) parser.add_argument( "--n-res-block", @@ -186,24 +198,28 @@ def parse_args(): type=int, help="the number of output dimensions", ) + # mode = ['waveform', 'mol'] parser.add_argument( "--mode", - default="RAW", + default="mol", type=str, - help="the type of input waveform in ['RAW', 'MOL']", + help="the mode of waveform", ) + # the length of input waveform and spectrogram parser.add_argument( "--seq-len-factor", default=5, type=int, - help="seq_length = hop_length * seq_len_factor, the length of sequence for training", + help="seq_length = hop_length * seq_len_factor", ) + # the number of waveforms for testing parser.add_argument( "--test-samples", default=50, type=float, - help="the number of files for test", + help="the number of test waveforms", ) + # the path to store audio files parser.add_argument( "--file-path", default="/private/home/jimchen90/datasets/LJSpeech-1.1/wavs/", @@ -215,6 +231,8 @@ def parse_args(): return args +# From wav2letter pipeline: +# https://github.com/vincentqb/audio/blob/wav2letter/examples/pipeline/wav2letter.py def save_checkpoint(state, is_best, filename): if filename == "": @@ -235,35 +253,37 @@ def save_checkpoint(state, is_best, filename): print("Checkpoint: saved", flush=True) -# count total parameters in the model +# count parameter numbers in model def count_parameters(model): return sum(p.numel() for p in model.parameters() if p.requires_grad) -def train_one_epoch(model, criterion, optimizer, data_loader, device, pbar=None): - +# train one epoch +def train_one_epoch(model, mode, bits, mulaw, criterion, optimizer, data_loader, device, pbar=None): model.train() sums = defaultdict(lambda: 0.0) - for i, (x, m, y) in enumerate(data_loader): - x = x.to(device, non_blocking=True) - m = m.to(device, non_blocking=True) - y = y.to(device, non_blocking=True) + for i, (waveform, specgram, target) in enumerate(data_loader, 1): + waveform = waveform.to(device, non_blocking=True) + specgram = specgram.to(device, non_blocking=True) + target = target.to(device, non_blocking=True) - y_hat = model(x, m) + output = model(waveform, specgram) - if model.mode == 'RAW': - y_hat = y_hat.transpose(1, 2) + if mode == 'waveform': + # (n_batch, 2 ** n_bits, n_time) + output = output.transpose(1, 2) + target = target.long() - elif model.mode == 'MOL': - y = y.float().unsqueeze(-1) + elif mode == 'mol': + # (n_batch, n_time, 1) + target = target.unsqueeze(-1) else: - raise ValueError('This input mode is not valid.') - - loss = criterion(y_hat, y) + raise ValueError(f"Expected mode: `waveform` or `mol`, but found {mode}") + loss = criterion(output, target) sums["loss"] += loss.item() optimizer.zero_grad() @@ -287,7 +307,7 @@ def train_one_epoch(model, criterion, optimizer, data_loader, device, pbar=None) print(f"Average gradient norm: {avg_gradient:4.8f}", flush=True) -def evaluate(model, criterion, data_loader, device): +def evaluate(model, mode, bits, mulaw, criterion, data_loader, device): with torch.no_grad(): @@ -295,58 +315,74 @@ def evaluate(model, criterion, data_loader, device): sums = defaultdict(lambda: 0.0) - for i, (x, m, y) in enumerate(data_loader): + for i, (waveform, specgram, target) in enumerate(data_loader, 1): + waveform = waveform.to(device, non_blocking=True) + specgram = specgram.to(device, non_blocking=True) + target = target.to(device, non_blocking=True) - x = x.to(device, non_blocking=True) - m = m.to(device, non_blocking=True) - y = y.to(device, non_blocking=True) + output = model(waveform, specgram) - y_hat = model(x, m) + if mode == 'waveform': + # (batch, 2 ** bits, seq_len) + output = output.transpose(1, 2) + target = target.long() - if model.mode == 'RAW': - y_hat = y_hat.transpose(1, 2) - - elif model.mode == 'MOL': - y = y.float().unsqueeze(-1) + elif mode == 'mol': + # (batch, seq_len, 1) + target = target.unsqueeze(-1) else: - raise ValueError('This input mode is not valid.') + raise ValueError(f"Expected mode: `waveform` or `mol`, but found {mode}") - loss = criterion(y_hat, y) + loss = criterion(output, target) sums["loss"] += loss.item() avg_loss = sums["loss"] / len(data_loader) - print(f"Validation loss: {avg_loss:.5f}", flush=True) + print(f"Validation loss: {avg_loss:.8f}", flush=True) return avg_loss def main(args): - devices = ["cuda:0" if torch.cuda.is_available() else "cpu"] + devices = ["cuda" if torch.cuda.is_available() else "cpu"] print("Start time: {}".format(str(datetime.now())), flush=True) # Empty CUDA cache torch.cuda.empty_cache() - # parameters for melspectrogram + # use torchaudio transform to get waveform and specgram + +# melkwargs = { +# "n_fft": 2048, +# "n_mels": args.n_freq, +# "hop_length": args.hop_length, +# "f_min": args.f_min, +# "win_length": args.win_length +# } + +# transforms = torch.nn.Sequential( +# torchaudio.transforms.MelSpectrogram( +# sample_rate=args.sample_rate, **melkwargs +# ), +# torchaudio.transforms.MuLawEncoding(2**args.n_bits) +# ) + + # use librosa transform to get waveform and specgram + melkwargs = { "n_fft": 2048, - "n_mels": args.n_freq, + "num_mels": args.n_freq, "hop_length": args.hop_length, - "f_min": args.f_min, - "win_length": args.win_length + "fmin": args.f_min, + "win_length": args.win_length, + "sample_rate": args.sample_rate, + "min_level_db": args.min_level_db } + transforms = Transform(**melkwargs) - transforms = torch.nn.Sequential( - # torchaudio.transforms.Resample(sample_rate_original, sample_rate_input), - torchaudio.transforms.MelSpectrogram( - sample_rate=args.sample_rate, **melkwargs - ), - ) - - # Dataloader + # dataset train_dataset, test_dataset = datasets_ljspeech(args, transforms) loader_training_params = { @@ -396,19 +432,14 @@ def main(args): # else: # model.cuda() # model = torch.nn.parallel.DistributedDataParallel(model, device_ids=devices) - + model = torch.nn.DataParallel(model) model = model.to(devices[0], non_blocking=True) n = count_parameters(model) print(f"Number of parameters: {n}", flush=True) - # Check the hop length is correctly factorised - total_scale = 1 - for upsample_scale in args.upsample_scales: - total_scale *= upsample_scale - assert total_scale == args.hop_length - # Optimizer + optimizer_params = { "lr": args.learning_rate, "betas": (args.adam_beta1, args.adam_beta2), @@ -418,8 +449,7 @@ def main(args): optimizer = Adam(model.parameters(), **optimizer_params) - # This is for 'RAW' input, I need to add loss function for 'MOL' input here. - criterion = nn.CrossEntropyLoss() + criterion = nn.CrossEntropyLoss() if args.mode == 'waveform' else LossFn_Mol best_loss = 1.0 @@ -459,6 +489,9 @@ def main(args): train_one_epoch( model, + args.mode, + args.n_bits, + args.mulaw, criterion, optimizer, loader_training, @@ -468,7 +501,15 @@ def main(args): if not (epoch + 1) % args.print_freq or epoch + 1 == args.epochs: - sum_loss = evaluate(model, criterion, loader_test, devices[0]) + sum_loss = evaluate( + model, + args.mode, + args.n_bits, + args.mulaw, + criterion, + loader_test, + devices[0], + ) is_best = sum_loss < best_loss best_loss = min(sum_loss, best_loss) From 9717b753aa11b78744b3b3d12dfc3dd2918915e8 Mon Sep 17 00:00:00 2001 From: Ji Chen Date: Mon, 6 Jul 2020 06:32:43 -0700 Subject: [PATCH 05/29] update format and add utils and readme --- examples/pipeline_wavernn/README | 35 ++ examples/pipeline_wavernn/datasets.py | 3 +- examples/pipeline_wavernn/model.py | 324 +++++++++++++++++ .../{loss_mol.py => mol_loss.py} | 34 +- examples/pipeline_wavernn/utils.py | 56 +++ examples/pipeline_wavernn/wavernn.py | 329 ++++++++++-------- 6 files changed, 629 insertions(+), 152 deletions(-) create mode 100644 examples/pipeline_wavernn/README create mode 100644 examples/pipeline_wavernn/model.py rename examples/pipeline_wavernn/{loss_mol.py => mol_loss.py} (71%) create mode 100644 examples/pipeline_wavernn/utils.py diff --git a/examples/pipeline_wavernn/README b/examples/pipeline_wavernn/README new file mode 100644 index 0000000000..5e5552f0ea --- /dev/null +++ b/examples/pipeline_wavernn/README @@ -0,0 +1,35 @@ +This is an example pipeline for WaveRNN vocoder using the WaveRNN model trained with LJSPEECH. WaveRNN and LJSPEECH are available in torchaudio. + +### Output + +The information reported at each iteration and epoch (e.g. loss) is printed to standard output in the form of one json per line. Further information is reported to standard error. Here is an example python function to parse the standard output. +```python +def read_json(filename): + """ + Convert the standard output saved to filename into a pandas dataframe for analysis. + """ + + import pandas + import json + + with open(filename, "r") as f: + data = f.read() + + # pandas doesn't read single quotes for json + data = data.replace("'", '"') + + data = [json.loads(l) for l in data.splitlines()] + return pandas.DataFrame(data) +``` + +### Usage + +More information about each command line parameters is available with the `--help` option. An example can be invoked as follows. +``` +python main.py \ + --batch-size 128 \ + --learning-rate 1e-4 \ + --n-freq 80 \ + --mode 'mol' \ + --n_bits 9 \ +``` \ No newline at end of file diff --git a/examples/pipeline_wavernn/datasets.py b/examples/pipeline_wavernn/datasets.py index b56f78c3aa..c61eb0ed46 100644 --- a/examples/pipeline_wavernn/datasets.py +++ b/examples/pipeline_wavernn/datasets.py @@ -28,8 +28,7 @@ def __getitem__(self, index): waveform = self.transforms.load(file) specgram = self.transforms.melspectrogram(waveform) - # waveform: [0, 2**bits-1] in all cases. - # It is better than [-1, 1] in all cases because of mulaw-encode. + # waveform and spectrogram: [0, 2**bits-1]. if self.mode == 'waveform': waveform = self.transforms.mulaw_encode(waveform, 2**self.n_bits) if self.mulaw \ else float_2_int(waveform, self.n_bits) diff --git a/examples/pipeline_wavernn/model.py b/examples/pipeline_wavernn/model.py new file mode 100644 index 0000000000..f808db5f74 --- /dev/null +++ b/examples/pipeline_wavernn/model.py @@ -0,0 +1,324 @@ +from typing import List + +import torch +from torch import Tensor +from torch import nn + +__all__ = ["_ResBlock", "_MelResNet", "_Stretch2d", "_UpsampleNetwork", "_WaveRNN"] + + +class _ResBlock(nn.Module): + r"""This is a ResNet block layer. This layer is based on the paper "Deep Residual Learning + for Image Recognition". Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun. CVPR, 2016. + It is a block used in WaveRNN. + + Args: + n_freq: the number of bins in a spectrogram (default=128) + + Examples:: + >>> resblock = _ResBlock(n_freq=128) + >>> input = torch.rand(10, 128, 512) + >>> output = resblock(input) + """ + + def __init__(self, n_freq: int = 128) -> None: + super().__init__() + + self.resblock_model = nn.Sequential( + nn.Conv1d(in_channels=n_freq, out_channels=n_freq, kernel_size=1, bias=False), + nn.BatchNorm1d(n_freq), + nn.ReLU(inplace=True), + nn.Conv1d(in_channels=n_freq, out_channels=n_freq, kernel_size=1, bias=False), + nn.BatchNorm1d(n_freq) + ) + + def forward(self, x: Tensor) -> Tensor: + r""" + Args: + x: the input sequence to the _ResBlock layer + + Shape: + - x: :math:`(batch, freq, time)` + - output: :math:`(batch, freq, time)` + """ + + residual = x + return self.resblock_model(x) + residual + + +class _MelResNet(nn.Module): + r"""This is a MelResNet layer based on a stack of ResBlocks. It is a block used in WaveRNN. + + Args: + n_res_block: the number of ResBlock in stack (default=10) + n_freq: the number of bins in a spectrogram (default=128) + n_hidden: the number of hidden dimensions (default=128) + n_output: the number of output dimensions (default=128) + kernel_size: the number of kernel size in the first Conv1d layer (default=5) + + Examples:: + >>> melresnet = _MelResNet(n_res_block=10, n_freq=128, n_hidden=128, + n_output=128, kernel_size=5) + >>> input = torch.rand(10, 128, 512) + >>> output = melresnet(input) + """ + + def __init__(self, + n_res_block: int = 10, + n_freq: int = 128, + n_hidden: int = 128, + n_output: int = 128, + kernel_size: int = 5) -> None: + super().__init__() + + ResBlocks = [] + + for i in range(n_res_block): + ResBlocks.append(_ResBlock(n_hidden)) + + self.melresnet_model = nn.Sequential( + nn.Conv1d(in_channels=n_freq, out_channels=n_hidden, kernel_size=kernel_size, bias=False), + nn.BatchNorm1d(n_hidden), + nn.ReLU(inplace=True), + *ResBlocks, + nn.Conv1d(in_channels=n_hidden, out_channels=n_output, kernel_size=1) + ) + + def forward(self, x: Tensor) -> Tensor: + r""" + Args: + x: the input sequence to the _MelResNet layer + + Shape: + - x: :math:`(batch, freq, time)` + - output: :math:`(batch, n_output, time - kernel_size + 1)` + """ + + return self.melresnet_model(x) + + +class _Stretch2d(nn.Module): + r"""This is a two-dimensional stretch layer. It is a block used in WaveRNN. + + Args: + x_scale: the scale factor in x axis + y_scale: the scale factor in y axis + + Examples:: + >>> stretch2d = _Stretch2d(x_scale=10, y_scale=10) + + >>> input = torch.rand(10, 1, 100, 512) + >>> output = stretch2d(input) + """ + + def __init__(self, + x_scale: int, + y_scale: int) -> None: + super().__init__() + + self.x_scale = x_scale + self.y_scale = y_scale + + def forward(self, x: Tensor) -> Tensor: + r""" + Args: + x: the input sequence to the _Stretch2d layer + + Shape: + - x: :math:`(..., freq, time)` + - output: :math:`(..., freq * y_scale, time * x_scale)` + """ + + return x.repeat_interleave(self.y_scale, 2).repeat_interleave(self.x_scale, 3) + + +class _UpsampleNetwork(nn.Module): + r"""This is an upsample block based on a stack of Conv2d and Strech2d layers. + It is a block used in WaveRNN. + + Args: + upsample_scales: the list of upsample scales + n_res_block: the number of ResBlock in stack (default=10) + n_freq: the number of bins in a spectrogram (default=128) + n_hidden: the number of hidden dimensions (default=128) + n_output: the number of output dimensions (default=128) + kernel_size: the number of kernel size in the first Conv1d layer (default=5) + + Examples:: + >>> upsamplenetwork = _UpsampleNetwork(upsample_scales=[4, 4, 16], + n_res_block=10, + n_freq=128, + n_hidden=128, + n_output=128, + kernel_size=5) + >>> input = torch.rand(10, 128, 512) + >>> output = upsamplenetwork(input) + """ + + def __init__(self, + upsample_scales: List[int], + n_res_block: int = 10, + n_freq: int = 128, + n_hidden: int = 128, + n_output: int = 128, + kernel_size: int = 5) -> None: + super().__init__() + + total_scale = 1 + for upsample_scale in upsample_scales: + total_scale *= upsample_scale + + self.indent = (kernel_size - 1) // 2 * total_scale + self.resnet = _MelResNet(n_res_block, n_freq, n_hidden, n_output, kernel_size) + self.resnet_stretch = _Stretch2d(total_scale, 1) + + up_layers = [] + for scale in upsample_scales: + k_size = (1, scale * 2 + 1) + padding = (0, scale) + stretch = _Stretch2d(scale, 1) + conv = nn.Conv2d(in_channels=1, out_channels=1, kernel_size=k_size, padding=padding, bias=False) + conv.weight.data.fill_(1. / k_size[1]) + up_layers.append(stretch) + up_layers.append(conv) + self.upsample_layers = nn.Sequential(*up_layers) + + def forward(self, x: Tensor) -> Tensor: + r""" + Args: + x: the input sequence to the _UpsampleNetwork layer + + Shape: + - x: :math:`(batch, freq, time)`. + - output: :math:`(batch, (time - kernel_size + 1) * total_scale, freq)`, + `(batch, (time - kernel_size + 1) * total_scale, n_output)` + where total_scale is the product of all elements in upsample_scales. + """ + + resnet_output = self.resnet(x).unsqueeze(1) + resnet_output = self.resnet_stretch(resnet_output) + resnet_output = resnet_output.squeeze(1) + + x = x.unsqueeze(1) + upsampling_output = self.upsample_layers(x) + upsampling_output = upsampling_output.squeeze(1)[:, :, self.indent:-self.indent] + + return upsampling_output.transpose(1, 2), resnet_output.transpose(1, 2) + + +class _WaveRNN(nn.Module): + r""" + Args: + upsample_scales: the list of upsample scales + n_bits: the bits of output waveform + sample_rate: the rate of audio dimensions (samples per second) + hop_length: the number of samples between the starts of consecutive frames + n_res_block: the number of ResBlock in stack (default=10) + n_rnn: the dimension of RNN layer (default=512) + n_fc: the dimension of fully connected layer (default=512) + kernel_size: the number of kernel size in the first Conv1d layer (default=5) + n_freq: the number of bins in a spectrogram (default=128) + n_hidden: the number of hidden dimensions (default=128) + n_output: the number of output dimensions (default=128) + mode: the type of input waveform (default='RAW') + + Examples:: + >>> upsamplenetwork = _waveRNN(upsample_scales=[5,5,8], + n_bits=9, + sample_rate=24000, + hop_length=200, + n_res_block=10, + n_rnn=512, + n_fc=512, + kernel_size=5, + n_freq=128, + n_hidden=128, + n_output=128, + mode='RAW') + >>> x = torch.rand(10, 24800, 512) + >>> mels = torch.rand(10, 128, 512) + >>> output = upsamplenetwork(x, mels) + """ + + def __init__(self, + upsample_scales: List[int], + n_bits: int, + sample_rate: int, + hop_length: int, + n_res_block: int = 10, + n_rnn: int = 512, + n_fc: int = 512, + kernel_size: int = 5, + n_freq: int = 128, + n_hidden: int = 128, + n_output: int = 128, + mode: str = 'waveform') -> None: + super().__init__() + + self.mode = mode + self.kernel_size = kernel_size + + if self.mode == 'waveform': + self.n_classes = 2 ** n_bits + elif self.mode == 'mol': + self.n_classes = 30 + + self.n_rnn = n_rnn + self.n_aux = n_output // 4 + self.hop_length = hop_length + self.sample_rate = sample_rate + + self.upsample = _UpsampleNetwork(upsample_scales, n_res_block, n_freq, n_hidden, n_output, kernel_size) + self.fc = nn.Linear(n_freq + self.n_aux + 1, n_rnn) + + self.rnn1 = nn.GRU(n_rnn, n_rnn, batch_first=True) + self.rnn2 = nn.GRU(n_rnn + self.n_aux, n_rnn, batch_first=True) + + self.relu1 = nn.ReLU() + self.relu2 = nn.ReLU() + + self.fc1 = nn.Linear(n_rnn + self.n_aux, n_fc) + self.fc2 = nn.Linear(n_fc + self.n_aux, n_fc) + self.fc3 = nn.Linear(n_fc, self.n_classes) + + def forward(self, x: Tensor, mels: Tensor) -> Tensor: + r""" + Args: + x: the input waveform to the _WaveRNN layer + mels: the input mel-spectrogram to the _WaveRNN layer + + Shape: + - x: :math:`(batch, time)` + - mels: :math:`(batch, freq, time_mels)` + - output: :math:`(batch, time, 2 ** n_bits)` + """ + + batch_size = x.size(0) + h1 = torch.zeros(1, batch_size, self.n_rnn, device=x.device) + h2 = torch.zeros(1, batch_size, self.n_rnn, device=x.device) + mels, aux = self.upsample(mels) + + aux_idx = [self.n_aux * i for i in range(5)] + a1 = aux[:, :, aux_idx[0]:aux_idx[1]] + a2 = aux[:, :, aux_idx[1]:aux_idx[2]] + a3 = aux[:, :, aux_idx[2]:aux_idx[3]] + a4 = aux[:, :, aux_idx[3]:aux_idx[4]] + + x = torch.cat([x.unsqueeze(-1), mels, a1], dim=2) + x = self.fc(x) + res = x + x, _ = self.rnn1(x, h1) + + x = x + res + res = x + x = torch.cat([x, a2], dim=2) + x, _ = self.rnn2(x, h2) + + x = x + res + x = torch.cat([x, a3], dim=2) + x = self.relu1(self.fc1(x)) + + x = torch.cat([x, a4], dim=2) + x = self.relu2(self.fc2(x)) + + return self.fc3(x) diff --git a/examples/pipeline_wavernn/loss_mol.py b/examples/pipeline_wavernn/mol_loss.py similarity index 71% rename from examples/pipeline_wavernn/loss_mol.py rename to examples/pipeline_wavernn/mol_loss.py index 70d1a18548..7ad8f9b6e5 100644 --- a/examples/pipeline_wavernn/loss_mol.py +++ b/examples/pipeline_wavernn/mol_loss.py @@ -1,13 +1,6 @@ import torch import torch.nn.functional as F -# Adapted from wavenet vocoder: -# https://github.com/r9y9/wavenet_vocoder/blob/master/wavenet_vocoder/mixture.py -# Explain mol loss: -# https://github.com/Rayhane-mamah/Tacotron-2/issues/155 - -# Remove numpy dependency - def log_sum_exp(x): """ numerically stable log_sum_exp implementation that prevents overflow """ @@ -18,12 +11,28 @@ def log_sum_exp(x): return m + torch.log(torch.sum(torch.exp(x - m2), dim=axis)) -def LossFn_Mol(y_hat, y, num_classes=65536, log_scale_min=None, reduce=True): - """ calculate the loss of mol mode""" +def Mol_Loss(y_hat, y, num_classes=65536, log_scale_min=None, reduce=True): + """ Discretized mixture of logistic distributions loss + + Adapted from wavenet vocoder: + https://github.com/r9y9/wavenet_vocoder/blob/master/wavenet_vocoder/mixture.py + Explanation of `mol` loss: + https://github.com/Rayhane-mamah/Tacotron-2/issues/155 + It is assumed that input is scaled to [-1, 1]. + + Args: + y_hat (Tensor): Predicted output (n_batch x n_time x n_channel) + y (Tensor): Target (n_batch x n_time x 1). + num_classes (int): Number of classes + log_scale_min (float): Log scale minimum value + reduce (bool): If True, the losses are averaged or summed for each minibatch. + + Returns + Tensor: loss + """ - min_value = -32.23619130191664 # = float(np.log(1e-14)) if log_scale_min is None: - log_scale_min = min_value + log_scale_min = torch.log(torch.as_tensor(1e-14)).item() assert y_hat.dim() == 3 assert y_hat.size(-1) % 3 == 0 @@ -63,10 +72,9 @@ def LossFn_Mol(y_hat, y, num_classes=65536, log_scale_min=None, reduce=True): inner_inner_cond = (cdf_delta > 1e-5).float() - tmp = 10.397192449493701 # = np.log((num_classes - 1) / 2) inner_inner_out = inner_inner_cond * \ torch.log(torch.clamp(cdf_delta, min=1e-12)) + \ - (1. - inner_inner_cond) * (log_pdf_mid - tmp) + (1. - inner_inner_cond) * (log_pdf_mid - torch.log(torch.as_tensor((num_classes - 1) / 2)).item()) inner_cond = (y > 0.999).float() inner_out = inner_cond * log_one_minus_cdf_min + (1. - inner_cond) * inner_inner_out cond = (y < -0.999).float() diff --git a/examples/pipeline_wavernn/utils.py b/examples/pipeline_wavernn/utils.py new file mode 100644 index 0000000000..0b714495a6 --- /dev/null +++ b/examples/pipeline_wavernn/utils.py @@ -0,0 +1,56 @@ +import logging +import os +import shutil +from collections import defaultdict, deque + +import torch + + +class MetricLogger: + def __init__(self, group, print_freq=1): + self.print_freq = print_freq + self._iter = 0 + self.data = defaultdict(lambda: deque(maxlen=self.print_freq)) + self.data["group"].append(group) + + def __call__(self, key, value): + self.data[key].append(value) + + def _get_last(self): + return {k: v[-1] for k, v in self.data.items()} + + def __str__(self): + return str(self._get_last()) + + def print(self): + self._iter = (self._iter + 1) % self.print_freq + if not self._iter: + print(self, flush=True) + + +def save_checkpoint(state, is_best, filename): + """ + Save the model to a temporary file first, + then copy it to filename, in case the signal interrupts + the torch.save() process. + """ + + if filename == "": + return + + tempfile = filename + ".temp" + + # Remove tempfile in case interuption during the copying from tempfile to filename + if os.path.isfile(tempfile): + os.remove(tempfile) + + torch.save(state, tempfile) + if os.path.isfile(tempfile): + os.rename(tempfile, filename) + if is_best: + shutil.copyfile(filename, "model_best.pth.tar") + logging.info("Checkpoint: saved") + + +def count_parameters(model): + return sum(p.numel() for p in model.parameters() if p.requires_grad) diff --git a/examples/pipeline_wavernn/wavernn.py b/examples/pipeline_wavernn/wavernn.py index 840236b61b..693ceec695 100644 --- a/examples/pipeline_wavernn/wavernn.py +++ b/examples/pipeline_wavernn/wavernn.py @@ -1,8 +1,10 @@ import argparse +import logging import os -import shutil +import signal from collections import defaultdict from datetime import datetime +from time import time import torch import torch.nn as nn @@ -10,34 +12,37 @@ from transform import Transform from datasets import datasets_ljspeech, collate_factory from typing import List -from torchaudio.models import _WaveRNN +from model import _WaveRNN from torch.utils.data import DataLoader +from torchaudio.datasets.utils import bg_iterator +from torch.optim.lr_scheduler import ExponentialLR, ReduceLROnPlateau from torch.optim import Adam -from tqdm import tqdm -from loss_mol import LossFn_Mol +from mol_loss import Mol_Loss +from utils import MetricLogger, count_parameters, save_checkpoint + +SIGNAL_RECEIVED = False def parse_args(): parser = argparse.ArgumentParser() - # training parameters parser.add_argument( "--workers", - default=8, + default=4, type=int, metavar="N", help="number of data loading workers", ) parser.add_argument( "--checkpoint", - default="checkpoint.pth.par", + default="checkpoint.pth.tar", type=str, metavar="FILE", help="filename to latest checkpoint", ) parser.add_argument( "--epochs", - default=2000, + default=3000, type=int, metavar="N", help="number of total epochs to run", @@ -58,7 +63,7 @@ def parse_args(): ) parser.add_argument( "--batch-size", - default=32, + default=256, type=int, metavar="N", help="mini-batch size" @@ -81,45 +86,80 @@ def parse_args(): "--adam-beta1", default=0.9, type=float, - metavar="BETA1", + metavar="AD1", help="adam_beta1" ) parser.add_argument( "--adam-beta2", default=0.999, type=float, - metavar="BETA2", + metavar="AD2", help="adam_beta2" ) parser.add_argument( "--eps", - default=1e-8, - type=float, metavar="EPS", - help="eps") + type=float, + default=1e-8 + ) parser.add_argument( "--clip-norm", metavar="NORM", type=float, - default=4.0, - help="clip norm value") - - parser.add_argument("--seed", type=int, default=1000, help="random seed") - parser.add_argument("--progress-bar", default=False, action="store_true", help="use progress bar while training") - parser.add_argument("--mulaw", default=True, action="store_true", help="if used, waveform is mulaw encoded") - parser.add_argument("--jit", default=False, action="store_true", help="if used, model is jitted") - # parser.add_argument("--distributed", default=False, action="store_true", help="enable DistributedDataParallel") - - # model parameters - - # the product of upsample_scales must equal hop_length + default=4.0 + ) + parser.add_argument( + "--scheduler", + metavar="S", + default="exponential", + choices=["exponential", "reduceonplateau"], + help="optimizer to use", + ) + parser.add_argument( + "--gamma", + default=0.999, + type=float, + metavar="GAMMA", + help="learning rate exponential decay constant", + ) + parser.add_argument( + "--seed", + type=int, + default=1000, + help="random seed" + ) + parser.add_argument( + "--progress-bar", + default=False, + action="store_true", + help="use progress bar while training" + ) + parser.add_argument( + "--mulaw", + default=True, + action="store_true", + help="if used, waveform is mulaw encoded" + ) + parser.add_argument( + "--jit", + default=False, + action="store_true", + help="if used, model is jitted" + ) + parser.add_argument( + '--resume', + default='', + type=str, + metavar='PATH', + help='path to latest checkpoint' + ) + # the product of `upsample_scales` must equal `hop_length` parser.add_argument( "--upsample-scales", default=[5, 5, 11], type=List[int], help="the list of upsample scales", ) - # output waveform bits parser.add_argument( "--n-bits", default=9, @@ -172,7 +212,7 @@ def parse_args(): "--n-fc", default=512, type=int, - help="the dimension of fully connected layer ", + help="the dimension of fully connected layer", ) parser.add_argument( "--kernel-size", @@ -198,28 +238,25 @@ def parse_args(): type=int, help="the number of output dimensions", ) - # mode = ['waveform', 'mol'] parser.add_argument( "--mode", - default="mol", + default="waveform", + choices=["waveform", "mol"], type=str, - help="the mode of waveform", + help="the type of waveform", ) - # the length of input waveform and spectrogram parser.add_argument( "--seq-len-factor", default=5, type=int, help="seq_length = hop_length * seq_len_factor", ) - # the number of waveforms for testing parser.add_argument( "--test-samples", default=50, type=float, - help="the number of test waveforms", + help="the number of waveforms for testing", ) - # the path to store audio files parser.add_argument( "--file-path", default="/private/home/jimchen90/datasets/LJSpeech-1.1/wavs/", @@ -231,114 +268,126 @@ def parse_args(): return args -# From wav2letter pipeline: -# https://github.com/vincentqb/audio/blob/wav2letter/examples/pipeline/wav2letter.py -def save_checkpoint(state, is_best, filename): - - if filename == "": - return - - tempfile = filename + ".temp" +def signal_handler(a, b): + global SIGNAL_RECEIVED + print("Signal received", a, datetime.now().strftime("%y%m%d.%H%M%S"), flush=True) + SIGNAL_RECEIVED = True - # Remove tempfile in case interuption during the copying from tempfile to filename - if os.path.isfile(tempfile): - os.remove(tempfile) - torch.save(state, tempfile) - if os.path.isfile(tempfile): - os.rename(tempfile, filename) - if is_best: - shutil.copyfile(filename, "model_best.pth.tar") +def train_one_epoch( + model, mode, criterion, optimizer, scheduler, data_loader, device, epoch +): - print("Checkpoint: saved", flush=True) + model.train() + sums = defaultdict(lambda: 0.0) + start1 = time() -# count parameter numbers in model -def count_parameters(model): - return sum(p.numel() for p in model.parameters() if p.requires_grad) + metric = MetricLogger("train_iteration") + metric("epoch", epoch) + for waveform, specgram, target in bg_iterator(data_loader, maxsize=2): -# train one epoch -def train_one_epoch(model, mode, bits, mulaw, criterion, optimizer, data_loader, device, pbar=None): - model.train() + start2 = time() - sums = defaultdict(lambda: 0.0) - - for i, (waveform, specgram, target) in enumerate(data_loader, 1): - waveform = waveform.to(device, non_blocking=True) - specgram = specgram.to(device, non_blocking=True) - target = target.to(device, non_blocking=True) + waveform = waveform.to(device) + specgram = specgram.to(device) + target = target.to(device) output = model(waveform, specgram) if mode == 'waveform': - # (n_batch, 2 ** n_bits, n_time) output = output.transpose(1, 2) target = target.long() elif mode == 'mol': - # (n_batch, n_time, 1) target = target.unsqueeze(-1) else: - raise ValueError(f"Expected mode: `waveform` or `mol`, but found {mode}") + raise ValueError( + f"Expected mode: `waveform` or `mol`, but found {mode}" + ) loss = criterion(output, target) - sums["loss"] += loss.item() + loss_item = loss.item() + sums["loss"] += loss_item + metric("loss", loss_item) optimizer.zero_grad() loss.backward() if args.clip_norm > 0: - sums["gradient"] += torch.nn.utils.clip_grad_norm_( + gradient = torch.nn.utils.clip_grad_norm_( model.parameters(), args.clip_norm ) + sums["gradient"] += gradient + metric("gradient", gradient.item()) optimizer.step() - if pbar is not None: - pbar.update(1 / len(data_loader)) + metric("iteration", sums["iteration"]) + metric("time", time() - start2) + metric.print() + sums["iteration"] += 1 + + if SIGNAL_RECEIVED: + return avg_loss = sums["loss"] / len(data_loader) - print(f"Training loss: {avg_loss:4.5f}", flush=True) + metric = MetricLogger("train_epoch") + metric("epoch", epoch) + metric("loss", avg_loss) if "gradient" in sums: - avg_gradient = sums["gradient"] / len(data_loader) - print(f"Average gradient norm: {avg_gradient:4.8f}", flush=True) + metric("gradient", sums["gradient"] / len(data_loader)) + metric("lr", scheduler.get_last_lr()[0]) + metric("time", time() - start1) + metric.print() + scheduler.step() -def evaluate(model, mode, bits, mulaw, criterion, data_loader, device): + +def evaluate(model, mode, criterion, data_loader, device, epoch): with torch.no_grad(): model.eval() - sums = defaultdict(lambda: 0.0) + start = time() - for i, (waveform, specgram, target) in enumerate(data_loader, 1): - waveform = waveform.to(device, non_blocking=True) - specgram = specgram.to(device, non_blocking=True) - target = target.to(device, non_blocking=True) + for waveform, specgram, target in bg_iterator(data_loader, maxsize=2): + + waveform = waveform.to(device) + specgram = specgram.to(device) + target = target.to(device) output = model(waveform, specgram) if mode == 'waveform': - # (batch, 2 ** bits, seq_len) output = output.transpose(1, 2) target = target.long() elif mode == 'mol': - # (batch, seq_len, 1) target = target.unsqueeze(-1) else: - raise ValueError(f"Expected mode: `waveform` or `mol`, but found {mode}") + raise ValueError( + f"Expected mode: `waveform` or `mol`, but found {mode}" + ) loss = criterion(output, target) sums["loss"] += loss.item() + if SIGNAL_RECEIVED: + break + avg_loss = sums["loss"] / len(data_loader) - print(f"Validation loss: {avg_loss:.8f}", flush=True) + + metric = MetricLogger("validation") + metric("epoch", epoch) + metric("loss", avg_loss) + metric("time", time() - start) + metric.print() return avg_loss @@ -347,11 +396,14 @@ def main(args): devices = ["cuda" if torch.cuda.is_available() else "cpu"] - print("Start time: {}".format(str(datetime.now())), flush=True) + logging.info("Start time: {}".format(str(datetime.now()))) # Empty CUDA cache torch.cuda.empty_cache() + # Install signal handler + signal.signal(signal.SIGUSR1, lambda a, b: signal_handler(a, b)) + # use torchaudio transform to get waveform and specgram # melkwargs = { @@ -366,7 +418,6 @@ def main(args): # torchaudio.transforms.MelSpectrogram( # sample_rate=args.sample_rate, **melkwargs # ), -# torchaudio.transforms.MuLawEncoding(2**args.n_bits) # ) # use librosa transform to get waveform and specgram @@ -387,7 +438,7 @@ def main(args): loader_training_params = { "num_workers": args.workers, - "pin_memory": True, + "pin_memory": False, "shuffle": True, "drop_last": False, } @@ -402,7 +453,6 @@ def main(args): collate_fn=collate_fn, **loader_training_params, ) - loader_test = DataLoader( test_dataset, batch_size=args.batch_size, @@ -422,24 +472,19 @@ def main(args): n_freq=args.n_freq, n_hidden=args.n_hidden, n_output=args.n_output, - mode=args.mode) + mode=args.mode, + ) -# if args.jit: -# model = torch.jit.script(model) + if args.jit: + model = torch.jit.script(model) -# if not args.distributed: -# model = torch.nn.DataParallel(model) -# else: -# model.cuda() -# model = torch.nn.parallel.DistributedDataParallel(model, device_ids=devices) model = torch.nn.DataParallel(model) model = model.to(devices[0], non_blocking=True) n = count_parameters(model) - print(f"Number of parameters: {n}", flush=True) + logging.info(f"Number of parameters: {n}") # Optimizer - optimizer_params = { "lr": args.learning_rate, "betas": (args.adam_beta1, args.adam_beta2), @@ -449,14 +494,19 @@ def main(args): optimizer = Adam(model.parameters(), **optimizer_params) - criterion = nn.CrossEntropyLoss() if args.mode == 'waveform' else LossFn_Mol + if args.scheduler == "exponential": + scheduler = ExponentialLR(optimizer, gamma=args.gamma) + elif args.scheduler == "reduceonplateau": + scheduler = ReduceLROnPlateau(optimizer, patience=10, threshold=1e-3) - best_loss = 1.0 + criterion = nn.CrossEntropyLoss() if args.mode == 'waveform' else Mol_Loss + + best_loss = 10. load_checkpoint = args.checkpoint and os.path.isfile(args.checkpoint) if load_checkpoint: - print("Checkpoint: loading '{}'".format(args.checkpoint), flush=True) + logging.info(f"Checkpoint: loading '{args.checkpoint}'") checkpoint = torch.load(args.checkpoint) args.start_epoch = checkpoint["epoch"] @@ -464,12 +514,13 @@ def main(args): model.load_state_dict(checkpoint["state_dict"]) optimizer.load_state_dict(checkpoint["optimizer"]) - # scheduler.load_state_dict(checkpoint["scheduler"]) - - print("Checkpoint: loaded '{}' at epoch {}".format(args.checkpoint, checkpoint["epoch"]), flush=True,) + scheduler.load_state_dict(checkpoint["scheduler"]) + logging.info( + f"Checkpoint: loaded '{args.checkpoint}' at epoch {checkpoint['epoch']}" + ) else: - print("Checkpoint: not found", flush=True) + logging.info("Checkpoint: not found") save_checkpoint( { @@ -477,54 +528,58 @@ def main(args): "state_dict": model.state_dict(), "best_loss": best_loss, "optimizer": optimizer.state_dict(), - # "scheduler": scheduler.state_dict(), + "scheduler": scheduler.state_dict(), }, False, args.checkpoint, ) - with tqdm(total=args.epochs, unit_scale=1, disable=not args.progress_bar) as pbar: + for epoch in range(args.start_epoch, args.epochs): + + train_one_epoch( + model, args.mode, criterion, optimizer, scheduler, loader_training, devices[0], epoch, + ) + + if SIGNAL_RECEIVED: + save_checkpoint({ + 'epoch': epoch, + 'state_dict': model.state_dict(), + 'best_loss': best_loss, + 'optimizer': optimizer.state_dict(), + 'scheduler': scheduler.state_dict(), + }, False, args.checkpoint) - for epoch in range(args.start_epoch, args.epochs): + if not (epoch + 1) % args.print_freq or epoch == args.epochs - 1: - train_one_epoch( + sum_loss = evaluate( model, args.mode, - args.n_bits, - args.mulaw, criterion, - optimizer, - loader_training, + loader_test, devices[0], - pbar=pbar, + epoch, ) - if not (epoch + 1) % args.print_freq or epoch + 1 == args.epochs: + is_best = sum_loss < best_loss + best_loss = min(sum_loss, best_loss) + save_checkpoint( + { + "epoch": epoch + 1, + "state_dict": model.state_dict(), + "best_loss": best_loss, + "optimizer": optimizer.state_dict(), + "scheduler": scheduler.state_dict(), + }, + is_best, + args.checkpoint, + ) - sum_loss = evaluate( - model, - args.mode, - args.n_bits, - args.mulaw, - criterion, - loader_test, - devices[0], - ) - - is_best = sum_loss < best_loss - best_loss = min(sum_loss, best_loss) - save_checkpoint( - { - "epoch": epoch + 1, - "state_dict": model.state_dict(), - "best_loss": best_loss, - "optimizer": optimizer.state_dict(), - }, - is_best, - args.checkpoint, - ) + logging.info(f"End time: {datetime.now()}") if __name__ == "__main__": + + logging.basicConfig(level=logging.INFO) + args = parse_args() main(args) From 131fe3342076946306c7bcf9e8eb9453a52dea85 Mon Sep 17 00:00:00 2001 From: Ji Chen Date: Mon, 6 Jul 2020 06:55:16 -0700 Subject: [PATCH 06/29] update model import --- examples/pipeline_wavernn/model.py | 324 --------------------------- examples/pipeline_wavernn/wavernn.py | 2 +- 2 files changed, 1 insertion(+), 325 deletions(-) delete mode 100644 examples/pipeline_wavernn/model.py diff --git a/examples/pipeline_wavernn/model.py b/examples/pipeline_wavernn/model.py deleted file mode 100644 index f808db5f74..0000000000 --- a/examples/pipeline_wavernn/model.py +++ /dev/null @@ -1,324 +0,0 @@ -from typing import List - -import torch -from torch import Tensor -from torch import nn - -__all__ = ["_ResBlock", "_MelResNet", "_Stretch2d", "_UpsampleNetwork", "_WaveRNN"] - - -class _ResBlock(nn.Module): - r"""This is a ResNet block layer. This layer is based on the paper "Deep Residual Learning - for Image Recognition". Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun. CVPR, 2016. - It is a block used in WaveRNN. - - Args: - n_freq: the number of bins in a spectrogram (default=128) - - Examples:: - >>> resblock = _ResBlock(n_freq=128) - >>> input = torch.rand(10, 128, 512) - >>> output = resblock(input) - """ - - def __init__(self, n_freq: int = 128) -> None: - super().__init__() - - self.resblock_model = nn.Sequential( - nn.Conv1d(in_channels=n_freq, out_channels=n_freq, kernel_size=1, bias=False), - nn.BatchNorm1d(n_freq), - nn.ReLU(inplace=True), - nn.Conv1d(in_channels=n_freq, out_channels=n_freq, kernel_size=1, bias=False), - nn.BatchNorm1d(n_freq) - ) - - def forward(self, x: Tensor) -> Tensor: - r""" - Args: - x: the input sequence to the _ResBlock layer - - Shape: - - x: :math:`(batch, freq, time)` - - output: :math:`(batch, freq, time)` - """ - - residual = x - return self.resblock_model(x) + residual - - -class _MelResNet(nn.Module): - r"""This is a MelResNet layer based on a stack of ResBlocks. It is a block used in WaveRNN. - - Args: - n_res_block: the number of ResBlock in stack (default=10) - n_freq: the number of bins in a spectrogram (default=128) - n_hidden: the number of hidden dimensions (default=128) - n_output: the number of output dimensions (default=128) - kernel_size: the number of kernel size in the first Conv1d layer (default=5) - - Examples:: - >>> melresnet = _MelResNet(n_res_block=10, n_freq=128, n_hidden=128, - n_output=128, kernel_size=5) - >>> input = torch.rand(10, 128, 512) - >>> output = melresnet(input) - """ - - def __init__(self, - n_res_block: int = 10, - n_freq: int = 128, - n_hidden: int = 128, - n_output: int = 128, - kernel_size: int = 5) -> None: - super().__init__() - - ResBlocks = [] - - for i in range(n_res_block): - ResBlocks.append(_ResBlock(n_hidden)) - - self.melresnet_model = nn.Sequential( - nn.Conv1d(in_channels=n_freq, out_channels=n_hidden, kernel_size=kernel_size, bias=False), - nn.BatchNorm1d(n_hidden), - nn.ReLU(inplace=True), - *ResBlocks, - nn.Conv1d(in_channels=n_hidden, out_channels=n_output, kernel_size=1) - ) - - def forward(self, x: Tensor) -> Tensor: - r""" - Args: - x: the input sequence to the _MelResNet layer - - Shape: - - x: :math:`(batch, freq, time)` - - output: :math:`(batch, n_output, time - kernel_size + 1)` - """ - - return self.melresnet_model(x) - - -class _Stretch2d(nn.Module): - r"""This is a two-dimensional stretch layer. It is a block used in WaveRNN. - - Args: - x_scale: the scale factor in x axis - y_scale: the scale factor in y axis - - Examples:: - >>> stretch2d = _Stretch2d(x_scale=10, y_scale=10) - - >>> input = torch.rand(10, 1, 100, 512) - >>> output = stretch2d(input) - """ - - def __init__(self, - x_scale: int, - y_scale: int) -> None: - super().__init__() - - self.x_scale = x_scale - self.y_scale = y_scale - - def forward(self, x: Tensor) -> Tensor: - r""" - Args: - x: the input sequence to the _Stretch2d layer - - Shape: - - x: :math:`(..., freq, time)` - - output: :math:`(..., freq * y_scale, time * x_scale)` - """ - - return x.repeat_interleave(self.y_scale, 2).repeat_interleave(self.x_scale, 3) - - -class _UpsampleNetwork(nn.Module): - r"""This is an upsample block based on a stack of Conv2d and Strech2d layers. - It is a block used in WaveRNN. - - Args: - upsample_scales: the list of upsample scales - n_res_block: the number of ResBlock in stack (default=10) - n_freq: the number of bins in a spectrogram (default=128) - n_hidden: the number of hidden dimensions (default=128) - n_output: the number of output dimensions (default=128) - kernel_size: the number of kernel size in the first Conv1d layer (default=5) - - Examples:: - >>> upsamplenetwork = _UpsampleNetwork(upsample_scales=[4, 4, 16], - n_res_block=10, - n_freq=128, - n_hidden=128, - n_output=128, - kernel_size=5) - >>> input = torch.rand(10, 128, 512) - >>> output = upsamplenetwork(input) - """ - - def __init__(self, - upsample_scales: List[int], - n_res_block: int = 10, - n_freq: int = 128, - n_hidden: int = 128, - n_output: int = 128, - kernel_size: int = 5) -> None: - super().__init__() - - total_scale = 1 - for upsample_scale in upsample_scales: - total_scale *= upsample_scale - - self.indent = (kernel_size - 1) // 2 * total_scale - self.resnet = _MelResNet(n_res_block, n_freq, n_hidden, n_output, kernel_size) - self.resnet_stretch = _Stretch2d(total_scale, 1) - - up_layers = [] - for scale in upsample_scales: - k_size = (1, scale * 2 + 1) - padding = (0, scale) - stretch = _Stretch2d(scale, 1) - conv = nn.Conv2d(in_channels=1, out_channels=1, kernel_size=k_size, padding=padding, bias=False) - conv.weight.data.fill_(1. / k_size[1]) - up_layers.append(stretch) - up_layers.append(conv) - self.upsample_layers = nn.Sequential(*up_layers) - - def forward(self, x: Tensor) -> Tensor: - r""" - Args: - x: the input sequence to the _UpsampleNetwork layer - - Shape: - - x: :math:`(batch, freq, time)`. - - output: :math:`(batch, (time - kernel_size + 1) * total_scale, freq)`, - `(batch, (time - kernel_size + 1) * total_scale, n_output)` - where total_scale is the product of all elements in upsample_scales. - """ - - resnet_output = self.resnet(x).unsqueeze(1) - resnet_output = self.resnet_stretch(resnet_output) - resnet_output = resnet_output.squeeze(1) - - x = x.unsqueeze(1) - upsampling_output = self.upsample_layers(x) - upsampling_output = upsampling_output.squeeze(1)[:, :, self.indent:-self.indent] - - return upsampling_output.transpose(1, 2), resnet_output.transpose(1, 2) - - -class _WaveRNN(nn.Module): - r""" - Args: - upsample_scales: the list of upsample scales - n_bits: the bits of output waveform - sample_rate: the rate of audio dimensions (samples per second) - hop_length: the number of samples between the starts of consecutive frames - n_res_block: the number of ResBlock in stack (default=10) - n_rnn: the dimension of RNN layer (default=512) - n_fc: the dimension of fully connected layer (default=512) - kernel_size: the number of kernel size in the first Conv1d layer (default=5) - n_freq: the number of bins in a spectrogram (default=128) - n_hidden: the number of hidden dimensions (default=128) - n_output: the number of output dimensions (default=128) - mode: the type of input waveform (default='RAW') - - Examples:: - >>> upsamplenetwork = _waveRNN(upsample_scales=[5,5,8], - n_bits=9, - sample_rate=24000, - hop_length=200, - n_res_block=10, - n_rnn=512, - n_fc=512, - kernel_size=5, - n_freq=128, - n_hidden=128, - n_output=128, - mode='RAW') - >>> x = torch.rand(10, 24800, 512) - >>> mels = torch.rand(10, 128, 512) - >>> output = upsamplenetwork(x, mels) - """ - - def __init__(self, - upsample_scales: List[int], - n_bits: int, - sample_rate: int, - hop_length: int, - n_res_block: int = 10, - n_rnn: int = 512, - n_fc: int = 512, - kernel_size: int = 5, - n_freq: int = 128, - n_hidden: int = 128, - n_output: int = 128, - mode: str = 'waveform') -> None: - super().__init__() - - self.mode = mode - self.kernel_size = kernel_size - - if self.mode == 'waveform': - self.n_classes = 2 ** n_bits - elif self.mode == 'mol': - self.n_classes = 30 - - self.n_rnn = n_rnn - self.n_aux = n_output // 4 - self.hop_length = hop_length - self.sample_rate = sample_rate - - self.upsample = _UpsampleNetwork(upsample_scales, n_res_block, n_freq, n_hidden, n_output, kernel_size) - self.fc = nn.Linear(n_freq + self.n_aux + 1, n_rnn) - - self.rnn1 = nn.GRU(n_rnn, n_rnn, batch_first=True) - self.rnn2 = nn.GRU(n_rnn + self.n_aux, n_rnn, batch_first=True) - - self.relu1 = nn.ReLU() - self.relu2 = nn.ReLU() - - self.fc1 = nn.Linear(n_rnn + self.n_aux, n_fc) - self.fc2 = nn.Linear(n_fc + self.n_aux, n_fc) - self.fc3 = nn.Linear(n_fc, self.n_classes) - - def forward(self, x: Tensor, mels: Tensor) -> Tensor: - r""" - Args: - x: the input waveform to the _WaveRNN layer - mels: the input mel-spectrogram to the _WaveRNN layer - - Shape: - - x: :math:`(batch, time)` - - mels: :math:`(batch, freq, time_mels)` - - output: :math:`(batch, time, 2 ** n_bits)` - """ - - batch_size = x.size(0) - h1 = torch.zeros(1, batch_size, self.n_rnn, device=x.device) - h2 = torch.zeros(1, batch_size, self.n_rnn, device=x.device) - mels, aux = self.upsample(mels) - - aux_idx = [self.n_aux * i for i in range(5)] - a1 = aux[:, :, aux_idx[0]:aux_idx[1]] - a2 = aux[:, :, aux_idx[1]:aux_idx[2]] - a3 = aux[:, :, aux_idx[2]:aux_idx[3]] - a4 = aux[:, :, aux_idx[3]:aux_idx[4]] - - x = torch.cat([x.unsqueeze(-1), mels, a1], dim=2) - x = self.fc(x) - res = x - x, _ = self.rnn1(x, h1) - - x = x + res - res = x - x = torch.cat([x, a2], dim=2) - x, _ = self.rnn2(x, h2) - - x = x + res - x = torch.cat([x, a3], dim=2) - x = self.relu1(self.fc1(x)) - - x = torch.cat([x, a4], dim=2) - x = self.relu2(self.fc2(x)) - - return self.fc3(x) diff --git a/examples/pipeline_wavernn/wavernn.py b/examples/pipeline_wavernn/wavernn.py index 693ceec695..af8e4f0da5 100644 --- a/examples/pipeline_wavernn/wavernn.py +++ b/examples/pipeline_wavernn/wavernn.py @@ -12,7 +12,7 @@ from transform import Transform from datasets import datasets_ljspeech, collate_factory from typing import List -from model import _WaveRNN +from torchaudio.models import _WaveRNN from torch.utils.data import DataLoader from torchaudio.datasets.utils import bg_iterator from torch.optim.lr_scheduler import ExponentialLR, ReduceLROnPlateau From 1d1c68395f22fb544ce0d6ba850f610eb97ff6b2 Mon Sep 17 00:00:00 2001 From: Ji Chen Date: Mon, 6 Jul 2020 07:00:15 -0700 Subject: [PATCH 07/29] update readme --- examples/pipeline_wavernn/README | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/pipeline_wavernn/README b/examples/pipeline_wavernn/README index 5e5552f0ea..51f07c06d8 100644 --- a/examples/pipeline_wavernn/README +++ b/examples/pipeline_wavernn/README @@ -2,7 +2,7 @@ This is an example pipeline for WaveRNN vocoder using the WaveRNN model trained ### Output -The information reported at each iteration and epoch (e.g. loss) is printed to standard output in the form of one json per line. Further information is reported to standard error. Here is an example python function to parse the standard output. +The information reported at each iteration and epoch (e.g. loss) is printed to standard output in the form of one json per line. Here is an example python function to parse the standard output. ```python def read_json(filename): """ @@ -32,4 +32,4 @@ python main.py \ --n-freq 80 \ --mode 'mol' \ --n_bits 9 \ -``` \ No newline at end of file +``` From 969966cb1e0f43a288810279d57b8c5f8878e166 Mon Sep 17 00:00:00 2001 From: Ji Chen Date: Mon, 6 Jul 2020 07:29:38 -0700 Subject: [PATCH 08/29] update readme --- examples/pipeline_wavernn/{README => README.md} | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) rename examples/pipeline_wavernn/{README => README.md} (97%) diff --git a/examples/pipeline_wavernn/README b/examples/pipeline_wavernn/README.md similarity index 97% rename from examples/pipeline_wavernn/README rename to examples/pipeline_wavernn/README.md index 51f07c06d8..910056bc07 100644 --- a/examples/pipeline_wavernn/README +++ b/examples/pipeline_wavernn/README.md @@ -30,6 +30,6 @@ python main.py \ --batch-size 128 \ --learning-rate 1e-4 \ --n-freq 80 \ - --mode 'mol' \ + --mode 'waveform' \ --n_bits 9 \ ``` From 4f9cc60fbbfe2463afb71b691d618b5a85d786b5 Mon Sep 17 00:00:00 2001 From: Ji Chen Date: Mon, 6 Jul 2020 20:02:49 -0700 Subject: [PATCH 09/29] add reference in readme --- examples/pipeline_wavernn/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/pipeline_wavernn/README.md b/examples/pipeline_wavernn/README.md index 910056bc07..aba40bec90 100644 --- a/examples/pipeline_wavernn/README.md +++ b/examples/pipeline_wavernn/README.md @@ -1,4 +1,4 @@ -This is an example pipeline for WaveRNN vocoder using the WaveRNN model trained with LJSPEECH. WaveRNN and LJSPEECH are available in torchaudio. +This is an example pipeline for WaveRNN vocoder using the WaveRNN model trained with LJSPEECH. WaveRNN model is based on the implementation from https://github.com/fatchord/WaveRNN. The original implementation was introduced in "Efficient Neural Audio Synthesis". WaveRNN and LJSPEECH are available in torchaudio. ### Output From 4671bedc43d9b97b6b0a44fa0d2e9a2d2bc72616 Mon Sep 17 00:00:00 2001 From: Ji Chen Date: Wed, 8 Jul 2020 06:32:17 -0700 Subject: [PATCH 10/29] add channel dimension --- examples/pipeline_wavernn/README.md | 5 +++-- examples/pipeline_wavernn/datasets.py | 2 +- examples/pipeline_wavernn/wavernn.py | 3 ++- 3 files changed, 6 insertions(+), 4 deletions(-) diff --git a/examples/pipeline_wavernn/README.md b/examples/pipeline_wavernn/README.md index aba40bec90..d0689f680b 100644 --- a/examples/pipeline_wavernn/README.md +++ b/examples/pipeline_wavernn/README.md @@ -1,4 +1,5 @@ -This is an example pipeline for WaveRNN vocoder using the WaveRNN model trained with LJSPEECH. WaveRNN model is based on the implementation from https://github.com/fatchord/WaveRNN. The original implementation was introduced in "Efficient Neural Audio Synthesis". WaveRNN and LJSPEECH are available in torchaudio. +This is an example pipeline for WaveRNN vocoder using the WaveRNN model trained with LJSpeech. WaveRNN model is based on the implementation from https://github.com/fatchord/WaveRNN. The original implementation was +introduced in "Efficient Neural Audio Synthesis". WaveRNN and LJSpeech are available in torchaudio. ### Output @@ -31,5 +32,5 @@ python main.py \ --learning-rate 1e-4 \ --n-freq 80 \ --mode 'waveform' \ - --n_bits 9 \ + --n_bits 8 \ ``` diff --git a/examples/pipeline_wavernn/datasets.py b/examples/pipeline_wavernn/datasets.py index c61eb0ed46..58273c9938 100644 --- a/examples/pipeline_wavernn/datasets.py +++ b/examples/pipeline_wavernn/datasets.py @@ -109,6 +109,6 @@ def raw_collate(batch): if args.mode == 'mol': target = int_2_float(target.float(), bits) - return waveform, specgram, target + return waveform.unsqueeze(1), specgram.unsqueeze(1), target.unsqueeze(1) return raw_collate diff --git a/examples/pipeline_wavernn/wavernn.py b/examples/pipeline_wavernn/wavernn.py index af8e4f0da5..a3fc4bb95d 100644 --- a/examples/pipeline_wavernn/wavernn.py +++ b/examples/pipeline_wavernn/wavernn.py @@ -162,7 +162,7 @@ def parse_args(): ) parser.add_argument( "--n-bits", - default=9, + default=8, type=int, help="the bits of output waveform", ) @@ -295,6 +295,7 @@ def train_one_epoch( target = target.to(device) output = model(waveform, specgram) + output, target = output.squeeze(1), target.squeeze(1) if mode == 'waveform': output = output.transpose(1, 2) From 55f866baf79f77808b1facff5fc5606eabfc1d05 Mon Sep 17 00:00:00 2001 From: Ji Chen Date: Sun, 12 Jul 2020 11:37:26 -0700 Subject: [PATCH 11/29] Update the transform and dataset function --- examples/pipeline_wavernn/README.md | 4 +- examples/pipeline_wavernn/datasets.py | 90 +++--- .../{mol_loss.py => losses.py} | 5 +- examples/pipeline_wavernn/transform.py | 63 +--- examples/pipeline_wavernn/utils.py | 46 ++- examples/pipeline_wavernn/wavernn.py | 298 ++++-------------- 6 files changed, 158 insertions(+), 348 deletions(-) rename examples/pipeline_wavernn/{mol_loss.py => losses.py} (96%) diff --git a/examples/pipeline_wavernn/README.md b/examples/pipeline_wavernn/README.md index d0689f680b..d8ec32be7c 100644 --- a/examples/pipeline_wavernn/README.md +++ b/examples/pipeline_wavernn/README.md @@ -1,4 +1,4 @@ -This is an example pipeline for WaveRNN vocoder using the WaveRNN model trained with LJSpeech. WaveRNN model is based on the implementation from https://github.com/fatchord/WaveRNN. The original implementation was +This is an example pipeline for WaveRNN vocoder using the WaveRNN model trained with LJSpeech. WaveRNN model is based on the implementation from this repository https://github.com/fatchord/WaveRNN. The original implementation was introduced in "Efficient Neural Audio Synthesis". WaveRNN and LJSpeech are available in torchaudio. ### Output @@ -32,5 +32,5 @@ python main.py \ --learning-rate 1e-4 \ --n-freq 80 \ --mode 'waveform' \ - --n_bits 8 \ + --n-bits 8 \ ``` diff --git a/examples/pipeline_wavernn/datasets.py b/examples/pipeline_wavernn/datasets.py index 58273c9938..2557d9e274 100644 --- a/examples/pipeline_wavernn/datasets.py +++ b/examples/pipeline_wavernn/datasets.py @@ -1,40 +1,42 @@ import os import random + import torch import torchaudio from torchaudio.datasets import LJSPEECH +from transform import linear_to_mel +from utils import (label_to_waveform, mulaw_encode, specgram_normalize, + waveform_to_label) -class ProcessedLJSPEECH(LJSPEECH): - def __init__(self, files, transforms, mode, mulaw, n_bits): +class ProcessedLJSPEECH(LJSPEECH): + def __init__(self, files, transforms, args): self.transforms = transforms self.files = files - self.mulaw = mulaw - self.n_bits = n_bits - self.mode = mode + self.args = args def __getitem__(self, index): file = self.files[index] - - # use torchaudio transform to get waveform and specgram - # waveform, sample_rate = torchaudio.load(file) - # specgram = self.transforms(x) - # return waveform.squeeze(0), mel.squeeze(0) - - # use librosa transform to get waveform and specgram - waveform = self.transforms.load(file) - specgram = self.transforms.melspectrogram(waveform) - - # waveform and spectrogram: [0, 2**bits-1]. - if self.mode == 'waveform': - waveform = self.transforms.mulaw_encode(waveform, 2**self.n_bits) if self.mulaw \ - else float_2_int(waveform, self.n_bits) - - elif self.mode == 'mol': - waveform = float_2_int(waveform, 16) + args = self.args + n_fft = 2048 + waveform, sample_rate = torchaudio.load(file) + specgram = self.transforms(waveform) + specgram = linear_to_mel(specgram, sample_rate, n_fft, args.n_freq, args.f_min) + specgram = specgram_normalize(specgram, args.min_level_db) + waveform = waveform.squeeze(0) + + if args.mode == "waveform": + waveform = ( + mulaw_encode(waveform, 2 ** args.n_bits) + if args.mulaw + else waveform_to_label(waveform, args.n_bits) + ) + + elif args.mode == "mol": + waveform = waveform_to_label(waveform, 16) return waveform, specgram @@ -42,18 +44,6 @@ def __len__(self): return len(self.files) -# From float waveform [-1, 1] to integer label [0, 2 ** bits - 1] -def float_2_int(waveform, bits): - assert abs(waveform).max() <= 1.0 - waveform = (waveform + 1.) * (2**bits - 1) / 2 - return torch.clamp(waveform, 0, 2**bits - 1).int() - - -# From integer label [0, 2 ** bits - 1] to float waveform [-1, 1] -def int_2_float(waveform, bits): - return 2 * waveform / (2**bits - 1.) - 1. - - def datasets_ljspeech(args, transforms): root = args.file_path @@ -62,21 +52,21 @@ def datasets_ljspeech(args, transforms): random.seed(args.seed) random.shuffle(wavefiles) - train_files = wavefiles[:-args.test_samples] - test_files = wavefiles[-args.test_samples:] + train_files = wavefiles[: -args.test_samples] + test_files = wavefiles[-args.test_samples :] - train_dataset = ProcessedLJSPEECH(train_files, transforms, args.mode, args.mulaw, args.n_bits) - test_dataset = ProcessedLJSPEECH(test_files, transforms, args.mode, args.mulaw, args.n_bits) + train_dataset = ProcessedLJSPEECH(train_files, transforms, args) + test_dataset = ProcessedLJSPEECH(test_files, transforms, args) return train_dataset, test_dataset def collate_factory(args): - def raw_collate(batch): pad = (args.kernel_size - 1) // 2 - # input sequence length, increase seq_len_factor to increase it. + + # input waveform length wave_length = args.hop_length * args.seq_len_factor # input spectrogram length spec_length = args.seq_len_factor + pad * 2 @@ -86,14 +76,18 @@ def raw_collate(batch): # random start postion in spectrogram spec_offsets = [random.randint(0, offset) for offset in max_offsets] - # random start postion in waveform wave_offsets = [(offset + pad) * args.hop_length for offset in spec_offsets] - waveform_combine = [x[0][wave_offsets[i]:wave_offsets[i] + wave_length + 1] for i, x in enumerate(batch)] - specgram = [x[1][:, spec_offsets[i]:spec_offsets[i] + spec_length] for i, x in enumerate(batch)] + waveform_combine = [ + x[0][wave_offsets[i] : wave_offsets[i] + wave_length + 1] + for i, x in enumerate(batch) + ] + specgram = [ + x[1][:, spec_offsets[i] : spec_offsets[i] + spec_length] + for i, x in enumerate(batch) + ] - # stack batch specgram = torch.stack(specgram) waveform_combine = torch.stack(waveform_combine) @@ -102,12 +96,12 @@ def raw_collate(batch): # waveform: [-1, 1], target: [0, 2**bits-1] if mode = 'waveform' # waveform: [-1, 1], target: [-1, 1] if mode = 'mol' - bits = 16 if args.mode == 'mol' else args.n_bits + bits = 16 if args.mode == "mol" else args.n_bits - waveform = int_2_float(waveform.float(), bits) + waveform = label_to_waveform(waveform.float(), bits) - if args.mode == 'mol': - target = int_2_float(target.float(), bits) + if args.mode == "mol": + target = label_to_waveform(target.float(), bits) return waveform.unsqueeze(1), specgram.unsqueeze(1), target.unsqueeze(1) diff --git a/examples/pipeline_wavernn/mol_loss.py b/examples/pipeline_wavernn/losses.py similarity index 96% rename from examples/pipeline_wavernn/mol_loss.py rename to examples/pipeline_wavernn/losses.py index 7ad8f9b6e5..9cb90327ea 100644 --- a/examples/pipeline_wavernn/mol_loss.py +++ b/examples/pipeline_wavernn/losses.py @@ -1,9 +1,10 @@ import torch -import torch.nn.functional as F +from torch.nn import functional as F def log_sum_exp(x): - """ numerically stable log_sum_exp implementation that prevents overflow """ + r""" numerically stable log_sum_exp implementation that prevents overflow + """ axis = len(x.size()) - 1 m, _ = torch.max(x, dim=axis) diff --git a/examples/pipeline_wavernn/transform.py b/examples/pipeline_wavernn/transform.py index 7d7f7e1501..9bb4aa7583 100644 --- a/examples/pipeline_wavernn/transform.py +++ b/examples/pipeline_wavernn/transform.py @@ -1,59 +1,14 @@ import librosa -import numpy as np import torch -class Transform(): - def __init__(self, - sample_rate, - n_fft, - hop_length, - win_length, - num_mels, - fmin, - min_level_db): +def linear_to_mel(specgram, sample_rate, n_fft, n_mels, fmin): - self.sample_rate = sample_rate - self.n_fft = n_fft - self.hop_length = hop_length - self.win_length = win_length - self.num_mels = num_mels - self.fmin = fmin - self.min_level_db = min_level_db - - def load(self, path): - waveform = librosa.load(path, sr=self.sample_rate)[0] - return torch.from_numpy(waveform) - - def stft(self, y): - return librosa.stft( - y=y, - n_fft=self.n_fft, hop_length=self.hop_length, win_length=self.win_length) - - def linear_to_mel(self, spectrogram): - return librosa.feature.melspectrogram( - S=spectrogram, sr=self.sample_rate, n_fft=self.n_fft, n_mels=self.num_mels, fmin=self.fmin) - - def normalize(self, S): - return np.clip((S - self.min_level_db) / - self.min_level_db, 0, 1) - - def denormalize(self, S): - return (np.clip(S, 0, 1) * -self.min_level_db) + self.min_level_db - - def amp_to_db(self, x): - return 20 * np.log10(np.maximum(1e-5, x)) - - def db_to_amp(self, x): - return np.power(10.0, x * 0.05) - - def melspectrogram(self, y): - D = self.stft(y.numpy()) - S = self.amp_to_db(self.linear_to_mel(np.abs(D))) - S = self.normalize(S) - return torch.from_numpy(S).float() - - def mulaw_encode(self, x, mu): - x = x.numpy() - mu = mu - 1 - fx = np.sign(x) * np.log(1 + mu * np.abs(x)) / np.log(1 + mu) - return torch.from_numpy(np.floor((fx + 1) / 2 * mu + 0.5)).int() + specgram = librosa.feature.melspectrogram( + S=specgram.squeeze(0).numpy(), + sr=sample_rate, + n_fft=n_fft, + n_mels=n_mels, + fmin=fmin + ) + return torch.from_numpy(specgram) diff --git a/examples/pipeline_wavernn/utils.py b/examples/pipeline_wavernn/utils.py index 0b714495a6..3026d2dc71 100644 --- a/examples/pipeline_wavernn/utils.py +++ b/examples/pipeline_wavernn/utils.py @@ -7,6 +7,9 @@ class MetricLogger: + r"""Logger for metrics + """ + def __init__(self, group, print_freq=1): self.print_freq = print_freq self._iter = 0 @@ -29,8 +32,7 @@ def print(self): def save_checkpoint(state, is_best, filename): - """ - Save the model to a temporary file first, + r"""Save the model to a temporary file first, then copy it to filename, in case the signal interrupts the torch.save() process. """ @@ -53,4 +55,44 @@ def save_checkpoint(state, is_best, filename): def count_parameters(model): + r"""Count the total parameters in the model + """ + return sum(p.numel() for p in model.parameters() if p.requires_grad) + + +def specgram_normalize(specgram, min_level_db): + r"""Normalize the spectrogram with a minimum db value + """ + + specgram = 20 * torch.log10(torch.clamp(specgram, min=1e-5)) + return torch.clamp((min_level_db - specgram) / min_level_db, min=0, max=1) + + +def mulaw_encode(waveform, mu): + r"""mulaw encode waveform + """ + + mu = mu - 1 + fx = ( + torch.sign(waveform) + * torch.log(1 + mu * torch.abs(waveform)) + / torch.log(torch.as_tensor(1.0 + mu)) + ) + return torch.floor((fx + 1) / 2 * mu + 0.5).int() + + +def waveform_to_label(waveform, bits): + r"""Transform waveform [-1, 1] to label [0, 2 ** bits - 1] + """ + + assert abs(waveform).max() <= 1.0 + waveform = (waveform + 1.0) * (2 ** bits - 1) / 2 + return torch.clamp(waveform, 0, 2 ** bits - 1).int() + + +def label_to_waveform(label, bits): + r"""Transform label [0, 2 ** bits - 1] to waveform [-1, 1] + """ + + return 2 * label / (2 ** bits - 1.0) - 1.0 diff --git a/examples/pipeline_wavernn/wavernn.py b/examples/pipeline_wavernn/wavernn.py index a3fc4bb95d..f6c4e41790 100644 --- a/examples/pipeline_wavernn/wavernn.py +++ b/examples/pipeline_wavernn/wavernn.py @@ -5,22 +5,19 @@ from collections import defaultdict from datetime import datetime from time import time +from typing import List import torch -import torch.nn as nn import torchaudio -from transform import Transform -from datasets import datasets_ljspeech, collate_factory -from typing import List -from torchaudio.models import _WaveRNN +from torch import nn as nn +from torch.optim import Adam from torch.utils.data import DataLoader from torchaudio.datasets.utils import bg_iterator -from torch.optim.lr_scheduler import ExponentialLR, ReduceLROnPlateau -from torch.optim import Adam -from mol_loss import Mol_Loss -from utils import MetricLogger, count_parameters, save_checkpoint +from torchaudio.models._wavernn import _WaveRNN -SIGNAL_RECEIVED = False +from datasets import collate_factory, datasets_ljspeech +from losses import Mol_Loss +from utils import MetricLogger, count_parameters, save_checkpoint def parse_args(): @@ -35,125 +32,45 @@ def parse_args(): ) parser.add_argument( "--checkpoint", - default="checkpoint.pth.tar", + default="", type=str, - metavar="FILE", - help="filename to latest checkpoint", + metavar="PATH", + help="path to latest checkpoint", ) parser.add_argument( "--epochs", - default=3000, + default=8000, type=int, metavar="N", help="number of total epochs to run", ) parser.add_argument( - "--start-epoch", - default=0, - type=int, - metavar="N", - help="manual epoch number" + "--start-epoch", default=0, type=int, metavar="N", help="manual epoch number" ) parser.add_argument( "--print-freq", - default=100, + default=10, type=int, metavar="N", help="print frequency in epochs", ) parser.add_argument( - "--batch-size", - default=256, - type=int, - metavar="N", - help="mini-batch size" - ) - parser.add_argument( - "--learning-rate", - default=1e-4, - type=float, - metavar="LR", - help="initial learning rate", - ) - parser.add_argument( - "--weight-decay", - default=0.0, - type=float, - metavar="W", - help="weight decay" - ) - parser.add_argument( - "--adam-beta1", - default=0.9, - type=float, - metavar="AD1", - help="adam_beta1" - ) - parser.add_argument( - "--adam-beta2", - default=0.999, - type=float, - metavar="AD2", - help="adam_beta2" - ) - parser.add_argument( - "--eps", - metavar="EPS", - type=float, - default=1e-8 - ) - parser.add_argument( - "--clip-norm", - metavar="NORM", - type=float, - default=4.0 - ) - parser.add_argument( - "--scheduler", - metavar="S", - default="exponential", - choices=["exponential", "reduceonplateau"], - help="optimizer to use", - ) - parser.add_argument( - "--gamma", - default=0.999, - type=float, - metavar="GAMMA", - help="learning rate exponential decay constant", - ) - parser.add_argument( - "--seed", - type=int, - default=1000, - help="random seed" + "--batch-size", default=256, type=int, metavar="N", help="mini-batch size" ) parser.add_argument( - "--progress-bar", - default=False, - action="store_true", - help="use progress bar while training" + "--learning-rate", default=1e-4, type=float, metavar="LR", help="learning rate", ) + parser.add_argument("--clip-grad", metavar="NORM", type=float, default=4.0) + parser.add_argument("--seed", type=int, default=1000, help="random seed") parser.add_argument( "--mulaw", default=True, action="store_true", - help="if used, waveform is mulaw encoded" + help="if used, waveform is mulaw encoded", ) parser.add_argument( - "--jit", - default=False, - action="store_true", - help="if used, model is jitted" + "--jit", default=False, action="store_true", help="if used, model is jitted" ) - parser.add_argument( - '--resume', - default='', - type=str, - metavar='PATH', - help='path to latest checkpoint' - ) - # the product of `upsample_scales` must equal `hop_length` parser.add_argument( "--upsample-scales", default=[5, 5, 11], @@ -161,10 +78,7 @@ def parse_args(): help="the list of upsample scales", ) parser.add_argument( - "--n-bits", - default=8, - type=int, - help="the bits of output waveform", + "--n-bits", default=8, type=int, help="the bits of output waveform", ) parser.add_argument( "--sample-rate", @@ -186,7 +100,7 @@ def parse_args(): ) parser.add_argument( "--f-min", - default=40., + default=40.0, type=float, help="the number of samples between the starts of consecutive frames", ) @@ -197,22 +111,13 @@ def parse_args(): help="the min db value for spectrogam normalization", ) parser.add_argument( - "--n-res-block", - default=10, - type=int, - help="the number of ResBlock in stack", + "--n-res-block", default=10, type=int, help="the number of ResBlock in stack", ) parser.add_argument( - "--n-rnn", - default=512, - type=int, - help="the dimension of RNN layer", + "--n-rnn", default=512, type=int, help="the dimension of RNN layer", ) parser.add_argument( - "--n-fc", - default=512, - type=int, - help="the dimension of fully connected layer", + "--n-fc", default=512, type=int, help="the dimension of fully connected layer", ) parser.add_argument( "--kernel-size", @@ -221,29 +126,20 @@ def parse_args(): help="the number of kernel size in the first Conv1d layer", ) parser.add_argument( - "--n-freq", - default=80, - type=int, - help="the number of bins in a spectrogram", + "--n-freq", default=80, type=int, help="the number of bins in a spectrogram", ) parser.add_argument( - "--n-hidden", - default=128, - type=int, - help="the number of hidden dimensions", + "--n-hidden", default=128, type=int, help="the number of hidden dimensions", ) parser.add_argument( - "--n-output", - default=128, - type=int, - help="the number of output dimensions", + "--n-output", default=128, type=int, help="the number of output dimensions", ) parser.add_argument( "--mode", - default="waveform", + default="mol", choices=["waveform", "mol"], type=str, - help="the type of waveform", + help="the mode of waveform", ) parser.add_argument( "--seq-len-factor", @@ -268,15 +164,7 @@ def parse_args(): return args -def signal_handler(a, b): - global SIGNAL_RECEIVED - print("Signal received", a, datetime.now().strftime("%y%m%d.%H%M%S"), flush=True) - SIGNAL_RECEIVED = True - - -def train_one_epoch( - model, mode, criterion, optimizer, scheduler, data_loader, device, epoch -): +def train_one_epoch(model, mode, criterion, optimizer, data_loader, device, epoch): model.train() @@ -297,17 +185,12 @@ def train_one_epoch( output = model(waveform, specgram) output, target = output.squeeze(1), target.squeeze(1) - if mode == 'waveform': + if mode == "waveform": output = output.transpose(1, 2) target = target.long() - elif mode == 'mol': - target = target.unsqueeze(-1) - else: - raise ValueError( - f"Expected mode: `waveform` or `mol`, but found {mode}" - ) + target = target.unsqueeze(-1) loss = criterion(output, target) loss_item = loss.item() @@ -317,9 +200,9 @@ def train_one_epoch( optimizer.zero_grad() loss.backward() - if args.clip_norm > 0: + if args.clip_grad > 0: gradient = torch.nn.utils.clip_grad_norm_( - model.parameters(), args.clip_norm + model.parameters(), args.clip_grad ) sums["gradient"] += gradient metric("gradient", gradient.item()) @@ -331,22 +214,15 @@ def train_one_epoch( metric.print() sums["iteration"] += 1 - if SIGNAL_RECEIVED: - return - avg_loss = sums["loss"] / len(data_loader) metric = MetricLogger("train_epoch") metric("epoch", epoch) metric("loss", avg_loss) - if "gradient" in sums: - metric("gradient", sums["gradient"] / len(data_loader)) - metric("lr", scheduler.get_last_lr()[0]) + metric("gradient", sums["gradient"] / len(data_loader)) metric("time", time() - start1) metric.print() - scheduler.step() - def evaluate(model, mode, criterion, data_loader, device, epoch): @@ -364,24 +240,16 @@ def evaluate(model, mode, criterion, data_loader, device, epoch): output = model(waveform, specgram) - if mode == 'waveform': + if mode == "waveform": output = output.transpose(1, 2) target = target.long() - elif mode == 'mol': - target = target.unsqueeze(-1) - else: - raise ValueError( - f"Expected mode: `waveform` or `mol`, but found {mode}" - ) + target = target.unsqueeze(-1) loss = criterion(output, target) sums["loss"] += loss.item() - if SIGNAL_RECEIVED: - break - avg_loss = sums["loss"] / len(data_loader) metric = MetricLogger("validation") @@ -402,39 +270,15 @@ def main(args): # Empty CUDA cache torch.cuda.empty_cache() - # Install signal handler - signal.signal(signal.SIGUSR1, lambda a, b: signal_handler(a, b)) - - # use torchaudio transform to get waveform and specgram - -# melkwargs = { -# "n_fft": 2048, -# "n_mels": args.n_freq, -# "hop_length": args.hop_length, -# "f_min": args.f_min, -# "win_length": args.win_length -# } - -# transforms = torch.nn.Sequential( -# torchaudio.transforms.MelSpectrogram( -# sample_rate=args.sample_rate, **melkwargs -# ), -# ) - - # use librosa transform to get waveform and specgram - melkwargs = { "n_fft": 2048, - "num_mels": args.n_freq, + "power": 1, "hop_length": args.hop_length, - "fmin": args.f_min, "win_length": args.win_length, - "sample_rate": args.sample_rate, - "min_level_db": args.min_level_db } - transforms = Transform(**melkwargs) - # dataset + transforms = torch.nn.Sequential(torchaudio.transforms.Spectrogram(**melkwargs)) + train_dataset, test_dataset = datasets_ljspeech(args, transforms) loader_training_params = { @@ -461,20 +305,20 @@ def main(args): **loader_validation_params, ) - # model - model = _WaveRNN(upsample_scales=args.upsample_scales, - n_bits=args.n_bits, - sample_rate=args.sample_rate, - hop_length=args.hop_length, - n_res_block=args.n_res_block, - n_rnn=args.n_rnn, - n_fc=args.n_fc, - kernel_size=args.kernel_size, - n_freq=args.n_freq, - n_hidden=args.n_hidden, - n_output=args.n_output, - mode=args.mode, - ) + model = _WaveRNN( + upsample_scales=args.upsample_scales, + n_bits=args.n_bits, + sample_rate=args.sample_rate, + hop_length=args.hop_length, + n_res_block=args.n_res_block, + n_rnn=args.n_rnn, + n_fc=args.n_fc, + kernel_size=args.kernel_size, + n_freq=args.n_freq, + n_hidden=args.n_hidden, + n_output=args.n_output, + mode=args.mode, + ) if args.jit: model = torch.jit.script(model) @@ -488,21 +332,13 @@ def main(args): # Optimizer optimizer_params = { "lr": args.learning_rate, - "betas": (args.adam_beta1, args.adam_beta2), - "eps": args.eps, - "weight_decay": args.weight_decay, } optimizer = Adam(model.parameters(), **optimizer_params) - if args.scheduler == "exponential": - scheduler = ExponentialLR(optimizer, gamma=args.gamma) - elif args.scheduler == "reduceonplateau": - scheduler = ReduceLROnPlateau(optimizer, patience=10, threshold=1e-3) - - criterion = nn.CrossEntropyLoss() if args.mode == 'waveform' else Mol_Loss + criterion = nn.CrossEntropyLoss() if args.mode == "waveform" else Mol_Loss - best_loss = 10. + best_loss = 10.0 load_checkpoint = args.checkpoint and os.path.isfile(args.checkpoint) @@ -515,7 +351,6 @@ def main(args): model.load_state_dict(checkpoint["state_dict"]) optimizer.load_state_dict(checkpoint["optimizer"]) - scheduler.load_state_dict(checkpoint["scheduler"]) logging.info( f"Checkpoint: loaded '{args.checkpoint}' at epoch {checkpoint['epoch']}" @@ -529,7 +364,6 @@ def main(args): "state_dict": model.state_dict(), "best_loss": best_loss, "optimizer": optimizer.state_dict(), - "scheduler": scheduler.state_dict(), }, False, args.checkpoint, @@ -538,27 +372,13 @@ def main(args): for epoch in range(args.start_epoch, args.epochs): train_one_epoch( - model, args.mode, criterion, optimizer, scheduler, loader_training, devices[0], epoch, + model, args.mode, criterion, optimizer, loader_training, devices[0], epoch, ) - if SIGNAL_RECEIVED: - save_checkpoint({ - 'epoch': epoch, - 'state_dict': model.state_dict(), - 'best_loss': best_loss, - 'optimizer': optimizer.state_dict(), - 'scheduler': scheduler.state_dict(), - }, False, args.checkpoint) - if not (epoch + 1) % args.print_freq or epoch == args.epochs - 1: sum_loss = evaluate( - model, - args.mode, - criterion, - loader_test, - devices[0], - epoch, + model, args.mode, criterion, loader_test, devices[0], epoch, ) is_best = sum_loss < best_loss @@ -569,7 +389,6 @@ def main(args): "state_dict": model.state_dict(), "best_loss": best_loss, "optimizer": optimizer.state_dict(), - "scheduler": scheduler.state_dict(), }, is_best, args.checkpoint, @@ -581,6 +400,5 @@ def main(args): if __name__ == "__main__": logging.basicConfig(level=logging.INFO) - args = parse_args() main(args) From c43149d4e3084d4c3b294712fef2e6da5e0a0b06 Mon Sep 17 00:00:00 2001 From: Ji Chen Date: Mon, 13 Jul 2020 05:17:49 -0700 Subject: [PATCH 12/29] Add function doctring --- examples/pipeline_wavernn/README.md | 26 +++++------ examples/pipeline_wavernn/datasets.py | 45 +++++++++++++------ examples/pipeline_wavernn/losses.py | 44 +++++++++--------- .../pipeline_wavernn/{wavernn.py => main.py} | 11 +++-- examples/pipeline_wavernn/transform.py | 2 +- examples/pipeline_wavernn/utils.py | 6 +-- 6 files changed, 75 insertions(+), 59 deletions(-) rename examples/pipeline_wavernn/{wavernn.py => main.py} (98%) diff --git a/examples/pipeline_wavernn/README.md b/examples/pipeline_wavernn/README.md index d8ec32be7c..76b7bc2e15 100644 --- a/examples/pipeline_wavernn/README.md +++ b/examples/pipeline_wavernn/README.md @@ -1,6 +1,18 @@ -This is an example pipeline for WaveRNN vocoder using the WaveRNN model trained with LJSpeech. WaveRNN model is based on the implementation from this repository https://github.com/fatchord/WaveRNN. The original implementation was +This is an example pipeline for WaveRNN vocoder using the WaveRNN model trained on LJSpeech. WaveRNN model is based on the implementation from [this repository](https://github.com/fatchord/WaveRNN). The original implementation was introduced in "Efficient Neural Audio Synthesis". WaveRNN and LJSpeech are available in torchaudio. +### Usage + +An example can be invoked as follows. +``` +python main.py \ + --batch-size 256 \ + --learning-rate 1e-4 \ + --n-freq 80 \ + --mode 'waveform' \ + --n-bits 8 \ +``` + ### Output The information reported at each iteration and epoch (e.g. loss) is printed to standard output in the form of one json per line. Here is an example python function to parse the standard output. @@ -22,15 +34,3 @@ def read_json(filename): data = [json.loads(l) for l in data.splitlines()] return pandas.DataFrame(data) ``` - -### Usage - -More information about each command line parameters is available with the `--help` option. An example can be invoked as follows. -``` -python main.py \ - --batch-size 128 \ - --learning-rate 1e-4 \ - --n-freq 80 \ - --mode 'waveform' \ - --n-bits 8 \ -``` diff --git a/examples/pipeline_wavernn/datasets.py b/examples/pipeline_wavernn/datasets.py index 2557d9e274..ed8298bc04 100644 --- a/examples/pipeline_wavernn/datasets.py +++ b/examples/pipeline_wavernn/datasets.py @@ -6,8 +6,28 @@ from torchaudio.datasets import LJSPEECH from transform import linear_to_mel -from utils import (label_to_waveform, mulaw_encode, specgram_normalize, - waveform_to_label) +from utils import label_to_waveform, mulaw_encode, specgram_normalize, waveform_to_label + + +class MapMemoryCache(torch.utils.data.Dataset): + r"""Wrap a dataset so that, whenever a new item is returned, it is saved to memory. + """ + + def __init__(self, dataset): + self.dataset = dataset + self._cache = [None] * len(dataset) + + def __getitem__(self, n): + if self._cache[n] is not None: + return self._cache[n] + + item = self.dataset[n] + self._cache[n] = item + + return item + + def __len__(self): + return len(self.dataset) class ProcessedLJSPEECH(LJSPEECH): @@ -24,6 +44,7 @@ def __getitem__(self, index): n_fft = 2048 waveform, sample_rate = torchaudio.load(file) specgram = self.transforms(waveform) + # Will be replaced by torchaudio as described in https://github.com/pytorch/audio/pull/593 specgram = linear_to_mel(specgram, sample_rate, n_fft, args.n_freq, args.f_min) specgram = specgram_normalize(specgram, args.min_level_db) waveform = waveform.squeeze(0) @@ -35,9 +56,6 @@ def __getitem__(self, index): else waveform_to_label(waveform, args.n_bits) ) - elif args.mode == "mol": - waveform = waveform_to_label(waveform, 16) - return waveform, specgram def __len__(self): @@ -53,11 +71,14 @@ def datasets_ljspeech(args, transforms): random.shuffle(wavefiles) train_files = wavefiles[: -args.test_samples] - test_files = wavefiles[-args.test_samples :] + test_files = wavefiles[-args.test_samples:] train_dataset = ProcessedLJSPEECH(train_files, transforms, args) test_dataset = ProcessedLJSPEECH(test_files, transforms, args) + train_dataset = MapMemoryCache(train_dataset) + test_dataset = MapMemoryCache(test_dataset) + return train_dataset, test_dataset @@ -80,11 +101,11 @@ def raw_collate(batch): wave_offsets = [(offset + pad) * args.hop_length for offset in spec_offsets] waveform_combine = [ - x[0][wave_offsets[i] : wave_offsets[i] + wave_length + 1] + x[0][wave_offsets[i]: wave_offsets[i] + wave_length + 1] for i, x in enumerate(batch) ] specgram = [ - x[1][:, spec_offsets[i] : spec_offsets[i] + spec_length] + x[1][:, spec_offsets[i]: spec_offsets[i] + spec_length] for i, x in enumerate(batch) ] @@ -96,12 +117,8 @@ def raw_collate(batch): # waveform: [-1, 1], target: [0, 2**bits-1] if mode = 'waveform' # waveform: [-1, 1], target: [-1, 1] if mode = 'mol' - bits = 16 if args.mode == "mol" else args.n_bits - - waveform = label_to_waveform(waveform.float(), bits) - - if args.mode == "mol": - target = label_to_waveform(target.float(), bits) + if args.mode == "waveform": + waveform = label_to_waveform(waveform.float(), args.n_bits) return waveform.unsqueeze(1), specgram.unsqueeze(1), target.unsqueeze(1) diff --git a/examples/pipeline_wavernn/losses.py b/examples/pipeline_wavernn/losses.py index 9cb90327ea..055daef53a 100644 --- a/examples/pipeline_wavernn/losses.py +++ b/examples/pipeline_wavernn/losses.py @@ -3,7 +3,7 @@ def log_sum_exp(x): - r""" numerically stable log_sum_exp implementation that prevents overflow + r""" Numerically stable log_sum_exp implementation that prevents overflow """ axis = len(x.size()) - 1 @@ -12,14 +12,12 @@ def log_sum_exp(x): return m + torch.log(torch.sum(torch.exp(x - m2), dim=axis)) -def Mol_Loss(y_hat, y, num_classes=65536, log_scale_min=None, reduce=True): - """ Discretized mixture of logistic distributions loss +def MoLLoss(y_hat, y, num_classes=65536, log_scale_min=None, reduce=True): + r""" Discretized mixture of logistic distributions loss - Adapted from wavenet vocoder: - https://github.com/r9y9/wavenet_vocoder/blob/master/wavenet_vocoder/mixture.py - Explanation of `mol` loss: - https://github.com/Rayhane-mamah/Tacotron-2/issues/155 - It is assumed that input is scaled to [-1, 1]. + Adapted from wavenet vocoder + (https://github.com/r9y9/wavenet_vocoder/blob/master/wavenet_vocoder/mixture.py) + Explanation of loss (https://github.com/Rayhane-mamah/Tacotron-2/issues/155) Args: y_hat (Tensor): Predicted output (n_batch x n_time x n_channel) @@ -40,19 +38,19 @@ def Mol_Loss(y_hat, y, num_classes=65536, log_scale_min=None, reduce=True): nr_mix = y_hat.size(-1) // 3 - # unpack parameters. (n_batch, n_time, n_mixtures) x 3 + # unpack parameters (n_batch, n_time, num_mixtures) x 3 logit_probs = y_hat[:, :, :nr_mix] - means = y_hat[:, :, nr_mix:2 * nr_mix] - log_scales = torch.clamp(y_hat[:, :, 2 * nr_mix:3 * nr_mix], min=log_scale_min) + means = y_hat[:, :, nr_mix: 2 * nr_mix] + log_scales = torch.clamp(y_hat[:, :, 2 * nr_mix: 3 * nr_mix], min=log_scale_min) - # n_batch x n_time x 1 -> n_batch x n_time x n_mixtures + # (n_batch x n_time x 1) to (n_batch x n_time x num_mixtures) y = y.expand_as(means) centered_y = y - means inv_stdv = torch.exp(-log_scales) - plus_in = inv_stdv * (centered_y + 1. / (num_classes - 1)) + plus_in = inv_stdv * (centered_y + 1.0 / (num_classes - 1)) cdf_plus = torch.sigmoid(plus_in) - min_in = inv_stdv * (centered_y - 1. / (num_classes - 1)) + min_in = inv_stdv * (centered_y - 1.0 / (num_classes - 1)) cdf_min = torch.sigmoid(min_in) # log probability for edge case of 0 (before scaling) @@ -68,20 +66,22 @@ def Mol_Loss(y_hat, y, num_classes=65536, log_scale_min=None, reduce=True): mid_in = inv_stdv * centered_y # log probability in the center of the bin, to be used in extreme cases - # (not actually used in our code) - log_pdf_mid = mid_in - log_scales - 2. * F.softplus(mid_in) + log_pdf_mid = mid_in - log_scales - 2.0 * F.softplus(mid_in) inner_inner_cond = (cdf_delta > 1e-5).float() - inner_inner_out = inner_inner_cond * \ - torch.log(torch.clamp(cdf_delta, min=1e-12)) + \ - (1. - inner_inner_cond) * (log_pdf_mid - torch.log(torch.as_tensor((num_classes - 1) / 2)).item()) + inner_inner_out = inner_inner_cond * torch.log( + torch.clamp(cdf_delta, min=1e-12) + ) + (1.0 - inner_inner_cond) * ( + log_pdf_mid - torch.log(torch.as_tensor((num_classes - 1) / 2)).item() + ) inner_cond = (y > 0.999).float() - inner_out = inner_cond * log_one_minus_cdf_min + (1. - inner_cond) * inner_inner_out + inner_out = ( + inner_cond * log_one_minus_cdf_min + (1.0 - inner_cond) * inner_inner_out + ) cond = (y < -0.999).float() - log_probs = cond * log_cdf_plus + (1. - cond) * inner_out + log_probs = cond * log_cdf_plus + (1.0 - cond) * inner_out - # Add the 10 distributions probabilities and compute the new probabilities: log_probs = log_probs + F.log_softmax(logit_probs, -1) if reduce: diff --git a/examples/pipeline_wavernn/wavernn.py b/examples/pipeline_wavernn/main.py similarity index 98% rename from examples/pipeline_wavernn/wavernn.py rename to examples/pipeline_wavernn/main.py index f6c4e41790..5010de08d1 100644 --- a/examples/pipeline_wavernn/wavernn.py +++ b/examples/pipeline_wavernn/main.py @@ -16,7 +16,7 @@ from torchaudio.models._wavernn import _WaveRNN from datasets import collate_factory, datasets_ljspeech -from losses import Mol_Loss +from losses import MoLLoss from utils import MetricLogger, count_parameters, save_checkpoint @@ -190,6 +190,7 @@ def train_one_epoch(model, mode, criterion, optimizer, data_loader, device, epoc target = target.long() else: + # use mol mode target = target.unsqueeze(-1) loss = criterion(output, target) @@ -245,6 +246,7 @@ def evaluate(model, mode, criterion, data_loader, device, epoch): target = target.long() else: + # use mol mode target = target.unsqueeze(-1) loss = criterion(output, target) @@ -299,10 +301,7 @@ def main(args): **loader_training_params, ) loader_test = DataLoader( - test_dataset, - batch_size=args.batch_size, - collate_fn=collate_fn, - **loader_validation_params, + test_dataset, batch_size=1, collate_fn=collate_fn, **loader_validation_params, ) model = _WaveRNN( @@ -336,7 +335,7 @@ def main(args): optimizer = Adam(model.parameters(), **optimizer_params) - criterion = nn.CrossEntropyLoss() if args.mode == "waveform" else Mol_Loss + criterion = nn.CrossEntropyLoss() if args.mode == "waveform" else MoLLoss best_loss = 10.0 diff --git a/examples/pipeline_wavernn/transform.py b/examples/pipeline_wavernn/transform.py index 9bb4aa7583..952bdc5d82 100644 --- a/examples/pipeline_wavernn/transform.py +++ b/examples/pipeline_wavernn/transform.py @@ -9,6 +9,6 @@ def linear_to_mel(specgram, sample_rate, n_fft, n_mels, fmin): sr=sample_rate, n_fft=n_fft, n_mels=n_mels, - fmin=fmin + fmin=fmin, ) return torch.from_numpy(specgram) diff --git a/examples/pipeline_wavernn/utils.py b/examples/pipeline_wavernn/utils.py index 3026d2dc71..5ddce0b4ec 100644 --- a/examples/pipeline_wavernn/utils.py +++ b/examples/pipeline_wavernn/utils.py @@ -7,7 +7,7 @@ class MetricLogger: - r"""Logger for metrics + r"""Logger for model metrics """ def __init__(self, group, print_freq=1): @@ -55,7 +55,7 @@ def save_checkpoint(state, is_best, filename): def count_parameters(model): - r"""Count the total parameters in the model + r"""Count the total number of parameters in the model """ return sum(p.numel() for p in model.parameters() if p.requires_grad) @@ -70,7 +70,7 @@ def specgram_normalize(specgram, min_level_db): def mulaw_encode(waveform, mu): - r"""mulaw encode waveform + r"""Waveform mulaw encoding """ mu = mu - 1 From 0b944b49db284af73830da62fcb8e5c465cb0eca Mon Sep 17 00:00:00 2001 From: Ji Chen Date: Mon, 13 Jul 2020 05:41:06 -0700 Subject: [PATCH 13/29] Use default argument --- examples/pipeline_wavernn/main.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/pipeline_wavernn/main.py b/examples/pipeline_wavernn/main.py index 5010de08d1..aa004d220f 100644 --- a/examples/pipeline_wavernn/main.py +++ b/examples/pipeline_wavernn/main.py @@ -136,7 +136,7 @@ def parse_args(): ) parser.add_argument( "--mode", - default="mol", + default="waveform", choices=["waveform", "mol"], type=str, help="the mode of waveform", @@ -155,7 +155,7 @@ def parse_args(): ) parser.add_argument( "--file-path", - default="/private/home/jimchen90/datasets/LJSpeech-1.1/wavs/", + default="", type=str, help="the path of audio files", ) From 79c0ded1c480012bf8d8b747b42d57d320dafed3 Mon Sep 17 00:00:00 2001 From: Ji Chen Date: Mon, 13 Jul 2020 21:15:05 -0700 Subject: [PATCH 14/29] Update dataset --- examples/pipeline_wavernn/datasets.py | 46 +++++++++++++++---------- examples/pipeline_wavernn/functional.py | 38 ++++++++++++++++++++ examples/pipeline_wavernn/main.py | 24 ++++++------- examples/pipeline_wavernn/utils.py | 37 -------------------- 4 files changed, 78 insertions(+), 67 deletions(-) create mode 100644 examples/pipeline_wavernn/functional.py diff --git a/examples/pipeline_wavernn/datasets.py b/examples/pipeline_wavernn/datasets.py index ed8298bc04..e21ed86c04 100644 --- a/examples/pipeline_wavernn/datasets.py +++ b/examples/pipeline_wavernn/datasets.py @@ -6,7 +6,7 @@ from torchaudio.datasets import LJSPEECH from transform import linear_to_mel -from utils import label_to_waveform, mulaw_encode, specgram_normalize, waveform_to_label +from functional import label_to_waveform, mulaw_encode, specgram_normalize, waveform_to_label class MapMemoryCache(torch.utils.data.Dataset): @@ -32,19 +32,25 @@ def __len__(self): class ProcessedLJSPEECH(LJSPEECH): def __init__(self, files, transforms, args): - - self.transforms = transforms self.files = files + self.transforms = transforms self.args = args def __getitem__(self, index): + filename = self.files[index][0] + file = os.path.join(self.args.file_path, filename + '.wav') - file = self.files[index] + return self.process_datapoint(file) + + def __len__(self): + return len(self.files) + + def process_datapoint(self, file): args = self.args n_fft = 2048 waveform, sample_rate = torchaudio.load(file) specgram = self.transforms(waveform) - # Will be replaced by torchaudio as described in https://github.com/pytorch/audio/pull/593 + # TODO Replace by torchaudio, once https://github.com/pytorch/audio/pull/593 is resolved specgram = linear_to_mel(specgram, sample_rate, n_fft, args.n_freq, args.f_min) specgram = specgram_normalize(specgram, args.min_level_db) waveform = waveform.squeeze(0) @@ -58,28 +64,32 @@ def __getitem__(self, index): return waveform, specgram - def __len__(self): - return len(self.files) +def split_data(data, val_ratio): + files = data._walker + random.shuffle(files) + train_files = files[: -int(val_ratio * len(files))] + val_files = files[-int(val_ratio * len(files)):] -def datasets_ljspeech(args, transforms): + return train_files, val_files - root = args.file_path - wavefiles = [os.path.join(root, file) for file in os.listdir(root)] - random.seed(args.seed) - random.shuffle(wavefiles) +def gen_datasets_ljspeech( + args, + transforms, + root="datasets/", +): + data = LJSPEECH(root=root, download=False) - train_files = wavefiles[: -args.test_samples] - test_files = wavefiles[-args.test_samples:] + train_dataset, val_dataset = split_data(data, args.val_ratio) - train_dataset = ProcessedLJSPEECH(train_files, transforms, args) - test_dataset = ProcessedLJSPEECH(test_files, transforms, args) + train_dataset = ProcessedLJSPEECH(train_dataset, transforms, args) + val_dataset = ProcessedLJSPEECH(val_dataset, transforms, args) train_dataset = MapMemoryCache(train_dataset) - test_dataset = MapMemoryCache(test_dataset) + val_dataset = MapMemoryCache(val_dataset) - return train_dataset, test_dataset + return train_dataset, val_dataset def collate_factory(args): diff --git a/examples/pipeline_wavernn/functional.py b/examples/pipeline_wavernn/functional.py new file mode 100644 index 0000000000..15c3def8f0 --- /dev/null +++ b/examples/pipeline_wavernn/functional.py @@ -0,0 +1,38 @@ +import torch + + +def specgram_normalize(specgram, min_level_db): + r"""Normalize the spectrogram with a minimum db value + """ + + specgram = 20 * torch.log10(torch.clamp(specgram, min=1e-5)) + return torch.clamp((min_level_db - specgram) / min_level_db, min=0, max=1) + + +def mulaw_encode(waveform, mu): + r"""Waveform mulaw encoding + """ + + mu = mu - 1 + fx = ( + torch.sign(waveform) + * torch.log(1 + mu * torch.abs(waveform)) + / torch.log(torch.as_tensor(1.0 + mu)) + ) + return torch.floor((fx + 1) / 2 * mu + 0.5).int() + + +def waveform_to_label(waveform, bits): + r"""Transform waveform [-1, 1] to label [0, 2 ** bits - 1] + """ + + assert abs(waveform).max() <= 1.0 + waveform = (waveform + 1.0) * (2 ** bits - 1) / 2 + return torch.clamp(waveform, 0, 2 ** bits - 1).int() + + +def label_to_waveform(label, bits): + r"""Transform label [0, 2 ** bits - 1] to waveform [-1, 1] + """ + + return 2 * label / (2 ** bits - 1.0) - 1.0 diff --git a/examples/pipeline_wavernn/main.py b/examples/pipeline_wavernn/main.py index aa004d220f..2ada505cc1 100644 --- a/examples/pipeline_wavernn/main.py +++ b/examples/pipeline_wavernn/main.py @@ -15,7 +15,7 @@ from torchaudio.datasets.utils import bg_iterator from torchaudio.models._wavernn import _WaveRNN -from datasets import collate_factory, datasets_ljspeech +from datasets import collate_factory, gen_datasets_ljspeech from losses import MoLLoss from utils import MetricLogger, count_parameters, save_checkpoint @@ -148,10 +148,10 @@ def parse_args(): help="seq_length = hop_length * seq_len_factor", ) parser.add_argument( - "--test-samples", - default=50, + "--val-ratio", + default=0.1, type=float, - help="the number of waveforms for testing", + help="the ratio of waveforms for validation", ) parser.add_argument( "--file-path", @@ -225,7 +225,7 @@ def train_one_epoch(model, mode, criterion, optimizer, data_loader, device, epoc metric.print() -def evaluate(model, mode, criterion, data_loader, device, epoch): +def validate(model, mode, criterion, data_loader, device, epoch): with torch.no_grad(): @@ -281,7 +281,7 @@ def main(args): transforms = torch.nn.Sequential(torchaudio.transforms.Spectrogram(**melkwargs)) - train_dataset, test_dataset = datasets_ljspeech(args, transforms) + train_dataset, val_dataset = gen_datasets_ljspeech(args, transforms) loader_training_params = { "num_workers": args.workers, @@ -294,14 +294,14 @@ def main(args): collate_fn = collate_factory(args) - loader_training = DataLoader( + train_loader = DataLoader( train_dataset, batch_size=args.batch_size, collate_fn=collate_fn, **loader_training_params, ) - loader_test = DataLoader( - test_dataset, batch_size=1, collate_fn=collate_fn, **loader_validation_params, + val_loader = DataLoader( + val_dataset, batch_size=args.batch_size, collate_fn=collate_fn, **loader_validation_params, ) model = _WaveRNN( @@ -371,13 +371,13 @@ def main(args): for epoch in range(args.start_epoch, args.epochs): train_one_epoch( - model, args.mode, criterion, optimizer, loader_training, devices[0], epoch, + model, args.mode, criterion, optimizer, train_loader, devices[0], epoch, ) if not (epoch + 1) % args.print_freq or epoch == args.epochs - 1: - sum_loss = evaluate( - model, args.mode, criterion, loader_test, devices[0], epoch, + sum_loss = validate( + model, args.mode, criterion, val_loader, devices[0], epoch, ) is_best = sum_loss < best_loss diff --git a/examples/pipeline_wavernn/utils.py b/examples/pipeline_wavernn/utils.py index 5ddce0b4ec..a9cae5fbdc 100644 --- a/examples/pipeline_wavernn/utils.py +++ b/examples/pipeline_wavernn/utils.py @@ -59,40 +59,3 @@ def count_parameters(model): """ return sum(p.numel() for p in model.parameters() if p.requires_grad) - - -def specgram_normalize(specgram, min_level_db): - r"""Normalize the spectrogram with a minimum db value - """ - - specgram = 20 * torch.log10(torch.clamp(specgram, min=1e-5)) - return torch.clamp((min_level_db - specgram) / min_level_db, min=0, max=1) - - -def mulaw_encode(waveform, mu): - r"""Waveform mulaw encoding - """ - - mu = mu - 1 - fx = ( - torch.sign(waveform) - * torch.log(1 + mu * torch.abs(waveform)) - / torch.log(torch.as_tensor(1.0 + mu)) - ) - return torch.floor((fx + 1) / 2 * mu + 0.5).int() - - -def waveform_to_label(waveform, bits): - r"""Transform waveform [-1, 1] to label [0, 2 ** bits - 1] - """ - - assert abs(waveform).max() <= 1.0 - waveform = (waveform + 1.0) * (2 ** bits - 1) / 2 - return torch.clamp(waveform, 0, 2 ** bits - 1).int() - - -def label_to_waveform(label, bits): - r"""Transform label [0, 2 ** bits - 1] to waveform [-1, 1] - """ - - return 2 * label / (2 ** bits - 1.0) - 1.0 From 553b170959c04abc74a9c7a1bdcb962327673926 Mon Sep 17 00:00:00 2001 From: Ji Chen Date: Tue, 14 Jul 2020 06:47:28 -0700 Subject: [PATCH 15/29] update dataset function --- examples/pipeline_wavernn/datasets.py | 41 ++++++++++++++++----------- examples/pipeline_wavernn/main.py | 10 +++---- 2 files changed, 29 insertions(+), 22 deletions(-) diff --git a/examples/pipeline_wavernn/datasets.py b/examples/pipeline_wavernn/datasets.py index e21ed86c04..2c6f4b3696 100644 --- a/examples/pipeline_wavernn/datasets.py +++ b/examples/pipeline_wavernn/datasets.py @@ -6,7 +6,12 @@ from torchaudio.datasets import LJSPEECH from transform import linear_to_mel -from functional import label_to_waveform, mulaw_encode, specgram_normalize, waveform_to_label +from functional import ( + label_to_waveform, + mulaw_encode, + specgram_normalize, + waveform_to_label, +) class MapMemoryCache(torch.utils.data.Dataset): @@ -31,19 +36,20 @@ def __len__(self): class ProcessedLJSPEECH(LJSPEECH): - def __init__(self, files, transforms, args): - self.files = files + def __init__(self, dataset, transforms, args): + self.dataset = dataset self.transforms = transforms self.args = args def __getitem__(self, index): - filename = self.files[index][0] - file = os.path.join(self.args.file_path, filename + '.wav') + filename = self.dataset[index][0] + folder = "LJSpeech-1.1/wavs/" + file = os.path.join(self.args.file_path, folder, filename + ".wav") return self.process_datapoint(file) def __len__(self): - return len(self.files) + return len(self.dataset) def process_datapoint(self, file): args = self.args @@ -65,23 +71,24 @@ def process_datapoint(self, file): return waveform, specgram -def split_data(data, val_ratio): - files = data._walker - random.shuffle(files) - train_files = files[: -int(val_ratio * len(files))] - val_files = files[-int(val_ratio * len(files)):] +def split_data(data, val_ratio, seed): + dataset = data._walker + + random.seed(seed) + random.shuffle(dataset) - return train_files, val_files + train_dataset = dataset[: -int(val_ratio * len(dataset))] + val_dataset = dataset[-int(val_ratio * len(dataset)):] + + return train_dataset, val_dataset def gen_datasets_ljspeech( - args, - transforms, - root="datasets/", + args, transforms, ): - data = LJSPEECH(root=root, download=False) + data = LJSPEECH(root=args.file_path, download=False) - train_dataset, val_dataset = split_data(data, args.val_ratio) + train_dataset, val_dataset = split_data(data, args.val_ratio, args.seed) train_dataset = ProcessedLJSPEECH(train_dataset, transforms, args) val_dataset = ProcessedLJSPEECH(val_dataset, transforms, args) diff --git a/examples/pipeline_wavernn/main.py b/examples/pipeline_wavernn/main.py index 2ada505cc1..43dc295b80 100644 --- a/examples/pipeline_wavernn/main.py +++ b/examples/pipeline_wavernn/main.py @@ -154,10 +154,7 @@ def parse_args(): help="the ratio of waveforms for validation", ) parser.add_argument( - "--file-path", - default="", - type=str, - help="the path of audio files", + "--file-path", default="", type=str, help="the path of audio files", ) args = parser.parse_args() @@ -301,7 +298,10 @@ def main(args): **loader_training_params, ) val_loader = DataLoader( - val_dataset, batch_size=args.batch_size, collate_fn=collate_fn, **loader_validation_params, + val_dataset, + batch_size=args.batch_size, + collate_fn=collate_fn, + **loader_validation_params, ) model = _WaveRNN( From 73f22d2c9a20c99edbb6d6cad9933a5274e60dad Mon Sep 17 00:00:00 2001 From: Ji Chen Date: Wed, 15 Jul 2020 06:09:29 -0700 Subject: [PATCH 16/29] update data split and transform --- examples/pipeline_wavernn/datasets.py | 77 ++++++++----------------- examples/pipeline_wavernn/functional.py | 8 --- examples/pipeline_wavernn/main.py | 25 ++++++-- examples/pipeline_wavernn/transform.py | 39 ++++++++++--- 4 files changed, 75 insertions(+), 74 deletions(-) diff --git a/examples/pipeline_wavernn/datasets.py b/examples/pipeline_wavernn/datasets.py index 2c6f4b3696..2df47feb33 100644 --- a/examples/pipeline_wavernn/datasets.py +++ b/examples/pipeline_wavernn/datasets.py @@ -3,15 +3,10 @@ import torch import torchaudio +from torch.utils.data.dataset import random_split from torchaudio.datasets import LJSPEECH -from transform import linear_to_mel -from functional import ( - label_to_waveform, - mulaw_encode, - specgram_normalize, - waveform_to_label, -) +from functional import label_to_waveform, mulaw_encode, waveform_to_label class MapMemoryCache(torch.utils.data.Dataset): @@ -35,63 +30,32 @@ def __len__(self): return len(self.dataset) -class ProcessedLJSPEECH(LJSPEECH): - def __init__(self, dataset, transforms, args): +class Processed(torch.utils.data.Dataset): + def __init__(self, dataset, transforms): self.dataset = dataset self.transforms = transforms - self.args = args - def __getitem__(self, index): - filename = self.dataset[index][0] - folder = "LJSpeech-1.1/wavs/" - file = os.path.join(self.args.file_path, folder, filename + ".wav") - - return self.process_datapoint(file) + def __getitem__(self, key): + item = self.dataset[key][0] + return self.process_datapoint(item) def __len__(self): return len(self.dataset) - def process_datapoint(self, file): - args = self.args - n_fft = 2048 - waveform, sample_rate = torchaudio.load(file) + def process_datapoint(self, waveform): specgram = self.transforms(waveform) - # TODO Replace by torchaudio, once https://github.com/pytorch/audio/pull/593 is resolved - specgram = linear_to_mel(specgram, sample_rate, n_fft, args.n_freq, args.f_min) - specgram = specgram_normalize(specgram, args.min_level_db) - waveform = waveform.squeeze(0) - - if args.mode == "waveform": - waveform = ( - mulaw_encode(waveform, 2 ** args.n_bits) - if args.mulaw - else waveform_to_label(waveform, args.n_bits) - ) - - return waveform, specgram - - -def split_data(data, val_ratio, seed): - dataset = data._walker - - random.seed(seed) - random.shuffle(dataset) + return waveform.squeeze(0), specgram - train_dataset = dataset[: -int(val_ratio * len(dataset))] - val_dataset = dataset[-int(val_ratio * len(dataset)):] - return train_dataset, val_dataset - - -def gen_datasets_ljspeech( - args, transforms, -): +def gen_datasets_ljspeech(args, transforms): data = LJSPEECH(root=args.file_path, download=False) - train_dataset, val_dataset = split_data(data, args.val_ratio, args.seed) + val_length = int(len(data) * args.val_ratio) + lengths = [len(data) - val_length, val_length] + train_dataset, val_dataset = random_split(data, lengths) - train_dataset = ProcessedLJSPEECH(train_dataset, transforms, args) - val_dataset = ProcessedLJSPEECH(val_dataset, transforms, args) + train_dataset = Processed(train_dataset, transforms) + val_dataset = Processed(val_dataset, transforms) train_dataset = MapMemoryCache(train_dataset) val_dataset = MapMemoryCache(val_dataset) @@ -133,9 +97,16 @@ def raw_collate(batch): target = waveform_combine[:, 1:] # waveform: [-1, 1], target: [0, 2**bits-1] if mode = 'waveform' - # waveform: [-1, 1], target: [-1, 1] if mode = 'mol' if args.mode == "waveform": - waveform = label_to_waveform(waveform.float(), args.n_bits) + + if args.mulaw: + waveform = mulaw_encode(waveform, 2 ** args.n_bits) + target = mulaw_encode(target, 2 ** args.n_bits) + + waveform = label_to_waveform(waveform, args.n_bits) + + else: + target = waveform_to_label(target, args.n_bits) return waveform.unsqueeze(1), specgram.unsqueeze(1), target.unsqueeze(1) diff --git a/examples/pipeline_wavernn/functional.py b/examples/pipeline_wavernn/functional.py index 15c3def8f0..cff982327b 100644 --- a/examples/pipeline_wavernn/functional.py +++ b/examples/pipeline_wavernn/functional.py @@ -1,14 +1,6 @@ import torch -def specgram_normalize(specgram, min_level_db): - r"""Normalize the spectrogram with a minimum db value - """ - - specgram = 20 * torch.log10(torch.clamp(specgram, min=1e-5)) - return torch.clamp((min_level_db - specgram) / min_level_db, min=0, max=1) - - def mulaw_encode(waveform, mu): r"""Waveform mulaw encoding """ diff --git a/examples/pipeline_wavernn/main.py b/examples/pipeline_wavernn/main.py index 43dc295b80..fc0d3d33e0 100644 --- a/examples/pipeline_wavernn/main.py +++ b/examples/pipeline_wavernn/main.py @@ -17,6 +17,7 @@ from datasets import collate_factory, gen_datasets_ljspeech from losses import MoLLoss +from transform import linear_to_mel, specgram_normalize from utils import MetricLogger, count_parameters, save_checkpoint @@ -61,7 +62,6 @@ def parse_args(): "--learning-rate", default=1e-4, type=float, metavar="LR", help="learning rate", ) parser.add_argument("--clip-grad", metavar="NORM", type=float, default=4.0) - parser.add_argument("--seed", type=int, default=1000, help="random seed") parser.add_argument( "--mulaw", default=True, @@ -134,6 +134,9 @@ def parse_args(): parser.add_argument( "--n-output", default=128, type=int, help="the number of output dimensions", ) + parser.add_argument( + "--n-fft", default=2048, type=int, help="the number of Fourier bins", + ) parser.add_argument( "--mode", default="waveform", @@ -154,7 +157,10 @@ def parse_args(): help="the ratio of waveforms for validation", ) parser.add_argument( - "--file-path", default="", type=str, help="the path of audio files", + "--file-path", + default="", + type=str, + help="the path of audio files", ) args = parser.parse_args() @@ -237,6 +243,7 @@ def validate(model, mode, criterion, data_loader, device, epoch): target = target.to(device) output = model(waveform, specgram) + output, target = output.squeeze(1), target.squeeze(1) if mode == "waveform": output = output.transpose(1, 2) @@ -270,13 +277,23 @@ def main(args): torch.cuda.empty_cache() melkwargs = { - "n_fft": 2048, + "n_fft": args.n_fft, "power": 1, "hop_length": args.hop_length, "win_length": args.win_length, } - transforms = torch.nn.Sequential(torchaudio.transforms.Spectrogram(**melkwargs)) + transforms = torch.nn.Sequential( + torchaudio.transforms.Spectrogram(**melkwargs), + # TODO Replace by torchaudio, once https://github.com/pytorch/audio/pull/593 is resolved + linear_to_mel( + sample_rate=args.sample_rate, + n_fft=args.n_fft, + n_mels=args.n_freq, + fmin=args.f_min, + ), + specgram_normalize(min_level_db=args.min_level_db), + ) train_dataset, val_dataset = gen_datasets_ljspeech(args, transforms) diff --git a/examples/pipeline_wavernn/transform.py b/examples/pipeline_wavernn/transform.py index 952bdc5d82..6335e130c2 100644 --- a/examples/pipeline_wavernn/transform.py +++ b/examples/pipeline_wavernn/transform.py @@ -1,14 +1,35 @@ import librosa import torch +import torch.nn as nn -def linear_to_mel(specgram, sample_rate, n_fft, n_mels, fmin): +class linear_to_mel(nn.Module): + def __init__(self, sample_rate, n_fft, n_mels, fmin): + super().__init__() + self.sample_rate = sample_rate + self.n_fft = n_fft + self.n_mels = n_mels + self.fmin = fmin - specgram = librosa.feature.melspectrogram( - S=specgram.squeeze(0).numpy(), - sr=sample_rate, - n_fft=n_fft, - n_mels=n_mels, - fmin=fmin, - ) - return torch.from_numpy(specgram) + def forward(self, specgram): + specgram = librosa.feature.melspectrogram( + S=specgram.squeeze(0).numpy(), + sr=self.sample_rate, + n_fft=self.n_fft, + n_mels=self.n_mels, + fmin=self.fmin, + ) + return torch.from_numpy(specgram) + + +class specgram_normalize(nn.Module): + r"""Normalize the spectrogram with a minimum db value + """ + + def __init__(self, min_level_db): + super().__init__() + self.min_level_db = min_level_db + + def forward(self, specgram): + specgram = 20 * torch.log10(torch.clamp(specgram, min=1e-5)) + return torch.clamp((self.min_level_db - specgram) / self.min_level_db, min=0, max=1) From a8aca081c24634963d3e876a4c4425557450759b Mon Sep 17 00:00:00 2001 From: Ji Chen Date: Wed, 15 Jul 2020 14:06:40 -0700 Subject: [PATCH 17/29] update format --- examples/pipeline_wavernn/datasets.py | 16 +++++----- examples/pipeline_wavernn/functional.py | 30 ------------------- examples/pipeline_wavernn/losses.py | 2 +- examples/pipeline_wavernn/main.py | 16 ++++------ .../{transform.py => processing.py} | 21 +++++++++++-- examples/pipeline_wavernn/utils.py | 4 +-- 6 files changed, 37 insertions(+), 52 deletions(-) delete mode 100644 examples/pipeline_wavernn/functional.py rename examples/pipeline_wavernn/{transform.py => processing.py} (62%) diff --git a/examples/pipeline_wavernn/datasets.py b/examples/pipeline_wavernn/datasets.py index 2df47feb33..77fd03fc53 100644 --- a/examples/pipeline_wavernn/datasets.py +++ b/examples/pipeline_wavernn/datasets.py @@ -5,8 +5,9 @@ import torchaudio from torch.utils.data.dataset import random_split from torchaudio.datasets import LJSPEECH +from torchaudio.transforms import MuLawEncoding -from functional import label_to_waveform, mulaw_encode, waveform_to_label +from functional import label_to_waveform, waveform_to_label class MapMemoryCache(torch.utils.data.Dataset): @@ -36,18 +37,18 @@ def __init__(self, dataset, transforms): self.transforms = transforms def __getitem__(self, key): - item = self.dataset[key][0] + item = self.dataset[key] return self.process_datapoint(item) def __len__(self): return len(self.dataset) def process_datapoint(self, waveform): - specgram = self.transforms(waveform) - return waveform.squeeze(0), specgram + specgram = self.transforms(waveform[0]) + return waveform[0].squeeze(0), specgram -def gen_datasets_ljspeech(args, transforms): +def split_process_ljspeech(args, transforms): data = LJSPEECH(root=args.file_path, download=False) val_length = int(len(data) * args.val_ratio) @@ -100,8 +101,9 @@ def raw_collate(batch): if args.mode == "waveform": if args.mulaw: - waveform = mulaw_encode(waveform, 2 ** args.n_bits) - target = mulaw_encode(target, 2 ** args.n_bits) + mulaw_encode = MuLawEncoding(2 ** args.n_bits) + waveform = mulaw_encode(waveform) + target = mulaw_encode(waveform) waveform = label_to_waveform(waveform, args.n_bits) diff --git a/examples/pipeline_wavernn/functional.py b/examples/pipeline_wavernn/functional.py deleted file mode 100644 index cff982327b..0000000000 --- a/examples/pipeline_wavernn/functional.py +++ /dev/null @@ -1,30 +0,0 @@ -import torch - - -def mulaw_encode(waveform, mu): - r"""Waveform mulaw encoding - """ - - mu = mu - 1 - fx = ( - torch.sign(waveform) - * torch.log(1 + mu * torch.abs(waveform)) - / torch.log(torch.as_tensor(1.0 + mu)) - ) - return torch.floor((fx + 1) / 2 * mu + 0.5).int() - - -def waveform_to_label(waveform, bits): - r"""Transform waveform [-1, 1] to label [0, 2 ** bits - 1] - """ - - assert abs(waveform).max() <= 1.0 - waveform = (waveform + 1.0) * (2 ** bits - 1) / 2 - return torch.clamp(waveform, 0, 2 ** bits - 1).int() - - -def label_to_waveform(label, bits): - r"""Transform label [0, 2 ** bits - 1] to waveform [-1, 1] - """ - - return 2 * label / (2 ** bits - 1.0) - 1.0 diff --git a/examples/pipeline_wavernn/losses.py b/examples/pipeline_wavernn/losses.py index 055daef53a..2b81d827f4 100644 --- a/examples/pipeline_wavernn/losses.py +++ b/examples/pipeline_wavernn/losses.py @@ -12,7 +12,7 @@ def log_sum_exp(x): return m + torch.log(torch.sum(torch.exp(x - m2), dim=axis)) -def MoLLoss(y_hat, y, num_classes=65536, log_scale_min=None, reduce=True): +def mol_loss(y_hat, y, num_classes=65536, log_scale_min=None, reduce=True): r""" Discretized mixture of logistic distributions loss Adapted from wavenet vocoder diff --git a/examples/pipeline_wavernn/main.py b/examples/pipeline_wavernn/main.py index fc0d3d33e0..192ba6006d 100644 --- a/examples/pipeline_wavernn/main.py +++ b/examples/pipeline_wavernn/main.py @@ -15,9 +15,9 @@ from torchaudio.datasets.utils import bg_iterator from torchaudio.models._wavernn import _WaveRNN -from datasets import collate_factory, gen_datasets_ljspeech +from datasets import collate_factory, split_process_ljspeech from losses import MoLLoss -from transform import linear_to_mel, specgram_normalize +from transform import LinearToMel, NormalizeDB from utils import MetricLogger, count_parameters, save_checkpoint @@ -158,7 +158,7 @@ def parse_args(): ) parser.add_argument( "--file-path", - default="", + default="/private/home/jimchen90/datasets", type=str, help="the path of audio files", ) @@ -273,9 +273,6 @@ def main(args): logging.info("Start time: {}".format(str(datetime.now()))) - # Empty CUDA cache - torch.cuda.empty_cache() - melkwargs = { "n_fft": args.n_fft, "power": 1, @@ -285,17 +282,16 @@ def main(args): transforms = torch.nn.Sequential( torchaudio.transforms.Spectrogram(**melkwargs), - # TODO Replace by torchaudio, once https://github.com/pytorch/audio/pull/593 is resolved - linear_to_mel( + LinearToMel( sample_rate=args.sample_rate, n_fft=args.n_fft, n_mels=args.n_freq, fmin=args.f_min, ), - specgram_normalize(min_level_db=args.min_level_db), + NormalizeDB(min_level_db=args.min_level_db), ) - train_dataset, val_dataset = gen_datasets_ljspeech(args, transforms) + train_dataset, val_dataset = split_process_ljspeech(args, transforms) loader_training_params = { "num_workers": args.workers, diff --git a/examples/pipeline_wavernn/transform.py b/examples/pipeline_wavernn/processing.py similarity index 62% rename from examples/pipeline_wavernn/transform.py rename to examples/pipeline_wavernn/processing.py index 6335e130c2..8b28d2990f 100644 --- a/examples/pipeline_wavernn/transform.py +++ b/examples/pipeline_wavernn/processing.py @@ -3,7 +3,8 @@ import torch.nn as nn -class linear_to_mel(nn.Module): +# TODO Replace by torchaudio, once https://github.com/pytorch/audio/pull/593 is resolved +class LinearToMel(nn.Module): def __init__(self, sample_rate, n_fft, n_mels, fmin): super().__init__() self.sample_rate = sample_rate @@ -22,7 +23,7 @@ def forward(self, specgram): return torch.from_numpy(specgram) -class specgram_normalize(nn.Module): +class NormalizeDB(nn.Module): r"""Normalize the spectrogram with a minimum db value """ @@ -33,3 +34,19 @@ def __init__(self, min_level_db): def forward(self, specgram): specgram = 20 * torch.log10(torch.clamp(specgram, min=1e-5)) return torch.clamp((self.min_level_db - specgram) / self.min_level_db, min=0, max=1) + + +def waveform_to_label(waveform, bits): + r"""Transform waveform [-1, 1] to label [0, 2 ** bits - 1] + """ + + assert abs(waveform).max() <= 1.0 + waveform = (waveform + 1.0) * (2 ** bits - 1) / 2 + return torch.clamp(waveform, 0, 2 ** bits - 1).int() + + +def label_to_waveform(label, bits): + r"""Transform label [0, 2 ** bits - 1] to waveform [-1, 1] + """ + + return 2 * label / (2 ** bits - 1.0) - 1.0 diff --git a/examples/pipeline_wavernn/utils.py b/examples/pipeline_wavernn/utils.py index a9cae5fbdc..7f7f7bfe5e 100644 --- a/examples/pipeline_wavernn/utils.py +++ b/examples/pipeline_wavernn/utils.py @@ -16,8 +16,8 @@ def __init__(self, group, print_freq=1): self.data = defaultdict(lambda: deque(maxlen=self.print_freq)) self.data["group"].append(group) - def __call__(self, key, value): - self.data[key].append(value) + def __setitem__(self, key): + self.data[key][-1] def _get_last(self): return {k: v[-1] for k, v in self.data.items()} From a47d00fb9b3428984d265d545ffb9cf726ef8108 Mon Sep 17 00:00:00 2001 From: Ji Chen Date: Wed, 15 Jul 2020 19:57:10 -0700 Subject: [PATCH 18/29] update logger --- examples/pipeline_wavernn/datasets.py | 4 +-- examples/pipeline_wavernn/main.py | 40 +++++++++++++-------------- examples/pipeline_wavernn/utils.py | 6 ++-- 3 files changed, 25 insertions(+), 25 deletions(-) diff --git a/examples/pipeline_wavernn/datasets.py b/examples/pipeline_wavernn/datasets.py index 77fd03fc53..c49faabab2 100644 --- a/examples/pipeline_wavernn/datasets.py +++ b/examples/pipeline_wavernn/datasets.py @@ -7,7 +7,7 @@ from torchaudio.datasets import LJSPEECH from torchaudio.transforms import MuLawEncoding -from functional import label_to_waveform, waveform_to_label +from processing import label_to_waveform, waveform_to_label class MapMemoryCache(torch.utils.data.Dataset): @@ -103,7 +103,7 @@ def raw_collate(batch): if args.mulaw: mulaw_encode = MuLawEncoding(2 ** args.n_bits) waveform = mulaw_encode(waveform) - target = mulaw_encode(waveform) + target = mulaw_encode(target) waveform = label_to_waveform(waveform, args.n_bits) diff --git a/examples/pipeline_wavernn/main.py b/examples/pipeline_wavernn/main.py index 192ba6006d..ddf7800964 100644 --- a/examples/pipeline_wavernn/main.py +++ b/examples/pipeline_wavernn/main.py @@ -16,8 +16,8 @@ from torchaudio.models._wavernn import _WaveRNN from datasets import collate_factory, split_process_ljspeech -from losses import MoLLoss -from transform import LinearToMel, NormalizeDB +from losses import mol_loss +from processing import LinearToMel, NormalizeDB from utils import MetricLogger, count_parameters, save_checkpoint @@ -158,7 +158,7 @@ def parse_args(): ) parser.add_argument( "--file-path", - default="/private/home/jimchen90/datasets", + default="", type=str, help="the path of audio files", ) @@ -175,7 +175,7 @@ def train_one_epoch(model, mode, criterion, optimizer, data_loader, device, epoc start1 = time() metric = MetricLogger("train_iteration") - metric("epoch", epoch) + metric["epoch"] = epoch for waveform, specgram, target in bg_iterator(data_loader, maxsize=2): @@ -199,7 +199,7 @@ def train_one_epoch(model, mode, criterion, optimizer, data_loader, device, epoc loss = criterion(output, target) loss_item = loss.item() sums["loss"] += loss_item - metric("loss", loss_item) + metric["loss"] = loss_item optimizer.zero_grad() loss.backward() @@ -208,24 +208,24 @@ def train_one_epoch(model, mode, criterion, optimizer, data_loader, device, epoc gradient = torch.nn.utils.clip_grad_norm_( model.parameters(), args.clip_grad ) - sums["gradient"] += gradient - metric("gradient", gradient.item()) + sums["gradient"] += gradient.item() + metric["gradient"] = gradient.item() optimizer.step() - metric("iteration", sums["iteration"]) - metric("time", time() - start2) - metric.print() + metric["iteration"] = sums["iteration"] + metric["time"] = time() - start2 + metric() sums["iteration"] += 1 avg_loss = sums["loss"] / len(data_loader) metric = MetricLogger("train_epoch") - metric("epoch", epoch) - metric("loss", avg_loss) - metric("gradient", sums["gradient"] / len(data_loader)) - metric("time", time() - start1) - metric.print() + metric["epoch"] = epoch + metric["loss"] = avg_loss + metric["gradient"] = sums["gradient"] / len(data_loader) + metric["time"] = time() - start1 + metric() def validate(model, mode, criterion, data_loader, device, epoch): @@ -259,10 +259,10 @@ def validate(model, mode, criterion, data_loader, device, epoch): avg_loss = sums["loss"] / len(data_loader) metric = MetricLogger("validation") - metric("epoch", epoch) - metric("loss", avg_loss) - metric("time", time() - start) - metric.print() + metric["epoch"] = epoch + metric["loss"] = avg_loss + metric["time"] = time() - start + metric() return avg_loss @@ -348,7 +348,7 @@ def main(args): optimizer = Adam(model.parameters(), **optimizer_params) - criterion = nn.CrossEntropyLoss() if args.mode == "waveform" else MoLLoss + criterion = nn.CrossEntropyLoss() if args.mode == "waveform" else mol_loss best_loss = 10.0 diff --git a/examples/pipeline_wavernn/utils.py b/examples/pipeline_wavernn/utils.py index 7f7f7bfe5e..e924c9f512 100644 --- a/examples/pipeline_wavernn/utils.py +++ b/examples/pipeline_wavernn/utils.py @@ -16,8 +16,8 @@ def __init__(self, group, print_freq=1): self.data = defaultdict(lambda: deque(maxlen=self.print_freq)) self.data["group"].append(group) - def __setitem__(self, key): - self.data[key][-1] + def __setitem__(self, key, value): + self.data[key].append(value) def _get_last(self): return {k: v[-1] for k, v in self.data.items()} @@ -25,7 +25,7 @@ def _get_last(self): def __str__(self): return str(self._get_last()) - def print(self): + def __call__(self): self._iter = (self._iter + 1) % self.print_freq if not self._iter: print(self, flush=True) From c4ff493b201bffb70c6a096a3ff1e284990cf9f6 Mon Sep 17 00:00:00 2001 From: Ji Chen Date: Thu, 16 Jul 2020 06:26:03 -0700 Subject: [PATCH 19/29] update import format --- examples/pipeline_wavernn/main.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/examples/pipeline_wavernn/main.py b/examples/pipeline_wavernn/main.py index ddf7800964..d2c30af39b 100644 --- a/examples/pipeline_wavernn/main.py +++ b/examples/pipeline_wavernn/main.py @@ -157,10 +157,7 @@ def parse_args(): help="the ratio of waveforms for validation", ) parser.add_argument( - "--file-path", - default="", - type=str, - help="the path of audio files", + "--file-path", default="", type=str, help="the path of audio files", ) args = parser.parse_args() From 3434e16bd6b568aedc5771d3a9510c00886ece3a Mon Sep 17 00:00:00 2001 From: Ji Chen Date: Thu, 16 Jul 2020 07:58:12 -0700 Subject: [PATCH 20/29] move condition in statement --- examples/pipeline_wavernn/main.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/examples/pipeline_wavernn/main.py b/examples/pipeline_wavernn/main.py index d2c30af39b..b5bbb4baf3 100644 --- a/examples/pipeline_wavernn/main.py +++ b/examples/pipeline_wavernn/main.py @@ -349,9 +349,7 @@ def main(args): best_loss = 10.0 - load_checkpoint = args.checkpoint and os.path.isfile(args.checkpoint) - - if load_checkpoint: + if args.checkpoint and os.path.isfile(args.checkpoint): logging.info(f"Checkpoint: loading '{args.checkpoint}'") checkpoint = torch.load(args.checkpoint) From 1295d8ad3202677a4dab18f109c9388d6eaa47fe Mon Sep 17 00:00:00 2001 From: Ji Chen Date: Thu, 16 Jul 2020 13:57:32 -0700 Subject: [PATCH 21/29] add loss class and change function name --- examples/pipeline_wavernn/datasets.py | 2 +- examples/pipeline_wavernn/losses.py | 120 +++++++++++++----------- examples/pipeline_wavernn/main.py | 4 +- examples/pipeline_wavernn/processing.py | 4 +- 4 files changed, 69 insertions(+), 61 deletions(-) diff --git a/examples/pipeline_wavernn/datasets.py b/examples/pipeline_wavernn/datasets.py index c49faabab2..e4631e1cbf 100644 --- a/examples/pipeline_wavernn/datasets.py +++ b/examples/pipeline_wavernn/datasets.py @@ -7,7 +7,7 @@ from torchaudio.datasets import LJSPEECH from torchaudio.transforms import MuLawEncoding -from processing import label_to_waveform, waveform_to_label +from processing import encode_waveform_into_bits, encode_bits_into_waveform class MapMemoryCache(torch.utils.data.Dataset): diff --git a/examples/pipeline_wavernn/losses.py b/examples/pipeline_wavernn/losses.py index 2b81d827f4..6e9f20c2ed 100644 --- a/examples/pipeline_wavernn/losses.py +++ b/examples/pipeline_wavernn/losses.py @@ -1,18 +1,9 @@ import torch +import torch.nn as nn from torch.nn import functional as F -def log_sum_exp(x): - r""" Numerically stable log_sum_exp implementation that prevents overflow - """ - - axis = len(x.size()) - 1 - m, _ = torch.max(x, dim=axis) - m2, _ = torch.max(x, dim=axis, keepdim=True) - return m + torch.log(torch.sum(torch.exp(x - m2), dim=axis)) - - -def mol_loss(y_hat, y, num_classes=65536, log_scale_min=None, reduce=True): +class MoLLoss(nn.Module): r""" Discretized mixture of logistic distributions loss Adapted from wavenet vocoder @@ -30,61 +21,78 @@ def mol_loss(y_hat, y, num_classes=65536, log_scale_min=None, reduce=True): Tensor: loss """ - if log_scale_min is None: - log_scale_min = torch.log(torch.as_tensor(1e-14)).item() + def __init__(self, num_classes=65536, log_scale_min=None, reduce=True): + + self.num_classes = num_classes + self.log_scale_min = log_scale_min + self.reduce = reduce + + def forward(self, y_hat, y): + + if self.log_scale_min is None: + self.log_scale_min = torch.log(torch.as_tensor(1e-14)).item() + + assert y_hat.dim() == 3 + assert y_hat.size(-1) % 3 == 0 + + nr_mix = y_hat.size(-1) // 3 - assert y_hat.dim() == 3 - assert y_hat.size(-1) % 3 == 0 + # unpack parameters (n_batch, n_time, num_mixtures) x 3 + logit_probs = y_hat[:, :, :nr_mix] + means = y_hat[:, :, nr_mix: 2 * nr_mix] + log_scales = torch.clamp(y_hat[:, :, 2 * nr_mix: 3 * nr_mix], min=self.log_scale_min) - nr_mix = y_hat.size(-1) // 3 + # (n_batch x n_time x 1) to (n_batch x n_time x num_mixtures) + y = y.expand_as(means) - # unpack parameters (n_batch, n_time, num_mixtures) x 3 - logit_probs = y_hat[:, :, :nr_mix] - means = y_hat[:, :, nr_mix: 2 * nr_mix] - log_scales = torch.clamp(y_hat[:, :, 2 * nr_mix: 3 * nr_mix], min=log_scale_min) + centered_y = y - means + inv_stdv = torch.exp(-log_scales) + plus_in = inv_stdv * (centered_y + 1.0 / (self.num_classes - 1)) + cdf_plus = torch.sigmoid(plus_in) + min_in = inv_stdv * (centered_y - 1.0 / (self.num_classes - 1)) + cdf_min = torch.sigmoid(min_in) - # (n_batch x n_time x 1) to (n_batch x n_time x num_mixtures) - y = y.expand_as(means) + # log probability for edge case of 0 (before scaling) + # equivalent: torch.log(F.sigmoid(plus_in)) + log_cdf_plus = plus_in - F.softplus(plus_in) - centered_y = y - means - inv_stdv = torch.exp(-log_scales) - plus_in = inv_stdv * (centered_y + 1.0 / (num_classes - 1)) - cdf_plus = torch.sigmoid(plus_in) - min_in = inv_stdv * (centered_y - 1.0 / (num_classes - 1)) - cdf_min = torch.sigmoid(min_in) + # log probability for edge case of 255 (before scaling) + # equivalent: (1 - F.sigmoid(min_in)).log() + log_one_minus_cdf_min = -F.softplus(min_in) - # log probability for edge case of 0 (before scaling) - # equivalent: torch.log(F.sigmoid(plus_in)) - log_cdf_plus = plus_in - F.softplus(plus_in) + # probability for all other cases + cdf_delta = cdf_plus - cdf_min - # log probability for edge case of 255 (before scaling) - # equivalent: (1 - F.sigmoid(min_in)).log() - log_one_minus_cdf_min = -F.softplus(min_in) + mid_in = inv_stdv * centered_y + # log probability in the center of the bin, to be used in extreme cases + log_pdf_mid = mid_in - log_scales - 2.0 * F.softplus(mid_in) - # probability for all other cases - cdf_delta = cdf_plus - cdf_min + inner_inner_cond = (cdf_delta > 1e-5).float() - mid_in = inv_stdv * centered_y - # log probability in the center of the bin, to be used in extreme cases - log_pdf_mid = mid_in - log_scales - 2.0 * F.softplus(mid_in) + inner_inner_out = inner_inner_cond * torch.log( + torch.clamp(cdf_delta, min=1e-12) + ) + (1.0 - inner_inner_cond) * ( + log_pdf_mid - torch.log(torch.as_tensor((self.num_classes - 1) / 2)).item() + ) + inner_cond = (y > 0.999).float() + inner_out = ( + inner_cond * log_one_minus_cdf_min + (1.0 - inner_cond) * inner_inner_out + ) + cond = (y < -0.999).float() + log_probs = cond * log_cdf_plus + (1.0 - cond) * inner_out - inner_inner_cond = (cdf_delta > 1e-5).float() + log_probs = log_probs + F.log_softmax(logit_probs, -1) - inner_inner_out = inner_inner_cond * torch.log( - torch.clamp(cdf_delta, min=1e-12) - ) + (1.0 - inner_inner_cond) * ( - log_pdf_mid - torch.log(torch.as_tensor((num_classes - 1) / 2)).item() - ) - inner_cond = (y > 0.999).float() - inner_out = ( - inner_cond * log_one_minus_cdf_min + (1.0 - inner_cond) * inner_inner_out - ) - cond = (y < -0.999).float() - log_probs = cond * log_cdf_plus + (1.0 - cond) * inner_out + if self.reduce: + return -torch.mean(self.log_sum_exp(log_probs)) + else: + return -self.log_sum_exp(log_probs).unsqueeze(-1) - log_probs = log_probs + F.log_softmax(logit_probs, -1) + def log_sum_exp(self, x): + r""" Numerically stable log_sum_exp implementation that prevents overflow + """ - if reduce: - return -torch.mean(log_sum_exp(log_probs)) - else: - return -log_sum_exp(log_probs).unsqueeze(-1) + axis = len(x.size()) - 1 + m, _ = torch.max(x, dim=axis) + m2, _ = torch.max(x, dim=axis, keepdim=True) + return m + torch.log(torch.sum(torch.exp(x - m2), dim=axis)) diff --git a/examples/pipeline_wavernn/main.py b/examples/pipeline_wavernn/main.py index b5bbb4baf3..232debbb9b 100644 --- a/examples/pipeline_wavernn/main.py +++ b/examples/pipeline_wavernn/main.py @@ -16,7 +16,7 @@ from torchaudio.models._wavernn import _WaveRNN from datasets import collate_factory, split_process_ljspeech -from losses import mol_loss +from losses import MoLLoss from processing import LinearToMel, NormalizeDB from utils import MetricLogger, count_parameters, save_checkpoint @@ -345,7 +345,7 @@ def main(args): optimizer = Adam(model.parameters(), **optimizer_params) - criterion = nn.CrossEntropyLoss() if args.mode == "waveform" else mol_loss + criterion = nn.CrossEntropyLoss() if args.mode == "waveform" else MoLLoss() best_loss = 10.0 diff --git a/examples/pipeline_wavernn/processing.py b/examples/pipeline_wavernn/processing.py index 8b28d2990f..51a1e0e7dc 100644 --- a/examples/pipeline_wavernn/processing.py +++ b/examples/pipeline_wavernn/processing.py @@ -36,7 +36,7 @@ def forward(self, specgram): return torch.clamp((self.min_level_db - specgram) / self.min_level_db, min=0, max=1) -def waveform_to_label(waveform, bits): +def encode_waveform_into_bits(waveform, bits): r"""Transform waveform [-1, 1] to label [0, 2 ** bits - 1] """ @@ -45,7 +45,7 @@ def waveform_to_label(waveform, bits): return torch.clamp(waveform, 0, 2 ** bits - 1).int() -def label_to_waveform(label, bits): +def encode_bits_into_waveform(label, bits): r"""Transform label [0, 2 ** bits - 1] to waveform [-1, 1] """ From ee1d702e463ed48a0ba4fc1232d6130c3df389c9 Mon Sep 17 00:00:00 2001 From: Ji Chen Date: Thu, 16 Jul 2020 20:26:27 -0700 Subject: [PATCH 22/29] update format --- examples/pipeline_wavernn/README.md | 4 ++-- examples/pipeline_wavernn/datasets.py | 4 ++-- examples/pipeline_wavernn/losses.py | 31 +++++++++++++++------------ examples/pipeline_wavernn/main.py | 24 ++++++++++----------- 4 files changed, 33 insertions(+), 30 deletions(-) diff --git a/examples/pipeline_wavernn/README.md b/examples/pipeline_wavernn/README.md index 76b7bc2e15..084d1fb372 100644 --- a/examples/pipeline_wavernn/README.md +++ b/examples/pipeline_wavernn/README.md @@ -1,4 +1,4 @@ -This is an example pipeline for WaveRNN vocoder using the WaveRNN model trained on LJSpeech. WaveRNN model is based on the implementation from [this repository](https://github.com/fatchord/WaveRNN). The original implementation was +This is an example vocoder pipeline using the WaveRNN model trained with LJSpeech. WaveRNN model is based on the implementation from [this repository](https://github.com/fatchord/WaveRNN). The original implementation was introduced in "Efficient Neural Audio Synthesis". WaveRNN and LJSpeech are available in torchaudio. ### Usage @@ -15,7 +15,7 @@ python main.py \ ### Output -The information reported at each iteration and epoch (e.g. loss) is printed to standard output in the form of one json per line. Here is an example python function to parse the standard output. +The information reported at each iteration and epoch (e.g. loss) is printed to standard output in the form of one json per line. Here is an example python function to parse the output if redirected to a file. ```python def read_json(filename): """ diff --git a/examples/pipeline_wavernn/datasets.py b/examples/pipeline_wavernn/datasets.py index e4631e1cbf..23ddebcfdd 100644 --- a/examples/pipeline_wavernn/datasets.py +++ b/examples/pipeline_wavernn/datasets.py @@ -105,10 +105,10 @@ def raw_collate(batch): waveform = mulaw_encode(waveform) target = mulaw_encode(target) - waveform = label_to_waveform(waveform, args.n_bits) + waveform = encode_bits_into_waveform(waveform, args.n_bits) else: - target = waveform_to_label(target, args.n_bits) + target = encode_waveform_into_bits(target, args.n_bits) return waveform.unsqueeze(1), specgram.unsqueeze(1), target.unsqueeze(1) diff --git a/examples/pipeline_wavernn/losses.py b/examples/pipeline_wavernn/losses.py index 6e9f20c2ed..e9f39cd30d 100644 --- a/examples/pipeline_wavernn/losses.py +++ b/examples/pipeline_wavernn/losses.py @@ -3,7 +3,7 @@ from torch.nn import functional as F -class MoLLoss(nn.Module): +class MoLLoss(torch.nn.Module): r""" Discretized mixture of logistic distributions loss Adapted from wavenet vocoder @@ -12,17 +12,17 @@ class MoLLoss(nn.Module): Args: y_hat (Tensor): Predicted output (n_batch x n_time x n_channel) - y (Tensor): Target (n_batch x n_time x 1). + y (Tensor): Target (n_batch x n_time x 1) num_classes (int): Number of classes log_scale_min (float): Log scale minimum value - reduce (bool): If True, the losses are averaged or summed for each minibatch. + reduce (bool): If True, the losses are averaged or summed for each minibatch Returns Tensor: loss """ def __init__(self, num_classes=65536, log_scale_min=None, reduce=True): - + super(MoLLoss, self).__init__() self.num_classes = num_classes self.log_scale_min = log_scale_min self.reduce = reduce @@ -40,7 +40,9 @@ def forward(self, y_hat, y): # unpack parameters (n_batch, n_time, num_mixtures) x 3 logit_probs = y_hat[:, :, :nr_mix] means = y_hat[:, :, nr_mix: 2 * nr_mix] - log_scales = torch.clamp(y_hat[:, :, 2 * nr_mix: 3 * nr_mix], min=self.log_scale_min) + log_scales = torch.clamp( + y_hat[:, :, 2 * nr_mix: 3 * nr_mix], min=self.log_scale_min + ) # (n_batch x n_time x 1) to (n_batch x n_time x num_mixtures) y = y.expand_as(means) @@ -84,15 +86,16 @@ def forward(self, y_hat, y): log_probs = log_probs + F.log_softmax(logit_probs, -1) if self.reduce: - return -torch.mean(self.log_sum_exp(log_probs)) + return -torch.mean(log_sum_exp(log_probs)) else: - return -self.log_sum_exp(log_probs).unsqueeze(-1) + return -log_sum_exp(log_probs).unsqueeze(-1) - def log_sum_exp(self, x): - r""" Numerically stable log_sum_exp implementation that prevents overflow - """ - axis = len(x.size()) - 1 - m, _ = torch.max(x, dim=axis) - m2, _ = torch.max(x, dim=axis, keepdim=True) - return m + torch.log(torch.sum(torch.exp(x - m2), dim=axis)) +def log_sum_exp(x): + r""" Numerically stable log_sum_exp implementation that prevents overflow + """ + + axis = len(x.size()) - 1 + m, _ = torch.max(x, dim=axis) + m2, _ = torch.max(x, dim=axis, keepdim=True) + return m + torch.log(torch.sum(torch.exp(x - m2), dim=axis)) diff --git a/examples/pipeline_wavernn/main.py b/examples/pipeline_wavernn/main.py index 232debbb9b..8208d8dda2 100644 --- a/examples/pipeline_wavernn/main.py +++ b/examples/pipeline_wavernn/main.py @@ -68,9 +68,6 @@ def parse_args(): action="store_true", help="if used, waveform is mulaw encoded", ) - parser.add_argument( - "--jit", default=False, action="store_true", help="if used, model is jitted" - ) parser.add_argument( "--upsample-scales", default=[5, 5, 11], @@ -126,13 +123,16 @@ def parse_args(): help="the number of kernel size in the first Conv1d layer", ) parser.add_argument( - "--n-freq", default=80, type=int, help="the number of bins in a spectrogram", + "--n-freq", default=80, type=int, help="the number of spectrogram bins to use", ) parser.add_argument( "--n-hidden", default=128, type=int, help="the number of hidden dimensions", ) parser.add_argument( - "--n-output", default=128, type=int, help="the number of output dimensions", + "--n-output", + default=128, + type=int, + help="the output dimension of upsample network in WaveRNN model", ) parser.add_argument( "--n-fft", default=2048, type=int, help="the number of Fourier bins", @@ -148,7 +148,7 @@ def parse_args(): "--seq-len-factor", default=5, type=int, - help="seq_length = hop_length * seq_len_factor", + help="the length factor of input waveform, the length of input waveform = hop_length * seq_len_factor", ) parser.add_argument( "--val-ratio", @@ -157,7 +157,10 @@ def parse_args(): help="the ratio of waveforms for validation", ) parser.add_argument( - "--file-path", default="", type=str, help="the path of audio files", + "--file-path", + default="", + type=str, + help="the path of audio files", ) args = parser.parse_args() @@ -219,8 +222,8 @@ def train_one_epoch(model, mode, criterion, optimizer, data_loader, device, epoc metric = MetricLogger("train_epoch") metric["epoch"] = epoch - metric["loss"] = avg_loss - metric["gradient"] = sums["gradient"] / len(data_loader) + metric["loss"] = sums["loss"] / len(data_loader) + metric["gradient"] = avg_loss metric["time"] = time() - start1 metric() @@ -329,9 +332,6 @@ def main(args): mode=args.mode, ) - if args.jit: - model = torch.jit.script(model) - model = torch.nn.DataParallel(model) model = model.to(devices[0], non_blocking=True) From f72545794f4187d34ee3dbcf4a3d00d39b731623 Mon Sep 17 00:00:00 2001 From: Ji Chen Date: Fri, 17 Jul 2020 06:16:19 -0700 Subject: [PATCH 23/29] update varible name --- examples/pipeline_wavernn/main.py | 17 ++++------------- 1 file changed, 4 insertions(+), 13 deletions(-) diff --git a/examples/pipeline_wavernn/main.py b/examples/pipeline_wavernn/main.py index 8208d8dda2..2bc74dc0f2 100644 --- a/examples/pipeline_wavernn/main.py +++ b/examples/pipeline_wavernn/main.py @@ -90,22 +90,16 @@ def parse_args(): help="the number of samples between the starts of consecutive frames", ) parser.add_argument( - "--win-length", - default=1100, - type=int, - help="the number of samples between the starts of consecutive frames", + "--win-length", default=1100, type=int, help="the length of the STFT window", ) parser.add_argument( - "--f-min", - default=40.0, - type=float, - help="the number of samples between the starts of consecutive frames", + "--f-min", default=40.0, type=float, help="the minimum frequency", ) parser.add_argument( "--min-level-db", default=-100, type=float, - help="the min db value for spectrogam normalization", + help="the minimum db value for spectrogam normalization", ) parser.add_argument( "--n-res-block", default=10, type=int, help="the number of ResBlock in stack", @@ -157,10 +151,7 @@ def parse_args(): help="the ratio of waveforms for validation", ) parser.add_argument( - "--file-path", - default="", - type=str, - help="the path of audio files", + "--file-path", default="", type=str, help="the path of audio files", ) args = parser.parse_args() From b6198d83441a2c045dff929ecd0f30091dece0c3 Mon Sep 17 00:00:00 2001 From: Ji Chen Date: Fri, 17 Jul 2020 11:26:42 -0700 Subject: [PATCH 24/29] update variable name in wavernn --- examples/pipeline_wavernn/datasets.py | 10 ++++---- examples/pipeline_wavernn/losses.py | 6 +++-- examples/pipeline_wavernn/main.py | 36 +++++++++++++-------------- 3 files changed, 27 insertions(+), 25 deletions(-) diff --git a/examples/pipeline_wavernn/datasets.py b/examples/pipeline_wavernn/datasets.py index 23ddebcfdd..d92165dcdc 100644 --- a/examples/pipeline_wavernn/datasets.py +++ b/examples/pipeline_wavernn/datasets.py @@ -43,9 +43,9 @@ def __getitem__(self, key): def __len__(self): return len(self.dataset) - def process_datapoint(self, waveform): - specgram = self.transforms(waveform[0]) - return waveform[0].squeeze(0), specgram + def process_datapoint(self, item): + specgram = self.transforms(item[0]) + return item[0].squeeze(0), specgram def split_process_ljspeech(args, transforms): @@ -97,8 +97,8 @@ def raw_collate(batch): waveform = waveform_combine[:, :wave_length] target = waveform_combine[:, 1:] - # waveform: [-1, 1], target: [0, 2**bits-1] if mode = 'waveform' - if args.mode == "waveform": + # waveform: [-1, 1], target: [0, 2**bits-1] if loss = 'waveform' + if args.loss == "waveform": if args.mulaw: mulaw_encode = MuLawEncoding(2 ** args.n_bits) diff --git a/examples/pipeline_wavernn/losses.py b/examples/pipeline_wavernn/losses.py index e9f39cd30d..4b189b2edc 100644 --- a/examples/pipeline_wavernn/losses.py +++ b/examples/pipeline_wavernn/losses.py @@ -2,6 +2,8 @@ import torch.nn as nn from torch.nn import functional as F +import math + class MoLLoss(torch.nn.Module): r""" Discretized mixture of logistic distributions loss @@ -30,7 +32,7 @@ def __init__(self, num_classes=65536, log_scale_min=None, reduce=True): def forward(self, y_hat, y): if self.log_scale_min is None: - self.log_scale_min = torch.log(torch.as_tensor(1e-14)).item() + self.log_scale_min = math.log(1e-14) assert y_hat.dim() == 3 assert y_hat.size(-1) % 3 == 0 @@ -74,7 +76,7 @@ def forward(self, y_hat, y): inner_inner_out = inner_inner_cond * torch.log( torch.clamp(cdf_delta, min=1e-12) ) + (1.0 - inner_inner_cond) * ( - log_pdf_mid - torch.log(torch.as_tensor((self.num_classes - 1) / 2)).item() + log_pdf_mid - math.log((self.num_classes - 1) / 2) ) inner_cond = (y > 0.999).float() inner_out = ( diff --git a/examples/pipeline_wavernn/main.py b/examples/pipeline_wavernn/main.py index 2bc74dc0f2..a1f9cdfe80 100644 --- a/examples/pipeline_wavernn/main.py +++ b/examples/pipeline_wavernn/main.py @@ -120,23 +120,23 @@ def parse_args(): "--n-freq", default=80, type=int, help="the number of spectrogram bins to use", ) parser.add_argument( - "--n-hidden", default=128, type=int, help="the number of hidden dimensions", + "--n-hidden-resblock", default=128, type=int, help="the number of hidden dimensions of resblock", ) parser.add_argument( - "--n-output", + "--n-output-melresnet", default=128, type=int, - help="the output dimension of upsample network in WaveRNN model", + help="the output dimension of melresnet block in WaveRNN model", ) parser.add_argument( "--n-fft", default=2048, type=int, help="the number of Fourier bins", ) parser.add_argument( - "--mode", + "--loss", default="waveform", choices=["waveform", "mol"], type=str, - help="the mode of waveform", + help="the type of loss", ) parser.add_argument( "--seq-len-factor", @@ -151,14 +151,14 @@ def parse_args(): help="the ratio of waveforms for validation", ) parser.add_argument( - "--file-path", default="", type=str, help="the path of audio files", + "--file-path", default="/private/home/jimchen90/datasets", type=str, help="the path of audio files", ) args = parser.parse_args() return args -def train_one_epoch(model, mode, criterion, optimizer, data_loader, device, epoch): +def train_one_epoch(model, loss, criterion, optimizer, data_loader, device, epoch): model.train() @@ -179,12 +179,12 @@ def train_one_epoch(model, mode, criterion, optimizer, data_loader, device, epoc output = model(waveform, specgram) output, target = output.squeeze(1), target.squeeze(1) - if mode == "waveform": + if loss == "waveform": output = output.transpose(1, 2) target = target.long() else: - # use mol mode + # use mol loss target = target.unsqueeze(-1) loss = criterion(output, target) @@ -219,7 +219,7 @@ def train_one_epoch(model, mode, criterion, optimizer, data_loader, device, epoc metric() -def validate(model, mode, criterion, data_loader, device, epoch): +def validate(model, loss, criterion, data_loader, device, epoch): with torch.no_grad(): @@ -236,12 +236,12 @@ def validate(model, mode, criterion, data_loader, device, epoch): output = model(waveform, specgram) output, target = output.squeeze(1), target.squeeze(1) - if mode == "waveform": + if loss == "waveform": output = output.transpose(1, 2) target = target.long() else: - # use mol mode + # use mol loss target = target.unsqueeze(-1) loss = criterion(output, target) @@ -318,9 +318,9 @@ def main(args): n_fc=args.n_fc, kernel_size=args.kernel_size, n_freq=args.n_freq, - n_hidden=args.n_hidden, - n_output=args.n_output, - mode=args.mode, + n_hidden_resblock=args.n_hidden_resblock, + n_output_melresnet=args.n_output_melresnet, + loss=args.loss, ) model = torch.nn.DataParallel(model) @@ -336,7 +336,7 @@ def main(args): optimizer = Adam(model.parameters(), **optimizer_params) - criterion = nn.CrossEntropyLoss() if args.mode == "waveform" else MoLLoss() + criterion = nn.CrossEntropyLoss() if args.loss == "waveform" else MoLLoss() best_loss = 10.0 @@ -370,13 +370,13 @@ def main(args): for epoch in range(args.start_epoch, args.epochs): train_one_epoch( - model, args.mode, criterion, optimizer, train_loader, devices[0], epoch, + model, args.loss, criterion, optimizer, train_loader, devices[0], epoch, ) if not (epoch + 1) % args.print_freq or epoch == args.epochs - 1: sum_loss = validate( - model, args.mode, criterion, val_loader, devices[0], epoch, + model, args.loss, criterion, val_loader, devices[0], epoch, ) is_best = sum_loss < best_loss From a43b8a2ae79ad08aed8efd522ef6ef3859022f8d Mon Sep 17 00:00:00 2001 From: Ji Chen Date: Fri, 17 Jul 2020 13:52:30 -0700 Subject: [PATCH 25/29] update mode --- examples/pipeline_wavernn/main.py | 41 +++++++++++++++++-------------- 1 file changed, 22 insertions(+), 19 deletions(-) diff --git a/examples/pipeline_wavernn/main.py b/examples/pipeline_wavernn/main.py index a1f9cdfe80..4c96e1c2dc 100644 --- a/examples/pipeline_wavernn/main.py +++ b/examples/pipeline_wavernn/main.py @@ -120,7 +120,10 @@ def parse_args(): "--n-freq", default=80, type=int, help="the number of spectrogram bins to use", ) parser.add_argument( - "--n-hidden-resblock", default=128, type=int, help="the number of hidden dimensions of resblock", + "--n-hidden-resblock", + default=128, + type=int, + help="the number of hidden dimensions of resblock", ) parser.add_argument( "--n-output-melresnet", @@ -132,17 +135,17 @@ def parse_args(): "--n-fft", default=2048, type=int, help="the number of Fourier bins", ) parser.add_argument( - "--loss", - default="waveform", - choices=["waveform", "mol"], + "--loss-fn", + default="crossentropy", + choices=["crossentropy", "mol"], type=str, - help="the type of loss", + help="the type of loss function", ) parser.add_argument( "--seq-len-factor", default=5, type=int, - help="the length factor of input waveform, the length of input waveform = hop_length * seq_len_factor", + help="the length of each waveform to process per batch = hop_length * seq_len_factor", ) parser.add_argument( "--val-ratio", @@ -151,14 +154,14 @@ def parse_args(): help="the ratio of waveforms for validation", ) parser.add_argument( - "--file-path", default="/private/home/jimchen90/datasets", type=str, help="the path of audio files", + "--file-path", default="", type=str, help="the path of audio files", ) args = parser.parse_args() return args -def train_one_epoch(model, loss, criterion, optimizer, data_loader, device, epoch): +def train_one_epoch(model, loss_fn, criterion, optimizer, data_loader, device, epoch): model.train() @@ -179,7 +182,7 @@ def train_one_epoch(model, loss, criterion, optimizer, data_loader, device, epoc output = model(waveform, specgram) output, target = output.squeeze(1), target.squeeze(1) - if loss == "waveform": + if loss_fn == "crossentropy": output = output.transpose(1, 2) target = target.long() @@ -219,7 +222,7 @@ def train_one_epoch(model, loss, criterion, optimizer, data_loader, device, epoc metric() -def validate(model, loss, criterion, data_loader, device, epoch): +def validate(model, loss_fn, criterion, data_loader, device, epoch): with torch.no_grad(): @@ -236,7 +239,7 @@ def validate(model, loss, criterion, data_loader, device, epoch): output = model(waveform, specgram) output, target = output.squeeze(1), target.squeeze(1) - if loss == "waveform": + if loss_fn == "crossentropy": output = output.transpose(1, 2) target = target.long() @@ -308,19 +311,19 @@ def main(args): **loader_validation_params, ) + n_classes = 2 ** args.n_bits if args.loss_fn == "crossentropy" else 30 + model = _WaveRNN( upsample_scales=args.upsample_scales, - n_bits=args.n_bits, - sample_rate=args.sample_rate, + n_classes=n_classes, hop_length=args.hop_length, n_res_block=args.n_res_block, n_rnn=args.n_rnn, n_fc=args.n_fc, kernel_size=args.kernel_size, n_freq=args.n_freq, - n_hidden_resblock=args.n_hidden_resblock, - n_output_melresnet=args.n_output_melresnet, - loss=args.loss, + n_hidden=args.n_hidden, + n_output=args.n_output, ) model = torch.nn.DataParallel(model) @@ -336,7 +339,7 @@ def main(args): optimizer = Adam(model.parameters(), **optimizer_params) - criterion = nn.CrossEntropyLoss() if args.loss == "waveform" else MoLLoss() + criterion = nn.CrossEntropyLoss() if args.loss_fn == "crossentropy" else MoLLoss() best_loss = 10.0 @@ -370,13 +373,13 @@ def main(args): for epoch in range(args.start_epoch, args.epochs): train_one_epoch( - model, args.loss, criterion, optimizer, train_loader, devices[0], epoch, + model, args.loss_fn, criterion, optimizer, train_loader, devices[0], epoch, ) if not (epoch + 1) % args.print_freq or epoch == args.epochs - 1: sum_loss = validate( - model, args.loss, criterion, val_loader, devices[0], epoch, + model, args.loss_fn, criterion, val_loader, devices[0], epoch, ) is_best = sum_loss < best_loss From 6f8660a3679065f3ead0e04f142e08f51dea904e Mon Sep 17 00:00:00 2001 From: Ji Chen Date: Fri, 17 Jul 2020 15:08:13 -0700 Subject: [PATCH 26/29] update loss class --- examples/pipeline_wavernn/datasets.py | 10 +++---- examples/pipeline_wavernn/losses.py | 20 +++++++++++-- examples/pipeline_wavernn/main.py | 40 +++++++------------------ examples/pipeline_wavernn/processing.py | 4 +-- 4 files changed, 36 insertions(+), 38 deletions(-) diff --git a/examples/pipeline_wavernn/datasets.py b/examples/pipeline_wavernn/datasets.py index d92165dcdc..8d3068a229 100644 --- a/examples/pipeline_wavernn/datasets.py +++ b/examples/pipeline_wavernn/datasets.py @@ -7,7 +7,7 @@ from torchaudio.datasets import LJSPEECH from torchaudio.transforms import MuLawEncoding -from processing import encode_waveform_into_bits, encode_bits_into_waveform +from processing import bits_to_normalized_waveform, normalized_waveform_to_bits class MapMemoryCache(torch.utils.data.Dataset): @@ -97,18 +97,18 @@ def raw_collate(batch): waveform = waveform_combine[:, :wave_length] target = waveform_combine[:, 1:] - # waveform: [-1, 1], target: [0, 2**bits-1] if loss = 'waveform' - if args.loss == "waveform": + # waveform: [-1, 1], target: [0, 2**bits-1] if loss = 'crossentropy' + if args.loss == "crossentropy": if args.mulaw: mulaw_encode = MuLawEncoding(2 ** args.n_bits) waveform = mulaw_encode(waveform) target = mulaw_encode(target) - waveform = encode_bits_into_waveform(waveform, args.n_bits) + waveform = bits_to_normalized_waveform(waveform, args.n_bits) else: - target = encode_waveform_into_bits(target, args.n_bits) + target = normalized_waveform_to_bits(target, args.n_bits) return waveform.unsqueeze(1), specgram.unsqueeze(1), target.unsqueeze(1) diff --git a/examples/pipeline_wavernn/losses.py b/examples/pipeline_wavernn/losses.py index 4b189b2edc..a008458d75 100644 --- a/examples/pipeline_wavernn/losses.py +++ b/examples/pipeline_wavernn/losses.py @@ -1,8 +1,23 @@ +import math + import torch -import torch.nn as nn +from torch import nn as nn from torch.nn import functional as F -import math + +class LongCrossEntropyLoss(torch.nn.Module): + r""" CrossEntropy loss + """ + + def __init__(self): + super(LongCrossEntropyLoss, self).__init__() + + def forward(self, output, target): + output = output.transpose(1, 2) + target = target.long() + + criterion = nn.CrossEntropyLoss() + return criterion(output, target) class MoLLoss(torch.nn.Module): @@ -30,6 +45,7 @@ def __init__(self, num_classes=65536, log_scale_min=None, reduce=True): self.reduce = reduce def forward(self, y_hat, y): + y = y.unsqueeze(-1) if self.log_scale_min is None: self.log_scale_min = math.log(1e-14) diff --git a/examples/pipeline_wavernn/main.py b/examples/pipeline_wavernn/main.py index 4c96e1c2dc..f9aae65251 100644 --- a/examples/pipeline_wavernn/main.py +++ b/examples/pipeline_wavernn/main.py @@ -16,7 +16,7 @@ from torchaudio.models._wavernn import _WaveRNN from datasets import collate_factory, split_process_ljspeech -from losses import MoLLoss +from losses import LongCrossEntropyLoss, MoLLoss from processing import LinearToMel, NormalizeDB from utils import MetricLogger, count_parameters, save_checkpoint @@ -120,13 +120,13 @@ def parse_args(): "--n-freq", default=80, type=int, help="the number of spectrogram bins to use", ) parser.add_argument( - "--n-hidden-resblock", + "--n-hidden", default=128, type=int, help="the number of hidden dimensions of resblock", ) parser.add_argument( - "--n-output-melresnet", + "--n-output", default=128, type=int, help="the output dimension of melresnet block in WaveRNN model", @@ -135,11 +135,11 @@ def parse_args(): "--n-fft", default=2048, type=int, help="the number of Fourier bins", ) parser.add_argument( - "--loss-fn", + "--loss", default="crossentropy", choices=["crossentropy", "mol"], type=str, - help="the type of loss function", + help="the type of loss", ) parser.add_argument( "--seq-len-factor", @@ -161,7 +161,7 @@ def parse_args(): return args -def train_one_epoch(model, loss_fn, criterion, optimizer, data_loader, device, epoch): +def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch): model.train() @@ -182,14 +182,6 @@ def train_one_epoch(model, loss_fn, criterion, optimizer, data_loader, device, e output = model(waveform, specgram) output, target = output.squeeze(1), target.squeeze(1) - if loss_fn == "crossentropy": - output = output.transpose(1, 2) - target = target.long() - - else: - # use mol loss - target = target.unsqueeze(-1) - loss = criterion(output, target) loss_item = loss.item() sums["loss"] += loss_item @@ -222,7 +214,7 @@ def train_one_epoch(model, loss_fn, criterion, optimizer, data_loader, device, e metric() -def validate(model, loss_fn, criterion, data_loader, device, epoch): +def validate(model, criterion, data_loader, device, epoch): with torch.no_grad(): @@ -239,14 +231,6 @@ def validate(model, loss_fn, criterion, data_loader, device, epoch): output = model(waveform, specgram) output, target = output.squeeze(1), target.squeeze(1) - if loss_fn == "crossentropy": - output = output.transpose(1, 2) - target = target.long() - - else: - # use mol loss - target = target.unsqueeze(-1) - loss = criterion(output, target) sums["loss"] += loss.item() @@ -311,7 +295,7 @@ def main(args): **loader_validation_params, ) - n_classes = 2 ** args.n_bits if args.loss_fn == "crossentropy" else 30 + n_classes = 2 ** args.n_bits if args.loss == "crossentropy" else 30 model = _WaveRNN( upsample_scales=args.upsample_scales, @@ -339,7 +323,7 @@ def main(args): optimizer = Adam(model.parameters(), **optimizer_params) - criterion = nn.CrossEntropyLoss() if args.loss_fn == "crossentropy" else MoLLoss() + criterion = LongCrossEntropyLoss() if args.loss == "crossentropy" else MoLLoss() best_loss = 10.0 @@ -373,14 +357,12 @@ def main(args): for epoch in range(args.start_epoch, args.epochs): train_one_epoch( - model, args.loss_fn, criterion, optimizer, train_loader, devices[0], epoch, + model, criterion, optimizer, train_loader, devices[0], epoch, ) if not (epoch + 1) % args.print_freq or epoch == args.epochs - 1: - sum_loss = validate( - model, args.loss_fn, criterion, val_loader, devices[0], epoch, - ) + sum_loss = validate(model, criterion, val_loader, devices[0], epoch) is_best = sum_loss < best_loss best_loss = min(sum_loss, best_loss) diff --git a/examples/pipeline_wavernn/processing.py b/examples/pipeline_wavernn/processing.py index 51a1e0e7dc..2f4dc94d15 100644 --- a/examples/pipeline_wavernn/processing.py +++ b/examples/pipeline_wavernn/processing.py @@ -36,7 +36,7 @@ def forward(self, specgram): return torch.clamp((self.min_level_db - specgram) / self.min_level_db, min=0, max=1) -def encode_waveform_into_bits(waveform, bits): +def normalized_waveform_to_bits(waveform, bits): r"""Transform waveform [-1, 1] to label [0, 2 ** bits - 1] """ @@ -45,7 +45,7 @@ def encode_waveform_into_bits(waveform, bits): return torch.clamp(waveform, 0, 2 ** bits - 1).int() -def encode_bits_into_waveform(label, bits): +def bits_to_normalized_waveform(label, bits): r"""Transform label [0, 2 ** bits - 1] to waveform [-1, 1] """ From 0df67b7238972cf8e824187c8f376ca251dd5849 Mon Sep 17 00:00:00 2001 From: Ji Chen Date: Fri, 17 Jul 2020 16:34:00 -0700 Subject: [PATCH 27/29] add underscore in mol loss --- examples/pipeline_wavernn/losses.py | 2 +- examples/pipeline_wavernn/main.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/pipeline_wavernn/losses.py b/examples/pipeline_wavernn/losses.py index a008458d75..227eca573a 100644 --- a/examples/pipeline_wavernn/losses.py +++ b/examples/pipeline_wavernn/losses.py @@ -109,7 +109,7 @@ def forward(self, y_hat, y): return -log_sum_exp(log_probs).unsqueeze(-1) -def log_sum_exp(x): +def _log_sum_exp(x): r""" Numerically stable log_sum_exp implementation that prevents overflow """ diff --git a/examples/pipeline_wavernn/main.py b/examples/pipeline_wavernn/main.py index f9aae65251..c3013f2097 100644 --- a/examples/pipeline_wavernn/main.py +++ b/examples/pipeline_wavernn/main.py @@ -129,7 +129,7 @@ def parse_args(): "--n-output", default=128, type=int, - help="the output dimension of melresnet block in WaveRNN model", + help="the output dimension of melresnet", ) parser.add_argument( "--n-fft", default=2048, type=int, help="the number of Fourier bins", From 1c426c055d1ebf3b17e4800cf851b252f6368b58 Mon Sep 17 00:00:00 2001 From: Ji Chen Date: Sat, 18 Jul 2020 12:54:55 -0700 Subject: [PATCH 28/29] add jit and underscore --- examples/pipeline_wavernn/losses.py | 8 ++++---- examples/pipeline_wavernn/main.py | 11 +++++++---- examples/pipeline_wavernn/processing.py | 10 ++++++++-- 3 files changed, 19 insertions(+), 10 deletions(-) diff --git a/examples/pipeline_wavernn/losses.py b/examples/pipeline_wavernn/losses.py index 227eca573a..a4494b05fb 100644 --- a/examples/pipeline_wavernn/losses.py +++ b/examples/pipeline_wavernn/losses.py @@ -5,7 +5,7 @@ from torch.nn import functional as F -class LongCrossEntropyLoss(torch.nn.Module): +class LongCrossEntropyLoss(nn.Module): r""" CrossEntropy loss """ @@ -20,7 +20,7 @@ def forward(self, output, target): return criterion(output, target) -class MoLLoss(torch.nn.Module): +class MoLLoss(nn.Module): r""" Discretized mixture of logistic distributions loss Adapted from wavenet vocoder @@ -104,9 +104,9 @@ def forward(self, y_hat, y): log_probs = log_probs + F.log_softmax(logit_probs, -1) if self.reduce: - return -torch.mean(log_sum_exp(log_probs)) + return -torch.mean(_log_sum_exp(log_probs)) else: - return -log_sum_exp(log_probs).unsqueeze(-1) + return -_log_sum_exp(log_probs).unsqueeze(-1) def _log_sum_exp(x): diff --git a/examples/pipeline_wavernn/main.py b/examples/pipeline_wavernn/main.py index c3013f2097..3e8b64c533 100644 --- a/examples/pipeline_wavernn/main.py +++ b/examples/pipeline_wavernn/main.py @@ -68,6 +68,9 @@ def parse_args(): action="store_true", help="if used, waveform is mulaw encoded", ) + parser.add_argument( + "--jit", default=False, action="store_true", help="if used, model is jitted" + ) parser.add_argument( "--upsample-scales", default=[5, 5, 11], @@ -126,10 +129,7 @@ def parse_args(): help="the number of hidden dimensions of resblock", ) parser.add_argument( - "--n-output", - default=128, - type=int, - help="the output dimension of melresnet", + "--n-output", default=128, type=int, help="the output dimension of melresnet", ) parser.add_argument( "--n-fft", default=2048, type=int, help="the number of Fourier bins", @@ -310,6 +310,9 @@ def main(args): n_output=args.n_output, ) + if args.jit: + model = torch.jit.script(model) + model = torch.nn.DataParallel(model) model = model.to(devices[0], non_blocking=True) diff --git a/examples/pipeline_wavernn/processing.py b/examples/pipeline_wavernn/processing.py index 2f4dc94d15..b22d60dae4 100644 --- a/examples/pipeline_wavernn/processing.py +++ b/examples/pipeline_wavernn/processing.py @@ -5,12 +5,14 @@ # TODO Replace by torchaudio, once https://github.com/pytorch/audio/pull/593 is resolved class LinearToMel(nn.Module): - def __init__(self, sample_rate, n_fft, n_mels, fmin): + def __init__(self, sample_rate, n_fft, n_mels, fmin, htk=False, norm="slaney"): super().__init__() self.sample_rate = sample_rate self.n_fft = n_fft self.n_mels = n_mels self.fmin = fmin + self.htk = htk + self.norm = norm def forward(self, specgram): specgram = librosa.feature.melspectrogram( @@ -19,6 +21,8 @@ def forward(self, specgram): n_fft=self.n_fft, n_mels=self.n_mels, fmin=self.fmin, + htk=self.htk, + norm=self.norm, ) return torch.from_numpy(specgram) @@ -33,7 +37,9 @@ def __init__(self, min_level_db): def forward(self, specgram): specgram = 20 * torch.log10(torch.clamp(specgram, min=1e-5)) - return torch.clamp((self.min_level_db - specgram) / self.min_level_db, min=0, max=1) + return torch.clamp( + (self.min_level_db - specgram) / self.min_level_db, min=0, max=1 + ) def normalized_waveform_to_bits(waveform, bits): From b306f68e8ac62afd96578b793788991a20f841e1 Mon Sep 17 00:00:00 2001 From: Ji Chen Date: Mon, 20 Jul 2020 11:55:53 -0700 Subject: [PATCH 29/29] change two command line parameters --- examples/pipeline_wavernn/README.md | 2 +- examples/pipeline_wavernn/main.py | 10 +++++----- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/examples/pipeline_wavernn/README.md b/examples/pipeline_wavernn/README.md index 084d1fb372..865f06c181 100644 --- a/examples/pipeline_wavernn/README.md +++ b/examples/pipeline_wavernn/README.md @@ -9,7 +9,7 @@ python main.py \ --batch-size 256 \ --learning-rate 1e-4 \ --n-freq 80 \ - --mode 'waveform' \ + --loss 'crossentropy' \ --n-bits 8 \ ``` diff --git a/examples/pipeline_wavernn/main.py b/examples/pipeline_wavernn/main.py index 3e8b64c533..032e9fc70e 100644 --- a/examples/pipeline_wavernn/main.py +++ b/examples/pipeline_wavernn/main.py @@ -123,13 +123,13 @@ def parse_args(): "--n-freq", default=80, type=int, help="the number of spectrogram bins to use", ) parser.add_argument( - "--n-hidden", + "--n-hidden-melresnet", default=128, type=int, - help="the number of hidden dimensions of resblock", + help="the number of hidden dimensions of resblock in melresnet", ) parser.add_argument( - "--n-output", default=128, type=int, help="the output dimension of melresnet", + "--n-output-melresnet", default=128, type=int, help="the output dimension of melresnet", ) parser.add_argument( "--n-fft", default=2048, type=int, help="the number of Fourier bins", @@ -306,8 +306,8 @@ def main(args): n_fc=args.n_fc, kernel_size=args.kernel_size, n_freq=args.n_freq, - n_hidden=args.n_hidden, - n_output=args.n_output, + n_hidden=args.n_hidden_melresnet, + n_output=args.n_output_melresnet, ) if args.jit: