From 077c69156bcb19326e9caa3ff73193a6a0c15b8c Mon Sep 17 00:00:00 2001 From: Jaesung Chung Date: Sat, 3 Feb 2024 05:03:14 -0800 Subject: [PATCH] Force to use the keras v2 version to resolve breakages in OSS PiperOrigin-RevId: 603926226 --- .../python/core/keras/compat.py | 18 +++++++------- .../python/core/quantization/keras/BUILD | 3 --- .../core/sparsity/keras/prune_registry.py | 2 -- .../examples/sparsity/keras/mnist/BUILD | 24 ------------------- 4 files changed, 8 insertions(+), 39 deletions(-) diff --git a/tensorflow_model_optimization/python/core/keras/compat.py b/tensorflow_model_optimization/python/core/keras/compat.py index 67b5a30b3..034eca897 100644 --- a/tensorflow_model_optimization/python/core/keras/compat.py +++ b/tensorflow_model_optimization/python/core/keras/compat.py @@ -19,25 +19,23 @@ from __future__ import print_function import collections +import os import weakref import tensorflow as tf def _get_keras_instance(): - from pkg_resources import parse_version - - required_tensorflow_version = '2.16.0' - if parse_version(tf.__version__) < parse_version(required_tensorflow_version): - return tf.keras + # Keep using keras-2 (tf-keras) rather than keras-3 (keras). + os.environ['TF_USE_LEGACY_KERAS'] = '1' + # Use Keras 2. version_fn = getattr(tf.keras, 'version', None) if version_fn and version_fn().startswith('3.'): - try: - import tf_keras as keras - except ImportError: - pass - return tf.keras + import tf_keras as keras_internal # pylint: disable=g-import-not-at-top,unused-import + else: + keras_internal = tf.keras + return keras_internal keras = _get_keras_instance() diff --git a/tensorflow_model_optimization/python/core/quantization/keras/BUILD b/tensorflow_model_optimization/python/core/quantization/keras/BUILD index ac0a4c856..5a4d32e9c 100644 --- a/tensorflow_model_optimization/python/core/quantization/keras/BUILD +++ b/tensorflow_model_optimization/python/core/quantization/keras/BUILD @@ -326,7 +326,6 @@ py_strict_test( srcs = ["quantize_models_test.py"], flaky = True, python_version = "PY3", - shard_count = 10, deps = [ ":quantize", ":utils", @@ -343,8 +342,6 @@ py_strict_test( size = "large", srcs = ["quantize_functional_test.py"], python_version = "PY3", - # To match parallel runs of run_all_keras_modes. - shard_count = 4, deps = [ ":quantize", ":utils", diff --git a/tensorflow_model_optimization/python/core/sparsity/keras/prune_registry.py b/tensorflow_model_optimization/python/core/sparsity/keras/prune_registry.py index 82e2fcef0..b3dd74be6 100644 --- a/tensorflow_model_optimization/python/core/sparsity/keras/prune_registry.py +++ b/tensorflow_model_optimization/python/core/sparsity/keras/prune_registry.py @@ -21,8 +21,6 @@ try: - # OSS case. - import keras # pylint: disable=g-import-not-at-top if hasattr(keras, 'src'): # Path as seen in pip packages as of TF/Keras 2.13. from keras.src.engine import base_layer # pylint: disable=g-import-not-at-top,g-importing-member diff --git a/tensorflow_model_optimization/python/examples/sparsity/keras/mnist/BUILD b/tensorflow_model_optimization/python/examples/sparsity/keras/mnist/BUILD index 283717425..a26fad0f6 100644 --- a/tensorflow_model_optimization/python/examples/sparsity/keras/mnist/BUILD +++ b/tensorflow_model_optimization/python/examples/sparsity/keras/mnist/BUILD @@ -11,30 +11,6 @@ filegroup( srcs = glob(["**"]), ) -py_strict_binary( - name = "mnist_estimator", - srcs = [ - "dataset.py", - "mnist_estimator.py", - ], - python_version = "PY3", - deps = [ - # absl/flags dep1, - # google/protobuf:use_fast_cpp_protos dep1, # Automatically added - # numpy dep1, - # six dep1, - # tensorflow dep1, - # tensorflow:tensorflow_compat_v1_estimator dep1, - "//tensorflow_model_optimization/python/core/keras:compat", - "//tensorflow_model_optimization/python/core/sparsity/keras:estimator_utils", - "//tensorflow_model_optimization/python/core/sparsity/keras:prune", - "//tensorflow_model_optimization/python/core/sparsity/keras:pruning_schedule", - "//third_party/tensorflow_models/official/common:distribute_utils", - "//third_party/tensorflow_models/official/r1/utils/logs:hooks_helper", - "//third_party/tensorflow_models/official/utils", - ], -) - py_strict_binary( name = "mnist_cnn", srcs = [