forked from ibab/tensorflow-wavenet
-
Notifications
You must be signed in to change notification settings - Fork 0
/
model.py
524 lines (446 loc) · 21.8 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
import tensorflow as tf
from .ops import causal_conv, mu_law_encode
def create_variable(name, shape):
'''Create a convolution filter variable with the specified name and shape,
and initialize it using Xavier initialition.'''
initializer = tf.contrib.layers.xavier_initializer_conv2d()
variable = tf.Variable(initializer(shape=shape), name=name)
return variable
def create_bias_variable(name, shape):
'''Create a bias variable with the specified name and shape and initialize
it to zero.'''
initializer = tf.constant_initializer(value=0.0, dtype=tf.float32)
return tf.Variable(initializer(shape=shape), name)
class WaveNetModel(object):
'''Implements the WaveNet network for generative audio.
Usage (with the architecture as in the DeepMind paper):
dilations = [2**i for i in range(N)] * M
filter_width = 2 # Convolutions just use 2 samples.
residual_channels = 16 # Not specified in the paper.
dilation_channels = 32 # Not specified in the paper.
skip_channels = 16 # Not specified in the paper.
net = WaveNetModel(batch_size, dilations, filter_width,
residual_channels, dilation_channels,
skip_channels)
loss = net.loss(input_batch)
'''
def __init__(self,
batch_size,
dilations,
filter_width,
residual_channels,
dilation_channels,
skip_channels,
quantization_channels=2**8,
use_biases=False,
scalar_input=False,
initial_filter_width=32,
histograms=False):
'''Initializes the WaveNet model.
Args:
batch_size: How many audio files are supplied per batch
(recommended: 1).
dilations: A list with the dilation factor for each layer.
filter_width: The samples that are included in each convolution,
after dilating.
residual_channels: How many filters to learn for the residual.
dilation_channels: How many filters to learn for the dilated
convolution.
skip_channels: How many filters to learn that contribute to the
quantized softmax output.
quantization_channels: How many amplitude values to use for audio
quantization and the corresponding one-hot encoding.
Default: 256 (8-bit quantization).
use_biases: Whether to add a bias layer to each convolution.
Default: False.
scalar_input: Whether to use the quantized waveform directly as
input to the network instead of one-hot encoding it.
Default: False.
initial_filter_width: The width of the initial filter of the
convolution applied to the scalar input. This is only relevant
if scalar_input=True.
histograms: Whether to store histograms in the summary.
Default: False.
'''
self.batch_size = batch_size
self.dilations = dilations
self.filter_width = filter_width
self.residual_channels = residual_channels
self.dilation_channels = dilation_channels
self.quantization_channels = quantization_channels
self.use_biases = use_biases
self.skip_channels = skip_channels
self.scalar_input = scalar_input
self.initial_filter_width = initial_filter_width
self.histograms = histograms
self.variables = self._create_variables()
def _create_variables(self):
'''This function creates all variables used by the network.
This allows us to share them between multiple calls to the loss
function and generation function.'''
var = dict()
with tf.variable_scope('wavenet'):
with tf.variable_scope('causal_layer'):
layer = dict()
if self.scalar_input:
initial_channels = 1
initial_filter_width = self.initial_filter_width
else:
initial_channels = self.quantization_channels
initial_filter_width = self.filter_width
layer['filter'] = create_variable(
'filter',
[initial_filter_width,
initial_channels,
self.residual_channels])
var['causal_layer'] = layer
var['dilated_stack'] = list()
with tf.variable_scope('dilated_stack'):
for i, dilation in enumerate(self.dilations):
with tf.variable_scope('layer{}'.format(i)):
current = dict()
current['filter'] = create_variable(
'filter',
[self.filter_width,
self.residual_channels,
self.dilation_channels])
current['gate'] = create_variable(
'gate',
[self.filter_width,
self.residual_channels,
self.dilation_channels])
current['dense'] = create_variable(
'dense',
[1,
self.dilation_channels,
self.residual_channels])
current['skip'] = create_variable(
'skip',
[1,
self.dilation_channels,
self.skip_channels])
if self.use_biases:
current['filter_bias'] = create_bias_variable(
'filter_bias',
[self.dilation_channels])
current['gate_bias'] = create_bias_variable(
'gate_bias',
[self.dilation_channels])
current['dense_bias'] = create_bias_variable(
'dense_bias',
[self.residual_channels])
current['skip_bias'] = create_bias_variable(
'slip_bias',
[self.skip_channels])
var['dilated_stack'].append(current)
with tf.variable_scope('postprocessing'):
current = dict()
current['postprocess1'] = create_variable(
'postprocess1',
[1, self.skip_channels, self.skip_channels])
current['postprocess2'] = create_variable(
'postprocess2',
[1, self.skip_channels, self.quantization_channels])
if self.use_biases:
current['postprocess1_bias'] = create_bias_variable(
'postprocess1_bias',
[self.skip_channels])
current['postprocess2_bias'] = create_bias_variable(
'postprocess2_bias',
[self.quantization_channels])
var['postprocessing'] = current
return var
def _create_causal_layer(self, input_batch):
'''Creates a single causal convolution layer.
The layer can change the number of channels.
'''
with tf.name_scope('causal_layer'):
weights_filter = self.variables['causal_layer']['filter']
return causal_conv(input_batch, weights_filter, 1)
def _create_dilation_layer(self, input_batch, layer_index, dilation):
'''Creates a single causal dilated convolution layer.
The layer contains a gated filter that connects to dense output
and to a skip connection:
|-> [gate] -| |-> 1x1 conv -> skip output
| |-> (*) -|
input -|-> [filter] -| |-> 1x1 conv -|
| |-> (+) -> dense output
|------------------------------------|
Where `[gate]` and `[filter]` are causal convolutions with a
non-linear activation at the output.
'''
variables = self.variables['dilated_stack'][layer_index]
weights_filter = variables['filter']
weights_gate = variables['gate']
conv_filter = causal_conv(input_batch, weights_filter, dilation)
conv_gate = causal_conv(input_batch, weights_gate, dilation)
if self.use_biases:
filter_bias = variables['filter_bias']
gate_bias = variables['gate_bias']
conv_filter = tf.add(conv_filter, filter_bias)
conv_gate = tf.add(conv_gate, gate_bias)
out = tf.tanh(conv_filter) * tf.sigmoid(conv_gate)
# The 1x1 conv to produce the residual output
weights_dense = variables['dense']
transformed = tf.nn.conv1d(
out, weights_dense, stride=1, padding="SAME", name="dense")
# The 1x1 conv to produce the skip output
weights_skip = variables['skip']
skip_contribution = tf.nn.conv1d(
out, weights_skip, stride=1, padding="SAME", name="skip")
if self.use_biases:
dense_bias = variables['dense_bias']
skip_bias = variables['skip_bias']
transformed = transformed + dense_bias
skip_contribution = skip_contribution + skip_bias
if self.histograms:
layer = 'layer{}'.format(layer_index)
tf.histogram_summary(layer + '_filter', weights_filter)
tf.histogram_summary(layer + '_gate', weights_gate)
tf.histogram_summary(layer + '_dense', weights_dense)
tf.histogram_summary(layer + '_skip', weights_skip)
if self.use_biases:
tf.histogram_summary(layer + '_biases_filter', filter_bias)
tf.histogram_summary(layer + '_biases_gate', gate_bias)
tf.histogram_summary(layer + '_biases_dense', dense_bias)
tf.histogram_summary(layer + '_biases_skip', skip_bias)
return skip_contribution, input_batch + transformed
def _generator_conv(self, input_batch, state_batch, weights):
'''Perform convolution for a single convolutional processing step.'''
# TODO generalize to filter_width > 2
past_weights = weights[0, :, :]
curr_weights = weights[1, :, :]
output = tf.matmul(state_batch, past_weights) + tf.matmul(
input_batch, curr_weights)
return output
def _generator_causal_layer(self, input_batch, state_batch):
with tf.name_scope('causal_layer'):
weights_filter = self.variables['causal_layer']['filter']
output = self._generator_conv(
input_batch, state_batch, weights_filter)
return output
def _generator_dilation_layer(self, input_batch, state_batch, layer_index,
dilation):
variables = self.variables['dilated_stack'][layer_index]
weights_filter = variables['filter']
weights_gate = variables['gate']
output_filter = self._generator_conv(
input_batch, state_batch, weights_filter)
output_gate = self._generator_conv(
input_batch, state_batch, weights_gate)
if self.use_biases:
output_filter = output_filter + variables['filter_bias']
output_gate = output_gate + variables['gate_bias']
out = tf.tanh(output_filter) * tf.sigmoid(output_gate)
weights_dense = variables['dense']
transformed = tf.matmul(out, weights_dense[0, :, :])
if self.use_biases:
transformed = transformed + variables['dense_bias']
weights_skip = variables['skip']
skip_contribution = tf.matmul(out, weights_skip[0, :, :])
if self.use_biases:
skip_contribution = skip_contribution + variables['skip_bias']
return skip_contribution, input_batch + transformed
def _create_network(self, input_batch):
'''Construct the WaveNet network.'''
outputs = []
current_layer = input_batch
# Pre-process the input with a regular convolution
if self.scalar_input:
initial_channels = 1
else:
initial_channels = self.quantization_channels
current_layer = self._create_causal_layer(current_layer)
# Add all defined dilation layers.
with tf.name_scope('dilated_stack'):
for layer_index, dilation in enumerate(self.dilations):
with tf.name_scope('layer{}'.format(layer_index)):
output, current_layer = self._create_dilation_layer(
current_layer, layer_index, dilation)
outputs.append(output)
with tf.name_scope('postprocessing'):
# Perform (+) -> ReLU -> 1x1 conv -> ReLU -> 1x1 conv to
# postprocess the output.
w1 = self.variables['postprocessing']['postprocess1']
w2 = self.variables['postprocessing']['postprocess2']
if self.use_biases:
b1 = self.variables['postprocessing']['postprocess1_bias']
b2 = self.variables['postprocessing']['postprocess2_bias']
if self.histograms:
tf.histogram_summary('postprocess1_weights', w1)
tf.histogram_summary('postprocess2_weights', w2)
if self.use_biases:
tf.histogram_summary('postprocess1_biases', b1)
tf.histogram_summary('postprocess2_biases', b2)
# We skip connections from the outputs of each layer, adding them
# all up here.
total = sum(outputs)
transformed1 = tf.nn.relu(total)
conv1 = tf.nn.conv1d(transformed1, w1, stride=1, padding="SAME")
if self.use_biases:
conv1 = tf.add(conv1, b1)
transformed2 = tf.nn.relu(conv1)
conv2 = tf.nn.conv1d(transformed2, w2, stride=1, padding="SAME")
if self.use_biases:
conv2 = tf.add(conv2, b2)
return conv2
def _create_generator(self, input_batch):
'''Construct an efficient incremental generator.'''
init_ops = []
push_ops = []
outputs = []
current_layer = input_batch
q = tf.FIFOQueue(
1,
dtypes=tf.float32,
shapes=(self.batch_size, self.quantization_channels))
init = q.enqueue_many(
tf.zeros((1, self.batch_size, self.quantization_channels)))
current_state = q.dequeue()
push = q.enqueue([current_layer])
init_ops.append(init)
push_ops.append(push)
current_layer = self._generator_causal_layer(
current_layer, current_state)
# Add all defined dilation layers.
with tf.name_scope('dilated_stack'):
for layer_index, dilation in enumerate(self.dilations):
with tf.name_scope('layer{}'.format(layer_index)):
q = tf.FIFOQueue(
dilation,
dtypes=tf.float32,
shapes=(self.batch_size, self.residual_channels))
init = q.enqueue_many(
tf.zeros((dilation, self.batch_size,
self.residual_channels)))
current_state = q.dequeue()
push = q.enqueue([current_layer])
init_ops.append(init)
push_ops.append(push)
output, current_layer = self._generator_dilation_layer(
current_layer, current_state, layer_index, dilation)
outputs.append(output)
self.init_ops = init_ops
self.push_ops = push_ops
with tf.name_scope('postprocessing'):
variables = self.variables['postprocessing']
# Perform (+) -> ReLU -> 1x1 conv -> ReLU -> 1x1 conv to
# postprocess the output.
w1 = variables['postprocess1']
w2 = variables['postprocess2']
if self.use_biases:
b1 = variables['postprocess1_bias']
b2 = variables['postprocess2_bias']
# We skip connections from the outputs of each layer, adding them
# all up here.
total = sum(outputs)
transformed1 = tf.nn.relu(total)
conv1 = tf.matmul(transformed1, w1[0, :, :])
if self.use_biases:
conv1 = conv1 + b1
transformed2 = tf.nn.relu(conv1)
conv2 = tf.matmul(transformed2, w2[0, :, :])
if self.use_biases:
conv2 = conv2 + b2
return conv2
def _one_hot(self, input_batch):
'''One-hot encodes the waveform amplitudes.
This allows the definition of the network as a categorical distribution
over a finite set of possible amplitudes.
'''
with tf.name_scope('one_hot_encode'):
encoded = tf.one_hot(
input_batch,
depth=self.quantization_channels,
dtype=tf.float32)
shape = [self.batch_size, -1, self.quantization_channels]
encoded = tf.reshape(encoded, shape)
return encoded
def predict_proba(self, waveform, name='wavenet'):
'''Computes the probability distribution of the next sample based on
all samples in the input waveform.
If you want to generate audio by feeding the output of the network back
as an input, see predict_proba_incremental for a faster alternative.'''
with tf.name_scope(name):
if self.scalar_input:
encoded = tf.cast(waveform, tf.float32)
encoded = tf.reshape(encoded, [-1, 1])
else:
encoded = self._one_hot(waveform)
raw_output = self._create_network(encoded)
out = tf.reshape(raw_output, [-1, self.quantization_channels])
# Cast to float64 to avoid bug in TensorFlow
proba = tf.cast(
tf.nn.softmax(tf.cast(out, tf.float64)), tf.float32)
last = tf.slice(
proba,
[tf.shape(proba)[0] - 1, 0],
[1, self.quantization_channels])
return tf.reshape(last, [-1])
def predict_proba_incremental(self, waveform, name='wavenet'):
'''Computes the probability distribution of the next sample
incrementally, based on a single sample and all previously passed
samples.'''
if self.filter_width > 2:
raise NotImplementedError("Incremental generation does not "
"support filter_width > 2.")
if self.scalar_input:
raise NotImplementedError("Incremental generation does not "
"support scalar input yet.")
with tf.name_scope(name):
encoded = tf.one_hot(waveform, self.quantization_channels)
encoded = tf.reshape(encoded, [-1, self.quantization_channels])
raw_output = self._create_generator(encoded)
out = tf.reshape(raw_output, [-1, self.quantization_channels])
proba = tf.cast(
tf.nn.softmax(tf.cast(out, tf.float64)), tf.float32)
last = tf.slice(
proba,
[tf.shape(proba)[0] - 1, 0],
[1, self.quantization_channels])
return tf.reshape(last, [-1])
def loss(self,
input_batch,
l2_regularization_strength=None,
name='wavenet'):
'''Creates a WaveNet network and returns the autoencoding loss.
The variables are all scoped to the given name.
'''
with tf.name_scope(name):
# We mu-law encode and quantize the input audioform.
input_batch = mu_law_encode(input_batch,
self.quantization_channels)
encoded = self._one_hot(input_batch)
if self.scalar_input:
network_input = tf.reshape(
tf.cast(input_batch, tf.float32),
[self.batch_size, -1, 1])
else:
network_input = encoded
raw_output = self._create_network(network_input)
with tf.name_scope('loss'):
# Shift original input left by one sample, which means that
# each output sample has to predict the next input sample.
shifted = tf.slice(encoded, [0, 1, 0],
[-1, tf.shape(encoded)[1] - 1, -1])
shifted = tf.pad(shifted, [[0, 0], [0, 1], [0, 0]])
prediction = tf.reshape(raw_output,
[-1, self.quantization_channels])
loss = tf.nn.softmax_cross_entropy_with_logits(
prediction,
tf.reshape(shifted, [-1, self.quantization_channels]))
reduced_loss = tf.reduce_mean(loss)
tf.scalar_summary('loss', reduced_loss)
if l2_regularization_strength is None:
return reduced_loss
else:
# L2 regularization for all trainable parameters
l2_loss = tf.add_n([tf.nn.l2_loss(v)
for v in tf.trainable_variables()
if not('bias' in v.name)])
# Add the regularization term to the loss
total_loss = (reduced_loss +
l2_regularization_strength * l2_loss)
tf.scalar_summary('l2_loss', l2_loss)
tf.scalar_summary('total_loss', total_loss)
return total_loss