Skip to content

Commit

Permalink
Fix TF saved_model issues (#1659)
Browse files Browse the repository at this point in the history
Signed-off-by: zehao-intel <[email protected]>
  • Loading branch information
zehao-intel authored Mar 12, 2024
1 parent d07175c commit d8e60b8
Showing 1 changed file with 131 additions and 40 deletions.
171 changes: 131 additions & 40 deletions neural_compressor/model/tensorflow_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,45 @@ def load_saved_model(model, saved_model_tags, input_tensor_names, output_tensor_
return opt, input_tensor_names, output_tensor_names


def _get_graph_from_saved_model_v3(model, input_tensor_names, output_tensor_names):
"""The version 3 function that get graph from saved_model.
Args:
model (string or tf.keras.Model): model path or tf.keras.Model object.
input_tensor_names (list of string): input tensor names of the model.
output_tensor_names (list of string): output tensor names of the model.
Returns:
graph_def (tf.compat.v1.Session): tf.compat.v1.Session object.
inputs (list of string): validated input names.
outputs (list of string): validated output names.
"""
from neural_compressor.adaptor.tf_utils.util import parse_saved_model

if isinstance(model, tf.keras.Model):
tmp_dir = cfg.default_workspace + "/saved_model"
model.save(tmp_dir)
model = tmp_dir
graph_def, _, _, _, input_names, output_names = parse_saved_model(
model, True, input_tensor_names, output_tensor_names
)

return graph_def, input_names, output_names


def _get_graph_from_saved_model_v2(saved_model_dir, input_tensor_names, output_tensor_names):
"""The version 2 function that get graph from the original keras model.
Args:
saved_model_dir (string): model path of a temporary saved_model.
input_tensor_names (list of string): input tensor names of the model.
output_tensor_names (list of string): output tensor names of the model.
Returns:
graph_def (tf.compat.v1.Session): tf.compat.v1.Session object.
input_names (list of string): validated input names.
output_names (list of string): validated output names.
"""
from tensorflow.python.saved_model import signature_constants, tag_constants

saved_model_exported_names = [signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY]
Expand All @@ -319,7 +357,17 @@ def _get_graph_from_saved_model_v2(saved_model_dir, input_tensor_names, output_t
return load_saved_model(saved_model_dir, saved_model_tags, input_tensor_names, output_tensor_names)


def _get_graph_from_original_keras_v2(model, output_dir):
def _get_graph_from_original_keras_v2(model):
"""The version 2 function that get graph from the original keras model.
Args:
model (string or tf.keras.Model): model path or tf.keras.Model object.
Returns:
graph_def (tf.compat.v1.Session): tf.compat.v1.Session object.
input_names (list of string): validated input names.
output_names (list of string): validated output names.
"""
from tensorflow.lite.python.convert import OpsSet
from tensorflow.lite.python.util import (
get_grappler_config,
Expand Down Expand Up @@ -364,6 +412,17 @@ def _get_graph_from_original_keras_v2(model, output_dir):


def _check_keras_format(model, saved_model_dir):
"""Decide which method will be used to get graph from the saved_model .
Args:
model (string or tf.keras.Model): model path or tf.keras.Model object.
saved_model_dir (string): the path to save a temporary saved_model.
Returns:
graph_def (tf.compat.v1.Session): tf.compat.v1.Session object.
inputs (list of string): validated input names.
outputs (list of string): validated output names.
"""
from tensorflow.python import saved_model
from tensorflow.python.saved_model import save_options
from tensorflow.python.saved_model.load import load
Expand All @@ -384,6 +443,16 @@ def _check_keras_format(model, saved_model_dir):


def _get_graph_from_saved_model_v1(model):
"""The version 1 function that get graph from saved_model.
Args:
model (string or tf.keras.Model): model path or tf.keras.Model object.
Returns:
graph_def (tf.compat.v1.Session): tf.compat.v1.Session object.
inputs (list of string): validated input names.
outputs (list of string): validated output names.
"""
from tensorflow.lite.python.convert_saved_model import get_inputs_outputs, get_meta_graph_def, get_signature_def
from tensorflow.python.client import session
from tensorflow.python.framework import ops
Expand Down Expand Up @@ -424,6 +493,51 @@ def _get_graph_from_saved_model_v1(model):
return graph_def, inputs, outputs


def try_loading_keras(model, input_tensor_names, output_tensor_names):
"""Try different ways of loading keras models.
Args:
model (string or tf.keras.Model): model path or tf.keras.Model object.
input_tensor_names (list of string): input tensor names of the model.
output_tensor_names (list of string): output tensor names of the model.
Returns:
graph_def (tf.compat.v1.Session): tf.compat.v1.Session object.
input_names (list of string): validated input names.
output_names (list of string): validated output names.
"""
temp_dir = tempfile.mkdtemp()
if not isinstance(model, tf.keras.Model):
model = tf.keras.models.load_model(model)
keras_format = _check_keras_format(model, temp_dir)

if keras_format == "saved_model_v2":
try:
graph_def, input_names, output_names = _get_graph_from_saved_model_v2(
temp_dir, input_tensor_names, output_tensor_names
)
if "_FusedBatchNormEx" in [node.op for node in graph_def.node]:
keras_format = "trackable_object"
except:
keras_format = "trackable_object"

if keras_format == "trackable_object":
try:
graph_def, input_names, output_names = _get_graph_from_original_keras_v2(model)
except:
keras_format = "saved_model_v1"

if keras_format == "saved_model_v1": # pragma: no cover
try:
tf.keras.backend.set_learning_phase(0)
graph_def, input_names, output_names = _get_graph_from_saved_model_v1(model)
except:
raise ValueError("Not supported keras model type...")

shutil.rmtree(temp_dir, True)
return graph_def, input_names, output_names


def keras_session(model, input_tensor_names, output_tensor_names, **kwargs):
"""Build session with keras model.
Expand All @@ -434,49 +548,19 @@ def keras_session(model, input_tensor_names, output_tensor_names, **kwargs):
Returns:
sess (tf.compat.v1.Session): tf.compat.v1.Session object.
input_tensor_names (list of string): validated input_tensor_names.
output_tensor_names (list of string): validated output_tensor_names.
"""
temp_dir = tempfile.mkdtemp()
if tf.version.VERSION > "2.1.0":
if not isinstance(model, tf.keras.Model):
model = tf.keras.models.load_model(model)
keras_format = _check_keras_format(model, temp_dir)
if keras_format == "saved_model_v2":
try:
graph_def, input_names, output_names = _get_graph_from_saved_model_v2(
temp_dir, input_tensor_names, output_tensor_names
)
if "_FusedBatchNormEx" in [node.op for node in graph_def.node]:
keras_format = "trackable_object"
except:
keras_format = "trackable_object"
if keras_format == "trackable_object":
try:
graph_def, input_names, output_names = _get_graph_from_original_keras_v2(model, temp_dir)
except:
keras_format = "saved_model_v1"
if keras_format == "saved_model_v1": # pragma: no cover
try:
tf.keras.backend.set_learning_phase(0)
graph_def, input_names, output_names = _get_graph_from_saved_model_v1(model)
except:
keras_format = "saved_model_general"
if keras_format == "saved_model_general": # pargma: no cover
try:
from neural_compressor.adaptor.tf_utils.util import parse_saved_model

graph_def, _saved_model, _, _, input_names, output_names = parse_saved_model(
temp_dir, True, input_tensor_names, output_tensor_names
)
except:
raise ValueError("Not supported keras model type...")

try:
graph_def, input_names, output_names = _get_graph_from_saved_model_v3(
model, input_tensor_names, output_tensor_names
)
except:
graph_def, input_names, output_names = try_loading_keras(model, input_tensor_names, output_tensor_names)
# tensorflow 1.x use v1 convert method
else:
tf.keras.backend.set_learning_phase(0)
graph_def, input_names, output_names = _get_graph_from_saved_model_v1(model)
shutil.rmtree(temp_dir, True)

return graph_def_session(graph_def, input_names, output_names, **kwargs)


Expand Down Expand Up @@ -645,12 +729,19 @@ def saved_model_session(model, input_tensor_names, output_tensor_names, **kwargs
output_tensor_names (list of string): validated output_tensor_names.
"""
try:
graph_def, input_names, output_names = _get_graph_from_saved_model_v2(
graph_def, input_names, output_names = _get_graph_from_saved_model_v3(
model, input_tensor_names, output_tensor_names
)
except:
graph_def, input_names, output_names = _get_graph_from_saved_model_v1(model)
try:
graph_def, input_names, output_names = _get_graph_from_saved_model_v2(
model, input_tensor_names, output_tensor_names
)
except:
graph_def, input_names, output_names = _get_graph_from_saved_model_v1(model)

assert graph_def is not None, "Can not parse the saved model..."

return graph_def_session(graph_def, input_names, output_names, **kwargs)


Expand Down

0 comments on commit d8e60b8

Please sign in to comment.