Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DCGAN overriding training step- Keras 3 Migration (Only Tensorflow Backend) #1693

Merged
merged 11 commits into from
Jan 4, 2024
Merged

Conversation

sineeli
Copy link
Collaborator

@sineeli sineeli commented Dec 21, 2023

DCGAN overriding training step- Keras 3 Migration (Only Tensorflow Backend) fix it - KerasCV-Fixit

Copy link
Contributor

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the PR!

@@ -154,10 +156,10 @@ def train_step(self, real_images):
)

# Sample random points in the latent space
random_latent_vectors = tf.random.normal(shape=(batch_size, self.latent_dim))
random_latent_vectors = keras.random.normal(shape=(batch_size, self.latent_dim))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Critically, this needs to be seeded with a SeedGenerator instance attached to the model.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, made the required changes

import matplotlib.pyplot as plt
import os
import gdown
from zipfile import ZipFile

seed_generator = keras.random.SeedGenerator(42)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The seed generator must be attached to the layer. So you should do:

class GAN(keras.Model):
    def __init__(self, **kwargs):
        ...
        self.seed_generator = keras.random.SeedGenerator(1337)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it

@@ -185,9 +193,13 @@ class GANMonitor(keras.callbacks.Callback):
def __init__(self, num_img=3, latent_dim=128):
self.num_img = num_img
self.latent_dim = latent_dim
self.seed_generator = keras.random.SeedGenerator(42)
Copy link
Collaborator Author

@sineeli sineeli Dec 30, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hope this is the correct way to declare for epoch end random number generation

Copy link
Contributor

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The code looks good! Please regenerate the other files.

@sineeli
Copy link
Collaborator Author

sineeli commented Jan 2, 2024

Retrained and generated the required files, Thanks!

Copy link
Contributor

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thank you!

@fchollet fchollet merged commit 3e68c0f into keras-team:master Jan 4, 2024
1 check passed
SuryanarayanaY pushed a commit to SuryanarayanaY/keras-io that referenced this pull request Jan 19, 2024
…ckend) (keras-team#1693)

* Migrate to Keras 3

* Keras 3 Migration

* Add seed generator to keras.random

* Train using seed generator

* Update dcgan_overriding_train_step.py

* Rgenerate files

* Migration to Keras 3

* Revert "Migration to Keras 3"

This reverts commit 7b003f5.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants