Skip to content

Commit

Permalink
Migrating SimSiam Example to Keras-3 (#1697)
Browse files Browse the repository at this point in the history
* Migrating SimSiam Example to Keras-3

* Removed Seed and Updated Formatting
  • Loading branch information
aditya02shah authored Dec 30, 2023
1 parent 44ee771 commit e3c97a6
Show file tree
Hide file tree
Showing 3 changed files with 154 additions and 129 deletions.
118 changes: 62 additions & 56 deletions examples/vision/ipynb/simsiam.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
"\n",
"**Author:** [Sayak Paul](https://twitter.com/RisingSayak)<br>\n",
"**Date created:** 2021/03/19<br>\n",
"**Last modified:** 2021/03/20<br>\n",
"**Last modified:** 2023/12/29<br>\n",
"**Description:** Implementation of a self-supervised learning method for computer vision."
]
},
Expand Down Expand Up @@ -52,9 +52,7 @@
"fully-connected network having an\n",
"[AutoEncoder](https://en.wikipedia.org/wiki/Autoencoder) like structure.\n",
"4. We then train our encoder to maximize the cosine similarity between the two different\n",
"versions of our dataset.\n",
"\n",
"This example requires TensorFlow 2.4 or higher."
"versions of our dataset.\n"
]
},
{
Expand All @@ -68,14 +66,19 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab_type": "code"
},
"outputs": [],
"source": [
"from tensorflow.keras import layers\n",
"from tensorflow.keras import regularizers\n",
"import os\n",
"os.environ[\"KERAS_BACKEND\"] = \"tensorflow\"\n",
"import keras\n",
"import keras_cv\n",
"from keras import ops\n",
"from keras import layers\n",
"from keras import regularizers\n",
"import tensorflow as tf\n",
"\n",
"import matplotlib.pyplot as plt\n",
Expand All @@ -93,7 +96,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab_type": "code"
},
Expand Down Expand Up @@ -121,13 +124,13 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab_type": "code"
},
"outputs": [],
"source": [
"(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()\n",
"(x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data()\n",
"print(f\"Total training examples: {len(x_train)}\")\n",
"print(f\"Total test examples: {len(x_test)}\")"
]
Expand All @@ -151,43 +154,50 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab_type": "code"
},
"outputs": [],
"source": [
"strength = [0.4, 0.4, 0.3, 0.1]\n",
"\n",
"random_flip = layers.RandomFlip(mode=\"horizontal_and_vertical\")\n",
"random_crop = layers.RandomCrop(CROP_TO, CROP_TO)\n",
"random_brightness = layers.RandomBrightness(0.8 * strength[0])\n",
"random_contrast = layers.RandomContrast((1 - 0.8 * strength[1], 1 + 0.8 * strength[1]))\n",
"random_saturation = keras_cv.layers.RandomSaturation(\n",
" (0.5 - 0.8 * strength[2], 0.5 + 0.8 * strength[2])\n",
")\n",
"random_hue = keras_cv.layers.RandomHue(0.2 * strength[3], [0,255])\n",
"grayscale = keras_cv.layers.Grayscale()\n",
"\n",
"def flip_random_crop(image):\n",
" # With random crops we also apply horizontal flipping.\n",
" image = tf.image.random_flip_left_right(image)\n",
" image = tf.image.random_crop(image, (CROP_TO, CROP_TO, 3))\n",
" image = random_flip(image)\n",
" image = random_crop(image)\n",
" return image\n",
"\n",
"\n",
"def color_jitter(x, strength=[0.4, 0.4, 0.4, 0.1]):\n",
" x = tf.image.random_brightness(x, max_delta=0.8 * strength[0])\n",
" x = tf.image.random_contrast(\n",
" x, lower=1 - 0.8 * strength[1], upper=1 + 0.8 * strength[1]\n",
" )\n",
" x = tf.image.random_saturation(\n",
" x, lower=1 - 0.8 * strength[2], upper=1 + 0.8 * strength[2]\n",
" )\n",
" x = tf.image.random_hue(x, max_delta=0.2 * strength[3])\n",
"def color_jitter(x):\n",
" x = random_brightness(x)\n",
" x = random_contrast(x)\n",
" x = random_saturation(x)\n",
" x = random_hue(x)\n",
" # Affine transformations can disturb the natural range of\n",
" # RGB images, hence this is needed.\n",
" x = tf.clip_by_value(x, 0, 255)\n",
" x = ops.clip(x, 0, 255)\n",
" return x\n",
"\n",
"\n",
"def color_drop(x):\n",
" x = tf.image.rgb_to_grayscale(x)\n",
" x = tf.tile(x, [1, 1, 3])\n",
" x = grayscale(x)\n",
" x = ops.tile(x, [1, 1, 3])\n",
" return x\n",
"\n",
"\n",
"def random_apply(func, x, p):\n",
" if tf.random.uniform([], minval=0, maxval=1) < p:\n",
" if keras.random.uniform([], minval=0, maxval=1) < p:\n",
" return func(x)\n",
" else:\n",
" return x\n",
Expand All @@ -200,8 +210,7 @@
" image = flip_random_crop(image)\n",
" image = random_apply(color_jitter, image, p=0.8)\n",
" image = random_apply(color_drop, image, p=0.2)\n",
" return image\n",
""
" return image\n"
]
},
{
Expand Down Expand Up @@ -231,7 +240,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab_type": "code"
},
Expand Down Expand Up @@ -303,7 +312,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab_type": "code"
},
Expand All @@ -314,7 +323,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab_type": "code"
},
Expand Down Expand Up @@ -347,11 +356,11 @@
" PROJECT_DIM, use_bias=False, kernel_regularizer=regularizers.l2(WEIGHT_DECAY)\n",
" )(x)\n",
" outputs = layers.BatchNormalization()(x)\n",
" return tf.keras.Model(inputs, outputs, name=\"encoder\")\n",
" return keras.Model(inputs, outputs, name=\"encoder\")\n",
"\n",
"\n",
"def get_predictor():\n",
" model = tf.keras.Sequential(\n",
" model = keras.Sequential(\n",
" [\n",
" # Note the AutoEncoder-like structure.\n",
" layers.Input((PROJECT_DIM,)),\n",
Expand All @@ -366,8 +375,7 @@
" ],\n",
" name=\"predictor\",\n",
" )\n",
" return model\n",
""
" return model\n"
]
},
{
Expand All @@ -387,7 +395,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab_type": "code"
},
Expand All @@ -398,13 +406,12 @@
" # The authors of SimSiam emphasize the impact of\n",
" # the `stop_gradient` operator in the paper as it\n",
" # has an important role in the overall optimization.\n",
" z = tf.stop_gradient(z)\n",
" p = tf.math.l2_normalize(p, axis=1)\n",
" z = tf.math.l2_normalize(z, axis=1)\n",
" z = ops.stop_gradient(z)\n",
" p = keras.utils.normalize(p, axis=1, order=2)\n",
" z = keras.utils.normalize(z, axis=1, order=2)\n",
" # Negative cosine similarity (minimizing this is\n",
" # equivalent to maximizing the similarity).\n",
" return -tf.reduce_mean(tf.reduce_sum((p * z), axis=1))\n",
""
" return -ops.mean(ops.sum((p * z), axis=1))\n"
]
},
{
Expand All @@ -414,24 +421,24 @@
},
"source": [
"We then define our training loop by overriding the `train_step()` function of the\n",
"`tf.keras.Model` class."
"`keras.Model` class."
]
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab_type": "code"
},
"outputs": [],
"source": [
"\n",
"class SimSiam(tf.keras.Model):\n",
"class SimSiam(keras.Model):\n",
" def __init__(self, encoder, predictor):\n",
" super().__init__()\n",
" self.encoder = encoder\n",
" self.predictor = predictor\n",
" self.loss_tracker = tf.keras.metrics.Mean(name=\"loss\")\n",
" self.loss_tracker = keras.metrics.Mean(name=\"loss\")\n",
"\n",
" @property\n",
" def metrics(self):\n",
Expand Down Expand Up @@ -459,8 +466,7 @@
"\n",
" # Monitor loss.\n",
" self.loss_tracker.update_state(loss)\n",
" return {\"loss\": self.loss_tracker.result()}\n",
""
" return {\"loss\": self.loss_tracker.result()}\n"
]
},
{
Expand All @@ -477,7 +483,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab_type": "code"
},
Expand All @@ -486,18 +492,18 @@
"# Create a cosine decay learning scheduler.\n",
"num_training_samples = len(x_train)\n",
"steps = EPOCHS * (num_training_samples // BATCH_SIZE)\n",
"lr_decayed_fn = tf.keras.optimizers.schedules.CosineDecay(\n",
"lr_decayed_fn = keras.optimizers.schedules.CosineDecay(\n",
" initial_learning_rate=0.03, decay_steps=steps\n",
")\n",
"\n",
"# Create an early stopping callback.\n",
"early_stopping = tf.keras.callbacks.EarlyStopping(\n",
"early_stopping = keras.callbacks.EarlyStopping(\n",
" monitor=\"loss\", patience=5, restore_best_weights=True\n",
")\n",
"\n",
"# Compile model and start training.\n",
"simsiam = SimSiam(get_encoder(), get_predictor())\n",
"simsiam.compile(optimizer=tf.keras.optimizers.SGD(lr_decayed_fn, momentum=0.6))\n",
"simsiam.compile(optimizer=keras.optimizers.SGD(lr_decayed_fn, momentum=0.6))\n",
"history = simsiam.fit(ssl_ds, epochs=EPOCHS, callbacks=[early_stopping])\n",
"\n",
"# Visualize the training progress of the model.\n",
Expand Down Expand Up @@ -544,7 +550,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab_type": "code"
},
Expand All @@ -566,7 +572,7 @@
"test_ds = test_ds.batch(BATCH_SIZE).prefetch(AUTO)\n",
"\n",
"# Extract the backbone ResNet20.\n",
"backbone = tf.keras.Model(\n",
"backbone = keras.Model(\n",
" simsiam.encoder.input, simsiam.encoder.get_layer(\"backbone_pool\").output\n",
")\n",
"\n",
Expand All @@ -575,13 +581,13 @@
"inputs = layers.Input((CROP_TO, CROP_TO, 3))\n",
"x = backbone(inputs, training=False)\n",
"outputs = layers.Dense(10, activation=\"softmax\")(x)\n",
"linear_model = tf.keras.Model(inputs, outputs, name=\"linear_model\")\n",
"linear_model = keras.Model(inputs, outputs, name=\"linear_model\")\n",
"\n",
"# Compile model and start training.\n",
"linear_model.compile(\n",
" loss=\"sparse_categorical_crossentropy\",\n",
" metrics=[\"accuracy\"],\n",
" optimizer=tf.keras.optimizers.SGD(lr_decayed_fn, momentum=0.9),\n",
" optimizer=keras.optimizers.SGD(lr_decayed_fn, momentum=0.9),\n",
")\n",
"history = linear_model.fit(\n",
" train_ds, validation_data=test_ds, epochs=EPOCHS, callbacks=[early_stopping]\n",
Expand Down Expand Up @@ -644,4 +650,4 @@
},
"nbformat": 4,
"nbformat_minor": 0
}
}
Loading

0 comments on commit e3c97a6

Please sign in to comment.