Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Migrate RTN, HQQ and AWQ to Torch new 3.x API #1765

Merged
merged 12 commits into from
May 7, 2024
19 changes: 9 additions & 10 deletions neural_compressor/torch/algorithms/base_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,19 +85,18 @@ def quantize(self, model: torch.nn.Module, *args: Any, **kwargs: Any):
Returns:
A quantized model.
"""
model = self.prepare(model, *args, **kwargs)

run_fn = kwargs.get("run_fn", None)
run_args = kwargs.get("run_args", None)
assert run_fn is not None, (
"Can't find run_func. Please provide run_func to quantize API "
"or overwrite quantize member function in your Quantizer class."
)
if run_fn is not None:
run_args = kwargs.get("run_args", None)
if run_args:
run_fn(model, *run_args)
else:
run_fn(model)

model = self.prepare(model, *args, **kwargs)
if run_args:
run_fn(model, *run_args)
else:
run_fn(model)
model = self.convert(model, *args, **kwargs)

return model

def execute(self, model: torch.nn.Module, mode, *args: Any, **kwargs: Any): # pragma: no cover
Expand Down
183 changes: 98 additions & 85 deletions neural_compressor/torch/algorithms/weight_only/awq.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,11 @@
# Copied from neural_compressor/adaptor/torch_utils/awq.py

import copy
from collections import OrderedDict

import torch

from neural_compressor.torch.algorithms import Quantizer
from neural_compressor.torch.utils import get_device, logger

from .modules import MulLinear
Expand All @@ -26,13 +28,13 @@
get_absorb_layers,
get_block_prefix,
get_example_input,
get_hidden_states,
get_module_input_output,
model_forward,
recover_forward,
replace_forward,
set_module,
)

__all__ = ["awq_quantize"]
__all__ = ["AWQQuantizer"]


def _get_absorb_per_block(model, example_inputs, folding=False, weight_config={}):
Expand Down Expand Up @@ -113,15 +115,15 @@ def __init__(
self,
model,
example_inputs=None,
calib_func=None,
dataloader=None,
n_samples=128,
data_type="int",
bits=4,
group_size=32,
scheme="asym",
use_full_range=False,
weight_config={},
total_block_args=[],
total_block_kwargs=[],
):

self.example_inputs = example_inputs
Expand All @@ -130,11 +132,9 @@ def __init__(
assert dataloader is not None, "datalaoder or example_inputs is required."
self.example_inputs = get_example_input(dataloader)
self._move_model_and_data_to_device()
# Step 1: get hidden states and kwargs of first block.
self.total_block_args, self.total_block_kwargs = get_hidden_states(
model, dataloader=dataloader, n_samples=n_samples, calib_func=calib_func
)
# Step 2: get block list and block prefix, number
self.total_block_args = total_block_args
self.total_block_kwargs = total_block_kwargs
# get block list and block prefix, number
self.block_prefix, self.block_num = get_block_prefix(model)
self.block_list = fetch_module(model, self.block_prefix)
self.data_type = data_type
Expand Down Expand Up @@ -429,14 +429,15 @@ def apply_quantize_with_clip(self, return_int=False):
"""
# apply quantization and clip
logger.info("Quantizing the AWQ optimized fp32 model")
from .rtn import rtn_quantize
from .rtn import RTNQuantizer

rtn_quantizer = RTNQuantizer(quant_config=self.weight_config)

