diff --git a/neural_compressor/model/tensorflow_model.py b/neural_compressor/model/tensorflow_model.py index 063b95ab052..e4809863a55 100644 --- a/neural_compressor/model/tensorflow_model.py +++ b/neural_compressor/model/tensorflow_model.py @@ -310,7 +310,45 @@ def load_saved_model(model, saved_model_tags, input_tensor_names, output_tensor_ return opt, input_tensor_names, output_tensor_names +def _get_graph_from_saved_model_v3(model, input_tensor_names, output_tensor_names): + """The version 3 function that get graph from saved_model. + + Args: + model (string or tf.keras.Model): model path or tf.keras.Model object. + input_tensor_names (list of string): input tensor names of the model. + output_tensor_names (list of string): output tensor names of the model. + + Returns: + graph_def (tf.compat.v1.Session): tf.compat.v1.Session object. + inputs (list of string): validated input names. + outputs (list of string): validated output names. + """ + from neural_compressor.adaptor.tf_utils.util import parse_saved_model + + if isinstance(model, tf.keras.Model): + tmp_dir = cfg.default_workspace + "/saved_model" + model.save(tmp_dir) + model = tmp_dir + graph_def, _, _, _, input_names, output_names = parse_saved_model( + model, True, input_tensor_names, output_tensor_names + ) + + return graph_def, input_names, output_names + + def _get_graph_from_saved_model_v2(saved_model_dir, input_tensor_names, output_tensor_names): + """The version 2 function that get graph from the original keras model. + + Args: + saved_model_dir (string): model path of a temporary saved_model. + input_tensor_names (list of string): input tensor names of the model. + output_tensor_names (list of string): output tensor names of the model. + + Returns: + graph_def (tf.compat.v1.Session): tf.compat.v1.Session object. + input_names (list of string): validated input names. + output_names (list of string): validated output names. + """ from tensorflow.python.saved_model import signature_constants, tag_constants saved_model_exported_names = [signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY] @@ -319,7 +357,17 @@ def _get_graph_from_saved_model_v2(saved_model_dir, input_tensor_names, output_t return load_saved_model(saved_model_dir, saved_model_tags, input_tensor_names, output_tensor_names) -def _get_graph_from_original_keras_v2(model, output_dir): +def _get_graph_from_original_keras_v2(model): + """The version 2 function that get graph from the original keras model. + + Args: + model (string or tf.keras.Model): model path or tf.keras.Model object. + + Returns: + graph_def (tf.compat.v1.Session): tf.compat.v1.Session object. + input_names (list of string): validated input names. + output_names (list of string): validated output names. + """ from tensorflow.lite.python.convert import OpsSet from tensorflow.lite.python.util import ( get_grappler_config, @@ -364,6 +412,17 @@ def _get_graph_from_original_keras_v2(model, output_dir): def _check_keras_format(model, saved_model_dir): + """Decide which method will be used to get graph from the saved_model . + + Args: + model (string or tf.keras.Model): model path or tf.keras.Model object. + saved_model_dir (string): the path to save a temporary saved_model. + + Returns: + graph_def (tf.compat.v1.Session): tf.compat.v1.Session object. + inputs (list of string): validated input names. + outputs (list of string): validated output names. + """ from tensorflow.python import saved_model from tensorflow.python.saved_model import save_options from tensorflow.python.saved_model.load import load @@ -384,6 +443,16 @@ def _check_keras_format(model, saved_model_dir): def _get_graph_from_saved_model_v1(model): + """The version 1 function that get graph from saved_model. + + Args: + model (string or tf.keras.Model): model path or tf.keras.Model object. + + Returns: + graph_def (tf.compat.v1.Session): tf.compat.v1.Session object. + inputs (list of string): validated input names. + outputs (list of string): validated output names. + """ from tensorflow.lite.python.convert_saved_model import get_inputs_outputs, get_meta_graph_def, get_signature_def from tensorflow.python.client import session from tensorflow.python.framework import ops @@ -424,6 +493,51 @@ def _get_graph_from_saved_model_v1(model): return graph_def, inputs, outputs +def try_loading_keras(model, input_tensor_names, output_tensor_names): + """Try different ways of loading keras models. + + Args: + model (string or tf.keras.Model): model path or tf.keras.Model object. + input_tensor_names (list of string): input tensor names of the model. + output_tensor_names (list of string): output tensor names of the model. + + Returns: + graph_def (tf.compat.v1.Session): tf.compat.v1.Session object. + input_names (list of string): validated input names. + output_names (list of string): validated output names. + """ + temp_dir = tempfile.mkdtemp() + if not isinstance(model, tf.keras.Model): + model = tf.keras.models.load_model(model) + keras_format = _check_keras_format(model, temp_dir) + + if keras_format == "saved_model_v2": + try: + graph_def, input_names, output_names = _get_graph_from_saved_model_v2( + temp_dir, input_tensor_names, output_tensor_names + ) + if "_FusedBatchNormEx" in [node.op for node in graph_def.node]: + keras_format = "trackable_object" + except: + keras_format = "trackable_object" + + if keras_format == "trackable_object": + try: + graph_def, input_names, output_names = _get_graph_from_original_keras_v2(model) + except: + keras_format = "saved_model_v1" + + if keras_format == "saved_model_v1": # pragma: no cover + try: + tf.keras.backend.set_learning_phase(0) + graph_def, input_names, output_names = _get_graph_from_saved_model_v1(model) + except: + raise ValueError("Not supported keras model type...") + + shutil.rmtree(temp_dir, True) + return graph_def, input_names, output_names + + def keras_session(model, input_tensor_names, output_tensor_names, **kwargs): """Build session with keras model. @@ -434,49 +548,19 @@ def keras_session(model, input_tensor_names, output_tensor_names, **kwargs): Returns: sess (tf.compat.v1.Session): tf.compat.v1.Session object. - input_tensor_names (list of string): validated input_tensor_names. - output_tensor_names (list of string): validated output_tensor_names. """ - temp_dir = tempfile.mkdtemp() if tf.version.VERSION > "2.1.0": - if not isinstance(model, tf.keras.Model): - model = tf.keras.models.load_model(model) - keras_format = _check_keras_format(model, temp_dir) - if keras_format == "saved_model_v2": - try: - graph_def, input_names, output_names = _get_graph_from_saved_model_v2( - temp_dir, input_tensor_names, output_tensor_names - ) - if "_FusedBatchNormEx" in [node.op for node in graph_def.node]: - keras_format = "trackable_object" - except: - keras_format = "trackable_object" - if keras_format == "trackable_object": - try: - graph_def, input_names, output_names = _get_graph_from_original_keras_v2(model, temp_dir) - except: - keras_format = "saved_model_v1" - if keras_format == "saved_model_v1": # pragma: no cover - try: - tf.keras.backend.set_learning_phase(0) - graph_def, input_names, output_names = _get_graph_from_saved_model_v1(model) - except: - keras_format = "saved_model_general" - if keras_format == "saved_model_general": # pargma: no cover - try: - from neural_compressor.adaptor.tf_utils.util import parse_saved_model - - graph_def, _saved_model, _, _, input_names, output_names = parse_saved_model( - temp_dir, True, input_tensor_names, output_tensor_names - ) - except: - raise ValueError("Not supported keras model type...") - + try: + graph_def, input_names, output_names = _get_graph_from_saved_model_v3( + model, input_tensor_names, output_tensor_names + ) + except: + graph_def, input_names, output_names = try_loading_keras(model, input_tensor_names, output_tensor_names) # tensorflow 1.x use v1 convert method else: tf.keras.backend.set_learning_phase(0) graph_def, input_names, output_names = _get_graph_from_saved_model_v1(model) - shutil.rmtree(temp_dir, True) + return graph_def_session(graph_def, input_names, output_names, **kwargs) @@ -645,12 +729,19 @@ def saved_model_session(model, input_tensor_names, output_tensor_names, **kwargs output_tensor_names (list of string): validated output_tensor_names. """ try: - graph_def, input_names, output_names = _get_graph_from_saved_model_v2( + graph_def, input_names, output_names = _get_graph_from_saved_model_v3( model, input_tensor_names, output_tensor_names ) except: - graph_def, input_names, output_names = _get_graph_from_saved_model_v1(model) + try: + graph_def, input_names, output_names = _get_graph_from_saved_model_v2( + model, input_tensor_names, output_tensor_names + ) + except: + graph_def, input_names, output_names = _get_graph_from_saved_model_v1(model) + assert graph_def is not None, "Can not parse the saved model..." + return graph_def_session(graph_def, input_names, output_names, **kwargs)