diff --git a/tensorflow_model_optimization/python/core/quantization/keras/default_8bit/default_8bit_transforms.py b/tensorflow_model_optimization/python/core/quantization/keras/default_8bit/default_8bit_transforms.py index 58ec82303..c6feb138a 100644 --- a/tensorflow_model_optimization/python/core/quantization/keras/default_8bit/default_8bit_transforms.py +++ b/tensorflow_model_optimization/python/core/quantization/keras/default_8bit/default_8bit_transforms.py @@ -29,10 +29,17 @@ from tensorflow_model_optimization.python.core.quantization.keras.graph_transformations import transforms try: - from keras.backend import unique_object_name # pylint: disable=g-import-not-at-top + # OSS + import keras # pylint: disable=g-import-not-at-top + if hasattr(keras, 'src'): + # Path as seen in pip packages as of TF/Keras 2.13. + from keras.src.backend import unique_object_name # pylint: disable=g-import-not-at-top + else: + from keras.backend import unique_object_name # pylint: disable=g-import-not-at-top,g-importing-member except ImportError: - # Path as seen in pip packages as of TF/Keras 2.13. - from keras.src.backend import unique_object_name # pylint: disable=g-import-not-at-top + # Internal + unique_object_name = tf._keras_internal.backend.unique_object_name # pylint: disable=protected-access + LayerNode = transforms.LayerNode LayerPattern = transforms.LayerPattern