-
Notifications
You must be signed in to change notification settings - Fork 661
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add WaveRNN Model #735
Add WaveRNN Model #735
Conversation
Three failed tests look irrelevant to this PR. |
torchaudio/models/_wavernn.py
Outdated
It is a block used in WaveRNN. WaveRNN is based on the paper "Efficient Neural Audio Synthesis". | ||
Nal Kalchbrenner, Erich Elsen, Karen Simonyan, Seb Noury, Norman Casagrande, Edward Lockhart, | ||
Florian Stimberg, Aaron van den Oord, Sander Dieleman, Koray Kavukcuoglu. arXiv:1802.08435, 2018. | ||
It is a block used in WaveRNN. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In Python, the first line has to be complete and give the summary of what it is.
See https://www.python.org/dev/peps/pep-0257/#multi-line-docstrings
Starting with "This is ..." is lengthy but not adding a value, instead, try to come up with some useful information for first time reader. In this case something like ResNet block based on "Deep Residual Learning for Image Recognition"
would do.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
torchaudio/models/_wavernn.py
Outdated
for i in range(res_blocks): | ||
ResBlocks.append(_ResBlock(hidden_dims)) | ||
for i in range(n_res_block): | ||
ResBlocks.append(_ResBlock(n_hidden)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can use one-liner here for better readability.
ResBlocks = [_ResBlock(n_hidden) for _ in range(n_res_block)]
Use underscore as variable name when you do not use the variable.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice suggestion! fixed.
This part is in the the melresnet part PR #751.
torchaudio/models/_wavernn.py
Outdated
up_layers = [] | ||
for scale in upsample_scales: | ||
k_size = (1, scale * 2 + 1) | ||
padding = (0, scale) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
defining variables like this looks uncommon yet it does not improve readability.
just put variables in keyword argument and it is readable enough.
conv = nn.Conv2d(
in_channels=1, out_channels=1, kernel_size= (1, scale * 2 + 1), padding=(0, scale), bias=False)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed, thanks.
This part is in the upsampling part PR #724 .
torchaudio/models/_wavernn.py
Outdated
x: the input sequence to the _UpsampleNetwork layer | ||
|
||
Shape: | ||
- x: :math:`(batch, freq, time)`. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In Sphinx docstring, code block is double back-ticks.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I used the math block here inspired by the format in transformer. But I agree that the code block is a good idea used here which is clear and matches other places. I will update this part.
torchaudio/models/_wavernn.py
Outdated
mode='RAW') | ||
>>> x = torch.rand(10, 24800, 512) | ||
>>> mels = torch.rand(10, 128, 512) | ||
>>> output = upsamplenetwork(x, mels) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are these examples ran by documentation generator?
I do not see anything special about this usage, so unless it is here to be tested at documentation generation time, it is not adding much value to the documentation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree that the example doesn't the relationship between x and mels. How about the following?
>>> waveform = torchaudio.load(file) # shape: batch x channel x time
>>> specgram = MelSpectrogram(waveform) # shape: batch x freq x time`
>>> output = upsamplenetwork(waveform, specgram) # shape: ...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
torchaudio/models/_wavernn.py
Outdated
n_freq: the number of bins in a spectrogram (default=128) | ||
n_hidden: the number of hidden dimensions (default=128) | ||
n_output: the number of output dimensions (default=128) | ||
mode: the type of input waveform (default='RAW') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What are the other available options for mode
? Can you list all of them in docstring?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated in this PR.
if self.mode == 'RAW': | ||
self.n_classes = 2 ** n_bits | ||
elif self.mode == 'MOL': | ||
self.n_classes = 30 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you throw error when mode
is invalid?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated in this PR.
torchaudio/models/_wavernn.py
Outdated
|
||
batch_size = x.size(0) | ||
h1 = torch.zeros(1, batch_size, self.n_rnn, device=x.device) | ||
h2 = torch.zeros(1, batch_size, self.n_rnn, device=x.device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
dtype
is missing from h1
and h2
, so if float64
, is passed, this will fail.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
dtype
is added in this PR.
test/test_models.py
Outdated
class TestUpsampleNetwork(common_utils.TorchaudioTestCase): | ||
|
||
def test_waveform(self): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For each test, can you add comment on what is tested?
Test docstring can be one line or multiple lines but it should express what you are trying to test here.
It is often common that the written test is not what was intended to, and such wrong test is very difficult to detect and fix.
Use pytest test --collect-only -v
and see examples of the current tests.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
class TestWaveRNN(common_utils.TorchaudioTestCase): | ||
|
||
def test_waveform(self): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same here, explanation of test is required. otherwise it will be difficult to make proper changes to this test later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Codecov Report
@@ Coverage Diff @@
## master #735 +/- ##
==========================================
+ Coverage 89.16% 89.34% +0.17%
==========================================
Files 32 32
Lines 2566 2627 +61
==========================================
+ Hits 2288 2347 +59
- Misses 278 280 +2
Continue to review full report at Codecov.
|
torchaudio/models/_wavernn.py
Outdated
Nal Kalchbrenner, Erich Elsen, Karen Simonyan, Seb Noury, Norman Casagrande, Edward Lockhart, | ||
Florian Stimberg, Aaron van den Oord, Sander Dieleman, Koray Kavukcuoglu. arXiv:1802.08435, 2018. | ||
r"""ResNet block layer based on | ||
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
short description has to fit in one line. https://www.python.org/dev/peps/pep-0257/#one-line-docstrings
"""ResNet block based on "Deep Residual Learning for Image Recognition"
more detailed explanation here.
"""
test/test_models.py
Outdated
""" | ||
Create a tensor as the input of _MelResNet layer | ||
and test if the output dimensions are correct. | ||
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please use the same docstring notation.
"""Short description on the first line as the opening of docstring
Then the more detailed explanation only if necessary.
Every test should be simple so if you need to add more explanation,
you should also consider the design of test and simplify it.
"""
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated.
This change is updated in the melresnet PR #751 .
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for also opening #751 to update this changes
torchaudio/models/_wavernn.py
Outdated
) | ||
|
||
def forward(self, x: Tensor) -> Tensor: | ||
r"""Pass the input through the _MelResNet layer. | ||
|
||
r""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's your rational behind removing the original short description, Pass the input through the _MelResNet layer.
??
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
torchaudio/models/_wavernn.py
Outdated
|
||
|
||
class _Stretch2d(nn.Module): | ||
r"""Two-dimensional stretch layer. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I do not think stretch layer
is common enough and reading the description does not help understand what it does. Can you mention the outcome of the layer?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree, the description of outcome has been added.
I updated this change in PR #724 .
torchaudio/models/_wavernn.py
Outdated
|
||
|
||
class _UpsampleNetwork(nn.Module): | ||
r"""Upsample block based on a stack of Conv2d and Strech2d layers. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
based on a stack of Conv2d and Strech2d layers
Instead of mentioning how it's implemented internally, explain what it does.
I updated this change in PR #724 .
torchaudio/models/_wavernn.py
Outdated
n_output=128, | ||
kernel_size=5) | ||
>>> input = torch.rand(10, 128, 512) | ||
>>> output = upsamplenetwork(input) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This example is missing the most important information of what the output shape looks like
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
torchaudio/models/_wavernn.py
Outdated
|
||
class _WaveRNN(nn.Module): | ||
r"""WaveRNN model based on | ||
`"Efficient Neural Audio Synthesis" <https://arxiv.org/pdf/1802.08435.pdf>`_ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove the link from the summary of docstring and put it in the detail section.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed.
torchaudio/models/_wavernn.py
Outdated
elif self.mode == 'MOL': | ||
self.n_classes = 30 | ||
else: | ||
raise ValueError("Unknown input mode - {}".format(self.mode)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Adding what is expected values would improve user experience.
ValueError(f"Expected mode: `RAW` or `MOL`, but found {self.mode}")
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also instead of .format
, using f-string will improve the readability of code.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated in this PR.
torchaudio/models/_wavernn.py
Outdated
>>> upsamplenetwork = _waveRNN(upsample_scales=[5,5,8], | ||
n_bits=9, | ||
sample_rate=24000, | ||
hop_length=200, | ||
n_res_block=10, | ||
n_rnn=512, | ||
n_fc=512, | ||
kernel_size=5, | ||
n_freq=128, | ||
n_hidden=128, | ||
n_output=128, | ||
mode='RAW') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: This whole block needs to be a code block. Would it be nicer to write the following?
>>> ... = _waveRNN(
>>> n_bits=9,
>>> ...
>>> )
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In fact, can you only specify the parameters that are required and omitted optional ones in this list?
>>> upsamplenetwork = _waveRNN(
>>> upsample_scales=[5,5,8],
>>> n_bits=9,
>>> sample_rate=24000,
>>> hop_length=200,
>>> )
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
torchaudio/models/_wavernn.py
Outdated
mode='RAW') | ||
>>> x = torch.rand(10, 24800, 512) | ||
>>> mels = torch.rand(10, 128, 512) | ||
>>> output = upsamplenetwork(x, mels) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree that the example doesn't the relationship between x and mels. How about the following?
>>> waveform = torchaudio.load(file) # shape: batch x channel x time
>>> specgram = MelSpectrogram(waveform) # shape: batch x freq x time`
>>> output = upsamplenetwork(waveform, specgram) # shape: ...
torchaudio/models/_wavernn.py
Outdated
n_output: the number of output dimensions (default=128) | ||
mode: the type of input waveform in ['RAW', 'MOL'] (default='RAW') | ||
|
||
Examples:: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: no needs for "::", see example
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
torchaudio/models/_wavernn.py
Outdated
batch_size = waveform.size(0) | ||
h1 = torch.zeros(1, batch_size, self.n_rnn, dtype=waveform.dtype, device=waveform.device) | ||
h2 = torch.zeros(1, batch_size, self.n_rnn, dtype=waveform.dtype, device=waveform.device) | ||
mels, aux = self.upsample(specgram) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Need to add transpose here because the transpose operation in upsampling part is removed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since this may not be "standard" compared to other implementations, could you add a comment like this?
# output of upsample is batch x ...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Listed a few nits, but overall looks good :)
torchaudio/models/_wavernn.py
Outdated
Examples | ||
>>> wavernn = _waveRNN(upsample_scales=[5,5,8], n_bits=9, sample_rate=24000, hop_length=200) | ||
>>> waveform, sample_rate = torchaudio.load(file) # waveform shape: | ||
>>> (n_batch, n_channel, (n_time - kernel_size + 1) * hop_length) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: presented like this, the shape looks like a new command. how about the following if the line is too long?
>>> waveform, sample_rate = torchaudio.load(file)
>>> # waveform shape: (n_batch, n_channel, (n_time - kernel_size + 1) * hop_length)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good suggestion. Updated.
torchaudio/models/_wavernn.py
Outdated
>>> (n_batch, n_channel, (n_time - kernel_size + 1) * hop_length) | ||
>>> specgram = MelSpectrogram(sample_rate)(waveform) # shape: (n_batch, n_channel, n_freq, n_time) | ||
>>> output = wavernn(waveform.squeeze(1), specgram.squeeze(1)) # shape: | ||
>>> (n_batch, (n_time - kernel_size + 1) * hop_length, 2 ** n_bits) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: same
>>> output = wavernn(waveform.squeeze(1), specgram.squeeze(1))
>>> # output shape: (n_batch, (n_time - kernel_size + 1) * hop_length, 2 ** n_bits)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated.
n_res_block: the number of ResBlock in stack (default=10) | ||
n_rnn: the dimension of RNN layer (default=512) | ||
n_fc: the dimension of fully connected layer (default=512) | ||
kernel_size: the number of kernel size in the first Conv1d layer (default=5) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
torchaudio/models/_wavernn.py
Outdated
batch_size = waveform.size(0) | ||
h1 = torch.zeros(1, batch_size, self.n_rnn, dtype=waveform.dtype, device=waveform.device) | ||
h2 = torch.zeros(1, batch_size, self.n_rnn, dtype=waveform.dtype, device=waveform.device) | ||
mels, aux = self.upsample(specgram) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since this may not be "standard" compared to other implementations, could you add a comment like this?
# output of upsample is batch x ...
Failed tests relate to #766 |
torchaudio/models/_wavernn.py
Outdated
class _WaveRNN(nn.Module): | ||
r"""WaveRNN model based on "Efficient Neural Audio Synthesis". | ||
|
||
The paper link is `<https://arxiv.org/pdf/1802.08435.pdf>`_. The input channels of waveform |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: For more info see the paper Efficient Neural Audio Synthesis <https://arxiv.org/pdf/1802.08435.pdf>
_
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Given the feedback in comment, let's make it more explicit that this version is not the original.
"""
WaveRNN model based on the implementation from `fatchord <https://github.com/fatchord/WaveRNN>`_ .
The original implementation was introduced in `"Efficient Neural Audio Synthesis" <https://arxiv.org/pdf/1802.08435.pdf>`_.
The input channels of waveform and spectrogram have to be 1. The product of `upsample_scales` must equal `hop_length`.
"""
If it doesn't make the first line too long, I would merge the first and second paragraph.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Another option is to add a flag in WaveRNN init, say version
that can only take the value fatchord
for now, but could eventually also take the value deepmind
later if/when implemented. Thoughts?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(Thanks for the feedback in comment @PetrochukM :))
torchaudio/models/_wavernn.py
Outdated
@@ -320,5 +324,6 @@ def forward(self, waveform: Tensor, specgram: Tensor) -> Tensor: | |||
x = torch.cat([x, a4], dim=-1) | |||
x = self.fc2(x) | |||
x = self.relu2(x) | |||
x = self.fc3(x).unsqueeze(1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: let's make that a new paragraph :)
x = torch.cat([x, a4], dim=-1)
x = self.fc2(x)
x = self.relu2(x)
x = self.fc3(x)
# bring back channel dimension
return x.unsqueeze(1)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The comment has been added.
""" | ||
|
||
assert waveform.size(1) == 1, 'Require the input channel of waveform is 1' | ||
assert specgram.size(1) == 1, 'Require the input channel of specgram is 1' | ||
waveform, specgram = waveform.squeeze(1), specgram.squeeze(1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit:
# remove channel dimension until the end
waveform, specgram = waveform.squeeze(1), specgram.squeeze(1)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The comment has been added.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's add the two minor comments I suggested, but otherwise LGTM!
if self.mode == 'waveform': | ||
self.n_classes = 2 ** n_bits | ||
elif self.mode == 'mol': | ||
self.n_classes = 30 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's replace mode
and n_bits
parameters simply by n_classes
.
cc comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I changed the mode
to loss
and replaced n_bits
by n_classes
in #797 .
if total_scale != self.hop_length: | ||
raise ValueError(f"Expected: total_scale == hop_length, but found {total_scale} != {hop_length}") | ||
|
||
self.upsample = _UpsampleNetwork(upsample_scales, n_res_block, n_freq, n_hidden, n_output, kernel_size) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's change to n_hidden_resblock
and n_output_upsample
. Thoughts?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I updated the name as n_hidden_resblock
and n_output_melresnet
in #797 .
because the n_output
is used in melresenet block and melresnet block is one part of upsample block, so I use n_output_melresnet
. Any suggestion?
* Add WaveRNN example This is the pipeline example based on [WaveRNN model](#735) in torchaudio. The design of this pipeline is inspired by [#632](#632). It offers a standardized implementation of WaveRNN vocoder in torchaudio. * Add utils and readme The metric logger is added based on the Wav2letter pipeline [#632](#632). It offers the way to parse the standard output as described in readme. * Add channel dimension The channel dimension of waveform in datasets is added to match the input dimensions of WaveRNN model because the channel dimensions of waveform and spectrogram are added in [this part] (https://github.com/pytorch/audio/blob/master/torchaudio/models/_wavernn.py#L281) of WaveRNN model. * Update date split and transform The design of dataset structure is discussed in [this comment](#749 (comment)). Now the dataset file has a clearer workflow after using the random-split function instead of walking through all the files. All transform functions are put together inside the transforms block. Co-authored-by: Ji Chen <[email protected]>
This is the WaveRNN model.
Related to #446
Stack:
Add MelResNet Block #705, #751Add Upsampling Block #724Add WaveRNN Model #735Add example pipeline with WaveRNN #749