From 1b26c0dc28cc0bc995c8532d82791a74a7aebe13 Mon Sep 17 00:00:00 2001 From: "Wang, Mengni" Date: Tue, 13 Jun 2023 13:04:16 +0800 Subject: [PATCH] Fix onnxrt smooth quant (#951) Signed-off-by: Mengni Wang --- .../adaptor/ox_utils/smooth_quant.py | 91 +++++++++---------- test/algorithm/test_smooth_quant.py | 6 +- 2 files changed, 45 insertions(+), 52 deletions(-) diff --git a/neural_compressor/adaptor/ox_utils/smooth_quant.py b/neural_compressor/adaptor/ox_utils/smooth_quant.py index 0a5f2ee88f3..fded97bf149 100644 --- a/neural_compressor/adaptor/ox_utils/smooth_quant.py +++ b/neural_compressor/adaptor/ox_utils/smooth_quant.py @@ -17,13 +17,15 @@ """SmoothQuant for onnxrt adaptor.""" import os +import copy import onnx import logging import numpy as np from onnx import onnx_pb as onnx_proto from neural_compressor.model.model import BaseModel from neural_compressor.model.onnx_model import ONNXModel -from neural_compressor.adaptor.ox_utils.util import find_by_name, quantize_data, _get_qrange_for_qType +from neural_compressor.adaptor.ox_utils.util import find_by_name, \ + quantize_data, _get_qrange_for_qType, is_B_transposed from onnx import numpy_helper, helper logger = logging.getLogger("neural_compressor") @@ -68,17 +70,6 @@ def make_sub_graph(node, inits, input_data, output_data, reduce_range, opset, ir from onnx import helper, TensorProto, numpy_helper input = helper.make_tensor_value_info(node.input[0], dtype_map[input_data.dtype], input_data.shape) output = helper.make_tensor_value_info(node.output[0], dtype_map[output_data.dtype], output_data.shape) - - for init in inits: - q_dq_val = quant_dequant_data(numpy_helper.to_array(init), reduce_range) - new_tensor = helper.make_tensor( - name=init.name, - data_type=dtype_map[numpy_helper.to_array(init).dtype], - dims=numpy_helper.to_array(init).shape if \ - len(numpy_helper.to_array(init).shape) != 0 else [], - vals=q_dq_val if \ - len(numpy_helper.to_array(init)) != 0 else [numpy_helper.to_array(init)]) - init.CopyFrom(new_tensor) graph = helper.make_graph([node], 'sub_graph', [input], [output], inits) model = helper.make_model(graph, opset_imports=opset) model.ir_version = ir_version @@ -110,11 +101,15 @@ class ORTSmoothQuant: def __init__(self, model, dataloader, reduce_range=False, backend='CPUExecutionProvider'): """Initialize the attributes of class.""" self.model = model if isinstance(model, BaseModel) else ONNXModel(model) + self.value_infos = {vi.name: vi for vi in self.model.model.graph.value_info} + self.value_infos.update({ot.name: ot for ot in self.model.model.graph.output}) + self.value_infos.update({it.name: it for it in self.model.model.graph.input}) self.dataloader = dataloader self.reduce_range = reduce_range self.backend = backend self.tensor_scales_info = {} self.new_added_mul_nodes = [] + self.new_added_value_info = [] self.new_init_tensors = [] # scales_tensor self.alpha = None self.percentile = None @@ -129,7 +124,7 @@ def __init__(self, model, dataloader, reduce_range=False, backend='CPUExecutionP self._build_absorb_function() def transform(self, alpha=0.5, folding=True, percentile=99.999, op_types=['Gemm', 'Conv', 'MatMul', 'FusedConv'], - scales_per_op=False, calib_iter=100, quantize_config=None, + scales_per_op=True, calib_iter=100, quantize_config=None, auto_alpha_args={'alpha_min': 0.3, 'alpha_max': 0.7, 'alpha_step': 0.05, 'attn_method': 'min'}): """The main entry of smooth quant. @@ -167,6 +162,7 @@ def transform(self, alpha=0.5, folding=True, percentile=99.999, op_types=['Gemm' self._insert_smooth_mul_op(scales) self._adjust_weights(scales) self.model.add_nodes(self.new_added_mul_nodes) + self.model.model.graph.value_info.extend(self.new_added_value_info) self.model.add_initializers(self.new_init_tensors) for node, old_input_name, new_input_name in self.replace_input: self.model.replace_node_input(node, old_input_name, new_input_name) @@ -194,9 +190,15 @@ def recover(self): for node, old_input_name, new_input_name in self.replace_input: self.model.replace_node_input(node, new_input_name, old_input_name) + for value_info in self.new_added_value_info: + self.model.model.graph.value_info.remove(value_info) + self.model.remove_nodes(self.new_added_mul_nodes) self.model.remove_initializers(self.new_init_tensors) self.tensor_scales_info = {} + self.new_added_mul_nodes = [] + self.new_init_tensors = [] + self.new_added_value_info = [] def _check_need_calibration(self, alpha, percentile, op_types, scales_per_op, calib_iter): """Check need calibration or not. @@ -340,7 +342,7 @@ def _get_output_loss(self, node_name, scale, calib_iter): loss = 0 if len(node) > 0: node = node[0] - + orig_outputs = self.model.output() added_tensors = [node.input[0], node.output[0]] self.model.add_tensors_to_outputs(added_tensors) @@ -350,20 +352,8 @@ def _get_output_loss(self, node_name, scale, calib_iter): ort.InferenceSession(self.model.model.SerializeToString(), providers=[self.backend]) base_dir = '' if not self.model.is_large_model else os.path.dirname(self.model.model_path) - if node.op_type in ['Conv', 'FusedConv']: - weight = onnx.numpy_helper.to_array(self.model.get_initializer(node.input[1]), base_dir) - weight_q = quant_dequant_data(weight) - elif node.op_type in ['MatMul', 'Gemm']: - weight = onnx.numpy_helper.to_array(self.model.get_initializer(node.input[1]), base_dir) - weight_q = quant_dequant_data(weight) - - base_dir = '' if not self.model.is_large_model else os.path.dirname(self.model.model_path) - if node.op_type in ['Conv', 'FusedConv']: - weight = onnx.numpy_helper.to_array(self.model.get_initializer(node.input[1]), base_dir) - weight_q = quant_dequant_data(weight) - elif node.op_type in ['MatMul', 'Gemm']: - weight = onnx.numpy_helper.to_array(self.model.get_initializer(node.input[1]), base_dir) - weight_q = quant_dequant_data(weight) + weight = onnx.numpy_helper.to_array(self.model.get_initializer(node.input[1]), base_dir) + weight_q = quant_dequant_data(weight) self.model.set_initializer(node.input[1], weight_q) inits = [self.model.get_initializer(i) for i in node.input if self.model.get_initializer(i) is not None] @@ -383,9 +373,9 @@ def _get_output_loss(self, node_name, scale, calib_iter): if model is None: model = make_sub_graph(node, inits, outputs[0], outputs[1], self.reduce_range, self.model.model.opset_import, self.model.model.ir_version) - loss += get_quant_dequant_output(model, outputs[0], outputs[1], self.reduce_range, self.backend) + loss += get_quant_dequant_output(model, outputs[0] * scale, outputs[1], self.reduce_range, self.backend) - self.model.remove_tensors_from_outputs(added_tensors) + self.model.remove_tensors_from_outputs([i for i in added_tensors if i not in orig_outputs]) self.model.set_initializer(node.input[1], weight) return loss @@ -430,15 +420,16 @@ def _auto_tune_alpha(self, calib_iter, alpha_min=0.3, alpha_max=0.7, alpha_step= ## Searching optimal alphas for tensor_name, node_infos in self.tensors_to_node.items(): - loss_all_ops = {} for node_info in node_infos: loss_alpha = {} key = node_info[0] if self.scales_per_op else tensor_name - + node = self.model.get_node(node_info[0]) for alpha in alpha_space: scale = self._get_smooth_scales(alpha, [key]) self._adjust_weights(scale) - input_scale = self._reshape_scale_for_input(tensor_name, key) + input_scale = self._reshape_scale_for_input(tensor_name, key) if \ + not (node.op_type == 'Gemm' and is_B_transposed(node)) else \ + self.tensor_scales_info[key] loss = self._get_output_loss(node_info[0], input_scale, calib_iter) loss_alpha[alpha] = loss if key not in optimal_alphas: # Update alpha results @@ -447,7 +438,6 @@ def _auto_tune_alpha(self, calib_iter, alpha_min=0.3, alpha_max=0.7, alpha_step= optimal_alphas[key] = alpha if optimal_alphas[key] in loss_alpha and \ loss < loss_alpha[optimal_alphas[key]] else optimal_alphas[key] self.recover() - loss_all_ops[key] = loss_alpha logger.info("auto tuning alpha done") if self.model.is_large_model: from onnx.external_data_helper import load_external_data_for_model @@ -474,28 +464,25 @@ def _get_smooth_scales(self, alpha, target_list=[]): # if scales_per_op the key of scales is the node name, otherwise the activation of node if self.scales_per_op: for node_info in nodes: + node = self.model.input_name_to_nodes[node_info[1][1]][0] if len(target_list) > 0 and node_info[0] not in target_list: continue weight = numpy_helper.to_array(self.model.get_initializer(node_info[1][1])) - if len(weight.shape) == 4: # conv - if weight.shape[1] == 1: # depthwise conv - pass - else: - weight = np.moveaxis(weight, 0, 1) + if (len(weight.shape) == 4 and weight.shape[1] != 1) or \ + (node.op_type == 'Gemm' and is_B_transposed(node)): + weight = np.moveaxis(weight, 0, 1) specific_alpha = alpha[node_info[0]] if isinstance(alpha, dict) else alpha scales[node_info[0]] = self._get_smooth_scale(weight, specific_alpha, tensor) else: if len(target_list) > 0 and tensor not in target_list: continue - weights = [numpy_helper.to_array(self.model.get_initializer(node_info[1][1])) for \ - node_info in nodes] weights_in_channel_max = [] - for weight in weights: # mamul ic*oc, conv oc*ic*k*k - if len(weight.shape) == 4: # conv - if weight.shape[1] == 1: # depthwise conv - pass - else: - weight = np.moveaxis(weight, 0, 1) + for node_info in nodes: + node = self.model.input_name_to_nodes[node_info[1][1]][0] + weight = numpy_helper.to_array(self.model.get_initializer(node_info[1][1])) + if (len(weight.shape) == 4 and weight.shape[1] != 1) or \ + (node.op_type == 'Gemm' and is_B_transposed(node)): + weight = np.moveaxis(weight, 0, 1) weight = weight.reshape(weight.shape[0], -1) cur_max = np.amax(weight, axis=-1) weights_in_channel_max.append(cur_max) @@ -555,6 +542,10 @@ def _insert_smooth_mul_op(self, scales): name=key + "_smooth_mul" ) self.new_added_mul_nodes.append(mul_node) + if input_name in self.value_infos: + value_info = copy.deepcopy(self.value_infos[input_name]) + value_info.name = mul_node.output[0] + self.new_added_value_info.append(value_info) if self.scales_per_op: self.replace_input.append([self.model.get_node(key), input_name, mul_output_name]) else: @@ -573,10 +564,12 @@ def _adjust_weights(self, scales): if key not in scales: continue input = node_info[1][1] + node = self.model.input_name_to_nodes[input][0] weight = numpy_helper.to_array(self.model.get_initializer(input)) if len(weight.shape) == 2: - scale = np.expand_dims(scales[key], - axis=-1) # TODO, to support conv + scale = np.expand_dims(scales[key], axis=0) if \ + node.op_type == 'Gemm' and is_B_transposed(node) else\ + np.expand_dims(scales[key], axis=-1) new_weight = weight * scale elif len(weight.shape) == 4: # TODO need to check conv node = self.model.input_name_to_nodes[input][0] diff --git a/test/algorithm/test_smooth_quant.py b/test/algorithm/test_smooth_quant.py index 19c77d36ba2..20730a07e3c 100644 --- a/test/algorithm/test_smooth_quant.py +++ b/test/algorithm/test_smooth_quant.py @@ -58,7 +58,7 @@ def tearDownClass(self): def test_sq(self): sq = ORTSmoothQuant(copy.deepcopy(self.model), self.dataloader) - model = sq.transform(calib_iter=5) + model = sq.transform(calib_iter=5, scales_per_op=False) self.assertEqual(len([i for i in model.model.graph.node if i.op_type == 'Mul']), 1) sq.recover() self.assertEqual(len(sq.model.nodes()), len(self.model.graph.node)) @@ -68,7 +68,7 @@ def test_sq(self): self.assertAlmostEqual(tensor[0][0], sq_tensor[0][0], 4) sq = ORTSmoothQuant(copy.deepcopy(self.model), self.dataloader) - model = sq.transform(calib_iter=5, folding=False) + model = sq.transform(calib_iter=5, folding=False, scales_per_op=False) self.assertEqual(len([i for i in model.model.graph.node if i.op_type == 'Mul']), 2) sq.recover() self.assertEqual(len(sq.model.nodes()), len(self.model.graph.node)) @@ -109,7 +109,7 @@ def test_sq(self): sq = ORTSmoothQuant(copy.deepcopy(self.model), self.dataloader) - model = sq.transform(calib_iter=5, alpha='auto') + model = sq.transform(calib_iter=5, alpha='auto', scales_per_op=False) self.assertEqual(len([i for i in model.model.graph.node if i.op_type == 'Mul']), 1) sq.recover() self.assertEqual(len(sq.model.nodes()), len(self.model.graph.node))