From 9600e1d973bf17e962e55fb0f3817db76ce8fea9 Mon Sep 17 00:00:00 2001 From: Yi30 <106061964+yiliu30@users.noreply.github.com> Date: Fri, 22 Dec 2023 22:03:35 +0800 Subject: [PATCH] Narrow the tuning space of sq auto-tune (#1489) Signed-off-by: yiliu30 --- docs/source/tuning_strategies.md | 2 +- neural_compressor/strategy/strategy.py | 18 ++++++++++++++++++ test/algorithm/test_smooth_quant.py | 14 +++++--------- test/algorithm/test_smooth_quant_onnx.py | 6 ------ 4 files changed, 24 insertions(+), 16 deletions(-) diff --git a/docs/source/tuning_strategies.md b/docs/source/tuning_strategies.md index 7788149f7e3..6238218892b 100644 --- a/docs/source/tuning_strategies.md +++ b/docs/source/tuning_strategies.md @@ -179,7 +179,7 @@ flowchart TD > `*` INC will detect the block pattern for [transformer-like](https://arxiv.org/abs/1706.03762) model by default. -> For [smooth quantization](./smooth_quant.md), users can tune the smooth quantization alpha by providing a list of scalars for the `alpha` item. The tuning process will take place at the **start stage** of the tuning procedure. For details usage, please refer to the [smooth quantization example](./smooth_quant.md#Example). +> For [smooth quantization](./smooth_quant.md), users can tune the smooth quantization alpha by providing a list of scalars for the `alpha` item. For details usage, please refer to the [smooth quantization example](./smooth_quant.md#Usage). > For [weight-only quantization](./quantization_weight_only.md), users can tune the weight-only algorithms from the available [pre-defined configurations](./quantization_weight_only.md#woq-algorithms-tuning). The tuning process will take place at the **start stage** of the tuning procedure, preceding the smooth quantization alpha tuning. For details usage, please refer to the [weight-only quantization example](./quantization_weight_only.md#woq-algorithms-tuning). *Please note that this behavior is specific to the `ONNX Runtime` backend.* diff --git a/neural_compressor/strategy/strategy.py b/neural_compressor/strategy/strategy.py index 95691a38142..df5627be4b3 100644 --- a/neural_compressor/strategy/strategy.py +++ b/neural_compressor/strategy/strategy.py @@ -186,6 +186,7 @@ def __init__( # track tuning cfg with the current best accuracy self.cur_best_tuning_cfg = {} self.re_quant = False + self.early_stop_sq_tuning_process = False self._trials_count = 0 self._capability = None @@ -1152,6 +1153,9 @@ def _should_tuning_sq_alpha(self, recipes): def tuning_sq_alpha(self, tuning_space, tuning_cfg, recipes): """Tuning smooth quant's alpha. + After trying all alpha values, the sq tuning process will stop early, returning the current best qmodel, + even if the current best accuracy does not meet the accuracy criterion. + Args: tuning_space: tuning space tuning_cfg: the initial tuning config @@ -1166,8 +1170,12 @@ def tuning_sq_alpha(self, tuning_space, tuning_cfg, recipes): ), "Only tune the smooth quant's alpha when user provide the alpha list,\ but got alpha_list: {alpha_list}" logger.info("[STRATEGY] Start tuning smooth quant'alpha.") + number_of_alpha = len(sq_alpha_list) + sq_trials_cnt = 0 sq_sampler = tuning_sampler_dict.get_class("smooth_quant")(tuning_space, [], tuning_cfg, sq_alpha_list) for tune_cfg in sq_sampler: + sq_trials_cnt += 1 + self.early_stop_sq_tuning_process = sq_trials_cnt == number_of_alpha yield tune_cfg def _should_tuning_woq_algo(self): @@ -1961,6 +1969,16 @@ def stop(self, timeout, trials_count): need_stop = True else: need_stop = False + if not need_stop and self.early_stop_sq_tuning_process: + if self.best_tuning_cfg is None: + self.best_tuning_cfg = self._tune_cfg_converter(self.cur_best_tuning_cfg) + logger.info( + "[Strategy] Tried all alpha values but none met the accuracy criterion. " + "The tuning process was early stopped and " + f"the currently best model(accuracy: {self.cur_best_acc}) was returned." + ) + + need_stop = True return need_stop diff --git a/test/algorithm/test_smooth_quant.py b/test/algorithm/test_smooth_quant.py index 08003161662..564fa84c84e 100644 --- a/test/algorithm/test_smooth_quant.py +++ b/test/algorithm/test_smooth_quant.py @@ -1150,6 +1150,8 @@ def _test_sq_tune_alpha_common(self, eval_func, alpha=np.arange(0.1, 0.2, 0.05). from neural_compressor import quantization from neural_compressor.config import PostTrainingQuantConfig, TuningCriterion + logger.info(f"alpha is: {alpha}") + tuning_criterion = TuningCriterion(max_trials=8) fp32_model = DemoModel() @@ -1183,8 +1185,8 @@ def fake_eval(model, eval_result_lst): # test for alpha is a list for eval_result_lst, note in [ ([1, 0.8, 1.1, 0.7, 1.1], "Expect tuning ends at 2nd trial with alpha is 0.15"), - ([1, 0.8, 0.9, 0.7, 1.1], "Expect tuning ends at 4th trial with alpha is 0.15"), - ([1, 0.9, 0.8, 0.7, 1.1], "Expect tuning ends at 4th trial with alpha is 0.10"), + ([1, 0.8, 0.9, 0.7, 1.1], "Expect tuning ends at 2nd trial with alpha is 0.15"), + ([1, 0.9, 0.8, 0.7, 1.1], "Expect tuning ends at 1st trial with alpha is 0.10"), ]: logger.info(f"test_sq_tune_alpha_common with eval_result_lst: {eval_result_lst}") logger.info(note) @@ -1222,13 +1224,7 @@ def fake_eval(model, eval_result_lst): [1, 0.8, 0.9, 0.7, 1.1], np.arange(0.1, 0.2, 0.05).tolist(), "auto", - "Expect tuning ends at 4th trial with alpha is 0.15 at basic strategy.", - ), - ( - [1, 1.1, 0.8, 0.7, 1.1], - np.arange(0.1, 0.2, 0.05).tolist(), - 0, - "Expect tuning ends at 1th trial with alpha is 0.1", + "Expect tuning ends at 2th trial with alpha is 0.15 at basic strategy.", ), ]: logger.info("test_sq_tune_alpha_common with ") diff --git a/test/algorithm/test_smooth_quant_onnx.py b/test/algorithm/test_smooth_quant_onnx.py index db2877638ce..6cc67b9803e 100644 --- a/test/algorithm/test_smooth_quant_onnx.py +++ b/test/algorithm/test_smooth_quant_onnx.py @@ -279,12 +279,6 @@ def fake_eval(model, eval_result_lst): "auto", "Expect tuning ends at 4th trial with alpha is 0.15 at basic strategy.", ), - ( - [1, 1.1, 0.8, 0.7, 1.1], - np.arange(0.1, 0.2, 0.05).tolist(), - 0, - "Expect tuning ends at 1th trial with alpha is 0.1", - ), ]: logger.info("test_sq_tune_alpha_common with ") logger.info(f"eval_result_lst: {eval_result_lst}, alpha: {alpha}, quant_level: {quant_level}")