From 4735be168ba65f9bb4d74d08450fe6c600f252a2 Mon Sep 17 00:00:00 2001 From: Francois Chollet Date: Thu, 7 Sep 2023 08:31:08 -0700 Subject: [PATCH] Remove private Keras imports. PiperOrigin-RevId: 563439874 --- .../keras/default_8bit/default_8bit_transforms.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/tensorflow_model_optimization/python/core/quantization/keras/default_8bit/default_8bit_transforms.py b/tensorflow_model_optimization/python/core/quantization/keras/default_8bit/default_8bit_transforms.py index 58ec82303..c6feb138a 100644 --- a/tensorflow_model_optimization/python/core/quantization/keras/default_8bit/default_8bit_transforms.py +++ b/tensorflow_model_optimization/python/core/quantization/keras/default_8bit/default_8bit_transforms.py @@ -29,10 +29,17 @@ from tensorflow_model_optimization.python.core.quantization.keras.graph_transformations import transforms try: - from keras.backend import unique_object_name # pylint: disable=g-import-not-at-top + # OSS + import keras # pylint: disable=g-import-not-at-top + if hasattr(keras, 'src'): + # Path as seen in pip packages as of TF/Keras 2.13. + from keras.src.backend import unique_object_name # pylint: disable=g-import-not-at-top + else: + from keras.backend import unique_object_name # pylint: disable=g-import-not-at-top,g-importing-member except ImportError: - # Path as seen in pip packages as of TF/Keras 2.13. - from keras.src.backend import unique_object_name # pylint: disable=g-import-not-at-top + # Internal + unique_object_name = tf._keras_internal.backend.unique_object_name # pylint: disable=protected-access + LayerNode = transforms.LayerNode LayerPattern = transforms.LayerPattern