diff --git a/tensorflow_model_optimization/python/core/quantization/keras/experimental/default_n_bit/default_n_bit_transforms.py b/tensorflow_model_optimization/python/core/quantization/keras/experimental/default_n_bit/default_n_bit_transforms.py index 62d37344b..8a8ed2c1a 100644 --- a/tensorflow_model_optimization/python/core/quantization/keras/experimental/default_n_bit/default_n_bit_transforms.py +++ b/tensorflow_model_optimization/python/core/quantization/keras/experimental/default_n_bit/default_n_bit_transforms.py @@ -28,12 +28,15 @@ from tensorflow_model_optimization.python.core.quantization.keras.experimental.default_n_bit import default_n_bit_quantize_registry from tensorflow_model_optimization.python.core.quantization.keras.graph_transformations import transforms - try: - from keras.backend import unique_object_name # pylint: disable=g-import-not-at-top + import keras # pylint: disable=g-import-not-at-top + if hasattr(keras, 'src'): + # Path as seen in pip packages as of TF/Keras 2.13. + from keras.src.backend import unique_object_name # pylint: disable=g-import-not-at-top,g-importing-member + else: + from keras.backend import unique_object_name # pylint: disable=g-import-not-at-top,g-importing-member except ImportError: - # Path as seen in pip packages as of TF/Keras 2.13. - from keras.src.backend import unique_object_name # pylint: disable=g-import-not-at-top + unique_object_name = tf._keras_internal.backend.unique_object_name # pylint: disable=protected-access LayerNode = transforms.LayerNode LayerPattern = transforms.LayerPattern