Skip to content

Commit

Permalink
Update Perceiver image classification example (keras-team#369)
Browse files Browse the repository at this point in the history
* added perceiver image classifier

* removed tfa dependency and replaced tfa LAMB optimizer and Input names

* modify comment in tutorial
  • Loading branch information
divyashreepathihalli authored Jun 16, 2023
1 parent 76bb8c4 commit 2ae265c
Showing 1 changed file with 5 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@
import tensorflow as tf
import keras_core as keras
from keras_core import layers
import tensorflow_addons as tfa

"""
## Prepare the data
Expand Down Expand Up @@ -206,9 +205,9 @@ def create_cross_attention_module(
):
inputs = {
# Recieve the latent array as an input of shape [1, latent_dim, projection_dim].
"latent_array": layers.Input(shape=(latent_dim, projection_dim)),
"latent_array": layers.Input(shape=(latent_dim, projection_dim), name="latent_array"),
# Recieve the data_array (encoded image) as an input of shape [batch_size, data_dim, projection_dim].
"data_array": layers.Input(shape=(data_dim, projection_dim)),
"data_array": layers.Input(shape=(data_dim, projection_dim), name="data_array"),
}

# Apply layer norm to the inputs
Expand Down Expand Up @@ -399,10 +398,10 @@ def call(self, inputs):


def run_experiment(model):
# Create LAMB optimizer with weight decay.
optimizer = tfa.optimizers.LAMB(
# Create Adam optimizer with weight decay.
optimizer = keras.optimizers.Adam(
learning_rate=learning_rate,
weight_decay_rate=weight_decay,
weight_decay=weight_decay,
)

# Compile the model.
Expand Down

0 comments on commit 2ae265c

Please sign in to comment.