diff --git a/docs/source/export.md b/docs/source/export.md
index 6dd1fc35c2f..5989286975e 100644
--- a/docs/source/export.md
+++ b/docs/source/export.md
@@ -37,15 +37,15 @@ Here is the workflow of our export API for PyTorch/Tensorflow FP32/INT8 model.
Post-Training Static Quantized INT8 |
- QLinear/QDQ INT8 |
+ QOperator/QDQ INT8 |
Post-Training Dynamic Quantized INT8 |
- / |
+ QOperator INT8 |
Quantization-aware Training INT8 |
- QLinear/QDQ INT8 |
+ QOperator/QDQ INT8 |
TensorFlow |
@@ -63,10 +63,6 @@ Here is the workflow of our export API for PyTorch/Tensorflow FP32/INT8 model.
-> **Note**: Follow this step to export a post training dynamic quantized ONNX model from PyTorch model: \
- 1. export FP32 PyTorch model to FP32 ONNX model. \
- 2. use FP32 ONNX model as the input model for post training dynamic quantization.
-
## Examples
### PyTorch Model
@@ -96,7 +92,7 @@ from neural_compressor.config import Torch2ONNXConfig
int8_onnx_config = Torch2ONNXConfig(
dtype="int8",
opset_version=14,
- quant_format="QLinear", # or QDQ
+ quant_format="QOperator", # or QDQ
example_inputs=torch.randn(1, 3, 224, 224),
input_names=['input'],
output_names=['output'],
diff --git a/examples/.config/model_params_pt2onnx.json b/examples/.config/model_params_pt2onnx.json
index c383a53e2a3..9f7cd7f7112 100644
--- a/examples/.config/model_params_pt2onnx.json
+++ b/examples/.config/model_params_pt2onnx.json
@@ -8,6 +8,14 @@
"main_script": "main.py",
"batch_size": 100
},
+ "resnet18_dynamic": {
+ "model_src_dir": "image_recognition/torchvision_models/export/fx",
+ "source_model_dataset": "/tf_dataset/pytorch/ImageNet/raw",
+ "target_model_dataset": "/tf_dataset2/datasets/imagenet/ImagenetRaw/ImagenetRaw_small_5000",
+ "input_model": "resnet18",
+ "main_script": "main.py",
+ "batch_size": 100
+ },
"resnet50": {
"model_src_dir": "image_recognition/torchvision_models/export/fx",
"source_model_dataset": "/tf_dataset/pytorch/ImageNet/raw",
@@ -16,6 +24,14 @@
"main_script": "main.py",
"batch_size": 100
},
+ "resnet50_dynamic": {
+ "model_src_dir": "image_recognition/torchvision_models/export/fx",
+ "source_model_dataset": "/tf_dataset/pytorch/ImageNet/raw",
+ "target_model_dataset": "/tf_dataset2/datasets/imagenet/ImagenetRaw/ImagenetRaw_small_5000",
+ "input_model": "resnet50",
+ "main_script": "main.py",
+ "batch_size": 100
+ },
"bert_base_MRPC": {
"model_src_dir": "nlp/huggingface_models/text-classification/export/fx",
"source_model_dataset": "mrpc",
@@ -24,6 +40,14 @@
"main_script": "run_glue.py",
"batch_size": 64
},
+ "bert_base_MRPC_dynamic": {
+ "model_src_dir": "nlp/huggingface_models/text-classification/export/fx",
+ "source_model_dataset": "mrpc",
+ "target_model_dataset": "mrpc",
+ "input_model": "/tf_dataset/pytorch/glue_data/base_weights/bert_MRPC_output",
+ "main_script": "run_glue.py",
+ "batch_size": 64
+ },
"bert_large_MRPC": {
"model_src_dir": "nlp/huggingface_models/text-classification/export/fx",
"source_model_dataset": "mrpc",
@@ -31,6 +55,14 @@
"input_model": "/tf_dataset/pytorch/glue_data/weights/bert_MRPC_output",
"main_script": "run_glue.py",
"batch_size": 64
+ },
+ "bert_large_MRPC_dynamic": {
+ "model_src_dir": "nlp/huggingface_models/text-classification/export/fx",
+ "source_model_dataset": "mrpc",
+ "target_model_dataset": "mrpc",
+ "input_model": "/tf_dataset/pytorch/glue_data/weights/bert_MRPC_output",
+ "main_script": "run_glue.py",
+ "batch_size": 64
}
}
}
\ No newline at end of file
diff --git a/examples/pytorch/image_recognition/torchvision_models/export/fx/README.md b/examples/pytorch/image_recognition/torchvision_models/export/fx/README.md
index e2e9b22b658..27b2557f8ce 100644
--- a/examples/pytorch/image_recognition/torchvision_models/export/fx/README.md
+++ b/examples/pytorch/image_recognition/torchvision_models/export/fx/README.md
@@ -34,7 +34,7 @@ Run run_export.sh to get ONNX model from PyTorch model.
# export fp32 model
bash run_export.sh --input_model=resnet50 --dtype=fp32 --dataset_location=/path/to/pytorch-imagenet --output_model=resnet50-fp32.onnx
# export int8 model
-bash run_export.sh --input_model=resnet50 --dtype=int8 --quant_format=[QDQ|QLinear] --dataset_location=/path/to/pytorch-imagenet --output_model=resnet50-int8.onnx
+bash run_export.sh --input_model=resnet50 --dtype=int8 --quant_format=[QDQ|QOperator] --dataset_location=/path/to/pytorch-imagenet --output_model=resnet50-int8.onnx --approach=[static|dynamic]
```
### 2. To get the benchmark of exported and tuned models, includes Batch_size and Throughput:
diff --git a/examples/pytorch/image_recognition/torchvision_models/export/fx/main.py b/examples/pytorch/image_recognition/torchvision_models/export/fx/main.py
index b9fc4016792..ace68d6fde7 100644
--- a/examples/pytorch/image_recognition/torchvision_models/export/fx/main.py
+++ b/examples/pytorch/image_recognition/torchvision_models/export/fx/main.py
@@ -90,8 +90,10 @@
parser.add_argument('--export', dest='export', action='store_true', help='run export')
parser.add_argument('--export_dtype', default='fp32', choices=['fp32', 'int8'],
help='choose the data type [fp32/int8] of PyTorch model to be exported.')
-parser.add_argument('--quant_format', default='QDQ', choices=['QDQ', 'QLinear'],
- help='choose the format [QDQ/QLinear] of int8 ONNX model exported.')
+parser.add_argument('--quant_format', default='QDQ', choices=['QDQ', 'QOperator'],
+ help='choose the format [QDQ/QOperator] of int8 ONNX model exported.')
+parser.add_argument('--approach', default='static', choices=['static', 'dynamic'],
+ help='Post-Training Quantization method.')
best_acc1 = 0
@@ -190,7 +192,7 @@ def eval_func(model):
if args.export and args.export_dtype == 'int8':
from neural_compressor import PostTrainingQuantConfig
from neural_compressor import quantization
- conf = PostTrainingQuantConfig()
+ conf = PostTrainingQuantConfig(approach=args.approach)
q_model = quantization.fit(model,
conf,
calib_dataloader=val_loader,
diff --git a/examples/pytorch/image_recognition/torchvision_models/export/fx/run_export.sh b/examples/pytorch/image_recognition/torchvision_models/export/fx/run_export.sh
index 366db7e850b..ebab5a27988 100644
--- a/examples/pytorch/image_recognition/torchvision_models/export/fx/run_export.sh
+++ b/examples/pytorch/image_recognition/torchvision_models/export/fx/run_export.sh
@@ -11,7 +11,7 @@ function main {
# init params
function init_params {
dtype='fp32'
- quant_format='QDQ' # or QLinear
+ quant_format='QDQ' # or QOperator
tuned_checkpoint=saved_results
for var in "$@"
do
@@ -31,6 +31,9 @@ function init_params {
--quant_format=*)
quant_format=$(echo $var |cut -f2 -d=)
;;
+ --approach=*)
+ approach=$(echo $var |cut -f2 -d=)
+ ;;
esac
done
@@ -48,6 +51,7 @@ function run_tuning {
--export \
--export_dtype ${dtype} \
--quant_format ${quant_format} \
+ --approach ${approach} \
${dataset_location}
}
diff --git a/examples/pytorch/nlp/huggingface_models/text-classification/export/fx/README.md b/examples/pytorch/nlp/huggingface_models/text-classification/export/fx/README.md
index ab65fa1c09f..aa0bc54a1db 100644
--- a/examples/pytorch/nlp/huggingface_models/text-classification/export/fx/README.md
+++ b/examples/pytorch/nlp/huggingface_models/text-classification/export/fx/README.md
@@ -45,7 +45,7 @@ Please pass in the name of dataset, supported datasets are 'mrpc', 'qqp', 'qnli'
# export fp32 model
bash run_export.sh --input_model=[model_name_or_path] --dataset_location=[dataset_name] --dtype=fp32 --output_model=bert-fp32.onnx
# export int8 model
-bash run_export.sh --input_model=[model_name_or_path] --dataset_location=[dataset_name] --dtype=int8 --quant_format=[QDQ/QLinear] --output_model=bert-int8.onnx
+bash run_export.sh --input_model=[model_name_or_path] --dataset_location=[dataset_name] --dtype=int8 --quant_format=[QDQ/QOperator] --output_model=bert-int8.onnx --approach=[static|dynamic]
```
### 2. Get the benchmark results of exported and tuned models, including Batch_size and Throughput:
diff --git a/examples/pytorch/nlp/huggingface_models/text-classification/export/fx/run_export.sh b/examples/pytorch/nlp/huggingface_models/text-classification/export/fx/run_export.sh
index 02226b0aa83..62f6b0c7273 100644
--- a/examples/pytorch/nlp/huggingface_models/text-classification/export/fx/run_export.sh
+++ b/examples/pytorch/nlp/huggingface_models/text-classification/export/fx/run_export.sh
@@ -11,7 +11,7 @@ function main {
# init params
function init_params {
dtype='fp32'
- quant_format='QDQ' # or QLinear
+ quant_format='QDQ' # or QOperator
for var in "$@"
do
case $var in
@@ -30,6 +30,9 @@ function init_params {
--quant_format=*)
quant_format=$(echo $var |cut -f2 -d=)
;;
+ --approach=*)
+ approach=$(echo $var |cut -f2 -d=)
+ ;;
esac
done
@@ -60,6 +63,7 @@ function run_tuning {
--quant_format ${quant_format} \
--output_dir ${tuned_checkpoint} \
--overwrite_output_dir \
+ --approach ${approach} \
${extra_cmd}
}
diff --git a/examples/pytorch/nlp/huggingface_models/text-classification/export/fx/run_glue.py b/examples/pytorch/nlp/huggingface_models/text-classification/export/fx/run_glue.py
index 22c6ba12fbc..7e8a0805af4 100644
--- a/examples/pytorch/nlp/huggingface_models/text-classification/export/fx/run_glue.py
+++ b/examples/pytorch/nlp/huggingface_models/text-classification/export/fx/run_glue.py
@@ -190,7 +190,7 @@ class ModelArguments:
default="fp32", metadata={"help": "choose the data type [fp32/int8] of PyTorch model to be exported."}
)
quant_format: str = field(
- default="QDQ", metadata={"help": "choose the format [QDQ/QLinear] of int8 ONNX model exported."}
+ default="QDQ", metadata={"help": "choose the format [QDQ/QOperator] of int8 ONNX model exported."}
)
output_model: str = field(
default="model.onnx", metadata={"help": "the name of exported model."}
@@ -210,6 +210,12 @@ class ModelArguments:
"help": "The inference iterations to run for benchmark."
},
)
+ approach: str = field(
+ default='static',
+ metadata={
+ "help": "Post-Training Quantization method."
+ },
+ )
def main():
@@ -541,13 +547,18 @@ def eval_func(model):
strategy_kwargs={"confidence_batches": 1},
max_trials=600,
)
- conf = PostTrainingQuantConfig(
- approach="static",
- quant_level=1,
- tuning_criterion=tuning_criterion,
- op_type_dict={"Embedding":FP32},
- calibration_sampling_size=[300],
- )
+ if model_args.approach == "static":
+ conf = PostTrainingQuantConfig(
+ approach=model_args.approach,
+ quant_level=1,
+ tuning_criterion=tuning_criterion,
+ op_type_dict={"Embedding":FP32},
+ calibration_sampling_size=[300],
+ )
+ elif model_args.approach == "dynamic":
+ conf = PostTrainingQuantConfig(
+ approach=model_args.approach,
+ )
q_model = fit(model, conf=conf, calib_dataloader=eval_dataloader, eval_func=eval_func)
from neural_compressor.utils.load_huggingface import save_for_huggingface_upstream
save_for_huggingface_upstream(q_model, tokenizer, training_args.output_dir)
diff --git a/neural_compressor/experimental/export/torch2onnx.py b/neural_compressor/experimental/export/torch2onnx.py
index 57461bd2083..99f0d2f656a 100644
--- a/neural_compressor/experimental/export/torch2onnx.py
+++ b/neural_compressor/experimental/export/torch2onnx.py
@@ -54,9 +54,233 @@ def _prepare_inputs(pt_model, input_names, example_inputs):
example_inputs = input2tuple(example_inputs)
return input_names, example_inputs
+def get_node_mapping(
+ fp32_model,
+ fp32_onnx_path,
+):
+ """Get PyTorch module and ONNX node mapping.
+
+ Args:
+ fp32_model (torch.nn.Module): quantization configuration from PyTorch.
+ fp32_onnx_path (str): path to fp32 onnx model.
+
+ Returns:
+ module_node_mapping: op mapping from PyTorch to ONNX.
+ """
+ def check_data(op_type, data, module_dict):
+ for name, value in module_dict.items():
+ if value.shape == data.shape:
+ if (value == data).all():
+ module_dict.pop(name)
+ return name
+ elif op_type == 'Conv':
+ # Convolution weight data have fluction and BN fusion will insert scale.
+ # We use the weight scale of the first output channel to check.
+ weight_scale = value[0] / data[0]
+ if np.allclose(weight_scale - np.mean(weight_scale), 0, atol=1.e-5):
+ module_dict.pop(name)
+ return name
+ return None
+
+ module_dict = {}
+ for name, module in fp32_model.named_modules():
+ if 'Conv' in str(module.__class__.__name__) or \
+ 'Embedding' in str(module.__class__.__name__) or \
+ 'Linear' in str(module.__class__.__name__):
+ if hasattr(module, 'weight'):
+ value = module.weight.detach().cpu().numpy()
+ module_dict[name] = value
+
+ module_node_mapping = {}
+ fp32_onnx_model = onnx.load(fp32_onnx_path)
+ initializer_data = {tensor.name: tensor for tensor in fp32_onnx_model.graph.initializer}
+ from onnx import numpy_helper
+ for node in fp32_onnx_model.graph.node:
+ if node.op_type in op_types_to_quantize:
+ if node.op_type == 'MatMul' and node.input[1] in initializer_data:
+ data = numpy_helper.to_array(initializer_data[node.input[1]]).T
+ elif node.op_type == 'Gather' and node.input[0] in initializer_data:
+ data = numpy_helper.to_array(initializer_data[node.input[0]])
+ elif node.op_type in ['Conv', 'Gemm']:
+ data = numpy_helper.to_array(initializer_data[node.input[1]])
+ else:
+ continue
+ pt_name = check_data(node.op_type, data, module_dict)
+ if pt_name:
+ module_node_mapping[pt_name] = node.name
+ return module_node_mapping
+
+def get_quantizable_onnx_ops(
+ int8_model,
+ module_node_mapping
+):
+ """Get quantizable onnx ops.
+
+ Args:
+ int8_model (torch.nn.Module): PyTorch int8 model.
+ module_node_mapping (dict): op mapping from PyTorch to ONNX.
+
+ Returns:
+ quantize_nodes: all onnx node that should be quantized.
+ """
+ quantize_nodes = []
+ for name, module in int8_model.named_modules():
+ if 'Conv' in str(module.__class__.__name__) or \
+ 'Embedding' in str(module.__class__.__name__) or \
+ 'Linear' in str(module.__class__.__name__):
+ if hasattr(module, 'weight') and callable(module.weight):
+ if module.weight().dtype in [torch.qint8, torch.quint8]:
+ if name.split('.module')[0] in module_node_mapping:
+ node = module_node_mapping[name.split('.module')[0]]
+ quantize_nodes.append(node)
+ return quantize_nodes
+
+def dynamic_quant_export(
+ pt_fp32_model,
+ pt_int8_model,
+ save_path,
+ example_inputs,
+ q_config,
+ opset_version,
+ dynamic_axes,
+ input_names,
+ output_names,
+ weight_type,
+):
+ """Export dynamic quantized model.
+
+ Args:
+ pt_fp32_model (torch.nn.module): PyTorch FP32 model.
+ pt_int8_model (torch.nn.module): PyTorch INT8 model.
+ save_path (str): save path of ONNX model.
+ example_inputs (dict|list|tuple|torch.Tensor): used to trace torch model.
+ q_config (dict): containing quantization configuration.
+ opset_version (int, optional): opset version. Defaults to 14.
+ dynamic_axes (dict, optional): dynamic axes. Defaults to
+ {"input": {0: "batch_size"}, "output": {0: "batch_size"}}.
+ input_names (dict, optional): input names. Defaults to None.
+ output_names (dict, optional): output names. Defaults to None.
+ weight_type (str, optional): data types of weight of ONNX model
+ (only needed for exporting dynamic quantized model). Defaults to 'S8'.
+ """
+ global op_types_to_quantize
+ op_types_to_quantize=['MatMul', 'Gemm', 'Gather']
+
+ # pylint: disable=E1101
+ fp32_onnx_path = save_path + '.tmp' if save_path else 'int8-model.onnx.tmp'
+ torch_to_fp32_onnx(
+ pt_fp32_model,
+ fp32_onnx_path,
+ example_inputs,
+ opset_version=opset_version,
+ input_names=input_names,
+ output_names=output_names,
+ dynamic_axes=dynamic_axes,
+ verbose=False,
+ )
+
+ module_node_mapping = get_node_mapping(pt_fp32_model, fp32_onnx_path)
+ quantize_nodes = get_quantizable_onnx_ops(pt_int8_model, module_node_mapping)
+
+ REDUCE_RANGE = q_config['reduce_range']
+ if REDUCE_RANGE:
+ logger.info("Reduce range is {}".format(str(REDUCE_RANGE)))
+
+ logger.info("Quantization format is not avalible when executing dynamic quantization.")
+
+ if weight_type.upper() == 'S8':
+ weight_type = ortq.QuantType.QInt8
+ elif weight_type.upper() == 'U8':
+ weight_type = ortq.QuantType.QUInt8
+ else:
+ assert False, "Right now, we don't support weight type: {}, " \
+ "please use S8/U8.".format(weight_type)
+
+ ortq.quantize_dynamic(
+ fp32_onnx_path,
+ save_path,
+ per_channel=True,
+ reduce_range=REDUCE_RANGE,
+ weight_type=weight_type,
+ nodes_to_quantize=quantize_nodes,
+ nodes_to_exclude=[],
+ extra_options={}
+ )
+
+ os.remove(fp32_onnx_path)
+
+def static_quant_export(
+ pt_int8_model,
+ save_path,
+ example_inputs,
+ q_config,
+ opset_version,
+ dynamic_axes,
+ input_names,
+ output_names,
+ quant_format,
+):
+ """Export static quantized model.
+
+ Args:
+ pt_int8_model (torch.nn.module): PyTorch INT8 model.
+ save_path (str): save path of ONNX model.
+ example_inputs (dict|list|tuple|torch.Tensor): used to trace torch model.
+ q_config (dict): containing quantization configuration.
+ opset_version (int, optional): opset version. Defaults to 14.
+ dynamic_axes (dict, optional): dynamic axes. Defaults to
+ {"input": {0: "batch_size"}, "output": {0: "batch_size"}}.
+ input_names (dict, optional): input names. Defaults to None.
+ output_names (dict, optional): output names. Defaults to None.
+ quant_format (str, optional): _quantization format of ONNX model. Defaults to 'QDQ'.
+ """
+ input_names, example_inputs = _prepare_inputs(pt_int8_model, input_names, example_inputs)
+
+ def model_wrapper(model_fn):
+ # export doesn't support a dictionary output, so manually turn it into a tuple
+ # refer to https://discuss.tvm.apache.org/t/how-to-deal-with-prim-dictconstruct/11978
+ def wrapper(*args, **kwargs):
+ output = model_fn(*args, **kwargs)
+ if isinstance(output, dict):
+ return tuple(v for v in output.values() if v is not None)
+ else:
+ return output
+ return wrapper
+ pt_int8_model.forward = model_wrapper(pt_int8_model.forward)
+ with torch.no_grad():
+ try:
+ torch.onnx.export(
+ pt_int8_model,
+ input2tuple(example_inputs),
+ save_path,
+ opset_version=opset_version,
+ input_names=input_names,
+ output_names=output_names,
+ dynamic_axes=dynamic_axes,
+ )
+ except TypeError:
+ config_name = "QuantizationAwareTrainingConfig" \
+ if q_config['approach'] == "quant_aware_training" else "PostTrainingQuantConfig"
+ logger.error("Export failed, possibly because unsupported quantized ops. Check "
+ "neural-compressor/docs/source/export.md#supported-quantized-ops "
+ "for supported ops.")
+ logger.error("Please fallback unsupported quantized ops by setting 'op_type_dict' or "
+ "'op_name_dict' in '{}' config. ".format(config_name))
+ exit(0)
+ except Exception as e:
+ import pdb;pdb.set_trace()
+ logger.error(e)
+ exit(0)
+
+ if quant_format != "QDQ":
+ sess_options = ort.SessionOptions()
+ sess_options.graph_optimization_level=ort.GraphOptimizationLevel.ORT_ENABLE_EXTENDED
+ sess_options.optimized_model_filepath=save_path
+ ort.InferenceSession(save_path, sess_options)
+
def torch_to_fp32_onnx(
- pt_model,
+ pt_fp32_model,
save_path,
example_inputs,
opset_version=14,
@@ -70,7 +294,7 @@ def torch_to_fp32_onnx(
"""Export FP32 PyTorch model into FP32 ONNX model.
Args:
- pt_model (torch.nn.module): PyTorch model.
+ pt_fp32_model (torch.nn.module): PyTorch FP32 model.
save_path (str): save path of ONNX model.
example_inputs (dict|list|tuple|torch.Tensor): used to trace torch model.
opset_version (int, optional): opset version. Defaults to 14.
@@ -82,14 +306,14 @@ def torch_to_fp32_onnx(
verbose (bool, optional): dump verbose or not. Defaults to True.
"""
from neural_compressor.utils.pytorch import is_int8_model
- assert is_int8_model(pt_model) == False, "The fp32 model is replaced during quantization. " + \
+ assert is_int8_model(pt_fp32_model) == False, "The fp32 model is replaced during quantization. " + \
"please customize a eval_func when quantizing, if not, such as `lambda x: 1`."
- input_names, example_inputs = _prepare_inputs(pt_model, input_names, example_inputs)
+ input_names, example_inputs = _prepare_inputs(pt_fp32_model, input_names, example_inputs)
with torch.no_grad():
torch.onnx.export(
- pt_model,
+ pt_fp32_model,
example_inputs,
save_path,
opset_version=opset_version,
@@ -107,7 +331,8 @@ def torch_to_fp32_onnx(
def torch_to_int8_onnx(
- pt_model,
+ pt_fp32_model,
+ pt_int8_model,
save_path,
example_inputs,
q_config,
@@ -117,12 +342,14 @@ def torch_to_int8_onnx(
input_names=None,
output_names=None,
quant_format: str = 'QDQ',
+ weight_type: str = 'S8',
verbose=True,
):
"""Export INT8 PyTorch model into INT8 ONNX model.
Args:
- pt_model (torch.nn.module): PyTorch model.
+ pt_fp32_model (torch.nn.module): PyTorch FP32 model.
+ pt_int8_model (torch.nn.module): PyTorch INT8 model.
save_path (str): save path of ONNX model.
example_inputs (dict|list|tuple|torch.Tensor): used to trace torch model.
q_config (dict): containing quantization configuration.
@@ -132,61 +359,32 @@ def torch_to_int8_onnx(
input_names (dict, optional): input names. Defaults to None.
output_names (dict, optional): output names. Defaults to None.
quant_format (str, optional): _quantization format of ONNX model. Defaults to 'QDQ'.
+ weight_type (str, optional): data types of weight of ONNX model
+ (only needed for exporting dynamic quantized model). Defaults to 'S8'.
verbose (bool, optional): dump verbose or not. Defaults to True.
"""
from neural_compressor.utils.pytorch import is_int8_model
- assert is_int8_model(pt_model), "The exported model is not INT8 model, "\
+ assert is_int8_model(pt_int8_model), "The exported model is not INT8 model, "\
"please reset 'dtype' to 'FP32' or check your model."
assert not q_config is None, "'q_config' is needed when export an INT8 model."
- if q_config['approach'] == 'post_training_dynamic_quant': # pragma: no cover
- assert False, "Post training dynamic quantized PyTorch model is not supported " \
- "to export to ONNX directly. Please follow this step to get a post training " \
- "dynamic quantized ONNX model: " \
- "1. export FP32 PyTorch model to FP32 ONNX model. " \
- "2. use FP32 ONNX model as the input model for post training dynamic quantization."
-
- input_names, example_inputs = _prepare_inputs(pt_model, input_names, example_inputs)
-
- def model_wrapper(model_fn):
- # export doesn't support a dictionary output, so manually turn it into a tuple
- # refer to https://discuss.tvm.apache.org/t/how-to-deal-with-prim-dictconstruct/11978
- def wrapper(*args, **kwargs):
- output = model_fn(*args, **kwargs)
- if isinstance(output, dict):
- return tuple(v for v in output.values() if v is not None)
- else:
- return output
- return wrapper
- pt_model.forward = model_wrapper(pt_model.forward)
+ quant_format = quant_format.upper()
+ if quant_format == 'QDQ' and opset_version < 13: # pragma: no cover
+ opset_version = 13
+ logger.warning("QDQ format requires opset_version >= 13, " +
+ "we reset opset_version={} here".format(opset_version))
- with torch.no_grad():
- try:
- torch.onnx.export(
- pt_model,
- input2tuple(example_inputs),
- save_path,
- opset_version=opset_version,
- input_names=input_names,
- output_names=output_names,
- dynamic_axes=dynamic_axes,
- )
- except Exception as e:
- config_name = "QuantizationAwareTrainingConfig" \
- if q_config['approach'] == "quant_aware_training" else "PostTrainingQuantConfig"
- logger.error("Export failed, possibly because unsupported quantized ops. Check "
- "neural-compressor/docs/source/export.md#supported-quantized-ops "
- "for supported ops.")
- logger.error("Please fallback unsupported quantized ops by setting 'op_type_dict' or "
- "'op_name_dict' in '{}' config. ".format(config_name))
- return
-
- if quant_format != "QDQ":
- sess_options = ort.SessionOptions()
- sess_options.graph_optimization_level=ort.GraphOptimizationLevel.ORT_ENABLE_EXTENDED
- sess_options.optimized_model_filepath=save_path
- ort.InferenceSession(save_path, sess_options)
+ if q_config['approach'] == 'post_training_dynamic_quant':
+ # dynamic quantization export follow these steps:
+ # "1. export FP32 PyTorch model to FP32 ONNX model. "
+ # "2. use FP32 ONNX model as the input model for post training dynamic quantization."
+ # TODO: will be removed once torch supports dynamic quantization export
+ dynamic_quant_export(pt_fp32_model, pt_int8_model, save_path, example_inputs, q_config,
+ opset_version, dynamic_axes, input_names, output_names, weight_type)
+ else:
+ static_quant_export(pt_int8_model, save_path, example_inputs, q_config, opset_version,
+ dynamic_axes, input_names, output_names, quant_format)
if verbose:
info = "The INT8 ONNX Model exported to path: {0}".format(save_path)
diff --git a/neural_compressor/model/torch_model.py b/neural_compressor/model/torch_model.py
index 30841775e35..58011a5466e 100644
--- a/neural_compressor/model/torch_model.py
+++ b/neural_compressor/model/torch_model.py
@@ -369,6 +369,7 @@ def export(
if conf.dtype == 'int8':
torch_to_int8_onnx(
+ self.fp32_model,
self.model,
save_path,
conf.example_inputs,
diff --git a/test/export/test_torch2onnx.py b/test/export/test_torch2onnx.py
index c89ae77f784..9c2a623b1f2 100644
--- a/test/export/test_torch2onnx.py
+++ b/test/export/test_torch2onnx.py
@@ -159,14 +159,8 @@ def test_int8_CV_models(self):
dynamic_axes={"input": {0: "batch_size"},
"output": {0: "batch_size"}},
)
- if fake_yaml == "dynamic":
- try:
- q_model.export('int8-nlp-qdq-model.onnx', int8_onnx_config)
- except Exception as e:
- self.assertIsInstance(e, AssertionError)
- elif fake_yaml == "static":
- q_model.export('int8-cv-qdq-model.onnx', int8_onnx_config)
- check_CV_onnx('int8-cv-qdq-model.onnx', self.cv_dataloader)
+ q_model.export('int8-cv-qdq-model.onnx', int8_onnx_config)
+ check_CV_onnx('int8-cv-qdq-model.onnx', self.cv_dataloader)
int8_onnx_config = Torch2ONNXConfig(
dtype="int8",
@@ -178,14 +172,8 @@ def test_int8_CV_models(self):
dynamic_axes={"input": {0: "batch_size"},
"output": {0: "batch_size"}},
)
- if fake_yaml == "dynamic":
- try:
- q_model.export('int8-nlp-qdq-model.onnx', int8_onnx_config)
- except Exception as e:
- self.assertIsInstance(e, AssertionError)
- elif fake_yaml == "static":
- q_model.export('int8-cv-qlinear-model.onnx', int8_onnx_config)
- check_CV_onnx('int8-cv-qlinear-model.onnx', self.cv_dataloader)
+ q_model.export('int8-cv-qlinear-model.onnx', int8_onnx_config)
+ check_CV_onnx('int8-cv-qlinear-model.onnx', self.cv_dataloader)
def test_fp32_NLP_models(self):
@@ -252,14 +240,8 @@ def test_int8_NLP_models(self):
output_names=['labels'],
dynamic_axes=dynamic_axes,
)
- if fake_yaml == "dynamic":
- try:
- q_model.export('int8-nlp-qdq-model.onnx', int8_onnx_config)
- except Exception as e:
- self.assertIsInstance(e, AssertionError)
- elif fake_yaml == "static":
- q_model.export('int8-nlp-qdq-model.onnx', int8_onnx_config)
- check_NLP_onnx('int8-nlp-qdq-model.onnx', self.nlp_input)
+ q_model.export('int8-nlp-qdq-model.onnx', int8_onnx_config)
+ check_NLP_onnx('int8-nlp-qdq-model.onnx', self.nlp_input)
int8_onnx_config = Torch2ONNXConfig(
dtype="int8",
@@ -270,16 +252,9 @@ def test_int8_NLP_models(self):
output_names=['labels'],
dynamic_axes=dynamic_axes,
)
- if fake_yaml == "dynamic":
- try:
- q_model.export('int8-nlp-qdq-model.onnx', int8_onnx_config)
- except Exception as e:
- self.assertIsInstance(e, AssertionError)
- elif fake_yaml == "static":
- q_model.export('int8-nlp-qlinear-model.onnx', int8_onnx_config)
- check_NLP_onnx('int8-nlp-qlinear-model.onnx', self.nlp_input)
+ q_model.export('int8-nlp-qlinear-model.onnx', int8_onnx_config)
+ check_NLP_onnx('int8-nlp-qlinear-model.onnx', self.nlp_input)
if __name__ == "__main__":
unittest.main()
-