forked from jaanli/variational-autoencoder
-
Notifications
You must be signed in to change notification settings - Fork 0
/
train_variational_autoencoder_tensorflow.py
246 lines (204 loc) · 9.57 KB
/
train_variational_autoencoder_tensorflow.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
import itertools
import matplotlib as mpl
import numpy as np
import os
import tensorflow as tf
import tensorflow.contrib.slim as slim
import time
import seaborn as sns
from matplotlib import pyplot as plt
from imageio import imwrite
from tensorflow.contrib.learn.python.learn.datasets.mnist import read_data_sets
sns.set_style('whitegrid')
distributions = tf.distributions
flags = tf.app.flags
flags.DEFINE_string('data_dir', '/tmp/dat/', 'Directory for data')
flags.DEFINE_string('logdir', '/tmp/log/', 'Directory for logs')
# For making plots:
# flags.DEFINE_integer('latent_dim', 2, 'Latent dimensionality of model')
# flags.DEFINE_integer('batch_size', 64, 'Minibatch size')
# flags.DEFINE_integer('n_samples', 10, 'Number of samples to save')
# flags.DEFINE_integer('print_every', 10, 'Print every n iterations')
# flags.DEFINE_integer('hidden_size', 200, 'Hidden size for neural networks')
# flags.DEFINE_integer('n_iterations', 1000, 'number of iterations')
# For bigger model:
flags.DEFINE_integer('latent_dim', 100, 'Latent dimensionality of model')
flags.DEFINE_integer('batch_size', 64, 'Minibatch size')
flags.DEFINE_integer('n_samples', 1, 'Number of samples to save')
flags.DEFINE_integer('print_every', 1000, 'Print every n iterations')
flags.DEFINE_integer('hidden_size', 200, 'Hidden size for neural networks')
flags.DEFINE_integer('n_iterations', 100000, 'number of iterations')
FLAGS = flags.FLAGS
def inference_network(x, latent_dim, hidden_size):
"""Construct an inference network parametrizing a Gaussian.
Args:
x: A batch of MNIST digits.
latent_dim: The latent dimensionality.
hidden_size: The size of the neural net hidden layers.
Returns:
mu: Mean parameters for the variational family Normal
sigma: Standard deviation parameters for the variational family Normal
"""
with slim.arg_scope([slim.fully_connected], activation_fn=tf.nn.relu):
net = slim.flatten(x)
net = slim.fully_connected(net, hidden_size)
net = slim.fully_connected(net, hidden_size)
gaussian_params = slim.fully_connected(
net, latent_dim * 2, activation_fn=None)
# The mean parameter is unconstrained
mu = gaussian_params[:, :latent_dim]
# The standard deviation must be positive. Parametrize with a softplus
sigma = tf.nn.softplus(gaussian_params[:, latent_dim:])
return mu, sigma
def generative_network(z, hidden_size):
"""Build a generative network parametrizing the likelihood of the data
Args:
z: Samples of latent variables
hidden_size: Size of the hidden state of the neural net
Returns:
bernoulli_logits: logits for the Bernoulli likelihood of the data
"""
with slim.arg_scope([slim.fully_connected], activation_fn=tf.nn.relu):
net = slim.fully_connected(z, hidden_size)
net = slim.fully_connected(net, hidden_size)
bernoulli_logits = slim.fully_connected(net, 784, activation_fn=None)
bernoulli_logits = tf.reshape(bernoulli_logits, [-1, 28, 28, 1])
return bernoulli_logits
def train():
# Train a Variational Autoencoder on MNIST
# Input placeholders
with tf.name_scope('data'):
x = tf.placeholder(tf.float32, [None, 28, 28, 1])
tf.summary.image('data', x)
with tf.variable_scope('variational'):
q_mu, q_sigma = inference_network(x=x,
latent_dim=FLAGS.latent_dim,
hidden_size=FLAGS.hidden_size)
# The variational distribution is a Normal with mean and standard
# deviation given by the inference network
q_z = distributions.Normal(loc=q_mu, scale=q_sigma)
assert q_z.reparameterization_type == distributions.FULLY_REPARAMETERIZED
with tf.variable_scope('model'):
# The likelihood is Bernoulli-distributed with logits given by the
# generative network
p_x_given_z_logits = generative_network(z=q_z.sample(),
hidden_size=FLAGS.hidden_size)
p_x_given_z = distributions.Bernoulli(logits=p_x_given_z_logits)
posterior_predictive_samples = p_x_given_z.sample()
tf.summary.image('posterior_predictive',
tf.cast(posterior_predictive_samples, tf.float32))
# Take samples from the prior
with tf.variable_scope('model', reuse=True):
p_z = distributions.Normal(loc=np.zeros(FLAGS.latent_dim, dtype=np.float32),
scale=np.ones(FLAGS.latent_dim, dtype=np.float32))
p_z_sample = p_z.sample(FLAGS.n_samples)
p_x_given_z_logits = generative_network(z=p_z_sample,
hidden_size=FLAGS.hidden_size)
prior_predictive = distributions.Bernoulli(logits=p_x_given_z_logits)
prior_predictive_samples = prior_predictive.sample()
tf.summary.image('prior_predictive',
tf.cast(prior_predictive_samples, tf.float32))
# Take samples from the prior with a placeholder
with tf.variable_scope('model', reuse=True):
z_input = tf.placeholder(tf.float32, [None, FLAGS.latent_dim])
p_x_given_z_logits = generative_network(z=z_input,
hidden_size=FLAGS.hidden_size)
prior_predictive_inp = distributions.Bernoulli(logits=p_x_given_z_logits)
prior_predictive_inp_sample = prior_predictive_inp.sample()
# Build the evidence lower bound (ELBO) or the negative loss
kl = tf.reduce_sum(distributions.kl_divergence(q_z, p_z), 1)
expected_log_likelihood = tf.reduce_sum(p_x_given_z.log_prob(x),
[1, 2, 3])
elbo = tf.reduce_sum(expected_log_likelihood - kl, 0)
optimizer = tf.train.RMSPropOptimizer(learning_rate=0.001)
train_op = optimizer.minimize(-elbo)
# Merge all the summaries
summary_op = tf.summary.merge_all()
init_op = tf.global_variables_initializer()
# Run training
sess = tf.InteractiveSession()
sess.run(init_op)
mnist = read_data_sets(FLAGS.data_dir, one_hot=True)
print('Saving TensorBoard summaries and images to: %s' % FLAGS.logdir)
train_writer = tf.summary.FileWriter(FLAGS.logdir, sess.graph)
# Get fixed MNIST digits for plotting posterior means during training
np_x_fixed, np_y = mnist.test.next_batch(5000)
np_x_fixed = np_x_fixed.reshape(5000, 28, 28, 1)
np_x_fixed = (np_x_fixed > 0.5).astype(np.float32)
t0 = time.time()
for i in range(FLAGS.n_iterations):
# Re-binarize the data at every batch; this improves results
np_x, _ = mnist.train.next_batch(FLAGS.batch_size)
np_x = np_x.reshape(FLAGS.batch_size, 28, 28, 1)
np_x = (np_x > 0.5).astype(np.float32)
sess.run(train_op, {x: np_x})
# Print progress and save samples every so often
if i % FLAGS.print_every == 0:
np_elbo, summary_str = sess.run([elbo, summary_op], {x: np_x})
train_writer.add_summary(summary_str, i)
print('Iteration: {0:d} ELBO: {1:.3f} s/iter: {2:.3e}'.format(
i,
np_elbo / FLAGS.batch_size,
(time.time() - t0) / FLAGS.print_every))
t0 = time.time()
# Save samples
np_posterior_samples, np_prior_samples = sess.run(
[posterior_predictive_samples, prior_predictive_samples], {x: np_x})
for k in range(FLAGS.n_samples):
f_name = os.path.join(
FLAGS.logdir, 'iter_%d_posterior_predictive_%d_data.jpg' % (i, k))
imwrite(f_name, np_x[k, :, :, 0])
f_name = os.path.join(
FLAGS.logdir, 'iter_%d_posterior_predictive_%d_sample.jpg' % (i, k))
imwrite(f_name, np_posterior_samples[k, :, :, 0])
f_name = os.path.join(
FLAGS.logdir, 'iter_%d_prior_predictive_%d.jpg' % (i, k))
imwrite(f_name, np_prior_samples[k, :, :, 0])
# Plot the posterior predictive space
if FLAGS.latent_dim == 2:
np_q_mu = sess.run(q_mu, {x: np_x_fixed})
cmap = mpl.colors.ListedColormap(sns.color_palette("husl"))
f, ax = plt.subplots(1, figsize=(6 * 1.1618, 6))
im = ax.scatter(np_q_mu[:, 0], np_q_mu[:, 1], c=np.argmax(np_y, 1), cmap=cmap,
alpha=0.7)
ax.set_xlabel('First dimension of sampled latent variable $z_1$')
ax.set_ylabel('Second dimension of sampled latent variable mean $z_2$')
ax.set_xlim([-10., 10.])
ax.set_ylim([-10., 10.])
f.colorbar(im, ax=ax, label='Digit class')
plt.tight_layout()
plt.savefig(os.path.join(FLAGS.logdir,
'posterior_predictive_map_frame_%d.png' % i))
plt.close()
nx = ny = 20
x_values = np.linspace(-3, 3, nx)
y_values = np.linspace(-3, 3, ny)
canvas = np.empty((28 * ny, 28 * nx))
for ii, yi in enumerate(x_values):
for j, xi in enumerate(y_values):
np_z = np.array([[xi, yi]])
x_mean = sess.run(prior_predictive_inp_sample, {z_input: np_z})
canvas[(nx - ii - 1) * 28:(nx - ii) * 28, j *
28:(j + 1) * 28] = x_mean[0].reshape(28, 28)
imwrite(os.path.join(FLAGS.logdir,
'prior_predictive_map_frame_%d.png' % i), canvas)
# plt.figure(figsize=(8, 10))
# Xi, Yi = np.meshgrid(x_values, y_values)
# plt.imshow(canvas, origin="upper")
# plt.tight_layout()
# plt.savefig()
# Make the gifs
if FLAGS.latent_dim == 2:
os.system(
'convert -delay 15 -loop 0 {0}/posterior_predictive_map_frame*png {0}/posterior_predictive.gif'
.format(FLAGS.logdir))
os.system(
'convert -delay 15 -loop 0 {0}/prior_predictive_map_frame*png {0}/prior_predictive.gif'
.format(FLAGS.logdir))
def main(_):
if tf.gfile.Exists(FLAGS.logdir):
tf.gfile.DeleteRecursively(FLAGS.logdir)
tf.gfile.MakeDirs(FLAGS.logdir)
train()
if __name__ == '__main__':
tf.app.run()