Skip to content

Commit

Permalink
keras support multiple input model version (#999)
Browse files Browse the repository at this point in the history
  • Loading branch information
ClarkChin08 authored Sep 5, 2022
1 parent c6568c9 commit 5a6f092
Show file tree
Hide file tree
Showing 5 changed files with 282 additions and 114 deletions.
30 changes: 29 additions & 1 deletion neural_compressor/adaptor/tf_utils/graph_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,11 @@ def _inference(self, model):
"""
input_tensor = model.input_tensor
output_tensor = model.output_tensor
# TF table initialization: https://github.com/tensorflow/tensorflow/issues/8665
node_names = [node.name for node in model.sess.graph.as_graph_def().node]
if 'init_all_tables' in node_names:
init_table_op = model.sess.graph.get_operation_by_name('init_all_tables')
model.sess.run(init_table_op)

logger.info("Start sampling on calibration dataset.")
for idx, (inputs, labels) in enumerate(self.data_loader):
Expand Down Expand Up @@ -190,7 +195,30 @@ def _inference(self, model):
feed_dict[tensor] = inputs[name]
break
else:
feed_dict = dict(zip(input_tensor, inputs))
# sometimes the input_tensor is not the same order with inputs
# we should check and pair them
def check_shape(tensor, data):
tensor_shape = tuple(tensor.shape)
data_shape = tuple(data.shape)
for tensor_dim, data_dim in zip(tensor_shape, data_shape):
if tensor_dim is not None and tensor_dim != data_dim:
return False
return True

disorder_tensors = []
disorder_inputs = []
for idx, sort_tensor in enumerate(input_tensor):
sort_input = inputs[idx]
if check_shape(sort_tensor, sort_input):
feed_dict.update({sort_tensor: sort_input})
else:
disorder_tensors.append(sort_tensor)
disorder_inputs.append(sort_input)
for i, dis_tensor in enumerate(disorder_tensors):
for j, dis_input in enumerate(disorder_inputs):
if check_shape(dis_tensor, dis_input):
feed_dict.update({dis_tensor: dis_input})
break
_ = model.sess.run(output_tensor, feed_dict) if model.iter_op==[] \
else iterator_sess_run(model.sess, model.iter_op, \
feed_dict, output_tensor, self.calib_iteration)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,12 @@ class PreOptimization():
def __init__(self, model, optimization, new_api):
self.model = model
self.optimization = optimization
# Table initialization should disable grappler dependency and pruning pass
node_names = [node.name for node in model.graph_def.node]
if 'init_all_tables' in node_names:
self.optimization['dependency'] = False
self.optimization['pruning'] = False
self.new_api = new_api

self.analyzer = GraphAnalyzer()
self.analyzer.graph = model.graph_def
self.analyzer.parse_graph()
Expand Down
3 changes: 3 additions & 0 deletions neural_compressor/adaptor/tf_utils/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,9 @@ def strip_unused_nodes(graph_def, input_node_names, output_node_names):
cur_graph.graph = graph_def
graph_info = cur_graph.parse_graph()
type_attr = {"Sub": "T", "RealDiv": "T", "Identity": "T"}
# this op should not be stripped for table initialization
if 'init_all_tables' in graph_info.keys():
output_node_names.append('init_all_tables')
not_found = {name for name in input_node_names}
for node_name in list(graph_info.keys()):
if node_name in not_found:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,8 @@ def __init__(self, shape, low=-128., high=127., dtype='float32', label=True, \
transform=None, filter=None):

dtype_map = {'float32':np.float32, 'float16':np.float16, 'uint8':np.uint8, \
'int8':np.int8, 'int32':np.int32, 'int64':np.int64, 'bool':np.bool}
'int8': np.int8, 'int32':np.int32, 'int64':np.int64, 'bool':np.bool,\
'string': np.str}

np.random.seed(9527)
self.transform = transform
Expand Down
Loading

0 comments on commit 5a6f092

Please sign in to comment.