diff --git a/onnxruntime/test/python/quantization/test_op_argmax.py b/onnxruntime/test/python/quantization/test_op_argmax.py index cb0a243c7e6e3..86bb187cfa54f 100644 --- a/onnxruntime/test/python/quantization/test_op_argmax.py +++ b/onnxruntime/test/python/quantization/test_op_argmax.py @@ -80,6 +80,7 @@ def quantize_argmax_test(self, activation_type, weight_type, extra_options = {}) weight_type_str = 'u8' if (weight_type == QuantType.QUInt8) else 's8' model_uint8_path = 'argmax_{}{}.onnx'.format(activation_type_str, weight_type_str) model_uint8_qdq_path = 'argmax_{}{}_qdq.onnx'.format(activation_type_str, weight_type_str) + model_uint8_qdq_trt_path = 'argmax_{}{}_qdq_trt.onnx'.format(activation_type_str, weight_type_str) # Verify QOperator mode data_reader = self.input_feeds(1, {'input': [1, 256, 128, 128]}) @@ -105,6 +106,17 @@ def quantize_argmax_test(self, activation_type, weight_type, extra_options = {}) data_reader.rewind() check_model_correctness(self, model_fp32_path, model_uint8_qdq_path, data_reader.get_next()) + # Verify QDQ mode for TensorRT + data_reader.rewind() + quantize_static(model_fp32_path, model_uint8_qdq_trt_path, data_reader, quant_format=QuantFormat.QDQ, + activation_type=activation_type, weight_type=weight_type, extra_options=extra_options, + op_types_to_quantize=['ArgMax']) + qdqnode_counts = {'QuantizeLinear': 1, 'DequantizeLinear': 1, 'ArgMax': 1} + check_op_type_count(self, model_uint8_qdq_trt_path, **qdqnode_counts) + qnode_io_qtypes = {'QuantizeLinear' : [['i', 2, activation_proto_qtype], ['o', 0, activation_proto_qtype]]} + check_qtype_by_node_type(self, model_uint8_qdq_trt_path, qnode_io_qtypes) + data_reader.rewind() + check_model_correctness(self, model_fp32_path, model_uint8_qdq_trt_path, data_reader.get_next()) def test_quantize_argmax(self): self.quantize_argmax_test(QuantType.QUInt8, QuantType.QUInt8)