diff --git a/neural_compressor/adaptor/ox_utils/operators/gather.py b/neural_compressor/adaptor/ox_utils/operators/gather.py index a3fd295ab1b..aac38b0c31b 100644 --- a/neural_compressor/adaptor/ox_utils/operators/gather.py +++ b/neural_compressor/adaptor/ox_utils/operators/gather.py @@ -64,6 +64,7 @@ def convert(self, convert_format): children = self.quantizer.model.get_children(node) if any([i.op_type == 'DequantizeLinear' for i in parents]): + from onnx import numpy_helper inputs = [] inputs.append(parents[0].input[0]) inputs.append(node.input[1]) @@ -80,7 +81,7 @@ def convert(self, convert_format): node.name, **kwargs) self.quantizer.new_nodes.append(gather_node) - if any([i.op_type != 'QuantizeLinear' for i in children]): # pragma: no cover + if any([i.op_type != 'QuantizeLinear' for i in children]): # pragma: no cover dq_inputs = [] dq_inputs.append(gather_new_output) dq_inputs.extend(parents[0].input[1:]) @@ -90,12 +91,22 @@ def convert(self, convert_format): node.name + '_DequantizeLinear') self.quantizer.new_nodes.append(dq_node) + out_scale = 1. + out_zp = 0 for child in children: if child.op_type == 'QuantizeLinear': + out_scale = numpy_helper.to_array(self.quantizer.model.get_initializer(child.input[1])) + out_zp = numpy_helper.to_array(self.quantizer.model.get_initializer(child.input[2])) self.quantizer.remove_nodes.append(child) for n in self.quantizer.model.get_children(child): self.quantizer.model.replace_node_input(n, child.output[0], gather_new_output) + if any([child.op_type == 'QuantizeLinear' for child in children]): + int8_tensor = numpy_helper.to_array(self.quantizer.model.get_initializer(parents[0].input[0])) + in_scale = numpy_helper.to_array(self.quantizer.model.get_initializer(parents[0].input[1])) + in_zp = numpy_helper.to_array(self.quantizer.model.get_initializer(parents[0].input[2])) + new_int8_tensor = (((int8_tensor.astype('float32') - in_zp) * in_scale) / out_scale).round() + out_zp + self.quantizer.model.set_initializer(parents[0].input[0], new_int8_tensor.astype(int8_tensor.dtype)) self.quantizer.remove_nodes.extend([node, parents[0]]) @qop_registry(op_types="Gather") diff --git a/neural_compressor/adaptor/ox_utils/operators/matmul.py b/neural_compressor/adaptor/ox_utils/operators/matmul.py index de06a0bef00..23427e1f6cd 100644 --- a/neural_compressor/adaptor/ox_utils/operators/matmul.py +++ b/neural_compressor/adaptor/ox_utils/operators/matmul.py @@ -29,6 +29,16 @@ def __init__(self, onnx_quantizer, onnx_node): """Initialization.""" super(MatMulOperator, self).__init__(onnx_quantizer, onnx_node) + def quantize_check(self): + """Check if quantizaion can be done.""" + node = self.node + if not all([self.quantizer.model.get_initializer(i) is None for i in node.input]): + return True + elif all([i not in self.quantizer.quantized_value_map for i in node.input]): + return False + else: + return True + def quantize(self): """Do quantizaion.""" node = self.node @@ -39,7 +49,7 @@ def quantize(self): else: self.quantizer.quantize_inputs(node, [1]) - if not self.disable_qdq_for_node_output or self.quantizer.mode != 'qdq': + if not self.disable_qdq_for_node_output: self.quantizer.quantize_outputs(node) node.name = node.name + "_quant" @@ -48,6 +58,8 @@ def convert_check(self, convert_format): node = self.node assert convert_format in ['dynamic', 'static'], \ "convert format for {} should be in ['dynamic', 'static']".format(node.op_type) + if not node.name.endswith('_quant'): + return False return True def convert(self, convert_format): @@ -112,22 +124,29 @@ def convert(self, convert_format): if len(self.quantizer.model.get_children(node)) == 0 or \ not node.name.endswith('_quant'): # pragma: no cover return - child = self.quantizer.model.get_children(node)[0] - - qlinear_matmul_output = child.output[0] qlinear_matmul_inputs = [] - for parent in parents: - qlinear_matmul_inputs.extend(parent.input) - qlinear_matmul_inputs.extend(child.input[1:]) - - qlinear_matmul_node = onnx.helper.make_node("QLinearMatMul", - qlinear_matmul_inputs, - [qlinear_matmul_output], - node.name) + if self.disable_qdq_for_node_output: + for i in range(len(parents[0].input)): + qlinear_matmul_inputs.extend([parent.input[i] for parent in parents]) + qlinear_matmul_node = onnx.helper.make_node("MatMulIntegerToFloat", + qlinear_matmul_inputs, + node.output, + node.name, + domain='com.microsoft') + else: + child = self.quantizer.model.get_children(node)[0] + qlinear_matmul_output = child.output[0] + for parent in parents: + qlinear_matmul_inputs.extend(parent.input) + qlinear_matmul_inputs.extend(child.input[1:]) + qlinear_matmul_node = onnx.helper.make_node("QLinearMatMul", + qlinear_matmul_inputs, + [qlinear_matmul_output], + node.name) + self.quantizer.remove_nodes.append(child) self.quantizer.new_nodes.append(qlinear_matmul_node) self.quantizer.remove_nodes.extend(parents) - self.quantizer.remove_nodes.append(child) self.quantizer.remove_nodes.append(node) @qop_registry(op_types="QLinearMatMul") diff --git a/test/adaptor/onnxrt_adaptor/test_adaptor_onnxrt.py b/test/adaptor/onnxrt_adaptor/test_adaptor_onnxrt.py index 26eec23f553..c5606d29cc2 100644 --- a/test/adaptor/onnxrt_adaptor/test_adaptor_onnxrt.py +++ b/test/adaptor/onnxrt_adaptor/test_adaptor_onnxrt.py @@ -414,24 +414,45 @@ def build_matmul_model(): def build_matmul_model2(): A = helper.make_tensor_value_info('A', TensorProto.FLOAT, [1, 1, 5, 5]) B = helper.make_tensor_value_info('B', TensorProto.FLOAT, [1, 1, 5, 1]) - C = helper.make_tensor_value_info('C', TensorProto.FLOAT, [1, 1, 5, 1]) - D = helper.make_tensor_value_info('D', TensorProto.FLOAT, [1, 1, 5, 1]) H = helper.make_tensor_value_info('H', TensorProto.FLOAT, [1, 1, 5, 1]) - + + C1_init = helper.make_tensor('C1', TensorProto.FLOAT, [1, 1, 5, 5], np.random.random(25).tolist()) matmul_node = onnx.helper.make_node('MatMul', ['A', 'B'], ['C'], name='Matmul') + matmul_node2 = onnx.helper.make_node('MatMul', ['C1', 'C'], ['C2'], name='Matmul2') + matmul_node3 = onnx.helper.make_node('MatMul', ['A', 'C2'], ['C3'], name='Matmul3') e_value = np.random.randint(2, size=(5)).astype(np.float32) E_init = helper.make_tensor('E', TensorProto.FLOAT, [1, 1, 5, 1], e_value.reshape(5).tolist()) - add = onnx.helper.make_node('Add', ['C', 'E'], ['D'], name='add') + add = onnx.helper.make_node('Add', ['C3', 'E'], ['D'], name='add') f_value = np.random.randint(2, size=(5)).astype(np.float32) F_init = helper.make_tensor('F', TensorProto.FLOAT, [1, 1, 5, 1], e_value.reshape(5).tolist()) add2 = onnx.helper.make_node('Add', ['D', 'F'], ['H'], name='add2') - graph = helper.make_graph([matmul_node, add, add2], 'test_graph_1', [A, B], [H], [E_init, F_init]) + graph = helper.make_graph([matmul_node, matmul_node2, matmul_node3, add, add2], 'test_graph_1', [A, B], [H], [E_init, F_init, C1_init]) model = helper.make_model(graph) model = helper.make_model(graph, **{'opset_imports': [helper.make_opsetid('', 13)]}) return model +def build_matmul_gather_model(): + input = helper.make_tensor_value_info('input0', TensorProto.INT64, [1, 1]) + output = helper.make_tensor_value_info('output0', TensorProto.FLOAT, [1, 1]) + + axes = helper.make_tensor('axes', TensorProto.INT64, [1], [1]) + squeeze = onnx.helper.make_node('Squeeze', ['input0', 'axes'], ['A'], name='squeeze') + + b_value = np.random.random((1, 2048)) + B_init = helper.make_tensor('B', TensorProto.FLOAT, [1, 2048], b_value.reshape(2048).tolist()) + + gather = onnx.helper.make_node('Gather', ['B', 'A'], ['C'], name='gather') + + d_value = np.random.random((2048, 1)).astype('float32') + D_init = helper.make_tensor('D', TensorProto.FLOAT, [2048, 1], d_value.reshape(2048).tolist()) + matmul = onnx.helper.make_node('MatMul', ['C', 'D'], ['output0']) + + graph = helper.make_graph([squeeze, gather, matmul], 'test_graph_1', [input], [output], [B_init, D_init, axes]) + model = helper.make_model(graph, **{'opset_imports': [helper.make_opsetid('', 13)]}) + return model + def build_model_with_gather(): b_value = np.random.randint(2, size=(10)).astype(np.int32) B_init = helper.make_tensor('B', TensorProto.INT32, [10], b_value.reshape(10).tolist()) @@ -670,6 +691,7 @@ def setUpClass(self): self.conv_model2 = build_conv_model2() export_onnx_nlp_model(self.distilbert_model, self.distilbert_export_path, 14) self.distilbert_model = onnx.load(self.distilbert_export_path) + self.gather_matmul_model = build_matmul_gather_model() build_benchmark() @classmethod @@ -1093,6 +1115,14 @@ def sub_eval(model, result): def eval(model): return sub_eval(model, result) + + dataset = Datasets("onnxrt_qdq")["dummy"]([(1,1,5,5), (1,1,5,1)]) + dataloader = DATALOADERS["onnxrt_qdq"](dataset) + config = PostTrainingQuantConfig(approach='static') + q_model = quantization.fit(self.matmul_model2, config, + calib_dataloader=dataloader, eval_func=eval) + self.assertEqual(len([i for i in q_model.nodes() if i.op_type == 'QLinearMatMul']), 2) + config = PostTrainingQuantConfig(approach='static', quant_format='QDQ') q_model = quantization.fit(self.matmul_model, config, calib_dataloader=self.matmul_dataloader, eval_func=eval) @@ -1118,6 +1148,28 @@ def eval(model): calib_dataloader=self.matmul_dataloader, eval_func=eval) self.assertTrue('QLinearMatMul' not in [i.op_type for i in q_model.nodes()]) + config = PostTrainingQuantConfig(approach='static', recipes={'optypes_to_exclude_output_quant': ['MatMul']}) + q_model = quantization.fit(self.matmul_model, config, + calib_dataloader=self.matmul_dataloader, eval_func=eval) + self.assertTrue('MatMulIntegerToFloat' in [i.op_type for i in q_model.nodes()]) + + dataset = Datasets("onnxrt_qdq")["dummy"]((1,1), low=0., high=0., dtype='int64') + dataloader = DATALOADERS["onnxrt_qdq"](dataset) + config = PostTrainingQuantConfig() + q_model = quantization.fit(self.gather_matmul_model, config, + calib_dataloader=dataloader, eval_func=eval) + + config = PostTrainingQuantConfig(quant_format='QDQ') + q_model2 = quantization.fit(self.gather_matmul_model, config, + calib_dataloader=dataloader, eval_func=eval) + + sess1 = ort.InferenceSession(q_model.model.SerializeToString(), providers=['CPUExecutionProvider']) + sess2 = ort.InferenceSession(q_model2.model.SerializeToString(), providers=['CPUExecutionProvider']) + for data, _ in dataloader: + output1 = sess1.run(None, {'input0': data}) + output2 = sess2.run(None, {'input0': data}) + self.assertAlmostEqual(output1[0][0], output2[0][0]) + def test_smooth_quant(self): config = PostTrainingQuantConfig(approach='static', recipes={'smooth_quant': True, \ 'smooth_quant_args': {'alpha': 0.5}}) diff --git a/test/adaptor/onnxrt_adaptor/test_onnxrt_operators.py b/test/adaptor/onnxrt_adaptor/test_onnxrt_operators.py index 6c4024a4465..15f736dfbfd 100644 --- a/test/adaptor/onnxrt_adaptor/test_onnxrt_operators.py +++ b/test/adaptor/onnxrt_adaptor/test_onnxrt_operators.py @@ -496,10 +496,10 @@ def test_conv(self): def test_matmul(self): A = helper.make_tensor_value_info('A', TensorProto.FLOAT, [1, 1, 5, 5]) - B = helper.make_tensor_value_info('B', TensorProto.FLOAT, [1, 1, 5, 1]) + B_init = helper.make_tensor('B', TensorProto.FLOAT, [1, 1, 5, 1], np.random.random((1, 1, 5, 1)).reshape(5).tolist()) C = helper.make_tensor_value_info('C', TensorProto.FLOAT, [1, 1, 5, 1]) matmul_node = onnx.helper.make_node('MatMul', ['A', 'B'], ['C'], name='Matmul') - graph = helper.make_graph([matmul_node], 'test_graph_1', [A, B], [C]) + graph = helper.make_graph([matmul_node], 'test_graph_1', [A], [C], [B_init]) model = helper.make_model(graph) q_config = {"Matmul": self.static_q_config} quantize_params = {"A": [np.uint8(10.), np.float32(0)], diff --git a/test/config/test_pythonic_config.py b/test/config/test_pythonic_config.py index 44f13857531..1c9a5fcbd27 100644 --- a/test/config/test_pythonic_config.py +++ b/test/config/test_pythonic_config.py @@ -28,7 +28,7 @@ def build_matmul_model(): A = helper.make_tensor_value_info('A', TensorProto.FLOAT, [1, 1, 5, 5]) - B = helper.make_tensor_value_info('B', TensorProto.FLOAT, [1, 1, 5, 1]) + B_init = helper.make_tensor('B', TensorProto.FLOAT, [1, 1, 5, 1], np.random.random([1, 1, 5, 1]).reshape(5).tolist()) C = helper.make_tensor_value_info('C', TensorProto.FLOAT, [1, 1, 5, 1]) D = helper.make_tensor_value_info('D', TensorProto.FLOAT, [1, 1, 5, 1]) H = helper.make_tensor_value_info('H', TensorProto.FLOAT, [1, 1, 5, 1]) @@ -40,7 +40,7 @@ def build_matmul_model(): f_value = np.random.randint(2, size=(5)).astype(np.float32) F_init = helper.make_tensor('F', TensorProto.FLOAT, [1, 1, 5, 1], e_value.reshape(5).tolist()) add2 = onnx.helper.make_node('Add', ['D', 'F'], ['H'], name='add2') - graph = helper.make_graph([matmul_node, add, add2], 'test_graph_1', [A, B], [H], [E_init, F_init]) + graph = helper.make_graph([matmul_node, add, add2], 'test_graph_1', [A], [H], [E_init, F_init, B_init]) model = helper.make_model(graph) model = helper.make_model(graph, **{'opset_imports': [helper.make_opsetid('', 13)]}) return model