diff --git a/examples/.config/model_params_onnxrt.json b/examples/.config/model_params_onnxrt.json index bf58872555a..2480fe54c7b 100644 --- a/examples/.config/model_params_onnxrt.json +++ b/examples/.config/model_params_onnxrt.json @@ -621,6 +621,15 @@ "batch_size": 1, "new_benchmark": true }, + "unet": { + "model_src_dir": "image_recognition/unet/quantization/ptq", + "dataset_location": "/tf_dataset2/datasets/imagenet/ImagenetRaw/ILSVRC2012_img_val", + "input_model": "/tf_dataset2/models/onnx/unet/model.onnx", + "yaml": "unet.yaml", + "strategy": "basic", + "batch_size": 1, + "new_benchmark": true + }, "BiDAF": { "model_src_dir": "language_translation/onnx_model_zoo/BiDAF/quantization/ptq", "dataset_location": "/tf_dataset2/datasets/squad/dev-v1.1.json", diff --git a/examples/onnxrt/image_recognition/unet/quantization/ptq/main.py b/examples/onnxrt/image_recognition/unet/quantization/ptq/main.py new file mode 100644 index 00000000000..a35a02fe418 --- /dev/null +++ b/examples/onnxrt/image_recognition/unet/quantization/ptq/main.py @@ -0,0 +1,82 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint:disable=redefined-outer-name,logging-format-interpolation + + +import logging +import argparse + +import numpy as np +import onnx + +logger = logging.getLogger(__name__) +logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', + datefmt = '%m/%d/%Y %H:%M:%S', + level = logging.WARN) + +if __name__ == "__main__": + logger.info("Evaluating ONNXRuntime full precision accuracy and performance:") + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + parser.add_argument( + '--model_path', + type=str, + help="Pre-trained mobilenet_v3 model on onnx file" + ) + parser.add_argument( + '--benchmark', + action='store_true', \ + default=False + ) + parser.add_argument( + '--tune', + action='store_true', \ + default=False, + help="whether quantize the model" + ) + parser.add_argument( + '--config', + type=str, + help="config yaml path" + ) + parser.add_argument( + '--output_model', + type=str, + help="output model path" + ) + parser.add_argument( + '--mode', + type=str, + default='performance', + help="benchmark mode of performance or accuracy" + ) + args = parser.parse_args() + if args.benchmark: + from neural_compressor.experimental import Benchmark, common + evaluator = Benchmark(args.config) + evaluator.model = common.Model(args.model_path) + evaluator(args.mode) + + if args.tune: + from neural_compressor.experimental import Quantization, common + + quantize = Quantization(args.config) + quantize.model = common.Model(args.model_path) + q_model = quantize() + q_model.save(args.output_model) + diff --git a/examples/onnxrt/image_recognition/unet/quantization/ptq/readme.md b/examples/onnxrt/image_recognition/unet/quantization/ptq/readme.md new file mode 100644 index 00000000000..4c28538295a --- /dev/null +++ b/examples/onnxrt/image_recognition/unet/quantization/ptq/readme.md @@ -0,0 +1,31 @@ +# Evaluate performance of ONNX Runtime(unet) + +This is an experimental example to quantize unet model. We use dummy data to do quantization and evaluation, so the accuracy is not guaranteed. + +### Environment +onnx: 1.12.0 +onnxruntime: 1.12.1 + +### Prepare model + +```bash +git clone https://github.com/huggingface/diffusers.git +cd diffusers/scripts/ +python convert_stable_diffusion_checkpoint_to_onnx.py --model_path "CompVis/stable-diffusion-v1-4" --output_path /workdir/output_path +``` + +### Quantization + +```bash +bash run_tuning.sh --input_model=/workdir/output_path/unet/model.onnx \ + --config=unet.yaml \ + --output_model=path/to/save +``` + +### Benchmark + +```bash +bash run_benchmark.sh --input_model=/workdir/output_path/unet/model.onnx \ + --config=unet.yaml \ + --mode=performance +``` diff --git a/examples/onnxrt/image_recognition/unet/quantization/ptq/requirements.txt b/examples/onnxrt/image_recognition/unet/quantization/ptq/requirements.txt new file mode 100644 index 00000000000..e204c53486c --- /dev/null +++ b/examples/onnxrt/image_recognition/unet/quantization/ptq/requirements.txt @@ -0,0 +1,3 @@ +onnx==1.12.0 +onnxruntime==1.12.0 +onnxruntime-extensions; python_version < '3.10' diff --git a/examples/onnxrt/image_recognition/unet/quantization/ptq/run_benchmark.sh b/examples/onnxrt/image_recognition/unet/quantization/ptq/run_benchmark.sh new file mode 100644 index 00000000000..2b7d99703d3 --- /dev/null +++ b/examples/onnxrt/image_recognition/unet/quantization/ptq/run_benchmark.sh @@ -0,0 +1,41 @@ +#!/bin/bash +set -x + +function main { + init_params "$@" + run_benchmark + +} + +# init params +function init_params { + + for var in "$@" + do + case $var in + --config=*) + config=$(echo $var |cut -f2 -d=) + ;; + --input_model=*) + input_model=$(echo $var |cut -f2 -d=) + ;; + --mode=*) + mode=$(echo $var |cut -f2 -d=) + ;; + esac + done + +} + +# run_benchmark +function run_benchmark { + + python main.py \ + --model_path ${input_model} \ + --config ${config} \ + --mode=${mode} \ + --benchmark + +} + +main "$@" diff --git a/examples/onnxrt/image_recognition/unet/quantization/ptq/run_tuning.sh b/examples/onnxrt/image_recognition/unet/quantization/ptq/run_tuning.sh new file mode 100644 index 00000000000..97d06dab599 --- /dev/null +++ b/examples/onnxrt/image_recognition/unet/quantization/ptq/run_tuning.sh @@ -0,0 +1,39 @@ +#!/bin/bash +set -x + +function main { + init_params "$@" + run_tuning + +} + +# init params +function init_params { + + for var in "$@" + do + case $var in + --config=*) + config=$(echo $var |cut -f2 -d=) + ;; + --input_model=*) + input_model=$(echo $var |cut -f2 -d=) + ;; + --output_model=*) + output_model=$(echo $var |cut -f2 -d=) + ;; + esac + done + +} + +# run_tuning +function run_tuning { + python main.py \ + --model_path ${input_model} \ + --output_model ${output_model} \ + --config ${config} \ + --tune +} + +main "$@" diff --git a/examples/onnxrt/image_recognition/unet/quantization/ptq/unet.yaml b/examples/onnxrt/image_recognition/unet/quantization/ptq/unet.yaml new file mode 100644 index 00000000000..f5d35489915 --- /dev/null +++ b/examples/onnxrt/image_recognition/unet/quantization/ptq/unet.yaml @@ -0,0 +1,57 @@ +# +# Copyright (c) 2021 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +version: 1.0 + +model: # mandatory. used to specify model specific information. + name: unet + framework: onnxrt_qlinearops # mandatory. supported values are tensorflow, pytorch, pytorch_ipex, onnxrt_integer, onnxrt_qlinear or mxnet; allow new framework backend extension. + +quantization: # optional. tuning constraints on model-wise for advance user to reduce tuning space. + approach: post_training_static_quant # optional. default value is post_training_static_quant. + calibration: + dataloader: + batch_size: 1 + dataset: + dummy: + shape: [[1, 4, 64, 64], [1], [1, 77, 768]] + dtype: ['float32', 'int64', 'float32'] + +evaluation: # optional. required if user doesn't provide eval_func in neural_compressor.Quantization. + accuracy: # optional. required if user doesn't provide eval_func in neural_compressor.Quantization. + dataloader: + batch_size: 1 + dataset: + dummy: + shape: [[1, 4, 64, 64], [1], [1, 77, 768]] + dtype: ['float32', 'int64', 'float32'] + + performance: # optional. used to benchmark performance of passing model. + warmup: 10 + iteration: 500 + configs: + cores_per_instance: 4 + num_of_instance: 7 + dataloader: + batch_size: 1 + dataset: + dummy: + shape: [[1, 4, 64, 64], [1], [1, 77, 768]] + dtype: ['float32', 'int64', 'float32'] + +tuning: + exit_policy: + performance_only: True + random_seed: 9527 # optional. random seed for deterministic tuning. diff --git a/neural_compressor/adaptor/onnxrt.py b/neural_compressor/adaptor/onnxrt.py index 370936e0b4c..21406f7aaf2 100644 --- a/neural_compressor/adaptor/onnxrt.py +++ b/neural_compressor/adaptor/onnxrt.py @@ -142,7 +142,7 @@ def quantize(self, tune_cfg, model, data_loader, q_func=None): break tmp_iterations = int(math.ceil(calib_sampling_size / calib_batch_size)) data_loader.batch(calib_batch_size) - quantize_params = self._get_quantize_params(tmp_model.model, data_loader, \ + quantize_params = self._get_quantize_params(tmp_model, data_loader, \ quantize_config, tmp_iterations) except Exception as e: # pragma: no cover if 'Got invalid dimensions for input' in str(e): @@ -153,7 +153,7 @@ def quantize(self, tune_cfg, model, data_loader, q_func=None): "Fail to forward with batch size={}, set to {} now.". format(batch_size, 1)) data_loader.batch(1) - quantize_params = self._get_quantize_params(tmp_model.model, data_loader, \ + quantize_params = self._get_quantize_params(tmp_model, data_loader, \ quantize_config, calib_sampling_size) else: # pragma: no cover if hasattr(data_loader, 'batch_size') and \ @@ -164,13 +164,13 @@ def quantize(self, tune_cfg, model, data_loader, q_func=None): "So the real sampling size is {}.". format(calib_sampling_size, data_loader.batch_size, data_loader.batch_size * iterations)) - quantize_params = self._get_quantize_params(tmp_model.model, data_loader, \ + quantize_params = self._get_quantize_params(tmp_model, data_loader, \ quantize_config, iterations) else: quantize_params = None self.quantize_params = quantize_params from neural_compressor.adaptor.ox_utils.quantizer import Quantizer - quantizer = Quantizer(tmp_model.model, + quantizer = Quantizer(copy.deepcopy(model), quantize_config, backend, self.static, @@ -459,15 +459,25 @@ def _pre_optimize(self, model, level=1): if sys.version_info < (3,10) and find_spec('onnxruntime_extensions'): # pragma: no cover from onnxruntime_extensions import get_library_path sess_options.register_custom_ops_library(get_library_path()) - _ = ort.InferenceSession(model.model.SerializeToString(), sess_options) - tmp_model = onnx.load(sess_options.optimized_model_filepath) + if not model.large_size: + ort.InferenceSession(model.model.SerializeToString(), sess_options) + elif model.model_path is not None: # pragma: no cover + ort.InferenceSession(model.model_path, sess_options) + else: # pragma: no cover + logger.warning('Please use model path instead of onnx model object to quantize') + + tmp_model = onnx.load(sess_options.optimized_model_filepath, load_external_data=False) + if model.large_size: # pragma: no cover + from onnx.external_data_helper import load_external_data_for_model + load_external_data_for_model(tmp_model, os.path.split(model.model_path)[0]) + model.model_path = sess_options.optimized_model_filepath model.model = self._replace_gemm_with_matmul(tmp_model).model \ if self.graph_optimization.gemm2matmul else tmp_model model.model = self._rename_node(model.model) model = self._revert_fusedconv(model) model = split_shared_bias(model) model.topological_sort() - self.pre_optimized_model = model + self.pre_optimized_model = copy.deepcopy(model) def _revert_fusedconv(self, model): from neural_compressor.adaptor.ox_utils.util import attribute_to_kwarg @@ -787,6 +797,13 @@ def evaluate(self, input_graph, dataloader, postprocess=None, Returns: (float) evaluation results. acc, f1 e.g. """ + if input_graph.large_size: # pragma: no cover + onnx.save_model(input_graph.model, + self.work_space + 'eval.onnx', + save_as_external_data=True, + all_tensors_to_one_file=True, + location="weights.pb", + convert_attribute=False) sess_options = ort.SessionOptions() if measurer: # https://github.com/microsoft/onnxruntime/issues/7347 @@ -796,7 +813,9 @@ def evaluate(self, input_graph, dataloader, postprocess=None, if sys.version_info < (3,10) and find_spec('onnxruntime_extensions'): # pragma: no cover from onnxruntime_extensions import get_library_path sess_options.register_custom_ops_library(get_library_path()) - session = ort.InferenceSession(input_graph.model.SerializeToString(), sess_options) + session = ort.InferenceSession(self.work_space + 'eval.onnx', sess_options) if \ + input_graph.large_size else \ + ort.InferenceSession(input_graph.model.SerializeToString(), sess_options) results = [] if metrics: for metric in metrics: diff --git a/neural_compressor/adaptor/ox_utils/calibration.py b/neural_compressor/adaptor/ox_utils/calibration.py index fb1cbadb309..6b9b76077d0 100644 --- a/neural_compressor/adaptor/ox_utils/calibration.py +++ b/neural_compressor/adaptor/ox_utils/calibration.py @@ -137,10 +137,12 @@ def augment_graph(self, activation_only=False, weight_only=False): elif activation_only: tensors_to_dump.update(node.output) + model_inputs = [i.name for i in model.graph.input] for tensor in tensors_to_dump: + if tensor not in node_outputs and tensor not in initializers and \ + tensor not in model_inputs: + continue if self.augment_nodes: - if tensor not in node_outputs and tensor not in initializers: - continue for augment_node_type in self.augment_nodes: if augment_node_type in ['DequantizeLinear']: # insert DequantizeLinear node as output @@ -183,6 +185,13 @@ def augment_graph(self, activation_only=False, weight_only=False): model.graph.output.extend(added_outputs) # pylint: disable=no-member self.augmented_model = model + if self.model_wrapper.large_size: # pragma: no cover + onnx.save_model(model, + self.model_wrapper.model_path + '_augment.onnx', + save_as_external_data=True, + all_tensors_to_one_file=True, + location="weights.pb", + convert_attribute=False) def get_intermediate_outputs(self, calib_mode=None): ''' @@ -196,7 +205,9 @@ def get_intermediate_outputs(self, calib_mode=None): from onnxruntime_extensions import get_library_path so.register_custom_ops_library(get_library_path()) - session = onnxruntime.InferenceSession(self.augmented_model.SerializeToString(), so) + session = onnxruntime.InferenceSession(self.augmented_model.SerializeToString(), so) if \ + not self.model_wrapper.large_size else \ + onnxruntime.InferenceSession(self.model_wrapper.model_path + '_augment.onnx', so) intermediate_outputs = [] len_inputs = len(session.get_inputs()) diff --git a/neural_compressor/adaptor/ox_utils/operators/direct_q8.py b/neural_compressor/adaptor/ox_utils/operators/direct_q8.py index bcddb43aa3a..00522c178a1 100644 --- a/neural_compressor/adaptor/ox_utils/operators/direct_q8.py +++ b/neural_compressor/adaptor/ox_utils/operators/direct_q8.py @@ -31,7 +31,7 @@ def quantize_check(self): def quantize(self): node = self.node - self.quantizer.quantize_inputs(self.node, direct_int8=True) + self.quantizer.quantize_inputs(self.node, [0], direct_int8=True) if not self.disable_qdq_for_node_output or self.quantizer.mode != 'qdq': self.quantizer.quantize_outputs(self.node, direct_int8=True) node.name = node.name + "_quant" @@ -82,4 +82,4 @@ def cast(self): node = self.node if node.input[0] not in [i.tensor_name for i in self.quantizer.new_value_info.values()]: return - self.quantizer.dtype_cast(self.node, self.dtype) \ No newline at end of file + self.quantizer.dtype_cast(self.node, self.dtype) diff --git a/neural_compressor/adaptor/ox_utils/quantizer.py b/neural_compressor/adaptor/ox_utils/quantizer.py index 829dff40045..65a5ec4d080 100644 --- a/neural_compressor/adaptor/ox_utils/quantizer.py +++ b/neural_compressor/adaptor/ox_utils/quantizer.py @@ -20,6 +20,7 @@ # Licensed under the MIT License. See License.txt in the project root for # license information. # -------------------------------------------------------------------------- +import os import onnx import logging import numpy as np @@ -45,8 +46,9 @@ class Quantizer: def __init__(self, model, q_config, mode, static, quantization_params, op_types_to_quantize, fallback_list=['fp32'], reduce_range=None): - model = onnx.shape_inference.infer_shapes(model) - self.model = ONNXModel(model) + self.model = ONNXModel(model) if not isinstance(model, ONNXModel) else model + model = onnx.shape_inference.infer_shapes(self.model.model) if \ + not self.model.large_size else self.model.model self.config = q_config self.reduce_range = reduce_range if reduce_range is not None \ else False if CpuInfo().vnni else True @@ -500,8 +502,7 @@ def quantize_inputs(self, node, indices=None, "In static mode quantization params for inputs and outputs \ of nodes to be quantized are required.".format(tensor_name)) if direct_int8: - parent = self.model.get_parents(node)[0] - if not parent.output[0].endswith('_QuantizeInput'): + if node.input[0] not in self.quantized_value_map: return q_input = tensor_name q_output = tensor_name + "_" + node.name + "_QuantizeLinear" if \ @@ -510,6 +511,8 @@ def quantize_inputs(self, node, indices=None, dq_output = tensor_name + "_" + node.name + "_dequantized" if \ tensor_name not in self.model.input() else tensor_name + "_dequantized" self.replace_input.append([node, tensor_name, dq_output]) + if tensor_name in self.model.input() and tensor_name in self.quantized_value_map: + continue quant_node_name = tensor_name + "_" + node.name + "_QuantizeLinear" dequant_node_name = tensor_name + "_" + node.name + "_DequantizeLinear" diff --git a/neural_compressor/model/model.py b/neural_compressor/model/model.py index 6b6c9eea3e0..432187aad72 100644 --- a/neural_compressor/model/model.py +++ b/neural_compressor/model/model.py @@ -138,9 +138,12 @@ def _is_onnxruntime(model): ort.InferenceSession(model, so) else: ort.InferenceSession(model.SerializeToString(), so) - except: - logger.warning("If you use an onnx model with custom_ops to do quantiztaion, " - "please ensure onnxruntime-extensions is installed") + except Exception as e: # pragma: no cover + if 'Message onnx.ModelProto exceeds maximum protobuf size of 2GB' in str(e): + logger.warning('Please use model path instead of onnx model object to quantize') + else: + logger.warning("If you use an onnx model with custom_ops to do quantiztaion, " + "please ensure onnxruntime-extensions is installed") else: return 'onnxruntime' return 'NA' diff --git a/neural_compressor/model/onnx_model.py b/neural_compressor/model/onnx_model.py index f210d2e96d8..37baccf69cc 100644 --- a/neural_compressor/model/onnx_model.py +++ b/neural_compressor/model/onnx_model.py @@ -30,6 +30,16 @@ class ONNXModel(BaseModel): def __init__(self, model, **kwargs): self._model = model if not isinstance(model, str) else onnx.load(model) + self._model_path = None if not isinstance(model, str) else model + self._large_size = False + try: + ort.InferenceSession(self._model.SerializeToString()) + except Exception as e: # pragma: no cover + if 'Message onnx.ModelProto exceeds maximum protobuf size of 2GB' in str(e): + self._large_size = True + if self._model_path is None: + logger.warning('Please use model path instead of onnx model ' + 'object to quantize') self.node_name_counter = {} self._output_name_to_node = {} self._input_name_to_nodes = {} @@ -39,6 +49,18 @@ def __init__(self, model, **kwargs): self._get_graph_info() self._q_config = None + @property + def large_size(self): + return self._large_size + + @property + def model_path(self): + return self._model_path + + @model_path.setter + def model_path(self, path): + self._model_path = path + def framework(self): return 'onnxruntime' @@ -89,6 +111,14 @@ def _get_graph_info(self): def save(self, root): if os.path.split(root)[0] != '' and not os.path.exists(os.path.split(root)[0]): raise ValueError('"root" directory does not exists.') + if self.large_size: # pragma: no cover + from onnx.external_data_helper import convert_model_to_external_data, \ + load_external_data_for_model + load_external_data_for_model(self._model, os.path.split(self._model_path)[0]) + convert_model_to_external_data(self._model, + all_tensors_to_one_file=True, + location="weights.pb", + convert_attribute=False) onnx.save(self._model, root) def nodes(self): @@ -331,6 +361,7 @@ def remove_unused_constant(self): def topological_sort(self, enable_subgraph=False): from collections import deque from functools import reduce + import copy if not enable_subgraph: input_name_to_nodes = {} output_name_to_node = {} @@ -347,26 +378,32 @@ def topological_sort(self, enable_subgraph=False): output_name_to_node = self._output_name_to_node all_nodes = {} - q = deque([output_name_to_node[i.name] for i in self.model.graph.output]) + q = deque() + wait = deque() + for inp in self.model.graph.input: + q.extend(input_name_to_nodes[inp.name]) + for n in self.model.graph.node: + if all([i not in output_name_to_node and i not in self.input() for i in n.input]): + q.append(n) + while q: n = q.popleft() - flag = True - for out in n.output: - if out in input_name_to_nodes and \ - any([i.name not in all_nodes for i in input_name_to_nodes[out]]): - q.extend(filter(lambda x:x.name not in all_nodes and x not in q, \ - input_name_to_nodes[out])) - flag = False - if not flag: - q.append(n) - if flag and n.name not in all_nodes: - all_nodes[n.name] = n - for inp in n.input: - if inp in output_name_to_node: - q.append(output_name_to_node[inp]) + if not all([output_name_to_node[i].name in all_nodes for \ + i in n.input if i in output_name_to_node]): + if n not in wait: + wait.append(n) + continue + all_nodes[n.name] = n + for out in n.output: + if out in input_name_to_nodes: + q.extend([i for i in input_name_to_nodes[out] if \ + i.name not in all_nodes and i not in q]) + if len(q) == 0 and len(wait) != 0: + q = copy.deepcopy(wait) + wait.clear() nodes = [i[1] for i in all_nodes.items()] - nodes.reverse() + assert len(nodes) == len(self.model.graph.node) self.model.graph.ClearField('node') self.model.graph.node.extend(nodes)