Skip to content

Commit

Permalink
Optimize the Workflow of Parsing Keras Model (#1623)
Browse files Browse the repository at this point in the history
Signed-off-by: zehao-intel <[email protected]>
  • Loading branch information
zehao-intel authored Feb 23, 2024
1 parent f2d9b78 commit b816d77
Showing 1 changed file with 11 additions and 9 deletions.
20 changes: 11 additions & 9 deletions neural_compressor/model/tensorflow_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,17 +313,10 @@ def load_saved_model(model, saved_model_tags, input_tensor_names, output_tensor_
def _get_graph_from_saved_model_v2(saved_model_dir, input_tensor_names, output_tensor_names):
from tensorflow.python.saved_model import signature_constants, tag_constants

from neural_compressor.adaptor.tf_utils.util import parse_saved_model

saved_model_exported_names = [signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY]
saved_model_tags = set([tag_constants.SERVING])
try:
graph_def, _saved_model, _, _, input_names, output_names = parse_saved_model(
saved_model_dir, True, input_tensor_names, output_tensor_names
)
except:
return load_saved_model(saved_model_dir, saved_model_tags, input_tensor_names, output_tensor_names)
return graph_def, input_names, output_names

return load_saved_model(saved_model_dir, saved_model_tags, input_tensor_names, output_tensor_names)


def _get_graph_from_original_keras_v2(model, output_dir):
Expand Down Expand Up @@ -467,6 +460,15 @@ def keras_session(model, input_tensor_names, output_tensor_names, **kwargs):
try:
tf.keras.backend.set_learning_phase(0)
graph_def, input_names, output_names = _get_graph_from_saved_model_v1(model)
except:
keras_format = "saved_model_general"
if keras_format == "saved_model_general": # pargma: no cover
try:
from neural_compressor.adaptor.tf_utils.util import parse_saved_model

graph_def, _saved_model, _, _, input_names, output_names = parse_saved_model(
temp_dir, True, input_tensor_names, output_tensor_names
)
except:
raise ValueError("Not supported keras model type...")

Expand Down

0 comments on commit b816d77

Please sign in to comment.