From 5a0374e7db23cac209af78f1ace9b38d23bebbb0 Mon Sep 17 00:00:00 2001 From: Yi Liu <106061964+yiliu30@users.noreply.github.com> Date: Tue, 18 Jun 2024 16:21:05 +0800 Subject: [PATCH] Enhance autotune to return the best `q_model` directly (#1875) Signed-off-by: yiliu30 --- neural_compressor/common/base_tuning.py | 8 +++- neural_compressor/common/utils/utility.py | 15 ++++++++ .../tensorflow/quantization/autotune.py | 15 ++++++-- .../torch/quantization/autotune.py | 9 +++-- .../torch/quantization/quantize.py | 3 +- test/3x/common/test_utility.py | 22 +++++++++++ test/3x/torch/test_autotune.py | 38 +++++++++++++++++++ 7 files changed, 99 insertions(+), 11 deletions(-) diff --git a/neural_compressor/common/base_tuning.py b/neural_compressor/common/base_tuning.py index 2a1adfa480b..30910a865f7 100644 --- a/neural_compressor/common/base_tuning.py +++ b/neural_compressor/common/base_tuning.py @@ -336,13 +336,17 @@ def set_baseline(self, baseline: float): def get_number_of_trials(self): return len(self.tuning_history) - def get_best_quant_config(self) -> BaseConfig: + def get_best_trial_record(self) -> _TrialRecord: assert self.get_number_of_trials() > 0, "No trial record in tuning monitor." # Put the record with a higher score at the beginning sorted_trials_records: List[_TrialRecord] = sorted( self.tuning_history, key=lambda x: x.trial_result, reverse=True ) - return sorted_trials_records[0].quant_config + return sorted_trials_records[0] + + def get_best_quant_config(self) -> BaseConfig: + best_trial_record = self.get_best_trial_record() + return best_trial_record.quant_config def need_stop(self) -> bool: """Check if need to stop tuning. Either accuracy goal is met, max trials is reached or timeout is reached. diff --git a/neural_compressor/common/utils/utility.py b/neural_compressor/common/utils/utility.py index 773d083398b..82f24243a9b 100644 --- a/neural_compressor/common/utils/utility.py +++ b/neural_compressor/common/utils/utility.py @@ -15,9 +15,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +import collections import importlib import subprocess import time +from typing import Dict import cpuinfo import psutil @@ -35,6 +37,7 @@ "LazyImport", "CpuInfo", "default_tuning_logger", + "call_counter", ] @@ -225,3 +228,15 @@ def inner_wrapper(*args, **kwargs): return inner_wrapper return log_process_wrapper + + +# decorator for recording number of times a function is called +FUNC_CALL_COUNTS: Dict[str, int] = collections.defaultdict(int) + + +def call_counter(func): + def wrapper(*args, **kwargs): + FUNC_CALL_COUNTS[func.__name__] += 1 + return func(*args, **kwargs) + + return wrapper diff --git a/neural_compressor/tensorflow/quantization/autotune.py b/neural_compressor/tensorflow/quantization/autotune.py index ab0d3a61949..55b089b923c 100644 --- a/neural_compressor/tensorflow/quantization/autotune.py +++ b/neural_compressor/tensorflow/quantization/autotune.py @@ -20,7 +20,7 @@ from neural_compressor.common import logger from neural_compressor.common.base_config import BaseConfig, get_all_config_set_from_config_registry from neural_compressor.common.base_tuning import EvaluationFuncWrapper, TuningConfig, init_tuning -from neural_compressor.common.utils import dump_elapsed_time +from neural_compressor.common.utils import call_counter, dump_elapsed_time from neural_compressor.tensorflow.quantization import quantize_model from neural_compressor.tensorflow.quantization.config import FRAMEWORK_NAME, StaticQuantConfig from neural_compressor.tensorflow.utils import BaseModel, Model, constants @@ -36,6 +36,7 @@ def get_all_config_set() -> Union[BaseConfig, List[BaseConfig]]: @dump_elapsed_time("Pass auto-tune") +@call_counter def autotune( model: Union[str, tf.keras.Model, BaseModel], tune_config: TuningConfig, @@ -52,7 +53,7 @@ def autotune( baseline: float = eval_func_wrapper.evaluate(model) tuning_monitor.set_baseline(baseline) tuning_logger.tuning_start() - for trial_index, quant_config in enumerate(config_loader): + for trial_index, quant_config in enumerate(config_loader, 1): tuning_logger.trial_start(trial_index=trial_index) tuning_logger.execution_start() logger.info(quant_config.to_dict()) @@ -65,8 +66,14 @@ def autotune( tuning_logger.trial_end(trial_index) if tuning_monitor.need_stop(): logger.info("Stopped tuning.") - best_quant_config: BaseConfig = tuning_monitor.get_best_quant_config() - best_quant_model = quantize_model(model, quant_config, calib_dataloader, calib_iteration) + best_trial_record = tuning_monitor.get_best_trial_record() + if best_trial_record.trial_index != trial_index: + logger.info("Re-quantizing with best quantization config...") + del q_model + best_quant_config: BaseConfig = best_trial_record.quant_config + best_quant_model = quantize_model(model, best_quant_config, calib_dataloader, calib_iteration) + else: + best_quant_model = q_model break tuning_logger.tuning_end() return best_quant_model diff --git a/neural_compressor/torch/quantization/autotune.py b/neural_compressor/torch/quantization/autotune.py index e54a9d97748..bdcbf642e47 100644 --- a/neural_compressor/torch/quantization/autotune.py +++ b/neural_compressor/torch/quantization/autotune.py @@ -72,7 +72,7 @@ def autotune( baseline: float = eval_func_wrapper.evaluate(model) tuning_monitor.set_baseline(baseline) tuning_logger.tuning_start() - for trial_index, quant_config in enumerate(config_loader): + for trial_index, quant_config in enumerate(config_loader, 1): tuning_logger.trial_start(trial_index=trial_index) tuning_logger.execution_start() logger.info(quant_config.to_dict()) @@ -93,10 +93,11 @@ def autotune( tuning_logger.trial_end(trial_index) if tuning_monitor.need_stop(): logger.info("Stopped tuning.") - if trial_index == 0: # recover the best q_model from previous results. - logger.info("Reconvering the best quantized model...") + best_trial_record = tuning_monitor.get_best_trial_record() + if best_trial_record.trial_index != trial_index: + logger.info("Re-quantizing with best quantization config...") del q_model # maybe gc.collect() is needed for memory release - best_quant_config: BaseConfig = tuning_monitor.get_best_quant_config() + best_quant_config: BaseConfig = best_trial_record.quant_config # !!! Make sure to use deepcopy only when inplace is set to `True`. q_model = quantize( deepcopy(model), diff --git a/neural_compressor/torch/quantization/quantize.py b/neural_compressor/torch/quantization/quantize.py index ff8298dad88..bc3020a942c 100644 --- a/neural_compressor/torch/quantization/quantize.py +++ b/neural_compressor/torch/quantization/quantize.py @@ -18,7 +18,7 @@ import torch from neural_compressor.common.base_config import BaseConfig, ComposableConfig, config_registry -from neural_compressor.common.utils import Mode, log_process +from neural_compressor.common.utils import Mode, call_counter, log_process from neural_compressor.torch.quantization.config import SmoothQuantConfig, StaticQuantConfig from neural_compressor.torch.utils import is_ipex_available, logger from neural_compressor.torch.utils.utility import WHITE_MODULE_LIST, algos_mapping, get_model_info @@ -31,6 +31,7 @@ def need_apply(configs_mapping: Dict[Tuple[str, callable], BaseConfig], algo_nam @log_process(mode=Mode.QUANTIZE) +@call_counter def quantize( model: torch.nn.Module, quant_config: BaseConfig, diff --git a/test/3x/common/test_utility.py b/test/3x/common/test_utility.py index 00a3be79514..527f74a4a13 100644 --- a/test/3x/common/test_utility.py +++ b/test/3x/common/test_utility.py @@ -11,6 +11,7 @@ import unittest from unittest.mock import MagicMock, patch +import neural_compressor.common.utils.utility as inc_utils from neural_compressor.common import options from neural_compressor.common.utils import ( CpuInfo, @@ -166,5 +167,26 @@ def __init__(self): assert instance2.value == 1, "Singleton should return the same instance" +class TestCallCounter(unittest.TestCase): + def test_call_counter(self): + # empty dict + inc_utils.FUNC_CALL_COUNTS.clear() + + @inc_utils.call_counter + def add(a, b): + return a + b + + # Initial count should be 0 + self.assertEqual(inc_utils.FUNC_CALL_COUNTS["add"], 0) + + # Call the function multiple times + add(1, 2) + add(3, 4) + add(5, 6) + + # Count should be incremented accordingly + self.assertEqual(inc_utils.FUNC_CALL_COUNTS["add"], 3) + + if __name__ == "__main__": unittest.main() diff --git a/test/3x/torch/test_autotune.py b/test/3x/torch/test_autotune.py index a9bd8a971c5..eeee6c8f561 100644 --- a/test/3x/torch/test_autotune.py +++ b/test/3x/torch/test_autotune.py @@ -6,6 +6,7 @@ import torch import transformers +import neural_compressor.common.utils.utility as inc_utils from neural_compressor.common import logger from neural_compressor.torch.quantization import ( MixPrecisionConfig, @@ -163,6 +164,43 @@ def eval_acc_fn(model) -> float: custom_tune_config = TuningConfig(config_set=[RTNConfig(bits=[4, 6])], max_trials=2) best_model = autotune(model=build_simple_torch_model(), tune_config=custom_tune_config, eval_fn=eval_acc_fn) + print(inc_utils.FUNC_CALL_COUNTS) + self.assertIsNotNone(best_model) + + def test_autotune_return_qmodel_directly(self): + inc_utils.FUNC_CALL_COUNTS.clear() + + baseline = 1 + eval_result = [0.9, 1.1] + acc_list = [baseline] + eval_result + + def eval_acc_fn(model) -> float: + acc = acc_list.pop(0) + return acc + + custom_tune_config = TuningConfig(config_set=[RTNConfig(bits=[4, 6])], max_trials=2) + best_model = autotune(model=build_simple_torch_model(), tune_config=custom_tune_config, eval_fn=eval_acc_fn) + assert ( + inc_utils.FUNC_CALL_COUNTS.get("quantize") == 2 + ), f"quantize should be called twice, but got {inc_utils.FUNC_CALL_COUNTS.get('quantize')}" + self.assertIsNotNone(best_model) + + def test_autotune_return_re_quant_qmodel(self): + inc_utils.FUNC_CALL_COUNTS.clear() + + baseline = 1 + eval_result = [0.9, 0.8] + acc_list = [baseline] + eval_result + + def eval_acc_fn(model) -> float: + acc = acc_list.pop(0) + return acc + + custom_tune_config = TuningConfig(config_set=[RTNConfig(bits=[4, 6])], max_trials=2) + best_model = autotune(model=build_simple_torch_model(), tune_config=custom_tune_config, eval_fn=eval_acc_fn) + assert ( + inc_utils.FUNC_CALL_COUNTS.get("quantize") == 3 + ), f"quantize should be called three times, but got {inc_utils.FUNC_CALL_COUNTS.get('quantize')}" self.assertIsNotNone(best_model) @reset_tuning_target