Skip to content

Commit

Permalink
add mse_range for weight_only RTN algo (#1157)
Browse files Browse the repository at this point in the history
* add mse_range for weight_only RTN algo

Signed-off-by: Cheng, Zixuan <[email protected]>

* minor fix

Signed-off-by: Cheng, Zixuan <[email protected]>

* fix mse calculation

Signed-off-by: Cheng, Zixuan <[email protected]>

* minor fix

Signed-off-by: Cheng, Zixuan <[email protected]>

* fix for UT coverage

Signed-off-by: Cheng, Zixuan <[email protected]>

* fix code

Signed-off-by: Cheng, Zixuan <[email protected]>

---------

Signed-off-by: Cheng, Zixuan <[email protected]>
  • Loading branch information
violetch24 authored Aug 18, 2023
1 parent 66f7c10 commit 19ab16c
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 5 deletions.
3 changes: 2 additions & 1 deletion docs/source/quantization_weight_only.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,14 @@ There are many excellent works for weight only quantization to improve its accur
| rtn_args | default value | comments |
|:----------:|:-------------:|:-------------------------------------------------------------------:|
| sym_full_range | False | Whether use -2**(bits-1) in sym scheme, for example, |
| mse_range | False | Whether search for the best clip range from range [0.805, 1.0, 0.005] |
| return_int | False | Whether return compressed model with int data type |

**AWQ arguments**:
| awq_args | default value | comments |
|:----------:|:-------------:|:-------------------------------------------------------------------:|
| auto_scale | True | Whether search for best scales based on activation distribution |
| mse_range | True | Whether search for the best clip range from range [0.89, 1.0, 0.01] |
| mse_range | True | Whether search for the best clip range from range [0.91, 1.0, 0.01] |
| folding | False | False will allow insert mul before linear when the scale cannot be absorbed by last layer, else won't |

**GPTQ arguments**:
Expand Down
7 changes: 5 additions & 2 deletions neural_compressor/adaptor/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -4338,8 +4338,10 @@ def rtn_quantize(self, model, tune_cfg):
logger.info("quantizing with the round-to-nearest algorithm")
if 'rtn_args' in self.recipes:
sym_full_range = self.recipes['rtn_args'].get('sym_full_range', False)
else:
mse_range = self.recipes['rtn_args'].get('mse_range', False)
else: # pragma: no cover
sym_full_range=False
mse_range=False
from .torch_utils.weight_only import rtn_quantize
from .torch_utils.util import fetch_module, set_module
for key, config in tune_cfg['op'].items():
Expand All @@ -4356,7 +4358,8 @@ def rtn_quantize(self, model, tune_cfg):
m = fetch_module(model, op_name)
m = rtn_quantize(m, num_bits, group_size, scheme,
return_int=False,
sym_full_range=sym_full_range)
sym_full_range=sym_full_range,
mse_range=mse_range)
set_module(model, op_name, m)
return model

Expand Down
47 changes: 46 additions & 1 deletion neural_compressor/adaptor/torch_utils/weight_only.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,9 +208,50 @@ def quant_weight(weight, num_bits=4, group_size=-1, scheme="asym", quantile=1.0,
return weight


def search_clip(m, num_bits, group_size, scheme, sym_full_range):
"""Search best clip range of each linears in current block.
Args:
m (torch.nn.Module): torch module.
num_bits (int, optional): num bits.
group_size (int, optional): how many elements share one scale/zp.
scheme (str, optional): sym or asym.
sym_full_range (bool, optional): Choose sym range whether use -2**(bits-1).
Returns:
best_clip_ratio (float): best percentile of clip
"""
org_weight = m.weight.data
logger.info("Searching the best clip range with RTN algorithm")
best_error = float('inf')
best_clip_ratio = None
n_grid = 200
max_shrink = 0.2
history = []
for i_s in range(int(max_shrink * n_grid)):
ratio = (1 - i_s / n_grid) # 1, 0.805-1.0
cur_weight = quant_weight(
m.weight.data,
num_bits=num_bits,
group_size=group_size,
scheme=scheme,
full_range=sym_full_range,
quantile=ratio,
)
loss = (org_weight - cur_weight).float().pow(2).mean().item()
history.append(loss)
is_best = loss < best_error
if is_best:
best_error = loss
best_clip_ratio = ratio
logger.debug("The loss history of different clip range:{}".format(history))
logger.debug("The best clip ratio is {}".format(best_clip_ratio))
return best_clip_ratio

def rtn_quantize(model, num_bits=4, group_size=32, scheme="asym",
quantile=1.0, weight_config={}, return_int=False,
sym_full_range=False, **kwargs):
sym_full_range=False, mse_range=False, **kwargs):
"""Quant the model with round to nearst method.
Args:
Expand All @@ -234,6 +275,8 @@ def rtn_quantize(model, num_bits=4, group_size=32, scheme="asym",
Defaults to False.
sym_full_range (bool, optional): Choose sym range whether use -2**(bits-1).
Defaults to False.
mse_range (bool, optional): Whether search clip range.
Defaults to True.
Returns:
model: fake quantized torch module
Expand Down Expand Up @@ -264,6 +307,8 @@ def rtn_quantize(model, num_bits=4, group_size=32, scheme="asym",
logger.info(f"Skip {name}")
continue
weight = m.weight
if mse_range:
quantile = search_clip(m, num_bits, group_size, scheme, sym_full_range)
if return_int:
from .model_wrapper import WeightOnlyLinear
int_weight, scale, zp = quant_weight(
Expand Down
16 changes: 15 additions & 1 deletion test/adaptor/pytorch_adaptor/test_weight_only_adaptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,6 @@ def test_RTN_quant(self):
compressed_model = q_model.export_compressed_model()
out3 = compressed_model(input)
self.assertTrue(torch.all(out3==out2))

model = Model()
out1 = model(input)

Expand All @@ -108,6 +107,21 @@ def test_RTN_quant(self):
out3 = compressed_model(input)
self.assertTrue(torch.all(out3==out2))

model = Model()
out1 = model(input)
conf = PostTrainingQuantConfig(
approach='weight_only',
recipes={
# By default, sym_full_range is False and 4 bit sym will only use range [-7,7].
# When mse_range is set to True, enable clip for weight by checking mse.
'rtn_args': {'sym_full_range': True, 'mse_range': True}
}
)
q_model = quantization.fit(model, conf)
out2 = q_model(input)
self.assertTrue(torch.all(torch.isclose(out1, out2, atol=5e-1)))
self.assertFalse(torch.all(out1 == out2))

model = Model()
out1 = model(input)
conf = PostTrainingQuantConfig(
Expand Down

0 comments on commit 19ab16c

Please sign in to comment.