From 6d4ea5b114d7af4030626702da7b515a3f7771a9 Mon Sep 17 00:00:00 2001 From: Kaihui-intel Date: Fri, 20 Oct 2023 16:57:40 +0800 Subject: [PATCH] Enable the tuning of WOQ algorithm (#1328) * support WOQ algos tuning --------- Signed-off-by: Kaihui-intel Signed-off-by: yuwenzho Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: yuwenzho --- neural_compressor/adaptor/onnxrt.py | 57 +++++++++++--- neural_compressor/strategy/auto.py | 6 ++ neural_compressor/strategy/basic.py | 6 ++ neural_compressor/strategy/strategy.py | 36 +++++++++ neural_compressor/strategy/utils/constant.py | 10 +++ .../strategy/utils/tuning_sampler.py | 32 ++++++++ neural_compressor/utils/utility.py | 31 ++++++++ .../test_weight_only_adaptor.py | 75 +++++++++++++++++++ 8 files changed, 243 insertions(+), 10 deletions(-) diff --git a/neural_compressor/adaptor/onnxrt.py b/neural_compressor/adaptor/onnxrt.py index 59fb352d4d5..7d46b531fd2 100644 --- a/neural_compressor/adaptor/onnxrt.py +++ b/neural_compressor/adaptor/onnxrt.py @@ -1632,25 +1632,36 @@ def quantize(self, tune_cfg, model, data_loader, q_func=None): Returns: (dict): quantized model """ + if self.performance_only: + tmp_model = model + else: + try: + tmp_model = copy.deepcopy(model) + except Exception as e: # pragma: no cover + logger.warning("Fail to deep copy the model due to {}, inplace is used now.".format(repr(e))) + tmp_model = model + assert q_func is None, "quantization aware training has not been supported on ONNXRUNTIME" for precision in self.query_handler.get_precisions(): if precision == "weight_only_integer": self.quantizable_op_types += self.query_handler.get_op_types_by_precision(precision=precision) - self.quantizable_ops = self._query_quantizable_ops(model.model) + self.quantizable_ops = self._query_quantizable_ops(tmp_model.model) + self._update_tune_cfg(tune_cfg, tmp_model.model) quant_config = self._cfg_to_quantize_config(tune_cfg) algos = set([item["algorithm"] for key, item in quant_config.items() if isinstance(item, dict)]) if "GPTQ" in algos: from neural_compressor.adaptor.ox_utils.weight_only import gptq_quantize + assert data_loader is not None, "GPTQ WOQ algorithm needs to pass 'calib_dataloader' to quantization.fit()" percdamp = self.recipes.get("gptq_args", {}).get("percdamp", 0.01) blocksize = self.recipes.get("gptq_args", {}).get("blocksize", 128) actorder = self.recipes.get("gptq_args", {}).get("actorder", False) mse = self.recipes.get("gptq_args", {}).get("mse", False) perchannel = self.recipes.get("gptq_args", {}).get("perchannel", True) calib_sampling_size = tune_cfg.get("calib_sampling_size", 1) - model = gptq_quantize( - model, + tmp_model = gptq_quantize( + tmp_model, data_loader, quant_config, n_samples=calib_sampling_size, @@ -1663,11 +1674,12 @@ def quantize(self, tune_cfg, model, data_loader, q_func=None): if "AWQ" in algos: from neural_compressor.adaptor.ox_utils.weight_only import awq_quantize + assert data_loader is not None, "AWQ WOQ algorithm needs to pass 'calib_dataloader' to quantization.fit()" enable_auto_scale = self.recipes.get("awq_args", {}).get("enable_auto_scale", True) enable_mse_search = self.recipes.get("awq_args", {}).get("enable_mse_search", True) calib_sampling_size = tune_cfg.get("calib_sampling_size", 1) - model = awq_quantize( - model, + tmp_model = awq_quantize( + tmp_model, data_loader, quant_config, n_samples=calib_sampling_size, @@ -1677,11 +1689,11 @@ def quantize(self, tune_cfg, model, data_loader, q_func=None): elif "RTN" in algos: from neural_compressor.adaptor.ox_utils.weight_only import rtn_quantize - model = rtn_quantize(model, quant_config) - model.q_config = copy.deepcopy(quant_config) - self._dump_model_op_stats(model, tune_cfg) - model.topological_sort() - return model + tmp_model = rtn_quantize(tmp_model, quant_config) + tmp_model.q_config = copy.deepcopy(quant_config) + self._dump_model_op_stats(tmp_model, tune_cfg) + tmp_model.topological_sort() + return tmp_model def _dump_model_op_stats(self, model, tune_cfg): import re @@ -1747,6 +1759,31 @@ def _cfg_to_quantize_config(self, tune_cfg): return quantize_config + def _update_tune_cfg(self, tune_cfg, model): + """Update tune cfg according to woq_tuning_cfg.""" + if tune_cfg.get("woq_tuning_cfg") is None: + return tune_cfg + + from neural_compressor.strategy.utils.constant import WOQ_TUNING_ALGOS + + woq_tuning_cfg = tune_cfg.get("woq_tuning_cfg") + new_woq_cfg = WOQ_TUNING_ALGOS.get(woq_tuning_cfg) + + for node_cfg in tune_cfg["op"].values(): + node_cfg["weight"].update( + {cfg_name: cfg_value for cfg_name, cfg_value in new_woq_cfg.items() if cfg_name in node_cfg["weight"]} + ) + + # find last matmul and set to fp32 + if "DISABLE_LAST_MATMUL" in woq_tuning_cfg: + last_matmul = None + fp32_op_cfg = {"weight": {"dtype": "fp32"}, "activation": {"dtype": "fp32", "quant_mode": "fp32"}} + for node in model.graph.node: + if node.op_type in ["MatMul"]: + last_matmul = (node.name, node.op_type) + if last_matmul in tune_cfg["op"]: + tune_cfg["op"][last_matmul].update(fp32_op_cfg) + def query_fw_capability(self, model): """The function is used to query framework capability. TODO: will be replaced by framework query API diff --git a/neural_compressor/strategy/auto.py b/neural_compressor/strategy/auto.py index 17511030de9..81c500e0f17 100644 --- a/neural_compressor/strategy/auto.py +++ b/neural_compressor/strategy/auto.py @@ -120,6 +120,12 @@ def next_tune_cfg(self): op_tuning_cfg["calib_sampling_size"] = calib_sampling_size_lst[0] if not self.cur_best_tuning_cfg: self.cur_best_tuning_cfg = deepcopy(op_tuning_cfg) + + # try to tune a WeightOnlyQuant algorithm + if self._should_tuning_woq_algo(): + for tune_cfg in self.tuning_woq_algo(tuning_space, deepcopy(self.cur_best_tuning_cfg)): + yield tune_cfg + # try to tune sq alpha if self._should_tuning_sq_alpha(self.config.recipes): for tune_cfg in self.tuning_sq_alpha(tuning_space, deepcopy(self.cur_best_tuning_cfg), self.config.recipes): diff --git a/neural_compressor/strategy/basic.py b/neural_compressor/strategy/basic.py index 1b708cd8d4e..e7aebbe09f7 100644 --- a/neural_compressor/strategy/basic.py +++ b/neural_compressor/strategy/basic.py @@ -312,6 +312,12 @@ def next_tune_cfg(self): stage1_max = 1e9 # TODO set a more appropriate value if not self.cur_best_tuning_cfg: self.cur_best_tuning_cfg = deepcopy(initial_op_tuning_cfg) + + # try to tune a WeightOnlyQuant algorithm + if self._should_tuning_woq_algo(): + for tune_cfg in self.tuning_woq_algo(tuning_space, deepcopy(self.cur_best_tuning_cfg)): + yield tune_cfg + # try to tune sq alpha if self._should_tuning_sq_alpha(self.config.recipes): for tune_cfg in self.tuning_sq_alpha( diff --git a/neural_compressor/strategy/strategy.py b/neural_compressor/strategy/strategy.py index 16537baa7b0..730f6e12760 100644 --- a/neural_compressor/strategy/strategy.py +++ b/neural_compressor/strategy/strategy.py @@ -46,6 +46,7 @@ DotDict, LazyImport, Statistics, + check_key_exist, dump_table, equal_dicts, fault_tolerant_file, @@ -1153,6 +1154,40 @@ def tuning_sq_alpha(self, tuning_space, tuning_cfg, recipes): for tune_cfg in sq_sampler: yield tune_cfg + def _should_tuning_woq_algo(self): + """Currently, it's only available for the ORT backend with approach is weight_only. + + It will be triggered when + a) quant_level is auto or quant_level is 1 && strategy is basic + b) and the "algorithm" is not set in op_type_dict + c) and woq will only trigger once + """ + return ( + "onnx" in self.framework.lower() + and "weight_only" in self.config.approach + and not check_key_exist(self.config.op_type_dict, "algorithm") + and not check_key_exist(self.tuning_history, "woq_tuning_cfg") + ) + + def tuning_woq_algo(self, tuning_space, tuning_cfg): + """Tuning weight only algorithm. + + Args: + tuning_space: tuning space + tuning_cfg: the initial tuning config + + Yields: + tuning config + """ + logger.info("[STRATEGY] Start tuning Weight Only Quant' algo.") + woq_sampler = tuning_sampler_dict.get_class("woq_algorithm")(tuning_space, [], tuning_cfg) + for tune_cfg in woq_sampler: + yield tune_cfg + + logger.info( + "[Strategy] The best tuning config with WeightOnlyQuant is" f"{self.cur_best_tuning_cfg['woq_tuning_cfg']}." + ) + def initial_dynamic_cfg_based_on_static_cfg(self, op_static_cfg: OpTuningConfig): """Init the dynamic tuning config according to the static config. @@ -1322,6 +1357,7 @@ def _tune_cfg_converter(self, op_tuning_cfg): # For not tuning recipe, tune cfg use it directly tune_cfg["recipe_cfgs"].update(self._not_tuning_recipes_values) tune_cfg["trial_number"] = deepcopy(self.trials_count) + tune_cfg.setdefault("woq_tuning_cfg", op_tuning_cfg.get("woq_tuning_cfg")) # The sq-related args comes from user config, current best tuning config # TODO simplify the logic for transforming the arguments # update the sq-related args from self.cur_best_tuning_cfg diff --git a/neural_compressor/strategy/utils/constant.py b/neural_compressor/strategy/utils/constant.py index 842635c4781..9b6afc65d9a 100644 --- a/neural_compressor/strategy/utils/constant.py +++ b/neural_compressor/strategy/utils/constant.py @@ -16,6 +16,7 @@ # limitations under the License. """Strategy constant.""" + PRECISION_LIST = ["bf16", "fp16", "fp32"] QUANT_MODE_SET = {"static", "dynamic"} LOWER_BIT_LIST = ["int4"] @@ -56,3 +57,12 @@ "last_conv_or_matmul_quantization", "pre_post_process_quantization", } + + +WOQ_TUNING_ALGOS = { + "RTN_G32ASYM": {"algorithm": "RTN", "group_size": 32, "scheme": "asym"}, + "GPTQ_G32ASYM": {"algorithm": "GPTQ", "group_size": 32, "scheme": "asym"}, + "GPTQ_G32ASYM_DISABLE_LAST_MATMUL": {"algorithm": "GPTQ", "group_size": 32, "scheme": "asym"}, + "GPTQ_G128ASYM": {"algorithm": "GPTQ", "group_size": 128, "scheme": "asym"}, + "AWQ_G32ASYM": {"algorithm": "AWQ", "group_size": 32, "scheme": "asym"}, +} diff --git a/neural_compressor/strategy/utils/tuning_sampler.py b/neural_compressor/strategy/utils/tuning_sampler.py index d7704dbdd82..5261bee988c 100644 --- a/neural_compressor/strategy/utils/tuning_sampler.py +++ b/neural_compressor/strategy/utils/tuning_sampler.py @@ -23,6 +23,7 @@ from typing import Any, Dict, List, Tuple, Union from ...utils import logger +from ..utils.constant import WOQ_TUNING_ALGOS from .tuning_space import TuningSpace, pattern_to_internal, quant_mode_from_pattern from .tuning_structs import OpTuningConfig from .utility import ClassRegister @@ -609,3 +610,34 @@ def __iter__(self): recipe_cfgs["smooth_quant_args"] = {"alpha": alpha} logger.debug(f"[STRATEGY] set smooth quant alpha with: {alpha:.4f}") yield new_tune_cfg + + +@tuning_sampler_dict("woq_algorithm") +class WeightOnlyQuantSampler(TuningSampler): + """Not displayed in API Docs.""" + + def __init__( + self, + tuning_space: TuningSpace, + tuning_order_lst: List[TuningOrder], + initial_op_tuning_cfg: Dict, + ): + """Init tuning sampler. + + Args: + tuning_space: The tuning space. + tuning_order_lst: The traverse orders. + initial_op_tuning_cfg: The initialized tuning config. + """ + super().__init__(tuning_space, tuning_order_lst, initial_op_tuning_cfg) + + def __iter__(self): + """Yield the next tuning config. + + Yields: + The next tuning config. + """ + new_tune_cfg = copy.deepcopy(self.initial_op_tuning_cfg) + for algo in WOQ_TUNING_ALGOS.keys(): + new_tune_cfg["woq_tuning_cfg"] = algo + yield new_tune_cfg diff --git a/neural_compressor/utils/utility.py b/neural_compressor/utils/utility.py index 78f00382c39..b46446cd5dc 100644 --- a/neural_compressor/utils/utility.py +++ b/neural_compressor/utils/utility.py @@ -1092,3 +1092,34 @@ def mse_metric_gap(fp32_tensor: Any, dequantize_tensor: Any) -> float: diff_tensor = fp32_tensor_norm - dequantize_tensor_norm euclidean_dist = np.sum(diff_tensor**2) # type: ignore return euclidean_dist / fp32_tensor.size + + +def check_key_exist(data, key): + """Recursively checks if a key exists in a dictionary or list. + + Args: + data (dict or list): The dictionary or list to search. + key (any): The key to search for. + + Returns: + bool: True if the key exists in the data structure, False otherwise. + + Examples: + >>> check_key_exist({'a': 1, 'b': {'c': 2}}, 'c') + True + >>> check_key_exist([{'a': 1}, {'b': 2}], 'b') + True + >>> check_key_exist({'a': 1, 'b': [1, 2, 3]}, 'c') + False + """ + if isinstance(data, dict): + if key in data: + return True + for value in data.values(): + if check_key_exist(value, key): + return True + elif isinstance(data, list): + for item in data: + if check_key_exist(item, key): + return True + return False diff --git a/test/adaptor/onnxrt_adaptor/test_weight_only_adaptor.py b/test/adaptor/onnxrt_adaptor/test_weight_only_adaptor.py index eb71711900b..6d2201df104 100644 --- a/test/adaptor/onnxrt_adaptor/test_weight_only_adaptor.py +++ b/test/adaptor/onnxrt_adaptor/test_weight_only_adaptor.py @@ -1,3 +1,4 @@ +import copy import os import shutil import subprocess @@ -9,6 +10,7 @@ from transformers import AutoTokenizer from neural_compressor import PostTrainingQuantConfig, quantization +from neural_compressor.utils.constant import FP32 def Inference(model, data): @@ -265,6 +267,79 @@ def test_GPTQ_quant(self): ] self.assertTrue(len(rtn_op_names) + 1, len(gptq_op_names)) + def _test_woq_tune_common(self, eval_func, quant_level=1, **kwargs): + from neural_compressor import quantization + from neural_compressor.config import PostTrainingQuantConfig, TuningCriterion + + tuning_criterion = TuningCriterion(max_trials=5) + + fp32_model = copy.deepcopy(self.model) + conf = PostTrainingQuantConfig( + approach="weight_only", quant_level=quant_level, tuning_criterion=tuning_criterion, **kwargs + ) + q_model = quantization.fit( + fp32_model, + conf, + calib_dataloader=self.dataloader, + eval_func=eval_func, + ) + self.assertIsNotNone(q_model) + return q_model + + def _count_woq_matmul(self, q_model, bits=4, group_size=32): + op_names = [ + i.name + for i in q_model.nodes() + if i.op_type.startswith("MatMul") and i.input[1].endswith("_Q{}G{}".format(bits, group_size)) + ] + return len(op_names) + + def test_woq_tune(self): + from functools import partial + + def fake_eval(model, eval_result_lst): + acc = eval_result_lst.pop(0) + return acc + + quant_levels = ["auto", 1] + for quant_level in quant_levels: + # Expect tuning ends with WOQ algorithm 'RTN_G32ASYM' + partial_fake_eval = partial(fake_eval, eval_result_lst=[1, 1.1]) + woq_model_1 = self._test_woq_tune_common(partial_fake_eval, quant_level) + self.assertEqual(self._count_woq_matmul(woq_model_1), 31) + + # Expect tuning ends with WOQ algorithm 'GPTQ_G32ASYM' + partial_fake_eval = partial(fake_eval, eval_result_lst=[1, 0.8, 1.1]) + woq_model_2 = self._test_woq_tune_common(partial_fake_eval, quant_level) + self.assertEqual(self._count_woq_matmul(woq_model_2), 31) + + # Expect tuning ends with WOQ algorithm 'GPTQ_G32ASYM_DISABLE_LAST_MATMUL' + partial_fake_eval = partial(fake_eval, eval_result_lst=[1, 0.8, 0.8, 1.1]) + woq_model_3 = self._test_woq_tune_common(partial_fake_eval, quant_level) + self.assertEqual(self._count_woq_matmul(woq_model_3), 30) + + # Expect tuning ends with WOQ algorithm 'GPTQ_G128ASYM' + partial_fake_eval = partial(fake_eval, eval_result_lst=[1, 0.8, 0.8, 0.8, 1.1]) + woq_model_4 = self._test_woq_tune_common(partial_fake_eval, quant_level) + self.assertEqual(self._count_woq_matmul(woq_model_4, group_size=128), 31) + + # Expect tuning ends with WOQ algorithm 'AWQ_G32ASYM' + partial_fake_eval = partial(fake_eval, eval_result_lst=[1, 0.8, 0.8, 0.8, 0.8, 1.1]) + woq_model_5 = self._test_woq_tune_common(partial_fake_eval, quant_level) + self.assertEqual(self._count_woq_matmul(woq_model_5), 31) + + # test WOQ tuning with fallback + partial_fake_eval = partial(fake_eval, eval_result_lst=[1, 1.1]) + woq_model = self._test_woq_tune_common( + partial_fake_eval, "auto", op_name_dict={"/transformer/h.*/attn/k_proj/MatMul": FP32} + ) + self.assertEqual(self._count_woq_matmul(woq_model), 26) + + # test 8 bits WOQ + partial_fake_eval = partial(fake_eval, eval_result_lst=[1, 1.1]) + woq_model = self._test_woq_tune_common(partial_fake_eval, "auto", op_type_dict={".*": {"weight": {"bits": 8}}}) + self.assertEqual(self._count_woq_matmul(woq_model, bits=8), 31) + if __name__ == "__main__": unittest.main()