diff --git a/ci/kokoro/gcp_ubuntu/Dockerfile b/ci/kokoro/gcp_ubuntu/Dockerfile index a7490ba91..342f28b2e 100644 --- a/ci/kokoro/gcp_ubuntu/Dockerfile +++ b/ci/kokoro/gcp_ubuntu/Dockerfile @@ -26,10 +26,10 @@ # run CI against. # Latest Ubuntu LTS (Focal), at the moment. -FROM ubuntu:20.04 +FROM ubuntu:22.04 -ARG BAZEL_VERSION=4.2.2 -ARG TENSORFLOW_VERSION=2.12.0 +ARG BAZEL_VERSION=7.0.2 +ARG TENSORFLOW_VERSION=2.15.0 RUN apt-get update -y diff --git a/ci/kokoro/run_bazel_unittests.sh b/ci/kokoro/run_bazel_unittests.sh index 8dd4fb530..336864441 100755 --- a/ci/kokoro/run_bazel_unittests.sh +++ b/ci/kokoro/run_bazel_unittests.sh @@ -33,9 +33,7 @@ set -o pipefail # Treat the failure of a command in a pipeline as error. # set -x pip install --requirement "requirements.txt" -# TODO(b/232345872): Not in list of requirements, but needed for EPR test. -# The EPR test relies on a feature (PowerLawEntropyModel) introduced in 2.10.0. -pip install tensorflow-compression~=2.11.0 +pip install tensorflow-compression>=2.11.0 # Run the tests. # Some tests requiring more RAM than the CI machine provides are disabled. diff --git a/tensorflow_model_optimization/g3doc/_index.yaml b/tensorflow_model_optimization/g3doc/_index.yaml index 80817fde9..37c299200 100644 --- a/tensorflow_model_optimization/g3doc/_index.yaml +++ b/tensorflow_model_optimization/g3doc/_index.yaml @@ -56,8 +56,9 @@ landing_page:
         import tensorflow as tf
         import tensorflow_model_optimization as tfmot
+        import tf_keras as keras
 
-        model = tf.keras.Sequential([...])
+        model = keras.Sequential([...])
 
         pruning_schedule = tfmot.sparsity.keras.PolynomialDecay(
                               initial_sparsity=0.0, final_sparsity=0.5,
diff --git a/tensorflow_model_optimization/g3doc/guide/clustering/clustering_comprehensive_guide.ipynb b/tensorflow_model_optimization/g3doc/guide/clustering/clustering_comprehensive_guide.ipynb
index 886f31ed2..4bc478eb8 100644
--- a/tensorflow_model_optimization/g3doc/guide/clustering/clustering_comprehensive_guide.ipynb
+++ b/tensorflow_model_optimization/g3doc/guide/clustering/clustering_comprehensive_guide.ipynb
@@ -102,6 +102,7 @@
         "! pip install -q tensorflow-model-optimization\n",
         "\n",
         "import tensorflow as tf\n",
+        "import tf_keras as keras\n",
         "import numpy as np\n",
         "import tempfile\n",
         "import os\n",
@@ -110,18 +111,18 @@
         "input_dim = 20\n",
         "output_dim = 20\n",
         "x_train = np.random.randn(1, input_dim).astype(np.float32)\n",
-        "y_train = tf.keras.utils.to_categorical(np.random.randn(1), num_classes=output_dim)\n",
+        "y_train = keras.utils.to_categorical(np.random.randn(1), num_classes=output_dim)\n",
         "\n",
         "def setup_model():\n",
-        "  model = tf.keras.Sequential([\n",
-        "      tf.keras.layers.Dense(input_dim, input_shape=[input_dim]),\n",
-        "      tf.keras.layers.Flatten()\n",
+        "  model = keras.Sequential([\n",
+        "      keras.layers.Dense(input_dim, input_shape=[input_dim]),\n",
+        "      keras.layers.Flatten()\n",
         "  ])\n",
         "  return model\n",
         "\n",
         "def train_model(model):\n",
         "  model.compile(\n",
-        "      loss=tf.keras.losses.categorical_crossentropy,\n",
+        "      loss=keras.losses.categorical_crossentropy,\n",
         "      optimizer='adam',\n",
         "      metrics=['accuracy']\n",
         "  )\n",
@@ -243,7 +244,7 @@
         "**Tips** for better model accuracy:\n",
         "\n",
         "* You must pass a pre-trained model with acceptable accuracy to this API. Training models from scratch with clustering results in subpar accuracy.\n",
-        "* Cluster later layers with more redundant parameters (e.g. `tf.keras.layers.Dense`, `tf.keras.layers.Conv2D`), as opposed to the early layers.\n",
+        "* Cluster later layers with more redundant parameters (e.g. `keras.layers.Dense`, `keras.layers.Conv2D`), as opposed to the early layers.\n",
         "* Freeze early layers prior to the clustered layers during fine-tuning. Treat the number of frozen layers as a hyperparameter. Empirically, freezing most early layers is ideal for the current clustering API.\n",
         "* Avoid clustering critical layers (e.g. attention mechanism).\n",
         "\n",
@@ -265,13 +266,13 @@
         "# Helper function uses `cluster_weights` to make only \n",
         "# the Dense layers train with clustering\n",
         "def apply_clustering_to_dense(layer):\n",
-        "  if isinstance(layer, tf.keras.layers.Dense):\n",
+        "  if isinstance(layer, keras.layers.Dense):\n",
         "    return cluster_weights(layer, **clustering_params)\n",
         "  return layer\n",
         "\n",
-        "# Use `tf.keras.models.clone_model` to apply `apply_clustering_to_dense` \n",
+        "# Use `keras.models.clone_model` to apply `apply_clustering_to_dense` \n",
         "# to the layers of the model.\n",
-        "clustered_model = tf.keras.models.clone_model(\n",
+        "clustered_model = keras.models.clone_model(\n",
         "    base_model,\n",
         "    clone_function=apply_clustering_to_dense,\n",
         ")\n",
@@ -326,7 +327,7 @@
       },
       "outputs": [],
       "source": [
-        "class MyDenseLayer(tf.keras.layers.Dense, tfmot.clustering.keras.ClusterableLayer):\n",
+        "class MyDenseLayer(keras.layers.Dense, tfmot.clustering.keras.ClusterableLayer):\n",
         "\n",
         "  def get_clusterable_weights(self):\n",
         "   # Cluster kernel and bias. This is just an example, clustering\n",
@@ -334,9 +335,9 @@
         "   return [('kernel', self.kernel), ('bias', self.bias)]\n",
         "\n",
         "# Use `cluster_weights` to make the `MyDenseLayer` layer train with clustering as usual.\n",
-        "model_for_clustering = tf.keras.Sequential([\n",
+        "model_for_clustering = keras.Sequential([\n",
         "  tfmot.clustering.keras.cluster_weights(MyDenseLayer(20, input_shape=[input_dim]), **clustering_params),\n",
-        "  tf.keras.layers.Flatten()\n",
+        "  keras.layers.Flatten()\n",
         "])\n",
         "\n",
         "model_for_clustering.summary()"
@@ -348,7 +349,7 @@
         "id": "SYlWPXEWmxTs"
       },
       "source": [
-        "You may also use `tfmot.clustering.keras.ClusterableLayer` to cluster a keras custom layer. To do this, you extend `tf.keras.Layer` as usual and implement the `__init__`, `call`, and `build` functions, but you also need to extend the `clusterable_layer.ClusterableLayer` class and implement:\n",
+        "You may also use `tfmot.clustering.keras.ClusterableLayer` to cluster a keras custom layer. To do this, you extend `keras.Layer` as usual and implement the `__init__`, `call`, and `build` functions, but you also need to extend the `clusterable_layer.ClusterableLayer` class and implement:\n",
         "1. `get_clusterable_weights`, where you specify the weight kernel to be clustered, as shown above.\n",
         "2. `get_clusterable_algorithm`, where you specify the clustering algorithm for the weight tensor. This is because you need to specify how the custom layer weights are shaped for clustering. The returned clustering algorithm class should be derived from the `clustering_algorithm.ClusteringAlgorithm` class and the function `get_pulling_indices` should be overwritten. An example of this function, which supports weights of ranks 1D, 2D, and 3D, can be found [here]( https://github.com/tensorflow/model-optimization/blob/18e87d262e536c9a742aef700880e71b47a7f768/tensorflow_model_optimization/python/core/clustering/keras/clustering_algorithm.py#L62).\n",
         "\n",
@@ -392,7 +393,7 @@
         "\n",
         "# `cluster_scope` is needed for deserializing HDF5 models.\n",
         "with tfmot.clustering.keras.cluster_scope():\n",
-        "  loaded_model = tf.keras.models.load_model(keras_model_file)\n",
+        "  loaded_model = keras.models.load_model(keras_model_file)\n",
         "\n",
         "loaded_model.summary()"
       ]
@@ -460,7 +461,7 @@
         "clustered_model = cluster_weights(model, **clustering_params)\n",
         "\n",
         "clustered_model.compile(\n",
-        "    loss=tf.keras.losses.categorical_crossentropy,\n",
+        "    loss=keras.losses.categorical_crossentropy,\n",
         "    optimizer='adam',\n",
         "    metrics=['accuracy']\n",
         ")\n",
diff --git a/tensorflow_model_optimization/g3doc/guide/clustering/clustering_example.ipynb b/tensorflow_model_optimization/g3doc/guide/clustering/clustering_example.ipynb
index 3b4c8b867..d2944cdcf 100755
--- a/tensorflow_model_optimization/g3doc/guide/clustering/clustering_example.ipynb
+++ b/tensorflow_model_optimization/g3doc/guide/clustering/clustering_example.ipynb
@@ -82,7 +82,7 @@
         "\n",
         "In the tutorial, you will:\n",
         "\n",
-        "1. Train a `tf.keras` model for the MNIST dataset from scratch.\n",
+        "1. Train a `keras` model for the MNIST dataset from scratch.\n",
         "2. Fine-tune the model by applying the weight clustering API and see the accuracy.\n",
         "3. Create a 6x smaller TF and TFLite models from clustering.\n",
         "4. Create a 8x smaller TFLite model from combining weight clustering and post-training quantization.\n",
@@ -120,7 +120,7 @@
       "outputs": [],
       "source": [
         "import tensorflow as tf\n",
-        "from tensorflow import keras\n",
+        "import tf_keras as keras\n",
         "\n",
         "import numpy as np\n",
         "import tempfile\n",
@@ -134,7 +134,7 @@
         "id": "dKzOfl5FSGPL"
       },
       "source": [
-        "## Train a tf.keras model for MNIST without clustering"
+        "## Train a keras model for MNIST without clustering"
       ]
     },
     {
@@ -146,8 +146,7 @@
       "outputs": [],
       "source": [
         "# Load MNIST dataset\n",
-        "mnist = keras.datasets.mnist\n",
-        "(train_images, train_labels), (test_images, test_labels) = mnist.load_data()\n",
+        "(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data()\n",
         "\n",
         "# Normalize the input image so that each pixel value is between 0 to 1.\n",
         "train_images = train_images / 255.0\n",
@@ -165,7 +164,7 @@
         "\n",
         "# Train the digit classification model\n",
         "model.compile(optimizer='adam',\n",
-        "              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n",
+        "              loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n",
         "              metrics=['accuracy'])\n",
         "\n",
         "model.fit(\n",
@@ -200,7 +199,7 @@
         "\n",
         "_, keras_file = tempfile.mkstemp('.h5')\n",
         "print('Saving model to: ', keras_file)\n",
-        "tf.keras.models.save_model(model, keras_file, include_optimizer=False)"
+        "keras.models.save_model(model, keras_file, include_optimizer=False)"
       ]
     },
     {
@@ -261,10 +260,10 @@
         "clustered_model = cluster_weights(model, **clustering_params)\n",
         "\n",
         "# Use smaller learning rate for fine-tuning clustered model\n",
-        "opt = tf.keras.optimizers.Adam(learning_rate=1e-5)\n",
+        "opt = keras.optimizers.Adam(learning_rate=1e-5)\n",
         "\n",
         "clustered_model.compile(\n",
-        "  loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n",
+        "  loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n",
         "  optimizer=opt,\n",
         "  metrics=['accuracy'])\n",
         "\n",
@@ -362,7 +361,7 @@
         "\n",
         "_, clustered_keras_file = tempfile.mkstemp('.h5')\n",
         "print('Saving clustered model to: ', clustered_keras_file)\n",
-        "tf.keras.models.save_model(final_model, clustered_keras_file, \n",
+        "keras.models.save_model(final_model, clustered_keras_file, \n",
         "                           include_optimizer=False)"
       ]
     },
diff --git a/tensorflow_model_optimization/g3doc/guide/clustering/index.md b/tensorflow_model_optimization/g3doc/guide/clustering/index.md
index 9c7312f85..d0c2a1e32 100644
--- a/tensorflow_model_optimization/g3doc/guide/clustering/index.md
+++ b/tensorflow_model_optimization/g3doc/guide/clustering/index.md
@@ -21,7 +21,7 @@ Please note that clustering will provide reduced benefits for convolution and de
 
 Users can apply clustering with the following APIs:
 
-*   Model building: `tf.keras` with only Sequential and Functional models
+*   Model building: `keras` with only Sequential and Functional models
 *   TensorFlow versions: TF 1.x for versions 1.14+ and 2.x.
     *   `tf.compat.v1` with a TF 2.X package and `tf.compat.v2` with a TF 1.X
         package are not supported.
diff --git a/tensorflow_model_optimization/g3doc/guide/combine/cqat_example.ipynb b/tensorflow_model_optimization/g3doc/guide/combine/cqat_example.ipynb
index 7d0b6f281..a754a2836 100644
--- a/tensorflow_model_optimization/g3doc/guide/combine/cqat_example.ipynb
+++ b/tensorflow_model_optimization/g3doc/guide/combine/cqat_example.ipynb
@@ -80,7 +80,7 @@
         "\n",
         "In the tutorial, you will:\n",
         "\n",
-        "1. Train a `tf.keras` model for the MNIST dataset from scratch.\n",
+        "1. Train a `keras` model for the MNIST dataset from scratch.\n",
         "2. Fine-tune the model with clustering and see the accuracy.\n",
         "3. Apply QAT and observe the loss of clusters.\n",
         "4. Apply CQAT and observe that the clustering applied earlier has been preserved.\n",
@@ -119,6 +119,7 @@
       "outputs": [],
       "source": [
         "import tensorflow as tf\n",
+        "import tf_keras as keras\n",
         "\n",
         "import numpy as np\n",
         "import tempfile\n",
@@ -132,7 +133,7 @@
         "id": "dKzOfl5FSGPL"
       },
       "source": [
-        "## Train a tf.keras model for MNIST without clustering"
+        "## Train a keras model for MNIST without clustering"
       ]
     },
     {
@@ -144,26 +145,26 @@
       "outputs": [],
       "source": [
         "# Load MNIST dataset\n",
-        "mnist = tf.keras.datasets.mnist\n",
+        "mnist = keras.datasets.mnist\n",
         "(train_images, train_labels), (test_images, test_labels) = mnist.load_data()\n",
         "\n",
         "# Normalize the input image so that each pixel value is between 0 to 1.\n",
         "train_images = train_images / 255.0\n",
         "test_images  = test_images / 255.0\n",
         "\n",
-        "model = tf.keras.Sequential([\n",
-        "  tf.keras.layers.InputLayer(input_shape=(28, 28)),\n",
-        "  tf.keras.layers.Reshape(target_shape=(28, 28, 1)),\n",
-        "  tf.keras.layers.Conv2D(filters=12, kernel_size=(3, 3),\n",
+        "model = keras.Sequential([\n",
+        "  keras.layers.InputLayer(input_shape=(28, 28)),\n",
+        "  keras.layers.Reshape(target_shape=(28, 28, 1)),\n",
+        "  keras.layers.Conv2D(filters=12, kernel_size=(3, 3),\n",
         "                         activation=tf.nn.relu),\n",
-        "  tf.keras.layers.MaxPooling2D(pool_size=(2, 2)),\n",
-        "  tf.keras.layers.Flatten(),\n",
-        "  tf.keras.layers.Dense(10)\n",
+        "  keras.layers.MaxPooling2D(pool_size=(2, 2)),\n",
+        "  keras.layers.Flatten(),\n",
+        "  keras.layers.Dense(10)\n",
         "])\n",
         "\n",
         "# Train the digit classification model\n",
         "model.compile(optimizer='adam',\n",
-        "              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n",
+        "              loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n",
         "              metrics=['accuracy'])\n",
         "\n",
         "model.fit(\n",
@@ -198,7 +199,7 @@
         "\n",
         "_, keras_file = tempfile.mkstemp('.h5')\n",
         "print('Saving model to: ', keras_file)\n",
-        "tf.keras.models.save_model(model, keras_file, include_optimizer=False)"
+        "keras.models.save_model(model, keras_file, include_optimizer=False)"
       ]
     },
     {
@@ -259,10 +260,10 @@
         "clustered_model = cluster_weights(model, **clustering_params)\n",
         "\n",
         "# Use smaller learning rate for fine-tuning\n",
-        "opt = tf.keras.optimizers.Adam(learning_rate=1e-5)\n",
+        "opt = keras.optimizers.Adam(learning_rate=1e-5)\n",
         "\n",
         "clustered_model.compile(\n",
-        "  loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n",
+        "  loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n",
         "  optimizer=opt,\n",
         "  metrics=['accuracy'])\n",
         "\n",
@@ -323,7 +324,7 @@
         "def print_model_weight_clusters(model):\n",
         "\n",
         "    for layer in model.layers:\n",
-        "        if isinstance(layer, tf.keras.layers.Wrapper):\n",
+        "        if isinstance(layer, keras.layers.Wrapper):\n",
         "            weights = layer.trainable_weights\n",
         "        else:\n",
         "            weights = layer.weights\n",
@@ -414,7 +415,7 @@
         "qat_model = tfmot.quantization.keras.quantize_model(stripped_clustered_model)\n",
         "\n",
         "qat_model.compile(optimizer='adam',\n",
-        "              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n",
+        "              loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n",
         "              metrics=['accuracy'])\n",
         "print('Train qat model:')\n",
         "qat_model.fit(train_images, train_labels, batch_size=128, epochs=1, validation_split=0.1)\n",
@@ -427,7 +428,7 @@
         "              tfmot.experimental.combine.Default8BitClusterPreserveQuantizeScheme())\n",
         "\n",
         "cqat_model.compile(optimizer='adam',\n",
-        "              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n",
+        "              loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n",
         "              metrics=['accuracy'])\n",
         "print('Train cqat model:')\n",
         "cqat_model.fit(train_images, train_labels, batch_size=128, epochs=1, validation_split=0.1)"
diff --git a/tensorflow_model_optimization/g3doc/guide/combine/pcqat_example.ipynb b/tensorflow_model_optimization/g3doc/guide/combine/pcqat_example.ipynb
index 16f12ad5c..af269f9e4 100755
--- a/tensorflow_model_optimization/g3doc/guide/combine/pcqat_example.ipynb
+++ b/tensorflow_model_optimization/g3doc/guide/combine/pcqat_example.ipynb
@@ -80,7 +80,7 @@
         "\n",
         "In the tutorial, you will:\n",
         "\n",
-        "1. Train a `tf.keras` model for the MNIST dataset from scratch.\n",
+        "1. Train a `keras` model for the MNIST dataset from scratch.\n",
         "2. Fine-tune the model with pruning and see the accuracy and observe that the model was successfully pruned.\n",
         "3. Apply sparsity preserving clustering on the pruned model and observe that the sparsity applied earlier has been preserved.\n",
         "4. Apply QAT and observe the loss of sparsity and clusters.\n",
@@ -121,6 +121,7 @@
       "outputs": [],
       "source": [
         "import tensorflow as tf\n",
+        "import tf_keras as keras\n",
         "\n",
         "import numpy as np\n",
         "import tempfile\n",
@@ -134,7 +135,7 @@
         "id": "dKzOfl5FSGPL"
       },
       "source": [
-        "## Train a tf.keras model for MNIST to be pruned and clustered"
+        "## Train a keras model for MNIST to be pruned and clustered"
       ]
     },
     {
@@ -146,28 +147,28 @@
       "outputs": [],
       "source": [
         "# Load MNIST dataset\n",
-        "mnist = tf.keras.datasets.mnist\n",
+        "mnist = keras.datasets.mnist\n",
         "(train_images, train_labels), (test_images, test_labels) = mnist.load_data()\n",
         "\n",
         "# Normalize the input image so that each pixel value is between 0 to 1.\n",
         "train_images = train_images / 255.0\n",
         "test_images  = test_images / 255.0\n",
         "\n",
-        "model = tf.keras.Sequential([\n",
-        "  tf.keras.layers.InputLayer(input_shape=(28, 28)),\n",
-        "  tf.keras.layers.Reshape(target_shape=(28, 28, 1)),\n",
-        "  tf.keras.layers.Conv2D(filters=12, kernel_size=(3, 3),\n",
+        "model = keras.Sequential([\n",
+        "  keras.layers.InputLayer(input_shape=(28, 28)),\n",
+        "  keras.layers.Reshape(target_shape=(28, 28, 1)),\n",
+        "  keras.layers.Conv2D(filters=12, kernel_size=(3, 3),\n",
         "                         activation=tf.nn.relu),\n",
-        "  tf.keras.layers.MaxPooling2D(pool_size=(2, 2)),\n",
-        "  tf.keras.layers.Flatten(),\n",
-        "  tf.keras.layers.Dense(10)\n",
+        "  keras.layers.MaxPooling2D(pool_size=(2, 2)),\n",
+        "  keras.layers.Flatten(),\n",
+        "  keras.layers.Dense(10)\n",
         "])\n",
         "\n",
-        "opt = tf.keras.optimizers.Adam(learning_rate=1e-3)\n",
+        "opt = keras.optimizers.Adam(learning_rate=1e-3)\n",
         "\n",
         "# Train the digit classification model\n",
         "model.compile(optimizer=opt,\n",
-        "              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n",
+        "              loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n",
         "              metrics=['accuracy'])\n",
         "\n",
         "model.fit(\n",
@@ -202,7 +203,7 @@
         "\n",
         "_, keras_file = tempfile.mkstemp('.h5')\n",
         "print('Saving model to: ', keras_file)\n",
-        "tf.keras.models.save_model(model, keras_file, include_optimizer=False)"
+        "keras.models.save_model(model, keras_file, include_optimizer=False)"
       ]
     },
     {
@@ -257,10 +258,10 @@
         "pruned_model = prune_low_magnitude(model, **pruning_params)\n",
         "\n",
         "# Use smaller learning rate for fine-tuning\n",
-        "opt = tf.keras.optimizers.Adam(learning_rate=1e-5)\n",
+        "opt = keras.optimizers.Adam(learning_rate=1e-5)\n",
         "\n",
         "pruned_model.compile(\n",
-        "  loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n",
+        "  loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n",
         "  optimizer=opt,\n",
         "  metrics=['accuracy'])"
       ]
@@ -312,7 +313,7 @@
       "source": [
         "def print_model_weights_sparsity(model):\n",
         "    for layer in model.layers:\n",
-        "        if isinstance(layer, tf.keras.layers.Wrapper):\n",
+        "        if isinstance(layer, keras.layers.Wrapper):\n",
         "            weights = layer.trainable_weights\n",
         "        else:\n",
         "            weights = layer.weights\n",
@@ -328,7 +329,7 @@
         "\n",
         "def print_model_weight_clusters(model):\n",
         "    for layer in model.layers:\n",
-        "        if isinstance(layer, tf.keras.layers.Wrapper):\n",
+        "        if isinstance(layer, keras.layers.Wrapper):\n",
         "            weights = layer.trainable_weights\n",
         "        else:\n",
         "            weights = layer.weights\n",
@@ -410,7 +411,7 @@
         "sparsity_clustered_model = cluster_weights(stripped_pruned_model, **clustering_params)\n",
         "\n",
         "sparsity_clustered_model.compile(optimizer='adam',\n",
-        "              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n",
+        "              loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n",
         "              metrics=['accuracy'])\n",
         "\n",
         "print('Train sparsity preserving clustering model:')\n",
@@ -473,7 +474,7 @@
         "qat_model = tfmot.quantization.keras.quantize_model(stripped_clustered_model)\n",
         "\n",
         "qat_model.compile(optimizer='adam',\n",
-        "              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n",
+        "              loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n",
         "              metrics=['accuracy'])\n",
         "print('Train qat model:')\n",
         "qat_model.fit(train_images, train_labels, batch_size=128, epochs=1, validation_split=0.1)\n",
@@ -486,7 +487,7 @@
         "              tfmot.experimental.combine.Default8BitClusterPreserveQuantizeScheme(preserve_sparsity=True))\n",
         "\n",
         "pcqat_model.compile(optimizer='adam',\n",
-        "              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n",
+        "              loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n",
         "              metrics=['accuracy'])\n",
         "print('Train pcqat model:')\n",
         "pcqat_model.fit(train_images, train_labels, batch_size=128, epochs=1, validation_split=0.1)"
diff --git a/tensorflow_model_optimization/g3doc/guide/combine/pqat_example.ipynb b/tensorflow_model_optimization/g3doc/guide/combine/pqat_example.ipynb
index 51a9313aa..a5ee8c45a 100755
--- a/tensorflow_model_optimization/g3doc/guide/combine/pqat_example.ipynb
+++ b/tensorflow_model_optimization/g3doc/guide/combine/pqat_example.ipynb
@@ -80,7 +80,7 @@
         "\n",
         "In the tutorial, you will:\n",
         "\n",
-        "1. Train a `tf.keras` model for the MNIST dataset from scratch.\n",
+        "1. Train a `keras` model for the MNIST dataset from scratch.\n",
         "2. Fine-tune the model with pruning, using the sparsity API, and see the accuracy.\n",
         "3. Apply QAT and observe the loss of sparsity.\n",
         "4. Apply PQAT and observe that the sparsity applied earlier has been preserved.\n",
@@ -119,6 +119,7 @@
       "outputs": [],
       "source": [
         "import tensorflow as tf\n",
+        "import tf_keras as keras\n",
         "\n",
         "import numpy as np\n",
         "import tempfile\n",
@@ -132,7 +133,7 @@
         "id": "dKzOfl5FSGPL"
       },
       "source": [
-        "## Train a tf.keras model for MNIST without pruning"
+        "## Train a keras model for MNIST without pruning"
       ]
     },
     {
@@ -144,26 +145,26 @@
       "outputs": [],
       "source": [
         "# Load MNIST dataset\n",
-        "mnist = tf.keras.datasets.mnist\n",
+        "mnist = keras.datasets.mnist\n",
         "(train_images, train_labels), (test_images, test_labels) = mnist.load_data()\n",
         "\n",
         "# Normalize the input image so that each pixel value is between 0 to 1.\n",
         "train_images = train_images / 255.0\n",
         "test_images  = test_images / 255.0\n",
         "\n",
-        "model = tf.keras.Sequential([\n",
-        "  tf.keras.layers.InputLayer(input_shape=(28, 28)),\n",
-        "  tf.keras.layers.Reshape(target_shape=(28, 28, 1)),\n",
-        "  tf.keras.layers.Conv2D(filters=12, kernel_size=(3, 3),\n",
+        "model = keras.Sequential([\n",
+        "  keras.layers.InputLayer(input_shape=(28, 28)),\n",
+        "  keras.layers.Reshape(target_shape=(28, 28, 1)),\n",
+        "  keras.layers.Conv2D(filters=12, kernel_size=(3, 3),\n",
         "                         activation=tf.nn.relu),\n",
-        "  tf.keras.layers.MaxPooling2D(pool_size=(2, 2)),\n",
-        "  tf.keras.layers.Flatten(),\n",
-        "  tf.keras.layers.Dense(10)\n",
+        "  keras.layers.MaxPooling2D(pool_size=(2, 2)),\n",
+        "  keras.layers.Flatten(),\n",
+        "  keras.layers.Dense(10)\n",
         "])\n",
         "\n",
         "# Train the digit classification model\n",
         "model.compile(optimizer='adam',\n",
-        "              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n",
+        "              loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n",
         "              metrics=['accuracy'])\n",
         "\n",
         "model.fit(\n",
@@ -198,7 +199,7 @@
         "\n",
         "_, keras_file = tempfile.mkstemp('.h5')\n",
         "print('Saving model to: ', keras_file)\n",
-        "tf.keras.models.save_model(model, keras_file, include_optimizer=False)"
+        "keras.models.save_model(model, keras_file, include_optimizer=False)"
       ]
     },
     {
@@ -260,10 +261,10 @@
         "pruned_model = prune_low_magnitude(model, **pruning_params)\n",
         "\n",
         "# Use smaller learning rate for fine-tuning\n",
-        "opt = tf.keras.optimizers.Adam(learning_rate=1e-5)\n",
+        "opt = keras.optimizers.Adam(learning_rate=1e-5)\n",
         "\n",
         "pruned_model.compile(\n",
-        "  loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n",
+        "  loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n",
         "  optimizer=opt,\n",
         "  metrics=['accuracy'])\n",
         "\n",
@@ -325,7 +326,7 @@
         "def print_model_weights_sparsity(model):\n",
         "\n",
         "    for layer in model.layers:\n",
-        "        if isinstance(layer, tf.keras.layers.Wrapper):\n",
+        "        if isinstance(layer, keras.layers.Wrapper):\n",
         "            weights = layer.trainable_weights\n",
         "        else:\n",
         "            weights = layer.weights\n",
@@ -417,7 +418,7 @@
         "qat_model = tfmot.quantization.keras.quantize_model(stripped_pruned_model)\n",
         "\n",
         "qat_model.compile(optimizer='adam',\n",
-        "              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n",
+        "              loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n",
         "              metrics=['accuracy'])\n",
         "print('Train qat model:')\n",
         "qat_model.fit(train_images, train_labels, batch_size=128, epochs=1, validation_split=0.1)\n",
@@ -430,7 +431,7 @@
         "              tfmot.experimental.combine.Default8BitPrunePreserveQuantizeScheme())\n",
         "\n",
         "pqat_model.compile(optimizer='adam',\n",
-        "              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n",
+        "              loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n",
         "              metrics=['accuracy'])\n",
         "print('Train pqat model:')\n",
         "pqat_model.fit(train_images, train_labels, batch_size=128, epochs=1, validation_split=0.1)"
diff --git a/tensorflow_model_optimization/g3doc/guide/combine/sparse_clustering_example.ipynb b/tensorflow_model_optimization/g3doc/guide/combine/sparse_clustering_example.ipynb
index 7cb680833..46a9d50fc 100644
--- a/tensorflow_model_optimization/g3doc/guide/combine/sparse_clustering_example.ipynb
+++ b/tensorflow_model_optimization/g3doc/guide/combine/sparse_clustering_example.ipynb
@@ -80,7 +80,7 @@
         "\n",
         "In the tutorial, you will:\n",
         "\n",
-        "1. Train a `tf.keras` model for the MNIST dataset from scratch.\n",
+        "1. Train a `keras` model for the MNIST dataset from scratch.\n",
         "2. Fine-tune the model with sparsity and see the accuracy and observe that the model was successfully pruned.\n",
         "3. Apply weight clustering to the pruned model and observe the loss of sparsity.\n",
         "4. Apply sparsity preserving clustering on the pruned model and observe that the sparsity applied earlier has been preserved.\n",
@@ -119,6 +119,7 @@
       "outputs": [],
       "source": [
         "import tensorflow as tf\n",
+        "import tf_keras as keras\n",
         "\n",
         "import numpy as np\n",
         "import tempfile\n",
@@ -132,7 +133,7 @@
         "id": "dKzOfl5FSGPL"
       },
       "source": [
-        "## Train a tf.keras model for MNIST to be pruned and clustered"
+        "## Train a keras model for MNIST to be pruned and clustered"
       ]
     },
     {
@@ -144,26 +145,26 @@
       "outputs": [],
       "source": [
         "# Load MNIST dataset\n",
-        "mnist = tf.keras.datasets.mnist\n",
+        "mnist = keras.datasets.mnist\n",
         "(train_images, train_labels), (test_images, test_labels) = mnist.load_data()\n",
         "\n",
         "# Normalize the input image so that each pixel value is between 0 to 1.\n",
         "train_images = train_images / 255.0\n",
         "test_images  = test_images / 255.0\n",
         "\n",
-        "model = tf.keras.Sequential([\n",
-        "  tf.keras.layers.InputLayer(input_shape=(28, 28)),\n",
-        "  tf.keras.layers.Reshape(target_shape=(28, 28, 1)),\n",
-        "  tf.keras.layers.Conv2D(filters=12, kernel_size=(3, 3),\n",
+        "model = keras.Sequential([\n",
+        "  keras.layers.InputLayer(input_shape=(28, 28)),\n",
+        "  keras.layers.Reshape(target_shape=(28, 28, 1)),\n",
+        "  keras.layers.Conv2D(filters=12, kernel_size=(3, 3),\n",
         "                         activation=tf.nn.relu),\n",
-        "  tf.keras.layers.MaxPooling2D(pool_size=(2, 2)),\n",
-        "  tf.keras.layers.Flatten(),\n",
-        "  tf.keras.layers.Dense(10)\n",
+        "  keras.layers.MaxPooling2D(pool_size=(2, 2)),\n",
+        "  keras.layers.Flatten(),\n",
+        "  keras.layers.Dense(10)\n",
         "])\n",
         "\n",
         "# Train the digit classification model\n",
         "model.compile(optimizer='adam',\n",
-        "              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n",
+        "              loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n",
         "              metrics=['accuracy'])\n",
         "\n",
         "model.fit(\n",
@@ -198,7 +199,7 @@
         "\n",
         "_, keras_file = tempfile.mkstemp('.h5')\n",
         "print('Saving model to: ', keras_file)\n",
-        "tf.keras.models.save_model(model, keras_file, include_optimizer=False)"
+        "keras.models.save_model(model, keras_file, include_optimizer=False)"
       ]
     },
     {
@@ -253,10 +254,10 @@
         "pruned_model = prune_low_magnitude(model, **pruning_params)\n",
         "\n",
         "# Use smaller learning rate for fine-tuning\n",
-        "opt = tf.keras.optimizers.Adam(learning_rate=1e-5)\n",
+        "opt = keras.optimizers.Adam(learning_rate=1e-5)\n",
         "\n",
         "pruned_model.compile(\n",
-        "  loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n",
+        "  loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n",
         "  optimizer=opt,\n",
         "  metrics=['accuracy'])\n",
         "\n",
@@ -311,7 +312,7 @@
         "def print_model_weights_sparsity(model):\n",
         "\n",
         "    for layer in model.layers:\n",
-        "        if isinstance(layer, tf.keras.layers.Wrapper):\n",
+        "        if isinstance(layer, keras.layers.Wrapper):\n",
         "            weights = layer.trainable_weights\n",
         "        else:\n",
         "            weights = layer.weights\n",
@@ -347,7 +348,7 @@
         "\n",
         "print_model_weights_sparsity(stripped_pruned_model)\n",
         "\n",
-        "stripped_pruned_model_copy = tf.keras.models.clone_model(stripped_pruned_model)\n",
+        "stripped_pruned_model_copy = keras.models.clone_model(stripped_pruned_model)\n",
         "stripped_pruned_model_copy.set_weights(stripped_pruned_model.get_weights())"
       ]
     },
@@ -389,7 +390,7 @@
         "clustered_model = cluster_weights(stripped_pruned_model, **clustering_params)\n",
         "\n",
         "clustered_model.compile(optimizer='adam',\n",
-        "              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n",
+        "              loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n",
         "              metrics=['accuracy'])\n",
         "\n",
         "print('Train clustering model:')\n",
@@ -414,7 +415,7 @@
         "sparsity_clustered_model = cluster_weights(stripped_pruned_model_copy, **clustering_params)\n",
         "\n",
         "sparsity_clustered_model.compile(optimizer='adam',\n",
-        "              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n",
+        "              loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n",
         "              metrics=['accuracy'])\n",
         "\n",
         "print('Train sparsity preserving clustering model:')\n",
@@ -598,7 +599,7 @@
       "source": [
         "# Keras model evaluation\n",
         "stripped_sparsity_clustered_model.compile(optimizer='adam',\n",
-        "              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n",
+        "              loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n",
         "              metrics=['accuracy'])\n",
         "_, sparsity_clustered_keras_accuracy = stripped_sparsity_clustered_model.evaluate(\n",
         "    test_images, test_labels, verbose=0)\n",
diff --git a/tensorflow_model_optimization/g3doc/guide/pruning/comprehensive_guide.ipynb b/tensorflow_model_optimization/g3doc/guide/pruning/comprehensive_guide.ipynb
index d674f47f0..0b91512ff 100644
--- a/tensorflow_model_optimization/g3doc/guide/pruning/comprehensive_guide.ipynb
+++ b/tensorflow_model_optimization/g3doc/guide/pruning/comprehensive_guide.ipynb
@@ -118,6 +118,7 @@
         "import tensorflow as tf\n",
         "import numpy as np\n",
         "import tensorflow_model_optimization as tfmot\n",
+        "import tf_keras as keras\n",
         "\n",
         "%load_ext tensorboard\n",
         "\n",
@@ -125,12 +126,12 @@
         "\n",
         "input_shape = [20]\n",
         "x_train = np.random.randn(1, 20).astype(np.float32)\n",
-        "y_train = tf.keras.utils.to_categorical(np.random.randn(1), num_classes=20)\n",
+        "y_train = keras.utils.to_categorical(np.random.randn(1), num_classes=20)\n",
         "\n",
         "def setup_model():\n",
-        "  model = tf.keras.Sequential([\n",
-        "      tf.keras.layers.Dense(20, input_shape=input_shape),\n",
-        "      tf.keras.layers.Flatten()\n",
+        "  model = keras.Sequential([\n",
+        "      keras.layers.Dense(20, input_shape=input_shape),\n",
+        "      keras.layers.Flatten()\n",
         "  ])\n",
         "  return model\n",
         "\n",
@@ -138,7 +139,7 @@
         "  model = setup_model()\n",
         "\n",
         "  model.compile(\n",
-        "      loss=tf.keras.losses.categorical_crossentropy,\n",
+        "      loss=keras.losses.categorical_crossentropy,\n",
         "      optimizer='adam',\n",
         "      metrics=['accuracy']\n",
         "  )\n",
@@ -259,13 +260,13 @@
         "# Helper function uses `prune_low_magnitude` to make only the \n",
         "# Dense layers train with pruning.\n",
         "def apply_pruning_to_dense(layer):\n",
-        "  if isinstance(layer, tf.keras.layers.Dense):\n",
+        "  if isinstance(layer, keras.layers.Dense):\n",
         "    return tfmot.sparsity.keras.prune_low_magnitude(layer)\n",
         "  return layer\n",
         "\n",
-        "# Use `tf.keras.models.clone_model` to apply `apply_pruning_to_dense` \n",
+        "# Use `keras.models.clone_model` to apply `apply_pruning_to_dense` \n",
         "# to the layers of the model.\n",
-        "model_for_pruning = tf.keras.models.clone_model(\n",
+        "model_for_pruning = keras.models.clone_model(\n",
         "    base_model,\n",
         "    clone_function=apply_pruning_to_dense,\n",
         ")\n",
@@ -332,10 +333,10 @@
       "outputs": [],
       "source": [
         "# Use `prune_low_magnitude` to make the `Dense` layer train with pruning.\n",
-        "i = tf.keras.Input(shape=(20,))\n",
-        "x = tfmot.sparsity.keras.prune_low_magnitude(tf.keras.layers.Dense(10))(i)\n",
-        "o = tf.keras.layers.Flatten()(x)\n",
-        "model_for_pruning = tf.keras.Model(inputs=i, outputs=o)\n",
+        "i = keras.Input(shape=(20,))\n",
+        "x = tfmot.sparsity.keras.prune_low_magnitude(keras.layers.Dense(10))(i)\n",
+        "o = keras.layers.Flatten()(x)\n",
+        "model_for_pruning = keras.Model(inputs=i, outputs=o)\n",
         "\n",
         "model_for_pruning.summary()"
       ]
@@ -358,9 +359,9 @@
       "outputs": [],
       "source": [
         "# Use `prune_low_magnitude` to make the `Dense` layer train with pruning.\n",
-        "model_for_pruning = tf.keras.Sequential([\n",
-        "  tfmot.sparsity.keras.prune_low_magnitude(tf.keras.layers.Dense(20, input_shape=input_shape)),\n",
-        "  tf.keras.layers.Flatten()\n",
+        "model_for_pruning = keras.Sequential([\n",
+        "  tfmot.sparsity.keras.prune_low_magnitude(keras.layers.Dense(20, input_shape=input_shape)),\n",
+        "  keras.layers.Flatten()\n",
         "])\n",
         "\n",
         "model_for_pruning.summary()"
@@ -399,16 +400,16 @@
       },
       "outputs": [],
       "source": [
-        "class MyDenseLayer(tf.keras.layers.Dense, tfmot.sparsity.keras.PrunableLayer):\n",
+        "class MyDenseLayer(keras.layers.Dense, tfmot.sparsity.keras.PrunableLayer):\n",
         "\n",
         "  def get_prunable_weights(self):\n",
         "    # Prune bias also, though that usually harms model accuracy too much.\n",
         "    return [self.kernel, self.bias]\n",
         "\n",
         "# Use `prune_low_magnitude` to make the `MyDenseLayer` layer train with pruning.\n",
-        "model_for_pruning = tf.keras.Sequential([\n",
+        "model_for_pruning = keras.Sequential([\n",
         "  tfmot.sparsity.keras.prune_low_magnitude(MyDenseLayer(20, input_shape=input_shape)),\n",
-        "  tf.keras.layers.Flatten()\n",
+        "  keras.layers.Flatten()\n",
         "])\n",
         "\n",
         "model_for_pruning.summary()\n"
@@ -464,7 +465,7 @@
         "]\n",
         "\n",
         "model_for_pruning.compile(\n",
-        "      loss=tf.keras.losses.categorical_crossentropy,\n",
+        "      loss=keras.losses.categorical_crossentropy,\n",
         "      optimizer='adam',\n",
         "      metrics=['accuracy']\n",
         ")\n",
@@ -523,8 +524,8 @@
         "model_for_pruning = tfmot.sparsity.keras.prune_low_magnitude(base_model)\n",
         "\n",
         "# Boilerplate\n",
-        "loss = tf.keras.losses.categorical_crossentropy\n",
-        "optimizer = tf.keras.optimizers.Adam()\n",
+        "loss = keras.losses.categorical_crossentropy\n",
+        "optimizer = keras.optimizers.Adam()\n",
         "log_dir = tempfile.mkdtemp()\n",
         "unused_arg = -1\n",
         "epochs = 2\n",
@@ -650,7 +651,7 @@
       "source": [
         "# Deserialize model.\n",
         "with tfmot.sparsity.keras.prune_scope():\n",
-        "  loaded_model = tf.keras.models.load_model(keras_model_file)\n",
+        "  loaded_model = keras.models.load_model(keras_model_file)\n",
         "\n",
         "loaded_model.summary()"
       ]
diff --git a/tensorflow_model_optimization/g3doc/guide/pruning/index.md b/tensorflow_model_optimization/g3doc/guide/pruning/index.md
index 927d4f303..8dc89e4f8 100644
--- a/tensorflow_model_optimization/g3doc/guide/pruning/index.md
+++ b/tensorflow_model_optimization/g3doc/guide/pruning/index.md
@@ -29,7 +29,7 @@ various vision and translation models.
 ### API Compatibility Matrix
 Users can apply pruning with the following APIs:
 
-*   Model building: `tf.keras` with only Sequential and Functional models
+*   Model building: `keras` with only Sequential and Functional models
 *   TensorFlow versions: TF 1.x for versions 1.14+ and 2.x.
     *   `tf.compat.v1` with a TF 2.X package and `tf.compat.v2` with a TF 1.X
         package are not supported.
diff --git a/tensorflow_model_optimization/g3doc/guide/pruning/pruning_for_on_device_inference.ipynb b/tensorflow_model_optimization/g3doc/guide/pruning/pruning_for_on_device_inference.ipynb
index 7b92dd055..fcb5640f6 100644
--- a/tensorflow_model_optimization/g3doc/guide/pruning/pruning_for_on_device_inference.ipynb
+++ b/tensorflow_model_optimization/g3doc/guide/pruning/pruning_for_on_device_inference.ipynb
@@ -118,6 +118,7 @@
         "from tensorflow import keras\n",
         "import tensorflow_datasets as tfds\n",
         "import tensorflow_model_optimization as tfmot\n",
+        "import tf_keras as keras\n",
         "\n",
         "%load_ext tensorboard"
       ]
@@ -209,7 +210,7 @@
         "\n",
         "# Compile and train the dense model for 10 epochs.\n",
         "dense_model.compile(\n",
-        "    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n",
+        "    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n",
         "    optimizer='adam',\n",
         "    metrics=['accuracy'])\n",
         "\n",
@@ -368,7 +369,7 @@
         "]\n",
         "\n",
         "model_for_pruning.compile(\n",
-        "    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n",
+        "    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n",
         "    optimizer='adam',\n",
         "    metrics=['accuracy'])\n",
         "\n",
diff --git a/tensorflow_model_optimization/g3doc/guide/pruning/pruning_with_keras.ipynb b/tensorflow_model_optimization/g3doc/guide/pruning/pruning_with_keras.ipynb
index da8ae175b..b2fcda92d 100644
--- a/tensorflow_model_optimization/g3doc/guide/pruning/pruning_with_keras.ipynb
+++ b/tensorflow_model_optimization/g3doc/guide/pruning/pruning_with_keras.ipynb
@@ -82,7 +82,7 @@
         "\n",
         "In this tutorial, you will:\n",
         "\n",
-        "1.   Train a `tf.keras` model for MNIST from scratch.\n",
+        "1.   Train a `keras` model for MNIST from scratch.\n",
         "2.   Fine tune the model by applying the pruning API and see the accuracy.\n",
         "3.   Create 3x smaller TF and TFLite models from pruning.\n",
         "4.   Create a 10x smaller TFLite model from combining pruning and post-training quantization.\n",
@@ -123,7 +123,7 @@
         "import tensorflow as tf\n",
         "import numpy as np\n",
         "\n",
-        "from tensorflow import keras\n",
+        "import tf_keras as keras\n",
         "\n",
         "%load_ext tensorboard"
       ]
@@ -146,8 +146,7 @@
       "outputs": [],
       "source": [
         "# Load MNIST dataset\n",
-        "mnist = keras.datasets.mnist\n",
-        "(train_images, train_labels), (test_images, test_labels) = mnist.load_data()\n",
+        "(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data()\n",
         "\n",
         "# Normalize the input image so that each pixel value is between 0 and 1.\n",
         "train_images = train_images / 255.0\n",
@@ -165,7 +164,7 @@
         "\n",
         "# Train the digit classification model\n",
         "model.compile(optimizer='adam',\n",
-        "              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n",
+        "              loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n",
         "              metrics=['accuracy'])\n",
         "\n",
         "model.fit(\n",
@@ -199,7 +198,7 @@
         "print('Baseline test accuracy:', baseline_model_accuracy)\n",
         "\n",
         "_, keras_file = tempfile.mkstemp('.h5')\n",
-        "tf.keras.models.save_model(model, keras_file, include_optimizer=False)\n",
+        "keras.models.save_model(model, keras_file, include_optimizer=False)\n",
         "print('Saved baseline model to:', keras_file)"
       ]
     },
@@ -267,7 +266,7 @@
         "\n",
         "# `prune_low_magnitude` requires a recompile.\n",
         "model_for_pruning.compile(optimizer='adam',\n",
-        "              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n",
+        "              loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n",
         "              metrics=['accuracy'])\n",
         "\n",
         "model_for_pruning.summary()"
@@ -403,7 +402,7 @@
         "model_for_export = tfmot.sparsity.keras.strip_pruning(model_for_pruning)\n",
         "\n",
         "_, pruned_keras_file = tempfile.mkstemp('.h5')\n",
-        "tf.keras.models.save_model(model_for_export, pruned_keras_file, include_optimizer=False)\n",
+        "keras.models.save_model(model_for_export, pruned_keras_file, include_optimizer=False)\n",
         "print('Saved pruned Keras model to:', pruned_keras_file)"
       ]
     },
diff --git a/tensorflow_model_optimization/g3doc/guide/quantization/training.md b/tensorflow_model_optimization/g3doc/guide/quantization/training.md
index 90cfb54e3..462bbbb09 100644
--- a/tensorflow_model_optimization/g3doc/guide/quantization/training.md
+++ b/tensorflow_model_optimization/g3doc/guide/quantization/training.md
@@ -49,7 +49,7 @@ compatibility.
 
 Users can apply quantization with the following APIs:
 
-*   Model building: `tf.keras` with only Sequential and Functional models.
+*   Model building: `keras` with only Sequential and Functional models.
 *   TensorFlow versions: TF 2.x for tf-nightly.
     *   `tf.compat.v1` with a TF 2.X package is not supported.
 *   TensorFlow execution mode: eager execution
diff --git a/tensorflow_model_optimization/g3doc/guide/quantization/training_comprehensive_guide.ipynb b/tensorflow_model_optimization/g3doc/guide/quantization/training_comprehensive_guide.ipynb
index 844894113..da383bc41 100644
--- a/tensorflow_model_optimization/g3doc/guide/quantization/training_comprehensive_guide.ipynb
+++ b/tensorflow_model_optimization/g3doc/guide/quantization/training_comprehensive_guide.ipynb
@@ -121,17 +121,18 @@
         "import tensorflow as tf\n",
         "import numpy as np\n",
         "import tensorflow_model_optimization as tfmot\n",
+	"import tf_keras as keras\n",
         "\n",
         "import tempfile\n",
         "\n",
         "input_shape = [20]\n",
         "x_train = np.random.randn(1, 20).astype(np.float32)\n",
-        "y_train = tf.keras.utils.to_categorical(np.random.randn(1), num_classes=20)\n",
+        "y_train = keras.utils.to_categorical(np.random.randn(1), num_classes=20)\n",
         "\n",
         "def setup_model():\n",
-        "  model = tf.keras.Sequential([\n",
-        "      tf.keras.layers.Dense(20, input_shape=input_shape),\n",
-        "      tf.keras.layers.Flatten()\n",
+        "  model = keras.Sequential([\n",
+        "      keras.layers.Dense(20, input_shape=input_shape),\n",
+        "      keras.layers.Flatten()\n",
         "  ])\n",
         "  return model\n",
         "\n",
@@ -139,7 +140,7 @@
         "  model= setup_model()\n",
         "\n",
         "  model.compile(\n",
-        "      loss=tf.keras.losses.categorical_crossentropy,\n",
+        "      loss=keras.losses.categorical_crossentropy,\n",
         "      optimizer='adam',\n",
         "      metrics=['accuracy']\n",
         "  )\n",
@@ -280,13 +281,13 @@
         "# Helper function uses `quantize_annotate_layer` to annotate that only the \n",
         "# Dense layers should be quantized.\n",
         "def apply_quantization_to_dense(layer):\n",
-        "  if isinstance(layer, tf.keras.layers.Dense):\n",
+        "  if isinstance(layer, keras.layers.Dense):\n",
         "    return tfmot.quantization.keras.quantize_annotate_layer(layer)\n",
         "  return layer\n",
         "\n",
-        "# Use `tf.keras.models.clone_model` to apply `apply_quantization_to_dense` \n",
+        "# Use `keras.models.clone_model` to apply `apply_quantization_to_dense` \n",
         "# to the layers of the model.\n",
-        "annotated_model = tf.keras.models.clone_model(\n",
+        "annotated_model = keras.models.clone_model(\n",
         "    base_model,\n",
         "    clone_function=apply_quantization_to_dense,\n",
         ")\n",
@@ -354,10 +355,10 @@
       "source": [
         "# Use `quantize_annotate_layer` to annotate that the `Dense` layer\n",
         "# should be quantized.\n",
-        "i = tf.keras.Input(shape=(20,))\n",
-        "x = tfmot.quantization.keras.quantize_annotate_layer(tf.keras.layers.Dense(10))(i)\n",
-        "o = tf.keras.layers.Flatten()(x)\n",
-        "annotated_model = tf.keras.Model(inputs=i, outputs=o)\n",
+        "i = keras.Input(shape=(20,))\n",
+        "x = tfmot.quantization.keras.quantize_annotate_layer(keras.layers.Dense(10))(i)\n",
+        "o = keras.layers.Flatten()(x)\n",
+        "annotated_model = keras.Model(inputs=i, outputs=o)\n",
         "\n",
         "# Use `quantize_apply` to actually make the model quantization aware.\n",
         "quant_aware_model = tfmot.quantization.keras.quantize_apply(annotated_model)\n",
@@ -386,9 +387,9 @@
       "source": [
         "# Use `quantize_annotate_layer` to annotate that the `Dense` layer\n",
         "# should be quantized.\n",
-        "annotated_model = tf.keras.Sequential([\n",
-        "  tfmot.quantization.keras.quantize_annotate_layer(tf.keras.layers.Dense(20, input_shape=input_shape)),\n",
-        "  tf.keras.layers.Flatten()\n",
+        "annotated_model = keras.Sequential([\n",
+        "  tfmot.quantization.keras.quantize_annotate_layer(keras.layers.Dense(20, input_shape=input_shape)),\n",
+        "  keras.layers.Flatten()\n",
         "])\n",
         "\n",
         "# Use `quantize_apply` to actually make the model quantization aware.\n",
@@ -434,7 +435,7 @@
         "\n",
         "# `quantize_scope` is needed for deserializing HDF5 models.\n",
         "with tfmot.quantization.keras.quantize_scope():\n",
-        "  loaded_model = tf.keras.models.load_model(keras_model_file)\n",
+        "  loaded_model = keras.models.load_model(keras_model_file)\n",
         "\n",
         "loaded_model.summary()"
       ]
@@ -601,12 +602,12 @@
         "quantize_annotate_model = tfmot.quantization.keras.quantize_annotate_model\n",
         "quantize_scope = tfmot.quantization.keras.quantize_scope\n",
         "\n",
-        "class CustomLayer(tf.keras.layers.Dense):\n",
+        "class CustomLayer(keras.layers.Dense):\n",
         "  pass\n",
         "\n",
-        "model = quantize_annotate_model(tf.keras.Sequential([\n",
+        "model = quantize_annotate_model(keras.Sequential([\n",
         "   quantize_annotate_layer(CustomLayer(20, input_shape=(20,)), DefaultDenseQuantizeConfig()),\n",
-        "   tf.keras.layers.Flatten()\n",
+        "   keras.layers.Flatten()\n",
         "]))\n",
         "\n",
         "# `quantize_apply` requires mentioning `DefaultDenseQuantizeConfig` with `quantize_scope`\n",
@@ -680,10 +681,10 @@
       },
       "outputs": [],
       "source": [
-        "model = quantize_annotate_model(tf.keras.Sequential([\n",
+        "model = quantize_annotate_model(keras.Sequential([\n",
         "   # Pass in modified `QuantizeConfig` to modify this Dense layer.\n",
-        "   quantize_annotate_layer(tf.keras.layers.Dense(20, input_shape=(20,)), ModifiedDenseQuantizeConfig()),\n",
-        "   tf.keras.layers.Flatten()\n",
+        "   quantize_annotate_layer(keras.layers.Dense(20, input_shape=(20,)), ModifiedDenseQuantizeConfig()),\n",
+        "   keras.layers.Flatten()\n",
         "]))\n",
         "\n",
         "# `quantize_apply` requires mentioning `ModifiedDenseQuantizeConfig` with `quantize_scope`:\n",
@@ -757,10 +758,10 @@
       },
       "outputs": [],
       "source": [
-        "model = quantize_annotate_model(tf.keras.Sequential([\n",
+        "model = quantize_annotate_model(keras.Sequential([\n",
         "   # Pass in modified `QuantizeConfig` to modify this Dense layer.\n",
-        "   quantize_annotate_layer(tf.keras.layers.Dense(20, input_shape=(20,)), ModifiedDenseQuantizeConfig()),\n",
-        "   tf.keras.layers.Flatten()\n",
+        "   quantize_annotate_layer(keras.layers.Dense(20, input_shape=(20,)), ModifiedDenseQuantizeConfig()),\n",
+        "   keras.layers.Flatten()\n",
         "]))\n",
         "\n",
         "# `quantize_apply` requires mentioning `ModifiedDenseQuantizeConfig` with `quantize_scope`:\n",
@@ -816,7 +817,7 @@
         "    return {}\n",
         "\n",
         "  def __call__(self, inputs, training, weights, **kwargs):\n",
-        "    return tf.keras.backend.clip(inputs, -1.0, 1.0)\n",
+        "    return keras.backend.clip(inputs, -1.0, 1.0)\n",
         "\n",
         "  def get_config(self):\n",
         "    # Not needed. No __init__ parameters to serialize.\n",
@@ -851,10 +852,10 @@
       },
       "outputs": [],
       "source": [
-        "model = quantize_annotate_model(tf.keras.Sequential([\n",
+        "model = quantize_annotate_model(keras.Sequential([\n",
         "   # Pass in modified `QuantizeConfig` to modify this `Dense` layer.\n",
-        "   quantize_annotate_layer(tf.keras.layers.Dense(20, input_shape=(20,)), ModifiedDenseQuantizeConfig()),\n",
-        "   tf.keras.layers.Flatten()\n",
+        "   quantize_annotate_layer(keras.layers.Dense(20, input_shape=(20,)), ModifiedDenseQuantizeConfig()),\n",
+        "   keras.layers.Flatten()\n",
         "]))\n",
         "\n",
         "# `quantize_apply` requires mentioning `ModifiedDenseQuantizeConfig` with `quantize_scope`:\n",
diff --git a/tensorflow_model_optimization/g3doc/guide/quantization/training_example.ipynb b/tensorflow_model_optimization/g3doc/guide/quantization/training_example.ipynb
index 3cf88af2a..5c131be66 100644
--- a/tensorflow_model_optimization/g3doc/guide/quantization/training_example.ipynb
+++ b/tensorflow_model_optimization/g3doc/guide/quantization/training_example.ipynb
@@ -82,7 +82,7 @@
         "\n",
         "In this tutorial, you will:\n",
         "\n",
-        "1.   Train a `tf.keras` model for MNIST from scratch.\n",
+        "1.   Train a `keras` model for MNIST from scratch.\n",
         "2.   Fine tune the model by applying the quantization aware training API, see the accuracy, and\n",
         "     export a quantization aware model.\n",
         "3.   Use the model to create an actually quantized model for the TFLite\n",
@@ -165,7 +165,7 @@
         "\n",
         "# Train the digit classification model\n",
         "model.compile(optimizer='adam',\n",
-        "              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n",
+        "              loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n",
         "              metrics=['accuracy'])\n",
         "\n",
         "model.fit(\n",
@@ -216,6 +216,7 @@
       "outputs": [],
       "source": [
         "import tensorflow_model_optimization as tfmot\n",
+        "import tf_keras as keras\n",
         "\n",
         "quantize_model = tfmot.quantization.keras.quantize_model\n",
         "\n",
@@ -224,7 +225,7 @@
         "\n",
         "# `quantize_model` requires a recompile.\n",
         "q_aware_model.compile(optimizer='adam',\n",
-        "              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n",
+        "              loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n",
         "              metrics=['accuracy'])\n",
         "\n",
         "q_aware_model.summary()"
diff --git a/tensorflow_model_optimization/python/core/clustering/keras/BUILD b/tensorflow_model_optimization/python/core/clustering/keras/BUILD
index a56cb264a..7a776f188 100644
--- a/tensorflow_model_optimization/python/core/clustering/keras/BUILD
+++ b/tensorflow_model_optimization/python/core/clustering/keras/BUILD
@@ -1,5 +1,5 @@
-# Placeholder: load py_test
 load("//tensorflow_model_optimization:tensorflow_model_optimization.bzl", "py_strict_library", "py_strict_test")
+# Placeholder: load py_test
 
 package(default_visibility = [
     "//tensorflow_model_optimization:__subpackages__",
@@ -30,6 +30,7 @@ py_strict_library(
         ":cluster_wrapper",
         ":clustering_centroids",
         # tensorflow dep1,
+        "//tensorflow_model_optimization/python/core/keras:compat",
     ],
 )
 
@@ -49,6 +50,7 @@ py_strict_library(
         ":clusterable_layer",
         ":clustering_algorithm",
         # tensorflow dep1,
+        "//tensorflow_model_optimization/python/core/keras:compat",
     ],
 )
 
@@ -72,6 +74,7 @@ py_strict_library(
         # six dep1,
         # tensorflow dep1,
         # python/ops:clustering_ops tensorflow dep2,
+        "//tensorflow_model_optimization/python/core/keras:compat",
     ],
 )
 
@@ -86,6 +89,7 @@ py_strict_library(
         ":clustering_centroids",
         ":clustering_registry",
         # tensorflow dep1,
+        "//tensorflow_model_optimization/python/core/keras:compat",
     ],
 )
 
@@ -128,6 +132,7 @@ py_test(
         # google/protobuf:use_fast_cpp_protos dep1,  # Automatically added
         # tensorflow dep1,
         "//tensorflow_model_optimization/python/core/clustering/keras/experimental:cluster",
+        "//tensorflow_model_optimization/python/core/keras:compat",
     ],
 )
 
@@ -144,6 +149,7 @@ py_strict_test(
         # google/protobuf:use_fast_cpp_protos dep1,  # Automatically added
         # numpy dep1,
         # tensorflow dep1,
+        "//tensorflow_model_optimization/python/core/keras:compat",
     ],
 )
 
@@ -161,6 +167,7 @@ py_strict_test(
         # absl/testing:parameterized dep1,
         # google/protobuf:use_fast_cpp_protos dep1,  # Automatically added
         # tensorflow dep1,
+        "//tensorflow_model_optimization/python/core/keras:compat",
     ],
 )
 
@@ -177,6 +184,7 @@ py_strict_test(
         # absl/testing:parameterized dep1,
         # google/protobuf:use_fast_cpp_protos dep1,  # Automatically added
         # tensorflow dep1,
+        "//tensorflow_model_optimization/python/core/keras:compat",
     ],
 )
 
@@ -194,6 +202,7 @@ py_test(
         # numpy dep1,
         # tensorflow dep1,
         "//tensorflow_model_optimization/python/core/clustering/keras/experimental:cluster",
+        "//tensorflow_model_optimization/python/core/keras:compat",
     ],
 )
 
@@ -211,6 +220,7 @@ py_strict_test(
         # numpy dep1,
         # tensorflow dep1,
         "//tensorflow_model_optimization/python/core/clustering/keras/experimental:cluster",
+        "//tensorflow_model_optimization/python/core/keras:compat",
         "//tensorflow_model_optimization/python/core/keras:test_utils",
     ],
 )
@@ -228,6 +238,7 @@ py_strict_test(
         ":clustering_algorithm",
         # google/protobuf:use_fast_cpp_protos dep1,  # Automatically added
         # tensorflow dep1,
+        "//tensorflow_model_optimization/python/core/keras:compat",
     ],
 )
 
@@ -244,5 +255,6 @@ py_strict_test(
         # google/protobuf:use_fast_cpp_protos dep1,  # Automatically added
         # tensorflow dep1,
         "//tensorflow_model_optimization/python/core/clustering/keras/experimental:cluster",
+        "//tensorflow_model_optimization/python/core/keras:compat",
     ],
 )
diff --git a/tensorflow_model_optimization/python/core/clustering/keras/cluster.py b/tensorflow_model_optimization/python/core/clustering/keras/cluster.py
index bbfee03e2..4535a9493 100644
--- a/tensorflow_model_optimization/python/core/clustering/keras/cluster.py
+++ b/tensorflow_model_optimization/python/core/clustering/keras/cluster.py
@@ -21,12 +21,14 @@
 from tensorflow_model_optimization.python.core.clustering.keras import cluster_config
 from tensorflow_model_optimization.python.core.clustering.keras import cluster_wrapper
 from tensorflow_model_optimization.python.core.clustering.keras import clustering_centroids
+from tensorflow_model_optimization.python.core.keras.compat import keras
 
-k = tf.keras.backend
-CustomObjectScope = tf.keras.utils.CustomObjectScope
+
+k = keras.backend
+CustomObjectScope = keras.utils.CustomObjectScope
 CentroidInitialization = cluster_config.CentroidInitialization
-Layer = tf.keras.layers.Layer
-InputLayer = tf.keras.layers.InputLayer
+Layer = keras.layers.Layer
+InputLayer = keras.layers.InputLayer
 
 
 def cluster_scope():
@@ -42,10 +44,10 @@ def cluster_scope():
 
   ```python
   clustered_model = cluster_weights(model, **self.params)
-  tf.keras.models.save_model(clustered_model, keras_file)
+  keras.models.save_model(clustered_model, keras_file)
 
   with cluster_scope():
-    loaded_model = tf.keras.models.load_model(keras_file)
+    loaded_model = keras.models.load_model(keras_file)
   ```
   """
   return CustomObjectScope({'ClusterWeights': cluster_wrapper.ClusterWeights})
@@ -93,29 +95,27 @@ def cluster_weights(
     'cluster_centroids_init': CentroidInitialization.DENSITY_BASED
   }
 
-  model = tf.keras.Sequential([
+  model = keras.Sequential([
       layers.Dense(10, activation='relu', input_shape=(100,)),
       cluster_weights(layers.Dense(2, activation='tanh'), **clustering_params)
   ])
   ```
 
   Arguments:
-      to_cluster: A single keras layer, list of keras layers, or a
-        `tf.keras.Model` instance.
+      to_cluster: A single keras layer, list of keras layers, or a `keras.Model`
+        instance.
       number_of_clusters: the number of cluster centroids to form when
         clustering a layer/model. For example, if number_of_clusters=8 then only
         8 unique values will be used in each weight array.
       cluster_centroids_init: enum value that determines how the cluster
-        centroids will be initialized.
-        Can have following values:
-          1. RANDOM : centroids are sampled using the uniform distribution
-            between the minimum and maximum weight values in a given layer
-          2. DENSITY_BASED : density-based sampling. First, cumulative
-            distribution function is built for weights, then y-axis is evenly
-            spaced into number_of_clusters regions. After this the corresponding
-            x values are obtained and used to initialize clusters centroids.
-          3. LINEAR : cluster centroids are evenly spaced between the minimum
-            and maximum values of a given weight
+        centroids will be initialized. Can have following values: 1. RANDOM :
+        centroids are sampled using the uniform distribution between the minimum
+        and maximum weight values in a given layer 2. DENSITY_BASED :
+        density-based sampling. First, cumulative distribution function is built
+        for weights, then y-axis is evenly spaced into number_of_clusters
+        regions. After this the corresponding x values are obtained and used to
+        initialize clusters centroids. 3. LINEAR : cluster centroids are evenly
+        spaced between the minimum and maximum values of a given weight
       **kwargs: Additional keyword arguments to be passed to the keras layer.
         Ignored when to_cluster is not a keras layer.
 
@@ -177,7 +177,7 @@ def _cluster_weights(to_cluster,
     'preserve_sparsity': False
   }
 
-  model = tf.keras.Sequential([
+  model = keras.Sequential([
       layers.Dense(10, activation='relu', input_shape=(100,)),
       cluster_weights(layers.Dense(2, activation='tanh'), **clustering_params)
   ])
@@ -192,15 +192,15 @@ def _cluster_weights(to_cluster,
     'preserve_sparsity': True
   }
 
-  model = tf.keras.Sequential([
+  model = keras.Sequential([
       layers.Dense(10, activation='relu', input_shape=(100,)),
       cluster_weights(layers.Dense(2, activation='tanh'), **clustering_params)
   ])
   ```
 
   Arguments:
-      to_cluster: A single keras layer, list of keras layers, or a
-        `tf.keras.Model` instance.
+      to_cluster: A single keras layer, list of keras layers, or a `keras.Model`
+        instance.
       number_of_clusters: the number of cluster centroids to form when
         clustering a layer/model. For example, if number_of_clusters=8 then only
         8 unique values will be used in each weight array.
@@ -235,23 +235,26 @@ def _cluster_weights(to_cluster,
         cluster_centroids_init))
 
   def _add_clustering_wrapper(layer):
-    if isinstance(layer, tf.keras.Model):
+    if isinstance(layer, keras.Model):
       # Check whether the model is a subclass.
       # NB: This check is copied from keras.py file in tensorflow.
       # There is no available public API to do this check.
       # pylint: disable=protected-access
-      if (not layer._is_graph_network and
-          not isinstance(layer, tf.keras.models.Sequential)):
+      if not layer._is_graph_network and not isinstance(
+          layer, keras.models.Sequential
+      ):
         raise ValueError('Subclassed models are not supported currently.')
 
-      return tf.keras.models.clone_model(
-          layer, input_tensors=None, clone_function=_add_clustering_wrapper)
+      return keras.models.clone_model(
+          layer, input_tensors=None, clone_function=_add_clustering_wrapper
+      )
     if isinstance(layer, cluster_wrapper.ClusterWeights):
       return layer
     if isinstance(layer, InputLayer):
       return layer.__class__.from_config(layer.get_config())
-    if isinstance(layer, tf.keras.layers.RNN) or isinstance(
-        layer, tf.keras.layers.Bidirectional):
+    if isinstance(layer, keras.layers.RNN) or isinstance(
+        layer, keras.layers.Bidirectional
+    ):
       return cluster_wrapper.ClusterWeightsRNN(
           layer,
           number_of_clusters,
@@ -259,7 +262,7 @@ def _add_clustering_wrapper(layer):
           preserve_sparsity,
           **kwargs,
       )
-    if isinstance(layer, tf.keras.layers.MultiHeadAttention):
+    if isinstance(layer, keras.layers.MultiHeadAttention):
       return cluster_wrapper.ClusterWeightsMHA(
           layer,
           number_of_clusters,
@@ -271,9 +274,10 @@ def _add_clustering_wrapper(layer):
     # Skip clustering if Conv2D layer has insufficient number of weights
     # for type of clustering
     if isinstance(
-        layer,
-        tf.keras.layers.Conv2D) and not layer_has_enough_weights_to_cluster(
-            layer, number_of_clusters, cluster_per_channel):
+        layer, keras.layers.Conv2D
+    ) and not layer_has_enough_weights_to_cluster(
+        layer, number_of_clusters, cluster_per_channel
+    ):
       return layer
 
     return cluster_wrapper.ClusterWeights(layer, number_of_clusters,
@@ -288,9 +292,10 @@ def _wrap_list(layers):
 
     return output
 
-  if isinstance(to_cluster, tf.keras.Model):
-    return tf.keras.models.clone_model(
-        to_cluster, input_tensors=None, clone_function=_add_clustering_wrapper)
+  if isinstance(to_cluster, keras.Model):
+    return keras.models.clone_model(
+        to_cluster, input_tensors=None, clone_function=_add_clustering_wrapper
+    )
   if isinstance(to_cluster, Layer):
     return _add_clustering_wrapper(layer=to_cluster)
   if isinstance(to_cluster, list):
@@ -306,32 +311,34 @@ def strip_clustering(model):
   Only sequential and functional models are supported for now.
 
   Arguments:
-      model: A `tf.keras.Model` instance with clustered layers.
+      model: A `keras.Model` instance with clustered layers.
 
   Returns:
     A keras model with clustering wrappers removed.
 
   Raises:
-    ValueError: if the model is not a `tf.keras.Model` instance.
+    ValueError: if the model is not a `keras.Model` instance.
     NotImplementedError: if the model is a subclass model.
 
   Usage:
 
   ```python
-  orig_model = tf.keras.Model(inputs, outputs)
+  orig_model = keras.Model(inputs, outputs)
   clustered_model = cluster_weights(orig_model)
   exported_model = strip_clustering(clustered_model)
   ```
   The exported_model and the orig_model have the same structure.
   """
-  if not isinstance(model, tf.keras.Model):
+  if not isinstance(model, keras.Model):
     raise ValueError(
-        'Expected model to be a `tf.keras.Model` instance but got: ', model)
+        'Expected model to be a `keras.Model` instance but got: ', model
+    )
 
   def _strip_clustering_wrapper(layer):
-    if isinstance(layer, tf.keras.Model):
-      return tf.keras.models.clone_model(
-          layer, input_tensors=None, clone_function=_strip_clustering_wrapper)
+    if isinstance(layer, keras.Model):
+      return keras.models.clone_model(
+          layer, input_tensors=None, clone_function=_strip_clustering_wrapper
+      )
 
     elif isinstance(layer, cluster_wrapper.ClusterWeightsMHA):
       # Update cluster associations in order to get the latest weights
@@ -363,8 +370,9 @@ def _strip_clustering_wrapper(layer):
     return layer
 
   # Just copy the model with the right callback
-  return tf.keras.models.clone_model(
-      model, input_tensors=None, clone_function=_strip_clustering_wrapper)
+  return keras.models.clone_model(
+      model, input_tensors=None, clone_function=_strip_clustering_wrapper
+  )
 
 
 def layer_has_enough_weights_to_cluster(layer, number_of_clusters,
@@ -379,7 +387,7 @@ def layer_has_enough_weights_to_cluster(layer, number_of_clusters,
     number_of_clusters: A number of cluster centroids to form clusters.
     cluster_per_channel: An optional boolean value.
   """
-  if not isinstance(layer, tf.keras.layers.Conv2D):
+  if not isinstance(layer, keras.layers.Conv2D):
     raise ValueError(f'Input layer should be Conv2D layer: {layer.name} given.')
 
   if not layer.trainable_weights:
diff --git a/tensorflow_model_optimization/python/core/clustering/keras/cluster_distributed_test.py b/tensorflow_model_optimization/python/core/clustering/keras/cluster_distributed_test.py
index 4d6abf7f4..18245b21a 100644
--- a/tensorflow_model_optimization/python/core/clustering/keras/cluster_distributed_test.py
+++ b/tensorflow_model_optimization/python/core/clustering/keras/cluster_distributed_test.py
@@ -25,8 +25,9 @@
 from tensorflow_model_optimization.python.core.clustering.keras import cluster_wrapper
 from tensorflow_model_optimization.python.core.clustering.keras.experimental import cluster as experimental_cluster
 from tensorflow_model_optimization.python.core.keras import test_utils as keras_test_utils
+from tensorflow_model_optimization.python.core.keras.compat import keras
+
 
-keras = tf.keras
 CentroidInitialization = cluster_config.CentroidInitialization
 
 
diff --git a/tensorflow_model_optimization/python/core/clustering/keras/cluster_integration_test.py b/tensorflow_model_optimization/python/core/clustering/keras/cluster_integration_test.py
index 09b084c93..d01960304 100644
--- a/tensorflow_model_optimization/python/core/clustering/keras/cluster_integration_test.py
+++ b/tensorflow_model_optimization/python/core/clustering/keras/cluster_integration_test.py
@@ -24,8 +24,9 @@
 from tensorflow_model_optimization.python.core.clustering.keras import cluster
 from tensorflow_model_optimization.python.core.clustering.keras import cluster_config
 from tensorflow_model_optimization.python.core.clustering.keras.experimental import cluster as experimental_cluster
+from tensorflow_model_optimization.python.core.keras.compat import keras
+
 
-keras = tf.keras
 layers = keras.layers
 test = tf.test
 
@@ -355,7 +356,7 @@ def testStripClusteringSequentialModelWithRegulariser(self):
     """Verifies that stripping the clustering wrappers from a sequential model produces the expected config."""
     original_model = keras.Sequential([
         layers.Dense(5, input_shape=(5,)),
-        layers.Dense(5, kernel_regularizer=tf.keras.regularizers.L1(0.01)),
+        layers.Dense(5, kernel_regularizer=keras.regularizers.L1(0.01)),
     ])
 
     def clusters_check(stripped_model):
@@ -385,8 +386,7 @@ def clusters_check(stripped_model):
 
   def testEndToEndDeepLayer(self):
     """Test End to End clustering for the model with deep layer."""
-    internal_model = tf.keras.Sequential(
-        [tf.keras.layers.Dense(5, input_shape=(5,))])
+    internal_model = keras.Sequential([keras.layers.Dense(5, input_shape=(5,))])
     original_model = keras.Sequential([
         internal_model,
         layers.Dense(5),
@@ -411,8 +411,7 @@ def clusters_check(stripped_model):
 
   def testEndToEndDeepLayer2(self):
     """Test End to End clustering for the model with 2 deep layers."""
-    internal_model = tf.keras.Sequential(
-        [tf.keras.layers.Dense(5, input_shape=(5,))])
+    internal_model = keras.Sequential([keras.layers.Dense(5, input_shape=(5,))])
     intermediate_model = keras.Sequential([
         internal_model,
         layers.Dense(5),
@@ -626,9 +625,12 @@ def testClusterStackedRNNCells(self):
     model.add(
         keras.layers.Embedding(self.max_features, 16, input_length=self.maxlen))
     model.add(
-        tf.keras.layers.RNN(
-            tf.keras.layers.StackedRNNCells(
-                [keras.layers.SimpleRNNCell(16) for _ in range(2)])))
+        keras.layers.RNN(
+            keras.layers.StackedRNNCells(
+                [keras.layers.SimpleRNNCell(16) for _ in range(2)]
+            )
+        )
+    )
     model.add(keras.layers.Dense(1))
     model.add(keras.layers.Activation("sigmoid"))
 
@@ -664,11 +666,12 @@ def setUp(self):
 
   def _get_model(self):
     """Returns functional model with MHA layer."""
-    inp = tf.keras.layers.Input(shape=(32, 32), batch_size=100)
-    x = tf.keras.layers.MultiHeadAttention(num_heads=2, key_dim=16)(
-        query=inp, value=inp)
-    out = tf.keras.layers.Flatten()(x)
-    model = tf.keras.Model(inputs=inp, outputs=out)
+    inp = keras.layers.Input(shape=(32, 32), batch_size=100)
+    x = keras.layers.MultiHeadAttention(num_heads=2, key_dim=16)(
+        query=inp, value=inp
+    )
+    out = keras.layers.Flatten()(x)
+    model = keras.Model(inputs=inp, outputs=out)
     return model
 
   def testMHA(self):
@@ -677,9 +680,10 @@ def testMHA(self):
     clustered_model = cluster.cluster_weights(model, **self.params_clustering)
 
     clustered_model.compile(
-        optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4),
-        loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
-        metrics=[tf.keras.metrics.SparseCategoricalAccuracy(name="accuracy")])
+        optimizer=keras.optimizers.Adam(learning_rate=1e-4),
+        loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
+        metrics=[keras.metrics.SparseCategoricalAccuracy(name="accuracy")],
+    )
     clustered_model.fit(
         self.x_train, self.y_train, epochs=1, batch_size=100, verbose=1)
 
@@ -711,14 +715,14 @@ def setUp(self):
 
   def _get_model(self):
     """Returns functional model with Conv2D layer."""
-    inp = tf.keras.layers.Input(shape=(32, 32), batch_size=100)
-    x = tf.keras.layers.Reshape((32, 32, 1))(inp)
-    x = tf.keras.layers.Conv2D(
-        filters=self.num_channels, kernel_size=(3, 3),
-        activation="relu")(x)
-    x = tf.keras.layers.MaxPool2D(2, 2)(x)
-    out = tf.keras.layers.Flatten()(x)
-    model = tf.keras.Model(inputs=inp, outputs=out)
+    inp = keras.layers.Input(shape=(32, 32), batch_size=100)
+    x = keras.layers.Reshape((32, 32, 1))(inp)
+    x = keras.layers.Conv2D(
+        filters=self.num_channels, kernel_size=(3, 3), activation="relu"
+    )(x)
+    x = keras.layers.MaxPool2D(2, 2)(x)
+    out = keras.layers.Flatten()(x)
+    model = keras.Model(inputs=inp, outputs=out)
     return model
 
   def testPerChannel(self):
@@ -727,9 +731,10 @@ def testPerChannel(self):
     clustered_model = cluster.cluster_weights(model, **self.params_clustering)
 
     clustered_model.compile(
-        optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4),
-        loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
-        metrics=[tf.keras.metrics.SparseCategoricalAccuracy(name="accuracy")])
+        optimizer=keras.optimizers.Adam(learning_rate=1e-4),
+        loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
+        metrics=[keras.metrics.SparseCategoricalAccuracy(name="accuracy")],
+    )
     clustered_model.fit(
         self.x_train, self.y_train, epochs=2, batch_size=100, verbose=1)
 
diff --git a/tensorflow_model_optimization/python/core/clustering/keras/cluster_test.py b/tensorflow_model_optimization/python/core/clustering/keras/cluster_test.py
index 03fa4bbd8..3cea083d0 100644
--- a/tensorflow_model_optimization/python/core/clustering/keras/cluster_test.py
+++ b/tensorflow_model_optimization/python/core/clustering/keras/cluster_test.py
@@ -27,8 +27,9 @@
 from tensorflow_model_optimization.python.core.clustering.keras import clusterable_layer
 from tensorflow_model_optimization.python.core.clustering.keras import clustering_registry
 from tensorflow_model_optimization.python.core.clustering.keras.experimental import cluster as experimental_cluster
+from tensorflow_model_optimization.python.core.keras.compat import keras
+
 
-keras = tf.keras
 errors_impl = tf.errors
 layers = keras.layers
 test = tf.test
@@ -345,7 +346,7 @@ def testStripClusteringSequentialModelWithKernelRegularizer(self):
     """Verifies that stripping the clustering wrappers from a sequential model produces the expected config."""
     model = keras.Sequential([
         layers.Dense(10, input_shape=(10,)),
-        layers.Dense(10, kernel_regularizer=tf.keras.regularizers.L1(0.01)),
+        layers.Dense(10, kernel_regularizer=keras.regularizers.L1(0.01)),
     ])
     clustered_model = cluster.cluster_weights(model, **self.params)
     stripped_model = cluster.strip_clustering(clustered_model)
@@ -359,7 +360,7 @@ def testStripClusteringSequentialModelWithBiasRegularizer(self):
     """Verifies that stripping the clustering wrappers from a sequential model produces the expected config."""
     model = keras.Sequential([
         layers.Dense(10, input_shape=(10,)),
-        layers.Dense(10, bias_regularizer=tf.keras.regularizers.L1(0.01)),
+        layers.Dense(10, bias_regularizer=keras.regularizers.L1(0.01)),
     ])
     clustered_model = cluster.cluster_weights(model, **self.params)
     stripped_model = cluster.strip_clustering(clustered_model)
@@ -373,7 +374,7 @@ def testStripClusteringSequentialModelWithActivityRegularizer(self):
     """Verifies that stripping the clustering wrappers from a sequential model produces the expected config."""
     model = keras.Sequential([
         layers.Dense(10, input_shape=(10,)),
-        layers.Dense(10, activity_regularizer=tf.keras.regularizers.L1(0.01)),
+        layers.Dense(10, activity_regularizer=keras.regularizers.L1(0.01)),
     ])
     clustered_model = cluster.cluster_weights(model, **self.params)
     stripped_model = cluster.strip_clustering(clustered_model)
@@ -387,7 +388,7 @@ def testStripClusteringSequentialModelWithKernelConstraint(self):
     """Verifies that stripping the clustering wrappers from a sequential model produces the expected config."""
     model = keras.Sequential([
         layers.Dense(10, input_shape=(10,)),
-        layers.Dense(10, kernel_constraint=tf.keras.constraints.max_norm(2.)),
+        layers.Dense(10, kernel_constraint=keras.constraints.max_norm(2.0)),
     ])
     clustered_model = cluster.cluster_weights(model, **self.params)
     stripped_model = cluster.strip_clustering(clustered_model)
@@ -401,7 +402,7 @@ def testStripClusteringSequentialModelWithBiasConstraint(self):
     """Verifies that stripping the clustering wrappers from a sequential model produces the expected config."""
     model = keras.Sequential([
         layers.Dense(10, input_shape=(10,)),
-        layers.Dense(10, bias_constraint=tf.keras.constraints.max_norm(2.)),
+        layers.Dense(10, bias_constraint=keras.constraints.max_norm(2.0)),
     ])
     clustered_model = cluster.cluster_weights(model, **self.params)
     stripped_model = cluster.strip_clustering(clustered_model)
diff --git a/tensorflow_model_optimization/python/core/clustering/keras/cluster_wrapper.py b/tensorflow_model_optimization/python/core/clustering/keras/cluster_wrapper.py
index 4a0142df7..82c63460f 100644
--- a/tensorflow_model_optimization/python/core/clustering/keras/cluster_wrapper.py
+++ b/tensorflow_model_optimization/python/core/clustering/keras/cluster_wrapper.py
@@ -22,9 +22,10 @@
 from tensorflow_model_optimization.python.core.clustering.keras import clusterable_layer
 from tensorflow_model_optimization.python.core.clustering.keras import clustering_centroids
 from tensorflow_model_optimization.python.core.clustering.keras import clustering_registry
+from tensorflow_model_optimization.python.core.keras.compat import keras
+
 
 attrgetter = operator.attrgetter  # pylint: disable=invalid-name
-keras = tf.keras
 k = keras.backend
 Layer = keras.layers.Layer
 Wrapper = keras.layers.Wrapper
@@ -106,8 +107,8 @@ def __init__(self,
     # Whether to cluster Conv2D kernels per-channel.
     # In case the layer isn't a Conv2D, this isn't applicable
     self.cluster_per_channel = (
-        cluster_per_channel if isinstance(layer, tf.keras.layers.Conv2D)
-        else False)
+        cluster_per_channel if isinstance(layer, keras.layers.Conv2D) else False
+    )
 
     # Number of channels in a Conv2D layer, to be used the case of per-channel
     # clustering.
@@ -226,15 +227,16 @@ def build(self, input_shape):
           shape=(cluster_centroids.shape),
           dtype=weight.dtype,
           trainable=True,
-          initializer=tf.keras.initializers.Constant(value=cluster_centroids))
+          initializer=keras.initializers.Constant(value=cluster_centroids),
+      )
 
       # Init the weight clustering algorithm
-      if isinstance(self.layer, tf.keras.layers.RNN):
-        if isinstance(self.layer.cell, tf.keras.layers.StackedRNNCells):
+      if isinstance(self.layer, keras.layers.RNN):
+        if isinstance(self.layer.cell, keras.layers.StackedRNNCells):
           weight_name_no_index = weight_name.split('/')[0]
         else:
           weight_name_no_index = weight_name
-      elif isinstance(self.layer, tf.keras.layers.Bidirectional):
+      elif isinstance(self.layer, keras.layers.Bidirectional):
         weight_name_no_index = weight_name.split('/')[0]
       else:
         weight_name_no_index = weight_name
@@ -258,7 +260,8 @@ def build(self, input_shape):
           trainable=False,
           synchronization=tf.VariableSynchronization.ON_READ,
           aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA,
-          initializer=tf.keras.initializers.Constant(value=pulling_indices))
+          initializer=keras.initializers.Constant(value=pulling_indices),
+      )
 
       if self.preserve_sparsity:
         # Init the sparsity mask
@@ -360,8 +363,9 @@ def from_config(cls, config, custom_objects=None):
     config['cluster_gradient_aggregation'] = cluster_gradient_aggregation
     config['cluster_per_channel'] = cluster_per_channel
 
-    layer = tf.keras.layers.deserialize(
-        config.pop('layer'), custom_objects=custom_objects)
+    layer = keras.layers.deserialize(
+        config.pop('layer'), custom_objects=custom_objects
+    )
     config['layer'] = layer
 
     return cls(**config)
@@ -417,11 +421,11 @@ def get_return_layer_cell(self, index):
   def get_weight_from_layer(self, weight_name):
     weight_name_no_index, i = self.get_weight_name_without_index(weight_name)
     if hasattr(self.layer, 'cell'):
-      if isinstance(self.layer.cell, tf.keras.layers.StackedRNNCells):
+      if isinstance(self.layer.cell, keras.layers.StackedRNNCells):
         return getattr(self.layer.cell.cells[i], weight_name_no_index)
       else:
         return getattr(self.layer.cell, weight_name_no_index)
-    elif isinstance(self.layer, tf.keras.layers.Bidirectional):
+    elif isinstance(self.layer, keras.layers.Bidirectional):
       if i < 0 or i > 1:
         raise ValueError(
             'Unsupported number of cells in the layer to get weights from.')
@@ -433,13 +437,13 @@ def get_weight_from_layer(self, weight_name):
   def set_weight_to_layer(self, weight_name, new_weight):
     weight_name_no_index, i = self.get_weight_name_without_index(weight_name)
     if hasattr(self.layer, 'cell'):
-      if isinstance(self.layer.cell, tf.keras.layers.StackedRNNCells):
+      if isinstance(self.layer.cell, keras.layers.StackedRNNCells):
         return setattr(self.layer.cell.cells[i],
                        weight_name_no_index,
                        new_weight)
       else:
         return setattr(self.layer.cell, weight_name_no_index, new_weight)
-    elif isinstance(self.layer, tf.keras.layers.Bidirectional):
+    elif isinstance(self.layer, keras.layers.Bidirectional):
       if i < 0 or i > 1:
         raise ValueError(
             'Unsupported number of cells in the layer to set weights for.')
@@ -481,4 +485,3 @@ def strip_clustering(self):
       setattr(self.layer, weight_name, original_weight)
 
     return self.layer
-
diff --git a/tensorflow_model_optimization/python/core/clustering/keras/cluster_wrapper_test.py b/tensorflow_model_optimization/python/core/clustering/keras/cluster_wrapper_test.py
index ad9435063..038580929 100644
--- a/tensorflow_model_optimization/python/core/clustering/keras/cluster_wrapper_test.py
+++ b/tensorflow_model_optimization/python/core/clustering/keras/cluster_wrapper_test.py
@@ -25,8 +25,9 @@
 from tensorflow_model_optimization.python.core.clustering.keras import cluster_config
 from tensorflow_model_optimization.python.core.clustering.keras import cluster_wrapper
 from tensorflow_model_optimization.python.core.clustering.keras import clusterable_layer
+from tensorflow_model_optimization.python.core.keras.compat import keras
+
 
-keras = tf.keras
 errors_impl = tf.errors
 layers = keras.layers
 test = tf.test
@@ -146,7 +147,7 @@ def testValuesAreClusteredAfterStripping(self,
                                            number_of_clusters,
                                            cluster_centroids_init):
     """Verifies that, for any number of clusters and any centroid initialization  method, the number of unique weight values after stripping is always less or equal to number_of_clusters."""
-    original_model = tf.keras.Sequential([
+    original_model = keras.Sequential([
         layers.Dense(32, input_shape=(10,)),
     ])
     self.assertGreater(
diff --git a/tensorflow_model_optimization/python/core/clustering/keras/clustering_callbacks.py b/tensorflow_model_optimization/python/core/clustering/keras/clustering_callbacks.py
index d4a3c5945..7bb825bae 100644
--- a/tensorflow_model_optimization/python/core/clustering/keras/clustering_callbacks.py
+++ b/tensorflow_model_optimization/python/core/clustering/keras/clustering_callbacks.py
@@ -17,19 +17,20 @@
 import tensorflow as tf
 
 from tensorflow_model_optimization.python.core.keras import compat
+from tensorflow_model_optimization.python.core.keras.compat import keras
 
 
-class ClusteringSummaries(tf.keras.callbacks.TensorBoard):
+class ClusteringSummaries(keras.callbacks.TensorBoard):
   """Helper class to create tensorboard summaries for the clustering progress.
 
-    This class is derived from tf.keras.callbacks.TensorBoard and just adds
-    functionality to write histograms with batch-wise frequency.
+  This class is derived from keras.callbacks.TensorBoard and just adds
+  functionality to write histograms with batch-wise frequency.
 
-    Arguments:
-        log_dir: The path to the directory where the log files are saved
-        cluster_update_freq: determines the frequency of updates of the
-          clustering histograms. Same behaviour as parameter update_freq of the
-          base class, i.e. it accepts `'batch'`, `'epoch'` or integer.
+  Arguments:
+      log_dir: The path to the directory where the log files are saved
+      cluster_update_freq: determines the frequency of updates of the clustering
+        histograms. Same behaviour as parameter update_freq of the base class,
+        i.e. it accepts `'batch'`, `'epoch'` or integer.
   """
 
   def __init__(self, log_dir='logs', cluster_update_freq='epoch', **kwargs):
diff --git a/tensorflow_model_optimization/python/core/clustering/keras/clustering_centroids.py b/tensorflow_model_optimization/python/core/clustering/keras/clustering_centroids.py
index 80df9ffc8..b71f86f79 100644
--- a/tensorflow_model_optimization/python/core/clustering/keras/clustering_centroids.py
+++ b/tensorflow_model_optimization/python/core/clustering/keras/clustering_centroids.py
@@ -19,8 +19,10 @@
 import tensorflow as tf
 from tensorflow.python.ops import clustering_ops
 from tensorflow_model_optimization.python.core.clustering.keras import cluster_config
+from tensorflow_model_optimization.python.core.keras.compat import keras
 
-k = tf.keras.backend
+
+k = keras.backend
 CentroidInitialization = cluster_config.CentroidInitialization
 
 
diff --git a/tensorflow_model_optimization/python/core/clustering/keras/clustering_centroids_test.py b/tensorflow_model_optimization/python/core/clustering/keras/clustering_centroids_test.py
index 73389aa3c..80d3d338a 100644
--- a/tensorflow_model_optimization/python/core/clustering/keras/clustering_centroids_test.py
+++ b/tensorflow_model_optimization/python/core/clustering/keras/clustering_centroids_test.py
@@ -20,8 +20,10 @@
 
 from tensorflow_model_optimization.python.core.clustering.keras import cluster_config
 from tensorflow_model_optimization.python.core.clustering.keras import clustering_centroids
+from tensorflow_model_optimization.python.core.keras.compat import keras
 
-K = tf.keras.backend
+
+K = keras.backend
 errors_impl = tf.errors
 
 CentroidInitialization = cluster_config.CentroidInitialization
diff --git a/tensorflow_model_optimization/python/core/clustering/keras/clustering_registry.py b/tensorflow_model_optimization/python/core/clustering/keras/clustering_registry.py
index 0421ca607..8d55cd387 100644
--- a/tensorflow_model_optimization/python/core/clustering/keras/clustering_registry.py
+++ b/tensorflow_model_optimization/python/core/clustering/keras/clustering_registry.py
@@ -18,8 +18,10 @@
 
 from tensorflow_model_optimization.python.core.clustering.keras import clusterable_layer
 from tensorflow_model_optimization.python.core.clustering.keras import clustering_algorithm
+from tensorflow_model_optimization.python.core.keras.compat import keras
 
-layers = tf.keras.layers
+
+layers = keras.layers
 ClusteringAlgorithm = clustering_algorithm.ClusteringAlgorithm
 ClusteringAlgorithmPerChannel = clustering_algorithm.ClusteringAlgorithmPerChannel
 
@@ -42,7 +44,7 @@ def get_clustering_impl(cls, layer, weight_name, cluster_per_channel=False):
 
     # Per-channel clustering is only applied if the layer is a Conv2D,
     # ignored otherwise
-    if cluster_per_channel and isinstance(layer, tf.keras.layers.Conv2D):
+    if cluster_per_channel and isinstance(layer, keras.layers.Conv2D):
       return ClusteringAlgorithmPerChannel
 
     # Clusterable layer could provide own implementation of get_pulling_indices
@@ -90,6 +92,13 @@ class ClusteringRegistry(object):
       tf.compat.v2.keras.layers.SimpleRNNCell,
       tf.compat.v1.keras.layers.StackedRNNCells,
       tf.compat.v2.keras.layers.StackedRNNCells,
+      tf.compat.v1.keras.layers.Bidirectional,
+      tf.compat.v2.keras.layers.Bidirectional,
+      layers.GRUCell,
+      layers.LSTMCell,
+      layers.SimpleRNNCell,
+      layers.StackedRNNCells,
+      layers.Bidirectional,
   })
 
   _SUPPORTED_RNN_LAYERS = frozenset([
@@ -101,7 +110,7 @@ class ClusteringRegistry(object):
   ])
 
   _SUPPORTED_MHA_LAYERS = {
-      tf.keras.layers.MultiHeadAttention,
+      keras.layers.MultiHeadAttention,
   }
 
   @classmethod
@@ -141,9 +150,9 @@ def supports(cls, layer):
   def _get_rnn_cells(rnn_layer):  # pylint: disable=no-self-argument
     """Get rnn cells from layer."""
 
-    if isinstance(rnn_layer, tf.keras.layers.Bidirectional):
+    if isinstance(rnn_layer, keras.layers.Bidirectional):
       return [rnn_layer.forward_layer.cell, rnn_layer.backward_layer.cell]
-    if isinstance(rnn_layer.cell, tf.keras.layers.StackedRNNCells):
+    if isinstance(rnn_layer.cell, keras.layers.StackedRNNCells):
       return rnn_layer.cell.cells
     # The case when RNN contains multiple cells
     if isinstance(rnn_layer.cell, (list, tuple)):
diff --git a/tensorflow_model_optimization/python/core/clustering/keras/clustering_registry_test.py b/tensorflow_model_optimization/python/core/clustering/keras/clustering_registry_test.py
index fb470bb83..543cb9ae1 100644
--- a/tensorflow_model_optimization/python/core/clustering/keras/clustering_registry_test.py
+++ b/tensorflow_model_optimization/python/core/clustering/keras/clustering_registry_test.py
@@ -20,8 +20,9 @@
 from tensorflow_model_optimization.python.core.clustering.keras import clusterable_layer
 from tensorflow_model_optimization.python.core.clustering.keras import clustering_registry
 from tensorflow_model_optimization.python.core.clustering.keras.cluster_config import GradientAggregation
+from tensorflow_model_optimization.python.core.keras.compat import keras
+
 
-keras = tf.keras
 k = keras.backend
 layers = keras.layers
 
@@ -525,7 +526,7 @@ def testMakeClusterableWorksOnKerasRNNLayerWithRNNCellsParams(self):
     """A built-in RNN layer with built-in RNN cells is clusterable."""
     cell1 = layers.LSTMCell(10)
     cell2 = layers.GRUCell(5)
-    cell_list = tf.keras.layers.StackedRNNCells([cell1, cell2])
+    cell_list = keras.layers.StackedRNNCells([cell1, cell2])
 
     layer = layers.RNN(cell_list)
     with self.assertRaises(AttributeError):
@@ -553,7 +554,7 @@ def testMakeClusterableWorksOnKerasBidirectionalLayerWithLSTM(self):
     Verifies that make_clusterable() works as expected on a Bidirectional
     wrapper with a LSTM layer
     """
-    layer = tf.keras.layers.Bidirectional(layers.LSTM(10))
+    layer = keras.layers.Bidirectional(layers.LSTM(10))
     with self.assertRaises(AttributeError):
       layer.get_clusterable_weights()
 
@@ -593,7 +594,7 @@ def testMakeClusterableRaisesErrorOnRNNLayersUnsupportedCell(self):
 
   def testSupportsMultiHeadAttentionLayer(self):
     """Verifies that ClusterRegistry supports a MultiHeadAttention layer."""
-    layer = tf.keras.layers.MultiHeadAttention(num_heads=2, key_dim=2)
+    layer = keras.layers.MultiHeadAttention(num_heads=2, key_dim=2)
     self.assertTrue(ClusterRegistry.supports(layer))
     ClusterRegistry.make_clusterable(layer)
 
diff --git a/tensorflow_model_optimization/python/core/clustering/keras/experimental/cluster.py b/tensorflow_model_optimization/python/core/clustering/keras/experimental/cluster.py
index 982afbac4..970c9d8da 100644
--- a/tensorflow_model_optimization/python/core/clustering/keras/experimental/cluster.py
+++ b/tensorflow_model_optimization/python/core/clustering/keras/experimental/cluster.py
@@ -88,8 +88,8 @@ def cluster_weights(
   ```
 
   Arguments:
-      to_cluster: A single keras layer, list of keras layers, or a
-        `tf.keras.Model` instance.
+      to_cluster: A single keras layer, list of keras layers, or a `keras.Model`
+        instance.
       number_of_clusters: the number of cluster centroids to form when
         clustering a layer/model. For example, if number_of_clusters=8 then only
         8 unique values will be used in each weight array.
diff --git a/tensorflow_model_optimization/python/core/clustering/keras/mnist_clusterable_layer_test.py b/tensorflow_model_optimization/python/core/clustering/keras/mnist_clusterable_layer_test.py
index 5685c8b90..afbac881c 100644
--- a/tensorflow_model_optimization/python/core/clustering/keras/mnist_clusterable_layer_test.py
+++ b/tensorflow_model_optimization/python/core/clustering/keras/mnist_clusterable_layer_test.py
@@ -20,10 +20,10 @@
 from tensorflow_model_optimization.python.core.clustering.keras import cluster_config
 from tensorflow_model_optimization.python.core.clustering.keras import clusterable_layer
 from tensorflow_model_optimization.python.core.clustering.keras import clustering_algorithm
+from tensorflow_model_optimization.python.core.keras.compat import keras
 
-tf.random.set_seed(42)
 
-keras = tf.keras
+tf.random.set_seed(42)
 
 EPOCHS = 7
 EPOCHS_FINE_TUNING = 4
@@ -102,36 +102,36 @@ def get_clusterable_algorithm(self, weight_name):
 
 def _build_model():
   """Builds model with MyDenseLayer."""
-  i = tf.keras.layers.Input(shape=(28, 28), name='input')
-  x = tf.keras.layers.Reshape((28, 28, 1))(i)
-  x = tf.keras.layers.Conv2D(
-      filters=12, kernel_size=(3, 3), activation='relu', name='conv1')(
-          x)
-  x = tf.keras.layers.MaxPool2D(2, 2)(x)
-  x = tf.keras.layers.Flatten()(x)
+  i = keras.layers.Input(shape=(28, 28), name='input')
+  x = keras.layers.Reshape((28, 28, 1))(i)
+  x = keras.layers.Conv2D(
+      filters=12, kernel_size=(3, 3), activation='relu', name='conv1'
+  )(x)
+  x = keras.layers.MaxPool2D(2, 2)(x)
+  x = keras.layers.Flatten()(x)
   output = MyDenseLayer(units=10)(x)
 
-  model = tf.keras.Model(inputs=[i], outputs=[output])
+  model = keras.Model(inputs=[i], outputs=[output])
   return model
 
 
 def _build_model_2():
   """Builds model with MyClusterableLayer layer."""
-  i = tf.keras.layers.Input(shape=(28, 28), name='input')
-  x = tf.keras.layers.Reshape((28, 28, 1))(i)
-  x = tf.keras.layers.Conv2D(
-      filters=12, kernel_size=(3, 3), activation='relu', name='conv1')(
-          x)
-  x = tf.keras.layers.MaxPool2D(2, 2)(x)
-  x = tf.keras.layers.Flatten()(x)
+  i = keras.layers.Input(shape=(28, 28), name='input')
+  x = keras.layers.Reshape((28, 28, 1))(i)
+  x = keras.layers.Conv2D(
+      filters=12, kernel_size=(3, 3), activation='relu', name='conv1'
+  )(x)
+  x = keras.layers.MaxPool2D(2, 2)(x)
+  x = keras.layers.Flatten()(x)
   output = MyClusterableLayer(units=10)(x)
 
-  model = tf.keras.Model(inputs=[i], outputs=[output])
+  model = keras.Model(inputs=[i], outputs=[output])
   return model
 
 
 def _get_dataset():
-  mnist = tf.keras.datasets.mnist
+  mnist = keras.datasets.mnist
   (x_train, y_train), (x_test, y_test) = mnist.load_data()
   x_train, x_test = x_train / 255.0, x_test / 255.0
   # Use subset of 60000 examples to keep unit test speed fast.
@@ -141,7 +141,7 @@ def _get_dataset():
 
 
 def _train_model(model):
-  loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
+  loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True)
 
   model.compile(optimizer='adam', loss=loss_fn, metrics=['accuracy'])
 
@@ -166,12 +166,13 @@ def _cluster_model(model, number_of_clusters):
 
   # Use smaller learning rate for fine-tuning
   # clustered model
-  opt = tf.keras.optimizers.Adam(learning_rate=1e-5)
+  opt = keras.optimizers.Adam(learning_rate=1e-5)
 
   clustered_model.compile(
-      loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
+      loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
       optimizer=opt,
-      metrics=['accuracy'])
+      metrics=['accuracy'],
+  )
 
   # Fine-tune clustered model
   clustered_model.fit(x_train, y_train, epochs=EPOCHS_FINE_TUNING)
@@ -179,9 +180,10 @@ def _cluster_model(model, number_of_clusters):
   stripped_model = cluster.strip_clustering(clustered_model)
 
   stripped_model.compile(
-      loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
+      loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
       optimizer=opt,
-      metrics=['accuracy'])
+      metrics=['accuracy'],
+  )
 
   return stripped_model
 
diff --git a/tensorflow_model_optimization/python/core/clustering/keras/mnist_clustering_test.py b/tensorflow_model_optimization/python/core/clustering/keras/mnist_clustering_test.py
index ca6b11adc..caa30ce18 100644
--- a/tensorflow_model_optimization/python/core/clustering/keras/mnist_clustering_test.py
+++ b/tensorflow_model_optimization/python/core/clustering/keras/mnist_clustering_test.py
@@ -19,10 +19,10 @@
 
 from tensorflow_model_optimization.python.core.clustering.keras import cluster
 from tensorflow_model_optimization.python.core.clustering.keras import cluster_config
+from tensorflow_model_optimization.python.core.keras.compat import keras
 
-tf.random.set_seed(42)
 
-keras = tf.keras
+tf.random.set_seed(42)
 
 EPOCHS = 7
 EPOCHS_FINE_TUNING = 4
@@ -32,21 +32,24 @@
 
 def _build_model():
   """Builds a simple CNN model."""
-  i = tf.keras.layers.Input(shape=(28, 28), name='input')
-  x = tf.keras.layers.Reshape((28, 28, 1))(i)
-  x = tf.keras.layers.Conv2D(
-      filters=NUMBER_OF_CHANNELS, kernel_size=(3, 3),
-      activation='relu', name='conv1')(x)
-  x = tf.keras.layers.MaxPool2D(2, 2)(x)
-  x = tf.keras.layers.Flatten()(x)
-  output = tf.keras.layers.Dense(units=10)(x)
-
-  model = tf.keras.Model(inputs=[i], outputs=[output])
+  i = keras.layers.Input(shape=(28, 28), name='input')
+  x = keras.layers.Reshape((28, 28, 1))(i)
+  x = keras.layers.Conv2D(
+      filters=NUMBER_OF_CHANNELS,
+      kernel_size=(3, 3),
+      activation='relu',
+      name='conv1',
+  )(x)
+  x = keras.layers.MaxPool2D(2, 2)(x)
+  x = keras.layers.Flatten()(x)
+  output = keras.layers.Dense(units=10)(x)
+
+  model = keras.Model(inputs=[i], outputs=[output])
   return model
 
 
 def _get_dataset():
-  mnist = tf.keras.datasets.mnist
+  mnist = keras.datasets.mnist
   (x_train, y_train), (x_test, y_test) = mnist.load_data()
   x_train, x_test = x_train / 255.0, x_test / 255.0
   # Use subset of 60000 examples to keep unit test speed fast.
@@ -56,7 +59,7 @@ def _get_dataset():
 
 
 def _train_model(model):
-  loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
+  loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True)
 
   model.compile(optimizer='adam', loss=loss_fn, metrics=['accuracy'])
 
@@ -89,21 +92,23 @@ def _cluster_model(model,
 
   # Use smaller learning rate for fine-tuning
   # clustered model
-  opt = tf.keras.optimizers.Adam(learning_rate=1e-5)
+  opt = keras.optimizers.Adam(learning_rate=1e-5)
 
   clustered_model.compile(
-      loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
+      loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
       optimizer=opt,
-      metrics=['accuracy'])
+      metrics=['accuracy'],
+  )
 
   # Fine-tune clustered model
   clustered_model.fit(x_train, y_train, epochs=EPOCHS_FINE_TUNING)
 
   stripped_model = cluster.strip_clustering(clustered_model)
   stripped_model.compile(
-      loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
+      loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
       optimizer=opt,
-      metrics=['accuracy'])
+      metrics=['accuracy'],
+  )
 
   return stripped_model
 
@@ -174,8 +179,9 @@ def testMnist(self, preserve_sparsity, cluster_per_channel):
     for i in layer_indices:
       nr_of_unique_weights = _get_number_of_unique_weights(
           clustered_model, i, 'kernel')
-      if (cluster_per_channel
-          and isinstance(clustered_model.layers[i], tf.keras.layers.Conv2D)):
+      if cluster_per_channel and isinstance(
+          clustered_model.layers[i], keras.layers.Conv2D
+      ):
         self.assertLessEqual(nr_of_unique_weights,
                              NUMBER_OF_CLUSTERS * NUMBER_OF_CHANNELS)
       else:
diff --git a/tensorflow_model_optimization/python/core/common/keras/compression/BUILD b/tensorflow_model_optimization/python/core/common/keras/compression/BUILD
index 1027551ee..1d9f5d824 100644
--- a/tensorflow_model_optimization/python/core/common/keras/compression/BUILD
+++ b/tensorflow_model_optimization/python/core/common/keras/compression/BUILD
@@ -1,5 +1,5 @@
-load("//tensorflow_model_optimization:tensorflow_model_optimization.bzl", "py_strict_library", "py_strict_test")
 load("//tensorflow_model_optimization:tensorflow_model_optimization.bzl", "pytype_strict_library")
+load("//tensorflow_model_optimization:tensorflow_model_optimization.bzl", "py_strict_library", "py_strict_test")
 
 licenses(["notice"])
 
@@ -11,6 +11,7 @@ pytype_strict_library(
     deps = [
         # tensorflow dep1,
         "//tensorflow_model_optimization/python/core/common/keras/compression/internal:optimize",
+        "//tensorflow_model_optimization/python/core/keras:compat",
     ],
 )
 
diff --git a/tensorflow_model_optimization/python/core/common/keras/compression/algorithm.py b/tensorflow_model_optimization/python/core/common/keras/compression/algorithm.py
index f70ca9f38..370711671 100644
--- a/tensorflow_model_optimization/python/core/common/keras/compression/algorithm.py
+++ b/tensorflow_model_optimization/python/core/common/keras/compression/algorithm.py
@@ -14,12 +14,13 @@
 # ==============================================================================
 """Public APIs for algorithm developer using weight compression API."""
 import abc
-from typing import List, Any
 import dataclasses
+from typing import Any, List
 
 import tensorflow as tf
 
 from tensorflow_model_optimization.python.core.common.keras.compression.internal import optimize
+from tensorflow_model_optimization.python.core.keras.compat import keras
 
 
 @dataclasses.dataclass
@@ -41,12 +42,13 @@ class WeightCompressor(metaclass=abc.ABCMeta):
 
   # TODO(tfmot): Consider separate from algorithm API for custom layer supports.
   def get_compressible_weights(
-      self, original_layer: tf.keras.layers.Layer) -> List[tf.Variable]:
+      self, original_layer: keras.layers.Layer
+  ) -> List[tf.Variable]:
     """Define compressible weights for each layer.
 
     Args:
-       original_layer: tf.keras.layers.Layer representing a layer from the
-       original model.
+       original_layer: keras.layers.Layer representing a layer from the original
+         model.
 
     Returns:
        List of compressible weights for the given layer.
@@ -175,12 +177,12 @@ def decompress_weights(
 
 
 def create_layer_for_training(
-    layer: tf.keras.layers.Layer,
-    algorithm: WeightCompressor) -> tf.keras.layers.Layer:
+    layer: keras.layers.Layer, algorithm: WeightCompressor
+) -> keras.layers.Layer:
   return optimize.create_layer_for_training(layer, algorithm)
 
 
 def create_layer_for_inference(
-    layer_for_training: tf.keras.layers.Layer,
-    algorithm: WeightCompressor) -> tf.keras.layers.Layer:
+    layer_for_training: keras.layers.Layer, algorithm: WeightCompressor
+) -> keras.layers.Layer:
   return optimize.create_layer_for_inference(layer_for_training, algorithm)
diff --git a/tensorflow_model_optimization/python/core/common/keras/compression/algorithms/BUILD b/tensorflow_model_optimization/python/core/common/keras/compression/algorithms/BUILD
index d44b93bbc..39772ca97 100644
--- a/tensorflow_model_optimization/python/core/common/keras/compression/algorithms/BUILD
+++ b/tensorflow_model_optimization/python/core/common/keras/compression/algorithms/BUILD
@@ -1,5 +1,5 @@
-load("//tensorflow_model_optimization:tensorflow_model_optimization.bzl", "py_strict_library", "py_strict_test")
 load("//tensorflow_model_optimization:tensorflow_model_optimization.bzl", "pytype_strict_library")
+load("//tensorflow_model_optimization:tensorflow_model_optimization.bzl", "py_strict_library", "py_strict_test")
 
 package(default_visibility = ["//visibility:private"])
 
@@ -13,6 +13,7 @@ pytype_strict_library(
         # tensorflow dep1,
         # tensorflow_compression dep1,
         "//tensorflow_model_optimization/python/core/common/keras/compression:algorithm",
+        "//tensorflow_model_optimization/python/core/keras:compat",
     ],
 )
 
@@ -31,6 +32,7 @@ py_strict_test(
         # absl/testing:parameterized dep1,
         # google/protobuf:use_fast_cpp_protos dep1,  # Automatically added
         # tensorflow dep1,
+        "//tensorflow_model_optimization/python/core/keras:compat",
     ],
 )
 
@@ -41,6 +43,7 @@ pytype_strict_library(
     deps = [
         # tensorflow dep1,
         "//tensorflow_model_optimization/python/core/common/keras/compression:algorithm",
+        "//tensorflow_model_optimization/python/core/keras:compat",
     ],
 )
 
@@ -54,6 +57,7 @@ py_strict_test(
         ":same_training_and_inference",
         # google/protobuf:use_fast_cpp_protos dep1,  # Automatically added
         # tensorflow dep1,
+        "//tensorflow_model_optimization/python/core/keras:compat",
         "//tensorflow_model_optimization/python/core/keras/testing:test_utils_mnist",
     ],
 )
@@ -65,6 +69,7 @@ pytype_strict_library(
     deps = [
         # tensorflow dep1,
         "//tensorflow_model_optimization/python/core/common/keras/compression:algorithm",
+        "//tensorflow_model_optimization/python/core/keras:compat",
     ],
 )
 
@@ -78,6 +83,7 @@ py_strict_test(
         ":different_training_and_inference",
         # google/protobuf:use_fast_cpp_protos dep1,  # Automatically added
         # tensorflow dep1,
+        "//tensorflow_model_optimization/python/core/keras:compat",
         "//tensorflow_model_optimization/python/core/keras/testing:test_utils_mnist",
     ],
 )
@@ -89,6 +95,7 @@ pytype_strict_library(
     deps = [
         # tensorflow dep1,
         "//tensorflow_model_optimization/python/core/common/keras/compression:algorithm",
+        "//tensorflow_model_optimization/python/core/keras:compat",
     ],
 )
 
@@ -102,6 +109,7 @@ py_strict_test(
         ":bias_only",
         # google/protobuf:use_fast_cpp_protos dep1,  # Automatically added
         # tensorflow dep1,
+        "//tensorflow_model_optimization/python/core/keras:compat",
         "//tensorflow_model_optimization/python/core/keras/testing:test_utils_mnist",
     ],
 )
@@ -116,6 +124,7 @@ py_strict_library(
         "//tensorflow_model_optimization/python/core/clustering/keras:clustering_centroids",
         "//tensorflow_model_optimization/python/core/clustering/keras:clustering_registry",
         "//tensorflow_model_optimization/python/core/common/keras/compression:algorithm",
+        "//tensorflow_model_optimization/python/core/keras:compat",
     ],
 )
 
@@ -129,6 +138,7 @@ py_strict_test(
         # google/protobuf:use_fast_cpp_protos dep1,  # Automatically added
         # tensorflow dep1,
         "//tensorflow_model_optimization/python/core/clustering/keras:cluster_config",
+        "//tensorflow_model_optimization/python/core/keras:compat",
     ],
 )
 
@@ -139,6 +149,7 @@ pytype_strict_library(
     deps = [
         # tensorflow dep1,
         "//tensorflow_model_optimization/python/core/common/keras/compression:algorithm",
+        "//tensorflow_model_optimization/python/core/keras:compat",
     ],
 )
 
@@ -152,6 +163,7 @@ py_strict_test(
         ":periodical_update_and_scheduling",
         # google/protobuf:use_fast_cpp_protos dep1,  # Automatically added
         # tensorflow dep1,
+        "//tensorflow_model_optimization/python/core/keras:compat",
         "//tensorflow_model_optimization/python/core/keras/testing:test_utils_mnist",
     ],
 )
diff --git a/tensorflow_model_optimization/python/core/common/keras/compression/algorithms/bias_only.py b/tensorflow_model_optimization/python/core/common/keras/compression/algorithms/bias_only.py
index c5e0929ba..792966fc3 100644
--- a/tensorflow_model_optimization/python/core/common/keras/compression/algorithms/bias_only.py
+++ b/tensorflow_model_optimization/python/core/common/keras/compression/algorithms/bias_only.py
@@ -18,6 +18,7 @@
 import tensorflow as tf
 
 from tensorflow_model_optimization.python.core.common.keras.compression import algorithm
+from tensorflow_model_optimization.python.core.keras.compat import keras
 
 
 # TODO(tfmot): This algorithm is showcase for bias only compression. if we find
@@ -41,12 +42,14 @@ def init_training_weights(
         name='bias_mean',
         shape=bias_mean.shape,
         dtype=bias_mean.dtype,
-        initializer=tf.keras.initializers.Constant(bias_mean))
+        initializer=keras.initializers.Constant(bias_mean),
+    )
     self.add_training_weight(
         name='bias_shape',
         shape=bias_shape.shape,
         dtype=bias_shape.dtype,
-        initializer=tf.keras.initializers.Constant(bias_shape))
+        initializer=keras.initializers.Constant(bias_shape),
+    )
 
   def decompress_weights(
       self, bias_mean: tf.Tensor, bias_shape: tf.Tensor) -> tf.Tensor:
@@ -57,20 +60,25 @@ def project_training_weights(
     return self.decompress_weights(bias_mean, bias_shape)
 
   def get_compressible_weights(
-      self, original_layer: tf.keras.layers.Layer) -> List[str]:
-    if isinstance(original_layer, tf.keras.layers.Conv2D) or \
-       isinstance(original_layer, tf.keras.layers.Dense):
+      self, original_layer: keras.layers.Layer
+  ) -> List[str]:
+    if isinstance(original_layer, keras.layers.Conv2D) or isinstance(
+        original_layer, keras.layers.Dense
+    ):
       return [original_layer.bias]
     return []
 
-  def compress_model(self, to_optimize: tf.keras.Model) -> tf.keras.Model:
+  def compress_model(self, to_optimize: keras.Model) -> keras.Model:
     """Model developer API for optimizing a model."""
     # pylint: disable=protected-access
-    if not isinstance(to_optimize, tf.keras.Sequential) \
-        and not to_optimize._is_graph_network:
+    if (
+        not isinstance(to_optimize, keras.Sequential)
+        and not to_optimize._is_graph_network
+    ):
       raise ValueError(
-          '`compress_model` can only either be a tf.keras Sequential or '
-          'Functional model.')
+          '`compress_model` can only either be a keras Sequential or '
+          'Functional model.'
+      )
     # pylint: enable=protected-access
 
     def _optimize_layer(layer):
@@ -82,5 +90,4 @@ def _optimize_layer(layer):
 
       return algorithm.create_layer_for_training(layer, algorithm=self)
 
-    return tf.keras.models.clone_model(
-        to_optimize, clone_function=_optimize_layer)
+    return keras.models.clone_model(to_optimize, clone_function=_optimize_layer)
diff --git a/tensorflow_model_optimization/python/core/common/keras/compression/algorithms/bias_only_test.py b/tensorflow_model_optimization/python/core/common/keras/compression/algorithms/bias_only_test.py
index 040165882..b520d43e9 100644
--- a/tensorflow_model_optimization/python/core/common/keras/compression/algorithms/bias_only_test.py
+++ b/tensorflow_model_optimization/python/core/common/keras/compression/algorithms/bias_only_test.py
@@ -20,30 +20,31 @@
 import tensorflow as tf
 
 from tensorflow_model_optimization.python.core.common.keras.compression.algorithms import bias_only
+from tensorflow_model_optimization.python.core.keras.compat import keras
 from tensorflow_model_optimization.python.core.keras.testing import test_utils_mnist
 
 
 def _build_model():
-  i = tf.keras.layers.Input(shape=(28, 28), name='input')
-  x = tf.keras.layers.Reshape((28, 28, 1))(i)
-  x = tf.keras.layers.Conv2D(
-      20, 5, activation='relu', padding='valid', name='conv1')(
-          x)
-  x = tf.keras.layers.MaxPool2D(2, 2)(x)
-  x = tf.keras.layers.Conv2D(
-      50, 5, activation='relu', padding='valid', name='conv2')(
-          x)
-  x = tf.keras.layers.MaxPool2D(2, 2)(x)
-  x = tf.keras.layers.Flatten()(x)
-  x = tf.keras.layers.Dense(500, activation='relu', name='fc1')(x)
-  output = tf.keras.layers.Dense(10, name='fc2')(x)
-
-  model = tf.keras.Model(inputs=[i], outputs=[output])
+  i = keras.layers.Input(shape=(28, 28), name='input')
+  x = keras.layers.Reshape((28, 28, 1))(i)
+  x = keras.layers.Conv2D(
+      20, 5, activation='relu', padding='valid', name='conv1'
+  )(x)
+  x = keras.layers.MaxPool2D(2, 2)(x)
+  x = keras.layers.Conv2D(
+      50, 5, activation='relu', padding='valid', name='conv2'
+  )(x)
+  x = keras.layers.MaxPool2D(2, 2)(x)
+  x = keras.layers.Flatten()(x)
+  x = keras.layers.Dense(500, activation='relu', name='fc1')(x)
+  output = keras.layers.Dense(10, name='fc2')(x)
+
+  model = keras.Model(inputs=[i], outputs=[output])
   return model
 
 
 def _get_dataset():
-  mnist = tf.keras.datasets.mnist
+  mnist = keras.datasets.mnist
   (x_train, y_train), (x_test, y_test) = mnist.load_data()
   x_train, x_test = x_train / 255.0, x_test / 255.0
   # Use subset of 60000 examples to keep unit test speed fast.
@@ -53,7 +54,7 @@ def _get_dataset():
 
 
 def _train_model(model):
-  loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
+  loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True)
 
   model.compile(optimizer='adam', loss=loss_fn, metrics=['accuracy'])
 
@@ -118,7 +119,7 @@ def testBiasOnly_HasReasonableAccuracy_TF(self):
 
     _, (x_test, y_test) = _get_dataset()
 
-    loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
+    loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True)
 
     compressed_model.compile(
         optimizer='adam', loss=loss_fn, metrics=['accuracy'])
@@ -156,9 +157,9 @@ def testBiasOnly_BreaksDownLayerWeights(self):
 
   # TODO(tfmot): can simplify to single layer test.
   def testBiasOnly_PreservesPretrainedWeights(self):
-    i = tf.keras.layers.Input(shape=(2), name='input')
-    output = tf.keras.layers.Dense(3, name='fc1')(i)
-    model = tf.keras.Model(inputs=[i], outputs=[output])
+    i = keras.layers.Input(shape=(2), name='input')
+    output = keras.layers.Dense(3, name='fc1')(i)
+    model = keras.Model(inputs=[i], outputs=[output])
 
     dense_layer_weights = model.layers[1].get_weights()
 
diff --git a/tensorflow_model_optimization/python/core/common/keras/compression/algorithms/different_training_and_inference.py b/tensorflow_model_optimization/python/core/common/keras/compression/algorithms/different_training_and_inference.py
index a5f5e1390..706b64274 100644
--- a/tensorflow_model_optimization/python/core/common/keras/compression/algorithms/different_training_and_inference.py
+++ b/tensorflow_model_optimization/python/core/common/keras/compression/algorithms/different_training_and_inference.py
@@ -18,6 +18,7 @@
 import tensorflow as tf
 
 from tensorflow_model_optimization.python.core.common.keras.compression import algorithm
+from tensorflow_model_optimization.python.core.keras.compat import keras
 
 
 class SVD(algorithm.WeightCompressor):
@@ -35,7 +36,8 @@ def init_training_weights(
         name='w',
         shape=pretrained_weight.shape,
         dtype=pretrained_weight.dtype,
-        initializer=tf.keras.initializers.Constant(pretrained_weight))
+        initializer=keras.initializers.Constant(pretrained_weight),
+    )
 
   def decompress_weights(self, u: tf.Tensor, sv: tf.Tensor) -> tf.Tensor:
     return tf.matmul(u, sv)
@@ -66,22 +68,27 @@ def project_training_weights(self, weight: tf.Tensor) -> tf.Tensor:
     return weight
 
   def get_compressible_weights(
-      self, original_layer: tf.keras.layers.Layer) -> List[str]:
-    if isinstance(original_layer, tf.keras.layers.Conv2D) or \
-       isinstance(original_layer, tf.keras.layers.Dense):
+      self, original_layer: keras.layers.Layer
+  ) -> List[str]:
+    if isinstance(original_layer, keras.layers.Conv2D) or isinstance(
+        original_layer, keras.layers.Dense
+    ):
       return [original_layer.kernel]
     return []
 
   # TODO(tfmot): consider if we can simplify `create_model_for_training` and
   # `create_model_for_inference` into a single API for algorithm developers.
-  def compress_model(self, to_optimize: tf.keras.Model) -> tf.keras.Model:
+  def compress_model(self, to_optimize: keras.Model) -> keras.Model:
     """Model developer API for optimizing a model."""
     # pylint: disable=protected-access
-    if not isinstance(to_optimize, tf.keras.Sequential) \
-        and not to_optimize._is_graph_network:
+    if (
+        not isinstance(to_optimize, keras.Sequential)
+        and not to_optimize._is_graph_network
+    ):
       raise ValueError(
-          '`compress_model` can only either be a tf.keras Sequential or '
-          'Functional model.')
+          '`compress_model` can only either be a keras Sequential or '
+          'Functional model.'
+      )
     # pylint: enable=protected-access
 
     def _create_layer_for_training(layer):
@@ -96,8 +103,10 @@ def _create_layer_for_training(layer):
     def _create_layer_for_inference(layer):
       return algorithm.create_layer_for_inference(layer, algorithm=self)
 
-    intermediate_model = tf.keras.models.clone_model(
-        to_optimize, clone_function=_create_layer_for_training)
+    intermediate_model = keras.models.clone_model(
+        to_optimize, clone_function=_create_layer_for_training
+    )
 
-    return tf.keras.models.clone_model(
-        intermediate_model, clone_function=_create_layer_for_inference)
+    return keras.models.clone_model(
+        intermediate_model, clone_function=_create_layer_for_inference
+    )
diff --git a/tensorflow_model_optimization/python/core/common/keras/compression/algorithms/different_training_and_inference_test.py b/tensorflow_model_optimization/python/core/common/keras/compression/algorithms/different_training_and_inference_test.py
index 60d0ae490..d375ad813 100644
--- a/tensorflow_model_optimization/python/core/common/keras/compression/algorithms/different_training_and_inference_test.py
+++ b/tensorflow_model_optimization/python/core/common/keras/compression/algorithms/different_training_and_inference_test.py
@@ -19,31 +19,32 @@
 import tensorflow as tf
 
 from tensorflow_model_optimization.python.core.common.keras.compression.algorithms import different_training_and_inference as svd
+from tensorflow_model_optimization.python.core.keras.compat import keras
 from tensorflow_model_optimization.python.core.keras.testing import test_utils_mnist
 
 
 # TODO(tfmot): dedup.
 def _build_model():
-  i = tf.keras.layers.Input(shape=(28, 28), name='input')
-  x = tf.keras.layers.Reshape((28, 28, 1))(i)
-  x = tf.keras.layers.Conv2D(
-      20, 5, activation='relu', padding='valid', name='conv1')(
-          x)
-  x = tf.keras.layers.MaxPool2D(2, 2)(x)
-  x = tf.keras.layers.Conv2D(
-      50, 5, activation='relu', padding='valid', name='conv2')(
-          x)
-  x = tf.keras.layers.MaxPool2D(2, 2)(x)
-  x = tf.keras.layers.Flatten()(x)
-  x = tf.keras.layers.Dense(500, activation='relu', name='fc1')(x)
-  output = tf.keras.layers.Dense(10, name='fc2')(x)
-
-  model = tf.keras.Model(inputs=[i], outputs=[output])
+  i = keras.layers.Input(shape=(28, 28), name='input')
+  x = keras.layers.Reshape((28, 28, 1))(i)
+  x = keras.layers.Conv2D(
+      20, 5, activation='relu', padding='valid', name='conv1'
+  )(x)
+  x = keras.layers.MaxPool2D(2, 2)(x)
+  x = keras.layers.Conv2D(
+      50, 5, activation='relu', padding='valid', name='conv2'
+  )(x)
+  x = keras.layers.MaxPool2D(2, 2)(x)
+  x = keras.layers.Flatten()(x)
+  x = keras.layers.Dense(500, activation='relu', name='fc1')(x)
+  output = keras.layers.Dense(10, name='fc2')(x)
+
+  model = keras.Model(inputs=[i], outputs=[output])
   return model
 
 
 def _get_dataset():
-  mnist = tf.keras.datasets.mnist
+  mnist = keras.datasets.mnist
   (x_train, y_train), (x_test, y_test) = mnist.load_data()
   x_train, x_test = x_train / 255.0, x_test / 255.0
   # Use subset of 60000 examples to keep unit test speed fast.
@@ -53,7 +54,7 @@ def _get_dataset():
 
 
 def _train_model(model):
-  loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
+  loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True)
 
   model.compile(optimizer='adam', loss=loss_fn, metrics=['accuracy'])
 
@@ -126,7 +127,7 @@ def testSVD_HasReasonableAccuracy_TF(self):
 
     _, (x_test, y_test) = _get_dataset()
 
-    loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
+    loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True)
 
     model_for_inference.compile(
         optimizer='adam', loss=loss_fn, metrics=['accuracy'])
@@ -179,9 +180,9 @@ def testSVD_BreaksDownLayerWeights(self):
 
   # TODO(tfmot): can simplify to single layer test.
   def testSVD_PreservesPretrainedWeights(self):
-    i = tf.keras.layers.Input(shape=(2), name='input')
-    output = tf.keras.layers.Dense(3, name='fc1')(i)
-    model = tf.keras.Model(inputs=[i], outputs=[output])
+    i = keras.layers.Input(shape=(2), name='input')
+    output = keras.layers.Dense(3, name='fc1')(i)
+    model = keras.Model(inputs=[i], outputs=[output])
 
     dense_layer_weights = model.layers[1].get_weights()
 
diff --git a/tensorflow_model_optimization/python/core/common/keras/compression/algorithms/epr.py b/tensorflow_model_optimization/python/core/common/keras/compression/algorithms/epr.py
index c21f952bb..09bae47c0 100644
--- a/tensorflow_model_optimization/python/core/common/keras/compression/algorithms/epr.py
+++ b/tensorflow_model_optimization/python/core/common/keras/compression/algorithms/epr.py
@@ -31,6 +31,7 @@
 import tensorflow as tf
 import tensorflow_compression as tfc
 from tensorflow_model_optimization.python.core.common.keras.compression import algorithm
+from tensorflow_model_optimization.python.core.keras.compat import keras
 
 
 @tf.custom_gradient
@@ -76,9 +77,9 @@ class EPRBase(algorithm.WeightCompressor):
   """Defines how to apply the EPR algorithm."""
 
   _compressible_classes = (
-      tf.keras.layers.Dense,
-      tf.keras.layers.Conv1D,
-      tf.keras.layers.Conv2D,
+      keras.layers.Dense,
+      keras.layers.Conv1D,
+      keras.layers.Conv2D,
   )
 
   def __init__(self, regularization_weight: float):
@@ -108,7 +109,8 @@ def _init_training_weights_reparam(
           name=weight_name,
           shape=shape,
           dtype=dtype,
-          initializer=tf.keras.initializers.Constant(pretrained_weight))
+          initializer=keras.initializers.Constant(pretrained_weight),
+      )
       prior_shape = tf.TensorShape(())
     elif 3 <= shape.rank <= 4:
       # Convolution kernel.
@@ -125,7 +127,8 @@ def _init_training_weights_reparam(
           name="kernel_rdft",
           shape=kernel_rdft.shape,
           dtype=kernel_rdft.dtype,
-          initializer=tf.keras.initializers.Constant(kernel_rdft))
+          initializer=keras.initializers.Constant(kernel_rdft),
+      )
       self.add_training_weight(
           name="kernel_shape",
           shape=kernel_shape.shape,
@@ -133,7 +136,8 @@ def _init_training_weights_reparam(
           # TODO(jballe): If False, breaks optimize.create_layer_for_training().
           # If True, throws warnings that int tensors have no gradient.
           # trainable=False,
-          initializer=tf.keras.initializers.Constant(kernel_shape))
+          initializer=keras.initializers.Constant(kernel_shape),
+      )
       prior_shape = kernel_rdft.shape[2:]
     else:
       raise ValueError(
@@ -146,18 +150,22 @@ def _init_training_weights_reparam(
         name=f"{weight_name}_log_step",
         shape=log_step.shape,
         dtype=log_step.dtype,
-        initializer=tf.keras.initializers.Constant(log_step))
+        initializer=keras.initializers.Constant(log_step),
+    )
 
     return prior_shape, dtype, weight_name
 
-  def get_training_model(self, model: tf.keras.Model) -> tf.keras.Model:
+  def get_training_model(self, model: keras.Model) -> keras.Model:
     """Augments a model for training with EPR."""
-    if not (isinstance(model, tf.keras.Sequential) or model._is_graph_network):  # pylint: disable=protected-access
+    if not (isinstance(model, keras.Sequential) or model._is_graph_network):  # pylint: disable=protected-access
       raise ValueError("`model` must be either sequential or functional.")
 
-    training_model = tf.keras.models.clone_model(
-        model, clone_function=functools.partial(
-            algorithm.create_layer_for_training, algorithm=self))
+    training_model = keras.models.clone_model(
+        model,
+        clone_function=functools.partial(
+            algorithm.create_layer_for_training, algorithm=self
+        ),
+    )
     training_model.build(model.input.shape)
 
     # Divide regularization weight by number of original model parameters to
@@ -178,13 +186,16 @@ def regularization_loss(layer, name):
     # different optimizer/learning rate. How to do this?
     return training_model
 
-  def compress_model(self, model: tf.keras.Model) -> tf.keras.Model:
+  def compress_model(self, model: keras.Model) -> keras.Model:
     """Compresses a model after training with EPR."""
-    if not (isinstance(model, tf.keras.Sequential) or model._is_graph_network):  # pylint: disable=protected-access
+    if not (isinstance(model, keras.Sequential) or model._is_graph_network):  # pylint: disable=protected-access
       raise ValueError("`model` must be either sequential or functional.")
-    return tf.keras.models.clone_model(
-        model, clone_function=functools.partial(
-            algorithm.create_layer_for_inference, algorithm=self))
+    return keras.models.clone_model(
+        model,
+        clone_function=functools.partial(
+            algorithm.create_layer_for_inference, algorithm=self
+        ),
+    )
 
 
 class EPR(EPRBase):
@@ -201,7 +212,8 @@ def init_training_weights(self, pretrained_weight: tf.Tensor):
         name=f"{weight_name}_log_scale",
         shape=log_scale.shape,
         dtype=log_scale.dtype,
-        initializer=tf.keras.initializers.Constant(log_scale))
+        initializer=keras.initializers.Constant(log_scale),
+    )
 
   def project_training_weights(self, *training_weights: tf.Tensor) -> tf.Tensor:
     if len(training_weights) == 3:
diff --git a/tensorflow_model_optimization/python/core/common/keras/compression/algorithms/epr_test.py b/tensorflow_model_optimization/python/core/common/keras/compression/algorithms/epr_test.py
index d97560dd7..64316bfcf 100644
--- a/tensorflow_model_optimization/python/core/common/keras/compression/algorithms/epr_test.py
+++ b/tensorflow_model_optimization/python/core/common/keras/compression/algorithms/epr_test.py
@@ -19,27 +19,28 @@
 from absl.testing import parameterized
 import tensorflow as tf
 from tensorflow_model_optimization.python.core.common.keras.compression.algorithms import epr
+from tensorflow_model_optimization.python.core.keras.compat import keras
 
 
 def build_model():
-  inputs = tf.keras.layers.Input(shape=(28, 28), name="input")
-  x = tf.keras.layers.Reshape((28, 28, 1))(inputs)
-  x = tf.keras.layers.Conv2D(
-      20, 5, use_bias=True, activation="relu", padding="valid", name="conv1")(x)
-  x = tf.keras.layers.MaxPool2D(2, 2)(x)
-  x = tf.keras.layers.Conv2D(
-      50, 5, use_bias=True, activation="relu", padding="valid", name="conv2")(x)
-  x = tf.keras.layers.MaxPool2D(2, 2)(x)
-  x = tf.keras.layers.Flatten()(x)
-  x = tf.keras.layers.Dense(
-      500, use_bias=True, activation="relu", name="fc1")(x)
-  outputs = tf.keras.layers.Dense(
-      10, use_bias=True, name="fc2")(x)
-  return tf.keras.Model(inputs=[inputs], outputs=[outputs])
+  inputs = keras.layers.Input(shape=(28, 28), name="input")
+  x = keras.layers.Reshape((28, 28, 1))(inputs)
+  x = keras.layers.Conv2D(
+      20, 5, use_bias=True, activation="relu", padding="valid", name="conv1"
+  )(x)
+  x = keras.layers.MaxPool2D(2, 2)(x)
+  x = keras.layers.Conv2D(
+      50, 5, use_bias=True, activation="relu", padding="valid", name="conv2"
+  )(x)
+  x = keras.layers.MaxPool2D(2, 2)(x)
+  x = keras.layers.Flatten()(x)
+  x = keras.layers.Dense(500, use_bias=True, activation="relu", name="fc1")(x)
+  outputs = keras.layers.Dense(10, use_bias=True, name="fc2")(x)
+  return keras.Model(inputs=[inputs], outputs=[outputs])
 
 
 def get_dataset():
-  (x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
+  (x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
   x_train = (x_train / 255).astype("float32")
   x_test = (x_test / 255).astype("float32")
   return (x_train, y_train), (x_test, y_test)
@@ -47,9 +48,9 @@ def get_dataset():
 
 def train_model(model):
   model.compile(
-      optimizer=tf.keras.optimizers.Adam(1e-2),
-      loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
-      metrics=[tf.keras.metrics.SparseCategoricalAccuracy(name="accuracy")],
+      optimizer=keras.optimizers.Adam(1e-2),
+      loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
+      metrics=[keras.metrics.SparseCategoricalAccuracy(name="accuracy")],
   )
   (x_train, y_train), _ = get_dataset()
   model.fit(x_train, y_train, batch_size=128, epochs=3)
@@ -57,7 +58,7 @@ def train_model(model):
 
 def evaluate_model(model):
   model.compile(
-      metrics=[tf.keras.metrics.SparseCategoricalAccuracy(name="accuracy")],
+      metrics=[keras.metrics.SparseCategoricalAccuracy(name="accuracy")],
   )
   _, (x_test, y_test) = get_dataset()
   results = model.evaluate(x_test, y_test, batch_size=128, return_dict=True)
@@ -90,7 +91,7 @@ def test_project_training_weights_has_gradients(self, *shape):
     algorithm = self.get_algorithm()
     init = tf.ones(shape, dtype=tf.float32)
     algorithm.init_training_weights(init)
-    layer = tf.keras.layers.Layer()
+    layer = keras.layers.Layer()
     for weight_repr in algorithm.weight_reprs:
       layer.add_weight(*weight_repr.args, **weight_repr.kwargs)
     with tf.GradientTape() as tape:
@@ -106,7 +107,7 @@ def test_regularization_loss_has_gradients(self, *shape):
     algorithm = self.get_algorithm()
     init = tf.ones(shape, dtype=tf.float32)
     algorithm.init_training_weights(init)
-    layer = tf.keras.layers.Layer()
+    layer = keras.layers.Layer()
     for weight_repr in algorithm.weight_reprs:
       layer.add_weight(*weight_repr.args, **weight_repr.kwargs)
     with tf.GradientTape() as tape:
@@ -117,16 +118,16 @@ def test_regularization_loss_has_gradients(self, *shape):
         [w.dtype.is_floating for w in layer.weights])
 
   @parameterized.parameters(
-      ((2, 3), tf.keras.layers.Dense, 5),
+      ((2, 3), keras.layers.Dense, 5),
       # TODO(jballe): This fails with: 'You called `set_weights(weights)` on
       # layer "private__training_wrapper" with a weight list of length 0, but
       # the layer was expecting 5 weights.' Find fix.
-      # ((3, 10, 2), tf.keras.layers.Conv1D, 5, 3),
-      ((1, 8, 9, 2), tf.keras.layers.Conv2D, 5, 3),
+      # ((3, 10, 2), keras.layers.Conv1D, 5, 3),
+      ((1, 8, 9, 2), keras.layers.Conv2D, 5, 3),
   )
   def test_model_has_gradients(self, input_shape, layer_cls, *args):
     algorithm = self.get_algorithm()
-    model = tf.keras.Sequential([layer_cls(*args, use_bias=True)])
+    model = keras.Sequential([layer_cls(*args, use_bias=True)])
     inputs = tf.random.normal(input_shape)
     model(inputs)
     training_model = algorithm.get_training_model(model)
@@ -145,7 +146,7 @@ def test_train_and_test_weights_are_equal(self, *shape):
     algorithm = self.get_algorithm()
     init = tf.random.uniform(shape, dtype=tf.float32)
     algorithm.init_training_weights(init)
-    layer = tf.keras.layers.Layer()
+    layer = keras.layers.Layer()
     for weight_repr in algorithm.weight_reprs:
       layer.add_weight(*weight_repr.args, **weight_repr.kwargs)
     train_weight = algorithm.project_training_weights(*layer.weights)
@@ -158,7 +159,7 @@ def test_initialized_value_is_close_enough(self, *shape):
     algorithm = self.get_algorithm()
     init = tf.random.uniform(shape, -10., 10., dtype=tf.float32)
     algorithm.init_training_weights(init)
-    layer = tf.keras.layers.Layer()
+    layer = keras.layers.Layer()
     for weight_repr in algorithm.weight_reprs:
       layer.add_weight(*weight_repr.args, **weight_repr.kwargs)
     weight = algorithm.project_training_weights(*layer.weights)
@@ -200,7 +201,7 @@ def test_reduces_model_size_at_reasonable_accuracy(self):
       self.assertLess(compressed_size, 0.2 * original_size)
 
     with self.subTest("compressed_model_has_reasonable_accuracy"):
-      compressed_model = tf.keras.models.load_model(compressed_model_dir)
+      compressed_model = keras.models.load_model(compressed_model_dir)
       accuracy = evaluate_model(compressed_model)
       self.assertGreater(accuracy, .9)
 
diff --git a/tensorflow_model_optimization/python/core/common/keras/compression/algorithms/periodical_update_and_scheduling.py b/tensorflow_model_optimization/python/core/common/keras/compression/algorithms/periodical_update_and_scheduling.py
index 17d79e690..8164a9f10 100644
--- a/tensorflow_model_optimization/python/core/common/keras/compression/algorithms/periodical_update_and_scheduling.py
+++ b/tensorflow_model_optimization/python/core/common/keras/compression/algorithms/periodical_update_and_scheduling.py
@@ -18,6 +18,7 @@
 import tensorflow as tf
 
 from tensorflow_model_optimization.python.core.common.keras.compression import algorithm
+from tensorflow_model_optimization.python.core.keras.compat import keras
 
 
 class SVD(algorithm.WeightCompressor):
@@ -42,12 +43,14 @@ def init_training_weights(
         name='w',
         shape=pretrained_weight.shape,
         dtype=pretrained_weight.dtype,
-        initializer=tf.keras.initializers.Constant(pretrained_weight))
+        initializer=keras.initializers.Constant(pretrained_weight),
+    )
     self.add_training_weight(
         name='step',
         shape=(),
         dtype=tf.int32,
-        initializer=tf.keras.initializers.Constant(0))
+        initializer=keras.initializers.Constant(0),
+    )
 
   def decompress_weights(self, u: tf.Tensor, sv: tf.Tensor) -> tf.Tensor:
     return tf.matmul(u, sv)
@@ -109,13 +112,13 @@ def compress_training_weights(self, weight: tf.Tensor, _) -> List[tf.Tensor]:
     return [u, sv]
 
   def get_compressible_weights(
-      self, original_layer: tf.keras.layers.Layer) -> List[str]:
-    if isinstance(original_layer, (tf.keras.layers.Conv2D,
-                                   tf.keras.layers.Dense)):
+      self, original_layer: keras.layers.Layer
+  ) -> List[str]:
+    if isinstance(original_layer, (keras.layers.Conv2D, keras.layers.Dense)):
       return [original_layer.kernel]
     return []
 
-  def optimize_model(self, to_optimize: tf.keras.Model) -> tf.keras.Model:
+  def optimize_model(self, to_optimize: keras.Model) -> keras.Model:
     """Model developer API for optimizing a model for training.
 
     The returned model should be used for compression aware training.
@@ -125,11 +128,14 @@ def optimize_model(self, to_optimize: tf.keras.Model) -> tf.keras.Model:
       A wrapped model that has compression optimizers.
     """
     # pylint: disable=protected-access
-    if not isinstance(
-        to_optimize, tf.keras.Sequential) and not to_optimize._is_graph_network:
+    if (
+        not isinstance(to_optimize, keras.Sequential)
+        and not to_optimize._is_graph_network
+    ):
       raise ValueError(
-          '`optimize_model` can only either be a tf.keras Sequential or '
-          'Functional model.')
+          '`optimize_model` can only either be a keras Sequential or '
+          'Functional model.'
+      )
     # pylint: enable=protected-access
 
     def _optimize_layer(layer):
@@ -141,10 +147,9 @@ def _optimize_layer(layer):
 
       return algorithm.create_layer_for_training(layer, algorithm=self)
 
-    return tf.keras.models.clone_model(
-        to_optimize, clone_function=_optimize_layer)
+    return keras.models.clone_model(to_optimize, clone_function=_optimize_layer)
 
-  def compress_model(self, to_compress: tf.keras.Model) -> tf.keras.Model:
+  def compress_model(self, to_compress: keras.Model) -> keras.Model:
     """Model developer API for optimizing a model for inference.
 
     Args:
@@ -162,5 +167,4 @@ def _optimize_layer(layer):
 
       return algorithm.create_layer_for_inference(layer, algorithm=self)
 
-    return tf.keras.models.clone_model(
-        to_compress, clone_function=_optimize_layer)
+    return keras.models.clone_model(to_compress, clone_function=_optimize_layer)
diff --git a/tensorflow_model_optimization/python/core/common/keras/compression/algorithms/periodical_update_and_scheduling_test.py b/tensorflow_model_optimization/python/core/common/keras/compression/algorithms/periodical_update_and_scheduling_test.py
index fcf0726de..227b141be 100644
--- a/tensorflow_model_optimization/python/core/common/keras/compression/algorithms/periodical_update_and_scheduling_test.py
+++ b/tensorflow_model_optimization/python/core/common/keras/compression/algorithms/periodical_update_and_scheduling_test.py
@@ -20,30 +20,31 @@
 import tensorflow as tf
 
 from tensorflow_model_optimization.python.core.common.keras.compression.algorithms import periodical_update_and_scheduling as svd
+from tensorflow_model_optimization.python.core.keras.compat import keras
 from tensorflow_model_optimization.python.core.keras.testing import test_utils_mnist
 
 
 def _build_model():
-  i = tf.keras.layers.Input(shape=(28, 28), name='input')
-  x = tf.keras.layers.Reshape((28, 28, 1))(i)
-  x = tf.keras.layers.Conv2D(
-      20, 5, activation='relu', padding='valid', name='conv1')(
-          x)
-  x = tf.keras.layers.MaxPool2D(2, 2)(x)
-  x = tf.keras.layers.Conv2D(
-      50, 5, activation='relu', padding='valid', name='conv2')(
-          x)
-  x = tf.keras.layers.MaxPool2D(2, 2)(x)
-  x = tf.keras.layers.Flatten()(x)
-  x = tf.keras.layers.Dense(500, activation='relu', name='fc1')(x)
-  output = tf.keras.layers.Dense(10, name='fc2')(x)
-
-  model = tf.keras.Model(inputs=[i], outputs=[output])
+  i = keras.layers.Input(shape=(28, 28), name='input')
+  x = keras.layers.Reshape((28, 28, 1))(i)
+  x = keras.layers.Conv2D(
+      20, 5, activation='relu', padding='valid', name='conv1'
+  )(x)
+  x = keras.layers.MaxPool2D(2, 2)(x)
+  x = keras.layers.Conv2D(
+      50, 5, activation='relu', padding='valid', name='conv2'
+  )(x)
+  x = keras.layers.MaxPool2D(2, 2)(x)
+  x = keras.layers.Flatten()(x)
+  x = keras.layers.Dense(500, activation='relu', name='fc1')(x)
+  output = keras.layers.Dense(10, name='fc2')(x)
+
+  model = keras.Model(inputs=[i], outputs=[output])
   return model
 
 
 def _get_dataset():
-  mnist = tf.keras.datasets.mnist
+  mnist = keras.datasets.mnist
   (x_train, y_train), (x_test, y_test) = mnist.load_data()
   x_train, x_test = x_train / 255.0, x_test / 255.0
   # Use subset of 60000 examples to keep unit test speed fast.
@@ -53,7 +54,7 @@ def _get_dataset():
 
 
 def _train_model(model):
-  loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
+  loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True)
 
   model.compile(optimizer='adam', loss=loss_fn, metrics=['accuracy'])
 
@@ -132,7 +133,7 @@ def testSVD_HasReasonableAccuracy_TF(self):
 
     _, (x_test, y_test) = _get_dataset()
 
-    loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
+    loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True)
 
     compressed_model.compile(
         optimizer='adam', loss=loss_fn, metrics=['accuracy'])
@@ -193,9 +194,9 @@ def testSVD_BreaksDownLayerWeights(self):
 
   # TODO(tfmot): can simplify to single layer test.
   def testSVD_PreservesPretrainedWeights(self):
-    i = tf.keras.layers.Input(shape=(2), name='input')
-    output = tf.keras.layers.Dense(3, name='fc1')(i)
-    model = tf.keras.Model(inputs=[i], outputs=[output])
+    i = keras.layers.Input(shape=(2), name='input')
+    output = keras.layers.Dense(3, name='fc1')(i)
+    model = keras.Model(inputs=[i], outputs=[output])
 
     dense_layer_weights = model.layers[1].get_weights()
 
diff --git a/tensorflow_model_optimization/python/core/common/keras/compression/algorithms/same_training_and_inference.py b/tensorflow_model_optimization/python/core/common/keras/compression/algorithms/same_training_and_inference.py
index 2d975dda0..81bf01dee 100644
--- a/tensorflow_model_optimization/python/core/common/keras/compression/algorithms/same_training_and_inference.py
+++ b/tensorflow_model_optimization/python/core/common/keras/compression/algorithms/same_training_and_inference.py
@@ -18,6 +18,7 @@
 import tensorflow as tf
 
 from tensorflow_model_optimization.python.core.common.keras.compression import algorithm
+from tensorflow_model_optimization.python.core.keras.compat import keras
 
 
 class SVD(algorithm.WeightCompressor):
@@ -55,12 +56,14 @@ def init_training_weights(self, pretrained_weight: tf.Tensor):
         name='u',
         shape=u.shape,
         dtype=u.dtype,
-        initializer=tf.keras.initializers.Constant(u))
+        initializer=keras.initializers.Constant(u),
+    )
     self.add_training_weight(
         name='sv',
         shape=sv.shape,
         dtype=sv.dtype,
-        initializer=tf.keras.initializers.Constant(sv))
+        initializer=keras.initializers.Constant(sv),
+    )
 
   def decompress_weights(self, u: tf.Tensor, sv: tf.Tensor) -> tf.Tensor:
     return tf.matmul(u, sv)
@@ -69,20 +72,25 @@ def project_training_weights(self, u: tf.Tensor, sv: tf.Tensor) -> tf.Tensor:
     return self.decompress_weights(u, sv)
 
   def get_compressible_weights(
-      self, original_layer: tf.keras.layers.Layer) -> List[str]:
-    if isinstance(original_layer, tf.keras.layers.Conv2D) or \
-       isinstance(original_layer, tf.keras.layers.Dense):
+      self, original_layer: keras.layers.Layer
+  ) -> List[str]:
+    if isinstance(original_layer, keras.layers.Conv2D) or isinstance(
+        original_layer, keras.layers.Dense
+    ):
       return [original_layer.kernel]
     return []
 
-  def compress_model(self, to_optimize: tf.keras.Model) -> tf.keras.Model:
+  def compress_model(self, to_optimize: keras.Model) -> keras.Model:
     """Model developer API for optimizing a model."""
     # pylint: disable=protected-access
-    if not isinstance(to_optimize, tf.keras.Sequential) \
-        and not to_optimize._is_graph_network:
+    if (
+        not isinstance(to_optimize, keras.Sequential)
+        and not to_optimize._is_graph_network
+    ):
       raise ValueError(
-          '`compress_model` can only either be a tf.keras Sequential or '
-          'Functional model.')
+          '`compress_model` can only either be a keras Sequential or '
+          'Functional model.'
+      )
     # pylint: enable=protected-access
 
     def _optimize_layer(layer):
@@ -94,5 +102,4 @@ def _optimize_layer(layer):
 
       return algorithm.create_layer_for_training(layer, algorithm=self)
 
-    return tf.keras.models.clone_model(
-        to_optimize, clone_function=_optimize_layer)
+    return keras.models.clone_model(to_optimize, clone_function=_optimize_layer)
diff --git a/tensorflow_model_optimization/python/core/common/keras/compression/algorithms/same_training_and_inference_test.py b/tensorflow_model_optimization/python/core/common/keras/compression/algorithms/same_training_and_inference_test.py
index 2f24c0945..474a7bcb5 100644
--- a/tensorflow_model_optimization/python/core/common/keras/compression/algorithms/same_training_and_inference_test.py
+++ b/tensorflow_model_optimization/python/core/common/keras/compression/algorithms/same_training_and_inference_test.py
@@ -20,30 +20,31 @@
 import tensorflow as tf
 
 from tensorflow_model_optimization.python.core.common.keras.compression.algorithms import same_training_and_inference as svd
+from tensorflow_model_optimization.python.core.keras.compat import keras
 from tensorflow_model_optimization.python.core.keras.testing import test_utils_mnist
 
 
 def _build_model():
-  i = tf.keras.layers.Input(shape=(28, 28), name='input')
-  x = tf.keras.layers.Reshape((28, 28, 1))(i)
-  x = tf.keras.layers.Conv2D(
-      20, 5, activation='relu', padding='valid', name='conv1')(
-          x)
-  x = tf.keras.layers.MaxPool2D(2, 2)(x)
-  x = tf.keras.layers.Conv2D(
-      50, 5, activation='relu', padding='valid', name='conv2')(
-          x)
-  x = tf.keras.layers.MaxPool2D(2, 2)(x)
-  x = tf.keras.layers.Flatten()(x)
-  x = tf.keras.layers.Dense(500, activation='relu', name='fc1')(x)
-  output = tf.keras.layers.Dense(10, name='fc2')(x)
-
-  model = tf.keras.Model(inputs=[i], outputs=[output])
+  i = keras.layers.Input(shape=(28, 28), name='input')
+  x = keras.layers.Reshape((28, 28, 1))(i)
+  x = keras.layers.Conv2D(
+      20, 5, activation='relu', padding='valid', name='conv1'
+  )(x)
+  x = keras.layers.MaxPool2D(2, 2)(x)
+  x = keras.layers.Conv2D(
+      50, 5, activation='relu', padding='valid', name='conv2'
+  )(x)
+  x = keras.layers.MaxPool2D(2, 2)(x)
+  x = keras.layers.Flatten()(x)
+  x = keras.layers.Dense(500, activation='relu', name='fc1')(x)
+  output = keras.layers.Dense(10, name='fc2')(x)
+
+  model = keras.Model(inputs=[i], outputs=[output])
   return model
 
 
 def _get_dataset():
-  mnist = tf.keras.datasets.mnist
+  mnist = keras.datasets.mnist
   (x_train, y_train), (x_test, y_test) = mnist.load_data()
   x_train, x_test = x_train / 255.0, x_test / 255.0
   # Use subset of 60000 examples to keep unit test speed fast.
@@ -53,7 +54,7 @@ def _get_dataset():
 
 
 def _train_model(model):
-  loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
+  loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True)
 
   model.compile(optimizer='adam', loss=loss_fn, metrics=['accuracy'])
 
@@ -127,7 +128,7 @@ def testSVD_HasReasonableAccuracy_TF(self):
 
     _, (x_test, y_test) = _get_dataset()
 
-    loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
+    loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True)
 
     compressed_model.compile(
         optimizer='adam', loss=loss_fn, metrics=['accuracy'])
@@ -181,9 +182,9 @@ def testSVD_BreaksDownLayerWeights(self):
 
   # TODO(tfmot): can simplify to single layer test.
   def testSVD_PreservesPretrainedWeights(self):
-    i = tf.keras.layers.Input(shape=(2), name='input')
-    output = tf.keras.layers.Dense(3, name='fc1')(i)
-    model = tf.keras.Model(inputs=[i], outputs=[output])
+    i = keras.layers.Input(shape=(2), name='input')
+    output = keras.layers.Dense(3, name='fc1')(i)
+    model = keras.Model(inputs=[i], outputs=[output])
 
     dense_layer_weights = model.layers[1].get_weights()
 
diff --git a/tensorflow_model_optimization/python/core/common/keras/compression/algorithms/weight_clustering.py b/tensorflow_model_optimization/python/core/common/keras/compression/algorithms/weight_clustering.py
index e81475367..21ce1040f 100644
--- a/tensorflow_model_optimization/python/core/common/keras/compression/algorithms/weight_clustering.py
+++ b/tensorflow_model_optimization/python/core/common/keras/compression/algorithms/weight_clustering.py
@@ -23,6 +23,7 @@
 from tensorflow_model_optimization.python.core.clustering.keras import clustering_registry
 from tensorflow_model_optimization.python.core.clustering.keras.cluster_config import GradientAggregation
 from tensorflow_model_optimization.python.core.common.keras.compression import algorithm
+from tensorflow_model_optimization.python.core.keras.compat import keras
 
 
 class ConvolutionalWeightsCA(clustering_registry.ClusteringAlgorithm):
@@ -79,12 +80,14 @@ def init_training_weights(
         name='cluster_centroids',
         shape=cluster_centroids.shape,
         dtype=cluster_centroids.dtype,
-        initializer=tf.keras.initializers.Constant(cluster_centroids))
+        initializer=keras.initializers.Constant(cluster_centroids),
+    )
     self.add_training_weight(
         name='pulling_indices',
         shape=pulling_indices.shape,
         dtype=pulling_indices.dtype,
-        initializer=tf.keras.initializers.Constant(pulling_indices))
+        initializer=keras.initializers.Constant(pulling_indices),
+    )
 
   def decompress_weights(self,
                          cluster_centroids: tf.Tensor,
@@ -100,15 +103,15 @@ def project_training_weights(self,
     return self.decompress_weights(cluster_centroids, pulling_indices)
 
   def get_compressible_weights(
-      self, original_layer: tf.keras.layers.Layer) -> List[str]:
-    if (isinstance(original_layer, tf.keras.layers.Conv2D) or
-        isinstance(original_layer, tf.keras.layers.Dense)):
+      self, original_layer: keras.layers.Layer
+  ) -> List[str]:
+    if isinstance(original_layer, keras.layers.Conv2D) or isinstance(
+        original_layer, keras.layers.Dense
+    ):
       return [original_layer.kernel]
     return []
 
-  def compress_model(
-      self,
-      to_optimize: tf.keras.Model) -> tf.keras.Model:
+  def compress_model(self, to_optimize: keras.Model) -> keras.Model:
     """Model developer API for optimizing a model."""
 
     def _optimize_layer(layer):
@@ -122,5 +125,4 @@ def _optimize_layer(layer):
       return algorithm.create_layer_for_training(
           layer, algorithm=self)
 
-    return tf.keras.models.clone_model(
-        to_optimize, clone_function=_optimize_layer)
+    return keras.models.clone_model(to_optimize, clone_function=_optimize_layer)
diff --git a/tensorflow_model_optimization/python/core/common/keras/compression/algorithms/weight_clustering_test.py b/tensorflow_model_optimization/python/core/common/keras/compression/algorithms/weight_clustering_test.py
index d35a7a518..abafe90dd 100644
--- a/tensorflow_model_optimization/python/core/common/keras/compression/algorithms/weight_clustering_test.py
+++ b/tensorflow_model_optimization/python/core/common/keras/compression/algorithms/weight_clustering_test.py
@@ -22,29 +22,30 @@
 
 from tensorflow_model_optimization.python.core.clustering.keras import cluster_config
 from tensorflow_model_optimization.python.core.common.keras.compression.algorithms import weight_clustering
+from tensorflow_model_optimization.python.core.keras.compat import keras
 
 
 def _build_model():
-  i = tf.keras.layers.Input(shape=(28, 28), name='input')
-  x = tf.keras.layers.Reshape((28, 28, 1))(i)
-  x = tf.keras.layers.Conv2D(
-      20, 5, activation='relu', padding='valid', name='conv1')(
-          x)
-  x = tf.keras.layers.MaxPool2D(2, 2)(x)
-  x = tf.keras.layers.Conv2D(
-      50, 5, activation='relu', padding='valid', name='conv2')(
-          x)
-  x = tf.keras.layers.MaxPool2D(2, 2)(x)
-  x = tf.keras.layers.Flatten()(x)
-  x = tf.keras.layers.Dense(500, activation='relu', name='fc1')(x)
-  output = tf.keras.layers.Dense(10, name='fc2')(x)
-
-  model = tf.keras.Model(inputs=[i], outputs=[output])
+  i = keras.layers.Input(shape=(28, 28), name='input')
+  x = keras.layers.Reshape((28, 28, 1))(i)
+  x = keras.layers.Conv2D(
+      20, 5, activation='relu', padding='valid', name='conv1'
+  )(x)
+  x = keras.layers.MaxPool2D(2, 2)(x)
+  x = keras.layers.Conv2D(
+      50, 5, activation='relu', padding='valid', name='conv2'
+  )(x)
+  x = keras.layers.MaxPool2D(2, 2)(x)
+  x = keras.layers.Flatten()(x)
+  x = keras.layers.Dense(500, activation='relu', name='fc1')(x)
+  output = keras.layers.Dense(10, name='fc2')(x)
+
+  model = keras.Model(inputs=[i], outputs=[output])
   return model
 
 
 def _get_dataset():
-  mnist = tf.keras.datasets.mnist
+  mnist = keras.datasets.mnist
   (x_train, y_train), (x_test, y_test) = mnist.load_data()
   x_train, x_test = x_train / 255.0, x_test / 255.0
   # Use subset of 60000 examples to keep unit test speed fast.
@@ -55,7 +56,7 @@ def _get_dataset():
 
 
 def _train_model(model):
-  loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
+  loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True)
   model.compile(optimizer='adam', loss=loss_fn, metrics=['accuracy'])
   (x_train, y_train), _ = _get_dataset()
   model.fit(x_train, y_train, epochs=1)
@@ -102,7 +103,7 @@ def testWeightClustering_TrainingE2E(self):
 
     _, (x_test, y_test) = _get_dataset()
 
-    loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
+    loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True)
 
     compressed_model.compile(
         optimizer='adam', loss=loss_fn, metrics=['accuracy'])
@@ -120,9 +121,9 @@ def testWeightClustering_TrainingE2E(self):
 
   def testWeightClustering_SingleLayer(self):
     number_of_clusters = 8
-    i = tf.keras.layers.Input(shape=(2), name='input')
-    output = tf.keras.layers.Dense(3, name='fc1')(i)
-    model = tf.keras.Model(inputs=[i], outputs=[output])
+    i = keras.layers.Input(shape=(2), name='input')
+    output = keras.layers.Dense(3, name='fc1')(i)
+    model = keras.Model(inputs=[i], outputs=[output])
 
     dense_layer_weights = model.layers[1].get_weights()
 
diff --git a/tensorflow_model_optimization/python/core/common/keras/compression/internal/BUILD b/tensorflow_model_optimization/python/core/common/keras/compression/internal/BUILD
index 764f838ef..be7cae93f 100644
--- a/tensorflow_model_optimization/python/core/common/keras/compression/internal/BUILD
+++ b/tensorflow_model_optimization/python/core/common/keras/compression/internal/BUILD
@@ -12,5 +12,6 @@ pytype_strict_library(
     srcs_version = "PY3",
     deps = [
         # tensorflow dep1,
+        "//tensorflow_model_optimization/python/core/keras:compat",
     ],
 )
diff --git a/tensorflow_model_optimization/python/core/common/keras/compression/internal/optimize.py b/tensorflow_model_optimization/python/core/common/keras/compression/internal/optimize.py
index 79fa42c4c..1e29333d0 100644
--- a/tensorflow_model_optimization/python/core/common/keras/compression/internal/optimize.py
+++ b/tensorflow_model_optimization/python/core/common/keras/compression/internal/optimize.py
@@ -14,8 +14,11 @@
 # ==============================================================================
 """Internal APIs and core implementation of weight compression API."""
 from typing import List, Mapping
+
 import tensorflow as tf
 
+from tensorflow_model_optimization.python.core.keras.compat import keras
+
 
 # Workaround to prevent MLIR from constant folding the
 # compressed weights into the original weights. For instance,
@@ -35,7 +38,7 @@ def _prevent_constant_folding(tensor, dummy_inputs):
   return outputs
 
 
-class _TrainingWrapper(tf.keras.layers.Wrapper):
+class _TrainingWrapper(keras.layers.Wrapper):
   """Represent modifications to training graph for weight compression."""
 
   def __init__(self, layer, algorithm, compressible_weights: List[str]):
@@ -156,7 +159,7 @@ def call(self, inputs):
 
 
 # TODO(tfmot): deduplicate code with _TrainingWrapper.
-class _InferenceWrapper(tf.keras.layers.Wrapper):
+class _InferenceWrapper(keras.layers.Wrapper):
   """Represent modifications to inference graph for weight compression."""
 
   def __init__(self, layer, algorithm,
@@ -218,8 +221,11 @@ def build(self, input_shape):  # pytype: disable=signature-mismatch  # overridin
       weights = []
       for t in compressed_tensors:
         weight = self.add_weight(
-            name='TODO', dtype=t.dtype, shape=t.shape,
-            initializer=tf.keras.initializers.Constant(t))
+            name='TODO',
+            dtype=t.dtype,
+            shape=t.shape,
+            initializer=keras.initializers.Constant(t),
+        )
         weights.append(weight)
 
       self.compressed_weights[attr_name] = weights
@@ -254,7 +260,7 @@ def _map_to_training_weights(
   """Construct the training weight values from the layer's pretrained weights.
 
     The weight values have the same structure as the output of
-    `tf.keras.layers.Layer.get_weights`.
+    `keras.layers.Layer.get_weights`.
 
   Args:
     algorithm: weight compression algorithm
@@ -271,17 +277,18 @@ def _map_to_training_weights(
   # TODO(tfmot): see if Keras can introduce changes to simplify this.
   original_weights = []
   training_weights = []
-  if isinstance(layer, tf.keras.layers.Conv2D) or \
-     isinstance(layer, tf.keras.layers.Dense):
+  if isinstance(layer, keras.layers.Conv2D) or isinstance(
+      layer, keras.layers.Dense
+  ):
     for weight in layer.weights:
       if _find(weight, compressible_weights):
         algorithm.weight_reprs = []
         algorithm.init_training_weights(weight)
         for weight_repr in algorithm.weight_reprs:
-          # Assumes initializer is tf.keras.initializers.Constant.
+          # Assumes initializer is keras.initializers.Constant.
           # TODO(tfmot): add check for this assumption.
           # TODO(tfmot): the documentation for
-          # tf.keras.initializers.Constant(value)
+          # keras.initializers.Constant(value)
           # suggests that the `value` cannot be any arbitrary shape and
           # only a single scalar value. It works in this implementation
           # to make `value` any tensor - check this.
@@ -298,7 +305,7 @@ def _map_to_inference_weights(training_weights, algorithm, training_tensors):
   """Construct the inference weight values from the weights after training.
 
     The weight values have the same structure as the output of
-    `tf.keras.layers.Layer.get_weights`.
+    `keras.layers.Layer.get_weights`.
 
   Args:
     training_weights: layer's weights from training, retrieved via
diff --git a/tensorflow_model_optimization/python/core/keras/BUILD b/tensorflow_model_optimization/python/core/keras/BUILD
index 7ef4a1e6c..f78fe8d25 100644
--- a/tensorflow_model_optimization/python/core/keras/BUILD
+++ b/tensorflow_model_optimization/python/core/keras/BUILD
@@ -21,6 +21,7 @@ py_strict_library(
     srcs = ["test_utils.py"],
     srcs_version = "PY3",
     deps = [
+        ":compat",
         # numpy dep1,
         # tensorflow dep1,
     ],
@@ -61,6 +62,7 @@ py_strict_test(
     srcs = ["metrics_test.py"],
     python_version = "PY3",
     deps = [
+        ":compat",
         ":metrics",
         # google/protobuf:use_fast_cpp_protos dep1,  # Automatically added
         # mock dep1,
diff --git a/tensorflow_model_optimization/python/core/keras/compat.py b/tensorflow_model_optimization/python/core/keras/compat.py
index 02486849b..67b5a30b3 100644
--- a/tensorflow_model_optimization/python/core/keras/compat.py
+++ b/tensorflow_model_optimization/python/core/keras/compat.py
@@ -12,15 +12,36 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 # ==============================================================================
-"""Functions for TF 1.X and 2.X compatibility."""
+"""Global variables and functions for TF/Keras compatibility."""
 
 from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
+import collections
+import weakref
+
 import tensorflow as tf
 
 
+def _get_keras_instance():
+  from pkg_resources import parse_version
+
+  required_tensorflow_version = '2.16.0'
+  if parse_version(tf.__version__) < parse_version(required_tensorflow_version):
+    return tf.keras
+
+  version_fn = getattr(tf.keras, 'version', None)
+  if version_fn and version_fn().startswith('3.'):
+    try:
+      import tf_keras as keras
+    except ImportError:
+      pass
+  return tf.keras
+
+
+keras = _get_keras_instance()
+
 def assign(ref, value, name=None):
   if hasattr(tf, 'assign'):
     return tf.assign(ref, value, name=name)
@@ -40,3 +61,75 @@ def initialize_variables(testcase):
 
 def is_v1_apis():
   return hasattr(tf, 'assign')
+
+
+# A global dictionary mapping graph objects to an index of counters used
+# for various layer/optimizer names in each graph.
+# Allows to give unique autogenerated names to layers, in a graph-specific way.
+PER_GRAPH_OBJECT_NAME_UIDS = weakref.WeakKeyDictionary()
+
+
+def get_default_graph_uid_map():
+  graph = tf.compat.v1.get_default_graph()
+  name_uid_map = PER_GRAPH_OBJECT_NAME_UIDS.get(graph, None)
+  if name_uid_map is None:
+    name_uid_map = collections.defaultdict(int)
+    PER_GRAPH_OBJECT_NAME_UIDS[graph] = name_uid_map
+  return name_uid_map
+
+
+def unique_object_name(
+    name,
+    name_uid_map=None,
+    avoid_names=None,
+    namespace='',
+    zero_based=False,
+    avoid_observed_names=False,
+):
+  """Makes a object name (or any string) unique within a TF-Keras session.
+
+  Args:
+    name: String name to make unique.
+    name_uid_map: An optional defaultdict(int) to use when creating unique
+      names. If None (default), uses a per-Graph dictionary.
+    avoid_names: An optional set or dict with names which should not be used. If
+      None (default), don't avoid any names unless `avoid_observed_names` is
+      True.
+    namespace: Gets a name which is unique within the (graph, namespace). Layers
+      which are not Networks use a blank namespace and so get graph-global
+      names.
+    zero_based: If True, name sequences start with no suffix (e.g. "dense",
+      "dense_1"). If False, naming is one-based ("dense_1", "dense_2").
+    avoid_observed_names: If True, avoid any names that have been observed by
+      `backend.observe_object_name`.
+
+  Returns:
+    Unique string name.
+
+  Example:
+
+
+  unique_object_name('dense')  # dense_1
+  unique_object_name('dense')  # dense_2
+  """
+  if name_uid_map is None:
+    name_uid_map = get_default_graph_uid_map()
+  if avoid_names is None:
+    if avoid_observed_names:
+      avoid_names = OBSERVED_NAMES
+    else:
+      avoid_names = set()
+  proposed_name = None
+  while proposed_name is None or proposed_name in avoid_names:
+    name_key = (namespace, name)
+    if zero_based:
+      number = name_uid_map[name_key]
+      if number:
+        proposed_name = name + '_' + str(number)
+      else:
+        proposed_name = name
+      name_uid_map[name_key] += 1
+    else:
+      name_uid_map[name_key] += 1
+      proposed_name = name + '_' + str(name_uid_map[name_key])
+  return proposed_name
diff --git a/tensorflow_model_optimization/python/core/keras/metrics_test.py b/tensorflow_model_optimization/python/core/keras/metrics_test.py
index ffd327816..b42c518f2 100644
--- a/tensorflow_model_optimization/python/core/keras/metrics_test.py
+++ b/tensorflow_model_optimization/python/core/keras/metrics_test.py
@@ -19,6 +19,7 @@
 
 from tensorflow.python.eager import monitoring
 from tensorflow_model_optimization.python.core.keras import metrics
+from tensorflow_model_optimization.python.core.keras.compat import keras
 
 
 class MetricsTest(tf.test.TestCase):
@@ -27,7 +28,7 @@ class MetricsTest(tf.test.TestCase):
 
   def setUp(self):
     super(MetricsTest, self).setUp()
-    self.test_label = tf.keras.layers.Conv2D(1, 1).__class__.__name__
+    self.test_label = keras.layers.Conv2D(1, 1).__class__.__name__
     for label in [
         self.test_label, metrics.MonitorBoolGauge._SUCCESS_LABEL,
         metrics.MonitorBoolGauge._FAILURE_LABEL
diff --git a/tensorflow_model_optimization/python/core/keras/test_utils.py b/tensorflow_model_optimization/python/core/keras/test_utils.py
index 1d3953bd9..da39df501 100644
--- a/tensorflow_model_optimization/python/core/keras/test_utils.py
+++ b/tensorflow_model_optimization/python/core/keras/test_utils.py
@@ -18,7 +18,10 @@
 import numpy as np
 import tensorflow as tf
 
-l = tf.keras.layers
+from tensorflow_model_optimization.python.core.keras.compat import keras
+
+
+l = keras.layers
 
 
 class ModelCompare(object):
@@ -45,8 +48,8 @@ def _assert_weights_same_values(self, model1, model2):
     self.assertEqual(
         len(model1.trainable_weights), len(model2.trainable_weights))
 
-    model1_weights = tf.keras.backend.batch_get_value(model1.trainable_weights)
-    model2_weights = tf.keras.backend.batch_get_value(model2.trainable_weights)
+    model1_weights = keras.backend.batch_get_value(model1.trainable_weights)
+    model2_weights = keras.backend.batch_get_value(model2.trainable_weights)
     for w1, w2 in zip(model1_weights, model2_weights):
       self.assertAllClose(w1, w2)
 
@@ -54,16 +57,16 @@ def _assert_weights_different_values(self, model1, model2):
     self.assertEqual(
         len(model1.trainable_weights), len(model2.trainable_weights))
 
-    model1_weights = tf.keras.backend.batch_get_value(model1.trainable_weights)
-    model2_weights = tf.keras.backend.batch_get_value(model2.trainable_weights)
+    model1_weights = keras.backend.batch_get_value(model1.trainable_weights)
+    model2_weights = keras.backend.batch_get_value(model2.trainable_weights)
     for w1, w2 in zip(model1_weights, model2_weights):
       self.assertNotAllClose(w1, w2)
 
 
 def build_simple_dense_model():
-  return tf.keras.Sequential([
+  return keras.Sequential([
       l.Dense(8, activation='relu', input_shape=(10,)),
-      l.Dense(5, activation='softmax')
+      l.Dense(5, activation='softmax'),
   ])
 
 
@@ -72,9 +75,9 @@ def get_preprocessed_mnist_data(img_rows=28,
                                 num_classes=10,
                                 is_quantized_model=False):
   """Get data for mnist training and evaluation."""
-  (x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
+  (x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
 
-  if tf.keras.backend.image_data_format() == 'channels_first':
+  if keras.backend.image_data_format() == 'channels_first':
     x_train = x_train.reshape(x_train.shape[0], 1, img_rows, img_cols)
     x_test = x_test.reshape(x_test.shape[0], 1, img_rows, img_cols)
     input_shape = (1, img_rows, img_cols)
@@ -90,8 +93,8 @@ def get_preprocessed_mnist_data(img_rows=28,
     x_test /= 255
 
   # convert class vectors to binary class matrices
-  y_train = tf.keras.utils.to_categorical(y_train, num_classes)
-  y_test = tf.keras.utils.to_categorical(y_test, num_classes)
+  y_train = keras.utils.to_categorical(y_train, num_classes)
+  y_test = keras.utils.to_categorical(y_test, num_classes)
 
   return (x_train, y_train), (x_test, y_test), input_shape
 
diff --git a/tensorflow_model_optimization/python/core/keras/testing/BUILD b/tensorflow_model_optimization/python/core/keras/testing/BUILD
index ef91b62a2..acf9a4d4a 100644
--- a/tensorflow_model_optimization/python/core/keras/testing/BUILD
+++ b/tensorflow_model_optimization/python/core/keras/testing/BUILD
@@ -15,5 +15,6 @@ py_strict_library(
     deps = [
         # numpy dep1,
         # tensorflow dep1,
+        "//tensorflow_model_optimization/python/core/keras:compat",
     ],
 )
diff --git a/tensorflow_model_optimization/python/core/keras/testing/test_utils_mnist.py b/tensorflow_model_optimization/python/core/keras/testing/test_utils_mnist.py
index a9306def6..1738ab942 100644
--- a/tensorflow_model_optimization/python/core/keras/testing/test_utils_mnist.py
+++ b/tensorflow_model_optimization/python/core/keras/testing/test_utils_mnist.py
@@ -18,7 +18,10 @@
 import tensorflow as tf
 from tensorflow import keras
 
-l = tf.keras.layers
+from tensorflow_model_optimization.python.core.keras.compat import keras
+
+
+l = keras.layers
 
 
 def layers_list():
@@ -59,7 +62,7 @@ def functional_model():
 
 
 def image_input_shape(img_rows=28, img_cols=28):
-  if tf.keras.backend.image_data_format() == 'channels_first':
+  if keras.backend.image_data_format() == 'channels_first':
     return 1, img_rows, img_cols
   else:
     return img_rows, img_cols, 1
@@ -69,9 +72,9 @@ def preprocessed_data(img_rows=28,
                       img_cols=28,
                       num_classes=10):
   """Get data for mnist training and evaluation."""
-  (x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
+  (x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
 
-  if tf.keras.backend.image_data_format() == 'channels_first':
+  if keras.backend.image_data_format() == 'channels_first':
     x_train = x_train.reshape(x_train.shape[0], 1, img_rows, img_cols)
     x_test = x_test.reshape(x_test.shape[0], 1, img_rows, img_cols)
   else:
@@ -84,8 +87,8 @@ def preprocessed_data(img_rows=28,
   x_test /= 255
 
   # convert class vectors to binary class matrices
-  y_train = tf.keras.utils.to_categorical(y_train, num_classes)
-  y_test = tf.keras.utils.to_categorical(y_test, num_classes)
+  y_train = keras.utils.to_categorical(y_train, num_classes)
+  y_test = keras.utils.to_categorical(y_test, num_classes)
 
   return x_train, y_train, x_test, y_test
 
diff --git a/tensorflow_model_optimization/python/core/quantization/keras/BUILD b/tensorflow_model_optimization/python/core/quantization/keras/BUILD
index 4c4855a70..ac0a4c856 100644
--- a/tensorflow_model_optimization/python/core/quantization/keras/BUILD
+++ b/tensorflow_model_optimization/python/core/quantization/keras/BUILD
@@ -60,6 +60,7 @@ py_strict_library(
         ":quant_ops",
         # six dep1,
         # tensorflow dep1,
+        "//tensorflow_model_optimization/python/core/keras:compat",
     ],
 )
 
@@ -130,6 +131,7 @@ py_strict_library(
     deps = [
         ":utils",
         # tensorflow dep1,
+        "//tensorflow_model_optimization/python/core/keras:compat",
     ],
 )
 
@@ -146,6 +148,7 @@ py_strict_test(
         # google/protobuf:use_fast_cpp_protos dep1,  # Automatically added
         # numpy dep1,
         # tensorflow dep1,
+        "//tensorflow_model_optimization/python/core/keras:compat",
     ],
 )
 
@@ -159,6 +162,7 @@ py_strict_library(
     deps = [
         ":utils",
         # tensorflow dep1,
+        "//tensorflow_model_optimization/python/core/keras:compat",
         "//tensorflow_model_optimization/python/core/keras:utils",
     ],
 )
@@ -178,6 +182,7 @@ py_strict_test(
         # google/protobuf:use_fast_cpp_protos dep1,  # Automatically added
         # numpy dep1,
         # tensorflow dep1,
+        "//tensorflow_model_optimization/python/core/keras:compat",
     ],
 )
 
@@ -192,6 +197,7 @@ py_strict_library(
         ":quantizers",
         ":utils",
         # tensorflow dep1,
+        "//tensorflow_model_optimization/python/core/keras:compat",
         "//tensorflow_model_optimization/python/core/keras:utils",
     ],
 )
@@ -209,6 +215,7 @@ py_strict_test(
         # google/protobuf:use_fast_cpp_protos dep1,  # Automatically added
         # numpy dep1,
         # tensorflow dep1,
+        "//tensorflow_model_optimization/python/core/keras:compat",
     ],
 )
 
@@ -224,6 +231,7 @@ py_strict_library(
         ":utils",
         # tensorflow dep1,
         # python/util:tf_inspect tensorflow dep2,
+        "//tensorflow_model_optimization/python/core/keras:compat",
         "//tensorflow_model_optimization/python/core/keras:metrics",
         "//tensorflow_model_optimization/python/core/keras:utils",
     ],
@@ -243,6 +251,7 @@ py_strict_test(
         # google/protobuf:use_fast_cpp_protos dep1,  # Automatically added
         # numpy dep1,
         # tensorflow dep1,
+        "//tensorflow_model_optimization/python/core/keras:compat",
         "//tensorflow_model_optimization/python/core/quantization/keras/default_8bit:default_8bit_quantize_registry",
     ],
 )
@@ -263,6 +272,7 @@ py_strict_library(
         ":quantizers",
         ":utils",
         # tensorflow dep1,
+        "//tensorflow_model_optimization/python/core/keras:compat",
         "//tensorflow_model_optimization/python/core/keras:metrics",
         "//tensorflow_model_optimization/python/core/quantization/keras/default_8bit:default_8bit_quantize_registry",
         "//tensorflow_model_optimization/python/core/quantization/keras/default_8bit:default_8bit_quantize_scheme",
@@ -286,6 +296,7 @@ py_strict_test(
         # google/protobuf:use_fast_cpp_protos dep1,  # Automatically added
         # numpy dep1,
         # tensorflow dep1,
+        "//tensorflow_model_optimization/python/core/keras:compat",
         "//tensorflow_model_optimization/python/core/keras:test_utils",
         "//tensorflow_model_optimization/python/core/quantization/keras/default_8bit:default_8bit_quantize_registry",
     ],
@@ -323,6 +334,7 @@ py_strict_test(
         # google/protobuf:use_fast_cpp_protos dep1,  # Automatically added
         # numpy dep1,
         # tensorflow dep1,
+        "//tensorflow_model_optimization/python/core/keras:compat",
     ],
 )
 
diff --git a/tensorflow_model_optimization/python/core/quantization/keras/collab_opts/cluster_preserve/BUILD b/tensorflow_model_optimization/python/core/quantization/keras/collab_opts/cluster_preserve/BUILD
index 76ffb4ab0..bca7e2881 100644
--- a/tensorflow_model_optimization/python/core/quantization/keras/collab_opts/cluster_preserve/BUILD
+++ b/tensorflow_model_optimization/python/core/quantization/keras/collab_opts/cluster_preserve/BUILD
@@ -1,5 +1,5 @@
-# Placeholder: load py_test
 load("//tensorflow_model_optimization:tensorflow_model_optimization.bzl", "py_strict_library")
+# Placeholder: load py_test
 
 package(default_visibility = [
     "//tensorflow_model_optimization:__subpackages__",
@@ -26,6 +26,7 @@ py_strict_library(
     srcs_version = "PY3",
     deps = [
         # tensorflow dep1,
+        "//tensorflow_model_optimization/python/core/keras:compat",
     ],
 )
 
@@ -40,6 +41,7 @@ py_strict_library(
         # tensorflow dep1,
         "//tensorflow_model_optimization/python/core/clustering/keras:cluster_config",
         "//tensorflow_model_optimization/python/core/clustering/keras:clustering_registry",
+        "//tensorflow_model_optimization/python/core/keras:compat",
         "//tensorflow_model_optimization/python/core/quantization/keras:quant_ops",
         "//tensorflow_model_optimization/python/core/quantization/keras:quantizers",
         "//tensorflow_model_optimization/python/core/quantization/keras/default_8bit:default_8bit_quantize_registry",
@@ -60,6 +62,7 @@ py_test(
         # google/protobuf:use_fast_cpp_protos dep1,  # Automatically added
         # tensorflow dep1,
         "//tensorflow_model_optimization/python/core/clustering/keras:clustering_registry",
+        "//tensorflow_model_optimization/python/core/keras:compat",
         "//tensorflow_model_optimization/python/core/quantization/keras:quantize_config",
         "//tensorflow_model_optimization/python/core/quantization/keras/default_8bit:default_8bit_quantize_registry",
     ],
@@ -78,6 +81,7 @@ py_test(
         # google/protobuf:use_fast_cpp_protos dep1,  # Automatically added
         # tensorflow dep1,
         "//tensorflow_model_optimization/python/core/clustering/keras/experimental:cluster",
+        "//tensorflow_model_optimization/python/core/keras:compat",
         "//tensorflow_model_optimization/python/core/quantization/keras:quantize",
         "//tensorflow_model_optimization/python/core/sparsity/keras:prune",
         "//tensorflow_model_optimization/python/core/sparsity/keras:pruning_callbacks",
@@ -110,6 +114,7 @@ py_test(
         "//tensorflow_model_optimization/python/core/clustering/keras:cluster",
         "//tensorflow_model_optimization/python/core/clustering/keras:clustering_registry",
         "//tensorflow_model_optimization/python/core/clustering/keras/experimental:cluster",
+        "//tensorflow_model_optimization/python/core/keras:compat",
         "//tensorflow_model_optimization/python/core/quantization/keras:quantize",
         "//tensorflow_model_optimization/python/core/quantization/keras:quantize_config",
     ],
diff --git a/tensorflow_model_optimization/python/core/quantization/keras/collab_opts/cluster_preserve/cluster_preserve_integration_test.py b/tensorflow_model_optimization/python/core/quantization/keras/collab_opts/cluster_preserve/cluster_preserve_integration_test.py
index f8f87aca0..99008a206 100644
--- a/tensorflow_model_optimization/python/core/quantization/keras/collab_opts/cluster_preserve/cluster_preserve_integration_test.py
+++ b/tensorflow_model_optimization/python/core/quantization/keras/collab_opts/cluster_preserve/cluster_preserve_integration_test.py
@@ -20,13 +20,15 @@
 from tensorflow_model_optimization.python.core.clustering.keras import cluster
 from tensorflow_model_optimization.python.core.clustering.keras import cluster_config
 from tensorflow_model_optimization.python.core.clustering.keras.experimental import cluster as experimental_cluster
+from tensorflow_model_optimization.python.core.keras.compat import keras
 from tensorflow_model_optimization.python.core.quantization.keras import quantize
 from tensorflow_model_optimization.python.core.quantization.keras.collab_opts.cluster_preserve import (
     default_8bit_cluster_preserve_quantize_scheme,)
 from tensorflow_model_optimization.python.core.quantization.keras.collab_opts.cluster_preserve.cluster_utils import (
     strip_clustering_cqat,)
 
-layers = tf.keras.layers
+
+layers = keras.layers
 
 
 class ClusterPreserveIntegrationTest(tf.test.TestCase, parameterized.TestCase):
@@ -41,14 +43,15 @@ def setUp(self):
   def compile_and_fit(self, model):
     """Here we compile and fit the model."""
     model.compile(
-        loss=tf.keras.losses.categorical_crossentropy,
+        loss=keras.losses.categorical_crossentropy,
         optimizer='adam',
         metrics=['accuracy'],
     )
     model.fit(
         np.random.rand(20, 10),
-        tf.keras.utils.to_categorical(np.random.randint(5, size=(20, 1)), 5),
-        batch_size=20)
+        keras.utils.to_categorical(np.random.randint(5, size=(20, 1)), 5),
+        batch_size=20,
+    )
 
   def _get_number_of_unique_weights(self, stripped_model, layer_nr,
                                     weight_name):
@@ -68,7 +71,7 @@ def _get_sparsity(self, model):
     for layer in model.layers:
       for weights in layer.trainable_weights:
         if 'kernel' in weights.name:
-          np_weights = tf.keras.backend.get_value(weights)
+          np_weights = keras.backend.get_value(weights)
           sparsity = 1.0 - np.count_nonzero(np_weights) / float(
               np_weights.size)
           sparsity_list.append(sparsity)
@@ -78,7 +81,7 @@ def _get_sparsity(self, model):
   def _get_clustered_model(self, preserve_sparsity):
     """Cluster the (sparse) model and return clustered_model."""
     tf.random.set_seed(1)
-    original_model = tf.keras.Sequential([
+    original_model = keras.Sequential([
         layers.Dense(5, activation='softmax', input_shape=(10,)),
         layers.Flatten(),
     ])
@@ -106,18 +109,18 @@ def _get_conv_model(self,
                       data_format=None,
                       kernel_size=(3, 3)):
     """Returns functional model with Conv2D layer."""
-    inp = tf.keras.layers.Input(shape=(32, 32), batch_size=100)
+    inp = keras.layers.Input(shape=(32, 32), batch_size=100)
     shape = (1, 32, 32) if data_format == 'channels_first' else (32, 32, 1)
-    x = tf.keras.layers.Reshape(shape)(inp)
-    x = tf.keras.layers.Conv2D(
+    x = keras.layers.Reshape(shape)(inp)
+    x = keras.layers.Conv2D(
         filters=nr_of_channels,
         kernel_size=kernel_size,
         data_format=data_format,
-        activation='relu')(
-            x)
-    x = tf.keras.layers.MaxPool2D(2, 2)(x)
-    out = tf.keras.layers.Flatten()(x)
-    model = tf.keras.Model(inputs=inp, outputs=out)
+        activation='relu',
+    )(x)
+    x = keras.layers.MaxPool2D(2, 2)(x)
+    out = keras.layers.Flatten()(x)
+    model = keras.Model(inputs=inp, outputs=out)
     return model
 
   def _compile_and_fit_conv_model(self, model, nr_epochs=1):
@@ -125,9 +128,10 @@ def _compile_and_fit_conv_model(self, model, nr_epochs=1):
     x_train = np.random.uniform(size=(500, 32, 32))
     y_train = np.random.randint(low=0, high=1024, size=(500,))
     model.compile(
-        optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4),
-        loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
-        metrics=[tf.keras.metrics.SparseCategoricalAccuracy(name='accuracy')])
+        optimizer=keras.optimizers.Adam(learning_rate=1e-4),
+        loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
+        metrics=[keras.metrics.SparseCategoricalAccuracy(name='accuracy')],
+    )
 
     model.fit(x_train, y_train, epochs=nr_epochs, batch_size=100, verbose=1)
 
@@ -197,9 +201,9 @@ def _pcqat_training(self, preserve_sparsity, quant_aware_annotate_model):
 
   def testEndToEndClusterPreserve(self):
     """Runs CQAT end to end and whole model is quantized."""
-    original_model = tf.keras.Sequential([
-        layers.Dense(5, activation='softmax', input_shape=(10,))
-    ])
+    original_model = keras.Sequential(
+        [layers.Dense(5, activation='softmax', input_shape=(10,))]
+    )
     clustered_model = cluster.cluster_weights(
         original_model,
         **self.cluster_params)
@@ -228,9 +232,9 @@ def testEndToEndClusterPreserve(self):
 
   def testEndToEndClusterPreservePerLayer(self):
     """Runs CQAT end to end and model is quantized per layers."""
-    original_model = tf.keras.Sequential([
+    original_model = keras.Sequential([
         layers.Dense(5, activation='relu', input_shape=(10,)),
-        layers.Dense(5, activation='softmax', input_shape=(10,))
+        layers.Dense(5, activation='softmax', input_shape=(10,)),
     ])
     clustered_model = cluster.cluster_weights(
         original_model,
@@ -241,11 +245,11 @@ def testEndToEndClusterPreservePerLayer(self):
         clustered_model, 1, 'kernel')
 
     def apply_quantization_to_dense(layer):
-      if isinstance(layer, tf.keras.layers.Dense):
+      if isinstance(layer, keras.layers.Dense):
         return quantize.quantize_annotate_layer(layer)
       return layer
 
-    quant_aware_annotate_model = tf.keras.models.clone_model(
+    quant_aware_annotate_model = keras.models.clone_model(
         clustered_model,
         clone_function=apply_quantization_to_dense,
     )
@@ -268,9 +272,9 @@ def apply_quantization_to_dense(layer):
 
   def testEndToEndClusterPreserveOneLayer(self):
     """Runs CQAT end to end and model is quantized only for a single layer."""
-    original_model = tf.keras.Sequential([
+    original_model = keras.Sequential([
         layers.Dense(5, activation='relu', input_shape=(10,)),
-        layers.Dense(5, activation='softmax', input_shape=(10,), name='qat')
+        layers.Dense(5, activation='softmax', input_shape=(10,), name='qat'),
     ])
     clustered_model = cluster.cluster_weights(
         original_model,
@@ -281,12 +285,12 @@ def testEndToEndClusterPreserveOneLayer(self):
         clustered_model, 1, 'kernel')
 
     def apply_quantization_to_dense(layer):
-      if isinstance(layer, tf.keras.layers.Dense):
+      if isinstance(layer, keras.layers.Dense):
         if layer.name == 'qat':
           return quantize.quantize_annotate_layer(layer)
       return layer
 
-    quant_aware_annotate_model = tf.keras.models.clone_model(
+    quant_aware_annotate_model = keras.models.clone_model(
         clustered_model,
         clone_function=apply_quantization_to_dense,
     )
@@ -591,7 +595,7 @@ def testPassingNonPrunedModelToPCQAT(self):
   def testPassingModelWithUniformWeightsToPCQAT(self, uniform_weights):
     """If pruned_clustered_model has uniform weights, it won't break PCQAT."""
     preserve_sparsity = True
-    original_model = tf.keras.Sequential([
+    original_model = keras.Sequential([
         layers.Dense(5, activation='softmax', input_shape=(10,)),
         layers.Flatten(),
     ])
@@ -643,12 +647,12 @@ def testTrainableWeightsBehaveCorrectlyDuringPCQAT(self):
         .Default8BitClusterPreserveQuantizeScheme(True))
 
     quant_aware_model.compile(
-        loss=tf.keras.losses.categorical_crossentropy,
+        loss=keras.losses.categorical_crossentropy,
         optimizer='adam',
         metrics=['accuracy'],
     )
 
-    class CheckCentroidsAndTrainableVarsCallback(tf.keras.callbacks.Callback):
+    class CheckCentroidsAndTrainableVarsCallback(keras.callbacks.Callback):
       """Check the updates of trainable variables and centroid masks."""
 
       def on_epoch_begin(self, batch, logs=None):
@@ -692,12 +696,13 @@ def on_epoch_end(self, batch, logs=None):
     # Use many epochs to verify layer's kernel weights are updating because
     # they can stay the same after being trained using only the first batch
     # of data for instance
-    quant_aware_model.fit(np.random.rand(20, 10),
-                          tf.keras.utils.to_categorical(
-                              np.random.randint(5, size=(20, 1)), 5),
-                          steps_per_epoch=5,
-                          epochs=3,
-                          callbacks=[CheckCentroidsAndTrainableVarsCallback()])
+    quant_aware_model.fit(
+        np.random.rand(20, 10),
+        keras.utils.to_categorical(np.random.randint(5, size=(20, 1)), 5),
+        steps_per_epoch=5,
+        epochs=3,
+        callbacks=[CheckCentroidsAndTrainableVarsCallback()],
+    )
 
 
 if __name__ == '__main__':
diff --git a/tensorflow_model_optimization/python/core/quantization/keras/collab_opts/cluster_preserve/cluster_preserve_quantize_registry.py b/tensorflow_model_optimization/python/core/quantization/keras/collab_opts/cluster_preserve/cluster_preserve_quantize_registry.py
index bc34c3bdc..ff2381f72 100644
--- a/tensorflow_model_optimization/python/core/quantization/keras/collab_opts/cluster_preserve/cluster_preserve_quantize_registry.py
+++ b/tensorflow_model_optimization/python/core/quantization/keras/collab_opts/cluster_preserve/cluster_preserve_quantize_registry.py
@@ -21,13 +21,15 @@
 
 from tensorflow_model_optimization.python.core.clustering.keras import cluster_config
 from tensorflow_model_optimization.python.core.clustering.keras import clustering_registry
+from tensorflow_model_optimization.python.core.keras.compat import keras
 from tensorflow_model_optimization.python.core.quantization.keras import quant_ops
 from tensorflow_model_optimization.python.core.quantization.keras import quantizers
 from tensorflow_model_optimization.python.core.quantization.keras.default_8bit import default_8bit_quantize_registry
 from tensorflow_model_optimization.python.core.quantization.keras.default_8bit import default_8bit_quantizers
 
-layers = tf.keras.layers
-K = tf.keras.backend
+
+layers = keras.layers
+K = keras.backend
 
 CLUSTER_CENTROIDS = 'cluster_centroids_tf'
 PULLING_INDICES = 'pulling_indices_tf'
@@ -77,8 +79,9 @@ def get_centroids(layer, weight, data_format):
     A 4-tuple of centroids (unique values), number of centroids, lookup index,
     whether to cluster per channel (boolean).
   """
-  cluster_per_channel = (
-      layer.layer and isinstance(layer.layer, tf.keras.layers.Conv2D))
+  cluster_per_channel = layer.layer and isinstance(
+      layer.layer, keras.layers.Conv2D
+  )
 
   if not cluster_per_channel:
     centroids, index = get_unique(weight)
@@ -373,18 +376,22 @@ def _build_clusters(self, name, layer):
       clst_centroids_tf = layer.add_weight(
           CLUSTER_CENTROIDS,
           shape=centroids.shape,
-          initializer=tf.keras.initializers.Constant(
-              value=K.batch_get_value([centroids])[0]),
+          initializer=keras.initializers.Constant(
+              value=K.batch_get_value([centroids])[0]
+          ),
           dtype=centroids.dtype,
-          trainable=True)
+          trainable=True,
+      )
 
       ori_weights_tf = layer.add_weight(
           ORIGINAL_WEIGHTS,
           shape=weights.shape,
-          initializer=tf.keras.initializers.Constant(
-              value=K.batch_get_value([weights])[0]),
+          initializer=keras.initializers.Constant(
+              value=K.batch_get_value([weights])[0]
+          ),
           dtype=weights.dtype,
-          trainable=True)
+          trainable=True,
+      )
 
       # Get clustering implementation according to layer type
       clustering_impl_cls = clustering_registry.ClusteringLookupRegistry(
@@ -402,10 +409,12 @@ def _build_clusters(self, name, layer):
       pulling_indices_tf = layer.add_weight(
           PULLING_INDICES,
           shape=lookup.shape,
-          initializer=tf.keras.initializers.Constant(
-              value=K.batch_get_value([pulling_indices])[0]),
+          initializer=keras.initializers.Constant(
+              value=K.batch_get_value([pulling_indices])[0]
+          ),
           dtype=lookup.dtype,
-          trainable=False)
+          trainable=False,
+      )
 
       result_clst = {
           CLUSTER_CENTROIDS: clst_centroids_tf,
diff --git a/tensorflow_model_optimization/python/core/quantization/keras/collab_opts/cluster_preserve/cluster_preserve_quantize_registry_test.py b/tensorflow_model_optimization/python/core/quantization/keras/collab_opts/cluster_preserve/cluster_preserve_quantize_registry_test.py
index 915f1c111..e6a25dd99 100644
--- a/tensorflow_model_optimization/python/core/quantization/keras/collab_opts/cluster_preserve/cluster_preserve_quantize_registry_test.py
+++ b/tensorflow_model_optimization/python/core/quantization/keras/collab_opts/cluster_preserve/cluster_preserve_quantize_registry_test.py
@@ -16,14 +16,15 @@
 
 import tensorflow as tf
 
-
 from tensorflow_model_optimization.python.core.clustering.keras import clustering_registry
+from tensorflow_model_optimization.python.core.keras.compat import keras
 from tensorflow_model_optimization.python.core.quantization.keras import quantize_config
 from tensorflow_model_optimization.python.core.quantization.keras.collab_opts.cluster_preserve import cluster_preserve_quantize_registry
 from tensorflow_model_optimization.python.core.quantization.keras.default_8bit import default_8bit_quantize_registry
 
+
 QuantizeConfig = quantize_config.QuantizeConfig
-layers = tf.keras.layers
+layers = keras.layers
 
 
 class ClusterPreserveQuantizeRegistryTest(tf.test.TestCase):
diff --git a/tensorflow_model_optimization/python/core/quantization/keras/collab_opts/cluster_preserve/cluster_utils.py b/tensorflow_model_optimization/python/core/quantization/keras/collab_opts/cluster_preserve/cluster_utils.py
index a38313156..022616341 100644
--- a/tensorflow_model_optimization/python/core/quantization/keras/collab_opts/cluster_preserve/cluster_utils.py
+++ b/tensorflow_model_optimization/python/core/quantization/keras/collab_opts/cluster_preserve/cluster_utils.py
@@ -14,6 +14,7 @@
 # ==============================================================================
 """Util functions for weight clustering."""
 import tensorflow as tf
+from tensorflow_model_optimization.python.core.keras.compat import keras
 
 
 def _type_model(model):
@@ -26,15 +27,17 @@ def _type_model(model):
       is_keras_layer, is_subclassed_model)
   """
   # pylint:disable=protected-access
-  is_sequential_or_functional = isinstance(
-      model, tf.keras.Model) and (isinstance(model, tf.keras.Sequential) or
-                                  model._is_graph_network)
+  is_sequential_or_functional = isinstance(model, keras.Model) and (
+      isinstance(model, keras.Sequential) or model._is_graph_network
+  )
 
-  is_keras_layer = isinstance(
-      model, tf.keras.layers.Layer) and not isinstance(model, tf.keras.Model)
+  is_keras_layer = isinstance(model, keras.layers.Layer) and not isinstance(
+      model, keras.Model
+  )
 
-  is_subclassed_model = isinstance(model, tf.keras.Model) and (
-      not model._is_graph_network)
+  is_subclassed_model = isinstance(model, keras.Model) and (
+      not model._is_graph_network
+  )
 
   return (is_sequential_or_functional, is_keras_layer, is_subclassed_model)
 
@@ -48,29 +51,32 @@ def strip_clustering_cqat(to_strip):
   with the clustered weights should be restored.
 
   Arguments:
-      to_strip: A `tf.keras.Model` instance with clustered layers or a
-      `tf.keras.layers.Layer` instance
+      to_strip: A `keras.Model` instance with clustered layers or a
+        `keras.layers.Layer` instance
 
   Returns:
     A keras model or layer with clustering variables removed.
 
   Raises:
-    ValueError: if the model is not a `tf.keras.Model` instance.
+    ValueError: if the model is not a `keras.Model` instance.
     NotImplementedError: if the model is a subclassed model.
-
   """
-  if not isinstance(to_strip, tf.keras.Model) and not isinstance(
-      to_strip, tf.keras.layers.Layer):
+  if not isinstance(to_strip, keras.Model) and not isinstance(
+      to_strip, keras.layers.Layer
+  ):
     raise ValueError(
-        ('Expected to_strip to be a `tf.keras.Model` or'
-         '`tf.keras.layers.Layer` instance but got: '), to_strip)
+        (
+            'Expected to_strip to be a `keras.Model` or'
+            '`keras.layers.Layer` instance but got: '
+        ),
+        to_strip,
+    )
 
   def _strip_clustering_ops(layer):
-    if isinstance(layer, tf.keras.Model):
-      return tf.keras.models.clone_model(
-          layer,
-          input_tensors=None,
-          clone_function=_strip_clustering_ops)
+    if isinstance(layer, keras.Model):
+      return keras.models.clone_model(
+          layer, input_tensors=None, clone_function=_strip_clustering_ops
+      )
 
     # set the attributes of the layer to the result after cqat
     # and remove all other variables, we do not remove the
@@ -81,8 +87,7 @@ def _strip_clustering_ops(layer):
     if hasattr(layer, 'layer'):
       # pylint:disable=protected-access
       if 'depthwise' not in layer.layer.name:
-        if isinstance(layer.layer,
-                      (tf.keras.layers.Conv2D, tf.keras.layers.Dense)):
+        if isinstance(layer.layer, (keras.layers.Conv2D, keras.layers.Dense)):
           new_variables = []
           for v in layer._trainable_weights:
             if 'cluster_centroids_tf' in v.name or (
@@ -105,10 +110,11 @@ def _strip_clustering_ops(layer):
 
   # Just copy the model with the right callback
   if is_sequential_or_functional:
-    return tf.keras.models.clone_model(
-        to_strip, input_tensors=None, clone_function=_strip_clustering_ops)
+    return keras.models.clone_model(
+        to_strip, input_tensors=None, clone_function=_strip_clustering_ops
+    )
   elif is_keras_layer:
-    if isinstance(to_strip, tf.keras.layers.Layer):
+    if isinstance(to_strip, keras.layers.Layer):
       return _strip_clustering_ops(to_strip)
   elif is_subclassed_model:
     to_strip_model = to_strip.model
diff --git a/tensorflow_model_optimization/python/core/quantization/keras/collab_opts/cluster_preserve/mnist_prune_cluster_preserve_qat_test.py b/tensorflow_model_optimization/python/core/quantization/keras/collab_opts/cluster_preserve/mnist_prune_cluster_preserve_qat_test.py
index 52ae828f2..fcc7e61da 100644
--- a/tensorflow_model_optimization/python/core/quantization/keras/collab_opts/cluster_preserve/mnist_prune_cluster_preserve_qat_test.py
+++ b/tensorflow_model_optimization/python/core/quantization/keras/collab_opts/cluster_preserve/mnist_prune_cluster_preserve_qat_test.py
@@ -19,6 +19,7 @@
 from tensorflow_model_optimization.python.core.clustering.keras import cluster as tfmot_cluster
 from tensorflow_model_optimization.python.core.clustering.keras import cluster_config as tfmot_cluster_config
 from tensorflow_model_optimization.python.core.clustering.keras.experimental import cluster as exp_tfmot_cluster
+from tensorflow_model_optimization.python.core.keras.compat import keras
 from tensorflow_model_optimization.python.core.quantization.keras import quantize
 from tensorflow_model_optimization.python.core.quantization.keras.collab_opts.cluster_preserve import cluster_utils
 from tensorflow_model_optimization.python.core.quantization.keras.collab_opts.cluster_preserve import (
@@ -28,27 +29,28 @@
 from tensorflow_model_optimization.python.core.sparsity.keras import pruning_schedule
 
 
-layers = tf.keras.layers
+layers = keras.layers
 np.random.seed(1)
 tf.random.set_seed(3)
 
 
 def _build_model():
   """Create the baseline model."""
-  i = tf.keras.layers.Input(shape=(28, 28), name='input')
-  x = tf.keras.layers.Reshape((28, 28, 1))(i)
-  x = tf.keras.layers.Conv2D(
-      filters=12, kernel_size=(3, 3), activation='relu', name='conv1')(x)
-  x = tf.keras.layers.MaxPool2D(2, 2)(x)
-  x = tf.keras.layers.Flatten()(x)
-  output = tf.keras.layers.Dense(10, name='fc2')(x)
-  model = tf.keras.Model(inputs=[i], outputs=[output])
+  i = keras.layers.Input(shape=(28, 28), name='input')
+  x = keras.layers.Reshape((28, 28, 1))(i)
+  x = keras.layers.Conv2D(
+      filters=12, kernel_size=(3, 3), activation='relu', name='conv1'
+  )(x)
+  x = keras.layers.MaxPool2D(2, 2)(x)
+  x = keras.layers.Flatten()(x)
+  output = keras.layers.Dense(10, name='fc2')(x)
+  model = keras.Model(inputs=[i], outputs=[output])
 
   return model
 
 
 def _get_dataset():
-  mnist = tf.keras.datasets.mnist
+  mnist = keras.datasets.mnist
   (x_train, y_train), (x_test, y_test) = mnist.load_data()
   x_train, x_test = x_train / 255.0, x_test / 255.0
   # Use subset of 60000 examples to keep unit test speed fast.
@@ -59,7 +61,7 @@ def _get_dataset():
 
 
 def _train_model(model, callback_to_use, num_of_epochs):
-  loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
+  loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True)
   model.compile(optimizer='adam',
                 loss=loss_fn,
                 metrics=['accuracy'],)
@@ -134,11 +136,11 @@ def selective_cluster_model(original_model, sparsity_flag):
   }
 
   def apply_clustering_to_conv2d(layer):
-    if isinstance(layer, tf.keras.layers.Conv2D):
+    if isinstance(layer, keras.layers.Conv2D):
       return exp_tfmot_cluster.cluster_weights(layer, **clustering_params)
     return layer
 
-  cluster_model = tf.keras.models.clone_model(
+  cluster_model = keras.models.clone_model(
       original_model,
       clone_function=apply_clustering_to_conv2d,
   )
@@ -177,13 +179,17 @@ def _get_num_unique_weights_kernel(model):
 
   num_unique_weights_list = []
   for layer in model.layers:
-    if isinstance(layer,
-                  (tf.keras.layers.Conv2D,
-                   tf.keras.layers.Dense,
-                   quantize.quantize_wrapper.QuantizeWrapper)):
+    if isinstance(
+        layer,
+        (
+            keras.layers.Conv2D,
+            keras.layers.Dense,
+            quantize.quantize_wrapper.QuantizeWrapper,
+        ),
+    ):
       for weights in layer.trainable_weights:
         if 'kernel' in weights.name:
-          np_weights = tf.keras.backend.get_value(weights)
+          np_weights = keras.backend.get_value(weights)
           unique_weights = len(np.unique(np_weights))
           num_unique_weights_list.append(unique_weights)
 
@@ -198,7 +204,7 @@ def _check_sparsity_kernel(model):
                    quantize.quantize_wrapper.QuantizeWrapper)):
       for weights in layer.trainable_weights:
         if 'kernel' in weights.name:
-          np_weights = tf.keras.backend.get_value(weights)
+          np_weights = keras.backend.get_value(weights)
           sparsity = 1.0 - np.count_nonzero(np_weights) / float(
               np_weights.size)
           sparsity_list.append(sparsity)
diff --git a/tensorflow_model_optimization/python/core/quantization/keras/collab_opts/prune_preserve/BUILD b/tensorflow_model_optimization/python/core/quantization/keras/collab_opts/prune_preserve/BUILD
index e08952720..74b56495a 100644
--- a/tensorflow_model_optimization/python/core/quantization/keras/collab_opts/prune_preserve/BUILD
+++ b/tensorflow_model_optimization/python/core/quantization/keras/collab_opts/prune_preserve/BUILD
@@ -1,5 +1,5 @@
-# Placeholder: load py_test
 load("//tensorflow_model_optimization:tensorflow_model_optimization.bzl", "py_strict_library")
+# Placeholder: load py_test
 
 package(default_visibility = [
     "//tensorflow_model_optimization:__subpackages__",
@@ -26,6 +26,7 @@ py_strict_library(
     srcs_version = "PY3",
     deps = [
         # tensorflow dep1,
+        "//tensorflow_model_optimization/python/core/keras:compat",
         "//tensorflow_model_optimization/python/core/quantization/keras:quant_ops",
         "//tensorflow_model_optimization/python/core/quantization/keras:quantizers",
         "//tensorflow_model_optimization/python/core/quantization/keras/default_8bit:default_8bit_quantize_registry",
@@ -45,6 +46,7 @@ py_test(
         # absl/testing:parameterized dep1,
         # google/protobuf:use_fast_cpp_protos dep1,  # Automatically added
         # tensorflow dep1,
+        "//tensorflow_model_optimization/python/core/keras:compat",
         "//tensorflow_model_optimization/python/core/quantization/keras/default_8bit:default_8bit_quantize_registry",
         "//tensorflow_model_optimization/python/core/sparsity/keras:prune_registry",
     ],
diff --git a/tensorflow_model_optimization/python/core/quantization/keras/collab_opts/prune_preserve/prune_preserve_quantize_registry.py b/tensorflow_model_optimization/python/core/quantization/keras/collab_opts/prune_preserve/prune_preserve_quantize_registry.py
index 59e490eae..e8301dffe 100644
--- a/tensorflow_model_optimization/python/core/quantization/keras/collab_opts/prune_preserve/prune_preserve_quantize_registry.py
+++ b/tensorflow_model_optimization/python/core/quantization/keras/collab_opts/prune_preserve/prune_preserve_quantize_registry.py
@@ -16,6 +16,7 @@
 
 import tensorflow as tf
 
+from tensorflow_model_optimization.python.core.keras.compat import keras
 from tensorflow_model_optimization.python.core.quantization.keras import quant_ops
 from tensorflow_model_optimization.python.core.quantization.keras import quantizers
 from tensorflow_model_optimization.python.core.quantization.keras.default_8bit import (
@@ -23,7 +24,8 @@
 from tensorflow_model_optimization.python.core.quantization.keras.default_8bit import (
     default_8bit_quantizers,)
 
-layers = tf.keras.layers
+
+layers = keras.layers
 
 
 class _PrunePreserveInfo(object):
diff --git a/tensorflow_model_optimization/python/core/quantization/keras/collab_opts/prune_preserve/prune_preserve_quantize_registry_test.py b/tensorflow_model_optimization/python/core/quantization/keras/collab_opts/prune_preserve/prune_preserve_quantize_registry_test.py
index 46f0fea5f..29e27e197 100644
--- a/tensorflow_model_optimization/python/core/quantization/keras/collab_opts/prune_preserve/prune_preserve_quantize_registry_test.py
+++ b/tensorflow_model_optimization/python/core/quantization/keras/collab_opts/prune_preserve/prune_preserve_quantize_registry_test.py
@@ -15,14 +15,16 @@
 """Tests for PrunePreserveQuantizeRegistry."""
 import tensorflow as tf
 
+from tensorflow_model_optimization.python.core.keras.compat import keras
 from tensorflow_model_optimization.python.core.quantization.keras import quantize_config
 from tensorflow_model_optimization.python.core.quantization.keras.collab_opts.prune_preserve import (
     prune_preserve_quantize_registry,)
 from tensorflow_model_optimization.python.core.quantization.keras.default_8bit import default_8bit_quantize_registry
 from tensorflow_model_optimization.python.core.sparsity.keras import prune_registry
 
+
 QuantizeConfig = quantize_config.QuantizeConfig
-layers = tf.keras.layers
+layers = keras.layers
 
 
 class PrunePreserveQuantizeRegistryTest(tf.test.TestCase):
diff --git a/tensorflow_model_optimization/python/core/quantization/keras/default_8bit/BUILD b/tensorflow_model_optimization/python/core/quantization/keras/default_8bit/BUILD
index ac3bf9a8a..bbd61529e 100644
--- a/tensorflow_model_optimization/python/core/quantization/keras/default_8bit/BUILD
+++ b/tensorflow_model_optimization/python/core/quantization/keras/default_8bit/BUILD
@@ -1,6 +1,6 @@
+load("//tensorflow_model_optimization:tensorflow_model_optimization.bzl", "py_strict_library", "py_strict_test")
 # Placeholder: load py_library
 # Placeholder: load py_test
-load("//tensorflow_model_optimization:tensorflow_model_optimization.bzl", "py_strict_library", "py_strict_test")
 
 package(default_visibility = [
     "//tensorflow_model_optimization:__subpackages__",
@@ -25,6 +25,7 @@ py_strict_library(
     srcs_version = "PY3",
     deps = [
         # tensorflow dep1,
+        "//tensorflow_model_optimization/python/core/keras:compat",
         "//tensorflow_model_optimization/python/core/quantization/keras:quantizers",
     ],
 )
@@ -41,6 +42,7 @@ py_test(
         # absl/testing:parameterized dep1,
         # google/protobuf:use_fast_cpp_protos dep1,  # Automatically added
         # tensorflow dep1,
+        "//tensorflow_model_optimization/python/core/keras:compat",
     ],
 )
 
@@ -64,6 +66,7 @@ py_strict_library(
     srcs_version = "PY3",
     deps = [
         # tensorflow dep1,
+        "//tensorflow_model_optimization/python/core/keras:compat",
         "//tensorflow_model_optimization/python/core/quantization/keras:quantize_config",
         "//tensorflow_model_optimization/python/core/quantization/keras:quantize_registry",
         "//tensorflow_model_optimization/python/core/quantization/keras:quantizers",
@@ -85,6 +88,7 @@ py_test(
         # google/protobuf:use_fast_cpp_protos dep1,  # Automatically added
         # numpy dep1,
         # tensorflow dep1,
+        "//tensorflow_model_optimization/python/core/keras:compat",
         "//tensorflow_model_optimization/python/core/quantization/keras:quantizers",
     ],
 )
@@ -99,6 +103,7 @@ py_library(
     deps = [
         # numpy dep1,
         # tensorflow dep1,
+        "//tensorflow_model_optimization/python/core/keras:compat",
         "//tensorflow_model_optimization/python/core/quantization/keras:quantize_aware_activation",
         "//tensorflow_model_optimization/python/core/quantization/keras:quantize_layer",
         "//tensorflow_model_optimization/python/core/quantization/keras:quantizers",
@@ -123,6 +128,7 @@ py_strict_test(
         # google/protobuf:use_fast_cpp_protos dep1,  # Automatically added
         # numpy dep1,
         # tensorflow dep1,
+        "//tensorflow_model_optimization/python/core/keras:compat",
         "//tensorflow_model_optimization/python/core/quantization/keras:quantize_aware_activation",
         "//tensorflow_model_optimization/python/core/quantization/keras:quantize_layer",
         "//tensorflow_model_optimization/python/core/quantization/keras:quantizers",
@@ -141,6 +147,7 @@ py_strict_library(
     deps = [
         ":default_8bit_transforms",
         # tensorflow dep1,
+        "//tensorflow_model_optimization/python/core/keras:compat",
         "//tensorflow_model_optimization/python/core/quantization/keras:quantize_layout_transform",
         "//tensorflow_model_optimization/python/core/quantization/keras/graph_transformations:model_transformer",
     ],
@@ -157,6 +164,7 @@ py_test(
         # google/protobuf:use_fast_cpp_protos dep1,  # Automatically added
         # numpy dep1,
         # tensorflow dep1,
+        "//tensorflow_model_optimization/python/core/keras:compat",
         "//tensorflow_model_optimization/python/core/quantization/keras:quantize",
         "//tensorflow_model_optimization/python/core/quantization/keras:utils",
     ],
diff --git a/tensorflow_model_optimization/python/core/quantization/keras/default_8bit/default_8bit_quantize_layout_transform.py b/tensorflow_model_optimization/python/core/quantization/keras/default_8bit/default_8bit_quantize_layout_transform.py
index 015ffda80..70b0681f7 100644
--- a/tensorflow_model_optimization/python/core/quantization/keras/default_8bit/default_8bit_quantize_layout_transform.py
+++ b/tensorflow_model_optimization/python/core/quantization/keras/default_8bit/default_8bit_quantize_layout_transform.py
@@ -20,15 +20,15 @@
 
 import tensorflow as tf
 
+from tensorflow_model_optimization.python.core.keras.compat import keras
 from tensorflow_model_optimization.python.core.quantization.keras import quantize_layout_transform
 from tensorflow_model_optimization.python.core.quantization.keras.default_8bit import default_8bit_transforms
 from tensorflow_model_optimization.python.core.quantization.keras.graph_transformations import model_transformer
 
-keras = tf.keras
-
 
 class Default8BitQuantizeLayoutTransform(
-    quantize_layout_transform.QuantizeLayoutTransform):
+    quantize_layout_transform.QuantizeLayoutTransform
+):
   """Default model transformations."""
 
   def apply(self, model, layer_quantize_map):
@@ -72,5 +72,5 @@ def apply(self, model, layer_quantize_map):
         default_8bit_transforms.LayerReluActivationQuantize(),
     ]
     return model_transformer.ModelTransformer(
-        model, transforms,
-        set(layer_quantize_map.keys()), layer_quantize_map).transform()
+        model, transforms, set(layer_quantize_map.keys()), layer_quantize_map
+    ).transform()
diff --git a/tensorflow_model_optimization/python/core/quantization/keras/default_8bit/default_8bit_quantize_registry.py b/tensorflow_model_optimization/python/core/quantization/keras/default_8bit/default_8bit_quantize_registry.py
index 248394e79..244a6c027 100644
--- a/tensorflow_model_optimization/python/core/quantization/keras/default_8bit/default_8bit_quantize_registry.py
+++ b/tensorflow_model_optimization/python/core/quantization/keras/default_8bit/default_8bit_quantize_registry.py
@@ -20,15 +20,17 @@
 
 import tensorflow as tf
 
+from tensorflow_model_optimization.python.core.keras.compat import keras
 from tensorflow_model_optimization.python.core.quantization.keras import quantize_config
 from tensorflow_model_optimization.python.core.quantization.keras import quantize_registry
 from tensorflow_model_optimization.python.core.quantization.keras import quantizers
 from tensorflow_model_optimization.python.core.quantization.keras.default_8bit import default_8bit_quantize_configs
 from tensorflow_model_optimization.python.core.quantization.keras.default_8bit import default_8bit_quantizers
 
+
 QuantizeConfig = quantize_config.QuantizeConfig
 
-layers = tf.keras.layers
+layers = keras.layers
 
 
 class _QuantizeInfo(object):
@@ -83,13 +85,10 @@ class Default8BitQuantizeRegistry(
       _QuantizeInfo(layers.LeakyReLU, [], [], True),
       # layers.PReLU,
       # layers.ThresholdedReLU,
-
       # Convolution Layers
       # _QuantizeInfo(layers.Conv1D, ['kernel'], ['activation']),
-
       # layers.Conv2D is supported and handled in code below.
       # layers.DepthwiseConv2D is supported and handled in code below.
-
       # _QuantizeInfo(layers.Conv3D, ['kernel'], ['activation']),
       # _QuantizeInfo(layers.Conv3DTranspose, ['kernel'], ['activation']),
       _QuantizeInfo(layers.Concatenate, [], [], True),
@@ -97,7 +96,6 @@ class Default8BitQuantizeRegistry(
       _no_quantize(layers.Cropping2D),
       _no_quantize(layers.Cropping3D),
       # _no_quantize(layers.UpSampling1D),
-
       # TODO(tfmot): Reduce the quantization errors for bilinear interpolation
       # type for UpSampling2D op. UpSampling2D supports two interpolation types,
       # nearest and bilinear. we convert the op to ResizeBilnear integer op on
@@ -111,15 +109,12 @@ class Default8BitQuantizeRegistry(
       # (Note that the nearest case just copies the number so there’s no more
       # errors even if the quantization order is different.)
       _QuantizeInfo(layers.UpSampling2D, [], [], True),
-
       # _no_quantize(layers.UpSampling3D),
       _no_quantize(layers.ZeroPadding1D),
       _no_quantize(layers.ZeroPadding2D),
       # _no_quantize(layers.ZeroPadding3D),
-
       # Supported via modifications in Transforms.
       # layers.SeparableConv1D, layers.SeparableConv2D,
-
       # Core Layers
       _no_quantize(layers.ActivityRegularization),
       _QuantizeInfo(layers.Dense, ['kernel'], ['activation']),
@@ -133,7 +128,6 @@ class Default8BitQuantizeRegistry(
       _no_quantize(layers.SpatialDropout2D),
       _no_quantize(layers.SpatialDropout3D),
       # layers.Lambda needs custom handling by the user.
-
       # Pooling Layers
       _QuantizeInfo(layers.AveragePooling1D, [], [], True),
       _QuantizeInfo(layers.AveragePooling2D, [], [], True),
@@ -147,34 +141,29 @@ class Default8BitQuantizeRegistry(
       # _no_quantize(layers.MaxPooling1D),
       _no_quantize(layers.MaxPooling2D),
       # _no_quantize(layers.MaxPooling3D),
-
       # _QuantizeInfo(layers.LocallyConnected1D, ['kernel'], ['activation']),
       # _QuantizeInfo(layers.LocallyConnected2D, ['kernel'], ['activation']),
       _QuantizeInfo(layers.Add, [], [], True),
-
       # Enable once verified with TFLite behavior.
       # layers.Embedding: ['embeddings'],
-
       # BatchNormalization is handled elsewhere, in the cases
       # where it's preceded by convolutional layers.
       #   layers.BatchNormalization: [],
-
       # Merge layers to be added.
-
       # RNN Cells
       # TODO(pulkitb): Verify RNN layers behavior.
       # TODO(tfmot): check if we still need to allowlist via compat.v1 and
       # compat.v2 to support legacy TensorFlow 2.X
       # behavior where the v2 RNN uses the v1 RNNCell instead of the v2 RNNCell.
       # See b/145939875 for details.
-      # _QuantizeInfo(tf.keras.layers.GRUCell, ['kernel', 'recurrent_kernel'],
+      # _QuantizeInfo(keras.layers.GRUCell, ['kernel', 'recurrent_kernel'],
       #               ['activation', 'recurrent_activation']),
-      # _QuantizeInfo(tf.keras.layers.LSTMCell, ['kernel', 'recurrent_kernel'],
+      # _QuantizeInfo(keras.layers.LSTMCell, ['kernel', 'recurrent_kernel'],
       #               ['activation', 'recurrent_activation']),
-      # _QuantizeInfo(tf.keras.experimental.PeepholeLSTMCell,
+      # _QuantizeInfo(keras.experimental.PeepholeLSTMCell,
       #               ['kernel', 'recurrent_kernel'],
       #               ['activation', 'recurrent_activation']),
-      # _QuantizeInfo(tf.keras.layers.SimpleRNNCell,
+      # _QuantizeInfo(keras.layers.SimpleRNNCell,
       #               ['kernel', 'recurrent_kernel'],
       #               ['activation', 'recurrent_activation']),
   ]
diff --git a/tensorflow_model_optimization/python/core/quantization/keras/default_8bit/default_8bit_quantize_registry_test.py b/tensorflow_model_optimization/python/core/quantization/keras/default_8bit/default_8bit_quantize_registry_test.py
index 64aa7421d..9c14d52ab 100644
--- a/tensorflow_model_optimization/python/core/quantization/keras/default_8bit/default_8bit_quantize_registry_test.py
+++ b/tensorflow_model_optimization/python/core/quantization/keras/default_8bit/default_8bit_quantize_registry_test.py
@@ -19,18 +19,19 @@
 from __future__ import print_function
 
 import unittest
-from absl.testing import parameterized
 
+from absl.testing import parameterized
 import numpy as np
 import tensorflow as tf
 
+from tensorflow_model_optimization.python.core.keras.compat import keras
 from tensorflow_model_optimization.python.core.quantization.keras import quantizers
 from tensorflow_model_optimization.python.core.quantization.keras import utils as quantize_utils
 from tensorflow_model_optimization.python.core.quantization.keras.default_8bit import default_8bit_quantize_registry
 
-keras = tf.keras
-K = tf.keras.backend
-l = tf.keras.layers
+
+K = keras.backend
+l = keras.layers
 
 deserialize_keras_object = quantize_utils.deserialize_keras_object
 serialize_keras_object = quantize_utils.serialize_keras_object
diff --git a/tensorflow_model_optimization/python/core/quantization/keras/default_8bit/default_8bit_quantizers.py b/tensorflow_model_optimization/python/core/quantization/keras/default_8bit/default_8bit_quantizers.py
index 59705ef96..71ac3d799 100644
--- a/tensorflow_model_optimization/python/core/quantization/keras/default_8bit/default_8bit_quantizers.py
+++ b/tensorflow_model_optimization/python/core/quantization/keras/default_8bit/default_8bit_quantizers.py
@@ -16,6 +16,7 @@
 
 import tensorflow as tf
 
+from tensorflow_model_optimization.python.core.keras.compat import keras
 from tensorflow_model_optimization.python.core.quantization.keras import quantizers
 
 
@@ -32,13 +33,15 @@ def build(self, tensor_shape, name, layer):
     min_weight = layer.add_weight(
         name + '_min',
         shape=(tensor_shape[-1],),
-        initializer=tf.keras.initializers.Constant(-6.0),
-        trainable=False)
+        initializer=keras.initializers.Constant(-6.0),
+        trainable=False,
+    )
     max_weight = layer.add_weight(
         name + '_max',
         shape=(tensor_shape[-1],),
-        initializer=tf.keras.initializers.Constant(6.0),
-        trainable=False)
+        initializer=keras.initializers.Constant(6.0),
+        trainable=False,
+    )
 
     return {'min_var': min_weight, 'max_var': max_weight}
 
diff --git a/tensorflow_model_optimization/python/core/quantization/keras/default_8bit/default_8bit_quantizers_test.py b/tensorflow_model_optimization/python/core/quantization/keras/default_8bit/default_8bit_quantizers_test.py
index 7da1d114a..f2d284451 100644
--- a/tensorflow_model_optimization/python/core/quantization/keras/default_8bit/default_8bit_quantizers_test.py
+++ b/tensorflow_model_optimization/python/core/quantization/keras/default_8bit/default_8bit_quantizers_test.py
@@ -19,14 +19,15 @@
 from __future__ import print_function
 
 from absl.testing import parameterized
-
 import tensorflow as tf
 
+from tensorflow_model_optimization.python.core.keras.compat import keras
 from tensorflow_model_optimization.python.core.quantization.keras.default_8bit import default_8bit_quantizers
 
-Default8BitConvWeightsQuantizer = default_8bit_quantizers.Default8BitConvWeightsQuantizer
 
-keras = tf.keras
+Default8BitConvWeightsQuantizer = (
+    default_8bit_quantizers.Default8BitConvWeightsQuantizer
+)
 
 
 class Default8BitConvWeightsQuantizerTest(tf.test.TestCase,
diff --git a/tensorflow_model_optimization/python/core/quantization/keras/default_8bit/default_8bit_transforms.py b/tensorflow_model_optimization/python/core/quantization/keras/default_8bit/default_8bit_transforms.py
index c6feb138a..bda38f6fb 100644
--- a/tensorflow_model_optimization/python/core/quantization/keras/default_8bit/default_8bit_transforms.py
+++ b/tensorflow_model_optimization/python/core/quantization/keras/default_8bit/default_8bit_transforms.py
@@ -20,6 +20,8 @@
 import numpy as np
 import tensorflow as tf
 
+from tensorflow_model_optimization.python.core.keras.compat import keras
+from tensorflow_model_optimization.python.core.keras.compat import unique_object_name
 from tensorflow_model_optimization.python.core.quantization.keras import quantize_aware_activation
 from tensorflow_model_optimization.python.core.quantization.keras import quantize_layer
 from tensorflow_model_optimization.python.core.quantization.keras import quantizers
@@ -28,24 +30,10 @@
 from tensorflow_model_optimization.python.core.quantization.keras.default_8bit import default_8bit_quantize_registry
 from tensorflow_model_optimization.python.core.quantization.keras.graph_transformations import transforms
 
-try:
-  # OSS
-  import keras  # pylint: disable=g-import-not-at-top
-  if hasattr(keras, 'src'):
-    # Path as seen in pip packages as of TF/Keras 2.13.
-    from keras.src.backend import unique_object_name  # pylint: disable=g-import-not-at-top
-  else:
-    from keras.backend import unique_object_name  # pylint: disable=g-import-not-at-top,g-importing-member
-except ImportError:
-  # Internal
-  unique_object_name = tf._keras_internal.backend.unique_object_name  # pylint: disable=protected-access
-
 
 LayerNode = transforms.LayerNode
 LayerPattern = transforms.LayerPattern
 
-keras = tf.keras
-
 
 def _get_conv_bn_layers(bn_layer_node):
   bn_layer = bn_layer_node.layer
@@ -397,14 +385,14 @@ def replacement(self, match_layer):
 
     # TODO(pulkitb): Handle other base_layer args such as dtype, input_dim etc.
 
-    sepconv2d_layer = tf.keras.layers.SeparableConv2D(
+    sepconv2d_layer = keras.layers.SeparableConv2D(
         filters=sepconv1d_config['filters'],
         kernel_size=(1,) + _normalize_tuple(sepconv1d_config['kernel_size']),
         strides=_normalize_tuple(sepconv1d_config['strides']) * 2,
         padding=padding,
         data_format=sepconv1d_config['data_format'],
-        dilation_rate=(1,) + _normalize_tuple(
-            sepconv1d_config['dilation_rate']),
+        dilation_rate=(1,)
+        + _normalize_tuple(sepconv1d_config['dilation_rate']),
         depth_multiplier=sepconv1d_config['depth_multiplier'],
         activation=sepconv1d_config['activation'],
         use_bias=sepconv1d_config['use_bias'],
@@ -421,7 +409,7 @@ def replacement(self, match_layer):
         # TODO(pulkitb): Rethink what to do for name. Using the same name leads
         # to confusion, since it's typically separable_conv1d
         name=sepconv1d_config['name'] + '_QAT_SepConv2D',
-        trainable=sepconv1d_config['trainable']
+        trainable=sepconv1d_config['trainable'],
     )
 
     sepconv2d_weights = collections.OrderedDict()
@@ -448,9 +436,10 @@ def replacement(self, match_layer):
     # TODO(pulkitb): Consider moving from Lambda to custom ExpandDims/Squeeze.
 
     # Layer before SeparableConv2D which expands input tensors to match 2D.
-    expand_layer = tf.keras.layers.Lambda(
+    expand_layer = keras.layers.Lambda(
         lambda x: tf.expand_dims(x, spatial_dim),
-        name=self._get_name('sepconv1d_expand'))
+        name=self._get_name('sepconv1d_expand'),
+    )
     expand_layer_config = quantize_utils.serialize_layer(
         expand_layer, use_legacy_format=True
     )
@@ -458,9 +447,10 @@ def replacement(self, match_layer):
     expand_layer_metadata = {
         'quantize_config': default_8bit_quantize_configs.NoOpQuantizeConfig()}
 
-    squeeze_layer = tf.keras.layers.Lambda(
+    squeeze_layer = keras.layers.Lambda(
         lambda x: tf.squeeze(x, [spatial_dim]),
-        name=self._get_name('sepconv1d_squeeze'))
+        name=self._get_name('sepconv1d_squeeze'),
+    )
     squeeze_layer_config = quantize_utils.serialize_layer(
         squeeze_layer, use_legacy_format=True
     )
@@ -512,7 +502,7 @@ def replacement(self, match_layer):
     # Needs special handling: weights
     # Unknown: dynamic, autocast
 
-    dconv_layer = tf.keras.layers.DepthwiseConv2D(
+    dconv_layer = keras.layers.DepthwiseConv2D(
         kernel_size=sepconv_layer['config']['kernel_size'],
         strides=sepconv_layer['config']['strides'],
         padding=sepconv_layer['config']['padding'],
@@ -524,7 +514,7 @@ def replacement(self, match_layer):
         depthwise_initializer=sepconv_layer['config']['depthwise_initializer'],
         depthwise_regularizer=sepconv_layer['config']['depthwise_regularizer'],
         depthwise_constraint=sepconv_layer['config']['depthwise_constraint'],
-        trainable=sepconv_layer['config']['trainable']
+        trainable=sepconv_layer['config']['trainable'],
     )
     dconv_weights = collections.OrderedDict()
     dconv_weights['depthwise_kernel:0'] = sepconv_weights[0]
@@ -535,7 +525,7 @@ def replacement(self, match_layer):
     # Needed to ensure these new layers are considered for quantization.
     dconv_metadata = {'quantize_config': None}
 
-    conv_layer = tf.keras.layers.Conv2D(
+    conv_layer = keras.layers.Conv2D(
         filters=sepconv_layer['config']['filters'],
         kernel_size=(1, 1),  # (1,) * rank
         strides=(1, 1),
@@ -552,7 +542,7 @@ def replacement(self, match_layer):
         activity_regularizer=sepconv_layer['config']['activity_regularizer'],
         kernel_constraint=sepconv_layer['config']['pointwise_constraint'],
         bias_constraint=sepconv_layer['config']['bias_constraint'],
-        trainable=sepconv_layer['config']['trainable']
+        trainable=sepconv_layer['config']['trainable'],
     )
     conv_weights = collections.OrderedDict()
     conv_weights['kernel:0'] = sepconv_weights[1]
@@ -659,7 +649,7 @@ def pattern(self):
   def _get_layer_type(self, layer_class_name):
     if layer_class_name == 'QuantizeLayer':
       return quantize_layer.QuantizeLayer
-    keras_layers = inspect.getmembers(tf.keras.layers, inspect.isclass)
+    keras_layers = inspect.getmembers(keras.layers, inspect.isclass)
     for layer_name, layer_type in keras_layers:
       if layer_name == layer_class_name:
         return layer_type
diff --git a/tensorflow_model_optimization/python/core/quantization/keras/default_8bit/default_8bit_transforms_test.py b/tensorflow_model_optimization/python/core/quantization/keras/default_8bit/default_8bit_transforms_test.py
index f096ab963..2c5ed815d 100644
--- a/tensorflow_model_optimization/python/core/quantization/keras/default_8bit/default_8bit_transforms_test.py
+++ b/tensorflow_model_optimization/python/core/quantization/keras/default_8bit/default_8bit_transforms_test.py
@@ -22,6 +22,7 @@
 import numpy as np
 import tensorflow as tf
 
+from tensorflow_model_optimization.python.core.keras.compat import keras
 from tensorflow_model_optimization.python.core.quantization.keras import quantize_aware_activation
 from tensorflow_model_optimization.python.core.quantization.keras import quantize_layer
 from tensorflow_model_optimization.python.core.quantization.keras import quantizers
@@ -32,14 +33,13 @@
 from tensorflow_model_optimization.python.core.quantization.keras.layers import conv_batchnorm_test_utils
 from tensorflow_model_optimization.python.core.quantization.keras.layers import dense_batchnorm_test_utils
 
+
 ModelTransformer = model_transformer.ModelTransformer
 
 Conv2DModel = conv_batchnorm_test_utils.Conv2DModel
 DepthwiseConv2DModel = conv_batchnorm_test_utils.DepthwiseConv2DModel
 DenseModel = dense_batchnorm_test_utils.DenseModel
 
-keras = tf.keras
-
 Conv2DBatchNormActivationQuantize = default_8bit_transforms.Conv2DBatchNormActivationQuantize
 Conv2DBatchNormReLUQuantize = default_8bit_transforms.Conv2DBatchNormReLUQuantize
 
@@ -287,15 +287,23 @@ def testDenseBatchNormActivationQuantize(self, layer_type,
       ('strides', {'strides': 2}),
       ('dilation_rate', {'dilation_rate': 2}),
       ('depth_multiplier', {'depth_multiplier': 2}),
-      ('regularizer', {
-          'depthwise_regularizer': 'l2',
-          'pointwise_regularizer': 'l2',
-          'bias_regularizer': 'l2',
-          'activity_regularizer': 'l2'}),
-      ('constraint', {
-          'depthwise_constraint': tf.keras.constraints.max_norm(2.),
-          'pointwise_constraint': tf.keras.constraints.min_max_norm(0., 2.),
-          'bias_constraint': tf.keras.constraints.unit_norm()}),
+      (
+          'regularizer',
+          {
+              'depthwise_regularizer': 'l2',
+              'pointwise_regularizer': 'l2',
+              'bias_regularizer': 'l2',
+              'activity_regularizer': 'l2',
+          },
+      ),
+      (
+          'constraint',
+          {
+              'depthwise_constraint': keras.constraints.max_norm(2.0),
+              'pointwise_constraint': keras.constraints.min_max_norm(0.0, 2.0),
+              'bias_constraint': keras.constraints.unit_norm(),
+          },
+      ),
       ('activation_relu', {'activation': 'relu'}),
       # TODO(pulkitb): Temporarily disabling due to numerical errors resulting
       # from caching of activation logits in TF code.
@@ -308,10 +316,10 @@ def testSeparableConv1DQuantize_(self, kwargs):
     stack_size = 3
     num_row = 7
 
-    sepconv_model = tf.keras.Sequential([
-        tf.keras.Input(
-            shape=(num_row, stack_size), batch_size=num_samples),
-        tf.keras.layers.SeparableConv1D(**kwargs)])
+    sepconv_model = keras.Sequential([
+        keras.Input(shape=(num_row, stack_size), batch_size=num_samples),
+        keras.layers.SeparableConv1D(**kwargs),
+    ])
 
     transformed_model, updated_metadata = ModelTransformer(
         sepconv_model,
@@ -344,21 +352,28 @@ def testSeparableConv1DQuantize_(self, kwargs):
   @parameterized.named_parameters(
       ('padding_valid', {'padding': 'valid'}),
       ('padding_same', {'padding': 'same'}),
-      ('padding_same_dilation_2',
-       {'padding': 'same', 'dilation_rate': 2}),
+      ('padding_same_dilation_2', {'padding': 'same', 'dilation_rate': 2}),
       ('strides', {'strides': 2}),
       ('dilation_rate', {'dilation_rate': 2}),
       ('depth_multiplier', {'depth_multiplier': 2}),
-      ('regularizer', {
-          'depthwise_regularizer': 'l2',
-          'pointwise_regularizer': 'l2',
-          'bias_regularizer': 'l2',
-          'activity_regularizer': 'l2'}),
+      (
+          'regularizer',
+          {
+              'depthwise_regularizer': 'l2',
+              'pointwise_regularizer': 'l2',
+              'bias_regularizer': 'l2',
+              'activity_regularizer': 'l2',
+          },
+      ),
       ('use_bias', {'use_bias': False}),
-      ('constraint', {
-          'depthwise_constraint': tf.keras.constraints.max_norm(2.),
-          'pointwise_constraint': tf.keras.constraints.min_max_norm(0., 2.),
-          'bias_constraint': tf.keras.constraints.unit_norm()})
+      (
+          'constraint',
+          {
+              'depthwise_constraint': keras.constraints.max_norm(2.0),
+              'pointwise_constraint': keras.constraints.min_max_norm(0.0, 2.0),
+              'bias_constraint': keras.constraints.unit_norm(),
+          },
+      ),
   )
   def testSeparableConvQuantize_(self, kwargs):
     kwargs['filters'] = 2
@@ -368,10 +383,12 @@ def testSeparableConvQuantize_(self, kwargs):
     num_row = 7
     num_col = 6
 
-    sepconv_model = tf.keras.Sequential([
-        tf.keras.Input(
-            shape=(num_row, num_col, stack_size), batch_size=num_samples),
-        tf.keras.layers.SeparableConv2D(**kwargs)])
+    sepconv_model = keras.Sequential([
+        keras.Input(
+            shape=(num_row, num_col, stack_size), batch_size=num_samples
+        ),
+        keras.layers.SeparableConv2D(**kwargs),
+    ])
 
     transformed_model, updated_metadata = ModelTransformer(
         sepconv_model,
@@ -439,13 +456,13 @@ def testAddReLUQuantize(self, activation_type, transform_type):
   def testLayerReLUQuantize(self, activation_type, transform_type):
     # TODO(b/185727342): Add tests for DepthConv and Dense
     input_shape = (3, 3, 3)
-    conv_layer = tf.keras.layers.Conv2D(5, 2, input_shape=input_shape)
+    conv_layer = keras.layers.Conv2D(5, 2, input_shape=input_shape)
     if activation_type == 'relu':
       act_layer = keras.layers.ReLU(6.0)
     elif activation_type == 'act_relu':
       act_layer = keras.layers.Activation('relu')
 
-    model = tf.keras.Sequential([conv_layer, act_layer])
+    model = keras.Sequential([conv_layer, act_layer])
 
     transformed_model, updated_metadata = ModelTransformer(
         model,
@@ -707,6 +724,6 @@ def testConcatConcatTransformDisablesOutput(self):
 
 
 if __name__ == '__main__':
-  if hasattr(tf.keras.__internal__, 'enable_unsafe_deserialization'):
-    tf.keras.__internal__.enable_unsafe_deserialization()
+  if hasattr(keras.__internal__, 'enable_unsafe_deserialization'):
+    keras.__internal__.enable_unsafe_deserialization()
   tf.test.main()
diff --git a/tensorflow_model_optimization/python/core/quantization/keras/default_8bit/quantize_numerical_test.py b/tensorflow_model_optimization/python/core/quantization/keras/default_8bit/quantize_numerical_test.py
index 96eae7e39..f66d69859 100644
--- a/tensorflow_model_optimization/python/core/quantization/keras/default_8bit/quantize_numerical_test.py
+++ b/tensorflow_model_optimization/python/core/quantization/keras/default_8bit/quantize_numerical_test.py
@@ -18,10 +18,10 @@
 import tempfile
 
 from absl.testing import parameterized
-
 import numpy as np
 import tensorflow as tf
 
+from tensorflow_model_optimization.python.core.keras.compat import keras
 from tensorflow_model_optimization.python.core.quantization.keras import quantize
 from tensorflow_model_optimization.python.core.quantization.keras import utils
 
@@ -57,104 +57,108 @@ def _execute_tflite(self, tflite_file, x_test, y_test):
     return y_
 
   def _get_single_conv_model(self):
-    i = tf.keras.Input(shape=(32, 32, 3))
-    x = tf.keras.layers.Conv2D(2, kernel_size=(3, 3), strides=(2, 2))(i)
-    return tf.keras.Model(i, x)
+    i = keras.Input(shape=(32, 32, 3))
+    x = keras.layers.Conv2D(2, kernel_size=(3, 3), strides=(2, 2))(i)
+    return keras.Model(i, x)
 
   def _get_single_dense_model(self):
-    i = tf.keras.Input(shape=(5,))
-    x = tf.keras.layers.Dense(3)(i)
-    return tf.keras.Model(i, x)
+    i = keras.Input(shape=(5,))
+    x = keras.layers.Dense(3)(i)
+    return keras.Model(i, x)
 
   def _get_single_conv_relu_model(self):
-    i = tf.keras.Input(shape=(6, 6, 3))
-    x = tf.keras.layers.Conv2D(
-        2, kernel_size=(3, 3), strides=(2, 2), activation='relu')(i)
-    x = tf.keras.layers.ReLU()(x)
-    return tf.keras.Model(i, x)
+    i = keras.Input(shape=(6, 6, 3))
+    x = keras.layers.Conv2D(
+        2, kernel_size=(3, 3), strides=(2, 2), activation='relu'
+    )(i)
+    x = keras.layers.ReLU()(x)
+    return keras.Model(i, x)
 
   def _get_stacked_convs_model(self):
-    i = tf.keras.Input(shape=(64, 64, 3))
-    x = tf.keras.layers.Conv2D(
-        10, kernel_size=(3, 3), strides=(1, 1), activation='relu')(i)
-    x = tf.keras.layers.Conv2D(
+    i = keras.Input(shape=(64, 64, 3))
+    x = keras.layers.Conv2D(
+        10, kernel_size=(3, 3), strides=(1, 1), activation='relu'
+    )(i)
+    x = keras.layers.Conv2D(
         # Setting strides to (1, 1) passes test, (2, 2) fails test?
         # Somehow one value is at border.
         # Train over 100 epochs, and issue goes away.
         # Why are all the first values zero?
-        10, kernel_size=(3, 3), strides=(2, 2), activation='relu')(x)
-    x = tf.keras.layers.Conv2D(
-        10, kernel_size=(3, 3), strides=(2, 2), activation='relu')(x)
-    x = tf.keras.layers.Conv2D(
-        5, kernel_size=(3, 3), strides=(2, 2), activation='relu')(x)
-    x = tf.keras.layers.Conv2D(
-        2, kernel_size=(3, 3), strides=(2, 2), activation='relu')(x)
-    return tf.keras.Model(i, x)
+        10,
+        kernel_size=(3, 3),
+        strides=(2, 2),
+        activation='relu',
+    )(x)
+    x = keras.layers.Conv2D(
+        10, kernel_size=(3, 3), strides=(2, 2), activation='relu'
+    )(x)
+    x = keras.layers.Conv2D(
+        5, kernel_size=(3, 3), strides=(2, 2), activation='relu'
+    )(x)
+    x = keras.layers.Conv2D(
+        2, kernel_size=(3, 3), strides=(2, 2), activation='relu'
+    )(x)
+    return keras.Model(i, x)
 
   def _get_conv_bn_relu_model(self):
-    i = tf.keras.Input(shape=(6, 6, 3))
-    x = tf.keras.layers.Conv2D(3, kernel_size=(3, 3), strides=(2, 2))(i)
-    x = tf.keras.layers.BatchNormalization()(x)
-    x = tf.keras.layers.ReLU()(x)
-    return tf.keras.Model(i, x)
+    i = keras.Input(shape=(6, 6, 3))
+    x = keras.layers.Conv2D(3, kernel_size=(3, 3), strides=(2, 2))(i)
+    x = keras.layers.BatchNormalization()(x)
+    x = keras.layers.ReLU()(x)
+    return keras.Model(i, x)
 
   def _get_depthconv_bn_relu_model(self):
-    i = tf.keras.Input(shape=(6, 6, 3))
-    x = tf.keras.layers.DepthwiseConv2D(kernel_size=(3, 3), strides=(2, 2))(i)
-    x = tf.keras.layers.BatchNormalization()(x)
-    x = tf.keras.layers.ReLU()(x)
-    return tf.keras.Model(i, x)
+    i = keras.Input(shape=(6, 6, 3))
+    x = keras.layers.DepthwiseConv2D(kernel_size=(3, 3), strides=(2, 2))(i)
+    x = keras.layers.BatchNormalization()(x)
+    x = keras.layers.ReLU()(x)
+    return keras.Model(i, x)
 
   def _get_separable_conv2d_model(self):
-    i = tf.keras.Input(shape=(12, 12, 3))
-    x = tf.keras.layers.SeparableConv2D(
-        filters=5, kernel_size=(3, 3), strides=(2, 2))(i)
-    x = tf.keras.layers.BatchNormalization()(x)
-    x = tf.keras.layers.ReLU()(x)
-    return tf.keras.Model(i, x)
+    i = keras.Input(shape=(12, 12, 3))
+    x = keras.layers.SeparableConv2D(
+        filters=5, kernel_size=(3, 3), strides=(2, 2)
+    )(i)
+    x = keras.layers.BatchNormalization()(x)
+    x = keras.layers.ReLU()(x)
+    return keras.Model(i, x)
 
   def _get_sepconv1d_bn_relu_model(self):
-    i = tf.keras.Input(shape=(8, 3))
-    x = tf.keras.layers.SeparableConv1D(
-        filters=5, kernel_size=3, strides=2)(i)
-    x = tf.keras.layers.BatchNormalization()(x)
-    x = tf.keras.layers.ReLU()(x)
-    return tf.keras.Model(i, x)
+    i = keras.Input(shape=(8, 3))
+    x = keras.layers.SeparableConv1D(filters=5, kernel_size=3, strides=2)(i)
+    x = keras.layers.BatchNormalization()(x)
+    x = keras.layers.ReLU()(x)
+    return keras.Model(i, x)
 
   def _get_sepconv1d_bn_model(self):
-    i = tf.keras.Input(shape=(8, 3))
-    x = tf.keras.layers.SeparableConv1D(
-        filters=5, kernel_size=3, strides=2)(i)
-    x = tf.keras.layers.BatchNormalization()(x)
-    return tf.keras.Model(i, x)
+    i = keras.Input(shape=(8, 3))
+    x = keras.layers.SeparableConv1D(filters=5, kernel_size=3, strides=2)(i)
+    x = keras.layers.BatchNormalization()(x)
+    return keras.Model(i, x)
 
   def _get_sepconv1d_stacked_model(self):
-    i = tf.keras.Input(shape=(8, 3))
-    x = tf.keras.layers.SeparableConv1D(
-        filters=5, kernel_size=3, strides=2)(i)
-    x = tf.keras.layers.BatchNormalization()(x)
-    x = tf.keras.layers.SeparableConv1D(
-        filters=5, kernel_size=3, strides=2)(x)
-    x = tf.keras.layers.BatchNormalization()(x)
-    x = tf.keras.layers.ReLU()(x)
-    return tf.keras.Model(i, x)
+    i = keras.Input(shape=(8, 3))
+    x = keras.layers.SeparableConv1D(filters=5, kernel_size=3, strides=2)(i)
+    x = keras.layers.BatchNormalization()(x)
+    x = keras.layers.SeparableConv1D(filters=5, kernel_size=3, strides=2)(x)
+    x = keras.layers.BatchNormalization()(x)
+    x = keras.layers.ReLU()(x)
+    return keras.Model(i, x)
 
   def _get_upsampling2d_nearest_model(self):
-    i = tf.keras.Input(shape=(32, 32, 3))
-    x = tf.keras.layers.UpSampling2D(size=(3, 4), interpolation='nearest')(i)
-    return tf.keras.Model(i, x)
+    i = keras.Input(shape=(32, 32, 3))
+    x = keras.layers.UpSampling2D(size=(3, 4), interpolation='nearest')(i)
+    return keras.Model(i, x)
 
   def _get_upsampling2d_bilinear_model(self):
-    i = tf.keras.Input(shape=(1, 3, 1))
-    x = tf.keras.layers.UpSampling2D(size=(1, 5), interpolation='bilinear')(i)
-    return tf.keras.Model(i, x)
+    i = keras.Input(shape=(1, 3, 1))
+    x = keras.layers.UpSampling2D(size=(1, 5), interpolation='bilinear')(i)
+    return keras.Model(i, x)
 
   def _get_conv2d_transpose_model(self):
-    i = tf.keras.Input(shape=(32, 32, 3))
-    x = tf.keras.layers.Conv2DTranspose(
-        2, kernel_size=(3, 3), strides=(2, 2))(
-            i)
-    return tf.keras.Model(i, x)
+    i = keras.Input(shape=(32, 32, 3))
+    x = keras.layers.Conv2DTranspose(2, kernel_size=(3, 3), strides=(2, 2))(i)
+    return keras.Model(i, x)
 
   @parameterized.parameters([
       _get_single_conv_model, _get_single_dense_model,
@@ -200,6 +204,6 @@ def testModelEndToEnd(self, model_fn):
 
 
 if __name__ == '__main__':
-  if hasattr(tf.keras.__internal__, 'enable_unsafe_deserialization'):
-    tf.keras.__internal__.enable_unsafe_deserialization()
+  if hasattr(keras.__internal__, 'enable_unsafe_deserialization'):
+    keras.__internal__.enable_unsafe_deserialization()
   tf.test.main()
diff --git a/tensorflow_model_optimization/python/core/quantization/keras/experimental/default_n_bit/BUILD b/tensorflow_model_optimization/python/core/quantization/keras/experimental/default_n_bit/BUILD
index 9acacf98a..d0fe054aa 100644
--- a/tensorflow_model_optimization/python/core/quantization/keras/experimental/default_n_bit/BUILD
+++ b/tensorflow_model_optimization/python/core/quantization/keras/experimental/default_n_bit/BUILD
@@ -1,6 +1,6 @@
+load("//tensorflow_model_optimization:tensorflow_model_optimization.bzl", "py_strict_library", "py_strict_test")
 # Placeholder: load py_library
 # Placeholder: load py_test
-load("//tensorflow_model_optimization:tensorflow_model_optimization.bzl", "py_strict_library", "py_strict_test")
 
 package(default_visibility = [
     "//tensorflow_model_optimization:__subpackages__",
@@ -25,6 +25,7 @@ py_strict_library(
     srcs_version = "PY3",
     deps = [
         # tensorflow dep1,
+        "//tensorflow_model_optimization/python/core/keras:compat",
         "//tensorflow_model_optimization/python/core/quantization/keras:quantizers",
     ],
 )
@@ -40,6 +41,7 @@ py_test(
         # absl/testing:parameterized dep1,
         # google/protobuf:use_fast_cpp_protos dep1,  # Automatically added
         # tensorflow dep1,
+        "//tensorflow_model_optimization/python/core/keras:compat",
     ],
 )
 
@@ -63,6 +65,7 @@ py_strict_library(
     srcs_version = "PY3",
     deps = [
         # tensorflow dep1,
+        "//tensorflow_model_optimization/python/core/keras:compat",
         "//tensorflow_model_optimization/python/core/quantization/keras:quantize_config",
         "//tensorflow_model_optimization/python/core/quantization/keras:quantize_registry",
         "//tensorflow_model_optimization/python/core/quantization/keras:quantizers",
@@ -83,6 +86,7 @@ py_test(
         # google/protobuf:use_fast_cpp_protos dep1,  # Automatically added
         # numpy dep1,
         # tensorflow dep1,
+        "//tensorflow_model_optimization/python/core/keras:compat",
         "//tensorflow_model_optimization/python/core/quantization/keras:quantizers",
     ],
 )
@@ -97,6 +101,7 @@ py_library(
     deps = [
         # numpy dep1,
         # tensorflow dep1,
+        "//tensorflow_model_optimization/python/core/keras:compat",
         "//tensorflow_model_optimization/python/core/quantization/keras:quantize_aware_activation",
         "//tensorflow_model_optimization/python/core/quantization/keras:quantize_layer",
         "//tensorflow_model_optimization/python/core/quantization/keras:quantizers",
@@ -115,6 +120,7 @@ py_strict_library(
     deps = [
         ":default_n_bit_transforms",
         # tensorflow dep1,
+        "//tensorflow_model_optimization/python/core/keras:compat",
         "//tensorflow_model_optimization/python/core/quantization/keras:quantize_layout_transform",
         "//tensorflow_model_optimization/python/core/quantization/keras/graph_transformations:model_transformer",
     ],
@@ -134,6 +140,7 @@ py_strict_test(
         # google/protobuf:use_fast_cpp_protos dep1,  # Automatically added
         # numpy dep1,
         # tensorflow dep1,
+        "//tensorflow_model_optimization/python/core/keras:compat",
         "//tensorflow_model_optimization/python/core/quantization/keras:quantize_aware_activation",
         "//tensorflow_model_optimization/python/core/quantization/keras:quantize_layer",
         "//tensorflow_model_optimization/python/core/quantization/keras:quantizers",
diff --git a/tensorflow_model_optimization/python/core/quantization/keras/experimental/default_n_bit/default_n_bit_quantize_layout_transform.py b/tensorflow_model_optimization/python/core/quantization/keras/experimental/default_n_bit/default_n_bit_quantize_layout_transform.py
index ed19e39dd..50089c509 100644
--- a/tensorflow_model_optimization/python/core/quantization/keras/experimental/default_n_bit/default_n_bit_quantize_layout_transform.py
+++ b/tensorflow_model_optimization/python/core/quantization/keras/experimental/default_n_bit/default_n_bit_quantize_layout_transform.py
@@ -20,16 +20,15 @@
 
 import tensorflow as tf
 
+from tensorflow_model_optimization.python.core.keras.compat import keras
 from tensorflow_model_optimization.python.core.quantization.keras import quantize_layout_transform
 from tensorflow_model_optimization.python.core.quantization.keras.experimental.default_n_bit import default_n_bit_transforms
 from tensorflow_model_optimization.python.core.quantization.keras.graph_transformations import model_transformer
 
 
-keras = tf.keras
-
-
 class DefaultNBitQuantizeLayoutTransform(
-    quantize_layout_transform.QuantizeLayoutTransform):
+    quantize_layout_transform.QuantizeLayoutTransform
+):
   """Default model transformations."""
 
   def __init__(self, num_bits_weight: int = 8, num_bits_activation: int = 8):
@@ -58,62 +57,81 @@ def apply(self, model, layer_quantize_map):
     transforms = [
         default_n_bit_transforms.InputLayerQuantize(
             num_bits_weight=self._num_bits_weight,
-            num_bits_activation=self._num_bits_activation),
+            num_bits_activation=self._num_bits_activation,
+        ),
         default_n_bit_transforms.SeparableConv1DQuantize(
             num_bits_weight=self._num_bits_weight,
-            num_bits_activation=self._num_bits_activation),
+            num_bits_activation=self._num_bits_activation,
+        ),
         default_n_bit_transforms.SeparableConvQuantize(
             num_bits_weight=self._num_bits_weight,
-            num_bits_activation=self._num_bits_activation),
+            num_bits_activation=self._num_bits_activation,
+        ),
         default_n_bit_transforms.Conv2DReshapeBatchNormReLUQuantize(
             num_bits_weight=self._num_bits_weight,
-            num_bits_activation=self._num_bits_activation),
+            num_bits_activation=self._num_bits_activation,
+        ),
         default_n_bit_transforms.Conv2DReshapeBatchNormActivationQuantize(
             num_bits_weight=self._num_bits_weight,
-            num_bits_activation=self._num_bits_activation),
+            num_bits_activation=self._num_bits_activation,
+        ),
         default_n_bit_transforms.Conv2DBatchNormReLUQuantize(
             num_bits_weight=self._num_bits_weight,
-            num_bits_activation=self._num_bits_activation),
+            num_bits_activation=self._num_bits_activation,
+        ),
         default_n_bit_transforms.Conv2DBatchNormActivationQuantize(
             num_bits_weight=self._num_bits_weight,
-            num_bits_activation=self._num_bits_activation),
+            num_bits_activation=self._num_bits_activation,
+        ),
         default_n_bit_transforms.Conv2DReshapeBatchNormQuantize(
             num_bits_weight=self._num_bits_weight,
-            num_bits_activation=self._num_bits_activation),
+            num_bits_activation=self._num_bits_activation,
+        ),
         default_n_bit_transforms.Conv2DBatchNormQuantize(
             num_bits_weight=self._num_bits_weight,
-            num_bits_activation=self._num_bits_activation),
+            num_bits_activation=self._num_bits_activation,
+        ),
         default_n_bit_transforms.ConcatTransform6Inputs(
             num_bits_weight=self._num_bits_weight,
-            num_bits_activation=self._num_bits_activation),
+            num_bits_activation=self._num_bits_activation,
+        ),
         default_n_bit_transforms.ConcatTransform5Inputs(
             num_bits_weight=self._num_bits_weight,
-            num_bits_activation=self._num_bits_activation),
+            num_bits_activation=self._num_bits_activation,
+        ),
         default_n_bit_transforms.ConcatTransform4Inputs(
             num_bits_weight=self._num_bits_weight,
-            num_bits_activation=self._num_bits_activation),
+            num_bits_activation=self._num_bits_activation,
+        ),
         default_n_bit_transforms.ConcatTransform3Inputs(
             num_bits_weight=self._num_bits_weight,
-            num_bits_activation=self._num_bits_activation),
+            num_bits_activation=self._num_bits_activation,
+        ),
         default_n_bit_transforms.ConcatTransform(
             num_bits_weight=self._num_bits_weight,
-            num_bits_activation=self._num_bits_activation),
+            num_bits_activation=self._num_bits_activation,
+        ),
         default_n_bit_transforms.LayerReLUQuantize(
             num_bits_weight=self._num_bits_weight,
-            num_bits_activation=self._num_bits_activation),
+            num_bits_activation=self._num_bits_activation,
+        ),
         default_n_bit_transforms.LayerReluActivationQuantize(
             num_bits_weight=self._num_bits_weight,
-            num_bits_activation=self._num_bits_activation),
+            num_bits_activation=self._num_bits_activation,
+        ),
         default_n_bit_transforms.DenseBatchNormQuantize(
             num_bits_weight=self._num_bits_weight,
-            num_bits_activation=self._num_bits_activation),
+            num_bits_activation=self._num_bits_activation,
+        ),
         default_n_bit_transforms.DenseBatchNormReLUQuantize(
             num_bits_weight=self._num_bits_weight,
-            num_bits_activation=self._num_bits_activation),
+            num_bits_activation=self._num_bits_activation,
+        ),
         default_n_bit_transforms.DenseBatchNormActivationQuantize(
             num_bits_weight=self._num_bits_weight,
-            num_bits_activation=self._num_bits_activation),
+            num_bits_activation=self._num_bits_activation,
+        ),
     ]
     return model_transformer.ModelTransformer(
-        model, transforms,
-        set(layer_quantize_map.keys()), layer_quantize_map).transform()
+        model, transforms, set(layer_quantize_map.keys()), layer_quantize_map
+    ).transform()
diff --git a/tensorflow_model_optimization/python/core/quantization/keras/experimental/default_n_bit/default_n_bit_quantize_registry.py b/tensorflow_model_optimization/python/core/quantization/keras/experimental/default_n_bit/default_n_bit_quantize_registry.py
index 61c368850..d33dc67be 100644
--- a/tensorflow_model_optimization/python/core/quantization/keras/experimental/default_n_bit/default_n_bit_quantize_registry.py
+++ b/tensorflow_model_optimization/python/core/quantization/keras/experimental/default_n_bit/default_n_bit_quantize_registry.py
@@ -22,15 +22,17 @@
 
 import tensorflow as tf
 
+from tensorflow_model_optimization.python.core.keras.compat import keras
 from tensorflow_model_optimization.python.core.quantization.keras import quantize_config
 from tensorflow_model_optimization.python.core.quantization.keras import quantize_registry
 from tensorflow_model_optimization.python.core.quantization.keras import quantizers
 from tensorflow_model_optimization.python.core.quantization.keras.experimental.default_n_bit import default_n_bit_quantize_configs as n_bit_configs
 from tensorflow_model_optimization.python.core.quantization.keras.experimental.default_n_bit import default_n_bit_quantizers as n_bit_quantizers
 
+
 QuantizeConfig = quantize_config.QuantizeConfig
 
-layers = tf.keras.layers
+layers = keras.layers
 
 
 class _QuantizeInfo(object):
@@ -92,13 +94,10 @@ class DefaultNBitQuantizeRegistry(
       _QuantizeInfo(layers.LeakyReLU, [], [], True),
       # layers.PReLU,
       # layers.ThresholdedReLU,
-
       # Convolution Layers
       # _QuantizeInfo(layers.Conv1D, ['kernel'], ['activation']),
-
       # layers.Conv2D is supported and handled in code below.
       # layers.DepthwiseConv2D is supported and handled in code below.
-
       # _QuantizeInfo(layers.Conv3D, ['kernel'], ['activation']),
       # _QuantizeInfo(layers.Conv3DTranspose, ['kernel'], ['activation']),
       _QuantizeInfo(layers.Concatenate, [], [], True),
@@ -106,7 +105,6 @@ class DefaultNBitQuantizeRegistry(
       _no_quantize(layers.Cropping2D),
       _no_quantize(layers.Cropping3D),
       # _no_quantize(layers.UpSampling1D),
-
       # TODO(tfmot): Reduce the quantization errors for bilinear interpolation
       # type for UpSampling2D op. UpSampling2D supports two interpolation types,
       # nearest and bilinear. we convert the op to ResizeBilnear integer op on
@@ -120,15 +118,12 @@ class DefaultNBitQuantizeRegistry(
       # (Note that the nearest case just copies the number so there’s no more
       # errors even if the quantization order is different.)
       _QuantizeInfo(layers.UpSampling2D, [], [], True),
-
       # _no_quantize(layers.UpSampling3D),
       _no_quantize(layers.ZeroPadding1D),
       _no_quantize(layers.ZeroPadding2D),
       # _no_quantize(layers.ZeroPadding3D),
-
       # Supported via modifications in Transforms.
       # layers.SeparableConv1D, layers.SeparableConv2D,
-
       # Core Layers
       _no_quantize(layers.ActivityRegularization),
       _QuantizeInfo(layers.Dense, ['kernel'], ['activation']),
@@ -142,7 +137,6 @@ class DefaultNBitQuantizeRegistry(
       _no_quantize(layers.SpatialDropout2D),
       _no_quantize(layers.SpatialDropout3D),
       # layers.Lambda needs custom handling by the user.
-
       # Pooling Layers
       _QuantizeInfo(layers.AveragePooling1D, [], [], True),
       _QuantizeInfo(layers.AveragePooling2D, [], [], True),
@@ -156,34 +150,29 @@ class DefaultNBitQuantizeRegistry(
       # _no_quantize(layers.MaxPooling1D),
       _no_quantize(layers.MaxPooling2D),
       # _no_quantize(layers.MaxPooling3D),
-
       # _QuantizeInfo(layers.LocallyConnected1D, ['kernel'], ['activation']),
       # _QuantizeInfo(layers.LocallyConnected2D, ['kernel'], ['activation']),
       _QuantizeInfo(layers.Add, [], [], True),
-
       # Enable once verified with TFLite behavior.
       # layers.Embedding: ['embeddings'],
-
       # BatchNormalization is handled elsewhere, in the cases
       # where it's preceded by convolutional layers.
       #   layers.BatchNormalization: [],
-
       # Merge layers to be added.
-
       # RNN Cells
       # TODO(pulkitb): Verify RNN layers behavior.
       # TODO(tfmot): check if we still need to allowlist via compat.v1 and
       # compat.v2 to support legacy TensorFlow 2.X
       # behavior where the v2 RNN uses the v1 RNNCell instead of the v2 RNNCell.
       # See b/145939875 for details.
-      # _QuantizeInfo(tf.keras.layers.GRUCell, ['kernel', 'recurrent_kernel'],
+      # _QuantizeInfo(keras.layers.GRUCell, ['kernel', 'recurrent_kernel'],
       #               ['activation', 'recurrent_activation']),
-      # _QuantizeInfo(tf.keras.layers.LSTMCell, ['kernel', 'recurrent_kernel'],
+      # _QuantizeInfo(keras.layers.LSTMCell, ['kernel', 'recurrent_kernel'],
       #               ['activation', 'recurrent_activation']),
-      # _QuantizeInfo(tf.keras.experimental.PeepholeLSTMCell,
+      # _QuantizeInfo(keras.experimental.PeepholeLSTMCell,
       #               ['kernel', 'recurrent_kernel'],
       #               ['activation', 'recurrent_activation']),
-      # _QuantizeInfo(tf.keras.layers.SimpleRNNCell,
+      # _QuantizeInfo(keras.layers.SimpleRNNCell,
       #               ['kernel', 'recurrent_kernel'],
       #               ['activation', 'recurrent_activation']),
   ]
diff --git a/tensorflow_model_optimization/python/core/quantization/keras/experimental/default_n_bit/default_n_bit_quantize_registry_test.py b/tensorflow_model_optimization/python/core/quantization/keras/experimental/default_n_bit/default_n_bit_quantize_registry_test.py
index 231ef4695..47bd9bc4d 100644
--- a/tensorflow_model_optimization/python/core/quantization/keras/experimental/default_n_bit/default_n_bit_quantize_registry_test.py
+++ b/tensorflow_model_optimization/python/core/quantization/keras/experimental/default_n_bit/default_n_bit_quantize_registry_test.py
@@ -24,13 +24,14 @@
 import numpy as np
 import tensorflow as tf
 
+from tensorflow_model_optimization.python.core.keras.compat import keras
 from tensorflow_model_optimization.python.core.quantization.keras import quantizers
 from tensorflow_model_optimization.python.core.quantization.keras import utils as quantize_utils
 from tensorflow_model_optimization.python.core.quantization.keras.experimental.default_n_bit import default_n_bit_quantize_registry as n_bit_registry
 
-keras = tf.keras
-K = tf.keras.backend
-l = tf.keras.layers
+
+K = keras.backend
+l = keras.layers
 
 deserialize_keras_object = quantize_utils.deserialize_keras_object
 serialize_keras_object = quantize_utils.serialize_keras_object
diff --git a/tensorflow_model_optimization/python/core/quantization/keras/experimental/default_n_bit/default_n_bit_quantizers.py b/tensorflow_model_optimization/python/core/quantization/keras/experimental/default_n_bit/default_n_bit_quantizers.py
index b36491eed..ac5276936 100644
--- a/tensorflow_model_optimization/python/core/quantization/keras/experimental/default_n_bit/default_n_bit_quantizers.py
+++ b/tensorflow_model_optimization/python/core/quantization/keras/experimental/default_n_bit/default_n_bit_quantizers.py
@@ -16,6 +16,7 @@
 
 import tensorflow as tf
 
+from tensorflow_model_optimization.python.core.keras.compat import keras
 from tensorflow_model_optimization.python.core.quantization.keras import quantizers
 
 
@@ -37,13 +38,15 @@ def build(self, tensor_shape, name, layer):
     min_weight = layer.add_weight(
         name + '_min',
         shape=(tensor_shape[-1],),
-        initializer=tf.keras.initializers.Constant(-6.0),
-        trainable=False)
+        initializer=keras.initializers.Constant(-6.0),
+        trainable=False,
+    )
     max_weight = layer.add_weight(
         name + '_max',
         shape=(tensor_shape[-1],),
-        initializer=tf.keras.initializers.Constant(6.0),
-        trainable=False)
+        initializer=keras.initializers.Constant(6.0),
+        trainable=False,
+    )
 
     return {'min_var': min_weight, 'max_var': max_weight}
 
diff --git a/tensorflow_model_optimization/python/core/quantization/keras/experimental/default_n_bit/default_n_bit_quantizers_test.py b/tensorflow_model_optimization/python/core/quantization/keras/experimental/default_n_bit/default_n_bit_quantizers_test.py
index dcbcc90bf..5a23cd653 100644
--- a/tensorflow_model_optimization/python/core/quantization/keras/experimental/default_n_bit/default_n_bit_quantizers_test.py
+++ b/tensorflow_model_optimization/python/core/quantization/keras/experimental/default_n_bit/default_n_bit_quantizers_test.py
@@ -19,14 +19,15 @@
 from __future__ import print_function
 
 from absl.testing import parameterized
-
 import tensorflow as tf
 
+from tensorflow_model_optimization.python.core.keras.compat import keras
 from tensorflow_model_optimization.python.core.quantization.keras.experimental.default_n_bit import default_n_bit_quantizers
 
-DefaultNBitConvWeightsQuantizer = default_n_bit_quantizers.DefaultNBitConvWeightsQuantizer
 
-keras = tf.keras
+DefaultNBitConvWeightsQuantizer = (
+    default_n_bit_quantizers.DefaultNBitConvWeightsQuantizer
+)
 
 
 class DefaultNBitConvWeightsQuantizerTest(tf.test.TestCase,
diff --git a/tensorflow_model_optimization/python/core/quantization/keras/experimental/default_n_bit/default_n_bit_transforms.py b/tensorflow_model_optimization/python/core/quantization/keras/experimental/default_n_bit/default_n_bit_transforms.py
index 8a8ed2c1a..0aa66292f 100644
--- a/tensorflow_model_optimization/python/core/quantization/keras/experimental/default_n_bit/default_n_bit_transforms.py
+++ b/tensorflow_model_optimization/python/core/quantization/keras/experimental/default_n_bit/default_n_bit_transforms.py
@@ -20,6 +20,8 @@
 import numpy as np
 import tensorflow as tf
 
+from tensorflow_model_optimization.python.core.keras.compat import keras
+from tensorflow_model_optimization.python.core.keras.compat import unique_object_name
 from tensorflow_model_optimization.python.core.quantization.keras import quantize_aware_activation
 from tensorflow_model_optimization.python.core.quantization.keras import quantize_layer
 from tensorflow_model_optimization.python.core.quantization.keras import quantizers
@@ -28,21 +30,10 @@
 from tensorflow_model_optimization.python.core.quantization.keras.experimental.default_n_bit import default_n_bit_quantize_registry
 from tensorflow_model_optimization.python.core.quantization.keras.graph_transformations import transforms
 
-try:
-  import keras  # pylint: disable=g-import-not-at-top
-  if hasattr(keras, 'src'):
-    # Path as seen in pip packages as of TF/Keras 2.13.
-    from keras.src.backend import unique_object_name  # pylint: disable=g-import-not-at-top,g-importing-member
-  else:
-    from keras.backend import unique_object_name  # pylint: disable=g-import-not-at-top,g-importing-member
-except ImportError:
-  unique_object_name = tf._keras_internal.backend.unique_object_name  # pylint: disable=protected-access
 
 LayerNode = transforms.LayerNode
 LayerPattern = transforms.LayerPattern
 
-keras = tf.keras
-
 
 def _get_conv_bn_layers(bn_layer_node):
   bn_layer = bn_layer_node.layer
@@ -425,14 +416,14 @@ def replacement(self, match_layer):
 
     # TODO(pulkitb): Handle other base_layer args such as dtype, input_dim etc.
 
-    sepconv2d_layer = tf.keras.layers.SeparableConv2D(
+    sepconv2d_layer = keras.layers.SeparableConv2D(
         filters=sepconv1d_config['filters'],
         kernel_size=(1,) + _normalize_tuple(sepconv1d_config['kernel_size']),
         strides=_normalize_tuple(sepconv1d_config['strides']) * 2,
         padding=padding,
         data_format=sepconv1d_config['data_format'],
-        dilation_rate=(1,) + _normalize_tuple(
-            sepconv1d_config['dilation_rate']),
+        dilation_rate=(1,)
+        + _normalize_tuple(sepconv1d_config['dilation_rate']),
         depth_multiplier=sepconv1d_config['depth_multiplier'],
         activation=sepconv1d_config['activation'],
         use_bias=sepconv1d_config['use_bias'],
@@ -449,7 +440,7 @@ def replacement(self, match_layer):
         # TODO(pulkitb): Rethink what to do for name. Using the same name leads
         # to confusion, since it's typically separable_conv1d
         name=sepconv1d_config['name'] + '_QAT_SepConv2D',
-        trainable=sepconv1d_config['trainable']
+        trainable=sepconv1d_config['trainable'],
     )
 
     sepconv2d_weights = collections.OrderedDict()
@@ -476,9 +467,10 @@ def replacement(self, match_layer):
     # TODO(pulkitb): Consider moving from Lambda to custom ExpandDims/Squeeze.
 
     # Layer before SeparableConv2D which expands input tensors to match 2D.
-    expand_layer = tf.keras.layers.Lambda(
+    expand_layer = keras.layers.Lambda(
         lambda x: tf.expand_dims(x, spatial_dim),
-        name=self._get_name('sepconv1d_expand'))
+        name=self._get_name('sepconv1d_expand'),
+    )
     expand_layer_config = quantize_utils.serialize_layer(
         expand_layer, use_legacy_format=True
     )
@@ -487,9 +479,10 @@ def replacement(self, match_layer):
         'quantize_config':
             configs.NoOpQuantizeConfig()}
 
-    squeeze_layer = tf.keras.layers.Lambda(
+    squeeze_layer = keras.layers.Lambda(
         lambda x: tf.squeeze(x, [spatial_dim]),
-        name=self._get_name('sepconv1d_squeeze'))
+        name=self._get_name('sepconv1d_squeeze'),
+    )
     squeeze_layer_config = quantize_utils.serialize_layer(
         squeeze_layer, use_legacy_format=True
     )
@@ -546,7 +539,7 @@ def replacement(self, match_layer):
     # Needs special handling: weights
     # Unknown: dynamic, autocast
 
-    dconv_layer = tf.keras.layers.DepthwiseConv2D(
+    dconv_layer = keras.layers.DepthwiseConv2D(
         kernel_size=sepconv_layer['config']['kernel_size'],
         strides=sepconv_layer['config']['strides'],
         padding=sepconv_layer['config']['padding'],
@@ -558,7 +551,7 @@ def replacement(self, match_layer):
         depthwise_initializer=sepconv_layer['config']['depthwise_initializer'],
         depthwise_regularizer=sepconv_layer['config']['depthwise_regularizer'],
         depthwise_constraint=sepconv_layer['config']['depthwise_constraint'],
-        trainable=sepconv_layer['config']['trainable']
+        trainable=sepconv_layer['config']['trainable'],
     )
     dconv_weights = collections.OrderedDict()
     dconv_weights['depthwise_kernel:0'] = sepconv_weights[0]
@@ -569,7 +562,7 @@ def replacement(self, match_layer):
     # Needed to ensure these new layers are considered for quantization.
     dconv_metadata = {'quantize_config': None}
 
-    conv_layer = tf.keras.layers.Conv2D(
+    conv_layer = keras.layers.Conv2D(
         filters=sepconv_layer['config']['filters'],
         kernel_size=(1, 1),  # (1,) * rank
         strides=(1, 1),
@@ -586,7 +579,7 @@ def replacement(self, match_layer):
         activity_regularizer=sepconv_layer['config']['activity_regularizer'],
         kernel_constraint=sepconv_layer['config']['pointwise_constraint'],
         bias_constraint=sepconv_layer['config']['bias_constraint'],
-        trainable=sepconv_layer['config']['trainable']
+        trainable=sepconv_layer['config']['trainable'],
     )
     conv_weights = collections.OrderedDict()
     conv_weights['kernel:0'] = sepconv_weights[1]
@@ -704,7 +697,7 @@ def pattern(self):
         'Concatenate', inputs=[LayerPattern('.*'), LayerPattern('.*')])
 
   def _get_layer_type(self, layer_class_name):
-    keras_layers = inspect.getmembers(tf.keras.layers, inspect.isclass)
+    keras_layers = inspect.getmembers(keras.layers, inspect.isclass)
     for layer_name, layer_type in keras_layers:
       if layer_name == layer_class_name:
         return layer_type
diff --git a/tensorflow_model_optimization/python/core/quantization/keras/experimental/default_n_bit/default_n_bit_transforms_test.py b/tensorflow_model_optimization/python/core/quantization/keras/experimental/default_n_bit/default_n_bit_transforms_test.py
index e5ba016e3..2f7107516 100644
--- a/tensorflow_model_optimization/python/core/quantization/keras/experimental/default_n_bit/default_n_bit_transforms_test.py
+++ b/tensorflow_model_optimization/python/core/quantization/keras/experimental/default_n_bit/default_n_bit_transforms_test.py
@@ -22,6 +22,7 @@
 import numpy as np
 import tensorflow as tf
 
+from tensorflow_model_optimization.python.core.keras.compat import keras
 from tensorflow_model_optimization.python.core.quantization.keras import quantize_aware_activation
 from tensorflow_model_optimization.python.core.quantization.keras import quantize_layer
 from tensorflow_model_optimization.python.core.quantization.keras import quantizers
@@ -31,14 +32,13 @@
 from tensorflow_model_optimization.python.core.quantization.keras.layers import conv_batchnorm_test_utils
 from tensorflow_model_optimization.python.core.quantization.keras.layers import dense_batchnorm_test_utils
 
+
 ModelTransformer = model_transformer.ModelTransformer
 
 Conv2DModel = conv_batchnorm_test_utils.Conv2DModel
 DepthwiseConv2DModel = conv_batchnorm_test_utils.DepthwiseConv2DModel
 DenseModel = dense_batchnorm_test_utils.DenseModel
 
-keras = tf.keras
-
 Conv2DBatchNormActivationQuantize = default_n_bit_transforms.Conv2DBatchNormActivationQuantize
 Conv2DBatchNormReLUQuantize = default_n_bit_transforms.Conv2DBatchNormReLUQuantize
 
@@ -288,15 +288,23 @@ def testDenseBatchNormActivationQuantize(self, layer_type,
       ('strides', {'strides': 2}),
       ('dilation_rate', {'dilation_rate': 2}),
       ('depth_multiplier', {'depth_multiplier': 2}),
-      ('regularizer', {
-          'depthwise_regularizer': 'l2',
-          'pointwise_regularizer': 'l2',
-          'bias_regularizer': 'l2',
-          'activity_regularizer': 'l2'}),
-      ('constraint', {
-          'depthwise_constraint': tf.keras.constraints.max_norm(2.),
-          'pointwise_constraint': tf.keras.constraints.min_max_norm(0., 2.),
-          'bias_constraint': tf.keras.constraints.unit_norm()}),
+      (
+          'regularizer',
+          {
+              'depthwise_regularizer': 'l2',
+              'pointwise_regularizer': 'l2',
+              'bias_regularizer': 'l2',
+              'activity_regularizer': 'l2',
+          },
+      ),
+      (
+          'constraint',
+          {
+              'depthwise_constraint': keras.constraints.max_norm(2.0),
+              'pointwise_constraint': keras.constraints.min_max_norm(0.0, 2.0),
+              'bias_constraint': keras.constraints.unit_norm(),
+          },
+      ),
       ('activation_relu', {'activation': 'relu'}),
       # TODO(pulkitb): Temporarily disabling due to numerical errors resulting
       # from caching of activation logits in TF code.
@@ -309,10 +317,10 @@ def testSeparableConv1DQuantize_(self, kwargs):
     stack_size = 3
     num_row = 7
 
-    sepconv_model = tf.keras.Sequential([
-        tf.keras.Input(
-            shape=(num_row, stack_size), batch_size=num_samples),
-        tf.keras.layers.SeparableConv1D(**kwargs)])
+    sepconv_model = keras.Sequential([
+        keras.Input(shape=(num_row, stack_size), batch_size=num_samples),
+        keras.layers.SeparableConv1D(**kwargs),
+    ])
 
     transformed_model, updated_metadata = ModelTransformer(
         sepconv_model,
@@ -345,21 +353,28 @@ def testSeparableConv1DQuantize_(self, kwargs):
   @parameterized.named_parameters(
       ('padding_valid', {'padding': 'valid'}),
       ('padding_same', {'padding': 'same'}),
-      ('padding_same_dilation_2',
-       {'padding': 'same', 'dilation_rate': 2}),
+      ('padding_same_dilation_2', {'padding': 'same', 'dilation_rate': 2}),
       ('strides', {'strides': 2}),
       ('dilation_rate', {'dilation_rate': 2}),
       ('depth_multiplier', {'depth_multiplier': 2}),
-      ('regularizer', {
-          'depthwise_regularizer': 'l2',
-          'pointwise_regularizer': 'l2',
-          'bias_regularizer': 'l2',
-          'activity_regularizer': 'l2'}),
+      (
+          'regularizer',
+          {
+              'depthwise_regularizer': 'l2',
+              'pointwise_regularizer': 'l2',
+              'bias_regularizer': 'l2',
+              'activity_regularizer': 'l2',
+          },
+      ),
       ('use_bias', {'use_bias': False}),
-      ('constraint', {
-          'depthwise_constraint': tf.keras.constraints.max_norm(2.),
-          'pointwise_constraint': tf.keras.constraints.min_max_norm(0., 2.),
-          'bias_constraint': tf.keras.constraints.unit_norm()})
+      (
+          'constraint',
+          {
+              'depthwise_constraint': keras.constraints.max_norm(2.0),
+              'pointwise_constraint': keras.constraints.min_max_norm(0.0, 2.0),
+              'bias_constraint': keras.constraints.unit_norm(),
+          },
+      ),
   )
   def testSeparableConvQuantize_(self, kwargs):
     kwargs['filters'] = 2
@@ -369,10 +384,12 @@ def testSeparableConvQuantize_(self, kwargs):
     num_row = 7
     num_col = 6
 
-    sepconv_model = tf.keras.Sequential([
-        tf.keras.Input(
-            shape=(num_row, num_col, stack_size), batch_size=num_samples),
-        tf.keras.layers.SeparableConv2D(**kwargs)])
+    sepconv_model = keras.Sequential([
+        keras.Input(
+            shape=(num_row, num_col, stack_size), batch_size=num_samples
+        ),
+        keras.layers.SeparableConv2D(**kwargs),
+    ])
 
     transformed_model, updated_metadata = ModelTransformer(
         sepconv_model,
@@ -440,13 +457,13 @@ def testAddReLUQuantize(self, activation_type, transform_type):
   def testLayerReLUQuantize(self, activation_type, transform_type):
     # TODO(b/185727342): Add tests for DepthConv and Dense
     input_shape = (3, 3, 3)
-    conv_layer = tf.keras.layers.Conv2D(5, 2, input_shape=input_shape)
+    conv_layer = keras.layers.Conv2D(5, 2, input_shape=input_shape)
     if activation_type == 'relu':
       act_layer = keras.layers.ReLU(6.0)
     elif activation_type == 'act_relu':
       act_layer = keras.layers.Activation('relu')
 
-    model = tf.keras.Sequential([conv_layer, act_layer])
+    model = keras.Sequential([conv_layer, act_layer])
 
     transformed_model, updated_metadata = ModelTransformer(
         model,
@@ -580,6 +597,6 @@ def testConcatMultipleLevels(self):
 
 
 if __name__ == '__main__':
-  if hasattr(tf.keras.__internal__, 'enable_unsafe_deserialization'):
-    tf.keras.__internal__.enable_unsafe_deserialization()
+  if hasattr(keras.__internal__, 'enable_unsafe_deserialization'):
+    keras.__internal__.enable_unsafe_deserialization()
   tf.test.main()
diff --git a/tensorflow_model_optimization/python/core/quantization/keras/graph_transformations/BUILD b/tensorflow_model_optimization/python/core/quantization/keras/graph_transformations/BUILD
index 8a83996de..c77365dbd 100644
--- a/tensorflow_model_optimization/python/core/quantization/keras/graph_transformations/BUILD
+++ b/tensorflow_model_optimization/python/core/quantization/keras/graph_transformations/BUILD
@@ -51,6 +51,7 @@ py_strict_library(
     deps = [
         ":transforms",
         # tensorflow dep1,
+        "//tensorflow_model_optimization/python/core/keras:compat",
     ],
 )
 
@@ -68,6 +69,7 @@ py_strict_test(
         # google/protobuf:use_fast_cpp_protos dep1,  # Automatically added
         # numpy dep1,
         # tensorflow dep1,
+        "//tensorflow_model_optimization/python/core/keras:compat",
         "//tensorflow_model_optimization/python/core/quantization/keras:utils",
     ],
 )
diff --git a/tensorflow_model_optimization/python/core/quantization/keras/graph_transformations/model_transformer.py b/tensorflow_model_optimization/python/core/quantization/keras/graph_transformations/model_transformer.py
index 4a732a792..596ff35df 100644
--- a/tensorflow_model_optimization/python/core/quantization/keras/graph_transformations/model_transformer.py
+++ b/tensorflow_model_optimization/python/core/quantization/keras/graph_transformations/model_transformer.py
@@ -13,7 +13,7 @@
 # limitations under the License.
 # ==============================================================================
 # pylint: disable=g-explicit-length-test
-"""Apply graph transformations to a tf.keras model."""
+"""Apply graph transformations to a keras model."""
 
 import collections
 import copy
@@ -21,16 +21,17 @@
 
 import tensorflow as tf
 
+from tensorflow_model_optimization.python.core.keras.compat import keras
 from tensorflow_model_optimization.python.core.quantization.keras.graph_transformations import transforms as transforms_mod
 
+
 LayerNode = transforms_mod.LayerNode
 
-keras = tf.keras
-K = tf.keras.backend
+K = keras.backend
 
 
 class ModelTransformer(object):
-  """Matches patterns to apply transforms in a tf.keras model graph."""
+  """Matches patterns to apply transforms in a keras model graph."""
 
   def __init__(self,
                model,
@@ -50,7 +51,8 @@ def __init__(self,
     """
     if not self._is_sequential_or_functional_model(model):
       raise ValueError(
-          'Only tf.keras sequential or functional models can be transformed.')
+          'Only keras sequential or functional models can be transformed.'
+      )
 
     if layer_metadata is None:
       layer_metadata = {}
diff --git a/tensorflow_model_optimization/python/core/quantization/keras/graph_transformations/model_transformer_test.py b/tensorflow_model_optimization/python/core/quantization/keras/graph_transformations/model_transformer_test.py
index 014424769..37973aae4 100644
--- a/tensorflow_model_optimization/python/core/quantization/keras/graph_transformations/model_transformer_test.py
+++ b/tensorflow_model_optimization/python/core/quantization/keras/graph_transformations/model_transformer_test.py
@@ -19,21 +19,20 @@
 from __future__ import print_function
 
 from absl.testing import parameterized
-
 import numpy as np
 import tensorflow as tf
 
+from tensorflow_model_optimization.python.core.keras.compat import keras
 from tensorflow_model_optimization.python.core.quantization.keras import utils as quantize_utils
 from tensorflow_model_optimization.python.core.quantization.keras.graph_transformations import model_transformer
 from tensorflow_model_optimization.python.core.quantization.keras.graph_transformations import transforms
 
+
 ModelTransformer = model_transformer.ModelTransformer
 Transform = transforms.Transform
 LayerPattern = transforms.LayerPattern
 LayerNode = transforms.LayerNode
 
-keras = tf.keras
-
 
 class ModelTransformerTest(tf.test.TestCase, parameterized.TestCase):
 
@@ -558,10 +557,10 @@ def replacement(self, match_layer):
         match_layer.metadata['key'] = 'value'
         return match_layer
 
-    model = tf.keras.Sequential([
-        tf.keras.layers.Conv2D(32, 5, input_shape=(28, 28, 1)),
-        tf.keras.layers.BatchNormalization(),
-        tf.keras.layers.ReLU(),
+    model = keras.Sequential([
+        keras.layers.Conv2D(32, 5, input_shape=(28, 28, 1)),
+        keras.layers.BatchNormalization(),
+        keras.layers.ReLU(),
     ])
     model_layer_names = [layer.name for layer in model.layers]
 
diff --git a/tensorflow_model_optimization/python/core/quantization/keras/layers/BUILD b/tensorflow_model_optimization/python/core/quantization/keras/layers/BUILD
index 1d01aae58..efdfb3174 100644
--- a/tensorflow_model_optimization/python/core/quantization/keras/layers/BUILD
+++ b/tensorflow_model_optimization/python/core/quantization/keras/layers/BUILD
@@ -21,6 +21,7 @@ py_strict_library(
     srcs_version = "PY3",
     deps = [
         # tensorflow dep1,
+        "//tensorflow_model_optimization/python/core/keras:compat",
     ],
 )
 
@@ -30,5 +31,6 @@ py_strict_library(
     srcs_version = "PY3",
     deps = [
         # tensorflow dep1,
+        "//tensorflow_model_optimization/python/core/keras:compat",
     ],
 )
diff --git a/tensorflow_model_optimization/python/core/quantization/keras/layers/conv_batchnorm_test_utils.py b/tensorflow_model_optimization/python/core/quantization/keras/layers/conv_batchnorm_test_utils.py
index 476610e35..a2f8998ce 100644
--- a/tensorflow_model_optimization/python/core/quantization/keras/layers/conv_batchnorm_test_utils.py
+++ b/tensorflow_model_optimization/python/core/quantization/keras/layers/conv_batchnorm_test_utils.py
@@ -20,7 +20,7 @@
 
 import tensorflow as tf
 
-keras = tf.keras
+from tensorflow_model_optimization.python.core.keras.compat import keras
 
 
 def _get_conv2d_params():
@@ -74,8 +74,9 @@ def get_nonfolded_batchnorm_model(cls,
       normalization = keras.layers.experimental.SyncBatchNormalization
 
     if squeeze_type == 'sepconv1d_squeeze':
-      squeeze_layer = tf.keras.layers.Lambda(
-          lambda x: tf.squeeze(x, [1]), name='sepconv1d_squeeze_1')
+      squeeze_layer = keras.layers.Lambda(
+          lambda x: tf.squeeze(x, [1]), name='sepconv1d_squeeze_1'
+      )
     else:
       squeeze_layer = None
 
@@ -91,7 +92,7 @@ def get_nonfolded_batchnorm_model(cls,
       layers.append(normalization(axis=-1))
       if post_bn_activation is not None:
         layers += post_bn_activation
-      return tf.keras.Sequential(layers)
+      return keras.Sequential(layers)
     else:
       inp = keras.layers.Input(cls.params['input_shape'],
                                cls.params['batch_size'])
@@ -106,7 +107,7 @@ def get_nonfolded_batchnorm_model(cls,
       out = normalization(axis=-1)(x)
       if post_bn_activation is not None:
         out = post_bn_activation(out)
-      return tf.keras.Model(inp, out)
+      return keras.Model(inp, out)
 
 
 class DepthwiseConv2DModel(Conv2DModel):
@@ -135,8 +136,9 @@ def get_nonfolded_batchnorm_model(cls,
       normalization = keras.layers.experimental.SyncBatchNormalization
 
     if squeeze_type == 'sepconv1d_squeeze':
-      squeeze_layer = tf.keras.layers.Lambda(
-          lambda x: tf.squeeze(x, [1]), name='sepconv1d_squeeze_1')
+      squeeze_layer = keras.layers.Lambda(
+          lambda x: tf.squeeze(x, [1]), name='sepconv1d_squeeze_1'
+      )
     else:
       squeeze_layer = None
 
@@ -152,7 +154,7 @@ def get_nonfolded_batchnorm_model(cls,
       layers.append(normalization(axis=-1))
       if post_bn_activation is not None:
         layers += post_bn_activation
-      return tf.keras.Sequential(layers)
+      return keras.Sequential(layers)
     else:
       inp = keras.layers.Input(cls.params['input_shape'],
                                cls.params['batch_size'])
@@ -166,4 +168,4 @@ def get_nonfolded_batchnorm_model(cls,
       out = normalization(axis=-1)(x)
       if post_bn_activation is not None:
         out = post_bn_activation(out)
-      return tf.keras.Model(inp, out)
+      return keras.Model(inp, out)
diff --git a/tensorflow_model_optimization/python/core/quantization/keras/layers/dense_batchnorm_test_utils.py b/tensorflow_model_optimization/python/core/quantization/keras/layers/dense_batchnorm_test_utils.py
index 8cca0320c..dbb1c4a4d 100644
--- a/tensorflow_model_optimization/python/core/quantization/keras/layers/dense_batchnorm_test_utils.py
+++ b/tensorflow_model_optimization/python/core/quantization/keras/layers/dense_batchnorm_test_utils.py
@@ -20,7 +20,7 @@
 
 import tensorflow as tf
 
-keras = tf.keras
+from tensorflow_model_optimization.python.core.keras.compat import keras
 
 
 class DenseModel(object):
@@ -55,4 +55,4 @@ def get_nonfolded_batchnorm_model(cls,
     out = normalization(axis=-1)(x)
     if post_bn_activation is not None:
       out = post_bn_activation(out)
-    return tf.keras.Model(inp, out)
+    return keras.Model(inp, out)
diff --git a/tensorflow_model_optimization/python/core/quantization/keras/quantize.py b/tensorflow_model_optimization/python/core/quantization/keras/quantize.py
index 1d0c4154a..b5e4337e9 100644
--- a/tensorflow_model_optimization/python/core/quantization/keras/quantize.py
+++ b/tensorflow_model_optimization/python/core/quantization/keras/quantize.py
@@ -12,12 +12,13 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 # ==============================================================================
-"""Quantization API functions for tf.keras models."""
+"""Quantization API functions for keras models."""
 import warnings
 
 import tensorflow as tf
 
 from tensorflow_model_optimization.python.core.keras import metrics
+from tensorflow_model_optimization.python.core.keras.compat import keras
 from tensorflow_model_optimization.python.core.quantization.keras import quantize_annotate as quantize_annotate_mod
 from tensorflow_model_optimization.python.core.quantization.keras import quantize_aware_activation
 from tensorflow_model_optimization.python.core.quantization.keras import quantize_config as quantize_config_mod
@@ -29,32 +30,30 @@
 from tensorflow_model_optimization.python.core.quantization.keras.default_8bit import default_8bit_quantize_scheme
 from tensorflow_model_optimization.python.core.quantization.keras.experimental.default_n_bit import default_n_bit_quantize_registry
 
-keras = tf.keras
-
 
 def quantize_scope(*args):
   """Scope which can be used to deserialize quantized Keras models and layers.
 
-  Under `quantize_scope`, Keras methods such as `tf.keras.load_model` and
-  `tf.keras.models.model_from_config` will be able to deserialize Keras models
+  Under `quantize_scope`, Keras methods such as `keras.load_model` and
+  `keras.models.model_from_config` will be able to deserialize Keras models
   and layers which contain quantization classes such as `QuantizeConfig`
   and `Quantizer`.
 
   Example:
 
   ```python
-  tf.keras.models.save_model(quantized_model, keras_file)
+  keras.models.save_model(quantized_model, keras_file)
 
   with quantize_scope():
-    loaded_model = tf.keras.models.load_model(keras_file)
+    loaded_model = keras.models.load_model(keras_file)
 
   # If your quantized model uses custom objects such as a specific `Quantizer`,
   # you can pass them to quantize_scope to deserialize your model.
   with quantize_scope({'FixedRangeQuantizer', FixedRangeQuantizer}
-    loaded_model = tf.keras.models.load_model(keras_file)
+    loaded_model = keras.models.load_model(keras_file)
   ```
 
-  For further understanding, see `tf.keras.utils.custom_object_scope`.
+  For further understanding, see `keras.utils.custom_object_scope`.
 
   Args:
     *args: Variable length list of dictionaries of `{name, class}` pairs to add
@@ -78,11 +77,11 @@ def quantize_scope(*args):
   quantization_objects.update(default_n_bit_quantize_registry._types_dict())  # pylint: disable=protected-access
   quantization_objects.update(quantizers._types_dict())  # pylint: disable=protected-access
 
-  return tf.keras.utils.custom_object_scope(*(args + (quantization_objects,)))
+  return keras.utils.custom_object_scope(*(args + (quantization_objects,)))
 
 
 def quantize_model(to_quantize, quantized_layer_name_prefix='quant_'):
-  """Quantize a `tf.keras` model with the default quantization implementation.
+  """Quantize a `keras` model with the default quantization implementation.
 
   Quantization constructs a model which emulates quantization during training.
   This allows the model to learn parameters robust to quantization loss, and
@@ -102,9 +101,9 @@ def quantize_model(to_quantize, quantized_layer_name_prefix='quant_'):
       ]))
 
   # Quantize functional model
-  in = tf.keras.Input((3,))
-  out = tf.keras.Dense(2)(in)
-  model = tf.keras.Model(in, out)
+  in = keras.Input((3,))
+  out = keras.Dense(2)(in)
+  model = keras.Model(in, out)
 
   quantized_model = quantize_model(model)
   ```
@@ -116,13 +115,12 @@ def quantize_model(to_quantize, quantized_layer_name_prefix='quant_'):
   of the original model.
 
   Args:
-    to_quantize: tf.keras model to be quantized. It can have pre-trained
-      weights.
+    to_quantize: keras model to be quantized. It can have pre-trained weights.
     quantized_layer_name_prefix: Name prefix for the quantized layers. The
       default is `quant_`.
 
   Returns:
-    Returns a new `tf.keras` model prepared for quantization.
+    Returns a new `keras` model prepared for quantization.
   """
   if to_quantize is None:
     raise ValueError('`to_quantize` cannot be None')
@@ -130,18 +128,14 @@ def quantize_model(to_quantize, quantized_layer_name_prefix='quant_'):
   if quantized_layer_name_prefix is None:
     quantized_layer_name_prefix = ''
 
-  if not isinstance(to_quantize, keras.Model):
-    raise ValueError(
-        '`to_quantize` can only be a `tf.keras.Model` instance. Use '
-        'the `quantize_annotate_layer` API to handle individual layers.'
-        'You passed an instance of type: {input}.'.format(
-            input=to_quantize.__class__.__name__))
-
-  if not isinstance(
-      to_quantize, keras.Sequential) and not to_quantize._is_graph_network:  # pylint: disable=protected-access
+  if not isinstance(to_quantize, keras.Sequential) and not (
+      hasattr(to_quantize, '_is_graph_network')
+      and to_quantize._is_graph_network
+  ):  # pylint: disable=protected-access
     raise ValueError(
-        '`to_quantize` can only either be a tf.keras Sequential or '
-        'Functional model.')
+        '`to_quantize` can only either be a keras Sequential or '
+        'Functional model.'
+    )
 
   annotated_model = quantize_annotate_model(to_quantize)
   return quantize_apply(
@@ -149,7 +143,7 @@ def quantize_model(to_quantize, quantized_layer_name_prefix='quant_'):
 
 
 def quantize_annotate_model(to_annotate):
-  """Annotate a `tf.keras` model to be quantized.
+  """Annotate a `keras` model to be quantized.
 
   This function does not actually quantize the model. It merely specifies
   that the model needs to be quantized. `quantize_apply` can then be used
@@ -180,10 +174,10 @@ def quantize_annotate_model(to_annotate):
   Note that this function removes the optimizer from the original model.
 
   Args:
-    to_annotate: `tf.keras` model which needs to be quantized.
+    to_annotate: `keras` model which needs to be quantized.
 
   Returns:
-    New tf.keras model with each layer in the model wrapped with
+    New keras model with each layer in the model wrapped with
     `QuantizeAnnotate`. The new model preserves weights from the original
     model.
 
@@ -195,16 +189,19 @@ def quantize_annotate_model(to_annotate):
 
   if not isinstance(to_annotate, keras.Model):
     raise ValueError(
-        '`to_annotate` can only be a `tf.keras.Model` instance. Use '
+        '`to_annotate` can only be a `keras.Model` instance. Use '
         'the `quantize_annotate_layer` API to handle individual layers. '
         'You passed an instance of type: {input}.'.format(
-            input=to_annotate.__class__.__name__))
+            input=to_annotate.__class__.__name__
+        )
+    )
 
   if not isinstance(
       to_annotate, keras.Sequential) and not to_annotate._is_graph_network:  # pylint: disable=protected-access
     raise ValueError(
-        '`to_annotate` can only either be a tf.keras Sequential or '
-        'Functional model.')
+        '`to_annotate` can only either be a keras Sequential or '
+        'Functional model.'
+    )
 
   def _add_quant_wrapper(layer):
     """Add annotation wrapper."""
@@ -212,7 +209,7 @@ def _add_quant_wrapper(layer):
     if isinstance(layer, quantize_annotate_mod.QuantizeAnnotate):
       return layer
 
-    if isinstance(layer, tf.keras.layers.Lambda):
+    if isinstance(layer, keras.layers.Lambda):
       warnings.warn(
           'Lambda layers are not supported by automatic model annotation '
           'because the internal functionality cannot always be determined by '
@@ -221,9 +218,10 @@ def _add_quant_wrapper(layer):
           'be quantized which may lead to unexpected results.')
       return layer
 
-    if isinstance(layer, tf.keras.Model):
+    if isinstance(layer, keras.Model):
       raise ValueError(
-          'Quantizing a tf.keras Model inside another tf.keras Model is not supported.'
+          'Quantizing a keras Model inside another keras Model is not'
+          ' supported.'
       )
 
     return quantize_annotate_mod.QuantizeAnnotate(layer)
@@ -233,7 +231,7 @@ def _add_quant_wrapper(layer):
 
 
 def quantize_annotate_layer(to_annotate, quantize_config=None):
-  """Annotate a `tf.keras` layer to be quantized.
+  """Annotate a `keras` layer to be quantized.
 
   This function does not actually quantize the layer. It is merely used to
   specify that the layer should be quantized. The layer then gets quantized
@@ -256,12 +254,12 @@ def quantize_annotate_layer(to_annotate, quantize_config=None):
   ```
 
   Args:
-    to_annotate: `tf.keras` layer which needs to be quantized.
+    to_annotate: `keras` layer which needs to be quantized.
     quantize_config: optional `QuantizeConfig` which controls how the layer is
       quantized. In its absence, the default behavior for the layer is used.
 
   Returns:
-    `tf.keras` layer wrapped with `QuantizeAnnotate`.
+    `keras` layer wrapped with `QuantizeAnnotate`.
   """
   if to_annotate is None:
     raise ValueError('`to_annotate` cannot be None')
@@ -270,9 +268,11 @@ def quantize_annotate_layer(to_annotate, quantize_config=None):
   if not isinstance(to_annotate, keras.layers.Layer) or isinstance(
       to_annotate, keras.Model):
     raise ValueError(
-        '`to_annotate` can only be a `tf.keras.layers.Layer` instance. '
+        '`to_annotate` can only be a `keras.layers.Layer` instance. '
         'You passed an instance of type: {input}.'.format(
-            input=to_annotate.__class__.__name__))
+            input=to_annotate.__class__.__name__
+        )
+    )
 
   if quantize_config is not None and not isinstance(
       quantize_config, quantize_config_mod.QuantizeConfig):
@@ -290,7 +290,7 @@ def quantize_apply(
     model,
     scheme=default_8bit_quantize_scheme.Default8BitQuantizeScheme(),
     quantized_layer_name_prefix='quant_'):
-  """Quantize a `tf.keras` model that has been annotated for quantization.
+  """Quantize a `keras` model that has been annotated for quantization.
 
   Quantization constructs a model which emulates quantization during training.
   This allows the model to learn parameters robust to quantization loss, and
@@ -300,7 +300,7 @@ def quantize_apply(
   https://www.tensorflow.org/model_optimization/guide/quantization/training
   TODO(tfmot): Link blog once launched.
 
-  This function takes a `tf.keras` model in which the desired layers for
+  This function takes a `keras` model in which the desired layers for
   quantization have already been annotated. See `quantize_annotate_model`
   and `quantize_annotate_layer`.
 
@@ -323,7 +323,7 @@ def quantize_apply(
   of the original model.
 
   Args:
-    model: A `tf.keras` Sequential or Functional model which has been annotated
+    model: A `keras` Sequential or Functional model which has been annotated
       with `quantize_annotate`. It can have pre-trained weights.
     scheme: A `QuantizeScheme` which specifies transformer and quantization
       registry. The default is `Default8BitQuantizeScheme()`.
@@ -331,7 +331,7 @@ def quantize_apply(
       is `quant_`.
 
   Returns:
-    Returns a new `tf.keras` model in which the annotated layers have been
+    Returns a new `keras` model in which the annotated layers have been
     prepared for quantization.
   """
   if model is None:
@@ -341,13 +341,17 @@ def quantize_apply(
     quantized_layer_name_prefix = ''
 
   if not isinstance(model, keras.Model):
-    raise ValueError('`model` can only be a `tf.keras.Model` instance.'
-                     'You passed an instance of type: {input}.'.format(
-                         input=model.__class__.__name__))
+    raise ValueError(
+        '`model` can only be a `keras.Model` instance.'
+        'You passed an instance of type: {input}.'.format(
+            input=model.__class__.__name__
+        )
+    )
 
   if not isinstance(model, keras.Sequential) and not model._is_graph_network:  # pylint: disable=protected-access
-    raise ValueError('`model` can only either be a tf.keras Sequential or '
-                     'Functional model.')
+    raise ValueError(
+        '`model` can only either be a keras Sequential or Functional model.'
+    )
 
   # Have at least 1 layer annotated with QuantizeAnnotate
   if not any(isinstance(layer, quantize_annotate_mod.QuantizeAnnotate)
@@ -586,18 +590,18 @@ def fix_input_output_range(
   altered during training. To set these values, use the arguments as follows:
 
   Args:
-    model: A `tf.keras` Sequential or Functional model which has been quantized.
+    model: A `keras` Sequential or Functional model which has been quantized.
     num_bits: Number of bits for quantization
     input_min: The lower end of quantization interval for the input.
     input_max: The upper end of quantization interval for the input.
     output_min: The lower end of quantization interval for the output.
     output_max: The upper end of quantization interval for the output.
-    narrow_range: In case of 8 bits, narrow_range nudges the quantized range
-      to be [-127, 127] instead of [-128, 127]. This ensures symmetric
-      range has 0 as the centre.
+    narrow_range: In case of 8 bits, narrow_range nudges the quantized range to
+      be [-127, 127] instead of [-128, 127]. This ensures symmetric range has 0
+      as the centre.
 
   Returns:
-    Returns a new `tf.keras` model fixed input range set to (input_min,
+    Returns a new `keras` model fixed input range set to (input_min,
     input_max) and fixed output range set to (output_min, output_max).
   """
   config = model.get_config()
@@ -684,10 +688,10 @@ def remove_input_range(model):
   internally used.
 
   Args:
-    model: A `tf.keras` Sequential or Functional model which has been quantized.
+    model: A `keras` Sequential or Functional model which has been quantized.
 
   Returns:
-    Returns a new `tf.keras` model removed input range.
+    Returns a new `keras` model removed input range.
   """
   config = model.get_config()
   no_input_quantizer = quantizers.NoQuantizer()
diff --git a/tensorflow_model_optimization/python/core/quantization/keras/quantize_annotate.py b/tensorflow_model_optimization/python/core/quantization/keras/quantize_annotate.py
index 8359aeebd..40e637b61 100644
--- a/tensorflow_model_optimization/python/core/quantization/keras/quantize_annotate.py
+++ b/tensorflow_model_optimization/python/core/quantization/keras/quantize_annotate.py
@@ -23,13 +23,15 @@
 
 import tensorflow as tf
 
+from tensorflow_model_optimization.python.core.keras.compat import keras
 from tensorflow_model_optimization.python.core.quantization.keras import utils as quantize_utils
 
+
 deserialize_keras_object = quantize_utils.deserialize_keras_object
 serialize_keras_object = quantize_utils.serialize_keras_object
 
 
-class QuantizeAnnotate(tf.keras.layers.Wrapper):
+class QuantizeAnnotate(keras.layers.Wrapper):
   """Annotates layers which quantization should be applied to.
 
   QuantizeAnnotate does not actually apply quantization to the underlying
@@ -60,12 +62,15 @@ def __init__(self, layer, quantize_config=None, **kwargs):
       raise ValueError('`layer` cannot be None.')
 
     # Check against keras.Model since it is an instance of keras.layers.Layer.
-    if not isinstance(layer, tf.keras.layers.Layer) or isinstance(
-        layer, tf.keras.Model):
+    if not isinstance(layer, keras.layers.Layer) or isinstance(
+        layer, keras.Model
+    ):
       raise ValueError(
-          '`layer` can only be a `tf.keras.layers.Layer` instance. '
+          '`layer` can only be a `keras.layers.Layer` instance. '
           'You passed an instance of type: {input}.'.format(
-              input=layer.__class__.__name__))
+              input=layer.__class__.__name__
+          )
+      )
 
     self.quantize_config = quantize_config
 
@@ -73,14 +78,14 @@ def __init__(self, layer, quantize_config=None, **kwargs):
     # Enables end-user to annotate the first layer in Sequential models, while
     # passing the input shape to the original layer.
     #
-    # tf.keras.Sequential(
-    #   quantize_annotate_layer(tf.keras.layers.Dense(2, input_shape=(3,)))
+    # keras.Sequential(
+    #   quantize_annotate_layer(keras.layers.Dense(2, input_shape=(3,)))
     # )
     #
     # as opposed to
     #
-    # tf.keras.Sequential(
-    #   quantize_annotate_layer(tf.keras.layers.Dense(2), input_shape=(3,))
+    # keras.Sequential(
+    #   quantize_annotate_layer(keras.layers.Dense(2), input_shape=(3,))
     # )
     #
     # Without this code, the QuantizeAnnotate wrapper doesn't have an input
diff --git a/tensorflow_model_optimization/python/core/quantization/keras/quantize_annotate_test.py b/tensorflow_model_optimization/python/core/quantization/keras/quantize_annotate_test.py
index 5b3b93bd3..f7e8c080f 100644
--- a/tensorflow_model_optimization/python/core/quantization/keras/quantize_annotate_test.py
+++ b/tensorflow_model_optimization/python/core/quantization/keras/quantize_annotate_test.py
@@ -21,12 +21,13 @@
 import numpy as np
 import tensorflow as tf
 
+from tensorflow_model_optimization.python.core.keras.compat import keras
 from tensorflow_model_optimization.python.core.quantization.keras import quantize_annotate
 from tensorflow_model_optimization.python.core.quantization.keras import quantize_config as quantize_config_mod
 
-keras = tf.keras
-deserialize_layer = tf.keras.layers.deserialize
-serialize_layer = tf.keras.layers.serialize
+
+deserialize_layer = keras.layers.deserialize
+serialize_layer = keras.layers.serialize
 
 
 class QuantizeAnnotateTest(tf.test.TestCase):
@@ -53,7 +54,7 @@ def get_config(self):
 
   def testAnnotateLayerCallPassesTraningBoolean(self):
 
-    class MockLayer(tf.keras.layers.Layer):
+    class MockLayer(keras.layers.Layer):
       self.training = None
 
       def call(self, training=None):
@@ -98,7 +99,7 @@ def testSerializationQuantizeAnnotate(self):
     }
 
     serialized_wrapper = serialize_layer(wrapper)
-    with tf.keras.utils.custom_object_scope(custom_objects):
+    with keras.utils.custom_object_scope(custom_objects):
       wrapper_from_config = deserialize_layer(serialized_wrapper)
 
     self.assertEqual(wrapper_from_config.get_config(), wrapper.get_config())
diff --git a/tensorflow_model_optimization/python/core/quantization/keras/quantize_aware_activation.py b/tensorflow_model_optimization/python/core/quantization/keras/quantize_aware_activation.py
index 8a4d58914..a46da3fcc 100644
--- a/tensorflow_model_optimization/python/core/quantization/keras/quantize_aware_activation.py
+++ b/tensorflow_model_optimization/python/core/quantization/keras/quantize_aware_activation.py
@@ -21,9 +21,11 @@
 import tensorflow as tf
 
 from tensorflow_model_optimization.python.core.keras import utils
+from tensorflow_model_optimization.python.core.keras.compat import keras
 from tensorflow_model_optimization.python.core.quantization.keras import utils as quantize_utils
 
-activations = tf.keras.activations
+
+activations = keras.activations
 
 
 class NoOpActivation(object):
@@ -92,9 +94,10 @@ class QuantizeAwareActivation(object):
   _NO_QUANTIZE_ACTIVATIONS = frozenset({'NoOpActivation'})
 
   _CUSTOM_ACTIVATION_ERR_MSG = (
-      'Only some Keras activations under `tf.keras.activations` are supported. '
+      'Only some Keras activations under `keras.activations` are supported. '
       'For other activations, use `Quantizer` directly, and update layer '
-      'config using `QuantizeConfig`.')
+      'config using `QuantizeConfig`.'
+  )
 
   def __init__(self, activation, quantizer, step, quantize_wrapper):
     """Constructs object, and initializes weights for quantization.
diff --git a/tensorflow_model_optimization/python/core/quantization/keras/quantize_aware_activation_test.py b/tensorflow_model_optimization/python/core/quantization/keras/quantize_aware_activation_test.py
index 6c547ec67..7968e2202 100644
--- a/tensorflow_model_optimization/python/core/quantization/keras/quantize_aware_activation_test.py
+++ b/tensorflow_model_optimization/python/core/quantization/keras/quantize_aware_activation_test.py
@@ -19,17 +19,17 @@
 from __future__ import print_function
 
 from absl.testing import parameterized
-
 import numpy as np
 import tensorflow as tf
 
+from tensorflow_model_optimization.python.core.keras.compat import keras
 from tensorflow_model_optimization.python.core.quantization.keras import quantize_aware_activation
 from tensorflow_model_optimization.python.core.quantization.keras import quantizers
 from tensorflow_model_optimization.python.core.quantization.keras import utils as quantize_utils
 
-keras = tf.keras
-activations = tf.keras.activations
-K = tf.keras.backend
+
+activations = keras.activations
+K = keras.backend
 deserialize_keras_object = quantize_utils.deserialize_keras_object
 serialize_keras_object = quantize_utils.serialize_keras_object
 
@@ -155,7 +155,7 @@ def testSerializationReturnsWrappedActivation(
         'config': activation_config
     }
     self.assertEqual(expected_config, serialized_quantize_activation)
-    with tf.keras.utils.custom_object_scope({
+    with keras.utils.custom_object_scope({
         'QuantizeAwareActivation': QuantizeAwareActivation,
         'NoOpActivation': quantize_aware_activation.NoOpActivation,
     }):
diff --git a/tensorflow_model_optimization/python/core/quantization/keras/quantize_functional_test.py b/tensorflow_model_optimization/python/core/quantization/keras/quantize_functional_test.py
index 6e31353cf..b14fa616d 100644
--- a/tensorflow_model_optimization/python/core/quantization/keras/quantize_functional_test.py
+++ b/tensorflow_model_optimization/python/core/quantization/keras/quantize_functional_test.py
@@ -26,11 +26,13 @@
 
 # TODO(b/139939526): move to public API.
 from tensorflow_model_optimization.python.core.keras import compat
+from tensorflow_model_optimization.python.core.keras.compat import keras
 from tensorflow_model_optimization.python.core.keras.testing import test_utils_mnist
 from tensorflow_model_optimization.python.core.quantization.keras import quantize
 from tensorflow_model_optimization.python.core.quantization.keras import utils as test_utils
 
-layers = tf.keras.layers
+
+layers = keras.layers
 
 
 @tf.__internal__.distribute.combinations.generate(
@@ -310,7 +312,7 @@ def testQuantizeSingleLayer_ProducesFullIntegerModel_TF2(
       kwargs['input_shape'] = (5,)
 
     layer = layer_type(**kwargs)
-    model = tf.keras.Sequential([layer])
+    model = keras.Sequential([layer])
     quantized_model = quantize.quantize_model(model)
 
     _, quantized_tflite_file = tempfile.mkstemp('.tflite')
@@ -399,7 +401,7 @@ def testQuantizeSingleLayer_ProducesFullIntegerModel_TF1(
       kwargs['input_shape'] = (5,)
 
     layer = layer_type(**kwargs)
-    model = tf.keras.Sequential([layer])
+    model = keras.Sequential([layer])
     quantized_model = quantize.quantize_model(model)
 
     with quantize.quantize_scope():
diff --git a/tensorflow_model_optimization/python/core/quantization/keras/quantize_integration_test.py b/tensorflow_model_optimization/python/core/quantization/keras/quantize_integration_test.py
index 923ff8433..9042dd023 100644
--- a/tensorflow_model_optimization/python/core/quantization/keras/quantize_integration_test.py
+++ b/tensorflow_model_optimization/python/core/quantization/keras/quantize_integration_test.py
@@ -21,7 +21,6 @@
 import tempfile
 
 from absl.testing import parameterized
-
 import numpy as np
 import tensorflow as tf
 
@@ -29,15 +28,17 @@
 
 from tensorflow_model_optimization.python.core.keras import compat
 from tensorflow_model_optimization.python.core.keras import test_utils
+from tensorflow_model_optimization.python.core.keras.compat import keras
 from tensorflow_model_optimization.python.core.quantization.keras import quantize
 from tensorflow_model_optimization.python.core.quantization.keras import quantize_config
 from tensorflow_model_optimization.python.core.quantization.keras import quantizers
 
+
 QuantizeConfig = quantize_config.QuantizeConfig
 Quantizer = quantizers.Quantizer
 MovingAverageQuantizer = quantizers.MovingAverageQuantizer
 
-l = tf.keras.layers
+l = keras.layers
 
 
 # TODO(tfmot): enable for v1. Currently fails because the decorator
@@ -106,8 +107,9 @@ def _train_model(self, model):
         loss='categorical_crossentropy', optimizer='sgd', metrics=['accuracy'])
     model.fit(
         np.random.rand(20, 10),
-        tf.keras.utils.to_categorical(np.random.randint(5, size=(20, 1)), 5),
-        batch_size=20)
+        keras.utils.to_categorical(np.random.randint(5, size=(20, 1)), 5),
+        batch_size=20,
+    )
 
   ####################################################################
   # Tests for research with quantization.
@@ -121,7 +123,7 @@ def build(self, tensor_shape, name, layer):
       return {}
 
     def __call__(self, inputs, training, weights, **kwargs):
-      return tf.keras.backend.clip(inputs, -1.0, 1.0)
+      return keras.backend.clip(inputs, -1.0, 1.0)
 
     def get_config(self):
       return {}
@@ -173,11 +175,11 @@ def get_output_quantizers(self, layer):
       def get_config(self):
         return {}
 
-    annotated_model = tf.keras.Sequential([
-        quantize.quantize_annotate_layer(
-            l.Dense(8, input_shape=(10,)), DenseQuantizeConfig()),
+    annotated_model = keras.Sequential([
         quantize.quantize_annotate_layer(
-            l.Dense(5), DenseQuantizeConfig())
+            l.Dense(8, input_shape=(10,)), DenseQuantizeConfig()
+        ),
+        quantize.quantize_annotate_layer(l.Dense(5), DenseQuantizeConfig()),
     ])
 
     with quantize.quantize_scope(
@@ -197,9 +199,9 @@ def testSerialization_KerasModel(self):
     self._train_model(quantized_model)
 
     _, model_file = tempfile.mkstemp('.h5')
-    tf.keras.models.save_model(quantized_model, model_file)
+    keras.models.save_model(quantized_model, model_file)
     with quantize.quantize_scope():
-      loaded_model = tf.keras.models.load_model(model_file)
+      loaded_model = keras.models.load_model(model_file)
 
     self._assert_models_equal(quantized_model, loaded_model)
 
@@ -226,8 +228,8 @@ def testSerialization_SavedModel(self):
     self._train_model(quantized_model)
 
     model_dir = tempfile.mkdtemp()
-    tf.keras.models.save_model(quantized_model, model_dir)
-    loaded_model = tf.keras.models.load_model(model_dir)
+    keras.models.save_model(quantized_model, model_dir)
+    loaded_model = keras.models.load_model(model_dir)
 
     self._assert_outputs_equal(quantized_model, loaded_model)
 
diff --git a/tensorflow_model_optimization/python/core/quantization/keras/quantize_layer.py b/tensorflow_model_optimization/python/core/quantization/keras/quantize_layer.py
index 1393388e6..59df68c1c 100644
--- a/tensorflow_model_optimization/python/core/quantization/keras/quantize_layer.py
+++ b/tensorflow_model_optimization/python/core/quantization/keras/quantize_layer.py
@@ -24,15 +24,16 @@
 import tensorflow as tf
 
 from tensorflow_model_optimization.python.core.keras import utils
-
+from tensorflow_model_optimization.python.core.keras.compat import keras
 from tensorflow_model_optimization.python.core.quantization.keras import quantizers
 from tensorflow_model_optimization.python.core.quantization.keras import utils as quantize_utils
 
+
 serialize_keras_object = quantize_utils.serialize_keras_object
 deserialize_keras_object = quantize_utils.deserialize_keras_object
 
 
-class QuantizeLayer(tf.keras.layers.Layer):
+class QuantizeLayer(keras.layers.Layer):
   """Emulate quantization of tensors passed through the layer."""
 
   def __init__(self, quantizer, **kwargs):
@@ -59,16 +60,17 @@ def build(self, input_shape):
 
     self.optimizer_step = self.add_weight(
         'optimizer_step',
-        initializer=tf.keras.initializers.Constant(-1),
+        initializer=keras.initializers.Constant(-1),
         dtype=tf.dtypes.int32,
-        trainable=False)
+        trainable=False,
+    )
 
   def call(self, inputs, training=None):
     if not self.quantizer:
       return inputs
 
     if training is None:
-      training = tf.keras.backend.learning_phase()
+      training = keras.backend.learning_phase()
 
     def _make_quantizer_fn(train_var):
       def quantizer_fn():
diff --git a/tensorflow_model_optimization/python/core/quantization/keras/quantize_layer_test.py b/tensorflow_model_optimization/python/core/quantization/keras/quantize_layer_test.py
index e7ad5cffd..774294ca3 100644
--- a/tensorflow_model_optimization/python/core/quantization/keras/quantize_layer_test.py
+++ b/tensorflow_model_optimization/python/core/quantization/keras/quantize_layer_test.py
@@ -21,12 +21,14 @@
 import numpy as np
 import tensorflow as tf
 
+from tensorflow_model_optimization.python.core.keras.compat import keras
 from tensorflow_model_optimization.python.core.quantization.keras import quantize_layer
 from tensorflow_model_optimization.python.core.quantization.keras import quantizers
 
+
 QuantizeLayer = quantize_layer.QuantizeLayer
-deserialize_layer = tf.keras.layers.deserialize
-serialize_layer = tf.keras.layers.serialize
+deserialize_layer = keras.layers.deserialize
+serialize_layer = keras.layers.serialize
 
 
 class QuantizeLayerTest(tf.test.TestCase):
@@ -41,11 +43,9 @@ def setUp(self):
         per_axis=False, symmetric=True, **self.quant_params)
 
   def testQuantizesTensors(self):
-    model = tf.keras.Sequential([
-        QuantizeLayer(
-            quantizer=self.quantizer,
-            input_shape=(4,)
-        )])
+    model = keras.Sequential(
+        [QuantizeLayer(quantizer=self.quantizer, input_shape=(4,))]
+    )
 
     x = np.random.rand(1, 4)
     quant_x = tf.quantization.fake_quant_with_min_max_vars(
@@ -64,14 +64,14 @@ def testSerializationQuantizeLayer(self):
     }
 
     serialized_layer = serialize_layer(layer)
-    with tf.keras.utils.custom_object_scope(custom_objects):
+    with keras.utils.custom_object_scope(custom_objects):
       layer_from_config = deserialize_layer(serialized_layer)
 
     self.assertEqual(layer_from_config.get_config(), layer.get_config())
 
   def testNoQuantizeLayer(self):
     layer = QuantizeLayer(quantizer=None, input_shape=(4,))
-    model = tf.keras.Sequential([layer])
+    model = keras.Sequential([layer])
     x = np.random.rand(1, 4)
     self.assertAllClose(x, model.predict(x))
 
@@ -80,7 +80,7 @@ def testNoQuantizeLayer(self):
     }
 
     serialized_layer = serialize_layer(layer)
-    with tf.keras.utils.custom_object_scope(custom_objects):
+    with keras.utils.custom_object_scope(custom_objects):
       layer_from_config = deserialize_layer(serialized_layer)
 
     self.assertEqual(layer_from_config.get_config(), layer.get_config())
diff --git a/tensorflow_model_optimization/python/core/quantization/keras/quantize_models_test.py b/tensorflow_model_optimization/python/core/quantization/keras/quantize_models_test.py
index 2f5754c8d..55d0c1f2c 100644
--- a/tensorflow_model_optimization/python/core/quantization/keras/quantize_models_test.py
+++ b/tensorflow_model_optimization/python/core/quantization/keras/quantize_models_test.py
@@ -25,6 +25,7 @@
 import numpy as np
 import tensorflow as tf
 
+from tensorflow_model_optimization.python.core.keras.compat import keras
 from tensorflow_model_optimization.python.core.quantization.keras import quantize
 from tensorflow_model_optimization.python.core.quantization.keras import utils
 
@@ -32,7 +33,7 @@
 class QuantizeModelsTest(tf.test.TestCase, parameterized.TestCase):
 
   # Derived using
-  # `inspect.getmembers(tf.keras.applications, inspect.isfunction)`
+  # `inspect.getmembers(keras.applications, inspect.isfunction)`
   _KERAS_APPLICATION_MODELS = [
       # 'DenseNet121',
       # 'DenseNet169',
@@ -65,8 +66,7 @@ def _batch(self, dims, batch_size):
 
   def _get_model(self, model_type):
     model_fn = [
-        y for x, y in inspect.getmembers(tf.keras.applications)
-        if x == model_type
+        y for x, y in inspect.getmembers(keras.applications) if x == model_type
     ][0]
 
     input_shape = QuantizeModelsTest._MODEL_INPUT_SHAPES.get(
@@ -77,8 +77,9 @@ def _get_model(self, model_type):
   def _create_test_data(self, model):
     x_train = np.random.randn(
         *self._batch(model.input.get_shape().as_list(), 2)).astype('float32')
-    y_train = tf.keras.utils.to_categorical(
-        np.random.randint(1000, size=(2, 1)), 1000)
+    y_train = keras.utils.to_categorical(
+        np.random.randint(1000, size=(2, 1)), 1000
+    )
 
     return x_train, y_train
 
diff --git a/tensorflow_model_optimization/python/core/quantization/keras/quantize_test.py b/tensorflow_model_optimization/python/core/quantization/keras/quantize_test.py
index 1074a68a4..2a1bde102 100644
--- a/tensorflow_model_optimization/python/core/quantization/keras/quantize_test.py
+++ b/tensorflow_model_optimization/python/core/quantization/keras/quantize_test.py
@@ -23,6 +23,7 @@
 import tensorflow as tf
 
 from tensorflow_model_optimization.python.core.keras import test_utils as keras_test_utils
+from tensorflow_model_optimization.python.core.keras.compat import keras
 from tensorflow_model_optimization.python.core.quantization.keras import quantize
 from tensorflow_model_optimization.python.core.quantization.keras import quantize_annotate as quantize_annotate_mod
 from tensorflow_model_optimization.python.core.quantization.keras import quantize_config as quantize_config_mod
@@ -31,6 +32,7 @@
 from tensorflow_model_optimization.python.core.quantization.keras import quantizers
 from tensorflow_model_optimization.python.core.quantization.keras.default_8bit import default_8bit_quantize_registry
 
+
 quantize_annotate_layer = quantize.quantize_annotate_layer
 quantize_annotate_model = quantize.quantize_annotate_model
 quantize_apply = quantize.quantize_apply
@@ -38,9 +40,8 @@
 QuantizeAnnotate = quantize_annotate_mod.QuantizeAnnotate
 QuantizeWrapper = quantize_wrapper_mod.QuantizeWrapper
 
-keras = tf.keras
-K = tf.keras.backend
-custom_object_scope = tf.keras.utils.custom_object_scope
+K = keras.backend
+custom_object_scope = keras.utils.custom_object_scope
 
 
 class _TestQuantizeConfig(quantize_config_mod.QuantizeConfig):
@@ -531,11 +532,11 @@ def testQuantizeApply_RunsWhenNestedModelNotAnnotated(self):
 
     quantize_apply(annotated_model)
 
-  class CustomConvLayer(tf.keras.layers.Layer):
+  class CustomConvLayer(keras.layers.Layer):
 
     def __init__(self, name=None, **kwargs):
       super().__init__(name=name, **kwargs)
-      self.conv1 = tf.keras.layers.Conv2D(2, 2)
+      self.conv1 = keras.layers.Conv2D(2, 2)
 
     def build(self, input_shape):
       self.conv1.build(input_shape)
@@ -578,7 +579,7 @@ def apply_quantization_to_dense(layer):
             layer, quantize_config=self.CustomConvQuantizeConfig())
       return layer
 
-    annotated_model = tf.keras.models.clone_model(
+    annotated_model = keras.models.clone_model(
         model,
         clone_function=apply_quantization_to_dense,
     )
diff --git a/tensorflow_model_optimization/python/core/quantization/keras/quantize_wrapper.py b/tensorflow_model_optimization/python/core/quantization/keras/quantize_wrapper.py
index db39a5317..32a6e2dec 100644
--- a/tensorflow_model_optimization/python/core/quantization/keras/quantize_wrapper.py
+++ b/tensorflow_model_optimization/python/core/quantization/keras/quantize_wrapper.py
@@ -29,17 +29,18 @@
 import tensorflow as tf
 
 from tensorflow.python.util import tf_inspect
-
 from tensorflow_model_optimization.python.core.keras import metrics
 from tensorflow_model_optimization.python.core.keras import utils
+from tensorflow_model_optimization.python.core.keras.compat import keras
 from tensorflow_model_optimization.python.core.quantization.keras import quantize_aware_activation
 from tensorflow_model_optimization.python.core.quantization.keras import utils as quantize_utils
 
+
 deserialize_keras_object = quantize_utils.deserialize_keras_object
 serialize_keras_object = quantize_utils.serialize_keras_object
 
 
-class QuantizeWrapper(tf.keras.layers.Wrapper):
+class QuantizeWrapper(keras.layers.Wrapper):
   """Quantizes the weights and activations of the keras layer it wraps."""
 
   def __init__(self, layer, quantize_config, name_prefix='quant_', **kwargs):
@@ -59,12 +60,15 @@ def __init__(self, layer, quantize_config, name_prefix='quant_', **kwargs):
       name_prefix = ''
 
     # Check against keras.Model since it is an instance of keras.layers.Layer.
-    if not isinstance(layer, tf.keras.layers.Layer) or isinstance(
-        layer, tf.keras.Model):
+    if not isinstance(layer, keras.layers.Layer) or isinstance(
+        layer, keras.Model
+    ):
       raise ValueError(
-          '`layer` can only be a `tf.keras.layers.Layer` instance. '
+          '`layer` can only be a `keras.layers.Layer` instance. '
           'You passed an instance of type: {input}.'.format(
-              input=layer.__class__.__name__))
+              input=layer.__class__.__name__
+          )
+      )
 
     if quantize_config is None:
       raise ValueError('quantize_config cannot be None. It is needed to '
@@ -101,9 +105,10 @@ def build(self, input_shape):
 
     self.optimizer_step = self.add_weight(
         'optimizer_step',
-        initializer=tf.keras.initializers.Constant(-1),
+        initializer=keras.initializers.Constant(-1),
         dtype=tf.dtypes.int32,
-        trainable=False)
+        trainable=False,
+    )
 
     self._weight_vars = []
     for weight, quantizer in (
@@ -142,7 +147,7 @@ def quantizer_fn():
 
   def call(self, inputs, training=None, **kwargs):
     if training is None:
-      training = tf.keras.backend.learning_phase()
+      training = keras.backend.learning_phase()
 
     # Quantize all weights, and replace them in the underlying layer.
 
diff --git a/tensorflow_model_optimization/python/core/quantization/keras/quantize_wrapper_test.py b/tensorflow_model_optimization/python/core/quantization/keras/quantize_wrapper_test.py
index aa257dedb..e760e9a37 100644
--- a/tensorflow_model_optimization/python/core/quantization/keras/quantize_wrapper_test.py
+++ b/tensorflow_model_optimization/python/core/quantization/keras/quantize_wrapper_test.py
@@ -19,24 +19,24 @@
 from __future__ import print_function
 
 from absl.testing import parameterized
-
 import numpy as np
 import tensorflow as tf
 
+from tensorflow_model_optimization.python.core.keras.compat import keras
 from tensorflow_model_optimization.python.core.quantization.keras import quantize_aware_activation
 from tensorflow_model_optimization.python.core.quantization.keras import quantize_wrapper
 from tensorflow_model_optimization.python.core.quantization.keras.default_8bit import default_8bit_quantize_registry
 
+
 QuantizeAwareActivation = quantize_aware_activation.QuantizeAwareActivation
 QuantizeWrapper = quantize_wrapper.QuantizeWrapper
 QuantizeRegistry = default_8bit_quantize_registry.Default8BitQuantizeRegistry
 
-keras = tf.keras
-layers = tf.keras.layers
+layers = keras.layers
 
-custom_object_scope = tf.keras.utils.custom_object_scope
-deserialize_layer = tf.keras.layers.deserialize
-serialize_layer = tf.keras.layers.serialize
+custom_object_scope = keras.utils.custom_object_scope
+deserialize_layer = keras.layers.deserialize
+serialize_layer = keras.layers.serialize
 
 
 class QuantizeWrapperTest(tf.test.TestCase, parameterized.TestCase):
diff --git a/tensorflow_model_optimization/python/core/quantization/keras/quantizers.py b/tensorflow_model_optimization/python/core/quantization/keras/quantizers.py
index 98b33f7cf..a5a1f73f5 100644
--- a/tensorflow_model_optimization/python/core/quantization/keras/quantizers.py
+++ b/tensorflow_model_optimization/python/core/quantization/keras/quantizers.py
@@ -22,14 +22,13 @@
 from __future__ import print_function
 
 import abc
-import six
 
+import six
 import tensorflow as tf
 
+from tensorflow_model_optimization.python.core.keras.compat import keras
 from tensorflow_model_optimization.python.core.quantization.keras import quant_ops
 
-keras = tf.keras
-
 
 @six.add_metaclass(abc.ABCMeta)
 class Quantizer(object):
@@ -59,7 +58,7 @@ def build(self, tensor_shape, name, layer):
       }
 
     def __call__(self, inputs, training, weights, **kwargs):
-      return tf.keras.backend.clip(
+      return keras.backend.clip(
           inputs, 0.0, weights['range_var'])
 
     def get_config(self):
diff --git a/tensorflow_model_optimization/python/core/quantization/keras/utils.py b/tensorflow_model_optimization/python/core/quantization/keras/utils.py
index ed309c4e1..44c09bdc0 100644
--- a/tensorflow_model_optimization/python/core/quantization/keras/utils.py
+++ b/tensorflow_model_optimization/python/core/quantization/keras/utils.py
@@ -21,72 +21,68 @@
 import tensorflow as tf
 
 from tensorflow_model_optimization.python.core.keras import compat
+from tensorflow_model_optimization.python.core.keras.compat import keras
 
 
 def serialize_keras_object(obj):
-  if hasattr(tf.keras.utils, "legacy"):
-    return tf.keras.utils.legacy.serialize_keras_object(obj)
+  if hasattr(keras.utils, "legacy"):
+    return keras.utils.legacy.serialize_keras_object(obj)
   else:
-    return tf.keras.utils.serialize_keras_object(obj)
+    return keras.utils.serialize_keras_object(obj)
 
 
 def deserialize_keras_object(
     config, module_objects=None, custom_objects=None, printable_module_name=None
 ):
-  if hasattr(tf.keras.utils, "legacy"):
-    return tf.keras.utils.legacy.deserialize_keras_object(
+  if hasattr(keras.utils, "legacy"):
+    return keras.utils.legacy.deserialize_keras_object(
         config, custom_objects, module_objects, printable_module_name
     )
   else:
-    return tf.keras.utils.deserialize_keras_object(
+    return keras.utils.deserialize_keras_object(
         config, custom_objects, module_objects, printable_module_name
     )
 
 
 def serialize_layer(layer, use_legacy_format=False):
-  if (
-      "use_legacy_format"
-      in inspect.getfullargspec(tf.keras.layers.serialize).args
-  ):
-    return tf.keras.layers.serialize(layer, use_legacy_format=use_legacy_format)
+  if "use_legacy_format" in inspect.getfullargspec(keras.layers.serialize).args:
+    return keras.layers.serialize(layer, use_legacy_format=use_legacy_format)
   else:
-    return tf.keras.layers.serialize(layer)
+    return keras.layers.serialize(layer)
 
 
 def deserialize_layer(config, use_legacy_format=False):
   if (
       "use_legacy_format"
-      in inspect.getfullargspec(tf.keras.layers.deserialize).args
+      in inspect.getfullargspec(keras.layers.deserialize).args
   ):
-    return tf.keras.layers.deserialize(
-        config, use_legacy_format=use_legacy_format
-    )
+    return keras.layers.deserialize(config, use_legacy_format=use_legacy_format)
   else:
-    return tf.keras.layers.deserialize(config)
+    return keras.layers.deserialize(config)
 
 
 def serialize_activation(activation, use_legacy_format=False):
   if (
       "use_legacy_format"
-      in inspect.getfullargspec(tf.keras.activations.serialize).args
+      in inspect.getfullargspec(keras.activations.serialize).args
   ):
-    return tf.keras.activations.serialize(
+    return keras.activations.serialize(
         activation, use_legacy_format=use_legacy_format
     )
   else:
-    return tf.keras.activations.serialize(activation)
+    return keras.activations.serialize(activation)
 
 
 def deserialize_activation(config, use_legacy_format=False):
   if (
       "use_legacy_format"
-      in inspect.getfullargspec(tf.keras.activations.deserialize).args
+      in inspect.getfullargspec(keras.activations.deserialize).args
   ):
-    return tf.keras.activations.deserialize(
+    return keras.activations.deserialize(
         config, use_legacy_format=use_legacy_format
     )
   else:
-    return tf.keras.activations.deserialize(config)
+    return keras.activations.deserialize(config)
 
 
 def convert_keras_to_tflite(model,
@@ -104,7 +100,7 @@ def convert_keras_to_tflite(model,
     converter = tf.lite.TFLiteConverter.from_keras_model(model)
   else:
     _, keras_file = tempfile.mkstemp(".h5")
-    tf.keras.models.save_model(model, keras_file)
+    keras.models.save_model(model, keras_file)
     converter = tf.lite.TFLiteConverter.from_keras_model_file(
         keras_file, custom_objects=custom_objects)
 
diff --git a/tensorflow_model_optimization/python/core/sparsity/keras/BUILD b/tensorflow_model_optimization/python/core/sparsity/keras/BUILD
index 6f42d2881..292691d13 100644
--- a/tensorflow_model_optimization/python/core/sparsity/keras/BUILD
+++ b/tensorflow_model_optimization/python/core/sparsity/keras/BUILD
@@ -1,5 +1,5 @@
-# Placeholder: load py_test
 load("//tensorflow_model_optimization:tensorflow_model_optimization.bzl", "py_strict_library", "py_strict_test")
+# Placeholder: load py_test
 
 package(default_visibility = [
     "//tensorflow_model_optimization:__subpackages__",
@@ -31,6 +31,7 @@ py_strict_library(
         ":pruning_schedule",
         ":pruning_wrapper",
         # tensorflow dep1,
+        "//tensorflow_model_optimization/python/core/keras:compat",
         "//tensorflow_model_optimization/python/core/keras:metrics",
     ],
 )
@@ -53,6 +54,7 @@ py_strict_library(
     deps = [
         ":prunable_layer",
         # tensorflow dep1,
+        "//tensorflow_model_optimization/python/core/keras:compat",
     ],
 )
 
@@ -64,6 +66,7 @@ py_strict_library(
     deps = [
         ":pruning_wrapper",
         # tensorflow dep1,
+        "//tensorflow_model_optimization/python/core/keras:compat",
         "//tensorflow_model_optimization/python/core/quantization/keras:utils",
     ],
 )
@@ -141,6 +144,7 @@ py_strict_test(
         # google/protobuf:use_fast_cpp_protos dep1,  # Automatically added
         # numpy dep1,
         # tensorflow dep1,
+        "//tensorflow_model_optimization/python/core/keras:compat",
         "//tensorflow_model_optimization/python/core/keras:test_utils",
     ],
 )
@@ -193,6 +197,7 @@ py_strict_library(
         ":pruning_wrapper",
         # numpy dep1,
         # tensorflow dep1,
+        "//tensorflow_model_optimization/python/core/keras:compat",
     ],
 )
 
@@ -211,6 +216,7 @@ py_strict_test(
         # google/protobuf:use_fast_cpp_protos dep1,  # Automatically added
         # numpy dep1,
         # tensorflow dep1,
+        "//tensorflow_model_optimization/python/core/keras:compat",
         "//tensorflow_model_optimization/python/core/keras:test_utils",
     ],
 )
@@ -233,6 +239,7 @@ py_strict_test(
         # google/protobuf:use_fast_cpp_protos dep1,  # Automatically added
         # numpy dep1,
         # tensorflow dep1,
+        "//tensorflow_model_optimization/python/core/keras:compat",
         "//tensorflow_model_optimization/python/core/keras:test_utils",
     ],
 )
@@ -253,6 +260,7 @@ py_strict_test(
         # google/protobuf:use_fast_cpp_protos dep1,  # Automatically added
         # numpy dep1,
         # tensorflow dep1,
+        "//tensorflow_model_optimization/python/core/keras:compat",
         "//tensorflow_model_optimization/python/core/keras:test_utils",
     ],
 )
@@ -269,6 +277,7 @@ py_strict_test(
         # absl/testing:parameterized dep1,
         # google/protobuf:use_fast_cpp_protos dep1,  # Automatically added
         # tensorflow dep1,
+        "//tensorflow_model_optimization/python/core/keras:compat",
     ],
 )
 
@@ -283,6 +292,7 @@ py_strict_test(
         ":pruning_wrapper",
         # google/protobuf:use_fast_cpp_protos dep1,  # Automatically added
         # tensorflow dep1,
+        "//tensorflow_model_optimization/python/core/keras:compat",
     ],
 )
 
@@ -332,5 +342,6 @@ py_test(
         ":pruning_wrapper",
         # google/protobuf:use_fast_cpp_protos dep1,  # Automatically added
         # tensorflow dep1,
+        "//tensorflow_model_optimization/python/core/keras:compat",
     ],
 )
diff --git a/tensorflow_model_optimization/python/core/sparsity/keras/prune.py b/tensorflow_model_optimization/python/core/sparsity/keras/prune.py
index 6c85d6ae1..8d40c52d0 100644
--- a/tensorflow_model_optimization/python/core/sparsity/keras/prune.py
+++ b/tensorflow_model_optimization/python/core/sparsity/keras/prune.py
@@ -18,11 +18,12 @@
 import tensorflow as tf
 
 from tensorflow_model_optimization.python.core.keras import metrics
+from tensorflow_model_optimization.python.core.keras.compat import keras
 from tensorflow_model_optimization.python.core.sparsity.keras import pruning_schedule as pruning_sched
 from tensorflow_model_optimization.python.core.sparsity.keras import pruning_wrapper
 
-keras = tf.keras
-custom_object_scope = tf.keras.utils.custom_object_scope
+
+custom_object_scope = keras.utils.custom_object_scope
 
 
 def prune_scope():
@@ -31,7 +32,7 @@ def prune_scope():
   For TF 2.X: this is not needed for SavedModel or TF checkpoints, which are
   the recommended serialization formats.
 
-  For TF 1.X: if a tf.keras h5 model or layer has been pruned, it needs to be
+  For TF 1.X: if a keras h5 model or layer has been pruned, it needs to be
   within this
   scope to be successfully deserialized. This is not needed for loading just
   keras weights.
@@ -61,15 +62,15 @@ def prune_low_magnitude(to_prune,
                         pruning_policy=None,
                         sparsity_m_by_n=None,
                         **kwargs):
-  """Modify a tf.keras layer or model to be pruned during training.
+  """Modify a keras layer or model to be pruned during training.
 
-  This function wraps a tf.keras model or layer with pruning functionality which
+  This function wraps a keras model or layer with pruning functionality which
   sparsifies the layer's weights during training. For example, using this with
   50% sparsity will ensure that 50% of the layer's weights are zero.
 
   The function accepts either a single keras layer
-  (subclass of `tf.keras.layers.Layer`), list of keras layers or a Sequential
-  or Functional tf.keras model and handles them appropriately.
+  (subclass of `keras.layers.Layer`), list of keras layers or a Sequential
+  or Functional keras model and handles them appropriately.
 
   If it encounters a layer it does not know how to handle, it will throw an
   error. While pruning an entire model, even a single unknown layer would lead
@@ -127,8 +128,8 @@ def prune_low_magnitude(to_prune,
   (https://github.com/tensorflow/model-optimization/issues/206).
 
   Arguments:
-      to_prune: A single keras layer, list of keras layers, or a
-        `tf.keras.Model` instance.
+      to_prune: A single keras layer, list of keras layers, or a `keras.Model`
+        instance.
       pruning_schedule: A `PruningSchedule` object that controls pruning rate
         throughout training.
       block_size: (optional) The dimensions (height, weight) for the block
@@ -140,8 +141,8 @@ def prune_low_magnitude(to_prune,
         and is subject to change.
       sparsity_m_by_n: default None, otherwise a tuple of 2 integers, indicates
         pruning with m_by_n sparsity, e.g., (2, 4): 2 zeros out of 4 consecutive
-        elements. It check whether we can do pruning with m_by_n sparsity.
-        If this type of sparsity is not applicable, then an error is thrown.
+        elements. It check whether we can do pruning with m_by_n sparsity. If
+        this type of sparsity is not applicable, then an error is thrown.
       **kwargs: Additional keyword arguments to be passed to the keras layer.
         Ignored when to_prune is not a keras layer.
 
@@ -214,9 +215,10 @@ def _add_pruning_wrapper(layer):
   else:
     raise ValueError(
         '`prune_low_magnitude` can only prune an object of the following '
-        'types: tf.keras.models.Sequential, tf.keras functional model, '
-        'tf.keras.layers.Layer, list of tf.keras.layers.Layer. You passed '
-        'an object of type: {input}.'.format(input=to_prune.__class__.__name__))
+        'types: keras.models.Sequential, keras functional model, '
+        'keras.layers.Layer, list of keras.layers.Layer. You passed '
+        'an object of type: {input}.'.format(input=to_prune.__class__.__name__)
+    )
 
 
 def strip_pruning(model):
@@ -228,19 +230,19 @@ def strip_pruning(model):
   Only sequential and functional models are supported for now.
 
   Arguments:
-      model: A `tf.keras.Model` instance with pruned layers.
+      model: A `keras.Model` instance with pruned layers.
 
   Returns:
     A keras model with pruning wrappers removed.
 
   Raises:
-    ValueError: if the model is not a `tf.keras.Model` instance.
+    ValueError: if the model is not a `keras.Model` instance.
     NotImplementedError: if the model is a subclass model.
 
   Usage:
 
   ```python
-  orig_model = tf.keras.Model(inputs, outputs)
+  orig_model = keras.Model(inputs, outputs)
   pruned_model = prune_low_magnitude(orig_model)
   exported_model = strip_pruning(pruned_model)
   ```
@@ -249,10 +251,11 @@ def strip_pruning(model):
 
   if not isinstance(model, keras.Model):
     raise ValueError(
-        'Expected model to be a `tf.keras.Model` instance but got: ', model)
+        'Expected model to be a `keras.Model` instance but got: ', model
+    )
 
   def _strip_pruning_wrapper(layer):
-    if isinstance(layer, tf.keras.Model):
+    if isinstance(layer, keras.Model):
       # A keras model with prunable layers
       return keras.models.clone_model(
           layer, input_tensors=None, clone_function=_strip_pruning_wrapper)
diff --git a/tensorflow_model_optimization/python/core/sparsity/keras/prune_distributed_test.py b/tensorflow_model_optimization/python/core/sparsity/keras/prune_distributed_test.py
index 6b1ba4226..c9f5fad35 100644
--- a/tensorflow_model_optimization/python/core/sparsity/keras/prune_distributed_test.py
+++ b/tensorflow_model_optimization/python/core/sparsity/keras/prune_distributed_test.py
@@ -15,18 +15,18 @@
 """Distributed pruning test."""
 
 import tempfile
+
 from absl.testing import parameterized
 import numpy as np
 import tensorflow as tf
 
 from tensorflow_model_optimization.python.core.keras import test_utils as keras_test_utils
+from tensorflow_model_optimization.python.core.keras.compat import keras
 from tensorflow_model_optimization.python.core.sparsity.keras import prune
 from tensorflow_model_optimization.python.core.sparsity.keras import pruning_callbacks
 from tensorflow_model_optimization.python.core.sparsity.keras import pruning_schedule
 from tensorflow_model_optimization.python.core.sparsity.keras import test_utils
 
-keras = tf.keras
-
 
 def _distribution_strategies():
   return [
@@ -45,18 +45,18 @@ def setUp(self):
     self.params = {
         'pruning_schedule': pruning_schedule.ConstantSparsity(0.5, 0, -1, 1),
         'block_size': (1, 1),
-        'block_pooling_type': 'AVG'
+        'block_pooling_type': 'AVG',
     }
 
   @parameterized.parameters(_distribution_strategies())
   def testPrunesSimpleDenseModel(self, distribution):
     with distribution.scope():
       model = prune.prune_low_magnitude(
-          keras_test_utils.build_simple_dense_model(), **self.params)
+          keras_test_utils.build_simple_dense_model(), **self.params
+      )
       model.compile(
-          loss='categorical_crossentropy',
-          optimizer='sgd',
-          metrics=['accuracy'])
+          loss='categorical_crossentropy', optimizer='sgd', metrics=['accuracy']
+      )
 
     # Model hasn't been trained yet. Sparsity 0.0
     test_utils.assert_model_sparsity(self, 0.0, model)
@@ -67,7 +67,8 @@ def testPrunesSimpleDenseModel(self, distribution):
         keras.utils.to_categorical(np.random.randint(5, size=(20, 1)), 5),
         epochs=2,
         callbacks=[pruning_callbacks.UpdatePruningStep()],
-        batch_size=20)
+        batch_size=20,
+    )
     model.predict(np.random.rand(20, 10))
     test_utils.assert_model_sparsity(self, 0.5, model)
 
diff --git a/tensorflow_model_optimization/python/core/sparsity/keras/prune_integration_test.py b/tensorflow_model_optimization/python/core/sparsity/keras/prune_integration_test.py
index ae1c1b117..5e749455e 100644
--- a/tensorflow_model_optimization/python/core/sparsity/keras/prune_integration_test.py
+++ b/tensorflow_model_optimization/python/core/sparsity/keras/prune_integration_test.py
@@ -22,6 +22,7 @@
 
 # TODO(b/139939526): move to public API.
 from tensorflow_model_optimization.python.core.keras import test_utils as keras_test_utils
+from tensorflow_model_optimization.python.core.keras.compat import keras
 from tensorflow_model_optimization.python.core.sparsity.keras import prune
 from tensorflow_model_optimization.python.core.sparsity.keras import prune_registry
 from tensorflow_model_optimization.python.core.sparsity.keras import pruning_callbacks
@@ -29,7 +30,7 @@
 from tensorflow_model_optimization.python.core.sparsity.keras import pruning_wrapper
 from tensorflow_model_optimization.python.core.sparsity.keras import test_utils
 
-keras = tf.keras
+
 layers = keras.layers
 
 list_to_named_parameters = test_utils.list_to_named_parameters
@@ -41,10 +42,13 @@ class PruneIntegrationTest(tf.test.TestCase, parameterized.TestCase,
 
   # Fetch all the prunable layers from the registry.
   _PRUNABLE_LAYERS = [
-      layer for layer, weights in
-      prune_registry.PruneRegistry._LAYERS_WEIGHTS_MAP.items()
-      if (weights and layer != tf.keras.layers.Conv3DTranspose and
-          layer != tf.keras.layers.Conv2DTranspose)
+      layer
+      for layer, weights in prune_registry.PruneRegistry._LAYERS_WEIGHTS_MAP.items()
+      if (
+          weights
+          and layer != keras.layers.Conv3DTranspose
+          and layer != keras.layers.Conv2DTranspose
+      )
   ]
 
   # Fetch all the non-prunable layers from the registry.
@@ -206,8 +210,7 @@ def testPruneWithHighSparsity(self):
     for layer in model.layers:
       if isinstance(layer, pruning_wrapper.PruneLowMagnitude):
         for weight in layer.layer.get_prunable_weights():
-          self.assertEqual(1,
-                           np.count_nonzero(tf.keras.backend.get_value(weight)))
+          self.assertEqual(1, np.count_nonzero(keras.backend.get_value(weight)))
 
   ###################################################################
   # Tests for training with pruning with pretrained models or weights.
@@ -335,40 +338,40 @@ def testPrunesSingleLayer_ReachesTargetSparsity(self, layer_type):
   @parameterized.named_parameters(
       {
           'testcase_name': 'Conv2D',
-          'layer_type': tf.keras.layers.Conv2D,
+          'layer_type': keras.layers.Conv2D,
           'layer_arg': [16, (5, 7)],
           'input_shape': (10, 10, 8),
       },
       {
           'testcase_name': 'Dense',
-          'layer_type': tf.keras.layers.Dense,
+          'layer_type': keras.layers.Dense,
           'layer_arg': [16],
           'input_shape': [(8)],
       },
       {
           'testcase_name': 'Conv2D_not_multiple_4',
-          'layer_type': tf.keras.layers.Conv2D,
+          'layer_type': keras.layers.Conv2D,
           'layer_arg': [16, (5, 7)],
           'input_shape': (10, 10, 7),
           'sparsity_ratio': 0.428571,
       },
       {
           'testcase_name': 'Conv2D_1by2',
-          'layer_type': tf.keras.layers.Conv2D,
+          'layer_type': keras.layers.Conv2D,
           'layer_arg': [16, (5, 7)],
           'input_shape': (10, 10, 8),
           'm_by_n': (1, 2),
       },
       {
           'testcase_name': 'Dense_1by2',
-          'layer_type': tf.keras.layers.Dense,
+          'layer_type': keras.layers.Dense,
           'layer_arg': [16],
           'input_shape': [(8)],
           'm_by_n': (1, 2),
       },
       {
           'testcase_name': 'DepthwiseConv_2by4',
-          'layer_type': tf.keras.layers.DepthwiseConv2D,
+          'layer_type': keras.layers.DepthwiseConv2D,
           'layer_arg': [3],
           'input_shape': (7, 7, 32),
           'm_by_n': (2, 4),
@@ -405,13 +408,14 @@ def testSparsityPruningMbyN_SupportedSubclassLayers(self):
     m_by_n = (2, 4)
     self.params.update({'sparsity_m_by_n': m_by_n})
 
-    class SubclassLayer(tf.keras.layers.Layer):
+    class SubclassLayer(keras.layers.Layer):
 
       def __init__(self):
         super(SubclassLayer, self).__init__()
-        self.conv1 = tf.keras.layers.Conv2D(
-            2, 3, activation='relu', padding='same', input_shape=[7, 7, 3])
-        self.conv2 = tf.keras.layers.DepthwiseConv2D(3)
+        self.conv1 = keras.layers.Conv2D(
+            2, 3, activation='relu', padding='same', input_shape=[7, 7, 3]
+        )
+        self.conv2 = keras.layers.DepthwiseConv2D(3)
         self.flatten = keras.layers.Flatten()
         self.dense = layers.Dense(10, activation='sigmoid')
 
@@ -529,17 +533,20 @@ def testPruneRecursivelyReachesTargetSparsity(self):
     self._check_strip_pruning_matches_original(model, 0.5, input_data)
 
   def testMHALayerReachesTargetSparsity(self):
-    inp = tf.keras.layers.Input(shape=(32,32), batch_size=100)
-    x = tf.keras.layers.MultiHeadAttention(num_heads=2, key_dim=16)(query=inp, value=inp)
-    out = tf.keras.layers.Flatten()(x)
-    model = tf.keras.Model(inputs=inp, outputs=out)
+    inp = keras.layers.Input(shape=(32, 32), batch_size=100)
+    x = keras.layers.MultiHeadAttention(num_heads=2, key_dim=16)(
+        query=inp, value=inp
+    )
+    out = keras.layers.Flatten()(x)
+    model = keras.Model(inputs=inp, outputs=out)
     model = prune.prune_low_magnitude(model, **self.params)
     x_train = np.random.uniform(size=(500, 32, 32))
     y_train = np.random.randint(low=0, high=1024, size=(500,))
     model.compile(
-      optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4),
-      loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
-      metrics=[tf.keras.metrics.SparseCategoricalAccuracy(name='accuracy')])
+        optimizer=keras.optimizers.Adam(learning_rate=1e-4),
+        loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
+        metrics=[keras.metrics.SparseCategoricalAccuracy(name='accuracy')],
+    )
     test_utils.assert_model_sparsity(self, 0.0, model)
     model.fit(
         x_train,
@@ -595,8 +602,9 @@ def testPruneCheckpoints_CheckpointsNotSparse(self):
 
       callbacks = [
           pruning_callbacks.UpdatePruningStep(),
-          tf.keras.callbacks.ModelCheckpoint(
-              filepath=checkpoint_path, save_weights_only=True, save_freq=1)
+          keras.callbacks.ModelCheckpoint(
+              filepath=checkpoint_path, save_weights_only=True, save_freq=1
+          ),
       ]
 
       # Train one step. Sparsity reaches final sparsity.
diff --git a/tensorflow_model_optimization/python/core/sparsity/keras/prune_registry.py b/tensorflow_model_optimization/python/core/sparsity/keras/prune_registry.py
index 1aaca9ce7..82e2fcef0 100644
--- a/tensorflow_model_optimization/python/core/sparsity/keras/prune_registry.py
+++ b/tensorflow_model_optimization/python/core/sparsity/keras/prune_registry.py
@@ -16,8 +16,10 @@
 
 import tensorflow as tf
 
+from tensorflow_model_optimization.python.core.keras.compat import keras
 from tensorflow_model_optimization.python.core.sparsity.keras import prunable_layer
 
+
 try:
   # OSS case.
   import keras  # pylint: disable=g-import-not-at-top
@@ -28,11 +30,14 @@
     from keras.engine import base_layer  # pylint: disable=g-import-not-at-top,g-importing-member
 except ImportError:
   # Internal case.
-  base_layer = tf._keras_internal.engine.base_layer  # pylint: disable=protected-access
+  try:
+    base_layer = tf._keras_internal.engine.base_layer  # pylint: disable=protected-access
+  except:
+    base_layer = None
 
 # TODO(b/139939526): move to public API.
 
-layers = tf.keras.layers
+layers = keras.layers
 layers_compat_v1 = tf.compat.v1.keras.layers
 
 
@@ -107,15 +112,27 @@ class PruneRegistry(object):
       layers.MaxPooling2D: [],
       layers.MaxPooling3D: [],
       layers.MultiHeadAttention: [
-          '_query_dense.kernel', '_key_dense.kernel', '_value_dense.kernel',
-          '_output_dense.kernel'
+          '_query_dense.kernel',
+          '_key_dense.kernel',
+          '_value_dense.kernel',
+          '_output_dense.kernel',
       ],
-      layers.experimental.SyncBatchNormalization: [],
-      layers.experimental.preprocessing.Rescaling.__class__: [],
-      base_layer.TensorFlowOpLayer: [],
       layers_compat_v1.BatchNormalization: [],
   }
 
+  if hasattr(layers, 'experimental'):
+    if hasattr(layers.experimental, 'SyncBatchNormalization'):
+      _LAYERS_WEIGHTS_MAP[layers.experimental.SyncBatchNormalization] = []
+    if hasattr(layers.experimental, 'preprocessing') and hasattr(
+        layers.experimental.preprocessing, 'Rescaling'
+    ):
+      _LAYERS_WEIGHTS_MAP[
+          layers.experimental.preprocessing.Rescaling.__class__
+      ] = []
+
+  if base_layer:
+    _LAYERS_WEIGHTS_MAP[base_layer.TensorFlowOpLayer] = []
+
   _RNN_CELLS_WEIGHTS_MAP = {
       # Allowlist via compat.v1 and compat.v2 to support legacy TensorFlow 2.X
       # behavior where the v2 RNN uses the v1 RNNCell instead of the v2 RNNCell.
diff --git a/tensorflow_model_optimization/python/core/sparsity/keras/prune_registry_test.py b/tensorflow_model_optimization/python/core/sparsity/keras/prune_registry_test.py
index bfe9c3ecc..d3b2a1b42 100644
--- a/tensorflow_model_optimization/python/core/sparsity/keras/prune_registry_test.py
+++ b/tensorflow_model_optimization/python/core/sparsity/keras/prune_registry_test.py
@@ -17,10 +17,11 @@
 from absl.testing import parameterized
 import tensorflow as tf
 
+from tensorflow_model_optimization.python.core.keras.compat import keras
 from tensorflow_model_optimization.python.core.sparsity.keras import prunable_layer
 from tensorflow_model_optimization.python.core.sparsity.keras import prune_registry
 
-keras = tf.keras
+
 layers = keras.layers
 PruneRegistry = prune_registry.PruneRegistry
 
@@ -87,6 +88,17 @@ class PruneRegistryTest(tf.test.TestCase, parameterized.TestCase):
       ]),
       keras.layers.RNN(MinimalRNNCellPrunable(32)),
   ]
+  if hasattr(layers, 'experimental'):
+    if hasattr(layers.experimental, 'SyncBatchNormalization'):
+      _PRUNE_REGISTRY_SUPPORTED_LAYERS += [
+          layers.experimental.SyncBatchNormalization()
+      ]
+    if hasattr(layers.experimental, 'preprocessing') and hasattr(
+        layers.experimental.preprocessing, 'Rescaling'
+    ):
+      _PRUNE_REGISTRY_SUPPORTED_LAYERS += [
+          layers.experimental.preprocessing.Rescaling
+      ]
 
   @parameterized.parameters(_PRUNE_REGISTRY_SUPPORTED_LAYERS)
   def testSupportsLayer(self, layer):
diff --git a/tensorflow_model_optimization/python/core/sparsity/keras/prune_test.py b/tensorflow_model_optimization/python/core/sparsity/keras/prune_test.py
index 2105f37f6..36761ff42 100644
--- a/tensorflow_model_optimization/python/core/sparsity/keras/prune_test.py
+++ b/tensorflow_model_optimization/python/core/sparsity/keras/prune_test.py
@@ -12,7 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 # ==============================================================================
-"""Tests for tf.keras pruning APIs under prune.py."""
+"""Tests for keras pruning APIs under prune.py."""
 
 import json
 import tempfile
@@ -22,12 +22,13 @@
 
 # TODO(b/139939526): move to public API.
 from tensorflow_model_optimization.python.core.keras import test_utils as keras_test_utils
+from tensorflow_model_optimization.python.core.keras.compat import keras
 from tensorflow_model_optimization.python.core.sparsity.keras import prunable_layer
 from tensorflow_model_optimization.python.core.sparsity.keras import prune
 from tensorflow_model_optimization.python.core.sparsity.keras import pruning_schedule
 from tensorflow_model_optimization.python.core.sparsity.keras import pruning_wrapper
 
-keras = tf.keras
+
 errors_impl = tf.errors
 layers = keras.layers
 test = tf.test
@@ -57,18 +58,20 @@ class CustomNonPrunableLayer(layers.Dense):
 
 class PruneTest(test.TestCase):
 
-  INVALID_TO_PRUNE_PARAM_ERROR = ('`prune_low_magnitude` can only prune an '
-                                  'object of the following types: '
-                                  'tf.keras.models.Sequential, tf.keras '
-                                  'functional model, tf.keras.layers.Layer, '
-                                  'list of tf.keras.layers.Layer. You passed an'
-                                  ' object of type: {input}.')
+  INVALID_TO_PRUNE_PARAM_ERROR = (
+      '`prune_low_magnitude` can only prune an '
+      'object of the following types: '
+      'keras.models.Sequential, keras '
+      'functional model, keras.layers.Layer, '
+      'list of keras.layers.Layer. You passed an'
+      ' object of type: {input}.'
+  )
 
   def setUp(self):
     super(PruneTest, self).setUp()
 
     # Layers passed in for Pruning can either be standard Keras layers provided
-    # by the tf.keras API (these fall under the `keras.layers` namespace), or
+    # by the keras API (these fall under the `keras.layers` namespace), or
     # custom layers provided by the user which inherit the base
     # `keras.layers.Layer`.
     # Standard Keras layers can either be Prunable (we know how to prune them),
@@ -362,11 +365,11 @@ def testPruneScope_NeededForKerasModel(self):
     pruned_model.save(keras_model)
 
     with self.assertRaises(ValueError):
-      tf.keras.models.load_model(keras_model)
+      keras.models.load_model(keras_model)
 
     # works with `prune_scope`
     with prune.prune_scope():
-      tf.keras.models.load_model(keras_model)
+      keras.models.load_model(keras_model)
 
   def testPruneScope_NotNeededForKerasCheckpoint(self):
     model = keras_test_utils.build_simple_dense_model()
@@ -421,13 +424,13 @@ def testPruneScope_NeededForTF1SavedModel(self):
 
     saved_model_dir = tempfile.mkdtemp()
 
-    tf.keras.experimental.export_saved_model(pruned_model, saved_model_dir)
+    keras.experimental.export_saved_model(pruned_model, saved_model_dir)
     with self.assertRaises(ValueError):
-      tf.keras.experimental.load_from_saved_model(saved_model_dir)
+      keras.experimental.load_from_saved_model(saved_model_dir)
 
     # works with `prune_scope`
     with prune.prune_scope():
-      tf.keras.experimental.load_from_saved_model(saved_model_dir)
+      keras.experimental.load_from_saved_model(saved_model_dir)
 
 
 if __name__ == '__main__':
diff --git a/tensorflow_model_optimization/python/core/sparsity/keras/pruning_callbacks.py b/tensorflow_model_optimization/python/core/sparsity/keras/pruning_callbacks.py
index a51b3faa6..555de78be 100644
--- a/tensorflow_model_optimization/python/core/sparsity/keras/pruning_callbacks.py
+++ b/tensorflow_model_optimization/python/core/sparsity/keras/pruning_callbacks.py
@@ -19,15 +19,18 @@
 from __future__ import print_function
 
 # import g3
+
 import numpy as np
 import six
 import tensorflow as tf
 
 from tensorflow_model_optimization.python.core.keras import compat
+from tensorflow_model_optimization.python.core.keras.compat import keras
 from tensorflow_model_optimization.python.core.sparsity.keras import pruning_wrapper
 
-K = tf.keras.backend
-callbacks = tf.keras.callbacks
+
+K = keras.backend
+callbacks = keras.callbacks
 
 
 class UpdatePruningStep(callbacks.Callback):
diff --git a/tensorflow_model_optimization/python/core/sparsity/keras/pruning_callbacks_test.py b/tensorflow_model_optimization/python/core/sparsity/keras/pruning_callbacks_test.py
index 31106100d..30398357e 100644
--- a/tensorflow_model_optimization/python/core/sparsity/keras/pruning_callbacks_test.py
+++ b/tensorflow_model_optimization/python/core/sparsity/keras/pruning_callbacks_test.py
@@ -22,10 +22,11 @@
 
 # TODO(b/139939526): move to public API.
 from tensorflow_model_optimization.python.core.keras import test_utils as keras_test_utils
+from tensorflow_model_optimization.python.core.keras.compat import keras
 from tensorflow_model_optimization.python.core.sparsity.keras import prune
 from tensorflow_model_optimization.python.core.sparsity.keras import pruning_callbacks
 
-keras = tf.keras
+
 errors_impl = tf.errors
 
 
@@ -67,9 +68,11 @@ def testUpdatePruningStepsAndLogsSummaries(self):
         ])
 
     self.assertEqual(
-        3, tf.keras.backend.get_value(pruned_model.layers[0].pruning_step))
+        3, keras.backend.get_value(pruned_model.layers[0].pruning_step)
+    )
     self.assertEqual(
-        3, tf.keras.backend.get_value(pruned_model.layers[1].pruning_step))
+        3, keras.backend.get_value(pruned_model.layers[1].pruning_step)
+    )
 
     self._assertLogsExist(log_dir)
 
@@ -107,9 +110,11 @@ def testUpdatePruningStepsAndLogsSummaries_CustomTrainingLoop(self):
       step_callback.on_epoch_end(batch=unused_arg)
 
     self.assertEqual(
-        3, tf.keras.backend.get_value(pruned_model.layers[0].pruning_step))
+        3, keras.backend.get_value(pruned_model.layers[0].pruning_step)
+    )
     self.assertEqual(
-        3, tf.keras.backend.get_value(pruned_model.layers[1].pruning_step))
+        3, keras.backend.get_value(pruned_model.layers[1].pruning_step)
+    )
     self._assertLogsExist(log_dir)
 
   def testUpdatePruningStepsAndLogsSummaries_RunInference(self):
@@ -119,9 +124,11 @@ def testUpdatePruningStepsAndLogsSummaries_RunInference(self):
     del model_output
 
     self.assertEqual(
-        -1, tf.keras.backend.get_value(pruned_model.layers[0].pruning_step))
+        -1, keras.backend.get_value(pruned_model.layers[0].pruning_step)
+    )
     self.assertEqual(
-        -1, tf.keras.backend.get_value(pruned_model.layers[1].pruning_step))
+        -1, keras.backend.get_value(pruned_model.layers[1].pruning_step)
+    )
 
   def testPruneTrainingRaisesError_PruningStepCallbackMissing(self):
     pruned_model, x_train, y_train = self._pruned_model_setup()
diff --git a/tensorflow_model_optimization/python/core/sparsity/keras/pruning_impl_test.py b/tensorflow_model_optimization/python/core/sparsity/keras/pruning_impl_test.py
index dd721ea37..d8b98e8af 100644
--- a/tensorflow_model_optimization/python/core/sparsity/keras/pruning_impl_test.py
+++ b/tensorflow_model_optimization/python/core/sparsity/keras/pruning_impl_test.py
@@ -26,10 +26,12 @@
 
 # TODO(b/139939526): move to public API.
 from tensorflow_model_optimization.python.core.keras import compat
+from tensorflow_model_optimization.python.core.keras.compat import keras
 from tensorflow_model_optimization.python.core.sparsity.keras import pruning_impl
 from tensorflow_model_optimization.python.core.sparsity.keras import pruning_schedule
 
-K = tf.keras.backend
+
+K = keras.backend
 dtypes = tf.dtypes
 test = tf.test
 
diff --git a/tensorflow_model_optimization/python/core/sparsity/keras/pruning_policy.py b/tensorflow_model_optimization/python/core/sparsity/keras/pruning_policy.py
index 39c5fe1fe..64a6c296d 100644
--- a/tensorflow_model_optimization/python/core/sparsity/keras/pruning_policy.py
+++ b/tensorflow_model_optimization/python/core/sparsity/keras/pruning_policy.py
@@ -16,13 +16,16 @@
 """Pruning Policy classes to control application of pruning wrapper."""
 
 import abc
+
 import tensorflow as tf
 
+from tensorflow_model_optimization.python.core.keras.compat import keras
 from tensorflow_model_optimization.python.core.quantization.keras import utils as quantize_utils
 from tensorflow_model_optimization.python.core.sparsity.keras import pruning_wrapper
 
-layers = tf.keras.layers
-activations = tf.keras.activations
+
+layers = keras.layers
+activations = keras.activations
 
 
 class PruningPolicy(abc.ABC):
@@ -71,7 +74,7 @@ def ensure_model_supports_pruning(self, model):
     """Checks that the model contains only supported layers.
 
     Args:
-      model: A `tf.keras.Model` instance which is going to be pruned.
+      model: A `keras.Model` instance which is going to be pruned.
 
     Raises:
       ValueError: if the keras model doesn't support pruning policy, i.e. keras
@@ -108,8 +111,11 @@ def _get_producers(self, layer):
   def _get_consumers(self, layer):
 
     def unpack(layer):
-      return (unpack(layer.layers[0])
-              if isinstance(layer, tf.keras.Sequential) else layer)
+      return (
+          unpack(layer.layers[0])
+          if isinstance(layer, keras.Sequential)
+          else layer
+      )
 
     return [unpack(node.outbound_layer) for node in layer._outbound_nodes]
 
@@ -221,8 +227,20 @@ def _check_layer_support(self, layer):
           layer.activation, use_legacy_format=True
       ) in ('relu', 'relu6', 'leaky_relu', 'elu', 'sigmoid')
     elif layer.__class__.__name__ == 'TFOpLambda':
-      return layer.function in (tf.identity, tf.__operators__.add, tf.math.add,
-                                tf.math.subtract, tf.math.multiply)
+      if layer.function in (
+          tf.identity,
+          tf.__operators__.add,
+          tf.math.add,
+          tf.math.subtract,
+          tf.math.multiply,
+      ):
+        return True
+      return layer.function.__name__ in [
+          'identity',
+          'add',
+          'subtract',
+          'multiply',
+      ]
     elif isinstance(layer, pruning_wrapper.PruneLowMagnitude):
       return self._check_layer_support(layer.layer)
     return False
@@ -231,8 +249,9 @@ def ensure_model_supports_pruning(self, model):
     """Ensures that the model contains only supported layers."""
 
     # Check whether the model is a subclass model.
-    if (not model._is_graph_network and
-        not isinstance(model, tf.keras.models.Sequential)):
+    if not model._is_graph_network and not isinstance(
+        model, keras.models.Sequential
+    ):
       raise ValueError('Subclassed models are not supported currently.')
 
     if not model.built:
diff --git a/tensorflow_model_optimization/python/core/sparsity/keras/pruning_policy_test.py b/tensorflow_model_optimization/python/core/sparsity/keras/pruning_policy_test.py
index a1ebe24cd..d9bb3504c 100644
--- a/tensorflow_model_optimization/python/core/sparsity/keras/pruning_policy_test.py
+++ b/tensorflow_model_optimization/python/core/sparsity/keras/pruning_policy_test.py
@@ -18,12 +18,13 @@
 
 import tensorflow as tf
 
+from tensorflow_model_optimization.python.core.keras.compat import keras
 from tensorflow_model_optimization.python.core.sparsity.keras import prune
 from tensorflow_model_optimization.python.core.sparsity.keras import pruning_policy
 from tensorflow_model_optimization.python.core.sparsity.keras import pruning_schedule
 from tensorflow_model_optimization.python.core.sparsity.keras import pruning_wrapper
 
-keras = tf.keras
+
 layers = keras.layers
 
 
@@ -316,8 +317,9 @@ def testPruneFunctionalModelAfterCloneForLatencyOnXNNPackPolicy(self):
     o = layers.GlobalAveragePooling2D()(x)
     original_model = keras.Model(inputs=[i], outputs=[o])
 
-    cloned_model = tf.keras.models.clone_model(
-        original_model, clone_function=lambda l: l)
+    cloned_model = keras.models.clone_model(
+        original_model, clone_function=lambda l: l
+    )
     pruned_model = prune.prune_low_magnitude(
         cloned_model,
         pruning_policy=pruning_policy.PruneForLatencyOnXNNPack(),
diff --git a/tensorflow_model_optimization/python/core/sparsity/keras/pruning_utils_test.py b/tensorflow_model_optimization/python/core/sparsity/keras/pruning_utils_test.py
index 2c43d06cb..dcdb36ebd 100644
--- a/tensorflow_model_optimization/python/core/sparsity/keras/pruning_utils_test.py
+++ b/tensorflow_model_optimization/python/core/sparsity/keras/pruning_utils_test.py
@@ -19,13 +19,16 @@
 from __future__ import print_function
 
 # import g3
-from absl.testing import parameterized
 
+from absl.testing import parameterized
 import tensorflow as tf
+
 from tensorflow_model_optimization.python.core.keras import compat
+from tensorflow_model_optimization.python.core.keras.compat import keras
 from tensorflow_model_optimization.python.core.sparsity.keras import pruning_utils
 
-glorot_uniform_initializer = tf.keras.initializers.glorot_uniform
+
+glorot_uniform_initializer = keras.initializers.glorot_uniform
 
 
 @parameterized.named_parameters(
diff --git a/tensorflow_model_optimization/python/core/sparsity/keras/pruning_wrapper.py b/tensorflow_model_optimization/python/core/sparsity/keras/pruning_wrapper.py
index 4845abcd7..d134640a3 100644
--- a/tensorflow_model_optimization/python/core/sparsity/keras/pruning_wrapper.py
+++ b/tensorflow_model_optimization/python/core/sparsity/keras/pruning_wrapper.py
@@ -36,8 +36,9 @@
 from tensorflow_model_optimization.python.core.sparsity.keras import pruning_impl
 from tensorflow_model_optimization.python.core.sparsity.keras import pruning_schedule as pruning_sched
 from tensorflow_model_optimization.python.core.sparsity.keras.pruning_utils import convert_to_tuple_of_two_int
+from tensorflow_model_optimization.python.core.keras.compat import keras
+
 
-keras = tf.keras
 K = keras.backend
 Wrapper = keras.layers.Wrapper
 
@@ -166,7 +167,7 @@ def __init__(self,
           'Unsupported pooling type \'{}\'. Should be \'AVG\' or \'MAX\'.'
           .format(block_pooling_type))
 
-    if not isinstance(layer, tf.keras.layers.Layer):
+    if not isinstance(layer, keras.layers.Layer):
       raise ValueError(
           'Please initialize `Prune` layer with a '
           '`Layer` instance. You passed: {input}'.format(input=layer))
@@ -203,14 +204,14 @@ def __init__(self,
     # Enables end-user to prune the first layer in Sequential models, while
     # passing the input shape to the original layer.
     #
-    # tf.keras.Sequential(
-    #   prune_low_magnitude(tf.keras.layers.Dense(2, input_shape=(3,)))
+    # keras.Sequential(
+    #   prune_low_magnitude(keras.layers.Dense(2, input_shape=(3,)))
     # )
     #
     # as opposed to
     #
-    # tf.keras.Sequential(
-    #   prune_low_magnitude(tf.keras.layers.Dense(2), input_shape=(3,))
+    # keras.Sequential(
+    #   prune_low_magnitude(keras.layers.Dense(2), input_shape=(3,))
     # )
     #
     # Without this code, the pruning wrapper doesn't have an input
@@ -235,17 +236,19 @@ def build(self, input_shape):
       mask = self.add_weight(
           'mask',
           shape=weight.shape,
-          initializer=tf.keras.initializers.get('ones'),
+          initializer=keras.initializers.get('ones'),
           dtype=weight.dtype,
           trainable=False,
-          aggregation=tf.VariableAggregation.MEAN)
+          aggregation=tf.VariableAggregation.MEAN,
+      )
       threshold = self.add_weight(
           'threshold',
           shape=[],
-          initializer=tf.keras.initializers.get('zeros'),
+          initializer=keras.initializers.get('zeros'),
           dtype=weight.dtype,
           trainable=False,
-          aggregation=tf.VariableAggregation.MEAN)
+          aggregation=tf.VariableAggregation.MEAN,
+      )
 
       weight_vars.append(weight)
       mask_vars.append(mask)
@@ -256,9 +259,10 @@ def build(self, input_shape):
     self.pruning_step = self.add_weight(
         'pruning_step',
         shape=[],
-        initializer=tf.keras.initializers.Constant(-1),
+        initializer=keras.initializers.Constant(-1),
         dtype=tf.int64,
-        trainable=False)
+        trainable=False,
+    )
 
     def training_step_fn():
       return self.pruning_step
@@ -392,7 +396,7 @@ def collect_prunable_layers(model):
   prunable_layers = []
   for layer in model.layers:
     # A keras model may have other models as layers.
-    if isinstance(layer, tf.keras.Model):
+    if isinstance(layer, keras.Model):
       prunable_layers += collect_prunable_layers(layer)
     if isinstance(layer, PruneLowMagnitude):
       prunable_layers.append(layer)
diff --git a/tensorflow_model_optimization/python/core/sparsity/keras/pruning_wrapper_test.py b/tensorflow_model_optimization/python/core/sparsity/keras/pruning_wrapper_test.py
index 456c0ff96..7e270a780 100644
--- a/tensorflow_model_optimization/python/core/sparsity/keras/pruning_wrapper_test.py
+++ b/tensorflow_model_optimization/python/core/sparsity/keras/pruning_wrapper_test.py
@@ -20,10 +20,11 @@
 
 import tensorflow as tf
 
+from tensorflow_model_optimization.python.core.keras.compat import keras
 from tensorflow_model_optimization.python.core.sparsity.keras import pruning_schedule
 from tensorflow_model_optimization.python.core.sparsity.keras import pruning_wrapper
 
-keras = tf.keras
+
 layers = keras.layers
 Prune = pruning_wrapper.PruneLowMagnitude
 
diff --git a/tensorflow_model_optimization/python/core/sparsity/keras/test_utils.py b/tensorflow_model_optimization/python/core/sparsity/keras/test_utils.py
index b8af569e8..073d60b62 100644
--- a/tensorflow_model_optimization/python/core/sparsity/keras/test_utils.py
+++ b/tensorflow_model_optimization/python/core/sparsity/keras/test_utils.py
@@ -15,15 +15,16 @@
 """Test utility to generate models for testing."""
 
 import tempfile
-import numpy as np
 
+import numpy as np
 import tensorflow as tf
 
+from tensorflow_model_optimization.python.core.keras.compat import keras
 from tensorflow_model_optimization.python.core.sparsity.keras import prune
 from tensorflow_model_optimization.python.core.sparsity.keras import pruning_utils
 from tensorflow_model_optimization.python.core.sparsity.keras import pruning_wrapper
 
-keras = tf.keras
+
 l = keras.layers
 
 
@@ -148,9 +149,9 @@ def _save_restore_keras_model(model):
 
 def _save_restore_tf_model(model):
   tmpdir = tempfile.mkdtemp()
-  tf.keras.models.save_model(model, tmpdir, save_format='tf')
+  keras.models.save_model(model, tmpdir, save_format='tf')
   with prune.prune_scope():
-    loaded_model = tf.keras.models.load_model(tmpdir)
+    loaded_model = keras.models.load_model(tmpdir)
   return loaded_model
 
 
@@ -171,9 +172,10 @@ def assert_model_sparsity(test_case, sparsity, model, rtol=1e-6, atol=1e-6):
       for weight in layer.layer.get_prunable_weights():
         test_case.assertAllClose(
             sparsity,
-            _get_sparsity(tf.keras.backend.get_value(weight)),
+            _get_sparsity(keras.backend.get_value(weight)),
             rtol=rtol,
-            atol=atol)
+            atol=atol,
+        )
 
 
 # Check if model does not have target sparsity.
@@ -181,7 +183,7 @@ def is_model_sparsity_not(sparsity, model):
   for layer in model.layers:
     if isinstance(layer, pruning_wrapper.PruneLowMagnitude):
       for weight in layer.layer.get_prunable_weights():
-        if sparsity != _get_sparsity(tf.keras.backend.get_value(weight)):
+        if sparsity != _get_sparsity(keras.backend.get_value(weight)):
           return True
   return False
 
diff --git a/tensorflow_model_optimization/python/core/sparsity/keras/tools/BUILD b/tensorflow_model_optimization/python/core/sparsity/keras/tools/BUILD
index 936453455..95ef7ed83 100644
--- a/tensorflow_model_optimization/python/core/sparsity/keras/tools/BUILD
+++ b/tensorflow_model_optimization/python/core/sparsity/keras/tools/BUILD
@@ -1,6 +1,6 @@
-# Placeholder: load py_test
-# Placeholder: load py_binary
 load("//tensorflow_model_optimization:tensorflow_model_optimization.bzl", "py_strict_library")
+# Placeholder: load py_binary
+# Placeholder: load py_test
 
 package(default_visibility = [
     "//tensorflow_model_optimization:__subpackages__",
@@ -15,6 +15,7 @@ py_strict_library(
     visibility = ["//visibility:public"],
     deps = [
         # tensorflow dep1,
+        "//tensorflow_model_optimization/python/core/keras:compat",
         "//tensorflow_model_optimization/python/core/keras:metrics",
         "//tensorflow_model_optimization/python/core/sparsity/keras:prune",
         "//tensorflow_model_optimization/python/core/sparsity/keras:pruning_schedule",
@@ -46,6 +47,7 @@ py_binary(
         ":sparsity_tooling",
         # google/protobuf:use_fast_cpp_protos dep1,  # Automatically added
         # tensorflow dep1,
+        "//tensorflow_model_optimization/python/core/keras:compat",
         "//tensorflow_model_optimization/python/core/sparsity/keras:prune",
     ],
 )
diff --git a/tensorflow_model_optimization/python/core/sparsity/keras/tools/evaluate_pruning.py b/tensorflow_model_optimization/python/core/sparsity/keras/tools/evaluate_pruning.py
index eba50e22f..ab897d98c 100644
--- a/tensorflow_model_optimization/python/core/sparsity/keras/tools/evaluate_pruning.py
+++ b/tensorflow_model_optimization/python/core/sparsity/keras/tools/evaluate_pruning.py
@@ -34,6 +34,7 @@
 from absl import flags
 import tensorflow as tf
 
+from tensorflow_model_optimization.python.core.keras.compat import keras
 from tensorflow_model_optimization.python.core.sparsity.keras import prune
 from tensorflow_model_optimization.python.core.sparsity.keras.tools import sparsity_tooling
 
@@ -100,7 +101,7 @@ def run(input_model_path, output_dir, target_sparsity, block_size):
              are not intended to be served in production, but to be used for
              performance benchmarking."""))
 
-  input_model = tf.keras.models.load_model(input_model_path)
+  input_model = keras.models.load_model(input_model_path)
 
   os.makedirs(output_dir, exist_ok=True)
   unpruned_tflite_path = os.path.join(
diff --git a/tensorflow_model_optimization/python/core/sparsity/keras/tools/sparsity_tooling.py b/tensorflow_model_optimization/python/core/sparsity/keras/tools/sparsity_tooling.py
index 658b2a580..e6e88709d 100644
--- a/tensorflow_model_optimization/python/core/sparsity/keras/tools/sparsity_tooling.py
+++ b/tensorflow_model_optimization/python/core/sparsity/keras/tools/sparsity_tooling.py
@@ -26,12 +26,11 @@
 import tensorflow as tf
 
 from tensorflow_model_optimization.python.core.keras import metrics
+from tensorflow_model_optimization.python.core.keras.compat import keras
 from tensorflow_model_optimization.python.core.sparsity.keras import prune
 from tensorflow_model_optimization.python.core.sparsity.keras import pruning_schedule
 from tensorflow_model_optimization.python.core.sparsity.keras import pruning_wrapper
 
-keras = tf.keras
-
 
 class StepIndependentConstantSparsity(pruning_schedule.PruningSchedule):
   """Pruning schedule with constant sparsity, applied at any step."""
@@ -70,16 +69,17 @@ def _apply_pruning(prunable_object):
 def prune_for_benchmark(keras_model,
                         target_sparsity,
                         block_size=(1, 1)):
-  """Prunes a tf.keras model in a single step, without re-training.
+  """Prunes a keras model in a single step, without re-training.
 
   This function is intented to quickly apply sparsity to a model, without
   consideration for accuracy.
 
   Args:
-    keras_model: A `tf.keras.Model` instance.
+    keras_model: A `keras.Model` instance.
     target_sparsity: Target sparsity as float, in [0, 1] interval.
-    block_size: The dimensions (height, weight) for the block sparse
-      pattern in rank-2 weight tensors.
+    block_size: The dimensions (height, weight) for the block sparse pattern in
+      rank-2 weight tensors.
+
   Returns:
     A pruned model, modified with pruning wrappers.
   """
diff --git a/tensorflow_model_optimization/python/core/sparsity/keras/tools/sparsity_tooling_test.py b/tensorflow_model_optimization/python/core/sparsity/keras/tools/sparsity_tooling_test.py
index 05431066f..329a3fe90 100644
--- a/tensorflow_model_optimization/python/core/sparsity/keras/tools/sparsity_tooling_test.py
+++ b/tensorflow_model_optimization/python/core/sparsity/keras/tools/sparsity_tooling_test.py
@@ -12,15 +12,16 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 # ==============================================================================
-"""Tests for tf.keras pruning tools in sparsity_tooling.py."""
+"""Tests for keras pruning tools in sparsity_tooling.py."""
 
 import tensorflow as tf
 
+from tensorflow_model_optimization.python.core.keras.compat import keras
 from tensorflow_model_optimization.python.core.sparsity.keras import pruning_wrapper
 from tensorflow_model_optimization.python.core.sparsity.keras import test_utils
 from tensorflow_model_optimization.python.core.sparsity.keras.tools import sparsity_tooling
 
-keras = tf.keras
+
 test = tf.test
 
 
diff --git a/tensorflow_model_optimization/python/examples/cluster_preserve_qat/keras/BUILD b/tensorflow_model_optimization/python/examples/cluster_preserve_qat/keras/BUILD
index 6d05a622e..0c402a982 100644
--- a/tensorflow_model_optimization/python/examples/cluster_preserve_qat/keras/BUILD
+++ b/tensorflow_model_optimization/python/examples/cluster_preserve_qat/keras/BUILD
@@ -15,6 +15,7 @@ py_strict_binary(
         # tensorflow dep1,
         "//tensorflow_model_optimization/python/core/clustering/keras:cluster",
         "//tensorflow_model_optimization/python/core/clustering/keras:cluster_config",
+        "//tensorflow_model_optimization/python/core/keras:compat",
         "//tensorflow_model_optimization/python/core/quantization/keras:quantize",
         "//tensorflow_model_optimization/python/core/quantization/keras/collab_opts/cluster_preserve:cluster_utils",
         "//tensorflow_model_optimization/python/core/quantization/keras/collab_opts/cluster_preserve:default_8bit_cluster_preserve_quantize_scheme",
diff --git a/tensorflow_model_optimization/python/examples/cluster_preserve_qat/keras/mnist_cnn.py b/tensorflow_model_optimization/python/examples/cluster_preserve_qat/keras/mnist_cnn.py
index de8e2d1f3..bfe3454e0 100644
--- a/tensorflow_model_optimization/python/examples/cluster_preserve_qat/keras/mnist_cnn.py
+++ b/tensorflow_model_optimization/python/examples/cluster_preserve_qat/keras/mnist_cnn.py
@@ -24,25 +24,28 @@
 
 from tensorflow_model_optimization.python.core.clustering.keras import cluster as tfmot_cluster
 from tensorflow_model_optimization.python.core.clustering.keras import cluster_config as tfmot_cluster_config
+from tensorflow_model_optimization.python.core.keras.compat import keras
 from tensorflow_model_optimization.python.core.quantization.keras import quantize
 from tensorflow_model_optimization.python.core.quantization.keras.collab_opts.cluster_preserve import (
     default_8bit_cluster_preserve_quantize_scheme,)
 from tensorflow_model_optimization.python.core.quantization.keras.collab_opts.cluster_preserve.cluster_utils import (
     strip_clustering_cqat,)
 
-layers = tf.keras.layers
+
+layers = keras.layers
 
 
 def setup_model(input_shape, image_train, label_train):
   """Baseline model."""
-  model = tf.keras.Sequential([
-      tf.keras.layers.InputLayer(input_shape),
-      tf.keras.layers.Reshape(target_shape=(28, 28, 1)),
-      tf.keras.layers.Conv2D(filters=12, kernel_size=(3, 3),
-                             activation=tf.nn.relu),
-      tf.keras.layers.MaxPooling2D(pool_size=(2, 2)),
-      tf.keras.layers.Flatten(),
-      tf.keras.layers.Dense(10)
+  model = keras.Sequential([
+      keras.layers.InputLayer(input_shape),
+      keras.layers.Reshape(target_shape=(28, 28, 1)),
+      keras.layers.Conv2D(
+          filters=12, kernel_size=(3, 3), activation=tf.nn.relu
+      ),
+      keras.layers.MaxPooling2D(pool_size=(2, 2)),
+      keras.layers.Flatten(),
+      keras.layers.Dense(10),
   ])
   compile_and_fit(model, image_train, label_train, 5)
 
@@ -51,12 +54,12 @@ def setup_model(input_shape, image_train, label_train):
 
 def _get_callback(model_dir):
   """Create callbacks for Keras model training."""
-  check_point = tf.keras.callbacks.ModelCheckpoint(
+  check_point = keras.callbacks.ModelCheckpoint(
       save_best_only=True,
       filepath=os.path.join(model_dir, 'model.ckpt-{epoch:04d}'),
-      verbose=1)
-  tensorboard = tf.keras.callbacks.TensorBoard(
-      log_dir=model_dir, update_freq=100)
+      verbose=1,
+  )
+  tensorboard = keras.callbacks.TensorBoard(log_dir=model_dir, update_freq=100)
   return [check_point, tensorboard]
 
 
@@ -66,8 +69,8 @@ def compile_and_fit(model,
                     epochs):
   model.compile(
       optimizer='adam',
-      loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
-      metrics=['accuracy']
+      loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
+      metrics=['accuracy'],
   )
 
   callbacks_to_use = _get_callback(model_dir='./logs')
@@ -126,11 +129,13 @@ def evaluate_model_fp32(model, image_test, label_test):
 def print_unique_weights(model):
   """Check Dense and Conv2D layers."""
   for layer in model.layers:
-    if (isinstance(layer, tf.keras.layers.Conv2D)
-        or isinstance(layer, tf.keras.layers.Dense)
-        or isinstance(layer, quantize.quantize_wrapper.QuantizeWrapper)):
+    if (
+        isinstance(layer, keras.layers.Conv2D)
+        or isinstance(layer, keras.layers.Dense)
+        or isinstance(layer, quantize.quantize_wrapper.QuantizeWrapper)
+    ):
       for weights in layer.trainable_weights:
-        np_weights = tf.keras.backend.get_value(weights)
+        np_weights = keras.backend.get_value(weights)
         unique_weights = len(np.unique(np_weights))
         if isinstance(layer, quantize.quantize_wrapper.QuantizeWrapper):
           print(layer.layer.__class__.__name__, ' (', weights.name,
@@ -172,7 +177,7 @@ def evaluate_model(interpreter, test_images, test_labels):
 
 def main(unused_args):
   # Load the MNIST dataset.
-  mnist = tf.keras.datasets.mnist
+  mnist = keras.datasets.mnist
   # Shuffle and split data to generate training and testing datasets
   (train_images, train_labels), (test_images, test_labels) = mnist.load_data()
   # Normalize the input images so that each pixel value is between 0 and 1.
diff --git a/tensorflow_model_optimization/python/examples/clustering/keras/imdb/BUILD b/tensorflow_model_optimization/python/examples/clustering/keras/imdb/BUILD
index c02c087cd..3c1073ea5 100644
--- a/tensorflow_model_optimization/python/examples/clustering/keras/imdb/BUILD
+++ b/tensorflow_model_optimization/python/examples/clustering/keras/imdb/BUILD
@@ -1,5 +1,5 @@
-# Placeholder: load py_binary
 load("//tensorflow_model_optimization:tensorflow_model_optimization.bzl", "py_strict_library")
+# Placeholder: load py_binary
 
 package(
     default_visibility = ["//visibility:public"],
@@ -67,6 +67,7 @@ py_strict_library(
         # tensorflow dep1,
         "//tensorflow_model_optimization/python/core/clustering/keras:cluster",
         "//tensorflow_model_optimization/python/core/clustering/keras:cluster_config",
+        "//tensorflow_model_optimization/python/core/keras:compat",
     ],
 )
 
@@ -82,5 +83,6 @@ py_binary(
         # tensorflow dep1,
         "//tensorflow_model_optimization/python/core/clustering/keras:cluster",
         "//tensorflow_model_optimization/python/core/clustering/keras:cluster_config",
+        "//tensorflow_model_optimization/python/core/keras:compat",
     ],
 )
diff --git a/tensorflow_model_optimization/python/examples/clustering/keras/imdb/imdb_multiple_cells.py b/tensorflow_model_optimization/python/examples/clustering/keras/imdb/imdb_multiple_cells.py
index 3c46d6b38..17eb5647d 100644
--- a/tensorflow_model_optimization/python/examples/clustering/keras/imdb/imdb_multiple_cells.py
+++ b/tensorflow_model_optimization/python/examples/clustering/keras/imdb/imdb_multiple_cells.py
@@ -22,6 +22,7 @@
 
 import tensorflow as tf
 
+from tensorflow_model_optimization.python.core.keras.compat import keras
 from tensorflow_model_optimization.python.examples.clustering.keras.imdb.imdb_utils import cluster_train_eval_strip
 from tensorflow_model_optimization.python.examples.clustering.keras.imdb.imdb_utils import prepare_dataset
 
@@ -33,29 +34,32 @@
 x_train, y_train, x_test, y_test = prepare_dataset()
 
 print("Build a model with the StackedRNNCells with LSTMCell...")
-model = tf.keras.models.Sequential()
+model = keras.models.Sequential()
 
-model.add(tf.keras.layers.Embedding(max_features, 128, input_length=maxlen))
+model.add(keras.layers.Embedding(max_features, 128, input_length=maxlen))
 model.add(
-    tf.keras.layers.RNN(
-        tf.keras.layers.StackedRNNCells(
-            [tf.keras.layers.LSTMCell(128) for _ in range(2)])))
-model.add(tf.keras.layers.Dropout(0.5))
-model.add(tf.keras.layers.Dense(1))
-model.add(tf.keras.layers.Activation("sigmoid"))
+    keras.layers.RNN(
+        keras.layers.StackedRNNCells(
+            [keras.layers.LSTMCell(128) for _ in range(2)]
+        )
+    )
+)
+model.add(keras.layers.Dropout(0.5))
+model.add(keras.layers.Dense(1))
+model.add(keras.layers.Activation("sigmoid"))
 
 test_case = "StackedRNNCells_LSTMCell"
 cluster_train_eval_strip(
     model, x_train, y_train, x_test, y_test, batch_size, test_case)
 
 print("Build a model with the Bidirectional wrapper with LSTM layer...")
-model = tf.keras.models.Sequential()
+model = keras.models.Sequential()
 
-model.add(tf.keras.layers.Embedding(max_features, 128, input_length=maxlen))
-model.add(tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(128)))
-model.add(tf.keras.layers.Dropout(0.5))
-model.add(tf.keras.layers.Dense(1))
-model.add(tf.keras.layers.Activation("sigmoid"))
+model.add(keras.layers.Embedding(max_features, 128, input_length=maxlen))
+model.add(keras.layers.Bidirectional(keras.layers.LSTM(128)))
+model.add(keras.layers.Dropout(0.5))
+model.add(keras.layers.Dense(1))
+model.add(keras.layers.Activation("sigmoid"))
 
 test_case = "Bidirectional_LSTM"
 cluster_train_eval_strip(
diff --git a/tensorflow_model_optimization/python/examples/clustering/keras/imdb/imdb_utils.py b/tensorflow_model_optimization/python/examples/clustering/keras/imdb/imdb_utils.py
index 5cd177f87..ace0b6c87 100644
--- a/tensorflow_model_optimization/python/examples/clustering/keras/imdb/imdb_utils.py
+++ b/tensorflow_model_optimization/python/examples/clustering/keras/imdb/imdb_utils.py
@@ -20,11 +20,13 @@
 from __future__ import print_function
 
 import tensorflow as tf
+
 from tensorflow_model_optimization.python.core.clustering.keras import cluster
 from tensorflow_model_optimization.python.core.clustering.keras import cluster_config
+from tensorflow_model_optimization.python.core.keras.compat import keras
 
 
-sequence = tf.keras.preprocessing.sequence
+sequence = keras.preprocessing.sequence
 
 
 def prepare_dataset():
@@ -34,9 +36,9 @@ def prepare_dataset():
   maxlen = 100  # cut texts after this number of words
 
   print("Loading data...")
-  (x_train,
-   y_train), (x_test,
-              y_test) = tf.keras.datasets.imdb.load_data(num_words=max_features)
+  (x_train, y_train), (x_test, y_test) = keras.datasets.imdb.load_data(
+      num_words=max_features
+  )
   print(len(x_train), "train sequences")
   print(len(x_test), "test sequences")
 
diff --git a/tensorflow_model_optimization/python/examples/clustering/keras/mnist/BUILD b/tensorflow_model_optimization/python/examples/clustering/keras/mnist/BUILD
index 3194028bb..5f14a886e 100644
--- a/tensorflow_model_optimization/python/examples/clustering/keras/mnist/BUILD
+++ b/tensorflow_model_optimization/python/examples/clustering/keras/mnist/BUILD
@@ -25,5 +25,6 @@ py_strict_binary(
         "//tensorflow_model_optimization/python/core/clustering/keras:cluster",
         "//tensorflow_model_optimization/python/core/clustering/keras:cluster_config",
         "//tensorflow_model_optimization/python/core/clustering/keras:clustering_callbacks",
+        "//tensorflow_model_optimization/python/core/keras:compat",
     ],
 )
diff --git a/tensorflow_model_optimization/python/examples/clustering/keras/mnist/mnist_cnn.py b/tensorflow_model_optimization/python/examples/clustering/keras/mnist/mnist_cnn.py
index 4da722635..8c83c8ae5 100644
--- a/tensorflow_model_optimization/python/examples/clustering/keras/mnist/mnist_cnn.py
+++ b/tensorflow_model_optimization/python/examples/clustering/keras/mnist/mnist_cnn.py
@@ -20,18 +20,19 @@
 """
 
 from __future__ import print_function
+
 import datetime
 import os
 
 from absl import app as absl_app
 from absl import flags
-
 import tensorflow as tf
+
 from tensorflow_model_optimization.python.core.clustering.keras import cluster
 from tensorflow_model_optimization.python.core.clustering.keras import cluster_config
 from tensorflow_model_optimization.python.core.clustering.keras import clustering_callbacks
+from tensorflow_model_optimization.python.core.keras.compat import keras
 
-keras = tf.keras
 
 FLAGS = flags.FLAGS
 
@@ -69,9 +70,10 @@ def build_sequential_model():
 
 def train_model(model, x_train, y_train, x_test, y_test):
   model.compile(
-      loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
+      loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
       optimizer='adam',
-      metrics=['accuracy'])
+      metrics=['accuracy'],
+  )
 
   # Print the model summary.
   model.summary()
@@ -106,12 +108,13 @@ def cluster_model(model, x_train, y_train, x_test, y_test):
 
   # Use smaller learning rate for fine-tuning
   # clustered model
-  opt = tf.keras.optimizers.Adam(learning_rate=1e-5)
+  opt = keras.optimizers.Adam(learning_rate=1e-5)
 
   clustered_model.compile(
-  loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
-  optimizer=opt,
-  metrics=['accuracy'])
+      loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
+      optimizer=opt,
+      metrics=['accuracy'],
+  )
 
   # Add callback for tensorboard summaries
   log_dir = os.path.join(
@@ -157,9 +160,10 @@ def test_clustered_model(clustered_model, x_test, y_test):
   # Ensure accuracy persists after stripping the model
   stripped_model = cluster.strip_clustering(loaded_clustered_model)
   stripped_model.compile(
-    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
-    optimizer='adam',
-    metrics=['accuracy'])
+      loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
+      optimizer='adam',
+      metrics=['accuracy'],
+  )
 
   # Checking that the stripped model's accuracy matches the clustered model
   score = stripped_model.evaluate(x_test, y_test, verbose=0)
diff --git a/tensorflow_model_optimization/python/examples/clustering/keras/mnist/mnist_mha.py b/tensorflow_model_optimization/python/examples/clustering/keras/mnist/mnist_mha.py
index feb70962a..e6547c2d2 100644
--- a/tensorflow_model_optimization/python/examples/clustering/keras/mnist/mnist_mha.py
+++ b/tensorflow_model_optimization/python/examples/clustering/keras/mnist/mnist_mha.py
@@ -16,15 +16,17 @@
 """Train a simple convnet with MultiHeadAttention layer on MNIST dataset
 and cluster it.
 """
+import numpy as np
 import tensorflow as tf
+
 import tensorflow_model_optimization as tfmot
+from tensorflow_model_optimization.python.core.keras.compat import keras
 
-import numpy as np
 
 NUMBER_OF_CLUSTERS = 3
 
 # Load MNIST dataset
-mnist = tf.keras.datasets.mnist
+mnist = keras.datasets.mnist
 (train_images, train_labels), (test_images, test_labels) = mnist.load_data()
 
 # Normalize the input image so that each pixel value is between 0 to 1.
@@ -32,19 +34,19 @@
 test_images = test_images / 255.0
 
 # define model
-input = tf.keras.layers.Input(shape=(28, 28))
-x = tf.keras.layers.MultiHeadAttention(num_heads=2, key_dim=16, name="mha")(
+input = keras.layers.Input(shape=(28, 28))
+x = keras.layers.MultiHeadAttention(num_heads=2, key_dim=16, name='mha')(
     query=input, value=input
 )
-x = tf.keras.layers.Flatten()(x)
-out = tf.keras.layers.Dense(10)(x)
-model = tf.keras.Model(inputs=input, outputs=out)
+x = keras.layers.Flatten()(x)
+out = keras.layers.Dense(10)(x)
+model = keras.Model(inputs=input, outputs=out)
 
 # Train the digit classification model
 model.compile(
-    optimizer="adam",
-    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
-    metrics=["accuracy"],
+    optimizer='adam',
+    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
+    metrics=['accuracy'],
 )
 
 model.fit(
@@ -72,9 +74,9 @@
 
 # `cluster_weights` requires a recompile.
 model_for_clustering.compile(
-    optimizer="adam",
-    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
-    metrics=["accuracy"],
+    optimizer='adam',
+    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
+    metrics=['accuracy'],
 )
 
 model_for_clustering.fit(
@@ -92,9 +94,10 @@
 # Strip clustering from the model
 clustered_model = tfmot.clustering.keras.strip_clustering(model_for_clustering)
 clustered_model.compile(
-    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
+    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
     optimizer='adam',
-    metrics=['accuracy'])
+    metrics=['accuracy'],
+)
 
 score = clustered_model.evaluate(test_images, test_labels, verbose=0)
 print('Stripped clustered model test loss:', score[0])
diff --git a/tensorflow_model_optimization/python/examples/quantization/keras/BUILD b/tensorflow_model_optimization/python/examples/quantization/keras/BUILD
index 71d0fe3c7..2633b1000 100644
--- a/tensorflow_model_optimization/python/examples/quantization/keras/BUILD
+++ b/tensorflow_model_optimization/python/examples/quantization/keras/BUILD
@@ -11,6 +11,7 @@ py_strict_binary(
     deps = [
         # google/protobuf:use_fast_cpp_protos dep1,  # Automatically added
         # tensorflow dep1,
+        "//tensorflow_model_optimization/python/core/keras:compat",
         "//tensorflow_model_optimization/python/core/quantization/keras:quantize",
     ],
 )
@@ -25,6 +26,7 @@ py_strict_binary(
         # google/protobuf:use_fast_cpp_protos dep1,  # Automatically added
         # numpy dep1,
         # tensorflow dep1,
+        "//tensorflow_model_optimization/python/core/keras:compat",
         "//tensorflow_model_optimization/python/core/quantization/keras:quantize",
     ],
 )
diff --git a/tensorflow_model_optimization/python/examples/quantization/keras/mnist_cnn.py b/tensorflow_model_optimization/python/examples/quantization/keras/mnist_cnn.py
index dc7f0c24e..07d84a4e1 100644
--- a/tensorflow_model_optimization/python/examples/quantization/keras/mnist_cnn.py
+++ b/tensorflow_model_optimization/python/examples/quantization/keras/mnist_cnn.py
@@ -23,6 +23,8 @@
 import tensorflow as tf  # pylint: disable=g-bad-import-order
 
 from tensorflow_model_optimization.python.core.quantization.keras import quantize
+from tensorflow_model_optimization.python.core.keras.compat import keras
+
 
 batch_size = 128
 num_classes = 10
@@ -32,9 +34,9 @@
 img_rows, img_cols = 28, 28
 
 # the data, shuffled and split between train and test sets
-(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
+(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
 
-if tf.keras.backend.image_data_format() == 'channels_first':
+if keras.backend.image_data_format() == 'channels_first':
   x_train = x_train.reshape(x_train.shape[0], 1, img_rows, img_cols)
   x_test = x_test.reshape(x_test.shape[0], 1, img_rows, img_cols)
   input_shape = (1, img_rows, img_cols)
@@ -52,17 +54,21 @@
 print(x_test.shape[0], 'test samples')
 
 # convert class vectors to binary class matrices
-y_train = tf.keras.utils.to_categorical(y_train, num_classes)
-y_test = tf.keras.utils.to_categorical(y_test, num_classes)
+y_train = keras.utils.to_categorical(y_train, num_classes)
+y_test = keras.utils.to_categorical(y_test, num_classes)
 
-l = tf.keras.layers
+l = keras.layers
 
-model = tf.keras.Sequential([
+model = keras.Sequential([
     quantize.quantize_annotate_layer(
-        l.Conv2D(32, 5, padding='same', activation='relu', input_shape=input_shape)),
+        l.Conv2D(
+            32, 5, padding='same', activation='relu', input_shape=input_shape
+        )
+    ),
     l.MaxPooling2D((2, 2), (2, 2), padding='same'),
     quantize.quantize_annotate_layer(
-        l.Conv2D(64, 5, padding='same', activation='relu')),
+        l.Conv2D(64, 5, padding='same', activation='relu')
+    ),
     l.MaxPooling2D((2, 2), (2, 2), padding='same'),
     l.Flatten(),
     quantize.quantize_annotate_layer(l.Dense(1024, activation='relu')),
@@ -80,9 +86,10 @@
   f.write(str(graph_def))
 
 model.compile(
-    loss=tf.keras.losses.categorical_crossentropy,
-    optimizer=tf.keras.optimizers.Adadelta(),
-    metrics=['accuracy'])
+    loss=keras.losses.categorical_crossentropy,
+    optimizer=keras.optimizers.Adadelta(),
+    metrics=['accuracy'],
+)
 
 model.fit(x_train, y_train,
           batch_size=batch_size,
@@ -95,7 +102,7 @@
 
 # Export to Keras.
 keras_file = '/tmp/quantized_mnist.h5'
-tf.keras.models.save_model(model, keras_file)
+keras.models.save_model(model, keras_file)
 
 # Convert to TFLite model.
 with quantize.quantize_scope():
diff --git a/tensorflow_model_optimization/python/examples/quantization/keras/mnist_cnn_cont_quant.py b/tensorflow_model_optimization/python/examples/quantization/keras/mnist_cnn_cont_quant.py
index 248255b2a..5371feb23 100644
--- a/tensorflow_model_optimization/python/examples/quantization/keras/mnist_cnn_cont_quant.py
+++ b/tensorflow_model_optimization/python/examples/quantization/keras/mnist_cnn_cont_quant.py
@@ -25,6 +25,8 @@
 import tensorflow as tf  # pylint: disable=g-bad-import-order
 
 from tensorflow_model_optimization.python.core.quantization.keras import quantize
+from tensorflow_model_optimization.python.core.keras.compat import keras
+
 
 batch_size = 128
 num_classes = 10
@@ -34,9 +36,9 @@
 img_rows, img_cols = 28, 28
 
 # the data, shuffled and split between train and test sets
-(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
+(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
 
-if tf.keras.backend.image_data_format() == 'channels_first':
+if keras.backend.image_data_format() == 'channels_first':
   x_train = x_train.reshape(x_train.shape[0], 1, img_rows, img_cols)
   x_test = x_test.reshape(x_test.shape[0], 1, img_rows, img_cols)
   input_shape = (1, img_rows, img_cols)
@@ -56,20 +58,21 @@
 print(x_test.shape[0], 'test samples')
 
 # convert class vectors to binary class matrices
-y_train = tf.keras.utils.to_categorical(y_train, num_classes)
-y_test = tf.keras.utils.to_categorical(y_test, num_classes)
+y_train = keras.utils.to_categorical(y_train, num_classes)
+y_test = keras.utils.to_categorical(y_test, num_classes)
 
-l = tf.keras.layers
+l = keras.layers
 
 keras_file = '/tmp/quantized_mnist.h5'
 if not os.path.exists(keras_file):
-  model = tf.keras.Sequential([
+  model = keras.Sequential([
       # Only the fisrt layer is quantized trained.
       # The rest of the layers are not quantization-aware.
       quantize.quantize_annotate_layer(
           l.Conv2D(
-              32, 5, padding='same', activation='relu',
-              input_shape=input_shape)),
+              32, 5, padding='same', activation='relu', input_shape=input_shape
+          )
+      ),
       l.MaxPooling2D((2, 2), (2, 2), padding='same'),
       l.Conv2D(64, 5, padding='same', activation='relu'),
       l.BatchNormalization(),
@@ -82,9 +85,10 @@
   ])
   model = quantize.quantize_apply(model)
   model.compile(
-      loss=tf.keras.losses.categorical_crossentropy,
-      optimizer=tf.keras.optimizers.Adadelta(),
-      metrics=['accuracy'])
+      loss=keras.losses.categorical_crossentropy,
+      optimizer=keras.optimizers.Adadelta(),
+      metrics=['accuracy'],
+  )
 
   model.fit(
       x_train,
@@ -95,10 +99,10 @@
       validation_data=(x_test, y_test))
 
   # Export to Keras.
-  tf.keras.models.save_model(model, keras_file)
+  keras.models.save_model(model, keras_file)
 
 with quantize.quantize_scope():
-  model = tf.keras.models.load_model(keras_file)
+  model = keras.models.load_model(keras_file)
 
 score = model.evaluate(x_test, y_test, verbose=1)
 print('Test loss:', score[0])
@@ -130,7 +134,7 @@ def calibration_gen():
   # }  # mean, std_dev values for float [0, 1] quantized to [-128, 127]
   # Set the representative dataset for post-training quantization.
 
-  model = tf.keras.models.load_model(keras_file)
+  model = keras.models.load_model(keras_file)
   converter = tf.lite.TFLiteConverter.from_keras_model(model)
 
 converter.representative_dataset = calibration_gen
@@ -174,4 +178,3 @@ def calibration_gen():
 # model. There is no clear way to measure quantization, but for MNIST
 # results which differ a lot likely suggest an error in quantization.
 np.testing.assert_allclose(score[1], quantized_score, rtol=0.2, atol=0.2)
-
diff --git a/tensorflow_model_optimization/python/examples/quantization_with_sparsity/keras/BUILD b/tensorflow_model_optimization/python/examples/quantization_with_sparsity/keras/BUILD
index b84874826..c70a3716f 100644
--- a/tensorflow_model_optimization/python/examples/quantization_with_sparsity/keras/BUILD
+++ b/tensorflow_model_optimization/python/examples/quantization_with_sparsity/keras/BUILD
@@ -13,6 +13,7 @@ py_strict_binary(
         # google/protobuf:use_fast_cpp_protos dep1,  # Automatically added
         # numpy dep1,
         # tensorflow dep1,
+        "//tensorflow_model_optimization/python/core/keras:compat",
         "//tensorflow_model_optimization/python/core/quantization/keras:quantize",
         "//tensorflow_model_optimization/python/core/quantization/keras/collab_opts/prune_preserve:default_8bit_prune_preserve_quantize_scheme",
         "//tensorflow_model_optimization/python/core/sparsity/keras:prune",
diff --git a/tensorflow_model_optimization/python/examples/quantization_with_sparsity/keras/mnist_cnn.py b/tensorflow_model_optimization/python/examples/quantization_with_sparsity/keras/mnist_cnn.py
index 96b7d88ad..603362a09 100644
--- a/tensorflow_model_optimization/python/examples/quantization_with_sparsity/keras/mnist_cnn.py
+++ b/tensorflow_model_optimization/python/examples/quantization_with_sparsity/keras/mnist_cnn.py
@@ -24,6 +24,7 @@
 import numpy as np
 import tensorflow as tf
 
+from tensorflow_model_optimization.python.core.keras.compat import keras
 from tensorflow_model_optimization.python.core.quantization.keras import quantize
 from tensorflow_model_optimization.python.core.quantization.keras.collab_opts.prune_preserve import (
     default_8bit_prune_preserve_quantize_scheme,)
@@ -32,26 +33,24 @@
 from tensorflow_model_optimization.python.core.sparsity.keras import pruning_schedule
 
 
-layers = tf.keras.layers
+layers = keras.layers
 
 
 def build_sequential_model(input_shape=(28, 28)):
   num_classes = 12
 
-  return tf.keras.Sequential([
+  return keras.Sequential([
       layers.InputLayer(input_shape=input_shape),
-      layers.Conv2D(32,
-                    5,
-                    padding='same',
-                    activation='relu',
-                    input_shape=input_shape),
+      layers.Conv2D(
+          32, 5, padding='same', activation='relu', input_shape=input_shape
+      ),
       layers.MaxPooling2D((2, 2), (2, 2), padding='same'),
       layers.Conv2D(64, 5, padding='same', activation='relu'),
       layers.MaxPooling2D((2, 2), (2, 2), padding='same'),
       layers.Flatten(),
       layers.Dense(1024, activation='relu'),
       layers.Dropout(0.4),
-      layers.Dense(num_classes, activation='softmax')
+      layers.Dense(num_classes, activation='softmax'),
   ])
 
 
@@ -85,7 +84,7 @@ def evaluate_and_show_sparsity(model, image_test, label_test):
                   prune.pruning_wrapper.PruneLowMagnitude) or isinstance(
                       layer, quantize.quantize_wrapper.QuantizeWrapper):
       for weights in layer.trainable_weights:
-        np_weights = tf.keras.backend.get_value(weights)
+        np_weights = keras.backend.get_value(weights)
         sparsity = 1.0 - np.count_nonzero(np_weights) / float(np_weights.size)
         print(layer.layer.__class__.__name__, ' (', weights.name,
               ') sparsity: ', sparsity)
@@ -145,7 +144,7 @@ def prune_preserve_quantize_model(pruned_model, train_images, train_labels):
 
 def main(unused_args):
   # Load the MNIST dataset.
-  mnist = tf.keras.datasets.mnist
+  mnist = keras.datasets.mnist
   (train_images, train_labels), (test_images, test_labels) = mnist.load_data()
   # data preprocessing
   # normalize the input images so that each pixel value is between 0 and 1.
diff --git a/tensorflow_model_optimization/python/examples/sparsity/keras/imdb/BUILD b/tensorflow_model_optimization/python/examples/sparsity/keras/imdb/BUILD
index 227770ddf..3c4cb71be 100644
--- a/tensorflow_model_optimization/python/examples/sparsity/keras/imdb/BUILD
+++ b/tensorflow_model_optimization/python/examples/sparsity/keras/imdb/BUILD
@@ -21,6 +21,7 @@ py_strict_binary(
         # google/protobuf:use_fast_cpp_protos dep1,  # Automatically added
         # numpy dep1,
         # tensorflow dep1,
+        "//tensorflow_model_optimization/python/core/keras:compat",
         "//tensorflow_model_optimization/python/core/sparsity/keras:prune",
         "//tensorflow_model_optimization/python/core/sparsity/keras:pruning_callbacks",
         "//tensorflow_model_optimization/python/core/sparsity/keras:pruning_schedule",
diff --git a/tensorflow_model_optimization/python/examples/sparsity/keras/imdb/imdb_lstm.py b/tensorflow_model_optimization/python/examples/sparsity/keras/imdb/imdb_lstm.py
index 8c0d6790e..a2da24224 100644
--- a/tensorflow_model_optimization/python/examples/sparsity/keras/imdb/imdb_lstm.py
+++ b/tensorflow_model_optimization/python/examples/sparsity/keras/imdb/imdb_lstm.py
@@ -21,13 +21,15 @@
 import numpy as np
 import tensorflow as tf
 
+from tensorflow_model_optimization.python.core.keras.compat import keras
 from tensorflow_model_optimization.python.core.sparsity.keras import prune
 from tensorflow_model_optimization.python.core.sparsity.keras import pruning_callbacks
 from tensorflow_model_optimization.python.core.sparsity.keras import pruning_schedule
 from tensorflow_model_optimization.python.core.sparsity.keras import pruning_wrapper
 
-keras = tf.keras
-K = tf.keras.backend
+
+keras = keras
+K = keras.backend
 
 
 def print_model_sparsity(pruned_model):
diff --git a/tensorflow_model_optimization/python/examples/sparsity/keras/mnist/BUILD b/tensorflow_model_optimization/python/examples/sparsity/keras/mnist/BUILD
index 083cd426f..283717425 100644
--- a/tensorflow_model_optimization/python/examples/sparsity/keras/mnist/BUILD
+++ b/tensorflow_model_optimization/python/examples/sparsity/keras/mnist/BUILD
@@ -25,6 +25,7 @@ py_strict_binary(
         # six dep1,
         # tensorflow dep1,
         # tensorflow:tensorflow_compat_v1_estimator dep1,
+        "//tensorflow_model_optimization/python/core/keras:compat",
         "//tensorflow_model_optimization/python/core/sparsity/keras:estimator_utils",
         "//tensorflow_model_optimization/python/core/sparsity/keras:prune",
         "//tensorflow_model_optimization/python/core/sparsity/keras:pruning_schedule",
@@ -45,6 +46,7 @@ py_strict_binary(
         # absl/flags dep1,
         # google/protobuf:use_fast_cpp_protos dep1,  # Automatically added
         # tensorflow dep1,
+        "//tensorflow_model_optimization/python/core/keras:compat",
         "//tensorflow_model_optimization/python/core/sparsity/keras:prune",
         "//tensorflow_model_optimization/python/core/sparsity/keras:pruning_callbacks",
         "//tensorflow_model_optimization/python/core/sparsity/keras:pruning_schedule",
@@ -62,6 +64,7 @@ py_strict_binary(
         # absl/flags dep1,
         # google/protobuf:use_fast_cpp_protos dep1,  # Automatically added
         # tensorflow dep1,
+        "//tensorflow_model_optimization/python/core/keras:compat",
         "//tensorflow_model_optimization/python/core/keras:test_utils",
         "//tensorflow_model_optimization/python/core/sparsity/keras:prune",
         "//tensorflow_model_optimization/python/core/sparsity/keras:pruning_callbacks",
@@ -76,6 +79,7 @@ py_strict_binary(
     deps = [
         # absl:app dep1,
         # tensorflow dep1,
+        "//tensorflow_model_optimization/python/core/keras:compat",
         "//tensorflow_model_optimization/python/core/keras:test_utils",
         "//tensorflow_model_optimization/python/core/sparsity/keras:prune",
         "//tensorflow_model_optimization/python/core/sparsity/keras:pruning_callbacks",
diff --git a/tensorflow_model_optimization/python/examples/sparsity/keras/mnist/mnist_cnn.py b/tensorflow_model_optimization/python/examples/sparsity/keras/mnist/mnist_cnn.py
index 7f4ea2188..eb476863b 100644
--- a/tensorflow_model_optimization/python/examples/sparsity/keras/mnist/mnist_cnn.py
+++ b/tensorflow_model_optimization/python/examples/sparsity/keras/mnist/mnist_cnn.py
@@ -18,15 +18,15 @@
 
 from absl import app as absl_app
 from absl import flags
-
 import tensorflow as tf
 
+from tensorflow_model_optimization.python.core.keras.compat import keras
 from tensorflow_model_optimization.python.core.sparsity.keras import prune
 from tensorflow_model_optimization.python.core.sparsity.keras import pruning_callbacks
 from tensorflow_model_optimization.python.core.sparsity.keras import pruning_schedule
 
+
 PolynomialDecay = pruning_schedule.PolynomialDecay
-keras = tf.keras
 l = keras.layers
 
 FLAGS = flags.FLAGS
@@ -40,9 +40,10 @@
 
 
 def build_sequential_model(input_shape):
-  return tf.keras.Sequential([
+  return keras.Sequential([
       l.Conv2D(
-          32, 5, padding='same', activation='relu', input_shape=input_shape),
+          32, 5, padding='same', activation='relu', input_shape=input_shape
+      ),
       l.MaxPooling2D((2, 2), (2, 2), padding='same'),
       l.BatchNormalization(),
       l.Conv2D(64, 5, padding='same', activation='relu'),
@@ -50,12 +51,12 @@ def build_sequential_model(input_shape):
       l.Flatten(),
       l.Dense(1024, activation='relu'),
       l.Dropout(0.4),
-      l.Dense(num_classes, activation='softmax')
+      l.Dense(num_classes, activation='softmax'),
   ])
 
 
 def build_functional_model(input_shape):
-  inp = tf.keras.Input(shape=input_shape)
+  inp = keras.Input(shape=input_shape)
   x = l.Conv2D(32, 5, padding='same', activation='relu')(inp)
   x = l.MaxPooling2D((2, 2), (2, 2), padding='same')(x)
   x = l.BatchNormalization()(x)
@@ -66,35 +67,40 @@ def build_functional_model(input_shape):
   x = l.Dropout(0.4)(x)
   out = l.Dense(num_classes, activation='softmax')(x)
 
-  return tf.keras.models.Model([inp], [out])
+  return keras.models.Model([inp], [out])
 
 
 def build_layerwise_model(input_shape, **pruning_params):
-  return tf.keras.Sequential([
+  return keras.Sequential([
       prune.prune_low_magnitude(
           l.Conv2D(32, 5, padding='same', activation='relu'),
           input_shape=input_shape,
-          **pruning_params),
+          **pruning_params
+      ),
       l.MaxPooling2D((2, 2), (2, 2), padding='same'),
       l.BatchNormalization(),
       prune.prune_low_magnitude(
-          l.Conv2D(64, 5, padding='same', activation='relu'), **pruning_params),
+          l.Conv2D(64, 5, padding='same', activation='relu'), **pruning_params
+      ),
       l.MaxPooling2D((2, 2), (2, 2), padding='same'),
       l.Flatten(),
       prune.prune_low_magnitude(
-          l.Dense(1024, activation='relu'), **pruning_params),
+          l.Dense(1024, activation='relu'), **pruning_params
+      ),
       l.Dropout(0.4),
       prune.prune_low_magnitude(
-          l.Dense(num_classes, activation='softmax'), **pruning_params)
+          l.Dense(num_classes, activation='softmax'), **pruning_params
+      ),
   ])
 
 
 def train_and_save(models, x_train, y_train, x_test, y_test):
   for model in models:
     model.compile(
-        loss=tf.keras.losses.categorical_crossentropy,
+        loss=keras.losses.categorical_crossentropy,
         optimizer='adam',
-        metrics=['accuracy'])
+        metrics=['accuracy'],
+    )
 
     # Print the model summary.
     model.summary()
@@ -121,9 +127,9 @@ def train_and_save(models, x_train, y_train, x_test, y_test):
     # Export and import the model. Check that accuracy persists.
     saved_model_dir = '/tmp/saved_model'
     print('Saving model to: ', saved_model_dir)
-    tf.keras.models.save_model(model, saved_model_dir, save_format='tf')
+    keras.models.save_model(model, saved_model_dir, save_format='tf')
     print('Loading model from: ', saved_model_dir)
-    loaded_model = tf.keras.models.load_model(saved_model_dir)
+    loaded_model = keras.models.load_model(saved_model_dir)
 
     score = loaded_model.evaluate(x_test, y_test, verbose=0)
     print('Test loss:', score[0])
@@ -135,9 +141,9 @@ def main(unused_argv):
   img_rows, img_cols = 28, 28
 
   # the data, shuffled and split between train and test sets
-  (x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
+  (x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
 
-  if tf.keras.backend.image_data_format() == 'channels_first':
+  if keras.backend.image_data_format() == 'channels_first':
     x_train = x_train.reshape(x_train.shape[0], 1, img_rows, img_cols)
     x_test = x_test.reshape(x_test.shape[0], 1, img_rows, img_cols)
     input_shape = (1, img_rows, img_cols)
@@ -155,8 +161,8 @@ def main(unused_argv):
   print(x_test.shape[0], 'test samples')
 
   # convert class vectors to binary class matrices
-  y_train = tf.keras.utils.to_categorical(y_train, num_classes)
-  y_test = tf.keras.utils.to_categorical(y_test, num_classes)
+  y_train = keras.utils.to_categorical(y_train, num_classes)
+  y_test = keras.utils.to_categorical(y_test, num_classes)
 
   pruning_params = {
       'pruning_schedule':
diff --git a/tensorflow_model_optimization/python/examples/sparsity/keras/mnist/mnist_e2e.py b/tensorflow_model_optimization/python/examples/sparsity/keras/mnist/mnist_e2e.py
index 28e744231..adfb275df 100644
--- a/tensorflow_model_optimization/python/examples/sparsity/keras/mnist/mnist_e2e.py
+++ b/tensorflow_model_optimization/python/examples/sparsity/keras/mnist/mnist_e2e.py
@@ -18,16 +18,16 @@
 
 from absl import app as absl_app
 from absl import flags
-
 import tensorflow as tf
 
 from tensorflow_model_optimization.python.core.keras import test_utils as keras_test_utils
+from tensorflow_model_optimization.python.core.keras.compat import keras
 from tensorflow_model_optimization.python.core.sparsity.keras import prune
 from tensorflow_model_optimization.python.core.sparsity.keras import pruning_callbacks
 from tensorflow_model_optimization.python.core.sparsity.keras import pruning_schedule
 
+
 ConstantSparsity = pruning_schedule.ConstantSparsity
-keras = tf.keras
 l = keras.layers
 
 FLAGS = flags.FLAGS
@@ -40,9 +40,10 @@
 
 
 def build_layerwise_model(input_shape, **pruning_params):
-  return tf.keras.Sequential([
+  return keras.Sequential([
       l.Conv2D(
-          32, 5, padding='same', activation='relu', input_shape=input_shape),
+          32, 5, padding='same', activation='relu', input_shape=input_shape
+      ),
       l.MaxPooling2D((2, 2), (2, 2), padding='same'),
       l.Conv2D(64, 5, padding='same'),
       l.BatchNormalization(),
@@ -50,18 +51,21 @@ def build_layerwise_model(input_shape, **pruning_params):
       l.MaxPooling2D((2, 2), (2, 2), padding='same'),
       l.Flatten(),
       prune.prune_low_magnitude(
-          l.Dense(1024, activation='relu'), **pruning_params),
+          l.Dense(1024, activation='relu'), **pruning_params
+      ),
       l.Dropout(0.4),
       prune.prune_low_magnitude(
-          l.Dense(num_classes, activation='softmax'), **pruning_params)
+          l.Dense(num_classes, activation='softmax'), **pruning_params
+      ),
   ])
 
 
 def train(model, x_train, y_train, x_test, y_test):
   model.compile(
-      loss=tf.keras.losses.categorical_crossentropy,
+      loss=keras.losses.categorical_crossentropy,
       optimizer='adam',
-      metrics=['accuracy'])
+      metrics=['accuracy'],
+  )
 
   # Print the model summary.
   model.summary()
diff --git a/tensorflow_model_optimization/python/examples/sparsity/keras/mnist/mnist_e2e_sparsity2x4.py b/tensorflow_model_optimization/python/examples/sparsity/keras/mnist/mnist_e2e_sparsity2x4.py
index b5c58f737..0520978f5 100644
--- a/tensorflow_model_optimization/python/examples/sparsity/keras/mnist/mnist_e2e_sparsity2x4.py
+++ b/tensorflow_model_optimization/python/examples/sparsity/keras/mnist/mnist_e2e_sparsity2x4.py
@@ -20,18 +20,18 @@
 from __future__ import print_function
 
 from absl import app as absl_app
-
 import tensorflow as tf
 
 from tensorflow_model_optimization.python.core.keras import test_utils as keras_test_utils
+from tensorflow_model_optimization.python.core.keras.compat import keras
 from tensorflow_model_optimization.python.core.sparsity.keras import prune
 from tensorflow_model_optimization.python.core.sparsity.keras import pruning_callbacks
 from tensorflow_model_optimization.python.core.sparsity.keras import pruning_schedule
 from tensorflow_model_optimization.python.core.sparsity.keras import pruning_utils
 from tensorflow_model_optimization.python.core.sparsity.keras import pruning_wrapper
 
+
 ConstantSparsity = pruning_schedule.ConstantSparsity
-keras = tf.keras
 l = keras.layers
 
 tf.random.set_seed(42)
@@ -40,7 +40,7 @@
 num_classes = 10
 epochs = 1
 
-PRUNABLE_2x4_LAYERS = (tf.keras.layers.Conv2D, tf.keras.layers.Dense)
+PRUNABLE_2x4_LAYERS = (keras.layers.Conv2D, keras.layers.Dense)
 
 
 def check_model_sparsity_2x4(model):
@@ -54,30 +54,35 @@ def check_model_sparsity_2x4(model):
 
 
 def build_layerwise_model(input_shape, **pruning_params):
-  return tf.keras.Sequential([
+  return keras.Sequential([
       prune.prune_low_magnitude(
           l.Conv2D(
-              32, 5, padding='same', activation='relu',
-              input_shape=input_shape), **pruning_params),
+              32, 5, padding='same', activation='relu', input_shape=input_shape
+          ),
+          **pruning_params
+      ),
       l.MaxPooling2D((2, 2), (2, 2), padding='same'),
       prune.prune_low_magnitude(
-          l.Conv2D(64, 5, padding='same'), **pruning_params),
+          l.Conv2D(64, 5, padding='same'), **pruning_params
+      ),
       l.BatchNormalization(),
       l.ReLU(),
       l.MaxPooling2D((2, 2), (2, 2), padding='same'),
       l.Flatten(),
       prune.prune_low_magnitude(
-          l.Dense(1024, activation='relu'), **pruning_params),
+          l.Dense(1024, activation='relu'), **pruning_params
+      ),
       l.Dropout(0.4),
-      l.Dense(num_classes, activation='softmax')
+      l.Dense(num_classes, activation='softmax'),
   ])
 
 
 def train(model, x_train, y_train, x_test, y_test):
   model.compile(
-      loss=tf.keras.losses.categorical_crossentropy,
+      loss=keras.losses.categorical_crossentropy,
       optimizer='adam',
-      metrics=['accuracy'])
+      metrics=['accuracy'],
+  )
   model.run_eagerly = True
 
   # Print the model summary.
diff --git a/tensorflow_model_optimization/python/examples/sparsity/keras/mnist/mnist_mha.py b/tensorflow_model_optimization/python/examples/sparsity/keras/mnist/mnist_mha.py
index 713c14936..fefd4a8e3 100644
--- a/tensorflow_model_optimization/python/examples/sparsity/keras/mnist/mnist_mha.py
+++ b/tensorflow_model_optimization/python/examples/sparsity/keras/mnist/mnist_mha.py
@@ -19,18 +19,20 @@
 import tensorflow as tf
 
 from tensorflow_model_optimization.python.core.keras import test_utils as keras_test_utils
+from tensorflow_model_optimization.python.core.keras.compat import keras
 from tensorflow_model_optimization.python.core.sparsity.keras import prune
 from tensorflow_model_optimization.python.core.sparsity.keras import pruning_callbacks
 from tensorflow_model_optimization.python.core.sparsity.keras import pruning_schedule
 from tensorflow_model_optimization.python.core.sparsity.keras import pruning_utils
 from tensorflow_model_optimization.python.core.sparsity.keras import pruning_wrapper
 
+
 tf.random.set_seed(42)
 
 ConstantSparsity = pruning_schedule.ConstantSparsity
 
 # Load MNIST dataset
-mnist = tf.keras.datasets.mnist
+mnist = keras.datasets.mnist
 (train_images, train_labels), (test_images, test_labels) = mnist.load_data()
 
 # Normalize the input image so that each pixel value is between 0 to 1.
@@ -38,18 +40,18 @@
 test_images = test_images / 255.0
 
 # define model
-input = tf.keras.layers.Input(shape=(28, 28))
-x = tf.keras.layers.MultiHeadAttention(num_heads=2, key_dim=16, name='mha')(
+input = keras.layers.Input(shape=(28, 28))
+x = keras.layers.MultiHeadAttention(num_heads=2, key_dim=16, name='mha')(
     query=input, value=input
 )
-x = tf.keras.layers.Flatten()(x)
-out = tf.keras.layers.Dense(10)(x)
-model = tf.keras.Model(inputs=input, outputs=out)
+x = keras.layers.Flatten()(x)
+out = keras.layers.Dense(10)(x)
+model = keras.Model(inputs=input, outputs=out)
 
 # Train the digit classification model
 model.compile(
     optimizer='adam',
-    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
+    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
     metrics=['accuracy'],
 )
 
@@ -81,7 +83,7 @@
 # `prune_low_magnitude` requires a recompile.
 model_for_pruning.compile(
     optimizer='adam',
-    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
+    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
     metrics=['accuracy'],
 )