Skip to content

Commit

Permalink
Enhance autotune to return the best q_model directly (#1875)
Browse files Browse the repository at this point in the history
Signed-off-by: yiliu30 <[email protected]>
  • Loading branch information
yiliu30 authored Jun 18, 2024
1 parent 90fb431 commit 5a0374e
Show file tree
Hide file tree
Showing 7 changed files with 99 additions and 11 deletions.
8 changes: 6 additions & 2 deletions neural_compressor/common/base_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,13 +336,17 @@ def set_baseline(self, baseline: float):
def get_number_of_trials(self):
return len(self.tuning_history)

def get_best_quant_config(self) -> BaseConfig:
def get_best_trial_record(self) -> _TrialRecord:
assert self.get_number_of_trials() > 0, "No trial record in tuning monitor."
# Put the record with a higher score at the beginning
sorted_trials_records: List[_TrialRecord] = sorted(
self.tuning_history, key=lambda x: x.trial_result, reverse=True
)
return sorted_trials_records[0].quant_config
return sorted_trials_records[0]

def get_best_quant_config(self) -> BaseConfig:
best_trial_record = self.get_best_trial_record()
return best_trial_record.quant_config

def need_stop(self) -> bool:
"""Check if need to stop tuning. Either accuracy goal is met, max trials is reached or timeout is reached.
Expand Down
15 changes: 15 additions & 0 deletions neural_compressor/common/utils/utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import collections
import importlib
import subprocess
import time
from typing import Dict

import cpuinfo
import psutil
Expand All @@ -35,6 +37,7 @@
"LazyImport",
"CpuInfo",
"default_tuning_logger",
"call_counter",
]


Expand Down Expand Up @@ -225,3 +228,15 @@ def inner_wrapper(*args, **kwargs):
return inner_wrapper

return log_process_wrapper


# decorator for recording number of times a function is called
FUNC_CALL_COUNTS: Dict[str, int] = collections.defaultdict(int)


def call_counter(func):
def wrapper(*args, **kwargs):
FUNC_CALL_COUNTS[func.__name__] += 1
return func(*args, **kwargs)

return wrapper
15 changes: 11 additions & 4 deletions neural_compressor/tensorflow/quantization/autotune.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from neural_compressor.common import logger
from neural_compressor.common.base_config import BaseConfig, get_all_config_set_from_config_registry
from neural_compressor.common.base_tuning import EvaluationFuncWrapper, TuningConfig, init_tuning
from neural_compressor.common.utils import dump_elapsed_time
from neural_compressor.common.utils import call_counter, dump_elapsed_time
from neural_compressor.tensorflow.quantization import quantize_model
from neural_compressor.tensorflow.quantization.config import FRAMEWORK_NAME, StaticQuantConfig
from neural_compressor.tensorflow.utils import BaseModel, Model, constants
Expand All @@ -36,6 +36,7 @@ def get_all_config_set() -> Union[BaseConfig, List[BaseConfig]]:


@dump_elapsed_time("Pass auto-tune")
@call_counter
def autotune(
model: Union[str, tf.keras.Model, BaseModel],
tune_config: TuningConfig,
Expand All @@ -52,7 +53,7 @@ def autotune(
baseline: float = eval_func_wrapper.evaluate(model)
tuning_monitor.set_baseline(baseline)
tuning_logger.tuning_start()
for trial_index, quant_config in enumerate(config_loader):
for trial_index, quant_config in enumerate(config_loader, 1):
tuning_logger.trial_start(trial_index=trial_index)
tuning_logger.execution_start()
logger.info(quant_config.to_dict())
Expand All @@ -65,8 +66,14 @@ def autotune(
tuning_logger.trial_end(trial_index)
if tuning_monitor.need_stop():
logger.info("Stopped tuning.")
best_quant_config: BaseConfig = tuning_monitor.get_best_quant_config()
best_quant_model = quantize_model(model, quant_config, calib_dataloader, calib_iteration)
best_trial_record = tuning_monitor.get_best_trial_record()
if best_trial_record.trial_index != trial_index:
logger.info("Re-quantizing with best quantization config...")
del q_model
best_quant_config: BaseConfig = best_trial_record.quant_config
best_quant_model = quantize_model(model, best_quant_config, calib_dataloader, calib_iteration)
else:
best_quant_model = q_model
break
tuning_logger.tuning_end()
return best_quant_model
9 changes: 5 additions & 4 deletions neural_compressor/torch/quantization/autotune.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def autotune(
baseline: float = eval_func_wrapper.evaluate(model)
tuning_monitor.set_baseline(baseline)
tuning_logger.tuning_start()
for trial_index, quant_config in enumerate(config_loader):
for trial_index, quant_config in enumerate(config_loader, 1):
tuning_logger.trial_start(trial_index=trial_index)
tuning_logger.execution_start()
logger.info(quant_config.to_dict())
Expand All @@ -93,10 +93,11 @@ def autotune(
tuning_logger.trial_end(trial_index)
if tuning_monitor.need_stop():
logger.info("Stopped tuning.")
if trial_index == 0: # recover the best q_model from previous results.
logger.info("Reconvering the best quantized model...")
best_trial_record = tuning_monitor.get_best_trial_record()
if best_trial_record.trial_index != trial_index:
logger.info("Re-quantizing with best quantization config...")
del q_model # maybe gc.collect() is needed for memory release
best_quant_config: BaseConfig = tuning_monitor.get_best_quant_config()
best_quant_config: BaseConfig = best_trial_record.quant_config
# !!! Make sure to use deepcopy only when inplace is set to `True`.
q_model = quantize(
deepcopy(model),
Expand Down
3 changes: 2 additions & 1 deletion neural_compressor/torch/quantization/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import torch

from neural_compressor.common.base_config import BaseConfig, ComposableConfig, config_registry
from neural_compressor.common.utils import Mode, log_process
from neural_compressor.common.utils import Mode, call_counter, log_process
from neural_compressor.torch.quantization.config import SmoothQuantConfig, StaticQuantConfig
from neural_compressor.torch.utils import is_ipex_available, logger
from neural_compressor.torch.utils.utility import WHITE_MODULE_LIST, algos_mapping, get_model_info
Expand All @@ -31,6 +31,7 @@ def need_apply(configs_mapping: Dict[Tuple[str, callable], BaseConfig], algo_nam


@log_process(mode=Mode.QUANTIZE)
@call_counter
def quantize(
model: torch.nn.Module,
quant_config: BaseConfig,
Expand Down
22 changes: 22 additions & 0 deletions test/3x/common/test_utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import unittest
from unittest.mock import MagicMock, patch

import neural_compressor.common.utils.utility as inc_utils
from neural_compressor.common import options
from neural_compressor.common.utils import (
CpuInfo,
Expand Down Expand Up @@ -166,5 +167,26 @@ def __init__(self):
assert instance2.value == 1, "Singleton should return the same instance"


class TestCallCounter(unittest.TestCase):
def test_call_counter(self):
# empty dict
inc_utils.FUNC_CALL_COUNTS.clear()

@inc_utils.call_counter
def add(a, b):
return a + b

# Initial count should be 0
self.assertEqual(inc_utils.FUNC_CALL_COUNTS["add"], 0)

# Call the function multiple times
add(1, 2)
add(3, 4)
add(5, 6)

# Count should be incremented accordingly
self.assertEqual(inc_utils.FUNC_CALL_COUNTS["add"], 3)


if __name__ == "__main__":
unittest.main()
38 changes: 38 additions & 0 deletions test/3x/torch/test_autotune.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import torch
import transformers

import neural_compressor.common.utils.utility as inc_utils
from neural_compressor.common import logger
from neural_compressor.torch.quantization import (
MixPrecisionConfig,
Expand Down Expand Up @@ -163,6 +164,43 @@ def eval_acc_fn(model) -> float:

custom_tune_config = TuningConfig(config_set=[RTNConfig(bits=[4, 6])], max_trials=2)
best_model = autotune(model=build_simple_torch_model(), tune_config=custom_tune_config, eval_fn=eval_acc_fn)
print(inc_utils.FUNC_CALL_COUNTS)
self.assertIsNotNone(best_model)

def test_autotune_return_qmodel_directly(self):
inc_utils.FUNC_CALL_COUNTS.clear()

baseline = 1
eval_result = [0.9, 1.1]
acc_list = [baseline] + eval_result

def eval_acc_fn(model) -> float:
acc = acc_list.pop(0)
return acc

custom_tune_config = TuningConfig(config_set=[RTNConfig(bits=[4, 6])], max_trials=2)
best_model = autotune(model=build_simple_torch_model(), tune_config=custom_tune_config, eval_fn=eval_acc_fn)
assert (
inc_utils.FUNC_CALL_COUNTS.get("quantize") == 2
), f"quantize should be called twice, but got {inc_utils.FUNC_CALL_COUNTS.get('quantize')}"
self.assertIsNotNone(best_model)

def test_autotune_return_re_quant_qmodel(self):
inc_utils.FUNC_CALL_COUNTS.clear()

baseline = 1
eval_result = [0.9, 0.8]
acc_list = [baseline] + eval_result

def eval_acc_fn(model) -> float:
acc = acc_list.pop(0)
return acc

custom_tune_config = TuningConfig(config_set=[RTNConfig(bits=[4, 6])], max_trials=2)
best_model = autotune(model=build_simple_torch_model(), tune_config=custom_tune_config, eval_fn=eval_acc_fn)
assert (
inc_utils.FUNC_CALL_COUNTS.get("quantize") == 3
), f"quantize should be called three times, but got {inc_utils.FUNC_CALL_COUNTS.get('quantize')}"
self.assertIsNotNone(best_model)

@reset_tuning_target
Expand Down

0 comments on commit 5a0374e

Please sign in to comment.