-
Notifications
You must be signed in to change notification settings - Fork 667
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add wavernn example pipeline #749
Merged
Merged
Changes from 28 commits
Commits
Show all changes
29 commits
Select commit
Hold shift + click to select a range
49b511d
add wavernn example
4728e3d
update dataset
d17b9d1
update input type
17560e7
add transform and mol loss
9717b75
update format and add utils and readme
131fe33
update model import
1d1c683
update readme
969966c
update readme
4f9cc60
add reference in readme
4671bed
add channel dimension
55f866b
Update the transform and dataset function
c43149d
Add function doctring
0b944b4
Use default argument
79c0ded
Update dataset
553b170
update dataset function
73f22d2
update data split and transform
a8aca08
update format
a47d00f
update logger
c4ff493
update import format
3434e16
move condition in statement
1295d8a
add loss class and change function name
ee1d702
update format
f725457
update varible name
b6198d8
update variable name in wavernn
a43b8a2
update mode
6f8660a
update loss class
0df67b7
add underscore in mol loss
1c426c0
add jit and underscore
b306f68
change two command line parameters
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
This is an example vocoder pipeline using the WaveRNN model trained with LJSpeech. WaveRNN model is based on the implementation from [this repository](https://github.com/fatchord/WaveRNN). The original implementation was | ||
introduced in "Efficient Neural Audio Synthesis". WaveRNN and LJSpeech are available in torchaudio. | ||
|
||
### Usage | ||
|
||
An example can be invoked as follows. | ||
``` | ||
python main.py \ | ||
--batch-size 256 \ | ||
--learning-rate 1e-4 \ | ||
--n-freq 80 \ | ||
--mode 'waveform' \ | ||
--n-bits 8 \ | ||
``` | ||
|
||
### Output | ||
|
||
The information reported at each iteration and epoch (e.g. loss) is printed to standard output in the form of one json per line. Here is an example python function to parse the output if redirected to a file. | ||
```python | ||
def read_json(filename): | ||
""" | ||
Convert the standard output saved to filename into a pandas dataframe for analysis. | ||
""" | ||
|
||
import pandas | ||
import json | ||
|
||
with open(filename, "r") as f: | ||
data = f.read() | ||
|
||
# pandas doesn't read single quotes for json | ||
data = data.replace("'", '"') | ||
|
||
data = [json.loads(l) for l in data.splitlines()] | ||
return pandas.DataFrame(data) | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,115 @@ | ||
import os | ||
import random | ||
|
||
import torch | ||
import torchaudio | ||
from torch.utils.data.dataset import random_split | ||
from torchaudio.datasets import LJSPEECH | ||
from torchaudio.transforms import MuLawEncoding | ||
|
||
from processing import bits_to_normalized_waveform, normalized_waveform_to_bits | ||
|
||
|
||
class MapMemoryCache(torch.utils.data.Dataset): | ||
r"""Wrap a dataset so that, whenever a new item is returned, it is saved to memory. | ||
""" | ||
|
||
def __init__(self, dataset): | ||
self.dataset = dataset | ||
self._cache = [None] * len(dataset) | ||
|
||
def __getitem__(self, n): | ||
if self._cache[n] is not None: | ||
return self._cache[n] | ||
|
||
item = self.dataset[n] | ||
self._cache[n] = item | ||
|
||
return item | ||
|
||
def __len__(self): | ||
return len(self.dataset) | ||
|
||
|
||
class Processed(torch.utils.data.Dataset): | ||
def __init__(self, dataset, transforms): | ||
self.dataset = dataset | ||
self.transforms = transforms | ||
|
||
def __getitem__(self, key): | ||
item = self.dataset[key] | ||
return self.process_datapoint(item) | ||
|
||
def __len__(self): | ||
return len(self.dataset) | ||
|
||
def process_datapoint(self, item): | ||
specgram = self.transforms(item[0]) | ||
return item[0].squeeze(0), specgram | ||
|
||
|
||
def split_process_ljspeech(args, transforms): | ||
data = LJSPEECH(root=args.file_path, download=False) | ||
vincentqb marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
val_length = int(len(data) * args.val_ratio) | ||
lengths = [len(data) - val_length, val_length] | ||
train_dataset, val_dataset = random_split(data, lengths) | ||
|
||
train_dataset = Processed(train_dataset, transforms) | ||
val_dataset = Processed(val_dataset, transforms) | ||
|
||
train_dataset = MapMemoryCache(train_dataset) | ||
val_dataset = MapMemoryCache(val_dataset) | ||
|
||
return train_dataset, val_dataset | ||
vincentqb marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
|
||
def collate_factory(args): | ||
def raw_collate(batch): | ||
|
||
pad = (args.kernel_size - 1) // 2 | ||
|
||
# input waveform length | ||
wave_length = args.hop_length * args.seq_len_factor | ||
# input spectrogram length | ||
spec_length = args.seq_len_factor + pad * 2 | ||
|
||
# max start postion in spectrogram | ||
max_offsets = [x[1].shape[-1] - (spec_length + pad * 2) for x in batch] | ||
|
||
# random start postion in spectrogram | ||
spec_offsets = [random.randint(0, offset) for offset in max_offsets] | ||
# random start postion in waveform | ||
wave_offsets = [(offset + pad) * args.hop_length for offset in spec_offsets] | ||
|
||
waveform_combine = [ | ||
x[0][wave_offsets[i]: wave_offsets[i] + wave_length + 1] | ||
for i, x in enumerate(batch) | ||
] | ||
specgram = [ | ||
x[1][:, spec_offsets[i]: spec_offsets[i] + spec_length] | ||
for i, x in enumerate(batch) | ||
] | ||
|
||
specgram = torch.stack(specgram) | ||
waveform_combine = torch.stack(waveform_combine) | ||
|
||
waveform = waveform_combine[:, :wave_length] | ||
target = waveform_combine[:, 1:] | ||
|
||
# waveform: [-1, 1], target: [0, 2**bits-1] if loss = 'crossentropy' | ||
if args.loss == "crossentropy": | ||
|
||
if args.mulaw: | ||
mulaw_encode = MuLawEncoding(2 ** args.n_bits) | ||
waveform = mulaw_encode(waveform) | ||
target = mulaw_encode(target) | ||
|
||
waveform = bits_to_normalized_waveform(waveform, args.n_bits) | ||
|
||
else: | ||
target = normalized_waveform_to_bits(target, args.n_bits) | ||
|
||
return waveform.unsqueeze(1), specgram.unsqueeze(1), target.unsqueeze(1) | ||
|
||
return raw_collate |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,119 @@ | ||
import math | ||
|
||
import torch | ||
from torch import nn as nn | ||
from torch.nn import functional as F | ||
|
||
|
||
class LongCrossEntropyLoss(nn.Module): | ||
r""" CrossEntropy loss | ||
""" | ||
|
||
def __init__(self): | ||
super(LongCrossEntropyLoss, self).__init__() | ||
|
||
def forward(self, output, target): | ||
output = output.transpose(1, 2) | ||
target = target.long() | ||
|
||
criterion = nn.CrossEntropyLoss() | ||
return criterion(output, target) | ||
|
||
|
||
class MoLLoss(nn.Module): | ||
r""" Discretized mixture of logistic distributions loss | ||
|
||
Adapted from wavenet vocoder | ||
(https://github.com/r9y9/wavenet_vocoder/blob/master/wavenet_vocoder/mixture.py) | ||
Explanation of loss (https://github.com/Rayhane-mamah/Tacotron-2/issues/155) | ||
|
||
Args: | ||
y_hat (Tensor): Predicted output (n_batch x n_time x n_channel) | ||
y (Tensor): Target (n_batch x n_time x 1) | ||
num_classes (int): Number of classes | ||
log_scale_min (float): Log scale minimum value | ||
reduce (bool): If True, the losses are averaged or summed for each minibatch | ||
|
||
Returns | ||
Tensor: loss | ||
""" | ||
|
||
def __init__(self, num_classes=65536, log_scale_min=None, reduce=True): | ||
super(MoLLoss, self).__init__() | ||
self.num_classes = num_classes | ||
self.log_scale_min = log_scale_min | ||
self.reduce = reduce | ||
|
||
def forward(self, y_hat, y): | ||
y = y.unsqueeze(-1) | ||
|
||
if self.log_scale_min is None: | ||
self.log_scale_min = math.log(1e-14) | ||
|
||
assert y_hat.dim() == 3 | ||
assert y_hat.size(-1) % 3 == 0 | ||
|
||
nr_mix = y_hat.size(-1) // 3 | ||
|
||
# unpack parameters (n_batch, n_time, num_mixtures) x 3 | ||
logit_probs = y_hat[:, :, :nr_mix] | ||
means = y_hat[:, :, nr_mix: 2 * nr_mix] | ||
log_scales = torch.clamp( | ||
y_hat[:, :, 2 * nr_mix: 3 * nr_mix], min=self.log_scale_min | ||
) | ||
|
||
# (n_batch x n_time x 1) to (n_batch x n_time x num_mixtures) | ||
y = y.expand_as(means) | ||
|
||
centered_y = y - means | ||
inv_stdv = torch.exp(-log_scales) | ||
plus_in = inv_stdv * (centered_y + 1.0 / (self.num_classes - 1)) | ||
cdf_plus = torch.sigmoid(plus_in) | ||
min_in = inv_stdv * (centered_y - 1.0 / (self.num_classes - 1)) | ||
cdf_min = torch.sigmoid(min_in) | ||
|
||
# log probability for edge case of 0 (before scaling) | ||
# equivalent: torch.log(F.sigmoid(plus_in)) | ||
log_cdf_plus = plus_in - F.softplus(plus_in) | ||
|
||
# log probability for edge case of 255 (before scaling) | ||
# equivalent: (1 - F.sigmoid(min_in)).log() | ||
log_one_minus_cdf_min = -F.softplus(min_in) | ||
|
||
# probability for all other cases | ||
cdf_delta = cdf_plus - cdf_min | ||
|
||
mid_in = inv_stdv * centered_y | ||
# log probability in the center of the bin, to be used in extreme cases | ||
log_pdf_mid = mid_in - log_scales - 2.0 * F.softplus(mid_in) | ||
|
||
inner_inner_cond = (cdf_delta > 1e-5).float() | ||
|
||
inner_inner_out = inner_inner_cond * torch.log( | ||
torch.clamp(cdf_delta, min=1e-12) | ||
) + (1.0 - inner_inner_cond) * ( | ||
log_pdf_mid - math.log((self.num_classes - 1) / 2) | ||
) | ||
inner_cond = (y > 0.999).float() | ||
inner_out = ( | ||
inner_cond * log_one_minus_cdf_min + (1.0 - inner_cond) * inner_inner_out | ||
) | ||
cond = (y < -0.999).float() | ||
log_probs = cond * log_cdf_plus + (1.0 - cond) * inner_out | ||
|
||
log_probs = log_probs + F.log_softmax(logit_probs, -1) | ||
|
||
if self.reduce: | ||
return -torch.mean(_log_sum_exp(log_probs)) | ||
else: | ||
return -_log_sum_exp(log_probs).unsqueeze(-1) | ||
|
||
|
||
def _log_sum_exp(x): | ||
r""" Numerically stable log_sum_exp implementation that prevents overflow | ||
""" | ||
|
||
axis = len(x.size()) - 1 | ||
m, _ = torch.max(x, dim=axis) | ||
m2, _ = torch.max(x, dim=axis, keepdim=True) | ||
return m + torch.log(torch.sum(torch.exp(x - m2), dim=axis)) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As a separate pull request, we could consider adding the plots for training and validation errors in this readme for the pre-trained weight. Either in #776 or another one.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good suggestion, I will have a separate pull request to add the plots.