Skip to content

Commit

Permalink
Remove keras.Model checks for tf_keras compatibiltiy
Browse files Browse the repository at this point in the history
This cl is for fixing failures at the colab.

PiperOrigin-RevId: 605290842
  • Loading branch information
abattery authored and tensorflower-gardener committed Feb 14, 2024
1 parent 590c8de commit d4f9574
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@
"\n",
"import tensorflow as tf\n",
"\n",
"from tensorflow import keras"
"import tf_keras as keras"
]
},
{
Expand All @@ -146,7 +146,7 @@
"outputs": [],
"source": [
"# Load MNIST dataset\n",
"mnist = keras.datasets.mnist\n",
"mnist = tf.keras.datasets.mnist\n",
"(train_images, train_labels), (test_images, test_labels) = mnist.load_data()\n",
"\n",
"# Normalize the input image so that each pixel value is between 0 to 1.\n",
Expand Down Expand Up @@ -216,7 +216,6 @@
"outputs": [],
"source": [
"import tensorflow_model_optimization as tfmot\n",
"import tf_keras as keras\n",
"\n",
"quantize_model = tfmot.quantization.keras.quantize_model\n",
"\n",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -235,19 +235,18 @@ def _cluster_weights(to_cluster,
cluster_centroids_init))

def _add_clustering_wrapper(layer):
if isinstance(layer, keras.Model):
if (isinstance(layer, keras.Sequential)) or (
hasattr(layer, '_is_graph_network') and layer._is_graph_network # pylint: disable=protected-access
):
return keras.models.clone_model(
layer, input_tensors=None, clone_function=_add_clustering_wrapper
)
if hasattr(layer, '_is_graph_network') and not layer._is_graph_network: # pylint: disable=protected-access
# Check whether the model is a subclass.
# NB: This check is copied from keras.py file in tensorflow.
# There is no available public API to do this check.
# pylint: disable=protected-access
if not layer._is_graph_network and not isinstance(
layer, keras.models.Sequential
):
raise ValueError('Subclassed models are not supported currently.')

return keras.models.clone_model(
layer, input_tensors=None, clone_function=_add_clustering_wrapper
)
raise ValueError('Subclassed models are not supported currently.')
if isinstance(layer, cluster_wrapper.ClusterWeights):
return layer
if isinstance(layer, InputLayer):
Expand Down Expand Up @@ -292,15 +291,16 @@ def _wrap_list(layers):

return output

if isinstance(to_cluster, keras.Model):
if isinstance(to_cluster, keras.Sequantial):
return keras.models.clone_model(
to_cluster, input_tensors=None, clone_function=_add_clustering_wrapper
)
if isinstance(to_cluster, Layer):
return _add_clustering_wrapper(layer=to_cluster)
if isinstance(to_cluster, list):
return _wrap_list(to_cluster)

# Assuming the given layer is supported.
return _add_clustering_wrapper(layer=to_cluster)


def strip_clustering(model):
"""Strips clustering wrappers from the model.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -195,9 +195,9 @@ def _add_pruning_wrapper(layer):
'sparsity_m_by_n': sparsity_m_by_n,
}

is_sequential_or_functional = isinstance(
to_prune, keras.Model) and (isinstance(to_prune, keras.Sequential) or
to_prune._is_graph_network)
is_sequential_or_functional = (isinstance(to_prune, keras.Sequential)) or (
hasattr(to_prune, '_is_graph_network') and to_prune._is_graph_network # pylint: disable=protected-access
)

# A subclassed model is also a subclass of keras.layers.Layer.
is_keras_layer = isinstance(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -167,11 +167,6 @@ def __init__(self,
'Unsupported pooling type \'{}\'. Should be \'AVG\' or \'MAX\'.'
.format(block_pooling_type))

if not isinstance(layer, keras.layers.Layer):
raise ValueError(
'Please initialize `Prune` layer with a '
'`Layer` instance. You passed: {input}'.format(input=layer))

# TODO(pulkitb): This should be pushed up to the wrappers.py
# Name the layer using the wrapper and underlying layer name.
# Prune(Dense) becomes prune_dense_1
Expand Down

0 comments on commit d4f9574

Please sign in to comment.