Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DCGAN overriding training step- Keras 3 Migration (Only Tensorflow Backend) #1693

Merged
merged 11 commits into from
Jan 4, 2024
44 changes: 27 additions & 17 deletions examples/generative/dcgan_overriding_train_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,26 @@
Title: DCGAN to generate face images
Author: [fchollet](https://twitter.com/fchollet)
Date created: 2019/04/29
Last modified: 2021/01/01
Last modified: 2023/12/21
Description: A simple DCGAN trained using `fit()` by overriding `train_step` on CelebA images.
Accelerator: GPU
"""
"""
## Setup
"""

import keras
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

from keras import layers
from keras import ops
import matplotlib.pyplot as plt
import os
import gdown
from zipfile import ZipFile

seed_generator = keras.random.SeedGenerator(42)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The seed generator must be attached to the layer. So you should do:

class GAN(keras.Model):
    def __init__(self, **kwargs):
        ...
        self.seed_generator = keras.random.SeedGenerator(1337)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it


"""
## Prepare CelebA data

Expand Down Expand Up @@ -64,11 +68,11 @@
[
keras.Input(shape=(64, 64, 3)),
layers.Conv2D(64, kernel_size=4, strides=2, padding="same"),
layers.LeakyReLU(alpha=0.2),
layers.LeakyReLU(negative_slope=0.2),
layers.Conv2D(128, kernel_size=4, strides=2, padding="same"),
layers.LeakyReLU(alpha=0.2),
layers.LeakyReLU(negative_slope=0.2),
layers.Conv2D(128, kernel_size=4, strides=2, padding="same"),
layers.LeakyReLU(alpha=0.2),
layers.LeakyReLU(negative_slope=0.2),
layers.Flatten(),
layers.Dropout(0.2),
layers.Dense(1, activation="sigmoid"),
Expand All @@ -91,11 +95,11 @@
layers.Dense(8 * 8 * 128),
layers.Reshape((8, 8, 128)),
layers.Conv2DTranspose(128, kernel_size=4, strides=2, padding="same"),
layers.LeakyReLU(alpha=0.2),
layers.LeakyReLU(negative_slope=0.2),
layers.Conv2DTranspose(256, kernel_size=4, strides=2, padding="same"),
layers.LeakyReLU(alpha=0.2),
layers.LeakyReLU(negative_slope=0.2),
layers.Conv2DTranspose(512, kernel_size=4, strides=2, padding="same"),
layers.LeakyReLU(alpha=0.2),
layers.LeakyReLU(negative_slope=0.2),
layers.Conv2D(3, kernel_size=5, padding="same", activation="sigmoid"),
],
name="generator",
Expand Down Expand Up @@ -128,18 +132,20 @@ def metrics(self):

def train_step(self, real_images):
# Sample random points in the latent space
batch_size = tf.shape(real_images)[0]
random_latent_vectors = tf.random.normal(shape=(batch_size, self.latent_dim))
batch_size = ops.shape(real_images)[0]
random_latent_vectors = keras.random.normal(
shape=(batch_size, self.latent_dim), seed=seed_generator
)

# Decode them to fake images
generated_images = self.generator(random_latent_vectors)

# Combine them with real images
combined_images = tf.concat([generated_images, real_images], axis=0)
combined_images = ops.concatenate([generated_images, real_images], axis=0)

# Assemble labels discriminating real from fake images
labels = tf.concat(
[tf.ones((batch_size, 1)), tf.zeros((batch_size, 1))], axis=0
labels = ops.concatenate(
[ops.ones((batch_size, 1)), ops.zeros((batch_size, 1))], axis=0
)
# Add random noise to the labels - important trick!
labels += 0.05 * tf.random.uniform(tf.shape(labels))
Expand All @@ -154,10 +160,12 @@ def train_step(self, real_images):
)

# Sample random points in the latent space
random_latent_vectors = tf.random.normal(shape=(batch_size, self.latent_dim))
random_latent_vectors = keras.random.normal(
shape=(batch_size, self.latent_dim), seed=seed_generator
)

# Assemble labels that say "all real images"
misleading_labels = tf.zeros((batch_size, 1))
misleading_labels = ops.zeros((batch_size, 1))

# Train the generator (note that we should *not* update the weights
# of the discriminator)!
Expand Down Expand Up @@ -187,7 +195,9 @@ def __init__(self, num_img=3, latent_dim=128):
self.latent_dim = latent_dim

def on_epoch_end(self, epoch, logs=None):
random_latent_vectors = tf.random.normal(shape=(self.num_img, self.latent_dim))
random_latent_vectors = keras.random.normal(
shape=(self.num_img, self.latent_dim), seed=seed_generator
)
generated_images = self.model.generator(random_latent_vectors)
generated_images *= 255
generated_images.numpy()
Expand Down
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
48 changes: 29 additions & 19 deletions examples/generative/ipynb/dcgan_overriding_train_step.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
"\n",
"**Author:** [fchollet](https://twitter.com/fchollet)<br>\n",
"**Date created:** 2019/04/29<br>\n",
"**Last modified:** 2021/01/01<br>\n",
"**Last modified:** 2023/12/21<br>\n",
"**Description:** A simple DCGAN trained using `fit()` by overriding `train_step` on CelebA images."
]
},
Expand All @@ -31,13 +31,17 @@
},
"outputs": [],
"source": [
"import keras\n",
"import tensorflow as tf\n",
"from tensorflow import keras\n",
"from tensorflow.keras import layers\n",
"\n",
"from keras import layers\n",
"from keras import ops\n",
"import matplotlib.pyplot as plt\n",
"import os\n",
"import gdown\n",
"from zipfile import ZipFile"
"from zipfile import ZipFile\n",
"\n",
"seed_generator = keras.random.SeedGenerator(42)"
]
},
{
Expand Down Expand Up @@ -141,11 +145,11 @@
" [\n",
" keras.Input(shape=(64, 64, 3)),\n",
" layers.Conv2D(64, kernel_size=4, strides=2, padding=\"same\"),\n",
" layers.LeakyReLU(alpha=0.2),\n",
" layers.LeakyReLU(negative_slope=0.2),\n",
" layers.Conv2D(128, kernel_size=4, strides=2, padding=\"same\"),\n",
" layers.LeakyReLU(alpha=0.2),\n",
" layers.LeakyReLU(negative_slope=0.2),\n",
" layers.Conv2D(128, kernel_size=4, strides=2, padding=\"same\"),\n",
" layers.LeakyReLU(alpha=0.2),\n",
" layers.LeakyReLU(negative_slope=0.2),\n",
" layers.Flatten(),\n",
" layers.Dropout(0.2),\n",
" layers.Dense(1, activation=\"sigmoid\"),\n",
Expand Down Expand Up @@ -182,11 +186,11 @@
" layers.Dense(8 * 8 * 128),\n",
" layers.Reshape((8, 8, 128)),\n",
" layers.Conv2DTranspose(128, kernel_size=4, strides=2, padding=\"same\"),\n",
" layers.LeakyReLU(alpha=0.2),\n",
" layers.LeakyReLU(negative_slope=0.2),\n",
" layers.Conv2DTranspose(256, kernel_size=4, strides=2, padding=\"same\"),\n",
" layers.LeakyReLU(alpha=0.2),\n",
" layers.LeakyReLU(negative_slope=0.2),\n",
" layers.Conv2DTranspose(512, kernel_size=4, strides=2, padding=\"same\"),\n",
" layers.LeakyReLU(alpha=0.2),\n",
" layers.LeakyReLU(negative_slope=0.2),\n",
" layers.Conv2D(3, kernel_size=5, padding=\"same\", activation=\"sigmoid\"),\n",
" ],\n",
" name=\"generator\",\n",
Expand Down Expand Up @@ -233,18 +237,20 @@
"\n",
" def train_step(self, real_images):\n",
" # Sample random points in the latent space\n",
" batch_size = tf.shape(real_images)[0]\n",
" random_latent_vectors = tf.random.normal(shape=(batch_size, self.latent_dim))\n",
" batch_size = ops.shape(real_images)[0]\n",
" random_latent_vectors = keras.random.normal(\n",
" shape=(batch_size, self.latent_dim), seed=seed_generator\n",
" )\n",
"\n",
" # Decode them to fake images\n",
" generated_images = self.generator(random_latent_vectors)\n",
"\n",
" # Combine them with real images\n",
" combined_images = tf.concat([generated_images, real_images], axis=0)\n",
" combined_images = ops.concatenate([generated_images, real_images], axis=0)\n",
"\n",
" # Assemble labels discriminating real from fake images\n",
" labels = tf.concat(\n",
" [tf.ones((batch_size, 1)), tf.zeros((batch_size, 1))], axis=0\n",
" labels = ops.concatenate(\n",
" [ops.ones((batch_size, 1)), ops.zeros((batch_size, 1))], axis=0\n",
" )\n",
" # Add random noise to the labels - important trick!\n",
" labels += 0.05 * tf.random.uniform(tf.shape(labels))\n",
Expand All @@ -259,10 +265,12 @@
" )\n",
"\n",
" # Sample random points in the latent space\n",
" random_latent_vectors = tf.random.normal(shape=(batch_size, self.latent_dim))\n",
" random_latent_vectors = keras.random.normal(\n",
" shape=(batch_size, self.latent_dim), seed=seed_generator\n",
" )\n",
"\n",
" # Assemble labels that say \"all real images\"\n",
" misleading_labels = tf.zeros((batch_size, 1))\n",
" misleading_labels = ops.zeros((batch_size, 1))\n",
"\n",
" # Train the generator (note that we should *not* update the weights\n",
" # of the discriminator)!\n",
Expand Down Expand Up @@ -306,7 +314,9 @@
" self.latent_dim = latent_dim\n",
"\n",
" def on_epoch_end(self, epoch, logs=None):\n",
" random_latent_vectors = tf.random.normal(shape=(self.num_img, self.latent_dim))\n",
" random_latent_vectors = keras.random.normal(\n",
" shape=(self.num_img, self.latent_dim), seed=seed_generator\n",
" )\n",
" generated_images = self.model.generator(random_latent_vectors)\n",
" generated_images *= 255\n",
" generated_images.numpy()\n",
Expand Down Expand Up @@ -389,4 +399,4 @@
},
"nbformat": 4,
"nbformat_minor": 0
}
}
Loading