self.model = rtn_quantize(
self.model = rtn_quantizer.quantize(
self.model,
num_bits=self.bits,
bits=self.bits,
group_size=self.group_size,
scheme=self.scheme,
weight_config=self.weight_config,
return_int=return_int,
use_full_range=self.use_full_range,
)
Expand Down Expand Up @@ -492,78 +493,90 @@ def module_inference(self, model, inputs):
return total_out


@torch.no_grad()
def awq_quantize(
model,
bits=4,
group_size=32,
scheme="asym",
weight_config={},
example_inputs=None,
dataloader=None,
n_samples=128,
calib_func=None,
use_auto_scale=True,
use_mse_search=True,
folding=False,
return_int=False,
use_full_range=False,
data_type="int",
):
"""Quant the model with Activation-aware Weight quantization(AWQ) method.
class AWQQuantizer(Quantizer):
def __init__(self, quant_config: OrderedDict = {}):
"""Init an AWQQuantizer object.

Args:
model (torch.nn.Module): torch model.
example_inputs: example_inputs.
weight_config (dict, optional): contains all info required by AWQ. Defaults to {}.
For example,
weight_config={
'fc2':
{
# 'absorb_layer': 'fc1',
'bits': 4,
'group_size': 32,
'scheme': 'sym'
}
}
absorb_dict (dict, optional): contains all absorb info required by AWQ.. Defaults to {}.
For example,
absorb_dict = {
# 'absorb_layer': absorbed_layer
'fc1': ['fc1', 'fc2', 'fc3']
} # in this case, fc2 and fc3 need to share the same scale. fc1 is self absorbed.
# self absorb module will replace with MulLinear, which contains torch.mul and module.
n_samples: calibration sample number.
use_auto_scale (bool, optional): whether enable scale for salient weight. Defaults to True.
use_mse_search (bool, optional): whether enable clip for weight by checking mse. Defaults to True.
calib_func: a custom inference function to replace dataloader and iters.
n_blocks: split model into block number to avoid OOM.
return_int (bool, optional): Choose return fp32 or int32 model.
Defaults to False.
use_full_range (bool, optional): Choose sym range whether use -2**(bits-1).
Args:
quant_config (OrderedDict, optional): quantization config for ops. Defaults to {}.
"""
super().__init__(quant_config)

Returns:
model: fake quantized model
"""
@torch.no_grad()
def prepare(self, model, *args, **kwargs):
"""Prepare a given model to get hidden states and kwargs of first block.

Args:
model: A float torch model.

assert isinstance(model, torch.nn.Module), "only support torch module"
awq = ActAwareWeightQuant(
Returns:
A prepared model.
"""
assert isinstance(model, torch.nn.Module), "AWQ algorithm only supports torch module"
model = replace_forward(model)
return model

@torch.no_grad()
def convert(
self,
model,
example_inputs=example_inputs,
calib_func=calib_func,
dataloader=dataloader,
n_samples=n_samples,
bits=bits,
group_size=group_size,
scheme=scheme,
use_full_range=use_full_range,
weight_config=weight_config,
data_type=data_type,
)
qdq_model = awq.quantize(
use_auto_scale=use_auto_scale,
use_mse_search=use_mse_search,
folding=folding,
return_int=return_int,
)
return qdq_model
bits=4,
group_size=32,
scheme="asym",
example_inputs=None,
dataloader=None,
use_auto_scale=True,
use_mse_search=True,
folding=False,
return_int=False,
use_full_range=False,
data_type="int",
*args,
**kwargs,
):
"""Converts a prepared model to a quantized model.

Args:
model: torch model.
bits: num bits. Defaults to 4.
group_size: how many elements share one scale/zp. Defaults to 32.
scheme: sym or asym. Defaults to "asym".
example_inputs: example_inputs. Defaults to None.
dataloader: datalaoder or example_inputs is required. Defaults to None.
use_auto_scale: whether enable scale for salient weight. Defaults to True.
use_mse_search: whether enable clip for weight by checking mse. Defaults to True.
folding: False will allow insert mul before linear when the scale cannot be absorbed
by last layer, else won't. Defaults to False.
return_int: Choose return fp32 or int32 model. Defaults to False.
use_full_range: Choose sym range whether use -2**(bits-1). Defaults to False.
data_type: data type. Defaults to "int".

Returns:
model: fake quantized model
"""
model = recover_forward(model)
total_block_args = getattr(model, "total_block_args", [])
total_block_kwargs = getattr(model, "total_block_kwargs", [])
delattr(model, "total_block_args")
delattr(model, "total_block_kwargs")

awq = ActAwareWeightQuant(
model,
example_inputs=example_inputs,
dataloader=dataloader,
data_type=data_type,
bits=bits,
group_size=group_size,
scheme=scheme,
use_full_range=use_full_range,
weight_config=self.quant_config,
total_block_args=total_block_args,
total_block_kwargs=total_block_kwargs,
)
qdq_model = awq.quantize(
use_auto_scale=use_auto_scale,
use_mse_search=use_mse_search,
folding=folding,
return_int=return_int,
)
return qdq_model
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,3 @@

from .quantizer import HQQuantizer
from .config import HQQModuleConfig, QTensorConfig
from .quant_api import hqq_quantize
63 changes: 0 additions & 63 deletions neural_compressor/torch/algorithms/weight_only/hqq/quant_api.py

This file was deleted.

Loading
Loading