Skip to content

Commit

Permalink
Migrate RTN, HQQ and AWQ to Torch new 3.x API (#1765)
Browse files Browse the repository at this point in the history
Migrate RTN, HQQ and AWQ to Torch new 3.x API
---------

Signed-off-by: yuwenzho <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
yuwenzho and pre-commit-ci[bot] authored May 7, 2024
1 parent 84d7055 commit 1a45090
Show file tree
Hide file tree
Showing 13 changed files with 778 additions and 410 deletions.
19 changes: 9 additions & 10 deletions neural_compressor/torch/algorithms/base_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,19 +85,18 @@ def quantize(self, model: torch.nn.Module, *args: Any, **kwargs: Any):
Returns:
A quantized model.
"""
model = self.prepare(model, *args, **kwargs)

run_fn = kwargs.get("run_fn", None)
run_args = kwargs.get("run_args", None)
assert run_fn is not None, (
"Can't find run_func. Please provide run_func to quantize API "
"or overwrite quantize member function in your Quantizer class."
)
if run_fn is not None:
run_args = kwargs.get("run_args", None)
if run_args:
run_fn(model, *run_args)
else:
run_fn(model)

model = self.prepare(model, *args, **kwargs)
if run_args:
run_fn(model, *run_args)
else:
run_fn(model)
model = self.convert(model, *args, **kwargs)

return model

def execute(self, model: torch.nn.Module, mode, *args: Any, **kwargs: Any): # pragma: no cover
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def __init__(self, quant_config: OrderedDict = {}):
quant_config (OrderedDict, optional): quantization config for ops. Defaults to {}.
"""
super().__init__(quant_config)
self.user_cfg = OrderedDict()

def prepare(self, model, example_inputs, inplace=True, *args, **kwargs):
"""Prepares a given model for quantization.
Expand All @@ -71,7 +72,7 @@ def prepare(self, model, example_inputs, inplace=True, *args, **kwargs):
model, example_inputs
)
# update json file in ipex_config_path; map ipex op_name to pt op_name
user_cfg = cfg_to_qconfig(self.quant_config, cfgs, op_infos_from_cfgs, output_tensor_id_op_name)
self.user_cfg = cfg_to_qconfig(self.quant_config, cfgs, op_infos_from_cfgs, output_tensor_id_op_name)
model.eval()

# Check save_qconf_summary part is a workaround for IPEX bug.
Expand All @@ -94,7 +95,6 @@ def prepare(self, model, example_inputs, inplace=True, *args, **kwargs):
model = ipex.quantization.prepare(model, static_qconfig, example_inputs=example_inputs, inplace=inplace)

model.load_qconf_summary(qconf_summary=ipex_config_path)
setattr(model, "user_cfg", user_cfg)
return model

def convert(self, model, example_inputs, inplace=True, *args, **kwargs):
Expand All @@ -110,16 +110,14 @@ def convert(self, model, example_inputs, inplace=True, *args, **kwargs):
"""
from neural_compressor.torch.algorithms.static_quant import save

user_cfg = getattr(model, "user_cfg", OrderedDict())

model.save_qconf_summary(qconf_summary=ipex_config_path)
model = _ipex_post_quant_process(model, example_inputs, inplace=inplace)

with open(ipex_config_path, "r") as f:
model.tune_cfg = json.load(f)
model.ipex_config_path = ipex_config_path

dump_model_op_stats(user_cfg)
dump_model_op_stats(self.user_cfg)

logger.info("Static quantization done.")
model.ori_save = model.save
Expand Down
183 changes: 98 additions & 85 deletions neural_compressor/torch/algorithms/weight_only/awq.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,11 @@
# Copied from neural_compressor/adaptor/torch_utils/awq.py

import copy
from collections import OrderedDict

import torch

from neural_compressor.torch.algorithms import Quantizer
from neural_compressor.torch.utils import get_device, logger

from .modules import MulLinear
Expand All @@ -26,13 +28,13 @@
get_absorb_layers,
get_block_prefix,
get_example_input,
get_hidden_states,
get_module_input_output,
model_forward,
recover_forward,
replace_forward,
set_module,
)

__all__ = ["awq_quantize"]
__all__ = ["AWQQuantizer"]


def _get_absorb_per_block(model, example_inputs, folding=False, weight_config={}):
Expand Down Expand Up @@ -113,15 +115,15 @@ def __init__(
self,
model,
example_inputs=None,
calib_func=None,
dataloader=None,
n_samples=128,
data_type="int",
bits=4,
group_size=32,
scheme="asym",
use_full_range=False,
weight_config={},
total_block_args=[],
total_block_kwargs=[],
):

