Skip to content

Commit

Permalink
Update InfoGAN
Browse files Browse the repository at this point in the history
  • Loading branch information
Thibault de Boissiere committed Oct 30, 2017
1 parent f83b1f2 commit bf9036b
Show file tree
Hide file tree
Showing 5 changed files with 63 additions and 61 deletions.
10 changes: 5 additions & 5 deletions InfoGAN/src/model/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def eval(**kwargs):
batch_size = kwargs["batch_size"]
generator = kwargs["generator"]
model_name = kwargs["model_name"]
image_dim_ordering = kwargs["image_dim_ordering"]
image_data_format = kwargs["image_data_format"]
img_dim = kwargs["img_dim"]
cont_dim = (kwargs["cont_dim"],)
cat_dim = (kwargs["cat_dim"],)
Expand All @@ -29,9 +29,9 @@ def eval(**kwargs):

# Load and rescale data
if dset == "RGZ":
X_real_train = data_utils.load_RGZ(img_dim, image_dim_ordering)
X_real_train = data_utils.load_RGZ(img_dim, image_data_format)
if dset == "mnist":
X_real_train, _, _, _ = data_utils.load_mnist(image_dim_ordering)
X_real_train, _, _, _ = data_utils.load_mnist(image_data_format)
img_dim = X_real_train.shape[-3:]

# Load generator model
Expand Down Expand Up @@ -60,7 +60,7 @@ def eval(**kwargs):
X_gen = generator_model.predict([X_cat, X_cont, X_noise])
X_gen = data_utils.inverse_normalization(X_gen)

if image_dim_ordering == "th":
if image_data_format == "channels_first":
X_gen = X_gen.transpose(0,2,3,1)

X_gen = [X_gen[i] for i in range(len(X_gen))]
Expand Down Expand Up @@ -97,7 +97,7 @@ def eval(**kwargs):

X_gen = generator_model.predict([X_cat, X_cont, X_noise])
X_gen = data_utils.inverse_normalization(X_gen)
if image_dim_ordering == "th":
if image_data_format == "channels_first":
X_gen = X_gen.transpose(0,2,3,1)
X_gen = [X_gen[i] for i in range(len(X_gen))]
X_plot.append(np.concatenate(X_gen, axis=1))
Expand Down
10 changes: 5 additions & 5 deletions InfoGAN/src/model/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,11 @@ def launch_eval(**kwargs):

# manually set dim ordering otherwise it is not changed
if args.backend == "theano":
image_dim_ordering = "th"
K.set_image_dim_ordering(image_dim_ordering)
image_data_format = "channels_first"
K.set_image_data_format(image_data_format)
elif args.backend == "tensorflow":
image_dim_ordering = "tf"
K.set_image_dim_ordering(image_dim_ordering)
image_data_format = "channels_last"
K.set_image_data_format(image_data_format)

