diff --git a/examples/vision/ipynb/simsiam.ipynb b/examples/vision/ipynb/simsiam.ipynb
index c5811fa09a..4cdda42861 100644
--- a/examples/vision/ipynb/simsiam.ipynb
+++ b/examples/vision/ipynb/simsiam.ipynb
@@ -10,7 +10,7 @@
"\n",
"**Author:** [Sayak Paul](https://twitter.com/RisingSayak)
\n",
"**Date created:** 2021/03/19
\n",
- "**Last modified:** 2021/03/20
\n",
+ "**Last modified:** 2023/12/29
\n",
"**Description:** Implementation of a self-supervised learning method for computer vision."
]
},
@@ -52,9 +52,7 @@
"fully-connected network having an\n",
"[AutoEncoder](https://en.wikipedia.org/wiki/Autoencoder) like structure.\n",
"4. We then train our encoder to maximize the cosine similarity between the two different\n",
- "versions of our dataset.\n",
- "\n",
- "This example requires TensorFlow 2.4 or higher."
+ "versions of our dataset.\n"
]
},
{
@@ -68,14 +66,19 @@
},
{
"cell_type": "code",
- "execution_count": 0,
+ "execution_count": null,
"metadata": {
"colab_type": "code"
},
"outputs": [],
"source": [
- "from tensorflow.keras import layers\n",
- "from tensorflow.keras import regularizers\n",
+ "import os\n",
+ "os.environ[\"KERAS_BACKEND\"] = \"tensorflow\"\n",
+ "import keras\n",
+ "import keras_cv\n",
+ "from keras import ops\n",
+ "from keras import layers\n",
+ "from keras import regularizers\n",
"import tensorflow as tf\n",
"\n",
"import matplotlib.pyplot as plt\n",
@@ -93,7 +96,7 @@
},
{
"cell_type": "code",
- "execution_count": 0,
+ "execution_count": null,
"metadata": {
"colab_type": "code"
},
@@ -121,13 +124,13 @@
},
{
"cell_type": "code",
- "execution_count": 0,
+ "execution_count": null,
"metadata": {
"colab_type": "code"
},
"outputs": [],
"source": [
- "(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()\n",
+ "(x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data()\n",
"print(f\"Total training examples: {len(x_train)}\")\n",
"print(f\"Total test examples: {len(x_test)}\")"
]
@@ -151,43 +154,50 @@
},
{
"cell_type": "code",
- "execution_count": 0,
+ "execution_count": null,
"metadata": {
"colab_type": "code"
},
"outputs": [],
"source": [
+ "strength = [0.4, 0.4, 0.3, 0.1]\n",
+ "\n",
+ "random_flip = layers.RandomFlip(mode=\"horizontal_and_vertical\")\n",
+ "random_crop = layers.RandomCrop(CROP_TO, CROP_TO)\n",
+ "random_brightness = layers.RandomBrightness(0.8 * strength[0])\n",
+ "random_contrast = layers.RandomContrast((1 - 0.8 * strength[1], 1 + 0.8 * strength[1]))\n",
+ "random_saturation = keras_cv.layers.RandomSaturation(\n",
+ " (0.5 - 0.8 * strength[2], 0.5 + 0.8 * strength[2])\n",
+ ")\n",
+ "random_hue = keras_cv.layers.RandomHue(0.2 * strength[3], [0,255])\n",
+ "grayscale = keras_cv.layers.Grayscale()\n",
"\n",
"def flip_random_crop(image):\n",
" # With random crops we also apply horizontal flipping.\n",
- " image = tf.image.random_flip_left_right(image)\n",
- " image = tf.image.random_crop(image, (CROP_TO, CROP_TO, 3))\n",
+ " image = random_flip(image)\n",
+ " image = random_crop(image)\n",
" return image\n",
"\n",
"\n",
- "def color_jitter(x, strength=[0.4, 0.4, 0.4, 0.1]):\n",
- " x = tf.image.random_brightness(x, max_delta=0.8 * strength[0])\n",
- " x = tf.image.random_contrast(\n",
- " x, lower=1 - 0.8 * strength[1], upper=1 + 0.8 * strength[1]\n",
- " )\n",
- " x = tf.image.random_saturation(\n",
- " x, lower=1 - 0.8 * strength[2], upper=1 + 0.8 * strength[2]\n",
- " )\n",
- " x = tf.image.random_hue(x, max_delta=0.2 * strength[3])\n",
+ "def color_jitter(x):\n",
+ " x = random_brightness(x)\n",
+ " x = random_contrast(x)\n",
+ " x = random_saturation(x)\n",
+ " x = random_hue(x)\n",
" # Affine transformations can disturb the natural range of\n",
" # RGB images, hence this is needed.\n",
- " x = tf.clip_by_value(x, 0, 255)\n",
+ " x = ops.clip(x, 0, 255)\n",
" return x\n",
"\n",
"\n",
"def color_drop(x):\n",
- " x = tf.image.rgb_to_grayscale(x)\n",
- " x = tf.tile(x, [1, 1, 3])\n",
+ " x = grayscale(x)\n",
+ " x = ops.tile(x, [1, 1, 3])\n",
" return x\n",
"\n",
"\n",
"def random_apply(func, x, p):\n",
- " if tf.random.uniform([], minval=0, maxval=1) < p:\n",
+ " if keras.random.uniform([], minval=0, maxval=1) < p:\n",
" return func(x)\n",
" else:\n",
" return x\n",
@@ -200,8 +210,7 @@
" image = flip_random_crop(image)\n",
" image = random_apply(color_jitter, image, p=0.8)\n",
" image = random_apply(color_drop, image, p=0.2)\n",
- " return image\n",
- ""
+ " return image\n"
]
},
{
@@ -231,7 +240,7 @@
},
{
"cell_type": "code",
- "execution_count": 0,
+ "execution_count": null,
"metadata": {
"colab_type": "code"
},
@@ -303,7 +312,7 @@
},
{
"cell_type": "code",
- "execution_count": 0,
+ "execution_count": null,
"metadata": {
"colab_type": "code"
},
@@ -314,7 +323,7 @@
},
{
"cell_type": "code",
- "execution_count": 0,
+ "execution_count": null,
"metadata": {
"colab_type": "code"
},
@@ -347,11 +356,11 @@
" PROJECT_DIM, use_bias=False, kernel_regularizer=regularizers.l2(WEIGHT_DECAY)\n",
" )(x)\n",
" outputs = layers.BatchNormalization()(x)\n",
- " return tf.keras.Model(inputs, outputs, name=\"encoder\")\n",
+ " return keras.Model(inputs, outputs, name=\"encoder\")\n",
"\n",
"\n",
"def get_predictor():\n",
- " model = tf.keras.Sequential(\n",
+ " model = keras.Sequential(\n",
" [\n",
" # Note the AutoEncoder-like structure.\n",
" layers.Input((PROJECT_DIM,)),\n",
@@ -366,8 +375,7 @@
" ],\n",
" name=\"predictor\",\n",
" )\n",
- " return model\n",
- ""
+ " return model\n"
]
},
{
@@ -387,7 +395,7 @@
},
{
"cell_type": "code",
- "execution_count": 0,
+ "execution_count": null,
"metadata": {
"colab_type": "code"
},
@@ -398,13 +406,12 @@
" # The authors of SimSiam emphasize the impact of\n",
" # the `stop_gradient` operator in the paper as it\n",
" # has an important role in the overall optimization.\n",
- " z = tf.stop_gradient(z)\n",
- " p = tf.math.l2_normalize(p, axis=1)\n",
- " z = tf.math.l2_normalize(z, axis=1)\n",
+ " z = ops.stop_gradient(z)\n",
+ " p = keras.utils.normalize(p, axis=1, order=2)\n",
+ " z = keras.utils.normalize(z, axis=1, order=2)\n",
" # Negative cosine similarity (minimizing this is\n",
" # equivalent to maximizing the similarity).\n",
- " return -tf.reduce_mean(tf.reduce_sum((p * z), axis=1))\n",
- ""
+ " return -ops.mean(ops.sum((p * z), axis=1))\n"
]
},
{
@@ -414,24 +421,24 @@
},
"source": [
"We then define our training loop by overriding the `train_step()` function of the\n",
- "`tf.keras.Model` class."
+ "`keras.Model` class."
]
},
{
"cell_type": "code",
- "execution_count": 0,
+ "execution_count": null,
"metadata": {
"colab_type": "code"
},
"outputs": [],
"source": [
"\n",
- "class SimSiam(tf.keras.Model):\n",
+ "class SimSiam(keras.Model):\n",
" def __init__(self, encoder, predictor):\n",
" super().__init__()\n",
" self.encoder = encoder\n",
" self.predictor = predictor\n",
- " self.loss_tracker = tf.keras.metrics.Mean(name=\"loss\")\n",
+ " self.loss_tracker = keras.metrics.Mean(name=\"loss\")\n",
"\n",
" @property\n",
" def metrics(self):\n",
@@ -459,8 +466,7 @@
"\n",
" # Monitor loss.\n",
" self.loss_tracker.update_state(loss)\n",
- " return {\"loss\": self.loss_tracker.result()}\n",
- ""
+ " return {\"loss\": self.loss_tracker.result()}\n"
]
},
{
@@ -477,7 +483,7 @@
},
{
"cell_type": "code",
- "execution_count": 0,
+ "execution_count": null,
"metadata": {
"colab_type": "code"
},
@@ -486,18 +492,18 @@
"# Create a cosine decay learning scheduler.\n",
"num_training_samples = len(x_train)\n",
"steps = EPOCHS * (num_training_samples // BATCH_SIZE)\n",
- "lr_decayed_fn = tf.keras.optimizers.schedules.CosineDecay(\n",
+ "lr_decayed_fn = keras.optimizers.schedules.CosineDecay(\n",
" initial_learning_rate=0.03, decay_steps=steps\n",
")\n",
"\n",
"# Create an early stopping callback.\n",
- "early_stopping = tf.keras.callbacks.EarlyStopping(\n",
+ "early_stopping = keras.callbacks.EarlyStopping(\n",
" monitor=\"loss\", patience=5, restore_best_weights=True\n",
")\n",
"\n",
"# Compile model and start training.\n",
"simsiam = SimSiam(get_encoder(), get_predictor())\n",
- "simsiam.compile(optimizer=tf.keras.optimizers.SGD(lr_decayed_fn, momentum=0.6))\n",
+ "simsiam.compile(optimizer=keras.optimizers.SGD(lr_decayed_fn, momentum=0.6))\n",
"history = simsiam.fit(ssl_ds, epochs=EPOCHS, callbacks=[early_stopping])\n",
"\n",
"# Visualize the training progress of the model.\n",
@@ -544,7 +550,7 @@
},
{
"cell_type": "code",
- "execution_count": 0,
+ "execution_count": null,
"metadata": {
"colab_type": "code"
},
@@ -566,7 +572,7 @@
"test_ds = test_ds.batch(BATCH_SIZE).prefetch(AUTO)\n",
"\n",
"# Extract the backbone ResNet20.\n",
- "backbone = tf.keras.Model(\n",
+ "backbone = keras.Model(\n",
" simsiam.encoder.input, simsiam.encoder.get_layer(\"backbone_pool\").output\n",
")\n",
"\n",
@@ -575,13 +581,13 @@
"inputs = layers.Input((CROP_TO, CROP_TO, 3))\n",
"x = backbone(inputs, training=False)\n",
"outputs = layers.Dense(10, activation=\"softmax\")(x)\n",
- "linear_model = tf.keras.Model(inputs, outputs, name=\"linear_model\")\n",
+ "linear_model = keras.Model(inputs, outputs, name=\"linear_model\")\n",
"\n",
"# Compile model and start training.\n",
"linear_model.compile(\n",
" loss=\"sparse_categorical_crossentropy\",\n",
" metrics=[\"accuracy\"],\n",
- " optimizer=tf.keras.optimizers.SGD(lr_decayed_fn, momentum=0.9),\n",
+ " optimizer=keras.optimizers.SGD(lr_decayed_fn, momentum=0.9),\n",
")\n",
"history = linear_model.fit(\n",
" train_ds, validation_data=test_ds, epochs=EPOCHS, callbacks=[early_stopping]\n",
@@ -644,4 +650,4 @@
},
"nbformat": 4,
"nbformat_minor": 0
-}
\ No newline at end of file
+}
diff --git a/examples/vision/md/simsiam.md b/examples/vision/md/simsiam.md
index 98671355fb..34d4698fdc 100644
--- a/examples/vision/md/simsiam.md
+++ b/examples/vision/md/simsiam.md
@@ -2,7 +2,7 @@
**Author:** [Sayak Paul](https://twitter.com/RisingSayak)
**Date created:** 2021/03/19
-**Last modified:** 2021/03/20
+**Last modified:** 2023/12/29
**Description:** Implementation of a self-supervised learning method for computer vision.
@@ -44,16 +44,16 @@ fully-connected network having an
4. We then train our encoder to maximize the cosine similarity between the two different
versions of our dataset.
-This example requires TensorFlow 2.4 or higher.
-
---
## Setup
```python
-from tensorflow.keras import layers
-from tensorflow.keras import regularizers
-import tensorflow as tf
+import os
+os.environ["KERAS_BACKEND"] = "tensorflow"
+import keras
+import keras_cv
+from keras import ops
import matplotlib.pyplot as plt
import numpy as np
@@ -80,7 +80,7 @@ WEIGHT_DECAY = 0.0005
```python
-(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
+(x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data()
print(f"Total training examples: {len(x_train)}")
print(f"Total test examples: {len(x_test)}")
```
@@ -106,36 +106,44 @@ etc.) include these in their training pipelines.
```python
+strength = [0.4, 0.4, 0.4, 0.1]
+
+random_flip = layers.RandomFlip(mode="horizontal_and_vertical")
+random_crop = layers.RandomCrop(CROP_TO, CROP_TO)
+random_brightness = layers.RandomBrightness(0.8 * strength[0])
+random_contrast = layers.RandomContrast((1 - 0.8 * strength[1], 1 + 0.8 * strength[1]))
+random_saturation = keras_cv.layers.RandomSaturation(
+ (0.5 - 0.8 * strength[2], 0.5 + 0.8 * strength[2])
+)
+random_hue = keras_cv.layers.RandomHue(0.2 * strength[3], [0,255])
+grayscale = keras_cv.layers.Grayscale()
+
def flip_random_crop(image):
# With random crops we also apply horizontal flipping.
- image = tf.image.random_flip_left_right(image)
- image = tf.image.random_crop(image, (CROP_TO, CROP_TO, 3))
+ image = random_flip(image)
+ image = random_crop(image)
return image
-def color_jitter(x, strength=[0.4, 0.4, 0.4, 0.1]):
- x = tf.image.random_brightness(x, max_delta=0.8 * strength[0])
- x = tf.image.random_contrast(
- x, lower=1 - 0.8 * strength[1], upper=1 + 0.8 * strength[1]
- )
- x = tf.image.random_saturation(
- x, lower=1 - 0.8 * strength[2], upper=1 + 0.8 * strength[2]
- )
- x = tf.image.random_hue(x, max_delta=0.2 * strength[3])
+def color_jitter(x, strength=[0.4, 0.4, 0.3, 0.1]):
+ x = random_brightness(x)
+ x = random_contrast(x)
+ x = random_saturation(x)
+ x = random_hue(x)
# Affine transformations can disturb the natural range of
# RGB images, hence this is needed.
- x = tf.clip_by_value(x, 0, 255)
+ x = ops.clip(x, 0, 255)
return x
def color_drop(x):
- x = tf.image.rgb_to_grayscale(x)
- x = tf.tile(x, [1, 1, 3])
+ x = grayscale(x)
+ x = ops.tile(x, [1, 1, 3])
return x
def random_apply(func, x, p):
- if tf.random.uniform([], minval=0, maxval=1) < p:
+ if keras.random.uniform([], minval=0, maxval=1) < p:
return func(x)
else:
return x
@@ -263,11 +271,11 @@ def get_encoder():
PROJECT_DIM, use_bias=False, kernel_regularizer=regularizers.l2(WEIGHT_DECAY)
)(x)
outputs = layers.BatchNormalization()(x)
- return tf.keras.Model(inputs, outputs, name="encoder")
+ return keras.Model(inputs, outputs, name="encoder")
def get_predictor():
- model = tf.keras.Sequential(
+ model = keras.Sequential(
[
# Note the AutoEncoder-like structure.
layers.Input((PROJECT_DIM,)),
@@ -302,27 +310,27 @@ def compute_loss(p, z):
# The authors of SimSiam emphasize the impact of
# the `stop_gradient` operator in the paper as it
# has an important role in the overall optimization.
- z = tf.stop_gradient(z)
- p = tf.math.l2_normalize(p, axis=1)
- z = tf.math.l2_normalize(z, axis=1)
+ z = ops.stop_gradient(z)
+ p = keras.utils.normalize(p, axis=1, order=2)
+ z = keras.utils.normalize(z, axis=1, order=2)
# Negative cosine similarity (minimizing this is
# equivalent to maximizing the similarity).
- return -tf.reduce_mean(tf.reduce_sum((p * z), axis=1))
+ return -ops.mean(ops.sum((p * z), axis=1))
```
We then define our training loop by overriding the `train_step()` function of the
-`tf.keras.Model` class.
+`keras.Model` class.
```python
-class SimSiam(tf.keras.Model):
+class SimSiam(keras.Model):
def __init__(self, encoder, predictor):
super().__init__()
self.encoder = encoder
self.predictor = predictor
- self.loss_tracker = tf.keras.metrics.Mean(name="loss")
+ self.loss_tracker = keras.metrics.Mean(name="loss")
@property
def metrics(self):
@@ -365,18 +373,18 @@ this should at least be 100 epochs.
# Create a cosine decay learning scheduler.
num_training_samples = len(x_train)
steps = EPOCHS * (num_training_samples // BATCH_SIZE)
-lr_decayed_fn = tf.keras.optimizers.schedules.CosineDecay(
+lr_decayed_fn = keras.optimizers.schedules.CosineDecay(
initial_learning_rate=0.03, decay_steps=steps
)
# Create an early stopping callback.
-early_stopping = tf.keras.callbacks.EarlyStopping(
+early_stopping = keras.callbacks.EarlyStopping(
monitor="loss", patience=5, restore_best_weights=True
)
# Compile model and start training.
simsiam = SimSiam(get_encoder(), get_predictor())
-simsiam.compile(optimizer=tf.keras.optimizers.SGD(lr_decayed_fn, momentum=0.6))
+simsiam.compile(optimizer=keras.optimizers.SGD(lr_decayed_fn, momentum=0.6))
history = simsiam.fit(ssl_ds, epochs=EPOCHS, callbacks=[early_stopping])
# Visualize the training progress of the model.
@@ -446,7 +454,7 @@ train_ds = (
test_ds = test_ds.batch(BATCH_SIZE).prefetch(AUTO)
# Extract the backbone ResNet20.
-backbone = tf.keras.Model(
+backbone = keras.Model(
simsiam.encoder.input, simsiam.encoder.get_layer("backbone_pool").output
)
@@ -455,13 +463,13 @@ backbone.trainable = False
inputs = layers.Input((CROP_TO, CROP_TO, 3))
x = backbone(inputs, training=False)
outputs = layers.Dense(10, activation="softmax")(x)
-linear_model = tf.keras.Model(inputs, outputs, name="linear_model")
+linear_model = keras.Model(inputs, outputs, name="linear_model")
# Compile model and start training.
linear_model.compile(
loss="sparse_categorical_crossentropy",
metrics=["accuracy"],
- optimizer=tf.keras.optimizers.SGD(lr_decayed_fn, momentum=0.9),
+ optimizer=keras.optimizers.SGD(lr_decayed_fn, momentum=0.9),
)
history = linear_model.fit(
train_ds, validation_data=test_ds, epochs=EPOCHS, callbacks=[early_stopping]
diff --git a/examples/vision/simsiam.py b/examples/vision/simsiam.py
index 2862763820..c4fae9ed52 100644
--- a/examples/vision/simsiam.py
+++ b/examples/vision/simsiam.py
@@ -2,7 +2,7 @@
Title: Self-supervised contrastive learning with SimSiam
Author: [Sayak Paul](https://twitter.com/RisingSayak)
Date created: 2021/03/19
-Last modified: 2021/03/20
+Last modified: 2023/12/29
Description: Implementation of a self-supervised learning method for computer vision.
Accelerator: GPU
"""
@@ -41,15 +41,18 @@
4. We then train our encoder to maximize the cosine similarity between the two different
versions of our dataset.
-This example requires TensorFlow 2.4 or higher.
"""
"""
## Setup
"""
-
-from tensorflow.keras import layers
-from tensorflow.keras import regularizers
+import os
+os.environ["KERAS_BACKEND"] = "tensorflow"
+import keras
+import keras_cv
+from keras import ops
+from keras import layers
+from keras import regularizers
import tensorflow as tf
import matplotlib.pyplot as plt
@@ -73,7 +76,7 @@
## Load the CIFAR-10 dataset
"""
-(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
+(x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data()
print(f"Total training examples: {len(x_train)}")
print(f"Total test examples: {len(x_test)}")
@@ -90,36 +93,44 @@
"""
+strength = [0.4, 0.4, 0.3, 0.1]
+
+random_flip = layers.RandomFlip(mode="horizontal_and_vertical")
+random_crop = layers.RandomCrop(CROP_TO, CROP_TO)
+random_brightness = layers.RandomBrightness(0.8 * strength[0])
+random_contrast = layers.RandomContrast((1 - 0.8 * strength[1], 1 + 0.8 * strength[1]))
+random_saturation = keras_cv.layers.RandomSaturation(
+ (0.5 - 0.8 * strength[2], 0.5 + 0.8 * strength[2])
+)
+random_hue = keras_cv.layers.RandomHue(0.2 * strength[3], [0, 255])
+grayscale = keras_cv.layers.Grayscale()
+
def flip_random_crop(image):
# With random crops we also apply horizontal flipping.
- image = tf.image.random_flip_left_right(image)
- image = tf.image.random_crop(image, (CROP_TO, CROP_TO, 3))
+ image = random_flip(image)
+ image = random_crop(image)
return image
-def color_jitter(x, strength=[0.4, 0.4, 0.4, 0.1]):
- x = tf.image.random_brightness(x, max_delta=0.8 * strength[0])
- x = tf.image.random_contrast(
- x, lower=1 - 0.8 * strength[1], upper=1 + 0.8 * strength[1]
- )
- x = tf.image.random_saturation(
- x, lower=1 - 0.8 * strength[2], upper=1 + 0.8 * strength[2]
- )
- x = tf.image.random_hue(x, max_delta=0.2 * strength[3])
+def color_jitter(x):
+ x = random_brightness(x)
+ x = random_contrast(x)
+ x = random_saturation(x)
+ x = random_hue(x)
# Affine transformations can disturb the natural range of
# RGB images, hence this is needed.
- x = tf.clip_by_value(x, 0, 255)
+ x = ops.clip(x, 0, 255)
return x
def color_drop(x):
- x = tf.image.rgb_to_grayscale(x)
- x = tf.tile(x, [1, 1, 3])
+ x = grayscale(x)
+ x = ops.tile(x, [1, 1, 3])
return x
def random_apply(func, x, p):
- if tf.random.uniform([], minval=0, maxval=1) < p:
+ if keras.random.uniform([], minval=0, maxval=1) < p:
return func(x)
else:
return x
@@ -232,11 +243,11 @@ def get_encoder():
PROJECT_DIM, use_bias=False, kernel_regularizer=regularizers.l2(WEIGHT_DECAY)
)(x)
outputs = layers.BatchNormalization()(x)
- return tf.keras.Model(inputs, outputs, name="encoder")
+ return keras.Model(inputs, outputs, name="encoder")
def get_predictor():
- model = tf.keras.Sequential(
+ model = keras.Sequential(
[
# Note the AutoEncoder-like structure.
layers.Input((PROJECT_DIM,)),
@@ -269,26 +280,26 @@ def compute_loss(p, z):
# The authors of SimSiam emphasize the impact of
# the `stop_gradient` operator in the paper as it
# has an important role in the overall optimization.
- z = tf.stop_gradient(z)
- p = tf.math.l2_normalize(p, axis=1)
- z = tf.math.l2_normalize(z, axis=1)
+ z = ops.stop_gradient(z)
+ p = keras.utils.normalize(p, axis=1, order=2)
+ z = keras.utils.normalize(z, axis=1, order=2)
# Negative cosine similarity (minimizing this is
# equivalent to maximizing the similarity).
- return -tf.reduce_mean(tf.reduce_sum((p * z), axis=1))
+ return -ops.mean(ops.sum((p * z), axis=1))
"""
We then define our training loop by overriding the `train_step()` function of the
-`tf.keras.Model` class.
+`keras.Model` class.
"""
-class SimSiam(tf.keras.Model):
+class SimSiam(keras.Model):
def __init__(self, encoder, predictor):
super().__init__()
self.encoder = encoder
self.predictor = predictor
- self.loss_tracker = tf.keras.metrics.Mean(name="loss")
+ self.loss_tracker = keras.metrics.Mean(name="loss")
@property
def metrics(self):
@@ -329,18 +340,18 @@ def train_step(self, data):
# Create a cosine decay learning scheduler.
num_training_samples = len(x_train)
steps = EPOCHS * (num_training_samples // BATCH_SIZE)
-lr_decayed_fn = tf.keras.optimizers.schedules.CosineDecay(
+lr_decayed_fn = keras.optimizers.schedules.CosineDecay(
initial_learning_rate=0.03, decay_steps=steps
)
# Create an early stopping callback.
-early_stopping = tf.keras.callbacks.EarlyStopping(
+early_stopping = keras.callbacks.EarlyStopping(
monitor="loss", patience=5, restore_best_weights=True
)
# Compile model and start training.
simsiam = SimSiam(get_encoder(), get_predictor())
-simsiam.compile(optimizer=tf.keras.optimizers.SGD(lr_decayed_fn, momentum=0.6))
+simsiam.compile(optimizer=keras.optimizers.SGD(lr_decayed_fn, momentum=0.6))
history = simsiam.fit(ssl_ds, epochs=EPOCHS, callbacks=[early_stopping])
# Visualize the training progress of the model.
@@ -391,7 +402,7 @@ def train_step(self, data):
test_ds = test_ds.batch(BATCH_SIZE).prefetch(AUTO)
# Extract the backbone ResNet20.
-backbone = tf.keras.Model(
+backbone = keras.Model(
simsiam.encoder.input, simsiam.encoder.get_layer("backbone_pool").output
)
@@ -400,13 +411,13 @@ def train_step(self, data):
inputs = layers.Input((CROP_TO, CROP_TO, 3))
x = backbone(inputs, training=False)
outputs = layers.Dense(10, activation="softmax")(x)
-linear_model = tf.keras.Model(inputs, outputs, name="linear_model")
+linear_model = keras.Model(inputs, outputs, name="linear_model")
# Compile model and start training.
linear_model.compile(
loss="sparse_categorical_crossentropy",
metrics=["accuracy"],
- optimizer=tf.keras.optimizers.SGD(lr_decayed_fn, momentum=0.9),
+ optimizer=keras.optimizers.SGD(lr_decayed_fn, momentum=0.9),
)
history = linear_model.fit(
train_ds, validation_data=test_ds, epochs=EPOCHS, callbacks=[early_stopping]