From 3df6478686d5405ba962d1e0f5c3c46ccec810c6 Mon Sep 17 00:00:00 2001 From: "Wang, Mengni" Date: Fri, 28 Apr 2023 17:15:39 +0800 Subject: [PATCH] Add mul absorbing for smooth quant of onnxrt adaptor (#807) Signed-off-by: Mengni Wang --- neural_compressor/adaptor/onnxrt.py | 22 +++-- .../adaptor/ox_utils/calibration.py | 14 +++- neural_compressor/adaptor/ox_utils/util.py | 82 ++++++++++++++++++- neural_compressor/algorithm/smooth_quant.py | 2 +- .../onnxrt_adaptor/test_adaptor_onnxrt.py | 34 ++++++-- 5 files changed, 130 insertions(+), 24 deletions(-) diff --git a/neural_compressor/adaptor/onnxrt.py b/neural_compressor/adaptor/onnxrt.py index 32737b8c328..7cd708ddd3f 100644 --- a/neural_compressor/adaptor/onnxrt.py +++ b/neural_compressor/adaptor/onnxrt.py @@ -152,8 +152,8 @@ def __init__(self, framework_specific_info): self.optype_statistics = None - def smooth_quant(self, model, dataloader, iterations, tune_cfg, alpha=0.5, folding=False, - percentile=99.999, op_types=['MatMul', 'Linear', 'Conv'], scales_per_op=True): + def smooth_quant(self, model, dataloader, iterations, tune_cfg, alpha=0.5, percentile=99.999, + op_types=['FusedConv', 'MatMul', 'Linear', 'Conv'], scales_per_op=True, **kwargs): """Get augmented model with smooth quant. Args: @@ -162,7 +162,6 @@ def smooth_quant(self, model, dataloader, iterations, tune_cfg, alpha=0.5, foldi iterations: iterations tune_cfg: quantization config alpha: smooth alpha in SmoothQuant, 1.0 will fallback to SPIQ - folding: whether insert mul(False) or just allow foldable layers(True) for SmoothQuant percentile:Percentile of calibration to remove outliers op_types: The op types whose input tensor will be dumped scales_per_op: True, each op will have an individual scale, mainly for accuracy @@ -173,8 +172,10 @@ def smooth_quant(self, model, dataloader, iterations, tune_cfg, alpha=0.5, foldi """ if self.smooth_quant_model is not None: return self.smooth_quant_model - from neural_compressor.adaptor.ox_utils.calibration import ONNXRTAugment + from onnx import numpy_helper + from neural_compressor.adaptor.ox_utils.calibration import ONNXRTAugment + from neural_compressor.adaptor.ox_utils.util import fold_scale if isinstance(alpha, str): logger.warning(f"onnx backend only support float alpha, reset alpha to 0.5 ") alpha = 0.5 @@ -199,14 +200,17 @@ def smooth_quant(self, model, dataloader, iterations, tune_cfg, alpha=0.5, foldi for name in max_vals_per_channel.keys(): curr_tensor_to_weight = [] curr_tensor_to_weight_nodes = [] - nodes = self.pre_optimized_model.input_name_to_nodes[name] + nodes = [i for i in self.pre_optimized_model.nodes() if name in i.input] for node in nodes: if node.op_type not in op_types: continue if len(node.input) >= 2: input = node.input[1] ##TODO always dump the index 1 to get the weight if self.pre_optimized_model.get_initializer(input): - weight = numpy_helper.to_array(self.pre_optimized_model.get_initializer(input)) + weight = numpy_helper.to_array(self.pre_optimized_model.get_initializer(input), + os.path.dirname(self.pre_optimized_model.model_path)) if \ + self.pre_optimized_model.model_path is not None else \ + numpy_helper.to_array(self.pre_optimized_model.get_initializer(input)) curr_tensor_to_weight.append(weight) curr_tensor_to_weight_nodes.append(node) input_tensors_2_weights[name] = curr_tensor_to_weight @@ -233,6 +237,9 @@ def smooth_quant(self, model, dataloader, iterations, tune_cfg, alpha=0.5, foldi self.pre_optimized_model.update() self.pre_optimized_model.topological_sort() self.pre_optimized_model.remove_unused_constant() + + fold_scale(self.pre_optimized_model, scales) + self.smooth_quant_model = self.pre_optimized_model return self.smooth_quant_model @@ -327,8 +334,10 @@ def quantize(self, tune_cfg, model, data_loader, q_func=None): else: quantize_params = None self.quantize_params = quantize_params + from neural_compressor.adaptor.ox_utils.quantizer import Quantizer from neural_compressor import options + quantizer = Quantizer(tmp_model, quantize_config, format, @@ -350,7 +359,6 @@ def quantize(self, tune_cfg, model, data_loader, q_func=None): tmp_model.q_config = self._generate_qconfig(model.model, tune_cfg, quantize_params) tmp_model.model = quantizer.model.model self.quantize_config = quantize_config # update so other methods can know current configs - self._dump_model_op_stats(tmp_model) tmp_model.topological_sort() return tmp_model diff --git a/neural_compressor/adaptor/ox_utils/calibration.py b/neural_compressor/adaptor/ox_utils/calibration.py index a6b88531c10..6b8178bb618 100644 --- a/neural_compressor/adaptor/ox_utils/calibration.py +++ b/neural_compressor/adaptor/ox_utils/calibration.py @@ -204,7 +204,6 @@ def augment_graph(self, activation_only=False, weight_only=False): self.model_wrapper.model_path + '_augment.onnx', save_as_external_data=True, all_tensors_to_one_file=True, - location="weights.pb", convert_attribute=False) def get_intermediate_outputs(self, q_config=None): @@ -218,11 +217,11 @@ def get_intermediate_outputs(self, q_config=None): session = onnxruntime.InferenceSession( self.augmented_model.SerializeToString(), so, - provider=self.backend) if not self.model_wrapper.is_large_model else \ + providers=[self.backend]) if not self.model_wrapper.is_large_model else \ onnxruntime.InferenceSession( self.model_wrapper.model_path + '_augment.onnx', so, - provider=self.backend) + providers=[self.backend]) len_inputs = len(session.get_inputs()) @@ -681,7 +680,14 @@ def calib_smooth(self, percentile, op_types, q_config): tensors_to_dump = self._get_input_tensor_of_ops(op_types) self.model_wrapper.add_tensors_to_outputs(tensors_to_dump) self.augmented_model = self.model_wrapper.model - _, output_dicts = self.get_intermediate_outputs(q_config) + if self.model_wrapper.is_large_model: # pragma: no cover + onnx.save_model(self.augmented_model, + self.model_wrapper.model_path + '_augment.onnx', + save_as_external_data=True, + all_tensors_to_one_file=True, + convert_attribute=False) + + _, output_dicts = self.get_intermediate_outputs() # remove the input tensors of {op_types} to outputs of the model self.model_wrapper.remove_tensors_from_outputs(tensors_to_dump) diff --git a/neural_compressor/adaptor/ox_utils/util.py b/neural_compressor/adaptor/ox_utils/util.py index 7b38d827fbe..a3173f5e2b2 100644 --- a/neural_compressor/adaptor/ox_utils/util.py +++ b/neural_compressor/adaptor/ox_utils/util.py @@ -535,7 +535,7 @@ def get_smooth_scales_per_op(max_vals_per_channel, input_tensors_2_weights, else: weight = np.moveaxis(weight, 0, 1) weight = weight.reshape(weight.shape[0], -1) - weight_max_per_channel = np.amax(weight, axis=-1) + weight_max_per_channel = np.amax(np.abs(weight), axis=-1) input_power = np.power(max_vals_per_channel[key], alpha) weight_power = np.power(weight_max_per_channel, 1 - alpha) scale = np.clip(input_power / weight_power, a_min=1e-5, a_max=None) @@ -641,7 +641,9 @@ def adjust_weights_per_op(model, nodes, scales): node = nodes[key] input = node.input[1] if input in name_to_indices.keys(): - weight = numpy_helper.to_array(model.model.graph.initializer[name_to_indices[input]]) + weight = numpy_helper.to_array(model.model.graph.initializer[name_to_indices[input]], + os.path.dirname(model.model_path)) if model.model_path is not None else \ + numpy_helper.to_array(model.model.graph.initializer[name_to_indices[input]]) if len(weight.shape) == 2: scale = np.expand_dims(scales[key], axis=-1) # TODO, to support conv @@ -672,7 +674,9 @@ def adjust_weights_per_input(model, nodes, scales): for node in curr_nodes: input = node.input[1] # TODO if input in name_to_indices.keys(): - weight = numpy_helper.to_array(model.model.graph.initializer[name_to_indices[input]]) + weight = numpy_helper.to_array(model.model.graph.initializer[name_to_indices[input]], + os.path.dirname(model.model_path)) if model.model_path is not None else \ + numpy_helper.to_array(model.model.graph.initializer[name_to_indices[input]]) if len(weight.shape) == 2: scale = np.expand_dims(scales[key], axis=-1) # TODO, to support conv @@ -740,6 +744,78 @@ def insert_smooth_mul_op_per_op(scales, shape_infos, input_tensors_2_weights_nod node.input[index] = mul_output_name return new_added_mul_nodes, new_init_tensors, name_2_nodes +def fold_scale(model, scales): + """Fold the scale to the operator at output channel. + + Args: + model: The neural_compressor model object + scales: A dict, tensor: smooth quant scale + """ + from onnx import numpy_helper + def norm(node, scale): # pragma: no cover + for idx in [1, 2]: + tensor = model.get_initializer(node.input[idx]) + new_tensor = numpy_helper.to_array(tensor, os.path.dirname(model.model_path)) * scale if \ + model.model_path is not None else numpy_helper.to_array(tensor) * scale + model.set_initializer(node.input[idx], new_tensor) + return True + + def mul(node, scale): # pragma: no cover + if all([model.get_initializer(inp) is None for inp in node.input]): + return False + for inp in node.input: + if model.get_initializer(inp) is not None: + tensor = model.get_initializer(inp) + new_tensor = numpy_helper.to_array(tensor, os.path.dirname(model.model_path)) * scale if \ + model.model_path is not None else numpy_helper.to_array(tensor) * scale + model.set_initializer(inp, new_tensor) + return True + + def conv(node, scale): # pragma: no cover + if len(node.input) > 2: + if model.get_initializer(node.input[2]) is not None: + tensor = model.get_initializer(node.input[2]) + new_tensor = numpy_helper.to_array(tensor, os.path.dirname(model.model_path)) * scale if \ + model.model_path is not None else numpy_helper.to_array(tensor) * scale + model.set_initializer(node.input[2], new_tensor) + scale = scale.reshape(-1, 1, 1, 1) + tensor = model.get_initializer(node.input[1]) + new_tensor = numpy_helper.to_array(tensor, os.path.dirname(model.model_path)) * scale if \ + model.model_path is not None else numpy_helper.to_array(tensor) * scale + model.set_initializer(node.input[1], new_tensor) + return True + + could_absorb_optype = {"LayerNormalization": norm, + "BatchNormalization": norm, + "InstanceNormalization": norm, + "SimplifiedLayerNormalization": mul, + "MatMul": mul, + "Gemm": mul, + "Conv": conv, + "FusedConv": conv, + "Mul": mul + } + remove_nodes = [] + + scales_per_op = model.get_initializer(list(scales.keys())[0]) is None + + for node in model.nodes(): + if node.op_type == "Mul" and node.name.endswith("_smooth_mul"): + parent = model.get_parent(node, 0) + if parent is None: + continue + if parent.op_type in could_absorb_optype and len(model.get_children(parent)) == 1: + if node.output[0].split("_smooth_output")[0] in scales: + if could_absorb_optype[parent.op_type](parent, + 1.0 / scales[node.output[0].split("_smooth_output")[0]]): + remove_nodes.append(node) + children = [i for i in model.nodes() if node.output[0] in i.input] + for child in children: + for idx, inp in enumerate(child.input): + if inp == node.output[0]: + child.input[idx] = node.input[0] + model.remove_nodes(remove_nodes) + def trt_env_setup(model): """Set environment variable for Tensorrt Execution Provider.""" is_int8 = False diff --git a/neural_compressor/algorithm/smooth_quant.py b/neural_compressor/algorithm/smooth_quant.py index d869bce9e76..12444db5ddf 100644 --- a/neural_compressor/algorithm/smooth_quant.py +++ b/neural_compressor/algorithm/smooth_quant.py @@ -78,13 +78,13 @@ def __call__(self, origin_model, q_model, adaptor, dataloader, calib_iter): kwargs['percentile'] = self.percentile if self.scales_per_op != None: kwargs['scales_per_op'] = self.scales_per_op + kwargs['folding'] = self.folding q_model = adaptor.smooth_quant( origin_model, dataloader, calib_iter, self.tune_cfg, alpha=self.alpha, - folding=self.folding, **kwargs, ) return q_model diff --git a/test/adaptor/onnxrt_adaptor/test_adaptor_onnxrt.py b/test/adaptor/onnxrt_adaptor/test_adaptor_onnxrt.py index 11cb65d2ae8..a2104f0a501 100644 --- a/test/adaptor/onnxrt_adaptor/test_adaptor_onnxrt.py +++ b/test/adaptor/onnxrt_adaptor/test_adaptor_onnxrt.py @@ -475,14 +475,18 @@ def build_conv_model(): np.random.randint(-1, 2, [5, 3, 3, 3]).astype(np.float32), name='conv2_weight') conv2_node = helper.make_node('Conv', ['conv1_output', 'conv2_weight'], ['conv2_output'], name='conv2') + conv3_weight_initializer = numpy_helper.from_array( + np.random.randint(-1, 2, [3, 3, 3, 3]).astype(np.float32), name='conv3_weight') + conv3_node = helper.make_node('Conv', ['input', 'conv3_weight'], ['conv3_output'], name='conv3') + avg_args = {'kernel_shape': [3, 3]} - avgpool_node = helper.make_node('AveragePool', ['conv1_output'], ['avg_output'], name='AveragePool', **avg_args) + avgpool_node = helper.make_node('AveragePool', ['conv3_output'], ['avg_output'], name='AveragePool', **avg_args) concat_node = helper.make_node('Concat', ['avg_output', 'conv2_output'], ['concat_output'], name='Concat', axis=1) output = helper.make_tensor_value_info('concat_output', TensorProto.FLOAT, [1, 8, 220, 220]) - initializers = [conv1_weight_initializer, conv2_weight_initializer] - graph = helper.make_graph([conv1_node, conv2_node, concat_node, avgpool_node], + initializers = [conv1_weight_initializer, conv2_weight_initializer, conv3_weight_initializer] + graph = helper.make_graph([conv1_node, conv2_node, conv3_node, concat_node, avgpool_node], 'test', [input], [output], initializer=initializers) model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)]) return model @@ -1096,12 +1100,24 @@ def test_smooth_quant(self): self.assertEqual(len([i for i in q_model.nodes() if i.op_type == 'Mul']), 2) def test_smooth_quant_args(self): - config = PostTrainingQuantConfig(approach='static', recipes={'smooth_quant': True, \ - 'smooth_quant_args': {'alpha': 0.6}}) - q_model = quantization.fit(self.conv_model, config, - calib_dataloader=self.cv_dataloader) - self.assertEqual(len([i for i in q_model.nodes() if i.op_type == 'Mul']), 2) - + from neural_compressor.model.onnx_model import ONNXModel + framework_specific_info = {"device": "cpu", + "approach": "post_training_static_quant", + "random_seed": 1234, + "q_dataloader": None, + "backend": "default", + "format": "default", + "domain": "auto", + "recipes": {}, + "workspace_path": './nc_workspace/{}/{}/'.format( + 'onnxrt', + 'imagenet')} + framework = "onnxrt_qlinearops" + adaptor = FRAMEWORKS[framework](framework_specific_info) + adaptor.pre_optimized_model = ONNXModel(self.conv_model) + adaptor.smooth_quant(self.conv_model, self.cv_dataloader, 1, None, scales_per_op=False) + self.assertEqual(len([i for i in adaptor.pre_optimized_model.nodes() if i.op_type == 'Mul']), 1) + def test_multi_metrics(self): conf.model.framework = 'onnxrt_qlinearops' conf.quantization.approach = 'post_training_static_quant'