Skip to content

Commit

Permalink
support TEQ quantization method (#1093)
Browse files Browse the repository at this point in the history
  • Loading branch information
lkk12014402 authored Jul 20, 2023
1 parent 59172ad commit d2f995b
Show file tree
Hide file tree
Showing 7 changed files with 645 additions and 4 deletions.
40 changes: 40 additions & 0 deletions neural_compressor/adaptor/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -4522,6 +4522,9 @@ def quantize(self, tune_cfg, model, dataloader, calib_func=None):
if 'GPTQ' in all_algo:
q_model._model = self.gptq_quantize(q_model._model, tune_cfg, dataloader)

if 'TEQ' in all_algo:
q_model._model = self.teq_quantize(q_model._model, tune_cfg, dataloader, calib_func)

if 'AWQ' in all_algo: # includes RTN in AWQ
q_model._model = self.awq_quantize(q_model._model, tune_cfg, dataloader, calib_func)
elif 'RTN' in all_algo:
Expand Down Expand Up @@ -4582,6 +4585,43 @@ def gptq_quantize(self, model, tune_cfg, dataloader):
)
return model

def teq_quantize(self, model, tune_cfg, dataloader, calib_func):
logger.debug("quantizing with the TEQ algorithm")
from .torch_utils.weight_only import teq_quantize
# get example inputs if not provided.
if self.example_inputs is None:
if dataloader is None:
assert False, "Please provide dataloader or example_inputs for TEQ algorithm."
try:
for idx, (input, label) in enumerate(dataloader):
self.example_inputs = input
break
except:
for idx, input in enumerate(dataloader):
self.example_inputs = input
break

if 'teq_args' in self.recipes:
wbits = self.recipes.get('wbits', 4)
group_size = self.recipes.get('group_size', 128)
sym = self.recipes.get('scheme', False)
folding = self.recipes.get('folding', True)

weight_config = {
'wbits': wbits,
'group_size': group_size,
'sym': sym,
'folding': folding
}
quantizer = teq_quantize(
model,
weight_config,
dataloader,
example_inputs=self.example_inputs,
calib_func=calib_func
)
return quantizer.model

def awq_quantize(self, model, tune_cfg, dataloader, calib_func):
logger.debug("quantizing with the AWQ algorithm")
from .torch_utils.weight_only import awq_quantize
Expand Down
94 changes: 94 additions & 0 deletions neural_compressor/adaptor/torch_utils/model_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
import math
import torch
from torch.nn import functional as F
from torch.autograd import Function
from .weight_only import quant_weight
from packaging.version import Version


Expand Down Expand Up @@ -355,3 +357,95 @@ def extra_repr(self) -> str:
return 'in_features={}, out_features={}, bits={}, group_size={}, bias={}'.format(
self.in_features, self.out_features, self.bits, self.groupsize, self.bias is not None
)


class FakeAffineTensorQuantFunction(Function):
"""Fake version of affine quantization
"""

@staticmethod
def forward(ctx, inputs, num_bits=4, group_size=1024):
"""
As it will be only applied on activation with per tensor granularity, broadcast is not needed.
Args:
ctx: Pytorch convention.
inputs: A Tensor of type float32.
min_range: A float.
max_range: A float.
num_bits: An integer
Returns:
outputs: A Tensor of type output_dtype
"""
return quant_weight(inputs, num_bits, group_size)

@staticmethod
def backward(ctx, grad_outputs):
"""
Args:
ctx: Pytorch convention.
grad_output: A tensor of gradient of outputs
Returns:
grad_inputs: A tensor of gradient
"""
return grad_outputs, None, None


class TEQLinearFakeQuant(torch.nn.Module):
"""
wrapper quantization linear
"""

def __init__(self, orig_layer, alpha=None, num_bits=4, group_size=-1):
"""
A forward hook to linear module
:param orig_layer: the original module
:param alpha: trainable alpha/scale
:param num_bits: quantization level
:param group_size: for fine-grained quantization
"""
super(TEQLinearFakeQuant, self).__init__()
self.orig_layer = orig_layer
self.alpha = alpha

self.num_bits = num_bits
self.group_size = group_size

def forward(self, x):
alpha = torch.clip(self.alpha, 1e-5)
shape_len = len(x.shape) - 1
shape = (1,) * shape_len + (-1,)
x = x / alpha.view(shape)
weight = self.orig_layer.weight
weight = weight * alpha.unsqueeze(dim=0)
weight_q = FakeAffineTensorQuantFunction().apply(weight, self.num_bits, self.group_size)
return F.linear(x, weight_q, self.orig_layer.bias)


class TEQMulLinear(torch.nn.Module):
"""
Trainable Equivalent Transformation (TEQ): linear wrapper to apply scale to input
"""

def __init__(self, module, input_scale):
"""
A forward hook to save input max of a module
:param module: the linear module
:param input_scale: scale for input
"""

super().__init__()
self.register_buffer('input_scale', input_scale)
self.add_module('sq_linear', module)

@property
def weight(self):
return self.sq_linear.weight

def forward(self, X):
X = torch.mul(X, self.input_scale)
X = self.sq_linear(X)
return X
Loading

0 comments on commit d2f995b

Please sign in to comment.