diff --git a/examples/.config/model_params_onnxrt.json b/examples/.config/model_params_onnxrt.json index 4ade34f75..45fafcb34 100644 --- a/examples/.config/model_params_onnxrt.json +++ b/examples/.config/model_params_onnxrt.json @@ -18,6 +18,15 @@ "batch_size": 1, "algorithm": "RTN" }, + "llama-2-7b-rtn-with-past-qdq": { + "model_name": "meta-llama/Llama-2-7b-hf", + "model_src_dir": "nlp/huggingface_model/text_generation/quantization/weight_only", + "dataset_location": "", + "input_model": "/tf_dataset2/models/onnx/Llama-2-7b-hf-with-past-opset-21", + "main_script": "main.py", + "batch_size": 1, + "algorithm": "RTN" + }, "llama-2-7b-awq": { "model_name": "meta-llama/Llama-2-7b-hf", "model_src_dir": "nlp/huggingface_model/text_generation/quantization/weight_only", @@ -36,6 +45,15 @@ "batch_size": 1, "algorithm": "AWQ" }, + "llama-2-7b-awq-with-past-qdq": { + "model_name": "meta-llama/Llama-2-7b-hf", + "model_src_dir": "nlp/huggingface_model/text_generation/quantization/weight_only", + "dataset_location": "", + "input_model": "/tf_dataset2/models/onnx/Llama-2-7b-hf-with-past-opset-21", + "main_script": "main.py", + "batch_size": 1, + "algorithm": "AWQ" + }, "llama-2-7b-gptq": { "model_name": "meta-llama/Llama-2-7b-hf", "model_src_dir": "nlp/huggingface_model/text_generation/quantization/weight_only", @@ -54,6 +72,15 @@ "batch_size": 1, "algorithm": "GPTQ" }, + "llama-2-7b-gptq-with-past-qdq": { + "model_name": "meta-llama/Llama-2-7b-hf", + "model_src_dir": "nlp/huggingface_model/text_generation/quantization/weight_only", + "dataset_location": "", + "input_model": "/tf_dataset2/models/onnx/Llama-2-7b-hf-with-past-opset-21", + "main_script": "main.py", + "batch_size": 1, + "algorithm": "GPTQ" + }, "llama-2-7b-woq_tune": { "model_name": "meta-llama/Llama-2-7b-hf", "model_src_dir": "nlp/huggingface_model/text_generation/quantization/weight_only", diff --git a/examples/nlp/huggingface_model/text_generation/quantization/weight_only/README.md b/examples/nlp/huggingface_model/text_generation/quantization/weight_only/README.md index 11d3ed27a..6bbd8234f 100644 --- a/examples/nlp/huggingface_model/text_generation/quantization/weight_only/README.md +++ b/examples/nlp/huggingface_model/text_generation/quantization/weight_only/README.md @@ -55,7 +55,7 @@ bash run_quant.sh --input_model=/path/to/model \ # folder path of onnx model --dataset=NeelNanda/pile-10k \ --tokenizer=meta-llama/Llama-2-7b-hf \ # model name or folder path containing all relevant files for model's tokenizer --algorithm=WOQ_TUNE # support WOQ_TUNE, RTN, AWQ, GPTQ \ - --quant_format=QOperator # support QOperator and QDQ + --quant_format=QDQ # support QOperator and QDQ ``` ## 2. Benchmark diff --git a/examples/nlp/huggingface_model/text_generation/quantization/weight_only/main.py b/examples/nlp/huggingface_model/text_generation/quantization/weight_only/main.py index db793ea49..196d7d4d1 100644 --- a/examples/nlp/huggingface_model/text_generation/quantization/weight_only/main.py +++ b/examples/nlp/huggingface_model/text_generation/quantization/weight_only/main.py @@ -106,7 +106,7 @@ default=[], help="nodes that will not be quantized. Doesn't take effect when 'algorithm' is 'WOQ_TUNE'", ) -parser.add_argument("--quant_format", type=str, default="QOperator", choices=["QOperator", "QDQ"]) +parser.add_argument("--quant_format", type=str, default="QDQ", choices=["QOperator", "QDQ"]) args = parser.parse_args() if args.tune and not os.path.exists(args.output_model): diff --git a/onnx_neural_compressor/algorithms/utility.py b/onnx_neural_compressor/algorithms/utility.py index f7d1cf99d..a45a291f7 100644 --- a/onnx_neural_compressor/algorithms/utility.py +++ b/onnx_neural_compressor/algorithms/utility.py @@ -340,7 +340,7 @@ def make_weight_only_dequant_node( input_names = [] kwargs = {"block_size": block_size, "axis": axis} - q_weight = q_weight.reshape((-1, weight_shape[-1])).T + q_weight = q_weight.reshape((weight_shape[-1], -1)).T if num_bits == 4: q_weight = ((q_weight[:, ::2] & 0xF | q_weight[:, 1::2] << 4) & 0xFF).astype("uint8") diff --git a/onnx_neural_compressor/algorithms/weight_only/gptq.py b/onnx_neural_compressor/algorithms/weight_only/gptq.py index a3c639bb1..ed8daf26b 100644 --- a/onnx_neural_compressor/algorithms/weight_only/gptq.py +++ b/onnx_neural_compressor/algorithms/weight_only/gptq.py @@ -25,6 +25,7 @@ from packaging.version import Version from onnx_neural_compressor import constants, data_reader, onnx_model, utility +from onnx_neural_compressor.algorithms.weight_only import rtn from onnx_neural_compressor.algorithms import utility as quant_utils from onnx_neural_compressor.algorithms.layer_wise import core from onnx_neural_compressor.quantization import config @@ -301,13 +302,13 @@ def gptq_quantize( weight, H, ) in zip(node_list, weights, Hs): - weight_dtype = weight_config[node.name].get("weight_dtype", "int") num_bits = weight_config[node.name].get("weight_bits", 4) group_size = weight_config[node.name].get("weight_group_size", 32) sym = weight_config[node.name].get("weight_sym", True) - accuracy_level = weight_config[node.name].get("accuracy_level", 0) group_size = group_size if group_size != -1 else weight.shape[0] - dtype = weight.dtype + + weight_tensor = model.get_initializer(node.input[1]) + init_share_num = model.get_initializer_share_num(node.input[1]) # weight -> quant -> dequant -> q_weight q_weight = _gptq( @@ -322,88 +323,30 @@ def gptq_quantize( mse=mse, perchannel=perchannel, ) - - weight_tensor = model.get_initializer(node.input[1]) - org_shape = weight.shape - init_share_num = model.get_initializer_share_num(node.input[1]) - - satisfy_MatMulNBits_condition = ort_version > constants.ONNXRT1161_VERSION and num_bits == 4 - satisfy_MatMulFpQ4_condition = ( - ort_version >= constants.ONNXRT116_VERSION and num_bits == 4 and group_size == 32 - ) - if ( - quant_format == 1 # QDQ format - and num_bits in [4, 8] - and ort_version >= constants.ONNXRT119_VERSION - and model.opset_import[0].version > 20 - ): - _, _, zp, scale, q_weight = quant_utils.quantize_data( - weight.T.reshape((-1, group_size)), - weight_dtype + str(num_bits), - sym, - axis=1, - ) - dequant_node, new_inits = quant_utils.make_weight_only_dequant_node( - node=node, - weight_shape=org_shape, - num_bits=num_bits, - dtype=weight_dtype, - q_weight=q_weight, - scale=scale.astype(weight.dtype), - axis=0, - block_size=group_size, - zero_point=zp, - ) - model.add_initializers(new_inits) - model.add_node(dequant_node) - node.name += "_Q" - elif ("CUDAExecutionProvider" in providers and satisfy_MatMulNBits_condition) or ( - "CUDAExecutionProvider" not in providers - and (satisfy_MatMulFpQ4_condition or satisfy_MatMulNBits_condition) - ): - # MatMulFpQ4 support 4 bits and 32 group_size with ort 1.16.0 and 1.16.1 versions, supported by CPU EP - # MatMulNBits supports 4 bits and 2^n group_size with ort > 1.16.1, supported by CPU EP AND CUDA EP - k_blocks = (org_shape[0] + group_size - 1) // group_size - q_weight = quant_utils.pad_tensor(q_weight, group_size, k_blocks) - _, _, zp, scale, q_weight = quant_utils.quantize_data( - q_weight.T.reshape((-1, group_size)), - weight_dtype + str(num_bits), - sym, - axis=1, - ) - q_matmul_node, new_inits = quant_utils.make_matmul_weight_only_node( - node=node, - weight_shape=org_shape, - num_bits=num_bits, - group_size=group_size, - k_blocks=k_blocks, - q_weight=q_weight, - scale=scale.astype(dtype), - zero_point=zp if not sym else None, - accuracy_level=accuracy_level, - ) - - model.add_initializers(new_inits) - model.remove_node(node) - model.add_node(q_matmul_node) + if init_share_num == 1: + model.set_initializer(weight_tensor, q_weight) else: - q_weight_tensor = onnx.helper.make_tensor( - name=node.input[1] + "_Q{}G{}".format(str(num_bits), str(group_size)), - data_type=onnx.helper.np_dtype_to_tensor_dtype(dtype), - dims=q_weight.shape, - vals=q_weight.astype(dtype).tobytes(), + new_init = onnx.helper.make_tensor( + name=node.input[1] + "_GPTQ", + data_type=weight_tensor.data_type, + dims=weight_tensor.dims, + vals=array.flatten().tolist().tobytes(), raw=True, ) - model.add_initializer(q_weight_tensor) - node.input[1] = q_weight_tensor.name - if init_share_num == 1: - model.remove_initializer(weight_tensor) + node.input[0] = new_init.name + model.add_initializer(new_init) + + model.model = rtn.rtn_quantize( + model=model, + weight_config=weight_config, + ratios=full_ratio, + providers=providers, + quant_format=quant_format, + ) model.remove_tensors_from_outputs(output_names) model.model.graph.output.MergeFrom(org_output) - model.topological_sort() - # reload external data to prevent external data file path errors if model.is_large_model: diff --git a/onnx_neural_compressor/quantization/algorithm_entry.py b/onnx_neural_compressor/quantization/algorithm_entry.py index e58acf2f0..1320ee304 100644 --- a/onnx_neural_compressor/quantization/algorithm_entry.py +++ b/onnx_neural_compressor/quantization/algorithm_entry.py @@ -23,7 +23,34 @@ from onnx_neural_compressor.algorithms.post_training_quant import calibrate, quantizer from onnx_neural_compressor.algorithms.smoother import core from onnx_neural_compressor.algorithms.weight_only import awq, gptq, rtn -from onnx_neural_compressor.quantization import config +from onnx_neural_compressor.quantization import QuantFormat, config + +ort_version = version.Version(ort.__version__) + + +def _update_quant_format(algorithm, model, quant_config): + if isinstance(model, str): + model = onnx.load(model, load_external_data=False) + quant_format = getattr(quant_config.quant_format, "value", quant_format) + if algorithm in [constants.RTN, constants.AWQ, constants.GPTQ]: + if quant_config.weight_bits not in [4, 8] and quant_format == 1: + logger.warning( + "QDQ format only support 4 and 8 bits now, but get {} bits." + "Use QOperator format instead.".format(quant_config.weight_bits) + ) + quant_config.quant_format = QuantFormat.QOperator + elif ( + quant_config.weight_bits == 4 + and (ort_version < constants.ONNXRT119_VERSION or model.opset_import[0].version < 21) + and quant_format == 1 + ): + logger.warning( + "QDQ format for 4 bits tensor requires onnxruntime >= 1.19.0 and the opset version of model > 20, " + "but get onnxruntime version is {}, opset version is {}. Use QOperator format instead.".format( + ort_version, model.opset_import[0].version + ) + ) + quant_config.quant_format = QuantFormat.QOperator ###################### RTN Algo Entry ################################## @@ -32,6 +59,8 @@ def rtn_quantize_entry( model: Union[pathlib.Path, str], quant_config: config.RTNConfig, *args, **kwargs ) -> onnx.ModelProto: """The main entry to apply rtn quantization.""" + _update_quant_format(constants.RTN, model, quant_config) + if len(quant_config.config_mapping) == 0: # map config to each op model_info = config.RTNConfig.get_model_info(model=model) @@ -39,6 +68,7 @@ def rtn_quantize_entry( logger.debug(config_mapping) else: config_mapping = quant_config.config_mapping + quant_kwargs = {} for key in config.RTNConfig.model_params_list: val = getattr(quant_config, key) @@ -62,6 +92,8 @@ def gptq_quantize_entry( calibration_data_reader, data_reader.CalibrationDataReader ), "Please follow onnx_neural_compressor/data_reader.py to implement calibration_data_reader" + _update_quant_format(constants.GPTQ, model, quant_config) + if len(quant_config.config_mapping) == 0: # map config to each op model_info = config.GPTQConfig.get_model_info(model=model) @@ -96,6 +128,8 @@ def awq_quantize_entry( calibration_data_reader, data_reader.CalibrationDataReader ), "Please follow onnx_neural_compressor/data_reader.py to implement calibration_data_reader" + _update_quant_format(constants.AWQ, model, quant_config) + if len(quant_config.config_mapping) == 0: # map config to each op model_info = config.AWQConfig.get_model_info(model=model)