import train
import eval
Expand All @@ -74,7 +74,7 @@ def launch_eval(**kwargs):
"epoch": args.epoch,
"nb_classes": args.nb_classes,
"do_plot": args.do_plot,
"image_dim_ordering": image_dim_ordering,
"image_data_format": image_data_format,
"bn_mode": args.bn_mode,
"img_dim": args.img_dim,
"noise_dim": args.noise_dim,
Expand Down
70 changes: 36 additions & 34 deletions InfoGAN/src/model/models.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from keras.models import Model
from keras.layers.core import Flatten, Dense, Dropout, Activation, Lambda, Reshape
from keras.layers.convolutional import Convolution2D, Deconvolution2D, ZeroPadding2D, UpSampling2D
from keras.layers.convolutional import Conv2D, Deconv2D, ZeroPadding2D, UpSampling2D
from keras.layers import Input, merge
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.normalization import BatchNormalization
Expand Down Expand Up @@ -44,11 +44,11 @@ def generator_upsampling(cat_dim, cont_dim, noise_dim, img_dim, bn_mode, model_n
gen_input = merge([cat_input, cont_input, noise_input], mode="concat")

x = Dense(1024)(gen_input)
x = BatchNormalization(mode=1)(x)
x = BatchNormalization()(x)
x = Activation("relu")(x)

x = Dense(f * start_dim * start_dim)(x)
x = BatchNormalization(mode=1)(x)
x = BatchNormalization()(x)
x = Activation("relu")(x)

x = Reshape(reshape_shape)(x)
Expand All @@ -57,16 +57,16 @@ def generator_upsampling(cat_dim, cont_dim, noise_dim, img_dim, bn_mode, model_n
for i in range(nb_upconv):
x = UpSampling2D(size=(2, 2))(x)
nb_filters = int(f / (2 ** (i + 1)))
x = Convolution2D(nb_filters, 3, 3, border_mode="same")(x)
x = BatchNormalization(mode=bn_mode, axis=bn_axis)(x)
x = Conv2D(nb_filters, (3, 3), padding="same")(x)
x = BatchNormalization(axis=bn_axis)(x)
x = Activation("relu")(x)
# x = Convolution2D(nb_filters, 3, 3, border_mode="same")(x)
# x = BatchNormalization(mode=bn_mode, axis=bn_axis)(x)
# x = Conv2D(nb_filters, (3, 3), padding="same")(x)
# x = BatchNormalization(axis=bn_axis)(x)
# x = Activation("relu")(x)

x = Convolution2D(output_channels, 3, 3, name="gen_convolution2d_final", border_mode="same", activation='tanh')(x)
x = Conv2D(output_channels, (3, 3), name="gen_Conv2D_final", padding="same", activation='tanh')(x)

generator_model = Model(input=[cat_input, cont_input, noise_input], output=[x], name=model_name)
generator_model = Model(inputs=[cat_input, cont_input, noise_input], outputs=[x], name=model_name)

return generator_model

Expand Down Expand Up @@ -105,11 +105,11 @@ def generator_deconv(cat_dim, cont_dim, noise_dim, img_dim, bn_mode, batch_size,
gen_input = merge([cat_input, cont_input, noise_input], mode="concat")

x = Dense(1024)(gen_input)
x = BatchNormalization(mode=1)(x)
x = BatchNormalization()(x)
x = Activation("relu")(x)

x = Dense(f * start_dim * start_dim)(x)
x = BatchNormalization(mode=1)(x)
x = BatchNormalization()(x)
x = Activation("relu")(x)

x = Reshape(reshape_shape)(x)
Expand All @@ -119,17 +119,17 @@ def generator_deconv(cat_dim, cont_dim, noise_dim, img_dim, bn_mode, batch_size,
nb_filters = int(f / (2 ** (i + 1)))
s = start_dim * (2 ** (i + 1))
o_shape = (batch_size, s, s, nb_filters)
x = Deconvolution2D(nb_filters, 3, 3, output_shape=o_shape, subsample=(2, 2), border_mode="same")(x)
x = Deconv2D(nb_filters, (3, 3), output_shape=o_shape, strides=(2, 2), padding="same")(x)
x = BatchNormalization(mode=2, axis=bn_axis)(x)
x = Activation("relu")(x)

# Last block
s = start_dim * (2 ** (nb_upconv))
o_shape = (batch_size, s, s, output_channels)
x = Deconvolution2D(output_channels, 3, 3, output_shape=o_shape, subsample=(2, 2), border_mode="same")(x)
x = Deconv2D(output_channels, (3, 3), output_shape=o_shape, strides=(2, 2), padding="same")(x)
x = Activation("tanh")(x)

generator_model = Model(input=[cat_input, cont_input, noise_input], output=[x], name=model_name)
generator_model = Model(inputs=[cat_input, cont_input, noise_input], outputs=[x], name=model_name)

return generator_model

Expand Down Expand Up @@ -158,19 +158,19 @@ def DCGAN_discriminator(cat_dim, cont_dim, img_dim, bn_mode, model_name="DCGAN_d
list_f = [64, 128, 256]

# First conv
x = Convolution2D(64, 3, 3, subsample=(2, 2), name="disc_convolution2d_1", border_mode="same")(disc_input)
x = Conv2D(64, (3, 3), strides=(2, 2), name="disc_Conv2D_1", padding="same")(disc_input)
x = LeakyReLU(0.2)(x)

# Next convs
for i, f in enumerate(list_f):
name = "disc_convolution2d_%s" % (i + 2)
x = Convolution2D(f, 3, 3, subsample=(2, 2), name=name, border_mode="same")(x)
x = BatchNormalization(mode=bn_mode, axis=bn_axis)(x)
name = "disc_Conv2D_%s" % (i + 2)
x = Conv2D(f, (3, 3), strides=(2, 2), name=name, padding="same")(x)
x = BatchNormalization(axis=bn_axis)(x)
x = LeakyReLU(0.2)(x)

x = Flatten()(x)
x = Dense(1024)(x)
x = BatchNormalization(mode=1)(x)
x = BatchNormalization()(x)
x = LeakyReLU(0.2)(x)

def linmax(x):
Expand All @@ -181,7 +181,7 @@ def linmax_shape(input_shape):

# More processing for auxiliary Q
x_Q = Dense(128)(x)
x_Q = BatchNormalization(mode=1)(x_Q)
x_Q = BatchNormalization()(x_Q)
x_Q = LeakyReLU(0.2)(x_Q)
x_Q_Y = Dense(cat_dim[0], activation='softmax', name="Q_cat_out")(x_Q)
x_Q_C_mean = Dense(cont_dim[0], activation='linear', name="dense_Q_cont_mean")(x_Q)
Expand All @@ -205,7 +205,7 @@ def lambda_output(input_shape):
num_kernels = 300
dim_per_kernel = 5

M = Dense(num_kernels * dim_per_kernel, bias=False, activation=None)
M = Dense(num_kernels * dim_per_kernel, use_bias=False, activation=None)
MBD = Lambda(minb_disc, output_shape=lambda_output)

if use_mbd:
Expand All @@ -216,7 +216,7 @@ def lambda_output(input_shape):

# Create discriminator model
x_disc = Dense(2, activation='softmax', name="disc_out")(x)
discriminator_model = Model(input=[disc_input], output=[x_disc, x_Q_Y, x_Q_C], name=model_name)
discriminator_model = Model(inputs=[disc_input], outputs=[x_disc, x_Q_Y, x_Q_C], name=model_name)

return discriminator_model

Expand All @@ -230,8 +230,8 @@ def DCGAN(generator, discriminator_model, cat_dim, cont_dim, noise_dim):
generated_image = generator([cat_input, cont_input, noise_input])
x_disc, x_Q_Y, x_Q_C = discriminator_model(generated_image)

DCGAN = Model(input=[cat_input, cont_input, noise_input],
output=[x_disc, x_Q_Y, x_Q_C],
DCGAN = Model(inputs=[cat_input, cont_input, noise_input],
outputs=[x_disc, x_Q_Y, x_Q_C],
name="DCGAN")

return DCGAN
Expand All @@ -241,21 +241,23 @@ def load(model_name, cat_dim, cont_dim, noise_dim, img_dim, bn_mode, batch_size,

if model_name == "generator_upsampling":
model = generator_upsampling(cat_dim, cont_dim, noise_dim, img_dim, bn_mode, model_name=model_name, dset=dset)
print model.summary()
from keras.utils.visualize_util import plot
plot(model, to_file='../../figures/%s.png' % model_name, show_shapes=True, show_layer_names=True)
model.summary()
from keras.utils import plot_model
plot_model(model, to_file='../../figures/%s.png' % model_name, show_shapes=True, show_layer_names=True)
return model
if model_name == "generator_deconv":
model = generator_deconv(cat_dim, cont_dim, noise_dim, img_dim, bn_mode, batch_size, model_name=model_name, dset=dset)
print model.summary()
from keras.utils.visualize_util import plot
plot(model, to_file='../../figures/%s.png' % model_name, show_shapes=True, show_layer_names=True)
model = generator_deconv(cat_dim, cont_dim, noise_dim, img_dim, bn_mode,
batch_size, model_name=model_name, dset=dset)
model.summary()
from keras.utils import plot_model
plot_model(model, to_file='../../figures/%s.png' % model_name, show_shapes=True, show_layer_names=True)
return model
if model_name == "DCGAN_discriminator":
model = DCGAN_discriminator(cat_dim, cont_dim, img_dim, bn_mode, model_name=model_name, dset=dset, use_mbd=use_mbd)
model = DCGAN_discriminator(cat_dim, cont_dim, img_dim, bn_mode,
model_name=model_name, dset=dset, use_mbd=use_mbd)
model.summary()
from keras.utils.visualize_util import plot
plot(model, to_file='../../figures/%s.png' % model_name, show_shapes=True, show_layer_names=True)
from keras.utils import plot_model
plot_model(model, to_file='../../figures/%s.png' % model_name, show_shapes=True, show_layer_names=True)
return model


Expand Down
8 changes: 4 additions & 4 deletions InfoGAN/src/model/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def train(**kwargs):
nb_epoch = kwargs["nb_epoch"]
generator = kwargs["generator"]
model_name = kwargs["model_name"]
image_dim_ordering = kwargs["image_dim_ordering"]
image_data_format = kwargs["image_data_format"]
img_dim = kwargs["img_dim"]
cont_dim = (kwargs["cont_dim"],)
cat_dim = (kwargs["cat_dim"],)
Expand All @@ -58,9 +58,9 @@ def train(**kwargs):

# Load and rescale data
if dset == "celebA":
X_real_train = data_utils.load_celebA(img_dim, image_dim_ordering)
X_real_train = data_utils.load_celebA(img_dim, image_data_format)
if dset == "mnist":
X_real_train, _, _, _ = data_utils.load_mnist(image_dim_ordering)
X_real_train, _, _, _ = data_utils.load_mnist(image_data_format)
img_dim = X_real_train.shape[-3:]

try:
Expand Down Expand Up @@ -162,7 +162,7 @@ def train(**kwargs):
# Save images for visualization
if batch_counter % (n_batch_per_epoch / 2) == 0:
data_utils.plot_generated_batch(X_real_batch, generator_model,
batch_size, cat_dim, cont_dim, noise_dim, image_dim_ordering)
batch_size, cat_dim, cont_dim, noise_dim, image_data_format)

if batch_counter >= n_batch_per_epoch:
break
Expand Down
26 changes: 13 additions & 13 deletions InfoGAN/src/utils/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,19 @@ def inverse_normalization(X):
return (X + 1.) / 2.


def load_mnist(image_dim_ordering):
def load_mnist(image_data_format):

(X_train, y_train), (X_test, y_test) = mnist.load_data()

if image_dim_ordering == 'th':
if image_data_format == "channels_first":
X_train = X_train.reshape(X_train.shape[0], 1, 28, 28)
X_test = X_test.reshape(X_test.shape[0], 1, 28, 28)
else:
X_train = X_train.reshape(X_train.shape[0], 28, 28, 1)
X_test = X_test.reshape(X_test.shape[0], 28, 28, 1)

X_train = X_train.astype('float32')
X_test = X_test.astype('float32')
X_train = X_train.astype("float32")
X_test = X_test.astype("float32")

X_train = normalization(X_train)
X_test = normalization(X_test)
Expand All @@ -38,19 +38,19 @@ def load_mnist(image_dim_ordering):
Y_train = np_utils.to_categorical(y_train, nb_classes)
Y_test = np_utils.to_categorical(y_test, nb_classes)

print X_train.shape, X_test.shape, Y_train.shape, Y_test.shape
print(X_train.shape, X_test.shape, Y_train.shape, Y_test.shape)

return X_train, Y_train, X_test, Y_test


def load_celebA(img_dim, image_dim_ordering):
def load_celebA(img_dim, image_data_format):

with h5py.File("../../data/processed/CelebA_%s_data.h5" % img_dim, "r") as hf:

X_real_train = hf["data"][:].astype(np.float32)
X_real_train = normalization(X_real_train)

if image_dim_ordering == "tf":
if image_data_format == "channels_last":
X_real_train = X_real_train.transpose(0, 2, 3, 1)

return X_real_train
Expand All @@ -70,7 +70,7 @@ def sample_noise(noise_scale, batch_size, noise_dim):

def sample_cat(batch_size, cat_dim):

y = np.zeros((batch_size, cat_dim[0]), dtype='float32')
y = np.zeros((batch_size, cat_dim[0]), dtype="float32")
random_y = np.random.randint(0, cat_dim[0], size=batch_size)
y[np.arange(batch_size), random_y] = 1

Expand Down Expand Up @@ -111,7 +111,7 @@ def get_disc_batch(X_real_batch, generator_model, batch_counter, batch_size, cat
if p > 0:
y_disc[:, [0, 1]] = y_disc[:, [1, 0]]

# Repeat y_cont to accomodate for keras' loss function conventions
# Repeat y_cont to accomodate for keras" loss function conventions
y_cont = np.expand_dims(y_cont, 1)
y_cont = np.repeat(y_cont, 2, axis=1)

Expand All @@ -127,14 +127,14 @@ def get_gen_batch(batch_size, cat_dim, cont_dim, noise_dim, noise_scale=0.5):
y_cat = sample_cat(batch_size, cat_dim)
y_cont = sample_noise(noise_scale, batch_size, cont_dim)

# Repeat y_cont to accomodate for keras' loss function conventions
# Repeat y_cont to accomodate for keras" loss function conventions
y_cont_target = np.expand_dims(y_cont, 1)
y_cont_target = np.repeat(y_cont_target, 2, axis=1)

return X_gen, y_gen, y_cat, y_cont, y_cont_target


def plot_generated_batch(X_real, generator_model, batch_size, cat_dim, cont_dim, noise_dim, image_dim_ordering, noise_scale=0.5):
def plot_generated_batch(X_real, generator_model, batch_size, cat_dim, cont_dim, noise_dim, image_data_format, noise_scale=0.5):

# Generate images
y_cat = sample_cat(batch_size, cat_dim)
Expand All @@ -149,7 +149,7 @@ def plot_generated_batch(X_real, generator_model, batch_size, cat_dim, cont_dim,
Xg = X_gen[:8]
Xr = X_real[:8]

if image_dim_ordering == "tf":
if image_data_format == "channels_last":
X = np.concatenate((Xg, Xr), axis=0)
list_rows = []
for i in range(int(X.shape[0] / 4)):
Expand All @@ -158,7 +158,7 @@ def plot_generated_batch(X_real, generator_model, batch_size, cat_dim, cont_dim,

Xr = np.concatenate(list_rows, axis=0)

if image_dim_ordering == "th":
if image_data_format == "channels_first":
X = np.concatenate((Xg, Xr), axis=0)
list_rows = []
for i in range(int(X.shape[0] / 4)):
Expand Down

0 comments on commit bf9036b

Please sign in to comment.