Skip to content

Commit

Permalink
Fix onnxrt backend recover function (#1788)
Browse files Browse the repository at this point in the history
Signed-off-by: Mengni Wang <[email protected]>
  • Loading branch information
mengniwang95 authored May 17, 2024
1 parent b6237cf commit ee24dba
Show file tree
Hide file tree
Showing 3 changed files with 125 additions and 90 deletions.
197 changes: 111 additions & 86 deletions neural_compressor/adaptor/onnxrt.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,9 @@ def quantize(self, tune_cfg, model, data_loader, q_func=None):
if ort_version < ONNXRT152_VERSION: # pragma: no cover
logger.warning("Quantize input needs onnxruntime 1.5.2 or newer.")
return model
if ort_version < ONNXRT170_VERSION and self.format == "qdq":
logger.error("QDQ mode needs onnxruntime1.7.0 or newer.")
exit(0)
if model.model.opset_import[0].version < 11: # pragma: no cover
logger.warning("Quantize input needs model opset 11 or newer.")
if self.backend == "DnnlExecutionProvider" and any(
Expand All @@ -289,17 +292,6 @@ def quantize(self, tune_cfg, model, data_loader, q_func=None):
"please upgrade it manually to run with bf16 data type"
)
exit(0)

from neural_compressor.adaptor.ox_utils.util import QuantizationMode

if self.format == "qlinearops":
format = QuantizationMode.QLinearOps
elif self.format == "qdq":
assert ort_version >= ONNXRT170_VERSION, "QDQ mode needs onnxruntime1.7.0 or newer"
format = "qdq"
else:
format = QuantizationMode.IntegerOps

self.quantizable_ops = self._query_quantizable_ops(model.model)
quantize_config = self._cfg_to_quantize_config(tune_cfg)

Expand Down Expand Up @@ -405,43 +397,11 @@ def quantize(self, tune_cfg, model, data_loader, q_func=None):
)
else:
quantize_params = None
q_config = self._generate_qconfig(model.model, tune_cfg, quantize_params)
self.quantize_params = quantize_params

from neural_compressor import options
from neural_compressor.adaptor.ox_utils.quantizer import Quantizer

quantizer = Quantizer(
tmp_model,
quantize_config,
format,
self.static,
quantize_params,
self.quantizable_op_types,
self.query_handler.get_fallback_list(),
self.reduce_range,
(
options.onnxrt.qdq_setting.AddQDQPairToWeight
if "add_qdq_pair_to_weight" not in self.recipes
else self.recipes.get("add_qdq_pair_to_weight", False)
),
(
options.onnxrt.qdq_setting.OpTypesToExcludeOutputQuantizatioin
if "optypes_to_exclude_output_quant" not in self.recipes
else self.recipes.get("optypes_to_exclude_output_quant", [])
),
(
options.onnxrt.qdq_setting.DedicatedQDQPair
if "dedicated_qdq_pair" not in self.recipes
else self.recipes.get("dedicated_qdq_pair", False)
),
self.backend,
)
quantizer.quantize_model()
tmp_model.q_config = self._generate_qconfig(model.model, tune_cfg, quantize_params)
tmp_model.model = quantizer.model.model
self.quantize_config = quantize_config # update so other methods can know current configs
tmp_model = self._quantize_model(tmp_model, quantize_config, quantize_params)
tmp_model.q_config = q_config
self._dump_model_op_stats(tmp_model)
tmp_model.topological_sort()

# if the model is large and acc tuning is required, save it to workspace
if not self.performance_only and tmp_model.is_large_model: # pragma: no cover
Expand Down Expand Up @@ -496,13 +456,21 @@ def _get_split_model_quantize_params(
)
return split_quantize_params, dataloder_for_next_split_model

def _quantize_split_model(self, split_model, quantize_config, quantize_params, quantized_model_merged):
"""Quantize split model, and merge the quantized models to generate final model."""
def _quantize_model(self, model, quantize_config, quantize_params):
"""Quantize model."""
from neural_compressor import options
from neural_compressor.adaptor.ox_utils.quantizer import Quantizer
from neural_compressor.adaptor.ox_utils.util import QuantizationMode

if self.format == "qlinearops":
format = QuantizationMode.QLinearOps
elif self.format == "qdq":
format = "qdq"
else:
format = QuantizationMode.IntegerOps

