Skip to content

Commit

Permalink
Enable the tuning of WOQ algorithm (#1328)
Browse files Browse the repository at this point in the history
* support WOQ algos tuning
---------

Signed-off-by: Kaihui-intel <[email protected]>
Signed-off-by: yuwenzho <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: yuwenzho <[email protected]>
  • Loading branch information
3 people authored and mengniwang95 committed Nov 20, 2023
1 parent 1fecb1c commit 0f8bf5e
Show file tree
Hide file tree
Showing 8 changed files with 239 additions and 5 deletions.
48 changes: 43 additions & 5 deletions neural_compressor/adaptor/onnxrt.py
Original file line number Diff line number Diff line change
Expand Up @@ -1628,26 +1628,37 @@ def quantize(self, tune_cfg, model, data_loader, q_func=None):
Returns:
(dict): quantized model
"""
if self.performance_only:
tmp_model = model
else:
try:
tmp_model = copy.deepcopy(model)
except Exception as e: # pragma: no cover
logger.warning("Fail to deep copy the model due to {}, inplace is used now.".format(repr(e)))
tmp_model = model

assert q_func is None, "quantization aware training has not been supported on ONNXRUNTIME"
for precision in self.query_handler.get_precisions():
if precision == "weight_only_integer":
self.quantizable_op_types += self.query_handler.get_op_types_by_precision(precision=precision)
self.quantizable_ops = self._query_quantizable_ops(model.model)
self.quantizable_ops = self._query_quantizable_ops(tmp_model.model)

self._update_tune_cfg(tune_cfg, tmp_model.model)
quant_config = self._cfg_to_quantize_config(tune_cfg)
algos = set([item["algorithm"] for key, item in quant_config.items() if isinstance(item, dict)])
if "GPTQ" in algos:
from neural_compressor.adaptor.ox_utils.weight_only import gptq_quantize

assert data_loader is not None, "GPTQ WOQ algorithm needs to pass 'calib_dataloader' to quantization.fit()"
percdamp = self.recipes.get("gptq_args", {}).get("percdamp", 0.01)
blocksize = self.recipes.get("gptq_args", {}).get("blocksize", 128)
actorder = self.recipes.get("gptq_args", {}).get("actorder", False)
mse = self.recipes.get("gptq_args", {}).get("mse", False)
perchannel = self.recipes.get("gptq_args", {}).get("perchannel", True)
accuracy_level = self.recipes.get("gptq_args", {}).get("accuracy_level", 0)
calib_sampling_size = tune_cfg.get("calib_sampling_size", 1)
model = gptq_quantize(
model,
tmp_model = gptq_quantize(
tmp_model,
data_loader,
quant_config,
n_samples=calib_sampling_size,
Expand All @@ -1661,12 +1672,13 @@ def quantize(self, tune_cfg, model, data_loader, q_func=None):
if "AWQ" in algos:
from neural_compressor.adaptor.ox_utils.weight_only import awq_quantize

assert data_loader is not None, "AWQ WOQ algorithm needs to pass 'calib_dataloader' to quantization.fit()"
enable_auto_scale = self.recipes.get("awq_args", {}).get("enable_auto_scale", True)
enable_mse_search = self.recipes.get("awq_args", {}).get("enable_mse_search", True)
accuracy_level = self.recipes.get("awq_args", {}).get("accuracy_level", 0)
calib_sampling_size = tune_cfg.get("calib_sampling_size", 1)
model = awq_quantize(
model,
tmp_model = awq_quantize(
tmp_model,
data_loader,
quant_config,
n_samples=calib_sampling_size,
Expand All @@ -1683,6 +1695,7 @@ def quantize(self, tune_cfg, model, data_loader, q_func=None):
quant_config,
accuracy_level=accuracy_level,
)
tmp_model = rtn_quantize(tmp_model, quant_config)
tmp_model.q_config = copy.deepcopy(quant_config)
self._dump_model_op_stats(tmp_model, tune_cfg)
tmp_model.topological_sort()
Expand Down Expand Up @@ -1752,6 +1765,31 @@ def _cfg_to_quantize_config(self, tune_cfg):

return quantize_config

def _update_tune_cfg(self, tune_cfg, model):
"""Update tune cfg according to woq_tuning_cfg."""
if tune_cfg.get("woq_tuning_cfg") is None:
return tune_cfg

from neural_compressor.strategy.utils.constant import WOQ_TUNING_ALGOS

woq_tuning_cfg = tune_cfg.get("woq_tuning_cfg")
new_woq_cfg = WOQ_TUNING_ALGOS.get(woq_tuning_cfg)

for node_cfg in tune_cfg["op"].values():
node_cfg["weight"].update(
{cfg_name: cfg_value for cfg_name, cfg_value in new_woq_cfg.items() if cfg_name in node_cfg["weight"]}
)

# find last matmul and set to fp32
if "DISABLE_LAST_MATMUL" in woq_tuning_cfg:
last_matmul = None
fp32_op_cfg = {"weight": {"dtype": "fp32"}, "activation": {"dtype": "fp32", "quant_mode": "fp32"}}
for node in model.graph.node:
if node.op_type in ["MatMul"]:
last_matmul = (node.name, node.op_type)
if last_matmul in tune_cfg["op"]:
tune_cfg["op"][last_matmul].update(fp32_op_cfg)

def query_fw_capability(self, model):
"""The function is used to query framework capability.
TODO: will be replaced by framework query API
Expand Down
6 changes: 6 additions & 0 deletions neural_compressor/strategy/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,12 @@ def next_tune_cfg(self):
op_tuning_cfg["calib_sampling_size"] = calib_sampling_size_lst[0]
if not self.cur_best_tuning_cfg:
self.cur_best_tuning_cfg = deepcopy(op_tuning_cfg)

# try to tune a WeightOnlyQuant algorithm
if self._should_tuning_woq_algo():
for tune_cfg in self.tuning_woq_algo(tuning_space, deepcopy(self.cur_best_tuning_cfg)):
yield tune_cfg

# try to tune sq alpha
if self._should_tuning_sq_alpha(self.config.recipes):
for tune_cfg in self.tuning_sq_alpha(tuning_space, deepcopy(self.cur_best_tuning_cfg), self.config.recipes):
Expand Down
6 changes: 6 additions & 0 deletions neural_compressor/strategy/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,12 @@ def next_tune_cfg(self):
stage1_max = 1e9 # TODO set a more appropriate value
if not self.cur_best_tuning_cfg:
self.cur_best_tuning_cfg = deepcopy(initial_op_tuning_cfg)

# try to tune a WeightOnlyQuant algorithm
if self._should_tuning_woq_algo():
for tune_cfg in self.tuning_woq_algo(tuning_space, deepcopy(self.cur_best_tuning_cfg)):
yield tune_cfg

# try to tune sq alpha
if self._should_tuning_sq_alpha(self.config.recipes):
for tune_cfg in self.tuning_sq_alpha(
Expand Down
36 changes: 36 additions & 0 deletions neural_compressor/strategy/strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
DotDict,
LazyImport,
Statistics,
check_key_exist,
dump_table,
equal_dicts,
fault_tolerant_file,
Expand Down Expand Up @@ -1153,6 +1154,40 @@ def tuning_sq_alpha(self, tuning_space, tuning_cfg, recipes):
for tune_cfg in sq_sampler:
yield tune_cfg

def _should_tuning_woq_algo(self):
"""Currently, it's only available for the ORT backend with approach is weight_only.
It will be triggered when
a) quant_level is auto or quant_level is 1 && strategy is basic
b) and the "algorithm" is not set in op_type_dict
c) and woq will only trigger once
"""
return (
"onnx" in self.framework.lower()
and "weight_only" in self.config.approach
and not check_key_exist(self.config.op_type_dict, "algorithm")
and not check_key_exist(self.tuning_history, "woq_tuning_cfg")
)

def tuning_woq_algo(self, tuning_space, tuning_cfg):
"""Tuning weight only algorithm.
Args:
tuning_space: tuning space
tuning_cfg: the initial tuning config
Yields:
tuning config
"""
logger.info("[STRATEGY] Start tuning Weight Only Quant' algo.")
woq_sampler = tuning_sampler_dict.get_class("woq_algorithm")(tuning_space, [], tuning_cfg)
for tune_cfg in woq_sampler:
yield tune_cfg

logger.info(
"[Strategy] The best tuning config with WeightOnlyQuant is" f"{self.cur_best_tuning_cfg['woq_tuning_cfg']}."
)

def initial_dynamic_cfg_based_on_static_cfg(self, op_static_cfg: OpTuningConfig):
"""Init the dynamic tuning config according to the static config.
Expand Down Expand Up @@ -1322,6 +1357,7 @@ def _tune_cfg_converter(self, op_tuning_cfg):
# For not tuning recipe, tune cfg use it directly
tune_cfg["recipe_cfgs"].update(self._not_tuning_recipes_values)
tune_cfg["trial_number"] = deepcopy(self.trials_count)
tune_cfg.setdefault("woq_tuning_cfg", op_tuning_cfg.get("woq_tuning_cfg"))
# The sq-related args comes from user config, current best tuning config
# TODO simplify the logic for transforming the arguments
# update the sq-related args from self.cur_best_tuning_cfg
Expand Down
10 changes: 10 additions & 0 deletions neural_compressor/strategy/utils/constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# limitations under the License.
"""Strategy constant."""


PRECISION_LIST = ["bf16", "fp16", "fp32"]
QUANT_MODE_SET = {"static", "dynamic"}
LOWER_BIT_LIST = ["int4"]
Expand Down Expand Up @@ -56,3 +57,12 @@
"last_conv_or_matmul_quantization",
"pre_post_process_quantization",
}


WOQ_TUNING_ALGOS = {
"RTN_G32ASYM": {"algorithm": "RTN", "group_size": 32, "scheme": "asym"},
"GPTQ_G32ASYM": {"algorithm": "GPTQ", "group_size": 32, "scheme": "asym"},
"GPTQ_G32ASYM_DISABLE_LAST_MATMUL": {"algorithm": "GPTQ", "group_size": 32, "scheme": "asym"},
"GPTQ_G128ASYM": {"algorithm": "GPTQ", "group_size": 128, "scheme": "asym"},
"AWQ_G32ASYM": {"algorithm": "AWQ", "group_size": 32, "scheme": "asym"},
}
32 changes: 32 additions & 0 deletions neural_compressor/strategy/utils/tuning_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from typing import Any, Dict, List, Tuple, Union

from ...utils import logger
from ..utils.constant import WOQ_TUNING_ALGOS
from .tuning_space import TuningSpace, pattern_to_internal, quant_mode_from_pattern
from .tuning_structs import OpTuningConfig
from .utility import ClassRegister
Expand Down Expand Up @@ -609,3 +610,34 @@ def __iter__(self):
recipe_cfgs["smooth_quant_args"] = {"alpha": alpha}
logger.debug(f"[STRATEGY] set smooth quant alpha with: {alpha:.4f}")
yield new_tune_cfg


@tuning_sampler_dict("woq_algorithm")
class WeightOnlyQuantSampler(TuningSampler):
"""Not displayed in API Docs."""

def __init__(
self,
tuning_space: TuningSpace,
tuning_order_lst: List[TuningOrder],
initial_op_tuning_cfg: Dict,
):
"""Init tuning sampler.
Args:
tuning_space: The tuning space.
tuning_order_lst: The traverse orders.
initial_op_tuning_cfg: The initialized tuning config.
"""
super().__init__(tuning_space, tuning_order_lst, initial_op_tuning_cfg)

def __iter__(self):
"""Yield the next tuning config.
Yields:
The next tuning config.
"""
new_tune_cfg = copy.deepcopy(self.initial_op_tuning_cfg)
for algo in WOQ_TUNING_ALGOS.keys():
new_tune_cfg["woq_tuning_cfg"] = algo
yield new_tune_cfg
31 changes: 31 additions & 0 deletions neural_compressor/utils/utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -1092,3 +1092,34 @@ def mse_metric_gap(fp32_tensor: Any, dequantize_tensor: Any) -> float:
diff_tensor = fp32_tensor_norm - dequantize_tensor_norm
euclidean_dist = np.sum(diff_tensor**2) # type: ignore
return euclidean_dist / fp32_tensor.size


def check_key_exist(data, key):
"""Recursively checks if a key exists in a dictionary or list.
Args:
data (dict or list): The dictionary or list to search.
key (any): The key to search for.
Returns:
bool: True if the key exists in the data structure, False otherwise.
Examples:
>>> check_key_exist({'a': 1, 'b': {'c': 2}}, 'c')
True
>>> check_key_exist([{'a': 1}, {'b': 2}], 'b')
True
>>> check_key_exist({'a': 1, 'b': [1, 2, 3]}, 'c')
False
"""
if isinstance(data, dict):
if key in data:
return True
for value in data.values():
if check_key_exist(value, key):
return True
elif isinstance(data, list):
for item in data:
if check_key_exist(item, key):
return True
return False
75 changes: 75 additions & 0 deletions test/adaptor/onnxrt_adaptor/test_weight_only_adaptor.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import copy
import os
import shutil
import subprocess
Expand All @@ -9,6 +10,7 @@
from transformers import AutoTokenizer

from neural_compressor import PostTrainingQuantConfig, quantization
from neural_compressor.utils.constant import FP32


def Inference(model, data):
Expand Down Expand Up @@ -265,6 +267,79 @@ def test_GPTQ_quant(self):
]
self.assertTrue(len(rtn_op_names) + 1, len(gptq_op_names))

def _test_woq_tune_common(self, eval_func, quant_level=1, **kwargs):
from neural_compressor import quantization
from neural_compressor.config import PostTrainingQuantConfig, TuningCriterion

tuning_criterion = TuningCriterion(max_trials=5)

fp32_model = copy.deepcopy(self.model)
conf = PostTrainingQuantConfig(
approach="weight_only", quant_level=quant_level, tuning_criterion=tuning_criterion, **kwargs
)
q_model = quantization.fit(
fp32_model,
conf,
calib_dataloader=self.dataloader,
eval_func=eval_func,
)
self.assertIsNotNone(q_model)
return q_model

def _count_woq_matmul(self, q_model, bits=4, group_size=32):
op_names = [
i.name
for i in q_model.nodes()
if i.op_type.startswith("MatMul") and i.input[1].endswith("_Q{}G{}".format(bits, group_size))
]
return len(op_names)

def test_woq_tune(self):
from functools import partial

def fake_eval(model, eval_result_lst):
acc = eval_result_lst.pop(0)
return acc

quant_levels = ["auto", 1]
for quant_level in quant_levels:
# Expect tuning ends with WOQ algorithm 'RTN_G32ASYM'
partial_fake_eval = partial(fake_eval, eval_result_lst=[1, 1.1])
woq_model_1 = self._test_woq_tune_common(partial_fake_eval, quant_level)
self.assertEqual(self._count_woq_matmul(woq_model_1), 31)

# Expect tuning ends with WOQ algorithm 'GPTQ_G32ASYM'
partial_fake_eval = partial(fake_eval, eval_result_lst=[1, 0.8, 1.1])
woq_model_2 = self._test_woq_tune_common(partial_fake_eval, quant_level)
self.assertEqual(self._count_woq_matmul(woq_model_2), 31)

# Expect tuning ends with WOQ algorithm 'GPTQ_G32ASYM_DISABLE_LAST_MATMUL'
partial_fake_eval = partial(fake_eval, eval_result_lst=[1, 0.8, 0.8, 1.1])
woq_model_3 = self._test_woq_tune_common(partial_fake_eval, quant_level)
self.assertEqual(self._count_woq_matmul(woq_model_3), 30)

# Expect tuning ends with WOQ algorithm 'GPTQ_G128ASYM'
partial_fake_eval = partial(fake_eval, eval_result_lst=[1, 0.8, 0.8, 0.8, 1.1])
woq_model_4 = self._test_woq_tune_common(partial_fake_eval, quant_level)
self.assertEqual(self._count_woq_matmul(woq_model_4, group_size=128), 31)

# Expect tuning ends with WOQ algorithm 'AWQ_G32ASYM'
partial_fake_eval = partial(fake_eval, eval_result_lst=[1, 0.8, 0.8, 0.8, 0.8, 1.1])
woq_model_5 = self._test_woq_tune_common(partial_fake_eval, quant_level)
self.assertEqual(self._count_woq_matmul(woq_model_5), 31)

# test WOQ tuning with fallback
partial_fake_eval = partial(fake_eval, eval_result_lst=[1, 1.1])
woq_model = self._test_woq_tune_common(
partial_fake_eval, "auto", op_name_dict={"/transformer/h.*/attn/k_proj/MatMul": FP32}
)
self.assertEqual(self._count_woq_matmul(woq_model), 26)

# test 8 bits WOQ
partial_fake_eval = partial(fake_eval, eval_result_lst=[1, 1.1])
woq_model = self._test_woq_tune_common(partial_fake_eval, "auto", op_type_dict={".*": {"weight": {"bits": 8}}})
self.assertEqual(self._count_woq_matmul(woq_model, bits=8), 31)


if __name__ == "__main__":
unittest.main()

0 comments on commit 0f8bf5e

Please sign in to comment.