diff --git a/docs/source/export.md b/docs/source/export.md index 6dd1fc35c2f..5989286975e 100644 --- a/docs/source/export.md +++ b/docs/source/export.md @@ -37,15 +37,15 @@ Here is the workflow of our export API for PyTorch/Tensorflow FP32/INT8 model. Post-Training Static Quantized INT8 - QLinear/QDQ INT8 + QOperator/QDQ INT8 Post-Training Dynamic Quantized INT8 - / + QOperator INT8 Quantization-aware Training INT8 - QLinear/QDQ INT8 + QOperator/QDQ INT8 TensorFlow @@ -63,10 +63,6 @@ Here is the workflow of our export API for PyTorch/Tensorflow FP32/INT8 model. -> **Note**: Follow this step to export a post training dynamic quantized ONNX model from PyTorch model: \ - 1. export FP32 PyTorch model to FP32 ONNX model. \ - 2. use FP32 ONNX model as the input model for post training dynamic quantization. - ## Examples ### PyTorch Model @@ -96,7 +92,7 @@ from neural_compressor.config import Torch2ONNXConfig int8_onnx_config = Torch2ONNXConfig( dtype="int8", opset_version=14, - quant_format="QLinear", # or QDQ + quant_format="QOperator", # or QDQ example_inputs=torch.randn(1, 3, 224, 224), input_names=['input'], output_names=['output'], diff --git a/examples/.config/model_params_pt2onnx.json b/examples/.config/model_params_pt2onnx.json index c383a53e2a3..9f7cd7f7112 100644 --- a/examples/.config/model_params_pt2onnx.json +++ b/examples/.config/model_params_pt2onnx.json @@ -8,6 +8,14 @@ "main_script": "main.py", "batch_size": 100 }, + "resnet18_dynamic": { + "model_src_dir": "image_recognition/torchvision_models/export/fx", + "source_model_dataset": "/tf_dataset/pytorch/ImageNet/raw", + "target_model_dataset": "/tf_dataset2/datasets/imagenet/ImagenetRaw/ImagenetRaw_small_5000", + "input_model": "resnet18", + "main_script": "main.py", + "batch_size": 100 + }, "resnet50": { "model_src_dir": "image_recognition/torchvision_models/export/fx", "source_model_dataset": "/tf_dataset/pytorch/ImageNet/raw", @@ -16,6 +24,14 @@ "main_script": "main.py", "batch_size": 100 }, + "resnet50_dynamic": { + "model_src_dir": "image_recognition/torchvision_models/export/fx", + "source_model_dataset": "/tf_dataset/pytorch/ImageNet/raw", + "target_model_dataset": "/tf_dataset2/datasets/imagenet/ImagenetRaw/ImagenetRaw_small_5000", + "input_model": "resnet50", + "main_script": "main.py", + "batch_size": 100 + }, "bert_base_MRPC": { "model_src_dir": "nlp/huggingface_models/text-classification/export/fx", "source_model_dataset": "mrpc", @@ -24,6 +40,14 @@ "main_script": "run_glue.py", "batch_size": 64 }, + "bert_base_MRPC_dynamic": { + "model_src_dir": "nlp/huggingface_models/text-classification/export/fx", + "source_model_dataset": "mrpc", + "target_model_dataset": "mrpc", + "input_model": "/tf_dataset/pytorch/glue_data/base_weights/bert_MRPC_output", + "main_script": "run_glue.py", + "batch_size": 64 + }, "bert_large_MRPC": { "model_src_dir": "nlp/huggingface_models/text-classification/export/fx", "source_model_dataset": "mrpc", @@ -31,6 +55,14 @@ "input_model": "/tf_dataset/pytorch/glue_data/weights/bert_MRPC_output", "main_script": "run_glue.py", "batch_size": 64 + }, + "bert_large_MRPC_dynamic": { + "model_src_dir": "nlp/huggingface_models/text-classification/export/fx", + "source_model_dataset": "mrpc", + "target_model_dataset": "mrpc", + "input_model": "/tf_dataset/pytorch/glue_data/weights/bert_MRPC_output", + "main_script": "run_glue.py", + "batch_size": 64 } } } \ No newline at end of file diff --git a/examples/pytorch/image_recognition/torchvision_models/export/fx/README.md b/examples/pytorch/image_recognition/torchvision_models/export/fx/README.md index e2e9b22b658..27b2557f8ce 100644 --- a/examples/pytorch/image_recognition/torchvision_models/export/fx/README.md +++ b/examples/pytorch/image_recognition/torchvision_models/export/fx/README.md @@ -34,7 +34,7 @@ Run run_export.sh to get ONNX model from PyTorch model. # export fp32 model bash run_export.sh --input_model=resnet50 --dtype=fp32 --dataset_location=/path/to/pytorch-imagenet --output_model=resnet50-fp32.onnx # export int8 model -bash run_export.sh --input_model=resnet50 --dtype=int8 --quant_format=[QDQ|QLinear] --dataset_location=/path/to/pytorch-imagenet --output_model=resnet50-int8.onnx +bash run_export.sh --input_model=resnet50 --dtype=int8 --quant_format=[QDQ|QOperator] --dataset_location=/path/to/pytorch-imagenet --output_model=resnet50-int8.onnx --approach=[static|dynamic] ``` ### 2. To get the benchmark of exported and tuned models, includes Batch_size and Throughput: diff --git a/examples/pytorch/image_recognition/torchvision_models/export/fx/main.py b/examples/pytorch/image_recognition/torchvision_models/export/fx/main.py index b9fc4016792..ace68d6fde7 100644 --- a/examples/pytorch/image_recognition/torchvision_models/export/fx/main.py +++ b/examples/pytorch/image_recognition/torchvision_models/export/fx/main.py @@ -90,8 +90,10 @@ parser.add_argument('--export', dest='export', action='store_true', help='run export') parser.add_argument('--export_dtype', default='fp32', choices=['fp32', 'int8'], help='choose the data type [fp32/int8] of PyTorch model to be exported.') -parser.add_argument('--quant_format', default='QDQ', choices=['QDQ', 'QLinear'], - help='choose the format [QDQ/QLinear] of int8 ONNX model exported.') +parser.add_argument('--quant_format', default='QDQ', choices=['QDQ', 'QOperator'], + help='choose the format [QDQ/QOperator] of int8 ONNX model exported.') +parser.add_argument('--approach', default='static', choices=['static', 'dynamic'], + help='Post-Training Quantization method.') best_acc1 = 0 @@ -190,7 +192,7 @@ def eval_func(model): if args.export and args.export_dtype == 'int8': from neural_compressor import PostTrainingQuantConfig from neural_compressor import quantization - conf = PostTrainingQuantConfig() + conf = PostTrainingQuantConfig(approach=args.approach) q_model = quantization.fit(model, conf, calib_dataloader=val_loader, diff --git a/examples/pytorch/image_recognition/torchvision_models/export/fx/run_export.sh b/examples/pytorch/image_recognition/torchvision_models/export/fx/run_export.sh index 366db7e850b..ebab5a27988 100644 --- a/examples/pytorch/image_recognition/torchvision_models/export/fx/run_export.sh +++ b/examples/pytorch/image_recognition/torchvision_models/export/fx/run_export.sh @@ -11,7 +11,7 @@ function main { # init params function init_params { dtype='fp32' - quant_format='QDQ' # or QLinear + quant_format='QDQ' # or QOperator tuned_checkpoint=saved_results for var in "$@" do @@ -31,6 +31,9 @@ function init_params { --quant_format=*) quant_format=$(echo $var |cut -f2 -d=) ;; + --approach=*) + approach=$(echo $var |cut -f2 -d=) + ;; esac done @@ -48,6 +51,7 @@ function run_tuning { --export \ --export_dtype ${dtype} \ --quant_format ${quant_format} \ + --approach ${approach} \ ${dataset_location} } diff --git a/examples/pytorch/nlp/huggingface_models/text-classification/export/fx/README.md b/examples/pytorch/nlp/huggingface_models/text-classification/export/fx/README.md index ab65fa1c09f..aa0bc54a1db 100644 --- a/examples/pytorch/nlp/huggingface_models/text-classification/export/fx/README.md +++ b/examples/pytorch/nlp/huggingface_models/text-classification/export/fx/README.md @@ -45,7 +45,7 @@ Please pass in the name of dataset, supported datasets are 'mrpc', 'qqp', 'qnli' # export fp32 model bash run_export.sh --input_model=[model_name_or_path] --dataset_location=[dataset_name] --dtype=fp32 --output_model=bert-fp32.onnx # export int8 model -bash run_export.sh --input_model=[model_name_or_path] --dataset_location=[dataset_name] --dtype=int8 --quant_format=[QDQ/QLinear] --output_model=bert-int8.onnx +bash run_export.sh --input_model=[model_name_or_path] --dataset_location=[dataset_name] --dtype=int8 --quant_format=[QDQ/QOperator] --output_model=bert-int8.onnx --approach=[static|dynamic] ``` ### 2. Get the benchmark results of exported and tuned models, including Batch_size and Throughput: diff --git a/examples/pytorch/nlp/huggingface_models/text-classification/export/fx/run_export.sh b/examples/pytorch/nlp/huggingface_models/text-classification/export/fx/run_export.sh index 02226b0aa83..62f6b0c7273 100644 --- a/examples/pytorch/nlp/huggingface_models/text-classification/export/fx/run_export.sh +++ b/examples/pytorch/nlp/huggingface_models/text-classification/export/fx/run_export.sh @@ -11,7 +11,7 @@ function main { # init params function init_params { dtype='fp32' - quant_format='QDQ' # or QLinear + quant_format='QDQ' # or QOperator for var in "$@" do case $var in @@ -30,6 +30,9 @@ function init_params { --quant_format=*) quant_format=$(echo $var |cut -f2 -d=) ;; + --approach=*) + approach=$(echo $var |cut -f2 -d=) + ;; esac done @@ -60,6 +63,7 @@ function run_tuning { --quant_format ${quant_format} \ --output_dir ${tuned_checkpoint} \ --overwrite_output_dir \ + --approach ${approach} \ ${extra_cmd} } diff --git a/examples/pytorch/nlp/huggingface_models/text-classification/export/fx/run_glue.py b/examples/pytorch/nlp/huggingface_models/text-classification/export/fx/run_glue.py index 22c6ba12fbc..7e8a0805af4 100644 --- a/examples/pytorch/nlp/huggingface_models/text-classification/export/fx/run_glue.py +++ b/examples/pytorch/nlp/huggingface_models/text-classification/export/fx/run_glue.py @@ -190,7 +190,7 @@ class ModelArguments: default="fp32", metadata={"help": "choose the data type [fp32/int8] of PyTorch model to be exported."} ) quant_format: str = field( - default="QDQ", metadata={"help": "choose the format [QDQ/QLinear] of int8 ONNX model exported."} + default="QDQ", metadata={"help": "choose the format [QDQ/QOperator] of int8 ONNX model exported."} ) output_model: str = field( default="model.onnx", metadata={"help": "the name of exported model."} @@ -210,6 +210,12 @@ class ModelArguments: "help": "The inference iterations to run for benchmark." }, ) + approach: str = field( + default='static', + metadata={ + "help": "Post-Training Quantization method." + }, + ) def main(): @@ -541,13 +547,18 @@ def eval_func(model): strategy_kwargs={"confidence_batches": 1}, max_trials=600, ) - conf = PostTrainingQuantConfig( - approach="static", - quant_level=1, - tuning_criterion=tuning_criterion, - op_type_dict={"Embedding":FP32}, - calibration_sampling_size=[300], - ) + if model_args.approach == "static": + conf = PostTrainingQuantConfig( + approach=model_args.approach, + quant_level=1, + tuning_criterion=tuning_criterion, + op_type_dict={"Embedding":FP32}, + calibration_sampling_size=[300], + ) + elif model_args.approach == "dynamic": + conf = PostTrainingQuantConfig( + approach=model_args.approach, + ) q_model = fit(model, conf=conf, calib_dataloader=eval_dataloader, eval_func=eval_func) from neural_compressor.utils.load_huggingface import save_for_huggingface_upstream save_for_huggingface_upstream(q_model, tokenizer, training_args.output_dir) diff --git a/neural_compressor/experimental/export/torch2onnx.py b/neural_compressor/experimental/export/torch2onnx.py index 57461bd2083..99f0d2f656a 100644 --- a/neural_compressor/experimental/export/torch2onnx.py +++ b/neural_compressor/experimental/export/torch2onnx.py @@ -54,9 +54,233 @@ def _prepare_inputs(pt_model, input_names, example_inputs): example_inputs = input2tuple(example_inputs) return input_names, example_inputs +def get_node_mapping( + fp32_model, + fp32_onnx_path, +): + """Get PyTorch module and ONNX node mapping. + + Args: + fp32_model (torch.nn.Module): quantization configuration from PyTorch. + fp32_onnx_path (str): path to fp32 onnx model. + + Returns: + module_node_mapping: op mapping from PyTorch to ONNX. + """ + def check_data(op_type, data, module_dict): + for name, value in module_dict.items(): + if value.shape == data.shape: + if (value == data).all(): + module_dict.pop(name) + return name + elif op_type == 'Conv': + # Convolution weight data have fluction and BN fusion will insert scale. + # We use the weight scale of the first output channel to check. + weight_scale = value[0] / data[0] + if np.allclose(weight_scale - np.mean(weight_scale), 0, atol=1.e-5): + module_dict.pop(name) + return name + return None + + module_dict = {} + for name, module in fp32_model.named_modules(): + if 'Conv' in str(module.__class__.__name__) or \ + 'Embedding' in str(module.__class__.__name__) or \ + 'Linear' in str(module.__class__.__name__): + if hasattr(module, 'weight'): + value = module.weight.detach().cpu().numpy() + module_dict[name] = value + + module_node_mapping = {} + fp32_onnx_model = onnx.load(fp32_onnx_path) + initializer_data = {tensor.name: tensor for tensor in fp32_onnx_model.graph.initializer} + from onnx import numpy_helper + for node in fp32_onnx_model.graph.node: + if node.op_type in op_types_to_quantize: + if node.op_type == 'MatMul' and node.input[1] in initializer_data: + data = numpy_helper.to_array(initializer_data[node.input[1]]).T + elif node.op_type == 'Gather' and node.input[0] in initializer_data: + data = numpy_helper.to_array(initializer_data[node.input[0]]) + elif node.op_type in ['Conv', 'Gemm']: + data = numpy_helper.to_array(initializer_data[node.input[1]]) + else: + continue + pt_name = check_data(node.op_type, data, module_dict) + if pt_name: + module_node_mapping[pt_name] = node.name + return module_node_mapping + +def get_quantizable_onnx_ops( + int8_model, + module_node_mapping +): + """Get quantizable onnx ops. + + Args: + int8_model (torch.nn.Module): PyTorch int8 model. + module_node_mapping (dict): op mapping from PyTorch to ONNX. + + Returns: + quantize_nodes: all onnx node that should be quantized. + """ + quantize_nodes = [] + for name, module in int8_model.named_modules(): + if 'Conv' in str(module.__class__.__name__) or \ + 'Embedding' in str(module.__class__.__name__) or \ + 'Linear' in str(module.__class__.__name__): + if hasattr(module, 'weight') and callable(module.weight): + if module.weight().dtype in [torch.qint8, torch.quint8]: + if name.split('.module')[0] in module_node_mapping: + node = module_node_mapping[name.split('.module')[0]] + quantize_nodes.append(node) + return quantize_nodes + +def dynamic_quant_export( + pt_fp32_model, + pt_int8_model, + save_path, + example_inputs, + q_config, + opset_version, + dynamic_axes, + input_names, + output_names, + weight_type, +): + """Export dynamic quantized model. + + Args: + pt_fp32_model (torch.nn.module): PyTorch FP32 model. + pt_int8_model (torch.nn.module): PyTorch INT8 model. + save_path (str): save path of ONNX model. + example_inputs (dict|list|tuple|torch.Tensor): used to trace torch model. + q_config (dict): containing quantization configuration. + opset_version (int, optional): opset version. Defaults to 14. + dynamic_axes (dict, optional): dynamic axes. Defaults to + {"input": {0: "batch_size"}, "output": {0: "batch_size"}}. + input_names (dict, optional): input names. Defaults to None. + output_names (dict, optional): output names. Defaults to None. + weight_type (str, optional): data types of weight of ONNX model + (only needed for exporting dynamic quantized model). Defaults to 'S8'. + """ + global op_types_to_quantize + op_types_to_quantize=['MatMul', 'Gemm', 'Gather'] + + # pylint: disable=E1101 + fp32_onnx_path = save_path + '.tmp' if save_path else 'int8-model.onnx.tmp' + torch_to_fp32_onnx( + pt_fp32_model, + fp32_onnx_path, + example_inputs, + opset_version=opset_version, + input_names=input_names, + output_names=output_names, + dynamic_axes=dynamic_axes, + verbose=False, + ) + + module_node_mapping = get_node_mapping(pt_fp32_model, fp32_onnx_path) + quantize_nodes = get_quantizable_onnx_ops(pt_int8_model, module_node_mapping) + + REDUCE_RANGE = q_config['reduce_range'] + if REDUCE_RANGE: + logger.info("Reduce range is {}".format(str(REDUCE_RANGE))) + + logger.info("Quantization format is not avalible when executing dynamic quantization.") + + if weight_type.upper() == 'S8': + weight_type = ortq.QuantType.QInt8 + elif weight_type.upper() == 'U8': + weight_type = ortq.QuantType.QUInt8 + else: + assert False, "Right now, we don't support weight type: {}, " \ + "please use S8/U8.".format(weight_type) + + ortq.quantize_dynamic( + fp32_onnx_path, + save_path, + per_channel=True, + reduce_range=REDUCE_RANGE, + weight_type=weight_type, + nodes_to_quantize=quantize_nodes, + nodes_to_exclude=[], + extra_options={} + ) + + os.remove(fp32_onnx_path) + +def static_quant_export( + pt_int8_model, + save_path, + example_inputs, + q_config, + opset_version, + dynamic_axes, + input_names, + output_names, + quant_format, +): + """Export static quantized model. + + Args: + pt_int8_model (torch.nn.module): PyTorch INT8 model. + save_path (str): save path of ONNX model. + example_inputs (dict|list|tuple|torch.Tensor): used to trace torch model. + q_config (dict): containing quantization configuration. + opset_version (int, optional): opset version. Defaults to 14. + dynamic_axes (dict, optional): dynamic axes. Defaults to + {"input": {0: "batch_size"}, "output": {0: "batch_size"}}. + input_names (dict, optional): input names. Defaults to None. + output_names (dict, optional): output names. Defaults to None. + quant_format (str, optional): _quantization format of ONNX model. Defaults to 'QDQ'. + """ + input_names, example_inputs = _prepare_inputs(pt_int8_model, input_names, example_inputs) + + def model_wrapper(model_fn): + # export doesn't support a dictionary output, so manually turn it into a tuple + # refer to https://discuss.tvm.apache.org/t/how-to-deal-with-prim-dictconstruct/11978 + def wrapper(*args, **kwargs): + output = model_fn(*args, **kwargs) + if isinstance(output, dict): + return tuple(v for v in output.values() if v is not None) + else: + return output + return wrapper + pt_int8_model.forward = model_wrapper(pt_int8_model.forward) + with torch.no_grad(): + try: + torch.onnx.export( + pt_int8_model, + input2tuple(example_inputs), + save_path, + opset_version=opset_version, + input_names=input_names, + output_names=output_names, + dynamic_axes=dynamic_axes, + ) + except TypeError: + config_name = "QuantizationAwareTrainingConfig" \ + if q_config['approach'] == "quant_aware_training" else "PostTrainingQuantConfig" + logger.error("Export failed, possibly because unsupported quantized ops. Check " + "neural-compressor/docs/source/export.md#supported-quantized-ops " + "for supported ops.") + logger.error("Please fallback unsupported quantized ops by setting 'op_type_dict' or " + "'op_name_dict' in '{}' config. ".format(config_name)) + exit(0) + except Exception as e: + import pdb;pdb.set_trace() + logger.error(e) + exit(0) + + if quant_format != "QDQ": + sess_options = ort.SessionOptions() + sess_options.graph_optimization_level=ort.GraphOptimizationLevel.ORT_ENABLE_EXTENDED + sess_options.optimized_model_filepath=save_path + ort.InferenceSession(save_path, sess_options) + def torch_to_fp32_onnx( - pt_model, + pt_fp32_model, save_path, example_inputs, opset_version=14, @@ -70,7 +294,7 @@ def torch_to_fp32_onnx( """Export FP32 PyTorch model into FP32 ONNX model. Args: - pt_model (torch.nn.module): PyTorch model. + pt_fp32_model (torch.nn.module): PyTorch FP32 model. save_path (str): save path of ONNX model. example_inputs (dict|list|tuple|torch.Tensor): used to trace torch model. opset_version (int, optional): opset version. Defaults to 14. @@ -82,14 +306,14 @@ def torch_to_fp32_onnx( verbose (bool, optional): dump verbose or not. Defaults to True. """ from neural_compressor.utils.pytorch import is_int8_model - assert is_int8_model(pt_model) == False, "The fp32 model is replaced during quantization. " + \ + assert is_int8_model(pt_fp32_model) == False, "The fp32 model is replaced during quantization. " + \ "please customize a eval_func when quantizing, if not, such as `lambda x: 1`." - input_names, example_inputs = _prepare_inputs(pt_model, input_names, example_inputs) + input_names, example_inputs = _prepare_inputs(pt_fp32_model, input_names, example_inputs) with torch.no_grad(): torch.onnx.export( - pt_model, + pt_fp32_model, example_inputs, save_path, opset_version=opset_version, @@ -107,7 +331,8 @@ def torch_to_fp32_onnx( def torch_to_int8_onnx( - pt_model, + pt_fp32_model, + pt_int8_model, save_path, example_inputs, q_config, @@ -117,12 +342,14 @@ def torch_to_int8_onnx( input_names=None, output_names=None, quant_format: str = 'QDQ', + weight_type: str = 'S8', verbose=True, ): """Export INT8 PyTorch model into INT8 ONNX model. Args: - pt_model (torch.nn.module): PyTorch model. + pt_fp32_model (torch.nn.module): PyTorch FP32 model. + pt_int8_model (torch.nn.module): PyTorch INT8 model. save_path (str): save path of ONNX model. example_inputs (dict|list|tuple|torch.Tensor): used to trace torch model. q_config (dict): containing quantization configuration. @@ -132,61 +359,32 @@ def torch_to_int8_onnx( input_names (dict, optional): input names. Defaults to None. output_names (dict, optional): output names. Defaults to None. quant_format (str, optional): _quantization format of ONNX model. Defaults to 'QDQ'. + weight_type (str, optional): data types of weight of ONNX model + (only needed for exporting dynamic quantized model). Defaults to 'S8'. verbose (bool, optional): dump verbose or not. Defaults to True. """ from neural_compressor.utils.pytorch import is_int8_model - assert is_int8_model(pt_model), "The exported model is not INT8 model, "\ + assert is_int8_model(pt_int8_model), "The exported model is not INT8 model, "\ "please reset 'dtype' to 'FP32' or check your model." assert not q_config is None, "'q_config' is needed when export an INT8 model." - if q_config['approach'] == 'post_training_dynamic_quant': # pragma: no cover - assert False, "Post training dynamic quantized PyTorch model is not supported " \ - "to export to ONNX directly. Please follow this step to get a post training " \ - "dynamic quantized ONNX model: " \ - "1. export FP32 PyTorch model to FP32 ONNX model. " \ - "2. use FP32 ONNX model as the input model for post training dynamic quantization." - - input_names, example_inputs = _prepare_inputs(pt_model, input_names, example_inputs) - - def model_wrapper(model_fn): - # export doesn't support a dictionary output, so manually turn it into a tuple - # refer to https://discuss.tvm.apache.org/t/how-to-deal-with-prim-dictconstruct/11978 - def wrapper(*args, **kwargs): - output = model_fn(*args, **kwargs) - if isinstance(output, dict): - return tuple(v for v in output.values() if v is not None) - else: - return output - return wrapper - pt_model.forward = model_wrapper(pt_model.forward) + quant_format = quant_format.upper() + if quant_format == 'QDQ' and opset_version < 13: # pragma: no cover + opset_version = 13 + logger.warning("QDQ format requires opset_version >= 13, " + + "we reset opset_version={} here".format(opset_version)) - with torch.no_grad(): - try: - torch.onnx.export( - pt_model, - input2tuple(example_inputs), - save_path, - opset_version=opset_version, - input_names=input_names, - output_names=output_names, - dynamic_axes=dynamic_axes, - ) - except Exception as e: - config_name = "QuantizationAwareTrainingConfig" \ - if q_config['approach'] == "quant_aware_training" else "PostTrainingQuantConfig" - logger.error("Export failed, possibly because unsupported quantized ops. Check " - "neural-compressor/docs/source/export.md#supported-quantized-ops " - "for supported ops.") - logger.error("Please fallback unsupported quantized ops by setting 'op_type_dict' or " - "'op_name_dict' in '{}' config. ".format(config_name)) - return - - if quant_format != "QDQ": - sess_options = ort.SessionOptions() - sess_options.graph_optimization_level=ort.GraphOptimizationLevel.ORT_ENABLE_EXTENDED - sess_options.optimized_model_filepath=save_path - ort.InferenceSession(save_path, sess_options) + if q_config['approach'] == 'post_training_dynamic_quant': + # dynamic quantization export follow these steps: + # "1. export FP32 PyTorch model to FP32 ONNX model. " + # "2. use FP32 ONNX model as the input model for post training dynamic quantization." + # TODO: will be removed once torch supports dynamic quantization export + dynamic_quant_export(pt_fp32_model, pt_int8_model, save_path, example_inputs, q_config, + opset_version, dynamic_axes, input_names, output_names, weight_type) + else: + static_quant_export(pt_int8_model, save_path, example_inputs, q_config, opset_version, + dynamic_axes, input_names, output_names, quant_format) if verbose: info = "The INT8 ONNX Model exported to path: {0}".format(save_path) diff --git a/neural_compressor/model/torch_model.py b/neural_compressor/model/torch_model.py index 30841775e35..58011a5466e 100644 --- a/neural_compressor/model/torch_model.py +++ b/neural_compressor/model/torch_model.py @@ -369,6 +369,7 @@ def export( if conf.dtype == 'int8': torch_to_int8_onnx( + self.fp32_model, self.model, save_path, conf.example_inputs, diff --git a/test/export/test_torch2onnx.py b/test/export/test_torch2onnx.py index c89ae77f784..9c2a623b1f2 100644 --- a/test/export/test_torch2onnx.py +++ b/test/export/test_torch2onnx.py @@ -159,14 +159,8 @@ def test_int8_CV_models(self): dynamic_axes={"input": {0: "batch_size"}, "output": {0: "batch_size"}}, ) - if fake_yaml == "dynamic": - try: - q_model.export('int8-nlp-qdq-model.onnx', int8_onnx_config) - except Exception as e: - self.assertIsInstance(e, AssertionError) - elif fake_yaml == "static": - q_model.export('int8-cv-qdq-model.onnx', int8_onnx_config) - check_CV_onnx('int8-cv-qdq-model.onnx', self.cv_dataloader) + q_model.export('int8-cv-qdq-model.onnx', int8_onnx_config) + check_CV_onnx('int8-cv-qdq-model.onnx', self.cv_dataloader) int8_onnx_config = Torch2ONNXConfig( dtype="int8", @@ -178,14 +172,8 @@ def test_int8_CV_models(self): dynamic_axes={"input": {0: "batch_size"}, "output": {0: "batch_size"}}, ) - if fake_yaml == "dynamic": - try: - q_model.export('int8-nlp-qdq-model.onnx', int8_onnx_config) - except Exception as e: - self.assertIsInstance(e, AssertionError) - elif fake_yaml == "static": - q_model.export('int8-cv-qlinear-model.onnx', int8_onnx_config) - check_CV_onnx('int8-cv-qlinear-model.onnx', self.cv_dataloader) + q_model.export('int8-cv-qlinear-model.onnx', int8_onnx_config) + check_CV_onnx('int8-cv-qlinear-model.onnx', self.cv_dataloader) def test_fp32_NLP_models(self): @@ -252,14 +240,8 @@ def test_int8_NLP_models(self): output_names=['labels'], dynamic_axes=dynamic_axes, ) - if fake_yaml == "dynamic": - try: - q_model.export('int8-nlp-qdq-model.onnx', int8_onnx_config) - except Exception as e: - self.assertIsInstance(e, AssertionError) - elif fake_yaml == "static": - q_model.export('int8-nlp-qdq-model.onnx', int8_onnx_config) - check_NLP_onnx('int8-nlp-qdq-model.onnx', self.nlp_input) + q_model.export('int8-nlp-qdq-model.onnx', int8_onnx_config) + check_NLP_onnx('int8-nlp-qdq-model.onnx', self.nlp_input) int8_onnx_config = Torch2ONNXConfig( dtype="int8", @@ -270,16 +252,9 @@ def test_int8_NLP_models(self): output_names=['labels'], dynamic_axes=dynamic_axes, ) - if fake_yaml == "dynamic": - try: - q_model.export('int8-nlp-qdq-model.onnx', int8_onnx_config) - except Exception as e: - self.assertIsInstance(e, AssertionError) - elif fake_yaml == "static": - q_model.export('int8-nlp-qlinear-model.onnx', int8_onnx_config) - check_NLP_onnx('int8-nlp-qlinear-model.onnx', self.nlp_input) + q_model.export('int8-nlp-qlinear-model.onnx', int8_onnx_config) + check_NLP_onnx('int8-nlp-qlinear-model.onnx', self.nlp_input) if __name__ == "__main__": unittest.main() -