quantizer = Quantizer(
split_model,
model,
quantize_config,
format,
self.static,
Expand All @@ -528,14 +496,19 @@ def _quantize_split_model(self, split_model, quantize_config, quantize_params, q
self.backend,
)
quantizer.quantize_model()
split_model.model = quantizer.model.model
split_model.topological_sort()
model.model = quantizer.model.model
self.quantize_config = quantize_config # update so other methods can know current configs
model.topological_sort()
return model

def _quantize_split_model(self, split_model, quantize_config, quantize_params, quantized_model_merged):
"""Quantize split model, and merge the quantized models to generate final model."""
split_model = self._quantize_model(split_model, quantize_config, quantize_params)
if quantized_model_merged is None:
quantized_model_merged = quantizer.model
quantized_model_merged = split_model
quantized_model_merged.write_external_data_to_new_location(overwrite=True)
else:
quantized_model_merged.merge_split_models(quantizer.model)
quantized_model_merged.merge_split_models(split_model)

return quantized_model_merged

Expand Down Expand Up @@ -640,57 +613,109 @@ def recover(self, model, q_config):
"""
self._pre_optimize(model)
model = self.pre_optimized_model

ort_version = Version(ort.__version__)
if ort_version < ONNXRT152_VERSION: # pragma: no cover
logger.warning("Quantize input needs onnxruntime 1.5.2 or newer.")
return model
if model.model.opset_import[0].version < 11: # pragma: no cover
logger.warning("Quantize input needs model opset 11 or newer.")
if ort_version < ONNXRT170_VERSION and self.format == "qdq":
logger.error("QDQ mode needs onnxruntime1.7.0 or newer.")
exit(0)
if self.backend == "DnnlExecutionProvider" and any(
[i.domain in ["", "ai.onnx"] and i.version < 15 for i in model.model.opset_import]
): # pragma: no cover
from onnx import version_converter

try:
model = self._rename_node(ONNXModel(version_converter.convert_version(model.model, 15)))
except:
logging.warning(
"Fail to upgrade model opset_import to >= 15, "
"please upgrade it manually to run with bf16 data type"
)
exit(0)

from neural_compressor.adaptor.ox_utils.util import QuantizationMode

if self.format in ["qlinearops"]:
if self.format == "qlinearops":
format = QuantizationMode.QLinearOps
elif self.format == "qdq":
assert ort_version >= ONNXRT170_VERSION, "QDQ mode needs onnxruntime1.7.0 or newer"
format = self.format
format = "qdq"
else:
format = QuantizationMode.IntegerOps
from neural_compressor import options
from neural_compressor.adaptor.ox_utils.quantizer import Quantizer

self.quantizable_ops = self._query_quantizable_ops(model.model)
quantize_params, tune_cfg = self._parse_qconfig(q_config)
quantize_config = self._cfg_to_quantize_config(tune_cfg)
quantizer = Quantizer(
model.model,
quantize_config,
format,
self.static,
quantize_params,
self.quantizable_op_types,
self.query_handler.get_fallback_list(),
self.reduce_range,
(
options.onnxrt.qdq_setting.AddQDQPairToWeight
if not options.onnxrt.qdq_setting.AddQDQPairToWeight
else self.recipes.get("add_qdq_pair_to_weight", False)
),
(
options.onnxrt.qdq_setting.OpTypesToExcludeOutputQuantizatioin
if options.onnxrt.qdq_setting.OpTypesToExcludeOutputQuantizatioin is not None
else self.recipes.get("optypes_to_exclude_output_quant", [])
),
(
options.onnxrt.qdq_setting.DedicatedQDQPair
if not options.onnxrt.qdq_setting.DedicatedQDQPair
else self.recipes.get("dedicated_qdq_pair", False)
),
)

quantizer.quantize_model()
model.model = quantizer.model.model
model.topological_sort()
if self._need_smooth_quant(tune_cfg):
logger.error("Don't support to recover quantized model with smooth quant from original fp32 model.")
exit(0)

if self.recipes.get("layer_wise_quant", False) and not self.dynamic:
# layer-wise quantization
# details refer to docs/source/quantization_weight_only.md#layer-wise-quantization
_model_to_split = copy.deepcopy(model)

split_nodes = _model_to_split.find_split_nodes()
logger.info(
"Will split model into {} parts to do layer-wise quantization".format(
len([node.name for node in split_nodes]) + 1
)
)
logger.debug(
"Will split model with these nodes for layer-wise quantization: {}".format(
[node.name for node in split_nodes]
)
)

split_idx = 1
model_to_split = [_model_to_split]
quantized_model_merged = None

while len(model_to_split) != 0:
split_model = model_to_split.pop(0)
split_node = split_nodes.pop(0)
save_both_split_models = True if len(split_nodes) == 0 else False
shape_infer = True if split_idx == 1 else False

# split model with given split_node
split_model_part_1, split_model_part_2 = split_model.split_model_with_node(
split_node.name, model.model_path, shape_infer, save_both_split_models
)
if not save_both_split_models:
# append split_model_part_2 to do next split
model_to_split.append(split_model_part_2)

logger.info("Quantize split model {}".format(split_idx))

# quantize split model
quantized_model_merged = self._quantize_split_model(
split_model_part_1, quantize_config, quantize_params, quantized_model_merged
)

split_idx += 1

# if this is the last split, then quantize the last split model
if save_both_split_models:
logger.info("Quantize split model {}".format(split_idx))

# quantize split model
quantized_model_merged = self._quantize_split_model(
split_model_part_2, quantize_config, quantize_params, quantized_model_merged
)
quantized_model_merged.re_org_output(model.output()) # re-org output as the origin output

model.model = quantized_model_merged.model
self._dump_model_op_stats(model)
model.check_is_large_model()

else:
model = self._quantize_model(model, quantize_config, quantize_params)

self._dump_model_op_stats(model)
return model

def _parse_qconfig(self, q_config):
Expand Down
12 changes: 9 additions & 3 deletions test/adaptor/onnxrt_adaptor/test_adaptor_onnxrt.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,14 @@
from packaging.version import Version
from transformers import AutoConfig, AutoModelForSequenceClassification, AutoTokenizer

from neural_compressor import PostTrainingQuantConfig, quantization
from neural_compressor import PostTrainingQuantConfig, quantization, set_workspace
from neural_compressor.adaptor import FRAMEWORKS
from neural_compressor.adaptor.pytorch import get_torch_version
from neural_compressor.conf.config import conf
from neural_compressor.data import DATALOADERS, DataLoader, Datasets
from neural_compressor.experimental import Benchmark, Quantization, common
from neural_compressor.model import Model
from neural_compressor.utils.utility import recover


def build_static_yaml():
Expand Down Expand Up @@ -898,6 +899,7 @@ def setUpClass(self):
self.albert_model = onnx.load(self.albert_export_path)
self.gather_matmul_model = build_matmul_gather_model()
build_benchmark()
set_workspace("nc_workspace")

@classmethod
def tearDownClass(self):
Expand Down Expand Up @@ -1390,8 +1392,6 @@ def test_adaptor(self):
self.assertNotEqual(q_model, None)

# check recover model function
from neural_compressor.utils.utility import recover

model = recover(self.mb_v2_model, "./nc_workspace/recover/history.snapshot", 0)
self.assertTrue(model.model == q_model.model)

Expand Down Expand Up @@ -1489,6 +1489,10 @@ def eval(model):
q_model = quantization.fit(self.matmul_model, config, calib_dataloader=self.matmul_dataloader, eval_func=eval)
self.assertTrue("QLinearMatMul" in [i.op_type for i in q_model.nodes()])

q_model = quantization.fit(self.matmul_model, config, calib_dataloader=self.matmul_dataloader)
recover_model = recover(self.matmul_model, "nc_workspace/history.snapshot", 0)
self.assertTrue(q_model.model == recover_model.model)

config = PostTrainingQuantConfig(approach="dynamic")
q_model = quantization.fit(self.matmul_model, config, calib_dataloader=self.matmul_dataloader, eval_func=eval)
self.assertTrue("MatMulInteger" in [i.op_type for i in q_model.nodes()])
Expand Down Expand Up @@ -1535,6 +1539,8 @@ def test_smooth_quant(self):
)
q_model = quantization.fit(self.conv_model, config, calib_dataloader=self.cv_dataloader)
self.assertEqual(len([i for i in q_model.nodes() if i.op_type == "Mul"]), 2)
with self.assertRaises(SystemExit):
recover_model = recover(self.conv_model, "nc_workspace/history.snapshot", 0)

def test_smooth_quant_args(self):
from neural_compressor.model.onnx_model import ONNXModel
Expand Down
6 changes: 5 additions & 1 deletion test/adaptor/onnxrt_adaptor/test_layer_wise.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@
import onnxruntime as ort
from transformers import AutoTokenizer

from neural_compressor import PostTrainingQuantConfig, quantization
from neural_compressor import PostTrainingQuantConfig, quantization, set_workspace
from neural_compressor.utils.constant import FP32
from neural_compressor.utils.utility import recover


def Inference(model_path, data):
Expand Down Expand Up @@ -44,6 +45,7 @@ def setUpClass(self):

self.model = onnx.load("tiny-llama/decoder_model.onnx")
self.dataloader = DummyNLPDataloader("yujiepan/llama-2-tiny-3layers-random")
set_workspace("nc_workspace")

@classmethod
def tearDownClass(self):
Expand All @@ -57,6 +59,8 @@ def test_layer_wise_W8A8_quant(self):
calibration_sampling_size=[1], recipes={"layer_wise_quant": True}, op_type_dict={"^((?!(MatMul)).)*$": FP32}
)
q_model = quantization.fit("tiny-llama/decoder_model.onnx", config, calib_dataloader=self.dataloader)
recover_model = recover("tiny-llama/decoder_model.onnx", "nc_workspace/history.snapshot", 0)
self.assertTrue(recover_model.model == q_model.model)
q_model.save(layerwise_quantized_model_path)

# not layer-wise quantization
Expand Down

0 comments on commit ee24dba

Please sign in to comment.