diff --git a/neural_compressor/model/tensorflow_model.py b/neural_compressor/model/tensorflow_model.py index c9a92220b5d..063b95ab052 100644 --- a/neural_compressor/model/tensorflow_model.py +++ b/neural_compressor/model/tensorflow_model.py @@ -313,17 +313,10 @@ def load_saved_model(model, saved_model_tags, input_tensor_names, output_tensor_ def _get_graph_from_saved_model_v2(saved_model_dir, input_tensor_names, output_tensor_names): from tensorflow.python.saved_model import signature_constants, tag_constants - from neural_compressor.adaptor.tf_utils.util import parse_saved_model - saved_model_exported_names = [signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY] saved_model_tags = set([tag_constants.SERVING]) - try: - graph_def, _saved_model, _, _, input_names, output_names = parse_saved_model( - saved_model_dir, True, input_tensor_names, output_tensor_names - ) - except: - return load_saved_model(saved_model_dir, saved_model_tags, input_tensor_names, output_tensor_names) - return graph_def, input_names, output_names + + return load_saved_model(saved_model_dir, saved_model_tags, input_tensor_names, output_tensor_names) def _get_graph_from_original_keras_v2(model, output_dir): @@ -467,6 +460,15 @@ def keras_session(model, input_tensor_names, output_tensor_names, **kwargs): try: tf.keras.backend.set_learning_phase(0) graph_def, input_names, output_names = _get_graph_from_saved_model_v1(model) + except: + keras_format = "saved_model_general" + if keras_format == "saved_model_general": # pargma: no cover + try: + from neural_compressor.adaptor.tf_utils.util import parse_saved_model + + graph_def, _saved_model, _, _, input_names, output_names = parse_saved_model( + temp_dir, True, input_tensor_names, output_tensor_names + ) except: raise ValueError("Not supported keras model type...")