diff --git a/neural_compressor/adaptor/tf_utils/graph_converter.py b/neural_compressor/adaptor/tf_utils/graph_converter.py index db94f70415e..3a4e618e862 100644 --- a/neural_compressor/adaptor/tf_utils/graph_converter.py +++ b/neural_compressor/adaptor/tf_utils/graph_converter.py @@ -160,6 +160,11 @@ def _inference(self, model): """ input_tensor = model.input_tensor output_tensor = model.output_tensor + # TF table initialization: https://github.com/tensorflow/tensorflow/issues/8665 + node_names = [node.name for node in model.sess.graph.as_graph_def().node] + if 'init_all_tables' in node_names: + init_table_op = model.sess.graph.get_operation_by_name('init_all_tables') + model.sess.run(init_table_op) logger.info("Start sampling on calibration dataset.") for idx, (inputs, labels) in enumerate(self.data_loader): @@ -190,7 +195,30 @@ def _inference(self, model): feed_dict[tensor] = inputs[name] break else: - feed_dict = dict(zip(input_tensor, inputs)) + # sometimes the input_tensor is not the same order with inputs + # we should check and pair them + def check_shape(tensor, data): + tensor_shape = tuple(tensor.shape) + data_shape = tuple(data.shape) + for tensor_dim, data_dim in zip(tensor_shape, data_shape): + if tensor_dim is not None and tensor_dim != data_dim: + return False + return True + + disorder_tensors = [] + disorder_inputs = [] + for idx, sort_tensor in enumerate(input_tensor): + sort_input = inputs[idx] + if check_shape(sort_tensor, sort_input): + feed_dict.update({sort_tensor: sort_input}) + else: + disorder_tensors.append(sort_tensor) + disorder_inputs.append(sort_input) + for i, dis_tensor in enumerate(disorder_tensors): + for j, dis_input in enumerate(disorder_inputs): + if check_shape(dis_tensor, dis_input): + feed_dict.update({dis_tensor: dis_input}) + break _ = model.sess.run(output_tensor, feed_dict) if model.iter_op==[] \ else iterator_sess_run(model.sess, model.iter_op, \ feed_dict, output_tensor, self.calib_iteration) diff --git a/neural_compressor/adaptor/tf_utils/graph_rewriter/generic/pre_optimize.py b/neural_compressor/adaptor/tf_utils/graph_rewriter/generic/pre_optimize.py index ed0a01cd4ad..99a4eb3b6aa 100644 --- a/neural_compressor/adaptor/tf_utils/graph_rewriter/generic/pre_optimize.py +++ b/neural_compressor/adaptor/tf_utils/graph_rewriter/generic/pre_optimize.py @@ -49,8 +49,12 @@ class PreOptimization(): def __init__(self, model, optimization, new_api): self.model = model self.optimization = optimization + # Table initialization should disable grappler dependency and pruning pass + node_names = [node.name for node in model.graph_def.node] + if 'init_all_tables' in node_names: + self.optimization['dependency'] = False + self.optimization['pruning'] = False self.new_api = new_api - self.analyzer = GraphAnalyzer() self.analyzer.graph = model.graph_def self.analyzer.parse_graph() diff --git a/neural_compressor/adaptor/tf_utils/util.py b/neural_compressor/adaptor/tf_utils/util.py index c112412fc25..6452ac51b17 100644 --- a/neural_compressor/adaptor/tf_utils/util.py +++ b/neural_compressor/adaptor/tf_utils/util.py @@ -302,6 +302,9 @@ def strip_unused_nodes(graph_def, input_node_names, output_node_names): cur_graph.graph = graph_def graph_info = cur_graph.parse_graph() type_attr = {"Sub": "T", "RealDiv": "T", "Identity": "T"} + # this op should not be stripped for table initialization + if 'init_all_tables' in graph_info.keys(): + output_node_names.append('init_all_tables') not_found = {name for name in input_node_names} for node_name in list(graph_info.keys()): if node_name in not_found: diff --git a/neural_compressor/experimental/data/datasets/dummy_dataset.py b/neural_compressor/experimental/data/datasets/dummy_dataset.py index cc8de47c1b4..878c8d3136d 100644 --- a/neural_compressor/experimental/data/datasets/dummy_dataset.py +++ b/neural_compressor/experimental/data/datasets/dummy_dataset.py @@ -60,7 +60,8 @@ def __init__(self, shape, low=-128., high=127., dtype='float32', label=True, \ transform=None, filter=None): dtype_map = {'float32':np.float32, 'float16':np.float16, 'uint8':np.uint8, \ - 'int8':np.int8, 'int32':np.int32, 'int64':np.int64, 'bool':np.bool} + 'int8': np.int8, 'int32':np.int32, 'int64':np.int64, 'bool':np.bool,\ + 'string': np.str} np.random.seed(9527) self.transform = transform diff --git a/neural_compressor/model/model.py b/neural_compressor/model/model.py index b6497f584d0..c997163a847 100644 --- a/neural_compressor/model/model.py +++ b/neural_compressor/model/model.py @@ -17,9 +17,11 @@ import copy import os +import shutil import importlib from collections import OrderedDict from abc import abstractmethod +import tempfile from neural_compressor.utils.utility import LazyImport, compute_sparsity, get_backend from neural_compressor.utils.utility import version1_lt_version2, version1_gt_version2, version1_gte_version2 from neural_compressor.utils import logger @@ -313,6 +315,208 @@ def frozen_pb_session(model, input_tensor_names, output_tensor_names, **kwargs): return graph_def_session(graph_def, input_tensor_names, \ output_tensor_names, **kwargs) +def _contains_function_with_implements_attr(saved_model_proto): + meta_graph = saved_model_proto.meta_graphs[0] + for function in meta_graph.graph_def.library.function: + if function.attr.get("_implements", None) or function.attr.get( + "api_implements", None): + return True + return False + +def load_saved_model(model, saved_model_tags, input_tensor_names, output_tensor_names): + """Load graph_def from saved model with the default serving signature key. + + Args: + saved_model_dir: Directory of the SavedModel. + saved_model_tags: Set of tags identifying the MetaGraphDef within the + SavedModel to analyze. + + Returns: + graph_def: The loaded GraphDef. + input_tensors: List of input tensors. + output_tensors: List of output tensors. + """ + config = tf.compat.v1.ConfigProto() + config.use_per_session_threads = 1 + config.inter_op_parallelism_threads = 1 + if get_backend() == 'tensorflow_itex_qdq': + from tensorflow.core.protobuf import rewriter_config_pb2 + config.graph_options.rewrite_options.constant_folding = \ + rewriter_config_pb2.RewriterConfig.OFF + if not os.listdir(os.path.join(model,'variables')): + sess = tf.compat.v1.Session(graph=tf.Graph(), config=config) + loader = tf.compat.v1.saved_model.loader.load(sess, ["serve"], model) + if len(input_tensor_names) == 0: + input_tensor_names = [i.name for _, i in \ + loader.signature_def['serving_default'].inputs.items()] + else: + assert validate_graph_node(\ + sess.graph.as_graph_def(), tensor_to_node(input_tensor_names)), \ + 'tensor names {} not in the graph'.format(input_tensor_names) + + if len(output_tensor_names) == 0: + output_tensor_names = [i.name for _, i in \ + loader.signature_def['serving_default'].outputs.items()] + else: + assert validate_graph_node(\ + sess.graph.as_graph_def(), tensor_to_node(output_tensor_names)), \ + 'tensor names {} not in the graph'.format(output_tensor_names) + + return sess.graph.as_graph_def(), input_tensor_names, output_tensor_names + else: + from tensorflow.python.eager import context + from tensorflow.python.saved_model import load + from tensorflow.python.saved_model import tag_constants + from tensorflow.python.saved_model import signature_constants + from tensorflow.python.framework.convert_to_constants import \ + convert_variables_to_constants_v2 + from tensorflow.python.training import saver + from tensorflow.core.protobuf import config_pb2 + from tensorflow.python.grappler import tf_optimizer + from tensorflow.core.protobuf import meta_graph_pb2 + _saved_model = load.load(model, [tag_constants.SERVING]) + func = _saved_model.signatures[signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY] + frozen_func = convert_variables_to_constants_v2(func) + grappler_meta_graph_def = saver.export_meta_graph( + graph_def=frozen_func.graph.as_graph_def(), graph=frozen_func.graph) + if len(input_tensor_names) == 0: + input_tensor_names = [i.name.split(':')[0] for i in frozen_func.inputs] + if len(output_tensor_names) == 0: + output_tensor_names = [i.name.split(':')[0] for i in frozen_func.outputs] + # Add a collection 'train_op' so that Grappler knows the outputs. + fetch_collection = meta_graph_pb2.CollectionDef() + for array in frozen_func.inputs + frozen_func.outputs: + fetch_collection.node_list.value.append(array.name) + grappler_meta_graph_def.collection_def["train_op"].CopyFrom( + fetch_collection) + from tensorflow.python.eager import context + grappler_session_config = config_pb2.ConfigProto() + rewrite_options = grappler_session_config.graph_options.rewrite_options + rewrite_options.min_graph_nodes = -1 + opt = tf_optimizer.OptimizeGraph(grappler_session_config, + grappler_meta_graph_def, graph_id=b"tf_graph") + return opt, input_tensor_names, output_tensor_names + +def check_keras_format(model, saved_model_dir): + from tensorflow.python import saved_model + from tensorflow.python.saved_model.load import load + from tensorflow.python.saved_model import save_options + from tensorflow.python.saved_model.loader_impl import parse_saved_model_with_debug_info + version = 'saved_model_v2' + try: + saved_model.save( + model, + saved_model_dir, + options=save_options.SaveOptions(save_debug_info=True)) + except: + return 'trackable_object' + saved_model_proto, _ = parse_saved_model_with_debug_info(saved_model_dir) + saved_model_version = saved_model_proto.saved_model_schema_version + if saved_model_version == 0: + return 'saved_model_v1' + if saved_model_version not in [1, 2]: + raise ValueError("SavedModel file format({0}) is not supported".format( + saved_model_version)) + return version + +def get_graph_from_saved_model_v2(saved_model_dir, + input_tensor_names, output_tensor_names): + from tensorflow.python.saved_model import tag_constants + from tensorflow.python.saved_model import signature_constants + saved_model_exported_names = [ + signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY + ] + saved_model_tags = set([tag_constants.SERVING]) + return load_saved_model(saved_model_dir, saved_model_tags, + input_tensor_names, output_tensor_names) + +def get_graph_from_original_keras_v2(model, output_dir): + from tensorflow.python.eager import def_function + from tensorflow.lite.python.util import trace_model_call + from tensorflow.lite.python.util import model_input_signature + from tensorflow.python.framework import convert_to_constants + from tensorflow.python.framework import dtypes + from tensorflow.lite.python.util import run_graph_optimizations + from tensorflow.lite.python.convert import OpsSet + from tensorflow.lite.python.util import get_grappler_config + input_signature = None + # If the model's call is not a `tf.function`, then we need to first get its + # input signature from `model_input_signature` method. + if not isinstance(model.call, def_function.Function): + input_signature = model_input_signature(model, keep_original_batch_size=False) + + func = trace_model_call(model, input_signature) + concrete_func = func.get_concrete_function() + funcs = [concrete_func] + + frozen_func, graph_def = ( + convert_to_constants.convert_variables_to_constants_v2_as_graph( + funcs[0], lower_control_flow=False)) + + input_tensors = [ + tensor for tensor in frozen_func.inputs + if tensor.dtype != dtypes.resource + ] + output_tensors = frozen_func.outputs + # Grappler will also try to lower while loop into switch merge + # representation which is undesired for Ophints, so we simply remove + # those attributes to prevent Grappler from doing so. + graph = convert_to_constants.disable_lower_using_switch_merge(graph_def) + # Run function inlining optimization to ensure any models generated + # through the from_frozen_graph path have been inlined. + # grappler_config = get_grappler_config(['function']) + # graph_def = run_graph_optimizations( + # graph, + # input_tensors, + # output_tensors, + # config=grappler_config) + input_names = [tensor.name.split(':')[0] for tensor in input_tensors] + output_names = [tensor.name.split(':')[0] for tensor in output_tensors] + return graph_def, input_names, output_names + +def get_graph_from_saved_model_v1(model): + from tensorflow.python.framework import ops + from tensorflow.python.saved_model import constants + from tensorflow.python.client import session + from tensorflow.python.saved_model import tag_constants + from tensorflow.python.saved_model import signature_constants + from tensorflow.lite.python.convert_saved_model import get_meta_graph_def + from tensorflow.lite.python.convert_saved_model import get_signature_def + from tensorflow.lite.python.convert_saved_model import get_inputs_outputs + saved_model_tags = set([tag_constants.SERVING]) + signature_key = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY + + meta_graph = get_meta_graph_def(model, saved_model_tags) + signature_def = get_signature_def(meta_graph, signature_key) + inputs, outputs = get_inputs_outputs(signature_def) + # Check SavedModel for assets directory. + collection_def = meta_graph.collection_def + if constants.ASSETS_KEY in collection_def: + raise ValueError("SavedModels with assets/ directory are not supported.") + + from tensorflow.python.saved_model import loader + from tensorflow.python.framework import graph_util as tf_graph_util + graph = ops.Graph() + import tensorflow as tf + with session.Session(graph=graph) as sess: + loader.load(sess, meta_graph.meta_info_def.tags, model) + sess.run(tf.compat.v1.global_variables_initializer()) + sess.run(tf.compat.v1.tables_initializer()) + output_nodes = list(set([output.split(':')[0] for output in outputs])) + node_ops = [node.op for node in graph.as_graph_def().node] + if 'MakeIterator' in node_ops: + output_nodes.append('MakeIterator') + table_ops = tf.compat.v1.get_collection( + tf.compat.v1.GraphKeys.TABLE_INITIALIZERS) + # For table initialization + for table_op in table_ops: + output_nodes.append(table_op.name) + if len(table_ops) > 0: + output_nodes.append('init_all_tables') + graph_def = tf_graph_util.convert_variables_to_constants( + sess, graph.as_graph_def(), output_nodes) + return graph_def, inputs, outputs + def keras_session(model, input_tensor_names, output_tensor_names, **kwargs): """Build session with keras model @@ -326,54 +530,37 @@ def keras_session(model, input_tensor_names, output_tensor_names, **kwargs): input_tensor_names (list of string): validated input_tensor_names output_tensor_names (list of string): validated output_tensor_names """ - - assert version1_gte_version2(tf.version.VERSION, '2.3.0'), 'keras model need tensorflow version >= 2.3.0....' - from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2 - if not isinstance(model, tf.keras.Model): - model = tf.keras.models.load_model(model) - kwargs = dict(zip(model.input_names, model.inputs)) - if version1_gt_version2(tf.version.VERSION, '2.2.0') and version1_lt_version2(tf.version.VERSION, '2.5.0'): - from tensorflow.python.keras.engine import keras_tensor - if keras_tensor.keras_tensors_enabled(): - for name, tensor in kwargs.items(): - kwargs[name] = tensor.type_spec - elif version1_gte_version2(tf.version.VERSION, '2.5.0'): - for name, tensor in kwargs.items(): - kwargs[name] = tensor.type_spec - full_model = tf.function(lambda **kwargs: model(kwargs.values(), training=False)) - concrete_function = full_model.get_concrete_function(**kwargs) - frozen_model = convert_variables_to_constants_v2(concrete_function) - - from tensorflow.python.training import saver - from tensorflow.core.protobuf import config_pb2 - from tensorflow.python.grappler import tf_optimizer - from tensorflow.core.protobuf import meta_graph_pb2 - graph_def = frozen_model.graph.as_graph_def() - input_names = [node.name for node in graph_def.node if node.op == 'Placeholder'] - output_names = [output.split(':')[0] for output in model.output_names] - # replace the output name with squential - for output_name in output_names: - for node in graph_def.node[::-1]: - if node.op == 'Identity' and output_name in node.input[0]: - node.name = output_name - break - - grappler_meta_graph_def = saver.export_meta_graph( - graph_def=graph_def, graph=frozen_model.graph) - - # Add a collection 'train_op' so that Grappler knows the outputs. - fetch_collection = meta_graph_pb2.CollectionDef() - for array in model.output_names: - fetch_collection.node_list.value.append(array) - grappler_meta_graph_def.collection_def["train_op"].CopyFrom( - fetch_collection) - grappler_session_config = config_pb2.ConfigProto() - rewrite_options = grappler_session_config.graph_options.rewrite_options - rewrite_options.optimizers.append('constfold') - rewrite_options.min_graph_nodes = -1 - graph_def = tf_optimizer.OptimizeGraph(grappler_session_config, \ - grappler_meta_graph_def, graph_id=b"tf_graph") - + temp_dir = tempfile.mkdtemp() + if tf.version.VERSION > '2.1.0': + if not isinstance(model, tf.keras.Model): + model = tf.keras.models.load_model(model) + keras_format = check_keras_format(model, temp_dir) + if keras_format == 'saved_model_v2': + try: + graph_def, input_names, output_names = get_graph_from_saved_model_v2( + temp_dir, input_tensor_names, output_tensor_names) + if '_FusedBatchNormEx' in [node.op for node in graph_def.node]: + keras_format = 'trackable_object' + except: + keras_format = 'trackable_object' + if keras_format == 'trackable_object': + try: + graph_def, input_names, output_names = get_graph_from_original_keras_v2( + model, temp_dir) + except: + keras_format = 'saved_model_v1' + if keras_format == 'saved_model_v1': + try: + tf.keras.backend.set_learning_phase(0) + graph_def, input_names, output_names = get_graph_from_saved_model_v1(model) + except: + raise ValueError('Not supported keras model type...') + + # tensorflow 1.x use v1 convert method + else: + tf.keras.backend.set_learning_phase(0) + graph_def, input_names, output_names = get_graph_from_saved_model_v1(model) + shutil.rmtree(temp_dir, True) return graph_def_session(graph_def, input_names, output_names, **kwargs) def slim_session(model, input_tensor_names, output_tensor_names, **kwargs): @@ -528,66 +715,13 @@ def saved_model_session(model, input_tensor_names, output_tensor_names, **kwargs input_tensor_names (list of string): validated input_tensor_names output_tensor_names (list of string): validated output_tensor_names """ - config = tf.compat.v1.ConfigProto() - config.use_per_session_threads = 1 - config.inter_op_parallelism_threads = 1 - if get_backend() == 'tensorflow_itex': - from tensorflow.core.protobuf import rewriter_config_pb2 - config.graph_options.rewrite_options.constant_folding = \ - rewriter_config_pb2.RewriterConfig.OFF - if not os.listdir(os.path.join(model,'variables')): - sess = tf.compat.v1.Session(graph=tf.Graph(), config=config) - loader = tf.compat.v1.saved_model.loader.load(sess, ["serve"], model) - if len(input_tensor_names) == 0: - input_tensor_names = [i.name for _, i in \ - loader.signature_def['serving_default'].inputs.items()] - else: - assert validate_graph_node(\ - sess.graph.as_graph_def(), tensor_to_node(input_tensor_names)), \ - 'tensor names {} not in the graph'.format(input_tensor_names) - - if len(output_tensor_names) == 0: - output_tensor_names = [i.name for _, i in \ - loader.signature_def['serving_default'].outputs.items()] - else: - assert validate_graph_node(\ - sess.graph.as_graph_def(), tensor_to_node(output_tensor_names)), \ - 'tensor names {} not in the graph'.format(output_tensor_names) - - return sess, input_tensor_names, output_tensor_names - else: - from tensorflow.python.eager import context - from tensorflow.python.saved_model import load - from tensorflow.python.saved_model import tag_constants - from tensorflow.python.saved_model import signature_constants - from tensorflow.python.framework.convert_to_constants import \ - convert_variables_to_constants_v2 - from tensorflow.python.training import saver - from tensorflow.core.protobuf import config_pb2 - from tensorflow.python.grappler import tf_optimizer - from tensorflow.core.protobuf import meta_graph_pb2 - _saved_model = load.load(model, [tag_constants.SERVING]) - func = _saved_model.signatures[signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY] - frozen_func = convert_variables_to_constants_v2(func) - grappler_meta_graph_def = saver.export_meta_graph( - graph_def=frozen_func.graph.as_graph_def(), graph=frozen_func.graph) - if len(input_tensor_names) == 0: - input_tensor_names = [i.name.split(':')[0] for i in frozen_func.inputs] - if len(output_tensor_names) == 0: - output_tensor_names = [i.name.split(':')[0] for i in frozen_func.outputs] - # Add a collection 'train_op' so that Grappler knows the outputs. - fetch_collection = meta_graph_pb2.CollectionDef() - for array in frozen_func.inputs + frozen_func.outputs: - fetch_collection.node_list.value.append(array.name) - grappler_meta_graph_def.collection_def["train_op"].CopyFrom( - fetch_collection) - from tensorflow.python.eager import context - grappler_session_config = config_pb2.ConfigProto() - rewrite_options = grappler_session_config.graph_options.rewrite_options - rewrite_options.min_graph_nodes = -1 - opt = tf_optimizer.OptimizeGraph(grappler_session_config, - grappler_meta_graph_def, graph_id=b"tf_graph") - return graph_def_session(opt, input_tensor_names, output_tensor_names, **kwargs) + try: + graph_def, input_names, output_names = get_graph_from_saved_model_v2( + model, input_tensor_names, output_tensor_names) + except: + graph_def, input_names, output_names = get_graph_from_saved_model_v1(model) + assert graph_def is not None, 'Can not parse the saved model...' + return graph_def_session(graph_def, input_names, output_names, **kwargs) # it's necessary that a session with input output tensors to run the model SESSIONS = {'frozen_pb': frozen_pb_session, @@ -944,9 +1078,7 @@ def graph_def(self, graph_def): 'saved_model': TensorflowSavedModelModel, 'keras': TensorflowSavedModelModel,} - -class TensorflowModel(object): - +class TensorflowModel(object): def __new__(cls, model_type, root, **kwargs): model = TENSORFLOW_MODELS[model_type](root, **kwargs) model.model_type = model_type