self.example_inputs = example_inputs
Expand All @@ -130,11 +132,9 @@ def __init__(
assert dataloader is not None, "datalaoder or example_inputs is required."
self.example_inputs = get_example_input(dataloader)
self._move_model_and_data_to_device()
# Step 1: get hidden states and kwargs of first block.
self.total_block_args, self.total_block_kwargs = get_hidden_states(
model, dataloader=dataloader, n_samples=n_samples, calib_func=calib_func
)
# Step 2: get block list and block prefix, number
self.total_block_args = total_block_args
self.total_block_kwargs = total_block_kwargs
# get block list and block prefix, number
self.block_prefix, self.block_num = get_block_prefix(model)
self.block_list = fetch_module(model, self.block_prefix)
self.data_type = data_type
Expand Down Expand Up @@ -429,14 +429,15 @@ def apply_quantize_with_clip(self, return_int=False):
"""
# apply quantization and clip
logger.info("Quantizing the AWQ optimized fp32 model")
from .rtn import rtn_quantize
from .rtn import RTNQuantizer

rtn_quantizer = RTNQuantizer(quant_config=self.weight_config)

self.model = rtn_quantize(
self.model = rtn_quantizer.quantize(
self.model,
num_bits=self.bits,
bits=self.bits,
group_size=self.group_size,
scheme=self.scheme,
weight_config=self.weight_config,
return_int=return_int,
use_full_range=self.use_full_range,
)
Expand Down Expand Up @@ -492,78 +493,90 @@ def module_inference(self, model, inputs):
return total_out


@torch.no_grad()
def awq_quantize(
model,
bits=4,
group_size=32,
scheme="asym",
weight_config={},
example_inputs=None,
dataloader=None,
n_samples=128,
calib_func=None,
use_auto_scale=True,
use_mse_search=True,
folding=False,
return_int=False,
use_full_range=False,
data_type="int",
):
"""Quant the model with Activation-aware Weight quantization(AWQ) method.
class AWQQuantizer(Quantizer):
def __init__(self, quant_config: OrderedDict = {}):
"""Init an AWQQuantizer object.
Args:
model (torch.nn.Module): torch model.
example_inputs: example_inputs.
weight_config (dict, optional): contains all info required by AWQ. Defaults to {}.
For example,
weight_config={
'fc2':
{
# 'absorb_layer': 'fc1',
'bits': 4,
'group_size': 32,
'scheme': 'sym'
}
}
absorb_dict (dict, optional): contains all absorb info required by AWQ.. Defaults to {}.
For example,
absorb_dict = {
# 'absorb_layer': absorbed_layer
'fc1': ['fc1', 'fc2', 'fc3']
} # in this case, fc2 and fc3 need to share the same scale. fc1 is self absorbed.
# self absorb module will replace with MulLinear, which contains torch.mul and module.
n_samples: calibration sample number.
use_auto_scale (bool, optional): whether enable scale for salient weight. Defaults to True.
use_mse_search (bool, optional): whether enable clip for weight by checking mse. Defaults to True.
calib_func: a custom inference function to replace dataloader and iters.
n_blocks: split model into block number to avoid OOM.
return_int (bool, optional): Choose return fp32 or int32 model.
Defaults to False.
use_full_range (bool, optional): Choose sym range whether use -2**(bits-1).
Args:
quant_config (OrderedDict, optional): quantization config for ops. Defaults to {}.
"""
super().__init__(quant_config)

Returns:
model: fake quantized model
"""
@torch.no_grad()
def prepare(self, model, *args, **kwargs):
"""Prepare a given model to get hidden states and kwargs of first block.
Args:
model: A float torch model.
assert isinstance(model, torch.nn.Module), "only support torch module"
awq = ActAwareWeightQuant(
Returns:
A prepared model.
"""
assert isinstance(model, torch.nn.Module), "AWQ algorithm only supports torch module"
model = replace_forward(model)
return model

@torch.no_grad()
def convert(
self,
model,
example_inputs=example_inputs,
calib_func=calib_func,
dataloader=dataloader,
n_samples=n_samples,
bits=bits,
group_size=group_size,
scheme=scheme,
use_full_range=use_full_range,
weight_config=weight_config,
data_type=data_type,
)
qdq_model = awq.quantize(
use_auto_scale=use_auto_scale,
use_mse_search=use_mse_search,
folding=folding,
return_int=return_int,
)
return qdq_model
bits=4,
group_size=32,
scheme="asym",
example_inputs=None,
dataloader=None,
use_auto_scale=True,
use_mse_search=True,
folding=False,
return_int=False,
use_full_range=False,
data_type="int",
*args,
**kwargs,
):
"""Converts a prepared model to a quantized model.
Args:
model: torch model.
bits: num bits. Defaults to 4.
group_size: how many elements share one scale/zp. Defaults to 32.
scheme: sym or asym. Defaults to "asym".
example_inputs: example_inputs. Defaults to None.
dataloader: datalaoder or example_inputs is required. Defaults to None.
use_auto_scale: whether enable scale for salient weight. Defaults to True.
use_mse_search: whether enable clip for weight by checking mse. Defaults to True.
folding: False will allow insert mul before linear when the scale cannot be absorbed
by last layer, else won't. Defaults to False.
return_int: Choose return fp32 or int32 model. Defaults to False.
use_full_range: Choose sym range whether use -2**(bits-1). Defaults to False.
data_type: data type. Defaults to "int".
Returns:
model: fake quantized model
"""
model = recover_forward(model)
total_block_args = getattr(model, "total_block_args", [])
total_block_kwargs = getattr(model, "total_block_kwargs", [])
delattr(model, "total_block_args")
delattr(model, "total_block_kwargs")

awq = ActAwareWeightQuant(
model,
example_inputs=example_inputs,
dataloader=dataloader,
data_type=data_type,
bits=bits,
group_size=group_size,
scheme=scheme,
use_full_range=use_full_range,
weight_config=self.quant_config,
total_block_args=total_block_args,
total_block_kwargs=total_block_kwargs,
)
qdq_model = awq.quantize(
use_auto_scale=use_auto_scale,
use_mse_search=use_mse_search,
folding=folding,
return_int=return_int,
)
return qdq_model
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,3 @@

from .quantizer import HQQuantizer
from .config import HQQModuleConfig, QTensorConfig
from .quant_api import hqq_quantize
63 changes: 0 additions & 63 deletions neural_compressor/torch/algorithms/weight_only/hqq/quant_api.py

This file was deleted.

Loading

0 comments on commit 1a45090

Please sign in to comment.