Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add wavernn example pipeline #749

Merged
merged 29 commits into from
Jul 21, 2020
Merged
Show file tree
Hide file tree
Changes from 28 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions examples/pipeline_wavernn/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
This is an example vocoder pipeline using the WaveRNN model trained with LJSpeech. WaveRNN model is based on the implementation from [this repository](https://github.com/fatchord/WaveRNN). The original implementation was
introduced in "Efficient Neural Audio Synthesis". WaveRNN and LJSpeech are available in torchaudio.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As a separate pull request, we could consider adding the plots for training and validation errors in this readme for the pre-trained weight. Either in #776 or another one.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good suggestion, I will have a separate pull request to add the plots.


### Usage

An example can be invoked as follows.
```
python main.py \
--batch-size 256 \
--learning-rate 1e-4 \
--n-freq 80 \
--mode 'waveform' \
--n-bits 8 \
```

### Output

The information reported at each iteration and epoch (e.g. loss) is printed to standard output in the form of one json per line. Here is an example python function to parse the output if redirected to a file.
```python
def read_json(filename):
"""
Convert the standard output saved to filename into a pandas dataframe for analysis.
"""

import pandas
import json

with open(filename, "r") as f:
data = f.read()

# pandas doesn't read single quotes for json
data = data.replace("'", '"')

data = [json.loads(l) for l in data.splitlines()]
return pandas.DataFrame(data)
```
115 changes: 115 additions & 0 deletions examples/pipeline_wavernn/datasets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
import os
import random

import torch
import torchaudio
from torch.utils.data.dataset import random_split
from torchaudio.datasets import LJSPEECH
from torchaudio.transforms import MuLawEncoding

from processing import bits_to_normalized_waveform, normalized_waveform_to_bits


class MapMemoryCache(torch.utils.data.Dataset):
r"""Wrap a dataset so that, whenever a new item is returned, it is saved to memory.
"""

def __init__(self, dataset):
self.dataset = dataset
self._cache = [None] * len(dataset)

def __getitem__(self, n):
if self._cache[n] is not None:
return self._cache[n]

item = self.dataset[n]
self._cache[n] = item

return item

def __len__(self):
return len(self.dataset)


class Processed(torch.utils.data.Dataset):
def __init__(self, dataset, transforms):
self.dataset = dataset
self.transforms = transforms

def __getitem__(self, key):
item = self.dataset[key]
return self.process_datapoint(item)

def __len__(self):
return len(self.dataset)

def process_datapoint(self, item):
specgram = self.transforms(item[0])
return item[0].squeeze(0), specgram


def split_process_ljspeech(args, transforms):
data = LJSPEECH(root=args.file_path, download=False)
vincentqb marked this conversation as resolved.
Show resolved Hide resolved

val_length = int(len(data) * args.val_ratio)
lengths = [len(data) - val_length, val_length]
train_dataset, val_dataset = random_split(data, lengths)

train_dataset = Processed(train_dataset, transforms)
val_dataset = Processed(val_dataset, transforms)

train_dataset = MapMemoryCache(train_dataset)
val_dataset = MapMemoryCache(val_dataset)

return train_dataset, val_dataset
vincentqb marked this conversation as resolved.
Show resolved Hide resolved


def collate_factory(args):
def raw_collate(batch):

pad = (args.kernel_size - 1) // 2

# input waveform length
wave_length = args.hop_length * args.seq_len_factor
# input spectrogram length
spec_length = args.seq_len_factor + pad * 2

# max start postion in spectrogram
max_offsets = [x[1].shape[-1] - (spec_length + pad * 2) for x in batch]

# random start postion in spectrogram
spec_offsets = [random.randint(0, offset) for offset in max_offsets]
# random start postion in waveform
wave_offsets = [(offset + pad) * args.hop_length for offset in spec_offsets]

waveform_combine = [
x[0][wave_offsets[i]: wave_offsets[i] + wave_length + 1]
for i, x in enumerate(batch)
]
specgram = [
x[1][:, spec_offsets[i]: spec_offsets[i] + spec_length]
for i, x in enumerate(batch)
]

specgram = torch.stack(specgram)
waveform_combine = torch.stack(waveform_combine)

waveform = waveform_combine[:, :wave_length]
target = waveform_combine[:, 1:]

# waveform: [-1, 1], target: [0, 2**bits-1] if loss = 'crossentropy'
if args.loss == "crossentropy":

if args.mulaw:
mulaw_encode = MuLawEncoding(2 ** args.n_bits)
waveform = mulaw_encode(waveform)
target = mulaw_encode(target)

waveform = bits_to_normalized_waveform(waveform, args.n_bits)

else:
target = normalized_waveform_to_bits(target, args.n_bits)

return waveform.unsqueeze(1), specgram.unsqueeze(1), target.unsqueeze(1)

return raw_collate
119 changes: 119 additions & 0 deletions examples/pipeline_wavernn/losses.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
import math

import torch
from torch import nn as nn
from torch.nn import functional as F


class LongCrossEntropyLoss(nn.Module):
r""" CrossEntropy loss
"""

def __init__(self):
super(LongCrossEntropyLoss, self).__init__()

def forward(self, output, target):
output = output.transpose(1, 2)
target = target.long()

criterion = nn.CrossEntropyLoss()
return criterion(output, target)


class MoLLoss(nn.Module):
r""" Discretized mixture of logistic distributions loss

Adapted from wavenet vocoder
(https://github.com/r9y9/wavenet_vocoder/blob/master/wavenet_vocoder/mixture.py)
Explanation of loss (https://github.com/Rayhane-mamah/Tacotron-2/issues/155)

Args:
y_hat (Tensor): Predicted output (n_batch x n_time x n_channel)
y (Tensor): Target (n_batch x n_time x 1)
num_classes (int): Number of classes
log_scale_min (float): Log scale minimum value
reduce (bool): If True, the losses are averaged or summed for each minibatch

Returns
Tensor: loss
"""

def __init__(self, num_classes=65536, log_scale_min=None, reduce=True):
super(MoLLoss, self).__init__()
self.num_classes = num_classes
self.log_scale_min = log_scale_min
self.reduce = reduce

def forward(self, y_hat, y):
y = y.unsqueeze(-1)

if self.log_scale_min is None:
self.log_scale_min = math.log(1e-14)

assert y_hat.dim() == 3
assert y_hat.size(-1) % 3 == 0

nr_mix = y_hat.size(-1) // 3

# unpack parameters (n_batch, n_time, num_mixtures) x 3
logit_probs = y_hat[:, :, :nr_mix]
means = y_hat[:, :, nr_mix: 2 * nr_mix]
log_scales = torch.clamp(
y_hat[:, :, 2 * nr_mix: 3 * nr_mix], min=self.log_scale_min
)

# (n_batch x n_time x 1) to (n_batch x n_time x num_mixtures)
y = y.expand_as(means)

centered_y = y - means
inv_stdv = torch.exp(-log_scales)
plus_in = inv_stdv * (centered_y + 1.0 / (self.num_classes - 1))
cdf_plus = torch.sigmoid(plus_in)
min_in = inv_stdv * (centered_y - 1.0 / (self.num_classes - 1))
cdf_min = torch.sigmoid(min_in)

# log probability for edge case of 0 (before scaling)
# equivalent: torch.log(F.sigmoid(plus_in))
log_cdf_plus = plus_in - F.softplus(plus_in)

# log probability for edge case of 255 (before scaling)
# equivalent: (1 - F.sigmoid(min_in)).log()
log_one_minus_cdf_min = -F.softplus(min_in)

# probability for all other cases
cdf_delta = cdf_plus - cdf_min

mid_in = inv_stdv * centered_y
# log probability in the center of the bin, to be used in extreme cases
log_pdf_mid = mid_in - log_scales - 2.0 * F.softplus(mid_in)

inner_inner_cond = (cdf_delta > 1e-5).float()

inner_inner_out = inner_inner_cond * torch.log(
torch.clamp(cdf_delta, min=1e-12)
) + (1.0 - inner_inner_cond) * (
log_pdf_mid - math.log((self.num_classes - 1) / 2)
)
inner_cond = (y > 0.999).float()
inner_out = (
inner_cond * log_one_minus_cdf_min + (1.0 - inner_cond) * inner_inner_out
)
cond = (y < -0.999).float()
log_probs = cond * log_cdf_plus + (1.0 - cond) * inner_out

log_probs = log_probs + F.log_softmax(logit_probs, -1)

if self.reduce:
return -torch.mean(_log_sum_exp(log_probs))
else:
return -_log_sum_exp(log_probs).unsqueeze(-1)


def _log_sum_exp(x):
r""" Numerically stable log_sum_exp implementation that prevents overflow
"""

axis = len(x.size()) - 1
m, _ = torch.max(x, dim=axis)
m2, _ = torch.max(x, dim=axis, keepdim=True)
return m + torch.log(torch.sum(torch.exp(x - m2), dim=axis))
Loading